tcp: Clean up some functions.
[platform/kernel/linux-starfive.git] / net / ipv4 / inet_hashtables.c
index 545f91b..29dce78 100644 (file)
@@ -81,55 +81,59 @@ struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep,
        return tb;
 }
 
-struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep,
-                                                  struct net *net,
-                                                  struct inet_bind2_hashbucket *head,
-                                                  const unsigned short port,
-                                                  int l3mdev,
-                                                  const struct sock *sk)
+/*
+ * Caller must hold hashbucket lock for this tb with local BH disabled
+ */
+void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket *tb)
 {
-       struct inet_bind2_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC);
-
-       if (tb) {
-               write_pnet(&tb->ib_net, net);
-               tb->l3mdev    = l3mdev;
-               tb->port      = port;
-#if IS_ENABLED(CONFIG_IPV6)
-               if (sk->sk_family == AF_INET6)
-                       tb->v6_rcv_saddr = sk->sk_v6_rcv_saddr;
-               else
-#endif
-                       tb->rcv_saddr = sk->sk_rcv_saddr;
-               INIT_HLIST_HEAD(&tb->owners);
-               hlist_add_head(&tb->node, &head->chain);
+       if (hlist_empty(&tb->owners)) {
+               __hlist_del(&tb->node);
+               kmem_cache_free(cachep, tb);
        }
-       return tb;
 }
 
-static bool bind2_bucket_addr_match(struct inet_bind2_bucket *tb2, struct sock *sk)
+bool inet_bind_bucket_match(const struct inet_bind_bucket *tb, const struct net *net,
+                           unsigned short port, int l3mdev)
+{
+       return net_eq(ib_net(tb), net) && tb->port == port &&
+               tb->l3mdev == l3mdev;
+}
+
+static void inet_bind2_bucket_init(struct inet_bind2_bucket *tb,
+                                  struct net *net,
+                                  struct inet_bind_hashbucket *head,
+                                  unsigned short port, int l3mdev,
+                                  const struct sock *sk)
 {
+       write_pnet(&tb->ib_net, net);
+       tb->l3mdev    = l3mdev;
+       tb->port      = port;
 #if IS_ENABLED(CONFIG_IPV6)
        if (sk->sk_family == AF_INET6)
-               return ipv6_addr_equal(&tb2->v6_rcv_saddr,
-                                      &sk->sk_v6_rcv_saddr);
+               tb->v6_rcv_saddr = sk->sk_v6_rcv_saddr;
+       else
 #endif
-       return tb2->rcv_saddr == sk->sk_rcv_saddr;
+               tb->rcv_saddr = sk->sk_rcv_saddr;
+       INIT_HLIST_HEAD(&tb->owners);
+       hlist_add_head(&tb->node, &head->chain);
 }
 
-/*
- * Caller must hold hashbucket lock for this tb with local BH disabled
- */
-void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket *tb)
+struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep,
+                                                  struct net *net,
+                                                  struct inet_bind_hashbucket *head,
+                                                  unsigned short port,
+                                                  int l3mdev,
+                                                  const struct sock *sk)
 {
-       if (hlist_empty(&tb->owners)) {
-               __hlist_del(&tb->node);
-               kmem_cache_free(cachep, tb);
-       }
+       struct inet_bind2_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC);
+
+       if (tb)
+               inet_bind2_bucket_init(tb, net, head, port, l3mdev, sk);
+
+       return tb;
 }
 
-/* Caller must hold the lock for the corresponding hashbucket in the bhash table
- * with local BH disabled
- */
+/* Caller must hold hashbucket lock for this tb with local BH disabled */
 void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb)
 {
        if (hlist_empty(&tb->owners)) {
@@ -138,10 +142,21 @@ void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_buck
        }
 }
 
+static bool inet_bind2_bucket_addr_match(const struct inet_bind2_bucket *tb2,
+                                        const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+       if (sk->sk_family == AF_INET6)
+               return ipv6_addr_equal(&tb2->v6_rcv_saddr,
+                                      &sk->sk_v6_rcv_saddr);
+#endif
+       return tb2->rcv_saddr == sk->sk_rcv_saddr;
+}
+
 void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
-                   struct inet_bind2_bucket *tb2, const unsigned short snum)
+                   struct inet_bind2_bucket *tb2, unsigned short port)
 {
-       inet_sk(sk)->inet_num = snum;
+       inet_sk(sk)->inet_num = port;
        sk_add_bind_node(sk, &tb->owners);
        inet_csk(sk)->icsk_bind_hash = tb;
        sk_add_bind2_node(sk, &tb2->owners);
@@ -154,11 +169,14 @@ void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
 static void __inet_put_port(struct sock *sk)
 {
        struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
-       const int bhash = inet_bhashfn(sock_net(sk), inet_sk(sk)->inet_num,
-                       hashinfo->bhash_size);
-       struct inet_bind_hashbucket *head = &hashinfo->bhash[bhash];
-       struct inet_bind2_bucket *tb2;
+       struct inet_bind_hashbucket *head, *head2;
+       struct net *net = sock_net(sk);
        struct inet_bind_bucket *tb;
+       int bhash;
+
+       bhash = inet_bhashfn(net, inet_sk(sk)->inet_num, hashinfo->bhash_size);
+       head = &hashinfo->bhash[bhash];
+       head2 = inet_bhashfn_portaddr(hashinfo, sk, net, inet_sk(sk)->inet_num);
 
        spin_lock(&head->lock);
        tb = inet_csk(sk)->icsk_bind_hash;
@@ -167,12 +185,16 @@ static void __inet_put_port(struct sock *sk)
        inet_sk(sk)->inet_num = 0;
        inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb);
 
+       spin_lock(&head2->lock);
        if (inet_csk(sk)->icsk_bind2_hash) {
-               tb2 = inet_csk(sk)->icsk_bind2_hash;
+               struct inet_bind2_bucket *tb2 = inet_csk(sk)->icsk_bind2_hash;
+
                __sk_del_bind2_node(sk);
                inet_csk(sk)->icsk_bind2_hash = NULL;
                inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2);
        }
+       spin_unlock(&head2->lock);
+
        spin_unlock(&head->lock);
 }
 
@@ -188,20 +210,24 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
 {
        struct inet_hashinfo *table = sk->sk_prot->h.hashinfo;
        unsigned short port = inet_sk(child)->inet_num;
-       const int bhash = inet_bhashfn(sock_net(sk), port,
-                                      table->bhash_size);
-       struct inet_bind_hashbucket *head = &table->bhash[bhash];
-       struct inet_bind2_hashbucket *head_bhash2;
+       struct inet_bind_hashbucket *head, *head2;
        bool created_inet_bind_bucket = false;
        struct net *net = sock_net(sk);
+       bool update_fastreuse = false;
        struct inet_bind2_bucket *tb2;
        struct inet_bind_bucket *tb;
-       int l3mdev;
+       int bhash, l3mdev;
+
+       bhash = inet_bhashfn(net, port, table->bhash_size);
+       head = &table->bhash[bhash];
+       head2 = inet_bhashfn_portaddr(table, child, net, port);
 
        spin_lock(&head->lock);
+       spin_lock(&head2->lock);
        tb = inet_csk(sk)->icsk_bind_hash;
        tb2 = inet_csk(sk)->icsk_bind2_hash;
        if (unlikely(!tb || !tb2)) {
+               spin_unlock(&head2->lock);
                spin_unlock(&head->lock);
                return -ENOENT;
        }
@@ -214,36 +240,39 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
                 * as that of the child socket. We have to look up or
                 * create a new bind bucket for the child here. */
                inet_bind_bucket_for_each(tb, &head->chain) {
-                       if (check_bind_bucket_match(tb, net, port, l3mdev))
+                       if (inet_bind_bucket_match(tb, net, port, l3mdev))
                                break;
                }
                if (!tb) {
                        tb = inet_bind_bucket_create(table->bind_bucket_cachep,
                                                     net, head, port, l3mdev);
                        if (!tb) {
+                               spin_unlock(&head2->lock);
                                spin_unlock(&head->lock);
                                return -ENOMEM;
                        }
                        created_inet_bind_bucket = true;
                }
-               inet_csk_update_fastreuse(tb, child);
+               update_fastreuse = true;
 
                goto bhash2_find;
-       } else if (!bind2_bucket_addr_match(tb2, child)) {
+       } else if (!inet_bind2_bucket_addr_match(tb2, child)) {
                l3mdev = inet_sk_bound_l3mdev(sk);
 
 bhash2_find:
-               tb2 = inet_bind2_bucket_find(table, net, port, l3mdev, child,
-                                            &head_bhash2);
+               tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, child);
                if (!tb2) {
                        tb2 = inet_bind2_bucket_create(table->bind2_bucket_cachep,
-                                                      net, head_bhash2, port,
+                                                      net, head2, port,
                                                       l3mdev, child);
                        if (!tb2)
                                goto error;
                }
        }
+       if (update_fastreuse)
+               inet_csk_update_fastreuse(tb, child);
        inet_bind_hash(child, tb, tb2, port);
+       spin_unlock(&head2->lock);
        spin_unlock(&head->lock);
 
        return 0;
@@ -251,6 +280,7 @@ bhash2_find:
 error:
        if (created_inet_bind_bucket)
                inet_bind_bucket_destroy(table->bind_bucket_cachep, tb);
+       spin_unlock(&head2->lock);
        spin_unlock(&head->lock);
        return -ENOMEM;
 }
@@ -600,8 +630,8 @@ static bool inet_ehash_lookup_by_sk(struct sock *sk,
 bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk)
 {
        struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
-       struct hlist_nulls_head *list;
        struct inet_ehash_bucket *head;
+       struct hlist_nulls_head *list;
        spinlock_t *lock;
        bool ret = true;
 
@@ -756,9 +786,9 @@ void inet_unhash(struct sock *sk)
 }
 EXPORT_SYMBOL_GPL(inet_unhash);
 
-static bool check_bind2_bucket_match(struct inet_bind2_bucket *tb,
-                                    struct net *net, unsigned short port,
-                                    int l3mdev, struct sock *sk)
+static bool inet_bind2_bucket_match(const struct inet_bind2_bucket *tb,
+                                   const struct net *net, unsigned short port,
+                                   int l3mdev, const struct sock *sk)
 {
 #if IS_ENABLED(CONFIG_IPV6)
        if (sk->sk_family == AF_INET6)
@@ -771,60 +801,96 @@ static bool check_bind2_bucket_match(struct inet_bind2_bucket *tb,
                        tb->l3mdev == l3mdev && tb->rcv_saddr == sk->sk_rcv_saddr;
 }
 
-bool check_bind2_bucket_match_nulladdr(struct inet_bind2_bucket *tb,
-                                      struct net *net, const unsigned short port,
-                                      int l3mdev, const struct sock *sk)
+bool inet_bind2_bucket_match_addr_any(const struct inet_bind2_bucket *tb, const struct net *net,
+                                     unsigned short port, int l3mdev, const struct sock *sk)
 {
 #if IS_ENABLED(CONFIG_IPV6)
-       struct in6_addr nulladdr = {};
+       struct in6_addr addr_any = {};
 
        if (sk->sk_family == AF_INET6)
                return net_eq(ib2_net(tb), net) && tb->port == port &&
                        tb->l3mdev == l3mdev &&
-                       ipv6_addr_equal(&tb->v6_rcv_saddr, &nulladdr);
+                       ipv6_addr_equal(&tb->v6_rcv_saddr, &addr_any);
        else
 #endif
                return net_eq(ib2_net(tb), net) && tb->port == port &&
                        tb->l3mdev == l3mdev && tb->rcv_saddr == 0;
 }
 
-static struct inet_bind2_hashbucket *
-inet_bhashfn_portaddr(struct inet_hashinfo *hinfo, const struct sock *sk,
-                     const struct net *net, unsigned short port)
+/* The socket's bhash2 hashbucket spinlock must be held when this is called */
+struct inet_bind2_bucket *
+inet_bind2_bucket_find(const struct inet_bind_hashbucket *head, const struct net *net,
+                      unsigned short port, int l3mdev, const struct sock *sk)
 {
-       u32 hash;
+       struct inet_bind2_bucket *bhash2 = NULL;
 
+       inet_bind_bucket_for_each(bhash2, &head->chain)
+               if (inet_bind2_bucket_match(bhash2, net, port, l3mdev, sk))
+                       break;
+
+       return bhash2;
+}
+
+struct inet_bind_hashbucket *
+inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, int port)
+{
+       struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
+       u32 hash;
 #if IS_ENABLED(CONFIG_IPV6)
+       struct in6_addr addr_any = {};
+
        if (sk->sk_family == AF_INET6)
-               hash = ipv6_portaddr_hash(net, &sk->sk_v6_rcv_saddr, port);
+               hash = ipv6_portaddr_hash(net, &addr_any, port);
        else
 #endif
-               hash = ipv4_portaddr_hash(net, sk->sk_rcv_saddr, port);
+               hash = ipv4_portaddr_hash(net, 0, port);
+
        return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)];
 }
 
-/* This should only be called when the spinlock for the socket's corresponding
- * bind_hashbucket is held
- */
-struct inet_bind2_bucket *
-inet_bind2_bucket_find(struct inet_hashinfo *hinfo, struct net *net,
-                      const unsigned short port, int l3mdev, struct sock *sk,
-                      struct inet_bind2_hashbucket **head)
+int inet_bhash2_update_saddr(struct inet_bind_hashbucket *prev_saddr, struct sock *sk)
 {
-       struct inet_bind2_bucket *bhash2 = NULL;
-       struct inet_bind2_hashbucket *h;
+       struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
+       struct inet_bind2_bucket *tb2, *new_tb2;
+       int l3mdev = inet_sk_bound_l3mdev(sk);
+       struct inet_bind_hashbucket *head2;
+       int port = inet_sk(sk)->inet_num;
+       struct net *net = sock_net(sk);
 
-       h = inet_bhashfn_portaddr(hinfo, sk, net, port);
-       inet_bind_bucket_for_each(bhash2, &h->chain) {
-               if (check_bind2_bucket_match(bhash2, net, port, l3mdev, sk))
-                       break;
+       /* Allocate a bind2 bucket ahead of time to avoid permanently putting
+        * the bhash2 table in an inconsistent state if a new tb2 bucket
+        * allocation fails.
+        */
+       new_tb2 = kmem_cache_alloc(hinfo->bind2_bucket_cachep, GFP_ATOMIC);
+       if (!new_tb2)
+               return -ENOMEM;
+
+       head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+
+       if (prev_saddr) {
+               spin_lock_bh(&prev_saddr->lock);
+               __sk_del_bind2_node(sk);
+               inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep,
+                                         inet_csk(sk)->icsk_bind2_hash);
+               spin_unlock_bh(&prev_saddr->lock);
+       }
+
+       spin_lock_bh(&head2->lock);
+       tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
+       if (!tb2) {
+               tb2 = new_tb2;
+               inet_bind2_bucket_init(tb2, net, head2, port, l3mdev, sk);
        }
+       sk_add_bind2_node(sk, &tb2->owners);
+       inet_csk(sk)->icsk_bind2_hash = tb2;
+       spin_unlock_bh(&head2->lock);
 
-       if (head)
-               *head = h;
+       if (tb2 != new_tb2)
+               kmem_cache_free(hinfo->bind2_bucket_cachep, new_tb2);
 
-       return bhash2;
+       return 0;
 }
+EXPORT_SYMBOL_GPL(inet_bhash2_update_saddr);
 
 /* RFC 6056 3.3.4.  Algorithm 4: Double-Hash Port Selection Algorithm
  * Note that we use 32bit integers (vs RFC 'short integers')
@@ -845,9 +911,8 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
                        struct sock *, __u16, struct inet_timewait_sock **))
 {
        struct inet_hashinfo *hinfo = death_row->hashinfo;
+       struct inet_bind_hashbucket *head, *head2;
        struct inet_timewait_sock *tw = NULL;
-       struct inet_bind2_hashbucket *head2;
-       struct inet_bind_hashbucket *head;
        int port = inet_sk(sk)->inet_num;
        struct net *net = sock_net(sk);
        struct inet_bind2_bucket *tb2;
@@ -909,7 +974,7 @@ other_parity_scan:
                 * the established check is already unique enough.
                 */
                inet_bind_bucket_for_each(tb, &head->chain) {
-                       if (check_bind_bucket_match(tb, net, port, l3mdev)) {
+                       if (inet_bind_bucket_match(tb, net, port, l3mdev)) {
                                if (tb->fastreuse >= 0 ||
                                    tb->fastreuseport >= 0)
                                        goto next_port;
@@ -946,7 +1011,10 @@ ok:
        /* Find the corresponding tb2 bucket since we need to
         * add the socket to the bhash2 table as well
         */
-       tb2 = inet_bind2_bucket_find(hinfo, net, port, l3mdev, sk, &head2);
+       head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+       spin_lock(&head2->lock);
+
+       tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
        if (!tb2) {
                tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, net,
                                               head2, port, l3mdev, sk);
@@ -964,6 +1032,9 @@ ok:
 
        /* Head lock still held and bh's disabled */
        inet_bind_hash(sk, tb, tb2, port);
+
+       spin_unlock(&head2->lock);
+
        if (sk_unhashed(sk)) {
                inet_sk(sk)->inet_sport = htons(port);
                inet_ehash_nolisten(sk, (struct sock *)tw, NULL);
@@ -977,6 +1048,7 @@ ok:
        return 0;
 
 error:
+       spin_unlock(&head2->lock);
        if (tb_created)
                inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb);
        spin_unlock_bh(&head->lock);