io_uring: shut io_prep_async_work warning
[platform/kernel/linux-starfive.git] / io_uring / kbuf.c
index db5f189..79c2545 100644 (file)
@@ -137,7 +137,8 @@ static void __user *io_ring_buffer_select(struct io_kiocb *req, size_t *len,
                return NULL;
 
        head &= bl->mask;
-       if (head < IO_BUFFER_LIST_BUF_PER_PAGE) {
+       /* mmaped buffers are always contig */
+       if (bl->is_mmap || head < IO_BUFFER_LIST_BUF_PER_PAGE) {
                buf = &br->bufs[head];
        } else {
                int off = head & (IO_BUFFER_LIST_BUF_PER_PAGE - 1);
@@ -214,15 +215,27 @@ static int __io_remove_buffers(struct io_ring_ctx *ctx,
        if (!nbufs)
                return 0;
 
-       if (bl->is_mapped && bl->buf_nr_pages) {
-               int j;
-
+       if (bl->is_mapped) {
                i = bl->buf_ring->tail - bl->head;
-               for (j = 0; j < bl->buf_nr_pages; j++)
-                       unpin_user_page(bl->buf_pages[j]);
-               kvfree(bl->buf_pages);
-               bl->buf_pages = NULL;
-               bl->buf_nr_pages = 0;
+               if (bl->is_mmap) {
+                       if (bl->buf_ring) {
+                               struct page *page;
+
+                               page = virt_to_head_page(bl->buf_ring);
+                               if (put_page_testzero(page))
+                                       free_compound_page(page);
+                               bl->buf_ring = NULL;
+                       }
+                       bl->is_mmap = 0;
+               } else if (bl->buf_nr_pages) {
+                       int j;
+
+                       for (j = 0; j < bl->buf_nr_pages; j++)
+                               unpin_user_page(bl->buf_pages[j]);
+                       kvfree(bl->buf_pages);
+                       bl->buf_pages = NULL;
+                       bl->buf_nr_pages = 0;
+               }
                /* make sure it's seen as empty */
                INIT_LIST_HEAD(&bl->buf_list);
                bl->is_mapped = 0;
@@ -478,10 +491,47 @@ static int io_pin_pbuf_ring(struct io_uring_buf_reg *reg,
                return PTR_ERR(pages);
 
        br = page_address(pages[0]);
+#ifdef SHM_COLOUR
+       /*
+        * On platforms that have specific aliasing requirements, SHM_COLOUR
+        * is set and we must guarantee that the kernel and user side align
+        * nicely. We cannot do that if IOU_PBUF_RING_MMAP isn't set and
+        * the application mmap's the provided ring buffer. Fail the request
+        * if we, by chance, don't end up with aligned addresses. The app
+        * should use IOU_PBUF_RING_MMAP instead, and liburing will handle
+        * this transparently.
+        */
+       if ((reg->ring_addr | (unsigned long) br) & (SHM_COLOUR - 1)) {
+               int i;
+
+               for (i = 0; i < nr_pages; i++)
+                       unpin_user_page(pages[i]);
+               return -EINVAL;
+       }
+#endif
        bl->buf_pages = pages;
        bl->buf_nr_pages = nr_pages;
        bl->buf_ring = br;
        bl->is_mapped = 1;
+       bl->is_mmap = 0;
+       return 0;
+}
+
+static int io_alloc_pbuf_ring(struct io_uring_buf_reg *reg,
+                             struct io_buffer_list *bl)
+{
+       gfp_t gfp = GFP_KERNEL_ACCOUNT | __GFP_ZERO | __GFP_NOWARN | __GFP_COMP;
+       size_t ring_size;
+       void *ptr;
+
+       ring_size = reg->ring_entries * sizeof(struct io_uring_buf_ring);
+       ptr = (void *) __get_free_pages(gfp, get_order(ring_size));
+       if (!ptr)
+               return -ENOMEM;
+
+       bl->buf_ring = ptr;
+       bl->is_mapped = 1;
+       bl->is_mmap = 1;
        return 0;
 }
 
@@ -494,12 +544,20 @@ int io_register_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
        if (copy_from_user(&reg, arg, sizeof(reg)))
                return -EFAULT;
 
-       if (reg.pad || reg.resv[0] || reg.resv[1] || reg.resv[2])
+       if (reg.resv[0] || reg.resv[1] || reg.resv[2])
                return -EINVAL;
-       if (!reg.ring_addr)
-               return -EFAULT;
-       if (reg.ring_addr & ~PAGE_MASK)
+       if (reg.flags & ~IOU_PBUF_RING_MMAP)
                return -EINVAL;
+       if (!(reg.flags & IOU_PBUF_RING_MMAP)) {
+               if (!reg.ring_addr)
+                       return -EFAULT;
+               if (reg.ring_addr & ~PAGE_MASK)
+                       return -EINVAL;
+       } else {
+               if (reg.ring_addr)
+                       return -EINVAL;
+       }
+
        if (!is_power_of_2(reg.ring_entries))
                return -EINVAL;
 
@@ -524,17 +582,21 @@ int io_register_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
                        return -ENOMEM;
        }
 
-       ret = io_pin_pbuf_ring(&reg, bl);
-       if (ret) {
-               kfree(free_bl);
-               return ret;
-       }
+       if (!(reg.flags & IOU_PBUF_RING_MMAP))
+               ret = io_pin_pbuf_ring(&reg, bl);
+       else
+               ret = io_alloc_pbuf_ring(&reg, bl);
 
-       bl->nr_entries = reg.ring_entries;
-       bl->mask = reg.ring_entries - 1;
+       if (!ret) {
+               bl->nr_entries = reg.ring_entries;
+               bl->mask = reg.ring_entries - 1;
 
-       io_buffer_add_list(ctx, bl, reg.bgid);
-       return 0;
+               io_buffer_add_list(ctx, bl, reg.bgid);
+               return 0;
+       }
+
+       kfree(free_bl);
+       return ret;
 }
 
 int io_unregister_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
@@ -544,7 +606,9 @@ int io_unregister_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
 
        if (copy_from_user(&reg, arg, sizeof(reg)))
                return -EFAULT;
-       if (reg.pad || reg.resv[0] || reg.resv[1] || reg.resv[2])
+       if (reg.resv[0] || reg.resv[1] || reg.resv[2])
+               return -EINVAL;
+       if (reg.flags)
                return -EINVAL;
 
        bl = io_buffer_get_list(ctx, reg.bgid);
@@ -560,3 +624,14 @@ int io_unregister_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
        }
        return 0;
 }
+
+void *io_pbuf_get_address(struct io_ring_ctx *ctx, unsigned long bgid)
+{
+       struct io_buffer_list *bl;
+
+       bl = io_buffer_get_list(ctx, bgid);
+       if (!bl || !bl->is_mmap)
+               return NULL;
+
+       return bl->buf_ring;
+}