diff --git a/ipcd/src/uds/stream.rs b/ipcd/src/uds/stream.rs index 81c846fd..d55d2bb7 100644 --- a/ipcd/src/uds/stream.rs +++ b/ipcd/src/uds/stream.rs @@ -180,7 +180,7 @@ pub struct Socket { options: HashSet, flags: usize, state: State, - awaiting: VecDeque, + awaiting: VecDeque<(usize, ucred)>, connection: Option, issued_token: Option, ucred: ucred, @@ -241,6 +241,7 @@ impl Socket { &mut self, primary_id: usize, awaiting_client_id: usize, + client_ucred: ucred, ctx: &CallerCtx, ) -> Result { if !self.is_listening() { @@ -250,15 +251,17 @@ impl Socket { ); return Err(Error::new(EINVAL)); } - Ok(Self::new( + Ok(Self { primary_id, - self.path.clone(), - State::Established, - self.options.clone(), - self.flags, - Some(Connection::new(awaiting_client_id)), - ctx, - )) + path: self.path.clone(), + state: State::Established, + options: self.options.clone(), + flags: self.flags, + awaiting: VecDeque::new(), + connection: Some(Connection::new(awaiting_client_id)), + issued_token: None, + ucred: client_ucred, + }) } fn establish(&mut self, new_socket: &mut Self, peer: usize) -> Result<()> { @@ -286,7 +289,7 @@ impl Socket { Ok(()) } - fn connect(&mut self, other: &mut Socket) -> Result<()> { + fn connect(&mut self, other: &mut Socket, client_ucred: ucred) -> Result<()> { match self.state { State::Unbound | State::Bound => { // If the socket is unbound or bound, wait for the listener to start listening. @@ -302,12 +305,12 @@ impl Socket { } _ => return Err(Error::new(ECONNREFUSED)), } - self.connect_unchecked(other); + self.connect_unchecked(other, client_ucred); Ok(()) } - fn connect_unchecked(&mut self, other: &mut Socket) { - self.awaiting.push_back(other.primary_id); + fn connect_unchecked(&mut self, other: &mut Socket, client_ucred: ucred) { + self.awaiting.push_back((other.primary_id, client_ucred)); other.state = State::Connecting; other.connection = Some(Connection::new(self.primary_id)); } @@ -495,7 +498,7 @@ impl<'sock> UdsStreamScheme<'sock> { }; match verb { SocketCall::Bind => self.handle_bind(id, &payload), - SocketCall::Connect => self.handle_connect(id, &payload), + SocketCall::Connect => self.handle_connect(id, &payload, ctx), SocketCall::SetSockOpt => self.handle_setsockopt( id, *metadata.get(1).ok_or(Error::new(EINVAL))? as i32, @@ -588,7 +591,7 @@ impl<'sock> UdsStreamScheme<'sock> { // and changes its own state to `Established`. // // After these three phases, the socket connection is considered established. - fn handle_connect(&mut self, id: usize, token_buf: &[u8]) -> Result { + fn handle_connect(&mut self, id: usize, token_buf: &[u8], ctx: &CallerCtx) -> Result { let token = read_num::(token_buf)?; let (listener_id, connecting_res) = { let listener_rc = self @@ -633,7 +636,8 @@ impl<'sock> UdsStreamScheme<'sock> { } // Phase 2: listener is now listening - listener.connect(&mut client)?; + let client_ucred = ucred { pid: ctx.pid as _, uid: ctx.uid as _, gid: ctx.gid as _ }; + listener.connect(&mut client, client_ucred)?; (listener_id, connecting_res) }; @@ -873,6 +877,7 @@ impl<'sock> UdsStreamScheme<'sock> { &mut self, listener_socket: &mut Socket, client_id: usize, + client_ucred: ucred, ctx: &CallerCtx, ) -> Result> { let (new_id, new) = { @@ -880,7 +885,7 @@ impl<'sock> UdsStreamScheme<'sock> { return Ok(None); // Client socket has been closed, nothing to accept }; let new_id = self.next_id; - let mut new = listener_socket.accept(new_id, client_id, ctx)?; + let mut new = listener_socket.accept(new_id, client_id, client_ucred, ctx)?; let mut client_socket = client_rc.borrow_mut(); client_socket.establish(&mut new, listener_socket.primary_id)?; @@ -912,14 +917,14 @@ impl<'sock> UdsStreamScheme<'sock> { } loop { // Try to accept a waiting connection - let Some(client_id) = socket.awaiting.pop_front() else { + let Some((client_id, client_ucred)) = socket.awaiting.pop_front() else { if flags & O_NONBLOCK == O_NONBLOCK { return Err(Error::new(EAGAIN)); } else { return Err(Error::new(EWOULDBLOCK)); } }; - return match self.accept_connection(socket, client_id, ctx) { + return match self.accept_connection(socket, client_id, client_ucred, ctx) { Ok(conn) => Ok(conn), Err(Error { errno: EAGAIN }) => continue, Err(e) => Err(e), @@ -991,7 +996,8 @@ impl<'sock> UdsStreamScheme<'sock> { ); return Err(Error::new(EPIPE)); } - socket.connect_unchecked(&mut new); + let pair_ucred = ucred { pid: ctx.pid as _, uid: ctx.uid as _, gid: ctx.gid as _ }; + socket.connect_unchecked(&mut new, pair_ucred); } // smoltcp sends writeable whenever a listener gets a @@ -1186,7 +1192,7 @@ impl<'sock> UdsStreamScheme<'sock> { } // Notify all waiting clients about listener closure - for client_id in &socket.awaiting { + for (client_id, _) in &socket.awaiting { if let Ok(client_rc) = self.get_socket(*client_id) { { let mut client = client_rc.borrow_mut();