diff --git a/src/parsers b/src/parsers index 4639c2725912db87ab4e476b1cee26324db1bad2..46c0382f99c7006827f956ec0ccf2c3936c1737b 160000 --- a/src/parsers +++ b/src/parsers @@ -1 +1 @@ -Subproject commit 4639c2725912db87ab4e476b1cee26324db1bad2 +Subproject commit 46c0382f99c7006827f956ec0ccf2c3936c1737b diff --git a/src/translator/.gitignore b/src/translator/.gitignore deleted file mode 100644 index f47ac92c2adc6532d81de641e7fe0ac3fcb8b396..0000000000000000000000000000000000000000 --- a/src/translator/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -*.c -*.nft \ No newline at end of file diff --git a/src/translator/LogType.py b/src/translator/LogType.py deleted file mode 100644 index abac54f45eb7f1d4f344e3dda4d3106a4acb2841..0000000000000000000000000000000000000000 --- a/src/translator/LogType.py +++ /dev/null @@ -1,12 +0,0 @@ -from enum import IntEnum - -class LogType(IntEnum): - """ - Enum class for the type of logging to be used. - """ - NONE = 0 # No logging - CSV = 1 # Log to a CSV file - PCAP = 2 # Log to a PCAP file - - def __str__(self): - return self.name diff --git a/src/translator/NFQueue.py b/src/translator/NFQueue.py deleted file mode 100644 index 17b788680156ee5180fd297fb3864753cd8b9ed4..0000000000000000000000000000000000000000 --- a/src/translator/NFQueue.py +++ /dev/null @@ -1,264 +0,0 @@ -import re -from copy import deepcopy -from LogType import LogType -from Policy import Policy - -class NFQueue: - """ - Class which represents a single nfqueue. - """ - - # Class variables - time_units = { - "second": 1, - "minute": 60, - "hour": 60 * 60, - "day": 60 * 60 * 24, - "week": 60 * 60 * 24 * 7 - } - - - def __init__(self, name: str, nft_matches: list, queue_num: int = -1) -> None: - """ - Initialize a new NFQueue object. - - :param name: descriptive name for the nfqueue - :param nft_matches: list of nftables matches corresponding to this queue - :param queue_num: number of the nfqueue queue corresponding to this policy, - or a negative number if the policy is simply `drop` - """ - self.name = name # Descriptive name for this nfqueue (name of the first policy to be added) - self.queue_num = queue_num # Number of the corresponding nfqueue - self.policies = [] # List of policies associated to this nfqueue - self.nft_matches = deepcopy(nft_matches) # List of nftables matches associated to this nfqueue - self.nft_stats = {} - - - def __eq__(self, other: object) -> bool: - """ - Compare another object to this NFQueue object. - - :param other: object to compare to this NFQueue object - :return: True if the other object is an NFQueue object with the same nftables match, False otherwise - """ - if not isinstance(other, self.__class__): - return NotImplemented - key_func = lambda x: x["template"].format(x["match"]) - self_matches = sorted(self.nft_matches, key=key_func) - other_matches = sorted(other.nft_matches, key=key_func) - return ( self.name == other.name and - self.queue_num == other.queue_num and - self_matches == other_matches ) - - - def contains_policy_matches(self, policy: Policy) -> bool: - """ - Check if this NFQueue object contains the nftables matches of the given policy. - - :param policy: policy to check - :return: True if this NFQueue object contains the nftables matches of the given policy, False otherwise - """ - key_func = lambda x: x["template"].format(x["match"]) - policy_matches = sorted(policy.nft_matches, key=key_func) - self_matches = sorted(self.nft_matches, key=key_func) - return policy_matches == self_matches - - - @staticmethod - def parse_rate_match(match: str) -> dict: - """ - Parse the rate match and return a dictionary containing the rate and burst values. - - :param match: rate match to parse - :return: dictionary containing the rate and burst values, or None if the match could not be parsed - """ - # Try to match a rate of 0, which means no rate limit - if match == 0: - return {"value": 0, "unit": None} - - # Try to match a packet rate with burst - try: - return re.compile(r"\s*(?P<value>\d+)/(?P<unit>second|minute|hour|day|week)\s+burst\s+(?P<burst_value>\d+)\s+(?P<burst_unit>packets|.bytes)\s*").match(match).groupdict() - except AttributeError: - pass - - # Try to match a packet rate without burst - try: - return re.compile(r"\s*(?P<value>\d+)/(?P<unit>second|minute|hour|day|week)\s*").match(match).groupdict() - except AttributeError: - pass - - # Return None if the match could not be parsed - return None - - - def update_rate_match(self, new_match: str) -> None: - """ - Update the rate NFTables match for this NFQueue object, if needed. - - :param new_match: new match to be compared to the current one - """ - old_match = NFQueue.parse_rate_match(self.nft_stats["rate"]["match"]) - new_match = NFQueue.parse_rate_match(new_match) - - # One of the rates is 0, which means no rate limit - if old_match["value"] == 0 or new_match["value"] == 0: - self.nft_stats["rate"]["match"] = 0 - return - - # Both rates are specified - # Compute and update rate - old_rate = float(old_match["value"]) / NFQueue.time_units[old_match["unit"]] - new_rate = float(new_match["value"]) / NFQueue.time_units[new_match["unit"]] - rate_sum = int(old_rate + new_rate) - updated_rate = "{}/{}".format(rate_sum, "second") - - # Compute and update new burst, if needed - if "burst_value" in old_match and "burst_value" in new_match: - if old_match["burst_unit"] == new_match["burst_unit"]: - old_burst = int(old_match["burst_value"]) - new_burst = int(new_match["burst_value"]) - burst_sum = old_burst + new_burst - updated_rate += " burst {} {}".format(burst_sum, old_match["burst_unit"]) - else: - # Burst units are different, so we cannot sum them - # Keep the old burst - updated_rate += " burst {} {}".format(old_match["burst_value"], old_match["burst_unit"]) - elif "burst_value" in new_match: - updated_rate += " burst {} {}".format(new_match["burst_value"], new_match["burst_unit"]) - elif "burst_value" in old_match: - updated_rate += " burst {} {}".format(old_match["burst_value"], old_match["burst_unit"]) - - # Set updated rate - self.nft_stats["rate"]["match"] = updated_rate - - - @staticmethod - def parse_size_match(match: str) -> tuple: - """ - Parse the packet size match and return a tuple containing the lower and upper bounds. - - :param match: packet size match to parse - :return: tuple containing the lower and upper bounds of the packet size match, - or None if the match could not be parsed - """ - try: - # Try to match a single upper bound value - re_upper = int(re.compile(r"\s*<\s*(?P<upper>\d+)\s*").match(match).group("upper")) - re_lower = 0 - except AttributeError: - try: - # Try to match a range of values - re_range = re.compile(r"\s*(?P<lower>\d+)\s*-\s*(?P<upper>\d+)\s*").match(match) - re_lower = int(re_range.group("lower")) - re_upper = int(re_range.group("upper")) - except AttributeError: - # No match found - return None - return (re_lower, re_upper) - - - def update_size_match(self, new_match: str): - """ - Update the packet size NFTables match for this NFQueue object, if needed. - - :param new_match: new match to be compared to the current one - """ - old_values = NFQueue.parse_size_match(self.nft_stats["packet-size"]["match"]) - new_values = NFQueue.parse_size_match(new_match) - new_lower = min(old_values[0], new_values[0]) - new_upper = max(old_values[1], new_values[1]) - if new_lower == 0: - new_match = "< {}".format(new_upper) - else: - new_match = "{} - {}".format(new_lower, new_upper) - self.nft_stats["packet-size"]["match"] = new_match - - - def update_match(self, stat: str, new_match: str): - """ - Update the match for the given stat, if needed. - Stat match is set to the least restrictive match between the current and the new one. - - :param stat: name of the stat to update - :param new_match: new match to set, if needed - """ - if stat == "rate": - self.update_rate_match(new_match) - elif stat == "packet-size": - self.update_size_match(new_match) - - - def add_policy(self, policy: Policy) -> bool: - """ - Add a policy to this NFQueue object. - - :param policy: policy to add - :return: True if the nfqueue queue number has been updated, False otherwise - """ - result = False - - # Update nfqueue queue number if necessary - if self.queue_num < 0 and policy.queue_num >= 0: - self.queue_num = policy.queue_num - result = True - - # Create policy dictionary - policy_dict = { - "policy": policy, - "counters_idx": {} - } - - # Update NFT stat matches if necessary - nfq_stats = policy.get_nft_match_stats() - for stat, data in nfq_stats.items(): - if stat not in self.nft_stats: - self.nft_stats[stat] = data - else: - self.update_match(stat, data["match"]) - - # Append policy to the list of policies - self.policies.append(policy_dict) - # Sort list of policies, with default drop policies at the end - sort_key = lambda x: (x["policy"]) - self.policies.sort(key=sort_key) - return result - - - def get_nft_rule(self, drop_proba: float = 1.0, log_type: LogType = LogType.NONE, log_group: int = 100) -> str: - """ - Retrieve the complete nftables rule, composed of the complete nftables match - and the action, for this nfqueue. - - :return: complete nftables rule for this nfqueue - """ - # Set NFT rule match - nft_rule = "" - for nft_match in self.nft_matches: - nft_rule += nft_match["template"].format(nft_match["match"]) + " " - for stat in self.nft_stats.values(): - if stat["match"] != 0: - nft_rule += stat["template"].format(stat["match"]) + " " - - # Set NFT rule action and log verdict (queue, drop, accept) - nft_action = "" - verdict = "" - if self.queue_num >= 0: - nft_action = f"queue num {self.queue_num}" - verdict = "QUEUE" - elif drop_proba == 1.0: - nft_action = "drop" - verdict = "DROP" - elif drop_proba == 0.0: - nft_action = "accept" - verdict = "ACCEPT" - - # Set full NFT action, including logging (if specified) - if log_type == LogType.CSV: - log = f"log prefix \"{self.name},,{verdict}\" group {log_group}" - return f"{nft_rule}{log} {nft_action}" - elif log_type == LogType.PCAP: - log = f"log group {log_group}" - return f"{nft_rule}{log} {nft_action}" - else: - return f"{nft_rule}{nft_action}" diff --git a/src/translator/Policy.py b/src/translator/Policy.py deleted file mode 100644 index 463d5f2e0738abf186777358944bbd1d94c754bf..0000000000000000000000000000000000000000 --- a/src/translator/Policy.py +++ /dev/null @@ -1,444 +0,0 @@ -## Import packages -from enum import Enum -from typing import Tuple, Dict -import ipaddress -## Custom libraries -from LogType import LogType -# Protocol translators -from protocols.Protocol import Protocol -from protocols.ip import ip - - -class Policy: - """ - Class which represents a single access control policy. - """ - - class NftType(Enum): - """ - Enum: NFTables types. - Possible values: - - MATCH: nftables match - - ACTION: nftables action - """ - MATCH = 1 - ACTION = 2 - - # Metadata for supported nftables statistics - stats_metadata = { - "rate": {"nft_type": NftType.MATCH, "counter": False, "template": "limit rate over {}"}, - "packet-size": {"nft_type": NftType.MATCH, "counter": False, "template": "ip length {}"}, - "packet-count": {"counter": True}, - "duration": {"counter": True} - } - - def __init__(self, policy_name: str, profile_data: dict, device: dict, is_backward: bool = False) -> None: - """ - Initialize a new Policy object. - - :param name: Name of the policy - :param profile_data: Dictionary containing the policy data from the YAML profile - :param device: Dictionary containing the device metadata from the YAML profile - :param is_backward: Whether the policy is backwards (i.e. the source and destination are reversed) - """ - self.name = policy_name # Policy name - self.profile_data = profile_data # Policy data from the YAML profile - self.device = device # Dictionary containing data for the device this policy is linked to - self.is_backward = is_backward # Whether the policy is backwards (i.e. the source and destination are reversed) - self.custom_parser = "" # Name of the custom parser (if any) - self.nft_matches = [] # List of nftables matches (will be populated by parsing) - self.nft_match = "" # Complete nftables match (including rate and packet size) - self.nft_stats = {} # Dict of nftables statistics (will be populated by parsing) - self.queue_num = -1 # Number of the nfqueue queue corresponding (will be updated by parsing) - self.nft_action = "" # nftables action associated to this policy - self.nfq_matches = [] # List of nfqueue matches (will be populated by parsing) - self.is_bidirectional = self.profile_data.get("bidirectional", False) # Whether the policy is bidirectional - self.initiator = profile_data["initiator"] if "initiator" in profile_data else "" - - - def __eq__(self, other: object) -> bool: - """ - Check whether this Policy object is equal to another object. - - :param other: object to compare to this Policy object - :return: True if the other object represents the same policy, False otherwise - """ - if not isinstance(other, self.__class__): - return NotImplemented - # Other object is a Policy object - key_func = lambda x: x["template"].format(x["match"]) - self_matches = sorted(self.nft_matches, key=key_func) - other_matches = sorted(other.nft_matches, key=key_func) - return ( other.name == self.name and - other.is_backward == self.is_backward and - self.device == other.device and - self_matches == other_matches and - self.nft_stats == other.nft_stats and - self.nft_action == other.nft_action and - self.queue_num == other.queue_num and - self.is_bidirectional == other.is_bidirectional ) - - def __lt__(self, other: object) -> bool: - """ - Check whether this Policy object is less than another object. - - :param other: object to compare to this Policy object - :return: True if this Policy object is less than the other object, False otherwise - """ - if not isinstance(other, self.__class__): - return NotImplemented - # Other object is a Policy object - if self.queue_num >= 0 and other.queue_num >= 0: - return self.queue_num < other.queue_num - elif self.queue_num < 0: - return False - elif other.queue_num < 0: - return True - else: - return self.name < other.name - - - def __hash__(self) -> int: - """ - Compute a hash value for this Policy object. - - :return: hash value for this Policy object - """ - return hash((self.name, self.is_backward)) - - - @staticmethod - def get_field_static(var: any, field: str, parent_key: str = "") -> Tuple[any, any]: - """ - Retrieve the parent key and value for a given field in a dict. - Adapted from https://stackoverflow.com/questions/9807634/find-all-occurrences-of-a-key-in-nested-dictionaries-and-lists. - - :param var: Data structure to search in - :param field: Field to retrieve - :param parent_key: Parent key of the current data structure - :return: tuple containing the parent key and the value for the given field, - or None if the field is not found - """ - if hasattr(var, 'items'): - for k, v in var.items(): - if k == field: - return parent_key, v - if isinstance(v, dict): - result = Policy.get_field_static(v, field, k) - if result is not None: - return result - elif isinstance(v, list): - for d in v: - result = Policy.get_field_static(d, field, k) - if result is not None: - return result - return None - - - def get_field(self, field: str) -> Tuple[any, any]: - """ - Retrieve the value for a given field in the policy profile data. - Adapted from https://stackoverflow.com/questions/9807634/find-all-occurrences-of-a-key-in-nested-dictionaries-and-lists. - - :param field: Field to retrieve - :return: tuple containing the parent key and the value for the given field, - or None if the field is not found - """ - return Policy.get_field_static(self.profile_data, field, self.name) - - - def parse_stat(self, stat: str) -> Dict[str, str]: - """ - Parse a single statistic. - Add the corresponding counters and nftables matches. - - :param stat: Statistic to handle - :return: parsed stat, with the form {"template": ..., "match": ...} - """ - parsed_stat = None - value = self.profile_data["stats"][stat] - if type(value) == dict: - # Stat is a dictionary, and contains data for directions "fwd" and "bwd" - value_fwd = Policy.parse_duration(value["fwd"]) if stat == "duration" else value["fwd"] - value_bwd = Policy.parse_duration(value["bwd"]) if stat == "duration" else value["bwd"] - if Policy.stats_metadata[stat]["counter"]: - # Add counters for "fwd" and "bwd" directions - self.counters[stat] = { - "fwd": value_fwd, - "bwd": value_bwd - } - if stat in Policy.stats_metadata and "template" in Policy.stats_metadata[stat]: - parsed_stat = { - "template": Policy.stats_metadata[stat]["template"], - "match": value_bwd if self.is_backward else value_fwd, - } - else: - # Stat is a single value, which is used for both directions - if Policy.stats_metadata[stat]["counter"]: - value = Policy.parse_duration(value) if stat == "duration" else value - self.counters[stat] = {"default": value} - value = f"\"{self.name[:-len('-backward')] if self.is_backward else self.name}\"" - if stat in Policy.stats_metadata and "template" in Policy.stats_metadata[stat]: - parsed_stat = { - "template": Policy.stats_metadata[stat]["template"], - "match": value - } - - if parsed_stat is not None and "nft_type" in Policy.stats_metadata[stat]: - self.nft_stats[stat] = parsed_stat - - - def build_nft_rule(self, queue_num: int, drop_proba: float = 1.0, log_type: LogType = LogType.NONE, log_group: int = 100) -> str: - """ - Build and store the nftables match and action, as strings, for this policy. - - :param queue_num: number of the nfqueue queue corresponding to this policy, - or a negative number if the policy is simply `drop` - :param rate: rate limit, in packets/second, for this policy - :param log_type: type of logging to enable - :param log_group: log group number - :return: complete nftables rule for this policy - """ - self.queue_num = queue_num - - # nftables match - for i in range(len(self.nft_matches)): - if i > 0: - self.nft_match += " " - template = self.nft_matches[i]["template"] - data = self.nft_matches[i]["match"] - self.nft_match += template.format(*(data)) if type(data) == list else template.format(data) - - # nftables stats - for stat in self.nft_stats: - template = self.nft_stats[stat]["template"] - data = self.nft_stats[stat]["match"] - if Policy.stats_metadata[stat].get("nft_type", 0) == Policy.NftType.MATCH: - self.nft_match += " " + (template.format(*(data)) if type(data) == list else template.format(data)) - elif Policy.stats_metadata[stat].get("nft_type", 0) == Policy.NftType.ACTION: - if self.nft_action: - self.nft_action += " " - self.nft_action += (template.format(*(data)) if type(data) == list else template.format(data)) - - ## nftables action - if self.nft_action: - self.nft_action += " " - - # Log action - verdict = "" - if queue_num >= 0: - verdict = "QUEUE" - elif drop_proba == 1.0: - verdict = "DROP" - elif drop_proba == 0.0: - verdict = "ACCEPT" - if log_type == LogType.CSV: - self.nft_action += f"log prefix \\\"{self.name},,{verdict}\\\" group {log_group} " - elif log_type == LogType.PCAP: - self.nft_action += f"log group {log_group} " - - # Verdict action - if queue_num >= 0: - self.nft_action += f"queue num {queue_num}" - elif drop_proba == 1.0: - self.nft_action += "drop" - elif drop_proba == 0.0: - self.nft_action += "accept" - - return self.get_nft_rule() - - - def get_nft_rule(self) -> str: - """ - Retrieve the complete nftables rule, composed of the complete nftables match - and the action, for this policy. - - :return: complete nftables rule for this policy - """ - return f"{self.nft_match} {self.nft_action}" - - - def parse(self) -> None: - """ - Parse the policy and populate the related instance variables. - """ - # Parse protocols - for protocol_name in self.profile_data["protocols"]: - try: - profile_protocol = self.profile_data["protocols"][protocol_name] - protocol = Protocol.init_protocol(protocol_name, profile_protocol, self.device) - except ModuleNotFoundError: - # Unsupported protocol, skip it - continue - else: - # Protocol is supported, parse it - - # Add custom parser if needed - if protocol.custom_parser: - self.custom_parser = protocol_name - - ### Check involved devices - protocols = ["arp", "ipv4", "ipv6"] - # This device's addresses - addrs = ["mac", "ipv4", "ipv6"] - self_addrs = ["self"] - for addr in addrs: - device_addr = self.device.get(addr, None) - if device_addr is not None: - self_addrs.append(device_addr) - if protocol_name in protocols: - ip_proto = "ipv6" if protocol_name == "ipv6" else "ipv4" - src = profile_protocol.get("spa", None) if protocol_name == "arp" else profile_protocol.get("src", None) - dst = profile_protocol.get("tpa", None) if protocol_name == "arp" else profile_protocol.get("dst", None) - - # Check if device is involved - if src in self_addrs or dst in self_addrs: - self.is_device = True - - # Device is not involved - else: - # Try expliciting source address - try: - saddr = ipaddress.ip_network(protocol.explicit_address(src)) - except ValueError: - saddr = None - - # Try expliciting destination address - try: - daddr = ipaddress.ip_network(protocol.explicit_address(dst)) - except ValueError: - daddr = None - - # Check if the involved other host is in the local network - local_networks = ip.addrs[ip_proto]["local"] - if isinstance(local_networks, list): - lans = map(lambda cidr: ipaddress.ip_network(cidr), local_networks) - else: - lans = [ipaddress.ip_network(local_networks)] - if saddr is not None and any(lan.supernet_of(saddr) for lan in lans): - self.other_host["protocol"] = protocol_name - self.other_host["direction"] = "src" - self.other_host["address"] = saddr - elif daddr is not None and any(lan.supernet_of(daddr) for lan in lans): - self.other_host["protocol"] = protocol_name - self.other_host["direction"] = "dst" - self.other_host["address"] = daddr - - # Add nft rules - new_rules = protocol.parse(is_backward=self.is_backward, initiator=self.initiator) - self.nft_matches += new_rules["nft"] - - # Add nfqueue matches - for match in new_rules["nfq"]: - self.nfq_matches.append(match) - - # Parse statistics - if "stats" in self.profile_data: - for stat in self.profile_data["stats"]: - if stat in Policy.stats_metadata: - self.parse_stat(stat) - - - def get_domain_name_hosts(self) -> Tuple[str, dict]: - """ - Retrieve the domain names and IP addresses for this policy, if any. - - :return: tuple containing: - - the IP family nftables match (`ip` or `ip6`) - - a dictionary containing a mapping between the direction matches (`saddr` or `daddr`) - and the corresponding domain names or list of IP addresses - """ - result = {} - directions = { - "src": "daddr" if self.is_backward else "saddr", - "dst": "saddr" if self.is_backward else "daddr" - } - protocol = "ipv4" - for dir, match in directions.items(): - field = self.get_field(dir) - if field is None: - # Field is not present in the policy - continue - - protocol, addr = self.get_field(dir) - if not ip.is_ip_static(addr, protocol): - # Host is a domain name, or - # list of hosts includes domain names - if type(addr) is list: - # Field is a list of hosts - for host in addr: - if ip.is_ip_static(host, protocol): - # Host is an explicit or well-known address - if match not in result: - result[match] = {} - result[match]["ip_addresses"] = result[match].get("ip_addresses", []) + [host] - else: - # Address is not explicit or well-known, might be a domain name - if match not in result: - result[match] = {} - result[match]["domain_names"] = result[match].get("domain_names", []) + [host] - else: - # Field is a single host - if match not in result: - result[match] = {} - result[match]["domain_names"] = result[match].get("domain_names", []) + [addr] - protocol = "ip" if protocol == "ipv4" else "ip6" - return protocol, result - - - def is_base_for_counter(self, counter: str): - """ - Check if the policy is the base policy for a given counter. - - :param counter: Counter to check (packet-count or duration) - :return: True if the policy is the base policy for the given counter and direction, False otherwise - """ - if counter not in self.counters: - return False - - # Counter is present for this policy - direction = "bwd" if self.is_backward else "fwd" - return ( ("default" in self.counters[counter] and not self.is_backward) or - direction in self.counters[counter] ) - - - def is_backward_for_counter(self, counter: str): - """ - Check if the policy is the backward policy for a given counter. - - :param counter: Counter to check (packet-count or duration) - :return: True if the policy is the backward policy for the given counter and direction, False otherwise - """ - if counter not in self.counters: - return False - - # Counter is present for this policy - return "default" in self.counters[counter] and self.is_backward - - - def get_data_from_nfqueues(self, nfqueues: list) -> dict: - """ - Retrieve the policy dictionary from the nfqueue list. - - :param nfqueues: List of nfqueues - :return: dictionary containing the policy data, - or None if the policy is not found - """ - for nfqueue in nfqueues: - for policy_dict in nfqueue.policies: - if policy_dict["policy"] == self: - return policy_dict - return None - - - def get_nft_match_stats(self) -> dict: - """ - Retrieve this policy's stats which correspond to an NFTables match. - - :return: dictionary containing the policy match statistics - """ - result = {} - for stat, data in self.nft_stats.items(): - if Policy.stats_metadata.get(stat, {}).get("nft_type", None) == Policy.NftType.MATCH: - result[stat] = data - return result diff --git a/src/translator/expand.py b/src/translator/expand.py deleted file mode 100644 index efd8563a4a4e856d873a86d303b417d59b659e29..0000000000000000000000000000000000000000 --- a/src/translator/expand.py +++ /dev/null @@ -1,29 +0,0 @@ -import os -import argparse -import yaml -from yaml_loaders.IncludeLoader import IncludeLoader - - -##### MAIN ##### -if __name__ == "__main__": - - # Command line arguments - description = "Expand a device YAML profile to its full form." - parser = argparse.ArgumentParser(description=description) - parser.add_argument("profile", type=str, help="Path to the device YAML profile") - args = parser.parse_args() - - # Retrieve useful paths - script_path = os.path.abspath(os.path.dirname(__file__)) # This script's path - device_path = os.path.abspath(os.path.dirname(args.profile)) # Device profile's path - - # Load the device profile - with open(args.profile, "r") as f_a: - - # Load YAML profile with custom loader - profile = yaml.load(f_a, IncludeLoader) - - # Write the expanded profile to a new file - expanded_profile_path = os.path.join(device_path, "expanded_profile.yaml") - with open(expanded_profile_path, "w") as f_b: - yaml.dump(profile, f_b, default_flow_style=False) diff --git a/src/translator/protocols b/src/translator/protocols deleted file mode 160000 index baaf4b6850ce8a97e39c20882e60f2ef468b4ec6..0000000000000000000000000000000000000000 --- a/src/translator/protocols +++ /dev/null @@ -1 +0,0 @@ -Subproject commit baaf4b6850ce8a97e39c20882e60f2ef468b4ec6 diff --git a/src/translator/templates/CMakeLists.txt.j2 b/src/translator/templates/CMakeLists.txt.j2 deleted file mode 100644 index b29d1b4d718b3e51ad058d8a091b7e38e5048ad6..0000000000000000000000000000000000000000 --- a/src/translator/templates/CMakeLists.txt.j2 +++ /dev/null @@ -1,15 +0,0 @@ -# Minimum required CMake version -cmake_minimum_required(VERSION 3.20) - -set(EXECUTABLE_OUTPUT_PATH ${BIN_DIR}) - -# Nfqueue C file for device {{device}} -add_executable({{device}} nfqueues.c) -target_link_libraries({{device}} pthread) -IF( OPENWRT_CROSSCOMPILING ) -target_link_libraries({{device}} jansson mnl nfnetlink nftnl nftables netfilter_queue netfilter_log) -ENDIF() -target_link_libraries({{device}} nfqueue packet_utils rule_utils) -target_link_libraries({{device}} ${PARSERS}) -target_include_directories({{device}} PRIVATE ${INCLUDE_DIR} ${INCLUDE_PARSERS_DIR}) -install(TARGETS {{device}} DESTINATION ${EXECUTABLE_OUTPUT_PATH}) diff --git a/src/translator/templates/callback.c.j2 b/src/translator/templates/callback.c.j2 deleted file mode 100644 index 90766bb66b0f7f3e552c530c40c23c9ecdb94580..0000000000000000000000000000000000000000 --- a/src/translator/templates/callback.c.j2 +++ /dev/null @@ -1,274 +0,0 @@ -{% macro verdict(policy_name) %} -uint32_t old_verdict = verdict; - - {% if drop_proba == 0 %} - // Binary ACCEPT - verdict = NF_ACCEPT; - {% elif drop_proba == 1 %} - // Binary DROP - verdict = NF_DROP; - {% else %} - // Stochastic dropping - uint16_t thread_id = *((uint16_t *) arg); - float random_float = rand_r(&(thread_data[thread_id].seed)) / (RAND_MAX + 1.0); - verdict = (random_float < DROP_PROBA) ? NF_DROP : NF_ACCEPT; - #ifdef DEBUG - printf("Generated random float: %f. Drop probability: %f.\n", random_float, DROP_PROBA); - #endif /* DEBUG */ - {% endif %} - - #if defined LOG || defined DEBUG - if (verdict == NF_DROP) { - #ifdef LOG - print_hash(hash); - printf(",%ld.%06ld,{{policy_name}},,DROP\n", (long int)timestamp.tv_sec, (long int)timestamp.tv_usec); - #endif /* LOG */ - #ifdef DEBUG - printf("DROP - Policy: {{policy_name}}\n"); - if (old_verdict != NF_DROP) { - dropped_packets++; - printf("Dropped packets: %hu\n", dropped_packets); - } - #endif /* DEBUG */ - } - #endif /* LOG || DEBUG */ -{% endmacro %} - -{% macro write_callback_function(loop_index, nfqueue) %} -{% set nfqueue_name = nfqueue.name.replace('-', '_') %} -{% set nfqueue_name = nfqueue_name.replace('#', '_') %} -/** - * @brief {{nfqueue.name}} callback function, called when a packet enters the queue. - * - * @param pkt_id packet ID for netfilter queue - * @param hash packet payload SHA256 hash (only present if LOG is defined) - * @param timestamp packet timestamp (only present if LOG is defined) - * @param pkt_len packet length, in bytes - * @param payload pointer to the packet payload - * @param arg pointer to the argument passed to the callback function - * @return the verdict for the packet - */ -#ifdef LOG -uint32_t callback_{{nfqueue_name}}(int pkt_id, uint8_t *hash, struct timeval timestamp, int pkt_len, uint8_t *payload, void *arg) -#else -uint32_t callback_{{nfqueue_name}}(int pkt_id, int pkt_len, uint8_t *payload, void *arg) -#endif /* LOG */ -{ - #ifdef DEBUG - printf("Received packet from nfqueue {{nfqueue.queue_num}}\n"); - #endif - - {% set custom_parsers = [] %} - {% set need_src_addr = namespace(value=False) %} - {% set need_dst_addr = namespace(value=False) %} - {% for policy_dict in nfqueue.policies %} - {% set policy = policy_dict["policy"] %} - {% for nfq_match in policy.nfq_matches %} - {% if nfq_match["template"]|is_list %} - {% for template in nfq_match["template"] %} - {% if "compare_ip" in template or "dns_entry_contains" in template %} - {% if "src_addr" in template and not need_src_addr.value %} - {% set need_src_addr.value = True %} - uint32_t src_addr = get_ipv4_src_addr(payload); // IPv4 source address, in network byte order - {% endif %} - {% if "dst_addr" in template and not need_dst_addr.value %} - {% set need_dst_addr.value = True %} - uint32_t dst_addr = get_ipv4_dst_addr(payload); // IPv4 destination address, in network byte order - {% endif %} - {% endif %} - {% endfor %} - {% else %} - {% if "compare_ip" in nfq_match["template"] or "dns_entry_contains" in nfq_match["template"] %} - {% if "src_addr" in nfq_match["template"] and not need_src_addr.value %} - {% set need_src_addr.value = True %} - uint32_t src_addr = get_ipv4_src_addr(payload); // IPv4 source address, in network byte order - {% endif %} - {% if "dst_addr" in nfq_match["template"] and not need_dst_addr.value %} - {% set need_dst_addr.value = True %} - uint32_t dst_addr = get_ipv4_dst_addr(payload); // IPv4 destination address, in network byte order - {% endif %} - {% endif %} - {% endif %} - {% endfor %} - {% if policy.custom_parser and policy.custom_parser not in custom_parsers %} - {% if policy.custom_parser == 'ssdp' and not need_dst_addr.value %} - {% set need_dst_addr.value = True %} - uint32_t dst_addr = get_ipv4_dst_addr(payload); // IPv4 destination address, in network byte order - {% endif %} - {% if policy.nfq_matches %} - // Skip layer 3 and 4 headers - {% if policy.custom_parser == 'http' or policy.custom_parser == 'coap' %} - size_t l3_header_length = get_l3_header_length(payload); - {% if policy.custom_parser == 'http' %} - uint16_t dst_port = get_dst_port(payload + l3_header_length); - {% elif policy.custom_parser == 'coap' %} - uint16_t coap_length = get_udp_payload_length(payload + l3_header_length); - {% endif %} - {% endif %} - size_t skipped = get_headers_length(payload); - {% if policy.custom_parser == 'http' %} - bool has_payload = pkt_len - skipped >= HTTP_MESSAGE_MIN_LEN; - bool is_http_message = has_payload && is_http(payload + skipped); - {% endif %} - - {% if "dns" in policy.custom_parser %} - // Parse payload as DNS message - dns_message_t dns_message = dns_parse_message(payload + skipped); - #ifdef DEBUG - dns_print_message(dns_message); - #endif - {% elif policy.custom_parser %} - // Parse payload as {{policy.custom_parser|upper}} message - {{policy.custom_parser}}_message_t {{policy.custom_parser}}_message = {{policy.custom_parser}}_parse_message(payload + skipped - {%- if policy.custom_parser == 'http' -%} - , dst_port - {%- elif policy.custom_parser == 'ssdp' -%} - , dst_addr - {%- elif policy.custom_parser == 'coap' -%} - , coap_length - {%- endif -%} - ); - #ifdef DEBUG - {% if policy.custom_parser == 'http' %} - if (is_http_message) { - http_print_message(http_message); - } else { - printf("TCP message with destination port %hu corresponding to HTTP traffic.\n", dst_port); - } - {% else %} - {{policy.custom_parser}}_print_message({{policy.custom_parser}}_message); - {% endif %} - #endif - {% endif %} - {% endif %} - {% set tmp = custom_parsers.append(policy.custom_parser) %} - {% endif %} - {% endfor %} - uint32_t verdict = NF_ACCEPT; // Packet verdict: ACCEPT or DROP - - {% for policy_dict in nfqueue.policies %} - {% set policy_idx = policy_dict["policy_idx"] %} - {% set policy = policy_dict["policy"] %} - {% set policy_name = policy.name %} - /* Policy {{policy_name}} */ - {% if policy.nfq_matches %} - if ( - {% if policy.custom_parser == 'http' %} - !is_http_message || ( - {% endif %} - {% set rule = policy.nfq_matches[0] %} - {% if rule['template'] | is_list %} - ( - {% for i in range(rule['template']|length) %} - {% set template = rule['template'][i] %} - {% set match = rule['match'][i] %} - {{ template.format(match) }} - {% if i < rule['template']|length - 1 %} - || - {% endif %} - {% endfor %} - ) - {% else %} - {{ rule['template'].format(rule['match']) }} - {% endif %} - {% for rule in policy.nfq_matches[1:] %} - && - {% if rule['match'] | is_list %} - ( - {% for i in range(rule['template']|length) %} - {% set template = rule['template'][i] %} - {% set match = rule['match'][i] %} - {{ template.format(match) }} - {% if i < rule['template']|length - 1 %} - || - {% endif %} - {% endfor %} - ) - {% else %} - {{ rule['template'].format(rule['match']) }} - {% endif %} - {% endfor %} - {% if policy.custom_parser == 'http' %} - ) - {% endif %} - ) { - - {% set is_dns_response = namespace(value=False) %} - {% if policy.custom_parser == "dns" %} - {% for nfq_match in policy.nfq_matches %} - {% if "dns_message.header.qr == " in nfq_match["template"] and nfq_match["match"] == 1 and not is_dns_response.value %} - {% set is_dns_response.value = True %} - // Retrieve IP addresses corresponding to the given domain name from the DNS response - char *domain_name = NULL; - ip_list_t ip_list = ip_list_init(); - {% endif %} - {% if is_dns_response.value %} - {% if nfq_match['template'] | is_list %} - {% for i in range(nfq_match['template']|length) %} - {% set template = nfq_match['template'][i] %} - {% if "domain_name" in template %} - {% set domain_name = nfq_match['match'][i] %} - {% if loop.index == 1 %} - if ({{ template.format(domain_name) }}) { - {% else %} - else if ({{ template.format(domain_name) }}) { - {% endif %} - domain_name = "{{domain_name}}"; - ip_list = dns_get_ip_from_name(dns_message.answers, dns_message.header.ancount, domain_name); - } - {% endif %} - {% endfor %} - {% else %} - {% if "domain_name" in nfq_match["template"] %} - {% set domain_name = nfq_match["match"] %} - domain_name = "{{domain_name}}"; - ip_list = dns_get_ip_from_name(dns_message.answers, dns_message.header.ancount, domain_name); - {% endif %} - {% endif %} - {% endif %} - {% endfor %} - {% endif %} - - {% if is_dns_response.value %} - if (ip_list.ip_count > 0) { - // Add IP addresses to DNS map - dns_map_add(dns_map, domain_name, ip_list); - } - {% endif %} - - {{ verdict(policy_name) }} - } - {% elif loop.last %} - // No other policy matched for this nfqueue - {{ verdict(nfqueue.name) }} - {% endif %} - {% endfor %} - - {% for custom_parser in custom_parsers %} - // Free memory allocated for parsed messages - {% if "dns" in custom_parser %} - dns_free_message(dns_message); - {% elif custom_parser != "ssdp" %} - {{custom_parser}}_free_message({{custom_parser}}_message); - {% endif %} - {% endfor %} - - #ifdef LOG - if (verdict != NF_DROP) { - // Log packet as accepted - print_hash(hash); - printf(",%ld.%06ld,{{nfqueue.name}},,ACCEPT\n", (long int)timestamp.tv_sec, (long int)timestamp.tv_usec); - } - free(hash); - #endif /* LOG */ - - return verdict; -} - -{% endmacro %} - -{% for nfqueue in nfqueues if nfqueue.queue_num >= 0 %} - -{{ write_callback_function(loop.index, nfqueue) }} - -{% endfor %} diff --git a/src/translator/templates/firewall.nft.j2 b/src/translator/templates/firewall.nft.j2 deleted file mode 100644 index f7d1c5e752bca847a0c30922336dc56065265b82..0000000000000000000000000000000000000000 --- a/src/translator/templates/firewall.nft.j2 +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/sbin/nft -f - -{% if test %} -table netdev {{device["name"]}} { - - # Chain INGRESS, entry point for all traffic - chain ingress { -{% else %} -table bridge {{device["name"]}} { - - # Chain PREROUTING, entry point for all traffic - chain prerouting { -{% endif %} - - # Base chain, need configuration - # Default policy is ACCEPT - {% if test %} - type filter hook ingress device enp0s8 priority 0; policy accept; - {% else %} - type filter hook prerouting priority 0; policy accept; - {% endif %} - - {% for nfqueue in nfqueues %} - # NFQueue {{nfqueue.name}} - {{nfqueue.get_nft_rule(drop_proba, log_type, log_group)}} - - {% endfor %} - - } - -} - diff --git a/src/translator/templates/header.c.j2 b/src/translator/templates/header.c.j2 deleted file mode 100644 index 8693f2916d5267448f1001fa9d73f35551b1d2f3..0000000000000000000000000000000000000000 --- a/src/translator/templates/header.c.j2 +++ /dev/null @@ -1,63 +0,0 @@ -// THIS FILE HAS BEEN AUTOGENERATED. DO NOT EDIT. - -/** - * Nefilter queue for device {{device}} - */ - -// Standard libraries -#include <stdlib.h> -#include <stdio.h> -#include <stdint.h> -#include <stdbool.h> -#include <unistd.h> -#include <string.h> -#include <pthread.h> -#include <assert.h> -#include <signal.h> -#include <sys/time.h> -// Custom libraries -#include "nfqueue.h" -#include "packet_utils.h" -#include "rule_utils.h" -// Parsers -#include "header.h" -{% set dns_parser_included = namespace(value=False) %} -{% for parser in custom_parsers %} -{% if "dns" in parser %} -{% set dns_parser_included.value = True %} -#include "dns.h" -{% else %} -#include "{{parser}}.h" -{% endif %} -{% endfor %} -{% if domain_names|length > 0 and not dns_parser_included.value %} -#include "dns.h" -{% endif %} - - -/* CONSTANTS */ - -float DROP_PROBA = {{drop_proba}}; // Drop probability for random drop verdict mode - -{% if num_threads > 0 %} -#define NUM_THREADS {{num_threads}} - -/** - * Thread-specific data. - */ -typedef struct { - uint8_t id; // Thread ID - uint32_t seed; // Thread-specific seed for random number generation - pthread_t thread; // The thread itself -} thread_data_t; - -thread_data_t thread_data[NUM_THREADS]; -{% endif %} - -{% if "dns" in custom_parsers or "mdns" in custom_parsers or domain_names|length > 0 %} -dns_map_t *dns_map; // Domain name to IP address mapping -{% endif %} - -#ifdef DEBUG -uint16_t dropped_packets = 0; -#endif /* DEBUG */ diff --git a/src/translator/templates/main.c.j2 b/src/translator/templates/main.c.j2 deleted file mode 100644 index 0e498de3c4cb6ba14ed66aa21c51089bbee28fd0..0000000000000000000000000000000000000000 --- a/src/translator/templates/main.c.j2 +++ /dev/null @@ -1,156 +0,0 @@ -/** - * @brief SIGINT handler, flush stdout and exit. - * - * @param arg unused - */ -void sigint_handler(int arg) { - fflush(stdout); - exit(0); -} - - -/** - * @brief Print program usage. - * - * @param prog program name - */ -void usage(char* prog) { - fprintf(stderr, "Usage: %s [-s DNS_SERVER_IP] [-p DROP_PROBA]\n", prog); -} - - -/** - * @brief Program entry point - * - * @param argc number of command line arguments - * @param argv list of command line arguments - * @return exit code, 0 if success - */ -int main(int argc, char *argv[]) { - - // Initialize variables - int ret; - char *dns_server_ip = "8.8.8.8"; // Default DNS server: Google Quad8 - - // Setup SIGINT handler - signal(SIGINT, sigint_handler); - - - /* COMMAND LINE ARGUMENTS */ - int opt; - while ((opt = getopt(argc, argv, "hp:s:")) != -1) - { - switch (opt) - { - case 'h': - /* Help */ - usage(argv[0]); - exit(EXIT_SUCCESS); - case 'p': - /* Random verdict mode: drop probability (float between 0 and 1) */ - DROP_PROBA = atof(optarg); - break; - case 's': - /* IP address of the network gateway */ - dns_server_ip = optarg; - break; - default: - usage(argv[0]); - exit(EXIT_FAILURE); - } - } - #ifdef DEBUG - printf("Drop probability for random verdict mode: %f\n", DROP_PROBA); - #endif /* DEBUG */ - - - #ifdef LOG - // CSV log file header - printf("hash,timestamp,policy,state,verdict\n"); - #endif /* LOG */ - - - /* GLOBAL STRUCTURES INITIALIZATION */ - - {% if "dns" in custom_parsers or "mdns" in custom_parsers or domain_names|length > 0 %} - // Initialize variables for DNS - dns_map = dns_map_create(); - dns_message_t dns_response; - ip_list_t ip_list; - dns_entry_t *dns_entry; - - {% if domain_names|length > 0 %} - // Open socket for DNS - int sockfd = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); - if (sockfd < 0) { - perror("Socket creation failed"); - exit(EXIT_FAILURE); - } - - // Server address: network gateway - struct sockaddr_in server_addr; - memset(&server_addr, 0, sizeof(server_addr)); - server_addr.sin_family = AF_INET; - server_addr.sin_port = htons(53); - server_addr.sin_addr.s_addr = inet_addr(dns_server_ip); - - {% for name in domain_names %} - // Add addresses for domain {{name}} to DNS map - ret = dns_send_query("{{name}}", sockfd, &server_addr); - if (ret == 0) { - ret = dns_receive_response(sockfd, &server_addr, &dns_response); - if (ret == 0) { - ip_list = dns_get_ip_from_name(dns_response.answers, dns_response.header.ancount, "{{name}}"); - dns_map_add(dns_map, "{{name}}", ip_list); - #ifdef DEBUG - // Check DNS map has been correctly updated - dns_entry = dns_map_get(dns_map, "{{name}}"); - dns_entry_print(dns_entry); - #endif /* DEBUG */ - } - } - - {% endfor %} - {% endif %} - {% endif %} - - - {% if num_threads > 0 %} - /* NFQUEUE THREADS LAUNCH */ - - // Create threads - uint8_t i = 0; - - {% for nfqueue in nfqueues if nfqueue.queue_num >= 0 %} - {% set nfqueue_name = nfqueue.name.replace('-', '_') %} - {% set nfqueue_name = nfqueue_name.replace('#', '_') %} - /* {{nfqueue_name}} */ - // Setup thread-specific data - thread_data[i].id = i; - thread_data[i].seed = time(NULL) + i; - thread_arg_t thread_arg_{{nfqueue_name}} = { - .queue_id = {{nfqueue.queue_num}}, - .func = &callback_{{nfqueue_name}}, - .arg = &(thread_data[i].id) - }; - ret = pthread_create(&(thread_data[i++].thread), NULL, nfqueue_thread, (void *) &thread_arg_{{nfqueue_name}}); - assert(ret == 0); - - {% endfor %} - // Wait forever for threads - for (i = 0; i < NUM_THREADS; i++) { - pthread_join(thread_data[i++].thread, NULL); - } - {% endif %} - - - /* FREE MEMORY */ - - {% if "dns" in custom_parsers or "mdns" in custom_parsers or domain_names|length > 0 %} - // Free DNS map - dns_map_free(dns_map); - {% endif %} - - return 0; -} - diff --git a/src/translator/translator.py b/src/translator/translator.py deleted file mode 100644 index d64b1ffbbc4b53f982dc1c3878ea5fe7d530e09f..0000000000000000000000000000000000000000 --- a/src/translator/translator.py +++ /dev/null @@ -1,293 +0,0 @@ -""" -Translate a device YAML profile to the corresponding pair -of NFTables firewall script and NFQueue C source code. -""" - -# Import packages -import os -import sys -from pathlib import Path -import argparse -import yaml -import jinja2 -from typing import Tuple - -# Paths -script_name = os.path.basename(__file__) -script_path = Path(os.path.abspath(__file__)) -script_dir = script_path.parents[0] -sys.path.insert(0, os.path.join(script_dir, "protocols")) - -# Import custom classes -from LogType import LogType -from Policy import Policy -from NFQueue import NFQueue -from pyyaml_loaders import IncludeLoader - - -##### Custom Argparse types ##### - -def uint16(value: str) -> int: - """ - Custom type for argparse, - to check whether a value is an unsigned 16-bit integer, - i.e. an integer between 0 and 65535. - - :param value: value to check - :return: the value, if it is an unsigned 16-bit integer - :raises argparse.ArgumentTypeError: if the value is not an unsigned 16-bit integer - """ - result = int(value) - if result < 0 or result > 65535: - raise argparse.ArgumentTypeError(f"{value} is not an unsigned 16-bit integer (must be between 0 and 65535)") - return result - - -def proba(value: str) -> float: - """ - Custom type for argparse, - to check whether a value is a valid probability, - i.e. a float between 0 and 1. - - :param value: value to check - :return: the value, if it is a valid probability - :raises argparse.ArgumentTypeError: if the value is not a valid probability - """ - result = float(value) - if result < 0 or result > 1: - raise argparse.ArgumentTypeError(f"{value} is not a valid probability (must be a float between 0 and 1)") - return result - - -##### Custom Jinja2 filters ##### - -def is_list(value: any) -> bool: - """ - Custom filter for Jinja2, to check whether a value is a list. - - :param value: value to check - :return: True if value is a list, False otherwise - """ - return isinstance(value, list) - - -def debug(value: any) -> str: - """ - Custom filter for Jinja2, to print a value. - - :param value: value to print - :return: an empty string - """ - print(str(value)) - return "" - - -##### Utility functions ##### - -def flatten_policies(single_policy_name: str, single_policy: dict, acc: dict = {}) -> None: - """ - Flatten a nested single policy into a list of single policies. - - :param single_policy_name: Name of the single policy to be flattened - :param single_policy: Single policy to be flattened - :param acc: Accumulator for the flattened policies - """ - if "protocols" in single_policy: - acc[single_policy_name] = single_policy - if single_policy.get("bidirectional", False): - acc[f"{single_policy_name}-backward"] = single_policy - else: - for subpolicy in single_policy: - flatten_policies(subpolicy, single_policy[subpolicy], acc) - - -def parse_policy(policy_data: dict, nfq_id: int, global_accs: dict, rate: int = None, drop_proba: float = 1.0, log_type: LogType = LogType.NONE, log_group: int = 100) -> Tuple[Policy, bool]: - """ - Parse a policy. - - :param policy_data: Dictionary containing all the necessary data to create a Policy object - :param global_accs: Dictionary containing the global accumulators - :param rate: Rate limit, in packets/second, to apply to matched traffic - :param drop_proba: Dropping probability, between 0 and 1, to apply to matched traffic - :param log_type: Type of packet logging to be used - :param log_group: Log group ID to be used - :return: the parsed policy, as a `Policy` object, and a boolean indicating whether a new NFQueue was created - """ - # If rate limit is given, add it to policy data - if rate is not None: - policy_data["profile_data"]["stats"] = {"rate": f"{rate}/second"} - - # Create and parse policy - policy = Policy(**policy_data) - policy.parse() - - # If policy has domain name match, - # add domain name to global list - _, hosts = policy.get_domain_name_hosts() - for direction in ["saddr", "daddr"]: - domain_names = hosts.get(direction, {}).get("domain_names", []) - for name in domain_names: - if name not in global_accs["domain_names"]: - global_accs["domain_names"].append(name) - - # Add nftables rules - not_nfq = not policy.nfq_matches and (drop_proba == 0.0 or drop_proba == 1.0) - nfq_id = -1 if not_nfq else nfq_id - policy.build_nft_rule(nfq_id, drop_proba, log_type, log_group) - new_nfq = False - try: - # Check if nft match is already stored - nfqueue = next(nfqueue for nfqueue in global_accs["nfqueues"] if nfqueue.contains_policy_matches(policy)) - except StopIteration: - # No nfqueue with this nft match - nfqueue = NFQueue(policy.name, policy.nft_matches, nfq_id) - global_accs["nfqueues"].append(nfqueue) - new_nfq = nfq_id != -1 - finally: - nfqueue.add_policy(policy) - - # Add custom parser (if any) - if policy.custom_parser: - global_accs["custom_parsers"].add(policy.custom_parser) - - return policy, new_nfq - - -##### MAIN ##### -if __name__ == "__main__": - - ## Command line arguments - description = "Translate a device YAML profile to the corresponding pair of NFTables firewall script and NFQueue C source code." - parser = argparse.ArgumentParser(description=description) - parser.add_argument("profile", type=str, help="Path to the device YAML profile") - parser.add_argument("nfq_id_base", type=uint16, help="NFQueue start index for this profile's policies (must be an integer between 0 and 65535)") - # Verdict modes - parser.add_argument("-r", "--rate", type=int, help="Rate limit, in packets/second, to apply to matched traffic, instead of a binary verdict. Cannot be used with dropping probability.") - parser.add_argument("-p", "--drop-proba", type=proba, help="Dropping probability to apply to matched traffic, instead of a binary verdict. Cannot be used with rate limiting.") - parser.add_argument("-l", "--log-type", type=lambda log_type: LogType[log_type], choices=list(LogType), default=LogType.NONE, help="Type of packet logging to be used") - parser.add_argument("-g", "--log-group", type=uint16, default=100, help="Log group number (must be an integer between 0 and 65535)") - parser.add_argument("-t", "--test", action="store_true", help="Test mode: use VM instead of router") - args = parser.parse_args() - - # Verify verdict mode - if args.rate is not None and args.drop_proba is not None: - parser.error("Arguments --rate and --drop-proba are mutually exclusive") - - # Set default value for drop probability - args.drop_proba = 1.0 if args.drop_proba is None else args.drop_proba - - # Retrieve device profile's path - device_path = os.path.abspath(os.path.dirname(args.profile)) - - # Jinja2 loader - loader = jinja2.FileSystemLoader(searchpath=f"{script_dir}/templates") - env = jinja2.Environment(loader=loader, trim_blocks=True, lstrip_blocks=True) - # Add custom Jinja2 filters - env.filters["debug"] = debug - env.filters["is_list"] = is_list - env.filters["any"] = any - env.filters["all"] = all - - # NFQueue ID increment - nfq_id_inc = 10 - - # Load the device profile - with open(args.profile, "r") as f: - - # Load YAML profile with custom loader - profile = yaml.load(f, IncludeLoader) - - # Get device info - device = profile["device-info"] - - # Base nfqueue id, will be incremented at each interaction - nfq_id = args.nfq_id_base - - # Global accumulators - global_accs = { - "custom_parsers": set(), - "nfqueues": [], - "domain_names": [] - } - - - # Loop over the device's individual policies - if "single-policies" in profile: - for policy_name in profile["single-policies"]: - profile_data = profile["single-policies"][policy_name] - - policy_data = { - "policy_name": policy_name, - "profile_data": profile_data, - "device": device, - "is_backward": False - } - - # Parse policy - is_backward = profile_data.get("bidirectional", False) - policy, new_nfq = parse_policy(policy_data, nfq_id, global_accs, args.rate, args.drop_proba, args.log_type, args.log_group) - - # Parse policy in backward direction, if needed - if is_backward: - policy_data_backward = { - "policy_name": f"{policy_name}-backward", - "profile_data": profile_data, - "device": device, - "is_backward": True - } - policy_backward, new_nfq = parse_policy(policy_data_backward, nfq_id + 1, global_accs, args.rate, args.drop_proba, args.log_type, args.log_group) - - # Update nfqueue variables if needed - if new_nfq: - nfq_id += nfq_id_inc - - - # Create nftables script - nft_dict = { - "device": device, - "nfqueues": global_accs["nfqueues"], - "drop_proba": args.drop_proba, - "log_type": args.log_type, - "log_group": args.log_group, - "test": args.test - } - env.get_template("firewall.nft.j2").stream(nft_dict).dump(f"{device_path}/firewall.nft") - - # If needed, create NFQueue-related files - num_threads = len([q for q in global_accs["nfqueues"] if q.queue_num >= 0]) - if num_threads > 0: - # Create nfqueue C file by rendering Jinja2 templates - header_dict = { - "device": device["name"], - "custom_parsers": global_accs["custom_parsers"], - "domain_names": global_accs["domain_names"], - "drop_proba": args.drop_proba, - "num_threads": num_threads, - } - header = env.get_template("header.c.j2").render(header_dict) - callback_dict = { - "nft_table": f"bridge {device['name']}", - "nfqueues": global_accs["nfqueues"], - "drop_proba": args.drop_proba - } - callback = env.get_template("callback.c.j2").render(callback_dict) - main_dict = { - "custom_parsers": global_accs["custom_parsers"], - "nfqueues": global_accs["nfqueues"], - "domain_names": global_accs["domain_names"], - "num_threads": num_threads - } - main = env.get_template("main.c.j2").render(main_dict) - - # Write policy C file - with open(f"{device_path}/nfqueues.c", "w+") as fw: - fw.write(header) - fw.write(callback) - fw.write(main) - - # Create CMake file - cmake_dict = {"device": device["name"]} - env.get_template("CMakeLists.txt.j2").stream(cmake_dict).dump(f"{device_path}/CMakeLists.txt") - - - print(f"Done translating {args.profile}.")