Merge tag 'rust-fixes-6.6' of https://github.com/Rust-for-Linux/linux
[platform/kernel/linux-starfive.git] / drivers / iommu / iommufd / selftest.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * Kernel side components to support tools/testing/selftests/iommu
5  */
6 #include <linux/slab.h>
7 #include <linux/iommu.h>
8 #include <linux/xarray.h>
9 #include <linux/file.h>
10 #include <linux/anon_inodes.h>
11 #include <linux/fault-inject.h>
12 #include <linux/platform_device.h>
13 #include <uapi/linux/iommufd.h>
14
15 #include "../iommu-priv.h"
16 #include "io_pagetable.h"
17 #include "iommufd_private.h"
18 #include "iommufd_test.h"
19
20 static DECLARE_FAULT_ATTR(fail_iommufd);
21 static struct dentry *dbgfs_root;
22 static struct platform_device *selftest_iommu_dev;
23
24 size_t iommufd_test_memory_limit = 65536;
25
26 enum {
27         MOCK_IO_PAGE_SIZE = PAGE_SIZE / 2,
28
29         /*
30          * Like a real page table alignment requires the low bits of the address
31          * to be zero. xarray also requires the high bit to be zero, so we store
32          * the pfns shifted. The upper bits are used for metadata.
33          */
34         MOCK_PFN_MASK = ULONG_MAX / MOCK_IO_PAGE_SIZE,
35
36         _MOCK_PFN_START = MOCK_PFN_MASK + 1,
37         MOCK_PFN_START_IOVA = _MOCK_PFN_START,
38         MOCK_PFN_LAST_IOVA = _MOCK_PFN_START,
39 };
40
41 /*
42  * Syzkaller has trouble randomizing the correct iova to use since it is linked
43  * to the map ioctl's output, and it has no ide about that. So, simplify things.
44  * In syzkaller mode the 64 bit IOVA is converted into an nth area and offset
45  * value. This has a much smaller randomization space and syzkaller can hit it.
46  */
47 static unsigned long iommufd_test_syz_conv_iova(struct io_pagetable *iopt,
48                                                 u64 *iova)
49 {
50         struct syz_layout {
51                 __u32 nth_area;
52                 __u32 offset;
53         };
54         struct syz_layout *syz = (void *)iova;
55         unsigned int nth = syz->nth_area;
56         struct iopt_area *area;
57
58         down_read(&iopt->iova_rwsem);
59         for (area = iopt_area_iter_first(iopt, 0, ULONG_MAX); area;
60              area = iopt_area_iter_next(area, 0, ULONG_MAX)) {
61                 if (nth == 0) {
62                         up_read(&iopt->iova_rwsem);
63                         return iopt_area_iova(area) + syz->offset;
64                 }
65                 nth--;
66         }
67         up_read(&iopt->iova_rwsem);
68
69         return 0;
70 }
71
72 void iommufd_test_syz_conv_iova_id(struct iommufd_ucmd *ucmd,
73                                    unsigned int ioas_id, u64 *iova, u32 *flags)
74 {
75         struct iommufd_ioas *ioas;
76
77         if (!(*flags & MOCK_FLAGS_ACCESS_SYZ))
78                 return;
79         *flags &= ~(u32)MOCK_FLAGS_ACCESS_SYZ;
80
81         ioas = iommufd_get_ioas(ucmd->ictx, ioas_id);
82         if (IS_ERR(ioas))
83                 return;
84         *iova = iommufd_test_syz_conv_iova(&ioas->iopt, iova);
85         iommufd_put_object(&ioas->obj);
86 }
87
88 struct mock_iommu_domain {
89         struct iommu_domain domain;
90         struct xarray pfns;
91 };
92
93 enum selftest_obj_type {
94         TYPE_IDEV,
95 };
96
97 struct mock_dev {
98         struct device dev;
99 };
100
101 struct selftest_obj {
102         struct iommufd_object obj;
103         enum selftest_obj_type type;
104
105         union {
106                 struct {
107                         struct iommufd_device *idev;
108                         struct iommufd_ctx *ictx;
109                         struct mock_dev *mock_dev;
110                 } idev;
111         };
112 };
113
114 static void mock_domain_blocking_free(struct iommu_domain *domain)
115 {
116 }
117
118 static int mock_domain_nop_attach(struct iommu_domain *domain,
119                                   struct device *dev)
120 {
121         return 0;
122 }
123
124 static const struct iommu_domain_ops mock_blocking_ops = {
125         .free = mock_domain_blocking_free,
126         .attach_dev = mock_domain_nop_attach,
127 };
128
129 static struct iommu_domain mock_blocking_domain = {
130         .type = IOMMU_DOMAIN_BLOCKED,
131         .ops = &mock_blocking_ops,
132 };
133
134 static void *mock_domain_hw_info(struct device *dev, u32 *length, u32 *type)
135 {
136         struct iommu_test_hw_info *info;
137
138         info = kzalloc(sizeof(*info), GFP_KERNEL);
139         if (!info)
140                 return ERR_PTR(-ENOMEM);
141
142         info->test_reg = IOMMU_HW_INFO_SELFTEST_REGVAL;
143         *length = sizeof(*info);
144         *type = IOMMU_HW_INFO_TYPE_SELFTEST;
145
146         return info;
147 }
148
149 static struct iommu_domain *mock_domain_alloc(unsigned int iommu_domain_type)
150 {
151         struct mock_iommu_domain *mock;
152
153         if (iommu_domain_type == IOMMU_DOMAIN_BLOCKED)
154                 return &mock_blocking_domain;
155
156         if (iommu_domain_type != IOMMU_DOMAIN_UNMANAGED)
157                 return NULL;
158
159         mock = kzalloc(sizeof(*mock), GFP_KERNEL);
160         if (!mock)
161                 return NULL;
162         mock->domain.geometry.aperture_start = MOCK_APERTURE_START;
163         mock->domain.geometry.aperture_end = MOCK_APERTURE_LAST;
164         mock->domain.pgsize_bitmap = MOCK_IO_PAGE_SIZE;
165         xa_init(&mock->pfns);
166         return &mock->domain;
167 }
168
169 static void mock_domain_free(struct iommu_domain *domain)
170 {
171         struct mock_iommu_domain *mock =
172                 container_of(domain, struct mock_iommu_domain, domain);
173
174         WARN_ON(!xa_empty(&mock->pfns));
175         kfree(mock);
176 }
177
178 static int mock_domain_map_pages(struct iommu_domain *domain,
179                                  unsigned long iova, phys_addr_t paddr,
180                                  size_t pgsize, size_t pgcount, int prot,
181                                  gfp_t gfp, size_t *mapped)
182 {
183         struct mock_iommu_domain *mock =
184                 container_of(domain, struct mock_iommu_domain, domain);
185         unsigned long flags = MOCK_PFN_START_IOVA;
186         unsigned long start_iova = iova;
187
188         /*
189          * xarray does not reliably work with fault injection because it does a
190          * retry allocation, so put our own failure point.
191          */
192         if (iommufd_should_fail())
193                 return -ENOENT;
194
195         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
196         WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
197         for (; pgcount; pgcount--) {
198                 size_t cur;
199
200                 for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
201                         void *old;
202
203                         if (pgcount == 1 && cur + MOCK_IO_PAGE_SIZE == pgsize)
204                                 flags = MOCK_PFN_LAST_IOVA;
205                         old = xa_store(&mock->pfns, iova / MOCK_IO_PAGE_SIZE,
206                                        xa_mk_value((paddr / MOCK_IO_PAGE_SIZE) |
207                                                    flags),
208                                        gfp);
209                         if (xa_is_err(old)) {
210                                 for (; start_iova != iova;
211                                      start_iova += MOCK_IO_PAGE_SIZE)
212                                         xa_erase(&mock->pfns,
213                                                  start_iova /
214                                                          MOCK_IO_PAGE_SIZE);
215                                 return xa_err(old);
216                         }
217                         WARN_ON(old);
218                         iova += MOCK_IO_PAGE_SIZE;
219                         paddr += MOCK_IO_PAGE_SIZE;
220                         *mapped += MOCK_IO_PAGE_SIZE;
221                         flags = 0;
222                 }
223         }
224         return 0;
225 }
226
227 static size_t mock_domain_unmap_pages(struct iommu_domain *domain,
228                                       unsigned long iova, size_t pgsize,
229                                       size_t pgcount,
230                                       struct iommu_iotlb_gather *iotlb_gather)
231 {
232         struct mock_iommu_domain *mock =
233                 container_of(domain, struct mock_iommu_domain, domain);
234         bool first = true;
235         size_t ret = 0;
236         void *ent;
237
238         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
239         WARN_ON(pgsize % MOCK_IO_PAGE_SIZE);
240
241         for (; pgcount; pgcount--) {
242                 size_t cur;
243
244                 for (cur = 0; cur != pgsize; cur += MOCK_IO_PAGE_SIZE) {
245                         ent = xa_erase(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
246                         WARN_ON(!ent);
247                         /*
248                          * iommufd generates unmaps that must be a strict
249                          * superset of the map's performend So every starting
250                          * IOVA should have been an iova passed to map, and the
251                          *
252                          * First IOVA must be present and have been a first IOVA
253                          * passed to map_pages
254                          */
255                         if (first) {
256                                 WARN_ON(!(xa_to_value(ent) &
257                                           MOCK_PFN_START_IOVA));
258                                 first = false;
259                         }
260                         if (pgcount == 1 && cur + MOCK_IO_PAGE_SIZE == pgsize)
261                                 WARN_ON(!(xa_to_value(ent) &
262                                           MOCK_PFN_LAST_IOVA));
263
264                         iova += MOCK_IO_PAGE_SIZE;
265                         ret += MOCK_IO_PAGE_SIZE;
266                 }
267         }
268         return ret;
269 }
270
271 static phys_addr_t mock_domain_iova_to_phys(struct iommu_domain *domain,
272                                             dma_addr_t iova)
273 {
274         struct mock_iommu_domain *mock =
275                 container_of(domain, struct mock_iommu_domain, domain);
276         void *ent;
277
278         WARN_ON(iova % MOCK_IO_PAGE_SIZE);
279         ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
280         WARN_ON(!ent);
281         return (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE;
282 }
283
284 static bool mock_domain_capable(struct device *dev, enum iommu_cap cap)
285 {
286         return cap == IOMMU_CAP_CACHE_COHERENCY;
287 }
288
289 static void mock_domain_set_plaform_dma_ops(struct device *dev)
290 {
291         /*
292          * mock doesn't setup default domains because we can't hook into the
293          * normal probe path
294          */
295 }
296
297 static struct iommu_device mock_iommu_device = {
298 };
299
300 static struct iommu_device *mock_probe_device(struct device *dev)
301 {
302         return &mock_iommu_device;
303 }
304
305 static const struct iommu_ops mock_ops = {
306         .owner = THIS_MODULE,
307         .pgsize_bitmap = MOCK_IO_PAGE_SIZE,
308         .hw_info = mock_domain_hw_info,
309         .domain_alloc = mock_domain_alloc,
310         .capable = mock_domain_capable,
311         .set_platform_dma_ops = mock_domain_set_plaform_dma_ops,
312         .device_group = generic_device_group,
313         .probe_device = mock_probe_device,
314         .default_domain_ops =
315                 &(struct iommu_domain_ops){
316                         .free = mock_domain_free,
317                         .attach_dev = mock_domain_nop_attach,
318                         .map_pages = mock_domain_map_pages,
319                         .unmap_pages = mock_domain_unmap_pages,
320                         .iova_to_phys = mock_domain_iova_to_phys,
321                 },
322 };
323
324 static inline struct iommufd_hw_pagetable *
325 get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id,
326                  struct mock_iommu_domain **mock)
327 {
328         struct iommufd_hw_pagetable *hwpt;
329         struct iommufd_object *obj;
330
331         obj = iommufd_get_object(ucmd->ictx, mockpt_id,
332                                  IOMMUFD_OBJ_HW_PAGETABLE);
333         if (IS_ERR(obj))
334                 return ERR_CAST(obj);
335         hwpt = container_of(obj, struct iommufd_hw_pagetable, obj);
336         if (hwpt->domain->ops != mock_ops.default_domain_ops) {
337                 iommufd_put_object(&hwpt->obj);
338                 return ERR_PTR(-EINVAL);
339         }
340         *mock = container_of(hwpt->domain, struct mock_iommu_domain, domain);
341         return hwpt;
342 }
343
344 struct mock_bus_type {
345         struct bus_type bus;
346         struct notifier_block nb;
347 };
348
349 static struct mock_bus_type iommufd_mock_bus_type = {
350         .bus = {
351                 .name = "iommufd_mock",
352         },
353 };
354
355 static atomic_t mock_dev_num;
356
357 static void mock_dev_release(struct device *dev)
358 {
359         struct mock_dev *mdev = container_of(dev, struct mock_dev, dev);
360
361         atomic_dec(&mock_dev_num);
362         kfree(mdev);
363 }
364
365 static struct mock_dev *mock_dev_create(void)
366 {
367         struct mock_dev *mdev;
368         int rc;
369
370         mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
371         if (!mdev)
372                 return ERR_PTR(-ENOMEM);
373
374         device_initialize(&mdev->dev);
375         mdev->dev.release = mock_dev_release;
376         mdev->dev.bus = &iommufd_mock_bus_type.bus;
377
378         rc = dev_set_name(&mdev->dev, "iommufd_mock%u",
379                           atomic_inc_return(&mock_dev_num));
380         if (rc)
381                 goto err_put;
382
383         rc = device_add(&mdev->dev);
384         if (rc)
385                 goto err_put;
386         return mdev;
387
388 err_put:
389         put_device(&mdev->dev);
390         return ERR_PTR(rc);
391 }
392
393 static void mock_dev_destroy(struct mock_dev *mdev)
394 {
395         device_unregister(&mdev->dev);
396 }
397
398 bool iommufd_selftest_is_mock_dev(struct device *dev)
399 {
400         return dev->release == mock_dev_release;
401 }
402
403 /* Create an hw_pagetable with the mock domain so we can test the domain ops */
404 static int iommufd_test_mock_domain(struct iommufd_ucmd *ucmd,
405                                     struct iommu_test_cmd *cmd)
406 {
407         struct iommufd_device *idev;
408         struct selftest_obj *sobj;
409         u32 pt_id = cmd->id;
410         u32 idev_id;
411         int rc;
412
413         sobj = iommufd_object_alloc(ucmd->ictx, sobj, IOMMUFD_OBJ_SELFTEST);
414         if (IS_ERR(sobj))
415                 return PTR_ERR(sobj);
416
417         sobj->idev.ictx = ucmd->ictx;
418         sobj->type = TYPE_IDEV;
419
420         sobj->idev.mock_dev = mock_dev_create();
421         if (IS_ERR(sobj->idev.mock_dev)) {
422                 rc = PTR_ERR(sobj->idev.mock_dev);
423                 goto out_sobj;
424         }
425
426         idev = iommufd_device_bind(ucmd->ictx, &sobj->idev.mock_dev->dev,
427                                    &idev_id);
428         if (IS_ERR(idev)) {
429                 rc = PTR_ERR(idev);
430                 goto out_mdev;
431         }
432         sobj->idev.idev = idev;
433
434         rc = iommufd_device_attach(idev, &pt_id);
435         if (rc)
436                 goto out_unbind;
437
438         /* Userspace must destroy the device_id to destroy the object */
439         cmd->mock_domain.out_hwpt_id = pt_id;
440         cmd->mock_domain.out_stdev_id = sobj->obj.id;
441         cmd->mock_domain.out_idev_id = idev_id;
442         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
443         if (rc)
444                 goto out_detach;
445         iommufd_object_finalize(ucmd->ictx, &sobj->obj);
446         return 0;
447
448 out_detach:
449         iommufd_device_detach(idev);
450 out_unbind:
451         iommufd_device_unbind(idev);
452 out_mdev:
453         mock_dev_destroy(sobj->idev.mock_dev);
454 out_sobj:
455         iommufd_object_abort(ucmd->ictx, &sobj->obj);
456         return rc;
457 }
458
459 /* Replace the mock domain with a manually allocated hw_pagetable */
460 static int iommufd_test_mock_domain_replace(struct iommufd_ucmd *ucmd,
461                                             unsigned int device_id, u32 pt_id,
462                                             struct iommu_test_cmd *cmd)
463 {
464         struct iommufd_object *dev_obj;
465         struct selftest_obj *sobj;
466         int rc;
467
468         /*
469          * Prefer to use the OBJ_SELFTEST because the destroy_rwsem will ensure
470          * it doesn't race with detach, which is not allowed.
471          */
472         dev_obj =
473                 iommufd_get_object(ucmd->ictx, device_id, IOMMUFD_OBJ_SELFTEST);
474         if (IS_ERR(dev_obj))
475                 return PTR_ERR(dev_obj);
476
477         sobj = container_of(dev_obj, struct selftest_obj, obj);
478         if (sobj->type != TYPE_IDEV) {
479                 rc = -EINVAL;
480                 goto out_dev_obj;
481         }
482
483         rc = iommufd_device_replace(sobj->idev.idev, &pt_id);
484         if (rc)
485                 goto out_dev_obj;
486
487         cmd->mock_domain_replace.pt_id = pt_id;
488         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
489
490 out_dev_obj:
491         iommufd_put_object(dev_obj);
492         return rc;
493 }
494
495 /* Add an additional reserved IOVA to the IOAS */
496 static int iommufd_test_add_reserved(struct iommufd_ucmd *ucmd,
497                                      unsigned int mockpt_id,
498                                      unsigned long start, size_t length)
499 {
500         struct iommufd_ioas *ioas;
501         int rc;
502
503         ioas = iommufd_get_ioas(ucmd->ictx, mockpt_id);
504         if (IS_ERR(ioas))
505                 return PTR_ERR(ioas);
506         down_write(&ioas->iopt.iova_rwsem);
507         rc = iopt_reserve_iova(&ioas->iopt, start, start + length - 1, NULL);
508         up_write(&ioas->iopt.iova_rwsem);
509         iommufd_put_object(&ioas->obj);
510         return rc;
511 }
512
513 /* Check that every pfn under each iova matches the pfn under a user VA */
514 static int iommufd_test_md_check_pa(struct iommufd_ucmd *ucmd,
515                                     unsigned int mockpt_id, unsigned long iova,
516                                     size_t length, void __user *uptr)
517 {
518         struct iommufd_hw_pagetable *hwpt;
519         struct mock_iommu_domain *mock;
520         uintptr_t end;
521         int rc;
522
523         if (iova % MOCK_IO_PAGE_SIZE || length % MOCK_IO_PAGE_SIZE ||
524             (uintptr_t)uptr % MOCK_IO_PAGE_SIZE ||
525             check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
526                 return -EINVAL;
527
528         hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
529         if (IS_ERR(hwpt))
530                 return PTR_ERR(hwpt);
531
532         for (; length; length -= MOCK_IO_PAGE_SIZE) {
533                 struct page *pages[1];
534                 unsigned long pfn;
535                 long npages;
536                 void *ent;
537
538                 npages = get_user_pages_fast((uintptr_t)uptr & PAGE_MASK, 1, 0,
539                                              pages);
540                 if (npages < 0) {
541                         rc = npages;
542                         goto out_put;
543                 }
544                 if (WARN_ON(npages != 1)) {
545                         rc = -EFAULT;
546                         goto out_put;
547                 }
548                 pfn = page_to_pfn(pages[0]);
549                 put_page(pages[0]);
550
551                 ent = xa_load(&mock->pfns, iova / MOCK_IO_PAGE_SIZE);
552                 if (!ent ||
553                     (xa_to_value(ent) & MOCK_PFN_MASK) * MOCK_IO_PAGE_SIZE !=
554                             pfn * PAGE_SIZE + ((uintptr_t)uptr % PAGE_SIZE)) {
555                         rc = -EINVAL;
556                         goto out_put;
557                 }
558                 iova += MOCK_IO_PAGE_SIZE;
559                 uptr += MOCK_IO_PAGE_SIZE;
560         }
561         rc = 0;
562
563 out_put:
564         iommufd_put_object(&hwpt->obj);
565         return rc;
566 }
567
568 /* Check that the page ref count matches, to look for missing pin/unpins */
569 static int iommufd_test_md_check_refs(struct iommufd_ucmd *ucmd,
570                                       void __user *uptr, size_t length,
571                                       unsigned int refs)
572 {
573         uintptr_t end;
574
575         if (length % PAGE_SIZE || (uintptr_t)uptr % PAGE_SIZE ||
576             check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
577                 return -EINVAL;
578
579         for (; length; length -= PAGE_SIZE) {
580                 struct page *pages[1];
581                 long npages;
582
583                 npages = get_user_pages_fast((uintptr_t)uptr, 1, 0, pages);
584                 if (npages < 0)
585                         return npages;
586                 if (WARN_ON(npages != 1))
587                         return -EFAULT;
588                 if (!PageCompound(pages[0])) {
589                         unsigned int count;
590
591                         count = page_ref_count(pages[0]);
592                         if (count / GUP_PIN_COUNTING_BIAS != refs) {
593                                 put_page(pages[0]);
594                                 return -EIO;
595                         }
596                 }
597                 put_page(pages[0]);
598                 uptr += PAGE_SIZE;
599         }
600         return 0;
601 }
602
603 struct selftest_access {
604         struct iommufd_access *access;
605         struct file *file;
606         struct mutex lock;
607         struct list_head items;
608         unsigned int next_id;
609         bool destroying;
610 };
611
612 struct selftest_access_item {
613         struct list_head items_elm;
614         unsigned long iova;
615         size_t length;
616         unsigned int id;
617 };
618
619 static const struct file_operations iommfd_test_staccess_fops;
620
621 static struct selftest_access *iommufd_access_get(int fd)
622 {
623         struct file *file;
624
625         file = fget(fd);
626         if (!file)
627                 return ERR_PTR(-EBADFD);
628
629         if (file->f_op != &iommfd_test_staccess_fops) {
630                 fput(file);
631                 return ERR_PTR(-EBADFD);
632         }
633         return file->private_data;
634 }
635
636 static void iommufd_test_access_unmap(void *data, unsigned long iova,
637                                       unsigned long length)
638 {
639         unsigned long iova_last = iova + length - 1;
640         struct selftest_access *staccess = data;
641         struct selftest_access_item *item;
642         struct selftest_access_item *tmp;
643
644         mutex_lock(&staccess->lock);
645         list_for_each_entry_safe(item, tmp, &staccess->items, items_elm) {
646                 if (iova > item->iova + item->length - 1 ||
647                     iova_last < item->iova)
648                         continue;
649                 list_del(&item->items_elm);
650                 iommufd_access_unpin_pages(staccess->access, item->iova,
651                                            item->length);
652                 kfree(item);
653         }
654         mutex_unlock(&staccess->lock);
655 }
656
657 static int iommufd_test_access_item_destroy(struct iommufd_ucmd *ucmd,
658                                             unsigned int access_id,
659                                             unsigned int item_id)
660 {
661         struct selftest_access_item *item;
662         struct selftest_access *staccess;
663
664         staccess = iommufd_access_get(access_id);
665         if (IS_ERR(staccess))
666                 return PTR_ERR(staccess);
667
668         mutex_lock(&staccess->lock);
669         list_for_each_entry(item, &staccess->items, items_elm) {
670                 if (item->id == item_id) {
671                         list_del(&item->items_elm);
672                         iommufd_access_unpin_pages(staccess->access, item->iova,
673                                                    item->length);
674                         mutex_unlock(&staccess->lock);
675                         kfree(item);
676                         fput(staccess->file);
677                         return 0;
678                 }
679         }
680         mutex_unlock(&staccess->lock);
681         fput(staccess->file);
682         return -ENOENT;
683 }
684
685 static int iommufd_test_staccess_release(struct inode *inode,
686                                          struct file *filep)
687 {
688         struct selftest_access *staccess = filep->private_data;
689
690         if (staccess->access) {
691                 iommufd_test_access_unmap(staccess, 0, ULONG_MAX);
692                 iommufd_access_destroy(staccess->access);
693         }
694         mutex_destroy(&staccess->lock);
695         kfree(staccess);
696         return 0;
697 }
698
699 static const struct iommufd_access_ops selftest_access_ops_pin = {
700         .needs_pin_pages = 1,
701         .unmap = iommufd_test_access_unmap,
702 };
703
704 static const struct iommufd_access_ops selftest_access_ops = {
705         .unmap = iommufd_test_access_unmap,
706 };
707
708 static const struct file_operations iommfd_test_staccess_fops = {
709         .release = iommufd_test_staccess_release,
710 };
711
712 static struct selftest_access *iommufd_test_alloc_access(void)
713 {
714         struct selftest_access *staccess;
715         struct file *filep;
716
717         staccess = kzalloc(sizeof(*staccess), GFP_KERNEL_ACCOUNT);
718         if (!staccess)
719                 return ERR_PTR(-ENOMEM);
720         INIT_LIST_HEAD(&staccess->items);
721         mutex_init(&staccess->lock);
722
723         filep = anon_inode_getfile("[iommufd_test_staccess]",
724                                    &iommfd_test_staccess_fops, staccess,
725                                    O_RDWR);
726         if (IS_ERR(filep)) {
727                 kfree(staccess);
728                 return ERR_CAST(filep);
729         }
730         staccess->file = filep;
731         return staccess;
732 }
733
734 static int iommufd_test_create_access(struct iommufd_ucmd *ucmd,
735                                       unsigned int ioas_id, unsigned int flags)
736 {
737         struct iommu_test_cmd *cmd = ucmd->cmd;
738         struct selftest_access *staccess;
739         struct iommufd_access *access;
740         u32 id;
741         int fdno;
742         int rc;
743
744         if (flags & ~MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES)
745                 return -EOPNOTSUPP;
746
747         staccess = iommufd_test_alloc_access();
748         if (IS_ERR(staccess))
749                 return PTR_ERR(staccess);
750
751         fdno = get_unused_fd_flags(O_CLOEXEC);
752         if (fdno < 0) {
753                 rc = -ENOMEM;
754                 goto out_free_staccess;
755         }
756
757         access = iommufd_access_create(
758                 ucmd->ictx,
759                 (flags & MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES) ?
760                         &selftest_access_ops_pin :
761                         &selftest_access_ops,
762                 staccess, &id);
763         if (IS_ERR(access)) {
764                 rc = PTR_ERR(access);
765                 goto out_put_fdno;
766         }
767         rc = iommufd_access_attach(access, ioas_id);
768         if (rc)
769                 goto out_destroy;
770         cmd->create_access.out_access_fd = fdno;
771         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
772         if (rc)
773                 goto out_destroy;
774
775         staccess->access = access;
776         fd_install(fdno, staccess->file);
777         return 0;
778
779 out_destroy:
780         iommufd_access_destroy(access);
781 out_put_fdno:
782         put_unused_fd(fdno);
783 out_free_staccess:
784         fput(staccess->file);
785         return rc;
786 }
787
788 static int iommufd_test_access_replace_ioas(struct iommufd_ucmd *ucmd,
789                                             unsigned int access_id,
790                                             unsigned int ioas_id)
791 {
792         struct selftest_access *staccess;
793         int rc;
794
795         staccess = iommufd_access_get(access_id);
796         if (IS_ERR(staccess))
797                 return PTR_ERR(staccess);
798
799         rc = iommufd_access_replace(staccess->access, ioas_id);
800         fput(staccess->file);
801         return rc;
802 }
803
804 /* Check that the pages in a page array match the pages in the user VA */
805 static int iommufd_test_check_pages(void __user *uptr, struct page **pages,
806                                     size_t npages)
807 {
808         for (; npages; npages--) {
809                 struct page *tmp_pages[1];
810                 long rc;
811
812                 rc = get_user_pages_fast((uintptr_t)uptr, 1, 0, tmp_pages);
813                 if (rc < 0)
814                         return rc;
815                 if (WARN_ON(rc != 1))
816                         return -EFAULT;
817                 put_page(tmp_pages[0]);
818                 if (tmp_pages[0] != *pages)
819                         return -EBADE;
820                 pages++;
821                 uptr += PAGE_SIZE;
822         }
823         return 0;
824 }
825
826 static int iommufd_test_access_pages(struct iommufd_ucmd *ucmd,
827                                      unsigned int access_id, unsigned long iova,
828                                      size_t length, void __user *uptr,
829                                      u32 flags)
830 {
831         struct iommu_test_cmd *cmd = ucmd->cmd;
832         struct selftest_access_item *item;
833         struct selftest_access *staccess;
834         struct page **pages;
835         size_t npages;
836         int rc;
837
838         /* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
839         if (length > 16*1024*1024)
840                 return -ENOMEM;
841
842         if (flags & ~(MOCK_FLAGS_ACCESS_WRITE | MOCK_FLAGS_ACCESS_SYZ))
843                 return -EOPNOTSUPP;
844
845         staccess = iommufd_access_get(access_id);
846         if (IS_ERR(staccess))
847                 return PTR_ERR(staccess);
848
849         if (staccess->access->ops != &selftest_access_ops_pin) {
850                 rc = -EOPNOTSUPP;
851                 goto out_put;
852         }
853
854         if (flags & MOCK_FLAGS_ACCESS_SYZ)
855                 iova = iommufd_test_syz_conv_iova(&staccess->access->ioas->iopt,
856                                         &cmd->access_pages.iova);
857
858         npages = (ALIGN(iova + length, PAGE_SIZE) -
859                   ALIGN_DOWN(iova, PAGE_SIZE)) /
860                  PAGE_SIZE;
861         pages = kvcalloc(npages, sizeof(*pages), GFP_KERNEL_ACCOUNT);
862         if (!pages) {
863                 rc = -ENOMEM;
864                 goto out_put;
865         }
866
867         /*
868          * Drivers will need to think very carefully about this locking. The
869          * core code can do multiple unmaps instantaneously after
870          * iommufd_access_pin_pages() and *all* the unmaps must not return until
871          * the range is unpinned. This simple implementation puts a global lock
872          * around the pin, which may not suit drivers that want this to be a
873          * performance path. drivers that get this wrong will trigger WARN_ON
874          * races and cause EDEADLOCK failures to userspace.
875          */
876         mutex_lock(&staccess->lock);
877         rc = iommufd_access_pin_pages(staccess->access, iova, length, pages,
878                                       flags & MOCK_FLAGS_ACCESS_WRITE);
879         if (rc)
880                 goto out_unlock;
881
882         /* For syzkaller allow uptr to be NULL to skip this check */
883         if (uptr) {
884                 rc = iommufd_test_check_pages(
885                         uptr - (iova - ALIGN_DOWN(iova, PAGE_SIZE)), pages,
886                         npages);
887                 if (rc)
888                         goto out_unaccess;
889         }
890
891         item = kzalloc(sizeof(*item), GFP_KERNEL_ACCOUNT);
892         if (!item) {
893                 rc = -ENOMEM;
894                 goto out_unaccess;
895         }
896
897         item->iova = iova;
898         item->length = length;
899         item->id = staccess->next_id++;
900         list_add_tail(&item->items_elm, &staccess->items);
901
902         cmd->access_pages.out_access_pages_id = item->id;
903         rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
904         if (rc)
905                 goto out_free_item;
906         goto out_unlock;
907
908 out_free_item:
909         list_del(&item->items_elm);
910         kfree(item);
911 out_unaccess:
912         iommufd_access_unpin_pages(staccess->access, iova, length);
913 out_unlock:
914         mutex_unlock(&staccess->lock);
915         kvfree(pages);
916 out_put:
917         fput(staccess->file);
918         return rc;
919 }
920
921 static int iommufd_test_access_rw(struct iommufd_ucmd *ucmd,
922                                   unsigned int access_id, unsigned long iova,
923                                   size_t length, void __user *ubuf,
924                                   unsigned int flags)
925 {
926         struct iommu_test_cmd *cmd = ucmd->cmd;
927         struct selftest_access *staccess;
928         void *tmp;
929         int rc;
930
931         /* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
932         if (length > 16*1024*1024)
933                 return -ENOMEM;
934
935         if (flags & ~(MOCK_ACCESS_RW_WRITE | MOCK_ACCESS_RW_SLOW_PATH |
936                       MOCK_FLAGS_ACCESS_SYZ))
937                 return -EOPNOTSUPP;
938
939         staccess = iommufd_access_get(access_id);
940         if (IS_ERR(staccess))
941                 return PTR_ERR(staccess);
942
943         tmp = kvzalloc(length, GFP_KERNEL_ACCOUNT);
944         if (!tmp) {
945                 rc = -ENOMEM;
946                 goto out_put;
947         }
948
949         if (flags & MOCK_ACCESS_RW_WRITE) {
950                 if (copy_from_user(tmp, ubuf, length)) {
951                         rc = -EFAULT;
952                         goto out_free;
953                 }
954         }
955
956         if (flags & MOCK_FLAGS_ACCESS_SYZ)
957                 iova = iommufd_test_syz_conv_iova(&staccess->access->ioas->iopt,
958                                         &cmd->access_rw.iova);
959
960         rc = iommufd_access_rw(staccess->access, iova, tmp, length, flags);
961         if (rc)
962                 goto out_free;
963         if (!(flags & MOCK_ACCESS_RW_WRITE)) {
964                 if (copy_to_user(ubuf, tmp, length)) {
965                         rc = -EFAULT;
966                         goto out_free;
967                 }
968         }
969
970 out_free:
971         kvfree(tmp);
972 out_put:
973         fput(staccess->file);
974         return rc;
975 }
976 static_assert((unsigned int)MOCK_ACCESS_RW_WRITE == IOMMUFD_ACCESS_RW_WRITE);
977 static_assert((unsigned int)MOCK_ACCESS_RW_SLOW_PATH ==
978               __IOMMUFD_ACCESS_RW_SLOW_PATH);
979
980 void iommufd_selftest_destroy(struct iommufd_object *obj)
981 {
982         struct selftest_obj *sobj = container_of(obj, struct selftest_obj, obj);
983
984         switch (sobj->type) {
985         case TYPE_IDEV:
986                 iommufd_device_detach(sobj->idev.idev);
987                 iommufd_device_unbind(sobj->idev.idev);
988                 mock_dev_destroy(sobj->idev.mock_dev);
989                 break;
990         }
991 }
992
993 int iommufd_test(struct iommufd_ucmd *ucmd)
994 {
995         struct iommu_test_cmd *cmd = ucmd->cmd;
996
997         switch (cmd->op) {
998         case IOMMU_TEST_OP_ADD_RESERVED:
999                 return iommufd_test_add_reserved(ucmd, cmd->id,
1000                                                  cmd->add_reserved.start,
1001                                                  cmd->add_reserved.length);
1002         case IOMMU_TEST_OP_MOCK_DOMAIN:
1003                 return iommufd_test_mock_domain(ucmd, cmd);
1004         case IOMMU_TEST_OP_MOCK_DOMAIN_REPLACE:
1005                 return iommufd_test_mock_domain_replace(
1006                         ucmd, cmd->id, cmd->mock_domain_replace.pt_id, cmd);
1007         case IOMMU_TEST_OP_MD_CHECK_MAP:
1008                 return iommufd_test_md_check_pa(
1009                         ucmd, cmd->id, cmd->check_map.iova,
1010                         cmd->check_map.length,
1011                         u64_to_user_ptr(cmd->check_map.uptr));
1012         case IOMMU_TEST_OP_MD_CHECK_REFS:
1013                 return iommufd_test_md_check_refs(
1014                         ucmd, u64_to_user_ptr(cmd->check_refs.uptr),
1015                         cmd->check_refs.length, cmd->check_refs.refs);
1016         case IOMMU_TEST_OP_CREATE_ACCESS:
1017                 return iommufd_test_create_access(ucmd, cmd->id,
1018                                                   cmd->create_access.flags);
1019         case IOMMU_TEST_OP_ACCESS_REPLACE_IOAS:
1020                 return iommufd_test_access_replace_ioas(
1021                         ucmd, cmd->id, cmd->access_replace_ioas.ioas_id);
1022         case IOMMU_TEST_OP_ACCESS_PAGES:
1023                 return iommufd_test_access_pages(
1024                         ucmd, cmd->id, cmd->access_pages.iova,
1025                         cmd->access_pages.length,
1026                         u64_to_user_ptr(cmd->access_pages.uptr),
1027                         cmd->access_pages.flags);
1028         case IOMMU_TEST_OP_ACCESS_RW:
1029                 return iommufd_test_access_rw(
1030                         ucmd, cmd->id, cmd->access_rw.iova,
1031                         cmd->access_rw.length,
1032                         u64_to_user_ptr(cmd->access_rw.uptr),
1033                         cmd->access_rw.flags);
1034         case IOMMU_TEST_OP_DESTROY_ACCESS_PAGES:
1035                 return iommufd_test_access_item_destroy(
1036                         ucmd, cmd->id, cmd->destroy_access_pages.access_pages_id);
1037         case IOMMU_TEST_OP_SET_TEMP_MEMORY_LIMIT:
1038                 /* Protect _batch_init(), can not be less than elmsz */
1039                 if (cmd->memory_limit.limit <
1040                     sizeof(unsigned long) + sizeof(u32))
1041                         return -EINVAL;
1042                 iommufd_test_memory_limit = cmd->memory_limit.limit;
1043                 return 0;
1044         default:
1045                 return -EOPNOTSUPP;
1046         }
1047 }
1048
1049 bool iommufd_should_fail(void)
1050 {
1051         return should_fail(&fail_iommufd, 1);
1052 }
1053
1054 int __init iommufd_test_init(void)
1055 {
1056         struct platform_device_info pdevinfo = {
1057                 .name = "iommufd_selftest_iommu",
1058         };
1059         int rc;
1060
1061         dbgfs_root =
1062                 fault_create_debugfs_attr("fail_iommufd", NULL, &fail_iommufd);
1063
1064         selftest_iommu_dev = platform_device_register_full(&pdevinfo);
1065         if (IS_ERR(selftest_iommu_dev)) {
1066                 rc = PTR_ERR(selftest_iommu_dev);
1067                 goto err_dbgfs;
1068         }
1069
1070         rc = bus_register(&iommufd_mock_bus_type.bus);
1071         if (rc)
1072                 goto err_platform;
1073
1074         rc = iommu_device_sysfs_add(&mock_iommu_device,
1075                                     &selftest_iommu_dev->dev, NULL, "%s",
1076                                     dev_name(&selftest_iommu_dev->dev));
1077         if (rc)
1078                 goto err_bus;
1079
1080         rc = iommu_device_register_bus(&mock_iommu_device, &mock_ops,
1081                                   &iommufd_mock_bus_type.bus,
1082                                   &iommufd_mock_bus_type.nb);
1083         if (rc)
1084                 goto err_sysfs;
1085         return 0;
1086
1087 err_sysfs:
1088         iommu_device_sysfs_remove(&mock_iommu_device);
1089 err_bus:
1090         bus_unregister(&iommufd_mock_bus_type.bus);
1091 err_platform:
1092         platform_device_unregister(selftest_iommu_dev);
1093 err_dbgfs:
1094         debugfs_remove_recursive(dbgfs_root);
1095         return rc;
1096 }
1097
1098 void iommufd_test_exit(void)
1099 {
1100         iommu_device_sysfs_remove(&mock_iommu_device);
1101         iommu_device_unregister_bus(&mock_iommu_device,
1102                                     &iommufd_mock_bus_type.bus,
1103                                     &iommufd_mock_bus_type.nb);
1104         bus_unregister(&iommufd_mock_bus_type.bus);
1105         platform_device_unregister(selftest_iommu_dev);
1106         debugfs_remove_recursive(dbgfs_root);
1107 }