rtnl: message - read group membership of incoming messages
authorTom Gundersen <teg@jklm.no>
Sat, 10 May 2014 18:15:52 +0000 (20:15 +0200)
committerTom Gundersen <teg@jklm.no>
Sat, 10 May 2014 18:56:37 +0000 (20:56 +0200)
src/libsystemd/sd-rtnl/rtnl-message.c
src/libsystemd/sd-rtnl/sd-rtnl.c
src/shared/missing.h

index 9558e11..84b46af 100644 (file)
@@ -1098,8 +1098,9 @@ int socket_write_message(sd_rtnl *nl, sd_rtnl_message *m) {
         return k;
 }
 
-static int socket_recv_message(int fd, struct iovec *iov, bool peek) {
-        uint8_t cred_buffer[CMSG_SPACE(sizeof(struct ucred))];
+static int socket_recv_message(int fd, struct iovec *iov, uint32_t *_group, bool peek) {
+        uint8_t cred_buffer[CMSG_SPACE(sizeof(struct ucred)) +
+                            CMSG_SPACE(sizeof(struct nl_pktinfo))];
         struct msghdr msg = {
                 .msg_iov = iov,
                 .msg_iovlen = 1,
@@ -1107,6 +1108,7 @@ static int socket_recv_message(int fd, struct iovec *iov, bool peek) {
                 .msg_controllen = sizeof(cred_buffer),
         };
         struct cmsghdr *cmsg;
+        uint32_t group = 0;
         bool auth = false;
         int r;
 
@@ -1128,10 +1130,15 @@ static int socket_recv_message(int fd, struct iovec *iov, bool peek) {
                         struct ucred *ucred = (void *)CMSG_DATA(cmsg);
 
                         /* from the kernel */
-                        if (ucred->uid == 0 && ucred->pid == 0) {
+                        if (ucred->uid == 0 && ucred->pid == 0)
                                 auth = true;
-                                break;
-                        }
+                } else if (cmsg->cmsg_level == SOL_NETLINK &&
+                           cmsg->cmsg_type == NETLINK_PKTINFO &&
+                           cmsg->cmsg_len == CMSG_LEN(sizeof(struct nl_pktinfo))) {
+                        struct nl_pktinfo *pktinfo = (void *)CMSG_DATA(cmsg);
+
+                        /* multi-cast group */
+                        group = pktinfo->group;
                 }
         }
 
@@ -1139,6 +1146,9 @@ static int socket_recv_message(int fd, struct iovec *iov, bool peek) {
                 /* not from the kernel, ignore */
                 return 0;
 
+        if (group)
+                *_group = group;
+
         return r;
 }
 
@@ -1150,6 +1160,7 @@ static int socket_recv_message(int fd, struct iovec *iov, bool peek) {
 int socket_read_message(sd_rtnl *rtnl) {
         _cleanup_rtnl_message_unref_ sd_rtnl_message *first = NULL;
         struct iovec iov = {};
+        uint32_t group = 0;
         bool multi_part = false, done = false;
         struct nlmsghdr *new_msg;
         size_t len;
@@ -1161,7 +1172,7 @@ int socket_read_message(sd_rtnl *rtnl) {
         assert(rtnl->rbuffer_allocated >= sizeof(struct nlmsghdr));
 
         /* read nothing, just get the pending message size */
-        r = socket_recv_message(rtnl->fd, &iov, true);
+        r = socket_recv_message(rtnl->fd, &iov, &group, true);
         if (r <= 0)
                 return r;
         else
@@ -1177,7 +1188,7 @@ int socket_read_message(sd_rtnl *rtnl) {
         iov.iov_len = rtnl->rbuffer_allocated;
 
         /* read the pending message */
-        r = socket_recv_message(rtnl->fd, &iov, false);
+        r = socket_recv_message(rtnl->fd, &iov, &group, false);
         if (r <= 0)
                 return r;
         else
index 4ee360c..b91d080 100644 (file)
@@ -22,6 +22,7 @@
 #include <sys/socket.h>
 #include <poll.h>
 
+#include "missing.h"
 #include "macro.h"
 #include "util.h"
 #include "hashmap.h"
@@ -109,7 +110,12 @@ int sd_rtnl_open(sd_rtnl **ret, unsigned n_groups, ...) {
         if (rtnl->fd < 0)
                 return -errno;
 
-        if (setsockopt(rtnl->fd, SOL_SOCKET, SO_PASSCRED, &one, sizeof(one)) < 0)
+        r = setsockopt(rtnl->fd, SOL_SOCKET, SO_PASSCRED, &one, sizeof(one));
+        if (r < 0)
+                return -errno;
+
+        r = setsockopt(rtnl->fd, SOL_NETLINK, NETLINK_PKTINFO, &one, sizeof(one));
+        if (r < 0)
                 return -errno;
 
         va_start(ap, n_groups);
index d5ec2f8..716d3b8 100644 (file)
   #endif
 #endif
 
+#ifndef SOL_NETLINK
+#define SOL_NETLINK 270
+#endif
+
 #if !HAVE_DECL_PIVOT_ROOT
 static inline int pivot_root(const char *new_root, const char *put_old) {
         return syscall(SYS_pivot_root, new_root, put_old);