Newer
Older
from __future__ import annotations
from typing import Union
import importlib
Generic protocol, inherited by all concrete protocols.
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""
def __init__(self, protocol_data: dict, device: dict) -> None:
"""
Generic protocol constructor.
:param protocol_data: Dictionary containing the protocol data.
:param device: Dictionary containing the device metadata.
"""
self.protocol_data = protocol_data
self.device = device
self.rules = {
"nft": [],
"nfq": []
}
@staticmethod
def convert_value(value: str) -> Union[str, int]:
"""
Convert a string value to an int if possible.
:param value: Value to convert.
:return: Converted value as int if possible, or the original string value otherwise.
"""
try:
result = int(value)
except ValueError:
result = value
return result
@classmethod
def init_protocol(c, protocol_name: str, protocol_data: dict, device: dict) -> Protocol:
"""
Factory method for a specific protocol.
:param protocol_name: Name of the protocol.
:param protocol_data: Dictionary containing the protocol data.
:param device: Dictionary containing the device metadata.
"""
module = importlib.import_module(f"{protocol_name}")
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
cls = getattr(module, protocol_name)
return cls(protocol_data, device)
def format_list(self, l: list, func = lambda x: x) -> str:
"""
Format a list of values.
:param l: List of values.
:param func: Function to apply to each value.
Optional, default is the identity function.
:return: Formatted list.
"""
value = "{ "
for i in range(len(l)):
if i != 0:
value += ", "
value += str(func(l[i]))
value += " }"
return value
def add_field(self, field: str, template_rules: dict, is_backward: bool = False, func = lambda x: x, backward_func = lambda x: x) -> None:
"""
Add a new nftables rule to the nftables rules accumulator.
Args:
field (str): Field to add the rule for.
rules (dict): Dictionary containing the protocol-specific rules to add.
is_backward (bool): Whether the field to add is for a backward rule.
Optional, default is `False`.
func (lambda): Function to apply to the field value before writing it.
Optional, default is the identity function.
backward_func (lambda): a to apply to the field value in the case of a backwards rule.
Will be applied after the forward function.
Optional, default is the identity function.
"""
if self.protocol_data is not None and field in self.protocol_data:
value = self.protocol_data[field]
# If value from YAML profile is a list, add each element
if type(value) == list:
# Value is a list
value = self.format_list(value, func)
else:
# Value is a single element
value = func(value)
# Build rule
rule = {}
value = Protocol.convert_value(value)
if not is_backward:
rule = {"template": template_rules["forward"], "match": value}
elif is_backward and "backward" in template_rules:
rule = {"template": template_rules["backward"], "match": backward_func(value)}
# Add rule to the list of rules
if rule:
self.rules["nft"].append(rule)
def parse(self, is_backward: bool = False, initiator: str = "src") -> dict:
"""
Default parsing method.
Must be updated in the children class.
:param is_backward (optional): Whether the protocol must be parsed for a backward rule.
Optional, default is `False`.
:param initiator (optional): Connection initiator (src or dst).
Optional, default is "src".
:return: Dictionary containing the (forward and backward) nftables and nfqueue rules for this policy.
"""
return self.rules