diff --git a/src/translator/Policy.py b/src/translator/Policy.py index 463d5f2e0738abf186777358944bbd1d94c754bf..f860ce4e2cb1fc88e62daf4d02663d8ce11f7f33 100644 --- a/src/translator/Policy.py +++ b/src/translator/Policy.py @@ -1,4 +1,5 @@ ## Import packages +from __future__ import annotations from enum import Enum from typing import Tuple, Dict import ipaddress @@ -32,17 +33,17 @@ class Policy: "duration": {"counter": True} } - def __init__(self, policy_name: str, profile_data: dict, device: dict, is_backward: bool = False) -> None: + + def __init__(self, profile_data: dict, device: dict, policy_name: str = None, is_backward: bool = False) -> None: """ Initialize a new Policy object. - :param name: Name of the policy :param profile_data: Dictionary containing the policy data from the YAML profile + :param policy_name: Name of the policy :param device: Dictionary containing the device metadata from the YAML profile :param is_backward: Whether the policy is backwards (i.e. the source and destination are reversed) """ - self.name = policy_name # Policy name - self.profile_data = profile_data # Policy data from the YAML profile + # Initialize attributes self.device = device # Dictionary containing data for the device this policy is linked to self.is_backward = is_backward # Whether the policy is backwards (i.e. the source and destination are reversed) self.custom_parser = "" # Name of the custom parser (if any) @@ -52,9 +53,97 @@ class Policy: self.queue_num = -1 # Number of the nfqueue queue corresponding (will be updated by parsing) self.nft_action = "" # nftables action associated to this policy self.nfq_matches = [] # List of nfqueue matches (will be populated by parsing) - self.is_bidirectional = self.profile_data.get("bidirectional", False) # Whether the policy is bidirectional + self.profile_data = profile_data # Policy data from the YAML profile self.initiator = profile_data["initiator"] if "initiator" in profile_data else "" + # Parse policy data + self.parse() + + # Set policy name + self.name = policy_name if policy_name is not None else self.get_policy_id() + + + def parse(self) -> None: + """ + Parse policy data to populate the policy's attributes. + """ + ### Parse protocols + for protocol_name in self.profile_data["protocols"]: + try: + profile_protocol = self.profile_data["protocols"][protocol_name] + protocol = Protocol.init_protocol(protocol_name, profile_protocol, self.device) + except ModuleNotFoundError: + # Unsupported protocol, skip it + continue + else: + # Protocol is supported, parse it + + # Add custom parser if needed + if protocol.custom_parser: + self.custom_parser = protocol_name + + ### Check involved devices + protocols = ["arp", "ipv4", "ipv6"] + # This device's addresses + addrs = ["mac", "ipv4", "ipv6"] + self_addrs = ["self"] + for addr in addrs: + device_addr = self.device.get(addr, None) + if device_addr is not None: + self_addrs.append(device_addr) + if protocol_name in protocols: + ip_proto = "ipv6" if protocol_name == "ipv6" else "ipv4" + src = profile_protocol.get("spa", None) if protocol_name == "arp" else profile_protocol.get("src", None) + dst = profile_protocol.get("tpa", None) if protocol_name == "arp" else profile_protocol.get("dst", None) + + # Check if device is involved + if src in self_addrs or dst in self_addrs: + self.is_device = True + + # Device is not involved + else: + # Try expliciting source address + try: + saddr = ipaddress.ip_network(protocol.explicit_address(src)) + except ValueError: + saddr = None + + # Try expliciting destination address + try: + daddr = ipaddress.ip_network(protocol.explicit_address(dst)) + except ValueError: + daddr = None + + # Check if the involved other host is in the local network + local_networks = ip.addrs[ip_proto]["local"] + if isinstance(local_networks, list): + lans = map(lambda cidr: ipaddress.ip_network(cidr), local_networks) + else: + lans = [ipaddress.ip_network(local_networks)] + if saddr is not None and any(lan.supernet_of(saddr) for lan in lans): + self.other_host["protocol"] = protocol_name + self.other_host["direction"] = "src" + self.other_host["address"] = saddr + elif daddr is not None and any(lan.supernet_of(daddr) for lan in lans): + self.other_host["protocol"] = protocol_name + self.other_host["direction"] = "dst" + self.other_host["address"] = daddr + + # Add nft rules + new_rules = protocol.parse(is_backward=self.is_backward, initiator=self.initiator) + self.nft_matches += new_rules["nft"] + + # Add nfqueue matches + for match in new_rules["nfq"]: + self.nfq_matches.append(match) + + + ### Parse statistics + if "stats" in self.profile_data: + for stat in self.profile_data["stats"]: + if stat in Policy.stats_metadata: + self.parse_stat(stat) + def __eq__(self, other: object) -> bool: """ @@ -75,8 +164,8 @@ class Policy: self_matches == other_matches and self.nft_stats == other.nft_stats and self.nft_action == other.nft_action and - self.queue_num == other.queue_num and - self.is_bidirectional == other.is_bidirectional ) + self.queue_num == other.queue_num ) + def __lt__(self, other: object) -> bool: """ @@ -442,3 +531,17 @@ class Policy: if Policy.stats_metadata.get(stat, {}).get("nft_type", None) == Policy.NftType.MATCH: result[stat] = data return result + + + def get_policy_id(self) -> str: + """ + Generate an identifier for this Policy. + + :return: str: Policy identifier + """ + highest_protocol = list(dict.keys(self.profile_data["protocols"]))[-1] + id = highest_protocol + for _, value in dict.items(self.profile_data["protocols"][highest_protocol]): + id += f"_{value}" + return id + \ No newline at end of file diff --git a/src/translator/protocols b/src/translator/protocols index baaf4b6850ce8a97e39c20882e60f2ef468b4ec6..a5005bced6d7edacb51f251ee772c3df011dbb08 160000 --- a/src/translator/protocols +++ b/src/translator/protocols @@ -1 +1 @@ -Subproject commit baaf4b6850ce8a97e39c20882e60f2ef468b4ec6 +Subproject commit a5005bced6d7edacb51f251ee772c3df011dbb08 diff --git a/src/translator/translator.py b/src/translator/translator.py index 044a6b1424594a9305220638823835c4ea6e12a5..a7fbf478428a72775026f53e8c2ee33a13da48ce 100644 --- a/src/translator/translator.py +++ b/src/translator/translator.py @@ -87,7 +87,6 @@ def parse_policy(policy_data: dict, nfq_id: int, global_accs: dict, rate: int = # Create and parse policy policy = Policy(**policy_data) - policy.parse() # If policy has domain name match, # add domain name to global list @@ -194,9 +193,9 @@ if __name__ == "__main__": profile_data = profile["single-policies"][policy_name] policy_data = { - "policy_name": policy_name, "profile_data": profile_data, "device": device, + "policy_name": policy_name, "is_backward": False } @@ -207,9 +206,9 @@ if __name__ == "__main__": # Parse policy in backward direction, if needed if is_backward: policy_data_backward = { - "policy_name": f"{policy_name}-backward", "profile_data": profile_data, "device": device, + "policy_name": f"{policy_name}-backward", "is_backward": True } policy_backward, new_nfq = parse_policy(policy_data_backward, nfq_id + 1, global_accs, args.rate, args.drop_proba, args.log_type, args.log_group)