diff --git a/include/dns.h b/include/dns.h index 0384b9084a35f20eee222d5d59ee348ab1327e3d..d28d8dcf1df3551fdbc47f5c7481111a35a18a2d 100644 --- a/include/dns.h +++ b/include/dns.h @@ -17,6 +17,8 @@ #include <stdbool.h> #include <string.h> #include <unistd.h> +#include <errno.h> +#include <sys/types.h> #include <sys/socket.h> #include <netinet/in.h> #include <arpa/inet.h> @@ -219,16 +221,19 @@ void dns_convert_qname(char *dst, char *src, uint16_t len); * @param qname domain name to query for * @param sockfd socket file descriptor * @param server_addr DNS server IPv4 address + * @return 0 if the query was sent successfully, -1 otherwise */ -void dns_send_query(char *qname, int sockfd, struct sockaddr_in *server_addr); +int dns_send_query(char *qname, int sockfd, struct sockaddr_in *server_addr); /** - * @brief Receive a DNS response + * @brief Receive a DNS response. * * @param sockfd socket file descriptor * @param server_addr DNS server IPv4 address + * @param dns_message allocated buffer which will be filled with the DNS response message, upon success + * @return 0 if DNS response was received successfully, -1 otherwise */ -dns_message_t dns_receive_response(int sockfd, struct sockaddr_in *server_addr); +int dns_receive_response(int sockfd, struct sockaddr_in *server_addr, dns_message_t *dns_message); ///// DESTROY ///// diff --git a/src/dns.c b/src/dns.c index cf6950232a8b513666fd5ebd00fc4aa85e3e9475..b99c1d98d72e676197fd47956ea6171549197668 100644 --- a/src/dns.c +++ b/src/dns.c @@ -10,6 +10,13 @@ #include "dns.h" +// DNS message timeout +#define TIMEOUT 5 // Timeout value, in seconds +struct timeval timeout = { + .tv_sec = TIMEOUT, + .tv_usec = 0 +}; + ///// PARSING ///// @@ -423,7 +430,7 @@ void dns_convert_qname(char *dst, char *src, uint16_t len) { * @param sockfd socket file descriptor * @param server_addr DNS server IPv4 address */ -void dns_send_query(char *qname, int sockfd, struct sockaddr_in *server_addr) { +int dns_send_query(char *qname, int sockfd, struct sockaddr_in *server_addr) { // Buffer that will contain the message uint16_t qname_len = strlen(qname); uint16_t qname_labels_len = qname_len + sizeof(uint8_t) * 2; @@ -454,15 +461,38 @@ void dns_send_query(char *qname, int sockfd, struct sockaddr_in *server_addr) { memcpy(buffer + DNS_HEADER_SIZE, dns_question.qname, qname_labels_len); memcpy(buffer + DNS_HEADER_SIZE + qname_labels_len, &(dns_question.qtype), sizeof(uint16_t) * 2); + // Set socket timeout + if (setsockopt(sockfd, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) + { + perror("Error setting socket send timeout"); + // Free memory + free(dns_question.qname); + free(buffer); + return -1; + } + // Send DNS message + #ifdef DEBUG + printf("Sending DNS query for domain name %s to server %s\n", qname, inet_ntoa(server_addr->sin_addr)); + #endif /* DEBUG */ if (sendto(sockfd, buffer, dns_message_size, 0, (struct sockaddr *)server_addr, sizeof(*server_addr)) < 0) { - perror("Failed sending DNS query."); + if (errno == EWOULDBLOCK || errno == EAGAIN) + { + printf("DNS query for %s timed out.\n", qname); + } else { + perror("Failed sending DNS query."); + } + // Free memory + free(dns_question.qname); + free(buffer); + return -1; } - // Free memory + // DNS query was sent successfully free(dns_question.qname); free(buffer); + return 0; } /** @@ -470,26 +500,38 @@ void dns_send_query(char *qname, int sockfd, struct sockaddr_in *server_addr) { * * @param sockfd socket file descriptor * @param server_addr DNS server IPv4 address - * @return received DNS message + * @param dns_message allocated buffer which will be filled with the DNS response message, upon success + * @return 0 if DNS response was received successfully, -1 otherwise */ -dns_message_t dns_receive_response(int sockfd, struct sockaddr_in *server_addr) +int dns_receive_response(int sockfd, struct sockaddr_in *server_addr, dns_message_t* dns_message) { // Receiving buffer int bufsize = 65536; uint8_t *buffer = (uint8_t *)malloc(bufsize); + // Set socket timeout + if (setsockopt(sockfd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) + { + perror("Error setting socket receive timeout"); + return -1; + } + // Await response int n = recvfrom(sockfd, (char *)buffer, bufsize, 0, NULL, NULL); if (n < 0) { - perror("Failed receiving DNS response."); - exit(-1); + if (errno == EWOULDBLOCK || errno == EAGAIN) + { + printf("DNS receive timed out\n"); + } else { + perror("Failed receiving DNS response."); + } + return -1; } - // Parse received DNS message - dns_message_t dns_message = dns_parse_message(buffer); - free(buffer); - return dns_message; + // DNS response was received successfully, parse it + *dns_message = dns_parse_message(buffer); + return 0; } ///// DESTROY ///// diff --git a/test/dns.c b/test/dns.c index 182f20de87e97d334b217b568a35be92bcff1334..03373a2674a0e6cd0b3b7f3d7266a1f6f31ee8ab 100644 --- a/test/dns.c +++ b/test/dns.c @@ -340,11 +340,14 @@ void test_dns_office() { * @brief Test the `dns_send_query` and `dns_receive_response` functions. */ void test_dns_send_receive() { - /* Initialize */ + // Initialize + int ret; char *domain_name = "www.google.com"; + // Open socket int sockfd = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); CU_ASSERT_TRUE(sockfd > 0); + // Server address: network gateway struct sockaddr_in server_addr; memset(&server_addr, 0, sizeof(server_addr)); @@ -353,14 +356,17 @@ void test_dns_send_receive() { server_addr.sin_addr.s_addr = inet_addr("8.8.8.8"); // Send query for dummy domain name - dns_send_query(domain_name, sockfd, &server_addr); - dns_message_t dns_message = dns_receive_response(sockfd, &server_addr); + ret = dns_send_query(domain_name, sockfd, &server_addr); + CU_ASSERT_EQUAL(ret, 0); - // Verify DNS response's domain name - CU_ASSERT_STRING_EQUAL(dns_message.questions->qname, domain_name); + // Receive response + dns_message_t dns_response; + ret = dns_receive_response(sockfd, &server_addr, &dns_response); + CU_ASSERT_EQUAL(ret ,0); + CU_ASSERT_STRING_EQUAL(dns_response.questions->qname, domain_name); // Free memory - dns_free_message(dns_message); + dns_free_message(dns_response); } /**