ipv6: init the accept_queue's spinlocks in inet6_create
[platform/kernel/linux-starfive.git] / net / tls / tls_sw.c
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7  * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
8  *
9  * This software is available to you under a choice of one of two
10  * licenses.  You may choose to be licensed under the terms of the GNU
11  * General Public License (GPL) Version 2, available from the file
12  * COPYING in the main directory of this source tree, or the
13  * OpenIB.org BSD license below:
14  *
15  *     Redistribution and use in source and binary forms, with or
16  *     without modification, are permitted provided that the following
17  *     conditions are met:
18  *
19  *      - Redistributions of source code must retain the above
20  *        copyright notice, this list of conditions and the following
21  *        disclaimer.
22  *
23  *      - Redistributions in binary form must reproduce the above
24  *        copyright notice, this list of conditions and the following
25  *        disclaimer in the documentation and/or other materials
26  *        provided with the distribution.
27  *
28  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
29  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
30  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
31  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
32  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
33  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
34  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35  * SOFTWARE.
36  */
37
38 #include <linux/bug.h>
39 #include <linux/sched/signal.h>
40 #include <linux/module.h>
41 #include <linux/kernel.h>
42 #include <linux/splice.h>
43 #include <crypto/aead.h>
44
45 #include <net/strparser.h>
46 #include <net/tls.h>
47 #include <trace/events/sock.h>
48
49 #include "tls.h"
50
51 struct tls_decrypt_arg {
52         struct_group(inargs,
53         bool zc;
54         bool async;
55         u8 tail;
56         );
57
58         struct sk_buff *skb;
59 };
60
61 struct tls_decrypt_ctx {
62         struct sock *sk;
63         u8 iv[MAX_IV_SIZE];
64         u8 aad[TLS_MAX_AAD_SIZE];
65         u8 tail;
66         struct scatterlist sg[];
67 };
68
69 noinline void tls_err_abort(struct sock *sk, int err)
70 {
71         WARN_ON_ONCE(err >= 0);
72         /* sk->sk_err should contain a positive error code. */
73         WRITE_ONCE(sk->sk_err, -err);
74         /* Paired with smp_rmb() in tcp_poll() */
75         smp_wmb();
76         sk_error_report(sk);
77 }
78
79 static int __skb_nsg(struct sk_buff *skb, int offset, int len,
80                      unsigned int recursion_level)
81 {
82         int start = skb_headlen(skb);
83         int i, chunk = start - offset;
84         struct sk_buff *frag_iter;
85         int elt = 0;
86
87         if (unlikely(recursion_level >= 24))
88                 return -EMSGSIZE;
89
90         if (chunk > 0) {
91                 if (chunk > len)
92                         chunk = len;
93                 elt++;
94                 len -= chunk;
95                 if (len == 0)
96                         return elt;
97                 offset += chunk;
98         }
99
100         for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
101                 int end;
102
103                 WARN_ON(start > offset + len);
104
105                 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
106                 chunk = end - offset;
107                 if (chunk > 0) {
108                         if (chunk > len)
109                                 chunk = len;
110                         elt++;
111                         len -= chunk;
112                         if (len == 0)
113                                 return elt;
114                         offset += chunk;
115                 }
116                 start = end;
117         }
118
119         if (unlikely(skb_has_frag_list(skb))) {
120                 skb_walk_frags(skb, frag_iter) {
121                         int end, ret;
122
123                         WARN_ON(start > offset + len);
124
125                         end = start + frag_iter->len;
126                         chunk = end - offset;
127                         if (chunk > 0) {
128                                 if (chunk > len)
129                                         chunk = len;
130                                 ret = __skb_nsg(frag_iter, offset - start, chunk,
131                                                 recursion_level + 1);
132                                 if (unlikely(ret < 0))
133                                         return ret;
134                                 elt += ret;
135                                 len -= chunk;
136                                 if (len == 0)
137                                         return elt;
138                                 offset += chunk;
139                         }
140                         start = end;
141                 }
142         }
143         BUG_ON(len);
144         return elt;
145 }
146
147 /* Return the number of scatterlist elements required to completely map the
148  * skb, or -EMSGSIZE if the recursion depth is exceeded.
149  */
150 static int skb_nsg(struct sk_buff *skb, int offset, int len)
151 {
152         return __skb_nsg(skb, offset, len, 0);
153 }
154
155 static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb,
156                               struct tls_decrypt_arg *darg)
157 {
158         struct strp_msg *rxm = strp_msg(skb);
159         struct tls_msg *tlm = tls_msg(skb);
160         int sub = 0;
161
162         /* Determine zero-padding length */
163         if (prot->version == TLS_1_3_VERSION) {
164                 int offset = rxm->full_len - TLS_TAG_SIZE - 1;
165                 char content_type = darg->zc ? darg->tail : 0;
166                 int err;
167
168                 while (content_type == 0) {
169                         if (offset < prot->prepend_size)
170                                 return -EBADMSG;
171                         err = skb_copy_bits(skb, rxm->offset + offset,
172                                             &content_type, 1);
173                         if (err)
174                                 return err;
175                         if (content_type)
176                                 break;
177                         sub++;
178                         offset--;
179                 }
180                 tlm->control = content_type;
181         }
182         return sub;
183 }
184
185 static void tls_decrypt_done(void *data, int err)
186 {
187         struct aead_request *aead_req = data;
188         struct crypto_aead *aead = crypto_aead_reqtfm(aead_req);
189         struct scatterlist *sgout = aead_req->dst;
190         struct scatterlist *sgin = aead_req->src;
191         struct tls_sw_context_rx *ctx;
192         struct tls_decrypt_ctx *dctx;
193         struct tls_context *tls_ctx;
194         struct scatterlist *sg;
195         unsigned int pages;
196         struct sock *sk;
197         int aead_size;
198
199         aead_size = sizeof(*aead_req) + crypto_aead_reqsize(aead);
200         aead_size = ALIGN(aead_size, __alignof__(*dctx));
201         dctx = (void *)((u8 *)aead_req + aead_size);
202
203         sk = dctx->sk;
204         tls_ctx = tls_get_ctx(sk);
205         ctx = tls_sw_ctx_rx(tls_ctx);
206
207         /* Propagate if there was an err */
208         if (err) {
209                 if (err == -EBADMSG)
210                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
211                 ctx->async_wait.err = err;
212                 tls_err_abort(sk, err);
213         }
214
215         /* Free the destination pages if skb was not decrypted inplace */
216         if (sgout != sgin) {
217                 /* Skip the first S/G entry as it points to AAD */
218                 for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
219                         if (!sg)
220                                 break;
221                         put_page(sg_page(sg));
222                 }
223         }
224
225         kfree(aead_req);
226
227         spin_lock_bh(&ctx->decrypt_compl_lock);
228         if (!atomic_dec_return(&ctx->decrypt_pending))
229                 complete(&ctx->async_wait.completion);
230         spin_unlock_bh(&ctx->decrypt_compl_lock);
231 }
232
233 static int tls_do_decryption(struct sock *sk,
234                              struct scatterlist *sgin,
235                              struct scatterlist *sgout,
236                              char *iv_recv,
237                              size_t data_len,
238                              struct aead_request *aead_req,
239                              struct tls_decrypt_arg *darg)
240 {
241         struct tls_context *tls_ctx = tls_get_ctx(sk);
242         struct tls_prot_info *prot = &tls_ctx->prot_info;
243         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
244         int ret;
245
246         aead_request_set_tfm(aead_req, ctx->aead_recv);
247         aead_request_set_ad(aead_req, prot->aad_size);
248         aead_request_set_crypt(aead_req, sgin, sgout,
249                                data_len + prot->tag_size,
250                                (u8 *)iv_recv);
251
252         if (darg->async) {
253                 aead_request_set_callback(aead_req,
254                                           CRYPTO_TFM_REQ_MAY_BACKLOG,
255                                           tls_decrypt_done, aead_req);
256                 atomic_inc(&ctx->decrypt_pending);
257         } else {
258                 aead_request_set_callback(aead_req,
259                                           CRYPTO_TFM_REQ_MAY_BACKLOG,
260                                           crypto_req_done, &ctx->async_wait);
261         }
262
263         ret = crypto_aead_decrypt(aead_req);
264         if (ret == -EINPROGRESS) {
265                 if (darg->async)
266                         return 0;
267
268                 ret = crypto_wait_req(ret, &ctx->async_wait);
269         }
270         darg->async = false;
271
272         return ret;
273 }
274
275 static void tls_trim_both_msgs(struct sock *sk, int target_size)
276 {
277         struct tls_context *tls_ctx = tls_get_ctx(sk);
278         struct tls_prot_info *prot = &tls_ctx->prot_info;
279         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
280         struct tls_rec *rec = ctx->open_rec;
281
282         sk_msg_trim(sk, &rec->msg_plaintext, target_size);
283         if (target_size > 0)
284                 target_size += prot->overhead_size;
285         sk_msg_trim(sk, &rec->msg_encrypted, target_size);
286 }
287
288 static int tls_alloc_encrypted_msg(struct sock *sk, int len)
289 {
290         struct tls_context *tls_ctx = tls_get_ctx(sk);
291         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
292         struct tls_rec *rec = ctx->open_rec;
293         struct sk_msg *msg_en = &rec->msg_encrypted;
294
295         return sk_msg_alloc(sk, msg_en, len, 0);
296 }
297
298 static int tls_clone_plaintext_msg(struct sock *sk, int required)
299 {
300         struct tls_context *tls_ctx = tls_get_ctx(sk);
301         struct tls_prot_info *prot = &tls_ctx->prot_info;
302         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
303         struct tls_rec *rec = ctx->open_rec;
304         struct sk_msg *msg_pl = &rec->msg_plaintext;
305         struct sk_msg *msg_en = &rec->msg_encrypted;
306         int skip, len;
307
308         /* We add page references worth len bytes from encrypted sg
309          * at the end of plaintext sg. It is guaranteed that msg_en
310          * has enough required room (ensured by caller).
311          */
312         len = required - msg_pl->sg.size;
313
314         /* Skip initial bytes in msg_en's data to be able to use
315          * same offset of both plain and encrypted data.
316          */
317         skip = prot->prepend_size + msg_pl->sg.size;
318
319         return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
320 }
321
322 static struct tls_rec *tls_get_rec(struct sock *sk)
323 {
324         struct tls_context *tls_ctx = tls_get_ctx(sk);
325         struct tls_prot_info *prot = &tls_ctx->prot_info;
326         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
327         struct sk_msg *msg_pl, *msg_en;
328         struct tls_rec *rec;
329         int mem_size;
330
331         mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
332
333         rec = kzalloc(mem_size, sk->sk_allocation);
334         if (!rec)
335                 return NULL;
336
337         msg_pl = &rec->msg_plaintext;
338         msg_en = &rec->msg_encrypted;
339
340         sk_msg_init(msg_pl);
341         sk_msg_init(msg_en);
342
343         sg_init_table(rec->sg_aead_in, 2);
344         sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
345         sg_unmark_end(&rec->sg_aead_in[1]);
346
347         sg_init_table(rec->sg_aead_out, 2);
348         sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
349         sg_unmark_end(&rec->sg_aead_out[1]);
350
351         rec->sk = sk;
352
353         return rec;
354 }
355
356 static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
357 {
358         sk_msg_free(sk, &rec->msg_encrypted);
359         sk_msg_free(sk, &rec->msg_plaintext);
360         kfree(rec);
361 }
362
363 static void tls_free_open_rec(struct sock *sk)
364 {
365         struct tls_context *tls_ctx = tls_get_ctx(sk);
366         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
367         struct tls_rec *rec = ctx->open_rec;
368
369         if (rec) {
370                 tls_free_rec(sk, rec);
371                 ctx->open_rec = NULL;
372         }
373 }
374
375 int tls_tx_records(struct sock *sk, int flags)
376 {
377         struct tls_context *tls_ctx = tls_get_ctx(sk);
378         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
379         struct tls_rec *rec, *tmp;
380         struct sk_msg *msg_en;
381         int tx_flags, rc = 0;
382
383         if (tls_is_partially_sent_record(tls_ctx)) {
384                 rec = list_first_entry(&ctx->tx_list,
385                                        struct tls_rec, list);
386
387                 if (flags == -1)
388                         tx_flags = rec->tx_flags;
389                 else
390                         tx_flags = flags;
391
392                 rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
393                 if (rc)
394                         goto tx_err;
395
396                 /* Full record has been transmitted.
397                  * Remove the head of tx_list
398                  */
399                 list_del(&rec->list);
400                 sk_msg_free(sk, &rec->msg_plaintext);
401                 kfree(rec);
402         }
403
404         /* Tx all ready records */
405         list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
406                 if (READ_ONCE(rec->tx_ready)) {
407                         if (flags == -1)
408                                 tx_flags = rec->tx_flags;
409                         else
410                                 tx_flags = flags;
411
412                         msg_en = &rec->msg_encrypted;
413                         rc = tls_push_sg(sk, tls_ctx,
414                                          &msg_en->sg.data[msg_en->sg.curr],
415                                          0, tx_flags);
416                         if (rc)
417                                 goto tx_err;
418
419                         list_del(&rec->list);
420                         sk_msg_free(sk, &rec->msg_plaintext);
421                         kfree(rec);
422                 } else {
423                         break;
424                 }
425         }
426
427 tx_err:
428         if (rc < 0 && rc != -EAGAIN)
429                 tls_err_abort(sk, -EBADMSG);
430
431         return rc;
432 }
433
434 static void tls_encrypt_done(void *data, int err)
435 {
436         struct tls_sw_context_tx *ctx;
437         struct tls_context *tls_ctx;
438         struct tls_prot_info *prot;
439         struct tls_rec *rec = data;
440         struct scatterlist *sge;
441         struct sk_msg *msg_en;
442         bool ready = false;
443         struct sock *sk;
444         int pending;
445
446         msg_en = &rec->msg_encrypted;
447
448         sk = rec->sk;
449         tls_ctx = tls_get_ctx(sk);
450         prot = &tls_ctx->prot_info;
451         ctx = tls_sw_ctx_tx(tls_ctx);
452
453         sge = sk_msg_elem(msg_en, msg_en->sg.curr);
454         sge->offset -= prot->prepend_size;
455         sge->length += prot->prepend_size;
456
457         /* Check if error is previously set on socket */
458         if (err || sk->sk_err) {
459                 rec = NULL;
460
461                 /* If err is already set on socket, return the same code */
462                 if (sk->sk_err) {
463                         ctx->async_wait.err = -sk->sk_err;
464                 } else {
465                         ctx->async_wait.err = err;
466                         tls_err_abort(sk, err);
467                 }
468         }
469
470         if (rec) {
471                 struct tls_rec *first_rec;
472
473                 /* Mark the record as ready for transmission */
474                 smp_store_mb(rec->tx_ready, true);
475
476                 /* If received record is at head of tx_list, schedule tx */
477                 first_rec = list_first_entry(&ctx->tx_list,
478                                              struct tls_rec, list);
479                 if (rec == first_rec)
480                         ready = true;
481         }
482
483         spin_lock_bh(&ctx->encrypt_compl_lock);
484         pending = atomic_dec_return(&ctx->encrypt_pending);
485
486         if (!pending && ctx->async_notify)
487                 complete(&ctx->async_wait.completion);
488         spin_unlock_bh(&ctx->encrypt_compl_lock);
489
490         if (!ready)
491                 return;
492
493         /* Schedule the transmission */
494         if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
495                 schedule_delayed_work(&ctx->tx_work.work, 1);
496 }
497
498 static int tls_do_encryption(struct sock *sk,
499                              struct tls_context *tls_ctx,
500                              struct tls_sw_context_tx *ctx,
501                              struct aead_request *aead_req,
502                              size_t data_len, u32 start)
503 {
504         struct tls_prot_info *prot = &tls_ctx->prot_info;
505         struct tls_rec *rec = ctx->open_rec;
506         struct sk_msg *msg_en = &rec->msg_encrypted;
507         struct scatterlist *sge = sk_msg_elem(msg_en, start);
508         int rc, iv_offset = 0;
509
510         /* For CCM based ciphers, first byte of IV is a constant */
511         switch (prot->cipher_type) {
512         case TLS_CIPHER_AES_CCM_128:
513                 rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
514                 iv_offset = 1;
515                 break;
516         case TLS_CIPHER_SM4_CCM:
517                 rec->iv_data[0] = TLS_SM4_CCM_IV_B0_BYTE;
518                 iv_offset = 1;
519                 break;
520         }
521
522         memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
523                prot->iv_size + prot->salt_size);
524
525         tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset,
526                             tls_ctx->tx.rec_seq);
527
528         sge->offset += prot->prepend_size;
529         sge->length -= prot->prepend_size;
530
531         msg_en->sg.curr = start;
532
533         aead_request_set_tfm(aead_req, ctx->aead_send);
534         aead_request_set_ad(aead_req, prot->aad_size);
535         aead_request_set_crypt(aead_req, rec->sg_aead_in,
536                                rec->sg_aead_out,
537                                data_len, rec->iv_data);
538
539         aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
540                                   tls_encrypt_done, rec);
541
542         /* Add the record in tx_list */
543         list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
544         atomic_inc(&ctx->encrypt_pending);
545
546         rc = crypto_aead_encrypt(aead_req);
547         if (!rc || rc != -EINPROGRESS) {
548                 atomic_dec(&ctx->encrypt_pending);
549                 sge->offset -= prot->prepend_size;
550                 sge->length += prot->prepend_size;
551         }
552
553         if (!rc) {
554                 WRITE_ONCE(rec->tx_ready, true);
555         } else if (rc != -EINPROGRESS) {
556                 list_del(&rec->list);
557                 return rc;
558         }
559
560         /* Unhook the record from context if encryption is not failure */
561         ctx->open_rec = NULL;
562         tls_advance_record_sn(sk, prot, &tls_ctx->tx);
563         return rc;
564 }
565
566 static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
567                                  struct tls_rec **to, struct sk_msg *msg_opl,
568                                  struct sk_msg *msg_oen, u32 split_point,
569                                  u32 tx_overhead_size, u32 *orig_end)
570 {
571         u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
572         struct scatterlist *sge, *osge, *nsge;
573         u32 orig_size = msg_opl->sg.size;
574         struct scatterlist tmp = { };
575         struct sk_msg *msg_npl;
576         struct tls_rec *new;
577         int ret;
578
579         new = tls_get_rec(sk);
580         if (!new)
581                 return -ENOMEM;
582         ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
583                            tx_overhead_size, 0);
584         if (ret < 0) {
585                 tls_free_rec(sk, new);
586                 return ret;
587         }
588
589         *orig_end = msg_opl->sg.end;
590         i = msg_opl->sg.start;
591         sge = sk_msg_elem(msg_opl, i);
592         while (apply && sge->length) {
593                 if (sge->length > apply) {
594                         u32 len = sge->length - apply;
595
596                         get_page(sg_page(sge));
597                         sg_set_page(&tmp, sg_page(sge), len,
598                                     sge->offset + apply);
599                         sge->length = apply;
600                         bytes += apply;
601                         apply = 0;
602                 } else {
603                         apply -= sge->length;
604                         bytes += sge->length;
605                 }
606
607                 sk_msg_iter_var_next(i);
608                 if (i == msg_opl->sg.end)
609                         break;
610                 sge = sk_msg_elem(msg_opl, i);
611         }
612
613         msg_opl->sg.end = i;
614         msg_opl->sg.curr = i;
615         msg_opl->sg.copybreak = 0;
616         msg_opl->apply_bytes = 0;
617         msg_opl->sg.size = bytes;
618
619         msg_npl = &new->msg_plaintext;
620         msg_npl->apply_bytes = apply;
621         msg_npl->sg.size = orig_size - bytes;
622
623         j = msg_npl->sg.start;
624         nsge = sk_msg_elem(msg_npl, j);
625         if (tmp.length) {
626                 memcpy(nsge, &tmp, sizeof(*nsge));
627                 sk_msg_iter_var_next(j);
628                 nsge = sk_msg_elem(msg_npl, j);
629         }
630
631         osge = sk_msg_elem(msg_opl, i);
632         while (osge->length) {
633                 memcpy(nsge, osge, sizeof(*nsge));
634                 sg_unmark_end(nsge);
635                 sk_msg_iter_var_next(i);
636                 sk_msg_iter_var_next(j);
637                 if (i == *orig_end)
638                         break;
639                 osge = sk_msg_elem(msg_opl, i);
640                 nsge = sk_msg_elem(msg_npl, j);
641         }
642
643         msg_npl->sg.end = j;
644         msg_npl->sg.curr = j;
645         msg_npl->sg.copybreak = 0;
646
647         *to = new;
648         return 0;
649 }
650
651 static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
652                                   struct tls_rec *from, u32 orig_end)
653 {
654         struct sk_msg *msg_npl = &from->msg_plaintext;
655         struct sk_msg *msg_opl = &to->msg_plaintext;
656         struct scatterlist *osge, *nsge;
657         u32 i, j;
658
659         i = msg_opl->sg.end;
660         sk_msg_iter_var_prev(i);
661         j = msg_npl->sg.start;
662
663         osge = sk_msg_elem(msg_opl, i);
664         nsge = sk_msg_elem(msg_npl, j);
665
666         if (sg_page(osge) == sg_page(nsge) &&
667             osge->offset + osge->length == nsge->offset) {
668                 osge->length += nsge->length;
669                 put_page(sg_page(nsge));
670         }
671
672         msg_opl->sg.end = orig_end;
673         msg_opl->sg.curr = orig_end;
674         msg_opl->sg.copybreak = 0;
675         msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
676         msg_opl->sg.size += msg_npl->sg.size;
677
678         sk_msg_free(sk, &to->msg_encrypted);
679         sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
680
681         kfree(from);
682 }
683
684 static int tls_push_record(struct sock *sk, int flags,
685                            unsigned char record_type)
686 {
687         struct tls_context *tls_ctx = tls_get_ctx(sk);
688         struct tls_prot_info *prot = &tls_ctx->prot_info;
689         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
690         struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
691         u32 i, split_point, orig_end;
692         struct sk_msg *msg_pl, *msg_en;
693         struct aead_request *req;
694         bool split;
695         int rc;
696
697         if (!rec)
698                 return 0;
699
700         msg_pl = &rec->msg_plaintext;
701         msg_en = &rec->msg_encrypted;
702
703         split_point = msg_pl->apply_bytes;
704         split = split_point && split_point < msg_pl->sg.size;
705         if (unlikely((!split &&
706                       msg_pl->sg.size +
707                       prot->overhead_size > msg_en->sg.size) ||
708                      (split &&
709                       split_point +
710                       prot->overhead_size > msg_en->sg.size))) {
711                 split = true;
712                 split_point = msg_en->sg.size;
713         }
714         if (split) {
715                 rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
716                                            split_point, prot->overhead_size,
717                                            &orig_end);
718                 if (rc < 0)
719                         return rc;
720                 /* This can happen if above tls_split_open_record allocates
721                  * a single large encryption buffer instead of two smaller
722                  * ones. In this case adjust pointers and continue without
723                  * split.
724                  */
725                 if (!msg_pl->sg.size) {
726                         tls_merge_open_record(sk, rec, tmp, orig_end);
727                         msg_pl = &rec->msg_plaintext;
728                         msg_en = &rec->msg_encrypted;
729                         split = false;
730                 }
731                 sk_msg_trim(sk, msg_en, msg_pl->sg.size +
732                             prot->overhead_size);
733         }
734
735         rec->tx_flags = flags;
736         req = &rec->aead_req;
737
738         i = msg_pl->sg.end;
739         sk_msg_iter_var_prev(i);
740
741         rec->content_type = record_type;
742         if (prot->version == TLS_1_3_VERSION) {
743                 /* Add content type to end of message.  No padding added */
744                 sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
745                 sg_mark_end(&rec->sg_content_type);
746                 sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
747                          &rec->sg_content_type);
748         } else {
749                 sg_mark_end(sk_msg_elem(msg_pl, i));
750         }
751
752         if (msg_pl->sg.end < msg_pl->sg.start) {
753                 sg_chain(&msg_pl->sg.data[msg_pl->sg.start],
754                          MAX_SKB_FRAGS - msg_pl->sg.start + 1,
755                          msg_pl->sg.data);
756         }
757
758         i = msg_pl->sg.start;
759         sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
760
761         i = msg_en->sg.end;
762         sk_msg_iter_var_prev(i);
763         sg_mark_end(sk_msg_elem(msg_en, i));
764
765         i = msg_en->sg.start;
766         sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
767
768         tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
769                      tls_ctx->tx.rec_seq, record_type, prot);
770
771         tls_fill_prepend(tls_ctx,
772                          page_address(sg_page(&msg_en->sg.data[i])) +
773                          msg_en->sg.data[i].offset,
774                          msg_pl->sg.size + prot->tail_size,
775                          record_type);
776
777         tls_ctx->pending_open_record_frags = false;
778
779         rc = tls_do_encryption(sk, tls_ctx, ctx, req,
780                                msg_pl->sg.size + prot->tail_size, i);
781         if (rc < 0) {
782                 if (rc != -EINPROGRESS) {
783                         tls_err_abort(sk, -EBADMSG);
784                         if (split) {
785                                 tls_ctx->pending_open_record_frags = true;
786                                 tls_merge_open_record(sk, rec, tmp, orig_end);
787                         }
788                 }
789                 ctx->async_capable = 1;
790                 return rc;
791         } else if (split) {
792                 msg_pl = &tmp->msg_plaintext;
793                 msg_en = &tmp->msg_encrypted;
794                 sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
795                 tls_ctx->pending_open_record_frags = true;
796                 ctx->open_rec = tmp;
797         }
798
799         return tls_tx_records(sk, flags);
800 }
801
802 static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
803                                bool full_record, u8 record_type,
804                                ssize_t *copied, int flags)
805 {
806         struct tls_context *tls_ctx = tls_get_ctx(sk);
807         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
808         struct sk_msg msg_redir = { };
809         struct sk_psock *psock;
810         struct sock *sk_redir;
811         struct tls_rec *rec;
812         bool enospc, policy, redir_ingress;
813         int err = 0, send;
814         u32 delta = 0;
815
816         policy = !(flags & MSG_SENDPAGE_NOPOLICY);
817         psock = sk_psock_get(sk);
818         if (!psock || !policy) {
819                 err = tls_push_record(sk, flags, record_type);
820                 if (err && err != -EINPROGRESS && sk->sk_err == EBADMSG) {
821                         *copied -= sk_msg_free(sk, msg);
822                         tls_free_open_rec(sk);
823                         err = -sk->sk_err;
824                 }
825                 if (psock)
826                         sk_psock_put(sk, psock);
827                 return err;
828         }
829 more_data:
830         enospc = sk_msg_full(msg);
831         if (psock->eval == __SK_NONE) {
832                 delta = msg->sg.size;
833                 psock->eval = sk_psock_msg_verdict(sk, psock, msg);
834                 delta -= msg->sg.size;
835         }
836         if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
837             !enospc && !full_record) {
838                 err = -ENOSPC;
839                 goto out_err;
840         }
841         msg->cork_bytes = 0;
842         send = msg->sg.size;
843         if (msg->apply_bytes && msg->apply_bytes < send)
844                 send = msg->apply_bytes;
845
846         switch (psock->eval) {
847         case __SK_PASS:
848                 err = tls_push_record(sk, flags, record_type);
849                 if (err && err != -EINPROGRESS && sk->sk_err == EBADMSG) {
850                         *copied -= sk_msg_free(sk, msg);
851                         tls_free_open_rec(sk);
852                         err = -sk->sk_err;
853                         goto out_err;
854                 }
855                 break;
856         case __SK_REDIRECT:
857                 redir_ingress = psock->redir_ingress;
858                 sk_redir = psock->sk_redir;
859                 memcpy(&msg_redir, msg, sizeof(*msg));
860                 if (msg->apply_bytes < send)
861                         msg->apply_bytes = 0;
862                 else
863                         msg->apply_bytes -= send;
864                 sk_msg_return_zero(sk, msg, send);
865                 msg->sg.size -= send;
866                 release_sock(sk);
867                 err = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress,
868                                             &msg_redir, send, flags);
869                 lock_sock(sk);
870                 if (err < 0) {
871                         *copied -= sk_msg_free_nocharge(sk, &msg_redir);
872                         msg->sg.size = 0;
873                 }
874                 if (msg->sg.size == 0)
875                         tls_free_open_rec(sk);
876                 break;
877         case __SK_DROP:
878         default:
879                 sk_msg_free_partial(sk, msg, send);
880                 if (msg->apply_bytes < send)
881                         msg->apply_bytes = 0;
882                 else
883                         msg->apply_bytes -= send;
884                 if (msg->sg.size == 0)
885                         tls_free_open_rec(sk);
886                 *copied -= (send + delta);
887                 err = -EACCES;
888         }
889
890         if (likely(!err)) {
891                 bool reset_eval = !ctx->open_rec;
892
893                 rec = ctx->open_rec;
894                 if (rec) {
895                         msg = &rec->msg_plaintext;
896                         if (!msg->apply_bytes)
897                                 reset_eval = true;
898                 }
899                 if (reset_eval) {
900                         psock->eval = __SK_NONE;
901                         if (psock->sk_redir) {
902                                 sock_put(psock->sk_redir);
903                                 psock->sk_redir = NULL;
904                         }
905                 }
906                 if (rec)
907                         goto more_data;
908         }
909  out_err:
910         sk_psock_put(sk, psock);
911         return err;
912 }
913
914 static int tls_sw_push_pending_record(struct sock *sk, int flags)
915 {
916         struct tls_context *tls_ctx = tls_get_ctx(sk);
917         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
918         struct tls_rec *rec = ctx->open_rec;
919         struct sk_msg *msg_pl;
920         size_t copied;
921
922         if (!rec)
923                 return 0;
924
925         msg_pl = &rec->msg_plaintext;
926         copied = msg_pl->sg.size;
927         if (!copied)
928                 return 0;
929
930         return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
931                                    &copied, flags);
932 }
933
934 static int tls_sw_sendmsg_splice(struct sock *sk, struct msghdr *msg,
935                                  struct sk_msg *msg_pl, size_t try_to_copy,
936                                  ssize_t *copied)
937 {
938         struct page *page = NULL, **pages = &page;
939
940         do {
941                 ssize_t part;
942                 size_t off;
943
944                 part = iov_iter_extract_pages(&msg->msg_iter, &pages,
945                                               try_to_copy, 1, 0, &off);
946                 if (part <= 0)
947                         return part ?: -EIO;
948
949                 if (WARN_ON_ONCE(!sendpage_ok(page))) {
950                         iov_iter_revert(&msg->msg_iter, part);
951                         return -EIO;
952                 }
953
954                 sk_msg_page_add(msg_pl, page, part, off);
955                 msg_pl->sg.copybreak = 0;
956                 msg_pl->sg.curr = msg_pl->sg.end;
957                 sk_mem_charge(sk, part);
958                 *copied += part;
959                 try_to_copy -= part;
960         } while (try_to_copy && !sk_msg_full(msg_pl));
961
962         return 0;
963 }
964
965 static int tls_sw_sendmsg_locked(struct sock *sk, struct msghdr *msg,
966                                  size_t size)
967 {
968         long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
969         struct tls_context *tls_ctx = tls_get_ctx(sk);
970         struct tls_prot_info *prot = &tls_ctx->prot_info;
971         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
972         bool async_capable = ctx->async_capable;
973         unsigned char record_type = TLS_RECORD_TYPE_DATA;
974         bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
975         bool eor = !(msg->msg_flags & MSG_MORE);
976         size_t try_to_copy;
977         ssize_t copied = 0;
978         struct sk_msg *msg_pl, *msg_en;
979         struct tls_rec *rec;
980         int required_size;
981         int num_async = 0;
982         bool full_record;
983         int record_room;
984         int num_zc = 0;
985         int orig_size;
986         int ret = 0;
987         int pending;
988
989         if (!eor && (msg->msg_flags & MSG_EOR))
990                 return -EINVAL;
991
992         if (unlikely(msg->msg_controllen)) {
993                 ret = tls_process_cmsg(sk, msg, &record_type);
994                 if (ret) {
995                         if (ret == -EINPROGRESS)
996                                 num_async++;
997                         else if (ret != -EAGAIN)
998                                 goto send_end;
999                 }
1000         }
1001
1002         while (msg_data_left(msg)) {
1003                 if (sk->sk_err) {
1004                         ret = -sk->sk_err;
1005                         goto send_end;
1006                 }
1007
1008                 if (ctx->open_rec)
1009                         rec = ctx->open_rec;
1010                 else
1011                         rec = ctx->open_rec = tls_get_rec(sk);
1012                 if (!rec) {
1013                         ret = -ENOMEM;
1014                         goto send_end;
1015                 }
1016
1017                 msg_pl = &rec->msg_plaintext;
1018                 msg_en = &rec->msg_encrypted;
1019
1020                 orig_size = msg_pl->sg.size;
1021                 full_record = false;
1022                 try_to_copy = msg_data_left(msg);
1023                 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
1024                 if (try_to_copy >= record_room) {
1025                         try_to_copy = record_room;
1026                         full_record = true;
1027                 }
1028
1029                 required_size = msg_pl->sg.size + try_to_copy +
1030                                 prot->overhead_size;
1031
1032                 if (!sk_stream_memory_free(sk))
1033                         goto wait_for_sndbuf;
1034
1035 alloc_encrypted:
1036                 ret = tls_alloc_encrypted_msg(sk, required_size);
1037                 if (ret) {
1038                         if (ret != -ENOSPC)
1039                                 goto wait_for_memory;
1040
1041                         /* Adjust try_to_copy according to the amount that was
1042                          * actually allocated. The difference is due
1043                          * to max sg elements limit
1044                          */
1045                         try_to_copy -= required_size - msg_en->sg.size;
1046                         full_record = true;
1047                 }
1048
1049                 if (try_to_copy && (msg->msg_flags & MSG_SPLICE_PAGES)) {
1050                         ret = tls_sw_sendmsg_splice(sk, msg, msg_pl,
1051                                                     try_to_copy, &copied);
1052                         if (ret < 0)
1053                                 goto send_end;
1054                         tls_ctx->pending_open_record_frags = true;
1055
1056                         if (sk_msg_full(msg_pl))
1057                                 full_record = true;
1058
1059                         if (full_record || eor)
1060                                 goto copied;
1061                         continue;
1062                 }
1063
1064                 if (!is_kvec && (full_record || eor) && !async_capable) {
1065                         u32 first = msg_pl->sg.end;
1066
1067                         ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
1068                                                         msg_pl, try_to_copy);
1069                         if (ret)
1070                                 goto fallback_to_reg_send;
1071
1072                         num_zc++;
1073                         copied += try_to_copy;
1074
1075                         sk_msg_sg_copy_set(msg_pl, first);
1076                         ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1077                                                   record_type, &copied,
1078                                                   msg->msg_flags);
1079                         if (ret) {
1080                                 if (ret == -EINPROGRESS)
1081                                         num_async++;
1082                                 else if (ret == -ENOMEM)
1083                                         goto wait_for_memory;
1084                                 else if (ctx->open_rec && ret == -ENOSPC)
1085                                         goto rollback_iter;
1086                                 else if (ret != -EAGAIN)
1087                                         goto send_end;
1088                         }
1089                         continue;
1090 rollback_iter:
1091                         copied -= try_to_copy;
1092                         sk_msg_sg_copy_clear(msg_pl, first);
1093                         iov_iter_revert(&msg->msg_iter,
1094                                         msg_pl->sg.size - orig_size);
1095 fallback_to_reg_send:
1096                         sk_msg_trim(sk, msg_pl, orig_size);
1097                 }
1098
1099                 required_size = msg_pl->sg.size + try_to_copy;
1100
1101                 ret = tls_clone_plaintext_msg(sk, required_size);
1102                 if (ret) {
1103                         if (ret != -ENOSPC)
1104                                 goto send_end;
1105
1106                         /* Adjust try_to_copy according to the amount that was
1107                          * actually allocated. The difference is due
1108                          * to max sg elements limit
1109                          */
1110                         try_to_copy -= required_size - msg_pl->sg.size;
1111                         full_record = true;
1112                         sk_msg_trim(sk, msg_en,
1113                                     msg_pl->sg.size + prot->overhead_size);
1114                 }
1115
1116                 if (try_to_copy) {
1117                         ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
1118                                                        msg_pl, try_to_copy);
1119                         if (ret < 0)
1120                                 goto trim_sgl;
1121                 }
1122
1123                 /* Open records defined only if successfully copied, otherwise
1124                  * we would trim the sg but not reset the open record frags.
1125                  */
1126                 tls_ctx->pending_open_record_frags = true;
1127                 copied += try_to_copy;
1128 copied:
1129                 if (full_record || eor) {
1130                         ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1131                                                   record_type, &copied,
1132                                                   msg->msg_flags);
1133                         if (ret) {
1134                                 if (ret == -EINPROGRESS)
1135                                         num_async++;
1136                                 else if (ret == -ENOMEM)
1137                                         goto wait_for_memory;
1138                                 else if (ret != -EAGAIN) {
1139                                         if (ret == -ENOSPC)
1140                                                 ret = 0;
1141                                         goto send_end;
1142                                 }
1143                         }
1144                 }
1145
1146                 continue;
1147
1148 wait_for_sndbuf:
1149                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1150 wait_for_memory:
1151                 ret = sk_stream_wait_memory(sk, &timeo);
1152                 if (ret) {
1153 trim_sgl:
1154                         if (ctx->open_rec)
1155                                 tls_trim_both_msgs(sk, orig_size);
1156                         goto send_end;
1157                 }
1158
1159                 if (ctx->open_rec && msg_en->sg.size < required_size)
1160                         goto alloc_encrypted;
1161         }
1162
1163         if (!num_async) {
1164                 goto send_end;
1165         } else if (num_zc) {
1166                 /* Wait for pending encryptions to get completed */
1167                 spin_lock_bh(&ctx->encrypt_compl_lock);
1168                 ctx->async_notify = true;
1169
1170                 pending = atomic_read(&ctx->encrypt_pending);
1171                 spin_unlock_bh(&ctx->encrypt_compl_lock);
1172                 if (pending)
1173                         crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1174                 else
1175                         reinit_completion(&ctx->async_wait.completion);
1176
1177                 /* There can be no concurrent accesses, since we have no
1178                  * pending encrypt operations
1179                  */
1180                 WRITE_ONCE(ctx->async_notify, false);
1181
1182                 if (ctx->async_wait.err) {
1183                         ret = ctx->async_wait.err;
1184                         copied = 0;
1185                 }
1186         }
1187
1188         /* Transmit if any encryptions have completed */
1189         if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1190                 cancel_delayed_work(&ctx->tx_work.work);
1191                 tls_tx_records(sk, msg->msg_flags);
1192         }
1193
1194 send_end:
1195         ret = sk_stream_error(sk, msg->msg_flags, ret);
1196         return copied > 0 ? copied : ret;
1197 }
1198
1199 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
1200 {
1201         struct tls_context *tls_ctx = tls_get_ctx(sk);
1202         int ret;
1203
1204         if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1205                                MSG_CMSG_COMPAT | MSG_SPLICE_PAGES | MSG_EOR |
1206                                MSG_SENDPAGE_NOPOLICY))
1207                 return -EOPNOTSUPP;
1208
1209         ret = mutex_lock_interruptible(&tls_ctx->tx_lock);
1210         if (ret)
1211                 return ret;
1212         lock_sock(sk);
1213         ret = tls_sw_sendmsg_locked(sk, msg, size);
1214         release_sock(sk);
1215         mutex_unlock(&tls_ctx->tx_lock);
1216         return ret;
1217 }
1218
1219 /*
1220  * Handle unexpected EOF during splice without SPLICE_F_MORE set.
1221  */
1222 void tls_sw_splice_eof(struct socket *sock)
1223 {
1224         struct sock *sk = sock->sk;
1225         struct tls_context *tls_ctx = tls_get_ctx(sk);
1226         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1227         struct tls_rec *rec;
1228         struct sk_msg *msg_pl;
1229         ssize_t copied = 0;
1230         bool retrying = false;
1231         int ret = 0;
1232         int pending;
1233
1234         if (!ctx->open_rec)
1235                 return;
1236
1237         mutex_lock(&tls_ctx->tx_lock);
1238         lock_sock(sk);
1239
1240 retry:
1241         /* same checks as in tls_sw_push_pending_record() */
1242         rec = ctx->open_rec;
1243         if (!rec)
1244                 goto unlock;
1245
1246         msg_pl = &rec->msg_plaintext;
1247         if (msg_pl->sg.size == 0)
1248                 goto unlock;
1249
1250         /* Check the BPF advisor and perform transmission. */
1251         ret = bpf_exec_tx_verdict(msg_pl, sk, false, TLS_RECORD_TYPE_DATA,
1252                                   &copied, 0);
1253         switch (ret) {
1254         case 0:
1255         case -EAGAIN:
1256                 if (retrying)
1257                         goto unlock;
1258                 retrying = true;
1259                 goto retry;
1260         case -EINPROGRESS:
1261                 break;
1262         default:
1263                 goto unlock;
1264         }
1265
1266         /* Wait for pending encryptions to get completed */
1267         spin_lock_bh(&ctx->encrypt_compl_lock);
1268         ctx->async_notify = true;
1269
1270         pending = atomic_read(&ctx->encrypt_pending);
1271         spin_unlock_bh(&ctx->encrypt_compl_lock);
1272         if (pending)
1273                 crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1274         else
1275                 reinit_completion(&ctx->async_wait.completion);
1276
1277         /* There can be no concurrent accesses, since we have no pending
1278          * encrypt operations
1279          */
1280         WRITE_ONCE(ctx->async_notify, false);
1281
1282         if (ctx->async_wait.err)
1283                 goto unlock;
1284
1285         /* Transmit if any encryptions have completed */
1286         if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1287                 cancel_delayed_work(&ctx->tx_work.work);
1288                 tls_tx_records(sk, 0);
1289         }
1290
1291 unlock:
1292         release_sock(sk);
1293         mutex_unlock(&tls_ctx->tx_lock);
1294 }
1295
1296 static int
1297 tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
1298                 bool released)
1299 {
1300         struct tls_context *tls_ctx = tls_get_ctx(sk);
1301         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1302         DEFINE_WAIT_FUNC(wait, woken_wake_function);
1303         int ret = 0;
1304         long timeo;
1305
1306         timeo = sock_rcvtimeo(sk, nonblock);
1307
1308         while (!tls_strp_msg_ready(ctx)) {
1309                 if (!sk_psock_queue_empty(psock))
1310                         return 0;
1311
1312                 if (sk->sk_err)
1313                         return sock_error(sk);
1314
1315                 if (ret < 0)
1316                         return ret;
1317
1318                 if (!skb_queue_empty(&sk->sk_receive_queue)) {
1319                         tls_strp_check_rcv(&ctx->strp);
1320                         if (tls_strp_msg_ready(ctx))
1321                                 break;
1322                 }
1323
1324                 if (sk->sk_shutdown & RCV_SHUTDOWN)
1325                         return 0;
1326
1327                 if (sock_flag(sk, SOCK_DONE))
1328                         return 0;
1329
1330                 if (!timeo)
1331                         return -EAGAIN;
1332
1333                 released = true;
1334                 add_wait_queue(sk_sleep(sk), &wait);
1335                 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1336                 ret = sk_wait_event(sk, &timeo,
1337                                     tls_strp_msg_ready(ctx) ||
1338                                     !sk_psock_queue_empty(psock),
1339                                     &wait);
1340                 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1341                 remove_wait_queue(sk_sleep(sk), &wait);
1342
1343                 /* Handle signals */
1344                 if (signal_pending(current))
1345                         return sock_intr_errno(timeo);
1346         }
1347
1348         tls_strp_msg_load(&ctx->strp, released);
1349
1350         return 1;
1351 }
1352
1353 static int tls_setup_from_iter(struct iov_iter *from,
1354                                int length, int *pages_used,
1355                                struct scatterlist *to,
1356                                int to_max_pages)
1357 {
1358         int rc = 0, i = 0, num_elem = *pages_used, maxpages;
1359         struct page *pages[MAX_SKB_FRAGS];
1360         unsigned int size = 0;
1361         ssize_t copied, use;
1362         size_t offset;
1363
1364         while (length > 0) {
1365                 i = 0;
1366                 maxpages = to_max_pages - num_elem;
1367                 if (maxpages == 0) {
1368                         rc = -EFAULT;
1369                         goto out;
1370                 }
1371                 copied = iov_iter_get_pages2(from, pages,
1372                                             length,
1373                                             maxpages, &offset);
1374                 if (copied <= 0) {
1375                         rc = -EFAULT;
1376                         goto out;
1377                 }
1378
1379                 length -= copied;
1380                 size += copied;
1381                 while (copied) {
1382                         use = min_t(int, copied, PAGE_SIZE - offset);
1383
1384                         sg_set_page(&to[num_elem],
1385                                     pages[i], use, offset);
1386                         sg_unmark_end(&to[num_elem]);
1387                         /* We do not uncharge memory from this API */
1388
1389                         offset = 0;
1390                         copied -= use;
1391
1392                         i++;
1393                         num_elem++;
1394                 }
1395         }
1396         /* Mark the end in the last sg entry if newly added */
1397         if (num_elem > *pages_used)
1398                 sg_mark_end(&to[num_elem - 1]);
1399 out:
1400         if (rc)
1401                 iov_iter_revert(from, size);
1402         *pages_used = num_elem;
1403
1404         return rc;
1405 }
1406
1407 static struct sk_buff *
1408 tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb,
1409                      unsigned int full_len)
1410 {
1411         struct strp_msg *clr_rxm;
1412         struct sk_buff *clr_skb;
1413         int err;
1414
1415         clr_skb = alloc_skb_with_frags(0, full_len, TLS_PAGE_ORDER,
1416                                        &err, sk->sk_allocation);
1417         if (!clr_skb)
1418                 return NULL;
1419
1420         skb_copy_header(clr_skb, skb);
1421         clr_skb->len = full_len;
1422         clr_skb->data_len = full_len;
1423
1424         clr_rxm = strp_msg(clr_skb);
1425         clr_rxm->offset = 0;
1426
1427         return clr_skb;
1428 }
1429
1430 /* Decrypt handlers
1431  *
1432  * tls_decrypt_sw() and tls_decrypt_device() are decrypt handlers.
1433  * They must transform the darg in/out argument are as follows:
1434  *       |          Input            |         Output
1435  * -------------------------------------------------------------------
1436  *    zc | Zero-copy decrypt allowed | Zero-copy performed
1437  * async | Async decrypt allowed     | Async crypto used / in progress
1438  *   skb |            *              | Output skb
1439  *
1440  * If ZC decryption was performed darg.skb will point to the input skb.
1441  */
1442
1443 /* This function decrypts the input skb into either out_iov or in out_sg
1444  * or in skb buffers itself. The input parameter 'darg->zc' indicates if
1445  * zero-copy mode needs to be tried or not. With zero-copy mode, either
1446  * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
1447  * NULL, then the decryption happens inside skb buffers itself, i.e.
1448  * zero-copy gets disabled and 'darg->zc' is updated.
1449  */
1450 static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
1451                           struct scatterlist *out_sg,
1452                           struct tls_decrypt_arg *darg)
1453 {
1454         struct tls_context *tls_ctx = tls_get_ctx(sk);
1455         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1456         struct tls_prot_info *prot = &tls_ctx->prot_info;
1457         int n_sgin, n_sgout, aead_size, err, pages = 0;
1458         struct sk_buff *skb = tls_strp_msg(ctx);
1459         const struct strp_msg *rxm = strp_msg(skb);
1460         const struct tls_msg *tlm = tls_msg(skb);
1461         struct aead_request *aead_req;
1462         struct scatterlist *sgin = NULL;
1463         struct scatterlist *sgout = NULL;
1464         const int data_len = rxm->full_len - prot->overhead_size;
1465         int tail_pages = !!prot->tail_size;
1466         struct tls_decrypt_ctx *dctx;
1467         struct sk_buff *clear_skb;
1468         int iv_offset = 0;
1469         u8 *mem;
1470
1471         n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
1472                          rxm->full_len - prot->prepend_size);
1473         if (n_sgin < 1)
1474                 return n_sgin ?: -EBADMSG;
1475
1476         if (darg->zc && (out_iov || out_sg)) {
1477                 clear_skb = NULL;
1478
1479                 if (out_iov)
1480                         n_sgout = 1 + tail_pages +
1481                                 iov_iter_npages_cap(out_iov, INT_MAX, data_len);
1482                 else
1483                         n_sgout = sg_nents(out_sg);
1484         } else {
1485                 darg->zc = false;
1486
1487                 clear_skb = tls_alloc_clrtxt_skb(sk, skb, rxm->full_len);
1488                 if (!clear_skb)
1489                         return -ENOMEM;
1490
1491                 n_sgout = 1 + skb_shinfo(clear_skb)->nr_frags;
1492         }
1493
1494         /* Increment to accommodate AAD */
1495         n_sgin = n_sgin + 1;
1496
1497         /* Allocate a single block of memory which contains
1498          *   aead_req || tls_decrypt_ctx.
1499          * Both structs are variable length.
1500          */
1501         aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
1502         aead_size = ALIGN(aead_size, __alignof__(*dctx));
1503         mem = kmalloc(aead_size + struct_size(dctx, sg, size_add(n_sgin, n_sgout)),
1504                       sk->sk_allocation);
1505         if (!mem) {
1506                 err = -ENOMEM;
1507                 goto exit_free_skb;
1508         }
1509
1510         /* Segment the allocated memory */
1511         aead_req = (struct aead_request *)mem;
1512         dctx = (struct tls_decrypt_ctx *)(mem + aead_size);
1513         dctx->sk = sk;
1514         sgin = &dctx->sg[0];
1515         sgout = &dctx->sg[n_sgin];
1516
1517         /* For CCM based ciphers, first byte of nonce+iv is a constant */
1518         switch (prot->cipher_type) {
1519         case TLS_CIPHER_AES_CCM_128:
1520                 dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE;
1521                 iv_offset = 1;
1522                 break;
1523         case TLS_CIPHER_SM4_CCM:
1524                 dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
1525                 iv_offset = 1;
1526                 break;
1527         }
1528
1529         /* Prepare IV */
1530         if (prot->version == TLS_1_3_VERSION ||
1531             prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
1532                 memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv,
1533                        prot->iv_size + prot->salt_size);
1534         } else {
1535                 err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
1536                                     &dctx->iv[iv_offset] + prot->salt_size,
1537                                     prot->iv_size);
1538                 if (err < 0)
1539                         goto exit_free;
1540                 memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size);
1541         }
1542         tls_xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq);
1543
1544         /* Prepare AAD */
1545         tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size +
1546                      prot->tail_size,
1547                      tls_ctx->rx.rec_seq, tlm->control, prot);
1548
1549         /* Prepare sgin */
1550         sg_init_table(sgin, n_sgin);
1551         sg_set_buf(&sgin[0], dctx->aad, prot->aad_size);
1552         err = skb_to_sgvec(skb, &sgin[1],
1553                            rxm->offset + prot->prepend_size,
1554                            rxm->full_len - prot->prepend_size);
1555         if (err < 0)
1556                 goto exit_free;
1557
1558         if (clear_skb) {
1559                 sg_init_table(sgout, n_sgout);
1560                 sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
1561
1562                 err = skb_to_sgvec(clear_skb, &sgout[1], prot->prepend_size,
1563                                    data_len + prot->tail_size);
1564                 if (err < 0)
1565                         goto exit_free;
1566         } else if (out_iov) {
1567                 sg_init_table(sgout, n_sgout);
1568                 sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
1569
1570                 err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1],
1571                                           (n_sgout - 1 - tail_pages));
1572                 if (err < 0)
1573                         goto exit_free_pages;
1574
1575                 if (prot->tail_size) {
1576                         sg_unmark_end(&sgout[pages]);
1577                         sg_set_buf(&sgout[pages + 1], &dctx->tail,
1578                                    prot->tail_size);
1579                         sg_mark_end(&sgout[pages + 1]);
1580                 }
1581         } else if (out_sg) {
1582                 memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
1583         }
1584
1585         /* Prepare and submit AEAD request */
1586         err = tls_do_decryption(sk, sgin, sgout, dctx->iv,
1587                                 data_len + prot->tail_size, aead_req, darg);
1588         if (err)
1589                 goto exit_free_pages;
1590
1591         darg->skb = clear_skb ?: tls_strp_msg(ctx);
1592         clear_skb = NULL;
1593
1594         if (unlikely(darg->async)) {
1595                 err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold);
1596                 if (err)
1597                         __skb_queue_tail(&ctx->async_hold, darg->skb);
1598                 return err;
1599         }
1600
1601         if (prot->tail_size)
1602                 darg->tail = dctx->tail;
1603
1604 exit_free_pages:
1605         /* Release the pages in case iov was mapped to pages */
1606         for (; pages > 0; pages--)
1607                 put_page(sg_page(&sgout[pages]));
1608 exit_free:
1609         kfree(mem);
1610 exit_free_skb:
1611         consume_skb(clear_skb);
1612         return err;
1613 }
1614
1615 static int
1616 tls_decrypt_sw(struct sock *sk, struct tls_context *tls_ctx,
1617                struct msghdr *msg, struct tls_decrypt_arg *darg)
1618 {
1619         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1620         struct tls_prot_info *prot = &tls_ctx->prot_info;
1621         struct strp_msg *rxm;
1622         int pad, err;
1623
1624         err = tls_decrypt_sg(sk, &msg->msg_iter, NULL, darg);
1625         if (err < 0) {
1626                 if (err == -EBADMSG)
1627                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
1628                 return err;
1629         }
1630         /* keep going even for ->async, the code below is TLS 1.3 */
1631
1632         /* If opportunistic TLS 1.3 ZC failed retry without ZC */
1633         if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
1634                      darg->tail != TLS_RECORD_TYPE_DATA)) {
1635                 darg->zc = false;
1636                 if (!darg->tail)
1637                         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL);
1638                 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY);
1639                 return tls_decrypt_sw(sk, tls_ctx, msg, darg);
1640         }
1641
1642         pad = tls_padding_length(prot, darg->skb, darg);
1643         if (pad < 0) {
1644                 if (darg->skb != tls_strp_msg(ctx))
1645                         consume_skb(darg->skb);
1646                 return pad;
1647         }
1648
1649         rxm = strp_msg(darg->skb);
1650         rxm->full_len -= pad;
1651
1652         return 0;
1653 }
1654
1655 static int
1656 tls_decrypt_device(struct sock *sk, struct msghdr *msg,
1657                    struct tls_context *tls_ctx, struct tls_decrypt_arg *darg)
1658 {
1659         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1660         struct tls_prot_info *prot = &tls_ctx->prot_info;
1661         struct strp_msg *rxm;
1662         int pad, err;
1663
1664         if (tls_ctx->rx_conf != TLS_HW)
1665                 return 0;
1666
1667         err = tls_device_decrypted(sk, tls_ctx);
1668         if (err <= 0)
1669                 return err;
1670
1671         pad = tls_padding_length(prot, tls_strp_msg(ctx), darg);
1672         if (pad < 0)
1673                 return pad;
1674
1675         darg->async = false;
1676         darg->skb = tls_strp_msg(ctx);
1677         /* ->zc downgrade check, in case TLS 1.3 gets here */
1678         darg->zc &= !(prot->version == TLS_1_3_VERSION &&
1679                       tls_msg(darg->skb)->control != TLS_RECORD_TYPE_DATA);
1680
1681         rxm = strp_msg(darg->skb);
1682         rxm->full_len -= pad;
1683
1684         if (!darg->zc) {
1685                 /* Non-ZC case needs a real skb */
1686                 darg->skb = tls_strp_msg_detach(ctx);
1687                 if (!darg->skb)
1688                         return -ENOMEM;
1689         } else {
1690                 unsigned int off, len;
1691
1692                 /* In ZC case nobody cares about the output skb.
1693                  * Just copy the data here. Note the skb is not fully trimmed.
1694                  */
1695                 off = rxm->offset + prot->prepend_size;
1696                 len = rxm->full_len - prot->overhead_size;
1697
1698                 err = skb_copy_datagram_msg(darg->skb, off, msg, len);
1699                 if (err)
1700                         return err;
1701         }
1702         return 1;
1703 }
1704
1705 static int tls_rx_one_record(struct sock *sk, struct msghdr *msg,
1706                              struct tls_decrypt_arg *darg)
1707 {
1708         struct tls_context *tls_ctx = tls_get_ctx(sk);
1709         struct tls_prot_info *prot = &tls_ctx->prot_info;
1710         struct strp_msg *rxm;
1711         int err;
1712
1713         err = tls_decrypt_device(sk, msg, tls_ctx, darg);
1714         if (!err)
1715                 err = tls_decrypt_sw(sk, tls_ctx, msg, darg);
1716         if (err < 0)
1717                 return err;
1718
1719         rxm = strp_msg(darg->skb);
1720         rxm->offset += prot->prepend_size;
1721         rxm->full_len -= prot->overhead_size;
1722         tls_advance_record_sn(sk, prot, &tls_ctx->rx);
1723
1724         return 0;
1725 }
1726
1727 int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
1728 {
1729         struct tls_decrypt_arg darg = { .zc = true, };
1730
1731         return tls_decrypt_sg(sk, NULL, sgout, &darg);
1732 }
1733
1734 static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
1735                                    u8 *control)
1736 {
1737         int err;
1738
1739         if (!*control) {
1740                 *control = tlm->control;
1741                 if (!*control)
1742                         return -EBADMSG;
1743
1744                 err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1745                                sizeof(*control), control);
1746                 if (*control != TLS_RECORD_TYPE_DATA) {
1747                         if (err || msg->msg_flags & MSG_CTRUNC)
1748                                 return -EIO;
1749                 }
1750         } else if (*control != tlm->control) {
1751                 return 0;
1752         }
1753
1754         return 1;
1755 }
1756
1757 static void tls_rx_rec_done(struct tls_sw_context_rx *ctx)
1758 {
1759         tls_strp_msg_done(&ctx->strp);
1760 }
1761
1762 /* This function traverses the rx_list in tls receive context to copies the
1763  * decrypted records into the buffer provided by caller zero copy is not
1764  * true. Further, the records are removed from the rx_list if it is not a peek
1765  * case and the record has been consumed completely.
1766  */
1767 static int process_rx_list(struct tls_sw_context_rx *ctx,
1768                            struct msghdr *msg,
1769                            u8 *control,
1770                            size_t skip,
1771                            size_t len,
1772                            bool is_peek)
1773 {
1774         struct sk_buff *skb = skb_peek(&ctx->rx_list);
1775         struct tls_msg *tlm;
1776         ssize_t copied = 0;
1777         int err;
1778
1779         while (skip && skb) {
1780                 struct strp_msg *rxm = strp_msg(skb);
1781                 tlm = tls_msg(skb);
1782
1783                 err = tls_record_content_type(msg, tlm, control);
1784                 if (err <= 0)
1785                         goto out;
1786
1787                 if (skip < rxm->full_len)
1788                         break;
1789
1790                 skip = skip - rxm->full_len;
1791                 skb = skb_peek_next(skb, &ctx->rx_list);
1792         }
1793
1794         while (len && skb) {
1795                 struct sk_buff *next_skb;
1796                 struct strp_msg *rxm = strp_msg(skb);
1797                 int chunk = min_t(unsigned int, rxm->full_len - skip, len);
1798
1799                 tlm = tls_msg(skb);
1800
1801                 err = tls_record_content_type(msg, tlm, control);
1802                 if (err <= 0)
1803                         goto out;
1804
1805                 err = skb_copy_datagram_msg(skb, rxm->offset + skip,
1806                                             msg, chunk);
1807                 if (err < 0)
1808                         goto out;
1809
1810                 len = len - chunk;
1811                 copied = copied + chunk;
1812
1813                 /* Consume the data from record if it is non-peek case*/
1814                 if (!is_peek) {
1815                         rxm->offset = rxm->offset + chunk;
1816                         rxm->full_len = rxm->full_len - chunk;
1817
1818                         /* Return if there is unconsumed data in the record */
1819                         if (rxm->full_len - skip)
1820                                 break;
1821                 }
1822
1823                 /* The remaining skip-bytes must lie in 1st record in rx_list.
1824                  * So from the 2nd record, 'skip' should be 0.
1825                  */
1826                 skip = 0;
1827
1828                 if (msg)
1829                         msg->msg_flags |= MSG_EOR;
1830
1831                 next_skb = skb_peek_next(skb, &ctx->rx_list);
1832
1833                 if (!is_peek) {
1834                         __skb_unlink(skb, &ctx->rx_list);
1835                         consume_skb(skb);
1836                 }
1837
1838                 skb = next_skb;
1839         }
1840         err = 0;
1841
1842 out:
1843         return copied ? : err;
1844 }
1845
1846 static bool
1847 tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
1848                        size_t len_left, size_t decrypted, ssize_t done,
1849                        size_t *flushed_at)
1850 {
1851         size_t max_rec;
1852
1853         if (len_left <= decrypted)
1854                 return false;
1855
1856         max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE;
1857         if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec)
1858                 return false;
1859
1860         *flushed_at = done;
1861         return sk_flush_backlog(sk);
1862 }
1863
1864 static int tls_rx_reader_acquire(struct sock *sk, struct tls_sw_context_rx *ctx,
1865                                  bool nonblock)
1866 {
1867         long timeo;
1868         int ret;
1869
1870         timeo = sock_rcvtimeo(sk, nonblock);
1871
1872         while (unlikely(ctx->reader_present)) {
1873                 DEFINE_WAIT_FUNC(wait, woken_wake_function);
1874
1875                 ctx->reader_contended = 1;
1876
1877                 add_wait_queue(&ctx->wq, &wait);
1878                 ret = sk_wait_event(sk, &timeo,
1879                                     !READ_ONCE(ctx->reader_present), &wait);
1880                 remove_wait_queue(&ctx->wq, &wait);
1881
1882                 if (timeo <= 0)
1883                         return -EAGAIN;
1884                 if (signal_pending(current))
1885                         return sock_intr_errno(timeo);
1886                 if (ret < 0)
1887                         return ret;
1888         }
1889
1890         WRITE_ONCE(ctx->reader_present, 1);
1891
1892         return 0;
1893 }
1894
1895 static int tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx,
1896                               bool nonblock)
1897 {
1898         int err;
1899
1900         lock_sock(sk);
1901         err = tls_rx_reader_acquire(sk, ctx, nonblock);
1902         if (err)
1903                 release_sock(sk);
1904         return err;
1905 }
1906
1907 static void tls_rx_reader_release(struct sock *sk, struct tls_sw_context_rx *ctx)
1908 {
1909         if (unlikely(ctx->reader_contended)) {
1910                 if (wq_has_sleeper(&ctx->wq))
1911                         wake_up(&ctx->wq);
1912                 else
1913                         ctx->reader_contended = 0;
1914
1915                 WARN_ON_ONCE(!ctx->reader_present);
1916         }
1917
1918         WRITE_ONCE(ctx->reader_present, 0);
1919 }
1920
1921 static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx)
1922 {
1923         tls_rx_reader_release(sk, ctx);
1924         release_sock(sk);
1925 }
1926
1927 int tls_sw_recvmsg(struct sock *sk,
1928                    struct msghdr *msg,
1929                    size_t len,
1930                    int flags,
1931                    int *addr_len)
1932 {
1933         struct tls_context *tls_ctx = tls_get_ctx(sk);
1934         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1935         struct tls_prot_info *prot = &tls_ctx->prot_info;
1936         ssize_t decrypted = 0, async_copy_bytes = 0;
1937         struct sk_psock *psock;
1938         unsigned char control = 0;
1939         size_t flushed_at = 0;
1940         struct strp_msg *rxm;
1941         struct tls_msg *tlm;
1942         ssize_t copied = 0;
1943         bool async = false;
1944         int target, err;
1945         bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
1946         bool is_peek = flags & MSG_PEEK;
1947         bool released = true;
1948         bool bpf_strp_enabled;
1949         bool zc_capable;
1950
1951         if (unlikely(flags & MSG_ERRQUEUE))
1952                 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
1953
1954         psock = sk_psock_get(sk);
1955         err = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT);
1956         if (err < 0)
1957                 return err;
1958         bpf_strp_enabled = sk_psock_strp_enabled(psock);
1959
1960         /* If crypto failed the connection is broken */
1961         err = ctx->async_wait.err;
1962         if (err)
1963                 goto end;
1964
1965         /* Process pending decrypted records. It must be non-zero-copy */
1966         err = process_rx_list(ctx, msg, &control, 0, len, is_peek);
1967         if (err < 0)
1968                 goto end;
1969
1970         copied = err;
1971         if (len <= copied)
1972                 goto end;
1973
1974         target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1975         len = len - copied;
1976
1977         zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
1978                 ctx->zc_capable;
1979         decrypted = 0;
1980         while (len && (decrypted + copied < target || tls_strp_msg_ready(ctx))) {
1981                 struct tls_decrypt_arg darg;
1982                 int to_decrypt, chunk;
1983
1984                 err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT,
1985                                       released);
1986                 if (err <= 0) {
1987                         if (psock) {
1988                                 chunk = sk_msg_recvmsg(sk, psock, msg, len,
1989                                                        flags);
1990                                 if (chunk > 0) {
1991                                         decrypted += chunk;
1992                                         len -= chunk;
1993                                         continue;
1994                                 }
1995                         }
1996                         goto recv_end;
1997                 }
1998
1999                 memset(&darg.inargs, 0, sizeof(darg.inargs));
2000
2001                 rxm = strp_msg(tls_strp_msg(ctx));
2002                 tlm = tls_msg(tls_strp_msg(ctx));
2003
2004                 to_decrypt = rxm->full_len - prot->overhead_size;
2005
2006                 if (zc_capable && to_decrypt <= len &&
2007                     tlm->control == TLS_RECORD_TYPE_DATA)
2008                         darg.zc = true;
2009
2010                 /* Do not use async mode if record is non-data */
2011                 if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
2012                         darg.async = ctx->async_capable;
2013                 else
2014                         darg.async = false;
2015
2016                 err = tls_rx_one_record(sk, msg, &darg);
2017                 if (err < 0) {
2018                         tls_err_abort(sk, -EBADMSG);
2019                         goto recv_end;
2020                 }
2021
2022                 async |= darg.async;
2023
2024                 /* If the type of records being processed is not known yet,
2025                  * set it to record type just dequeued. If it is already known,
2026                  * but does not match the record type just dequeued, go to end.
2027                  * We always get record type here since for tls1.2, record type
2028                  * is known just after record is dequeued from stream parser.
2029                  * For tls1.3, we disable async.
2030                  */
2031                 err = tls_record_content_type(msg, tls_msg(darg.skb), &control);
2032                 if (err <= 0) {
2033                         DEBUG_NET_WARN_ON_ONCE(darg.zc);
2034                         tls_rx_rec_done(ctx);
2035 put_on_rx_list_err:
2036                         __skb_queue_tail(&ctx->rx_list, darg.skb);
2037                         goto recv_end;
2038                 }
2039
2040                 /* periodically flush backlog, and feed strparser */
2041                 released = tls_read_flush_backlog(sk, prot, len, to_decrypt,
2042                                                   decrypted + copied,
2043                                                   &flushed_at);
2044
2045                 /* TLS 1.3 may have updated the length by more than overhead */
2046                 rxm = strp_msg(darg.skb);
2047                 chunk = rxm->full_len;
2048                 tls_rx_rec_done(ctx);
2049
2050                 if (!darg.zc) {
2051                         bool partially_consumed = chunk > len;
2052                         struct sk_buff *skb = darg.skb;
2053
2054                         DEBUG_NET_WARN_ON_ONCE(darg.skb == ctx->strp.anchor);
2055
2056                         if (async) {
2057                                 /* TLS 1.2-only, to_decrypt must be text len */
2058                                 chunk = min_t(int, to_decrypt, len);
2059                                 async_copy_bytes += chunk;
2060 put_on_rx_list:
2061                                 decrypted += chunk;
2062                                 len -= chunk;
2063                                 __skb_queue_tail(&ctx->rx_list, skb);
2064                                 continue;
2065                         }
2066
2067                         if (bpf_strp_enabled) {
2068                                 released = true;
2069                                 err = sk_psock_tls_strp_read(psock, skb);
2070                                 if (err != __SK_PASS) {
2071                                         rxm->offset = rxm->offset + rxm->full_len;
2072                                         rxm->full_len = 0;
2073                                         if (err == __SK_DROP)
2074                                                 consume_skb(skb);
2075                                         continue;
2076                                 }
2077                         }
2078
2079                         if (partially_consumed)
2080                                 chunk = len;
2081
2082                         err = skb_copy_datagram_msg(skb, rxm->offset,
2083                                                     msg, chunk);
2084                         if (err < 0)
2085                                 goto put_on_rx_list_err;
2086
2087                         if (is_peek)
2088                                 goto put_on_rx_list;
2089
2090                         if (partially_consumed) {
2091                                 rxm->offset += chunk;
2092                                 rxm->full_len -= chunk;
2093                                 goto put_on_rx_list;
2094                         }
2095
2096                         consume_skb(skb);
2097                 }
2098
2099                 decrypted += chunk;
2100                 len -= chunk;
2101
2102                 /* Return full control message to userspace before trying
2103                  * to parse another message type
2104                  */
2105                 msg->msg_flags |= MSG_EOR;
2106                 if (control != TLS_RECORD_TYPE_DATA)
2107                         break;
2108         }
2109
2110 recv_end:
2111         if (async) {
2112                 int ret, pending;
2113
2114                 /* Wait for all previously submitted records to be decrypted */
2115                 spin_lock_bh(&ctx->decrypt_compl_lock);
2116                 reinit_completion(&ctx->async_wait.completion);
2117                 pending = atomic_read(&ctx->decrypt_pending);
2118                 spin_unlock_bh(&ctx->decrypt_compl_lock);
2119                 ret = 0;
2120                 if (pending)
2121                         ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2122                 __skb_queue_purge(&ctx->async_hold);
2123
2124                 if (ret) {
2125                         if (err >= 0 || err == -EINPROGRESS)
2126                                 err = ret;
2127                         decrypted = 0;
2128                         goto end;
2129                 }
2130
2131                 /* Drain records from the rx_list & copy if required */
2132                 if (is_peek || is_kvec)
2133                         err = process_rx_list(ctx, msg, &control, copied,
2134                                               decrypted, is_peek);
2135                 else
2136                         err = process_rx_list(ctx, msg, &control, 0,
2137                                               async_copy_bytes, is_peek);
2138                 decrypted += max(err, 0);
2139         }
2140
2141         copied += decrypted;
2142
2143 end:
2144         tls_rx_reader_unlock(sk, ctx);
2145         if (psock)
2146                 sk_psock_put(sk, psock);
2147         return copied ? : err;
2148 }
2149
2150 ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
2151                            struct pipe_inode_info *pipe,
2152                            size_t len, unsigned int flags)
2153 {
2154         struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
2155         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2156         struct strp_msg *rxm = NULL;
2157         struct sock *sk = sock->sk;
2158         struct tls_msg *tlm;
2159         struct sk_buff *skb;
2160         ssize_t copied = 0;
2161         int chunk;
2162         int err;
2163
2164         err = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK);
2165         if (err < 0)
2166                 return err;
2167
2168         if (!skb_queue_empty(&ctx->rx_list)) {
2169                 skb = __skb_dequeue(&ctx->rx_list);
2170         } else {
2171                 struct tls_decrypt_arg darg;
2172
2173                 err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
2174                                       true);
2175                 if (err <= 0)
2176                         goto splice_read_end;
2177
2178                 memset(&darg.inargs, 0, sizeof(darg.inargs));
2179
2180                 err = tls_rx_one_record(sk, NULL, &darg);
2181                 if (err < 0) {
2182                         tls_err_abort(sk, -EBADMSG);
2183                         goto splice_read_end;
2184                 }
2185
2186                 tls_rx_rec_done(ctx);
2187                 skb = darg.skb;
2188         }
2189
2190         rxm = strp_msg(skb);
2191         tlm = tls_msg(skb);
2192
2193         /* splice does not support reading control messages */
2194         if (tlm->control != TLS_RECORD_TYPE_DATA) {
2195                 err = -EINVAL;
2196                 goto splice_requeue;
2197         }
2198
2199         chunk = min_t(unsigned int, rxm->full_len, len);
2200         copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
2201         if (copied < 0)
2202                 goto splice_requeue;
2203
2204         if (chunk < rxm->full_len) {
2205                 rxm->offset += len;
2206                 rxm->full_len -= len;
2207                 goto splice_requeue;
2208         }
2209
2210         consume_skb(skb);
2211
2212 splice_read_end:
2213         tls_rx_reader_unlock(sk, ctx);
2214         return copied ? : err;
2215
2216 splice_requeue:
2217         __skb_queue_head(&ctx->rx_list, skb);
2218         goto splice_read_end;
2219 }
2220
2221 int tls_sw_read_sock(struct sock *sk, read_descriptor_t *desc,
2222                      sk_read_actor_t read_actor)
2223 {
2224         struct tls_context *tls_ctx = tls_get_ctx(sk);
2225         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2226         struct tls_prot_info *prot = &tls_ctx->prot_info;
2227         struct strp_msg *rxm = NULL;
2228         struct sk_buff *skb = NULL;
2229         struct sk_psock *psock;
2230         size_t flushed_at = 0;
2231         bool released = true;
2232         struct tls_msg *tlm;
2233         ssize_t copied = 0;
2234         ssize_t decrypted;
2235         int err, used;
2236
2237         psock = sk_psock_get(sk);
2238         if (psock) {
2239                 sk_psock_put(sk, psock);
2240                 return -EINVAL;
2241         }
2242         err = tls_rx_reader_acquire(sk, ctx, true);
2243         if (err < 0)
2244                 return err;
2245
2246         /* If crypto failed the connection is broken */
2247         err = ctx->async_wait.err;
2248         if (err)
2249                 goto read_sock_end;
2250
2251         decrypted = 0;
2252         do {
2253                 if (!skb_queue_empty(&ctx->rx_list)) {
2254                         skb = __skb_dequeue(&ctx->rx_list);
2255                         rxm = strp_msg(skb);
2256                         tlm = tls_msg(skb);
2257                 } else {
2258                         struct tls_decrypt_arg darg;
2259
2260                         err = tls_rx_rec_wait(sk, NULL, true, released);
2261                         if (err <= 0)
2262                                 goto read_sock_end;
2263
2264                         memset(&darg.inargs, 0, sizeof(darg.inargs));
2265
2266                         err = tls_rx_one_record(sk, NULL, &darg);
2267                         if (err < 0) {
2268                                 tls_err_abort(sk, -EBADMSG);
2269                                 goto read_sock_end;
2270                         }
2271
2272                         released = tls_read_flush_backlog(sk, prot, INT_MAX,
2273                                                           0, decrypted,
2274                                                           &flushed_at);
2275                         skb = darg.skb;
2276                         rxm = strp_msg(skb);
2277                         tlm = tls_msg(skb);
2278                         decrypted += rxm->full_len;
2279
2280                         tls_rx_rec_done(ctx);
2281                 }
2282
2283                 /* read_sock does not support reading control messages */
2284                 if (tlm->control != TLS_RECORD_TYPE_DATA) {
2285                         err = -EINVAL;
2286                         goto read_sock_requeue;
2287                 }
2288
2289                 used = read_actor(desc, skb, rxm->offset, rxm->full_len);
2290                 if (used <= 0) {
2291                         if (!copied)
2292                                 err = used;
2293                         goto read_sock_requeue;
2294                 }
2295                 copied += used;
2296                 if (used < rxm->full_len) {
2297                         rxm->offset += used;
2298                         rxm->full_len -= used;
2299                         if (!desc->count)
2300                                 goto read_sock_requeue;
2301                 } else {
2302                         consume_skb(skb);
2303                         if (!desc->count)
2304                                 skb = NULL;
2305                 }
2306         } while (skb);
2307
2308 read_sock_end:
2309         tls_rx_reader_release(sk, ctx);
2310         return copied ? : err;
2311
2312 read_sock_requeue:
2313         __skb_queue_head(&ctx->rx_list, skb);
2314         goto read_sock_end;
2315 }
2316
2317 bool tls_sw_sock_is_readable(struct sock *sk)
2318 {
2319         struct tls_context *tls_ctx = tls_get_ctx(sk);
2320         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2321         bool ingress_empty = true;
2322         struct sk_psock *psock;
2323
2324         rcu_read_lock();
2325         psock = sk_psock(sk);
2326         if (psock)
2327                 ingress_empty = list_empty(&psock->ingress_msg);
2328         rcu_read_unlock();
2329
2330         return !ingress_empty || tls_strp_msg_ready(ctx) ||
2331                 !skb_queue_empty(&ctx->rx_list);
2332 }
2333
2334 int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb)
2335 {
2336         struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
2337         struct tls_prot_info *prot = &tls_ctx->prot_info;
2338         char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
2339         size_t cipher_overhead;
2340         size_t data_len = 0;
2341         int ret;
2342
2343         /* Verify that we have a full TLS header, or wait for more data */
2344         if (strp->stm.offset + prot->prepend_size > skb->len)
2345                 return 0;
2346
2347         /* Sanity-check size of on-stack buffer. */
2348         if (WARN_ON(prot->prepend_size > sizeof(header))) {
2349                 ret = -EINVAL;
2350                 goto read_failure;
2351         }
2352
2353         /* Linearize header to local buffer */
2354         ret = skb_copy_bits(skb, strp->stm.offset, header, prot->prepend_size);
2355         if (ret < 0)
2356                 goto read_failure;
2357
2358         strp->mark = header[0];
2359
2360         data_len = ((header[4] & 0xFF) | (header[3] << 8));
2361
2362         cipher_overhead = prot->tag_size;
2363         if (prot->version != TLS_1_3_VERSION &&
2364             prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305)
2365                 cipher_overhead += prot->iv_size;
2366
2367         if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
2368             prot->tail_size) {
2369                 ret = -EMSGSIZE;
2370                 goto read_failure;
2371         }
2372         if (data_len < cipher_overhead) {
2373                 ret = -EBADMSG;
2374                 goto read_failure;
2375         }
2376
2377         /* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */
2378         if (header[1] != TLS_1_2_VERSION_MINOR ||
2379             header[2] != TLS_1_2_VERSION_MAJOR) {
2380                 ret = -EINVAL;
2381                 goto read_failure;
2382         }
2383
2384         tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
2385                                      TCP_SKB_CB(skb)->seq + strp->stm.offset);
2386         return data_len + TLS_HEADER_SIZE;
2387
2388 read_failure:
2389         tls_err_abort(strp->sk, ret);
2390
2391         return ret;
2392 }
2393
2394 void tls_rx_msg_ready(struct tls_strparser *strp)
2395 {
2396         struct tls_sw_context_rx *ctx;
2397
2398         ctx = container_of(strp, struct tls_sw_context_rx, strp);
2399         ctx->saved_data_ready(strp->sk);
2400 }
2401
2402 static void tls_data_ready(struct sock *sk)
2403 {
2404         struct tls_context *tls_ctx = tls_get_ctx(sk);
2405         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2406         struct sk_psock *psock;
2407         gfp_t alloc_save;
2408
2409         trace_sk_data_ready(sk);
2410
2411         alloc_save = sk->sk_allocation;
2412         sk->sk_allocation = GFP_ATOMIC;
2413         tls_strp_data_ready(&ctx->strp);
2414         sk->sk_allocation = alloc_save;
2415
2416         psock = sk_psock_get(sk);
2417         if (psock) {
2418                 if (!list_empty(&psock->ingress_msg))
2419                         ctx->saved_data_ready(sk);
2420                 sk_psock_put(sk, psock);
2421         }
2422 }
2423
2424 void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
2425 {
2426         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2427
2428         set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
2429         set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
2430         cancel_delayed_work_sync(&ctx->tx_work.work);
2431 }
2432
2433 void tls_sw_release_resources_tx(struct sock *sk)
2434 {
2435         struct tls_context *tls_ctx = tls_get_ctx(sk);
2436         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2437         struct tls_rec *rec, *tmp;
2438         int pending;
2439
2440         /* Wait for any pending async encryptions to complete */
2441         spin_lock_bh(&ctx->encrypt_compl_lock);
2442         ctx->async_notify = true;
2443         pending = atomic_read(&ctx->encrypt_pending);
2444         spin_unlock_bh(&ctx->encrypt_compl_lock);
2445
2446         if (pending)
2447                 crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2448
2449         tls_tx_records(sk, -1);
2450
2451         /* Free up un-sent records in tx_list. First, free
2452          * the partially sent record if any at head of tx_list.
2453          */
2454         if (tls_ctx->partially_sent_record) {
2455                 tls_free_partial_record(sk, tls_ctx);
2456                 rec = list_first_entry(&ctx->tx_list,
2457                                        struct tls_rec, list);
2458                 list_del(&rec->list);
2459                 sk_msg_free(sk, &rec->msg_plaintext);
2460                 kfree(rec);
2461         }
2462
2463         list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
2464                 list_del(&rec->list);
2465                 sk_msg_free(sk, &rec->msg_encrypted);
2466                 sk_msg_free(sk, &rec->msg_plaintext);
2467                 kfree(rec);
2468         }
2469
2470         crypto_free_aead(ctx->aead_send);
2471         tls_free_open_rec(sk);
2472 }
2473
2474 void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
2475 {
2476         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2477
2478         kfree(ctx);
2479 }
2480
2481 void tls_sw_release_resources_rx(struct sock *sk)
2482 {
2483         struct tls_context *tls_ctx = tls_get_ctx(sk);
2484         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2485
2486         kfree(tls_ctx->rx.rec_seq);
2487         kfree(tls_ctx->rx.iv);
2488
2489         if (ctx->aead_recv) {
2490                 __skb_queue_purge(&ctx->rx_list);
2491                 crypto_free_aead(ctx->aead_recv);
2492                 tls_strp_stop(&ctx->strp);
2493                 /* If tls_sw_strparser_arm() was not called (cleanup paths)
2494                  * we still want to tls_strp_stop(), but sk->sk_data_ready was
2495                  * never swapped.
2496                  */
2497                 if (ctx->saved_data_ready) {
2498                         write_lock_bh(&sk->sk_callback_lock);
2499                         sk->sk_data_ready = ctx->saved_data_ready;
2500                         write_unlock_bh(&sk->sk_callback_lock);
2501                 }
2502         }
2503 }
2504
2505 void tls_sw_strparser_done(struct tls_context *tls_ctx)
2506 {
2507         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2508
2509         tls_strp_done(&ctx->strp);
2510 }
2511
2512 void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
2513 {
2514         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2515
2516         kfree(ctx);
2517 }
2518
2519 void tls_sw_free_resources_rx(struct sock *sk)
2520 {
2521         struct tls_context *tls_ctx = tls_get_ctx(sk);
2522
2523         tls_sw_release_resources_rx(sk);
2524         tls_sw_free_ctx_rx(tls_ctx);
2525 }
2526
2527 /* The work handler to transmitt the encrypted records in tx_list */
2528 static void tx_work_handler(struct work_struct *work)
2529 {
2530         struct delayed_work *delayed_work = to_delayed_work(work);
2531         struct tx_work *tx_work = container_of(delayed_work,
2532                                                struct tx_work, work);
2533         struct sock *sk = tx_work->sk;
2534         struct tls_context *tls_ctx = tls_get_ctx(sk);
2535         struct tls_sw_context_tx *ctx;
2536
2537         if (unlikely(!tls_ctx))
2538                 return;
2539
2540         ctx = tls_sw_ctx_tx(tls_ctx);
2541         if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
2542                 return;
2543
2544         if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
2545                 return;
2546
2547         if (mutex_trylock(&tls_ctx->tx_lock)) {
2548                 lock_sock(sk);
2549                 tls_tx_records(sk, -1);
2550                 release_sock(sk);
2551                 mutex_unlock(&tls_ctx->tx_lock);
2552         } else if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
2553                 /* Someone is holding the tx_lock, they will likely run Tx
2554                  * and cancel the work on their way out of the lock section.
2555                  * Schedule a long delay just in case.
2556                  */
2557                 schedule_delayed_work(&ctx->tx_work.work, msecs_to_jiffies(10));
2558         }
2559 }
2560
2561 static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx)
2562 {
2563         struct tls_rec *rec;
2564
2565         rec = list_first_entry_or_null(&ctx->tx_list, struct tls_rec, list);
2566         if (!rec)
2567                 return false;
2568
2569         return READ_ONCE(rec->tx_ready);
2570 }
2571
2572 void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
2573 {
2574         struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
2575
2576         /* Schedule the transmission if tx list is ready */
2577         if (tls_is_tx_ready(tx_ctx) &&
2578             !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
2579                 schedule_delayed_work(&tx_ctx->tx_work.work, 0);
2580 }
2581
2582 void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
2583 {
2584         struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2585
2586         write_lock_bh(&sk->sk_callback_lock);
2587         rx_ctx->saved_data_ready = sk->sk_data_ready;
2588         sk->sk_data_ready = tls_data_ready;
2589         write_unlock_bh(&sk->sk_callback_lock);
2590 }
2591
2592 void tls_update_rx_zc_capable(struct tls_context *tls_ctx)
2593 {
2594         struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2595
2596         rx_ctx->zc_capable = tls_ctx->rx_no_pad ||
2597                 tls_ctx->prot_info.version != TLS_1_3_VERSION;
2598 }
2599
2600 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
2601 {
2602         struct tls_context *tls_ctx = tls_get_ctx(sk);
2603         struct tls_prot_info *prot = &tls_ctx->prot_info;
2604         struct tls_crypto_info *crypto_info;
2605         struct tls_sw_context_tx *sw_ctx_tx = NULL;
2606         struct tls_sw_context_rx *sw_ctx_rx = NULL;
2607         struct cipher_context *cctx;
2608         struct crypto_aead **aead;
2609         struct crypto_tfm *tfm;
2610         char *iv, *rec_seq, *key, *salt;
2611         const struct tls_cipher_desc *cipher_desc;
2612         u16 nonce_size;
2613         int rc = 0;
2614
2615         if (!ctx) {
2616                 rc = -EINVAL;
2617                 goto out;
2618         }
2619
2620         if (tx) {
2621                 if (!ctx->priv_ctx_tx) {
2622                         sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
2623                         if (!sw_ctx_tx) {
2624                                 rc = -ENOMEM;
2625                                 goto out;
2626                         }
2627                         ctx->priv_ctx_tx = sw_ctx_tx;
2628                 } else {
2629                         sw_ctx_tx =
2630                                 (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
2631                 }
2632         } else {
2633                 if (!ctx->priv_ctx_rx) {
2634                         sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
2635                         if (!sw_ctx_rx) {
2636                                 rc = -ENOMEM;
2637                                 goto out;
2638                         }
2639                         ctx->priv_ctx_rx = sw_ctx_rx;
2640                 } else {
2641                         sw_ctx_rx =
2642                                 (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
2643                 }
2644         }
2645
2646         if (tx) {
2647                 crypto_init_wait(&sw_ctx_tx->async_wait);
2648                 spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
2649                 crypto_info = &ctx->crypto_send.info;
2650                 cctx = &ctx->tx;
2651                 aead = &sw_ctx_tx->aead_send;
2652                 INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
2653                 INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
2654                 sw_ctx_tx->tx_work.sk = sk;
2655         } else {
2656                 crypto_init_wait(&sw_ctx_rx->async_wait);
2657                 spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
2658                 init_waitqueue_head(&sw_ctx_rx->wq);
2659                 crypto_info = &ctx->crypto_recv.info;
2660                 cctx = &ctx->rx;
2661                 skb_queue_head_init(&sw_ctx_rx->rx_list);
2662                 skb_queue_head_init(&sw_ctx_rx->async_hold);
2663                 aead = &sw_ctx_rx->aead_recv;
2664         }
2665
2666         cipher_desc = get_cipher_desc(crypto_info->cipher_type);
2667         if (!cipher_desc) {
2668                 rc = -EINVAL;
2669                 goto free_priv;
2670         }
2671
2672         nonce_size = cipher_desc->nonce;
2673
2674         iv = crypto_info_iv(crypto_info, cipher_desc);
2675         key = crypto_info_key(crypto_info, cipher_desc);
2676         salt = crypto_info_salt(crypto_info, cipher_desc);
2677         rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc);
2678
2679         if (crypto_info->version == TLS_1_3_VERSION) {
2680                 nonce_size = 0;
2681                 prot->aad_size = TLS_HEADER_SIZE;
2682                 prot->tail_size = 1;
2683         } else {
2684                 prot->aad_size = TLS_AAD_SPACE_SIZE;
2685                 prot->tail_size = 0;
2686         }
2687
2688         /* Sanity-check the sizes for stack allocations. */
2689         if (nonce_size > MAX_IV_SIZE || prot->aad_size > TLS_MAX_AAD_SIZE) {
2690                 rc = -EINVAL;
2691                 goto free_priv;
2692         }
2693
2694         prot->version = crypto_info->version;
2695         prot->cipher_type = crypto_info->cipher_type;
2696         prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
2697         prot->tag_size = cipher_desc->tag;
2698         prot->overhead_size = prot->prepend_size +
2699                               prot->tag_size + prot->tail_size;
2700         prot->iv_size = cipher_desc->iv;
2701         prot->salt_size = cipher_desc->salt;
2702         cctx->iv = kmalloc(cipher_desc->iv + cipher_desc->salt, GFP_KERNEL);
2703         if (!cctx->iv) {
2704                 rc = -ENOMEM;
2705                 goto free_priv;
2706         }
2707         /* Note: 128 & 256 bit salt are the same size */
2708         prot->rec_seq_size = cipher_desc->rec_seq;
2709         memcpy(cctx->iv, salt, cipher_desc->salt);
2710         memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv);
2711
2712         cctx->rec_seq = kmemdup(rec_seq, cipher_desc->rec_seq, GFP_KERNEL);
2713         if (!cctx->rec_seq) {
2714                 rc = -ENOMEM;
2715                 goto free_iv;
2716         }
2717
2718         if (!*aead) {
2719                 *aead = crypto_alloc_aead(cipher_desc->cipher_name, 0, 0);
2720                 if (IS_ERR(*aead)) {
2721                         rc = PTR_ERR(*aead);
2722                         *aead = NULL;
2723                         goto free_rec_seq;
2724                 }
2725         }
2726
2727         ctx->push_pending_record = tls_sw_push_pending_record;
2728
2729         rc = crypto_aead_setkey(*aead, key, cipher_desc->key);
2730         if (rc)
2731                 goto free_aead;
2732
2733         rc = crypto_aead_setauthsize(*aead, prot->tag_size);
2734         if (rc)
2735                 goto free_aead;
2736
2737         if (sw_ctx_rx) {
2738                 tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
2739
2740                 tls_update_rx_zc_capable(ctx);
2741                 sw_ctx_rx->async_capable =
2742                         crypto_info->version != TLS_1_3_VERSION &&
2743                         !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
2744
2745                 rc = tls_strp_init(&sw_ctx_rx->strp, sk);
2746                 if (rc)
2747                         goto free_aead;
2748         }
2749
2750         goto out;
2751
2752 free_aead:
2753         crypto_free_aead(*aead);
2754         *aead = NULL;
2755 free_rec_seq:
2756         kfree(cctx->rec_seq);
2757         cctx->rec_seq = NULL;
2758 free_iv:
2759         kfree(cctx->iv);
2760         cctx->iv = NULL;
2761 free_priv:
2762         if (tx) {
2763                 kfree(ctx->priv_ctx_tx);
2764                 ctx->priv_ctx_tx = NULL;
2765         } else {
2766                 kfree(ctx->priv_ctx_rx);
2767                 ctx->priv_ctx_rx = NULL;
2768         }
2769 out:
2770         return rc;
2771 }