io_uring: modularize io_sqe_buffer_register
authorBijan Mottahedeh <bijan.mottahedeh@oracle.com>
Wed, 6 Jan 2021 20:39:10 +0000 (12:39 -0800)
committerJens Axboe <axboe@kernel.dk>
Mon, 1 Feb 2021 17:02:41 +0000 (10:02 -0700)
Split io_sqe_buffer_register into two routines:

- io_sqe_buffer_register() registers a single buffer
- io_sqe_buffers_register iterates over all user specified buffers

Reviewed-by: Pavel Begunkov <asml.silence@gmail.com>
Signed-off-by: Bijan Mottahedeh <bijan.mottahedeh@oracle.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
fs/io_uring.c

index fbc6d2f..ec70ba0 100644 (file)
@@ -8370,7 +8370,7 @@ static unsigned long ring_pages(unsigned sq_entries, unsigned cq_entries)
        return pages;
 }
 
-static int io_sqe_buffer_unregister(struct io_ring_ctx *ctx)
+static int io_sqe_buffers_unregister(struct io_ring_ctx *ctx)
 {
        int i, j;
 
@@ -8488,14 +8488,103 @@ static int io_buffer_account_pin(struct io_ring_ctx *ctx, struct page **pages,
        return ret;
 }
 
-static int io_sqe_buffer_register(struct io_ring_ctx *ctx, void __user *arg,
-                                 unsigned nr_args)
+static int io_sqe_buffer_register(struct io_ring_ctx *ctx, struct iovec *iov,
+                                 struct io_mapped_ubuf *imu,
+                                 struct page **last_hpage)
 {
        struct vm_area_struct **vmas = NULL;
        struct page **pages = NULL;
+       unsigned long off, start, end, ubuf;
+       size_t size;
+       int ret, pret, nr_pages, i;
+
+       ubuf = (unsigned long) iov->iov_base;
+       end = (ubuf + iov->iov_len + PAGE_SIZE - 1) >> PAGE_SHIFT;
+       start = ubuf >> PAGE_SHIFT;
+       nr_pages = end - start;
+
+       ret = -ENOMEM;
+
+       pages = kvmalloc_array(nr_pages, sizeof(struct page *), GFP_KERNEL);
+       if (!pages)
+               goto done;
+
+       vmas = kvmalloc_array(nr_pages, sizeof(struct vm_area_struct *),
+                             GFP_KERNEL);
+       if (!vmas)
+               goto done;
+
+       imu->bvec = kvmalloc_array(nr_pages, sizeof(struct bio_vec),
+                                  GFP_KERNEL);
+       if (!imu->bvec)
+               goto done;
+
+       ret = 0;
+       mmap_read_lock(current->mm);
+       pret = pin_user_pages(ubuf, nr_pages, FOLL_WRITE | FOLL_LONGTERM,
+                             pages, vmas);
+       if (pret == nr_pages) {
+               /* don't support file backed memory */
+               for (i = 0; i < nr_pages; i++) {
+                       struct vm_area_struct *vma = vmas[i];
+
+                       if (vma->vm_file &&
+                           !is_file_hugepages(vma->vm_file)) {
+                               ret = -EOPNOTSUPP;
+                               break;
+                       }
+               }
+       } else {
+               ret = pret < 0 ? pret : -EFAULT;
+       }
+       mmap_read_unlock(current->mm);
+       if (ret) {
+               /*
+                * if we did partial map, or found file backed vmas,
+                * release any pages we did get
+                */
+               if (pret > 0)
+                       unpin_user_pages(pages, pret);
+               kvfree(imu->bvec);
+               goto done;
+       }
+
+       ret = io_buffer_account_pin(ctx, pages, pret, imu, last_hpage);
+       if (ret) {
+               unpin_user_pages(pages, pret);
+               kvfree(imu->bvec);
+               goto done;
+       }
+
+       off = ubuf & ~PAGE_MASK;
+       size = iov->iov_len;
+       for (i = 0; i < nr_pages; i++) {
+               size_t vec_len;
+
+               vec_len = min_t(size_t, size, PAGE_SIZE - off);
+               imu->bvec[i].bv_page = pages[i];
+               imu->bvec[i].bv_len = vec_len;
+               imu->bvec[i].bv_offset = off;
+               off = 0;
+               size -= vec_len;
+       }
+       /* store original address for later verification */
+       imu->ubuf = ubuf;
+       imu->len = iov->iov_len;
+       imu->nr_bvecs = nr_pages;
+       ret = 0;
+done:
+       kvfree(pages);
+       kvfree(vmas);
+       return ret;
+}
+
+static int io_sqe_buffers_register(struct io_ring_ctx *ctx, void __user *arg,
+                                  unsigned int nr_args)
+{
+       int i, ret;
+       struct iovec iov;
        struct page *last_hpage = NULL;
-       int i, j, got_pages = 0;
-       int ret = -EINVAL;
 
        if (ctx->user_bufs)
                return -EBUSY;
@@ -8509,14 +8598,10 @@ static int io_sqe_buffer_register(struct io_ring_ctx *ctx, void __user *arg,
 
        for (i = 0; i < nr_args; i++) {
                struct io_mapped_ubuf *imu = &ctx->user_bufs[i];
-               unsigned long off, start, end, ubuf;
-               int pret, nr_pages;
-               struct iovec iov;
-               size_t size;
 
                ret = io_copy_iov(ctx, &iov, arg, i);
                if (ret)
-                       goto err;
+                       break;
 
                /*
                 * Don't impose further limits on the size and buffer
@@ -8525,103 +8610,22 @@ static int io_sqe_buffer_register(struct io_ring_ctx *ctx, void __user *arg,
                 */
                ret = -EFAULT;
                if (!iov.iov_base || !iov.iov_len)
-                       goto err;
+                       break;
 
                /* arbitrary limit, but we need something */
                if (iov.iov_len > SZ_1G)
-                       goto err;
-
-               ubuf = (unsigned long) iov.iov_base;
-               end = (ubuf + iov.iov_len + PAGE_SIZE - 1) >> PAGE_SHIFT;
-               start = ubuf >> PAGE_SHIFT;
-               nr_pages = end - start;
-
-               ret = 0;
-               if (!pages || nr_pages > got_pages) {
-                       kvfree(vmas);
-                       kvfree(pages);
-                       pages = kvmalloc_array(nr_pages, sizeof(struct page *),
-                                               GFP_KERNEL);
-                       vmas = kvmalloc_array(nr_pages,
-                                       sizeof(struct vm_area_struct *),
-                                       GFP_KERNEL);
-                       if (!pages || !vmas) {
-                               ret = -ENOMEM;
-                               goto err;
-                       }
-                       got_pages = nr_pages;
-               }
-
-               imu->bvec = kvmalloc_array(nr_pages, sizeof(struct bio_vec),
-                                               GFP_KERNEL);
-               ret = -ENOMEM;
-               if (!imu->bvec)
-                       goto err;
-
-               ret = 0;
-               mmap_read_lock(current->mm);
-               pret = pin_user_pages(ubuf, nr_pages,
-                                     FOLL_WRITE | FOLL_LONGTERM,
-                                     pages, vmas);
-               if (pret == nr_pages) {
-                       /* don't support file backed memory */
-                       for (j = 0; j < nr_pages; j++) {
-                               struct vm_area_struct *vma = vmas[j];
-
-                               if (vma->vm_file &&
-                                   !is_file_hugepages(vma->vm_file)) {
-                                       ret = -EOPNOTSUPP;
-                                       break;
-                               }
-                       }
-               } else {
-                       ret = pret < 0 ? pret : -EFAULT;
-               }
-               mmap_read_unlock(current->mm);
-               if (ret) {
-                       /*
-                        * if we did partial map, or found file backed vmas,
-                        * release any pages we did get
-                        */
-                       if (pret > 0)
-                               unpin_user_pages(pages, pret);
-                       kvfree(imu->bvec);
-                       goto err;
-               }
-
-               ret = io_buffer_account_pin(ctx, pages, pret, imu, &last_hpage);
-               if (ret) {
-                       unpin_user_pages(pages, pret);
-                       kvfree(imu->bvec);
-                       goto err;
-               }
+                       break;
 
-               off = ubuf & ~PAGE_MASK;
-               size = iov.iov_len;
-               for (j = 0; j < nr_pages; j++) {
-                       size_t vec_len;
-
-                       vec_len = min_t(size_t, size, PAGE_SIZE - off);
-                       imu->bvec[j].bv_page = pages[j];
-                       imu->bvec[j].bv_len = vec_len;
-                       imu->bvec[j].bv_offset = off;
-                       off = 0;
-                       size -= vec_len;
-               }
-               /* store original address for later verification */
-               imu->ubuf = ubuf;
-               imu->len = iov.iov_len;
-               imu->nr_bvecs = nr_pages;
+               ret = io_sqe_buffer_register(ctx, &iov, imu, &last_hpage);
+               if (ret)
+                       break;
 
                ctx->nr_user_bufs++;
        }
-       kvfree(pages);
-       kvfree(vmas);
-       return 0;
-err:
-       kvfree(pages);
-       kvfree(vmas);
-       io_sqe_buffer_unregister(ctx);
+
+       if (ret)
+               io_sqe_buffers_unregister(ctx);
+
        return ret;
 }
 
@@ -8675,7 +8679,7 @@ static void io_destroy_buffers(struct io_ring_ctx *ctx)
 static void io_ring_ctx_free(struct io_ring_ctx *ctx)
 {
        io_finish_async(ctx);
-       io_sqe_buffer_unregister(ctx);
+       io_sqe_buffers_unregister(ctx);
 
        if (ctx->sqo_task) {
                put_task_struct(ctx->sqo_task);
@@ -10057,13 +10061,13 @@ static int __io_uring_register(struct io_ring_ctx *ctx, unsigned opcode,
 
        switch (opcode) {
        case IORING_REGISTER_BUFFERS:
-               ret = io_sqe_buffer_register(ctx, arg, nr_args);
+               ret = io_sqe_buffers_register(ctx, arg, nr_args);
                break;
        case IORING_UNREGISTER_BUFFERS:
                ret = -EINVAL;
                if (arg || nr_args)
                        break;
-               ret = io_sqe_buffer_unregister(ctx);
+               ret = io_sqe_buffers_unregister(ctx);
                break;
        case IORING_REGISTER_FILES:
                ret = io_sqe_files_register(ctx, arg, nr_args);