inet: read sk->sk_family once in inet_recv_error()
[platform/kernel/linux-starfive.git] / net / core / sock.c
index 666a17c..383e30f 100644 (file)
 #include <linux/interrupt.h>
 #include <linux/poll.h>
 #include <linux/tcp.h>
+#include <linux/udp.h>
 #include <linux/init.h>
 #include <linux/highmem.h>
 #include <linux/user_namespace.h>
@@ -600,7 +601,7 @@ struct dst_entry *__sk_dst_check(struct sock *sk, u32 cookie)
            INDIRECT_CALL_INET(dst->ops->check, ip6_dst_check, ipv4_dst_check,
                               dst, cookie) == NULL) {
                sk_tx_queue_clear(sk);
-               sk->sk_dst_pending_confirm = 0;
+               WRITE_ONCE(sk->sk_dst_pending_confirm, 0);
                RCU_INIT_POINTER(sk->sk_dst_cache, NULL);
                dst_release(dst);
                return NULL;
@@ -765,7 +766,8 @@ bool sk_mc_loop(struct sock *sk)
                return false;
        if (!sk)
                return true;
-       switch (sk->sk_family) {
+       /* IPV6_ADDRFORM can change sk->sk_family under us. */
+       switch (READ_ONCE(sk->sk_family)) {
        case AF_INET:
                return inet_test_bit(MC_LOOP, sk);
 #if IS_ENABLED(CONFIG_IPV6)
@@ -893,7 +895,7 @@ static int sock_timestamping_bind_phc(struct sock *sk, int phc_index)
        if (!match)
                return -EINVAL;
 
-       sk->sk_bind_phc = phc_index;
+       WRITE_ONCE(sk->sk_bind_phc, phc_index);
 
        return 0;
 }
@@ -936,7 +938,7 @@ int sock_set_timestamping(struct sock *sk, int optname,
                        return ret;
        }
 
-       sk->sk_tsflags = val;
+       WRITE_ONCE(sk->sk_tsflags, val);
        sock_valbool_flag(sk, SOCK_TSTAMP_NEW, optname == SO_TIMESTAMPING_NEW);
 
        if (val & SOF_TIMESTAMPING_RX_SOFTWARE)
@@ -1044,7 +1046,7 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
                mem_cgroup_uncharge_skmem(sk->sk_memcg, pages);
                return -ENOMEM;
        }
-       sk->sk_forward_alloc += pages << PAGE_SHIFT;
+       sk_forward_alloc_add(sk, pages << PAGE_SHIFT);
 
        WRITE_ONCE(sk->sk_reserved_mem,
                   sk->sk_reserved_mem + (pages << PAGE_SHIFT));
@@ -1717,9 +1719,16 @@ int sk_getsockopt(struct sock *sk, int level, int optname,
                break;
 
        case SO_TIMESTAMPING_OLD:
+       case SO_TIMESTAMPING_NEW:
                lv = sizeof(v.timestamping);
-               v.timestamping.flags = sk->sk_tsflags;
-               v.timestamping.bind_phc = sk->sk_bind_phc;
+               /* For the later-added case SO_TIMESTAMPING_NEW: Be strict about only
+                * returning the flags when they were set through the same option.
+                * Don't change the beviour for the old case SO_TIMESTAMPING_OLD.
+                */
+               if (optname == SO_TIMESTAMPING_OLD || sock_flag(sk, SOCK_TSTAMP_NEW)) {
+                       v.timestamping.flags = READ_ONCE(sk->sk_tsflags);
+                       v.timestamping.bind_phc = READ_ONCE(sk->sk_bind_phc);
+               }
                break;
 
        case SO_RCVTIMEO_OLD:
@@ -2746,9 +2755,9 @@ static long sock_wait_for_wmem(struct sock *sk, long timeo)
                prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
                if (refcount_read(&sk->sk_wmem_alloc) < READ_ONCE(sk->sk_sndbuf))
                        break;
-               if (sk->sk_shutdown & SEND_SHUTDOWN)
+               if (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN)
                        break;
-               if (sk->sk_err)
+               if (READ_ONCE(sk->sk_err))
                        break;
                timeo = schedule_timeout(timeo);
        }
@@ -2776,7 +2785,7 @@ struct sk_buff *sock_alloc_send_pskb(struct sock *sk, unsigned long header_len,
                        goto failure;
 
                err = -EPIPE;
-               if (sk->sk_shutdown & SEND_SHUTDOWN)
+               if (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN)
                        goto failure;
 
                if (sk_wmem_alloc_get(sk) < READ_ONCE(sk->sk_sndbuf))
@@ -2820,6 +2829,7 @@ int __sock_cmsg_send(struct sock *sk, struct cmsghdr *cmsg,
                sockc->mark = *(u32 *)CMSG_DATA(cmsg);
                break;
        case SO_TIMESTAMPING_OLD:
+       case SO_TIMESTAMPING_NEW:
                if (cmsg->cmsg_len != CMSG_LEN(sizeof(u32)))
                        return -EINVAL;
 
@@ -3138,10 +3148,10 @@ int __sk_mem_schedule(struct sock *sk, int size, int kind)
 {
        int ret, amt = sk_mem_pages(size);
 
-       sk->sk_forward_alloc += amt << PAGE_SHIFT;
+       sk_forward_alloc_add(sk, amt << PAGE_SHIFT);
        ret = __sk_mem_raise_allocated(sk, size, amt, kind);
        if (!ret)
-               sk->sk_forward_alloc -= amt << PAGE_SHIFT;
+               sk_forward_alloc_add(sk, -(amt << PAGE_SHIFT));
        return ret;
 }
 EXPORT_SYMBOL(__sk_mem_schedule);
@@ -3173,7 +3183,7 @@ void __sk_mem_reduce_allocated(struct sock *sk, int amount)
 void __sk_mem_reclaim(struct sock *sk, int amount)
 {
        amount >>= PAGE_SHIFT;
-       sk->sk_forward_alloc -= amount << PAGE_SHIFT;
+       sk_forward_alloc_add(sk, -(amount << PAGE_SHIFT));
        __sk_mem_reduce_allocated(sk, amount);
 }
 EXPORT_SYMBOL(__sk_mem_reclaim);
@@ -3742,7 +3752,7 @@ void sk_get_meminfo(const struct sock *sk, u32 *mem)
        mem[SK_MEMINFO_RCVBUF] = READ_ONCE(sk->sk_rcvbuf);
        mem[SK_MEMINFO_WMEM_ALLOC] = sk_wmem_alloc_get(sk);
        mem[SK_MEMINFO_SNDBUF] = READ_ONCE(sk->sk_sndbuf);
-       mem[SK_MEMINFO_FWD_ALLOC] = sk->sk_forward_alloc;
+       mem[SK_MEMINFO_FWD_ALLOC] = sk_forward_alloc_get(sk);
        mem[SK_MEMINFO_WMEM_QUEUED] = READ_ONCE(sk->sk_wmem_queued);
        mem[SK_MEMINFO_OPTMEM] = atomic_read(&sk->sk_omem_alloc);
        mem[SK_MEMINFO_BACKLOG] = READ_ONCE(sk->sk_backlog.len);
@@ -4127,8 +4137,14 @@ bool sk_busy_loop_end(void *p, unsigned long start_time)
 {
        struct sock *sk = p;
 
-       return !skb_queue_empty_lockless(&sk->sk_receive_queue) ||
-              sk_busy_loop_timeout(sk, start_time);
+       if (!skb_queue_empty_lockless(&sk->sk_receive_queue))
+               return true;
+
+       if (sk_is_udp(sk) &&
+           !skb_queue_empty_lockless(&udp_sk(sk)->reader_queue))
+               return true;
+
+       return sk_busy_loop_timeout(sk, start_time);
 }
 EXPORT_SYMBOL(sk_busy_loop_end);
 #endif /* CONFIG_NET_RX_BUSY_POLL */