drivers/iommu: Take a ref to the IOMMU driver prior to ->add_device()
authorWill Deacon <will@kernel.org>
Thu, 19 Dec 2019 12:03:41 +0000 (12:03 +0000)
committerJoerg Roedel <jroedel@suse.de>
Mon, 23 Dec 2019 13:06:05 +0000 (14:06 +0100)
To avoid accidental removal of an active IOMMU driver module, take a
reference to the driver module in 'iommu_probe_device()' immediately
prior to invoking the '->add_device()' callback and hold it until the
after the device has been removed by '->remove_device()'.

Suggested-by: Joerg Roedel <joro@8bytes.org>
Signed-off-by: Will Deacon <will@kernel.org>
Tested-by: John Garry <john.garry@huawei.com> # smmu v3
Reviewed-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
Signed-off-by: Joerg Roedel <jroedel@suse.de>
drivers/iommu/iommu.c
include/linux/iommu.h

index 3abe19e..32ceda1 100644 (file)
@@ -22,6 +22,7 @@
 #include <linux/bitops.h>
 #include <linux/property.h>
 #include <linux/fsl/mc.h>
+#include <linux/module.h>
 #include <trace/events/iommu.h>
 
 static struct kset *iommu_group_kset;
@@ -185,10 +186,21 @@ int iommu_probe_device(struct device *dev)
        if (!iommu_get_dev_param(dev))
                return -ENOMEM;
 
+       if (!try_module_get(ops->owner)) {
+               ret = -EINVAL;
+               goto err_free_dev_param;
+       }
+
        ret = ops->add_device(dev);
        if (ret)
-               iommu_free_dev_param(dev);
+               goto err_module_put;
+
+       return 0;
 
+err_module_put:
+       module_put(ops->owner);
+err_free_dev_param:
+       iommu_free_dev_param(dev);
        return ret;
 }
 
@@ -199,7 +211,10 @@ void iommu_release_device(struct device *dev)
        if (dev->iommu_group)
                ops->remove_device(dev);
 
-       iommu_free_dev_param(dev);
+       if (dev->iommu_param) {
+               module_put(ops->owner);
+               iommu_free_dev_param(dev);
+       }
 }
 
 static struct iommu_domain *__iommu_domain_alloc(struct bus_type *bus,
index f2223cb..e9f94d3 100644 (file)
@@ -246,9 +246,10 @@ struct iommu_iotlb_gather {
  * @sva_get_pasid: Get PASID associated to a SVA handle
  * @page_response: handle page request response
  * @cache_invalidate: invalidate translation caches
- * @pgsize_bitmap: bitmap of all possible supported page sizes
  * @sva_bind_gpasid: bind guest pasid and mm
  * @sva_unbind_gpasid: unbind guest pasid and mm
+ * @pgsize_bitmap: bitmap of all possible supported page sizes
+ * @owner: Driver module providing these ops
  */
 struct iommu_ops {
        bool (*capable)(enum iommu_cap);
@@ -318,6 +319,7 @@ struct iommu_ops {
        int (*sva_unbind_gpasid)(struct device *dev, int pasid);
 
        unsigned long pgsize_bitmap;
+       struct module *owner;
 };
 
 /**