usb: typec: mux: fix static inline syntax error
[platform/kernel/linux-starfive.git] / net / tls / tls_strp.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (c) 2016 Tom Herbert <tom@herbertland.com> */
3
4 #include <linux/skbuff.h>
5 #include <linux/workqueue.h>
6 #include <net/strparser.h>
7 #include <net/tcp.h>
8 #include <net/sock.h>
9 #include <net/tls.h>
10
11 #include "tls.h"
12
13 static struct workqueue_struct *tls_strp_wq;
14
15 static void tls_strp_abort_strp(struct tls_strparser *strp, int err)
16 {
17         if (strp->stopped)
18                 return;
19
20         strp->stopped = 1;
21
22         /* Report an error on the lower socket */
23         strp->sk->sk_err = -err;
24         sk_error_report(strp->sk);
25 }
26
27 static void tls_strp_anchor_free(struct tls_strparser *strp)
28 {
29         struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
30
31         DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
32         if (!strp->copy_mode)
33                 shinfo->frag_list = NULL;
34         consume_skb(strp->anchor);
35         strp->anchor = NULL;
36 }
37
38 static struct sk_buff *
39 tls_strp_skb_copy(struct tls_strparser *strp, struct sk_buff *in_skb,
40                   int offset, int len)
41 {
42         struct sk_buff *skb;
43         int i, err;
44
45         skb = alloc_skb_with_frags(0, len, TLS_PAGE_ORDER,
46                                    &err, strp->sk->sk_allocation);
47         if (!skb)
48                 return NULL;
49
50         for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
51                 skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
52
53                 WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
54                                            skb_frag_address(frag),
55                                            skb_frag_size(frag)));
56                 offset += skb_frag_size(frag);
57         }
58
59         skb->len = len;
60         skb->data_len = len;
61         skb_copy_header(skb, in_skb);
62         return skb;
63 }
64
65 /* Create a new skb with the contents of input copied to its page frags */
66 static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
67 {
68         struct strp_msg *rxm;
69         struct sk_buff *skb;
70
71         skb = tls_strp_skb_copy(strp, strp->anchor, strp->stm.offset,
72                                 strp->stm.full_len);
73         if (!skb)
74                 return NULL;
75
76         rxm = strp_msg(skb);
77         rxm->offset = 0;
78         return skb;
79 }
80
81 /* Steal the input skb, input msg is invalid after calling this function */
82 struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx)
83 {
84         struct tls_strparser *strp = &ctx->strp;
85
86 #ifdef CONFIG_TLS_DEVICE
87         DEBUG_NET_WARN_ON_ONCE(!strp->anchor->decrypted);
88 #else
89         /* This function turns an input into an output,
90          * that can only happen if we have offload.
91          */
92         WARN_ON(1);
93 #endif
94
95         if (strp->copy_mode) {
96                 struct sk_buff *skb;
97
98                 /* Replace anchor with an empty skb, this is a little
99                  * dangerous but __tls_cur_msg() warns on empty skbs
100                  * so hopefully we'll catch abuses.
101                  */
102                 skb = alloc_skb(0, strp->sk->sk_allocation);
103                 if (!skb)
104                         return NULL;
105
106                 swap(strp->anchor, skb);
107                 return skb;
108         }
109
110         return tls_strp_msg_make_copy(strp);
111 }
112
113 /* Force the input skb to be in copy mode. The data ownership remains
114  * with the input skb itself (meaning unpause will wipe it) but it can
115  * be modified.
116  */
117 int tls_strp_msg_cow(struct tls_sw_context_rx *ctx)
118 {
119         struct tls_strparser *strp = &ctx->strp;
120         struct sk_buff *skb;
121
122         if (strp->copy_mode)
123                 return 0;
124
125         skb = tls_strp_msg_make_copy(strp);
126         if (!skb)
127                 return -ENOMEM;
128
129         tls_strp_anchor_free(strp);
130         strp->anchor = skb;
131
132         tcp_read_done(strp->sk, strp->stm.full_len);
133         strp->copy_mode = 1;
134
135         return 0;
136 }
137
138 /* Make a clone (in the skb sense) of the input msg to keep a reference
139  * to the underlying data. The reference-holding skbs get placed on
140  * @dst.
141  */
142 int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst)
143 {
144         struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
145
146         if (strp->copy_mode) {
147                 struct sk_buff *skb;
148
149                 WARN_ON_ONCE(!shinfo->nr_frags);
150
151                 /* We can't skb_clone() the anchor, it gets wiped by unpause */
152                 skb = alloc_skb(0, strp->sk->sk_allocation);
153                 if (!skb)
154                         return -ENOMEM;
155
156                 __skb_queue_tail(dst, strp->anchor);
157                 strp->anchor = skb;
158         } else {
159                 struct sk_buff *iter, *clone;
160                 int chunk, len, offset;
161
162                 offset = strp->stm.offset;
163                 len = strp->stm.full_len;
164                 iter = shinfo->frag_list;
165
166                 while (len > 0) {
167                         if (iter->len <= offset) {
168                                 offset -= iter->len;
169                                 goto next;
170                         }
171
172                         chunk = iter->len - offset;
173                         offset = 0;
174
175                         clone = skb_clone(iter, strp->sk->sk_allocation);
176                         if (!clone)
177                                 return -ENOMEM;
178                         __skb_queue_tail(dst, clone);
179
180                         len -= chunk;
181 next:
182                         iter = iter->next;
183                 }
184         }
185
186         return 0;
187 }
188
189 static void tls_strp_flush_anchor_copy(struct tls_strparser *strp)
190 {
191         struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
192         int i;
193
194         DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
195
196         for (i = 0; i < shinfo->nr_frags; i++)
197                 __skb_frag_unref(&shinfo->frags[i], false);
198         shinfo->nr_frags = 0;
199         if (strp->copy_mode) {
200                 kfree_skb_list(shinfo->frag_list);
201                 shinfo->frag_list = NULL;
202         }
203         strp->copy_mode = 0;
204         strp->mixed_decrypted = 0;
205 }
206
207 static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb,
208                                 struct sk_buff *in_skb, unsigned int offset,
209                                 size_t in_len)
210 {
211         size_t len, chunk;
212         skb_frag_t *frag;
213         int sz;
214
215         frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE];
216
217         len = in_len;
218         /* First make sure we got the header */
219         if (!strp->stm.full_len) {
220                 /* Assume one page is more than enough for headers */
221                 chunk = min_t(size_t, len, PAGE_SIZE - skb_frag_size(frag));
222                 WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
223                                            skb_frag_address(frag) +
224                                            skb_frag_size(frag),
225                                            chunk));
226
227                 skb->len += chunk;
228                 skb->data_len += chunk;
229                 skb_frag_size_add(frag, chunk);
230
231                 sz = tls_rx_msg_size(strp, skb);
232                 if (sz < 0)
233                         return sz;
234
235                 /* We may have over-read, sz == 0 is guaranteed under-read */
236                 if (unlikely(sz && sz < skb->len)) {
237                         int over = skb->len - sz;
238
239                         WARN_ON_ONCE(over > chunk);
240                         skb->len -= over;
241                         skb->data_len -= over;
242                         skb_frag_size_add(frag, -over);
243
244                         chunk -= over;
245                 }
246
247                 frag++;
248                 len -= chunk;
249                 offset += chunk;
250
251                 strp->stm.full_len = sz;
252                 if (!strp->stm.full_len)
253                         goto read_done;
254         }
255
256         /* Load up more data */
257         while (len && strp->stm.full_len > skb->len) {
258                 chunk = min_t(size_t, len, strp->stm.full_len - skb->len);
259                 chunk = min_t(size_t, chunk, PAGE_SIZE - skb_frag_size(frag));
260                 WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
261                                            skb_frag_address(frag) +
262                                            skb_frag_size(frag),
263                                            chunk));
264
265                 skb->len += chunk;
266                 skb->data_len += chunk;
267                 skb_frag_size_add(frag, chunk);
268                 frag++;
269                 len -= chunk;
270                 offset += chunk;
271         }
272
273 read_done:
274         return in_len - len;
275 }
276
277 static int tls_strp_copyin_skb(struct tls_strparser *strp, struct sk_buff *skb,
278                                struct sk_buff *in_skb, unsigned int offset,
279                                size_t in_len)
280 {
281         struct sk_buff *nskb, *first, *last;
282         struct skb_shared_info *shinfo;
283         size_t chunk;
284         int sz;
285
286         if (strp->stm.full_len)
287                 chunk = strp->stm.full_len - skb->len;
288         else
289                 chunk = TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
290         chunk = min(chunk, in_len);
291
292         nskb = tls_strp_skb_copy(strp, in_skb, offset, chunk);
293         if (!nskb)
294                 return -ENOMEM;
295
296         shinfo = skb_shinfo(skb);
297         if (!shinfo->frag_list) {
298                 shinfo->frag_list = nskb;
299                 nskb->prev = nskb;
300         } else {
301                 first = shinfo->frag_list;
302                 last = first->prev;
303                 last->next = nskb;
304                 first->prev = nskb;
305         }
306
307         skb->len += chunk;
308         skb->data_len += chunk;
309
310         if (!strp->stm.full_len) {
311                 sz = tls_rx_msg_size(strp, skb);
312                 if (sz < 0)
313                         return sz;
314
315                 /* We may have over-read, sz == 0 is guaranteed under-read */
316                 if (unlikely(sz && sz < skb->len)) {
317                         int over = skb->len - sz;
318
319                         WARN_ON_ONCE(over > chunk);
320                         skb->len -= over;
321                         skb->data_len -= over;
322                         __pskb_trim(nskb, nskb->len - over);
323
324                         chunk -= over;
325                 }
326
327                 strp->stm.full_len = sz;
328         }
329
330         return chunk;
331 }
332
333 static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
334                            unsigned int offset, size_t in_len)
335 {
336         struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
337         struct sk_buff *skb;
338         int ret;
339
340         if (strp->msg_ready)
341                 return 0;
342
343         skb = strp->anchor;
344         if (!skb->len)
345                 skb_copy_decrypted(skb, in_skb);
346         else
347                 strp->mixed_decrypted |= !!skb_cmp_decrypted(skb, in_skb);
348
349         if (IS_ENABLED(CONFIG_TLS_DEVICE) && strp->mixed_decrypted)
350                 ret = tls_strp_copyin_skb(strp, skb, in_skb, offset, in_len);
351         else
352                 ret = tls_strp_copyin_frag(strp, skb, in_skb, offset, in_len);
353         if (ret < 0) {
354                 desc->error = ret;
355                 ret = 0;
356         }
357
358         if (strp->stm.full_len && strp->stm.full_len == skb->len) {
359                 desc->count = 0;
360
361                 strp->msg_ready = 1;
362                 tls_rx_msg_ready(strp);
363         }
364
365         return ret;
366 }
367
368 static int tls_strp_read_copyin(struct tls_strparser *strp)
369 {
370         struct socket *sock = strp->sk->sk_socket;
371         read_descriptor_t desc;
372
373         desc.arg.data = strp;
374         desc.error = 0;
375         desc.count = 1; /* give more than one skb per call */
376
377         /* sk should be locked here, so okay to do read_sock */
378         sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin);
379
380         return desc.error;
381 }
382
383 static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort)
384 {
385         struct skb_shared_info *shinfo;
386         struct page *page;
387         int need_spc, len;
388
389         /* If the rbuf is small or rcv window has collapsed to 0 we need
390          * to read the data out. Otherwise the connection will stall.
391          * Without pressure threshold of INT_MAX will never be ready.
392          */
393         if (likely(qshort && !tcp_epollin_ready(strp->sk, INT_MAX)))
394                 return 0;
395
396         shinfo = skb_shinfo(strp->anchor);
397         shinfo->frag_list = NULL;
398
399         /* If we don't know the length go max plus page for cipher overhead */
400         need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
401
402         for (len = need_spc; len > 0; len -= PAGE_SIZE) {
403                 page = alloc_page(strp->sk->sk_allocation);
404                 if (!page) {
405                         tls_strp_flush_anchor_copy(strp);
406                         return -ENOMEM;
407                 }
408
409                 skb_fill_page_desc(strp->anchor, shinfo->nr_frags++,
410                                    page, 0, 0);
411         }
412
413         strp->copy_mode = 1;
414         strp->stm.offset = 0;
415
416         strp->anchor->len = 0;
417         strp->anchor->data_len = 0;
418         strp->anchor->truesize = round_up(need_spc, PAGE_SIZE);
419
420         tls_strp_read_copyin(strp);
421
422         return 0;
423 }
424
425 static bool tls_strp_check_queue_ok(struct tls_strparser *strp)
426 {
427         unsigned int len = strp->stm.offset + strp->stm.full_len;
428         struct sk_buff *first, *skb;
429         u32 seq;
430
431         first = skb_shinfo(strp->anchor)->frag_list;
432         skb = first;
433         seq = TCP_SKB_CB(first)->seq;
434
435         /* Make sure there's no duplicate data in the queue,
436          * and the decrypted status matches.
437          */
438         while (skb->len < len) {
439                 seq += skb->len;
440                 len -= skb->len;
441                 skb = skb->next;
442
443                 if (TCP_SKB_CB(skb)->seq != seq)
444                         return false;
445                 if (skb_cmp_decrypted(first, skb))
446                         return false;
447         }
448
449         return true;
450 }
451
452 static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len)
453 {
454         struct tcp_sock *tp = tcp_sk(strp->sk);
455         struct sk_buff *first;
456         u32 offset;
457
458         first = tcp_recv_skb(strp->sk, tp->copied_seq, &offset);
459         if (WARN_ON_ONCE(!first))
460                 return;
461
462         /* Bestow the state onto the anchor */
463         strp->anchor->len = offset + len;
464         strp->anchor->data_len = offset + len;
465         strp->anchor->truesize = offset + len;
466
467         skb_shinfo(strp->anchor)->frag_list = first;
468
469         skb_copy_header(strp->anchor, first);
470         strp->anchor->destructor = NULL;
471
472         strp->stm.offset = offset;
473 }
474
475 void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
476 {
477         struct strp_msg *rxm;
478         struct tls_msg *tlm;
479
480         DEBUG_NET_WARN_ON_ONCE(!strp->msg_ready);
481         DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len);
482
483         if (!strp->copy_mode && force_refresh) {
484                 if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len))
485                         return;
486
487                 tls_strp_load_anchor_with_queue(strp, strp->stm.full_len);
488         }
489
490         rxm = strp_msg(strp->anchor);
491         rxm->full_len   = strp->stm.full_len;
492         rxm->offset     = strp->stm.offset;
493         tlm = tls_msg(strp->anchor);
494         tlm->control    = strp->mark;
495 }
496
497 /* Called with lock held on lower socket */
498 static int tls_strp_read_sock(struct tls_strparser *strp)
499 {
500         int sz, inq;
501
502         inq = tcp_inq(strp->sk);
503         if (inq < 1)
504                 return 0;
505
506         if (unlikely(strp->copy_mode))
507                 return tls_strp_read_copyin(strp);
508
509         if (inq < strp->stm.full_len)
510                 return tls_strp_read_copy(strp, true);
511
512         if (!strp->stm.full_len) {
513                 tls_strp_load_anchor_with_queue(strp, inq);
514
515                 sz = tls_rx_msg_size(strp, strp->anchor);
516                 if (sz < 0) {
517                         tls_strp_abort_strp(strp, sz);
518                         return sz;
519                 }
520
521                 strp->stm.full_len = sz;
522
523                 if (!strp->stm.full_len || inq < strp->stm.full_len)
524                         return tls_strp_read_copy(strp, true);
525         }
526
527         if (!tls_strp_check_queue_ok(strp))
528                 return tls_strp_read_copy(strp, false);
529
530         strp->msg_ready = 1;
531         tls_rx_msg_ready(strp);
532
533         return 0;
534 }
535
536 void tls_strp_check_rcv(struct tls_strparser *strp)
537 {
538         if (unlikely(strp->stopped) || strp->msg_ready)
539                 return;
540
541         if (tls_strp_read_sock(strp) == -ENOMEM)
542                 queue_work(tls_strp_wq, &strp->work);
543 }
544
545 /* Lower sock lock held */
546 void tls_strp_data_ready(struct tls_strparser *strp)
547 {
548         /* This check is needed to synchronize with do_tls_strp_work.
549          * do_tls_strp_work acquires a process lock (lock_sock) whereas
550          * the lock held here is bh_lock_sock. The two locks can be
551          * held by different threads at the same time, but bh_lock_sock
552          * allows a thread in BH context to safely check if the process
553          * lock is held. In this case, if the lock is held, queue work.
554          */
555         if (sock_owned_by_user_nocheck(strp->sk)) {
556                 queue_work(tls_strp_wq, &strp->work);
557                 return;
558         }
559
560         tls_strp_check_rcv(strp);
561 }
562
563 static void tls_strp_work(struct work_struct *w)
564 {
565         struct tls_strparser *strp =
566                 container_of(w, struct tls_strparser, work);
567
568         lock_sock(strp->sk);
569         tls_strp_check_rcv(strp);
570         release_sock(strp->sk);
571 }
572
573 void tls_strp_msg_done(struct tls_strparser *strp)
574 {
575         WARN_ON(!strp->stm.full_len);
576
577         if (likely(!strp->copy_mode))
578                 tcp_read_done(strp->sk, strp->stm.full_len);
579         else
580                 tls_strp_flush_anchor_copy(strp);
581
582         strp->msg_ready = 0;
583         memset(&strp->stm, 0, sizeof(strp->stm));
584
585         tls_strp_check_rcv(strp);
586 }
587
588 void tls_strp_stop(struct tls_strparser *strp)
589 {
590         strp->stopped = 1;
591 }
592
593 int tls_strp_init(struct tls_strparser *strp, struct sock *sk)
594 {
595         memset(strp, 0, sizeof(*strp));
596
597         strp->sk = sk;
598
599         strp->anchor = alloc_skb(0, GFP_KERNEL);
600         if (!strp->anchor)
601                 return -ENOMEM;
602
603         INIT_WORK(&strp->work, tls_strp_work);
604
605         return 0;
606 }
607
608 /* strp must already be stopped so that tls_strp_recv will no longer be called.
609  * Note that tls_strp_done is not called with the lower socket held.
610  */
611 void tls_strp_done(struct tls_strparser *strp)
612 {
613         WARN_ON(!strp->stopped);
614
615         cancel_work_sync(&strp->work);
616         tls_strp_anchor_free(strp);
617 }
618
619 int __init tls_strp_dev_init(void)
620 {
621         tls_strp_wq = create_workqueue("tls-strp");
622         if (unlikely(!tls_strp_wq))
623                 return -ENOMEM;
624
625         return 0;
626 }
627
628 void tls_strp_dev_exit(void)
629 {
630         destroy_workqueue(tls_strp_wq);
631 }