bpf: Avoid iter->offset making backward progress in bpf_iter_udp
[platform/kernel/linux-starfive.git] / net / ipv4 / udp.c
index f39b9c8..9cb22a6 100644 (file)
@@ -714,7 +714,7 @@ int __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)
                               iph->saddr, uh->source, skb->dev->ifindex,
                               inet_sdif(skb), udptable, NULL);
 
-       if (!sk || udp_sk(sk)->encap_type) {
+       if (!sk || READ_ONCE(udp_sk(sk)->encap_type)) {
                /* No socket for error: try tunnels before discarding */
                if (static_branch_unlikely(&udp_encap_needed_key)) {
                        sk = __udp4_lib_err_encap(net, iph, uh, udptable, sk, skb,
@@ -805,7 +805,7 @@ void udp_flush_pending_frames(struct sock *sk)
 
        if (up->pending) {
                up->len = 0;
-               up->pending = 0;
+               WRITE_ONCE(up->pending, 0);
                ip_flush_pending_frames(sk);
        }
 }
@@ -993,7 +993,7 @@ int udp_push_pending_frames(struct sock *sk)
 
 out:
        up->len = 0;
-       up->pending = 0;
+       WRITE_ONCE(up->pending, 0);
        return err;
 }
 EXPORT_SYMBOL(udp_push_pending_frames);
@@ -1051,7 +1051,7 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
        u8 tos, scope;
        __be16 dport;
        int err, is_udplite = IS_UDPLITE(sk);
-       int corkreq = READ_ONCE(up->corkflag) || msg->msg_flags&MSG_MORE;
+       int corkreq = udp_test_bit(CORK, sk) || msg->msg_flags & MSG_MORE;
        int (*getfrag)(void *, char *, int, int, int, struct sk_buff *);
        struct sk_buff *skb;
        struct ip_options_data opt_copy;
@@ -1069,7 +1069,7 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
        getfrag = is_udplite ? udplite_getfrag : ip_generic_getfrag;
 
        fl4 = &inet->cork.fl.u.ip4;
-       if (up->pending) {
+       if (READ_ONCE(up->pending)) {
                /*
                 * There are pending frames.
                 * The socket lock must be held while it's corked.
@@ -1265,7 +1265,7 @@ back_from_confirm:
        fl4->saddr = saddr;
        fl4->fl4_dport = dport;
        fl4->fl4_sport = inet->inet_sport;
-       up->pending = AF_INET;
+       WRITE_ONCE(up->pending, AF_INET);
 
 do_append_data:
        up->len += ulen;
@@ -1277,7 +1277,7 @@ do_append_data:
        else if (!corkreq)
                err = udp_push_pending_frames(sk);
        else if (unlikely(skb_queue_empty(&sk->sk_write_queue)))
-               up->pending = 0;
+               WRITE_ONCE(up->pending, 0);
        release_sock(sk);
 
 out:
@@ -1315,11 +1315,11 @@ void udp_splice_eof(struct socket *sock)
        struct sock *sk = sock->sk;
        struct udp_sock *up = udp_sk(sk);
 
-       if (!up->pending || READ_ONCE(up->corkflag))
+       if (!READ_ONCE(up->pending) || udp_test_bit(CORK, sk))
                return;
 
        lock_sock(sk);
-       if (up->pending && !READ_ONCE(up->corkflag))
+       if (up->pending && !udp_test_bit(CORK, sk))
                udp_push_pending_frames(sk);
        release_sock(sk);
 }
@@ -1868,7 +1868,7 @@ try_again:
                                                      (struct sockaddr *)sin);
        }
 
-       if (udp_sk(sk)->gro_enabled)
+       if (udp_test_bit(GRO_ENABLED, sk))
                udp_cmsg_recv(msg, sk, skb);
 
        if (inet_cmsg_flags(inet))
@@ -2081,7 +2081,8 @@ static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb)
        }
        nf_reset_ct(skb);
 
-       if (static_branch_unlikely(&udp_encap_needed_key) && up->encap_type) {
+       if (static_branch_unlikely(&udp_encap_needed_key) &&
+           READ_ONCE(up->encap_type)) {
                int (*encap_rcv)(struct sock *sk, struct sk_buff *skb);
 
                /*
@@ -2119,7 +2120,8 @@ static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb)
        /*
         *      UDP-Lite specific tests, ignored on UDP sockets
         */
-       if ((up->pcflag & UDPLITE_RECV_CC)  &&  UDP_SKB_CB(skb)->partial_cov) {
+       if (udp_test_bit(UDPLITE_RECV_CC, sk) && UDP_SKB_CB(skb)->partial_cov) {
+               u16 pcrlen = READ_ONCE(up->pcrlen);
 
                /*
                 * MIB statistics other than incrementing the error count are
@@ -2132,7 +2134,7 @@ static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb)
                 * delivery of packets with coverage values less than a value
                 * provided by the application."
                 */
-               if (up->pcrlen == 0) {          /* full coverage was set  */
+               if (pcrlen == 0) {          /* full coverage was set  */
                        net_dbg_ratelimited("UDPLite: partial coverage %d while full coverage %d requested\n",
                                            UDP_SKB_CB(skb)->cscov, skb->len);
                        goto drop;
@@ -2143,9 +2145,9 @@ static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb)
                 * that it wants x while sender emits packets of smaller size y.
                 * Therefore the above ...()->partial_cov statement is essential.
                 */
-               if (UDP_SKB_CB(skb)->cscov  <  up->pcrlen) {
+               if (UDP_SKB_CB(skb)->cscov pcrlen) {
                        net_dbg_ratelimited("UDPLite: coverage %d too small, need min %d\n",
-                                           UDP_SKB_CB(skb)->cscov, up->pcrlen);
+                                           UDP_SKB_CB(skb)->cscov, pcrlen);
                        goto drop;
                }
        }
@@ -2618,7 +2620,7 @@ void udp_destroy_sock(struct sock *sk)
                        if (encap_destroy)
                                encap_destroy(sk);
                }
-               if (up->encap_enabled)
+               if (udp_test_bit(ENCAP_ENABLED, sk))
                        static_branch_dec(&udp_encap_needed_key);
        }
 }
@@ -2658,9 +2660,9 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
        switch (optname) {
        case UDP_CORK:
                if (val != 0) {
-                       WRITE_ONCE(up->corkflag, 1);
+                       udp_set_bit(CORK, sk);
                } else {
-                       WRITE_ONCE(up->corkflag, 0);
+                       udp_clear_bit(CORK, sk);
                        lock_sock(sk);
                        push_pending_frames(sk);
                        release_sock(sk);
@@ -2675,17 +2677,17 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
                case UDP_ENCAP_ESPINUDP_NON_IKE:
 #if IS_ENABLED(CONFIG_IPV6)
                        if (sk->sk_family == AF_INET6)
-                               up->encap_rcv = ipv6_stub->xfrm6_udp_encap_rcv;
+                               WRITE_ONCE(up->encap_rcv,
+                                          ipv6_stub->xfrm6_udp_encap_rcv);
                        else
 #endif
-                               up->encap_rcv = xfrm4_udp_encap_rcv;
+                               WRITE_ONCE(up->encap_rcv,
+                                          xfrm4_udp_encap_rcv);
 #endif
                        fallthrough;
                case UDP_ENCAP_L2TPINUDP:
-                       up->encap_type = val;
-                       lock_sock(sk);
-                       udp_tunnel_encap_enable(sk->sk_socket);
-                       release_sock(sk);
+                       WRITE_ONCE(up->encap_type, val);
+                       udp_tunnel_encap_enable(sk);
                        break;
                default:
                        err = -ENOPROTOOPT;
@@ -2694,11 +2696,11 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
                break;
 
        case UDP_NO_CHECK6_TX:
-               up->no_check6_tx = valbool;
+               udp_set_no_check6_tx(sk, valbool);
                break;
 
        case UDP_NO_CHECK6_RX:
-               up->no_check6_rx = valbool;
+               udp_set_no_check6_rx(sk, valbool);
                break;
 
        case UDP_SEGMENT:
@@ -2708,14 +2710,12 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
                break;
 
        case UDP_GRO:
-               lock_sock(sk);
 
                /* when enabling GRO, accept the related GSO packet type */
                if (valbool)
-                       udp_tunnel_encap_enable(sk->sk_socket);
-               up->gro_enabled = valbool;
-               up->accept_udp_l4 = valbool;
-               release_sock(sk);
+                       udp_tunnel_encap_enable(sk);
+               udp_assign_bit(GRO_ENABLED, sk, valbool);
+               udp_assign_bit(ACCEPT_L4, sk, valbool);
                break;
 
        /*
@@ -2730,8 +2730,8 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
                        val = 8;
                else if (val > USHRT_MAX)
                        val = USHRT_MAX;
-               up->pcslen = val;
-               up->pcflag |= UDPLITE_SEND_CC;
+               WRITE_ONCE(up->pcslen, val);
+               udp_set_bit(UDPLITE_SEND_CC, sk);
                break;
 
        /* The receiver specifies a minimum checksum coverage value. To make
@@ -2744,8 +2744,8 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
                        val = 8;
                else if (val > USHRT_MAX)
                        val = USHRT_MAX;
-               up->pcrlen = val;
-               up->pcflag |= UDPLITE_RECV_CC;
+               WRITE_ONCE(up->pcrlen, val);
+               udp_set_bit(UDPLITE_RECV_CC, sk);
                break;
 
        default:
@@ -2783,19 +2783,19 @@ int udp_lib_getsockopt(struct sock *sk, int level, int optname,
 
        switch (optname) {
        case UDP_CORK:
-               val = READ_ONCE(up->corkflag);
+               val = udp_test_bit(CORK, sk);
                break;
 
        case UDP_ENCAP:
-               val = up->encap_type;
+               val = READ_ONCE(up->encap_type);
                break;
 
        case UDP_NO_CHECK6_TX:
-               val = up->no_check6_tx;
+               val = udp_get_no_check6_tx(sk);
                break;
 
        case UDP_NO_CHECK6_RX:
-               val = up->no_check6_rx;
+               val = udp_get_no_check6_rx(sk);
                break;
 
        case UDP_SEGMENT:
@@ -2803,17 +2803,17 @@ int udp_lib_getsockopt(struct sock *sk, int level, int optname,
                break;
 
        case UDP_GRO:
-               val = up->gro_enabled;
+               val = udp_test_bit(GRO_ENABLED, sk);
                break;
 
        /* The following two cannot be changed on UDP sockets, the return is
         * always 0 (which corresponds to the full checksum coverage of UDP). */
        case UDPLITE_SEND_CSCOV:
-               val = up->pcslen;
+               val = READ_ONCE(up->pcslen);
                break;
 
        case UDPLITE_RECV_CSCOV:
-               val = up->pcrlen;
+               val = READ_ONCE(up->pcrlen);
                break;
 
        default:
@@ -3116,16 +3116,18 @@ static struct sock *bpf_iter_udp_batch(struct seq_file *seq)
        struct bpf_udp_iter_state *iter = seq->private;
        struct udp_iter_state *state = &iter->state;
        struct net *net = seq_file_net(seq);
+       int resume_bucket, resume_offset;
        struct udp_table *udptable;
        unsigned int batch_sks = 0;
        bool resized = false;
        struct sock *sk;
 
+       resume_bucket = state->bucket;
+       resume_offset = iter->offset;
+
        /* The current batch is done, so advance the bucket. */
-       if (iter->st_bucket_done) {
+       if (iter->st_bucket_done)
                state->bucket++;
-               iter->offset = 0;
-       }
 
        udptable = udp_get_table_seq(seq, net);
 
@@ -3145,19 +3147,19 @@ again:
        for (; state->bucket <= udptable->mask; state->bucket++) {
                struct udp_hslot *hslot2 = &udptable->hash2[state->bucket];
 
-               if (hlist_empty(&hslot2->head)) {
-                       iter->offset = 0;
+               if (hlist_empty(&hslot2->head))
                        continue;
-               }
 
+               iter->offset = 0;
                spin_lock_bh(&hslot2->lock);
                udp_portaddr_for_each_entry(sk, &hslot2->head) {
                        if (seq_sk_match(seq, sk)) {
                                /* Resume from the last iterated socket at the
                                 * offset in the bucket before iterator was stopped.
                                 */
-                               if (iter->offset) {
-                                       --iter->offset;
+                               if (state->bucket == resume_bucket &&
+                                   iter->offset < resume_offset) {
+                                       ++iter->offset;
                                        continue;
                                }
                                if (iter->end_sk < iter->max_sk) {
@@ -3171,9 +3173,6 @@ again:
 
                if (iter->end_sk)
                        break;
-
-               /* Reset the current bucket's offset before moving to the next bucket. */
-               iter->offset = 0;
        }
 
        /* All done: no batch made. */
@@ -3192,7 +3191,6 @@ again:
                /* After allocating a larger batch, retry one more time to grab
                 * the whole bucket.
                 */
-               state->bucket--;
                goto again;
        }
 done: