mlock: convert mlock to vma iterator
authorLiam R. Howlett <Liam.Howlett@Oracle.com>
Fri, 20 Jan 2023 16:26:19 +0000 (11:26 -0500)
committerAndrew Morton <akpm@linux-foundation.org>
Fri, 10 Feb 2023 00:51:33 +0000 (16:51 -0800)
Use the vma iterator so that the iterator can be invalidated or updated to
avoid each caller doing so.

Link: https://lkml.kernel.org/r/20230120162650.984577-19-Liam.Howlett@oracle.com
Signed-off-by: Liam R. Howlett <Liam.Howlett@oracle.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
mm/mlock.c

index b680f11..0d09b90 100644 (file)
@@ -401,8 +401,9 @@ static void mlock_vma_pages_range(struct vm_area_struct *vma,
  *
  * For vmas that pass the filters, merge/split as appropriate.
  */
-static int mlock_fixup(struct vm_area_struct *vma, struct vm_area_struct **prev,
-       unsigned long start, unsigned long end, vm_flags_t newflags)
+static int mlock_fixup(struct vma_iterator *vmi, struct vm_area_struct *vma,
+              struct vm_area_struct **prev, unsigned long start,
+              unsigned long end, vm_flags_t newflags)
 {
        struct mm_struct *mm = vma->vm_mm;
        pgoff_t pgoff;
@@ -417,22 +418,22 @@ static int mlock_fixup(struct vm_area_struct *vma, struct vm_area_struct **prev,
                goto out;
 
        pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
-       *prev = vma_merge(mm, *prev, start, end, newflags, vma->anon_vma,
-                         vma->vm_file, pgoff, vma_policy(vma),
-                         vma->vm_userfaultfd_ctx, anon_vma_name(vma));
+       *prev = vmi_vma_merge(vmi, mm, *prev, start, end, newflags,
+                       vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
+                       vma->vm_userfaultfd_ctx, anon_vma_name(vma));
        if (*prev) {
                vma = *prev;
                goto success;
        }
 
        if (start != vma->vm_start) {
-               ret = split_vma(mm, vma, start, 1);
+               ret = vmi_split_vma(vmi, mm, vma, start, 1);
                if (ret)
                        goto out;
        }
 
        if (end != vma->vm_end) {
-               ret = split_vma(mm, vma, end, 0);
+               ret = vmi_split_vma(vmi, mm, vma, end, 0);
                if (ret)
                        goto out;
        }
@@ -471,7 +472,7 @@ static int apply_vma_lock_flags(unsigned long start, size_t len,
        unsigned long nstart, end, tmp;
        struct vm_area_struct *vma, *prev;
        int error;
-       MA_STATE(mas, &current->mm->mm_mt, start, start);
+       VMA_ITERATOR(vmi, current->mm, start);
 
        VM_BUG_ON(offset_in_page(start));
        VM_BUG_ON(len != PAGE_ALIGN(len));
@@ -480,39 +481,37 @@ static int apply_vma_lock_flags(unsigned long start, size_t len,
                return -EINVAL;
        if (end == start)
                return 0;
-       vma = mas_walk(&mas);
+       vma = vma_iter_load(&vmi);
        if (!vma)
                return -ENOMEM;
 
+       prev = vma_prev(&vmi);
        if (start > vma->vm_start)
                prev = vma;
-       else
-               prev = mas_prev(&mas, 0);
 
-       for (nstart = start ; ; ) {
-               vm_flags_t newflags = vma->vm_flags & VM_LOCKED_CLEAR_MASK;
+       nstart = start;
+       tmp = vma->vm_start;
+       for_each_vma_range(vmi, vma, end) {
+               vm_flags_t newflags;
 
-               newflags |= flags;
+               if (vma->vm_start != tmp)
+                       return -ENOMEM;
 
+               newflags = vma->vm_flags & VM_LOCKED_CLEAR_MASK;
+               newflags |= flags;
                /* Here we know that  vma->vm_start <= nstart < vma->vm_end. */
                tmp = vma->vm_end;
                if (tmp > end)
                        tmp = end;
-               error = mlock_fixup(vma, &prev, nstart, tmp, newflags);
+               error = mlock_fixup(&vmi, vma, &prev, nstart, tmp, newflags);
                if (error)
                        break;
                nstart = tmp;
-               if (nstart < prev->vm_end)
-                       nstart = prev->vm_end;
-               if (nstart >= end)
-                       break;
-
-               vma = find_vma(prev->vm_mm, prev->vm_end);
-               if (!vma || vma->vm_start != nstart) {
-                       error = -ENOMEM;
-                       break;
-               }
        }
+
+       if (vma_iter_end(&vmi) < end)
+               return -ENOMEM;
+
        return error;
 }
 
@@ -658,7 +657,7 @@ SYSCALL_DEFINE2(munlock, unsigned long, start, size_t, len)
  */
 static int apply_mlockall_flags(int flags)
 {
-       MA_STATE(mas, &current->mm->mm_mt, 0, 0);
+       VMA_ITERATOR(vmi, current->mm, 0);
        struct vm_area_struct *vma, *prev = NULL;
        vm_flags_t to_add = 0;
 
@@ -679,15 +678,15 @@ static int apply_mlockall_flags(int flags)
                        to_add |= VM_LOCKONFAULT;
        }
 
-       mas_for_each(&mas, vma, ULONG_MAX) {
+       for_each_vma(vmi, vma) {
                vm_flags_t newflags;
 
                newflags = vma->vm_flags & VM_LOCKED_CLEAR_MASK;
                newflags |= to_add;
 
                /* Ignore errors */
-               mlock_fixup(vma, &prev, vma->vm_start, vma->vm_end, newflags);
-               mas_pause(&mas);
+               mlock_fixup(&vmi, vma, &prev, vma->vm_start, vma->vm_end,
+                           newflags);
                cond_resched();
        }
 out: