Merge remote-tracking branch 'stable/linux-5.15.y' into rpi-5.15.y
[platform/kernel/linux-rpi.git] / mm / mempolicy.c
index 1592b08..4472be6 100644 (file)
@@ -347,7 +347,7 @@ static void mpol_rebind_preferred(struct mempolicy *pol,
  */
 static void mpol_rebind_policy(struct mempolicy *pol, const nodemask_t *newmask)
 {
-       if (!pol)
+       if (!pol || pol->mode == MPOL_LOCAL)
                return;
        if (!mpol_store_user_nodemask(pol) &&
            nodes_equal(pol->w.cpuset_mems_allowed, *newmask))
@@ -783,7 +783,6 @@ static int vma_replace_policy(struct vm_area_struct *vma,
 static int mbind_range(struct mm_struct *mm, unsigned long start,
                       unsigned long end, struct mempolicy *new_pol)
 {
-       struct vm_area_struct *next;
        struct vm_area_struct *prev;
        struct vm_area_struct *vma;
        int err = 0;
@@ -798,8 +797,7 @@ static int mbind_range(struct mm_struct *mm, unsigned long start,
        if (start > vma->vm_start)
                prev = vma;
 
-       for (; vma && vma->vm_start < end; prev = vma, vma = next) {
-               next = vma->vm_next;
+       for (; vma && vma->vm_start < end; prev = vma, vma = vma->vm_next) {
                vmstart = max(start, vma->vm_start);
                vmend   = min(end, vma->vm_end);
 
@@ -813,10 +811,6 @@ static int mbind_range(struct mm_struct *mm, unsigned long start,
                                 new_pol, vma->vm_userfaultfd_ctx);
                if (prev) {
                        vma = prev;
-                       next = vma->vm_next;
-                       if (mpol_equal(vma_policy(vma), new_pol))
-                               continue;
-                       /* vma_merge() joined vma && vma->next, case 8 */
                        goto replace;
                }
                if (vma->vm_start != vmstart) {
@@ -856,16 +850,6 @@ static long do_set_mempolicy(unsigned short mode, unsigned short flags,
                goto out;
        }
 
-       if (flags & MPOL_F_NUMA_BALANCING) {
-               if (new && new->mode == MPOL_BIND) {
-                       new->flags |= (MPOL_F_MOF | MPOL_F_MORON);
-               } else {
-                       ret = -EINVAL;
-                       mpol_put(new);
-                       goto out;
-               }
-       }
-
        ret = mpol_set_nodemask(new, nodes, scratch);
        if (ret) {
                mpol_put(new);
@@ -1405,7 +1389,7 @@ static int get_nodes(nodemask_t *nodes, const unsigned long __user *nmask,
                unsigned long bits = min_t(unsigned long, maxnode, BITS_PER_LONG);
                unsigned long t;
 
-               if (get_bitmap(&t, &nmask[maxnode / BITS_PER_LONG], bits))
+               if (get_bitmap(&t, &nmask[(maxnode - 1) / BITS_PER_LONG], bits))
                        return -EFAULT;
 
                if (maxnode - bits >= MAX_NUMNODES) {
@@ -1458,7 +1442,11 @@ static inline int sanitize_mpol_flags(int *mode, unsigned short *flags)
                return -EINVAL;
        if ((*flags & MPOL_F_STATIC_NODES) && (*flags & MPOL_F_RELATIVE_NODES))
                return -EINVAL;
-
+       if (*flags & MPOL_F_NUMA_BALANCING) {
+               if (*mode != MPOL_BIND)
+                       return -EINVAL;
+               *flags |= (MPOL_F_MOF | MPOL_F_MORON);
+       }
        return 0;
 }
 
@@ -2146,8 +2134,7 @@ struct page *alloc_pages_vma(gfp_t gfp, int order, struct vm_area_struct *vma,
                         * memory with both reclaim and compact as well.
                         */
                        if (!page && (gfp & __GFP_DIRECT_RECLAIM))
-                               page = __alloc_pages_node(hpage_node,
-                                                               gfp, order);
+                               page = __alloc_pages(gfp, order, hpage_node, nmask);
 
                        goto out;
                }
@@ -2574,6 +2561,7 @@ alloc_new:
        mpol_new = kmem_cache_alloc(policy_cache, GFP_KERNEL);
        if (!mpol_new)
                goto err_out;
+       atomic_set(&mpol_new->refcnt, 1);
        goto restart;
 }