RDMA/cm: Allow ib_send_cm_rej() to be done under lock
authorJason Gunthorpe <jgg@mellanox.com>
Tue, 10 Mar 2020 09:25:43 +0000 (11:25 +0200)
committerJason Gunthorpe <jgg@mellanox.com>
Tue, 17 Mar 2020 20:05:54 +0000 (17:05 -0300)
The first thing ib_send_cm_rej() does is obtain the lock, so use the usual
unlocked wrapper, locked actor pattern here.

This avoids a sketchy lock/unlock sequence (which could allow state to
change) during cm_destroy_id().

While here simplify some of the logic in the implementation.

Link: https://lore.kernel.org/r/20200310092545.251365-14-leon@kernel.org
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
drivers/infiniband/core/cm.c

index 8ed2534..651e7c3 100644 (file)
@@ -87,6 +87,10 @@ static int cm_send_dreq_locked(struct cm_id_private *cm_id_priv,
                               const void *private_data, u8 private_data_len);
 static int cm_send_drep_locked(struct cm_id_private *cm_id_priv,
                               void *private_data, u8 private_data_len);
+static int cm_send_rej_locked(struct cm_id_private *cm_id_priv,
+                             enum ib_cm_rej_reason reason, void *ari,
+                             u8 ari_length, const void *private_data,
+                             u8 private_data_len);
 
 static struct ib_client cm_client = {
        .name   = "cm",
@@ -1060,11 +1064,11 @@ retest:
        case IB_CM_REQ_SENT:
        case IB_CM_MRA_REQ_RCVD:
                ib_cancel_mad(cm_id_priv->av.port->mad_agent, cm_id_priv->msg);
+               cm_send_rej_locked(cm_id_priv, IB_CM_REJ_TIMEOUT,
+                                  &cm_id_priv->id.device->node_guid,
+                                  sizeof(cm_id_priv->id.device->node_guid),
+                                  NULL, 0);
                spin_unlock_irq(&cm_id_priv->lock);
-               ib_send_cm_rej(cm_id, IB_CM_REJ_TIMEOUT,
-                              &cm_id_priv->id.device->node_guid,
-                              sizeof cm_id_priv->id.device->node_guid,
-                              NULL, 0);
                break;
        case IB_CM_REQ_RCVD:
                if (err == -ENOMEM) {
@@ -1072,9 +1076,10 @@ retest:
                        cm_reset_to_idle(cm_id_priv);
                        spin_unlock_irq(&cm_id_priv->lock);
                } else {
+                       cm_send_rej_locked(cm_id_priv,
+                                          IB_CM_REJ_CONSUMER_DEFINED, NULL, 0,
+                                          NULL, 0);
                        spin_unlock_irq(&cm_id_priv->lock);
-                       ib_send_cm_rej(cm_id, IB_CM_REJ_CONSUMER_DEFINED,
-                                      NULL, 0, NULL, 0);
                }
                break;
        case IB_CM_REP_SENT:
@@ -1084,9 +1089,9 @@ retest:
        case IB_CM_MRA_REQ_SENT:
        case IB_CM_REP_RCVD:
        case IB_CM_MRA_REP_SENT:
+               cm_send_rej_locked(cm_id_priv, IB_CM_REJ_CONSUMER_DEFINED, NULL,
+                                  0, NULL, 0);
                spin_unlock_irq(&cm_id_priv->lock);
-               ib_send_cm_rej(cm_id, IB_CM_REJ_CONSUMER_DEFINED,
-                              NULL, 0, NULL, 0);
                break;
        case IB_CM_ESTABLISHED:
                if (cm_id_priv->qp_type == IB_QPT_XRC_TGT) {
@@ -2899,65 +2904,72 @@ out:
        return -EINVAL;
 }
 
-int ib_send_cm_rej(struct ib_cm_id *cm_id,
-                  enum ib_cm_rej_reason reason,
-                  void *ari,
-                  u8 ari_length,
-                  const void *private_data,
-                  u8 private_data_len)
+static int cm_send_rej_locked(struct cm_id_private *cm_id_priv,
+                             enum ib_cm_rej_reason reason, void *ari,
+                             u8 ari_length, const void *private_data,
+                             u8 private_data_len)
 {
-       struct cm_id_private *cm_id_priv;
        struct ib_mad_send_buf *msg;
-       unsigned long flags;
        int ret;
 
+       lockdep_assert_held(&cm_id_priv->lock);
+
        if ((private_data && private_data_len > IB_CM_REJ_PRIVATE_DATA_SIZE) ||
            (ari && ari_length > IB_CM_REJ_ARI_LENGTH))
                return -EINVAL;
 
-       cm_id_priv = container_of(cm_id, struct cm_id_private, id);
-
-       spin_lock_irqsave(&cm_id_priv->lock, flags);
-       switch (cm_id->state) {
+       switch (cm_id_priv->id.state) {
        case IB_CM_REQ_SENT:
        case IB_CM_MRA_REQ_RCVD:
        case IB_CM_REQ_RCVD:
        case IB_CM_MRA_REQ_SENT:
        case IB_CM_REP_RCVD:
        case IB_CM_MRA_REP_SENT:
-               ret = cm_alloc_msg(cm_id_priv, &msg);
-               if (!ret)
-                       cm_format_rej((struct cm_rej_msg *) msg->mad,
-                                     cm_id_priv, reason, ari, ari_length,
-                                     private_data, private_data_len);
-
                cm_reset_to_idle(cm_id_priv);
+               ret = cm_alloc_msg(cm_id_priv, &msg);
+               if (ret)
+                       return ret;
+               cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
+                             ari, ari_length, private_data, private_data_len);
                break;
        case IB_CM_REP_SENT:
        case IB_CM_MRA_REP_RCVD:
-               ret = cm_alloc_msg(cm_id_priv, &msg);
-               if (!ret)
-                       cm_format_rej((struct cm_rej_msg *) msg->mad,
-                                     cm_id_priv, reason, ari, ari_length,
-                                     private_data, private_data_len);
-
                cm_enter_timewait(cm_id_priv);
+               ret = cm_alloc_msg(cm_id_priv, &msg);
+               if (ret)
+                       return ret;
+               cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
+                             ari, ari_length, private_data, private_data_len);
                break;
        default:
                pr_debug("%s: local_id %d, cm_id->state: %d\n", __func__,
-                        be32_to_cpu(cm_id_priv->id.local_id), cm_id->state);
-               ret = -EINVAL;
-               goto out;
+                        be32_to_cpu(cm_id_priv->id.local_id),
+                        cm_id_priv->id.state);
+               return -EINVAL;
        }
 
-       if (ret)
-               goto out;
-
        ret = ib_post_send_mad(msg, NULL);
-       if (ret)
+       if (ret) {
                cm_free_msg(msg);
+               return ret;
+       }
 
-out:   spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+       return 0;
+}
+
+int ib_send_cm_rej(struct ib_cm_id *cm_id, enum ib_cm_rej_reason reason,
+                  void *ari, u8 ari_length, const void *private_data,
+                  u8 private_data_len)
+{
+       struct cm_id_private *cm_id_priv =
+               container_of(cm_id, struct cm_id_private, id);
+       unsigned long flags;
+       int ret;
+
+       spin_lock_irqsave(&cm_id_priv->lock, flags);
+       ret = cm_send_rej_locked(cm_id_priv, reason, ari, ari_length,
+                                private_data, private_data_len);
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
        return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_rej);