RDMA/odp: Split creating a umem_odp from ib_umem_get
authorJason Gunthorpe <jgg@mellanox.com>
Mon, 19 Aug 2019 11:17:04 +0000 (14:17 +0300)
committerJason Gunthorpe <jgg@mellanox.com>
Wed, 21 Aug 2019 17:08:42 +0000 (14:08 -0300)
This is the last creation API that is overloaded for both, there is very
little code sharing and a driver has to be specifically ready for a
umem_odp to be created to use the odp version.

Link: https://lore.kernel.org/r/20190819111710.18440-7-leon@kernel.org
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
drivers/infiniband/core/umem.c
drivers/infiniband/core/umem_odp.c
drivers/infiniband/hw/mlx5/mem.c
drivers/infiniband/hw/mlx5/mr.c
include/rdma/ib_umem_odp.h

index 5655366..9a39c45 100644 (file)
@@ -184,9 +184,6 @@ EXPORT_SYMBOL(ib_umem_find_best_pgsz);
 /**
  * ib_umem_get - Pin and DMA map userspace memory.
  *
- * If access flags indicate ODP memory, avoid pinning. Instead, stores
- * the mm for future page fault handling in conjunction with MMU notifiers.
- *
  * @udata: userspace context to pin memory for
  * @addr: userspace virtual address to start at
  * @size: length of region to pin
@@ -231,17 +228,12 @@ struct ib_umem *ib_umem_get(struct ib_udata *udata, unsigned long addr,
        if (!can_do_mlock())
                return ERR_PTR(-EPERM);
 
-       if (access & IB_ACCESS_ON_DEMAND) {
-               umem = kzalloc(sizeof(struct ib_umem_odp), GFP_KERNEL);
-               if (!umem)
-                       return ERR_PTR(-ENOMEM);
-               umem->is_odp = 1;
-       } else {
-               umem = kzalloc(sizeof(*umem), GFP_KERNEL);
-               if (!umem)
-                       return ERR_PTR(-ENOMEM);
-       }
+       if (access & IB_ACCESS_ON_DEMAND)
+               return ERR_PTR(-EOPNOTSUPP);
 
+       umem = kzalloc(sizeof(*umem), GFP_KERNEL);
+       if (!umem)
+               return ERR_PTR(-ENOMEM);
        umem->context    = context;
        umem->length     = size;
        umem->address    = addr;
@@ -249,18 +241,6 @@ struct ib_umem *ib_umem_get(struct ib_udata *udata, unsigned long addr,
        umem->owning_mm = mm = current->mm;
        mmgrab(mm);
 
-       if (access & IB_ACCESS_ON_DEMAND) {
-               if (WARN_ON_ONCE(!context->invalidate_range)) {
-                       ret = -EINVAL;
-                       goto umem_kfree;
-               }
-
-               ret = ib_umem_odp_get(to_ib_umem_odp(umem), access);
-               if (ret)
-                       goto umem_kfree;
-               return umem;
-       }
-
        page_list = (struct page **) __get_free_page(GFP_KERNEL);
        if (!page_list) {
                ret = -ENOMEM;
index 198c0ce..6a88bd0 100644 (file)
@@ -335,6 +335,7 @@ static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
                                     &per_mm->umem_tree);
                up_write(&per_mm->umem_rwsem);
        }
+       mmgrab(umem_odp->umem.owning_mm);
 
        return 0;
 
@@ -389,9 +390,6 @@ struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
                kfree(umem_odp);
                return ERR_PTR(ret);
        }
-
-       mmgrab(umem->owning_mm);
-
        return umem_odp;
 }
 EXPORT_SYMBOL(ib_umem_odp_alloc_implicit);
@@ -435,27 +433,51 @@ struct ib_umem_odp *ib_umem_odp_alloc_child(struct ib_umem_odp *root,
                kfree(odp_data);
                return ERR_PTR(ret);
        }
-
-       mmgrab(umem->owning_mm);
-
        return odp_data;
 }
 EXPORT_SYMBOL(ib_umem_odp_alloc_child);
 
 /**
- * ib_umem_odp_get - Complete ib_umem_get()
+ * ib_umem_odp_get - Create a umem_odp for a userspace va
  *
- * @umem_odp: The partially configured umem from ib_umem_get()
- * @addr: The starting userspace VA
- * @access: ib_reg_mr access flags
+ * @udata: userspace context to pin memory for
+ * @addr: userspace virtual address to start at
+ * @size: length of region to pin
+ * @access: IB_ACCESS_xxx flags for memory being pinned
+ *
+ * The driver should use when the access flags indicate ODP memory. It avoids
+ * pinning, instead, stores the mm for future page fault handling in
+ * conjunction with MMU notifiers.
  */
-int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
+struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
+                                   size_t size, int access)
 {
-       /*
-        * NOTE: This must called in a process context where umem->owning_mm
-        * == current->mm
-        */
-       struct mm_struct *mm = umem_odp->umem.owning_mm;
+       struct ib_umem_odp *umem_odp;
+       struct ib_ucontext *context;
+       struct mm_struct *mm;
+       int ret;
+
+       if (!udata)
+               return ERR_PTR(-EIO);
+
+       context = container_of(udata, struct uverbs_attr_bundle, driver_udata)
+                         ->context;
+       if (!context)
+               return ERR_PTR(-EIO);
+
+       if (WARN_ON_ONCE(!(access & IB_ACCESS_ON_DEMAND)) ||
+           WARN_ON_ONCE(!context->invalidate_range))
+               return ERR_PTR(-EINVAL);
+
+       umem_odp = kzalloc(sizeof(struct ib_umem_odp), GFP_KERNEL);
+       if (!umem_odp)
+               return ERR_PTR(-ENOMEM);
+
+       umem_odp->umem.context = context;
+       umem_odp->umem.length = size;
+       umem_odp->umem.address = addr;
+       umem_odp->umem.writable = ib_access_writable(access);
+       umem_odp->umem.owning_mm = mm = current->mm;
 
        umem_odp->page_shift = PAGE_SHIFT;
        if (access & IB_ACCESS_HUGETLB) {
@@ -466,15 +488,24 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
                vma = find_vma(mm, ib_umem_start(umem_odp));
                if (!vma || !is_vm_hugetlb_page(vma)) {
                        up_read(&mm->mmap_sem);
-                       return -EINVAL;
+                       ret = -EINVAL;
+                       goto err_free;
                }
                h = hstate_vma(vma);
                umem_odp->page_shift = huge_page_shift(h);
                up_read(&mm->mmap_sem);
        }
 
-       return ib_init_umem_odp(umem_odp, NULL);
+       ret = ib_init_umem_odp(umem_odp, NULL);
+       if (ret)
+               goto err_free;
+       return umem_odp;
+
+err_free:
+       kfree(umem_odp);
+       return ERR_PTR(ret);
 }
+EXPORT_SYMBOL(ib_umem_odp_get);
 
 void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
 {
index a40e0ab..b5aece7 100644 (file)
@@ -56,19 +56,6 @@ void mlx5_ib_cont_pages(struct ib_umem *umem, u64 addr,
        struct scatterlist *sg;
        int entry;
 
-       if (umem->is_odp) {
-               struct ib_umem_odp *odp = to_ib_umem_odp(umem);
-               unsigned int page_shift = odp->page_shift;
-
-               *ncont = ib_umem_odp_num_pages(odp);
-               *count = *ncont << (page_shift - PAGE_SHIFT);
-               *shift = page_shift;
-               if (order)
-                       *order = ilog2(roundup_pow_of_two(*ncont));
-
-               return;
-       }
-
        addr = addr >> PAGE_SHIFT;
        tmp = (unsigned long)addr;
        m = find_first_bit(&tmp, BITS_PER_LONG);
index ba2ec49..fc1106d 100644 (file)
@@ -784,19 +784,37 @@ static int mr_umem_get(struct mlx5_ib_dev *dev, struct ib_udata *udata,
                       int *ncont, int *order)
 {
        struct ib_umem *u;
-       int err;
 
        *umem = NULL;
 
-       u = ib_umem_get(udata, start, length, access_flags, 0);
-       err = PTR_ERR_OR_ZERO(u);
-       if (err) {
-               mlx5_ib_dbg(dev, "umem get failed (%d)\n", err);
-               return err;
+       if (access_flags & IB_ACCESS_ON_DEMAND) {
+               struct ib_umem_odp *odp;
+
+               odp = ib_umem_odp_get(udata, start, length, access_flags);
+               if (IS_ERR(odp)) {
+                       mlx5_ib_dbg(dev, "umem get failed (%ld)\n",
+                                   PTR_ERR(odp));
+                       return PTR_ERR(odp);
+               }
+
+               u = &odp->umem;
+
+               *page_shift = odp->page_shift;
+               *ncont = ib_umem_odp_num_pages(odp);
+               *npages = *ncont << (*page_shift - PAGE_SHIFT);
+               if (order)
+                       *order = ilog2(roundup_pow_of_two(*ncont));
+       } else {
+               u = ib_umem_get(udata, start, length, access_flags, 0);
+               if (IS_ERR(u)) {
+                       mlx5_ib_dbg(dev, "umem get failed (%ld)\n", PTR_ERR(u));
+                       return PTR_ERR(u);
+               }
+
+               mlx5_ib_cont_pages(u, start, MLX5_MKEY_PAGE_SHIFT_MASK, npages,
+                                  page_shift, ncont, order);
        }
 
-       mlx5_ib_cont_pages(u, start, MLX5_MKEY_PAGE_SHIFT_MASK, npages,
-                          page_shift, ncont, order);
        if (!*npages) {
                mlx5_ib_warn(dev, "avoid zero region\n");
                ib_umem_release(u);
index 219fe70..5efb67f 100644 (file)
@@ -139,7 +139,8 @@ struct ib_ucontext_per_mm {
        struct rcu_head rcu;
 };
 
-int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access);
+struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
+                                   size_t size, int access);
 struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
                                               int access);
 struct ib_umem_odp *ib_umem_odp_alloc_child(struct ib_umem_odp *root_umem,
@@ -199,9 +200,11 @@ static inline int ib_umem_mmu_notifier_retry(struct ib_umem_odp *umem_odp,
 
 #else /* CONFIG_INFINIBAND_ON_DEMAND_PAGING */
 
-static inline int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
+static inline struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata,
+                                                 unsigned long addr,
+                                                 size_t size, int access)
 {
-       return -EINVAL;
+       return ERR_PTR(-EINVAL);
 }
 
 static inline void ib_umem_odp_release(struct ib_umem_odp *umem_odp) {}