Merge tag 'ipsec-2023-10-17' of git://git.kernel.org/pub/scm/linux/kernel/git/klasser...
[platform/kernel/linux-rpi.git] / net / tls / tls_device.c
1 /* Copyright (c) 2018, Mellanox Technologies All rights reserved.
2  *
3  * This software is available to you under a choice of one of two
4  * licenses.  You may choose to be licensed under the terms of the GNU
5  * General Public License (GPL) Version 2, available from the file
6  * COPYING in the main directory of this source tree, or the
7  * OpenIB.org BSD license below:
8  *
9  *     Redistribution and use in source and binary forms, with or
10  *     without modification, are permitted provided that the following
11  *     conditions are met:
12  *
13  *      - Redistributions of source code must retain the above
14  *        copyright notice, this list of conditions and the following
15  *        disclaimer.
16  *
17  *      - Redistributions in binary form must reproduce the above
18  *        copyright notice, this list of conditions and the following
19  *        disclaimer in the documentation and/or other materials
20  *        provided with the distribution.
21  *
22  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
26  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
27  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
28  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29  * SOFTWARE.
30  */
31
32 #include <crypto/aead.h>
33 #include <linux/highmem.h>
34 #include <linux/module.h>
35 #include <linux/netdevice.h>
36 #include <net/dst.h>
37 #include <net/inet_connection_sock.h>
38 #include <net/tcp.h>
39 #include <net/tls.h>
40
41 #include "tls.h"
42 #include "trace.h"
43
44 /* device_offload_lock is used to synchronize tls_dev_add
45  * against NETDEV_DOWN notifications.
46  */
47 static DECLARE_RWSEM(device_offload_lock);
48
49 static struct workqueue_struct *destruct_wq __read_mostly;
50
51 static LIST_HEAD(tls_device_list);
52 static LIST_HEAD(tls_device_down_list);
53 static DEFINE_SPINLOCK(tls_device_lock);
54
55 static struct page *dummy_page;
56
57 static void tls_device_free_ctx(struct tls_context *ctx)
58 {
59         if (ctx->tx_conf == TLS_HW) {
60                 kfree(tls_offload_ctx_tx(ctx));
61                 kfree(ctx->tx.rec_seq);
62                 kfree(ctx->tx.iv);
63         }
64
65         if (ctx->rx_conf == TLS_HW)
66                 kfree(tls_offload_ctx_rx(ctx));
67
68         tls_ctx_free(NULL, ctx);
69 }
70
71 static void tls_device_tx_del_task(struct work_struct *work)
72 {
73         struct tls_offload_context_tx *offload_ctx =
74                 container_of(work, struct tls_offload_context_tx, destruct_work);
75         struct tls_context *ctx = offload_ctx->ctx;
76         struct net_device *netdev;
77
78         /* Safe, because this is the destroy flow, refcount is 0, so
79          * tls_device_down can't store this field in parallel.
80          */
81         netdev = rcu_dereference_protected(ctx->netdev,
82                                            !refcount_read(&ctx->refcount));
83
84         netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX);
85         dev_put(netdev);
86         ctx->netdev = NULL;
87         tls_device_free_ctx(ctx);
88 }
89
90 static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
91 {
92         struct net_device *netdev;
93         unsigned long flags;
94         bool async_cleanup;
95
96         spin_lock_irqsave(&tls_device_lock, flags);
97         if (unlikely(!refcount_dec_and_test(&ctx->refcount))) {
98                 spin_unlock_irqrestore(&tls_device_lock, flags);
99                 return;
100         }
101
102         list_del(&ctx->list); /* Remove from tls_device_list / tls_device_down_list */
103
104         /* Safe, because this is the destroy flow, refcount is 0, so
105          * tls_device_down can't store this field in parallel.
106          */
107         netdev = rcu_dereference_protected(ctx->netdev,
108                                            !refcount_read(&ctx->refcount));
109
110         async_cleanup = netdev && ctx->tx_conf == TLS_HW;
111         if (async_cleanup) {
112                 struct tls_offload_context_tx *offload_ctx = tls_offload_ctx_tx(ctx);
113
114                 /* queue_work inside the spinlock
115                  * to make sure tls_device_down waits for that work.
116                  */
117                 queue_work(destruct_wq, &offload_ctx->destruct_work);
118         }
119         spin_unlock_irqrestore(&tls_device_lock, flags);
120
121         if (!async_cleanup)
122                 tls_device_free_ctx(ctx);
123 }
124
125 /* We assume that the socket is already connected */
126 static struct net_device *get_netdev_for_sock(struct sock *sk)
127 {
128         struct dst_entry *dst = sk_dst_get(sk);
129         struct net_device *netdev = NULL;
130
131         if (likely(dst)) {
132                 netdev = netdev_sk_get_lowest_dev(dst->dev, sk);
133                 dev_hold(netdev);
134         }
135
136         dst_release(dst);
137
138         return netdev;
139 }
140
141 static void destroy_record(struct tls_record_info *record)
142 {
143         int i;
144
145         for (i = 0; i < record->num_frags; i++)
146                 __skb_frag_unref(&record->frags[i], false);
147         kfree(record);
148 }
149
150 static void delete_all_records(struct tls_offload_context_tx *offload_ctx)
151 {
152         struct tls_record_info *info, *temp;
153
154         list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
155                 list_del(&info->list);
156                 destroy_record(info);
157         }
158
159         offload_ctx->retransmit_hint = NULL;
160 }
161
162 static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
163 {
164         struct tls_context *tls_ctx = tls_get_ctx(sk);
165         struct tls_record_info *info, *temp;
166         struct tls_offload_context_tx *ctx;
167         u64 deleted_records = 0;
168         unsigned long flags;
169
170         if (!tls_ctx)
171                 return;
172
173         ctx = tls_offload_ctx_tx(tls_ctx);
174
175         spin_lock_irqsave(&ctx->lock, flags);
176         info = ctx->retransmit_hint;
177         if (info && !before(acked_seq, info->end_seq))
178                 ctx->retransmit_hint = NULL;
179
180         list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
181                 if (before(acked_seq, info->end_seq))
182                         break;
183                 list_del(&info->list);
184
185                 destroy_record(info);
186                 deleted_records++;
187         }
188
189         ctx->unacked_record_sn += deleted_records;
190         spin_unlock_irqrestore(&ctx->lock, flags);
191 }
192
193 /* At this point, there should be no references on this
194  * socket and no in-flight SKBs associated with this
195  * socket, so it is safe to free all the resources.
196  */
197 void tls_device_sk_destruct(struct sock *sk)
198 {
199         struct tls_context *tls_ctx = tls_get_ctx(sk);
200         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
201
202         tls_ctx->sk_destruct(sk);
203
204         if (tls_ctx->tx_conf == TLS_HW) {
205                 if (ctx->open_record)
206                         destroy_record(ctx->open_record);
207                 delete_all_records(ctx);
208                 crypto_free_aead(ctx->aead_send);
209                 clean_acked_data_disable(inet_csk(sk));
210         }
211
212         tls_device_queue_ctx_destruction(tls_ctx);
213 }
214 EXPORT_SYMBOL_GPL(tls_device_sk_destruct);
215
216 void tls_device_free_resources_tx(struct sock *sk)
217 {
218         struct tls_context *tls_ctx = tls_get_ctx(sk);
219
220         tls_free_partial_record(sk, tls_ctx);
221 }
222
223 void tls_offload_tx_resync_request(struct sock *sk, u32 got_seq, u32 exp_seq)
224 {
225         struct tls_context *tls_ctx = tls_get_ctx(sk);
226
227         trace_tls_device_tx_resync_req(sk, got_seq, exp_seq);
228         WARN_ON(test_and_set_bit(TLS_TX_SYNC_SCHED, &tls_ctx->flags));
229 }
230 EXPORT_SYMBOL_GPL(tls_offload_tx_resync_request);
231
232 static void tls_device_resync_tx(struct sock *sk, struct tls_context *tls_ctx,
233                                  u32 seq)
234 {
235         struct net_device *netdev;
236         struct sk_buff *skb;
237         int err = 0;
238         u8 *rcd_sn;
239
240         skb = tcp_write_queue_tail(sk);
241         if (skb)
242                 TCP_SKB_CB(skb)->eor = 1;
243
244         rcd_sn = tls_ctx->tx.rec_seq;
245
246         trace_tls_device_tx_resync_send(sk, seq, rcd_sn);
247         down_read(&device_offload_lock);
248         netdev = rcu_dereference_protected(tls_ctx->netdev,
249                                            lockdep_is_held(&device_offload_lock));
250         if (netdev)
251                 err = netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq,
252                                                          rcd_sn,
253                                                          TLS_OFFLOAD_CTX_DIR_TX);
254         up_read(&device_offload_lock);
255         if (err)
256                 return;
257
258         clear_bit_unlock(TLS_TX_SYNC_SCHED, &tls_ctx->flags);
259 }
260
261 static void tls_append_frag(struct tls_record_info *record,
262                             struct page_frag *pfrag,
263                             int size)
264 {
265         skb_frag_t *frag;
266
267         frag = &record->frags[record->num_frags - 1];
268         if (skb_frag_page(frag) == pfrag->page &&
269             skb_frag_off(frag) + skb_frag_size(frag) == pfrag->offset) {
270                 skb_frag_size_add(frag, size);
271         } else {
272                 ++frag;
273                 skb_frag_fill_page_desc(frag, pfrag->page, pfrag->offset,
274                                         size);
275                 ++record->num_frags;
276                 get_page(pfrag->page);
277         }
278
279         pfrag->offset += size;
280         record->len += size;
281 }
282
283 static int tls_push_record(struct sock *sk,
284                            struct tls_context *ctx,
285                            struct tls_offload_context_tx *offload_ctx,
286                            struct tls_record_info *record,
287                            int flags)
288 {
289         struct tls_prot_info *prot = &ctx->prot_info;
290         struct tcp_sock *tp = tcp_sk(sk);
291         skb_frag_t *frag;
292         int i;
293
294         record->end_seq = tp->write_seq + record->len;
295         list_add_tail_rcu(&record->list, &offload_ctx->records_list);
296         offload_ctx->open_record = NULL;
297
298         if (test_bit(TLS_TX_SYNC_SCHED, &ctx->flags))
299                 tls_device_resync_tx(sk, ctx, tp->write_seq);
300
301         tls_advance_record_sn(sk, prot, &ctx->tx);
302
303         for (i = 0; i < record->num_frags; i++) {
304                 frag = &record->frags[i];
305                 sg_unmark_end(&offload_ctx->sg_tx_data[i]);
306                 sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
307                             skb_frag_size(frag), skb_frag_off(frag));
308                 sk_mem_charge(sk, skb_frag_size(frag));
309                 get_page(skb_frag_page(frag));
310         }
311         sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
312
313         /* all ready, send */
314         return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
315 }
316
317 static void tls_device_record_close(struct sock *sk,
318                                     struct tls_context *ctx,
319                                     struct tls_record_info *record,
320                                     struct page_frag *pfrag,
321                                     unsigned char record_type)
322 {
323         struct tls_prot_info *prot = &ctx->prot_info;
324         struct page_frag dummy_tag_frag;
325
326         /* append tag
327          * device will fill in the tag, we just need to append a placeholder
328          * use socket memory to improve coalescing (re-using a single buffer
329          * increases frag count)
330          * if we can't allocate memory now use the dummy page
331          */
332         if (unlikely(pfrag->size - pfrag->offset < prot->tag_size) &&
333             !skb_page_frag_refill(prot->tag_size, pfrag, sk->sk_allocation)) {
334                 dummy_tag_frag.page = dummy_page;
335                 dummy_tag_frag.offset = 0;
336                 pfrag = &dummy_tag_frag;
337         }
338         tls_append_frag(record, pfrag, prot->tag_size);
339
340         /* fill prepend */
341         tls_fill_prepend(ctx, skb_frag_address(&record->frags[0]),
342                          record->len - prot->overhead_size,
343                          record_type);
344 }
345
346 static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
347                                  struct page_frag *pfrag,
348                                  size_t prepend_size)
349 {
350         struct tls_record_info *record;
351         skb_frag_t *frag;
352
353         record = kmalloc(sizeof(*record), GFP_KERNEL);
354         if (!record)
355                 return -ENOMEM;
356
357         frag = &record->frags[0];
358         skb_frag_fill_page_desc(frag, pfrag->page, pfrag->offset,
359                                 prepend_size);
360
361         get_page(pfrag->page);
362         pfrag->offset += prepend_size;
363
364         record->num_frags = 1;
365         record->len = prepend_size;
366         offload_ctx->open_record = record;
367         return 0;
368 }
369
370 static int tls_do_allocation(struct sock *sk,
371                              struct tls_offload_context_tx *offload_ctx,
372                              struct page_frag *pfrag,
373                              size_t prepend_size)
374 {
375         int ret;
376
377         if (!offload_ctx->open_record) {
378                 if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
379                                                    sk->sk_allocation))) {
380                         READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
381                         sk_stream_moderate_sndbuf(sk);
382                         return -ENOMEM;
383                 }
384
385                 ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
386                 if (ret)
387                         return ret;
388
389                 if (pfrag->size > pfrag->offset)
390                         return 0;
391         }
392
393         if (!sk_page_frag_refill(sk, pfrag))
394                 return -ENOMEM;
395
396         return 0;
397 }
398
399 static int tls_device_copy_data(void *addr, size_t bytes, struct iov_iter *i)
400 {
401         size_t pre_copy, nocache;
402
403         pre_copy = ~((unsigned long)addr - 1) & (SMP_CACHE_BYTES - 1);
404         if (pre_copy) {
405                 pre_copy = min(pre_copy, bytes);
406                 if (copy_from_iter(addr, pre_copy, i) != pre_copy)
407                         return -EFAULT;
408                 bytes -= pre_copy;
409                 addr += pre_copy;
410         }
411
412         nocache = round_down(bytes, SMP_CACHE_BYTES);
413         if (copy_from_iter_nocache(addr, nocache, i) != nocache)
414                 return -EFAULT;
415         bytes -= nocache;
416         addr += nocache;
417
418         if (bytes && copy_from_iter(addr, bytes, i) != bytes)
419                 return -EFAULT;
420
421         return 0;
422 }
423
424 static int tls_push_data(struct sock *sk,
425                          struct iov_iter *iter,
426                          size_t size, int flags,
427                          unsigned char record_type)
428 {
429         struct tls_context *tls_ctx = tls_get_ctx(sk);
430         struct tls_prot_info *prot = &tls_ctx->prot_info;
431         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
432         struct tls_record_info *record;
433         int tls_push_record_flags;
434         struct page_frag *pfrag;
435         size_t orig_size = size;
436         u32 max_open_record_len;
437         bool more = false;
438         bool done = false;
439         int copy, rc = 0;
440         long timeo;
441
442         if (flags &
443             ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
444               MSG_SPLICE_PAGES | MSG_EOR))
445                 return -EOPNOTSUPP;
446
447         if ((flags & (MSG_MORE | MSG_EOR)) == (MSG_MORE | MSG_EOR))
448                 return -EINVAL;
449
450         if (unlikely(sk->sk_err))
451                 return -sk->sk_err;
452
453         flags |= MSG_SENDPAGE_DECRYPTED;
454         tls_push_record_flags = flags | MSG_MORE;
455
456         timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
457         if (tls_is_partially_sent_record(tls_ctx)) {
458                 rc = tls_push_partial_record(sk, tls_ctx, flags);
459                 if (rc < 0)
460                         return rc;
461         }
462
463         pfrag = sk_page_frag(sk);
464
465         /* TLS_HEADER_SIZE is not counted as part of the TLS record, and
466          * we need to leave room for an authentication tag.
467          */
468         max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
469                               prot->prepend_size;
470         do {
471                 rc = tls_do_allocation(sk, ctx, pfrag, prot->prepend_size);
472                 if (unlikely(rc)) {
473                         rc = sk_stream_wait_memory(sk, &timeo);
474                         if (!rc)
475                                 continue;
476
477                         record = ctx->open_record;
478                         if (!record)
479                                 break;
480 handle_error:
481                         if (record_type != TLS_RECORD_TYPE_DATA) {
482                                 /* avoid sending partial
483                                  * record with type !=
484                                  * application_data
485                                  */
486                                 size = orig_size;
487                                 destroy_record(record);
488                                 ctx->open_record = NULL;
489                         } else if (record->len > prot->prepend_size) {
490                                 goto last_record;
491                         }
492
493                         break;
494                 }
495
496                 record = ctx->open_record;
497
498                 copy = min_t(size_t, size, max_open_record_len - record->len);
499                 if (copy && (flags & MSG_SPLICE_PAGES)) {
500                         struct page_frag zc_pfrag;
501                         struct page **pages = &zc_pfrag.page;
502                         size_t off;
503
504                         rc = iov_iter_extract_pages(iter, &pages,
505                                                     copy, 1, 0, &off);
506                         if (rc <= 0) {
507                                 if (rc == 0)
508                                         rc = -EIO;
509                                 goto handle_error;
510                         }
511                         copy = rc;
512
513                         if (WARN_ON_ONCE(!sendpage_ok(zc_pfrag.page))) {
514                                 iov_iter_revert(iter, copy);
515                                 rc = -EIO;
516                                 goto handle_error;
517                         }
518
519                         zc_pfrag.offset = off;
520                         zc_pfrag.size = copy;
521                         tls_append_frag(record, &zc_pfrag, copy);
522                 } else if (copy) {
523                         copy = min_t(size_t, copy, pfrag->size - pfrag->offset);
524
525                         rc = tls_device_copy_data(page_address(pfrag->page) +
526                                                   pfrag->offset, copy,
527                                                   iter);
528                         if (rc)
529                                 goto handle_error;
530                         tls_append_frag(record, pfrag, copy);
531                 }
532
533                 size -= copy;
534                 if (!size) {
535 last_record:
536                         tls_push_record_flags = flags;
537                         if (flags & MSG_MORE) {
538                                 more = true;
539                                 break;
540                         }
541
542                         done = true;
543                 }
544
545                 if (done || record->len >= max_open_record_len ||
546                     (record->num_frags >= MAX_SKB_FRAGS - 1)) {
547                         tls_device_record_close(sk, tls_ctx, record,
548                                                 pfrag, record_type);
549
550                         rc = tls_push_record(sk,
551                                              tls_ctx,
552                                              ctx,
553                                              record,
554                                              tls_push_record_flags);
555                         if (rc < 0)
556                                 break;
557                 }
558         } while (!done);
559
560         tls_ctx->pending_open_record_frags = more;
561
562         if (orig_size - size > 0)
563                 rc = orig_size - size;
564
565         return rc;
566 }
567
568 int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
569 {
570         unsigned char record_type = TLS_RECORD_TYPE_DATA;
571         struct tls_context *tls_ctx = tls_get_ctx(sk);
572         int rc;
573
574         if (!tls_ctx->zerocopy_sendfile)
575                 msg->msg_flags &= ~MSG_SPLICE_PAGES;
576
577         mutex_lock(&tls_ctx->tx_lock);
578         lock_sock(sk);
579
580         if (unlikely(msg->msg_controllen)) {
581                 rc = tls_process_cmsg(sk, msg, &record_type);
582                 if (rc)
583                         goto out;
584         }
585
586         rc = tls_push_data(sk, &msg->msg_iter, size, msg->msg_flags,
587                            record_type);
588
589 out:
590         release_sock(sk);
591         mutex_unlock(&tls_ctx->tx_lock);
592         return rc;
593 }
594
595 void tls_device_splice_eof(struct socket *sock)
596 {
597         struct sock *sk = sock->sk;
598         struct tls_context *tls_ctx = tls_get_ctx(sk);
599         struct iov_iter iter = {};
600
601         if (!tls_is_partially_sent_record(tls_ctx))
602                 return;
603
604         mutex_lock(&tls_ctx->tx_lock);
605         lock_sock(sk);
606
607         if (tls_is_partially_sent_record(tls_ctx)) {
608                 iov_iter_bvec(&iter, ITER_SOURCE, NULL, 0, 0);
609                 tls_push_data(sk, &iter, 0, 0, TLS_RECORD_TYPE_DATA);
610         }
611
612         release_sock(sk);
613         mutex_unlock(&tls_ctx->tx_lock);
614 }
615
616 struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
617                                        u32 seq, u64 *p_record_sn)
618 {
619         u64 record_sn = context->hint_record_sn;
620         struct tls_record_info *info, *last;
621
622         info = context->retransmit_hint;
623         if (!info ||
624             before(seq, info->end_seq - info->len)) {
625                 /* if retransmit_hint is irrelevant start
626                  * from the beginning of the list
627                  */
628                 info = list_first_entry_or_null(&context->records_list,
629                                                 struct tls_record_info, list);
630                 if (!info)
631                         return NULL;
632                 /* send the start_marker record if seq number is before the
633                  * tls offload start marker sequence number. This record is
634                  * required to handle TCP packets which are before TLS offload
635                  * started.
636                  *  And if it's not start marker, look if this seq number
637                  * belongs to the list.
638                  */
639                 if (likely(!tls_record_is_start_marker(info))) {
640                         /* we have the first record, get the last record to see
641                          * if this seq number belongs to the list.
642                          */
643                         last = list_last_entry(&context->records_list,
644                                                struct tls_record_info, list);
645
646                         if (!between(seq, tls_record_start_seq(info),
647                                      last->end_seq))
648                                 return NULL;
649                 }
650                 record_sn = context->unacked_record_sn;
651         }
652
653         /* We just need the _rcu for the READ_ONCE() */
654         rcu_read_lock();
655         list_for_each_entry_from_rcu(info, &context->records_list, list) {
656                 if (before(seq, info->end_seq)) {
657                         if (!context->retransmit_hint ||
658                             after(info->end_seq,
659                                   context->retransmit_hint->end_seq)) {
660                                 context->hint_record_sn = record_sn;
661                                 context->retransmit_hint = info;
662                         }
663                         *p_record_sn = record_sn;
664                         goto exit_rcu_unlock;
665                 }
666                 record_sn++;
667         }
668         info = NULL;
669
670 exit_rcu_unlock:
671         rcu_read_unlock();
672         return info;
673 }
674 EXPORT_SYMBOL(tls_get_record);
675
676 static int tls_device_push_pending_record(struct sock *sk, int flags)
677 {
678         struct iov_iter iter;
679
680         iov_iter_kvec(&iter, ITER_SOURCE, NULL, 0, 0);
681         return tls_push_data(sk, &iter, 0, flags, TLS_RECORD_TYPE_DATA);
682 }
683
684 void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
685 {
686         if (tls_is_partially_sent_record(ctx)) {
687                 gfp_t sk_allocation = sk->sk_allocation;
688
689                 WARN_ON_ONCE(sk->sk_write_pending);
690
691                 sk->sk_allocation = GFP_ATOMIC;
692                 tls_push_partial_record(sk, ctx,
693                                         MSG_DONTWAIT | MSG_NOSIGNAL |
694                                         MSG_SENDPAGE_DECRYPTED);
695                 sk->sk_allocation = sk_allocation;
696         }
697 }
698
699 static void tls_device_resync_rx(struct tls_context *tls_ctx,
700                                  struct sock *sk, u32 seq, u8 *rcd_sn)
701 {
702         struct tls_offload_context_rx *rx_ctx = tls_offload_ctx_rx(tls_ctx);
703         struct net_device *netdev;
704
705         trace_tls_device_rx_resync_send(sk, seq, rcd_sn, rx_ctx->resync_type);
706         rcu_read_lock();
707         netdev = rcu_dereference(tls_ctx->netdev);
708         if (netdev)
709                 netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn,
710                                                    TLS_OFFLOAD_CTX_DIR_RX);
711         rcu_read_unlock();
712         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICERESYNC);
713 }
714
715 static bool
716 tls_device_rx_resync_async(struct tls_offload_resync_async *resync_async,
717                            s64 resync_req, u32 *seq, u16 *rcd_delta)
718 {
719         u32 is_async = resync_req & RESYNC_REQ_ASYNC;
720         u32 req_seq = resync_req >> 32;
721         u32 req_end = req_seq + ((resync_req >> 16) & 0xffff);
722         u16 i;
723
724         *rcd_delta = 0;
725
726         if (is_async) {
727                 /* shouldn't get to wraparound:
728                  * too long in async stage, something bad happened
729                  */
730                 if (WARN_ON_ONCE(resync_async->rcd_delta == USHRT_MAX))
731                         return false;
732
733                 /* asynchronous stage: log all headers seq such that
734                  * req_seq <= seq <= end_seq, and wait for real resync request
735                  */
736                 if (before(*seq, req_seq))
737                         return false;
738                 if (!after(*seq, req_end) &&
739                     resync_async->loglen < TLS_DEVICE_RESYNC_ASYNC_LOGMAX)
740                         resync_async->log[resync_async->loglen++] = *seq;
741
742                 resync_async->rcd_delta++;
743
744                 return false;
745         }
746
747         /* synchronous stage: check against the logged entries and
748          * proceed to check the next entries if no match was found
749          */
750         for (i = 0; i < resync_async->loglen; i++)
751                 if (req_seq == resync_async->log[i] &&
752                     atomic64_try_cmpxchg(&resync_async->req, &resync_req, 0)) {
753                         *rcd_delta = resync_async->rcd_delta - i;
754                         *seq = req_seq;
755                         resync_async->loglen = 0;
756                         resync_async->rcd_delta = 0;
757                         return true;
758                 }
759
760         resync_async->loglen = 0;
761         resync_async->rcd_delta = 0;
762
763         if (req_seq == *seq &&
764             atomic64_try_cmpxchg(&resync_async->req,
765                                  &resync_req, 0))
766                 return true;
767
768         return false;
769 }
770
771 void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq)
772 {
773         struct tls_context *tls_ctx = tls_get_ctx(sk);
774         struct tls_offload_context_rx *rx_ctx;
775         u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
776         u32 sock_data, is_req_pending;
777         struct tls_prot_info *prot;
778         s64 resync_req;
779         u16 rcd_delta;
780         u32 req_seq;
781
782         if (tls_ctx->rx_conf != TLS_HW)
783                 return;
784         if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags)))
785                 return;
786
787         prot = &tls_ctx->prot_info;
788         rx_ctx = tls_offload_ctx_rx(tls_ctx);
789         memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
790
791         switch (rx_ctx->resync_type) {
792         case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ:
793                 resync_req = atomic64_read(&rx_ctx->resync_req);
794                 req_seq = resync_req >> 32;
795                 seq += TLS_HEADER_SIZE - 1;
796                 is_req_pending = resync_req;
797
798                 if (likely(!is_req_pending) || req_seq != seq ||
799                     !atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
800                         return;
801                 break;
802         case TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT:
803                 if (likely(!rx_ctx->resync_nh_do_now))
804                         return;
805
806                 /* head of next rec is already in, note that the sock_inq will
807                  * include the currently parsed message when called from parser
808                  */
809                 sock_data = tcp_inq(sk);
810                 if (sock_data > rcd_len) {
811                         trace_tls_device_rx_resync_nh_delay(sk, sock_data,
812                                                             rcd_len);
813                         return;
814                 }
815
816                 rx_ctx->resync_nh_do_now = 0;
817                 seq += rcd_len;
818                 tls_bigint_increment(rcd_sn, prot->rec_seq_size);
819                 break;
820         case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ_ASYNC:
821                 resync_req = atomic64_read(&rx_ctx->resync_async->req);
822                 is_req_pending = resync_req;
823                 if (likely(!is_req_pending))
824                         return;
825
826                 if (!tls_device_rx_resync_async(rx_ctx->resync_async,
827                                                 resync_req, &seq, &rcd_delta))
828                         return;
829                 tls_bigint_subtract(rcd_sn, rcd_delta);
830                 break;
831         }
832
833         tls_device_resync_rx(tls_ctx, sk, seq, rcd_sn);
834 }
835
836 static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
837                                            struct tls_offload_context_rx *ctx,
838                                            struct sock *sk, struct sk_buff *skb)
839 {
840         struct strp_msg *rxm;
841
842         /* device will request resyncs by itself based on stream scan */
843         if (ctx->resync_type != TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT)
844                 return;
845         /* already scheduled */
846         if (ctx->resync_nh_do_now)
847                 return;
848         /* seen decrypted fragments since last fully-failed record */
849         if (ctx->resync_nh_reset) {
850                 ctx->resync_nh_reset = 0;
851                 ctx->resync_nh.decrypted_failed = 1;
852                 ctx->resync_nh.decrypted_tgt = TLS_DEVICE_RESYNC_NH_START_IVAL;
853                 return;
854         }
855
856         if (++ctx->resync_nh.decrypted_failed <= ctx->resync_nh.decrypted_tgt)
857                 return;
858
859         /* doing resync, bump the next target in case it fails */
860         if (ctx->resync_nh.decrypted_tgt < TLS_DEVICE_RESYNC_NH_MAX_IVAL)
861                 ctx->resync_nh.decrypted_tgt *= 2;
862         else
863                 ctx->resync_nh.decrypted_tgt += TLS_DEVICE_RESYNC_NH_MAX_IVAL;
864
865         rxm = strp_msg(skb);
866
867         /* head of next rec is already in, parser will sync for us */
868         if (tcp_inq(sk) > rxm->full_len) {
869                 trace_tls_device_rx_resync_nh_schedule(sk);
870                 ctx->resync_nh_do_now = 1;
871         } else {
872                 struct tls_prot_info *prot = &tls_ctx->prot_info;
873                 u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
874
875                 memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
876                 tls_bigint_increment(rcd_sn, prot->rec_seq_size);
877
878                 tls_device_resync_rx(tls_ctx, sk, tcp_sk(sk)->copied_seq,
879                                      rcd_sn);
880         }
881 }
882
883 static int
884 tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx)
885 {
886         struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
887         const struct tls_cipher_desc *cipher_desc;
888         int err, offset, copy, data_len, pos;
889         struct sk_buff *skb, *skb_iter;
890         struct scatterlist sg[1];
891         struct strp_msg *rxm;
892         char *orig_buf, *buf;
893
894         switch (tls_ctx->crypto_recv.info.cipher_type) {
895         case TLS_CIPHER_AES_GCM_128:
896         case TLS_CIPHER_AES_GCM_256:
897                 break;
898         default:
899                 return -EINVAL;
900         }
901         cipher_desc = get_cipher_desc(tls_ctx->crypto_recv.info.cipher_type);
902
903         rxm = strp_msg(tls_strp_msg(sw_ctx));
904         orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + cipher_desc->iv,
905                            sk->sk_allocation);
906         if (!orig_buf)
907                 return -ENOMEM;
908         buf = orig_buf;
909
910         err = tls_strp_msg_cow(sw_ctx);
911         if (unlikely(err))
912                 goto free_buf;
913
914         skb = tls_strp_msg(sw_ctx);
915         rxm = strp_msg(skb);
916         offset = rxm->offset;
917
918         sg_init_table(sg, 1);
919         sg_set_buf(&sg[0], buf,
920                    rxm->full_len + TLS_HEADER_SIZE + cipher_desc->iv);
921         err = skb_copy_bits(skb, offset, buf, TLS_HEADER_SIZE + cipher_desc->iv);
922         if (err)
923                 goto free_buf;
924
925         /* We are interested only in the decrypted data not the auth */
926         err = decrypt_skb(sk, sg);
927         if (err != -EBADMSG)
928                 goto free_buf;
929         else
930                 err = 0;
931
932         data_len = rxm->full_len - cipher_desc->tag;
933
934         if (skb_pagelen(skb) > offset) {
935                 copy = min_t(int, skb_pagelen(skb) - offset, data_len);
936
937                 if (skb->decrypted) {
938                         err = skb_store_bits(skb, offset, buf, copy);
939                         if (err)
940                                 goto free_buf;
941                 }
942
943                 offset += copy;
944                 buf += copy;
945         }
946
947         pos = skb_pagelen(skb);
948         skb_walk_frags(skb, skb_iter) {
949                 int frag_pos;
950
951                 /* Practically all frags must belong to msg if reencrypt
952                  * is needed with current strparser and coalescing logic,
953                  * but strparser may "get optimized", so let's be safe.
954                  */
955                 if (pos + skb_iter->len <= offset)
956                         goto done_with_frag;
957                 if (pos >= data_len + rxm->offset)
958                         break;
959
960                 frag_pos = offset - pos;
961                 copy = min_t(int, skb_iter->len - frag_pos,
962                              data_len + rxm->offset - offset);
963
964                 if (skb_iter->decrypted) {
965                         err = skb_store_bits(skb_iter, frag_pos, buf, copy);
966                         if (err)
967                                 goto free_buf;
968                 }
969
970                 offset += copy;
971                 buf += copy;
972 done_with_frag:
973                 pos += skb_iter->len;
974         }
975
976 free_buf:
977         kfree(orig_buf);
978         return err;
979 }
980
981 int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
982 {
983         struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
984         struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
985         struct sk_buff *skb = tls_strp_msg(sw_ctx);
986         struct strp_msg *rxm = strp_msg(skb);
987         int is_decrypted, is_encrypted;
988
989         if (!tls_strp_msg_mixed_decrypted(sw_ctx)) {
990                 is_decrypted = skb->decrypted;
991                 is_encrypted = !is_decrypted;
992         } else {
993                 is_decrypted = 0;
994                 is_encrypted = 0;
995         }
996
997         trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len,
998                                    tls_ctx->rx.rec_seq, rxm->full_len,
999                                    is_encrypted, is_decrypted);
1000
1001         if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) {
1002                 if (likely(is_encrypted || is_decrypted))
1003                         return is_decrypted;
1004
1005                 /* After tls_device_down disables the offload, the next SKB will
1006                  * likely have initial fragments decrypted, and final ones not
1007                  * decrypted. We need to reencrypt that single SKB.
1008                  */
1009                 return tls_device_reencrypt(sk, tls_ctx);
1010         }
1011
1012         /* Return immediately if the record is either entirely plaintext or
1013          * entirely ciphertext. Otherwise handle reencrypt partially decrypted
1014          * record.
1015          */
1016         if (is_decrypted) {
1017                 ctx->resync_nh_reset = 1;
1018                 return is_decrypted;
1019         }
1020         if (is_encrypted) {
1021                 tls_device_core_ctrl_rx_resync(tls_ctx, ctx, sk, skb);
1022                 return 0;
1023         }
1024
1025         ctx->resync_nh_reset = 1;
1026         return tls_device_reencrypt(sk, tls_ctx);
1027 }
1028
1029 static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
1030                               struct net_device *netdev)
1031 {
1032         if (sk->sk_destruct != tls_device_sk_destruct) {
1033                 refcount_set(&ctx->refcount, 1);
1034                 dev_hold(netdev);
1035                 RCU_INIT_POINTER(ctx->netdev, netdev);
1036                 spin_lock_irq(&tls_device_lock);
1037                 list_add_tail(&ctx->list, &tls_device_list);
1038                 spin_unlock_irq(&tls_device_lock);
1039
1040                 ctx->sk_destruct = sk->sk_destruct;
1041                 smp_store_release(&sk->sk_destruct, tls_device_sk_destruct);
1042         }
1043 }
1044
1045 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
1046 {
1047         struct tls_context *tls_ctx = tls_get_ctx(sk);
1048         struct tls_prot_info *prot = &tls_ctx->prot_info;
1049         const struct tls_cipher_desc *cipher_desc;
1050         struct tls_record_info *start_marker_record;
1051         struct tls_offload_context_tx *offload_ctx;
1052         struct tls_crypto_info *crypto_info;
1053         struct net_device *netdev;
1054         char *iv, *rec_seq;
1055         struct sk_buff *skb;
1056         __be64 rcd_sn;
1057         int rc;
1058
1059         if (!ctx)
1060                 return -EINVAL;
1061
1062         if (ctx->priv_ctx_tx)
1063                 return -EEXIST;
1064
1065         netdev = get_netdev_for_sock(sk);
1066         if (!netdev) {
1067                 pr_err_ratelimited("%s: netdev not found\n", __func__);
1068                 return -EINVAL;
1069         }
1070
1071         if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
1072                 rc = -EOPNOTSUPP;
1073                 goto release_netdev;
1074         }
1075
1076         crypto_info = &ctx->crypto_send.info;
1077         if (crypto_info->version != TLS_1_2_VERSION) {
1078                 rc = -EOPNOTSUPP;
1079                 goto release_netdev;
1080         }
1081
1082         cipher_desc = get_cipher_desc(crypto_info->cipher_type);
1083         if (!cipher_desc || !cipher_desc->offloadable) {
1084                 rc = -EINVAL;
1085                 goto release_netdev;
1086         }
1087
1088         iv = crypto_info_iv(crypto_info, cipher_desc);
1089         rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc);
1090
1091         prot->version = crypto_info->version;
1092         prot->cipher_type = crypto_info->cipher_type;
1093         prot->prepend_size = TLS_HEADER_SIZE + cipher_desc->iv;
1094         prot->tag_size = cipher_desc->tag;
1095         prot->overhead_size = prot->prepend_size + prot->tag_size;
1096         prot->iv_size = cipher_desc->iv;
1097         prot->salt_size = cipher_desc->salt;
1098         ctx->tx.iv = kmalloc(cipher_desc->iv + cipher_desc->salt, GFP_KERNEL);
1099         if (!ctx->tx.iv) {
1100                 rc = -ENOMEM;
1101                 goto release_netdev;
1102         }
1103
1104         memcpy(ctx->tx.iv + cipher_desc->salt, iv, cipher_desc->iv);
1105
1106         prot->rec_seq_size = cipher_desc->rec_seq;
1107         ctx->tx.rec_seq = kmemdup(rec_seq, cipher_desc->rec_seq, GFP_KERNEL);
1108         if (!ctx->tx.rec_seq) {
1109                 rc = -ENOMEM;
1110                 goto free_iv;
1111         }
1112
1113         start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
1114         if (!start_marker_record) {
1115                 rc = -ENOMEM;
1116                 goto free_rec_seq;
1117         }
1118
1119         offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
1120         if (!offload_ctx) {
1121                 rc = -ENOMEM;
1122                 goto free_marker_record;
1123         }
1124
1125         rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
1126         if (rc)
1127                 goto free_offload_ctx;
1128
1129         /* start at rec_seq - 1 to account for the start marker record */
1130         memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
1131         offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
1132
1133         start_marker_record->end_seq = tcp_sk(sk)->write_seq;
1134         start_marker_record->len = 0;
1135         start_marker_record->num_frags = 0;
1136
1137         INIT_WORK(&offload_ctx->destruct_work, tls_device_tx_del_task);
1138         offload_ctx->ctx = ctx;
1139
1140         INIT_LIST_HEAD(&offload_ctx->records_list);
1141         list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
1142         spin_lock_init(&offload_ctx->lock);
1143         sg_init_table(offload_ctx->sg_tx_data,
1144                       ARRAY_SIZE(offload_ctx->sg_tx_data));
1145
1146         clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
1147         ctx->push_pending_record = tls_device_push_pending_record;
1148
1149         /* TLS offload is greatly simplified if we don't send
1150          * SKBs where only part of the payload needs to be encrypted.
1151          * So mark the last skb in the write queue as end of record.
1152          */
1153         skb = tcp_write_queue_tail(sk);
1154         if (skb)
1155                 TCP_SKB_CB(skb)->eor = 1;
1156
1157         /* Avoid offloading if the device is down
1158          * We don't want to offload new flows after
1159          * the NETDEV_DOWN event
1160          *
1161          * device_offload_lock is taken in tls_devices's NETDEV_DOWN
1162          * handler thus protecting from the device going down before
1163          * ctx was added to tls_device_list.
1164          */
1165         down_read(&device_offload_lock);
1166         if (!(netdev->flags & IFF_UP)) {
1167                 rc = -EINVAL;
1168                 goto release_lock;
1169         }
1170
1171         ctx->priv_ctx_tx = offload_ctx;
1172         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
1173                                              &ctx->crypto_send.info,
1174                                              tcp_sk(sk)->write_seq);
1175         trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_TX,
1176                                      tcp_sk(sk)->write_seq, rec_seq, rc);
1177         if (rc)
1178                 goto release_lock;
1179
1180         tls_device_attach(ctx, sk, netdev);
1181         up_read(&device_offload_lock);
1182
1183         /* following this assignment tls_is_skb_tx_device_offloaded
1184          * will return true and the context might be accessed
1185          * by the netdev's xmit function.
1186          */
1187         smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
1188         dev_put(netdev);
1189
1190         return 0;
1191
1192 release_lock:
1193         up_read(&device_offload_lock);
1194         clean_acked_data_disable(inet_csk(sk));
1195         crypto_free_aead(offload_ctx->aead_send);
1196 free_offload_ctx:
1197         kfree(offload_ctx);
1198         ctx->priv_ctx_tx = NULL;
1199 free_marker_record:
1200         kfree(start_marker_record);
1201 free_rec_seq:
1202         kfree(ctx->tx.rec_seq);
1203 free_iv:
1204         kfree(ctx->tx.iv);
1205 release_netdev:
1206         dev_put(netdev);
1207         return rc;
1208 }
1209
1210 int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
1211 {
1212         struct tls12_crypto_info_aes_gcm_128 *info;
1213         struct tls_offload_context_rx *context;
1214         struct net_device *netdev;
1215         int rc = 0;
1216
1217         if (ctx->crypto_recv.info.version != TLS_1_2_VERSION)
1218                 return -EOPNOTSUPP;
1219
1220         netdev = get_netdev_for_sock(sk);
1221         if (!netdev) {
1222                 pr_err_ratelimited("%s: netdev not found\n", __func__);
1223                 return -EINVAL;
1224         }
1225
1226         if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
1227                 rc = -EOPNOTSUPP;
1228                 goto release_netdev;
1229         }
1230
1231         /* Avoid offloading if the device is down
1232          * We don't want to offload new flows after
1233          * the NETDEV_DOWN event
1234          *
1235          * device_offload_lock is taken in tls_devices's NETDEV_DOWN
1236          * handler thus protecting from the device going down before
1237          * ctx was added to tls_device_list.
1238          */
1239         down_read(&device_offload_lock);
1240         if (!(netdev->flags & IFF_UP)) {
1241                 rc = -EINVAL;
1242                 goto release_lock;
1243         }
1244
1245         context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
1246         if (!context) {
1247                 rc = -ENOMEM;
1248                 goto release_lock;
1249         }
1250         context->resync_nh_reset = 1;
1251
1252         ctx->priv_ctx_rx = context;
1253         rc = tls_set_sw_offload(sk, ctx, 0);
1254         if (rc)
1255                 goto release_ctx;
1256
1257         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
1258                                              &ctx->crypto_recv.info,
1259                                              tcp_sk(sk)->copied_seq);
1260         info = (void *)&ctx->crypto_recv.info;
1261         trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_RX,
1262                                      tcp_sk(sk)->copied_seq, info->rec_seq, rc);
1263         if (rc)
1264                 goto free_sw_resources;
1265
1266         tls_device_attach(ctx, sk, netdev);
1267         up_read(&device_offload_lock);
1268
1269         dev_put(netdev);
1270
1271         return 0;
1272
1273 free_sw_resources:
1274         up_read(&device_offload_lock);
1275         tls_sw_free_resources_rx(sk);
1276         down_read(&device_offload_lock);
1277 release_ctx:
1278         ctx->priv_ctx_rx = NULL;
1279 release_lock:
1280         up_read(&device_offload_lock);
1281 release_netdev:
1282         dev_put(netdev);
1283         return rc;
1284 }
1285
1286 void tls_device_offload_cleanup_rx(struct sock *sk)
1287 {
1288         struct tls_context *tls_ctx = tls_get_ctx(sk);
1289         struct net_device *netdev;
1290
1291         down_read(&device_offload_lock);
1292         netdev = rcu_dereference_protected(tls_ctx->netdev,
1293                                            lockdep_is_held(&device_offload_lock));
1294         if (!netdev)
1295                 goto out;
1296
1297         netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
1298                                         TLS_OFFLOAD_CTX_DIR_RX);
1299
1300         if (tls_ctx->tx_conf != TLS_HW) {
1301                 dev_put(netdev);
1302                 rcu_assign_pointer(tls_ctx->netdev, NULL);
1303         } else {
1304                 set_bit(TLS_RX_DEV_CLOSED, &tls_ctx->flags);
1305         }
1306 out:
1307         up_read(&device_offload_lock);
1308         tls_sw_release_resources_rx(sk);
1309 }
1310
1311 static int tls_device_down(struct net_device *netdev)
1312 {
1313         struct tls_context *ctx, *tmp;
1314         unsigned long flags;
1315         LIST_HEAD(list);
1316
1317         /* Request a write lock to block new offload attempts */
1318         down_write(&device_offload_lock);
1319
1320         spin_lock_irqsave(&tls_device_lock, flags);
1321         list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
1322                 struct net_device *ctx_netdev =
1323                         rcu_dereference_protected(ctx->netdev,
1324                                                   lockdep_is_held(&device_offload_lock));
1325
1326                 if (ctx_netdev != netdev ||
1327                     !refcount_inc_not_zero(&ctx->refcount))
1328                         continue;
1329
1330                 list_move(&ctx->list, &list);
1331         }
1332         spin_unlock_irqrestore(&tls_device_lock, flags);
1333
1334         list_for_each_entry_safe(ctx, tmp, &list, list) {
1335                 /* Stop offloaded TX and switch to the fallback.
1336                  * tls_is_skb_tx_device_offloaded will return false.
1337                  */
1338                 WRITE_ONCE(ctx->sk->sk_validate_xmit_skb, tls_validate_xmit_skb_sw);
1339
1340                 /* Stop the RX and TX resync.
1341                  * tls_dev_resync must not be called after tls_dev_del.
1342                  */
1343                 rcu_assign_pointer(ctx->netdev, NULL);
1344
1345                 /* Start skipping the RX resync logic completely. */
1346                 set_bit(TLS_RX_DEV_DEGRADED, &ctx->flags);
1347
1348                 /* Sync with inflight packets. After this point:
1349                  * TX: no non-encrypted packets will be passed to the driver.
1350                  * RX: resync requests from the driver will be ignored.
1351                  */
1352                 synchronize_net();
1353
1354                 /* Release the offload context on the driver side. */
1355                 if (ctx->tx_conf == TLS_HW)
1356                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
1357                                                         TLS_OFFLOAD_CTX_DIR_TX);
1358                 if (ctx->rx_conf == TLS_HW &&
1359                     !test_bit(TLS_RX_DEV_CLOSED, &ctx->flags))
1360                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
1361                                                         TLS_OFFLOAD_CTX_DIR_RX);
1362
1363                 dev_put(netdev);
1364
1365                 /* Move the context to a separate list for two reasons:
1366                  * 1. When the context is deallocated, list_del is called.
1367                  * 2. It's no longer an offloaded context, so we don't want to
1368                  *    run offload-specific code on this context.
1369                  */
1370                 spin_lock_irqsave(&tls_device_lock, flags);
1371                 list_move_tail(&ctx->list, &tls_device_down_list);
1372                 spin_unlock_irqrestore(&tls_device_lock, flags);
1373
1374                 /* Device contexts for RX and TX will be freed in on sk_destruct
1375                  * by tls_device_free_ctx. rx_conf and tx_conf stay in TLS_HW.
1376                  * Now release the ref taken above.
1377                  */
1378                 if (refcount_dec_and_test(&ctx->refcount)) {
1379                         /* sk_destruct ran after tls_device_down took a ref, and
1380                          * it returned early. Complete the destruction here.
1381                          */
1382                         list_del(&ctx->list);
1383                         tls_device_free_ctx(ctx);
1384                 }
1385         }
1386
1387         up_write(&device_offload_lock);
1388
1389         flush_workqueue(destruct_wq);
1390
1391         return NOTIFY_DONE;
1392 }
1393
1394 static int tls_dev_event(struct notifier_block *this, unsigned long event,
1395                          void *ptr)
1396 {
1397         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
1398
1399         if (!dev->tlsdev_ops &&
1400             !(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
1401                 return NOTIFY_DONE;
1402
1403         switch (event) {
1404         case NETDEV_REGISTER:
1405         case NETDEV_FEAT_CHANGE:
1406                 if (netif_is_bond_master(dev))
1407                         return NOTIFY_DONE;
1408                 if ((dev->features & NETIF_F_HW_TLS_RX) &&
1409                     !dev->tlsdev_ops->tls_dev_resync)
1410                         return NOTIFY_BAD;
1411
1412                 if  (dev->tlsdev_ops &&
1413                      dev->tlsdev_ops->tls_dev_add &&
1414                      dev->tlsdev_ops->tls_dev_del)
1415                         return NOTIFY_DONE;
1416                 else
1417                         return NOTIFY_BAD;
1418         case NETDEV_DOWN:
1419                 return tls_device_down(dev);
1420         }
1421         return NOTIFY_DONE;
1422 }
1423
1424 static struct notifier_block tls_dev_notifier = {
1425         .notifier_call  = tls_dev_event,
1426 };
1427
1428 int __init tls_device_init(void)
1429 {
1430         int err;
1431
1432         dummy_page = alloc_page(GFP_KERNEL);
1433         if (!dummy_page)
1434                 return -ENOMEM;
1435
1436         destruct_wq = alloc_workqueue("ktls_device_destruct", 0, 0);
1437         if (!destruct_wq) {
1438                 err = -ENOMEM;
1439                 goto err_free_dummy;
1440         }
1441
1442         err = register_netdevice_notifier(&tls_dev_notifier);
1443         if (err)
1444                 goto err_destroy_wq;
1445
1446         return 0;
1447
1448 err_destroy_wq:
1449         destroy_workqueue(destruct_wq);
1450 err_free_dummy:
1451         put_page(dummy_page);
1452         return err;
1453 }
1454
1455 void __exit tls_device_cleanup(void)
1456 {
1457         unregister_netdevice_notifier(&tls_dev_notifier);
1458         destroy_workqueue(destruct_wq);
1459         clean_acked_data_flush();
1460         put_page(dummy_page);
1461 }