diff --git a/ConfigureChecks.cmake b/ConfigureChecks.cmake index 2c78b83f168dabde65f1c1b20dc7c03bc0ff0b31..b820a65d2dd6f204466ed2cd82cb0ca5c6d50b92 100644 --- a/ConfigureChecks.cmake +++ b/ConfigureChecks.cmake @@ -48,8 +48,10 @@ check_include_file(sys/filio.h HAVE_SYS_FILIO_H) check_include_file(sys/signalfd.h HAVE_SYS_SIGNALFD_H) check_include_file(sys/eventfd.h HAVE_SYS_EVENTFD_H) check_include_file(sys/timerfd.h HAVE_SYS_TIMERFD_H) +check_include_file(sys/syscall.h HAVE_SYS_SYSCALL_H) check_include_file(gnu/lib-names.h HAVE_GNU_LIB_NAMES_H) check_include_file(rpc/rpc.h HAVE_RPC_RPC_H) +check_include_file(syscall.h HAVE_SYSCALL_H) # SYMBOLS set(CMAKE_REQUIRED_FLAGS -D_GNU_SOURCE) @@ -75,6 +77,9 @@ check_function_exists(pledge HAVE_PLEDGE) check_function_exists(_socket HAVE__SOCKET) check_function_exists(_close HAVE__CLOSE) check_function_exists(__close_nocancel HAVE___CLOSE_NOCANCEL) +check_function_exists(recvmmsg HAVE_RECVMMSG) +check_function_exists(sendmmsg HAVE_SENDMMSG) +check_function_exists(syscall HAVE_SYSCALL) if (UNIX) find_library(DLFCN_LIBRARY dl) @@ -147,6 +152,44 @@ if (HAVE_EVENTFD) HAVE_EVENTFD_UNSIGNED_INT) endif (HAVE_EVENTFD) +if (HAVE_SYSCALL) + set(CMAKE_REQUIRED_DEFINITIONS -D_GNU_SOURCE) + + check_prototype_definition(syscall + "int syscall(int sysno, ...)" + "-1" + "unistd.h;sys/syscall.h" + HAVE_SYSCALL_INT) + set(CMAKE_REQUIRED_DEFINITIONS) +endif (HAVE_SYSCALL) + +if (HAVE_RECVMMSG) + # Linux legacy glibc < 2.21 + set(CMAKE_REQUIRED_DEFINITIONS -D_GNU_SOURCE) + check_prototype_definition(recvmmsg + "int recvmmsg(int __fd, struct mmsghdr *__vmessages, unsigned int __vlen, int __flags, const struct timespec *__tmo)" + "-1" + "sys/types.h;sys/socket.h" + HAVE_RECVMMSG_CONST_TIMEOUT) + set(CMAKE_REQUIRED_DEFINITIONS) + + # FreeBSD + check_prototype_definition(recvmmsg + "ssize_t recvmmsg(int __fd, struct mmsghdr * __restrict __vmessages, size_t __vlen, int __flags, const struct timespec * __restrict __tmo)" + "-1" + "sys/types.h;sys/socket.h" + HAVE_RECVMMSG_SSIZE_T_CONST_TIMEOUT) +endif (HAVE_RECVMMSG) + +if (HAVE_SENDMMSG) + # FreeBSD + check_prototype_definition(sendmmsg + "ssize_t sendmmsg(int __fd, struct mmsghdr * __restrict __vmessages, size_t __vlen, int __flags)" + "-1" + "sys/types.h;sys/socket.h" + HAVE_SENDMMSG_SSIZE_T) +endif (HAVE_SENDMMSG) + # IPV6 check_c_source_compiles(" #include diff --git a/config.h.cmake b/config.h.cmake index 0f2fb09e58b224c13f9e6b47661dc9faf59b0525..a637a34da5afbff5a8d04147db4fed09d2727701 100644 --- a/config.h.cmake +++ b/config.h.cmake @@ -13,9 +13,11 @@ #cmakedefine HAVE_SYS_FILIO_H 1 #cmakedefine HAVE_SYS_SIGNALFD_H 1 #cmakedefine HAVE_SYS_EVENTFD_H 1 +#cmakedefine HAVE_SYS_SYSCALL_H 1 #cmakedefine HAVE_SYS_TIMERFD_H 1 #cmakedefine HAVE_GNU_LIB_NAMES_H 1 #cmakedefine HAVE_RPC_RPC_H 1 +#cmakedefine HAVE_SYSCALL_H 1 /**************************** STRUCTS ****************************/ @@ -52,6 +54,13 @@ #cmakedefine HAVE_ACCEPT_PSOCKLEN_T 1 #cmakedefine HAVE_IOCTL_INT 1 #cmakedefine HAVE_EVENTFD_UNSIGNED_INT 1 +#cmakedefine HAVE_RECVMMSG 1 +#cmakedefine HAVE_RECVMMSG_CONST_TIMEOUT 1 +#cmakedefine HAVE_RECVMMSG_SSIZE_T_CONST_TIMEOUT 1 +#cmakedefine HAVE_SENDMMSG 1 +#cmakedefine HAVE_SENDMMSG_SSIZE_T 1 +#cmakedefine HAVE_SYSCALL 1 +#cmakedefine HAVE_SYSCALL_INT 1 /*************************** LIBRARIES ***************************/ diff --git a/src/socket_wrapper.c b/src/socket_wrapper.c index bedda07a72b939cb8a302b40bca174d048925009..ab90c22c922a1150afc7c291968628fc3f359fbb 100644 --- a/src/socket_wrapper.c +++ b/src/socket_wrapper.c @@ -47,6 +47,12 @@ #include #include #include +#ifdef HAVE_SYS_SYSCALL_H +#include +#endif +#ifdef HAVE_SYSCALL_H +#include +#endif #include #include #ifdef HAVE_SYS_FILIO_H @@ -537,8 +543,29 @@ typedef int (*__libc_recvfrom)(int sockfd, struct sockaddr *src_addr, socklen_t *addrlen); typedef int (*__libc_recvmsg)(int sockfd, const struct msghdr *msg, int flags); +#ifdef HAVE_RECVMMSG +#if defined(HAVE_RECVMMSG_SSIZE_T_CONST_TIMEOUT) +/* FreeBSD */ +typedef ssize_t (*__libc_recvmmsg)(int sockfd, struct mmsghdr *msgvec, size_t vlen, int flags, const struct timespec *timeout); +#elif defined(HAVE_RECVMMSG_CONST_TIMEOUT) +/* Linux legacy glibc < 2.21 */ +typedef int (*__libc_recvmmsg)(int sockfd, struct mmsghdr *msgvec, unsigned int vlen, int flags, const struct timespec *timeout); +#else +/* Linux glibc >= 2.21 */ +typedef int (*__libc_recvmmsg)(int sockfd, struct mmsghdr *msgvec, unsigned int vlen, int flags, struct timespec *timeout); +#endif +#endif /* HAVE_RECVMMSG */ typedef int (*__libc_send)(int sockfd, const void *buf, size_t len, int flags); typedef int (*__libc_sendmsg)(int sockfd, const struct msghdr *msg, int flags); +#ifdef HAVE_SENDMMSG +#if defined(HAVE_SENDMMSG_SSIZE_T) +/* FreeBSD */ +typedef ssize_t (*__libc_sendmmsg)(int sockfd, struct mmsghdr *msgvec, size_t vlen, int flags); +#else +/* Linux */ +typedef int (*__libc_sendmmsg)(int sockfd, struct mmsghdr *msgvec, unsigned int vlen, int flags); +#endif +#endif /* HAVE_SENDMMSG */ typedef int (*__libc_sendto)(int sockfd, const void *buf, size_t len, @@ -560,6 +587,9 @@ typedef int (*__libc_timerfd_create)(int clockid, int flags); #endif typedef ssize_t (*__libc_write)(int fd, const void *buf, size_t count); typedef ssize_t (*__libc_writev)(int fd, const struct iovec *iov, int iovcnt); +#ifdef HAVE_SYSCALL +typedef long int (*__libc_syscall)(long int sysno, ...); +#endif #define SWRAP_SYMBOL_ENTRY(i) \ union { \ @@ -605,8 +635,14 @@ struct swrap_libc_symbols { SWRAP_SYMBOL_ENTRY(recv); SWRAP_SYMBOL_ENTRY(recvfrom); SWRAP_SYMBOL_ENTRY(recvmsg); +#ifdef HAVE_RECVMMSG + SWRAP_SYMBOL_ENTRY(recvmmsg); +#endif SWRAP_SYMBOL_ENTRY(send); SWRAP_SYMBOL_ENTRY(sendmsg); +#ifdef HAVE_SENDMMSG + SWRAP_SYMBOL_ENTRY(sendmmsg); +#endif SWRAP_SYMBOL_ENTRY(sendto); SWRAP_SYMBOL_ENTRY(setsockopt); #ifdef HAVE_SIGNALFD @@ -619,7 +655,32 @@ struct swrap_libc_symbols { #endif SWRAP_SYMBOL_ENTRY(write); SWRAP_SYMBOL_ENTRY(writev); +#ifdef HAVE_SYSCALL + SWRAP_SYMBOL_ENTRY(syscall); +#endif }; +#undef SWRAP_SYMBOL_ENTRY + +#define SWRAP_SYMBOL_ENTRY(i) \ + union { \ + __rtld_default_##i f; \ + void *obj; \ + } _rtld_default_##i + +#ifdef HAVE_SYSCALL +typedef bool (*__rtld_default_uid_wrapper_syscall_valid)(long int sysno); +typedef long int (*__rtld_default_uid_wrapper_syscall_va)(long int sysno, va_list va); +#endif + +struct swrap_rtld_default_symbols { +#ifdef HAVE_SYSCALL + SWRAP_SYMBOL_ENTRY(uid_wrapper_syscall_valid); + SWRAP_SYMBOL_ENTRY(uid_wrapper_syscall_va); +#else + uint8_t dummy; +#endif +}; +#undef SWRAP_SYMBOL_ENTRY struct swrap { struct { @@ -627,6 +688,10 @@ struct swrap { void *socket_handle; struct swrap_libc_symbols symbols; } libc; + + struct { + struct swrap_rtld_default_symbols symbols; + } rtld_default; }; static struct swrap swrap; @@ -807,6 +872,11 @@ static void _swrap_mutex_unlock(pthread_mutex_t *mutex, const char *name, const #define swrap_bind_symbol_libsocket(sym_name) \ _swrap_bind_symbol_generic(SWRAP_LIBSOCKET, sym_name) +#define swrap_bind_symbol_rtld_default_optional(sym_name) do { \ + swrap.rtld_default.symbols._rtld_default_##sym_name.obj = \ + dlsym(RTLD_DEFAULT, #sym_name); \ +} while(0); + static void swrap_bind_symbol_all(void); /**************************************************************************** @@ -1131,6 +1201,24 @@ static int libc_recvmsg(int sockfd, struct msghdr *msg, int flags) return swrap.libc.symbols._libc_recvmsg.f(sockfd, msg, flags); } +#ifdef HAVE_RECVMMSG +#if defined(HAVE_RECVMMSG_SSIZE_T_CONST_TIMEOUT) +/* FreeBSD */ +static ssize_t libc_recvmmsg(int sockfd, struct mmsghdr *msgvec, size_t vlen, int flags, const struct timespec *timeout) +#elif defined(HAVE_RECVMMSG_CONST_TIMEOUT) +/* Linux legacy glibc < 2.21 */ +static int libc_recvmmsg(int sockfd, struct mmsghdr *msgvec, unsigned int vlen, int flags, const struct timespec *timeout) +#else +/* Linux glibc >= 2.21 */ +static int libc_recvmmsg(int sockfd, struct mmsghdr *msgvec, unsigned int vlen, int flags, struct timespec *timeout) +#endif +{ + swrap_bind_symbol_all(); + + return swrap.libc.symbols._libc_recvmmsg.f(sockfd, msgvec, vlen, flags, timeout); +} +#endif + static int libc_send(int sockfd, const void *buf, size_t len, int flags) { swrap_bind_symbol_all(); @@ -1145,6 +1233,21 @@ static int libc_sendmsg(int sockfd, const struct msghdr *msg, int flags) return swrap.libc.symbols._libc_sendmsg.f(sockfd, msg, flags); } +#ifdef HAVE_SENDMMSG +#if defined(HAVE_SENDMMSG_SSIZE_T) +/* FreeBSD */ +static ssize_t libc_sendmmsg(int sockfd, struct mmsghdr *msgvec, size_t vlen, int flags) +#else +/* Linux */ +static int libc_sendmmsg(int sockfd, struct mmsghdr *msgvec, unsigned int vlen, int flags) +#endif +{ + swrap_bind_symbol_all(); + + return swrap.libc.symbols._libc_sendmmsg.f(sockfd, msgvec, vlen, flags); +} +#endif + static int libc_sendto(int sockfd, const void *buf, size_t len, @@ -1223,6 +1326,64 @@ static ssize_t libc_writev(int fd, const struct iovec *iov, int iovcnt) return swrap.libc.symbols._libc_writev.f(fd, iov, iovcnt); } +#ifdef HAVE_SYSCALL +DO_NOT_SANITIZE_ADDRESS_ATTRIBUTE +static long int libc_vsyscall(long int sysno, va_list va) +{ + long int args[8]; + long int rc; + int i; + + swrap_bind_symbol_all(); + + for (i = 0; i < 8; i++) { + args[i] = va_arg(va, long int); + } + + rc = swrap.libc.symbols._libc_syscall.f(sysno, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[6], + args[7]); + + return rc; +} + +static bool swrap_uwrap_syscall_valid(long int sysno) +{ + swrap_bind_symbol_all(); + + if (swrap.rtld_default.symbols._rtld_default_uid_wrapper_syscall_valid.f == NULL) { + return false; + } + + return swrap.rtld_default.symbols._rtld_default_uid_wrapper_syscall_valid.f( + sysno); +} + +DO_NOT_SANITIZE_ADDRESS_ATTRIBUTE +static long int swrap_uwrap_syscall_va(long int sysno, va_list va) +{ + swrap_bind_symbol_all(); + + if (swrap.rtld_default.symbols._rtld_default_uid_wrapper_syscall_va.f == NULL) { + /* + * Fallback to libc, if uid_wrapper_syscall_va is not + * available. + */ + return libc_vsyscall(sysno, va); + } + + return swrap.rtld_default.symbols._rtld_default_uid_wrapper_syscall_va.f( + sysno, + va); +} +#endif /* HAVE_SYSCALL */ + /* DO NOT call this function during library initialization! */ static void __swrap_bind_symbol_all_once(void) { @@ -1263,8 +1424,14 @@ static void __swrap_bind_symbol_all_once(void) swrap_bind_symbol_libsocket(recv); swrap_bind_symbol_libsocket(recvfrom); swrap_bind_symbol_libsocket(recvmsg); +#ifdef HAVE_RECVMMSG + swrap_bind_symbol_libsocket(recvmmsg); +#endif swrap_bind_symbol_libsocket(send); swrap_bind_symbol_libsocket(sendmsg); +#ifdef HAVE_SENDMMSG + swrap_bind_symbol_libsocket(sendmmsg); +#endif swrap_bind_symbol_libsocket(sendto); swrap_bind_symbol_libsocket(setsockopt); #ifdef HAVE_SIGNALFD @@ -1277,6 +1444,11 @@ static void __swrap_bind_symbol_all_once(void) #endif swrap_bind_symbol_libc(write); swrap_bind_symbol_libsocket(writev); +#ifdef HAVE_SYSCALL + swrap_bind_symbol_libc(syscall); + swrap_bind_symbol_rtld_default_optional(uid_wrapper_syscall_valid); + swrap_bind_symbol_rtld_default_optional(uid_wrapper_syscall_va); +#endif } static void swrap_bind_symbol_all(void) @@ -1429,6 +1601,55 @@ static size_t socket_length(int family) return 0; } +struct swrap_sockaddr_buf { + char str[128]; +}; + +static const char *swrap_sockaddr_string(struct swrap_sockaddr_buf *buf, + const struct sockaddr *saddr) +{ + unsigned int port = 0; + char addr[64] = {0,}; + + switch (saddr->sa_family) { + case AF_INET: { + const struct sockaddr_in *in = + (const struct sockaddr_in *)(const void *)saddr; + + port = ntohs(in->sin_port); + + inet_ntop(saddr->sa_family, + &in->sin_addr, + addr, sizeof(addr)); + break; + } +#ifdef HAVE_IPV6 + case AF_INET6: { + const struct sockaddr_in6 *in6 = + (const struct sockaddr_in6 *)(const void *)saddr; + + port = ntohs(in6->sin6_port); + + inet_ntop(saddr->sa_family, + &in6->sin6_addr, + addr, sizeof(addr)); + break; + } +#endif + default: + snprintf(addr, sizeof(addr), + "", + saddr->sa_family); + break; + } + + snprintf(buf->str, sizeof(buf->str), + "addr[%s]/port[%u]", + addr, port); + + return buf->str; +} + static struct socket_info *swrap_get_socket_info(int si_index) { return (struct socket_info *)(&(sockets[si_index].info)); @@ -2064,13 +2285,10 @@ static int convert_in_un_remote(struct socket_info *si, const struct sockaddr *i type = u_type; iface = (addr & 0x000000FF); } else { - char str[256] = {0,}; - inet_ntop(inaddr->sa_family, - &in->sin_addr, - str, sizeof(str)); + struct swrap_sockaddr_buf buf = {}; SWRAP_LOG(SWRAP_LOG_WARN, - "str[%s] prt[%u]", - str, (unsigned)prt); + "%s", + swrap_sockaddr_string(&buf, inaddr)); errno = ENETUNREACH; return -1; } @@ -2106,13 +2324,10 @@ static int convert_in_un_remote(struct socket_info *si, const struct sockaddr *i if (IN6_ARE_ADDR_EQUAL(&cmp1, &cmp2)) { iface = in->sin6_addr.s6_addr[15]; } else { - char str[256] = {0,}; - inet_ntop(inaddr->sa_family, - &in->sin6_addr, - str, sizeof(str)); + struct swrap_sockaddr_buf buf = {}; SWRAP_LOG(SWRAP_LOG_WARN, - "str[%s] prt[%u]", - str, (unsigned)prt); + "%s", + swrap_sockaddr_string(&buf, inaddr)); errno = ENETUNREACH; return -1; } @@ -3984,6 +4199,7 @@ static int swrap_connect(int s, const struct sockaddr *serv_addr, .sa_socklen = sizeof(struct sockaddr_un), }; struct socket_info *si = find_socket_info(s); + struct swrap_sockaddr_buf buf = {}; int bcast = 0; if (!si) { @@ -4032,7 +4248,8 @@ static int swrap_connect(int s, const struct sockaddr *serv_addr, } SWRAP_LOG(SWRAP_LOG_TRACE, - "connect() path=%s, fd=%d", + "connect(%s) path=%s, fd=%d", + swrap_sockaddr_string(&buf, serv_addr), un_addr.sa.un.sun_path, s); @@ -4098,6 +4315,8 @@ static int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen) .sa_socklen = sizeof(struct sockaddr_un), }; struct socket_info *si = find_socket_info(s); + struct swrap_sockaddr_buf buf = {}; + int ret_errno = errno; int bind_error = 0; #if 0 /* FIXME */ bool in_use; @@ -4155,7 +4374,7 @@ static int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen) } if (bind_error != 0) { - errno = bind_error; + ret_errno = bind_error; ret = -1; goto out; } @@ -4179,16 +4398,21 @@ static int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen) 1, &si->bcast); if (ret == -1) { + ret_errno = errno; goto out; } unlink(un_addr.sa.un.sun_path); ret = libc_bind(s, &un_addr.sa.s, un_addr.sa_socklen); + if (ret == -1) { + ret_errno = errno; + } SWRAP_LOG(SWRAP_LOG_TRACE, - "bind() path=%s, fd=%d", - un_addr.sa.un.sun_path, s); + "bind(%s) path=%s, fd=%d ret=%d ret_errno=%d", + swrap_sockaddr_string(&buf, myaddr), + un_addr.sa.un.sun_path, s, ret, ret_errno); if (ret == 0) { si->bound = 1; @@ -4196,7 +4420,7 @@ static int swrap_bind(int s, const struct sockaddr *myaddr, socklen_t addrlen) out: SWRAP_UNLOCK_SI(si); - + errno = ret_errno; return ret; } @@ -6120,6 +6344,7 @@ static ssize_t swrap_sendmsg_before(int fd, { size_t i, len = 0; ssize_t ret = -1; + struct swrap_sockaddr_buf buf = {}; if (to_un) { *to_un = NULL; @@ -6181,6 +6406,10 @@ static ssize_t swrap_sendmsg_before(int fd, msg->msg_name = NULL; msg->msg_namelen = 0; } + SWRAP_LOG(SWRAP_LOG_TRACE, + "connected(%s) fd=%d", + swrap_sockaddr_string(&buf, &si->peername.sa.s), + fd); } else { const struct sockaddr *msg_name; msg_name = (const struct sockaddr *)msg->msg_name; @@ -6235,6 +6464,11 @@ static ssize_t swrap_sendmsg_before(int fd, goto out; } + SWRAP_LOG(SWRAP_LOG_TRACE, + "deferred connect(%s) path=%s, fd=%d", + swrap_sockaddr_string(&buf, &si->peername.sa.s), + tmp_un->sun_path, fd); + ret = libc_connect(fd, (struct sockaddr *)(void *)tmp_un, sizeof(*tmp_un)); @@ -7135,6 +7369,219 @@ ssize_t recvmsg(int sockfd, struct msghdr *msg, int flags) return swrap_recvmsg(sockfd, msg, flags); } +/**************************************************************************** + * RECVMMSG + ***************************************************************************/ + +#ifdef HAVE_RECVMMSG +#if defined(HAVE_RECVMMSG_SSIZE_T_CONST_TIMEOUT) +/* FreeBSD */ +static ssize_t swrap_recvmmsg(int s, struct mmsghdr *omsgvec, size_t _vlen, int flags, const struct timespec *timeout) +#elif defined(HAVE_RECVMMSG_CONST_TIMEOUT) +/* Linux legacy glibc < 2.21 */ +static int swrap_recvmmsg(int s, struct mmsghdr *omsgvec, unsigned int _vlen, int flags, const struct timespec *timeout) +#else +/* Linux glibc >= 2.21 */ +static int swrap_recvmmsg(int s, struct mmsghdr *omsgvec, unsigned int _vlen, int flags, struct timespec *timeout) +#endif +{ + struct socket_info *si = find_socket_info(s); +#define __SWRAP_RECVMMSG_MAX_VLEN 16 + struct mmsghdr msgvec[__SWRAP_RECVMMSG_MAX_VLEN] = {}; + struct { + struct iovec iov; + struct swrap_address from_addr; + struct swrap_address convert_addr; +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + size_t msg_ctrllen_filled; + size_t msg_ctrllen_left; +#endif + } tmp[__SWRAP_RECVMMSG_MAX_VLEN] = {}; + int vlen; + int i; + int ret; + int rc; + int saved_errno; + + if (_vlen > __SWRAP_RECVMMSG_MAX_VLEN) { + vlen = __SWRAP_RECVMMSG_MAX_VLEN; + } else { + vlen = _vlen; + } + + if (si == NULL) { + uint8_t *tmp_control[__SWRAP_RECVMMSG_MAX_VLEN] = { NULL, }; + + for (i = 0; i < vlen; i++) { + struct msghdr *omsg = &omsgvec[i].msg_hdr; + struct msghdr *msg = &msgvec[i].msg_hdr; + + rc = swrap_recvmsg_before_unix(omsg, msg, + &tmp_control[i]); + if (rc < 0) { + ret = rc; + goto fail_libc; + } + } + + ret = libc_recvmmsg(s, msgvec, vlen, flags, timeout); + if (ret < 0) { + goto fail_libc; + } + + for (i = 0; i < ret; i++) { + omsgvec[i].msg_len = msgvec[i].msg_len; + } + +fail_libc: + saved_errno = errno; + for (i = 0; i < vlen; i++) { + struct msghdr *omsg = &omsgvec[i].msg_hdr; + struct msghdr *msg = &msgvec[i].msg_hdr; + + if (i == 0 || i < ret) { + swrap_recvmsg_after_unix(msg, &tmp_control[i], omsg, ret); + } + SAFE_FREE(tmp_control[i]); + } + errno = saved_errno; + + return ret; + } + + for (i = 0; i < vlen; i++) { + struct msghdr *omsg = &omsgvec[i].msg_hdr; + struct msghdr *msg = &msgvec[i].msg_hdr; + + tmp[i].from_addr.sa_socklen = sizeof(struct sockaddr_un); + tmp[i].convert_addr.sa_socklen = sizeof(struct sockaddr_storage); + + msg->msg_name = &tmp[i].from_addr.sa; /* optional address */ + msg->msg_namelen = tmp[i].from_addr.sa_socklen; /* size of address */ + msg->msg_iov = omsg->msg_iov; /* scatter/gather array */ + msg->msg_iovlen = omsg->msg_iovlen; /* # elements in msg_iov */ +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + tmp[i].msg_ctrllen_filled = 0; + tmp[i].msg_ctrllen_left = omsg->msg_controllen; + + msg->msg_control = omsg->msg_control; /* ancillary data, see below */ + msg->msg_controllen = omsg->msg_controllen; /* ancillary data buffer len */ + msg->msg_flags = omsg->msg_flags; /* flags on received message */ +#endif + + rc = swrap_recvmsg_before(s, si, msg, &tmp[i].iov); + if (rc < 0) { + ret = rc; + goto fail_swrap; + } + } + + ret = libc_recvmmsg(s, msgvec, vlen, flags, timeout); + if (ret < 0) { + goto fail_swrap; + } + + for (i = 0; i < ret; i++) { + omsgvec[i].msg_len = msgvec[i].msg_len; + } + +fail_swrap: + + saved_errno = errno; + for (i = 0; i < vlen; i++) { + struct msghdr *omsg = &omsgvec[i].msg_hdr; + struct msghdr *msg = &msgvec[i].msg_hdr; + + if (!(i == 0 || i < ret)) { + break; + } + +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + tmp[i].msg_ctrllen_filled += msg->msg_controllen; + tmp[i].msg_ctrllen_left -= msg->msg_controllen; + + if (omsg->msg_control != NULL) { + uint8_t *p; + + p = omsg->msg_control; + p += tmp[i].msg_ctrllen_filled; + + msg->msg_control = p; + msg->msg_controllen = tmp[i].msg_ctrllen_left; + } else { + msg->msg_control = NULL; + msg->msg_controllen = 0; + } +#endif + + /* + * We convert the unix address to a IP address so we need a buffer + * which can store the address in case of SOCK_DGRAM, see below. + */ + msg->msg_name = &tmp[i].convert_addr.sa; + msg->msg_namelen = tmp[i].convert_addr.sa_socklen; + + swrap_recvmsg_after(s, si, msg, + &tmp[i].from_addr.sa.un, + tmp[i].from_addr.sa_socklen, + ret); + +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + if (omsg->msg_control != NULL) { + /* msg->msg_controllen = space left */ + tmp[i].msg_ctrllen_left = msg->msg_controllen; + tmp[i].msg_ctrllen_filled = omsg->msg_controllen - tmp[i].msg_ctrllen_left; + } + + /* Update the original message length */ + omsg->msg_controllen = tmp[i].msg_ctrllen_filled; + omsg->msg_flags = msg->msg_flags; +#endif + omsg->msg_iovlen = msg->msg_iovlen; + + SWRAP_LOCK_SI(si); + + /* + * From the manpage: + * + * The msg_name field points to a caller-allocated buffer that is + * used to return the source address if the socket is unconnected. The + * caller should set msg_namelen to the size of this buffer before this + * call; upon return from a successful call, msg_name will contain the + * length of the returned address. If the application does not need + * to know the source address, msg_name can be specified as NULL. + */ + if (si->type == SOCK_STREAM) { + omsg->msg_namelen = 0; + } else if (omsg->msg_name != NULL && + omsg->msg_namelen != 0 && + omsg->msg_namelen >= msg->msg_namelen) { + memcpy(omsg->msg_name, msg->msg_name, msg->msg_namelen); + omsg->msg_namelen = msg->msg_namelen; + } + + SWRAP_UNLOCK_SI(si); + } + errno = saved_errno; + + return ret; +} + +#if defined(HAVE_RECVMMSG_SSIZE_T_CONST_TIMEOUT) +/* FreeBSD */ +ssize_t recvmmsg(int sockfd, struct mmsghdr *msgvec, size_t vlen, int flags, const struct timespec *timeout) +#elif defined(HAVE_RECVMMSG_CONST_TIMEOUT) +/* Linux legacy glibc < 2.21 */ +int recvmmsg(int sockfd, struct mmsghdr *msgvec, unsigned int vlen, int flags, const struct timespec *timeout) +#else +/* Linux glibc >= 2.21 */ +int recvmmsg(int sockfd, struct mmsghdr *msgvec, unsigned int vlen, int flags, struct timespec *timeout) +#endif +{ + return swrap_recvmmsg(sockfd, msgvec, vlen, flags, timeout); +} +#endif /* HAVE_RECVMMSG */ + /**************************************************************************** * SENDMSG ***************************************************************************/ @@ -7306,6 +7753,249 @@ ssize_t sendmsg(int s, const struct msghdr *omsg, int flags) return swrap_sendmsg(s, omsg, flags); } +/**************************************************************************** + * SENDMMSG + ***************************************************************************/ + +#ifdef HAVE_SENDMMSG +#if defined(HAVE_SENDMMSG_SSIZE_T) +/* FreeBSD */ +static ssize_t swrap_sendmmsg(int s, struct mmsghdr *omsgvec, size_t _vlen, int flags) +#else +/* Linux */ +static int swrap_sendmmsg(int s, struct mmsghdr *omsgvec, unsigned int _vlen, int flags) +#endif +{ + struct socket_info *si = find_socket_info(s); +#define __SWRAP_SENDMMSG_MAX_VLEN 16 + struct mmsghdr msgvec[__SWRAP_SENDMMSG_MAX_VLEN] = {}; + struct { + struct iovec iov; + struct sockaddr_un un_addr; + const struct sockaddr_un *to_un; + const struct sockaddr *to; + int bcast; + } tmp[__SWRAP_SENDMMSG_MAX_VLEN] = {}; + int vlen; + int i; + char *swrap_dir = NULL; + int connected = 0; + int found_bcast = 0; + int ret; + int rc; + int saved_errno; + + if (_vlen > __SWRAP_SENDMMSG_MAX_VLEN) { + vlen = __SWRAP_SENDMMSG_MAX_VLEN; + } else { + vlen = _vlen; + } + + if (!si) { + int scm_rights_pipe_fd[__SWRAP_SENDMMSG_MAX_VLEN]; + + for (i = 0; i < __SWRAP_SENDMMSG_MAX_VLEN; i++) { + scm_rights_pipe_fd[i] = -1; + } + + for (i = 0; i < vlen; i++) { + struct msghdr *omsg = &omsgvec[i].msg_hdr; + struct msghdr *msg = &msgvec[i].msg_hdr; + + rc = swrap_sendmsg_before_unix(omsg, msg, + &scm_rights_pipe_fd[i]); + if (rc < 0) { + ret = rc; + goto fail_libc; + } + } + + ret = libc_sendmmsg(s, msgvec, vlen, flags); + if (ret < 0) { + goto fail_libc; + } + + for (i = 0; i < ret; i++) { + omsgvec[i].msg_len = msgvec[i].msg_len; + } + +fail_libc: + saved_errno = errno; + for (i = 0; i < vlen; i++) { + struct msghdr *msg = &msgvec[i].msg_hdr; + + swrap_sendmsg_after_unix(msg, ret, + scm_rights_pipe_fd[i]); + } + errno = saved_errno; + + return ret; + } + + SWRAP_LOCK_SI(si); + connected = si->connected; + SWRAP_UNLOCK_SI(si); + + for (i = 0; i < vlen; i++) { + struct msghdr *omsg = &omsgvec[i].msg_hdr; + struct msghdr *msg = &msgvec[i].msg_hdr; + + if (connected == 0) { + msg->msg_name = omsg->msg_name; /* optional address */ + msg->msg_namelen = omsg->msg_namelen; /* size of address */ + } + msg->msg_iov = omsg->msg_iov; /* scatter/gather array */ + msg->msg_iovlen = omsg->msg_iovlen; /* # elements in msg_iov */ + +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + if (omsg->msg_controllen > 0 && omsg->msg_control != NULL) { + uint8_t *cmbuf = NULL; + size_t cmlen = 0; + + rc = swrap_sendmsg_filter_cmsghdr(omsg, &cmbuf, &cmlen); + if (rc < 0) { + ret = rc; + goto fail_swrap; + } + + if (cmlen != 0) { + msg->msg_control = cmbuf; + msg->msg_controllen = cmlen; + } + } + msg->msg_flags = omsg->msg_flags; /* flags on received message */ +#endif + + rc = swrap_sendmsg_before(s, si, msg, + &tmp[i].iov, + &tmp[i].un_addr, + &tmp[i].to_un, + &tmp[i].to, + &tmp[i].bcast); + if (rc < 0) { + ret = rc; + goto fail_swrap; + } + + if (tmp[i].bcast) { + found_bcast = 1; + } + } + + if (found_bcast) { + + swrap_dir = socket_wrapper_dir(); + if (swrap_dir == NULL) { + ret = -1; + goto fail_swrap; + } + + for (i = 0; i < vlen; i++) { + struct msghdr *msg = &msgvec[i].msg_hdr; + struct sockaddr_un *un_addr = &tmp[i].un_addr; + const struct sockaddr *to = tmp[i].to; + struct stat st; + unsigned int iface; + unsigned int prt = ntohs(((const struct sockaddr_in *)(const void *)to)->sin_port); + char type; + size_t l, len = 0; + uint8_t *buf; + off_t ofs = 0; + size_t avail = 0; + size_t remain; + + for (l = 0; l < (size_t)msg->msg_iovlen; l++) { + avail += msg->msg_iov[l].iov_len; + } + + len = avail; + remain = avail; + + /* we capture it as one single packet */ + buf = (uint8_t *)malloc(remain); + if (!buf) { + ret = -1; + goto fail_swrap; + } + + for (l = 0; l < (size_t)msg->msg_iovlen; l++) { + size_t this_time = MIN(remain, (size_t)msg->msg_iov[l].iov_len); + memcpy(buf + ofs, + msg->msg_iov[l].iov_base, + this_time); + ofs += this_time; + remain -= this_time; + } + + type = SOCKET_TYPE_CHAR_UDP; + + for(iface=0; iface <= MAX_WRAPPED_INTERFACES; iface++) { + swrap_un_path(un_addr, swrap_dir, type, iface, prt); + if (stat(un_addr->sun_path, &st) != 0) continue; + + msg->msg_name = un_addr; /* optional address */ + msg->msg_namelen = sizeof(*un_addr); /* size of address */ + + /* + * ignore the any errors in broadcast sends and + * do a single sendmsg instead of sendmmsg + */ + libc_sendmsg(s, msg, flags); + } + + SWRAP_LOCK_SI(si); + swrap_pcap_dump_packet(si, to, SWRAP_SENDTO, buf, len); + SWRAP_UNLOCK_SI(si); + + SAFE_FREE(buf); + + msgvec[i].msg_len = len; + } + + ret = vlen; + goto bcast_done; + } + + ret = libc_sendmmsg(s, msgvec, vlen, flags); + if (ret < 0) { + goto fail_swrap; + } + +bcast_done: + for (i = 0; i < ret; i++) { + omsgvec[i].msg_len = msgvec[i].msg_len; + } + +fail_swrap: + saved_errno = errno; + for (i = 0; i < vlen; i++) { + struct msghdr *msg = &msgvec[i].msg_hdr; + + if (i == 0 || i < ret) { + swrap_sendmsg_after(s, si, msg, tmp[i].to, ret); + } +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + SAFE_FREE(msg->msg_control); +#endif + } + SAFE_FREE(swrap_dir); + errno = saved_errno; + + return ret; +} + +#if defined(HAVE_SENDMMSG_SSIZE_T) +/* FreeBSD */ +ssize_t sendmmsg(int s, struct mmsghdr *msgvec, size_t vlen, int flags) +#else +/* Linux */ +int sendmmsg(int s, struct mmsghdr *msgvec, unsigned int vlen, int flags) +#endif +{ + return swrap_sendmmsg(s, msgvec, vlen, flags); +} +#endif /* HAVE_SENDMMSG */ + /**************************************************************************** * READV ***************************************************************************/ @@ -7773,6 +8463,169 @@ int pledge(const char *promises, const char *paths[]) } #endif /* HAVE_PLEDGE */ +#ifdef HAVE_SYSCALL +static bool swrap_is_swrap_related_syscall(long int sysno) +{ + switch (sysno) { +#ifdef SYS_close + case SYS_close: + return true; +#endif /* SYS_close */ + +#ifdef SYS_recvmmsg + case SYS_recvmmsg: + return true; +#endif /* SYS_recvmmsg */ + +#ifdef SYS_sendmmsg + case SYS_sendmmsg: + return true; +#endif /* SYS_sendmmsg */ + + default: + return false; + } +} + +static long int swrap_syscall(long int sysno, va_list vp) +{ + long int rc; + + switch (sysno) { +#ifdef SYS_close + case SYS_close: + { + int fd = (int) va_arg(vp, int); + + SWRAP_LOG(SWRAP_LOG_TRACE, + "calling swrap_close syscall %lu", + sysno); + rc = swrap_close(fd); + } + break; +#endif /* SYS_close */ + +#ifdef SYS_recvmmsg + case SYS_recvmmsg: + { + int fd = (int) va_arg(vp, int); + struct mmsghdr *msgvec = va_arg(vp, struct mmsghdr *); + unsigned int vlen = va_arg(vp, unsigned int); + int flags = va_arg(vp, int); + struct timespec *timeout = va_arg(vp, struct timespec *); + + SWRAP_LOG(SWRAP_LOG_TRACE, + "calling swrap_recvmmsg syscall %lu", + sysno); + rc = swrap_recvmmsg(fd, msgvec, vlen, flags, timeout); + } + break; +#endif /* SYS_recvmmsg */ + +#ifdef SYS_sendmmsg + case SYS_sendmmsg: + { + int fd = (int) va_arg(vp, int); + struct mmsghdr *msgvec = va_arg(vp, struct mmsghdr *); + unsigned int vlen = va_arg(vp, unsigned int); + int flags = va_arg(vp, int); + + SWRAP_LOG(SWRAP_LOG_TRACE, + "calling swrap_sendmmsg syscall %lu", + sysno); + rc = swrap_sendmmsg(fd, msgvec, vlen, flags); + } + break; +#endif /* SYS_sendmmsg */ + + default: + rc = -1; + errno = ENOSYS; + break; + } + + return rc; +} + +#ifdef HAVE_SYSCALL_INT +int syscall (int sysno, ...) +#else +long int syscall (long int sysno, ...) +#endif +{ +#ifdef HAVE_SYSCALL_INT + int rc; +#else + long int rc; +#endif + va_list va; + + va_start(va, sysno); + + /* + * We should only handle the syscall numbers + * we care about... + */ + if (!swrap_is_swrap_related_syscall(sysno)) { + /* + * We need to give socket_wrapper a + * chance to take over... + */ + if (swrap_uwrap_syscall_valid(sysno)) { + rc = swrap_uwrap_syscall_va(sysno, va); + va_end(va); + return rc; + } + + rc = libc_vsyscall(sysno, va); + va_end(va); + return rc; + } + + if (!socket_wrapper_enabled()) { + rc = libc_vsyscall(sysno, va); + va_end(va); + return rc; + } + + rc = swrap_syscall(sysno, va); + va_end(va); + + return rc; +} + +/* used by uid_wrapper */ +bool socket_wrapper_syscall_valid(long int sysno); +bool socket_wrapper_syscall_valid(long int sysno) +{ + if (!swrap_is_swrap_related_syscall(sysno)) { + return false; + } + + if (!socket_wrapper_enabled()) { + return false; + } + + return true; +} + +/* used by uid_wrapper */ +long int socket_wrapper_syscall_va(long int sysno, va_list va); +long int socket_wrapper_syscall_va(long int sysno, va_list va) +{ + if (!swrap_is_swrap_related_syscall(sysno)) { + errno = ENOSYS; + return -1; + } + + if (!socket_wrapper_enabled()) { + return libc_vsyscall(sysno, va); + } + + return swrap_syscall(sysno, va); +} +#endif /* HAVE_SYSCALL */ + static void swrap_thread_prepare(void) { /* diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 2f98af1460163dad8b6da891160b05da7b41c5b8..6c3aae9026b2b33709f2e3162dc093daa1d37443 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -32,6 +32,15 @@ target_link_libraries(${TORTURE_LIBRARY} ${SWRAP_REQUIRED_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) +add_library(swrap_fake_uid_wrapper SHARED swrap_fake_uid_wrapper.c) +target_compile_options(swrap_fake_uid_wrapper + PRIVATE + ${DEFAULT_C_COMPILE_FLAGS} + -D_GNU_SOURCE) +#target_include_directories(swrap_fake_uid_wrapper +# PRIVATE ${CMAKE_BINARY_DIR} ${CMOCKA_INCLUDE_DIR}) +set(SWRAP_FAKE_UID_WRAPPER_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_SHARED_LIBRARY_PREFIX}swrap_fake_uid_wrapper${CMAKE_SHARED_LIBRARY_SUFFIX}") + set(SWRAP_THREADED_TESTS test_thread_sockets test_thread_echo_tcp_connect @@ -48,6 +57,7 @@ set(SWRAP_TESTS test_echo_tcp_bind test_echo_tcp_socket_options test_echo_tcp_sendmsg_recvmsg + test_echo_tcp_sendmmsg_recvmmsg test_echo_tcp_write_read test_echo_tcp_writev_readv test_echo_tcp_get_peer_sock_name @@ -59,6 +69,7 @@ set(SWRAP_TESTS test_public_functions test_close_failure test_tcp_socket_overwrite + test_syscall_uwrap ${SWRAP_THREADED_TESTS}) if (HAVE_STRUCT_MSGHDR_MSG_CONTROL) @@ -84,6 +95,7 @@ function(ADD_CMOCKA_TEST_ENVIRONMENT _TEST_NAME) if (ASAN_LIBRARY) list(APPEND PRELOAD_LIBRARIES ${ASAN_LIBRARY}) endif() + list(APPEND PRELOAD_LIBRARIES ${SWRAP_FAKE_UID_WRAPPER_LOCATION}) list(APPEND PRELOAD_LIBRARIES ${SOCKET_WRAPPER_LOCATION}) if (OSX) diff --git a/tests/swrap_fake_uid_wrapper.c b/tests/swrap_fake_uid_wrapper.c new file mode 100644 index 0000000000000000000000000000000000000000..286d7addfd8ce75867d4feac853aa4708075e51c --- /dev/null +++ b/tests/swrap_fake_uid_wrapper.c @@ -0,0 +1,44 @@ +#include "config.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#ifdef HAVE_SYS_SYSCALL_H +#include +#endif +#ifdef HAVE_SYSCALL_H +#include +#endif + +#include "swrap_fake_uid_wrapper.h" + +/* simulate uid_wrapper hooks */ +bool uid_wrapper_syscall_valid(long int sysno) +{ + if (sysno == __FAKE_UID_WRAPPER_SYSCALL_NO) { + return true; + } + + return false; +} + +long int uid_wrapper_syscall_va(long int sysno, va_list va) +{ + (void) va; /* unused */ + + if (sysno == __FAKE_UID_WRAPPER_SYSCALL_NO) { + errno = 0; + return __FAKE_UID_WRAPPER_SYSCALL_RC; + } + + errno = ENOSYS; + return -1; +} diff --git a/tests/swrap_fake_uid_wrapper.h b/tests/swrap_fake_uid_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..70ac1d02b78bed8a8ca4a316bf1703ed22878690 --- /dev/null +++ b/tests/swrap_fake_uid_wrapper.h @@ -0,0 +1,7 @@ +#include + +/* simulate socket_wrapper hooks */ +#define __FAKE_UID_WRAPPER_SYSCALL_NO 123456789 +#define __FAKE_UID_WRAPPER_SYSCALL_RC 987654321 +bool uid_wrapper_syscall_valid(long int sysno); +long int uid_wrapper_syscall_va(long int sysno, va_list va); diff --git a/tests/test_echo_tcp_sendmmsg_recvmmsg.c b/tests/test_echo_tcp_sendmmsg_recvmmsg.c new file mode 100644 index 0000000000000000000000000000000000000000..5715ab27ddd5acdcdbae0be5e30e1be70b94d51a --- /dev/null +++ b/tests/test_echo_tcp_sendmmsg_recvmmsg.c @@ -0,0 +1,427 @@ +#include +#include +#include +#include + +#include "config.h" +#include "torture.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef HAVE_SYS_SYSCALL_H +#include +#endif +#ifdef HAVE_SYSCALL_H +#include +#endif + +static int setup_echo_srv_tcp_ipv4(void **state) +{ + torture_setup_echo_srv_tcp_ipv4(state); + + return 0; +} + +#ifdef HAVE_IPV6 +static int setup_echo_srv_tcp_ipv6(void **state) +{ + torture_setup_echo_srv_tcp_ipv6(state); + + return 0; +} +#endif + +static int teardown(void **state) +{ + torture_teardown_echo_srv(state); + + return 0; +} + +static void test_sendmmsg_recvmmsg_ipv4_ignore(void **state) +{ + struct torture_address addr = { + .sa_socklen = sizeof(struct sockaddr_storage), + }; + struct { + struct torture_address reply_addr; + struct iovec s_iov; + struct iovec r_iov; + char send_buf[64]; + char recv_buf[64]; + } tmsgs[10] = {}; + struct mmsghdr s_msgs[10] = {}; + struct mmsghdr r_msgs[10] = {}; + ssize_t ret; + int rc; + int i; + int s; + + (void) state; /* unused */ + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + assert_int_not_equal(s, -1); + + addr.sa.in.sin_family = AF_INET; + addr.sa.in.sin_port = htons(torture_server_port()); + + rc = inet_pton(AF_INET, + torture_server_address(AF_INET), + &addr.sa.in.sin_addr); + assert_int_equal(rc, 1); + + rc = connect(s, &addr.sa.s, addr.sa_socklen); + assert_return_code(rc, errno); + + /* This should be ignored */ + rc = inet_pton(AF_INET, + "127.0.0.1", + &addr.sa.in.sin_addr); + assert_int_equal(rc, 1); + + for (i = 0; i < 10; i++) { + tmsgs[i].reply_addr = (struct torture_address){ + .sa_socklen = sizeof(struct sockaddr_storage), + }; + + snprintf(tmsgs[i].send_buf, sizeof(tmsgs[i].send_buf), "packet.%d", i); + + tmsgs[i].s_iov.iov_base = tmsgs[i].send_buf; + tmsgs[i].s_iov.iov_len = sizeof(tmsgs[i].send_buf); + + s_msgs[i].msg_hdr.msg_name = &addr.sa.s; + s_msgs[i].msg_hdr.msg_namelen = addr.sa_socklen; + s_msgs[i].msg_hdr.msg_iov = &tmsgs[i].s_iov; + s_msgs[i].msg_hdr.msg_iovlen = 1; + + tmsgs[i].r_iov.iov_base = tmsgs[i].recv_buf; + tmsgs[i].r_iov.iov_len = sizeof(tmsgs[i].recv_buf); + + r_msgs[i].msg_hdr.msg_name = &tmsgs[i].reply_addr.sa.s; + r_msgs[i].msg_hdr.msg_namelen = tmsgs[i].reply_addr.sa_socklen; + r_msgs[i].msg_hdr.msg_iov = &tmsgs[i].r_iov; + r_msgs[i].msg_hdr.msg_iovlen = 1; + } + + ret = sendmmsg(s, s_msgs, 10, 0); + assert_int_equal(ret, 10); + + ret = recvmmsg(s, r_msgs, 10, 0, NULL); + assert_int_equal(ret, 10); + + for (i = 0; i < 10; i++) { + assert_int_equal(r_msgs[i].msg_hdr.msg_namelen, 0); + assert_ptr_equal(r_msgs[i].msg_hdr.msg_name, &tmsgs[i].reply_addr.sa.s); + + assert_int_equal(r_msgs[i].msg_len, tmsgs[i].s_iov.iov_len); + assert_memory_equal(tmsgs[i].send_buf, tmsgs[i].recv_buf, sizeof(tmsgs[i].send_buf)); + } + + rc = close(s); + assert_int_equal(rc, 0); +} + +#ifdef HAVE_IPV6 +static void test_sendmmsg_recvmmsg_ipv6(void **state) +{ + struct torture_address addr = { + .sa_socklen = sizeof(struct sockaddr_storage), + }; + struct { + struct torture_address reply_addr; + struct iovec s_iov; + struct iovec r_iov; + char send_buf[64]; + char recv_buf[64]; + } tmsgs[10] = {}; + struct mmsghdr s_msgs[10] = {}; + struct mmsghdr r_msgs[10] = {}; + ssize_t ret; + int rc; + int i; + int s; + + (void) state; /* unused */ + + s = socket(AF_INET6, SOCK_STREAM, IPPROTO_TCP); + assert_int_not_equal(s, -1); + + addr.sa.in.sin_family = AF_INET6; + addr.sa.in.sin_port = htons(torture_server_port()); + + rc = inet_pton(AF_INET6, + torture_server_address(AF_INET6), + &addr.sa.in6.sin6_addr); + assert_int_equal(rc, 1); + + rc = connect(s, &addr.sa.s, addr.sa_socklen); + assert_return_code(rc, errno); + + for (i = 0; i < 10; i++) { + tmsgs[i].reply_addr = (struct torture_address){ + .sa_socklen = sizeof(struct sockaddr_storage), + }; + + snprintf(tmsgs[i].send_buf, sizeof(tmsgs[i].send_buf), "packet.%d", i); + + tmsgs[i].s_iov.iov_base = tmsgs[i].send_buf; + tmsgs[i].s_iov.iov_len = sizeof(tmsgs[i].send_buf); + + s_msgs[i].msg_hdr.msg_name = &addr.sa.s; + s_msgs[i].msg_hdr.msg_namelen = addr.sa_socklen; + s_msgs[i].msg_hdr.msg_iov = &tmsgs[i].s_iov; + s_msgs[i].msg_hdr.msg_iovlen = 1; + + tmsgs[i].r_iov.iov_base = tmsgs[i].recv_buf; + tmsgs[i].r_iov.iov_len = sizeof(tmsgs[i].recv_buf); + + r_msgs[i].msg_hdr.msg_name = &tmsgs[i].reply_addr.sa.s; + r_msgs[i].msg_hdr.msg_namelen = tmsgs[i].reply_addr.sa_socklen; + r_msgs[i].msg_hdr.msg_iov = &tmsgs[i].r_iov; + r_msgs[i].msg_hdr.msg_iovlen = 1; + } + + ret = sendmmsg(s, s_msgs, 10, 0); + assert_int_equal(ret, 10); + + ret = recvmmsg(s, r_msgs, 10, 0, NULL); + assert_int_equal(ret, 10); + + for (i = 0; i < 10; i++) { + assert_int_equal(r_msgs[i].msg_hdr.msg_namelen, 0); + assert_ptr_equal(r_msgs[i].msg_hdr.msg_name, &tmsgs[i].reply_addr.sa.s); + + assert_int_equal(r_msgs[i].msg_len, tmsgs[i].s_iov.iov_len); + assert_memory_equal(tmsgs[i].send_buf, tmsgs[i].recv_buf, sizeof(tmsgs[i].send_buf)); + } + + rc = close(s); + assert_int_equal(rc, 0); +} +#endif + +static void test_sendmmsg_recvmmsg_ipv4_null(void **state) +{ + struct torture_address addr = { + .sa_socklen = sizeof(struct sockaddr_storage), + }; + struct { + struct torture_address reply_addr; + struct iovec s_iov; + struct iovec r_iov; + char send_buf[64]; + char recv_buf[64]; + } tmsgs[10] = {}; + struct mmsghdr s_msgs[10] = {}; + struct mmsghdr r_msgs[10] = {}; + ssize_t ret; + int rc; + int i; + int s; + + (void) state; /* unused */ + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + assert_int_not_equal(s, -1); + + addr.sa.in.sin_family = AF_INET; + addr.sa.in.sin_port = htons(torture_server_port()); + + rc = inet_pton(AF_INET, + torture_server_address(AF_INET), + &addr.sa.in.sin_addr); + assert_int_equal(rc, 1); + + rc = connect(s, &addr.sa.s, addr.sa_socklen); + assert_return_code(rc, errno); + + for (i = 0; i < 10; i++) { + tmsgs[i].reply_addr = (struct torture_address){ + .sa_socklen = sizeof(struct sockaddr_storage), + }; + + snprintf(tmsgs[i].send_buf, sizeof(tmsgs[i].send_buf), "packet.%d", i); + + tmsgs[i].s_iov.iov_base = tmsgs[i].send_buf; + tmsgs[i].s_iov.iov_len = sizeof(tmsgs[i].send_buf); + + s_msgs[i].msg_hdr.msg_name = NULL; + s_msgs[i].msg_hdr.msg_namelen = 0; + s_msgs[i].msg_hdr.msg_iov = &tmsgs[i].s_iov; + s_msgs[i].msg_hdr.msg_iovlen = 1; + + tmsgs[i].r_iov.iov_base = tmsgs[i].recv_buf; + tmsgs[i].r_iov.iov_len = sizeof(tmsgs[i].recv_buf); + + r_msgs[i].msg_hdr.msg_name = NULL; + r_msgs[i].msg_hdr.msg_namelen = 0; + r_msgs[i].msg_hdr.msg_iov = &tmsgs[i].r_iov; + r_msgs[i].msg_hdr.msg_iovlen = 1; + } + + ret = sendmmsg(s, s_msgs, 10, 0); + assert_int_equal(ret, 10); + + ret = recvmmsg(s, r_msgs, 10, 0, NULL); + assert_int_equal(ret, 10); + + for (i = 0; i < 10; i++) { + assert_int_equal(r_msgs[i].msg_hdr.msg_namelen, 0); + assert_null(r_msgs[i].msg_hdr.msg_name); + + assert_int_equal(r_msgs[i].msg_len, tmsgs[i].s_iov.iov_len); + assert_memory_equal(tmsgs[i].send_buf, tmsgs[i].recv_buf, sizeof(tmsgs[i].send_buf)); + } + + rc = close(s); + assert_int_equal(rc, 0); +} + +#ifdef SYS_recvmmsg +static int __raw_syscall_close(int sockfd) +{ + return syscall(SYS_close, sockfd); +} + +static int __raw_syscall_recvmmsg(int sockfd, + struct mmsghdr *msgvec, + unsigned int vlen, + int flags, + struct timespec *timeout) +{ + return syscall(SYS_recvmmsg, + sockfd, + (uintptr_t)msgvec, + vlen, + flags, + (uintptr_t)timeout); +} + +static int __raw_syscall_sendmmsg(int sockfd, + struct mmsghdr *msgvec, + unsigned int vlen, + int flags) +{ + return syscall(SYS_sendmmsg, + sockfd, + (uintptr_t)msgvec, + vlen, + flags); +} + +static void test_sendmmsg_recvmmsg_ipv4_raw(void **state) +{ + struct torture_address addr = { + .sa_socklen = sizeof(struct sockaddr_storage), + }; + struct { + struct torture_address reply_addr; + struct iovec s_iov; + struct iovec r_iov; + char send_buf[64]; + char recv_buf[64]; + } tmsgs[10] = {}; + struct mmsghdr s_msgs[10] = {}; + struct mmsghdr r_msgs[10] = {}; + ssize_t ret; + int rc; + int i; + int s; + + (void) state; /* unused */ + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + assert_int_not_equal(s, -1); + + addr.sa.in.sin_family = AF_INET; + addr.sa.in.sin_port = htons(torture_server_port()); + + rc = inet_pton(AF_INET, + torture_server_address(AF_INET), + &addr.sa.in.sin_addr); + assert_int_equal(rc, 1); + + rc = connect(s, &addr.sa.s, addr.sa_socklen); + assert_return_code(rc, errno); + + /* This should be ignored */ + rc = inet_pton(AF_INET, + "127.0.0.1", + &addr.sa.in.sin_addr); + assert_int_equal(rc, 1); + + for (i = 0; i < 10; i++) { + tmsgs[i].reply_addr = (struct torture_address){ + .sa_socklen = sizeof(struct sockaddr_storage), + }; + + snprintf(tmsgs[i].send_buf, sizeof(tmsgs[i].send_buf), "packet.%d", i); + + tmsgs[i].s_iov.iov_base = tmsgs[i].send_buf; + tmsgs[i].s_iov.iov_len = sizeof(tmsgs[i].send_buf); + + s_msgs[i].msg_hdr.msg_name = &addr.sa.s; + s_msgs[i].msg_hdr.msg_namelen = addr.sa_socklen; + s_msgs[i].msg_hdr.msg_iov = &tmsgs[i].s_iov; + s_msgs[i].msg_hdr.msg_iovlen = 1; + + tmsgs[i].r_iov.iov_base = tmsgs[i].recv_buf; + tmsgs[i].r_iov.iov_len = sizeof(tmsgs[i].recv_buf); + + r_msgs[i].msg_hdr.msg_name = &tmsgs[i].reply_addr.sa.s; + r_msgs[i].msg_hdr.msg_namelen = tmsgs[i].reply_addr.sa_socklen; + r_msgs[i].msg_hdr.msg_iov = &tmsgs[i].r_iov; + r_msgs[i].msg_hdr.msg_iovlen = 1; + } + + ret = __raw_syscall_sendmmsg(s, s_msgs, 10, 0); + assert_int_equal(ret, 10); + + ret = __raw_syscall_recvmmsg(s, r_msgs, 10, 0, NULL); + assert_int_equal(ret, 10); + + for (i = 0; i < 10; i++) { + assert_int_equal(r_msgs[i].msg_hdr.msg_namelen, 0); + assert_ptr_equal(r_msgs[i].msg_hdr.msg_name, &tmsgs[i].reply_addr.sa.s); + + assert_int_equal(r_msgs[i].msg_len, tmsgs[i].s_iov.iov_len); + assert_memory_equal(tmsgs[i].send_buf, tmsgs[i].recv_buf, sizeof(tmsgs[i].send_buf)); + } + + rc = __raw_syscall_close(s); + assert_int_equal(rc, 0); +} +#endif /* SYS_recvmmsg */ + +int main(void) { + int rc; + + const struct CMUnitTest sendmsg_tests[] = { + cmocka_unit_test_setup_teardown(test_sendmmsg_recvmmsg_ipv4_ignore, + setup_echo_srv_tcp_ipv4, + teardown), + cmocka_unit_test_setup_teardown(test_sendmmsg_recvmmsg_ipv4_null, + setup_echo_srv_tcp_ipv4, + teardown), +#ifdef SYS_recvmmsg + cmocka_unit_test_setup_teardown(test_sendmmsg_recvmmsg_ipv4_raw, + setup_echo_srv_tcp_ipv4, + teardown), +#endif /* SYS_recvmmsg */ +#ifdef HAVE_IPV6 + cmocka_unit_test_setup_teardown(test_sendmmsg_recvmmsg_ipv6, + setup_echo_srv_tcp_ipv6, + teardown), +#endif + }; + + rc = cmocka_run_group_tests(sendmsg_tests, NULL, NULL); + + return rc; +} diff --git a/tests/test_syscall_uwrap.c b/tests/test_syscall_uwrap.c new file mode 100644 index 0000000000000000000000000000000000000000..4ded634be2ba71ca9d78d13a5ba940d5d51c6850 --- /dev/null +++ b/tests/test_syscall_uwrap.c @@ -0,0 +1,50 @@ +#include "config.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#ifdef HAVE_SYS_SYSCALL_H +#include +#endif +#ifdef HAVE_SYSCALL_H +#include +#endif + +#include "swrap_fake_uid_wrapper.h" + +static void test_swrap_syscall_uwrap(void **state) +{ + long int rc; + + (void) state; /* unused */ + + rc = syscall(__FAKE_UID_WRAPPER_SYSCALL_NO); + assert_int_equal(rc, __FAKE_UID_WRAPPER_SYSCALL_RC); + + signal(SIGSYS, SIG_IGN); + rc = syscall(__FAKE_UID_WRAPPER_SYSCALL_NO+1); + signal(SIGSYS, SIG_DFL); + assert_int_equal(rc, -1); + assert_int_equal(errno, ENOSYS); +} + +int main(void) { + int rc; + + const struct CMUnitTest swrap_tests[] = { + cmocka_unit_test(test_swrap_syscall_uwrap), + }; + + rc = cmocka_run_group_tests(swrap_tests, NULL, NULL); + + return rc; +}