From f04f5d176f8e8374a448df6d7afe263e74e582c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20De=20Keersmaeker?= <francois.dekeersmaeker@uclouvain.be> Date: Tue, 27 May 2025 21:51:44 +0200 Subject: [PATCH] Support for TLS & MQTT (Mehdi Laurent's master's thesis) --- Custom.py | 3 +- mqtt.py | 87 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ tls.py | 74 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 163 insertions(+), 1 deletion(-) create mode 100644 mqtt.py create mode 100644 tls.py diff --git a/Custom.py b/Custom.py index 1430e7f..ee2ab5d 100644 --- a/Custom.py +++ b/Custom.py @@ -65,7 +65,8 @@ class Custom(Protocol): rules = Custom.build_nfq_list_match(value, template_rules, is_backward, func, backward_func) else: # Value is a single element - value = Protocol.convert_value(value) + if type(value) != float: + value = Protocol.convert_value(value) if not is_backward: rules = {"template": template_rules["forward"], "match": func(value)} elif is_backward and "backward" in template_rules: diff --git a/mqtt.py b/mqtt.py new file mode 100644 index 0000000..d0675ee --- /dev/null +++ b/mqtt.py @@ -0,0 +1,87 @@ +from Custom import Custom + +class mqtt(Custom): + + # Class variables + layer = 7 # Protocol OSI layer + protocol_name = "mqtt" # Protocol name + + supported_keys = [ + "packet-type", + "topic-name", + "payload-length" + ] + + def parse(self, is_backward: bool = False, initiator: str = "src") -> dict: + """ + Parse the MQTT protocol. + + :param is_backward (optional): Whether the protocol must be parsed for a backward rule. + Optional, default is `False`. + :param initiator (optional): Connection initiator (src or dst). + Optional, default is "src". + :return: Dictionary containing the (forward and backward) nftables and nfqueue rules for this policy. + """ + # Handle MQTT packet type + packet_type = self.protocol_data.get("packet-type", None) + if packet_type is not None: + rule = {"forward": "mqtt_message.packet_type == {}"} + self.add_field("packet-type", rule, is_backward) + + # Handle MQTT client ID + client_id = self.protocol_data.get("client-id", None) + if client_id is not None: + rule = {"forward": 'strcmp(mqtt_message.client_id, "{}") == 0'} + self.add_field("client-id", rule, is_backward) + + # Handle MQTT client ID length + client_id_length = self.protocol_data.get("client-id-length", None) + if client_id_length is not None: + rule = {"forward": "mqtt_message.client_id_length == {}"} + self.add_field("client-id-length", rule, is_backward) + + # Handle MQTT clean session flag + clean_session = self.protocol_data.get("clean-session", None) + if clean_session is not None: + rule = {"forward": "mqtt_message.connect_flags.clean_session == {}"} + self.add_field("clean-session", rule, is_backward) + + # Handle MQTT Keep Alive + keep_alive = self.protocol_data.get("keep-alive", None) + if keep_alive is not None: + rule = {"forward": "mqtt_message.keep_alive == {}"} + self.add_field("keep-alive", rule, is_backward) + + # Handle MQTT topic name + topic_name = self.protocol_data.get("topic-name", None) + if topic_name is not None: + string = 'strcmp(mqtt_message.topic_name, "{}") == 0' + + if topic_name == "temperature": + string += "\n \t \t&& \n \t \tcheck_payload_regex(mqtt_message.payload, strlen((char *)mqtt_message.payload),\ +\"-?[0-9]?[0-9]\\\\.[0-9]°[CF]\") == 1" # floating point number with 1 decimal place and °C + elif topic_name == "humidity": + string += "\n \t \t&& \n \t \tcheck_payload_regex(mqtt_message.payload, strlen((char *)mqtt_message.payload),\ +\"[0-9]?[0-9]\\\\.[0-9]%\") == 1" # positive floating point number with 1 decimal place and % + + rule = {"forward": string} + self.add_field("topic-name", rule, is_backward) + + # Handle MQTT payload regex + payload_regex = self.protocol_data.get("payload-regex", None) + if payload_regex is not None: + rule = {"forward": 'check_payload_regex(mqtt_message.payload, strlen((char *)mqtt_message.payload), "{}") == 1'} + self.add_field("payload-regex", rule, is_backward) + + # Handle MQTT payload length + payload_length = self.protocol_data.get("payload-length", None) + if payload_length is not None: + payload_length = str(payload_length) + if '-' in payload_length: + min_length, max_length = payload_length.split('-') + rule = {"forward": "mqtt_message.payload_length >= {} && mqtt_message.payload_length <= {}".format(min_length, max_length)} + else: + rule = {"forward": "mqtt_message.payload_length == {}"} + self.add_field("payload-length", rule, is_backward) + + return self.rules diff --git a/tls.py b/tls.py new file mode 100644 index 0000000..35e5961 --- /dev/null +++ b/tls.py @@ -0,0 +1,74 @@ +from Custom import Custom + +class tls(Custom): + + # Class variables + layer = 5 # Protocol OSI layer (arbitrary) + protocol_name = "tls" # Protocol name + + supported_keys = [ + "content-type", + "handshake-type", + "tls-version", + "session-id" + ] + + def parse(self, is_backward: bool = False, initiator: str = "src") -> dict: + """ + Parse the TLS protocol. + + :param is_backward (optional): Whether the protocol must be parsed for a backward rule. + Optional, default is `False`. + :param initiator (optional): Connection initiator (src or dst). + Optional, default is "src". + :return: Dictionary containing the (forward and backward) nftables and nfqueue rules for this policy. + """ + # Handle TLS content type + content_type = self.protocol_data.get("content-type", None) + if content_type is not None: + rule = {"forward": "tls_packet->messages->message.content_type == {}"} + self.add_field("content-type", rule, is_backward) + + # Handle TLS handshake type + handshake_type = self.protocol_data.get("handshake-type", None) + if handshake_type is not None: + if not ',' in str(handshake_type): + rule = {"forward": "tls_packet != NULL && tls_packet->messages != NULL && tls_packet->messages->message.handshake_type == {}"} + else: + if '[' in str(handshake_type): # in case format in profile is '[ {n°}, {n°}, ... ]' + handshake_type = handshake_type.replace('[', '').replace(']', '') + lst = [int(x.strip(), 16 if x.strip().startswith('0x') else 10) for x in handshake_type.split(',')] + conditions = ["tls_packet != NULL"] + current_chain = "tls_packet->messages" + conditions.append(f"{current_chain} != NULL") + + for i, value in enumerate(lst): + if i > 0: + current_chain = f"{current_chain}->next" + conditions.append(f"{current_chain} != NULL") + + if value == 20: # change cipher spec message or much less likely finished handshake encryted message + conditions.append(f"( {current_chain}->message.content_type == {value} || {current_chain}->message.handshake_type == {value} )") + else: + conditions.append(f"{current_chain}->message.handshake_type == {value}") + + string = " && ".join(conditions) + rule = {"forward": string} + self.add_field("handshake-type", rule, is_backward) + + # Handle TLS version + tls_version = self.protocol_data.get("tls-version", None) + if tls_version is not None: + if '"' in str(tls_version): # in case written in profile as string, is float otherwise + tls_version = tls_version.replace('"', '') + rule = {"forward": "tls_packet->messages->message.tls_version == {}"} + func = lambda tls_version: str(769 + int(str(tls_version)[-1])) # 769 = 0x0301 + self.add_field("tls-version", rule, is_backward, func) + + # Handle TLS session ID + session_id = self.protocol_data.get("session-id", None) + if session_id is not None: + rule = {"forward": "tls_packet->messages->message.session_id_present == {}"} + self.add_field("session-id", rule, is_backward) + + return self.rules -- GitLab