bpf: tcp: Bpf iter batching and lock_sock
authorMartin KaFai Lau <kafai@fb.com>
Thu, 1 Jul 2021 20:06:13 +0000 (13:06 -0700)
committerAndrii Nakryiko <andrii@kernel.org>
Fri, 23 Jul 2021 23:45:00 +0000 (16:45 -0700)
This patch does batching and lock_sock for the bpf tcp iter.
It does not affect the proc fs iteration.

With bpf-tcp-cc, new algo rollout happens more often.  Instead of
restarting the application to pick up the new tcp-cc, the next patch
will allow bpf iter to do setsockopt(TCP_CONGESTION).  This requires
locking the sock.

Also, unlike the proc iteration (cat /proc/net/tcp[6]), the bpf iter
can inspect all fields of a tcp_sock.  It will be useful to have a
consistent view on some of the fields (e.g. the ones reported in
tcp_get_info() that also acquires the sock lock).

Double lock: locking the bucket first and then locking the sock could
lead to deadlock.  This patch takes a batching approach similar to
inet_diag.  While holding the bucket lock, it batch a number of sockets
into an array first and then unlock the bucket.  Before doing show(),
it then calls lock_sock_fast().

In a machine with ~400k connections, the maximum number of
sk in a bucket of the established hashtable is 7.  0.02% of
the established connections fall into this bucket size.

For listen hash (port+addr lhash2), the bucket is usually very
small also except for the SO_REUSEPORT use case which the
userspace could have one SO_REUSEPORT socket per thread.

While batching is used, it can also minimize the chance of missing
sock in the setsockopt use case if the whole bucket is batched.
This patch will start with a batch array with INIT_BATCH_SZ (16)
which will be enough for the most common cases.  bpf_iter_tcp_batch()
will try to realloc to a larger array to handle exception case (e.g.
the SO_REUSEPORT case in the lhash2).

Signed-off-by: Martin KaFai Lau <kafai@fb.com>
Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Acked-by: Kuniyuki Iwashima <kuniyu@amazon.co.jp>
Acked-by: Yonghong Song <yhs@fb.com>
Link: https://lore.kernel.org/bpf/20210701200613.1036157-1-kafai@fb.com
net/ipv4/tcp_ipv4.c

index d38b4379dca4f40317202874341b8153e295c6f4..84ac0135d3895719457f2d4040d13c5903023d33 100644 (file)
@@ -2687,6 +2687,15 @@ out:
 }
 
 #ifdef CONFIG_BPF_SYSCALL
+struct bpf_tcp_iter_state {
+       struct tcp_iter_state state;
+       unsigned int cur_sk;
+       unsigned int end_sk;
+       unsigned int max_sk;
+       struct sock **batch;
+       bool st_bucket_done;
+};
+
 struct bpf_iter__tcp {
        __bpf_md_ptr(struct bpf_iter_meta *, meta);
        __bpf_md_ptr(struct sock_common *, sk_common);
@@ -2705,16 +2714,204 @@ static int tcp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta,
        return bpf_iter_run_prog(prog, &ctx);
 }
 
+static void bpf_iter_tcp_put_batch(struct bpf_tcp_iter_state *iter)
+{
+       while (iter->cur_sk < iter->end_sk)
+               sock_put(iter->batch[iter->cur_sk++]);
+}
+
+static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter,
+                                     unsigned int new_batch_sz)
+{
+       struct sock **new_batch;
+
+       new_batch = kvmalloc(sizeof(*new_batch) * new_batch_sz,
+                            GFP_USER | __GFP_NOWARN);
+       if (!new_batch)
+               return -ENOMEM;
+
+       bpf_iter_tcp_put_batch(iter);
+       kvfree(iter->batch);
+       iter->batch = new_batch;
+       iter->max_sk = new_batch_sz;
+
+       return 0;
+}
+
+static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq,
+                                                struct sock *start_sk)
+{
+       struct bpf_tcp_iter_state *iter = seq->private;
+       struct tcp_iter_state *st = &iter->state;
+       struct inet_connection_sock *icsk;
+       unsigned int expected = 1;
+       struct sock *sk;
+
+       sock_hold(start_sk);
+       iter->batch[iter->end_sk++] = start_sk;
+
+       icsk = inet_csk(start_sk);
+       inet_lhash2_for_each_icsk_continue(icsk) {
+               sk = (struct sock *)icsk;
+               if (seq_sk_match(seq, sk)) {
+                       if (iter->end_sk < iter->max_sk) {
+                               sock_hold(sk);
+                               iter->batch[iter->end_sk++] = sk;
+                       }
+                       expected++;
+               }
+       }
+       spin_unlock(&tcp_hashinfo.lhash2[st->bucket].lock);
+
+       return expected;
+}
+
+static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq,
+                                                  struct sock *start_sk)
+{
+       struct bpf_tcp_iter_state *iter = seq->private;
+       struct tcp_iter_state *st = &iter->state;
+       struct hlist_nulls_node *node;
+       unsigned int expected = 1;
+       struct sock *sk;
+
+       sock_hold(start_sk);
+       iter->batch[iter->end_sk++] = start_sk;
+
+       sk = sk_nulls_next(start_sk);
+       sk_nulls_for_each_from(sk, node) {
+               if (seq_sk_match(seq, sk)) {
+                       if (iter->end_sk < iter->max_sk) {
+                               sock_hold(sk);
+                               iter->batch[iter->end_sk++] = sk;
+                       }
+                       expected++;
+               }
+       }
+       spin_unlock_bh(inet_ehash_lockp(&tcp_hashinfo, st->bucket));
+
+       return expected;
+}
+
+static struct sock *bpf_iter_tcp_batch(struct seq_file *seq)
+{
+       struct bpf_tcp_iter_state *iter = seq->private;
+       struct tcp_iter_state *st = &iter->state;
+       unsigned int expected;
+       bool resized = false;
+       struct sock *sk;
+
+       /* The st->bucket is done.  Directly advance to the next
+        * bucket instead of having the tcp_seek_last_pos() to skip
+        * one by one in the current bucket and eventually find out
+        * it has to advance to the next bucket.
+        */
+       if (iter->st_bucket_done) {
+               st->offset = 0;
+               st->bucket++;
+               if (st->state == TCP_SEQ_STATE_LISTENING &&
+                   st->bucket > tcp_hashinfo.lhash2_mask) {
+                       st->state = TCP_SEQ_STATE_ESTABLISHED;
+                       st->bucket = 0;
+               }
+       }
+
+again:
+       /* Get a new batch */
+       iter->cur_sk = 0;
+       iter->end_sk = 0;
+       iter->st_bucket_done = false;
+
+       sk = tcp_seek_last_pos(seq);
+       if (!sk)
+               return NULL; /* Done */
+
+       if (st->state == TCP_SEQ_STATE_LISTENING)
+               expected = bpf_iter_tcp_listening_batch(seq, sk);
+       else
+               expected = bpf_iter_tcp_established_batch(seq, sk);
+
+       if (iter->end_sk == expected) {
+               iter->st_bucket_done = true;
+               return sk;
+       }
+
+       if (!resized && !bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2)) {
+               resized = true;
+               goto again;
+       }
+
+       return sk;
+}
+
+static void *bpf_iter_tcp_seq_start(struct seq_file *seq, loff_t *pos)
+{
+       /* bpf iter does not support lseek, so it always
+        * continue from where it was stop()-ped.
+        */
+       if (*pos)
+               return bpf_iter_tcp_batch(seq);
+
+       return SEQ_START_TOKEN;
+}
+
+static void *bpf_iter_tcp_seq_next(struct seq_file *seq, void *v, loff_t *pos)
+{
+       struct bpf_tcp_iter_state *iter = seq->private;
+       struct tcp_iter_state *st = &iter->state;
+       struct sock *sk;
+
+       /* Whenever seq_next() is called, the iter->cur_sk is
+        * done with seq_show(), so advance to the next sk in
+        * the batch.
+        */
+       if (iter->cur_sk < iter->end_sk) {
+               /* Keeping st->num consistent in tcp_iter_state.
+                * bpf_iter_tcp does not use st->num.
+                * meta.seq_num is used instead.
+                */
+               st->num++;
+               /* Move st->offset to the next sk in the bucket such that
+                * the future start() will resume at st->offset in
+                * st->bucket.  See tcp_seek_last_pos().
+                */
+               st->offset++;
+               sock_put(iter->batch[iter->cur_sk++]);
+       }
+
+       if (iter->cur_sk < iter->end_sk)
+               sk = iter->batch[iter->cur_sk];
+       else
+               sk = bpf_iter_tcp_batch(seq);
+
+       ++*pos;
+       /* Keeping st->last_pos consistent in tcp_iter_state.
+        * bpf iter does not do lseek, so st->last_pos always equals to *pos.
+        */
+       st->last_pos = *pos;
+       return sk;
+}
+
 static int bpf_iter_tcp_seq_show(struct seq_file *seq, void *v)
 {
        struct bpf_iter_meta meta;
        struct bpf_prog *prog;
        struct sock *sk = v;
+       bool slow;
        uid_t uid;
+       int ret;
 
        if (v == SEQ_START_TOKEN)
                return 0;
 
+       if (sk_fullsock(sk))
+               slow = lock_sock_fast(sk);
+
+       if (unlikely(sk_unhashed(sk))) {
+               ret = SEQ_SKIP;
+               goto unlock;
+       }
+
        if (sk->sk_state == TCP_TIME_WAIT) {
                uid = 0;
        } else if (sk->sk_state == TCP_NEW_SYN_RECV) {
@@ -2728,11 +2925,18 @@ static int bpf_iter_tcp_seq_show(struct seq_file *seq, void *v)
 
        meta.seq = seq;
        prog = bpf_iter_get_info(&meta, false);
-       return tcp_prog_seq_show(prog, &meta, v, uid);
+       ret = tcp_prog_seq_show(prog, &meta, v, uid);
+
+unlock:
+       if (sk_fullsock(sk))
+               unlock_sock_fast(sk, slow);
+       return ret;
+
 }
 
 static void bpf_iter_tcp_seq_stop(struct seq_file *seq, void *v)
 {
+       struct bpf_tcp_iter_state *iter = seq->private;
        struct bpf_iter_meta meta;
        struct bpf_prog *prog;
 
@@ -2743,13 +2947,16 @@ static void bpf_iter_tcp_seq_stop(struct seq_file *seq, void *v)
                        (void)tcp_prog_seq_show(prog, &meta, v, 0);
        }
 
-       tcp_seq_stop(seq, v);
+       if (iter->cur_sk < iter->end_sk) {
+               bpf_iter_tcp_put_batch(iter);
+               iter->st_bucket_done = false;
+       }
 }
 
 static const struct seq_operations bpf_iter_tcp_seq_ops = {
        .show           = bpf_iter_tcp_seq_show,
-       .start          = tcp_seq_start,
-       .next           = tcp_seq_next,
+       .start          = bpf_iter_tcp_seq_start,
+       .next           = bpf_iter_tcp_seq_next,
        .stop           = bpf_iter_tcp_seq_stop,
 };
 #endif
@@ -3017,21 +3224,39 @@ static struct pernet_operations __net_initdata tcp_sk_ops = {
 DEFINE_BPF_ITER_FUNC(tcp, struct bpf_iter_meta *meta,
                     struct sock_common *sk_common, uid_t uid)
 
+#define INIT_BATCH_SZ 16
+
 static int bpf_iter_init_tcp(void *priv_data, struct bpf_iter_aux_info *aux)
 {
-       return bpf_iter_init_seq_net(priv_data, aux);
+       struct bpf_tcp_iter_state *iter = priv_data;
+       int err;
+
+       err = bpf_iter_init_seq_net(priv_data, aux);
+       if (err)
+               return err;
+
+       err = bpf_iter_tcp_realloc_batch(iter, INIT_BATCH_SZ);
+       if (err) {
+               bpf_iter_fini_seq_net(priv_data);
+               return err;
+       }
+
+       return 0;
 }
 
 static void bpf_iter_fini_tcp(void *priv_data)
 {
+       struct bpf_tcp_iter_state *iter = priv_data;
+
        bpf_iter_fini_seq_net(priv_data);
+       kvfree(iter->batch);
 }
 
 static const struct bpf_iter_seq_info tcp_seq_info = {
        .seq_ops                = &bpf_iter_tcp_seq_ops,
        .init_seq_private       = bpf_iter_init_tcp,
        .fini_seq_private       = bpf_iter_fini_tcp,
-       .seq_priv_size          = sizeof(struct tcp_iter_state),
+       .seq_priv_size          = sizeof(struct bpf_tcp_iter_state),
 };
 
 static struct bpf_iter_reg tcp_reg_info = {