vt_ioctl: fix GIO_UNIMAP regression
[platform/kernel/linux-rpi.git] / drivers / vhost / vdpa.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2018-2020 Intel Corporation.
4  * Copyright (C) 2020 Red Hat, Inc.
5  *
6  * Author: Tiwei Bie <tiwei.bie@intel.com>
7  *         Jason Wang <jasowang@redhat.com>
8  *
9  * Thanks Michael S. Tsirkin for the valuable comments and
10  * suggestions.  And thanks to Cunming Liang and Zhihong Wang for all
11  * their supports.
12  */
13
14 #include <linux/kernel.h>
15 #include <linux/module.h>
16 #include <linux/cdev.h>
17 #include <linux/device.h>
18 #include <linux/mm.h>
19 #include <linux/iommu.h>
20 #include <linux/uuid.h>
21 #include <linux/vdpa.h>
22 #include <linux/nospec.h>
23 #include <linux/vhost.h>
24 #include <linux/virtio_net.h>
25
26 #include "vhost.h"
27
28 enum {
29         VHOST_VDPA_BACKEND_FEATURES =
30         (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) |
31         (1ULL << VHOST_BACKEND_F_IOTLB_BATCH),
32 };
33
34 #define VHOST_VDPA_DEV_MAX (1U << MINORBITS)
35
36 struct vhost_vdpa {
37         struct vhost_dev vdev;
38         struct iommu_domain *domain;
39         struct vhost_virtqueue *vqs;
40         struct completion completion;
41         struct vdpa_device *vdpa;
42         struct device dev;
43         struct cdev cdev;
44         atomic_t opened;
45         int nvqs;
46         int virtio_id;
47         int minor;
48         struct eventfd_ctx *config_ctx;
49         int in_batch;
50 };
51
52 static DEFINE_IDA(vhost_vdpa_ida);
53
54 static dev_t vhost_vdpa_major;
55
56 static void handle_vq_kick(struct vhost_work *work)
57 {
58         struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
59                                                   poll.work);
60         struct vhost_vdpa *v = container_of(vq->dev, struct vhost_vdpa, vdev);
61         const struct vdpa_config_ops *ops = v->vdpa->config;
62
63         ops->kick_vq(v->vdpa, vq - v->vqs);
64 }
65
66 static irqreturn_t vhost_vdpa_virtqueue_cb(void *private)
67 {
68         struct vhost_virtqueue *vq = private;
69         struct eventfd_ctx *call_ctx = vq->call_ctx.ctx;
70
71         if (call_ctx)
72                 eventfd_signal(call_ctx, 1);
73
74         return IRQ_HANDLED;
75 }
76
77 static irqreturn_t vhost_vdpa_config_cb(void *private)
78 {
79         struct vhost_vdpa *v = private;
80         struct eventfd_ctx *config_ctx = v->config_ctx;
81
82         if (config_ctx)
83                 eventfd_signal(config_ctx, 1);
84
85         return IRQ_HANDLED;
86 }
87
88 static void vhost_vdpa_setup_vq_irq(struct vhost_vdpa *v, u16 qid)
89 {
90         struct vhost_virtqueue *vq = &v->vqs[qid];
91         const struct vdpa_config_ops *ops = v->vdpa->config;
92         struct vdpa_device *vdpa = v->vdpa;
93         int ret, irq;
94
95         if (!ops->get_vq_irq)
96                 return;
97
98         irq = ops->get_vq_irq(vdpa, qid);
99         irq_bypass_unregister_producer(&vq->call_ctx.producer);
100         if (!vq->call_ctx.ctx || irq < 0)
101                 return;
102
103         vq->call_ctx.producer.token = vq->call_ctx.ctx;
104         vq->call_ctx.producer.irq = irq;
105         ret = irq_bypass_register_producer(&vq->call_ctx.producer);
106 }
107
108 static void vhost_vdpa_unsetup_vq_irq(struct vhost_vdpa *v, u16 qid)
109 {
110         struct vhost_virtqueue *vq = &v->vqs[qid];
111
112         irq_bypass_unregister_producer(&vq->call_ctx.producer);
113 }
114
115 static void vhost_vdpa_reset(struct vhost_vdpa *v)
116 {
117         struct vdpa_device *vdpa = v->vdpa;
118
119         vdpa_reset(vdpa);
120         v->in_batch = 0;
121 }
122
123 static long vhost_vdpa_get_device_id(struct vhost_vdpa *v, u8 __user *argp)
124 {
125         struct vdpa_device *vdpa = v->vdpa;
126         const struct vdpa_config_ops *ops = vdpa->config;
127         u32 device_id;
128
129         device_id = ops->get_device_id(vdpa);
130
131         if (copy_to_user(argp, &device_id, sizeof(device_id)))
132                 return -EFAULT;
133
134         return 0;
135 }
136
137 static long vhost_vdpa_get_status(struct vhost_vdpa *v, u8 __user *statusp)
138 {
139         struct vdpa_device *vdpa = v->vdpa;
140         const struct vdpa_config_ops *ops = vdpa->config;
141         u8 status;
142
143         status = ops->get_status(vdpa);
144
145         if (copy_to_user(statusp, &status, sizeof(status)))
146                 return -EFAULT;
147
148         return 0;
149 }
150
151 static long vhost_vdpa_set_status(struct vhost_vdpa *v, u8 __user *statusp)
152 {
153         struct vdpa_device *vdpa = v->vdpa;
154         const struct vdpa_config_ops *ops = vdpa->config;
155         u8 status, status_old;
156         int nvqs = v->nvqs;
157         u16 i;
158
159         if (copy_from_user(&status, statusp, sizeof(status)))
160                 return -EFAULT;
161
162         status_old = ops->get_status(vdpa);
163
164         /*
165          * Userspace shouldn't remove status bits unless reset the
166          * status to 0.
167          */
168         if (status != 0 && (ops->get_status(vdpa) & ~status) != 0)
169                 return -EINVAL;
170
171         ops->set_status(vdpa, status);
172
173         if ((status & VIRTIO_CONFIG_S_DRIVER_OK) && !(status_old & VIRTIO_CONFIG_S_DRIVER_OK))
174                 for (i = 0; i < nvqs; i++)
175                         vhost_vdpa_setup_vq_irq(v, i);
176
177         if ((status_old & VIRTIO_CONFIG_S_DRIVER_OK) && !(status & VIRTIO_CONFIG_S_DRIVER_OK))
178                 for (i = 0; i < nvqs; i++)
179                         vhost_vdpa_unsetup_vq_irq(v, i);
180
181         return 0;
182 }
183
184 static int vhost_vdpa_config_validate(struct vhost_vdpa *v,
185                                       struct vhost_vdpa_config *c)
186 {
187         long size = 0;
188
189         switch (v->virtio_id) {
190         case VIRTIO_ID_NET:
191                 size = sizeof(struct virtio_net_config);
192                 break;
193         }
194
195         if (c->len == 0)
196                 return -EINVAL;
197
198         if (c->len > size - c->off)
199                 return -E2BIG;
200
201         return 0;
202 }
203
204 static long vhost_vdpa_get_config(struct vhost_vdpa *v,
205                                   struct vhost_vdpa_config __user *c)
206 {
207         struct vdpa_device *vdpa = v->vdpa;
208         struct vhost_vdpa_config config;
209         unsigned long size = offsetof(struct vhost_vdpa_config, buf);
210         u8 *buf;
211
212         if (copy_from_user(&config, c, size))
213                 return -EFAULT;
214         if (vhost_vdpa_config_validate(v, &config))
215                 return -EINVAL;
216         buf = kvzalloc(config.len, GFP_KERNEL);
217         if (!buf)
218                 return -ENOMEM;
219
220         vdpa_get_config(vdpa, config.off, buf, config.len);
221
222         if (copy_to_user(c->buf, buf, config.len)) {
223                 kvfree(buf);
224                 return -EFAULT;
225         }
226
227         kvfree(buf);
228         return 0;
229 }
230
231 static long vhost_vdpa_set_config(struct vhost_vdpa *v,
232                                   struct vhost_vdpa_config __user *c)
233 {
234         struct vdpa_device *vdpa = v->vdpa;
235         const struct vdpa_config_ops *ops = vdpa->config;
236         struct vhost_vdpa_config config;
237         unsigned long size = offsetof(struct vhost_vdpa_config, buf);
238         u8 *buf;
239
240         if (copy_from_user(&config, c, size))
241                 return -EFAULT;
242         if (vhost_vdpa_config_validate(v, &config))
243                 return -EINVAL;
244         buf = kvzalloc(config.len, GFP_KERNEL);
245         if (!buf)
246                 return -ENOMEM;
247
248         if (copy_from_user(buf, c->buf, config.len)) {
249                 kvfree(buf);
250                 return -EFAULT;
251         }
252
253         ops->set_config(vdpa, config.off, buf, config.len);
254
255         kvfree(buf);
256         return 0;
257 }
258
259 static long vhost_vdpa_get_features(struct vhost_vdpa *v, u64 __user *featurep)
260 {
261         struct vdpa_device *vdpa = v->vdpa;
262         const struct vdpa_config_ops *ops = vdpa->config;
263         u64 features;
264
265         features = ops->get_features(vdpa);
266
267         if (copy_to_user(featurep, &features, sizeof(features)))
268                 return -EFAULT;
269
270         return 0;
271 }
272
273 static long vhost_vdpa_set_features(struct vhost_vdpa *v, u64 __user *featurep)
274 {
275         struct vdpa_device *vdpa = v->vdpa;
276         const struct vdpa_config_ops *ops = vdpa->config;
277         u64 features;
278
279         /*
280          * It's not allowed to change the features after they have
281          * been negotiated.
282          */
283         if (ops->get_status(vdpa) & VIRTIO_CONFIG_S_FEATURES_OK)
284                 return -EBUSY;
285
286         if (copy_from_user(&features, featurep, sizeof(features)))
287                 return -EFAULT;
288
289         if (vdpa_set_features(vdpa, features))
290                 return -EINVAL;
291
292         return 0;
293 }
294
295 static long vhost_vdpa_get_vring_num(struct vhost_vdpa *v, u16 __user *argp)
296 {
297         struct vdpa_device *vdpa = v->vdpa;
298         const struct vdpa_config_ops *ops = vdpa->config;
299         u16 num;
300
301         num = ops->get_vq_num_max(vdpa);
302
303         if (copy_to_user(argp, &num, sizeof(num)))
304                 return -EFAULT;
305
306         return 0;
307 }
308
309 static void vhost_vdpa_config_put(struct vhost_vdpa *v)
310 {
311         if (v->config_ctx)
312                 eventfd_ctx_put(v->config_ctx);
313 }
314
315 static long vhost_vdpa_set_config_call(struct vhost_vdpa *v, u32 __user *argp)
316 {
317         struct vdpa_callback cb;
318         int fd;
319         struct eventfd_ctx *ctx;
320
321         cb.callback = vhost_vdpa_config_cb;
322         cb.private = v->vdpa;
323         if (copy_from_user(&fd, argp, sizeof(fd)))
324                 return  -EFAULT;
325
326         ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
327         swap(ctx, v->config_ctx);
328
329         if (!IS_ERR_OR_NULL(ctx))
330                 eventfd_ctx_put(ctx);
331
332         if (IS_ERR(v->config_ctx))
333                 return PTR_ERR(v->config_ctx);
334
335         v->vdpa->config->set_config_cb(v->vdpa, &cb);
336
337         return 0;
338 }
339
340 static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd,
341                                    void __user *argp)
342 {
343         struct vdpa_device *vdpa = v->vdpa;
344         const struct vdpa_config_ops *ops = vdpa->config;
345         struct vdpa_vq_state vq_state;
346         struct vdpa_callback cb;
347         struct vhost_virtqueue *vq;
348         struct vhost_vring_state s;
349         u32 idx;
350         long r;
351
352         r = get_user(idx, (u32 __user *)argp);
353         if (r < 0)
354                 return r;
355
356         if (idx >= v->nvqs)
357                 return -ENOBUFS;
358
359         idx = array_index_nospec(idx, v->nvqs);
360         vq = &v->vqs[idx];
361
362         switch (cmd) {
363         case VHOST_VDPA_SET_VRING_ENABLE:
364                 if (copy_from_user(&s, argp, sizeof(s)))
365                         return -EFAULT;
366                 ops->set_vq_ready(vdpa, idx, s.num);
367                 return 0;
368         case VHOST_GET_VRING_BASE:
369                 r = ops->get_vq_state(v->vdpa, idx, &vq_state);
370                 if (r)
371                         return r;
372
373                 vq->last_avail_idx = vq_state.avail_index;
374                 break;
375         }
376
377         r = vhost_vring_ioctl(&v->vdev, cmd, argp);
378         if (r)
379                 return r;
380
381         switch (cmd) {
382         case VHOST_SET_VRING_ADDR:
383                 if (ops->set_vq_address(vdpa, idx,
384                                         (u64)(uintptr_t)vq->desc,
385                                         (u64)(uintptr_t)vq->avail,
386                                         (u64)(uintptr_t)vq->used))
387                         r = -EINVAL;
388                 break;
389
390         case VHOST_SET_VRING_BASE:
391                 vq_state.avail_index = vq->last_avail_idx;
392                 if (ops->set_vq_state(vdpa, idx, &vq_state))
393                         r = -EINVAL;
394                 break;
395
396         case VHOST_SET_VRING_CALL:
397                 if (vq->call_ctx.ctx) {
398                         cb.callback = vhost_vdpa_virtqueue_cb;
399                         cb.private = vq;
400                 } else {
401                         cb.callback = NULL;
402                         cb.private = NULL;
403                 }
404                 ops->set_vq_cb(vdpa, idx, &cb);
405                 vhost_vdpa_setup_vq_irq(v, idx);
406                 break;
407
408         case VHOST_SET_VRING_NUM:
409                 ops->set_vq_num(vdpa, idx, vq->num);
410                 break;
411         }
412
413         return r;
414 }
415
416 static long vhost_vdpa_unlocked_ioctl(struct file *filep,
417                                       unsigned int cmd, unsigned long arg)
418 {
419         struct vhost_vdpa *v = filep->private_data;
420         struct vhost_dev *d = &v->vdev;
421         void __user *argp = (void __user *)arg;
422         u64 __user *featurep = argp;
423         u64 features;
424         long r;
425
426         if (cmd == VHOST_SET_BACKEND_FEATURES) {
427                 r = copy_from_user(&features, featurep, sizeof(features));
428                 if (r)
429                         return r;
430                 if (features & ~VHOST_VDPA_BACKEND_FEATURES)
431                         return -EOPNOTSUPP;
432                 vhost_set_backend_features(&v->vdev, features);
433                 return 0;
434         }
435
436         mutex_lock(&d->mutex);
437
438         switch (cmd) {
439         case VHOST_VDPA_GET_DEVICE_ID:
440                 r = vhost_vdpa_get_device_id(v, argp);
441                 break;
442         case VHOST_VDPA_GET_STATUS:
443                 r = vhost_vdpa_get_status(v, argp);
444                 break;
445         case VHOST_VDPA_SET_STATUS:
446                 r = vhost_vdpa_set_status(v, argp);
447                 break;
448         case VHOST_VDPA_GET_CONFIG:
449                 r = vhost_vdpa_get_config(v, argp);
450                 break;
451         case VHOST_VDPA_SET_CONFIG:
452                 r = vhost_vdpa_set_config(v, argp);
453                 break;
454         case VHOST_GET_FEATURES:
455                 r = vhost_vdpa_get_features(v, argp);
456                 break;
457         case VHOST_SET_FEATURES:
458                 r = vhost_vdpa_set_features(v, argp);
459                 break;
460         case VHOST_VDPA_GET_VRING_NUM:
461                 r = vhost_vdpa_get_vring_num(v, argp);
462                 break;
463         case VHOST_SET_LOG_BASE:
464         case VHOST_SET_LOG_FD:
465                 r = -ENOIOCTLCMD;
466                 break;
467         case VHOST_VDPA_SET_CONFIG_CALL:
468                 r = vhost_vdpa_set_config_call(v, argp);
469                 break;
470         case VHOST_GET_BACKEND_FEATURES:
471                 features = VHOST_VDPA_BACKEND_FEATURES;
472                 r = copy_to_user(featurep, &features, sizeof(features));
473                 break;
474         default:
475                 r = vhost_dev_ioctl(&v->vdev, cmd, argp);
476                 if (r == -ENOIOCTLCMD)
477                         r = vhost_vdpa_vring_ioctl(v, cmd, argp);
478                 break;
479         }
480
481         mutex_unlock(&d->mutex);
482         return r;
483 }
484
485 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v, u64 start, u64 last)
486 {
487         struct vhost_dev *dev = &v->vdev;
488         struct vhost_iotlb *iotlb = dev->iotlb;
489         struct vhost_iotlb_map *map;
490         struct page *page;
491         unsigned long pfn, pinned;
492
493         while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
494                 pinned = map->size >> PAGE_SHIFT;
495                 for (pfn = map->addr >> PAGE_SHIFT;
496                      pinned > 0; pfn++, pinned--) {
497                         page = pfn_to_page(pfn);
498                         if (map->perm & VHOST_ACCESS_WO)
499                                 set_page_dirty_lock(page);
500                         unpin_user_page(page);
501                 }
502                 atomic64_sub(map->size >> PAGE_SHIFT, &dev->mm->pinned_vm);
503                 vhost_iotlb_map_free(iotlb, map);
504         }
505 }
506
507 static void vhost_vdpa_iotlb_free(struct vhost_vdpa *v)
508 {
509         struct vhost_dev *dev = &v->vdev;
510
511         vhost_vdpa_iotlb_unmap(v, 0ULL, 0ULL - 1);
512         kfree(dev->iotlb);
513         dev->iotlb = NULL;
514 }
515
516 static int perm_to_iommu_flags(u32 perm)
517 {
518         int flags = 0;
519
520         switch (perm) {
521         case VHOST_ACCESS_WO:
522                 flags |= IOMMU_WRITE;
523                 break;
524         case VHOST_ACCESS_RO:
525                 flags |= IOMMU_READ;
526                 break;
527         case VHOST_ACCESS_RW:
528                 flags |= (IOMMU_WRITE | IOMMU_READ);
529                 break;
530         default:
531                 WARN(1, "invalidate vhost IOTLB permission\n");
532                 break;
533         }
534
535         return flags | IOMMU_CACHE;
536 }
537
538 static int vhost_vdpa_map(struct vhost_vdpa *v,
539                           u64 iova, u64 size, u64 pa, u32 perm)
540 {
541         struct vhost_dev *dev = &v->vdev;
542         struct vdpa_device *vdpa = v->vdpa;
543         const struct vdpa_config_ops *ops = vdpa->config;
544         int r = 0;
545
546         r = vhost_iotlb_add_range(dev->iotlb, iova, iova + size - 1,
547                                   pa, perm);
548         if (r)
549                 return r;
550
551         if (ops->dma_map) {
552                 r = ops->dma_map(vdpa, iova, size, pa, perm);
553         } else if (ops->set_map) {
554                 if (!v->in_batch)
555                         r = ops->set_map(vdpa, dev->iotlb);
556         } else {
557                 r = iommu_map(v->domain, iova, pa, size,
558                               perm_to_iommu_flags(perm));
559         }
560
561         if (r)
562                 vhost_iotlb_del_range(dev->iotlb, iova, iova + size - 1);
563
564         return r;
565 }
566
567 static void vhost_vdpa_unmap(struct vhost_vdpa *v, u64 iova, u64 size)
568 {
569         struct vhost_dev *dev = &v->vdev;
570         struct vdpa_device *vdpa = v->vdpa;
571         const struct vdpa_config_ops *ops = vdpa->config;
572
573         vhost_vdpa_iotlb_unmap(v, iova, iova + size - 1);
574
575         if (ops->dma_map) {
576                 ops->dma_unmap(vdpa, iova, size);
577         } else if (ops->set_map) {
578                 if (!v->in_batch)
579                         ops->set_map(vdpa, dev->iotlb);
580         } else {
581                 iommu_unmap(v->domain, iova, size);
582         }
583 }
584
585 static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
586                                            struct vhost_iotlb_msg *msg)
587 {
588         struct vhost_dev *dev = &v->vdev;
589         struct vhost_iotlb *iotlb = dev->iotlb;
590         struct page **page_list;
591         struct vm_area_struct **vmas;
592         unsigned int gup_flags = FOLL_LONGTERM;
593         unsigned long map_pfn, last_pfn = 0;
594         unsigned long npages, lock_limit;
595         unsigned long i, nmap = 0;
596         u64 iova = msg->iova;
597         long pinned;
598         int ret = 0;
599
600         if (vhost_iotlb_itree_first(iotlb, msg->iova,
601                                     msg->iova + msg->size - 1))
602                 return -EEXIST;
603
604         if (msg->perm & VHOST_ACCESS_WO)
605                 gup_flags |= FOLL_WRITE;
606
607         npages = PAGE_ALIGN(msg->size + (iova & ~PAGE_MASK)) >> PAGE_SHIFT;
608         if (!npages)
609                 return -EINVAL;
610
611         page_list = kvmalloc_array(npages, sizeof(struct page *), GFP_KERNEL);
612         vmas = kvmalloc_array(npages, sizeof(struct vm_area_struct *),
613                               GFP_KERNEL);
614         if (!page_list || !vmas) {
615                 ret = -ENOMEM;
616                 goto free;
617         }
618
619         mmap_read_lock(dev->mm);
620
621         lock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
622         if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
623                 ret = -ENOMEM;
624                 goto unlock;
625         }
626
627         pinned = pin_user_pages(msg->uaddr & PAGE_MASK, npages, gup_flags,
628                                 page_list, vmas);
629         if (npages != pinned) {
630                 if (pinned < 0) {
631                         ret = pinned;
632                 } else {
633                         unpin_user_pages(page_list, pinned);
634                         ret = -ENOMEM;
635                 }
636                 goto unlock;
637         }
638
639         iova &= PAGE_MASK;
640         map_pfn = page_to_pfn(page_list[0]);
641
642         /* One more iteration to avoid extra vdpa_map() call out of loop. */
643         for (i = 0; i <= npages; i++) {
644                 unsigned long this_pfn;
645                 u64 csize;
646
647                 /* The last chunk may have no valid PFN next to it */
648                 this_pfn = i < npages ? page_to_pfn(page_list[i]) : -1UL;
649
650                 if (last_pfn && (this_pfn == -1UL ||
651                                  this_pfn != last_pfn + 1)) {
652                         /* Pin a contiguous chunk of memory */
653                         csize = last_pfn - map_pfn + 1;
654                         ret = vhost_vdpa_map(v, iova, csize << PAGE_SHIFT,
655                                              map_pfn << PAGE_SHIFT,
656                                              msg->perm);
657                         if (ret) {
658                                 /*
659                                  * Unpin the rest chunks of memory on the
660                                  * flight with no corresponding vdpa_map()
661                                  * calls having been made yet. On the other
662                                  * hand, vdpa_unmap() in the failure path
663                                  * is in charge of accounting the number of
664                                  * pinned pages for its own.
665                                  * This asymmetrical pattern of accounting
666                                  * is for efficiency to pin all pages at
667                                  * once, while there is no other callsite
668                                  * of vdpa_map() than here above.
669                                  */
670                                 unpin_user_pages(&page_list[nmap],
671                                                  npages - nmap);
672                                 goto out;
673                         }
674                         atomic64_add(csize, &dev->mm->pinned_vm);
675                         nmap += csize;
676                         iova += csize << PAGE_SHIFT;
677                         map_pfn = this_pfn;
678                 }
679                 last_pfn = this_pfn;
680         }
681
682         WARN_ON(nmap != npages);
683 out:
684         if (ret)
685                 vhost_vdpa_unmap(v, msg->iova, msg->size);
686 unlock:
687         mmap_read_unlock(dev->mm);
688 free:
689         kvfree(vmas);
690         kvfree(page_list);
691         return ret;
692 }
693
694 static int vhost_vdpa_process_iotlb_msg(struct vhost_dev *dev,
695                                         struct vhost_iotlb_msg *msg)
696 {
697         struct vhost_vdpa *v = container_of(dev, struct vhost_vdpa, vdev);
698         struct vdpa_device *vdpa = v->vdpa;
699         const struct vdpa_config_ops *ops = vdpa->config;
700         int r = 0;
701
702         r = vhost_dev_check_owner(dev);
703         if (r)
704                 return r;
705
706         switch (msg->type) {
707         case VHOST_IOTLB_UPDATE:
708                 r = vhost_vdpa_process_iotlb_update(v, msg);
709                 break;
710         case VHOST_IOTLB_INVALIDATE:
711                 vhost_vdpa_unmap(v, msg->iova, msg->size);
712                 break;
713         case VHOST_IOTLB_BATCH_BEGIN:
714                 v->in_batch = true;
715                 break;
716         case VHOST_IOTLB_BATCH_END:
717                 if (v->in_batch && ops->set_map)
718                         ops->set_map(vdpa, dev->iotlb);
719                 v->in_batch = false;
720                 break;
721         default:
722                 r = -EINVAL;
723                 break;
724         }
725
726         return r;
727 }
728
729 static ssize_t vhost_vdpa_chr_write_iter(struct kiocb *iocb,
730                                          struct iov_iter *from)
731 {
732         struct file *file = iocb->ki_filp;
733         struct vhost_vdpa *v = file->private_data;
734         struct vhost_dev *dev = &v->vdev;
735
736         return vhost_chr_write_iter(dev, from);
737 }
738
739 static int vhost_vdpa_alloc_domain(struct vhost_vdpa *v)
740 {
741         struct vdpa_device *vdpa = v->vdpa;
742         const struct vdpa_config_ops *ops = vdpa->config;
743         struct device *dma_dev = vdpa_get_dma_dev(vdpa);
744         struct bus_type *bus;
745         int ret;
746
747         /* Device want to do DMA by itself */
748         if (ops->set_map || ops->dma_map)
749                 return 0;
750
751         bus = dma_dev->bus;
752         if (!bus)
753                 return -EFAULT;
754
755         if (!iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
756                 return -ENOTSUPP;
757
758         v->domain = iommu_domain_alloc(bus);
759         if (!v->domain)
760                 return -EIO;
761
762         ret = iommu_attach_device(v->domain, dma_dev);
763         if (ret)
764                 goto err_attach;
765
766         return 0;
767
768 err_attach:
769         iommu_domain_free(v->domain);
770         return ret;
771 }
772
773 static void vhost_vdpa_free_domain(struct vhost_vdpa *v)
774 {
775         struct vdpa_device *vdpa = v->vdpa;
776         struct device *dma_dev = vdpa_get_dma_dev(vdpa);
777
778         if (v->domain) {
779                 iommu_detach_device(v->domain, dma_dev);
780                 iommu_domain_free(v->domain);
781         }
782
783         v->domain = NULL;
784 }
785
786 static int vhost_vdpa_open(struct inode *inode, struct file *filep)
787 {
788         struct vhost_vdpa *v;
789         struct vhost_dev *dev;
790         struct vhost_virtqueue **vqs;
791         int nvqs, i, r, opened;
792
793         v = container_of(inode->i_cdev, struct vhost_vdpa, cdev);
794
795         opened = atomic_cmpxchg(&v->opened, 0, 1);
796         if (opened)
797                 return -EBUSY;
798
799         nvqs = v->nvqs;
800         vhost_vdpa_reset(v);
801
802         vqs = kmalloc_array(nvqs, sizeof(*vqs), GFP_KERNEL);
803         if (!vqs) {
804                 r = -ENOMEM;
805                 goto err;
806         }
807
808         dev = &v->vdev;
809         for (i = 0; i < nvqs; i++) {
810                 vqs[i] = &v->vqs[i];
811                 vqs[i]->handle_kick = handle_vq_kick;
812         }
813         vhost_dev_init(dev, vqs, nvqs, 0, 0, 0, false,
814                        vhost_vdpa_process_iotlb_msg);
815
816         dev->iotlb = vhost_iotlb_alloc(0, 0);
817         if (!dev->iotlb) {
818                 r = -ENOMEM;
819                 goto err_init_iotlb;
820         }
821
822         r = vhost_vdpa_alloc_domain(v);
823         if (r)
824                 goto err_init_iotlb;
825
826         filep->private_data = v;
827
828         return 0;
829
830 err_init_iotlb:
831         vhost_dev_cleanup(&v->vdev);
832         kfree(vqs);
833 err:
834         atomic_dec(&v->opened);
835         return r;
836 }
837
838 static void vhost_vdpa_clean_irq(struct vhost_vdpa *v)
839 {
840         struct vhost_virtqueue *vq;
841         int i;
842
843         for (i = 0; i < v->nvqs; i++) {
844                 vq = &v->vqs[i];
845                 if (vq->call_ctx.producer.irq)
846                         irq_bypass_unregister_producer(&vq->call_ctx.producer);
847         }
848 }
849
850 static int vhost_vdpa_release(struct inode *inode, struct file *filep)
851 {
852         struct vhost_vdpa *v = filep->private_data;
853         struct vhost_dev *d = &v->vdev;
854
855         mutex_lock(&d->mutex);
856         filep->private_data = NULL;
857         vhost_vdpa_reset(v);
858         vhost_dev_stop(&v->vdev);
859         vhost_vdpa_iotlb_free(v);
860         vhost_vdpa_free_domain(v);
861         vhost_vdpa_config_put(v);
862         vhost_vdpa_clean_irq(v);
863         vhost_dev_cleanup(&v->vdev);
864         kfree(v->vdev.vqs);
865         mutex_unlock(&d->mutex);
866
867         atomic_dec(&v->opened);
868         complete(&v->completion);
869
870         return 0;
871 }
872
873 #ifdef CONFIG_MMU
874 static vm_fault_t vhost_vdpa_fault(struct vm_fault *vmf)
875 {
876         struct vhost_vdpa *v = vmf->vma->vm_file->private_data;
877         struct vdpa_device *vdpa = v->vdpa;
878         const struct vdpa_config_ops *ops = vdpa->config;
879         struct vdpa_notification_area notify;
880         struct vm_area_struct *vma = vmf->vma;
881         u16 index = vma->vm_pgoff;
882
883         notify = ops->get_vq_notification(vdpa, index);
884
885         vma->vm_page_prot = pgprot_noncached(vma->vm_page_prot);
886         if (remap_pfn_range(vma, vmf->address & PAGE_MASK,
887                             notify.addr >> PAGE_SHIFT, PAGE_SIZE,
888                             vma->vm_page_prot))
889                 return VM_FAULT_SIGBUS;
890
891         return VM_FAULT_NOPAGE;
892 }
893
894 static const struct vm_operations_struct vhost_vdpa_vm_ops = {
895         .fault = vhost_vdpa_fault,
896 };
897
898 static int vhost_vdpa_mmap(struct file *file, struct vm_area_struct *vma)
899 {
900         struct vhost_vdpa *v = vma->vm_file->private_data;
901         struct vdpa_device *vdpa = v->vdpa;
902         const struct vdpa_config_ops *ops = vdpa->config;
903         struct vdpa_notification_area notify;
904         unsigned long index = vma->vm_pgoff;
905
906         if (vma->vm_end - vma->vm_start != PAGE_SIZE)
907                 return -EINVAL;
908         if ((vma->vm_flags & VM_SHARED) == 0)
909                 return -EINVAL;
910         if (vma->vm_flags & VM_READ)
911                 return -EINVAL;
912         if (index > 65535)
913                 return -EINVAL;
914         if (!ops->get_vq_notification)
915                 return -ENOTSUPP;
916
917         /* To be safe and easily modelled by userspace, We only
918          * support the doorbell which sits on the page boundary and
919          * does not share the page with other registers.
920          */
921         notify = ops->get_vq_notification(vdpa, index);
922         if (notify.addr & (PAGE_SIZE - 1))
923                 return -EINVAL;
924         if (vma->vm_end - vma->vm_start != notify.size)
925                 return -ENOTSUPP;
926
927         vma->vm_ops = &vhost_vdpa_vm_ops;
928         return 0;
929 }
930 #endif /* CONFIG_MMU */
931
932 static const struct file_operations vhost_vdpa_fops = {
933         .owner          = THIS_MODULE,
934         .open           = vhost_vdpa_open,
935         .release        = vhost_vdpa_release,
936         .write_iter     = vhost_vdpa_chr_write_iter,
937         .unlocked_ioctl = vhost_vdpa_unlocked_ioctl,
938 #ifdef CONFIG_MMU
939         .mmap           = vhost_vdpa_mmap,
940 #endif /* CONFIG_MMU */
941         .compat_ioctl   = compat_ptr_ioctl,
942 };
943
944 static void vhost_vdpa_release_dev(struct device *device)
945 {
946         struct vhost_vdpa *v =
947                container_of(device, struct vhost_vdpa, dev);
948
949         ida_simple_remove(&vhost_vdpa_ida, v->minor);
950         kfree(v->vqs);
951         kfree(v);
952 }
953
954 static int vhost_vdpa_probe(struct vdpa_device *vdpa)
955 {
956         const struct vdpa_config_ops *ops = vdpa->config;
957         struct vhost_vdpa *v;
958         int minor;
959         int r;
960
961         /* Currently, we only accept the network devices. */
962         if (ops->get_device_id(vdpa) != VIRTIO_ID_NET)
963                 return -ENOTSUPP;
964
965         v = kzalloc(sizeof(*v), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
966         if (!v)
967                 return -ENOMEM;
968
969         minor = ida_simple_get(&vhost_vdpa_ida, 0,
970                                VHOST_VDPA_DEV_MAX, GFP_KERNEL);
971         if (minor < 0) {
972                 kfree(v);
973                 return minor;
974         }
975
976         atomic_set(&v->opened, 0);
977         v->minor = minor;
978         v->vdpa = vdpa;
979         v->nvqs = vdpa->nvqs;
980         v->virtio_id = ops->get_device_id(vdpa);
981
982         device_initialize(&v->dev);
983         v->dev.release = vhost_vdpa_release_dev;
984         v->dev.parent = &vdpa->dev;
985         v->dev.devt = MKDEV(MAJOR(vhost_vdpa_major), minor);
986         v->vqs = kmalloc_array(v->nvqs, sizeof(struct vhost_virtqueue),
987                                GFP_KERNEL);
988         if (!v->vqs) {
989                 r = -ENOMEM;
990                 goto err;
991         }
992
993         r = dev_set_name(&v->dev, "vhost-vdpa-%u", minor);
994         if (r)
995                 goto err;
996
997         cdev_init(&v->cdev, &vhost_vdpa_fops);
998         v->cdev.owner = THIS_MODULE;
999
1000         r = cdev_device_add(&v->cdev, &v->dev);
1001         if (r)
1002                 goto err;
1003
1004         init_completion(&v->completion);
1005         vdpa_set_drvdata(vdpa, v);
1006
1007         return 0;
1008
1009 err:
1010         put_device(&v->dev);
1011         return r;
1012 }
1013
1014 static void vhost_vdpa_remove(struct vdpa_device *vdpa)
1015 {
1016         struct vhost_vdpa *v = vdpa_get_drvdata(vdpa);
1017         int opened;
1018
1019         cdev_device_del(&v->cdev, &v->dev);
1020
1021         do {
1022                 opened = atomic_cmpxchg(&v->opened, 0, 1);
1023                 if (!opened)
1024                         break;
1025                 wait_for_completion(&v->completion);
1026         } while (1);
1027
1028         put_device(&v->dev);
1029 }
1030
1031 static struct vdpa_driver vhost_vdpa_driver = {
1032         .driver = {
1033                 .name   = "vhost_vdpa",
1034         },
1035         .probe  = vhost_vdpa_probe,
1036         .remove = vhost_vdpa_remove,
1037 };
1038
1039 static int __init vhost_vdpa_init(void)
1040 {
1041         int r;
1042
1043         r = alloc_chrdev_region(&vhost_vdpa_major, 0, VHOST_VDPA_DEV_MAX,
1044                                 "vhost-vdpa");
1045         if (r)
1046                 goto err_alloc_chrdev;
1047
1048         r = vdpa_register_driver(&vhost_vdpa_driver);
1049         if (r)
1050                 goto err_vdpa_register_driver;
1051
1052         return 0;
1053
1054 err_vdpa_register_driver:
1055         unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1056 err_alloc_chrdev:
1057         return r;
1058 }
1059 module_init(vhost_vdpa_init);
1060
1061 static void __exit vhost_vdpa_exit(void)
1062 {
1063         vdpa_unregister_driver(&vhost_vdpa_driver);
1064         unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1065 }
1066 module_exit(vhost_vdpa_exit);
1067
1068 MODULE_VERSION("0.0.1");
1069 MODULE_LICENSE("GPL v2");
1070 MODULE_AUTHOR("Intel Corporation");
1071 MODULE_DESCRIPTION("vDPA-based vhost backend for virtio");