virtio/vsock: rest of SOCK_SEQPACKET support
authorArseny Krasnov <arseny.krasnov@kaspersky.com>
Fri, 11 Jun 2021 11:13:06 +0000 (14:13 +0300)
committerDavid S. Miller <davem@davemloft.net>
Fri, 11 Jun 2021 20:32:47 +0000 (13:32 -0700)
Small updates to make SOCK_SEQPACKET work:
1) Send SHUTDOWN on socket close for SEQPACKET type.
2) Set SEQPACKET packet type during send.
3) Set 'VIRTIO_VSOCK_SEQ_EOR' bit in flags for last
   packet of message.
4) Implement data check function for SEQPACKET.
5) Check for max datagram size.

Signed-off-by: Arseny Krasnov <arseny.krasnov@kaspersky.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/linux/virtio_vsock.h
net/vmw_vsock/virtio_transport_common.c

index 1d9a302..35d7eed 100644 (file)
@@ -81,12 +81,17 @@ virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
                               struct msghdr *msg,
                               size_t len, int flags);
 
+int
+virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
+                                  struct msghdr *msg,
+                                  size_t len);
 ssize_t
 virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
                                   struct msghdr *msg,
                                   int flags);
 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk);
 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk);
+u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk);
 
 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
                                 struct vsock_sock *psk);
index 3a658ff..23704a6 100644 (file)
@@ -74,6 +74,10 @@ virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
                err = memcpy_from_msg(pkt->buf, info->msg, len);
                if (err)
                        goto out;
+
+               if (msg_data_left(info->msg) == 0 &&
+                   info->type == VIRTIO_VSOCK_TYPE_SEQPACKET)
+                       pkt->hdr.flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR);
        }
 
        trace_virtio_transport_alloc_pkt(src_cid, src_port,
@@ -187,7 +191,7 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
        struct virtio_vsock_pkt *pkt;
        u32 pkt_len = info->pkt_len;
 
-       info->type = VIRTIO_VSOCK_TYPE_STREAM;
+       info->type = virtio_transport_get_type(sk_vsock(vsk));
 
        t_ops = virtio_transport_get_ops(vsk);
        if (unlikely(!t_ops))
@@ -498,6 +502,26 @@ virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
 
 int
+virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
+                                  struct msghdr *msg,
+                                  size_t len)
+{
+       struct virtio_vsock_sock *vvs = vsk->trans;
+
+       spin_lock_bh(&vvs->tx_lock);
+
+       if (len > vvs->peer_buf_alloc) {
+               spin_unlock_bh(&vvs->tx_lock);
+               return -EMSGSIZE;
+       }
+
+       spin_unlock_bh(&vvs->tx_lock);
+
+       return virtio_transport_stream_enqueue(vsk, msg, len);
+}
+EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue);
+
+int
 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
                               struct msghdr *msg,
                               size_t len, int flags)
@@ -519,6 +543,19 @@ s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
 }
 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
 
+u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk)
+{
+       struct virtio_vsock_sock *vvs = vsk->trans;
+       u32 msg_count;
+
+       spin_lock_bh(&vvs->rx_lock);
+       msg_count = vvs->msg_count;
+       spin_unlock_bh(&vvs->rx_lock);
+
+       return msg_count;
+}
+EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data);
+
 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
 {
        struct virtio_vsock_sock *vvs = vsk->trans;
@@ -931,7 +968,7 @@ void virtio_transport_release(struct vsock_sock *vsk)
        struct sock *sk = &vsk->sk;
        bool remove_sock = true;
 
-       if (sk->sk_type == SOCK_STREAM)
+       if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)
                remove_sock = virtio_transport_close(vsk);
 
        if (remove_sock) {