net/tls: handle errors from padding_length()
authorJakub Kicinski <jakub.kicinski@netronome.com>
Thu, 9 May 2019 23:14:07 +0000 (16:14 -0700)
committerDavid S. Miller <davem@davemloft.net>
Thu, 9 May 2019 23:37:39 +0000 (16:37 -0700)
At the time padding_length() is called the record header
is still part of the message.  If malicious TLS 1.3 peer
sends an all-zero record padding_length() will stop at
the record header, and return full length of the data
including the tail_size.

Subsequent subtraction of prot->overhead_size from rxm->full_len
will cause rxm->full_len to turn negative.  skb accessors,
however, will always catch resulting out-of-bounds operation,
so in practice this fix comes down to returning the correct
error code.  It also fixes a set but not used warning.

This code was added by commit 130b392c6cd6 ("net: tls: Add tls 1.3 support").

CC: Dave Watson <davejwatson@fb.com>
Signed-off-by: Jakub Kicinski <jakub.kicinski@netronome.com>
Reviewed-by: Dirk van der Merwe <dirk.vandermerwe@netronome.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/tls/tls_sw.c

index c02293f..d93f83f 100644 (file)
@@ -119,23 +119,25 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len)
 }
 
 static int padding_length(struct tls_sw_context_rx *ctx,
-                         struct tls_context *tls_ctx, struct sk_buff *skb)
+                         struct tls_prot_info *prot, struct sk_buff *skb)
 {
        struct strp_msg *rxm = strp_msg(skb);
        int sub = 0;
 
        /* Determine zero-padding length */
-       if (tls_ctx->prot_info.version == TLS_1_3_VERSION) {
+       if (prot->version == TLS_1_3_VERSION) {
                char content_type = 0;
                int err;
                int back = 17;
 
                while (content_type == 0) {
-                       if (back > rxm->full_len)
+                       if (back > rxm->full_len - prot->prepend_size)
                                return -EBADMSG;
                        err = skb_copy_bits(skb,
                                            rxm->offset + rxm->full_len - back,
                                            &content_type, 1);
+                       if (err)
+                               return err;
                        if (content_type)
                                break;
                        sub++;
@@ -170,9 +172,17 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
                tls_err_abort(skb->sk, err);
        } else {
                struct strp_msg *rxm = strp_msg(skb);
-               rxm->full_len -= padding_length(ctx, tls_ctx, skb);
-               rxm->offset += prot->prepend_size;
-               rxm->full_len -= prot->overhead_size;
+               int pad;
+
+               pad = padding_length(ctx, prot, skb);
+               if (pad < 0) {
+                       ctx->async_wait.err = pad;
+                       tls_err_abort(skb->sk, pad);
+               } else {
+                       rxm->full_len -= pad;
+                       rxm->offset += prot->prepend_size;
+                       rxm->full_len -= prot->overhead_size;
+               }
        }
 
        /* After using skb->sk to propagate sk through crypto async callback
@@ -1478,7 +1488,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
        struct tls_prot_info *prot = &tls_ctx->prot_info;
        int version = prot->version;
        struct strp_msg *rxm = strp_msg(skb);
-       int err = 0;
+       int pad, err = 0;
 
        if (!ctx->decrypted) {
 #ifdef CONFIG_TLS_DEVICE
@@ -1501,7 +1511,11 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
                        *zc = false;
                }
 
-               rxm->full_len -= padding_length(ctx, tls_ctx, skb);
+               pad = padding_length(ctx, prot, skb);
+               if (pad < 0)
+                       return pad;
+
+               rxm->full_len -= pad;
                rxm->offset += prot->prepend_size;
                rxm->full_len -= prot->overhead_size;
                tls_advance_record_sn(sk, &tls_ctx->rx, version);