Merge branch 'io_uring/io_uring-5.19' of https://github.com/isilence/linux into io_ur...
[platform/kernel/linux-rpi.git] / fs / io_uring.c
index 4719eae..ca6170a 100644 (file)
                        IOSQE_IO_DRAIN | IOSQE_CQE_SKIP_SUCCESS)
 
 #define IO_REQ_CLEAN_FLAGS (REQ_F_BUFFER_SELECTED | REQ_F_NEED_CLEANUP | \
-                               REQ_F_POLLED | REQ_F_CREDS | REQ_F_ASYNC_DATA)
+                               REQ_F_POLLED | REQ_F_INFLIGHT | REQ_F_CREDS | \
+                               REQ_F_ASYNC_DATA)
 
 #define IO_REQ_CLEAN_SLOW_FLAGS (REQ_F_REFCOUNT | REQ_F_LINK | REQ_F_HARDLINK |\
                                 IO_REQ_CLEAN_FLAGS)
@@ -297,8 +298,8 @@ struct io_buffer_list {
        /* below is for ring provided buffers */
        __u16 buf_nr_pages;
        __u16 nr_entries;
-       __u32 head;
-       __u32 mask;
+       __u16 head;
+       __u16 mask;
 };
 
 struct io_buffer {
@@ -540,6 +541,7 @@ struct io_uring_task {
        const struct io_ring_ctx *last;
        struct io_wq            *io_wq;
        struct percpu_counter   inflight;
+       atomic_t                inflight_tracked;
        atomic_t                in_idle;
 
        spinlock_t              task_lock;
@@ -1356,8 +1358,6 @@ static void io_clean_op(struct io_kiocb *req);
 static inline struct file *io_file_get_fixed(struct io_kiocb *req, int fd,
                                             unsigned issue_flags);
 static struct file *io_file_get_normal(struct io_kiocb *req, int fd);
-static void io_drop_inflight_file(struct io_kiocb *req);
-static bool io_assign_file(struct io_kiocb *req, unsigned int issue_flags);
 static void io_queue_sqe(struct io_kiocb *req);
 static void io_rsrc_put_work(struct work_struct *work);
 
@@ -1772,9 +1772,29 @@ static bool io_match_task(struct io_kiocb *head, struct task_struct *task,
                          bool cancel_all)
        __must_hold(&req->ctx->timeout_lock)
 {
+       struct io_kiocb *req;
+
        if (task && head->task != task)
                return false;
-       return cancel_all;
+       if (cancel_all)
+               return true;
+
+       io_for_each_link(req, head) {
+               if (req->flags & REQ_F_INFLIGHT)
+                       return true;
+       }
+       return false;
+}
+
+static bool io_match_linked(struct io_kiocb *head)
+{
+       struct io_kiocb *req;
+
+       io_for_each_link(req, head) {
+               if (req->flags & REQ_F_INFLIGHT)
+                       return true;
+       }
+       return false;
 }
 
 /*
@@ -1784,9 +1804,24 @@ static bool io_match_task(struct io_kiocb *head, struct task_struct *task,
 static bool io_match_task_safe(struct io_kiocb *head, struct task_struct *task,
                               bool cancel_all)
 {
+       bool matched;
+
        if (task && head->task != task)
                return false;
-       return cancel_all;
+       if (cancel_all)
+               return true;
+
+       if (head->flags & REQ_F_LINK_TIMEOUT) {
+               struct io_ring_ctx *ctx = head->ctx;
+
+               /* protect against races with linked timeouts */
+               spin_lock_irq(&ctx->timeout_lock);
+               matched = io_match_linked(head);
+               spin_unlock_irq(&ctx->timeout_lock);
+       } else {
+               matched = io_match_linked(head);
+       }
+       return matched;
 }
 
 static inline bool req_has_async_data(struct io_kiocb *req)
@@ -1942,6 +1977,14 @@ static inline bool io_req_ffs_set(struct io_kiocb *req)
        return req->flags & REQ_F_FIXED_FILE;
 }
 
+static inline void io_req_track_inflight(struct io_kiocb *req)
+{
+       if (!(req->flags & REQ_F_INFLIGHT)) {
+               req->flags |= REQ_F_INFLIGHT;
+               atomic_inc(&current->io_uring->inflight_tracked);
+       }
+}
+
 static struct io_kiocb *__io_prep_linked_timeout(struct io_kiocb *req)
 {
        if (WARN_ON_ONCE(!req->link))
@@ -3003,8 +3046,6 @@ static void __io_req_task_work_add(struct io_kiocb *req,
        unsigned long flags;
        bool running;
 
-       io_drop_inflight_file(req);
-
        spin_lock_irqsave(&tctx->task_lock, flags);
        wq_list_add_tail(&req->io_task_work.node, list);
        running = tctx->task_running;
@@ -3847,7 +3888,7 @@ static void __user *io_ring_buffer_select(struct io_kiocb *req, size_t *len,
 {
        struct io_uring_buf_ring *br = bl->buf_ring;
        struct io_uring_buf *buf;
-       __u32 head = bl->head;
+       __u16 head = bl->head;
 
        if (unlikely(smp_load_acquire(&br->tail) == head))
                return NULL;
@@ -3857,7 +3898,7 @@ static void __user *io_ring_buffer_select(struct io_kiocb *req, size_t *len,
                buf = &br->bufs[head];
        } else {
                int off = head & (IO_BUFFER_LIST_BUF_PER_PAGE - 1);
-               int index = head / IO_BUFFER_LIST_BUF_PER_PAGE - 1;
+               int index = head / IO_BUFFER_LIST_BUF_PER_PAGE;
                buf = page_address(bl->buf_pages[index]);
                buf += off;
        }
@@ -5458,27 +5499,24 @@ static int io_fixed_fd_install(struct io_kiocb *req, unsigned int issue_flags,
        struct io_ring_ctx *ctx = req->ctx;
        int ret;
 
+       io_ring_submit_lock(ctx, issue_flags);
+
        if (alloc_slot) {
-               io_ring_submit_lock(ctx, issue_flags);
                ret = io_file_bitmap_get(ctx);
-               if (unlikely(ret < 0)) {
-                       io_ring_submit_unlock(ctx, issue_flags);
-                       fput(file);
-                       return ret;
-               }
-
+               if (unlikely(ret < 0))
+                       goto err;
                file_slot = ret;
        } else {
                file_slot--;
        }
 
        ret = io_install_fixed_file(req, file, issue_flags, file_slot);
-       if (alloc_slot) {
-               io_ring_submit_unlock(ctx, issue_flags);
-               if (!ret)
-                       return file_slot;
-       }
-
+       if (!ret && alloc_slot)
+               ret = file_slot;
+err:
+       io_ring_submit_unlock(ctx, issue_flags);
+       if (unlikely(ret < 0))
+               fput(file);
        return ret;
 }
 
@@ -5982,7 +6020,7 @@ static int io_close(struct io_kiocb *req, unsigned int issue_flags)
        struct files_struct *files = current->files;
        struct io_close *close = &req->close;
        struct fdtable *fdt;
-       struct file *file = NULL;
+       struct file *file;
        int ret = -EBADF;
 
        if (req->close.file_slot) {
@@ -6001,7 +6039,6 @@ static int io_close(struct io_kiocb *req, unsigned int issue_flags)
                        lockdep_is_held(&files->file_lock));
        if (!file || file->f_op == &io_uring_fops) {
                spin_unlock(&files->file_lock);
-               file = NULL;
                goto err;
        }
 
@@ -6011,21 +6048,16 @@ static int io_close(struct io_kiocb *req, unsigned int issue_flags)
                return -EAGAIN;
        }
 
-       ret = __close_fd_get_file(close->fd, &file);
+       file = __close_fd_get_file(close->fd);
        spin_unlock(&files->file_lock);
-       if (ret < 0) {
-               if (ret == -ENOENT)
-                       ret = -EBADF;
+       if (!file)
                goto err;
-       }
 
        /* No ->flush() or already async, safely close from here */
        ret = filp_close(file, current->files);
 err:
        if (ret < 0)
                req_set_fail(req);
-       if (file)
-               fput(file);
        __io_req_complete(req, issue_flags, ret, 0);
        return 0;
 }
@@ -6927,10 +6959,6 @@ static int io_poll_check_events(struct io_kiocb *req, bool *locked)
 
                if (!req->cqe.res) {
                        struct poll_table_struct pt = { ._key = req->apoll_events };
-                       unsigned flags = locked ? 0 : IO_URING_F_UNLOCKED;
-
-                       if (unlikely(!io_assign_file(req, flags)))
-                               return -EBADF;
                        req->cqe.res = vfs_poll(req->file, &pt) & req->apoll_events;
                }
 
@@ -8338,6 +8366,11 @@ static void io_clean_op(struct io_kiocb *req)
                kfree(req->apoll);
                req->apoll = NULL;
        }
+       if (req->flags & REQ_F_INFLIGHT) {
+               struct io_uring_task *tctx = req->task->io_uring;
+
+               atomic_dec(&tctx->inflight_tracked);
+       }
        if (req->flags & REQ_F_CREDS)
                put_cred(req->creds);
        if (req->flags & REQ_F_ASYNC_DATA) {
@@ -8644,19 +8677,6 @@ out:
        return file;
 }
 
-/*
- * Drop the file for requeue operations. Only used of req->file is the
- * io_uring descriptor itself.
- */
-static void io_drop_inflight_file(struct io_kiocb *req)
-{
-       if (unlikely(req->flags & REQ_F_INFLIGHT)) {
-               fput(req->file);
-               req->file = NULL;
-               req->flags &= ~REQ_F_INFLIGHT;
-       }
-}
-
 static struct file *io_file_get_normal(struct io_kiocb *req, int fd)
 {
        struct file *file = fget(fd);
@@ -8665,7 +8685,7 @@ static struct file *io_file_get_normal(struct io_kiocb *req, int fd)
 
        /* we don't allow fixed io_uring files */
        if (file && file->f_op == &io_uring_fops)
-               req->flags |= REQ_F_INFLIGHT;
+               io_req_track_inflight(req);
        return file;
 }
 
@@ -10197,21 +10217,19 @@ static int io_queue_rsrc_removal(struct io_rsrc_data *data, unsigned idx,
 
 static int io_install_fixed_file(struct io_kiocb *req, struct file *file,
                                 unsigned int issue_flags, u32 slot_index)
+       __must_hold(&req->ctx->uring_lock)
 {
        struct io_ring_ctx *ctx = req->ctx;
        bool needs_switch = false;
        struct io_fixed_file *file_slot;
-       int ret = -EBADF;
+       int ret;
 
-       io_ring_submit_lock(ctx, issue_flags);
        if (file->f_op == &io_uring_fops)
-               goto err;
-       ret = -ENXIO;
+               return -EBADF;
        if (!ctx->file_data)
-               goto err;
-       ret = -EINVAL;
+               return -ENXIO;
        if (slot_index >= ctx->nr_user_files)
-               goto err;
+               return -EINVAL;
 
        slot_index = array_index_nospec(slot_index, ctx->nr_user_files);
        file_slot = io_fixed_file_slot(&ctx->file_table, slot_index);
@@ -10242,7 +10260,6 @@ static int io_install_fixed_file(struct io_kiocb *req, struct file *file,
 err:
        if (needs_switch)
                io_rsrc_node_switch(ctx, ctx->file_data);
-       io_ring_submit_unlock(ctx, issue_flags);
        if (ret)
                fput(file);
        return ret;
@@ -10440,6 +10457,7 @@ static __cold int io_uring_alloc_task_context(struct task_struct *task,
        xa_init(&tctx->xa);
        init_waitqueue_head(&tctx->wait);
        atomic_set(&tctx->in_idle, 0);
+       atomic_set(&tctx->inflight_tracked, 0);
        task->io_uring = tctx;
        spin_lock_init(&tctx->task_lock);
        INIT_WQ_LIST(&tctx->task_list);
@@ -11678,7 +11696,7 @@ static __cold void io_uring_clean_tctx(struct io_uring_task *tctx)
 static s64 tctx_inflight(struct io_uring_task *tctx, bool tracked)
 {
        if (tracked)
-               return 0;
+               return atomic_read(&tctx->inflight_tracked);
        return percpu_counter_sum(&tctx->inflight);
 }
 
@@ -12054,14 +12072,14 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
                        return -EINVAL;
                fd = array_index_nospec(fd, IO_RINGFD_REG_MAX);
                f.file = tctx->registered_rings[fd];
-               if (unlikely(!f.file))
-                       return -EBADF;
+               f.flags = 0;
        } else {
                f = fdget(fd);
-               if (unlikely(!f.file))
-                       return -EBADF;
        }
 
+       if (unlikely(!f.file))
+               return -EBADF;
+
        ret = -EOPNOTSUPP;
        if (unlikely(f.file->f_op != &io_uring_fops))
                goto out_fput;
@@ -12159,8 +12177,7 @@ iopoll_locked:
 out:
        percpu_ref_put(&ctx->refs);
 out_fput:
-       if (!(flags & IORING_ENTER_REGISTERED_RING))
-               fdput(f);
+       fdput(f);
        return ret;
 }
 
@@ -13010,6 +13027,10 @@ static int io_register_pbuf_ring(struct io_ring_ctx *ctx, void __user *arg)
        if (!is_power_of_2(reg.ring_entries))
                return -EINVAL;
 
+       /* cannot disambiguate full vs empty due to head/tail size */
+       if (reg.ring_entries >= 65536)
+               return -EINVAL;
+
        if (unlikely(reg.bgid < BGID_ARRAY && !ctx->io_bl)) {
                int ret = io_init_bl_list(ctx);
                if (ret)