net/tls: Fix race in TLS device down flow
[platform/kernel/linux-rpi.git] / net / tls / tls_device.c
1 /* Copyright (c) 2018, Mellanox Technologies All rights reserved.
2  *
3  * This software is available to you under a choice of one of two
4  * licenses.  You may choose to be licensed under the terms of the GNU
5  * General Public License (GPL) Version 2, available from the file
6  * COPYING in the main directory of this source tree, or the
7  * OpenIB.org BSD license below:
8  *
9  *     Redistribution and use in source and binary forms, with or
10  *     without modification, are permitted provided that the following
11  *     conditions are met:
12  *
13  *      - Redistributions of source code must retain the above
14  *        copyright notice, this list of conditions and the following
15  *        disclaimer.
16  *
17  *      - Redistributions in binary form must reproduce the above
18  *        copyright notice, this list of conditions and the following
19  *        disclaimer in the documentation and/or other materials
20  *        provided with the distribution.
21  *
22  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
23  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
24  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
25  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
26  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
27  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
28  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29  * SOFTWARE.
30  */
31
32 #include <crypto/aead.h>
33 #include <linux/highmem.h>
34 #include <linux/module.h>
35 #include <linux/netdevice.h>
36 #include <net/dst.h>
37 #include <net/inet_connection_sock.h>
38 #include <net/tcp.h>
39 #include <net/tls.h>
40
41 #include "trace.h"
42
43 /* device_offload_lock is used to synchronize tls_dev_add
44  * against NETDEV_DOWN notifications.
45  */
46 static DECLARE_RWSEM(device_offload_lock);
47
48 static void tls_device_gc_task(struct work_struct *work);
49
50 static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
51 static LIST_HEAD(tls_device_gc_list);
52 static LIST_HEAD(tls_device_list);
53 static LIST_HEAD(tls_device_down_list);
54 static DEFINE_SPINLOCK(tls_device_lock);
55
56 static void tls_device_free_ctx(struct tls_context *ctx)
57 {
58         if (ctx->tx_conf == TLS_HW) {
59                 kfree(tls_offload_ctx_tx(ctx));
60                 kfree(ctx->tx.rec_seq);
61                 kfree(ctx->tx.iv);
62         }
63
64         if (ctx->rx_conf == TLS_HW)
65                 kfree(tls_offload_ctx_rx(ctx));
66
67         tls_ctx_free(NULL, ctx);
68 }
69
70 static void tls_device_gc_task(struct work_struct *work)
71 {
72         struct tls_context *ctx, *tmp;
73         unsigned long flags;
74         LIST_HEAD(gc_list);
75
76         spin_lock_irqsave(&tls_device_lock, flags);
77         list_splice_init(&tls_device_gc_list, &gc_list);
78         spin_unlock_irqrestore(&tls_device_lock, flags);
79
80         list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
81                 struct net_device *netdev = ctx->netdev;
82
83                 if (netdev && ctx->tx_conf == TLS_HW) {
84                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
85                                                         TLS_OFFLOAD_CTX_DIR_TX);
86                         dev_put(netdev);
87                         ctx->netdev = NULL;
88                 }
89
90                 list_del(&ctx->list);
91                 tls_device_free_ctx(ctx);
92         }
93 }
94
95 static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
96 {
97         unsigned long flags;
98
99         spin_lock_irqsave(&tls_device_lock, flags);
100         if (unlikely(!refcount_dec_and_test(&ctx->refcount)))
101                 goto unlock;
102
103         list_move_tail(&ctx->list, &tls_device_gc_list);
104
105         /* schedule_work inside the spinlock
106          * to make sure tls_device_down waits for that work.
107          */
108         schedule_work(&tls_device_gc_work);
109 unlock:
110         spin_unlock_irqrestore(&tls_device_lock, flags);
111 }
112
113 /* We assume that the socket is already connected */
114 static struct net_device *get_netdev_for_sock(struct sock *sk)
115 {
116         struct dst_entry *dst = sk_dst_get(sk);
117         struct net_device *netdev = NULL;
118
119         if (likely(dst)) {
120                 netdev = netdev_sk_get_lowest_dev(dst->dev, sk);
121                 dev_hold(netdev);
122         }
123
124         dst_release(dst);
125
126         return netdev;
127 }
128
129 static void destroy_record(struct tls_record_info *record)
130 {
131         int i;
132
133         for (i = 0; i < record->num_frags; i++)
134                 __skb_frag_unref(&record->frags[i], false);
135         kfree(record);
136 }
137
138 static void delete_all_records(struct tls_offload_context_tx *offload_ctx)
139 {
140         struct tls_record_info *info, *temp;
141
142         list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) {
143                 list_del(&info->list);
144                 destroy_record(info);
145         }
146
147         offload_ctx->retransmit_hint = NULL;
148 }
149
150 static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
151 {
152         struct tls_context *tls_ctx = tls_get_ctx(sk);
153         struct tls_record_info *info, *temp;
154         struct tls_offload_context_tx *ctx;
155         u64 deleted_records = 0;
156         unsigned long flags;
157
158         if (!tls_ctx)
159                 return;
160
161         ctx = tls_offload_ctx_tx(tls_ctx);
162
163         spin_lock_irqsave(&ctx->lock, flags);
164         info = ctx->retransmit_hint;
165         if (info && !before(acked_seq, info->end_seq))
166                 ctx->retransmit_hint = NULL;
167
168         list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
169                 if (before(acked_seq, info->end_seq))
170                         break;
171                 list_del(&info->list);
172
173                 destroy_record(info);
174                 deleted_records++;
175         }
176
177         ctx->unacked_record_sn += deleted_records;
178         spin_unlock_irqrestore(&ctx->lock, flags);
179 }
180
181 /* At this point, there should be no references on this
182  * socket and no in-flight SKBs associated with this
183  * socket, so it is safe to free all the resources.
184  */
185 void tls_device_sk_destruct(struct sock *sk)
186 {
187         struct tls_context *tls_ctx = tls_get_ctx(sk);
188         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
189
190         tls_ctx->sk_destruct(sk);
191
192         if (tls_ctx->tx_conf == TLS_HW) {
193                 if (ctx->open_record)
194                         destroy_record(ctx->open_record);
195                 delete_all_records(ctx);
196                 crypto_free_aead(ctx->aead_send);
197                 clean_acked_data_disable(inet_csk(sk));
198         }
199
200         tls_device_queue_ctx_destruction(tls_ctx);
201 }
202 EXPORT_SYMBOL_GPL(tls_device_sk_destruct);
203
204 void tls_device_free_resources_tx(struct sock *sk)
205 {
206         struct tls_context *tls_ctx = tls_get_ctx(sk);
207
208         tls_free_partial_record(sk, tls_ctx);
209 }
210
211 void tls_offload_tx_resync_request(struct sock *sk, u32 got_seq, u32 exp_seq)
212 {
213         struct tls_context *tls_ctx = tls_get_ctx(sk);
214
215         trace_tls_device_tx_resync_req(sk, got_seq, exp_seq);
216         WARN_ON(test_and_set_bit(TLS_TX_SYNC_SCHED, &tls_ctx->flags));
217 }
218 EXPORT_SYMBOL_GPL(tls_offload_tx_resync_request);
219
220 static void tls_device_resync_tx(struct sock *sk, struct tls_context *tls_ctx,
221                                  u32 seq)
222 {
223         struct net_device *netdev;
224         struct sk_buff *skb;
225         int err = 0;
226         u8 *rcd_sn;
227
228         skb = tcp_write_queue_tail(sk);
229         if (skb)
230                 TCP_SKB_CB(skb)->eor = 1;
231
232         rcd_sn = tls_ctx->tx.rec_seq;
233
234         trace_tls_device_tx_resync_send(sk, seq, rcd_sn);
235         down_read(&device_offload_lock);
236         netdev = tls_ctx->netdev;
237         if (netdev)
238                 err = netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq,
239                                                          rcd_sn,
240                                                          TLS_OFFLOAD_CTX_DIR_TX);
241         up_read(&device_offload_lock);
242         if (err)
243                 return;
244
245         clear_bit_unlock(TLS_TX_SYNC_SCHED, &tls_ctx->flags);
246 }
247
248 static void tls_append_frag(struct tls_record_info *record,
249                             struct page_frag *pfrag,
250                             int size)
251 {
252         skb_frag_t *frag;
253
254         frag = &record->frags[record->num_frags - 1];
255         if (skb_frag_page(frag) == pfrag->page &&
256             skb_frag_off(frag) + skb_frag_size(frag) == pfrag->offset) {
257                 skb_frag_size_add(frag, size);
258         } else {
259                 ++frag;
260                 __skb_frag_set_page(frag, pfrag->page);
261                 skb_frag_off_set(frag, pfrag->offset);
262                 skb_frag_size_set(frag, size);
263                 ++record->num_frags;
264                 get_page(pfrag->page);
265         }
266
267         pfrag->offset += size;
268         record->len += size;
269 }
270
271 static int tls_push_record(struct sock *sk,
272                            struct tls_context *ctx,
273                            struct tls_offload_context_tx *offload_ctx,
274                            struct tls_record_info *record,
275                            int flags)
276 {
277         struct tls_prot_info *prot = &ctx->prot_info;
278         struct tcp_sock *tp = tcp_sk(sk);
279         skb_frag_t *frag;
280         int i;
281
282         record->end_seq = tp->write_seq + record->len;
283         list_add_tail_rcu(&record->list, &offload_ctx->records_list);
284         offload_ctx->open_record = NULL;
285
286         if (test_bit(TLS_TX_SYNC_SCHED, &ctx->flags))
287                 tls_device_resync_tx(sk, ctx, tp->write_seq);
288
289         tls_advance_record_sn(sk, prot, &ctx->tx);
290
291         for (i = 0; i < record->num_frags; i++) {
292                 frag = &record->frags[i];
293                 sg_unmark_end(&offload_ctx->sg_tx_data[i]);
294                 sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
295                             skb_frag_size(frag), skb_frag_off(frag));
296                 sk_mem_charge(sk, skb_frag_size(frag));
297                 get_page(skb_frag_page(frag));
298         }
299         sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
300
301         /* all ready, send */
302         return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
303 }
304
305 static int tls_device_record_close(struct sock *sk,
306                                    struct tls_context *ctx,
307                                    struct tls_record_info *record,
308                                    struct page_frag *pfrag,
309                                    unsigned char record_type)
310 {
311         struct tls_prot_info *prot = &ctx->prot_info;
312         int ret;
313
314         /* append tag
315          * device will fill in the tag, we just need to append a placeholder
316          * use socket memory to improve coalescing (re-using a single buffer
317          * increases frag count)
318          * if we can't allocate memory now, steal some back from data
319          */
320         if (likely(skb_page_frag_refill(prot->tag_size, pfrag,
321                                         sk->sk_allocation))) {
322                 ret = 0;
323                 tls_append_frag(record, pfrag, prot->tag_size);
324         } else {
325                 ret = prot->tag_size;
326                 if (record->len <= prot->overhead_size)
327                         return -ENOMEM;
328         }
329
330         /* fill prepend */
331         tls_fill_prepend(ctx, skb_frag_address(&record->frags[0]),
332                          record->len - prot->overhead_size,
333                          record_type);
334         return ret;
335 }
336
337 static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
338                                  struct page_frag *pfrag,
339                                  size_t prepend_size)
340 {
341         struct tls_record_info *record;
342         skb_frag_t *frag;
343
344         record = kmalloc(sizeof(*record), GFP_KERNEL);
345         if (!record)
346                 return -ENOMEM;
347
348         frag = &record->frags[0];
349         __skb_frag_set_page(frag, pfrag->page);
350         skb_frag_off_set(frag, pfrag->offset);
351         skb_frag_size_set(frag, prepend_size);
352
353         get_page(pfrag->page);
354         pfrag->offset += prepend_size;
355
356         record->num_frags = 1;
357         record->len = prepend_size;
358         offload_ctx->open_record = record;
359         return 0;
360 }
361
362 static int tls_do_allocation(struct sock *sk,
363                              struct tls_offload_context_tx *offload_ctx,
364                              struct page_frag *pfrag,
365                              size_t prepend_size)
366 {
367         int ret;
368
369         if (!offload_ctx->open_record) {
370                 if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
371                                                    sk->sk_allocation))) {
372                         READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
373                         sk_stream_moderate_sndbuf(sk);
374                         return -ENOMEM;
375                 }
376
377                 ret = tls_create_new_record(offload_ctx, pfrag, prepend_size);
378                 if (ret)
379                         return ret;
380
381                 if (pfrag->size > pfrag->offset)
382                         return 0;
383         }
384
385         if (!sk_page_frag_refill(sk, pfrag))
386                 return -ENOMEM;
387
388         return 0;
389 }
390
391 static int tls_device_copy_data(void *addr, size_t bytes, struct iov_iter *i)
392 {
393         size_t pre_copy, nocache;
394
395         pre_copy = ~((unsigned long)addr - 1) & (SMP_CACHE_BYTES - 1);
396         if (pre_copy) {
397                 pre_copy = min(pre_copy, bytes);
398                 if (copy_from_iter(addr, pre_copy, i) != pre_copy)
399                         return -EFAULT;
400                 bytes -= pre_copy;
401                 addr += pre_copy;
402         }
403
404         nocache = round_down(bytes, SMP_CACHE_BYTES);
405         if (copy_from_iter_nocache(addr, nocache, i) != nocache)
406                 return -EFAULT;
407         bytes -= nocache;
408         addr += nocache;
409
410         if (bytes && copy_from_iter(addr, bytes, i) != bytes)
411                 return -EFAULT;
412
413         return 0;
414 }
415
416 static int tls_push_data(struct sock *sk,
417                          struct iov_iter *msg_iter,
418                          size_t size, int flags,
419                          unsigned char record_type)
420 {
421         struct tls_context *tls_ctx = tls_get_ctx(sk);
422         struct tls_prot_info *prot = &tls_ctx->prot_info;
423         struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
424         struct tls_record_info *record;
425         int tls_push_record_flags;
426         struct page_frag *pfrag;
427         size_t orig_size = size;
428         u32 max_open_record_len;
429         bool more = false;
430         bool done = false;
431         int copy, rc = 0;
432         long timeo;
433
434         if (flags &
435             ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST))
436                 return -EOPNOTSUPP;
437
438         if (unlikely(sk->sk_err))
439                 return -sk->sk_err;
440
441         flags |= MSG_SENDPAGE_DECRYPTED;
442         tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
443
444         timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
445         if (tls_is_partially_sent_record(tls_ctx)) {
446                 rc = tls_push_partial_record(sk, tls_ctx, flags);
447                 if (rc < 0)
448                         return rc;
449         }
450
451         pfrag = sk_page_frag(sk);
452
453         /* TLS_HEADER_SIZE is not counted as part of the TLS record, and
454          * we need to leave room for an authentication tag.
455          */
456         max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
457                               prot->prepend_size;
458         do {
459                 rc = tls_do_allocation(sk, ctx, pfrag, prot->prepend_size);
460                 if (unlikely(rc)) {
461                         rc = sk_stream_wait_memory(sk, &timeo);
462                         if (!rc)
463                                 continue;
464
465                         record = ctx->open_record;
466                         if (!record)
467                                 break;
468 handle_error:
469                         if (record_type != TLS_RECORD_TYPE_DATA) {
470                                 /* avoid sending partial
471                                  * record with type !=
472                                  * application_data
473                                  */
474                                 size = orig_size;
475                                 destroy_record(record);
476                                 ctx->open_record = NULL;
477                         } else if (record->len > prot->prepend_size) {
478                                 goto last_record;
479                         }
480
481                         break;
482                 }
483
484                 record = ctx->open_record;
485                 copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
486                 copy = min_t(size_t, copy, (max_open_record_len - record->len));
487
488                 if (copy) {
489                         rc = tls_device_copy_data(page_address(pfrag->page) +
490                                                   pfrag->offset, copy, msg_iter);
491                         if (rc)
492                                 goto handle_error;
493                         tls_append_frag(record, pfrag, copy);
494                 }
495
496                 size -= copy;
497                 if (!size) {
498 last_record:
499                         tls_push_record_flags = flags;
500                         if (flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE)) {
501                                 more = true;
502                                 break;
503                         }
504
505                         done = true;
506                 }
507
508                 if (done || record->len >= max_open_record_len ||
509                     (record->num_frags >= MAX_SKB_FRAGS - 1)) {
510                         rc = tls_device_record_close(sk, tls_ctx, record,
511                                                      pfrag, record_type);
512                         if (rc) {
513                                 if (rc > 0) {
514                                         size += rc;
515                                 } else {
516                                         size = orig_size;
517                                         destroy_record(record);
518                                         ctx->open_record = NULL;
519                                         break;
520                                 }
521                         }
522
523                         rc = tls_push_record(sk,
524                                              tls_ctx,
525                                              ctx,
526                                              record,
527                                              tls_push_record_flags);
528                         if (rc < 0)
529                                 break;
530                 }
531         } while (!done);
532
533         tls_ctx->pending_open_record_frags = more;
534
535         if (orig_size - size > 0)
536                 rc = orig_size - size;
537
538         return rc;
539 }
540
541 int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
542 {
543         unsigned char record_type = TLS_RECORD_TYPE_DATA;
544         struct tls_context *tls_ctx = tls_get_ctx(sk);
545         int rc;
546
547         mutex_lock(&tls_ctx->tx_lock);
548         lock_sock(sk);
549
550         if (unlikely(msg->msg_controllen)) {
551                 rc = tls_proccess_cmsg(sk, msg, &record_type);
552                 if (rc)
553                         goto out;
554         }
555
556         rc = tls_push_data(sk, &msg->msg_iter, size,
557                            msg->msg_flags, record_type);
558
559 out:
560         release_sock(sk);
561         mutex_unlock(&tls_ctx->tx_lock);
562         return rc;
563 }
564
565 int tls_device_sendpage(struct sock *sk, struct page *page,
566                         int offset, size_t size, int flags)
567 {
568         struct tls_context *tls_ctx = tls_get_ctx(sk);
569         struct iov_iter msg_iter;
570         char *kaddr;
571         struct kvec iov;
572         int rc;
573
574         if (flags & MSG_SENDPAGE_NOTLAST)
575                 flags |= MSG_MORE;
576
577         mutex_lock(&tls_ctx->tx_lock);
578         lock_sock(sk);
579
580         if (flags & MSG_OOB) {
581                 rc = -EOPNOTSUPP;
582                 goto out;
583         }
584
585         kaddr = kmap(page);
586         iov.iov_base = kaddr + offset;
587         iov.iov_len = size;
588         iov_iter_kvec(&msg_iter, WRITE, &iov, 1, size);
589         rc = tls_push_data(sk, &msg_iter, size,
590                            flags, TLS_RECORD_TYPE_DATA);
591         kunmap(page);
592
593 out:
594         release_sock(sk);
595         mutex_unlock(&tls_ctx->tx_lock);
596         return rc;
597 }
598
599 struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
600                                        u32 seq, u64 *p_record_sn)
601 {
602         u64 record_sn = context->hint_record_sn;
603         struct tls_record_info *info, *last;
604
605         info = context->retransmit_hint;
606         if (!info ||
607             before(seq, info->end_seq - info->len)) {
608                 /* if retransmit_hint is irrelevant start
609                  * from the beginning of the list
610                  */
611                 info = list_first_entry_or_null(&context->records_list,
612                                                 struct tls_record_info, list);
613                 if (!info)
614                         return NULL;
615                 /* send the start_marker record if seq number is before the
616                  * tls offload start marker sequence number. This record is
617                  * required to handle TCP packets which are before TLS offload
618                  * started.
619                  *  And if it's not start marker, look if this seq number
620                  * belongs to the list.
621                  */
622                 if (likely(!tls_record_is_start_marker(info))) {
623                         /* we have the first record, get the last record to see
624                          * if this seq number belongs to the list.
625                          */
626                         last = list_last_entry(&context->records_list,
627                                                struct tls_record_info, list);
628
629                         if (!between(seq, tls_record_start_seq(info),
630                                      last->end_seq))
631                                 return NULL;
632                 }
633                 record_sn = context->unacked_record_sn;
634         }
635
636         /* We just need the _rcu for the READ_ONCE() */
637         rcu_read_lock();
638         list_for_each_entry_from_rcu(info, &context->records_list, list) {
639                 if (before(seq, info->end_seq)) {
640                         if (!context->retransmit_hint ||
641                             after(info->end_seq,
642                                   context->retransmit_hint->end_seq)) {
643                                 context->hint_record_sn = record_sn;
644                                 context->retransmit_hint = info;
645                         }
646                         *p_record_sn = record_sn;
647                         goto exit_rcu_unlock;
648                 }
649                 record_sn++;
650         }
651         info = NULL;
652
653 exit_rcu_unlock:
654         rcu_read_unlock();
655         return info;
656 }
657 EXPORT_SYMBOL(tls_get_record);
658
659 static int tls_device_push_pending_record(struct sock *sk, int flags)
660 {
661         struct iov_iter msg_iter;
662
663         iov_iter_kvec(&msg_iter, WRITE, NULL, 0, 0);
664         return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
665 }
666
667 void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
668 {
669         if (tls_is_partially_sent_record(ctx)) {
670                 gfp_t sk_allocation = sk->sk_allocation;
671
672                 WARN_ON_ONCE(sk->sk_write_pending);
673
674                 sk->sk_allocation = GFP_ATOMIC;
675                 tls_push_partial_record(sk, ctx,
676                                         MSG_DONTWAIT | MSG_NOSIGNAL |
677                                         MSG_SENDPAGE_DECRYPTED);
678                 sk->sk_allocation = sk_allocation;
679         }
680 }
681
682 static void tls_device_resync_rx(struct tls_context *tls_ctx,
683                                  struct sock *sk, u32 seq, u8 *rcd_sn)
684 {
685         struct tls_offload_context_rx *rx_ctx = tls_offload_ctx_rx(tls_ctx);
686         struct net_device *netdev;
687
688         trace_tls_device_rx_resync_send(sk, seq, rcd_sn, rx_ctx->resync_type);
689         rcu_read_lock();
690         netdev = READ_ONCE(tls_ctx->netdev);
691         if (netdev)
692                 netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn,
693                                                    TLS_OFFLOAD_CTX_DIR_RX);
694         rcu_read_unlock();
695         TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICERESYNC);
696 }
697
698 static bool
699 tls_device_rx_resync_async(struct tls_offload_resync_async *resync_async,
700                            s64 resync_req, u32 *seq, u16 *rcd_delta)
701 {
702         u32 is_async = resync_req & RESYNC_REQ_ASYNC;
703         u32 req_seq = resync_req >> 32;
704         u32 req_end = req_seq + ((resync_req >> 16) & 0xffff);
705         u16 i;
706
707         *rcd_delta = 0;
708
709         if (is_async) {
710                 /* shouldn't get to wraparound:
711                  * too long in async stage, something bad happened
712                  */
713                 if (WARN_ON_ONCE(resync_async->rcd_delta == USHRT_MAX))
714                         return false;
715
716                 /* asynchronous stage: log all headers seq such that
717                  * req_seq <= seq <= end_seq, and wait for real resync request
718                  */
719                 if (before(*seq, req_seq))
720                         return false;
721                 if (!after(*seq, req_end) &&
722                     resync_async->loglen < TLS_DEVICE_RESYNC_ASYNC_LOGMAX)
723                         resync_async->log[resync_async->loglen++] = *seq;
724
725                 resync_async->rcd_delta++;
726
727                 return false;
728         }
729
730         /* synchronous stage: check against the logged entries and
731          * proceed to check the next entries if no match was found
732          */
733         for (i = 0; i < resync_async->loglen; i++)
734                 if (req_seq == resync_async->log[i] &&
735                     atomic64_try_cmpxchg(&resync_async->req, &resync_req, 0)) {
736                         *rcd_delta = resync_async->rcd_delta - i;
737                         *seq = req_seq;
738                         resync_async->loglen = 0;
739                         resync_async->rcd_delta = 0;
740                         return true;
741                 }
742
743         resync_async->loglen = 0;
744         resync_async->rcd_delta = 0;
745
746         if (req_seq == *seq &&
747             atomic64_try_cmpxchg(&resync_async->req,
748                                  &resync_req, 0))
749                 return true;
750
751         return false;
752 }
753
754 void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq)
755 {
756         struct tls_context *tls_ctx = tls_get_ctx(sk);
757         struct tls_offload_context_rx *rx_ctx;
758         u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
759         u32 sock_data, is_req_pending;
760         struct tls_prot_info *prot;
761         s64 resync_req;
762         u16 rcd_delta;
763         u32 req_seq;
764
765         if (tls_ctx->rx_conf != TLS_HW)
766                 return;
767         if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags)))
768                 return;
769
770         prot = &tls_ctx->prot_info;
771         rx_ctx = tls_offload_ctx_rx(tls_ctx);
772         memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
773
774         switch (rx_ctx->resync_type) {
775         case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ:
776                 resync_req = atomic64_read(&rx_ctx->resync_req);
777                 req_seq = resync_req >> 32;
778                 seq += TLS_HEADER_SIZE - 1;
779                 is_req_pending = resync_req;
780
781                 if (likely(!is_req_pending) || req_seq != seq ||
782                     !atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
783                         return;
784                 break;
785         case TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT:
786                 if (likely(!rx_ctx->resync_nh_do_now))
787                         return;
788
789                 /* head of next rec is already in, note that the sock_inq will
790                  * include the currently parsed message when called from parser
791                  */
792                 sock_data = tcp_inq(sk);
793                 if (sock_data > rcd_len) {
794                         trace_tls_device_rx_resync_nh_delay(sk, sock_data,
795                                                             rcd_len);
796                         return;
797                 }
798
799                 rx_ctx->resync_nh_do_now = 0;
800                 seq += rcd_len;
801                 tls_bigint_increment(rcd_sn, prot->rec_seq_size);
802                 break;
803         case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ_ASYNC:
804                 resync_req = atomic64_read(&rx_ctx->resync_async->req);
805                 is_req_pending = resync_req;
806                 if (likely(!is_req_pending))
807                         return;
808
809                 if (!tls_device_rx_resync_async(rx_ctx->resync_async,
810                                                 resync_req, &seq, &rcd_delta))
811                         return;
812                 tls_bigint_subtract(rcd_sn, rcd_delta);
813                 break;
814         }
815
816         tls_device_resync_rx(tls_ctx, sk, seq, rcd_sn);
817 }
818
819 static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
820                                            struct tls_offload_context_rx *ctx,
821                                            struct sock *sk, struct sk_buff *skb)
822 {
823         struct strp_msg *rxm;
824
825         /* device will request resyncs by itself based on stream scan */
826         if (ctx->resync_type != TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT)
827                 return;
828         /* already scheduled */
829         if (ctx->resync_nh_do_now)
830                 return;
831         /* seen decrypted fragments since last fully-failed record */
832         if (ctx->resync_nh_reset) {
833                 ctx->resync_nh_reset = 0;
834                 ctx->resync_nh.decrypted_failed = 1;
835                 ctx->resync_nh.decrypted_tgt = TLS_DEVICE_RESYNC_NH_START_IVAL;
836                 return;
837         }
838
839         if (++ctx->resync_nh.decrypted_failed <= ctx->resync_nh.decrypted_tgt)
840                 return;
841
842         /* doing resync, bump the next target in case it fails */
843         if (ctx->resync_nh.decrypted_tgt < TLS_DEVICE_RESYNC_NH_MAX_IVAL)
844                 ctx->resync_nh.decrypted_tgt *= 2;
845         else
846                 ctx->resync_nh.decrypted_tgt += TLS_DEVICE_RESYNC_NH_MAX_IVAL;
847
848         rxm = strp_msg(skb);
849
850         /* head of next rec is already in, parser will sync for us */
851         if (tcp_inq(sk) > rxm->full_len) {
852                 trace_tls_device_rx_resync_nh_schedule(sk);
853                 ctx->resync_nh_do_now = 1;
854         } else {
855                 struct tls_prot_info *prot = &tls_ctx->prot_info;
856                 u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
857
858                 memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
859                 tls_bigint_increment(rcd_sn, prot->rec_seq_size);
860
861                 tls_device_resync_rx(tls_ctx, sk, tcp_sk(sk)->copied_seq,
862                                      rcd_sn);
863         }
864 }
865
866 static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
867 {
868         struct strp_msg *rxm = strp_msg(skb);
869         int err = 0, offset = rxm->offset, copy, nsg, data_len, pos;
870         struct sk_buff *skb_iter, *unused;
871         struct scatterlist sg[1];
872         char *orig_buf, *buf;
873
874         orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
875                            TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
876         if (!orig_buf)
877                 return -ENOMEM;
878         buf = orig_buf;
879
880         nsg = skb_cow_data(skb, 0, &unused);
881         if (unlikely(nsg < 0)) {
882                 err = nsg;
883                 goto free_buf;
884         }
885
886         sg_init_table(sg, 1);
887         sg_set_buf(&sg[0], buf,
888                    rxm->full_len + TLS_HEADER_SIZE +
889                    TLS_CIPHER_AES_GCM_128_IV_SIZE);
890         err = skb_copy_bits(skb, offset, buf,
891                             TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
892         if (err)
893                 goto free_buf;
894
895         /* We are interested only in the decrypted data not the auth */
896         err = decrypt_skb(sk, skb, sg);
897         if (err != -EBADMSG)
898                 goto free_buf;
899         else
900                 err = 0;
901
902         data_len = rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE;
903
904         if (skb_pagelen(skb) > offset) {
905                 copy = min_t(int, skb_pagelen(skb) - offset, data_len);
906
907                 if (skb->decrypted) {
908                         err = skb_store_bits(skb, offset, buf, copy);
909                         if (err)
910                                 goto free_buf;
911                 }
912
913                 offset += copy;
914                 buf += copy;
915         }
916
917         pos = skb_pagelen(skb);
918         skb_walk_frags(skb, skb_iter) {
919                 int frag_pos;
920
921                 /* Practically all frags must belong to msg if reencrypt
922                  * is needed with current strparser and coalescing logic,
923                  * but strparser may "get optimized", so let's be safe.
924                  */
925                 if (pos + skb_iter->len <= offset)
926                         goto done_with_frag;
927                 if (pos >= data_len + rxm->offset)
928                         break;
929
930                 frag_pos = offset - pos;
931                 copy = min_t(int, skb_iter->len - frag_pos,
932                              data_len + rxm->offset - offset);
933
934                 if (skb_iter->decrypted) {
935                         err = skb_store_bits(skb_iter, frag_pos, buf, copy);
936                         if (err)
937                                 goto free_buf;
938                 }
939
940                 offset += copy;
941                 buf += copy;
942 done_with_frag:
943                 pos += skb_iter->len;
944         }
945
946 free_buf:
947         kfree(orig_buf);
948         return err;
949 }
950
951 int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
952                          struct sk_buff *skb, struct strp_msg *rxm)
953 {
954         struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
955         int is_decrypted = skb->decrypted;
956         int is_encrypted = !is_decrypted;
957         struct sk_buff *skb_iter;
958
959         /* Check if all the data is decrypted already */
960         skb_walk_frags(skb, skb_iter) {
961                 is_decrypted &= skb_iter->decrypted;
962                 is_encrypted &= !skb_iter->decrypted;
963         }
964
965         trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len,
966                                    tls_ctx->rx.rec_seq, rxm->full_len,
967                                    is_encrypted, is_decrypted);
968
969         ctx->sw.decrypted |= is_decrypted;
970
971         if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) {
972                 if (likely(is_encrypted || is_decrypted))
973                         return 0;
974
975                 /* After tls_device_down disables the offload, the next SKB will
976                  * likely have initial fragments decrypted, and final ones not
977                  * decrypted. We need to reencrypt that single SKB.
978                  */
979                 return tls_device_reencrypt(sk, skb);
980         }
981
982         /* Return immediately if the record is either entirely plaintext or
983          * entirely ciphertext. Otherwise handle reencrypt partially decrypted
984          * record.
985          */
986         if (is_decrypted) {
987                 ctx->resync_nh_reset = 1;
988                 return 0;
989         }
990         if (is_encrypted) {
991                 tls_device_core_ctrl_rx_resync(tls_ctx, ctx, sk, skb);
992                 return 0;
993         }
994
995         ctx->resync_nh_reset = 1;
996         return tls_device_reencrypt(sk, skb);
997 }
998
999 static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
1000                               struct net_device *netdev)
1001 {
1002         if (sk->sk_destruct != tls_device_sk_destruct) {
1003                 refcount_set(&ctx->refcount, 1);
1004                 dev_hold(netdev);
1005                 ctx->netdev = netdev;
1006                 spin_lock_irq(&tls_device_lock);
1007                 list_add_tail(&ctx->list, &tls_device_list);
1008                 spin_unlock_irq(&tls_device_lock);
1009
1010                 ctx->sk_destruct = sk->sk_destruct;
1011                 smp_store_release(&sk->sk_destruct, tls_device_sk_destruct);
1012         }
1013 }
1014
1015 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
1016 {
1017         u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
1018         struct tls_context *tls_ctx = tls_get_ctx(sk);
1019         struct tls_prot_info *prot = &tls_ctx->prot_info;
1020         struct tls_record_info *start_marker_record;
1021         struct tls_offload_context_tx *offload_ctx;
1022         struct tls_crypto_info *crypto_info;
1023         struct net_device *netdev;
1024         char *iv, *rec_seq;
1025         struct sk_buff *skb;
1026         __be64 rcd_sn;
1027         int rc;
1028
1029         if (!ctx)
1030                 return -EINVAL;
1031
1032         if (ctx->priv_ctx_tx)
1033                 return -EEXIST;
1034
1035         start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
1036         if (!start_marker_record)
1037                 return -ENOMEM;
1038
1039         offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
1040         if (!offload_ctx) {
1041                 rc = -ENOMEM;
1042                 goto free_marker_record;
1043         }
1044
1045         crypto_info = &ctx->crypto_send.info;
1046         if (crypto_info->version != TLS_1_2_VERSION) {
1047                 rc = -EOPNOTSUPP;
1048                 goto free_offload_ctx;
1049         }
1050
1051         switch (crypto_info->cipher_type) {
1052         case TLS_CIPHER_AES_GCM_128:
1053                 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1054                 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
1055                 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1056                 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
1057                 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
1058                 salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
1059                 rec_seq =
1060                  ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
1061                 break;
1062         default:
1063                 rc = -EINVAL;
1064                 goto free_offload_ctx;
1065         }
1066
1067         /* Sanity-check the rec_seq_size for stack allocations */
1068         if (rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
1069                 rc = -EINVAL;
1070                 goto free_offload_ctx;
1071         }
1072
1073         prot->version = crypto_info->version;
1074         prot->cipher_type = crypto_info->cipher_type;
1075         prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
1076         prot->tag_size = tag_size;
1077         prot->overhead_size = prot->prepend_size + prot->tag_size;
1078         prot->iv_size = iv_size;
1079         prot->salt_size = salt_size;
1080         ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1081                              GFP_KERNEL);
1082         if (!ctx->tx.iv) {
1083                 rc = -ENOMEM;
1084                 goto free_offload_ctx;
1085         }
1086
1087         memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1088
1089         prot->rec_seq_size = rec_seq_size;
1090         ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
1091         if (!ctx->tx.rec_seq) {
1092                 rc = -ENOMEM;
1093                 goto free_iv;
1094         }
1095
1096         rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
1097         if (rc)
1098                 goto free_rec_seq;
1099
1100         /* start at rec_seq - 1 to account for the start marker record */
1101         memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
1102         offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1;
1103
1104         start_marker_record->end_seq = tcp_sk(sk)->write_seq;
1105         start_marker_record->len = 0;
1106         start_marker_record->num_frags = 0;
1107
1108         INIT_LIST_HEAD(&offload_ctx->records_list);
1109         list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
1110         spin_lock_init(&offload_ctx->lock);
1111         sg_init_table(offload_ctx->sg_tx_data,
1112                       ARRAY_SIZE(offload_ctx->sg_tx_data));
1113
1114         clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked);
1115         ctx->push_pending_record = tls_device_push_pending_record;
1116
1117         /* TLS offload is greatly simplified if we don't send
1118          * SKBs where only part of the payload needs to be encrypted.
1119          * So mark the last skb in the write queue as end of record.
1120          */
1121         skb = tcp_write_queue_tail(sk);
1122         if (skb)
1123                 TCP_SKB_CB(skb)->eor = 1;
1124
1125         netdev = get_netdev_for_sock(sk);
1126         if (!netdev) {
1127                 pr_err_ratelimited("%s: netdev not found\n", __func__);
1128                 rc = -EINVAL;
1129                 goto disable_cad;
1130         }
1131
1132         if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
1133                 rc = -EOPNOTSUPP;
1134                 goto release_netdev;
1135         }
1136
1137         /* Avoid offloading if the device is down
1138          * We don't want to offload new flows after
1139          * the NETDEV_DOWN event
1140          *
1141          * device_offload_lock is taken in tls_devices's NETDEV_DOWN
1142          * handler thus protecting from the device going down before
1143          * ctx was added to tls_device_list.
1144          */
1145         down_read(&device_offload_lock);
1146         if (!(netdev->flags & IFF_UP)) {
1147                 rc = -EINVAL;
1148                 goto release_lock;
1149         }
1150
1151         ctx->priv_ctx_tx = offload_ctx;
1152         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX,
1153                                              &ctx->crypto_send.info,
1154                                              tcp_sk(sk)->write_seq);
1155         trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_TX,
1156                                      tcp_sk(sk)->write_seq, rec_seq, rc);
1157         if (rc)
1158                 goto release_lock;
1159
1160         tls_device_attach(ctx, sk, netdev);
1161         up_read(&device_offload_lock);
1162
1163         /* following this assignment tls_is_sk_tx_device_offloaded
1164          * will return true and the context might be accessed
1165          * by the netdev's xmit function.
1166          */
1167         smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
1168         dev_put(netdev);
1169
1170         return 0;
1171
1172 release_lock:
1173         up_read(&device_offload_lock);
1174 release_netdev:
1175         dev_put(netdev);
1176 disable_cad:
1177         clean_acked_data_disable(inet_csk(sk));
1178         crypto_free_aead(offload_ctx->aead_send);
1179 free_rec_seq:
1180         kfree(ctx->tx.rec_seq);
1181 free_iv:
1182         kfree(ctx->tx.iv);
1183 free_offload_ctx:
1184         kfree(offload_ctx);
1185         ctx->priv_ctx_tx = NULL;
1186 free_marker_record:
1187         kfree(start_marker_record);
1188         return rc;
1189 }
1190
1191 int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
1192 {
1193         struct tls12_crypto_info_aes_gcm_128 *info;
1194         struct tls_offload_context_rx *context;
1195         struct net_device *netdev;
1196         int rc = 0;
1197
1198         if (ctx->crypto_recv.info.version != TLS_1_2_VERSION)
1199                 return -EOPNOTSUPP;
1200
1201         netdev = get_netdev_for_sock(sk);
1202         if (!netdev) {
1203                 pr_err_ratelimited("%s: netdev not found\n", __func__);
1204                 return -EINVAL;
1205         }
1206
1207         if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
1208                 rc = -EOPNOTSUPP;
1209                 goto release_netdev;
1210         }
1211
1212         /* Avoid offloading if the device is down
1213          * We don't want to offload new flows after
1214          * the NETDEV_DOWN event
1215          *
1216          * device_offload_lock is taken in tls_devices's NETDEV_DOWN
1217          * handler thus protecting from the device going down before
1218          * ctx was added to tls_device_list.
1219          */
1220         down_read(&device_offload_lock);
1221         if (!(netdev->flags & IFF_UP)) {
1222                 rc = -EINVAL;
1223                 goto release_lock;
1224         }
1225
1226         context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
1227         if (!context) {
1228                 rc = -ENOMEM;
1229                 goto release_lock;
1230         }
1231         context->resync_nh_reset = 1;
1232
1233         ctx->priv_ctx_rx = context;
1234         rc = tls_set_sw_offload(sk, ctx, 0);
1235         if (rc)
1236                 goto release_ctx;
1237
1238         rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
1239                                              &ctx->crypto_recv.info,
1240                                              tcp_sk(sk)->copied_seq);
1241         info = (void *)&ctx->crypto_recv.info;
1242         trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_RX,
1243                                      tcp_sk(sk)->copied_seq, info->rec_seq, rc);
1244         if (rc)
1245                 goto free_sw_resources;
1246
1247         tls_device_attach(ctx, sk, netdev);
1248         up_read(&device_offload_lock);
1249
1250         dev_put(netdev);
1251
1252         return 0;
1253
1254 free_sw_resources:
1255         up_read(&device_offload_lock);
1256         tls_sw_free_resources_rx(sk);
1257         down_read(&device_offload_lock);
1258 release_ctx:
1259         ctx->priv_ctx_rx = NULL;
1260 release_lock:
1261         up_read(&device_offload_lock);
1262 release_netdev:
1263         dev_put(netdev);
1264         return rc;
1265 }
1266
1267 void tls_device_offload_cleanup_rx(struct sock *sk)
1268 {
1269         struct tls_context *tls_ctx = tls_get_ctx(sk);
1270         struct net_device *netdev;
1271
1272         down_read(&device_offload_lock);
1273         netdev = tls_ctx->netdev;
1274         if (!netdev)
1275                 goto out;
1276
1277         netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
1278                                         TLS_OFFLOAD_CTX_DIR_RX);
1279
1280         if (tls_ctx->tx_conf != TLS_HW) {
1281                 dev_put(netdev);
1282                 tls_ctx->netdev = NULL;
1283         } else {
1284                 set_bit(TLS_RX_DEV_CLOSED, &tls_ctx->flags);
1285         }
1286 out:
1287         up_read(&device_offload_lock);
1288         tls_sw_release_resources_rx(sk);
1289 }
1290
1291 static int tls_device_down(struct net_device *netdev)
1292 {
1293         struct tls_context *ctx, *tmp;
1294         unsigned long flags;
1295         LIST_HEAD(list);
1296
1297         /* Request a write lock to block new offload attempts */
1298         down_write(&device_offload_lock);
1299
1300         spin_lock_irqsave(&tls_device_lock, flags);
1301         list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
1302                 if (ctx->netdev != netdev ||
1303                     !refcount_inc_not_zero(&ctx->refcount))
1304                         continue;
1305
1306                 list_move(&ctx->list, &list);
1307         }
1308         spin_unlock_irqrestore(&tls_device_lock, flags);
1309
1310         list_for_each_entry_safe(ctx, tmp, &list, list) {
1311                 /* Stop offloaded TX and switch to the fallback.
1312                  * tls_is_sk_tx_device_offloaded will return false.
1313                  */
1314                 WRITE_ONCE(ctx->sk->sk_validate_xmit_skb, tls_validate_xmit_skb_sw);
1315
1316                 /* Stop the RX and TX resync.
1317                  * tls_dev_resync must not be called after tls_dev_del.
1318                  */
1319                 WRITE_ONCE(ctx->netdev, NULL);
1320
1321                 /* Start skipping the RX resync logic completely. */
1322                 set_bit(TLS_RX_DEV_DEGRADED, &ctx->flags);
1323
1324                 /* Sync with inflight packets. After this point:
1325                  * TX: no non-encrypted packets will be passed to the driver.
1326                  * RX: resync requests from the driver will be ignored.
1327                  */
1328                 synchronize_net();
1329
1330                 /* Release the offload context on the driver side. */
1331                 if (ctx->tx_conf == TLS_HW)
1332                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
1333                                                         TLS_OFFLOAD_CTX_DIR_TX);
1334                 if (ctx->rx_conf == TLS_HW &&
1335                     !test_bit(TLS_RX_DEV_CLOSED, &ctx->flags))
1336                         netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
1337                                                         TLS_OFFLOAD_CTX_DIR_RX);
1338
1339                 dev_put(netdev);
1340
1341                 /* Move the context to a separate list for two reasons:
1342                  * 1. When the context is deallocated, list_del is called.
1343                  * 2. It's no longer an offloaded context, so we don't want to
1344                  *    run offload-specific code on this context.
1345                  */
1346                 spin_lock_irqsave(&tls_device_lock, flags);
1347                 list_move_tail(&ctx->list, &tls_device_down_list);
1348                 spin_unlock_irqrestore(&tls_device_lock, flags);
1349
1350                 /* Device contexts for RX and TX will be freed in on sk_destruct
1351                  * by tls_device_free_ctx. rx_conf and tx_conf stay in TLS_HW.
1352                  * Now release the ref taken above.
1353                  */
1354                 if (refcount_dec_and_test(&ctx->refcount))
1355                         tls_device_free_ctx(ctx);
1356         }
1357
1358         up_write(&device_offload_lock);
1359
1360         flush_work(&tls_device_gc_work);
1361
1362         return NOTIFY_DONE;
1363 }
1364
1365 static int tls_dev_event(struct notifier_block *this, unsigned long event,
1366                          void *ptr)
1367 {
1368         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
1369
1370         if (!dev->tlsdev_ops &&
1371             !(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
1372                 return NOTIFY_DONE;
1373
1374         switch (event) {
1375         case NETDEV_REGISTER:
1376         case NETDEV_FEAT_CHANGE:
1377                 if (netif_is_bond_master(dev))
1378                         return NOTIFY_DONE;
1379                 if ((dev->features & NETIF_F_HW_TLS_RX) &&
1380                     !dev->tlsdev_ops->tls_dev_resync)
1381                         return NOTIFY_BAD;
1382
1383                 if  (dev->tlsdev_ops &&
1384                      dev->tlsdev_ops->tls_dev_add &&
1385                      dev->tlsdev_ops->tls_dev_del)
1386                         return NOTIFY_DONE;
1387                 else
1388                         return NOTIFY_BAD;
1389         case NETDEV_DOWN:
1390                 return tls_device_down(dev);
1391         }
1392         return NOTIFY_DONE;
1393 }
1394
1395 static struct notifier_block tls_dev_notifier = {
1396         .notifier_call  = tls_dev_event,
1397 };
1398
1399 int __init tls_device_init(void)
1400 {
1401         return register_netdevice_notifier(&tls_dev_notifier);
1402 }
1403
1404 void __exit tls_device_cleanup(void)
1405 {
1406         unregister_netdevice_notifier(&tls_dev_notifier);
1407         flush_work(&tls_device_gc_work);
1408         clean_acked_data_flush();
1409 }