bpf: keep a reference to the mm, in case the task is dead.
authorKui-Feng Lee <kuifeng@meta.com>
Fri, 16 Dec 2022 22:18:54 +0000 (14:18 -0800)
committerAlexei Starovoitov <ast@kernel.org>
Wed, 28 Dec 2022 22:11:48 +0000 (14:11 -0800)
Fix the system crash that happens when a task iterator travel through
vma of tasks.

In task iterators, we used to access mm by following the pointer on
the task_struct; however, the death of a task will clear the pointer,
even though we still hold the task_struct.  That can cause an
unexpected crash for a null pointer when an iterator is visiting a
task that dies during the visit.  Keeping a reference of mm on the
iterator ensures we always have a valid pointer to mm.

Co-developed-by: Song Liu <song@kernel.org>
Signed-off-by: Song Liu <song@kernel.org>
Signed-off-by: Kui-Feng Lee <kuifeng@meta.com>
Reported-by: Nathan Slingerland <slinger@meta.com>
Acked-by: Yonghong Song <yhs@fb.com>
Link: https://lore.kernel.org/r/20221216221855.4122288-2-kuifeng@meta.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
kernel/bpf/task_iter.c

index c2a2182..c4ab9d6 100644 (file)
@@ -438,6 +438,7 @@ struct bpf_iter_seq_task_vma_info {
         */
        struct bpf_iter_seq_task_common common;
        struct task_struct *task;
+       struct mm_struct *mm;
        struct vm_area_struct *vma;
        u32 tid;
        unsigned long prev_vm_start;
@@ -456,16 +457,19 @@ task_vma_seq_get_next(struct bpf_iter_seq_task_vma_info *info)
        enum bpf_task_vma_iter_find_op op;
        struct vm_area_struct *curr_vma;
        struct task_struct *curr_task;
+       struct mm_struct *curr_mm;
        u32 saved_tid = info->tid;
 
        /* If this function returns a non-NULL vma, it holds a reference to
-        * the task_struct, and holds read lock on vma->mm->mmap_lock.
+        * the task_struct, holds a refcount on mm->mm_users, and holds
+        * read lock on vma->mm->mmap_lock.
         * If this function returns NULL, it does not hold any reference or
         * lock.
         */
        if (info->task) {
                curr_task = info->task;
                curr_vma = info->vma;
+               curr_mm = info->mm;
                /* In case of lock contention, drop mmap_lock to unblock
                 * the writer.
                 *
@@ -504,13 +508,15 @@ task_vma_seq_get_next(struct bpf_iter_seq_task_vma_info *info)
                 *    4.2) VMA2 and VMA2' covers different ranges, process
                 *         VMA2'.
                 */
-               if (mmap_lock_is_contended(curr_task->mm)) {
+               if (mmap_lock_is_contended(curr_mm)) {
                        info->prev_vm_start = curr_vma->vm_start;
                        info->prev_vm_end = curr_vma->vm_end;
                        op = task_vma_iter_find_vma;
-                       mmap_read_unlock(curr_task->mm);
-                       if (mmap_read_lock_killable(curr_task->mm))
+                       mmap_read_unlock(curr_mm);
+                       if (mmap_read_lock_killable(curr_mm)) {
+                               mmput(curr_mm);
                                goto finish;
+                       }
                } else {
                        op = task_vma_iter_next_vma;
                }
@@ -535,42 +541,47 @@ again:
                        op = task_vma_iter_find_vma;
                }
 
-               if (!curr_task->mm)
+               curr_mm = get_task_mm(curr_task);
+               if (!curr_mm)
                        goto next_task;
 
-               if (mmap_read_lock_killable(curr_task->mm))
+               if (mmap_read_lock_killable(curr_mm)) {
+                       mmput(curr_mm);
                        goto finish;
+               }
        }
 
        switch (op) {
        case task_vma_iter_first_vma:
-               curr_vma = find_vma(curr_task->mm, 0);
+               curr_vma = find_vma(curr_mm, 0);
                break;
        case task_vma_iter_next_vma:
-               curr_vma = find_vma(curr_task->mm, curr_vma->vm_end);
+               curr_vma = find_vma(curr_mm, curr_vma->vm_end);
                break;
        case task_vma_iter_find_vma:
                /* We dropped mmap_lock so it is necessary to use find_vma
                 * to find the next vma. This is similar to the  mechanism
                 * in show_smaps_rollup().
                 */
-               curr_vma = find_vma(curr_task->mm, info->prev_vm_end - 1);
+               curr_vma = find_vma(curr_mm, info->prev_vm_end - 1);
                /* case 1) and 4.2) above just use curr_vma */
 
                /* check for case 2) or case 4.1) above */
                if (curr_vma &&
                    curr_vma->vm_start == info->prev_vm_start &&
                    curr_vma->vm_end == info->prev_vm_end)
-                       curr_vma = find_vma(curr_task->mm, curr_vma->vm_end);
+                       curr_vma = find_vma(curr_mm, curr_vma->vm_end);
                break;
        }
        if (!curr_vma) {
                /* case 3) above, or case 2) 4.1) with vma->next == NULL */
-               mmap_read_unlock(curr_task->mm);
+               mmap_read_unlock(curr_mm);
+               mmput(curr_mm);
                goto next_task;
        }
        info->task = curr_task;
        info->vma = curr_vma;
+       info->mm = curr_mm;
        return curr_vma;
 
 next_task:
@@ -579,6 +590,7 @@ next_task:
 
        put_task_struct(curr_task);
        info->task = NULL;
+       info->mm = NULL;
        info->tid++;
        goto again;
 
@@ -587,6 +599,7 @@ finish:
                put_task_struct(curr_task);
        info->task = NULL;
        info->vma = NULL;
+       info->mm = NULL;
        return NULL;
 }
 
@@ -658,7 +671,9 @@ static void task_vma_seq_stop(struct seq_file *seq, void *v)
                 */
                info->prev_vm_start = ~0UL;
                info->prev_vm_end = info->vma->vm_end;
-               mmap_read_unlock(info->task->mm);
+               mmap_read_unlock(info->mm);
+               mmput(info->mm);
+               info->mm = NULL;
                put_task_struct(info->task);
                info->task = NULL;
        }