io_uring: check for valid register opcode earlier
[platform/kernel/linux-starfive.git] / io_uring / tctx.c
index 7f97d97..4324b1c 100644 (file)
@@ -91,32 +91,12 @@ __cold int io_uring_alloc_task_context(struct task_struct *task,
        return 0;
 }
 
-static int io_register_submitter(struct io_ring_ctx *ctx)
-{
-       int ret = 0;
-
-       mutex_lock(&ctx->uring_lock);
-       if (!ctx->submitter_task)
-               ctx->submitter_task = get_task_struct(current);
-       else if (ctx->submitter_task != current)
-               ret = -EEXIST;
-       mutex_unlock(&ctx->uring_lock);
-
-       return ret;
-}
-
-int __io_uring_add_tctx_node(struct io_ring_ctx *ctx, bool submitter)
+int __io_uring_add_tctx_node(struct io_ring_ctx *ctx)
 {
        struct io_uring_task *tctx = current->io_uring;
        struct io_tctx_node *node;
        int ret;
 
-       if ((ctx->flags & IORING_SETUP_SINGLE_ISSUER) && submitter) {
-               ret = io_register_submitter(ctx);
-               if (ret)
-                       return ret;
-       }
-
        if (unlikely(!tctx)) {
                ret = io_uring_alloc_task_context(current, ctx);
                if (unlikely(ret))
@@ -150,8 +130,22 @@ int __io_uring_add_tctx_node(struct io_ring_ctx *ctx, bool submitter)
                list_add(&node->ctx_node, &ctx->tctx_list);
                mutex_unlock(&ctx->uring_lock);
        }
-       if (submitter)
-               tctx->last = ctx;
+       return 0;
+}
+
+int __io_uring_add_tctx_node_from_submit(struct io_ring_ctx *ctx)
+{
+       int ret;
+
+       if (ctx->flags & IORING_SETUP_SINGLE_ISSUER
+           && ctx->submitter_task != current)
+               return -EEXIST;
+
+       ret = __io_uring_add_tctx_node(ctx);
+       if (ret)
+               return ret;
+
+       current->io_uring->last = ctx;
        return 0;
 }
 
@@ -259,7 +253,7 @@ int io_ringfd_register(struct io_ring_ctx *ctx, void __user *__arg,
                return -EINVAL;
 
        mutex_unlock(&ctx->uring_lock);
-       ret = __io_uring_add_tctx_node(ctx, false);
+       ret = __io_uring_add_tctx_node(ctx);
        mutex_lock(&ctx->uring_lock);
        if (ret)
                return ret;