tls: rx: factor out writing ContentType to cmsg
authorJakub Kicinski <kuba@kernel.org>
Fri, 8 Apr 2022 18:31:28 +0000 (11:31 -0700)
committerDavid S. Miller <davem@davemloft.net>
Sun, 10 Apr 2022 16:32:11 +0000 (17:32 +0100)
cmsg can be filled in during rx_list processing or normal
receive. Consolidate the code.

We don't need to keep the boolean to track if the cmsg was
created. 0 is an invalid content type.

Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/tls/tls_sw.c

index 003f7c1..103a1aa 100644 (file)
@@ -1635,6 +1635,29 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
        return true;
 }
 
+static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
+                                  u8 *control)
+{
+       int err;
+
+       if (!*control) {
+               *control = tlm->control;
+               if (!*control)
+                       return -EBADMSG;
+
+               err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
+                              sizeof(*control), control);
+               if (*control != TLS_RECORD_TYPE_DATA) {
+                       if (err || msg->msg_flags & MSG_CTRUNC)
+                               return -EIO;
+               }
+       } else if (*control != tlm->control) {
+               return 0;
+       }
+
+       return 1;
+}
+
 /* This function traverses the rx_list in tls receive context to copies the
  * decrypted records into the buffer provided by caller zero copy is not
  * true. Further, the records are removed from the rx_list if it is not a peek
@@ -1643,31 +1666,23 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
 static int process_rx_list(struct tls_sw_context_rx *ctx,
                           struct msghdr *msg,
                           u8 *control,
-                          bool *cmsg,
                           size_t skip,
                           size_t len,
                           bool zc,
                           bool is_peek)
 {
        struct sk_buff *skb = skb_peek(&ctx->rx_list);
-       u8 ctrl = *control;
-       u8 msgc = *cmsg;
        struct tls_msg *tlm;
        ssize_t copied = 0;
-
-       /* Set the record type in 'control' if caller didn't pass it */
-       if (!ctrl && skb) {
-               tlm = tls_msg(skb);
-               ctrl = tlm->control;
-       }
+       int err;
 
        while (skip && skb) {
                struct strp_msg *rxm = strp_msg(skb);
                tlm = tls_msg(skb);
 
-               /* Cannot process a record of different type */
-               if (ctrl != tlm->control)
-                       return 0;
+               err = tls_record_content_type(msg, tlm, control);
+               if (err <= 0)
+                       return err;
 
                if (skip < rxm->full_len)
                        break;
@@ -1683,27 +1698,12 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
 
                tlm = tls_msg(skb);
 
-               /* Cannot process a record of different type */
-               if (ctrl != tlm->control)
-                       return 0;
-
-               /* Set record type if not already done. For a non-data record,
-                * do not proceed if record type could not be copied.
-                */
-               if (!msgc) {
-                       int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
-                                           sizeof(ctrl), &ctrl);
-                       msgc = true;
-                       if (ctrl != TLS_RECORD_TYPE_DATA) {
-                               if (cerr || msg->msg_flags & MSG_CTRUNC)
-                                       return -EIO;
-
-                               *cmsg = msgc;
-                       }
-               }
+               err = tls_record_content_type(msg, tlm, control);
+               if (err <= 0)
+                       return err;
 
                if (!zc || (rxm->full_len - skip) > len) {
-                       int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
+                       err = skb_copy_datagram_msg(skb, rxm->offset + skip,
                                                    msg, chunk);
                        if (err < 0)
                                return err;
@@ -1740,7 +1740,6 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
                skb = next_skb;
        }
 
-       *control = ctrl;
        return copied;
 }
 
@@ -1762,7 +1761,6 @@ int tls_sw_recvmsg(struct sock *sk,
        struct tls_msg *tlm;
        struct sk_buff *skb;
        ssize_t copied = 0;
-       bool cmsg = false;
        int target, err = 0;
        long timeo;
        bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
@@ -1779,8 +1777,7 @@ int tls_sw_recvmsg(struct sock *sk,
        bpf_strp_enabled = sk_psock_strp_enabled(psock);
 
        /* Process pending decrypted records. It must be non-zero-copy */
-       err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false,
-                             is_peek);
+       err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek);
        if (err < 0) {
                tls_err_abort(sk, err);
                goto end;
@@ -1852,26 +1849,10 @@ int tls_sw_recvmsg(struct sock *sk,
                 * is known just after record is dequeued from stream parser.
                 * For tls1.3, we disable async.
                 */
-
-               if (!control)
-                       control = tlm->control;
-               else if (control != tlm->control)
+               err = tls_record_content_type(msg, tlm, &control);
+               if (err <= 0)
                        goto recv_end;
 
-               if (!cmsg) {
-                       int cerr;
-
-                       cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
-                                       sizeof(control), &control);
-                       cmsg = true;
-                       if (control != TLS_RECORD_TYPE_DATA) {
-                               if (cerr || msg->msg_flags & MSG_CTRUNC) {
-                                       err = -EIO;
-                                       goto recv_end;
-                               }
-                       }
-               }
-
                if (async) {
                        /* TLS 1.2-only, to_decrypt must be text length */
                        chunk = min_t(int, to_decrypt, len);
@@ -1953,10 +1934,10 @@ recv_end:
 
                /* Drain records from the rx_list & copy if required */
                if (is_peek || is_kvec)
-                       err = process_rx_list(ctx, msg, &control, &cmsg, copied,
+                       err = process_rx_list(ctx, msg, &control, copied,
                                              decrypted, false, is_peek);
                else
-                       err = process_rx_list(ctx, msg, &control, &cmsg, 0,
+                       err = process_rx_list(ctx, msg, &control, 0,
                                              decrypted, true, is_peek);
                if (err < 0) {
                        tls_err_abort(sk, err);