io_uring: stash ctx task reference for SQPOLL
authorJens Axboe <axboe@kernel.dk>
Mon, 14 Sep 2020 16:45:53 +0000 (10:45 -0600)
committerJens Axboe <axboe@kernel.dk>
Thu, 1 Oct 2020 02:32:32 +0000 (20:32 -0600)
We can grab a reference to the task instead of stashing away the task
files_struct. This is doable without creating a circular reference
between the ring fd and the task itself.

Reviewed-by: Pavel Begunkov <asml.silence@gmail.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
fs/io_uring.c

index 73c5dbb1591d3fcde734d6feaa2144cc5794284f..d24e0322bd1d43f1fea1af63c959ca8fb2c7b6ce 100644 (file)
@@ -265,7 +265,16 @@ struct io_ring_ctx {
        /* IO offload */
        struct io_wq            *io_wq;
        struct task_struct      *sqo_thread;    /* if using sq thread polling */
-       struct mm_struct        *sqo_mm;
+
+       /*
+        * For SQPOLL usage - we hold a reference to the parent task, so we
+        * have access to the ->files
+        */
+       struct task_struct      *sqo_task;
+
+       /* Only used for accounting purposes */
+       struct mm_struct        *mm_account;
+
        wait_queue_head_t       sqo_wait;
 
        /*
@@ -969,9 +978,10 @@ static int __io_sq_thread_acquire_mm(struct io_ring_ctx *ctx)
 {
        if (!current->mm) {
                if (unlikely(!(ctx->flags & IORING_SETUP_SQPOLL) ||
-                            !mmget_not_zero(ctx->sqo_mm)))
+                            !ctx->sqo_task->mm ||
+                            !mmget_not_zero(ctx->sqo_task->mm)))
                        return -EFAULT;
-               kthread_use_mm(ctx->sqo_mm);
+               kthread_use_mm(ctx->sqo_task->mm);
        }
 
        return 0;
@@ -7591,11 +7601,11 @@ static void io_unaccount_mem(struct io_ring_ctx *ctx, unsigned long nr_pages,
        if (ctx->limit_mem)
                __io_unaccount_mem(ctx->user, nr_pages);
 
-       if (ctx->sqo_mm) {
+       if (ctx->mm_account) {
                if (acct == ACCT_LOCKED)
-                       ctx->sqo_mm->locked_vm -= nr_pages;
+                       ctx->mm_account->locked_vm -= nr_pages;
                else if (acct == ACCT_PINNED)
-                       atomic64_sub(nr_pages, &ctx->sqo_mm->pinned_vm);
+                       atomic64_sub(nr_pages, &ctx->mm_account->pinned_vm);
        }
 }
 
@@ -7610,11 +7620,11 @@ static int io_account_mem(struct io_ring_ctx *ctx, unsigned long nr_pages,
                        return ret;
        }
 
-       if (ctx->sqo_mm) {
+       if (ctx->mm_account) {
                if (acct == ACCT_LOCKED)
-                       ctx->sqo_mm->locked_vm += nr_pages;
+                       ctx->mm_account->locked_vm += nr_pages;
                else if (acct == ACCT_PINNED)
-                       atomic64_add(nr_pages, &ctx->sqo_mm->pinned_vm);
+                       atomic64_add(nr_pages, &ctx->mm_account->pinned_vm);
        }
 
        return 0;
@@ -7918,9 +7928,12 @@ static void io_ring_ctx_free(struct io_ring_ctx *ctx)
 {
        io_finish_async(ctx);
        io_sqe_buffer_unregister(ctx);
-       if (ctx->sqo_mm) {
-               mmdrop(ctx->sqo_mm);
-               ctx->sqo_mm = NULL;
+
+       if (ctx->sqo_task) {
+               put_task_struct(ctx->sqo_task);
+               ctx->sqo_task = NULL;
+               mmdrop(ctx->mm_account);
+               ctx->mm_account = NULL;
        }
 
        io_sqe_files_unregister(ctx);
@@ -8665,8 +8678,16 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p,
        ctx->user = user;
        ctx->creds = get_current_cred();
 
+       ctx->sqo_task = get_task_struct(current);
+
+       /*
+        * This is just grabbed for accounting purposes. When a process exits,
+        * the mm is exited and dropped before the files, hence we need to hang
+        * on to this mm purely for the purposes of being able to unaccount
+        * memory (locked/pinned vm). It's not used for anything else.
+        */
        mmgrab(current->mm);
-       ctx->sqo_mm = current->mm;
+       ctx->mm_account = current->mm;
 
        /*
         * Account memory _before_ installing the file descriptor. Once