netlink: make NLA_BINARY validation more flexible
[platform/kernel/linux-starfive.git] / lib / nlattr.c
index bc5b5cf..665bdaf 100644 (file)
@@ -124,6 +124,7 @@ void nla_get_range_unsigned(const struct nla_policy *pt,
                range->max = U8_MAX;
                break;
        case NLA_U16:
+       case NLA_BINARY:
                range->max = U16_MAX;
                break;
        case NLA_U32:
@@ -140,6 +141,7 @@ void nla_get_range_unsigned(const struct nla_policy *pt,
 
        switch (pt->validation_type) {
        case NLA_VALIDATE_RANGE:
+       case NLA_VALIDATE_RANGE_WARN_TOO_LONG:
                range->min = pt->min;
                range->max = pt->max;
                break;
@@ -157,9 +159,10 @@ void nla_get_range_unsigned(const struct nla_policy *pt,
        }
 }
 
-static int nla_validate_int_range_unsigned(const struct nla_policy *pt,
-                                          const struct nlattr *nla,
-                                          struct netlink_ext_ack *extack)
+static int nla_validate_range_unsigned(const struct nla_policy *pt,
+                                      const struct nlattr *nla,
+                                      struct netlink_ext_ack *extack,
+                                      unsigned int validate)
 {
        struct netlink_range_validation range;
        u64 value;
@@ -178,15 +181,39 @@ static int nla_validate_int_range_unsigned(const struct nla_policy *pt,
        case NLA_MSECS:
                value = nla_get_u64(nla);
                break;
+       case NLA_BINARY:
+               value = nla_len(nla);
+               break;
        default:
                return -EINVAL;
        }
 
        nla_get_range_unsigned(pt, &range);
 
+       if (pt->validation_type == NLA_VALIDATE_RANGE_WARN_TOO_LONG &&
+           pt->type == NLA_BINARY && value > range.max) {
+               pr_warn_ratelimited("netlink: '%s': attribute type %d has an invalid length.\n",
+                                   current->comm, pt->type);
+               if (validate & NL_VALIDATE_STRICT_ATTRS) {
+                       NL_SET_ERR_MSG_ATTR(extack, nla,
+                                           "invalid attribute length");
+                       return -EINVAL;
+               }
+
+               /* this assumes min <= max (don't validate against min) */
+               return 0;
+       }
+
        if (value < range.min || value > range.max) {
-               NL_SET_ERR_MSG_ATTR(extack, nla,
-                                   "integer out of range");
+               bool binary = pt->type == NLA_BINARY;
+
+               if (binary)
+                       NL_SET_ERR_MSG_ATTR(extack, nla,
+                                           "binary attribute size out of range");
+               else
+                       NL_SET_ERR_MSG_ATTR(extack, nla,
+                                           "integer out of range");
+
                return -ERANGE;
        }
 
@@ -274,7 +301,8 @@ static int nla_validate_int_range_signed(const struct nla_policy *pt,
 
 static int nla_validate_int_range(const struct nla_policy *pt,
                                  const struct nlattr *nla,
-                                 struct netlink_ext_ack *extack)
+                                 struct netlink_ext_ack *extack,
+                                 unsigned int validate)
 {
        switch (pt->type) {
        case NLA_U8:
@@ -282,7 +310,8 @@ static int nla_validate_int_range(const struct nla_policy *pt,
        case NLA_U32:
        case NLA_U64:
        case NLA_MSECS:
-               return nla_validate_int_range_unsigned(pt, nla, extack);
+       case NLA_BINARY:
+               return nla_validate_range_unsigned(pt, nla, extack, validate);
        case NLA_S8:
        case NLA_S16:
        case NLA_S32:
@@ -313,10 +342,7 @@ static int validate_nla(const struct nlattr *nla, int maxtype,
 
        BUG_ON(pt->type > NLA_TYPE_MAX);
 
-       if ((nla_attr_len[pt->type] && attrlen != nla_attr_len[pt->type]) ||
-           (pt->type == NLA_EXACT_LEN &&
-            pt->validation_type == NLA_VALIDATE_WARN_TOO_LONG &&
-            attrlen != pt->len)) {
+       if (nla_attr_len[pt->type] && attrlen != nla_attr_len[pt->type]) {
                pr_warn_ratelimited("netlink: '%s': attribute type %d has an invalid length.\n",
                                    current->comm, type);
                if (validate & NL_VALIDATE_STRICT_ATTRS) {
@@ -449,19 +475,10 @@ static int validate_nla(const struct nlattr *nla, int maxtype,
                                            "Unsupported attribute");
                        return -EINVAL;
                }
-               /* fall through */
-       case NLA_MIN_LEN:
                if (attrlen < pt->len)
                        goto out_err;
                break;
 
-       case NLA_EXACT_LEN:
-               if (pt->validation_type != NLA_VALIDATE_WARN_TOO_LONG) {
-                       if (attrlen != pt->len)
-                               goto out_err;
-                       break;
-               }
-               /* fall through */
        default:
                if (pt->len)
                        minlen = pt->len;
@@ -479,9 +496,10 @@ static int validate_nla(const struct nlattr *nla, int maxtype,
                break;
        case NLA_VALIDATE_RANGE_PTR:
        case NLA_VALIDATE_RANGE:
+       case NLA_VALIDATE_RANGE_WARN_TOO_LONG:
        case NLA_VALIDATE_MIN:
        case NLA_VALIDATE_MAX:
-               err = nla_validate_int_range(pt, nla, extack);
+               err = nla_validate_int_range(pt, nla, extack, validate);
                if (err)
                        return err;
                break;