sd-netlink: refcount multicast groups
authorTom Gundersen <teg@jklm.no>
Thu, 15 Oct 2015 15:59:10 +0000 (17:59 +0200)
committerTom Gundersen <teg@jklm.no>
Thu, 15 Oct 2015 16:59:08 +0000 (18:59 +0200)
Track the number of matches installed for a given multicast group, and leave the
group once no matches depend on it.

In order to handle passed-in sockets that are already members of multicast groups
we initialize the refcount based on the membership once we take over the socket.
This way we will leave the socket in the state we found it once we finish with
it.

On kernels that do not fully support reading out the multicast group membership
we fall back to never leaving any groups (as before).

src/basic/missing.h
src/libsystemd/sd-netlink/netlink-internal.h
src/libsystemd/sd-netlink/netlink-socket.c
src/libsystemd/sd-netlink/sd-netlink.c

index 59e835a..5fb9951 100644 (file)
 #define SOL_NETLINK 270
 #endif
 
+#ifndef NETLINK_LIST_MEMBERSHIPS
+#define NETLINK_LIST_MEMBERSHIPS 9
+#endif
+
 #if !HAVE_DECL_PIVOT_ROOT
 static inline int pivot_root(const char *new_root, const char *put_old) {
         return syscall(SYS_pivot_root, new_root, put_old);
index 4026e2c..b9cb806 100644 (file)
@@ -64,6 +64,9 @@ struct sd_netlink {
                 struct sockaddr_nl nl;
         } sockaddr;
 
+        Hashmap *broadcast_group_refs;
+        bool broadcast_group_dont_leave:1; /* until we can rely on 4.2 */
+
         sd_netlink_message **rqueue;
         unsigned rqueue_size;
         size_t rqueue_allocated;
@@ -124,7 +127,8 @@ int message_new_empty(sd_netlink *rtnl, sd_netlink_message **ret);
 
 int socket_open(int family);
 int socket_bind(sd_netlink *nl);
-int socket_join_broadcast_group(sd_netlink *nl, unsigned group);
+int socket_broadcast_group_ref(sd_netlink *nl, unsigned group);
+int socket_broadcast_group_unref(sd_netlink *nl, unsigned group);
 int socket_write_message(sd_netlink *nl, sd_netlink_message *m);
 int socket_read_message(sd_netlink *nl);
 
index 84ff7c3..e1b14c3 100644 (file)
@@ -44,6 +44,65 @@ int socket_open(int family) {
         return fd;
 }
 
+static int broadcast_groups_get(sd_netlink *nl) {
+        _cleanup_free_ uint32_t *groups = NULL;
+        socklen_t len = 0, old_len;
+        unsigned i, j;
+        int r;
+
+        assert(nl);
+        assert(nl->fd > 0);
+
+        r = getsockopt(nl->fd, SOL_NETLINK, NETLINK_LIST_MEMBERSHIPS, NULL, &len);
+        if (r < 0) {
+                if (errno == ENOPROTOOPT) {
+                        nl->broadcast_group_dont_leave = true;
+                        return 0;
+                } else
+                        return -errno;
+        }
+
+        if (len == 0)
+                return 0;
+
+        groups = new0(uint32_t, len);
+        if (!groups)
+                return -ENOMEM;
+
+        old_len = len;
+
+        r = getsockopt(nl->fd, SOL_NETLINK, NETLINK_LIST_MEMBERSHIPS, groups, &len);
+        if (r < 0)
+                return -errno;
+
+        if (old_len != len)
+                return -EIO;
+
+        r = hashmap_ensure_allocated(&nl->broadcast_group_refs, NULL);
+        if (r < 0)
+                return r;
+
+        for (i = 0; i < len; i++) {
+                for (j = 0; j < sizeof(uint32_t) * 8; j ++) {
+                        uint32_t offset;
+                        unsigned group;
+
+                        offset = 1U << j;
+
+                        if (!(groups[i] & offset))
+                                continue;
+
+                        group = i * sizeof(uint32_t) * 8 + j + 1;
+
+                        r = hashmap_put(nl->broadcast_group_refs, UINT_TO_PTR(group), UINT_TO_PTR(1));
+                        if (r < 0)
+                                return r;
+                }
+        }
+
+        return 0;
+}
+
 int socket_bind(sd_netlink *nl) {
         socklen_t addrlen;
         int r, one = 1;
@@ -63,11 +122,32 @@ int socket_bind(sd_netlink *nl) {
         if (r < 0)
                 return -errno;
 
+        r = broadcast_groups_get(nl);
+        if (r < 0)
+                return r;
+
         return 0;
 }
 
+static unsigned broadcast_group_get_ref(sd_netlink *nl, unsigned group) {
+        assert(nl);
+
+        return PTR_TO_UINT(hashmap_get(nl->broadcast_group_refs, UINT_TO_PTR(group)));
+}
 
-int socket_join_broadcast_group(sd_netlink *nl, unsigned group) {
+static int broadcast_group_set_ref(sd_netlink *nl, unsigned group, unsigned n_ref) {
+        int r;
+
+        assert(nl);
+
+        r = hashmap_replace(nl->broadcast_group_refs, UINT_TO_PTR(group), UINT_TO_PTR(n_ref));
+        if (r < 0)
+                return r;
+
+        return 0;
+}
+
+static int broadcast_group_join(sd_netlink *nl, unsigned group) {
         int r;
 
         assert(nl);
@@ -81,6 +161,79 @@ int socket_join_broadcast_group(sd_netlink *nl, unsigned group) {
         return 0;
 }
 
+int socket_broadcast_group_ref(sd_netlink *nl, unsigned group) {
+        unsigned n_ref;
+        int r;
+
+        assert(nl);
+
+        n_ref = broadcast_group_get_ref(nl, group);
+
+        n_ref ++;
+
+        r = hashmap_ensure_allocated(&nl->broadcast_group_refs, NULL);
+        if (r < 0)
+                return r;
+
+        r = broadcast_group_set_ref(nl, group, n_ref);
+        if (r < 0)
+                return r;
+
+        if (n_ref > 1)
+                /* not yet in the group */
+                return 0;
+
+        r = broadcast_group_join(nl, group);
+        if (r < 0)
+                return r;
+
+        return 0;
+}
+
+static int broadcast_group_leave(sd_netlink *nl, unsigned group) {
+        int r;
+
+        assert(nl);
+        assert(nl->fd >= 0);
+        assert(group > 0);
+
+        if (nl->broadcast_group_dont_leave)
+                return 0;
+
+        r = setsockopt(nl->fd, SOL_NETLINK, NETLINK_DROP_MEMBERSHIP, &group, sizeof(group));
+        if (r < 0)
+                return -errno;
+
+        return 0;
+}
+
+int socket_broadcast_group_unref(sd_netlink *nl, unsigned group) {
+        unsigned n_ref;
+        int r;
+
+        assert(nl);
+
+        n_ref = broadcast_group_get_ref(nl, group);
+
+        assert(n_ref > 0);
+
+        n_ref --;
+
+        r = broadcast_group_set_ref(nl, group, n_ref);
+        if (r < 0)
+                return r;
+
+        if (n_ref > 0)
+                /* still refs left */
+                return 0;
+
+        r = broadcast_group_leave(nl, group);
+        if (r < 0)
+                return r;
+
+        return 0;
+}
+
 /* returns the number of bytes sent, or a negative error code */
 int socket_write_message(sd_netlink *nl, sd_netlink_message *m) {
         union {
index f4a0a35..5af2860 100644 (file)
@@ -183,10 +183,11 @@ sd_netlink *sd_netlink_unref(sd_netlink *rtnl) {
                 sd_event_unref(rtnl->event);
 
                 while ((f = rtnl->match_callbacks)) {
-                        LIST_REMOVE(match_callbacks, rtnl->match_callbacks, f);
-                        free(f);
+                        sd_netlink_remove_match(rtnl, f->type, f->callback, f->userdata);
                 }
 
+                hashmap_free(rtnl->broadcast_group_refs);
+
                 safe_close(rtnl->fd);
                 free(rtnl);
         }
@@ -857,29 +858,29 @@ int sd_netlink_add_match(sd_netlink *rtnl,
         switch (type) {
                 case RTM_NEWLINK:
                 case RTM_DELLINK:
-                        r = socket_join_broadcast_group(rtnl, RTNLGRP_LINK);
+                        r = socket_broadcast_group_ref(rtnl, RTNLGRP_LINK);
                         if (r < 0)
                                 return r;
 
                         break;
                 case RTM_NEWADDR:
                 case RTM_DELADDR:
-                        r = socket_join_broadcast_group(rtnl, RTNLGRP_IPV4_IFADDR);
+                        r = socket_broadcast_group_ref(rtnl, RTNLGRP_IPV4_IFADDR);
                         if (r < 0)
                                 return r;
 
-                        r = socket_join_broadcast_group(rtnl, RTNLGRP_IPV6_IFADDR);
+                        r = socket_broadcast_group_ref(rtnl, RTNLGRP_IPV6_IFADDR);
                         if (r < 0)
                                 return r;
 
                         break;
                 case RTM_NEWROUTE:
                 case RTM_DELROUTE:
-                        r = socket_join_broadcast_group(rtnl, RTNLGRP_IPV4_ROUTE);
+                        r = socket_broadcast_group_ref(rtnl, RTNLGRP_IPV4_ROUTE);
                         if (r < 0)
                                 return r;
 
-                        r = socket_join_broadcast_group(rtnl, RTNLGRP_IPV6_ROUTE);
+                        r = socket_broadcast_group_ref(rtnl, RTNLGRP_IPV6_ROUTE);
                         if (r < 0)
                                 return r;
                         break;
@@ -899,23 +900,50 @@ int sd_netlink_remove_match(sd_netlink *rtnl,
                          sd_netlink_message_handler_t callback,
                          void *userdata) {
         struct match_callback *c;
+        int r;
 
         assert_return(rtnl, -EINVAL);
         assert_return(callback, -EINVAL);
         assert_return(!rtnl_pid_changed(rtnl), -ECHILD);
 
-        /* we should unsubscribe from the broadcast groups at this point, but it is not so
-           trivial for a few reasons: the refcounting is a bit of a mess and not obvious
-           how it will look like after we add genetlink support, and it is also not possible
-           to query what broadcast groups were subscribed to when we inherit the socket to get
-           the initial refcount. The latter could indeed be done for the first 32 broadcast
-           groups (which incidentally is all we currently support in .socket units anyway),
-           but we better not rely on only ever using 32 groups. */
         LIST_FOREACH(match_callbacks, c, rtnl->match_callbacks)
                 if (c->callback == callback && c->type == type && c->userdata == userdata) {
                         LIST_REMOVE(match_callbacks, rtnl->match_callbacks, c);
                         free(c);
 
+                        switch (type) {
+                                case RTM_NEWLINK:
+                                case RTM_DELLINK:
+                                        r = socket_broadcast_group_unref(rtnl, RTNLGRP_LINK);
+                                        if (r < 0)
+                                                return r;
+
+                                        break;
+                                case RTM_NEWADDR:
+                                case RTM_DELADDR:
+                                        r = socket_broadcast_group_unref(rtnl, RTNLGRP_IPV4_IFADDR);
+                                        if (r < 0)
+                                                return r;
+
+                                        r = socket_broadcast_group_unref(rtnl, RTNLGRP_IPV6_IFADDR);
+                                        if (r < 0)
+                                                return r;
+
+                                        break;
+                                case RTM_NEWROUTE:
+                                case RTM_DELROUTE:
+                                        r = socket_broadcast_group_unref(rtnl, RTNLGRP_IPV4_ROUTE);
+                                        if (r < 0)
+                                                return r;
+
+                                        r = socket_broadcast_group_unref(rtnl, RTNLGRP_IPV6_ROUTE);
+                                        if (r < 0)
+                                                return r;
+                                        break;
+                                default:
+                                        return -EOPNOTSUPP;
+                        }
+
                         return 1;
                 }