vfio: Use down_reads to protect iommu disconnects
authorAlex Williamson <alex.williamson@redhat.com>
Mon, 29 Apr 2013 14:41:36 +0000 (08:41 -0600)
committerAlex Williamson <alex.williamson@redhat.com>
Mon, 29 Apr 2013 14:41:36 +0000 (08:41 -0600)
If a group or device is released or a container is unset from a group
it can race against file ops on the container.  Protect these with
down_reads to allow concurrent users.

Signed-off-by: Alex Williamson <alex.williamson@redhat.com>
Reported-by: Michael S. Tsirkin <mst@redhat.com>
drivers/vfio/vfio.c

index 073788e..ac7423b 100644 (file)
@@ -704,9 +704,13 @@ EXPORT_SYMBOL_GPL(vfio_del_group_dev);
 static long vfio_ioctl_check_extension(struct vfio_container *container,
                                       unsigned long arg)
 {
-       struct vfio_iommu_driver *driver = container->iommu_driver;
+       struct vfio_iommu_driver *driver;
        long ret = 0;
 
+       down_read(&container->group_lock);
+
+       driver = container->iommu_driver;
+
        switch (arg) {
                /* No base extensions yet */
        default:
@@ -736,6 +740,8 @@ static long vfio_ioctl_check_extension(struct vfio_container *container,
                                                 VFIO_CHECK_EXTENSION, arg);
        }
 
+       up_read(&container->group_lock);
+
        return ret;
 }
 
@@ -844,9 +850,6 @@ static long vfio_fops_unl_ioctl(struct file *filep,
        if (!container)
                return ret;
 
-       driver = container->iommu_driver;
-       data = container->iommu_data;
-
        switch (cmd) {
        case VFIO_GET_API_VERSION:
                ret = VFIO_API_VERSION;
@@ -858,8 +861,15 @@ static long vfio_fops_unl_ioctl(struct file *filep,
                ret = vfio_ioctl_set_iommu(container, arg);
                break;
        default:
+               down_read(&container->group_lock);
+
+               driver = container->iommu_driver;
+               data = container->iommu_data;
+
                if (driver) /* passthrough all unrecognized ioctls */
                        ret = driver->ops->ioctl(data, cmd, arg);
+
+               up_read(&container->group_lock);
        }
 
        return ret;
@@ -910,35 +920,55 @@ static ssize_t vfio_fops_read(struct file *filep, char __user *buf,
                              size_t count, loff_t *ppos)
 {
        struct vfio_container *container = filep->private_data;
-       struct vfio_iommu_driver *driver = container->iommu_driver;
+       struct vfio_iommu_driver *driver;
+       ssize_t ret = -EINVAL;
 
-       if (unlikely(!driver || !driver->ops->read))
-               return -EINVAL;
+       down_read(&container->group_lock);
+
+       driver = container->iommu_driver;
+       if (likely(driver && driver->ops->read))
+               ret = driver->ops->read(container->iommu_data,
+                                       buf, count, ppos);
 
-       return driver->ops->read(container->iommu_data, buf, count, ppos);
+       up_read(&container->group_lock);
+
+       return ret;
 }
 
 static ssize_t vfio_fops_write(struct file *filep, const char __user *buf,
                               size_t count, loff_t *ppos)
 {
        struct vfio_container *container = filep->private_data;
-       struct vfio_iommu_driver *driver = container->iommu_driver;
+       struct vfio_iommu_driver *driver;
+       ssize_t ret = -EINVAL;
 
-       if (unlikely(!driver || !driver->ops->write))
-               return -EINVAL;
+       down_read(&container->group_lock);
+
+       driver = container->iommu_driver;
+       if (likely(driver && driver->ops->write))
+               ret = driver->ops->write(container->iommu_data,
+                                        buf, count, ppos);
+
+       up_read(&container->group_lock);
 
-       return driver->ops->write(container->iommu_data, buf, count, ppos);
+       return ret;
 }
 
 static int vfio_fops_mmap(struct file *filep, struct vm_area_struct *vma)
 {
        struct vfio_container *container = filep->private_data;
-       struct vfio_iommu_driver *driver = container->iommu_driver;
+       struct vfio_iommu_driver *driver;
+       int ret = -EINVAL;
 
-       if (unlikely(!driver || !driver->ops->mmap))
-               return -EINVAL;
+       down_read(&container->group_lock);
 
-       return driver->ops->mmap(container->iommu_data, vma);
+       driver = container->iommu_driver;
+       if (likely(driver && driver->ops->mmap))
+               ret = driver->ops->mmap(container->iommu_data, vma);
+
+       up_read(&container->group_lock);
+
+       return ret;
 }
 
 static const struct file_operations vfio_fops = {