mm: change mprotect_fixup to vma iterator
authorLiam R. Howlett <Liam.Howlett@Oracle.com>
Fri, 20 Jan 2023 16:26:18 +0000 (11:26 -0500)
committerAndrew Morton <akpm@linux-foundation.org>
Fri, 10 Feb 2023 00:51:33 +0000 (16:51 -0800)
Use the vma iterator so that the iterator can be invalidated or updated to
avoid each caller doing so.

Link: https://lkml.kernel.org/r/20230120162650.984577-18-Liam.Howlett@oracle.com
Signed-off-by: Liam R. Howlett <Liam.Howlett@oracle.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
fs/exec.c
include/linux/mm.h
mm/mprotect.c

index ab91324..b98647e 100644 (file)
--- a/fs/exec.c
+++ b/fs/exec.c
@@ -758,6 +758,7 @@ int setup_arg_pages(struct linux_binprm *bprm,
        unsigned long stack_expand;
        unsigned long rlim_stack;
        struct mmu_gather tlb;
+       struct vma_iterator vmi;
 
 #ifdef CONFIG_STACK_GROWSUP
        /* Limit stack size */
@@ -812,8 +813,10 @@ int setup_arg_pages(struct linux_binprm *bprm,
        vm_flags |= mm->def_flags;
        vm_flags |= VM_STACK_INCOMPLETE_SETUP;
 
+       vma_iter_init(&vmi, mm, vma->vm_start);
+
        tlb_gather_mmu(&tlb, mm);
-       ret = mprotect_fixup(&tlb, vma, &prev, vma->vm_start, vma->vm_end,
+       ret = mprotect_fixup(&vmi, &tlb, vma, &prev, vma->vm_start, vma->vm_end,
                        vm_flags);
        tlb_finish_mmu(&tlb);
 
index 5b5f26d..144ddfd 100644 (file)
@@ -2197,9 +2197,9 @@ bool can_change_pte_writable(struct vm_area_struct *vma, unsigned long addr,
 extern long change_protection(struct mmu_gather *tlb,
                              struct vm_area_struct *vma, unsigned long start,
                              unsigned long end, unsigned long cp_flags);
-extern int mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma,
-                         struct vm_area_struct **pprev, unsigned long start,
-                         unsigned long end, unsigned long newflags);
+extern int mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
+         struct vm_area_struct *vma, struct vm_area_struct **pprev,
+         unsigned long start, unsigned long end, unsigned long newflags);
 
 /*
  * doesn't attempt to fault and will return short.
index 6a22f3a..39b6335 100644 (file)
@@ -585,9 +585,9 @@ static const struct mm_walk_ops prot_none_walk_ops = {
 };
 
 int
-mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma,
-              struct vm_area_struct **pprev, unsigned long start,
-              unsigned long end, unsigned long newflags)
+mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
+              struct vm_area_struct *vma, struct vm_area_struct **pprev,
+              unsigned long start, unsigned long end, unsigned long newflags)
 {
        struct mm_struct *mm = vma->vm_mm;
        unsigned long oldflags = vma->vm_flags;
@@ -642,7 +642,7 @@ mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma,
         * First try to merge with previous and/or next vma.
         */
        pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
-       *pprev = vma_merge(mm, *pprev, start, end, newflags,
+       *pprev = vmi_vma_merge(vmi, mm, *pprev, start, end, newflags,
                           vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
                           vma->vm_userfaultfd_ctx, anon_vma_name(vma));
        if (*pprev) {
@@ -654,13 +654,13 @@ mprotect_fixup(struct mmu_gather *tlb, struct vm_area_struct *vma,
        *pprev = vma;
 
        if (start != vma->vm_start) {
-               error = split_vma(mm, vma, start, 1);
+               error = vmi_split_vma(vmi, mm, vma, start, 1);
                if (error)
                        goto fail;
        }
 
        if (end != vma->vm_end) {
-               error = split_vma(mm, vma, end, 0);
+               error = vmi_split_vma(vmi, mm, vma, end, 0);
                if (error)
                        goto fail;
        }
@@ -709,7 +709,7 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
        const bool rier = (current->personality & READ_IMPLIES_EXEC) &&
                                (prot & PROT_READ);
        struct mmu_gather tlb;
-       MA_STATE(mas, &current->mm->mm_mt, 0, 0);
+       struct vma_iterator vmi;
 
        start = untagged_addr(start);
 
@@ -741,8 +741,8 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
        if ((pkey != -1) && !mm_pkey_is_allocated(current->mm, pkey))
                goto out;
 
-       mas_set(&mas, start);
-       vma = mas_find(&mas, ULONG_MAX);
+       vma_iter_init(&vmi, current->mm, start);
+       vma = vma_find(&vmi, end);
        error = -ENOMEM;
        if (!vma)
                goto out;
@@ -765,18 +765,22 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
                }
        }
 
+       prev = vma_prev(&vmi);
        if (start > vma->vm_start)
                prev = vma;
-       else
-               prev = mas_prev(&mas, 0);
 
        tlb_gather_mmu(&tlb, current->mm);
-       for (nstart = start ; ; ) {
+       nstart = start;
+       tmp = vma->vm_start;
+       for_each_vma_range(vmi, vma, end) {
                unsigned long mask_off_old_flags;
                unsigned long newflags;
                int new_vma_pkey;
 
-               /* Here we know that vma->vm_start <= nstart < vma->vm_end. */
+               if (vma->vm_start != tmp) {
+                       error = -ENOMEM;
+                       break;
+               }
 
                /* Does the application expect PROT_READ to imply PROT_EXEC */
                if (rier && (vma->vm_flags & VM_MAYEXEC))
@@ -824,25 +828,18 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
                                break;
                }
 
-               error = mprotect_fixup(&tlb, vma, &prev, nstart, tmp, newflags);
+               error = mprotect_fixup(&vmi, &tlb, vma, &prev, nstart, tmp, newflags);
                if (error)
                        break;
 
                nstart = tmp;
-
-               if (nstart < prev->vm_end)
-                       nstart = prev->vm_end;
-               if (nstart >= end)
-                       break;
-
-               vma = find_vma(current->mm, prev->vm_end);
-               if (!vma || vma->vm_start != nstart) {
-                       error = -ENOMEM;
-                       break;
-               }
                prot = reqprot;
        }
        tlb_finish_mmu(&tlb);
+
+       if (vma_iter_end(&vmi) < end)
+               error = -ENOMEM;
+
 out:
        mmap_write_unlock(current->mm);
        return error;