73c8af4827fe87ca521e4b9486e53df6ca0981a2
[platform/kernel/linux-rpi.git] / mm / hmm.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Copyright 2013 Red Hat Inc.
4  *
5  * Authors: Jérôme Glisse <jglisse@redhat.com>
6  */
7 /*
8  * Refer to include/linux/hmm.h for information about heterogeneous memory
9  * management or HMM for short.
10  */
11 #include <linux/mm.h>
12 #include <linux/hmm.h>
13 #include <linux/init.h>
14 #include <linux/rmap.h>
15 #include <linux/swap.h>
16 #include <linux/slab.h>
17 #include <linux/sched.h>
18 #include <linux/mmzone.h>
19 #include <linux/pagemap.h>
20 #include <linux/swapops.h>
21 #include <linux/hugetlb.h>
22 #include <linux/memremap.h>
23 #include <linux/sched/mm.h>
24 #include <linux/jump_label.h>
25 #include <linux/dma-mapping.h>
26 #include <linux/mmu_notifier.h>
27 #include <linux/memory_hotplug.h>
28
29 #define PA_SECTION_SIZE (1UL << PA_SECTION_SHIFT)
30
31 #if IS_ENABLED(CONFIG_HMM_MIRROR)
32 static const struct mmu_notifier_ops hmm_mmu_notifier_ops;
33
34 /**
35  * hmm_get_or_create - register HMM against an mm (HMM internal)
36  *
37  * @mm: mm struct to attach to
38  * Returns: returns an HMM object, either by referencing the existing
39  *          (per-process) object, or by creating a new one.
40  *
41  * This is not intended to be used directly by device drivers. If mm already
42  * has an HMM struct then it get a reference on it and returns it. Otherwise
43  * it allocates an HMM struct, initializes it, associate it with the mm and
44  * returns it.
45  */
46 static struct hmm *hmm_get_or_create(struct mm_struct *mm)
47 {
48         struct hmm *hmm;
49
50         lockdep_assert_held_exclusive(&mm->mmap_sem);
51
52         /* Abuse the page_table_lock to also protect mm->hmm. */
53         spin_lock(&mm->page_table_lock);
54         hmm = mm->hmm;
55         if (mm->hmm && kref_get_unless_zero(&mm->hmm->kref))
56                 goto out_unlock;
57         spin_unlock(&mm->page_table_lock);
58
59         hmm = kmalloc(sizeof(*hmm), GFP_KERNEL);
60         if (!hmm)
61                 return NULL;
62         init_waitqueue_head(&hmm->wq);
63         INIT_LIST_HEAD(&hmm->mirrors);
64         init_rwsem(&hmm->mirrors_sem);
65         hmm->mmu_notifier.ops = NULL;
66         INIT_LIST_HEAD(&hmm->ranges);
67         mutex_init(&hmm->lock);
68         kref_init(&hmm->kref);
69         hmm->notifiers = 0;
70         hmm->dead = false;
71         hmm->mm = mm;
72
73         hmm->mmu_notifier.ops = &hmm_mmu_notifier_ops;
74         if (__mmu_notifier_register(&hmm->mmu_notifier, mm)) {
75                 kfree(hmm);
76                 return NULL;
77         }
78
79         mmgrab(hmm->mm);
80
81         /*
82          * We hold the exclusive mmap_sem here so we know that mm->hmm is
83          * still NULL or 0 kref, and is safe to update.
84          */
85         spin_lock(&mm->page_table_lock);
86         mm->hmm = hmm;
87
88 out_unlock:
89         spin_unlock(&mm->page_table_lock);
90         return hmm;
91 }
92
93 static void hmm_free_rcu(struct rcu_head *rcu)
94 {
95         struct hmm *hmm = container_of(rcu, struct hmm, rcu);
96
97         mmdrop(hmm->mm);
98         kfree(hmm);
99 }
100
101 static void hmm_free(struct kref *kref)
102 {
103         struct hmm *hmm = container_of(kref, struct hmm, kref);
104
105         spin_lock(&hmm->mm->page_table_lock);
106         if (hmm->mm->hmm == hmm)
107                 hmm->mm->hmm = NULL;
108         spin_unlock(&hmm->mm->page_table_lock);
109
110         mmu_notifier_unregister_no_release(&hmm->mmu_notifier, hmm->mm);
111         mmu_notifier_call_srcu(&hmm->rcu, hmm_free_rcu);
112 }
113
114 static inline void hmm_put(struct hmm *hmm)
115 {
116         kref_put(&hmm->kref, hmm_free);
117 }
118
119 static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
120 {
121         struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
122         struct hmm_mirror *mirror;
123         struct hmm_range *range;
124
125         /* Bail out if hmm is in the process of being freed */
126         if (!kref_get_unless_zero(&hmm->kref))
127                 return;
128
129         /* Report this HMM as dying. */
130         hmm->dead = true;
131
132         /* Wake-up everyone waiting on any range. */
133         mutex_lock(&hmm->lock);
134         list_for_each_entry(range, &hmm->ranges, list)
135                 range->valid = false;
136         wake_up_all(&hmm->wq);
137         mutex_unlock(&hmm->lock);
138
139         down_write(&hmm->mirrors_sem);
140         mirror = list_first_entry_or_null(&hmm->mirrors, struct hmm_mirror,
141                                           list);
142         while (mirror) {
143                 list_del_init(&mirror->list);
144                 if (mirror->ops->release) {
145                         /*
146                          * Drop mirrors_sem so the release callback can wait
147                          * on any pending work that might itself trigger a
148                          * mmu_notifier callback and thus would deadlock with
149                          * us.
150                          */
151                         up_write(&hmm->mirrors_sem);
152                         mirror->ops->release(mirror);
153                         down_write(&hmm->mirrors_sem);
154                 }
155                 mirror = list_first_entry_or_null(&hmm->mirrors,
156                                                   struct hmm_mirror, list);
157         }
158         up_write(&hmm->mirrors_sem);
159
160         hmm_put(hmm);
161 }
162
163 static int hmm_invalidate_range_start(struct mmu_notifier *mn,
164                         const struct mmu_notifier_range *nrange)
165 {
166         struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
167         struct hmm_mirror *mirror;
168         struct hmm_update update;
169         struct hmm_range *range;
170         int ret = 0;
171
172         if (!kref_get_unless_zero(&hmm->kref))
173                 return 0;
174
175         update.start = nrange->start;
176         update.end = nrange->end;
177         update.event = HMM_UPDATE_INVALIDATE;
178         update.blockable = mmu_notifier_range_blockable(nrange);
179
180         if (mmu_notifier_range_blockable(nrange))
181                 mutex_lock(&hmm->lock);
182         else if (!mutex_trylock(&hmm->lock)) {
183                 ret = -EAGAIN;
184                 goto out;
185         }
186         hmm->notifiers++;
187         list_for_each_entry(range, &hmm->ranges, list) {
188                 if (update.end < range->start || update.start >= range->end)
189                         continue;
190
191                 range->valid = false;
192         }
193         mutex_unlock(&hmm->lock);
194
195         if (mmu_notifier_range_blockable(nrange))
196                 down_read(&hmm->mirrors_sem);
197         else if (!down_read_trylock(&hmm->mirrors_sem)) {
198                 ret = -EAGAIN;
199                 goto out;
200         }
201         list_for_each_entry(mirror, &hmm->mirrors, list) {
202                 int ret;
203
204                 ret = mirror->ops->sync_cpu_device_pagetables(mirror, &update);
205                 if (!update.blockable && ret == -EAGAIN)
206                         break;
207         }
208         up_read(&hmm->mirrors_sem);
209
210 out:
211         hmm_put(hmm);
212         return ret;
213 }
214
215 static void hmm_invalidate_range_end(struct mmu_notifier *mn,
216                         const struct mmu_notifier_range *nrange)
217 {
218         struct hmm *hmm = container_of(mn, struct hmm, mmu_notifier);
219
220         if (!kref_get_unless_zero(&hmm->kref))
221                 return;
222
223         mutex_lock(&hmm->lock);
224         hmm->notifiers--;
225         if (!hmm->notifiers) {
226                 struct hmm_range *range;
227
228                 list_for_each_entry(range, &hmm->ranges, list) {
229                         if (range->valid)
230                                 continue;
231                         range->valid = true;
232                 }
233                 wake_up_all(&hmm->wq);
234         }
235         mutex_unlock(&hmm->lock);
236
237         hmm_put(hmm);
238 }
239
240 static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
241         .release                = hmm_release,
242         .invalidate_range_start = hmm_invalidate_range_start,
243         .invalidate_range_end   = hmm_invalidate_range_end,
244 };
245
246 /*
247  * hmm_mirror_register() - register a mirror against an mm
248  *
249  * @mirror: new mirror struct to register
250  * @mm: mm to register against
251  * Return: 0 on success, -ENOMEM if no memory, -EINVAL if invalid arguments
252  *
253  * To start mirroring a process address space, the device driver must register
254  * an HMM mirror struct.
255  *
256  * THE mm->mmap_sem MUST BE HELD IN WRITE MODE !
257  */
258 int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm)
259 {
260         /* Sanity check */
261         if (!mm || !mirror || !mirror->ops)
262                 return -EINVAL;
263
264         mirror->hmm = hmm_get_or_create(mm);
265         if (!mirror->hmm)
266                 return -ENOMEM;
267
268         down_write(&mirror->hmm->mirrors_sem);
269         list_add(&mirror->list, &mirror->hmm->mirrors);
270         up_write(&mirror->hmm->mirrors_sem);
271
272         return 0;
273 }
274 EXPORT_SYMBOL(hmm_mirror_register);
275
276 /*
277  * hmm_mirror_unregister() - unregister a mirror
278  *
279  * @mirror: mirror struct to unregister
280  *
281  * Stop mirroring a process address space, and cleanup.
282  */
283 void hmm_mirror_unregister(struct hmm_mirror *mirror)
284 {
285         struct hmm *hmm = READ_ONCE(mirror->hmm);
286
287         if (hmm == NULL)
288                 return;
289
290         down_write(&hmm->mirrors_sem);
291         list_del_init(&mirror->list);
292         /* To protect us against double unregister ... */
293         mirror->hmm = NULL;
294         up_write(&hmm->mirrors_sem);
295
296         hmm_put(hmm);
297 }
298 EXPORT_SYMBOL(hmm_mirror_unregister);
299
300 struct hmm_vma_walk {
301         struct hmm_range        *range;
302         struct dev_pagemap      *pgmap;
303         unsigned long           last;
304         bool                    fault;
305         bool                    block;
306 };
307
308 static int hmm_vma_do_fault(struct mm_walk *walk, unsigned long addr,
309                             bool write_fault, uint64_t *pfn)
310 {
311         unsigned int flags = FAULT_FLAG_REMOTE;
312         struct hmm_vma_walk *hmm_vma_walk = walk->private;
313         struct hmm_range *range = hmm_vma_walk->range;
314         struct vm_area_struct *vma = walk->vma;
315         vm_fault_t ret;
316
317         flags |= hmm_vma_walk->block ? 0 : FAULT_FLAG_ALLOW_RETRY;
318         flags |= write_fault ? FAULT_FLAG_WRITE : 0;
319         ret = handle_mm_fault(vma, addr, flags);
320         if (ret & VM_FAULT_RETRY)
321                 return -EAGAIN;
322         if (ret & VM_FAULT_ERROR) {
323                 *pfn = range->values[HMM_PFN_ERROR];
324                 return -EFAULT;
325         }
326
327         return -EBUSY;
328 }
329
330 static int hmm_pfns_bad(unsigned long addr,
331                         unsigned long end,
332                         struct mm_walk *walk)
333 {
334         struct hmm_vma_walk *hmm_vma_walk = walk->private;
335         struct hmm_range *range = hmm_vma_walk->range;
336         uint64_t *pfns = range->pfns;
337         unsigned long i;
338
339         i = (addr - range->start) >> PAGE_SHIFT;
340         for (; addr < end; addr += PAGE_SIZE, i++)
341                 pfns[i] = range->values[HMM_PFN_ERROR];
342
343         return 0;
344 }
345
346 /*
347  * hmm_vma_walk_hole() - handle a range lacking valid pmd or pte(s)
348  * @start: range virtual start address (inclusive)
349  * @end: range virtual end address (exclusive)
350  * @fault: should we fault or not ?
351  * @write_fault: write fault ?
352  * @walk: mm_walk structure
353  * Return: 0 on success, -EBUSY after page fault, or page fault error
354  *
355  * This function will be called whenever pmd_none() or pte_none() returns true,
356  * or whenever there is no page directory covering the virtual address range.
357  */
358 static int hmm_vma_walk_hole_(unsigned long addr, unsigned long end,
359                               bool fault, bool write_fault,
360                               struct mm_walk *walk)
361 {
362         struct hmm_vma_walk *hmm_vma_walk = walk->private;
363         struct hmm_range *range = hmm_vma_walk->range;
364         uint64_t *pfns = range->pfns;
365         unsigned long i, page_size;
366
367         hmm_vma_walk->last = addr;
368         page_size = hmm_range_page_size(range);
369         i = (addr - range->start) >> range->page_shift;
370
371         for (; addr < end; addr += page_size, i++) {
372                 pfns[i] = range->values[HMM_PFN_NONE];
373                 if (fault || write_fault) {
374                         int ret;
375
376                         ret = hmm_vma_do_fault(walk, addr, write_fault,
377                                                &pfns[i]);
378                         if (ret != -EBUSY)
379                                 return ret;
380                 }
381         }
382
383         return (fault || write_fault) ? -EBUSY : 0;
384 }
385
386 static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
387                                       uint64_t pfns, uint64_t cpu_flags,
388                                       bool *fault, bool *write_fault)
389 {
390         struct hmm_range *range = hmm_vma_walk->range;
391
392         if (!hmm_vma_walk->fault)
393                 return;
394
395         /*
396          * So we not only consider the individual per page request we also
397          * consider the default flags requested for the range. The API can
398          * be use in 2 fashions. The first one where the HMM user coalesce
399          * multiple page fault into one request and set flags per pfns for
400          * of those faults. The second one where the HMM user want to pre-
401          * fault a range with specific flags. For the latter one it is a
402          * waste to have the user pre-fill the pfn arrays with a default
403          * flags value.
404          */
405         pfns = (pfns & range->pfn_flags_mask) | range->default_flags;
406
407         /* We aren't ask to do anything ... */
408         if (!(pfns & range->flags[HMM_PFN_VALID]))
409                 return;
410         /* If this is device memory than only fault if explicitly requested */
411         if ((cpu_flags & range->flags[HMM_PFN_DEVICE_PRIVATE])) {
412                 /* Do we fault on device memory ? */
413                 if (pfns & range->flags[HMM_PFN_DEVICE_PRIVATE]) {
414                         *write_fault = pfns & range->flags[HMM_PFN_WRITE];
415                         *fault = true;
416                 }
417                 return;
418         }
419
420         /* If CPU page table is not valid then we need to fault */
421         *fault = !(cpu_flags & range->flags[HMM_PFN_VALID]);
422         /* Need to write fault ? */
423         if ((pfns & range->flags[HMM_PFN_WRITE]) &&
424             !(cpu_flags & range->flags[HMM_PFN_WRITE])) {
425                 *write_fault = true;
426                 *fault = true;
427         }
428 }
429
430 static void hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
431                                  const uint64_t *pfns, unsigned long npages,
432                                  uint64_t cpu_flags, bool *fault,
433                                  bool *write_fault)
434 {
435         unsigned long i;
436
437         if (!hmm_vma_walk->fault) {
438                 *fault = *write_fault = false;
439                 return;
440         }
441
442         *fault = *write_fault = false;
443         for (i = 0; i < npages; ++i) {
444                 hmm_pte_need_fault(hmm_vma_walk, pfns[i], cpu_flags,
445                                    fault, write_fault);
446                 if ((*write_fault))
447                         return;
448         }
449 }
450
451 static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
452                              struct mm_walk *walk)
453 {
454         struct hmm_vma_walk *hmm_vma_walk = walk->private;
455         struct hmm_range *range = hmm_vma_walk->range;
456         bool fault, write_fault;
457         unsigned long i, npages;
458         uint64_t *pfns;
459
460         i = (addr - range->start) >> PAGE_SHIFT;
461         npages = (end - addr) >> PAGE_SHIFT;
462         pfns = &range->pfns[i];
463         hmm_range_need_fault(hmm_vma_walk, pfns, npages,
464                              0, &fault, &write_fault);
465         return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
466 }
467
468 static inline uint64_t pmd_to_hmm_pfn_flags(struct hmm_range *range, pmd_t pmd)
469 {
470         if (pmd_protnone(pmd))
471                 return 0;
472         return pmd_write(pmd) ? range->flags[HMM_PFN_VALID] |
473                                 range->flags[HMM_PFN_WRITE] :
474                                 range->flags[HMM_PFN_VALID];
475 }
476
477 static inline uint64_t pud_to_hmm_pfn_flags(struct hmm_range *range, pud_t pud)
478 {
479         if (!pud_present(pud))
480                 return 0;
481         return pud_write(pud) ? range->flags[HMM_PFN_VALID] |
482                                 range->flags[HMM_PFN_WRITE] :
483                                 range->flags[HMM_PFN_VALID];
484 }
485
486 static int hmm_vma_handle_pmd(struct mm_walk *walk,
487                               unsigned long addr,
488                               unsigned long end,
489                               uint64_t *pfns,
490                               pmd_t pmd)
491 {
492 #ifdef CONFIG_TRANSPARENT_HUGEPAGE
493         struct hmm_vma_walk *hmm_vma_walk = walk->private;
494         struct hmm_range *range = hmm_vma_walk->range;
495         unsigned long pfn, npages, i;
496         bool fault, write_fault;
497         uint64_t cpu_flags;
498
499         npages = (end - addr) >> PAGE_SHIFT;
500         cpu_flags = pmd_to_hmm_pfn_flags(range, pmd);
501         hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags,
502                              &fault, &write_fault);
503
504         if (pmd_protnone(pmd) || fault || write_fault)
505                 return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
506
507         pfn = pmd_pfn(pmd) + pte_index(addr);
508         for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++) {
509                 if (pmd_devmap(pmd)) {
510                         hmm_vma_walk->pgmap = get_dev_pagemap(pfn,
511                                               hmm_vma_walk->pgmap);
512                         if (unlikely(!hmm_vma_walk->pgmap))
513                                 return -EBUSY;
514                 }
515                 pfns[i] = hmm_device_entry_from_pfn(range, pfn) | cpu_flags;
516         }
517         if (hmm_vma_walk->pgmap) {
518                 put_dev_pagemap(hmm_vma_walk->pgmap);
519                 hmm_vma_walk->pgmap = NULL;
520         }
521         hmm_vma_walk->last = end;
522         return 0;
523 #else
524         /* If THP is not enabled then we should never reach that code ! */
525         return -EINVAL;
526 #endif
527 }
528
529 static inline uint64_t pte_to_hmm_pfn_flags(struct hmm_range *range, pte_t pte)
530 {
531         if (pte_none(pte) || !pte_present(pte) || pte_protnone(pte))
532                 return 0;
533         return pte_write(pte) ? range->flags[HMM_PFN_VALID] |
534                                 range->flags[HMM_PFN_WRITE] :
535                                 range->flags[HMM_PFN_VALID];
536 }
537
538 static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
539                               unsigned long end, pmd_t *pmdp, pte_t *ptep,
540                               uint64_t *pfn)
541 {
542         struct hmm_vma_walk *hmm_vma_walk = walk->private;
543         struct hmm_range *range = hmm_vma_walk->range;
544         struct vm_area_struct *vma = walk->vma;
545         bool fault, write_fault;
546         uint64_t cpu_flags;
547         pte_t pte = *ptep;
548         uint64_t orig_pfn = *pfn;
549
550         *pfn = range->values[HMM_PFN_NONE];
551         fault = write_fault = false;
552
553         if (pte_none(pte)) {
554                 hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0,
555                                    &fault, &write_fault);
556                 if (fault || write_fault)
557                         goto fault;
558                 return 0;
559         }
560
561         if (!pte_present(pte)) {
562                 swp_entry_t entry = pte_to_swp_entry(pte);
563
564                 if (!non_swap_entry(entry)) {
565                         if (fault || write_fault)
566                                 goto fault;
567                         return 0;
568                 }
569
570                 /*
571                  * This is a special swap entry, ignore migration, use
572                  * device and report anything else as error.
573                  */
574                 if (is_device_private_entry(entry)) {
575                         cpu_flags = range->flags[HMM_PFN_VALID] |
576                                 range->flags[HMM_PFN_DEVICE_PRIVATE];
577                         cpu_flags |= is_write_device_private_entry(entry) ?
578                                 range->flags[HMM_PFN_WRITE] : 0;
579                         hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
580                                            &fault, &write_fault);
581                         if (fault || write_fault)
582                                 goto fault;
583                         *pfn = hmm_device_entry_from_pfn(range,
584                                             swp_offset(entry));
585                         *pfn |= cpu_flags;
586                         return 0;
587                 }
588
589                 if (is_migration_entry(entry)) {
590                         if (fault || write_fault) {
591                                 pte_unmap(ptep);
592                                 hmm_vma_walk->last = addr;
593                                 migration_entry_wait(vma->vm_mm,
594                                                      pmdp, addr);
595                                 return -EBUSY;
596                         }
597                         return 0;
598                 }
599
600                 /* Report error for everything else */
601                 *pfn = range->values[HMM_PFN_ERROR];
602                 return -EFAULT;
603         } else {
604                 cpu_flags = pte_to_hmm_pfn_flags(range, pte);
605                 hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
606                                    &fault, &write_fault);
607         }
608
609         if (fault || write_fault)
610                 goto fault;
611
612         if (pte_devmap(pte)) {
613                 hmm_vma_walk->pgmap = get_dev_pagemap(pte_pfn(pte),
614                                               hmm_vma_walk->pgmap);
615                 if (unlikely(!hmm_vma_walk->pgmap))
616                         return -EBUSY;
617         } else if (IS_ENABLED(CONFIG_ARCH_HAS_PTE_SPECIAL) && pte_special(pte)) {
618                 *pfn = range->values[HMM_PFN_SPECIAL];
619                 return -EFAULT;
620         }
621
622         *pfn = hmm_device_entry_from_pfn(range, pte_pfn(pte)) | cpu_flags;
623         return 0;
624
625 fault:
626         if (hmm_vma_walk->pgmap) {
627                 put_dev_pagemap(hmm_vma_walk->pgmap);
628                 hmm_vma_walk->pgmap = NULL;
629         }
630         pte_unmap(ptep);
631         /* Fault any virtual address we were asked to fault */
632         return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
633 }
634
635 static int hmm_vma_walk_pmd(pmd_t *pmdp,
636                             unsigned long start,
637                             unsigned long end,
638                             struct mm_walk *walk)
639 {
640         struct hmm_vma_walk *hmm_vma_walk = walk->private;
641         struct hmm_range *range = hmm_vma_walk->range;
642         struct vm_area_struct *vma = walk->vma;
643         uint64_t *pfns = range->pfns;
644         unsigned long addr = start, i;
645         pte_t *ptep;
646         pmd_t pmd;
647
648
649 again:
650         pmd = READ_ONCE(*pmdp);
651         if (pmd_none(pmd))
652                 return hmm_vma_walk_hole(start, end, walk);
653
654         if (pmd_huge(pmd) && (range->vma->vm_flags & VM_HUGETLB))
655                 return hmm_pfns_bad(start, end, walk);
656
657         if (thp_migration_supported() && is_pmd_migration_entry(pmd)) {
658                 bool fault, write_fault;
659                 unsigned long npages;
660                 uint64_t *pfns;
661
662                 i = (addr - range->start) >> PAGE_SHIFT;
663                 npages = (end - addr) >> PAGE_SHIFT;
664                 pfns = &range->pfns[i];
665
666                 hmm_range_need_fault(hmm_vma_walk, pfns, npages,
667                                      0, &fault, &write_fault);
668                 if (fault || write_fault) {
669                         hmm_vma_walk->last = addr;
670                         pmd_migration_entry_wait(vma->vm_mm, pmdp);
671                         return -EBUSY;
672                 }
673                 return 0;
674         } else if (!pmd_present(pmd))
675                 return hmm_pfns_bad(start, end, walk);
676
677         if (pmd_devmap(pmd) || pmd_trans_huge(pmd)) {
678                 /*
679                  * No need to take pmd_lock here, even if some other threads
680                  * is splitting the huge pmd we will get that event through
681                  * mmu_notifier callback.
682                  *
683                  * So just read pmd value and check again its a transparent
684                  * huge or device mapping one and compute corresponding pfn
685                  * values.
686                  */
687                 pmd = pmd_read_atomic(pmdp);
688                 barrier();
689                 if (!pmd_devmap(pmd) && !pmd_trans_huge(pmd))
690                         goto again;
691
692                 i = (addr - range->start) >> PAGE_SHIFT;
693                 return hmm_vma_handle_pmd(walk, addr, end, &pfns[i], pmd);
694         }
695
696         /*
697          * We have handled all the valid case above ie either none, migration,
698          * huge or transparent huge. At this point either it is a valid pmd
699          * entry pointing to pte directory or it is a bad pmd that will not
700          * recover.
701          */
702         if (pmd_bad(pmd))
703                 return hmm_pfns_bad(start, end, walk);
704
705         ptep = pte_offset_map(pmdp, addr);
706         i = (addr - range->start) >> PAGE_SHIFT;
707         for (; addr < end; addr += PAGE_SIZE, ptep++, i++) {
708                 int r;
709
710                 r = hmm_vma_handle_pte(walk, addr, end, pmdp, ptep, &pfns[i]);
711                 if (r) {
712                         /* hmm_vma_handle_pte() did unmap pte directory */
713                         hmm_vma_walk->last = addr;
714                         return r;
715                 }
716         }
717         if (hmm_vma_walk->pgmap) {
718                 /*
719                  * We do put_dev_pagemap() here and not in hmm_vma_handle_pte()
720                  * so that we can leverage get_dev_pagemap() optimization which
721                  * will not re-take a reference on a pgmap if we already have
722                  * one.
723                  */
724                 put_dev_pagemap(hmm_vma_walk->pgmap);
725                 hmm_vma_walk->pgmap = NULL;
726         }
727         pte_unmap(ptep - 1);
728
729         hmm_vma_walk->last = addr;
730         return 0;
731 }
732
733 static int hmm_vma_walk_pud(pud_t *pudp,
734                             unsigned long start,
735                             unsigned long end,
736                             struct mm_walk *walk)
737 {
738         struct hmm_vma_walk *hmm_vma_walk = walk->private;
739         struct hmm_range *range = hmm_vma_walk->range;
740         unsigned long addr = start, next;
741         pmd_t *pmdp;
742         pud_t pud;
743         int ret;
744
745 again:
746         pud = READ_ONCE(*pudp);
747         if (pud_none(pud))
748                 return hmm_vma_walk_hole(start, end, walk);
749
750         if (pud_huge(pud) && pud_devmap(pud)) {
751                 unsigned long i, npages, pfn;
752                 uint64_t *pfns, cpu_flags;
753                 bool fault, write_fault;
754
755                 if (!pud_present(pud))
756                         return hmm_vma_walk_hole(start, end, walk);
757
758                 i = (addr - range->start) >> PAGE_SHIFT;
759                 npages = (end - addr) >> PAGE_SHIFT;
760                 pfns = &range->pfns[i];
761
762                 cpu_flags = pud_to_hmm_pfn_flags(range, pud);
763                 hmm_range_need_fault(hmm_vma_walk, pfns, npages,
764                                      cpu_flags, &fault, &write_fault);
765                 if (fault || write_fault)
766                         return hmm_vma_walk_hole_(addr, end, fault,
767                                                 write_fault, walk);
768
769                 pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
770                 for (i = 0; i < npages; ++i, ++pfn) {
771                         hmm_vma_walk->pgmap = get_dev_pagemap(pfn,
772                                               hmm_vma_walk->pgmap);
773                         if (unlikely(!hmm_vma_walk->pgmap))
774                                 return -EBUSY;
775                         pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
776                                   cpu_flags;
777                 }
778                 if (hmm_vma_walk->pgmap) {
779                         put_dev_pagemap(hmm_vma_walk->pgmap);
780                         hmm_vma_walk->pgmap = NULL;
781                 }
782                 hmm_vma_walk->last = end;
783                 return 0;
784         }
785
786         split_huge_pud(walk->vma, pudp, addr);
787         if (pud_none(*pudp))
788                 goto again;
789
790         pmdp = pmd_offset(pudp, addr);
791         do {
792                 next = pmd_addr_end(addr, end);
793                 ret = hmm_vma_walk_pmd(pmdp, addr, next, walk);
794                 if (ret)
795                         return ret;
796         } while (pmdp++, addr = next, addr != end);
797
798         return 0;
799 }
800
801 static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
802                                       unsigned long start, unsigned long end,
803                                       struct mm_walk *walk)
804 {
805 #ifdef CONFIG_HUGETLB_PAGE
806         unsigned long addr = start, i, pfn, mask, size, pfn_inc;
807         struct hmm_vma_walk *hmm_vma_walk = walk->private;
808         struct hmm_range *range = hmm_vma_walk->range;
809         struct vm_area_struct *vma = walk->vma;
810         struct hstate *h = hstate_vma(vma);
811         uint64_t orig_pfn, cpu_flags;
812         bool fault, write_fault;
813         spinlock_t *ptl;
814         pte_t entry;
815         int ret = 0;
816
817         size = 1UL << huge_page_shift(h);
818         mask = size - 1;
819         if (range->page_shift != PAGE_SHIFT) {
820                 /* Make sure we are looking at full page. */
821                 if (start & mask)
822                         return -EINVAL;
823                 if (end < (start + size))
824                         return -EINVAL;
825                 pfn_inc = size >> PAGE_SHIFT;
826         } else {
827                 pfn_inc = 1;
828                 size = PAGE_SIZE;
829         }
830
831
832         ptl = huge_pte_lock(hstate_vma(walk->vma), walk->mm, pte);
833         entry = huge_ptep_get(pte);
834
835         i = (start - range->start) >> range->page_shift;
836         orig_pfn = range->pfns[i];
837         range->pfns[i] = range->values[HMM_PFN_NONE];
838         cpu_flags = pte_to_hmm_pfn_flags(range, entry);
839         fault = write_fault = false;
840         hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
841                            &fault, &write_fault);
842         if (fault || write_fault) {
843                 ret = -ENOENT;
844                 goto unlock;
845         }
846
847         pfn = pte_pfn(entry) + ((start & mask) >> range->page_shift);
848         for (; addr < end; addr += size, i++, pfn += pfn_inc)
849                 range->pfns[i] = hmm_device_entry_from_pfn(range, pfn) |
850                                  cpu_flags;
851         hmm_vma_walk->last = end;
852
853 unlock:
854         spin_unlock(ptl);
855
856         if (ret == -ENOENT)
857                 return hmm_vma_walk_hole_(addr, end, fault, write_fault, walk);
858
859         return ret;
860 #else /* CONFIG_HUGETLB_PAGE */
861         return -EINVAL;
862 #endif
863 }
864
865 static void hmm_pfns_clear(struct hmm_range *range,
866                            uint64_t *pfns,
867                            unsigned long addr,
868                            unsigned long end)
869 {
870         for (; addr < end; addr += PAGE_SIZE, pfns++)
871                 *pfns = range->values[HMM_PFN_NONE];
872 }
873
874 /*
875  * hmm_range_register() - start tracking change to CPU page table over a range
876  * @range: range
877  * @mm: the mm struct for the range of virtual address
878  * @start: start virtual address (inclusive)
879  * @end: end virtual address (exclusive)
880  * @page_shift: expect page shift for the range
881  * Returns 0 on success, -EFAULT if the address space is no longer valid
882  *
883  * Track updates to the CPU page table see include/linux/hmm.h
884  */
885 int hmm_range_register(struct hmm_range *range,
886                        struct hmm_mirror *mirror,
887                        unsigned long start,
888                        unsigned long end,
889                        unsigned page_shift)
890 {
891         unsigned long mask = ((1UL << page_shift) - 1UL);
892         struct hmm *hmm = mirror->hmm;
893
894         range->valid = false;
895         range->hmm = NULL;
896
897         if ((start & mask) || (end & mask))
898                 return -EINVAL;
899         if (start >= end)
900                 return -EINVAL;
901
902         range->page_shift = page_shift;
903         range->start = start;
904         range->end = end;
905
906         /* Check if hmm_mm_destroy() was call. */
907         if (hmm->mm == NULL || hmm->dead)
908                 return -EFAULT;
909
910         /* Initialize range to track CPU page table updates. */
911         mutex_lock(&hmm->lock);
912
913         range->hmm = hmm;
914         kref_get(&hmm->kref);
915         list_add(&range->list, &hmm->ranges);
916
917         /*
918          * If there are any concurrent notifiers we have to wait for them for
919          * the range to be valid (see hmm_range_wait_until_valid()).
920          */
921         if (!hmm->notifiers)
922                 range->valid = true;
923         mutex_unlock(&hmm->lock);
924
925         return 0;
926 }
927 EXPORT_SYMBOL(hmm_range_register);
928
929 /*
930  * hmm_range_unregister() - stop tracking change to CPU page table over a range
931  * @range: range
932  *
933  * Range struct is used to track updates to the CPU page table after a call to
934  * hmm_range_register(). See include/linux/hmm.h for how to use it.
935  */
936 void hmm_range_unregister(struct hmm_range *range)
937 {
938         struct hmm *hmm = range->hmm;
939
940         /* Sanity check this really should not happen. */
941         if (hmm == NULL || range->end <= range->start)
942                 return;
943
944         mutex_lock(&hmm->lock);
945         list_del(&range->list);
946         mutex_unlock(&hmm->lock);
947
948         /* Drop reference taken by hmm_range_register() */
949         range->valid = false;
950         hmm_put(hmm);
951         range->hmm = NULL;
952 }
953 EXPORT_SYMBOL(hmm_range_unregister);
954
955 /*
956  * hmm_range_snapshot() - snapshot CPU page table for a range
957  * @range: range
958  * Return: -EINVAL if invalid argument, -ENOMEM out of memory, -EPERM invalid
959  *          permission (for instance asking for write and range is read only),
960  *          -EAGAIN if you need to retry, -EFAULT invalid (ie either no valid
961  *          vma or it is illegal to access that range), number of valid pages
962  *          in range->pfns[] (from range start address).
963  *
964  * This snapshots the CPU page table for a range of virtual addresses. Snapshot
965  * validity is tracked by range struct. See in include/linux/hmm.h for example
966  * on how to use.
967  */
968 long hmm_range_snapshot(struct hmm_range *range)
969 {
970         const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
971         unsigned long start = range->start, end;
972         struct hmm_vma_walk hmm_vma_walk;
973         struct hmm *hmm = range->hmm;
974         struct vm_area_struct *vma;
975         struct mm_walk mm_walk;
976
977         /* Check if hmm_mm_destroy() was call. */
978         if (hmm->mm == NULL || hmm->dead)
979                 return -EFAULT;
980
981         do {
982                 /* If range is no longer valid force retry. */
983                 if (!range->valid)
984                         return -EAGAIN;
985
986                 vma = find_vma(hmm->mm, start);
987                 if (vma == NULL || (vma->vm_flags & device_vma))
988                         return -EFAULT;
989
990                 if (is_vm_hugetlb_page(vma)) {
991                         if (huge_page_shift(hstate_vma(vma)) !=
992                                     range->page_shift &&
993                             range->page_shift != PAGE_SHIFT)
994                                 return -EINVAL;
995                 } else {
996                         if (range->page_shift != PAGE_SHIFT)
997                                 return -EINVAL;
998                 }
999
1000                 if (!(vma->vm_flags & VM_READ)) {
1001                         /*
1002                          * If vma do not allow read access, then assume that it
1003                          * does not allow write access, either. HMM does not
1004                          * support architecture that allow write without read.
1005                          */
1006                         hmm_pfns_clear(range, range->pfns,
1007                                 range->start, range->end);
1008                         return -EPERM;
1009                 }
1010
1011                 range->vma = vma;
1012                 hmm_vma_walk.pgmap = NULL;
1013                 hmm_vma_walk.last = start;
1014                 hmm_vma_walk.fault = false;
1015                 hmm_vma_walk.range = range;
1016                 mm_walk.private = &hmm_vma_walk;
1017                 end = min(range->end, vma->vm_end);
1018
1019                 mm_walk.vma = vma;
1020                 mm_walk.mm = vma->vm_mm;
1021                 mm_walk.pte_entry = NULL;
1022                 mm_walk.test_walk = NULL;
1023                 mm_walk.hugetlb_entry = NULL;
1024                 mm_walk.pud_entry = hmm_vma_walk_pud;
1025                 mm_walk.pmd_entry = hmm_vma_walk_pmd;
1026                 mm_walk.pte_hole = hmm_vma_walk_hole;
1027                 mm_walk.hugetlb_entry = hmm_vma_walk_hugetlb_entry;
1028
1029                 walk_page_range(start, end, &mm_walk);
1030                 start = end;
1031         } while (start < range->end);
1032
1033         return (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
1034 }
1035 EXPORT_SYMBOL(hmm_range_snapshot);
1036
1037 /*
1038  * hmm_range_fault() - try to fault some address in a virtual address range
1039  * @range: range being faulted
1040  * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem)
1041  * Return: number of valid pages in range->pfns[] (from range start
1042  *          address). This may be zero. If the return value is negative,
1043  *          then one of the following values may be returned:
1044  *
1045  *           -EINVAL  invalid arguments or mm or virtual address are in an
1046  *                    invalid vma (for instance device file vma).
1047  *           -ENOMEM: Out of memory.
1048  *           -EPERM:  Invalid permission (for instance asking for write and
1049  *                    range is read only).
1050  *           -EAGAIN: If you need to retry and mmap_sem was drop. This can only
1051  *                    happens if block argument is false.
1052  *           -EBUSY:  If the the range is being invalidated and you should wait
1053  *                    for invalidation to finish.
1054  *           -EFAULT: Invalid (ie either no valid vma or it is illegal to access
1055  *                    that range), number of valid pages in range->pfns[] (from
1056  *                    range start address).
1057  *
1058  * This is similar to a regular CPU page fault except that it will not trigger
1059  * any memory migration if the memory being faulted is not accessible by CPUs
1060  * and caller does not ask for migration.
1061  *
1062  * On error, for one virtual address in the range, the function will mark the
1063  * corresponding HMM pfn entry with an error flag.
1064  */
1065 long hmm_range_fault(struct hmm_range *range, bool block)
1066 {
1067         const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
1068         unsigned long start = range->start, end;
1069         struct hmm_vma_walk hmm_vma_walk;
1070         struct hmm *hmm = range->hmm;
1071         struct vm_area_struct *vma;
1072         struct mm_walk mm_walk;
1073         int ret;
1074
1075         /* Check if hmm_mm_destroy() was call. */
1076         if (hmm->mm == NULL || hmm->dead)
1077                 return -EFAULT;
1078
1079         do {
1080                 /* If range is no longer valid force retry. */
1081                 if (!range->valid) {
1082                         up_read(&hmm->mm->mmap_sem);
1083                         return -EAGAIN;
1084                 }
1085
1086                 vma = find_vma(hmm->mm, start);
1087                 if (vma == NULL || (vma->vm_flags & device_vma))
1088                         return -EFAULT;
1089
1090                 if (is_vm_hugetlb_page(vma)) {
1091                         if (huge_page_shift(hstate_vma(vma)) !=
1092                             range->page_shift &&
1093                             range->page_shift != PAGE_SHIFT)
1094                                 return -EINVAL;
1095                 } else {
1096                         if (range->page_shift != PAGE_SHIFT)
1097                                 return -EINVAL;
1098                 }
1099
1100                 if (!(vma->vm_flags & VM_READ)) {
1101                         /*
1102                          * If vma do not allow read access, then assume that it
1103                          * does not allow write access, either. HMM does not
1104                          * support architecture that allow write without read.
1105                          */
1106                         hmm_pfns_clear(range, range->pfns,
1107                                 range->start, range->end);
1108                         return -EPERM;
1109                 }
1110
1111                 range->vma = vma;
1112                 hmm_vma_walk.pgmap = NULL;
1113                 hmm_vma_walk.last = start;
1114                 hmm_vma_walk.fault = true;
1115                 hmm_vma_walk.block = block;
1116                 hmm_vma_walk.range = range;
1117                 mm_walk.private = &hmm_vma_walk;
1118                 end = min(range->end, vma->vm_end);
1119
1120                 mm_walk.vma = vma;
1121                 mm_walk.mm = vma->vm_mm;
1122                 mm_walk.pte_entry = NULL;
1123                 mm_walk.test_walk = NULL;
1124                 mm_walk.hugetlb_entry = NULL;
1125                 mm_walk.pud_entry = hmm_vma_walk_pud;
1126                 mm_walk.pmd_entry = hmm_vma_walk_pmd;
1127                 mm_walk.pte_hole = hmm_vma_walk_hole;
1128                 mm_walk.hugetlb_entry = hmm_vma_walk_hugetlb_entry;
1129
1130                 do {
1131                         ret = walk_page_range(start, end, &mm_walk);
1132                         start = hmm_vma_walk.last;
1133
1134                         /* Keep trying while the range is valid. */
1135                 } while (ret == -EBUSY && range->valid);
1136
1137                 if (ret) {
1138                         unsigned long i;
1139
1140                         i = (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
1141                         hmm_pfns_clear(range, &range->pfns[i],
1142                                 hmm_vma_walk.last, range->end);
1143                         return ret;
1144                 }
1145                 start = end;
1146
1147         } while (start < range->end);
1148
1149         return (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
1150 }
1151 EXPORT_SYMBOL(hmm_range_fault);
1152
1153 /**
1154  * hmm_range_dma_map() - hmm_range_fault() and dma map page all in one.
1155  * @range: range being faulted
1156  * @device: device against to dma map page to
1157  * @daddrs: dma address of mapped pages
1158  * @block: allow blocking on fault (if true it sleeps and do not drop mmap_sem)
1159  * Return: number of pages mapped on success, -EAGAIN if mmap_sem have been
1160  *          drop and you need to try again, some other error value otherwise
1161  *
1162  * Note same usage pattern as hmm_range_fault().
1163  */
1164 long hmm_range_dma_map(struct hmm_range *range,
1165                        struct device *device,
1166                        dma_addr_t *daddrs,
1167                        bool block)
1168 {
1169         unsigned long i, npages, mapped;
1170         long ret;
1171
1172         ret = hmm_range_fault(range, block);
1173         if (ret <= 0)
1174                 return ret ? ret : -EBUSY;
1175
1176         npages = (range->end - range->start) >> PAGE_SHIFT;
1177         for (i = 0, mapped = 0; i < npages; ++i) {
1178                 enum dma_data_direction dir = DMA_TO_DEVICE;
1179                 struct page *page;
1180
1181                 /*
1182                  * FIXME need to update DMA API to provide invalid DMA address
1183                  * value instead of a function to test dma address value. This
1184                  * would remove lot of dumb code duplicated accross many arch.
1185                  *
1186                  * For now setting it to 0 here is good enough as the pfns[]
1187                  * value is what is use to check what is valid and what isn't.
1188                  */
1189                 daddrs[i] = 0;
1190
1191                 page = hmm_device_entry_to_page(range, range->pfns[i]);
1192                 if (page == NULL)
1193                         continue;
1194
1195                 /* Check if range is being invalidated */
1196                 if (!range->valid) {
1197                         ret = -EBUSY;
1198                         goto unmap;
1199                 }
1200
1201                 /* If it is read and write than map bi-directional. */
1202                 if (range->pfns[i] & range->flags[HMM_PFN_WRITE])
1203                         dir = DMA_BIDIRECTIONAL;
1204
1205                 daddrs[i] = dma_map_page(device, page, 0, PAGE_SIZE, dir);
1206                 if (dma_mapping_error(device, daddrs[i])) {
1207                         ret = -EFAULT;
1208                         goto unmap;
1209                 }
1210
1211                 mapped++;
1212         }
1213
1214         return mapped;
1215
1216 unmap:
1217         for (npages = i, i = 0; (i < npages) && mapped; ++i) {
1218                 enum dma_data_direction dir = DMA_TO_DEVICE;
1219                 struct page *page;
1220
1221                 page = hmm_device_entry_to_page(range, range->pfns[i]);
1222                 if (page == NULL)
1223                         continue;
1224
1225                 if (dma_mapping_error(device, daddrs[i]))
1226                         continue;
1227
1228                 /* If it is read and write than map bi-directional. */
1229                 if (range->pfns[i] & range->flags[HMM_PFN_WRITE])
1230                         dir = DMA_BIDIRECTIONAL;
1231
1232                 dma_unmap_page(device, daddrs[i], PAGE_SIZE, dir);
1233                 mapped--;
1234         }
1235
1236         return ret;
1237 }
1238 EXPORT_SYMBOL(hmm_range_dma_map);
1239
1240 /**
1241  * hmm_range_dma_unmap() - unmap range of that was map with hmm_range_dma_map()
1242  * @range: range being unmapped
1243  * @vma: the vma against which the range (optional)
1244  * @device: device against which dma map was done
1245  * @daddrs: dma address of mapped pages
1246  * @dirty: dirty page if it had the write flag set
1247  * Return: number of page unmapped on success, -EINVAL otherwise
1248  *
1249  * Note that caller MUST abide by mmu notifier or use HMM mirror and abide
1250  * to the sync_cpu_device_pagetables() callback so that it is safe here to
1251  * call set_page_dirty(). Caller must also take appropriate locks to avoid
1252  * concurrent mmu notifier or sync_cpu_device_pagetables() to make progress.
1253  */
1254 long hmm_range_dma_unmap(struct hmm_range *range,
1255                          struct vm_area_struct *vma,
1256                          struct device *device,
1257                          dma_addr_t *daddrs,
1258                          bool dirty)
1259 {
1260         unsigned long i, npages;
1261         long cpages = 0;
1262
1263         /* Sanity check. */
1264         if (range->end <= range->start)
1265                 return -EINVAL;
1266         if (!daddrs)
1267                 return -EINVAL;
1268         if (!range->pfns)
1269                 return -EINVAL;
1270
1271         npages = (range->end - range->start) >> PAGE_SHIFT;
1272         for (i = 0; i < npages; ++i) {
1273                 enum dma_data_direction dir = DMA_TO_DEVICE;
1274                 struct page *page;
1275
1276                 page = hmm_device_entry_to_page(range, range->pfns[i]);
1277                 if (page == NULL)
1278                         continue;
1279
1280                 /* If it is read and write than map bi-directional. */
1281                 if (range->pfns[i] & range->flags[HMM_PFN_WRITE]) {
1282                         dir = DMA_BIDIRECTIONAL;
1283
1284                         /*
1285                          * See comments in function description on why it is
1286                          * safe here to call set_page_dirty()
1287                          */
1288                         if (dirty)
1289                                 set_page_dirty(page);
1290                 }
1291
1292                 /* Unmap and clear pfns/dma address */
1293                 dma_unmap_page(device, daddrs[i], PAGE_SIZE, dir);
1294                 range->pfns[i] = range->values[HMM_PFN_NONE];
1295                 /* FIXME see comments in hmm_vma_dma_map() */
1296                 daddrs[i] = 0;
1297                 cpages++;
1298         }
1299
1300         return cpages;
1301 }
1302 EXPORT_SYMBOL(hmm_range_dma_unmap);
1303 #endif /* IS_ENABLED(CONFIG_HMM_MIRROR) */
1304
1305
1306 #if IS_ENABLED(CONFIG_DEVICE_PRIVATE) ||  IS_ENABLED(CONFIG_DEVICE_PUBLIC)
1307 struct page *hmm_vma_alloc_locked_page(struct vm_area_struct *vma,
1308                                        unsigned long addr)
1309 {
1310         struct page *page;
1311
1312         page = alloc_page_vma(GFP_HIGHUSER, vma, addr);
1313         if (!page)
1314                 return NULL;
1315         lock_page(page);
1316         return page;
1317 }
1318 EXPORT_SYMBOL(hmm_vma_alloc_locked_page);
1319
1320
1321 static void hmm_devmem_ref_release(struct percpu_ref *ref)
1322 {
1323         struct hmm_devmem *devmem;
1324
1325         devmem = container_of(ref, struct hmm_devmem, ref);
1326         complete(&devmem->completion);
1327 }
1328
1329 static void hmm_devmem_ref_exit(void *data)
1330 {
1331         struct percpu_ref *ref = data;
1332         struct hmm_devmem *devmem;
1333
1334         devmem = container_of(ref, struct hmm_devmem, ref);
1335         wait_for_completion(&devmem->completion);
1336         percpu_ref_exit(ref);
1337 }
1338
1339 static void hmm_devmem_ref_kill(struct percpu_ref *ref)
1340 {
1341         percpu_ref_kill(ref);
1342 }
1343
1344 static vm_fault_t hmm_devmem_fault(struct vm_area_struct *vma,
1345                             unsigned long addr,
1346                             const struct page *page,
1347                             unsigned int flags,
1348                             pmd_t *pmdp)
1349 {
1350         struct hmm_devmem *devmem = page->pgmap->data;
1351
1352         return devmem->ops->fault(devmem, vma, addr, page, flags, pmdp);
1353 }
1354
1355 static void hmm_devmem_free(struct page *page, void *data)
1356 {
1357         struct hmm_devmem *devmem = data;
1358
1359         page->mapping = NULL;
1360
1361         devmem->ops->free(devmem, page);
1362 }
1363
1364 /*
1365  * hmm_devmem_add() - hotplug ZONE_DEVICE memory for device memory
1366  *
1367  * @ops: memory event device driver callback (see struct hmm_devmem_ops)
1368  * @device: device struct to bind the resource too
1369  * @size: size in bytes of the device memory to add
1370  * Return: pointer to new hmm_devmem struct ERR_PTR otherwise
1371  *
1372  * This function first finds an empty range of physical address big enough to
1373  * contain the new resource, and then hotplugs it as ZONE_DEVICE memory, which
1374  * in turn allocates struct pages. It does not do anything beyond that; all
1375  * events affecting the memory will go through the various callbacks provided
1376  * by hmm_devmem_ops struct.
1377  *
1378  * Device driver should call this function during device initialization and
1379  * is then responsible of memory management. HMM only provides helpers.
1380  */
1381 struct hmm_devmem *hmm_devmem_add(const struct hmm_devmem_ops *ops,
1382                                   struct device *device,
1383                                   unsigned long size)
1384 {
1385         struct hmm_devmem *devmem;
1386         resource_size_t addr;
1387         void *result;
1388         int ret;
1389
1390         dev_pagemap_get_ops();
1391
1392         devmem = devm_kzalloc(device, sizeof(*devmem), GFP_KERNEL);
1393         if (!devmem)
1394                 return ERR_PTR(-ENOMEM);
1395
1396         init_completion(&devmem->completion);
1397         devmem->pfn_first = -1UL;
1398         devmem->pfn_last = -1UL;
1399         devmem->resource = NULL;
1400         devmem->device = device;
1401         devmem->ops = ops;
1402
1403         ret = percpu_ref_init(&devmem->ref, &hmm_devmem_ref_release,
1404                               0, GFP_KERNEL);
1405         if (ret)
1406                 return ERR_PTR(ret);
1407
1408         ret = devm_add_action_or_reset(device, hmm_devmem_ref_exit, &devmem->ref);
1409         if (ret)
1410                 return ERR_PTR(ret);
1411
1412         size = ALIGN(size, PA_SECTION_SIZE);
1413         addr = min((unsigned long)iomem_resource.end,
1414                    (1UL << MAX_PHYSMEM_BITS) - 1);
1415         addr = addr - size + 1UL;
1416
1417         /*
1418          * FIXME add a new helper to quickly walk resource tree and find free
1419          * range
1420          *
1421          * FIXME what about ioport_resource resource ?
1422          */
1423         for (; addr > size && addr >= iomem_resource.start; addr -= size) {
1424                 ret = region_intersects(addr, size, 0, IORES_DESC_NONE);
1425                 if (ret != REGION_DISJOINT)
1426                         continue;
1427
1428                 devmem->resource = devm_request_mem_region(device, addr, size,
1429                                                            dev_name(device));
1430                 if (!devmem->resource)
1431                         return ERR_PTR(-ENOMEM);
1432                 break;
1433         }
1434         if (!devmem->resource)
1435                 return ERR_PTR(-ERANGE);
1436
1437         devmem->resource->desc = IORES_DESC_DEVICE_PRIVATE_MEMORY;
1438         devmem->pfn_first = devmem->resource->start >> PAGE_SHIFT;
1439         devmem->pfn_last = devmem->pfn_first +
1440                            (resource_size(devmem->resource) >> PAGE_SHIFT);
1441         devmem->page_fault = hmm_devmem_fault;
1442
1443         devmem->pagemap.type = MEMORY_DEVICE_PRIVATE;
1444         devmem->pagemap.res = *devmem->resource;
1445         devmem->pagemap.page_free = hmm_devmem_free;
1446         devmem->pagemap.altmap_valid = false;
1447         devmem->pagemap.ref = &devmem->ref;
1448         devmem->pagemap.data = devmem;
1449         devmem->pagemap.kill = hmm_devmem_ref_kill;
1450
1451         result = devm_memremap_pages(devmem->device, &devmem->pagemap);
1452         if (IS_ERR(result))
1453                 return result;
1454         return devmem;
1455 }
1456 EXPORT_SYMBOL_GPL(hmm_devmem_add);
1457
1458 struct hmm_devmem *hmm_devmem_add_resource(const struct hmm_devmem_ops *ops,
1459                                            struct device *device,
1460                                            struct resource *res)
1461 {
1462         struct hmm_devmem *devmem;
1463         void *result;
1464         int ret;
1465
1466         if (res->desc != IORES_DESC_DEVICE_PUBLIC_MEMORY)
1467                 return ERR_PTR(-EINVAL);
1468
1469         dev_pagemap_get_ops();
1470
1471         devmem = devm_kzalloc(device, sizeof(*devmem), GFP_KERNEL);
1472         if (!devmem)
1473                 return ERR_PTR(-ENOMEM);
1474
1475         init_completion(&devmem->completion);
1476         devmem->pfn_first = -1UL;
1477         devmem->pfn_last = -1UL;
1478         devmem->resource = res;
1479         devmem->device = device;
1480         devmem->ops = ops;
1481
1482         ret = percpu_ref_init(&devmem->ref, &hmm_devmem_ref_release,
1483                               0, GFP_KERNEL);
1484         if (ret)
1485                 return ERR_PTR(ret);
1486
1487         ret = devm_add_action_or_reset(device, hmm_devmem_ref_exit,
1488                         &devmem->ref);
1489         if (ret)
1490                 return ERR_PTR(ret);
1491
1492         devmem->pfn_first = devmem->resource->start >> PAGE_SHIFT;
1493         devmem->pfn_last = devmem->pfn_first +
1494                            (resource_size(devmem->resource) >> PAGE_SHIFT);
1495         devmem->page_fault = hmm_devmem_fault;
1496
1497         devmem->pagemap.type = MEMORY_DEVICE_PUBLIC;
1498         devmem->pagemap.res = *devmem->resource;
1499         devmem->pagemap.page_free = hmm_devmem_free;
1500         devmem->pagemap.altmap_valid = false;
1501         devmem->pagemap.ref = &devmem->ref;
1502         devmem->pagemap.data = devmem;
1503         devmem->pagemap.kill = hmm_devmem_ref_kill;
1504
1505         result = devm_memremap_pages(devmem->device, &devmem->pagemap);
1506         if (IS_ERR(result))
1507                 return result;
1508         return devmem;
1509 }
1510 EXPORT_SYMBOL_GPL(hmm_devmem_add_resource);
1511
1512 /*
1513  * A device driver that wants to handle multiple devices memory through a
1514  * single fake device can use hmm_device to do so. This is purely a helper
1515  * and it is not needed to make use of any HMM functionality.
1516  */
1517 #define HMM_DEVICE_MAX 256
1518
1519 static DECLARE_BITMAP(hmm_device_mask, HMM_DEVICE_MAX);
1520 static DEFINE_SPINLOCK(hmm_device_lock);
1521 static struct class *hmm_device_class;
1522 static dev_t hmm_device_devt;
1523
1524 static void hmm_device_release(struct device *device)
1525 {
1526         struct hmm_device *hmm_device;
1527
1528         hmm_device = container_of(device, struct hmm_device, device);
1529         spin_lock(&hmm_device_lock);
1530         clear_bit(hmm_device->minor, hmm_device_mask);
1531         spin_unlock(&hmm_device_lock);
1532
1533         kfree(hmm_device);
1534 }
1535
1536 struct hmm_device *hmm_device_new(void *drvdata)
1537 {
1538         struct hmm_device *hmm_device;
1539
1540         hmm_device = kzalloc(sizeof(*hmm_device), GFP_KERNEL);
1541         if (!hmm_device)
1542                 return ERR_PTR(-ENOMEM);
1543
1544         spin_lock(&hmm_device_lock);
1545         hmm_device->minor = find_first_zero_bit(hmm_device_mask, HMM_DEVICE_MAX);
1546         if (hmm_device->minor >= HMM_DEVICE_MAX) {
1547                 spin_unlock(&hmm_device_lock);
1548                 kfree(hmm_device);
1549                 return ERR_PTR(-EBUSY);
1550         }
1551         set_bit(hmm_device->minor, hmm_device_mask);
1552         spin_unlock(&hmm_device_lock);
1553
1554         dev_set_name(&hmm_device->device, "hmm_device%d", hmm_device->minor);
1555         hmm_device->device.devt = MKDEV(MAJOR(hmm_device_devt),
1556                                         hmm_device->minor);
1557         hmm_device->device.release = hmm_device_release;
1558         dev_set_drvdata(&hmm_device->device, drvdata);
1559         hmm_device->device.class = hmm_device_class;
1560         device_initialize(&hmm_device->device);
1561
1562         return hmm_device;
1563 }
1564 EXPORT_SYMBOL(hmm_device_new);
1565
1566 void hmm_device_put(struct hmm_device *hmm_device)
1567 {
1568         put_device(&hmm_device->device);
1569 }
1570 EXPORT_SYMBOL(hmm_device_put);
1571
1572 static int __init hmm_init(void)
1573 {
1574         int ret;
1575
1576         ret = alloc_chrdev_region(&hmm_device_devt, 0,
1577                                   HMM_DEVICE_MAX,
1578                                   "hmm_device");
1579         if (ret)
1580                 return ret;
1581
1582         hmm_device_class = class_create(THIS_MODULE, "hmm_device");
1583         if (IS_ERR(hmm_device_class)) {
1584                 unregister_chrdev_region(hmm_device_devt, HMM_DEVICE_MAX);
1585                 return PTR_ERR(hmm_device_class);
1586         }
1587         return 0;
1588 }
1589
1590 device_initcall(hmm_init);
1591 #endif /* CONFIG_DEVICE_PRIVATE || CONFIG_DEVICE_PUBLIC */