Merge tag 'vfio-v6.3-rc1' of https://github.com/awilliam/linux-vfio
[platform/kernel/linux-rpi.git] / drivers / vfio / vfio_iommu_type1.c
index a44ac3f..493c31d 100644 (file)
@@ -71,11 +71,9 @@ struct vfio_iommu {
        unsigned int            vaddr_invalid_count;
        uint64_t                pgsize_bitmap;
        uint64_t                num_non_pinned_groups;
-       wait_queue_head_t       vaddr_wait;
        bool                    v2;
        bool                    nesting;
        bool                    dirty_page_tracking;
-       bool                    container_open;
        struct list_head        emulated_iommu_groups;
 };
 
@@ -99,6 +97,8 @@ struct vfio_dma {
        struct task_struct      *task;
        struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
        unsigned long           *bitmap;
+       struct mm_struct        *mm;
+       size_t                  locked_vm;
 };
 
 struct vfio_batch {
@@ -151,8 +151,6 @@ struct vfio_regions {
 #define DIRTY_BITMAP_PAGES_MAX  ((u64)INT_MAX)
 #define DIRTY_BITMAP_SIZE_MAX   DIRTY_BITMAP_BYTES(DIRTY_BITMAP_PAGES_MAX)
 
-#define WAITED 1
-
 static int put_pfn(unsigned long pfn, int prot);
 
 static struct vfio_iommu_group*
@@ -411,6 +409,19 @@ static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
        return ret;
 }
 
+static int mm_lock_acct(struct task_struct *task, struct mm_struct *mm,
+                       bool lock_cap, long npage)
+{
+       int ret = mmap_write_lock_killable(mm);
+
+       if (ret)
+               return ret;
+
+       ret = __account_locked_vm(mm, abs(npage), npage > 0, task, lock_cap);
+       mmap_write_unlock(mm);
+       return ret;
+}
+
 static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
 {
        struct mm_struct *mm;
@@ -419,16 +430,13 @@ static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
        if (!npage)
                return 0;
 
-       mm = async ? get_task_mm(dma->task) : dma->task->mm;
-       if (!mm)
+       mm = dma->mm;
+       if (async && !mmget_not_zero(mm))
                return -ESRCH; /* process exited */
 
-       ret = mmap_write_lock_killable(mm);
-       if (!ret) {
-               ret = __account_locked_vm(mm, abs(npage), npage > 0, dma->task,
-                                         dma->lock_cap);
-               mmap_write_unlock(mm);
-       }
+       ret = mm_lock_acct(dma->task, mm, dma->lock_cap, npage);
+       if (!ret)
+               dma->locked_vm += npage;
 
        if (async)
                mmput(mm);
@@ -594,61 +602,6 @@ done:
        return ret;
 }
 
-static int vfio_wait(struct vfio_iommu *iommu)
-{
-       DEFINE_WAIT(wait);
-
-       prepare_to_wait(&iommu->vaddr_wait, &wait, TASK_KILLABLE);
-       mutex_unlock(&iommu->lock);
-       schedule();
-       mutex_lock(&iommu->lock);
-       finish_wait(&iommu->vaddr_wait, &wait);
-       if (kthread_should_stop() || !iommu->container_open ||
-           fatal_signal_pending(current)) {
-               return -EFAULT;
-       }
-       return WAITED;
-}
-
-/*
- * Find dma struct and wait for its vaddr to be valid.  iommu lock is dropped
- * if the task waits, but is re-locked on return.  Return result in *dma_p.
- * Return 0 on success with no waiting, WAITED on success if waited, and -errno
- * on error.
- */
-static int vfio_find_dma_valid(struct vfio_iommu *iommu, dma_addr_t start,
-                              size_t size, struct vfio_dma **dma_p)
-{
-       int ret = 0;
-
-       do {
-               *dma_p = vfio_find_dma(iommu, start, size);
-               if (!*dma_p)
-                       return -EINVAL;
-               else if (!(*dma_p)->vaddr_invalid)
-                       return ret;
-               else
-                       ret = vfio_wait(iommu);
-       } while (ret == WAITED);
-
-       return ret;
-}
-
-/*
- * Wait for all vaddr in the dma_list to become valid.  iommu lock is dropped
- * if the task waits, but is re-locked on return.  Return 0 on success with no
- * waiting, WAITED on success if waited, and -errno on error.
- */
-static int vfio_wait_all_valid(struct vfio_iommu *iommu)
-{
-       int ret = 0;
-
-       while (iommu->vaddr_invalid_count && ret >= 0)
-               ret = vfio_wait(iommu);
-
-       return ret;
-}
-
 /*
  * Attempt to pin pages.  We really don't want to track all the pfns and
  * the iommu can only map chunks of consecutive pfns anyway, so get the
@@ -793,8 +746,8 @@ static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
        struct mm_struct *mm;
        int ret;
 
-       mm = get_task_mm(dma->task);
-       if (!mm)
+       mm = dma->mm;
+       if (!mmget_not_zero(mm))
                return -ENODEV;
 
        ret = vaddr_get_pfns(mm, vaddr, 1, dma->prot, pfn_base, pages);
@@ -804,7 +757,7 @@ static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
        ret = 0;
 
        if (do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
-               ret = vfio_lock_acct(dma, 1, true);
+               ret = vfio_lock_acct(dma, 1, false);
                if (ret) {
                        put_pfn(*pfn_base, dma->prot);
                        if (ret == -ENOMEM)
@@ -849,7 +802,6 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
        unsigned long remote_vaddr;
        struct vfio_dma *dma;
        bool do_accounting;
-       dma_addr_t iova;
 
        if (!iommu || !pages)
                return -EINVAL;
@@ -860,20 +812,10 @@ static int vfio_iommu_type1_pin_pages(void *iommu_data,
 
        mutex_lock(&iommu->lock);
 
-       /*
-        * Wait for all necessary vaddr's to be valid so they can be used in
-        * the main loop without dropping the lock, to avoid racing vs unmap.
-        */
-again:
-       if (iommu->vaddr_invalid_count) {
-               for (i = 0; i < npage; i++) {
-                       iova = user_iova + PAGE_SIZE * i;
-                       ret = vfio_find_dma_valid(iommu, iova, PAGE_SIZE, &dma);
-                       if (ret < 0)
-                               goto pin_done;
-                       if (ret == WAITED)
-                               goto again;
-               }
+       if (WARN_ONCE(iommu->vaddr_invalid_count,
+                     "vfio_pin_pages not allowed with VFIO_UPDATE_VADDR\n")) {
+               ret = -EBUSY;
+               goto pin_done;
        }
 
        /* Fail if no dma_umap notifier is registered */
@@ -891,6 +833,7 @@ again:
 
        for (i = 0; i < npage; i++) {
                unsigned long phys_pfn;
+               dma_addr_t iova;
                struct vfio_pfn *vpfn;
 
                iova = user_iova + PAGE_SIZE * i;
@@ -1173,11 +1116,10 @@ static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
        vfio_unmap_unpin(iommu, dma, true);
        vfio_unlink_dma(iommu, dma);
        put_task_struct(dma->task);
+       mmdrop(dma->mm);
        vfio_dma_bitmap_free(dma);
-       if (dma->vaddr_invalid) {
+       if (dma->vaddr_invalid)
                iommu->vaddr_invalid_count--;
-               wake_up_all(&iommu->vaddr_wait);
-       }
        kfree(dma);
        iommu->dma_avail++;
 }
@@ -1342,6 +1284,12 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
 
        mutex_lock(&iommu->lock);
 
+       /* Cannot update vaddr if mdev is present. */
+       if (invalidate_vaddr && !list_empty(&iommu->emulated_iommu_groups)) {
+               ret = -EBUSY;
+               goto unlock;
+       }
+
        pgshift = __ffs(iommu->pgsize_bitmap);
        pgsize = (size_t)1 << pgshift;
 
@@ -1566,6 +1514,38 @@ static bool vfio_iommu_iova_dma_valid(struct vfio_iommu *iommu,
        return list_empty(iova);
 }
 
+static int vfio_change_dma_owner(struct vfio_dma *dma)
+{
+       struct task_struct *task = current->group_leader;
+       struct mm_struct *mm = current->mm;
+       long npage = dma->locked_vm;
+       bool lock_cap;
+       int ret;
+
+       if (mm == dma->mm)
+               return 0;
+
+       lock_cap = capable(CAP_IPC_LOCK);
+       ret = mm_lock_acct(task, mm, lock_cap, npage);
+       if (ret)
+               return ret;
+
+       if (mmget_not_zero(dma->mm)) {
+               mm_lock_acct(dma->task, dma->mm, dma->lock_cap, -npage);
+               mmput(dma->mm);
+       }
+
+       if (dma->task != task) {
+               put_task_struct(dma->task);
+               dma->task = get_task_struct(task);
+       }
+       mmdrop(dma->mm);
+       dma->mm = mm;
+       mmgrab(dma->mm);
+       dma->lock_cap = lock_cap;
+       return 0;
+}
+
 static int vfio_dma_do_map(struct vfio_iommu *iommu,
                           struct vfio_iommu_type1_dma_map *map)
 {
@@ -1615,10 +1595,12 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu,
                           dma->size != size) {
                        ret = -EINVAL;
                } else {
+                       ret = vfio_change_dma_owner(dma);
+                       if (ret)
+                               goto out_unlock;
                        dma->vaddr = vaddr;
                        dma->vaddr_invalid = false;
                        iommu->vaddr_invalid_count--;
-                       wake_up_all(&iommu->vaddr_wait);
                }
                goto out_unlock;
        } else if (dma) {
@@ -1652,29 +1634,15 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu,
         * against the locked memory limit and we need to be able to do both
         * outside of this call path as pinning can be asynchronous via the
         * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
-        * task_struct and VM locked pages requires an mm_struct, however
-        * holding an indefinite mm reference is not recommended, therefore we
-        * only hold a reference to a task.  We could hold a reference to
-        * current, however QEMU uses this call path through vCPU threads,
-        * which can be killed resulting in a NULL mm and failure in the unmap
-        * path when called via a different thread.  Avoid this problem by
-        * using the group_leader as threads within the same group require
-        * both CLONE_THREAD and CLONE_VM and will therefore use the same
-        * mm_struct.
-        *
-        * Previously we also used the task for testing CAP_IPC_LOCK at the
-        * time of pinning and accounting, however has_capability() makes use
-        * of real_cred, a copy-on-write field, so we can't guarantee that it
-        * matches group_leader, or in fact that it might not change by the
-        * time it's evaluated.  If a process were to call MAP_DMA with
-        * CAP_IPC_LOCK but later drop it, it doesn't make sense that they
-        * possibly see different results for an iommu_mapped vfio_dma vs
-        * externally mapped.  Therefore track CAP_IPC_LOCK in vfio_dma at the
-        * time of calling MAP_DMA.
+        * task_struct. Save the group_leader so that all DMA tracking uses
+        * the same task, to make debugging easier.  VM locked pages requires
+        * an mm_struct, so grab the mm in case the task dies.
         */
        get_task_struct(current->group_leader);
        dma->task = current->group_leader;
        dma->lock_cap = capable(CAP_IPC_LOCK);
+       dma->mm = current->mm;
+       mmgrab(dma->mm);
 
        dma->pfn_list = RB_ROOT;
 
@@ -1707,10 +1675,6 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
        unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
        int ret;
 
-       ret = vfio_wait_all_valid(iommu);
-       if (ret < 0)
-               return ret;
-
        /* Arbitrarily pick the first domain in the list for lookups */
        if (!list_empty(&iommu->domain_list))
                d = list_first_entry(&iommu->domain_list,
@@ -2188,11 +2152,16 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
        struct iommu_domain_geometry *geo;
        LIST_HEAD(iova_copy);
        LIST_HEAD(group_resv_regions);
-       int ret = -EINVAL;
+       int ret = -EBUSY;
 
        mutex_lock(&iommu->lock);
 
+       /* Attach could require pinning, so disallow while vaddr is invalid. */
+       if (iommu->vaddr_invalid_count)
+               goto out_unlock;
+
        /* Check for duplicates */
+       ret = -EINVAL;
        if (vfio_iommu_find_iommu_group(iommu, iommu_group))
                goto out_unlock;
 
@@ -2592,11 +2561,9 @@ static void *vfio_iommu_type1_open(unsigned long arg)
        INIT_LIST_HEAD(&iommu->iova_list);
        iommu->dma_list = RB_ROOT;
        iommu->dma_avail = dma_entry_limit;
-       iommu->container_open = true;
        mutex_init(&iommu->lock);
        mutex_init(&iommu->device_list_lock);
        INIT_LIST_HEAD(&iommu->device_list);
-       init_waitqueue_head(&iommu->vaddr_wait);
        iommu->pgsize_bitmap = PAGE_MASK;
        INIT_LIST_HEAD(&iommu->emulated_iommu_groups);
 
@@ -2660,6 +2627,16 @@ static int vfio_domains_have_enforce_cache_coherency(struct vfio_iommu *iommu)
        return ret;
 }
 
+static bool vfio_iommu_has_emulated(struct vfio_iommu *iommu)
+{
+       bool ret;
+
+       mutex_lock(&iommu->lock);
+       ret = !list_empty(&iommu->emulated_iommu_groups);
+       mutex_unlock(&iommu->lock);
+       return ret;
+}
+
 static int vfio_iommu_type1_check_extension(struct vfio_iommu *iommu,
                                            unsigned long arg)
 {
@@ -2668,8 +2645,13 @@ static int vfio_iommu_type1_check_extension(struct vfio_iommu *iommu,
        case VFIO_TYPE1v2_IOMMU:
        case VFIO_TYPE1_NESTING_IOMMU:
        case VFIO_UNMAP_ALL:
-       case VFIO_UPDATE_VADDR:
                return 1;
+       case VFIO_UPDATE_VADDR:
+               /*
+                * Disable this feature if mdevs are present.  They cannot
+                * safely pin/unpin/rw while vaddrs are being updated.
+                */
+               return iommu && !vfio_iommu_has_emulated(iommu);
        case VFIO_DMA_CC_IOMMU:
                if (!iommu)
                        return 0;
@@ -3078,21 +3060,19 @@ static int vfio_iommu_type1_dma_rw_chunk(struct vfio_iommu *iommu,
        struct vfio_dma *dma;
        bool kthread = current->mm == NULL;
        size_t offset;
-       int ret;
 
        *copied = 0;
 
-       ret = vfio_find_dma_valid(iommu, user_iova, 1, &dma);
-       if (ret < 0)
-               return ret;
+       dma = vfio_find_dma(iommu, user_iova, 1);
+       if (!dma)
+               return -EINVAL;
 
        if ((write && !(dma->prot & IOMMU_WRITE)) ||
                        !(dma->prot & IOMMU_READ))
                return -EPERM;
 
-       mm = get_task_mm(dma->task);
-
-       if (!mm)
+       mm = dma->mm;
+       if (!mmget_not_zero(mm))
                return -EPERM;
 
        if (kthread)
@@ -3138,6 +3118,13 @@ static int vfio_iommu_type1_dma_rw(void *iommu_data, dma_addr_t user_iova,
        size_t done;
 
        mutex_lock(&iommu->lock);
+
+       if (WARN_ONCE(iommu->vaddr_invalid_count,
+                     "vfio_dma_rw not allowed with VFIO_UPDATE_VADDR\n")) {
+               ret = -EBUSY;
+               goto out;
+       }
+
        while (count > 0) {
                ret = vfio_iommu_type1_dma_rw_chunk(iommu, user_iova, data,
                                                    count, write, &done);
@@ -3149,6 +3136,7 @@ static int vfio_iommu_type1_dma_rw(void *iommu_data, dma_addr_t user_iova,
                user_iova += done;
        }
 
+out:
        mutex_unlock(&iommu->lock);
        return ret;
 }
@@ -3176,19 +3164,6 @@ vfio_iommu_type1_group_iommu_domain(void *iommu_data,
        return domain;
 }
 
-static void vfio_iommu_type1_notify(void *iommu_data,
-                                   enum vfio_iommu_notify_type event)
-{
-       struct vfio_iommu *iommu = iommu_data;
-
-       if (event != VFIO_IOMMU_CONTAINER_CLOSE)
-               return;
-       mutex_lock(&iommu->lock);
-       iommu->container_open = false;
-       mutex_unlock(&iommu->lock);
-       wake_up_all(&iommu->vaddr_wait);
-}
-
 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
        .name                   = "vfio-iommu-type1",
        .owner                  = THIS_MODULE,
@@ -3203,7 +3178,6 @@ static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
        .unregister_device      = vfio_iommu_type1_unregister_device,
        .dma_rw                 = vfio_iommu_type1_dma_rw,
        .group_iommu_domain     = vfio_iommu_type1_group_iommu_domain,
-       .notify                 = vfio_iommu_type1_notify,
 };
 
 static int __init vfio_iommu_type1_init(void)