RDMA/cma: Fix use after free race in roce multicast join
authorJason Gunthorpe <jgg@nvidia.com>
Wed, 2 Sep 2020 08:11:22 +0000 (11:11 +0300)
committerJason Gunthorpe <jgg@nvidia.com>
Thu, 17 Sep 2020 12:09:25 +0000 (09:09 -0300)
The roce path triggers a work queue that continues to touch the id_priv
but doesn't hold any reference on it. Futher, unlike in the IB case, the
work queue is not fenced during rdma_destroy_id().

This can trigger a use after free if a destroy is triggered in the
incredibly narrow window after the queue_work and the work starting and
obtaining the handler_mutex.

The only purpose of this work queue is to run the ULP event callback from
the standard context, so switch the design to use the existing
cma_work_handler() scheme. This simplifies quite a lot of the flow:

- Use the cma_work_handler() callback to launch the work for roce. This
  requires generating the event synchronously inside the
  rdma_join_multicast(), which in turn means the dummy struct
  ib_sa_multicast can become a simple stack variable.

- cm_work_handler() used the id_priv kref, so we can entirely eliminate
  the kref inside struct cma_multicast. Since the cma_multicast never
  leaks into an unprotected work queue the kfree can be done at the same
  time as for IB.

- Eliminating the general multicast.ib requires using cma_set_mgid() in a
  few places to recompute the mgid.

Fixes: 3c86aa70bf67 ("RDMA/cm: Add RDMA CM support for IBoE devices")
Link: https://lore.kernel.org/r/20200902081122.745412-9-leon@kernel.org
Signed-off-by: Leon Romanovsky <leonro@nvidia.com>
Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
drivers/infiniband/core/cma.c

index 906717c..22bd892 100644 (file)
@@ -68,6 +68,9 @@ static const char * const cma_events[] = {
        [RDMA_CM_EVENT_TIMEWAIT_EXIT]    = "timewait exit",
 };
 
+static void cma_set_mgid(struct rdma_id_private *id_priv, struct sockaddr *addr,
+                        union ib_gid *mgid);
+
 const char *__attribute_const__ rdma_event_msg(enum rdma_cm_event_type event)
 {
        size_t index = event;
@@ -345,13 +348,10 @@ struct ib_device *cma_get_ib_dev(struct cma_device *cma_dev)
 
 struct cma_multicast {
        struct rdma_id_private *id_priv;
-       union {
-               struct ib_sa_multicast *ib;
-       } multicast;
+       struct ib_sa_multicast *sa_mc;
        struct list_head        list;
        void                    *context;
        struct sockaddr_storage addr;
-       struct kref             mcref;
        u8                      join_state;
 };
 
@@ -363,12 +363,6 @@ struct cma_work {
        struct rdma_cm_event    event;
 };
 
-struct iboe_mcast_work {
-       struct work_struct       work;
-       struct rdma_id_private  *id;
-       struct cma_multicast    *mc;
-};
-
 union cma_ip_addr {
        struct in6_addr ip6;
        struct {
@@ -475,14 +469,6 @@ static void cma_attach_to_dev(struct rdma_id_private *id_priv,
                                          rdma_start_port(cma_dev->device)];
 }
 
-static inline void release_mc(struct kref *kref)
-{
-       struct cma_multicast *mc = container_of(kref, struct cma_multicast, mcref);
-
-       kfree(mc->multicast.ib);
-       kfree(mc);
-}
-
 static void cma_release_dev(struct rdma_id_private *id_priv)
 {
        mutex_lock(&lock);
@@ -1778,14 +1764,10 @@ static void cma_release_port(struct rdma_id_private *id_priv)
 static void destroy_mc(struct rdma_id_private *id_priv,
                       struct cma_multicast *mc)
 {
-       if (rdma_cap_ib_mcast(id_priv->id.device, id_priv->id.port_num)) {
-               ib_sa_free_multicast(mc->multicast.ib);
-               kfree(mc);
-               return;
-       }
+       if (rdma_cap_ib_mcast(id_priv->id.device, id_priv->id.port_num))
+               ib_sa_free_multicast(mc->sa_mc);
 
-       if (rdma_protocol_roce(id_priv->id.device,
-                                     id_priv->id.port_num)) {
+       if (rdma_protocol_roce(id_priv->id.device, id_priv->id.port_num)) {
                struct rdma_dev_addr *dev_addr =
                        &id_priv->id.route.addr.dev_addr;
                struct net_device *ndev = NULL;
@@ -1794,11 +1776,15 @@ static void destroy_mc(struct rdma_id_private *id_priv,
                        ndev = dev_get_by_index(dev_addr->net,
                                                dev_addr->bound_dev_if);
                if (ndev) {
-                       cma_igmp_send(ndev, &mc->multicast.ib->rec.mgid, false);
+                       union ib_gid mgid;
+
+                       cma_set_mgid(id_priv, (struct sockaddr *)&mc->addr,
+                                    &mgid);
+                       cma_igmp_send(ndev, &mgid, false);
                        dev_put(ndev);
                }
-               kref_put(&mc->mcref, release_mc);
        }
+       kfree(mc);
 }
 
 static void cma_leave_mc_groups(struct rdma_id_private *id_priv)
@@ -2664,6 +2650,8 @@ out_unlock:
        mutex_unlock(&id_priv->handler_mutex);
        cma_id_put(id_priv);
 out_free:
+       if (work->event.event == RDMA_CM_EVENT_MULTICAST_JOIN)
+               rdma_destroy_ah_attr(&work->event.param.ud.ah_attr);
        kfree(work);
 }
 
@@ -4324,53 +4312,66 @@ out:
 }
 EXPORT_SYMBOL(rdma_disconnect);
 
+static void cma_make_mc_event(int status, struct rdma_id_private *id_priv,
+                             struct ib_sa_multicast *multicast,
+                             struct rdma_cm_event *event,
+                             struct cma_multicast *mc)
+{
+       struct rdma_dev_addr *dev_addr;
+       enum ib_gid_type gid_type;
+       struct net_device *ndev;
+
+       if (!status)
+               status = cma_set_qkey(id_priv, be32_to_cpu(multicast->rec.qkey));
+       else
+               pr_debug_ratelimited("RDMA CM: MULTICAST_ERROR: failed to join multicast. status %d\n",
+                                    status);
+
+       event->status = status;
+       event->param.ud.private_data = mc->context;
+       if (status) {
+               event->event = RDMA_CM_EVENT_MULTICAST_ERROR;
+               return;
+       }
+
+       dev_addr = &id_priv->id.route.addr.dev_addr;
+       ndev = dev_get_by_index(dev_addr->net, dev_addr->bound_dev_if);
+       gid_type =
+               id_priv->cma_dev
+                       ->default_gid_type[id_priv->id.port_num -
+                                          rdma_start_port(
+                                                  id_priv->cma_dev->device)];
+
+       event->event = RDMA_CM_EVENT_MULTICAST_JOIN;
+       if (ib_init_ah_from_mcmember(id_priv->id.device, id_priv->id.port_num,
+                                    &multicast->rec, ndev, gid_type,
+                                    &event->param.ud.ah_attr)) {
+               event->event = RDMA_CM_EVENT_MULTICAST_ERROR;
+               goto out;
+       }
+
+       event->param.ud.qp_num = 0xFFFFFF;
+       event->param.ud.qkey = be32_to_cpu(multicast->rec.qkey);
+
+out:
+       if (ndev)
+               dev_put(ndev);
+}
+
 static int cma_ib_mc_handler(int status, struct ib_sa_multicast *multicast)
 {
-       struct rdma_id_private *id_priv;
        struct cma_multicast *mc = multicast->context;
+       struct rdma_id_private *id_priv = mc->id_priv;
        struct rdma_cm_event event = {};
        int ret = 0;
 
-       id_priv = mc->id_priv;
        mutex_lock(&id_priv->handler_mutex);
        if (READ_ONCE(id_priv->state) == RDMA_CM_DEVICE_REMOVAL ||
            READ_ONCE(id_priv->state) == RDMA_CM_DESTROYING)
                goto out;
 
-       if (!status)
-               status = cma_set_qkey(id_priv, be32_to_cpu(multicast->rec.qkey));
-       else
-               pr_debug_ratelimited("RDMA CM: MULTICAST_ERROR: failed to join multicast. status %d\n",
-                                    status);
-       event.status = status;
-       event.param.ud.private_data = mc->context;
-       if (!status) {
-               struct rdma_dev_addr *dev_addr =
-                       &id_priv->id.route.addr.dev_addr;
-               struct net_device *ndev =
-                       dev_get_by_index(dev_addr->net, dev_addr->bound_dev_if);
-               enum ib_gid_type gid_type =
-                       id_priv->cma_dev->default_gid_type[id_priv->id.port_num -
-                       rdma_start_port(id_priv->cma_dev->device)];
-
-               event.event = RDMA_CM_EVENT_MULTICAST_JOIN;
-               ret = ib_init_ah_from_mcmember(id_priv->id.device,
-                                              id_priv->id.port_num,
-                                              &multicast->rec,
-                                              ndev, gid_type,
-                                              &event.param.ud.ah_attr);
-               if (ret)
-                       event.event = RDMA_CM_EVENT_MULTICAST_ERROR;
-
-               event.param.ud.qp_num = 0xFFFFFF;
-               event.param.ud.qkey = be32_to_cpu(multicast->rec.qkey);
-               if (ndev)
-                       dev_put(ndev);
-       } else
-               event.event = RDMA_CM_EVENT_MULTICAST_ERROR;
-
+       cma_make_mc_event(status, id_priv, multicast, &event, mc);
        ret = cma_cm_event_handler(id_priv, &event);
-
        rdma_destroy_ah_attr(&event.param.ud.ah_attr);
        if (ret) {
                destroy_id_handler_unlock(id_priv);
@@ -4460,23 +4461,10 @@ static int cma_join_ib_multicast(struct rdma_id_private *id_priv,
                             IB_SA_MCMEMBER_REC_MTU |
                             IB_SA_MCMEMBER_REC_HOP_LIMIT;
 
-       mc->multicast.ib = ib_sa_join_multicast(&sa_client, id_priv->id.device,
-                                               id_priv->id.port_num, &rec,
-                                               comp_mask, GFP_KERNEL,
-                                               cma_ib_mc_handler, mc);
-       return PTR_ERR_OR_ZERO(mc->multicast.ib);
-}
-
-static void iboe_mcast_work_handler(struct work_struct *work)
-{
-       struct iboe_mcast_work *mw = container_of(work, struct iboe_mcast_work, work);
-       struct cma_multicast *mc = mw->mc;
-       struct ib_sa_multicast *m = mc->multicast.ib;
-
-       mc->multicast.ib->context = mc;
-       cma_ib_mc_handler(0, m);
-       kref_put(&mc->mcref, release_mc);
-       kfree(mw);
+       mc->sa_mc = ib_sa_join_multicast(&sa_client, id_priv->id.device,
+                                        id_priv->id.port_num, &rec, comp_mask,
+                                        GFP_KERNEL, cma_ib_mc_handler, mc);
+       return PTR_ERR_OR_ZERO(mc->sa_mc);
 }
 
 static void cma_iboe_set_mgid(struct sockaddr *addr, union ib_gid *mgid,
@@ -4511,52 +4499,47 @@ static void cma_iboe_set_mgid(struct sockaddr *addr, union ib_gid *mgid,
 static int cma_iboe_join_multicast(struct rdma_id_private *id_priv,
                                   struct cma_multicast *mc)
 {
-       struct iboe_mcast_work *work;
+       struct cma_work *work;
        struct rdma_dev_addr *dev_addr = &id_priv->id.route.addr.dev_addr;
        int err = 0;
        struct sockaddr *addr = (struct sockaddr *)&mc->addr;
        struct net_device *ndev = NULL;
+       struct ib_sa_multicast ib;
        enum ib_gid_type gid_type;
        bool send_only;
 
        send_only = mc->join_state == BIT(SENDONLY_FULLMEMBER_JOIN);
 
-       if (cma_zero_addr((struct sockaddr *)&mc->addr))
+       if (cma_zero_addr(addr))
                return -EINVAL;
 
        work = kzalloc(sizeof *work, GFP_KERNEL);
        if (!work)
                return -ENOMEM;
 
-       mc->multicast.ib = kzalloc(sizeof(struct ib_sa_multicast), GFP_KERNEL);
-       if (!mc->multicast.ib) {
-               err = -ENOMEM;
-               goto out1;
-       }
-
        gid_type = id_priv->cma_dev->default_gid_type[id_priv->id.port_num -
                   rdma_start_port(id_priv->cma_dev->device)];
-       cma_iboe_set_mgid(addr, &mc->multicast.ib->rec.mgid, gid_type);
+       cma_iboe_set_mgid(addr, &ib.rec.mgid, gid_type);
 
-       mc->multicast.ib->rec.pkey = cpu_to_be16(0xffff);
+       ib.rec.pkey = cpu_to_be16(0xffff);
        if (id_priv->id.ps == RDMA_PS_UDP)
-               mc->multicast.ib->rec.qkey = cpu_to_be32(RDMA_UDP_QKEY);
+               ib.rec.qkey = cpu_to_be32(RDMA_UDP_QKEY);
 
        if (dev_addr->bound_dev_if)
                ndev = dev_get_by_index(dev_addr->net, dev_addr->bound_dev_if);
        if (!ndev) {
                err = -ENODEV;
-               goto out2;
+               goto err_free;
        }
-       mc->multicast.ib->rec.rate = iboe_get_rate(ndev);
-       mc->multicast.ib->rec.hop_limit = 1;
-       mc->multicast.ib->rec.mtu = iboe_get_mtu(ndev->mtu);
+       ib.rec.rate = iboe_get_rate(ndev);
+       ib.rec.hop_limit = 1;
+       ib.rec.mtu = iboe_get_mtu(ndev->mtu);
 
        if (addr->sa_family == AF_INET) {
                if (gid_type == IB_GID_TYPE_ROCE_UDP_ENCAP) {
-                       mc->multicast.ib->rec.hop_limit = IPV6_DEFAULT_HOPLIMIT;
+                       ib.rec.hop_limit = IPV6_DEFAULT_HOPLIMIT;
                        if (!send_only) {
-                               err = cma_igmp_send(ndev, &mc->multicast.ib->rec.mgid,
+                               err = cma_igmp_send(ndev, &ib.rec.mgid,
                                                    true);
                        }
                }
@@ -4565,24 +4548,22 @@ static int cma_iboe_join_multicast(struct rdma_id_private *id_priv,
                        err = -ENOTSUPP;
        }
        dev_put(ndev);
-       if (err || !mc->multicast.ib->rec.mtu) {
+       if (err || !ib.rec.mtu) {
                if (!err)
                        err = -EINVAL;
-               goto out2;
+               goto err_free;
        }
        rdma_ip2gid((struct sockaddr *)&id_priv->id.route.addr.src_addr,
-                   &mc->multicast.ib->rec.port_gid);
+                   &ib.rec.port_gid);
        work->id = id_priv;
-       work->mc = mc;
-       INIT_WORK(&work->work, iboe_mcast_work_handler);
-       kref_get(&mc->mcref);
+       INIT_WORK(&work->work, cma_work_handler);
+       cma_make_mc_event(0, id_priv, &ib, &work->event, mc);
+       /* Balances with cma_id_put() in cma_work_handler */
+       cma_id_get(id_priv);
        queue_work(cma_wq, &work->work);
-
        return 0;
 
-out2:
-       kfree(mc->multicast.ib);
-out1:
+err_free:
        kfree(work);
        return err;
 }
@@ -4604,7 +4585,7 @@ int rdma_join_multicast(struct rdma_cm_id *id, struct sockaddr *addr,
                            READ_ONCE(id_priv->state) != RDMA_CM_ADDR_RESOLVED))
                return -EINVAL;
 
-       mc = kmalloc(sizeof *mc, GFP_KERNEL);
+       mc = kzalloc(sizeof(*mc), GFP_KERNEL);
        if (!mc)
                return -ENOMEM;
 
@@ -4614,7 +4595,6 @@ int rdma_join_multicast(struct rdma_cm_id *id, struct sockaddr *addr,
        mc->join_state = join_state;
 
        if (rdma_protocol_roce(id->device, id->port_num)) {
-               kref_init(&mc->mcref);
                ret = cma_iboe_join_multicast(id_priv, mc);
                if (ret)
                        goto out_err;