Merge tag 'io_uring-6.0-2022-08-13' of git://git.kernel.dk/linux-block
[platform/kernel/linux-starfive.git] / io_uring / net.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/kernel.h>
3 #include <linux/errno.h>
4 #include <linux/file.h>
5 #include <linux/slab.h>
6 #include <linux/net.h>
7 #include <linux/compat.h>
8 #include <net/compat.h>
9 #include <linux/io_uring.h>
10
11 #include <uapi/linux/io_uring.h>
12
13 #include "io_uring.h"
14 #include "kbuf.h"
15 #include "alloc_cache.h"
16 #include "net.h"
17 #include "notif.h"
18 #include "rsrc.h"
19
20 #if defined(CONFIG_NET)
21 struct io_shutdown {
22         struct file                     *file;
23         int                             how;
24 };
25
26 struct io_accept {
27         struct file                     *file;
28         struct sockaddr __user          *addr;
29         int __user                      *addr_len;
30         int                             flags;
31         u32                             file_slot;
32         unsigned long                   nofile;
33 };
34
35 struct io_socket {
36         struct file                     *file;
37         int                             domain;
38         int                             type;
39         int                             protocol;
40         int                             flags;
41         u32                             file_slot;
42         unsigned long                   nofile;
43 };
44
45 struct io_connect {
46         struct file                     *file;
47         struct sockaddr __user          *addr;
48         int                             addr_len;
49 };
50
51 struct io_sr_msg {
52         struct file                     *file;
53         union {
54                 struct compat_msghdr __user     *umsg_compat;
55                 struct user_msghdr __user       *umsg;
56                 void __user                     *buf;
57         };
58         unsigned                        msg_flags;
59         unsigned                        flags;
60         size_t                          len;
61         size_t                          done_io;
62 };
63
64 struct io_sendzc {
65         struct file                     *file;
66         void __user                     *buf;
67         size_t                          len;
68         u16                             slot_idx;
69         unsigned                        msg_flags;
70         unsigned                        flags;
71         unsigned                        addr_len;
72         void __user                     *addr;
73         size_t                          done_io;
74 };
75
76 #define IO_APOLL_MULTI_POLLED (REQ_F_APOLL_MULTISHOT | REQ_F_POLLED)
77
78 int io_shutdown_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
79 {
80         struct io_shutdown *shutdown = io_kiocb_to_cmd(req, struct io_shutdown);
81
82         if (unlikely(sqe->off || sqe->addr || sqe->rw_flags ||
83                      sqe->buf_index || sqe->splice_fd_in))
84                 return -EINVAL;
85
86         shutdown->how = READ_ONCE(sqe->len);
87         return 0;
88 }
89
90 int io_shutdown(struct io_kiocb *req, unsigned int issue_flags)
91 {
92         struct io_shutdown *shutdown = io_kiocb_to_cmd(req, struct io_shutdown);
93         struct socket *sock;
94         int ret;
95
96         if (issue_flags & IO_URING_F_NONBLOCK)
97                 return -EAGAIN;
98
99         sock = sock_from_file(req->file);
100         if (unlikely(!sock))
101                 return -ENOTSOCK;
102
103         ret = __sys_shutdown_sock(sock, shutdown->how);
104         io_req_set_res(req, ret, 0);
105         return IOU_OK;
106 }
107
108 static bool io_net_retry(struct socket *sock, int flags)
109 {
110         if (!(flags & MSG_WAITALL))
111                 return false;
112         return sock->type == SOCK_STREAM || sock->type == SOCK_SEQPACKET;
113 }
114
115 static void io_netmsg_recycle(struct io_kiocb *req, unsigned int issue_flags)
116 {
117         struct io_async_msghdr *hdr = req->async_data;
118
119         if (!hdr || issue_flags & IO_URING_F_UNLOCKED)
120                 return;
121
122         /* Let normal cleanup path reap it if we fail adding to the cache */
123         if (io_alloc_cache_put(&req->ctx->netmsg_cache, &hdr->cache)) {
124                 req->async_data = NULL;
125                 req->flags &= ~REQ_F_ASYNC_DATA;
126         }
127 }
128
129 static struct io_async_msghdr *io_recvmsg_alloc_async(struct io_kiocb *req,
130                                                       unsigned int issue_flags)
131 {
132         struct io_ring_ctx *ctx = req->ctx;
133         struct io_cache_entry *entry;
134
135         if (!(issue_flags & IO_URING_F_UNLOCKED) &&
136             (entry = io_alloc_cache_get(&ctx->netmsg_cache)) != NULL) {
137                 struct io_async_msghdr *hdr;
138
139                 hdr = container_of(entry, struct io_async_msghdr, cache);
140                 req->flags |= REQ_F_ASYNC_DATA;
141                 req->async_data = hdr;
142                 return hdr;
143         }
144
145         if (!io_alloc_async_data(req))
146                 return req->async_data;
147
148         return NULL;
149 }
150
151 static int io_setup_async_msg(struct io_kiocb *req,
152                               struct io_async_msghdr *kmsg,
153                               unsigned int issue_flags)
154 {
155         struct io_async_msghdr *async_msg = req->async_data;
156
157         if (async_msg)
158                 return -EAGAIN;
159         async_msg = io_recvmsg_alloc_async(req, issue_flags);
160         if (!async_msg) {
161                 kfree(kmsg->free_iov);
162                 return -ENOMEM;
163         }
164         req->flags |= REQ_F_NEED_CLEANUP;
165         memcpy(async_msg, kmsg, sizeof(*kmsg));
166         async_msg->msg.msg_name = &async_msg->addr;
167         /* if were using fast_iov, set it to the new one */
168         if (!async_msg->free_iov)
169                 async_msg->msg.msg_iter.iov = async_msg->fast_iov;
170
171         return -EAGAIN;
172 }
173
174 static int io_sendmsg_copy_hdr(struct io_kiocb *req,
175                                struct io_async_msghdr *iomsg)
176 {
177         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
178
179         iomsg->msg.msg_name = &iomsg->addr;
180         iomsg->free_iov = iomsg->fast_iov;
181         return sendmsg_copy_msghdr(&iomsg->msg, sr->umsg, sr->msg_flags,
182                                         &iomsg->free_iov);
183 }
184
185 int io_sendmsg_prep_async(struct io_kiocb *req)
186 {
187         int ret;
188
189         ret = io_sendmsg_copy_hdr(req, req->async_data);
190         if (!ret)
191                 req->flags |= REQ_F_NEED_CLEANUP;
192         return ret;
193 }
194
195 void io_sendmsg_recvmsg_cleanup(struct io_kiocb *req)
196 {
197         struct io_async_msghdr *io = req->async_data;
198
199         kfree(io->free_iov);
200 }
201
202 int io_sendmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
203 {
204         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
205
206         if (unlikely(sqe->file_index || sqe->addr2))
207                 return -EINVAL;
208
209         sr->umsg = u64_to_user_ptr(READ_ONCE(sqe->addr));
210         sr->len = READ_ONCE(sqe->len);
211         sr->flags = READ_ONCE(sqe->ioprio);
212         if (sr->flags & ~IORING_RECVSEND_POLL_FIRST)
213                 return -EINVAL;
214         sr->msg_flags = READ_ONCE(sqe->msg_flags) | MSG_NOSIGNAL;
215         if (sr->msg_flags & MSG_DONTWAIT)
216                 req->flags |= REQ_F_NOWAIT;
217
218 #ifdef CONFIG_COMPAT
219         if (req->ctx->compat)
220                 sr->msg_flags |= MSG_CMSG_COMPAT;
221 #endif
222         sr->done_io = 0;
223         return 0;
224 }
225
226 int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags)
227 {
228         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
229         struct io_async_msghdr iomsg, *kmsg;
230         struct socket *sock;
231         unsigned flags;
232         int min_ret = 0;
233         int ret;
234
235         sock = sock_from_file(req->file);
236         if (unlikely(!sock))
237                 return -ENOTSOCK;
238
239         if (req_has_async_data(req)) {
240                 kmsg = req->async_data;
241         } else {
242                 ret = io_sendmsg_copy_hdr(req, &iomsg);
243                 if (ret)
244                         return ret;
245                 kmsg = &iomsg;
246         }
247
248         if (!(req->flags & REQ_F_POLLED) &&
249             (sr->flags & IORING_RECVSEND_POLL_FIRST))
250                 return io_setup_async_msg(req, kmsg, issue_flags);
251
252         flags = sr->msg_flags;
253         if (issue_flags & IO_URING_F_NONBLOCK)
254                 flags |= MSG_DONTWAIT;
255         if (flags & MSG_WAITALL)
256                 min_ret = iov_iter_count(&kmsg->msg.msg_iter);
257
258         ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
259
260         if (ret < min_ret) {
261                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
262                         return io_setup_async_msg(req, kmsg, issue_flags);
263                 if (ret == -ERESTARTSYS)
264                         ret = -EINTR;
265                 if (ret > 0 && io_net_retry(sock, flags)) {
266                         sr->done_io += ret;
267                         req->flags |= REQ_F_PARTIAL_IO;
268                         return io_setup_async_msg(req, kmsg, issue_flags);
269                 }
270                 req_set_fail(req);
271         }
272         /* fast path, check for non-NULL to avoid function call */
273         if (kmsg->free_iov)
274                 kfree(kmsg->free_iov);
275         req->flags &= ~REQ_F_NEED_CLEANUP;
276         io_netmsg_recycle(req, issue_flags);
277         if (ret >= 0)
278                 ret += sr->done_io;
279         else if (sr->done_io)
280                 ret = sr->done_io;
281         io_req_set_res(req, ret, 0);
282         return IOU_OK;
283 }
284
285 int io_send(struct io_kiocb *req, unsigned int issue_flags)
286 {
287         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
288         struct msghdr msg;
289         struct iovec iov;
290         struct socket *sock;
291         unsigned flags;
292         int min_ret = 0;
293         int ret;
294
295         if (!(req->flags & REQ_F_POLLED) &&
296             (sr->flags & IORING_RECVSEND_POLL_FIRST))
297                 return -EAGAIN;
298
299         sock = sock_from_file(req->file);
300         if (unlikely(!sock))
301                 return -ENOTSOCK;
302
303         ret = import_single_range(WRITE, sr->buf, sr->len, &iov, &msg.msg_iter);
304         if (unlikely(ret))
305                 return ret;
306
307         msg.msg_name = NULL;
308         msg.msg_control = NULL;
309         msg.msg_controllen = 0;
310         msg.msg_namelen = 0;
311         msg.msg_ubuf = NULL;
312
313         flags = sr->msg_flags;
314         if (issue_flags & IO_URING_F_NONBLOCK)
315                 flags |= MSG_DONTWAIT;
316         if (flags & MSG_WAITALL)
317                 min_ret = iov_iter_count(&msg.msg_iter);
318
319         msg.msg_flags = flags;
320         ret = sock_sendmsg(sock, &msg);
321         if (ret < min_ret) {
322                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
323                         return -EAGAIN;
324                 if (ret == -ERESTARTSYS)
325                         ret = -EINTR;
326                 if (ret > 0 && io_net_retry(sock, flags)) {
327                         sr->len -= ret;
328                         sr->buf += ret;
329                         sr->done_io += ret;
330                         req->flags |= REQ_F_PARTIAL_IO;
331                         return -EAGAIN;
332                 }
333                 req_set_fail(req);
334         }
335         if (ret >= 0)
336                 ret += sr->done_io;
337         else if (sr->done_io)
338                 ret = sr->done_io;
339         io_req_set_res(req, ret, 0);
340         return IOU_OK;
341 }
342
343 static bool io_recvmsg_multishot_overflow(struct io_async_msghdr *iomsg)
344 {
345         int hdr;
346
347         if (iomsg->namelen < 0)
348                 return true;
349         if (check_add_overflow((int)sizeof(struct io_uring_recvmsg_out),
350                                iomsg->namelen, &hdr))
351                 return true;
352         if (check_add_overflow(hdr, (int)iomsg->controllen, &hdr))
353                 return true;
354
355         return false;
356 }
357
358 static int __io_recvmsg_copy_hdr(struct io_kiocb *req,
359                                  struct io_async_msghdr *iomsg)
360 {
361         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
362         struct user_msghdr msg;
363         int ret;
364
365         if (copy_from_user(&msg, sr->umsg, sizeof(*sr->umsg)))
366                 return -EFAULT;
367
368         ret = __copy_msghdr(&iomsg->msg, &msg, &iomsg->uaddr);
369         if (ret)
370                 return ret;
371
372         if (req->flags & REQ_F_BUFFER_SELECT) {
373                 if (msg.msg_iovlen == 0) {
374                         sr->len = iomsg->fast_iov[0].iov_len = 0;
375                         iomsg->fast_iov[0].iov_base = NULL;
376                         iomsg->free_iov = NULL;
377                 } else if (msg.msg_iovlen > 1) {
378                         return -EINVAL;
379                 } else {
380                         if (copy_from_user(iomsg->fast_iov, msg.msg_iov, sizeof(*msg.msg_iov)))
381                                 return -EFAULT;
382                         sr->len = iomsg->fast_iov[0].iov_len;
383                         iomsg->free_iov = NULL;
384                 }
385
386                 if (req->flags & REQ_F_APOLL_MULTISHOT) {
387                         iomsg->namelen = msg.msg_namelen;
388                         iomsg->controllen = msg.msg_controllen;
389                         if (io_recvmsg_multishot_overflow(iomsg))
390                                 return -EOVERFLOW;
391                 }
392         } else {
393                 iomsg->free_iov = iomsg->fast_iov;
394                 ret = __import_iovec(READ, msg.msg_iov, msg.msg_iovlen, UIO_FASTIOV,
395                                      &iomsg->free_iov, &iomsg->msg.msg_iter,
396                                      false);
397                 if (ret > 0)
398                         ret = 0;
399         }
400
401         return ret;
402 }
403
404 #ifdef CONFIG_COMPAT
405 static int __io_compat_recvmsg_copy_hdr(struct io_kiocb *req,
406                                         struct io_async_msghdr *iomsg)
407 {
408         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
409         struct compat_msghdr msg;
410         struct compat_iovec __user *uiov;
411         int ret;
412
413         if (copy_from_user(&msg, sr->umsg_compat, sizeof(msg)))
414                 return -EFAULT;
415
416         ret = __get_compat_msghdr(&iomsg->msg, &msg, &iomsg->uaddr);
417         if (ret)
418                 return ret;
419
420         uiov = compat_ptr(msg.msg_iov);
421         if (req->flags & REQ_F_BUFFER_SELECT) {
422                 compat_ssize_t clen;
423
424                 if (msg.msg_iovlen == 0) {
425                         sr->len = 0;
426                         iomsg->free_iov = NULL;
427                 } else if (msg.msg_iovlen > 1) {
428                         return -EINVAL;
429                 } else {
430                         if (!access_ok(uiov, sizeof(*uiov)))
431                                 return -EFAULT;
432                         if (__get_user(clen, &uiov->iov_len))
433                                 return -EFAULT;
434                         if (clen < 0)
435                                 return -EINVAL;
436                         sr->len = clen;
437                         iomsg->free_iov = NULL;
438                 }
439
440                 if (req->flags & REQ_F_APOLL_MULTISHOT) {
441                         iomsg->namelen = msg.msg_namelen;
442                         iomsg->controllen = msg.msg_controllen;
443                         if (io_recvmsg_multishot_overflow(iomsg))
444                                 return -EOVERFLOW;
445                 }
446         } else {
447                 iomsg->free_iov = iomsg->fast_iov;
448                 ret = __import_iovec(READ, (struct iovec __user *)uiov, msg.msg_iovlen,
449                                    UIO_FASTIOV, &iomsg->free_iov,
450                                    &iomsg->msg.msg_iter, true);
451                 if (ret < 0)
452                         return ret;
453         }
454
455         return 0;
456 }
457 #endif
458
459 static int io_recvmsg_copy_hdr(struct io_kiocb *req,
460                                struct io_async_msghdr *iomsg)
461 {
462         iomsg->msg.msg_name = &iomsg->addr;
463
464 #ifdef CONFIG_COMPAT
465         if (req->ctx->compat)
466                 return __io_compat_recvmsg_copy_hdr(req, iomsg);
467 #endif
468
469         return __io_recvmsg_copy_hdr(req, iomsg);
470 }
471
472 int io_recvmsg_prep_async(struct io_kiocb *req)
473 {
474         int ret;
475
476         ret = io_recvmsg_copy_hdr(req, req->async_data);
477         if (!ret)
478                 req->flags |= REQ_F_NEED_CLEANUP;
479         return ret;
480 }
481
482 #define RECVMSG_FLAGS (IORING_RECVSEND_POLL_FIRST | IORING_RECV_MULTISHOT)
483
484 int io_recvmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
485 {
486         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
487
488         if (unlikely(sqe->file_index || sqe->addr2))
489                 return -EINVAL;
490
491         sr->umsg = u64_to_user_ptr(READ_ONCE(sqe->addr));
492         sr->len = READ_ONCE(sqe->len);
493         sr->flags = READ_ONCE(sqe->ioprio);
494         if (sr->flags & ~(RECVMSG_FLAGS))
495                 return -EINVAL;
496         sr->msg_flags = READ_ONCE(sqe->msg_flags) | MSG_NOSIGNAL;
497         if (sr->msg_flags & MSG_DONTWAIT)
498                 req->flags |= REQ_F_NOWAIT;
499         if (sr->msg_flags & MSG_ERRQUEUE)
500                 req->flags |= REQ_F_CLEAR_POLLIN;
501         if (sr->flags & IORING_RECV_MULTISHOT) {
502                 if (!(req->flags & REQ_F_BUFFER_SELECT))
503                         return -EINVAL;
504                 if (sr->msg_flags & MSG_WAITALL)
505                         return -EINVAL;
506                 if (req->opcode == IORING_OP_RECV && sr->len)
507                         return -EINVAL;
508                 req->flags |= REQ_F_APOLL_MULTISHOT;
509         }
510
511 #ifdef CONFIG_COMPAT
512         if (req->ctx->compat)
513                 sr->msg_flags |= MSG_CMSG_COMPAT;
514 #endif
515         sr->done_io = 0;
516         return 0;
517 }
518
519 static inline void io_recv_prep_retry(struct io_kiocb *req)
520 {
521         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
522
523         sr->done_io = 0;
524         sr->len = 0; /* get from the provided buffer */
525 }
526
527 /*
528  * Finishes io_recv and io_recvmsg.
529  *
530  * Returns true if it is actually finished, or false if it should run
531  * again (for multishot).
532  */
533 static inline bool io_recv_finish(struct io_kiocb *req, int *ret,
534                                   unsigned int cflags, bool mshot_finished)
535 {
536         if (!(req->flags & REQ_F_APOLL_MULTISHOT)) {
537                 io_req_set_res(req, *ret, cflags);
538                 *ret = IOU_OK;
539                 return true;
540         }
541
542         if (!mshot_finished) {
543                 if (io_post_aux_cqe(req->ctx, req->cqe.user_data, *ret,
544                                     cflags | IORING_CQE_F_MORE, false)) {
545                         io_recv_prep_retry(req);
546                         return false;
547                 }
548                 /*
549                  * Otherwise stop multishot but use the current result.
550                  * Probably will end up going into overflow, but this means
551                  * we cannot trust the ordering anymore
552                  */
553         }
554
555         io_req_set_res(req, *ret, cflags);
556
557         if (req->flags & REQ_F_POLLED)
558                 *ret = IOU_STOP_MULTISHOT;
559         else
560                 *ret = IOU_OK;
561         return true;
562 }
563
564 static int io_recvmsg_prep_multishot(struct io_async_msghdr *kmsg,
565                                      struct io_sr_msg *sr, void __user **buf,
566                                      size_t *len)
567 {
568         unsigned long ubuf = (unsigned long) *buf;
569         unsigned long hdr;
570
571         hdr = sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
572                 kmsg->controllen;
573         if (*len < hdr)
574                 return -EFAULT;
575
576         if (kmsg->controllen) {
577                 unsigned long control = ubuf + hdr - kmsg->controllen;
578
579                 kmsg->msg.msg_control_user = (void __user *) control;
580                 kmsg->msg.msg_controllen = kmsg->controllen;
581         }
582
583         sr->buf = *buf; /* stash for later copy */
584         *buf = (void __user *) (ubuf + hdr);
585         kmsg->payloadlen = *len = *len - hdr;
586         return 0;
587 }
588
589 struct io_recvmsg_multishot_hdr {
590         struct io_uring_recvmsg_out msg;
591         struct sockaddr_storage addr;
592 };
593
594 static int io_recvmsg_multishot(struct socket *sock, struct io_sr_msg *io,
595                                 struct io_async_msghdr *kmsg,
596                                 unsigned int flags, bool *finished)
597 {
598         int err;
599         int copy_len;
600         struct io_recvmsg_multishot_hdr hdr;
601
602         if (kmsg->namelen)
603                 kmsg->msg.msg_name = &hdr.addr;
604         kmsg->msg.msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT);
605         kmsg->msg.msg_namelen = 0;
606
607         if (sock->file->f_flags & O_NONBLOCK)
608                 flags |= MSG_DONTWAIT;
609
610         err = sock_recvmsg(sock, &kmsg->msg, flags);
611         *finished = err <= 0;
612         if (err < 0)
613                 return err;
614
615         hdr.msg = (struct io_uring_recvmsg_out) {
616                 .controllen = kmsg->controllen - kmsg->msg.msg_controllen,
617                 .flags = kmsg->msg.msg_flags & ~MSG_CMSG_COMPAT
618         };
619
620         hdr.msg.payloadlen = err;
621         if (err > kmsg->payloadlen)
622                 err = kmsg->payloadlen;
623
624         copy_len = sizeof(struct io_uring_recvmsg_out);
625         if (kmsg->msg.msg_namelen > kmsg->namelen)
626                 copy_len += kmsg->namelen;
627         else
628                 copy_len += kmsg->msg.msg_namelen;
629
630         /*
631          *      "fromlen shall refer to the value before truncation.."
632          *                      1003.1g
633          */
634         hdr.msg.namelen = kmsg->msg.msg_namelen;
635
636         /* ensure that there is no gap between hdr and sockaddr_storage */
637         BUILD_BUG_ON(offsetof(struct io_recvmsg_multishot_hdr, addr) !=
638                      sizeof(struct io_uring_recvmsg_out));
639         if (copy_to_user(io->buf, &hdr, copy_len)) {
640                 *finished = true;
641                 return -EFAULT;
642         }
643
644         return sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
645                         kmsg->controllen + err;
646 }
647
648 int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
649 {
650         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
651         struct io_async_msghdr iomsg, *kmsg;
652         struct socket *sock;
653         unsigned int cflags;
654         unsigned flags;
655         int ret, min_ret = 0;
656         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
657         bool mshot_finished = true;
658
659         sock = sock_from_file(req->file);
660         if (unlikely(!sock))
661                 return -ENOTSOCK;
662
663         if (req_has_async_data(req)) {
664                 kmsg = req->async_data;
665         } else {
666                 ret = io_recvmsg_copy_hdr(req, &iomsg);
667                 if (ret)
668                         return ret;
669                 kmsg = &iomsg;
670         }
671
672         if (!(req->flags & REQ_F_POLLED) &&
673             (sr->flags & IORING_RECVSEND_POLL_FIRST))
674                 return io_setup_async_msg(req, kmsg, issue_flags);
675
676 retry_multishot:
677         if (io_do_buffer_select(req)) {
678                 void __user *buf;
679                 size_t len = sr->len;
680
681                 buf = io_buffer_select(req, &len, issue_flags);
682                 if (!buf)
683                         return -ENOBUFS;
684
685                 if (req->flags & REQ_F_APOLL_MULTISHOT) {
686                         ret = io_recvmsg_prep_multishot(kmsg, sr, &buf, &len);
687                         if (ret) {
688                                 io_kbuf_recycle(req, issue_flags);
689                                 return ret;
690                         }
691                 }
692
693                 kmsg->fast_iov[0].iov_base = buf;
694                 kmsg->fast_iov[0].iov_len = len;
695                 iov_iter_init(&kmsg->msg.msg_iter, READ, kmsg->fast_iov, 1,
696                                 len);
697         }
698
699         flags = sr->msg_flags;
700         if (force_nonblock)
701                 flags |= MSG_DONTWAIT;
702         if (flags & MSG_WAITALL)
703                 min_ret = iov_iter_count(&kmsg->msg.msg_iter);
704
705         kmsg->msg.msg_get_inq = 1;
706         if (req->flags & REQ_F_APOLL_MULTISHOT)
707                 ret = io_recvmsg_multishot(sock, sr, kmsg, flags,
708                                            &mshot_finished);
709         else
710                 ret = __sys_recvmsg_sock(sock, &kmsg->msg, sr->umsg,
711                                          kmsg->uaddr, flags);
712
713         if (ret < min_ret) {
714                 if (ret == -EAGAIN && force_nonblock) {
715                         ret = io_setup_async_msg(req, kmsg, issue_flags);
716                         if (ret == -EAGAIN && (req->flags & IO_APOLL_MULTI_POLLED) ==
717                                                IO_APOLL_MULTI_POLLED) {
718                                 io_kbuf_recycle(req, issue_flags);
719                                 return IOU_ISSUE_SKIP_COMPLETE;
720                         }
721                         return ret;
722                 }
723                 if (ret == -ERESTARTSYS)
724                         ret = -EINTR;
725                 if (ret > 0 && io_net_retry(sock, flags)) {
726                         sr->done_io += ret;
727                         req->flags |= REQ_F_PARTIAL_IO;
728                         return io_setup_async_msg(req, kmsg, issue_flags);
729                 }
730                 req_set_fail(req);
731         } else if ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
732                 req_set_fail(req);
733         }
734
735         if (ret > 0)
736                 ret += sr->done_io;
737         else if (sr->done_io)
738                 ret = sr->done_io;
739         else
740                 io_kbuf_recycle(req, issue_flags);
741
742         cflags = io_put_kbuf(req, issue_flags);
743         if (kmsg->msg.msg_inq)
744                 cflags |= IORING_CQE_F_SOCK_NONEMPTY;
745
746         if (!io_recv_finish(req, &ret, cflags, mshot_finished))
747                 goto retry_multishot;
748
749         if (mshot_finished) {
750                 io_netmsg_recycle(req, issue_flags);
751                 /* fast path, check for non-NULL to avoid function call */
752                 if (kmsg->free_iov)
753                         kfree(kmsg->free_iov);
754                 req->flags &= ~REQ_F_NEED_CLEANUP;
755         }
756
757         return ret;
758 }
759
760 int io_recv(struct io_kiocb *req, unsigned int issue_flags)
761 {
762         struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
763         struct msghdr msg;
764         struct socket *sock;
765         struct iovec iov;
766         unsigned int cflags;
767         unsigned flags;
768         int ret, min_ret = 0;
769         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
770         size_t len = sr->len;
771
772         if (!(req->flags & REQ_F_POLLED) &&
773             (sr->flags & IORING_RECVSEND_POLL_FIRST))
774                 return -EAGAIN;
775
776         sock = sock_from_file(req->file);
777         if (unlikely(!sock))
778                 return -ENOTSOCK;
779
780 retry_multishot:
781         if (io_do_buffer_select(req)) {
782                 void __user *buf;
783
784                 buf = io_buffer_select(req, &len, issue_flags);
785                 if (!buf)
786                         return -ENOBUFS;
787                 sr->buf = buf;
788         }
789
790         ret = import_single_range(READ, sr->buf, len, &iov, &msg.msg_iter);
791         if (unlikely(ret))
792                 goto out_free;
793
794         msg.msg_name = NULL;
795         msg.msg_namelen = 0;
796         msg.msg_control = NULL;
797         msg.msg_get_inq = 1;
798         msg.msg_flags = 0;
799         msg.msg_controllen = 0;
800         msg.msg_iocb = NULL;
801         msg.msg_ubuf = NULL;
802
803         flags = sr->msg_flags;
804         if (force_nonblock)
805                 flags |= MSG_DONTWAIT;
806         if (flags & MSG_WAITALL)
807                 min_ret = iov_iter_count(&msg.msg_iter);
808
809         ret = sock_recvmsg(sock, &msg, flags);
810         if (ret < min_ret) {
811                 if (ret == -EAGAIN && force_nonblock) {
812                         if ((req->flags & IO_APOLL_MULTI_POLLED) == IO_APOLL_MULTI_POLLED) {
813                                 io_kbuf_recycle(req, issue_flags);
814                                 return IOU_ISSUE_SKIP_COMPLETE;
815                         }
816
817                         return -EAGAIN;
818                 }
819                 if (ret == -ERESTARTSYS)
820                         ret = -EINTR;
821                 if (ret > 0 && io_net_retry(sock, flags)) {
822                         sr->len -= ret;
823                         sr->buf += ret;
824                         sr->done_io += ret;
825                         req->flags |= REQ_F_PARTIAL_IO;
826                         return -EAGAIN;
827                 }
828                 req_set_fail(req);
829         } else if ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
830 out_free:
831                 req_set_fail(req);
832         }
833
834         if (ret > 0)
835                 ret += sr->done_io;
836         else if (sr->done_io)
837                 ret = sr->done_io;
838         else
839                 io_kbuf_recycle(req, issue_flags);
840
841         cflags = io_put_kbuf(req, issue_flags);
842         if (msg.msg_inq)
843                 cflags |= IORING_CQE_F_SOCK_NONEMPTY;
844
845         if (!io_recv_finish(req, &ret, cflags, ret <= 0))
846                 goto retry_multishot;
847
848         return ret;
849 }
850
851 int io_sendzc_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
852 {
853         struct io_sendzc *zc = io_kiocb_to_cmd(req, struct io_sendzc);
854         struct io_ring_ctx *ctx = req->ctx;
855
856         if (READ_ONCE(sqe->__pad2[0]) || READ_ONCE(sqe->addr3))
857                 return -EINVAL;
858
859         zc->flags = READ_ONCE(sqe->ioprio);
860         if (zc->flags & ~(IORING_RECVSEND_POLL_FIRST |
861                           IORING_RECVSEND_FIXED_BUF | IORING_RECVSEND_NOTIF_FLUSH))
862                 return -EINVAL;
863         if (zc->flags & IORING_RECVSEND_FIXED_BUF) {
864                 unsigned idx = READ_ONCE(sqe->buf_index);
865
866                 if (unlikely(idx >= ctx->nr_user_bufs))
867                         return -EFAULT;
868                 idx = array_index_nospec(idx, ctx->nr_user_bufs);
869                 req->imu = READ_ONCE(ctx->user_bufs[idx]);
870                 io_req_set_rsrc_node(req, ctx, 0);
871         }
872
873         zc->buf = u64_to_user_ptr(READ_ONCE(sqe->addr));
874         zc->len = READ_ONCE(sqe->len);
875         zc->msg_flags = READ_ONCE(sqe->msg_flags) | MSG_NOSIGNAL;
876         zc->slot_idx = READ_ONCE(sqe->notification_idx);
877         if (zc->msg_flags & MSG_DONTWAIT)
878                 req->flags |= REQ_F_NOWAIT;
879
880         zc->addr = u64_to_user_ptr(READ_ONCE(sqe->addr2));
881         zc->addr_len = READ_ONCE(sqe->addr_len);
882         zc->done_io = 0;
883
884 #ifdef CONFIG_COMPAT
885         if (req->ctx->compat)
886                 zc->msg_flags |= MSG_CMSG_COMPAT;
887 #endif
888         return 0;
889 }
890
891 static int io_sg_from_iter(struct sock *sk, struct sk_buff *skb,
892                            struct iov_iter *from, size_t length)
893 {
894         struct skb_shared_info *shinfo = skb_shinfo(skb);
895         int frag = shinfo->nr_frags;
896         int ret = 0;
897         struct bvec_iter bi;
898         ssize_t copied = 0;
899         unsigned long truesize = 0;
900
901         if (!shinfo->nr_frags)
902                 shinfo->flags |= SKBFL_MANAGED_FRAG_REFS;
903
904         if (!skb_zcopy_managed(skb) || !iov_iter_is_bvec(from)) {
905                 skb_zcopy_downgrade_managed(skb);
906                 return __zerocopy_sg_from_iter(NULL, sk, skb, from, length);
907         }
908
909         bi.bi_size = min(from->count, length);
910         bi.bi_bvec_done = from->iov_offset;
911         bi.bi_idx = 0;
912
913         while (bi.bi_size && frag < MAX_SKB_FRAGS) {
914                 struct bio_vec v = mp_bvec_iter_bvec(from->bvec, bi);
915
916                 copied += v.bv_len;
917                 truesize += PAGE_ALIGN(v.bv_len + v.bv_offset);
918                 __skb_fill_page_desc_noacc(shinfo, frag++, v.bv_page,
919                                            v.bv_offset, v.bv_len);
920                 bvec_iter_advance_single(from->bvec, &bi, v.bv_len);
921         }
922         if (bi.bi_size)
923                 ret = -EMSGSIZE;
924
925         shinfo->nr_frags = frag;
926         from->bvec += bi.bi_idx;
927         from->nr_segs -= bi.bi_idx;
928         from->count = bi.bi_size;
929         from->iov_offset = bi.bi_bvec_done;
930
931         skb->data_len += copied;
932         skb->len += copied;
933         skb->truesize += truesize;
934
935         if (sk && sk->sk_type == SOCK_STREAM) {
936                 sk_wmem_queued_add(sk, truesize);
937                 if (!skb_zcopy_pure(skb))
938                         sk_mem_charge(sk, truesize);
939         } else {
940                 refcount_add(truesize, &skb->sk->sk_wmem_alloc);
941         }
942         return ret;
943 }
944
945 int io_sendzc(struct io_kiocb *req, unsigned int issue_flags)
946 {
947         struct sockaddr_storage address;
948         struct io_ring_ctx *ctx = req->ctx;
949         struct io_sendzc *zc = io_kiocb_to_cmd(req, struct io_sendzc);
950         struct io_notif_slot *notif_slot;
951         struct io_kiocb *notif;
952         struct msghdr msg;
953         struct iovec iov;
954         struct socket *sock;
955         unsigned msg_flags;
956         int ret, min_ret = 0;
957
958         if (!(req->flags & REQ_F_POLLED) &&
959             (zc->flags & IORING_RECVSEND_POLL_FIRST))
960                 return -EAGAIN;
961
962         if (issue_flags & IO_URING_F_UNLOCKED)
963                 return -EAGAIN;
964         sock = sock_from_file(req->file);
965         if (unlikely(!sock))
966                 return -ENOTSOCK;
967
968         notif_slot = io_get_notif_slot(ctx, zc->slot_idx);
969         if (!notif_slot)
970                 return -EINVAL;
971         notif = io_get_notif(ctx, notif_slot);
972         if (!notif)
973                 return -ENOMEM;
974
975         msg.msg_name = NULL;
976         msg.msg_control = NULL;
977         msg.msg_controllen = 0;
978         msg.msg_namelen = 0;
979
980         if (zc->flags & IORING_RECVSEND_FIXED_BUF) {
981                 ret = io_import_fixed(WRITE, &msg.msg_iter, req->imu,
982                                         (u64)(uintptr_t)zc->buf, zc->len);
983                 if (unlikely(ret))
984                                 return ret;
985         } else {
986                 ret = import_single_range(WRITE, zc->buf, zc->len, &iov,
987                                           &msg.msg_iter);
988                 if (unlikely(ret))
989                         return ret;
990                 ret = io_notif_account_mem(notif, zc->len);
991                 if (unlikely(ret))
992                         return ret;
993         }
994
995         if (zc->addr) {
996                 ret = move_addr_to_kernel(zc->addr, zc->addr_len, &address);
997                 if (unlikely(ret < 0))
998                         return ret;
999                 msg.msg_name = (struct sockaddr *)&address;
1000                 msg.msg_namelen = zc->addr_len;
1001         }
1002
1003         msg_flags = zc->msg_flags | MSG_ZEROCOPY;
1004         if (issue_flags & IO_URING_F_NONBLOCK)
1005                 msg_flags |= MSG_DONTWAIT;
1006         if (msg_flags & MSG_WAITALL)
1007                 min_ret = iov_iter_count(&msg.msg_iter);
1008
1009         msg.msg_flags = msg_flags;
1010         msg.msg_ubuf = &io_notif_to_data(notif)->uarg;
1011         msg.sg_from_iter = io_sg_from_iter;
1012         ret = sock_sendmsg(sock, &msg);
1013
1014         if (unlikely(ret < min_ret)) {
1015                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
1016                         return -EAGAIN;
1017                 if (ret > 0 && io_net_retry(sock, msg.msg_flags)) {
1018                         zc->len -= ret;
1019                         zc->buf += ret;
1020                         zc->done_io += ret;
1021                         req->flags |= REQ_F_PARTIAL_IO;
1022                         return -EAGAIN;
1023                 }
1024                 if (ret == -ERESTARTSYS)
1025                         ret = -EINTR;
1026         } else if (zc->flags & IORING_RECVSEND_NOTIF_FLUSH) {
1027                 io_notif_slot_flush_submit(notif_slot, 0);
1028         }
1029
1030         if (ret >= 0)
1031                 ret += zc->done_io;
1032         else if (zc->done_io)
1033                 ret = zc->done_io;
1034         io_req_set_res(req, ret, 0);
1035         return IOU_OK;
1036 }
1037
1038 int io_accept_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
1039 {
1040         struct io_accept *accept = io_kiocb_to_cmd(req, struct io_accept);
1041         unsigned flags;
1042
1043         if (sqe->len || sqe->buf_index)
1044                 return -EINVAL;
1045
1046         accept->addr = u64_to_user_ptr(READ_ONCE(sqe->addr));
1047         accept->addr_len = u64_to_user_ptr(READ_ONCE(sqe->addr2));
1048         accept->flags = READ_ONCE(sqe->accept_flags);
1049         accept->nofile = rlimit(RLIMIT_NOFILE);
1050         flags = READ_ONCE(sqe->ioprio);
1051         if (flags & ~IORING_ACCEPT_MULTISHOT)
1052                 return -EINVAL;
1053
1054         accept->file_slot = READ_ONCE(sqe->file_index);
1055         if (accept->file_slot) {
1056                 if (accept->flags & SOCK_CLOEXEC)
1057                         return -EINVAL;
1058                 if (flags & IORING_ACCEPT_MULTISHOT &&
1059                     accept->file_slot != IORING_FILE_INDEX_ALLOC)
1060                         return -EINVAL;
1061         }
1062         if (accept->flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
1063                 return -EINVAL;
1064         if (SOCK_NONBLOCK != O_NONBLOCK && (accept->flags & SOCK_NONBLOCK))
1065                 accept->flags = (accept->flags & ~SOCK_NONBLOCK) | O_NONBLOCK;
1066         if (flags & IORING_ACCEPT_MULTISHOT)
1067                 req->flags |= REQ_F_APOLL_MULTISHOT;
1068         return 0;
1069 }
1070
1071 int io_accept(struct io_kiocb *req, unsigned int issue_flags)
1072 {
1073         struct io_ring_ctx *ctx = req->ctx;
1074         struct io_accept *accept = io_kiocb_to_cmd(req, struct io_accept);
1075         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
1076         unsigned int file_flags = force_nonblock ? O_NONBLOCK : 0;
1077         bool fixed = !!accept->file_slot;
1078         struct file *file;
1079         int ret, fd;
1080
1081 retry:
1082         if (!fixed) {
1083                 fd = __get_unused_fd_flags(accept->flags, accept->nofile);
1084                 if (unlikely(fd < 0))
1085                         return fd;
1086         }
1087         file = do_accept(req->file, file_flags, accept->addr, accept->addr_len,
1088                          accept->flags);
1089         if (IS_ERR(file)) {
1090                 if (!fixed)
1091                         put_unused_fd(fd);
1092                 ret = PTR_ERR(file);
1093                 if (ret == -EAGAIN && force_nonblock) {
1094                         /*
1095                          * if it's multishot and polled, we don't need to
1096                          * return EAGAIN to arm the poll infra since it
1097                          * has already been done
1098                          */
1099                         if ((req->flags & IO_APOLL_MULTI_POLLED) ==
1100                             IO_APOLL_MULTI_POLLED)
1101                                 ret = IOU_ISSUE_SKIP_COMPLETE;
1102                         return ret;
1103                 }
1104                 if (ret == -ERESTARTSYS)
1105                         ret = -EINTR;
1106                 req_set_fail(req);
1107         } else if (!fixed) {
1108                 fd_install(fd, file);
1109                 ret = fd;
1110         } else {
1111                 ret = io_fixed_fd_install(req, issue_flags, file,
1112                                                 accept->file_slot);
1113         }
1114
1115         if (!(req->flags & REQ_F_APOLL_MULTISHOT)) {
1116                 io_req_set_res(req, ret, 0);
1117                 return IOU_OK;
1118         }
1119
1120         if (ret >= 0 &&
1121             io_post_aux_cqe(ctx, req->cqe.user_data, ret, IORING_CQE_F_MORE, false))
1122                 goto retry;
1123
1124         io_req_set_res(req, ret, 0);
1125         if (req->flags & REQ_F_POLLED)
1126                 return IOU_STOP_MULTISHOT;
1127         return IOU_OK;
1128 }
1129
1130 int io_socket_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
1131 {
1132         struct io_socket *sock = io_kiocb_to_cmd(req, struct io_socket);
1133
1134         if (sqe->addr || sqe->rw_flags || sqe->buf_index)
1135                 return -EINVAL;
1136
1137         sock->domain = READ_ONCE(sqe->fd);
1138         sock->type = READ_ONCE(sqe->off);
1139         sock->protocol = READ_ONCE(sqe->len);
1140         sock->file_slot = READ_ONCE(sqe->file_index);
1141         sock->nofile = rlimit(RLIMIT_NOFILE);
1142
1143         sock->flags = sock->type & ~SOCK_TYPE_MASK;
1144         if (sock->file_slot && (sock->flags & SOCK_CLOEXEC))
1145                 return -EINVAL;
1146         if (sock->flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
1147                 return -EINVAL;
1148         return 0;
1149 }
1150
1151 int io_socket(struct io_kiocb *req, unsigned int issue_flags)
1152 {
1153         struct io_socket *sock = io_kiocb_to_cmd(req, struct io_socket);
1154         bool fixed = !!sock->file_slot;
1155         struct file *file;
1156         int ret, fd;
1157
1158         if (!fixed) {
1159                 fd = __get_unused_fd_flags(sock->flags, sock->nofile);
1160                 if (unlikely(fd < 0))
1161                         return fd;
1162         }
1163         file = __sys_socket_file(sock->domain, sock->type, sock->protocol);
1164         if (IS_ERR(file)) {
1165                 if (!fixed)
1166                         put_unused_fd(fd);
1167                 ret = PTR_ERR(file);
1168                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
1169                         return -EAGAIN;
1170                 if (ret == -ERESTARTSYS)
1171                         ret = -EINTR;
1172                 req_set_fail(req);
1173         } else if (!fixed) {
1174                 fd_install(fd, file);
1175                 ret = fd;
1176         } else {
1177                 ret = io_fixed_fd_install(req, issue_flags, file,
1178                                             sock->file_slot);
1179         }
1180         io_req_set_res(req, ret, 0);
1181         return IOU_OK;
1182 }
1183
1184 int io_connect_prep_async(struct io_kiocb *req)
1185 {
1186         struct io_async_connect *io = req->async_data;
1187         struct io_connect *conn = io_kiocb_to_cmd(req, struct io_connect);
1188
1189         return move_addr_to_kernel(conn->addr, conn->addr_len, &io->address);
1190 }
1191
1192 int io_connect_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
1193 {
1194         struct io_connect *conn = io_kiocb_to_cmd(req, struct io_connect);
1195
1196         if (sqe->len || sqe->buf_index || sqe->rw_flags || sqe->splice_fd_in)
1197                 return -EINVAL;
1198
1199         conn->addr = u64_to_user_ptr(READ_ONCE(sqe->addr));
1200         conn->addr_len =  READ_ONCE(sqe->addr2);
1201         return 0;
1202 }
1203
1204 int io_connect(struct io_kiocb *req, unsigned int issue_flags)
1205 {
1206         struct io_connect *connect = io_kiocb_to_cmd(req, struct io_connect);
1207         struct io_async_connect __io, *io;
1208         unsigned file_flags;
1209         int ret;
1210         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
1211
1212         if (req_has_async_data(req)) {
1213                 io = req->async_data;
1214         } else {
1215                 ret = move_addr_to_kernel(connect->addr,
1216                                                 connect->addr_len,
1217                                                 &__io.address);
1218                 if (ret)
1219                         goto out;
1220                 io = &__io;
1221         }
1222
1223         file_flags = force_nonblock ? O_NONBLOCK : 0;
1224
1225         ret = __sys_connect_file(req->file, &io->address,
1226                                         connect->addr_len, file_flags);
1227         if ((ret == -EAGAIN || ret == -EINPROGRESS) && force_nonblock) {
1228                 if (req_has_async_data(req))
1229                         return -EAGAIN;
1230                 if (io_alloc_async_data(req)) {
1231                         ret = -ENOMEM;
1232                         goto out;
1233                 }
1234                 memcpy(req->async_data, &__io, sizeof(__io));
1235                 return -EAGAIN;
1236         }
1237         if (ret == -ERESTARTSYS)
1238                 ret = -EINTR;
1239 out:
1240         if (ret < 0)
1241                 req_set_fail(req);
1242         io_req_set_res(req, ret, 0);
1243         return IOU_OK;
1244 }
1245
1246 void io_netmsg_cache_free(struct io_cache_entry *entry)
1247 {
1248         kfree(container_of(entry, struct io_async_msghdr, cache));
1249 }
1250 #endif