dccp/tcp: Update saddr under bhash's lock.
[platform/kernel/linux-rpi.git] / net / ipv4 / inet_hashtables.c
index d745f96..18ef370 100644 (file)
@@ -858,14 +858,34 @@ inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, in
        return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)];
 }
 
-int inet_bhash2_update_saddr(struct inet_bind_hashbucket *prev_saddr, struct sock *sk)
+static void inet_update_saddr(struct sock *sk, void *saddr, int family)
+{
+       if (family == AF_INET) {
+               inet_sk(sk)->inet_saddr = *(__be32 *)saddr;
+               sk_rcv_saddr_set(sk, inet_sk(sk)->inet_saddr);
+       }
+#if IS_ENABLED(CONFIG_IPV6)
+       else {
+               sk->sk_v6_rcv_saddr = *(struct in6_addr *)saddr;
+       }
+#endif
+}
+
+int inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family)
 {
        struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk);
+       struct inet_bind_hashbucket *head, *head2;
        struct inet_bind2_bucket *tb2, *new_tb2;
        int l3mdev = inet_sk_bound_l3mdev(sk);
-       struct inet_bind_hashbucket *head2;
        int port = inet_sk(sk)->inet_num;
        struct net *net = sock_net(sk);
+       int bhash;
+
+       if (!inet_csk(sk)->icsk_bind2_hash) {
+               /* Not bind()ed before. */
+               inet_update_saddr(sk, saddr, family);
+               return 0;
+       }
 
        /* Allocate a bind2 bucket ahead of time to avoid permanently putting
         * the bhash2 table in an inconsistent state if a new tb2 bucket
@@ -875,14 +895,25 @@ int inet_bhash2_update_saddr(struct inet_bind_hashbucket *prev_saddr, struct soc
        if (!new_tb2)
                return -ENOMEM;
 
+       bhash = inet_bhashfn(net, port, hinfo->bhash_size);
+       head = &hinfo->bhash[bhash];
        head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
 
-       spin_lock_bh(&prev_saddr->lock);
+       /* If we change saddr locklessly, another thread
+        * iterating over bhash might see corrupted address.
+        */
+       spin_lock_bh(&head->lock);
+
+       spin_lock(&head2->lock);
        __sk_del_bind2_node(sk);
        inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep, inet_csk(sk)->icsk_bind2_hash);
-       spin_unlock_bh(&prev_saddr->lock);
+       spin_unlock(&head2->lock);
+
+       inet_update_saddr(sk, saddr, family);
 
-       spin_lock_bh(&head2->lock);
+       head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+
+       spin_lock(&head2->lock);
        tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
        if (!tb2) {
                tb2 = new_tb2;
@@ -890,7 +921,9 @@ int inet_bhash2_update_saddr(struct inet_bind_hashbucket *prev_saddr, struct soc
        }
        sk_add_bind2_node(sk, &tb2->owners);
        inet_csk(sk)->icsk_bind2_hash = tb2;
-       spin_unlock_bh(&head2->lock);
+       spin_unlock(&head2->lock);
+
+       spin_unlock_bh(&head->lock);
 
        if (tb2 != new_tb2)
                kmem_cache_free(hinfo->bind2_bucket_cachep, new_tb2);