2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7 * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
9 * This software is available to you under a choice of one of two
10 * licenses. You may choose to be licensed under the terms of the GNU
11 * General Public License (GPL) Version 2, available from the file
12 * COPYING in the main directory of this source tree, or the
13 * OpenIB.org BSD license below:
15 * Redistribution and use in source and binary forms, with or
16 * without modification, are permitted provided that the following
19 * - Redistributions of source code must retain the above
20 * copyright notice, this list of conditions and the following
23 * - Redistributions in binary form must reproduce the above
24 * copyright notice, this list of conditions and the following
25 * disclaimer in the documentation and/or other materials
26 * provided with the distribution.
28 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
29 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
30 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
31 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
32 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
33 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
34 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38 #include <linux/bug.h>
39 #include <linux/sched/signal.h>
40 #include <linux/module.h>
41 #include <linux/splice.h>
42 #include <crypto/aead.h>
44 #include <net/strparser.h>
49 struct tls_decrypt_arg {
59 struct tls_decrypt_ctx {
61 u8 aad[TLS_MAX_AAD_SIZE];
63 struct scatterlist sg[];
66 noinline void tls_err_abort(struct sock *sk, int err)
68 WARN_ON_ONCE(err >= 0);
69 /* sk->sk_err should contain a positive error code. */
70 WRITE_ONCE(sk->sk_err, -err);
71 /* Paired with smp_rmb() in tcp_poll() */
76 static int __skb_nsg(struct sk_buff *skb, int offset, int len,
77 unsigned int recursion_level)
79 int start = skb_headlen(skb);
80 int i, chunk = start - offset;
81 struct sk_buff *frag_iter;
84 if (unlikely(recursion_level >= 24))
97 for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
100 WARN_ON(start > offset + len);
102 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
103 chunk = end - offset;
116 if (unlikely(skb_has_frag_list(skb))) {
117 skb_walk_frags(skb, frag_iter) {
120 WARN_ON(start > offset + len);
122 end = start + frag_iter->len;
123 chunk = end - offset;
127 ret = __skb_nsg(frag_iter, offset - start, chunk,
128 recursion_level + 1);
129 if (unlikely(ret < 0))
144 /* Return the number of scatterlist elements required to completely map the
145 * skb, or -EMSGSIZE if the recursion depth is exceeded.
147 static int skb_nsg(struct sk_buff *skb, int offset, int len)
149 return __skb_nsg(skb, offset, len, 0);
152 static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb,
153 struct tls_decrypt_arg *darg)
155 struct strp_msg *rxm = strp_msg(skb);
156 struct tls_msg *tlm = tls_msg(skb);
159 /* Determine zero-padding length */
160 if (prot->version == TLS_1_3_VERSION) {
161 int offset = rxm->full_len - TLS_TAG_SIZE - 1;
162 char content_type = darg->zc ? darg->tail : 0;
165 while (content_type == 0) {
166 if (offset < prot->prepend_size)
168 err = skb_copy_bits(skb, rxm->offset + offset,
177 tlm->control = content_type;
182 static void tls_decrypt_done(struct crypto_async_request *req, int err)
184 struct aead_request *aead_req = (struct aead_request *)req;
185 struct scatterlist *sgout = aead_req->dst;
186 struct scatterlist *sgin = aead_req->src;
187 struct tls_sw_context_rx *ctx;
188 struct tls_context *tls_ctx;
189 struct scatterlist *sg;
193 sk = (struct sock *)req->data;
194 tls_ctx = tls_get_ctx(sk);
195 ctx = tls_sw_ctx_rx(tls_ctx);
197 /* Propagate if there was an err */
200 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
201 ctx->async_wait.err = err;
202 tls_err_abort(sk, err);
205 /* Free the destination pages if skb was not decrypted inplace */
207 /* Skip the first S/G entry as it points to AAD */
208 for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
211 put_page(sg_page(sg));
217 spin_lock_bh(&ctx->decrypt_compl_lock);
218 if (!atomic_dec_return(&ctx->decrypt_pending))
219 complete(&ctx->async_wait.completion);
220 spin_unlock_bh(&ctx->decrypt_compl_lock);
223 static int tls_do_decryption(struct sock *sk,
224 struct scatterlist *sgin,
225 struct scatterlist *sgout,
228 struct aead_request *aead_req,
229 struct tls_decrypt_arg *darg)
231 struct tls_context *tls_ctx = tls_get_ctx(sk);
232 struct tls_prot_info *prot = &tls_ctx->prot_info;
233 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
236 aead_request_set_tfm(aead_req, ctx->aead_recv);
237 aead_request_set_ad(aead_req, prot->aad_size);
238 aead_request_set_crypt(aead_req, sgin, sgout,
239 data_len + prot->tag_size,
243 aead_request_set_callback(aead_req,
244 CRYPTO_TFM_REQ_MAY_BACKLOG,
245 tls_decrypt_done, sk);
246 atomic_inc(&ctx->decrypt_pending);
248 aead_request_set_callback(aead_req,
249 CRYPTO_TFM_REQ_MAY_BACKLOG,
250 crypto_req_done, &ctx->async_wait);
253 ret = crypto_aead_decrypt(aead_req);
254 if (ret == -EINPROGRESS) {
258 ret = crypto_wait_req(ret, &ctx->async_wait);
265 static void tls_trim_both_msgs(struct sock *sk, int target_size)
267 struct tls_context *tls_ctx = tls_get_ctx(sk);
268 struct tls_prot_info *prot = &tls_ctx->prot_info;
269 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
270 struct tls_rec *rec = ctx->open_rec;
272 sk_msg_trim(sk, &rec->msg_plaintext, target_size);
274 target_size += prot->overhead_size;
275 sk_msg_trim(sk, &rec->msg_encrypted, target_size);
278 static int tls_alloc_encrypted_msg(struct sock *sk, int len)
280 struct tls_context *tls_ctx = tls_get_ctx(sk);
281 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
282 struct tls_rec *rec = ctx->open_rec;
283 struct sk_msg *msg_en = &rec->msg_encrypted;
285 return sk_msg_alloc(sk, msg_en, len, 0);
288 static int tls_clone_plaintext_msg(struct sock *sk, int required)
290 struct tls_context *tls_ctx = tls_get_ctx(sk);
291 struct tls_prot_info *prot = &tls_ctx->prot_info;
292 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
293 struct tls_rec *rec = ctx->open_rec;
294 struct sk_msg *msg_pl = &rec->msg_plaintext;
295 struct sk_msg *msg_en = &rec->msg_encrypted;
298 /* We add page references worth len bytes from encrypted sg
299 * at the end of plaintext sg. It is guaranteed that msg_en
300 * has enough required room (ensured by caller).
302 len = required - msg_pl->sg.size;
304 /* Skip initial bytes in msg_en's data to be able to use
305 * same offset of both plain and encrypted data.
307 skip = prot->prepend_size + msg_pl->sg.size;
309 return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
312 static struct tls_rec *tls_get_rec(struct sock *sk)
314 struct tls_context *tls_ctx = tls_get_ctx(sk);
315 struct tls_prot_info *prot = &tls_ctx->prot_info;
316 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
317 struct sk_msg *msg_pl, *msg_en;
321 mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
323 rec = kzalloc(mem_size, sk->sk_allocation);
327 msg_pl = &rec->msg_plaintext;
328 msg_en = &rec->msg_encrypted;
333 sg_init_table(rec->sg_aead_in, 2);
334 sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
335 sg_unmark_end(&rec->sg_aead_in[1]);
337 sg_init_table(rec->sg_aead_out, 2);
338 sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
339 sg_unmark_end(&rec->sg_aead_out[1]);
344 static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
346 sk_msg_free(sk, &rec->msg_encrypted);
347 sk_msg_free(sk, &rec->msg_plaintext);
351 static void tls_free_open_rec(struct sock *sk)
353 struct tls_context *tls_ctx = tls_get_ctx(sk);
354 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
355 struct tls_rec *rec = ctx->open_rec;
358 tls_free_rec(sk, rec);
359 ctx->open_rec = NULL;
363 int tls_tx_records(struct sock *sk, int flags)
365 struct tls_context *tls_ctx = tls_get_ctx(sk);
366 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
367 struct tls_rec *rec, *tmp;
368 struct sk_msg *msg_en;
369 int tx_flags, rc = 0;
371 if (tls_is_partially_sent_record(tls_ctx)) {
372 rec = list_first_entry(&ctx->tx_list,
373 struct tls_rec, list);
376 tx_flags = rec->tx_flags;
380 rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
384 /* Full record has been transmitted.
385 * Remove the head of tx_list
387 list_del(&rec->list);
388 sk_msg_free(sk, &rec->msg_plaintext);
392 /* Tx all ready records */
393 list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
394 if (READ_ONCE(rec->tx_ready)) {
396 tx_flags = rec->tx_flags;
400 msg_en = &rec->msg_encrypted;
401 rc = tls_push_sg(sk, tls_ctx,
402 &msg_en->sg.data[msg_en->sg.curr],
407 list_del(&rec->list);
408 sk_msg_free(sk, &rec->msg_plaintext);
416 if (rc < 0 && rc != -EAGAIN)
417 tls_err_abort(sk, -EBADMSG);
422 static void tls_encrypt_done(struct crypto_async_request *req, int err)
424 struct aead_request *aead_req = (struct aead_request *)req;
425 struct sock *sk = req->data;
426 struct tls_context *tls_ctx = tls_get_ctx(sk);
427 struct tls_prot_info *prot = &tls_ctx->prot_info;
428 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
429 struct scatterlist *sge;
430 struct sk_msg *msg_en;
435 rec = container_of(aead_req, struct tls_rec, aead_req);
436 msg_en = &rec->msg_encrypted;
438 sge = sk_msg_elem(msg_en, msg_en->sg.curr);
439 sge->offset -= prot->prepend_size;
440 sge->length += prot->prepend_size;
442 /* Check if error is previously set on socket */
443 if (err || sk->sk_err) {
446 /* If err is already set on socket, return the same code */
448 ctx->async_wait.err = -sk->sk_err;
450 ctx->async_wait.err = err;
451 tls_err_abort(sk, err);
456 struct tls_rec *first_rec;
458 /* Mark the record as ready for transmission */
459 smp_store_mb(rec->tx_ready, true);
461 /* If received record is at head of tx_list, schedule tx */
462 first_rec = list_first_entry(&ctx->tx_list,
463 struct tls_rec, list);
464 if (rec == first_rec)
468 spin_lock_bh(&ctx->encrypt_compl_lock);
469 pending = atomic_dec_return(&ctx->encrypt_pending);
471 if (!pending && ctx->async_notify)
472 complete(&ctx->async_wait.completion);
473 spin_unlock_bh(&ctx->encrypt_compl_lock);
478 /* Schedule the transmission */
479 if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
480 schedule_delayed_work(&ctx->tx_work.work, 1);
483 static int tls_do_encryption(struct sock *sk,
484 struct tls_context *tls_ctx,
485 struct tls_sw_context_tx *ctx,
486 struct aead_request *aead_req,
487 size_t data_len, u32 start)
489 struct tls_prot_info *prot = &tls_ctx->prot_info;
490 struct tls_rec *rec = ctx->open_rec;
491 struct sk_msg *msg_en = &rec->msg_encrypted;
492 struct scatterlist *sge = sk_msg_elem(msg_en, start);
493 int rc, iv_offset = 0;
495 /* For CCM based ciphers, first byte of IV is a constant */
496 switch (prot->cipher_type) {
497 case TLS_CIPHER_AES_CCM_128:
498 rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
501 case TLS_CIPHER_SM4_CCM:
502 rec->iv_data[0] = TLS_SM4_CCM_IV_B0_BYTE;
507 memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
508 prot->iv_size + prot->salt_size);
510 tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset,
511 tls_ctx->tx.rec_seq);
513 sge->offset += prot->prepend_size;
514 sge->length -= prot->prepend_size;
516 msg_en->sg.curr = start;
518 aead_request_set_tfm(aead_req, ctx->aead_send);
519 aead_request_set_ad(aead_req, prot->aad_size);
520 aead_request_set_crypt(aead_req, rec->sg_aead_in,
522 data_len, rec->iv_data);
524 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
525 tls_encrypt_done, sk);
527 /* Add the record in tx_list */
528 list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
529 atomic_inc(&ctx->encrypt_pending);
531 rc = crypto_aead_encrypt(aead_req);
532 if (!rc || rc != -EINPROGRESS) {
533 atomic_dec(&ctx->encrypt_pending);
534 sge->offset -= prot->prepend_size;
535 sge->length += prot->prepend_size;
539 WRITE_ONCE(rec->tx_ready, true);
540 } else if (rc != -EINPROGRESS) {
541 list_del(&rec->list);
545 /* Unhook the record from context if encryption is not failure */
546 ctx->open_rec = NULL;
547 tls_advance_record_sn(sk, prot, &tls_ctx->tx);
551 static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
552 struct tls_rec **to, struct sk_msg *msg_opl,
553 struct sk_msg *msg_oen, u32 split_point,
554 u32 tx_overhead_size, u32 *orig_end)
556 u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
557 struct scatterlist *sge, *osge, *nsge;
558 u32 orig_size = msg_opl->sg.size;
559 struct scatterlist tmp = { };
560 struct sk_msg *msg_npl;
564 new = tls_get_rec(sk);
567 ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
568 tx_overhead_size, 0);
570 tls_free_rec(sk, new);
574 *orig_end = msg_opl->sg.end;
575 i = msg_opl->sg.start;
576 sge = sk_msg_elem(msg_opl, i);
577 while (apply && sge->length) {
578 if (sge->length > apply) {
579 u32 len = sge->length - apply;
581 get_page(sg_page(sge));
582 sg_set_page(&tmp, sg_page(sge), len,
583 sge->offset + apply);
588 apply -= sge->length;
589 bytes += sge->length;
592 sk_msg_iter_var_next(i);
593 if (i == msg_opl->sg.end)
595 sge = sk_msg_elem(msg_opl, i);
599 msg_opl->sg.curr = i;
600 msg_opl->sg.copybreak = 0;
601 msg_opl->apply_bytes = 0;
602 msg_opl->sg.size = bytes;
604 msg_npl = &new->msg_plaintext;
605 msg_npl->apply_bytes = apply;
606 msg_npl->sg.size = orig_size - bytes;
608 j = msg_npl->sg.start;
609 nsge = sk_msg_elem(msg_npl, j);
611 memcpy(nsge, &tmp, sizeof(*nsge));
612 sk_msg_iter_var_next(j);
613 nsge = sk_msg_elem(msg_npl, j);
616 osge = sk_msg_elem(msg_opl, i);
617 while (osge->length) {
618 memcpy(nsge, osge, sizeof(*nsge));
620 sk_msg_iter_var_next(i);
621 sk_msg_iter_var_next(j);
624 osge = sk_msg_elem(msg_opl, i);
625 nsge = sk_msg_elem(msg_npl, j);
629 msg_npl->sg.curr = j;
630 msg_npl->sg.copybreak = 0;
636 static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
637 struct tls_rec *from, u32 orig_end)
639 struct sk_msg *msg_npl = &from->msg_plaintext;
640 struct sk_msg *msg_opl = &to->msg_plaintext;
641 struct scatterlist *osge, *nsge;
645 sk_msg_iter_var_prev(i);
646 j = msg_npl->sg.start;
648 osge = sk_msg_elem(msg_opl, i);
649 nsge = sk_msg_elem(msg_npl, j);
651 if (sg_page(osge) == sg_page(nsge) &&
652 osge->offset + osge->length == nsge->offset) {
653 osge->length += nsge->length;
654 put_page(sg_page(nsge));
657 msg_opl->sg.end = orig_end;
658 msg_opl->sg.curr = orig_end;
659 msg_opl->sg.copybreak = 0;
660 msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
661 msg_opl->sg.size += msg_npl->sg.size;
663 sk_msg_free(sk, &to->msg_encrypted);
664 sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
669 static int tls_push_record(struct sock *sk, int flags,
670 unsigned char record_type)
672 struct tls_context *tls_ctx = tls_get_ctx(sk);
673 struct tls_prot_info *prot = &tls_ctx->prot_info;
674 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
675 struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
676 u32 i, split_point, orig_end;
677 struct sk_msg *msg_pl, *msg_en;
678 struct aead_request *req;
685 msg_pl = &rec->msg_plaintext;
686 msg_en = &rec->msg_encrypted;
688 split_point = msg_pl->apply_bytes;
689 split = split_point && split_point < msg_pl->sg.size;
690 if (unlikely((!split &&
692 prot->overhead_size > msg_en->sg.size) ||
695 prot->overhead_size > msg_en->sg.size))) {
697 split_point = msg_en->sg.size;
700 rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
701 split_point, prot->overhead_size,
705 /* This can happen if above tls_split_open_record allocates
706 * a single large encryption buffer instead of two smaller
707 * ones. In this case adjust pointers and continue without
710 if (!msg_pl->sg.size) {
711 tls_merge_open_record(sk, rec, tmp, orig_end);
712 msg_pl = &rec->msg_plaintext;
713 msg_en = &rec->msg_encrypted;
716 sk_msg_trim(sk, msg_en, msg_pl->sg.size +
717 prot->overhead_size);
720 rec->tx_flags = flags;
721 req = &rec->aead_req;
724 sk_msg_iter_var_prev(i);
726 rec->content_type = record_type;
727 if (prot->version == TLS_1_3_VERSION) {
728 /* Add content type to end of message. No padding added */
729 sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
730 sg_mark_end(&rec->sg_content_type);
731 sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
732 &rec->sg_content_type);
734 sg_mark_end(sk_msg_elem(msg_pl, i));
737 if (msg_pl->sg.end < msg_pl->sg.start) {
738 sg_chain(&msg_pl->sg.data[msg_pl->sg.start],
739 MAX_SKB_FRAGS - msg_pl->sg.start + 1,
743 i = msg_pl->sg.start;
744 sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
747 sk_msg_iter_var_prev(i);
748 sg_mark_end(sk_msg_elem(msg_en, i));
750 i = msg_en->sg.start;
751 sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
753 tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
754 tls_ctx->tx.rec_seq, record_type, prot);
756 tls_fill_prepend(tls_ctx,
757 page_address(sg_page(&msg_en->sg.data[i])) +
758 msg_en->sg.data[i].offset,
759 msg_pl->sg.size + prot->tail_size,
762 tls_ctx->pending_open_record_frags = false;
764 rc = tls_do_encryption(sk, tls_ctx, ctx, req,
765 msg_pl->sg.size + prot->tail_size, i);
767 if (rc != -EINPROGRESS) {
768 tls_err_abort(sk, -EBADMSG);
770 tls_ctx->pending_open_record_frags = true;
771 tls_merge_open_record(sk, rec, tmp, orig_end);
774 ctx->async_capable = 1;
777 msg_pl = &tmp->msg_plaintext;
778 msg_en = &tmp->msg_encrypted;
779 sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
780 tls_ctx->pending_open_record_frags = true;
784 return tls_tx_records(sk, flags);
787 static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
788 bool full_record, u8 record_type,
789 ssize_t *copied, int flags)
791 struct tls_context *tls_ctx = tls_get_ctx(sk);
792 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
793 struct sk_msg msg_redir = { };
794 struct sk_psock *psock;
795 struct sock *sk_redir;
797 bool enospc, policy, redir_ingress;
801 policy = !(flags & MSG_SENDPAGE_NOPOLICY);
802 psock = sk_psock_get(sk);
803 if (!psock || !policy) {
804 err = tls_push_record(sk, flags, record_type);
805 if (err && err != -EINPROGRESS && sk->sk_err == EBADMSG) {
806 *copied -= sk_msg_free(sk, msg);
807 tls_free_open_rec(sk);
811 sk_psock_put(sk, psock);
815 enospc = sk_msg_full(msg);
816 if (psock->eval == __SK_NONE) {
817 delta = msg->sg.size;
818 psock->eval = sk_psock_msg_verdict(sk, psock, msg);
819 delta -= msg->sg.size;
821 if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
822 !enospc && !full_record) {
828 if (msg->apply_bytes && msg->apply_bytes < send)
829 send = msg->apply_bytes;
831 switch (psock->eval) {
833 err = tls_push_record(sk, flags, record_type);
834 if (err && err != -EINPROGRESS && sk->sk_err == EBADMSG) {
835 *copied -= sk_msg_free(sk, msg);
836 tls_free_open_rec(sk);
842 redir_ingress = psock->redir_ingress;
843 sk_redir = psock->sk_redir;
844 memcpy(&msg_redir, msg, sizeof(*msg));
845 if (msg->apply_bytes < send)
846 msg->apply_bytes = 0;
848 msg->apply_bytes -= send;
849 sk_msg_return_zero(sk, msg, send);
850 msg->sg.size -= send;
852 err = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress,
853 &msg_redir, send, flags);
856 *copied -= sk_msg_free_nocharge(sk, &msg_redir);
859 if (msg->sg.size == 0)
860 tls_free_open_rec(sk);
864 sk_msg_free_partial(sk, msg, send);
865 if (msg->apply_bytes < send)
866 msg->apply_bytes = 0;
868 msg->apply_bytes -= send;
869 if (msg->sg.size == 0)
870 tls_free_open_rec(sk);
871 *copied -= (send + delta);
876 bool reset_eval = !ctx->open_rec;
880 msg = &rec->msg_plaintext;
881 if (!msg->apply_bytes)
885 psock->eval = __SK_NONE;
886 if (psock->sk_redir) {
887 sock_put(psock->sk_redir);
888 psock->sk_redir = NULL;
895 sk_psock_put(sk, psock);
899 static int tls_sw_push_pending_record(struct sock *sk, int flags)
901 struct tls_context *tls_ctx = tls_get_ctx(sk);
902 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
903 struct tls_rec *rec = ctx->open_rec;
904 struct sk_msg *msg_pl;
910 msg_pl = &rec->msg_plaintext;
911 copied = msg_pl->sg.size;
915 return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
919 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
921 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
922 struct tls_context *tls_ctx = tls_get_ctx(sk);
923 struct tls_prot_info *prot = &tls_ctx->prot_info;
924 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
925 bool async_capable = ctx->async_capable;
926 unsigned char record_type = TLS_RECORD_TYPE_DATA;
927 bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
928 bool eor = !(msg->msg_flags & MSG_MORE);
931 struct sk_msg *msg_pl, *msg_en;
942 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
946 ret = mutex_lock_interruptible(&tls_ctx->tx_lock);
951 if (unlikely(msg->msg_controllen)) {
952 ret = tls_process_cmsg(sk, msg, &record_type);
954 if (ret == -EINPROGRESS)
956 else if (ret != -EAGAIN)
961 while (msg_data_left(msg)) {
970 rec = ctx->open_rec = tls_get_rec(sk);
976 msg_pl = &rec->msg_plaintext;
977 msg_en = &rec->msg_encrypted;
979 orig_size = msg_pl->sg.size;
981 try_to_copy = msg_data_left(msg);
982 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
983 if (try_to_copy >= record_room) {
984 try_to_copy = record_room;
988 required_size = msg_pl->sg.size + try_to_copy +
991 if (!sk_stream_memory_free(sk))
992 goto wait_for_sndbuf;
995 ret = tls_alloc_encrypted_msg(sk, required_size);
998 goto wait_for_memory;
1000 /* Adjust try_to_copy according to the amount that was
1001 * actually allocated. The difference is due
1002 * to max sg elements limit
1004 try_to_copy -= required_size - msg_en->sg.size;
1008 if (!is_kvec && (full_record || eor) && !async_capable) {
1009 u32 first = msg_pl->sg.end;
1011 ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
1012 msg_pl, try_to_copy);
1014 goto fallback_to_reg_send;
1017 copied += try_to_copy;
1019 sk_msg_sg_copy_set(msg_pl, first);
1020 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1021 record_type, &copied,
1024 if (ret == -EINPROGRESS)
1026 else if (ret == -ENOMEM)
1027 goto wait_for_memory;
1028 else if (ctx->open_rec && ret == -ENOSPC)
1030 else if (ret != -EAGAIN)
1035 copied -= try_to_copy;
1036 sk_msg_sg_copy_clear(msg_pl, first);
1037 iov_iter_revert(&msg->msg_iter,
1038 msg_pl->sg.size - orig_size);
1039 fallback_to_reg_send:
1040 sk_msg_trim(sk, msg_pl, orig_size);
1043 required_size = msg_pl->sg.size + try_to_copy;
1045 ret = tls_clone_plaintext_msg(sk, required_size);
1050 /* Adjust try_to_copy according to the amount that was
1051 * actually allocated. The difference is due
1052 * to max sg elements limit
1054 try_to_copy -= required_size - msg_pl->sg.size;
1056 sk_msg_trim(sk, msg_en,
1057 msg_pl->sg.size + prot->overhead_size);
1061 ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
1062 msg_pl, try_to_copy);
1067 /* Open records defined only if successfully copied, otherwise
1068 * we would trim the sg but not reset the open record frags.
1070 tls_ctx->pending_open_record_frags = true;
1071 copied += try_to_copy;
1072 if (full_record || eor) {
1073 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1074 record_type, &copied,
1077 if (ret == -EINPROGRESS)
1079 else if (ret == -ENOMEM)
1080 goto wait_for_memory;
1081 else if (ret != -EAGAIN) {
1092 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1094 ret = sk_stream_wait_memory(sk, &timeo);
1098 tls_trim_both_msgs(sk, orig_size);
1102 if (ctx->open_rec && msg_en->sg.size < required_size)
1103 goto alloc_encrypted;
1108 } else if (num_zc) {
1109 /* Wait for pending encryptions to get completed */
1110 spin_lock_bh(&ctx->encrypt_compl_lock);
1111 ctx->async_notify = true;
1113 pending = atomic_read(&ctx->encrypt_pending);
1114 spin_unlock_bh(&ctx->encrypt_compl_lock);
1116 crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1118 reinit_completion(&ctx->async_wait.completion);
1120 /* There can be no concurrent accesses, since we have no
1121 * pending encrypt operations
1123 WRITE_ONCE(ctx->async_notify, false);
1125 if (ctx->async_wait.err) {
1126 ret = ctx->async_wait.err;
1131 /* Transmit if any encryptions have completed */
1132 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1133 cancel_delayed_work(&ctx->tx_work.work);
1134 tls_tx_records(sk, msg->msg_flags);
1138 ret = sk_stream_error(sk, msg->msg_flags, ret);
1141 mutex_unlock(&tls_ctx->tx_lock);
1142 return copied > 0 ? copied : ret;
1145 static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
1146 int offset, size_t size, int flags)
1148 long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
1149 struct tls_context *tls_ctx = tls_get_ctx(sk);
1150 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1151 struct tls_prot_info *prot = &tls_ctx->prot_info;
1152 unsigned char record_type = TLS_RECORD_TYPE_DATA;
1153 struct sk_msg *msg_pl;
1154 struct tls_rec *rec;
1162 eor = !(flags & MSG_SENDPAGE_NOTLAST);
1163 sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
1165 /* Call the sk_stream functions to manage the sndbuf mem. */
1167 size_t copy, required_size;
1175 rec = ctx->open_rec;
1177 rec = ctx->open_rec = tls_get_rec(sk);
1183 msg_pl = &rec->msg_plaintext;
1185 full_record = false;
1186 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
1188 if (copy >= record_room) {
1193 required_size = msg_pl->sg.size + copy + prot->overhead_size;
1195 if (!sk_stream_memory_free(sk))
1196 goto wait_for_sndbuf;
1198 ret = tls_alloc_encrypted_msg(sk, required_size);
1201 goto wait_for_memory;
1203 /* Adjust copy according to the amount that was
1204 * actually allocated. The difference is due
1205 * to max sg elements limit
1207 copy -= required_size - msg_pl->sg.size;
1211 sk_msg_page_add(msg_pl, page, copy, offset);
1212 sk_mem_charge(sk, copy);
1218 tls_ctx->pending_open_record_frags = true;
1219 if (full_record || eor || sk_msg_full(msg_pl)) {
1220 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1221 record_type, &copied, flags);
1223 if (ret == -EINPROGRESS)
1225 else if (ret == -ENOMEM)
1226 goto wait_for_memory;
1227 else if (ret != -EAGAIN) {
1236 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1238 ret = sk_stream_wait_memory(sk, &timeo);
1241 tls_trim_both_msgs(sk, msg_pl->sg.size);
1250 /* Transmit if any encryptions have completed */
1251 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1252 cancel_delayed_work(&ctx->tx_work.work);
1253 tls_tx_records(sk, flags);
1257 ret = sk_stream_error(sk, flags, ret);
1258 return copied > 0 ? copied : ret;
1261 int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
1262 int offset, size_t size, int flags)
1264 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1265 MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
1266 MSG_NO_SHARED_FRAGS))
1269 return tls_sw_do_sendpage(sk, page, offset, size, flags);
1272 int tls_sw_sendpage(struct sock *sk, struct page *page,
1273 int offset, size_t size, int flags)
1275 struct tls_context *tls_ctx = tls_get_ctx(sk);
1278 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1279 MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
1282 ret = mutex_lock_interruptible(&tls_ctx->tx_lock);
1286 ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
1288 mutex_unlock(&tls_ctx->tx_lock);
1293 tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
1296 struct tls_context *tls_ctx = tls_get_ctx(sk);
1297 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1298 DEFINE_WAIT_FUNC(wait, woken_wake_function);
1301 timeo = sock_rcvtimeo(sk, nonblock);
1303 while (!tls_strp_msg_ready(ctx)) {
1304 if (!sk_psock_queue_empty(psock))
1308 return sock_error(sk);
1310 if (!skb_queue_empty(&sk->sk_receive_queue)) {
1311 tls_strp_check_rcv(&ctx->strp);
1312 if (tls_strp_msg_ready(ctx))
1316 if (sk->sk_shutdown & RCV_SHUTDOWN)
1319 if (sock_flag(sk, SOCK_DONE))
1326 add_wait_queue(sk_sleep(sk), &wait);
1327 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1328 sk_wait_event(sk, &timeo,
1329 tls_strp_msg_ready(ctx) ||
1330 !sk_psock_queue_empty(psock),
1332 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1333 remove_wait_queue(sk_sleep(sk), &wait);
1335 /* Handle signals */
1336 if (signal_pending(current))
1337 return sock_intr_errno(timeo);
1340 tls_strp_msg_load(&ctx->strp, released);
1345 static int tls_setup_from_iter(struct iov_iter *from,
1346 int length, int *pages_used,
1347 struct scatterlist *to,
1350 int rc = 0, i = 0, num_elem = *pages_used, maxpages;
1351 struct page *pages[MAX_SKB_FRAGS];
1352 unsigned int size = 0;
1353 ssize_t copied, use;
1356 while (length > 0) {
1358 maxpages = to_max_pages - num_elem;
1359 if (maxpages == 0) {
1363 copied = iov_iter_get_pages2(from, pages,
1374 use = min_t(int, copied, PAGE_SIZE - offset);
1376 sg_set_page(&to[num_elem],
1377 pages[i], use, offset);
1378 sg_unmark_end(&to[num_elem]);
1379 /* We do not uncharge memory from this API */
1388 /* Mark the end in the last sg entry if newly added */
1389 if (num_elem > *pages_used)
1390 sg_mark_end(&to[num_elem - 1]);
1393 iov_iter_revert(from, size);
1394 *pages_used = num_elem;
1399 static struct sk_buff *
1400 tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb,
1401 unsigned int full_len)
1403 struct strp_msg *clr_rxm;
1404 struct sk_buff *clr_skb;
1407 clr_skb = alloc_skb_with_frags(0, full_len, TLS_PAGE_ORDER,
1408 &err, sk->sk_allocation);
1412 skb_copy_header(clr_skb, skb);
1413 clr_skb->len = full_len;
1414 clr_skb->data_len = full_len;
1416 clr_rxm = strp_msg(clr_skb);
1417 clr_rxm->offset = 0;
1424 * tls_decrypt_sw() and tls_decrypt_device() are decrypt handlers.
1425 * They must transform the darg in/out argument are as follows:
1427 * -------------------------------------------------------------------
1428 * zc | Zero-copy decrypt allowed | Zero-copy performed
1429 * async | Async decrypt allowed | Async crypto used / in progress
1430 * skb | * | Output skb
1432 * If ZC decryption was performed darg.skb will point to the input skb.
1435 /* This function decrypts the input skb into either out_iov or in out_sg
1436 * or in skb buffers itself. The input parameter 'darg->zc' indicates if
1437 * zero-copy mode needs to be tried or not. With zero-copy mode, either
1438 * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
1439 * NULL, then the decryption happens inside skb buffers itself, i.e.
1440 * zero-copy gets disabled and 'darg->zc' is updated.
1442 static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
1443 struct scatterlist *out_sg,
1444 struct tls_decrypt_arg *darg)
1446 struct tls_context *tls_ctx = tls_get_ctx(sk);
1447 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1448 struct tls_prot_info *prot = &tls_ctx->prot_info;
1449 int n_sgin, n_sgout, aead_size, err, pages = 0;
1450 struct sk_buff *skb = tls_strp_msg(ctx);
1451 const struct strp_msg *rxm = strp_msg(skb);
1452 const struct tls_msg *tlm = tls_msg(skb);
1453 struct aead_request *aead_req;
1454 struct scatterlist *sgin = NULL;
1455 struct scatterlist *sgout = NULL;
1456 const int data_len = rxm->full_len - prot->overhead_size;
1457 int tail_pages = !!prot->tail_size;
1458 struct tls_decrypt_ctx *dctx;
1459 struct sk_buff *clear_skb;
1463 n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
1464 rxm->full_len - prot->prepend_size);
1466 return n_sgin ?: -EBADMSG;
1468 if (darg->zc && (out_iov || out_sg)) {
1472 n_sgout = 1 + tail_pages +
1473 iov_iter_npages_cap(out_iov, INT_MAX, data_len);
1475 n_sgout = sg_nents(out_sg);
1479 clear_skb = tls_alloc_clrtxt_skb(sk, skb, rxm->full_len);
1483 n_sgout = 1 + skb_shinfo(clear_skb)->nr_frags;
1486 /* Increment to accommodate AAD */
1487 n_sgin = n_sgin + 1;
1489 /* Allocate a single block of memory which contains
1490 * aead_req || tls_decrypt_ctx.
1491 * Both structs are variable length.
1493 aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
1494 mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
1501 /* Segment the allocated memory */
1502 aead_req = (struct aead_request *)mem;
1503 dctx = (struct tls_decrypt_ctx *)(mem + aead_size);
1504 sgin = &dctx->sg[0];
1505 sgout = &dctx->sg[n_sgin];
1507 /* For CCM based ciphers, first byte of nonce+iv is a constant */
1508 switch (prot->cipher_type) {
1509 case TLS_CIPHER_AES_CCM_128:
1510 dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE;
1513 case TLS_CIPHER_SM4_CCM:
1514 dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
1520 if (prot->version == TLS_1_3_VERSION ||
1521 prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
1522 memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv,
1523 prot->iv_size + prot->salt_size);
1525 err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
1526 &dctx->iv[iv_offset] + prot->salt_size,
1530 memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size);
1532 tls_xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq);
1535 tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size +
1537 tls_ctx->rx.rec_seq, tlm->control, prot);
1540 sg_init_table(sgin, n_sgin);
1541 sg_set_buf(&sgin[0], dctx->aad, prot->aad_size);
1542 err = skb_to_sgvec(skb, &sgin[1],
1543 rxm->offset + prot->prepend_size,
1544 rxm->full_len - prot->prepend_size);
1549 sg_init_table(sgout, n_sgout);
1550 sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
1552 err = skb_to_sgvec(clear_skb, &sgout[1], prot->prepend_size,
1553 data_len + prot->tail_size);
1556 } else if (out_iov) {
1557 sg_init_table(sgout, n_sgout);
1558 sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
1560 err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1],
1561 (n_sgout - 1 - tail_pages));
1563 goto exit_free_pages;
1565 if (prot->tail_size) {
1566 sg_unmark_end(&sgout[pages]);
1567 sg_set_buf(&sgout[pages + 1], &dctx->tail,
1569 sg_mark_end(&sgout[pages + 1]);
1571 } else if (out_sg) {
1572 memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
1575 /* Prepare and submit AEAD request */
1576 err = tls_do_decryption(sk, sgin, sgout, dctx->iv,
1577 data_len + prot->tail_size, aead_req, darg);
1579 goto exit_free_pages;
1581 darg->skb = clear_skb ?: tls_strp_msg(ctx);
1584 if (unlikely(darg->async)) {
1585 err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold);
1587 __skb_queue_tail(&ctx->async_hold, darg->skb);
1591 if (prot->tail_size)
1592 darg->tail = dctx->tail;
1595 /* Release the pages in case iov was mapped to pages */
1596 for (; pages > 0; pages--)
1597 put_page(sg_page(&sgout[pages]));
1601 consume_skb(clear_skb);
1606 tls_decrypt_sw(struct sock *sk, struct tls_context *tls_ctx,
1607 struct msghdr *msg, struct tls_decrypt_arg *darg)
1609 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1610 struct tls_prot_info *prot = &tls_ctx->prot_info;
1611 struct strp_msg *rxm;
1614 err = tls_decrypt_sg(sk, &msg->msg_iter, NULL, darg);
1616 if (err == -EBADMSG)
1617 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
1620 /* keep going even for ->async, the code below is TLS 1.3 */
1622 /* If opportunistic TLS 1.3 ZC failed retry without ZC */
1623 if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
1624 darg->tail != TLS_RECORD_TYPE_DATA)) {
1627 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL);
1628 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY);
1629 return tls_decrypt_sw(sk, tls_ctx, msg, darg);
1632 pad = tls_padding_length(prot, darg->skb, darg);
1634 if (darg->skb != tls_strp_msg(ctx))
1635 consume_skb(darg->skb);
1639 rxm = strp_msg(darg->skb);
1640 rxm->full_len -= pad;
1646 tls_decrypt_device(struct sock *sk, struct msghdr *msg,
1647 struct tls_context *tls_ctx, struct tls_decrypt_arg *darg)
1649 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1650 struct tls_prot_info *prot = &tls_ctx->prot_info;
1651 struct strp_msg *rxm;
1654 if (tls_ctx->rx_conf != TLS_HW)
1657 err = tls_device_decrypted(sk, tls_ctx);
1661 pad = tls_padding_length(prot, tls_strp_msg(ctx), darg);
1665 darg->async = false;
1666 darg->skb = tls_strp_msg(ctx);
1667 /* ->zc downgrade check, in case TLS 1.3 gets here */
1668 darg->zc &= !(prot->version == TLS_1_3_VERSION &&
1669 tls_msg(darg->skb)->control != TLS_RECORD_TYPE_DATA);
1671 rxm = strp_msg(darg->skb);
1672 rxm->full_len -= pad;
1675 /* Non-ZC case needs a real skb */
1676 darg->skb = tls_strp_msg_detach(ctx);
1680 unsigned int off, len;
1682 /* In ZC case nobody cares about the output skb.
1683 * Just copy the data here. Note the skb is not fully trimmed.
1685 off = rxm->offset + prot->prepend_size;
1686 len = rxm->full_len - prot->overhead_size;
1688 err = skb_copy_datagram_msg(darg->skb, off, msg, len);
1695 static int tls_rx_one_record(struct sock *sk, struct msghdr *msg,
1696 struct tls_decrypt_arg *darg)
1698 struct tls_context *tls_ctx = tls_get_ctx(sk);
1699 struct tls_prot_info *prot = &tls_ctx->prot_info;
1700 struct strp_msg *rxm;
1703 err = tls_decrypt_device(sk, msg, tls_ctx, darg);
1705 err = tls_decrypt_sw(sk, tls_ctx, msg, darg);
1709 rxm = strp_msg(darg->skb);
1710 rxm->offset += prot->prepend_size;
1711 rxm->full_len -= prot->overhead_size;
1712 tls_advance_record_sn(sk, prot, &tls_ctx->rx);
1717 int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
1719 struct tls_decrypt_arg darg = { .zc = true, };
1721 return tls_decrypt_sg(sk, NULL, sgout, &darg);
1724 static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
1730 *control = tlm->control;
1734 err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1735 sizeof(*control), control);
1736 if (*control != TLS_RECORD_TYPE_DATA) {
1737 if (err || msg->msg_flags & MSG_CTRUNC)
1740 } else if (*control != tlm->control) {
1747 static void tls_rx_rec_done(struct tls_sw_context_rx *ctx)
1749 tls_strp_msg_done(&ctx->strp);
1752 /* This function traverses the rx_list in tls receive context to copies the
1753 * decrypted records into the buffer provided by caller zero copy is not
1754 * true. Further, the records are removed from the rx_list if it is not a peek
1755 * case and the record has been consumed completely.
1757 static int process_rx_list(struct tls_sw_context_rx *ctx,
1764 struct sk_buff *skb = skb_peek(&ctx->rx_list);
1765 struct tls_msg *tlm;
1769 while (skip && skb) {
1770 struct strp_msg *rxm = strp_msg(skb);
1773 err = tls_record_content_type(msg, tlm, control);
1777 if (skip < rxm->full_len)
1780 skip = skip - rxm->full_len;
1781 skb = skb_peek_next(skb, &ctx->rx_list);
1784 while (len && skb) {
1785 struct sk_buff *next_skb;
1786 struct strp_msg *rxm = strp_msg(skb);
1787 int chunk = min_t(unsigned int, rxm->full_len - skip, len);
1791 err = tls_record_content_type(msg, tlm, control);
1795 err = skb_copy_datagram_msg(skb, rxm->offset + skip,
1801 copied = copied + chunk;
1803 /* Consume the data from record if it is non-peek case*/
1805 rxm->offset = rxm->offset + chunk;
1806 rxm->full_len = rxm->full_len - chunk;
1808 /* Return if there is unconsumed data in the record */
1809 if (rxm->full_len - skip)
1813 /* The remaining skip-bytes must lie in 1st record in rx_list.
1814 * So from the 2nd record, 'skip' should be 0.
1819 msg->msg_flags |= MSG_EOR;
1821 next_skb = skb_peek_next(skb, &ctx->rx_list);
1824 __skb_unlink(skb, &ctx->rx_list);
1833 return copied ? : err;
1837 tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
1838 size_t len_left, size_t decrypted, ssize_t done,
1843 if (len_left <= decrypted)
1846 max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE;
1847 if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec)
1851 return sk_flush_backlog(sk);
1854 static int tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx,
1862 timeo = sock_rcvtimeo(sk, nonblock);
1864 while (unlikely(ctx->reader_present)) {
1865 DEFINE_WAIT_FUNC(wait, woken_wake_function);
1867 ctx->reader_contended = 1;
1869 add_wait_queue(&ctx->wq, &wait);
1870 sk_wait_event(sk, &timeo,
1871 !READ_ONCE(ctx->reader_present), &wait);
1872 remove_wait_queue(&ctx->wq, &wait);
1878 if (signal_pending(current)) {
1879 err = sock_intr_errno(timeo);
1884 WRITE_ONCE(ctx->reader_present, 1);
1893 static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx)
1895 if (unlikely(ctx->reader_contended)) {
1896 if (wq_has_sleeper(&ctx->wq))
1899 ctx->reader_contended = 0;
1901 WARN_ON_ONCE(!ctx->reader_present);
1904 WRITE_ONCE(ctx->reader_present, 0);
1908 int tls_sw_recvmsg(struct sock *sk,
1914 struct tls_context *tls_ctx = tls_get_ctx(sk);
1915 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1916 struct tls_prot_info *prot = &tls_ctx->prot_info;
1917 ssize_t decrypted = 0, async_copy_bytes = 0;
1918 struct sk_psock *psock;
1919 unsigned char control = 0;
1920 size_t flushed_at = 0;
1921 struct strp_msg *rxm;
1922 struct tls_msg *tlm;
1926 bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
1927 bool is_peek = flags & MSG_PEEK;
1928 bool released = true;
1929 bool bpf_strp_enabled;
1932 if (unlikely(flags & MSG_ERRQUEUE))
1933 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
1935 psock = sk_psock_get(sk);
1936 err = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT);
1939 bpf_strp_enabled = sk_psock_strp_enabled(psock);
1941 /* If crypto failed the connection is broken */
1942 err = ctx->async_wait.err;
1946 /* Process pending decrypted records. It must be non-zero-copy */
1947 err = process_rx_list(ctx, msg, &control, 0, len, is_peek);
1955 target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1958 zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
1961 while (len && (decrypted + copied < target || tls_strp_msg_ready(ctx))) {
1962 struct tls_decrypt_arg darg;
1963 int to_decrypt, chunk;
1965 err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT,
1969 chunk = sk_msg_recvmsg(sk, psock, msg, len,
1980 memset(&darg.inargs, 0, sizeof(darg.inargs));
1982 rxm = strp_msg(tls_strp_msg(ctx));
1983 tlm = tls_msg(tls_strp_msg(ctx));
1985 to_decrypt = rxm->full_len - prot->overhead_size;
1987 if (zc_capable && to_decrypt <= len &&
1988 tlm->control == TLS_RECORD_TYPE_DATA)
1991 /* Do not use async mode if record is non-data */
1992 if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
1993 darg.async = ctx->async_capable;
1997 err = tls_rx_one_record(sk, msg, &darg);
1999 tls_err_abort(sk, -EBADMSG);
2003 async |= darg.async;
2005 /* If the type of records being processed is not known yet,
2006 * set it to record type just dequeued. If it is already known,
2007 * but does not match the record type just dequeued, go to end.
2008 * We always get record type here since for tls1.2, record type
2009 * is known just after record is dequeued from stream parser.
2010 * For tls1.3, we disable async.
2012 err = tls_record_content_type(msg, tls_msg(darg.skb), &control);
2014 DEBUG_NET_WARN_ON_ONCE(darg.zc);
2015 tls_rx_rec_done(ctx);
2017 __skb_queue_tail(&ctx->rx_list, darg.skb);
2021 /* periodically flush backlog, and feed strparser */
2022 released = tls_read_flush_backlog(sk, prot, len, to_decrypt,
2026 /* TLS 1.3 may have updated the length by more than overhead */
2027 rxm = strp_msg(darg.skb);
2028 chunk = rxm->full_len;
2029 tls_rx_rec_done(ctx);
2032 bool partially_consumed = chunk > len;
2033 struct sk_buff *skb = darg.skb;
2035 DEBUG_NET_WARN_ON_ONCE(darg.skb == ctx->strp.anchor);
2038 /* TLS 1.2-only, to_decrypt must be text len */
2039 chunk = min_t(int, to_decrypt, len);
2040 async_copy_bytes += chunk;
2044 __skb_queue_tail(&ctx->rx_list, skb);
2048 if (bpf_strp_enabled) {
2050 err = sk_psock_tls_strp_read(psock, skb);
2051 if (err != __SK_PASS) {
2052 rxm->offset = rxm->offset + rxm->full_len;
2054 if (err == __SK_DROP)
2060 if (partially_consumed)
2063 err = skb_copy_datagram_msg(skb, rxm->offset,
2066 goto put_on_rx_list_err;
2069 goto put_on_rx_list;
2071 if (partially_consumed) {
2072 rxm->offset += chunk;
2073 rxm->full_len -= chunk;
2074 goto put_on_rx_list;
2083 /* Return full control message to userspace before trying
2084 * to parse another message type
2086 msg->msg_flags |= MSG_EOR;
2087 if (control != TLS_RECORD_TYPE_DATA)
2095 /* Wait for all previously submitted records to be decrypted */
2096 spin_lock_bh(&ctx->decrypt_compl_lock);
2097 reinit_completion(&ctx->async_wait.completion);
2098 pending = atomic_read(&ctx->decrypt_pending);
2099 spin_unlock_bh(&ctx->decrypt_compl_lock);
2102 ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2103 __skb_queue_purge(&ctx->async_hold);
2106 if (err >= 0 || err == -EINPROGRESS)
2112 /* Drain records from the rx_list & copy if required */
2113 if (is_peek || is_kvec)
2114 err = process_rx_list(ctx, msg, &control, copied,
2115 decrypted, is_peek);
2117 err = process_rx_list(ctx, msg, &control, 0,
2118 async_copy_bytes, is_peek);
2119 decrypted += max(err, 0);
2122 copied += decrypted;
2125 tls_rx_reader_unlock(sk, ctx);
2127 sk_psock_put(sk, psock);
2128 return copied ? : err;
2131 ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
2132 struct pipe_inode_info *pipe,
2133 size_t len, unsigned int flags)
2135 struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
2136 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2137 struct strp_msg *rxm = NULL;
2138 struct sock *sk = sock->sk;
2139 struct tls_msg *tlm;
2140 struct sk_buff *skb;
2145 err = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK);
2149 if (!skb_queue_empty(&ctx->rx_list)) {
2150 skb = __skb_dequeue(&ctx->rx_list);
2152 struct tls_decrypt_arg darg;
2154 err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
2157 goto splice_read_end;
2159 memset(&darg.inargs, 0, sizeof(darg.inargs));
2161 err = tls_rx_one_record(sk, NULL, &darg);
2163 tls_err_abort(sk, -EBADMSG);
2164 goto splice_read_end;
2167 tls_rx_rec_done(ctx);
2171 rxm = strp_msg(skb);
2174 /* splice does not support reading control messages */
2175 if (tlm->control != TLS_RECORD_TYPE_DATA) {
2177 goto splice_requeue;
2180 chunk = min_t(unsigned int, rxm->full_len, len);
2181 copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
2183 goto splice_requeue;
2185 if (chunk < rxm->full_len) {
2187 rxm->full_len -= len;
2188 goto splice_requeue;
2194 tls_rx_reader_unlock(sk, ctx);
2195 return copied ? : err;
2198 __skb_queue_head(&ctx->rx_list, skb);
2199 goto splice_read_end;
2202 bool tls_sw_sock_is_readable(struct sock *sk)
2204 struct tls_context *tls_ctx = tls_get_ctx(sk);
2205 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2206 bool ingress_empty = true;
2207 struct sk_psock *psock;
2210 psock = sk_psock(sk);
2212 ingress_empty = list_empty(&psock->ingress_msg);
2215 return !ingress_empty || tls_strp_msg_ready(ctx) ||
2216 !skb_queue_empty(&ctx->rx_list);
2219 int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb)
2221 struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
2222 struct tls_prot_info *prot = &tls_ctx->prot_info;
2223 char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
2224 size_t cipher_overhead;
2225 size_t data_len = 0;
2228 /* Verify that we have a full TLS header, or wait for more data */
2229 if (strp->stm.offset + prot->prepend_size > skb->len)
2232 /* Sanity-check size of on-stack buffer. */
2233 if (WARN_ON(prot->prepend_size > sizeof(header))) {
2238 /* Linearize header to local buffer */
2239 ret = skb_copy_bits(skb, strp->stm.offset, header, prot->prepend_size);
2243 strp->mark = header[0];
2245 data_len = ((header[4] & 0xFF) | (header[3] << 8));
2247 cipher_overhead = prot->tag_size;
2248 if (prot->version != TLS_1_3_VERSION &&
2249 prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305)
2250 cipher_overhead += prot->iv_size;
2252 if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
2257 if (data_len < cipher_overhead) {
2262 /* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */
2263 if (header[1] != TLS_1_2_VERSION_MINOR ||
2264 header[2] != TLS_1_2_VERSION_MAJOR) {
2269 tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
2270 TCP_SKB_CB(skb)->seq + strp->stm.offset);
2271 return data_len + TLS_HEADER_SIZE;
2274 tls_err_abort(strp->sk, ret);
2279 void tls_rx_msg_ready(struct tls_strparser *strp)
2281 struct tls_sw_context_rx *ctx;
2283 ctx = container_of(strp, struct tls_sw_context_rx, strp);
2284 ctx->saved_data_ready(strp->sk);
2287 static void tls_data_ready(struct sock *sk)
2289 struct tls_context *tls_ctx = tls_get_ctx(sk);
2290 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2291 struct sk_psock *psock;
2294 alloc_save = sk->sk_allocation;
2295 sk->sk_allocation = GFP_ATOMIC;
2296 tls_strp_data_ready(&ctx->strp);
2297 sk->sk_allocation = alloc_save;
2299 psock = sk_psock_get(sk);
2301 if (!list_empty(&psock->ingress_msg))
2302 ctx->saved_data_ready(sk);
2303 sk_psock_put(sk, psock);
2307 void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
2309 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2311 set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
2312 set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
2313 cancel_delayed_work_sync(&ctx->tx_work.work);
2316 void tls_sw_release_resources_tx(struct sock *sk)
2318 struct tls_context *tls_ctx = tls_get_ctx(sk);
2319 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2320 struct tls_rec *rec, *tmp;
2323 /* Wait for any pending async encryptions to complete */
2324 spin_lock_bh(&ctx->encrypt_compl_lock);
2325 ctx->async_notify = true;
2326 pending = atomic_read(&ctx->encrypt_pending);
2327 spin_unlock_bh(&ctx->encrypt_compl_lock);
2330 crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2332 tls_tx_records(sk, -1);
2334 /* Free up un-sent records in tx_list. First, free
2335 * the partially sent record if any at head of tx_list.
2337 if (tls_ctx->partially_sent_record) {
2338 tls_free_partial_record(sk, tls_ctx);
2339 rec = list_first_entry(&ctx->tx_list,
2340 struct tls_rec, list);
2341 list_del(&rec->list);
2342 sk_msg_free(sk, &rec->msg_plaintext);
2346 list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
2347 list_del(&rec->list);
2348 sk_msg_free(sk, &rec->msg_encrypted);
2349 sk_msg_free(sk, &rec->msg_plaintext);
2353 crypto_free_aead(ctx->aead_send);
2354 tls_free_open_rec(sk);
2357 void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
2359 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2364 void tls_sw_release_resources_rx(struct sock *sk)
2366 struct tls_context *tls_ctx = tls_get_ctx(sk);
2367 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2369 kfree(tls_ctx->rx.rec_seq);
2370 kfree(tls_ctx->rx.iv);
2372 if (ctx->aead_recv) {
2373 __skb_queue_purge(&ctx->rx_list);
2374 crypto_free_aead(ctx->aead_recv);
2375 tls_strp_stop(&ctx->strp);
2376 /* If tls_sw_strparser_arm() was not called (cleanup paths)
2377 * we still want to tls_strp_stop(), but sk->sk_data_ready was
2380 if (ctx->saved_data_ready) {
2381 write_lock_bh(&sk->sk_callback_lock);
2382 sk->sk_data_ready = ctx->saved_data_ready;
2383 write_unlock_bh(&sk->sk_callback_lock);
2388 void tls_sw_strparser_done(struct tls_context *tls_ctx)
2390 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2392 tls_strp_done(&ctx->strp);
2395 void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
2397 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2402 void tls_sw_free_resources_rx(struct sock *sk)
2404 struct tls_context *tls_ctx = tls_get_ctx(sk);
2406 tls_sw_release_resources_rx(sk);
2407 tls_sw_free_ctx_rx(tls_ctx);
2410 /* The work handler to transmitt the encrypted records in tx_list */
2411 static void tx_work_handler(struct work_struct *work)
2413 struct delayed_work *delayed_work = to_delayed_work(work);
2414 struct tx_work *tx_work = container_of(delayed_work,
2415 struct tx_work, work);
2416 struct sock *sk = tx_work->sk;
2417 struct tls_context *tls_ctx = tls_get_ctx(sk);
2418 struct tls_sw_context_tx *ctx;
2420 if (unlikely(!tls_ctx))
2423 ctx = tls_sw_ctx_tx(tls_ctx);
2424 if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
2427 if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
2430 if (mutex_trylock(&tls_ctx->tx_lock)) {
2432 tls_tx_records(sk, -1);
2434 mutex_unlock(&tls_ctx->tx_lock);
2435 } else if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
2436 /* Someone is holding the tx_lock, they will likely run Tx
2437 * and cancel the work on their way out of the lock section.
2438 * Schedule a long delay just in case.
2440 schedule_delayed_work(&ctx->tx_work.work, msecs_to_jiffies(10));
2444 static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx)
2446 struct tls_rec *rec;
2448 rec = list_first_entry_or_null(&ctx->tx_list, struct tls_rec, list);
2452 return READ_ONCE(rec->tx_ready);
2455 void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
2457 struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
2459 /* Schedule the transmission if tx list is ready */
2460 if (tls_is_tx_ready(tx_ctx) &&
2461 !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
2462 schedule_delayed_work(&tx_ctx->tx_work.work, 0);
2465 void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
2467 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2469 write_lock_bh(&sk->sk_callback_lock);
2470 rx_ctx->saved_data_ready = sk->sk_data_ready;
2471 sk->sk_data_ready = tls_data_ready;
2472 write_unlock_bh(&sk->sk_callback_lock);
2475 void tls_update_rx_zc_capable(struct tls_context *tls_ctx)
2477 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2479 rx_ctx->zc_capable = tls_ctx->rx_no_pad ||
2480 tls_ctx->prot_info.version != TLS_1_3_VERSION;
2483 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
2485 struct tls_context *tls_ctx = tls_get_ctx(sk);
2486 struct tls_prot_info *prot = &tls_ctx->prot_info;
2487 struct tls_crypto_info *crypto_info;
2488 struct tls_sw_context_tx *sw_ctx_tx = NULL;
2489 struct tls_sw_context_rx *sw_ctx_rx = NULL;
2490 struct cipher_context *cctx;
2491 struct crypto_aead **aead;
2492 u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
2493 struct crypto_tfm *tfm;
2494 char *iv, *rec_seq, *key, *salt, *cipher_name;
2504 if (!ctx->priv_ctx_tx) {
2505 sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
2510 ctx->priv_ctx_tx = sw_ctx_tx;
2513 (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
2516 if (!ctx->priv_ctx_rx) {
2517 sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
2522 ctx->priv_ctx_rx = sw_ctx_rx;
2525 (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
2530 crypto_init_wait(&sw_ctx_tx->async_wait);
2531 spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
2532 crypto_info = &ctx->crypto_send.info;
2534 aead = &sw_ctx_tx->aead_send;
2535 INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
2536 INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
2537 sw_ctx_tx->tx_work.sk = sk;
2539 crypto_init_wait(&sw_ctx_rx->async_wait);
2540 spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
2541 init_waitqueue_head(&sw_ctx_rx->wq);
2542 crypto_info = &ctx->crypto_recv.info;
2544 skb_queue_head_init(&sw_ctx_rx->rx_list);
2545 skb_queue_head_init(&sw_ctx_rx->async_hold);
2546 aead = &sw_ctx_rx->aead_recv;
2549 switch (crypto_info->cipher_type) {
2550 case TLS_CIPHER_AES_GCM_128: {
2551 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
2553 gcm_128_info = (void *)crypto_info;
2554 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2555 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
2556 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
2557 iv = gcm_128_info->iv;
2558 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
2559 rec_seq = gcm_128_info->rec_seq;
2560 keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
2561 key = gcm_128_info->key;
2562 salt = gcm_128_info->salt;
2563 salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
2564 cipher_name = "gcm(aes)";
2567 case TLS_CIPHER_AES_GCM_256: {
2568 struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
2570 gcm_256_info = (void *)crypto_info;
2571 nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2572 tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
2573 iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2574 iv = gcm_256_info->iv;
2575 rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
2576 rec_seq = gcm_256_info->rec_seq;
2577 keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
2578 key = gcm_256_info->key;
2579 salt = gcm_256_info->salt;
2580 salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
2581 cipher_name = "gcm(aes)";
2584 case TLS_CIPHER_AES_CCM_128: {
2585 struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
2587 ccm_128_info = (void *)crypto_info;
2588 nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2589 tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
2590 iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2591 iv = ccm_128_info->iv;
2592 rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
2593 rec_seq = ccm_128_info->rec_seq;
2594 keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
2595 key = ccm_128_info->key;
2596 salt = ccm_128_info->salt;
2597 salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
2598 cipher_name = "ccm(aes)";
2601 case TLS_CIPHER_CHACHA20_POLY1305: {
2602 struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info;
2604 chacha20_poly1305_info = (void *)crypto_info;
2606 tag_size = TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE;
2607 iv_size = TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE;
2608 iv = chacha20_poly1305_info->iv;
2609 rec_seq_size = TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE;
2610 rec_seq = chacha20_poly1305_info->rec_seq;
2611 keysize = TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE;
2612 key = chacha20_poly1305_info->key;
2613 salt = chacha20_poly1305_info->salt;
2614 salt_size = TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE;
2615 cipher_name = "rfc7539(chacha20,poly1305)";
2618 case TLS_CIPHER_SM4_GCM: {
2619 struct tls12_crypto_info_sm4_gcm *sm4_gcm_info;
2621 sm4_gcm_info = (void *)crypto_info;
2622 nonce_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
2623 tag_size = TLS_CIPHER_SM4_GCM_TAG_SIZE;
2624 iv_size = TLS_CIPHER_SM4_GCM_IV_SIZE;
2625 iv = sm4_gcm_info->iv;
2626 rec_seq_size = TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE;
2627 rec_seq = sm4_gcm_info->rec_seq;
2628 keysize = TLS_CIPHER_SM4_GCM_KEY_SIZE;
2629 key = sm4_gcm_info->key;
2630 salt = sm4_gcm_info->salt;
2631 salt_size = TLS_CIPHER_SM4_GCM_SALT_SIZE;
2632 cipher_name = "gcm(sm4)";
2635 case TLS_CIPHER_SM4_CCM: {
2636 struct tls12_crypto_info_sm4_ccm *sm4_ccm_info;
2638 sm4_ccm_info = (void *)crypto_info;
2639 nonce_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
2640 tag_size = TLS_CIPHER_SM4_CCM_TAG_SIZE;
2641 iv_size = TLS_CIPHER_SM4_CCM_IV_SIZE;
2642 iv = sm4_ccm_info->iv;
2643 rec_seq_size = TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE;
2644 rec_seq = sm4_ccm_info->rec_seq;
2645 keysize = TLS_CIPHER_SM4_CCM_KEY_SIZE;
2646 key = sm4_ccm_info->key;
2647 salt = sm4_ccm_info->salt;
2648 salt_size = TLS_CIPHER_SM4_CCM_SALT_SIZE;
2649 cipher_name = "ccm(sm4)";
2652 case TLS_CIPHER_ARIA_GCM_128: {
2653 struct tls12_crypto_info_aria_gcm_128 *aria_gcm_128_info;
2655 aria_gcm_128_info = (void *)crypto_info;
2656 nonce_size = TLS_CIPHER_ARIA_GCM_128_IV_SIZE;
2657 tag_size = TLS_CIPHER_ARIA_GCM_128_TAG_SIZE;
2658 iv_size = TLS_CIPHER_ARIA_GCM_128_IV_SIZE;
2659 iv = aria_gcm_128_info->iv;
2660 rec_seq_size = TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE;
2661 rec_seq = aria_gcm_128_info->rec_seq;
2662 keysize = TLS_CIPHER_ARIA_GCM_128_KEY_SIZE;
2663 key = aria_gcm_128_info->key;
2664 salt = aria_gcm_128_info->salt;
2665 salt_size = TLS_CIPHER_ARIA_GCM_128_SALT_SIZE;
2666 cipher_name = "gcm(aria)";
2669 case TLS_CIPHER_ARIA_GCM_256: {
2670 struct tls12_crypto_info_aria_gcm_256 *gcm_256_info;
2672 gcm_256_info = (void *)crypto_info;
2673 nonce_size = TLS_CIPHER_ARIA_GCM_256_IV_SIZE;
2674 tag_size = TLS_CIPHER_ARIA_GCM_256_TAG_SIZE;
2675 iv_size = TLS_CIPHER_ARIA_GCM_256_IV_SIZE;
2676 iv = gcm_256_info->iv;
2677 rec_seq_size = TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE;
2678 rec_seq = gcm_256_info->rec_seq;
2679 keysize = TLS_CIPHER_ARIA_GCM_256_KEY_SIZE;
2680 key = gcm_256_info->key;
2681 salt = gcm_256_info->salt;
2682 salt_size = TLS_CIPHER_ARIA_GCM_256_SALT_SIZE;
2683 cipher_name = "gcm(aria)";
2691 if (crypto_info->version == TLS_1_3_VERSION) {
2693 prot->aad_size = TLS_HEADER_SIZE;
2694 prot->tail_size = 1;
2696 prot->aad_size = TLS_AAD_SPACE_SIZE;
2697 prot->tail_size = 0;
2700 /* Sanity-check the sizes for stack allocations. */
2701 if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
2702 rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE ||
2703 prot->aad_size > TLS_MAX_AAD_SIZE) {
2708 prot->version = crypto_info->version;
2709 prot->cipher_type = crypto_info->cipher_type;
2710 prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
2711 prot->tag_size = tag_size;
2712 prot->overhead_size = prot->prepend_size +
2713 prot->tag_size + prot->tail_size;
2714 prot->iv_size = iv_size;
2715 prot->salt_size = salt_size;
2716 cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
2721 /* Note: 128 & 256 bit salt are the same size */
2722 prot->rec_seq_size = rec_seq_size;
2723 memcpy(cctx->iv, salt, salt_size);
2724 memcpy(cctx->iv + salt_size, iv, iv_size);
2725 cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
2726 if (!cctx->rec_seq) {
2732 *aead = crypto_alloc_aead(cipher_name, 0, 0);
2733 if (IS_ERR(*aead)) {
2734 rc = PTR_ERR(*aead);
2740 ctx->push_pending_record = tls_sw_push_pending_record;
2742 rc = crypto_aead_setkey(*aead, key, keysize);
2747 rc = crypto_aead_setauthsize(*aead, prot->tag_size);
2752 tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
2754 tls_update_rx_zc_capable(ctx);
2755 sw_ctx_rx->async_capable =
2756 crypto_info->version != TLS_1_3_VERSION &&
2757 !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
2759 rc = tls_strp_init(&sw_ctx_rx->strp, sk);
2767 crypto_free_aead(*aead);
2770 kfree(cctx->rec_seq);
2771 cctx->rec_seq = NULL;
2777 kfree(ctx->priv_ctx_tx);
2778 ctx->priv_ctx_tx = NULL;
2780 kfree(ctx->priv_ctx_rx);
2781 ctx->priv_ctx_rx = NULL;