bpf: Factor out socket lookup functions for the TC hookpoint.
[platform/kernel/linux-starfive.git] / net / vmw_vsock / virtio_transport.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * virtio transport for vsock
4  *
5  * Copyright (C) 2013-2015 Red Hat, Inc.
6  * Author: Asias He <asias@redhat.com>
7  *         Stefan Hajnoczi <stefanha@redhat.com>
8  *
9  * Some of the code is take from Gerd Hoffmann <kraxel@redhat.com>'s
10  * early virtio-vsock proof-of-concept bits.
11  */
12 #include <linux/spinlock.h>
13 #include <linux/module.h>
14 #include <linux/list.h>
15 #include <linux/atomic.h>
16 #include <linux/virtio.h>
17 #include <linux/virtio_ids.h>
18 #include <linux/virtio_config.h>
19 #include <linux/virtio_vsock.h>
20 #include <net/sock.h>
21 #include <linux/mutex.h>
22 #include <net/af_vsock.h>
23
24 static struct workqueue_struct *virtio_vsock_workqueue;
25 static struct virtio_vsock __rcu *the_virtio_vsock;
26 static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */
27 static struct virtio_transport virtio_transport; /* forward declaration */
28
29 struct virtio_vsock {
30         struct virtio_device *vdev;
31         struct virtqueue *vqs[VSOCK_VQ_MAX];
32
33         /* Virtqueue processing is deferred to a workqueue */
34         struct work_struct tx_work;
35         struct work_struct rx_work;
36         struct work_struct event_work;
37
38         /* The following fields are protected by tx_lock.  vqs[VSOCK_VQ_TX]
39          * must be accessed with tx_lock held.
40          */
41         struct mutex tx_lock;
42         bool tx_run;
43
44         struct work_struct send_pkt_work;
45         spinlock_t send_pkt_list_lock;
46         struct list_head send_pkt_list;
47
48         atomic_t queued_replies;
49
50         /* The following fields are protected by rx_lock.  vqs[VSOCK_VQ_RX]
51          * must be accessed with rx_lock held.
52          */
53         struct mutex rx_lock;
54         bool rx_run;
55         int rx_buf_nr;
56         int rx_buf_max_nr;
57
58         /* The following fields are protected by event_lock.
59          * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held.
60          */
61         struct mutex event_lock;
62         bool event_run;
63         struct virtio_vsock_event event_list[8];
64
65         u32 guest_cid;
66         bool seqpacket_allow;
67 };
68
69 static u32 virtio_transport_get_local_cid(void)
70 {
71         struct virtio_vsock *vsock;
72         u32 ret;
73
74         rcu_read_lock();
75         vsock = rcu_dereference(the_virtio_vsock);
76         if (!vsock) {
77                 ret = VMADDR_CID_ANY;
78                 goto out_rcu;
79         }
80
81         ret = vsock->guest_cid;
82 out_rcu:
83         rcu_read_unlock();
84         return ret;
85 }
86
87 static void
88 virtio_transport_send_pkt_work(struct work_struct *work)
89 {
90         struct virtio_vsock *vsock =
91                 container_of(work, struct virtio_vsock, send_pkt_work);
92         struct virtqueue *vq;
93         bool added = false;
94         bool restart_rx = false;
95
96         mutex_lock(&vsock->tx_lock);
97
98         if (!vsock->tx_run)
99                 goto out;
100
101         vq = vsock->vqs[VSOCK_VQ_TX];
102
103         for (;;) {
104                 struct virtio_vsock_pkt *pkt;
105                 struct scatterlist hdr, buf, *sgs[2];
106                 int ret, in_sg = 0, out_sg = 0;
107                 bool reply;
108
109                 spin_lock_bh(&vsock->send_pkt_list_lock);
110                 if (list_empty(&vsock->send_pkt_list)) {
111                         spin_unlock_bh(&vsock->send_pkt_list_lock);
112                         break;
113                 }
114
115                 pkt = list_first_entry(&vsock->send_pkt_list,
116                                        struct virtio_vsock_pkt, list);
117                 list_del_init(&pkt->list);
118                 spin_unlock_bh(&vsock->send_pkt_list_lock);
119
120                 virtio_transport_deliver_tap_pkt(pkt);
121
122                 reply = pkt->reply;
123
124                 sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
125                 sgs[out_sg++] = &hdr;
126                 if (pkt->buf) {
127                         sg_init_one(&buf, pkt->buf, pkt->len);
128                         sgs[out_sg++] = &buf;
129                 }
130
131                 ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, pkt, GFP_KERNEL);
132                 /* Usually this means that there is no more space available in
133                  * the vq
134                  */
135                 if (ret < 0) {
136                         spin_lock_bh(&vsock->send_pkt_list_lock);
137                         list_add(&pkt->list, &vsock->send_pkt_list);
138                         spin_unlock_bh(&vsock->send_pkt_list_lock);
139                         break;
140                 }
141
142                 if (reply) {
143                         struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
144                         int val;
145
146                         val = atomic_dec_return(&vsock->queued_replies);
147
148                         /* Do we now have resources to resume rx processing? */
149                         if (val + 1 == virtqueue_get_vring_size(rx_vq))
150                                 restart_rx = true;
151                 }
152
153                 added = true;
154         }
155
156         if (added)
157                 virtqueue_kick(vq);
158
159 out:
160         mutex_unlock(&vsock->tx_lock);
161
162         if (restart_rx)
163                 queue_work(virtio_vsock_workqueue, &vsock->rx_work);
164 }
165
166 static int
167 virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
168 {
169         struct virtio_vsock *vsock;
170         int len = pkt->len;
171
172         rcu_read_lock();
173         vsock = rcu_dereference(the_virtio_vsock);
174         if (!vsock) {
175                 virtio_transport_free_pkt(pkt);
176                 len = -ENODEV;
177                 goto out_rcu;
178         }
179
180         if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) {
181                 virtio_transport_free_pkt(pkt);
182                 len = -ENODEV;
183                 goto out_rcu;
184         }
185
186         if (pkt->reply)
187                 atomic_inc(&vsock->queued_replies);
188
189         spin_lock_bh(&vsock->send_pkt_list_lock);
190         list_add_tail(&pkt->list, &vsock->send_pkt_list);
191         spin_unlock_bh(&vsock->send_pkt_list_lock);
192
193         queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
194
195 out_rcu:
196         rcu_read_unlock();
197         return len;
198 }
199
200 static int
201 virtio_transport_cancel_pkt(struct vsock_sock *vsk)
202 {
203         struct virtio_vsock *vsock;
204         struct virtio_vsock_pkt *pkt, *n;
205         int cnt = 0, ret;
206         LIST_HEAD(freeme);
207
208         rcu_read_lock();
209         vsock = rcu_dereference(the_virtio_vsock);
210         if (!vsock) {
211                 ret = -ENODEV;
212                 goto out_rcu;
213         }
214
215         spin_lock_bh(&vsock->send_pkt_list_lock);
216         list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) {
217                 if (pkt->vsk != vsk)
218                         continue;
219                 list_move(&pkt->list, &freeme);
220         }
221         spin_unlock_bh(&vsock->send_pkt_list_lock);
222
223         list_for_each_entry_safe(pkt, n, &freeme, list) {
224                 if (pkt->reply)
225                         cnt++;
226                 list_del(&pkt->list);
227                 virtio_transport_free_pkt(pkt);
228         }
229
230         if (cnt) {
231                 struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
232                 int new_cnt;
233
234                 new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
235                 if (new_cnt + cnt >= virtqueue_get_vring_size(rx_vq) &&
236                     new_cnt < virtqueue_get_vring_size(rx_vq))
237                         queue_work(virtio_vsock_workqueue, &vsock->rx_work);
238         }
239
240         ret = 0;
241
242 out_rcu:
243         rcu_read_unlock();
244         return ret;
245 }
246
247 static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
248 {
249         int buf_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
250         struct virtio_vsock_pkt *pkt;
251         struct scatterlist hdr, buf, *sgs[2];
252         struct virtqueue *vq;
253         int ret;
254
255         vq = vsock->vqs[VSOCK_VQ_RX];
256
257         do {
258                 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
259                 if (!pkt)
260                         break;
261
262                 pkt->buf = kmalloc(buf_len, GFP_KERNEL);
263                 if (!pkt->buf) {
264                         virtio_transport_free_pkt(pkt);
265                         break;
266                 }
267
268                 pkt->buf_len = buf_len;
269                 pkt->len = buf_len;
270
271                 sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
272                 sgs[0] = &hdr;
273
274                 sg_init_one(&buf, pkt->buf, buf_len);
275                 sgs[1] = &buf;
276                 ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL);
277                 if (ret) {
278                         virtio_transport_free_pkt(pkt);
279                         break;
280                 }
281                 vsock->rx_buf_nr++;
282         } while (vq->num_free);
283         if (vsock->rx_buf_nr > vsock->rx_buf_max_nr)
284                 vsock->rx_buf_max_nr = vsock->rx_buf_nr;
285         virtqueue_kick(vq);
286 }
287
288 static void virtio_transport_tx_work(struct work_struct *work)
289 {
290         struct virtio_vsock *vsock =
291                 container_of(work, struct virtio_vsock, tx_work);
292         struct virtqueue *vq;
293         bool added = false;
294
295         vq = vsock->vqs[VSOCK_VQ_TX];
296         mutex_lock(&vsock->tx_lock);
297
298         if (!vsock->tx_run)
299                 goto out;
300
301         do {
302                 struct virtio_vsock_pkt *pkt;
303                 unsigned int len;
304
305                 virtqueue_disable_cb(vq);
306                 while ((pkt = virtqueue_get_buf(vq, &len)) != NULL) {
307                         virtio_transport_free_pkt(pkt);
308                         added = true;
309                 }
310         } while (!virtqueue_enable_cb(vq));
311
312 out:
313         mutex_unlock(&vsock->tx_lock);
314
315         if (added)
316                 queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
317 }
318
319 /* Is there space left for replies to rx packets? */
320 static bool virtio_transport_more_replies(struct virtio_vsock *vsock)
321 {
322         struct virtqueue *vq = vsock->vqs[VSOCK_VQ_RX];
323         int val;
324
325         smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
326         val = atomic_read(&vsock->queued_replies);
327
328         return val < virtqueue_get_vring_size(vq);
329 }
330
331 /* event_lock must be held */
332 static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock,
333                                        struct virtio_vsock_event *event)
334 {
335         struct scatterlist sg;
336         struct virtqueue *vq;
337
338         vq = vsock->vqs[VSOCK_VQ_EVENT];
339
340         sg_init_one(&sg, event, sizeof(*event));
341
342         return virtqueue_add_inbuf(vq, &sg, 1, event, GFP_KERNEL);
343 }
344
345 /* event_lock must be held */
346 static void virtio_vsock_event_fill(struct virtio_vsock *vsock)
347 {
348         size_t i;
349
350         for (i = 0; i < ARRAY_SIZE(vsock->event_list); i++) {
351                 struct virtio_vsock_event *event = &vsock->event_list[i];
352
353                 virtio_vsock_event_fill_one(vsock, event);
354         }
355
356         virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
357 }
358
359 static void virtio_vsock_reset_sock(struct sock *sk)
360 {
361         /* vmci_transport.c doesn't take sk_lock here either.  At least we're
362          * under vsock_table_lock so the sock cannot disappear while we're
363          * executing.
364          */
365
366         sk->sk_state = TCP_CLOSE;
367         sk->sk_err = ECONNRESET;
368         sk_error_report(sk);
369 }
370
371 static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock)
372 {
373         struct virtio_device *vdev = vsock->vdev;
374         __le64 guest_cid;
375
376         vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid),
377                           &guest_cid, sizeof(guest_cid));
378         vsock->guest_cid = le64_to_cpu(guest_cid);
379 }
380
381 /* event_lock must be held */
382 static void virtio_vsock_event_handle(struct virtio_vsock *vsock,
383                                       struct virtio_vsock_event *event)
384 {
385         switch (le32_to_cpu(event->id)) {
386         case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET:
387                 virtio_vsock_update_guest_cid(vsock);
388                 vsock_for_each_connected_socket(&virtio_transport.transport,
389                                                 virtio_vsock_reset_sock);
390                 break;
391         }
392 }
393
394 static void virtio_transport_event_work(struct work_struct *work)
395 {
396         struct virtio_vsock *vsock =
397                 container_of(work, struct virtio_vsock, event_work);
398         struct virtqueue *vq;
399
400         vq = vsock->vqs[VSOCK_VQ_EVENT];
401
402         mutex_lock(&vsock->event_lock);
403
404         if (!vsock->event_run)
405                 goto out;
406
407         do {
408                 struct virtio_vsock_event *event;
409                 unsigned int len;
410
411                 virtqueue_disable_cb(vq);
412                 while ((event = virtqueue_get_buf(vq, &len)) != NULL) {
413                         if (len == sizeof(*event))
414                                 virtio_vsock_event_handle(vsock, event);
415
416                         virtio_vsock_event_fill_one(vsock, event);
417                 }
418         } while (!virtqueue_enable_cb(vq));
419
420         virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
421 out:
422         mutex_unlock(&vsock->event_lock);
423 }
424
425 static void virtio_vsock_event_done(struct virtqueue *vq)
426 {
427         struct virtio_vsock *vsock = vq->vdev->priv;
428
429         if (!vsock)
430                 return;
431         queue_work(virtio_vsock_workqueue, &vsock->event_work);
432 }
433
434 static void virtio_vsock_tx_done(struct virtqueue *vq)
435 {
436         struct virtio_vsock *vsock = vq->vdev->priv;
437
438         if (!vsock)
439                 return;
440         queue_work(virtio_vsock_workqueue, &vsock->tx_work);
441 }
442
443 static void virtio_vsock_rx_done(struct virtqueue *vq)
444 {
445         struct virtio_vsock *vsock = vq->vdev->priv;
446
447         if (!vsock)
448                 return;
449         queue_work(virtio_vsock_workqueue, &vsock->rx_work);
450 }
451
452 static bool virtio_transport_seqpacket_allow(u32 remote_cid);
453
454 static struct virtio_transport virtio_transport = {
455         .transport = {
456                 .module                   = THIS_MODULE,
457
458                 .get_local_cid            = virtio_transport_get_local_cid,
459
460                 .init                     = virtio_transport_do_socket_init,
461                 .destruct                 = virtio_transport_destruct,
462                 .release                  = virtio_transport_release,
463                 .connect                  = virtio_transport_connect,
464                 .shutdown                 = virtio_transport_shutdown,
465                 .cancel_pkt               = virtio_transport_cancel_pkt,
466
467                 .dgram_bind               = virtio_transport_dgram_bind,
468                 .dgram_dequeue            = virtio_transport_dgram_dequeue,
469                 .dgram_enqueue            = virtio_transport_dgram_enqueue,
470                 .dgram_allow              = virtio_transport_dgram_allow,
471
472                 .stream_dequeue           = virtio_transport_stream_dequeue,
473                 .stream_enqueue           = virtio_transport_stream_enqueue,
474                 .stream_has_data          = virtio_transport_stream_has_data,
475                 .stream_has_space         = virtio_transport_stream_has_space,
476                 .stream_rcvhiwat          = virtio_transport_stream_rcvhiwat,
477                 .stream_is_active         = virtio_transport_stream_is_active,
478                 .stream_allow             = virtio_transport_stream_allow,
479
480                 .seqpacket_dequeue        = virtio_transport_seqpacket_dequeue,
481                 .seqpacket_enqueue        = virtio_transport_seqpacket_enqueue,
482                 .seqpacket_allow          = virtio_transport_seqpacket_allow,
483                 .seqpacket_has_data       = virtio_transport_seqpacket_has_data,
484
485                 .notify_poll_in           = virtio_transport_notify_poll_in,
486                 .notify_poll_out          = virtio_transport_notify_poll_out,
487                 .notify_recv_init         = virtio_transport_notify_recv_init,
488                 .notify_recv_pre_block    = virtio_transport_notify_recv_pre_block,
489                 .notify_recv_pre_dequeue  = virtio_transport_notify_recv_pre_dequeue,
490                 .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
491                 .notify_send_init         = virtio_transport_notify_send_init,
492                 .notify_send_pre_block    = virtio_transport_notify_send_pre_block,
493                 .notify_send_pre_enqueue  = virtio_transport_notify_send_pre_enqueue,
494                 .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
495                 .notify_buffer_size       = virtio_transport_notify_buffer_size,
496         },
497
498         .send_pkt = virtio_transport_send_pkt,
499 };
500
501 static bool virtio_transport_seqpacket_allow(u32 remote_cid)
502 {
503         struct virtio_vsock *vsock;
504         bool seqpacket_allow;
505
506         seqpacket_allow = false;
507         rcu_read_lock();
508         vsock = rcu_dereference(the_virtio_vsock);
509         if (vsock)
510                 seqpacket_allow = vsock->seqpacket_allow;
511         rcu_read_unlock();
512
513         return seqpacket_allow;
514 }
515
516 static void virtio_transport_rx_work(struct work_struct *work)
517 {
518         struct virtio_vsock *vsock =
519                 container_of(work, struct virtio_vsock, rx_work);
520         struct virtqueue *vq;
521
522         vq = vsock->vqs[VSOCK_VQ_RX];
523
524         mutex_lock(&vsock->rx_lock);
525
526         if (!vsock->rx_run)
527                 goto out;
528
529         do {
530                 virtqueue_disable_cb(vq);
531                 for (;;) {
532                         struct virtio_vsock_pkt *pkt;
533                         unsigned int len;
534
535                         if (!virtio_transport_more_replies(vsock)) {
536                                 /* Stop rx until the device processes already
537                                  * pending replies.  Leave rx virtqueue
538                                  * callbacks disabled.
539                                  */
540                                 goto out;
541                         }
542
543                         pkt = virtqueue_get_buf(vq, &len);
544                         if (!pkt) {
545                                 break;
546                         }
547
548                         vsock->rx_buf_nr--;
549
550                         /* Drop short/long packets */
551                         if (unlikely(len < sizeof(pkt->hdr) ||
552                                      len > sizeof(pkt->hdr) + pkt->len)) {
553                                 virtio_transport_free_pkt(pkt);
554                                 continue;
555                         }
556
557                         pkt->len = len - sizeof(pkt->hdr);
558                         virtio_transport_deliver_tap_pkt(pkt);
559                         virtio_transport_recv_pkt(&virtio_transport, pkt);
560                 }
561         } while (!virtqueue_enable_cb(vq));
562
563 out:
564         if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
565                 virtio_vsock_rx_fill(vsock);
566         mutex_unlock(&vsock->rx_lock);
567 }
568
569 static int virtio_vsock_vqs_init(struct virtio_vsock *vsock)
570 {
571         struct virtio_device *vdev = vsock->vdev;
572         static const char * const names[] = {
573                 "rx",
574                 "tx",
575                 "event",
576         };
577         vq_callback_t *callbacks[] = {
578                 virtio_vsock_rx_done,
579                 virtio_vsock_tx_done,
580                 virtio_vsock_event_done,
581         };
582         int ret;
583
584         ret = virtio_find_vqs(vdev, VSOCK_VQ_MAX, vsock->vqs, callbacks, names,
585                               NULL);
586         if (ret < 0)
587                 return ret;
588
589         virtio_vsock_update_guest_cid(vsock);
590
591         virtio_device_ready(vdev);
592
593         mutex_lock(&vsock->tx_lock);
594         vsock->tx_run = true;
595         mutex_unlock(&vsock->tx_lock);
596
597         mutex_lock(&vsock->rx_lock);
598         virtio_vsock_rx_fill(vsock);
599         vsock->rx_run = true;
600         mutex_unlock(&vsock->rx_lock);
601
602         mutex_lock(&vsock->event_lock);
603         virtio_vsock_event_fill(vsock);
604         vsock->event_run = true;
605         mutex_unlock(&vsock->event_lock);
606
607         return 0;
608 }
609
610 static void virtio_vsock_vqs_del(struct virtio_vsock *vsock)
611 {
612         struct virtio_device *vdev = vsock->vdev;
613         struct virtio_vsock_pkt *pkt;
614
615         /* Reset all connected sockets when the VQs disappear */
616         vsock_for_each_connected_socket(&virtio_transport.transport,
617                                         virtio_vsock_reset_sock);
618
619         /* Stop all work handlers to make sure no one is accessing the device,
620          * so we can safely call virtio_reset_device().
621          */
622         mutex_lock(&vsock->rx_lock);
623         vsock->rx_run = false;
624         mutex_unlock(&vsock->rx_lock);
625
626         mutex_lock(&vsock->tx_lock);
627         vsock->tx_run = false;
628         mutex_unlock(&vsock->tx_lock);
629
630         mutex_lock(&vsock->event_lock);
631         vsock->event_run = false;
632         mutex_unlock(&vsock->event_lock);
633
634         /* Flush all device writes and interrupts, device will not use any
635          * more buffers.
636          */
637         virtio_reset_device(vdev);
638
639         mutex_lock(&vsock->rx_lock);
640         while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_RX])))
641                 virtio_transport_free_pkt(pkt);
642         mutex_unlock(&vsock->rx_lock);
643
644         mutex_lock(&vsock->tx_lock);
645         while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_TX])))
646                 virtio_transport_free_pkt(pkt);
647         mutex_unlock(&vsock->tx_lock);
648
649         spin_lock_bh(&vsock->send_pkt_list_lock);
650         while (!list_empty(&vsock->send_pkt_list)) {
651                 pkt = list_first_entry(&vsock->send_pkt_list,
652                                        struct virtio_vsock_pkt, list);
653                 list_del(&pkt->list);
654                 virtio_transport_free_pkt(pkt);
655         }
656         spin_unlock_bh(&vsock->send_pkt_list_lock);
657
658         /* Delete virtqueues and flush outstanding callbacks if any */
659         vdev->config->del_vqs(vdev);
660 }
661
662 static int virtio_vsock_probe(struct virtio_device *vdev)
663 {
664         struct virtio_vsock *vsock = NULL;
665         int ret;
666
667         ret = mutex_lock_interruptible(&the_virtio_vsock_mutex);
668         if (ret)
669                 return ret;
670
671         /* Only one virtio-vsock device per guest is supported */
672         if (rcu_dereference_protected(the_virtio_vsock,
673                                 lockdep_is_held(&the_virtio_vsock_mutex))) {
674                 ret = -EBUSY;
675                 goto out;
676         }
677
678         vsock = kzalloc(sizeof(*vsock), GFP_KERNEL);
679         if (!vsock) {
680                 ret = -ENOMEM;
681                 goto out;
682         }
683
684         vsock->vdev = vdev;
685
686         vsock->rx_buf_nr = 0;
687         vsock->rx_buf_max_nr = 0;
688         atomic_set(&vsock->queued_replies, 0);
689
690         mutex_init(&vsock->tx_lock);
691         mutex_init(&vsock->rx_lock);
692         mutex_init(&vsock->event_lock);
693         spin_lock_init(&vsock->send_pkt_list_lock);
694         INIT_LIST_HEAD(&vsock->send_pkt_list);
695         INIT_WORK(&vsock->rx_work, virtio_transport_rx_work);
696         INIT_WORK(&vsock->tx_work, virtio_transport_tx_work);
697         INIT_WORK(&vsock->event_work, virtio_transport_event_work);
698         INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work);
699
700         if (virtio_has_feature(vdev, VIRTIO_VSOCK_F_SEQPACKET))
701                 vsock->seqpacket_allow = true;
702
703         vdev->priv = vsock;
704
705         ret = virtio_vsock_vqs_init(vsock);
706         if (ret < 0)
707                 goto out;
708
709         rcu_assign_pointer(the_virtio_vsock, vsock);
710
711         mutex_unlock(&the_virtio_vsock_mutex);
712
713         return 0;
714
715 out:
716         kfree(vsock);
717         mutex_unlock(&the_virtio_vsock_mutex);
718         return ret;
719 }
720
721 static void virtio_vsock_remove(struct virtio_device *vdev)
722 {
723         struct virtio_vsock *vsock = vdev->priv;
724
725         mutex_lock(&the_virtio_vsock_mutex);
726
727         vdev->priv = NULL;
728         rcu_assign_pointer(the_virtio_vsock, NULL);
729         synchronize_rcu();
730
731         virtio_vsock_vqs_del(vsock);
732
733         /* Other works can be queued before 'config->del_vqs()', so we flush
734          * all works before to free the vsock object to avoid use after free.
735          */
736         flush_work(&vsock->rx_work);
737         flush_work(&vsock->tx_work);
738         flush_work(&vsock->event_work);
739         flush_work(&vsock->send_pkt_work);
740
741         mutex_unlock(&the_virtio_vsock_mutex);
742
743         kfree(vsock);
744 }
745
746 #ifdef CONFIG_PM_SLEEP
747 static int virtio_vsock_freeze(struct virtio_device *vdev)
748 {
749         struct virtio_vsock *vsock = vdev->priv;
750
751         mutex_lock(&the_virtio_vsock_mutex);
752
753         rcu_assign_pointer(the_virtio_vsock, NULL);
754         synchronize_rcu();
755
756         virtio_vsock_vqs_del(vsock);
757
758         mutex_unlock(&the_virtio_vsock_mutex);
759
760         return 0;
761 }
762
763 static int virtio_vsock_restore(struct virtio_device *vdev)
764 {
765         struct virtio_vsock *vsock = vdev->priv;
766         int ret;
767
768         mutex_lock(&the_virtio_vsock_mutex);
769
770         /* Only one virtio-vsock device per guest is supported */
771         if (rcu_dereference_protected(the_virtio_vsock,
772                                 lockdep_is_held(&the_virtio_vsock_mutex))) {
773                 ret = -EBUSY;
774                 goto out;
775         }
776
777         ret = virtio_vsock_vqs_init(vsock);
778         if (ret < 0)
779                 goto out;
780
781         rcu_assign_pointer(the_virtio_vsock, vsock);
782
783 out:
784         mutex_unlock(&the_virtio_vsock_mutex);
785         return ret;
786 }
787 #endif /* CONFIG_PM_SLEEP */
788
789 static struct virtio_device_id id_table[] = {
790         { VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID },
791         { 0 },
792 };
793
794 static unsigned int features[] = {
795         VIRTIO_VSOCK_F_SEQPACKET
796 };
797
798 static struct virtio_driver virtio_vsock_driver = {
799         .feature_table = features,
800         .feature_table_size = ARRAY_SIZE(features),
801         .driver.name = KBUILD_MODNAME,
802         .driver.owner = THIS_MODULE,
803         .id_table = id_table,
804         .probe = virtio_vsock_probe,
805         .remove = virtio_vsock_remove,
806 #ifdef CONFIG_PM_SLEEP
807         .freeze = virtio_vsock_freeze,
808         .restore = virtio_vsock_restore,
809 #endif
810 };
811
812 static int __init virtio_vsock_init(void)
813 {
814         int ret;
815
816         virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0);
817         if (!virtio_vsock_workqueue)
818                 return -ENOMEM;
819
820         ret = vsock_core_register(&virtio_transport.transport,
821                                   VSOCK_TRANSPORT_F_G2H);
822         if (ret)
823                 goto out_wq;
824
825         ret = register_virtio_driver(&virtio_vsock_driver);
826         if (ret)
827                 goto out_vci;
828
829         return 0;
830
831 out_vci:
832         vsock_core_unregister(&virtio_transport.transport);
833 out_wq:
834         destroy_workqueue(virtio_vsock_workqueue);
835         return ret;
836 }
837
838 static void __exit virtio_vsock_exit(void)
839 {
840         unregister_virtio_driver(&virtio_vsock_driver);
841         vsock_core_unregister(&virtio_transport.transport);
842         destroy_workqueue(virtio_vsock_workqueue);
843 }
844
845 module_init(virtio_vsock_init);
846 module_exit(virtio_vsock_exit);
847 MODULE_LICENSE("GPL v2");
848 MODULE_AUTHOR("Asias He");
849 MODULE_DESCRIPTION("virtio transport for vsock");
850 MODULE_DEVICE_TABLE(virtio, id_table);