Skip to content
Extraits de code Groupes Projets
Valider 34999d96 rédigé par François De Keersmaeker's avatar François De Keersmaeker
Parcourir les fichiers

Tiny refactor

parent 70ba5997
Aucune branche associée trouvée
Aucune étiquette associée trouvée
Aucune requête de fusion associée trouvée
from .dns_unbound_cache_reader import read_unbound_cache from .dns_unbound_cache_reader import read_dns_cache
...@@ -55,7 +55,10 @@ class DnsTableKeys(Enum): ...@@ -55,7 +55,10 @@ class DnsTableKeys(Enum):
SERVICE = "service" SERVICE = "service"
def read_unbound_cache(host: str = "127.0.0.1"): def read_dns_cache(
host: str = "127.0.0.1",
file: str = None
) -> dict:
""" """
Read the Unbound DNS cache and return it as a dictionary, Read the Unbound DNS cache and return it as a dictionary,
in the format: in the format:
...@@ -72,21 +75,35 @@ def read_unbound_cache(host: str = "127.0.0.1"): ...@@ -72,21 +75,35 @@ def read_unbound_cache(host: str = "127.0.0.1"):
Args: Args:
host (str): IP address of the Unbound DNS server. Default is localhost. host (str): IP address of the Unbound DNS server. Default is localhost.
file (str): Path to a file containing the Unbound DNS cache. Default is None.
If specified, the function reads the cache from the file instead of the server.
""" """
if host in localhost:
## Unbound runs on localhost ### Get DNS cache ###
proc = subprocess.run(cmd.split(), capture_output=True) dns_cache = None
dns_cache = proc.stdout.decode().strip().split("\n")
if file is None:
# Get DNS cache from Unbound
if host in localhost:
## Unbound runs on localhost
proc = subprocess.run(cmd.split(), capture_output=True)
dns_cache = proc.stdout.decode().strip().split("\n")
else:
## Unbound runs on a remote host
# SSH connection with remote host
ssh_config = Config(overrides={"run": {"hide": True}})
remote = Connection(host, config=ssh_config)
# Get the DNS cache
result = remote.run(cmd)
dns_cache = result.stdout.strip().split("\n")
else: else:
## Unbound runs on a remote host # Read DNS cache from file
# SSH connection with remote host with open(file, "r") as f:
ssh_config = Config(overrides={"run": {"hide": True}}) dns_cache = f.read().strip().split("\n")
remote = Connection(host, config=ssh_config)
# Get the DNS cache
result = remote.run(cmd)
dns_cache = result.stdout.strip().split("\n")
### Parse DNS cache ### ### Parse DNS cache ###
...@@ -131,10 +148,10 @@ def read_unbound_cache(host: str = "127.0.0.1"): ...@@ -131,10 +148,10 @@ def read_unbound_cache(host: str = "127.0.0.1"):
# A (IPv4) and AAAA (IPv6) records # A (IPv4) and AAAA (IPv6) records
if rtype == DnsRtype.A.name or rtype == DnsRtype.AAAA.name: if rtype == DnsRtype.A.name or rtype == DnsRtype.AAAA.name:
ip = rdata ip = rdata
if DnsTableKeys.IP in dns_table: if DnsTableKeys.IP.name in dns_table:
dns_table[DnsTableKeys.IP][ip] = name dns_table[DnsTableKeys.IP.name][ip] = name
else: else:
dns_table[DnsTableKeys.IP] = {ip: name} dns_table[DnsTableKeys.IP.name] = {ip: name}
# PTR records # PTR records
if rtype == DnsRtype.PTR.name: if rtype == DnsRtype.PTR.name:
...@@ -142,17 +159,17 @@ def read_unbound_cache(host: str = "127.0.0.1"): ...@@ -142,17 +159,17 @@ def read_unbound_cache(host: str = "127.0.0.1"):
if match_ptr: if match_ptr:
# PTR record is a reverse DNS lookup # PTR record is a reverse DNS lookup
ip = ".".join(reversed(match_ptr.groups())) ip = ".".join(reversed(match_ptr.groups()))
if ip not in dns_table.get(DnsTableKeys.IP, {}): if ip not in dns_table.get(DnsTableKeys.IP.name, {}):
if DnsTableKeys.IP in dns_table: if DnsTableKeys.IP.name in dns_table:
dns_table[DnsTableKeys.IP][ip] = rdata dns_table[DnsTableKeys.IP.name][ip] = rdata
else: else:
dns_table[DnsTableKeys.IP] = {ip: rdata} dns_table[DnsTableKeys.IP.name] = {ip: rdata}
else: else:
# PTR record contains generic RDATA # PTR record contains generic RDATA
if DnsTableKeys.SERVICE in dns_table: if DnsTableKeys.SERVICE.name in dns_table:
dns_table[DnsTableKeys.SERVICE][name] = rdata dns_table[DnsTableKeys.SERVICE.name][name] = rdata
else: else:
dns_table[DnsTableKeys.SERVICE] = {name: rdata} dns_table[DnsTableKeys.SERVICE.name] = {name: rdata}
# SRV records # SRV records
if rtype == DnsRtype.SRV.name: if rtype == DnsRtype.SRV.name:
...@@ -163,10 +180,10 @@ def read_unbound_cache(host: str = "127.0.0.1"): ...@@ -163,10 +180,10 @@ def read_unbound_cache(host: str = "127.0.0.1"):
service = match_srv.group(4) service = match_srv.group(4)
if service.endswith("."): if service.endswith("."):
service = service[:-1] service = service[:-1]
if DnsTableKeys.SERVICE in dns_table: if DnsTableKeys.SERVICE.name in dns_table:
dns_table[DnsTableKeys.SERVICE][service] = name dns_table[DnsTableKeys.SERVICE.name][service] = name
else: else:
dns_table[DnsTableKeys.SERVICE] = {service: name} dns_table[DnsTableKeys.SERVICE.name] = {service: name}
return dns_cache return dns_table
0% Chargement en cours ou .
You are about to add 0 people to the discussion. Proceed with caution.
Terminez d'abord l'édition de ce message.
Veuillez vous inscrire ou vous pour commenter