RDMA/umem: Make ib_umem_odp into a sub structure of ib_umem
authorJason Gunthorpe <jgg@mellanox.com>
Sun, 16 Sep 2018 17:48:05 +0000 (20:48 +0300)
committerDoug Ledford <dledford@redhat.com>
Fri, 21 Sep 2018 15:54:46 +0000 (11:54 -0400)
These two structures are linked together, use the container_of pattern
instead of a double allocation to make the code simpler and easier to
follow.

Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Doug Ledford <dledford@redhat.com>
drivers/infiniband/core/umem.c
drivers/infiniband/core/umem_odp.c
drivers/infiniband/hw/mlx5/odp.c
include/rdma/ib_umem_odp.h

index 971d92d..88b9b88 100644 (file)
@@ -108,34 +108,39 @@ struct ib_umem *ib_umem_get(struct ib_ucontext *context, unsigned long addr,
        if (!can_do_mlock())
                return ERR_PTR(-EPERM);
 
-       umem = kzalloc(sizeof *umem, GFP_KERNEL);
-       if (!umem)
-               return ERR_PTR(-ENOMEM);
+       if (access & IB_ACCESS_ON_DEMAND) {
+               umem = kzalloc(sizeof(struct ib_umem_odp), GFP_KERNEL);
+               if (!umem)
+                       return ERR_PTR(-ENOMEM);
+               umem->odp_data = to_ib_umem_odp(umem);
+       } else {
+               umem = kzalloc(sizeof(*umem), GFP_KERNEL);
+               if (!umem)
+                       return ERR_PTR(-ENOMEM);
+       }
 
        umem->context    = context;
        umem->length     = size;
        umem->address    = addr;
        umem->page_shift = PAGE_SHIFT;
        umem->writable   = ib_access_writable(access);
+       umem->owning_mm = mm = current->mm;
+       mmgrab(mm);
 
        if (access & IB_ACCESS_ON_DEMAND) {
-               ret = ib_umem_odp_get(context, umem, access);
+               ret = ib_umem_odp_get(to_ib_umem_odp(umem), access);
                if (ret)
                        goto umem_kfree;
                return umem;
        }
 
-       umem->owning_mm = mm = current->mm;
-       mmgrab(mm);
-       umem->odp_data = NULL;
-
        /* We assume the memory is from hugetlb until proved otherwise */
        umem->hugetlb   = 1;
 
        page_list = (struct page **) __get_free_page(GFP_KERNEL);
        if (!page_list) {
                ret = -ENOMEM;
-               goto umem_kfree_drop;
+               goto umem_kfree;
        }
 
        /*
@@ -226,12 +231,11 @@ out:
        if (vma_list)
                free_page((unsigned long) vma_list);
        free_page((unsigned long) page_list);
-umem_kfree_drop:
-       if (ret)
-               mmdrop(umem->owning_mm);
 umem_kfree:
-       if (ret)
+       if (ret) {
+               mmdrop(umem->owning_mm);
                kfree(umem);
+       }
        return ret ? ERR_PTR(ret) : umem;
 }
 EXPORT_SYMBOL(ib_umem_get);
@@ -239,7 +243,10 @@ EXPORT_SYMBOL(ib_umem_get);
 static void __ib_umem_release_tail(struct ib_umem *umem)
 {
        mmdrop(umem->owning_mm);
-       kfree(umem);
+       if (umem->odp_data)
+               kfree(to_ib_umem_odp(umem));
+       else
+               kfree(umem);
 }
 
 static void ib_umem_release_defer(struct work_struct *work)
@@ -263,6 +270,7 @@ void ib_umem_release(struct ib_umem *umem)
 
        if (umem->odp_data) {
                ib_umem_odp_release(to_ib_umem_odp(umem));
+               __ib_umem_release_tail(umem);
                return;
        }
 
index 8405e9a..900fded 100644 (file)
@@ -58,7 +58,7 @@ static u64 node_start(struct umem_odp_node *n)
        struct ib_umem_odp *umem_odp =
                        container_of(n, struct ib_umem_odp, interval_tree);
 
-       return ib_umem_start(umem_odp->umem);
+       return ib_umem_start(&umem_odp->umem);
 }
 
 /* Note that the representation of the intervals in the interval tree
@@ -71,7 +71,7 @@ static u64 node_last(struct umem_odp_node *n)
        struct ib_umem_odp *umem_odp =
                        container_of(n, struct ib_umem_odp, interval_tree);
 
-       return ib_umem_end(umem_odp->umem) - 1;
+       return ib_umem_end(&umem_odp->umem) - 1;
 }
 
 INTERVAL_TREE_DEFINE(struct umem_odp_node, rb, u64, __subtree_last,
@@ -159,7 +159,7 @@ static void ib_ucontext_notifier_end_account(struct ib_ucontext *context)
 static int ib_umem_notifier_release_trampoline(struct ib_umem_odp *umem_odp,
                                               u64 start, u64 end, void *cookie)
 {
-       struct ib_umem *umem = umem_odp->umem;
+       struct ib_umem *umem = &umem_odp->umem;
 
        /*
         * Increase the number of notifiers running, to
@@ -198,7 +198,7 @@ static int invalidate_page_trampoline(struct ib_umem_odp *item, u64 start,
                                      u64 end, void *cookie)
 {
        ib_umem_notifier_start_account(item);
-       item->umem->context->invalidate_range(item, start, start + PAGE_SIZE);
+       item->umem.context->invalidate_range(item, start, start + PAGE_SIZE);
        ib_umem_notifier_end_account(item);
        return 0;
 }
@@ -207,7 +207,7 @@ static int invalidate_range_start_trampoline(struct ib_umem_odp *item,
                                             u64 start, u64 end, void *cookie)
 {
        ib_umem_notifier_start_account(item);
-       item->umem->context->invalidate_range(item, start, end);
+       item->umem.context->invalidate_range(item, start, end);
        return 0;
 }
 
@@ -277,28 +277,21 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
 struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
                                      unsigned long addr, size_t size)
 {
-       struct ib_umem *umem;
        struct ib_umem_odp *odp_data;
+       struct ib_umem *umem;
        int pages = size >> PAGE_SHIFT;
        int ret;
 
-       umem = kzalloc(sizeof(*umem), GFP_KERNEL);
-       if (!umem)
+       odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
+       if (!odp_data)
                return ERR_PTR(-ENOMEM);
-
+       umem = &odp_data->umem;
        umem->context    = context;
        umem->length     = size;
        umem->address    = addr;
        umem->page_shift = PAGE_SHIFT;
        umem->writable   = 1;
 
-       odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
-       if (!odp_data) {
-               ret = -ENOMEM;
-               goto out_umem;
-       }
-       odp_data->umem = umem;
-
        mutex_init(&odp_data->umem_mutex);
        init_completion(&odp_data->notifier_completion);
 
@@ -334,15 +327,14 @@ out_page_list:
        vfree(odp_data->page_list);
 out_odp_data:
        kfree(odp_data);
-out_umem:
-       kfree(umem);
        return ERR_PTR(ret);
 }
 EXPORT_SYMBOL(ib_alloc_odp_umem);
 
-int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
-                   int access)
+int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
 {
+       struct ib_ucontext *context = umem_odp->umem.context;
+       struct ib_umem *umem = &umem_odp->umem;
        int ret_val;
        struct pid *our_pid;
        struct mm_struct *mm = get_task_mm(current);
@@ -378,30 +370,23 @@ int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
                goto out_mm;
        }
 
-       umem->odp_data = kzalloc(sizeof(*umem->odp_data), GFP_KERNEL);
-       if (!umem->odp_data) {
-               ret_val = -ENOMEM;
-               goto out_mm;
-       }
-       umem->odp_data->umem = umem;
-
-       mutex_init(&umem->odp_data->umem_mutex);
+       mutex_init(&umem_odp->umem_mutex);
 
-       init_completion(&umem->odp_data->notifier_completion);
+       init_completion(&umem_odp->notifier_completion);
 
        if (ib_umem_num_pages(umem)) {
-               umem->odp_data->page_list =
-                       vzalloc(array_size(sizeof(*umem->odp_data->page_list),
+               umem_odp->page_list =
+                       vzalloc(array_size(sizeof(*umem_odp->page_list),
                                           ib_umem_num_pages(umem)));
-               if (!umem->odp_data->page_list) {
+               if (!umem_odp->page_list) {
                        ret_val = -ENOMEM;
-                       goto out_odp_data;
+                       goto out_mm;
                }
 
-               umem->odp_data->dma_list =
-                       vzalloc(array_size(sizeof(*umem->odp_data->dma_list),
+               umem_odp->dma_list =
+                       vzalloc(array_size(sizeof(*umem_odp->dma_list),
                                           ib_umem_num_pages(umem)));
-               if (!umem->odp_data->dma_list) {
+               if (!umem_odp->dma_list) {
                        ret_val = -ENOMEM;
                        goto out_page_list;
                }
@@ -415,13 +400,13 @@ int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
        down_write(&context->umem_rwsem);
        context->odp_mrs_count++;
        if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
-               rbt_ib_umem_insert(&umem->odp_data->interval_tree,
+               rbt_ib_umem_insert(&umem_odp->interval_tree,
                                   &context->umem_tree);
        if (likely(!atomic_read(&context->notifier_count)) ||
            context->odp_mrs_count == 1)
-               umem->odp_data->mn_counters_active = true;
+               umem_odp->mn_counters_active = true;
        else
-               list_add(&umem->odp_data->no_private_counters,
+               list_add(&umem_odp->no_private_counters,
                         &context->no_private_counters);
        downgrade_write(&context->umem_rwsem);
 
@@ -454,11 +439,9 @@ int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
 
 out_mutex:
        up_read(&context->umem_rwsem);
-       vfree(umem->odp_data->dma_list);
+       vfree(umem_odp->dma_list);
 out_page_list:
-       vfree(umem->odp_data->page_list);
-out_odp_data:
-       kfree(umem->odp_data);
+       vfree(umem_odp->page_list);
 out_mm:
        mmput(mm);
        return ret_val;
@@ -466,7 +449,7 @@ out_mm:
 
 void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
 {
-       struct ib_umem *umem = umem_odp->umem;
+       struct ib_umem *umem = &umem_odp->umem;
        struct ib_ucontext *context = umem->context;
 
        /*
@@ -528,8 +511,6 @@ out:
 
        vfree(umem_odp->dma_list);
        vfree(umem_odp->page_list);
-       kfree(umem_odp);
-       kfree(umem);
 }
 
 /*
@@ -557,7 +538,7 @@ static int ib_umem_odp_map_dma_single_page(
                u64 access_mask,
                unsigned long current_seq)
 {
-       struct ib_umem *umem = umem_odp->umem;
+       struct ib_umem *umem = &umem_odp->umem;
        struct ib_device *dev = umem->context->device;
        dma_addr_t dma_addr;
        int stored_page = 0;
@@ -643,7 +624,7 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
                              u64 bcnt, u64 access_mask,
                              unsigned long current_seq)
 {
-       struct ib_umem *umem = umem_odp->umem;
+       struct ib_umem *umem = &umem_odp->umem;
        struct task_struct *owning_process  = NULL;
        struct mm_struct   *owning_mm       = NULL;
        struct page       **local_page_list = NULL;
@@ -759,7 +740,7 @@ EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
 void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
                                 u64 bound)
 {
-       struct ib_umem *umem = umem_odp->umem;
+       struct ib_umem *umem = &umem_odp->umem;
        int idx;
        u64 addr;
        struct ib_device *dev = umem->context->device;
index 8f4a4a8..5b9fd56 100644 (file)
@@ -64,7 +64,7 @@ static int check_parent(struct ib_umem_odp *odp,
 static struct ib_umem_odp *odp_next(struct ib_umem_odp *odp)
 {
        struct mlx5_ib_mr *mr = odp->private, *parent = mr->parent;
-       struct ib_ucontext *ctx = odp->umem->context;
+       struct ib_ucontext *ctx = odp->umem.context;
        struct rb_node *rb;
 
        down_read(&ctx->umem_rwsem);
@@ -102,7 +102,7 @@ static struct ib_umem_odp *odp_lookup(struct ib_ucontext *ctx,
                if (!rb)
                        goto not_found;
                odp = rb_entry(rb, struct ib_umem_odp, interval_tree.rb);
-               if (ib_umem_start(odp->umem) > start + length)
+               if (ib_umem_start(&odp->umem) > start + length)
                        goto not_found;
        }
 not_found:
@@ -137,7 +137,7 @@ void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset,
        for (i = 0; i < nentries; i++, pklm++) {
                pklm->bcount = cpu_to_be32(MLX5_IMR_MTT_SIZE);
                va = (offset + i) * MLX5_IMR_MTT_SIZE;
-               if (odp && odp->umem->address == va) {
+               if (odp && odp->umem.address == va) {
                        struct mlx5_ib_mr *mtt = odp->private;
 
                        pklm->key = cpu_to_be32(mtt->ibmr.lkey);
@@ -153,13 +153,13 @@ void mlx5_odp_populate_klm(struct mlx5_klm *pklm, size_t offset,
 static void mr_leaf_free_action(struct work_struct *work)
 {
        struct ib_umem_odp *odp = container_of(work, struct ib_umem_odp, work);
-       int idx = ib_umem_start(odp->umem) >> MLX5_IMR_MTT_SHIFT;
+       int idx = ib_umem_start(&odp->umem) >> MLX5_IMR_MTT_SHIFT;
        struct mlx5_ib_mr *mr = odp->private, *imr = mr->parent;
 
        mr->parent = NULL;
        synchronize_srcu(&mr->dev->mr_srcu);
 
-       ib_umem_release(odp->umem);
+       ib_umem_release(&odp->umem);
        if (imr->live)
                mlx5_ib_update_xlt(imr, idx, 1, 0,
                                   MLX5_IB_UPD_XLT_INDIRECT |
@@ -185,7 +185,7 @@ void mlx5_ib_invalidate_range(struct ib_umem_odp *umem_odp, unsigned long start,
                pr_err("invalidation called on NULL umem or non-ODP umem\n");
                return;
        }
-       umem = umem_odp->umem;
+       umem = &umem_odp->umem;
 
        mr = umem_odp->private;
 
@@ -392,16 +392,16 @@ next_mr:
                        return ERR_CAST(odp);
                }
 
-               mtt = implicit_mr_alloc(mr->ibmr.pd, odp->umem, 0,
+               mtt = implicit_mr_alloc(mr->ibmr.pd, &odp->umem, 0,
                                        mr->access_flags);
                if (IS_ERR(mtt)) {
                        mutex_unlock(&mr->umem->odp_data->umem_mutex);
-                       ib_umem_release(odp->umem);
+                       ib_umem_release(&odp->umem);
                        return ERR_CAST(mtt);
                }
 
                odp->private = mtt;
-               mtt->umem = odp->umem;
+               mtt->umem = &odp->umem;
                mtt->mmkey.iova = addr;
                mtt->parent = mr;
                INIT_WORK(&odp->work, mr_leaf_free_action);
@@ -418,7 +418,7 @@ next_mr:
        addr += MLX5_IMR_MTT_SIZE;
        if (unlikely(addr < io_virt + bcnt)) {
                odp = odp_next(odp);
-               if (odp && odp->umem->address != addr)
+               if (odp && odp->umem.address != addr)
                        odp = NULL;
                goto next_mr;
        }
@@ -465,7 +465,7 @@ static int mr_leaf_free(struct ib_umem_odp *umem_odp, u64 start, u64 end,
                        void *cookie)
 {
        struct mlx5_ib_mr *mr = umem_odp->private, *imr = cookie;
-       struct ib_umem *umem = umem_odp->umem;
+       struct ib_umem *umem = &umem_odp->umem;
 
        if (mr->parent != imr)
                return 0;
@@ -518,7 +518,7 @@ static int pagefault_mr(struct mlx5_ib_dev *dev, struct mlx5_ib_mr *mr,
        }
 
 next_mr:
-       size = min_t(size_t, bcnt, ib_umem_end(odp->umem) - io_virt);
+       size = min_t(size_t, bcnt, ib_umem_end(&odp->umem) - io_virt);
 
        page_shift = mr->umem->page_shift;
        page_mask = ~(BIT(page_shift) - 1);
@@ -577,7 +577,7 @@ next_mr:
 
                io_virt += size;
                next = odp_next(odp);
-               if (unlikely(!next || next->umem->address != io_virt)) {
+               if (unlikely(!next || next->umem.address != io_virt)) {
                        mlx5_ib_dbg(dev, "next implicit leaf removed at 0x%llx. got %p\n",
                                    io_virt, next);
                        return -EAGAIN;
index 3ef2975..4519ea6 100644 (file)
@@ -43,6 +43,7 @@ struct umem_odp_node {
 };
 
 struct ib_umem_odp {
+       struct ib_umem umem;
        /*
         * An array of the pages included in the on-demand paging umem.
         * Indices of pages that are currently not mapped into the device will
@@ -72,7 +73,6 @@ struct ib_umem_odp {
        /* A linked list of umems that don't have private mmu notifier
         * counters yet. */
        struct list_head no_private_counters;
-       struct ib_umem          *umem;
 
        /* Tree tracking */
        struct umem_odp_node    interval_tree;
@@ -84,13 +84,12 @@ struct ib_umem_odp {
 
 static inline struct ib_umem_odp *to_ib_umem_odp(struct ib_umem *umem)
 {
-       return umem->odp_data;
+       return container_of(umem, struct ib_umem_odp, umem);
 }
 
 #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
 
-int ib_umem_odp_get(struct ib_ucontext *context, struct ib_umem *umem,
-                   int access);
+int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access);
 struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
                                      unsigned long addr, size_t size);
 void ib_umem_odp_release(struct ib_umem_odp *umem_odp);
@@ -158,9 +157,7 @@ static inline int ib_umem_mmu_notifier_retry(struct ib_umem_odp *umem_odp,
 
 #else /* CONFIG_INFINIBAND_ON_DEMAND_PAGING */
 
-static inline int ib_umem_odp_get(struct ib_ucontext *context,
-                                 struct ib_umem *umem,
-                                 int access)
+static inline int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
 {
        return -EINVAL;
 }