net: fix compat pointer in get_compat_msghdr()
[platform/kernel/linux-starfive.git] / io_uring / net.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/kernel.h>
3 #include <linux/errno.h>
4 #include <linux/file.h>
5 #include <linux/slab.h>
6 #include <linux/net.h>
7 #include <linux/compat.h>
8 #include <net/compat.h>
9 #include <linux/io_uring.h>
10
11 #include <uapi/linux/io_uring.h>
12
13 #include "io_uring.h"
14 #include "kbuf.h"
15 #include "alloc_cache.h"
16 #include "net.h"
17
18 #if defined(CONFIG_NET)
19 struct io_shutdown {
20         struct file                     *file;
21         int                             how;
22 };
23
24 struct io_accept {
25         struct file                     *file;
26         struct sockaddr __user          *addr;
27         int __user                      *addr_len;
28         int                             flags;
29         u32                             file_slot;
30         unsigned long                   nofile;
31 };
32
33 struct io_socket {
34         struct file                     *file;
35         int                             domain;
36         int                             type;
37         int                             protocol;
38         int                             flags;
39         u32                             file_slot;
40         unsigned long                   nofile;
41 };
42
43 struct io_connect {
44         struct file                     *file;
45         struct sockaddr __user          *addr;
46         int                             addr_len;
47 };
48
49 struct io_sr_msg {
50         struct file                     *file;
51         union {
52                 struct compat_msghdr __user     *umsg_compat;
53                 struct user_msghdr __user       *umsg;
54                 void __user                     *buf;
55         };
56         int                             msg_flags;
57         size_t                          len;
58         size_t                          done_io;
59         unsigned int                    flags;
60 };
61
62 #define IO_APOLL_MULTI_POLLED (REQ_F_APOLL_MULTISHOT | REQ_F_POLLED)
63
64 int io_shutdown_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
65 {
66         struct io_shutdown *shutdown = io_kiocb_to_cmd(req);
67
68         if (unlikely(sqe->off || sqe->addr || sqe->rw_flags ||
69                      sqe->buf_index || sqe->splice_fd_in))
70                 return -EINVAL;
71
72         shutdown->how = READ_ONCE(sqe->len);
73         return 0;
74 }
75
76 int io_shutdown(struct io_kiocb *req, unsigned int issue_flags)
77 {
78         struct io_shutdown *shutdown = io_kiocb_to_cmd(req);
79         struct socket *sock;
80         int ret;
81
82         if (issue_flags & IO_URING_F_NONBLOCK)
83                 return -EAGAIN;
84
85         sock = sock_from_file(req->file);
86         if (unlikely(!sock))
87                 return -ENOTSOCK;
88
89         ret = __sys_shutdown_sock(sock, shutdown->how);
90         io_req_set_res(req, ret, 0);
91         return IOU_OK;
92 }
93
94 static bool io_net_retry(struct socket *sock, int flags)
95 {
96         if (!(flags & MSG_WAITALL))
97                 return false;
98         return sock->type == SOCK_STREAM || sock->type == SOCK_SEQPACKET;
99 }
100
101 static void io_netmsg_recycle(struct io_kiocb *req, unsigned int issue_flags)
102 {
103         struct io_async_msghdr *hdr = req->async_data;
104
105         if (!hdr || issue_flags & IO_URING_F_UNLOCKED)
106                 return;
107
108         /* Let normal cleanup path reap it if we fail adding to the cache */
109         if (io_alloc_cache_put(&req->ctx->netmsg_cache, &hdr->cache)) {
110                 req->async_data = NULL;
111                 req->flags &= ~REQ_F_ASYNC_DATA;
112         }
113 }
114
115 static struct io_async_msghdr *io_recvmsg_alloc_async(struct io_kiocb *req,
116                                                       unsigned int issue_flags)
117 {
118         struct io_ring_ctx *ctx = req->ctx;
119         struct io_cache_entry *entry;
120
121         if (!(issue_flags & IO_URING_F_UNLOCKED) &&
122             (entry = io_alloc_cache_get(&ctx->netmsg_cache)) != NULL) {
123                 struct io_async_msghdr *hdr;
124
125                 hdr = container_of(entry, struct io_async_msghdr, cache);
126                 req->flags |= REQ_F_ASYNC_DATA;
127                 req->async_data = hdr;
128                 return hdr;
129         }
130
131         if (!io_alloc_async_data(req))
132                 return req->async_data;
133
134         return NULL;
135 }
136
137 static int io_setup_async_msg(struct io_kiocb *req,
138                               struct io_async_msghdr *kmsg,
139                               unsigned int issue_flags)
140 {
141         struct io_async_msghdr *async_msg = req->async_data;
142
143         if (async_msg)
144                 return -EAGAIN;
145         async_msg = io_recvmsg_alloc_async(req, issue_flags);
146         if (!async_msg) {
147                 kfree(kmsg->free_iov);
148                 return -ENOMEM;
149         }
150         req->flags |= REQ_F_NEED_CLEANUP;
151         memcpy(async_msg, kmsg, sizeof(*kmsg));
152         async_msg->msg.msg_name = &async_msg->addr;
153         /* if were using fast_iov, set it to the new one */
154         if (!async_msg->free_iov)
155                 async_msg->msg.msg_iter.iov = async_msg->fast_iov;
156
157         return -EAGAIN;
158 }
159
160 static int io_sendmsg_copy_hdr(struct io_kiocb *req,
161                                struct io_async_msghdr *iomsg)
162 {
163         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
164
165         iomsg->msg.msg_name = &iomsg->addr;
166         iomsg->free_iov = iomsg->fast_iov;
167         return sendmsg_copy_msghdr(&iomsg->msg, sr->umsg, sr->msg_flags,
168                                         &iomsg->free_iov);
169 }
170
171 int io_sendmsg_prep_async(struct io_kiocb *req)
172 {
173         int ret;
174
175         ret = io_sendmsg_copy_hdr(req, req->async_data);
176         if (!ret)
177                 req->flags |= REQ_F_NEED_CLEANUP;
178         return ret;
179 }
180
181 void io_sendmsg_recvmsg_cleanup(struct io_kiocb *req)
182 {
183         struct io_async_msghdr *io = req->async_data;
184
185         kfree(io->free_iov);
186 }
187
188 int io_sendmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
189 {
190         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
191
192         if (unlikely(sqe->file_index || sqe->addr2))
193                 return -EINVAL;
194
195         sr->umsg = u64_to_user_ptr(READ_ONCE(sqe->addr));
196         sr->len = READ_ONCE(sqe->len);
197         sr->flags = READ_ONCE(sqe->ioprio);
198         if (sr->flags & ~IORING_RECVSEND_POLL_FIRST)
199                 return -EINVAL;
200         sr->msg_flags = READ_ONCE(sqe->msg_flags) | MSG_NOSIGNAL;
201         if (sr->msg_flags & MSG_DONTWAIT)
202                 req->flags |= REQ_F_NOWAIT;
203
204 #ifdef CONFIG_COMPAT
205         if (req->ctx->compat)
206                 sr->msg_flags |= MSG_CMSG_COMPAT;
207 #endif
208         sr->done_io = 0;
209         return 0;
210 }
211
212 int io_sendmsg(struct io_kiocb *req, unsigned int issue_flags)
213 {
214         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
215         struct io_async_msghdr iomsg, *kmsg;
216         struct socket *sock;
217         unsigned flags;
218         int min_ret = 0;
219         int ret;
220
221         sock = sock_from_file(req->file);
222         if (unlikely(!sock))
223                 return -ENOTSOCK;
224
225         if (req_has_async_data(req)) {
226                 kmsg = req->async_data;
227         } else {
228                 ret = io_sendmsg_copy_hdr(req, &iomsg);
229                 if (ret)
230                         return ret;
231                 kmsg = &iomsg;
232         }
233
234         if (!(req->flags & REQ_F_POLLED) &&
235             (sr->flags & IORING_RECVSEND_POLL_FIRST))
236                 return io_setup_async_msg(req, kmsg, issue_flags);
237
238         flags = sr->msg_flags;
239         if (issue_flags & IO_URING_F_NONBLOCK)
240                 flags |= MSG_DONTWAIT;
241         if (flags & MSG_WAITALL)
242                 min_ret = iov_iter_count(&kmsg->msg.msg_iter);
243
244         ret = __sys_sendmsg_sock(sock, &kmsg->msg, flags);
245
246         if (ret < min_ret) {
247                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
248                         return io_setup_async_msg(req, kmsg, issue_flags);
249                 if (ret == -ERESTARTSYS)
250                         ret = -EINTR;
251                 if (ret > 0 && io_net_retry(sock, flags)) {
252                         sr->done_io += ret;
253                         req->flags |= REQ_F_PARTIAL_IO;
254                         return io_setup_async_msg(req, kmsg, issue_flags);
255                 }
256                 req_set_fail(req);
257         }
258         /* fast path, check for non-NULL to avoid function call */
259         if (kmsg->free_iov)
260                 kfree(kmsg->free_iov);
261         req->flags &= ~REQ_F_NEED_CLEANUP;
262         io_netmsg_recycle(req, issue_flags);
263         if (ret >= 0)
264                 ret += sr->done_io;
265         else if (sr->done_io)
266                 ret = sr->done_io;
267         io_req_set_res(req, ret, 0);
268         return IOU_OK;
269 }
270
271 int io_send(struct io_kiocb *req, unsigned int issue_flags)
272 {
273         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
274         struct msghdr msg;
275         struct iovec iov;
276         struct socket *sock;
277         unsigned flags;
278         int min_ret = 0;
279         int ret;
280
281         if (!(req->flags & REQ_F_POLLED) &&
282             (sr->flags & IORING_RECVSEND_POLL_FIRST))
283                 return -EAGAIN;
284
285         sock = sock_from_file(req->file);
286         if (unlikely(!sock))
287                 return -ENOTSOCK;
288
289         ret = import_single_range(WRITE, sr->buf, sr->len, &iov, &msg.msg_iter);
290         if (unlikely(ret))
291                 return ret;
292
293         msg.msg_name = NULL;
294         msg.msg_control = NULL;
295         msg.msg_controllen = 0;
296         msg.msg_namelen = 0;
297
298         flags = sr->msg_flags;
299         if (issue_flags & IO_URING_F_NONBLOCK)
300                 flags |= MSG_DONTWAIT;
301         if (flags & MSG_WAITALL)
302                 min_ret = iov_iter_count(&msg.msg_iter);
303
304         msg.msg_flags = flags;
305         ret = sock_sendmsg(sock, &msg);
306         if (ret < min_ret) {
307                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
308                         return -EAGAIN;
309                 if (ret == -ERESTARTSYS)
310                         ret = -EINTR;
311                 if (ret > 0 && io_net_retry(sock, flags)) {
312                         sr->len -= ret;
313                         sr->buf += ret;
314                         sr->done_io += ret;
315                         req->flags |= REQ_F_PARTIAL_IO;
316                         return -EAGAIN;
317                 }
318                 req_set_fail(req);
319         }
320         if (ret >= 0)
321                 ret += sr->done_io;
322         else if (sr->done_io)
323                 ret = sr->done_io;
324         io_req_set_res(req, ret, 0);
325         return IOU_OK;
326 }
327
328 static bool io_recvmsg_multishot_overflow(struct io_async_msghdr *iomsg)
329 {
330         int hdr;
331
332         if (iomsg->namelen < 0)
333                 return true;
334         if (check_add_overflow((int)sizeof(struct io_uring_recvmsg_out),
335                                iomsg->namelen, &hdr))
336                 return true;
337         if (check_add_overflow(hdr, (int)iomsg->controllen, &hdr))
338                 return true;
339
340         return false;
341 }
342
343 static int __io_recvmsg_copy_hdr(struct io_kiocb *req,
344                                  struct io_async_msghdr *iomsg)
345 {
346         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
347         struct user_msghdr msg;
348         int ret;
349
350         if (copy_from_user(&msg, sr->umsg, sizeof(*sr->umsg)))
351                 return -EFAULT;
352
353         ret = __copy_msghdr(&iomsg->msg, &msg, &iomsg->uaddr);
354         if (ret)
355                 return ret;
356
357         if (req->flags & REQ_F_BUFFER_SELECT) {
358                 if (msg.msg_iovlen == 0) {
359                         sr->len = iomsg->fast_iov[0].iov_len = 0;
360                         iomsg->fast_iov[0].iov_base = NULL;
361                         iomsg->free_iov = NULL;
362                 } else if (msg.msg_iovlen > 1) {
363                         return -EINVAL;
364                 } else {
365                         if (copy_from_user(iomsg->fast_iov, msg.msg_iov, sizeof(*msg.msg_iov)))
366                                 return -EFAULT;
367                         sr->len = iomsg->fast_iov[0].iov_len;
368                         iomsg->free_iov = NULL;
369                 }
370
371                 if (req->flags & REQ_F_APOLL_MULTISHOT) {
372                         iomsg->namelen = msg.msg_namelen;
373                         iomsg->controllen = msg.msg_controllen;
374                         if (io_recvmsg_multishot_overflow(iomsg))
375                                 return -EOVERFLOW;
376                 }
377         } else {
378                 iomsg->free_iov = iomsg->fast_iov;
379                 ret = __import_iovec(READ, msg.msg_iov, msg.msg_iovlen, UIO_FASTIOV,
380                                      &iomsg->free_iov, &iomsg->msg.msg_iter,
381                                      false);
382                 if (ret > 0)
383                         ret = 0;
384         }
385
386         return ret;
387 }
388
389 #ifdef CONFIG_COMPAT
390 static int __io_compat_recvmsg_copy_hdr(struct io_kiocb *req,
391                                         struct io_async_msghdr *iomsg)
392 {
393         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
394         struct compat_msghdr msg;
395         struct compat_iovec __user *uiov;
396         int ret;
397
398         if (copy_from_user(&msg, sr->umsg_compat, sizeof(msg)))
399                 return -EFAULT;
400
401         ret = __get_compat_msghdr(&iomsg->msg, &msg, &iomsg->uaddr);
402         if (ret)
403                 return ret;
404
405         uiov = compat_ptr(msg.msg_iov);
406         if (req->flags & REQ_F_BUFFER_SELECT) {
407                 compat_ssize_t clen;
408
409                 if (msg.msg_iovlen == 0) {
410                         sr->len = 0;
411                         iomsg->free_iov = NULL;
412                 } else if (msg.msg_iovlen > 1) {
413                         return -EINVAL;
414                 } else {
415                         if (!access_ok(uiov, sizeof(*uiov)))
416                                 return -EFAULT;
417                         if (__get_user(clen, &uiov->iov_len))
418                                 return -EFAULT;
419                         if (clen < 0)
420                                 return -EINVAL;
421                         sr->len = clen;
422                         iomsg->free_iov = NULL;
423                 }
424
425                 if (req->flags & REQ_F_APOLL_MULTISHOT) {
426                         iomsg->namelen = msg.msg_namelen;
427                         iomsg->controllen = msg.msg_controllen;
428                         if (io_recvmsg_multishot_overflow(iomsg))
429                                 return -EOVERFLOW;
430                 }
431         } else {
432                 iomsg->free_iov = iomsg->fast_iov;
433                 ret = __import_iovec(READ, (struct iovec __user *)uiov, msg.msg_iovlen,
434                                    UIO_FASTIOV, &iomsg->free_iov,
435                                    &iomsg->msg.msg_iter, true);
436                 if (ret < 0)
437                         return ret;
438         }
439
440         return 0;
441 }
442 #endif
443
444 static int io_recvmsg_copy_hdr(struct io_kiocb *req,
445                                struct io_async_msghdr *iomsg)
446 {
447         iomsg->msg.msg_name = &iomsg->addr;
448
449 #ifdef CONFIG_COMPAT
450         if (req->ctx->compat)
451                 return __io_compat_recvmsg_copy_hdr(req, iomsg);
452 #endif
453
454         return __io_recvmsg_copy_hdr(req, iomsg);
455 }
456
457 int io_recvmsg_prep_async(struct io_kiocb *req)
458 {
459         int ret;
460
461         ret = io_recvmsg_copy_hdr(req, req->async_data);
462         if (!ret)
463                 req->flags |= REQ_F_NEED_CLEANUP;
464         return ret;
465 }
466
467 #define RECVMSG_FLAGS (IORING_RECVSEND_POLL_FIRST | IORING_RECV_MULTISHOT)
468
469 int io_recvmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
470 {
471         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
472
473         if (unlikely(sqe->file_index || sqe->addr2))
474                 return -EINVAL;
475
476         sr->umsg = u64_to_user_ptr(READ_ONCE(sqe->addr));
477         sr->len = READ_ONCE(sqe->len);
478         sr->flags = READ_ONCE(sqe->ioprio);
479         if (sr->flags & ~(RECVMSG_FLAGS))
480                 return -EINVAL;
481         sr->msg_flags = READ_ONCE(sqe->msg_flags) | MSG_NOSIGNAL;
482         if (sr->msg_flags & MSG_DONTWAIT)
483                 req->flags |= REQ_F_NOWAIT;
484         if (sr->msg_flags & MSG_ERRQUEUE)
485                 req->flags |= REQ_F_CLEAR_POLLIN;
486         if (sr->flags & IORING_RECV_MULTISHOT) {
487                 if (!(req->flags & REQ_F_BUFFER_SELECT))
488                         return -EINVAL;
489                 if (sr->msg_flags & MSG_WAITALL)
490                         return -EINVAL;
491                 if (req->opcode == IORING_OP_RECV && sr->len)
492                         return -EINVAL;
493                 req->flags |= REQ_F_APOLL_MULTISHOT;
494         }
495
496 #ifdef CONFIG_COMPAT
497         if (req->ctx->compat)
498                 sr->msg_flags |= MSG_CMSG_COMPAT;
499 #endif
500         sr->done_io = 0;
501         return 0;
502 }
503
504 static inline void io_recv_prep_retry(struct io_kiocb *req)
505 {
506         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
507
508         sr->done_io = 0;
509         sr->len = 0; /* get from the provided buffer */
510 }
511
512 /*
513  * Finishes io_recv and io_recvmsg.
514  *
515  * Returns true if it is actually finished, or false if it should run
516  * again (for multishot).
517  */
518 static inline bool io_recv_finish(struct io_kiocb *req, int *ret,
519                                   unsigned int cflags, bool mshot_finished)
520 {
521         if (!(req->flags & REQ_F_APOLL_MULTISHOT)) {
522                 io_req_set_res(req, *ret, cflags);
523                 *ret = IOU_OK;
524                 return true;
525         }
526
527         if (!mshot_finished) {
528                 if (io_post_aux_cqe(req->ctx, req->cqe.user_data, *ret,
529                                     cflags | IORING_CQE_F_MORE, false)) {
530                         io_recv_prep_retry(req);
531                         return false;
532                 }
533                 /*
534                  * Otherwise stop multishot but use the current result.
535                  * Probably will end up going into overflow, but this means
536                  * we cannot trust the ordering anymore
537                  */
538         }
539
540         io_req_set_res(req, *ret, cflags);
541
542         if (req->flags & REQ_F_POLLED)
543                 *ret = IOU_STOP_MULTISHOT;
544         else
545                 *ret = IOU_OK;
546         return true;
547 }
548
549 static int io_recvmsg_prep_multishot(struct io_async_msghdr *kmsg,
550                                      struct io_sr_msg *sr, void __user **buf,
551                                      size_t *len)
552 {
553         unsigned long ubuf = (unsigned long) *buf;
554         unsigned long hdr;
555
556         hdr = sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
557                 kmsg->controllen;
558         if (*len < hdr)
559                 return -EFAULT;
560
561         if (kmsg->controllen) {
562                 unsigned long control = ubuf + hdr - kmsg->controllen;
563
564                 kmsg->msg.msg_control_user = (void *) control;
565                 kmsg->msg.msg_controllen = kmsg->controllen;
566         }
567
568         sr->buf = *buf; /* stash for later copy */
569         *buf = (void *) (ubuf + hdr);
570         kmsg->payloadlen = *len = *len - hdr;
571         return 0;
572 }
573
574 struct io_recvmsg_multishot_hdr {
575         struct io_uring_recvmsg_out msg;
576         struct sockaddr_storage addr;
577 };
578
579 static int io_recvmsg_multishot(struct socket *sock, struct io_sr_msg *io,
580                                 struct io_async_msghdr *kmsg,
581                                 unsigned int flags, bool *finished)
582 {
583         int err;
584         int copy_len;
585         struct io_recvmsg_multishot_hdr hdr;
586
587         if (kmsg->namelen)
588                 kmsg->msg.msg_name = &hdr.addr;
589         kmsg->msg.msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT);
590         kmsg->msg.msg_namelen = 0;
591
592         if (sock->file->f_flags & O_NONBLOCK)
593                 flags |= MSG_DONTWAIT;
594
595         err = sock_recvmsg(sock, &kmsg->msg, flags);
596         *finished = err <= 0;
597         if (err < 0)
598                 return err;
599
600         hdr.msg = (struct io_uring_recvmsg_out) {
601                 .controllen = kmsg->controllen - kmsg->msg.msg_controllen,
602                 .flags = kmsg->msg.msg_flags & ~MSG_CMSG_COMPAT
603         };
604
605         hdr.msg.payloadlen = err;
606         if (err > kmsg->payloadlen)
607                 err = kmsg->payloadlen;
608
609         copy_len = sizeof(struct io_uring_recvmsg_out);
610         if (kmsg->msg.msg_namelen > kmsg->namelen)
611                 copy_len += kmsg->namelen;
612         else
613                 copy_len += kmsg->msg.msg_namelen;
614
615         /*
616          *      "fromlen shall refer to the value before truncation.."
617          *                      1003.1g
618          */
619         hdr.msg.namelen = kmsg->msg.msg_namelen;
620
621         /* ensure that there is no gap between hdr and sockaddr_storage */
622         BUILD_BUG_ON(offsetof(struct io_recvmsg_multishot_hdr, addr) !=
623                      sizeof(struct io_uring_recvmsg_out));
624         if (copy_to_user(io->buf, &hdr, copy_len)) {
625                 *finished = true;
626                 return -EFAULT;
627         }
628
629         return sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
630                         kmsg->controllen + err;
631 }
632
633 int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
634 {
635         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
636         struct io_async_msghdr iomsg, *kmsg;
637         struct socket *sock;
638         unsigned int cflags;
639         unsigned flags;
640         int ret, min_ret = 0;
641         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
642         bool mshot_finished = true;
643
644         sock = sock_from_file(req->file);
645         if (unlikely(!sock))
646                 return -ENOTSOCK;
647
648         if (req_has_async_data(req)) {
649                 kmsg = req->async_data;
650         } else {
651                 ret = io_recvmsg_copy_hdr(req, &iomsg);
652                 if (ret)
653                         return ret;
654                 kmsg = &iomsg;
655         }
656
657         if (!(req->flags & REQ_F_POLLED) &&
658             (sr->flags & IORING_RECVSEND_POLL_FIRST))
659                 return io_setup_async_msg(req, kmsg, issue_flags);
660
661 retry_multishot:
662         if (io_do_buffer_select(req)) {
663                 void __user *buf;
664                 size_t len = sr->len;
665
666                 buf = io_buffer_select(req, &len, issue_flags);
667                 if (!buf)
668                         return -ENOBUFS;
669
670                 if (req->flags & REQ_F_APOLL_MULTISHOT) {
671                         ret = io_recvmsg_prep_multishot(kmsg, sr, &buf, &len);
672                         if (ret) {
673                                 io_kbuf_recycle(req, issue_flags);
674                                 return ret;
675                         }
676                 }
677
678                 kmsg->fast_iov[0].iov_base = buf;
679                 kmsg->fast_iov[0].iov_len = len;
680                 iov_iter_init(&kmsg->msg.msg_iter, READ, kmsg->fast_iov, 1,
681                                 len);
682         }
683
684         flags = sr->msg_flags;
685         if (force_nonblock)
686                 flags |= MSG_DONTWAIT;
687         if (flags & MSG_WAITALL)
688                 min_ret = iov_iter_count(&kmsg->msg.msg_iter);
689
690         kmsg->msg.msg_get_inq = 1;
691         if (req->flags & REQ_F_APOLL_MULTISHOT)
692                 ret = io_recvmsg_multishot(sock, sr, kmsg, flags,
693                                            &mshot_finished);
694         else
695                 ret = __sys_recvmsg_sock(sock, &kmsg->msg, sr->umsg,
696                                          kmsg->uaddr, flags);
697
698         if (ret < min_ret) {
699                 if (ret == -EAGAIN && force_nonblock) {
700                         ret = io_setup_async_msg(req, kmsg, issue_flags);
701                         if (ret == -EAGAIN && (req->flags & IO_APOLL_MULTI_POLLED) ==
702                                                IO_APOLL_MULTI_POLLED) {
703                                 io_kbuf_recycle(req, issue_flags);
704                                 return IOU_ISSUE_SKIP_COMPLETE;
705                         }
706                         return ret;
707                 }
708                 if (ret == -ERESTARTSYS)
709                         ret = -EINTR;
710                 if (ret > 0 && io_net_retry(sock, flags)) {
711                         sr->done_io += ret;
712                         req->flags |= REQ_F_PARTIAL_IO;
713                         return io_setup_async_msg(req, kmsg, issue_flags);
714                 }
715                 req_set_fail(req);
716         } else if ((flags & MSG_WAITALL) && (kmsg->msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
717                 req_set_fail(req);
718         }
719
720         if (ret > 0)
721                 ret += sr->done_io;
722         else if (sr->done_io)
723                 ret = sr->done_io;
724         else
725                 io_kbuf_recycle(req, issue_flags);
726
727         cflags = io_put_kbuf(req, issue_flags);
728         if (kmsg->msg.msg_inq)
729                 cflags |= IORING_CQE_F_SOCK_NONEMPTY;
730
731         if (!io_recv_finish(req, &ret, cflags, mshot_finished))
732                 goto retry_multishot;
733
734         if (mshot_finished) {
735                 io_netmsg_recycle(req, issue_flags);
736                 /* fast path, check for non-NULL to avoid function call */
737                 if (kmsg->free_iov)
738                         kfree(kmsg->free_iov);
739                 req->flags &= ~REQ_F_NEED_CLEANUP;
740         }
741
742         return ret;
743 }
744
745 int io_recv(struct io_kiocb *req, unsigned int issue_flags)
746 {
747         struct io_sr_msg *sr = io_kiocb_to_cmd(req);
748         struct msghdr msg;
749         struct socket *sock;
750         struct iovec iov;
751         unsigned int cflags;
752         unsigned flags;
753         int ret, min_ret = 0;
754         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
755         size_t len = sr->len;
756
757         if (!(req->flags & REQ_F_POLLED) &&
758             (sr->flags & IORING_RECVSEND_POLL_FIRST))
759                 return -EAGAIN;
760
761         sock = sock_from_file(req->file);
762         if (unlikely(!sock))
763                 return -ENOTSOCK;
764
765 retry_multishot:
766         if (io_do_buffer_select(req)) {
767                 void __user *buf;
768
769                 buf = io_buffer_select(req, &len, issue_flags);
770                 if (!buf)
771                         return -ENOBUFS;
772                 sr->buf = buf;
773         }
774
775         ret = import_single_range(READ, sr->buf, len, &iov, &msg.msg_iter);
776         if (unlikely(ret))
777                 goto out_free;
778
779         msg.msg_name = NULL;
780         msg.msg_namelen = 0;
781         msg.msg_control = NULL;
782         msg.msg_get_inq = 1;
783         msg.msg_flags = 0;
784         msg.msg_controllen = 0;
785         msg.msg_iocb = NULL;
786
787         flags = sr->msg_flags;
788         if (force_nonblock)
789                 flags |= MSG_DONTWAIT;
790         if (flags & MSG_WAITALL)
791                 min_ret = iov_iter_count(&msg.msg_iter);
792
793         ret = sock_recvmsg(sock, &msg, flags);
794         if (ret < min_ret) {
795                 if (ret == -EAGAIN && force_nonblock) {
796                         if ((req->flags & IO_APOLL_MULTI_POLLED) == IO_APOLL_MULTI_POLLED) {
797                                 io_kbuf_recycle(req, issue_flags);
798                                 return IOU_ISSUE_SKIP_COMPLETE;
799                         }
800
801                         return -EAGAIN;
802                 }
803                 if (ret == -ERESTARTSYS)
804                         ret = -EINTR;
805                 if (ret > 0 && io_net_retry(sock, flags)) {
806                         sr->len -= ret;
807                         sr->buf += ret;
808                         sr->done_io += ret;
809                         req->flags |= REQ_F_PARTIAL_IO;
810                         return -EAGAIN;
811                 }
812                 req_set_fail(req);
813         } else if ((flags & MSG_WAITALL) && (msg.msg_flags & (MSG_TRUNC | MSG_CTRUNC))) {
814 out_free:
815                 req_set_fail(req);
816         }
817
818         if (ret > 0)
819                 ret += sr->done_io;
820         else if (sr->done_io)
821                 ret = sr->done_io;
822         else
823                 io_kbuf_recycle(req, issue_flags);
824
825         cflags = io_put_kbuf(req, issue_flags);
826         if (msg.msg_inq)
827                 cflags |= IORING_CQE_F_SOCK_NONEMPTY;
828
829         if (!io_recv_finish(req, &ret, cflags, ret <= 0))
830                 goto retry_multishot;
831
832         return ret;
833 }
834
835 int io_accept_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
836 {
837         struct io_accept *accept = io_kiocb_to_cmd(req);
838         unsigned flags;
839
840         if (sqe->len || sqe->buf_index)
841                 return -EINVAL;
842
843         accept->addr = u64_to_user_ptr(READ_ONCE(sqe->addr));
844         accept->addr_len = u64_to_user_ptr(READ_ONCE(sqe->addr2));
845         accept->flags = READ_ONCE(sqe->accept_flags);
846         accept->nofile = rlimit(RLIMIT_NOFILE);
847         flags = READ_ONCE(sqe->ioprio);
848         if (flags & ~IORING_ACCEPT_MULTISHOT)
849                 return -EINVAL;
850
851         accept->file_slot = READ_ONCE(sqe->file_index);
852         if (accept->file_slot) {
853                 if (accept->flags & SOCK_CLOEXEC)
854                         return -EINVAL;
855                 if (flags & IORING_ACCEPT_MULTISHOT &&
856                     accept->file_slot != IORING_FILE_INDEX_ALLOC)
857                         return -EINVAL;
858         }
859         if (accept->flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
860                 return -EINVAL;
861         if (SOCK_NONBLOCK != O_NONBLOCK && (accept->flags & SOCK_NONBLOCK))
862                 accept->flags = (accept->flags & ~SOCK_NONBLOCK) | O_NONBLOCK;
863         if (flags & IORING_ACCEPT_MULTISHOT)
864                 req->flags |= REQ_F_APOLL_MULTISHOT;
865         return 0;
866 }
867
868 int io_accept(struct io_kiocb *req, unsigned int issue_flags)
869 {
870         struct io_ring_ctx *ctx = req->ctx;
871         struct io_accept *accept = io_kiocb_to_cmd(req);
872         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
873         unsigned int file_flags = force_nonblock ? O_NONBLOCK : 0;
874         bool fixed = !!accept->file_slot;
875         struct file *file;
876         int ret, fd;
877
878 retry:
879         if (!fixed) {
880                 fd = __get_unused_fd_flags(accept->flags, accept->nofile);
881                 if (unlikely(fd < 0))
882                         return fd;
883         }
884         file = do_accept(req->file, file_flags, accept->addr, accept->addr_len,
885                          accept->flags);
886         if (IS_ERR(file)) {
887                 if (!fixed)
888                         put_unused_fd(fd);
889                 ret = PTR_ERR(file);
890                 if (ret == -EAGAIN && force_nonblock) {
891                         /*
892                          * if it's multishot and polled, we don't need to
893                          * return EAGAIN to arm the poll infra since it
894                          * has already been done
895                          */
896                         if ((req->flags & IO_APOLL_MULTI_POLLED) ==
897                             IO_APOLL_MULTI_POLLED)
898                                 ret = IOU_ISSUE_SKIP_COMPLETE;
899                         return ret;
900                 }
901                 if (ret == -ERESTARTSYS)
902                         ret = -EINTR;
903                 req_set_fail(req);
904         } else if (!fixed) {
905                 fd_install(fd, file);
906                 ret = fd;
907         } else {
908                 ret = io_fixed_fd_install(req, issue_flags, file,
909                                                 accept->file_slot);
910         }
911
912         if (!(req->flags & REQ_F_APOLL_MULTISHOT)) {
913                 io_req_set_res(req, ret, 0);
914                 return IOU_OK;
915         }
916
917         if (ret >= 0 &&
918             io_post_aux_cqe(ctx, req->cqe.user_data, ret, IORING_CQE_F_MORE, false))
919                 goto retry;
920
921         io_req_set_res(req, ret, 0);
922         if (req->flags & REQ_F_POLLED)
923                 return IOU_STOP_MULTISHOT;
924         return IOU_OK;
925 }
926
927 int io_socket_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
928 {
929         struct io_socket *sock = io_kiocb_to_cmd(req);
930
931         if (sqe->addr || sqe->rw_flags || sqe->buf_index)
932                 return -EINVAL;
933
934         sock->domain = READ_ONCE(sqe->fd);
935         sock->type = READ_ONCE(sqe->off);
936         sock->protocol = READ_ONCE(sqe->len);
937         sock->file_slot = READ_ONCE(sqe->file_index);
938         sock->nofile = rlimit(RLIMIT_NOFILE);
939
940         sock->flags = sock->type & ~SOCK_TYPE_MASK;
941         if (sock->file_slot && (sock->flags & SOCK_CLOEXEC))
942                 return -EINVAL;
943         if (sock->flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
944                 return -EINVAL;
945         return 0;
946 }
947
948 int io_socket(struct io_kiocb *req, unsigned int issue_flags)
949 {
950         struct io_socket *sock = io_kiocb_to_cmd(req);
951         bool fixed = !!sock->file_slot;
952         struct file *file;
953         int ret, fd;
954
955         if (!fixed) {
956                 fd = __get_unused_fd_flags(sock->flags, sock->nofile);
957                 if (unlikely(fd < 0))
958                         return fd;
959         }
960         file = __sys_socket_file(sock->domain, sock->type, sock->protocol);
961         if (IS_ERR(file)) {
962                 if (!fixed)
963                         put_unused_fd(fd);
964                 ret = PTR_ERR(file);
965                 if (ret == -EAGAIN && (issue_flags & IO_URING_F_NONBLOCK))
966                         return -EAGAIN;
967                 if (ret == -ERESTARTSYS)
968                         ret = -EINTR;
969                 req_set_fail(req);
970         } else if (!fixed) {
971                 fd_install(fd, file);
972                 ret = fd;
973         } else {
974                 ret = io_fixed_fd_install(req, issue_flags, file,
975                                             sock->file_slot);
976         }
977         io_req_set_res(req, ret, 0);
978         return IOU_OK;
979 }
980
981 int io_connect_prep_async(struct io_kiocb *req)
982 {
983         struct io_async_connect *io = req->async_data;
984         struct io_connect *conn = io_kiocb_to_cmd(req);
985
986         return move_addr_to_kernel(conn->addr, conn->addr_len, &io->address);
987 }
988
989 int io_connect_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
990 {
991         struct io_connect *conn = io_kiocb_to_cmd(req);
992
993         if (sqe->len || sqe->buf_index || sqe->rw_flags || sqe->splice_fd_in)
994                 return -EINVAL;
995
996         conn->addr = u64_to_user_ptr(READ_ONCE(sqe->addr));
997         conn->addr_len =  READ_ONCE(sqe->addr2);
998         return 0;
999 }
1000
1001 int io_connect(struct io_kiocb *req, unsigned int issue_flags)
1002 {
1003         struct io_connect *connect = io_kiocb_to_cmd(req);
1004         struct io_async_connect __io, *io;
1005         unsigned file_flags;
1006         int ret;
1007         bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
1008
1009         if (req_has_async_data(req)) {
1010                 io = req->async_data;
1011         } else {
1012                 ret = move_addr_to_kernel(connect->addr,
1013                                                 connect->addr_len,
1014                                                 &__io.address);
1015                 if (ret)
1016                         goto out;
1017                 io = &__io;
1018         }
1019
1020         file_flags = force_nonblock ? O_NONBLOCK : 0;
1021
1022         ret = __sys_connect_file(req->file, &io->address,
1023                                         connect->addr_len, file_flags);
1024         if ((ret == -EAGAIN || ret == -EINPROGRESS) && force_nonblock) {
1025                 if (req_has_async_data(req))
1026                         return -EAGAIN;
1027                 if (io_alloc_async_data(req)) {
1028                         ret = -ENOMEM;
1029                         goto out;
1030                 }
1031                 memcpy(req->async_data, &__io, sizeof(__io));
1032                 return -EAGAIN;
1033         }
1034         if (ret == -ERESTARTSYS)
1035                 ret = -EINTR;
1036 out:
1037         if (ret < 0)
1038                 req_set_fail(req);
1039         io_req_set_res(req, ret, 0);
1040         return IOU_OK;
1041 }
1042
1043 void io_netmsg_cache_free(struct io_cache_entry *entry)
1044 {
1045         kfree(container_of(entry, struct io_async_msghdr, cache));
1046 }
1047 #endif