io_uring: get rid of hashed provided buffer groups
authorJens Axboe <axboe@kernel.dk>
Sun, 1 May 2022 16:52:44 +0000 (10:52 -0600)
committerJens Axboe <axboe@kernel.dk>
Mon, 9 May 2022 12:29:06 +0000 (06:29 -0600)
Use a plain array for any group ID that's less than 64, and punt
anything beyond that to an xarray. 64 fits in a page even for 4KB
page sizes and with the planned additions.

This makes the expected group usage faster by avoiding a hash and lookup
to find our list, and it uses less memory upfront by not allocating any
memory for provided buffers unless it's actually being used.

Suggested-by: Pavel Begunkov <asml.silence@gmail.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
fs/io_uring.c

index b5efabc917c64b03fcbdddce4aa74ed0b091e1b7..0672ce0bd832ca21b91ed565aa40b611beef883b 100644 (file)
@@ -282,7 +282,6 @@ struct io_rsrc_data {
 };
 
 struct io_buffer_list {
-       struct list_head list;
        struct list_head buf_list;
        __u16 bgid;
 };
@@ -357,7 +356,7 @@ struct io_ev_fd {
        struct rcu_head         rcu;
 };
 
-#define IO_BUFFERS_HASH_BITS   5
+#define BGID_ARRAY     64
 
 struct io_ring_ctx {
        /* const or read-mostly hot data */
@@ -414,7 +413,8 @@ struct io_ring_ctx {
                struct list_head        timeout_list;
                struct list_head        ltimeout_list;
                struct list_head        cq_overflow_list;
-               struct list_head        *io_buffers;
+               struct io_buffer_list   *io_bl;
+               struct xarray           io_bl_xa;
                struct list_head        io_buffers_cache;
                struct list_head        apoll_cache;
                struct xarray           personalities;
@@ -1507,15 +1507,10 @@ static inline unsigned int io_put_kbuf(struct io_kiocb *req,
 static struct io_buffer_list *io_buffer_get_list(struct io_ring_ctx *ctx,
                                                 unsigned int bgid)
 {
-       struct list_head *hash_list;
-       struct io_buffer_list *bl;
-
-       hash_list = &ctx->io_buffers[hash_32(bgid, IO_BUFFERS_HASH_BITS)];
-       list_for_each_entry(bl, hash_list, list)
-               if (bl->bgid == bgid || bgid == -1U)
-                       return bl;
+       if (ctx->io_bl && bgid < BGID_ARRAY)
+               return &ctx->io_bl[bgid];
 
-       return NULL;
+       return xa_load(&ctx->io_bl_xa, bgid);
 }
 
 static void io_kbuf_recycle(struct io_kiocb *req, unsigned issue_flags)
@@ -1621,12 +1616,14 @@ static __cold void io_fallback_req_func(struct work_struct *work)
 static __cold struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
 {
        struct io_ring_ctx *ctx;
-       int i, hash_bits;
+       int hash_bits;
 
        ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
        if (!ctx)
                return NULL;
 
+       xa_init(&ctx->io_bl_xa);
+
        /*
         * Use 5 bits less than the max cq entries, that should give us around
         * 32 entries per hash list if totally full and uniformly spread.
@@ -1648,13 +1645,6 @@ static __cold struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
        /* set invalid range, so io_import_fixed() fails meeting it */
        ctx->dummy_ubuf->ubuf = -1UL;
 
-       ctx->io_buffers = kcalloc(1U << IO_BUFFERS_HASH_BITS,
-                                       sizeof(struct list_head), GFP_KERNEL);
-       if (!ctx->io_buffers)
-               goto err;
-       for (i = 0; i < (1U << IO_BUFFERS_HASH_BITS); i++)
-               INIT_LIST_HEAD(&ctx->io_buffers[i]);
-
        if (percpu_ref_init(&ctx->refs, io_ring_ctx_ref_free,
                            PERCPU_REF_ALLOW_REINIT, GFP_KERNEL))
                goto err;
@@ -1690,7 +1680,8 @@ static __cold struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
 err:
        kfree(ctx->dummy_ubuf);
        kfree(ctx->cancel_hash);
-       kfree(ctx->io_buffers);
+       kfree(ctx->io_bl);
+       xa_destroy(&ctx->io_bl_xa);
        kfree(ctx);
        return NULL;
 }
@@ -3447,15 +3438,14 @@ static int io_import_fixed(struct io_kiocb *req, int rw, struct iov_iter *iter,
        return __io_import_fixed(req, rw, iter, imu);
 }
 
-static void io_buffer_add_list(struct io_ring_ctx *ctx,
-                              struct io_buffer_list *bl, unsigned int bgid)
+static int io_buffer_add_list(struct io_ring_ctx *ctx,
+                             struct io_buffer_list *bl, unsigned int bgid)
 {
-       struct list_head *list;
-
-       list = &ctx->io_buffers[hash_32(bgid, IO_BUFFERS_HASH_BITS)];
-       INIT_LIST_HEAD(&bl->buf_list);
        bl->bgid = bgid;
-       list_add(&bl->list, list);
+       if (bgid < BGID_ARRAY)
+               return 0;
+
+       return xa_err(xa_store(&ctx->io_bl_xa, bgid, bl, GFP_KERNEL));
 }
 
 static void __user *io_buffer_select(struct io_kiocb *req, size_t *len,
@@ -4921,6 +4911,23 @@ static int io_add_buffers(struct io_ring_ctx *ctx, struct io_provide_buf *pbuf,
        return i ? 0 : -ENOMEM;
 }
 
+static __cold int io_init_bl_list(struct io_ring_ctx *ctx)
+{
+       int i;
+
+       ctx->io_bl = kcalloc(BGID_ARRAY, sizeof(struct io_buffer_list),
+                               GFP_KERNEL);
+       if (!ctx->io_bl)
+               return -ENOMEM;
+
+       for (i = 0; i < BGID_ARRAY; i++) {
+               INIT_LIST_HEAD(&ctx->io_bl[i].buf_list);
+               ctx->io_bl[i].bgid = i;
+       }
+
+       return 0;
+}
+
 static int io_provide_buffers(struct io_kiocb *req, unsigned int issue_flags)
 {
        struct io_provide_buf *p = &req->pbuf;
@@ -4930,6 +4937,12 @@ static int io_provide_buffers(struct io_kiocb *req, unsigned int issue_flags)
 
        io_ring_submit_lock(ctx, issue_flags);
 
+       if (unlikely(p->bgid < BGID_ARRAY && !ctx->io_bl)) {
+               ret = io_init_bl_list(ctx);
+               if (ret)
+                       goto err;
+       }
+
        bl = io_buffer_get_list(ctx, p->bgid);
        if (unlikely(!bl)) {
                bl = kmalloc(sizeof(*bl), GFP_KERNEL);
@@ -4937,7 +4950,11 @@ static int io_provide_buffers(struct io_kiocb *req, unsigned int issue_flags)
                        ret = -ENOMEM;
                        goto err;
                }
-               io_buffer_add_list(ctx, bl, p->bgid);
+               ret = io_buffer_add_list(ctx, bl, p->bgid);
+               if (ret) {
+                       kfree(bl);
+                       goto err;
+               }
        }
 
        ret = io_add_buffers(ctx, p, bl);
@@ -9931,19 +9948,19 @@ static int io_eventfd_unregister(struct io_ring_ctx *ctx)
 
 static void io_destroy_buffers(struct io_ring_ctx *ctx)
 {
+       struct io_buffer_list *bl;
+       unsigned long index;
        int i;
 
-       for (i = 0; i < (1U << IO_BUFFERS_HASH_BITS); i++) {
-               struct list_head *list = &ctx->io_buffers[i];
-
-               while (!list_empty(list)) {
-                       struct io_buffer_list *bl;
+       for (i = 0; i < BGID_ARRAY; i++) {
+               if (!ctx->io_bl)
+                       break;
+               __io_remove_buffers(ctx, &ctx->io_bl[i], -1U);
+       }
 
-                       bl = list_first_entry(list, struct io_buffer_list, list);
-                       __io_remove_buffers(ctx, bl, -1U);
-                       list_del(&bl->list);
-                       kfree(bl);
-               }
+       xa_for_each(&ctx->io_bl_xa, index, bl) {
+               xa_erase(&ctx->io_bl_xa, bl->bgid);
+               __io_remove_buffers(ctx, bl, -1U);
        }
 
        while (!list_empty(&ctx->io_buffers_pages)) {
@@ -10052,7 +10069,8 @@ static __cold void io_ring_ctx_free(struct io_ring_ctx *ctx)
                io_wq_put_hash(ctx->hash_map);
        kfree(ctx->cancel_hash);
        kfree(ctx->dummy_ubuf);
-       kfree(ctx->io_buffers);
+       kfree(ctx->io_bl);
+       xa_destroy(&ctx->io_bl_xa);
        kfree(ctx);
 }
 
@@ -11980,6 +11998,7 @@ static int __init io_uring_init(void)
 
        /* ->buf_index is u16 */
        BUILD_BUG_ON(IORING_MAX_REG_BUFFERS >= (1u << 16));
+       BUILD_BUG_ON(BGID_ARRAY * sizeof(struct io_buffer_list) > PAGE_SIZE);
 
        /* should fit into one byte */
        BUILD_BUG_ON(SQE_VALID_FLAGS >= (1 << 8));