IB/mlx5: Validate correct PD before prefetch MR
authorMoni Shoua <monis@mellanox.com>
Sun, 17 Feb 2019 14:08:23 +0000 (16:08 +0200)
committerJason Gunthorpe <jgg@mellanox.com>
Thu, 21 Feb 2019 23:32:45 +0000 (16:32 -0700)
When prefetching odp mr it is required to verify that pd of the mr is
identical to the pd for which the advise_mr request arrived with.

This check was missing from synchronous flow and is added now.

Fixes: 813e90b1aeaa ("IB/mlx5: Add advise_mr() support")
Reported-by: Parav Pandit <parav@mellanox.com>
Signed-off-by: Moni Shoua <monis@mellanox.com>
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
drivers/infiniband/hw/mlx5/odp.c

index a134697..d828c20 100644 (file)
@@ -736,7 +736,8 @@ static int get_indirect_num_descs(struct mlx5_core_mkey *mmkey)
  * -EFAULT when there's an error mapping the requested pages. The caller will
  *  abort the page fault handling.
  */
-static int pagefault_single_data_segment(struct mlx5_ib_dev *dev, u32 key,
+static int pagefault_single_data_segment(struct mlx5_ib_dev *dev,
+                                        struct ib_pd *pd, u32 key,
                                         u64 io_virt, size_t bcnt,
                                         u32 *bytes_committed,
                                         u32 *bytes_mapped, u32 flags)
@@ -779,9 +780,15 @@ next_mr:
                        goto srcu_unlock;
                }
 
-               if (prefetch && !is_odp_mr(mr)) {
-                       ret = -EINVAL;
-                       goto srcu_unlock;
+               if (prefetch) {
+                       if (!is_odp_mr(mr) ||
+                           mr->ibmr.pd != pd) {
+                               mlx5_ib_dbg(dev, "Invalid prefetch request: %s\n",
+                                           is_odp_mr(mr) ?  "MR is not ODP" :
+                                           "PD is not of the MR");
+                               ret = -EINVAL;
+                               goto srcu_unlock;
+                       }
                }
 
                if (!is_odp_mr(mr)) {
@@ -964,7 +971,8 @@ static int pagefault_data_segments(struct mlx5_ib_dev *dev,
                        continue;
                }
 
-               ret = pagefault_single_data_segment(dev, key, io_virt, bcnt,
+               ret = pagefault_single_data_segment(dev, NULL, key,
+                                                   io_virt, bcnt,
                                                    &pfault->bytes_committed,
                                                    bytes_mapped, 0);
                if (ret < 0)
@@ -1331,7 +1339,7 @@ static void mlx5_ib_mr_rdma_pfault_handler(struct mlx5_ib_dev *dev,
                prefetch_len = min(MAX_PREFETCH_LEN, prefetch_len);
        }
 
-       ret = pagefault_single_data_segment(dev, rkey, address, length,
+       ret = pagefault_single_data_segment(dev, NULL, rkey, address, length,
                                            &pfault->bytes_committed, NULL,
                                            0);
        if (ret == -EAGAIN) {
@@ -1358,7 +1366,7 @@ static void mlx5_ib_mr_rdma_pfault_handler(struct mlx5_ib_dev *dev,
        if (prefetch_activated) {
                u32 bytes_committed = 0;
 
-               ret = pagefault_single_data_segment(dev, rkey, address,
+               ret = pagefault_single_data_segment(dev, NULL, rkey, address,
                                                    prefetch_len,
                                                    &bytes_committed, NULL,
                                                    0);
@@ -1655,7 +1663,7 @@ int mlx5_ib_odp_init(void)
 
 struct prefetch_mr_work {
        struct work_struct work;
-       struct mlx5_ib_dev *dev;
+       struct ib_pd *pd;
        u32 pf_flags;
        u32 num_sge;
        struct ib_sge sg_list[0];
@@ -1727,17 +1735,18 @@ static bool num_pending_prefetch_inc(struct ib_pd *pd,
        return ret;
 }
 
-static int mlx5_ib_prefetch_sg_list(struct mlx5_ib_dev *dev, u32 pf_flags,
+static int mlx5_ib_prefetch_sg_list(struct ib_pd *pd, u32 pf_flags,
                                    struct ib_sge *sg_list, u32 num_sge)
 {
        u32 i;
        int ret = 0;
+       struct mlx5_ib_dev *dev = to_mdev(pd->device);
 
        for (i = 0; i < num_sge; ++i) {
                struct ib_sge *sg = &sg_list[i];
                int bytes_committed = 0;
 
-               ret = pagefault_single_data_segment(dev, sg->lkey, sg->addr,
+               ret = pagefault_single_data_segment(dev, pd, sg->lkey, sg->addr,
                                                    sg->length,
                                                    &bytes_committed, NULL,
                                                    pf_flags);
@@ -1753,13 +1762,14 @@ static void mlx5_ib_prefetch_mr_work(struct work_struct *work)
        struct prefetch_mr_work *w =
                container_of(work, struct prefetch_mr_work, work);
 
-       if (ib_device_try_get(&w->dev->ib_dev)) {
-               mlx5_ib_prefetch_sg_list(w->dev, w->pf_flags, w->sg_list,
+       if (ib_device_try_get(w->pd->device)) {
+               mlx5_ib_prefetch_sg_list(w->pd, w->pf_flags, w->sg_list,
                                         w->num_sge);
-               ib_device_put(&w->dev->ib_dev);
+               ib_device_put(w->pd->device);
        }
 
-       num_pending_prefetch_dec(w->dev, w->sg_list, w->num_sge, 0);
+       num_pending_prefetch_dec(to_mdev(w->pd->device), w->sg_list,
+                                w->num_sge, 0);
        kfree(w);
 }
 
@@ -1777,7 +1787,7 @@ int mlx5_ib_advise_mr_prefetch(struct ib_pd *pd,
                pf_flags |= MLX5_PF_FLAGS_DOWNGRADE;
 
        if (flags & IB_UVERBS_ADVISE_MR_FLAG_FLUSH)
-               return mlx5_ib_prefetch_sg_list(dev, pf_flags, sg_list,
+               return mlx5_ib_prefetch_sg_list(pd, pf_flags, sg_list,
                                                num_sge);
 
        work = kvzalloc(struct_size(work, sg_list, num_sge), GFP_KERNEL);
@@ -1786,7 +1796,11 @@ int mlx5_ib_advise_mr_prefetch(struct ib_pd *pd,
 
        memcpy(work->sg_list, sg_list, num_sge * sizeof(struct ib_sge));
 
-       work->dev = dev;
+       /* It is guaranteed that the pd when work is executed is the pd when
+        * work was queued since pd can't be destroyed while it holds MRs and
+        * destroying a MR leads to flushing the workquque
+        */
+       work->pd = pd;
        work->pf_flags = pf_flags;
        work->num_sge = num_sge;