mm/gup: take mmap_lock in get_dump_page()
[platform/kernel/linux-starfive.git] / mm / gup.c
index ae096ea..102877e 100644 (file)
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -329,6 +329,13 @@ void unpin_user_pages(struct page **pages, unsigned long npages)
        unsigned long index;
 
        /*
+        * If this WARN_ON() fires, then the system *might* be leaking pages (by
+        * leaving them pinned), but probably not. More likely, gup/pup returned
+        * a hard -ERRNO error to the caller, who erroneously passed it here.
+        */
+       if (WARN_ON(IS_ERR_VALUE(npages)))
+               return;
+       /*
         * TODO: this can be optimized for huge pages: if a series of pages is
         * physically contiguous and part of the same compound page, then a
         * single operation to the head page should suffice.
@@ -381,22 +388,13 @@ static int follow_pfn_pte(struct vm_area_struct *vma, unsigned long address,
 }
 
 /*
- * FOLL_FORCE or a forced COW break can write even to unwritable pte's,
- * but only after we've gone through a COW cycle and they are dirty.
+ * FOLL_FORCE can write to even unwritable pte's, but only
+ * after we've gone through a COW cycle and they are dirty.
  */
 static inline bool can_follow_write_pte(pte_t pte, unsigned int flags)
 {
-       return pte_write(pte) || ((flags & FOLL_COW) && pte_dirty(pte));
-}
-
-/*
- * A (separate) COW fault might break the page the other way and
- * get_user_pages() would return the page from what is now the wrong
- * VM. So we need to force a COW break at GUP time even for reads.
- */
-static inline bool should_force_cow_break(struct vm_area_struct *vma, unsigned int flags)
-{
-       return is_cow_mapping(vma->vm_flags) && (flags & (FOLL_GET | FOLL_PIN));
+       return pte_write(pte) ||
+               ((flags & FOLL_FORCE) && (flags & FOLL_COW) && pte_dirty(pte));
 }
 
 static struct page *follow_page_pte(struct vm_area_struct *vma,
@@ -843,7 +841,7 @@ static int get_gate_page(struct mm_struct *mm, unsigned long address,
                        goto unmap;
                *page = pte_page(*pte);
        }
-       if (unlikely(!try_get_page(*page))) {
+       if (unlikely(!try_grab_page(*page, gup_flags))) {
                ret = -ENOMEM;
                goto unmap;
        }
@@ -1067,11 +1065,9 @@ static long __get_user_pages(struct mm_struct *mm,
                                goto out;
                        }
                        if (is_vm_hugetlb_page(vma)) {
-                               if (should_force_cow_break(vma, foll_flags))
-                                       foll_flags |= FOLL_WRITE;
                                i = follow_hugetlb_page(mm, vma, pages, vmas,
                                                &start, &nr_pages, i,
-                                               foll_flags, locked);
+                                               gup_flags, locked);
                                if (locked && *locked == 0) {
                                        /*
                                         * We've got a VM_FAULT_RETRY
@@ -1085,10 +1081,6 @@ static long __get_user_pages(struct mm_struct *mm,
                                continue;
                        }
                }
-
-               if (should_force_cow_break(vma, foll_flags))
-                       foll_flags |= FOLL_WRITE;
-
 retry:
                /*
                 * If we have a pending SIGKILL, don't keep faulting pages and
@@ -1270,6 +1262,9 @@ static __always_inline long __get_user_pages_locked(struct mm_struct *mm,
                BUG_ON(*locked != 1);
        }
 
+       if (flags & FOLL_PIN)
+               atomic_set(&mm->has_pinned, 1);
+
        /*
         * FOLL_PIN and FOLL_GET are mutually exclusive. Traditional behavior
         * is to set FOLL_GET if the caller wants pages[] filled in (but has
@@ -1495,35 +1490,6 @@ int __mm_populate(unsigned long start, unsigned long len, int ignore_errors)
                mmap_read_unlock(mm);
        return ret;     /* 0 or negative error code */
 }
-
-/**
- * get_dump_page() - pin user page in memory while writing it to core dump
- * @addr: user address
- *
- * Returns struct page pointer of user page pinned for dump,
- * to be freed afterwards by put_page().
- *
- * Returns NULL on any kind of failure - a hole must then be inserted into
- * the corefile, to preserve alignment with its headers; and also returns
- * NULL wherever the ZERO_PAGE, or an anonymous pte_none, has been found -
- * allowing a hole to be left in the corefile to save diskspace.
- *
- * Called without mmap_lock, but after all other threads have been killed.
- */
-#ifdef CONFIG_ELF_CORE
-struct page *get_dump_page(unsigned long addr)
-{
-       struct vm_area_struct *vma;
-       struct page *page;
-
-       if (__get_user_pages(current->mm, addr, 1,
-                            FOLL_FORCE | FOLL_DUMP | FOLL_GET, &page, &vma,
-                            NULL) < 1)
-               return NULL;
-       flush_cache_page(vma, addr, page_to_pfn(page));
-       return page;
-}
-#endif /* CONFIG_ELF_CORE */
 #else /* CONFIG_MMU */
 static long __get_user_pages_locked(struct mm_struct *mm, unsigned long start,
                unsigned long nr_pages, struct page **pages,
@@ -1569,6 +1535,38 @@ finish_or_fault:
 }
 #endif /* !CONFIG_MMU */
 
+/**
+ * get_dump_page() - pin user page in memory while writing it to core dump
+ * @addr: user address
+ *
+ * Returns struct page pointer of user page pinned for dump,
+ * to be freed afterwards by put_page().
+ *
+ * Returns NULL on any kind of failure - a hole must then be inserted into
+ * the corefile, to preserve alignment with its headers; and also returns
+ * NULL wherever the ZERO_PAGE, or an anonymous pte_none, has been found -
+ * allowing a hole to be left in the corefile to save diskspace.
+ *
+ * Called without mmap_lock (takes and releases the mmap_lock by itself).
+ */
+#ifdef CONFIG_ELF_CORE
+struct page *get_dump_page(unsigned long addr)
+{
+       struct mm_struct *mm = current->mm;
+       struct page *page;
+       int locked = 1;
+       int ret;
+
+       if (mmap_read_lock_killable(mm))
+               return NULL;
+       ret = __get_user_pages_locked(mm, addr, 1, &page, NULL, &locked,
+                                     FOLL_FORCE | FOLL_DUMP | FOLL_GET);
+       if (locked)
+               mmap_read_unlock(mm);
+       return (ret == 1) ? page : NULL;
+}
+#endif /* CONFIG_ELF_CORE */
+
 #if defined(CONFIG_FS_DAX) || defined (CONFIG_CMA)
 static bool check_dax_vmas(struct vm_area_struct **vmas, long nr_pages)
 {
@@ -1759,6 +1757,25 @@ static __always_inline long __gup_longterm_locked(struct mm_struct *mm,
 }
 #endif /* CONFIG_FS_DAX || CONFIG_CMA */
 
+static bool is_valid_gup_flags(unsigned int gup_flags)
+{
+       /*
+        * FOLL_PIN must only be set internally by the pin_user_pages*() APIs,
+        * never directly by the caller, so enforce that with an assertion:
+        */
+       if (WARN_ON_ONCE(gup_flags & FOLL_PIN))
+               return false;
+       /*
+        * FOLL_PIN is a prerequisite to FOLL_LONGTERM. Another way of saying
+        * that is, FOLL_LONGTERM is a specific case, more restrictive case of
+        * FOLL_PIN.
+        */
+       if (WARN_ON_ONCE(gup_flags & FOLL_LONGTERM))
+               return false;
+
+       return true;
+}
+
 #ifdef CONFIG_MMU
 static long __get_user_pages_remote(struct mm_struct *mm,
                                    unsigned long start, unsigned long nr_pages,
@@ -1854,11 +1871,7 @@ long get_user_pages_remote(struct mm_struct *mm,
                unsigned int gup_flags, struct page **pages,
                struct vm_area_struct **vmas, int *locked)
 {
-       /*
-        * FOLL_PIN must only be set internally by the pin_user_pages*() APIs,
-        * never directly by the caller, so enforce that with an assertion:
-        */
-       if (WARN_ON_ONCE(gup_flags & FOLL_PIN))
+       if (!is_valid_gup_flags(gup_flags))
                return -EINVAL;
 
        return __get_user_pages_remote(mm, start, nr_pages, gup_flags,
@@ -1904,11 +1917,7 @@ long get_user_pages(unsigned long start, unsigned long nr_pages,
                unsigned int gup_flags, struct page **pages,
                struct vm_area_struct **vmas)
 {
-       /*
-        * FOLL_PIN must only be set internally by the pin_user_pages*() APIs,
-        * never directly by the caller, so enforce that with an assertion:
-        */
-       if (WARN_ON_ONCE(gup_flags & FOLL_PIN))
+       if (!is_valid_gup_flags(gup_flags))
                return -EINVAL;
 
        return __gup_longterm_locked(current->mm, start, nr_pages,
@@ -2500,13 +2509,13 @@ static int gup_huge_pgd(pgd_t orig, pgd_t *pgdp, unsigned long addr,
        return 1;
 }
 
-static int gup_pmd_range(pud_t pud, unsigned long addr, unsigned long end,
+static int gup_pmd_range(pud_t *pudp, pud_t pud, unsigned long addr, unsigned long end,
                unsigned int flags, struct page **pages, int *nr)
 {
        unsigned long next;
        pmd_t *pmdp;
 
-       pmdp = pmd_offset(&pud, addr);
+       pmdp = pmd_offset_lockless(pudp, pud, addr);
        do {
                pmd_t pmd = READ_ONCE(*pmdp);
 
@@ -2543,13 +2552,13 @@ static int gup_pmd_range(pud_t pud, unsigned long addr, unsigned long end,
        return 1;
 }
 
-static int gup_pud_range(p4d_t p4d, unsigned long addr, unsigned long end,
+static int gup_pud_range(p4d_t *p4dp, p4d_t p4d, unsigned long addr, unsigned long end,
                         unsigned int flags, struct page **pages, int *nr)
 {
        unsigned long next;
        pud_t *pudp;
 
-       pudp = pud_offset(&p4d, addr);
+       pudp = pud_offset_lockless(p4dp, p4d, addr);
        do {
                pud_t pud = READ_ONCE(*pudp);
 
@@ -2564,20 +2573,20 @@ static int gup_pud_range(p4d_t p4d, unsigned long addr, unsigned long end,
                        if (!gup_huge_pd(__hugepd(pud_val(pud)), addr,
                                         PUD_SHIFT, next, flags, pages, nr))
                                return 0;
-               } else if (!gup_pmd_range(pud, addr, next, flags, pages, nr))
+               } else if (!gup_pmd_range(pudp, pud, addr, next, flags, pages, nr))
                        return 0;
        } while (pudp++, addr = next, addr != end);
 
        return 1;
 }
 
-static int gup_p4d_range(pgd_t pgd, unsigned long addr, unsigned long end,
+static int gup_p4d_range(pgd_t *pgdp, pgd_t pgd, unsigned long addr, unsigned long end,
                         unsigned int flags, struct page **pages, int *nr)
 {
        unsigned long next;
        p4d_t *p4dp;
 
-       p4dp = p4d_offset(&pgd, addr);
+       p4dp = p4d_offset_lockless(pgdp, pgd, addr);
        do {
                p4d_t p4d = READ_ONCE(*p4dp);
 
@@ -2589,7 +2598,7 @@ static int gup_p4d_range(pgd_t pgd, unsigned long addr, unsigned long end,
                        if (!gup_huge_pd(__hugepd(p4d_val(p4d)), addr,
                                         P4D_SHIFT, next, flags, pages, nr))
                                return 0;
-               } else if (!gup_pud_range(p4d, addr, next, flags, pages, nr))
+               } else if (!gup_pud_range(p4dp, p4d, addr, next, flags, pages, nr))
                        return 0;
        } while (p4dp++, addr = next, addr != end);
 
@@ -2617,7 +2626,7 @@ static void gup_pgd_range(unsigned long addr, unsigned long end,
                        if (!gup_huge_pd(__hugepd(pgd_val(pgd)), addr,
                                         PGDIR_SHIFT, next, flags, pages, nr))
                                return;
-               } else if (!gup_p4d_range(pgd, addr, next, flags, pages, nr))
+               } else if (!gup_p4d_range(pgdp, pgd, addr, next, flags, pages, nr))
                        return;
        } while (pgdp++, addr = next, addr != end);
 }
@@ -2675,6 +2684,9 @@ static int internal_get_user_pages_fast(unsigned long start, int nr_pages,
                                       FOLL_FAST_ONLY)))
                return -EINVAL;
 
+       if (gup_flags & FOLL_PIN)
+               atomic_set(&current->mm->has_pinned, 1);
+
        if (!(gup_flags & FOLL_FAST_ONLY))
                might_lock_read(&current->mm->mmap_lock);
 
@@ -2689,19 +2701,6 @@ static int internal_get_user_pages_fast(unsigned long start, int nr_pages,
                return -EFAULT;
 
        /*
-        * The FAST_GUP case requires FOLL_WRITE even for pure reads,
-        * because get_user_pages() may need to cause an early COW in
-        * order to avoid confusing the normal COW routines. So only
-        * targets that are already writable are safe to do by just
-        * looking at the page tables.
-        *
-        * NOTE! With FOLL_FAST_ONLY we allow read-only gup_fast() here,
-        * because there is no slow path to fall back on. But you'd
-        * better be careful about possible COW pages - you'll get _a_
-        * COW page, but not necessarily the one you intended to get
-        * depending on what COW event happens after this. COW may break
-        * the page copy in a random direction.
-        *
         * Disable interrupts. The nested form is used, in order to allow
         * full, general purpose use of this routine.
         *
@@ -2714,8 +2713,6 @@ static int internal_get_user_pages_fast(unsigned long start, int nr_pages,
         */
        if (IS_ENABLED(CONFIG_HAVE_FAST_GUP) && gup_fast_permitted(start, end)) {
                unsigned long fast_flags = gup_flags;
-               if (!(gup_flags & FOLL_FAST_ONLY))
-                       fast_flags |= FOLL_WRITE;
 
                local_irq_save(flags);
                gup_pgd_range(addr, end, fast_flags, pages, &nr_pinned);
@@ -2810,11 +2807,7 @@ EXPORT_SYMBOL_GPL(get_user_pages_fast_only);
 int get_user_pages_fast(unsigned long start, int nr_pages,
                        unsigned int gup_flags, struct page **pages)
 {
-       /*
-        * FOLL_PIN must only be set internally by the pin_user_pages*() APIs,
-        * never directly by the caller, so enforce that:
-        */
-       if (WARN_ON_ONCE(gup_flags & FOLL_PIN))
+       if (!is_valid_gup_flags(gup_flags))
                return -EINVAL;
 
        /*