Merge git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net
[platform/kernel/linux-starfive.git] / net / vmw_vsock / af_vsock.c
index ae11311..3e02cc3 100644 (file)
@@ -415,8 +415,8 @@ static void vsock_deassign_transport(struct vsock_sock *vsk)
 
 /* Assign a transport to a socket and call the .init transport callback.
  *
- * Note: for stream socket this must be called when vsk->remote_addr is set
- * (e.g. during the connect() or when a connection request on a listener
+ * Note: for connection oriented socket this must be called when vsk->remote_addr
+ * is set (e.g. during the connect() or when a connection request on a listener
  * socket is received).
  * The vsk->remote_addr is used to decide which transport to use:
  *  - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if
@@ -452,6 +452,7 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
                new_transport = transport_dgram;
                break;
        case SOCK_STREAM:
+       case SOCK_SEQPACKET:
                if (vsock_use_local_transport(remote_cid))
                        new_transport = transport_local;
                else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g ||
@@ -469,10 +470,10 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
                        return 0;
 
                /* transport->release() must be called with sock lock acquired.
-                * This path can only be taken during vsock_stream_connect(),
-                * where we have already held the sock lock.
-                * In the other cases, this function is called on a new socket
-                * which is not assigned to any transport.
+                * This path can only be taken during vsock_connect(), where we
+                * have already held the sock lock. In the other cases, this
+                * function is called on a new socket which is not assigned to
+                * any transport.
                 */
                vsk->transport->release(vsk);
                vsock_deassign_transport(vsk);
@@ -484,6 +485,14 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
        if (!new_transport || !try_module_get(new_transport->module))
                return -ENODEV;
 
+       if (sk->sk_type == SOCK_SEQPACKET) {
+               if (!new_transport->seqpacket_allow ||
+                   !new_transport->seqpacket_allow(remote_cid)) {
+                       module_put(new_transport->module);
+                       return -ESOCKTNOSUPPORT;
+               }
+       }
+
        ret = new_transport->init(vsk, psk);
        if (ret) {
                module_put(new_transport->module);
@@ -604,8 +613,8 @@ out:
 
 /**** SOCKET OPERATIONS ****/
 
-static int __vsock_bind_stream(struct vsock_sock *vsk,
-                              struct sockaddr_vm *addr)
+static int __vsock_bind_connectible(struct vsock_sock *vsk,
+                                   struct sockaddr_vm *addr)
 {
        static u32 port;
        struct sockaddr_vm new_addr;
@@ -649,9 +658,10 @@ static int __vsock_bind_stream(struct vsock_sock *vsk,
 
        vsock_addr_init(&vsk->local_addr, new_addr.svm_cid, new_addr.svm_port);
 
-       /* Remove stream sockets from the unbound list and add them to the hash
-        * table for easy lookup by its address.  The unbound list is simply an
-        * extra entry at the end of the hash table, a trick used by AF_UNIX.
+       /* Remove connection oriented sockets from the unbound list and add them
+        * to the hash table for easy lookup by its address.  The unbound list
+        * is simply an extra entry at the end of the hash table, a trick used
+        * by AF_UNIX.
         */
        __vsock_remove_bound(vsk);
        __vsock_insert_bound(vsock_bound_sockets(&vsk->local_addr), vsk);
@@ -684,8 +694,9 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
 
        switch (sk->sk_socket->type) {
        case SOCK_STREAM:
+       case SOCK_SEQPACKET:
                spin_lock_bh(&vsock_table_lock);
-               retval = __vsock_bind_stream(vsk, addr);
+               retval = __vsock_bind_connectible(vsk, addr);
                spin_unlock_bh(&vsock_table_lock);
                break;
 
@@ -768,6 +779,11 @@ static struct sock *__vsock_create(struct net *net,
        return sk;
 }
 
+static bool sock_type_connectible(u16 type)
+{
+       return (type == SOCK_STREAM) || (type == SOCK_SEQPACKET);
+}
+
 static void __vsock_release(struct sock *sk, int level)
 {
        if (sk) {
@@ -786,7 +802,7 @@ static void __vsock_release(struct sock *sk, int level)
 
                if (vsk->transport)
                        vsk->transport->release(vsk);
-               else if (sk->sk_type == SOCK_STREAM)
+               else if (sock_type_connectible(sk->sk_type))
                        vsock_remove_sock(vsk);
 
                sock_orphan(sk);
@@ -844,6 +860,16 @@ s64 vsock_stream_has_data(struct vsock_sock *vsk)
 }
 EXPORT_SYMBOL_GPL(vsock_stream_has_data);
 
+static s64 vsock_connectible_has_data(struct vsock_sock *vsk)
+{
+       struct sock *sk = sk_vsock(vsk);
+
+       if (sk->sk_type == SOCK_SEQPACKET)
+               return vsk->transport->seqpacket_has_data(vsk);
+       else
+               return vsock_stream_has_data(vsk);
+}
+
 s64 vsock_stream_has_space(struct vsock_sock *vsk)
 {
        return vsk->transport->stream_has_space(vsk);
@@ -937,10 +963,10 @@ static int vsock_shutdown(struct socket *sock, int mode)
        if ((mode & ~SHUTDOWN_MASK) || !mode)
                return -EINVAL;
 
-       /* If this is a STREAM socket and it is not connected then bail out
-        * immediately.  If it is a DGRAM socket then we must first kick the
-        * socket so that it wakes up from any sleeping calls, for example
-        * recv(), and then afterwards return the error.
+       /* If this is a connection oriented socket and it is not connected then
+        * bail out immediately.  If it is a DGRAM socket then we must first
+        * kick the socket so that it wakes up from any sleeping calls, for
+        * example recv(), and then afterwards return the error.
         */
 
        sk = sock->sk;
@@ -948,7 +974,7 @@ static int vsock_shutdown(struct socket *sock, int mode)
        lock_sock(sk);
        if (sock->state == SS_UNCONNECTED) {
                err = -ENOTCONN;
-               if (sk->sk_type == SOCK_STREAM)
+               if (sock_type_connectible(sk->sk_type))
                        goto out;
        } else {
                sock->state = SS_DISCONNECTING;
@@ -961,7 +987,7 @@ static int vsock_shutdown(struct socket *sock, int mode)
                sk->sk_shutdown |= mode;
                sk->sk_state_change(sk);
 
-               if (sk->sk_type == SOCK_STREAM) {
+               if (sock_type_connectible(sk->sk_type)) {
                        sock_reset_flag(sk, SOCK_DONE);
                        vsock_send_shutdown(sk, mode);
                }
@@ -1016,7 +1042,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
                if (!(sk->sk_shutdown & SEND_SHUTDOWN))
                        mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
 
-       } else if (sock->type == SOCK_STREAM) {
+       } else if (sock_type_connectible(sk->sk_type)) {
                const struct vsock_transport *transport;
 
                lock_sock(sk);
@@ -1255,7 +1281,7 @@ static void vsock_connect_timeout(struct work_struct *work)
            (sk->sk_shutdown != SHUTDOWN_MASK)) {
                sk->sk_state = TCP_CLOSE;
                sk->sk_err = ETIMEDOUT;
-               sk->sk_error_report(sk);
+               sk_error_report(sk);
                vsock_transport_cancel_pkt(vsk);
        }
        release_sock(sk);
@@ -1263,8 +1289,8 @@ static void vsock_connect_timeout(struct work_struct *work)
        sock_put(sk);
 }
 
-static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
-                               int addr_len, int flags)
+static int vsock_connect(struct socket *sock, struct sockaddr *addr,
+                        int addr_len, int flags)
 {
        int err;
        struct sock *sk;
@@ -1414,7 +1440,7 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags,
 
        lock_sock(listener);
 
-       if (sock->type != SOCK_STREAM) {
+       if (!sock_type_connectible(sock->type)) {
                err = -EOPNOTSUPP;
                goto out;
        }
@@ -1491,7 +1517,7 @@ static int vsock_listen(struct socket *sock, int backlog)
 
        lock_sock(sk);
 
-       if (sock->type != SOCK_STREAM) {
+       if (!sock_type_connectible(sk->sk_type)) {
                err = -EOPNOTSUPP;
                goto out;
        }
@@ -1535,11 +1561,11 @@ static void vsock_update_buffer_size(struct vsock_sock *vsk,
        vsk->buffer_size = val;
 }
 
-static int vsock_stream_setsockopt(struct socket *sock,
-                                  int level,
-                                  int optname,
-                                  sockptr_t optval,
-                                  unsigned int optlen)
+static int vsock_connectible_setsockopt(struct socket *sock,
+                                       int level,
+                                       int optname,
+                                       sockptr_t optval,
+                                       unsigned int optlen)
 {
        int err;
        struct sock *sk;
@@ -1617,10 +1643,10 @@ exit:
        return err;
 }
 
-static int vsock_stream_getsockopt(struct socket *sock,
-                                  int level, int optname,
-                                  char __user *optval,
-                                  int __user *optlen)
+static int vsock_connectible_getsockopt(struct socket *sock,
+                                       int level, int optname,
+                                       char __user *optval,
+                                       int __user *optlen)
 {
        int err;
        int len;
@@ -1688,8 +1714,8 @@ static int vsock_stream_getsockopt(struct socket *sock,
        return 0;
 }
 
-static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
-                               size_t len)
+static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
+                                    size_t len)
 {
        struct sock *sk;
        struct vsock_sock *vsk;
@@ -1712,7 +1738,9 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 
        transport = vsk->transport;
 
-       /* Callers should not provide a destination with stream sockets. */
+       /* Callers should not provide a destination with connection oriented
+        * sockets.
+        */
        if (msg->msg_namelen) {
                err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP;
                goto out;
@@ -1803,9 +1831,13 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
                 * responsibility to check how many bytes we were able to send.
                 */
 
-               written = transport->stream_enqueue(
-                               vsk, msg,
-                               len - total_written);
+               if (sk->sk_type == SOCK_SEQPACKET) {
+                       written = transport->seqpacket_enqueue(vsk,
+                                               msg, len - total_written);
+               } else {
+                       written = transport->stream_enqueue(vsk,
+                                       msg, len - total_written);
+               }
                if (written < 0) {
                        err = -ENOMEM;
                        goto out_err;
@@ -1821,72 +1853,98 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
        }
 
 out_err:
-       if (total_written > 0)
-               err = total_written;
+       if (total_written > 0) {
+               /* Return number of written bytes only if:
+                * 1) SOCK_STREAM socket.
+                * 2) SOCK_SEQPACKET socket when whole buffer is sent.
+                */
+               if (sk->sk_type == SOCK_STREAM || total_written == len)
+                       err = total_written;
+       }
 out:
        release_sock(sk);
        return err;
 }
 
-
-static int
-vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
-                    int flags)
+static int vsock_connectible_wait_data(struct sock *sk,
+                                      struct wait_queue_entry *wait,
+                                      long timeout,
+                                      struct vsock_transport_recv_notify_data *recv_data,
+                                      size_t target)
 {
-       struct sock *sk;
-       struct vsock_sock *vsk;
        const struct vsock_transport *transport;
+       struct vsock_sock *vsk;
+       s64 data;
        int err;
-       size_t target;
-       ssize_t copied;
-       long timeout;
-       struct vsock_transport_recv_notify_data recv_data;
-
-       DEFINE_WAIT(wait);
 
-       sk = sock->sk;
        vsk = vsock_sk(sk);
        err = 0;
+       transport = vsk->transport;
 
-       lock_sock(sk);
+       while ((data = vsock_connectible_has_data(vsk)) == 0) {
+               prepare_to_wait(sk_sleep(sk), wait, TASK_INTERRUPTIBLE);
 
-       transport = vsk->transport;
+               if (sk->sk_err != 0 ||
+                   (sk->sk_shutdown & RCV_SHUTDOWN) ||
+                   (vsk->peer_shutdown & SEND_SHUTDOWN)) {
+                       break;
+               }
 
-       if (!transport || sk->sk_state != TCP_ESTABLISHED) {
-               /* Recvmsg is supposed to return 0 if a peer performs an
-                * orderly shutdown. Differentiate between that case and when a
-                * peer has not connected or a local shutdown occurred with the
-                * SOCK_DONE flag.
-                */
-               if (sock_flag(sk, SOCK_DONE))
-                       err = 0;
-               else
-                       err = -ENOTCONN;
+               /* Don't wait for non-blocking sockets. */
+               if (timeout == 0) {
+                       err = -EAGAIN;
+                       break;
+               }
 
-               goto out;
-       }
+               if (recv_data) {
+                       err = transport->notify_recv_pre_block(vsk, target, recv_data);
+                       if (err < 0)
+                               break;
+               }
 
-       if (flags & MSG_OOB) {
-               err = -EOPNOTSUPP;
-               goto out;
-       }
+               release_sock(sk);
+               timeout = schedule_timeout(timeout);
+               lock_sock(sk);
 
-       /* We don't check peer_shutdown flag here since peer may actually shut
-        * down, but there can be data in the queue that a local socket can
-        * receive.
-        */
-       if (sk->sk_shutdown & RCV_SHUTDOWN) {
-               err = 0;
-               goto out;
+               if (signal_pending(current)) {
+                       err = sock_intr_errno(timeout);
+                       break;
+               } else if (timeout == 0) {
+                       err = -EAGAIN;
+                       break;
+               }
        }
 
-       /* It is valid on Linux to pass in a zero-length receive buffer.  This
-        * is not an error.  We may as well bail out now.
+       finish_wait(sk_sleep(sk), wait);
+
+       if (err)
+               return err;
+
+       /* Internal transport error when checking for available
+        * data. XXX This should be changed to a connection
+        * reset in a later change.
         */
-       if (!len) {
-               err = 0;
-               goto out;
-       }
+       if (data < 0)
+               return -ENOMEM;
+
+       return data;
+}
+
+static int __vsock_stream_recvmsg(struct sock *sk, struct msghdr *msg,
+                                 size_t len, int flags)
+{
+       struct vsock_transport_recv_notify_data recv_data;
+       const struct vsock_transport *transport;
+       struct vsock_sock *vsk;
+       ssize_t copied;
+       size_t target;
+       long timeout;
+       int err;
+
+       DEFINE_WAIT(wait);
+
+       vsk = vsock_sk(sk);
+       transport = vsk->transport;
 
        /* We must not copy less than target bytes into the user's buffer
         * before returning successfully, so we wait for the consume queue to
@@ -1908,94 +1966,158 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 
 
        while (1) {
-               s64 ready;
+               ssize_t read;
 
-               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
-               ready = vsock_stream_has_data(vsk);
+               err = vsock_connectible_wait_data(sk, &wait, timeout,
+                                                 &recv_data, target);
+               if (err <= 0)
+                       break;
 
-               if (ready == 0) {
-                       if (sk->sk_err != 0 ||
-                           (sk->sk_shutdown & RCV_SHUTDOWN) ||
-                           (vsk->peer_shutdown & SEND_SHUTDOWN)) {
-                               finish_wait(sk_sleep(sk), &wait);
-                               break;
-                       }
-                       /* Don't wait for non-blocking sockets. */
-                       if (timeout == 0) {
-                               err = -EAGAIN;
-                               finish_wait(sk_sleep(sk), &wait);
-                               break;
-                       }
+               err = transport->notify_recv_pre_dequeue(vsk, target,
+                                                        &recv_data);
+               if (err < 0)
+                       break;
 
-                       err = transport->notify_recv_pre_block(
-                                       vsk, target, &recv_data);
-                       if (err < 0) {
-                               finish_wait(sk_sleep(sk), &wait);
-                               break;
-                       }
-                       release_sock(sk);
-                       timeout = schedule_timeout(timeout);
-                       lock_sock(sk);
+               read = transport->stream_dequeue(vsk, msg, len - copied, flags);
+               if (read < 0) {
+                       err = -ENOMEM;
+                       break;
+               }
 
-                       if (signal_pending(current)) {
-                               err = sock_intr_errno(timeout);
-                               finish_wait(sk_sleep(sk), &wait);
-                               break;
-                       } else if (timeout == 0) {
-                               err = -EAGAIN;
-                               finish_wait(sk_sleep(sk), &wait);
-                               break;
-                       }
-               } else {
-                       ssize_t read;
+               copied += read;
 
-                       finish_wait(sk_sleep(sk), &wait);
+               err = transport->notify_recv_post_dequeue(vsk, target, read,
+                                               !(flags & MSG_PEEK), &recv_data);
+               if (err < 0)
+                       goto out;
 
-                       if (ready < 0) {
-                               /* Invalid queue pair content. XXX This should
-                               * be changed to a connection reset in a later
-                               * change.
-                               */
+               if (read >= target || flags & MSG_PEEK)
+                       break;
 
-                               err = -ENOMEM;
-                               goto out;
-                       }
+               target -= read;
+       }
 
-                       err = transport->notify_recv_pre_dequeue(
-                                       vsk, target, &recv_data);
-                       if (err < 0)
-                               break;
+       if (sk->sk_err)
+               err = -sk->sk_err;
+       else if (sk->sk_shutdown & RCV_SHUTDOWN)
+               err = 0;
 
-                       read = transport->stream_dequeue(
-                                       vsk, msg,
-                                       len - copied, flags);
-                       if (read < 0) {
-                               err = -ENOMEM;
-                               break;
-                       }
+       if (copied > 0)
+               err = copied;
+
+out:
+       return err;
+}
 
-                       copied += read;
+static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
+                                    size_t len, int flags)
+{
+       const struct vsock_transport *transport;
+       struct vsock_sock *vsk;
+       ssize_t record_len;
+       long timeout;
+       int err = 0;
+       DEFINE_WAIT(wait);
 
-                       err = transport->notify_recv_post_dequeue(
-                                       vsk, target, read,
-                                       !(flags & MSG_PEEK), &recv_data);
-                       if (err < 0)
-                               goto out;
+       vsk = vsock_sk(sk);
+       transport = vsk->transport;
 
-                       if (read >= target || flags & MSG_PEEK)
-                               break;
+       timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
 
-                       target -= read;
-               }
+       err = vsock_connectible_wait_data(sk, &wait, timeout, NULL, 0);
+       if (err <= 0)
+               goto out;
+
+       record_len = transport->seqpacket_dequeue(vsk, msg, flags);
+
+       if (record_len < 0) {
+               err = -ENOMEM;
+               goto out;
        }
 
-       if (sk->sk_err)
+       if (sk->sk_err) {
                err = -sk->sk_err;
-       else if (sk->sk_shutdown & RCV_SHUTDOWN)
+       } else if (sk->sk_shutdown & RCV_SHUTDOWN) {
                err = 0;
+       } else {
+               /* User sets MSG_TRUNC, so return real length of
+                * packet.
+                */
+               if (flags & MSG_TRUNC)
+                       err = record_len;
+               else
+                       err = len - msg_data_left(msg);
 
-       if (copied > 0)
-               err = copied;
+               /* Always set MSG_TRUNC if real length of packet is
+                * bigger than user's buffer.
+                */
+               if (record_len > len)
+                       msg->msg_flags |= MSG_TRUNC;
+       }
+
+out:
+       return err;
+}
+
+static int
+vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
+                         int flags)
+{
+       struct sock *sk;
+       struct vsock_sock *vsk;
+       const struct vsock_transport *transport;
+       int err;
+
+       DEFINE_WAIT(wait);
+
+       sk = sock->sk;
+       vsk = vsock_sk(sk);
+       err = 0;
+
+       lock_sock(sk);
+
+       transport = vsk->transport;
+
+       if (!transport || sk->sk_state != TCP_ESTABLISHED) {
+               /* Recvmsg is supposed to return 0 if a peer performs an
+                * orderly shutdown. Differentiate between that case and when a
+                * peer has not connected or a local shutdown occurred with the
+                * SOCK_DONE flag.
+                */
+               if (sock_flag(sk, SOCK_DONE))
+                       err = 0;
+               else
+                       err = -ENOTCONN;
+
+               goto out;
+       }
+
+       if (flags & MSG_OOB) {
+               err = -EOPNOTSUPP;
+               goto out;
+       }
+
+       /* We don't check peer_shutdown flag here since peer may actually shut
+        * down, but there can be data in the queue that a local socket can
+        * receive.
+        */
+       if (sk->sk_shutdown & RCV_SHUTDOWN) {
+               err = 0;
+               goto out;
+       }
+
+       /* It is valid on Linux to pass in a zero-length receive buffer.  This
+        * is not an error.  We may as well bail out now.
+        */
+       if (!len) {
+               err = 0;
+               goto out;
+       }
+
+       if (sk->sk_type == SOCK_STREAM)
+               err = __vsock_stream_recvmsg(sk, msg, len, flags);
+       else
+               err = __vsock_seqpacket_recvmsg(sk, msg, len, flags);
 
 out:
        release_sock(sk);
@@ -2007,7 +2129,7 @@ static const struct proto_ops vsock_stream_ops = {
        .owner = THIS_MODULE,
        .release = vsock_release,
        .bind = vsock_bind,
-       .connect = vsock_stream_connect,
+       .connect = vsock_connect,
        .socketpair = sock_no_socketpair,
        .accept = vsock_accept,
        .getname = vsock_getname,
@@ -2015,10 +2137,31 @@ static const struct proto_ops vsock_stream_ops = {
        .ioctl = sock_no_ioctl,
        .listen = vsock_listen,
        .shutdown = vsock_shutdown,
-       .setsockopt = vsock_stream_setsockopt,
-       .getsockopt = vsock_stream_getsockopt,
-       .sendmsg = vsock_stream_sendmsg,
-       .recvmsg = vsock_stream_recvmsg,
+       .setsockopt = vsock_connectible_setsockopt,
+       .getsockopt = vsock_connectible_getsockopt,
+       .sendmsg = vsock_connectible_sendmsg,
+       .recvmsg = vsock_connectible_recvmsg,
+       .mmap = sock_no_mmap,
+       .sendpage = sock_no_sendpage,
+};
+
+static const struct proto_ops vsock_seqpacket_ops = {
+       .family = PF_VSOCK,
+       .owner = THIS_MODULE,
+       .release = vsock_release,
+       .bind = vsock_bind,
+       .connect = vsock_connect,
+       .socketpair = sock_no_socketpair,
+       .accept = vsock_accept,
+       .getname = vsock_getname,
+       .poll = vsock_poll,
+       .ioctl = sock_no_ioctl,
+       .listen = vsock_listen,
+       .shutdown = vsock_shutdown,
+       .setsockopt = vsock_connectible_setsockopt,
+       .getsockopt = vsock_connectible_getsockopt,
+       .sendmsg = vsock_connectible_sendmsg,
+       .recvmsg = vsock_connectible_recvmsg,
        .mmap = sock_no_mmap,
        .sendpage = sock_no_sendpage,
 };
@@ -2043,6 +2186,9 @@ static int vsock_create(struct net *net, struct socket *sock,
        case SOCK_STREAM:
                sock->ops = &vsock_stream_ops;
                break;
+       case SOCK_SEQPACKET:
+               sock->ops = &vsock_seqpacket_ops;
+               break;
        default:
                return -ESOCKTNOSUPPORT;
        }