udplite: fix various data-races
[platform/kernel/linux-starfive.git] / net / ipv4 / udp.c
index 0794a2c..c3ff984 100644 (file)
@@ -714,7 +714,7 @@ int __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)
                               iph->saddr, uh->source, skb->dev->ifindex,
                               inet_sdif(skb), udptable, NULL);
 
-       if (!sk || udp_sk(sk)->encap_type) {
+       if (!sk || READ_ONCE(udp_sk(sk)->encap_type)) {
                /* No socket for error: try tunnels before discarding */
                if (static_branch_unlikely(&udp_encap_needed_key)) {
                        sk = __udp4_lib_err_encap(net, iph, uh, udptable, sk, skb,
@@ -1051,7 +1051,7 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
        u8 tos, scope;
        __be16 dport;
        int err, is_udplite = IS_UDPLITE(sk);
-       int corkreq = READ_ONCE(up->corkflag) || msg->msg_flags&MSG_MORE;
+       int corkreq = udp_test_bit(CORK, sk) || msg->msg_flags & MSG_MORE;
        int (*getfrag)(void *, char *, int, int, int, struct sk_buff *);
        struct sk_buff *skb;
        struct ip_options_data opt_copy;
@@ -1315,11 +1315,11 @@ void udp_splice_eof(struct socket *sock)
        struct sock *sk = sock->sk;
        struct udp_sock *up = udp_sk(sk);
 
-       if (!up->pending || READ_ONCE(up->corkflag))
+       if (!up->pending || udp_test_bit(CORK, sk))
                return;
 
        lock_sock(sk);
-       if (up->pending && !READ_ONCE(up->corkflag))
+       if (up->pending && !udp_test_bit(CORK, sk))
                udp_push_pending_frames(sk);
        release_sock(sk);
 }
@@ -1414,9 +1414,9 @@ static void udp_rmem_release(struct sock *sk, int size, int partial,
                spin_lock(&sk_queue->lock);
 
 
-       sk->sk_forward_alloc += size;
+       sk_forward_alloc_add(sk, size);
        amt = (sk->sk_forward_alloc - partial) & ~(PAGE_SIZE - 1);
-       sk->sk_forward_alloc -= amt;
+       sk_forward_alloc_add(sk, -amt);
 
        if (amt)
                __sk_mem_reduce_allocated(sk, amt >> PAGE_SHIFT);
@@ -1527,7 +1527,7 @@ int __udp_enqueue_schedule_skb(struct sock *sk, struct sk_buff *skb)
                goto uncharge_drop;
        }
 
-       sk->sk_forward_alloc -= size;
+       sk_forward_alloc_add(sk, -size);
 
        /* no need to setup a destructor, we will explicitly release the
         * forward allocated memory on dequeue
@@ -1868,7 +1868,7 @@ try_again:
                                                      (struct sockaddr *)sin);
        }
 
-       if (udp_sk(sk)->gro_enabled)
+       if (udp_test_bit(GRO_ENABLED, sk))
                udp_cmsg_recv(msg, sk, skb);
 
        if (inet_cmsg_flags(inet))
@@ -2081,7 +2081,8 @@ static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb)
        }
        nf_reset_ct(skb);
 
-       if (static_branch_unlikely(&udp_encap_needed_key) && up->encap_type) {
+       if (static_branch_unlikely(&udp_encap_needed_key) &&
+           READ_ONCE(up->encap_type)) {
                int (*encap_rcv)(struct sock *sk, struct sk_buff *skb);
 
                /*
@@ -2119,7 +2120,8 @@ static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb)
        /*
         *      UDP-Lite specific tests, ignored on UDP sockets
         */
-       if ((up->pcflag & UDPLITE_RECV_CC)  &&  UDP_SKB_CB(skb)->partial_cov) {
+       if (udp_test_bit(UDPLITE_RECV_CC, sk) && UDP_SKB_CB(skb)->partial_cov) {
+               u16 pcrlen = READ_ONCE(up->pcrlen);
 
                /*
                 * MIB statistics other than incrementing the error count are
@@ -2132,7 +2134,7 @@ static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb)
                 * delivery of packets with coverage values less than a value
                 * provided by the application."
                 */
-               if (up->pcrlen == 0) {          /* full coverage was set  */
+               if (pcrlen == 0) {          /* full coverage was set  */
                        net_dbg_ratelimited("UDPLite: partial coverage %d while full coverage %d requested\n",
                                            UDP_SKB_CB(skb)->cscov, skb->len);
                        goto drop;
@@ -2143,9 +2145,9 @@ static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb)
                 * that it wants x while sender emits packets of smaller size y.
                 * Therefore the above ...()->partial_cov statement is essential.
                 */
-               if (UDP_SKB_CB(skb)->cscov  <  up->pcrlen) {
+               if (UDP_SKB_CB(skb)->cscov pcrlen) {
                        net_dbg_ratelimited("UDPLite: coverage %d too small, need min %d\n",
-                                           UDP_SKB_CB(skb)->cscov, up->pcrlen);
+                                           UDP_SKB_CB(skb)->cscov, pcrlen);
                        goto drop;
                }
        }
@@ -2618,7 +2620,7 @@ void udp_destroy_sock(struct sock *sk)
                        if (encap_destroy)
                                encap_destroy(sk);
                }
-               if (up->encap_enabled)
+               if (udp_test_bit(ENCAP_ENABLED, sk))
                        static_branch_dec(&udp_encap_needed_key);
        }
 }
@@ -2658,9 +2660,9 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
        switch (optname) {
        case UDP_CORK:
                if (val != 0) {
-                       WRITE_ONCE(up->corkflag, 1);
+                       udp_set_bit(CORK, sk);
                } else {
-                       WRITE_ONCE(up->corkflag, 0);
+                       udp_clear_bit(CORK, sk);
                        lock_sock(sk);
                        push_pending_frames(sk);
                        release_sock(sk);
@@ -2675,17 +2677,17 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
                case UDP_ENCAP_ESPINUDP_NON_IKE:
 #if IS_ENABLED(CONFIG_IPV6)
                        if (sk->sk_family == AF_INET6)
-                               up->encap_rcv = ipv6_stub->xfrm6_udp_encap_rcv;
+                               WRITE_ONCE(up->encap_rcv,
+                                          ipv6_stub->xfrm6_udp_encap_rcv);
                        else
 #endif
-                               up->encap_rcv = xfrm4_udp_encap_rcv;
+                               WRITE_ONCE(up->encap_rcv,
+                                          xfrm4_udp_encap_rcv);
 #endif
                        fallthrough;
                case UDP_ENCAP_L2TPINUDP:
-                       up->encap_type = val;
-                       lock_sock(sk);
-                       udp_tunnel_encap_enable(sk->sk_socket);
-                       release_sock(sk);
+                       WRITE_ONCE(up->encap_type, val);
+                       udp_tunnel_encap_enable(sk);
                        break;
                default:
                        err = -ENOPROTOOPT;
@@ -2694,11 +2696,11 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
                break;
 
        case UDP_NO_CHECK6_TX:
-               up->no_check6_tx = valbool;
+               udp_set_no_check6_tx(sk, valbool);
                break;
 
        case UDP_NO_CHECK6_RX:
-               up->no_check6_rx = valbool;
+               udp_set_no_check6_rx(sk, valbool);
                break;
 
        case UDP_SEGMENT:
@@ -2708,14 +2710,12 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
                break;
 
        case UDP_GRO:
-               lock_sock(sk);
 
                /* when enabling GRO, accept the related GSO packet type */
                if (valbool)
-                       udp_tunnel_encap_enable(sk->sk_socket);
-               up->gro_enabled = valbool;
-               up->accept_udp_l4 = valbool;
-               release_sock(sk);
+                       udp_tunnel_encap_enable(sk);
+               udp_assign_bit(GRO_ENABLED, sk, valbool);
+               udp_assign_bit(ACCEPT_L4, sk, valbool);
                break;
 
        /*
@@ -2730,8 +2730,8 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
                        val = 8;
                else if (val > USHRT_MAX)
                        val = USHRT_MAX;
-               up->pcslen = val;
-               up->pcflag |= UDPLITE_SEND_CC;
+               WRITE_ONCE(up->pcslen, val);
+               udp_set_bit(UDPLITE_SEND_CC, sk);
                break;
 
        /* The receiver specifies a minimum checksum coverage value. To make
@@ -2744,8 +2744,8 @@ int udp_lib_setsockopt(struct sock *sk, int level, int optname,
                        val = 8;
                else if (val > USHRT_MAX)
                        val = USHRT_MAX;
-               up->pcrlen = val;
-               up->pcflag |= UDPLITE_RECV_CC;
+               WRITE_ONCE(up->pcrlen, val);
+               udp_set_bit(UDPLITE_RECV_CC, sk);
                break;
 
        default:
@@ -2783,19 +2783,19 @@ int udp_lib_getsockopt(struct sock *sk, int level, int optname,
 
        switch (optname) {
        case UDP_CORK:
-               val = READ_ONCE(up->corkflag);
+               val = udp_test_bit(CORK, sk);
                break;
 
        case UDP_ENCAP:
-               val = up->encap_type;
+               val = READ_ONCE(up->encap_type);
                break;
 
        case UDP_NO_CHECK6_TX:
-               val = up->no_check6_tx;
+               val = udp_get_no_check6_tx(sk);
                break;
 
        case UDP_NO_CHECK6_RX:
-               val = up->no_check6_rx;
+               val = udp_get_no_check6_rx(sk);
                break;
 
        case UDP_SEGMENT:
@@ -2803,17 +2803,17 @@ int udp_lib_getsockopt(struct sock *sk, int level, int optname,
                break;
 
        case UDP_GRO:
-               val = up->gro_enabled;
+               val = udp_test_bit(GRO_ENABLED, sk);
                break;
 
        /* The following two cannot be changed on UDP sockets, the return is
         * always 0 (which corresponds to the full checksum coverage of UDP). */
        case UDPLITE_SEND_CSCOV:
-               val = up->pcslen;
+               val = READ_ONCE(up->pcslen);
                break;
 
        case UDPLITE_RECV_CSCOV:
-               val = up->pcrlen;
+               val = READ_ONCE(up->pcrlen);
                break;
 
        default: