vsock: each transport cycles only on its own sockets
authorJiyong Park <jiyong@google.com>
Fri, 11 Mar 2022 02:00:16 +0000 (11:00 +0900)
committerJakub Kicinski <kuba@kernel.org>
Sat, 12 Mar 2022 07:14:19 +0000 (23:14 -0800)
When iterating over sockets using vsock_for_each_connected_socket, make
sure that a transport filters out sockets that don't belong to the
transport.

There actually was an issue caused by this; in a nested VM
configuration, destroying the nested VM (which often involves the
closing of /dev/vhost-vsock if there was h2g connections to the nested
VM) kills not only the h2g connections, but also all existing g2h
connections to the (outmost) host which are totally unrelated.

Tested: Executed the following steps on Cuttlefish (Android running on a
VM) [1]: (1) Enter into an `adb shell` session - to have a g2h
connection inside the VM, (2) open and then close /dev/vhost-vsock by
`exec 3< /dev/vhost-vsock && exec 3<&-`, (3) observe that the adb
session is not reset.

[1] https://android.googlesource.com/device/google/cuttlefish/

Fixes: c0cfa2d8a788 ("vsock: add multi-transports support")
Reviewed-by: Stefano Garzarella <sgarzare@redhat.com>
Acked-by: Michael S. Tsirkin <mst@redhat.com>
Signed-off-by: Jiyong Park <jiyong@google.com>
Link: https://lore.kernel.org/r/20220311020017.1509316-1-jiyong@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
drivers/vhost/vsock.c
include/net/af_vsock.h
net/vmw_vsock/af_vsock.c
net/vmw_vsock/virtio_transport.c
net/vmw_vsock/vmci_transport.c

index 37f0b4274113c6dd8aca659a1e102fa294b48560..e6c9d41db1de77449240fcaa03941377b18a5087 100644 (file)
@@ -753,7 +753,8 @@ static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
 
        /* Iterating over all connections for all CIDs to find orphans is
         * inefficient.  Room for improvement here. */
-       vsock_for_each_connected_socket(vhost_vsock_reset_orphans);
+       vsock_for_each_connected_socket(&vhost_transport.transport,
+                                       vhost_vsock_reset_orphans);
 
        /* Don't check the owner, because we are in the release path, so we
         * need to stop the vsock device in any case.
index ab207677e0a8bd68a8dbd31e57ee8838c396101c..f742e50207fbde4516b5743502d7875573a02fd4 100644 (file)
@@ -205,7 +205,8 @@ struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr);
 struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
                                         struct sockaddr_vm *dst);
 void vsock_remove_sock(struct vsock_sock *vsk);
-void vsock_for_each_connected_socket(void (*fn)(struct sock *sk));
+void vsock_for_each_connected_socket(struct vsock_transport *transport,
+                                    void (*fn)(struct sock *sk));
 int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk);
 bool vsock_find_cid(unsigned int cid);
 
index 38baeb189d4e657b1c05fb16f8af620feff9dbe6..f04abf662ec6cb9d3b693f9dd3023a8fb4617c5d 100644 (file)
@@ -334,7 +334,8 @@ void vsock_remove_sock(struct vsock_sock *vsk)
 }
 EXPORT_SYMBOL_GPL(vsock_remove_sock);
 
-void vsock_for_each_connected_socket(void (*fn)(struct sock *sk))
+void vsock_for_each_connected_socket(struct vsock_transport *transport,
+                                    void (*fn)(struct sock *sk))
 {
        int i;
 
@@ -343,8 +344,12 @@ void vsock_for_each_connected_socket(void (*fn)(struct sock *sk))
        for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) {
                struct vsock_sock *vsk;
                list_for_each_entry(vsk, &vsock_connected_table[i],
-                                   connected_table)
+                                   connected_table) {
+                       if (vsk->transport != transport)
+                               continue;
+
                        fn(sk_vsock(vsk));
+               }
        }
 
        spin_unlock_bh(&vsock_table_lock);
index fb3302fff62795b1c3723ed0b7982772d7f38d55..5afc194a58bbdf1ac2030614da60434351b6f1ff 100644 (file)
@@ -24,6 +24,7 @@
 static struct workqueue_struct *virtio_vsock_workqueue;
 static struct virtio_vsock __rcu *the_virtio_vsock;
 static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */
+static struct virtio_transport virtio_transport; /* forward declaration */
 
 struct virtio_vsock {
        struct virtio_device *vdev;
@@ -384,7 +385,8 @@ static void virtio_vsock_event_handle(struct virtio_vsock *vsock,
        switch (le32_to_cpu(event->id)) {
        case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET:
                virtio_vsock_update_guest_cid(vsock);
-               vsock_for_each_connected_socket(virtio_vsock_reset_sock);
+               vsock_for_each_connected_socket(&virtio_transport.transport,
+                                               virtio_vsock_reset_sock);
                break;
        }
 }
@@ -662,7 +664,8 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
        synchronize_rcu();
 
        /* Reset all connected sockets when the device disappear */
-       vsock_for_each_connected_socket(virtio_vsock_reset_sock);
+       vsock_for_each_connected_socket(&virtio_transport.transport,
+                                       virtio_vsock_reset_sock);
 
        /* Stop all work handlers to make sure no one is accessing the device,
         * so we can safely call virtio_reset_device().
index 7aef34e32bdf8fc9bfb3344002d06e356726c1ca..b17dc9745188e4f1840881209017a23f00c65af9 100644 (file)
@@ -75,6 +75,8 @@ static u32 vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
 
 static int PROTOCOL_OVERRIDE = -1;
 
+static struct vsock_transport vmci_transport; /* forward declaration */
+
 /* Helper function to convert from a VMCI error code to a VSock error code. */
 
 static s32 vmci_transport_error_to_vsock_error(s32 vmci_error)
@@ -882,7 +884,8 @@ static void vmci_transport_qp_resumed_cb(u32 sub_id,
                                         const struct vmci_event_data *e_data,
                                         void *client_data)
 {
-       vsock_for_each_connected_socket(vmci_transport_handle_detach);
+       vsock_for_each_connected_socket(&vmci_transport,
+                                       vmci_transport_handle_detach);
 }
 
 static void vmci_transport_recv_pkt_work(struct work_struct *work)