From f04f5d176f8e8374a448df6d7afe263e74e582c5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Fran=C3=A7ois=20De=20Keersmaeker?=
 <francois.dekeersmaeker@uclouvain.be>
Date: Tue, 27 May 2025 21:51:44 +0200
Subject: [PATCH] Support for TLS & MQTT (Mehdi Laurent's master's thesis)

---
 Custom.py |  3 +-
 mqtt.py   | 87 +++++++++++++++++++++++++++++++++++++++++++++++++++++++
 tls.py    | 74 ++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 163 insertions(+), 1 deletion(-)
 create mode 100644 mqtt.py
 create mode 100644 tls.py

diff --git a/Custom.py b/Custom.py
index 1430e7f..ee2ab5d 100644
--- a/Custom.py
+++ b/Custom.py
@@ -65,7 +65,8 @@ class Custom(Protocol):
                 rules = Custom.build_nfq_list_match(value, template_rules, is_backward, func, backward_func)
             else:
                 # Value is a single element
-                value = Protocol.convert_value(value)
+                if type(value) != float:
+                    value = Protocol.convert_value(value)
                 if not is_backward:
                     rules = {"template": template_rules["forward"], "match": func(value)}
                 elif is_backward and "backward" in template_rules:
diff --git a/mqtt.py b/mqtt.py
new file mode 100644
index 0000000..d0675ee
--- /dev/null
+++ b/mqtt.py
@@ -0,0 +1,87 @@
+from Custom import Custom
+
+class mqtt(Custom):
+
+    # Class variables
+    layer = 7              # Protocol OSI layer
+    protocol_name = "mqtt" # Protocol name
+
+    supported_keys = [
+        "packet-type",
+        "topic-name",
+        "payload-length"
+    ]
+
+    def parse(self, is_backward: bool = False, initiator: str = "src") -> dict:
+        """
+        Parse the MQTT protocol.
+
+        :param is_backward (optional): Whether the protocol must be parsed for a backward rule.
+                                       Optional, default is `False`.
+        :param initiator (optional): Connection initiator (src or dst).
+                                     Optional, default is "src".
+        :return: Dictionary containing the (forward and backward) nftables and nfqueue rules for this policy.
+        """
+        # Handle MQTT packet type
+        packet_type = self.protocol_data.get("packet-type", None)
+        if packet_type is not None:
+            rule = {"forward": "mqtt_message.packet_type == {}"}
+            self.add_field("packet-type", rule, is_backward)
+
+        # Handle MQTT client ID
+        client_id = self.protocol_data.get("client-id", None)
+        if client_id is not None:
+            rule = {"forward": 'strcmp(mqtt_message.client_id, "{}") == 0'}
+            self.add_field("client-id", rule, is_backward)
+
+        # Handle MQTT client ID length
+        client_id_length = self.protocol_data.get("client-id-length", None)
+        if client_id_length is not None:
+            rule = {"forward": "mqtt_message.client_id_length == {}"}
+            self.add_field("client-id-length", rule, is_backward)
+
+        # Handle MQTT clean session flag
+        clean_session = self.protocol_data.get("clean-session", None)
+        if clean_session is not None:
+            rule = {"forward": "mqtt_message.connect_flags.clean_session == {}"}
+            self.add_field("clean-session", rule, is_backward)
+        
+        # Handle MQTT Keep Alive
+        keep_alive = self.protocol_data.get("keep-alive", None)
+        if keep_alive is not None:
+            rule = {"forward": "mqtt_message.keep_alive == {}"}
+            self.add_field("keep-alive", rule, is_backward)
+
+        # Handle MQTT topic name
+        topic_name = self.protocol_data.get("topic-name", None)
+        if topic_name is not None:
+            string = 'strcmp(mqtt_message.topic_name, "{}") == 0'
+
+            if topic_name == "temperature":
+                string += "\n \t \t&& \n \t \tcheck_payload_regex(mqtt_message.payload, strlen((char *)mqtt_message.payload),\
+\"-?[0-9]?[0-9]\\\\.[0-9]°[CF]\") == 1" # floating point number with 1 decimal place and °C
+            elif topic_name == "humidity":
+                string += "\n \t \t&& \n \t \tcheck_payload_regex(mqtt_message.payload, strlen((char *)mqtt_message.payload),\
+\"[0-9]?[0-9]\\\\.[0-9]%\") == 1" # positive floating point number with 1 decimal place and %
+
+            rule = {"forward": string}
+            self.add_field("topic-name", rule, is_backward)
+        
+        # Handle MQTT payload regex
+        payload_regex = self.protocol_data.get("payload-regex", None)
+        if payload_regex is not None:
+            rule = {"forward": 'check_payload_regex(mqtt_message.payload, strlen((char *)mqtt_message.payload), "{}") == 1'}
+            self.add_field("payload-regex", rule, is_backward)
+
+        # Handle MQTT payload length
+        payload_length = self.protocol_data.get("payload-length", None)
+        if payload_length is not None:
+            payload_length = str(payload_length)
+            if '-' in payload_length:
+                min_length, max_length = payload_length.split('-')
+                rule = {"forward": "mqtt_message.payload_length >= {} && mqtt_message.payload_length <= {}".format(min_length, max_length)}
+            else:
+                rule = {"forward": "mqtt_message.payload_length == {}"}
+            self.add_field("payload-length", rule, is_backward)
+
+        return self.rules
diff --git a/tls.py b/tls.py
new file mode 100644
index 0000000..35e5961
--- /dev/null
+++ b/tls.py
@@ -0,0 +1,74 @@
+from Custom import Custom
+
+class tls(Custom):
+
+    # Class variables
+    layer = 5               # Protocol OSI layer (arbitrary)
+    protocol_name = "tls"   # Protocol name
+
+    supported_keys = [
+        "content-type",
+        "handshake-type",
+        "tls-version",
+        "session-id"
+    ]
+
+    def parse(self, is_backward: bool = False, initiator: str = "src") -> dict:
+        """
+        Parse the TLS protocol.
+
+        :param is_backward (optional): Whether the protocol must be parsed for a backward rule.
+                                       Optional, default is `False`.
+        :param initiator (optional): Connection initiator (src or dst).
+                                     Optional, default is "src".
+        :return: Dictionary containing the (forward and backward) nftables and nfqueue rules for this policy.
+        """
+        # Handle TLS content type
+        content_type = self.protocol_data.get("content-type", None)
+        if content_type is not None:
+            rule = {"forward": "tls_packet->messages->message.content_type == {}"}
+            self.add_field("content-type", rule, is_backward)
+        
+        # Handle TLS handshake type
+        handshake_type = self.protocol_data.get("handshake-type", None)
+        if handshake_type is not None:
+            if not ',' in str(handshake_type):
+                rule = {"forward": "tls_packet != NULL && tls_packet->messages != NULL && tls_packet->messages->message.handshake_type == {}"}
+            else:
+                if '[' in str(handshake_type): # in case format in profile is '[ {n°}, {n°}, ... ]'
+                    handshake_type = handshake_type.replace('[', '').replace(']', '')
+                lst = [int(x.strip(), 16 if x.strip().startswith('0x') else 10) for x in handshake_type.split(',')]
+                conditions = ["tls_packet != NULL"]
+                current_chain = "tls_packet->messages"
+                conditions.append(f"{current_chain} != NULL")
+                
+                for i, value in enumerate(lst):
+                    if i > 0:
+                        current_chain = f"{current_chain}->next"
+                        conditions.append(f"{current_chain} != NULL")
+                    
+                    if value == 20:  # change cipher spec message or much less likely finished handshake encryted message
+                        conditions.append(f"( {current_chain}->message.content_type == {value} || {current_chain}->message.handshake_type == {value} )")
+                    else:
+                        conditions.append(f"{current_chain}->message.handshake_type == {value}")
+                        
+                string = " && ".join(conditions)
+                rule = {"forward": string}
+            self.add_field("handshake-type", rule, is_backward)
+        
+        # Handle TLS version
+        tls_version = self.protocol_data.get("tls-version", None)
+        if tls_version is not None:
+            if '"' in str(tls_version): # in case written in profile as string, is float otherwise
+                tls_version = tls_version.replace('"', '')
+            rule = {"forward": "tls_packet->messages->message.tls_version == {}"}
+            func = lambda tls_version: str(769 + int(str(tls_version)[-1])) # 769 = 0x0301
+            self.add_field("tls-version", rule, is_backward, func)
+        
+        # Handle TLS session ID
+        session_id = self.protocol_data.get("session-id", None)
+        if session_id is not None:
+            rule = {"forward": "tls_packet->messages->message.session_id_present == {}"}
+            self.add_field("session-id", rule, is_backward)
+        
+        return self.rules
-- 
GitLab