inet: read sk->sk_family once in inet_recv_error()
[platform/kernel/linux-starfive.git] / net / ipv4 / af_inet.c
index 2713c9b..e59962f 100644 (file)
@@ -330,6 +330,9 @@ lookup_protocol:
        if (INET_PROTOSW_REUSE & answer_flags)
                sk->sk_reuse = SK_CAN_REUSE;
 
+       if (INET_PROTOSW_ICSK & answer_flags)
+               inet_init_csk_locks(sk);
+
        inet = inet_sk(sk);
        inet_assign_bit(IS_ICSK, sk, INET_PROTOSW_ICSK & answer_flags);
 
@@ -452,7 +455,7 @@ int inet_bind_sk(struct sock *sk, struct sockaddr *uaddr, int addr_len)
        /* BPF prog is run before any checks are done so that if the prog
         * changes context in a wrong way it will be caught.
         */
-       err = BPF_CGROUP_RUN_PROG_INET_BIND_LOCK(sk, uaddr,
+       err = BPF_CGROUP_RUN_PROG_INET_BIND_LOCK(sk, uaddr, &addr_len,
                                                 CGROUP_INET4_BIND, &flags);
        if (err)
                return err;
@@ -794,6 +797,7 @@ int inet_getname(struct socket *sock, struct sockaddr *uaddr,
        struct sock *sk         = sock->sk;
        struct inet_sock *inet  = inet_sk(sk);
        DECLARE_SOCKADDR(struct sockaddr_in *, sin, uaddr);
+       int sin_addr_len = sizeof(*sin);
 
        sin->sin_family = AF_INET;
        lock_sock(sk);
@@ -806,7 +810,7 @@ int inet_getname(struct socket *sock, struct sockaddr *uaddr,
                }
                sin->sin_port = inet->inet_dport;
                sin->sin_addr.s_addr = inet->inet_daddr;
-               BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin,
+               BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin, &sin_addr_len,
                                       CGROUP_INET4_GETPEERNAME);
        } else {
                __be32 addr = inet->inet_rcv_saddr;
@@ -814,12 +818,12 @@ int inet_getname(struct socket *sock, struct sockaddr *uaddr,
                        addr = inet->inet_saddr;
                sin->sin_port = inet->inet_sport;
                sin->sin_addr.s_addr = addr;
-               BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin,
+               BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin, &sin_addr_len,
                                       CGROUP_INET4_GETSOCKNAME);
        }
        release_sock(sk);
        memset(sin->sin_zero, 0, sizeof(sin->sin_zero));
-       return sizeof(*sin);
+       return sin_addr_len;
 }
 EXPORT_SYMBOL(inet_getname);
 
@@ -1624,14 +1628,17 @@ EXPORT_SYMBOL(inet_current_timestamp);
 
 int inet_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len)
 {
-       if (sk->sk_family == AF_INET)
+       unsigned int family = READ_ONCE(sk->sk_family);
+
+       if (family == AF_INET)
                return ip_recv_error(sk, msg, len, addr_len);
 #if IS_ENABLED(CONFIG_IPV6)
-       if (sk->sk_family == AF_INET6)
+       if (family == AF_INET6)
                return pingv6_ops.ipv6_recv_error(sk, msg, len, addr_len);
 #endif
        return -EINVAL;
 }
+EXPORT_SYMBOL(inet_recv_error);
 
 int inet_gro_complete(struct sk_buff *skb, int nhoff)
 {