diff --git a/dns_unbound_cache_reader/__init__.py b/dns_unbound_cache_reader/__init__.py index 7e2a5e27d8a346195d4befe8a5bedf2218ba1d01..2281e6ca5a5c3f5803bd080c8b6e626d75585ad7 100644 --- a/dns_unbound_cache_reader/__init__.py +++ b/dns_unbound_cache_reader/__init__.py @@ -1 +1 @@ -from .dns_unbound_cache_reader import read_unbound_cache +from .dns_unbound_cache_reader import read_dns_cache diff --git a/dns_unbound_cache_reader/dns_unbound_cache_reader.py b/dns_unbound_cache_reader/dns_unbound_cache_reader.py index 827e5450c4e8fad1f635d79685976dd816130cd2..aef3c929dfa90ff89a8573b82a27e3b5b32d65a8 100644 --- a/dns_unbound_cache_reader/dns_unbound_cache_reader.py +++ b/dns_unbound_cache_reader/dns_unbound_cache_reader.py @@ -55,7 +55,10 @@ class DnsTableKeys(Enum): SERVICE = "service" -def read_unbound_cache(host: str = "127.0.0.1"): +def read_dns_cache( + host: str = "127.0.0.1", + file: str = None + ) -> dict: """ Read the Unbound DNS cache and return it as a dictionary, in the format: @@ -72,21 +75,35 @@ def read_unbound_cache(host: str = "127.0.0.1"): Args: host (str): IP address of the Unbound DNS server. Default is localhost. + file (str): Path to a file containing the Unbound DNS cache. Default is None. + If specified, the function reads the cache from the file instead of the server. """ - if host in localhost: - ## Unbound runs on localhost - proc = subprocess.run(cmd.split(), capture_output=True) - dns_cache = proc.stdout.decode().strip().split("\n") + + ### Get DNS cache ### + dns_cache = None + + if file is None: + # Get DNS cache from Unbound + + if host in localhost: + ## Unbound runs on localhost + proc = subprocess.run(cmd.split(), capture_output=True) + dns_cache = proc.stdout.decode().strip().split("\n") + + else: + ## Unbound runs on a remote host + # SSH connection with remote host + ssh_config = Config(overrides={"run": {"hide": True}}) + remote = Connection(host, config=ssh_config) + # Get the DNS cache + result = remote.run(cmd) + dns_cache = result.stdout.strip().split("\n") else: - ## Unbound runs on a remote host - # SSH connection with remote host - ssh_config = Config(overrides={"run": {"hide": True}}) - remote = Connection(host, config=ssh_config) - # Get the DNS cache - result = remote.run(cmd) - dns_cache = result.stdout.strip().split("\n") - + # Read DNS cache from file + with open(file, "r") as f: + dns_cache = f.read().strip().split("\n") + ### Parse DNS cache ### @@ -131,10 +148,10 @@ def read_unbound_cache(host: str = "127.0.0.1"): # A (IPv4) and AAAA (IPv6) records if rtype == DnsRtype.A.name or rtype == DnsRtype.AAAA.name: ip = rdata - if DnsTableKeys.IP in dns_table: - dns_table[DnsTableKeys.IP][ip] = name + if DnsTableKeys.IP.name in dns_table: + dns_table[DnsTableKeys.IP.name][ip] = name else: - dns_table[DnsTableKeys.IP] = {ip: name} + dns_table[DnsTableKeys.IP.name] = {ip: name} # PTR records if rtype == DnsRtype.PTR.name: @@ -142,17 +159,17 @@ def read_unbound_cache(host: str = "127.0.0.1"): if match_ptr: # PTR record is a reverse DNS lookup ip = ".".join(reversed(match_ptr.groups())) - if ip not in dns_table.get(DnsTableKeys.IP, {}): - if DnsTableKeys.IP in dns_table: - dns_table[DnsTableKeys.IP][ip] = rdata + if ip not in dns_table.get(DnsTableKeys.IP.name, {}): + if DnsTableKeys.IP.name in dns_table: + dns_table[DnsTableKeys.IP.name][ip] = rdata else: - dns_table[DnsTableKeys.IP] = {ip: rdata} + dns_table[DnsTableKeys.IP.name] = {ip: rdata} else: # PTR record contains generic RDATA - if DnsTableKeys.SERVICE in dns_table: - dns_table[DnsTableKeys.SERVICE][name] = rdata + if DnsTableKeys.SERVICE.name in dns_table: + dns_table[DnsTableKeys.SERVICE.name][name] = rdata else: - dns_table[DnsTableKeys.SERVICE] = {name: rdata} + dns_table[DnsTableKeys.SERVICE.name] = {name: rdata} # SRV records if rtype == DnsRtype.SRV.name: @@ -163,10 +180,10 @@ def read_unbound_cache(host: str = "127.0.0.1"): service = match_srv.group(4) if service.endswith("."): service = service[:-1] - if DnsTableKeys.SERVICE in dns_table: - dns_table[DnsTableKeys.SERVICE][service] = name + if DnsTableKeys.SERVICE.name in dns_table: + dns_table[DnsTableKeys.SERVICE.name][service] = name else: - dns_table[DnsTableKeys.SERVICE] = {service: name} + dns_table[DnsTableKeys.SERVICE.name] = {service: name} - return dns_cache + return dns_table