net: devlink: move net check into devlinks_xa_for_each_registered_get()
authorJiri Pirko <jiri@nvidia.com>
Mon, 25 Jul 2022 08:29:15 +0000 (10:29 +0200)
committerJakub Kicinski <kuba@kernel.org>
Tue, 26 Jul 2022 20:50:50 +0000 (13:50 -0700)
Benefit from having devlinks iterator helper
devlinks_xa_for_each_registered_get() and move the net pointer
check inside.

Suggested-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: Jiri Pirko <jiri@nvidia.com>
Reviewed-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
net/core/devlink.c

index c7abd92..865232a 100644 (file)
@@ -289,7 +289,7 @@ void devl_unlock(struct devlink *devlink)
 EXPORT_SYMBOL_GPL(devl_unlock);
 
 static struct devlink *
-devlinks_xa_find_get(unsigned long *indexp, xa_mark_t filter,
+devlinks_xa_find_get(struct net *net, unsigned long *indexp, xa_mark_t filter,
                     void * (*xa_find_fn)(struct xarray *, unsigned long *,
                                          unsigned long, xa_mark_t))
 {
@@ -304,33 +304,40 @@ retry:
        xa_find_fn = xa_find_after;
        if (!devlink_try_get(devlink))
                goto retry;
+       if (!net_eq(devlink_net(devlink), net)) {
+               devlink_put(devlink);
+               goto retry;
+       }
 unlock:
        rcu_read_unlock();
        return devlink;
 }
 
-static struct devlink *devlinks_xa_find_get_first(unsigned long *indexp,
+static struct devlink *devlinks_xa_find_get_first(struct net *net,
+                                                 unsigned long *indexp,
                                                  xa_mark_t filter)
 {
-       return devlinks_xa_find_get(indexp, filter, xa_find);
+       return devlinks_xa_find_get(net, indexp, filter, xa_find);
 }
 
-static struct devlink *devlinks_xa_find_get_next(unsigned long *indexp,
+static struct devlink *devlinks_xa_find_get_next(struct net *net,
+                                                unsigned long *indexp,
                                                 xa_mark_t filter)
 {
-       return devlinks_xa_find_get(indexp, filter, xa_find_after);
+       return devlinks_xa_find_get(net, indexp, filter, xa_find_after);
 }
 
 /* Iterate over devlink pointers which were possible to get reference to.
  * devlink_put() needs to be called for each iterated devlink pointer
  * in loop body in order to release the reference.
  */
-#define devlinks_xa_for_each_get(index, devlink, filter)                       \
-       for (index = 0, devlink = devlinks_xa_find_get_first(&index, filter);   \
-            devlink; devlink = devlinks_xa_find_get_next(&index, filter))
+#define devlinks_xa_for_each_get(net, index, devlink, filter)                  \
+       for (index = 0,                                                         \
+            devlink = devlinks_xa_find_get_first(net, &index, filter);         \
+            devlink; devlink = devlinks_xa_find_get_next(net, &index, filter))
 
-#define devlinks_xa_for_each_registered_get(index, devlink)                    \
-       devlinks_xa_for_each_get(index, devlink, DEVLINK_REGISTERED)
+#define devlinks_xa_for_each_registered_get(net, index, devlink)               \
+       devlinks_xa_for_each_get(net, index, devlink, DEVLINK_REGISTERED)
 
 static struct devlink *devlink_get_from_attrs(struct net *net,
                                              struct nlattr **attrs)
@@ -346,10 +353,9 @@ static struct devlink *devlink_get_from_attrs(struct net *net,
        busname = nla_data(attrs[DEVLINK_ATTR_BUS_NAME]);
        devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]);
 
-       devlinks_xa_for_each_registered_get(index, devlink) {
+       devlinks_xa_for_each_registered_get(net, index, devlink) {
                if (strcmp(devlink->dev->bus->name, busname) == 0 &&
-                   strcmp(dev_name(devlink->dev), devname) == 0 &&
-                   net_eq(devlink_net(devlink), net))
+                   strcmp(dev_name(devlink->dev), devname) == 0)
                        return devlink;
                devlink_put(devlink);
        }
@@ -1376,10 +1382,7 @@ static int devlink_nl_cmd_rate_get_dumpit(struct sk_buff *msg,
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(devlink_rate, &devlink->rate_list, list) {
                        enum devlink_command cmd = DEVLINK_CMD_RATE_NEW;
@@ -1400,7 +1403,6 @@ static int devlink_nl_cmd_rate_get_dumpit(struct sk_buff *msg,
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
@@ -1476,12 +1478,7 @@ static int devlink_nl_cmd_get_dumpit(struct sk_buff *msg,
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) {
-                       devlink_put(devlink);
-                       continue;
-               }
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                if (idx < start) {
                        idx++;
                        devlink_put(devlink);
@@ -1536,10 +1533,7 @@ static int devlink_nl_cmd_port_get_dumpit(struct sk_buff *msg,
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(devlink_port, &devlink->port_list, list) {
                        if (idx < start) {
@@ -1559,7 +1553,6 @@ static int devlink_nl_cmd_port_get_dumpit(struct sk_buff *msg,
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
@@ -2215,10 +2208,7 @@ static int devlink_nl_cmd_linecard_get_dumpit(struct sk_buff *msg,
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                mutex_lock(&devlink->linecards_lock);
                list_for_each_entry(linecard, &devlink->linecard_list, list) {
                        if (idx < start) {
@@ -2241,7 +2231,6 @@ static int devlink_nl_cmd_linecard_get_dumpit(struct sk_buff *msg,
                        idx++;
                }
                mutex_unlock(&devlink->linecards_lock);
-retry:
                devlink_put(devlink);
        }
 out:
@@ -2484,10 +2473,7 @@ static int devlink_nl_cmd_sb_get_dumpit(struct sk_buff *msg,
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
                        if (idx < start) {
@@ -2507,7 +2493,6 @@ static int devlink_nl_cmd_sb_get_dumpit(struct sk_buff *msg,
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
@@ -2633,7 +2618,7 @@ static int devlink_nl_cmd_sb_pool_get_dumpit(struct sk_buff *msg,
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
                    !devlink->ops->sb_pool_get)
                        goto retry;
@@ -2851,9 +2836,8 @@ static int devlink_nl_cmd_sb_port_pool_get_dumpit(struct sk_buff *msg,
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
-                   !devlink->ops->sb_port_pool_get)
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
+               if (!devlink->ops->sb_port_pool_get)
                        goto retry;
 
                devl_lock(devlink);
@@ -3097,9 +3081,8 @@ devlink_nl_cmd_sb_tc_pool_bind_get_dumpit(struct sk_buff *msg,
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
-                   !devlink->ops->sb_tc_pool_bind_get)
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
+               if (!devlink->ops->sb_tc_pool_bind_get)
                        goto retry;
 
                devl_lock(devlink);
@@ -5181,10 +5164,7 @@ static int devlink_nl_cmd_param_get_dumpit(struct sk_buff *msg,
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(param_item, &devlink->param_list, list) {
                        if (idx < start) {
@@ -5206,7 +5186,6 @@ static int devlink_nl_cmd_param_get_dumpit(struct sk_buff *msg,
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
@@ -5413,10 +5392,7 @@ static int devlink_nl_cmd_port_param_get_dumpit(struct sk_buff *msg,
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(devlink_port, &devlink->port_list, list) {
                        list_for_each_entry(param_item,
@@ -5443,7 +5419,6 @@ static int devlink_nl_cmd_port_param_get_dumpit(struct sk_buff *msg,
                        }
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
@@ -5994,13 +5969,9 @@ static int devlink_nl_cmd_region_get_dumpit(struct sk_buff *msg,
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                err = devlink_nl_cmd_region_get_devlink_dumpit(msg, cb, devlink,
                                                               &idx, start);
-retry:
                devlink_put(devlink);
                if (err)
                        goto out;
@@ -6525,10 +6496,7 @@ static int devlink_nl_cmd_info_get_dumpit(struct sk_buff *msg,
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                if (idx < start || !devlink->ops->info_get)
                        goto inc;
 
@@ -6546,7 +6514,6 @@ static int devlink_nl_cmd_info_get_dumpit(struct sk_buff *msg,
                }
 inc:
                idx++;
-retry:
                devlink_put(devlink);
        }
        mutex_unlock(&devlink_mutex);
@@ -7702,10 +7669,7 @@ devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg,
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry_rep;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                mutex_lock(&devlink->reporters_lock);
                list_for_each_entry(reporter, &devlink->reporter_list,
                                    list) {
@@ -7725,14 +7689,10 @@ devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg,
                        idx++;
                }
                mutex_unlock(&devlink->reporters_lock);
-retry_rep:
                devlink_put(devlink);
        }
 
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry_port;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(port, &devlink->port_list, list) {
                        mutex_lock(&port->reporters_lock);
@@ -7757,7 +7717,6 @@ retry_rep:
                        mutex_unlock(&port->reporters_lock);
                }
                devl_unlock(devlink);
-retry_port:
                devlink_put(devlink);
        }
 out:
@@ -8296,10 +8255,7 @@ static int devlink_nl_cmd_trap_get_dumpit(struct sk_buff *msg,
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(trap_item, &devlink->trap_list, list) {
                        if (idx < start) {
@@ -8319,7 +8275,6 @@ static int devlink_nl_cmd_trap_get_dumpit(struct sk_buff *msg,
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
@@ -8520,10 +8475,7 @@ static int devlink_nl_cmd_trap_group_get_dumpit(struct sk_buff *msg,
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(group_item, &devlink->trap_group_list,
                                    list) {
@@ -8544,7 +8496,6 @@ static int devlink_nl_cmd_trap_group_get_dumpit(struct sk_buff *msg,
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
@@ -8831,10 +8782,7 @@ static int devlink_nl_cmd_trap_policer_get_dumpit(struct sk_buff *msg,
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(policer_item, &devlink->trap_policer_list,
                                    list) {
@@ -8855,7 +8803,6 @@ static int devlink_nl_cmd_trap_policer_get_dumpit(struct sk_buff *msg,
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
@@ -12273,10 +12220,7 @@ static void __net_exit devlink_pernet_pre_exit(struct net *net)
         * all devlink instances from this namespace into init_net.
         */
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), net))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(net, index, devlink) {
                WARN_ON(!(devlink->features & DEVLINK_F_RELOAD));
                err = devlink_reload(devlink, &init_net,
                                     DEVLINK_RELOAD_ACTION_DRIVER_REINIT,
@@ -12284,7 +12228,6 @@ static void __net_exit devlink_pernet_pre_exit(struct net *net)
                                     &actions_performed, NULL);
                if (err && err != -EOPNOTSUPP)
                        pr_warn("Failed to reload devlink instance into init_net\n");
-retry:
                devlink_put(devlink);
        }
        mutex_unlock(&devlink_mutex);