IB/cm: Remove cma_multicast->igmp_joined
authorJason Gunthorpe <jgg@mellanox.com>
Wed, 11 Jul 2018 08:20:29 +0000 (11:20 +0300)
committerJason Gunthorpe <jgg@mellanox.com>
Fri, 13 Jul 2018 18:18:55 +0000 (12:18 -0600)
This variable isn't read and written to with proper locking, so it is
racy. Instead of using an unlocked bool use presence in the mc->list

The caller could race rdma_join_multicast with rdma_leave_multicast which
would leak a mc join and cause a use after free of mc.

Instead, do not add the mc to the list until it has completed
initialization, all mcs on the list require leaving.

Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
drivers/infiniband/core/cma.c

index a735ab4..f2bf997 100644 (file)
@@ -366,7 +366,6 @@ struct cma_multicast {
        void                    *context;
        struct sockaddr_storage addr;
        struct kref             mcref;
-       bool                    igmp_joined;
        u8                      join_state;
 };
 
@@ -1643,21 +1642,14 @@ static void cma_release_port(struct rdma_id_private *id_priv)
 static void cma_leave_roce_mc_group(struct rdma_id_private *id_priv,
                                    struct cma_multicast *mc)
 {
-       if (mc->igmp_joined) {
-               struct rdma_dev_addr *dev_addr =
-                       &id_priv->id.route.addr.dev_addr;
-               struct net_device *ndev = NULL;
-
-               if (dev_addr->bound_dev_if)
-                       ndev = dev_get_by_index(dev_addr->net,
-                                               dev_addr->bound_dev_if);
-               if (ndev) {
-                       cma_igmp_send(ndev,
-                                     &mc->multicast.ib->rec.mgid,
-                                     false);
-                       dev_put(ndev);
-               }
-               mc->igmp_joined = false;
+       struct rdma_dev_addr *dev_addr = &id_priv->id.route.addr.dev_addr;
+       struct net_device *ndev = NULL;
+
+       if (dev_addr->bound_dev_if)
+               ndev = dev_get_by_index(dev_addr->net, dev_addr->bound_dev_if);
+       if (ndev) {
+               cma_igmp_send(ndev, &mc->multicast.ib->rec.mgid, false);
+               dev_put(ndev);
        }
        kref_put(&mc->mcref, release_mc);
 }
@@ -4196,8 +4188,6 @@ static int cma_iboe_join_multicast(struct rdma_id_private *id_priv,
                        if (!send_only) {
                                err = cma_igmp_send(ndev, &mc->multicast.ib->rec.mgid,
                                                    true);
-                               if (!err)
-                                       mc->igmp_joined = true;
                        }
                }
        } else {
@@ -4249,26 +4239,29 @@ int rdma_join_multicast(struct rdma_cm_id *id, struct sockaddr *addr,
        memcpy(&mc->addr, addr, rdma_addr_size(addr));
        mc->context = context;
        mc->id_priv = id_priv;
-       mc->igmp_joined = false;
        mc->join_state = join_state;
-       spin_lock(&id_priv->lock);
-       list_add(&mc->list, &id_priv->mc_list);
-       spin_unlock(&id_priv->lock);
 
        if (rdma_protocol_roce(id->device, id->port_num)) {
                kref_init(&mc->mcref);
                ret = cma_iboe_join_multicast(id_priv, mc);
-       } else if (rdma_cap_ib_mcast(id->device, id->port_num))
+               if (ret)
+                       goto out_err;
+       } else if (rdma_cap_ib_mcast(id->device, id->port_num)) {
                ret = cma_join_ib_multicast(id_priv, mc);
-       else
+               if (ret)
+                       goto out_err;
+       } else {
                ret = -ENOSYS;
-
-       if (ret) {
-               spin_lock_irq(&id_priv->lock);
-               list_del(&mc->list);
-               spin_unlock_irq(&id_priv->lock);
-               kfree(mc);
+               goto out_err;
        }
+
+       spin_lock(&id_priv->lock);
+       list_add(&mc->list, &id_priv->mc_list);
+       spin_unlock(&id_priv->lock);
+
+       return 0;
+out_err:
+       kfree(mc);
        return ret;
 }
 EXPORT_SYMBOL(rdma_join_multicast);