os: wrap recvmsg(MSG_CMSG_CLOEXEC)
authorPekka Paalanen <ppaalanen@gmail.com>
Wed, 21 Mar 2012 14:00:23 +0000 (16:00 +0200)
committerPekka Paalanen <ppaalanen@gmail.com>
Wed, 25 Apr 2012 06:37:42 +0000 (09:37 +0300)
Some system C libraries do not have MSG_CMSG_CLOEXEC. This flag would
automatically set O_CLOEXEC flag on any received file descriptors.

Provide a fallback that does it manually. If setting CLOEXEC fails, the
file descriptor is closed immediately, which will lead to failures but
avoid leaks. However, setting CLOEXEC is not really expected to fail
occasionally.

Add tests for the wrapper. The setup is copied from connection-test.c.

Signed-off-by: Pekka Paalanen <ppaalanen@gmail.com>
src/connection.c
src/wayland-os.c
src/wayland-os.h
tests/os-wrappers-test.c

index c49ca3d..06cc66f 100644 (file)
@@ -307,7 +307,7 @@ wl_connection_data(struct wl_connection *connection, uint32_t mask)
                msg.msg_flags = 0;
 
                do {
-                       len = recvmsg(connection->fd, &msg, MSG_CMSG_CLOEXEC);
+                       len = wl_os_recvmsg_cloexec(connection->fd, &msg, 0);
                } while (len < 0 && errno == EINTR);
 
                if (len < 0) {
index 4a19da6..eb53eec 100644 (file)
@@ -79,3 +79,48 @@ wl_os_dupfd_cloexec(int fd, long minfd)
        newfd = fcntl(fd, F_DUPFD, minfd);
        return set_cloexec_or_close(newfd);
 }
+
+static ssize_t
+recvmsg_cloexec_fallback(int sockfd, struct msghdr *msg, int flags)
+{
+       ssize_t len;
+       struct cmsghdr *cmsg;
+       unsigned char *data;
+       int *fd;
+       int *end;
+
+       len = recvmsg(sockfd, msg, flags);
+       if (len == -1)
+               return -1;
+
+       if (!msg->msg_control || msg->msg_controllen == 0)
+               return len;
+
+       cmsg = CMSG_FIRSTHDR(msg);
+       for (; cmsg != NULL; cmsg = CMSG_NXTHDR(msg, cmsg)) {
+               if (cmsg->cmsg_level != SOL_SOCKET ||
+                   cmsg->cmsg_type != SCM_RIGHTS)
+                       continue;
+
+               data = CMSG_DATA(cmsg);
+               end = (int *)(data + cmsg->cmsg_len - CMSG_LEN(0));
+               for (fd = (int *)data; fd < end; ++fd)
+                       *fd = set_cloexec_or_close(*fd);
+       }
+
+       return len;
+}
+
+ssize_t
+wl_os_recvmsg_cloexec(int sockfd, struct msghdr *msg, int flags)
+{
+       ssize_t len;
+
+       len = recvmsg(sockfd, msg, flags | MSG_CMSG_CLOEXEC);
+       if (len >= 0)
+               return len;
+       if (errno != EINVAL)
+               return -1;
+
+       return recvmsg_cloexec_fallback(sockfd, msg, flags);
+}
index 456d8b0..43c317b 100644 (file)
@@ -29,6 +29,10 @@ wl_os_socket_cloexec(int domain, int type, int protocol);
 int
 wl_os_dupfd_cloexec(int fd, long minfd);
 
+ssize_t
+wl_os_recvmsg_cloexec(int sockfd, struct msghdr *msg, int flags);
+
+
 /*
  * The following are for wayland-os.c and the unit tests.
  * Do not use them elsewhere.
@@ -44,6 +48,10 @@ wl_os_dupfd_cloexec(int fd, long minfd);
 #define F_DUPFD_CLOEXEC 1030
 #endif
 
+#ifndef MSG_CMSG_CLOEXEC
+#define MSG_CMSG_CLOEXEC 0x40000000
+#endif
+
 #endif /* __linux__ */
 
 #endif
index 2272b73..657f1fe 100644 (file)
@@ -1,5 +1,6 @@
 /*
  * Copyright © 2012 Collabora, Ltd.
+ * Copyright © 2012 Intel Corporation
  *
  * Permission to use, copy, modify, distribute, and sell this software and its
  * documentation for any purpose is hereby granted without fee, provided that
 #include <assert.h>
 #include <sys/types.h>
 #include <sys/socket.h>
+#include <sys/stat.h>
 #include <unistd.h>
 #include <dlfcn.h>
 #include <errno.h>
 #include <stdarg.h>
 #include <fcntl.h>
+#include <stdio.h>
 
+#include "../src/wayland-private.h"
 #include "test-runner.h"
 #include "../src/wayland-os.h"
 
@@ -43,12 +47,16 @@ static int wrapped_calls_socket;
 static int (*real_fcntl)(int, int, ...);
 static int wrapped_calls_fcntl;
 
+static ssize_t (*real_recvmsg)(int, struct msghdr *, int);
+static int wrapped_calls_recvmsg;
+
 static void
 init_fallbacks(int do_fallbacks)
 {
        fall_back = do_fallbacks;
        real_socket = dlsym(RTLD_NEXT, "socket");
        real_fcntl = dlsym(RTLD_NEXT, "fcntl");
+       real_recvmsg = dlsym(RTLD_NEXT, "recvmsg");
 }
 
 __attribute__ ((visibility("default"))) int
@@ -84,6 +92,19 @@ fcntl(int fd, int cmd, ...)
        return real_fcntl(fd, cmd, arg);
 }
 
+__attribute__ ((visibility("default"))) ssize_t
+recvmsg(int sockfd, struct msghdr *msg, int flags)
+{
+       wrapped_calls_recvmsg++;
+
+       if (fall_back && (flags & MSG_CMSG_CLOEXEC)) {
+               errno = EINVAL;
+               return -1;
+       }
+
+       return real_recvmsg(sockfd, msg, flags);
+}
+
 static void
 do_os_wrappers_socket_cloexec(int n)
 {
@@ -156,3 +177,157 @@ TEST(os_wrappers_dupfd_cloexec_fallback)
        init_fallbacks(1);
        do_os_wrappers_dupfd_cloexec(3);
 }
+
+struct marshal_data {
+       struct wl_connection *read_connection;
+       struct wl_connection *write_connection;
+       int s[2];
+       uint32_t read_mask;
+       uint32_t write_mask;
+       union {
+               int h[3];
+       } value;
+       int nr_fds_begin;
+       int nr_fds_conn;
+       int wrapped_calls;
+};
+
+static int
+update_func(struct wl_connection *connection, uint32_t mask, void *data)
+{
+       uint32_t *m = data;
+
+       *m = mask;
+
+       return 0;
+}
+
+static void
+setup_marshal_data(struct marshal_data *data)
+{
+       assert(socketpair(AF_UNIX,
+                         SOCK_STREAM | SOCK_CLOEXEC, 0, data->s) == 0);
+
+       data->read_connection =
+               wl_connection_create(data->s[0],
+                                    update_func, &data->read_mask);
+       assert(data->read_connection);
+       assert(data->read_mask == WL_CONNECTION_READABLE);
+
+       data->write_connection =
+               wl_connection_create(data->s[1],
+                                    update_func, &data->write_mask);
+       assert(data->write_connection);
+       assert(data->write_mask == WL_CONNECTION_READABLE);
+}
+
+static void
+marshal_demarshal(struct marshal_data *data, 
+                 void (*func)(void), int size, const char *format, ...)
+{
+       struct wl_closure closure;
+       static const int opcode = 4444;
+       static struct wl_object sender = { NULL, NULL, 1234 };
+       struct wl_message message = { "test", format, NULL };
+       struct wl_map objects;
+       struct wl_object object;
+       va_list ap;
+       uint32_t msg[1] = { 1234 };
+       int ret;
+
+       va_start(ap, format);
+       ret = wl_closure_vmarshal(&closure, &sender, opcode, ap, &message);
+       va_end(ap);
+
+       assert(ret == 0);
+       assert(wl_closure_send(&closure, data->write_connection) == 0);
+       wl_closure_destroy(&closure);
+       assert(data->write_mask ==
+              (WL_CONNECTION_WRITABLE | WL_CONNECTION_READABLE));
+       assert(wl_connection_data(data->write_connection,
+                                 WL_CONNECTION_WRITABLE) == 0);
+       assert(data->write_mask == WL_CONNECTION_READABLE);
+
+       assert(wl_connection_data(data->read_connection,
+                                 WL_CONNECTION_READABLE) == size);
+
+       wl_map_init(&objects);
+       object.id = msg[0];
+       ret  = wl_connection_demarshal(data->read_connection,
+                                      &closure, size, &objects, &message);
+       wl_closure_invoke(&closure, &object, func, data);
+       wl_closure_destroy(&closure);
+}
+
+static void
+validate_recvmsg_h(struct marshal_data *data,
+                  struct wl_object *object, int fd1, int fd2, int fd3)
+{
+       struct stat buf1, buf2;
+
+       assert(fd1 >= 0);
+       assert(fd2 >= 0);
+       assert(fd3 >= 0);
+
+       assert(fd1 != data->value.h[0]);
+       assert(fd2 != data->value.h[1]);
+       assert(fd3 != data->value.h[2]);
+
+       assert(fstat(fd3, &buf1) == 0);
+       assert(fstat(data->value.h[2], &buf2) == 0);
+       assert(buf1.st_dev == buf2.st_dev);
+       assert(buf1.st_ino == buf2.st_ino);
+
+       /* close the original file descriptors */
+       close(data->value.h[0]);
+       close(data->value.h[1]);
+       close(data->value.h[2]);
+
+       /* the dup'd (received) fds should still be open */
+       assert(count_open_fds() == data->nr_fds_conn + 3);
+
+       /*
+        * Must have 2 calls if falling back, but must also allow
+        * falling back without a forced fallback.
+        */
+       assert(wrapped_calls_recvmsg > data->wrapped_calls);
+
+       if (data->wrapped_calls == 0 && wrapped_calls_recvmsg > 1)
+               printf("recvmsg fell back unforced.\n");
+
+       /* all fds opened during the test in any way should be gone on exec */
+       exec_fd_leak_check(data->nr_fds_begin);
+}
+
+static void
+do_os_wrappers_recvmsg_cloexec(int n)
+{
+       struct marshal_data data;
+
+       data.nr_fds_begin = count_open_fds();
+       data.wrapped_calls = n;
+
+       setup_marshal_data(&data);
+       data.nr_fds_conn = count_open_fds();
+
+       assert(pipe(data.value.h) >= 0);
+
+       data.value.h[2] = open("/dev/zero", O_RDONLY);
+       assert(data.value.h[2] >= 0);
+
+       marshal_demarshal(&data, (void *) validate_recvmsg_h,
+                         8, "hhh", data.value.h[0], data.value.h[1],
+                         data.value.h[2]);
+}
+
+TEST(os_wrappers_recvmsg_cloexec)
+{
+       init_fallbacks(0);
+       do_os_wrappers_recvmsg_cloexec(0);
+}
+
+TEST(os_wrappers_recvmsg_cloexec_fallback)
+{
+       init_fallbacks(1);
+       do_os_wrappers_recvmsg_cloexec(1);
+}