diff --git a/src/header/arpa_inet/mod.rs b/src/header/arpa_inet/mod.rs index e982353f..56806d4a 100644 --- a/src/header/arpa_inet/mod.rs +++ b/src/header/arpa_inet/mod.rs @@ -2,6 +2,7 @@ //! //! See . +use alloc::{string::String, vec::Vec}; use core::{ ptr, slice, str::{self, FromStr}, @@ -13,8 +14,8 @@ use crate::{ bits_arpainet::ntohl, bits_socklen_t::socklen_t, errno::{EAFNOSUPPORT, ENOSPC}, - netinet_in::{INADDR_NONE, in_addr, in_addr_t}, - sys_socket::constants::AF_INET, + netinet_in::{INADDR_NONE, INET6_ADDRSTRLEN, in6_addr, in_addr, in_addr_t}, + sys_socket::constants::{AF_INET, AF_INET6}, }, platform::{ self, @@ -181,34 +182,111 @@ pub unsafe extern "C" fn inet_ntop( dst: *mut c_char, size: socklen_t, ) -> *const c_char { - if af != AF_INET { - platform::ERRNO.set(EAFNOSUPPORT); - ptr::null() - } else if size < 16 { - platform::ERRNO.set(ENOSPC); - ptr::null() - } else { - let s_addr = unsafe { - slice::from_raw_parts( - ptr::from_ref(&(*(src.cast::())).s_addr).cast::(), - 4, - ) - }; - let addr = format!("{}.{}.{}.{}\0", s_addr[0], s_addr[1], s_addr[2], s_addr[3]); + if af == AF_INET6 { + if size < INET6_ADDRSTRLEN as socklen_t { + platform::ERRNO.set(ENOSPC); + return ptr::null(); + } + let s6_addr = unsafe { &(*(src.cast::())).s6_addr }; + let output = inet_ntop6(s6_addr); + let bytes = output.as_bytes(); unsafe { - ptr::copy(addr.as_ptr().cast::(), dst, addr.len()); + ptr::copy(bytes.as_ptr().cast::(), dst, bytes.len()); + *dst.add(bytes.len()) = 0; } dst + } else if af == AF_INET { + if size < 16 { + platform::ERRNO.set(ENOSPC); + ptr::null() + } else { + let s_addr = unsafe { + slice::from_raw_parts( + ptr::from_ref(&(*(src.cast::())).s_addr).cast::(), + 4, + ) + }; + let addr = format!("{}.{}.{}.{}\0", s_addr[0], s_addr[1], s_addr[2], s_addr[3]); + unsafe { + ptr::copy(addr.as_ptr().cast::(), dst, addr.len()); + } + dst + } + } else { + platform::ERRNO.set(EAFNOSUPPORT); + ptr::null() + } +} + +fn inet_ntop6(addr: &[u8; 16]) -> String { + let groups: [u16; 8] = core::array::from_fn(|i| { + u16::from_be_bytes([addr[i * 2], addr[i * 2 + 1]]) + }); + + let mut best_start = 8usize; + let mut best_len = 1usize; + let mut cur_start = 8usize; + let mut cur_len = 0usize; + + for i in 0..8 { + if groups[i] == 0 { + if cur_len == 0 { + cur_start = i; + } + cur_len += 1; + } else { + if cur_len > best_len { + best_start = cur_start; + best_len = cur_len; + } + cur_len = 0; + } + } + if cur_len > best_len { + best_start = cur_start; + best_len = cur_len; + } + + let mut parts = Vec::new(); + let mut i = 0usize; + while i < 8 { + if i == best_start && best_len > 1 { + if i == 0 { + parts.push(String::new()); + } + if i + best_len == 8 { + parts.push(String::new()); + } + i += best_len; + } else { + parts.push(format!("{:x}", groups[i])); + i += 1; + } + } + + if best_len == 8 { + return String::from("::"); } + + parts.join(":") } /// See . #[unsafe(no_mangle)] pub unsafe extern "C" fn inet_pton(af: c_int, src: *const c_char, dst: *mut c_void) -> c_int { - if af != AF_INET { - platform::ERRNO.set(EAFNOSUPPORT); - -1 - } else { + if af == AF_INET6 { + let src_cstr = unsafe { CStr::from_ptr(src) }; + let src_str = match src_cstr.to_str() { + Ok(s) => s, + Err(_) => return 0, + }; + let out = unsafe { &mut *(dst.cast::()) }; + if inet_pton6(src_str, &mut out.s6_addr) { + 1 + } else { + 0 + } + } else if af == AF_INET { let s_addr = unsafe { slice::from_raw_parts_mut( ptr::from_mut(&mut (*dst.cast::()).s_addr).cast::(), @@ -233,5 +311,137 @@ pub unsafe extern "C" fn inet_pton(af: c_int, src: *const c_char, dst: *mut c_vo } else { 0 } + } else { + platform::ERRNO.set(EAFNOSUPPORT); + -1 + } +} + +fn inet_pton6(src: &str, dst: &mut [u8; 16]) -> bool { + dst.fill(0); + + let double_colon_pos = src.find("::"); + let second_double = if let Some(pos) = double_colon_pos { + src[pos + 2..].find("::").map(|p| p + pos + 2) + } else { + None + }; + if second_double.is_some() { + return false; + } + + let (left_str, right_str) = match double_colon_pos { + Some(pos) => (&src[..pos], &src[pos + 2..]), + None => (src, ""), + }; + + let left_groups: Vec<&str> = if left_str.is_empty() { + Vec::new() + } else { + left_str.split(':').collect() + }; + let right_groups: Vec<&str> = if right_str.is_empty() { + Vec::new() + } else { + right_str.split(':').collect() + }; + + let right_has_ipv4 = right_groups.last().is_some_and(|g| g.contains('.')); + let mut left_count = left_groups.len(); + let mut right_count = right_groups.len(); + if right_has_ipv4 { + right_count -= 1; + left_count += 1; + } + + let gap = 8 - left_count - right_count; + if double_colon_pos.is_none() && gap != 0 { + return false; + } + if double_colon_pos.is_some() && gap < 0 { + return false; + } + if double_colon_pos.is_none() && left_groups.len() + right_groups.len() != 8 { + return false; + } + + let mut idx = 0usize; + + for group in &left_groups { + if idx >= 16 { + return false; + } + let val = match parse_hex_group(group) { + Some(v) => v, + None => return false, + }; + dst[idx] = (val >> 8) as u8; + dst[idx + 1] = val as u8; + idx += 2; + } + + if double_colon_pos.is_some() { + for _ in 0..gap { + if idx >= 16 { + return false; + } + dst[idx] = 0; + dst[idx + 1] = 0; + idx += 2; + } + } + + let right_hex_count = if right_has_ipv4 { + right_groups.len().saturating_sub(1) + } else { + right_groups.len() + }; + + for group in &right_groups[..right_hex_count] { + if idx >= 16 { + return false; + } + let val = match parse_hex_group(group) { + Some(v) => v, + None => return false, + }; + dst[idx] = (val >> 8) as u8; + dst[idx + 1] = val as u8; + idx += 2; + } + + if right_has_ipv4 { + if idx != 12 { + return false; + } + let ipv4_str = right_groups[right_groups.len() - 1]; + let parts: Vec<&str> = ipv4_str.split('.').collect(); + if parts.len() != 4 { + return false; + } + for (i, part) in parts.iter().enumerate() { + match u8::from_str(part) { + Ok(v) => dst[12 + i] = v, + Err(_) => return false, + } + } + idx += 4; + } + + idx == 16 +} + +fn parse_hex_group(s: &str) -> Option { + if s.is_empty() || s.len() > 4 { + return None; + } + let mut val: u16 = 0; + for c in s.chars() { + val = val.checked_mul(16)?; + match c.to_digit(16) { + Some(d) => val = val.checked_add(d as u16)?, + None => return None, + } } + Some(val) }