thermal: Validate new state in cur_state_store()
[platform/kernel/linux-rpi.git] / drivers / vfio / vfio_iommu_type1.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * VFIO: IOMMU DMA mapping support for Type1 IOMMU
4  *
5  * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
6  *     Author: Alex Williamson <alex.williamson@redhat.com>
7  *
8  * Derived from original vfio:
9  * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
10  * Author: Tom Lyon, pugs@cisco.com
11  *
12  * We arbitrarily define a Type1 IOMMU as one matching the below code.
13  * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
14  * VT-d, but that makes it harder to re-use as theoretically anyone
15  * implementing a similar IOMMU could make use of this.  We expect the
16  * IOMMU to support the IOMMU API and have few to no restrictions around
17  * the IOVA range that can be mapped.  The Type1 IOMMU is currently
18  * optimized for relatively static mappings of a userspace process with
19  * userspace pages pinned into memory.  We also assume devices and IOMMU
20  * domains are PCI based as the IOMMU API is still centered around a
21  * device/bus interface rather than a group interface.
22  */
23
24 #include <linux/compat.h>
25 #include <linux/device.h>
26 #include <linux/fs.h>
27 #include <linux/highmem.h>
28 #include <linux/iommu.h>
29 #include <linux/module.h>
30 #include <linux/mm.h>
31 #include <linux/kthread.h>
32 #include <linux/rbtree.h>
33 #include <linux/sched/signal.h>
34 #include <linux/sched/mm.h>
35 #include <linux/slab.h>
36 #include <linux/uaccess.h>
37 #include <linux/vfio.h>
38 #include <linux/workqueue.h>
39 #include <linux/mdev.h>
40 #include <linux/notifier.h>
41 #include <linux/dma-iommu.h>
42 #include <linux/irqdomain.h>
43
44 #define DRIVER_VERSION  "0.2"
45 #define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
46 #define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
47
48 static bool allow_unsafe_interrupts;
49 module_param_named(allow_unsafe_interrupts,
50                    allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
51 MODULE_PARM_DESC(allow_unsafe_interrupts,
52                  "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
53
54 static bool disable_hugepages;
55 module_param_named(disable_hugepages,
56                    disable_hugepages, bool, S_IRUGO | S_IWUSR);
57 MODULE_PARM_DESC(disable_hugepages,
58                  "Disable VFIO IOMMU support for IOMMU hugepages.");
59
60 static unsigned int dma_entry_limit __read_mostly = U16_MAX;
61 module_param_named(dma_entry_limit, dma_entry_limit, uint, 0644);
62 MODULE_PARM_DESC(dma_entry_limit,
63                  "Maximum number of user DMA mappings per container (65535).");
64
65 struct vfio_iommu {
66         struct list_head        domain_list;
67         struct list_head        iova_list;
68         struct vfio_domain      *external_domain; /* domain for external user */
69         struct mutex            lock;
70         struct rb_root          dma_list;
71         struct blocking_notifier_head notifier;
72         unsigned int            dma_avail;
73         unsigned int            vaddr_invalid_count;
74         uint64_t                pgsize_bitmap;
75         uint64_t                num_non_pinned_groups;
76         wait_queue_head_t       vaddr_wait;
77         bool                    v2;
78         bool                    nesting;
79         bool                    dirty_page_tracking;
80         bool                    container_open;
81 };
82
83 struct vfio_domain {
84         struct iommu_domain     *domain;
85         struct list_head        next;
86         struct list_head        group_list;
87         int                     prot;           /* IOMMU_CACHE */
88         bool                    fgsp;           /* Fine-grained super pages */
89 };
90
91 struct vfio_dma {
92         struct rb_node          node;
93         dma_addr_t              iova;           /* Device address */
94         unsigned long           vaddr;          /* Process virtual addr */
95         size_t                  size;           /* Map size (bytes) */
96         int                     prot;           /* IOMMU_READ/WRITE */
97         bool                    iommu_mapped;
98         bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
99         bool                    vaddr_invalid;
100         struct task_struct      *task;
101         struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
102         unsigned long           *bitmap;
103 };
104
105 struct vfio_batch {
106         struct page             **pages;        /* for pin_user_pages_remote */
107         struct page             *fallback_page; /* if pages alloc fails */
108         int                     capacity;       /* length of pages array */
109         int                     size;           /* of batch currently */
110         int                     offset;         /* of next entry in pages */
111 };
112
113 struct vfio_iommu_group {
114         struct iommu_group      *iommu_group;
115         struct list_head        next;
116         bool                    mdev_group;     /* An mdev group */
117         bool                    pinned_page_dirty_scope;
118 };
119
120 struct vfio_iova {
121         struct list_head        list;
122         dma_addr_t              start;
123         dma_addr_t              end;
124 };
125
126 /*
127  * Guest RAM pinning working set or DMA target
128  */
129 struct vfio_pfn {
130         struct rb_node          node;
131         dma_addr_t              iova;           /* Device address */
132         unsigned long           pfn;            /* Host pfn */
133         unsigned int            ref_count;
134 };
135
136 struct vfio_regions {
137         struct list_head list;
138         dma_addr_t iova;
139         phys_addr_t phys;
140         size_t len;
141 };
142
143 #define IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu) \
144                                         (!list_empty(&iommu->domain_list))
145
146 #define DIRTY_BITMAP_BYTES(n)   (ALIGN(n, BITS_PER_TYPE(u64)) / BITS_PER_BYTE)
147
148 /*
149  * Input argument of number of bits to bitmap_set() is unsigned integer, which
150  * further casts to signed integer for unaligned multi-bit operation,
151  * __bitmap_set().
152  * Then maximum bitmap size supported is 2^31 bits divided by 2^3 bits/byte,
153  * that is 2^28 (256 MB) which maps to 2^31 * 2^12 = 2^43 (8TB) on 4K page
154  * system.
155  */
156 #define DIRTY_BITMAP_PAGES_MAX   ((u64)INT_MAX)
157 #define DIRTY_BITMAP_SIZE_MAX    DIRTY_BITMAP_BYTES(DIRTY_BITMAP_PAGES_MAX)
158
159 #define WAITED 1
160
161 static int put_pfn(unsigned long pfn, int prot);
162
163 static struct vfio_iommu_group*
164 vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
165                             struct iommu_group *iommu_group);
166
167 /*
168  * This code handles mapping and unmapping of user data buffers
169  * into DMA'ble space using the IOMMU
170  */
171
172 static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
173                                       dma_addr_t start, size_t size)
174 {
175         struct rb_node *node = iommu->dma_list.rb_node;
176
177         while (node) {
178                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
179
180                 if (start + size <= dma->iova)
181                         node = node->rb_left;
182                 else if (start >= dma->iova + dma->size)
183                         node = node->rb_right;
184                 else
185                         return dma;
186         }
187
188         return NULL;
189 }
190
191 static struct rb_node *vfio_find_dma_first_node(struct vfio_iommu *iommu,
192                                                 dma_addr_t start, u64 size)
193 {
194         struct rb_node *res = NULL;
195         struct rb_node *node = iommu->dma_list.rb_node;
196         struct vfio_dma *dma_res = NULL;
197
198         while (node) {
199                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
200
201                 if (start < dma->iova + dma->size) {
202                         res = node;
203                         dma_res = dma;
204                         if (start >= dma->iova)
205                                 break;
206                         node = node->rb_left;
207                 } else {
208                         node = node->rb_right;
209                 }
210         }
211         if (res && size && dma_res->iova >= start + size)
212                 res = NULL;
213         return res;
214 }
215
216 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
217 {
218         struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
219         struct vfio_dma *dma;
220
221         while (*link) {
222                 parent = *link;
223                 dma = rb_entry(parent, struct vfio_dma, node);
224
225                 if (new->iova + new->size <= dma->iova)
226                         link = &(*link)->rb_left;
227                 else
228                         link = &(*link)->rb_right;
229         }
230
231         rb_link_node(&new->node, parent, link);
232         rb_insert_color(&new->node, &iommu->dma_list);
233 }
234
235 static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
236 {
237         rb_erase(&old->node, &iommu->dma_list);
238 }
239
240
241 static int vfio_dma_bitmap_alloc(struct vfio_dma *dma, size_t pgsize)
242 {
243         uint64_t npages = dma->size / pgsize;
244
245         if (npages > DIRTY_BITMAP_PAGES_MAX)
246                 return -EINVAL;
247
248         /*
249          * Allocate extra 64 bits that are used to calculate shift required for
250          * bitmap_shift_left() to manipulate and club unaligned number of pages
251          * in adjacent vfio_dma ranges.
252          */
253         dma->bitmap = kvzalloc(DIRTY_BITMAP_BYTES(npages) + sizeof(u64),
254                                GFP_KERNEL);
255         if (!dma->bitmap)
256                 return -ENOMEM;
257
258         return 0;
259 }
260
261 static void vfio_dma_bitmap_free(struct vfio_dma *dma)
262 {
263         kfree(dma->bitmap);
264         dma->bitmap = NULL;
265 }
266
267 static void vfio_dma_populate_bitmap(struct vfio_dma *dma, size_t pgsize)
268 {
269         struct rb_node *p;
270         unsigned long pgshift = __ffs(pgsize);
271
272         for (p = rb_first(&dma->pfn_list); p; p = rb_next(p)) {
273                 struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn, node);
274
275                 bitmap_set(dma->bitmap, (vpfn->iova - dma->iova) >> pgshift, 1);
276         }
277 }
278
279 static void vfio_iommu_populate_bitmap_full(struct vfio_iommu *iommu)
280 {
281         struct rb_node *n;
282         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
283
284         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
285                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
286
287                 bitmap_set(dma->bitmap, 0, dma->size >> pgshift);
288         }
289 }
290
291 static int vfio_dma_bitmap_alloc_all(struct vfio_iommu *iommu, size_t pgsize)
292 {
293         struct rb_node *n;
294
295         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
296                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
297                 int ret;
298
299                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
300                 if (ret) {
301                         struct rb_node *p;
302
303                         for (p = rb_prev(n); p; p = rb_prev(p)) {
304                                 struct vfio_dma *dma = rb_entry(n,
305                                                         struct vfio_dma, node);
306
307                                 vfio_dma_bitmap_free(dma);
308                         }
309                         return ret;
310                 }
311                 vfio_dma_populate_bitmap(dma, pgsize);
312         }
313         return 0;
314 }
315
316 static void vfio_dma_bitmap_free_all(struct vfio_iommu *iommu)
317 {
318         struct rb_node *n;
319
320         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
321                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
322
323                 vfio_dma_bitmap_free(dma);
324         }
325 }
326
327 /*
328  * Helper Functions for host iova-pfn list
329  */
330 static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
331 {
332         struct vfio_pfn *vpfn;
333         struct rb_node *node = dma->pfn_list.rb_node;
334
335         while (node) {
336                 vpfn = rb_entry(node, struct vfio_pfn, node);
337
338                 if (iova < vpfn->iova)
339                         node = node->rb_left;
340                 else if (iova > vpfn->iova)
341                         node = node->rb_right;
342                 else
343                         return vpfn;
344         }
345         return NULL;
346 }
347
348 static void vfio_link_pfn(struct vfio_dma *dma,
349                           struct vfio_pfn *new)
350 {
351         struct rb_node **link, *parent = NULL;
352         struct vfio_pfn *vpfn;
353
354         link = &dma->pfn_list.rb_node;
355         while (*link) {
356                 parent = *link;
357                 vpfn = rb_entry(parent, struct vfio_pfn, node);
358
359                 if (new->iova < vpfn->iova)
360                         link = &(*link)->rb_left;
361                 else
362                         link = &(*link)->rb_right;
363         }
364
365         rb_link_node(&new->node, parent, link);
366         rb_insert_color(&new->node, &dma->pfn_list);
367 }
368
369 static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
370 {
371         rb_erase(&old->node, &dma->pfn_list);
372 }
373
374 static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
375                                 unsigned long pfn)
376 {
377         struct vfio_pfn *vpfn;
378
379         vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
380         if (!vpfn)
381                 return -ENOMEM;
382
383         vpfn->iova = iova;
384         vpfn->pfn = pfn;
385         vpfn->ref_count = 1;
386         vfio_link_pfn(dma, vpfn);
387         return 0;
388 }
389
390 static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
391                                       struct vfio_pfn *vpfn)
392 {
393         vfio_unlink_pfn(dma, vpfn);
394         kfree(vpfn);
395 }
396
397 static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
398                                                unsigned long iova)
399 {
400         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
401
402         if (vpfn)
403                 vpfn->ref_count++;
404         return vpfn;
405 }
406
407 static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
408 {
409         int ret = 0;
410
411         vpfn->ref_count--;
412         if (!vpfn->ref_count) {
413                 ret = put_pfn(vpfn->pfn, dma->prot);
414                 vfio_remove_from_pfn_list(dma, vpfn);
415         }
416         return ret;
417 }
418
419 static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
420 {
421         struct mm_struct *mm;
422         int ret;
423
424         if (!npage)
425                 return 0;
426
427         mm = async ? get_task_mm(dma->task) : dma->task->mm;
428         if (!mm)
429                 return -ESRCH; /* process exited */
430
431         ret = mmap_write_lock_killable(mm);
432         if (!ret) {
433                 ret = __account_locked_vm(mm, abs(npage), npage > 0, dma->task,
434                                           dma->lock_cap);
435                 mmap_write_unlock(mm);
436         }
437
438         if (async)
439                 mmput(mm);
440
441         return ret;
442 }
443
444 /*
445  * Some mappings aren't backed by a struct page, for example an mmap'd
446  * MMIO range for our own or another device.  These use a different
447  * pfn conversion and shouldn't be tracked as locked pages.
448  * For compound pages, any driver that sets the reserved bit in head
449  * page needs to set the reserved bit in all subpages to be safe.
450  */
451 static bool is_invalid_reserved_pfn(unsigned long pfn)
452 {
453         if (pfn_valid(pfn))
454                 return PageReserved(pfn_to_page(pfn));
455
456         return true;
457 }
458
459 static int put_pfn(unsigned long pfn, int prot)
460 {
461         if (!is_invalid_reserved_pfn(pfn)) {
462                 struct page *page = pfn_to_page(pfn);
463
464                 unpin_user_pages_dirty_lock(&page, 1, prot & IOMMU_WRITE);
465                 return 1;
466         }
467         return 0;
468 }
469
470 #define VFIO_BATCH_MAX_CAPACITY (PAGE_SIZE / sizeof(struct page *))
471
472 static void vfio_batch_init(struct vfio_batch *batch)
473 {
474         batch->size = 0;
475         batch->offset = 0;
476
477         if (unlikely(disable_hugepages))
478                 goto fallback;
479
480         batch->pages = (struct page **) __get_free_page(GFP_KERNEL);
481         if (!batch->pages)
482                 goto fallback;
483
484         batch->capacity = VFIO_BATCH_MAX_CAPACITY;
485         return;
486
487 fallback:
488         batch->pages = &batch->fallback_page;
489         batch->capacity = 1;
490 }
491
492 static void vfio_batch_unpin(struct vfio_batch *batch, struct vfio_dma *dma)
493 {
494         while (batch->size) {
495                 unsigned long pfn = page_to_pfn(batch->pages[batch->offset]);
496
497                 put_pfn(pfn, dma->prot);
498                 batch->offset++;
499                 batch->size--;
500         }
501 }
502
503 static void vfio_batch_fini(struct vfio_batch *batch)
504 {
505         if (batch->capacity == VFIO_BATCH_MAX_CAPACITY)
506                 free_page((unsigned long)batch->pages);
507 }
508
509 static int follow_fault_pfn(struct vm_area_struct *vma, struct mm_struct *mm,
510                             unsigned long vaddr, unsigned long *pfn,
511                             bool write_fault)
512 {
513         pte_t *ptep;
514         spinlock_t *ptl;
515         int ret;
516
517         ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
518         if (ret) {
519                 bool unlocked = false;
520
521                 ret = fixup_user_fault(mm, vaddr,
522                                        FAULT_FLAG_REMOTE |
523                                        (write_fault ?  FAULT_FLAG_WRITE : 0),
524                                        &unlocked);
525                 if (unlocked)
526                         return -EAGAIN;
527
528                 if (ret)
529                         return ret;
530
531                 ret = follow_pte(vma->vm_mm, vaddr, &ptep, &ptl);
532                 if (ret)
533                         return ret;
534         }
535
536         if (write_fault && !pte_write(*ptep))
537                 ret = -EFAULT;
538         else
539                 *pfn = pte_pfn(*ptep);
540
541         pte_unmap_unlock(ptep, ptl);
542         return ret;
543 }
544
545 /*
546  * Returns the positive number of pfns successfully obtained or a negative
547  * error code.
548  */
549 static int vaddr_get_pfns(struct mm_struct *mm, unsigned long vaddr,
550                           long npages, int prot, unsigned long *pfn,
551                           struct page **pages)
552 {
553         struct vm_area_struct *vma;
554         unsigned int flags = 0;
555         int ret;
556
557         if (prot & IOMMU_WRITE)
558                 flags |= FOLL_WRITE;
559
560         mmap_read_lock(mm);
561         ret = pin_user_pages_remote(mm, vaddr, npages, flags | FOLL_LONGTERM,
562                                     pages, NULL, NULL);
563         if (ret > 0) {
564                 int i;
565
566                 /*
567                  * The zero page is always resident, we don't need to pin it
568                  * and it falls into our invalid/reserved test so we don't
569                  * unpin in put_pfn().  Unpin all zero pages in the batch here.
570                  */
571                 for (i = 0 ; i < ret; i++) {
572                         if (unlikely(is_zero_pfn(page_to_pfn(pages[i]))))
573                                 unpin_user_page(pages[i]);
574                 }
575
576                 *pfn = page_to_pfn(pages[0]);
577                 goto done;
578         }
579
580         vaddr = untagged_addr(vaddr);
581
582 retry:
583         vma = vma_lookup(mm, vaddr);
584
585         if (vma && vma->vm_flags & VM_PFNMAP) {
586                 ret = follow_fault_pfn(vma, mm, vaddr, pfn, prot & IOMMU_WRITE);
587                 if (ret == -EAGAIN)
588                         goto retry;
589
590                 if (!ret) {
591                         if (is_invalid_reserved_pfn(*pfn))
592                                 ret = 1;
593                         else
594                                 ret = -EFAULT;
595                 }
596         }
597 done:
598         mmap_read_unlock(mm);
599         return ret;
600 }
601
602 static int vfio_wait(struct vfio_iommu *iommu)
603 {
604         DEFINE_WAIT(wait);
605
606         prepare_to_wait(&iommu->vaddr_wait, &wait, TASK_KILLABLE);
607         mutex_unlock(&iommu->lock);
608         schedule();
609         mutex_lock(&iommu->lock);
610         finish_wait(&iommu->vaddr_wait, &wait);
611         if (kthread_should_stop() || !iommu->container_open ||
612             fatal_signal_pending(current)) {
613                 return -EFAULT;
614         }
615         return WAITED;
616 }
617
618 /*
619  * Find dma struct and wait for its vaddr to be valid.  iommu lock is dropped
620  * if the task waits, but is re-locked on return.  Return result in *dma_p.
621  * Return 0 on success with no waiting, WAITED on success if waited, and -errno
622  * on error.
623  */
624 static int vfio_find_dma_valid(struct vfio_iommu *iommu, dma_addr_t start,
625                                size_t size, struct vfio_dma **dma_p)
626 {
627         int ret = 0;
628
629         do {
630                 *dma_p = vfio_find_dma(iommu, start, size);
631                 if (!*dma_p)
632                         return -EINVAL;
633                 else if (!(*dma_p)->vaddr_invalid)
634                         return ret;
635                 else
636                         ret = vfio_wait(iommu);
637         } while (ret == WAITED);
638
639         return ret;
640 }
641
642 /*
643  * Wait for all vaddr in the dma_list to become valid.  iommu lock is dropped
644  * if the task waits, but is re-locked on return.  Return 0 on success with no
645  * waiting, WAITED on success if waited, and -errno on error.
646  */
647 static int vfio_wait_all_valid(struct vfio_iommu *iommu)
648 {
649         int ret = 0;
650
651         while (iommu->vaddr_invalid_count && ret >= 0)
652                 ret = vfio_wait(iommu);
653
654         return ret;
655 }
656
657 /*
658  * Attempt to pin pages.  We really don't want to track all the pfns and
659  * the iommu can only map chunks of consecutive pfns anyway, so get the
660  * first page and all consecutive pages with the same locking.
661  */
662 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
663                                   long npage, unsigned long *pfn_base,
664                                   unsigned long limit, struct vfio_batch *batch)
665 {
666         unsigned long pfn;
667         struct mm_struct *mm = current->mm;
668         long ret, pinned = 0, lock_acct = 0;
669         bool rsvd;
670         dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
671
672         /* This code path is only user initiated */
673         if (!mm)
674                 return -ENODEV;
675
676         if (batch->size) {
677                 /* Leftover pages in batch from an earlier call. */
678                 *pfn_base = page_to_pfn(batch->pages[batch->offset]);
679                 pfn = *pfn_base;
680                 rsvd = is_invalid_reserved_pfn(*pfn_base);
681         } else {
682                 *pfn_base = 0;
683         }
684
685         while (npage) {
686                 if (!batch->size) {
687                         /* Empty batch, so refill it. */
688                         long req_pages = min_t(long, npage, batch->capacity);
689
690                         ret = vaddr_get_pfns(mm, vaddr, req_pages, dma->prot,
691                                              &pfn, batch->pages);
692                         if (ret < 0)
693                                 goto unpin_out;
694
695                         batch->size = ret;
696                         batch->offset = 0;
697
698                         if (!*pfn_base) {
699                                 *pfn_base = pfn;
700                                 rsvd = is_invalid_reserved_pfn(*pfn_base);
701                         }
702                 }
703
704                 /*
705                  * pfn is preset for the first iteration of this inner loop and
706                  * updated at the end to handle a VM_PFNMAP pfn.  In that case,
707                  * batch->pages isn't valid (there's no struct page), so allow
708                  * batch->pages to be touched only when there's more than one
709                  * pfn to check, which guarantees the pfns are from a
710                  * !VM_PFNMAP vma.
711                  */
712                 while (true) {
713                         if (pfn != *pfn_base + pinned ||
714                             rsvd != is_invalid_reserved_pfn(pfn))
715                                 goto out;
716
717                         /*
718                          * Reserved pages aren't counted against the user,
719                          * externally pinned pages are already counted against
720                          * the user.
721                          */
722                         if (!rsvd && !vfio_find_vpfn(dma, iova)) {
723                                 if (!dma->lock_cap &&
724                                     mm->locked_vm + lock_acct + 1 > limit) {
725                                         pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
726                                                 __func__, limit << PAGE_SHIFT);
727                                         ret = -ENOMEM;
728                                         goto unpin_out;
729                                 }
730                                 lock_acct++;
731                         }
732
733                         pinned++;
734                         npage--;
735                         vaddr += PAGE_SIZE;
736                         iova += PAGE_SIZE;
737                         batch->offset++;
738                         batch->size--;
739
740                         if (!batch->size)
741                                 break;
742
743                         pfn = page_to_pfn(batch->pages[batch->offset]);
744                 }
745
746                 if (unlikely(disable_hugepages))
747                         break;
748         }
749
750 out:
751         ret = vfio_lock_acct(dma, lock_acct, false);
752
753 unpin_out:
754         if (batch->size == 1 && !batch->offset) {
755                 /* May be a VM_PFNMAP pfn, which the batch can't remember. */
756                 put_pfn(pfn, dma->prot);
757                 batch->size = 0;
758         }
759
760         if (ret < 0) {
761                 if (pinned && !rsvd) {
762                         for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
763                                 put_pfn(pfn, dma->prot);
764                 }
765                 vfio_batch_unpin(batch, dma);
766
767                 return ret;
768         }
769
770         return pinned;
771 }
772
773 static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
774                                     unsigned long pfn, long npage,
775                                     bool do_accounting)
776 {
777         long unlocked = 0, locked = 0;
778         long i;
779
780         for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
781                 if (put_pfn(pfn++, dma->prot)) {
782                         unlocked++;
783                         if (vfio_find_vpfn(dma, iova))
784                                 locked++;
785                 }
786         }
787
788         if (do_accounting)
789                 vfio_lock_acct(dma, locked - unlocked, true);
790
791         return unlocked;
792 }
793
794 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
795                                   unsigned long *pfn_base, bool do_accounting)
796 {
797         struct page *pages[1];
798         struct mm_struct *mm;
799         int ret;
800
801         mm = get_task_mm(dma->task);
802         if (!mm)
803                 return -ENODEV;
804
805         ret = vaddr_get_pfns(mm, vaddr, 1, dma->prot, pfn_base, pages);
806         if (ret != 1)
807                 goto out;
808
809         ret = 0;
810
811         if (do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
812                 ret = vfio_lock_acct(dma, 1, true);
813                 if (ret) {
814                         put_pfn(*pfn_base, dma->prot);
815                         if (ret == -ENOMEM)
816                                 pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
817                                         "(%ld) exceeded\n", __func__,
818                                         dma->task->comm, task_pid_nr(dma->task),
819                                         task_rlimit(dma->task, RLIMIT_MEMLOCK));
820                 }
821         }
822
823 out:
824         mmput(mm);
825         return ret;
826 }
827
828 static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
829                                     bool do_accounting)
830 {
831         int unlocked;
832         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
833
834         if (!vpfn)
835                 return 0;
836
837         unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
838
839         if (do_accounting)
840                 vfio_lock_acct(dma, -unlocked, true);
841
842         return unlocked;
843 }
844
845 static int vfio_iommu_type1_pin_pages(void *iommu_data,
846                                       struct iommu_group *iommu_group,
847                                       unsigned long *user_pfn,
848                                       int npage, int prot,
849                                       unsigned long *phys_pfn)
850 {
851         struct vfio_iommu *iommu = iommu_data;
852         struct vfio_iommu_group *group;
853         int i, j, ret;
854         unsigned long remote_vaddr;
855         struct vfio_dma *dma;
856         bool do_accounting;
857         dma_addr_t iova;
858
859         if (!iommu || !user_pfn || !phys_pfn)
860                 return -EINVAL;
861
862         /* Supported for v2 version only */
863         if (!iommu->v2)
864                 return -EACCES;
865
866         mutex_lock(&iommu->lock);
867
868         /*
869          * Wait for all necessary vaddr's to be valid so they can be used in
870          * the main loop without dropping the lock, to avoid racing vs unmap.
871          */
872 again:
873         if (iommu->vaddr_invalid_count) {
874                 for (i = 0; i < npage; i++) {
875                         iova = user_pfn[i] << PAGE_SHIFT;
876                         ret = vfio_find_dma_valid(iommu, iova, PAGE_SIZE, &dma);
877                         if (ret < 0)
878                                 goto pin_done;
879                         if (ret == WAITED)
880                                 goto again;
881                 }
882         }
883
884         /* Fail if notifier list is empty */
885         if (!iommu->notifier.head) {
886                 ret = -EINVAL;
887                 goto pin_done;
888         }
889
890         /*
891          * If iommu capable domain exist in the container then all pages are
892          * already pinned and accounted. Accounting should be done if there is no
893          * iommu capable domain in the container.
894          */
895         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
896
897         for (i = 0; i < npage; i++) {
898                 struct vfio_pfn *vpfn;
899
900                 iova = user_pfn[i] << PAGE_SHIFT;
901                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
902                 if (!dma) {
903                         ret = -EINVAL;
904                         goto pin_unwind;
905                 }
906
907                 if ((dma->prot & prot) != prot) {
908                         ret = -EPERM;
909                         goto pin_unwind;
910                 }
911
912                 vpfn = vfio_iova_get_vfio_pfn(dma, iova);
913                 if (vpfn) {
914                         phys_pfn[i] = vpfn->pfn;
915                         continue;
916                 }
917
918                 remote_vaddr = dma->vaddr + (iova - dma->iova);
919                 ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn[i],
920                                              do_accounting);
921                 if (ret)
922                         goto pin_unwind;
923
924                 ret = vfio_add_to_pfn_list(dma, iova, phys_pfn[i]);
925                 if (ret) {
926                         if (put_pfn(phys_pfn[i], dma->prot) && do_accounting)
927                                 vfio_lock_acct(dma, -1, true);
928                         goto pin_unwind;
929                 }
930
931                 if (iommu->dirty_page_tracking) {
932                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
933
934                         /*
935                          * Bitmap populated with the smallest supported page
936                          * size
937                          */
938                         bitmap_set(dma->bitmap,
939                                    (iova - dma->iova) >> pgshift, 1);
940                 }
941         }
942         ret = i;
943
944         group = vfio_iommu_find_iommu_group(iommu, iommu_group);
945         if (!group->pinned_page_dirty_scope) {
946                 group->pinned_page_dirty_scope = true;
947                 iommu->num_non_pinned_groups--;
948         }
949
950         goto pin_done;
951
952 pin_unwind:
953         phys_pfn[i] = 0;
954         for (j = 0; j < i; j++) {
955                 dma_addr_t iova;
956
957                 iova = user_pfn[j] << PAGE_SHIFT;
958                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
959                 vfio_unpin_page_external(dma, iova, do_accounting);
960                 phys_pfn[j] = 0;
961         }
962 pin_done:
963         mutex_unlock(&iommu->lock);
964         return ret;
965 }
966
967 static int vfio_iommu_type1_unpin_pages(void *iommu_data,
968                                         unsigned long *user_pfn,
969                                         int npage)
970 {
971         struct vfio_iommu *iommu = iommu_data;
972         bool do_accounting;
973         int i;
974
975         if (!iommu || !user_pfn || npage <= 0)
976                 return -EINVAL;
977
978         /* Supported for v2 version only */
979         if (!iommu->v2)
980                 return -EACCES;
981
982         mutex_lock(&iommu->lock);
983
984         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
985         for (i = 0; i < npage; i++) {
986                 struct vfio_dma *dma;
987                 dma_addr_t iova;
988
989                 iova = user_pfn[i] << PAGE_SHIFT;
990                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
991                 if (!dma)
992                         break;
993
994                 vfio_unpin_page_external(dma, iova, do_accounting);
995         }
996
997         mutex_unlock(&iommu->lock);
998         return i > 0 ? i : -EINVAL;
999 }
1000
1001 static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
1002                             struct list_head *regions,
1003                             struct iommu_iotlb_gather *iotlb_gather)
1004 {
1005         long unlocked = 0;
1006         struct vfio_regions *entry, *next;
1007
1008         iommu_iotlb_sync(domain->domain, iotlb_gather);
1009
1010         list_for_each_entry_safe(entry, next, regions, list) {
1011                 unlocked += vfio_unpin_pages_remote(dma,
1012                                                     entry->iova,
1013                                                     entry->phys >> PAGE_SHIFT,
1014                                                     entry->len >> PAGE_SHIFT,
1015                                                     false);
1016                 list_del(&entry->list);
1017                 kfree(entry);
1018         }
1019
1020         cond_resched();
1021
1022         return unlocked;
1023 }
1024
1025 /*
1026  * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
1027  * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
1028  * of these regions (currently using a list).
1029  *
1030  * This value specifies maximum number of regions for each IOTLB flush sync.
1031  */
1032 #define VFIO_IOMMU_TLB_SYNC_MAX         512
1033
1034 static size_t unmap_unpin_fast(struct vfio_domain *domain,
1035                                struct vfio_dma *dma, dma_addr_t *iova,
1036                                size_t len, phys_addr_t phys, long *unlocked,
1037                                struct list_head *unmapped_list,
1038                                int *unmapped_cnt,
1039                                struct iommu_iotlb_gather *iotlb_gather)
1040 {
1041         size_t unmapped = 0;
1042         struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
1043
1044         if (entry) {
1045                 unmapped = iommu_unmap_fast(domain->domain, *iova, len,
1046                                             iotlb_gather);
1047
1048                 if (!unmapped) {
1049                         kfree(entry);
1050                 } else {
1051                         entry->iova = *iova;
1052                         entry->phys = phys;
1053                         entry->len  = unmapped;
1054                         list_add_tail(&entry->list, unmapped_list);
1055
1056                         *iova += unmapped;
1057                         (*unmapped_cnt)++;
1058                 }
1059         }
1060
1061         /*
1062          * Sync if the number of fast-unmap regions hits the limit
1063          * or in case of errors.
1064          */
1065         if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
1066                 *unlocked += vfio_sync_unpin(dma, domain, unmapped_list,
1067                                              iotlb_gather);
1068                 *unmapped_cnt = 0;
1069         }
1070
1071         return unmapped;
1072 }
1073
1074 static size_t unmap_unpin_slow(struct vfio_domain *domain,
1075                                struct vfio_dma *dma, dma_addr_t *iova,
1076                                size_t len, phys_addr_t phys,
1077                                long *unlocked)
1078 {
1079         size_t unmapped = iommu_unmap(domain->domain, *iova, len);
1080
1081         if (unmapped) {
1082                 *unlocked += vfio_unpin_pages_remote(dma, *iova,
1083                                                      phys >> PAGE_SHIFT,
1084                                                      unmapped >> PAGE_SHIFT,
1085                                                      false);
1086                 *iova += unmapped;
1087                 cond_resched();
1088         }
1089         return unmapped;
1090 }
1091
1092 static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
1093                              bool do_accounting)
1094 {
1095         dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
1096         struct vfio_domain *domain, *d;
1097         LIST_HEAD(unmapped_region_list);
1098         struct iommu_iotlb_gather iotlb_gather;
1099         int unmapped_region_cnt = 0;
1100         long unlocked = 0;
1101
1102         if (!dma->size)
1103                 return 0;
1104
1105         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1106                 return 0;
1107
1108         /*
1109          * We use the IOMMU to track the physical addresses, otherwise we'd
1110          * need a much more complicated tracking system.  Unfortunately that
1111          * means we need to use one of the iommu domains to figure out the
1112          * pfns to unpin.  The rest need to be unmapped in advance so we have
1113          * no iommu translations remaining when the pages are unpinned.
1114          */
1115         domain = d = list_first_entry(&iommu->domain_list,
1116                                       struct vfio_domain, next);
1117
1118         list_for_each_entry_continue(d, &iommu->domain_list, next) {
1119                 iommu_unmap(d->domain, dma->iova, dma->size);
1120                 cond_resched();
1121         }
1122
1123         iommu_iotlb_gather_init(&iotlb_gather);
1124         while (iova < end) {
1125                 size_t unmapped, len;
1126                 phys_addr_t phys, next;
1127
1128                 phys = iommu_iova_to_phys(domain->domain, iova);
1129                 if (WARN_ON(!phys)) {
1130                         iova += PAGE_SIZE;
1131                         continue;
1132                 }
1133
1134                 /*
1135                  * To optimize for fewer iommu_unmap() calls, each of which
1136                  * may require hardware cache flushing, try to find the
1137                  * largest contiguous physical memory chunk to unmap.
1138                  */
1139                 for (len = PAGE_SIZE;
1140                      !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
1141                         next = iommu_iova_to_phys(domain->domain, iova + len);
1142                         if (next != phys + len)
1143                                 break;
1144                 }
1145
1146                 /*
1147                  * First, try to use fast unmap/unpin. In case of failure,
1148                  * switch to slow unmap/unpin path.
1149                  */
1150                 unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
1151                                             &unlocked, &unmapped_region_list,
1152                                             &unmapped_region_cnt,
1153                                             &iotlb_gather);
1154                 if (!unmapped) {
1155                         unmapped = unmap_unpin_slow(domain, dma, &iova, len,
1156                                                     phys, &unlocked);
1157                         if (WARN_ON(!unmapped))
1158                                 break;
1159                 }
1160         }
1161
1162         dma->iommu_mapped = false;
1163
1164         if (unmapped_region_cnt) {
1165                 unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list,
1166                                             &iotlb_gather);
1167         }
1168
1169         if (do_accounting) {
1170                 vfio_lock_acct(dma, -unlocked, true);
1171                 return 0;
1172         }
1173         return unlocked;
1174 }
1175
1176 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
1177 {
1178         WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list));
1179         vfio_unmap_unpin(iommu, dma, true);
1180         vfio_unlink_dma(iommu, dma);
1181         put_task_struct(dma->task);
1182         vfio_dma_bitmap_free(dma);
1183         if (dma->vaddr_invalid) {
1184                 iommu->vaddr_invalid_count--;
1185                 wake_up_all(&iommu->vaddr_wait);
1186         }
1187         kfree(dma);
1188         iommu->dma_avail++;
1189 }
1190
1191 static void vfio_update_pgsize_bitmap(struct vfio_iommu *iommu)
1192 {
1193         struct vfio_domain *domain;
1194
1195         iommu->pgsize_bitmap = ULONG_MAX;
1196
1197         list_for_each_entry(domain, &iommu->domain_list, next)
1198                 iommu->pgsize_bitmap &= domain->domain->pgsize_bitmap;
1199
1200         /*
1201          * In case the IOMMU supports page sizes smaller than PAGE_SIZE
1202          * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
1203          * That way the user will be able to map/unmap buffers whose size/
1204          * start address is aligned with PAGE_SIZE. Pinning code uses that
1205          * granularity while iommu driver can use the sub-PAGE_SIZE size
1206          * to map the buffer.
1207          */
1208         if (iommu->pgsize_bitmap & ~PAGE_MASK) {
1209                 iommu->pgsize_bitmap &= PAGE_MASK;
1210                 iommu->pgsize_bitmap |= PAGE_SIZE;
1211         }
1212 }
1213
1214 static int update_user_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1215                               struct vfio_dma *dma, dma_addr_t base_iova,
1216                               size_t pgsize)
1217 {
1218         unsigned long pgshift = __ffs(pgsize);
1219         unsigned long nbits = dma->size >> pgshift;
1220         unsigned long bit_offset = (dma->iova - base_iova) >> pgshift;
1221         unsigned long copy_offset = bit_offset / BITS_PER_LONG;
1222         unsigned long shift = bit_offset % BITS_PER_LONG;
1223         unsigned long leftover;
1224
1225         /*
1226          * mark all pages dirty if any IOMMU capable device is not able
1227          * to report dirty pages and all pages are pinned and mapped.
1228          */
1229         if (iommu->num_non_pinned_groups && dma->iommu_mapped)
1230                 bitmap_set(dma->bitmap, 0, nbits);
1231
1232         if (shift) {
1233                 bitmap_shift_left(dma->bitmap, dma->bitmap, shift,
1234                                   nbits + shift);
1235
1236                 if (copy_from_user(&leftover,
1237                                    (void __user *)(bitmap + copy_offset),
1238                                    sizeof(leftover)))
1239                         return -EFAULT;
1240
1241                 bitmap_or(dma->bitmap, dma->bitmap, &leftover, shift);
1242         }
1243
1244         if (copy_to_user((void __user *)(bitmap + copy_offset), dma->bitmap,
1245                          DIRTY_BITMAP_BYTES(nbits + shift)))
1246                 return -EFAULT;
1247
1248         return 0;
1249 }
1250
1251 static int vfio_iova_dirty_bitmap(u64 __user *bitmap, struct vfio_iommu *iommu,
1252                                   dma_addr_t iova, size_t size, size_t pgsize)
1253 {
1254         struct vfio_dma *dma;
1255         struct rb_node *n;
1256         unsigned long pgshift = __ffs(pgsize);
1257         int ret;
1258
1259         /*
1260          * GET_BITMAP request must fully cover vfio_dma mappings.  Multiple
1261          * vfio_dma mappings may be clubbed by specifying large ranges, but
1262          * there must not be any previous mappings bisected by the range.
1263          * An error will be returned if these conditions are not met.
1264          */
1265         dma = vfio_find_dma(iommu, iova, 1);
1266         if (dma && dma->iova != iova)
1267                 return -EINVAL;
1268
1269         dma = vfio_find_dma(iommu, iova + size - 1, 0);
1270         if (dma && dma->iova + dma->size != iova + size)
1271                 return -EINVAL;
1272
1273         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1274                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1275
1276                 if (dma->iova < iova)
1277                         continue;
1278
1279                 if (dma->iova > iova + size - 1)
1280                         break;
1281
1282                 ret = update_user_bitmap(bitmap, iommu, dma, iova, pgsize);
1283                 if (ret)
1284                         return ret;
1285
1286                 /*
1287                  * Re-populate bitmap to include all pinned pages which are
1288                  * considered as dirty but exclude pages which are unpinned and
1289                  * pages which are marked dirty by vfio_dma_rw()
1290                  */
1291                 bitmap_clear(dma->bitmap, 0, dma->size >> pgshift);
1292                 vfio_dma_populate_bitmap(dma, pgsize);
1293         }
1294         return 0;
1295 }
1296
1297 static int verify_bitmap_size(uint64_t npages, uint64_t bitmap_size)
1298 {
1299         if (!npages || !bitmap_size || (bitmap_size > DIRTY_BITMAP_SIZE_MAX) ||
1300             (bitmap_size < DIRTY_BITMAP_BYTES(npages)))
1301                 return -EINVAL;
1302
1303         return 0;
1304 }
1305
1306 static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
1307                              struct vfio_iommu_type1_dma_unmap *unmap,
1308                              struct vfio_bitmap *bitmap)
1309 {
1310         struct vfio_dma *dma, *dma_last = NULL;
1311         size_t unmapped = 0, pgsize;
1312         int ret = -EINVAL, retries = 0;
1313         unsigned long pgshift;
1314         dma_addr_t iova = unmap->iova;
1315         u64 size = unmap->size;
1316         bool unmap_all = unmap->flags & VFIO_DMA_UNMAP_FLAG_ALL;
1317         bool invalidate_vaddr = unmap->flags & VFIO_DMA_UNMAP_FLAG_VADDR;
1318         struct rb_node *n, *first_n;
1319
1320         mutex_lock(&iommu->lock);
1321
1322         pgshift = __ffs(iommu->pgsize_bitmap);
1323         pgsize = (size_t)1 << pgshift;
1324
1325         if (iova & (pgsize - 1))
1326                 goto unlock;
1327
1328         if (unmap_all) {
1329                 if (iova || size)
1330                         goto unlock;
1331                 size = U64_MAX;
1332         } else if (!size || size & (pgsize - 1) ||
1333                    iova + size - 1 < iova || size > SIZE_MAX) {
1334                 goto unlock;
1335         }
1336
1337         /* When dirty tracking is enabled, allow only min supported pgsize */
1338         if ((unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
1339             (!iommu->dirty_page_tracking || (bitmap->pgsize != pgsize))) {
1340                 goto unlock;
1341         }
1342
1343         WARN_ON((pgsize - 1) & PAGE_MASK);
1344 again:
1345         /*
1346          * vfio-iommu-type1 (v1) - User mappings were coalesced together to
1347          * avoid tracking individual mappings.  This means that the granularity
1348          * of the original mapping was lost and the user was allowed to attempt
1349          * to unmap any range.  Depending on the contiguousness of physical
1350          * memory and page sizes supported by the IOMMU, arbitrary unmaps may
1351          * or may not have worked.  We only guaranteed unmap granularity
1352          * matching the original mapping; even though it was untracked here,
1353          * the original mappings are reflected in IOMMU mappings.  This
1354          * resulted in a couple unusual behaviors.  First, if a range is not
1355          * able to be unmapped, ex. a set of 4k pages that was mapped as a
1356          * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
1357          * a zero sized unmap.  Also, if an unmap request overlaps the first
1358          * address of a hugepage, the IOMMU will unmap the entire hugepage.
1359          * This also returns success and the returned unmap size reflects the
1360          * actual size unmapped.
1361          *
1362          * We attempt to maintain compatibility with this "v1" interface, but
1363          * we take control out of the hands of the IOMMU.  Therefore, an unmap
1364          * request offset from the beginning of the original mapping will
1365          * return success with zero sized unmap.  And an unmap request covering
1366          * the first iova of mapping will unmap the entire range.
1367          *
1368          * The v2 version of this interface intends to be more deterministic.
1369          * Unmap requests must fully cover previous mappings.  Multiple
1370          * mappings may still be unmaped by specifying large ranges, but there
1371          * must not be any previous mappings bisected by the range.  An error
1372          * will be returned if these conditions are not met.  The v2 interface
1373          * will only return success and a size of zero if there were no
1374          * mappings within the range.
1375          */
1376         if (iommu->v2 && !unmap_all) {
1377                 dma = vfio_find_dma(iommu, iova, 1);
1378                 if (dma && dma->iova != iova)
1379                         goto unlock;
1380
1381                 dma = vfio_find_dma(iommu, iova + size - 1, 0);
1382                 if (dma && dma->iova + dma->size != iova + size)
1383                         goto unlock;
1384         }
1385
1386         ret = 0;
1387         n = first_n = vfio_find_dma_first_node(iommu, iova, size);
1388
1389         while (n) {
1390                 dma = rb_entry(n, struct vfio_dma, node);
1391                 if (dma->iova >= iova + size)
1392                         break;
1393
1394                 if (!iommu->v2 && iova > dma->iova)
1395                         break;
1396                 /*
1397                  * Task with same address space who mapped this iova range is
1398                  * allowed to unmap the iova range.
1399                  */
1400                 if (dma->task->mm != current->mm)
1401                         break;
1402
1403                 if (invalidate_vaddr) {
1404                         if (dma->vaddr_invalid) {
1405                                 struct rb_node *last_n = n;
1406
1407                                 for (n = first_n; n != last_n; n = rb_next(n)) {
1408                                         dma = rb_entry(n,
1409                                                        struct vfio_dma, node);
1410                                         dma->vaddr_invalid = false;
1411                                         iommu->vaddr_invalid_count--;
1412                                 }
1413                                 ret = -EINVAL;
1414                                 unmapped = 0;
1415                                 break;
1416                         }
1417                         dma->vaddr_invalid = true;
1418                         iommu->vaddr_invalid_count++;
1419                         unmapped += dma->size;
1420                         n = rb_next(n);
1421                         continue;
1422                 }
1423
1424                 if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
1425                         struct vfio_iommu_type1_dma_unmap nb_unmap;
1426
1427                         if (dma_last == dma) {
1428                                 BUG_ON(++retries > 10);
1429                         } else {
1430                                 dma_last = dma;
1431                                 retries = 0;
1432                         }
1433
1434                         nb_unmap.iova = dma->iova;
1435                         nb_unmap.size = dma->size;
1436
1437                         /*
1438                          * Notify anyone (mdev vendor drivers) to invalidate and
1439                          * unmap iovas within the range we're about to unmap.
1440                          * Vendor drivers MUST unpin pages in response to an
1441                          * invalidation.
1442                          */
1443                         mutex_unlock(&iommu->lock);
1444                         blocking_notifier_call_chain(&iommu->notifier,
1445                                                     VFIO_IOMMU_NOTIFY_DMA_UNMAP,
1446                                                     &nb_unmap);
1447                         mutex_lock(&iommu->lock);
1448                         goto again;
1449                 }
1450
1451                 if (unmap->flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
1452                         ret = update_user_bitmap(bitmap->data, iommu, dma,
1453                                                  iova, pgsize);
1454                         if (ret)
1455                                 break;
1456                 }
1457
1458                 unmapped += dma->size;
1459                 n = rb_next(n);
1460                 vfio_remove_dma(iommu, dma);
1461         }
1462
1463 unlock:
1464         mutex_unlock(&iommu->lock);
1465
1466         /* Report how much was unmapped */
1467         unmap->size = unmapped;
1468
1469         return ret;
1470 }
1471
1472 static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
1473                           unsigned long pfn, long npage, int prot)
1474 {
1475         struct vfio_domain *d;
1476         int ret;
1477
1478         list_for_each_entry(d, &iommu->domain_list, next) {
1479                 ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
1480                                 npage << PAGE_SHIFT, prot | d->prot);
1481                 if (ret)
1482                         goto unwind;
1483
1484                 cond_resched();
1485         }
1486
1487         return 0;
1488
1489 unwind:
1490         list_for_each_entry_continue_reverse(d, &iommu->domain_list, next) {
1491                 iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
1492                 cond_resched();
1493         }
1494
1495         return ret;
1496 }
1497
1498 static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
1499                             size_t map_size)
1500 {
1501         dma_addr_t iova = dma->iova;
1502         unsigned long vaddr = dma->vaddr;
1503         struct vfio_batch batch;
1504         size_t size = map_size;
1505         long npage;
1506         unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1507         int ret = 0;
1508
1509         vfio_batch_init(&batch);
1510
1511         while (size) {
1512                 /* Pin a contiguous chunk of memory */
1513                 npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1514                                               size >> PAGE_SHIFT, &pfn, limit,
1515                                               &batch);
1516                 if (npage <= 0) {
1517                         WARN_ON(!npage);
1518                         ret = (int)npage;
1519                         break;
1520                 }
1521
1522                 /* Map it! */
1523                 ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1524                                      dma->prot);
1525                 if (ret) {
1526                         vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1527                                                 npage, true);
1528                         vfio_batch_unpin(&batch, dma);
1529                         break;
1530                 }
1531
1532                 size -= npage << PAGE_SHIFT;
1533                 dma->size += npage << PAGE_SHIFT;
1534         }
1535
1536         vfio_batch_fini(&batch);
1537         dma->iommu_mapped = true;
1538
1539         if (ret)
1540                 vfio_remove_dma(iommu, dma);
1541
1542         return ret;
1543 }
1544
1545 /*
1546  * Check dma map request is within a valid iova range
1547  */
1548 static bool vfio_iommu_iova_dma_valid(struct vfio_iommu *iommu,
1549                                       dma_addr_t start, dma_addr_t end)
1550 {
1551         struct list_head *iova = &iommu->iova_list;
1552         struct vfio_iova *node;
1553
1554         list_for_each_entry(node, iova, list) {
1555                 if (start >= node->start && end <= node->end)
1556                         return true;
1557         }
1558
1559         /*
1560          * Check for list_empty() as well since a container with
1561          * a single mdev device will have an empty list.
1562          */
1563         return list_empty(iova);
1564 }
1565
1566 static int vfio_dma_do_map(struct vfio_iommu *iommu,
1567                            struct vfio_iommu_type1_dma_map *map)
1568 {
1569         bool set_vaddr = map->flags & VFIO_DMA_MAP_FLAG_VADDR;
1570         dma_addr_t iova = map->iova;
1571         unsigned long vaddr = map->vaddr;
1572         size_t size = map->size;
1573         int ret = 0, prot = 0;
1574         size_t pgsize;
1575         struct vfio_dma *dma;
1576
1577         /* Verify that none of our __u64 fields overflow */
1578         if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1579                 return -EINVAL;
1580
1581         /* READ/WRITE from device perspective */
1582         if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1583                 prot |= IOMMU_WRITE;
1584         if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1585                 prot |= IOMMU_READ;
1586
1587         if ((prot && set_vaddr) || (!prot && !set_vaddr))
1588                 return -EINVAL;
1589
1590         mutex_lock(&iommu->lock);
1591
1592         pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
1593
1594         WARN_ON((pgsize - 1) & PAGE_MASK);
1595
1596         if (!size || (size | iova | vaddr) & (pgsize - 1)) {
1597                 ret = -EINVAL;
1598                 goto out_unlock;
1599         }
1600
1601         /* Don't allow IOVA or virtual address wrap */
1602         if (iova + size - 1 < iova || vaddr + size - 1 < vaddr) {
1603                 ret = -EINVAL;
1604                 goto out_unlock;
1605         }
1606
1607         dma = vfio_find_dma(iommu, iova, size);
1608         if (set_vaddr) {
1609                 if (!dma) {
1610                         ret = -ENOENT;
1611                 } else if (!dma->vaddr_invalid || dma->iova != iova ||
1612                            dma->size != size) {
1613                         ret = -EINVAL;
1614                 } else {
1615                         dma->vaddr = vaddr;
1616                         dma->vaddr_invalid = false;
1617                         iommu->vaddr_invalid_count--;
1618                         wake_up_all(&iommu->vaddr_wait);
1619                 }
1620                 goto out_unlock;
1621         } else if (dma) {
1622                 ret = -EEXIST;
1623                 goto out_unlock;
1624         }
1625
1626         if (!iommu->dma_avail) {
1627                 ret = -ENOSPC;
1628                 goto out_unlock;
1629         }
1630
1631         if (!vfio_iommu_iova_dma_valid(iommu, iova, iova + size - 1)) {
1632                 ret = -EINVAL;
1633                 goto out_unlock;
1634         }
1635
1636         dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1637         if (!dma) {
1638                 ret = -ENOMEM;
1639                 goto out_unlock;
1640         }
1641
1642         iommu->dma_avail--;
1643         dma->iova = iova;
1644         dma->vaddr = vaddr;
1645         dma->prot = prot;
1646
1647         /*
1648          * We need to be able to both add to a task's locked memory and test
1649          * against the locked memory limit and we need to be able to do both
1650          * outside of this call path as pinning can be asynchronous via the
1651          * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1652          * task_struct and VM locked pages requires an mm_struct, however
1653          * holding an indefinite mm reference is not recommended, therefore we
1654          * only hold a reference to a task.  We could hold a reference to
1655          * current, however QEMU uses this call path through vCPU threads,
1656          * which can be killed resulting in a NULL mm and failure in the unmap
1657          * path when called via a different thread.  Avoid this problem by
1658          * using the group_leader as threads within the same group require
1659          * both CLONE_THREAD and CLONE_VM and will therefore use the same
1660          * mm_struct.
1661          *
1662          * Previously we also used the task for testing CAP_IPC_LOCK at the
1663          * time of pinning and accounting, however has_capability() makes use
1664          * of real_cred, a copy-on-write field, so we can't guarantee that it
1665          * matches group_leader, or in fact that it might not change by the
1666          * time it's evaluated.  If a process were to call MAP_DMA with
1667          * CAP_IPC_LOCK but later drop it, it doesn't make sense that they
1668          * possibly see different results for an iommu_mapped vfio_dma vs
1669          * externally mapped.  Therefore track CAP_IPC_LOCK in vfio_dma at the
1670          * time of calling MAP_DMA.
1671          */
1672         get_task_struct(current->group_leader);
1673         dma->task = current->group_leader;
1674         dma->lock_cap = capable(CAP_IPC_LOCK);
1675
1676         dma->pfn_list = RB_ROOT;
1677
1678         /* Insert zero-sized and grow as we map chunks of it */
1679         vfio_link_dma(iommu, dma);
1680
1681         /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1682         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1683                 dma->size = size;
1684         else
1685                 ret = vfio_pin_map_dma(iommu, dma, size);
1686
1687         if (!ret && iommu->dirty_page_tracking) {
1688                 ret = vfio_dma_bitmap_alloc(dma, pgsize);
1689                 if (ret)
1690                         vfio_remove_dma(iommu, dma);
1691         }
1692
1693 out_unlock:
1694         mutex_unlock(&iommu->lock);
1695         return ret;
1696 }
1697
1698 static int vfio_bus_type(struct device *dev, void *data)
1699 {
1700         struct bus_type **bus = data;
1701
1702         if (*bus && *bus != dev->bus)
1703                 return -EINVAL;
1704
1705         *bus = dev->bus;
1706
1707         return 0;
1708 }
1709
1710 static int vfio_iommu_replay(struct vfio_iommu *iommu,
1711                              struct vfio_domain *domain)
1712 {
1713         struct vfio_batch batch;
1714         struct vfio_domain *d = NULL;
1715         struct rb_node *n;
1716         unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1717         int ret;
1718
1719         ret = vfio_wait_all_valid(iommu);
1720         if (ret < 0)
1721                 return ret;
1722
1723         /* Arbitrarily pick the first domain in the list for lookups */
1724         if (!list_empty(&iommu->domain_list))
1725                 d = list_first_entry(&iommu->domain_list,
1726                                      struct vfio_domain, next);
1727
1728         vfio_batch_init(&batch);
1729
1730         n = rb_first(&iommu->dma_list);
1731
1732         for (; n; n = rb_next(n)) {
1733                 struct vfio_dma *dma;
1734                 dma_addr_t iova;
1735
1736                 dma = rb_entry(n, struct vfio_dma, node);
1737                 iova = dma->iova;
1738
1739                 while (iova < dma->iova + dma->size) {
1740                         phys_addr_t phys;
1741                         size_t size;
1742
1743                         if (dma->iommu_mapped) {
1744                                 phys_addr_t p;
1745                                 dma_addr_t i;
1746
1747                                 if (WARN_ON(!d)) { /* mapped w/o a domain?! */
1748                                         ret = -EINVAL;
1749                                         goto unwind;
1750                                 }
1751
1752                                 phys = iommu_iova_to_phys(d->domain, iova);
1753
1754                                 if (WARN_ON(!phys)) {
1755                                         iova += PAGE_SIZE;
1756                                         continue;
1757                                 }
1758
1759                                 size = PAGE_SIZE;
1760                                 p = phys + size;
1761                                 i = iova + size;
1762                                 while (i < dma->iova + dma->size &&
1763                                        p == iommu_iova_to_phys(d->domain, i)) {
1764                                         size += PAGE_SIZE;
1765                                         p += PAGE_SIZE;
1766                                         i += PAGE_SIZE;
1767                                 }
1768                         } else {
1769                                 unsigned long pfn;
1770                                 unsigned long vaddr = dma->vaddr +
1771                                                      (iova - dma->iova);
1772                                 size_t n = dma->iova + dma->size - iova;
1773                                 long npage;
1774
1775                                 npage = vfio_pin_pages_remote(dma, vaddr,
1776                                                               n >> PAGE_SHIFT,
1777                                                               &pfn, limit,
1778                                                               &batch);
1779                                 if (npage <= 0) {
1780                                         WARN_ON(!npage);
1781                                         ret = (int)npage;
1782                                         goto unwind;
1783                                 }
1784
1785                                 phys = pfn << PAGE_SHIFT;
1786                                 size = npage << PAGE_SHIFT;
1787                         }
1788
1789                         ret = iommu_map(domain->domain, iova, phys,
1790                                         size, dma->prot | domain->prot);
1791                         if (ret) {
1792                                 if (!dma->iommu_mapped) {
1793                                         vfio_unpin_pages_remote(dma, iova,
1794                                                         phys >> PAGE_SHIFT,
1795                                                         size >> PAGE_SHIFT,
1796                                                         true);
1797                                         vfio_batch_unpin(&batch, dma);
1798                                 }
1799                                 goto unwind;
1800                         }
1801
1802                         iova += size;
1803                 }
1804         }
1805
1806         /* All dmas are now mapped, defer to second tree walk for unwind */
1807         for (n = rb_first(&iommu->dma_list); n; n = rb_next(n)) {
1808                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1809
1810                 dma->iommu_mapped = true;
1811         }
1812
1813         vfio_batch_fini(&batch);
1814         return 0;
1815
1816 unwind:
1817         for (; n; n = rb_prev(n)) {
1818                 struct vfio_dma *dma = rb_entry(n, struct vfio_dma, node);
1819                 dma_addr_t iova;
1820
1821                 if (dma->iommu_mapped) {
1822                         iommu_unmap(domain->domain, dma->iova, dma->size);
1823                         continue;
1824                 }
1825
1826                 iova = dma->iova;
1827                 while (iova < dma->iova + dma->size) {
1828                         phys_addr_t phys, p;
1829                         size_t size;
1830                         dma_addr_t i;
1831
1832                         phys = iommu_iova_to_phys(domain->domain, iova);
1833                         if (!phys) {
1834                                 iova += PAGE_SIZE;
1835                                 continue;
1836                         }
1837
1838                         size = PAGE_SIZE;
1839                         p = phys + size;
1840                         i = iova + size;
1841                         while (i < dma->iova + dma->size &&
1842                                p == iommu_iova_to_phys(domain->domain, i)) {
1843                                 size += PAGE_SIZE;
1844                                 p += PAGE_SIZE;
1845                                 i += PAGE_SIZE;
1846                         }
1847
1848                         iommu_unmap(domain->domain, iova, size);
1849                         vfio_unpin_pages_remote(dma, iova, phys >> PAGE_SHIFT,
1850                                                 size >> PAGE_SHIFT, true);
1851                 }
1852         }
1853
1854         vfio_batch_fini(&batch);
1855         return ret;
1856 }
1857
1858 /*
1859  * We change our unmap behavior slightly depending on whether the IOMMU
1860  * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1861  * for practically any contiguous power-of-two mapping we give it.  This means
1862  * we don't need to look for contiguous chunks ourselves to make unmapping
1863  * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1864  * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1865  * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1866  * hugetlbfs is in use.
1867  */
1868 static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1869 {
1870         struct page *pages;
1871         int ret, order = get_order(PAGE_SIZE * 2);
1872
1873         pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1874         if (!pages)
1875                 return;
1876
1877         ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1878                         IOMMU_READ | IOMMU_WRITE | domain->prot);
1879         if (!ret) {
1880                 size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1881
1882                 if (unmapped == PAGE_SIZE)
1883                         iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1884                 else
1885                         domain->fgsp = true;
1886         }
1887
1888         __free_pages(pages, order);
1889 }
1890
1891 static struct vfio_iommu_group *find_iommu_group(struct vfio_domain *domain,
1892                                                  struct iommu_group *iommu_group)
1893 {
1894         struct vfio_iommu_group *g;
1895
1896         list_for_each_entry(g, &domain->group_list, next) {
1897                 if (g->iommu_group == iommu_group)
1898                         return g;
1899         }
1900
1901         return NULL;
1902 }
1903
1904 static struct vfio_iommu_group*
1905 vfio_iommu_find_iommu_group(struct vfio_iommu *iommu,
1906                             struct iommu_group *iommu_group)
1907 {
1908         struct vfio_domain *domain;
1909         struct vfio_iommu_group *group = NULL;
1910
1911         list_for_each_entry(domain, &iommu->domain_list, next) {
1912                 group = find_iommu_group(domain, iommu_group);
1913                 if (group)
1914                         return group;
1915         }
1916
1917         if (iommu->external_domain)
1918                 group = find_iommu_group(iommu->external_domain, iommu_group);
1919
1920         return group;
1921 }
1922
1923 static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions,
1924                                   phys_addr_t *base)
1925 {
1926         struct iommu_resv_region *region;
1927         bool ret = false;
1928
1929         list_for_each_entry(region, group_resv_regions, list) {
1930                 /*
1931                  * The presence of any 'real' MSI regions should take
1932                  * precedence over the software-managed one if the
1933                  * IOMMU driver happens to advertise both types.
1934                  */
1935                 if (region->type == IOMMU_RESV_MSI) {
1936                         ret = false;
1937                         break;
1938                 }
1939
1940                 if (region->type == IOMMU_RESV_SW_MSI) {
1941                         *base = region->start;
1942                         ret = true;
1943                 }
1944         }
1945
1946         return ret;
1947 }
1948
1949 static int vfio_mdev_attach_domain(struct device *dev, void *data)
1950 {
1951         struct mdev_device *mdev = to_mdev_device(dev);
1952         struct iommu_domain *domain = data;
1953         struct device *iommu_device;
1954
1955         iommu_device = mdev_get_iommu_device(mdev);
1956         if (iommu_device) {
1957                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1958                         return iommu_aux_attach_device(domain, iommu_device);
1959                 else
1960                         return iommu_attach_device(domain, iommu_device);
1961         }
1962
1963         return -EINVAL;
1964 }
1965
1966 static int vfio_mdev_detach_domain(struct device *dev, void *data)
1967 {
1968         struct mdev_device *mdev = to_mdev_device(dev);
1969         struct iommu_domain *domain = data;
1970         struct device *iommu_device;
1971
1972         iommu_device = mdev_get_iommu_device(mdev);
1973         if (iommu_device) {
1974                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1975                         iommu_aux_detach_device(domain, iommu_device);
1976                 else
1977                         iommu_detach_device(domain, iommu_device);
1978         }
1979
1980         return 0;
1981 }
1982
1983 static int vfio_iommu_attach_group(struct vfio_domain *domain,
1984                                    struct vfio_iommu_group *group)
1985 {
1986         if (group->mdev_group)
1987                 return iommu_group_for_each_dev(group->iommu_group,
1988                                                 domain->domain,
1989                                                 vfio_mdev_attach_domain);
1990         else
1991                 return iommu_attach_group(domain->domain, group->iommu_group);
1992 }
1993
1994 static void vfio_iommu_detach_group(struct vfio_domain *domain,
1995                                     struct vfio_iommu_group *group)
1996 {
1997         if (group->mdev_group)
1998                 iommu_group_for_each_dev(group->iommu_group, domain->domain,
1999                                          vfio_mdev_detach_domain);
2000         else
2001                 iommu_detach_group(domain->domain, group->iommu_group);
2002 }
2003
2004 static bool vfio_bus_is_mdev(struct bus_type *bus)
2005 {
2006         struct bus_type *mdev_bus;
2007         bool ret = false;
2008
2009         mdev_bus = symbol_get(mdev_bus_type);
2010         if (mdev_bus) {
2011                 ret = (bus == mdev_bus);
2012                 symbol_put(mdev_bus_type);
2013         }
2014
2015         return ret;
2016 }
2017
2018 static int vfio_mdev_iommu_device(struct device *dev, void *data)
2019 {
2020         struct mdev_device *mdev = to_mdev_device(dev);
2021         struct device **old = data, *new;
2022
2023         new = mdev_get_iommu_device(mdev);
2024         if (!new || (*old && *old != new))
2025                 return -EINVAL;
2026
2027         *old = new;
2028
2029         return 0;
2030 }
2031
2032 /*
2033  * This is a helper function to insert an address range to iova list.
2034  * The list is initially created with a single entry corresponding to
2035  * the IOMMU domain geometry to which the device group is attached.
2036  * The list aperture gets modified when a new domain is added to the
2037  * container if the new aperture doesn't conflict with the current one
2038  * or with any existing dma mappings. The list is also modified to
2039  * exclude any reserved regions associated with the device group.
2040  */
2041 static int vfio_iommu_iova_insert(struct list_head *head,
2042                                   dma_addr_t start, dma_addr_t end)
2043 {
2044         struct vfio_iova *region;
2045
2046         region = kmalloc(sizeof(*region), GFP_KERNEL);
2047         if (!region)
2048                 return -ENOMEM;
2049
2050         INIT_LIST_HEAD(&region->list);
2051         region->start = start;
2052         region->end = end;
2053
2054         list_add_tail(&region->list, head);
2055         return 0;
2056 }
2057
2058 /*
2059  * Check the new iommu aperture conflicts with existing aper or with any
2060  * existing dma mappings.
2061  */
2062 static bool vfio_iommu_aper_conflict(struct vfio_iommu *iommu,
2063                                      dma_addr_t start, dma_addr_t end)
2064 {
2065         struct vfio_iova *first, *last;
2066         struct list_head *iova = &iommu->iova_list;
2067
2068         if (list_empty(iova))
2069                 return false;
2070
2071         /* Disjoint sets, return conflict */
2072         first = list_first_entry(iova, struct vfio_iova, list);
2073         last = list_last_entry(iova, struct vfio_iova, list);
2074         if (start > last->end || end < first->start)
2075                 return true;
2076
2077         /* Check for any existing dma mappings below the new start */
2078         if (start > first->start) {
2079                 if (vfio_find_dma(iommu, first->start, start - first->start))
2080                         return true;
2081         }
2082
2083         /* Check for any existing dma mappings beyond the new end */
2084         if (end < last->end) {
2085                 if (vfio_find_dma(iommu, end + 1, last->end - end))
2086                         return true;
2087         }
2088
2089         return false;
2090 }
2091
2092 /*
2093  * Resize iommu iova aperture window. This is called only if the new
2094  * aperture has no conflict with existing aperture and dma mappings.
2095  */
2096 static int vfio_iommu_aper_resize(struct list_head *iova,
2097                                   dma_addr_t start, dma_addr_t end)
2098 {
2099         struct vfio_iova *node, *next;
2100
2101         if (list_empty(iova))
2102                 return vfio_iommu_iova_insert(iova, start, end);
2103
2104         /* Adjust iova list start */
2105         list_for_each_entry_safe(node, next, iova, list) {
2106                 if (start < node->start)
2107                         break;
2108                 if (start >= node->start && start < node->end) {
2109                         node->start = start;
2110                         break;
2111                 }
2112                 /* Delete nodes before new start */
2113                 list_del(&node->list);
2114                 kfree(node);
2115         }
2116
2117         /* Adjust iova list end */
2118         list_for_each_entry_safe(node, next, iova, list) {
2119                 if (end > node->end)
2120                         continue;
2121                 if (end > node->start && end <= node->end) {
2122                         node->end = end;
2123                         continue;
2124                 }
2125                 /* Delete nodes after new end */
2126                 list_del(&node->list);
2127                 kfree(node);
2128         }
2129
2130         return 0;
2131 }
2132
2133 /*
2134  * Check reserved region conflicts with existing dma mappings
2135  */
2136 static bool vfio_iommu_resv_conflict(struct vfio_iommu *iommu,
2137                                      struct list_head *resv_regions)
2138 {
2139         struct iommu_resv_region *region;
2140
2141         /* Check for conflict with existing dma mappings */
2142         list_for_each_entry(region, resv_regions, list) {
2143                 if (region->type == IOMMU_RESV_DIRECT_RELAXABLE)
2144                         continue;
2145
2146                 if (vfio_find_dma(iommu, region->start, region->length))
2147                         return true;
2148         }
2149
2150         return false;
2151 }
2152
2153 /*
2154  * Check iova region overlap with  reserved regions and
2155  * exclude them from the iommu iova range
2156  */
2157 static int vfio_iommu_resv_exclude(struct list_head *iova,
2158                                    struct list_head *resv_regions)
2159 {
2160         struct iommu_resv_region *resv;
2161         struct vfio_iova *n, *next;
2162
2163         list_for_each_entry(resv, resv_regions, list) {
2164                 phys_addr_t start, end;
2165
2166                 if (resv->type == IOMMU_RESV_DIRECT_RELAXABLE)
2167                         continue;
2168
2169                 start = resv->start;
2170                 end = resv->start + resv->length - 1;
2171
2172                 list_for_each_entry_safe(n, next, iova, list) {
2173                         int ret = 0;
2174
2175                         /* No overlap */
2176                         if (start > n->end || end < n->start)
2177                                 continue;
2178                         /*
2179                          * Insert a new node if current node overlaps with the
2180                          * reserve region to exclude that from valid iova range.
2181                          * Note that, new node is inserted before the current
2182                          * node and finally the current node is deleted keeping
2183                          * the list updated and sorted.
2184                          */
2185                         if (start > n->start)
2186                                 ret = vfio_iommu_iova_insert(&n->list, n->start,
2187                                                              start - 1);
2188                         if (!ret && end < n->end)
2189                                 ret = vfio_iommu_iova_insert(&n->list, end + 1,
2190                                                              n->end);
2191                         if (ret)
2192                                 return ret;
2193
2194                         list_del(&n->list);
2195                         kfree(n);
2196                 }
2197         }
2198
2199         if (list_empty(iova))
2200                 return -EINVAL;
2201
2202         return 0;
2203 }
2204
2205 static void vfio_iommu_resv_free(struct list_head *resv_regions)
2206 {
2207         struct iommu_resv_region *n, *next;
2208
2209         list_for_each_entry_safe(n, next, resv_regions, list) {
2210                 list_del(&n->list);
2211                 kfree(n);
2212         }
2213 }
2214
2215 static void vfio_iommu_iova_free(struct list_head *iova)
2216 {
2217         struct vfio_iova *n, *next;
2218
2219         list_for_each_entry_safe(n, next, iova, list) {
2220                 list_del(&n->list);
2221                 kfree(n);
2222         }
2223 }
2224
2225 static int vfio_iommu_iova_get_copy(struct vfio_iommu *iommu,
2226                                     struct list_head *iova_copy)
2227 {
2228         struct list_head *iova = &iommu->iova_list;
2229         struct vfio_iova *n;
2230         int ret;
2231
2232         list_for_each_entry(n, iova, list) {
2233                 ret = vfio_iommu_iova_insert(iova_copy, n->start, n->end);
2234                 if (ret)
2235                         goto out_free;
2236         }
2237
2238         return 0;
2239
2240 out_free:
2241         vfio_iommu_iova_free(iova_copy);
2242         return ret;
2243 }
2244
2245 static void vfio_iommu_iova_insert_copy(struct vfio_iommu *iommu,
2246                                         struct list_head *iova_copy)
2247 {
2248         struct list_head *iova = &iommu->iova_list;
2249
2250         vfio_iommu_iova_free(iova);
2251
2252         list_splice_tail(iova_copy, iova);
2253 }
2254
2255 static int vfio_iommu_type1_attach_group(void *iommu_data,
2256                                          struct iommu_group *iommu_group)
2257 {
2258         struct vfio_iommu *iommu = iommu_data;
2259         struct vfio_iommu_group *group;
2260         struct vfio_domain *domain, *d;
2261         struct bus_type *bus = NULL;
2262         int ret;
2263         bool resv_msi, msi_remap;
2264         phys_addr_t resv_msi_base = 0;
2265         struct iommu_domain_geometry *geo;
2266         LIST_HEAD(iova_copy);
2267         LIST_HEAD(group_resv_regions);
2268
2269         mutex_lock(&iommu->lock);
2270
2271         /* Check for duplicates */
2272         if (vfio_iommu_find_iommu_group(iommu, iommu_group)) {
2273                 mutex_unlock(&iommu->lock);
2274                 return -EINVAL;
2275         }
2276
2277         group = kzalloc(sizeof(*group), GFP_KERNEL);
2278         domain = kzalloc(sizeof(*domain), GFP_KERNEL);
2279         if (!group || !domain) {
2280                 ret = -ENOMEM;
2281                 goto out_free;
2282         }
2283
2284         group->iommu_group = iommu_group;
2285
2286         /* Determine bus_type in order to allocate a domain */
2287         ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
2288         if (ret)
2289                 goto out_free;
2290
2291         if (vfio_bus_is_mdev(bus)) {
2292                 struct device *iommu_device = NULL;
2293
2294                 group->mdev_group = true;
2295
2296                 /* Determine the isolation type */
2297                 ret = iommu_group_for_each_dev(iommu_group, &iommu_device,
2298                                                vfio_mdev_iommu_device);
2299                 if (ret || !iommu_device) {
2300                         if (!iommu->external_domain) {
2301                                 INIT_LIST_HEAD(&domain->group_list);
2302                                 iommu->external_domain = domain;
2303                                 vfio_update_pgsize_bitmap(iommu);
2304                         } else {
2305                                 kfree(domain);
2306                         }
2307
2308                         list_add(&group->next,
2309                                  &iommu->external_domain->group_list);
2310                         /*
2311                          * Non-iommu backed group cannot dirty memory directly,
2312                          * it can only use interfaces that provide dirty
2313                          * tracking.
2314                          * The iommu scope can only be promoted with the
2315                          * addition of a dirty tracking group.
2316                          */
2317                         group->pinned_page_dirty_scope = true;
2318                         mutex_unlock(&iommu->lock);
2319
2320                         return 0;
2321                 }
2322
2323                 bus = iommu_device->bus;
2324         }
2325
2326         domain->domain = iommu_domain_alloc(bus);
2327         if (!domain->domain) {
2328                 ret = -EIO;
2329                 goto out_free;
2330         }
2331
2332         if (iommu->nesting) {
2333                 ret = iommu_enable_nesting(domain->domain);
2334                 if (ret)
2335                         goto out_domain;
2336         }
2337
2338         ret = vfio_iommu_attach_group(domain, group);
2339         if (ret)
2340                 goto out_domain;
2341
2342         /* Get aperture info */
2343         geo = &domain->domain->geometry;
2344         if (vfio_iommu_aper_conflict(iommu, geo->aperture_start,
2345                                      geo->aperture_end)) {
2346                 ret = -EINVAL;
2347                 goto out_detach;
2348         }
2349
2350         ret = iommu_get_group_resv_regions(iommu_group, &group_resv_regions);
2351         if (ret)
2352                 goto out_detach;
2353
2354         if (vfio_iommu_resv_conflict(iommu, &group_resv_regions)) {
2355                 ret = -EINVAL;
2356                 goto out_detach;
2357         }
2358
2359         /*
2360          * We don't want to work on the original iova list as the list
2361          * gets modified and in case of failure we have to retain the
2362          * original list. Get a copy here.
2363          */
2364         ret = vfio_iommu_iova_get_copy(iommu, &iova_copy);
2365         if (ret)
2366                 goto out_detach;
2367
2368         ret = vfio_iommu_aper_resize(&iova_copy, geo->aperture_start,
2369                                      geo->aperture_end);
2370         if (ret)
2371                 goto out_detach;
2372
2373         ret = vfio_iommu_resv_exclude(&iova_copy, &group_resv_regions);
2374         if (ret)
2375                 goto out_detach;
2376
2377         resv_msi = vfio_iommu_has_sw_msi(&group_resv_regions, &resv_msi_base);
2378
2379         INIT_LIST_HEAD(&domain->group_list);
2380         list_add(&group->next, &domain->group_list);
2381
2382         msi_remap = irq_domain_check_msi_remap() ||
2383                     iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
2384
2385         if (!allow_unsafe_interrupts && !msi_remap) {
2386                 pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
2387                        __func__);
2388                 ret = -EPERM;
2389                 goto out_detach;
2390         }
2391
2392         if (iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
2393                 domain->prot |= IOMMU_CACHE;
2394
2395         /*
2396          * Try to match an existing compatible domain.  We don't want to
2397          * preclude an IOMMU driver supporting multiple bus_types and being
2398          * able to include different bus_types in the same IOMMU domain, so
2399          * we test whether the domains use the same iommu_ops rather than
2400          * testing if they're on the same bus_type.
2401          */
2402         list_for_each_entry(d, &iommu->domain_list, next) {
2403                 if (d->domain->ops == domain->domain->ops &&
2404                     d->prot == domain->prot) {
2405                         vfio_iommu_detach_group(domain, group);
2406                         if (!vfio_iommu_attach_group(d, group)) {
2407                                 list_add(&group->next, &d->group_list);
2408                                 iommu_domain_free(domain->domain);
2409                                 kfree(domain);
2410                                 goto done;
2411                         }
2412
2413                         ret = vfio_iommu_attach_group(domain, group);
2414                         if (ret)
2415                                 goto out_domain;
2416                 }
2417         }
2418
2419         vfio_test_domain_fgsp(domain);
2420
2421         /* replay mappings on new domains */
2422         ret = vfio_iommu_replay(iommu, domain);
2423         if (ret)
2424                 goto out_detach;
2425
2426         if (resv_msi) {
2427                 ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
2428                 if (ret && ret != -ENODEV)
2429                         goto out_detach;
2430         }
2431
2432         list_add(&domain->next, &iommu->domain_list);
2433         vfio_update_pgsize_bitmap(iommu);
2434 done:
2435         /* Delete the old one and insert new iova list */
2436         vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2437
2438         /*
2439          * An iommu backed group can dirty memory directly and therefore
2440          * demotes the iommu scope until it declares itself dirty tracking
2441          * capable via the page pinning interface.
2442          */
2443         iommu->num_non_pinned_groups++;
2444         mutex_unlock(&iommu->lock);
2445         vfio_iommu_resv_free(&group_resv_regions);
2446
2447         return 0;
2448
2449 out_detach:
2450         vfio_iommu_detach_group(domain, group);
2451 out_domain:
2452         iommu_domain_free(domain->domain);
2453         vfio_iommu_iova_free(&iova_copy);
2454         vfio_iommu_resv_free(&group_resv_regions);
2455 out_free:
2456         kfree(domain);
2457         kfree(group);
2458         mutex_unlock(&iommu->lock);
2459         return ret;
2460 }
2461
2462 static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
2463 {
2464         struct rb_node *node;
2465
2466         while ((node = rb_first(&iommu->dma_list)))
2467                 vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
2468 }
2469
2470 static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
2471 {
2472         struct rb_node *n, *p;
2473
2474         n = rb_first(&iommu->dma_list);
2475         for (; n; n = rb_next(n)) {
2476                 struct vfio_dma *dma;
2477                 long locked = 0, unlocked = 0;
2478
2479                 dma = rb_entry(n, struct vfio_dma, node);
2480                 unlocked += vfio_unmap_unpin(iommu, dma, false);
2481                 p = rb_first(&dma->pfn_list);
2482                 for (; p; p = rb_next(p)) {
2483                         struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
2484                                                          node);
2485
2486                         if (!is_invalid_reserved_pfn(vpfn->pfn))
2487                                 locked++;
2488                 }
2489                 vfio_lock_acct(dma, locked - unlocked, true);
2490         }
2491 }
2492
2493 /*
2494  * Called when a domain is removed in detach. It is possible that
2495  * the removed domain decided the iova aperture window. Modify the
2496  * iova aperture with the smallest window among existing domains.
2497  */
2498 static void vfio_iommu_aper_expand(struct vfio_iommu *iommu,
2499                                    struct list_head *iova_copy)
2500 {
2501         struct vfio_domain *domain;
2502         struct vfio_iova *node;
2503         dma_addr_t start = 0;
2504         dma_addr_t end = (dma_addr_t)~0;
2505
2506         if (list_empty(iova_copy))
2507                 return;
2508
2509         list_for_each_entry(domain, &iommu->domain_list, next) {
2510                 struct iommu_domain_geometry *geo = &domain->domain->geometry;
2511
2512                 if (geo->aperture_start > start)
2513                         start = geo->aperture_start;
2514                 if (geo->aperture_end < end)
2515                         end = geo->aperture_end;
2516         }
2517
2518         /* Modify aperture limits. The new aper is either same or bigger */
2519         node = list_first_entry(iova_copy, struct vfio_iova, list);
2520         node->start = start;
2521         node = list_last_entry(iova_copy, struct vfio_iova, list);
2522         node->end = end;
2523 }
2524
2525 /*
2526  * Called when a group is detached. The reserved regions for that
2527  * group can be part of valid iova now. But since reserved regions
2528  * may be duplicated among groups, populate the iova valid regions
2529  * list again.
2530  */
2531 static int vfio_iommu_resv_refresh(struct vfio_iommu *iommu,
2532                                    struct list_head *iova_copy)
2533 {
2534         struct vfio_domain *d;
2535         struct vfio_iommu_group *g;
2536         struct vfio_iova *node;
2537         dma_addr_t start, end;
2538         LIST_HEAD(resv_regions);
2539         int ret;
2540
2541         if (list_empty(iova_copy))
2542                 return -EINVAL;
2543
2544         list_for_each_entry(d, &iommu->domain_list, next) {
2545                 list_for_each_entry(g, &d->group_list, next) {
2546                         ret = iommu_get_group_resv_regions(g->iommu_group,
2547                                                            &resv_regions);
2548                         if (ret)
2549                                 goto done;
2550                 }
2551         }
2552
2553         node = list_first_entry(iova_copy, struct vfio_iova, list);
2554         start = node->start;
2555         node = list_last_entry(iova_copy, struct vfio_iova, list);
2556         end = node->end;
2557
2558         /* purge the iova list and create new one */
2559         vfio_iommu_iova_free(iova_copy);
2560
2561         ret = vfio_iommu_aper_resize(iova_copy, start, end);
2562         if (ret)
2563                 goto done;
2564
2565         /* Exclude current reserved regions from iova ranges */
2566         ret = vfio_iommu_resv_exclude(iova_copy, &resv_regions);
2567 done:
2568         vfio_iommu_resv_free(&resv_regions);
2569         return ret;
2570 }
2571
2572 static void vfio_iommu_type1_detach_group(void *iommu_data,
2573                                           struct iommu_group *iommu_group)
2574 {
2575         struct vfio_iommu *iommu = iommu_data;
2576         struct vfio_domain *domain;
2577         struct vfio_iommu_group *group;
2578         bool update_dirty_scope = false;
2579         LIST_HEAD(iova_copy);
2580
2581         mutex_lock(&iommu->lock);
2582
2583         if (iommu->external_domain) {
2584                 group = find_iommu_group(iommu->external_domain, iommu_group);
2585                 if (group) {
2586                         update_dirty_scope = !group->pinned_page_dirty_scope;
2587                         list_del(&group->next);
2588                         kfree(group);
2589
2590                         if (list_empty(&iommu->external_domain->group_list)) {
2591                                 if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu)) {
2592                                         WARN_ON(iommu->notifier.head);
2593                                         vfio_iommu_unmap_unpin_all(iommu);
2594                                 }
2595
2596                                 kfree(iommu->external_domain);
2597                                 iommu->external_domain = NULL;
2598                         }
2599                         goto detach_group_done;
2600                 }
2601         }
2602
2603         /*
2604          * Get a copy of iova list. This will be used to update
2605          * and to replace the current one later. Please note that
2606          * we will leave the original list as it is if update fails.
2607          */
2608         vfio_iommu_iova_get_copy(iommu, &iova_copy);
2609
2610         list_for_each_entry(domain, &iommu->domain_list, next) {
2611                 group = find_iommu_group(domain, iommu_group);
2612                 if (!group)
2613                         continue;
2614
2615                 vfio_iommu_detach_group(domain, group);
2616                 update_dirty_scope = !group->pinned_page_dirty_scope;
2617                 list_del(&group->next);
2618                 kfree(group);
2619                 /*
2620                  * Group ownership provides privilege, if the group list is
2621                  * empty, the domain goes away. If it's the last domain with
2622                  * iommu and external domain doesn't exist, then all the
2623                  * mappings go away too. If it's the last domain with iommu and
2624                  * external domain exist, update accounting
2625                  */
2626                 if (list_empty(&domain->group_list)) {
2627                         if (list_is_singular(&iommu->domain_list)) {
2628                                 if (!iommu->external_domain) {
2629                                         WARN_ON(iommu->notifier.head);
2630                                         vfio_iommu_unmap_unpin_all(iommu);
2631                                 } else {
2632                                         vfio_iommu_unmap_unpin_reaccount(iommu);
2633                                 }
2634                         }
2635                         iommu_domain_free(domain->domain);
2636                         list_del(&domain->next);
2637                         kfree(domain);
2638                         vfio_iommu_aper_expand(iommu, &iova_copy);
2639                         vfio_update_pgsize_bitmap(iommu);
2640                 }
2641                 break;
2642         }
2643
2644         if (!vfio_iommu_resv_refresh(iommu, &iova_copy))
2645                 vfio_iommu_iova_insert_copy(iommu, &iova_copy);
2646         else
2647                 vfio_iommu_iova_free(&iova_copy);
2648
2649 detach_group_done:
2650         /*
2651          * Removal of a group without dirty tracking may allow the iommu scope
2652          * to be promoted.
2653          */
2654         if (update_dirty_scope) {
2655                 iommu->num_non_pinned_groups--;
2656                 if (iommu->dirty_page_tracking)
2657                         vfio_iommu_populate_bitmap_full(iommu);
2658         }
2659         mutex_unlock(&iommu->lock);
2660 }
2661
2662 static void *vfio_iommu_type1_open(unsigned long arg)
2663 {
2664         struct vfio_iommu *iommu;
2665
2666         iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
2667         if (!iommu)
2668                 return ERR_PTR(-ENOMEM);
2669
2670         switch (arg) {
2671         case VFIO_TYPE1_IOMMU:
2672                 break;
2673         case VFIO_TYPE1_NESTING_IOMMU:
2674                 iommu->nesting = true;
2675                 fallthrough;
2676         case VFIO_TYPE1v2_IOMMU:
2677                 iommu->v2 = true;
2678                 break;
2679         default:
2680                 kfree(iommu);
2681                 return ERR_PTR(-EINVAL);
2682         }
2683
2684         INIT_LIST_HEAD(&iommu->domain_list);
2685         INIT_LIST_HEAD(&iommu->iova_list);
2686         iommu->dma_list = RB_ROOT;
2687         iommu->dma_avail = dma_entry_limit;
2688         iommu->container_open = true;
2689         mutex_init(&iommu->lock);
2690         BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
2691         init_waitqueue_head(&iommu->vaddr_wait);
2692
2693         return iommu;
2694 }
2695
2696 static void vfio_release_domain(struct vfio_domain *domain, bool external)
2697 {
2698         struct vfio_iommu_group *group, *group_tmp;
2699
2700         list_for_each_entry_safe(group, group_tmp,
2701                                  &domain->group_list, next) {
2702                 if (!external)
2703                         vfio_iommu_detach_group(domain, group);
2704                 list_del(&group->next);
2705                 kfree(group);
2706         }
2707
2708         if (!external)
2709                 iommu_domain_free(domain->domain);
2710 }
2711
2712 static void vfio_iommu_type1_release(void *iommu_data)
2713 {
2714         struct vfio_iommu *iommu = iommu_data;
2715         struct vfio_domain *domain, *domain_tmp;
2716
2717         if (iommu->external_domain) {
2718                 vfio_release_domain(iommu->external_domain, true);
2719                 kfree(iommu->external_domain);
2720         }
2721
2722         vfio_iommu_unmap_unpin_all(iommu);
2723
2724         list_for_each_entry_safe(domain, domain_tmp,
2725                                  &iommu->domain_list, next) {
2726                 vfio_release_domain(domain, false);
2727                 list_del(&domain->next);
2728                 kfree(domain);
2729         }
2730
2731         vfio_iommu_iova_free(&iommu->iova_list);
2732
2733         kfree(iommu);
2734 }
2735
2736 static int vfio_domains_have_iommu_cache(struct vfio_iommu *iommu)
2737 {
2738         struct vfio_domain *domain;
2739         int ret = 1;
2740
2741         mutex_lock(&iommu->lock);
2742         list_for_each_entry(domain, &iommu->domain_list, next) {
2743                 if (!(domain->prot & IOMMU_CACHE)) {
2744                         ret = 0;
2745                         break;
2746                 }
2747         }
2748         mutex_unlock(&iommu->lock);
2749
2750         return ret;
2751 }
2752
2753 static int vfio_iommu_type1_check_extension(struct vfio_iommu *iommu,
2754                                             unsigned long arg)
2755 {
2756         switch (arg) {
2757         case VFIO_TYPE1_IOMMU:
2758         case VFIO_TYPE1v2_IOMMU:
2759         case VFIO_TYPE1_NESTING_IOMMU:
2760         case VFIO_UNMAP_ALL:
2761         case VFIO_UPDATE_VADDR:
2762                 return 1;
2763         case VFIO_DMA_CC_IOMMU:
2764                 if (!iommu)
2765                         return 0;
2766                 return vfio_domains_have_iommu_cache(iommu);
2767         default:
2768                 return 0;
2769         }
2770 }
2771
2772 static int vfio_iommu_iova_add_cap(struct vfio_info_cap *caps,
2773                  struct vfio_iommu_type1_info_cap_iova_range *cap_iovas,
2774                  size_t size)
2775 {
2776         struct vfio_info_cap_header *header;
2777         struct vfio_iommu_type1_info_cap_iova_range *iova_cap;
2778
2779         header = vfio_info_cap_add(caps, size,
2780                                    VFIO_IOMMU_TYPE1_INFO_CAP_IOVA_RANGE, 1);
2781         if (IS_ERR(header))
2782                 return PTR_ERR(header);
2783
2784         iova_cap = container_of(header,
2785                                 struct vfio_iommu_type1_info_cap_iova_range,
2786                                 header);
2787         iova_cap->nr_iovas = cap_iovas->nr_iovas;
2788         memcpy(iova_cap->iova_ranges, cap_iovas->iova_ranges,
2789                cap_iovas->nr_iovas * sizeof(*cap_iovas->iova_ranges));
2790         return 0;
2791 }
2792
2793 static int vfio_iommu_iova_build_caps(struct vfio_iommu *iommu,
2794                                       struct vfio_info_cap *caps)
2795 {
2796         struct vfio_iommu_type1_info_cap_iova_range *cap_iovas;
2797         struct vfio_iova *iova;
2798         size_t size;
2799         int iovas = 0, i = 0, ret;
2800
2801         list_for_each_entry(iova, &iommu->iova_list, list)
2802                 iovas++;
2803
2804         if (!iovas) {
2805                 /*
2806                  * Return 0 as a container with a single mdev device
2807                  * will have an empty list
2808                  */
2809                 return 0;
2810         }
2811
2812         size = struct_size(cap_iovas, iova_ranges, iovas);
2813
2814         cap_iovas = kzalloc(size, GFP_KERNEL);
2815         if (!cap_iovas)
2816                 return -ENOMEM;
2817
2818         cap_iovas->nr_iovas = iovas;
2819
2820         list_for_each_entry(iova, &iommu->iova_list, list) {
2821                 cap_iovas->iova_ranges[i].start = iova->start;
2822                 cap_iovas->iova_ranges[i].end = iova->end;
2823                 i++;
2824         }
2825
2826         ret = vfio_iommu_iova_add_cap(caps, cap_iovas, size);
2827
2828         kfree(cap_iovas);
2829         return ret;
2830 }
2831
2832 static int vfio_iommu_migration_build_caps(struct vfio_iommu *iommu,
2833                                            struct vfio_info_cap *caps)
2834 {
2835         struct vfio_iommu_type1_info_cap_migration cap_mig;
2836
2837         cap_mig.header.id = VFIO_IOMMU_TYPE1_INFO_CAP_MIGRATION;
2838         cap_mig.header.version = 1;
2839
2840         cap_mig.flags = 0;
2841         /* support minimum pgsize */
2842         cap_mig.pgsize_bitmap = (size_t)1 << __ffs(iommu->pgsize_bitmap);
2843         cap_mig.max_dirty_bitmap_size = DIRTY_BITMAP_SIZE_MAX;
2844
2845         return vfio_info_add_capability(caps, &cap_mig.header, sizeof(cap_mig));
2846 }
2847
2848 static int vfio_iommu_dma_avail_build_caps(struct vfio_iommu *iommu,
2849                                            struct vfio_info_cap *caps)
2850 {
2851         struct vfio_iommu_type1_info_dma_avail cap_dma_avail;
2852
2853         cap_dma_avail.header.id = VFIO_IOMMU_TYPE1_INFO_DMA_AVAIL;
2854         cap_dma_avail.header.version = 1;
2855
2856         cap_dma_avail.avail = iommu->dma_avail;
2857
2858         return vfio_info_add_capability(caps, &cap_dma_avail.header,
2859                                         sizeof(cap_dma_avail));
2860 }
2861
2862 static int vfio_iommu_type1_get_info(struct vfio_iommu *iommu,
2863                                      unsigned long arg)
2864 {
2865         struct vfio_iommu_type1_info info;
2866         unsigned long minsz;
2867         struct vfio_info_cap caps = { .buf = NULL, .size = 0 };
2868         unsigned long capsz;
2869         int ret;
2870
2871         minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
2872
2873         /* For backward compatibility, cannot require this */
2874         capsz = offsetofend(struct vfio_iommu_type1_info, cap_offset);
2875
2876         if (copy_from_user(&info, (void __user *)arg, minsz))
2877                 return -EFAULT;
2878
2879         if (info.argsz < minsz)
2880                 return -EINVAL;
2881
2882         if (info.argsz >= capsz) {
2883                 minsz = capsz;
2884                 info.cap_offset = 0; /* output, no-recopy necessary */
2885         }
2886
2887         mutex_lock(&iommu->lock);
2888         info.flags = VFIO_IOMMU_INFO_PGSIZES;
2889
2890         info.iova_pgsizes = iommu->pgsize_bitmap;
2891
2892         ret = vfio_iommu_migration_build_caps(iommu, &caps);
2893
2894         if (!ret)
2895                 ret = vfio_iommu_dma_avail_build_caps(iommu, &caps);
2896
2897         if (!ret)
2898                 ret = vfio_iommu_iova_build_caps(iommu, &caps);
2899
2900         mutex_unlock(&iommu->lock);
2901
2902         if (ret)
2903                 return ret;
2904
2905         if (caps.size) {
2906                 info.flags |= VFIO_IOMMU_INFO_CAPS;
2907
2908                 if (info.argsz < sizeof(info) + caps.size) {
2909                         info.argsz = sizeof(info) + caps.size;
2910                 } else {
2911                         vfio_info_cap_shift(&caps, sizeof(info));
2912                         if (copy_to_user((void __user *)arg +
2913                                         sizeof(info), caps.buf,
2914                                         caps.size)) {
2915                                 kfree(caps.buf);
2916                                 return -EFAULT;
2917                         }
2918                         info.cap_offset = sizeof(info);
2919                 }
2920
2921                 kfree(caps.buf);
2922         }
2923
2924         return copy_to_user((void __user *)arg, &info, minsz) ?
2925                         -EFAULT : 0;
2926 }
2927
2928 static int vfio_iommu_type1_map_dma(struct vfio_iommu *iommu,
2929                                     unsigned long arg)
2930 {
2931         struct vfio_iommu_type1_dma_map map;
2932         unsigned long minsz;
2933         uint32_t mask = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE |
2934                         VFIO_DMA_MAP_FLAG_VADDR;
2935
2936         minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
2937
2938         if (copy_from_user(&map, (void __user *)arg, minsz))
2939                 return -EFAULT;
2940
2941         if (map.argsz < minsz || map.flags & ~mask)
2942                 return -EINVAL;
2943
2944         return vfio_dma_do_map(iommu, &map);
2945 }
2946
2947 static int vfio_iommu_type1_unmap_dma(struct vfio_iommu *iommu,
2948                                       unsigned long arg)
2949 {
2950         struct vfio_iommu_type1_dma_unmap unmap;
2951         struct vfio_bitmap bitmap = { 0 };
2952         uint32_t mask = VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP |
2953                         VFIO_DMA_UNMAP_FLAG_VADDR |
2954                         VFIO_DMA_UNMAP_FLAG_ALL;
2955         unsigned long minsz;
2956         int ret;
2957
2958         minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
2959
2960         if (copy_from_user(&unmap, (void __user *)arg, minsz))
2961                 return -EFAULT;
2962
2963         if (unmap.argsz < minsz || unmap.flags & ~mask)
2964                 return -EINVAL;
2965
2966         if ((unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) &&
2967             (unmap.flags & (VFIO_DMA_UNMAP_FLAG_ALL |
2968                             VFIO_DMA_UNMAP_FLAG_VADDR)))
2969                 return -EINVAL;
2970
2971         if (unmap.flags & VFIO_DMA_UNMAP_FLAG_GET_DIRTY_BITMAP) {
2972                 unsigned long pgshift;
2973
2974                 if (unmap.argsz < (minsz + sizeof(bitmap)))
2975                         return -EINVAL;
2976
2977                 if (copy_from_user(&bitmap,
2978                                    (void __user *)(arg + minsz),
2979                                    sizeof(bitmap)))
2980                         return -EFAULT;
2981
2982                 if (!access_ok((void __user *)bitmap.data, bitmap.size))
2983                         return -EINVAL;
2984
2985                 pgshift = __ffs(bitmap.pgsize);
2986                 ret = verify_bitmap_size(unmap.size >> pgshift,
2987                                          bitmap.size);
2988                 if (ret)
2989                         return ret;
2990         }
2991
2992         ret = vfio_dma_do_unmap(iommu, &unmap, &bitmap);
2993         if (ret)
2994                 return ret;
2995
2996         return copy_to_user((void __user *)arg, &unmap, minsz) ?
2997                         -EFAULT : 0;
2998 }
2999
3000 static int vfio_iommu_type1_dirty_pages(struct vfio_iommu *iommu,
3001                                         unsigned long arg)
3002 {
3003         struct vfio_iommu_type1_dirty_bitmap dirty;
3004         uint32_t mask = VFIO_IOMMU_DIRTY_PAGES_FLAG_START |
3005                         VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP |
3006                         VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP;
3007         unsigned long minsz;
3008         int ret = 0;
3009
3010         if (!iommu->v2)
3011                 return -EACCES;
3012
3013         minsz = offsetofend(struct vfio_iommu_type1_dirty_bitmap, flags);
3014
3015         if (copy_from_user(&dirty, (void __user *)arg, minsz))
3016                 return -EFAULT;
3017
3018         if (dirty.argsz < minsz || dirty.flags & ~mask)
3019                 return -EINVAL;
3020
3021         /* only one flag should be set at a time */
3022         if (__ffs(dirty.flags) != __fls(dirty.flags))
3023                 return -EINVAL;
3024
3025         if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_START) {
3026                 size_t pgsize;
3027
3028                 mutex_lock(&iommu->lock);
3029                 pgsize = 1 << __ffs(iommu->pgsize_bitmap);
3030                 if (!iommu->dirty_page_tracking) {
3031                         ret = vfio_dma_bitmap_alloc_all(iommu, pgsize);
3032                         if (!ret)
3033                                 iommu->dirty_page_tracking = true;
3034                 }
3035                 mutex_unlock(&iommu->lock);
3036                 return ret;
3037         } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_STOP) {
3038                 mutex_lock(&iommu->lock);
3039                 if (iommu->dirty_page_tracking) {
3040                         iommu->dirty_page_tracking = false;
3041                         vfio_dma_bitmap_free_all(iommu);
3042                 }
3043                 mutex_unlock(&iommu->lock);
3044                 return 0;
3045         } else if (dirty.flags & VFIO_IOMMU_DIRTY_PAGES_FLAG_GET_BITMAP) {
3046                 struct vfio_iommu_type1_dirty_bitmap_get range;
3047                 unsigned long pgshift;
3048                 size_t data_size = dirty.argsz - minsz;
3049                 size_t iommu_pgsize;
3050
3051                 if (!data_size || data_size < sizeof(range))
3052                         return -EINVAL;
3053
3054                 if (copy_from_user(&range, (void __user *)(arg + minsz),
3055                                    sizeof(range)))
3056                         return -EFAULT;
3057
3058                 if (range.iova + range.size < range.iova)
3059                         return -EINVAL;
3060                 if (!access_ok((void __user *)range.bitmap.data,
3061                                range.bitmap.size))
3062                         return -EINVAL;
3063
3064                 pgshift = __ffs(range.bitmap.pgsize);
3065                 ret = verify_bitmap_size(range.size >> pgshift,
3066                                          range.bitmap.size);
3067                 if (ret)
3068                         return ret;
3069
3070                 mutex_lock(&iommu->lock);
3071
3072                 iommu_pgsize = (size_t)1 << __ffs(iommu->pgsize_bitmap);
3073
3074                 /* allow only smallest supported pgsize */
3075                 if (range.bitmap.pgsize != iommu_pgsize) {
3076                         ret = -EINVAL;
3077                         goto out_unlock;
3078                 }
3079                 if (range.iova & (iommu_pgsize - 1)) {
3080                         ret = -EINVAL;
3081                         goto out_unlock;
3082                 }
3083                 if (!range.size || range.size & (iommu_pgsize - 1)) {
3084                         ret = -EINVAL;
3085                         goto out_unlock;
3086                 }
3087
3088                 if (iommu->dirty_page_tracking)
3089                         ret = vfio_iova_dirty_bitmap(range.bitmap.data,
3090                                                      iommu, range.iova,
3091                                                      range.size,
3092                                                      range.bitmap.pgsize);
3093                 else
3094                         ret = -EINVAL;
3095 out_unlock:
3096                 mutex_unlock(&iommu->lock);
3097
3098                 return ret;
3099         }
3100
3101         return -EINVAL;
3102 }
3103
3104 static long vfio_iommu_type1_ioctl(void *iommu_data,
3105                                    unsigned int cmd, unsigned long arg)
3106 {
3107         struct vfio_iommu *iommu = iommu_data;
3108
3109         switch (cmd) {
3110         case VFIO_CHECK_EXTENSION:
3111                 return vfio_iommu_type1_check_extension(iommu, arg);
3112         case VFIO_IOMMU_GET_INFO:
3113                 return vfio_iommu_type1_get_info(iommu, arg);
3114         case VFIO_IOMMU_MAP_DMA:
3115                 return vfio_iommu_type1_map_dma(iommu, arg);
3116         case VFIO_IOMMU_UNMAP_DMA:
3117                 return vfio_iommu_type1_unmap_dma(iommu, arg);
3118         case VFIO_IOMMU_DIRTY_PAGES:
3119                 return vfio_iommu_type1_dirty_pages(iommu, arg);
3120         default:
3121                 return -ENOTTY;
3122         }
3123 }
3124
3125 static int vfio_iommu_type1_register_notifier(void *iommu_data,
3126                                               unsigned long *events,
3127                                               struct notifier_block *nb)
3128 {
3129         struct vfio_iommu *iommu = iommu_data;
3130
3131         /* clear known events */
3132         *events &= ~VFIO_IOMMU_NOTIFY_DMA_UNMAP;
3133
3134         /* refuse to register if still events remaining */
3135         if (*events)
3136                 return -EINVAL;
3137
3138         return blocking_notifier_chain_register(&iommu->notifier, nb);
3139 }
3140
3141 static int vfio_iommu_type1_unregister_notifier(void *iommu_data,
3142                                                 struct notifier_block *nb)
3143 {
3144         struct vfio_iommu *iommu = iommu_data;
3145
3146         return blocking_notifier_chain_unregister(&iommu->notifier, nb);
3147 }
3148
3149 static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu,
3150                                          dma_addr_t user_iova, void *data,
3151                                          size_t count, bool write,
3152                                          size_t *copied)
3153 {
3154         struct mm_struct *mm;
3155         unsigned long vaddr;
3156         struct vfio_dma *dma;
3157         bool kthread = current->mm == NULL;
3158         size_t offset;
3159         int ret;
3160
3161         *copied = 0;
3162
3163         ret = vfio_find_dma_valid(iommu, user_iova, 1, &dma);
3164         if (ret < 0)
3165                 return ret;
3166
3167         if ((write && !(dma->prot & IOMMU_WRITE)) ||
3168                         !(dma->prot & IOMMU_READ))
3169                 return -EPERM;
3170
3171         mm = get_task_mm(dma->task);
3172
3173         if (!mm)
3174                 return -EPERM;
3175
3176         if (kthread)
3177                 kthread_use_mm(mm);
3178         else if (current->mm != mm)
3179                 goto out;
3180
3181         offset = user_iova - dma->iova;
3182
3183         if (count > dma->size - offset)
3184                 count = dma->size - offset;
3185
3186         vaddr = dma->vaddr + offset;
3187
3188         if (write) {
3189                 *copied = copy_to_user((void __user *)vaddr, data,
3190                                          count) ? 0 : count;
3191                 if (*copied && iommu->dirty_page_tracking) {
3192                         unsigned long pgshift = __ffs(iommu->pgsize_bitmap);
3193                         /*
3194                          * Bitmap populated with the smallest supported page
3195                          * size
3196                          */
3197                         bitmap_set(dma->bitmap, offset >> pgshift,
3198                                    ((offset + *copied - 1) >> pgshift) -
3199                                    (offset >> pgshift) + 1);
3200                 }
3201         } else
3202                 *copied = copy_from_user(data, (void __user *)vaddr,
3203                                            count) ? 0 : count;
3204         if (kthread)
3205                 kthread_unuse_mm(mm);
3206 out:
3207         mmput(mm);
3208         return *copied ? 0 : -EFAULT;
3209 }
3210
3211 static int vfio_iommu_type1_dma_rw(void *iommu_data, dma_addr_t user_iova,
3212                                    void *data, size_t count, bool write)
3213 {
3214         struct vfio_iommu *iommu = iommu_data;
3215         int ret = 0;
3216         size_t done;
3217
3218         mutex_lock(&iommu->lock);
3219         while (count > 0) {
3220                 ret = vfio_iommu_type1_dma_rw_chunk(iommu, user_iova, data,
3221                                                     count, write, &done);
3222                 if (ret)
3223                         break;
3224
3225                 count -= done;
3226                 data += done;
3227                 user_iova += done;
3228         }
3229
3230         mutex_unlock(&iommu->lock);
3231         return ret;
3232 }
3233
3234 static struct iommu_domain *
3235 vfio_iommu_type1_group_iommu_domain(void *iommu_data,
3236                                     struct iommu_group *iommu_group)
3237 {
3238         struct iommu_domain *domain = ERR_PTR(-ENODEV);
3239         struct vfio_iommu *iommu = iommu_data;
3240         struct vfio_domain *d;
3241
3242         if (!iommu || !iommu_group)
3243                 return ERR_PTR(-EINVAL);
3244
3245         mutex_lock(&iommu->lock);
3246         list_for_each_entry(d, &iommu->domain_list, next) {
3247                 if (find_iommu_group(d, iommu_group)) {
3248                         domain = d->domain;
3249                         break;
3250                 }
3251         }
3252         mutex_unlock(&iommu->lock);
3253
3254         return domain;
3255 }
3256
3257 static void vfio_iommu_type1_notify(void *iommu_data,
3258                                     enum vfio_iommu_notify_type event)
3259 {
3260         struct vfio_iommu *iommu = iommu_data;
3261
3262         if (event != VFIO_IOMMU_CONTAINER_CLOSE)
3263                 return;
3264         mutex_lock(&iommu->lock);
3265         iommu->container_open = false;
3266         mutex_unlock(&iommu->lock);
3267         wake_up_all(&iommu->vaddr_wait);
3268 }
3269
3270 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
3271         .name                   = "vfio-iommu-type1",
3272         .owner                  = THIS_MODULE,
3273         .open                   = vfio_iommu_type1_open,
3274         .release                = vfio_iommu_type1_release,
3275         .ioctl                  = vfio_iommu_type1_ioctl,
3276         .attach_group           = vfio_iommu_type1_attach_group,
3277         .detach_group           = vfio_iommu_type1_detach_group,
3278         .pin_pages              = vfio_iommu_type1_pin_pages,
3279         .unpin_pages            = vfio_iommu_type1_unpin_pages,
3280         .register_notifier      = vfio_iommu_type1_register_notifier,
3281         .unregister_notifier    = vfio_iommu_type1_unregister_notifier,
3282         .dma_rw                 = vfio_iommu_type1_dma_rw,
3283         .group_iommu_domain     = vfio_iommu_type1_group_iommu_domain,
3284         .notify                 = vfio_iommu_type1_notify,
3285 };
3286
3287 static int __init vfio_iommu_type1_init(void)
3288 {
3289         return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
3290 }
3291
3292 static void __exit vfio_iommu_type1_cleanup(void)
3293 {
3294         vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
3295 }
3296
3297 module_init(vfio_iommu_type1_init);
3298 module_exit(vfio_iommu_type1_cleanup);
3299
3300 MODULE_VERSION(DRIVER_VERSION);
3301 MODULE_LICENSE("GPL v2");
3302 MODULE_AUTHOR(DRIVER_AUTHOR);
3303 MODULE_DESCRIPTION(DRIVER_DESC);