net/sched: sch_taprio: fix possible use-after-free
[platform/kernel/linux-rpi.git] / net / sched / sch_taprio.c
index b9fd18d..a76a2af 100644 (file)
@@ -66,6 +66,7 @@ struct taprio_sched {
        u32 flags;
        enum tk_offsets tk_offset;
        int clockid;
+       bool offloaded;
        atomic64_t picos_per_byte; /* Using picoseconds because for 10Gbps+
                                    * speeds it's sub-nanoseconds per byte
                                    */
@@ -95,18 +96,22 @@ static ktime_t sched_base_time(const struct sched_gate_list *sched)
        return ns_to_ktime(sched->base_time);
 }
 
-static ktime_t taprio_get_time(struct taprio_sched *q)
+static ktime_t taprio_mono_to_any(const struct taprio_sched *q, ktime_t mono)
 {
-       ktime_t mono = ktime_get();
+       /* This pairs with WRITE_ONCE() in taprio_parse_clockid() */
+       enum tk_offsets tk_offset = READ_ONCE(q->tk_offset);
 
-       switch (q->tk_offset) {
+       switch (tk_offset) {
        case TK_OFFS_MAX:
                return mono;
        default:
-               return ktime_mono_to_any(mono, q->tk_offset);
+               return ktime_mono_to_any(mono, tk_offset);
        }
+}
 
-       return KTIME_MAX;
+static ktime_t taprio_get_time(const struct taprio_sched *q)
+{
+       return taprio_mono_to_any(q, ktime_get());
 }
 
 static void taprio_free_sched_cb(struct rcu_head *head)
@@ -319,7 +324,7 @@ static ktime_t get_tcp_tstamp(struct taprio_sched *q, struct sk_buff *skb)
                return 0;
        }
 
-       return ktime_mono_to_any(skb->skb_mstamp_ns, q->tk_offset);
+       return taprio_mono_to_any(q, skb->skb_mstamp_ns);
 }
 
 /* There are a few scenarios where we will have to modify the txtime from
@@ -413,7 +418,8 @@ static int taprio_enqueue_one(struct sk_buff *skb, struct Qdisc *sch,
 {
        struct taprio_sched *q = qdisc_priv(sch);
 
-       if (skb->sk && sock_flag(skb->sk, SOCK_TXTIME)) {
+       /* sk_flags are only safe to use on full sockets. */
+       if (skb->sk && sk_fullsock(skb->sk) && sock_flag(skb->sk, SOCK_TXTIME)) {
                if (!is_valid_interval(skb, sch))
                        return qdisc_drop(skb, sch, to_free);
        } else if (TXTIME_ASSIST_IS_ENABLED(q->flags)) {
@@ -1273,6 +1279,8 @@ static int taprio_enable_offload(struct net_device *dev,
                goto done;
        }
 
+       q->offloaded = true;
+
 done:
        taprio_offload_free(offload);
 
@@ -1287,12 +1295,9 @@ static int taprio_disable_offload(struct net_device *dev,
        struct tc_taprio_qopt_offload *offload;
        int err;
 
-       if (!FULL_OFFLOAD_IS_ENABLED(q->flags))
+       if (!q->offloaded)
                return 0;
 
-       if (!ops->ndo_setup_tc)
-               return -EOPNOTSUPP;
-
        offload = taprio_offload_alloc(0);
        if (!offload) {
                NL_SET_ERR_MSG(extack,
@@ -1308,6 +1313,8 @@ static int taprio_disable_offload(struct net_device *dev,
                goto out;
        }
 
+       q->offloaded = false;
+
 out:
        taprio_offload_free(offload);
 
@@ -1352,6 +1359,7 @@ static int taprio_parse_clockid(struct Qdisc *sch, struct nlattr **tb,
                }
        } else if (tb[TCA_TAPRIO_ATTR_SCHED_CLOCKID]) {
                int clockid = nla_get_s32(tb[TCA_TAPRIO_ATTR_SCHED_CLOCKID]);
+               enum tk_offsets tk_offset;
 
                /* We only support static clockids and we don't allow
                 * for it to be modified after the first init.
@@ -1366,22 +1374,24 @@ static int taprio_parse_clockid(struct Qdisc *sch, struct nlattr **tb,
 
                switch (clockid) {
                case CLOCK_REALTIME:
-                       q->tk_offset = TK_OFFS_REAL;
+                       tk_offset = TK_OFFS_REAL;
                        break;
                case CLOCK_MONOTONIC:
-                       q->tk_offset = TK_OFFS_MAX;
+                       tk_offset = TK_OFFS_MAX;
                        break;
                case CLOCK_BOOTTIME:
-                       q->tk_offset = TK_OFFS_BOOT;
+                       tk_offset = TK_OFFS_BOOT;
                        break;
                case CLOCK_TAI:
-                       q->tk_offset = TK_OFFS_TAI;
+                       tk_offset = TK_OFFS_TAI;
                        break;
                default:
                        NL_SET_ERR_MSG(extack, "Invalid 'clockid'");
                        err = -EINVAL;
                        goto out;
                }
+               /* This pairs with READ_ONCE() in taprio_mono_to_any */
+               WRITE_ONCE(q->tk_offset, tk_offset);
 
                q->clockid = clockid;
        } else {
@@ -1622,13 +1632,13 @@ static void taprio_reset(struct Qdisc *sch)
        int i;
 
        hrtimer_cancel(&q->advance_timer);
+       qdisc_synchronize(sch);
+
        if (q->qdiscs) {
                for (i = 0; i < dev->num_tx_queues; i++)
                        if (q->qdiscs[i])
                                qdisc_reset(q->qdiscs[i]);
        }
-       sch->qstats.backlog = 0;
-       sch->q.qlen = 0;
 }
 
 static void taprio_destroy(struct Qdisc *sch)
@@ -1645,6 +1655,7 @@ static void taprio_destroy(struct Qdisc *sch)
         * happens in qdisc_create(), after taprio_init() has been called.
         */
        hrtimer_cancel(&q->advance_timer);
+       qdisc_synchronize(sch);
 
        taprio_disable_offload(dev, q, NULL);
 
@@ -1940,12 +1951,14 @@ start_error:
 
 static struct Qdisc *taprio_leaf(struct Qdisc *sch, unsigned long cl)
 {
-       struct netdev_queue *dev_queue = taprio_queue_get(sch, cl);
+       struct taprio_sched *q = qdisc_priv(sch);
+       struct net_device *dev = qdisc_dev(sch);
+       unsigned int ntx = cl - 1;
 
-       if (!dev_queue)
+       if (ntx >= dev->num_tx_queues)
                return NULL;
 
-       return dev_queue->qdisc_sleeping;
+       return q->qdiscs[ntx];
 }
 
 static unsigned long taprio_find(struct Qdisc *sch, u32 classid)