Merge tag 'v5.15-rc2' into spi-5.15
[platform/kernel/linux-rpi.git] / drivers / vhost / vdpa.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2018-2020 Intel Corporation.
4  * Copyright (C) 2020 Red Hat, Inc.
5  *
6  * Author: Tiwei Bie <tiwei.bie@intel.com>
7  *         Jason Wang <jasowang@redhat.com>
8  *
9  * Thanks Michael S. Tsirkin for the valuable comments and
10  * suggestions.  And thanks to Cunming Liang and Zhihong Wang for all
11  * their supports.
12  */
13
14 #include <linux/kernel.h>
15 #include <linux/module.h>
16 #include <linux/cdev.h>
17 #include <linux/device.h>
18 #include <linux/mm.h>
19 #include <linux/slab.h>
20 #include <linux/iommu.h>
21 #include <linux/uuid.h>
22 #include <linux/vdpa.h>
23 #include <linux/nospec.h>
24 #include <linux/vhost.h>
25
26 #include "vhost.h"
27
28 enum {
29         VHOST_VDPA_BACKEND_FEATURES =
30         (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) |
31         (1ULL << VHOST_BACKEND_F_IOTLB_BATCH),
32 };
33
34 #define VHOST_VDPA_DEV_MAX (1U << MINORBITS)
35
36 struct vhost_vdpa {
37         struct vhost_dev vdev;
38         struct iommu_domain *domain;
39         struct vhost_virtqueue *vqs;
40         struct completion completion;
41         struct vdpa_device *vdpa;
42         struct device dev;
43         struct cdev cdev;
44         atomic_t opened;
45         int nvqs;
46         int virtio_id;
47         int minor;
48         struct eventfd_ctx *config_ctx;
49         int in_batch;
50         struct vdpa_iova_range range;
51 };
52
53 static DEFINE_IDA(vhost_vdpa_ida);
54
55 static dev_t vhost_vdpa_major;
56
57 static void handle_vq_kick(struct vhost_work *work)
58 {
59         struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
60                                                   poll.work);
61         struct vhost_vdpa *v = container_of(vq->dev, struct vhost_vdpa, vdev);
62         const struct vdpa_config_ops *ops = v->vdpa->config;
63
64         ops->kick_vq(v->vdpa, vq - v->vqs);
65 }
66
67 static irqreturn_t vhost_vdpa_virtqueue_cb(void *private)
68 {
69         struct vhost_virtqueue *vq = private;
70         struct eventfd_ctx *call_ctx = vq->call_ctx.ctx;
71
72         if (call_ctx)
73                 eventfd_signal(call_ctx, 1);
74
75         return IRQ_HANDLED;
76 }
77
78 static irqreturn_t vhost_vdpa_config_cb(void *private)
79 {
80         struct vhost_vdpa *v = private;
81         struct eventfd_ctx *config_ctx = v->config_ctx;
82
83         if (config_ctx)
84                 eventfd_signal(config_ctx, 1);
85
86         return IRQ_HANDLED;
87 }
88
89 static void vhost_vdpa_setup_vq_irq(struct vhost_vdpa *v, u16 qid)
90 {
91         struct vhost_virtqueue *vq = &v->vqs[qid];
92         const struct vdpa_config_ops *ops = v->vdpa->config;
93         struct vdpa_device *vdpa = v->vdpa;
94         int ret, irq;
95
96         if (!ops->get_vq_irq)
97                 return;
98
99         irq = ops->get_vq_irq(vdpa, qid);
100         irq_bypass_unregister_producer(&vq->call_ctx.producer);
101         if (!vq->call_ctx.ctx || irq < 0)
102                 return;
103
104         vq->call_ctx.producer.token = vq->call_ctx.ctx;
105         vq->call_ctx.producer.irq = irq;
106         ret = irq_bypass_register_producer(&vq->call_ctx.producer);
107         if (unlikely(ret))
108                 dev_info(&v->dev, "vq %u, irq bypass producer (token %p) registration fails, ret =  %d\n",
109                          qid, vq->call_ctx.producer.token, ret);
110 }
111
112 static void vhost_vdpa_unsetup_vq_irq(struct vhost_vdpa *v, u16 qid)
113 {
114         struct vhost_virtqueue *vq = &v->vqs[qid];
115
116         irq_bypass_unregister_producer(&vq->call_ctx.producer);
117 }
118
119 static int vhost_vdpa_reset(struct vhost_vdpa *v)
120 {
121         struct vdpa_device *vdpa = v->vdpa;
122
123         v->in_batch = 0;
124
125         return vdpa_reset(vdpa);
126 }
127
128 static long vhost_vdpa_get_device_id(struct vhost_vdpa *v, u8 __user *argp)
129 {
130         struct vdpa_device *vdpa = v->vdpa;
131         const struct vdpa_config_ops *ops = vdpa->config;
132         u32 device_id;
133
134         device_id = ops->get_device_id(vdpa);
135
136         if (copy_to_user(argp, &device_id, sizeof(device_id)))
137                 return -EFAULT;
138
139         return 0;
140 }
141
142 static long vhost_vdpa_get_status(struct vhost_vdpa *v, u8 __user *statusp)
143 {
144         struct vdpa_device *vdpa = v->vdpa;
145         const struct vdpa_config_ops *ops = vdpa->config;
146         u8 status;
147
148         status = ops->get_status(vdpa);
149
150         if (copy_to_user(statusp, &status, sizeof(status)))
151                 return -EFAULT;
152
153         return 0;
154 }
155
156 static long vhost_vdpa_set_status(struct vhost_vdpa *v, u8 __user *statusp)
157 {
158         struct vdpa_device *vdpa = v->vdpa;
159         const struct vdpa_config_ops *ops = vdpa->config;
160         u8 status, status_old;
161         int ret, nvqs = v->nvqs;
162         u16 i;
163
164         if (copy_from_user(&status, statusp, sizeof(status)))
165                 return -EFAULT;
166
167         status_old = ops->get_status(vdpa);
168
169         /*
170          * Userspace shouldn't remove status bits unless reset the
171          * status to 0.
172          */
173         if (status != 0 && (ops->get_status(vdpa) & ~status) != 0)
174                 return -EINVAL;
175
176         if (status == 0) {
177                 ret = ops->reset(vdpa);
178                 if (ret)
179                         return ret;
180         } else
181                 ops->set_status(vdpa, status);
182
183         if ((status & VIRTIO_CONFIG_S_DRIVER_OK) && !(status_old & VIRTIO_CONFIG_S_DRIVER_OK))
184                 for (i = 0; i < nvqs; i++)
185                         vhost_vdpa_setup_vq_irq(v, i);
186
187         if ((status_old & VIRTIO_CONFIG_S_DRIVER_OK) && !(status & VIRTIO_CONFIG_S_DRIVER_OK))
188                 for (i = 0; i < nvqs; i++)
189                         vhost_vdpa_unsetup_vq_irq(v, i);
190
191         return 0;
192 }
193
194 static int vhost_vdpa_config_validate(struct vhost_vdpa *v,
195                                       struct vhost_vdpa_config *c)
196 {
197         struct vdpa_device *vdpa = v->vdpa;
198         long size = vdpa->config->get_config_size(vdpa);
199
200         if (c->len == 0)
201                 return -EINVAL;
202
203         if (c->len > size - c->off)
204                 return -E2BIG;
205
206         return 0;
207 }
208
209 static long vhost_vdpa_get_config(struct vhost_vdpa *v,
210                                   struct vhost_vdpa_config __user *c)
211 {
212         struct vdpa_device *vdpa = v->vdpa;
213         struct vhost_vdpa_config config;
214         unsigned long size = offsetof(struct vhost_vdpa_config, buf);
215         u8 *buf;
216
217         if (copy_from_user(&config, c, size))
218                 return -EFAULT;
219         if (vhost_vdpa_config_validate(v, &config))
220                 return -EINVAL;
221         buf = kvzalloc(config.len, GFP_KERNEL);
222         if (!buf)
223                 return -ENOMEM;
224
225         vdpa_get_config(vdpa, config.off, buf, config.len);
226
227         if (copy_to_user(c->buf, buf, config.len)) {
228                 kvfree(buf);
229                 return -EFAULT;
230         }
231
232         kvfree(buf);
233         return 0;
234 }
235
236 static long vhost_vdpa_set_config(struct vhost_vdpa *v,
237                                   struct vhost_vdpa_config __user *c)
238 {
239         struct vdpa_device *vdpa = v->vdpa;
240         const struct vdpa_config_ops *ops = vdpa->config;
241         struct vhost_vdpa_config config;
242         unsigned long size = offsetof(struct vhost_vdpa_config, buf);
243         u8 *buf;
244
245         if (copy_from_user(&config, c, size))
246                 return -EFAULT;
247         if (vhost_vdpa_config_validate(v, &config))
248                 return -EINVAL;
249
250         buf = vmemdup_user(c->buf, config.len);
251         if (IS_ERR(buf))
252                 return PTR_ERR(buf);
253
254         ops->set_config(vdpa, config.off, buf, config.len);
255
256         kvfree(buf);
257         return 0;
258 }
259
260 static long vhost_vdpa_get_features(struct vhost_vdpa *v, u64 __user *featurep)
261 {
262         struct vdpa_device *vdpa = v->vdpa;
263         const struct vdpa_config_ops *ops = vdpa->config;
264         u64 features;
265
266         features = ops->get_features(vdpa);
267
268         if (copy_to_user(featurep, &features, sizeof(features)))
269                 return -EFAULT;
270
271         return 0;
272 }
273
274 static long vhost_vdpa_set_features(struct vhost_vdpa *v, u64 __user *featurep)
275 {
276         struct vdpa_device *vdpa = v->vdpa;
277         const struct vdpa_config_ops *ops = vdpa->config;
278         u64 features;
279
280         /*
281          * It's not allowed to change the features after they have
282          * been negotiated.
283          */
284         if (ops->get_status(vdpa) & VIRTIO_CONFIG_S_FEATURES_OK)
285                 return -EBUSY;
286
287         if (copy_from_user(&features, featurep, sizeof(features)))
288                 return -EFAULT;
289
290         if (vdpa_set_features(vdpa, features))
291                 return -EINVAL;
292
293         return 0;
294 }
295
296 static long vhost_vdpa_get_vring_num(struct vhost_vdpa *v, u16 __user *argp)
297 {
298         struct vdpa_device *vdpa = v->vdpa;
299         const struct vdpa_config_ops *ops = vdpa->config;
300         u16 num;
301
302         num = ops->get_vq_num_max(vdpa);
303
304         if (copy_to_user(argp, &num, sizeof(num)))
305                 return -EFAULT;
306
307         return 0;
308 }
309
310 static void vhost_vdpa_config_put(struct vhost_vdpa *v)
311 {
312         if (v->config_ctx) {
313                 eventfd_ctx_put(v->config_ctx);
314                 v->config_ctx = NULL;
315         }
316 }
317
318 static long vhost_vdpa_set_config_call(struct vhost_vdpa *v, u32 __user *argp)
319 {
320         struct vdpa_callback cb;
321         int fd;
322         struct eventfd_ctx *ctx;
323
324         cb.callback = vhost_vdpa_config_cb;
325         cb.private = v->vdpa;
326         if (copy_from_user(&fd, argp, sizeof(fd)))
327                 return  -EFAULT;
328
329         ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
330         swap(ctx, v->config_ctx);
331
332         if (!IS_ERR_OR_NULL(ctx))
333                 eventfd_ctx_put(ctx);
334
335         if (IS_ERR(v->config_ctx)) {
336                 long ret = PTR_ERR(v->config_ctx);
337
338                 v->config_ctx = NULL;
339                 return ret;
340         }
341
342         v->vdpa->config->set_config_cb(v->vdpa, &cb);
343
344         return 0;
345 }
346
347 static long vhost_vdpa_get_iova_range(struct vhost_vdpa *v, u32 __user *argp)
348 {
349         struct vhost_vdpa_iova_range range = {
350                 .first = v->range.first,
351                 .last = v->range.last,
352         };
353
354         if (copy_to_user(argp, &range, sizeof(range)))
355                 return -EFAULT;
356         return 0;
357 }
358
359 static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd,
360                                    void __user *argp)
361 {
362         struct vdpa_device *vdpa = v->vdpa;
363         const struct vdpa_config_ops *ops = vdpa->config;
364         struct vdpa_vq_state vq_state;
365         struct vdpa_callback cb;
366         struct vhost_virtqueue *vq;
367         struct vhost_vring_state s;
368         u32 idx;
369         long r;
370
371         r = get_user(idx, (u32 __user *)argp);
372         if (r < 0)
373                 return r;
374
375         if (idx >= v->nvqs)
376                 return -ENOBUFS;
377
378         idx = array_index_nospec(idx, v->nvqs);
379         vq = &v->vqs[idx];
380
381         switch (cmd) {
382         case VHOST_VDPA_SET_VRING_ENABLE:
383                 if (copy_from_user(&s, argp, sizeof(s)))
384                         return -EFAULT;
385                 ops->set_vq_ready(vdpa, idx, s.num);
386                 return 0;
387         case VHOST_GET_VRING_BASE:
388                 r = ops->get_vq_state(v->vdpa, idx, &vq_state);
389                 if (r)
390                         return r;
391
392                 vq->last_avail_idx = vq_state.split.avail_index;
393                 break;
394         }
395
396         r = vhost_vring_ioctl(&v->vdev, cmd, argp);
397         if (r)
398                 return r;
399
400         switch (cmd) {
401         case VHOST_SET_VRING_ADDR:
402                 if (ops->set_vq_address(vdpa, idx,
403                                         (u64)(uintptr_t)vq->desc,
404                                         (u64)(uintptr_t)vq->avail,
405                                         (u64)(uintptr_t)vq->used))
406                         r = -EINVAL;
407                 break;
408
409         case VHOST_SET_VRING_BASE:
410                 vq_state.split.avail_index = vq->last_avail_idx;
411                 if (ops->set_vq_state(vdpa, idx, &vq_state))
412                         r = -EINVAL;
413                 break;
414
415         case VHOST_SET_VRING_CALL:
416                 if (vq->call_ctx.ctx) {
417                         cb.callback = vhost_vdpa_virtqueue_cb;
418                         cb.private = vq;
419                 } else {
420                         cb.callback = NULL;
421                         cb.private = NULL;
422                 }
423                 ops->set_vq_cb(vdpa, idx, &cb);
424                 vhost_vdpa_setup_vq_irq(v, idx);
425                 break;
426
427         case VHOST_SET_VRING_NUM:
428                 ops->set_vq_num(vdpa, idx, vq->num);
429                 break;
430         }
431
432         return r;
433 }
434
435 static long vhost_vdpa_unlocked_ioctl(struct file *filep,
436                                       unsigned int cmd, unsigned long arg)
437 {
438         struct vhost_vdpa *v = filep->private_data;
439         struct vhost_dev *d = &v->vdev;
440         void __user *argp = (void __user *)arg;
441         u64 __user *featurep = argp;
442         u64 features;
443         long r = 0;
444
445         if (cmd == VHOST_SET_BACKEND_FEATURES) {
446                 if (copy_from_user(&features, featurep, sizeof(features)))
447                         return -EFAULT;
448                 if (features & ~VHOST_VDPA_BACKEND_FEATURES)
449                         return -EOPNOTSUPP;
450                 vhost_set_backend_features(&v->vdev, features);
451                 return 0;
452         }
453
454         mutex_lock(&d->mutex);
455
456         switch (cmd) {
457         case VHOST_VDPA_GET_DEVICE_ID:
458                 r = vhost_vdpa_get_device_id(v, argp);
459                 break;
460         case VHOST_VDPA_GET_STATUS:
461                 r = vhost_vdpa_get_status(v, argp);
462                 break;
463         case VHOST_VDPA_SET_STATUS:
464                 r = vhost_vdpa_set_status(v, argp);
465                 break;
466         case VHOST_VDPA_GET_CONFIG:
467                 r = vhost_vdpa_get_config(v, argp);
468                 break;
469         case VHOST_VDPA_SET_CONFIG:
470                 r = vhost_vdpa_set_config(v, argp);
471                 break;
472         case VHOST_GET_FEATURES:
473                 r = vhost_vdpa_get_features(v, argp);
474                 break;
475         case VHOST_SET_FEATURES:
476                 r = vhost_vdpa_set_features(v, argp);
477                 break;
478         case VHOST_VDPA_GET_VRING_NUM:
479                 r = vhost_vdpa_get_vring_num(v, argp);
480                 break;
481         case VHOST_SET_LOG_BASE:
482         case VHOST_SET_LOG_FD:
483                 r = -ENOIOCTLCMD;
484                 break;
485         case VHOST_VDPA_SET_CONFIG_CALL:
486                 r = vhost_vdpa_set_config_call(v, argp);
487                 break;
488         case VHOST_GET_BACKEND_FEATURES:
489                 features = VHOST_VDPA_BACKEND_FEATURES;
490                 if (copy_to_user(featurep, &features, sizeof(features)))
491                         r = -EFAULT;
492                 break;
493         case VHOST_VDPA_GET_IOVA_RANGE:
494                 r = vhost_vdpa_get_iova_range(v, argp);
495                 break;
496         default:
497                 r = vhost_dev_ioctl(&v->vdev, cmd, argp);
498                 if (r == -ENOIOCTLCMD)
499                         r = vhost_vdpa_vring_ioctl(v, cmd, argp);
500                 break;
501         }
502
503         mutex_unlock(&d->mutex);
504         return r;
505 }
506
507 static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v, u64 start, u64 last)
508 {
509         struct vhost_dev *dev = &v->vdev;
510         struct vhost_iotlb *iotlb = dev->iotlb;
511         struct vhost_iotlb_map *map;
512         struct page *page;
513         unsigned long pfn, pinned;
514
515         while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
516                 pinned = PFN_DOWN(map->size);
517                 for (pfn = PFN_DOWN(map->addr);
518                      pinned > 0; pfn++, pinned--) {
519                         page = pfn_to_page(pfn);
520                         if (map->perm & VHOST_ACCESS_WO)
521                                 set_page_dirty_lock(page);
522                         unpin_user_page(page);
523                 }
524                 atomic64_sub(PFN_DOWN(map->size), &dev->mm->pinned_vm);
525                 vhost_iotlb_map_free(iotlb, map);
526         }
527 }
528
529 static void vhost_vdpa_va_unmap(struct vhost_vdpa *v, u64 start, u64 last)
530 {
531         struct vhost_dev *dev = &v->vdev;
532         struct vhost_iotlb *iotlb = dev->iotlb;
533         struct vhost_iotlb_map *map;
534         struct vdpa_map_file *map_file;
535
536         while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
537                 map_file = (struct vdpa_map_file *)map->opaque;
538                 fput(map_file->file);
539                 kfree(map_file);
540                 vhost_iotlb_map_free(iotlb, map);
541         }
542 }
543
544 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v, u64 start, u64 last)
545 {
546         struct vdpa_device *vdpa = v->vdpa;
547
548         if (vdpa->use_va)
549                 return vhost_vdpa_va_unmap(v, start, last);
550
551         return vhost_vdpa_pa_unmap(v, start, last);
552 }
553
554 static void vhost_vdpa_iotlb_free(struct vhost_vdpa *v)
555 {
556         struct vhost_dev *dev = &v->vdev;
557
558         vhost_vdpa_iotlb_unmap(v, 0ULL, 0ULL - 1);
559         kfree(dev->iotlb);
560         dev->iotlb = NULL;
561 }
562
563 static int perm_to_iommu_flags(u32 perm)
564 {
565         int flags = 0;
566
567         switch (perm) {
568         case VHOST_ACCESS_WO:
569                 flags |= IOMMU_WRITE;
570                 break;
571         case VHOST_ACCESS_RO:
572                 flags |= IOMMU_READ;
573                 break;
574         case VHOST_ACCESS_RW:
575                 flags |= (IOMMU_WRITE | IOMMU_READ);
576                 break;
577         default:
578                 WARN(1, "invalidate vhost IOTLB permission\n");
579                 break;
580         }
581
582         return flags | IOMMU_CACHE;
583 }
584
585 static int vhost_vdpa_map(struct vhost_vdpa *v, u64 iova,
586                           u64 size, u64 pa, u32 perm, void *opaque)
587 {
588         struct vhost_dev *dev = &v->vdev;
589         struct vdpa_device *vdpa = v->vdpa;
590         const struct vdpa_config_ops *ops = vdpa->config;
591         int r = 0;
592
593         r = vhost_iotlb_add_range_ctx(dev->iotlb, iova, iova + size - 1,
594                                       pa, perm, opaque);
595         if (r)
596                 return r;
597
598         if (ops->dma_map) {
599                 r = ops->dma_map(vdpa, iova, size, pa, perm, opaque);
600         } else if (ops->set_map) {
601                 if (!v->in_batch)
602                         r = ops->set_map(vdpa, dev->iotlb);
603         } else {
604                 r = iommu_map(v->domain, iova, pa, size,
605                               perm_to_iommu_flags(perm));
606         }
607         if (r) {
608                 vhost_iotlb_del_range(dev->iotlb, iova, iova + size - 1);
609                 return r;
610         }
611
612         if (!vdpa->use_va)
613                 atomic64_add(PFN_DOWN(size), &dev->mm->pinned_vm);
614
615         return 0;
616 }
617
618 static void vhost_vdpa_unmap(struct vhost_vdpa *v, u64 iova, u64 size)
619 {
620         struct vhost_dev *dev = &v->vdev;
621         struct vdpa_device *vdpa = v->vdpa;
622         const struct vdpa_config_ops *ops = vdpa->config;
623
624         vhost_vdpa_iotlb_unmap(v, iova, iova + size - 1);
625
626         if (ops->dma_map) {
627                 ops->dma_unmap(vdpa, iova, size);
628         } else if (ops->set_map) {
629                 if (!v->in_batch)
630                         ops->set_map(vdpa, dev->iotlb);
631         } else {
632                 iommu_unmap(v->domain, iova, size);
633         }
634 }
635
636 static int vhost_vdpa_va_map(struct vhost_vdpa *v,
637                              u64 iova, u64 size, u64 uaddr, u32 perm)
638 {
639         struct vhost_dev *dev = &v->vdev;
640         u64 offset, map_size, map_iova = iova;
641         struct vdpa_map_file *map_file;
642         struct vm_area_struct *vma;
643         int ret;
644
645         mmap_read_lock(dev->mm);
646
647         while (size) {
648                 vma = find_vma(dev->mm, uaddr);
649                 if (!vma) {
650                         ret = -EINVAL;
651                         break;
652                 }
653                 map_size = min(size, vma->vm_end - uaddr);
654                 if (!(vma->vm_file && (vma->vm_flags & VM_SHARED) &&
655                         !(vma->vm_flags & (VM_IO | VM_PFNMAP))))
656                         goto next;
657
658                 map_file = kzalloc(sizeof(*map_file), GFP_KERNEL);
659                 if (!map_file) {
660                         ret = -ENOMEM;
661                         break;
662                 }
663                 offset = (vma->vm_pgoff << PAGE_SHIFT) + uaddr - vma->vm_start;
664                 map_file->offset = offset;
665                 map_file->file = get_file(vma->vm_file);
666                 ret = vhost_vdpa_map(v, map_iova, map_size, uaddr,
667                                      perm, map_file);
668                 if (ret) {
669                         fput(map_file->file);
670                         kfree(map_file);
671                         break;
672                 }
673 next:
674                 size -= map_size;
675                 uaddr += map_size;
676                 map_iova += map_size;
677         }
678         if (ret)
679                 vhost_vdpa_unmap(v, iova, map_iova - iova);
680
681         mmap_read_unlock(dev->mm);
682
683         return ret;
684 }
685
686 static int vhost_vdpa_pa_map(struct vhost_vdpa *v,
687                              u64 iova, u64 size, u64 uaddr, u32 perm)
688 {
689         struct vhost_dev *dev = &v->vdev;
690         struct page **page_list;
691         unsigned long list_size = PAGE_SIZE / sizeof(struct page *);
692         unsigned int gup_flags = FOLL_LONGTERM;
693         unsigned long npages, cur_base, map_pfn, last_pfn = 0;
694         unsigned long lock_limit, sz2pin, nchunks, i;
695         u64 start = iova;
696         long pinned;
697         int ret = 0;
698
699         /* Limit the use of memory for bookkeeping */
700         page_list = (struct page **) __get_free_page(GFP_KERNEL);
701         if (!page_list)
702                 return -ENOMEM;
703
704         if (perm & VHOST_ACCESS_WO)
705                 gup_flags |= FOLL_WRITE;
706
707         npages = PFN_UP(size + (iova & ~PAGE_MASK));
708         if (!npages) {
709                 ret = -EINVAL;
710                 goto free;
711         }
712
713         mmap_read_lock(dev->mm);
714
715         lock_limit = PFN_DOWN(rlimit(RLIMIT_MEMLOCK));
716         if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
717                 ret = -ENOMEM;
718                 goto unlock;
719         }
720
721         cur_base = uaddr & PAGE_MASK;
722         iova &= PAGE_MASK;
723         nchunks = 0;
724
725         while (npages) {
726                 sz2pin = min_t(unsigned long, npages, list_size);
727                 pinned = pin_user_pages(cur_base, sz2pin,
728                                         gup_flags, page_list, NULL);
729                 if (sz2pin != pinned) {
730                         if (pinned < 0) {
731                                 ret = pinned;
732                         } else {
733                                 unpin_user_pages(page_list, pinned);
734                                 ret = -ENOMEM;
735                         }
736                         goto out;
737                 }
738                 nchunks++;
739
740                 if (!last_pfn)
741                         map_pfn = page_to_pfn(page_list[0]);
742
743                 for (i = 0; i < pinned; i++) {
744                         unsigned long this_pfn = page_to_pfn(page_list[i]);
745                         u64 csize;
746
747                         if (last_pfn && (this_pfn != last_pfn + 1)) {
748                                 /* Pin a contiguous chunk of memory */
749                                 csize = PFN_PHYS(last_pfn - map_pfn + 1);
750                                 ret = vhost_vdpa_map(v, iova, csize,
751                                                      PFN_PHYS(map_pfn),
752                                                      perm, NULL);
753                                 if (ret) {
754                                         /*
755                                          * Unpin the pages that are left unmapped
756                                          * from this point on in the current
757                                          * page_list. The remaining outstanding
758                                          * ones which may stride across several
759                                          * chunks will be covered in the common
760                                          * error path subsequently.
761                                          */
762                                         unpin_user_pages(&page_list[i],
763                                                          pinned - i);
764                                         goto out;
765                                 }
766
767                                 map_pfn = this_pfn;
768                                 iova += csize;
769                                 nchunks = 0;
770                         }
771
772                         last_pfn = this_pfn;
773                 }
774
775                 cur_base += PFN_PHYS(pinned);
776                 npages -= pinned;
777         }
778
779         /* Pin the rest chunk */
780         ret = vhost_vdpa_map(v, iova, PFN_PHYS(last_pfn - map_pfn + 1),
781                              PFN_PHYS(map_pfn), perm, NULL);
782 out:
783         if (ret) {
784                 if (nchunks) {
785                         unsigned long pfn;
786
787                         /*
788                          * Unpin the outstanding pages which are yet to be
789                          * mapped but haven't due to vdpa_map() or
790                          * pin_user_pages() failure.
791                          *
792                          * Mapped pages are accounted in vdpa_map(), hence
793                          * the corresponding unpinning will be handled by
794                          * vdpa_unmap().
795                          */
796                         WARN_ON(!last_pfn);
797                         for (pfn = map_pfn; pfn <= last_pfn; pfn++)
798                                 unpin_user_page(pfn_to_page(pfn));
799                 }
800                 vhost_vdpa_unmap(v, start, size);
801         }
802 unlock:
803         mmap_read_unlock(dev->mm);
804 free:
805         free_page((unsigned long)page_list);
806         return ret;
807
808 }
809
810 static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
811                                            struct vhost_iotlb_msg *msg)
812 {
813         struct vhost_dev *dev = &v->vdev;
814         struct vdpa_device *vdpa = v->vdpa;
815         struct vhost_iotlb *iotlb = dev->iotlb;
816
817         if (msg->iova < v->range.first || !msg->size ||
818             msg->iova > U64_MAX - msg->size + 1 ||
819             msg->iova + msg->size - 1 > v->range.last)
820                 return -EINVAL;
821
822         if (vhost_iotlb_itree_first(iotlb, msg->iova,
823                                     msg->iova + msg->size - 1))
824                 return -EEXIST;
825
826         if (vdpa->use_va)
827                 return vhost_vdpa_va_map(v, msg->iova, msg->size,
828                                          msg->uaddr, msg->perm);
829
830         return vhost_vdpa_pa_map(v, msg->iova, msg->size, msg->uaddr,
831                                  msg->perm);
832 }
833
834 static int vhost_vdpa_process_iotlb_msg(struct vhost_dev *dev,
835                                         struct vhost_iotlb_msg *msg)
836 {
837         struct vhost_vdpa *v = container_of(dev, struct vhost_vdpa, vdev);
838         struct vdpa_device *vdpa = v->vdpa;
839         const struct vdpa_config_ops *ops = vdpa->config;
840         int r = 0;
841
842         mutex_lock(&dev->mutex);
843
844         r = vhost_dev_check_owner(dev);
845         if (r)
846                 goto unlock;
847
848         switch (msg->type) {
849         case VHOST_IOTLB_UPDATE:
850                 r = vhost_vdpa_process_iotlb_update(v, msg);
851                 break;
852         case VHOST_IOTLB_INVALIDATE:
853                 vhost_vdpa_unmap(v, msg->iova, msg->size);
854                 break;
855         case VHOST_IOTLB_BATCH_BEGIN:
856                 v->in_batch = true;
857                 break;
858         case VHOST_IOTLB_BATCH_END:
859                 if (v->in_batch && ops->set_map)
860                         ops->set_map(vdpa, dev->iotlb);
861                 v->in_batch = false;
862                 break;
863         default:
864                 r = -EINVAL;
865                 break;
866         }
867 unlock:
868         mutex_unlock(&dev->mutex);
869
870         return r;
871 }
872
873 static ssize_t vhost_vdpa_chr_write_iter(struct kiocb *iocb,
874                                          struct iov_iter *from)
875 {
876         struct file *file = iocb->ki_filp;
877         struct vhost_vdpa *v = file->private_data;
878         struct vhost_dev *dev = &v->vdev;
879
880         return vhost_chr_write_iter(dev, from);
881 }
882
883 static int vhost_vdpa_alloc_domain(struct vhost_vdpa *v)
884 {
885         struct vdpa_device *vdpa = v->vdpa;
886         const struct vdpa_config_ops *ops = vdpa->config;
887         struct device *dma_dev = vdpa_get_dma_dev(vdpa);
888         struct bus_type *bus;
889         int ret;
890
891         /* Device want to do DMA by itself */
892         if (ops->set_map || ops->dma_map)
893                 return 0;
894
895         bus = dma_dev->bus;
896         if (!bus)
897                 return -EFAULT;
898
899         if (!iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
900                 return -ENOTSUPP;
901
902         v->domain = iommu_domain_alloc(bus);
903         if (!v->domain)
904                 return -EIO;
905
906         ret = iommu_attach_device(v->domain, dma_dev);
907         if (ret)
908                 goto err_attach;
909
910         return 0;
911
912 err_attach:
913         iommu_domain_free(v->domain);
914         return ret;
915 }
916
917 static void vhost_vdpa_free_domain(struct vhost_vdpa *v)
918 {
919         struct vdpa_device *vdpa = v->vdpa;
920         struct device *dma_dev = vdpa_get_dma_dev(vdpa);
921
922         if (v->domain) {
923                 iommu_detach_device(v->domain, dma_dev);
924                 iommu_domain_free(v->domain);
925         }
926
927         v->domain = NULL;
928 }
929
930 static void vhost_vdpa_set_iova_range(struct vhost_vdpa *v)
931 {
932         struct vdpa_iova_range *range = &v->range;
933         struct vdpa_device *vdpa = v->vdpa;
934         const struct vdpa_config_ops *ops = vdpa->config;
935
936         if (ops->get_iova_range) {
937                 *range = ops->get_iova_range(vdpa);
938         } else if (v->domain && v->domain->geometry.force_aperture) {
939                 range->first = v->domain->geometry.aperture_start;
940                 range->last = v->domain->geometry.aperture_end;
941         } else {
942                 range->first = 0;
943                 range->last = ULLONG_MAX;
944         }
945 }
946
947 static int vhost_vdpa_open(struct inode *inode, struct file *filep)
948 {
949         struct vhost_vdpa *v;
950         struct vhost_dev *dev;
951         struct vhost_virtqueue **vqs;
952         int nvqs, i, r, opened;
953
954         v = container_of(inode->i_cdev, struct vhost_vdpa, cdev);
955
956         opened = atomic_cmpxchg(&v->opened, 0, 1);
957         if (opened)
958                 return -EBUSY;
959
960         nvqs = v->nvqs;
961         r = vhost_vdpa_reset(v);
962         if (r)
963                 goto err;
964
965         vqs = kmalloc_array(nvqs, sizeof(*vqs), GFP_KERNEL);
966         if (!vqs) {
967                 r = -ENOMEM;
968                 goto err;
969         }
970
971         dev = &v->vdev;
972         for (i = 0; i < nvqs; i++) {
973                 vqs[i] = &v->vqs[i];
974                 vqs[i]->handle_kick = handle_vq_kick;
975         }
976         vhost_dev_init(dev, vqs, nvqs, 0, 0, 0, false,
977                        vhost_vdpa_process_iotlb_msg);
978
979         dev->iotlb = vhost_iotlb_alloc(0, 0);
980         if (!dev->iotlb) {
981                 r = -ENOMEM;
982                 goto err_init_iotlb;
983         }
984
985         r = vhost_vdpa_alloc_domain(v);
986         if (r)
987                 goto err_init_iotlb;
988
989         vhost_vdpa_set_iova_range(v);
990
991         filep->private_data = v;
992
993         return 0;
994
995 err_init_iotlb:
996         vhost_dev_cleanup(&v->vdev);
997         kfree(vqs);
998 err:
999         atomic_dec(&v->opened);
1000         return r;
1001 }
1002
1003 static void vhost_vdpa_clean_irq(struct vhost_vdpa *v)
1004 {
1005         int i;
1006
1007         for (i = 0; i < v->nvqs; i++)
1008                 vhost_vdpa_unsetup_vq_irq(v, i);
1009 }
1010
1011 static int vhost_vdpa_release(struct inode *inode, struct file *filep)
1012 {
1013         struct vhost_vdpa *v = filep->private_data;
1014         struct vhost_dev *d = &v->vdev;
1015
1016         mutex_lock(&d->mutex);
1017         filep->private_data = NULL;
1018         vhost_vdpa_reset(v);
1019         vhost_dev_stop(&v->vdev);
1020         vhost_vdpa_iotlb_free(v);
1021         vhost_vdpa_free_domain(v);
1022         vhost_vdpa_config_put(v);
1023         vhost_vdpa_clean_irq(v);
1024         vhost_dev_cleanup(&v->vdev);
1025         kfree(v->vdev.vqs);
1026         mutex_unlock(&d->mutex);
1027
1028         atomic_dec(&v->opened);
1029         complete(&v->completion);
1030
1031         return 0;
1032 }
1033
1034 #ifdef CONFIG_MMU
1035 static vm_fault_t vhost_vdpa_fault(struct vm_fault *vmf)
1036 {
1037         struct vhost_vdpa *v = vmf->vma->vm_file->private_data;
1038         struct vdpa_device *vdpa = v->vdpa;
1039         const struct vdpa_config_ops *ops = vdpa->config;
1040         struct vdpa_notification_area notify;
1041         struct vm_area_struct *vma = vmf->vma;
1042         u16 index = vma->vm_pgoff;
1043
1044         notify = ops->get_vq_notification(vdpa, index);
1045
1046         vma->vm_page_prot = pgprot_noncached(vma->vm_page_prot);
1047         if (remap_pfn_range(vma, vmf->address & PAGE_MASK,
1048                             PFN_DOWN(notify.addr), PAGE_SIZE,
1049                             vma->vm_page_prot))
1050                 return VM_FAULT_SIGBUS;
1051
1052         return VM_FAULT_NOPAGE;
1053 }
1054
1055 static const struct vm_operations_struct vhost_vdpa_vm_ops = {
1056         .fault = vhost_vdpa_fault,
1057 };
1058
1059 static int vhost_vdpa_mmap(struct file *file, struct vm_area_struct *vma)
1060 {
1061         struct vhost_vdpa *v = vma->vm_file->private_data;
1062         struct vdpa_device *vdpa = v->vdpa;
1063         const struct vdpa_config_ops *ops = vdpa->config;
1064         struct vdpa_notification_area notify;
1065         unsigned long index = vma->vm_pgoff;
1066
1067         if (vma->vm_end - vma->vm_start != PAGE_SIZE)
1068                 return -EINVAL;
1069         if ((vma->vm_flags & VM_SHARED) == 0)
1070                 return -EINVAL;
1071         if (vma->vm_flags & VM_READ)
1072                 return -EINVAL;
1073         if (index > 65535)
1074                 return -EINVAL;
1075         if (!ops->get_vq_notification)
1076                 return -ENOTSUPP;
1077
1078         /* To be safe and easily modelled by userspace, We only
1079          * support the doorbell which sits on the page boundary and
1080          * does not share the page with other registers.
1081          */
1082         notify = ops->get_vq_notification(vdpa, index);
1083         if (notify.addr & (PAGE_SIZE - 1))
1084                 return -EINVAL;
1085         if (vma->vm_end - vma->vm_start != notify.size)
1086                 return -ENOTSUPP;
1087
1088         vma->vm_flags |= VM_IO | VM_PFNMAP | VM_DONTEXPAND | VM_DONTDUMP;
1089         vma->vm_ops = &vhost_vdpa_vm_ops;
1090         return 0;
1091 }
1092 #endif /* CONFIG_MMU */
1093
1094 static const struct file_operations vhost_vdpa_fops = {
1095         .owner          = THIS_MODULE,
1096         .open           = vhost_vdpa_open,
1097         .release        = vhost_vdpa_release,
1098         .write_iter     = vhost_vdpa_chr_write_iter,
1099         .unlocked_ioctl = vhost_vdpa_unlocked_ioctl,
1100 #ifdef CONFIG_MMU
1101         .mmap           = vhost_vdpa_mmap,
1102 #endif /* CONFIG_MMU */
1103         .compat_ioctl   = compat_ptr_ioctl,
1104 };
1105
1106 static void vhost_vdpa_release_dev(struct device *device)
1107 {
1108         struct vhost_vdpa *v =
1109                container_of(device, struct vhost_vdpa, dev);
1110
1111         ida_simple_remove(&vhost_vdpa_ida, v->minor);
1112         kfree(v->vqs);
1113         kfree(v);
1114 }
1115
1116 static int vhost_vdpa_probe(struct vdpa_device *vdpa)
1117 {
1118         const struct vdpa_config_ops *ops = vdpa->config;
1119         struct vhost_vdpa *v;
1120         int minor;
1121         int r;
1122
1123         v = kzalloc(sizeof(*v), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
1124         if (!v)
1125                 return -ENOMEM;
1126
1127         minor = ida_simple_get(&vhost_vdpa_ida, 0,
1128                                VHOST_VDPA_DEV_MAX, GFP_KERNEL);
1129         if (minor < 0) {
1130                 kfree(v);
1131                 return minor;
1132         }
1133
1134         atomic_set(&v->opened, 0);
1135         v->minor = minor;
1136         v->vdpa = vdpa;
1137         v->nvqs = vdpa->nvqs;
1138         v->virtio_id = ops->get_device_id(vdpa);
1139
1140         device_initialize(&v->dev);
1141         v->dev.release = vhost_vdpa_release_dev;
1142         v->dev.parent = &vdpa->dev;
1143         v->dev.devt = MKDEV(MAJOR(vhost_vdpa_major), minor);
1144         v->vqs = kmalloc_array(v->nvqs, sizeof(struct vhost_virtqueue),
1145                                GFP_KERNEL);
1146         if (!v->vqs) {
1147                 r = -ENOMEM;
1148                 goto err;
1149         }
1150
1151         r = dev_set_name(&v->dev, "vhost-vdpa-%u", minor);
1152         if (r)
1153                 goto err;
1154
1155         cdev_init(&v->cdev, &vhost_vdpa_fops);
1156         v->cdev.owner = THIS_MODULE;
1157
1158         r = cdev_device_add(&v->cdev, &v->dev);
1159         if (r)
1160                 goto err;
1161
1162         init_completion(&v->completion);
1163         vdpa_set_drvdata(vdpa, v);
1164
1165         return 0;
1166
1167 err:
1168         put_device(&v->dev);
1169         return r;
1170 }
1171
1172 static void vhost_vdpa_remove(struct vdpa_device *vdpa)
1173 {
1174         struct vhost_vdpa *v = vdpa_get_drvdata(vdpa);
1175         int opened;
1176
1177         cdev_device_del(&v->cdev, &v->dev);
1178
1179         do {
1180                 opened = atomic_cmpxchg(&v->opened, 0, 1);
1181                 if (!opened)
1182                         break;
1183                 wait_for_completion(&v->completion);
1184         } while (1);
1185
1186         put_device(&v->dev);
1187 }
1188
1189 static struct vdpa_driver vhost_vdpa_driver = {
1190         .driver = {
1191                 .name   = "vhost_vdpa",
1192         },
1193         .probe  = vhost_vdpa_probe,
1194         .remove = vhost_vdpa_remove,
1195 };
1196
1197 static int __init vhost_vdpa_init(void)
1198 {
1199         int r;
1200
1201         r = alloc_chrdev_region(&vhost_vdpa_major, 0, VHOST_VDPA_DEV_MAX,
1202                                 "vhost-vdpa");
1203         if (r)
1204                 goto err_alloc_chrdev;
1205
1206         r = vdpa_register_driver(&vhost_vdpa_driver);
1207         if (r)
1208                 goto err_vdpa_register_driver;
1209
1210         return 0;
1211
1212 err_vdpa_register_driver:
1213         unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1214 err_alloc_chrdev:
1215         return r;
1216 }
1217 module_init(vhost_vdpa_init);
1218
1219 static void __exit vhost_vdpa_exit(void)
1220 {
1221         vdpa_unregister_driver(&vhost_vdpa_driver);
1222         unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1223 }
1224 module_exit(vhost_vdpa_exit);
1225
1226 MODULE_VERSION("0.0.1");
1227 MODULE_LICENSE("GPL v2");
1228 MODULE_AUTHOR("Intel Corporation");
1229 MODULE_DESCRIPTION("vDPA-based vhost backend for virtio");