inet: read sk->sk_family once in inet_recv_error()
[platform/kernel/linux-starfive.git] / net / ipv4 / af_inet.c
index 3d2e30e..e59962f 100644 (file)
@@ -330,6 +330,9 @@ lookup_protocol:
        if (INET_PROTOSW_REUSE & answer_flags)
                sk->sk_reuse = SK_CAN_REUSE;
 
+       if (INET_PROTOSW_ICSK & answer_flags)
+               inet_init_csk_locks(sk);
+
        inet = inet_sk(sk);
        inet_assign_bit(IS_ICSK, sk, INET_PROTOSW_ICSK & answer_flags);
 
@@ -452,7 +455,7 @@ int inet_bind_sk(struct sock *sk, struct sockaddr *uaddr, int addr_len)
        /* BPF prog is run before any checks are done so that if the prog
         * changes context in a wrong way it will be caught.
         */
-       err = BPF_CGROUP_RUN_PROG_INET_BIND_LOCK(sk, uaddr,
+       err = BPF_CGROUP_RUN_PROG_INET_BIND_LOCK(sk, uaddr, &addr_len,
                                                 CGROUP_INET4_BIND, &flags);
        if (err)
                return err;
@@ -597,7 +600,6 @@ static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias)
 
        add_wait_queue(sk_sleep(sk), &wait);
        sk->sk_write_pending += writebias;
-       sk->sk_wait_pending++;
 
        /* Basic assumption: if someone sets sk->sk_err, he _must_
         * change state of the socket from TCP_SYN_*.
@@ -613,7 +615,6 @@ static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias)
        }
        remove_wait_queue(sk_sleep(sk), &wait);
        sk->sk_write_pending -= writebias;
-       sk->sk_wait_pending--;
        return timeo;
 }
 
@@ -642,6 +643,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
                        return -EINVAL;
 
                if (uaddr->sa_family == AF_UNSPEC) {
+                       sk->sk_disconnects++;
                        err = sk->sk_prot->disconnect(sk, flags);
                        sock->state = err ? SS_DISCONNECTING : SS_UNCONNECTED;
                        goto out;
@@ -696,6 +698,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
                int writebias = (sk->sk_protocol == IPPROTO_TCP) &&
                                tcp_sk(sk)->fastopen_req &&
                                tcp_sk(sk)->fastopen_req->data ? 1 : 0;
+               int dis = sk->sk_disconnects;
 
                /* Error code is set above */
                if (!timeo || !inet_wait_for_connect(sk, timeo, writebias))
@@ -704,6 +707,11 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
                err = sock_intr_errno(timeo);
                if (signal_pending(current))
                        goto out;
+
+               if (dis != sk->sk_disconnects) {
+                       err = -EPIPE;
+                       goto out;
+               }
        }
 
        /* Connection was closed by RST, timeout, ICMP error
@@ -725,6 +733,7 @@ out:
 sock_error:
        err = sock_error(sk) ? : -ECONNABORTED;
        sock->state = SS_UNCONNECTED;
+       sk->sk_disconnects++;
        if (sk->sk_prot->disconnect(sk, flags))
                sock->state = SS_DISCONNECTING;
        goto out;
@@ -788,6 +797,7 @@ int inet_getname(struct socket *sock, struct sockaddr *uaddr,
        struct sock *sk         = sock->sk;
        struct inet_sock *inet  = inet_sk(sk);
        DECLARE_SOCKADDR(struct sockaddr_in *, sin, uaddr);
+       int sin_addr_len = sizeof(*sin);
 
        sin->sin_family = AF_INET;
        lock_sock(sk);
@@ -800,7 +810,7 @@ int inet_getname(struct socket *sock, struct sockaddr *uaddr,
                }
                sin->sin_port = inet->inet_dport;
                sin->sin_addr.s_addr = inet->inet_daddr;
-               BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin,
+               BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin, &sin_addr_len,
                                       CGROUP_INET4_GETPEERNAME);
        } else {
                __be32 addr = inet->inet_rcv_saddr;
@@ -808,12 +818,12 @@ int inet_getname(struct socket *sock, struct sockaddr *uaddr,
                        addr = inet->inet_saddr;
                sin->sin_port = inet->inet_sport;
                sin->sin_addr.s_addr = addr;
-               BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin,
+               BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin, &sin_addr_len,
                                       CGROUP_INET4_GETSOCKNAME);
        }
        release_sock(sk);
        memset(sin->sin_zero, 0, sizeof(sin->sin_zero));
-       return sizeof(*sin);
+       return sin_addr_len;
 }
 EXPORT_SYMBOL(inet_getname);
 
@@ -1618,14 +1628,17 @@ EXPORT_SYMBOL(inet_current_timestamp);
 
 int inet_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len)
 {
-       if (sk->sk_family == AF_INET)
+       unsigned int family = READ_ONCE(sk->sk_family);
+
+       if (family == AF_INET)
                return ip_recv_error(sk, msg, len, addr_len);
 #if IS_ENABLED(CONFIG_IPV6)
-       if (sk->sk_family == AF_INET6)
+       if (family == AF_INET6)
                return pingv6_ops.ipv6_recv_error(sk, msg, len, addr_len);
 #endif
        return -EINVAL;
 }
+EXPORT_SYMBOL(inet_recv_error);
 
 int inet_gro_complete(struct sk_buff *skb, int nhoff)
 {