io_uring/net: save msghdr->msg_control for retries
[platform/kernel/linux-starfive.git] / io_uring / net.c
index b7f190c..51b0f7f 100644 (file)
@@ -47,6 +47,7 @@ struct io_connect {
        struct sockaddr __user          *addr;
        int                             addr_len;
        bool                            in_progress;
+       bool                            seen_econnaborted;
 };
 
 struct io_sr_msg {
@@ -64,6 +65,7 @@ struct io_sr_msg {
        u16                             addr_len;
        u16                             buf_group;
        void __user                     *addr;
+       void __user                     *msg_control;
        /* used only for send zerocopy */
        struct io_kiocb                 *notif;
 };
@@ -183,8 +185,8 @@ static int io_setup_async_msg(struct io_kiocb *req,
                async_msg->msg.msg_name = &async_msg->addr;
        /* if were using fast_iov, set it to the new one */
        if (iter_is_iovec(&kmsg->msg.msg_iter) && !kmsg->free_iov) {
-               size_t fast_idx = kmsg->msg.msg_iter.iov - kmsg->fast_iov;
-               async_msg->msg.msg_iter.iov = &async_msg->fast_iov[fast_idx];
+               size_t fast_idx = iter_iov(&kmsg->msg.msg_iter) - kmsg->fast_iov;
+               async_msg->msg.msg_iter.__iov = &async_msg->fast_iov[fast_idx];
        }
 
        return -EAGAIN;
@@ -194,11 +196,15 @@ static int io_sendmsg_copy_hdr(struct io_kiocb *req,
                               struct io_async_msghdr *iomsg)
 {
        struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
+       int ret;
 
        iomsg->msg.msg_name = &iomsg->addr;
        iomsg->free_iov = iomsg->fast_iov;
-       return sendmsg_copy_msghdr(&iomsg->msg, sr->umsg, sr->msg_flags,
+       ret = sendmsg_copy_msghdr(&iomsg->msg, sr->umsg, sr->msg_flags,
                                        &iomsg->free_iov);
+       /* save msg_control as sys_sendmsg() overwrites it */
+       sr->msg_control = iomsg->msg.msg_control;
+       return ret;
 }
 
 int io_send_prep_async(struct io_kiocb *req)
@@ -296,6 +302,7 @@ int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags)
 
        if (req_has_async_data(req)) {
                kmsg = req->async_data;
+               kmsg->msg.msg_control = sr->msg_control;
        } else {
                ret = io_sendmsg_copy_hdr(req, &iomsg);
                if (ret)
@@ -1424,7 +1431,7 @@ int io_connect_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 
        conn->addr = u64_to_user_ptr(READ_ONCE(sqe->addr));
        conn->addr_len =  READ_ONCE(sqe->addr2);
-       conn->in_progress = false;
+       conn->in_progress = conn->seen_econnaborted = false;
        return 0;
 }
 
@@ -1461,18 +1468,24 @@ int io_connect(struct io_kiocb *req, unsigned int issue_flags)
 
        ret = __sys_connect_file(req->file, &io->address,
                                        connect->addr_len, file_flags);
-       if ((ret == -EAGAIN || ret == -EINPROGRESS) && force_nonblock) {
+       if ((ret == -EAGAIN || ret == -EINPROGRESS || ret == -ECONNABORTED)
+           && force_nonblock) {
                if (ret == -EINPROGRESS) {
                        connect->in_progress = true;
-               } else {
-                       if (req_has_async_data(req))
-                               return -EAGAIN;
-                       if (io_alloc_async_data(req)) {
-                               ret = -ENOMEM;
+                       return -EAGAIN;
+               }
+               if (ret == -ECONNABORTED) {
+                       if (connect->seen_econnaborted)
                                goto out;
-                       }
-                       memcpy(req->async_data, &__io, sizeof(__io));
+                       connect->seen_econnaborted = true;
+               }
+               if (req_has_async_data(req))
+                       return -EAGAIN;
+               if (io_alloc_async_data(req)) {
+                       ret = -ENOMEM;
+                       goto out;
                }
+               memcpy(req->async_data, &__io, sizeof(__io));
                return -EAGAIN;
        }
        if (ret == -ERESTARTSYS)