drm/amdkfd: Ensure mm remain valid in svm deferred_list work
authorPhilip Yang <Philip.Yang@amd.com>
Tue, 18 Jan 2022 17:15:24 +0000 (12:15 -0500)
committerAlex Deucher <alexander.deucher@amd.com>
Thu, 27 Jan 2022 20:47:51 +0000 (15:47 -0500)
svm_deferred_list work should continue to handle deferred_range_list
which maybe split to child range to avoid child range leak, and remove
ranges mmu interval notifier to avoid mm mm_count leak. So taking mm
reference when adding range to deferred list, to ensure mm is valid in
the scheduled deferred_list_work, and drop the mm referrence after range
is handled.

Signed-off-by: Philip Yang <Philip.Yang@amd.com>
Reported-by: Ruili Ji <ruili.ji@amd.com>
Reviewed-by: Felix Kuehling <Felix.Kuehling@amd.com>
Signed-off-by: Alex Deucher <alexander.deucher@amd.com>
drivers/gpu/drm/amd/amdkfd/kfd_svm.c

index f2805ba..225affc 100644 (file)
@@ -1985,10 +1985,9 @@ svm_range_update_notifier_and_interval_tree(struct mm_struct *mm,
 }
 
 static void
-svm_range_handle_list_op(struct svm_range_list *svms, struct svm_range *prange)
+svm_range_handle_list_op(struct svm_range_list *svms, struct svm_range *prange,
+                        struct mm_struct *mm)
 {
-       struct mm_struct *mm = prange->work_item.mm;
-
        switch (prange->work_item.op) {
        case SVM_OP_NULL:
                pr_debug("NULL OP 0x%p prange 0x%p [0x%lx 0x%lx]\n",
@@ -2071,34 +2070,41 @@ static void svm_range_deferred_list_work(struct work_struct *work)
        pr_debug("enter svms 0x%p\n", svms);
 
        p = container_of(svms, struct kfd_process, svms);
-       /* Avoid mm is gone when inserting mmu notifier */
-       mm = get_task_mm(p->lead_thread);
-       if (!mm) {
-               pr_debug("svms 0x%p process mm gone\n", svms);
-               return;
-       }
-retry:
-       mmap_write_lock(mm);
-
-       /* Checking for the need to drain retry faults must be inside
-        * mmap write lock to serialize with munmap notifiers.
-        */
-       if (unlikely(atomic_read(&svms->drain_pagefaults))) {
-               mmap_write_unlock(mm);
-               svm_range_drain_retry_fault(svms);
-               goto retry;
-       }
 
        spin_lock(&svms->deferred_list_lock);
        while (!list_empty(&svms->deferred_range_list)) {
                prange = list_first_entry(&svms->deferred_range_list,
                                          struct svm_range, deferred_list);
-               list_del_init(&prange->deferred_list);
                spin_unlock(&svms->deferred_list_lock);
 
                pr_debug("prange 0x%p [0x%lx 0x%lx] op %d\n", prange,
                         prange->start, prange->last, prange->work_item.op);
 
+               mm = prange->work_item.mm;
+retry:
+               mmap_write_lock(mm);
+
+               /* Checking for the need to drain retry faults must be inside
+                * mmap write lock to serialize with munmap notifiers.
+                */
+               if (unlikely(atomic_read(&svms->drain_pagefaults))) {
+                       mmap_write_unlock(mm);
+                       svm_range_drain_retry_fault(svms);
+                       goto retry;
+               }
+
+               /* Remove from deferred_list must be inside mmap write lock, for
+                * two race cases:
+                * 1. unmap_from_cpu may change work_item.op and add the range
+                *    to deferred_list again, cause use after free bug.
+                * 2. svm_range_list_lock_and_flush_work may hold mmap write
+                *    lock and continue because deferred_list is empty, but
+                *    deferred_list work is actually waiting for mmap lock.
+                */
+               spin_lock(&svms->deferred_list_lock);
+               list_del_init(&prange->deferred_list);
+               spin_unlock(&svms->deferred_list_lock);
+
                mutex_lock(&svms->lock);
                mutex_lock(&prange->migrate_mutex);
                while (!list_empty(&prange->child_list)) {
@@ -2109,19 +2115,20 @@ retry:
                        pr_debug("child prange 0x%p op %d\n", pchild,
                                 pchild->work_item.op);
                        list_del_init(&pchild->child_list);
-                       svm_range_handle_list_op(svms, pchild);
+                       svm_range_handle_list_op(svms, pchild, mm);
                }
                mutex_unlock(&prange->migrate_mutex);
 
-               svm_range_handle_list_op(svms, prange);
+               svm_range_handle_list_op(svms, prange, mm);
                mutex_unlock(&svms->lock);
+               mmap_write_unlock(mm);
+
+               /* Pairs with mmget in svm_range_add_list_work */
+               mmput(mm);
 
                spin_lock(&svms->deferred_list_lock);
        }
        spin_unlock(&svms->deferred_list_lock);
-
-       mmap_write_unlock(mm);
-       mmput(mm);
        pr_debug("exit svms 0x%p\n", svms);
 }
 
@@ -2139,6 +2146,9 @@ svm_range_add_list_work(struct svm_range_list *svms, struct svm_range *prange,
                        prange->work_item.op = op;
        } else {
                prange->work_item.op = op;
+
+               /* Pairs with mmput in deferred_list_work */
+               mmget(mm);
                prange->work_item.mm = mm;
                list_add_tail(&prange->deferred_list,
                              &prange->svms->deferred_range_list);