Merge tag 'mips_6.6' of git://git.kernel.org/pub/scm/linux/kernel/git/mips/linux
[platform/kernel/linux-rpi.git] / net / vmw_vsock / virtio_transport_common.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * common code for virtio vsock
4  *
5  * Copyright (C) 2013-2015 Red Hat, Inc.
6  * Author: Asias He <asias@redhat.com>
7  *         Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 #include <linux/spinlock.h>
10 #include <linux/module.h>
11 #include <linux/sched/signal.h>
12 #include <linux/ctype.h>
13 #include <linux/list.h>
14 #include <linux/virtio_vsock.h>
15 #include <uapi/linux/vsockmon.h>
16
17 #include <net/sock.h>
18 #include <net/af_vsock.h>
19
20 #define CREATE_TRACE_POINTS
21 #include <trace/events/vsock_virtio_transport_common.h>
22
23 /* How long to wait for graceful shutdown of a connection */
24 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
25
26 /* Threshold for detecting small packets to copy */
27 #define GOOD_COPY_LEN  128
28
29 static const struct virtio_transport *
30 virtio_transport_get_ops(struct vsock_sock *vsk)
31 {
32         const struct vsock_transport *t = vsock_core_get_transport(vsk);
33
34         if (WARN_ON(!t))
35                 return NULL;
36
37         return container_of(t, struct virtio_transport, transport);
38 }
39
40 /* Returns a new packet on success, otherwise returns NULL.
41  *
42  * If NULL is returned, errp is set to a negative errno.
43  */
44 static struct sk_buff *
45 virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info,
46                            size_t len,
47                            u32 src_cid,
48                            u32 src_port,
49                            u32 dst_cid,
50                            u32 dst_port)
51 {
52         const size_t skb_len = VIRTIO_VSOCK_SKB_HEADROOM + len;
53         struct virtio_vsock_hdr *hdr;
54         struct sk_buff *skb;
55         void *payload;
56         int err;
57
58         skb = virtio_vsock_alloc_skb(skb_len, GFP_KERNEL);
59         if (!skb)
60                 return NULL;
61
62         hdr = virtio_vsock_hdr(skb);
63         hdr->type       = cpu_to_le16(info->type);
64         hdr->op         = cpu_to_le16(info->op);
65         hdr->src_cid    = cpu_to_le64(src_cid);
66         hdr->dst_cid    = cpu_to_le64(dst_cid);
67         hdr->src_port   = cpu_to_le32(src_port);
68         hdr->dst_port   = cpu_to_le32(dst_port);
69         hdr->flags      = cpu_to_le32(info->flags);
70         hdr->len        = cpu_to_le32(len);
71
72         if (info->msg && len > 0) {
73                 payload = skb_put(skb, len);
74                 err = memcpy_from_msg(payload, info->msg, len);
75                 if (err)
76                         goto out;
77
78                 if (msg_data_left(info->msg) == 0 &&
79                     info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) {
80                         hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM);
81
82                         if (info->msg->msg_flags & MSG_EOR)
83                                 hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
84                 }
85         }
86
87         if (info->reply)
88                 virtio_vsock_skb_set_reply(skb);
89
90         trace_virtio_transport_alloc_pkt(src_cid, src_port,
91                                          dst_cid, dst_port,
92                                          len,
93                                          info->type,
94                                          info->op,
95                                          info->flags);
96
97         if (info->vsk && !skb_set_owner_sk_safe(skb, sk_vsock(info->vsk))) {
98                 WARN_ONCE(1, "failed to allocate skb on vsock socket with sk_refcnt == 0\n");
99                 goto out;
100         }
101
102         return skb;
103
104 out:
105         kfree_skb(skb);
106         return NULL;
107 }
108
109 /* Packet capture */
110 static struct sk_buff *virtio_transport_build_skb(void *opaque)
111 {
112         struct virtio_vsock_hdr *pkt_hdr;
113         struct sk_buff *pkt = opaque;
114         struct af_vsockmon_hdr *hdr;
115         struct sk_buff *skb;
116         size_t payload_len;
117         void *payload_buf;
118
119         /* A packet could be split to fit the RX buffer, so we can retrieve
120          * the payload length from the header and the buffer pointer taking
121          * care of the offset in the original packet.
122          */
123         pkt_hdr = virtio_vsock_hdr(pkt);
124         payload_len = pkt->len;
125         payload_buf = pkt->data;
126
127         skb = alloc_skb(sizeof(*hdr) + sizeof(*pkt_hdr) + payload_len,
128                         GFP_ATOMIC);
129         if (!skb)
130                 return NULL;
131
132         hdr = skb_put(skb, sizeof(*hdr));
133
134         /* pkt->hdr is little-endian so no need to byteswap here */
135         hdr->src_cid = pkt_hdr->src_cid;
136         hdr->src_port = pkt_hdr->src_port;
137         hdr->dst_cid = pkt_hdr->dst_cid;
138         hdr->dst_port = pkt_hdr->dst_port;
139
140         hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
141         hdr->len = cpu_to_le16(sizeof(*pkt_hdr));
142         memset(hdr->reserved, 0, sizeof(hdr->reserved));
143
144         switch (le16_to_cpu(pkt_hdr->op)) {
145         case VIRTIO_VSOCK_OP_REQUEST:
146         case VIRTIO_VSOCK_OP_RESPONSE:
147                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
148                 break;
149         case VIRTIO_VSOCK_OP_RST:
150         case VIRTIO_VSOCK_OP_SHUTDOWN:
151                 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
152                 break;
153         case VIRTIO_VSOCK_OP_RW:
154                 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
155                 break;
156         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
157         case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
158                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
159                 break;
160         default:
161                 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
162                 break;
163         }
164
165         skb_put_data(skb, pkt_hdr, sizeof(*pkt_hdr));
166
167         if (payload_len) {
168                 skb_put_data(skb, payload_buf, payload_len);
169         }
170
171         return skb;
172 }
173
174 void virtio_transport_deliver_tap_pkt(struct sk_buff *skb)
175 {
176         if (virtio_vsock_skb_tap_delivered(skb))
177                 return;
178
179         vsock_deliver_tap(virtio_transport_build_skb, skb);
180         virtio_vsock_skb_set_tap_delivered(skb);
181 }
182 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
183
184 static u16 virtio_transport_get_type(struct sock *sk)
185 {
186         if (sk->sk_type == SOCK_STREAM)
187                 return VIRTIO_VSOCK_TYPE_STREAM;
188         else
189                 return VIRTIO_VSOCK_TYPE_SEQPACKET;
190 }
191
192 /* This function can only be used on connecting/connected sockets,
193  * since a socket assigned to a transport is required.
194  *
195  * Do not use on listener sockets!
196  */
197 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
198                                           struct virtio_vsock_pkt_info *info)
199 {
200         u32 src_cid, src_port, dst_cid, dst_port;
201         const struct virtio_transport *t_ops;
202         struct virtio_vsock_sock *vvs;
203         u32 pkt_len = info->pkt_len;
204         u32 rest_len;
205         int ret;
206
207         info->type = virtio_transport_get_type(sk_vsock(vsk));
208
209         t_ops = virtio_transport_get_ops(vsk);
210         if (unlikely(!t_ops))
211                 return -EFAULT;
212
213         src_cid = t_ops->transport.get_local_cid();
214         src_port = vsk->local_addr.svm_port;
215         if (!info->remote_cid) {
216                 dst_cid = vsk->remote_addr.svm_cid;
217                 dst_port = vsk->remote_addr.svm_port;
218         } else {
219                 dst_cid = info->remote_cid;
220                 dst_port = info->remote_port;
221         }
222
223         vvs = vsk->trans;
224
225         /* virtio_transport_get_credit might return less than pkt_len credit */
226         pkt_len = virtio_transport_get_credit(vvs, pkt_len);
227
228         /* Do not send zero length OP_RW pkt */
229         if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
230                 return pkt_len;
231
232         rest_len = pkt_len;
233
234         do {
235                 struct sk_buff *skb;
236                 size_t skb_len;
237
238                 skb_len = min_t(u32, VIRTIO_VSOCK_MAX_PKT_BUF_SIZE, rest_len);
239
240                 skb = virtio_transport_alloc_skb(info, skb_len,
241                                                  src_cid, src_port,
242                                                  dst_cid, dst_port);
243                 if (!skb) {
244                         ret = -ENOMEM;
245                         break;
246                 }
247
248                 virtio_transport_inc_tx_pkt(vvs, skb);
249
250                 ret = t_ops->send_pkt(skb);
251                 if (ret < 0)
252                         break;
253
254                 /* Both virtio and vhost 'send_pkt()' returns 'skb_len',
255                  * but for reliability use 'ret' instead of 'skb_len'.
256                  * Also if partial send happens (e.g. 'ret' != 'skb_len')
257                  * somehow, we break this loop, but account such returned
258                  * value in 'virtio_transport_put_credit()'.
259                  */
260                 rest_len -= ret;
261
262                 if (WARN_ONCE(ret != skb_len,
263                               "'send_pkt()' returns %i, but %zu expected\n",
264                               ret, skb_len))
265                         break;
266         } while (rest_len);
267
268         virtio_transport_put_credit(vvs, rest_len);
269
270         /* Return number of bytes, if any data has been sent. */
271         if (rest_len != pkt_len)
272                 ret = pkt_len - rest_len;
273
274         return ret;
275 }
276
277 static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
278                                         u32 len)
279 {
280         if (vvs->rx_bytes + len > vvs->buf_alloc)
281                 return false;
282
283         vvs->rx_bytes += len;
284         return true;
285 }
286
287 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
288                                         u32 len)
289 {
290         vvs->rx_bytes -= len;
291         vvs->fwd_cnt += len;
292 }
293
294 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct sk_buff *skb)
295 {
296         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
297
298         spin_lock_bh(&vvs->rx_lock);
299         vvs->last_fwd_cnt = vvs->fwd_cnt;
300         hdr->fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
301         hdr->buf_alloc = cpu_to_le32(vvs->buf_alloc);
302         spin_unlock_bh(&vvs->rx_lock);
303 }
304 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
305
306 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
307 {
308         u32 ret;
309
310         if (!credit)
311                 return 0;
312
313         spin_lock_bh(&vvs->tx_lock);
314         ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
315         if (ret > credit)
316                 ret = credit;
317         vvs->tx_cnt += ret;
318         spin_unlock_bh(&vvs->tx_lock);
319
320         return ret;
321 }
322 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
323
324 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
325 {
326         if (!credit)
327                 return;
328
329         spin_lock_bh(&vvs->tx_lock);
330         vvs->tx_cnt -= credit;
331         spin_unlock_bh(&vvs->tx_lock);
332 }
333 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
334
335 static int virtio_transport_send_credit_update(struct vsock_sock *vsk)
336 {
337         struct virtio_vsock_pkt_info info = {
338                 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
339                 .vsk = vsk,
340         };
341
342         return virtio_transport_send_pkt_info(vsk, &info);
343 }
344
345 static ssize_t
346 virtio_transport_stream_do_peek(struct vsock_sock *vsk,
347                                 struct msghdr *msg,
348                                 size_t len)
349 {
350         struct virtio_vsock_sock *vvs = vsk->trans;
351         struct sk_buff *skb;
352         size_t total = 0;
353         int err;
354
355         spin_lock_bh(&vvs->rx_lock);
356
357         skb_queue_walk(&vvs->rx_queue, skb) {
358                 size_t bytes;
359
360                 bytes = len - total;
361                 if (bytes > skb->len)
362                         bytes = skb->len;
363
364                 spin_unlock_bh(&vvs->rx_lock);
365
366                 /* sk_lock is held by caller so no one else can dequeue.
367                  * Unlock rx_lock since memcpy_to_msg() may sleep.
368                  */
369                 err = memcpy_to_msg(msg, skb->data, bytes);
370                 if (err)
371                         goto out;
372
373                 total += bytes;
374
375                 spin_lock_bh(&vvs->rx_lock);
376
377                 if (total == len)
378                         break;
379         }
380
381         spin_unlock_bh(&vvs->rx_lock);
382
383         return total;
384
385 out:
386         if (total)
387                 err = total;
388         return err;
389 }
390
391 static ssize_t
392 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
393                                    struct msghdr *msg,
394                                    size_t len)
395 {
396         struct virtio_vsock_sock *vvs = vsk->trans;
397         size_t bytes, total = 0;
398         struct sk_buff *skb;
399         int err = -EFAULT;
400         u32 free_space;
401
402         spin_lock_bh(&vvs->rx_lock);
403
404         if (WARN_ONCE(skb_queue_empty(&vvs->rx_queue) && vvs->rx_bytes,
405                       "rx_queue is empty, but rx_bytes is non-zero\n")) {
406                 spin_unlock_bh(&vvs->rx_lock);
407                 return err;
408         }
409
410         while (total < len && !skb_queue_empty(&vvs->rx_queue)) {
411                 skb = skb_peek(&vvs->rx_queue);
412
413                 bytes = len - total;
414                 if (bytes > skb->len)
415                         bytes = skb->len;
416
417                 /* sk_lock is held by caller so no one else can dequeue.
418                  * Unlock rx_lock since memcpy_to_msg() may sleep.
419                  */
420                 spin_unlock_bh(&vvs->rx_lock);
421
422                 err = memcpy_to_msg(msg, skb->data, bytes);
423                 if (err)
424                         goto out;
425
426                 spin_lock_bh(&vvs->rx_lock);
427
428                 total += bytes;
429                 skb_pull(skb, bytes);
430
431                 if (skb->len == 0) {
432                         u32 pkt_len = le32_to_cpu(virtio_vsock_hdr(skb)->len);
433
434                         virtio_transport_dec_rx_pkt(vvs, pkt_len);
435                         __skb_unlink(skb, &vvs->rx_queue);
436                         consume_skb(skb);
437                 }
438         }
439
440         free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt);
441
442         spin_unlock_bh(&vvs->rx_lock);
443
444         /* To reduce the number of credit update messages,
445          * don't update credits as long as lots of space is available.
446          * Note: the limit chosen here is arbitrary. Setting the limit
447          * too high causes extra messages. Too low causes transmitter
448          * stalls. As stalls are in theory more expensive than extra
449          * messages, we set the limit to a high value. TODO: experiment
450          * with different values.
451          */
452         if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
453                 virtio_transport_send_credit_update(vsk);
454
455         return total;
456
457 out:
458         if (total)
459                 err = total;
460         return err;
461 }
462
463 static ssize_t
464 virtio_transport_seqpacket_do_peek(struct vsock_sock *vsk,
465                                    struct msghdr *msg)
466 {
467         struct virtio_vsock_sock *vvs = vsk->trans;
468         struct sk_buff *skb;
469         size_t total, len;
470
471         spin_lock_bh(&vvs->rx_lock);
472
473         if (!vvs->msg_count) {
474                 spin_unlock_bh(&vvs->rx_lock);
475                 return 0;
476         }
477
478         total = 0;
479         len = msg_data_left(msg);
480
481         skb_queue_walk(&vvs->rx_queue, skb) {
482                 struct virtio_vsock_hdr *hdr;
483
484                 if (total < len) {
485                         size_t bytes;
486                         int err;
487
488                         bytes = len - total;
489                         if (bytes > skb->len)
490                                 bytes = skb->len;
491
492                         spin_unlock_bh(&vvs->rx_lock);
493
494                         /* sk_lock is held by caller so no one else can dequeue.
495                          * Unlock rx_lock since memcpy_to_msg() may sleep.
496                          */
497                         err = memcpy_to_msg(msg, skb->data, bytes);
498                         if (err)
499                                 return err;
500
501                         spin_lock_bh(&vvs->rx_lock);
502                 }
503
504                 total += skb->len;
505                 hdr = virtio_vsock_hdr(skb);
506
507                 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) {
508                         if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR)
509                                 msg->msg_flags |= MSG_EOR;
510
511                         break;
512                 }
513         }
514
515         spin_unlock_bh(&vvs->rx_lock);
516
517         return total;
518 }
519
520 static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
521                                                  struct msghdr *msg,
522                                                  int flags)
523 {
524         struct virtio_vsock_sock *vvs = vsk->trans;
525         int dequeued_len = 0;
526         size_t user_buf_len = msg_data_left(msg);
527         bool msg_ready = false;
528         struct sk_buff *skb;
529
530         spin_lock_bh(&vvs->rx_lock);
531
532         if (vvs->msg_count == 0) {
533                 spin_unlock_bh(&vvs->rx_lock);
534                 return 0;
535         }
536
537         while (!msg_ready) {
538                 struct virtio_vsock_hdr *hdr;
539                 size_t pkt_len;
540
541                 skb = __skb_dequeue(&vvs->rx_queue);
542                 if (!skb)
543                         break;
544                 hdr = virtio_vsock_hdr(skb);
545                 pkt_len = (size_t)le32_to_cpu(hdr->len);
546
547                 if (dequeued_len >= 0) {
548                         size_t bytes_to_copy;
549
550                         bytes_to_copy = min(user_buf_len, pkt_len);
551
552                         if (bytes_to_copy) {
553                                 int err;
554
555                                 /* sk_lock is held by caller so no one else can dequeue.
556                                  * Unlock rx_lock since memcpy_to_msg() may sleep.
557                                  */
558                                 spin_unlock_bh(&vvs->rx_lock);
559
560                                 err = memcpy_to_msg(msg, skb->data, bytes_to_copy);
561                                 if (err) {
562                                         /* Copy of message failed. Rest of
563                                          * fragments will be freed without copy.
564                                          */
565                                         dequeued_len = err;
566                                 } else {
567                                         user_buf_len -= bytes_to_copy;
568                                 }
569
570                                 spin_lock_bh(&vvs->rx_lock);
571                         }
572
573                         if (dequeued_len >= 0)
574                                 dequeued_len += pkt_len;
575                 }
576
577                 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) {
578                         msg_ready = true;
579                         vvs->msg_count--;
580
581                         if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR)
582                                 msg->msg_flags |= MSG_EOR;
583                 }
584
585                 virtio_transport_dec_rx_pkt(vvs, pkt_len);
586                 kfree_skb(skb);
587         }
588
589         spin_unlock_bh(&vvs->rx_lock);
590
591         virtio_transport_send_credit_update(vsk);
592
593         return dequeued_len;
594 }
595
596 ssize_t
597 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
598                                 struct msghdr *msg,
599                                 size_t len, int flags)
600 {
601         if (flags & MSG_PEEK)
602                 return virtio_transport_stream_do_peek(vsk, msg, len);
603         else
604                 return virtio_transport_stream_do_dequeue(vsk, msg, len);
605 }
606 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
607
608 ssize_t
609 virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
610                                    struct msghdr *msg,
611                                    int flags)
612 {
613         if (flags & MSG_PEEK)
614                 return virtio_transport_seqpacket_do_peek(vsk, msg);
615         else
616                 return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags);
617 }
618 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
619
620 int
621 virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
622                                    struct msghdr *msg,
623                                    size_t len)
624 {
625         struct virtio_vsock_sock *vvs = vsk->trans;
626
627         spin_lock_bh(&vvs->tx_lock);
628
629         if (len > vvs->peer_buf_alloc) {
630                 spin_unlock_bh(&vvs->tx_lock);
631                 return -EMSGSIZE;
632         }
633
634         spin_unlock_bh(&vvs->tx_lock);
635
636         return virtio_transport_stream_enqueue(vsk, msg, len);
637 }
638 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue);
639
640 int
641 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
642                                struct msghdr *msg,
643                                size_t len, int flags)
644 {
645         return -EOPNOTSUPP;
646 }
647 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
648
649 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
650 {
651         struct virtio_vsock_sock *vvs = vsk->trans;
652         s64 bytes;
653
654         spin_lock_bh(&vvs->rx_lock);
655         bytes = vvs->rx_bytes;
656         spin_unlock_bh(&vvs->rx_lock);
657
658         return bytes;
659 }
660 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
661
662 u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk)
663 {
664         struct virtio_vsock_sock *vvs = vsk->trans;
665         u32 msg_count;
666
667         spin_lock_bh(&vvs->rx_lock);
668         msg_count = vvs->msg_count;
669         spin_unlock_bh(&vvs->rx_lock);
670
671         return msg_count;
672 }
673 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data);
674
675 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
676 {
677         struct virtio_vsock_sock *vvs = vsk->trans;
678         s64 bytes;
679
680         bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
681         if (bytes < 0)
682                 bytes = 0;
683
684         return bytes;
685 }
686
687 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
688 {
689         struct virtio_vsock_sock *vvs = vsk->trans;
690         s64 bytes;
691
692         spin_lock_bh(&vvs->tx_lock);
693         bytes = virtio_transport_has_space(vsk);
694         spin_unlock_bh(&vvs->tx_lock);
695
696         return bytes;
697 }
698 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
699
700 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
701                                     struct vsock_sock *psk)
702 {
703         struct virtio_vsock_sock *vvs;
704
705         vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
706         if (!vvs)
707                 return -ENOMEM;
708
709         vsk->trans = vvs;
710         vvs->vsk = vsk;
711         if (psk && psk->trans) {
712                 struct virtio_vsock_sock *ptrans = psk->trans;
713
714                 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
715         }
716
717         if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE)
718                 vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE;
719
720         vvs->buf_alloc = vsk->buffer_size;
721
722         spin_lock_init(&vvs->rx_lock);
723         spin_lock_init(&vvs->tx_lock);
724         skb_queue_head_init(&vvs->rx_queue);
725
726         return 0;
727 }
728 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
729
730 /* sk_lock held by the caller */
731 void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
732 {
733         struct virtio_vsock_sock *vvs = vsk->trans;
734
735         if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE)
736                 *val = VIRTIO_VSOCK_MAX_BUF_SIZE;
737
738         vvs->buf_alloc = *val;
739
740         virtio_transport_send_credit_update(vsk);
741 }
742 EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
743
744 int
745 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
746                                 size_t target,
747                                 bool *data_ready_now)
748 {
749         *data_ready_now = vsock_stream_has_data(vsk) >= target;
750
751         return 0;
752 }
753 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
754
755 int
756 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
757                                  size_t target,
758                                  bool *space_avail_now)
759 {
760         s64 free_space;
761
762         free_space = vsock_stream_has_space(vsk);
763         if (free_space > 0)
764                 *space_avail_now = true;
765         else if (free_space == 0)
766                 *space_avail_now = false;
767
768         return 0;
769 }
770 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
771
772 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
773         size_t target, struct vsock_transport_recv_notify_data *data)
774 {
775         return 0;
776 }
777 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
778
779 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
780         size_t target, struct vsock_transport_recv_notify_data *data)
781 {
782         return 0;
783 }
784 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
785
786 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
787         size_t target, struct vsock_transport_recv_notify_data *data)
788 {
789         return 0;
790 }
791 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
792
793 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
794         size_t target, ssize_t copied, bool data_read,
795         struct vsock_transport_recv_notify_data *data)
796 {
797         return 0;
798 }
799 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
800
801 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
802         struct vsock_transport_send_notify_data *data)
803 {
804         return 0;
805 }
806 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
807
808 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
809         struct vsock_transport_send_notify_data *data)
810 {
811         return 0;
812 }
813 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
814
815 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
816         struct vsock_transport_send_notify_data *data)
817 {
818         return 0;
819 }
820 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
821
822 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
823         ssize_t written, struct vsock_transport_send_notify_data *data)
824 {
825         return 0;
826 }
827 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
828
829 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
830 {
831         return vsk->buffer_size;
832 }
833 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
834
835 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
836 {
837         return true;
838 }
839 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
840
841 bool virtio_transport_stream_allow(u32 cid, u32 port)
842 {
843         return true;
844 }
845 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
846
847 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
848                                 struct sockaddr_vm *addr)
849 {
850         return -EOPNOTSUPP;
851 }
852 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
853
854 bool virtio_transport_dgram_allow(u32 cid, u32 port)
855 {
856         return false;
857 }
858 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
859
860 int virtio_transport_connect(struct vsock_sock *vsk)
861 {
862         struct virtio_vsock_pkt_info info = {
863                 .op = VIRTIO_VSOCK_OP_REQUEST,
864                 .vsk = vsk,
865         };
866
867         return virtio_transport_send_pkt_info(vsk, &info);
868 }
869 EXPORT_SYMBOL_GPL(virtio_transport_connect);
870
871 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
872 {
873         struct virtio_vsock_pkt_info info = {
874                 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
875                 .flags = (mode & RCV_SHUTDOWN ?
876                           VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
877                          (mode & SEND_SHUTDOWN ?
878                           VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
879                 .vsk = vsk,
880         };
881
882         return virtio_transport_send_pkt_info(vsk, &info);
883 }
884 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
885
886 int
887 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
888                                struct sockaddr_vm *remote_addr,
889                                struct msghdr *msg,
890                                size_t dgram_len)
891 {
892         return -EOPNOTSUPP;
893 }
894 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
895
896 ssize_t
897 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
898                                 struct msghdr *msg,
899                                 size_t len)
900 {
901         struct virtio_vsock_pkt_info info = {
902                 .op = VIRTIO_VSOCK_OP_RW,
903                 .msg = msg,
904                 .pkt_len = len,
905                 .vsk = vsk,
906         };
907
908         return virtio_transport_send_pkt_info(vsk, &info);
909 }
910 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
911
912 void virtio_transport_destruct(struct vsock_sock *vsk)
913 {
914         struct virtio_vsock_sock *vvs = vsk->trans;
915
916         kfree(vvs);
917 }
918 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
919
920 static int virtio_transport_reset(struct vsock_sock *vsk,
921                                   struct sk_buff *skb)
922 {
923         struct virtio_vsock_pkt_info info = {
924                 .op = VIRTIO_VSOCK_OP_RST,
925                 .reply = !!skb,
926                 .vsk = vsk,
927         };
928
929         /* Send RST only if the original pkt is not a RST pkt */
930         if (skb && le16_to_cpu(virtio_vsock_hdr(skb)->op) == VIRTIO_VSOCK_OP_RST)
931                 return 0;
932
933         return virtio_transport_send_pkt_info(vsk, &info);
934 }
935
936 /* Normally packets are associated with a socket.  There may be no socket if an
937  * attempt was made to connect to a socket that does not exist.
938  */
939 static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
940                                           struct sk_buff *skb)
941 {
942         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
943         struct virtio_vsock_pkt_info info = {
944                 .op = VIRTIO_VSOCK_OP_RST,
945                 .type = le16_to_cpu(hdr->type),
946                 .reply = true,
947         };
948         struct sk_buff *reply;
949
950         /* Send RST only if the original pkt is not a RST pkt */
951         if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST)
952                 return 0;
953
954         if (!t)
955                 return -ENOTCONN;
956
957         reply = virtio_transport_alloc_skb(&info, 0,
958                                            le64_to_cpu(hdr->dst_cid),
959                                            le32_to_cpu(hdr->dst_port),
960                                            le64_to_cpu(hdr->src_cid),
961                                            le32_to_cpu(hdr->src_port));
962         if (!reply)
963                 return -ENOMEM;
964
965         return t->send_pkt(reply);
966 }
967
968 /* This function should be called with sk_lock held and SOCK_DONE set */
969 static void virtio_transport_remove_sock(struct vsock_sock *vsk)
970 {
971         struct virtio_vsock_sock *vvs = vsk->trans;
972
973         /* We don't need to take rx_lock, as the socket is closing and we are
974          * removing it.
975          */
976         __skb_queue_purge(&vvs->rx_queue);
977         vsock_remove_sock(vsk);
978 }
979
980 static void virtio_transport_wait_close(struct sock *sk, long timeout)
981 {
982         if (timeout) {
983                 DEFINE_WAIT_FUNC(wait, woken_wake_function);
984
985                 add_wait_queue(sk_sleep(sk), &wait);
986
987                 do {
988                         if (sk_wait_event(sk, &timeout,
989                                           sock_flag(sk, SOCK_DONE), &wait))
990                                 break;
991                 } while (!signal_pending(current) && timeout);
992
993                 remove_wait_queue(sk_sleep(sk), &wait);
994         }
995 }
996
997 static void virtio_transport_do_close(struct vsock_sock *vsk,
998                                       bool cancel_timeout)
999 {
1000         struct sock *sk = sk_vsock(vsk);
1001
1002         sock_set_flag(sk, SOCK_DONE);
1003         vsk->peer_shutdown = SHUTDOWN_MASK;
1004         if (vsock_stream_has_data(vsk) <= 0)
1005                 sk->sk_state = TCP_CLOSING;
1006         sk->sk_state_change(sk);
1007
1008         if (vsk->close_work_scheduled &&
1009             (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
1010                 vsk->close_work_scheduled = false;
1011
1012                 virtio_transport_remove_sock(vsk);
1013
1014                 /* Release refcnt obtained when we scheduled the timeout */
1015                 sock_put(sk);
1016         }
1017 }
1018
1019 static void virtio_transport_close_timeout(struct work_struct *work)
1020 {
1021         struct vsock_sock *vsk =
1022                 container_of(work, struct vsock_sock, close_work.work);
1023         struct sock *sk = sk_vsock(vsk);
1024
1025         sock_hold(sk);
1026         lock_sock(sk);
1027
1028         if (!sock_flag(sk, SOCK_DONE)) {
1029                 (void)virtio_transport_reset(vsk, NULL);
1030
1031                 virtio_transport_do_close(vsk, false);
1032         }
1033
1034         vsk->close_work_scheduled = false;
1035
1036         release_sock(sk);
1037         sock_put(sk);
1038 }
1039
1040 /* User context, vsk->sk is locked */
1041 static bool virtio_transport_close(struct vsock_sock *vsk)
1042 {
1043         struct sock *sk = &vsk->sk;
1044
1045         if (!(sk->sk_state == TCP_ESTABLISHED ||
1046               sk->sk_state == TCP_CLOSING))
1047                 return true;
1048
1049         /* Already received SHUTDOWN from peer, reply with RST */
1050         if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
1051                 (void)virtio_transport_reset(vsk, NULL);
1052                 return true;
1053         }
1054
1055         if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
1056                 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
1057
1058         if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
1059                 virtio_transport_wait_close(sk, sk->sk_lingertime);
1060
1061         if (sock_flag(sk, SOCK_DONE)) {
1062                 return true;
1063         }
1064
1065         sock_hold(sk);
1066         INIT_DELAYED_WORK(&vsk->close_work,
1067                           virtio_transport_close_timeout);
1068         vsk->close_work_scheduled = true;
1069         schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
1070         return false;
1071 }
1072
1073 void virtio_transport_release(struct vsock_sock *vsk)
1074 {
1075         struct sock *sk = &vsk->sk;
1076         bool remove_sock = true;
1077
1078         if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
1079                 remove_sock = virtio_transport_close(vsk);
1080
1081         if (remove_sock) {
1082                 sock_set_flag(sk, SOCK_DONE);
1083                 virtio_transport_remove_sock(vsk);
1084         }
1085 }
1086 EXPORT_SYMBOL_GPL(virtio_transport_release);
1087
1088 static int
1089 virtio_transport_recv_connecting(struct sock *sk,
1090                                  struct sk_buff *skb)
1091 {
1092         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1093         struct vsock_sock *vsk = vsock_sk(sk);
1094         int skerr;
1095         int err;
1096
1097         switch (le16_to_cpu(hdr->op)) {
1098         case VIRTIO_VSOCK_OP_RESPONSE:
1099                 sk->sk_state = TCP_ESTABLISHED;
1100                 sk->sk_socket->state = SS_CONNECTED;
1101                 vsock_insert_connected(vsk);
1102                 sk->sk_state_change(sk);
1103                 break;
1104         case VIRTIO_VSOCK_OP_INVALID:
1105                 break;
1106         case VIRTIO_VSOCK_OP_RST:
1107                 skerr = ECONNRESET;
1108                 err = 0;
1109                 goto destroy;
1110         default:
1111                 skerr = EPROTO;
1112                 err = -EINVAL;
1113                 goto destroy;
1114         }
1115         return 0;
1116
1117 destroy:
1118         virtio_transport_reset(vsk, skb);
1119         sk->sk_state = TCP_CLOSE;
1120         sk->sk_err = skerr;
1121         sk_error_report(sk);
1122         return err;
1123 }
1124
1125 static void
1126 virtio_transport_recv_enqueue(struct vsock_sock *vsk,
1127                               struct sk_buff *skb)
1128 {
1129         struct virtio_vsock_sock *vvs = vsk->trans;
1130         bool can_enqueue, free_pkt = false;
1131         struct virtio_vsock_hdr *hdr;
1132         u32 len;
1133
1134         hdr = virtio_vsock_hdr(skb);
1135         len = le32_to_cpu(hdr->len);
1136
1137         spin_lock_bh(&vvs->rx_lock);
1138
1139         can_enqueue = virtio_transport_inc_rx_pkt(vvs, len);
1140         if (!can_enqueue) {
1141                 free_pkt = true;
1142                 goto out;
1143         }
1144
1145         if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)
1146                 vvs->msg_count++;
1147
1148         /* Try to copy small packets into the buffer of last packet queued,
1149          * to avoid wasting memory queueing the entire buffer with a small
1150          * payload.
1151          */
1152         if (len <= GOOD_COPY_LEN && !skb_queue_empty(&vvs->rx_queue)) {
1153                 struct virtio_vsock_hdr *last_hdr;
1154                 struct sk_buff *last_skb;
1155
1156                 last_skb = skb_peek_tail(&vvs->rx_queue);
1157                 last_hdr = virtio_vsock_hdr(last_skb);
1158
1159                 /* If there is space in the last packet queued, we copy the
1160                  * new packet in its buffer. We avoid this if the last packet
1161                  * queued has VIRTIO_VSOCK_SEQ_EOM set, because this is
1162                  * delimiter of SEQPACKET message, so 'pkt' is the first packet
1163                  * of a new message.
1164                  */
1165                 if (skb->len < skb_tailroom(last_skb) &&
1166                     !(le32_to_cpu(last_hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)) {
1167                         memcpy(skb_put(last_skb, skb->len), skb->data, skb->len);
1168                         free_pkt = true;
1169                         last_hdr->flags |= hdr->flags;
1170                         le32_add_cpu(&last_hdr->len, len);
1171                         goto out;
1172                 }
1173         }
1174
1175         __skb_queue_tail(&vvs->rx_queue, skb);
1176
1177 out:
1178         spin_unlock_bh(&vvs->rx_lock);
1179         if (free_pkt)
1180                 kfree_skb(skb);
1181 }
1182
1183 static int
1184 virtio_transport_recv_connected(struct sock *sk,
1185                                 struct sk_buff *skb)
1186 {
1187         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1188         struct vsock_sock *vsk = vsock_sk(sk);
1189         int err = 0;
1190
1191         switch (le16_to_cpu(hdr->op)) {
1192         case VIRTIO_VSOCK_OP_RW:
1193                 virtio_transport_recv_enqueue(vsk, skb);
1194                 vsock_data_ready(sk);
1195                 return err;
1196         case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
1197                 virtio_transport_send_credit_update(vsk);
1198                 break;
1199         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
1200                 sk->sk_write_space(sk);
1201                 break;
1202         case VIRTIO_VSOCK_OP_SHUTDOWN:
1203                 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
1204                         vsk->peer_shutdown |= RCV_SHUTDOWN;
1205                 if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
1206                         vsk->peer_shutdown |= SEND_SHUTDOWN;
1207                 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
1208                     vsock_stream_has_data(vsk) <= 0 &&
1209                     !sock_flag(sk, SOCK_DONE)) {
1210                         (void)virtio_transport_reset(vsk, NULL);
1211                         virtio_transport_do_close(vsk, true);
1212                 }
1213                 if (le32_to_cpu(virtio_vsock_hdr(skb)->flags))
1214                         sk->sk_state_change(sk);
1215                 break;
1216         case VIRTIO_VSOCK_OP_RST:
1217                 virtio_transport_do_close(vsk, true);
1218                 break;
1219         default:
1220                 err = -EINVAL;
1221                 break;
1222         }
1223
1224         kfree_skb(skb);
1225         return err;
1226 }
1227
1228 static void
1229 virtio_transport_recv_disconnecting(struct sock *sk,
1230                                     struct sk_buff *skb)
1231 {
1232         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1233         struct vsock_sock *vsk = vsock_sk(sk);
1234
1235         if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST)
1236                 virtio_transport_do_close(vsk, true);
1237 }
1238
1239 static int
1240 virtio_transport_send_response(struct vsock_sock *vsk,
1241                                struct sk_buff *skb)
1242 {
1243         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1244         struct virtio_vsock_pkt_info info = {
1245                 .op = VIRTIO_VSOCK_OP_RESPONSE,
1246                 .remote_cid = le64_to_cpu(hdr->src_cid),
1247                 .remote_port = le32_to_cpu(hdr->src_port),
1248                 .reply = true,
1249                 .vsk = vsk,
1250         };
1251
1252         return virtio_transport_send_pkt_info(vsk, &info);
1253 }
1254
1255 static bool virtio_transport_space_update(struct sock *sk,
1256                                           struct sk_buff *skb)
1257 {
1258         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1259         struct vsock_sock *vsk = vsock_sk(sk);
1260         struct virtio_vsock_sock *vvs = vsk->trans;
1261         bool space_available;
1262
1263         /* Listener sockets are not associated with any transport, so we are
1264          * not able to take the state to see if there is space available in the
1265          * remote peer, but since they are only used to receive requests, we
1266          * can assume that there is always space available in the other peer.
1267          */
1268         if (!vvs)
1269                 return true;
1270
1271         /* buf_alloc and fwd_cnt is always included in the hdr */
1272         spin_lock_bh(&vvs->tx_lock);
1273         vvs->peer_buf_alloc = le32_to_cpu(hdr->buf_alloc);
1274         vvs->peer_fwd_cnt = le32_to_cpu(hdr->fwd_cnt);
1275         space_available = virtio_transport_has_space(vsk);
1276         spin_unlock_bh(&vvs->tx_lock);
1277         return space_available;
1278 }
1279
1280 /* Handle server socket */
1281 static int
1282 virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb,
1283                              struct virtio_transport *t)
1284 {
1285         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1286         struct vsock_sock *vsk = vsock_sk(sk);
1287         struct vsock_sock *vchild;
1288         struct sock *child;
1289         int ret;
1290
1291         if (le16_to_cpu(hdr->op) != VIRTIO_VSOCK_OP_REQUEST) {
1292                 virtio_transport_reset_no_sock(t, skb);
1293                 return -EINVAL;
1294         }
1295
1296         if (sk_acceptq_is_full(sk)) {
1297                 virtio_transport_reset_no_sock(t, skb);
1298                 return -ENOMEM;
1299         }
1300
1301         child = vsock_create_connected(sk);
1302         if (!child) {
1303                 virtio_transport_reset_no_sock(t, skb);
1304                 return -ENOMEM;
1305         }
1306
1307         sk_acceptq_added(sk);
1308
1309         lock_sock_nested(child, SINGLE_DEPTH_NESTING);
1310
1311         child->sk_state = TCP_ESTABLISHED;
1312
1313         vchild = vsock_sk(child);
1314         vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid),
1315                         le32_to_cpu(hdr->dst_port));
1316         vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid),
1317                         le32_to_cpu(hdr->src_port));
1318
1319         ret = vsock_assign_transport(vchild, vsk);
1320         /* Transport assigned (looking at remote_addr) must be the same
1321          * where we received the request.
1322          */
1323         if (ret || vchild->transport != &t->transport) {
1324                 release_sock(child);
1325                 virtio_transport_reset_no_sock(t, skb);
1326                 sock_put(child);
1327                 return ret;
1328         }
1329
1330         if (virtio_transport_space_update(child, skb))
1331                 child->sk_write_space(child);
1332
1333         vsock_insert_connected(vchild);
1334         vsock_enqueue_accept(sk, child);
1335         virtio_transport_send_response(vchild, skb);
1336
1337         release_sock(child);
1338
1339         sk->sk_data_ready(sk);
1340         return 0;
1341 }
1342
1343 static bool virtio_transport_valid_type(u16 type)
1344 {
1345         return (type == VIRTIO_VSOCK_TYPE_STREAM) ||
1346                (type == VIRTIO_VSOCK_TYPE_SEQPACKET);
1347 }
1348
1349 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
1350  * lock.
1351  */
1352 void virtio_transport_recv_pkt(struct virtio_transport *t,
1353                                struct sk_buff *skb)
1354 {
1355         struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
1356         struct sockaddr_vm src, dst;
1357         struct vsock_sock *vsk;
1358         struct sock *sk;
1359         bool space_available;
1360
1361         vsock_addr_init(&src, le64_to_cpu(hdr->src_cid),
1362                         le32_to_cpu(hdr->src_port));
1363         vsock_addr_init(&dst, le64_to_cpu(hdr->dst_cid),
1364                         le32_to_cpu(hdr->dst_port));
1365
1366         trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
1367                                         dst.svm_cid, dst.svm_port,
1368                                         le32_to_cpu(hdr->len),
1369                                         le16_to_cpu(hdr->type),
1370                                         le16_to_cpu(hdr->op),
1371                                         le32_to_cpu(hdr->flags),
1372                                         le32_to_cpu(hdr->buf_alloc),
1373                                         le32_to_cpu(hdr->fwd_cnt));
1374
1375         if (!virtio_transport_valid_type(le16_to_cpu(hdr->type))) {
1376                 (void)virtio_transport_reset_no_sock(t, skb);
1377                 goto free_pkt;
1378         }
1379
1380         /* The socket must be in connected or bound table
1381          * otherwise send reset back
1382          */
1383         sk = vsock_find_connected_socket(&src, &dst);
1384         if (!sk) {
1385                 sk = vsock_find_bound_socket(&dst);
1386                 if (!sk) {
1387                         (void)virtio_transport_reset_no_sock(t, skb);
1388                         goto free_pkt;
1389                 }
1390         }
1391
1392         if (virtio_transport_get_type(sk) != le16_to_cpu(hdr->type)) {
1393                 (void)virtio_transport_reset_no_sock(t, skb);
1394                 sock_put(sk);
1395                 goto free_pkt;
1396         }
1397
1398         if (!skb_set_owner_sk_safe(skb, sk)) {
1399                 WARN_ONCE(1, "receiving vsock socket has sk_refcnt == 0\n");
1400                 goto free_pkt;
1401         }
1402
1403         vsk = vsock_sk(sk);
1404
1405         lock_sock(sk);
1406
1407         /* Check if sk has been closed before lock_sock */
1408         if (sock_flag(sk, SOCK_DONE)) {
1409                 (void)virtio_transport_reset_no_sock(t, skb);
1410                 release_sock(sk);
1411                 sock_put(sk);
1412                 goto free_pkt;
1413         }
1414
1415         space_available = virtio_transport_space_update(sk, skb);
1416
1417         /* Update CID in case it has changed after a transport reset event */
1418         if (vsk->local_addr.svm_cid != VMADDR_CID_ANY)
1419                 vsk->local_addr.svm_cid = dst.svm_cid;
1420
1421         if (space_available)
1422                 sk->sk_write_space(sk);
1423
1424         switch (sk->sk_state) {
1425         case TCP_LISTEN:
1426                 virtio_transport_recv_listen(sk, skb, t);
1427                 kfree_skb(skb);
1428                 break;
1429         case TCP_SYN_SENT:
1430                 virtio_transport_recv_connecting(sk, skb);
1431                 kfree_skb(skb);
1432                 break;
1433         case TCP_ESTABLISHED:
1434                 virtio_transport_recv_connected(sk, skb);
1435                 break;
1436         case TCP_CLOSING:
1437                 virtio_transport_recv_disconnecting(sk, skb);
1438                 kfree_skb(skb);
1439                 break;
1440         default:
1441                 (void)virtio_transport_reset_no_sock(t, skb);
1442                 kfree_skb(skb);
1443                 break;
1444         }
1445
1446         release_sock(sk);
1447
1448         /* Release refcnt obtained when we fetched this socket out of the
1449          * bound or connected list.
1450          */
1451         sock_put(sk);
1452         return;
1453
1454 free_pkt:
1455         kfree_skb(skb);
1456 }
1457 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1458
1459 /* Remove skbs found in a queue that have a vsk that matches.
1460  *
1461  * Each skb is freed.
1462  *
1463  * Returns the count of skbs that were reply packets.
1464  */
1465 int virtio_transport_purge_skbs(void *vsk, struct sk_buff_head *queue)
1466 {
1467         struct sk_buff_head freeme;
1468         struct sk_buff *skb, *tmp;
1469         int cnt = 0;
1470
1471         skb_queue_head_init(&freeme);
1472
1473         spin_lock_bh(&queue->lock);
1474         skb_queue_walk_safe(queue, skb, tmp) {
1475                 if (vsock_sk(skb->sk) != vsk)
1476                         continue;
1477
1478                 __skb_unlink(skb, queue);
1479                 __skb_queue_tail(&freeme, skb);
1480
1481                 if (virtio_vsock_skb_reply(skb))
1482                         cnt++;
1483         }
1484         spin_unlock_bh(&queue->lock);
1485
1486         __skb_queue_purge(&freeme);
1487
1488         return cnt;
1489 }
1490 EXPORT_SYMBOL_GPL(virtio_transport_purge_skbs);
1491
1492 int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t recv_actor)
1493 {
1494         struct virtio_vsock_sock *vvs = vsk->trans;
1495         struct sock *sk = sk_vsock(vsk);
1496         struct sk_buff *skb;
1497         int off = 0;
1498         int err;
1499
1500         spin_lock_bh(&vvs->rx_lock);
1501         /* Use __skb_recv_datagram() for race-free handling of the receive. It
1502          * works for types other than dgrams.
1503          */
1504         skb = __skb_recv_datagram(sk, &vvs->rx_queue, MSG_DONTWAIT, &off, &err);
1505         spin_unlock_bh(&vvs->rx_lock);
1506
1507         if (!skb)
1508                 return err;
1509
1510         return recv_actor(sk, skb);
1511 }
1512 EXPORT_SYMBOL_GPL(virtio_transport_read_skb);
1513
1514 MODULE_LICENSE("GPL v2");
1515 MODULE_AUTHOR("Asias He");
1516 MODULE_DESCRIPTION("common code for virtio vsock");