Merge tag 'block-6.1-2022-10-20' of git://git.kernel.dk/linux
[platform/kernel/linux-starfive.git] / drivers / vhost / vdpa.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2018-2020 Intel Corporation.
4  * Copyright (C) 2020 Red Hat, Inc.
5  *
6  * Author: Tiwei Bie <tiwei.bie@intel.com>
7  *         Jason Wang <jasowang@redhat.com>
8  *
9  * Thanks Michael S. Tsirkin for the valuable comments and
10  * suggestions.  And thanks to Cunming Liang and Zhihong Wang for all
11  * their supports.
12  */
13
14 #include <linux/kernel.h>
15 #include <linux/module.h>
16 #include <linux/cdev.h>
17 #include <linux/device.h>
18 #include <linux/mm.h>
19 #include <linux/slab.h>
20 #include <linux/iommu.h>
21 #include <linux/uuid.h>
22 #include <linux/vdpa.h>
23 #include <linux/nospec.h>
24 #include <linux/vhost.h>
25
26 #include "vhost.h"
27
28 enum {
29         VHOST_VDPA_BACKEND_FEATURES =
30         (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) |
31         (1ULL << VHOST_BACKEND_F_IOTLB_BATCH) |
32         (1ULL << VHOST_BACKEND_F_IOTLB_ASID),
33 };
34
35 #define VHOST_VDPA_DEV_MAX (1U << MINORBITS)
36
37 #define VHOST_VDPA_IOTLB_BUCKETS 16
38
39 struct vhost_vdpa_as {
40         struct hlist_node hash_link;
41         struct vhost_iotlb iotlb;
42         u32 id;
43 };
44
45 struct vhost_vdpa {
46         struct vhost_dev vdev;
47         struct iommu_domain *domain;
48         struct vhost_virtqueue *vqs;
49         struct completion completion;
50         struct vdpa_device *vdpa;
51         struct hlist_head as[VHOST_VDPA_IOTLB_BUCKETS];
52         struct device dev;
53         struct cdev cdev;
54         atomic_t opened;
55         u32 nvqs;
56         int virtio_id;
57         int minor;
58         struct eventfd_ctx *config_ctx;
59         int in_batch;
60         struct vdpa_iova_range range;
61         u32 batch_asid;
62 };
63
64 static DEFINE_IDA(vhost_vdpa_ida);
65
66 static dev_t vhost_vdpa_major;
67
68 static inline u32 iotlb_to_asid(struct vhost_iotlb *iotlb)
69 {
70         struct vhost_vdpa_as *as = container_of(iotlb, struct
71                                                 vhost_vdpa_as, iotlb);
72         return as->id;
73 }
74
75 static struct vhost_vdpa_as *asid_to_as(struct vhost_vdpa *v, u32 asid)
76 {
77         struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
78         struct vhost_vdpa_as *as;
79
80         hlist_for_each_entry(as, head, hash_link)
81                 if (as->id == asid)
82                         return as;
83
84         return NULL;
85 }
86
87 static struct vhost_iotlb *asid_to_iotlb(struct vhost_vdpa *v, u32 asid)
88 {
89         struct vhost_vdpa_as *as = asid_to_as(v, asid);
90
91         if (!as)
92                 return NULL;
93
94         return &as->iotlb;
95 }
96
97 static struct vhost_vdpa_as *vhost_vdpa_alloc_as(struct vhost_vdpa *v, u32 asid)
98 {
99         struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
100         struct vhost_vdpa_as *as;
101
102         if (asid_to_as(v, asid))
103                 return NULL;
104
105         if (asid >= v->vdpa->nas)
106                 return NULL;
107
108         as = kmalloc(sizeof(*as), GFP_KERNEL);
109         if (!as)
110                 return NULL;
111
112         vhost_iotlb_init(&as->iotlb, 0, 0);
113         as->id = asid;
114         hlist_add_head(&as->hash_link, head);
115
116         return as;
117 }
118
119 static struct vhost_vdpa_as *vhost_vdpa_find_alloc_as(struct vhost_vdpa *v,
120                                                       u32 asid)
121 {
122         struct vhost_vdpa_as *as = asid_to_as(v, asid);
123
124         if (as)
125                 return as;
126
127         return vhost_vdpa_alloc_as(v, asid);
128 }
129
130 static int vhost_vdpa_remove_as(struct vhost_vdpa *v, u32 asid)
131 {
132         struct vhost_vdpa_as *as = asid_to_as(v, asid);
133
134         if (!as)
135                 return -EINVAL;
136
137         hlist_del(&as->hash_link);
138         vhost_iotlb_reset(&as->iotlb);
139         kfree(as);
140
141         return 0;
142 }
143
144 static void handle_vq_kick(struct vhost_work *work)
145 {
146         struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
147                                                   poll.work);
148         struct vhost_vdpa *v = container_of(vq->dev, struct vhost_vdpa, vdev);
149         const struct vdpa_config_ops *ops = v->vdpa->config;
150
151         ops->kick_vq(v->vdpa, vq - v->vqs);
152 }
153
154 static irqreturn_t vhost_vdpa_virtqueue_cb(void *private)
155 {
156         struct vhost_virtqueue *vq = private;
157         struct eventfd_ctx *call_ctx = vq->call_ctx.ctx;
158
159         if (call_ctx)
160                 eventfd_signal(call_ctx, 1);
161
162         return IRQ_HANDLED;
163 }
164
165 static irqreturn_t vhost_vdpa_config_cb(void *private)
166 {
167         struct vhost_vdpa *v = private;
168         struct eventfd_ctx *config_ctx = v->config_ctx;
169
170         if (config_ctx)
171                 eventfd_signal(config_ctx, 1);
172
173         return IRQ_HANDLED;
174 }
175
176 static void vhost_vdpa_setup_vq_irq(struct vhost_vdpa *v, u16 qid)
177 {
178         struct vhost_virtqueue *vq = &v->vqs[qid];
179         const struct vdpa_config_ops *ops = v->vdpa->config;
180         struct vdpa_device *vdpa = v->vdpa;
181         int ret, irq;
182
183         if (!ops->get_vq_irq)
184                 return;
185
186         irq = ops->get_vq_irq(vdpa, qid);
187         if (irq < 0)
188                 return;
189
190         irq_bypass_unregister_producer(&vq->call_ctx.producer);
191         if (!vq->call_ctx.ctx)
192                 return;
193
194         vq->call_ctx.producer.token = vq->call_ctx.ctx;
195         vq->call_ctx.producer.irq = irq;
196         ret = irq_bypass_register_producer(&vq->call_ctx.producer);
197         if (unlikely(ret))
198                 dev_info(&v->dev, "vq %u, irq bypass producer (token %p) registration fails, ret =  %d\n",
199                          qid, vq->call_ctx.producer.token, ret);
200 }
201
202 static void vhost_vdpa_unsetup_vq_irq(struct vhost_vdpa *v, u16 qid)
203 {
204         struct vhost_virtqueue *vq = &v->vqs[qid];
205
206         irq_bypass_unregister_producer(&vq->call_ctx.producer);
207 }
208
209 static int vhost_vdpa_reset(struct vhost_vdpa *v)
210 {
211         struct vdpa_device *vdpa = v->vdpa;
212
213         v->in_batch = 0;
214
215         return vdpa_reset(vdpa);
216 }
217
218 static long vhost_vdpa_get_device_id(struct vhost_vdpa *v, u8 __user *argp)
219 {
220         struct vdpa_device *vdpa = v->vdpa;
221         const struct vdpa_config_ops *ops = vdpa->config;
222         u32 device_id;
223
224         device_id = ops->get_device_id(vdpa);
225
226         if (copy_to_user(argp, &device_id, sizeof(device_id)))
227                 return -EFAULT;
228
229         return 0;
230 }
231
232 static long vhost_vdpa_get_status(struct vhost_vdpa *v, u8 __user *statusp)
233 {
234         struct vdpa_device *vdpa = v->vdpa;
235         const struct vdpa_config_ops *ops = vdpa->config;
236         u8 status;
237
238         status = ops->get_status(vdpa);
239
240         if (copy_to_user(statusp, &status, sizeof(status)))
241                 return -EFAULT;
242
243         return 0;
244 }
245
246 static long vhost_vdpa_set_status(struct vhost_vdpa *v, u8 __user *statusp)
247 {
248         struct vdpa_device *vdpa = v->vdpa;
249         const struct vdpa_config_ops *ops = vdpa->config;
250         u8 status, status_old;
251         u32 nvqs = v->nvqs;
252         int ret;
253         u16 i;
254
255         if (copy_from_user(&status, statusp, sizeof(status)))
256                 return -EFAULT;
257
258         status_old = ops->get_status(vdpa);
259
260         /*
261          * Userspace shouldn't remove status bits unless reset the
262          * status to 0.
263          */
264         if (status != 0 && (status_old & ~status) != 0)
265                 return -EINVAL;
266
267         if ((status_old & VIRTIO_CONFIG_S_DRIVER_OK) && !(status & VIRTIO_CONFIG_S_DRIVER_OK))
268                 for (i = 0; i < nvqs; i++)
269                         vhost_vdpa_unsetup_vq_irq(v, i);
270
271         if (status == 0) {
272                 ret = vdpa_reset(vdpa);
273                 if (ret)
274                         return ret;
275         } else
276                 vdpa_set_status(vdpa, status);
277
278         if ((status & VIRTIO_CONFIG_S_DRIVER_OK) && !(status_old & VIRTIO_CONFIG_S_DRIVER_OK))
279                 for (i = 0; i < nvqs; i++)
280                         vhost_vdpa_setup_vq_irq(v, i);
281
282         return 0;
283 }
284
285 static int vhost_vdpa_config_validate(struct vhost_vdpa *v,
286                                       struct vhost_vdpa_config *c)
287 {
288         struct vdpa_device *vdpa = v->vdpa;
289         size_t size = vdpa->config->get_config_size(vdpa);
290
291         if (c->len == 0 || c->off > size)
292                 return -EINVAL;
293
294         if (c->len > size - c->off)
295                 return -E2BIG;
296
297         return 0;
298 }
299
300 static long vhost_vdpa_get_config(struct vhost_vdpa *v,
301                                   struct vhost_vdpa_config __user *c)
302 {
303         struct vdpa_device *vdpa = v->vdpa;
304         struct vhost_vdpa_config config;
305         unsigned long size = offsetof(struct vhost_vdpa_config, buf);
306         u8 *buf;
307
308         if (copy_from_user(&config, c, size))
309                 return -EFAULT;
310         if (vhost_vdpa_config_validate(v, &config))
311                 return -EINVAL;
312         buf = kvzalloc(config.len, GFP_KERNEL);
313         if (!buf)
314                 return -ENOMEM;
315
316         vdpa_get_config(vdpa, config.off, buf, config.len);
317
318         if (copy_to_user(c->buf, buf, config.len)) {
319                 kvfree(buf);
320                 return -EFAULT;
321         }
322
323         kvfree(buf);
324         return 0;
325 }
326
327 static long vhost_vdpa_set_config(struct vhost_vdpa *v,
328                                   struct vhost_vdpa_config __user *c)
329 {
330         struct vdpa_device *vdpa = v->vdpa;
331         struct vhost_vdpa_config config;
332         unsigned long size = offsetof(struct vhost_vdpa_config, buf);
333         u8 *buf;
334
335         if (copy_from_user(&config, c, size))
336                 return -EFAULT;
337         if (vhost_vdpa_config_validate(v, &config))
338                 return -EINVAL;
339
340         buf = vmemdup_user(c->buf, config.len);
341         if (IS_ERR(buf))
342                 return PTR_ERR(buf);
343
344         vdpa_set_config(vdpa, config.off, buf, config.len);
345
346         kvfree(buf);
347         return 0;
348 }
349
350 static bool vhost_vdpa_can_suspend(const struct vhost_vdpa *v)
351 {
352         struct vdpa_device *vdpa = v->vdpa;
353         const struct vdpa_config_ops *ops = vdpa->config;
354
355         return ops->suspend;
356 }
357
358 static long vhost_vdpa_get_features(struct vhost_vdpa *v, u64 __user *featurep)
359 {
360         struct vdpa_device *vdpa = v->vdpa;
361         const struct vdpa_config_ops *ops = vdpa->config;
362         u64 features;
363
364         features = ops->get_device_features(vdpa);
365
366         if (copy_to_user(featurep, &features, sizeof(features)))
367                 return -EFAULT;
368
369         return 0;
370 }
371
372 static long vhost_vdpa_set_features(struct vhost_vdpa *v, u64 __user *featurep)
373 {
374         struct vdpa_device *vdpa = v->vdpa;
375         const struct vdpa_config_ops *ops = vdpa->config;
376         u64 features;
377
378         /*
379          * It's not allowed to change the features after they have
380          * been negotiated.
381          */
382         if (ops->get_status(vdpa) & VIRTIO_CONFIG_S_FEATURES_OK)
383                 return -EBUSY;
384
385         if (copy_from_user(&features, featurep, sizeof(features)))
386                 return -EFAULT;
387
388         if (vdpa_set_features(vdpa, features))
389                 return -EINVAL;
390
391         return 0;
392 }
393
394 static long vhost_vdpa_get_vring_num(struct vhost_vdpa *v, u16 __user *argp)
395 {
396         struct vdpa_device *vdpa = v->vdpa;
397         const struct vdpa_config_ops *ops = vdpa->config;
398         u16 num;
399
400         num = ops->get_vq_num_max(vdpa);
401
402         if (copy_to_user(argp, &num, sizeof(num)))
403                 return -EFAULT;
404
405         return 0;
406 }
407
408 static void vhost_vdpa_config_put(struct vhost_vdpa *v)
409 {
410         if (v->config_ctx) {
411                 eventfd_ctx_put(v->config_ctx);
412                 v->config_ctx = NULL;
413         }
414 }
415
416 static long vhost_vdpa_set_config_call(struct vhost_vdpa *v, u32 __user *argp)
417 {
418         struct vdpa_callback cb;
419         int fd;
420         struct eventfd_ctx *ctx;
421
422         cb.callback = vhost_vdpa_config_cb;
423         cb.private = v;
424         if (copy_from_user(&fd, argp, sizeof(fd)))
425                 return  -EFAULT;
426
427         ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
428         swap(ctx, v->config_ctx);
429
430         if (!IS_ERR_OR_NULL(ctx))
431                 eventfd_ctx_put(ctx);
432
433         if (IS_ERR(v->config_ctx)) {
434                 long ret = PTR_ERR(v->config_ctx);
435
436                 v->config_ctx = NULL;
437                 return ret;
438         }
439
440         v->vdpa->config->set_config_cb(v->vdpa, &cb);
441
442         return 0;
443 }
444
445 static long vhost_vdpa_get_iova_range(struct vhost_vdpa *v, u32 __user *argp)
446 {
447         struct vhost_vdpa_iova_range range = {
448                 .first = v->range.first,
449                 .last = v->range.last,
450         };
451
452         if (copy_to_user(argp, &range, sizeof(range)))
453                 return -EFAULT;
454         return 0;
455 }
456
457 static long vhost_vdpa_get_config_size(struct vhost_vdpa *v, u32 __user *argp)
458 {
459         struct vdpa_device *vdpa = v->vdpa;
460         const struct vdpa_config_ops *ops = vdpa->config;
461         u32 size;
462
463         size = ops->get_config_size(vdpa);
464
465         if (copy_to_user(argp, &size, sizeof(size)))
466                 return -EFAULT;
467
468         return 0;
469 }
470
471 static long vhost_vdpa_get_vqs_count(struct vhost_vdpa *v, u32 __user *argp)
472 {
473         struct vdpa_device *vdpa = v->vdpa;
474
475         if (copy_to_user(argp, &vdpa->nvqs, sizeof(vdpa->nvqs)))
476                 return -EFAULT;
477
478         return 0;
479 }
480
481 /* After a successful return of ioctl the device must not process more
482  * virtqueue descriptors. The device can answer to read or writes of config
483  * fields as if it were not suspended. In particular, writing to "queue_enable"
484  * with a value of 1 will not make the device start processing buffers.
485  */
486 static long vhost_vdpa_suspend(struct vhost_vdpa *v)
487 {
488         struct vdpa_device *vdpa = v->vdpa;
489         const struct vdpa_config_ops *ops = vdpa->config;
490
491         if (!ops->suspend)
492                 return -EOPNOTSUPP;
493
494         return ops->suspend(vdpa);
495 }
496
497 static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd,
498                                    void __user *argp)
499 {
500         struct vdpa_device *vdpa = v->vdpa;
501         const struct vdpa_config_ops *ops = vdpa->config;
502         struct vdpa_vq_state vq_state;
503         struct vdpa_callback cb;
504         struct vhost_virtqueue *vq;
505         struct vhost_vring_state s;
506         u32 idx;
507         long r;
508
509         r = get_user(idx, (u32 __user *)argp);
510         if (r < 0)
511                 return r;
512
513         if (idx >= v->nvqs)
514                 return -ENOBUFS;
515
516         idx = array_index_nospec(idx, v->nvqs);
517         vq = &v->vqs[idx];
518
519         switch (cmd) {
520         case VHOST_VDPA_SET_VRING_ENABLE:
521                 if (copy_from_user(&s, argp, sizeof(s)))
522                         return -EFAULT;
523                 ops->set_vq_ready(vdpa, idx, s.num);
524                 return 0;
525         case VHOST_VDPA_GET_VRING_GROUP:
526                 if (!ops->get_vq_group)
527                         return -EOPNOTSUPP;
528                 s.index = idx;
529                 s.num = ops->get_vq_group(vdpa, idx);
530                 if (s.num >= vdpa->ngroups)
531                         return -EIO;
532                 else if (copy_to_user(argp, &s, sizeof(s)))
533                         return -EFAULT;
534                 return 0;
535         case VHOST_VDPA_SET_GROUP_ASID:
536                 if (copy_from_user(&s, argp, sizeof(s)))
537                         return -EFAULT;
538                 if (s.num >= vdpa->nas)
539                         return -EINVAL;
540                 if (!ops->set_group_asid)
541                         return -EOPNOTSUPP;
542                 return ops->set_group_asid(vdpa, idx, s.num);
543         case VHOST_GET_VRING_BASE:
544                 r = ops->get_vq_state(v->vdpa, idx, &vq_state);
545                 if (r)
546                         return r;
547
548                 vq->last_avail_idx = vq_state.split.avail_index;
549                 break;
550         }
551
552         r = vhost_vring_ioctl(&v->vdev, cmd, argp);
553         if (r)
554                 return r;
555
556         switch (cmd) {
557         case VHOST_SET_VRING_ADDR:
558                 if (ops->set_vq_address(vdpa, idx,
559                                         (u64)(uintptr_t)vq->desc,
560                                         (u64)(uintptr_t)vq->avail,
561                                         (u64)(uintptr_t)vq->used))
562                         r = -EINVAL;
563                 break;
564
565         case VHOST_SET_VRING_BASE:
566                 vq_state.split.avail_index = vq->last_avail_idx;
567                 if (ops->set_vq_state(vdpa, idx, &vq_state))
568                         r = -EINVAL;
569                 break;
570
571         case VHOST_SET_VRING_CALL:
572                 if (vq->call_ctx.ctx) {
573                         cb.callback = vhost_vdpa_virtqueue_cb;
574                         cb.private = vq;
575                 } else {
576                         cb.callback = NULL;
577                         cb.private = NULL;
578                 }
579                 ops->set_vq_cb(vdpa, idx, &cb);
580                 vhost_vdpa_setup_vq_irq(v, idx);
581                 break;
582
583         case VHOST_SET_VRING_NUM:
584                 ops->set_vq_num(vdpa, idx, vq->num);
585                 break;
586         }
587
588         return r;
589 }
590
591 static long vhost_vdpa_unlocked_ioctl(struct file *filep,
592                                       unsigned int cmd, unsigned long arg)
593 {
594         struct vhost_vdpa *v = filep->private_data;
595         struct vhost_dev *d = &v->vdev;
596         void __user *argp = (void __user *)arg;
597         u64 __user *featurep = argp;
598         u64 features;
599         long r = 0;
600
601         if (cmd == VHOST_SET_BACKEND_FEATURES) {
602                 if (copy_from_user(&features, featurep, sizeof(features)))
603                         return -EFAULT;
604                 if (features & ~(VHOST_VDPA_BACKEND_FEATURES |
605                                  BIT_ULL(VHOST_BACKEND_F_SUSPEND)))
606                         return -EOPNOTSUPP;
607                 if ((features & BIT_ULL(VHOST_BACKEND_F_SUSPEND)) &&
608                      !vhost_vdpa_can_suspend(v))
609                         return -EOPNOTSUPP;
610                 vhost_set_backend_features(&v->vdev, features);
611                 return 0;
612         }
613
614         mutex_lock(&d->mutex);
615
616         switch (cmd) {
617         case VHOST_VDPA_GET_DEVICE_ID:
618                 r = vhost_vdpa_get_device_id(v, argp);
619                 break;
620         case VHOST_VDPA_GET_STATUS:
621                 r = vhost_vdpa_get_status(v, argp);
622                 break;
623         case VHOST_VDPA_SET_STATUS:
624                 r = vhost_vdpa_set_status(v, argp);
625                 break;
626         case VHOST_VDPA_GET_CONFIG:
627                 r = vhost_vdpa_get_config(v, argp);
628                 break;
629         case VHOST_VDPA_SET_CONFIG:
630                 r = vhost_vdpa_set_config(v, argp);
631                 break;
632         case VHOST_GET_FEATURES:
633                 r = vhost_vdpa_get_features(v, argp);
634                 break;
635         case VHOST_SET_FEATURES:
636                 r = vhost_vdpa_set_features(v, argp);
637                 break;
638         case VHOST_VDPA_GET_VRING_NUM:
639                 r = vhost_vdpa_get_vring_num(v, argp);
640                 break;
641         case VHOST_VDPA_GET_GROUP_NUM:
642                 if (copy_to_user(argp, &v->vdpa->ngroups,
643                                  sizeof(v->vdpa->ngroups)))
644                         r = -EFAULT;
645                 break;
646         case VHOST_VDPA_GET_AS_NUM:
647                 if (copy_to_user(argp, &v->vdpa->nas, sizeof(v->vdpa->nas)))
648                         r = -EFAULT;
649                 break;
650         case VHOST_SET_LOG_BASE:
651         case VHOST_SET_LOG_FD:
652                 r = -ENOIOCTLCMD;
653                 break;
654         case VHOST_VDPA_SET_CONFIG_CALL:
655                 r = vhost_vdpa_set_config_call(v, argp);
656                 break;
657         case VHOST_GET_BACKEND_FEATURES:
658                 features = VHOST_VDPA_BACKEND_FEATURES;
659                 if (vhost_vdpa_can_suspend(v))
660                         features |= BIT_ULL(VHOST_BACKEND_F_SUSPEND);
661                 if (copy_to_user(featurep, &features, sizeof(features)))
662                         r = -EFAULT;
663                 break;
664         case VHOST_VDPA_GET_IOVA_RANGE:
665                 r = vhost_vdpa_get_iova_range(v, argp);
666                 break;
667         case VHOST_VDPA_GET_CONFIG_SIZE:
668                 r = vhost_vdpa_get_config_size(v, argp);
669                 break;
670         case VHOST_VDPA_GET_VQS_COUNT:
671                 r = vhost_vdpa_get_vqs_count(v, argp);
672                 break;
673         case VHOST_VDPA_SUSPEND:
674                 r = vhost_vdpa_suspend(v);
675                 break;
676         default:
677                 r = vhost_dev_ioctl(&v->vdev, cmd, argp);
678                 if (r == -ENOIOCTLCMD)
679                         r = vhost_vdpa_vring_ioctl(v, cmd, argp);
680                 break;
681         }
682
683         mutex_unlock(&d->mutex);
684         return r;
685 }
686
687 static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v,
688                                 struct vhost_iotlb *iotlb,
689                                 u64 start, u64 last)
690 {
691         struct vhost_dev *dev = &v->vdev;
692         struct vhost_iotlb_map *map;
693         struct page *page;
694         unsigned long pfn, pinned;
695
696         while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
697                 pinned = PFN_DOWN(map->size);
698                 for (pfn = PFN_DOWN(map->addr);
699                      pinned > 0; pfn++, pinned--) {
700                         page = pfn_to_page(pfn);
701                         if (map->perm & VHOST_ACCESS_WO)
702                                 set_page_dirty_lock(page);
703                         unpin_user_page(page);
704                 }
705                 atomic64_sub(PFN_DOWN(map->size), &dev->mm->pinned_vm);
706                 vhost_iotlb_map_free(iotlb, map);
707         }
708 }
709
710 static void vhost_vdpa_va_unmap(struct vhost_vdpa *v,
711                                 struct vhost_iotlb *iotlb,
712                                 u64 start, u64 last)
713 {
714         struct vhost_iotlb_map *map;
715         struct vdpa_map_file *map_file;
716
717         while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
718                 map_file = (struct vdpa_map_file *)map->opaque;
719                 fput(map_file->file);
720                 kfree(map_file);
721                 vhost_iotlb_map_free(iotlb, map);
722         }
723 }
724
725 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
726                                    struct vhost_iotlb *iotlb,
727                                    u64 start, u64 last)
728 {
729         struct vdpa_device *vdpa = v->vdpa;
730
731         if (vdpa->use_va)
732                 return vhost_vdpa_va_unmap(v, iotlb, start, last);
733
734         return vhost_vdpa_pa_unmap(v, iotlb, start, last);
735 }
736
737 static int perm_to_iommu_flags(u32 perm)
738 {
739         int flags = 0;
740
741         switch (perm) {
742         case VHOST_ACCESS_WO:
743                 flags |= IOMMU_WRITE;
744                 break;
745         case VHOST_ACCESS_RO:
746                 flags |= IOMMU_READ;
747                 break;
748         case VHOST_ACCESS_RW:
749                 flags |= (IOMMU_WRITE | IOMMU_READ);
750                 break;
751         default:
752                 WARN(1, "invalidate vhost IOTLB permission\n");
753                 break;
754         }
755
756         return flags | IOMMU_CACHE;
757 }
758
759 static int vhost_vdpa_map(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
760                           u64 iova, u64 size, u64 pa, u32 perm, void *opaque)
761 {
762         struct vhost_dev *dev = &v->vdev;
763         struct vdpa_device *vdpa = v->vdpa;
764         const struct vdpa_config_ops *ops = vdpa->config;
765         u32 asid = iotlb_to_asid(iotlb);
766         int r = 0;
767
768         r = vhost_iotlb_add_range_ctx(iotlb, iova, iova + size - 1,
769                                       pa, perm, opaque);
770         if (r)
771                 return r;
772
773         if (ops->dma_map) {
774                 r = ops->dma_map(vdpa, asid, iova, size, pa, perm, opaque);
775         } else if (ops->set_map) {
776                 if (!v->in_batch)
777                         r = ops->set_map(vdpa, asid, iotlb);
778         } else {
779                 r = iommu_map(v->domain, iova, pa, size,
780                               perm_to_iommu_flags(perm));
781         }
782         if (r) {
783                 vhost_iotlb_del_range(iotlb, iova, iova + size - 1);
784                 return r;
785         }
786
787         if (!vdpa->use_va)
788                 atomic64_add(PFN_DOWN(size), &dev->mm->pinned_vm);
789
790         return 0;
791 }
792
793 static void vhost_vdpa_unmap(struct vhost_vdpa *v,
794                              struct vhost_iotlb *iotlb,
795                              u64 iova, u64 size)
796 {
797         struct vdpa_device *vdpa = v->vdpa;
798         const struct vdpa_config_ops *ops = vdpa->config;
799         u32 asid = iotlb_to_asid(iotlb);
800
801         vhost_vdpa_iotlb_unmap(v, iotlb, iova, iova + size - 1);
802
803         if (ops->dma_map) {
804                 ops->dma_unmap(vdpa, asid, iova, size);
805         } else if (ops->set_map) {
806                 if (!v->in_batch)
807                         ops->set_map(vdpa, asid, iotlb);
808         } else {
809                 iommu_unmap(v->domain, iova, size);
810         }
811
812         /* If we are in the middle of batch processing, delay the free
813          * of AS until BATCH_END.
814          */
815         if (!v->in_batch && !iotlb->nmaps)
816                 vhost_vdpa_remove_as(v, asid);
817 }
818
819 static int vhost_vdpa_va_map(struct vhost_vdpa *v,
820                              struct vhost_iotlb *iotlb,
821                              u64 iova, u64 size, u64 uaddr, u32 perm)
822 {
823         struct vhost_dev *dev = &v->vdev;
824         u64 offset, map_size, map_iova = iova;
825         struct vdpa_map_file *map_file;
826         struct vm_area_struct *vma;
827         int ret = 0;
828
829         mmap_read_lock(dev->mm);
830
831         while (size) {
832                 vma = find_vma(dev->mm, uaddr);
833                 if (!vma) {
834                         ret = -EINVAL;
835                         break;
836                 }
837                 map_size = min(size, vma->vm_end - uaddr);
838                 if (!(vma->vm_file && (vma->vm_flags & VM_SHARED) &&
839                         !(vma->vm_flags & (VM_IO | VM_PFNMAP))))
840                         goto next;
841
842                 map_file = kzalloc(sizeof(*map_file), GFP_KERNEL);
843                 if (!map_file) {
844                         ret = -ENOMEM;
845                         break;
846                 }
847                 offset = (vma->vm_pgoff << PAGE_SHIFT) + uaddr - vma->vm_start;
848                 map_file->offset = offset;
849                 map_file->file = get_file(vma->vm_file);
850                 ret = vhost_vdpa_map(v, iotlb, map_iova, map_size, uaddr,
851                                      perm, map_file);
852                 if (ret) {
853                         fput(map_file->file);
854                         kfree(map_file);
855                         break;
856                 }
857 next:
858                 size -= map_size;
859                 uaddr += map_size;
860                 map_iova += map_size;
861         }
862         if (ret)
863                 vhost_vdpa_unmap(v, iotlb, iova, map_iova - iova);
864
865         mmap_read_unlock(dev->mm);
866
867         return ret;
868 }
869
870 static int vhost_vdpa_pa_map(struct vhost_vdpa *v,
871                              struct vhost_iotlb *iotlb,
872                              u64 iova, u64 size, u64 uaddr, u32 perm)
873 {
874         struct vhost_dev *dev = &v->vdev;
875         struct page **page_list;
876         unsigned long list_size = PAGE_SIZE / sizeof(struct page *);
877         unsigned int gup_flags = FOLL_LONGTERM;
878         unsigned long npages, cur_base, map_pfn, last_pfn = 0;
879         unsigned long lock_limit, sz2pin, nchunks, i;
880         u64 start = iova;
881         long pinned;
882         int ret = 0;
883
884         /* Limit the use of memory for bookkeeping */
885         page_list = (struct page **) __get_free_page(GFP_KERNEL);
886         if (!page_list)
887                 return -ENOMEM;
888
889         if (perm & VHOST_ACCESS_WO)
890                 gup_flags |= FOLL_WRITE;
891
892         npages = PFN_UP(size + (iova & ~PAGE_MASK));
893         if (!npages) {
894                 ret = -EINVAL;
895                 goto free;
896         }
897
898         mmap_read_lock(dev->mm);
899
900         lock_limit = PFN_DOWN(rlimit(RLIMIT_MEMLOCK));
901         if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
902                 ret = -ENOMEM;
903                 goto unlock;
904         }
905
906         cur_base = uaddr & PAGE_MASK;
907         iova &= PAGE_MASK;
908         nchunks = 0;
909
910         while (npages) {
911                 sz2pin = min_t(unsigned long, npages, list_size);
912                 pinned = pin_user_pages(cur_base, sz2pin,
913                                         gup_flags, page_list, NULL);
914                 if (sz2pin != pinned) {
915                         if (pinned < 0) {
916                                 ret = pinned;
917                         } else {
918                                 unpin_user_pages(page_list, pinned);
919                                 ret = -ENOMEM;
920                         }
921                         goto out;
922                 }
923                 nchunks++;
924
925                 if (!last_pfn)
926                         map_pfn = page_to_pfn(page_list[0]);
927
928                 for (i = 0; i < pinned; i++) {
929                         unsigned long this_pfn = page_to_pfn(page_list[i]);
930                         u64 csize;
931
932                         if (last_pfn && (this_pfn != last_pfn + 1)) {
933                                 /* Pin a contiguous chunk of memory */
934                                 csize = PFN_PHYS(last_pfn - map_pfn + 1);
935                                 ret = vhost_vdpa_map(v, iotlb, iova, csize,
936                                                      PFN_PHYS(map_pfn),
937                                                      perm, NULL);
938                                 if (ret) {
939                                         /*
940                                          * Unpin the pages that are left unmapped
941                                          * from this point on in the current
942                                          * page_list. The remaining outstanding
943                                          * ones which may stride across several
944                                          * chunks will be covered in the common
945                                          * error path subsequently.
946                                          */
947                                         unpin_user_pages(&page_list[i],
948                                                          pinned - i);
949                                         goto out;
950                                 }
951
952                                 map_pfn = this_pfn;
953                                 iova += csize;
954                                 nchunks = 0;
955                         }
956
957                         last_pfn = this_pfn;
958                 }
959
960                 cur_base += PFN_PHYS(pinned);
961                 npages -= pinned;
962         }
963
964         /* Pin the rest chunk */
965         ret = vhost_vdpa_map(v, iotlb, iova, PFN_PHYS(last_pfn - map_pfn + 1),
966                              PFN_PHYS(map_pfn), perm, NULL);
967 out:
968         if (ret) {
969                 if (nchunks) {
970                         unsigned long pfn;
971
972                         /*
973                          * Unpin the outstanding pages which are yet to be
974                          * mapped but haven't due to vdpa_map() or
975                          * pin_user_pages() failure.
976                          *
977                          * Mapped pages are accounted in vdpa_map(), hence
978                          * the corresponding unpinning will be handled by
979                          * vdpa_unmap().
980                          */
981                         WARN_ON(!last_pfn);
982                         for (pfn = map_pfn; pfn <= last_pfn; pfn++)
983                                 unpin_user_page(pfn_to_page(pfn));
984                 }
985                 vhost_vdpa_unmap(v, iotlb, start, size);
986         }
987 unlock:
988         mmap_read_unlock(dev->mm);
989 free:
990         free_page((unsigned long)page_list);
991         return ret;
992
993 }
994
995 static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
996                                            struct vhost_iotlb *iotlb,
997                                            struct vhost_iotlb_msg *msg)
998 {
999         struct vdpa_device *vdpa = v->vdpa;
1000
1001         if (msg->iova < v->range.first || !msg->size ||
1002             msg->iova > U64_MAX - msg->size + 1 ||
1003             msg->iova + msg->size - 1 > v->range.last)
1004                 return -EINVAL;
1005
1006         if (vhost_iotlb_itree_first(iotlb, msg->iova,
1007                                     msg->iova + msg->size - 1))
1008                 return -EEXIST;
1009
1010         if (vdpa->use_va)
1011                 return vhost_vdpa_va_map(v, iotlb, msg->iova, msg->size,
1012                                          msg->uaddr, msg->perm);
1013
1014         return vhost_vdpa_pa_map(v, iotlb, msg->iova, msg->size, msg->uaddr,
1015                                  msg->perm);
1016 }
1017
1018 static int vhost_vdpa_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
1019                                         struct vhost_iotlb_msg *msg)
1020 {
1021         struct vhost_vdpa *v = container_of(dev, struct vhost_vdpa, vdev);
1022         struct vdpa_device *vdpa = v->vdpa;
1023         const struct vdpa_config_ops *ops = vdpa->config;
1024         struct vhost_iotlb *iotlb = NULL;
1025         struct vhost_vdpa_as *as = NULL;
1026         int r = 0;
1027
1028         mutex_lock(&dev->mutex);
1029
1030         r = vhost_dev_check_owner(dev);
1031         if (r)
1032                 goto unlock;
1033
1034         if (msg->type == VHOST_IOTLB_UPDATE ||
1035             msg->type == VHOST_IOTLB_BATCH_BEGIN) {
1036                 as = vhost_vdpa_find_alloc_as(v, asid);
1037                 if (!as) {
1038                         dev_err(&v->dev, "can't find and alloc asid %d\n",
1039                                 asid);
1040                         r = -EINVAL;
1041                         goto unlock;
1042                 }
1043                 iotlb = &as->iotlb;
1044         } else
1045                 iotlb = asid_to_iotlb(v, asid);
1046
1047         if ((v->in_batch && v->batch_asid != asid) || !iotlb) {
1048                 if (v->in_batch && v->batch_asid != asid) {
1049                         dev_info(&v->dev, "batch id %d asid %d\n",
1050                                  v->batch_asid, asid);
1051                 }
1052                 if (!iotlb)
1053                         dev_err(&v->dev, "no iotlb for asid %d\n", asid);
1054                 r = -EINVAL;
1055                 goto unlock;
1056         }
1057
1058         switch (msg->type) {
1059         case VHOST_IOTLB_UPDATE:
1060                 r = vhost_vdpa_process_iotlb_update(v, iotlb, msg);
1061                 break;
1062         case VHOST_IOTLB_INVALIDATE:
1063                 vhost_vdpa_unmap(v, iotlb, msg->iova, msg->size);
1064                 break;
1065         case VHOST_IOTLB_BATCH_BEGIN:
1066                 v->batch_asid = asid;
1067                 v->in_batch = true;
1068                 break;
1069         case VHOST_IOTLB_BATCH_END:
1070                 if (v->in_batch && ops->set_map)
1071                         ops->set_map(vdpa, asid, iotlb);
1072                 v->in_batch = false;
1073                 if (!iotlb->nmaps)
1074                         vhost_vdpa_remove_as(v, asid);
1075                 break;
1076         default:
1077                 r = -EINVAL;
1078                 break;
1079         }
1080 unlock:
1081         mutex_unlock(&dev->mutex);
1082
1083         return r;
1084 }
1085
1086 static ssize_t vhost_vdpa_chr_write_iter(struct kiocb *iocb,
1087                                          struct iov_iter *from)
1088 {
1089         struct file *file = iocb->ki_filp;
1090         struct vhost_vdpa *v = file->private_data;
1091         struct vhost_dev *dev = &v->vdev;
1092
1093         return vhost_chr_write_iter(dev, from);
1094 }
1095
1096 static int vhost_vdpa_alloc_domain(struct vhost_vdpa *v)
1097 {
1098         struct vdpa_device *vdpa = v->vdpa;
1099         const struct vdpa_config_ops *ops = vdpa->config;
1100         struct device *dma_dev = vdpa_get_dma_dev(vdpa);
1101         struct bus_type *bus;
1102         int ret;
1103
1104         /* Device want to do DMA by itself */
1105         if (ops->set_map || ops->dma_map)
1106                 return 0;
1107
1108         bus = dma_dev->bus;
1109         if (!bus)
1110                 return -EFAULT;
1111
1112         if (!device_iommu_capable(dma_dev, IOMMU_CAP_CACHE_COHERENCY))
1113                 return -ENOTSUPP;
1114
1115         v->domain = iommu_domain_alloc(bus);
1116         if (!v->domain)
1117                 return -EIO;
1118
1119         ret = iommu_attach_device(v->domain, dma_dev);
1120         if (ret)
1121                 goto err_attach;
1122
1123         return 0;
1124
1125 err_attach:
1126         iommu_domain_free(v->domain);
1127         return ret;
1128 }
1129
1130 static void vhost_vdpa_free_domain(struct vhost_vdpa *v)
1131 {
1132         struct vdpa_device *vdpa = v->vdpa;
1133         struct device *dma_dev = vdpa_get_dma_dev(vdpa);
1134
1135         if (v->domain) {
1136                 iommu_detach_device(v->domain, dma_dev);
1137                 iommu_domain_free(v->domain);
1138         }
1139
1140         v->domain = NULL;
1141 }
1142
1143 static void vhost_vdpa_set_iova_range(struct vhost_vdpa *v)
1144 {
1145         struct vdpa_iova_range *range = &v->range;
1146         struct vdpa_device *vdpa = v->vdpa;
1147         const struct vdpa_config_ops *ops = vdpa->config;
1148
1149         if (ops->get_iova_range) {
1150                 *range = ops->get_iova_range(vdpa);
1151         } else if (v->domain && v->domain->geometry.force_aperture) {
1152                 range->first = v->domain->geometry.aperture_start;
1153                 range->last = v->domain->geometry.aperture_end;
1154         } else {
1155                 range->first = 0;
1156                 range->last = ULLONG_MAX;
1157         }
1158 }
1159
1160 static void vhost_vdpa_cleanup(struct vhost_vdpa *v)
1161 {
1162         struct vhost_vdpa_as *as;
1163         u32 asid;
1164
1165         vhost_dev_cleanup(&v->vdev);
1166         kfree(v->vdev.vqs);
1167
1168         for (asid = 0; asid < v->vdpa->nas; asid++) {
1169                 as = asid_to_as(v, asid);
1170                 if (as)
1171                         vhost_vdpa_remove_as(v, asid);
1172         }
1173 }
1174
1175 static int vhost_vdpa_open(struct inode *inode, struct file *filep)
1176 {
1177         struct vhost_vdpa *v;
1178         struct vhost_dev *dev;
1179         struct vhost_virtqueue **vqs;
1180         int r, opened;
1181         u32 i, nvqs;
1182
1183         v = container_of(inode->i_cdev, struct vhost_vdpa, cdev);
1184
1185         opened = atomic_cmpxchg(&v->opened, 0, 1);
1186         if (opened)
1187                 return -EBUSY;
1188
1189         nvqs = v->nvqs;
1190         r = vhost_vdpa_reset(v);
1191         if (r)
1192                 goto err;
1193
1194         vqs = kmalloc_array(nvqs, sizeof(*vqs), GFP_KERNEL);
1195         if (!vqs) {
1196                 r = -ENOMEM;
1197                 goto err;
1198         }
1199
1200         dev = &v->vdev;
1201         for (i = 0; i < nvqs; i++) {
1202                 vqs[i] = &v->vqs[i];
1203                 vqs[i]->handle_kick = handle_vq_kick;
1204         }
1205         vhost_dev_init(dev, vqs, nvqs, 0, 0, 0, false,
1206                        vhost_vdpa_process_iotlb_msg);
1207
1208         r = vhost_vdpa_alloc_domain(v);
1209         if (r)
1210                 goto err_alloc_domain;
1211
1212         vhost_vdpa_set_iova_range(v);
1213
1214         filep->private_data = v;
1215
1216         return 0;
1217
1218 err_alloc_domain:
1219         vhost_vdpa_cleanup(v);
1220 err:
1221         atomic_dec(&v->opened);
1222         return r;
1223 }
1224
1225 static void vhost_vdpa_clean_irq(struct vhost_vdpa *v)
1226 {
1227         u32 i;
1228
1229         for (i = 0; i < v->nvqs; i++)
1230                 vhost_vdpa_unsetup_vq_irq(v, i);
1231 }
1232
1233 static int vhost_vdpa_release(struct inode *inode, struct file *filep)
1234 {
1235         struct vhost_vdpa *v = filep->private_data;
1236         struct vhost_dev *d = &v->vdev;
1237
1238         mutex_lock(&d->mutex);
1239         filep->private_data = NULL;
1240         vhost_vdpa_clean_irq(v);
1241         vhost_vdpa_reset(v);
1242         vhost_dev_stop(&v->vdev);
1243         vhost_vdpa_free_domain(v);
1244         vhost_vdpa_config_put(v);
1245         vhost_vdpa_cleanup(v);
1246         mutex_unlock(&d->mutex);
1247
1248         atomic_dec(&v->opened);
1249         complete(&v->completion);
1250
1251         return 0;
1252 }
1253
1254 #ifdef CONFIG_MMU
1255 static vm_fault_t vhost_vdpa_fault(struct vm_fault *vmf)
1256 {
1257         struct vhost_vdpa *v = vmf->vma->vm_file->private_data;
1258         struct vdpa_device *vdpa = v->vdpa;
1259         const struct vdpa_config_ops *ops = vdpa->config;
1260         struct vdpa_notification_area notify;
1261         struct vm_area_struct *vma = vmf->vma;
1262         u16 index = vma->vm_pgoff;
1263
1264         notify = ops->get_vq_notification(vdpa, index);
1265
1266         vma->vm_page_prot = pgprot_noncached(vma->vm_page_prot);
1267         if (remap_pfn_range(vma, vmf->address & PAGE_MASK,
1268                             PFN_DOWN(notify.addr), PAGE_SIZE,
1269                             vma->vm_page_prot))
1270                 return VM_FAULT_SIGBUS;
1271
1272         return VM_FAULT_NOPAGE;
1273 }
1274
1275 static const struct vm_operations_struct vhost_vdpa_vm_ops = {
1276         .fault = vhost_vdpa_fault,
1277 };
1278
1279 static int vhost_vdpa_mmap(struct file *file, struct vm_area_struct *vma)
1280 {
1281         struct vhost_vdpa *v = vma->vm_file->private_data;
1282         struct vdpa_device *vdpa = v->vdpa;
1283         const struct vdpa_config_ops *ops = vdpa->config;
1284         struct vdpa_notification_area notify;
1285         unsigned long index = vma->vm_pgoff;
1286
1287         if (vma->vm_end - vma->vm_start != PAGE_SIZE)
1288                 return -EINVAL;
1289         if ((vma->vm_flags & VM_SHARED) == 0)
1290                 return -EINVAL;
1291         if (vma->vm_flags & VM_READ)
1292                 return -EINVAL;
1293         if (index > 65535)
1294                 return -EINVAL;
1295         if (!ops->get_vq_notification)
1296                 return -ENOTSUPP;
1297
1298         /* To be safe and easily modelled by userspace, We only
1299          * support the doorbell which sits on the page boundary and
1300          * does not share the page with other registers.
1301          */
1302         notify = ops->get_vq_notification(vdpa, index);
1303         if (notify.addr & (PAGE_SIZE - 1))
1304                 return -EINVAL;
1305         if (vma->vm_end - vma->vm_start != notify.size)
1306                 return -ENOTSUPP;
1307
1308         vma->vm_flags |= VM_IO | VM_PFNMAP | VM_DONTEXPAND | VM_DONTDUMP;
1309         vma->vm_ops = &vhost_vdpa_vm_ops;
1310         return 0;
1311 }
1312 #endif /* CONFIG_MMU */
1313
1314 static const struct file_operations vhost_vdpa_fops = {
1315         .owner          = THIS_MODULE,
1316         .open           = vhost_vdpa_open,
1317         .release        = vhost_vdpa_release,
1318         .write_iter     = vhost_vdpa_chr_write_iter,
1319         .unlocked_ioctl = vhost_vdpa_unlocked_ioctl,
1320 #ifdef CONFIG_MMU
1321         .mmap           = vhost_vdpa_mmap,
1322 #endif /* CONFIG_MMU */
1323         .compat_ioctl   = compat_ptr_ioctl,
1324 };
1325
1326 static void vhost_vdpa_release_dev(struct device *device)
1327 {
1328         struct vhost_vdpa *v =
1329                container_of(device, struct vhost_vdpa, dev);
1330
1331         ida_simple_remove(&vhost_vdpa_ida, v->minor);
1332         kfree(v->vqs);
1333         kfree(v);
1334 }
1335
1336 static int vhost_vdpa_probe(struct vdpa_device *vdpa)
1337 {
1338         const struct vdpa_config_ops *ops = vdpa->config;
1339         struct vhost_vdpa *v;
1340         int minor;
1341         int i, r;
1342
1343         /* We can't support platform IOMMU device with more than 1
1344          * group or as
1345          */
1346         if (!ops->set_map && !ops->dma_map &&
1347             (vdpa->ngroups > 1 || vdpa->nas > 1))
1348                 return -EOPNOTSUPP;
1349
1350         v = kzalloc(sizeof(*v), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
1351         if (!v)
1352                 return -ENOMEM;
1353
1354         minor = ida_simple_get(&vhost_vdpa_ida, 0,
1355                                VHOST_VDPA_DEV_MAX, GFP_KERNEL);
1356         if (minor < 0) {
1357                 kfree(v);
1358                 return minor;
1359         }
1360
1361         atomic_set(&v->opened, 0);
1362         v->minor = minor;
1363         v->vdpa = vdpa;
1364         v->nvqs = vdpa->nvqs;
1365         v->virtio_id = ops->get_device_id(vdpa);
1366
1367         device_initialize(&v->dev);
1368         v->dev.release = vhost_vdpa_release_dev;
1369         v->dev.parent = &vdpa->dev;
1370         v->dev.devt = MKDEV(MAJOR(vhost_vdpa_major), minor);
1371         v->vqs = kmalloc_array(v->nvqs, sizeof(struct vhost_virtqueue),
1372                                GFP_KERNEL);
1373         if (!v->vqs) {
1374                 r = -ENOMEM;
1375                 goto err;
1376         }
1377
1378         r = dev_set_name(&v->dev, "vhost-vdpa-%u", minor);
1379         if (r)
1380                 goto err;
1381
1382         cdev_init(&v->cdev, &vhost_vdpa_fops);
1383         v->cdev.owner = THIS_MODULE;
1384
1385         r = cdev_device_add(&v->cdev, &v->dev);
1386         if (r)
1387                 goto err;
1388
1389         init_completion(&v->completion);
1390         vdpa_set_drvdata(vdpa, v);
1391
1392         for (i = 0; i < VHOST_VDPA_IOTLB_BUCKETS; i++)
1393                 INIT_HLIST_HEAD(&v->as[i]);
1394
1395         return 0;
1396
1397 err:
1398         put_device(&v->dev);
1399         ida_simple_remove(&vhost_vdpa_ida, v->minor);
1400         return r;
1401 }
1402
1403 static void vhost_vdpa_remove(struct vdpa_device *vdpa)
1404 {
1405         struct vhost_vdpa *v = vdpa_get_drvdata(vdpa);
1406         int opened;
1407
1408         cdev_device_del(&v->cdev, &v->dev);
1409
1410         do {
1411                 opened = atomic_cmpxchg(&v->opened, 0, 1);
1412                 if (!opened)
1413                         break;
1414                 wait_for_completion(&v->completion);
1415         } while (1);
1416
1417         put_device(&v->dev);
1418 }
1419
1420 static struct vdpa_driver vhost_vdpa_driver = {
1421         .driver = {
1422                 .name   = "vhost_vdpa",
1423         },
1424         .probe  = vhost_vdpa_probe,
1425         .remove = vhost_vdpa_remove,
1426 };
1427
1428 static int __init vhost_vdpa_init(void)
1429 {
1430         int r;
1431
1432         r = alloc_chrdev_region(&vhost_vdpa_major, 0, VHOST_VDPA_DEV_MAX,
1433                                 "vhost-vdpa");
1434         if (r)
1435                 goto err_alloc_chrdev;
1436
1437         r = vdpa_register_driver(&vhost_vdpa_driver);
1438         if (r)
1439                 goto err_vdpa_register_driver;
1440
1441         return 0;
1442
1443 err_vdpa_register_driver:
1444         unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1445 err_alloc_chrdev:
1446         return r;
1447 }
1448 module_init(vhost_vdpa_init);
1449
1450 static void __exit vhost_vdpa_exit(void)
1451 {
1452         vdpa_unregister_driver(&vhost_vdpa_driver);
1453         unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1454 }
1455 module_exit(vhost_vdpa_exit);
1456
1457 MODULE_VERSION("0.0.1");
1458 MODULE_LICENSE("GPL v2");
1459 MODULE_AUTHOR("Intel Corporation");
1460 MODULE_DESCRIPTION("vDPA-based vhost backend for virtio");