Merge tag 'for_linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mst/vhost
[platform/kernel/linux-rpi.git] / drivers / vhost / vhost.c
index 43fa626..10bf35a 100644 (file)
 #include <linux/slab.h>
 #include <linux/vmalloc.h>
 #include <linux/kthread.h>
-#include <linux/cgroup.h>
 #include <linux/module.h>
 #include <linux/sort.h>
 #include <linux/sched/mm.h>
 #include <linux/sched/signal.h>
+#include <linux/sched/vhost_task.h>
 #include <linux/interval_tree_generic.h>
 #include <linux/nospec.h>
 #include <linux/kcov.h>
@@ -255,8 +255,8 @@ void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work)
                 * sure it was not in the list.
                 * test_and_set_bit() implies a memory barrier.
                 */
-               llist_add(&work->node, &dev->work_list);
-               wake_up_process(dev->worker);
+               llist_add(&work->node, &dev->worker->work_list);
+               wake_up_process(dev->worker->vtsk->task);
        }
 }
 EXPORT_SYMBOL_GPL(vhost_work_queue);
@@ -264,7 +264,7 @@ EXPORT_SYMBOL_GPL(vhost_work_queue);
 /* A lockless hint for busy polling code to exit the loop */
 bool vhost_has_work(struct vhost_dev *dev)
 {
-       return !llist_empty(&dev->work_list);
+       return dev->worker && !llist_empty(&dev->worker->work_list);
 }
 EXPORT_SYMBOL_GPL(vhost_has_work);
 
@@ -335,22 +335,20 @@ static void vhost_vq_reset(struct vhost_dev *dev,
 
 static int vhost_worker(void *data)
 {
-       struct vhost_dev *dev = data;
+       struct vhost_worker *worker = data;
        struct vhost_work *work, *work_next;
        struct llist_node *node;
 
-       kthread_use_mm(dev->mm);
-
        for (;;) {
                /* mb paired w/ kthread_stop */
                set_current_state(TASK_INTERRUPTIBLE);
 
-               if (kthread_should_stop()) {
+               if (vhost_task_should_stop(worker->vtsk)) {
                        __set_current_state(TASK_RUNNING);
                        break;
                }
 
-               node = llist_del_all(&dev->work_list);
+               node = llist_del_all(&worker->work_list);
                if (!node)
                        schedule();
 
@@ -360,14 +358,14 @@ static int vhost_worker(void *data)
                llist_for_each_entry_safe(work, work_next, node, node) {
                        clear_bit(VHOST_WORK_QUEUED, &work->flags);
                        __set_current_state(TASK_RUNNING);
-                       kcov_remote_start_common(dev->kcov_handle);
+                       kcov_remote_start_common(worker->kcov_handle);
                        work->fn(work);
                        kcov_remote_stop();
                        if (need_resched())
                                schedule();
                }
        }
-       kthread_unuse_mm(dev->mm);
+
        return 0;
 }
 
@@ -477,7 +475,6 @@ void vhost_dev_init(struct vhost_dev *dev,
        dev->byte_weight = byte_weight;
        dev->use_worker = use_worker;
        dev->msg_handler = msg_handler;
-       init_llist_head(&dev->work_list);
        init_waitqueue_head(&dev->wait);
        INIT_LIST_HEAD(&dev->read_list);
        INIT_LIST_HEAD(&dev->pending_list);
@@ -507,31 +504,6 @@ long vhost_dev_check_owner(struct vhost_dev *dev)
 }
 EXPORT_SYMBOL_GPL(vhost_dev_check_owner);
 
-struct vhost_attach_cgroups_struct {
-       struct vhost_work work;
-       struct task_struct *owner;
-       int ret;
-};
-
-static void vhost_attach_cgroups_work(struct vhost_work *work)
-{
-       struct vhost_attach_cgroups_struct *s;
-
-       s = container_of(work, struct vhost_attach_cgroups_struct, work);
-       s->ret = cgroup_attach_task_all(s->owner, current);
-}
-
-static int vhost_attach_cgroups(struct vhost_dev *dev)
-{
-       struct vhost_attach_cgroups_struct attach;
-
-       attach.owner = current;
-       vhost_work_init(&attach.work, vhost_attach_cgroups_work);
-       vhost_work_queue(dev, &attach.work);
-       vhost_dev_flush(dev);
-       return attach.ret;
-}
-
 /* Caller should have device mutex */
 bool vhost_dev_has_owner(struct vhost_dev *dev)
 {
@@ -569,10 +541,54 @@ static void vhost_detach_mm(struct vhost_dev *dev)
        dev->mm = NULL;
 }
 
+static void vhost_worker_free(struct vhost_dev *dev)
+{
+       struct vhost_worker *worker = dev->worker;
+
+       if (!worker)
+               return;
+
+       dev->worker = NULL;
+       WARN_ON(!llist_empty(&worker->work_list));
+       vhost_task_stop(worker->vtsk);
+       kfree(worker);
+}
+
+static int vhost_worker_create(struct vhost_dev *dev)
+{
+       struct vhost_worker *worker;
+       struct vhost_task *vtsk;
+       char name[TASK_COMM_LEN];
+       int ret;
+
+       worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
+       if (!worker)
+               return -ENOMEM;
+
+       dev->worker = worker;
+       worker->kcov_handle = kcov_common_handle();
+       init_llist_head(&worker->work_list);
+       snprintf(name, sizeof(name), "vhost-%d", current->pid);
+
+       vtsk = vhost_task_create(vhost_worker, worker, name);
+       if (!vtsk) {
+               ret = -ENOMEM;
+               goto free_worker;
+       }
+
+       worker->vtsk = vtsk;
+       vhost_task_start(vtsk);
+       return 0;
+
+free_worker:
+       kfree(worker);
+       dev->worker = NULL;
+       return ret;
+}
+
 /* Caller should have device mutex */
 long vhost_dev_set_owner(struct vhost_dev *dev)
 {
-       struct task_struct *worker;
        int err;
 
        /* Is there an owner already? */
@@ -583,36 +599,21 @@ long vhost_dev_set_owner(struct vhost_dev *dev)
 
        vhost_attach_mm(dev);
 
-       dev->kcov_handle = kcov_common_handle();
        if (dev->use_worker) {
-               worker = kthread_create(vhost_worker, dev,
-                                       "vhost-%d", current->pid);
-               if (IS_ERR(worker)) {
-                       err = PTR_ERR(worker);
-                       goto err_worker;
-               }
-
-               dev->worker = worker;
-               wake_up_process(worker); /* avoid contributing to loadavg */
-
-               err = vhost_attach_cgroups(dev);
+               err = vhost_worker_create(dev);
                if (err)
-                       goto err_cgroup;
+                       goto err_worker;
        }
 
        err = vhost_dev_alloc_iovecs(dev);
        if (err)
-               goto err_cgroup;
+               goto err_iovecs;
 
        return 0;
-err_cgroup:
-       if (dev->worker) {
-               kthread_stop(dev->worker);
-               dev->worker = NULL;
-       }
+err_iovecs:
+       vhost_worker_free(dev);
 err_worker:
        vhost_detach_mm(dev);
-       dev->kcov_handle = 0;
 err_mm:
        return err;
 }
@@ -703,12 +704,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
        dev->iotlb = NULL;
        vhost_clear_msg(dev);
        wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
-       WARN_ON(!llist_empty(&dev->work_list));
-       if (dev->worker) {
-               kthread_stop(dev->worker);
-               dev->worker = NULL;
-               dev->kcov_handle = 0;
-       }
+       vhost_worker_free(dev);
        vhost_detach_mm(dev);
 }
 EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
@@ -1829,7 +1825,7 @@ EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
 
 /* TODO: This is really inefficient.  We need something like get_user()
  * (instruction directly accesses the data, with an exception table entry
- * returning -EFAULT). See Documentation/x86/exception-tables.rst.
+ * returning -EFAULT). See Documentation/arch/x86/exception-tables.rst.
  */
 static int set_bit_to_user(int nr, void __user *addr)
 {