io_uring/rsrc: kill rsrc_ref_lock
[platform/kernel/linux-starfive.git] / io_uring / rsrc.c
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/kernel.h>
3 #include <linux/errno.h>
4 #include <linux/fs.h>
5 #include <linux/file.h>
6 #include <linux/mm.h>
7 #include <linux/slab.h>
8 #include <linux/nospec.h>
9 #include <linux/hugetlb.h>
10 #include <linux/compat.h>
11 #include <linux/io_uring.h>
12
13 #include <uapi/linux/io_uring.h>
14
15 #include "io_uring.h"
16 #include "openclose.h"
17 #include "rsrc.h"
18
19 struct io_rsrc_update {
20         struct file                     *file;
21         u64                             arg;
22         u32                             nr_args;
23         u32                             offset;
24 };
25
26 static int io_sqe_buffer_register(struct io_ring_ctx *ctx, struct iovec *iov,
27                                   struct io_mapped_ubuf **pimu,
28                                   struct page **last_hpage);
29
30 /* only define max */
31 #define IORING_MAX_FIXED_FILES  (1U << 20)
32 #define IORING_MAX_REG_BUFFERS  (1U << 14)
33
34 int __io_account_mem(struct user_struct *user, unsigned long nr_pages)
35 {
36         unsigned long page_limit, cur_pages, new_pages;
37
38         if (!nr_pages)
39                 return 0;
40
41         /* Don't allow more pages than we can safely lock */
42         page_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
43
44         cur_pages = atomic_long_read(&user->locked_vm);
45         do {
46                 new_pages = cur_pages + nr_pages;
47                 if (new_pages > page_limit)
48                         return -ENOMEM;
49         } while (!atomic_long_try_cmpxchg(&user->locked_vm,
50                                           &cur_pages, new_pages));
51         return 0;
52 }
53
54 static void io_unaccount_mem(struct io_ring_ctx *ctx, unsigned long nr_pages)
55 {
56         if (ctx->user)
57                 __io_unaccount_mem(ctx->user, nr_pages);
58
59         if (ctx->mm_account)
60                 atomic64_sub(nr_pages, &ctx->mm_account->pinned_vm);
61 }
62
63 static int io_account_mem(struct io_ring_ctx *ctx, unsigned long nr_pages)
64 {
65         int ret;
66
67         if (ctx->user) {
68                 ret = __io_account_mem(ctx->user, nr_pages);
69                 if (ret)
70                         return ret;
71         }
72
73         if (ctx->mm_account)
74                 atomic64_add(nr_pages, &ctx->mm_account->pinned_vm);
75
76         return 0;
77 }
78
79 static int io_copy_iov(struct io_ring_ctx *ctx, struct iovec *dst,
80                        void __user *arg, unsigned index)
81 {
82         struct iovec __user *src;
83
84 #ifdef CONFIG_COMPAT
85         if (ctx->compat) {
86                 struct compat_iovec __user *ciovs;
87                 struct compat_iovec ciov;
88
89                 ciovs = (struct compat_iovec __user *) arg;
90                 if (copy_from_user(&ciov, &ciovs[index], sizeof(ciov)))
91                         return -EFAULT;
92
93                 dst->iov_base = u64_to_user_ptr((u64)ciov.iov_base);
94                 dst->iov_len = ciov.iov_len;
95                 return 0;
96         }
97 #endif
98         src = (struct iovec __user *) arg;
99         if (copy_from_user(dst, &src[index], sizeof(*dst)))
100                 return -EFAULT;
101         return 0;
102 }
103
104 static int io_buffer_validate(struct iovec *iov)
105 {
106         unsigned long tmp, acct_len = iov->iov_len + (PAGE_SIZE - 1);
107
108         /*
109          * Don't impose further limits on the size and buffer
110          * constraints here, we'll -EINVAL later when IO is
111          * submitted if they are wrong.
112          */
113         if (!iov->iov_base)
114                 return iov->iov_len ? -EFAULT : 0;
115         if (!iov->iov_len)
116                 return -EFAULT;
117
118         /* arbitrary limit, but we need something */
119         if (iov->iov_len > SZ_1G)
120                 return -EFAULT;
121
122         if (check_add_overflow((unsigned long)iov->iov_base, acct_len, &tmp))
123                 return -EOVERFLOW;
124
125         return 0;
126 }
127
128 static void io_buffer_unmap(struct io_ring_ctx *ctx, struct io_mapped_ubuf **slot)
129 {
130         struct io_mapped_ubuf *imu = *slot;
131         unsigned int i;
132
133         if (imu != ctx->dummy_ubuf) {
134                 for (i = 0; i < imu->nr_bvecs; i++)
135                         unpin_user_page(imu->bvec[i].bv_page);
136                 if (imu->acct_pages)
137                         io_unaccount_mem(ctx, imu->acct_pages);
138                 kvfree(imu);
139         }
140         *slot = NULL;
141 }
142
143 static void __io_rsrc_put_work(struct io_rsrc_node *ref_node)
144 {
145         struct io_rsrc_data *rsrc_data = ref_node->rsrc_data;
146         struct io_ring_ctx *ctx = rsrc_data->ctx;
147         struct io_rsrc_put *prsrc, *tmp;
148
149         list_for_each_entry_safe(prsrc, tmp, &ref_node->rsrc_list, list) {
150                 list_del(&prsrc->list);
151
152                 if (prsrc->tag) {
153                         if (ctx->flags & IORING_SETUP_IOPOLL) {
154                                 mutex_lock(&ctx->uring_lock);
155                                 io_post_aux_cqe(ctx, prsrc->tag, 0, 0);
156                                 mutex_unlock(&ctx->uring_lock);
157                         } else {
158                                 io_post_aux_cqe(ctx, prsrc->tag, 0, 0);
159                         }
160                 }
161
162                 rsrc_data->do_put(ctx, prsrc);
163                 kfree(prsrc);
164         }
165
166         io_rsrc_node_destroy(ref_node);
167         if (atomic_dec_and_test(&rsrc_data->refs))
168                 complete(&rsrc_data->done);
169 }
170
171 void io_rsrc_put_work(struct work_struct *work)
172 {
173         struct io_ring_ctx *ctx;
174         struct llist_node *node;
175
176         ctx = container_of(work, struct io_ring_ctx, rsrc_put_work.work);
177         node = llist_del_all(&ctx->rsrc_put_llist);
178
179         while (node) {
180                 struct io_rsrc_node *ref_node;
181                 struct llist_node *next = node->next;
182
183                 ref_node = llist_entry(node, struct io_rsrc_node, llist);
184                 __io_rsrc_put_work(ref_node);
185                 node = next;
186         }
187 }
188
189 void io_rsrc_put_tw(struct callback_head *cb)
190 {
191         struct io_ring_ctx *ctx = container_of(cb, struct io_ring_ctx,
192                                                rsrc_put_tw);
193
194         io_rsrc_put_work(&ctx->rsrc_put_work.work);
195 }
196
197 void io_wait_rsrc_data(struct io_rsrc_data *data)
198 {
199         if (data && !atomic_dec_and_test(&data->refs))
200                 wait_for_completion(&data->done);
201 }
202
203 void io_rsrc_node_destroy(struct io_rsrc_node *ref_node)
204 {
205         kfree(ref_node);
206 }
207
208 void io_rsrc_node_ref_zero(struct io_rsrc_node *node)
209         __must_hold(&node->rsrc_data->ctx->uring_lock)
210 {
211         struct io_ring_ctx *ctx = node->rsrc_data->ctx;
212         bool first_add = false;
213         unsigned long delay = HZ;
214
215         node->done = true;
216
217         /* if we are mid-quiesce then do not delay */
218         if (node->rsrc_data->quiesce)
219                 delay = 0;
220
221         while (!list_empty(&ctx->rsrc_ref_list)) {
222                 node = list_first_entry(&ctx->rsrc_ref_list,
223                                             struct io_rsrc_node, node);
224                 /* recycle ref nodes in order */
225                 if (!node->done)
226                         break;
227                 list_del(&node->node);
228                 first_add |= llist_add(&node->llist, &ctx->rsrc_put_llist);
229         }
230
231         if (!first_add)
232                 return;
233
234         if (ctx->submitter_task) {
235                 if (!task_work_add(ctx->submitter_task, &ctx->rsrc_put_tw,
236                                    ctx->notify_method))
237                         return;
238         }
239         mod_delayed_work(system_wq, &ctx->rsrc_put_work, delay);
240 }
241
242 static struct io_rsrc_node *io_rsrc_node_alloc(void)
243 {
244         struct io_rsrc_node *ref_node;
245
246         ref_node = kzalloc(sizeof(*ref_node), GFP_KERNEL);
247         if (!ref_node)
248                 return NULL;
249
250         ref_node->refs = 1;
251         INIT_LIST_HEAD(&ref_node->node);
252         INIT_LIST_HEAD(&ref_node->rsrc_list);
253         ref_node->done = false;
254         return ref_node;
255 }
256
257 void io_rsrc_node_switch(struct io_ring_ctx *ctx,
258                          struct io_rsrc_data *data_to_kill)
259         __must_hold(&ctx->uring_lock)
260 {
261         WARN_ON_ONCE(!ctx->rsrc_backup_node);
262         WARN_ON_ONCE(data_to_kill && !ctx->rsrc_node);
263
264         if (data_to_kill) {
265                 struct io_rsrc_node *rsrc_node = ctx->rsrc_node;
266
267                 rsrc_node->rsrc_data = data_to_kill;
268                 list_add_tail(&rsrc_node->node, &ctx->rsrc_ref_list);
269
270                 atomic_inc(&data_to_kill->refs);
271                 /* put master ref */
272                 io_put_rsrc_node(rsrc_node);
273                 ctx->rsrc_node = NULL;
274         }
275
276         if (!ctx->rsrc_node) {
277                 ctx->rsrc_node = ctx->rsrc_backup_node;
278                 ctx->rsrc_backup_node = NULL;
279         }
280 }
281
282 int io_rsrc_node_switch_start(struct io_ring_ctx *ctx)
283 {
284         if (ctx->rsrc_backup_node)
285                 return 0;
286         ctx->rsrc_backup_node = io_rsrc_node_alloc();
287         return ctx->rsrc_backup_node ? 0 : -ENOMEM;
288 }
289
290 __cold static int io_rsrc_ref_quiesce(struct io_rsrc_data *data,
291                                       struct io_ring_ctx *ctx)
292 {
293         int ret;
294
295         /* As we may drop ->uring_lock, other task may have started quiesce */
296         if (data->quiesce)
297                 return -ENXIO;
298         ret = io_rsrc_node_switch_start(ctx);
299         if (ret)
300                 return ret;
301         io_rsrc_node_switch(ctx, data);
302
303         /* kill initial ref, already quiesced if zero */
304         if (atomic_dec_and_test(&data->refs))
305                 return 0;
306
307         data->quiesce = true;
308         mutex_unlock(&ctx->uring_lock);
309         do {
310                 ret = io_run_task_work_sig(ctx);
311                 if (ret < 0) {
312                         atomic_inc(&data->refs);
313                         /* wait for all works potentially completing data->done */
314                         flush_delayed_work(&ctx->rsrc_put_work);
315                         reinit_completion(&data->done);
316                         mutex_lock(&ctx->uring_lock);
317                         break;
318                 }
319
320                 flush_delayed_work(&ctx->rsrc_put_work);
321                 ret = wait_for_completion_interruptible(&data->done);
322                 if (!ret) {
323                         mutex_lock(&ctx->uring_lock);
324                         if (atomic_read(&data->refs) <= 0)
325                                 break;
326                         /*
327                          * it has been revived by another thread while
328                          * we were unlocked
329                          */
330                         mutex_unlock(&ctx->uring_lock);
331                 }
332         } while (1);
333         data->quiesce = false;
334
335         return ret;
336 }
337
338 static void io_free_page_table(void **table, size_t size)
339 {
340         unsigned i, nr_tables = DIV_ROUND_UP(size, PAGE_SIZE);
341
342         for (i = 0; i < nr_tables; i++)
343                 kfree(table[i]);
344         kfree(table);
345 }
346
347 static void io_rsrc_data_free(struct io_rsrc_data *data)
348 {
349         size_t size = data->nr * sizeof(data->tags[0][0]);
350
351         if (data->tags)
352                 io_free_page_table((void **)data->tags, size);
353         kfree(data);
354 }
355
356 static __cold void **io_alloc_page_table(size_t size)
357 {
358         unsigned i, nr_tables = DIV_ROUND_UP(size, PAGE_SIZE);
359         size_t init_size = size;
360         void **table;
361
362         table = kcalloc(nr_tables, sizeof(*table), GFP_KERNEL_ACCOUNT);
363         if (!table)
364                 return NULL;
365
366         for (i = 0; i < nr_tables; i++) {
367                 unsigned int this_size = min_t(size_t, size, PAGE_SIZE);
368
369                 table[i] = kzalloc(this_size, GFP_KERNEL_ACCOUNT);
370                 if (!table[i]) {
371                         io_free_page_table(table, init_size);
372                         return NULL;
373                 }
374                 size -= this_size;
375         }
376         return table;
377 }
378
379 __cold static int io_rsrc_data_alloc(struct io_ring_ctx *ctx,
380                                      rsrc_put_fn *do_put, u64 __user *utags,
381                                      unsigned nr, struct io_rsrc_data **pdata)
382 {
383         struct io_rsrc_data *data;
384         int ret = 0;
385         unsigned i;
386
387         data = kzalloc(sizeof(*data), GFP_KERNEL);
388         if (!data)
389                 return -ENOMEM;
390         data->tags = (u64 **)io_alloc_page_table(nr * sizeof(data->tags[0][0]));
391         if (!data->tags) {
392                 kfree(data);
393                 return -ENOMEM;
394         }
395
396         data->nr = nr;
397         data->ctx = ctx;
398         data->do_put = do_put;
399         if (utags) {
400                 ret = -EFAULT;
401                 for (i = 0; i < nr; i++) {
402                         u64 *tag_slot = io_get_tag_slot(data, i);
403
404                         if (copy_from_user(tag_slot, &utags[i],
405                                            sizeof(*tag_slot)))
406                                 goto fail;
407                 }
408         }
409
410         atomic_set(&data->refs, 1);
411         init_completion(&data->done);
412         *pdata = data;
413         return 0;
414 fail:
415         io_rsrc_data_free(data);
416         return ret;
417 }
418
419 static int __io_sqe_files_update(struct io_ring_ctx *ctx,
420                                  struct io_uring_rsrc_update2 *up,
421                                  unsigned nr_args)
422 {
423         u64 __user *tags = u64_to_user_ptr(up->tags);
424         __s32 __user *fds = u64_to_user_ptr(up->data);
425         struct io_rsrc_data *data = ctx->file_data;
426         struct io_fixed_file *file_slot;
427         struct file *file;
428         int fd, i, err = 0;
429         unsigned int done;
430         bool needs_switch = false;
431
432         if (!ctx->file_data)
433                 return -ENXIO;
434         if (up->offset + nr_args > ctx->nr_user_files)
435                 return -EINVAL;
436
437         for (done = 0; done < nr_args; done++) {
438                 u64 tag = 0;
439
440                 if ((tags && copy_from_user(&tag, &tags[done], sizeof(tag))) ||
441                     copy_from_user(&fd, &fds[done], sizeof(fd))) {
442                         err = -EFAULT;
443                         break;
444                 }
445                 if ((fd == IORING_REGISTER_FILES_SKIP || fd == -1) && tag) {
446                         err = -EINVAL;
447                         break;
448                 }
449                 if (fd == IORING_REGISTER_FILES_SKIP)
450                         continue;
451
452                 i = array_index_nospec(up->offset + done, ctx->nr_user_files);
453                 file_slot = io_fixed_file_slot(&ctx->file_table, i);
454
455                 if (file_slot->file_ptr) {
456                         file = (struct file *)(file_slot->file_ptr & FFS_MASK);
457                         err = io_queue_rsrc_removal(data, i, ctx->rsrc_node, file);
458                         if (err)
459                                 break;
460                         file_slot->file_ptr = 0;
461                         io_file_bitmap_clear(&ctx->file_table, i);
462                         needs_switch = true;
463                 }
464                 if (fd != -1) {
465                         file = fget(fd);
466                         if (!file) {
467                                 err = -EBADF;
468                                 break;
469                         }
470                         /*
471                          * Don't allow io_uring instances to be registered. If
472                          * UNIX isn't enabled, then this causes a reference
473                          * cycle and this instance can never get freed. If UNIX
474                          * is enabled we'll handle it just fine, but there's
475                          * still no point in allowing a ring fd as it doesn't
476                          * support regular read/write anyway.
477                          */
478                         if (io_is_uring_fops(file)) {
479                                 fput(file);
480                                 err = -EBADF;
481                                 break;
482                         }
483                         err = io_scm_file_account(ctx, file);
484                         if (err) {
485                                 fput(file);
486                                 break;
487                         }
488                         *io_get_tag_slot(data, i) = tag;
489                         io_fixed_file_set(file_slot, file);
490                         io_file_bitmap_set(&ctx->file_table, i);
491                 }
492         }
493
494         if (needs_switch)
495                 io_rsrc_node_switch(ctx, data);
496         return done ? done : err;
497 }
498
499 static int __io_sqe_buffers_update(struct io_ring_ctx *ctx,
500                                    struct io_uring_rsrc_update2 *up,
501                                    unsigned int nr_args)
502 {
503         u64 __user *tags = u64_to_user_ptr(up->tags);
504         struct iovec iov, __user *iovs = u64_to_user_ptr(up->data);
505         struct page *last_hpage = NULL;
506         bool needs_switch = false;
507         __u32 done;
508         int i, err;
509
510         if (!ctx->buf_data)
511                 return -ENXIO;
512         if (up->offset + nr_args > ctx->nr_user_bufs)
513                 return -EINVAL;
514
515         for (done = 0; done < nr_args; done++) {
516                 struct io_mapped_ubuf *imu;
517                 int offset = up->offset + done;
518                 u64 tag = 0;
519
520                 err = io_copy_iov(ctx, &iov, iovs, done);
521                 if (err)
522                         break;
523                 if (tags && copy_from_user(&tag, &tags[done], sizeof(tag))) {
524                         err = -EFAULT;
525                         break;
526                 }
527                 err = io_buffer_validate(&iov);
528                 if (err)
529                         break;
530                 if (!iov.iov_base && tag) {
531                         err = -EINVAL;
532                         break;
533                 }
534                 err = io_sqe_buffer_register(ctx, &iov, &imu, &last_hpage);
535                 if (err)
536                         break;
537
538                 i = array_index_nospec(offset, ctx->nr_user_bufs);
539                 if (ctx->user_bufs[i] != ctx->dummy_ubuf) {
540                         err = io_queue_rsrc_removal(ctx->buf_data, i,
541                                                     ctx->rsrc_node, ctx->user_bufs[i]);
542                         if (unlikely(err)) {
543                                 io_buffer_unmap(ctx, &imu);
544                                 break;
545                         }
546                         ctx->user_bufs[i] = ctx->dummy_ubuf;
547                         needs_switch = true;
548                 }
549
550                 ctx->user_bufs[i] = imu;
551                 *io_get_tag_slot(ctx->buf_data, offset) = tag;
552         }
553
554         if (needs_switch)
555                 io_rsrc_node_switch(ctx, ctx->buf_data);
556         return done ? done : err;
557 }
558
559 static int __io_register_rsrc_update(struct io_ring_ctx *ctx, unsigned type,
560                                      struct io_uring_rsrc_update2 *up,
561                                      unsigned nr_args)
562 {
563         __u32 tmp;
564         int err;
565
566         if (check_add_overflow(up->offset, nr_args, &tmp))
567                 return -EOVERFLOW;
568         err = io_rsrc_node_switch_start(ctx);
569         if (err)
570                 return err;
571
572         switch (type) {
573         case IORING_RSRC_FILE:
574                 return __io_sqe_files_update(ctx, up, nr_args);
575         case IORING_RSRC_BUFFER:
576                 return __io_sqe_buffers_update(ctx, up, nr_args);
577         }
578         return -EINVAL;
579 }
580
581 int io_register_files_update(struct io_ring_ctx *ctx, void __user *arg,
582                              unsigned nr_args)
583 {
584         struct io_uring_rsrc_update2 up;
585
586         if (!nr_args)
587                 return -EINVAL;
588         memset(&up, 0, sizeof(up));
589         if (copy_from_user(&up, arg, sizeof(struct io_uring_rsrc_update)))
590                 return -EFAULT;
591         if (up.resv || up.resv2)
592                 return -EINVAL;
593         return __io_register_rsrc_update(ctx, IORING_RSRC_FILE, &up, nr_args);
594 }
595
596 int io_register_rsrc_update(struct io_ring_ctx *ctx, void __user *arg,
597                             unsigned size, unsigned type)
598 {
599         struct io_uring_rsrc_update2 up;
600
601         if (size != sizeof(up))
602                 return -EINVAL;
603         if (copy_from_user(&up, arg, sizeof(up)))
604                 return -EFAULT;
605         if (!up.nr || up.resv || up.resv2)
606                 return -EINVAL;
607         return __io_register_rsrc_update(ctx, type, &up, up.nr);
608 }
609
610 __cold int io_register_rsrc(struct io_ring_ctx *ctx, void __user *arg,
611                             unsigned int size, unsigned int type)
612 {
613         struct io_uring_rsrc_register rr;
614
615         /* keep it extendible */
616         if (size != sizeof(rr))
617                 return -EINVAL;
618
619         memset(&rr, 0, sizeof(rr));
620         if (copy_from_user(&rr, arg, size))
621                 return -EFAULT;
622         if (!rr.nr || rr.resv2)
623                 return -EINVAL;
624         if (rr.flags & ~IORING_RSRC_REGISTER_SPARSE)
625                 return -EINVAL;
626
627         switch (type) {
628         case IORING_RSRC_FILE:
629                 if (rr.flags & IORING_RSRC_REGISTER_SPARSE && rr.data)
630                         break;
631                 return io_sqe_files_register(ctx, u64_to_user_ptr(rr.data),
632                                              rr.nr, u64_to_user_ptr(rr.tags));
633         case IORING_RSRC_BUFFER:
634                 if (rr.flags & IORING_RSRC_REGISTER_SPARSE && rr.data)
635                         break;
636                 return io_sqe_buffers_register(ctx, u64_to_user_ptr(rr.data),
637                                                rr.nr, u64_to_user_ptr(rr.tags));
638         }
639         return -EINVAL;
640 }
641
642 int io_files_update_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
643 {
644         struct io_rsrc_update *up = io_kiocb_to_cmd(req, struct io_rsrc_update);
645
646         if (unlikely(req->flags & (REQ_F_FIXED_FILE | REQ_F_BUFFER_SELECT)))
647                 return -EINVAL;
648         if (sqe->rw_flags || sqe->splice_fd_in)
649                 return -EINVAL;
650
651         up->offset = READ_ONCE(sqe->off);
652         up->nr_args = READ_ONCE(sqe->len);
653         if (!up->nr_args)
654                 return -EINVAL;
655         up->arg = READ_ONCE(sqe->addr);
656         return 0;
657 }
658
659 static int io_files_update_with_index_alloc(struct io_kiocb *req,
660                                             unsigned int issue_flags)
661 {
662         struct io_rsrc_update *up = io_kiocb_to_cmd(req, struct io_rsrc_update);
663         __s32 __user *fds = u64_to_user_ptr(up->arg);
664         unsigned int done;
665         struct file *file;
666         int ret, fd;
667
668         if (!req->ctx->file_data)
669                 return -ENXIO;
670
671         for (done = 0; done < up->nr_args; done++) {
672                 if (copy_from_user(&fd, &fds[done], sizeof(fd))) {
673                         ret = -EFAULT;
674                         break;
675                 }
676
677                 file = fget(fd);
678                 if (!file) {
679                         ret = -EBADF;
680                         break;
681                 }
682                 ret = io_fixed_fd_install(req, issue_flags, file,
683                                           IORING_FILE_INDEX_ALLOC);
684                 if (ret < 0)
685                         break;
686                 if (copy_to_user(&fds[done], &ret, sizeof(ret))) {
687                         __io_close_fixed(req->ctx, issue_flags, ret);
688                         ret = -EFAULT;
689                         break;
690                 }
691         }
692
693         if (done)
694                 return done;
695         return ret;
696 }
697
698 int io_files_update(struct io_kiocb *req, unsigned int issue_flags)
699 {
700         struct io_rsrc_update *up = io_kiocb_to_cmd(req, struct io_rsrc_update);
701         struct io_ring_ctx *ctx = req->ctx;
702         struct io_uring_rsrc_update2 up2;
703         int ret;
704
705         up2.offset = up->offset;
706         up2.data = up->arg;
707         up2.nr = 0;
708         up2.tags = 0;
709         up2.resv = 0;
710         up2.resv2 = 0;
711
712         if (up->offset == IORING_FILE_INDEX_ALLOC) {
713                 ret = io_files_update_with_index_alloc(req, issue_flags);
714         } else {
715                 io_ring_submit_lock(ctx, issue_flags);
716                 ret = __io_register_rsrc_update(ctx, IORING_RSRC_FILE,
717                                                 &up2, up->nr_args);
718                 io_ring_submit_unlock(ctx, issue_flags);
719         }
720
721         if (ret < 0)
722                 req_set_fail(req);
723         io_req_set_res(req, ret, 0);
724         return IOU_OK;
725 }
726
727 int io_queue_rsrc_removal(struct io_rsrc_data *data, unsigned idx,
728                           struct io_rsrc_node *node, void *rsrc)
729 {
730         u64 *tag_slot = io_get_tag_slot(data, idx);
731         struct io_rsrc_put *prsrc;
732
733         prsrc = kzalloc(sizeof(*prsrc), GFP_KERNEL);
734         if (!prsrc)
735                 return -ENOMEM;
736
737         prsrc->tag = *tag_slot;
738         *tag_slot = 0;
739         prsrc->rsrc = rsrc;
740         list_add(&prsrc->list, &node->rsrc_list);
741         return 0;
742 }
743
744 void __io_sqe_files_unregister(struct io_ring_ctx *ctx)
745 {
746         int i;
747
748         for (i = 0; i < ctx->nr_user_files; i++) {
749                 struct file *file = io_file_from_index(&ctx->file_table, i);
750
751                 /* skip scm accounted files, they'll be freed by ->ring_sock */
752                 if (!file || io_file_need_scm(file))
753                         continue;
754                 io_file_bitmap_clear(&ctx->file_table, i);
755                 fput(file);
756         }
757
758 #if defined(CONFIG_UNIX)
759         if (ctx->ring_sock) {
760                 struct sock *sock = ctx->ring_sock->sk;
761                 struct sk_buff *skb;
762
763                 while ((skb = skb_dequeue(&sock->sk_receive_queue)) != NULL)
764                         kfree_skb(skb);
765         }
766 #endif
767         io_free_file_tables(&ctx->file_table);
768         io_file_table_set_alloc_range(ctx, 0, 0);
769         io_rsrc_data_free(ctx->file_data);
770         ctx->file_data = NULL;
771         ctx->nr_user_files = 0;
772 }
773
774 int io_sqe_files_unregister(struct io_ring_ctx *ctx)
775 {
776         unsigned nr = ctx->nr_user_files;
777         int ret;
778
779         if (!ctx->file_data)
780                 return -ENXIO;
781
782         /*
783          * Quiesce may unlock ->uring_lock, and while it's not held
784          * prevent new requests using the table.
785          */
786         ctx->nr_user_files = 0;
787         ret = io_rsrc_ref_quiesce(ctx->file_data, ctx);
788         ctx->nr_user_files = nr;
789         if (!ret)
790                 __io_sqe_files_unregister(ctx);
791         return ret;
792 }
793
794 /*
795  * Ensure the UNIX gc is aware of our file set, so we are certain that
796  * the io_uring can be safely unregistered on process exit, even if we have
797  * loops in the file referencing. We account only files that can hold other
798  * files because otherwise they can't form a loop and so are not interesting
799  * for GC.
800  */
801 int __io_scm_file_account(struct io_ring_ctx *ctx, struct file *file)
802 {
803 #if defined(CONFIG_UNIX)
804         struct sock *sk = ctx->ring_sock->sk;
805         struct sk_buff_head *head = &sk->sk_receive_queue;
806         struct scm_fp_list *fpl;
807         struct sk_buff *skb;
808
809         if (likely(!io_file_need_scm(file)))
810                 return 0;
811
812         /*
813          * See if we can merge this file into an existing skb SCM_RIGHTS
814          * file set. If there's no room, fall back to allocating a new skb
815          * and filling it in.
816          */
817         spin_lock_irq(&head->lock);
818         skb = skb_peek(head);
819         if (skb && UNIXCB(skb).fp->count < SCM_MAX_FD)
820                 __skb_unlink(skb, head);
821         else
822                 skb = NULL;
823         spin_unlock_irq(&head->lock);
824
825         if (!skb) {
826                 fpl = kzalloc(sizeof(*fpl), GFP_KERNEL);
827                 if (!fpl)
828                         return -ENOMEM;
829
830                 skb = alloc_skb(0, GFP_KERNEL);
831                 if (!skb) {
832                         kfree(fpl);
833                         return -ENOMEM;
834                 }
835
836                 fpl->user = get_uid(current_user());
837                 fpl->max = SCM_MAX_FD;
838                 fpl->count = 0;
839
840                 UNIXCB(skb).fp = fpl;
841                 skb->sk = sk;
842                 skb->scm_io_uring = 1;
843                 skb->destructor = unix_destruct_scm;
844                 refcount_add(skb->truesize, &sk->sk_wmem_alloc);
845         }
846
847         fpl = UNIXCB(skb).fp;
848         fpl->fp[fpl->count++] = get_file(file);
849         unix_inflight(fpl->user, file);
850         skb_queue_head(head, skb);
851         fput(file);
852 #endif
853         return 0;
854 }
855
856 static void io_rsrc_file_put(struct io_ring_ctx *ctx, struct io_rsrc_put *prsrc)
857 {
858         struct file *file = prsrc->file;
859 #if defined(CONFIG_UNIX)
860         struct sock *sock = ctx->ring_sock->sk;
861         struct sk_buff_head list, *head = &sock->sk_receive_queue;
862         struct sk_buff *skb;
863         int i;
864
865         if (!io_file_need_scm(file)) {
866                 fput(file);
867                 return;
868         }
869
870         __skb_queue_head_init(&list);
871
872         /*
873          * Find the skb that holds this file in its SCM_RIGHTS. When found,
874          * remove this entry and rearrange the file array.
875          */
876         skb = skb_dequeue(head);
877         while (skb) {
878                 struct scm_fp_list *fp;
879
880                 fp = UNIXCB(skb).fp;
881                 for (i = 0; i < fp->count; i++) {
882                         int left;
883
884                         if (fp->fp[i] != file)
885                                 continue;
886
887                         unix_notinflight(fp->user, fp->fp[i]);
888                         left = fp->count - 1 - i;
889                         if (left) {
890                                 memmove(&fp->fp[i], &fp->fp[i + 1],
891                                                 left * sizeof(struct file *));
892                         }
893                         fp->count--;
894                         if (!fp->count) {
895                                 kfree_skb(skb);
896                                 skb = NULL;
897                         } else {
898                                 __skb_queue_tail(&list, skb);
899                         }
900                         fput(file);
901                         file = NULL;
902                         break;
903                 }
904
905                 if (!file)
906                         break;
907
908                 __skb_queue_tail(&list, skb);
909
910                 skb = skb_dequeue(head);
911         }
912
913         if (skb_peek(&list)) {
914                 spin_lock_irq(&head->lock);
915                 while ((skb = __skb_dequeue(&list)) != NULL)
916                         __skb_queue_tail(head, skb);
917                 spin_unlock_irq(&head->lock);
918         }
919 #else
920         fput(file);
921 #endif
922 }
923
924 int io_sqe_files_register(struct io_ring_ctx *ctx, void __user *arg,
925                           unsigned nr_args, u64 __user *tags)
926 {
927         __s32 __user *fds = (__s32 __user *) arg;
928         struct file *file;
929         int fd, ret;
930         unsigned i;
931
932         if (ctx->file_data)
933                 return -EBUSY;
934         if (!nr_args)
935                 return -EINVAL;
936         if (nr_args > IORING_MAX_FIXED_FILES)
937                 return -EMFILE;
938         if (nr_args > rlimit(RLIMIT_NOFILE))
939                 return -EMFILE;
940         ret = io_rsrc_node_switch_start(ctx);
941         if (ret)
942                 return ret;
943         ret = io_rsrc_data_alloc(ctx, io_rsrc_file_put, tags, nr_args,
944                                  &ctx->file_data);
945         if (ret)
946                 return ret;
947
948         if (!io_alloc_file_tables(&ctx->file_table, nr_args)) {
949                 io_rsrc_data_free(ctx->file_data);
950                 ctx->file_data = NULL;
951                 return -ENOMEM;
952         }
953
954         for (i = 0; i < nr_args; i++, ctx->nr_user_files++) {
955                 struct io_fixed_file *file_slot;
956
957                 if (fds && copy_from_user(&fd, &fds[i], sizeof(fd))) {
958                         ret = -EFAULT;
959                         goto fail;
960                 }
961                 /* allow sparse sets */
962                 if (!fds || fd == -1) {
963                         ret = -EINVAL;
964                         if (unlikely(*io_get_tag_slot(ctx->file_data, i)))
965                                 goto fail;
966                         continue;
967                 }
968
969                 file = fget(fd);
970                 ret = -EBADF;
971                 if (unlikely(!file))
972                         goto fail;
973
974                 /*
975                  * Don't allow io_uring instances to be registered. If UNIX
976                  * isn't enabled, then this causes a reference cycle and this
977                  * instance can never get freed. If UNIX is enabled we'll
978                  * handle it just fine, but there's still no point in allowing
979                  * a ring fd as it doesn't support regular read/write anyway.
980                  */
981                 if (io_is_uring_fops(file)) {
982                         fput(file);
983                         goto fail;
984                 }
985                 ret = io_scm_file_account(ctx, file);
986                 if (ret) {
987                         fput(file);
988                         goto fail;
989                 }
990                 file_slot = io_fixed_file_slot(&ctx->file_table, i);
991                 io_fixed_file_set(file_slot, file);
992                 io_file_bitmap_set(&ctx->file_table, i);
993         }
994
995         /* default it to the whole table */
996         io_file_table_set_alloc_range(ctx, 0, ctx->nr_user_files);
997         io_rsrc_node_switch(ctx, NULL);
998         return 0;
999 fail:
1000         __io_sqe_files_unregister(ctx);
1001         return ret;
1002 }
1003
1004 static void io_rsrc_buf_put(struct io_ring_ctx *ctx, struct io_rsrc_put *prsrc)
1005 {
1006         io_buffer_unmap(ctx, &prsrc->buf);
1007         prsrc->buf = NULL;
1008 }
1009
1010 void __io_sqe_buffers_unregister(struct io_ring_ctx *ctx)
1011 {
1012         unsigned int i;
1013
1014         for (i = 0; i < ctx->nr_user_bufs; i++)
1015                 io_buffer_unmap(ctx, &ctx->user_bufs[i]);
1016         kfree(ctx->user_bufs);
1017         io_rsrc_data_free(ctx->buf_data);
1018         ctx->user_bufs = NULL;
1019         ctx->buf_data = NULL;
1020         ctx->nr_user_bufs = 0;
1021 }
1022
1023 int io_sqe_buffers_unregister(struct io_ring_ctx *ctx)
1024 {
1025         unsigned nr = ctx->nr_user_bufs;
1026         int ret;
1027
1028         if (!ctx->buf_data)
1029                 return -ENXIO;
1030
1031         /*
1032          * Quiesce may unlock ->uring_lock, and while it's not held
1033          * prevent new requests using the table.
1034          */
1035         ctx->nr_user_bufs = 0;
1036         ret = io_rsrc_ref_quiesce(ctx->buf_data, ctx);
1037         ctx->nr_user_bufs = nr;
1038         if (!ret)
1039                 __io_sqe_buffers_unregister(ctx);
1040         return ret;
1041 }
1042
1043 /*
1044  * Not super efficient, but this is just a registration time. And we do cache
1045  * the last compound head, so generally we'll only do a full search if we don't
1046  * match that one.
1047  *
1048  * We check if the given compound head page has already been accounted, to
1049  * avoid double accounting it. This allows us to account the full size of the
1050  * page, not just the constituent pages of a huge page.
1051  */
1052 static bool headpage_already_acct(struct io_ring_ctx *ctx, struct page **pages,
1053                                   int nr_pages, struct page *hpage)
1054 {
1055         int i, j;
1056
1057         /* check current page array */
1058         for (i = 0; i < nr_pages; i++) {
1059                 if (!PageCompound(pages[i]))
1060                         continue;
1061                 if (compound_head(pages[i]) == hpage)
1062                         return true;
1063         }
1064
1065         /* check previously registered pages */
1066         for (i = 0; i < ctx->nr_user_bufs; i++) {
1067                 struct io_mapped_ubuf *imu = ctx->user_bufs[i];
1068
1069                 for (j = 0; j < imu->nr_bvecs; j++) {
1070                         if (!PageCompound(imu->bvec[j].bv_page))
1071                                 continue;
1072                         if (compound_head(imu->bvec[j].bv_page) == hpage)
1073                                 return true;
1074                 }
1075         }
1076
1077         return false;
1078 }
1079
1080 static int io_buffer_account_pin(struct io_ring_ctx *ctx, struct page **pages,
1081                                  int nr_pages, struct io_mapped_ubuf *imu,
1082                                  struct page **last_hpage)
1083 {
1084         int i, ret;
1085
1086         imu->acct_pages = 0;
1087         for (i = 0; i < nr_pages; i++) {
1088                 if (!PageCompound(pages[i])) {
1089                         imu->acct_pages++;
1090                 } else {
1091                         struct page *hpage;
1092
1093                         hpage = compound_head(pages[i]);
1094                         if (hpage == *last_hpage)
1095                                 continue;
1096                         *last_hpage = hpage;
1097                         if (headpage_already_acct(ctx, pages, i, hpage))
1098                                 continue;
1099                         imu->acct_pages += page_size(hpage) >> PAGE_SHIFT;
1100                 }
1101         }
1102
1103         if (!imu->acct_pages)
1104                 return 0;
1105
1106         ret = io_account_mem(ctx, imu->acct_pages);
1107         if (ret)
1108                 imu->acct_pages = 0;
1109         return ret;
1110 }
1111
1112 struct page **io_pin_pages(unsigned long ubuf, unsigned long len, int *npages)
1113 {
1114         unsigned long start, end, nr_pages;
1115         struct vm_area_struct **vmas = NULL;
1116         struct page **pages = NULL;
1117         int i, pret, ret = -ENOMEM;
1118
1119         end = (ubuf + len + PAGE_SIZE - 1) >> PAGE_SHIFT;
1120         start = ubuf >> PAGE_SHIFT;
1121         nr_pages = end - start;
1122
1123         pages = kvmalloc_array(nr_pages, sizeof(struct page *), GFP_KERNEL);
1124         if (!pages)
1125                 goto done;
1126
1127         vmas = kvmalloc_array(nr_pages, sizeof(struct vm_area_struct *),
1128                               GFP_KERNEL);
1129         if (!vmas)
1130                 goto done;
1131
1132         ret = 0;
1133         mmap_read_lock(current->mm);
1134         pret = pin_user_pages(ubuf, nr_pages, FOLL_WRITE | FOLL_LONGTERM,
1135                               pages, vmas);
1136         if (pret == nr_pages) {
1137                 struct file *file = vmas[0]->vm_file;
1138
1139                 /* don't support file backed memory */
1140                 for (i = 0; i < nr_pages; i++) {
1141                         if (vmas[i]->vm_file != file) {
1142                                 ret = -EINVAL;
1143                                 break;
1144                         }
1145                         if (!file)
1146                                 continue;
1147                         if (!vma_is_shmem(vmas[i]) && !is_file_hugepages(file)) {
1148                                 ret = -EOPNOTSUPP;
1149                                 break;
1150                         }
1151                 }
1152                 *npages = nr_pages;
1153         } else {
1154                 ret = pret < 0 ? pret : -EFAULT;
1155         }
1156         mmap_read_unlock(current->mm);
1157         if (ret) {
1158                 /*
1159                  * if we did partial map, or found file backed vmas,
1160                  * release any pages we did get
1161                  */
1162                 if (pret > 0)
1163                         unpin_user_pages(pages, pret);
1164                 goto done;
1165         }
1166         ret = 0;
1167 done:
1168         kvfree(vmas);
1169         if (ret < 0) {
1170                 kvfree(pages);
1171                 pages = ERR_PTR(ret);
1172         }
1173         return pages;
1174 }
1175
1176 static int io_sqe_buffer_register(struct io_ring_ctx *ctx, struct iovec *iov,
1177                                   struct io_mapped_ubuf **pimu,
1178                                   struct page **last_hpage)
1179 {
1180         struct io_mapped_ubuf *imu = NULL;
1181         struct page **pages = NULL;
1182         unsigned long off;
1183         size_t size;
1184         int ret, nr_pages, i;
1185         struct folio *folio = NULL;
1186
1187         *pimu = ctx->dummy_ubuf;
1188         if (!iov->iov_base)
1189                 return 0;
1190
1191         ret = -ENOMEM;
1192         pages = io_pin_pages((unsigned long) iov->iov_base, iov->iov_len,
1193                                 &nr_pages);
1194         if (IS_ERR(pages)) {
1195                 ret = PTR_ERR(pages);
1196                 pages = NULL;
1197                 goto done;
1198         }
1199
1200         /* If it's a huge page, try to coalesce them into a single bvec entry */
1201         if (nr_pages > 1) {
1202                 folio = page_folio(pages[0]);
1203                 for (i = 1; i < nr_pages; i++) {
1204                         if (page_folio(pages[i]) != folio) {
1205                                 folio = NULL;
1206                                 break;
1207                         }
1208                 }
1209                 if (folio) {
1210                         /*
1211                          * The pages are bound to the folio, it doesn't
1212                          * actually unpin them but drops all but one reference,
1213                          * which is usually put down by io_buffer_unmap().
1214                          * Note, needs a better helper.
1215                          */
1216                         unpin_user_pages(&pages[1], nr_pages - 1);
1217                         nr_pages = 1;
1218                 }
1219         }
1220
1221         imu = kvmalloc(struct_size(imu, bvec, nr_pages), GFP_KERNEL);
1222         if (!imu)
1223                 goto done;
1224
1225         ret = io_buffer_account_pin(ctx, pages, nr_pages, imu, last_hpage);
1226         if (ret) {
1227                 unpin_user_pages(pages, nr_pages);
1228                 goto done;
1229         }
1230
1231         off = (unsigned long) iov->iov_base & ~PAGE_MASK;
1232         size = iov->iov_len;
1233         /* store original address for later verification */
1234         imu->ubuf = (unsigned long) iov->iov_base;
1235         imu->ubuf_end = imu->ubuf + iov->iov_len;
1236         imu->nr_bvecs = nr_pages;
1237         *pimu = imu;
1238         ret = 0;
1239
1240         if (folio) {
1241                 bvec_set_page(&imu->bvec[0], pages[0], size, off);
1242                 goto done;
1243         }
1244         for (i = 0; i < nr_pages; i++) {
1245                 size_t vec_len;
1246
1247                 vec_len = min_t(size_t, size, PAGE_SIZE - off);
1248                 bvec_set_page(&imu->bvec[i], pages[i], vec_len, off);
1249                 off = 0;
1250                 size -= vec_len;
1251         }
1252 done:
1253         if (ret)
1254                 kvfree(imu);
1255         kvfree(pages);
1256         return ret;
1257 }
1258
1259 static int io_buffers_map_alloc(struct io_ring_ctx *ctx, unsigned int nr_args)
1260 {
1261         ctx->user_bufs = kcalloc(nr_args, sizeof(*ctx->user_bufs), GFP_KERNEL);
1262         return ctx->user_bufs ? 0 : -ENOMEM;
1263 }
1264
1265 int io_sqe_buffers_register(struct io_ring_ctx *ctx, void __user *arg,
1266                             unsigned int nr_args, u64 __user *tags)
1267 {
1268         struct page *last_hpage = NULL;
1269         struct io_rsrc_data *data;
1270         int i, ret;
1271         struct iovec iov;
1272
1273         BUILD_BUG_ON(IORING_MAX_REG_BUFFERS >= (1u << 16));
1274
1275         if (ctx->user_bufs)
1276                 return -EBUSY;
1277         if (!nr_args || nr_args > IORING_MAX_REG_BUFFERS)
1278                 return -EINVAL;
1279         ret = io_rsrc_node_switch_start(ctx);
1280         if (ret)
1281                 return ret;
1282         ret = io_rsrc_data_alloc(ctx, io_rsrc_buf_put, tags, nr_args, &data);
1283         if (ret)
1284                 return ret;
1285         ret = io_buffers_map_alloc(ctx, nr_args);
1286         if (ret) {
1287                 io_rsrc_data_free(data);
1288                 return ret;
1289         }
1290
1291         for (i = 0; i < nr_args; i++, ctx->nr_user_bufs++) {
1292                 if (arg) {
1293                         ret = io_copy_iov(ctx, &iov, arg, i);
1294                         if (ret)
1295                                 break;
1296                         ret = io_buffer_validate(&iov);
1297                         if (ret)
1298                                 break;
1299                 } else {
1300                         memset(&iov, 0, sizeof(iov));
1301                 }
1302
1303                 if (!iov.iov_base && *io_get_tag_slot(data, i)) {
1304                         ret = -EINVAL;
1305                         break;
1306                 }
1307
1308                 ret = io_sqe_buffer_register(ctx, &iov, &ctx->user_bufs[i],
1309                                              &last_hpage);
1310                 if (ret)
1311                         break;
1312         }
1313
1314         WARN_ON_ONCE(ctx->buf_data);
1315
1316         ctx->buf_data = data;
1317         if (ret)
1318                 __io_sqe_buffers_unregister(ctx);
1319         else
1320                 io_rsrc_node_switch(ctx, NULL);
1321         return ret;
1322 }
1323
1324 int io_import_fixed(int ddir, struct iov_iter *iter,
1325                            struct io_mapped_ubuf *imu,
1326                            u64 buf_addr, size_t len)
1327 {
1328         u64 buf_end;
1329         size_t offset;
1330
1331         if (WARN_ON_ONCE(!imu))
1332                 return -EFAULT;
1333         if (unlikely(check_add_overflow(buf_addr, (u64)len, &buf_end)))
1334                 return -EFAULT;
1335         /* not inside the mapped region */
1336         if (unlikely(buf_addr < imu->ubuf || buf_end > imu->ubuf_end))
1337                 return -EFAULT;
1338
1339         /*
1340          * Might not be a start of buffer, set size appropriately
1341          * and advance us to the beginning.
1342          */
1343         offset = buf_addr - imu->ubuf;
1344         iov_iter_bvec(iter, ddir, imu->bvec, imu->nr_bvecs, offset + len);
1345
1346         if (offset) {
1347                 /*
1348                  * Don't use iov_iter_advance() here, as it's really slow for
1349                  * using the latter parts of a big fixed buffer - it iterates
1350                  * over each segment manually. We can cheat a bit here, because
1351                  * we know that:
1352                  *
1353                  * 1) it's a BVEC iter, we set it up
1354                  * 2) all bvecs are PAGE_SIZE in size, except potentially the
1355                  *    first and last bvec
1356                  *
1357                  * So just find our index, and adjust the iterator afterwards.
1358                  * If the offset is within the first bvec (or the whole first
1359                  * bvec, just use iov_iter_advance(). This makes it easier
1360                  * since we can just skip the first segment, which may not
1361                  * be PAGE_SIZE aligned.
1362                  */
1363                 const struct bio_vec *bvec = imu->bvec;
1364
1365                 if (offset <= bvec->bv_len) {
1366                         /*
1367                          * Note, huge pages buffers consists of one large
1368                          * bvec entry and should always go this way. The other
1369                          * branch doesn't expect non PAGE_SIZE'd chunks.
1370                          */
1371                         iter->bvec = bvec;
1372                         iter->nr_segs = bvec->bv_len;
1373                         iter->count -= offset;
1374                         iter->iov_offset = offset;
1375                 } else {
1376                         unsigned long seg_skip;
1377
1378                         /* skip first vec */
1379                         offset -= bvec->bv_len;
1380                         seg_skip = 1 + (offset >> PAGE_SHIFT);
1381
1382                         iter->bvec = bvec + seg_skip;
1383                         iter->nr_segs -= seg_skip;
1384                         iter->count -= bvec->bv_len + offset;
1385                         iter->iov_offset = offset & ~PAGE_MASK;
1386                 }
1387         }
1388
1389         return 0;
1390 }