io_uring: iopoll protect complete_post
authorPavel Begunkov <asml.silence@gmail.com>
Wed, 23 Nov 2022 11:33:41 +0000 (11:33 +0000)
committerJens Axboe <axboe@kernel.dk>
Wed, 23 Nov 2022 17:45:31 +0000 (10:45 -0700)
io_req_complete_post() may be used by iopoll enabled rings, grab locks
in this case. That requires to pass issue_flags to propagate the locking
state.

Signed-off-by: Pavel Begunkov <asml.silence@gmail.com>
Link: https://lore.kernel.org/r/cc6d854065c57c838ca8e8806f707a226b70fd2d.1669203009.git.asml.silence@gmail.com
Signed-off-by: Jens Axboe <axboe@kernel.dk>
io_uring/io_uring.c
io_uring/io_uring.h
io_uring/kbuf.c
io_uring/poll.c
io_uring/uring_cmd.c

index a0c71a2..cc27413 100644 (file)
@@ -814,7 +814,7 @@ bool io_post_aux_cqe(struct io_ring_ctx *ctx,
        return filled;
 }
 
-void io_req_complete_post(struct io_kiocb *req)
+static void __io_req_complete_post(struct io_kiocb *req)
 {
        struct io_ring_ctx *ctx = req->ctx;
 
@@ -850,9 +850,18 @@ void io_req_complete_post(struct io_kiocb *req)
        io_cq_unlock_post(ctx);
 }
 
-inline void __io_req_complete(struct io_kiocb *req, unsigned issue_flags)
+void io_req_complete_post(struct io_kiocb *req, unsigned issue_flags)
 {
-       io_req_complete_post(req);
+       if (!(issue_flags & IO_URING_F_UNLOCKED) ||
+           !(req->ctx->flags & IORING_SETUP_IOPOLL)) {
+               __io_req_complete_post(req);
+       } else {
+               struct io_ring_ctx *ctx = req->ctx;
+
+               mutex_lock(&ctx->uring_lock);
+               __io_req_complete_post(req);
+               mutex_unlock(&ctx->uring_lock);
+       }
 }
 
 void io_req_complete_failed(struct io_kiocb *req, s32 res)
@@ -866,7 +875,7 @@ void io_req_complete_failed(struct io_kiocb *req, s32 res)
        io_req_set_res(req, res, io_put_kbuf(req, IO_URING_F_UNLOCKED));
        if (def->fail)
                def->fail(req);
-       io_req_complete_post(req);
+       io_req_complete_post(req, 0);
 }
 
 /*
@@ -1450,7 +1459,7 @@ void io_req_task_complete(struct io_kiocb *req, bool *locked)
        if (*locked)
                io_req_complete_defer(req);
        else
-               io_req_complete_post(req);
+               io_req_complete_post_tw(req, locked);
 }
 
 /*
@@ -1718,7 +1727,7 @@ static int io_issue_sqe(struct io_kiocb *req, unsigned int issue_flags)
                if (issue_flags & IO_URING_F_COMPLETE_DEFER)
                        io_req_complete_defer(req);
                else
-                       io_req_complete_post(req);
+                       io_req_complete_post(req, issue_flags);
        } else if (ret != IOU_ISSUE_SKIP_COMPLETE)
                return ret;
 
index 222af88..b5b80bf 100644 (file)
@@ -31,14 +31,20 @@ int io_run_task_work_sig(struct io_ring_ctx *ctx);
 int __io_run_local_work(struct io_ring_ctx *ctx, bool *locked);
 int io_run_local_work(struct io_ring_ctx *ctx);
 void io_req_complete_failed(struct io_kiocb *req, s32 res);
-void __io_req_complete(struct io_kiocb *req, unsigned issue_flags);
-void io_req_complete_post(struct io_kiocb *req);
+void io_req_complete_post(struct io_kiocb *req, unsigned issue_flags);
 bool io_post_aux_cqe(struct io_ring_ctx *ctx, u64 user_data, s32 res, u32 cflags,
                     bool allow_overflow);
 bool io_fill_cqe_aux(struct io_ring_ctx *ctx, u64 user_data, s32 res, u32 cflags,
                     bool allow_overflow);
 void __io_commit_cqring_flush(struct io_ring_ctx *ctx);
 
+static inline void io_req_complete_post_tw(struct io_kiocb *req, bool *locked)
+{
+       unsigned flags = *locked ? 0 : IO_URING_F_UNLOCKED;
+
+       io_req_complete_post(req, flags);
+}
+
 struct page **io_pin_pages(unsigned long ubuf, unsigned long len, int *npages);
 
 struct file *io_file_get_normal(struct io_kiocb *req, int fd);
index e2c4688..e8150ed 100644 (file)
@@ -311,7 +311,7 @@ int io_remove_buffers(struct io_kiocb *req, unsigned int issue_flags)
 
        /* complete before unlock, IOPOLL may need the lock */
        io_req_set_res(req, ret, 0);
-       __io_req_complete(req, issue_flags);
+       io_req_complete_post(req, 0);
        io_ring_submit_unlock(ctx, issue_flags);
        return IOU_ISSUE_SKIP_COMPLETE;
 }
@@ -462,7 +462,7 @@ err:
                req_set_fail(req);
        /* complete before unlock, IOPOLL may need the lock */
        io_req_set_res(req, ret, 0);
-       __io_req_complete(req, issue_flags);
+       io_req_complete_post(req, 0);
        io_ring_submit_unlock(ctx, issue_flags);
        return IOU_ISSUE_SKIP_COMPLETE;
 }
index cd4d98d..4624e5e 100644 (file)
@@ -312,7 +312,7 @@ static void io_apoll_task_func(struct io_kiocb *req, bool *locked)
        io_poll_tw_hash_eject(req, locked);
 
        if (ret == IOU_POLL_REMOVE_POLL_USE_RES)
-               io_req_complete_post(req);
+               io_req_complete_post_tw(req, locked);
        else if (ret == IOU_POLL_DONE)
                io_req_task_submit(req, locked);
        else
index e50de0b..446a189 100644 (file)
@@ -56,7 +56,7 @@ void io_uring_cmd_done(struct io_uring_cmd *ioucmd, ssize_t ret, ssize_t res2)
                /* order with io_iopoll_req_issued() checking ->iopoll_complete */
                smp_store_release(&req->iopoll_completed, 1);
        else
-               __io_req_complete(req, 0);
+               io_req_complete_post(req, 0);
 }
 EXPORT_SYMBOL_GPL(io_uring_cmd_done);