virtchnl: fix fake 1-elem arrays in structs allocated as `nents + 1` - 1
authorAlexander Lobakin <aleksander.lobakin@intel.com>
Fri, 28 Jul 2023 15:52:05 +0000 (17:52 +0200)
committerTony Nguyen <anthony.l.nguyen@intel.com>
Wed, 16 Aug 2023 15:56:56 +0000 (08:56 -0700)
The two most problematic virtchnl structures are virtchnl_rss_key and
virtchnl_rss_lut. Their "flex" arrays have the type of u8, thus, when
allocating / checking, the actual size is calculated as `sizeof +
nents - 1 byte`. But their sizeof() is not 1 byte larger than the size
of such structure with proper flex array, it's two bytes larger due to
the padding. That said, their size is always 1 byte larger unless
there are no tail elements -- then it's +2 bytes.
Add virtchnl_struct_size() macro which will handle this case (and later
other cases as well). Make its calling conv the same as we call
struct_size() to allow it to be drop-in, even though it's unlikely to
become possible to switch to generic API. The macro will calculate a
proper size of a structure with a flex array at the end, so that it
becomes transparent for the compilers, but add the difference from the
old values, so that the real size of sorta-ABI-messages doesn't change.
Use it on the allocation side in IAVF and the receiving side (defined
as static inline in virtchnl.h) for the mentioned two structures.

Signed-off-by: Alexander Lobakin <aleksander.lobakin@intel.com>
Reviewed-by: Kees Cook <keescook@chromium.org>
Tested-by: Rafal Romanowski <rafal.romanowski@intel.com>
Signed-off-by: Tony Nguyen <anthony.l.nguyen@intel.com>
drivers/net/ethernet/intel/iavf/iavf_virtchnl.c
include/linux/avf/virtchnl.h

index be3c007..10f0305 100644 (file)
@@ -1085,8 +1085,7 @@ void iavf_set_rss_key(struct iavf_adapter *adapter)
                        adapter->current_op);
                return;
        }
-       len = sizeof(struct virtchnl_rss_key) +
-             (adapter->rss_key_size * sizeof(u8)) - 1;
+       len = virtchnl_struct_size(vrk, key, adapter->rss_key_size);
        vrk = kzalloc(len, GFP_KERNEL);
        if (!vrk)
                return;
@@ -1117,8 +1116,7 @@ void iavf_set_rss_lut(struct iavf_adapter *adapter)
                        adapter->current_op);
                return;
        }
-       len = sizeof(struct virtchnl_rss_lut) +
-             (adapter->rss_lut_size * sizeof(u8)) - 1;
+       len = virtchnl_struct_size(vrl, lut, adapter->rss_lut_size);
        vrl = kzalloc(len, GFP_KERNEL);
        if (!vrl)
                return;
index c15221d..3ab207c 100644 (file)
@@ -866,18 +866,20 @@ VIRTCHNL_CHECK_STRUCT_LEN(4, virtchnl_promisc_info);
 struct virtchnl_rss_key {
        u16 vsi_id;
        u16 key_len;
-       u8 key[1];         /* RSS hash key, packed bytes */
+       u8 key[];          /* RSS hash key, packed bytes */
 };
 
-VIRTCHNL_CHECK_STRUCT_LEN(6, virtchnl_rss_key);
+VIRTCHNL_CHECK_STRUCT_LEN(4, virtchnl_rss_key);
+#define virtchnl_rss_key_LEGACY_SIZEOF 6
 
 struct virtchnl_rss_lut {
        u16 vsi_id;
        u16 lut_entries;
-       u8 lut[1];        /* RSS lookup table */
+       u8 lut[];         /* RSS lookup table */
 };
 
-VIRTCHNL_CHECK_STRUCT_LEN(6, virtchnl_rss_lut);
+VIRTCHNL_CHECK_STRUCT_LEN(4, virtchnl_rss_lut);
+#define virtchnl_rss_lut_LEGACY_SIZEOF 6
 
 /* VIRTCHNL_OP_GET_RSS_HENA_CAPS
  * VIRTCHNL_OP_SET_RSS_HENA
@@ -1367,6 +1369,17 @@ struct virtchnl_fdir_del {
 
 VIRTCHNL_CHECK_STRUCT_LEN(12, virtchnl_fdir_del);
 
+#define __vss_byone(p, member, count, old)                                   \
+       (struct_size(p, member, count) + (old - 1 - struct_size(p, member, 0)))
+
+#define __vss(type, func, p, member, count)            \
+       struct type: func(p, member, count, type##_LEGACY_SIZEOF)
+
+#define virtchnl_struct_size(p, m, c)                                        \
+       _Generic(*p,                                                          \
+                __vss(virtchnl_rss_key, __vss_byone, p, m, c),               \
+                __vss(virtchnl_rss_lut, __vss_byone, p, m, c))
+
 /**
  * virtchnl_vc_validate_vf_msg
  * @ver: Virtchnl version info
@@ -1479,19 +1492,21 @@ virtchnl_vc_validate_vf_msg(struct virtchnl_version_info *ver, u32 v_opcode,
                }
                break;
        case VIRTCHNL_OP_CONFIG_RSS_KEY:
-               valid_len = sizeof(struct virtchnl_rss_key);
+               valid_len = virtchnl_rss_key_LEGACY_SIZEOF;
                if (msglen >= valid_len) {
                        struct virtchnl_rss_key *vrk =
                                (struct virtchnl_rss_key *)msg;
-                       valid_len += vrk->key_len - 1;
+                       valid_len = virtchnl_struct_size(vrk, key,
+                                                        vrk->key_len);
                }
                break;
        case VIRTCHNL_OP_CONFIG_RSS_LUT:
-               valid_len = sizeof(struct virtchnl_rss_lut);
+               valid_len = virtchnl_rss_lut_LEGACY_SIZEOF;
                if (msglen >= valid_len) {
                        struct virtchnl_rss_lut *vrl =
                                (struct virtchnl_rss_lut *)msg;
-                       valid_len += vrl->lut_entries - 1;
+                       valid_len = virtchnl_struct_size(vrl, lut,
+                                                        vrl->lut_entries);
                }
                break;
        case VIRTCHNL_OP_GET_RSS_HENA_CAPS: