iommu/amd: Convert from atomic_t to refcount_t on pasid_state->count
authorXiyu Yang via iommu <iommu@lists.linux-foundation.org>
Mon, 19 Jul 2021 08:32:58 +0000 (16:32 +0800)
committerJoerg Roedel <jroedel@suse.de>
Mon, 26 Jul 2021 11:46:57 +0000 (13:46 +0200)
refcount_t type and corresponding API can protect refcounters from
accidental underflow and overflow and further use-after-free situations.

Signed-off-by: Xiyu Yang <xiyuyang19@fudan.edu.cn>
Signed-off-by: Xin Tan <tanxin.ctf@gmail.com>
Reviewed-by: Suravee Suthikulpanit <suravee.suthikulpanit@amd.com>
Link: https://lore.kernel.org/r/1626683578-64214-1-git-send-email-xiyuyang19@fudan.edu.cn
Signed-off-by: Joerg Roedel <jroedel@suse.de>
drivers/iommu/amd/iommu_v2.c

index f8d4ad4..a9e5682 100644 (file)
@@ -6,6 +6,7 @@
 
 #define pr_fmt(fmt)     "AMD-Vi: " fmt
 
+#include <linux/refcount.h>
 #include <linux/mmu_notifier.h>
 #include <linux/amd-iommu.h>
 #include <linux/mm_types.h>
@@ -33,7 +34,7 @@ struct pri_queue {
 
 struct pasid_state {
        struct list_head list;                  /* For global state-list */
-       atomic_t count;                         /* Reference count */
+       refcount_t count;                               /* Reference count */
        unsigned mmu_notifier_count;            /* Counting nested mmu_notifier
                                                   calls */
        struct mm_struct *mm;                   /* mm_struct for the faults */
@@ -242,7 +243,7 @@ static struct pasid_state *get_pasid_state(struct device_state *dev_state,
 
        ret = *ptr;
        if (ret)
-               atomic_inc(&ret->count);
+               refcount_inc(&ret->count);
 
 out_unlock:
        spin_unlock_irqrestore(&dev_state->lock, flags);
@@ -257,14 +258,14 @@ static void free_pasid_state(struct pasid_state *pasid_state)
 
 static void put_pasid_state(struct pasid_state *pasid_state)
 {
-       if (atomic_dec_and_test(&pasid_state->count))
+       if (refcount_dec_and_test(&pasid_state->count))
                wake_up(&pasid_state->wq);
 }
 
 static void put_pasid_state_wait(struct pasid_state *pasid_state)
 {
-       atomic_dec(&pasid_state->count);
-       wait_event(pasid_state->wq, !atomic_read(&pasid_state->count));
+       refcount_dec(&pasid_state->count);
+       wait_event(pasid_state->wq, !refcount_read(&pasid_state->count));
        free_pasid_state(pasid_state);
 }
 
@@ -624,7 +625,7 @@ int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid,
                goto out;
 
 
-       atomic_set(&pasid_state->count, 1);
+       refcount_set(&pasid_state->count, 1);
        init_waitqueue_head(&pasid_state->wq);
        spin_lock_init(&pasid_state->lock);