net: devlink: limit flash component name to match version returned by info_get()
authorJiri Pirko <jiri@nvidia.com>
Wed, 24 Aug 2022 12:20:11 +0000 (14:20 +0200)
committerJakub Kicinski <kuba@kernel.org>
Thu, 25 Aug 2022 20:22:53 +0000 (13:22 -0700)
Limit the acceptance of component name passed to cmd_flash_update() to
match one of the versions returned by info_get(), marked by version type.
This makes things clearer and enforces 1:1 mapping between exposed
version and accepted flash component.

Check VERSION_TYPE_COMPONENT version type during cmd_flash_update()
execution by calling info_get() with different "req" context.
That causes info_get() to lookup the component name instead of
filling-up the netlink message.

Remove "UPDATE_COMPONENT" flag which becomes used.

Signed-off-by: Jiri Pirko <jiri@nvidia.com>
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
drivers/net/netdevsim/dev.c
include/net/devlink.h
net/core/devlink.c

index d6938fa..efea94c 100644 (file)
@@ -1322,8 +1322,7 @@ nsim_dev_devlink_trap_drop_counter_get(struct devlink *devlink,
 static const struct devlink_ops nsim_dev_devlink_ops = {
        .eswitch_mode_set = nsim_devlink_eswitch_mode_set,
        .eswitch_mode_get = nsim_devlink_eswitch_mode_get,
-       .supported_flash_update_params = DEVLINK_SUPPORT_FLASH_UPDATE_COMPONENT |
-                                        DEVLINK_SUPPORT_FLASH_UPDATE_OVERWRITE_MASK,
+       .supported_flash_update_params = DEVLINK_SUPPORT_FLASH_UPDATE_OVERWRITE_MASK,
        .reload_actions = BIT(DEVLINK_RELOAD_ACTION_DRIVER_REINIT),
        .reload_down = nsim_dev_reload_down,
        .reload_up = nsim_dev_reload_up,
index f50a002..1f70260 100644 (file)
@@ -624,8 +624,7 @@ struct devlink_flash_update_params {
        u32 overwrite_mask;
 };
 
-#define DEVLINK_SUPPORT_FLASH_UPDATE_COMPONENT         BIT(0)
-#define DEVLINK_SUPPORT_FLASH_UPDATE_OVERWRITE_MASK    BIT(1)
+#define DEVLINK_SUPPORT_FLASH_UPDATE_OVERWRITE_MASK    BIT(0)
 
 struct devlink_region;
 struct devlink_info_req;
index 43c75b5..0f7078d 100644 (file)
@@ -4742,10 +4742,76 @@ void devlink_flash_update_timeout_notify(struct devlink *devlink,
 }
 EXPORT_SYMBOL_GPL(devlink_flash_update_timeout_notify);
 
+struct devlink_info_req {
+       struct sk_buff *msg;
+       void (*version_cb)(const char *version_name,
+                          enum devlink_info_version_type version_type,
+                          void *version_cb_priv);
+       void *version_cb_priv;
+};
+
+struct devlink_flash_component_lookup_ctx {
+       const char *lookup_name;
+       bool lookup_name_found;
+};
+
+static void
+devlink_flash_component_lookup_cb(const char *version_name,
+                                 enum devlink_info_version_type version_type,
+                                 void *version_cb_priv)
+{
+       struct devlink_flash_component_lookup_ctx *lookup_ctx = version_cb_priv;
+
+       if (version_type != DEVLINK_INFO_VERSION_TYPE_COMPONENT ||
+           lookup_ctx->lookup_name_found)
+               return;
+
+       lookup_ctx->lookup_name_found =
+               !strcmp(lookup_ctx->lookup_name, version_name);
+}
+
+static int devlink_flash_component_get(struct devlink *devlink,
+                                      struct nlattr *nla_component,
+                                      const char **p_component,
+                                      struct netlink_ext_ack *extack)
+{
+       struct devlink_flash_component_lookup_ctx lookup_ctx = {};
+       struct devlink_info_req req = {};
+       const char *component;
+       int ret;
+
+       if (!nla_component)
+               return 0;
+
+       component = nla_data(nla_component);
+
+       if (!devlink->ops->info_get) {
+               NL_SET_ERR_MSG_ATTR(extack, nla_component,
+                                   "component update is not supported by this device");
+               return -EOPNOTSUPP;
+       }
+
+       lookup_ctx.lookup_name = component;
+       req.version_cb = devlink_flash_component_lookup_cb;
+       req.version_cb_priv = &lookup_ctx;
+
+       ret = devlink->ops->info_get(devlink, &req, NULL);
+       if (ret)
+               return ret;
+
+       if (!lookup_ctx.lookup_name_found) {
+               NL_SET_ERR_MSG_ATTR(extack, nla_component,
+                                   "selected component is not supported by this device");
+               return -EINVAL;
+       }
+       *p_component = component;
+       return 0;
+}
+
 static int devlink_nl_cmd_flash_update(struct sk_buff *skb,
                                       struct genl_info *info)
 {
-       struct nlattr *nla_component, *nla_overwrite_mask, *nla_file_name;
+       struct nlattr *nla_overwrite_mask, *nla_file_name;
        struct devlink_flash_update_params params = {};
        struct devlink *devlink = info->user_ptr[0];
        const char *file_name;
@@ -4758,17 +4824,13 @@ static int devlink_nl_cmd_flash_update(struct sk_buff *skb,
        if (!info->attrs[DEVLINK_ATTR_FLASH_UPDATE_FILE_NAME])
                return -EINVAL;
 
-       supported_params = devlink->ops->supported_flash_update_params;
+       ret = devlink_flash_component_get(devlink,
+                                         info->attrs[DEVLINK_ATTR_FLASH_UPDATE_COMPONENT],
+                                         &params.component, info->extack);
+       if (ret)
+               return ret;
 
-       nla_component = info->attrs[DEVLINK_ATTR_FLASH_UPDATE_COMPONENT];
-       if (nla_component) {
-               if (!(supported_params & DEVLINK_SUPPORT_FLASH_UPDATE_COMPONENT)) {
-                       NL_SET_ERR_MSG_ATTR(info->extack, nla_component,
-                                           "component update is not supported by this device");
-                       return -EOPNOTSUPP;
-               }
-               params.component = nla_data(nla_component);
-       }
+       supported_params = devlink->ops->supported_flash_update_params;
 
        nla_overwrite_mask = info->attrs[DEVLINK_ATTR_FLASH_UPDATE_OVERWRITE_MASK];
        if (nla_overwrite_mask) {
@@ -6553,18 +6615,18 @@ out_unlock:
        return err;
 }
 
-struct devlink_info_req {
-       struct sk_buff *msg;
-};
-
 int devlink_info_driver_name_put(struct devlink_info_req *req, const char *name)
 {
+       if (!req->msg)
+               return 0;
        return nla_put_string(req->msg, DEVLINK_ATTR_INFO_DRIVER_NAME, name);
 }
 EXPORT_SYMBOL_GPL(devlink_info_driver_name_put);
 
 int devlink_info_serial_number_put(struct devlink_info_req *req, const char *sn)
 {
+       if (!req->msg)
+               return 0;
        return nla_put_string(req->msg, DEVLINK_ATTR_INFO_SERIAL_NUMBER, sn);
 }
 EXPORT_SYMBOL_GPL(devlink_info_serial_number_put);
@@ -6572,6 +6634,8 @@ EXPORT_SYMBOL_GPL(devlink_info_serial_number_put);
 int devlink_info_board_serial_number_put(struct devlink_info_req *req,
                                         const char *bsn)
 {
+       if (!req->msg)
+               return 0;
        return nla_put_string(req->msg, DEVLINK_ATTR_INFO_BOARD_SERIAL_NUMBER,
                              bsn);
 }
@@ -6585,6 +6649,13 @@ static int devlink_info_version_put(struct devlink_info_req *req, int attr,
        struct nlattr *nest;
        int err;
 
+       if (req->version_cb)
+               req->version_cb(version_name, version_type,
+                               req->version_cb_priv);
+
+       if (!req->msg)
+               return 0;
+
        nest = nla_nest_start_noflag(req->msg, attr);
        if (!nest)
                return -EMSGSIZE;
@@ -6665,7 +6736,7 @@ devlink_nl_info_fill(struct sk_buff *msg, struct devlink *devlink,
                     enum devlink_command cmd, u32 portid,
                     u32 seq, int flags, struct netlink_ext_ack *extack)
 {
-       struct devlink_info_req req;
+       struct devlink_info_req req = {};
        void *hdr;
        int err;
 
@@ -12332,8 +12403,8 @@ EXPORT_SYMBOL_GPL(devl_trap_policers_unregister);
 static void __devlink_compat_running_version(struct devlink *devlink,
                                             char *buf, size_t len)
 {
+       struct devlink_info_req req = {};
        const struct nlattr *nlattr;
-       struct devlink_info_req req;
        struct sk_buff *msg;
        int rem, err;