userfaultfd: use vma iterator
authorLiam R. Howlett <Liam.Howlett@Oracle.com>
Fri, 20 Jan 2023 16:26:17 +0000 (11:26 -0500)
committerAndrew Morton <akpm@linux-foundation.org>
Fri, 10 Feb 2023 00:51:33 +0000 (16:51 -0800)
Use the vma iterator so that the iterator can be invalidated or updated to
avoid each caller doing so.

Link: https://lkml.kernel.org/r/20230120162650.984577-17-Liam.Howlett@oracle.com
Signed-off-by: Liam R. Howlett <Liam.Howlett@oracle.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
fs/userfaultfd.c

index 15a5bf7..4334bd3 100644 (file)
@@ -883,7 +883,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
        /* len == 0 means wake all */
        struct userfaultfd_wake_range range = { .len = 0, };
        unsigned long new_flags;
-       MA_STATE(mas, &mm->mm_mt, 0, 0);
+       VMA_ITERATOR(vmi, mm, 0);
 
        WRITE_ONCE(ctx->released, true);
 
@@ -900,7 +900,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
         */
        mmap_write_lock(mm);
        prev = NULL;
-       mas_for_each(&mas, vma, ULONG_MAX) {
+       for_each_vma(vmi, vma) {
                cond_resched();
                BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^
                       !!(vma->vm_flags & __VM_UFFD_FLAGS));
@@ -909,13 +909,12 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
                        continue;
                }
                new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
-               prev = vma_merge(mm, prev, vma->vm_start, vma->vm_end,
+               prev = vmi_vma_merge(&vmi, mm, prev, vma->vm_start, vma->vm_end,
                                 new_flags, vma->anon_vma,
                                 vma->vm_file, vma->vm_pgoff,
                                 vma_policy(vma),
                                 NULL_VM_UFFD_CTX, anon_vma_name(vma));
                if (prev) {
-                       mas_pause(&mas);
                        vma = prev;
                } else {
                        prev = vma;
@@ -1302,7 +1301,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
        bool found;
        bool basic_ioctls;
        unsigned long start, end, vma_end;
-       MA_STATE(mas, &mm->mm_mt, 0, 0);
+       struct vma_iterator vmi;
 
        user_uffdio_register = (struct uffdio_register __user *) arg;
 
@@ -1344,17 +1343,13 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
        if (!mmget_not_zero(mm))
                goto out;
 
+       ret = -EINVAL;
        mmap_write_lock(mm);
-       mas_set(&mas, start);
-       vma = mas_find(&mas, ULONG_MAX);
+       vma_iter_init(&vmi, mm, start);
+       vma = vma_find(&vmi, end);
        if (!vma)
                goto out_unlock;
 
-       /* check that there's at least one vma in the range */
-       ret = -EINVAL;
-       if (vma->vm_start >= end)
-               goto out_unlock;
-
        /*
         * If the first vma contains huge pages, make sure start address
         * is aligned to huge page size.
@@ -1371,7 +1366,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
         */
        found = false;
        basic_ioctls = false;
-       for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
+       cur = vma;
+       do {
                cond_resched();
 
                BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@@ -1428,16 +1424,14 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
                        basic_ioctls = true;
 
                found = true;
-       }
+       } for_each_vma_range(vmi, cur, end);
        BUG_ON(!found);
 
-       mas_set(&mas, start);
-       prev = mas_prev(&mas, 0);
-       if (prev != vma)
-               mas_next(&mas, ULONG_MAX);
+       vma_iter_set(&vmi, start);
+       prev = vma_prev(&vmi);
 
        ret = 0;
-       do {
+       for_each_vma_range(vmi, vma, end) {
                cond_resched();
 
                BUG_ON(!vma_can_userfault(vma, vm_flags));
@@ -1458,30 +1452,25 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
                vma_end = min(end, vma->vm_end);
 
                new_flags = (vma->vm_flags & ~__VM_UFFD_FLAGS) | vm_flags;
-               prev = vma_merge(mm, prev, start, vma_end, new_flags,
+               prev = vmi_vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
                                 vma->anon_vma, vma->vm_file, vma->vm_pgoff,
                                 vma_policy(vma),
                                 ((struct vm_userfaultfd_ctx){ ctx }),
                                 anon_vma_name(vma));
                if (prev) {
                        /* vma_merge() invalidated the mas */
-                       mas_pause(&mas);
                        vma = prev;
                        goto next;
                }
                if (vma->vm_start < start) {
-                       ret = split_vma(mm, vma, start, 1);
+                       ret = vmi_split_vma(&vmi, mm, vma, start, 1);
                        if (ret)
                                break;
-                       /* split_vma() invalidated the mas */
-                       mas_pause(&mas);
                }
                if (vma->vm_end > end) {
-                       ret = split_vma(mm, vma, end, 0);
+                       ret = vmi_split_vma(&vmi, mm, vma, end, 0);
                        if (ret)
                                break;
-                       /* split_vma() invalidated the mas */
-                       mas_pause(&mas);
                }
        next:
                /*
@@ -1498,8 +1487,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
        skip:
                prev = vma;
                start = vma->vm_end;
-               vma = mas_next(&mas, end - 1);
-       } while (vma);
+       }
+
 out_unlock:
        mmap_write_unlock(mm);
        mmput(mm);
@@ -1543,7 +1532,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
        bool found;
        unsigned long start, end, vma_end;
        const void __user *buf = (void __user *)arg;
-       MA_STATE(mas, &mm->mm_mt, 0, 0);
+       struct vma_iterator vmi;
 
        ret = -EFAULT;
        if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
@@ -1562,14 +1551,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
                goto out;
 
        mmap_write_lock(mm);
-       mas_set(&mas, start);
-       vma = mas_find(&mas, ULONG_MAX);
-       if (!vma)
-               goto out_unlock;
-
-       /* check that there's at least one vma in the range */
        ret = -EINVAL;
-       if (vma->vm_start >= end)
+       vma_iter_init(&vmi, mm, start);
+       vma = vma_find(&vmi, end);
+       if (!vma)
                goto out_unlock;
 
        /*
@@ -1587,8 +1572,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
         * Search for not compatible vmas.
         */
        found = false;
-       ret = -EINVAL;
-       for (cur = vma; cur; cur = mas_next(&mas, end - 1)) {
+       cur = vma;
+       do {
                cond_resched();
 
                BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
@@ -1605,16 +1590,13 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
                        goto out_unlock;
 
                found = true;
-       }
+       } for_each_vma_range(vmi, cur, end);
        BUG_ON(!found);
 
-       mas_set(&mas, start);
-       prev = mas_prev(&mas, 0);
-       if (prev != vma)
-               mas_next(&mas, ULONG_MAX);
-
+       vma_iter_set(&vmi, start);
+       prev = vma_prev(&vmi);
        ret = 0;
-       do {
+       for_each_vma_range(vmi, vma, end) {
                cond_resched();
 
                BUG_ON(!vma_can_userfault(vma, vma->vm_flags));
@@ -1650,26 +1632,23 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
                        uffd_wp_range(mm, vma, start, vma_end - start, false);
 
                new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
-               prev = vma_merge(mm, prev, start, vma_end, new_flags,
+               prev = vmi_vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
                                 vma->anon_vma, vma->vm_file, vma->vm_pgoff,
                                 vma_policy(vma),
                                 NULL_VM_UFFD_CTX, anon_vma_name(vma));
                if (prev) {
                        vma = prev;
-                       mas_pause(&mas);
                        goto next;
                }
                if (vma->vm_start < start) {
-                       ret = split_vma(mm, vma, start, 1);
+                       ret = vmi_split_vma(&vmi, mm, vma, start, 1);
                        if (ret)
                                break;
-                       mas_pause(&mas);
                }
                if (vma->vm_end > end) {
-                       ret = split_vma(mm, vma, end, 0);
+                       ret = vmi_split_vma(&vmi, mm, vma, end, 0);
                        if (ret)
                                break;
-                       mas_pause(&mas);
                }
        next:
                /*
@@ -1683,8 +1662,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
        skip:
                prev = vma;
                start = vma->vm_end;
-               vma = mas_next(&mas, end - 1);
-       } while (vma);
+       }
+
 out_unlock:
        mmap_write_unlock(mm);
        mmput(mm);