Skip to content
Extraits de code Groupes Projets
Valider 34999d96 rédigé par François De Keersmaeker's avatar François De Keersmaeker
Parcourir les fichiers

Tiny refactor

parent 70ba5997
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
from .dns_unbound_cache_reader import read_unbound_cache
from .dns_unbound_cache_reader import read_dns_cache
......@@ -55,7 +55,10 @@ class DnsTableKeys(Enum):
SERVICE = "service"
def read_unbound_cache(host: str = "127.0.0.1"):
def read_dns_cache(
host: str = "127.0.0.1",
file: str = None
) -> dict:
"""
Read the Unbound DNS cache and return it as a dictionary,
in the format:
......@@ -72,21 +75,35 @@ def read_unbound_cache(host: str = "127.0.0.1"):
Args:
host (str): IP address of the Unbound DNS server. Default is localhost.
file (str): Path to a file containing the Unbound DNS cache. Default is None.
If specified, the function reads the cache from the file instead of the server.
"""
if host in localhost:
## Unbound runs on localhost
proc = subprocess.run(cmd.split(), capture_output=True)
dns_cache = proc.stdout.decode().strip().split("\n")
### Get DNS cache ###
dns_cache = None
if file is None:
# Get DNS cache from Unbound
if host in localhost:
## Unbound runs on localhost
proc = subprocess.run(cmd.split(), capture_output=True)
dns_cache = proc.stdout.decode().strip().split("\n")
else:
## Unbound runs on a remote host
# SSH connection with remote host
ssh_config = Config(overrides={"run": {"hide": True}})
remote = Connection(host, config=ssh_config)
# Get the DNS cache
result = remote.run(cmd)
dns_cache = result.stdout.strip().split("\n")
else:
## Unbound runs on a remote host
# SSH connection with remote host
ssh_config = Config(overrides={"run": {"hide": True}})
remote = Connection(host, config=ssh_config)
# Get the DNS cache
result = remote.run(cmd)
dns_cache = result.stdout.strip().split("\n")
# Read DNS cache from file
with open(file, "r") as f:
dns_cache = f.read().strip().split("\n")
### Parse DNS cache ###
......@@ -131,10 +148,10 @@ def read_unbound_cache(host: str = "127.0.0.1"):
# A (IPv4) and AAAA (IPv6) records
if rtype == DnsRtype.A.name or rtype == DnsRtype.AAAA.name:
ip = rdata
if DnsTableKeys.IP in dns_table:
dns_table[DnsTableKeys.IP][ip] = name
if DnsTableKeys.IP.name in dns_table:
dns_table[DnsTableKeys.IP.name][ip] = name
else:
dns_table[DnsTableKeys.IP] = {ip: name}
dns_table[DnsTableKeys.IP.name] = {ip: name}
# PTR records
if rtype == DnsRtype.PTR.name:
......@@ -142,17 +159,17 @@ def read_unbound_cache(host: str = "127.0.0.1"):
if match_ptr:
# PTR record is a reverse DNS lookup
ip = ".".join(reversed(match_ptr.groups()))
if ip not in dns_table.get(DnsTableKeys.IP, {}):
if DnsTableKeys.IP in dns_table:
dns_table[DnsTableKeys.IP][ip] = rdata
if ip not in dns_table.get(DnsTableKeys.IP.name, {}):
if DnsTableKeys.IP.name in dns_table:
dns_table[DnsTableKeys.IP.name][ip] = rdata
else:
dns_table[DnsTableKeys.IP] = {ip: rdata}
dns_table[DnsTableKeys.IP.name] = {ip: rdata}
else:
# PTR record contains generic RDATA
if DnsTableKeys.SERVICE in dns_table:
dns_table[DnsTableKeys.SERVICE][name] = rdata
if DnsTableKeys.SERVICE.name in dns_table:
dns_table[DnsTableKeys.SERVICE.name][name] = rdata
else:
dns_table[DnsTableKeys.SERVICE] = {name: rdata}
dns_table[DnsTableKeys.SERVICE.name] = {name: rdata}
# SRV records
if rtype == DnsRtype.SRV.name:
......@@ -163,10 +180,10 @@ def read_unbound_cache(host: str = "127.0.0.1"):
service = match_srv.group(4)
if service.endswith("."):
service = service[:-1]
if DnsTableKeys.SERVICE in dns_table:
dns_table[DnsTableKeys.SERVICE][service] = name
if DnsTableKeys.SERVICE.name in dns_table:
dns_table[DnsTableKeys.SERVICE.name][service] = name
else:
dns_table[DnsTableKeys.SERVICE] = {service: name}
dns_table[DnsTableKeys.SERVICE.name] = {service: name}
return dns_cache
return dns_table
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Terminez d'abord l'édition de ce message.
Veuillez vous inscrire ou vous pour commenter