From a81b754a67b3f4e70288e4479c918945bc3e8966 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20De=20Keersmaeker?=
 <francois.dekeersmaeker@uclouvain.be>
Date: Thu, 14 Nov 2024 15:02:02 +0000
Subject: [PATCH] Added support for CNAME records + Renamed service => alias

---
 .../dns_unbound_cache_reader.py               | 80 ++++++++++++-------
 test/sample_dns_cache.txt                     |  9 ++-
 test/test_sample_file.py                      | 14 ++--
 3 files changed, 64 insertions(+), 39 deletions(-)

diff --git a/dns_unbound_cache_reader/dns_unbound_cache_reader.py b/dns_unbound_cache_reader/dns_unbound_cache_reader.py
index 2fe8119..7228e1a 100644
--- a/dns_unbound_cache_reader/dns_unbound_cache_reader.py
+++ b/dns_unbound_cache_reader/dns_unbound_cache_reader.py
@@ -22,7 +22,7 @@ to_skip = (
     "EOF"
 )
 # Regex patterns
-pattern_line      = r"^([a-zA-Z0-9._-]+)\s+(\d+)\s+IN\s+([A-Z]+)\s+(.+)$"                    # Generic DNS cache line
+pattern_line      = r"^([a-zA-Z0-9._-]+)\s+(\d+)\s+IN\s+([A-Z]+)\s+(.+)$"                   # Generic DNS cache line
 pattern_ipv4_byte = r"(25[0-5]|2[0-4][0-9]|1[0-9]{2}|[1-9]?[0-9])"                          # Single byte from an IPv4 address
 pattern_ptr       = (pattern_ipv4_byte + r"\.") * 3 + pattern_ipv4_byte + r".in-addr.arpa"  # Reverse DNS lookup qname
 pattern_srv       = r"^(\d+)\s+(\d+)\s+(\d+)\s+([a-zA-Z0-9.-]+)$"                           # SRV record target
@@ -40,18 +40,19 @@ class DnsRtype(IntEnum):
     """
     Enum class for the DNS resource record types.
     """
-    A    = 1   # IPv4 address
-    PTR  = 12  # Domain name pointer
-    AAAA = 28  # IPv6 address
-    SRV  = 33  # Service locator
+    A     = 1   # IPv4 address
+    CNAME = 5   # Canonical name
+    PTR   = 12  # Domain name pointer
+    AAAA  = 28  # IPv6 address
+    SRV   = 33  # Service locator
 
 
 class DnsTableKeys(Enum):
     """
     Enum class for the allowed dictionary keys.
     """
-    IP      = "ip"
-    SERVICE = "service"
+    IP    = "ip"
+    ALIAS = "alias"
 
 
 def read_dns_cache(
@@ -66,8 +67,8 @@ def read_dns_cache(
                 ip_address: domain_name,
                 ...
             },
-            DnsTableKeys.SERVICE: {
-                service_name: actual_name,
+            DnsTableKeys.ALIAS: {
+                canonical_name: alias,
                 ...
             }
         }
@@ -129,9 +130,9 @@ def read_dns_cache(
         if not match:
             continue
 
-        name  = match.group(1)
-        if name.endswith("."):
-            name = name[:-1]
+        qname  = match.group(1)
+        if qname.endswith("."):
+            qname = qname[:-1]
         rtype = match.group(3)
         rdata = match.group(4)
         if rdata.endswith("."):
@@ -148,13 +149,35 @@ def read_dns_cache(
         if rtype == DnsRtype.A.name or rtype == DnsRtype.AAAA.name:
             ip = rdata
             if DnsTableKeys.IP.name in dns_table:
-                dns_table[DnsTableKeys.IP.name][ip] = name
+                dns_table[DnsTableKeys.IP.name][ip] = qname
             else:
-                dns_table[DnsTableKeys.IP.name] = {ip: name}
+                dns_table[DnsTableKeys.IP.name] = {ip: qname}
+
+        # CNAME records
+        if rtype == DnsRtype.CNAME.name:
+            cname = rdata
+            if DnsTableKeys.ALIAS.name in dns_table:
+                dns_table[DnsTableKeys.ALIAS.name][cname] = qname
+            else:
+                dns_table[DnsTableKeys.ALIAS.name] = {cname: qname}
+
+        # SRV records
+        if rtype == DnsRtype.SRV.name:
+            # Parse target service
+            match_srv = re.match(pattern_srv, rdata)
+            if not match_srv:
+                continue
+            service = match_srv.group(4)
+            if service.endswith("."):
+                service = service[:-1]
+            if DnsTableKeys.ALIAS.name in dns_table:
+                dns_table[DnsTableKeys.ALIAS.name][service] = qname
+            else:
+                dns_table[DnsTableKeys.ALIAS.name] = {service: qname}
 
         # PTR records
         if rtype == DnsRtype.PTR.name:
-            match_ptr = re.match(pattern_ptr, name)
+            match_ptr = re.match(pattern_ptr, qname)
             if match_ptr:
                 # PTR record is a reverse DNS lookup
                 ip = ".".join(reversed(match_ptr.groups()))
@@ -165,24 +188,19 @@ def read_dns_cache(
                         dns_table[DnsTableKeys.IP.name] = {ip: rdata}
             else:
                 # PTR record contains generic RDATA
-                if DnsTableKeys.SERVICE.name in dns_table:
-                    dns_table[DnsTableKeys.SERVICE.name][name] = rdata
+                ptr = rdata
+                if DnsTableKeys.ALIAS.name in dns_table:
+                    dns_table[DnsTableKeys.ALIAS.name][qname] = ptr
                 else:
-                    dns_table[DnsTableKeys.SERVICE.name] = {name: rdata}
+                    dns_table[DnsTableKeys.ALIAS.name] = {qname: ptr}
 
-        # SRV records
-        if rtype == DnsRtype.SRV.name:
-            # Parse target service
-            match_srv = re.match(pattern_srv, rdata)
-            if not match_srv:
-                continue
-            service = match_srv.group(4)
-            if service.endswith("."):
-                service = service[:-1]
-            if DnsTableKeys.SERVICE.name in dns_table:
-                dns_table[DnsTableKeys.SERVICE.name][service] = name
-            else:
-                dns_table[DnsTableKeys.SERVICE.name] = {service: name}
+
+    ## Post-processing
+    # Replace all cnames with aliases
+    if DnsTableKeys.IP.name in dns_table and DnsTableKeys.ALIAS.name in dns_table:
+        for ip, cname in dns_table[DnsTableKeys.IP.name].items():
+            if cname in dns_table[DnsTableKeys.ALIAS.name]:
+                dns_table[DnsTableKeys.IP.name][ip] = dns_table[DnsTableKeys.ALIAS.name][cname]
 
 
     return dns_table
diff --git a/test/sample_dns_cache.txt b/test/sample_dns_cache.txt
index a78f062..69e0580 100644
--- a/test/sample_dns_cache.txt
+++ b/test/sample_dns_cache.txt
@@ -7,9 +7,14 @@ example.org.   600    IN    NS   ns1.example.org.
 example.org.   600    IN    NS   ns2.example.org.
 ns1.example.org.  3600  IN    A    192.0.2.1
 ns2.example.org.  3600  IN    A    198.51.100.1
-1.2.0.192.in-addr.arpa. 3600    IN  PTR example.com
-2.2.0.192.in-addr.arpa. 3600    IN  PTR example.com
+; CNAME records
+example1.local. 300 IN   A   192.168.1.100
+example.local.  300 IN   CNAME  example1.local.
+; SRV records
 _tcp_.matter.example.com.   3600    IN  SRV 10 60 5000 server1.example.com
 _tcp_.matter.example.com.   3600    IN  SRV 20 60 5000 server2.example.com
+; PTR records
+1.2.0.192.in-addr.arpa. 3600    IN  PTR example.com
+2.2.0.192.in-addr.arpa. 3600    IN  PTR example.com
 END_RRSET_CACHE
 EOF
diff --git a/test/test_sample_file.py b/test/test_sample_file.py
index b89fde7..5b3d727 100644
--- a/test/test_sample_file.py
+++ b/test/test_sample_file.py
@@ -19,17 +19,19 @@ def test_read_sample_cache_file() -> None:
     """
     dns_table = dns_reader.read_dns_cache(file=sample_cache_file)
     assert DnsTableKeys.IP.name in dns_table
-    assert DnsTableKeys.SERVICE.name in dns_table
+    assert DnsTableKeys.ALIAS.name in dns_table
 
     dns_table_ip = dns_table[DnsTableKeys.IP.name]
-    assert len(dns_table_ip) == 5
+    assert len(dns_table_ip) == 6
     assert dns_table_ip["93.184.216.34"] == "example.com"
     assert dns_table_ip["2606:2800:220:1:248:1893:25c8:1946"] == "example.com"
     assert dns_table_ip["192.0.2.1"] == "ns1.example.org"
     assert dns_table_ip["198.51.100.1"] == "ns2.example.org"
     assert dns_table_ip["192.0.2.2"] == "example.com"
+    assert dns_table_ip["192.168.1.100"] == "example.local"
     
-    dns_table_service = dns_table[DnsTableKeys.SERVICE.name]
-    assert len(dns_table_service) == 2
-    assert dns_table_service["server1.example.com"] == "_tcp_.matter.example.com"
-    assert dns_table_service["server2.example.com"] == "_tcp_.matter.example.com"
+    dns_table_alias = dns_table[DnsTableKeys.ALIAS.name]
+    assert len(dns_table_alias) == 3
+    assert dns_table_alias["example1.local"] == "example.local"
+    assert dns_table_alias["server1.example.com"] == "_tcp_.matter.example.com"
+    assert dns_table_alias["server2.example.com"] == "_tcp_.matter.example.com"
-- 
GitLab