Merge tag 'v6.3-p2' of git://git.kernel.org/pub/scm/linux/kernel/git/herbert/crypto-2.6
[platform/kernel/linux-starfive.git] / drivers / vfio / group.c
index e166ad7..27d5ba7 100644 (file)
@@ -140,7 +140,7 @@ static int vfio_group_ioctl_set_container(struct vfio_group *group,
                        ret = iommufd_vfio_compat_ioas_create(iommufd);
 
                if (ret) {
-                       iommufd_ctx_put(group->iommufd);
+                       iommufd_ctx_put(iommufd);
                        goto out_unlock;
                }
 
@@ -157,6 +157,18 @@ out_unlock:
        return ret;
 }
 
+static void vfio_device_group_get_kvm_safe(struct vfio_device *device)
+{
+       spin_lock(&device->group->kvm_ref_lock);
+       if (!device->group->kvm)
+               goto unlock;
+
+       _vfio_device_get_kvm_safe(device, device->group->kvm);
+
+unlock:
+       spin_unlock(&device->group->kvm_ref_lock);
+}
+
 static int vfio_device_group_open(struct vfio_device *device)
 {
        int ret;
@@ -167,13 +179,23 @@ static int vfio_device_group_open(struct vfio_device *device)
                goto out_unlock;
        }
 
+       mutex_lock(&device->dev_set->lock);
+
        /*
-        * Here we pass the KVM pointer with the group under the lock.  If the
-        * device driver will use it, it must obtain a reference and release it
-        * during close_device.
+        * Before the first device open, get the KVM pointer currently
+        * associated with the group (if there is one) and obtain a reference
+        * now that will be held until the open_count reaches 0 again.  Save
+        * the pointer in the device for use by drivers.
         */
-       ret = vfio_device_open(device, device->group->iommufd,
-                              device->group->kvm);
+       if (device->open_count == 0)
+               vfio_device_group_get_kvm_safe(device);
+
+       ret = vfio_device_open(device, device->group->iommufd);
+
+       if (device->open_count == 0)
+               vfio_device_put_kvm(device);
+
+       mutex_unlock(&device->dev_set->lock);
 
 out_unlock:
        mutex_unlock(&device->group->group_lock);
@@ -183,7 +205,14 @@ out_unlock:
 void vfio_device_group_close(struct vfio_device *device)
 {
        mutex_lock(&device->group->group_lock);
+       mutex_lock(&device->dev_set->lock);
+
        vfio_device_close(device, device->group->iommufd);
+
+       if (device->open_count == 0)
+               vfio_device_put_kvm(device);
+
+       mutex_unlock(&device->dev_set->lock);
        mutex_unlock(&device->group->group_lock);
 }
 
@@ -453,6 +482,7 @@ static struct vfio_group *vfio_group_alloc(struct iommu_group *iommu_group,
 
        refcount_set(&group->drivers, 1);
        mutex_init(&group->group_lock);
+       spin_lock_init(&group->kvm_ref_lock);
        INIT_LIST_HEAD(&group->device_list);
        mutex_init(&group->device_lock);
        group->iommu_group = iommu_group;
@@ -806,9 +836,9 @@ void vfio_file_set_kvm(struct file *file, struct kvm *kvm)
        if (!vfio_file_is_group(file))
                return;
 
-       mutex_lock(&group->group_lock);
+       spin_lock(&group->kvm_ref_lock);
        group->kvm = kvm;
-       mutex_unlock(&group->group_lock);
+       spin_unlock(&group->kvm_ref_lock);
 }
 EXPORT_SYMBOL_GPL(vfio_file_set_kvm);