[ Upstream commit
7ff94f276f8ea05df82eb115225e9b26f47a3347 ]
Fix the system crash that happens when a task iterator travel through
vma of tasks.
In task iterators, we used to access mm by following the pointer on
the task_struct; however, the death of a task will clear the pointer,
even though we still hold the task_struct. That can cause an
unexpected crash for a null pointer when an iterator is visiting a
task that dies during the visit. Keeping a reference of mm on the
iterator ensures we always have a valid pointer to mm.
Co-developed-by: Song Liu <song@kernel.org>
Signed-off-by: Song Liu <song@kernel.org>
Signed-off-by: Kui-Feng Lee <kuifeng@meta.com>
Reported-by: Nathan Slingerland <slinger@meta.com>
Acked-by: Yonghong Song <yhs@fb.com>
Link: https://lore.kernel.org/r/20221216221855.4122288-2-kuifeng@meta.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Signed-off-by: Sasha Levin <sashal@kernel.org>
*/
struct bpf_iter_seq_task_common common;
struct task_struct *task;
*/
struct bpf_iter_seq_task_common common;
struct task_struct *task;
struct vm_area_struct *vma;
u32 tid;
unsigned long prev_vm_start;
struct vm_area_struct *vma;
u32 tid;
unsigned long prev_vm_start;
enum bpf_task_vma_iter_find_op op;
struct vm_area_struct *curr_vma;
struct task_struct *curr_task;
enum bpf_task_vma_iter_find_op op;
struct vm_area_struct *curr_vma;
struct task_struct *curr_task;
+ struct mm_struct *curr_mm;
u32 saved_tid = info->tid;
/* If this function returns a non-NULL vma, it holds a reference to
u32 saved_tid = info->tid;
/* If this function returns a non-NULL vma, it holds a reference to
- * the task_struct, and holds read lock on vma->mm->mmap_lock.
+ * the task_struct, holds a refcount on mm->mm_users, and holds
+ * read lock on vma->mm->mmap_lock.
* If this function returns NULL, it does not hold any reference or
* lock.
*/
if (info->task) {
curr_task = info->task;
curr_vma = info->vma;
* If this function returns NULL, it does not hold any reference or
* lock.
*/
if (info->task) {
curr_task = info->task;
curr_vma = info->vma;
/* In case of lock contention, drop mmap_lock to unblock
* the writer.
*
/* In case of lock contention, drop mmap_lock to unblock
* the writer.
*
* 4.2) VMA2 and VMA2' covers different ranges, process
* VMA2'.
*/
* 4.2) VMA2 and VMA2' covers different ranges, process
* VMA2'.
*/
- if (mmap_lock_is_contended(curr_task->mm)) {
+ if (mmap_lock_is_contended(curr_mm)) {
info->prev_vm_start = curr_vma->vm_start;
info->prev_vm_end = curr_vma->vm_end;
op = task_vma_iter_find_vma;
info->prev_vm_start = curr_vma->vm_start;
info->prev_vm_end = curr_vma->vm_end;
op = task_vma_iter_find_vma;
- mmap_read_unlock(curr_task->mm);
- if (mmap_read_lock_killable(curr_task->mm))
+ mmap_read_unlock(curr_mm);
+ if (mmap_read_lock_killable(curr_mm)) {
+ mmput(curr_mm);
} else {
op = task_vma_iter_next_vma;
}
} else {
op = task_vma_iter_next_vma;
}
op = task_vma_iter_find_vma;
}
op = task_vma_iter_find_vma;
}
+ curr_mm = get_task_mm(curr_task);
+ if (!curr_mm)
- if (mmap_read_lock_killable(curr_task->mm))
+ if (mmap_read_lock_killable(curr_mm)) {
+ mmput(curr_mm);
}
switch (op) {
case task_vma_iter_first_vma:
}
switch (op) {
case task_vma_iter_first_vma:
- curr_vma = find_vma(curr_task->mm, 0);
+ curr_vma = find_vma(curr_mm, 0);
break;
case task_vma_iter_next_vma:
break;
case task_vma_iter_next_vma:
- curr_vma = find_vma(curr_task->mm, curr_vma->vm_end);
+ curr_vma = find_vma(curr_mm, curr_vma->vm_end);
break;
case task_vma_iter_find_vma:
/* We dropped mmap_lock so it is necessary to use find_vma
* to find the next vma. This is similar to the mechanism
* in show_smaps_rollup().
*/
break;
case task_vma_iter_find_vma:
/* We dropped mmap_lock so it is necessary to use find_vma
* to find the next vma. This is similar to the mechanism
* in show_smaps_rollup().
*/
- curr_vma = find_vma(curr_task->mm, info->prev_vm_end - 1);
+ curr_vma = find_vma(curr_mm, info->prev_vm_end - 1);
/* case 1) and 4.2) above just use curr_vma */
/* check for case 2) or case 4.1) above */
if (curr_vma &&
curr_vma->vm_start == info->prev_vm_start &&
curr_vma->vm_end == info->prev_vm_end)
/* case 1) and 4.2) above just use curr_vma */
/* check for case 2) or case 4.1) above */
if (curr_vma &&
curr_vma->vm_start == info->prev_vm_start &&
curr_vma->vm_end == info->prev_vm_end)
- curr_vma = find_vma(curr_task->mm, curr_vma->vm_end);
+ curr_vma = find_vma(curr_mm, curr_vma->vm_end);
break;
}
if (!curr_vma) {
/* case 3) above, or case 2) 4.1) with vma->next == NULL */
break;
}
if (!curr_vma) {
/* case 3) above, or case 2) 4.1) with vma->next == NULL */
- mmap_read_unlock(curr_task->mm);
+ mmap_read_unlock(curr_mm);
+ mmput(curr_mm);
goto next_task;
}
info->task = curr_task;
info->vma = curr_vma;
goto next_task;
}
info->task = curr_task;
info->vma = curr_vma;
return curr_vma;
next_task:
return curr_vma;
next_task:
put_task_struct(curr_task);
info->task = NULL;
put_task_struct(curr_task);
info->task = NULL;
put_task_struct(curr_task);
info->task = NULL;
info->vma = NULL;
put_task_struct(curr_task);
info->task = NULL;
info->vma = NULL;
*/
info->prev_vm_start = ~0UL;
info->prev_vm_end = info->vma->vm_end;
*/
info->prev_vm_start = ~0UL;
info->prev_vm_end = info->vma->vm_end;
- mmap_read_unlock(info->task->mm);
+ mmap_read_unlock(info->mm);
+ mmput(info->mm);
+ info->mm = NULL;
put_task_struct(info->task);
info->task = NULL;
}
put_task_struct(info->task);
info->task = NULL;
}