diff --git a/src/header/sys_socket/mod.rs b/src/header/sys_socket/mod.rs --- a/src/header/sys_socket/mod.rs +++ b/src/header/sys_socket/mod.rs @@ -5,7 +5,7 @@ use core::{mem, ptr}; use crate::{ - error::ResultExt, + error::{Errno, ResultExt}, header::{bits_iovec::iovec, bits_safamily_t::sa_family_t, bits_socklen_t::socklen_t}, platform::{ PalSocket, Sys, @@ -236,6 +236,48 @@ socket, address, address_len + ) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn getpeereid(socket: c_int, euid: *mut uid_t, egid: *mut gid_t) -> c_int { + trace_expr!( + (|| { + if euid.is_null() || egid.is_null() { + return Err(Errno(crate::header::errno::EFAULT)); + } + + let mut cred = ucred { + pid: 0, + uid: 0, + gid: 0, + }; + let mut len = mem::size_of::() as socklen_t; + unsafe { + Sys::getsockopt( + socket, + constants::SOL_SOCKET, + constants::SO_PEERCRED, + &mut cred as *mut ucred as *mut c_void, + &mut len, + )?; + } + + if (len as usize) < mem::size_of::() { + return Err(Errno(crate::header::errno::EINVAL)); + } + + unsafe { + *euid = cred.uid; + *egid = cred.gid; + } + Ok(0) + })() + .or_minus_one_errno(), + "getpeereid({}, {:p}, {:p})", + socket, + euid, + egid ) } diff --git a/tests/Makefile.tests.mk b/tests/Makefile.tests.mk --- a/tests/Makefile.tests.mk +++ b/tests/Makefile.tests.mk @@ -137,6 +137,8 @@ string/stpcpy \ string/stpncpy \ strings \ + sys_socket/passcred \ + sys_socket/peercred \ sys_socket/recv \ sys_socket/recvfrom \ sys_socket/unixrecv \ diff --git a/tests/sys_socket/peercred.c b/tests/sys_socket/peercred.c new file mode 100644 --- /dev/null +++ b/tests/sys_socket/peercred.c @@ -0,0 +1,33 @@ +#include +#include +#include +#include + +#include "test_helpers.h" + +int main(void) +{ + int sv[2]; + int status = socketpair(AF_UNIX, SOCK_STREAM, 0, sv); + ERROR_IF(socketpair, status, == -1); + + struct ucred cred; + socklen_t len = sizeof(cred); + status = getsockopt(sv[0], SOL_SOCKET, SO_PEERCRED, &cred, &len); + ERROR_IF(getsockopt, status, == -1); + assert(len == sizeof(cred)); + assert(cred.uid == getuid()); + assert(cred.gid == getgid()); + assert(cred.pid > 0); + + uid_t euid; + gid_t egid; + status = getpeereid(sv[0], &euid, &egid); + ERROR_IF(getpeereid, status, == -1); + assert(euid == getuid()); + assert(egid == getgid()); + + close(sv[0]); + close(sv[1]); + return 0; +} diff --git a/tests/sys_socket/passcred.c b/tests/sys_socket/passcred.c new file mode 100644 --- /dev/null +++ b/tests/sys_socket/passcred.c @@ -0,0 +1,62 @@ +#include +#include +#include +#include +#include + +#include "test_helpers.h" + +int main(void) +{ + int sv[2]; + int status = socketpair(AF_UNIX, SOCK_STREAM, 0, sv); + ERROR_IF(socketpair, status, == -1); + + int one = 1; + status = setsockopt(sv[1], SOL_SOCKET, SO_PASSCRED, &one, sizeof(one)); + ERROR_IF(setsockopt, status, == -1); + + const char payload[] = "x"; + ssize_t sent = send(sv[0], payload, sizeof(payload), 0); + ERROR_IF(send, sent, == -1); + assert(sent == (ssize_t)sizeof(payload)); + + char byte = '\0'; + struct iovec iov = { + .iov_base = &byte, + .iov_len = sizeof(byte), + }; + union { + char buf[CMSG_SPACE(sizeof(struct ucred))]; + struct cmsghdr align; + } control; + memset(&control, 0, sizeof(control)); + + struct msghdr msg = { + .msg_name = NULL, + .msg_namelen = 0, + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = control.buf, + .msg_controllen = sizeof(control.buf), + .msg_flags = 0, + }; + + ssize_t recvd = recvmsg(sv[1], &msg, 0); + ERROR_IF(recvmsg, recvd, == -1); + assert(recvd == (ssize_t)sizeof(payload)); + assert(byte == 'x'); + + struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); + assert(cmsg != NULL); + assert(cmsg->cmsg_level == SOL_SOCKET); + assert(cmsg->cmsg_type == SCM_CREDENTIALS); + struct ucred *cred = (struct ucred *)CMSG_DATA(cmsg); + assert(cred->uid == getuid()); + assert(cred->gid == getgid()); + assert(cred->pid > 0); + + close(sv[0]); + close(sv[1]); + return 0; +}