vhost: take worker or vq instead of dev for queueing
[platform/kernel/linux-rpi.git] / drivers / vhost / vhost.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (C) 2009 Red Hat, Inc.
3  * Copyright (C) 2006 Rusty Russell IBM Corporation
4  *
5  * Author: Michael S. Tsirkin <mst@redhat.com>
6  *
7  * Inspiration, some code, and most witty comments come from
8  * Documentation/virtual/lguest/lguest.c, by Rusty Russell
9  *
10  * Generic code for virtio server in host kernel.
11  */
12
13 #include <linux/eventfd.h>
14 #include <linux/vhost.h>
15 #include <linux/uio.h>
16 #include <linux/mm.h>
17 #include <linux/miscdevice.h>
18 #include <linux/mutex.h>
19 #include <linux/poll.h>
20 #include <linux/file.h>
21 #include <linux/highmem.h>
22 #include <linux/slab.h>
23 #include <linux/vmalloc.h>
24 #include <linux/kthread.h>
25 #include <linux/module.h>
26 #include <linux/sort.h>
27 #include <linux/sched/mm.h>
28 #include <linux/sched/signal.h>
29 #include <linux/sched/vhost_task.h>
30 #include <linux/interval_tree_generic.h>
31 #include <linux/nospec.h>
32 #include <linux/kcov.h>
33
34 #include "vhost.h"
35
36 static ushort max_mem_regions = 64;
37 module_param(max_mem_regions, ushort, 0444);
38 MODULE_PARM_DESC(max_mem_regions,
39         "Maximum number of memory regions in memory map. (default: 64)");
40 static int max_iotlb_entries = 2048;
41 module_param(max_iotlb_entries, int, 0444);
42 MODULE_PARM_DESC(max_iotlb_entries,
43         "Maximum number of iotlb entries. (default: 2048)");
44
45 enum {
46         VHOST_MEMORY_F_LOG = 0x1,
47 };
48
49 #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
50 #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
51
52 #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
53 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
54 {
55         vq->user_be = !virtio_legacy_is_little_endian();
56 }
57
58 static void vhost_enable_cross_endian_big(struct vhost_virtqueue *vq)
59 {
60         vq->user_be = true;
61 }
62
63 static void vhost_enable_cross_endian_little(struct vhost_virtqueue *vq)
64 {
65         vq->user_be = false;
66 }
67
68 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
69 {
70         struct vhost_vring_state s;
71
72         if (vq->private_data)
73                 return -EBUSY;
74
75         if (copy_from_user(&s, argp, sizeof(s)))
76                 return -EFAULT;
77
78         if (s.num != VHOST_VRING_LITTLE_ENDIAN &&
79             s.num != VHOST_VRING_BIG_ENDIAN)
80                 return -EINVAL;
81
82         if (s.num == VHOST_VRING_BIG_ENDIAN)
83                 vhost_enable_cross_endian_big(vq);
84         else
85                 vhost_enable_cross_endian_little(vq);
86
87         return 0;
88 }
89
90 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
91                                    int __user *argp)
92 {
93         struct vhost_vring_state s = {
94                 .index = idx,
95                 .num = vq->user_be
96         };
97
98         if (copy_to_user(argp, &s, sizeof(s)))
99                 return -EFAULT;
100
101         return 0;
102 }
103
104 static void vhost_init_is_le(struct vhost_virtqueue *vq)
105 {
106         /* Note for legacy virtio: user_be is initialized at reset time
107          * according to the host endianness. If userspace does not set an
108          * explicit endianness, the default behavior is native endian, as
109          * expected by legacy virtio.
110          */
111         vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) || !vq->user_be;
112 }
113 #else
114 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
115 {
116 }
117
118 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
119 {
120         return -ENOIOCTLCMD;
121 }
122
123 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
124                                    int __user *argp)
125 {
126         return -ENOIOCTLCMD;
127 }
128
129 static void vhost_init_is_le(struct vhost_virtqueue *vq)
130 {
131         vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1)
132                 || virtio_legacy_is_little_endian();
133 }
134 #endif /* CONFIG_VHOST_CROSS_ENDIAN_LEGACY */
135
136 static void vhost_reset_is_le(struct vhost_virtqueue *vq)
137 {
138         vhost_init_is_le(vq);
139 }
140
141 struct vhost_flush_struct {
142         struct vhost_work work;
143         struct completion wait_event;
144 };
145
146 static void vhost_flush_work(struct vhost_work *work)
147 {
148         struct vhost_flush_struct *s;
149
150         s = container_of(work, struct vhost_flush_struct, work);
151         complete(&s->wait_event);
152 }
153
154 static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
155                             poll_table *pt)
156 {
157         struct vhost_poll *poll;
158
159         poll = container_of(pt, struct vhost_poll, table);
160         poll->wqh = wqh;
161         add_wait_queue(wqh, &poll->wait);
162 }
163
164 static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync,
165                              void *key)
166 {
167         struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait);
168         struct vhost_work *work = &poll->work;
169
170         if (!(key_to_poll(key) & poll->mask))
171                 return 0;
172
173         if (!poll->dev->use_worker)
174                 work->fn(work);
175         else
176                 vhost_poll_queue(poll);
177
178         return 0;
179 }
180
181 void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
182 {
183         clear_bit(VHOST_WORK_QUEUED, &work->flags);
184         work->fn = fn;
185 }
186 EXPORT_SYMBOL_GPL(vhost_work_init);
187
188 /* Init poll structure */
189 void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
190                      __poll_t mask, struct vhost_dev *dev)
191 {
192         init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);
193         init_poll_funcptr(&poll->table, vhost_poll_func);
194         poll->mask = mask;
195         poll->dev = dev;
196         poll->wqh = NULL;
197
198         vhost_work_init(&poll->work, fn);
199 }
200 EXPORT_SYMBOL_GPL(vhost_poll_init);
201
202 /* Start polling a file. We add ourselves to file's wait queue. The caller must
203  * keep a reference to a file until after vhost_poll_stop is called. */
204 int vhost_poll_start(struct vhost_poll *poll, struct file *file)
205 {
206         __poll_t mask;
207
208         if (poll->wqh)
209                 return 0;
210
211         mask = vfs_poll(file, &poll->table);
212         if (mask)
213                 vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask));
214         if (mask & EPOLLERR) {
215                 vhost_poll_stop(poll);
216                 return -EINVAL;
217         }
218
219         return 0;
220 }
221 EXPORT_SYMBOL_GPL(vhost_poll_start);
222
223 /* Stop polling a file. After this function returns, it becomes safe to drop the
224  * file reference. You must also flush afterwards. */
225 void vhost_poll_stop(struct vhost_poll *poll)
226 {
227         if (poll->wqh) {
228                 remove_wait_queue(poll->wqh, &poll->wait);
229                 poll->wqh = NULL;
230         }
231 }
232 EXPORT_SYMBOL_GPL(vhost_poll_stop);
233
234 static bool vhost_worker_queue(struct vhost_worker *worker,
235                                struct vhost_work *work)
236 {
237         if (!worker)
238                 return false;
239         /*
240          * vsock can queue while we do a VHOST_SET_OWNER, so we have a smp_wmb
241          * when setting up the worker. We don't have a smp_rmb here because
242          * test_and_set_bit gives us a mb already.
243          */
244         if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
245                 /* We can only add the work to the list after we're
246                  * sure it was not in the list.
247                  * test_and_set_bit() implies a memory barrier.
248                  */
249                 llist_add(&work->node, &worker->work_list);
250                 vhost_task_wake(worker->vtsk);
251         }
252
253         return true;
254 }
255
256 bool vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
257 {
258         return vhost_worker_queue(dev->worker, work);
259 }
260 EXPORT_SYMBOL_GPL(vhost_work_queue);
261
262 bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work)
263 {
264         return vhost_worker_queue(vq->worker, work);
265 }
266 EXPORT_SYMBOL_GPL(vhost_vq_work_queue);
267
268 void vhost_dev_flush(struct vhost_dev *dev)
269 {
270         struct vhost_flush_struct flush;
271
272         init_completion(&flush.wait_event);
273         vhost_work_init(&flush.work, vhost_flush_work);
274
275         if (vhost_work_queue(dev, &flush.work))
276                 wait_for_completion(&flush.wait_event);
277 }
278 EXPORT_SYMBOL_GPL(vhost_dev_flush);
279
280 /* A lockless hint for busy polling code to exit the loop */
281 bool vhost_vq_has_work(struct vhost_virtqueue *vq)
282 {
283         return !llist_empty(&vq->worker->work_list);
284 }
285 EXPORT_SYMBOL_GPL(vhost_vq_has_work);
286
287 void vhost_poll_queue(struct vhost_poll *poll)
288 {
289         vhost_work_queue(poll->dev, &poll->work);
290 }
291 EXPORT_SYMBOL_GPL(vhost_poll_queue);
292
293 static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq)
294 {
295         int j;
296
297         for (j = 0; j < VHOST_NUM_ADDRS; j++)
298                 vq->meta_iotlb[j] = NULL;
299 }
300
301 static void vhost_vq_meta_reset(struct vhost_dev *d)
302 {
303         int i;
304
305         for (i = 0; i < d->nvqs; ++i)
306                 __vhost_vq_meta_reset(d->vqs[i]);
307 }
308
309 static void vhost_vring_call_reset(struct vhost_vring_call *call_ctx)
310 {
311         call_ctx->ctx = NULL;
312         memset(&call_ctx->producer, 0x0, sizeof(struct irq_bypass_producer));
313 }
314
315 bool vhost_vq_is_setup(struct vhost_virtqueue *vq)
316 {
317         return vq->avail && vq->desc && vq->used && vhost_vq_access_ok(vq);
318 }
319 EXPORT_SYMBOL_GPL(vhost_vq_is_setup);
320
321 static void vhost_vq_reset(struct vhost_dev *dev,
322                            struct vhost_virtqueue *vq)
323 {
324         vq->num = 1;
325         vq->desc = NULL;
326         vq->avail = NULL;
327         vq->used = NULL;
328         vq->last_avail_idx = 0;
329         vq->avail_idx = 0;
330         vq->last_used_idx = 0;
331         vq->signalled_used = 0;
332         vq->signalled_used_valid = false;
333         vq->used_flags = 0;
334         vq->log_used = false;
335         vq->log_addr = -1ull;
336         vq->private_data = NULL;
337         vq->acked_features = 0;
338         vq->acked_backend_features = 0;
339         vq->log_base = NULL;
340         vq->error_ctx = NULL;
341         vq->kick = NULL;
342         vq->log_ctx = NULL;
343         vhost_disable_cross_endian(vq);
344         vhost_reset_is_le(vq);
345         vq->busyloop_timeout = 0;
346         vq->umem = NULL;
347         vq->iotlb = NULL;
348         vq->worker = NULL;
349         vhost_vring_call_reset(&vq->call_ctx);
350         __vhost_vq_meta_reset(vq);
351 }
352
353 static bool vhost_worker(void *data)
354 {
355         struct vhost_worker *worker = data;
356         struct vhost_work *work, *work_next;
357         struct llist_node *node;
358
359         node = llist_del_all(&worker->work_list);
360         if (node) {
361                 __set_current_state(TASK_RUNNING);
362
363                 node = llist_reverse_order(node);
364                 /* make sure flag is seen after deletion */
365                 smp_wmb();
366                 llist_for_each_entry_safe(work, work_next, node, node) {
367                         clear_bit(VHOST_WORK_QUEUED, &work->flags);
368                         kcov_remote_start_common(worker->kcov_handle);
369                         work->fn(work);
370                         kcov_remote_stop();
371                         cond_resched();
372                 }
373         }
374
375         return !!node;
376 }
377
378 static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
379 {
380         kfree(vq->indirect);
381         vq->indirect = NULL;
382         kfree(vq->log);
383         vq->log = NULL;
384         kfree(vq->heads);
385         vq->heads = NULL;
386 }
387
388 /* Helper to allocate iovec buffers for all vqs. */
389 static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
390 {
391         struct vhost_virtqueue *vq;
392         int i;
393
394         for (i = 0; i < dev->nvqs; ++i) {
395                 vq = dev->vqs[i];
396                 vq->indirect = kmalloc_array(UIO_MAXIOV,
397                                              sizeof(*vq->indirect),
398                                              GFP_KERNEL);
399                 vq->log = kmalloc_array(dev->iov_limit, sizeof(*vq->log),
400                                         GFP_KERNEL);
401                 vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads),
402                                           GFP_KERNEL);
403                 if (!vq->indirect || !vq->log || !vq->heads)
404                         goto err_nomem;
405         }
406         return 0;
407
408 err_nomem:
409         for (; i >= 0; --i)
410                 vhost_vq_free_iovecs(dev->vqs[i]);
411         return -ENOMEM;
412 }
413
414 static void vhost_dev_free_iovecs(struct vhost_dev *dev)
415 {
416         int i;
417
418         for (i = 0; i < dev->nvqs; ++i)
419                 vhost_vq_free_iovecs(dev->vqs[i]);
420 }
421
422 bool vhost_exceeds_weight(struct vhost_virtqueue *vq,
423                           int pkts, int total_len)
424 {
425         struct vhost_dev *dev = vq->dev;
426
427         if ((dev->byte_weight && total_len >= dev->byte_weight) ||
428             pkts >= dev->weight) {
429                 vhost_poll_queue(&vq->poll);
430                 return true;
431         }
432
433         return false;
434 }
435 EXPORT_SYMBOL_GPL(vhost_exceeds_weight);
436
437 static size_t vhost_get_avail_size(struct vhost_virtqueue *vq,
438                                    unsigned int num)
439 {
440         size_t event __maybe_unused =
441                vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
442
443         return size_add(struct_size(vq->avail, ring, num), event);
444 }
445
446 static size_t vhost_get_used_size(struct vhost_virtqueue *vq,
447                                   unsigned int num)
448 {
449         size_t event __maybe_unused =
450                vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
451
452         return size_add(struct_size(vq->used, ring, num), event);
453 }
454
455 static size_t vhost_get_desc_size(struct vhost_virtqueue *vq,
456                                   unsigned int num)
457 {
458         return sizeof(*vq->desc) * num;
459 }
460
461 void vhost_dev_init(struct vhost_dev *dev,
462                     struct vhost_virtqueue **vqs, int nvqs,
463                     int iov_limit, int weight, int byte_weight,
464                     bool use_worker,
465                     int (*msg_handler)(struct vhost_dev *dev, u32 asid,
466                                        struct vhost_iotlb_msg *msg))
467 {
468         struct vhost_virtqueue *vq;
469         int i;
470
471         dev->vqs = vqs;
472         dev->nvqs = nvqs;
473         mutex_init(&dev->mutex);
474         dev->log_ctx = NULL;
475         dev->umem = NULL;
476         dev->iotlb = NULL;
477         dev->mm = NULL;
478         dev->worker = NULL;
479         dev->iov_limit = iov_limit;
480         dev->weight = weight;
481         dev->byte_weight = byte_weight;
482         dev->use_worker = use_worker;
483         dev->msg_handler = msg_handler;
484         init_waitqueue_head(&dev->wait);
485         INIT_LIST_HEAD(&dev->read_list);
486         INIT_LIST_HEAD(&dev->pending_list);
487         spin_lock_init(&dev->iotlb_lock);
488
489
490         for (i = 0; i < dev->nvqs; ++i) {
491                 vq = dev->vqs[i];
492                 vq->log = NULL;
493                 vq->indirect = NULL;
494                 vq->heads = NULL;
495                 vq->dev = dev;
496                 mutex_init(&vq->mutex);
497                 vhost_vq_reset(dev, vq);
498                 if (vq->handle_kick)
499                         vhost_poll_init(&vq->poll, vq->handle_kick,
500                                         EPOLLIN, dev);
501         }
502 }
503 EXPORT_SYMBOL_GPL(vhost_dev_init);
504
505 /* Caller should have device mutex */
506 long vhost_dev_check_owner(struct vhost_dev *dev)
507 {
508         /* Are you the owner? If not, I don't think you mean to do that */
509         return dev->mm == current->mm ? 0 : -EPERM;
510 }
511 EXPORT_SYMBOL_GPL(vhost_dev_check_owner);
512
513 /* Caller should have device mutex */
514 bool vhost_dev_has_owner(struct vhost_dev *dev)
515 {
516         return dev->mm;
517 }
518 EXPORT_SYMBOL_GPL(vhost_dev_has_owner);
519
520 static void vhost_attach_mm(struct vhost_dev *dev)
521 {
522         /* No owner, become one */
523         if (dev->use_worker) {
524                 dev->mm = get_task_mm(current);
525         } else {
526                 /* vDPA device does not use worker thead, so there's
527                  * no need to hold the address space for mm. This help
528                  * to avoid deadlock in the case of mmap() which may
529                  * held the refcnt of the file and depends on release
530                  * method to remove vma.
531                  */
532                 dev->mm = current->mm;
533                 mmgrab(dev->mm);
534         }
535 }
536
537 static void vhost_detach_mm(struct vhost_dev *dev)
538 {
539         if (!dev->mm)
540                 return;
541
542         if (dev->use_worker)
543                 mmput(dev->mm);
544         else
545                 mmdrop(dev->mm);
546
547         dev->mm = NULL;
548 }
549
550 static void vhost_worker_free(struct vhost_dev *dev)
551 {
552         if (!dev->worker)
553                 return;
554
555         WARN_ON(!llist_empty(&dev->worker->work_list));
556         vhost_task_stop(dev->worker->vtsk);
557         kfree(dev->worker);
558         dev->worker = NULL;
559 }
560
561 static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
562 {
563         struct vhost_worker *worker;
564         struct vhost_task *vtsk;
565         char name[TASK_COMM_LEN];
566
567         worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
568         if (!worker)
569                 return NULL;
570
571         snprintf(name, sizeof(name), "vhost-%d", current->pid);
572
573         vtsk = vhost_task_create(vhost_worker, worker, name);
574         if (!vtsk)
575                 goto free_worker;
576
577         init_llist_head(&worker->work_list);
578         worker->kcov_handle = kcov_common_handle();
579         worker->vtsk = vtsk;
580         /*
581          * vsock can already try to queue so make sure llist and vtsk are both
582          * set before vhost_work_queue sees dev->worker is set.
583          */
584         smp_wmb();
585         dev->worker = worker;
586
587         vhost_task_start(vtsk);
588         return worker;
589
590 free_worker:
591         kfree(worker);
592         return NULL;
593 }
594
595 /* Caller should have device mutex */
596 long vhost_dev_set_owner(struct vhost_dev *dev)
597 {
598         struct vhost_worker *worker;
599         int err, i;
600
601         /* Is there an owner already? */
602         if (vhost_dev_has_owner(dev)) {
603                 err = -EBUSY;
604                 goto err_mm;
605         }
606
607         vhost_attach_mm(dev);
608
609         err = vhost_dev_alloc_iovecs(dev);
610         if (err)
611                 goto err_iovecs;
612
613         if (dev->use_worker) {
614                 /*
615                  * This should be done last, because vsock can queue work
616                  * before VHOST_SET_OWNER so it simplifies the failure path
617                  * below since we don't have to worry about vsock queueing
618                  * while we free the worker.
619                  */
620                 worker = vhost_worker_create(dev);
621                 if (!worker) {
622                         err = -ENOMEM;
623                         goto err_worker;
624                 }
625
626                 for (i = 0; i < dev->nvqs; i++)
627                         dev->vqs[i]->worker = worker;
628         }
629
630         return 0;
631
632 err_worker:
633         vhost_dev_free_iovecs(dev);
634 err_iovecs:
635         vhost_detach_mm(dev);
636 err_mm:
637         return err;
638 }
639 EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
640
641 static struct vhost_iotlb *iotlb_alloc(void)
642 {
643         return vhost_iotlb_alloc(max_iotlb_entries,
644                                  VHOST_IOTLB_FLAG_RETIRE);
645 }
646
647 struct vhost_iotlb *vhost_dev_reset_owner_prepare(void)
648 {
649         return iotlb_alloc();
650 }
651 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
652
653 /* Caller should have device mutex */
654 void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem)
655 {
656         int i;
657
658         vhost_dev_cleanup(dev);
659
660         dev->umem = umem;
661         /* We don't need VQ locks below since vhost_dev_cleanup makes sure
662          * VQs aren't running.
663          */
664         for (i = 0; i < dev->nvqs; ++i)
665                 dev->vqs[i]->umem = umem;
666 }
667 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
668
669 void vhost_dev_stop(struct vhost_dev *dev)
670 {
671         int i;
672
673         for (i = 0; i < dev->nvqs; ++i) {
674                 if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick)
675                         vhost_poll_stop(&dev->vqs[i]->poll);
676         }
677
678         vhost_dev_flush(dev);
679 }
680 EXPORT_SYMBOL_GPL(vhost_dev_stop);
681
682 void vhost_clear_msg(struct vhost_dev *dev)
683 {
684         struct vhost_msg_node *node, *n;
685
686         spin_lock(&dev->iotlb_lock);
687
688         list_for_each_entry_safe(node, n, &dev->read_list, node) {
689                 list_del(&node->node);
690                 kfree(node);
691         }
692
693         list_for_each_entry_safe(node, n, &dev->pending_list, node) {
694                 list_del(&node->node);
695                 kfree(node);
696         }
697
698         spin_unlock(&dev->iotlb_lock);
699 }
700 EXPORT_SYMBOL_GPL(vhost_clear_msg);
701
702 void vhost_dev_cleanup(struct vhost_dev *dev)
703 {
704         int i;
705
706         for (i = 0; i < dev->nvqs; ++i) {
707                 if (dev->vqs[i]->error_ctx)
708                         eventfd_ctx_put(dev->vqs[i]->error_ctx);
709                 if (dev->vqs[i]->kick)
710                         fput(dev->vqs[i]->kick);
711                 if (dev->vqs[i]->call_ctx.ctx)
712                         eventfd_ctx_put(dev->vqs[i]->call_ctx.ctx);
713                 vhost_vq_reset(dev, dev->vqs[i]);
714         }
715         vhost_dev_free_iovecs(dev);
716         if (dev->log_ctx)
717                 eventfd_ctx_put(dev->log_ctx);
718         dev->log_ctx = NULL;
719         /* No one will access memory at this point */
720         vhost_iotlb_free(dev->umem);
721         dev->umem = NULL;
722         vhost_iotlb_free(dev->iotlb);
723         dev->iotlb = NULL;
724         vhost_clear_msg(dev);
725         wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
726         vhost_worker_free(dev);
727         vhost_detach_mm(dev);
728 }
729 EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
730
731 static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
732 {
733         u64 a = addr / VHOST_PAGE_SIZE / 8;
734
735         /* Make sure 64 bit math will not overflow. */
736         if (a > ULONG_MAX - (unsigned long)log_base ||
737             a + (unsigned long)log_base > ULONG_MAX)
738                 return false;
739
740         return access_ok(log_base + a,
741                          (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
742 }
743
744 /* Make sure 64 bit math will not overflow. */
745 static bool vhost_overflow(u64 uaddr, u64 size)
746 {
747         if (uaddr > ULONG_MAX || size > ULONG_MAX)
748                 return true;
749
750         if (!size)
751                 return false;
752
753         return uaddr > ULONG_MAX - size + 1;
754 }
755
756 /* Caller should have vq mutex and device mutex. */
757 static bool vq_memory_access_ok(void __user *log_base, struct vhost_iotlb *umem,
758                                 int log_all)
759 {
760         struct vhost_iotlb_map *map;
761
762         if (!umem)
763                 return false;
764
765         list_for_each_entry(map, &umem->list, link) {
766                 unsigned long a = map->addr;
767
768                 if (vhost_overflow(map->addr, map->size))
769                         return false;
770
771
772                 if (!access_ok((void __user *)a, map->size))
773                         return false;
774                 else if (log_all && !log_access_ok(log_base,
775                                                    map->start,
776                                                    map->size))
777                         return false;
778         }
779         return true;
780 }
781
782 static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq,
783                                                u64 addr, unsigned int size,
784                                                int type)
785 {
786         const struct vhost_iotlb_map *map = vq->meta_iotlb[type];
787
788         if (!map)
789                 return NULL;
790
791         return (void __user *)(uintptr_t)(map->addr + addr - map->start);
792 }
793
794 /* Can we switch to this memory table? */
795 /* Caller should have device mutex but not vq mutex */
796 static bool memory_access_ok(struct vhost_dev *d, struct vhost_iotlb *umem,
797                              int log_all)
798 {
799         int i;
800
801         for (i = 0; i < d->nvqs; ++i) {
802                 bool ok;
803                 bool log;
804
805                 mutex_lock(&d->vqs[i]->mutex);
806                 log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
807                 /* If ring is inactive, will check when it's enabled. */
808                 if (d->vqs[i]->private_data)
809                         ok = vq_memory_access_ok(d->vqs[i]->log_base,
810                                                  umem, log);
811                 else
812                         ok = true;
813                 mutex_unlock(&d->vqs[i]->mutex);
814                 if (!ok)
815                         return false;
816         }
817         return true;
818 }
819
820 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
821                           struct iovec iov[], int iov_size, int access);
822
823 static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to,
824                               const void *from, unsigned size)
825 {
826         int ret;
827
828         if (!vq->iotlb)
829                 return __copy_to_user(to, from, size);
830         else {
831                 /* This function should be called after iotlb
832                  * prefetch, which means we're sure that all vq
833                  * could be access through iotlb. So -EAGAIN should
834                  * not happen in this case.
835                  */
836                 struct iov_iter t;
837                 void __user *uaddr = vhost_vq_meta_fetch(vq,
838                                      (u64)(uintptr_t)to, size,
839                                      VHOST_ADDR_USED);
840
841                 if (uaddr)
842                         return __copy_to_user(uaddr, from, size);
843
844                 ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
845                                      ARRAY_SIZE(vq->iotlb_iov),
846                                      VHOST_ACCESS_WO);
847                 if (ret < 0)
848                         goto out;
849                 iov_iter_init(&t, ITER_DEST, vq->iotlb_iov, ret, size);
850                 ret = copy_to_iter(from, size, &t);
851                 if (ret == size)
852                         ret = 0;
853         }
854 out:
855         return ret;
856 }
857
858 static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
859                                 void __user *from, unsigned size)
860 {
861         int ret;
862
863         if (!vq->iotlb)
864                 return __copy_from_user(to, from, size);
865         else {
866                 /* This function should be called after iotlb
867                  * prefetch, which means we're sure that vq
868                  * could be access through iotlb. So -EAGAIN should
869                  * not happen in this case.
870                  */
871                 void __user *uaddr = vhost_vq_meta_fetch(vq,
872                                      (u64)(uintptr_t)from, size,
873                                      VHOST_ADDR_DESC);
874                 struct iov_iter f;
875
876                 if (uaddr)
877                         return __copy_from_user(to, uaddr, size);
878
879                 ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
880                                      ARRAY_SIZE(vq->iotlb_iov),
881                                      VHOST_ACCESS_RO);
882                 if (ret < 0) {
883                         vq_err(vq, "IOTLB translation failure: uaddr "
884                                "%p size 0x%llx\n", from,
885                                (unsigned long long) size);
886                         goto out;
887                 }
888                 iov_iter_init(&f, ITER_SOURCE, vq->iotlb_iov, ret, size);
889                 ret = copy_from_iter(to, size, &f);
890                 if (ret == size)
891                         ret = 0;
892         }
893
894 out:
895         return ret;
896 }
897
898 static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq,
899                                           void __user *addr, unsigned int size,
900                                           int type)
901 {
902         int ret;
903
904         ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
905                              ARRAY_SIZE(vq->iotlb_iov),
906                              VHOST_ACCESS_RO);
907         if (ret < 0) {
908                 vq_err(vq, "IOTLB translation failure: uaddr "
909                         "%p size 0x%llx\n", addr,
910                         (unsigned long long) size);
911                 return NULL;
912         }
913
914         if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
915                 vq_err(vq, "Non atomic userspace memory access: uaddr "
916                         "%p size 0x%llx\n", addr,
917                         (unsigned long long) size);
918                 return NULL;
919         }
920
921         return vq->iotlb_iov[0].iov_base;
922 }
923
924 /* This function should be called after iotlb
925  * prefetch, which means we're sure that vq
926  * could be access through iotlb. So -EAGAIN should
927  * not happen in this case.
928  */
929 static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
930                                             void __user *addr, unsigned int size,
931                                             int type)
932 {
933         void __user *uaddr = vhost_vq_meta_fetch(vq,
934                              (u64)(uintptr_t)addr, size, type);
935         if (uaddr)
936                 return uaddr;
937
938         return __vhost_get_user_slow(vq, addr, size, type);
939 }
940
941 #define vhost_put_user(vq, x, ptr)              \
942 ({ \
943         int ret; \
944         if (!vq->iotlb) { \
945                 ret = __put_user(x, ptr); \
946         } else { \
947                 __typeof__(ptr) to = \
948                         (__typeof__(ptr)) __vhost_get_user(vq, ptr,     \
949                                           sizeof(*ptr), VHOST_ADDR_USED); \
950                 if (to != NULL) \
951                         ret = __put_user(x, to); \
952                 else \
953                         ret = -EFAULT;  \
954         } \
955         ret; \
956 })
957
958 static inline int vhost_put_avail_event(struct vhost_virtqueue *vq)
959 {
960         return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
961                               vhost_avail_event(vq));
962 }
963
964 static inline int vhost_put_used(struct vhost_virtqueue *vq,
965                                  struct vring_used_elem *head, int idx,
966                                  int count)
967 {
968         return vhost_copy_to_user(vq, vq->used->ring + idx, head,
969                                   count * sizeof(*head));
970 }
971
972 static inline int vhost_put_used_flags(struct vhost_virtqueue *vq)
973
974 {
975         return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
976                               &vq->used->flags);
977 }
978
979 static inline int vhost_put_used_idx(struct vhost_virtqueue *vq)
980
981 {
982         return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
983                               &vq->used->idx);
984 }
985
986 #define vhost_get_user(vq, x, ptr, type)                \
987 ({ \
988         int ret; \
989         if (!vq->iotlb) { \
990                 ret = __get_user(x, ptr); \
991         } else { \
992                 __typeof__(ptr) from = \
993                         (__typeof__(ptr)) __vhost_get_user(vq, ptr, \
994                                                            sizeof(*ptr), \
995                                                            type); \
996                 if (from != NULL) \
997                         ret = __get_user(x, from); \
998                 else \
999                         ret = -EFAULT; \
1000         } \
1001         ret; \
1002 })
1003
1004 #define vhost_get_avail(vq, x, ptr) \
1005         vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL)
1006
1007 #define vhost_get_used(vq, x, ptr) \
1008         vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
1009
1010 static void vhost_dev_lock_vqs(struct vhost_dev *d)
1011 {
1012         int i = 0;
1013         for (i = 0; i < d->nvqs; ++i)
1014                 mutex_lock_nested(&d->vqs[i]->mutex, i);
1015 }
1016
1017 static void vhost_dev_unlock_vqs(struct vhost_dev *d)
1018 {
1019         int i = 0;
1020         for (i = 0; i < d->nvqs; ++i)
1021                 mutex_unlock(&d->vqs[i]->mutex);
1022 }
1023
1024 static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq,
1025                                       __virtio16 *idx)
1026 {
1027         return vhost_get_avail(vq, *idx, &vq->avail->idx);
1028 }
1029
1030 static inline int vhost_get_avail_head(struct vhost_virtqueue *vq,
1031                                        __virtio16 *head, int idx)
1032 {
1033         return vhost_get_avail(vq, *head,
1034                                &vq->avail->ring[idx & (vq->num - 1)]);
1035 }
1036
1037 static inline int vhost_get_avail_flags(struct vhost_virtqueue *vq,
1038                                         __virtio16 *flags)
1039 {
1040         return vhost_get_avail(vq, *flags, &vq->avail->flags);
1041 }
1042
1043 static inline int vhost_get_used_event(struct vhost_virtqueue *vq,
1044                                        __virtio16 *event)
1045 {
1046         return vhost_get_avail(vq, *event, vhost_used_event(vq));
1047 }
1048
1049 static inline int vhost_get_used_idx(struct vhost_virtqueue *vq,
1050                                      __virtio16 *idx)
1051 {
1052         return vhost_get_used(vq, *idx, &vq->used->idx);
1053 }
1054
1055 static inline int vhost_get_desc(struct vhost_virtqueue *vq,
1056                                  struct vring_desc *desc, int idx)
1057 {
1058         return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc));
1059 }
1060
1061 static void vhost_iotlb_notify_vq(struct vhost_dev *d,
1062                                   struct vhost_iotlb_msg *msg)
1063 {
1064         struct vhost_msg_node *node, *n;
1065
1066         spin_lock(&d->iotlb_lock);
1067
1068         list_for_each_entry_safe(node, n, &d->pending_list, node) {
1069                 struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
1070                 if (msg->iova <= vq_msg->iova &&
1071                     msg->iova + msg->size - 1 >= vq_msg->iova &&
1072                     vq_msg->type == VHOST_IOTLB_MISS) {
1073                         vhost_poll_queue(&node->vq->poll);
1074                         list_del(&node->node);
1075                         kfree(node);
1076                 }
1077         }
1078
1079         spin_unlock(&d->iotlb_lock);
1080 }
1081
1082 static bool umem_access_ok(u64 uaddr, u64 size, int access)
1083 {
1084         unsigned long a = uaddr;
1085
1086         /* Make sure 64 bit math will not overflow. */
1087         if (vhost_overflow(uaddr, size))
1088                 return false;
1089
1090         if ((access & VHOST_ACCESS_RO) &&
1091             !access_ok((void __user *)a, size))
1092                 return false;
1093         if ((access & VHOST_ACCESS_WO) &&
1094             !access_ok((void __user *)a, size))
1095                 return false;
1096         return true;
1097 }
1098
1099 static int vhost_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
1100                                    struct vhost_iotlb_msg *msg)
1101 {
1102         int ret = 0;
1103
1104         if (asid != 0)
1105                 return -EINVAL;
1106
1107         mutex_lock(&dev->mutex);
1108         vhost_dev_lock_vqs(dev);
1109         switch (msg->type) {
1110         case VHOST_IOTLB_UPDATE:
1111                 if (!dev->iotlb) {
1112                         ret = -EFAULT;
1113                         break;
1114                 }
1115                 if (!umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
1116                         ret = -EFAULT;
1117                         break;
1118                 }
1119                 vhost_vq_meta_reset(dev);
1120                 if (vhost_iotlb_add_range(dev->iotlb, msg->iova,
1121                                           msg->iova + msg->size - 1,
1122                                           msg->uaddr, msg->perm)) {
1123                         ret = -ENOMEM;
1124                         break;
1125                 }
1126                 vhost_iotlb_notify_vq(dev, msg);
1127                 break;
1128         case VHOST_IOTLB_INVALIDATE:
1129                 if (!dev->iotlb) {
1130                         ret = -EFAULT;
1131                         break;
1132                 }
1133                 vhost_vq_meta_reset(dev);
1134                 vhost_iotlb_del_range(dev->iotlb, msg->iova,
1135                                       msg->iova + msg->size - 1);
1136                 break;
1137         default:
1138                 ret = -EINVAL;
1139                 break;
1140         }
1141
1142         vhost_dev_unlock_vqs(dev);
1143         mutex_unlock(&dev->mutex);
1144
1145         return ret;
1146 }
1147 ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
1148                              struct iov_iter *from)
1149 {
1150         struct vhost_iotlb_msg msg;
1151         size_t offset;
1152         int type, ret;
1153         u32 asid = 0;
1154
1155         ret = copy_from_iter(&type, sizeof(type), from);
1156         if (ret != sizeof(type)) {
1157                 ret = -EINVAL;
1158                 goto done;
1159         }
1160
1161         switch (type) {
1162         case VHOST_IOTLB_MSG:
1163                 /* There maybe a hole after type for V1 message type,
1164                  * so skip it here.
1165                  */
1166                 offset = offsetof(struct vhost_msg, iotlb) - sizeof(int);
1167                 break;
1168         case VHOST_IOTLB_MSG_V2:
1169                 if (vhost_backend_has_feature(dev->vqs[0],
1170                                               VHOST_BACKEND_F_IOTLB_ASID)) {
1171                         ret = copy_from_iter(&asid, sizeof(asid), from);
1172                         if (ret != sizeof(asid)) {
1173                                 ret = -EINVAL;
1174                                 goto done;
1175                         }
1176                         offset = 0;
1177                 } else
1178                         offset = sizeof(__u32);
1179                 break;
1180         default:
1181                 ret = -EINVAL;
1182                 goto done;
1183         }
1184
1185         iov_iter_advance(from, offset);
1186         ret = copy_from_iter(&msg, sizeof(msg), from);
1187         if (ret != sizeof(msg)) {
1188                 ret = -EINVAL;
1189                 goto done;
1190         }
1191
1192         if ((msg.type == VHOST_IOTLB_UPDATE ||
1193              msg.type == VHOST_IOTLB_INVALIDATE) &&
1194              msg.size == 0) {
1195                 ret = -EINVAL;
1196                 goto done;
1197         }
1198
1199         if (dev->msg_handler)
1200                 ret = dev->msg_handler(dev, asid, &msg);
1201         else
1202                 ret = vhost_process_iotlb_msg(dev, asid, &msg);
1203         if (ret) {
1204                 ret = -EFAULT;
1205                 goto done;
1206         }
1207
1208         ret = (type == VHOST_IOTLB_MSG) ? sizeof(struct vhost_msg) :
1209               sizeof(struct vhost_msg_v2);
1210 done:
1211         return ret;
1212 }
1213 EXPORT_SYMBOL(vhost_chr_write_iter);
1214
1215 __poll_t vhost_chr_poll(struct file *file, struct vhost_dev *dev,
1216                             poll_table *wait)
1217 {
1218         __poll_t mask = 0;
1219
1220         poll_wait(file, &dev->wait, wait);
1221
1222         if (!list_empty(&dev->read_list))
1223                 mask |= EPOLLIN | EPOLLRDNORM;
1224
1225         return mask;
1226 }
1227 EXPORT_SYMBOL(vhost_chr_poll);
1228
1229 ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
1230                             int noblock)
1231 {
1232         DEFINE_WAIT(wait);
1233         struct vhost_msg_node *node;
1234         ssize_t ret = 0;
1235         unsigned size = sizeof(struct vhost_msg);
1236
1237         if (iov_iter_count(to) < size)
1238                 return 0;
1239
1240         while (1) {
1241                 if (!noblock)
1242                         prepare_to_wait(&dev->wait, &wait,
1243                                         TASK_INTERRUPTIBLE);
1244
1245                 node = vhost_dequeue_msg(dev, &dev->read_list);
1246                 if (node)
1247                         break;
1248                 if (noblock) {
1249                         ret = -EAGAIN;
1250                         break;
1251                 }
1252                 if (signal_pending(current)) {
1253                         ret = -ERESTARTSYS;
1254                         break;
1255                 }
1256                 if (!dev->iotlb) {
1257                         ret = -EBADFD;
1258                         break;
1259                 }
1260
1261                 schedule();
1262         }
1263
1264         if (!noblock)
1265                 finish_wait(&dev->wait, &wait);
1266
1267         if (node) {
1268                 struct vhost_iotlb_msg *msg;
1269                 void *start = &node->msg;
1270
1271                 switch (node->msg.type) {
1272                 case VHOST_IOTLB_MSG:
1273                         size = sizeof(node->msg);
1274                         msg = &node->msg.iotlb;
1275                         break;
1276                 case VHOST_IOTLB_MSG_V2:
1277                         size = sizeof(node->msg_v2);
1278                         msg = &node->msg_v2.iotlb;
1279                         break;
1280                 default:
1281                         BUG();
1282                         break;
1283                 }
1284
1285                 ret = copy_to_iter(start, size, to);
1286                 if (ret != size || msg->type != VHOST_IOTLB_MISS) {
1287                         kfree(node);
1288                         return ret;
1289                 }
1290                 vhost_enqueue_msg(dev, &dev->pending_list, node);
1291         }
1292
1293         return ret;
1294 }
1295 EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
1296
1297 static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
1298 {
1299         struct vhost_dev *dev = vq->dev;
1300         struct vhost_msg_node *node;
1301         struct vhost_iotlb_msg *msg;
1302         bool v2 = vhost_backend_has_feature(vq, VHOST_BACKEND_F_IOTLB_MSG_V2);
1303
1304         node = vhost_new_msg(vq, v2 ? VHOST_IOTLB_MSG_V2 : VHOST_IOTLB_MSG);
1305         if (!node)
1306                 return -ENOMEM;
1307
1308         if (v2) {
1309                 node->msg_v2.type = VHOST_IOTLB_MSG_V2;
1310                 msg = &node->msg_v2.iotlb;
1311         } else {
1312                 msg = &node->msg.iotlb;
1313         }
1314
1315         msg->type = VHOST_IOTLB_MISS;
1316         msg->iova = iova;
1317         msg->perm = access;
1318
1319         vhost_enqueue_msg(dev, &dev->read_list, node);
1320
1321         return 0;
1322 }
1323
1324 static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
1325                          vring_desc_t __user *desc,
1326                          vring_avail_t __user *avail,
1327                          vring_used_t __user *used)
1328
1329 {
1330         /* If an IOTLB device is present, the vring addresses are
1331          * GIOVAs. Access validation occurs at prefetch time. */
1332         if (vq->iotlb)
1333                 return true;
1334
1335         return access_ok(desc, vhost_get_desc_size(vq, num)) &&
1336                access_ok(avail, vhost_get_avail_size(vq, num)) &&
1337                access_ok(used, vhost_get_used_size(vq, num));
1338 }
1339
1340 static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
1341                                  const struct vhost_iotlb_map *map,
1342                                  int type)
1343 {
1344         int access = (type == VHOST_ADDR_USED) ?
1345                      VHOST_ACCESS_WO : VHOST_ACCESS_RO;
1346
1347         if (likely(map->perm & access))
1348                 vq->meta_iotlb[type] = map;
1349 }
1350
1351 static bool iotlb_access_ok(struct vhost_virtqueue *vq,
1352                             int access, u64 addr, u64 len, int type)
1353 {
1354         const struct vhost_iotlb_map *map;
1355         struct vhost_iotlb *umem = vq->iotlb;
1356         u64 s = 0, size, orig_addr = addr, last = addr + len - 1;
1357
1358         if (vhost_vq_meta_fetch(vq, addr, len, type))
1359                 return true;
1360
1361         while (len > s) {
1362                 map = vhost_iotlb_itree_first(umem, addr, last);
1363                 if (map == NULL || map->start > addr) {
1364                         vhost_iotlb_miss(vq, addr, access);
1365                         return false;
1366                 } else if (!(map->perm & access)) {
1367                         /* Report the possible access violation by
1368                          * request another translation from userspace.
1369                          */
1370                         return false;
1371                 }
1372
1373                 size = map->size - addr + map->start;
1374
1375                 if (orig_addr == addr && size >= len)
1376                         vhost_vq_meta_update(vq, map, type);
1377
1378                 s += size;
1379                 addr += size;
1380         }
1381
1382         return true;
1383 }
1384
1385 int vq_meta_prefetch(struct vhost_virtqueue *vq)
1386 {
1387         unsigned int num = vq->num;
1388
1389         if (!vq->iotlb)
1390                 return 1;
1391
1392         return iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->desc,
1393                                vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) &&
1394                iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->avail,
1395                                vhost_get_avail_size(vq, num),
1396                                VHOST_ADDR_AVAIL) &&
1397                iotlb_access_ok(vq, VHOST_MAP_WO, (u64)(uintptr_t)vq->used,
1398                                vhost_get_used_size(vq, num), VHOST_ADDR_USED);
1399 }
1400 EXPORT_SYMBOL_GPL(vq_meta_prefetch);
1401
1402 /* Can we log writes? */
1403 /* Caller should have device mutex but not vq mutex */
1404 bool vhost_log_access_ok(struct vhost_dev *dev)
1405 {
1406         return memory_access_ok(dev, dev->umem, 1);
1407 }
1408 EXPORT_SYMBOL_GPL(vhost_log_access_ok);
1409
1410 static bool vq_log_used_access_ok(struct vhost_virtqueue *vq,
1411                                   void __user *log_base,
1412                                   bool log_used,
1413                                   u64 log_addr)
1414 {
1415         /* If an IOTLB device is present, log_addr is a GIOVA that
1416          * will never be logged by log_used(). */
1417         if (vq->iotlb)
1418                 return true;
1419
1420         return !log_used || log_access_ok(log_base, log_addr,
1421                                           vhost_get_used_size(vq, vq->num));
1422 }
1423
1424 /* Verify access for write logging. */
1425 /* Caller should have vq mutex and device mutex */
1426 static bool vq_log_access_ok(struct vhost_virtqueue *vq,
1427                              void __user *log_base)
1428 {
1429         return vq_memory_access_ok(log_base, vq->umem,
1430                                    vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
1431                 vq_log_used_access_ok(vq, log_base, vq->log_used, vq->log_addr);
1432 }
1433
1434 /* Can we start vq? */
1435 /* Caller should have vq mutex and device mutex */
1436 bool vhost_vq_access_ok(struct vhost_virtqueue *vq)
1437 {
1438         if (!vq_log_access_ok(vq, vq->log_base))
1439                 return false;
1440
1441         return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used);
1442 }
1443 EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
1444
1445 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
1446 {
1447         struct vhost_memory mem, *newmem;
1448         struct vhost_memory_region *region;
1449         struct vhost_iotlb *newumem, *oldumem;
1450         unsigned long size = offsetof(struct vhost_memory, regions);
1451         int i;
1452
1453         if (copy_from_user(&mem, m, size))
1454                 return -EFAULT;
1455         if (mem.padding)
1456                 return -EOPNOTSUPP;
1457         if (mem.nregions > max_mem_regions)
1458                 return -E2BIG;
1459         newmem = kvzalloc(struct_size(newmem, regions, mem.nregions),
1460                         GFP_KERNEL);
1461         if (!newmem)
1462                 return -ENOMEM;
1463
1464         memcpy(newmem, &mem, size);
1465         if (copy_from_user(newmem->regions, m->regions,
1466                            flex_array_size(newmem, regions, mem.nregions))) {
1467                 kvfree(newmem);
1468                 return -EFAULT;
1469         }
1470
1471         newumem = iotlb_alloc();
1472         if (!newumem) {
1473                 kvfree(newmem);
1474                 return -ENOMEM;
1475         }
1476
1477         for (region = newmem->regions;
1478              region < newmem->regions + mem.nregions;
1479              region++) {
1480                 if (vhost_iotlb_add_range(newumem,
1481                                           region->guest_phys_addr,
1482                                           region->guest_phys_addr +
1483                                           region->memory_size - 1,
1484                                           region->userspace_addr,
1485                                           VHOST_MAP_RW))
1486                         goto err;
1487         }
1488
1489         if (!memory_access_ok(d, newumem, 0))
1490                 goto err;
1491
1492         oldumem = d->umem;
1493         d->umem = newumem;
1494
1495         /* All memory accesses are done under some VQ mutex. */
1496         for (i = 0; i < d->nvqs; ++i) {
1497                 mutex_lock(&d->vqs[i]->mutex);
1498                 d->vqs[i]->umem = newumem;
1499                 mutex_unlock(&d->vqs[i]->mutex);
1500         }
1501
1502         kvfree(newmem);
1503         vhost_iotlb_free(oldumem);
1504         return 0;
1505
1506 err:
1507         vhost_iotlb_free(newumem);
1508         kvfree(newmem);
1509         return -EFAULT;
1510 }
1511
1512 static long vhost_vring_set_num(struct vhost_dev *d,
1513                                 struct vhost_virtqueue *vq,
1514                                 void __user *argp)
1515 {
1516         struct vhost_vring_state s;
1517
1518         /* Resizing ring with an active backend?
1519          * You don't want to do that. */
1520         if (vq->private_data)
1521                 return -EBUSY;
1522
1523         if (copy_from_user(&s, argp, sizeof s))
1524                 return -EFAULT;
1525
1526         if (!s.num || s.num > 0xffff || (s.num & (s.num - 1)))
1527                 return -EINVAL;
1528         vq->num = s.num;
1529
1530         return 0;
1531 }
1532
1533 static long vhost_vring_set_addr(struct vhost_dev *d,
1534                                  struct vhost_virtqueue *vq,
1535                                  void __user *argp)
1536 {
1537         struct vhost_vring_addr a;
1538
1539         if (copy_from_user(&a, argp, sizeof a))
1540                 return -EFAULT;
1541         if (a.flags & ~(0x1 << VHOST_VRING_F_LOG))
1542                 return -EOPNOTSUPP;
1543
1544         /* For 32bit, verify that the top 32bits of the user
1545            data are set to zero. */
1546         if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
1547             (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
1548             (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr)
1549                 return -EFAULT;
1550
1551         /* Make sure it's safe to cast pointers to vring types. */
1552         BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
1553         BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
1554         if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
1555             (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
1556             (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1)))
1557                 return -EINVAL;
1558
1559         /* We only verify access here if backend is configured.
1560          * If it is not, we don't as size might not have been setup.
1561          * We will verify when backend is configured. */
1562         if (vq->private_data) {
1563                 if (!vq_access_ok(vq, vq->num,
1564                         (void __user *)(unsigned long)a.desc_user_addr,
1565                         (void __user *)(unsigned long)a.avail_user_addr,
1566                         (void __user *)(unsigned long)a.used_user_addr))
1567                         return -EINVAL;
1568
1569                 /* Also validate log access for used ring if enabled. */
1570                 if (!vq_log_used_access_ok(vq, vq->log_base,
1571                                 a.flags & (0x1 << VHOST_VRING_F_LOG),
1572                                 a.log_guest_addr))
1573                         return -EINVAL;
1574         }
1575
1576         vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
1577         vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
1578         vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
1579         vq->log_addr = a.log_guest_addr;
1580         vq->used = (void __user *)(unsigned long)a.used_user_addr;
1581
1582         return 0;
1583 }
1584
1585 static long vhost_vring_set_num_addr(struct vhost_dev *d,
1586                                      struct vhost_virtqueue *vq,
1587                                      unsigned int ioctl,
1588                                      void __user *argp)
1589 {
1590         long r;
1591
1592         mutex_lock(&vq->mutex);
1593
1594         switch (ioctl) {
1595         case VHOST_SET_VRING_NUM:
1596                 r = vhost_vring_set_num(d, vq, argp);
1597                 break;
1598         case VHOST_SET_VRING_ADDR:
1599                 r = vhost_vring_set_addr(d, vq, argp);
1600                 break;
1601         default:
1602                 BUG();
1603         }
1604
1605         mutex_unlock(&vq->mutex);
1606
1607         return r;
1608 }
1609 long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1610 {
1611         struct file *eventfp, *filep = NULL;
1612         bool pollstart = false, pollstop = false;
1613         struct eventfd_ctx *ctx = NULL;
1614         u32 __user *idxp = argp;
1615         struct vhost_virtqueue *vq;
1616         struct vhost_vring_state s;
1617         struct vhost_vring_file f;
1618         u32 idx;
1619         long r;
1620
1621         r = get_user(idx, idxp);
1622         if (r < 0)
1623                 return r;
1624         if (idx >= d->nvqs)
1625                 return -ENOBUFS;
1626
1627         idx = array_index_nospec(idx, d->nvqs);
1628         vq = d->vqs[idx];
1629
1630         if (ioctl == VHOST_SET_VRING_NUM ||
1631             ioctl == VHOST_SET_VRING_ADDR) {
1632                 return vhost_vring_set_num_addr(d, vq, ioctl, argp);
1633         }
1634
1635         mutex_lock(&vq->mutex);
1636
1637         switch (ioctl) {
1638         case VHOST_SET_VRING_BASE:
1639                 /* Moving base with an active backend?
1640                  * You don't want to do that. */
1641                 if (vq->private_data) {
1642                         r = -EBUSY;
1643                         break;
1644                 }
1645                 if (copy_from_user(&s, argp, sizeof s)) {
1646                         r = -EFAULT;
1647                         break;
1648                 }
1649                 if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
1650                         vq->last_avail_idx = s.num & 0xffff;
1651                         vq->last_used_idx = (s.num >> 16) & 0xffff;
1652                 } else {
1653                         if (s.num > 0xffff) {
1654                                 r = -EINVAL;
1655                                 break;
1656                         }
1657                         vq->last_avail_idx = s.num;
1658                 }
1659                 /* Forget the cached index value. */
1660                 vq->avail_idx = vq->last_avail_idx;
1661                 break;
1662         case VHOST_GET_VRING_BASE:
1663                 s.index = idx;
1664                 if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED))
1665                         s.num = (u32)vq->last_avail_idx | ((u32)vq->last_used_idx << 16);
1666                 else
1667                         s.num = vq->last_avail_idx;
1668                 if (copy_to_user(argp, &s, sizeof s))
1669                         r = -EFAULT;
1670                 break;
1671         case VHOST_SET_VRING_KICK:
1672                 if (copy_from_user(&f, argp, sizeof f)) {
1673                         r = -EFAULT;
1674                         break;
1675                 }
1676                 eventfp = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_fget(f.fd);
1677                 if (IS_ERR(eventfp)) {
1678                         r = PTR_ERR(eventfp);
1679                         break;
1680                 }
1681                 if (eventfp != vq->kick) {
1682                         pollstop = (filep = vq->kick) != NULL;
1683                         pollstart = (vq->kick = eventfp) != NULL;
1684                 } else
1685                         filep = eventfp;
1686                 break;
1687         case VHOST_SET_VRING_CALL:
1688                 if (copy_from_user(&f, argp, sizeof f)) {
1689                         r = -EFAULT;
1690                         break;
1691                 }
1692                 ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
1693                 if (IS_ERR(ctx)) {
1694                         r = PTR_ERR(ctx);
1695                         break;
1696                 }
1697
1698                 swap(ctx, vq->call_ctx.ctx);
1699                 break;
1700         case VHOST_SET_VRING_ERR:
1701                 if (copy_from_user(&f, argp, sizeof f)) {
1702                         r = -EFAULT;
1703                         break;
1704                 }
1705                 ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
1706                 if (IS_ERR(ctx)) {
1707                         r = PTR_ERR(ctx);
1708                         break;
1709                 }
1710                 swap(ctx, vq->error_ctx);
1711                 break;
1712         case VHOST_SET_VRING_ENDIAN:
1713                 r = vhost_set_vring_endian(vq, argp);
1714                 break;
1715         case VHOST_GET_VRING_ENDIAN:
1716                 r = vhost_get_vring_endian(vq, idx, argp);
1717                 break;
1718         case VHOST_SET_VRING_BUSYLOOP_TIMEOUT:
1719                 if (copy_from_user(&s, argp, sizeof(s))) {
1720                         r = -EFAULT;
1721                         break;
1722                 }
1723                 vq->busyloop_timeout = s.num;
1724                 break;
1725         case VHOST_GET_VRING_BUSYLOOP_TIMEOUT:
1726                 s.index = idx;
1727                 s.num = vq->busyloop_timeout;
1728                 if (copy_to_user(argp, &s, sizeof(s)))
1729                         r = -EFAULT;
1730                 break;
1731         default:
1732                 r = -ENOIOCTLCMD;
1733         }
1734
1735         if (pollstop && vq->handle_kick)
1736                 vhost_poll_stop(&vq->poll);
1737
1738         if (!IS_ERR_OR_NULL(ctx))
1739                 eventfd_ctx_put(ctx);
1740         if (filep)
1741                 fput(filep);
1742
1743         if (pollstart && vq->handle_kick)
1744                 r = vhost_poll_start(&vq->poll, vq->kick);
1745
1746         mutex_unlock(&vq->mutex);
1747
1748         if (pollstop && vq->handle_kick)
1749                 vhost_dev_flush(vq->poll.dev);
1750         return r;
1751 }
1752 EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
1753
1754 int vhost_init_device_iotlb(struct vhost_dev *d)
1755 {
1756         struct vhost_iotlb *niotlb, *oiotlb;
1757         int i;
1758
1759         niotlb = iotlb_alloc();
1760         if (!niotlb)
1761                 return -ENOMEM;
1762
1763         oiotlb = d->iotlb;
1764         d->iotlb = niotlb;
1765
1766         for (i = 0; i < d->nvqs; ++i) {
1767                 struct vhost_virtqueue *vq = d->vqs[i];
1768
1769                 mutex_lock(&vq->mutex);
1770                 vq->iotlb = niotlb;
1771                 __vhost_vq_meta_reset(vq);
1772                 mutex_unlock(&vq->mutex);
1773         }
1774
1775         vhost_iotlb_free(oiotlb);
1776
1777         return 0;
1778 }
1779 EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
1780
1781 /* Caller must have device mutex */
1782 long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1783 {
1784         struct eventfd_ctx *ctx;
1785         u64 p;
1786         long r;
1787         int i, fd;
1788
1789         /* If you are not the owner, you can become one */
1790         if (ioctl == VHOST_SET_OWNER) {
1791                 r = vhost_dev_set_owner(d);
1792                 goto done;
1793         }
1794
1795         /* You must be the owner to do anything else */
1796         r = vhost_dev_check_owner(d);
1797         if (r)
1798                 goto done;
1799
1800         switch (ioctl) {
1801         case VHOST_SET_MEM_TABLE:
1802                 r = vhost_set_memory(d, argp);
1803                 break;
1804         case VHOST_SET_LOG_BASE:
1805                 if (copy_from_user(&p, argp, sizeof p)) {
1806                         r = -EFAULT;
1807                         break;
1808                 }
1809                 if ((u64)(unsigned long)p != p) {
1810                         r = -EFAULT;
1811                         break;
1812                 }
1813                 for (i = 0; i < d->nvqs; ++i) {
1814                         struct vhost_virtqueue *vq;
1815                         void __user *base = (void __user *)(unsigned long)p;
1816                         vq = d->vqs[i];
1817                         mutex_lock(&vq->mutex);
1818                         /* If ring is inactive, will check when it's enabled. */
1819                         if (vq->private_data && !vq_log_access_ok(vq, base))
1820                                 r = -EFAULT;
1821                         else
1822                                 vq->log_base = base;
1823                         mutex_unlock(&vq->mutex);
1824                 }
1825                 break;
1826         case VHOST_SET_LOG_FD:
1827                 r = get_user(fd, (int __user *)argp);
1828                 if (r < 0)
1829                         break;
1830                 ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
1831                 if (IS_ERR(ctx)) {
1832                         r = PTR_ERR(ctx);
1833                         break;
1834                 }
1835                 swap(ctx, d->log_ctx);
1836                 for (i = 0; i < d->nvqs; ++i) {
1837                         mutex_lock(&d->vqs[i]->mutex);
1838                         d->vqs[i]->log_ctx = d->log_ctx;
1839                         mutex_unlock(&d->vqs[i]->mutex);
1840                 }
1841                 if (ctx)
1842                         eventfd_ctx_put(ctx);
1843                 break;
1844         default:
1845                 r = -ENOIOCTLCMD;
1846                 break;
1847         }
1848 done:
1849         return r;
1850 }
1851 EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
1852
1853 /* TODO: This is really inefficient.  We need something like get_user()
1854  * (instruction directly accesses the data, with an exception table entry
1855  * returning -EFAULT). See Documentation/arch/x86/exception-tables.rst.
1856  */
1857 static int set_bit_to_user(int nr, void __user *addr)
1858 {
1859         unsigned long log = (unsigned long)addr;
1860         struct page *page;
1861         void *base;
1862         int bit = nr + (log % PAGE_SIZE) * 8;
1863         int r;
1864
1865         r = pin_user_pages_fast(log, 1, FOLL_WRITE, &page);
1866         if (r < 0)
1867                 return r;
1868         BUG_ON(r != 1);
1869         base = kmap_atomic(page);
1870         set_bit(bit, base);
1871         kunmap_atomic(base);
1872         unpin_user_pages_dirty_lock(&page, 1, true);
1873         return 0;
1874 }
1875
1876 static int log_write(void __user *log_base,
1877                      u64 write_address, u64 write_length)
1878 {
1879         u64 write_page = write_address / VHOST_PAGE_SIZE;
1880         int r;
1881
1882         if (!write_length)
1883                 return 0;
1884         write_length += write_address % VHOST_PAGE_SIZE;
1885         for (;;) {
1886                 u64 base = (u64)(unsigned long)log_base;
1887                 u64 log = base + write_page / 8;
1888                 int bit = write_page % 8;
1889                 if ((u64)(unsigned long)log != log)
1890                         return -EFAULT;
1891                 r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
1892                 if (r < 0)
1893                         return r;
1894                 if (write_length <= VHOST_PAGE_SIZE)
1895                         break;
1896                 write_length -= VHOST_PAGE_SIZE;
1897                 write_page += 1;
1898         }
1899         return r;
1900 }
1901
1902 static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
1903 {
1904         struct vhost_iotlb *umem = vq->umem;
1905         struct vhost_iotlb_map *u;
1906         u64 start, end, l, min;
1907         int r;
1908         bool hit = false;
1909
1910         while (len) {
1911                 min = len;
1912                 /* More than one GPAs can be mapped into a single HVA. So
1913                  * iterate all possible umems here to be safe.
1914                  */
1915                 list_for_each_entry(u, &umem->list, link) {
1916                         if (u->addr > hva - 1 + len ||
1917                             u->addr - 1 + u->size < hva)
1918                                 continue;
1919                         start = max(u->addr, hva);
1920                         end = min(u->addr - 1 + u->size, hva - 1 + len);
1921                         l = end - start + 1;
1922                         r = log_write(vq->log_base,
1923                                       u->start + start - u->addr,
1924                                       l);
1925                         if (r < 0)
1926                                 return r;
1927                         hit = true;
1928                         min = min(l, min);
1929                 }
1930
1931                 if (!hit)
1932                         return -EFAULT;
1933
1934                 len -= min;
1935                 hva += min;
1936         }
1937
1938         return 0;
1939 }
1940
1941 static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
1942 {
1943         struct iovec *iov = vq->log_iov;
1944         int i, ret;
1945
1946         if (!vq->iotlb)
1947                 return log_write(vq->log_base, vq->log_addr + used_offset, len);
1948
1949         ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
1950                              len, iov, 64, VHOST_ACCESS_WO);
1951         if (ret < 0)
1952                 return ret;
1953
1954         for (i = 0; i < ret; i++) {
1955                 ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
1956                                     iov[i].iov_len);
1957                 if (ret)
1958                         return ret;
1959         }
1960
1961         return 0;
1962 }
1963
1964 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
1965                     unsigned int log_num, u64 len, struct iovec *iov, int count)
1966 {
1967         int i, r;
1968
1969         /* Make sure data written is seen before log. */
1970         smp_wmb();
1971
1972         if (vq->iotlb) {
1973                 for (i = 0; i < count; i++) {
1974                         r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
1975                                           iov[i].iov_len);
1976                         if (r < 0)
1977                                 return r;
1978                 }
1979                 return 0;
1980         }
1981
1982         for (i = 0; i < log_num; ++i) {
1983                 u64 l = min(log[i].len, len);
1984                 r = log_write(vq->log_base, log[i].addr, l);
1985                 if (r < 0)
1986                         return r;
1987                 len -= l;
1988                 if (!len) {
1989                         if (vq->log_ctx)
1990                                 eventfd_signal(vq->log_ctx, 1);
1991                         return 0;
1992                 }
1993         }
1994         /* Length written exceeds what we have stored. This is a bug. */
1995         BUG();
1996         return 0;
1997 }
1998 EXPORT_SYMBOL_GPL(vhost_log_write);
1999
2000 static int vhost_update_used_flags(struct vhost_virtqueue *vq)
2001 {
2002         void __user *used;
2003         if (vhost_put_used_flags(vq))
2004                 return -EFAULT;
2005         if (unlikely(vq->log_used)) {
2006                 /* Make sure the flag is seen before log. */
2007                 smp_wmb();
2008                 /* Log used flag write. */
2009                 used = &vq->used->flags;
2010                 log_used(vq, (used - (void __user *)vq->used),
2011                          sizeof vq->used->flags);
2012                 if (vq->log_ctx)
2013                         eventfd_signal(vq->log_ctx, 1);
2014         }
2015         return 0;
2016 }
2017
2018 static int vhost_update_avail_event(struct vhost_virtqueue *vq)
2019 {
2020         if (vhost_put_avail_event(vq))
2021                 return -EFAULT;
2022         if (unlikely(vq->log_used)) {
2023                 void __user *used;
2024                 /* Make sure the event is seen before log. */
2025                 smp_wmb();
2026                 /* Log avail event write */
2027                 used = vhost_avail_event(vq);
2028                 log_used(vq, (used - (void __user *)vq->used),
2029                          sizeof *vhost_avail_event(vq));
2030                 if (vq->log_ctx)
2031                         eventfd_signal(vq->log_ctx, 1);
2032         }
2033         return 0;
2034 }
2035
2036 int vhost_vq_init_access(struct vhost_virtqueue *vq)
2037 {
2038         __virtio16 last_used_idx;
2039         int r;
2040         bool is_le = vq->is_le;
2041
2042         if (!vq->private_data)
2043                 return 0;
2044
2045         vhost_init_is_le(vq);
2046
2047         r = vhost_update_used_flags(vq);
2048         if (r)
2049                 goto err;
2050         vq->signalled_used_valid = false;
2051         if (!vq->iotlb &&
2052             !access_ok(&vq->used->idx, sizeof vq->used->idx)) {
2053                 r = -EFAULT;
2054                 goto err;
2055         }
2056         r = vhost_get_used_idx(vq, &last_used_idx);
2057         if (r) {
2058                 vq_err(vq, "Can't access used idx at %p\n",
2059                        &vq->used->idx);
2060                 goto err;
2061         }
2062         vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
2063         return 0;
2064
2065 err:
2066         vq->is_le = is_le;
2067         return r;
2068 }
2069 EXPORT_SYMBOL_GPL(vhost_vq_init_access);
2070
2071 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
2072                           struct iovec iov[], int iov_size, int access)
2073 {
2074         const struct vhost_iotlb_map *map;
2075         struct vhost_dev *dev = vq->dev;
2076         struct vhost_iotlb *umem = dev->iotlb ? dev->iotlb : dev->umem;
2077         struct iovec *_iov;
2078         u64 s = 0, last = addr + len - 1;
2079         int ret = 0;
2080
2081         while ((u64)len > s) {
2082                 u64 size;
2083                 if (unlikely(ret >= iov_size)) {
2084                         ret = -ENOBUFS;
2085                         break;
2086                 }
2087
2088                 map = vhost_iotlb_itree_first(umem, addr, last);
2089                 if (map == NULL || map->start > addr) {
2090                         if (umem != dev->iotlb) {
2091                                 ret = -EFAULT;
2092                                 break;
2093                         }
2094                         ret = -EAGAIN;
2095                         break;
2096                 } else if (!(map->perm & access)) {
2097                         ret = -EPERM;
2098                         break;
2099                 }
2100
2101                 _iov = iov + ret;
2102                 size = map->size - addr + map->start;
2103                 _iov->iov_len = min((u64)len - s, size);
2104                 _iov->iov_base = (void __user *)(unsigned long)
2105                                  (map->addr + addr - map->start);
2106                 s += size;
2107                 addr += size;
2108                 ++ret;
2109         }
2110
2111         if (ret == -EAGAIN)
2112                 vhost_iotlb_miss(vq, addr, access);
2113         return ret;
2114 }
2115
2116 /* Each buffer in the virtqueues is actually a chain of descriptors.  This
2117  * function returns the next descriptor in the chain,
2118  * or -1U if we're at the end. */
2119 static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc)
2120 {
2121         unsigned int next;
2122
2123         /* If this descriptor says it doesn't chain, we're done. */
2124         if (!(desc->flags & cpu_to_vhost16(vq, VRING_DESC_F_NEXT)))
2125                 return -1U;
2126
2127         /* Check they're not leading us off end of descriptors. */
2128         next = vhost16_to_cpu(vq, READ_ONCE(desc->next));
2129         return next;
2130 }
2131
2132 static int get_indirect(struct vhost_virtqueue *vq,
2133                         struct iovec iov[], unsigned int iov_size,
2134                         unsigned int *out_num, unsigned int *in_num,
2135                         struct vhost_log *log, unsigned int *log_num,
2136                         struct vring_desc *indirect)
2137 {
2138         struct vring_desc desc;
2139         unsigned int i = 0, count, found = 0;
2140         u32 len = vhost32_to_cpu(vq, indirect->len);
2141         struct iov_iter from;
2142         int ret, access;
2143
2144         /* Sanity check */
2145         if (unlikely(len % sizeof desc)) {
2146                 vq_err(vq, "Invalid length in indirect descriptor: "
2147                        "len 0x%llx not multiple of 0x%zx\n",
2148                        (unsigned long long)len,
2149                        sizeof desc);
2150                 return -EINVAL;
2151         }
2152
2153         ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
2154                              UIO_MAXIOV, VHOST_ACCESS_RO);
2155         if (unlikely(ret < 0)) {
2156                 if (ret != -EAGAIN)
2157                         vq_err(vq, "Translation failure %d in indirect.\n", ret);
2158                 return ret;
2159         }
2160         iov_iter_init(&from, ITER_SOURCE, vq->indirect, ret, len);
2161         count = len / sizeof desc;
2162         /* Buffers are chained via a 16 bit next field, so
2163          * we can have at most 2^16 of these. */
2164         if (unlikely(count > USHRT_MAX + 1)) {
2165                 vq_err(vq, "Indirect buffer length too big: %d\n",
2166                        indirect->len);
2167                 return -E2BIG;
2168         }
2169
2170         do {
2171                 unsigned iov_count = *in_num + *out_num;
2172                 if (unlikely(++found > count)) {
2173                         vq_err(vq, "Loop detected: last one at %u "
2174                                "indirect size %u\n",
2175                                i, count);
2176                         return -EINVAL;
2177                 }
2178                 if (unlikely(!copy_from_iter_full(&desc, sizeof(desc), &from))) {
2179                         vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n",
2180                                i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2181                         return -EINVAL;
2182                 }
2183                 if (unlikely(desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT))) {
2184                         vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n",
2185                                i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2186                         return -EINVAL;
2187                 }
2188
2189                 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2190                         access = VHOST_ACCESS_WO;
2191                 else
2192                         access = VHOST_ACCESS_RO;
2193
2194                 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2195                                      vhost32_to_cpu(vq, desc.len), iov + iov_count,
2196                                      iov_size - iov_count, access);
2197                 if (unlikely(ret < 0)) {
2198                         if (ret != -EAGAIN)
2199                                 vq_err(vq, "Translation failure %d indirect idx %d\n",
2200                                         ret, i);
2201                         return ret;
2202                 }
2203                 /* If this is an input descriptor, increment that count. */
2204                 if (access == VHOST_ACCESS_WO) {
2205                         *in_num += ret;
2206                         if (unlikely(log && ret)) {
2207                                 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2208                                 log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2209                                 ++*log_num;
2210                         }
2211                 } else {
2212                         /* If it's an output descriptor, they're all supposed
2213                          * to come before any input descriptors. */
2214                         if (unlikely(*in_num)) {
2215                                 vq_err(vq, "Indirect descriptor "
2216                                        "has out after in: idx %d\n", i);
2217                                 return -EINVAL;
2218                         }
2219                         *out_num += ret;
2220                 }
2221         } while ((i = next_desc(vq, &desc)) != -1);
2222         return 0;
2223 }
2224
2225 /* This looks in the virtqueue and for the first available buffer, and converts
2226  * it to an iovec for convenient access.  Since descriptors consist of some
2227  * number of output then some number of input descriptors, it's actually two
2228  * iovecs, but we pack them into one and note how many of each there were.
2229  *
2230  * This function returns the descriptor number found, or vq->num (which is
2231  * never a valid descriptor number) if none was found.  A negative code is
2232  * returned on error. */
2233 int vhost_get_vq_desc(struct vhost_virtqueue *vq,
2234                       struct iovec iov[], unsigned int iov_size,
2235                       unsigned int *out_num, unsigned int *in_num,
2236                       struct vhost_log *log, unsigned int *log_num)
2237 {
2238         struct vring_desc desc;
2239         unsigned int i, head, found = 0;
2240         u16 last_avail_idx;
2241         __virtio16 avail_idx;
2242         __virtio16 ring_head;
2243         int ret, access;
2244
2245         /* Check it isn't doing very strange things with descriptor numbers. */
2246         last_avail_idx = vq->last_avail_idx;
2247
2248         if (vq->avail_idx == vq->last_avail_idx) {
2249                 if (unlikely(vhost_get_avail_idx(vq, &avail_idx))) {
2250                         vq_err(vq, "Failed to access avail idx at %p\n",
2251                                 &vq->avail->idx);
2252                         return -EFAULT;
2253                 }
2254                 vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2255
2256                 if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) {
2257                         vq_err(vq, "Guest moved used index from %u to %u",
2258                                 last_avail_idx, vq->avail_idx);
2259                         return -EFAULT;
2260                 }
2261
2262                 /* If there's nothing new since last we looked, return
2263                  * invalid.
2264                  */
2265                 if (vq->avail_idx == last_avail_idx)
2266                         return vq->num;
2267
2268                 /* Only get avail ring entries after they have been
2269                  * exposed by guest.
2270                  */
2271                 smp_rmb();
2272         }
2273
2274         /* Grab the next descriptor number they're advertising, and increment
2275          * the index we've seen. */
2276         if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) {
2277                 vq_err(vq, "Failed to read head: idx %d address %p\n",
2278                        last_avail_idx,
2279                        &vq->avail->ring[last_avail_idx % vq->num]);
2280                 return -EFAULT;
2281         }
2282
2283         head = vhost16_to_cpu(vq, ring_head);
2284
2285         /* If their number is silly, that's an error. */
2286         if (unlikely(head >= vq->num)) {
2287                 vq_err(vq, "Guest says index %u > %u is available",
2288                        head, vq->num);
2289                 return -EINVAL;
2290         }
2291
2292         /* When we start there are none of either input nor output. */
2293         *out_num = *in_num = 0;
2294         if (unlikely(log))
2295                 *log_num = 0;
2296
2297         i = head;
2298         do {
2299                 unsigned iov_count = *in_num + *out_num;
2300                 if (unlikely(i >= vq->num)) {
2301                         vq_err(vq, "Desc index is %u > %u, head = %u",
2302                                i, vq->num, head);
2303                         return -EINVAL;
2304                 }
2305                 if (unlikely(++found > vq->num)) {
2306                         vq_err(vq, "Loop detected: last one at %u "
2307                                "vq size %u head %u\n",
2308                                i, vq->num, head);
2309                         return -EINVAL;
2310                 }
2311                 ret = vhost_get_desc(vq, &desc, i);
2312                 if (unlikely(ret)) {
2313                         vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
2314                                i, vq->desc + i);
2315                         return -EFAULT;
2316                 }
2317                 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) {
2318                         ret = get_indirect(vq, iov, iov_size,
2319                                            out_num, in_num,
2320                                            log, log_num, &desc);
2321                         if (unlikely(ret < 0)) {
2322                                 if (ret != -EAGAIN)
2323                                         vq_err(vq, "Failure detected "
2324                                                 "in indirect descriptor at idx %d\n", i);
2325                                 return ret;
2326                         }
2327                         continue;
2328                 }
2329
2330                 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2331                         access = VHOST_ACCESS_WO;
2332                 else
2333                         access = VHOST_ACCESS_RO;
2334                 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2335                                      vhost32_to_cpu(vq, desc.len), iov + iov_count,
2336                                      iov_size - iov_count, access);
2337                 if (unlikely(ret < 0)) {
2338                         if (ret != -EAGAIN)
2339                                 vq_err(vq, "Translation failure %d descriptor idx %d\n",
2340                                         ret, i);
2341                         return ret;
2342                 }
2343                 if (access == VHOST_ACCESS_WO) {
2344                         /* If this is an input descriptor,
2345                          * increment that count. */
2346                         *in_num += ret;
2347                         if (unlikely(log && ret)) {
2348                                 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2349                                 log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2350                                 ++*log_num;
2351                         }
2352                 } else {
2353                         /* If it's an output descriptor, they're all supposed
2354                          * to come before any input descriptors. */
2355                         if (unlikely(*in_num)) {
2356                                 vq_err(vq, "Descriptor has out after in: "
2357                                        "idx %d\n", i);
2358                                 return -EINVAL;
2359                         }
2360                         *out_num += ret;
2361                 }
2362         } while ((i = next_desc(vq, &desc)) != -1);
2363
2364         /* On success, increment avail index. */
2365         vq->last_avail_idx++;
2366
2367         /* Assume notifications from guest are disabled at this point,
2368          * if they aren't we would need to update avail_event index. */
2369         BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
2370         return head;
2371 }
2372 EXPORT_SYMBOL_GPL(vhost_get_vq_desc);
2373
2374 /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */
2375 void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)
2376 {
2377         vq->last_avail_idx -= n;
2378 }
2379 EXPORT_SYMBOL_GPL(vhost_discard_vq_desc);
2380
2381 /* After we've used one of their buffers, we tell them about it.  We'll then
2382  * want to notify the guest, using eventfd. */
2383 int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
2384 {
2385         struct vring_used_elem heads = {
2386                 cpu_to_vhost32(vq, head),
2387                 cpu_to_vhost32(vq, len)
2388         };
2389
2390         return vhost_add_used_n(vq, &heads, 1);
2391 }
2392 EXPORT_SYMBOL_GPL(vhost_add_used);
2393
2394 static int __vhost_add_used_n(struct vhost_virtqueue *vq,
2395                             struct vring_used_elem *heads,
2396                             unsigned count)
2397 {
2398         vring_used_elem_t __user *used;
2399         u16 old, new;
2400         int start;
2401
2402         start = vq->last_used_idx & (vq->num - 1);
2403         used = vq->used->ring + start;
2404         if (vhost_put_used(vq, heads, start, count)) {
2405                 vq_err(vq, "Failed to write used");
2406                 return -EFAULT;
2407         }
2408         if (unlikely(vq->log_used)) {
2409                 /* Make sure data is seen before log. */
2410                 smp_wmb();
2411                 /* Log used ring entry write. */
2412                 log_used(vq, ((void __user *)used - (void __user *)vq->used),
2413                          count * sizeof *used);
2414         }
2415         old = vq->last_used_idx;
2416         new = (vq->last_used_idx += count);
2417         /* If the driver never bothers to signal in a very long while,
2418          * used index might wrap around. If that happens, invalidate
2419          * signalled_used index we stored. TODO: make sure driver
2420          * signals at least once in 2^16 and remove this. */
2421         if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old)))
2422                 vq->signalled_used_valid = false;
2423         return 0;
2424 }
2425
2426 /* After we've used one of their buffers, we tell them about it.  We'll then
2427  * want to notify the guest, using eventfd. */
2428 int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
2429                      unsigned count)
2430 {
2431         int start, n, r;
2432
2433         start = vq->last_used_idx & (vq->num - 1);
2434         n = vq->num - start;
2435         if (n < count) {
2436                 r = __vhost_add_used_n(vq, heads, n);
2437                 if (r < 0)
2438                         return r;
2439                 heads += n;
2440                 count -= n;
2441         }
2442         r = __vhost_add_used_n(vq, heads, count);
2443
2444         /* Make sure buffer is written before we update index. */
2445         smp_wmb();
2446         if (vhost_put_used_idx(vq)) {
2447                 vq_err(vq, "Failed to increment used idx");
2448                 return -EFAULT;
2449         }
2450         if (unlikely(vq->log_used)) {
2451                 /* Make sure used idx is seen before log. */
2452                 smp_wmb();
2453                 /* Log used index update. */
2454                 log_used(vq, offsetof(struct vring_used, idx),
2455                          sizeof vq->used->idx);
2456                 if (vq->log_ctx)
2457                         eventfd_signal(vq->log_ctx, 1);
2458         }
2459         return r;
2460 }
2461 EXPORT_SYMBOL_GPL(vhost_add_used_n);
2462
2463 static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2464 {
2465         __u16 old, new;
2466         __virtio16 event;
2467         bool v;
2468         /* Flush out used index updates. This is paired
2469          * with the barrier that the Guest executes when enabling
2470          * interrupts. */
2471         smp_mb();
2472
2473         if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2474             unlikely(vq->avail_idx == vq->last_avail_idx))
2475                 return true;
2476
2477         if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2478                 __virtio16 flags;
2479                 if (vhost_get_avail_flags(vq, &flags)) {
2480                         vq_err(vq, "Failed to get flags");
2481                         return true;
2482                 }
2483                 return !(flags & cpu_to_vhost16(vq, VRING_AVAIL_F_NO_INTERRUPT));
2484         }
2485         old = vq->signalled_used;
2486         v = vq->signalled_used_valid;
2487         new = vq->signalled_used = vq->last_used_idx;
2488         vq->signalled_used_valid = true;
2489
2490         if (unlikely(!v))
2491                 return true;
2492
2493         if (vhost_get_used_event(vq, &event)) {
2494                 vq_err(vq, "Failed to get used event idx");
2495                 return true;
2496         }
2497         return vring_need_event(vhost16_to_cpu(vq, event), new, old);
2498 }
2499
2500 /* This actually signals the guest, using eventfd. */
2501 void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2502 {
2503         /* Signal the Guest tell them we used something up. */
2504         if (vq->call_ctx.ctx && vhost_notify(dev, vq))
2505                 eventfd_signal(vq->call_ctx.ctx, 1);
2506 }
2507 EXPORT_SYMBOL_GPL(vhost_signal);
2508
2509 /* And here's the combo meal deal.  Supersize me! */
2510 void vhost_add_used_and_signal(struct vhost_dev *dev,
2511                                struct vhost_virtqueue *vq,
2512                                unsigned int head, int len)
2513 {
2514         vhost_add_used(vq, head, len);
2515         vhost_signal(dev, vq);
2516 }
2517 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal);
2518
2519 /* multi-buffer version of vhost_add_used_and_signal */
2520 void vhost_add_used_and_signal_n(struct vhost_dev *dev,
2521                                  struct vhost_virtqueue *vq,
2522                                  struct vring_used_elem *heads, unsigned count)
2523 {
2524         vhost_add_used_n(vq, heads, count);
2525         vhost_signal(dev, vq);
2526 }
2527 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n);
2528
2529 /* return true if we're sure that avaiable ring is empty */
2530 bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2531 {
2532         __virtio16 avail_idx;
2533         int r;
2534
2535         if (vq->avail_idx != vq->last_avail_idx)
2536                 return false;
2537
2538         r = vhost_get_avail_idx(vq, &avail_idx);
2539         if (unlikely(r))
2540                 return false;
2541         vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2542
2543         return vq->avail_idx == vq->last_avail_idx;
2544 }
2545 EXPORT_SYMBOL_GPL(vhost_vq_avail_empty);
2546
2547 /* OK, now we need to know about added descriptors. */
2548 bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2549 {
2550         __virtio16 avail_idx;
2551         int r;
2552
2553         if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
2554                 return false;
2555         vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
2556         if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2557                 r = vhost_update_used_flags(vq);
2558                 if (r) {
2559                         vq_err(vq, "Failed to enable notification at %p: %d\n",
2560                                &vq->used->flags, r);
2561                         return false;
2562                 }
2563         } else {
2564                 r = vhost_update_avail_event(vq);
2565                 if (r) {
2566                         vq_err(vq, "Failed to update avail event index at %p: %d\n",
2567                                vhost_avail_event(vq), r);
2568                         return false;
2569                 }
2570         }
2571         /* They could have slipped one in as we were doing that: make
2572          * sure it's written, then check again. */
2573         smp_mb();
2574         r = vhost_get_avail_idx(vq, &avail_idx);
2575         if (r) {
2576                 vq_err(vq, "Failed to check avail idx at %p: %d\n",
2577                        &vq->avail->idx, r);
2578                 return false;
2579         }
2580         vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2581
2582         return vq->avail_idx != vq->last_avail_idx;
2583 }
2584 EXPORT_SYMBOL_GPL(vhost_enable_notify);
2585
2586 /* We don't need to be notified again. */
2587 void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2588 {
2589         int r;
2590
2591         if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
2592                 return;
2593         vq->used_flags |= VRING_USED_F_NO_NOTIFY;
2594         if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2595                 r = vhost_update_used_flags(vq);
2596                 if (r)
2597                         vq_err(vq, "Failed to disable notification at %p: %d\n",
2598                                &vq->used->flags, r);
2599         }
2600 }
2601 EXPORT_SYMBOL_GPL(vhost_disable_notify);
2602
2603 /* Create a new message. */
2604 struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
2605 {
2606         /* Make sure all padding within the structure is initialized. */
2607         struct vhost_msg_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
2608         if (!node)
2609                 return NULL;
2610
2611         node->vq = vq;
2612         node->msg.type = type;
2613         return node;
2614 }
2615 EXPORT_SYMBOL_GPL(vhost_new_msg);
2616
2617 void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
2618                        struct vhost_msg_node *node)
2619 {
2620         spin_lock(&dev->iotlb_lock);
2621         list_add_tail(&node->node, head);
2622         spin_unlock(&dev->iotlb_lock);
2623
2624         wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
2625 }
2626 EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
2627
2628 struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
2629                                          struct list_head *head)
2630 {
2631         struct vhost_msg_node *node = NULL;
2632
2633         spin_lock(&dev->iotlb_lock);
2634         if (!list_empty(head)) {
2635                 node = list_first_entry(head, struct vhost_msg_node,
2636                                         node);
2637                 list_del(&node->node);
2638         }
2639         spin_unlock(&dev->iotlb_lock);
2640
2641         return node;
2642 }
2643 EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
2644
2645 void vhost_set_backend_features(struct vhost_dev *dev, u64 features)
2646 {
2647         struct vhost_virtqueue *vq;
2648         int i;
2649
2650         mutex_lock(&dev->mutex);
2651         for (i = 0; i < dev->nvqs; ++i) {
2652                 vq = dev->vqs[i];
2653                 mutex_lock(&vq->mutex);
2654                 vq->acked_backend_features = features;
2655                 mutex_unlock(&vq->mutex);
2656         }
2657         mutex_unlock(&dev->mutex);
2658 }
2659 EXPORT_SYMBOL_GPL(vhost_set_backend_features);
2660
2661 static int __init vhost_init(void)
2662 {
2663         return 0;
2664 }
2665
2666 static void __exit vhost_exit(void)
2667 {
2668 }
2669
2670 module_init(vhost_init);
2671 module_exit(vhost_exit);
2672
2673 MODULE_VERSION("0.0.1");
2674 MODULE_LICENSE("GPL v2");
2675 MODULE_AUTHOR("Michael S. Tsirkin");
2676 MODULE_DESCRIPTION("Host kernel accelerator for virtio");