Add UDP traceroute netstack patch

Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
This commit is contained in:
2026-04-14 22:53:12 +01:00
parent d273bf718b
commit c76d4b8968
+324
View File
@@ -282,6 +282,330 @@ index 94a1eb17..3fd91156 100644
address: u64,
}
diff --git a/netstack/src/scheme/icmp.rs b/netstack/src/scheme/icmp.rs
index 6365e2a7..b391bae1 100644
--- a/netstack/src/scheme/icmp.rs
+++ b/netstack/src/scheme/icmp.rs
@@ -3,7 +3,7 @@ use smoltcp::socket::icmp::{
Endpoint as IcmpEndpoint, PacketBuffer as IcmpSocketBuffer,
PacketMetadata as IcmpPacketMetadata, Socket as IcmpSocket,
};
-use smoltcp::wire::{Icmpv4Packet, Icmpv4Repr, IpAddress, IpListenEndpoint};
+use smoltcp::wire::{Icmpv4Packet, Icmpv4Repr, IpAddress, IpListenEndpoint, IpProtocol, UdpPacket};
use std::mem;
use std::str;
use syscall;
@@ -16,6 +16,10 @@ use crate::router::Router;
pub type IcmpScheme = SchemeWrapper<IcmpSocket<'static>>;
+const ICMP_UDP_TRACE_EVENT_LEN: usize = 12;
+const ICMP_UDP_TRACE_KIND_TIME_EXCEEDED: u8 = 1;
+const ICMP_UDP_TRACE_KIND_DST_UNREACHABLE: u8 = 2;
+
enum IcmpSocketType {
Echo,
Udp,
@@ -27,6 +31,79 @@ pub struct IcmpData {
ident: u16,
}
+fn encode_udp_trace_event(
+ kind: u8,
+ code: u8,
+ responder: std::net::Ipv4Addr,
+ source_port: u16,
+ dest_port: u16,
+ buf: &mut [u8],
+) -> SyscallResult<usize> {
+ if buf.len() < ICMP_UDP_TRACE_EVENT_LEN {
+ return Err(SyscallError::new(syscall::EINVAL));
+ }
+
+ buf[0] = kind;
+ buf[1] = code;
+ buf[2..4].fill(0);
+ buf[4..8].copy_from_slice(&responder.octets());
+ buf[8..10].copy_from_slice(&source_port.to_be_bytes());
+ buf[10..12].copy_from_slice(&dest_port.to_be_bytes());
+ Ok(ICMP_UDP_TRACE_EVENT_LEN)
+}
+
+fn build_udp_trace_event(
+ icmp_repr: &Icmpv4Repr<'_>,
+ responder_ip: IpAddress,
+ remote_ip: IpAddress,
+ local_port: u16,
+ buf: &mut [u8],
+) -> Option<SyscallResult<usize>> {
+ let (kind, code, header, data) = match icmp_repr {
+ Icmpv4Repr::TimeExceeded {
+ reason,
+ header,
+ data,
+ } => (
+ ICMP_UDP_TRACE_KIND_TIME_EXCEEDED,
+ u8::from(*reason),
+ header,
+ *data,
+ ),
+ Icmpv4Repr::DstUnreachable {
+ reason,
+ header,
+ data,
+ } => (
+ ICMP_UDP_TRACE_KIND_DST_UNREACHABLE,
+ u8::from(*reason),
+ header,
+ *data,
+ ),
+ _ => return None,
+ };
+
+ if header.next_header != IpProtocol::Udp || remote_ip != IpAddress::Ipv4(header.dst_addr) {
+ return None;
+ }
+
+ let udp_packet = UdpPacket::new_checked(data).ok()?;
+ if udp_packet.src_port() != local_port {
+ return None;
+ }
+
+ let IpAddress::Ipv4(responder) = responder_ip;
+
+ Some(encode_udp_trace_event(
+ kind,
+ code,
+ responder,
+ udp_packet.src_port(),
+ udp_packet.dst_port(),
+ buf,
+ ))
+}
+
impl<'a> SchemeSocket for IcmpSocket<'a> {
type SchemeDataT = PortSet;
type DataT = IcmpData;
@@ -126,6 +203,21 @@ impl<'a> SchemeSocket for IcmpSocket<'a> {
let ip =
IpAddress::from_str(addr).map_err(|_| syscall::Error::new(syscall::EINVAL))?;
+ let ident = match parts.next() {
+ Some(port) if !port.is_empty() => {
+ let port = port
+ .parse::<u16>()
+ .map_err(|_| syscall::Error::new(syscall::EINVAL))?;
+ if !ident_set.claim_port(port) {
+ return Err(SyscallError::new(syscall::EADDRINUSE));
+ }
+ port
+ }
+ Some(_) | None => ident_set
+ .get_port()
+ .ok_or_else(|| SyscallError::new(syscall::EINVAL))?,
+ };
+
let socket = IcmpSocket::new(
IcmpSocketBuffer::new(
vec![IcmpPacketMetadata::EMPTY; Smolnetd::SOCKET_BUFFER_SIZE],
@@ -138,9 +230,6 @@ impl<'a> SchemeSocket for IcmpSocket<'a> {
);
let handle = socket_set.add(socket);
let icmp_socket = socket_set.get_mut::<IcmpSocket>(handle);
- let ident = ident_set
- .get_port()
- .ok_or_else(|| SyscallError::new(syscall::EINVAL))?;
icmp_socket
.bind(IcmpEndpoint::Udp(IpListenEndpoint::from(ident)))
.map_err(|_| syscall::Error::new(syscall::EINVAL))?;
@@ -213,22 +302,39 @@ impl<'a> SchemeSocket for IcmpSocket<'a> {
return Ok(0);
}
while self.can_recv(&file.data) {
- let (payload, _) = self.recv().expect("Can't recv icmp packet");
+ let (payload, responder_ip) = self.recv().expect("Can't recv icmp packet");
let icmp_packet = Icmpv4Packet::new_unchecked(&payload);
//TODO: replace default with actual caps
- let icmp_repr = Icmpv4Repr::parse(&icmp_packet, &Default::default()).unwrap();
+ let Ok(icmp_repr) = Icmpv4Repr::parse(&icmp_packet, &Default::default()) else {
+ continue;
+ };
- if let Icmpv4Repr::EchoReply { seq_no, data, .. } = icmp_repr {
- if buf.len() < mem::size_of::<u16>() + data.len() {
- return Err(SyscallError::new(syscall::EINVAL));
- }
- buf[0..2].copy_from_slice(&seq_no.to_be_bytes());
+ match file.data.socket_type {
+ IcmpSocketType::Echo => {
+ if let Icmpv4Repr::EchoReply { seq_no, data, .. } = icmp_repr {
+ if buf.len() < mem::size_of::<u16>() + data.len() {
+ return Err(SyscallError::new(syscall::EINVAL));
+ }
+ buf[0..2].copy_from_slice(&seq_no.to_be_bytes());
- for i in 0..data.len() {
- buf[mem::size_of::<u16>() + i] = data[i];
- }
+ for i in 0..data.len() {
+ buf[mem::size_of::<u16>() + i] = data[i];
+ }
- return Ok(mem::size_of::<u16>() + data.len());
+ return Ok(mem::size_of::<u16>() + data.len());
+ }
+ }
+ IcmpSocketType::Udp => {
+ if let Some(result) = build_udp_trace_event(
+ &icmp_repr,
+ responder_ip,
+ file.data.ip,
+ file.data.ident,
+ buf,
+ ) {
+ return result;
+ }
+ }
}
}
@@ -264,7 +370,10 @@ impl<'a> SchemeSocket for IcmpSocket<'a> {
Ok(i)
}
IcmpSocketType::Udp => {
- let path = format!("/scheme/icmp/udp/{}", socket_file.data.ip);
+ let path = format!(
+ "/scheme/icmp/udp/{}/{}",
+ socket_file.data.ip, socket_file.data.ident
+ );
let path = path.as_bytes();
let mut i = 0;
@@ -316,3 +425,123 @@ impl<'a> SchemeSocket for IcmpSocket<'a> {
Ok(0)
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::{
+ build_udp_trace_event, ICMP_UDP_TRACE_KIND_DST_UNREACHABLE,
+ ICMP_UDP_TRACE_KIND_TIME_EXCEEDED,
+ };
+ use smoltcp::wire::{
+ Icmpv4DstUnreachable, Icmpv4Repr, Icmpv4TimeExceeded, IpAddress, IpProtocol, Ipv4Address,
+ Ipv4Repr,
+ };
+
+ fn udp_header(source_port: u16, dest_port: u16) -> [u8; 8] {
+ let mut header = [0_u8; 8];
+ header[0..2].copy_from_slice(&source_port.to_be_bytes());
+ header[2..4].copy_from_slice(&dest_port.to_be_bytes());
+ header[4..6].copy_from_slice(&(8_u16).to_be_bytes());
+ header
+ }
+
+ #[test]
+ fn emits_time_exceeded_trace_event_for_matching_udp_probe() {
+ let remote = Ipv4Address::new(203, 0, 113, 9);
+ let responder = Ipv4Address::new(192, 0, 2, 1);
+ let local_port = 42_000;
+ let dest_port = 33_434;
+ let udp = udp_header(local_port, dest_port);
+ let icmp_repr = Icmpv4Repr::TimeExceeded {
+ reason: Icmpv4TimeExceeded::TtlExpired,
+ header: Ipv4Repr {
+ src_addr: Ipv4Address::new(10, 0, 2, 15),
+ dst_addr: remote,
+ next_header: IpProtocol::Udp,
+ payload_len: udp.len(),
+ hop_limit: 1,
+ },
+ data: &udp,
+ };
+ let mut encoded = [0_u8; 12];
+
+ let written = build_udp_trace_event(
+ &icmp_repr,
+ IpAddress::Ipv4(responder),
+ IpAddress::Ipv4(remote),
+ local_port,
+ &mut encoded,
+ )
+ .expect("expected traceroute event")
+ .expect("expected successful encoding");
+
+ assert_eq!(written, encoded.len());
+ assert_eq!(encoded[0], ICMP_UDP_TRACE_KIND_TIME_EXCEEDED);
+ assert_eq!(encoded[1], u8::from(Icmpv4TimeExceeded::TtlExpired));
+ assert_eq!(&encoded[4..8], &responder.octets());
+ assert_eq!(u16::from_be_bytes([encoded[8], encoded[9]]), local_port);
+ assert_eq!(u16::from_be_bytes([encoded[10], encoded[11]]), dest_port);
+ }
+
+ #[test]
+ fn ignores_udp_trace_event_for_wrong_destination() {
+ let remote = Ipv4Address::new(203, 0, 113, 9);
+ let other_remote = Ipv4Address::new(203, 0, 113, 19);
+ let udp = udp_header(42_000, 33_434);
+ let icmp_repr = Icmpv4Repr::TimeExceeded {
+ reason: Icmpv4TimeExceeded::TtlExpired,
+ header: Ipv4Repr {
+ src_addr: Ipv4Address::new(10, 0, 2, 15),
+ dst_addr: remote,
+ next_header: IpProtocol::Udp,
+ payload_len: udp.len(),
+ hop_limit: 1,
+ },
+ data: &udp,
+ };
+
+ assert!(build_udp_trace_event(
+ &icmp_repr,
+ IpAddress::Ipv4(Ipv4Address::new(192, 0, 2, 1)),
+ IpAddress::Ipv4(other_remote),
+ 42_000,
+ &mut [0_u8; 12],
+ )
+ .is_none());
+ }
+
+ #[test]
+ fn emits_destination_unreachable_trace_event() {
+ let remote = Ipv4Address::new(203, 0, 113, 9);
+ let responder = Ipv4Address::new(203, 0, 113, 9);
+ let local_port = 42_000;
+ let dest_port = 33_450;
+ let udp = udp_header(local_port, dest_port);
+ let icmp_repr = Icmpv4Repr::DstUnreachable {
+ reason: Icmpv4DstUnreachable::PortUnreachable,
+ header: Ipv4Repr {
+ src_addr: Ipv4Address::new(10, 0, 2, 15),
+ dst_addr: remote,
+ next_header: IpProtocol::Udp,
+ payload_len: udp.len(),
+ hop_limit: 64,
+ },
+ data: &udp,
+ };
+ let mut encoded = [0_u8; 12];
+
+ let written = build_udp_trace_event(
+ &icmp_repr,
+ IpAddress::Ipv4(responder),
+ IpAddress::Ipv4(remote),
+ local_port,
+ &mut encoded,
+ )
+ .expect("expected traceroute event")
+ .expect("expected successful encoding");
+
+ assert_eq!(written, encoded.len());
+ assert_eq!(encoded[0], ICMP_UDP_TRACE_KIND_DST_UNREACHABLE);
+ assert_eq!(encoded[1], u8::from(Icmpv4DstUnreachable::PortUnreachable));
+ }
+}
+impl GenericAddress {
+ pub fn is_empty(&self) -> bool {
+ self.address == 0