drivers: media: arducam_64mp: Add V4L2_CID_LINK_FREQ control
[platform/kernel/linux-rpi.git] / drivers / iommu / iommufd / pages.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * The iopt_pages is the center of the storage and motion of PFNs. Each
5  * iopt_pages represents a logical linear array of full PFNs. The array is 0
6  * based and has npages in it. Accessors use 'index' to refer to the entry in
7  * this logical array, regardless of its storage location.
8  *
9  * PFNs are stored in a tiered scheme:
10  *  1) iopt_pages::pinned_pfns xarray
11  *  2) An iommu_domain
12  *  3) The origin of the PFNs, i.e. the userspace pointer
13  *
14  * PFN have to be copied between all combinations of tiers, depending on the
15  * configuration.
16  *
17  * When a PFN is taken out of the userspace pointer it is pinned exactly once.
18  * The storage locations of the PFN's index are tracked in the two interval
19  * trees. If no interval includes the index then it is not pinned.
20  *
21  * If access_itree includes the PFN's index then an in-kernel access has
22  * requested the page. The PFN is stored in the xarray so other requestors can
23  * continue to find it.
24  *
25  * If the domains_itree includes the PFN's index then an iommu_domain is storing
26  * the PFN and it can be read back using iommu_iova_to_phys(). To avoid
27  * duplicating storage the xarray is not used if only iommu_domains are using
28  * the PFN's index.
29  *
30  * As a general principle this is designed so that destroy never fails. This
31  * means removing an iommu_domain or releasing a in-kernel access will not fail
32  * due to insufficient memory. In practice this means some cases have to hold
33  * PFNs in the xarray even though they are also being stored in an iommu_domain.
34  *
35  * While the iopt_pages can use an iommu_domain as storage, it does not have an
36  * IOVA itself. Instead the iopt_area represents a range of IOVA and uses the
37  * iopt_pages as the PFN provider. Multiple iopt_areas can share the iopt_pages
38  * and reference their own slice of the PFN array, with sub page granularity.
39  *
40  * In this file the term 'last' indicates an inclusive and closed interval, eg
41  * [0,0] refers to a single PFN. 'end' means an open range, eg [0,0) refers to
42  * no PFNs.
43  *
44  * Be cautious of overflow. An IOVA can go all the way up to U64_MAX, so
45  * last_iova + 1 can overflow. An iopt_pages index will always be much less than
46  * ULONG_MAX so last_index + 1 cannot overflow.
47  */
48 #include <linux/overflow.h>
49 #include <linux/slab.h>
50 #include <linux/iommu.h>
51 #include <linux/sched/mm.h>
52 #include <linux/highmem.h>
53 #include <linux/kthread.h>
54 #include <linux/iommufd.h>
55
56 #include "io_pagetable.h"
57 #include "double_span.h"
58
59 #ifndef CONFIG_IOMMUFD_TEST
60 #define TEMP_MEMORY_LIMIT 65536
61 #else
62 #define TEMP_MEMORY_LIMIT iommufd_test_memory_limit
63 #endif
64 #define BATCH_BACKUP_SIZE 32
65
66 /*
67  * More memory makes pin_user_pages() and the batching more efficient, but as
68  * this is only a performance optimization don't try too hard to get it. A 64k
69  * allocation can hold about 26M of 4k pages and 13G of 2M pages in an
70  * pfn_batch. Various destroy paths cannot fail and provide a small amount of
71  * stack memory as a backup contingency. If backup_len is given this cannot
72  * fail.
73  */
74 static void *temp_kmalloc(size_t *size, void *backup, size_t backup_len)
75 {
76         void *res;
77
78         if (WARN_ON(*size == 0))
79                 return NULL;
80
81         if (*size < backup_len)
82                 return backup;
83
84         if (!backup && iommufd_should_fail())
85                 return NULL;
86
87         *size = min_t(size_t, *size, TEMP_MEMORY_LIMIT);
88         res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
89         if (res)
90                 return res;
91         *size = PAGE_SIZE;
92         if (backup_len) {
93                 res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
94                 if (res)
95                         return res;
96                 *size = backup_len;
97                 return backup;
98         }
99         return kmalloc(*size, GFP_KERNEL);
100 }
101
102 void interval_tree_double_span_iter_update(
103         struct interval_tree_double_span_iter *iter)
104 {
105         unsigned long last_hole = ULONG_MAX;
106         unsigned int i;
107
108         for (i = 0; i != ARRAY_SIZE(iter->spans); i++) {
109                 if (interval_tree_span_iter_done(&iter->spans[i])) {
110                         iter->is_used = -1;
111                         return;
112                 }
113
114                 if (iter->spans[i].is_hole) {
115                         last_hole = min(last_hole, iter->spans[i].last_hole);
116                         continue;
117                 }
118
119                 iter->is_used = i + 1;
120                 iter->start_used = iter->spans[i].start_used;
121                 iter->last_used = min(iter->spans[i].last_used, last_hole);
122                 return;
123         }
124
125         iter->is_used = 0;
126         iter->start_hole = iter->spans[0].start_hole;
127         iter->last_hole =
128                 min(iter->spans[0].last_hole, iter->spans[1].last_hole);
129 }
130
131 void interval_tree_double_span_iter_first(
132         struct interval_tree_double_span_iter *iter,
133         struct rb_root_cached *itree1, struct rb_root_cached *itree2,
134         unsigned long first_index, unsigned long last_index)
135 {
136         unsigned int i;
137
138         iter->itrees[0] = itree1;
139         iter->itrees[1] = itree2;
140         for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
141                 interval_tree_span_iter_first(&iter->spans[i], iter->itrees[i],
142                                               first_index, last_index);
143         interval_tree_double_span_iter_update(iter);
144 }
145
146 void interval_tree_double_span_iter_next(
147         struct interval_tree_double_span_iter *iter)
148 {
149         unsigned int i;
150
151         if (iter->is_used == -1 ||
152             iter->last_hole == iter->spans[0].last_index) {
153                 iter->is_used = -1;
154                 return;
155         }
156
157         for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
158                 interval_tree_span_iter_advance(
159                         &iter->spans[i], iter->itrees[i], iter->last_hole + 1);
160         interval_tree_double_span_iter_update(iter);
161 }
162
163 static void iopt_pages_add_npinned(struct iopt_pages *pages, size_t npages)
164 {
165         int rc;
166
167         rc = check_add_overflow(pages->npinned, npages, &pages->npinned);
168         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
169                 WARN_ON(rc || pages->npinned > pages->npages);
170 }
171
172 static void iopt_pages_sub_npinned(struct iopt_pages *pages, size_t npages)
173 {
174         int rc;
175
176         rc = check_sub_overflow(pages->npinned, npages, &pages->npinned);
177         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
178                 WARN_ON(rc || pages->npinned > pages->npages);
179 }
180
181 static void iopt_pages_err_unpin(struct iopt_pages *pages,
182                                  unsigned long start_index,
183                                  unsigned long last_index,
184                                  struct page **page_list)
185 {
186         unsigned long npages = last_index - start_index + 1;
187
188         unpin_user_pages(page_list, npages);
189         iopt_pages_sub_npinned(pages, npages);
190 }
191
192 /*
193  * index is the number of PAGE_SIZE units from the start of the area's
194  * iopt_pages. If the iova is sub page-size then the area has an iova that
195  * covers a portion of the first and last pages in the range.
196  */
197 static unsigned long iopt_area_index_to_iova(struct iopt_area *area,
198                                              unsigned long index)
199 {
200         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
201                 WARN_ON(index < iopt_area_index(area) ||
202                         index > iopt_area_last_index(area));
203         index -= iopt_area_index(area);
204         if (index == 0)
205                 return iopt_area_iova(area);
206         return iopt_area_iova(area) - area->page_offset + index * PAGE_SIZE;
207 }
208
209 static unsigned long iopt_area_index_to_iova_last(struct iopt_area *area,
210                                                   unsigned long index)
211 {
212         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
213                 WARN_ON(index < iopt_area_index(area) ||
214                         index > iopt_area_last_index(area));
215         if (index == iopt_area_last_index(area))
216                 return iopt_area_last_iova(area);
217         return iopt_area_iova(area) - area->page_offset +
218                (index - iopt_area_index(area) + 1) * PAGE_SIZE - 1;
219 }
220
221 static void iommu_unmap_nofail(struct iommu_domain *domain, unsigned long iova,
222                                size_t size)
223 {
224         size_t ret;
225
226         ret = iommu_unmap(domain, iova, size);
227         /*
228          * It is a logic error in this code or a driver bug if the IOMMU unmaps
229          * something other than exactly as requested. This implies that the
230          * iommu driver may not fail unmap for reasons beyond bad agruments.
231          * Particularly, the iommu driver may not do a memory allocation on the
232          * unmap path.
233          */
234         WARN_ON(ret != size);
235 }
236
237 static void iopt_area_unmap_domain_range(struct iopt_area *area,
238                                          struct iommu_domain *domain,
239                                          unsigned long start_index,
240                                          unsigned long last_index)
241 {
242         unsigned long start_iova = iopt_area_index_to_iova(area, start_index);
243
244         iommu_unmap_nofail(domain, start_iova,
245                            iopt_area_index_to_iova_last(area, last_index) -
246                                    start_iova + 1);
247 }
248
249 static struct iopt_area *iopt_pages_find_domain_area(struct iopt_pages *pages,
250                                                      unsigned long index)
251 {
252         struct interval_tree_node *node;
253
254         node = interval_tree_iter_first(&pages->domains_itree, index, index);
255         if (!node)
256                 return NULL;
257         return container_of(node, struct iopt_area, pages_node);
258 }
259
260 /*
261  * A simple datastructure to hold a vector of PFNs, optimized for contiguous
262  * PFNs. This is used as a temporary holding memory for shuttling pfns from one
263  * place to another. Generally everything is made more efficient if operations
264  * work on the largest possible grouping of pfns. eg fewer lock/unlock cycles,
265  * better cache locality, etc
266  */
267 struct pfn_batch {
268         unsigned long *pfns;
269         u32 *npfns;
270         unsigned int array_size;
271         unsigned int end;
272         unsigned int total_pfns;
273 };
274
275 static void batch_clear(struct pfn_batch *batch)
276 {
277         batch->total_pfns = 0;
278         batch->end = 0;
279         batch->pfns[0] = 0;
280         batch->npfns[0] = 0;
281 }
282
283 /*
284  * Carry means we carry a portion of the final hugepage over to the front of the
285  * batch
286  */
287 static void batch_clear_carry(struct pfn_batch *batch, unsigned int keep_pfns)
288 {
289         if (!keep_pfns)
290                 return batch_clear(batch);
291
292         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
293                 WARN_ON(!batch->end ||
294                         batch->npfns[batch->end - 1] < keep_pfns);
295
296         batch->total_pfns = keep_pfns;
297         batch->pfns[0] = batch->pfns[batch->end - 1] +
298                          (batch->npfns[batch->end - 1] - keep_pfns);
299         batch->npfns[0] = keep_pfns;
300         batch->end = 1;
301 }
302
303 static void batch_skip_carry(struct pfn_batch *batch, unsigned int skip_pfns)
304 {
305         if (!batch->total_pfns)
306                 return;
307         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
308                 WARN_ON(batch->total_pfns != batch->npfns[0]);
309         skip_pfns = min(batch->total_pfns, skip_pfns);
310         batch->pfns[0] += skip_pfns;
311         batch->npfns[0] -= skip_pfns;
312         batch->total_pfns -= skip_pfns;
313 }
314
315 static int __batch_init(struct pfn_batch *batch, size_t max_pages, void *backup,
316                         size_t backup_len)
317 {
318         const size_t elmsz = sizeof(*batch->pfns) + sizeof(*batch->npfns);
319         size_t size = max_pages * elmsz;
320
321         batch->pfns = temp_kmalloc(&size, backup, backup_len);
322         if (!batch->pfns)
323                 return -ENOMEM;
324         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) && WARN_ON(size < elmsz))
325                 return -EINVAL;
326         batch->array_size = size / elmsz;
327         batch->npfns = (u32 *)(batch->pfns + batch->array_size);
328         batch_clear(batch);
329         return 0;
330 }
331
332 static int batch_init(struct pfn_batch *batch, size_t max_pages)
333 {
334         return __batch_init(batch, max_pages, NULL, 0);
335 }
336
337 static void batch_init_backup(struct pfn_batch *batch, size_t max_pages,
338                               void *backup, size_t backup_len)
339 {
340         __batch_init(batch, max_pages, backup, backup_len);
341 }
342
343 static void batch_destroy(struct pfn_batch *batch, void *backup)
344 {
345         if (batch->pfns != backup)
346                 kfree(batch->pfns);
347 }
348
349 /* true if the pfn was added, false otherwise */
350 static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
351 {
352         const unsigned int MAX_NPFNS = type_max(typeof(*batch->npfns));
353
354         if (batch->end &&
355             pfn == batch->pfns[batch->end - 1] + batch->npfns[batch->end - 1] &&
356             batch->npfns[batch->end - 1] != MAX_NPFNS) {
357                 batch->npfns[batch->end - 1]++;
358                 batch->total_pfns++;
359                 return true;
360         }
361         if (batch->end == batch->array_size)
362                 return false;
363         batch->total_pfns++;
364         batch->pfns[batch->end] = pfn;
365         batch->npfns[batch->end] = 1;
366         batch->end++;
367         return true;
368 }
369
370 /*
371  * Fill the batch with pfns from the domain. When the batch is full, or it
372  * reaches last_index, the function will return. The caller should use
373  * batch->total_pfns to determine the starting point for the next iteration.
374  */
375 static void batch_from_domain(struct pfn_batch *batch,
376                               struct iommu_domain *domain,
377                               struct iopt_area *area, unsigned long start_index,
378                               unsigned long last_index)
379 {
380         unsigned int page_offset = 0;
381         unsigned long iova;
382         phys_addr_t phys;
383
384         iova = iopt_area_index_to_iova(area, start_index);
385         if (start_index == iopt_area_index(area))
386                 page_offset = area->page_offset;
387         while (start_index <= last_index) {
388                 /*
389                  * This is pretty slow, it would be nice to get the page size
390                  * back from the driver, or have the driver directly fill the
391                  * batch.
392                  */
393                 phys = iommu_iova_to_phys(domain, iova) - page_offset;
394                 if (!batch_add_pfn(batch, PHYS_PFN(phys)))
395                         return;
396                 iova += PAGE_SIZE - page_offset;
397                 page_offset = 0;
398                 start_index++;
399         }
400 }
401
402 static struct page **raw_pages_from_domain(struct iommu_domain *domain,
403                                            struct iopt_area *area,
404                                            unsigned long start_index,
405                                            unsigned long last_index,
406                                            struct page **out_pages)
407 {
408         unsigned int page_offset = 0;
409         unsigned long iova;
410         phys_addr_t phys;
411
412         iova = iopt_area_index_to_iova(area, start_index);
413         if (start_index == iopt_area_index(area))
414                 page_offset = area->page_offset;
415         while (start_index <= last_index) {
416                 phys = iommu_iova_to_phys(domain, iova) - page_offset;
417                 *(out_pages++) = pfn_to_page(PHYS_PFN(phys));
418                 iova += PAGE_SIZE - page_offset;
419                 page_offset = 0;
420                 start_index++;
421         }
422         return out_pages;
423 }
424
425 /* Continues reading a domain until we reach a discontinuity in the pfns. */
426 static void batch_from_domain_continue(struct pfn_batch *batch,
427                                        struct iommu_domain *domain,
428                                        struct iopt_area *area,
429                                        unsigned long start_index,
430                                        unsigned long last_index)
431 {
432         unsigned int array_size = batch->array_size;
433
434         batch->array_size = batch->end;
435         batch_from_domain(batch, domain, area, start_index, last_index);
436         batch->array_size = array_size;
437 }
438
439 /*
440  * This is part of the VFIO compatibility support for VFIO_TYPE1_IOMMU. That
441  * mode permits splitting a mapped area up, and then one of the splits is
442  * unmapped. Doing this normally would cause us to violate our invariant of
443  * pairing map/unmap. Thus, to support old VFIO compatibility disable support
444  * for batching consecutive PFNs. All PFNs mapped into the iommu are done in
445  * PAGE_SIZE units, not larger or smaller.
446  */
447 static int batch_iommu_map_small(struct iommu_domain *domain,
448                                  unsigned long iova, phys_addr_t paddr,
449                                  size_t size, int prot)
450 {
451         unsigned long start_iova = iova;
452         int rc;
453
454         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
455                 WARN_ON(paddr % PAGE_SIZE || iova % PAGE_SIZE ||
456                         size % PAGE_SIZE);
457
458         while (size) {
459                 rc = iommu_map(domain, iova, paddr, PAGE_SIZE, prot,
460                                GFP_KERNEL_ACCOUNT);
461                 if (rc)
462                         goto err_unmap;
463                 iova += PAGE_SIZE;
464                 paddr += PAGE_SIZE;
465                 size -= PAGE_SIZE;
466         }
467         return 0;
468
469 err_unmap:
470         if (start_iova != iova)
471                 iommu_unmap_nofail(domain, start_iova, iova - start_iova);
472         return rc;
473 }
474
475 static int batch_to_domain(struct pfn_batch *batch, struct iommu_domain *domain,
476                            struct iopt_area *area, unsigned long start_index)
477 {
478         bool disable_large_pages = area->iopt->disable_large_pages;
479         unsigned long last_iova = iopt_area_last_iova(area);
480         unsigned int page_offset = 0;
481         unsigned long start_iova;
482         unsigned long next_iova;
483         unsigned int cur = 0;
484         unsigned long iova;
485         int rc;
486
487         /* The first index might be a partial page */
488         if (start_index == iopt_area_index(area))
489                 page_offset = area->page_offset;
490         next_iova = iova = start_iova =
491                 iopt_area_index_to_iova(area, start_index);
492         while (cur < batch->end) {
493                 next_iova = min(last_iova + 1,
494                                 next_iova + batch->npfns[cur] * PAGE_SIZE -
495                                         page_offset);
496                 if (disable_large_pages)
497                         rc = batch_iommu_map_small(
498                                 domain, iova,
499                                 PFN_PHYS(batch->pfns[cur]) + page_offset,
500                                 next_iova - iova, area->iommu_prot);
501                 else
502                         rc = iommu_map(domain, iova,
503                                        PFN_PHYS(batch->pfns[cur]) + page_offset,
504                                        next_iova - iova, area->iommu_prot,
505                                        GFP_KERNEL_ACCOUNT);
506                 if (rc)
507                         goto err_unmap;
508                 iova = next_iova;
509                 page_offset = 0;
510                 cur++;
511         }
512         return 0;
513 err_unmap:
514         if (start_iova != iova)
515                 iommu_unmap_nofail(domain, start_iova, iova - start_iova);
516         return rc;
517 }
518
519 static void batch_from_xarray(struct pfn_batch *batch, struct xarray *xa,
520                               unsigned long start_index,
521                               unsigned long last_index)
522 {
523         XA_STATE(xas, xa, start_index);
524         void *entry;
525
526         rcu_read_lock();
527         while (true) {
528                 entry = xas_next(&xas);
529                 if (xas_retry(&xas, entry))
530                         continue;
531                 WARN_ON(!xa_is_value(entry));
532                 if (!batch_add_pfn(batch, xa_to_value(entry)) ||
533                     start_index == last_index)
534                         break;
535                 start_index++;
536         }
537         rcu_read_unlock();
538 }
539
540 static void batch_from_xarray_clear(struct pfn_batch *batch, struct xarray *xa,
541                                     unsigned long start_index,
542                                     unsigned long last_index)
543 {
544         XA_STATE(xas, xa, start_index);
545         void *entry;
546
547         xas_lock(&xas);
548         while (true) {
549                 entry = xas_next(&xas);
550                 if (xas_retry(&xas, entry))
551                         continue;
552                 WARN_ON(!xa_is_value(entry));
553                 if (!batch_add_pfn(batch, xa_to_value(entry)))
554                         break;
555                 xas_store(&xas, NULL);
556                 if (start_index == last_index)
557                         break;
558                 start_index++;
559         }
560         xas_unlock(&xas);
561 }
562
563 static void clear_xarray(struct xarray *xa, unsigned long start_index,
564                          unsigned long last_index)
565 {
566         XA_STATE(xas, xa, start_index);
567         void *entry;
568
569         xas_lock(&xas);
570         xas_for_each(&xas, entry, last_index)
571                 xas_store(&xas, NULL);
572         xas_unlock(&xas);
573 }
574
575 static int pages_to_xarray(struct xarray *xa, unsigned long start_index,
576                            unsigned long last_index, struct page **pages)
577 {
578         struct page **end_pages = pages + (last_index - start_index) + 1;
579         struct page **half_pages = pages + (end_pages - pages) / 2;
580         XA_STATE(xas, xa, start_index);
581
582         do {
583                 void *old;
584
585                 xas_lock(&xas);
586                 while (pages != end_pages) {
587                         /* xarray does not participate in fault injection */
588                         if (pages == half_pages && iommufd_should_fail()) {
589                                 xas_set_err(&xas, -EINVAL);
590                                 xas_unlock(&xas);
591                                 /* aka xas_destroy() */
592                                 xas_nomem(&xas, GFP_KERNEL);
593                                 goto err_clear;
594                         }
595
596                         old = xas_store(&xas, xa_mk_value(page_to_pfn(*pages)));
597                         if (xas_error(&xas))
598                                 break;
599                         WARN_ON(old);
600                         pages++;
601                         xas_next(&xas);
602                 }
603                 xas_unlock(&xas);
604         } while (xas_nomem(&xas, GFP_KERNEL));
605
606 err_clear:
607         if (xas_error(&xas)) {
608                 if (xas.xa_index != start_index)
609                         clear_xarray(xa, start_index, xas.xa_index - 1);
610                 return xas_error(&xas);
611         }
612         return 0;
613 }
614
615 static void batch_from_pages(struct pfn_batch *batch, struct page **pages,
616                              size_t npages)
617 {
618         struct page **end = pages + npages;
619
620         for (; pages != end; pages++)
621                 if (!batch_add_pfn(batch, page_to_pfn(*pages)))
622                         break;
623 }
624
625 static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages,
626                         unsigned int first_page_off, size_t npages)
627 {
628         unsigned int cur = 0;
629
630         while (first_page_off) {
631                 if (batch->npfns[cur] > first_page_off)
632                         break;
633                 first_page_off -= batch->npfns[cur];
634                 cur++;
635         }
636
637         while (npages) {
638                 size_t to_unpin = min_t(size_t, npages,
639                                         batch->npfns[cur] - first_page_off);
640
641                 unpin_user_page_range_dirty_lock(
642                         pfn_to_page(batch->pfns[cur] + first_page_off),
643                         to_unpin, pages->writable);
644                 iopt_pages_sub_npinned(pages, to_unpin);
645                 cur++;
646                 first_page_off = 0;
647                 npages -= to_unpin;
648         }
649 }
650
651 static void copy_data_page(struct page *page, void *data, unsigned long offset,
652                            size_t length, unsigned int flags)
653 {
654         void *mem;
655
656         mem = kmap_local_page(page);
657         if (flags & IOMMUFD_ACCESS_RW_WRITE) {
658                 memcpy(mem + offset, data, length);
659                 set_page_dirty_lock(page);
660         } else {
661                 memcpy(data, mem + offset, length);
662         }
663         kunmap_local(mem);
664 }
665
666 static unsigned long batch_rw(struct pfn_batch *batch, void *data,
667                               unsigned long offset, unsigned long length,
668                               unsigned int flags)
669 {
670         unsigned long copied = 0;
671         unsigned int npage = 0;
672         unsigned int cur = 0;
673
674         while (cur < batch->end) {
675                 unsigned long bytes = min(length, PAGE_SIZE - offset);
676
677                 copy_data_page(pfn_to_page(batch->pfns[cur] + npage), data,
678                                offset, bytes, flags);
679                 offset = 0;
680                 length -= bytes;
681                 data += bytes;
682                 copied += bytes;
683                 npage++;
684                 if (npage == batch->npfns[cur]) {
685                         npage = 0;
686                         cur++;
687                 }
688                 if (!length)
689                         break;
690         }
691         return copied;
692 }
693
694 /* pfn_reader_user is just the pin_user_pages() path */
695 struct pfn_reader_user {
696         struct page **upages;
697         size_t upages_len;
698         unsigned long upages_start;
699         unsigned long upages_end;
700         unsigned int gup_flags;
701         /*
702          * 1 means mmget() and mmap_read_lock(), 0 means only mmget(), -1 is
703          * neither
704          */
705         int locked;
706 };
707
708 static void pfn_reader_user_init(struct pfn_reader_user *user,
709                                  struct iopt_pages *pages)
710 {
711         user->upages = NULL;
712         user->upages_start = 0;
713         user->upages_end = 0;
714         user->locked = -1;
715
716         user->gup_flags = FOLL_LONGTERM;
717         if (pages->writable)
718                 user->gup_flags |= FOLL_WRITE;
719 }
720
721 static void pfn_reader_user_destroy(struct pfn_reader_user *user,
722                                     struct iopt_pages *pages)
723 {
724         if (user->locked != -1) {
725                 if (user->locked)
726                         mmap_read_unlock(pages->source_mm);
727                 if (pages->source_mm != current->mm)
728                         mmput(pages->source_mm);
729                 user->locked = -1;
730         }
731
732         kfree(user->upages);
733         user->upages = NULL;
734 }
735
736 static int pfn_reader_user_pin(struct pfn_reader_user *user,
737                                struct iopt_pages *pages,
738                                unsigned long start_index,
739                                unsigned long last_index)
740 {
741         bool remote_mm = pages->source_mm != current->mm;
742         unsigned long npages;
743         uintptr_t uptr;
744         long rc;
745
746         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
747             WARN_ON(last_index < start_index))
748                 return -EINVAL;
749
750         if (!user->upages) {
751                 /* All undone in pfn_reader_destroy() */
752                 user->upages_len =
753                         (last_index - start_index + 1) * sizeof(*user->upages);
754                 user->upages = temp_kmalloc(&user->upages_len, NULL, 0);
755                 if (!user->upages)
756                         return -ENOMEM;
757         }
758
759         if (user->locked == -1) {
760                 /*
761                  * The majority of usages will run the map task within the mm
762                  * providing the pages, so we can optimize into
763                  * get_user_pages_fast()
764                  */
765                 if (remote_mm) {
766                         if (!mmget_not_zero(pages->source_mm))
767                                 return -EFAULT;
768                 }
769                 user->locked = 0;
770         }
771
772         npages = min_t(unsigned long, last_index - start_index + 1,
773                        user->upages_len / sizeof(*user->upages));
774
775
776         if (iommufd_should_fail())
777                 return -EFAULT;
778
779         uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
780         if (!remote_mm)
781                 rc = pin_user_pages_fast(uptr, npages, user->gup_flags,
782                                          user->upages);
783         else {
784                 if (!user->locked) {
785                         mmap_read_lock(pages->source_mm);
786                         user->locked = 1;
787                 }
788                 rc = pin_user_pages_remote(pages->source_mm, uptr, npages,
789                                            user->gup_flags, user->upages,
790                                            &user->locked);
791         }
792         if (rc <= 0) {
793                 if (WARN_ON(!rc))
794                         return -EFAULT;
795                 return rc;
796         }
797         iopt_pages_add_npinned(pages, rc);
798         user->upages_start = start_index;
799         user->upages_end = start_index + rc;
800         return 0;
801 }
802
803 /* This is the "modern" and faster accounting method used by io_uring */
804 static int incr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
805 {
806         unsigned long lock_limit;
807         unsigned long cur_pages;
808         unsigned long new_pages;
809
810         lock_limit = task_rlimit(pages->source_task, RLIMIT_MEMLOCK) >>
811                      PAGE_SHIFT;
812         do {
813                 cur_pages = atomic_long_read(&pages->source_user->locked_vm);
814                 new_pages = cur_pages + npages;
815                 if (new_pages > lock_limit)
816                         return -ENOMEM;
817         } while (atomic_long_cmpxchg(&pages->source_user->locked_vm, cur_pages,
818                                      new_pages) != cur_pages);
819         return 0;
820 }
821
822 static void decr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
823 {
824         if (WARN_ON(atomic_long_read(&pages->source_user->locked_vm) < npages))
825                 return;
826         atomic_long_sub(npages, &pages->source_user->locked_vm);
827 }
828
829 /* This is the accounting method used for compatibility with VFIO */
830 static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
831                                bool inc, struct pfn_reader_user *user)
832 {
833         bool do_put = false;
834         int rc;
835
836         if (user && user->locked) {
837                 mmap_read_unlock(pages->source_mm);
838                 user->locked = 0;
839                 /* If we had the lock then we also have a get */
840         } else if ((!user || !user->upages) &&
841                    pages->source_mm != current->mm) {
842                 if (!mmget_not_zero(pages->source_mm))
843                         return -EINVAL;
844                 do_put = true;
845         }
846
847         mmap_write_lock(pages->source_mm);
848         rc = __account_locked_vm(pages->source_mm, npages, inc,
849                                  pages->source_task, false);
850         mmap_write_unlock(pages->source_mm);
851
852         if (do_put)
853                 mmput(pages->source_mm);
854         return rc;
855 }
856
857 static int do_update_pinned(struct iopt_pages *pages, unsigned long npages,
858                             bool inc, struct pfn_reader_user *user)
859 {
860         int rc = 0;
861
862         switch (pages->account_mode) {
863         case IOPT_PAGES_ACCOUNT_NONE:
864                 break;
865         case IOPT_PAGES_ACCOUNT_USER:
866                 if (inc)
867                         rc = incr_user_locked_vm(pages, npages);
868                 else
869                         decr_user_locked_vm(pages, npages);
870                 break;
871         case IOPT_PAGES_ACCOUNT_MM:
872                 rc = update_mm_locked_vm(pages, npages, inc, user);
873                 break;
874         }
875         if (rc)
876                 return rc;
877
878         pages->last_npinned = pages->npinned;
879         if (inc)
880                 atomic64_add(npages, &pages->source_mm->pinned_vm);
881         else
882                 atomic64_sub(npages, &pages->source_mm->pinned_vm);
883         return 0;
884 }
885
886 static void update_unpinned(struct iopt_pages *pages)
887 {
888         if (WARN_ON(pages->npinned > pages->last_npinned))
889                 return;
890         if (pages->npinned == pages->last_npinned)
891                 return;
892         do_update_pinned(pages, pages->last_npinned - pages->npinned, false,
893                          NULL);
894 }
895
896 /*
897  * Changes in the number of pages pinned is done after the pages have been read
898  * and processed. If the user lacked the limit then the error unwind will unpin
899  * everything that was just pinned. This is because it is expensive to calculate
900  * how many pages we have already pinned within a range to generate an accurate
901  * prediction in advance of doing the work to actually pin them.
902  */
903 static int pfn_reader_user_update_pinned(struct pfn_reader_user *user,
904                                          struct iopt_pages *pages)
905 {
906         unsigned long npages;
907         bool inc;
908
909         lockdep_assert_held(&pages->mutex);
910
911         if (pages->npinned == pages->last_npinned)
912                 return 0;
913
914         if (pages->npinned < pages->last_npinned) {
915                 npages = pages->last_npinned - pages->npinned;
916                 inc = false;
917         } else {
918                 if (iommufd_should_fail())
919                         return -ENOMEM;
920                 npages = pages->npinned - pages->last_npinned;
921                 inc = true;
922         }
923         return do_update_pinned(pages, npages, inc, user);
924 }
925
926 /*
927  * PFNs are stored in three places, in order of preference:
928  * - The iopt_pages xarray. This is only populated if there is a
929  *   iopt_pages_access
930  * - The iommu_domain under an area
931  * - The original PFN source, ie pages->source_mm
932  *
933  * This iterator reads the pfns optimizing to load according to the
934  * above order.
935  */
936 struct pfn_reader {
937         struct iopt_pages *pages;
938         struct interval_tree_double_span_iter span;
939         struct pfn_batch batch;
940         unsigned long batch_start_index;
941         unsigned long batch_end_index;
942         unsigned long last_index;
943
944         struct pfn_reader_user user;
945 };
946
947 static int pfn_reader_update_pinned(struct pfn_reader *pfns)
948 {
949         return pfn_reader_user_update_pinned(&pfns->user, pfns->pages);
950 }
951
952 /*
953  * The batch can contain a mixture of pages that are still in use and pages that
954  * need to be unpinned. Unpin only pages that are not held anywhere else.
955  */
956 static void pfn_reader_unpin(struct pfn_reader *pfns)
957 {
958         unsigned long last = pfns->batch_end_index - 1;
959         unsigned long start = pfns->batch_start_index;
960         struct interval_tree_double_span_iter span;
961         struct iopt_pages *pages = pfns->pages;
962
963         lockdep_assert_held(&pages->mutex);
964
965         interval_tree_for_each_double_span(&span, &pages->access_itree,
966                                            &pages->domains_itree, start, last) {
967                 if (span.is_used)
968                         continue;
969
970                 batch_unpin(&pfns->batch, pages, span.start_hole - start,
971                             span.last_hole - span.start_hole + 1);
972         }
973 }
974
975 /* Process a single span to load it from the proper storage */
976 static int pfn_reader_fill_span(struct pfn_reader *pfns)
977 {
978         struct interval_tree_double_span_iter *span = &pfns->span;
979         unsigned long start_index = pfns->batch_end_index;
980         struct iopt_area *area;
981         int rc;
982
983         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
984             WARN_ON(span->last_used < start_index))
985                 return -EINVAL;
986
987         if (span->is_used == 1) {
988                 batch_from_xarray(&pfns->batch, &pfns->pages->pinned_pfns,
989                                   start_index, span->last_used);
990                 return 0;
991         }
992
993         if (span->is_used == 2) {
994                 /*
995                  * Pull as many pages from the first domain we find in the
996                  * target span. If it is too small then we will be called again
997                  * and we'll find another area.
998                  */
999                 area = iopt_pages_find_domain_area(pfns->pages, start_index);
1000                 if (WARN_ON(!area))
1001                         return -EINVAL;
1002
1003                 /* The storage_domain cannot change without the pages mutex */
1004                 batch_from_domain(
1005                         &pfns->batch, area->storage_domain, area, start_index,
1006                         min(iopt_area_last_index(area), span->last_used));
1007                 return 0;
1008         }
1009
1010         if (start_index >= pfns->user.upages_end) {
1011                 rc = pfn_reader_user_pin(&pfns->user, pfns->pages, start_index,
1012                                          span->last_hole);
1013                 if (rc)
1014                         return rc;
1015         }
1016
1017         batch_from_pages(&pfns->batch,
1018                          pfns->user.upages +
1019                                  (start_index - pfns->user.upages_start),
1020                          pfns->user.upages_end - start_index);
1021         return 0;
1022 }
1023
1024 static bool pfn_reader_done(struct pfn_reader *pfns)
1025 {
1026         return pfns->batch_start_index == pfns->last_index + 1;
1027 }
1028
1029 static int pfn_reader_next(struct pfn_reader *pfns)
1030 {
1031         int rc;
1032
1033         batch_clear(&pfns->batch);
1034         pfns->batch_start_index = pfns->batch_end_index;
1035
1036         while (pfns->batch_end_index != pfns->last_index + 1) {
1037                 unsigned int npfns = pfns->batch.total_pfns;
1038
1039                 if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1040                     WARN_ON(interval_tree_double_span_iter_done(&pfns->span)))
1041                         return -EINVAL;
1042
1043                 rc = pfn_reader_fill_span(pfns);
1044                 if (rc)
1045                         return rc;
1046
1047                 if (WARN_ON(!pfns->batch.total_pfns))
1048                         return -EINVAL;
1049
1050                 pfns->batch_end_index =
1051                         pfns->batch_start_index + pfns->batch.total_pfns;
1052                 if (pfns->batch_end_index == pfns->span.last_used + 1)
1053                         interval_tree_double_span_iter_next(&pfns->span);
1054
1055                 /* Batch is full */
1056                 if (npfns == pfns->batch.total_pfns)
1057                         return 0;
1058         }
1059         return 0;
1060 }
1061
1062 static int pfn_reader_init(struct pfn_reader *pfns, struct iopt_pages *pages,
1063                            unsigned long start_index, unsigned long last_index)
1064 {
1065         int rc;
1066
1067         lockdep_assert_held(&pages->mutex);
1068
1069         pfns->pages = pages;
1070         pfns->batch_start_index = start_index;
1071         pfns->batch_end_index = start_index;
1072         pfns->last_index = last_index;
1073         pfn_reader_user_init(&pfns->user, pages);
1074         rc = batch_init(&pfns->batch, last_index - start_index + 1);
1075         if (rc)
1076                 return rc;
1077         interval_tree_double_span_iter_first(&pfns->span, &pages->access_itree,
1078                                              &pages->domains_itree, start_index,
1079                                              last_index);
1080         return 0;
1081 }
1082
1083 /*
1084  * There are many assertions regarding the state of pages->npinned vs
1085  * pages->last_pinned, for instance something like unmapping a domain must only
1086  * decrement the npinned, and pfn_reader_destroy() must be called only after all
1087  * the pins are updated. This is fine for success flows, but error flows
1088  * sometimes need to release the pins held inside the pfn_reader before going on
1089  * to complete unmapping and releasing pins held in domains.
1090  */
1091 static void pfn_reader_release_pins(struct pfn_reader *pfns)
1092 {
1093         struct iopt_pages *pages = pfns->pages;
1094
1095         if (pfns->user.upages_end > pfns->batch_end_index) {
1096                 size_t npages = pfns->user.upages_end - pfns->batch_end_index;
1097
1098                 /* Any pages not transferred to the batch are just unpinned */
1099                 unpin_user_pages(pfns->user.upages + (pfns->batch_end_index -
1100                                                       pfns->user.upages_start),
1101                                  npages);
1102                 iopt_pages_sub_npinned(pages, npages);
1103                 pfns->user.upages_end = pfns->batch_end_index;
1104         }
1105         if (pfns->batch_start_index != pfns->batch_end_index) {
1106                 pfn_reader_unpin(pfns);
1107                 pfns->batch_start_index = pfns->batch_end_index;
1108         }
1109 }
1110
1111 static void pfn_reader_destroy(struct pfn_reader *pfns)
1112 {
1113         struct iopt_pages *pages = pfns->pages;
1114
1115         pfn_reader_release_pins(pfns);
1116         pfn_reader_user_destroy(&pfns->user, pfns->pages);
1117         batch_destroy(&pfns->batch, NULL);
1118         WARN_ON(pages->last_npinned != pages->npinned);
1119 }
1120
1121 static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
1122                             unsigned long start_index, unsigned long last_index)
1123 {
1124         int rc;
1125
1126         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1127             WARN_ON(last_index < start_index))
1128                 return -EINVAL;
1129
1130         rc = pfn_reader_init(pfns, pages, start_index, last_index);
1131         if (rc)
1132                 return rc;
1133         rc = pfn_reader_next(pfns);
1134         if (rc) {
1135                 pfn_reader_destroy(pfns);
1136                 return rc;
1137         }
1138         return 0;
1139 }
1140
1141 struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
1142                                     bool writable)
1143 {
1144         struct iopt_pages *pages;
1145         unsigned long end;
1146
1147         /*
1148          * The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
1149          * below from overflow
1150          */
1151         if (length > SIZE_MAX - PAGE_SIZE || length == 0)
1152                 return ERR_PTR(-EINVAL);
1153
1154         if (check_add_overflow((unsigned long)uptr, length, &end))
1155                 return ERR_PTR(-EOVERFLOW);
1156
1157         pages = kzalloc(sizeof(*pages), GFP_KERNEL_ACCOUNT);
1158         if (!pages)
1159                 return ERR_PTR(-ENOMEM);
1160
1161         kref_init(&pages->kref);
1162         xa_init_flags(&pages->pinned_pfns, XA_FLAGS_ACCOUNT);
1163         mutex_init(&pages->mutex);
1164         pages->source_mm = current->mm;
1165         mmgrab(pages->source_mm);
1166         pages->uptr = (void __user *)ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
1167         pages->npages = DIV_ROUND_UP(length + (uptr - pages->uptr), PAGE_SIZE);
1168         pages->access_itree = RB_ROOT_CACHED;
1169         pages->domains_itree = RB_ROOT_CACHED;
1170         pages->writable = writable;
1171         if (capable(CAP_IPC_LOCK))
1172                 pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1173         else
1174                 pages->account_mode = IOPT_PAGES_ACCOUNT_USER;
1175         pages->source_task = current->group_leader;
1176         get_task_struct(current->group_leader);
1177         pages->source_user = get_uid(current_user());
1178         return pages;
1179 }
1180
1181 void iopt_release_pages(struct kref *kref)
1182 {
1183         struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);
1184
1185         WARN_ON(!RB_EMPTY_ROOT(&pages->access_itree.rb_root));
1186         WARN_ON(!RB_EMPTY_ROOT(&pages->domains_itree.rb_root));
1187         WARN_ON(pages->npinned);
1188         WARN_ON(!xa_empty(&pages->pinned_pfns));
1189         mmdrop(pages->source_mm);
1190         mutex_destroy(&pages->mutex);
1191         put_task_struct(pages->source_task);
1192         free_uid(pages->source_user);
1193         kfree(pages);
1194 }
1195
1196 static void
1197 iopt_area_unpin_domain(struct pfn_batch *batch, struct iopt_area *area,
1198                        struct iopt_pages *pages, struct iommu_domain *domain,
1199                        unsigned long start_index, unsigned long last_index,
1200                        unsigned long *unmapped_end_index,
1201                        unsigned long real_last_index)
1202 {
1203         while (start_index <= last_index) {
1204                 unsigned long batch_last_index;
1205
1206                 if (*unmapped_end_index <= last_index) {
1207                         unsigned long start =
1208                                 max(start_index, *unmapped_end_index);
1209
1210                         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1211                             batch->total_pfns)
1212                                 WARN_ON(*unmapped_end_index -
1213                                                 batch->total_pfns !=
1214                                         start_index);
1215                         batch_from_domain(batch, domain, area, start,
1216                                           last_index);
1217                         batch_last_index = start_index + batch->total_pfns - 1;
1218                 } else {
1219                         batch_last_index = last_index;
1220                 }
1221
1222                 if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1223                         WARN_ON(batch_last_index > real_last_index);
1224
1225                 /*
1226                  * unmaps must always 'cut' at a place where the pfns are not
1227                  * contiguous to pair with the maps that always install
1228                  * contiguous pages. Thus, if we have to stop unpinning in the
1229                  * middle of the domains we need to keep reading pfns until we
1230                  * find a cut point to do the unmap. The pfns we read are
1231                  * carried over and either skipped or integrated into the next
1232                  * batch.
1233                  */
1234                 if (batch_last_index == last_index &&
1235                     last_index != real_last_index)
1236                         batch_from_domain_continue(batch, domain, area,
1237                                                    last_index + 1,
1238                                                    real_last_index);
1239
1240                 if (*unmapped_end_index <= batch_last_index) {
1241                         iopt_area_unmap_domain_range(
1242                                 area, domain, *unmapped_end_index,
1243                                 start_index + batch->total_pfns - 1);
1244                         *unmapped_end_index = start_index + batch->total_pfns;
1245                 }
1246
1247                 /* unpin must follow unmap */
1248                 batch_unpin(batch, pages, 0,
1249                             batch_last_index - start_index + 1);
1250                 start_index = batch_last_index + 1;
1251
1252                 batch_clear_carry(batch,
1253                                   *unmapped_end_index - batch_last_index - 1);
1254         }
1255 }
1256
1257 static void __iopt_area_unfill_domain(struct iopt_area *area,
1258                                       struct iopt_pages *pages,
1259                                       struct iommu_domain *domain,
1260                                       unsigned long last_index)
1261 {
1262         struct interval_tree_double_span_iter span;
1263         unsigned long start_index = iopt_area_index(area);
1264         unsigned long unmapped_end_index = start_index;
1265         u64 backup[BATCH_BACKUP_SIZE];
1266         struct pfn_batch batch;
1267
1268         lockdep_assert_held(&pages->mutex);
1269
1270         /*
1271          * For security we must not unpin something that is still DMA mapped,
1272          * so this must unmap any IOVA before we go ahead and unpin the pages.
1273          * This creates a complexity where we need to skip over unpinning pages
1274          * held in the xarray, but continue to unmap from the domain.
1275          *
1276          * The domain unmap cannot stop in the middle of a contiguous range of
1277          * PFNs. To solve this problem the unpinning step will read ahead to the
1278          * end of any contiguous span, unmap that whole span, and then only
1279          * unpin the leading part that does not have any accesses. The residual
1280          * PFNs that were unmapped but not unpinned are called a "carry" in the
1281          * batch as they are moved to the front of the PFN list and continue on
1282          * to the next iteration(s).
1283          */
1284         batch_init_backup(&batch, last_index + 1, backup, sizeof(backup));
1285         interval_tree_for_each_double_span(&span, &pages->domains_itree,
1286                                            &pages->access_itree, start_index,
1287                                            last_index) {
1288                 if (span.is_used) {
1289                         batch_skip_carry(&batch,
1290                                          span.last_used - span.start_used + 1);
1291                         continue;
1292                 }
1293                 iopt_area_unpin_domain(&batch, area, pages, domain,
1294                                        span.start_hole, span.last_hole,
1295                                        &unmapped_end_index, last_index);
1296         }
1297         /*
1298          * If the range ends in a access then we do the residual unmap without
1299          * any unpins.
1300          */
1301         if (unmapped_end_index != last_index + 1)
1302                 iopt_area_unmap_domain_range(area, domain, unmapped_end_index,
1303                                              last_index);
1304         WARN_ON(batch.total_pfns);
1305         batch_destroy(&batch, backup);
1306         update_unpinned(pages);
1307 }
1308
1309 static void iopt_area_unfill_partial_domain(struct iopt_area *area,
1310                                             struct iopt_pages *pages,
1311                                             struct iommu_domain *domain,
1312                                             unsigned long end_index)
1313 {
1314         if (end_index != iopt_area_index(area))
1315                 __iopt_area_unfill_domain(area, pages, domain, end_index - 1);
1316 }
1317
1318 /**
1319  * iopt_area_unmap_domain() - Unmap without unpinning PFNs in a domain
1320  * @area: The IOVA range to unmap
1321  * @domain: The domain to unmap
1322  *
1323  * The caller must know that unpinning is not required, usually because there
1324  * are other domains in the iopt.
1325  */
1326 void iopt_area_unmap_domain(struct iopt_area *area, struct iommu_domain *domain)
1327 {
1328         iommu_unmap_nofail(domain, iopt_area_iova(area),
1329                            iopt_area_length(area));
1330 }
1331
1332 /**
1333  * iopt_area_unfill_domain() - Unmap and unpin PFNs in a domain
1334  * @area: IOVA area to use
1335  * @pages: page supplier for the area (area->pages is NULL)
1336  * @domain: Domain to unmap from
1337  *
1338  * The domain should be removed from the domains_itree before calling. The
1339  * domain will always be unmapped, but the PFNs may not be unpinned if there are
1340  * still accesses.
1341  */
1342 void iopt_area_unfill_domain(struct iopt_area *area, struct iopt_pages *pages,
1343                              struct iommu_domain *domain)
1344 {
1345         __iopt_area_unfill_domain(area, pages, domain,
1346                                   iopt_area_last_index(area));
1347 }
1348
1349 /**
1350  * iopt_area_fill_domain() - Map PFNs from the area into a domain
1351  * @area: IOVA area to use
1352  * @domain: Domain to load PFNs into
1353  *
1354  * Read the pfns from the area's underlying iopt_pages and map them into the
1355  * given domain. Called when attaching a new domain to an io_pagetable.
1356  */
1357 int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain)
1358 {
1359         unsigned long done_end_index;
1360         struct pfn_reader pfns;
1361         int rc;
1362
1363         lockdep_assert_held(&area->pages->mutex);
1364
1365         rc = pfn_reader_first(&pfns, area->pages, iopt_area_index(area),
1366                               iopt_area_last_index(area));
1367         if (rc)
1368                 return rc;
1369
1370         while (!pfn_reader_done(&pfns)) {
1371                 done_end_index = pfns.batch_start_index;
1372                 rc = batch_to_domain(&pfns.batch, domain, area,
1373                                      pfns.batch_start_index);
1374                 if (rc)
1375                         goto out_unmap;
1376                 done_end_index = pfns.batch_end_index;
1377
1378                 rc = pfn_reader_next(&pfns);
1379                 if (rc)
1380                         goto out_unmap;
1381         }
1382
1383         rc = pfn_reader_update_pinned(&pfns);
1384         if (rc)
1385                 goto out_unmap;
1386         goto out_destroy;
1387
1388 out_unmap:
1389         pfn_reader_release_pins(&pfns);
1390         iopt_area_unfill_partial_domain(area, area->pages, domain,
1391                                         done_end_index);
1392 out_destroy:
1393         pfn_reader_destroy(&pfns);
1394         return rc;
1395 }
1396
1397 /**
1398  * iopt_area_fill_domains() - Install PFNs into the area's domains
1399  * @area: The area to act on
1400  * @pages: The pages associated with the area (area->pages is NULL)
1401  *
1402  * Called during area creation. The area is freshly created and not inserted in
1403  * the domains_itree yet. PFNs are read and loaded into every domain held in the
1404  * area's io_pagetable and the area is installed in the domains_itree.
1405  *
1406  * On failure all domains are left unchanged.
1407  */
1408 int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages)
1409 {
1410         unsigned long done_first_end_index;
1411         unsigned long done_all_end_index;
1412         struct iommu_domain *domain;
1413         unsigned long unmap_index;
1414         struct pfn_reader pfns;
1415         unsigned long index;
1416         int rc;
1417
1418         lockdep_assert_held(&area->iopt->domains_rwsem);
1419
1420         if (xa_empty(&area->iopt->domains))
1421                 return 0;
1422
1423         mutex_lock(&pages->mutex);
1424         rc = pfn_reader_first(&pfns, pages, iopt_area_index(area),
1425                               iopt_area_last_index(area));
1426         if (rc)
1427                 goto out_unlock;
1428
1429         while (!pfn_reader_done(&pfns)) {
1430                 done_first_end_index = pfns.batch_end_index;
1431                 done_all_end_index = pfns.batch_start_index;
1432                 xa_for_each(&area->iopt->domains, index, domain) {
1433                         rc = batch_to_domain(&pfns.batch, domain, area,
1434                                              pfns.batch_start_index);
1435                         if (rc)
1436                                 goto out_unmap;
1437                 }
1438                 done_all_end_index = done_first_end_index;
1439
1440                 rc = pfn_reader_next(&pfns);
1441                 if (rc)
1442                         goto out_unmap;
1443         }
1444         rc = pfn_reader_update_pinned(&pfns);
1445         if (rc)
1446                 goto out_unmap;
1447
1448         area->storage_domain = xa_load(&area->iopt->domains, 0);
1449         interval_tree_insert(&area->pages_node, &pages->domains_itree);
1450         goto out_destroy;
1451
1452 out_unmap:
1453         pfn_reader_release_pins(&pfns);
1454         xa_for_each(&area->iopt->domains, unmap_index, domain) {
1455                 unsigned long end_index;
1456
1457                 if (unmap_index < index)
1458                         end_index = done_first_end_index;
1459                 else
1460                         end_index = done_all_end_index;
1461
1462                 /*
1463                  * The area is not yet part of the domains_itree so we have to
1464                  * manage the unpinning specially. The last domain does the
1465                  * unpin, every other domain is just unmapped.
1466                  */
1467                 if (unmap_index != area->iopt->next_domain_id - 1) {
1468                         if (end_index != iopt_area_index(area))
1469                                 iopt_area_unmap_domain_range(
1470                                         area, domain, iopt_area_index(area),
1471                                         end_index - 1);
1472                 } else {
1473                         iopt_area_unfill_partial_domain(area, pages, domain,
1474                                                         end_index);
1475                 }
1476         }
1477 out_destroy:
1478         pfn_reader_destroy(&pfns);
1479 out_unlock:
1480         mutex_unlock(&pages->mutex);
1481         return rc;
1482 }
1483
1484 /**
1485  * iopt_area_unfill_domains() - unmap PFNs from the area's domains
1486  * @area: The area to act on
1487  * @pages: The pages associated with the area (area->pages is NULL)
1488  *
1489  * Called during area destruction. This unmaps the iova's covered by all the
1490  * area's domains and releases the PFNs.
1491  */
1492 void iopt_area_unfill_domains(struct iopt_area *area, struct iopt_pages *pages)
1493 {
1494         struct io_pagetable *iopt = area->iopt;
1495         struct iommu_domain *domain;
1496         unsigned long index;
1497
1498         lockdep_assert_held(&iopt->domains_rwsem);
1499
1500         mutex_lock(&pages->mutex);
1501         if (!area->storage_domain)
1502                 goto out_unlock;
1503
1504         xa_for_each(&iopt->domains, index, domain)
1505                 if (domain != area->storage_domain)
1506                         iopt_area_unmap_domain_range(
1507                                 area, domain, iopt_area_index(area),
1508                                 iopt_area_last_index(area));
1509
1510         if (IS_ENABLED(CONFIG_IOMMUFD_TEST))
1511                 WARN_ON(RB_EMPTY_NODE(&area->pages_node.rb));
1512         interval_tree_remove(&area->pages_node, &pages->domains_itree);
1513         iopt_area_unfill_domain(area, pages, area->storage_domain);
1514         area->storage_domain = NULL;
1515 out_unlock:
1516         mutex_unlock(&pages->mutex);
1517 }
1518
1519 static void iopt_pages_unpin_xarray(struct pfn_batch *batch,
1520                                     struct iopt_pages *pages,
1521                                     unsigned long start_index,
1522                                     unsigned long end_index)
1523 {
1524         while (start_index <= end_index) {
1525                 batch_from_xarray_clear(batch, &pages->pinned_pfns, start_index,
1526                                         end_index);
1527                 batch_unpin(batch, pages, 0, batch->total_pfns);
1528                 start_index += batch->total_pfns;
1529                 batch_clear(batch);
1530         }
1531 }
1532
1533 /**
1534  * iopt_pages_unfill_xarray() - Update the xarry after removing an access
1535  * @pages: The pages to act on
1536  * @start_index: Starting PFN index
1537  * @last_index: Last PFN index
1538  *
1539  * Called when an iopt_pages_access is removed, removes pages from the itree.
1540  * The access should already be removed from the access_itree.
1541  */
1542 void iopt_pages_unfill_xarray(struct iopt_pages *pages,
1543                               unsigned long start_index,
1544                               unsigned long last_index)
1545 {
1546         struct interval_tree_double_span_iter span;
1547         u64 backup[BATCH_BACKUP_SIZE];
1548         struct pfn_batch batch;
1549         bool batch_inited = false;
1550
1551         lockdep_assert_held(&pages->mutex);
1552
1553         interval_tree_for_each_double_span(&span, &pages->access_itree,
1554                                            &pages->domains_itree, start_index,
1555                                            last_index) {
1556                 if (!span.is_used) {
1557                         if (!batch_inited) {
1558                                 batch_init_backup(&batch,
1559                                                   last_index - start_index + 1,
1560                                                   backup, sizeof(backup));
1561                                 batch_inited = true;
1562                         }
1563                         iopt_pages_unpin_xarray(&batch, pages, span.start_hole,
1564                                                 span.last_hole);
1565                 } else if (span.is_used == 2) {
1566                         /* Covered by a domain */
1567                         clear_xarray(&pages->pinned_pfns, span.start_used,
1568                                      span.last_used);
1569                 }
1570                 /* Otherwise covered by an existing access */
1571         }
1572         if (batch_inited)
1573                 batch_destroy(&batch, backup);
1574         update_unpinned(pages);
1575 }
1576
1577 /**
1578  * iopt_pages_fill_from_xarray() - Fast path for reading PFNs
1579  * @pages: The pages to act on
1580  * @start_index: The first page index in the range
1581  * @last_index: The last page index in the range
1582  * @out_pages: The output array to return the pages
1583  *
1584  * This can be called if the caller is holding a refcount on an
1585  * iopt_pages_access that is known to have already been filled. It quickly reads
1586  * the pages directly from the xarray.
1587  *
1588  * This is part of the SW iommu interface to read pages for in-kernel use.
1589  */
1590 void iopt_pages_fill_from_xarray(struct iopt_pages *pages,
1591                                  unsigned long start_index,
1592                                  unsigned long last_index,
1593                                  struct page **out_pages)
1594 {
1595         XA_STATE(xas, &pages->pinned_pfns, start_index);
1596         void *entry;
1597
1598         rcu_read_lock();
1599         while (start_index <= last_index) {
1600                 entry = xas_next(&xas);
1601                 if (xas_retry(&xas, entry))
1602                         continue;
1603                 WARN_ON(!xa_is_value(entry));
1604                 *(out_pages++) = pfn_to_page(xa_to_value(entry));
1605                 start_index++;
1606         }
1607         rcu_read_unlock();
1608 }
1609
1610 static int iopt_pages_fill_from_domain(struct iopt_pages *pages,
1611                                        unsigned long start_index,
1612                                        unsigned long last_index,
1613                                        struct page **out_pages)
1614 {
1615         while (start_index != last_index + 1) {
1616                 unsigned long domain_last;
1617                 struct iopt_area *area;
1618
1619                 area = iopt_pages_find_domain_area(pages, start_index);
1620                 if (WARN_ON(!area))
1621                         return -EINVAL;
1622
1623                 domain_last = min(iopt_area_last_index(area), last_index);
1624                 out_pages = raw_pages_from_domain(area->storage_domain, area,
1625                                                   start_index, domain_last,
1626                                                   out_pages);
1627                 start_index = domain_last + 1;
1628         }
1629         return 0;
1630 }
1631
1632 static int iopt_pages_fill_from_mm(struct iopt_pages *pages,
1633                                    struct pfn_reader_user *user,
1634                                    unsigned long start_index,
1635                                    unsigned long last_index,
1636                                    struct page **out_pages)
1637 {
1638         unsigned long cur_index = start_index;
1639         int rc;
1640
1641         while (cur_index != last_index + 1) {
1642                 user->upages = out_pages + (cur_index - start_index);
1643                 rc = pfn_reader_user_pin(user, pages, cur_index, last_index);
1644                 if (rc)
1645                         goto out_unpin;
1646                 cur_index = user->upages_end;
1647         }
1648         return 0;
1649
1650 out_unpin:
1651         if (start_index != cur_index)
1652                 iopt_pages_err_unpin(pages, start_index, cur_index - 1,
1653                                      out_pages);
1654         return rc;
1655 }
1656
1657 /**
1658  * iopt_pages_fill_xarray() - Read PFNs
1659  * @pages: The pages to act on
1660  * @start_index: The first page index in the range
1661  * @last_index: The last page index in the range
1662  * @out_pages: The output array to return the pages, may be NULL
1663  *
1664  * This populates the xarray and returns the pages in out_pages. As the slow
1665  * path this is able to copy pages from other storage tiers into the xarray.
1666  *
1667  * On failure the xarray is left unchanged.
1668  *
1669  * This is part of the SW iommu interface to read pages for in-kernel use.
1670  */
1671 int iopt_pages_fill_xarray(struct iopt_pages *pages, unsigned long start_index,
1672                            unsigned long last_index, struct page **out_pages)
1673 {
1674         struct interval_tree_double_span_iter span;
1675         unsigned long xa_end = start_index;
1676         struct pfn_reader_user user;
1677         int rc;
1678
1679         lockdep_assert_held(&pages->mutex);
1680
1681         pfn_reader_user_init(&user, pages);
1682         user.upages_len = (last_index - start_index + 1) * sizeof(*out_pages);
1683         interval_tree_for_each_double_span(&span, &pages->access_itree,
1684                                            &pages->domains_itree, start_index,
1685                                            last_index) {
1686                 struct page **cur_pages;
1687
1688                 if (span.is_used == 1) {
1689                         cur_pages = out_pages + (span.start_used - start_index);
1690                         iopt_pages_fill_from_xarray(pages, span.start_used,
1691                                                     span.last_used, cur_pages);
1692                         continue;
1693                 }
1694
1695                 if (span.is_used == 2) {
1696                         cur_pages = out_pages + (span.start_used - start_index);
1697                         iopt_pages_fill_from_domain(pages, span.start_used,
1698                                                     span.last_used, cur_pages);
1699                         rc = pages_to_xarray(&pages->pinned_pfns,
1700                                              span.start_used, span.last_used,
1701                                              cur_pages);
1702                         if (rc)
1703                                 goto out_clean_xa;
1704                         xa_end = span.last_used + 1;
1705                         continue;
1706                 }
1707
1708                 /* hole */
1709                 cur_pages = out_pages + (span.start_hole - start_index);
1710                 rc = iopt_pages_fill_from_mm(pages, &user, span.start_hole,
1711                                              span.last_hole, cur_pages);
1712                 if (rc)
1713                         goto out_clean_xa;
1714                 rc = pages_to_xarray(&pages->pinned_pfns, span.start_hole,
1715                                      span.last_hole, cur_pages);
1716                 if (rc) {
1717                         iopt_pages_err_unpin(pages, span.start_hole,
1718                                              span.last_hole, cur_pages);
1719                         goto out_clean_xa;
1720                 }
1721                 xa_end = span.last_hole + 1;
1722         }
1723         rc = pfn_reader_user_update_pinned(&user, pages);
1724         if (rc)
1725                 goto out_clean_xa;
1726         user.upages = NULL;
1727         pfn_reader_user_destroy(&user, pages);
1728         return 0;
1729
1730 out_clean_xa:
1731         if (start_index != xa_end)
1732                 iopt_pages_unfill_xarray(pages, start_index, xa_end - 1);
1733         user.upages = NULL;
1734         pfn_reader_user_destroy(&user, pages);
1735         return rc;
1736 }
1737
1738 /*
1739  * This uses the pfn_reader instead of taking a shortcut by using the mm. It can
1740  * do every scenario and is fully consistent with what an iommu_domain would
1741  * see.
1742  */
1743 static int iopt_pages_rw_slow(struct iopt_pages *pages,
1744                               unsigned long start_index,
1745                               unsigned long last_index, unsigned long offset,
1746                               void *data, unsigned long length,
1747                               unsigned int flags)
1748 {
1749         struct pfn_reader pfns;
1750         int rc;
1751
1752         mutex_lock(&pages->mutex);
1753
1754         rc = pfn_reader_first(&pfns, pages, start_index, last_index);
1755         if (rc)
1756                 goto out_unlock;
1757
1758         while (!pfn_reader_done(&pfns)) {
1759                 unsigned long done;
1760
1761                 done = batch_rw(&pfns.batch, data, offset, length, flags);
1762                 data += done;
1763                 length -= done;
1764                 offset = 0;
1765                 pfn_reader_unpin(&pfns);
1766
1767                 rc = pfn_reader_next(&pfns);
1768                 if (rc)
1769                         goto out_destroy;
1770         }
1771         if (WARN_ON(length != 0))
1772                 rc = -EINVAL;
1773 out_destroy:
1774         pfn_reader_destroy(&pfns);
1775 out_unlock:
1776         mutex_unlock(&pages->mutex);
1777         return rc;
1778 }
1779
1780 /*
1781  * A medium speed path that still allows DMA inconsistencies, but doesn't do any
1782  * memory allocations or interval tree searches.
1783  */
1784 static int iopt_pages_rw_page(struct iopt_pages *pages, unsigned long index,
1785                               unsigned long offset, void *data,
1786                               unsigned long length, unsigned int flags)
1787 {
1788         struct page *page = NULL;
1789         int rc;
1790
1791         if (!mmget_not_zero(pages->source_mm))
1792                 return iopt_pages_rw_slow(pages, index, index, offset, data,
1793                                           length, flags);
1794
1795         if (iommufd_should_fail()) {
1796                 rc = -EINVAL;
1797                 goto out_mmput;
1798         }
1799
1800         mmap_read_lock(pages->source_mm);
1801         rc = pin_user_pages_remote(
1802                 pages->source_mm, (uintptr_t)(pages->uptr + index * PAGE_SIZE),
1803                 1, (flags & IOMMUFD_ACCESS_RW_WRITE) ? FOLL_WRITE : 0, &page,
1804                 NULL);
1805         mmap_read_unlock(pages->source_mm);
1806         if (rc != 1) {
1807                 if (WARN_ON(rc >= 0))
1808                         rc = -EINVAL;
1809                 goto out_mmput;
1810         }
1811         copy_data_page(page, data, offset, length, flags);
1812         unpin_user_page(page);
1813         rc = 0;
1814
1815 out_mmput:
1816         mmput(pages->source_mm);
1817         return rc;
1818 }
1819
1820 /**
1821  * iopt_pages_rw_access - Copy to/from a linear slice of the pages
1822  * @pages: pages to act on
1823  * @start_byte: First byte of pages to copy to/from
1824  * @data: Kernel buffer to get/put the data
1825  * @length: Number of bytes to copy
1826  * @flags: IOMMUFD_ACCESS_RW_* flags
1827  *
1828  * This will find each page in the range, kmap it and then memcpy to/from
1829  * the given kernel buffer.
1830  */
1831 int iopt_pages_rw_access(struct iopt_pages *pages, unsigned long start_byte,
1832                          void *data, unsigned long length, unsigned int flags)
1833 {
1834         unsigned long start_index = start_byte / PAGE_SIZE;
1835         unsigned long last_index = (start_byte + length - 1) / PAGE_SIZE;
1836         bool change_mm = current->mm != pages->source_mm;
1837         int rc = 0;
1838
1839         if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1840             (flags & __IOMMUFD_ACCESS_RW_SLOW_PATH))
1841                 change_mm = true;
1842
1843         if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1844                 return -EPERM;
1845
1846         if (!(flags & IOMMUFD_ACCESS_RW_KTHREAD) && change_mm) {
1847                 if (start_index == last_index)
1848                         return iopt_pages_rw_page(pages, start_index,
1849                                                   start_byte % PAGE_SIZE, data,
1850                                                   length, flags);
1851                 return iopt_pages_rw_slow(pages, start_index, last_index,
1852                                           start_byte % PAGE_SIZE, data, length,
1853                                           flags);
1854         }
1855
1856         /*
1857          * Try to copy using copy_to_user(). We do this as a fast path and
1858          * ignore any pinning inconsistencies, unlike a real DMA path.
1859          */
1860         if (change_mm) {
1861                 if (!mmget_not_zero(pages->source_mm))
1862                         return iopt_pages_rw_slow(pages, start_index,
1863                                                   last_index,
1864                                                   start_byte % PAGE_SIZE, data,
1865                                                   length, flags);
1866                 kthread_use_mm(pages->source_mm);
1867         }
1868
1869         if (flags & IOMMUFD_ACCESS_RW_WRITE) {
1870                 if (copy_to_user(pages->uptr + start_byte, data, length))
1871                         rc = -EFAULT;
1872         } else {
1873                 if (copy_from_user(data, pages->uptr + start_byte, length))
1874                         rc = -EFAULT;
1875         }
1876
1877         if (change_mm) {
1878                 kthread_unuse_mm(pages->source_mm);
1879                 mmput(pages->source_mm);
1880         }
1881
1882         return rc;
1883 }
1884
1885 static struct iopt_pages_access *
1886 iopt_pages_get_exact_access(struct iopt_pages *pages, unsigned long index,
1887                             unsigned long last)
1888 {
1889         struct interval_tree_node *node;
1890
1891         lockdep_assert_held(&pages->mutex);
1892
1893         /* There can be overlapping ranges in this interval tree */
1894         for (node = interval_tree_iter_first(&pages->access_itree, index, last);
1895              node; node = interval_tree_iter_next(node, index, last))
1896                 if (node->start == index && node->last == last)
1897                         return container_of(node, struct iopt_pages_access,
1898                                             node);
1899         return NULL;
1900 }
1901
1902 /**
1903  * iopt_area_add_access() - Record an in-knerel access for PFNs
1904  * @area: The source of PFNs
1905  * @start_index: First page index
1906  * @last_index: Inclusive last page index
1907  * @out_pages: Output list of struct page's representing the PFNs
1908  * @flags: IOMMUFD_ACCESS_RW_* flags
1909  *
1910  * Record that an in-kernel access will be accessing the pages, ensure they are
1911  * pinned, and return the PFNs as a simple list of 'struct page *'.
1912  *
1913  * This should be undone through a matching call to iopt_area_remove_access()
1914  */
1915 int iopt_area_add_access(struct iopt_area *area, unsigned long start_index,
1916                           unsigned long last_index, struct page **out_pages,
1917                           unsigned int flags)
1918 {
1919         struct iopt_pages *pages = area->pages;
1920         struct iopt_pages_access *access;
1921         int rc;
1922
1923         if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1924                 return -EPERM;
1925
1926         mutex_lock(&pages->mutex);
1927         access = iopt_pages_get_exact_access(pages, start_index, last_index);
1928         if (access) {
1929                 area->num_accesses++;
1930                 access->users++;
1931                 iopt_pages_fill_from_xarray(pages, start_index, last_index,
1932                                             out_pages);
1933                 mutex_unlock(&pages->mutex);
1934                 return 0;
1935         }
1936
1937         access = kzalloc(sizeof(*access), GFP_KERNEL_ACCOUNT);
1938         if (!access) {
1939                 rc = -ENOMEM;
1940                 goto err_unlock;
1941         }
1942
1943         rc = iopt_pages_fill_xarray(pages, start_index, last_index, out_pages);
1944         if (rc)
1945                 goto err_free;
1946
1947         access->node.start = start_index;
1948         access->node.last = last_index;
1949         access->users = 1;
1950         area->num_accesses++;
1951         interval_tree_insert(&access->node, &pages->access_itree);
1952         mutex_unlock(&pages->mutex);
1953         return 0;
1954
1955 err_free:
1956         kfree(access);
1957 err_unlock:
1958         mutex_unlock(&pages->mutex);
1959         return rc;
1960 }
1961
1962 /**
1963  * iopt_area_remove_access() - Release an in-kernel access for PFNs
1964  * @area: The source of PFNs
1965  * @start_index: First page index
1966  * @last_index: Inclusive last page index
1967  *
1968  * Undo iopt_area_add_access() and unpin the pages if necessary. The caller
1969  * must stop using the PFNs before calling this.
1970  */
1971 void iopt_area_remove_access(struct iopt_area *area, unsigned long start_index,
1972                              unsigned long last_index)
1973 {
1974         struct iopt_pages *pages = area->pages;
1975         struct iopt_pages_access *access;
1976
1977         mutex_lock(&pages->mutex);
1978         access = iopt_pages_get_exact_access(pages, start_index, last_index);
1979         if (WARN_ON(!access))
1980                 goto out_unlock;
1981
1982         WARN_ON(area->num_accesses == 0 || access->users == 0);
1983         area->num_accesses--;
1984         access->users--;
1985         if (access->users)
1986                 goto out_unlock;
1987
1988         interval_tree_remove(&access->node, &pages->access_itree);
1989         iopt_pages_unfill_xarray(pages, start_index, last_index);
1990         kfree(access);
1991 out_unlock:
1992         mutex_unlock(&pages->mutex);
1993 }