iommufd: Add iommufd_group
authorJason Gunthorpe <jgg@nvidia.com>
Mon, 17 Jul 2023 18:11:58 +0000 (15:11 -0300)
committerJason Gunthorpe <jgg@nvidia.com>
Wed, 26 Jul 2023 13:19:17 +0000 (10:19 -0300)
When the hwpt to device attachment is fairly static we could get away with
the simple approach of keeping track of the groups via a device list. But
with replace this is infeasible.

Add an automatically managed struct that is 1:1 with the iommu_group
per-ictx so we can store the necessary tracking information there.

Link: https://lore.kernel.org/r/2-v8-6659224517ea+532-iommufd_alloc_jgg@nvidia.com
Reviewed-by: Lu Baolu <baolu.lu@linux.intel.com>
Reviewed-by: Kevin Tian <kevin.tian@intel.com>
Tested-by: Nicolin Chen <nicolinc@nvidia.com>
Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
drivers/iommu/iommufd/device.c
drivers/iommu/iommufd/iommufd_private.h
drivers/iommu/iommufd/main.c

index d4cc394..dbd517d 100644 (file)
@@ -15,13 +15,121 @@ MODULE_PARM_DESC(
        "Allow IOMMUFD to bind to devices even if the platform cannot isolate "
        "the MSI interrupt window. Enabling this is a security weakness.");
 
+static void iommufd_group_release(struct kref *kref)
+{
+       struct iommufd_group *igroup =
+               container_of(kref, struct iommufd_group, ref);
+
+       xa_cmpxchg(&igroup->ictx->groups, iommu_group_id(igroup->group), igroup,
+                  NULL, GFP_KERNEL);
+       iommu_group_put(igroup->group);
+       kfree(igroup);
+}
+
+static void iommufd_put_group(struct iommufd_group *group)
+{
+       kref_put(&group->ref, iommufd_group_release);
+}
+
+static bool iommufd_group_try_get(struct iommufd_group *igroup,
+                                 struct iommu_group *group)
+{
+       if (!igroup)
+               return false;
+       /*
+        * group ID's cannot be re-used until the group is put back which does
+        * not happen if we could get an igroup pointer under the xa_lock.
+        */
+       if (WARN_ON(igroup->group != group))
+               return false;
+       return kref_get_unless_zero(&igroup->ref);
+}
+
+/*
+ * iommufd needs to store some more data for each iommu_group, we keep a
+ * parallel xarray indexed by iommu_group id to hold this instead of putting it
+ * in the core structure. To keep things simple the iommufd_group memory is
+ * unique within the iommufd_ctx. This makes it easy to check there are no
+ * memory leaks.
+ */
+static struct iommufd_group *iommufd_get_group(struct iommufd_ctx *ictx,
+                                              struct device *dev)
+{
+       struct iommufd_group *new_igroup;
+       struct iommufd_group *cur_igroup;
+       struct iommufd_group *igroup;
+       struct iommu_group *group;
+       unsigned int id;
+
+       group = iommu_group_get(dev);
+       if (!group)
+               return ERR_PTR(-ENODEV);
+
+       id = iommu_group_id(group);
+
+       xa_lock(&ictx->groups);
+       igroup = xa_load(&ictx->groups, id);
+       if (iommufd_group_try_get(igroup, group)) {
+               xa_unlock(&ictx->groups);
+               iommu_group_put(group);
+               return igroup;
+       }
+       xa_unlock(&ictx->groups);
+
+       new_igroup = kzalloc(sizeof(*new_igroup), GFP_KERNEL);
+       if (!new_igroup) {
+               iommu_group_put(group);
+               return ERR_PTR(-ENOMEM);
+       }
+
+       kref_init(&new_igroup->ref);
+       /* group reference moves into new_igroup */
+       new_igroup->group = group;
+
+       /*
+        * The ictx is not additionally refcounted here becase all objects using
+        * an igroup must put it before their destroy completes.
+        */
+       new_igroup->ictx = ictx;
+
+       /*
+        * We dropped the lock so igroup is invalid. NULL is a safe and likely
+        * value to assume for the xa_cmpxchg algorithm.
+        */
+       cur_igroup = NULL;
+       xa_lock(&ictx->groups);
+       while (true) {
+               igroup = __xa_cmpxchg(&ictx->groups, id, cur_igroup, new_igroup,
+                                     GFP_KERNEL);
+               if (xa_is_err(igroup)) {
+                       xa_unlock(&ictx->groups);
+                       iommufd_put_group(new_igroup);
+                       return ERR_PTR(xa_err(igroup));
+               }
+
+               /* new_group was successfully installed */
+               if (cur_igroup == igroup) {
+                       xa_unlock(&ictx->groups);
+                       return new_igroup;
+               }
+
+               /* Check again if the current group is any good */
+               if (iommufd_group_try_get(igroup, group)) {
+                       xa_unlock(&ictx->groups);
+                       iommufd_put_group(new_igroup);
+                       return igroup;
+               }
+               cur_igroup = igroup;
+       }
+}
+
 void iommufd_device_destroy(struct iommufd_object *obj)
 {
        struct iommufd_device *idev =
                container_of(obj, struct iommufd_device, obj);
 
        iommu_device_release_dma_owner(idev->dev);
-       iommu_group_put(idev->group);
+       iommufd_put_group(idev->igroup);
        if (!iommufd_selftest_is_mock_dev(idev->dev))
                iommufd_ctx_put(idev->ictx);
 }
@@ -46,7 +154,7 @@ struct iommufd_device *iommufd_device_bind(struct iommufd_ctx *ictx,
                                           struct device *dev, u32 *id)
 {
        struct iommufd_device *idev;
-       struct iommu_group *group;
+       struct iommufd_group *igroup;
        int rc;
 
        /*
@@ -56,9 +164,9 @@ struct iommufd_device *iommufd_device_bind(struct iommufd_ctx *ictx,
        if (!device_iommu_capable(dev, IOMMU_CAP_CACHE_COHERENCY))
                return ERR_PTR(-EINVAL);
 
-       group = iommu_group_get(dev);
-       if (!group)
-               return ERR_PTR(-ENODEV);
+       igroup = iommufd_get_group(ictx, dev);
+       if (IS_ERR(igroup))
+               return ERR_CAST(igroup);
 
        /*
         * For historical compat with VFIO the insecure interrupt path is
@@ -67,7 +175,7 @@ struct iommufd_device *iommufd_device_bind(struct iommufd_ctx *ictx,
         * interrupt outside this iommufd context.
         */
        if (!iommufd_selftest_is_mock_dev(dev) &&
-           !iommu_group_has_isolated_msi(group)) {
+           !iommu_group_has_isolated_msi(igroup->group)) {
                if (!allow_unsafe_interrupts) {
                        rc = -EPERM;
                        goto out_group_put;
@@ -97,8 +205,8 @@ struct iommufd_device *iommufd_device_bind(struct iommufd_ctx *ictx,
                device_iommu_capable(dev, IOMMU_CAP_ENFORCE_CACHE_COHERENCY);
        /* The calling driver is a user until iommufd_device_unbind() */
        refcount_inc(&idev->obj.users);
-       /* group refcount moves into iommufd_device */
-       idev->group = group;
+       /* igroup refcount moves into iommufd_device */
+       idev->igroup = igroup;
 
        /*
         * If the caller fails after this success it must call
@@ -113,7 +221,7 @@ struct iommufd_device *iommufd_device_bind(struct iommufd_ctx *ictx,
 out_release_owner:
        iommu_device_release_dma_owner(dev);
 out_group_put:
-       iommu_group_put(group);
+       iommufd_put_group(igroup);
        return ERR_PTR(rc);
 }
 EXPORT_SYMBOL_NS_GPL(iommufd_device_bind, IOMMUFD);
@@ -138,7 +246,8 @@ bool iommufd_ctx_has_group(struct iommufd_ctx *ictx, struct iommu_group *group)
        xa_lock(&ictx->objects);
        xa_for_each(&ictx->objects, index, obj) {
                if (obj->type == IOMMUFD_OBJ_DEVICE &&
-                   container_of(obj, struct iommufd_device, obj)->group == group) {
+                   container_of(obj, struct iommufd_device, obj)
+                                   ->igroup->group == group) {
                        xa_unlock(&ictx->objects);
                        return true;
                }
@@ -212,14 +321,14 @@ static int iommufd_device_setup_msi(struct iommufd_device *idev,
 }
 
 static bool iommufd_hw_pagetable_has_group(struct iommufd_hw_pagetable *hwpt,
-                                          struct iommu_group *group)
+                                          struct iommufd_group *igroup)
 {
        struct iommufd_device *cur_dev;
 
        lockdep_assert_held(&hwpt->devices_lock);
 
        list_for_each_entry(cur_dev, &hwpt->devices, devices_item)
-               if (cur_dev->group == group)
+               if (cur_dev->igroup->group == igroup->group)
                        return true;
        return false;
 }
@@ -253,7 +362,8 @@ int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
        }
 
        rc = iopt_table_enforce_group_resv_regions(&hwpt->ioas->iopt, idev->dev,
-                                                  idev->group, &sw_msi_start);
+                                                  idev->igroup->group,
+                                                  &sw_msi_start);
        if (rc)
                return rc;
 
@@ -265,8 +375,8 @@ int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
         * FIXME: Hack around missing a device-centric iommu api, only attach to
         * the group once for the first device that is in the group.
         */
-       if (!iommufd_hw_pagetable_has_group(hwpt, idev->group)) {
-               rc = iommu_attach_group(hwpt->domain, idev->group);
+       if (!iommufd_hw_pagetable_has_group(hwpt, idev->igroup)) {
+               rc = iommu_attach_group(hwpt->domain, idev->igroup->group);
                if (rc)
                        goto err_unresv;
        }
@@ -279,8 +389,8 @@ err_unresv:
 void iommufd_hw_pagetable_detach(struct iommufd_hw_pagetable *hwpt,
                                 struct iommufd_device *idev)
 {
-       if (!iommufd_hw_pagetable_has_group(hwpt, idev->group))
-               iommu_detach_group(hwpt->domain, idev->group);
+       if (!iommufd_hw_pagetable_has_group(hwpt, idev->igroup))
+               iommu_detach_group(hwpt->domain, idev->igroup->group);
        iopt_remove_reserved_iova(&hwpt->ioas->iopt, idev->dev);
 }
 
index 3dcaf86..285ac4d 100644 (file)
@@ -17,6 +17,7 @@ struct iommufd_device;
 struct iommufd_ctx {
        struct file *file;
        struct xarray objects;
+       struct xarray groups;
 
        u8 account_mode;
        /* Compatibility with VFIO no iommu */
@@ -262,6 +263,12 @@ void iommufd_hw_pagetable_detach(struct iommufd_hw_pagetable *hwpt,
                                 struct iommufd_device *idev);
 void iommufd_hw_pagetable_destroy(struct iommufd_object *obj);
 
+struct iommufd_group {
+       struct kref ref;
+       struct iommufd_ctx *ictx;
+       struct iommu_group *group;
+};
+
 /*
  * A iommufd_device object represents the binding relationship between a
  * consuming driver and the iommufd. These objects are created/destroyed by
@@ -270,12 +277,12 @@ void iommufd_hw_pagetable_destroy(struct iommufd_object *obj);
 struct iommufd_device {
        struct iommufd_object obj;
        struct iommufd_ctx *ictx;
+       struct iommufd_group *igroup;
        struct iommufd_hw_pagetable *hwpt;
        /* Head at iommufd_hw_pagetable::devices */
        struct list_head devices_item;
        /* always the physical device */
        struct device *dev;
-       struct iommu_group *group;
        bool enforce_cache_coherency;
 };
 
index 4bbb20d..34fefc0 100644 (file)
@@ -183,6 +183,7 @@ static int iommufd_fops_open(struct inode *inode, struct file *filp)
        }
 
        xa_init_flags(&ictx->objects, XA_FLAGS_ALLOC1 | XA_FLAGS_ACCOUNT);
+       xa_init(&ictx->groups);
        ictx->file = filp;
        filp->private_data = ictx;
        return 0;
@@ -218,6 +219,7 @@ static int iommufd_fops_release(struct inode *inode, struct file *filp)
                if (WARN_ON(!destroyed))
                        break;
        }
+       WARN_ON(!xa_empty(&ictx->groups));
        kfree(ictx);
        return 0;
 }