tcp: Clean up some functions.
[platform/kernel/linux-starfive.git] / net / ipv4 / inet_hashtables.c
index b9d995b..29dce78 100644 (file)
@@ -92,12 +92,75 @@ void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket
        }
 }
 
+bool inet_bind_bucket_match(const struct inet_bind_bucket *tb, const struct net *net,
+                           unsigned short port, int l3mdev)
+{
+       return net_eq(ib_net(tb), net) && tb->port == port &&
+               tb->l3mdev == l3mdev;
+}
+
+static void inet_bind2_bucket_init(struct inet_bind2_bucket *tb,
+                                  struct net *net,
+                                  struct inet_bind_hashbucket *head,
+                                  unsigned short port, int l3mdev,
+                                  const struct sock *sk)
+{
+       write_pnet(&tb->ib_net, net);
+       tb->l3mdev    = l3mdev;
+       tb->port      = port;
+#if IS_ENABLED(CONFIG_IPV6)
+       if (sk->sk_family == AF_INET6)
+               tb->v6_rcv_saddr = sk->sk_v6_rcv_saddr;
+       else
+#endif
+               tb->rcv_saddr = sk->sk_rcv_saddr;
+       INIT_HLIST_HEAD(&tb->owners);
+       hlist_add_head(&tb->node, &head->chain);
+}
+
+struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep,
+                                                  struct net *net,
+                                                  struct inet_bind_hashbucket *head,
+                                                  unsigned short port,
+                                                  int l3mdev,
+                                                  const struct sock *sk)
+{
+       struct inet_bind2_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC);
+
+       if (tb)
+               inet_bind2_bucket_init(tb, net, head, port, l3mdev, sk);
+
+       return tb;
+}
+
+/* Caller must hold hashbucket lock for this tb with local BH disabled */
+void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb)
+{
+       if (hlist_empty(&tb->owners)) {
+               __hlist_del(&tb->node);
+               kmem_cache_free(cachep, tb);
+       }
+}
+
+static bool inet_bind2_bucket_addr_match(const struct inet_bind2_bucket *tb2,
+                                        const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+       if (sk->sk_family == AF_INET6)
+               return ipv6_addr_equal(&tb2->v6_rcv_saddr,
+                                      &sk->sk_v6_rcv_saddr);
+#endif
+       return tb2->rcv_saddr == sk->sk_rcv_saddr;
+}
+
 void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
-                   const unsigned short snum)
+                   struct inet_bind2_bucket *tb2, unsigned short port)
 {
-       inet_sk(sk)->inet_num = snum;
+       inet_sk(sk)->inet_num = port;
        sk_add_bind_node(sk, &tb->owners);
        inet_csk(sk)->icsk_bind_hash = tb;
+       sk_add_bind2_node(sk, &tb2->owners);
+       inet_csk(sk)->icsk_bind2_hash = tb2;
 }
 
 /*
@@ -106,10 +169,14 @@ void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb,
 static void __inet_put_port(struct sock *sk)
 {
        struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
-       const int bhash = inet_bhashfn(sock_net(sk), inet_sk(sk)->inet_num,
-                       hashinfo->bhash_size);
-       struct inet_bind_hashbucket *head = &hashinfo->bhash[bhash];
+       struct inet_bind_hashbucket *head, *head2;
+       struct net *net = sock_net(sk);
        struct inet_bind_bucket *tb;
+       int bhash;
+
+       bhash = inet_bhashfn(net, inet_sk(sk)->inet_num, hashinfo->bhash_size);
+       head = &hashinfo->bhash[bhash];
+       head2 = inet_bhashfn_portaddr(hashinfo, sk, net, inet_sk(sk)->inet_num);
 
        spin_lock(&head->lock);
        tb = inet_csk(sk)->icsk_bind_hash;
@@ -117,6 +184,17 @@ static void __inet_put_port(struct sock *sk)
        inet_csk(sk)->icsk_bind_hash = NULL;
        inet_sk(sk)->inet_num = 0;
        inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb);
+
+       spin_lock(&head2->lock);
+       if (inet_csk(sk)->icsk_bind2_hash) {
+               struct inet_bind2_bucket *tb2 = inet_csk(sk)->icsk_bind2_hash;
+
+               __sk_del_bind2_node(sk);
+               inet_csk(sk)->icsk_bind2_hash = NULL;
+               inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2);
+       }
+       spin_unlock(&head2->lock);
+
        spin_unlock(&head->lock);
 }
 
@@ -132,15 +210,24 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
 {
        struct inet_hashinfo *table = sk->sk_prot->h.hashinfo;
        unsigned short port = inet_sk(child)->inet_num;
-       const int bhash = inet_bhashfn(sock_net(sk), port,
-                       table->bhash_size);
-       struct inet_bind_hashbucket *head = &table->bhash[bhash];
+       struct inet_bind_hashbucket *head, *head2;
+       bool created_inet_bind_bucket = false;
+       struct net *net = sock_net(sk);
+       bool update_fastreuse = false;
+       struct inet_bind2_bucket *tb2;
        struct inet_bind_bucket *tb;
-       int l3mdev;
+       int bhash, l3mdev;
+
+       bhash = inet_bhashfn(net, port, table->bhash_size);
+       head = &table->bhash[bhash];
+       head2 = inet_bhashfn_portaddr(table, child, net, port);
 
        spin_lock(&head->lock);
+       spin_lock(&head2->lock);
        tb = inet_csk(sk)->icsk_bind_hash;
-       if (unlikely(!tb)) {
+       tb2 = inet_csk(sk)->icsk_bind2_hash;
+       if (unlikely(!tb || !tb2)) {
+               spin_unlock(&head2->lock);
                spin_unlock(&head->lock);
                return -ENOENT;
        }
@@ -153,25 +240,49 @@ int __inet_inherit_port(const struct sock *sk, struct sock *child)
                 * as that of the child socket. We have to look up or
                 * create a new bind bucket for the child here. */
                inet_bind_bucket_for_each(tb, &head->chain) {
-                       if (net_eq(ib_net(tb), sock_net(sk)) &&
-                           tb->l3mdev == l3mdev && tb->port == port)
+                       if (inet_bind_bucket_match(tb, net, port, l3mdev))
                                break;
                }
                if (!tb) {
                        tb = inet_bind_bucket_create(table->bind_bucket_cachep,
-                                                    sock_net(sk), head, port,
-                                                    l3mdev);
+                                                    net, head, port, l3mdev);
                        if (!tb) {
+                               spin_unlock(&head2->lock);
                                spin_unlock(&head->lock);
                                return -ENOMEM;
                        }
+                       created_inet_bind_bucket = true;
+               }
+               update_fastreuse = true;
+
+               goto bhash2_find;
+       } else if (!inet_bind2_bucket_addr_match(tb2, child)) {
+               l3mdev = inet_sk_bound_l3mdev(sk);
+
+bhash2_find:
+               tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, child);
+               if (!tb2) {
+                       tb2 = inet_bind2_bucket_create(table->bind2_bucket_cachep,
+                                                      net, head2, port,
+                                                      l3mdev, child);
+                       if (!tb2)
+                               goto error;
                }
-               inet_csk_update_fastreuse(tb, child);
        }
-       inet_bind_hash(child, tb, port);
+       if (update_fastreuse)
+               inet_csk_update_fastreuse(tb, child);
+       inet_bind_hash(child, tb, tb2, port);
+       spin_unlock(&head2->lock);
        spin_unlock(&head->lock);
 
        return 0;
+
+error:
+       if (created_inet_bind_bucket)
+               inet_bind_bucket_destroy(table->bind_bucket_cachep, tb);
+       spin_unlock(&head2->lock);
+       spin_unlock(&head->lock);
+       return -ENOMEM;
 }
 EXPORT_SYMBOL_GPL(__inet_inherit_port);
 
@@ -519,8 +630,8 @@ static bool inet_ehash_lookup_by_sk(struct sock *sk,
 bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk)
 {
        struct inet_hashinfo *hashinfo = sk->sk_prot->h.hashinfo;
-       struct hlist_nulls_head *list;
        struct inet_ehash_bucket *head;
+       struct hlist_nulls_head *list;
        spinlock_t *lock;
        bool ret = true;
 
@@ -675,6 +786,112 @@ void inet_unhash(struct sock *sk)
 }
 EXPORT_SYMBOL_GPL(inet_unhash);
 
+static bool inet_bind2_bucket_match(const struct inet_bind2_bucket *tb,
+                                   const struct net *net, unsigned short port,
+                                   int l3mdev, const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+       if (sk->sk_family == AF_INET6)
+               return net_eq(ib2_net(tb), net) && tb->port == port &&
+                       tb->l3mdev == l3mdev &&
+                       ipv6_addr_equal(&tb->v6_rcv_saddr, &sk->sk_v6_rcv_saddr);
+       else
+#endif
+               return net_eq(ib2_net(tb), net) && tb->port == port &&
+                       tb->l3mdev == l3mdev && tb->rcv_saddr == sk->sk_rcv_saddr;
+}
+
+bool inet_bind2_bucket_match_addr_any(const struct inet_bind2_bucket *tb, const struct net *net,
+                                     unsigned short port, int l3mdev, const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+       struct in6_addr addr_any = {};
+
+       if (sk->sk_family == AF_INET6)
+               return net_eq(ib2_net(tb), net) && tb->port == port &&
+                       tb->l3mdev == l3mdev &&
+                       ipv6_addr_equal(&tb->v6_rcv_saddr, &addr_any);
+       else
+#endif
+               return net_eq(ib2_net(tb), net) && tb->port == port &&
+                       tb->l3mdev == l3mdev && tb->rcv_saddr == 0;
+}
+
+/* The socket's bhash2 hashbucket spinlock must be held when this is called */
+struct inet_bind2_bucket *
+inet_bind2_bucket_find(const struct inet_bind_hashbucket *head, const struct net *net,
+                      unsigned short port, int l3mdev, const struct sock *sk)
+{
+       struct inet_bind2_bucket *bhash2 = NULL;
+
+       inet_bind_bucket_for_each(bhash2, &head->chain)
+               if (inet_bind2_bucket_match(bhash2, net, port, l3mdev, sk))
+                       break;
+
+       return bhash2;
+}
+
+struct inet_bind_hashbucket *
+inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, int port)
+{
+       struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
+       u32 hash;
+#if IS_ENABLED(CONFIG_IPV6)
+       struct in6_addr addr_any = {};
+
+       if (sk->sk_family == AF_INET6)
+               hash = ipv6_portaddr_hash(net, &addr_any, port);
+       else
+#endif
+               hash = ipv4_portaddr_hash(net, 0, port);
+
+       return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)];
+}
+
+int inet_bhash2_update_saddr(struct inet_bind_hashbucket *prev_saddr, struct sock *sk)
+{
+       struct inet_hashinfo *hinfo = sk->sk_prot->h.hashinfo;
+       struct inet_bind2_bucket *tb2, *new_tb2;
+       int l3mdev = inet_sk_bound_l3mdev(sk);
+       struct inet_bind_hashbucket *head2;
+       int port = inet_sk(sk)->inet_num;
+       struct net *net = sock_net(sk);
+
+       /* Allocate a bind2 bucket ahead of time to avoid permanently putting
+        * the bhash2 table in an inconsistent state if a new tb2 bucket
+        * allocation fails.
+        */
+       new_tb2 = kmem_cache_alloc(hinfo->bind2_bucket_cachep, GFP_ATOMIC);
+       if (!new_tb2)
+               return -ENOMEM;
+
+       head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+
+       if (prev_saddr) {
+               spin_lock_bh(&prev_saddr->lock);
+               __sk_del_bind2_node(sk);
+               inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep,
+                                         inet_csk(sk)->icsk_bind2_hash);
+               spin_unlock_bh(&prev_saddr->lock);
+       }
+
+       spin_lock_bh(&head2->lock);
+       tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
+       if (!tb2) {
+               tb2 = new_tb2;
+               inet_bind2_bucket_init(tb2, net, head2, port, l3mdev, sk);
+       }
+       sk_add_bind2_node(sk, &tb2->owners);
+       inet_csk(sk)->icsk_bind2_hash = tb2;
+       spin_unlock_bh(&head2->lock);
+
+       if (tb2 != new_tb2)
+               kmem_cache_free(hinfo->bind2_bucket_cachep, new_tb2);
+
+       return 0;
+}
+EXPORT_SYMBOL_GPL(inet_bhash2_update_saddr);
+
 /* RFC 6056 3.3.4.  Algorithm 4: Double-Hash Port Selection Algorithm
  * Note that we use 32bit integers (vs RFC 'short integers')
  * because 2^16 is not a multiple of num_ephemeral and this
@@ -694,11 +911,13 @@ int __inet_hash_connect(struct inet_timewait_death_row *death_row,
                        struct sock *, __u16, struct inet_timewait_sock **))
 {
        struct inet_hashinfo *hinfo = death_row->hashinfo;
+       struct inet_bind_hashbucket *head, *head2;
        struct inet_timewait_sock *tw = NULL;
-       struct inet_bind_hashbucket *head;
        int port = inet_sk(sk)->inet_num;
        struct net *net = sock_net(sk);
+       struct inet_bind2_bucket *tb2;
        struct inet_bind_bucket *tb;
+       bool tb_created = false;
        u32 remaining, offset;
        int ret, i, low, high;
        int l3mdev;
@@ -755,8 +974,7 @@ other_parity_scan:
                 * the established check is already unique enough.
                 */
                inet_bind_bucket_for_each(tb, &head->chain) {
-                       if (net_eq(ib_net(tb), net) && tb->l3mdev == l3mdev &&
-                           tb->port == port) {
+                       if (inet_bind_bucket_match(tb, net, port, l3mdev)) {
                                if (tb->fastreuse >= 0 ||
                                    tb->fastreuseport >= 0)
                                        goto next_port;
@@ -774,6 +992,7 @@ other_parity_scan:
                        spin_unlock_bh(&head->lock);
                        return -ENOMEM;
                }
+               tb_created = true;
                tb->fastreuse = -1;
                tb->fastreuseport = -1;
                goto ok;
@@ -789,6 +1008,20 @@ next_port:
        return -EADDRNOTAVAIL;
 
 ok:
+       /* Find the corresponding tb2 bucket since we need to
+        * add the socket to the bhash2 table as well
+        */
+       head2 = inet_bhashfn_portaddr(hinfo, sk, net, port);
+       spin_lock(&head2->lock);
+
+       tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk);
+       if (!tb2) {
+               tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, net,
+                                              head2, port, l3mdev, sk);
+               if (!tb2)
+                       goto error;
+       }
+
        /* Here we want to add a little bit of randomness to the next source
         * port that will be chosen. We use a max() with a random here so that
         * on low contention the randomness is maximal and on high contention
@@ -798,7 +1031,10 @@ ok:
        WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + 2);
 
        /* Head lock still held and bh's disabled */
-       inet_bind_hash(sk, tb, port);
+       inet_bind_hash(sk, tb, tb2, port);
+
+       spin_unlock(&head2->lock);
+
        if (sk_unhashed(sk)) {
                inet_sk(sk)->inet_sport = htons(port);
                inet_ehash_nolisten(sk, (struct sock *)tw, NULL);
@@ -810,6 +1046,13 @@ ok:
                inet_twsk_deschedule_put(tw);
        local_bh_enable();
        return 0;
+
+error:
+       spin_unlock(&head2->lock);
+       if (tb_created)
+               inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb);
+       spin_unlock_bh(&head->lock);
+       return -ENOMEM;
 }
 
 /*