tls/sw: Convert tls_sw_sendpage() to use MSG_SPLICE_PAGES
authorDavid Howells <dhowells@redhat.com>
Wed, 7 Jun 2023 18:19:18 +0000 (19:19 +0100)
committerJakub Kicinski <kuba@kernel.org>
Fri, 9 Jun 2023 02:40:31 +0000 (19:40 -0700)
Convert tls_sw_sendpage() and tls_sw_sendpage_locked() to use sendmsg()
with MSG_SPLICE_PAGES rather than directly splicing in the pages itself.

[!] Note that tls_sw_sendpage_locked() appears to have the wrong locking
    upstream.  I think the caller will only hold the socket lock, but it
    should hold tls_ctx->tx_lock too.

This allows ->sendpage() to be replaced by something that can handle
multiple multipage folios in a single transaction.

Signed-off-by: David Howells <dhowells@redhat.com>
Reviewed-by: Jakub Kicinski <kuba@kernel.org>
cc: Chuck Lever <chuck.lever@oracle.com>
cc: Boris Pismenny <borisp@nvidia.com>
cc: John Fastabend <john.fastabend@gmail.com>
cc: Jens Axboe <axboe@kernel.dk>
cc: Matthew Wilcox <willy@infradead.org>
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
net/tls/tls_sw.c

index 2d2bb93..319f615 100644 (file)
@@ -960,7 +960,8 @@ static int tls_sw_sendmsg_splice(struct sock *sk, struct msghdr *msg,
        return 0;
 }
 
-int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
+static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
+                                size_t size)
 {
        long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
        struct tls_context *tls_ctx = tls_get_ctx(sk);
@@ -983,15 +984,6 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
        int ret = 0;
        int pending;
 
-       if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
-                              MSG_CMSG_COMPAT | MSG_SPLICE_PAGES))
-               return -EOPNOTSUPP;
-
-       ret = mutex_lock_interruptible(&tls_ctx->tx_lock);
-       if (ret)
-               return ret;
-       lock_sock(sk);
-
        if (unlikely(msg->msg_controllen)) {
                ret = tls_process_cmsg(sk, msg, &record_type);
                if (ret) {
@@ -1192,10 +1184,27 @@ trim_sgl:
 
 send_end:
        ret = sk_stream_error(sk, msg->msg_flags, ret);
+       return copied > 0 ? copied : ret;
+}
 
+int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       int ret;
+
+       if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
+                              MSG_CMSG_COMPAT | MSG_SPLICE_PAGES |
+                              MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
+               return -EOPNOTSUPP;
+
+       ret = mutex_lock_interruptible(&tls_ctx->tx_lock);
+       if (ret)
+               return ret;
+       lock_sock(sk);
+       ret = tls_sw_sendmsg_locked(sk, msg, size);
        release_sock(sk);
        mutex_unlock(&tls_ctx->tx_lock);
-       return copied > 0 ? copied : ret;
+       return ret;
 }
 
 /*
@@ -1272,151 +1281,39 @@ unlock:
        mutex_unlock(&tls_ctx->tx_lock);
 }
 
-static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
-                             int offset, size_t size, int flags)
-{
-       long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
-       struct tls_context *tls_ctx = tls_get_ctx(sk);
-       struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-       struct tls_prot_info *prot = &tls_ctx->prot_info;
-       unsigned char record_type = TLS_RECORD_TYPE_DATA;
-       struct sk_msg *msg_pl;
-       struct tls_rec *rec;
-       int num_async = 0;
-       ssize_t copied = 0;
-       bool full_record;
-       int record_room;
-       int ret = 0;
-       bool eor;
-
-       eor = !(flags & MSG_SENDPAGE_NOTLAST);
-       sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
-
-       /* Call the sk_stream functions to manage the sndbuf mem. */
-       while (size > 0) {
-               size_t copy, required_size;
-
-               if (sk->sk_err) {
-                       ret = -sk->sk_err;
-                       goto sendpage_end;
-               }
-
-               if (ctx->open_rec)
-                       rec = ctx->open_rec;
-               else
-                       rec = ctx->open_rec = tls_get_rec(sk);
-               if (!rec) {
-                       ret = -ENOMEM;
-                       goto sendpage_end;
-               }
-
-               msg_pl = &rec->msg_plaintext;
-
-               full_record = false;
-               record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
-               copy = size;
-               if (copy >= record_room) {
-                       copy = record_room;
-                       full_record = true;
-               }
-
-               required_size = msg_pl->sg.size + copy + prot->overhead_size;
-
-               if (!sk_stream_memory_free(sk))
-                       goto wait_for_sndbuf;
-alloc_payload:
-               ret = tls_alloc_encrypted_msg(sk, required_size);
-               if (ret) {
-                       if (ret != -ENOSPC)
-                               goto wait_for_memory;
-
-                       /* Adjust copy according to the amount that was
-                        * actually allocated. The difference is due
-                        * to max sg elements limit
-                        */
-                       copy -= required_size - msg_pl->sg.size;
-                       full_record = true;
-               }
-
-               sk_msg_page_add(msg_pl, page, copy, offset);
-               sk_mem_charge(sk, copy);
-
-               offset += copy;
-               size -= copy;
-               copied += copy;
-
-               tls_ctx->pending_open_record_frags = true;
-               if (full_record || eor || sk_msg_full(msg_pl)) {
-                       ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
-                                                 record_type, &copied, flags);
-                       if (ret) {
-                               if (ret == -EINPROGRESS)
-                                       num_async++;
-                               else if (ret == -ENOMEM)
-                                       goto wait_for_memory;
-                               else if (ret != -EAGAIN) {
-                                       if (ret == -ENOSPC)
-                                               ret = 0;
-                                       goto sendpage_end;
-                               }
-                       }
-               }
-               continue;
-wait_for_sndbuf:
-               set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
-wait_for_memory:
-               ret = sk_stream_wait_memory(sk, &timeo);
-               if (ret) {
-                       if (ctx->open_rec)
-                               tls_trim_both_msgs(sk, msg_pl->sg.size);
-                       goto sendpage_end;
-               }
-
-               if (ctx->open_rec)
-                       goto alloc_payload;
-       }
-
-       if (num_async) {
-               /* Transmit if any encryptions have completed */
-               if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
-                       cancel_delayed_work(&ctx->tx_work.work);
-                       tls_tx_records(sk, flags);
-               }
-       }
-sendpage_end:
-       ret = sk_stream_error(sk, flags, ret);
-       return copied > 0 ? copied : ret;
-}
-
 int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
                           int offset, size_t size, int flags)
 {
+       struct bio_vec bvec;
+       struct msghdr msg = { .msg_flags = flags | MSG_SPLICE_PAGES, };
+
        if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
                      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
                      MSG_NO_SHARED_FRAGS))
                return -EOPNOTSUPP;
+       if (flags & MSG_SENDPAGE_NOTLAST)
+               msg.msg_flags |= MSG_MORE;
 
-       return tls_sw_do_sendpage(sk, page, offset, size, flags);
+       bvec_set_page(&bvec, page, size, offset);
+       iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size);
+       return tls_sw_sendmsg_locked(sk, &msg, size);
 }
 
 int tls_sw_sendpage(struct sock *sk, struct page *page,
                    int offset, size_t size, int flags)
 {
-       struct tls_context *tls_ctx = tls_get_ctx(sk);
-       int ret;
+       struct bio_vec bvec;
+       struct msghdr msg = { .msg_flags = flags | MSG_SPLICE_PAGES, };
 
        if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
                      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
                return -EOPNOTSUPP;
+       if (flags & MSG_SENDPAGE_NOTLAST)
+               msg.msg_flags |= MSG_MORE;
 
-       ret = mutex_lock_interruptible(&tls_ctx->tx_lock);
-       if (ret)
-               return ret;
-       lock_sock(sk);
-       ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
-       release_sock(sk);
-       mutex_unlock(&tls_ctx->tx_lock);
-       return ret;
+       bvec_set_page(&bvec, page, size, offset);
+       iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size);
+       return tls_sw_sendmsg(sk, &msg, size);
 }
 
 static int