vhost_vdpa: fix the crash in unmap a large memory
[platform/kernel/linux-starfive.git] / drivers / vhost / vdpa.c
index b08e07f..ec32f78 100644 (file)
@@ -66,8 +66,8 @@ static DEFINE_IDA(vhost_vdpa_ida);
 static dev_t vhost_vdpa_major;
 
 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
-                                  struct vhost_iotlb *iotlb,
-                                  u64 start, u64 last);
+                                  struct vhost_iotlb *iotlb, u64 start,
+                                  u64 last, u32 asid);
 
 static inline u32 iotlb_to_asid(struct vhost_iotlb *iotlb)
 {
@@ -139,7 +139,7 @@ static int vhost_vdpa_remove_as(struct vhost_vdpa *v, u32 asid)
                return -EINVAL;
 
        hlist_del(&as->hash_link);
-       vhost_vdpa_iotlb_unmap(v, &as->iotlb, 0ULL, 0ULL - 1);
+       vhost_vdpa_iotlb_unmap(v, &as->iotlb, 0ULL, 0ULL - 1, asid);
        kfree(as);
 
        return 0;
@@ -687,10 +687,20 @@ static long vhost_vdpa_unlocked_ioctl(struct file *filep,
        mutex_unlock(&d->mutex);
        return r;
 }
+static void vhost_vdpa_general_unmap(struct vhost_vdpa *v,
+                                    struct vhost_iotlb_map *map, u32 asid)
+{
+       struct vdpa_device *vdpa = v->vdpa;
+       const struct vdpa_config_ops *ops = vdpa->config;
+       if (ops->dma_map) {
+               ops->dma_unmap(vdpa, asid, map->start, map->size);
+       } else if (ops->set_map == NULL) {
+               iommu_unmap(v->domain, map->start, map->size);
+       }
+}
 
-static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v,
-                               struct vhost_iotlb *iotlb,
-                               u64 start, u64 last)
+static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
+                               u64 start, u64 last, u32 asid)
 {
        struct vhost_dev *dev = &v->vdev;
        struct vhost_iotlb_map *map;
@@ -707,13 +717,13 @@ static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v,
                        unpin_user_page(page);
                }
                atomic64_sub(PFN_DOWN(map->size), &dev->mm->pinned_vm);
+               vhost_vdpa_general_unmap(v, map, asid);
                vhost_iotlb_map_free(iotlb, map);
        }
 }
 
-static void vhost_vdpa_va_unmap(struct vhost_vdpa *v,
-                               struct vhost_iotlb *iotlb,
-                               u64 start, u64 last)
+static void vhost_vdpa_va_unmap(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
+                               u64 start, u64 last, u32 asid)
 {
        struct vhost_iotlb_map *map;
        struct vdpa_map_file *map_file;
@@ -722,20 +732,21 @@ static void vhost_vdpa_va_unmap(struct vhost_vdpa *v,
                map_file = (struct vdpa_map_file *)map->opaque;
                fput(map_file->file);
                kfree(map_file);
+               vhost_vdpa_general_unmap(v, map, asid);
                vhost_iotlb_map_free(iotlb, map);
        }
 }
 
 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
-                                  struct vhost_iotlb *iotlb,
-                                  u64 start, u64 last)
+                                  struct vhost_iotlb *iotlb, u64 start,
+                                  u64 last, u32 asid)
 {
        struct vdpa_device *vdpa = v->vdpa;
 
        if (vdpa->use_va)
-               return vhost_vdpa_va_unmap(v, iotlb, start, last);
+               return vhost_vdpa_va_unmap(v, iotlb, start, last, asid);
 
-       return vhost_vdpa_pa_unmap(v, iotlb, start, last);
+       return vhost_vdpa_pa_unmap(v, iotlb, start, last, asid);
 }
 
 static int perm_to_iommu_flags(u32 perm)
@@ -802,17 +813,12 @@ static void vhost_vdpa_unmap(struct vhost_vdpa *v,
        const struct vdpa_config_ops *ops = vdpa->config;
        u32 asid = iotlb_to_asid(iotlb);
 
-       vhost_vdpa_iotlb_unmap(v, iotlb, iova, iova + size - 1);
+       vhost_vdpa_iotlb_unmap(v, iotlb, iova, iova + size - 1, asid);
 
-       if (ops->dma_map) {
-               ops->dma_unmap(vdpa, asid, iova, size);
-       } else if (ops->set_map) {
+       if (ops->set_map) {
                if (!v->in_batch)
                        ops->set_map(vdpa, asid, iotlb);
-       } else {
-               iommu_unmap(v->domain, iova, size);
        }
-
        /* If we are in the middle of batch processing, delay the free
         * of AS until BATCH_END.
         */