io_uring/rsrc: rename rsrc_list
[platform/kernel/linux-starfive.git] / io_uring / poll.c
index 8339a92..c90e47d 100644 (file)
@@ -51,6 +51,9 @@ struct io_poll_table {
 
 #define IO_WQE_F_DOUBLE                1
 
+static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync,
+                       void *key);
+
 static inline struct io_kiocb *wqe_to_req(struct wait_queue_entry *wqe)
 {
        unsigned long priv = (unsigned long)wqe->private;
@@ -145,7 +148,7 @@ static void io_poll_req_insert_locked(struct io_kiocb *req)
        hlist_add_head(&req->hash_node, &table->hbs[index].list);
 }
 
-static void io_poll_tw_hash_eject(struct io_kiocb *req, bool *locked)
+static void io_poll_tw_hash_eject(struct io_kiocb *req, struct io_tw_state *ts)
 {
        struct io_ring_ctx *ctx = req->ctx;
 
@@ -156,7 +159,7 @@ static void io_poll_tw_hash_eject(struct io_kiocb *req, bool *locked)
                 * already grabbed the mutex for us, but there is a chance it
                 * failed.
                 */
-               io_tw_lock(ctx, locked);
+               io_tw_lock(ctx, ts);
                hash_del(&req->hash_node);
                req->flags &= ~REQ_F_HASH_LOCKED;
        } else {
@@ -164,15 +167,14 @@ static void io_poll_tw_hash_eject(struct io_kiocb *req, bool *locked)
        }
 }
 
-static void io_init_poll_iocb(struct io_poll *poll, __poll_t events,
-                             wait_queue_func_t wake_func)
+static void io_init_poll_iocb(struct io_poll *poll, __poll_t events)
 {
        poll->head = NULL;
 #define IO_POLL_UNMASK (EPOLLERR|EPOLLHUP|EPOLLNVAL|EPOLLRDHUP)
        /* mask in events that we always want/need */
        poll->events = events | IO_POLL_UNMASK;
        INIT_LIST_HEAD(&poll->wait.entry);
-       init_waitqueue_func_entry(&poll->wait, wake_func);
+       init_waitqueue_func_entry(&poll->wait, io_poll_wake);
 }
 
 static inline void io_poll_remove_entry(struct io_poll *poll)
@@ -236,7 +238,7 @@ enum {
  * req->cqe.res. IOU_POLL_REMOVE_POLL_USE_RES indicates to remove multishot
  * poll and that the result is stored in req->cqe.
  */
-static int io_poll_check_events(struct io_kiocb *req, bool *locked)
+static int io_poll_check_events(struct io_kiocb *req, struct io_tw_state *ts)
 {
        int v;
 
@@ -298,13 +300,13 @@ static int io_poll_check_events(struct io_kiocb *req, bool *locked)
                        __poll_t mask = mangle_poll(req->cqe.res &
                                                    req->apoll_events);
 
-                       if (!io_aux_cqe(req->ctx, *locked, req->cqe.user_data,
+                       if (!io_aux_cqe(req->ctx, ts->locked, req->cqe.user_data,
                                        mask, IORING_CQE_F_MORE, false)) {
                                io_req_set_res(req, mask, 0);
                                return IOU_POLL_REMOVE_POLL_USE_RES;
                        }
                } else {
-                       int ret = io_poll_issue(req, locked);
+                       int ret = io_poll_issue(req, ts);
                        if (ret == IOU_STOP_MULTISHOT)
                                return IOU_POLL_REMOVE_POLL_USE_RES;
                        if (ret < 0)
@@ -324,15 +326,15 @@ static int io_poll_check_events(struct io_kiocb *req, bool *locked)
        return IOU_POLL_NO_ACTION;
 }
 
-static void io_poll_task_func(struct io_kiocb *req, bool *locked)
+static void io_poll_task_func(struct io_kiocb *req, struct io_tw_state *ts)
 {
        int ret;
 
-       ret = io_poll_check_events(req, locked);
+       ret = io_poll_check_events(req, ts);
        if (ret == IOU_POLL_NO_ACTION)
                return;
        io_poll_remove_entries(req);
-       io_poll_tw_hash_eject(req, locked);
+       io_poll_tw_hash_eject(req, ts);
 
        if (req->opcode == IORING_OP_POLL_ADD) {
                if (ret == IOU_POLL_DONE) {
@@ -341,7 +343,7 @@ static void io_poll_task_func(struct io_kiocb *req, bool *locked)
                        poll = io_kiocb_to_cmd(req, struct io_poll);
                        req->cqe.res = mangle_poll(req->cqe.res & poll->events);
                } else if (ret == IOU_POLL_REISSUE) {
-                       io_req_task_submit(req, locked);
+                       io_req_task_submit(req, ts);
                        return;
                } else if (ret != IOU_POLL_REMOVE_POLL_USE_RES) {
                        req->cqe.res = ret;
@@ -349,14 +351,14 @@ static void io_poll_task_func(struct io_kiocb *req, bool *locked)
                }
 
                io_req_set_res(req, req->cqe.res, 0);
-               io_req_task_complete(req, locked);
+               io_req_task_complete(req, ts);
        } else {
-               io_tw_lock(req->ctx, locked);
+               io_tw_lock(req->ctx, ts);
 
                if (ret == IOU_POLL_REMOVE_POLL_USE_RES)
-                       io_req_task_complete(req, locked);
+                       io_req_task_complete(req, ts);
                else if (ret == IOU_POLL_DONE || ret == IOU_POLL_REISSUE)
-                       io_req_task_submit(req, locked);
+                       io_req_task_submit(req, ts);
                else
                        io_req_defer_failed(req, ret);
        }
@@ -508,7 +510,7 @@ static void __io_queue_proc(struct io_poll *poll, struct io_poll_table *pt,
 
                /* mark as double wq entry */
                wqe_private |= IO_WQE_F_DOUBLE;
-               io_init_poll_iocb(poll, first->events, first->wait.func);
+               io_init_poll_iocb(poll, first->events);
                if (!io_poll_double_prepare(req)) {
                        /* the request is completing, just back off */
                        kfree(poll);
@@ -569,7 +571,7 @@ static int __io_arm_poll_handler(struct io_kiocb *req,
 
        INIT_HLIST_NODE(&req->hash_node);
        req->work.cancel_seq = atomic_read(&ctx->cancel_seq);
-       io_init_poll_iocb(poll, mask, io_poll_wake);
+       io_init_poll_iocb(poll, mask);
        poll->file = req->file;
        req->apoll_events = poll->events;
 
@@ -650,6 +652,14 @@ static void io_async_queue_proc(struct file *file, struct wait_queue_head *head,
        __io_queue_proc(&apoll->poll, pt, head, &apoll->double_poll);
 }
 
+/*
+ * We can't reliably detect loops in repeated poll triggers and issue
+ * subsequently failing. But rather than fail these immediately, allow a
+ * certain amount of retries before we give up. Given that this condition
+ * should _rarely_ trigger even once, we should be fine with a larger value.
+ */
+#define APOLL_MAX_RETRY                128
+
 static struct async_poll *io_req_alloc_apoll(struct io_kiocb *req,
                                             unsigned issue_flags)
 {
@@ -665,14 +675,18 @@ static struct async_poll *io_req_alloc_apoll(struct io_kiocb *req,
                if (entry == NULL)
                        goto alloc_apoll;
                apoll = container_of(entry, struct async_poll, cache);
+               apoll->poll.retries = APOLL_MAX_RETRY;
        } else {
 alloc_apoll:
                apoll = kmalloc(sizeof(*apoll), GFP_ATOMIC);
                if (unlikely(!apoll))
                        return NULL;
+               apoll->poll.retries = APOLL_MAX_RETRY;
        }
        apoll->double_poll = NULL;
        req->apoll = apoll;
+       if (unlikely(!--apoll->poll.retries))
+               return NULL;
        return apoll;
 }
 
@@ -694,8 +708,6 @@ int io_arm_poll_handler(struct io_kiocb *req, unsigned issue_flags)
                return IO_APOLL_ABORTED;
        if (!file_can_poll(req->file))
                return IO_APOLL_ABORTED;
-       if ((req->flags & (REQ_F_POLLED|REQ_F_PARTIAL_IO)) == REQ_F_POLLED)
-               return IO_APOLL_ABORTED;
        if (!(req->flags & REQ_F_APOLL_MULTISHOT))
                mask |= EPOLLONESHOT;
 
@@ -714,6 +726,7 @@ int io_arm_poll_handler(struct io_kiocb *req, unsigned issue_flags)
        apoll = io_req_alloc_apoll(req, issue_flags);
        if (!apoll)
                return IO_APOLL_ABORTED;
+       req->flags &= ~(REQ_F_SINGLE_POLL | REQ_F_DOUBLE_POLL);
        req->flags |= REQ_F_POLLED;
        ipt.pt._qproc = io_async_queue_proc;
 
@@ -964,7 +977,7 @@ int io_poll_remove(struct io_kiocb *req, unsigned int issue_flags)
        struct io_hash_bucket *bucket;
        struct io_kiocb *preq;
        int ret2, ret = 0;
-       bool locked;
+       struct io_tw_state ts = {};
 
        preq = io_poll_find(ctx, true, &cd, &ctx->cancel_table, &bucket);
        ret2 = io_poll_disarm(preq);
@@ -1014,8 +1027,8 @@ found:
 
        req_set_fail(preq);
        io_req_set_res(preq, -ECANCELED, 0);
-       locked = !(issue_flags & IO_URING_F_UNLOCKED);
-       io_req_task_complete(preq, &locked);
+       ts.locked = !(issue_flags & IO_URING_F_UNLOCKED);
+       io_req_task_complete(preq, &ts);
 out:
        if (ret < 0) {
                req_set_fail(req);