Merge tag 'cxl-for-5.15' of git://git.kernel.org/pub/scm/linux/kernel/git/cxl/cxl
[platform/kernel/linux-rpi.git] / drivers / dax / device.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright(c) 2016-2018 Intel Corporation. All rights reserved. */
3 #include <linux/memremap.h>
4 #include <linux/pagemap.h>
5 #include <linux/module.h>
6 #include <linux/device.h>
7 #include <linux/pfn_t.h>
8 #include <linux/cdev.h>
9 #include <linux/slab.h>
10 #include <linux/dax.h>
11 #include <linux/fs.h>
12 #include <linux/mm.h>
13 #include <linux/mman.h>
14 #include "dax-private.h"
15 #include "bus.h"
16
17 static int check_vma(struct dev_dax *dev_dax, struct vm_area_struct *vma,
18                 const char *func)
19 {
20         struct device *dev = &dev_dax->dev;
21         unsigned long mask;
22
23         if (!dax_alive(dev_dax->dax_dev))
24                 return -ENXIO;
25
26         /* prevent private mappings from being established */
27         if ((vma->vm_flags & VM_MAYSHARE) != VM_MAYSHARE) {
28                 dev_info_ratelimited(dev,
29                                 "%s: %s: fail, attempted private mapping\n",
30                                 current->comm, func);
31                 return -EINVAL;
32         }
33
34         mask = dev_dax->align - 1;
35         if (vma->vm_start & mask || vma->vm_end & mask) {
36                 dev_info_ratelimited(dev,
37                                 "%s: %s: fail, unaligned vma (%#lx - %#lx, %#lx)\n",
38                                 current->comm, func, vma->vm_start, vma->vm_end,
39                                 mask);
40                 return -EINVAL;
41         }
42
43         if (!vma_is_dax(vma)) {
44                 dev_info_ratelimited(dev,
45                                 "%s: %s: fail, vma is not DAX capable\n",
46                                 current->comm, func);
47                 return -EINVAL;
48         }
49
50         return 0;
51 }
52
53 /* see "strong" declaration in tools/testing/nvdimm/dax-dev.c */
54 __weak phys_addr_t dax_pgoff_to_phys(struct dev_dax *dev_dax, pgoff_t pgoff,
55                 unsigned long size)
56 {
57         int i;
58
59         for (i = 0; i < dev_dax->nr_range; i++) {
60                 struct dev_dax_range *dax_range = &dev_dax->ranges[i];
61                 struct range *range = &dax_range->range;
62                 unsigned long long pgoff_end;
63                 phys_addr_t phys;
64
65                 pgoff_end = dax_range->pgoff + PHYS_PFN(range_len(range)) - 1;
66                 if (pgoff < dax_range->pgoff || pgoff > pgoff_end)
67                         continue;
68                 phys = PFN_PHYS(pgoff - dax_range->pgoff) + range->start;
69                 if (phys + size - 1 <= range->end)
70                         return phys;
71                 break;
72         }
73         return -1;
74 }
75
76 static vm_fault_t __dev_dax_pte_fault(struct dev_dax *dev_dax,
77                                 struct vm_fault *vmf, pfn_t *pfn)
78 {
79         struct device *dev = &dev_dax->dev;
80         phys_addr_t phys;
81         unsigned int fault_size = PAGE_SIZE;
82
83         if (check_vma(dev_dax, vmf->vma, __func__))
84                 return VM_FAULT_SIGBUS;
85
86         if (dev_dax->align > PAGE_SIZE) {
87                 dev_dbg(dev, "alignment (%#x) > fault size (%#x)\n",
88                         dev_dax->align, fault_size);
89                 return VM_FAULT_SIGBUS;
90         }
91
92         if (fault_size != dev_dax->align)
93                 return VM_FAULT_SIGBUS;
94
95         phys = dax_pgoff_to_phys(dev_dax, vmf->pgoff, PAGE_SIZE);
96         if (phys == -1) {
97                 dev_dbg(dev, "pgoff_to_phys(%#lx) failed\n", vmf->pgoff);
98                 return VM_FAULT_SIGBUS;
99         }
100
101         *pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
102
103         return vmf_insert_mixed(vmf->vma, vmf->address, *pfn);
104 }
105
106 static vm_fault_t __dev_dax_pmd_fault(struct dev_dax *dev_dax,
107                                 struct vm_fault *vmf, pfn_t *pfn)
108 {
109         unsigned long pmd_addr = vmf->address & PMD_MASK;
110         struct device *dev = &dev_dax->dev;
111         phys_addr_t phys;
112         pgoff_t pgoff;
113         unsigned int fault_size = PMD_SIZE;
114
115         if (check_vma(dev_dax, vmf->vma, __func__))
116                 return VM_FAULT_SIGBUS;
117
118         if (dev_dax->align > PMD_SIZE) {
119                 dev_dbg(dev, "alignment (%#x) > fault size (%#x)\n",
120                         dev_dax->align, fault_size);
121                 return VM_FAULT_SIGBUS;
122         }
123
124         if (fault_size < dev_dax->align)
125                 return VM_FAULT_SIGBUS;
126         else if (fault_size > dev_dax->align)
127                 return VM_FAULT_FALLBACK;
128
129         /* if we are outside of the VMA */
130         if (pmd_addr < vmf->vma->vm_start ||
131                         (pmd_addr + PMD_SIZE) > vmf->vma->vm_end)
132                 return VM_FAULT_SIGBUS;
133
134         pgoff = linear_page_index(vmf->vma, pmd_addr);
135         phys = dax_pgoff_to_phys(dev_dax, pgoff, PMD_SIZE);
136         if (phys == -1) {
137                 dev_dbg(dev, "pgoff_to_phys(%#lx) failed\n", pgoff);
138                 return VM_FAULT_SIGBUS;
139         }
140
141         *pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
142
143         return vmf_insert_pfn_pmd(vmf, *pfn, vmf->flags & FAULT_FLAG_WRITE);
144 }
145
146 #ifdef CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD
147 static vm_fault_t __dev_dax_pud_fault(struct dev_dax *dev_dax,
148                                 struct vm_fault *vmf, pfn_t *pfn)
149 {
150         unsigned long pud_addr = vmf->address & PUD_MASK;
151         struct device *dev = &dev_dax->dev;
152         phys_addr_t phys;
153         pgoff_t pgoff;
154         unsigned int fault_size = PUD_SIZE;
155
156
157         if (check_vma(dev_dax, vmf->vma, __func__))
158                 return VM_FAULT_SIGBUS;
159
160         if (dev_dax->align > PUD_SIZE) {
161                 dev_dbg(dev, "alignment (%#x) > fault size (%#x)\n",
162                         dev_dax->align, fault_size);
163                 return VM_FAULT_SIGBUS;
164         }
165
166         if (fault_size < dev_dax->align)
167                 return VM_FAULT_SIGBUS;
168         else if (fault_size > dev_dax->align)
169                 return VM_FAULT_FALLBACK;
170
171         /* if we are outside of the VMA */
172         if (pud_addr < vmf->vma->vm_start ||
173                         (pud_addr + PUD_SIZE) > vmf->vma->vm_end)
174                 return VM_FAULT_SIGBUS;
175
176         pgoff = linear_page_index(vmf->vma, pud_addr);
177         phys = dax_pgoff_to_phys(dev_dax, pgoff, PUD_SIZE);
178         if (phys == -1) {
179                 dev_dbg(dev, "pgoff_to_phys(%#lx) failed\n", pgoff);
180                 return VM_FAULT_SIGBUS;
181         }
182
183         *pfn = phys_to_pfn_t(phys, PFN_DEV|PFN_MAP);
184
185         return vmf_insert_pfn_pud(vmf, *pfn, vmf->flags & FAULT_FLAG_WRITE);
186 }
187 #else
188 static vm_fault_t __dev_dax_pud_fault(struct dev_dax *dev_dax,
189                                 struct vm_fault *vmf, pfn_t *pfn)
190 {
191         return VM_FAULT_FALLBACK;
192 }
193 #endif /* !CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD */
194
195 static vm_fault_t dev_dax_huge_fault(struct vm_fault *vmf,
196                 enum page_entry_size pe_size)
197 {
198         struct file *filp = vmf->vma->vm_file;
199         unsigned long fault_size;
200         vm_fault_t rc = VM_FAULT_SIGBUS;
201         int id;
202         pfn_t pfn;
203         struct dev_dax *dev_dax = filp->private_data;
204
205         dev_dbg(&dev_dax->dev, "%s: %s (%#lx - %#lx) size = %d\n", current->comm,
206                         (vmf->flags & FAULT_FLAG_WRITE) ? "write" : "read",
207                         vmf->vma->vm_start, vmf->vma->vm_end, pe_size);
208
209         id = dax_read_lock();
210         switch (pe_size) {
211         case PE_SIZE_PTE:
212                 fault_size = PAGE_SIZE;
213                 rc = __dev_dax_pte_fault(dev_dax, vmf, &pfn);
214                 break;
215         case PE_SIZE_PMD:
216                 fault_size = PMD_SIZE;
217                 rc = __dev_dax_pmd_fault(dev_dax, vmf, &pfn);
218                 break;
219         case PE_SIZE_PUD:
220                 fault_size = PUD_SIZE;
221                 rc = __dev_dax_pud_fault(dev_dax, vmf, &pfn);
222                 break;
223         default:
224                 rc = VM_FAULT_SIGBUS;
225         }
226
227         if (rc == VM_FAULT_NOPAGE) {
228                 unsigned long i;
229                 pgoff_t pgoff;
230
231                 /*
232                  * In the device-dax case the only possibility for a
233                  * VM_FAULT_NOPAGE result is when device-dax capacity is
234                  * mapped. No need to consider the zero page, or racing
235                  * conflicting mappings.
236                  */
237                 pgoff = linear_page_index(vmf->vma, vmf->address
238                                 & ~(fault_size - 1));
239                 for (i = 0; i < fault_size / PAGE_SIZE; i++) {
240                         struct page *page;
241
242                         page = pfn_to_page(pfn_t_to_pfn(pfn) + i);
243                         if (page->mapping)
244                                 continue;
245                         page->mapping = filp->f_mapping;
246                         page->index = pgoff + i;
247                 }
248         }
249         dax_read_unlock(id);
250
251         return rc;
252 }
253
254 static vm_fault_t dev_dax_fault(struct vm_fault *vmf)
255 {
256         return dev_dax_huge_fault(vmf, PE_SIZE_PTE);
257 }
258
259 static int dev_dax_may_split(struct vm_area_struct *vma, unsigned long addr)
260 {
261         struct file *filp = vma->vm_file;
262         struct dev_dax *dev_dax = filp->private_data;
263
264         if (!IS_ALIGNED(addr, dev_dax->align))
265                 return -EINVAL;
266         return 0;
267 }
268
269 static unsigned long dev_dax_pagesize(struct vm_area_struct *vma)
270 {
271         struct file *filp = vma->vm_file;
272         struct dev_dax *dev_dax = filp->private_data;
273
274         return dev_dax->align;
275 }
276
277 static const struct vm_operations_struct dax_vm_ops = {
278         .fault = dev_dax_fault,
279         .huge_fault = dev_dax_huge_fault,
280         .may_split = dev_dax_may_split,
281         .pagesize = dev_dax_pagesize,
282 };
283
284 static int dax_mmap(struct file *filp, struct vm_area_struct *vma)
285 {
286         struct dev_dax *dev_dax = filp->private_data;
287         int rc, id;
288
289         dev_dbg(&dev_dax->dev, "trace\n");
290
291         /*
292          * We lock to check dax_dev liveness and will re-check at
293          * fault time.
294          */
295         id = dax_read_lock();
296         rc = check_vma(dev_dax, vma, __func__);
297         dax_read_unlock(id);
298         if (rc)
299                 return rc;
300
301         vma->vm_ops = &dax_vm_ops;
302         vma->vm_flags |= VM_HUGEPAGE;
303         return 0;
304 }
305
306 /* return an unmapped area aligned to the dax region specified alignment */
307 static unsigned long dax_get_unmapped_area(struct file *filp,
308                 unsigned long addr, unsigned long len, unsigned long pgoff,
309                 unsigned long flags)
310 {
311         unsigned long off, off_end, off_align, len_align, addr_align, align;
312         struct dev_dax *dev_dax = filp ? filp->private_data : NULL;
313
314         if (!dev_dax || addr)
315                 goto out;
316
317         align = dev_dax->align;
318         off = pgoff << PAGE_SHIFT;
319         off_end = off + len;
320         off_align = round_up(off, align);
321
322         if ((off_end <= off_align) || ((off_end - off_align) < align))
323                 goto out;
324
325         len_align = len + align;
326         if ((off + len_align) < off)
327                 goto out;
328
329         addr_align = current->mm->get_unmapped_area(filp, addr, len_align,
330                         pgoff, flags);
331         if (!IS_ERR_VALUE(addr_align)) {
332                 addr_align += (off - addr_align) & (align - 1);
333                 return addr_align;
334         }
335  out:
336         return current->mm->get_unmapped_area(filp, addr, len, pgoff, flags);
337 }
338
339 static const struct address_space_operations dev_dax_aops = {
340         .set_page_dirty         = __set_page_dirty_no_writeback,
341         .invalidatepage         = noop_invalidatepage,
342 };
343
344 static int dax_open(struct inode *inode, struct file *filp)
345 {
346         struct dax_device *dax_dev = inode_dax(inode);
347         struct inode *__dax_inode = dax_inode(dax_dev);
348         struct dev_dax *dev_dax = dax_get_private(dax_dev);
349
350         dev_dbg(&dev_dax->dev, "trace\n");
351         inode->i_mapping = __dax_inode->i_mapping;
352         inode->i_mapping->host = __dax_inode;
353         inode->i_mapping->a_ops = &dev_dax_aops;
354         filp->f_mapping = inode->i_mapping;
355         filp->f_wb_err = filemap_sample_wb_err(filp->f_mapping);
356         filp->f_sb_err = file_sample_sb_err(filp);
357         filp->private_data = dev_dax;
358         inode->i_flags = S_DAX;
359
360         return 0;
361 }
362
363 static int dax_release(struct inode *inode, struct file *filp)
364 {
365         struct dev_dax *dev_dax = filp->private_data;
366
367         dev_dbg(&dev_dax->dev, "trace\n");
368         return 0;
369 }
370
371 static const struct file_operations dax_fops = {
372         .llseek = noop_llseek,
373         .owner = THIS_MODULE,
374         .open = dax_open,
375         .release = dax_release,
376         .get_unmapped_area = dax_get_unmapped_area,
377         .mmap = dax_mmap,
378         .mmap_supported_flags = MAP_SYNC,
379 };
380
381 static void dev_dax_cdev_del(void *cdev)
382 {
383         cdev_del(cdev);
384 }
385
386 static void dev_dax_kill(void *dev_dax)
387 {
388         kill_dev_dax(dev_dax);
389 }
390
391 int dev_dax_probe(struct dev_dax *dev_dax)
392 {
393         struct dax_device *dax_dev = dev_dax->dax_dev;
394         struct device *dev = &dev_dax->dev;
395         struct dev_pagemap *pgmap;
396         struct inode *inode;
397         struct cdev *cdev;
398         void *addr;
399         int rc, i;
400
401         pgmap = dev_dax->pgmap;
402         if (dev_WARN_ONCE(dev, pgmap && dev_dax->nr_range > 1,
403                         "static pgmap / multi-range device conflict\n"))
404                 return -EINVAL;
405
406         if (!pgmap) {
407                 pgmap = devm_kzalloc(dev, sizeof(*pgmap) + sizeof(struct range)
408                                 * (dev_dax->nr_range - 1), GFP_KERNEL);
409                 if (!pgmap)
410                         return -ENOMEM;
411                 pgmap->nr_range = dev_dax->nr_range;
412         }
413
414         for (i = 0; i < dev_dax->nr_range; i++) {
415                 struct range *range = &dev_dax->ranges[i].range;
416
417                 if (!devm_request_mem_region(dev, range->start,
418                                         range_len(range), dev_name(dev))) {
419                         dev_warn(dev, "mapping%d: %#llx-%#llx could not reserve range\n",
420                                         i, range->start, range->end);
421                         return -EBUSY;
422                 }
423                 /* don't update the range for static pgmap */
424                 if (!dev_dax->pgmap)
425                         pgmap->ranges[i] = *range;
426         }
427
428         pgmap->type = MEMORY_DEVICE_GENERIC;
429         addr = devm_memremap_pages(dev, pgmap);
430         if (IS_ERR(addr))
431                 return PTR_ERR(addr);
432
433         inode = dax_inode(dax_dev);
434         cdev = inode->i_cdev;
435         cdev_init(cdev, &dax_fops);
436         if (dev->class) {
437                 /* for the CONFIG_DEV_DAX_PMEM_COMPAT case */
438                 cdev->owner = dev->parent->driver->owner;
439         } else
440                 cdev->owner = dev->driver->owner;
441         cdev_set_parent(cdev, &dev->kobj);
442         rc = cdev_add(cdev, dev->devt, 1);
443         if (rc)
444                 return rc;
445
446         rc = devm_add_action_or_reset(dev, dev_dax_cdev_del, cdev);
447         if (rc)
448                 return rc;
449
450         run_dax(dax_dev);
451         return devm_add_action_or_reset(dev, dev_dax_kill, dev_dax);
452 }
453 EXPORT_SYMBOL_GPL(dev_dax_probe);
454
455 static struct dax_device_driver device_dax_driver = {
456         .probe = dev_dax_probe,
457         /* all probe actions are unwound by devm, so .remove isn't necessary */
458         .match_always = 1,
459 };
460
461 static int __init dax_init(void)
462 {
463         return dax_driver_register(&device_dax_driver);
464 }
465
466 static void __exit dax_exit(void)
467 {
468         dax_driver_unregister(&device_dax_driver);
469 }
470
471 MODULE_AUTHOR("Intel Corporation");
472 MODULE_LICENSE("GPL v2");
473 module_init(dax_init);
474 module_exit(dax_exit);
475 MODULE_ALIAS_DAX_DEVICE(0);