io_uring: allow sparse fixed file sets
authorJens Axboe <axboe@kernel.dk>
Thu, 3 Oct 2019 14:11:03 +0000 (08:11 -0600)
committerJens Axboe <axboe@kernel.dk>
Tue, 29 Oct 2019 16:22:43 +0000 (10:22 -0600)
This is in preparation for allowing updates to fixed file sets without
requiring a full unregister+register.

Reviewed-by: Jeff Moyer <jmoyer@redhat.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
fs/io_uring.c

index 5db0854..b85e5fe 100644 (file)
@@ -2379,6 +2379,8 @@ static int io_req_set_file(struct io_ring_ctx *ctx, const struct sqe_submit *s,
                if (unlikely(!ctx->user_files ||
                    (unsigned) fd >= ctx->nr_user_files))
                        return -EBADF;
+               if (!ctx->user_files[fd])
+                       return -EBADF;
                req->file = ctx->user_files[fd];
                req->flags |= REQ_F_FIXED_FILE;
        } else {
@@ -2999,7 +3001,8 @@ static void __io_sqe_files_unregister(struct io_ring_ctx *ctx)
        int i;
 
        for (i = 0; i < ctx->nr_user_files; i++)
-               fput(ctx->user_files[i]);
+               if (ctx->user_files[i])
+                       fput(ctx->user_files[i]);
 #endif
 }
 
@@ -3067,7 +3070,7 @@ static int __io_sqe_files_scm(struct io_ring_ctx *ctx, int nr, int offset)
        struct sock *sk = ctx->ring_sock->sk;
        struct scm_fp_list *fpl;
        struct sk_buff *skb;
-       int i;
+       int i, nr_files;
 
        if (!capable(CAP_SYS_RESOURCE) && !capable(CAP_SYS_ADMIN)) {
                unsigned long inflight = ctx->user->unix_inflight + nr;
@@ -3087,21 +3090,31 @@ static int __io_sqe_files_scm(struct io_ring_ctx *ctx, int nr, int offset)
        }
 
        skb->sk = sk;
-       skb->destructor = io_destruct_skb;
 
+       nr_files = 0;
        fpl->user = get_uid(ctx->user);
        for (i = 0; i < nr; i++) {
-               fpl->fp[i] = get_file(ctx->user_files[i + offset]);
-               unix_inflight(fpl->user, fpl->fp[i]);
+               if (!ctx->user_files[i + offset])
+                       continue;
+               fpl->fp[nr_files] = get_file(ctx->user_files[i + offset]);
+               unix_inflight(fpl->user, fpl->fp[nr_files]);
+               nr_files++;
        }
 
-       fpl->max = fpl->count = nr;
-       UNIXCB(skb).fp = fpl;
-       refcount_add(skb->truesize, &sk->sk_wmem_alloc);
-       skb_queue_head(&sk->sk_receive_queue, skb);
+       if (nr_files) {
+               fpl->max = SCM_MAX_FD;
+               fpl->count = nr_files;
+               UNIXCB(skb).fp = fpl;
+               skb->destructor = io_destruct_skb;
+               refcount_add(skb->truesize, &sk->sk_wmem_alloc);
+               skb_queue_head(&sk->sk_receive_queue, skb);
 
-       for (i = 0; i < nr; i++)
-               fput(fpl->fp[i]);
+               for (i = 0; i < nr_files; i++)
+                       fput(fpl->fp[i]);
+       } else {
+               kfree_skb(skb);
+               kfree(fpl);
+       }
 
        return 0;
 }
@@ -3132,7 +3145,8 @@ static int io_sqe_files_scm(struct io_ring_ctx *ctx)
                return 0;
 
        while (total < ctx->nr_user_files) {
-               fput(ctx->user_files[total]);
+               if (ctx->user_files[total])
+                       fput(ctx->user_files[total]);
                total++;
        }
 
@@ -3163,10 +3177,15 @@ static int io_sqe_files_register(struct io_ring_ctx *ctx, void __user *arg,
        if (!ctx->user_files)
                return -ENOMEM;
 
-       for (i = 0; i < nr_args; i++) {
+       for (i = 0; i < nr_args; i++, ctx->nr_user_files++) {
                ret = -EFAULT;
                if (copy_from_user(&fd, &fds[i], sizeof(fd)))
                        break;
+               /* allow sparse sets */
+               if (fd == -1) {
+                       ret = 0;
+                       continue;
+               }
 
                ctx->user_files[i] = fget(fd);
 
@@ -3184,13 +3203,13 @@ static int io_sqe_files_register(struct io_ring_ctx *ctx, void __user *arg,
                        fput(ctx->user_files[i]);
                        break;
                }
-               ctx->nr_user_files++;
                ret = 0;
        }
 
        if (ret) {
                for (i = 0; i < ctx->nr_user_files; i++)
-                       fput(ctx->user_files[i]);
+                       if (ctx->user_files[i])
+                               fput(ctx->user_files[i]);
 
                kfree(ctx->user_files);
                ctx->user_files = NULL;