af_vsock: implement send logic for SEQPACKET
authorArseny Krasnov <arseny.krasnov@kaspersky.com>
Fri, 11 Jun 2021 11:10:49 +0000 (14:10 +0300)
committerDavid S. Miller <davem@davemloft.net>
Fri, 11 Jun 2021 20:32:46 +0000 (13:32 -0700)
Update current stream enqueue function for SEQPACKET
support:
1) Call transport's seqpacket enqueue callback.
2) Return value from enqueue function is whole record length or error
   for SOCK_SEQPACKET.

Signed-off-by: Arseny Krasnov <arseny.krasnov@kaspersky.com>
Reviewed-by: Stefano Garzarella <sgarzare@redhat.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/af_vsock.h
net/vmw_vsock/af_vsock.c

index 4d7cf6b..d6745d8 100644 (file)
@@ -138,6 +138,8 @@ struct vsock_transport {
        /* SEQ_PACKET. */
        ssize_t (*seqpacket_dequeue)(struct vsock_sock *vsk, struct msghdr *msg,
                                     int flags);
+       int (*seqpacket_enqueue)(struct vsock_sock *vsk, struct msghdr *msg,
+                                size_t len);
 
        /* Notification. */
        int (*notify_poll_in)(struct vsock_sock *, size_t, bool *);
index 87ae26b..9e0cc07 100644 (file)
@@ -1808,9 +1808,13 @@ static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
                 * responsibility to check how many bytes we were able to send.
                 */
 
-               written = transport->stream_enqueue(
-                               vsk, msg,
-                               len - total_written);
+               if (sk->sk_type == SOCK_SEQPACKET) {
+                       written = transport->seqpacket_enqueue(vsk,
+                                               msg, len - total_written);
+               } else {
+                       written = transport->stream_enqueue(vsk,
+                                       msg, len - total_written);
+               }
                if (written < 0) {
                        err = -ENOMEM;
                        goto out_err;
@@ -1826,8 +1830,14 @@ static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
        }
 
 out_err:
-       if (total_written > 0)
-               err = total_written;
+       if (total_written > 0) {
+               /* Return number of written bytes only if:
+                * 1) SOCK_STREAM socket.
+                * 2) SOCK_SEQPACKET socket when whole buffer is sent.
+                */
+               if (sk->sk_type == SOCK_STREAM || total_written == len)
+                       err = total_written;
+       }
 out:
        release_sock(sk);
        return err;