net/sched: cls_api: Fix lockup on flushing explicitly created chain
[platform/kernel/linux-starfive.git] / net / sched / cls_api.c
index 50566db..445ab1b 100644 (file)
@@ -41,8 +41,6 @@
 #include <net/tc_act/tc_gate.h>
 #include <net/flow_offload.h>
 
-extern const struct nla_policy rtm_tca_policy[TCA_MAX + 1];
-
 /* The list of all installed classifier types */
 static LIST_HEAD(tcf_proto_base);
 
@@ -487,7 +485,8 @@ static struct tcf_chain *tcf_chain_lookup_rcu(const struct tcf_block *block,
 #endif
 
 static int tc_chain_notify(struct tcf_chain *chain, struct sk_buff *oskb,
-                          u32 seq, u16 flags, int event, bool unicast);
+                          u32 seq, u16 flags, int event, bool unicast,
+                          struct netlink_ext_ack *extack);
 
 static struct tcf_chain *__tcf_chain_get(struct tcf_block *block,
                                         u32 chain_index, bool create,
@@ -520,7 +519,7 @@ static struct tcf_chain *__tcf_chain_get(struct tcf_block *block,
         */
        if (is_first_reference && !by_act)
                tc_chain_notify(chain, NULL, 0, NLM_F_CREATE | NLM_F_EXCL,
-                               RTM_NEWCHAIN, false);
+                               RTM_NEWCHAIN, false, NULL);
 
        return chain;
 
@@ -553,8 +552,8 @@ static void __tcf_chain_put(struct tcf_chain *chain, bool by_act,
 {
        struct tcf_block *block = chain->block;
        const struct tcf_proto_ops *tmplt_ops;
+       unsigned int refcnt, non_act_refcnt;
        bool free_block = false;
-       unsigned int refcnt;
        void *tmplt_priv;
 
        mutex_lock(&block->lock);
@@ -574,13 +573,15 @@ static void __tcf_chain_put(struct tcf_chain *chain, bool by_act,
         * save these to temporary variables.
         */
        refcnt = --chain->refcnt;
+       non_act_refcnt = refcnt - chain->action_refcnt;
        tmplt_ops = chain->tmplt_ops;
        tmplt_priv = chain->tmplt_priv;
 
-       /* The last dropped non-action reference will trigger notification. */
-       if (refcnt - chain->action_refcnt == 0 && !by_act) {
-               tc_chain_notify_delete(tmplt_ops, tmplt_priv, chain->index,
-                                      block, NULL, 0, 0, false);
+       if (non_act_refcnt == chain->explicitly_created && !by_act) {
+               if (non_act_refcnt == 0)
+                       tc_chain_notify_delete(tmplt_ops, tmplt_priv,
+                                              chain->index, block, NULL, 0, 0,
+                                              false);
                /* Last reference to chain, no need to lock. */
                chain->flushing = false;
        }
@@ -1483,6 +1484,7 @@ static int tcf_block_bind(struct tcf_block *block,
 
 err_unroll:
        list_for_each_entry_safe(block_cb, next, &bo->cb_list, list) {
+               list_del(&block_cb->driver_list);
                if (i-- > 0) {
                        list_del(&block_cb->list);
                        tcf_block_playback_offloads(block, block_cb->cb,
@@ -1816,7 +1818,8 @@ static int tcf_fill_node(struct net *net, struct sk_buff *skb,
                         struct tcf_proto *tp, struct tcf_block *block,
                         struct Qdisc *q, u32 parent, void *fh,
                         u32 portid, u32 seq, u16 flags, int event,
-                        bool terse_dump, bool rtnl_held)
+                        bool terse_dump, bool rtnl_held,
+                        struct netlink_ext_ack *extack)
 {
        struct tcmsg *tcm;
        struct nlmsghdr  *nlh;
@@ -1856,7 +1859,13 @@ static int tcf_fill_node(struct net *net, struct sk_buff *skb,
                    tp->ops->dump(net, tp, fh, skb, tcm, rtnl_held) < 0)
                        goto nla_put_failure;
        }
+
+       if (extack && extack->_msg &&
+           nla_put_string(skb, TCA_EXT_WARN_MSG, extack->_msg))
+               goto nla_put_failure;
+
        nlh->nlmsg_len = skb_tail_pointer(skb) - b;
+
        return skb->len;
 
 out_nlmsg_trim:
@@ -1870,7 +1879,7 @@ static int tfilter_notify(struct net *net, struct sk_buff *oskb,
                          struct nlmsghdr *n, struct tcf_proto *tp,
                          struct tcf_block *block, struct Qdisc *q,
                          u32 parent, void *fh, int event, bool unicast,
-                         bool rtnl_held)
+                         bool rtnl_held, struct netlink_ext_ack *extack)
 {
        struct sk_buff *skb;
        u32 portid = oskb ? NETLINK_CB(oskb).portid : 0;
@@ -1882,7 +1891,7 @@ static int tfilter_notify(struct net *net, struct sk_buff *oskb,
 
        if (tcf_fill_node(net, skb, tp, block, q, parent, fh, portid,
                          n->nlmsg_seq, n->nlmsg_flags, event,
-                         false, rtnl_held) <= 0) {
+                         false, rtnl_held, extack) <= 0) {
                kfree_skb(skb);
                return -EINVAL;
        }
@@ -1911,7 +1920,7 @@ static int tfilter_del_notify(struct net *net, struct sk_buff *oskb,
 
        if (tcf_fill_node(net, skb, tp, block, q, parent, fh, portid,
                          n->nlmsg_seq, n->nlmsg_flags, RTM_DELTFILTER,
-                         false, rtnl_held) <= 0) {
+                         false, rtnl_held, extack) <= 0) {
                NL_SET_ERR_MSG(extack, "Failed to build del event notification");
                kfree_skb(skb);
                return -EINVAL;
@@ -1937,14 +1946,15 @@ static int tfilter_del_notify(struct net *net, struct sk_buff *oskb,
 static void tfilter_notify_chain(struct net *net, struct sk_buff *oskb,
                                 struct tcf_block *block, struct Qdisc *q,
                                 u32 parent, struct nlmsghdr *n,
-                                struct tcf_chain *chain, int event)
+                                struct tcf_chain *chain, int event,
+                                struct netlink_ext_ack *extack)
 {
        struct tcf_proto *tp;
 
        for (tp = tcf_get_next_proto(chain, NULL);
             tp; tp = tcf_get_next_proto(chain, tp))
-               tfilter_notify(net, oskb, n, tp, block,
-                              q, parent, NULL, event, false, true);
+               tfilter_notify(net, oskb, n, tp, block, q, parent, NULL,
+                              event, false, true, extack);
 }
 
 static void tfilter_put(struct tcf_proto *tp, void *fh)
@@ -2148,7 +2158,7 @@ replay:
                              flags, extack);
        if (err == 0) {
                tfilter_notify(net, skb, n, tp, block, q, parent, fh,
-                              RTM_NEWTFILTER, false, rtnl_held);
+                              RTM_NEWTFILTER, false, rtnl_held, extack);
                tfilter_put(tp, fh);
                /* q pointer is NULL for shared blocks */
                if (q)
@@ -2276,7 +2286,7 @@ static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n,
 
        if (prio == 0) {
                tfilter_notify_chain(net, skb, block, q, parent, n,
-                                    chain, RTM_DELTFILTER);
+                                    chain, RTM_DELTFILTER, extack);
                tcf_chain_flush(chain, rtnl_held);
                err = 0;
                goto errout;
@@ -2300,7 +2310,7 @@ static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n,
 
                tcf_proto_put(tp, rtnl_held, NULL);
                tfilter_notify(net, skb, n, tp, block, q, parent, fh,
-                              RTM_DELTFILTER, false, rtnl_held);
+                              RTM_DELTFILTER, false, rtnl_held, extack);
                err = 0;
                goto errout;
        }
@@ -2444,7 +2454,7 @@ static int tc_get_tfilter(struct sk_buff *skb, struct nlmsghdr *n,
                err = -ENOENT;
        } else {
                err = tfilter_notify(net, skb, n, tp, block, q, parent,
-                                    fh, RTM_NEWTFILTER, true, rtnl_held);
+                                    fh, RTM_NEWTFILTER, true, rtnl_held, NULL);
                if (err < 0)
                        NL_SET_ERR_MSG(extack, "Failed to send filter notify message");
        }
@@ -2482,7 +2492,7 @@ static int tcf_node_dump(struct tcf_proto *tp, void *n, struct tcf_walker *arg)
        return tcf_fill_node(net, a->skb, tp, a->block, a->q, a->parent,
                             n, NETLINK_CB(a->cb->skb).portid,
                             a->cb->nlh->nlmsg_seq, NLM_F_MULTI,
-                            RTM_NEWTFILTER, a->terse_dump, true);
+                            RTM_NEWTFILTER, a->terse_dump, true, NULL);
 }
 
 static bool tcf_chain_dump(struct tcf_chain *chain, struct Qdisc *q, u32 parent,
@@ -2516,7 +2526,7 @@ static bool tcf_chain_dump(struct tcf_chain *chain, struct Qdisc *q, u32 parent,
                        if (tcf_fill_node(net, skb, tp, block, q, parent, NULL,
                                          NETLINK_CB(cb->skb).portid,
                                          cb->nlh->nlmsg_seq, NLM_F_MULTI,
-                                         RTM_NEWTFILTER, false, true) <= 0)
+                                         RTM_NEWTFILTER, false, true, NULL) <= 0)
                                goto errout;
                        cb->args[1] = 1;
                }
@@ -2659,7 +2669,8 @@ static int tc_chain_fill_node(const struct tcf_proto_ops *tmplt_ops,
                              void *tmplt_priv, u32 chain_index,
                              struct net *net, struct sk_buff *skb,
                              struct tcf_block *block,
-                             u32 portid, u32 seq, u16 flags, int event)
+                             u32 portid, u32 seq, u16 flags, int event,
+                             struct netlink_ext_ack *extack)
 {
        unsigned char *b = skb_tail_pointer(skb);
        const struct tcf_proto_ops *ops;
@@ -2696,7 +2707,12 @@ static int tc_chain_fill_node(const struct tcf_proto_ops *tmplt_ops,
                        goto nla_put_failure;
        }
 
+       if (extack && extack->_msg &&
+           nla_put_string(skb, TCA_EXT_WARN_MSG, extack->_msg))
+               goto out_nlmsg_trim;
+
        nlh->nlmsg_len = skb_tail_pointer(skb) - b;
+
        return skb->len;
 
 out_nlmsg_trim:
@@ -2706,7 +2722,8 @@ nla_put_failure:
 }
 
 static int tc_chain_notify(struct tcf_chain *chain, struct sk_buff *oskb,
-                          u32 seq, u16 flags, int event, bool unicast)
+                          u32 seq, u16 flags, int event, bool unicast,
+                          struct netlink_ext_ack *extack)
 {
        u32 portid = oskb ? NETLINK_CB(oskb).portid : 0;
        struct tcf_block *block = chain->block;
@@ -2720,7 +2737,7 @@ static int tc_chain_notify(struct tcf_chain *chain, struct sk_buff *oskb,
 
        if (tc_chain_fill_node(chain->tmplt_ops, chain->tmplt_priv,
                               chain->index, net, skb, block, portid,
-                              seq, flags, event) <= 0) {
+                              seq, flags, event, extack) <= 0) {
                kfree_skb(skb);
                return -EINVAL;
        }
@@ -2748,7 +2765,7 @@ static int tc_chain_notify_delete(const struct tcf_proto_ops *tmplt_ops,
                return -ENOBUFS;
 
        if (tc_chain_fill_node(tmplt_ops, tmplt_priv, chain_index, net, skb,
-                              block, portid, seq, flags, RTM_DELCHAIN) <= 0) {
+                              block, portid, seq, flags, RTM_DELCHAIN, NULL) <= 0) {
                kfree_skb(skb);
                return -EINVAL;
        }
@@ -2781,6 +2798,7 @@ static int tc_chain_tmplt_add(struct tcf_chain *chain, struct net *net,
                return PTR_ERR(ops);
        if (!ops->tmplt_create || !ops->tmplt_destroy || !ops->tmplt_dump) {
                NL_SET_ERR_MSG(extack, "Chain templates are not supported with specified classifier");
+               module_put(ops->owner);
                return -EOPNOTSUPP;
        }
 
@@ -2900,11 +2918,11 @@ replay:
                }
 
                tc_chain_notify(chain, NULL, 0, NLM_F_CREATE | NLM_F_EXCL,
-                               RTM_NEWCHAIN, false);
+                               RTM_NEWCHAIN, false, extack);
                break;
        case RTM_DELCHAIN:
                tfilter_notify_chain(net, skb, block, q, parent, n,
-                                    chain, RTM_DELTFILTER);
+                                    chain, RTM_DELTFILTER, extack);
                /* Flush the chain first as the user requested chain removal. */
                tcf_chain_flush(chain, true);
                /* In case the chain was successfully deleted, put a reference
@@ -2914,7 +2932,7 @@ replay:
                break;
        case RTM_GETCHAIN:
                err = tc_chain_notify(chain, skb, n->nlmsg_seq,
-                                     n->nlmsg_flags, n->nlmsg_type, true);
+                                     n->nlmsg_flags, n->nlmsg_type, true, extack);
                if (err < 0)
                        NL_SET_ERR_MSG(extack, "Failed to send chain notify message");
                break;
@@ -3014,7 +3032,7 @@ static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb)
                                         chain->index, net, skb, block,
                                         NETLINK_CB(cb->skb).portid,
                                         cb->nlh->nlmsg_seq, NLM_F_MULTI,
-                                        RTM_NEWCHAIN);
+                                        RTM_NEWCHAIN, NULL);
                if (err <= 0)
                        break;
                index++;