diff --git a/src/header/netdb/dns/mod.rs b/src/header/netdb/dns/mod.rs index 9d7e44b..f5bc21b 100644 --- a/src/header/netdb/dns/mod.rs +++ b/src/header/netdb/dns/mod.rs @@ -15,7 +15,35 @@ use alloc::{string::String, vec::Vec}; mod answer; mod query; +const DNS_FLAG_QR: u16 = 0x8000; +const DNS_FLAG_TC: u16 = 0x0200; +const DNS_RCODE_MASK: u16 = 0x000F; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(super) enum DnsError { + MalformedResponse, + TransactionIdMismatch, + NotResponse, + Truncated, + ServerFailure, + NameError, + ResponseCode(u8), +} + +impl DnsError { + fn as_str(self) -> &'static str { + match self { + Self::MalformedResponse => "malformed dns response", + Self::TransactionIdMismatch => "dns transaction id mismatch", + Self::NotResponse => "dns packet was not a response", + Self::Truncated => "truncated dns response", + Self::ServerFailure => "dns server failure", + Self::NameError => "dns name error", + Self::ResponseCode(_) => "dns server returned an error response", + } + } +} #[derive(Clone, Debug)] pub struct Dns { pub transaction_id: u16, @@ -59,6 +88,14 @@ impl Dns { } pub fn parse(data: &[u8]) -> Result { + Self::parse_impl(data, None).map_err(|err| err.as_str().into()) + } + + pub(super) fn parse_reply(data: &[u8], expected_transaction_id: u16) -> Result { + Self::parse_impl(data, Some(expected_transaction_id)) + } + + fn parse_impl(data: &[u8], expected_transaction_id: Option) -> Result { let name_ind = 0b1100_0000; let mut i = 0; @@ -66,7 +103,7 @@ impl Dns { () => {{ i += 1; if i > data.len() { - return Err(format!("{}: {}: pop_u8", file!(), line!())); + return Err(DnsError::MalformedResponse); } data[i - 1] }}; @@ -77,9 +114,11 @@ impl Dns { use core::convert::TryInto; i += 2; if i > data.len() { - return Err(format!("{}: {}: pop_n16", file!(), line!())); + return Err(DnsError::MalformedResponse); } - let bytes: [u8; 2] = data[i - 2..i].try_into().unwrap(); + let bytes: [u8; 2] = data[i - 2..i] + .try_into() + .map_err(|_| DnsError::MalformedResponse)?; u16::from_be_bytes(bytes) }}; } @@ -156,10 +195,83 @@ impl Dns { }); } - Ok(Dns { + let dns = Dns { transaction_id, flags, queries, answers, - }) + }; + + if let Some(expected_transaction_id) = expected_transaction_id { + if dns.transaction_id != expected_transaction_id { + return Err(DnsError::TransactionIdMismatch); + } + } + + if dns.flags & DNS_FLAG_QR == 0 { + return Err(DnsError::NotResponse); + } + + if dns.flags & DNS_FLAG_TC != 0 { + return Err(DnsError::Truncated); + } + + match (dns.flags & DNS_RCODE_MASK) as u8 { + 0 => Ok(dns), + 2 => Err(DnsError::ServerFailure), + 3 => Err(DnsError::NameError), + rcode => Err(DnsError::ResponseCode(rcode)), + } + } +} + +#[cfg(test)] +mod tests { + use alloc::{string::ToString, vec::Vec}; + + use super::{Dns, DnsError, DnsQuery}; + + fn packet(transaction_id: u16, flags: u16) -> Vec { + Dns { + transaction_id, + flags, + queries: vec![DnsQuery { + name: "example.com".to_string(), + q_type: 0x0001, + q_class: 0x0001, + }], + answers: vec![], + } + .compile() + } + + #[test] + fn parse_reply_accepts_valid_response() { + let response = Dns::parse_reply(&packet(0x1234, 0x8180), 0x1234).unwrap(); + assert_eq!(response.transaction_id, 0x1234); + } + + #[test] + fn parse_reply_rejects_transaction_id_mismatch() { + let err = Dns::parse_reply(&packet(0x1234, 0x8180), 0x4321).unwrap_err(); + assert_eq!(err, DnsError::TransactionIdMismatch); + } + + #[test] + fn parse_rejects_query_packets() { + let err = Dns::parse(&packet(0x1234, 0x0100)).unwrap_err(); + assert_eq!(err, DnsError::NotResponse.as_str()); + } + + #[test] + fn parse_rejects_truncated_response() { + let err = Dns::parse(&packet(0x1234, 0x8380)).unwrap_err(); + assert_eq!(err, DnsError::Truncated.as_str()); + } + + #[test] + fn parse_rejects_name_error_response() { + let err = Dns::parse(&packet(0x1234, 0x8183)).unwrap_err(); + assert_eq!(err, DnsError::NameError.as_str()); + } } diff --git a/src/header/netdb/lookup.rs b/src/header/netdb/lookup.rs index c2b6cdb..af25f97 100644 --- a/src/header/netdb/lookup.rs +++ b/src/header/netdb/lookup.rs @@ -1,10 +1,10 @@ -use alloc::{boxed::Box, string::ToString, vec::Vec}; +use alloc::{string::ToString, vec::Vec}; use core::{mem, ptr}; use crate::{ out::Out, platform::{ - Pal, Sys, + self, Pal, Sys, types::{c_int, c_void}, }, }; @@ -25,12 +25,120 @@ use crate::header::{ }; use super::{ - dns::{Dns, DnsQuery}, + dns::{Dns, DnsError, DnsQuery}, sys::get_dns_server, }; pub type LookupHost = Vec; pub type LookupHostV6 = Vec; + +fn close_socket(sock: c_int) { + if sock >= 0 { + if let Ok(()) = Sys::close(sock) {}; + } +} + +fn last_socket_error(default: c_int) -> c_int { + match platform::ERRNO.get() { + 0 => default, + err => err, + } +} + +fn map_dns_error(err: DnsError) -> c_int { + match err { + DnsError::NameError => ENOENT, + DnsError::ServerFailure => EAGAIN, + DnsError::Truncated => EMSGSIZE, + DnsError::MalformedResponse + | DnsError::TransactionIdMismatch + | DnsError::NotResponse + | DnsError::ResponseCode(_) => EREMOTEIO, + } +} + +fn lookup_dns_response(packet: &Dns, dns_addr: u32) -> Result { + let packet_data = packet.compile(); + let packet_data_len = packet_data.len(); + let packet_data_ptr = packet_data.as_ptr().cast::(); + + let dest = sockaddr_in { + sin_family: AF_INET as u16, + sin_port: htons(53), + sin_addr: in_addr { s_addr: dns_addr }, + ..Default::default() + }; + let dest_ptr = ptr::from_ref(&dest).cast::(); + + let sock = unsafe { sys_socket::socket(AF_INET, SOCK_DGRAM, i32::from(IPPROTO_UDP)) }; + if sock < 0 { + return Err(last_socket_error(EIO)); + } + + if unsafe { sys_socket::connect(sock, dest_ptr, mem::size_of_val(&dest) as socklen_t) } < 0 { + let err = last_socket_error(EIO); + close_socket(sock); + return Err(err); + } + + if unsafe { sys_socket::send(sock, packet_data_ptr, packet_data_len, 0) } < 0 { + let err = last_socket_error(EIO); + close_socket(sock); + return Err(err); + } + + let tv = timeval { + tv_sec: 5, + tv_usec: 0, + }; + unsafe { + sys_socket::setsockopt( + sock, + SOL_SOCKET, + SO_RCVTIMEO, + &tv as *const timeval as *const c_void, + core::mem::size_of::() as socklen_t, + ); + } + + let mut buf = vec![0u8; 65536]; + let buf_ptr = buf.as_mut_ptr().cast::(); + + let mut count: isize = -1; + let mut recv_error = EIO; + for attempt in 0..2 { + count = unsafe { sys_socket::recv(sock, buf_ptr, buf.len(), 0) }; + if count >= 0 { + break; + } + + recv_error = last_socket_error(EIO); + if attempt + 1 == 2 { + break; + } + + if unsafe { sys_socket::send(sock, packet_data_ptr, packet_data_len, 0) } < 0 { + recv_error = last_socket_error(EIO); + break; + } + } + + if count < 0 { + close_socket(sock); + return Err(recv_error); + } + + let response = match Dns::parse_reply(&buf[..count as usize], packet.transaction_id) { + Ok(response) => response, + Err(err) => { + close_socket(sock); + return Err(map_dns_error(err)); + } + }; + + close_socket(sock); + Ok(response) +} pub fn lookup_host(host: &str) -> Result { if let Some(host_direct_addr) = parse_ipv4_string(host) { @@ -61,97 +134,28 @@ pub fn lookup_host(host: &str) -> Result { answers: vec![], }; - let packet_data = packet.compile(); - let packet_data_len = packet_data.len(); - - let packet_data_box = packet_data.into_boxed_slice(); - let packet_data_ptr = Box::into_raw(packet_data_box) as *mut _ as *mut c_void; - - let dest = sockaddr_in { - sin_family: AF_INET as u16, - sin_port: htons(53), - sin_addr: in_addr { s_addr: dns_addr }, - ..Default::default() - }; - let dest_ptr = ptr::from_ref(&dest).cast::(); - - let sock = unsafe { - let sock = sys_socket::socket(AF_INET, SOCK_DGRAM, i32::from(IPPROTO_UDP)); - if sys_socket::connect(sock, dest_ptr, mem::size_of_val(&dest) as socklen_t) < 0 { - return Err(EIO); - } - if sys_socket::send(sock, packet_data_ptr, packet_data_len, 0) < 0 { - drop(Box::from_raw(packet_data_ptr)); - return Err(EIO); - } - sock - }; - - unsafe { - drop(Box::from_raw(packet_data_ptr)); - } - - let mut buf = vec![0u8; 65536]; - let buf_ptr = buf.as_mut_ptr().cast::(); - - // Set 5s recv timeout (best-effort; if this fails, recv may block longer). - let tv = timeval { - tv_sec: 5, - tv_usec: 0, - }; - unsafe { - sys_socket::setsockopt( - sock, - SOL_SOCKET, - SO_RCVTIMEO, - &tv as *const timeval as *const c_void, - core::mem::size_of::() as socklen_t, - ); - } - - let mut count: isize = -1; - for _attempt in 0..2 { - count = unsafe { sys_socket::recv(sock, buf_ptr, 65536, 0) }; - if count >= 0 { - break; - } - if unsafe { sys_socket::send(sock, packet_data_ptr, packet_data_len, 0) } < 0 { - break; - } - } - if count < 0 { - return Err(EIO); - } - - match Dns::parse(&buf[..count as usize]) { - Ok(response) => { - let addrs: Vec<_> = response - .answers - .into_iter() - .filter_map(|answer| { - if answer.a_type == 0x0001 - && answer.a_class == 0x0001 - && answer.data.len() == 4 - { - let addr = in_addr { - s_addr: u32::from_ne_bytes([ - answer.data[0], - answer.data[1], - answer.data[2], - answer.data[3], - ]), - }; - Some(addr) - } else { - None - } - }) - .collect(); - - Ok(addrs) - } - Err(_err) => Err(EINVAL), - } + let response = lookup_dns_response(&packet, dns_addr)?; + let addrs: Vec<_> = response + .answers + .into_iter() + .filter_map(|answer| { + if answer.a_type == 0x0001 && answer.a_class == 0x0001 && answer.data.len() == 4 { + let addr = in_addr { + s_addr: u32::from_ne_bytes([ + answer.data[0], + answer.data[1], + answer.data[2], + answer.data[3], + ]), + }; + Some(addr) + } else { + None + } + }) + .collect(); + + Ok(addrs) } else { Err(EINVAL) } @@ -186,91 +192,22 @@ pub fn lookup_host_v6(host: &str) -> Result { answers: vec![], }; - let packet_data = packet.compile(); - let packet_data_len = packet_data.len(); - - let packet_data_box = packet_data.into_boxed_slice(); - let packet_data_ptr = Box::into_raw(packet_data_box) as *mut _ as *mut c_void; - - let dest = sockaddr_in { - sin_family: AF_INET as u16, - sin_port: htons(53), - sin_addr: in_addr { s_addr: dns_addr }, - ..Default::default() - }; - let dest_ptr = ptr::from_ref(&dest).cast::(); - - let sock = unsafe { - let sock = sys_socket::socket(AF_INET, SOCK_DGRAM, i32::from(IPPROTO_UDP)); - if sys_socket::connect(sock, dest_ptr, mem::size_of_val(&dest) as socklen_t) < 0 { - return Err(EIO); - } - if sys_socket::send(sock, packet_data_ptr, packet_data_len, 0) < 0 { - drop(Box::from_raw(packet_data_ptr)); - return Err(EIO); - } - sock - }; - - unsafe { - drop(Box::from_raw(packet_data_ptr)); - } - - let mut buf = vec![0u8; 65536]; - let buf_ptr = buf.as_mut_ptr().cast::(); - - // Set 5s recv timeout (best-effort; if this fails, recv may block longer). - let tv = timeval { - tv_sec: 5, - tv_usec: 0, - }; - unsafe { - sys_socket::setsockopt( - sock, - SOL_SOCKET, - SO_RCVTIMEO, - &tv as *const timeval as *const c_void, - core::mem::size_of::() as socklen_t, - ); - } - - let mut count: isize = -1; - for _attempt in 0..2 { - count = unsafe { sys_socket::recv(sock, buf_ptr, 65536, 0) }; - if count >= 0 { - break; - } - if unsafe { sys_socket::send(sock, packet_data_ptr, packet_data_len, 0) } < 0 { - break; - } - } - if count < 0 { - return Err(EIO); - } - - match Dns::parse(&buf[..count as usize]) { - Ok(response) => { - let addrs: Vec<_> = response - .answers - .into_iter() - .filter_map(|answer| { - if answer.a_type == 0x001c - && answer.a_class == 0x0001 - && answer.data.len() == 16 - { - let mut s6_addr = [0u8; 16]; - s6_addr.copy_from_slice(&answer.data[..16]); - Some(in6_addr { s6_addr }) - } else { - None - } - }) - .collect(); - - Ok(addrs) - } - Err(_err) => Err(EINVAL), - } + let response = lookup_dns_response(&packet, dns_addr)?; + let addrs: Vec<_> = response + .answers + .into_iter() + .filter_map(|answer| { + if answer.a_type == 0x001c && answer.a_class == 0x0001 && answer.data.len() == 16 { + let mut s6_addr = [0u8; 16]; + s6_addr.copy_from_slice(&answer.data[..16]); + Some(in6_addr { s6_addr }) + } else { + None + } + }) + .collect(); + + Ok(addrs) } else { Err(EINVAL) } @@ -315,92 +254,24 @@ pub fn lookup_addr(addr: in_addr) -> Result>, c_int> { answers: vec![], }; - let packet_data = packet.compile(); - let packet_data_len = packet_data.len(); - let packet_data_box = packet_data.into_boxed_slice(); - let packet_data_ptr = Box::into_raw(packet_data_box) as *mut _ as *mut c_void; - - let dest = sockaddr_in { - sin_family: AF_INET as u16, - sin_port: htons(53), - sin_addr: in_addr { s_addr: dns_addr }, - ..Default::default() - }; - - let dest_ptr = ptr::from_ref(&dest).cast::(); - - let sock = unsafe { - let sock = sys_socket::socket(AF_INET, SOCK_DGRAM, i32::from(IPPROTO_UDP)); - if sys_socket::connect(sock, dest_ptr, mem::size_of_val(&dest) as socklen_t) < 0 { - return Err(EIO); - } - sock - }; - - unsafe { - if sys_socket::send(sock, packet_data_ptr, packet_data_len, 0) < 0 { - return Err(EIO); - } - } - - unsafe { - drop(Box::from_raw(packet_data_ptr)); - } - - let mut buf = [0u8; 65536]; - let buf_ptr = buf.as_mut_ptr().cast::(); - - // Set 5s recv timeout (best-effort; if this fails, recv may block longer). - let tv = timeval { - tv_sec: 5, - tv_usec: 0, - }; - unsafe { - sys_socket::setsockopt( - sock, - SOL_SOCKET, - SO_RCVTIMEO, - &tv as *const timeval as *const c_void, - core::mem::size_of::() as socklen_t, - ); - } - - let mut count: isize = -1; - for _attempt in 0..2 { - count = unsafe { sys_socket::recv(sock, buf_ptr, 65536, 0) }; - if count >= 0 { - break; - } - if unsafe { sys_socket::send(sock, packet_data_ptr, packet_data_len, 0) } < 0 { - break; - } - } - if count < 0 { - return Err(EIO); - } - - match Dns::parse(&buf[..count as usize]) { - Ok(response) => { - let names = response - .answers - .into_iter() - .filter_map(|answer| { - if answer.a_type == 0x000C && answer.a_class == 0x0001 { - // answer.data is encoded kinda weird. - // Basically length-prefixed strings for each - // subsection of the domain. - // We need to parse this to insert periods where - // they belong (ie at the end of each string) - Some(parse_revdns_answer(&answer.data)) - } else { - None - } - }) - .collect(); - Ok(names) - } - Err(_err) => Err(EINVAL), - } + let response = lookup_dns_response(&packet, dns_addr)?; + let names = response + .answers + .into_iter() + .filter_map(|answer| { + if answer.a_type == 0x000C && answer.a_class == 0x0001 { + // answer.data is encoded kinda weird. + // Basically length-prefixed strings for each + // subsection of the domain. + // We need to parse this to insert periods where + // they belong (ie at the end of each string) + Some(parse_revdns_answer(&answer.data)) + } else { + None + } + }) + .collect(); + Ok(names) } else { Err(EINVAL) } diff --git a/src/header/netdb/mod.rs b/src/header/netdb/mod.rs index ba58b6e..cdcc10e 100644 --- a/src/header/netdb/mod.rs +++ b/src/header/netdb/mod.rs @@ -180,6 +180,35 @@ fn bytes_to_box_str(bytes: &[u8]) -> Box { Box::from(core::str::from_utf8(bytes).unwrap_or("")) } +fn lookup_error_to_eai(err: c_int) -> c_int { + match err { + ETIMEDOUT | EAGAIN => EAI_AGAIN, + ENOENT => EAI_NONAME, + _ => EAI_FAIL, + } +} + +fn lookup_error_priority(err: c_int) -> u8 { + match err { + EAI_AGAIN => 3, + EAI_FAIL => 2, + EAI_NONAME => 1, + _ => 0, + } +} + +fn combine_lookup_error(current: Option, err: c_int) -> c_int { + let mapped = lookup_error_to_eai(err); + + match current { + Some(existing) if lookup_error_priority(existing) >= lookup_error_priority(mapped) => { + existing + } + Some(_) => mapped, + None => mapped, + } +} + /// See . #[unsafe(no_mangle)] pub unsafe extern "C" fn endnetent() { @@ -926,6 +951,8 @@ pub unsafe extern "C" fn getaddrinfo( let want_inet4 = requested_family == AF_INET || requested_family == AF_UNSPEC; let want_inet6 = requested_family == AF_INET6 || requested_family == AF_UNSPEC; + let mut lookup_error = None; + let lookuphost_v4: Vec = if want_inet4 { if ai_flags & AI_NUMERICHOST > 0 { match parse_ipv4_string(node_str) { @@ -937,7 +964,10 @@ pub unsafe extern "C" fn getaddrinfo( } else { match lookup_host(node_str) { Ok(addrs) => addrs, - Err(_) => vec![], + Err(err) => { + lookup_error = Some(combine_lookup_error(lookup_error, err)); + vec![] + } } } } else { @@ -955,7 +985,10 @@ pub unsafe extern "C" fn getaddrinfo( } else { match lookup_host_v6(node_str) { Ok(addrs) => addrs, - Err(_) => vec![], + Err(err) => { + lookup_error = Some(combine_lookup_error(lookup_error, err)); + vec![] + } } } } else { @@ -963,5 +996,5 @@ pub unsafe extern "C" fn getaddrinfo( }; if lookuphost_v4.is_empty() && lookuphost_v6.is_empty() { - return EAI_NONAME; + return lookup_error.unwrap_or(EAI_NONAME); }