io_wq: add get/put_work handlers to io_wq_create()
authorJens Axboe <axboe@kernel.dk>
Wed, 13 Nov 2019 05:31:31 +0000 (22:31 -0700)
committerJens Axboe <axboe@kernel.dk>
Wed, 13 Nov 2019 18:37:54 +0000 (11:37 -0700)
For cancellation, we need to ensure that the work item stays valid for
as long as ->cur_work is valid. Right now we can't safely dereference
the work item even under the wqe->lock, because while the ->cur_work
pointer will remain valid, the work could be completing and be freed
in parallel.

Only invoke ->get/put_work() on items we know that the caller queued
themselves. Add IO_WQ_WORK_INTERNAL for io-wq to use, which is needed
when we're queueing a flush item, for instance.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
fs/io-wq.c
fs/io-wq.h
fs/io_uring.c

index 33b14b8..26d8154 100644 (file)
@@ -106,6 +106,9 @@ struct io_wq {
        unsigned long state;
        unsigned nr_wqes;
 
+       get_work_fn *get_work;
+       put_work_fn *put_work;
+
        struct task_struct *manager;
        struct user_struct *user;
        struct mm_struct *mm;
@@ -392,7 +395,7 @@ static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash)
 static void io_worker_handle_work(struct io_worker *worker)
        __releases(wqe->lock)
 {
-       struct io_wq_work *work, *old_work;
+       struct io_wq_work *work, *old_work = NULL, *put_work = NULL;
        struct io_wqe *wqe = worker->wqe;
        struct io_wq *wq = wqe->wq;
 
@@ -424,6 +427,8 @@ static void io_worker_handle_work(struct io_worker *worker)
                        wqe->flags |= IO_WQE_FLAG_STALLED;
 
                spin_unlock_irq(&wqe->lock);
+               if (put_work && wq->put_work)
+                       wq->put_work(old_work);
                if (!work)
                        break;
 next:
@@ -444,6 +449,11 @@ next:
                if (worker->mm)
                        work->flags |= IO_WQ_WORK_HAS_MM;
 
+               if (wq->get_work && !(work->flags & IO_WQ_WORK_INTERNAL)) {
+                       put_work = work;
+                       wq->get_work(work);
+               }
+
                old_work = work;
                work->func(&work);
 
@@ -455,6 +465,12 @@ next:
                }
                if (work && work != old_work) {
                        spin_unlock_irq(&wqe->lock);
+
+                       if (put_work && wq->put_work) {
+                               wq->put_work(put_work);
+                               put_work = NULL;
+                       }
+
                        /* dependent work not hashed */
                        hash = -1U;
                        goto next;
@@ -950,13 +966,15 @@ void io_wq_flush(struct io_wq *wq)
 
                init_completion(&data.done);
                INIT_IO_WORK(&data.work, io_wq_flush_func);
+               data.work.flags |= IO_WQ_WORK_INTERNAL;
                io_wqe_enqueue(wqe, &data.work);
                wait_for_completion(&data.done);
        }
 }
 
 struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm,
-                          struct user_struct *user)
+                          struct user_struct *user, get_work_fn *get_work,
+                          put_work_fn *put_work)
 {
        int ret = -ENOMEM, i, node;
        struct io_wq *wq;
@@ -972,6 +990,9 @@ struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm,
                return ERR_PTR(-ENOMEM);
        }
 
+       wq->get_work = get_work;
+       wq->put_work = put_work;
+
        /* caller must already hold a reference to this */
        wq->user = user;
 
index cc50754..4b29f92 100644 (file)
@@ -10,6 +10,7 @@ enum {
        IO_WQ_WORK_NEEDS_USER   = 8,
        IO_WQ_WORK_NEEDS_FILES  = 16,
        IO_WQ_WORK_UNBOUND      = 32,
+       IO_WQ_WORK_INTERNAL     = 64,
 
        IO_WQ_HASH_SHIFT        = 24,   /* upper 8 bits are used for hash key */
 };
@@ -34,8 +35,12 @@ struct io_wq_work {
                (work)->files = NULL;                   \
        } while (0)                                     \
 
+typedef void (get_work_fn)(struct io_wq_work *);
+typedef void (put_work_fn)(struct io_wq_work *);
+
 struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm,
-                               struct user_struct *user);
+                               struct user_struct *user,
+                               get_work_fn *get_work, put_work_fn *put_work);
 void io_wq_destroy(struct io_wq *wq);
 
 void io_wq_enqueue(struct io_wq *wq, struct io_wq_work *work);
index 99822bf..e1a3b8b 100644 (file)
@@ -3822,6 +3822,20 @@ static int io_sqe_files_update(struct io_ring_ctx *ctx, void __user *arg,
        return done ? done : err;
 }
 
+static void io_put_work(struct io_wq_work *work)
+{
+       struct io_kiocb *req = container_of(work, struct io_kiocb, work);
+
+       io_put_req(req);
+}
+
+static void io_get_work(struct io_wq_work *work)
+{
+       struct io_kiocb *req = container_of(work, struct io_kiocb, work);
+
+       refcount_inc(&req->refs);
+}
+
 static int io_sq_offload_start(struct io_ring_ctx *ctx,
                               struct io_uring_params *p)
 {
@@ -3871,7 +3885,8 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx,
 
        /* Do QD, or 4 * CPUS, whatever is smallest */
        concurrency = min(ctx->sq_entries, 4 * num_online_cpus());
-       ctx->io_wq = io_wq_create(concurrency, ctx->sqo_mm, ctx->user);
+       ctx->io_wq = io_wq_create(concurrency, ctx->sqo_mm, ctx->user,
+                                       io_get_work, io_put_work);
        if (IS_ERR(ctx->io_wq)) {
                ret = PTR_ERR(ctx->io_wq);
                ctx->io_wq = NULL;