tree-wide: port more code to use send_one_fd() and receive_one_fd()
authorLennart Poettering <lennart@poettering.net>
Tue, 22 Sep 2015 23:00:04 +0000 (01:00 +0200)
committerLennart Poettering <lennart@poettering.net>
Tue, 29 Sep 2015 19:08:37 +0000 (21:08 +0200)
Also, make it slightly more powerful, by accepting a flags argument, and
make it safe for handling if more than one cmsg attribute happens to be
attached.

src/basic/util.c
src/basic/util.h
src/core/namespace.c
src/import/importd.c
src/journal/journal-send.c
src/libsystemd/sd-bus/bus-container.c
src/nspawn/nspawn-expose-ports.c
src/nspawn/nspawn.c

index bc61ec0..c788a1d 100644 (file)
@@ -6201,15 +6201,6 @@ int ptsname_malloc(int fd, char **ret) {
 int openpt_in_namespace(pid_t pid, int flags) {
         _cleanup_close_ int pidnsfd = -1, mntnsfd = -1, usernsfd = -1, rootfd = -1;
         _cleanup_close_pair_ int pair[2] = { -1, -1 };
-        union {
-                struct cmsghdr cmsghdr;
-                uint8_t buf[CMSG_SPACE(sizeof(int))];
-        } control = {};
-        struct msghdr mh = {
-                .msg_control = &control,
-                .msg_controllen = sizeof(control),
-        };
-        struct cmsghdr *cmsg;
         siginfo_t si;
         pid_t child;
         int r;
@@ -6243,15 +6234,7 @@ int openpt_in_namespace(pid_t pid, int flags) {
                 if (unlockpt(master) < 0)
                         _exit(EXIT_FAILURE);
 
-                cmsg = CMSG_FIRSTHDR(&mh);
-                cmsg->cmsg_level = SOL_SOCKET;
-                cmsg->cmsg_type = SCM_RIGHTS;
-                cmsg->cmsg_len = CMSG_LEN(sizeof(int));
-                memcpy(CMSG_DATA(cmsg), &master, sizeof(int));
-
-                mh.msg_controllen = cmsg->cmsg_len;
-
-                if (sendmsg(pair[1], &mh, MSG_NOSIGNAL) < 0)
+                if (send_one_fd(pair[1], master, 0) < 0)
                         _exit(EXIT_FAILURE);
 
                 _exit(EXIT_SUCCESS);
@@ -6265,26 +6248,7 @@ int openpt_in_namespace(pid_t pid, int flags) {
         if (si.si_code != CLD_EXITED || si.si_status != EXIT_SUCCESS)
                 return -EIO;
 
-        if (recvmsg(pair[0], &mh, MSG_NOSIGNAL|MSG_CMSG_CLOEXEC) < 0)
-                return -errno;
-
-        CMSG_FOREACH(cmsg, &mh)
-                if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
-                        int *fds;
-                        unsigned n_fds;
-
-                        fds = (int*) CMSG_DATA(cmsg);
-                        n_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
-
-                        if (n_fds != 1) {
-                                close_many(fds, n_fds);
-                                return -EIO;
-                        }
-
-                        return fds[0];
-                }
-
-        return -EIO;
+        return receive_one_fd(pair[0], 0);
 }
 
 ssize_t fgetxattrat_fake(int dirfd, const char *filename, const char *attribute, void *value, size_t size, int flags) {
@@ -6884,7 +6848,7 @@ int fgetxattr_malloc(int fd, const char *name, char **value) {
         }
 }
 
-int send_one_fd(int transport_fd, int fd) {
+int send_one_fd(int transport_fd, int fd, int flags) {
         union {
                 struct cmsghdr cmsghdr;
                 uint8_t buf[CMSG_SPACE(sizeof(int))];
@@ -6894,7 +6858,6 @@ int send_one_fd(int transport_fd, int fd) {
                 .msg_controllen = sizeof(control),
         };
         struct cmsghdr *cmsg;
-        ssize_t k;
 
         assert(transport_fd >= 0);
         assert(fd >= 0);
@@ -6906,14 +6869,13 @@ int send_one_fd(int transport_fd, int fd) {
         memcpy(CMSG_DATA(cmsg), &fd, sizeof(int));
 
         mh.msg_controllen = CMSG_SPACE(sizeof(int));
-        k = sendmsg(transport_fd, &mh, MSG_NOSIGNAL);
-        if (k < 0)
+        if (sendmsg(transport_fd, &mh, MSG_NOSIGNAL | flags) < 0)
                 return -errno;
 
         return 0;
 }
 
-int receive_one_fd(int transport_fd) {
+int receive_one_fd(int transport_fd, int flags) {
         union {
                 struct cmsghdr cmsghdr;
                 uint8_t buf[CMSG_SPACE(sizeof(int))];
@@ -6922,33 +6884,35 @@ int receive_one_fd(int transport_fd) {
                 .msg_control = &control,
                 .msg_controllen = sizeof(control),
         };
-        struct cmsghdr *cmsg;
-        ssize_t k;
+        struct cmsghdr *cmsg, *found = NULL;
 
         assert(transport_fd >= 0);
 
         /*
-         * Receive a single FD via @transport_fd. We don't care for the
-         * transport-type, but the caller must assure that no other CMSG types
-         * than SCM_RIGHTS is enabled. We also retrieve a single FD at most, so
-         * for packet-based transports, the caller must ensure to send only a
-         * single FD per packet.
-         * This is best used in combination with send_one_fd().
+         * Receive a single FD via @transport_fd. We don't care for
+         * the transport-type. We retrieve a single FD at most, so for
+         * packet-based transports, the caller must ensure to send
+         * only a single FD per packet.  This is best used in
+         * combination with send_one_fd().
          */
 
-        k = recvmsg(transport_fd, &mh, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC);
-        if (k < 0)
+        if (recvmsg(transport_fd, &mh, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC | flags) < 0)
                 return -errno;
 
-        cmsg = CMSG_FIRSTHDR(&mh);
-        if (!cmsg || CMSG_NXTHDR(&mh, cmsg) ||
-            cmsg->cmsg_level != SOL_SOCKET ||
-            cmsg->cmsg_type != SCM_RIGHTS ||
-            cmsg->cmsg_len != CMSG_LEN(sizeof(int)) ||
-            *(const int *)CMSG_DATA(cmsg) < 0) {
+        CMSG_FOREACH(cmsg, &mh) {
+                if (cmsg->cmsg_level == SOL_SOCKET &&
+                    cmsg->cmsg_type == SCM_RIGHTS &&
+                    cmsg->cmsg_len == CMSG_LEN(sizeof(int))) {
+                        assert(!found);
+                        found = cmsg;
+                        break;
+                }
+        }
+
+        if (!found) {
                 cmsg_close_all(&mh);
                 return -EIO;
         }
 
-        return *(const int *)CMSG_DATA(cmsg);
+        return *(int*) CMSG_DATA(found);
 }
index 56d9f03..e6417b4 100644 (file)
@@ -942,5 +942,5 @@ int reset_uid_gid(void);
 int getxattr_malloc(const char *path, const char *name, char **value, bool allow_symlink);
 int fgetxattr_malloc(int fd, const char *name, char **value);
 
-int send_one_fd(int transport_fd, int fd);
-int receive_one_fd(int transport_fd);
+int send_one_fd(int transport_fd, int fd, int flags);
+int receive_one_fd(int transport_fd, int flags);
index eb88574..2b8b707 100644 (file)
@@ -643,16 +643,7 @@ int setup_tmp_dirs(const char *id, char **tmp_dir, char **var_tmp_dir) {
 
 int setup_netns(int netns_storage_socket[2]) {
         _cleanup_close_ int netns = -1;
-        union {
-                struct cmsghdr cmsghdr;
-                uint8_t buf[CMSG_SPACE(sizeof(int))];
-        } control = {};
-        struct msghdr mh = {
-                .msg_control = &control,
-                .msg_controllen = sizeof(control),
-        };
-        struct cmsghdr *cmsg;
-        int r;
+        int r, q;
 
         assert(netns_storage_socket);
         assert(netns_storage_socket[0] >= 0);
@@ -669,12 +660,8 @@ int setup_netns(int netns_storage_socket[2]) {
         if (lockf(netns_storage_socket[0], F_LOCK, 0) < 0)
                 return -errno;
 
-        if (recvmsg(netns_storage_socket[0], &mh, MSG_DONTWAIT|MSG_CMSG_CLOEXEC) < 0) {
-                if (errno != EAGAIN) {
-                        r = -errno;
-                        goto fail;
-                }
-
+        netns = receive_one_fd(netns_storage_socket[0], MSG_DONTWAIT);
+        if (netns == -EAGAIN) {
                 /* Nothing stored yet, so let's create a new namespace */
 
                 if (unshare(CLONE_NEWNET) < 0) {
@@ -691,15 +678,13 @@ int setup_netns(int netns_storage_socket[2]) {
                 }
 
                 r = 1;
-        } else {
-                /* Yay, found something, so let's join the namespace */
 
-                CMSG_FOREACH(cmsg, &mh)
-                        if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
-                                assert(cmsg->cmsg_len == CMSG_LEN(sizeof(int)));
-                                netns = *(int*) CMSG_DATA(cmsg);
-                        }
+        } else if (netns < 0) {
+                r = netns;
+                goto fail;
 
+        } else {
+                /* Yay, found something, so let's join the namespace */
                 if (setns(netns, CLONE_NEWNET) < 0) {
                         r = -errno;
                         goto fail;
@@ -708,21 +693,14 @@ int setup_netns(int netns_storage_socket[2]) {
                 r = 0;
         }
 
-        cmsg = CMSG_FIRSTHDR(&mh);
-        cmsg->cmsg_level = SOL_SOCKET;
-        cmsg->cmsg_type = SCM_RIGHTS;
-        cmsg->cmsg_len = CMSG_LEN(sizeof(int));
-        memcpy(CMSG_DATA(cmsg), &netns, sizeof(int));
-        mh.msg_controllen = cmsg->cmsg_len;
-
-        if (sendmsg(netns_storage_socket[1], &mh, MSG_DONTWAIT|MSG_NOSIGNAL) < 0) {
-                r = -errno;
+        q = send_one_fd(netns_storage_socket[1], netns, MSG_DONTWAIT);
+        if (q < 0) {
+                r = q;
                 goto fail;
         }
 
 fail:
         lockf(netns_storage_socket[0], F_ULOCK, 0);
-
         return r;
 }
 
index c90ada5..a29e9d4 100644 (file)
@@ -600,11 +600,11 @@ static int manager_on_notify(sd_event_source *s, int fd, uint32_t revents, void
 
         cmsg_close_all(&msghdr);
 
-        CMSG_FOREACH(cmsg, &msghdr) {
-                if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS && cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred)))
-
+        CMSG_FOREACH(cmsg, &msghdr)
+                if (cmsg->cmsg_level == SOL_SOCKET &&
+                    cmsg->cmsg_type == SCM_CREDENTIALS &&
+                    cmsg->cmsg_len == CMSG_LEN(sizeof(struct ucred)))
                         ucred = (struct ucred*) CMSG_DATA(cmsg);
-        }
 
         if (msghdr.msg_flags & MSG_TRUNC) {
                 log_warning("Got overly long notification datagram, ignoring.");
index 1e3a463..dc1b210 100644 (file)
@@ -212,11 +212,6 @@ _public_ int sd_journal_sendv(const struct iovec *iov, int n) {
                 .msg_namelen = offsetof(struct sockaddr_un, sun_path) + strlen(sa.sun_path),
         };
         ssize_t k;
-        union {
-                struct cmsghdr cmsghdr;
-                uint8_t buf[CMSG_SPACE(sizeof(int))];
-        } control;
-        struct cmsghdr *cmsg;
         bool have_syslog_identifier = false;
         bool seal = true;
 
@@ -335,26 +330,7 @@ _public_ int sd_journal_sendv(const struct iovec *iov, int n) {
                         return r;
         }
 
-        mh.msg_iov = NULL;
-        mh.msg_iovlen = 0;
-
-        zero(control);
-        mh.msg_control = &control;
-        mh.msg_controllen = sizeof(control);
-
-        cmsg = CMSG_FIRSTHDR(&mh);
-        cmsg->cmsg_level = SOL_SOCKET;
-        cmsg->cmsg_type = SCM_RIGHTS;
-        cmsg->cmsg_len = CMSG_LEN(sizeof(int));
-        memcpy(CMSG_DATA(cmsg), &buffer_fd, sizeof(int));
-
-        mh.msg_controllen = cmsg->cmsg_len;
-
-        k = sendmsg(fd, &mh, MSG_NOSIGNAL);
-        if (k < 0)
-                return -errno;
-
-        return 0;
+        return send_one_fd(fd, buffer_fd, 0);
 }
 
 static int fill_iovec_perror_and_send(const char *message, int skip, struct iovec iov[]) {
index 5c607f4..435ec92 100644 (file)
@@ -217,15 +217,8 @@ int bus_container_connect_kernel(sd_bus *b) {
                                 _exit(EXIT_FAILURE);
                         }
 
-                        cmsg = CMSG_FIRSTHDR(&mh);
-                        cmsg->cmsg_level = SOL_SOCKET;
-                        cmsg->cmsg_type = SCM_RIGHTS;
-                        cmsg->cmsg_len = CMSG_LEN(sizeof(int));
-                        memcpy(CMSG_DATA(cmsg), &fd, sizeof(int));
-
-                        mh.msg_controllen = cmsg->cmsg_len;
-
-                        if (sendmsg(pair[1], &mh, MSG_NOSIGNAL) < 0)
+                        r = send_one_fd(pair[1], fd, 0);
+                        if (r < 0)
                                 _exit(EXIT_FAILURE);
 
                         _exit(EXIT_SUCCESS);
index 9e63d88..3658f45 100644 (file)
@@ -194,7 +194,7 @@ int expose_port_send_rtnl(int send_fd) {
 
         /* Store away the fd in the socket, so that it stays open as
          * long as we run the child */
-        r = send_one_fd(send_fd, fd);
+        r = send_one_fd(send_fd, fd, 0);
         if (r < 0)
                 return log_error_errno(r, "Failed to send netlink fd: %m");
 
@@ -214,7 +214,7 @@ int expose_port_watch_rtnl(
         assert(recv_fd >= 0);
         assert(ret);
 
-        fd = receive_one_fd(recv_fd);
+        fd = receive_one_fd(recv_fd, 0);
         if (fd < 0)
                 return log_error_errno(fd, "Failed to recv netlink fd: %m");
 
index 7451c2b..f4721a1 100644 (file)
@@ -1291,7 +1291,7 @@ static int setup_kmsg(const char *dest, int kmsg_socket) {
 
         /* Store away the fd in the socket, so that it stays open as
          * long as we run the child */
-        r = send_one_fd(kmsg_socket, fd);
+        r = send_one_fd(kmsg_socket, fd, 0);
         safe_close(fd);
 
         if (r < 0)