RDMA/uverbs: Use the iterator for ib_uverbs_unmarshall_recv()
authorJason Gunthorpe <jgg@mellanox.com>
Sun, 25 Nov 2018 18:58:43 +0000 (20:58 +0200)
committerDoug Ledford <dledford@redhat.com>
Mon, 3 Dec 2018 17:01:58 +0000 (12:01 -0500)
This has a very complicated memory layout, with two flex arrays. Use
the iterator API to make reading it clearer.

Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Doug Ledford <dledford@redhat.com>
drivers/infiniband/core/uverbs_cmd.c

index a89b844..15b9db4 100644 (file)
@@ -150,6 +150,17 @@ static int uverbs_request_next(struct uverbs_req_iter *iter, void *val,
        return 0;
 }
 
+static const void __user *uverbs_request_next_ptr(struct uverbs_req_iter *iter,
+                                                 size_t len)
+{
+       const void __user *res = iter->cur;
+
+       if (iter->cur + len > iter->end)
+               return ERR_PTR(-ENOSPC);
+       iter->cur += len;
+       return res;
+}
+
 static int uverbs_request_finish(struct uverbs_req_iter *iter)
 {
        if (!ib_is_buffer_cleared(iter->cur, iter->end - iter->cur))
@@ -2073,16 +2084,23 @@ static int ib_uverbs_post_send(struct uverbs_attr_bundle *attrs,
        int                             is_ud;
        int ret, ret2;
        size_t                          next_size;
+       const struct ib_sge __user *sgls;
+       const void __user *wqes;
+       struct uverbs_req_iter iter;
 
-       if (copy_from_user(&cmd, buf, sizeof cmd))
-               return -EFAULT;
-
-       if (in_len < sizeof cmd + cmd.wqe_size * cmd.wr_count +
-           cmd.sge_count * sizeof (struct ib_uverbs_sge))
-               return -EINVAL;
-
-       if (cmd.wqe_size < sizeof (struct ib_uverbs_send_wr))
-               return -EINVAL;
+       ret = uverbs_request_start(attrs, &iter, &cmd, sizeof(cmd));
+       if (ret)
+               return ret;
+       wqes = uverbs_request_next_ptr(&iter, cmd.wqe_size * cmd.wr_count);
+       if (IS_ERR(wqes))
+               return PTR_ERR(wqes);
+       sgls = uverbs_request_next_ptr(
+               &iter, cmd.sge_count * sizeof(struct ib_uverbs_sge));
+       if (IS_ERR(sgls))
+               return PTR_ERR(sgls);
+       ret = uverbs_request_finish(&iter);
+       if (ret)
+               return ret;
 
        user_wr = kmalloc(cmd.wqe_size, GFP_KERNEL);
        if (!user_wr)
@@ -2096,8 +2114,7 @@ static int ib_uverbs_post_send(struct uverbs_attr_bundle *attrs,
        sg_ind = 0;
        last = NULL;
        for (i = 0; i < cmd.wr_count; ++i) {
-               if (copy_from_user(user_wr,
-                                  buf + sizeof cmd + i * cmd.wqe_size,
+               if (copy_from_user(user_wr, wqes + i * cmd.wqe_size,
                                   cmd.wqe_size)) {
                        ret = -EFAULT;
                        goto out_put;
@@ -2205,11 +2222,9 @@ static int ib_uverbs_post_send(struct uverbs_attr_bundle *attrs,
                if (next->num_sge) {
                        next->sg_list = (void *) next +
                                ALIGN(next_size, sizeof(struct ib_sge));
-                       if (copy_from_user(next->sg_list,
-                                          buf + sizeof cmd +
-                                          cmd.wr_count * cmd.wqe_size +
-                                          sg_ind * sizeof (struct ib_sge),
-                                          next->num_sge * sizeof (struct ib_sge))) {
+                       if (copy_from_user(next->sg_list, sgls + sg_ind,
+                                          next->num_sge *
+                                                  sizeof(struct ib_sge))) {
                                ret = -EFAULT;
                                goto out_put;
                        }
@@ -2248,25 +2263,32 @@ out:
        return ret;
 }
 
-static struct ib_recv_wr *ib_uverbs_unmarshall_recv(const char __user *buf,
-                                                   int in_len,
-                                                   u32 wr_count,
-                                                   u32 sge_count,
-                                                   u32 wqe_size)
+static struct ib_recv_wr *
+ib_uverbs_unmarshall_recv(struct uverbs_req_iter *iter, u32 wr_count,
+                         u32 wqe_size, u32 sge_count)
 {
        struct ib_uverbs_recv_wr *user_wr;
        struct ib_recv_wr        *wr = NULL, *last, *next;
        int                       sg_ind;
        int                       i;
        int                       ret;
-
-       if (in_len < wqe_size * wr_count +
-           sge_count * sizeof (struct ib_uverbs_sge))
-               return ERR_PTR(-EINVAL);
+       const struct ib_sge __user *sgls;
+       const void __user *wqes;
 
        if (wqe_size < sizeof (struct ib_uverbs_recv_wr))
                return ERR_PTR(-EINVAL);
 
+       wqes = uverbs_request_next_ptr(iter, wqe_size * wr_count);
+       if (IS_ERR(wqes))
+               return ERR_CAST(wqes);
+       sgls = uverbs_request_next_ptr(
+               iter, sge_count * sizeof(struct ib_uverbs_sge));
+       if (IS_ERR(sgls))
+               return ERR_CAST(sgls);
+       ret = uverbs_request_finish(iter);
+       if (ret)
+               return ERR_PTR(ret);
+
        user_wr = kmalloc(wqe_size, GFP_KERNEL);
        if (!user_wr)
                return ERR_PTR(-ENOMEM);
@@ -2274,7 +2296,7 @@ static struct ib_recv_wr *ib_uverbs_unmarshall_recv(const char __user *buf,
        sg_ind = 0;
        last = NULL;
        for (i = 0; i < wr_count; ++i) {
-               if (copy_from_user(user_wr, buf + i * wqe_size,
+               if (copy_from_user(user_wr, wqes + i * wqe_size,
                                   wqe_size)) {
                        ret = -EFAULT;
                        goto err;
@@ -2313,10 +2335,9 @@ static struct ib_recv_wr *ib_uverbs_unmarshall_recv(const char __user *buf,
                if (next->num_sge) {
                        next->sg_list = (void *) next +
                                ALIGN(sizeof *next, sizeof (struct ib_sge));
-                       if (copy_from_user(next->sg_list,
-                                          buf + wr_count * wqe_size +
-                                          sg_ind * sizeof (struct ib_sge),
-                                          next->num_sge * sizeof (struct ib_sge))) {
+                       if (copy_from_user(next->sg_list, sgls + sg_ind,
+                                          next->num_sge *
+                                                  sizeof(struct ib_sge))) {
                                ret = -EFAULT;
                                goto err;
                        }
@@ -2349,13 +2370,14 @@ static int ib_uverbs_post_recv(struct uverbs_attr_bundle *attrs,
        const struct ib_recv_wr        *bad_wr;
        struct ib_qp                   *qp;
        int ret, ret2;
+       struct uverbs_req_iter iter;
 
-       if (copy_from_user(&cmd, buf, sizeof cmd))
-               return -EFAULT;
+       ret = uverbs_request_start(attrs, &iter, &cmd, sizeof(cmd));
+       if (ret)
+               return ret;
 
-       wr = ib_uverbs_unmarshall_recv(buf + sizeof cmd,
-                                      in_len - sizeof cmd, cmd.wr_count,
-                                      cmd.sge_count, cmd.wqe_size);
+       wr = ib_uverbs_unmarshall_recv(&iter, cmd.wr_count, cmd.wqe_size,
+                                      cmd.sge_count);
        if (IS_ERR(wr))
                return PTR_ERR(wr);
 
@@ -2400,13 +2422,14 @@ static int ib_uverbs_post_srq_recv(struct uverbs_attr_bundle *attrs,
        const struct ib_recv_wr            *bad_wr;
        struct ib_srq                      *srq;
        int ret, ret2;
+       struct uverbs_req_iter iter;
 
-       if (copy_from_user(&cmd, buf, sizeof cmd))
-               return -EFAULT;
+       ret = uverbs_request_start(attrs, &iter, &cmd, sizeof(cmd));
+       if (ret)
+               return ret;
 
-       wr = ib_uverbs_unmarshall_recv(buf + sizeof cmd,
-                                      in_len - sizeof cmd, cmd.wr_count,
-                                      cmd.sge_count, cmd.wqe_size);
+       wr = ib_uverbs_unmarshall_recv(&iter, cmd.wr_count, cmd.wqe_size,
+                                      cmd.sge_count);
        if (IS_ERR(wr))
                return PTR_ERR(wr);