Merge tag 'nfsd-5.15-2' of git://git.kernel.org/pub/scm/linux/kernel/git/cel/linux
[platform/kernel/linux-starfive.git] / mm / gup.c
index 3ded6a5..886d614 100644 (file)
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -10,6 +10,7 @@
 #include <linux/rmap.h>
 #include <linux/swap.h>
 #include <linux/swapops.h>
+#include <linux/secretmem.h>
 
 #include <linux/sched/signal.h>
 #include <linux/rwsem.h>
@@ -44,6 +45,23 @@ static void hpage_pincount_sub(struct page *page, int refs)
        atomic_sub(refs, compound_pincount_ptr(page));
 }
 
+/* Equivalent to calling put_page() @refs times. */
+static void put_page_refs(struct page *page, int refs)
+{
+#ifdef CONFIG_DEBUG_VM
+       if (VM_WARN_ON_ONCE_PAGE(page_ref_count(page) < refs, page))
+               return;
+#endif
+
+       /*
+        * Calling put_page() for each ref is unnecessarily slow. Only the last
+        * ref needs a put_page().
+        */
+       if (refs > 1)
+               page_ref_sub(page, refs - 1);
+       put_page(page);
+}
+
 /*
  * Return the compound head page with ref appropriately incremented,
  * or NULL if that failed.
@@ -56,13 +74,35 @@ static inline struct page *try_get_compound_head(struct page *page, int refs)
                return NULL;
        if (unlikely(!page_cache_add_speculative(head, refs)))
                return NULL;
+
+       /*
+        * At this point we have a stable reference to the head page; but it
+        * could be that between the compound_head() lookup and the refcount
+        * increment, the compound page was split, in which case we'd end up
+        * holding a reference on a page that has nothing to do with the page
+        * we were given anymore.
+        * So now that the head page is stable, recheck that the pages still
+        * belong together.
+        */
+       if (unlikely(compound_head(page) != head)) {
+               put_page_refs(head, refs);
+               return NULL;
+       }
+
        return head;
 }
 
-/*
+/**
  * try_grab_compound_head() - attempt to elevate a page's refcount, by a
  * flags-dependent amount.
  *
+ * Even though the name includes "compound_head", this function is still
+ * appropriate for callers that have a non-compound @page to get.
+ *
+ * @page:  pointer to page to be grabbed
+ * @refs:  the value to (effectively) add to the page's refcount
+ * @flags: gup flags: these are the FOLL_* flag values.
+ *
  * "grab" names in this file mean, "look at flags to decide whether to use
  * FOLL_PIN or FOLL_GET behavior, when incrementing the page's refcount.
  *
@@ -70,22 +110,26 @@ static inline struct page *try_get_compound_head(struct page *page, int refs)
  * same time. (That's true throughout the get_user_pages*() and
  * pin_user_pages*() APIs.) Cases:
  *
- *    FOLL_GET: page's refcount will be incremented by 1.
- *    FOLL_PIN: page's refcount will be incremented by GUP_PIN_COUNTING_BIAS.
+ *    FOLL_GET: page's refcount will be incremented by @refs.
+ *
+ *    FOLL_PIN on compound pages that are > two pages long: page's refcount will
+ *    be incremented by @refs, and page[2].hpage_pinned_refcount will be
+ *    incremented by @refs * GUP_PIN_COUNTING_BIAS.
+ *
+ *    FOLL_PIN on normal pages, or compound pages that are two pages long:
+ *    page's refcount will be incremented by @refs * GUP_PIN_COUNTING_BIAS.
  *
  * Return: head page (with refcount appropriately incremented) for success, or
  * NULL upon failure. If neither FOLL_GET nor FOLL_PIN was set, that's
  * considered failure, and furthermore, a likely bug in the caller, so a warning
  * is also emitted.
  */
-__maybe_unused struct page *try_grab_compound_head(struct page *page,
-                                                  int refs, unsigned int flags)
+struct page *try_grab_compound_head(struct page *page,
+                                   int refs, unsigned int flags)
 {
        if (flags & FOLL_GET)
                return try_get_compound_head(page, refs);
        else if (flags & FOLL_PIN) {
-               int orig_refs = refs;
-
                /*
                 * Can't do FOLL_LONGTERM + FOLL_PIN gup fast path if not in a
                 * right zone, so fail and let the caller fall back to the slow
@@ -96,25 +140,30 @@ __maybe_unused struct page *try_grab_compound_head(struct page *page,
                        return NULL;
 
                /*
+                * CAUTION: Don't use compound_head() on the page before this
+                * point, the result won't be stable.
+                */
+               page = try_get_compound_head(page, refs);
+               if (!page)
+                       return NULL;
+
+               /*
                 * When pinning a compound page of order > 1 (which is what
                 * hpage_pincount_available() checks for), use an exact count to
                 * track it, via hpage_pincount_add/_sub().
                 *
                 * However, be sure to *also* increment the normal page refcount
                 * field at least once, so that the page really is pinned.
+                * That's why the refcount from the earlier
+                * try_get_compound_head() is left intact.
                 */
-               if (!hpage_pincount_available(page))
-                       refs *= GUP_PIN_COUNTING_BIAS;
-
-               page = try_get_compound_head(page, refs);
-               if (!page)
-                       return NULL;
-
                if (hpage_pincount_available(page))
                        hpage_pincount_add(page, refs);
+               else
+                       page_ref_add(page, refs * (GUP_PIN_COUNTING_BIAS - 1));
 
                mod_node_page_state(page_pgdat(page), NR_FOLL_PIN_ACQUIRED,
-                                   orig_refs);
+                                   refs);
 
                return page;
        }
@@ -135,14 +184,7 @@ static void put_compound_head(struct page *page, int refs, unsigned int flags)
                        refs *= GUP_PIN_COUNTING_BIAS;
        }
 
-       VM_BUG_ON_PAGE(page_ref_count(page) < refs, page);
-       /*
-        * Calling put_page() for each ref is unnecessarily slow. Only the last
-        * ref needs a put_page().
-        */
-       if (refs > 1)
-               page_ref_sub(page, refs - 1);
-       put_page(page);
+       put_page_refs(page, refs);
 }
 
 /**
@@ -157,10 +199,8 @@ static void put_compound_head(struct page *page, int refs, unsigned int flags)
  * @flags:   gup flags: these are the FOLL_* flag values.
  *
  * Either FOLL_PIN or FOLL_GET (or neither) may be set, but not both at the same
- * time. Cases:
- *
- *    FOLL_GET: page's refcount will be incremented by 1.
- *    FOLL_PIN: page's refcount will be incremented by GUP_PIN_COUNTING_BIAS.
+ * time. Cases: please see the try_grab_compound_head() documentation, with
+ * "refs=1".
  *
  * Return: true for success, or if no action was required (if neither FOLL_PIN
  * nor FOLL_GET was set, nothing is done). False for failure: FOLL_GET or
@@ -168,35 +208,10 @@ static void put_compound_head(struct page *page, int refs, unsigned int flags)
  */
 bool __must_check try_grab_page(struct page *page, unsigned int flags)
 {
-       WARN_ON_ONCE((flags & (FOLL_GET | FOLL_PIN)) == (FOLL_GET | FOLL_PIN));
-
-       if (flags & FOLL_GET)
-               return try_get_page(page);
-       else if (flags & FOLL_PIN) {
-               int refs = 1;
-
-               page = compound_head(page);
-
-               if (WARN_ON_ONCE(page_ref_count(page) <= 0))
-                       return false;
-
-               if (hpage_pincount_available(page))
-                       hpage_pincount_add(page, 1);
-               else
-                       refs = GUP_PIN_COUNTING_BIAS;
-
-               /*
-                * Similar to try_grab_compound_head(): even if using the
-                * hpage_pincount_add/_sub() routines, be sure to
-                * *also* increment the normal page refcount field at least
-                * once, so that the page really is pinned.
-                */
-               page_ref_add(page, refs);
-
-               mod_node_page_state(page_pgdat(page), NR_FOLL_PIN_ACQUIRED, 1);
-       }
+       if (!(flags & (FOLL_GET | FOLL_PIN)))
+               return true;
 
-       return true;
+       return try_grab_compound_head(page, 1, flags);
 }
 
 /**
@@ -392,6 +407,17 @@ void unpin_user_pages(struct page **pages, unsigned long npages)
 }
 EXPORT_SYMBOL(unpin_user_pages);
 
+/*
+ * Set the MMF_HAS_PINNED if not set yet; after set it'll be there for the mm's
+ * lifecycle.  Avoid setting the bit unless necessary, or it might cause write
+ * cache bouncing on large SMP machines for concurrent pinned gups.
+ */
+static inline void mm_set_has_pinned_flag(unsigned long *mm_flags)
+{
+       if (!test_bit(MMF_HAS_PINNED, mm_flags))
+               set_bit(MMF_HAS_PINNED, mm_flags);
+}
+
 #ifdef CONFIG_MMU
 static struct page *no_page_table(struct vm_area_struct *vma,
                unsigned int flags)
@@ -816,6 +842,9 @@ struct page *follow_page(struct vm_area_struct *vma, unsigned long address,
        struct follow_page_context ctx = { NULL };
        struct page *page;
 
+       if (vma_is_secretmem(vma))
+               return NULL;
+
        page = follow_page_mask(vma, address, foll_flags, &ctx);
        if (ctx.pgmap)
                put_dev_pagemap(ctx.pgmap);
@@ -949,6 +978,9 @@ static int check_vma_flags(struct vm_area_struct *vma, unsigned long gup_flags)
        if ((gup_flags & FOLL_LONGTERM) && vma_is_fsdax(vma))
                return -EOPNOTSUPP;
 
+       if (vma_is_secretmem(vma))
+               return -EFAULT;
+
        if (write) {
                if (!(vm_flags & VM_WRITE)) {
                        if (!(gup_flags & FOLL_FORCE))
@@ -1105,7 +1137,6 @@ static long __get_user_pages(struct mm_struct *mm,
                                         * We must stop here.
                                         */
                                        BUG_ON(gup_flags & FOLL_NOWAIT);
-                                       BUG_ON(ret != 0);
                                        goto out;
                                }
                                continue;
@@ -1230,7 +1261,7 @@ int fixup_user_fault(struct mm_struct *mm,
                     bool *unlocked)
 {
        struct vm_area_struct *vma;
-       vm_fault_t ret, major = 0;
+       vm_fault_t ret;
 
        address = untagged_addr(address);
 
@@ -1250,7 +1281,6 @@ retry:
                return -EINTR;
 
        ret = handle_mm_fault(vma, address, fault_flags, NULL);
-       major |= ret & VM_FAULT_MAJOR;
        if (ret & VM_FAULT_ERROR) {
                int err = vm_fault_to_errno(ret, 0);
 
@@ -1293,7 +1323,7 @@ static __always_inline long __get_user_pages_locked(struct mm_struct *mm,
        }
 
        if (flags & FOLL_PIN)
-               atomic_set(&mm->has_pinned, 1);
+               mm_set_has_pinned_flag(&mm->flags);
 
        /*
         * FOLL_PIN and FOLL_GET are mutually exclusive. Traditional behavior
@@ -1429,8 +1459,8 @@ long populate_vma_page_range(struct vm_area_struct *vma,
        unsigned long nr_pages = (end - start) / PAGE_SIZE;
        int gup_flags;
 
-       VM_BUG_ON(start & ~PAGE_MASK);
-       VM_BUG_ON(end   & ~PAGE_MASK);
+       VM_BUG_ON(!PAGE_ALIGNED(start));
+       VM_BUG_ON(!PAGE_ALIGNED(end));
        VM_BUG_ON_VMA(start < vma->vm_start, vma);
        VM_BUG_ON_VMA(end   > vma->vm_end, vma);
        mmap_assert_locked(mm);
@@ -1462,6 +1492,67 @@ long populate_vma_page_range(struct vm_area_struct *vma,
 }
 
 /*
+ * faultin_vma_page_range() - populate (prefault) page tables inside the
+ *                           given VMA range readable/writable
+ *
+ * This takes care of mlocking the pages, too, if VM_LOCKED is set.
+ *
+ * @vma: target vma
+ * @start: start address
+ * @end: end address
+ * @write: whether to prefault readable or writable
+ * @locked: whether the mmap_lock is still held
+ *
+ * Returns either number of processed pages in the vma, or a negative error
+ * code on error (see __get_user_pages()).
+ *
+ * vma->vm_mm->mmap_lock must be held. The range must be page-aligned and
+ * covered by the VMA.
+ *
+ * If @locked is NULL, it may be held for read or write and will be unperturbed.
+ *
+ * If @locked is non-NULL, it must held for read only and may be released.  If
+ * it's released, *@locked will be set to 0.
+ */
+long faultin_vma_page_range(struct vm_area_struct *vma, unsigned long start,
+                           unsigned long end, bool write, int *locked)
+{
+       struct mm_struct *mm = vma->vm_mm;
+       unsigned long nr_pages = (end - start) / PAGE_SIZE;
+       int gup_flags;
+
+       VM_BUG_ON(!PAGE_ALIGNED(start));
+       VM_BUG_ON(!PAGE_ALIGNED(end));
+       VM_BUG_ON_VMA(start < vma->vm_start, vma);
+       VM_BUG_ON_VMA(end > vma->vm_end, vma);
+       mmap_assert_locked(mm);
+
+       /*
+        * FOLL_TOUCH: Mark page accessed and thereby young; will also mark
+        *             the page dirty with FOLL_WRITE -- which doesn't make a
+        *             difference with !FOLL_FORCE, because the page is writable
+        *             in the page table.
+        * FOLL_HWPOISON: Return -EHWPOISON instead of -EFAULT when we hit
+        *                a poisoned page.
+        * FOLL_POPULATE: Always populate memory with VM_LOCKONFAULT.
+        * !FOLL_FORCE: Require proper access permissions.
+        */
+       gup_flags = FOLL_TOUCH | FOLL_POPULATE | FOLL_MLOCK | FOLL_HWPOISON;
+       if (write)
+               gup_flags |= FOLL_WRITE;
+
+       /*
+        * We want to report -EINVAL instead of -EFAULT for any permission
+        * problems or incompatible mappings.
+        */
+       if (check_vma_flags(vma, gup_flags))
+               return -EINVAL;
+
+       return __get_user_pages(mm, start, nr_pages, gup_flags,
+                               NULL, NULL, locked);
+}
+
+/*
  * __mm_populate - populate and/or mlock pages within a range of address space.
  *
  * This is used to implement mlock() and the MAP_POPULATE / MAP_LOCKED mmap
@@ -1668,7 +1759,7 @@ static long check_and_migrate_movable_pages(unsigned long nr_pages,
        if (!list_empty(&movable_page_list)) {
                ret = migrate_pages(&movable_page_list, alloc_migration_target,
                                    NULL, (unsigned long)&mtc, MIGRATE_SYNC,
-                                   MR_LONGTERM_PIN);
+                                   MR_LONGTERM_PIN, NULL);
                if (ret && !list_empty(&movable_page_list))
                        putback_movable_pages(&movable_page_list);
        }
@@ -2073,6 +2164,11 @@ static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
                if (!head)
                        goto pte_unmap;
 
+               if (unlikely(page_is_secretmem(page))) {
+                       put_compound_head(head, 1, flags);
+                       goto pte_unmap;
+               }
+
                if (unlikely(pte_val(pte) != pte_val(*ptep))) {
                        put_compound_head(head, 1, flags);
                        goto pte_unmap;
@@ -2132,6 +2228,7 @@ static int __gup_device_huge(unsigned long pfn, unsigned long addr,
 {
        int nr_start = *nr;
        struct dev_pagemap *pgmap = NULL;
+       int ret = 1;
 
        do {
                struct page *page = pfn_to_page(pfn);
@@ -2139,21 +2236,22 @@ static int __gup_device_huge(unsigned long pfn, unsigned long addr,
                pgmap = get_dev_pagemap(pfn, pgmap);
                if (unlikely(!pgmap)) {
                        undo_dev_pagemap(nr, nr_start, flags, pages);
-                       return 0;
+                       ret = 0;
+                       break;
                }
                SetPageReferenced(page);
                pages[*nr] = page;
                if (unlikely(!try_grab_page(page, flags))) {
                        undo_dev_pagemap(nr, nr_start, flags, pages);
-                       return 0;
+                       ret = 0;
+                       break;
                }
                (*nr)++;
                pfn++;
        } while (addr += PAGE_SIZE, addr != end);
 
-       if (pgmap)
-               put_dev_pagemap(pgmap);
-       return 1;
+       put_dev_pagemap(pgmap);
+       return ret;
 }
 
 static int __gup_device_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
@@ -2614,7 +2712,7 @@ static int internal_get_user_pages_fast(unsigned long start,
                return -EINVAL;
 
        if (gup_flags & FOLL_PIN)
-               atomic_set(&current->mm->has_pinned, 1);
+               mm_set_has_pinned_flag(&current->mm->flags);
 
        if (!(gup_flags & FOLL_FAST_ONLY))
                might_lock_read(&current->mm->mmap_lock);