diff --git a/.ci_scripts/install_packages.sh b/.ci_scripts/install_packages.sh new file mode 100755 index 0000000000000000000000000000000000000000..a7b016c806e62a7b608fe356c795bbafb80ec2ae --- /dev/null +++ b/.ci_scripts/install_packages.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +apt update +apt install -y gcc make cmake libcunit1 libcunit1-dev net-tools valgrind cppcheck diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e377e7550c6f067b9c71c7251891f9bed7e03307 --- /dev/null +++ b/.github/workflows/unit-tests.yaml @@ -0,0 +1,27 @@ +name: Unit tests for source files +on: [push] + + +jobs: + + test: + runs-on: ubuntu-latest + steps: + + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Install required packages + run: sudo $GITHUB_WORKSPACE/.ci_scripts/full-test/install_packages.sh + + - name: Build project with CMake + run: $GITHUB_WORKSPACE/build.sh -d $GITHUB_WORKSPACE + + - name: Run CUnit tests + run: $GITHUB_WORKSPACE/.ci_scripts/full-test/run_tests.sh + + - name: Run Valgrind on CUnit tests + run: $GITHUB_WORKSPACE/.ci_scripts/full-test/run_tests.sh valgrind + + - name: Run cppcheck on source files + run: $GITHUB_WORKSPACE/.ci_scripts/full-test/run_cppcheck.sh diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..1e9e4b2504216ca0052d8221a6e237293811e0b9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +# Config folders +.vscode + +# Build folders +build +bin diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..8634ffc63d5f3ae63fe83265acf800dd750021e9 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,31 @@ +# Minimum required CMake version +cmake_minimum_required(VERSION 3.20) + +# Project name +project(protocol-parsers C) + +# Set project directories +link_directories($ENV{LD_LIBRARY_PATH}) +set(CMAKE_INSTALL_PREFIX ${PROJECT_SOURCE_DIR}) +set(INCLUDE_DIR ${PROJECT_SOURCE_DIR}/include) +set(BIN_DIR ${PROJECT_SOURCE_DIR}/bin) +set(EXECUTABLE_OUTPUT_PATH ${BIN_DIR}) + +# Set compiler flags +#add_compile_options(-Wall -Werror -Wno-unused-variable -DDEBUG) # Debug +#add_compile_options(-Wall -Werror -Wno-unused-variable -DLOG) # Logging +#add_compile_options(-Wall -Werror -Wno-unused-variable) # Production +# With optimisation +#add_compile_options(-Wall -Werror -Wno-unused-variable -O3 -DDEBUG) # Debug +#add_compile_options(-Wall -Werror -Wno-unused-variable -O3 -DLOG) # Packet Logging +add_compile_options(-Wall -Werror -Wno-unused-variable -O3) # Production +# With debug symbols +#add_compile_options(-Wall -Werror -Wno-unused-variable -g) # Without debug logging +#add_compile_options(-Wall -Werror -Wno-unused-variable -DDEBUG -g) # With debug logging + +# Custom parsers +set(PARSERS header dns dhcp http igmp ssdp coap) + +# Subdirectories containing code +add_subdirectory(src) +add_subdirectory(test) diff --git a/include/coap.h b/include/coap.h new file mode 100644 index 0000000000000000000000000000000000000000..d56a6978215faf60f35afc187dcfd7f8ab5cc5cf --- /dev/null +++ b/include/coap.h @@ -0,0 +1,87 @@ +/** + * @file include/coap.h + * @brief CoAP message parser + * @date 2022-11-30 + * + * @copyright Copyright (c) 2022 + * + */ + +#ifndef _PROTOCOL_PARSERS_COAP_ +#define _PROTOCOL_PARSERS_COAP_ + +#include <stdlib.h> +#include <stdio.h> +#include <stdint.h> +#include <string.h> +#include <arpa/inet.h> +#include "http.h" + + +/** + * @brief CoAP message type + */ +typedef enum +{ + COAP_CON = 0, + COAP_NON = 1, + COAP_ACK = 2, + COAP_RST = 3 +} coap_type_t; + +/** + * @brief CoAP Option number + */ +typedef enum +{ + COAP_URI_PATH = 11, + COAP_URI_QUERY = 15 +} coap_option_t; + +/** + * @brief Abstraction of a CoAP message + */ +typedef struct coap_message +{ + coap_type_t type; // CoAP message type + http_method_t method; // CoAP method, analogous to HTTP + char *uri; // Message URI + uint16_t uri_len; // URI length +} coap_message_t; + + +////////// FUNCTIONS ////////// + +///// PARSING ///// + +/** + * @brief Parse a CoAP message. + * + * @param data pointer to the start of the CoAP message + * @param length length of the CoAP message, in bytes + * @return the parsed CoAP message + */ +coap_message_t coap_parse_message(uint8_t *data, uint16_t length); + + +///// DESTROY ///// + +/** + * @brief Free the memory allocated for a CoAP message. + * + * @param message the CoAP message to free + */ +void coap_free_message(coap_message_t message); + + +///// PRINTING ///// + +/** + * @brief Print a CoAP message. + * + * @param message the CoAP message to print + */ +void coap_print_message(coap_message_t message); + + +#endif /* _PROTOCOL_PARSERS_COAP_ */ diff --git a/include/dhcp.h b/include/dhcp.h new file mode 100644 index 0000000000000000000000000000000000000000..1e500a5ba120dcf12e13cdab139853c86a865927 --- /dev/null +++ b/include/dhcp.h @@ -0,0 +1,174 @@ +/** + * @file include/dhcp.h + * @brief DHCP message parser + * @date 2022-09-12 + * + * @copyright Copyright (c) 2022 + * + */ + +#ifndef _PROTOCOL_PARSERS_DHCP_ +#define _PROTOCOL_PARSERS_DHCP_ + +#include <stdlib.h> +#include <stdio.h> +#include <stdint.h> +#include <string.h> +#include <arpa/inet.h> + +#define MAX_HW_LEN 16 +#define DHCP_HEADER_LEN 236 +#define DHCP_MAX_OPTION_COUNT 20 +#define DHCP_MAGIC_COOKIE 0x63825363 + + +////////// TYPE DEFINITIONS ////////// + +/** + * DHCP opcode + */ +typedef enum +{ + DHCP_BOOTREQUEST = 1, + DHCP_BOOTREPLY = 2 +} dhcp_opcode_t; + +/** + * Useful DHCP option codes + */ +typedef enum +{ + DHCP_PAD = 0, + DHCP_MESSAGE_TYPE = 53, + DHCP_END = 255 +} dhcp_option_code_t; + +/** + * DHCP message type + */ +typedef enum +{ + DHCP_DISCOVER = 1, + DHCP_OFFER = 2, + DHCP_REQUEST = 3, + DHCP_DECLINE = 4, + DHCP_ACK = 5, + DHCP_NAK = 6, + DHCP_RELEASE = 7, + DHCP_INFORM = 8 +} dhcp_message_type_t; + +/** + * DHCP Option + */ +typedef struct dhcp_option { + dhcp_option_code_t code; + uint8_t length; + uint8_t *value; +} dhcp_option_t; + +/** + * DHCP Options + */ +typedef struct dhcp_options { + uint8_t count; // Number of options + dhcp_message_type_t message_type; // DHCP Message type (stored for convenience) + dhcp_option_t *options; // List of options +} dhcp_options_t; + +/** + * DHCP Message + */ +typedef struct dhcp_message { + dhcp_opcode_t op; // DHCP opcode + uint8_t htype; // Hardware address type + uint8_t hlen; // Hardware address length + uint8_t hops; // Number of hops + uint32_t xid; // Transaction ID + uint16_t secs; // Seconds elapsed since client began address acquisition or renewal process + uint16_t flags; // DHCP flags + uint32_t ciaddr; // Client IP address + uint32_t yiaddr; // Your (client) IP address + uint32_t siaddr; // Next server IP address + uint32_t giaddr; // Relay agent IP address + uint8_t chaddr[16]; // Client hardware address + uint8_t sname[64]; // Optional server host name + uint8_t file[128]; // Boot file name + dhcp_options_t options; // DHCP options +} dhcp_message_t; + + +////////// FUNCTIONS ////////// + +///// PARSING ///// + +/** + * @brief Parse the header of a DHCP message (not including options) + * + * @param data a pointer to the start of the DHCP message + * @return the parsed DHCP message with the header fields filled in + */ +dhcp_message_t dhcp_parse_header(uint8_t *data); + +/** + * @brief Parse a DHCP option + * + * @param data a pointer to the start of the DHCP option + * @param offset a pointer to the current offset inside the DHCP message + * Its value will be updated to point to the next option + * @return the parsed DHCP option + */ +dhcp_option_t dhcp_parse_option(uint8_t *data, uint16_t *offset); + +/** + * @brief Parse DHCP options + * + * @param data a pointer to the start of the DHCP options list + * @return a pointer to the start of the parsed DHCP options + */ +dhcp_options_t dhcp_parse_options(uint8_t *data); + +/** + * @brief Parse a DHCP message + * + * @param data a pointer to the start of the DHCP message + * @return the parsed DHCP message + */ +dhcp_message_t dhcp_parse_message(uint8_t *data); + + +///// DESTROY ////// + +/** + * @brief Free the memory allocated for a DHCP message. + * + * @param message the DHCP message to free + */ +void dhcp_free_message(dhcp_message_t message); + + +///// PRINTING ///// + +/** + * @brief Print the header of a DHCP message + * + * @param message the DHCP message to print the header of + */ +void dhcp_print_header(dhcp_message_t message); + +/** + * @brief Print a DHCP option + * + * @param option the DHCP option to print + */ +void dhcp_print_option(dhcp_option_t option); + +/** + * @brief Print a DHCP message + * + * @param message the DHCP message to print + */ +void dhcp_print_message(dhcp_message_t message); + + +#endif /* _PROTOCOL_PARSERS_DHCP_ */ diff --git a/include/dns.h b/include/dns.h new file mode 100644 index 0000000000000000000000000000000000000000..7b6975ea1b8835199ff93868ab879542929b9efd --- /dev/null +++ b/include/dns.h @@ -0,0 +1,271 @@ +/** + * @file include/dns.h + * @brief DNS message parser + * @date 2022-09-09 + * + * @copyright Copyright (c) 2022 + * + */ + +#ifndef _PROTOCOL_PARSERS_DNS_ +#define _PROTOCOL_PARSERS_DNS_ + +#include <stdlib.h> +#include <stdio.h> +#include <stdint.h> +#include <stdbool.h> +#include <string.h> +#include <arpa/inet.h> +#include "packet_utils.h" +#include "dns_map.h" + +#define DNS_HEADER_SIZE 12 +#define DNS_MAX_DOMAIN_NAME_LENGTH 100 +#define DNS_QR_FLAG_MASK 0x8000 +#define DNS_CLASS_MASK 0x7fff +#define DNS_COMPRESSION_MASK 0x3fff + + +////////// TYPE DEFINITIONS ////////// + +/** + * DNS types + */ +typedef enum { + A = 1, + NS = 2, + MD = 3, + MF = 4, + CNAME = 5, + SOA = 6, + MB = 7, + MG = 8, + MR = 9, + NULL_ = 10, + WKS = 11, + PTR = 12, + HINFO = 13, + MINFO = 14, + MX = 15, + TXT = 16, + AAAA = 28, + OPT = 41, // Used to specify extensions + ANY = 255 // Used to query any type +} dns_rr_type_t; + +/** + * DNS Header + */ +typedef struct dns_header { + uint16_t id; + uint16_t flags; + bool qr; // 0 if the message is a query, 1 if it is a response + uint16_t qdcount; // Number of entries in Question section + uint16_t ancount; // Number of Resource Records in Answer section + uint16_t nscount; // Number of Resource Records in Authority section + uint16_t arcount; // Number of Resource Records in Additional section +} dns_header_t; + +/** + * DNS Question + */ +typedef struct dns_question { + char *qname; + dns_rr_type_t qtype; + uint16_t qclass; +} dns_question_t; + +/** + * RDATA field of a DNS Resource Record + */ +typedef union { + char *domain_name; // Domain name, character string + ip_addr_t ip; // IP (v4 or v6) address + uint8_t *data; // Generic data, series of bytes +} rdata_t; + +/** + * DNS Resource Record + */ +typedef struct dns_resource_record { + char *name; + dns_rr_type_t rtype; + uint16_t rclass; + uint32_t ttl; + uint16_t rdlength; + rdata_t rdata; +} dns_resource_record_t; + +/** + * DNS Message + */ +typedef struct dns_message { + dns_header_t header; + dns_question_t *questions; + dns_resource_record_t *answers; + dns_resource_record_t *authorities; + dns_resource_record_t *additionals; +} dns_message_t; + + +////////// FUNCTIONS ////////// + +///// PARSING ///// + +/** + * Parse a DNS header. + * A DNS header is always 12 bytes. + * + * @param data a pointer pointing to the start of the DNS message + * @param offset a pointer to the current parsing offset + * @return the parsed header + */ +dns_header_t dns_parse_header(uint8_t *data, uint16_t *offset); + +/** + * Parse a DNS question section. + * + * @param qdcount the number of questions present in the question section + * @param data a pointer pointing to the start of the DNS message + * @param offset a pointer to the current parsing offset + * @return the parsed question section + */ +dns_question_t* dns_parse_questions(uint16_t qdcount, uint8_t *data, uint16_t *offset); + +/** + * Parse a DNS resource record list. + * + * @param count the number of resource records present in the section + * @param data a pointer pointing to the start of the DNS message + * @param offset a pointer to the current parsing offset + * @return the parsed resource records list + */ +dns_resource_record_t* dns_parse_rrs(uint16_t count, uint8_t *data, uint16_t *offset); + +/** + * Parse a DNS message. + * + * @param data a pointer to the start of the DNS message + * @return the parsed DNS message + */ +dns_message_t dns_parse_message(uint8_t *data); + + +///// LOOKUP ///// + +/** + * @brief Check if a given DNS Questions list contains a domain name which has a given suffix. + * + * @param questions DNS Questions list + * @param qdcount number of Questions in the list + * @param suffix the domain name suffix to search for + * @param suffix_length the length of the domain name suffix + * @return true if a domain name with the given suffix is found is found in the Questions list, + * false otherwise + */ +bool dns_contains_suffix_domain_name(dns_question_t *questions, uint16_t qdcount, char *suffix, uint16_t suffix_length); + +/** + * @brief Check if a given domain name is fully contained in a DNS Questions list. + * + * @param questions DNS Questions list + * @param qdcount number of Questions in the list + * @param domain_name the domain name to search for + * @return true if the full domain name is found in the Questions list, false otherwise + */ +bool dns_contains_full_domain_name(dns_question_t *questions, uint16_t qdcount, char *domain_name); + +/** + * @brief Search for a specific domain name in a DNS Questions list. + * + * @param questions DNS Questions list + * @param qdcount number of Suestions in the list + * @param domain_name the domain name to search for + * @return the DNS Question related to the given domain name, or NULL if not found + */ +dns_question_t* dns_get_question(dns_question_t *questions, uint16_t qdcount, char *domain_name); + +/** + * @brief Retrieve the IP addresses corresponding to a given domain name in a DNS Answers list. + * + * Searches a DNS Answer list for a specific domain name and returns the corresponding IP address. + * Processes each Answer recursively if the Answer Type is a CNAME. + * + * @param answers DNS Answers list to search in + * @param ancount number of Answers in the list + * @param domain_name domain name to search for + * @return struct ip_list representing the list of corresponding IP addresses + */ +ip_list_t dns_get_ip_from_name(dns_resource_record_t *answers, uint16_t ancount, char *domain_name); + + +///// DESTROY ///// + +/** + * Free the memory allocated for a DNS message. + * + * @param question the DNS message to free + */ +void dns_free_message(dns_message_t message); + + +///// PRINTING ///// + +/** + * Print a DNS header. + * + * @param message the DNS header + */ +void dns_print_header(dns_header_t header); + +/** + * Print a DNS Question + * + * @param question the DNS Question + */ +void dns_print_question(dns_question_t question); + +/** + * Print a DNS Question section. + * + * @param qdcount the number of Questions in the Question section + * @param questions the list of DNS Questions + */ +void dns_print_questions(uint16_t qdcount, dns_question_t *questions); + +/** + * Return a string representation of the given RDATA value. + * + * @param rtype the type corresponding to the RDATA value + * @param rdlength the length, in bytes, of the RDATA value + * @param rdata the RDATA value, stored as a union type + * @return a string representation of the RDATA value + */ +char* dns_rdata_to_str(dns_rr_type_t rtype, uint16_t rdlength, rdata_t rdata); + +/** + * Print a DNS Resource Record. + * + * @param section_name the name of the Resource Record section + * @param rr the DNS Resource Record + */ +void dns_print_rr(char* section_name, dns_resource_record_t rr); + +/** + * Print a DNS Resource Records section. + * + * @param section_name the name of the Resource Record section + * @param count the number of Resource Records in the section + * @param rrs the list of DNS Resource Records + */ +void dns_print_rrs(char* section_name, uint16_t count, dns_resource_record_t *rrs); + +/** + * Print a DNS message. + * + * @param message the DNS message + */ +void dns_print_message(dns_message_t message); + + +#endif /* _PROTOCOL_PARSERS_DNS_ */ diff --git a/include/dns_map.h b/include/dns_map.h new file mode 100644 index 0000000000000000000000000000000000000000..b302a4879c74c52b43d21506fd739378e26dd8d3 --- /dev/null +++ b/include/dns_map.h @@ -0,0 +1,132 @@ +/** + * @file include/dns_map.h + * @brief Implementation of a DNS domain name to IP addresses mapping, using Joshua J Baker's hashmap.c (https://github.com/tidwall/hashmap.c) + * @date 2022-09-06 + * + * @copyright Copyright (c) 2022 + * + */ + +#ifndef _PROTOCOL_PARSERS_DNS_MAP_ +#define _PROTOCOL_PARSERS_DNS_MAP_ + +#include <stdlib.h> +#include <stdint.h> +#include <stdbool.h> +#include <string.h> +#include "hashmap.h" +#include "packet_utils.h" + +// Initial size of the DNS table +// If set to 0, the default size will be 16 +#define DNS_MAP_INIT_SIZE 0 + + +////////// TYPE DEFINITIONS ////////// + +/** + * List of IP addresses + */ +typedef struct ip_list { + uint8_t ip_count; // Number of IP addresses + ip_addr_t *ip_addresses; // List of IP addresses +} ip_list_t; + +/** + * DNS table entry: + * mapping between domain name and a list of IP addresses. + */ +typedef struct dns_entry { + char *domain_name; // Domain name + ip_list_t ip_list; // List of IP addresses +} dns_entry_t; + +/** + * Alias for the hashmap structure. + */ +typedef struct hashmap dns_map_t; + + +////////// FUNCTIONS ////////// + +/** + * @brief Initialize an ip_list_t structure. + * + * Creates an empty list of IP addresses. + * The `ip_count` field is set to 0, + * and the `ip_addresses` field is set to NULL. + * + * @return ip_list_t newly initialized structure + */ +ip_list_t ip_list_init(); + +/** + * @brief Checks if a dns_entry_t structure contains a given IP address. + * + * @param dns_entry pointer to the DNS entry to process + * @param ip_address IP address to check the presence of + * @return true if the IP address is present in the DNS entry, false otherwise + */ +bool dns_entry_contains(dns_entry_t *dns_entry, ip_addr_t ip_address); + +/** + * Create a new DNS table. + * + * @return the newly created DNS table + */ +dns_map_t* dns_map_create(); + +/** + * Destroy (free) a DNS table. + * + * @param table the DNS table to free + */ +void dns_map_free(dns_map_t *table); + +/** + * Add IP addresses corresponding to a given domain name in the DNS table. + * If the domain name was already present, its IP addresses will be replaced by the new ones. + * + * @param table the DNS table to add the entry to + * @param domain_name the domain name of the entry + * @param ip_list an ip_list_t structure containing the list of IP addresses + */ +void dns_map_add(dns_map_t *table, char *domain_name, ip_list_t ip_list); + +/** + * Remove a domain name (and its corresponding IP addresses) from the DNS table. + * + * @param table the DNS table to remove the entry from + * @param domain_name the domain name of the entry to remove + */ +void dns_map_remove(dns_map_t *table, char *domain_name); + +/** + * Retrieve the IP addresses corresponding to a given domain name in the DNS table. + * + * @param table the DNS table to retrieve the entry from + * @param domain_name the domain name of the entry to retrieve + * @return a pointer to a dns_entry structure containing the IP addresses corresponding to the domain name, + * or NULL if the domain name was not found in the DNS table + */ +dns_entry_t* dns_map_get(dns_map_t *table, char *domain_name); + +/** + * Retrieve the IP addresses corresponding to a given domain name, + * and remove the domain name from the DNS table. + * + * @param table the DNS table to retrieve the entry from + * @param domain_name the domain name of the entry to retrieve + * @return a pointer to a dns_entry structure containing the IP addresses corresponding to the domain name, + * or NULL if the domain name was not found in the DNS table + */ +dns_entry_t* dns_map_pop(dns_map_t *table, char *domain_name); + +/** + * @brief Print a DNS table entry. + * + * @param dns_entry the DNS table entry to print + */ +void dns_entry_print(dns_entry_t *dns_entry); + +#endif /* _PROTOCOL_PARSERS_DNS_MAP_ */ diff --git a/include/hashmap.h b/include/hashmap.h new file mode 100644 index 0000000000000000000000000000000000000000..197a6627eeef06d40209ae42f13543213bb7474b --- /dev/null +++ b/include/hashmap.h @@ -0,0 +1,55 @@ +// Copyright 2020 Joshua J Baker. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +#ifndef HASHMAP_H +#define HASHMAP_H + +#include <stdbool.h> +#include <stddef.h> +#include <stdint.h> + +typedef struct hashmap hashmap; + +struct hashmap *hashmap_new(size_t elsize, size_t cap, + uint64_t seed0, uint64_t seed1, + uint64_t (*hash)(const void *item, + uint64_t seed0, uint64_t seed1), + int (*compare)(const void *a, const void *b, + void *udata), + void (*elfree)(void *item), + void *udata); +struct hashmap *hashmap_new_with_allocator( + void *(*malloc)(size_t), + void *(*realloc)(void *, size_t), + void (*free)(void*), + size_t elsize, size_t cap, + uint64_t seed0, uint64_t seed1, + uint64_t (*hash)(const void *item, + uint64_t seed0, uint64_t seed1), + int (*compare)(const void *a, const void *b, + void *udata), + void (*elfree)(void *item), + void *udata); +void hashmap_free(struct hashmap *map); +void hashmap_clear(struct hashmap *map, bool update_cap); +size_t hashmap_count(struct hashmap *map); +bool hashmap_oom(struct hashmap *map); +void *hashmap_get(struct hashmap *map, const void *item); +void *hashmap_set(struct hashmap *map, const void *item); +void *hashmap_delete(struct hashmap *map, void *item); +void *hashmap_probe(struct hashmap *map, uint64_t position); +bool hashmap_scan(struct hashmap *map, + bool (*iter)(const void *item, void *udata), void *udata); +bool hashmap_iter(struct hashmap *map, size_t *i, void **item); + +uint64_t hashmap_sip(const void *data, size_t len, + uint64_t seed0, uint64_t seed1); +uint64_t hashmap_murmur(const void *data, size_t len, + uint64_t seed0, uint64_t seed1); + + +// DEPRECATED: use `hashmap_new_with_allocator` +void hashmap_set_allocator(void *(*malloc)(size_t), void (*free)(void*)); + +#endif /* HASHMAP_H */ diff --git a/include/header.h b/include/header.h new file mode 100644 index 0000000000000000000000000000000000000000..738a3868a6829e62814268099e40b02d72761164 --- /dev/null +++ b/include/header.h @@ -0,0 +1,140 @@ +/** + * @file include/header.h + * @brief Parser for layer 3 and 4 headers (currently only IP, UDP and TCP) + * + * Parser for layer 3 and 4 headers. + * Currently supported protocols: + * - Layer 3: + * - IP + * - Layer 4: + * - UDP + * - TCP + * + * @date 2022-09-09 + * + * @copyright Copyright (c) 2022 + * + */ + +#ifndef _PROTOCOL_PARSERS_HEADER_ +#define _PROTOCOL_PARSERS_HEADER_ + +#include <stdlib.h> +#include <stdio.h> +#include <stdint.h> +#include <string.h> +#include <arpa/inet.h> +#include "packet_utils.h" + +#define IPV6_HEADER_LENGTH 40 +#define UDP_HEADER_LENGTH 8 + + +/** + * IP protocols assigned to their protocol number + */ +typedef enum { + ICMP = 1, + IGMP = 2, + TCP = 6, + UDP = 17 +} ip_protocol_t; + +/** + * Retrieve the length of a packet's IPv4 header. + * + * @param data a pointer to the start of the packet's IPv4 header + * @return the size, in bytes, of the IPv4 header + */ +size_t get_ipv4_header_length(uint8_t *data); + +/** + * Retrieve the length of a packet's IPv6 header. + * + * @param data a pointer to the start of the packet's IPv6 header + * @return the size, in bytes, of the IPv6 header + */ +size_t get_ipv6_header_length(uint8_t *data); + +/** + * Retrieve the length of a packet's UDP header. + * + * @param data a pointer to the start of the packet's UDP (layer 4) header + * @return the size, in bytes, of the UDP header + */ +size_t get_udp_header_length(uint8_t *data); + +/** + * Retrieve the length of a packet's TCP header. + * + * @param data a pointer to the start of the packet's TCP (layer 4) header + * @return the size, in bytes, of the UDP header + */ +size_t get_tcp_header_length(uint8_t *data); + +/** + * Retrieve the length of a packet's layer 3 header (IPv4 or IPv6). + * + * @param data a pointer to the start of the packet's layer 3 header + * @return the size, in bytes, of the layer 3 header + */ +size_t get_l3_header_length(uint8_t *data); + +/** + * Retrieve the length of a packet's layer-3 and layer-4 headers. + * + * @param data a pointer to the start of the packet's layer-3 header + * @return the size, in bytes, of the UDP header + */ +size_t get_headers_length(uint8_t* data); + +/** + * @brief Retrieve the length of a UDP payload. + * + * @param data pointer to the start of the UDP header + * @return length of the UDP payload, in bytes + */ +uint16_t get_udp_payload_length(uint8_t *data); + +/** + * @brief Retrieve the source port from a layer 4 header. + * + * @param data pointer to the start of the layer 4 header + * @return destination port + */ +uint16_t get_dst_port(uint8_t* data); + +/** + * @brief Retrieve the source address from an IPv4 header. + * + * @param data pointer to the start of the IPv4 header + * @return source IPv4 address, in network byte order + */ +uint32_t get_ipv4_src_addr(uint8_t *data); + +/** + * @brief Retrieve the destination address from an IPv4 header. + * + * @param data pointer to the start of the IPv4 header + * @return destination IPv4 address, in network byte order + */ +uint32_t get_ipv4_dst_addr(uint8_t *data); + +/** + * @brief Retrieve the source address from an IPv6 header. + * + * @param data pointer to the start of the IPv6 header + * @return source IPv6 address, as a 16-byte array + */ +uint8_t* get_ipv6_src_addr(uint8_t *data); + +/** + * @brief Retrieve the destination address from an IPv6 header. + * + * @param data pointer to the start of the IPv6 header + * @return destination IPv6 address, as a 16-byte array + */ +uint8_t* get_ipv6_dst_addr(uint8_t *data); + + +#endif /* _PROTOCOL_PARSERS_HEADER_ */ diff --git a/include/http.h b/include/http.h new file mode 100644 index 0000000000000000000000000000000000000000..909efe5d12b8f327b2a538fd40fe50df5aede4e0 --- /dev/null +++ b/include/http.h @@ -0,0 +1,101 @@ +/** + * @file include/http.h + * @brief HTTP message parser + * @date 2022-09-09 + * + * @copyright Copyright (c) 2022 + * + */ + +#ifndef _PROTOCOL_PARSERS_HTTP_ +#define _PROTOCOL_PARSERS_HTTP_ + +#include <stdlib.h> +#include <stdio.h> +#include <stdint.h> +#include <stdbool.h> + +#define HTTP_MESSAGE_MIN_LEN 16 // Minimum length of a HTTP message +#define HTTP_METHOD_MAX_LEN 7 // Maximum length of a HTTP method +#define HTTP_URI_DEFAULT_LEN 100 // Default length of a HTTP URI + + +/** + * HTTP methods + */ +typedef enum +{ + HTTP_GET, + HTTP_HEAD, + HTTP_POST, + HTTP_PUT, + HTTP_DELETE, + HTTP_CONNECT, + HTTP_OPTIONS, + HTTP_TRACE, + HTTP_UNKNOWN +} http_method_t; + +/** + * Abstraction of a HTTP message + */ +typedef struct http_message { + bool is_request; // True if the message is a request, false if it is a response + http_method_t method; // HTTP method (GET, POST, etc.) + char *uri; // Message URI +} http_message_t; + + +////////// FUNCTIONS ////////// + +///// PARSING ///// + +/** + * @brief Check if a TCP message is a HTTP message. + * + * @param data pointer to the start of the TCP payload + * @param dst_port TCP destination port + * @return true if the message is a HTTP message + * @return false if the message is not a HTTP message + */ +bool is_http(uint8_t *data); + +/** + * @brief Parse the method and URI of HTTP message. + * + * @param data pointer to the start of the HTTP message + * @param src_port TCP destination port + * @return the parsed HTTP message + */ +http_message_t http_parse_message(uint8_t *data, uint16_t dst_port); + + +///// DESTROY ///// + +/** + * @brief Free the memory allocated for a HTTP message. + * + * @param message the HTTP message to free + */ +void http_free_message(http_message_t message); + + +///// PRINTING ///// + +/** + * @brief Converts a HTTP method from enum value to character string. + * + * @param method the HTTP method in enum value + * @return the same HTTP method as a character string + */ +char* http_method_to_str(http_method_t method); + +/** + * @brief Print an HTTP message. + * + * @param message the HTTP message to print + */ +void http_print_message(http_message_t message); + + +#endif /* _PROTOCOL_PARSERS_HTTP_ */ diff --git a/include/igmp.h b/include/igmp.h new file mode 100644 index 0000000000000000000000000000000000000000..2b04cf2ccf5fab9398b936672f779e29fd8a2334 --- /dev/null +++ b/include/igmp.h @@ -0,0 +1,121 @@ +/** + * @file include/igmp.h + * @brief IGMP message parser + * @date 2022-10-05 + * + * IGMP message parser. + * Supports v1 and v2, and v3 Membership Report messages. + * TODO: support v3 Membership Query messages. + * + * @copyright Copyright (c) 2022 + * + */ + +#ifndef _PROTOCOL_PARSERS_IGMP_ +#define _PROTOCOL_PARSERS_IGMP_ + +#include <stdio.h> +#include <stdint.h> +#include "packet_utils.h" + + +/** + * @brief IGMP message types + */ +typedef enum { + MEMBERSHIP_QUERY = 0x11, + V1_MEMBERSHIP_REPORT = 0x12, + V2_MEMBERSHIP_REPORT = 0x16, + LEAVE_GROUP = 0x17, + V3_MEMBERSHIP_REPORT = 0x22 +} igmp_message_type_t; + +/** + * @brief IGMPv2 message + */ +typedef struct { + uint8_t max_resp_time; + uint16_t checksum; + uint32_t group_address; // IPv4 group address, in network byte order +} igmp_v2_message_t; + +/** + * @brief IGMPv3 membership query + */ +typedef struct { + uint8_t max_resp_code; + uint16_t checksum; + uint32_t group_address; // IPv4 group address, in network byte order + uint8_t flags; // Resv, S, QRV + uint8_t qqic; + uint16_t num_sources; + uint32_t *sources; // Array of IPv4 addresses, in network byte order +} igmp_v3_membership_query_t; + +/** + * @brief IGMPv3 Group Record + */ +typedef struct { + uint8_t type; + uint8_t aux_data_len; + uint16_t num_sources; + uint32_t group_address; // IPv4 group address, in network byte order + uint32_t *sources; // Array of IPv4 addresses, in network byte order +} igmp_v3_group_record_t; + +/** + * @brief IGMPv3 membership report + */ +typedef struct { + uint16_t checksum; + uint16_t num_groups; + igmp_v3_group_record_t *groups; // Array of group records +} igmp_v3_membership_report_t; + +/** + * @brief IGMP message body. + */ +typedef union +{ + igmp_v2_message_t v2_message; + igmp_v3_membership_query_t v3_membership_query; + igmp_v3_membership_report_t v3_membership_report; +} igmp_message_body_t; + +/** + * @brief Generic IGMP message + */ +typedef struct +{ + uint8_t version; + igmp_message_type_t type; + igmp_message_body_t body; +} igmp_message_t; + + +////////// FUNCTIONS ////////// + +/** + * @brief Parse an IGMP message. + * + * @param data pointer to the start of the IGMP message + * @return the parsed IGMP message + */ +igmp_message_t igmp_parse_message(uint8_t *data); + +/** + * @brief Free the memory allocated for an IGMP message. + * + * @param message the IGMP message to free + */ +void igmp_free_message(igmp_message_t message); + +/** + * @brief Print an IGMP message. + * + * @param message the IGMP message to print + */ +void igmp_print_message(igmp_message_t message); + + +#endif /* _PROTOCOL_PARSERS_IGMP_ */ diff --git a/include/packet_utils.h b/include/packet_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..9f6819f676ee233ebae756234ff71b3b3cabc716 --- /dev/null +++ b/include/packet_utils.h @@ -0,0 +1,184 @@ +/** + * @file include/packet_utils.h + * @brief Utilitaries for payload manipulation and display + * @date 2022-09-09 + * + * @copyright Copyright (c) 2022 + * + */ + +#ifndef _PROTOCOL_PARSERS_PACKET_UTILS_ +#define _PROTOCOL_PARSERS_PACKET_UTILS_ + +#include <stdlib.h> +#include <stdio.h> +#include <stdint.h> +#include <stdbool.h> +#include <string.h> +#include <arpa/inet.h> +#include "sha256.h" + +#define MAC_ADDR_LENGTH 6 +#define MAC_ADDR_STRLEN 18 +#define IPV4_ADDR_LENGTH 4 +#define IPV6_ADDR_LENGTH 16 + +/** + * @brief IP (v4 or v6) address value + */ +typedef union { + uint32_t ipv4; // IPv4 address, as a 32-bit unsigned integer in network byte order + uint8_t ipv6[IPV6_ADDR_LENGTH]; // IPv6 address, as a 16-byte array +} ip_val_t; + +/** + * @brief IP (v4 or v6) address + */ +typedef struct { + uint8_t version; // IP version (4 or 6, 0 if not set) + ip_val_t value; // IP address value (0 if not set) +} ip_addr_t; + +/** + * Print a packet payload. + * + * @param length length of the payload in bytes + * @param data pointer to the start of the payload + */ +void print_payload(int length, uint8_t *data); + +/** + * Converts a hexstring payload to a data buffer. + * + * @param hexstring the hexstring to convert + * @param payload a double pointer to the payload, which will be set to the start of the payload + * @return the length of the payload in bytes + */ +size_t hexstr_to_payload(char *hexstring, uint8_t **payload); + +/** + * Converts a MAC address from its hexadecimal representation + * to its string representation. + * + * @param mac_hex MAC address in hexadecimal representation + * @return the same MAC address in string representation + */ +char *mac_hex_to_str(uint8_t mac_hex[]); + +/** + * Converts a MAC address from its string representation + * to its hexadecimal representation. + * + * @param mac_str MAC address in string representation + * @return the same MAC address in hexadecimal representation + */ +uint8_t *mac_str_to_hex(char *mac_str); + +/** + * Converts an IPv4 address from its network order numerical representation + * to its string representation. + * (Wrapper arount inet_ntoa) + * + * @param ipv4_net IPv4 address in hexadecimal representation + * @return the same IPv4 address in string representation + */ +char* ipv4_net_to_str(uint32_t ipv4_net); + +/** + * Converts an IPv4 address from its string representation + * to its network order numerical representation. + * (Wrapper arount inet_aton) + * + * @param ipv4_str IPv4 address in string representation + * @return the same IPv4 address in network order numerical representation + */ +uint32_t ipv4_str_to_net(char *ipv4_str); + +/** + * Converts an IPv4 addres from its hexadecimal representation + * to its string representation. + * + * @param ipv4_hex IPv4 address in hexadecimal representation + * @return the same IPv4 address in string representation + */ +char* ipv4_hex_to_str(char *ipv4_hex); + +/** + * Converts an IPv4 address from its string representation + * to its hexadecimal representation. + * + * @param ipv4_str IPv4 address in string representation + * @return the same IPv4 address in hexadecimal representation + */ +char* ipv4_str_to_hex(char *ipv4_str); + +/** + * @brief Converts an IPv6 address to its string representation. + * + * @param ipv6 the IPv6 address + * @return the same IPv6 address in string representation + */ +char* ipv6_net_to_str(uint8_t ipv6[]); + +/** + * Converts an IPv6 address from its string representation + * to its network representation (a 16-byte array). + * + * @param ipv6_str IPv6 address in string representation + * @return the same IPv6 address as a 16-byte array + */ +uint8_t* ipv6_str_to_net(char *ipv6_str); + +/** + * @brief Converts an IP (v4 or v6) address to its string representation. + * + * @param ip_addr the IP address, as an ip_addr_t struct + * @return the same IP address in string representation + */ +char* ip_net_to_str(ip_addr_t ip_addr); + +/** + * Converts an IP (v4 or v6) address from its string representation + * to an ip_addr_t struct. + * + * @param ip_str IP (v4 or v6) address in string representation + * @return the same IP address as a ip_addr_t struct + */ +ip_addr_t ip_str_to_net(char *ip_str, uint8_t version); + +/** + * @brief Compare two IPv6 addresses. + * + * @param ipv6_1 first IPv6 address + * @param ipv6_2 second IPv6 address + * @return true if the two addresses are equal, false otherwise + */ +bool compare_ipv6(uint8_t *ipv6_1, uint8_t *ipv6_2); + +/** + * @brief Compare two IP (v4 or v6) addresses. + * + * @param ip_1 first IP address + * @param ip_2 second IP address + * @return true if the two addresses are equal, false otherwise + */ +bool compare_ip(ip_addr_t ip_1, ip_addr_t ip_2); + +/** + * @brief Compute SHA256 hash of a given payload. + * + * @param payload Payload to hash + * @param payload_len Payload length, including padding (in bytes) + * @return uint8_t* SHA256 hash of the payload + */ +uint8_t* compute_hash(uint8_t *payload, int payload_len); + +/** + * @brief Print a SHA256 hash. + * + * @param hash SHA256 hash to print + */ +void print_hash(uint8_t *hash); + + +#endif /* _PROTOCOL_PARSERS_PACKET_UTILS_ */ diff --git a/include/sha256.h b/include/sha256.h new file mode 100644 index 0000000000000000000000000000000000000000..7123a30dd49628d6ca15345c33968c50ca328cb7 --- /dev/null +++ b/include/sha256.h @@ -0,0 +1,34 @@ +/********************************************************************* +* Filename: sha256.h +* Author: Brad Conte (brad AT bradconte.com) +* Copyright: +* Disclaimer: This code is presented "as is" without any guarantees. +* Details: Defines the API for the corresponding SHA1 implementation. +*********************************************************************/ + +#ifndef SHA256_H +#define SHA256_H + +/*************************** HEADER FILES ***************************/ +#include <stddef.h> + +/****************************** MACROS ******************************/ +#define SHA256_BLOCK_SIZE 32 // SHA256 outputs a 32 byte digest + +/**************************** DATA TYPES ****************************/ +typedef unsigned char BYTE; // 8-bit byte +typedef unsigned int WORD; // 32-bit word, change to "long" for 16-bit machines + +typedef struct { + BYTE data[64]; + WORD datalen; + unsigned long long bitlen; + WORD state[8]; +} SHA256_CTX; + +/*********************** FUNCTION DECLARATIONS **********************/ +void sha256_init(SHA256_CTX *ctx); +void sha256_update(SHA256_CTX *ctx, const BYTE data[], size_t len); +void sha256_final(SHA256_CTX *ctx, BYTE hash[]); + +#endif // SHA256_H diff --git a/include/ssdp.h b/include/ssdp.h new file mode 100644 index 0000000000000000000000000000000000000000..4e7beafe853b1cf937d2da73700d465691c1c4eb --- /dev/null +++ b/include/ssdp.h @@ -0,0 +1,73 @@ +/** + * @file include/ssdp.h + * @brief SSDP message parser + * @date 2022-11-24 + * + * @copyright Copyright (c) 2022 + * + */ + +#ifndef _PROTOCOL_PARSERS_SSDP_ +#define _PROTOCOL_PARSERS_SSDP_ + +#include <stdlib.h> +#include <stdio.h> +#include <stdint.h> +#include <stdbool.h> +#include <arpa/inet.h> +#include "packet_utils.h" + +#define SSDP_METHOD_MAX_LEN 8 // Maximum length of a SSDP method +#define SSDP_MULTICAST_ADDR "239.255.255.250" // SSDP multicast group address + +/** + * SSDP methods + */ +typedef enum { + SSDP_M_SEARCH, + SSDP_NOTIFY, + SSDP_UNKNOWN +} ssdp_method_t; + +/** + * Abstraction of an SSDP message + */ +typedef struct ssdp_message { + bool is_request; // True if the message is a request, false if it is a response + ssdp_method_t method; // SSDP method (M-SEARCH or NOTIFY) +} ssdp_message_t; + + +////////// FUNCTIONS ////////// + +///// PARSING ///// + +/** + * @brief Parse the method and URI of SSDP message. + * + * @param data pointer to the start of the SSDP message + * @param dst_addr IPv4 destination address, in network byte order + * @return the parsed SSDP message + */ +ssdp_message_t ssdp_parse_message(uint8_t *data, uint32_t dst_addr); + + +///// PRINTING ///// + +/** + * @brief Converts a SSDP method from enum value to character string. + * + * @param method the SSDP method in enum value + * @return the same SSDP method as a character string + */ +char *ssdp_method_to_str(ssdp_method_t method); + +/** + * @brief Print the method and URI of a SSDP message. + * + * @param message the message to print + */ +void ssdp_print_message(ssdp_message_t message); + + +#endif /* _PROTOCOL_PARSERS_SSDP_ */ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..f066ccc22fc86928610a2182545687091a52af7c --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,53 @@ +# Minimum required CMake version +cmake_minimum_required(VERSION 3.20) + +# hashmap +add_library(hashmap STATIC ${INCLUDE_DIR}/hashmap.h hashmap.c) +target_include_directories(hashmap PRIVATE ${INCLUDE_DIR}) +install(TARGETS hashmap DESTINATION ${LIB_DIR}) + +# SHA256 +add_library(sha256 STATIC ${INCLUDE_DIR}/sha256.h sha256.c) +target_include_directories(sha256 PRIVATE ${INCLUDE_DIR}) +install(TARGETS sha256 DESTINATION ${LIB_DIR}) + +# packet_utils +add_library(packet_utils STATIC ${INCLUDE_DIR}/packet_utils.h packet_utils.c) +target_link_libraries(packet_utils sha256) +target_include_directories(packet_utils PRIVATE ${INCLUDE_DIR}) +install(TARGETS packet_utils DESTINATION ${LIB_DIR}) + +## Protocol parsers +# Header parser +add_library(header STATIC ${INCLUDE_DIR}/header.h header.c) +target_include_directories(header PRIVATE ${INCLUDE_DIR} ${INCLUDE_DIR}) +target_link_libraries(header packet_utils) +# DNS parser +add_library(dns STATIC ${INCLUDE_DIR}/dns.h dns.c) +target_include_directories(dns PRIVATE ${INCLUDE_DIR} ${INCLUDE_DIR}) +target_link_libraries(dns packet_utils dns_map) +# DHCP parser +add_library(dhcp STATIC ${INCLUDE_DIR}/dhcp.h dhcp.c) +target_include_directories(dhcp PRIVATE ${INCLUDE_DIR} ${INCLUDE_DIR}) +# HTTP parser +add_library(http STATIC ${INCLUDE_DIR}/http.h http.c) +target_include_directories(http PRIVATE ${INCLUDE_DIR} ${INCLUDE_DIR}) +# IGMP parser +add_library(igmp STATIC ${INCLUDE_DIR}/igmp.h igmp.c) +target_include_directories(igmp PRIVATE ${INCLUDE_DIR} ${INCLUDE_DIR}) +# SSDP parser +add_library(ssdp STATIC ${INCLUDE_DIR}/ssdp.h ssdp.c) +target_include_directories(ssdp PRIVATE ${INCLUDE_DIR} ${INCLUDE_DIR}) +# CoAP parser +add_library(coap STATIC ${INCLUDE_DIR}/coap.h coap.c) +target_include_directories(coap PRIVATE ${INCLUDE_DIR} ${INCLUDE_DIR}) +target_link_libraries(coap http) + +# DNS map +add_library(dns_map STATIC ${INCLUDE_DIR}/dns_map.h dns_map.c) +target_link_libraries(dns_map hashmap) +target_include_directories(dns_map PRIVATE ${INCLUDE_DIR}) +install(TARGETS dns_map DESTINATION ${LIB_DIR}) + +# Installation +install(TARGETS ${PARSERS} DESTINATION ${LIB_DIR}) diff --git a/src/coap.c b/src/coap.c new file mode 100644 index 0000000000000000000000000000000000000000..649f51d8538cff5222aabbb81b4b8395cd46b4d8 --- /dev/null +++ b/src/coap.c @@ -0,0 +1,201 @@ +/** + * @file src/coap.c + * @brief CoAP message parser + * @date 2022-11-30 + * + * @copyright Copyright (c) 2022 + * + */ + +#include "coap.h" + + +///// PARSING ///// + +/** + * @brief Parse the method of a CoAP message. + * + * @param code byte which encodes the CoAP method + * @return CoAP method + */ +static http_method_t coap_parse_method(uint8_t code) { + switch (code) { + case 1: + return HTTP_GET; + break; + case 2: + return HTTP_POST; + break; + case 3: + return HTTP_PUT; + break; + case 4: + return HTTP_DELETE; + break; + default: + // CoAP responses and all other codes are not supported + return HTTP_UNKNOWN; + } +} + +/** + * @brief Parse an URI option (Uri-Path or Uri-Query) of a CoAP message. + * + * @param message pointer to the CoAP message, which will be updated + * @param option CoAP option number (11 for Uri-Path, 15 for Uri-Query) + * @param length CoAP option length + * @param data pointer to the start of the URI option + */ +static void coap_parse_uri_option(coap_message_t *message, coap_option_t option_num, uint16_t length, uint8_t *data) { + char prefix = (option_num == COAP_URI_PATH) ? '/' : '?'; + if (message->uri == NULL) { + message->uri = malloc(length + 2); + } else { + message->uri = realloc(message->uri, message->uri_len + length + 2); + } + *(message->uri + message->uri_len) = prefix; + memcpy(message->uri + message->uri_len + 1, data, length); + message->uri_len += length + 1; + *(message->uri + message->uri_len) = '\0'; +} + +/** + * @brief Parse CoAP options. + * + * @param message pointer to the currently parsed CoAP message, which will be updated + * @param data pointer to the start of the options section of a CoAP message + * @param msg_length length of the rest of the CoAP message (after the header) + */ +static void coap_parse_options(coap_message_t *message, uint8_t *data, uint16_t msg_length) { + uint16_t option_num = 0; + uint16_t bytes_read = 0; + while (bytes_read < msg_length && *data != 0b11111111) + { + // Parse option delta + uint16_t delta = (*data) >> 4; + uint8_t delta_len = 0; // Length of the extended delta field + switch (delta) { + case 13: + delta = (*(data + 1)) + 13; + delta_len = 1; + break; + case 14: + delta = ntohs(*((uint16_t*) (data + 1))) + 269; + delta_len = 2; + break; + case 15: + continue; + break; + default: + break; + } + // Compute option number + option_num += delta; + + // Parse option length + uint16_t option_length = (*data) & 0b00001111; + uint8_t length_len = 0; // Length of the extended length field + switch (option_length) + { + case 13: + option_length = (*(data + 1 + delta_len)) + 13; + length_len = 1; + break; + case 14: + option_length = ntohs(*((uint16_t *)(data + 1 + delta_len))) + 269; + length_len = 2; + break; + case 15: + continue; + break; + default: + break; + } + + // Parse option value + data += 1 + delta_len + length_len; + if (option_num == COAP_URI_PATH || option_num == COAP_URI_QUERY) + { + // Option Uri-Path or Uri-Query + coap_parse_uri_option(message, option_num, option_length, data); + } + data += option_length; + bytes_read += 1 + delta_len + length_len + option_length; + // Other options are not supported (yet) + } +} + +/** + * @brief Parse a CoAP message. + * + * @param data pointer to the start of the CoAP message + * @param length length of the CoAP message, in bytes + * @return the parsed CoAP message + */ +coap_message_t coap_parse_message(uint8_t *data, uint16_t length) +{ + coap_message_t message; + message.type = (coap_type_t) (((*data) & 0b00110000) >> 4); // CoAP type is encoded in bits 2-3 + message.method = coap_parse_method(*(data + 1)); // CoAP method is encoded in byte 1 + uint8_t token_length = (*data) & 0b00001111; // CoAP token length is encoded in bits 4-7 + uint8_t header_length = 4 + token_length; // Length of the CoAP header + data += header_length; // Skip the header + message.uri = NULL; // Initialize the URI to NULL + message.uri_len = 0; + coap_parse_options(&message, data, length - header_length); // Parse CoAP options + return message; +} + + +///// DESTROY ///// + +/** + * @brief Free the memory allocated for a CoAP message. + * + * @param message the CoAP message to free + */ +void coap_free_message(coap_message_t message) { + if (message.uri != NULL) + free(message.uri); +} + + +///// PRINTING ///// + +/** + * @brief Converts a CoAP message type to its string representation. + * + * @param type CoAP message type + * @return string representation of the CoAP message type + */ +static char* coap_type_to_str(coap_type_t type) { + switch (type) { + case COAP_CON: + return "Confirmable"; + break; + case COAP_NON: + return "Non-Confirmable"; + break; + case COAP_ACK: + return "Acknowledgement"; + break; + case COAP_RST: + return "Reset"; + break; + default: + return "Unknown"; + } +} + +/** + * @brief Print a CoAP message. + * + * @param message the CoAP message to print + */ +void coap_print_message(coap_message_t message) +{ + printf("CoAP message:\n"); + printf(" Type: %s\n", coap_type_to_str(message.type)); + printf(" Method: %s\n", http_method_to_str(message.method)); + printf(" URI: %s\n", message.uri); +} diff --git a/src/dhcp.c b/src/dhcp.c new file mode 100644 index 0000000000000000000000000000000000000000..dc245dbc0429250e9ba6014c8c5cb31b443ef0aa --- /dev/null +++ b/src/dhcp.c @@ -0,0 +1,245 @@ +/** + * @file src/dhcp.c + * @brief DHCP message parser + * @date 2022-09-12 + * + * @copyright Copyright (c) 2022 + * + */ + +#include "dhcp.h" + + +///// PARSING ///// + +/** + * @brief Parse the header of a DHCP message (not including options). + * + * @param data a pointer to the start of the DHCP message + * @return the parsed DHCP message with the header fields filled in + */ +dhcp_message_t dhcp_parse_header(uint8_t *data) { + dhcp_message_t message; + // Opcode: 1 byte + message.op = *data; + // htype: 1 byte + message.htype = *(data + 1); + // hlen: 1 byte + message.hlen = *(data + 2); + // hops: 1 byte + message.hops = *(data + 3); + // xid: 4 bytes + message.xid = ntohl(*((uint32_t *) (data + 4))); + // secs: 2 bytes + message.secs = ntohs(*((uint16_t *) (data + 8))); + // flags: 2 bytes + message.flags = ntohs(*((uint16_t *) (data + 10))); + // The IP addresses are left in network byte order + // ciaddr: 4 bytes + message.ciaddr = *((uint32_t *) (data + 12)); + // yiaddr: 4 bytes + message.yiaddr = *((uint32_t *) (data + 16)); + // siaddr: 4 bytes + message.siaddr = *((uint32_t *) (data + 20)); + // giaddr: 4 bytes + message.giaddr = *((uint32_t *) (data + 24)); + // chaddr: 16 bytes + memcpy(message.chaddr, data + 28, sizeof(uint8_t) * 16); + // sname: 64 bytes + memcpy(message.sname, data + 44, sizeof(uint8_t) * 64); + // file: 128 bytes + memcpy(message.file, data + 108, sizeof(uint8_t) * 128); + return message; +} + +/** + * @brief Parse a DHCP option. + * + * @param data a pointer to the start of the DHCP option + * @param offset a pointer to the current offset inside the DHCP message + * Its value will be updated to point to the next option + * @return the parsed DHCP option + */ +dhcp_option_t dhcp_parse_option(uint8_t *data, uint16_t *offset) { + dhcp_option_t option; + option.code = *(data + *offset); + if (option.code == DHCP_PAD || option.code == DHCP_END) + { + option.length = 0; + option.value = NULL; + *offset += 1; + } + else + { + option.length = *(data + *offset + 1); + option.value = (uint8_t *) malloc(sizeof(uint8_t) * option.length); + memcpy(option.value, data + *offset + 2, option.length * sizeof(uint8_t)); + *offset += 2 + option.length; + } + return option; +} + +/** + * @brief Parse DHCP options. + * + * @param data a pointer to the start of the DHCP options list + * @return a pointer to the start of the parsed DHCP options + */ +dhcp_options_t dhcp_parse_options(uint8_t *data) { + // Init + uint8_t max_option_count = DHCP_MAX_OPTION_COUNT; + dhcp_options_t options; + options.count = 0; + // Check magic cookie is equal to 0x63825363 + uint32_t magic_cookie = ntohl(*((uint32_t *) data)); + if (magic_cookie != DHCP_MAGIC_COOKIE) { + fprintf(stderr, "Error: DHCP magic cookie is %#x, which is not equal to %#x\n", magic_cookie, DHCP_MAGIC_COOKIE); + return options; + } + // Parse options + options.options = (dhcp_option_t *) malloc(sizeof(dhcp_option_t) * max_option_count); + uint16_t offset = 4; + uint8_t code; + do { + if (options.count == max_option_count) { + // Realloc memory if too many options + max_option_count *= 2; + options.options = (dhcp_option_t *) realloc(options.options, sizeof(dhcp_option_t) * max_option_count); + } + dhcp_option_t option = dhcp_parse_option(data, &offset); + code = option.code; + if (code == DHCP_MESSAGE_TYPE) { + // Store DHCP message type + options.message_type = *option.value; + } + *(options.options + (options.count++)) = option; + } while (code != DHCP_END); + // Shrink allocated memory to the actual number of options, if needed + if (options.count < max_option_count) { + options.options = (dhcp_option_t *) realloc(options.options, sizeof(dhcp_option_t) * options.count); + } + return options; +} + +/** + * @brief Parse a DHCP message. + * + * @param data a pointer to the start of the DHCP message + * @return the parsed DHCP message + */ +dhcp_message_t dhcp_parse_message(uint8_t *data) { + // Parse constant fields + dhcp_message_t message = dhcp_parse_header(data); + // Parse DHCP options + message.options = dhcp_parse_options(data + DHCP_HEADER_LEN); + // Return + return message; +} + + +///// DESTROY ////// + +/** + * @brief Free the memory allocated for a DHCP message. + * + * @param message the DHCP message to free + */ +void dhcp_free_message(dhcp_message_t message) { + if (message.options.count > 0) { + for (uint8_t i = 0; i < message.options.count; i++) { + dhcp_option_t option = *(message.options.options + i); + if (option.length > 0) { + free(option.value); + } + } + free(message.options.options); + } +} + + +///// PRINTING ///// + +/** + * @brief Print a hardware address. + * + * @param htype hardware type + * @param chaddr the hardware address to print + */ +static void dhcp_print_chaddr(uint8_t htype, uint8_t chaddr[]) { + printf(" Client hardware address: "); + uint8_t length = (htype == 1) ? 6 : 16; + printf("%02hhx", chaddr[0]); + for (uint8_t i = 1; i < length; i++) { + printf(":%02hhx", chaddr[i]); + } + printf("\n"); +} + +/** + * @brief Print the header of a DHCP message. + * + * @param message the DHCP message to print the header of + */ +void dhcp_print_header(dhcp_message_t message) { + // Opcode + printf(" Opcode: %hhu\n", message.op); + // htype + printf(" Hardware type: %hhu\n", message.htype); + // hlen + printf(" Hardware address length: %hhu\n", message.hlen); + // hops + printf(" Hops: %hhu\n", message.hops); + // xid + printf(" Transaction ID: %#x\n", message.xid); + // secs + printf(" Seconds elapsed: %hu\n", message.secs); + // flags + printf(" Flags: 0x%04x\n", message.flags); + // ciaddr + printf(" Client IP address: %s\n", inet_ntoa((struct in_addr) {message.ciaddr})); + // yiaddr + printf(" Your IP address: %s\n", inet_ntoa((struct in_addr) {message.yiaddr})); + // siaddr + printf(" Server IP address: %s\n", inet_ntoa((struct in_addr) {message.siaddr})); + // giaddr + printf(" Gateway IP address: %s\n", inet_ntoa((struct in_addr) {message.giaddr})); + // chaddr + dhcp_print_chaddr(message.htype, message.chaddr); + // sname + if (strlen((char *) message.sname) > 0) { + printf(" Server name: %s\n", message.sname); + } + // file + if (strlen((char *) message.file) > 0) { + printf(" Boot file name: %s\n", message.file); + } +} + +/** + * @brief Print a DHCP option. + * + * @param option the DHCP option to print + */ +void dhcp_print_option(dhcp_option_t option) { + printf(" Code: %hhu; Length: %hhu; Value: ", option.code, option.length); + for (uint8_t i = 0; i < option.length; i++) { + printf("%02hhx ", *(option.value + i)); + } + printf("\n"); +} + +/** + * @brief Print a DHCP message. + * + * @param message the DHCP message to print + */ +void dhcp_print_message(dhcp_message_t message) { + printf("DHCP message\n"); + // Print header fields + dhcp_print_header(message); + // Print DHCP options + printf(" DHCP options:\n"); + for (uint8_t i = 0; i < message.options.count; i++) { + dhcp_print_option(*(message.options.options + i)); + } +} diff --git a/src/dns.c b/src/dns.c new file mode 100644 index 0000000000000000000000000000000000000000..5c3126517eaa2b5b23cf42f513f2fa13a07390cd --- /dev/null +++ b/src/dns.c @@ -0,0 +1,584 @@ +/** + * @file src/dns.c + * @brief DNS message parser + * @date 2022-09-09 + * + * @copyright Copyright (c) 2022 + * + */ + +#include "dns.h" + + +///// PARSING ///// + +/** + * Parse a DNS header. + * A DNS header is always 12 bytes. + * + * @param data a pointer pointing to the start of the DNS message + * @param offset a pointer to the current parsing offset + * @return the parsed header + */ +dns_header_t dns_parse_header(uint8_t *data, uint16_t *offset) { + // Init + dns_header_t header; + // Parse fields + header.id = ntohs(*((uint16_t *) (data + *offset))); + header.flags = ntohs(*((uint16_t *) (data + *offset + 2))); + header.qr = (header.flags & DNS_QR_FLAG_MASK); + header.qdcount = ntohs(*((uint16_t *) (data + *offset + 4))); + header.ancount = ntohs(*((uint16_t *) (data + *offset + 6))); + header.nscount = ntohs(*((uint16_t *) (data + *offset + 8))); + header.arcount = ntohs(*((uint16_t *) (data + *offset + 10))); + // Update offset to point after header + *offset += DNS_HEADER_SIZE; + + return header; +} + +/** + * Parse a DNS Domain Name. + * + * @param data a pointer pointing to the start of the DNS message + * @param offset a pointer to the current parsing offset + * @return the parsed domain name + */ +static char* dns_parse_domain_name(uint8_t *data, uint16_t *offset) { + if (*(data + *offset) == '\0') { + // Domain name is ROOT + (*offset)++; + return ""; + } + uint16_t current_length = 0; + uint16_t max_length = DNS_MAX_DOMAIN_NAME_LENGTH; + char* domain_name = (char *) malloc(sizeof(char) * max_length); + bool compression = false; + uint16_t domain_name_offset = *offset; // Other offset, might be useful for domain name compression + while (*(data + domain_name_offset) != '\0') { + uint8_t length_byte = *((uint8_t *) (data + domain_name_offset)); + if (length_byte >> 6 == 3) { // Length byte starts with 0b11 + // Domain name compression + // Advance offset by 2 bytes, and do not update it again + if(!compression) { + *offset += 2; + } + compression = true; + // Retrieve new offset to parse domain name from + domain_name_offset = ntohs(*((uint16_t *) (data + domain_name_offset))) & DNS_COMPRESSION_MASK; + } else { + // Fully written label, parse it + for (int i = 1; i <= length_byte; i++) { + if (current_length == max_length) { + // Realloc buffer + max_length *= 2; + void *realloc_ptr = realloc(domain_name, sizeof(char) * max_length); + if (realloc_ptr == NULL) { + // Handle realloc error + fprintf(stderr, "Error reallocating memory for domain name %s\n", domain_name); + free(domain_name); + return NULL; + } else { + domain_name = (char*) realloc_ptr; + } + } + char c = *(data + domain_name_offset + i); + *(domain_name + (current_length++)) = c; + } + *(domain_name + (current_length++)) = '.'; + domain_name_offset += length_byte + 1; + if (!compression) { + *offset = domain_name_offset; + } + } + } + // Domain name was fully parsed + // Overwrite last '.' written with NULL byte + *(domain_name + (--current_length)) = '\0'; + // Shrink allocated memory to fit domain name, if needed + if (current_length + 1 < max_length) { + void* realloc_ptr = realloc(domain_name, sizeof(char) * (current_length + 1)); + if (realloc_ptr == NULL) { + fprintf(stderr, "Error shrinking memory for domain name %s\n", domain_name); + } else { + domain_name = (char*) realloc_ptr; + } + } + // Advance offset after NULL terminator, if domain name compression was not used + if (!compression) { + (*offset)++; + } + return domain_name; +} + +/** + * Parse a DNS Question section. + * + * @param qdcount the number of questions present in the question section + * @param data a pointer pointing to the start of the DNS message + * @param offset a pointer to the current parsing offset + * @return the parsed question section + */ +dns_question_t* dns_parse_questions(uint16_t qdcount, uint8_t *data, uint16_t *offset) { + // Init + dns_question_t *questions = (dns_question_t *) malloc(qdcount * sizeof(dns_question_t)); + // Iterate over all questions + for (uint16_t i = 0; i < qdcount; i++) { + // Parse domain name + (questions + i)->qname = dns_parse_domain_name(data, offset); + // Parse rtype and rclass + (questions + i)->qtype = ntohs(*((uint16_t *) (data + *offset))); + (questions + i)->qclass = ntohs(*((uint16_t *) (data + *offset + 2))) & DNS_CLASS_MASK; + *offset += 4; + } + return questions; +} + +/** + * Parse a DNS Resource Record RDATA field. + * + * @param rdlength the length, in bytes, of the RDATA field + * @param data a pointer pointing to the start of the DNS message + * @param offset a pointer to the current parsing offset + * @return the parsed RDATA field + */ +static rdata_t dns_parse_rdata(dns_rr_type_t rtype, uint16_t rdlength, uint8_t *data, uint16_t *offset) { + rdata_t rdata; + if (rdlength == 0) { + // RDATA field is empty + rdata.data = NULL; + } else { + // RDATA field is not empty + switch (rtype) { + case A: + // RDATA contains an IPv4 address + rdata.ip.version = 4; + rdata.ip.value.ipv4 = *((uint32_t *) (data + *offset)); // Stored in network byte order + *offset += rdlength; + break; + case AAAA: + // RDATA contains an IPv6 address + rdata.ip.version = 6; + memcpy(rdata.ip.value.ipv6, data + *offset, rdlength); + *offset += rdlength; + break; + case NS: + case CNAME: + case PTR: + // RDATA contains is a domain name + rdata.domain_name = dns_parse_domain_name(data, offset); + break; + default: + // RDATA contains is generic data + rdata.data = (uint8_t *) malloc(sizeof(char) * rdlength); + memcpy(rdata.data, data + *offset, rdlength); + *offset += rdlength; + } + } + return rdata; +} + +/** + * Parse a DNS Resource Record list. + * @param count the number of resource records present in the section + * @param data a pointer pointing to the start of the DNS message + * @param offset a pointer to the current parsing offset + * @return the parsed resource records list + */ +dns_resource_record_t* dns_parse_rrs(uint16_t count, uint8_t *data, uint16_t *offset) { + dns_resource_record_t *rrs = (dns_resource_record_t *) malloc(count * sizeof(dns_resource_record_t)); + for (uint16_t i = 0; i < count; i++) { + // Parse domain name + (rrs + i)->name = dns_parse_domain_name(data, offset); + // Parse rtype, rclass and TTL + dns_rr_type_t rtype = ntohs(*((uint16_t *) (data + *offset))); + (rrs + i)->rtype = rtype; + (rrs + i)->rclass = ntohs(*((uint16_t *) (data + *offset + 2))) & DNS_CLASS_MASK; + (rrs + i)->ttl = ntohl(*((uint32_t *) (data + *offset + 4))); + // Parse rdata + uint16_t rdlength = ntohs(*((uint16_t *) (data + *offset + 8))); + (rrs + i)->rdlength = rdlength; + *offset += 10; + (rrs + i)->rdata = dns_parse_rdata(rtype, rdlength, data, offset); + } + return rrs; +} + +/** + * Parse a DNS message. + * + * @param data a pointer to the start of the DNS message + * @return the parsed DNS message + */ +dns_message_t dns_parse_message(uint8_t *data) { + // Init + dns_message_t message; + uint16_t offset = 0; + message.questions = NULL; + message.answers = NULL; + message.authorities = NULL; + message.additionals = NULL; + + // Parse DNS header + message.header = dns_parse_header(data, &offset); + // If present, parse DNS Question section + if (message.header.qdcount > 0) + { + message.questions = dns_parse_questions(message.header.qdcount, data, &offset); + } + // If message is a response and section is present, parse DNS Answer section + if (message.header.qr == 1 && message.header.ancount > 0) + { + message.answers = dns_parse_rrs(message.header.ancount, data, &offset); + } + + /* Parsing other sections is not necessary for this project + + // If message is a response and section is present, parse DNS Authority section + if (message.header.qr == 1 && message.header.nscount > 0) + { + message.authorities = dns_parse_rrs(message.header.nscount, data, &offset); + } + // If message is a response and section is present, parse DNS Additional section + if (message.header.qr == 1 && message.header.arcount > 0) + { + message.additionals = dns_parse_rrs(message.header.arcount, data, &offset); + } + + */ + + return message; +} + + +///// LOOKUP ///// + +/** + * @brief Check if a given string ends with a given suffix. + * + * @param str the string to check + * @param suffix the suffix to search for + * @param suffix_length the length of the suffix + * @return true if the string ends with the suffix + * @return false if the string does not end with the suffix + */ +static bool ends_with(char* str, char* suffix, uint16_t suffix_length) { + uint16_t str_length = strlen(str); + if (str_length < suffix_length) { + return false; + } + return strncmp(str + str_length - suffix_length, suffix, suffix_length) == 0; +} + +/** + * @brief Check if a given DNS Questions list contains a domain name which has a given suffix. + * + * @param questions DNS Questions list + * @param qdcount number of Questions in the list + * @param suffix the domain name suffix to search for + * @param suffix_length the length of the domain name suffix + * @return true if a domain name with the given suffix is found is found in the Questions list, + * false otherwise + */ +bool dns_contains_suffix_domain_name(dns_question_t *questions, uint16_t qdcount, char *suffix, uint16_t suffix_length) { + for (uint16_t i = 0; i < qdcount; i++) { + if (ends_with((questions + i)-> qname, suffix, suffix_length)) { + return true; + } + } + return false; +} + +/** + * @brief Check if a given domain name is fully contained in a DNS Questions list. + * + * @param questions DNS Questions list + * @param qdcount number of Questions in the list + * @param domain_name the domain name to search for + * @return true if the full domain name is found in the Questions list, false otherwise + */ +bool dns_contains_full_domain_name(dns_question_t *questions, uint16_t qdcount, char *domain_name) +{ + for (uint16_t i = 0; i < qdcount; i++) { + if (strcmp((questions + i)->qname, domain_name) == 0) { + return true; + } + } + return false; +} + +/** + * @brief Search for a specific domain name in a DNS Questions list. + * + * @param questions DNS Questions list + * @param qdcount number of Suestions in the list + * @param domain_name the domain name to search for + * @return the DNS Question related to the given domain name, or NULL if not found + */ +dns_question_t* dns_get_question(dns_question_t *questions, uint16_t qdcount, char *domain_name) { + for (uint16_t i = 0; i < qdcount; i++) { + if (strcmp((questions + i)->qname, domain_name) == 0) { + return questions + i; + } + } + return NULL; +} + +/** + * @brief Retrieve the IP addresses corresponding to a given domain name in a DNS Answers list. + * + * Searches a DNS Answer list for a specific domain name and returns the corresponding IP address. + * Processes each Answer recursively if the Answer Type is a CNAME. + * + * @param answers DNS Answers list to search in + * @param ancount number of Answers in the list + * @param domain_name domain name to search for + * @return struct ip_list representing the list of corresponding IP addresses + */ +ip_list_t dns_get_ip_from_name(dns_resource_record_t *answers, uint16_t ancount, char *domain_name) { + ip_list_t ip_list; + ip_list.ip_count = 0; + ip_list.ip_addresses = NULL; + char *cname = domain_name; + for (uint16_t i = 0; i < ancount; i++) { + if (strcmp((answers + i)->name, cname) == 0) { + dns_rr_type_t rtype = (answers + i)->rtype; + if (rtype == A || rtype == AAAA) + { + // Handle IP list length + if (ip_list.ip_addresses == NULL) { + ip_list.ip_addresses = (ip_addr_t *) malloc(sizeof(ip_addr_t)); + } else { + void *realloc_ptr = realloc(ip_list.ip_addresses, (ip_list.ip_count + 1) * sizeof(ip_addr_t)); + if (realloc_ptr == NULL) { + // Handle realloc error + free(ip_list.ip_addresses); + fprintf(stderr, "Error reallocating memory for IP list.\n"); + ip_list.ip_count = 0; + ip_list.ip_addresses = NULL; + return ip_list; + } else { + ip_list.ip_addresses = (ip_addr_t*) realloc_ptr; + } + } + // Handle IP version and value + *(ip_list.ip_addresses + ip_list.ip_count) = (answers + i)->rdata.ip; + ip_list.ip_count++; + } + else if ((answers + i)->rtype == CNAME) + { + cname = (answers + i)->rdata.domain_name; + } + } + } + return ip_list; +} + + +///// DESTROY ///// + +/** + * @brief Free the memory allocated for a DNS RDATA field. + * + * @param rdata the DNS RDATA field to free + * @param rtype the DNS Resource Record Type of the RDATA field + */ +static void dns_free_rdata(rdata_t rdata, dns_rr_type_t rtype) { + switch (rtype) { + case A: + case AAAA: + break; // Nothing to free for IP addresses + case NS: + case CNAME: + case PTR: + free(rdata.domain_name); + break; + default: + free(rdata.data); + } +} + +/** + * @brief Free the memory allocated for a list of DNS Resource Records. + * + * @param rr the list of DNS Resource Records to free + * @param count the number of Resource Records in the list + */ +static void dns_free_rrs(dns_resource_record_t *rrs, uint16_t count) { + if (rrs != NULL && count > 0) { + for (uint16_t i = 0; i < count; i++) { + dns_resource_record_t rr = *(rrs + i); + if (rr.rdlength > 0) { + free(rr.name); + dns_free_rdata(rr.rdata, rr.rtype); + } + } + free(rrs); + } +} + +/** + * Free the memory allocated for a DNS message. + * + * @param question the DNS message to free + */ +void dns_free_message(dns_message_t message) { + // Free DNS Questions + if (message.header.qdcount > 0) { + for (uint16_t i = 0; i < message.header.qdcount; i++) { + free((message.questions + i)->qname); + } + free(message.questions); + } + + // Free DNS Answers + dns_free_rrs(message.answers, message.header.ancount); + + /* Other sections are not used in this project + + // Free DNS Authorities + dns_free_rrs(message.authorities, message.header.nscount); + // Free DNS Additionals + dns_free_rrs(message.additionals, message.header.arcount); + + */ +} + + +///// PRINTING ///// + +/** + * Print a DNS header. + * + * @param message the DNS header + */ +void dns_print_header(dns_header_t header) { + printf("DNS Header:\n"); + printf(" ID: %#hx\n", header.id); + printf(" Flags: %#hx\n", header.flags); + printf(" QR: %d\n", header.qr); + printf(" Questions count: %hd\n", header.qdcount); + printf(" Answers count: %hd\n", header.ancount); + printf(" Authority name servers count: %hd\n", header.nscount); + printf(" Additional records count: %hd\n", header.arcount); +} + +/** + * Print a DNS Question + * + * @param question the DNS Question + */ +void dns_print_question(dns_question_t question) { + printf(" Question:\n"); + printf(" Domain name: %s\n", question.qname); + printf(" Type: %hd\n", question.qtype); + printf(" Class: %hd\n", question.qclass); +} + +/** + * Print a DNS Question section. + * + * @param qdcount the number of Questions in the Question section + * @param questions the list of DNS Questions + */ +void dns_print_questions(uint16_t qdcount, dns_question_t *questions) { + printf("DNS Question section:\n"); + for (uint16_t i = 0; i < qdcount; i++) { + dns_question_t *question = questions + i; + if (question != NULL) { + dns_print_question(*question); + } + } +} + +/** + * Return a string representation of the given RDATA value. + * + * @param rtype the type corresponding to the RDATA value + * @param rdlength the length, in bytes, of the RDATA value + * @param rdata the RDATA value, stored as a union type + * @return a string representation of the RDATA value + */ +char* dns_rdata_to_str(dns_rr_type_t rtype, uint16_t rdlength, rdata_t rdata) { + if (rdlength == 0) { + // RDATA is empty + return ""; + } + switch (rtype) { + case A: + case AAAA: + // RDATA is an IP (v4 or v6) address + return ip_net_to_str(rdata.ip); + break; + case NS: + case CNAME: + case PTR: + // RDATA is a domain name + return rdata.domain_name; + break; + default: ; + // Generic RDATA + char *buffer = (char *) malloc(rdlength * 4 + 1); // Allocate memory for each byte (4 characters) + the NULL terminator + for (uint8_t i = 0; i < rdlength; i++) { + snprintf(buffer + (i * 4), 5, "\\x%02x", *(rdata.data + i)); + } + return buffer; + } +} + +/** + * Print a DNS Resource Record. + * + * @param section_name the name of the Resource Record section + * @param rr the DNS Resource Record + */ +void dns_print_rr(char* section_name, dns_resource_record_t rr) { + printf(" %s RR:\n", section_name); + printf(" Name: %s\n", rr.name); + printf(" Type: %hd\n", rr.rtype); + printf(" Class: %hd\n", rr.rclass); + printf(" TTL [s]: %d\n", rr.ttl); + printf(" Data length: %hd\n", rr.rdlength); + printf(" RDATA: %s\n", dns_rdata_to_str(rr.rtype, rr.rdlength, rr.rdata)); +} + +/** + * Print a DNS Resource Records section. + * + * @param section_name the name of the Resource Record section + * @param count the number of Resource Records in the section + * @param rrs the list of DNS Resource Records + */ +void dns_print_rrs(char* section_name, uint16_t count, dns_resource_record_t *rrs) { + printf("%s RRs:\n", section_name); + for (uint16_t i = 0; i < count; i++) { + dns_resource_record_t *rr = rrs + i; + if (rr != NULL) + dns_print_rr(section_name, *rr); + } +} + +/** + * Print a DNS message. + * + * @param message the DNS message + */ +void dns_print_message(dns_message_t message) { + // Print DNS Header + dns_print_header(message.header); + + // Print DNS Questions, if any + if (message.header.qdcount > 0) + dns_print_questions(message.header.qdcount, message.questions); + + // Print DNS Answers, if message is a response and has answers + if (message.header.qr == 1 && message.header.ancount > 0) + dns_print_rrs("Answer", message.header.ancount, message.answers); + + /* Other sections are not used in this project + + dns_print_rrs("Authority", message.header.nscount, message.authorities); + dns_print_rrs("Additional", message.header.arcount, message.additionals); + + */ +} diff --git a/src/dns_map.c b/src/dns_map.c new file mode 100644 index 0000000000000000000000000000000000000000..b5406b8f17adec239c60ae6bf716cb803870aec1 --- /dev/null +++ b/src/dns_map.c @@ -0,0 +1,203 @@ +/** + * @file src/dns_map.c + * @brief Implementation of a DNS domain name to IP addresses mapping, using Joshua J Baker's hashmap.c (https://github.com/tidwall/hashmap.c) + * @date 2022-09-06 + * + * @copyright Copyright (c) 2022 + * + */ + +#include "dns_map.h" + + +/*** Static functions for hashmap ****/ + +/** + * Hash function for the DNS table. + * + * @param item DNS table entry to hash + * @param seed0 first seed + * @param seed1 second seed + * @return hash value for the given DNS table entry + */ +static uint64_t dns_hash(const void *item, uint64_t seed0, uint64_t seed1) { + const dns_entry_t *entry = (dns_entry_t *) item; + return hashmap_sip(entry->domain_name, strlen(entry->domain_name), seed0, seed1); +} + +/** + * Compare function for the DNS table. + * + * @param a first DNS table entry to compare + * @param a second DNS table entry to compare + * @param udata user data, unused + * @return an integer which takes the following value: + * - 0 if a and b are equal + * - less than 0 if a is smaller than b + * - greater than 0 if a is greater than b + */ +static int dns_compare(const void *a, const void *b, void *udata) { + const dns_entry_t *entry1 = (dns_entry_t *) a; + const dns_entry_t *entry2 = (dns_entry_t *) b; + return strcmp(entry1->domain_name, entry2->domain_name); +} + +/** + * Free an entry of the DNS table. + * + * @param item the entry to free + */ +static void dns_free(void *item) { + free(((dns_entry_t *) item)->ip_list.ip_addresses); +} + + +/*** Visible functions ***/ + +/** + * @brief Initialize an ip_list_t structure. + * + * Creates an empty list of IP addresses. + * The `ip_count` field is set to 0, + * and the `ip_addresses` field is set to NULL. + * + * @return ip_list_t newly initialized structure + */ +ip_list_t ip_list_init() { + ip_list_t ip_list; + ip_list.ip_count = 0; + ip_list.ip_addresses = NULL; + return ip_list; +} + +/** + * @brief Checks if a dns_entry_t structure contains a given IP address. + * + * @param dns_entry pointer to the DNS entry to process + * @param ip_address IP address to check the presence of + * @return true if the IP address is present in the DNS entry, false otherwise + */ +bool dns_entry_contains(dns_entry_t *dns_entry, ip_addr_t ip_address) { + if (dns_entry == NULL || dns_entry->ip_list.ip_addresses == NULL) { + // DNS entry or IP address list is NULL + return false; + } + + // Not NULL, search for the IP address + for (uint8_t i = 0; i < dns_entry->ip_list.ip_count; i++) { + if (compare_ip(*(dns_entry->ip_list.ip_addresses + i), ip_address)) { + // IP address found + return true; + } + } + + // IP address not found + return false; +} + +/** + * Create a new DNS table. + * Uses random seeds for the hash function. + * + * @return the newly created DNS table, or NULL if creation failed + */ +dns_map_t* dns_map_create() { + return hashmap_new( + sizeof(dns_entry_t), // Size of one entry + DNS_MAP_INIT_SIZE, // Hashmap initial size + rand(), // Optional seed 1 + rand(), // Optional seed 2 + &dns_hash, // Hash function + &dns_compare, // Compare function + &dns_free, // Element free function + NULL // User data, unused + ); +} + +/** + * Free the memory allocated for a DNS table. + * + * @param table the DNS table to free + */ +void dns_map_free(dns_map_t *table) { + hashmap_free(table); +} + +/** + * Add IP addresses corresponding to a given domain name in the DNS table. + * If the domain name was already present, its IP addresses will be replaced by the new ones. + * + * @param table the DNS table to add the entry to + * @param domain_name the domain name of the entry + * @param ip_list an ip_list_t structure containing the list of IP addresses + */ +void dns_map_add(dns_map_t *table, char *domain_name, ip_list_t ip_list) { + dns_entry_t *dns_entry = dns_map_get(table, domain_name); + if (dns_entry != NULL) { + // Domain name already present, add given IP addresses to the already existing ones + ip_list_t old_ip_list = dns_entry->ip_list; + ip_list_t new_ip_list; + new_ip_list.ip_count = old_ip_list.ip_count + ip_list.ip_count; + new_ip_list.ip_addresses = (ip_addr_t *) malloc(new_ip_list.ip_count * sizeof(ip_addr_t)); + memcpy(new_ip_list.ip_addresses, old_ip_list.ip_addresses, old_ip_list.ip_count * sizeof(ip_addr_t)); + memcpy(new_ip_list.ip_addresses + old_ip_list.ip_count, ip_list.ip_addresses, ip_list.ip_count * sizeof(ip_addr_t)); + dns_entry->ip_list = new_ip_list; + free(old_ip_list.ip_addresses); + free(ip_list.ip_addresses); + } else { + // Domain name not present, create a new entry with given IP addresses + hashmap_set(table, &(dns_entry_t){.domain_name = domain_name, .ip_list = ip_list}); + } +} + +/** + * Remove a domain name, and its corresponding IP addresses, from the DNS table. + * + * @param table the DNS table to remove the entry from + * @param domain_name the domain name of the entry to remove + */ +void dns_map_remove(dns_map_t *table, char *domain_name) { + dns_entry_t *entry = hashmap_delete(table, &(dns_entry_t){ .domain_name = domain_name }); + if (entry != NULL) + dns_free(entry); +} + +/** + * Retrieve the IP addresses corresponding to a given domain name in the DNS table. + * + * @param table the DNS table to retrieve the entry from + * @param domain_name the domain name of the entry to retrieve + * @return a pointer to a dns_entry structure containing the IP addresses corresponding to the domain name, + * or NULL if the domain name was not found in the DNS table + */ +dns_entry_t* dns_map_get(dns_map_t *table, char *domain_name) { + return (dns_entry_t *) hashmap_get(table, &(dns_entry_t){ .domain_name = domain_name }); +} + +/** + * Retrieve the IP addresses corresponding to a given domain name, + * and remove the domain name from the DNS table. + * + * @param table the DNS table to retrieve the entry from + * @param domain_name the domain name of the entry to retrieve + * @return a pointer to a dns_entry structure containing the IP addresses corresponding to the domain name, + * or NULL if the domain name was not found in the DNS table + */ +dns_entry_t* dns_map_pop(dns_map_t *table, char *domain_name) { + return (dns_entry_t *) hashmap_delete(table, &(dns_entry_t){ .domain_name = domain_name }); +} + +/** + * @brief Print a DNS table entry. + * + * @param dns_entry the DNS table entry to print + */ +void dns_entry_print(dns_entry_t *dns_entry) { + if (dns_entry != NULL) { + printf("Domain name: %s\n", dns_entry->domain_name); + printf("IP addresses:\n"); + for (uint8_t i = 0; i < dns_entry->ip_list.ip_count; i++) { + printf(" %s\n", ip_net_to_str(*(dns_entry->ip_list.ip_addresses + i))); + } + } +} diff --git a/src/hashmap.c b/src/hashmap.c new file mode 100644 index 0000000000000000000000000000000000000000..6633e35317a920c8cb468192584704889ebc6990 --- /dev/null +++ b/src/hashmap.c @@ -0,0 +1,980 @@ +// Copyright 2020 Joshua J Baker. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +#include <stdio.h> +#include <string.h> +#include <stdlib.h> +#include <stdint.h> +#include <stddef.h> +#include "hashmap.h" + +static void *(*_malloc)(size_t) = NULL; +static void *(*_realloc)(void *, size_t) = NULL; +static void (*_free)(void *) = NULL; + +// hashmap_set_allocator allows for configuring a custom allocator for +// all hashmap library operations. This function, if needed, should be called +// only once at startup and a prior to calling hashmap_new(). +void hashmap_set_allocator(void *(*malloc)(size_t), void (*free)(void*)) +{ + _malloc = malloc; + _free = free; +} + +#define panic(_msg_) { \ + fprintf(stderr, "panic: %s (%s:%d)\n", (_msg_), __FILE__, __LINE__); \ + exit(1); \ +} + +struct bucket { + uint64_t hash:48; + uint64_t dib:16; +}; + +// hashmap is an open addressed hash map using robinhood hashing. +struct hashmap { + void *(*malloc)(size_t); + void *(*realloc)(void *, size_t); + void (*free)(void *); + bool oom; + size_t elsize; + size_t cap; + uint64_t seed0; + uint64_t seed1; + uint64_t (*hash)(const void *item, uint64_t seed0, uint64_t seed1); + int (*compare)(const void *a, const void *b, void *udata); + void (*elfree)(void *item); + void *udata; + size_t bucketsz; + size_t nbuckets; + size_t count; + size_t mask; + size_t growat; + size_t shrinkat; + void *buckets; + void *spare; + void *edata; +}; + +static struct bucket *bucket_at(struct hashmap *map, size_t index) { + return (struct bucket*)(((char*)map->buckets)+(map->bucketsz*index)); +} + +static void *bucket_item(struct bucket *entry) { + return ((char*)entry)+sizeof(struct bucket); +} + +static uint64_t get_hash(struct hashmap *map, const void *key) { + return map->hash(key, map->seed0, map->seed1) << 16 >> 16; +} + +// hashmap_new_with_allocator returns a new hash map using a custom allocator. +// See hashmap_new for more information information +struct hashmap *hashmap_new_with_allocator( + void *(*_malloc)(size_t), + void *(*_realloc)(void*, size_t), + void (*_free)(void*), + size_t elsize, size_t cap, + uint64_t seed0, uint64_t seed1, + uint64_t (*hash)(const void *item, + uint64_t seed0, uint64_t seed1), + int (*compare)(const void *a, const void *b, + void *udata), + void (*elfree)(void *item), + void *udata) +{ + _malloc = _malloc ? _malloc : malloc; + _realloc = _realloc ? _realloc : realloc; + _free = _free ? _free : free; + int ncap = 16; + if (cap < ncap) { + cap = ncap; + } else { + while (ncap < cap) { + ncap *= 2; + } + cap = ncap; + } + size_t bucketsz = sizeof(struct bucket) + elsize; + while (bucketsz & (sizeof(uintptr_t)-1)) { + bucketsz++; + } + // hashmap + spare + edata + size_t size = sizeof(struct hashmap)+bucketsz*2; + struct hashmap *map = _malloc(size); + if (!map) { + return NULL; + } + memset(map, 0, sizeof(struct hashmap)); + map->elsize = elsize; + map->bucketsz = bucketsz; + map->seed0 = seed0; + map->seed1 = seed1; + map->hash = hash; + map->compare = compare; + map->elfree = elfree; + map->udata = udata; + map->spare = ((char*)map)+sizeof(struct hashmap); + map->edata = (char*)map->spare+bucketsz; + map->cap = cap; + map->nbuckets = cap; + map->mask = map->nbuckets-1; + map->buckets = _malloc(map->bucketsz*map->nbuckets); + if (!map->buckets) { + _free(map); + return NULL; + } + memset(map->buckets, 0, map->bucketsz*map->nbuckets); + map->growat = map->nbuckets*0.75; + map->shrinkat = map->nbuckets*0.10; + map->malloc = _malloc; + map->realloc = _realloc; + map->free = _free; + return map; +} + + +// hashmap_new returns a new hash map. +// Param `elsize` is the size of each element in the tree. Every element that +// is inserted, deleted, or retrieved will be this size. +// Param `cap` is the default lower capacity of the hashmap. Setting this to +// zero will default to 16. +// Params `seed0` and `seed1` are optional seed values that are passed to the +// following `hash` function. These can be any value you wish but it's often +// best to use randomly generated values. +// Param `hash` is a function that generates a hash value for an item. It's +// important that you provide a good hash function, otherwise it will perform +// poorly or be vulnerable to Denial-of-service attacks. This implementation +// comes with two helper functions `hashmap_sip()` and `hashmap_murmur()`. +// Param `compare` is a function that compares items in the tree. See the +// qsort stdlib function for an example of how this function works. +// The hashmap must be freed with hashmap_free(). +// Param `elfree` is a function that frees a specific item. This should be NULL +// unless you're storing some kind of reference data in the hash. +struct hashmap *hashmap_new(size_t elsize, size_t cap, + uint64_t seed0, uint64_t seed1, + uint64_t (*hash)(const void *item, + uint64_t seed0, uint64_t seed1), + int (*compare)(const void *a, const void *b, + void *udata), + void (*elfree)(void *item), + void *udata) +{ + return hashmap_new_with_allocator( + (_malloc?_malloc:malloc), + (_realloc?_realloc:realloc), + (_free?_free:free), + elsize, cap, seed0, seed1, hash, compare, elfree, udata + ); +} + +static void free_elements(struct hashmap *map) { + if (map->elfree) { + for (size_t i = 0; i < map->nbuckets; i++) { + struct bucket *bucket = bucket_at(map, i); + if (bucket->dib) map->elfree(bucket_item(bucket)); + } + } +} + + +// hashmap_clear quickly clears the map. +// Every item is called with the element-freeing function given in hashmap_new, +// if present, to free any data referenced in the elements of the hashmap. +// When the update_cap is provided, the map's capacity will be updated to match +// the currently number of allocated buckets. This is an optimization to ensure +// that this operation does not perform any allocations. +void hashmap_clear(struct hashmap *map, bool update_cap) { + map->count = 0; + free_elements(map); + if (update_cap) { + map->cap = map->nbuckets; + } else if (map->nbuckets != map->cap) { + void *new_buckets = map->malloc(map->bucketsz*map->cap); + if (new_buckets) { + map->free(map->buckets); + map->buckets = new_buckets; + } + map->nbuckets = map->cap; + } + memset(map->buckets, 0, map->bucketsz*map->nbuckets); + map->mask = map->nbuckets-1; + map->growat = map->nbuckets*0.75; + map->shrinkat = map->nbuckets*0.10; +} + + +static bool resize(struct hashmap *map, size_t new_cap) { + struct hashmap *map2 = hashmap_new_with_allocator(map->malloc, map->realloc, map->free, + map->elsize, new_cap, map->seed0, + map->seed1, map->hash, map->compare, + map->elfree, map->udata); + if (!map2) { + return false; + } + for (size_t i = 0; i < map->nbuckets; i++) { + struct bucket *entry = bucket_at(map, i); + if (!entry->dib) { + continue; + } + entry->dib = 1; + size_t j = entry->hash & map2->mask; + for (;;) { + struct bucket *bucket = bucket_at(map2, j); + if (bucket->dib == 0) { + memcpy(bucket, entry, map->bucketsz); + break; + } + if (bucket->dib < entry->dib) { + memcpy(map2->spare, bucket, map->bucketsz); + memcpy(bucket, entry, map->bucketsz); + memcpy(entry, map2->spare, map->bucketsz); + } + j = (j + 1) & map2->mask; + entry->dib += 1; + } + } + map->free(map->buckets); + map->buckets = map2->buckets; + map->nbuckets = map2->nbuckets; + map->mask = map2->mask; + map->growat = map2->growat; + map->shrinkat = map2->shrinkat; + map->free(map2); + return true; +} + +// hashmap_set inserts or replaces an item in the hash map. If an item is +// replaced then it is returned otherwise NULL is returned. This operation +// may allocate memory. If the system is unable to allocate additional +// memory then NULL is returned and hashmap_oom() returns true. +void *hashmap_set(struct hashmap *map, const void *item) { + if (!item) { + panic("item is null"); + } + map->oom = false; + if (map->count == map->growat) { + if (!resize(map, map->nbuckets*2)) { + map->oom = true; + return NULL; + } + } + + + struct bucket *entry = map->edata; + entry->hash = get_hash(map, item); + entry->dib = 1; + memcpy(bucket_item(entry), item, map->elsize); + + size_t i = entry->hash & map->mask; + for (;;) { + struct bucket *bucket = bucket_at(map, i); + if (bucket->dib == 0) { + memcpy(bucket, entry, map->bucketsz); + map->count++; + return NULL; + } + if (entry->hash == bucket->hash && + map->compare(bucket_item(entry), bucket_item(bucket), + map->udata) == 0) + { + memcpy(map->spare, bucket_item(bucket), map->elsize); + memcpy(bucket_item(bucket), bucket_item(entry), map->elsize); + return map->spare; + } + if (bucket->dib < entry->dib) { + memcpy(map->spare, bucket, map->bucketsz); + memcpy(bucket, entry, map->bucketsz); + memcpy(entry, map->spare, map->bucketsz); + } + i = (i + 1) & map->mask; + entry->dib += 1; + } +} + +// hashmap_get returns the item based on the provided key. If the item is not +// found then NULL is returned. +void *hashmap_get(struct hashmap *map, const void *key) { + if (!key) { + panic("key is null"); + } + uint64_t hash = get_hash(map, key); + size_t i = hash & map->mask; + for (;;) { + struct bucket *bucket = bucket_at(map, i); + if (!bucket->dib) { + return NULL; + } + if (bucket->hash == hash && + map->compare(key, bucket_item(bucket), map->udata) == 0) + { + return bucket_item(bucket); + } + i = (i + 1) & map->mask; + } +} + +// hashmap_probe returns the item in the bucket at position or NULL if an item +// is not set for that bucket. The position is 'moduloed' by the number of +// buckets in the hashmap. +void *hashmap_probe(struct hashmap *map, uint64_t position) { + size_t i = position & map->mask; + struct bucket *bucket = bucket_at(map, i); + if (!bucket->dib) { + return NULL; + } + return bucket_item(bucket); +} + + +// hashmap_delete removes an item from the hash map and returns it. If the +// item is not found then NULL is returned. +void *hashmap_delete(struct hashmap *map, void *key) { + if (!key) { + panic("key is null"); + } + map->oom = false; + uint64_t hash = get_hash(map, key); + size_t i = hash & map->mask; + for (;;) { + struct bucket *bucket = bucket_at(map, i); + if (!bucket->dib) { + return NULL; + } + if (bucket->hash == hash && + map->compare(key, bucket_item(bucket), map->udata) == 0) + { + memcpy(map->spare, bucket_item(bucket), map->elsize); + bucket->dib = 0; + for (;;) { + struct bucket *prev = bucket; + i = (i + 1) & map->mask; + bucket = bucket_at(map, i); + if (bucket->dib <= 1) { + prev->dib = 0; + break; + } + memcpy(prev, bucket, map->bucketsz); + prev->dib--; + } + map->count--; + if (map->nbuckets > map->cap && map->count <= map->shrinkat) { + // Ignore the return value. It's ok for the resize operation to + // fail to allocate enough memory because a shrink operation + // does not change the integrity of the data. + resize(map, map->nbuckets/2); + } + return map->spare; + } + i = (i + 1) & map->mask; + } +} + +// hashmap_count returns the number of items in the hash map. +size_t hashmap_count(struct hashmap *map) { + return map->count; +} + +// hashmap_free frees the hash map +// Every item is called with the element-freeing function given in hashmap_new, +// if present, to free any data referenced in the elements of the hashmap. +void hashmap_free(struct hashmap *map) { + if (!map) return; + free_elements(map); + map->free(map->buckets); + map->free(map); +} + +// hashmap_oom returns true if the last hashmap_set() call failed due to the +// system being out of memory. +bool hashmap_oom(struct hashmap *map) { + return map->oom; +} + +// hashmap_scan iterates over all items in the hash map +// Param `iter` can return false to stop iteration early. +// Returns false if the iteration has been stopped early. +bool hashmap_scan(struct hashmap *map, + bool (*iter)(const void *item, void *udata), void *udata) +{ + for (size_t i = 0; i < map->nbuckets; i++) { + struct bucket *bucket = bucket_at(map, i); + if (bucket->dib) { + if (!iter(bucket_item(bucket), udata)) { + return false; + } + } + } + return true; +} + + +// hashmap_iter iterates one key at a time yielding a reference to an +// entry at each iteration. Useful to write simple loops and avoid writing +// dedicated callbacks and udata structures, as in hashmap_scan. +// +// map is a hash map handle. i is a pointer to a size_t cursor that +// should be initialized to 0 at the beginning of the loop. item is a void +// pointer pointer that is populated with the retrieved item. Note that this +// is NOT a copy of the item stored in the hash map and can be directly +// modified. +// +// Note that if hashmap_delete() is called on the hashmap being iterated, +// the buckets are rearranged and the iterator must be reset to 0, otherwise +// unexpected results may be returned after deletion. +// +// This function has not been tested for thread safety. +// +// The function returns true if an item was retrieved; false if the end of the +// iteration has been reached. +bool hashmap_iter(struct hashmap *map, size_t *i, void **item) +{ + struct bucket *bucket; + + do { + if (*i >= map->nbuckets) return false; + + bucket = bucket_at(map, *i); + (*i)++; + } while (!bucket->dib); + + *item = bucket_item(bucket); + + return true; +} + + +//----------------------------------------------------------------------------- +// SipHash reference C implementation +// +// Copyright (c) 2012-2016 Jean-Philippe Aumasson +// <jeanphilippe.aumasson@gmail.com> +// Copyright (c) 2012-2014 Daniel J. Bernstein <djb@cr.yp.to> +// +// To the extent possible under law, the author(s) have dedicated all copyright +// and related and neighboring rights to this software to the public domain +// worldwide. This software is distributed without any warranty. +// +// You should have received a copy of the CC0 Public Domain Dedication along +// with this software. If not, see +// <http://creativecommons.org/publicdomain/zero/1.0/>. +// +// default: SipHash-2-4 +//----------------------------------------------------------------------------- +static uint64_t SIP64(const uint8_t *in, const size_t inlen, + uint64_t seed0, uint64_t seed1) +{ +#define U8TO64_LE(p) \ + { (((uint64_t)((p)[0])) | ((uint64_t)((p)[1]) << 8) | \ + ((uint64_t)((p)[2]) << 16) | ((uint64_t)((p)[3]) << 24) | \ + ((uint64_t)((p)[4]) << 32) | ((uint64_t)((p)[5]) << 40) | \ + ((uint64_t)((p)[6]) << 48) | ((uint64_t)((p)[7]) << 56)) } +#define U64TO8_LE(p, v) \ + { U32TO8_LE((p), (uint32_t)((v))); \ + U32TO8_LE((p) + 4, (uint32_t)((v) >> 32)); } +#define U32TO8_LE(p, v) \ + { (p)[0] = (uint8_t)((v)); \ + (p)[1] = (uint8_t)((v) >> 8); \ + (p)[2] = (uint8_t)((v) >> 16); \ + (p)[3] = (uint8_t)((v) >> 24); } +#define ROTL(x, b) (uint64_t)(((x) << (b)) | ((x) >> (64 - (b)))) +#define SIPROUND \ + { v0 += v1; v1 = ROTL(v1, 13); \ + v1 ^= v0; v0 = ROTL(v0, 32); \ + v2 += v3; v3 = ROTL(v3, 16); \ + v3 ^= v2; \ + v0 += v3; v3 = ROTL(v3, 21); \ + v3 ^= v0; \ + v2 += v1; v1 = ROTL(v1, 17); \ + v1 ^= v2; v2 = ROTL(v2, 32); } + uint64_t k0 = U8TO64_LE((uint8_t*)&seed0); + uint64_t k1 = U8TO64_LE((uint8_t*)&seed1); + uint64_t v3 = UINT64_C(0x7465646279746573) ^ k1; + uint64_t v2 = UINT64_C(0x6c7967656e657261) ^ k0; + uint64_t v1 = UINT64_C(0x646f72616e646f6d) ^ k1; + uint64_t v0 = UINT64_C(0x736f6d6570736575) ^ k0; + const uint8_t *end = in + inlen - (inlen % sizeof(uint64_t)); + for (; in != end; in += 8) { + uint64_t m = U8TO64_LE(in); + v3 ^= m; + SIPROUND; SIPROUND; + v0 ^= m; + } + const int left = inlen & 7; + uint64_t b = ((uint64_t)inlen) << 56; + switch (left) { + case 7: b |= ((uint64_t)in[6]) << 48; + case 6: b |= ((uint64_t)in[5]) << 40; + case 5: b |= ((uint64_t)in[4]) << 32; + case 4: b |= ((uint64_t)in[3]) << 24; + case 3: b |= ((uint64_t)in[2]) << 16; + case 2: b |= ((uint64_t)in[1]) << 8; + case 1: b |= ((uint64_t)in[0]); break; + case 0: break; + } + v3 ^= b; + SIPROUND; SIPROUND; + v0 ^= b; + v2 ^= 0xff; + SIPROUND; SIPROUND; SIPROUND; SIPROUND; + b = v0 ^ v1 ^ v2 ^ v3; + uint64_t out = 0; + U64TO8_LE((uint8_t*)&out, b); + return out; +} + +//----------------------------------------------------------------------------- +// MurmurHash3 was written by Austin Appleby, and is placed in the public +// domain. The author hereby disclaims copyright to this source code. +// +// Murmur3_86_128 +//----------------------------------------------------------------------------- +static void MM86128(const void *key, const int len, uint32_t seed, void *out) { +#define ROTL32(x, r) ((x << r) | (x >> (32 - r))) +#define FMIX32(h) h^=h>>16; h*=0x85ebca6b; h^=h>>13; h*=0xc2b2ae35; h^=h>>16; + const uint8_t * data = (const uint8_t*)key; + const int nblocks = len / 16; + uint32_t h1 = seed; + uint32_t h2 = seed; + uint32_t h3 = seed; + uint32_t h4 = seed; + uint32_t c1 = 0x239b961b; + uint32_t c2 = 0xab0e9789; + uint32_t c3 = 0x38b34ae5; + uint32_t c4 = 0xa1e38b93; + const uint32_t * blocks = (const uint32_t *)(data + nblocks*16); + for (int i = -nblocks; i; i++) { + uint32_t k1 = blocks[i*4+0]; + uint32_t k2 = blocks[i*4+1]; + uint32_t k3 = blocks[i*4+2]; + uint32_t k4 = blocks[i*4+3]; + k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; + h1 = ROTL32(h1,19); h1 += h2; h1 = h1*5+0x561ccd1b; + k2 *= c2; k2 = ROTL32(k2,16); k2 *= c3; h2 ^= k2; + h2 = ROTL32(h2,17); h2 += h3; h2 = h2*5+0x0bcaa747; + k3 *= c3; k3 = ROTL32(k3,17); k3 *= c4; h3 ^= k3; + h3 = ROTL32(h3,15); h3 += h4; h3 = h3*5+0x96cd1c35; + k4 *= c4; k4 = ROTL32(k4,18); k4 *= c1; h4 ^= k4; + h4 = ROTL32(h4,13); h4 += h1; h4 = h4*5+0x32ac3b17; + } + const uint8_t * tail = (const uint8_t*)(data + nblocks*16); + uint32_t k1 = 0; + uint32_t k2 = 0; + uint32_t k3 = 0; + uint32_t k4 = 0; + switch(len & 15) { + case 15: k4 ^= tail[14] << 16; + case 14: k4 ^= tail[13] << 8; + case 13: k4 ^= tail[12] << 0; + k4 *= c4; k4 = ROTL32(k4,18); k4 *= c1; h4 ^= k4; + case 12: k3 ^= tail[11] << 24; + case 11: k3 ^= tail[10] << 16; + case 10: k3 ^= tail[ 9] << 8; + case 9: k3 ^= tail[ 8] << 0; + k3 *= c3; k3 = ROTL32(k3,17); k3 *= c4; h3 ^= k3; + case 8: k2 ^= tail[ 7] << 24; + case 7: k2 ^= tail[ 6] << 16; + case 6: k2 ^= tail[ 5] << 8; + case 5: k2 ^= tail[ 4] << 0; + k2 *= c2; k2 = ROTL32(k2,16); k2 *= c3; h2 ^= k2; + case 4: k1 ^= tail[ 3] << 24; + case 3: k1 ^= tail[ 2] << 16; + case 2: k1 ^= tail[ 1] << 8; + case 1: k1 ^= tail[ 0] << 0; + k1 *= c1; k1 = ROTL32(k1,15); k1 *= c2; h1 ^= k1; + }; + h1 ^= len; h2 ^= len; h3 ^= len; h4 ^= len; + h1 += h2; h1 += h3; h1 += h4; + h2 += h1; h3 += h1; h4 += h1; + FMIX32(h1); FMIX32(h2); FMIX32(h3); FMIX32(h4); + h1 += h2; h1 += h3; h1 += h4; + h2 += h1; h3 += h1; h4 += h1; + ((uint32_t*)out)[0] = h1; + ((uint32_t*)out)[1] = h2; + ((uint32_t*)out)[2] = h3; + ((uint32_t*)out)[3] = h4; +} + +// hashmap_sip returns a hash value for `data` using SipHash-2-4. +uint64_t hashmap_sip(const void *data, size_t len, + uint64_t seed0, uint64_t seed1) +{ + return SIP64((uint8_t*)data, len, seed0, seed1); +} + +// hashmap_murmur returns a hash value for `data` using Murmur3_86_128. +uint64_t hashmap_murmur(const void *data, size_t len, + uint64_t seed0, uint64_t seed1) +{ + char out[16]; + MM86128(data, len, seed0, &out); + return *(uint64_t*)out; +} + +//============================================================================== +// TESTS AND BENCHMARKS +// $ cc -DHASHMAP_TEST hashmap.c && ./a.out # run tests +// $ cc -DHASHMAP_TEST -O3 hashmap.c && BENCH=1 ./a.out # run benchmarks +//============================================================================== +#ifdef HASHMAP_TEST + +static size_t deepcount(struct hashmap *map) { + size_t count = 0; + for (size_t i = 0; i < map->nbuckets; i++) { + if (bucket_at(map, i)->dib) { + count++; + } + } + return count; +} + + +#pragma GCC diagnostic ignored "-Wextra" + + +#include <stdlib.h> +#include <string.h> +#include <time.h> +#include <assert.h> +#include <stdio.h> +#include "hashmap.h" + +static bool rand_alloc_fail = false; +static int rand_alloc_fail_odds = 3; // 1 in 3 chance malloc will fail. +static uintptr_t total_allocs = 0; +static uintptr_t total_mem = 0; + +static void *xmalloc(size_t size) { + if (rand_alloc_fail && rand()%rand_alloc_fail_odds == 0) { + return NULL; + } + void *mem = malloc(sizeof(uintptr_t)+size); + assert(mem); + *(uintptr_t*)mem = size; + total_allocs++; + total_mem += size; + return (char*)mem+sizeof(uintptr_t); +} + +static void xfree(void *ptr) { + if (ptr) { + total_mem -= *(uintptr_t*)((char*)ptr-sizeof(uintptr_t)); + free((char*)ptr-sizeof(uintptr_t)); + total_allocs--; + } +} + +static void shuffle(void *array, size_t numels, size_t elsize) { + char tmp[elsize]; + char *arr = array; + for (size_t i = 0; i < numels - 1; i++) { + int j = i + rand() / (RAND_MAX / (numels - i) + 1); + memcpy(tmp, arr + j * elsize, elsize); + memcpy(arr + j * elsize, arr + i * elsize, elsize); + memcpy(arr + i * elsize, tmp, elsize); + } +} + +static bool iter_ints(const void *item, void *udata) { + int *vals = *(int**)udata; + vals[*(int*)item] = 1; + return true; +} + +static int compare_ints(const void *a, const void *b) { + return *(int*)a - *(int*)b; +} + +static int compare_ints_udata(const void *a, const void *b, void *udata) { + return *(int*)a - *(int*)b; +} + +static int compare_strs(const void *a, const void *b, void *udata) { + return strcmp(*(char**)a, *(char**)b); +} + +static uint64_t hash_int(const void *item, uint64_t seed0, uint64_t seed1) { + return hashmap_murmur(item, sizeof(int), seed0, seed1); +} + +static uint64_t hash_str(const void *item, uint64_t seed0, uint64_t seed1) { + return hashmap_murmur(*(char**)item, strlen(*(char**)item), seed0, seed1); +} + +static void free_str(void *item) { + xfree(*(char**)item); +} + +static void all() { + int seed = getenv("SEED")?atoi(getenv("SEED")):time(NULL); + int N = getenv("N")?atoi(getenv("N")):2000; + printf("seed=%d, count=%d, item_size=%zu\n", seed, N, sizeof(int)); + srand(seed); + + rand_alloc_fail = true; + + // test sip and murmur hashes + assert(hashmap_sip("hello", 5, 1, 2) == 2957200328589801622); + assert(hashmap_murmur("hello", 5, 1, 2) == 1682575153221130884); + + int *vals; + while (!(vals = xmalloc(N * sizeof(int)))) {} + for (int i = 0; i < N; i++) { + vals[i] = i; + } + + struct hashmap *map; + + while (!(map = hashmap_new(sizeof(int), 0, seed, seed, + hash_int, compare_ints_udata, NULL, NULL))) {} + shuffle(vals, N, sizeof(int)); + for (int i = 0; i < N; i++) { + // // printf("== %d ==\n", vals[i]); + assert(map->count == i); + assert(map->count == hashmap_count(map)); + assert(map->count == deepcount(map)); + int *v; + assert(!hashmap_get(map, &vals[i])); + assert(!hashmap_delete(map, &vals[i])); + while (true) { + assert(!hashmap_set(map, &vals[i])); + if (!hashmap_oom(map)) { + break; + } + } + + for (int j = 0; j < i; j++) { + v = hashmap_get(map, &vals[j]); + assert(v && *v == vals[j]); + } + while (true) { + v = hashmap_set(map, &vals[i]); + if (!v) { + assert(hashmap_oom(map)); + continue; + } else { + assert(!hashmap_oom(map)); + assert(v && *v == vals[i]); + break; + } + } + v = hashmap_get(map, &vals[i]); + assert(v && *v == vals[i]); + v = hashmap_delete(map, &vals[i]); + assert(v && *v == vals[i]); + assert(!hashmap_get(map, &vals[i])); + assert(!hashmap_delete(map, &vals[i])); + assert(!hashmap_set(map, &vals[i])); + assert(map->count == i+1); + assert(map->count == hashmap_count(map)); + assert(map->count == deepcount(map)); + } + + int *vals2; + while (!(vals2 = xmalloc(N * sizeof(int)))) {} + memset(vals2, 0, N * sizeof(int)); + assert(hashmap_scan(map, iter_ints, &vals2)); + + // Test hashmap_iter. This does the same as hashmap_scan above. + size_t iter = 0; + void *iter_val; + while (hashmap_iter (map, &iter, &iter_val)) { + assert (iter_ints(iter_val, &vals2)); + } + for (int i = 0; i < N; i++) { + assert(vals2[i] == 1); + } + xfree(vals2); + + shuffle(vals, N, sizeof(int)); + for (int i = 0; i < N; i++) { + int *v; + v = hashmap_delete(map, &vals[i]); + assert(v && *v == vals[i]); + assert(!hashmap_get(map, &vals[i])); + assert(map->count == N-i-1); + assert(map->count == hashmap_count(map)); + assert(map->count == deepcount(map)); + for (int j = N-1; j > i; j--) { + v = hashmap_get(map, &vals[j]); + assert(v && *v == vals[j]); + } + } + + for (int i = 0; i < N; i++) { + while (true) { + assert(!hashmap_set(map, &vals[i])); + if (!hashmap_oom(map)) { + break; + } + } + } + + assert(map->count != 0); + size_t prev_cap = map->cap; + hashmap_clear(map, true); + assert(prev_cap < map->cap); + assert(map->count == 0); + + + for (int i = 0; i < N; i++) { + while (true) { + assert(!hashmap_set(map, &vals[i])); + if (!hashmap_oom(map)) { + break; + } + } + } + + prev_cap = map->cap; + hashmap_clear(map, false); + assert(prev_cap == map->cap); + + hashmap_free(map); + + xfree(vals); + + + while (!(map = hashmap_new(sizeof(char*), 0, seed, seed, + hash_str, compare_strs, free_str, NULL))); + + for (int i = 0; i < N; i++) { + char *str; + while (!(str = xmalloc(16))); + sprintf(str, "s%i", i); + while(!hashmap_set(map, &str)); + } + + hashmap_clear(map, false); + assert(hashmap_count(map) == 0); + + for (int i = 0; i < N; i++) { + char *str; + while (!(str = xmalloc(16))); + sprintf(str, "s%i", i); + while(!hashmap_set(map, &str)); + } + + hashmap_free(map); + + if (total_allocs != 0) { + fprintf(stderr, "total_allocs: expected 0, got %lu\n", total_allocs); + exit(1); + } +} + +#define bench(name, N, code) {{ \ + if (strlen(name) > 0) { \ + printf("%-14s ", name); \ + } \ + size_t tmem = total_mem; \ + size_t tallocs = total_allocs; \ + uint64_t bytes = 0; \ + clock_t begin = clock(); \ + for (int i = 0; i < N; i++) { \ + (code); \ + } \ + clock_t end = clock(); \ + double elapsed_secs = (double)(end - begin) / CLOCKS_PER_SEC; \ + double bytes_sec = (double)bytes/elapsed_secs; \ + printf("%d ops in %.3f secs, %.0f ns/op, %.0f op/sec", \ + N, elapsed_secs, \ + elapsed_secs/(double)N*1e9, \ + (double)N/elapsed_secs \ + ); \ + if (bytes > 0) { \ + printf(", %.1f GB/sec", bytes_sec/1024/1024/1024); \ + } \ + if (total_mem > tmem) { \ + size_t used_mem = total_mem-tmem; \ + printf(", %.2f bytes/op", (double)used_mem/N); \ + } \ + if (total_allocs > tallocs) { \ + size_t used_allocs = total_allocs-tallocs; \ + printf(", %.2f allocs/op", (double)used_allocs/N); \ + } \ + printf("\n"); \ +}} + +static void benchmarks() { + int seed = getenv("SEED")?atoi(getenv("SEED")):time(NULL); + int N = getenv("N")?atoi(getenv("N")):5000000; + printf("seed=%d, count=%d, item_size=%zu\n", seed, N, sizeof(int)); + srand(seed); + + + int *vals = xmalloc(N * sizeof(int)); + for (int i = 0; i < N; i++) { + vals[i] = i; + } + + shuffle(vals, N, sizeof(int)); + + struct hashmap *map; + shuffle(vals, N, sizeof(int)); + + map = hashmap_new(sizeof(int), 0, seed, seed, hash_int, compare_ints_udata, + NULL, NULL); + bench("set", N, { + int *v = hashmap_set(map, &vals[i]); + assert(!v); + }) + shuffle(vals, N, sizeof(int)); + bench("get", N, { + int *v = hashmap_get(map, &vals[i]); + assert(v && *v == vals[i]); + }) + shuffle(vals, N, sizeof(int)); + bench("delete", N, { + int *v = hashmap_delete(map, &vals[i]); + assert(v && *v == vals[i]); + }) + hashmap_free(map); + + map = hashmap_new(sizeof(int), N, seed, seed, hash_int, compare_ints_udata, + NULL, NULL); + bench("set (cap)", N, { + int *v = hashmap_set(map, &vals[i]); + assert(!v); + }) + shuffle(vals, N, sizeof(int)); + bench("get (cap)", N, { + int *v = hashmap_get(map, &vals[i]); + assert(v && *v == vals[i]); + }) + shuffle(vals, N, sizeof(int)); + bench("delete (cap)" , N, { + int *v = hashmap_delete(map, &vals[i]); + assert(v && *v == vals[i]); + }) + + hashmap_free(map); + + + xfree(vals); + + if (total_allocs != 0) { + fprintf(stderr, "total_allocs: expected 0, got %lu\n", total_allocs); + exit(1); + } +} + +int main() { + hashmap_set_allocator(xmalloc, xfree); + + if (getenv("BENCH")) { + printf("Running hashmap.c benchmarks...\n"); + benchmarks(); + } else { + printf("Running hashmap.c tests...\n"); + all(); + printf("PASSED\n"); + } +} + + +#endif + + + diff --git a/src/header.c b/src/header.c new file mode 100644 index 0000000000000000000000000000000000000000..c1062eff93d541a2493c591e6351773f1f1a74b2 --- /dev/null +++ b/src/header.c @@ -0,0 +1,204 @@ +/** + * @file src/header.c + * @brief Parser for layer 3 and 4 headers (currently only IPv4, IPv6, UDP and TCP) + * + * Parser for layer 3 and 4 headers. + * Currently supported protocols: + * - Layer 3: + * - IPv4 + * - IPv6 + * - Layer 4: + * - UDP + * - TCP + * + * @date 2022-09-09 + * + * @copyright Copyright (c) 2022 + * + */ + +#include "header.h" + + +/** + * Retrieve the length of a packet's IPv4 header. + * + * @param data a pointer to the start of the packet's IPv4 header + * @return the size, in bytes, of the IPv4 header + */ +size_t get_ipv4_header_length(uint8_t *data) { + // 4-bit IPv4 header length is encoded in the last 4 bits of byte 0. + // It indicates the number of 32-bit words. + // It must be multiplied by 4 to obtain the header size in bytes. + uint8_t length = (*data & 0x0f) * 4; + return length; +} + +/** + * Retrieve the length of a packet's IPv6 header. + * + * @param data a pointer to the start of the packet's IPv6 header + * @return the size, in bytes, of the IPv6 header + */ +size_t get_ipv6_header_length(uint8_t *data) { + // An IPv6 header has a fixed length of 40 bytes + return IPV6_HEADER_LENGTH; +} + +/** + * Retrieve the length of a packet's UDP header. + * + * @param data a pointer to the start of the packet's UDP (layer 4) header + * @return the size, in bytes, of the UDP header + */ +size_t get_udp_header_length(uint8_t *data) { + // A UDP header has a fixed length of 8 bytes + return UDP_HEADER_LENGTH; +} + +/** + * Retrieve the length of a packet's TCP header. + * + * @param data a pointer to the start of the packet's TCP (layer 4) header + * @return the size, in bytes, of the UDP header + */ +size_t get_tcp_header_length(uint8_t *data) { + // 4-bit TCP header data offset is encoded in the first 4 bits of byte 12. + // It indicates the number of 32-bit words. + // It must be multiplied by 4 to obtain the header size in bytes. + uint8_t length = (*((data) + 12) >> 4) * 4; + return length; +} + +/** + * Retrieve the length of a packet's layer 3 header (IPv4 or IPv6). + * + * @param data a pointer to the start of the packet's layer 3 header + * @return the size, in bytes, of the layer 3 header + */ +size_t get_l3_header_length(uint8_t *data) { + uint8_t ip_version = (*data) >> 4; + switch (ip_version) { + case 4: + return get_ipv4_header_length(data); + break; + case 6: + return get_ipv6_header_length(data); + break; + default: + return 0; + break; + } +} + +/** + * Retrieve the length of a packet's layer-3 and layer-4 headers. + * + * @param data a pointer to the start of the packet's layer-3 header + * @return the size, in bytes, of the UDP header + */ +size_t get_headers_length(uint8_t* data) { + size_t length = 0; + + // Layer 3: Network + // Retrieve the IP version, which is encoded in the first 4 bits of byte 0 + uint8_t ip_version = (*data) >> 4; + ip_protocol_t protocol = 0; + switch (ip_version) { + case 4: + length += get_ipv4_header_length(data); + protocol = *((data) + 9); // In IPv4, the protocol number is encoded in byte 9 + break; + case 6: + length += get_ipv6_header_length(data); + protocol = *((data) + 6); // In IPv6, the protocol number is encoded in byte 6 + break; + default: + break; + } + + // Layer 4: Transport + switch (protocol) { + case TCP: + length += get_tcp_header_length(data + length); + break; + case UDP: + length += get_udp_header_length(data + length); + break; + default: + break; + } + return length; +} + +/** + * @brief Retrieve the length of a UDP payload. + * + * @param data pointer to the start of the UDP header + * @return length of the UDP payload, in bytes + */ +uint16_t get_udp_payload_length(uint8_t *data) +{ + // The 16-bit length of the complete UDP datagram is encoded in bytes 4 and 5 of the UDP header. + // The length of the UDP header (8 bytes) must then be subtracted to obtain the length of the UDP payload. + return ntohs(*((uint16_t *) (data + 4))) - UDP_HEADER_LENGTH; +} + +/** + * @brief Retrieve the source port from a layer 4 header. + * + * @param data pointer to the start of the layer 4 header + * @return destination port + */ +uint16_t get_dst_port(uint8_t *data) { + // Source port is encoded in bytes 2 and 3 + return ntohs(*((uint16_t*) (data + 2))); +} + +/** + * @brief Retrieve the source address from an IPv4 header. + * + * @param data pointer to the start of the IPv4 header + * @return source IPv4 address, in network byte order + */ +uint32_t get_ipv4_src_addr(uint8_t *data) { + // Source address is encoded in bytes 12 to 15 + return *((uint32_t*) (data + 12)); +} + +/** + * @brief Retrieve the destination address from an IPv4 header. + * + * @param data pointer to the start of the IPv4 header + * @return destination IPv4 address, in network byte order + */ +uint32_t get_ipv4_dst_addr(uint8_t* data) { + // Destination address is encoded in bytes 16 to 19 + return *((uint32_t*) (data + 16)); +} + +/** + * @brief Retrieve the source address from an IPv6 header. + * + * @param data pointer to the start of the IPv6 header + * @return source IPv6 address, as a 16-byte array + */ +uint8_t* get_ipv6_src_addr(uint8_t *data) { + // Source address is encoded in bytes 8 to 23 + uint8_t *addr = (uint8_t *) malloc(IPV6_ADDR_LENGTH); + memcpy(addr, data + 8, IPV6_ADDR_LENGTH); + return addr; +} + +/** + * @brief Retrieve the destination address from an IPv6 header. + * + * @param data pointer to the start of the IPv6 header + * @return destination IPv6 address, as a 16-byte array + */ +uint8_t* get_ipv6_dst_addr(uint8_t *data) { + // Source address is encoded in bytes 24 to 39 + uint8_t *addr = (uint8_t *) malloc(IPV6_ADDR_LENGTH); + memcpy(addr, data + 24, IPV6_ADDR_LENGTH); + return addr; +} diff --git a/src/http.c b/src/http.c new file mode 100644 index 0000000000000000000000000000000000000000..5d0c002b5fcd3ccd8faf139b6e9e130b9d305505 --- /dev/null +++ b/src/http.c @@ -0,0 +1,229 @@ +/** + * @file src/http.c + * @brief HTTP message parser + * @date 2022-09-19 + * + * @copyright Copyright (c) 2022 + * + */ + +#include "http.h" + + +///// PARSING ///// + +/** + * @brief Parse the method of an HTTP message. + * + * Parse a HTTP message to retrieve its method, + * and convert it to a http_message_t. + * Only the two first characters need to be parsed. + * Advances the offset value after parsing. + * + * @param data pointer to the start of the HTTP message + * @param offset current offset in the message + * @return parsed HTTP method + */ +static http_method_t http_parse_method(uint8_t *data, uint16_t *offset) { + switch (*(data + *offset)) { + case 'G': + // Method is GET + *offset += 4; + return HTTP_GET; + break; + case 'H': + // Method is HEAD + *offset += 5; + return HTTP_HEAD; + break; + case 'P': + // Method is POST or PUT + switch (*(data + *offset + 1)) { + case 'O': + // Method is POST + *offset += 5; + return HTTP_POST; + break; + case 'U': + // Method is PUT + *offset += 4; + return HTTP_PUT; + break; + default: + // Unknown method + return HTTP_UNKNOWN; + } + case 'D': + // Method is DELETE + *offset += 7; + return HTTP_DELETE; + break; + case 'C': + // Method is CONNECT + *offset += 8; + return HTTP_CONNECT; + break; + case 'O': + // Method is OPTIONS + *offset += 8; + return HTTP_OPTIONS; + break; + case 'T': + // Method is TRACE + *offset += 6; + return HTTP_TRACE; + break; + default: + // Unknown method + return HTTP_UNKNOWN; + } +} + +/** + * @brief Check if a TCP message is a HTTP message. + * + * @param data pointer to the start of the TCP payload + * @param dst_port TCP destination port + * @return true if the message is a HTTP message + * @return false if the message is not a HTTP message + */ +bool is_http(uint8_t *data) +{ + uint16_t offset = 0; + return http_parse_method(data, &offset) != HTTP_UNKNOWN; +} + +/** + * @brief Parse an URI in an HTTP message. + * + * Parse a HTTP message to retrieve its URI, + * and convert it to a character string. + * Advances the offset value after parsing. + * + * @param data pointer to the start of the HTTP message + * @param offset current offset in the message + * @return parsed URI + */ +static char* http_parse_uri(uint8_t *data, uint16_t *offset) { + uint16_t length = 1; + uint16_t max_length = HTTP_METHOD_MAX_LEN; + char *uri = (char *) malloc(sizeof(char) * max_length); + while (*(data + *offset) != ' ') { + if (length == max_length) { + // URI is too long, increase buffer size + max_length *= 2; + void* realloc_ptr = realloc(uri, sizeof(char) * max_length); + if (realloc_ptr == NULL) { + // Handle realloc error + fprintf(stderr, "Error reallocating memory for URI %s\n", uri); + free(uri); + return NULL; + } else { + uri = (char*) realloc_ptr; + } + } + *(uri + (length - 1)) = *(data + (*offset)++); + length++; + } + if (length < max_length) { + // URI is shorter than allocated buffer, shrink buffer + void *realloc_ptr = realloc(uri, sizeof(char) * length); + if (realloc_ptr == NULL) { + fprintf(stderr, "Error shrinking memory for URI %s\n", uri); + } else { + uri = (char*) realloc_ptr; + } + } + // Add NULL terminating character + *(uri + length - 1) = '\0'; + return uri; +} + +/** + * @brief Parse the method and URI of HTTP message. + * + * @param data pointer to the start of the HTTP message + * @param dst_port TCP destination port + * @return the parsed HTTP message + */ +http_message_t http_parse_message(uint8_t *data, uint16_t dst_port) { + http_message_t message; + uint16_t offset = 0; + http_method_t http_method = http_parse_method(data, &offset); + message.is_request = dst_port == 80 && http_method != HTTP_UNKNOWN; + if (message.is_request) { + message.method = http_method; + message.uri = http_parse_uri(data, &offset); + } else { + message.method = HTTP_UNKNOWN; + message.uri = NULL; + } + return message; +} + + +///// DESTROY ///// + +/** + * @brief Free the memory allocated for a HTTP message. + * + * @param message the HTTP message to free + */ +void http_free_message(http_message_t message) { + if (message.uri != NULL) + free(message.uri); +} + + +///// PRINTING ///// + +/** + * @brief Converts a HTTP method from enum value to character string. + * + * @param method the HTTP method in enum value + * @return the same HTTP method as a character string + */ +char* http_method_to_str(http_method_t method) { + switch (method) { + case HTTP_GET: + return "GET"; + break; + case HTTP_HEAD: + return "HEAD"; + break; + case HTTP_POST: + return "POST"; + break; + case HTTP_PUT: + return "PUT"; + break; + case HTTP_DELETE: + return "DELETE"; + break; + case HTTP_CONNECT: + return "CONNECT"; + break; + case HTTP_OPTIONS: + return "OPTIONS"; + break; + case HTTP_TRACE: + return "TRACE"; + break; + default: + return "UNKNOWN"; + } +} + +/** + * @brief Print the method and URI of a HTTP message. + * + * @param message the message to print + */ +void http_print_message(http_message_t message) { + printf("HTTP message:\n"); + printf(" is request ?: %d\n", message.is_request); + if (message.is_request) { + printf(" Method: %s\n", http_method_to_str(message.method)); + printf(" URI: %s\n", message.uri); + } +} diff --git a/src/igmp.c b/src/igmp.c new file mode 100644 index 0000000000000000000000000000000000000000..18bf25d960a9bd1f4aeb7bd3c4e47d675fcf0335 --- /dev/null +++ b/src/igmp.c @@ -0,0 +1,177 @@ +/** + * @file src/igmp.c + * @brief IGMP message parser + * @date 2022-10-05 + * + * IGMP message parser. + * Supports v1 and v2, and v3 Membership Report messages. + * TODO: support v3 Membership Query messages. + * + * @copyright Copyright (c) 2022 + * + */ + +#include "igmp.h" + + +///// PARSING ///// + +/** + * @brief Parse an IGMPv2 message. + * + * @param data pointer to the start of the IGMPv2 message + * @return the parsed IGMPv2 message + */ +static igmp_v2_message_t igmp_v2_parse_message(uint8_t *data) { + igmp_v2_message_t message; + message.max_resp_time = *(data + 1); + message.checksum = ntohs(*((uint16_t *)(data + 2))); + message.group_address = *((uint32_t *)(data + 4)); // Stored in network byte order + return message; +} + +/** + * @brief Parse an array of IGMPv3 group records. + * + * @param num_groups number of group records + * @param data pointer to the start of the group records + * @return pointer to the array of parsed group records + */ +static igmp_v3_group_record_t* igmp_v3_parse_groups(uint16_t num_groups, uint8_t *data) { + // If num_groups is 0, group list is NULL + if (num_groups == 0) + return NULL; + + // num_groups is greater than 0 + igmp_v3_group_record_t *groups = malloc(num_groups * sizeof(igmp_v3_group_record_t)); + for (uint16_t i = 0; i < num_groups; i++) { + igmp_v3_group_record_t *group = groups + i; + group->type = *data; + group->aux_data_len = *(data + 1); + group->num_sources = ntohs(*((uint16_t *)(data + 2))); + group->group_address = *((uint32_t *)(data + 4)); // Stored in network byte order + if (group->num_sources > 0) { + group->sources = malloc(group->num_sources * sizeof(uint32_t)); + for (uint16_t j = 0; j < group->num_sources; j++) { + *((group->sources) + j) = *((uint32_t *)(data + 8 + j * 4)); // Stored in network byte order + } + } else { + group->sources = NULL; + } + data += 8 + group->num_sources * 4; + } + return groups; +} + +/** + * @brief Parse an IGMPv3 Membership Report message. + * + * @param data pointer to the start of the IGMPv3 Membership Report message + * @return the parsed IGMPv3 Membership Report message + */ +static igmp_v3_membership_report_t igmp_v3_parse_membership_report(uint8_t *data) { + igmp_v3_membership_report_t message; + message.checksum = ntohs(*((uint16_t *)(data + 2))); + message.num_groups = ntohs(*((uint16_t *)(data + 6))); + message.groups = igmp_v3_parse_groups(message.num_groups, data + 8); + return message; +} + +/** + * @brief Parse an IGMP message. + * + * @param data pointer to the start of the IGMP message + * @return the parsed IGMP message + */ +igmp_message_t igmp_parse_message(uint8_t *data) { + igmp_message_t message; + message.type = (igmp_message_type_t) *data; + // Dispatch on IGMP message type + switch (message.type) { + case MEMBERSHIP_QUERY: + case V1_MEMBERSHIP_REPORT: + case V2_MEMBERSHIP_REPORT: + case LEAVE_GROUP: + message.version = 2; + message.body.v2_message = igmp_v2_parse_message(data); + break; + case V3_MEMBERSHIP_REPORT: + message.version = 3; + message.body.v3_membership_report = igmp_v3_parse_membership_report(data); + break; + default: + break; + } + return message; +} + +/** + * @brief Free the memory allocated for an IGMP message. + * + * @param message the IGMP message to free + */ +void igmp_free_message(igmp_message_t message) { + if (message.version == 3 && message.body.v3_membership_report.num_groups > 0) { + for (uint16_t i = 0; i < message.body.v3_membership_report.num_groups; i++) + { + igmp_v3_group_record_t group = *(message.body.v3_membership_report.groups + i); + if (group.num_sources > 0) + free(group.sources); + } + free(message.body.v3_membership_report.groups); + } +} + + +///// PRINTING ///// + +/** + * @brief Print an IGMPv2 message. + * + * @param v2_message the IGMPv2 message to print + */ +static void igmp_v2_print_message(igmp_v2_message_t v2_message) { + printf(" Max resp time: %hhu\n", v2_message.max_resp_time); + printf(" Checksum: %#hx\n", v2_message.checksum); + printf(" Group address: %s\n", ipv4_net_to_str(v2_message.group_address)); +} + +/** + * @brief Print an IGMPv3 Membership Report message. + * + * @param group the IGMPv3 Membership Report message to print + */ +static void igmp_v3_print_membership_report(igmp_v3_membership_report_t v3_message) { + printf(" Checksum: %#hx\n", v3_message.checksum); + printf(" Number of groups: %hu\n", v3_message.num_groups); + for (uint16_t i = 0; i < v3_message.num_groups; i++) { + igmp_v3_group_record_t group = *(v3_message.groups + i); + printf(" Group %d:\n", i); + printf(" Type: %#hhx\n", group.type); + printf(" Aux data len: %hhu\n", group.aux_data_len); + printf(" Number of sources: %hu\n", group.num_sources); + printf(" Group address: %s\n", ipv4_net_to_str(group.group_address)); + for (uint16_t j = 0; j < group.num_sources; j++) { + printf(" Source %d: %s\n", j, ipv4_net_to_str(*(group.sources + j))); + } + } +} + +/** + * @brief Print an IGMP message. + * + * @param message the IGMP message to print + */ +void igmp_print_message(igmp_message_t message) { + printf("IGMP message:\n"); + printf(" Version: %hhu\n", message.version); + printf(" Type: %#hhx\n", message.type); + switch (message.version) { + case 2: + igmp_v2_print_message(message.body.v2_message); + break; + case 3: + igmp_v3_print_membership_report(message.body.v3_membership_report); + break; + } +} diff --git a/src/packet_utils.c b/src/packet_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..8b0aec2358ec417e898759dfdd46ec602b86f90c --- /dev/null +++ b/src/packet_utils.c @@ -0,0 +1,294 @@ +/** + * @file src/packet_utils.c + * @brief Utilitaries for payload manipulation and display + * @date 2022-09-09 + * + * @copyright Copyright (c) 2022 + * + */ + +#include "packet_utils.h" + + +/** + * Print a packet payload. + * + * @param length length of the payload in bytes + * @param data pointer to the start of the payload + */ +void print_payload(int length, uint8_t *data) { + char trailing = ' '; + // Iterate on the whole payload + for (int i = 0; i < length; i++) { + if (i == length - 1) { + // Insert newline after last byte + trailing = '\n'; + } + + uint8_t c = *(data + i); + if (c == 0) { + printf("0x00%c", trailing); + } else { + printf("%#.2x%c", c, trailing); + } + } +} + +/** + * Converts a hexstring payload to a data buffer. + * + * @param hexstring the hexstring to convert + * @param payload a double pointer to the payload, which will be set to the start of the payload + * @return the length of the payload in bytes + */ +size_t hexstr_to_payload(char *hexstring, uint8_t **payload) { + size_t length = strlen(hexstring) / 2; // Size of the payload in bytes, one byte is two characters + *payload = (uint8_t *) malloc(length * sizeof(uint8_t)); // Allocate memory for the payload + + // WARNING: no sanitization or error-checking whatsoever + for (size_t count = 0; count < length; count++) { + sscanf(hexstring + 2*count, "%2hhx", (*payload) + count); // Convert two characters to one byte + } + + return length; +} + +/** + * Converts a MAC address from its hexadecimal representation + * to its string representation. + * + * @param mac_hex MAC address in hexadecimal representation + * @return the same MAC address in string representation + */ +char *mac_hex_to_str(uint8_t mac_hex[]) +{ + char *mac_str = (char *) malloc(MAC_ADDR_STRLEN * sizeof(char)); // A string representation of a MAC address is 17 characters long + null terminator + int ret = snprintf(mac_str, MAC_ADDR_STRLEN, "%02hhx:%02hhx:%02hhx:%02hhx:%02hhx:%02hhx", mac_hex[0], mac_hex[1], mac_hex[2], mac_hex[3], mac_hex[4], mac_hex[5]); + // Error handling + if (ret != MAC_ADDR_STRLEN - 1) + { + free(mac_str); + fprintf(stderr, "Error converting MAC address \\x%2x\\x%2x\\x%2x\\x%2x\\x%2x\\x%2x to string representation.\n", mac_hex[0], mac_hex[1], mac_hex[2], mac_hex[3], mac_hex[4], mac_hex[5]); + return NULL; + } + return mac_str; +} + +/** + * Converts a MAC address from its string representation + * to its hexadecimal representation. + * + * @param mac_str MAC address in string representation + * @return the same MAC address in hexadecimal representation + */ +uint8_t *mac_str_to_hex(char *mac_str) +{ + uint8_t *mac_hex = (uint8_t *) malloc(MAC_ADDR_LENGTH * sizeof(uint8_t)); // A MAC address is 6 bytes long + int ret = sscanf(mac_str, "%hhx:%hhx:%hhx:%hhx:%hhx:%hhx", mac_hex, mac_hex + 1, mac_hex + 2, mac_hex + 3, mac_hex + 4, mac_hex + 5); + // Error handling + if (ret != MAC_ADDR_LENGTH) + { + free(mac_hex); + fprintf(stderr, "Error converting MAC address %s to hexadecimal representation.\n", mac_str); + return NULL; + } + return mac_hex; +} + +/** + * Converts an IPv4 address from its network order numerical representation + * to its string representation. + * (Wrapper arount inet_ntoa) + * + * @param ipv4_net IPv4 address in hexadecimal representation + * @return the same IPv4 address in string representation + */ +char* ipv4_net_to_str(uint32_t ipv4_net) { + return inet_ntoa((struct in_addr) {ipv4_net}); +} + +/** + * Converts an IPv4 address from its string representation + * to its network order numerical representation. + * (Wrapper arount inet_aton) + * + * @param ipv4_str IPv4 address in string representation + * @return the same IPv4 address in network order numerical representation + */ +uint32_t ipv4_str_to_net(char *ipv4_str) { + struct in_addr ipv4_addr; + inet_aton(ipv4_str, &ipv4_addr); + return ipv4_addr.s_addr; +} + +/** + * Converts an IPv4 addres from its hexadecimal representation + * to its string representation. + * + * @param ipv4_hex IPv4 address in hexadecimal representation + * @return the same IPv4 address in string representation + */ +char* ipv4_hex_to_str(char *ipv4_hex) { + char* ipv4_str = (char *) malloc(INET_ADDRSTRLEN * sizeof(char)); // A string representation of an IPv4 address is at most 15 characters long + null terminator + int ret = snprintf(ipv4_str, INET_ADDRSTRLEN, "%hhu.%hhu.%hhu.%hhu", *ipv4_hex, *(ipv4_hex + 1), *(ipv4_hex + 2), *(ipv4_hex + 3)); + // Error handling + if (ret < 0) { + free(ipv4_str); + fprintf(stderr, "Error converting IPv4 address \\x%2x\\x%2x\\x%2x\\x%2x to string representation.\n", *ipv4_hex, *(ipv4_hex + 1), *(ipv4_hex + 2), *(ipv4_hex + 3)); + return NULL; + } + return ipv4_str; +} + +/** + * Converts an IPv4 address from its string representation + * to its hexadecimal representation. + * + * @param ipv4_str IPv4 address in string representation + * @return the same IPv4 address in hexadecimal representation + */ +char* ipv4_str_to_hex(char *ipv4_str) { + char* ipv4_hex = (char *) malloc(4 * sizeof(char)); // An IPv4 address is 4 bytes long + int ret = sscanf(ipv4_str, "%hhu.%hhu.%hhu.%hhu", ipv4_hex, ipv4_hex + 1, ipv4_hex + 2, ipv4_hex + 3); + // Error handling + if (ret != 4) { + free(ipv4_hex); + fprintf(stderr, "Error converting IPv4 address %s to hexadecimal representation.\n", ipv4_str); + return NULL; + } + return ipv4_hex; +} + +/** + * @brief Converts an IPv6 to its string representation. + * + * @param ipv6 the IPv6 address + * @return the same IPv6 address in string representation + */ +char* ipv6_net_to_str(uint8_t ipv6[]) { + char *ipv6_str = (char *) malloc(INET6_ADDRSTRLEN * sizeof(char)); + const char *ret = inet_ntop(AF_INET6, ipv6, ipv6_str, INET6_ADDRSTRLEN); + // Error handling + if (ret == NULL) { + fprintf(stderr, "Error converting IPv6 address \\x%2x\\x%2x\\x%2x\\x%2x\\x%2x\\x%2x\\x%2x\\x%2x\\x%2x\\x%2x\\x%2x\\x%2x\\x%2x\\x%2x\\x%2x\\x%2x to its string representation.\n", ipv6[0], ipv6[1], ipv6[2], ipv6[3], ipv6[4], ipv6[5], ipv6[6], ipv6[7], ipv6[8], ipv6[9], ipv6[10], ipv6[11], ipv6[12], ipv6[13], ipv6[14], ipv6[15]); + } + return ipv6_str; +} + +/** + * Converts an IPv6 address from its string representation + * to its network representation (a 16-byte array). + * + * @param ipv6_str IPv6 address in string representation + * @return the same IPv6 address as a 16-byte array + */ +uint8_t *ipv6_str_to_net(char *ipv6_str) { + uint8_t *ipv6 = (uint8_t *) malloc(IPV6_ADDR_LENGTH * sizeof(uint8_t)); // An IPv6 address is 16 bytes long + int err = inet_pton(AF_INET6, ipv6_str, ipv6); + // Error handling + if (err != 1) { + fprintf(stderr, "Error converting IPv6 address %s to its network representation.\n", ipv6_str); + return NULL; + } + return ipv6; +} + +/** + * @brief Converts an IP (v4 or v6) address to its string representation. + * + * Converts an IP (v4 or v6) address to its string representation. + * If it is an IPv6 address, it must be freed after use. + * + * @param ip_addr the IP address, as an ip_addr_t struct + * @return the same IP address in string representation + */ +char* ip_net_to_str(ip_addr_t ip_addr) { + switch (ip_addr.version) { + case 4: + return ipv4_net_to_str(ip_addr.value.ipv4); + break; + case 6: + return ipv6_net_to_str(ip_addr.value.ipv6); + break; + default: + fprintf(stderr, "Unknown IP version: %hhu.\n", ip_addr.version); + return ""; + } +} + +/** + * Converts an IP (v4 or v6) address from its string representation + * to an ip_addr_t struct. + * + * @param ip_str IP (v4 or v6) address in string representation + * @return the same IP address as a ip_addr_t struct + */ +ip_addr_t ip_str_to_net(char *ip_str, uint8_t version) { + ip_addr_t ip_addr; + ip_addr.version = version; + if (version == 4) { + ip_addr.value.ipv4 = ipv4_str_to_net(ip_str); + } else if (version == 6) { + uint8_t *ipv6_net = ipv6_str_to_net(ip_str); + memcpy(ip_addr.value.ipv6, ipv6_net, IPV6_ADDR_LENGTH); + free(ipv6_net); + } else { + fprintf(stderr, "Error converting address %s to ip_addr_t.\n", ip_str); + } + return ip_addr; +} + +/** + * @brief Compare two IPv6 addresses. + * + * @param ipv6_1 first IPv6 address + * @param ipv6_2 second IPv6 address + * @return true if the two addresses are equal, false otherwise + */ +bool compare_ipv6(uint8_t *ipv6_1, uint8_t *ipv6_2) { + return memcmp(ipv6_1, ipv6_2, 16) == 0; +} + +/** + * @brief Compare two IP (v4 or v6) addresses. + * + * @param ip_1 first IP address + * @param ip_2 second IP address + * @return true if the two addresses are equal, false otherwise + */ +bool compare_ip(ip_addr_t ip_1, ip_addr_t ip_2) { + if (ip_1.version == 4 && ip_2.version == 4) { + return ip_1.value.ipv4 == ip_2.value.ipv4; + } else if (ip_1.version == 6 && ip_2.version == 6) { + return compare_ipv6(ip_1.value.ipv6, ip_2.value.ipv6); + } else { + return false; + } +} + +/** + * @brief Compute SHA256 hash of a given payload. + * + * @param payload Payload to hash + * @param payload_len Payload length, including padding (in bytes) + * @return uint8_t* SHA256 hash of the payload + */ +uint8_t* compute_hash(uint8_t *payload, int payload_len) { + uint8_t *hash = (uint8_t *) malloc(SHA256_BLOCK_SIZE * sizeof(uint8_t)); + SHA256_CTX ctx; + sha256_init(&ctx); + sha256_update(&ctx, payload, payload_len); + sha256_final(&ctx, hash); + return hash; +} + +/** + * @brief Print a SHA256 hash. + * + * @param hash SHA256 hash to print + */ +void print_hash(uint8_t *hash) { + for (uint16_t i = 0; i < SHA256_BLOCK_SIZE; i++) { + printf("%02x", *(hash + i)); + } +} diff --git a/src/sha256.c b/src/sha256.c new file mode 100644 index 0000000000000000000000000000000000000000..eb9c5c0733e7a6234998e1dff49300c4b58e7d71 --- /dev/null +++ b/src/sha256.c @@ -0,0 +1,158 @@ +/********************************************************************* +* Filename: sha256.c +* Author: Brad Conte (brad AT bradconte.com) +* Copyright: +* Disclaimer: This code is presented "as is" without any guarantees. +* Details: Implementation of the SHA-256 hashing algorithm. + SHA-256 is one of the three algorithms in the SHA2 + specification. The others, SHA-384 and SHA-512, are not + offered in this implementation. + Algorithm specification can be found here: + * http://csrc.nist.gov/publications/fips/fips180-2/fips180-2withchangenotice.pdf + This implementation uses little endian byte order. +*********************************************************************/ + +/*************************** HEADER FILES ***************************/ +#include <stdlib.h> +#include <memory.h> +#include "sha256.h" + +/****************************** MACROS ******************************/ +#define ROTLEFT(a,b) (((a) << (b)) | ((a) >> (32-(b)))) +#define ROTRIGHT(a,b) (((a) >> (b)) | ((a) << (32-(b)))) + +#define CH(x,y,z) (((x) & (y)) ^ (~(x) & (z))) +#define MAJ(x,y,z) (((x) & (y)) ^ ((x) & (z)) ^ ((y) & (z))) +#define EP0(x) (ROTRIGHT(x,2) ^ ROTRIGHT(x,13) ^ ROTRIGHT(x,22)) +#define EP1(x) (ROTRIGHT(x,6) ^ ROTRIGHT(x,11) ^ ROTRIGHT(x,25)) +#define SIG0(x) (ROTRIGHT(x,7) ^ ROTRIGHT(x,18) ^ ((x) >> 3)) +#define SIG1(x) (ROTRIGHT(x,17) ^ ROTRIGHT(x,19) ^ ((x) >> 10)) + +/**************************** VARIABLES *****************************/ +static const WORD k[64] = { + 0x428a2f98,0x71374491,0xb5c0fbcf,0xe9b5dba5,0x3956c25b,0x59f111f1,0x923f82a4,0xab1c5ed5, + 0xd807aa98,0x12835b01,0x243185be,0x550c7dc3,0x72be5d74,0x80deb1fe,0x9bdc06a7,0xc19bf174, + 0xe49b69c1,0xefbe4786,0x0fc19dc6,0x240ca1cc,0x2de92c6f,0x4a7484aa,0x5cb0a9dc,0x76f988da, + 0x983e5152,0xa831c66d,0xb00327c8,0xbf597fc7,0xc6e00bf3,0xd5a79147,0x06ca6351,0x14292967, + 0x27b70a85,0x2e1b2138,0x4d2c6dfc,0x53380d13,0x650a7354,0x766a0abb,0x81c2c92e,0x92722c85, + 0xa2bfe8a1,0xa81a664b,0xc24b8b70,0xc76c51a3,0xd192e819,0xd6990624,0xf40e3585,0x106aa070, + 0x19a4c116,0x1e376c08,0x2748774c,0x34b0bcb5,0x391c0cb3,0x4ed8aa4a,0x5b9cca4f,0x682e6ff3, + 0x748f82ee,0x78a5636f,0x84c87814,0x8cc70208,0x90befffa,0xa4506ceb,0xbef9a3f7,0xc67178f2 +}; + +/*********************** FUNCTION DEFINITIONS ***********************/ +void sha256_transform(SHA256_CTX *ctx, const BYTE data[]) +{ + WORD a, b, c, d, e, f, g, h, i, j, t1, t2, m[64]; + + for (i = 0, j = 0; i < 16; ++i, j += 4) + m[i] = (data[j] << 24) | (data[j + 1] << 16) | (data[j + 2] << 8) | (data[j + 3]); + for ( ; i < 64; ++i) + m[i] = SIG1(m[i - 2]) + m[i - 7] + SIG0(m[i - 15]) + m[i - 16]; + + a = ctx->state[0]; + b = ctx->state[1]; + c = ctx->state[2]; + d = ctx->state[3]; + e = ctx->state[4]; + f = ctx->state[5]; + g = ctx->state[6]; + h = ctx->state[7]; + + for (i = 0; i < 64; ++i) { + t1 = h + EP1(e) + CH(e,f,g) + k[i] + m[i]; + t2 = EP0(a) + MAJ(a,b,c); + h = g; + g = f; + f = e; + e = d + t1; + d = c; + c = b; + b = a; + a = t1 + t2; + } + + ctx->state[0] += a; + ctx->state[1] += b; + ctx->state[2] += c; + ctx->state[3] += d; + ctx->state[4] += e; + ctx->state[5] += f; + ctx->state[6] += g; + ctx->state[7] += h; +} + +void sha256_init(SHA256_CTX *ctx) +{ + ctx->datalen = 0; + ctx->bitlen = 0; + ctx->state[0] = 0x6a09e667; + ctx->state[1] = 0xbb67ae85; + ctx->state[2] = 0x3c6ef372; + ctx->state[3] = 0xa54ff53a; + ctx->state[4] = 0x510e527f; + ctx->state[5] = 0x9b05688c; + ctx->state[6] = 0x1f83d9ab; + ctx->state[7] = 0x5be0cd19; +} + +void sha256_update(SHA256_CTX *ctx, const BYTE data[], size_t len) +{ + WORD i; + + for (i = 0; i < len; ++i) { + ctx->data[ctx->datalen] = data[i]; + ctx->datalen++; + if (ctx->datalen == 64) { + sha256_transform(ctx, ctx->data); + ctx->bitlen += 512; + ctx->datalen = 0; + } + } +} + +void sha256_final(SHA256_CTX *ctx, BYTE hash[]) +{ + WORD i; + + i = ctx->datalen; + + // Pad whatever data is left in the buffer. + if (ctx->datalen < 56) { + ctx->data[i++] = 0x80; + while (i < 56) + ctx->data[i++] = 0x00; + } + else { + ctx->data[i++] = 0x80; + while (i < 64) + ctx->data[i++] = 0x00; + sha256_transform(ctx, ctx->data); + memset(ctx->data, 0, 56); + } + + // Append to the padding the total message's length in bits and transform. + ctx->bitlen += ctx->datalen * 8; + ctx->data[63] = ctx->bitlen; + ctx->data[62] = ctx->bitlen >> 8; + ctx->data[61] = ctx->bitlen >> 16; + ctx->data[60] = ctx->bitlen >> 24; + ctx->data[59] = ctx->bitlen >> 32; + ctx->data[58] = ctx->bitlen >> 40; + ctx->data[57] = ctx->bitlen >> 48; + ctx->data[56] = ctx->bitlen >> 56; + sha256_transform(ctx, ctx->data); + + // Since this implementation uses little endian byte ordering and SHA uses big endian, + // reverse all the bytes when copying the final state to the output hash. + for (i = 0; i < 4; ++i) { + hash[i] = (ctx->state[0] >> (24 - i * 8)) & 0x000000ff; + hash[i + 4] = (ctx->state[1] >> (24 - i * 8)) & 0x000000ff; + hash[i + 8] = (ctx->state[2] >> (24 - i * 8)) & 0x000000ff; + hash[i + 12] = (ctx->state[3] >> (24 - i * 8)) & 0x000000ff; + hash[i + 16] = (ctx->state[4] >> (24 - i * 8)) & 0x000000ff; + hash[i + 20] = (ctx->state[5] >> (24 - i * 8)) & 0x000000ff; + hash[i + 24] = (ctx->state[6] >> (24 - i * 8)) & 0x000000ff; + hash[i + 28] = (ctx->state[7] >> (24 - i * 8)) & 0x000000ff; + } +} diff --git a/src/ssdp.c b/src/ssdp.c new file mode 100644 index 0000000000000000000000000000000000000000..48b0fbb61ffbdfedf917a4e9faedde2aa4b5e294 --- /dev/null +++ b/src/ssdp.c @@ -0,0 +1,91 @@ +/** + * @file src/ssdp.c + * @brief SSDP message parser + * @date 2022-11-24 + * + * @copyright Copyright (c) 2022 + * + */ + +#include "ssdp.h" + + +///// PARSING ///// + +/** + * @brief Parse the method of an SSDP message. + * + * Parse a SSDP message to retrieve its method, + * and convert it to a ssdp_message_t. + * Only the two first characters need to be parsed. + * Advances the offset value after parsing. + * + * @param data pointer to the start of the SSDP message + * @param offset current offset in the message + * @return parsed SSDP method + */ +static ssdp_method_t ssdp_parse_method(uint8_t *data, uint16_t *offset) { + switch (*(data + *offset)) { + case 'M': + // Method is M-SEARCH + *offset += 9; + return SSDP_M_SEARCH; + break; + case 'N': + // Method is NOTIFY + *offset += 7; + return SSDP_NOTIFY; + break; + default: + // Unknown method + return SSDP_UNKNOWN; + } +} + +/** + * @brief Parse the method and URI of SSDP message. + * + * @param data pointer to the start of the SSDP message + * @param dst_addr IPv4 destination address, in network byte order + * @return the parsed SSDP message + */ +ssdp_message_t ssdp_parse_message(uint8_t *data, uint32_t dst_addr) { + ssdp_message_t message; + message.is_request = dst_addr == ipv4_str_to_net(SSDP_MULTICAST_ADDR); + uint16_t offset = 0; + message.method = ssdp_parse_method(data, &offset); + return message; +} + + +///// PRINTING ///// + +/** + * @brief Converts a SSDP method from enum value to character string. + * + * @param method the SSDP method in enum value + * @return the same SSDP method as a character string + */ +char *ssdp_method_to_str(ssdp_method_t method) { + switch (method) { + case SSDP_M_SEARCH: + return "M-SEARCH"; + break; + case SSDP_NOTIFY: + return "NOTIFY"; + break; + default: + return "UNKNOWN"; + } +} + +/** + * @brief Print the method and URI of a SSDP message. + * + * @param message the message to print + */ +void ssdp_print_message(ssdp_message_t message) { + printf("SSDP message:\n"); + printf(" is request ?: %d\n", message.is_request); + printf(" Method: %s\n", ssdp_method_to_str(message.method)); +} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ff5f8e6f9f62d2aeff75b378b286146695722984 --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,82 @@ +# Minimum required CMake version +cmake_minimum_required(VERSION 3.20) + +# Set test output directory +set(TEST_BIN_DIR ${BIN_DIR}/test) +set(EXECUTABLE_OUTPUT_PATH ${TEST_BIN_DIR}) + +# Packet utils test +add_executable(packet_utils-test packet_utils.c) +target_include_directories(packet_utils-test PRIVATE ${INCLUDE_DIR}) +target_link_libraries(packet_utils-test cunit) +target_link_libraries(packet_utils-test packet_utils) +install(TARGETS packet_utils-test DESTINATION ${TEST_BIN_DIR}) + + +## Protocol parsers + +# Header +add_executable(header-test header.c) +target_include_directories(header-test PRIVATE ${INCLUDE_DIR}) +target_link_libraries(header-test cunit) +target_link_libraries(header-test packet_utils) +target_link_libraries(header-test header) +install(TARGETS header-test DESTINATION ${TEST_BIN_DIR}) + +# DNS +add_executable(dns-test dns.c) +target_include_directories(dns-test PRIVATE ${INCLUDE_DIR}) +target_link_libraries(dns-test cunit) +target_link_libraries(dns-test packet_utils) +target_link_libraries(dns-test header dns) +install(TARGETS dns-test DESTINATION ${TEST_BIN_DIR}) + +# DHCP +add_executable(dhcp-test dhcp.c) +target_include_directories(dhcp-test PRIVATE ${INCLUDE_DIR}) +target_link_libraries(dhcp-test cunit) +target_link_libraries(dhcp-test packet_utils) +target_link_libraries(dhcp-test header dhcp) +install(TARGETS dhcp-test DESTINATION ${TEST_BIN_DIR}) + +# HTTP +add_executable(http-test http.c) +target_include_directories(http-test PRIVATE ${INCLUDE_DIR}) +target_link_libraries(http-test cunit) +target_link_libraries(http-test packet_utils) +target_link_libraries(http-test header http) +install(TARGETS http-test DESTINATION ${TEST_BIN_DIR}) + +# IGMP +add_executable(igmp-test igmp.c) +target_include_directories(igmp-test PRIVATE ${INCLUDE_DIR}) +target_link_libraries(igmp-test cunit) +target_link_libraries(igmp-test packet_utils) +target_link_libraries(igmp-test header igmp) +install(TARGETS igmp-test DESTINATION ${TEST_BIN_DIR}) + +# SSDP +add_executable(ssdp-test ssdp.c) +target_include_directories(ssdp-test PRIVATE ${INCLUDE_DIR}) +target_link_libraries(ssdp-test cunit) +target_link_libraries(ssdp-test packet_utils) +target_link_libraries(ssdp-test header ssdp) +install(TARGETS ssdp-test DESTINATION ${TEST_BIN_DIR}) + +# CoAP +add_executable(coap-test coap.c) +target_include_directories(coap-test PRIVATE ${INCLUDE_DIR}) +target_link_libraries(coap-test cunit) +target_link_libraries(coap-test packet_utils) +target_link_libraries(coap-test header coap) +install(TARGETS coap-test DESTINATION ${TEST_BIN_DIR}) + + +## DNS map +add_executable(dns_map-test dns_map.c) +target_include_directories(dns_map-test PRIVATE ${INCLUDE_DIR}) +target_link_libraries(dns_map-test cunit) +target_link_libraries(dns_map-test hashmap) +target_link_libraries(dns_map-test packet_utils) +target_link_libraries(dns_map-test dns_map) +install(TARGETS dns_map-test DESTINATION ${TEST_BIN_DIR}) diff --git a/test/coap.c b/test/coap.c new file mode 100644 index 0000000000000000000000000000000000000000..59ff864efa580545d67890b4575c3e67f61b3777 --- /dev/null +++ b/test/coap.c @@ -0,0 +1,71 @@ +/** + * @file test/coap.c + * @brief Unit tests for the CoAP parser + * @date 2022-11-30 + * + * @copyright Copyright (c) 2022 + * + */ + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +// Custom libraries +#include "packet_utils.h" +#include "header.h" +#include "coap.h" +// CUnit +#include <CUnit/CUnit.h> +#include <CUnit/Basic.h> + + +/** + * @brief Unit test for the CoAP parser, using a Non-Confirmable GET message. + */ +void test_coap_non_get() { + + char *hexstring = "60017a1800451102fe80000000000000db22fbeca6b444feff0200000000000000000000000001588b5316330045c374580175f2d55892c87b38f0fbb36f6963037265734d1472743d782e636f6d2e73616d73756e672e70726f766973696f6e696e67696e666f213ce1fed6c0"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify message length + + // Actual message + size_t skipped = get_ipv6_header_length(payload); + uint16_t coap_length = get_udp_payload_length(payload + skipped); + skipped += get_udp_header_length(payload + skipped); + coap_message_t actual = coap_parse_message(payload + skipped, coap_length); + free(payload); + //coap_print_message(actual); + + // Expected message + coap_message_t expected; + expected.type = COAP_NON; + expected.method = HTTP_GET; + expected.uri = "/oic/res?rt=x.com.samsung.provisioninginfo"; + + // Compare messages + CU_ASSERT_EQUAL(actual.type, expected.type); + CU_ASSERT_EQUAL(actual.method, expected.method); + CU_ASSERT_STRING_EQUAL(actual.uri, expected.uri); + + coap_free_message(actual); + +} + + +/** + * Main function for the unit tests. + */ +int main(int argc, char const *argv[]) +{ + // Initialize registry and suite + if (CU_initialize_registry() != CUE_SUCCESS) + return CU_get_error(); + CU_pSuite suite = CU_add_suite("coap", NULL, NULL); + // Run tests + CU_add_test(suite, "coap-non-get", test_coap_non_get); + CU_basic_run_tests(); + CU_cleanup_registry(); + return 0; +} diff --git a/test/dhcp.c b/test/dhcp.c new file mode 100644 index 0000000000000000000000000000000000000000..b81174c6575cc701da3003c89edbaa50d80bdf43 --- /dev/null +++ b/test/dhcp.c @@ -0,0 +1,254 @@ +/** + * @file test/dhcp.c + * @brief Unit tests for the DHCP parser + * @date 2022-09-12 + * + * @copyright Copyright (c) 2022 + * + */ + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +// Custom libraries +#include "packet_utils.h" +#include "header.h" +#include "dhcp.h" +// CUnit +#include <CUnit/CUnit.h> +#include <CUnit/Basic.h> + + +/** + * @brief Compare the headers of two DHCP messages. + * + * @param actual actual DHCP message + * @param expected expected DHCP message + */ +void compare_headers(dhcp_message_t actual, dhcp_message_t expected) { + CU_ASSERT_EQUAL(actual.op, expected.op); + CU_ASSERT_EQUAL(actual.htype, expected.htype); + CU_ASSERT_EQUAL(actual.hlen, expected.hlen); + CU_ASSERT_EQUAL(actual.hops, expected.hops); + CU_ASSERT_EQUAL(actual.xid, expected.xid); + CU_ASSERT_EQUAL(actual.secs, expected.secs); + CU_ASSERT_EQUAL(actual.flags, expected.flags); + CU_ASSERT_EQUAL(actual.ciaddr, expected.ciaddr); + CU_ASSERT_EQUAL(actual.yiaddr, expected.yiaddr); + CU_ASSERT_EQUAL(actual.siaddr, expected.siaddr); + CU_ASSERT_EQUAL(actual.giaddr, expected.giaddr); + for (uint8_t i = 0; i < MAX_HW_LEN; i++) { + CU_ASSERT_EQUAL(actual.chaddr[i], expected.chaddr[i]); + } +} + +/** + * @brief Compare two DHCP options lists. + * + * @param actual actual DHCP options list + * @param expected expected DHCP options list + */ +void compare_options(dhcp_options_t actual, dhcp_options_t expected) { + for (uint8_t i = 0; i < expected.count; i++) { + CU_ASSERT_EQUAL((actual.options + i)->code, (expected.options + i)->code); + CU_ASSERT_EQUAL((actual.options + i)->length, (expected.options + i)->length); + for (uint8_t j = 0; j < (actual.options + i)->length; j++) { + CU_ASSERT_EQUAL(*(((actual.options + i)->value) + j), *(((expected.options + i)->value) + j)); + } + } +} + +/** + * DHCP Unit test, with a DHCP Discover message. + */ +void test_dhcp_discover() { + char *hexstring = "4500014c00000000401179a200000000ffffffff004400430138dc40010106006617ca540000000000000000000000000000000000000000788b2ab220ea00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000638253633501013d0701788b2ab220ea3902024037070103060c0f1c2a3c0c756468637020312e32382e310c16636875616e676d695f63616d6572615f697063303139ff"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify message length + + size_t skipped = get_headers_length(payload); + dhcp_message_t message = dhcp_parse_message(payload + skipped); + free(payload); + //dhcp_print_message(message); + + // Test different sections of the DHCP message + + // Header + dhcp_message_t expected; + expected.op = DHCP_BOOTREQUEST; + expected.htype = 1; + expected.hlen = 6; + expected.hops = 0; + expected.xid = 0x6617ca54; + expected.secs = 0; + expected.flags = 0x0000; + expected.ciaddr = ipv4_str_to_net("0.0.0.0"); + expected.yiaddr = ipv4_str_to_net("0.0.0.0"); + expected.siaddr = ipv4_str_to_net("0.0.0.0"); + expected.giaddr = ipv4_str_to_net("0.0.0.0"); + memcpy(expected.chaddr, "\x78\x8b\x2a\xb2\x20\xea\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", MAX_HW_LEN); + compare_headers(message, expected); + + // Options + expected.options.count = 7; + expected.options.options = (dhcp_option_t *) malloc(sizeof(dhcp_option_t) * expected.options.count); + // Option 53: DHCP Message Type + expected.options.options->code = 53; + expected.options.options->length = 1; + expected.options.options->value = (uint8_t *) malloc(sizeof(uint8_t) * expected.options.options->length); + *(expected.options.options->value) = DHCP_DISCOVER; + CU_ASSERT_EQUAL(message.options.message_type, DHCP_DISCOVER); + // Option 61: Client Identifier + (expected.options.options + 1)->code = 61; + (expected.options.options + 1)->length = 7; + (expected.options.options + 1)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 1)->length); + memcpy((expected.options.options + 1)->value, "\x01\x78\x8b\x2a\xb2\x20\xea", (expected.options.options + 1)->length); + // Option 57: Maximum DHCP Message Size + (expected.options.options + 2)->code = 57; + (expected.options.options + 2)->length = 2; + (expected.options.options + 2)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 2)->length); + memcpy((expected.options.options + 2)->value, "\x02\x40", (expected.options.options + 2)->length); + // Option 55: Parameter Request List + (expected.options.options + 3)->code = 55; + (expected.options.options + 3)->length = 7; + (expected.options.options + 3)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 3)->length); + memcpy((expected.options.options + 3)->value, "\x01\x03\x06\x0c\x0f\x1c\x2a", (expected.options.options + 3)->length); + // Option 60: Vendor Class Identifier + (expected.options.options + 4)->code = 60; + (expected.options.options + 4)->length = 12; + (expected.options.options + 4)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 4)->length); + memcpy((expected.options.options + 4)->value, "\x75\x64\x68\x63\x70\x20\x31\x2e\x32\x38\x2e\x31", (expected.options.options + 4)->length); + // Option 12: Host Name + (expected.options.options + 5)->code = 12; + (expected.options.options + 5)->length = 22; + (expected.options.options + 5)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 5)->length); + memcpy((expected.options.options + 5)->value, "\x63\x68\x75\x61\x6e\x67\x6d\x69\x5f\x63\x61\x6d\x65\x72\x61\x5f\x69\x70\x63\x30\x31\x39", (expected.options.options + 5)->length); + // Option 255: End + (expected.options.options + 6)->code = 255; + (expected.options.options + 6)->length = 0; + (expected.options.options + 6)->value = NULL; + // Compare and free options + compare_options(message.options, expected.options); + + // Free messages + dhcp_free_message(message); + dhcp_free_message(expected); +} + +/** + * DHCP Unit test, with a DHCP Offer message. + */ +void test_dhcp_offer() { + char *hexstring = "45c0014820a000004011d452c0a80101c0a801a10043004401341617020106006617ca540000000000000000c0a801a1c0a8010100000000788b2ab220ea00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000638253633501023604c0a8010133040000a8c03a04000054603b04000093a80104ffffff001c04c0a801ff0304c0a801010604c0a801010f036c616eff000000"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify message length + + size_t skipped = get_headers_length(payload); + dhcp_message_t message = dhcp_parse_message(payload + skipped); + free(payload); + //dhcp_print_message(message); + + // Test different sections of the DHCP message + + // Header + dhcp_message_t expected; + expected.op = DHCP_BOOTREPLY; + expected.htype = 1; + expected.hlen = 6; + expected.hops = 0; + expected.xid = 0x6617ca54; + expected.secs = 0; + expected.flags = 0x0000; + expected.ciaddr = ipv4_str_to_net("0.0.0.0"); + expected.yiaddr = ipv4_str_to_net("192.168.1.161"); + expected.siaddr = ipv4_str_to_net("192.168.1.1"); + expected.giaddr = ipv4_str_to_net("0.0.0.0"); + memcpy(expected.chaddr, "\x78\x8b\x2a\xb2\x20\xea\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", MAX_HW_LEN); + compare_headers(message, expected); + + // Options + expected.options.count = 11; + expected.options.options = (dhcp_option_t *) malloc(sizeof(dhcp_option_t) * expected.options.count); + // Option 53: DHCP Message Type + expected.options.options->code = 53; + expected.options.options->length = 1; + expected.options.options->value = (uint8_t *) malloc(sizeof(uint8_t) * expected.options.options->length); + *(expected.options.options->value) = DHCP_OFFER; + CU_ASSERT_EQUAL(message.options.message_type, DHCP_OFFER); + // Option 54: Server Identifier + (expected.options.options + 1)->code = 54; + (expected.options.options + 1)->length = 4; + (expected.options.options + 1)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 1)->length); + memcpy((expected.options.options + 1)->value, "\xc0\xa8\x01\x01", (expected.options.options + 1)->length); + // Option 51: IP Address Lease Time + (expected.options.options + 2)->code = 51; + (expected.options.options + 2)->length = 4; + (expected.options.options + 2)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 2)->length); + memcpy((expected.options.options + 2)->value, "\x00\x00\xa8\xc0", (expected.options.options + 2)->length); + // Option 58: Renewal Time Value + (expected.options.options + 3)->code = 58; + (expected.options.options + 3)->length = 4; + (expected.options.options + 3)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 3)->length); + memcpy((expected.options.options + 3)->value, "\x00\x00\x54\x60", (expected.options.options + 3)->length); + // Option 59: Rebinding Time Value + (expected.options.options + 4)->code = 59; + (expected.options.options + 4)->length = 4; + (expected.options.options + 4)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 4)->length); + memcpy((expected.options.options + 4)->value, "\x00\x00\x93\xa8", (expected.options.options + 4)->length); + // Option 1: Subnet Mask + (expected.options.options + 5)->code = 1; + (expected.options.options + 5)->length = 4; + (expected.options.options + 5)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 5)->length); + memcpy((expected.options.options + 5)->value, "\xff\xff\xff\x00", (expected.options.options + 5)->length); + // Option 28: Broadcast Address + (expected.options.options + 6)->code = 28; + (expected.options.options + 6)->length = 4; + (expected.options.options + 6)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 6)->length); + memcpy((expected.options.options + 6)->value, "\xc0\xa8\x01\xff", (expected.options.options + 6)->length); + // Option 3: Router + (expected.options.options + 7)->code = 3; + (expected.options.options + 7)->length = 4; + (expected.options.options + 7)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 7)->length); + memcpy((expected.options.options + 7)->value, "\xc0\xa8\x01\x01", (expected.options.options + 7)->length); + // Option 6: Domain Name Server + (expected.options.options + 8)->code = 6; + (expected.options.options + 8)->length = 4; + (expected.options.options + 8)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 8)->length); + memcpy((expected.options.options + 8)->value, "\xc0\xa8\x01\x01", (expected.options.options + 8)->length); + // Option 15: Domain Name + (expected.options.options + 9)->code = 15; + (expected.options.options + 9)->length = 3; + (expected.options.options + 9)->value = (uint8_t *) malloc(sizeof(uint8_t) * (expected.options.options + 9)->length); + memcpy((expected.options.options + 9)->value, "\x6c\x61\x6e", (expected.options.options + 9)->length); + // Option 255: End + (expected.options.options + 10)->code = 255; + (expected.options.options + 10)->length = 0; + (expected.options.options + 10)->value = NULL; + // Compare and free options + compare_options(message.options, expected.options); + + // Free messages + dhcp_free_message(message); + dhcp_free_message(expected); +} + +/** + * Main function for the unit tests. + */ +int main(int argc, char const *argv[]) +{ + // Initialize registry and suite + if (CU_initialize_registry() != CUE_SUCCESS) + return CU_get_error(); + CU_pSuite suite = CU_add_suite("dhcp", NULL, NULL); + // Run tests + CU_add_test(suite, "dhcp-discover", test_dhcp_discover); + CU_add_test(suite, "dhcp-offer", test_dhcp_offer); + CU_basic_run_tests(); + CU_cleanup_registry(); + return 0; +} diff --git a/test/dns.c b/test/dns.c new file mode 100644 index 0000000000000000000000000000000000000000..80a311155f603c0392d3b612bccd3cb117651ef3 --- /dev/null +++ b/test/dns.c @@ -0,0 +1,333 @@ +/** + * @file test/dns.c + * @brief Unit tests for the DNS parser + * @date 2022-09-09 + * + * @copyright Copyright (c) 2022 + * + */ + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +// Custom libraries +#include "packet_utils.h" +#include "header.h" +#include "dns.h" +// CUnit +#include <CUnit/CUnit.h> +#include <CUnit/Basic.h> + + +/** + * Unit test for the header section of a DNS message. + * Verify that each header field is as expected. + */ +void compare_headers(dns_header_t actual, dns_header_t expected) { + CU_ASSERT_EQUAL(actual.id, expected.id); + CU_ASSERT_EQUAL(actual.flags, expected.flags); + CU_ASSERT_EQUAL(actual.qr, expected.qr); + CU_ASSERT_EQUAL(actual.qdcount, expected.qdcount); + CU_ASSERT_EQUAL(actual.ancount, expected.ancount); + CU_ASSERT_EQUAL(actual.nscount, expected.nscount); + CU_ASSERT_EQUAL(actual.arcount, expected.arcount); +} + +/** + * Unit test for the questions section + * of a DNS message. + */ +void compare_questions(uint16_t qdcount, dns_question_t *actual, dns_question_t *expected) { + for (int i = 0; i < qdcount; i++) { + CU_ASSERT_STRING_EQUAL((actual + i)->qname, (expected + i)->qname); + CU_ASSERT_EQUAL((actual + i)->qtype, (expected + i)->qtype); + CU_ASSERT_EQUAL((actual + i)->qclass, (expected + i)->qclass); + } +} + +/** + * Unit test for a resource records section + * of a DNS message. + */ +void compare_rrs(uint16_t count, dns_resource_record_t *actual, dns_resource_record_t *expected) { + for (int i = 0; i < count; i++) { + CU_ASSERT_STRING_EQUAL((actual + i)->name, (expected + i)->name); + CU_ASSERT_EQUAL((actual + i)->rtype, (expected + i)->rtype); + CU_ASSERT_EQUAL((actual + i)->rclass, (expected + i)->rclass); + CU_ASSERT_EQUAL((actual + i)->ttl, (expected + i)->ttl); + CU_ASSERT_EQUAL((actual + i)->rdlength, (expected + i)->rdlength); + CU_ASSERT_STRING_EQUAL( + dns_rdata_to_str((actual + i)->rtype, (actual + i)->rdlength, (actual + i)->rdata), + dns_rdata_to_str((expected + i)->rtype, (expected + i)->rdlength, (expected + i)->rdata) + ); + } +} + +/** + * Unit test for the DNS parser. + */ +void test_dns_xiaomi() { + + char *hexstring = "450000912ecc40004011879dc0a80101c0a801a10035a6b5007d76b46dca8180000100020000000008627573696e6573730b736d61727463616d6572610361706902696f026d6903636f6d0000010001c00c0005000100000258002516636e616d652d6170702d636f6d2d616d7370726f78790177066d692d64756e03636f6d00c04000010001000000930004142f61e7"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify message length + + size_t skipped = get_headers_length(payload); + dns_message_t message = dns_parse_message(payload + skipped); + free(payload); + //dns_print_message(message); + + // Test different sections of the DNS message + + // Header + dns_header_t expected_header; + expected_header.id = 0x6dca; + expected_header.flags = 0x8180; + expected_header.qr = 1; + expected_header.qdcount = 1; + expected_header.ancount = 2; + expected_header.nscount = 0; + expected_header.arcount = 0; + compare_headers(message.header, expected_header); + + // Questions + dns_question_t *expected_question; + expected_question = malloc(sizeof(dns_question_t) * message.header.qdcount); + expected_question->qname = "business.smartcamera.api.io.mi.com"; + expected_question->qtype = 1; + expected_question->qclass = 1; + compare_questions(message.header.qdcount, message.questions, expected_question); + free(expected_question); + + // Answer resource records + dns_resource_record_t *expected_answer; + expected_answer = malloc(sizeof(dns_resource_record_t) * message.header.ancount); + // Answer n°0 + expected_answer->name = "business.smartcamera.api.io.mi.com"; + expected_answer->rtype = 5; + expected_answer->rclass = 1; + expected_answer->ttl = 600; + expected_answer->rdlength = 37; + expected_answer->rdata.domain_name = "cname-app-com-amsproxy.w.mi-dun.com"; + // Answer n°1 + (expected_answer + 1)->name = "cname-app-com-amsproxy.w.mi-dun.com"; + (expected_answer + 1)->rtype = 1; + (expected_answer + 1)->rclass = 1; + (expected_answer + 1)->ttl = 147; + (expected_answer + 1)->rdlength = 4; + (expected_answer + 1)->rdata.ip.version = 4; + (expected_answer + 1)->rdata.ip.value.ipv4 = ipv4_str_to_net("20.47.97.231"); + compare_rrs(message.header.ancount, message.answers, expected_answer); + free(expected_answer); + + + // Lookup functions + + // Search for domain name + char *domain_name = "business.smartcamera.api.io.mi.com"; + CU_ASSERT_TRUE(dns_contains_full_domain_name(message.questions, message.header.qdcount, domain_name)); + char *suffix = "api.io.mi.com"; + CU_ASSERT_TRUE(dns_contains_suffix_domain_name(message.questions, message.header.qdcount, suffix, strlen(suffix))); + domain_name = "www.example.org"; + CU_ASSERT_FALSE(dns_contains_full_domain_name(message.questions, message.header.qdcount, domain_name)); + suffix = "example.org"; + CU_ASSERT_FALSE(dns_contains_suffix_domain_name(message.questions, message.header.qdcount, suffix, strlen(suffix))); + + // Get question from domain name + domain_name = "business.smartcamera.api.io.mi.com"; + dns_question_t *question_lookup = dns_get_question(message.questions, message.header.qdcount, domain_name); + CU_ASSERT_PTR_NOT_NULL(question_lookup); + domain_name = "www.example.org"; + question_lookup = dns_get_question(message.questions, message.header.qdcount, domain_name); + CU_ASSERT_PTR_NULL(question_lookup); + + // Get IP addresses from domain name + domain_name = "business.smartcamera.api.io.mi.com"; + ip_list_t ip_list = dns_get_ip_from_name(message.answers, message.header.ancount, domain_name); + char *ip_address = "20.47.97.231"; + CU_ASSERT_EQUAL(ip_list.ip_count, 1); + CU_ASSERT_STRING_EQUAL(ipv4_net_to_str(ip_list.ip_addresses->value.ipv4), ip_address); + free(ip_list.ip_addresses); + domain_name = "www.example.org"; + ip_list = dns_get_ip_from_name(message.answers, message.header.ancount, domain_name); + CU_ASSERT_EQUAL(ip_list.ip_count, 0); + CU_ASSERT_PTR_NULL(ip_list.ip_addresses); + + // Free memory + dns_free_message(message); +} + +/** + * Unit test for the DNS parser. + */ +void test_dns_office() { + char *hexstring = "4500012a4aa900003e114737826801018268e4110035d7550116a82b3ebf81800001000900000001076f75746c6f6f6b066f666669636503636f6d0000010001c00c0005000100000007000c09737562737472617465c014c03000050001000000500017076f75746c6f6f6b096f666669636533363503636f6d00c0480005000100000093001a076f75746c6f6f6b026861096f666669636533363503636f6d00c06b000500010000000b001c076f75746c6f6f6b076d732d61636463066f666669636503636f6d00c091000500010000001b000a07414d532d65667ac099c0b90001000100000004000434619ea2c0b90001000100000004000428650c62c0b9000100010000000400042863cc22c0b9000100010000000400042865791200002904d0000000000000"; + + // Create payload from hexstring + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify message length + size_t skipped = get_headers_length(payload); + dns_message_t message = dns_parse_message(payload + skipped); + free(payload); + //dns_print_message(message); + + // Test different sections of the DNS message + + // Header + dns_header_t expected_header; + expected_header.id = 0x3ebf; + expected_header.flags = 0x8180; + expected_header.qr = 1; + expected_header.qdcount = 1; + expected_header.ancount = 9; + expected_header.nscount = 0; + expected_header.arcount = 1; + compare_headers(message.header, expected_header); + + // Questions + dns_question_t *expected_question; + expected_question = malloc(sizeof(dns_question_t) * message.header.qdcount); + expected_question->qname = "outlook.office.com"; + expected_question->qtype = 1; + expected_question->qclass = 1; + compare_questions(message.header.qdcount, message.questions, expected_question); + free(expected_question); + + // Answer resource records + dns_resource_record_t *expected_answer; + expected_answer = malloc(sizeof(dns_resource_record_t) * message.header.ancount); + // Answer n°0 + expected_answer->name = "outlook.office.com"; + expected_answer->rtype = 5; + expected_answer->rclass = 1; + expected_answer->ttl = 7; + expected_answer->rdlength = 12; + expected_answer->rdata.domain_name = "substrate.office.com"; + // Answer n°1 + (expected_answer + 1)->name = "substrate.office.com"; + (expected_answer + 1)->rtype = 5; + (expected_answer + 1)->rclass = 1; + (expected_answer + 1)->ttl = 80; + (expected_answer + 1)->rdlength = 23; + (expected_answer + 1)->rdata.domain_name = "outlook.office365.com"; + // Answer n°2 + (expected_answer + 2)->name = "outlook.office365.com"; + (expected_answer + 2)->rtype = 5; + (expected_answer + 2)->rclass = 1; + (expected_answer + 2)->ttl = 147; + (expected_answer + 2)->rdlength = 26; + (expected_answer + 2)->rdata.domain_name = "outlook.ha.office365.com"; + // Answer n°3 + (expected_answer + 3)->name = "outlook.ha.office365.com"; + (expected_answer + 3)->rtype = 5; + (expected_answer + 3)->rclass = 1; + (expected_answer + 3)->ttl = 11; + (expected_answer + 3)->rdlength = 28; + (expected_answer + 3)->rdata.domain_name = "outlook.ms-acdc.office.com"; + // Answer n°4 + (expected_answer + 4)->name = "outlook.ms-acdc.office.com"; + (expected_answer + 4)->rtype = CNAME; + (expected_answer + 4)->rclass = 1; + (expected_answer + 4)->ttl = 27; + (expected_answer + 4)->rdlength = 10; + (expected_answer + 4)->rdata.domain_name = "AMS-efz.ms-acdc.office.com"; + // Answer n°5 + (expected_answer + 5)->name = "AMS-efz.ms-acdc.office.com"; + (expected_answer + 5)->rtype = A; + (expected_answer + 5)->rclass = 1; + (expected_answer + 5)->ttl = 4; + (expected_answer + 5)->rdlength = 4; + (expected_answer + 5)->rdata.ip.version = 4; + (expected_answer + 5)->rdata.ip.value.ipv4 = ipv4_str_to_net("52.97.158.162"); + // Answer n°6 + (expected_answer + 6)->name = "AMS-efz.ms-acdc.office.com"; + (expected_answer + 6)->rtype = A; + (expected_answer + 6)->rclass = 1; + (expected_answer + 6)->ttl = 4; + (expected_answer + 6)->rdlength = 4; + (expected_answer + 6)->rdata.ip.version = 4; + (expected_answer + 6)->rdata.ip.value.ipv4 = ipv4_str_to_net("40.101.12.98"); + // Answer n°7 + (expected_answer + 7)->name = "AMS-efz.ms-acdc.office.com"; + (expected_answer + 7)->rtype = A; + (expected_answer + 7)->rclass = 1; + (expected_answer + 7)->ttl = 4; + (expected_answer + 7)->rdlength = 4; + (expected_answer + 7)->rdata.ip.version = 4; + (expected_answer + 7)->rdata.ip.value.ipv4 = ipv4_str_to_net("40.99.204.34"); + // Answer n°8 + (expected_answer + 8)->name = "AMS-efz.ms-acdc.office.com"; + (expected_answer + 8)->rtype = A; + (expected_answer + 8)->rclass = 1; + (expected_answer + 8)->ttl = 4; + (expected_answer + 8)->rdlength = 4; + (expected_answer + 8)->rdata.ip.version = 4; + (expected_answer + 8)->rdata.ip.value.ipv4 = ipv4_str_to_net("40.101.121.18"); + // Compare and free answer + compare_rrs(message.header.ancount, message.answers, expected_answer); + free(expected_answer); + + + // Lookup functions + + // Search for domain name + char *domain_name = "outlook.office.com"; + CU_ASSERT_TRUE(dns_contains_full_domain_name(message.questions, message.header.qdcount, domain_name)); + char* suffix = "office.com"; + CU_ASSERT_TRUE(dns_contains_suffix_domain_name(message.questions, message.header.qdcount, suffix, strlen(suffix))); + domain_name = "www.example.org"; + CU_ASSERT_FALSE(dns_contains_full_domain_name(message.questions, message.header.qdcount, domain_name)); + suffix = "example.org"; + CU_ASSERT_FALSE(dns_contains_suffix_domain_name(message.questions, message.header.qdcount, suffix, strlen(suffix))); + + // Get question from domain name + domain_name = "outlook.office.com"; + dns_question_t *question_lookup = dns_get_question(message.questions, message.header.qdcount, domain_name); + CU_ASSERT_PTR_NOT_NULL(question_lookup); + domain_name = "www.example.org"; + question_lookup = dns_get_question(message.questions, message.header.qdcount, domain_name); + CU_ASSERT_PTR_NULL(question_lookup); + + // Get IP addresses from domain name + domain_name = "outlook.office.com"; + ip_list_t ip_list = dns_get_ip_from_name(message.answers, message.header.ancount, domain_name); + char* ip_addresses[] = { + "52.97.158.162", + "40.101.12.98", + "40.99.204.34", + "40.101.121.18" + }; + CU_ASSERT_EQUAL(ip_list.ip_count, 4); + for (uint8_t i = 0; i < 4; i++) { + CU_ASSERT_STRING_EQUAL(ipv4_net_to_str((ip_list.ip_addresses + i)->value.ipv4), ip_addresses[i]); + } + free(ip_list.ip_addresses); + domain_name = "www.example.org"; + ip_list = dns_get_ip_from_name(message.answers, message.header.ancount, domain_name); + CU_ASSERT_EQUAL(ip_list.ip_count, 0); + CU_ASSERT_PTR_NULL(ip_list.ip_addresses); + + // Free memory + dns_free_message(message); +} + +/** + * Main function for the unit tests. + */ +int main(int argc, char const *argv[]) +{ + // Initialize registry and suite + if (CU_initialize_registry() != CUE_SUCCESS) + return CU_get_error(); + printf("Test suite: dns\n"); + CU_pSuite suite = CU_add_suite("dns", NULL, NULL); + // Run tests + CU_add_test(suite, "dns-xiaomi", test_dns_xiaomi); + CU_add_test(suite, "dns-office", test_dns_office); + CU_basic_run_tests(); + CU_cleanup_registry(); + return 0; +} diff --git a/test/dns_map.c b/test/dns_map.c new file mode 100644 index 0000000000000000000000000000000000000000..2c367738dde37899de4e1d803d915edb2f32e6f3 --- /dev/null +++ b/test/dns_map.c @@ -0,0 +1,244 @@ +/** + * @file test/dns_map.c + * @brief Unit tests for the mapping structure from DNS domain names to IP addresses + * @date 2022-09-06 + * + * @copyright Copyright (c) 2022 + * + */ + +#include <string.h> +// Custom libraries +#include "hashmap.h" +#include "packet_utils.h" +#include "dns_map.h" +// CUnit +#include <CUnit/CUnit.h> +#include <CUnit/Basic.h> + +/** + * Test the creation of a DNS table. + */ +void test_dns_map_create() { + printf("test_dns_map_create\n"); + dns_map_t *table = dns_map_create(); + CU_ASSERT_PTR_NOT_NULL(table); + CU_ASSERT_EQUAL(hashmap_count(table), 0); + dns_map_free(table); +} + +/** + * Test operations on an empty DNS table. + */ +void test_dns_map_empty() { + printf("test_dns_map_empty\n"); + dns_map_t *table = dns_map_create(); + dns_entry_t* entry = dns_map_get(table, "www.google.com"); + CU_ASSERT_PTR_NULL(entry); + entry = dns_map_pop(table, "www.google.com"); + CU_ASSERT_PTR_NULL(entry); + dns_map_remove(table, "www.google.com"); // Does nothing, but should not crash + dns_map_free(table); +} + +/** + * Test adding and removing entries in a DNS table. + */ +void test_dns_map_add_remove() { + printf("test_dns_map_add_remove\n"); + dns_map_t *table = dns_map_create(); + + // Add IP addresses for www.google.com + ip_addr_t *google_ips = (ip_addr_t *) malloc(2 * sizeof(ip_addr_t)); + *google_ips = (ip_addr_t) {.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.1")}; + *(google_ips + 1) = (ip_addr_t) {.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.2")}; + ip_list_t ip_list_google = { .ip_count = 2, .ip_addresses = google_ips }; + dns_map_add(table, "www.google.com", ip_list_google); + CU_ASSERT_EQUAL(hashmap_count(table), 1); + + // Add IP addresses for www.example.com + ip_addr_t *example_ips = (ip_addr_t *) malloc(2 * sizeof(ip_addr_t)); + *example_ips = (ip_addr_t) {.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.3")}; + *(example_ips + 1) = (ip_addr_t) {.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.4")}; + ip_list_t ip_list_example = {.ip_count = 2, .ip_addresses = example_ips}; + dns_map_add(table, "www.example.com", ip_list_example); + CU_ASSERT_EQUAL(hashmap_count(table), 2); + + // Add a new IP address for www.google.com + ip_addr_t *google_ips_new = (ip_addr_t *) malloc(sizeof(ip_addr_t)); + *google_ips_new = (ip_addr_t) {.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.5")}; + ip_list_t ip_list_google_new = { .ip_count = 1, .ip_addresses = google_ips_new }; + dns_map_add(table, "www.google.com", ip_list_google_new); + CU_ASSERT_EQUAL(hashmap_count(table), 2); + + // Remove all IP addresses + dns_map_remove(table, "www.google.com"); + CU_ASSERT_EQUAL(hashmap_count(table), 1); + dns_map_remove(table, "www.example.com"); + CU_ASSERT_EQUAL(hashmap_count(table), 0); + dns_map_free(table); +} + +/** + * Test retrieving entries from a DNS table. + */ +void test_dns_map_get() { + printf("test_dns_map_get\n"); + dns_map_t *table = dns_map_create(); + + // Add IP addresses for www.google.com + ip_addr_t *google_ips = (ip_addr_t *)malloc(2 * sizeof(ip_addr_t)); + *google_ips = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.1")}; + *(google_ips + 1) = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.2")}; + ip_list_t ip_list_google = {.ip_count = 2, .ip_addresses = google_ips}; + dns_map_add(table, "www.google.com", ip_list_google); + + // Verify getting IP addresses for www.google.com + dns_entry_t *actual = dns_map_get(table, "www.google.com"); + CU_ASSERT_PTR_NOT_NULL(actual); + CU_ASSERT_EQUAL(actual->ip_list.ip_count, 2); + for (int i = 0; i < actual->ip_list.ip_count; i++) { + CU_ASSERT_TRUE(compare_ip(*(actual->ip_list.ip_addresses + i), *(google_ips + i))); + } + + // Add IP addresses for www.example.com + ip_addr_t *example_ips = (ip_addr_t *)malloc(2 * sizeof(ip_addr_t)); + *example_ips = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.3")}; + *(example_ips + 1) = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.4")}; + ip_list_t ip_list_example = {.ip_count = 2, .ip_addresses = example_ips}; + dns_map_add(table, "www.example.com", ip_list_example); + + // Verify getting IP addresses for www.example.com + actual = dns_map_get(table, "www.example.com"); + CU_ASSERT_PTR_NOT_NULL(actual); + CU_ASSERT_EQUAL(actual->ip_list.ip_count, 2); + for (int i = 0; i < actual->ip_list.ip_count; i++) { + CU_ASSERT_TRUE(compare_ip(*(actual->ip_list.ip_addresses + i), *(example_ips + i))); + } + + // Add a new IP address for www.google.com + ip_addr_t *google_ips_new = (ip_addr_t *)malloc(sizeof(ip_addr_t)); + *google_ips_new = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.5")}; + ip_list_t ip_list_google_new = {.ip_count = 1, .ip_addresses = google_ips_new}; + dns_map_add(table, "www.google.com", ip_list_google_new); + + // Verify getting IP addresses for www.google.com + actual = dns_map_get(table, "www.google.com"); + CU_ASSERT_PTR_NOT_NULL(actual); + CU_ASSERT_EQUAL(actual->ip_list.ip_count, 3); + ip_addr_t *google_all_ips = (ip_addr_t *)malloc(3 * sizeof(ip_addr_t)); + *google_all_ips = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.1")}; + *(google_all_ips + 1) = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.2")}; + *(google_all_ips + 2) = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.5")}; + for (int i = 0; i < actual->ip_list.ip_count; i++) + { + CU_ASSERT_TRUE(compare_ip(*(actual->ip_list.ip_addresses + i), *(google_all_ips + i))); + } + + free(google_all_ips); + dns_map_free(table); +} + +/** + * Test popping entries from a DNS table. + */ +void test_dns_map_pop() { + printf("test_dns_map_pop\n"); + dns_map_t *table = dns_map_create(); + + // Add IP addresses for www.google.com + ip_addr_t *google_ips = (ip_addr_t *)malloc(2 * sizeof(ip_addr_t)); + *google_ips = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.1")}; + *(google_ips + 1) = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.2")}; + ip_list_t ip_list_google = {.ip_count = 2, .ip_addresses = google_ips}; + dns_map_add(table, "www.google.com", ip_list_google); + + // Add IP addresses for www.example.com + ip_addr_t *example_ips = (ip_addr_t *)malloc(2 * sizeof(ip_addr_t)); + *example_ips = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.3")}; + *(example_ips + 1) = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.4")}; + ip_list_t ip_list_example = {.ip_count = 2, .ip_addresses = example_ips}; + dns_map_add(table, "www.example.com", ip_list_example); + + // Verify popping IP addresses for www.google.com + dns_entry_t *actual = dns_map_pop(table, "www.google.com"); + CU_ASSERT_PTR_NOT_NULL(actual); + CU_ASSERT_EQUAL(actual->ip_list.ip_count, 2); + for (int i = 0; i < actual->ip_list.ip_count; i++) + { + CU_ASSERT_TRUE(compare_ip(*(actual->ip_list.ip_addresses + i), *(google_ips + i))); + } + free(actual->ip_list.ip_addresses); + CU_ASSERT_EQUAL(hashmap_count(table), 1); + actual = dns_map_pop(table, "www.google.com"); + CU_ASSERT_PTR_NULL(actual); + + // Verify popping IP addresses for www.example.com + actual = dns_map_pop(table, "www.example.com"); + CU_ASSERT_PTR_NOT_NULL(actual); + CU_ASSERT_EQUAL(actual->ip_list.ip_count, 2); + for (int i = 0; i < actual->ip_list.ip_count; i++) + { + CU_ASSERT_TRUE(compare_ip(*(actual->ip_list.ip_addresses + i), *(example_ips + i))); + } + free(actual->ip_list.ip_addresses); + CU_ASSERT_EQUAL(hashmap_count(table), 0); + actual = dns_map_pop(table, "www.example.com"); + CU_ASSERT_PTR_NULL(actual); + + dns_map_free(table); +} + +/** + * Test printing entries from a DNS table. + */ +void test_dns_entry_print() { + printf("test_dns_entry_print\n"); + dns_map_t *table = dns_map_create(); + + // Add IP addresses for www.google.com + ip_addr_t *google_ips = (ip_addr_t *)malloc(2 * sizeof(ip_addr_t)); + *google_ips = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.1")}; + *(google_ips + 1) = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.2")}; + ip_list_t ip_list_google = {.ip_count = 2, .ip_addresses = google_ips}; + dns_map_add(table, "www.google.com", ip_list_google); + + // Add IP addresses for www.example.com + ip_addr_t *example_ips = (ip_addr_t *)malloc(2 * sizeof(ip_addr_t)); + *example_ips = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.3")}; + *(example_ips + 1) = (ip_addr_t){.version = 4, .value.ipv4 = ipv4_str_to_net("192.168.1.4")}; + ip_list_t ip_list_example = {.ip_count = 2, .ip_addresses = example_ips}; + dns_map_add(table, "www.example.com", ip_list_example); + + // Print entries + dns_entry_t *dns_entry = dns_map_get(table, "www.google.com"); + dns_entry_print(dns_entry); + dns_entry = dns_map_get(table, "www.example.com"); + dns_entry_print(dns_entry); + + // Destroy DNS table + dns_map_free(table); +} + + +/** + * Test suite entry point. + */ +int main(int argc, char const *argv[]) +{ + // Initialize the CUnit test registry and suite + printf("Test suite: dns_map\n"); + if (CU_initialize_registry() != CUE_SUCCESS) + return CU_get_error(); + CU_pSuite suite = CU_add_suite("dns_map", NULL, NULL); + // Add and run tests + CU_add_test(suite, "dns_map_create", test_dns_map_create); + CU_add_test(suite, "dns_map_empty", test_dns_map_empty); + CU_add_test(suite, "dns_map_add_remove", test_dns_map_add_remove); + CU_add_test(suite, "dns_map_get", test_dns_map_get); + CU_add_test(suite, "dns_map_pop", test_dns_map_pop); + CU_add_test(suite, "dns_entry_print", test_dns_entry_print); + CU_basic_run_tests(); + CU_cleanup_registry(); + return 0; +} diff --git a/test/header.c b/test/header.c new file mode 100644 index 0000000000000000000000000000000000000000..a40d1e1aef2644c0327a10b4189147a46ea66616 --- /dev/null +++ b/test/header.c @@ -0,0 +1,165 @@ +/** + * @file test/header.c + * @brief Unit test for the header parser (OSI layers 3 and 4) + * @date 2022-12-01 + * + * @copyright Copyright (c) 2022 + * + */ + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +// Custom libraries +#include "packet_utils.h" +#include "header.h" +// CUnit +#include <CUnit/CUnit.h> +#include <CUnit/Basic.h> + + +/** + * @brief Unit test using a TCP SYN packet. + */ +void test_tcp_syn() { + + char *hexstring = "4500003cbcd2400040066e0fc0a801966c8ae111c67f005004f77abb00000000a002ffff2b380000020405b40402080a0003c6690000000001030306"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify packet length + + // Layer-3 header length + uint16_t l3_header_length = get_l3_header_length(payload); + CU_ASSERT_EQUAL(l3_header_length, 20); + + // IPv4 destination address + uint32_t ipv4_src_addr = get_ipv4_src_addr(payload); + CU_ASSERT_STRING_EQUAL(ipv4_net_to_str(ipv4_src_addr), "192.168.1.150"); + + // IPv4 destination address + uint32_t ipv4_dst_addr = get_ipv4_dst_addr(payload); + CU_ASSERT_STRING_EQUAL(ipv4_net_to_str(ipv4_dst_addr), "108.138.225.17"); + + // TCP header length + uint16_t tcp_header_length = get_tcp_header_length(payload + l3_header_length); + CU_ASSERT_EQUAL(tcp_header_length, 40); + + // Layers 3 and 4 headers length + uint16_t headers_length = get_headers_length(payload); + CU_ASSERT_EQUAL(headers_length, 20 + 40); + + // Destination port + uint16_t dst_port = get_dst_port(payload + l3_header_length); + CU_ASSERT_EQUAL(dst_port, 80); + + // Contains payload ? + CU_ASSERT_FALSE(length - headers_length > 0); + + free(payload); +} + +/** + * @brief Unit test using an HTTPS data packet. + */ +void test_https_data() { + + char *hexstring = "450001613b64400040067977c0a801dec0a8018d8da801bbec035d653f25b250501808065ff2000017030301340000000000000087884ca5c237291279d20249e09c2848a56615a0fda66e788fdc5a04cb96d7be52b00302e4956118ec87e74ad1e3e20192689876cc821e6c95087fbc160163edd6a48b5f1f06752e3b0b0ee4c9c1f208508ba36fd57499c3a1d95805f33a5e5b89edb06e8b70615eb3f531a375537674e298b7692d78bd5e407738597097285a1205a2d3f4ba183bbd7f609ec1a9464934dd9999b8955c6a537a28a03118ac8a3391fdc378413bfcacba2a3995f54b45ea05126f1d906bbad2629a8d16e88b531f2d047a7f8b5199c5db819f76eac6d83e1e428b97b71721f3280e4eab6fb1c10dd58dfad004d11061aff1ee559c4704930a4dac9e33f32707f80823438990457dafdd5d325dda22f2fab0863cbbb45cafc11c5209370e23d5bc779506f5621d75afa003932c8bdb72ff5f9a2f"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify packet length + + // Layer-3 header length + uint16_t l3_header_length = get_l3_header_length(payload); + CU_ASSERT_EQUAL(l3_header_length, 20); + + // IPv4 destination address + uint32_t ipv4_src_addr = get_ipv4_src_addr(payload); + CU_ASSERT_STRING_EQUAL(ipv4_net_to_str(ipv4_src_addr), "192.168.1.222"); + + // IPv4 destination address + uint32_t ipv4_dst_addr = get_ipv4_dst_addr(payload); + CU_ASSERT_STRING_EQUAL(ipv4_net_to_str(ipv4_dst_addr), "192.168.1.141"); + + // TCP header length + uint16_t tcp_header_length = get_tcp_header_length(payload + l3_header_length); + CU_ASSERT_EQUAL(tcp_header_length, 20); + + // Layers 3 and 4 headers length + uint16_t headers_length = get_headers_length(payload); + CU_ASSERT_EQUAL(headers_length, 20 + 20); + + // Destination port + uint16_t dst_port = get_dst_port(payload + l3_header_length); + CU_ASSERT_EQUAL(dst_port, 443); + + // Contains payload ? + CU_ASSERT_TRUE(length - headers_length > 0); + + free(payload); +} + +/** + * @brief Unit test using a DNS message over IPv6. + */ +void test_dns_ipv6() { + + char *hexstring = "6002ec1b002d1140fddded18f05b0000d8a3adc0f68fe5cffddded18f05b00000000000000000001b0f20035002d5388ac4a01000001000000000000036170690b736d6172747468696e677303636f6d00001c0001"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify packet length + + // Layer-3 header length + uint16_t l3_header_length = get_l3_header_length(payload); + CU_ASSERT_EQUAL(l3_header_length, IPV6_HEADER_LENGTH); + + // IPv6 source address + uint8_t *ipv6_src_addr = get_ipv6_src_addr(payload); + uint8_t expected_src[IPV6_ADDR_LENGTH] = {0xfd, 0xdd, 0xed, 0x18, 0xf0, 0x5b, 0x00, 0x00, 0xd8, 0xa3, 0xad, 0xc0, 0xf6, 0x8f, 0xe5, 0xcf}; + CU_ASSERT_TRUE(compare_ipv6(ipv6_src_addr, expected_src)); + free(ipv6_src_addr); + + // IPv6 destination address + uint8_t *ipv6_dst_addr = get_ipv6_dst_addr(payload); + uint8_t expected_dst[IPV6_ADDR_LENGTH] = {0xfd, 0xdd, 0xed, 0x18, 0xf0, 0x5b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}; + CU_ASSERT_TRUE(compare_ipv6(ipv6_dst_addr, expected_dst)); + free(ipv6_dst_addr); + + // UDP header length + uint16_t udp_header_length = get_udp_header_length(payload + l3_header_length); + CU_ASSERT_EQUAL(udp_header_length, UDP_HEADER_LENGTH); + + // Layers 3 and 4 headers length + uint16_t headers_length = get_headers_length(payload); + CU_ASSERT_EQUAL(headers_length, IPV6_HEADER_LENGTH + UDP_HEADER_LENGTH); + + // Destination port + uint16_t dst_port = get_dst_port(payload + l3_header_length); + CU_ASSERT_EQUAL(dst_port, 53); + + // UDP payload length + uint16_t udp_payload_length = get_udp_payload_length(payload + l3_header_length); + CU_ASSERT_EQUAL(udp_payload_length, 45 - UDP_HEADER_LENGTH); + + free(payload); +} + +/** + * Driver function for the unit tests. + */ +int main(int argc, char const *argv[]) +{ + // Initialize registry and suite + if (CU_initialize_registry() != CUE_SUCCESS) + return CU_get_error(); + CU_pSuite suite = CU_add_suite("header", NULL, NULL); + // Run tests + CU_add_test(suite, "tcp-syn", test_tcp_syn); + CU_add_test(suite, "https-data", test_https_data); + CU_add_test(suite, "dns-ipv6", test_dns_ipv6); + CU_basic_run_tests(); + CU_cleanup_registry(); + return 0; +} diff --git a/test/http.c b/test/http.c new file mode 100644 index 0000000000000000000000000000000000000000..d740b0c7d2ab8dde1d4abc077865a6262f841b5d --- /dev/null +++ b/test/http.c @@ -0,0 +1,92 @@ +/** + * @file test/http.c + * @brief Unit tests for the HTTP parser + * @date 2022-20-09 + * + * @copyright Copyright (c) 2022 + * + */ + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +// Custom libraries +#include "packet_utils.h" +#include "header.h" +#include "http.h" +// CUnit +#include <CUnit/CUnit.h> +#include <CUnit/Basic.h> + + +/** + * @brief Unit test for the HTTP parser. + */ +void test_http_request() { + + char *hexstring = "450000ccb11f400040065845c0a801a16e2b005387b8005023882026a6ab695450180e4278860000474554202f67736c623f747665723d322669643d33363932313536313726646d3d6f74732e696f2e6d692e636f6d2674696d657374616d703d38267369676e3d6a327a743325324270624177637872786f765155467443795a3644556d47706c584e4b723169386a746552623425334420485454502f312e310d0a486f73743a20646e732e696f2e6d692e636f6d0d0a557365722d4167656e743a204d496f540d0a0d0a"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify message length + + size_t l3_header_length = get_l3_header_length(payload); + uint16_t dst_port = get_dst_port(payload + l3_header_length); + size_t skipped = get_headers_length(payload); + http_message_t actual = http_parse_message(payload + skipped, dst_port); + free(payload); + //http_print_message(actual); + + // Test if HTTP message has been correctly parsed + http_message_t expected; + expected.is_request = true; + expected.method = HTTP_GET; + expected.uri = "/gslb?tver=2&id=369215617&dm=ots.io.mi.com×tamp=8&sign=j2zt3%2BpbAwcxrxovQUFtCyZ6DUmGplXNKr1i8jteRb4%3D"; + CU_ASSERT_EQUAL(actual.is_request, expected.is_request); + CU_ASSERT_EQUAL(actual.method, expected.method); + CU_ASSERT_STRING_EQUAL(actual.uri, expected.uri); + + http_free_message(actual); + +} + +void test_http_response() { + + char *hexstring = "450001a42fc540002f06e9c76e2b0053c0a801a1005087b8a6ab6954238820ca501803b8e92e0000485454502f312e3120323030204f4b0d0a5365727665723a2054656e67696e650d0a446174653a205765642c203330204d617220323032322031323a30353a323420474d540d0a436f6e74656e742d547970653a206170706c69636174696f6e2f6a736f6e3b20636861727365743d7574662d380d0a436f6e74656e742d4c656e6774683a203231350d0a436f6e6e656374696f6e3a206b6565702d616c6976650d0a0d0a7b22696e666f223a7b22656e61626c65223a312c22686f73745f6c697374223a5b7b226970223a223132302e39322e39362e313535222c22706f7274223a3434337d2c7b226970223a223132302e39322e3134352e313430222c22706f7274223a3434337d2c7b226970223a223132302e39322e36352e323431222c22706f7274223a3434337d5d7d2c227369676e223a225a757856496a2b337858303362654a4b5936684e385668454f7a65485630446a6753654471656d2b7032413d222c2274696d657374616d70223a313634383634313932347d"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify message length + + size_t skipped = get_ipv4_header_length(payload); + uint16_t dst_port = get_dst_port(payload + skipped); + skipped += get_tcp_header_length(payload + skipped); + http_message_t actual = http_parse_message(payload + skipped, dst_port); + free(payload); + //http_print_message(actual); + + // Test if HTTP message has been correctly parsed + http_message_t expected; + expected.is_request = false; + CU_ASSERT_EQUAL(actual.is_request, expected.is_request); + + http_free_message(actual); + +} + +/** + * Driver function for the unit tests. + */ +int main(int argc, char const *argv[]) +{ + // Initialize registry and suite + if (CU_initialize_registry() != CUE_SUCCESS) + return CU_get_error(); + CU_pSuite suite = CU_add_suite("http", NULL, NULL); + // Run tests + CU_add_test(suite, "http-request", test_http_request); + CU_add_test(suite, "http-response", test_http_response); + CU_basic_run_tests(); + CU_cleanup_registry(); + return 0; +} diff --git a/test/igmp.c b/test/igmp.c new file mode 100644 index 0000000000000000000000000000000000000000..291cbf9e07225d04295d41b2cdc99706861f6eaf --- /dev/null +++ b/test/igmp.c @@ -0,0 +1,195 @@ +/** + * @file test/igmp.c + * @brief Unit tests for the IGMP parser + * @date 2022-10-05 + * + * @copyright Copyright (c) 2022 + * + */ + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +// Custom libraries +#include "packet_utils.h" +#include "header.h" +#include "igmp.h" +// CUnit +#include <CUnit/CUnit.h> +#include <CUnit/Basic.h> + + +/** + * @brief Compare two IGMPv2 messages. + * + * @param actual actual IGMPv2 message + * @param expected expected IGMPv2 message + */ +void compare_igmp_v2_messages(igmp_v2_message_t actual, igmp_v2_message_t expected) { + CU_ASSERT_EQUAL(actual.max_resp_time, expected.max_resp_time); + CU_ASSERT_EQUAL(actual.checksum, expected.checksum); + CU_ASSERT_EQUAL(actual.group_address, expected.group_address); +} + +/** + * @brief Compare two IGMPv3 Membership Report messages. + * + * @param actual actual IGMPv3 Membership Report message + * @param expected expected IGMPv3 Membership Report message + */ +void compare_igmp_v3_messages(igmp_v3_membership_report_t actual, igmp_v3_membership_report_t expected) { + CU_ASSERT_EQUAL(actual.checksum, expected.checksum); + CU_ASSERT_EQUAL(actual.num_groups, expected.num_groups); + for (uint16_t i = 0; i < actual.num_groups; i++) { + igmp_v3_group_record_t actual_group = *(actual.groups + i); + igmp_v3_group_record_t expected_group = *(expected.groups + i); + CU_ASSERT_EQUAL(actual_group.type, expected_group.type); + CU_ASSERT_EQUAL(actual_group.aux_data_len, expected_group.aux_data_len); + CU_ASSERT_EQUAL(actual_group.num_sources, expected_group.num_sources); + CU_ASSERT_EQUAL(actual_group.group_address, expected_group.group_address); + for (uint16_t j = 0; j < actual_group.num_sources; j++) { + CU_ASSERT_EQUAL(*(actual_group.sources + j), *(expected_group.sources + j)); + } + } +} + +/** + * @brief Compare two IGMP messages. + * + * @param actual actual IGMP message + * @param expected expected IGMP message + */ +void compare_igmp_messages(igmp_message_t actual, igmp_message_t expected) +{ + CU_ASSERT_EQUAL(actual.version, expected.version); + if (actual.version != expected.version) + return; + + CU_ASSERT_EQUAL(actual.type, expected.type); + if (actual.type != expected.type) + return; + + switch (actual.version) + { + case 2: + compare_igmp_v2_messages(actual.body.v2_message, expected.body.v2_message); + break; + case 3: + compare_igmp_v3_messages(actual.body.v3_membership_report, expected.body.v3_membership_report); + break; + default: + CU_FAIL("Unknown IGMP version"); + } +} + +/** + * @brief Unit test with an IGMPv2 Membership Report message. + */ +void test_igmp_v2_membership_report() { + + char *hexstring = "46c000200000400001024096c0a801dee00000fb9404000016000904e00000fb"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify message length + + size_t skipped = get_headers_length(payload); + igmp_message_t actual = igmp_parse_message(payload + skipped); + free(payload); + //igmp_print_message(actual); + + // Expected message + igmp_message_t expected; + expected.version = 2; + expected.type = V2_MEMBERSHIP_REPORT; + expected.body.v2_message.max_resp_time = 0; + expected.body.v2_message.checksum = 0x0904; + expected.body.v2_message.group_address = ipv4_str_to_net("224.0.0.251"); + + // Compare messages + compare_igmp_messages(actual, expected); + +} + +/** + * @brief Unit test with an IGMPv2 Leave Group message. + */ +void test_igmp_v2_leave_group() { + + char *hexstring = "46c00020000040000102418fc0a801dee00000029404000017000804e00000fb"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify message length + + size_t skipped = get_headers_length(payload); + igmp_message_t actual = igmp_parse_message(payload + skipped); + free(payload); + //igmp_print_message(actual); + + // Expected message + igmp_message_t expected; + expected.version = 2; + expected.type = LEAVE_GROUP; + expected.body.v2_message.max_resp_time = 0; + expected.body.v2_message.checksum = 0x0804; + expected.body.v2_message.group_address = ipv4_str_to_net("224.0.0.251"); + + // Compare messages + compare_igmp_messages(actual, expected); + +} + +/** + * @brief Unit test with an IGMPv3 Membership Report message. + */ +void test_igmp_v3_membership_report() { + + char *hexstring = "46c0002800004000010241dec0a80173e0000016940400002200f9020000000104000000e00000fb"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify message length + + size_t skipped = get_headers_length(payload); + igmp_message_t actual = igmp_parse_message(payload + skipped); + free(payload); + //igmp_print_message(actual); + + // Expected message + igmp_message_t expected; + expected.version = 3; + expected.type = V3_MEMBERSHIP_REPORT; + expected.body.v3_membership_report.checksum = 0xf902; + expected.body.v3_membership_report.num_groups = 1; + expected.body.v3_membership_report.groups = malloc(sizeof(igmp_v3_group_record_t)); + expected.body.v3_membership_report.groups->type = 4; + expected.body.v3_membership_report.groups->aux_data_len = 0; + expected.body.v3_membership_report.groups->num_sources = 0; + expected.body.v3_membership_report.groups->group_address = ipv4_str_to_net("224.0.0.251"); + + // Compare messages + compare_igmp_messages(actual, expected); + + // Free messages + igmp_free_message(actual); + igmp_free_message(expected); +} + +/** + * Main function for the unit tests. + */ +int main(int argc, char const *argv[]) +{ + // Initialize registry and suite + if (CU_initialize_registry() != CUE_SUCCESS) + return CU_get_error(); + CU_pSuite suite = CU_add_suite("igmp", NULL, NULL); + // Run tests + CU_add_test(suite, "igmp-v2-membership-report", test_igmp_v2_membership_report); + CU_add_test(suite, "igmp-leave-group", test_igmp_v2_leave_group); + CU_add_test(suite, "igmp-v3-membership-report", test_igmp_v3_membership_report); + CU_basic_run_tests(); + CU_cleanup_registry(); + return 0; +} diff --git a/test/packet_utils.c b/test/packet_utils.c new file mode 100644 index 0000000000000000000000000000000000000000..673dbb55254ad4c959d8457f9745394522202bce --- /dev/null +++ b/test/packet_utils.c @@ -0,0 +1,263 @@ +/** + * @file test/packet_utils.c + * @brief Unit tests for the packet utilities + * @date 2022-09-13 + * + * @copyright Copyright (c) 2022 + * + */ + +// Standard libraries +#include <stdlib.h> +#include <string.h> +// Custom libraries +#include "packet_utils.h" +// CUnit +#include <CUnit/CUnit.h> +#include <CUnit/Basic.h> + + +/** + * @brief Unit test for the function hexstr_to_payload. + */ +void test_hexstr_to_payload() { + char *hexstr = "48656c6c6f20576f726c6421"; + uint8_t expected[] = {0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21}; + uint8_t *actual; + size_t length = hexstr_to_payload(hexstr, &actual); + CU_ASSERT_EQUAL(length, strlen(hexstr) / 2); // Verify payload length + // Verify payload byte by byte + for (uint8_t i = 0; i < length; i++) { + CU_ASSERT_EQUAL(*(actual + i), expected[i]); + } + free(actual); +} + +/** + * @brief Unit test for the function mac_hex_to_str. + */ +void test_mac_hex_to_str() +{ + uint8_t mac_hex[] = {0x00, 0x0c, 0x29, 0x6b, 0x9f, 0x5a}; + char *expected = "00:0c:29:6b:9f:5a"; + char *actual = mac_hex_to_str(mac_hex); + CU_ASSERT_STRING_EQUAL(actual, expected); + free(actual); +} + +/** + * @brief Unit test for the function mac_str_to_hex. + */ +void test_mac_str_to_hex() +{ + char *mac_str = "00:0c:29:6b:9f:5a"; + uint8_t *expected = (uint8_t *) malloc(sizeof(uint8_t) * 6); + memcpy(expected, "\x00\x0c\x29\x6b\x9f\x5a", 6); + uint8_t *actual = mac_str_to_hex(mac_str); + for (uint8_t i = 0; i < 6; i++) + { + CU_ASSERT_EQUAL(*(actual + i), *(expected + i)) + } + free(actual); + free(expected); +} + +/** + * @brief Unit test for the function ipv4_net_to_str. + */ +void test_ipv4_net_to_str() { + uint32_t ipv4_net = 0xa101a8c0; + char *expected = "192.168.1.161"; + char *actual = ipv4_net_to_str(ipv4_net); + CU_ASSERT_STRING_EQUAL(actual, expected); +} + +/** + * @brief Unit test for the function ipv4_str_to_net. + */ +void test_ipv4_str_to_net() { + char *ipv4_str = "192.168.1.161"; + uint32_t expected = 0xa101a8c0; + uint32_t actual = ipv4_str_to_net(ipv4_str); + CU_ASSERT_EQUAL(actual, expected); +} + +/** + * @brief Unit test for the function ipv4_hex_to_str. + */ +void test_ipv4_hex_to_str() { + char *ipv4_hex = "\xc0\xa8\x01\xa1"; + char *expected = "192.168.1.161"; + char *actual = ipv4_hex_to_str(ipv4_hex); + CU_ASSERT_STRING_EQUAL(actual, expected); + free(actual); +} + +/** + * @brief Unit test for the function ipv4_str_to_hex. + */ +void test_ipv4_str_to_hex() { + char *ipv4_str = "192.168.1.161"; + char *expected = "\xc0\xa8\x01\xa1"; + char *actual = ipv4_str_to_hex(ipv4_str); + for (uint8_t i = 0; i < 4; i++) { + CU_ASSERT_EQUAL(*(actual + i), *(expected + i)) + } + free(actual); +} + +/** + * @brief Unit test for the function ipv6_net_to_str. + */ +void test_ipv6_net_to_str() { + // Full textual representation + uint8_t ipv6_1[] = {0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x11}; + char *actual = ipv6_net_to_str(ipv6_1); + char *expected = "1122:3344:5566:7788:99aa:bbcc:ddee:ff11"; + CU_ASSERT_STRING_EQUAL(actual, expected); + free(actual); + + // Compressed textual representation + uint8_t ipv6_2[] = {0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}; + actual = ipv6_net_to_str(ipv6_2); + expected = "1::1"; + CU_ASSERT_STRING_EQUAL(actual, expected); + free(actual); +} + +/** + * @brief Unit test for the function ipv6_str_to_net. + * + */ +void test_ipv6_str_to_net() { + // Full textual representation + char *ipv6_1 = "1122:3344:5566:7788:99aa:bbcc:ddee:ff11"; + uint8_t *expected = (uint8_t *) malloc(IPV6_ADDR_LENGTH * sizeof(uint8_t)); + memcpy(expected, "\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff\x11", IPV6_ADDR_LENGTH); + uint8_t *actual = ipv6_str_to_net(ipv6_1); + for (uint8_t i = 0; i < IPV6_ADDR_LENGTH; i++) { + CU_ASSERT_EQUAL(*(actual + i), *(expected + i)) + } + free(actual); + + // Compressed textual representation + char *ipv6_2 = "1::1"; + memcpy(expected, "\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", IPV6_ADDR_LENGTH); + actual = ipv6_str_to_net(ipv6_2); + for (uint8_t i = 0; i < IPV6_ADDR_LENGTH; i++) { + CU_ASSERT_EQUAL(*(actual + i), *(expected + i)) + } + free(actual); + free(expected); +} + +/** + * @brief Unit test for the function ip_net_to_str. + */ +void test_ip_net_to_str() { + // IPv4 + ip_addr_t ipv4 = {.version = 4, .value.ipv4 = 0x0101a8c0}; + char *actual = ip_net_to_str(ipv4); + char *expected = "192.168.1.1"; + CU_ASSERT_STRING_EQUAL(actual, expected); + + // IPv6 + ip_addr_t ipv6 = {.version = 6, .value.ipv6 = {0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x11}}; + actual = ip_net_to_str(ipv6); + expected = "1122:3344:5566:7788:99aa:bbcc:ddee:ff11"; + CU_ASSERT_STRING_EQUAL(actual, expected); + free(actual); +} + +/** + * @brief Unit test for the function ip_str_to_net. + * + */ +void test_ip_str_to_net() +{ + // IPv4 + char *ipv4_str = "192.168.1.161"; + ip_addr_t actual = ip_str_to_net(ipv4_str, 4); + ip_addr_t expected = (ip_addr_t) {.version = 4, .value.ipv4 = 0xa101a8c0}; + CU_ASSERT_EQUAL(actual.version, expected.version); + CU_ASSERT_EQUAL(actual.value.ipv4, expected.value.ipv4); + + // IPv6 + char *ipv6_str = "1122:3344:5566:7788:99aa:bbcc:ddee:ff11"; + actual = ip_str_to_net(ipv6_str, 6); + expected.version = 6; + memcpy(expected.value.ipv6, "\x11\x22\x33\x44\x55\x66\x77\x88\x99\xaa\xbb\xcc\xdd\xee\xff\x11", IPV6_ADDR_LENGTH); + CU_ASSERT_EQUAL(actual.version, expected.version); + for (uint8_t i = 0; i < IPV6_ADDR_LENGTH; i++) { + CU_ASSERT_EQUAL(actual.value.ipv6[i], expected.value.ipv6[i]); + } +} + +/** + * @brief Unit test for the function compare_ipv6. + */ +void test_compare_ipv6() { + uint8_t ipv6_1[] = {0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}; + uint8_t ipv6_2[] = {0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}; + uint8_t ipv6_3[] = {0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02}; + CU_ASSERT_TRUE(compare_ipv6(ipv6_1, ipv6_2)); + CU_ASSERT_TRUE(compare_ipv6(ipv6_2, ipv6_1)); + CU_ASSERT_FALSE(compare_ipv6(ipv6_1, ipv6_3)); + CU_ASSERT_FALSE(compare_ipv6(ipv6_3, ipv6_1)); +} + +/** + * @brief Unit test for the function compare_ip. + */ +void test_compare_ip() { + // Compare IPv4 + ip_addr_t ipv4_1 = { .version = 4, .value.ipv4 = 0xa101a8c0 }; + ip_addr_t ipv4_2 = {.version = 4, .value.ipv4 = 0xa101a8c0}; + ip_addr_t ipv4_3 = {.version = 4, .value.ipv4 = 0xa201a8c0}; + CU_ASSERT_TRUE(compare_ip(ipv4_1, ipv4_2)); + CU_ASSERT_TRUE(compare_ip(ipv4_2, ipv4_1)); + CU_ASSERT_FALSE(compare_ip(ipv4_1, ipv4_3)); + CU_ASSERT_FALSE(compare_ip(ipv4_3, ipv4_1)); + + // Compare IPv6 + ip_addr_t ipv6_1 = {.version = 6, .value.ipv6 = {0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}}; + ip_addr_t ipv6_2 = {.version = 6, .value.ipv6 = {0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}}; + ip_addr_t ipv6_3 = {.version = 6, .value.ipv6 = {0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02}}; + CU_ASSERT_TRUE(compare_ip(ipv6_1, ipv6_2)); + CU_ASSERT_TRUE(compare_ip(ipv6_2, ipv6_1)); + CU_ASSERT_FALSE(compare_ip(ipv6_1, ipv6_3)); + CU_ASSERT_FALSE(compare_ip(ipv6_3, ipv6_1)); + + // Compare IPv4 and IPv6 + CU_ASSERT_FALSE(compare_ip(ipv4_1, ipv6_1)); + CU_ASSERT_FALSE(compare_ip(ipv6_1, ipv4_1)); +} + +/** + * Test suite entry point. + */ +int main(int argc, char const *argv[]) +{ + // Initialize the CUnit test registry and suite + printf("Test suite: packet_utils\n"); + if (CU_initialize_registry() != CUE_SUCCESS) + return CU_get_error(); + CU_pSuite suite = CU_add_suite("packet_utils", NULL, NULL); + // Add and run tests + CU_add_test(suite, "hexstr_to_payload", test_hexstr_to_payload); + CU_add_test(suite, "mac_hex_to_str", test_mac_hex_to_str); + CU_add_test(suite, "mac_str_to_hex", test_mac_str_to_hex); + CU_add_test(suite, "ipv4_net_to_str", test_ipv4_net_to_str); + CU_add_test(suite, "ipv4_str_to_net", test_ipv4_str_to_net); + CU_add_test(suite, "ipv4_hex_to_str", test_ipv4_hex_to_str); + CU_add_test(suite, "ipv4_str_to_hex", test_ipv4_str_to_hex); + CU_add_test(suite, "ipv6_net_to_str", test_ipv6_net_to_str); + CU_add_test(suite, "ipv6_str_to_net", test_ipv6_str_to_net); + CU_add_test(suite, "ip_net_to_str", test_ip_net_to_str); + CU_add_test(suite, "ip_str_to_net", test_ip_str_to_net); + CU_add_test(suite, "compare_ipv6", test_compare_ipv6); + CU_add_test(suite, "compare_ip", test_compare_ip); + CU_basic_run_tests(); + CU_cleanup_registry(); + return 0; +} diff --git a/test/ssdp.c b/test/ssdp.c new file mode 100644 index 0000000000000000000000000000000000000000..44a16a41b63b6d6ce87bc1dd06d1036e07333c1d --- /dev/null +++ b/test/ssdp.c @@ -0,0 +1,117 @@ +/** + * @file test/ssdp.c + * @brief Unit tests for the SSDP parser + * @date 2022-11-24 + * + * @copyright Copyright (c) 2022 + * + */ + +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +// Custom libraries +#include "packet_utils.h" +#include "header.h" +#include "ssdp.h" +// CUnit +#include <CUnit/CUnit.h> +#include <CUnit/Basic.h> + + +/** + * @brief Unit test for an SSDP M-SEARCH message. + */ +void test_ssdp_msearch() { + + char *hexstring = "45000095dba640000111eb7bc0a80193effffffad741076c008163124d2d534541524348202a20485454502f312e310d0a4d583a20340d0a4d414e3a2022737364703a646973636f766572220d0a484f53543a203233392e3235352e3235352e3235303a313930300d0a53543a2075726e3a736368656d61732d75706e702d6f72673a6465766963653a62617369633a310d0a0d0a"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify message length + + uint32_t dst_addr = get_ipv4_dst_addr(payload); // IPv4 destination address, in network byte order + size_t skipped = get_ipv4_header_length(payload); + skipped += get_udp_header_length(payload + skipped); + ssdp_message_t actual = ssdp_parse_message(payload + skipped, dst_addr); + free(payload); + //ssdp_print_message(actual); + + // Test if SSDP message has been correctly parsed + ssdp_message_t expected; + expected.is_request = true; + expected.method = SSDP_M_SEARCH; + CU_ASSERT_EQUAL(actual.is_request, expected.is_request); + CU_ASSERT_EQUAL(actual.method, expected.method); + +} + +/** + * @brief Unit test for an SSDP NOTIFY message. + */ +void test_ssdp_notify() { + + char *hexstring = "4500014db3ea4000ff111485c0a8018deffffffa076c076c01399a564e4f54494659202a20485454502f312e310d0a484f53543a203233392e3235352e3235352e3235303a313930300d0a43414348452d434f4e54524f4c3a206d61782d6167653d3130300d0a4c4f434154494f4e3a20687474703a2f2f3139322e3136382e312e3134313a38302f6465736372697074696f6e2e786d6c0d0a5345525645523a204875652f312e302055506e502f312e3020332e31342e302f49704272696467650d0a4e54533a20737364703a616c6976650d0a6875652d62726964676569643a20303031373838464646453734433244430d0a4e543a2075706e703a726f6f746465766963650d0a55534e3a20757569643a32663430326638302d646135302d313165312d396232332d3030313738383734633264633a3a75706e703a726f6f746465766963650d0a0d0a"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify message length + + uint32_t dst_addr = get_ipv4_dst_addr(payload); // IPv4 destination address, in network byte order + size_t skipped = get_ipv4_header_length(payload); + skipped += get_udp_header_length(payload + skipped); + ssdp_message_t actual = ssdp_parse_message(payload + skipped, dst_addr); + free(payload); + //ssdp_print_message(actual); + + // Test if SSDP message has been correctly parsed + ssdp_message_t expected; + expected.is_request = true; + expected.method = SSDP_NOTIFY; + CU_ASSERT_EQUAL(actual.is_request, expected.is_request); + CU_ASSERT_EQUAL(actual.method, expected.method); +} + +/** + * @brief Unit test for an SSDP response. + */ +void test_ssdp_response() { + + char *hexstring = "45000140456c400040116f85c0a8018dc0a801de076c0f66012cdcc8485454502f312e3120323030204f4b0d0a484f53543a203233392e3235352e3235352e3235303a313930300d0a4558543a0d0a43414348452d434f4e54524f4c3a206d61782d6167653d3130300d0a4c4f434154494f4e3a20687474703a2f2f3139322e3136382e312e3134313a38302f6465736372697074696f6e2e786d6c0d0a5345525645523a204875652f312e302055506e502f312e302049704272696467652f312e34382e300d0a6875652d62726964676569643a20303031373838464646453734433244430d0a53543a2075706e703a726f6f746465766963650d0a55534e3a20757569643a32663430326638302d646135302d313165312d396232332d3030313738383734633264633a3a75706e703a726f6f746465766963650d0a0d0a"; + + uint8_t *payload; + size_t length = hexstr_to_payload(hexstring, &payload); + CU_ASSERT_EQUAL(length, strlen(hexstring) / 2); // Verify message length + + uint32_t dst_addr = get_ipv4_dst_addr(payload); // IPv4 destination address, in network byte order + size_t skipped = get_ipv4_header_length(payload); + skipped += get_udp_header_length(payload + skipped); + ssdp_message_t actual = ssdp_parse_message(payload + skipped, dst_addr); + free(payload); + //ssdp_print_message(actual); + + // Test if SSDP message has been correctly parsed + ssdp_message_t expected; + expected.is_request = false; + expected.method = SSDP_UNKNOWN; + CU_ASSERT_EQUAL(actual.is_request, expected.is_request); + CU_ASSERT_EQUAL(actual.method, expected.method); + +} + +/** + * Main function for the unit tests. + */ +int main(int argc, char const *argv[]) { + // Initialize registry and suite + if (CU_initialize_registry() != CUE_SUCCESS) + return CU_get_error(); + CU_pSuite suite = CU_add_suite("ssdp", NULL, NULL); + // Run tests + CU_add_test(suite, "ssdp-msearch", test_ssdp_msearch); + CU_add_test(suite, "ssdp-notify", test_ssdp_notify); + CU_add_test(suite, "ssdp-response", test_ssdp_response); + CU_basic_run_tests(); + CU_cleanup_registry(); + return 0; +}