net/sched: act_pedit: Add size check for TCA_PEDIT_PARMS_EX
[platform/kernel/linux-starfive.git] / net / sched / act_pedit.c
index 94ed585..aee2e13 100644 (file)
 #include <linux/rtnetlink.h>
 #include <linux/module.h>
 #include <linux/init.h>
+#include <linux/ip.h>
+#include <linux/ipv6.h>
 #include <linux/slab.h>
+#include <net/ipv6.h>
 #include <net/netlink.h>
 #include <net/pkt_sched.h>
 #include <linux/tc_act/tc_pedit.h>
@@ -25,6 +28,7 @@ static struct tc_action_ops act_pedit_ops;
 
 static const struct nla_policy pedit_policy[TCA_PEDIT_MAX + 1] = {
        [TCA_PEDIT_PARMS]       = { .len = sizeof(struct tc_pedit) },
+       [TCA_PEDIT_PARMS_EX]    = { .len = sizeof(struct tc_pedit) },
        [TCA_PEDIT_KEYS_EX]   = { .type = NLA_NESTED },
 };
 
@@ -133,6 +137,17 @@ nla_failure:
        return -EINVAL;
 }
 
+static void tcf_pedit_cleanup_rcu(struct rcu_head *head)
+{
+       struct tcf_pedit_parms *parms =
+               container_of(head, struct tcf_pedit_parms, rcu);
+
+       kfree(parms->tcfp_keys_ex);
+       kfree(parms->tcfp_keys);
+
+       kfree(parms);
+}
+
 static int tcf_pedit_init(struct net *net, struct nlattr *nla,
                          struct nlattr *est, struct tc_action **a,
                          struct tcf_proto *tp, u32 flags,
@@ -140,10 +155,9 @@ static int tcf_pedit_init(struct net *net, struct nlattr *nla,
 {
        struct tc_action_net *tn = net_generic(net, act_pedit_ops.net_id);
        bool bind = flags & TCA_ACT_FLAGS_BIND;
-       struct nlattr *tb[TCA_PEDIT_MAX + 1];
        struct tcf_chain *goto_ch = NULL;
-       struct tc_pedit_key *keys = NULL;
-       struct tcf_pedit_key_ex *keys_ex;
+       struct tcf_pedit_parms *oparms, *nparms;
+       struct nlattr *tb[TCA_PEDIT_MAX + 1];
        struct tc_pedit *parm;
        struct nlattr *pattr;
        struct tcf_pedit *p;
@@ -170,109 +184,125 @@ static int tcf_pedit_init(struct net *net, struct nlattr *nla,
        }
 
        parm = nla_data(pattr);
-       if (!parm->nkeys) {
-               NL_SET_ERR_MSG_MOD(extack, "Pedit requires keys to be passed");
-               return -EINVAL;
-       }
-       ksize = parm->nkeys * sizeof(struct tc_pedit_key);
-       if (nla_len(pattr) < sizeof(*parm) + ksize) {
-               NL_SET_ERR_MSG_ATTR(extack, pattr, "Length of TCA_PEDIT_PARMS or TCA_PEDIT_PARMS_EX pedit attribute is invalid");
-               return -EINVAL;
-       }
-
-       keys_ex = tcf_pedit_keys_ex_parse(tb[TCA_PEDIT_KEYS_EX], parm->nkeys);
-       if (IS_ERR(keys_ex))
-               return PTR_ERR(keys_ex);
 
        index = parm->index;
        err = tcf_idr_check_alloc(tn, &index, a, bind);
        if (!err) {
-               ret = tcf_idr_create(tn, index, est, a,
-                                    &act_pedit_ops, bind, false, flags);
+               ret = tcf_idr_create_from_flags(tn, index, est, a,
+                                               &act_pedit_ops, bind, flags);
                if (ret) {
                        tcf_idr_cleanup(tn, index);
-                       goto out_free;
+                       return ret;
                }
                ret = ACT_P_CREATED;
        } else if (err > 0) {
                if (bind)
-                       goto out_free;
+                       return 0;
                if (!(flags & TCA_ACT_FLAGS_REPLACE)) {
                        ret = -EEXIST;
                        goto out_release;
                }
        } else {
-               ret = err;
+               return err;
+       }
+
+       if (!parm->nkeys) {
+               NL_SET_ERR_MSG_MOD(extack, "Pedit requires keys to be passed");
+               ret = -EINVAL;
+               goto out_release;
+       }
+       ksize = parm->nkeys * sizeof(struct tc_pedit_key);
+       if (nla_len(pattr) < sizeof(*parm) + ksize) {
+               NL_SET_ERR_MSG_ATTR(extack, pattr, "Length of TCA_PEDIT_PARMS or TCA_PEDIT_PARMS_EX pedit attribute is invalid");
+               ret = -EINVAL;
+               goto out_release;
+       }
+
+       nparms = kzalloc(sizeof(*nparms), GFP_KERNEL);
+       if (!nparms) {
+               ret = -ENOMEM;
+               goto out_release;
+       }
+
+       nparms->tcfp_keys_ex =
+               tcf_pedit_keys_ex_parse(tb[TCA_PEDIT_KEYS_EX], parm->nkeys);
+       if (IS_ERR(nparms->tcfp_keys_ex)) {
+               ret = PTR_ERR(nparms->tcfp_keys_ex);
                goto out_free;
        }
 
        err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack);
        if (err < 0) {
                ret = err;
-               goto out_release;
+               goto out_free_ex;
        }
-       p = to_pedit(*a);
-       spin_lock_bh(&p->tcf_lock);
 
-       if (ret == ACT_P_CREATED ||
-           (p->tcfp_nkeys && p->tcfp_nkeys != parm->nkeys)) {
-               keys = kmalloc(ksize, GFP_ATOMIC);
-               if (!keys) {
-                       spin_unlock_bh(&p->tcf_lock);
-                       ret = -ENOMEM;
-                       goto put_chain;
-               }
-               kfree(p->tcfp_keys);
-               p->tcfp_keys = keys;
-               p->tcfp_nkeys = parm->nkeys;
+       nparms->tcfp_off_max_hint = 0;
+       nparms->tcfp_flags = parm->flags;
+       nparms->tcfp_nkeys = parm->nkeys;
+
+       nparms->tcfp_keys = kmalloc(ksize, GFP_KERNEL);
+       if (!nparms->tcfp_keys) {
+               ret = -ENOMEM;
+               goto put_chain;
        }
-       memcpy(p->tcfp_keys, parm->keys, ksize);
-       p->tcfp_off_max_hint = 0;
-       for (i = 0; i < p->tcfp_nkeys; ++i) {
-               u32 cur = p->tcfp_keys[i].off;
+
+       memcpy(nparms->tcfp_keys, parm->keys, ksize);
+
+       for (i = 0; i < nparms->tcfp_nkeys; ++i) {
+               u32 cur = nparms->tcfp_keys[i].off;
 
                /* sanitize the shift value for any later use */
-               p->tcfp_keys[i].shift = min_t(size_t, BITS_PER_TYPE(int) - 1,
-                                             p->tcfp_keys[i].shift);
+               nparms->tcfp_keys[i].shift = min_t(size_t,
+                                                  BITS_PER_TYPE(int) - 1,
+                                                  nparms->tcfp_keys[i].shift);
 
                /* The AT option can read a single byte, we can bound the actual
                 * value with uchar max.
                 */
-               cur += (0xff & p->tcfp_keys[i].offmask) >> p->tcfp_keys[i].shift;
+               cur += (0xff & nparms->tcfp_keys[i].offmask) >> nparms->tcfp_keys[i].shift;
 
                /* Each key touches 4 bytes starting from the computed offset */
-               p->tcfp_off_max_hint = max(p->tcfp_off_max_hint, cur + 4);
+               nparms->tcfp_off_max_hint =
+                       max(nparms->tcfp_off_max_hint, cur + 4);
        }
 
-       p->tcfp_flags = parm->flags;
+       p = to_pedit(*a);
+
+       spin_lock_bh(&p->tcf_lock);
        goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch);
+       oparms = rcu_replace_pointer(p->parms, nparms, 1);
+       spin_unlock_bh(&p->tcf_lock);
 
-       kfree(p->tcfp_keys_ex);
-       p->tcfp_keys_ex = keys_ex;
+       if (oparms)
+               call_rcu(&oparms->rcu, tcf_pedit_cleanup_rcu);
 
-       spin_unlock_bh(&p->tcf_lock);
        if (goto_ch)
                tcf_chain_put_by_act(goto_ch);
+
        return ret;
 
 put_chain:
        if (goto_ch)
                tcf_chain_put_by_act(goto_ch);
+out_free_ex:
+       kfree(nparms->tcfp_keys_ex);
+out_free:
+       kfree(nparms);
 out_release:
        tcf_idr_release(*a, bind);
-out_free:
-       kfree(keys_ex);
        return ret;
-
 }
 
 static void tcf_pedit_cleanup(struct tc_action *a)
 {
        struct tcf_pedit *p = to_pedit(a);
-       struct tc_pedit_key *keys = p->tcfp_keys;
+       struct tcf_pedit_parms *parms;
 
-       kfree(keys);
-       kfree(p->tcfp_keys_ex);
+       parms = rcu_dereference_protected(p->parms, 1);
+
+       if (parms)
+               call_rcu(&parms->rcu, tcf_pedit_cleanup_rcu);
 }
 
 static bool offset_valid(struct sk_buff *skb, int offset)
@@ -286,11 +316,35 @@ static bool offset_valid(struct sk_buff *skb, int offset)
        return true;
 }
 
-static int pedit_skb_hdr_offset(struct sk_buff *skb,
-                               enum pedit_header_type htype, int *hoffset)
+static int pedit_l4_skb_offset(struct sk_buff *skb, int *hoffset, const int header_type)
 {
+       const int noff = skb_network_offset(skb);
        int ret = -EINVAL;
+       struct iphdr _iph;
+
+       switch (skb->protocol) {
+       case htons(ETH_P_IP): {
+               const struct iphdr *iph = skb_header_pointer(skb, noff, sizeof(_iph), &_iph);
+
+               if (!iph)
+                       goto out;
+               *hoffset = noff + iph->ihl * 4;
+               ret = 0;
+               break;
+       }
+       case htons(ETH_P_IPV6):
+               ret = ipv6_find_hdr(skb, hoffset, header_type, NULL, NULL) == header_type ? 0 : -EINVAL;
+               break;
+       }
+out:
+       return ret;
+}
 
+static int pedit_skb_hdr_offset(struct sk_buff *skb,
+                                enum pedit_header_type htype, int *hoffset)
+{
+       int ret = -EINVAL;
+       /* 'htype' is validated in the netlink parsing */
        switch (htype) {
        case TCA_PEDIT_KEY_EX_HDR_TYPE_ETH:
                if (skb_mac_header_was_set(skb)) {
@@ -305,126 +359,120 @@ static int pedit_skb_hdr_offset(struct sk_buff *skb,
                ret = 0;
                break;
        case TCA_PEDIT_KEY_EX_HDR_TYPE_TCP:
+               ret = pedit_l4_skb_offset(skb, hoffset, IPPROTO_TCP);
+               break;
        case TCA_PEDIT_KEY_EX_HDR_TYPE_UDP:
-               if (skb_transport_header_was_set(skb)) {
-                       *hoffset = skb_transport_offset(skb);
-                       ret = 0;
-               }
+               ret = pedit_l4_skb_offset(skb, hoffset, IPPROTO_UDP);
                break;
        default:
-               ret = -EINVAL;
                break;
        }
-
        return ret;
 }
 
 static int tcf_pedit_act(struct sk_buff *skb, const struct tc_action *a,
                         struct tcf_result *res)
 {
+       enum pedit_header_type htype = TCA_PEDIT_KEY_EX_HDR_TYPE_NETWORK;
+       enum pedit_cmd cmd = TCA_PEDIT_KEY_EX_CMD_SET;
        struct tcf_pedit *p = to_pedit(a);
+       struct tcf_pedit_key_ex *tkey_ex;
+       struct tcf_pedit_parms *parms;
+       struct tc_pedit_key *tkey;
        u32 max_offset;
        int i;
 
-       spin_lock(&p->tcf_lock);
+       parms = rcu_dereference_bh(p->parms);
 
        max_offset = (skb_transport_header_was_set(skb) ?
                      skb_transport_offset(skb) :
                      skb_network_offset(skb)) +
-                    p->tcfp_off_max_hint;
+                    parms->tcfp_off_max_hint;
        if (skb_ensure_writable(skb, min(skb->len, max_offset)))
-               goto unlock;
+               goto done;
 
        tcf_lastuse_update(&p->tcf_tm);
+       tcf_action_update_bstats(&p->common, skb);
 
-       if (p->tcfp_nkeys > 0) {
-               struct tc_pedit_key *tkey = p->tcfp_keys;
-               struct tcf_pedit_key_ex *tkey_ex = p->tcfp_keys_ex;
-               enum pedit_header_type htype =
-                       TCA_PEDIT_KEY_EX_HDR_TYPE_NETWORK;
-               enum pedit_cmd cmd = TCA_PEDIT_KEY_EX_CMD_SET;
-
-               for (i = p->tcfp_nkeys; i > 0; i--, tkey++) {
-                       u32 *ptr, hdata;
-                       int offset = tkey->off;
-                       int hoffset;
-                       u32 val;
-                       int rc;
-
-                       if (tkey_ex) {
-                               htype = tkey_ex->htype;
-                               cmd = tkey_ex->cmd;
-
-                               tkey_ex++;
-                       }
+       tkey = parms->tcfp_keys;
+       tkey_ex = parms->tcfp_keys_ex;
 
-                       rc = pedit_skb_hdr_offset(skb, htype, &hoffset);
-                       if (rc) {
-                               pr_info("tc action pedit bad header type specified (0x%x)\n",
-                                       htype);
-                               goto bad;
-                       }
+       for (i = parms->tcfp_nkeys; i > 0; i--, tkey++) {
+               int offset = tkey->off;
+               int hoffset = 0;
+               u32 *ptr, hdata;
+               u32 val;
+               int rc;
 
-                       if (tkey->offmask) {
-                               u8 *d, _d;
-
-                               if (!offset_valid(skb, hoffset + tkey->at)) {
-                                       pr_info("tc action pedit 'at' offset %d out of bounds\n",
-                                               hoffset + tkey->at);
-                                       goto bad;
-                               }
-                               d = skb_header_pointer(skb, hoffset + tkey->at,
-                                                      sizeof(_d), &_d);
-                               if (!d)
-                                       goto bad;
-                               offset += (*d & tkey->offmask) >> tkey->shift;
-                       }
+               if (tkey_ex) {
+                       htype = tkey_ex->htype;
+                       cmd = tkey_ex->cmd;
 
-                       if (offset % 4) {
-                               pr_info("tc action pedit offset must be on 32 bit boundaries\n");
-                               goto bad;
-                       }
+                       tkey_ex++;
+               }
 
-                       if (!offset_valid(skb, hoffset + offset)) {
-                               pr_info("tc action pedit offset %d out of bounds\n",
-                                       hoffset + offset);
-                               goto bad;
-                       }
+               rc = pedit_skb_hdr_offset(skb, htype, &hoffset);
+               if (rc) {
+                       pr_info_ratelimited("tc action pedit unable to extract header offset for header type (0x%x)\n", htype);
+                       goto bad;
+               }
 
-                       ptr = skb_header_pointer(skb, hoffset + offset,
-                                                sizeof(hdata), &hdata);
-                       if (!ptr)
-                               goto bad;
-                       /* just do it, baby */
-                       switch (cmd) {
-                       case TCA_PEDIT_KEY_EX_CMD_SET:
-                               val = tkey->val;
-                               break;
-                       case TCA_PEDIT_KEY_EX_CMD_ADD:
-                               val = (*ptr + tkey->val) & ~tkey->mask;
-                               break;
-                       default:
-                               pr_info("tc action pedit bad command (%d)\n",
-                                       cmd);
+               if (tkey->offmask) {
+                       u8 *d, _d;
+
+                       if (!offset_valid(skb, hoffset + tkey->at)) {
+                               pr_info("tc action pedit 'at' offset %d out of bounds\n",
+                                       hoffset + tkey->at);
                                goto bad;
                        }
+                       d = skb_header_pointer(skb, hoffset + tkey->at,
+                                              sizeof(_d), &_d);
+                       if (!d)
+                               goto bad;
+                       offset += (*d & tkey->offmask) >> tkey->shift;
+               }
 
-                       *ptr = ((*ptr & tkey->mask) ^ val);
-                       if (ptr == &hdata)
-                               skb_store_bits(skb, hoffset + offset, ptr, 4);
+               if (offset % 4) {
+                       pr_info("tc action pedit offset must be on 32 bit boundaries\n");
+                       goto bad;
                }
 
-               goto done;
-       } else {
-               WARN(1, "pedit BUG: index %d\n", p->tcf_index);
+               if (!offset_valid(skb, hoffset + offset)) {
+                       pr_info("tc action pedit offset %d out of bounds\n",
+                               hoffset + offset);
+                       goto bad;
+               }
+
+               ptr = skb_header_pointer(skb, hoffset + offset,
+                                        sizeof(hdata), &hdata);
+               if (!ptr)
+                       goto bad;
+               /* just do it, baby */
+               switch (cmd) {
+               case TCA_PEDIT_KEY_EX_CMD_SET:
+                       val = tkey->val;
+                       break;
+               case TCA_PEDIT_KEY_EX_CMD_ADD:
+                       val = (*ptr + tkey->val) & ~tkey->mask;
+                       break;
+               default:
+                       pr_info("tc action pedit bad command (%d)\n",
+                               cmd);
+                       goto bad;
+               }
+
+               *ptr = ((*ptr & tkey->mask) ^ val);
+               if (ptr == &hdata)
+                       skb_store_bits(skb, hoffset + offset, ptr, 4);
        }
 
+       goto done;
+
 bad:
+       spin_lock(&p->tcf_lock);
        p->tcf_qstats.overlimits++;
-done:
-       bstats_update(&p->tcf_bstats, skb);
-unlock:
        spin_unlock(&p->tcf_lock);
+done:
        return p->tcf_action;
 }
 
@@ -443,30 +491,33 @@ static int tcf_pedit_dump(struct sk_buff *skb, struct tc_action *a,
 {
        unsigned char *b = skb_tail_pointer(skb);
        struct tcf_pedit *p = to_pedit(a);
+       struct tcf_pedit_parms *parms;
        struct tc_pedit *opt;
        struct tcf_t t;
        int s;
 
-       s = struct_size(opt, keys, p->tcfp_nkeys);
+       spin_lock_bh(&p->tcf_lock);
+       parms = rcu_dereference_protected(p->parms, 1);
+       s = struct_size(opt, keys, parms->tcfp_nkeys);
 
-       /* netlink spinlocks held above us - must use ATOMIC */
        opt = kzalloc(s, GFP_ATOMIC);
-       if (unlikely(!opt))
+       if (unlikely(!opt)) {
+               spin_unlock_bh(&p->tcf_lock);
                return -ENOBUFS;
+       }
 
-       spin_lock_bh(&p->tcf_lock);
-       memcpy(opt->keys, p->tcfp_keys, flex_array_size(opt, keys, p->tcfp_nkeys));
+       memcpy(opt->keys, parms->tcfp_keys,
+              flex_array_size(opt, keys, parms->tcfp_nkeys));
        opt->index = p->tcf_index;
-       opt->nkeys = p->tcfp_nkeys;
-       opt->flags = p->tcfp_flags;
+       opt->nkeys = parms->tcfp_nkeys;
+       opt->flags = parms->tcfp_flags;
        opt->action = p->tcf_action;
        opt->refcnt = refcount_read(&p->tcf_refcnt) - ref;
        opt->bindcnt = atomic_read(&p->tcf_bindcnt) - bind;
 
-       if (p->tcfp_keys_ex) {
-               if (tcf_pedit_key_ex_dump(skb,
-                                         p->tcfp_keys_ex,
-                                         p->tcfp_nkeys))
+       if (parms->tcfp_keys_ex) {
+               if (tcf_pedit_key_ex_dump(skb, parms->tcfp_keys_ex,
+                                         parms->tcfp_nkeys))
                        goto nla_put_failure;
 
                if (nla_put(skb, TCA_PEDIT_PARMS_EX, s, opt))