sd-netlink: save dynamic general netlink message type
authorYu Watanabe <watanabe.yu+github@gmail.com>
Tue, 23 Jul 2019 06:54:06 +0000 (15:54 +0900)
committerYu Watanabe <watanabe.yu+github@gmail.com>
Mon, 14 Oct 2019 16:57:19 +0000 (01:57 +0900)
src/libsystemd/meson.build
src/libsystemd/sd-netlink/generic-netlink.c
src/libsystemd/sd-netlink/generic-netlink.h [new file with mode: 0644]
src/libsystemd/sd-netlink/netlink-internal.h
src/libsystemd/sd-netlink/netlink-message.c
src/libsystemd/sd-netlink/netlink-socket.c
src/libsystemd/sd-netlink/netlink-types.c
src/libsystemd/sd-netlink/netlink-types.h
src/libsystemd/sd-netlink/sd-netlink.c
src/libsystemd/sd-netlink/test-netlink.c
src/systemd/sd-netlink.h

index 77fe6e7..aa1ed9b 100644 (file)
@@ -71,6 +71,7 @@ libsystemd_sources = files('''
         sd-hwdb/hwdb-util.h
         sd-hwdb/sd-hwdb.c
         sd-netlink/generic-netlink.c
+        sd-netlink/generic-netlink.h
         sd-netlink/netlink-internal.h
         sd-netlink/netlink-message.c
         sd-netlink/netlink-slot.c
index 32af79f..bfbfb0f 100644 (file)
@@ -3,8 +3,10 @@
 #include <linux/genetlink.h>
 
 #include "sd-netlink.h"
-#include "netlink-internal.h"
+
 #include "alloc-util.h"
+#include "generic-netlink.h"
+#include "netlink-internal.h"
 
 typedef struct {
         const char* name;
@@ -25,12 +27,12 @@ int sd_genl_socket_open(sd_netlink **ret) {
 static int lookup_id(sd_netlink *nl, sd_genl_family family, uint16_t *id);
 
 static int genl_message_new(sd_netlink *nl, sd_genl_family family, uint16_t nlmsg_type, uint8_t cmd, sd_netlink_message **ret) {
-        int r;
-        struct genlmsghdr *genl;
+        _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
         const NLType *genl_cmd_type, *nl_type;
         const NLTypeSystem *type_system;
+        struct genlmsghdr *genl;
         size_t size;
-        _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
+        int r;
 
         assert_return(nl->protocol == NETLINK_GENERIC, -EINVAL);
 
@@ -69,21 +71,33 @@ static int genl_message_new(sd_netlink *nl, sd_genl_family family, uint16_t nlms
 }
 
 int sd_genl_message_new(sd_netlink *nl, sd_genl_family family, uint8_t cmd, sd_netlink_message **ret) {
+        uint16_t id;
         int r;
-        uint16_t id = GENL_ID_CTRL;
 
-        if (family != SD_GENL_ID_CTRL) {
-                r = lookup_id(nl, family, &id);
-                if (r < 0)
-                        return r;
-        }
+        r = lookup_id(nl, family, &id);
+        if (r < 0)
+                return r;
 
         return genl_message_new(nl, family, id, cmd, ret);
 }
 
 static int lookup_id(sd_netlink *nl, sd_genl_family family, uint16_t *id) {
-        int r;
         _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *req = NULL, *reply = NULL;
+        uint16_t u;
+        void *v;
+        int r;
+
+        if (family == SD_GENL_ID_CTRL) {
+                *id = GENL_ID_CTRL;
+                return 0;
+        }
+
+        v = hashmap_get(nl->genl_family_to_nlmsg_type, INT_TO_PTR(family));
+        if (v) {
+                *id = PTR_TO_UINT(v);
+                return 0;
+        }
+
 
         r = sd_genl_message_new(nl, SD_GENL_ID_CTRL, CTRL_CMD_GETFAMILY, &req);
         if (r < 0)
@@ -97,5 +111,48 @@ static int lookup_id(sd_netlink *nl, sd_genl_family family, uint16_t *id) {
         if (r < 0)
                 return r;
 
-        return sd_netlink_message_read_u16(reply, CTRL_ATTR_FAMILY_ID, id);
+        r = sd_netlink_message_read_u16(reply, CTRL_ATTR_FAMILY_ID, &u);
+        if (r < 0)
+                return r;
+
+        r = hashmap_ensure_allocated(&nl->genl_family_to_nlmsg_type, NULL);
+        if (r < 0)
+                return r;
+
+        r = hashmap_ensure_allocated(&nl->nlmsg_type_to_genl_family, NULL);
+        if (r < 0)
+                return r;
+
+        r = hashmap_put(nl->genl_family_to_nlmsg_type, INT_TO_PTR(family), UINT_TO_PTR(u));
+        if (r < 0)
+                return r;
+
+        r = hashmap_put(nl->nlmsg_type_to_genl_family, UINT_TO_PTR(u), INT_TO_PTR(family));
+        if (r < 0)
+                return r;
+
+        *id = u;
+        return 0;
+}
+
+int nlmsg_type_to_genl_family(sd_netlink *nl, uint16_t type, sd_genl_family *ret) {
+        void *p;
+
+        assert_return(nl, -EINVAL);
+        assert_return(nl->protocol == NETLINK_GENERIC, -EINVAL);
+        assert(ret);
+
+        if (type == NLMSG_ERROR)
+                *ret = SD_GENL_ERROR;
+        else if (type == GENL_ID_CTRL)
+                *ret = SD_GENL_ID_CTRL;
+        else {
+                p = hashmap_get(nl->nlmsg_type_to_genl_family, UINT_TO_PTR(type));
+                if (!p)
+                        return -EOPNOTSUPP;
+
+                *ret = PTR_TO_INT(p);
+        }
+
+        return 0;
 }
diff --git a/src/libsystemd/sd-netlink/generic-netlink.h b/src/libsystemd/sd-netlink/generic-netlink.h
new file mode 100644 (file)
index 0000000..82afe4e
--- /dev/null
@@ -0,0 +1,6 @@
+/* SPDX-License-Identifier: LGPL-2.1+ */
+#pragma once
+
+#include "sd-netlink.h"
+
+int nlmsg_type_to_genl_family(sd_netlink *nl, uint16_t type, sd_genl_family *ret);
index 13e7ab6..93f495f 100644 (file)
@@ -98,6 +98,9 @@ struct sd_netlink {
         sd_event_source *time_event_source;
         sd_event_source *exit_event_source;
         sd_event *event;
+
+        Hashmap *genl_family_to_nlmsg_type;
+        Hashmap *nlmsg_type_to_genl_family;
 };
 
 struct netlink_attribute {
index a30911b..bfbfcb2 100644 (file)
@@ -49,15 +49,12 @@ int message_new_empty(sd_netlink *rtnl, sd_netlink_message **ret) {
 int message_new(sd_netlink *rtnl, sd_netlink_message **ret, uint16_t type) {
         _cleanup_(sd_netlink_message_unrefp) sd_netlink_message *m = NULL;
         const NLType *nl_type;
-        const NLTypeSystem *type_system_root;
         size_t size;
         int r;
 
         assert_return(rtnl, -EINVAL);
 
-        type_system_root = type_system_get_root(rtnl->protocol);
-
-        r = type_system_get_type(type_system_root, &nl_type, type);
+        r = type_system_root_get_type(rtnl, &nl_type, type);
         if (r < 0)
                 return r;
 
@@ -1025,22 +1022,20 @@ int sd_netlink_message_get_errno(sd_netlink_message *m) {
         return err->error;
 }
 
-int sd_netlink_message_rewind(sd_netlink_message *m) {
+int sd_netlink_message_rewind(sd_netlink_message *m, sd_netlink *genl) {
         const NLType *nl_type;
-        const NLTypeSystem *type_system_root;
         uint16_t type;
         size_t size;
         unsigned i;
         int r;
 
         assert_return(m, -EINVAL);
+        assert_return(genl || m->protocol != NETLINK_GENERIC, -EINVAL);
 
         /* don't allow appending to message once parsed */
         if (!m->sealed)
                 rtnl_message_seal(m);
 
-        type_system_root = type_system_get_root(m->protocol);
-
         for (i = 1; i <= m->n_containers; i++)
                 m->containers[i].attributes = mfree(m->containers[i].attributes);
 
@@ -1052,7 +1047,7 @@ int sd_netlink_message_rewind(sd_netlink_message *m) {
 
         assert(m->hdr);
 
-        r = type_system_get_type(type_system_root, &nl_type, m->hdr->nlmsg_type);
+        r = type_system_root_get_type(genl, &nl_type, m->hdr->nlmsg_type);
         if (r < 0)
                 return r;
 
index 98edb7e..7331aa1 100644 (file)
@@ -313,14 +313,11 @@ int socket_read_message(sd_netlink *rtnl) {
         size_t len;
         int r;
         unsigned i = 0;
-        const NLTypeSystem *type_system_root;
 
         assert(rtnl);
         assert(rtnl->rbuffer);
         assert(rtnl->rbuffer_allocated >= sizeof(struct nlmsghdr));
 
-        type_system_root = type_system_get_root(rtnl->protocol);
-
         /* read nothing, just get the pending message size */
         r = socket_recv_message(rtnl->fd, &iov, NULL, true);
         if (r <= 0)
@@ -381,7 +378,7 @@ int socket_read_message(sd_netlink *rtnl) {
                 }
 
                 /* check that we support this message type */
-                r = type_system_get_type(type_system_root, &nl_type, new_msg->nlmsg_type);
+                r = type_system_root_get_type(rtnl, &nl_type, new_msg->nlmsg_type);
                 if (r < 0) {
                         if (r == -EOPNOTSUPP)
                                 log_debug("sd-netlink: ignored message with unknown type: %i",
@@ -407,7 +404,7 @@ int socket_read_message(sd_netlink *rtnl) {
                         return -ENOMEM;
 
                 /* seal and parse the top-level message */
-                r = sd_netlink_message_rewind(m);
+                r = sd_netlink_message_rewind(m, rtnl);
                 if (r < 0)
                         return r;
 
index 6db0ffd..ae4b6c7 100644 (file)
 #include <linux/veth.h>
 #include <linux/wireguard.h>
 
+#include "sd-netlink.h"
+
+#include "generic-netlink.h"
+#include "hashmap.h"
 #include "macro.h"
 #include "missing.h"
+#include "netlink-internal.h"
 #include "netlink-types.h"
-#include "sd-netlink.h"
 #include "string-table.h"
 #include "util.h"
 
@@ -992,16 +996,18 @@ static const NLType genl_families[] = {
         [SD_GENL_MACSEC]    = { .type = NETLINK_TYPE_NESTED, .type_system = &genl_macsec_device_type_system },
 };
 
+/* Mainly used when sending message */
 const NLTypeSystem genl_family_type_system_root = {
         .count = ELEMENTSOF(genl_families),
         .types = genl_families,
 };
 
 static const NLType genl_types[] = {
-        [NLMSG_ERROR]  = { .type = NETLINK_TYPE_NESTED, .type_system = &empty_type_system, .size = sizeof(struct nlmsgerr) },
-        [GENL_ID_CTRL] = { .type = NETLINK_TYPE_NESTED, .type_system = &genl_get_family_type_system, .size = sizeof(struct genlmsghdr) },
+        [SD_GENL_ERROR]   = { .type = NETLINK_TYPE_NESTED, .type_system = &empty_type_system, .size = sizeof(struct nlmsgerr) },
+        [SD_GENL_ID_CTRL] = { .type = NETLINK_TYPE_NESTED, .type_system = &genl_get_family_type_system, .size = sizeof(struct genlmsghdr) },
 };
 
+/* Mainly used when message received */
 const NLTypeSystem genl_type_system_root = {
         .count = ELEMENTSOF(genl_types),
         .types = genl_types,
@@ -1049,6 +1055,31 @@ const NLTypeSystem *type_system_get_root(int protocol) {
         }
 }
 
+int type_system_root_get_type(sd_netlink *nl, const NLType **ret, uint16_t type) {
+        sd_genl_family family;
+        const NLType *nl_type;
+        int r;
+
+        if (!nl || nl->protocol != NETLINK_GENERIC)
+                return type_system_get_type(&rtnl_type_system_root, ret, type);
+
+        r = nlmsg_type_to_genl_family(nl, type, &family);
+        if (r < 0)
+                return r;
+
+        if (family >= genl_type_system_root.count)
+                return -EOPNOTSUPP;
+
+        nl_type = &genl_type_system_root.types[family];
+
+        if (nl_type->type == NETLINK_TYPE_UNSPEC)
+                return -EOPNOTSUPP;
+
+        *ret = nl_type;
+
+        return 0;
+}
+
 int type_system_get_type(const NLTypeSystem *type_system, const NLType **ret, uint16_t type) {
         const NLType *nl_type;
 
index 45a2a38..9bc6f68 100644 (file)
@@ -45,6 +45,7 @@ void type_get_type_system_union(const NLType *type, const NLTypeSystemUnion **re
 
 const NLTypeSystem* type_system_get_root(int protocol);
 uint16_t type_system_get_count(const NLTypeSystem *type_system);
+int type_system_root_get_type(sd_netlink *nl, const NLType **ret, uint16_t type);
 int type_system_get_type(const NLTypeSystem *type_system, const NLType **ret, uint16_t type);
 int type_system_get_type_system(const NLTypeSystem *type_system, const NLTypeSystem **ret, uint16_t type);
 int type_system_get_type_system_union(const NLTypeSystem *type_system, const NLTypeSystemUnion **ret, uint16_t type);
index f3366d1..ce2ad36 100644 (file)
@@ -178,6 +178,9 @@ static sd_netlink *netlink_free(sd_netlink *rtnl) {
 
         hashmap_free(rtnl->broadcast_group_refs);
 
+        hashmap_free(rtnl->genl_family_to_nlmsg_type);
+        hashmap_free(rtnl->nlmsg_type_to_genl_family);
+
         safe_close(rtnl->fd);
         return mfree(rtnl);
 }
index 868fcd0..379ad30 100644 (file)
@@ -26,7 +26,7 @@ static void test_message_link_bridge(sd_netlink *rtnl) {
         assert_se(sd_netlink_message_append_u32(message, IFLA_BRPORT_COST, 10) >= 0);
         assert_se(sd_netlink_message_close_container(message) >= 0);
 
-        assert_se(sd_netlink_message_rewind(message) >= 0);
+        assert_se(sd_netlink_message_rewind(message, NULL) >= 0);
 
         assert_se(sd_netlink_message_enter_container(message, IFLA_PROTINFO) >= 0);
         assert_se(sd_netlink_message_read_u32(message, IFLA_BRPORT_COST, &cost) >= 0);
@@ -49,7 +49,7 @@ static void test_link_configure(sd_netlink *rtnl, int ifindex) {
         assert_se(sd_netlink_message_append_u32(message, IFLA_MTU, mtu) >= 0);
 
         assert_se(sd_netlink_call(rtnl, message, 0, NULL) == 1);
-        assert_se(sd_netlink_message_rewind(message) >= 0);
+        assert_se(sd_netlink_message_rewind(message, NULL) >= 0);
 
         assert_se(sd_netlink_message_read_string(message, IFLA_IFNAME, &name_out) >= 0);
         assert_se(streq(name, name_out));
@@ -153,7 +153,7 @@ static void test_route(sd_netlink *rtnl) {
                 return;
         }
 
-        assert_se(sd_netlink_message_rewind(req) >= 0);
+        assert_se(sd_netlink_message_rewind(req, NULL) >= 0);
 
         assert_se(sd_netlink_message_read_in_addr(req, RTA_GATEWAY, &addr_data) >= 0);
         assert_se(addr_data.s_addr == addr.s_addr);
@@ -439,7 +439,7 @@ static void test_container(sd_netlink *rtnl) {
         assert_se(sd_netlink_message_close_container(m) >= 0);
         assert_se(sd_netlink_message_close_container(m) == -EINVAL);
 
-        assert_se(sd_netlink_message_rewind(m) >= 0);
+        assert_se(sd_netlink_message_rewind(m, NULL) >= 0);
 
         assert_se(sd_netlink_message_enter_container(m, IFLA_LINKINFO) >= 0);
         assert_se(sd_netlink_message_read_string(m, IFLA_INFO_KIND, &string_data) >= 0);
@@ -530,7 +530,7 @@ static void test_array(void) {
         assert_se(sd_netlink_message_close_container(m) >= 0);
 
         rtnl_message_seal(m);
-        assert_se(sd_netlink_message_rewind(m) >= 0);
+        assert_se(sd_netlink_message_rewind(m, genl) >= 0);
 
         assert_se(sd_netlink_message_enter_container(m, CTRL_ATTR_MCAST_GROUPS) >= 0);
         for (unsigned i = 0; i < 10; i++) {
index 62c5ef0..04bb2e5 100644 (file)
@@ -35,6 +35,7 @@ typedef struct sd_netlink_message sd_netlink_message;
 typedef struct sd_netlink_slot sd_netlink_slot;
 
 typedef enum sd_gen_family {
+        SD_GENL_ERROR,
         SD_GENL_ID_CTRL,
         SD_GENL_WIREGUARD,
         SD_GENL_FOU,
@@ -111,7 +112,7 @@ int sd_netlink_message_exit_container(sd_netlink_message *m);
 int sd_netlink_message_open_array(sd_netlink_message *m, uint16_t type);
 int sd_netlink_message_cancel_array(sd_netlink_message *m);
 
-int sd_netlink_message_rewind(sd_netlink_message *m);
+int sd_netlink_message_rewind(sd_netlink_message *m, sd_netlink *genl);
 
 sd_netlink_message *sd_netlink_message_next(sd_netlink_message *m);