Merge branch 'CR_2506_515_usb_wifi_jianlong' into 'vf2-515-devel'
[platform/kernel/linux-starfive.git] / fs / io_uring.c
index 6b9e702..bc18af5 100644 (file)
@@ -456,6 +456,8 @@ struct io_ring_ctx {
                struct work_struct              exit_work;
                struct list_head                tctx_list;
                struct completion               ref_comp;
+               u32                             iowq_limits[2];
+               bool                            iowq_limits_set;
        };
 };
 
@@ -1368,11 +1370,6 @@ static void io_req_track_inflight(struct io_kiocb *req)
        }
 }
 
-static inline void io_unprep_linked_timeout(struct io_kiocb *req)
-{
-       req->flags &= ~REQ_F_LINK_TIMEOUT;
-}
-
 static struct io_kiocb *__io_prep_linked_timeout(struct io_kiocb *req)
 {
        if (WARN_ON_ONCE(!req->link))
@@ -2949,7 +2946,7 @@ static void kiocb_done(struct kiocb *kiocb, ssize_t ret,
                        struct io_ring_ctx *ctx = req->ctx;
 
                        req_set_fail(req);
-                       if (issue_flags & IO_URING_F_NONBLOCK) {
+                       if (!(issue_flags & IO_URING_F_NONBLOCK)) {
                                mutex_lock(&ctx->uring_lock);
                                __io_req_complete(req, issue_flags, ret, cflags);
                                mutex_unlock(&ctx->uring_lock);
@@ -6983,7 +6980,7 @@ issue_sqe:
                switch (io_arm_poll_handler(req)) {
                case IO_APOLL_READY:
                        if (linked_timeout)
-                               io_unprep_linked_timeout(req);
+                               io_queue_linked_timeout(linked_timeout);
                        goto issue_sqe;
                case IO_APOLL_ABORTED:
                        /*
@@ -9638,7 +9635,16 @@ static int __io_uring_add_tctx_node(struct io_ring_ctx *ctx)
                ret = io_uring_alloc_task_context(current, ctx);
                if (unlikely(ret))
                        return ret;
+
                tctx = current->io_uring;
+               if (ctx->iowq_limits_set) {
+                       unsigned int limits[2] = { ctx->iowq_limits[0],
+                                                  ctx->iowq_limits[1], };
+
+                       ret = io_wq_max_workers(tctx->io_wq, limits);
+                       if (ret)
+                               return ret;
+               }
        }
        if (!xa_load(&tctx->xa, (unsigned long)ctx)) {
                node = kmalloc(sizeof(*node), GFP_KERNEL);
@@ -10643,7 +10649,9 @@ static int io_unregister_iowq_aff(struct io_ring_ctx *ctx)
 
 static int io_register_iowq_max_workers(struct io_ring_ctx *ctx,
                                        void __user *arg)
+       __must_hold(&ctx->uring_lock)
 {
+       struct io_tctx_node *node;
        struct io_uring_task *tctx = NULL;
        struct io_sq_data *sqd = NULL;
        __u32 new_count[2];
@@ -10674,13 +10682,19 @@ static int io_register_iowq_max_workers(struct io_ring_ctx *ctx,
                tctx = current->io_uring;
        }
 
-       ret = -EINVAL;
-       if (!tctx || !tctx->io_wq)
-               goto err;
+       BUILD_BUG_ON(sizeof(new_count) != sizeof(ctx->iowq_limits));
 
-       ret = io_wq_max_workers(tctx->io_wq, new_count);
-       if (ret)
-               goto err;
+       memcpy(ctx->iowq_limits, new_count, sizeof(new_count));
+       ctx->iowq_limits_set = true;
+
+       ret = -EINVAL;
+       if (tctx && tctx->io_wq) {
+               ret = io_wq_max_workers(tctx->io_wq, new_count);
+               if (ret)
+                       goto err;
+       } else {
+               memset(new_count, 0, sizeof(new_count));
+       }
 
        if (sqd) {
                mutex_unlock(&sqd->lock);
@@ -10690,6 +10704,22 @@ static int io_register_iowq_max_workers(struct io_ring_ctx *ctx,
        if (copy_to_user(arg, new_count, sizeof(new_count)))
                return -EFAULT;
 
+       /* that's it for SQPOLL, only the SQPOLL task creates requests */
+       if (sqd)
+               return 0;
+
+       /* now propagate the restriction to all registered users */
+       list_for_each_entry(node, &ctx->tctx_list, ctx_node) {
+               struct io_uring_task *tctx = node->task->io_uring;
+
+               if (WARN_ON_ONCE(!tctx->io_wq))
+                       continue;
+
+               for (i = 0; i < ARRAY_SIZE(new_count); i++)
+                       new_count[i] = ctx->iowq_limits[i];
+               /* ignore errors, it always returns zero anyway */
+               (void)io_wq_max_workers(tctx->io_wq, new_count);
+       }
        return 0;
 err:
        if (sqd) {