0e725f18af240ed8da024e18cc0d140ac2648b8d
[platform/kernel/linux-rpi.git] / net / mptcp / protocol.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2017 - 2019, Intel Corporation.
5  */
6
7 #define pr_fmt(fmt) "MPTCP: " fmt
8
9 #include <linux/kernel.h>
10 #include <linux/module.h>
11 #include <linux/netdevice.h>
12 #include <linux/sched/signal.h>
13 #include <linux/atomic.h>
14 #include <net/sock.h>
15 #include <net/inet_common.h>
16 #include <net/inet_hashtables.h>
17 #include <net/protocol.h>
18 #include <net/tcp.h>
19 #include <net/tcp_states.h>
20 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
21 #include <net/transp_v6.h>
22 #endif
23 #include <net/mptcp.h>
24 #include "protocol.h"
25 #include "mib.h"
26
27 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
28 struct mptcp6_sock {
29         struct mptcp_sock msk;
30         struct ipv6_pinfo np;
31 };
32 #endif
33
34 struct mptcp_skb_cb {
35         u64 map_seq;
36         u64 end_seq;
37         u32 offset;
38 };
39
40 #define MPTCP_SKB_CB(__skb)     ((struct mptcp_skb_cb *)&((__skb)->cb[0]))
41
42 static struct percpu_counter mptcp_sockets_allocated;
43
44 /* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not
45  * completed yet or has failed, return the subflow socket.
46  * Otherwise return NULL.
47  */
48 static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
49 {
50         if (!msk->subflow || READ_ONCE(msk->can_ack))
51                 return NULL;
52
53         return msk->subflow;
54 }
55
56 static bool mptcp_is_tcpsk(struct sock *sk)
57 {
58         struct socket *sock = sk->sk_socket;
59
60         if (unlikely(sk->sk_prot == &tcp_prot)) {
61                 /* we are being invoked after mptcp_accept() has
62                  * accepted a non-mp-capable flow: sk is a tcp_sk,
63                  * not an mptcp one.
64                  *
65                  * Hand the socket over to tcp so all further socket ops
66                  * bypass mptcp.
67                  */
68                 sock->ops = &inet_stream_ops;
69                 return true;
70 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
71         } else if (unlikely(sk->sk_prot == &tcpv6_prot)) {
72                 sock->ops = &inet6_stream_ops;
73                 return true;
74 #endif
75         }
76
77         return false;
78 }
79
80 static struct sock *__mptcp_tcp_fallback(struct mptcp_sock *msk)
81 {
82         sock_owned_by_me((const struct sock *)msk);
83
84         if (likely(!__mptcp_check_fallback(msk)))
85                 return NULL;
86
87         return msk->first;
88 }
89
90 static int __mptcp_socket_create(struct mptcp_sock *msk)
91 {
92         struct mptcp_subflow_context *subflow;
93         struct sock *sk = (struct sock *)msk;
94         struct socket *ssock;
95         int err;
96
97         err = mptcp_subflow_create_socket(sk, &ssock);
98         if (err)
99                 return err;
100
101         msk->first = ssock->sk;
102         msk->subflow = ssock;
103         subflow = mptcp_subflow_ctx(ssock->sk);
104         list_add(&subflow->node, &msk->conn_list);
105         subflow->request_mptcp = 1;
106
107         /* accept() will wait on first subflow sk_wq, and we always wakes up
108          * via msk->sk_socket
109          */
110         RCU_INIT_POINTER(msk->first->sk_wq, &sk->sk_socket->wq);
111
112         return 0;
113 }
114
115 static void mptcp_drop(struct sock *sk, struct sk_buff *skb)
116 {
117         sk_drops_add(sk, skb);
118         __kfree_skb(skb);
119 }
120
121 static bool mptcp_try_coalesce(struct sock *sk, struct sk_buff *to,
122                                struct sk_buff *from)
123 {
124         bool fragstolen;
125         int delta;
126
127         if (MPTCP_SKB_CB(from)->offset ||
128             !skb_try_coalesce(to, from, &fragstolen, &delta))
129                 return false;
130
131         pr_debug("colesced seq %llx into %llx new len %d new end seq %llx",
132                  MPTCP_SKB_CB(from)->map_seq, MPTCP_SKB_CB(to)->map_seq,
133                  to->len, MPTCP_SKB_CB(from)->end_seq);
134         MPTCP_SKB_CB(to)->end_seq = MPTCP_SKB_CB(from)->end_seq;
135         kfree_skb_partial(from, fragstolen);
136         atomic_add(delta, &sk->sk_rmem_alloc);
137         sk_mem_charge(sk, delta);
138         return true;
139 }
140
141 static bool mptcp_ooo_try_coalesce(struct mptcp_sock *msk, struct sk_buff *to,
142                                    struct sk_buff *from)
143 {
144         if (MPTCP_SKB_CB(from)->map_seq != MPTCP_SKB_CB(to)->end_seq)
145                 return false;
146
147         return mptcp_try_coalesce((struct sock *)msk, to, from);
148 }
149
150 /* "inspired" by tcp_data_queue_ofo(), main differences:
151  * - use mptcp seqs
152  * - don't cope with sacks
153  */
154 static void mptcp_data_queue_ofo(struct mptcp_sock *msk, struct sk_buff *skb)
155 {
156         struct sock *sk = (struct sock *)msk;
157         struct rb_node **p, *parent;
158         u64 seq, end_seq, max_seq;
159         struct sk_buff *skb1;
160         int space;
161
162         seq = MPTCP_SKB_CB(skb)->map_seq;
163         end_seq = MPTCP_SKB_CB(skb)->end_seq;
164         space = tcp_space(sk);
165         max_seq = space > 0 ? space + msk->ack_seq : msk->ack_seq;
166
167         pr_debug("msk=%p seq=%llx limit=%llx empty=%d", msk, seq, max_seq,
168                  RB_EMPTY_ROOT(&msk->out_of_order_queue));
169         if (after64(seq, max_seq)) {
170                 /* out of window */
171                 mptcp_drop(sk, skb);
172                 MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_NODSSWINDOW);
173                 return;
174         }
175
176         p = &msk->out_of_order_queue.rb_node;
177         MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_OFOQUEUE);
178         if (RB_EMPTY_ROOT(&msk->out_of_order_queue)) {
179                 rb_link_node(&skb->rbnode, NULL, p);
180                 rb_insert_color(&skb->rbnode, &msk->out_of_order_queue);
181                 msk->ooo_last_skb = skb;
182                 goto end;
183         }
184
185         /* with 2 subflows, adding at end of ooo queue is quite likely
186          * Use of ooo_last_skb avoids the O(Log(N)) rbtree lookup.
187          */
188         if (mptcp_ooo_try_coalesce(msk, msk->ooo_last_skb, skb)) {
189                 MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_OFOMERGE);
190                 MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_OFOQUEUETAIL);
191                 return;
192         }
193
194         /* Can avoid an rbtree lookup if we are adding skb after ooo_last_skb */
195         if (!before64(seq, MPTCP_SKB_CB(msk->ooo_last_skb)->end_seq)) {
196                 MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_OFOQUEUETAIL);
197                 parent = &msk->ooo_last_skb->rbnode;
198                 p = &parent->rb_right;
199                 goto insert;
200         }
201
202         /* Find place to insert this segment. Handle overlaps on the way. */
203         parent = NULL;
204         while (*p) {
205                 parent = *p;
206                 skb1 = rb_to_skb(parent);
207                 if (before64(seq, MPTCP_SKB_CB(skb1)->map_seq)) {
208                         p = &parent->rb_left;
209                         continue;
210                 }
211                 if (before64(seq, MPTCP_SKB_CB(skb1)->end_seq)) {
212                         if (!after64(end_seq, MPTCP_SKB_CB(skb1)->end_seq)) {
213                                 /* All the bits are present. Drop. */
214                                 mptcp_drop(sk, skb);
215                                 MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_DUPDATA);
216                                 return;
217                         }
218                         if (after64(seq, MPTCP_SKB_CB(skb1)->map_seq)) {
219                                 /* partial overlap:
220                                  *     |     skb      |
221                                  *  |     skb1    |
222                                  * continue traversing
223                                  */
224                         } else {
225                                 /* skb's seq == skb1's seq and skb covers skb1.
226                                  * Replace skb1 with skb.
227                                  */
228                                 rb_replace_node(&skb1->rbnode, &skb->rbnode,
229                                                 &msk->out_of_order_queue);
230                                 mptcp_drop(sk, skb1);
231                                 MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_DUPDATA);
232                                 goto merge_right;
233                         }
234                 } else if (mptcp_ooo_try_coalesce(msk, skb1, skb)) {
235                         MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_OFOMERGE);
236                         return;
237                 }
238                 p = &parent->rb_right;
239         }
240
241 insert:
242         /* Insert segment into RB tree. */
243         rb_link_node(&skb->rbnode, parent, p);
244         rb_insert_color(&skb->rbnode, &msk->out_of_order_queue);
245
246 merge_right:
247         /* Remove other segments covered by skb. */
248         while ((skb1 = skb_rb_next(skb)) != NULL) {
249                 if (before64(end_seq, MPTCP_SKB_CB(skb1)->end_seq))
250                         break;
251                 rb_erase(&skb1->rbnode, &msk->out_of_order_queue);
252                 mptcp_drop(sk, skb1);
253                 MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_DUPDATA);
254         }
255         /* If there is no skb after us, we are the last_skb ! */
256         if (!skb1)
257                 msk->ooo_last_skb = skb;
258
259 end:
260         skb_condense(skb);
261         skb_set_owner_r(skb, sk);
262 }
263
264 static bool __mptcp_move_skb(struct mptcp_sock *msk, struct sock *ssk,
265                              struct sk_buff *skb, unsigned int offset,
266                              size_t copy_len)
267 {
268         struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
269         struct sock *sk = (struct sock *)msk;
270         struct sk_buff *tail;
271
272         __skb_unlink(skb, &ssk->sk_receive_queue);
273
274         skb_ext_reset(skb);
275         skb_orphan(skb);
276
277         /* try to fetch required memory from subflow */
278         if (!sk_rmem_schedule(sk, skb, skb->truesize)) {
279                 if (ssk->sk_forward_alloc < skb->truesize)
280                         goto drop;
281                 __sk_mem_reclaim(ssk, skb->truesize);
282                 if (!sk_rmem_schedule(sk, skb, skb->truesize))
283                         goto drop;
284         }
285
286         /* the skb map_seq accounts for the skb offset:
287          * mptcp_subflow_get_mapped_dsn() is based on the current tp->copied_seq
288          * value
289          */
290         MPTCP_SKB_CB(skb)->map_seq = mptcp_subflow_get_mapped_dsn(subflow);
291         MPTCP_SKB_CB(skb)->end_seq = MPTCP_SKB_CB(skb)->map_seq + copy_len;
292         MPTCP_SKB_CB(skb)->offset = offset;
293
294         if (MPTCP_SKB_CB(skb)->map_seq == msk->ack_seq) {
295                 /* in sequence */
296                 WRITE_ONCE(msk->ack_seq, msk->ack_seq + copy_len);
297                 tail = skb_peek_tail(&sk->sk_receive_queue);
298                 if (tail && mptcp_try_coalesce(sk, tail, skb))
299                         return true;
300
301                 skb_set_owner_r(skb, sk);
302                 __skb_queue_tail(&sk->sk_receive_queue, skb);
303                 return true;
304         } else if (after64(MPTCP_SKB_CB(skb)->map_seq, msk->ack_seq)) {
305                 mptcp_data_queue_ofo(msk, skb);
306                 return false;
307         }
308
309         /* old data, keep it simple and drop the whole pkt, sender
310          * will retransmit as needed, if needed.
311          */
312         MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_DUPDATA);
313 drop:
314         mptcp_drop(sk, skb);
315         return false;
316 }
317
318 static void mptcp_stop_timer(struct sock *sk)
319 {
320         struct inet_connection_sock *icsk = inet_csk(sk);
321
322         sk_stop_timer(sk, &icsk->icsk_retransmit_timer);
323         mptcp_sk(sk)->timer_ival = 0;
324 }
325
326 static void mptcp_check_data_fin_ack(struct sock *sk)
327 {
328         struct mptcp_sock *msk = mptcp_sk(sk);
329
330         if (__mptcp_check_fallback(msk))
331                 return;
332
333         /* Look for an acknowledged DATA_FIN */
334         if (((1 << sk->sk_state) &
335              (TCPF_FIN_WAIT1 | TCPF_CLOSING | TCPF_LAST_ACK)) &&
336             msk->write_seq == atomic64_read(&msk->snd_una)) {
337                 mptcp_stop_timer(sk);
338
339                 WRITE_ONCE(msk->snd_data_fin_enable, 0);
340
341                 switch (sk->sk_state) {
342                 case TCP_FIN_WAIT1:
343                         inet_sk_state_store(sk, TCP_FIN_WAIT2);
344                         sk->sk_state_change(sk);
345                         break;
346                 case TCP_CLOSING:
347                 case TCP_LAST_ACK:
348                         inet_sk_state_store(sk, TCP_CLOSE);
349                         sk->sk_state_change(sk);
350                         break;
351                 }
352
353                 if (sk->sk_shutdown == SHUTDOWN_MASK ||
354                     sk->sk_state == TCP_CLOSE)
355                         sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_HUP);
356                 else
357                         sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
358         }
359 }
360
361 static bool mptcp_pending_data_fin(struct sock *sk, u64 *seq)
362 {
363         struct mptcp_sock *msk = mptcp_sk(sk);
364
365         if (READ_ONCE(msk->rcv_data_fin) &&
366             ((1 << sk->sk_state) &
367              (TCPF_ESTABLISHED | TCPF_FIN_WAIT1 | TCPF_FIN_WAIT2))) {
368                 u64 rcv_data_fin_seq = READ_ONCE(msk->rcv_data_fin_seq);
369
370                 if (msk->ack_seq == rcv_data_fin_seq) {
371                         if (seq)
372                                 *seq = rcv_data_fin_seq;
373
374                         return true;
375                 }
376         }
377
378         return false;
379 }
380
381 static void mptcp_set_timeout(const struct sock *sk, const struct sock *ssk)
382 {
383         long tout = ssk && inet_csk(ssk)->icsk_pending ?
384                                       inet_csk(ssk)->icsk_timeout - jiffies : 0;
385
386         if (tout <= 0)
387                 tout = mptcp_sk(sk)->timer_ival;
388         mptcp_sk(sk)->timer_ival = tout > 0 ? tout : TCP_RTO_MIN;
389 }
390
391 static void mptcp_check_data_fin(struct sock *sk)
392 {
393         struct mptcp_sock *msk = mptcp_sk(sk);
394         u64 rcv_data_fin_seq;
395
396         if (__mptcp_check_fallback(msk) || !msk->first)
397                 return;
398
399         /* Need to ack a DATA_FIN received from a peer while this side
400          * of the connection is in ESTABLISHED, FIN_WAIT1, or FIN_WAIT2.
401          * msk->rcv_data_fin was set when parsing the incoming options
402          * at the subflow level and the msk lock was not held, so this
403          * is the first opportunity to act on the DATA_FIN and change
404          * the msk state.
405          *
406          * If we are caught up to the sequence number of the incoming
407          * DATA_FIN, send the DATA_ACK now and do state transition.  If
408          * not caught up, do nothing and let the recv code send DATA_ACK
409          * when catching up.
410          */
411
412         if (mptcp_pending_data_fin(sk, &rcv_data_fin_seq)) {
413                 struct mptcp_subflow_context *subflow;
414
415                 WRITE_ONCE(msk->ack_seq, msk->ack_seq + 1);
416                 WRITE_ONCE(msk->rcv_data_fin, 0);
417
418                 sk->sk_shutdown |= RCV_SHUTDOWN;
419                 smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
420                 set_bit(MPTCP_DATA_READY, &msk->flags);
421
422                 switch (sk->sk_state) {
423                 case TCP_ESTABLISHED:
424                         inet_sk_state_store(sk, TCP_CLOSE_WAIT);
425                         break;
426                 case TCP_FIN_WAIT1:
427                         inet_sk_state_store(sk, TCP_CLOSING);
428                         break;
429                 case TCP_FIN_WAIT2:
430                         inet_sk_state_store(sk, TCP_CLOSE);
431                         // @@ Close subflows now?
432                         break;
433                 default:
434                         /* Other states not expected */
435                         WARN_ON_ONCE(1);
436                         break;
437                 }
438
439                 mptcp_set_timeout(sk, NULL);
440                 mptcp_for_each_subflow(msk, subflow) {
441                         struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
442
443                         lock_sock(ssk);
444                         tcp_send_ack(ssk);
445                         release_sock(ssk);
446                 }
447
448                 sk->sk_state_change(sk);
449
450                 if (sk->sk_shutdown == SHUTDOWN_MASK ||
451                     sk->sk_state == TCP_CLOSE)
452                         sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_HUP);
453                 else
454                         sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
455         }
456 }
457
458 static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk,
459                                            struct sock *ssk,
460                                            unsigned int *bytes)
461 {
462         struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
463         struct sock *sk = (struct sock *)msk;
464         unsigned int moved = 0;
465         bool more_data_avail;
466         struct tcp_sock *tp;
467         u32 old_copied_seq;
468         bool done = false;
469         int sk_rbuf;
470
471         sk_rbuf = READ_ONCE(sk->sk_rcvbuf);
472
473         if (!(sk->sk_userlocks & SOCK_RCVBUF_LOCK)) {
474                 int ssk_rbuf = READ_ONCE(ssk->sk_rcvbuf);
475
476                 if (unlikely(ssk_rbuf > sk_rbuf)) {
477                         WRITE_ONCE(sk->sk_rcvbuf, ssk_rbuf);
478                         sk_rbuf = ssk_rbuf;
479                 }
480         }
481
482         pr_debug("msk=%p ssk=%p", msk, ssk);
483         tp = tcp_sk(ssk);
484         old_copied_seq = tp->copied_seq;
485         do {
486                 u32 map_remaining, offset;
487                 u32 seq = tp->copied_seq;
488                 struct sk_buff *skb;
489                 bool fin;
490
491                 /* try to move as much data as available */
492                 map_remaining = subflow->map_data_len -
493                                 mptcp_subflow_get_map_offset(subflow);
494
495                 skb = skb_peek(&ssk->sk_receive_queue);
496                 if (!skb) {
497                         /* if no data is found, a racing workqueue/recvmsg
498                          * already processed the new data, stop here or we
499                          * can enter an infinite loop
500                          */
501                         if (!moved)
502                                 done = true;
503                         break;
504                 }
505
506                 if (__mptcp_check_fallback(msk)) {
507                         /* if we are running under the workqueue, TCP could have
508                          * collapsed skbs between dummy map creation and now
509                          * be sure to adjust the size
510                          */
511                         map_remaining = skb->len;
512                         subflow->map_data_len = skb->len;
513                 }
514
515                 offset = seq - TCP_SKB_CB(skb)->seq;
516                 fin = TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN;
517                 if (fin) {
518                         done = true;
519                         seq++;
520                 }
521
522                 if (offset < skb->len) {
523                         size_t len = skb->len - offset;
524
525                         if (tp->urg_data)
526                                 done = true;
527
528                         if (__mptcp_move_skb(msk, ssk, skb, offset, len))
529                                 moved += len;
530                         seq += len;
531
532                         if (WARN_ON_ONCE(map_remaining < len))
533                                 break;
534                 } else {
535                         WARN_ON_ONCE(!fin);
536                         sk_eat_skb(ssk, skb);
537                         done = true;
538                 }
539
540                 WRITE_ONCE(tp->copied_seq, seq);
541                 more_data_avail = mptcp_subflow_data_available(ssk);
542
543                 if (atomic_read(&sk->sk_rmem_alloc) > sk_rbuf) {
544                         done = true;
545                         break;
546                 }
547         } while (more_data_avail);
548
549         *bytes += moved;
550         if (tp->copied_seq != old_copied_seq)
551                 tcp_cleanup_rbuf(ssk, 1);
552
553         return done;
554 }
555
556 static bool mptcp_ofo_queue(struct mptcp_sock *msk)
557 {
558         struct sock *sk = (struct sock *)msk;
559         struct sk_buff *skb, *tail;
560         bool moved = false;
561         struct rb_node *p;
562         u64 end_seq;
563
564         p = rb_first(&msk->out_of_order_queue);
565         pr_debug("msk=%p empty=%d", msk, RB_EMPTY_ROOT(&msk->out_of_order_queue));
566         while (p) {
567                 skb = rb_to_skb(p);
568                 if (after64(MPTCP_SKB_CB(skb)->map_seq, msk->ack_seq))
569                         break;
570
571                 p = rb_next(p);
572                 rb_erase(&skb->rbnode, &msk->out_of_order_queue);
573
574                 if (unlikely(!after64(MPTCP_SKB_CB(skb)->end_seq,
575                                       msk->ack_seq))) {
576                         mptcp_drop(sk, skb);
577                         MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_DUPDATA);
578                         continue;
579                 }
580
581                 end_seq = MPTCP_SKB_CB(skb)->end_seq;
582                 tail = skb_peek_tail(&sk->sk_receive_queue);
583                 if (!tail || !mptcp_ooo_try_coalesce(msk, tail, skb)) {
584                         int delta = msk->ack_seq - MPTCP_SKB_CB(skb)->map_seq;
585
586                         /* skip overlapping data, if any */
587                         pr_debug("uncoalesced seq=%llx ack seq=%llx delta=%d",
588                                  MPTCP_SKB_CB(skb)->map_seq, msk->ack_seq,
589                                  delta);
590                         MPTCP_SKB_CB(skb)->offset += delta;
591                         __skb_queue_tail(&sk->sk_receive_queue, skb);
592                 }
593                 msk->ack_seq = end_seq;
594                 moved = true;
595         }
596         return moved;
597 }
598
599 /* In most cases we will be able to lock the mptcp socket.  If its already
600  * owned, we need to defer to the work queue to avoid ABBA deadlock.
601  */
602 static bool move_skbs_to_msk(struct mptcp_sock *msk, struct sock *ssk)
603 {
604         struct sock *sk = (struct sock *)msk;
605         unsigned int moved = 0;
606
607         if (READ_ONCE(sk->sk_lock.owned))
608                 return false;
609
610         if (unlikely(!spin_trylock_bh(&sk->sk_lock.slock)))
611                 return false;
612
613         /* must re-check after taking the lock */
614         if (!READ_ONCE(sk->sk_lock.owned)) {
615                 __mptcp_move_skbs_from_subflow(msk, ssk, &moved);
616                 mptcp_ofo_queue(msk);
617
618                 /* If the moves have caught up with the DATA_FIN sequence number
619                  * it's time to ack the DATA_FIN and change socket state, but
620                  * this is not a good place to change state. Let the workqueue
621                  * do it.
622                  */
623                 if (mptcp_pending_data_fin(sk, NULL))
624                         mptcp_schedule_work(sk);
625         }
626
627         spin_unlock_bh(&sk->sk_lock.slock);
628
629         return moved > 0;
630 }
631
632 void mptcp_data_ready(struct sock *sk, struct sock *ssk)
633 {
634         struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
635         struct mptcp_sock *msk = mptcp_sk(sk);
636         int sk_rbuf, ssk_rbuf;
637         bool wake;
638
639         /* move_skbs_to_msk below can legitly clear the data_avail flag,
640          * but we will need later to properly woke the reader, cache its
641          * value
642          */
643         wake = subflow->data_avail == MPTCP_SUBFLOW_DATA_AVAIL;
644         if (wake)
645                 set_bit(MPTCP_DATA_READY, &msk->flags);
646
647         ssk_rbuf = READ_ONCE(ssk->sk_rcvbuf);
648         sk_rbuf = READ_ONCE(sk->sk_rcvbuf);
649         if (unlikely(ssk_rbuf > sk_rbuf))
650                 sk_rbuf = ssk_rbuf;
651
652         /* over limit? can't append more skbs to msk */
653         if (atomic_read(&sk->sk_rmem_alloc) > sk_rbuf)
654                 goto wake;
655
656         if (move_skbs_to_msk(msk, ssk))
657                 goto wake;
658
659         /* mptcp socket is owned, release_cb should retry */
660         if (!test_and_set_bit(TCP_DELACK_TIMER_DEFERRED,
661                               &sk->sk_tsq_flags)) {
662                 sock_hold(sk);
663
664                 /* need to try again, its possible release_cb() has already
665                  * been called after the test_and_set_bit() above.
666                  */
667                 move_skbs_to_msk(msk, ssk);
668         }
669 wake:
670         if (wake)
671                 sk->sk_data_ready(sk);
672 }
673
674 static void __mptcp_flush_join_list(struct mptcp_sock *msk)
675 {
676         if (likely(list_empty(&msk->join_list)))
677                 return;
678
679         spin_lock_bh(&msk->join_list_lock);
680         list_splice_tail_init(&msk->join_list, &msk->conn_list);
681         spin_unlock_bh(&msk->join_list_lock);
682 }
683
684 static bool mptcp_timer_pending(struct sock *sk)
685 {
686         return timer_pending(&inet_csk(sk)->icsk_retransmit_timer);
687 }
688
689 static void mptcp_reset_timer(struct sock *sk)
690 {
691         struct inet_connection_sock *icsk = inet_csk(sk);
692         unsigned long tout;
693
694         /* should never be called with mptcp level timer cleared */
695         tout = READ_ONCE(mptcp_sk(sk)->timer_ival);
696         if (WARN_ON_ONCE(!tout))
697                 tout = TCP_RTO_MIN;
698         sk_reset_timer(sk, &icsk->icsk_retransmit_timer, jiffies + tout);
699 }
700
701 bool mptcp_schedule_work(struct sock *sk)
702 {
703         if (inet_sk_state_load(sk) != TCP_CLOSE &&
704             schedule_work(&mptcp_sk(sk)->work)) {
705                 /* each subflow already holds a reference to the sk, and the
706                  * workqueue is invoked by a subflow, so sk can't go away here.
707                  */
708                 sock_hold(sk);
709                 return true;
710         }
711         return false;
712 }
713
714 void mptcp_data_acked(struct sock *sk)
715 {
716         mptcp_reset_timer(sk);
717
718         if ((!test_bit(MPTCP_SEND_SPACE, &mptcp_sk(sk)->flags) ||
719              (inet_sk_state_load(sk) != TCP_ESTABLISHED)))
720                 mptcp_schedule_work(sk);
721 }
722
723 void mptcp_subflow_eof(struct sock *sk)
724 {
725         if (!test_and_set_bit(MPTCP_WORK_EOF, &mptcp_sk(sk)->flags))
726                 mptcp_schedule_work(sk);
727 }
728
729 static void mptcp_check_for_eof(struct mptcp_sock *msk)
730 {
731         struct mptcp_subflow_context *subflow;
732         struct sock *sk = (struct sock *)msk;
733         int receivers = 0;
734
735         mptcp_for_each_subflow(msk, subflow)
736                 receivers += !subflow->rx_eof;
737
738         if (!receivers && !(sk->sk_shutdown & RCV_SHUTDOWN)) {
739                 /* hopefully temporary hack: propagate shutdown status
740                  * to msk, when all subflows agree on it
741                  */
742                 sk->sk_shutdown |= RCV_SHUTDOWN;
743
744                 smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
745                 set_bit(MPTCP_DATA_READY, &msk->flags);
746                 sk->sk_data_ready(sk);
747         }
748 }
749
750 static bool mptcp_ext_cache_refill(struct mptcp_sock *msk)
751 {
752         const struct sock *sk = (const struct sock *)msk;
753
754         if (!msk->cached_ext)
755                 msk->cached_ext = __skb_ext_alloc(sk->sk_allocation);
756
757         return !!msk->cached_ext;
758 }
759
760 static struct sock *mptcp_subflow_recv_lookup(const struct mptcp_sock *msk)
761 {
762         struct mptcp_subflow_context *subflow;
763         struct sock *sk = (struct sock *)msk;
764
765         sock_owned_by_me(sk);
766
767         mptcp_for_each_subflow(msk, subflow) {
768                 if (subflow->data_avail)
769                         return mptcp_subflow_tcp_sock(subflow);
770         }
771
772         return NULL;
773 }
774
775 static bool mptcp_skb_can_collapse_to(u64 write_seq,
776                                       const struct sk_buff *skb,
777                                       const struct mptcp_ext *mpext)
778 {
779         if (!tcp_skb_can_collapse_to(skb))
780                 return false;
781
782         /* can collapse only if MPTCP level sequence is in order and this
783          * mapping has not been xmitted yet
784          */
785         return mpext && mpext->data_seq + mpext->data_len == write_seq &&
786                !mpext->frozen;
787 }
788
789 static bool mptcp_frag_can_collapse_to(const struct mptcp_sock *msk,
790                                        const struct page_frag *pfrag,
791                                        const struct mptcp_data_frag *df)
792 {
793         return df && pfrag->page == df->page &&
794                 df->data_seq + df->data_len == msk->write_seq;
795 }
796
797 static void dfrag_uncharge(struct sock *sk, int len)
798 {
799         sk_mem_uncharge(sk, len);
800         sk_wmem_queued_add(sk, -len);
801 }
802
803 static void dfrag_clear(struct sock *sk, struct mptcp_data_frag *dfrag)
804 {
805         int len = dfrag->data_len + dfrag->overhead;
806
807         list_del(&dfrag->list);
808         dfrag_uncharge(sk, len);
809         put_page(dfrag->page);
810 }
811
812 static bool mptcp_is_writeable(struct mptcp_sock *msk)
813 {
814         struct mptcp_subflow_context *subflow;
815
816         if (!sk_stream_is_writeable((struct sock *)msk))
817                 return false;
818
819         mptcp_for_each_subflow(msk, subflow) {
820                 if (sk_stream_is_writeable(subflow->tcp_sock))
821                         return true;
822         }
823         return false;
824 }
825
826 static void mptcp_clean_una(struct sock *sk)
827 {
828         struct mptcp_sock *msk = mptcp_sk(sk);
829         struct mptcp_data_frag *dtmp, *dfrag;
830         bool cleaned = false;
831         u64 snd_una;
832
833         /* on fallback we just need to ignore snd_una, as this is really
834          * plain TCP
835          */
836         if (__mptcp_check_fallback(msk))
837                 atomic64_set(&msk->snd_una, msk->write_seq);
838         snd_una = atomic64_read(&msk->snd_una);
839
840         list_for_each_entry_safe(dfrag, dtmp, &msk->rtx_queue, list) {
841                 if (after64(dfrag->data_seq + dfrag->data_len, snd_una))
842                         break;
843
844                 dfrag_clear(sk, dfrag);
845                 cleaned = true;
846         }
847
848         dfrag = mptcp_rtx_head(sk);
849         if (dfrag && after64(snd_una, dfrag->data_seq)) {
850                 u64 delta = snd_una - dfrag->data_seq;
851
852                 if (WARN_ON_ONCE(delta > dfrag->data_len))
853                         goto out;
854
855                 dfrag->data_seq += delta;
856                 dfrag->offset += delta;
857                 dfrag->data_len -= delta;
858
859                 dfrag_uncharge(sk, delta);
860                 cleaned = true;
861         }
862
863 out:
864         if (cleaned)
865                 sk_mem_reclaim_partial(sk);
866 }
867
868 static void mptcp_clean_una_wakeup(struct sock *sk)
869 {
870         struct mptcp_sock *msk = mptcp_sk(sk);
871
872         mptcp_clean_una(sk);
873
874         /* Only wake up writers if a subflow is ready */
875         if (mptcp_is_writeable(msk)) {
876                 set_bit(MPTCP_SEND_SPACE, &msk->flags);
877                 smp_mb__after_atomic();
878
879                 /* set SEND_SPACE before sk_stream_write_space clears
880                  * NOSPACE
881                  */
882                 sk_stream_write_space(sk);
883         }
884 }
885
886 /* ensure we get enough memory for the frag hdr, beyond some minimal amount of
887  * data
888  */
889 static bool mptcp_page_frag_refill(struct sock *sk, struct page_frag *pfrag)
890 {
891         if (likely(skb_page_frag_refill(32U + sizeof(struct mptcp_data_frag),
892                                         pfrag, sk->sk_allocation)))
893                 return true;
894
895         sk->sk_prot->enter_memory_pressure(sk);
896         sk_stream_moderate_sndbuf(sk);
897         return false;
898 }
899
900 static struct mptcp_data_frag *
901 mptcp_carve_data_frag(const struct mptcp_sock *msk, struct page_frag *pfrag,
902                       int orig_offset)
903 {
904         int offset = ALIGN(orig_offset, sizeof(long));
905         struct mptcp_data_frag *dfrag;
906
907         dfrag = (struct mptcp_data_frag *)(page_to_virt(pfrag->page) + offset);
908         dfrag->data_len = 0;
909         dfrag->data_seq = msk->write_seq;
910         dfrag->overhead = offset - orig_offset + sizeof(struct mptcp_data_frag);
911         dfrag->offset = offset + sizeof(struct mptcp_data_frag);
912         dfrag->page = pfrag->page;
913
914         return dfrag;
915 }
916
917 static int mptcp_sendmsg_frag(struct sock *sk, struct sock *ssk,
918                               struct msghdr *msg, struct mptcp_data_frag *dfrag,
919                               long *timeo, int *pmss_now,
920                               int *ps_goal)
921 {
922         int mss_now, avail_size, size_goal, offset, ret, frag_truesize = 0;
923         bool dfrag_collapsed, can_collapse = false;
924         struct mptcp_sock *msk = mptcp_sk(sk);
925         struct mptcp_ext *mpext = NULL;
926         bool retransmission = !!dfrag;
927         struct sk_buff *skb, *tail;
928         struct page_frag *pfrag;
929         struct page *page;
930         u64 *write_seq;
931         size_t psize;
932
933         /* use the mptcp page cache so that we can easily move the data
934          * from one substream to another, but do per subflow memory accounting
935          * Note: pfrag is used only !retransmission, but the compiler if
936          * fooled into a warning if we don't init here
937          */
938         pfrag = sk_page_frag(sk);
939         if (!retransmission) {
940                 write_seq = &msk->write_seq;
941                 page = pfrag->page;
942         } else {
943                 write_seq = &dfrag->data_seq;
944                 page = dfrag->page;
945         }
946
947         /* compute copy limit */
948         mss_now = tcp_send_mss(ssk, &size_goal, msg->msg_flags);
949         *pmss_now = mss_now;
950         *ps_goal = size_goal;
951         avail_size = size_goal;
952         skb = tcp_write_queue_tail(ssk);
953         if (skb) {
954                 mpext = skb_ext_find(skb, SKB_EXT_MPTCP);
955
956                 /* Limit the write to the size available in the
957                  * current skb, if any, so that we create at most a new skb.
958                  * Explicitly tells TCP internals to avoid collapsing on later
959                  * queue management operation, to avoid breaking the ext <->
960                  * SSN association set here
961                  */
962                 can_collapse = (size_goal - skb->len > 0) &&
963                               mptcp_skb_can_collapse_to(*write_seq, skb, mpext);
964                 if (!can_collapse)
965                         TCP_SKB_CB(skb)->eor = 1;
966                 else
967                         avail_size = size_goal - skb->len;
968         }
969
970         if (!retransmission) {
971                 /* reuse tail pfrag, if possible, or carve a new one from the
972                  * page allocator
973                  */
974                 dfrag = mptcp_rtx_tail(sk);
975                 offset = pfrag->offset;
976                 dfrag_collapsed = mptcp_frag_can_collapse_to(msk, pfrag, dfrag);
977                 if (!dfrag_collapsed) {
978                         dfrag = mptcp_carve_data_frag(msk, pfrag, offset);
979                         offset = dfrag->offset;
980                         frag_truesize = dfrag->overhead;
981                 }
982                 psize = min_t(size_t, pfrag->size - offset, avail_size);
983
984                 /* Copy to page */
985                 pr_debug("left=%zu", msg_data_left(msg));
986                 psize = copy_page_from_iter(pfrag->page, offset,
987                                             min_t(size_t, msg_data_left(msg),
988                                                   psize),
989                                             &msg->msg_iter);
990                 pr_debug("left=%zu", msg_data_left(msg));
991                 if (!psize)
992                         return -EINVAL;
993
994                 if (!sk_wmem_schedule(sk, psize + dfrag->overhead)) {
995                         iov_iter_revert(&msg->msg_iter, psize);
996                         return -ENOMEM;
997                 }
998         } else {
999                 offset = dfrag->offset;
1000                 psize = min_t(size_t, dfrag->data_len, avail_size);
1001         }
1002
1003         tail = tcp_build_frag(ssk, psize, msg->msg_flags, page, offset, &psize);
1004         if (!tail) {
1005                 tcp_remove_empty_skb(sk, tcp_write_queue_tail(ssk));
1006                 return -ENOMEM;
1007         }
1008
1009         ret = psize;
1010         frag_truesize += ret;
1011         if (!retransmission) {
1012                 if (unlikely(ret < psize))
1013                         iov_iter_revert(&msg->msg_iter, psize - ret);
1014
1015                 /* send successful, keep track of sent data for mptcp-level
1016                  * retransmission
1017                  */
1018                 dfrag->data_len += ret;
1019                 if (!dfrag_collapsed) {
1020                         get_page(dfrag->page);
1021                         list_add_tail(&dfrag->list, &msk->rtx_queue);
1022                         sk_wmem_queued_add(sk, frag_truesize);
1023                 } else {
1024                         sk_wmem_queued_add(sk, ret);
1025                 }
1026
1027                 /* charge data on mptcp rtx queue to the master socket
1028                  * Note: we charge such data both to sk and ssk
1029                  */
1030                 sk->sk_forward_alloc -= frag_truesize;
1031         }
1032
1033         /* if the tail skb is still the cached one, collapsing really happened.
1034          */
1035         if (skb == tail) {
1036                 WARN_ON_ONCE(!can_collapse);
1037                 mpext->data_len += ret;
1038                 goto out;
1039         }
1040
1041         mpext = __skb_ext_set(tail, SKB_EXT_MPTCP, msk->cached_ext);
1042         msk->cached_ext = NULL;
1043
1044         memset(mpext, 0, sizeof(*mpext));
1045         mpext->data_seq = *write_seq;
1046         mpext->subflow_seq = mptcp_subflow_ctx(ssk)->rel_write_seq;
1047         mpext->data_len = ret;
1048         mpext->use_map = 1;
1049         mpext->dsn64 = 1;
1050
1051         pr_debug("data_seq=%llu subflow_seq=%u data_len=%u dsn64=%d",
1052                  mpext->data_seq, mpext->subflow_seq, mpext->data_len,
1053                  mpext->dsn64);
1054
1055 out:
1056         if (!retransmission)
1057                 pfrag->offset += frag_truesize;
1058         WRITE_ONCE(*write_seq, *write_seq + ret);
1059         mptcp_subflow_ctx(ssk)->rel_write_seq += ret;
1060
1061         return ret;
1062 }
1063
1064 static void mptcp_nospace(struct mptcp_sock *msk)
1065 {
1066         struct mptcp_subflow_context *subflow;
1067
1068         clear_bit(MPTCP_SEND_SPACE, &msk->flags);
1069         smp_mb__after_atomic(); /* msk->flags is changed by write_space cb */
1070
1071         mptcp_for_each_subflow(msk, subflow) {
1072                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
1073                 struct socket *sock = READ_ONCE(ssk->sk_socket);
1074
1075                 /* enables ssk->write_space() callbacks */
1076                 if (sock)
1077                         set_bit(SOCK_NOSPACE, &sock->flags);
1078         }
1079 }
1080
1081 static bool mptcp_subflow_active(struct mptcp_subflow_context *subflow)
1082 {
1083         struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
1084
1085         /* can't send if JOIN hasn't completed yet (i.e. is usable for mptcp) */
1086         if (subflow->request_join && !subflow->fully_established)
1087                 return false;
1088
1089         /* only send if our side has not closed yet */
1090         return ((1 << ssk->sk_state) & (TCPF_ESTABLISHED | TCPF_CLOSE_WAIT));
1091 }
1092
1093 #define MPTCP_SEND_BURST_SIZE           ((1 << 16) - \
1094                                          sizeof(struct tcphdr) - \
1095                                          MAX_TCP_OPTION_SPACE - \
1096                                          sizeof(struct ipv6hdr) - \
1097                                          sizeof(struct frag_hdr))
1098
1099 struct subflow_send_info {
1100         struct sock *ssk;
1101         u64 ratio;
1102 };
1103
1104 static struct sock *mptcp_subflow_get_send(struct mptcp_sock *msk,
1105                                            u32 *sndbuf)
1106 {
1107         struct subflow_send_info send_info[2];
1108         struct mptcp_subflow_context *subflow;
1109         int i, nr_active = 0;
1110         struct sock *ssk;
1111         u64 ratio;
1112         u32 pace;
1113
1114         sock_owned_by_me((struct sock *)msk);
1115
1116         *sndbuf = 0;
1117         if (!mptcp_ext_cache_refill(msk))
1118                 return NULL;
1119
1120         if (__mptcp_check_fallback(msk)) {
1121                 if (!msk->first)
1122                         return NULL;
1123                 *sndbuf = msk->first->sk_sndbuf;
1124                 return sk_stream_memory_free(msk->first) ? msk->first : NULL;
1125         }
1126
1127         /* re-use last subflow, if the burst allow that */
1128         if (msk->last_snd && msk->snd_burst > 0 &&
1129             sk_stream_memory_free(msk->last_snd) &&
1130             mptcp_subflow_active(mptcp_subflow_ctx(msk->last_snd))) {
1131                 mptcp_for_each_subflow(msk, subflow) {
1132                         ssk =  mptcp_subflow_tcp_sock(subflow);
1133                         *sndbuf = max(tcp_sk(ssk)->snd_wnd, *sndbuf);
1134                 }
1135                 return msk->last_snd;
1136         }
1137
1138         /* pick the subflow with the lower wmem/wspace ratio */
1139         for (i = 0; i < 2; ++i) {
1140                 send_info[i].ssk = NULL;
1141                 send_info[i].ratio = -1;
1142         }
1143         mptcp_for_each_subflow(msk, subflow) {
1144                 ssk =  mptcp_subflow_tcp_sock(subflow);
1145                 if (!mptcp_subflow_active(subflow))
1146                         continue;
1147
1148                 nr_active += !subflow->backup;
1149                 *sndbuf = max(tcp_sk(ssk)->snd_wnd, *sndbuf);
1150                 if (!sk_stream_memory_free(subflow->tcp_sock))
1151                         continue;
1152
1153                 pace = READ_ONCE(ssk->sk_pacing_rate);
1154                 if (!pace)
1155                         continue;
1156
1157                 ratio = div_u64((u64)READ_ONCE(ssk->sk_wmem_queued) << 32,
1158                                 pace);
1159                 if (ratio < send_info[subflow->backup].ratio) {
1160                         send_info[subflow->backup].ssk = ssk;
1161                         send_info[subflow->backup].ratio = ratio;
1162                 }
1163         }
1164
1165         pr_debug("msk=%p nr_active=%d ssk=%p:%lld backup=%p:%lld",
1166                  msk, nr_active, send_info[0].ssk, send_info[0].ratio,
1167                  send_info[1].ssk, send_info[1].ratio);
1168
1169         /* pick the best backup if no other subflow is active */
1170         if (!nr_active)
1171                 send_info[0].ssk = send_info[1].ssk;
1172
1173         if (send_info[0].ssk) {
1174                 msk->last_snd = send_info[0].ssk;
1175                 msk->snd_burst = min_t(int, MPTCP_SEND_BURST_SIZE,
1176                                        sk_stream_wspace(msk->last_snd));
1177                 return msk->last_snd;
1178         }
1179         return NULL;
1180 }
1181
1182 static void ssk_check_wmem(struct mptcp_sock *msk)
1183 {
1184         if (unlikely(!mptcp_is_writeable(msk)))
1185                 mptcp_nospace(msk);
1186 }
1187
1188 static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
1189 {
1190         int mss_now = 0, size_goal = 0, ret = 0;
1191         struct mptcp_sock *msk = mptcp_sk(sk);
1192         struct page_frag *pfrag;
1193         size_t copied = 0;
1194         struct sock *ssk;
1195         u32 sndbuf;
1196         bool tx_ok;
1197         long timeo;
1198
1199         if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
1200                 return -EOPNOTSUPP;
1201
1202         lock_sock(sk);
1203
1204         timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
1205
1206         if ((1 << sk->sk_state) & ~(TCPF_ESTABLISHED | TCPF_CLOSE_WAIT)) {
1207                 ret = sk_stream_wait_connect(sk, &timeo);
1208                 if (ret)
1209                         goto out;
1210         }
1211
1212         pfrag = sk_page_frag(sk);
1213 restart:
1214         mptcp_clean_una(sk);
1215
1216         if (sk->sk_err || (sk->sk_shutdown & SEND_SHUTDOWN)) {
1217                 ret = -EPIPE;
1218                 goto out;
1219         }
1220
1221         __mptcp_flush_join_list(msk);
1222         ssk = mptcp_subflow_get_send(msk, &sndbuf);
1223         while (!sk_stream_memory_free(sk) ||
1224                !ssk ||
1225                !mptcp_page_frag_refill(ssk, pfrag)) {
1226                 if (ssk) {
1227                         /* make sure retransmit timer is
1228                          * running before we wait for memory.
1229                          *
1230                          * The retransmit timer might be needed
1231                          * to make the peer send an up-to-date
1232                          * MPTCP Ack.
1233                          */
1234                         mptcp_set_timeout(sk, ssk);
1235                         if (!mptcp_timer_pending(sk))
1236                                 mptcp_reset_timer(sk);
1237                 }
1238
1239                 mptcp_nospace(msk);
1240                 ret = sk_stream_wait_memory(sk, &timeo);
1241                 if (ret)
1242                         goto out;
1243
1244                 mptcp_clean_una(sk);
1245
1246                 ssk = mptcp_subflow_get_send(msk, &sndbuf);
1247                 if (list_empty(&msk->conn_list)) {
1248                         ret = -ENOTCONN;
1249                         goto out;
1250                 }
1251         }
1252
1253         /* do auto tuning */
1254         if (!(sk->sk_userlocks & SOCK_SNDBUF_LOCK) &&
1255             sndbuf > READ_ONCE(sk->sk_sndbuf))
1256                 WRITE_ONCE(sk->sk_sndbuf, sndbuf);
1257
1258         pr_debug("conn_list->subflow=%p", ssk);
1259
1260         lock_sock(ssk);
1261         tx_ok = msg_data_left(msg);
1262         while (tx_ok) {
1263                 ret = mptcp_sendmsg_frag(sk, ssk, msg, NULL, &timeo, &mss_now,
1264                                          &size_goal);
1265                 if (ret < 0) {
1266                         if (ret == -EAGAIN && timeo > 0) {
1267                                 mptcp_set_timeout(sk, ssk);
1268                                 release_sock(ssk);
1269                                 goto restart;
1270                         }
1271                         break;
1272                 }
1273
1274                 /* burst can be negative, we will try move to the next subflow
1275                  * at selection time, if possible.
1276                  */
1277                 msk->snd_burst -= ret;
1278                 copied += ret;
1279
1280                 tx_ok = msg_data_left(msg);
1281                 if (!tx_ok)
1282                         break;
1283
1284                 if (!sk_stream_memory_free(ssk) ||
1285                     !mptcp_page_frag_refill(ssk, pfrag) ||
1286                     !mptcp_ext_cache_refill(msk)) {
1287                         tcp_push(ssk, msg->msg_flags, mss_now,
1288                                  tcp_sk(ssk)->nonagle, size_goal);
1289                         mptcp_set_timeout(sk, ssk);
1290                         release_sock(ssk);
1291                         goto restart;
1292                 }
1293
1294                 /* memory is charged to mptcp level socket as well, i.e.
1295                  * if msg is very large, mptcp socket may run out of buffer
1296                  * space.  mptcp_clean_una() will release data that has
1297                  * been acked at mptcp level in the mean time, so there is
1298                  * a good chance we can continue sending data right away.
1299                  *
1300                  * Normally, when the tcp subflow can accept more data, then
1301                  * so can the MPTCP socket.  However, we need to cope with
1302                  * peers that might lag behind in their MPTCP-level
1303                  * acknowledgements, i.e.  data might have been acked at
1304                  * tcp level only.  So, we must also check the MPTCP socket
1305                  * limits before we send more data.
1306                  */
1307                 if (unlikely(!sk_stream_memory_free(sk))) {
1308                         tcp_push(ssk, msg->msg_flags, mss_now,
1309                                  tcp_sk(ssk)->nonagle, size_goal);
1310                         mptcp_clean_una(sk);
1311                         if (!sk_stream_memory_free(sk)) {
1312                                 /* can't send more for now, need to wait for
1313                                  * MPTCP-level ACKs from peer.
1314                                  *
1315                                  * Wakeup will happen via mptcp_clean_una().
1316                                  */
1317                                 mptcp_set_timeout(sk, ssk);
1318                                 release_sock(ssk);
1319                                 goto restart;
1320                         }
1321                 }
1322         }
1323
1324         mptcp_set_timeout(sk, ssk);
1325         if (copied) {
1326                 tcp_push(ssk, msg->msg_flags, mss_now, tcp_sk(ssk)->nonagle,
1327                          size_goal);
1328
1329                 /* start the timer, if it's not pending */
1330                 if (!mptcp_timer_pending(sk))
1331                         mptcp_reset_timer(sk);
1332         }
1333
1334         release_sock(ssk);
1335 out:
1336         ssk_check_wmem(msk);
1337         release_sock(sk);
1338         return copied ? : ret;
1339 }
1340
1341 static void mptcp_wait_data(struct sock *sk, long *timeo)
1342 {
1343         DEFINE_WAIT_FUNC(wait, woken_wake_function);
1344         struct mptcp_sock *msk = mptcp_sk(sk);
1345
1346         add_wait_queue(sk_sleep(sk), &wait);
1347         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1348
1349         sk_wait_event(sk, timeo,
1350                       test_and_clear_bit(MPTCP_DATA_READY, &msk->flags), &wait);
1351
1352         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1353         remove_wait_queue(sk_sleep(sk), &wait);
1354 }
1355
1356 static int __mptcp_recvmsg_mskq(struct mptcp_sock *msk,
1357                                 struct msghdr *msg,
1358                                 size_t len)
1359 {
1360         struct sock *sk = (struct sock *)msk;
1361         struct sk_buff *skb;
1362         int copied = 0;
1363
1364         while ((skb = skb_peek(&sk->sk_receive_queue)) != NULL) {
1365                 u32 offset = MPTCP_SKB_CB(skb)->offset;
1366                 u32 data_len = skb->len - offset;
1367                 u32 count = min_t(size_t, len - copied, data_len);
1368                 int err;
1369
1370                 err = skb_copy_datagram_msg(skb, offset, msg, count);
1371                 if (unlikely(err < 0)) {
1372                         if (!copied)
1373                                 return err;
1374                         break;
1375                 }
1376
1377                 copied += count;
1378
1379                 if (count < data_len) {
1380                         MPTCP_SKB_CB(skb)->offset += count;
1381                         break;
1382                 }
1383
1384                 __skb_unlink(skb, &sk->sk_receive_queue);
1385                 __kfree_skb(skb);
1386
1387                 if (copied >= len)
1388                         break;
1389         }
1390
1391         return copied;
1392 }
1393
1394 /* receive buffer autotuning.  See tcp_rcv_space_adjust for more information.
1395  *
1396  * Only difference: Use highest rtt estimate of the subflows in use.
1397  */
1398 static void mptcp_rcv_space_adjust(struct mptcp_sock *msk, int copied)
1399 {
1400         struct mptcp_subflow_context *subflow;
1401         struct sock *sk = (struct sock *)msk;
1402         u32 time, advmss = 1;
1403         u64 rtt_us, mstamp;
1404
1405         sock_owned_by_me(sk);
1406
1407         if (copied <= 0)
1408                 return;
1409
1410         msk->rcvq_space.copied += copied;
1411
1412         mstamp = div_u64(tcp_clock_ns(), NSEC_PER_USEC);
1413         time = tcp_stamp_us_delta(mstamp, msk->rcvq_space.time);
1414
1415         rtt_us = msk->rcvq_space.rtt_us;
1416         if (rtt_us && time < (rtt_us >> 3))
1417                 return;
1418
1419         rtt_us = 0;
1420         mptcp_for_each_subflow(msk, subflow) {
1421                 const struct tcp_sock *tp;
1422                 u64 sf_rtt_us;
1423                 u32 sf_advmss;
1424
1425                 tp = tcp_sk(mptcp_subflow_tcp_sock(subflow));
1426
1427                 sf_rtt_us = READ_ONCE(tp->rcv_rtt_est.rtt_us);
1428                 sf_advmss = READ_ONCE(tp->advmss);
1429
1430                 rtt_us = max(sf_rtt_us, rtt_us);
1431                 advmss = max(sf_advmss, advmss);
1432         }
1433
1434         msk->rcvq_space.rtt_us = rtt_us;
1435         if (time < (rtt_us >> 3) || rtt_us == 0)
1436                 return;
1437
1438         if (msk->rcvq_space.copied <= msk->rcvq_space.space)
1439                 goto new_measure;
1440
1441         if (sock_net(sk)->ipv4.sysctl_tcp_moderate_rcvbuf &&
1442             !(sk->sk_userlocks & SOCK_RCVBUF_LOCK)) {
1443                 int rcvmem, rcvbuf;
1444                 u64 rcvwin, grow;
1445
1446                 rcvwin = ((u64)msk->rcvq_space.copied << 1) + 16 * advmss;
1447
1448                 grow = rcvwin * (msk->rcvq_space.copied - msk->rcvq_space.space);
1449
1450                 do_div(grow, msk->rcvq_space.space);
1451                 rcvwin += (grow << 1);
1452
1453                 rcvmem = SKB_TRUESIZE(advmss + MAX_TCP_HEADER);
1454                 while (tcp_win_from_space(sk, rcvmem) < advmss)
1455                         rcvmem += 128;
1456
1457                 do_div(rcvwin, advmss);
1458                 rcvbuf = min_t(u64, rcvwin * rcvmem,
1459                                sock_net(sk)->ipv4.sysctl_tcp_rmem[2]);
1460
1461                 if (rcvbuf > sk->sk_rcvbuf) {
1462                         u32 window_clamp;
1463
1464                         window_clamp = tcp_win_from_space(sk, rcvbuf);
1465                         WRITE_ONCE(sk->sk_rcvbuf, rcvbuf);
1466
1467                         /* Make subflows follow along.  If we do not do this, we
1468                          * get drops at subflow level if skbs can't be moved to
1469                          * the mptcp rx queue fast enough (announced rcv_win can
1470                          * exceed ssk->sk_rcvbuf).
1471                          */
1472                         mptcp_for_each_subflow(msk, subflow) {
1473                                 struct sock *ssk;
1474                                 bool slow;
1475
1476                                 ssk = mptcp_subflow_tcp_sock(subflow);
1477                                 slow = lock_sock_fast(ssk);
1478                                 WRITE_ONCE(ssk->sk_rcvbuf, rcvbuf);
1479                                 tcp_sk(ssk)->window_clamp = window_clamp;
1480                                 tcp_cleanup_rbuf(ssk, 1);
1481                                 unlock_sock_fast(ssk, slow);
1482                         }
1483                 }
1484         }
1485
1486         msk->rcvq_space.space = msk->rcvq_space.copied;
1487 new_measure:
1488         msk->rcvq_space.copied = 0;
1489         msk->rcvq_space.time = mstamp;
1490 }
1491
1492 static bool __mptcp_move_skbs(struct mptcp_sock *msk)
1493 {
1494         unsigned int moved = 0;
1495         bool done;
1496
1497         /* avoid looping forever below on racing close */
1498         if (((struct sock *)msk)->sk_state == TCP_CLOSE)
1499                 return false;
1500
1501         __mptcp_flush_join_list(msk);
1502         do {
1503                 struct sock *ssk = mptcp_subflow_recv_lookup(msk);
1504                 bool slowpath;
1505
1506                 if (!ssk)
1507                         break;
1508
1509                 slowpath = lock_sock_fast(ssk);
1510                 done = __mptcp_move_skbs_from_subflow(msk, ssk, &moved);
1511                 unlock_sock_fast(ssk, slowpath);
1512         } while (!done);
1513
1514         if (mptcp_ofo_queue(msk) || moved > 0) {
1515                 mptcp_check_data_fin((struct sock *)msk);
1516                 return true;
1517         }
1518         return false;
1519 }
1520
1521 static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
1522                          int nonblock, int flags, int *addr_len)
1523 {
1524         struct mptcp_sock *msk = mptcp_sk(sk);
1525         int copied = 0;
1526         int target;
1527         long timeo;
1528
1529         if (msg->msg_flags & ~(MSG_WAITALL | MSG_DONTWAIT))
1530                 return -EOPNOTSUPP;
1531
1532         lock_sock(sk);
1533         timeo = sock_rcvtimeo(sk, nonblock);
1534
1535         len = min_t(size_t, len, INT_MAX);
1536         target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1537         __mptcp_flush_join_list(msk);
1538
1539         while (len > (size_t)copied) {
1540                 int bytes_read;
1541
1542                 bytes_read = __mptcp_recvmsg_mskq(msk, msg, len - copied);
1543                 if (unlikely(bytes_read < 0)) {
1544                         if (!copied)
1545                                 copied = bytes_read;
1546                         goto out_err;
1547                 }
1548
1549                 copied += bytes_read;
1550
1551                 if (skb_queue_empty(&sk->sk_receive_queue) &&
1552                     __mptcp_move_skbs(msk))
1553                         continue;
1554
1555                 /* only the master socket status is relevant here. The exit
1556                  * conditions mirror closely tcp_recvmsg()
1557                  */
1558                 if (copied >= target)
1559                         break;
1560
1561                 if (copied) {
1562                         if (sk->sk_err ||
1563                             sk->sk_state == TCP_CLOSE ||
1564                             (sk->sk_shutdown & RCV_SHUTDOWN) ||
1565                             !timeo ||
1566                             signal_pending(current))
1567                                 break;
1568                 } else {
1569                         if (sk->sk_err) {
1570                                 copied = sock_error(sk);
1571                                 break;
1572                         }
1573
1574                         if (test_and_clear_bit(MPTCP_WORK_EOF, &msk->flags))
1575                                 mptcp_check_for_eof(msk);
1576
1577                         if (sk->sk_shutdown & RCV_SHUTDOWN)
1578                                 break;
1579
1580                         if (sk->sk_state == TCP_CLOSE) {
1581                                 copied = -ENOTCONN;
1582                                 break;
1583                         }
1584
1585                         if (!timeo) {
1586                                 copied = -EAGAIN;
1587                                 break;
1588                         }
1589
1590                         if (signal_pending(current)) {
1591                                 copied = sock_intr_errno(timeo);
1592                                 break;
1593                         }
1594                 }
1595
1596                 pr_debug("block timeout %ld", timeo);
1597                 mptcp_wait_data(sk, &timeo);
1598         }
1599
1600         if (skb_queue_empty(&sk->sk_receive_queue)) {
1601                 /* entire backlog drained, clear DATA_READY. */
1602                 clear_bit(MPTCP_DATA_READY, &msk->flags);
1603
1604                 /* .. race-breaker: ssk might have gotten new data
1605                  * after last __mptcp_move_skbs() returned false.
1606                  */
1607                 if (unlikely(__mptcp_move_skbs(msk)))
1608                         set_bit(MPTCP_DATA_READY, &msk->flags);
1609         } else if (unlikely(!test_bit(MPTCP_DATA_READY, &msk->flags))) {
1610                 /* data to read but mptcp_wait_data() cleared DATA_READY */
1611                 set_bit(MPTCP_DATA_READY, &msk->flags);
1612         }
1613 out_err:
1614         pr_debug("msk=%p data_ready=%d rx queue empty=%d copied=%d",
1615                  msk, test_bit(MPTCP_DATA_READY, &msk->flags),
1616                  skb_queue_empty(&sk->sk_receive_queue), copied);
1617         mptcp_rcv_space_adjust(msk, copied);
1618
1619         release_sock(sk);
1620         return copied;
1621 }
1622
1623 static void mptcp_retransmit_handler(struct sock *sk)
1624 {
1625         struct mptcp_sock *msk = mptcp_sk(sk);
1626
1627         if (atomic64_read(&msk->snd_una) == READ_ONCE(msk->write_seq)) {
1628                 mptcp_stop_timer(sk);
1629         } else {
1630                 set_bit(MPTCP_WORK_RTX, &msk->flags);
1631                 mptcp_schedule_work(sk);
1632         }
1633 }
1634
1635 static void mptcp_retransmit_timer(struct timer_list *t)
1636 {
1637         struct inet_connection_sock *icsk = from_timer(icsk, t,
1638                                                        icsk_retransmit_timer);
1639         struct sock *sk = &icsk->icsk_inet.sk;
1640
1641         bh_lock_sock(sk);
1642         if (!sock_owned_by_user(sk)) {
1643                 mptcp_retransmit_handler(sk);
1644         } else {
1645                 /* delegate our work to tcp_release_cb() */
1646                 if (!test_and_set_bit(TCP_WRITE_TIMER_DEFERRED,
1647                                       &sk->sk_tsq_flags))
1648                         sock_hold(sk);
1649         }
1650         bh_unlock_sock(sk);
1651         sock_put(sk);
1652 }
1653
1654 /* Find an idle subflow.  Return NULL if there is unacked data at tcp
1655  * level.
1656  *
1657  * A backup subflow is returned only if that is the only kind available.
1658  */
1659 static struct sock *mptcp_subflow_get_retrans(const struct mptcp_sock *msk)
1660 {
1661         struct mptcp_subflow_context *subflow;
1662         struct sock *backup = NULL;
1663
1664         sock_owned_by_me((const struct sock *)msk);
1665
1666         if (__mptcp_check_fallback(msk))
1667                 return msk->first;
1668
1669         mptcp_for_each_subflow(msk, subflow) {
1670                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
1671
1672                 if (!mptcp_subflow_active(subflow))
1673                         continue;
1674
1675                 /* still data outstanding at TCP level?  Don't retransmit. */
1676                 if (!tcp_write_queue_empty(ssk))
1677                         return NULL;
1678
1679                 if (subflow->backup) {
1680                         if (!backup)
1681                                 backup = ssk;
1682                         continue;
1683                 }
1684
1685                 return ssk;
1686         }
1687
1688         return backup;
1689 }
1690
1691 /* subflow sockets can be either outgoing (connect) or incoming
1692  * (accept).
1693  *
1694  * Outgoing subflows use in-kernel sockets.
1695  * Incoming subflows do not have their own 'struct socket' allocated,
1696  * so we need to use tcp_close() after detaching them from the mptcp
1697  * parent socket.
1698  */
1699 void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
1700                        struct mptcp_subflow_context *subflow,
1701                        long timeout)
1702 {
1703         struct socket *sock = READ_ONCE(ssk->sk_socket);
1704
1705         list_del(&subflow->node);
1706
1707         if (sock && sock != sk->sk_socket) {
1708                 /* outgoing subflow */
1709                 sock_release(sock);
1710         } else {
1711                 /* incoming subflow */
1712                 tcp_close(ssk, timeout);
1713         }
1714 }
1715
1716 static unsigned int mptcp_sync_mss(struct sock *sk, u32 pmtu)
1717 {
1718         return 0;
1719 }
1720
1721 static void pm_work(struct mptcp_sock *msk)
1722 {
1723         struct mptcp_pm_data *pm = &msk->pm;
1724
1725         spin_lock_bh(&msk->pm.lock);
1726
1727         pr_debug("msk=%p status=%x", msk, pm->status);
1728         if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) {
1729                 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED);
1730                 mptcp_pm_nl_add_addr_received(msk);
1731         }
1732         if (pm->status & BIT(MPTCP_PM_RM_ADDR_RECEIVED)) {
1733                 pm->status &= ~BIT(MPTCP_PM_RM_ADDR_RECEIVED);
1734                 mptcp_pm_nl_rm_addr_received(msk);
1735         }
1736         if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) {
1737                 pm->status &= ~BIT(MPTCP_PM_ESTABLISHED);
1738                 mptcp_pm_nl_fully_established(msk);
1739         }
1740         if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) {
1741                 pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED);
1742                 mptcp_pm_nl_subflow_established(msk);
1743         }
1744
1745         spin_unlock_bh(&msk->pm.lock);
1746 }
1747
1748 static void __mptcp_close_subflow(struct mptcp_sock *msk)
1749 {
1750         struct mptcp_subflow_context *subflow, *tmp;
1751
1752         list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
1753                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
1754
1755                 if (inet_sk_state_load(ssk) != TCP_CLOSE)
1756                         continue;
1757
1758                 __mptcp_close_ssk((struct sock *)msk, ssk, subflow, 0);
1759         }
1760 }
1761
1762 static void mptcp_worker(struct work_struct *work)
1763 {
1764         struct mptcp_sock *msk = container_of(work, struct mptcp_sock, work);
1765         struct sock *ssk, *sk = &msk->sk.icsk_inet.sk;
1766         int orig_len, orig_offset, mss_now = 0, size_goal = 0;
1767         struct mptcp_data_frag *dfrag;
1768         u64 orig_write_seq;
1769         size_t copied = 0;
1770         struct msghdr msg = {
1771                 .msg_flags = MSG_DONTWAIT,
1772         };
1773         long timeo = 0;
1774
1775         lock_sock(sk);
1776         mptcp_clean_una_wakeup(sk);
1777         mptcp_check_data_fin_ack(sk);
1778         __mptcp_flush_join_list(msk);
1779         if (test_and_clear_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags))
1780                 __mptcp_close_subflow(msk);
1781
1782         __mptcp_move_skbs(msk);
1783
1784         if (msk->pm.status)
1785                 pm_work(msk);
1786
1787         if (test_and_clear_bit(MPTCP_WORK_EOF, &msk->flags))
1788                 mptcp_check_for_eof(msk);
1789
1790         mptcp_check_data_fin(sk);
1791
1792         if (!test_and_clear_bit(MPTCP_WORK_RTX, &msk->flags))
1793                 goto unlock;
1794
1795         dfrag = mptcp_rtx_head(sk);
1796         if (!dfrag)
1797                 goto unlock;
1798
1799         if (!mptcp_ext_cache_refill(msk))
1800                 goto reset_unlock;
1801
1802         ssk = mptcp_subflow_get_retrans(msk);
1803         if (!ssk)
1804                 goto reset_unlock;
1805
1806         lock_sock(ssk);
1807
1808         orig_len = dfrag->data_len;
1809         orig_offset = dfrag->offset;
1810         orig_write_seq = dfrag->data_seq;
1811         while (dfrag->data_len > 0) {
1812                 int ret = mptcp_sendmsg_frag(sk, ssk, &msg, dfrag, &timeo,
1813                                              &mss_now, &size_goal);
1814                 if (ret < 0)
1815                         break;
1816
1817                 MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RETRANSSEGS);
1818                 copied += ret;
1819                 dfrag->data_len -= ret;
1820                 dfrag->offset += ret;
1821
1822                 if (!mptcp_ext_cache_refill(msk))
1823                         break;
1824         }
1825         if (copied)
1826                 tcp_push(ssk, msg.msg_flags, mss_now, tcp_sk(ssk)->nonagle,
1827                          size_goal);
1828
1829         dfrag->data_seq = orig_write_seq;
1830         dfrag->offset = orig_offset;
1831         dfrag->data_len = orig_len;
1832
1833         mptcp_set_timeout(sk, ssk);
1834         release_sock(ssk);
1835
1836 reset_unlock:
1837         if (!mptcp_timer_pending(sk))
1838                 mptcp_reset_timer(sk);
1839
1840 unlock:
1841         release_sock(sk);
1842         sock_put(sk);
1843 }
1844
1845 static int __mptcp_init_sock(struct sock *sk)
1846 {
1847         struct mptcp_sock *msk = mptcp_sk(sk);
1848
1849         spin_lock_init(&msk->join_list_lock);
1850
1851         INIT_LIST_HEAD(&msk->conn_list);
1852         INIT_LIST_HEAD(&msk->join_list);
1853         INIT_LIST_HEAD(&msk->rtx_queue);
1854         __set_bit(MPTCP_SEND_SPACE, &msk->flags);
1855         INIT_WORK(&msk->work, mptcp_worker);
1856         msk->out_of_order_queue = RB_ROOT;
1857
1858         msk->first = NULL;
1859         inet_csk(sk)->icsk_sync_mss = mptcp_sync_mss;
1860
1861         mptcp_pm_data_init(msk);
1862
1863         /* re-use the csk retrans timer for MPTCP-level retrans */
1864         timer_setup(&msk->sk.icsk_retransmit_timer, mptcp_retransmit_timer, 0);
1865
1866         return 0;
1867 }
1868
1869 static int mptcp_init_sock(struct sock *sk)
1870 {
1871         struct net *net = sock_net(sk);
1872         int ret;
1873
1874         ret = __mptcp_init_sock(sk);
1875         if (ret)
1876                 return ret;
1877
1878         if (!mptcp_is_enabled(net))
1879                 return -ENOPROTOOPT;
1880
1881         if (unlikely(!net->mib.mptcp_statistics) && !mptcp_mib_alloc(net))
1882                 return -ENOMEM;
1883
1884         ret = __mptcp_socket_create(mptcp_sk(sk));
1885         if (ret)
1886                 return ret;
1887
1888         sk_sockets_allocated_inc(sk);
1889         sk->sk_rcvbuf = sock_net(sk)->ipv4.sysctl_tcp_rmem[1];
1890         sk->sk_sndbuf = sock_net(sk)->ipv4.sysctl_tcp_wmem[1];
1891
1892         return 0;
1893 }
1894
1895 static void __mptcp_clear_xmit(struct sock *sk)
1896 {
1897         struct mptcp_sock *msk = mptcp_sk(sk);
1898         struct mptcp_data_frag *dtmp, *dfrag;
1899
1900         sk_stop_timer(sk, &msk->sk.icsk_retransmit_timer);
1901
1902         list_for_each_entry_safe(dfrag, dtmp, &msk->rtx_queue, list)
1903                 dfrag_clear(sk, dfrag);
1904 }
1905
1906 static void mptcp_cancel_work(struct sock *sk)
1907 {
1908         struct mptcp_sock *msk = mptcp_sk(sk);
1909
1910         if (cancel_work_sync(&msk->work))
1911                 sock_put(sk);
1912 }
1913
1914 void mptcp_subflow_shutdown(struct sock *sk, struct sock *ssk, int how)
1915 {
1916         lock_sock(ssk);
1917
1918         switch (ssk->sk_state) {
1919         case TCP_LISTEN:
1920                 if (!(how & RCV_SHUTDOWN))
1921                         break;
1922                 fallthrough;
1923         case TCP_SYN_SENT:
1924                 tcp_disconnect(ssk, O_NONBLOCK);
1925                 break;
1926         default:
1927                 if (__mptcp_check_fallback(mptcp_sk(sk))) {
1928                         pr_debug("Fallback");
1929                         ssk->sk_shutdown |= how;
1930                         tcp_shutdown(ssk, how);
1931                 } else {
1932                         pr_debug("Sending DATA_FIN on subflow %p", ssk);
1933                         mptcp_set_timeout(sk, ssk);
1934                         tcp_send_ack(ssk);
1935                 }
1936                 break;
1937         }
1938
1939         release_sock(ssk);
1940 }
1941
1942 static const unsigned char new_state[16] = {
1943         /* current state:     new state:      action:   */
1944         [0 /* (Invalid) */] = TCP_CLOSE,
1945         [TCP_ESTABLISHED]   = TCP_FIN_WAIT1 | TCP_ACTION_FIN,
1946         [TCP_SYN_SENT]      = TCP_CLOSE,
1947         [TCP_SYN_RECV]      = TCP_FIN_WAIT1 | TCP_ACTION_FIN,
1948         [TCP_FIN_WAIT1]     = TCP_FIN_WAIT1,
1949         [TCP_FIN_WAIT2]     = TCP_FIN_WAIT2,
1950         [TCP_TIME_WAIT]     = TCP_CLOSE,        /* should not happen ! */
1951         [TCP_CLOSE]         = TCP_CLOSE,
1952         [TCP_CLOSE_WAIT]    = TCP_LAST_ACK  | TCP_ACTION_FIN,
1953         [TCP_LAST_ACK]      = TCP_LAST_ACK,
1954         [TCP_LISTEN]        = TCP_CLOSE,
1955         [TCP_CLOSING]       = TCP_CLOSING,
1956         [TCP_NEW_SYN_RECV]  = TCP_CLOSE,        /* should not happen ! */
1957 };
1958
1959 static int mptcp_close_state(struct sock *sk)
1960 {
1961         int next = (int)new_state[sk->sk_state];
1962         int ns = next & TCP_STATE_MASK;
1963
1964         inet_sk_state_store(sk, ns);
1965
1966         return next & TCP_ACTION_FIN;
1967 }
1968
1969 static void mptcp_close(struct sock *sk, long timeout)
1970 {
1971         struct mptcp_subflow_context *subflow, *tmp;
1972         struct mptcp_sock *msk = mptcp_sk(sk);
1973         LIST_HEAD(conn_list);
1974
1975         lock_sock(sk);
1976         sk->sk_shutdown = SHUTDOWN_MASK;
1977
1978         if (sk->sk_state == TCP_LISTEN) {
1979                 inet_sk_state_store(sk, TCP_CLOSE);
1980                 goto cleanup;
1981         } else if (sk->sk_state == TCP_CLOSE) {
1982                 goto cleanup;
1983         }
1984
1985         if (__mptcp_check_fallback(msk)) {
1986                 goto update_state;
1987         } else if (mptcp_close_state(sk)) {
1988                 pr_debug("Sending DATA_FIN sk=%p", sk);
1989                 WRITE_ONCE(msk->write_seq, msk->write_seq + 1);
1990                 WRITE_ONCE(msk->snd_data_fin_enable, 1);
1991
1992                 mptcp_for_each_subflow(msk, subflow) {
1993                         struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow);
1994
1995                         mptcp_subflow_shutdown(sk, tcp_sk, SHUTDOWN_MASK);
1996                 }
1997         }
1998
1999         sk_stream_wait_close(sk, timeout);
2000
2001 update_state:
2002         inet_sk_state_store(sk, TCP_CLOSE);
2003
2004 cleanup:
2005         /* be sure to always acquire the join list lock, to sync vs
2006          * mptcp_finish_join().
2007          */
2008         spin_lock_bh(&msk->join_list_lock);
2009         list_splice_tail_init(&msk->join_list, &msk->conn_list);
2010         spin_unlock_bh(&msk->join_list_lock);
2011         list_splice_init(&msk->conn_list, &conn_list);
2012
2013         __mptcp_clear_xmit(sk);
2014
2015         release_sock(sk);
2016
2017         list_for_each_entry_safe(subflow, tmp, &conn_list, node) {
2018                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
2019                 __mptcp_close_ssk(sk, ssk, subflow, timeout);
2020         }
2021
2022         mptcp_cancel_work(sk);
2023
2024         __skb_queue_purge(&sk->sk_receive_queue);
2025
2026         sk_common_release(sk);
2027 }
2028
2029 static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
2030 {
2031 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
2032         const struct ipv6_pinfo *ssk6 = inet6_sk(ssk);
2033         struct ipv6_pinfo *msk6 = inet6_sk(msk);
2034
2035         msk->sk_v6_daddr = ssk->sk_v6_daddr;
2036         msk->sk_v6_rcv_saddr = ssk->sk_v6_rcv_saddr;
2037
2038         if (msk6 && ssk6) {
2039                 msk6->saddr = ssk6->saddr;
2040                 msk6->flow_label = ssk6->flow_label;
2041         }
2042 #endif
2043
2044         inet_sk(msk)->inet_num = inet_sk(ssk)->inet_num;
2045         inet_sk(msk)->inet_dport = inet_sk(ssk)->inet_dport;
2046         inet_sk(msk)->inet_sport = inet_sk(ssk)->inet_sport;
2047         inet_sk(msk)->inet_daddr = inet_sk(ssk)->inet_daddr;
2048         inet_sk(msk)->inet_saddr = inet_sk(ssk)->inet_saddr;
2049         inet_sk(msk)->inet_rcv_saddr = inet_sk(ssk)->inet_rcv_saddr;
2050 }
2051
2052 static int mptcp_disconnect(struct sock *sk, int flags)
2053 {
2054         /* Should never be called.
2055          * inet_stream_connect() calls ->disconnect, but that
2056          * refers to the subflow socket, not the mptcp one.
2057          */
2058         WARN_ON_ONCE(1);
2059         return 0;
2060 }
2061
2062 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
2063 static struct ipv6_pinfo *mptcp_inet6_sk(const struct sock *sk)
2064 {
2065         unsigned int offset = sizeof(struct mptcp6_sock) - sizeof(struct ipv6_pinfo);
2066
2067         return (struct ipv6_pinfo *)(((u8 *)sk) + offset);
2068 }
2069 #endif
2070
2071 struct sock *mptcp_sk_clone(const struct sock *sk,
2072                             const struct mptcp_options_received *mp_opt,
2073                             struct request_sock *req)
2074 {
2075         struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
2076         struct sock *nsk = sk_clone_lock(sk, GFP_ATOMIC);
2077         struct mptcp_sock *msk;
2078         u64 ack_seq;
2079
2080         if (!nsk)
2081                 return NULL;
2082
2083 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
2084         if (nsk->sk_family == AF_INET6)
2085                 inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk);
2086 #endif
2087
2088         __mptcp_init_sock(nsk);
2089
2090         msk = mptcp_sk(nsk);
2091         msk->local_key = subflow_req->local_key;
2092         msk->token = subflow_req->token;
2093         msk->subflow = NULL;
2094         WRITE_ONCE(msk->fully_established, false);
2095
2096         msk->write_seq = subflow_req->idsn + 1;
2097         atomic64_set(&msk->snd_una, msk->write_seq);
2098         if (mp_opt->mp_capable) {
2099                 msk->can_ack = true;
2100                 msk->remote_key = mp_opt->sndr_key;
2101                 mptcp_crypto_key_sha(msk->remote_key, NULL, &ack_seq);
2102                 ack_seq++;
2103                 WRITE_ONCE(msk->ack_seq, ack_seq);
2104         }
2105
2106         sock_reset_flag(nsk, SOCK_RCU_FREE);
2107         /* will be fully established after successful MPC subflow creation */
2108         inet_sk_state_store(nsk, TCP_SYN_RECV);
2109         bh_unlock_sock(nsk);
2110
2111         /* keep a single reference */
2112         __sock_put(nsk);
2113         return nsk;
2114 }
2115
2116 void mptcp_rcv_space_init(struct mptcp_sock *msk, const struct sock *ssk)
2117 {
2118         const struct tcp_sock *tp = tcp_sk(ssk);
2119
2120         msk->rcvq_space.copied = 0;
2121         msk->rcvq_space.rtt_us = 0;
2122
2123         msk->rcvq_space.time = tp->tcp_mstamp;
2124
2125         /* initial rcv_space offering made to peer */
2126         msk->rcvq_space.space = min_t(u32, tp->rcv_wnd,
2127                                       TCP_INIT_CWND * tp->advmss);
2128         if (msk->rcvq_space.space == 0)
2129                 msk->rcvq_space.space = TCP_INIT_CWND * TCP_MSS_DEFAULT;
2130 }
2131
2132 static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
2133                                  bool kern)
2134 {
2135         struct mptcp_sock *msk = mptcp_sk(sk);
2136         struct socket *listener;
2137         struct sock *newsk;
2138
2139         listener = __mptcp_nmpc_socket(msk);
2140         if (WARN_ON_ONCE(!listener)) {
2141                 *err = -EINVAL;
2142                 return NULL;
2143         }
2144
2145         pr_debug("msk=%p, listener=%p", msk, mptcp_subflow_ctx(listener->sk));
2146         newsk = inet_csk_accept(listener->sk, flags, err, kern);
2147         if (!newsk)
2148                 return NULL;
2149
2150         pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));
2151         if (sk_is_mptcp(newsk)) {
2152                 struct mptcp_subflow_context *subflow;
2153                 struct sock *new_mptcp_sock;
2154                 struct sock *ssk = newsk;
2155
2156                 subflow = mptcp_subflow_ctx(newsk);
2157                 new_mptcp_sock = subflow->conn;
2158
2159                 /* is_mptcp should be false if subflow->conn is missing, see
2160                  * subflow_syn_recv_sock()
2161                  */
2162                 if (WARN_ON_ONCE(!new_mptcp_sock)) {
2163                         tcp_sk(newsk)->is_mptcp = 0;
2164                         return newsk;
2165                 }
2166
2167                 /* acquire the 2nd reference for the owning socket */
2168                 sock_hold(new_mptcp_sock);
2169
2170                 local_bh_disable();
2171                 bh_lock_sock(new_mptcp_sock);
2172                 msk = mptcp_sk(new_mptcp_sock);
2173                 msk->first = newsk;
2174
2175                 newsk = new_mptcp_sock;
2176                 mptcp_copy_inaddrs(newsk, ssk);
2177                 list_add(&subflow->node, &msk->conn_list);
2178
2179                 mptcp_rcv_space_init(msk, ssk);
2180                 bh_unlock_sock(new_mptcp_sock);
2181
2182                 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPCAPABLEPASSIVEACK);
2183                 local_bh_enable();
2184         } else {
2185                 MPTCP_INC_STATS(sock_net(sk),
2186                                 MPTCP_MIB_MPCAPABLEPASSIVEFALLBACK);
2187         }
2188
2189         return newsk;
2190 }
2191
2192 void mptcp_destroy_common(struct mptcp_sock *msk)
2193 {
2194         skb_rbtree_purge(&msk->out_of_order_queue);
2195         mptcp_token_destroy(msk);
2196         mptcp_pm_free_anno_list(msk);
2197 }
2198
2199 static void mptcp_destroy(struct sock *sk)
2200 {
2201         struct mptcp_sock *msk = mptcp_sk(sk);
2202
2203         if (msk->cached_ext)
2204                 __skb_ext_put(msk->cached_ext);
2205
2206         mptcp_destroy_common(msk);
2207         sk_sockets_allocated_dec(sk);
2208 }
2209
2210 static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname,
2211                                        sockptr_t optval, unsigned int optlen)
2212 {
2213         struct sock *sk = (struct sock *)msk;
2214         struct socket *ssock;
2215         int ret;
2216
2217         switch (optname) {
2218         case SO_REUSEPORT:
2219         case SO_REUSEADDR:
2220                 lock_sock(sk);
2221                 ssock = __mptcp_nmpc_socket(msk);
2222                 if (!ssock) {
2223                         release_sock(sk);
2224                         return -EINVAL;
2225                 }
2226
2227                 ret = sock_setsockopt(ssock, SOL_SOCKET, optname, optval, optlen);
2228                 if (ret == 0) {
2229                         if (optname == SO_REUSEPORT)
2230                                 sk->sk_reuseport = ssock->sk->sk_reuseport;
2231                         else if (optname == SO_REUSEADDR)
2232                                 sk->sk_reuse = ssock->sk->sk_reuse;
2233                 }
2234                 release_sock(sk);
2235                 return ret;
2236         }
2237
2238         return sock_setsockopt(sk->sk_socket, SOL_SOCKET, optname, optval, optlen);
2239 }
2240
2241 static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname,
2242                                sockptr_t optval, unsigned int optlen)
2243 {
2244         struct sock *sk = (struct sock *)msk;
2245         int ret = -EOPNOTSUPP;
2246         struct socket *ssock;
2247
2248         switch (optname) {
2249         case IPV6_V6ONLY:
2250                 lock_sock(sk);
2251                 ssock = __mptcp_nmpc_socket(msk);
2252                 if (!ssock) {
2253                         release_sock(sk);
2254                         return -EINVAL;
2255                 }
2256
2257                 ret = tcp_setsockopt(ssock->sk, SOL_IPV6, optname, optval, optlen);
2258                 if (ret == 0)
2259                         sk->sk_ipv6only = ssock->sk->sk_ipv6only;
2260
2261                 release_sock(sk);
2262                 break;
2263         }
2264
2265         return ret;
2266 }
2267
2268 static int mptcp_setsockopt(struct sock *sk, int level, int optname,
2269                             sockptr_t optval, unsigned int optlen)
2270 {
2271         struct mptcp_sock *msk = mptcp_sk(sk);
2272         struct sock *ssk;
2273
2274         pr_debug("msk=%p", msk);
2275
2276         if (level == SOL_SOCKET)
2277                 return mptcp_setsockopt_sol_socket(msk, optname, optval, optlen);
2278
2279         /* @@ the meaning of setsockopt() when the socket is connected and
2280          * there are multiple subflows is not yet defined. It is up to the
2281          * MPTCP-level socket to configure the subflows until the subflow
2282          * is in TCP fallback, when TCP socket options are passed through
2283          * to the one remaining subflow.
2284          */
2285         lock_sock(sk);
2286         ssk = __mptcp_tcp_fallback(msk);
2287         release_sock(sk);
2288         if (ssk)
2289                 return tcp_setsockopt(ssk, level, optname, optval, optlen);
2290
2291         if (level == SOL_IPV6)
2292                 return mptcp_setsockopt_v6(msk, optname, optval, optlen);
2293
2294         return -EOPNOTSUPP;
2295 }
2296
2297 static int mptcp_getsockopt(struct sock *sk, int level, int optname,
2298                             char __user *optval, int __user *option)
2299 {
2300         struct mptcp_sock *msk = mptcp_sk(sk);
2301         struct sock *ssk;
2302
2303         pr_debug("msk=%p", msk);
2304
2305         /* @@ the meaning of setsockopt() when the socket is connected and
2306          * there are multiple subflows is not yet defined. It is up to the
2307          * MPTCP-level socket to configure the subflows until the subflow
2308          * is in TCP fallback, when socket options are passed through
2309          * to the one remaining subflow.
2310          */
2311         lock_sock(sk);
2312         ssk = __mptcp_tcp_fallback(msk);
2313         release_sock(sk);
2314         if (ssk)
2315                 return tcp_getsockopt(ssk, level, optname, optval, option);
2316
2317         return -EOPNOTSUPP;
2318 }
2319
2320 #define MPTCP_DEFERRED_ALL (TCPF_DELACK_TIMER_DEFERRED | \
2321                             TCPF_WRITE_TIMER_DEFERRED)
2322
2323 /* this is very alike tcp_release_cb() but we must handle differently a
2324  * different set of events
2325  */
2326 static void mptcp_release_cb(struct sock *sk)
2327 {
2328         unsigned long flags, nflags;
2329
2330         do {
2331                 flags = sk->sk_tsq_flags;
2332                 if (!(flags & MPTCP_DEFERRED_ALL))
2333                         return;
2334                 nflags = flags & ~MPTCP_DEFERRED_ALL;
2335         } while (cmpxchg(&sk->sk_tsq_flags, flags, nflags) != flags);
2336
2337         sock_release_ownership(sk);
2338
2339         if (flags & TCPF_DELACK_TIMER_DEFERRED) {
2340                 struct mptcp_sock *msk = mptcp_sk(sk);
2341                 struct sock *ssk;
2342
2343                 ssk = mptcp_subflow_recv_lookup(msk);
2344                 if (!ssk || sk->sk_state == TCP_CLOSE ||
2345                     !schedule_work(&msk->work))
2346                         __sock_put(sk);
2347         }
2348
2349         if (flags & TCPF_WRITE_TIMER_DEFERRED) {
2350                 mptcp_retransmit_handler(sk);
2351                 __sock_put(sk);
2352         }
2353 }
2354
2355 static int mptcp_hash(struct sock *sk)
2356 {
2357         /* should never be called,
2358          * we hash the TCP subflows not the master socket
2359          */
2360         WARN_ON_ONCE(1);
2361         return 0;
2362 }
2363
2364 static void mptcp_unhash(struct sock *sk)
2365 {
2366         /* called from sk_common_release(), but nothing to do here */
2367 }
2368
2369 static int mptcp_get_port(struct sock *sk, unsigned short snum)
2370 {
2371         struct mptcp_sock *msk = mptcp_sk(sk);
2372         struct socket *ssock;
2373
2374         ssock = __mptcp_nmpc_socket(msk);
2375         pr_debug("msk=%p, subflow=%p", msk, ssock);
2376         if (WARN_ON_ONCE(!ssock))
2377                 return -EINVAL;
2378
2379         return inet_csk_get_port(ssock->sk, snum);
2380 }
2381
2382 void mptcp_finish_connect(struct sock *ssk)
2383 {
2384         struct mptcp_subflow_context *subflow;
2385         struct mptcp_sock *msk;
2386         struct sock *sk;
2387         u64 ack_seq;
2388
2389         subflow = mptcp_subflow_ctx(ssk);
2390         sk = subflow->conn;
2391         msk = mptcp_sk(sk);
2392
2393         pr_debug("msk=%p, token=%u", sk, subflow->token);
2394
2395         mptcp_crypto_key_sha(subflow->remote_key, NULL, &ack_seq);
2396         ack_seq++;
2397         subflow->map_seq = ack_seq;
2398         subflow->map_subflow_seq = 1;
2399
2400         /* the socket is not connected yet, no msk/subflow ops can access/race
2401          * accessing the field below
2402          */
2403         WRITE_ONCE(msk->remote_key, subflow->remote_key);
2404         WRITE_ONCE(msk->local_key, subflow->local_key);
2405         WRITE_ONCE(msk->write_seq, subflow->idsn + 1);
2406         WRITE_ONCE(msk->ack_seq, ack_seq);
2407         WRITE_ONCE(msk->can_ack, 1);
2408         atomic64_set(&msk->snd_una, msk->write_seq);
2409
2410         mptcp_pm_new_connection(msk, 0);
2411
2412         mptcp_rcv_space_init(msk, ssk);
2413 }
2414
2415 static void mptcp_sock_graft(struct sock *sk, struct socket *parent)
2416 {
2417         write_lock_bh(&sk->sk_callback_lock);
2418         rcu_assign_pointer(sk->sk_wq, &parent->wq);
2419         sk_set_socket(sk, parent);
2420         sk->sk_uid = SOCK_INODE(parent)->i_uid;
2421         write_unlock_bh(&sk->sk_callback_lock);
2422 }
2423
2424 bool mptcp_finish_join(struct sock *sk)
2425 {
2426         struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
2427         struct mptcp_sock *msk = mptcp_sk(subflow->conn);
2428         struct sock *parent = (void *)msk;
2429         struct socket *parent_sock;
2430         bool ret;
2431
2432         pr_debug("msk=%p, subflow=%p", msk, subflow);
2433
2434         /* mptcp socket already closing? */
2435         if (!mptcp_is_fully_established(parent))
2436                 return false;
2437
2438         if (!msk->pm.server_side)
2439                 return true;
2440
2441         if (!mptcp_pm_allow_new_subflow(msk))
2442                 return false;
2443
2444         /* active connections are already on conn_list, and we can't acquire
2445          * msk lock here.
2446          * use the join list lock as synchronization point and double-check
2447          * msk status to avoid racing with mptcp_close()
2448          */
2449         spin_lock_bh(&msk->join_list_lock);
2450         ret = inet_sk_state_load(parent) == TCP_ESTABLISHED;
2451         if (ret && !WARN_ON_ONCE(!list_empty(&subflow->node)))
2452                 list_add_tail(&subflow->node, &msk->join_list);
2453         spin_unlock_bh(&msk->join_list_lock);
2454         if (!ret)
2455                 return false;
2456
2457         /* attach to msk socket only after we are sure he will deal with us
2458          * at close time
2459          */
2460         parent_sock = READ_ONCE(parent->sk_socket);
2461         if (parent_sock && !sk->sk_socket)
2462                 mptcp_sock_graft(sk, parent_sock);
2463         subflow->map_seq = READ_ONCE(msk->ack_seq);
2464         return true;
2465 }
2466
2467 static bool mptcp_memory_free(const struct sock *sk, int wake)
2468 {
2469         struct mptcp_sock *msk = mptcp_sk(sk);
2470
2471         return wake ? test_bit(MPTCP_SEND_SPACE, &msk->flags) : true;
2472 }
2473
2474 static struct proto mptcp_prot = {
2475         .name           = "MPTCP",
2476         .owner          = THIS_MODULE,
2477         .init           = mptcp_init_sock,
2478         .disconnect     = mptcp_disconnect,
2479         .close          = mptcp_close,
2480         .accept         = mptcp_accept,
2481         .setsockopt     = mptcp_setsockopt,
2482         .getsockopt     = mptcp_getsockopt,
2483         .shutdown       = tcp_shutdown,
2484         .destroy        = mptcp_destroy,
2485         .sendmsg        = mptcp_sendmsg,
2486         .recvmsg        = mptcp_recvmsg,
2487         .release_cb     = mptcp_release_cb,
2488         .hash           = mptcp_hash,
2489         .unhash         = mptcp_unhash,
2490         .get_port       = mptcp_get_port,
2491         .sockets_allocated      = &mptcp_sockets_allocated,
2492         .memory_allocated       = &tcp_memory_allocated,
2493         .memory_pressure        = &tcp_memory_pressure,
2494         .stream_memory_free     = mptcp_memory_free,
2495         .sysctl_wmem_offset     = offsetof(struct net, ipv4.sysctl_tcp_wmem),
2496         .sysctl_rmem_offset     = offsetof(struct net, ipv4.sysctl_tcp_rmem),
2497         .sysctl_mem     = sysctl_tcp_mem,
2498         .obj_size       = sizeof(struct mptcp_sock),
2499         .slab_flags     = SLAB_TYPESAFE_BY_RCU,
2500         .no_autobind    = true,
2501 };
2502
2503 static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
2504 {
2505         struct mptcp_sock *msk = mptcp_sk(sock->sk);
2506         struct socket *ssock;
2507         int err;
2508
2509         lock_sock(sock->sk);
2510         ssock = __mptcp_nmpc_socket(msk);
2511         if (!ssock) {
2512                 err = -EINVAL;
2513                 goto unlock;
2514         }
2515
2516         err = ssock->ops->bind(ssock, uaddr, addr_len);
2517         if (!err)
2518                 mptcp_copy_inaddrs(sock->sk, ssock->sk);
2519
2520 unlock:
2521         release_sock(sock->sk);
2522         return err;
2523 }
2524
2525 static void mptcp_subflow_early_fallback(struct mptcp_sock *msk,
2526                                          struct mptcp_subflow_context *subflow)
2527 {
2528         subflow->request_mptcp = 0;
2529         __mptcp_do_fallback(msk);
2530 }
2531
2532 static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
2533                                 int addr_len, int flags)
2534 {
2535         struct mptcp_sock *msk = mptcp_sk(sock->sk);
2536         struct mptcp_subflow_context *subflow;
2537         struct socket *ssock;
2538         int err;
2539
2540         lock_sock(sock->sk);
2541         if (sock->state != SS_UNCONNECTED && msk->subflow) {
2542                 /* pending connection or invalid state, let existing subflow
2543                  * cope with that
2544                  */
2545                 ssock = msk->subflow;
2546                 goto do_connect;
2547         }
2548
2549         ssock = __mptcp_nmpc_socket(msk);
2550         if (!ssock) {
2551                 err = -EINVAL;
2552                 goto unlock;
2553         }
2554
2555         mptcp_token_destroy(msk);
2556         inet_sk_state_store(sock->sk, TCP_SYN_SENT);
2557         subflow = mptcp_subflow_ctx(ssock->sk);
2558 #ifdef CONFIG_TCP_MD5SIG
2559         /* no MPTCP if MD5SIG is enabled on this socket or we may run out of
2560          * TCP option space.
2561          */
2562         if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info))
2563                 mptcp_subflow_early_fallback(msk, subflow);
2564 #endif
2565         if (subflow->request_mptcp && mptcp_token_new_connect(ssock->sk))
2566                 mptcp_subflow_early_fallback(msk, subflow);
2567
2568 do_connect:
2569         err = ssock->ops->connect(ssock, uaddr, addr_len, flags);
2570         sock->state = ssock->state;
2571
2572         /* on successful connect, the msk state will be moved to established by
2573          * subflow_finish_connect()
2574          */
2575         if (!err || err == -EINPROGRESS)
2576                 mptcp_copy_inaddrs(sock->sk, ssock->sk);
2577         else
2578                 inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
2579
2580 unlock:
2581         release_sock(sock->sk);
2582         return err;
2583 }
2584
2585 static int mptcp_listen(struct socket *sock, int backlog)
2586 {
2587         struct mptcp_sock *msk = mptcp_sk(sock->sk);
2588         struct socket *ssock;
2589         int err;
2590
2591         pr_debug("msk=%p", msk);
2592
2593         lock_sock(sock->sk);
2594         ssock = __mptcp_nmpc_socket(msk);
2595         if (!ssock) {
2596                 err = -EINVAL;
2597                 goto unlock;
2598         }
2599
2600         mptcp_token_destroy(msk);
2601         inet_sk_state_store(sock->sk, TCP_LISTEN);
2602         sock_set_flag(sock->sk, SOCK_RCU_FREE);
2603
2604         err = ssock->ops->listen(ssock, backlog);
2605         inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
2606         if (!err)
2607                 mptcp_copy_inaddrs(sock->sk, ssock->sk);
2608
2609 unlock:
2610         release_sock(sock->sk);
2611         return err;
2612 }
2613
2614 static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
2615                                int flags, bool kern)
2616 {
2617         struct mptcp_sock *msk = mptcp_sk(sock->sk);
2618         struct socket *ssock;
2619         int err;
2620
2621         pr_debug("msk=%p", msk);
2622
2623         lock_sock(sock->sk);
2624         if (sock->sk->sk_state != TCP_LISTEN)
2625                 goto unlock_fail;
2626
2627         ssock = __mptcp_nmpc_socket(msk);
2628         if (!ssock)
2629                 goto unlock_fail;
2630
2631         clear_bit(MPTCP_DATA_READY, &msk->flags);
2632         sock_hold(ssock->sk);
2633         release_sock(sock->sk);
2634
2635         err = ssock->ops->accept(sock, newsock, flags, kern);
2636         if (err == 0 && !mptcp_is_tcpsk(newsock->sk)) {
2637                 struct mptcp_sock *msk = mptcp_sk(newsock->sk);
2638                 struct mptcp_subflow_context *subflow;
2639
2640                 /* set ssk->sk_socket of accept()ed flows to mptcp socket.
2641                  * This is needed so NOSPACE flag can be set from tcp stack.
2642                  */
2643                 __mptcp_flush_join_list(msk);
2644                 mptcp_for_each_subflow(msk, subflow) {
2645                         struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
2646
2647                         if (!ssk->sk_socket)
2648                                 mptcp_sock_graft(ssk, newsock);
2649                 }
2650         }
2651
2652         if (inet_csk_listen_poll(ssock->sk))
2653                 set_bit(MPTCP_DATA_READY, &msk->flags);
2654         sock_put(ssock->sk);
2655         return err;
2656
2657 unlock_fail:
2658         release_sock(sock->sk);
2659         return -EINVAL;
2660 }
2661
2662 static __poll_t mptcp_check_readable(struct mptcp_sock *msk)
2663 {
2664         return test_bit(MPTCP_DATA_READY, &msk->flags) ? EPOLLIN | EPOLLRDNORM :
2665                0;
2666 }
2667
2668 static __poll_t mptcp_poll(struct file *file, struct socket *sock,
2669                            struct poll_table_struct *wait)
2670 {
2671         struct sock *sk = sock->sk;
2672         struct mptcp_sock *msk;
2673         __poll_t mask = 0;
2674         int state;
2675
2676         msk = mptcp_sk(sk);
2677         sock_poll_wait(file, sock, wait);
2678
2679         state = inet_sk_state_load(sk);
2680         pr_debug("msk=%p state=%d flags=%lx", msk, state, msk->flags);
2681         if (state == TCP_LISTEN)
2682                 return mptcp_check_readable(msk);
2683
2684         if (state != TCP_SYN_SENT && state != TCP_SYN_RECV) {
2685                 mask |= mptcp_check_readable(msk);
2686                 if (test_bit(MPTCP_SEND_SPACE, &msk->flags))
2687                         mask |= EPOLLOUT | EPOLLWRNORM;
2688         }
2689         if (sk->sk_shutdown & RCV_SHUTDOWN)
2690                 mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
2691
2692         return mask;
2693 }
2694
2695 static int mptcp_shutdown(struct socket *sock, int how)
2696 {
2697         struct mptcp_sock *msk = mptcp_sk(sock->sk);
2698         struct mptcp_subflow_context *subflow;
2699         int ret = 0;
2700
2701         pr_debug("sk=%p, how=%d", msk, how);
2702
2703         lock_sock(sock->sk);
2704
2705         how++;
2706         if ((how & ~SHUTDOWN_MASK) || !how) {
2707                 ret = -EINVAL;
2708                 goto out_unlock;
2709         }
2710
2711         if (sock->state == SS_CONNECTING) {
2712                 if ((1 << sock->sk->sk_state) &
2713                     (TCPF_SYN_SENT | TCPF_SYN_RECV | TCPF_CLOSE))
2714                         sock->state = SS_DISCONNECTING;
2715                 else
2716                         sock->state = SS_CONNECTED;
2717         }
2718
2719         /* If we've already sent a FIN, or it's a closed state, skip this. */
2720         if (__mptcp_check_fallback(msk)) {
2721                 if (how == SHUT_WR || how == SHUT_RDWR)
2722                         inet_sk_state_store(sock->sk, TCP_FIN_WAIT1);
2723
2724                 mptcp_for_each_subflow(msk, subflow) {
2725                         struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow);
2726
2727                         mptcp_subflow_shutdown(sock->sk, tcp_sk, how);
2728                 }
2729         } else if ((how & SEND_SHUTDOWN) &&
2730                    ((1 << sock->sk->sk_state) &
2731                     (TCPF_ESTABLISHED | TCPF_SYN_SENT |
2732                      TCPF_SYN_RECV | TCPF_CLOSE_WAIT)) &&
2733                    mptcp_close_state(sock->sk)) {
2734                 __mptcp_flush_join_list(msk);
2735
2736                 WRITE_ONCE(msk->write_seq, msk->write_seq + 1);
2737                 WRITE_ONCE(msk->snd_data_fin_enable, 1);
2738
2739                 mptcp_for_each_subflow(msk, subflow) {
2740                         struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow);
2741
2742                         mptcp_subflow_shutdown(sock->sk, tcp_sk, how);
2743                 }
2744         }
2745
2746         /* Wake up anyone sleeping in poll. */
2747         sock->sk->sk_state_change(sock->sk);
2748
2749 out_unlock:
2750         release_sock(sock->sk);
2751
2752         return ret;
2753 }
2754
2755 static const struct proto_ops mptcp_stream_ops = {
2756         .family            = PF_INET,
2757         .owner             = THIS_MODULE,
2758         .release           = inet_release,
2759         .bind              = mptcp_bind,
2760         .connect           = mptcp_stream_connect,
2761         .socketpair        = sock_no_socketpair,
2762         .accept            = mptcp_stream_accept,
2763         .getname           = inet_getname,
2764         .poll              = mptcp_poll,
2765         .ioctl             = inet_ioctl,
2766         .gettstamp         = sock_gettstamp,
2767         .listen            = mptcp_listen,
2768         .shutdown          = mptcp_shutdown,
2769         .setsockopt        = sock_common_setsockopt,
2770         .getsockopt        = sock_common_getsockopt,
2771         .sendmsg           = inet_sendmsg,
2772         .recvmsg           = inet_recvmsg,
2773         .mmap              = sock_no_mmap,
2774         .sendpage          = inet_sendpage,
2775 };
2776
2777 static struct inet_protosw mptcp_protosw = {
2778         .type           = SOCK_STREAM,
2779         .protocol       = IPPROTO_MPTCP,
2780         .prot           = &mptcp_prot,
2781         .ops            = &mptcp_stream_ops,
2782         .flags          = INET_PROTOSW_ICSK,
2783 };
2784
2785 void __init mptcp_proto_init(void)
2786 {
2787         mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo;
2788
2789         if (percpu_counter_init(&mptcp_sockets_allocated, 0, GFP_KERNEL))
2790                 panic("Failed to allocate MPTCP pcpu counter\n");
2791
2792         mptcp_subflow_init();
2793         mptcp_pm_init();
2794         mptcp_token_init();
2795
2796         if (proto_register(&mptcp_prot, 1) != 0)
2797                 panic("Failed to register MPTCP proto.\n");
2798
2799         inet_register_protosw(&mptcp_protosw);
2800
2801         BUILD_BUG_ON(sizeof(struct mptcp_skb_cb) > sizeof_field(struct sk_buff, cb));
2802 }
2803
2804 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
2805 static const struct proto_ops mptcp_v6_stream_ops = {
2806         .family            = PF_INET6,
2807         .owner             = THIS_MODULE,
2808         .release           = inet6_release,
2809         .bind              = mptcp_bind,
2810         .connect           = mptcp_stream_connect,
2811         .socketpair        = sock_no_socketpair,
2812         .accept            = mptcp_stream_accept,
2813         .getname           = inet6_getname,
2814         .poll              = mptcp_poll,
2815         .ioctl             = inet6_ioctl,
2816         .gettstamp         = sock_gettstamp,
2817         .listen            = mptcp_listen,
2818         .shutdown          = mptcp_shutdown,
2819         .setsockopt        = sock_common_setsockopt,
2820         .getsockopt        = sock_common_getsockopt,
2821         .sendmsg           = inet6_sendmsg,
2822         .recvmsg           = inet6_recvmsg,
2823         .mmap              = sock_no_mmap,
2824         .sendpage          = inet_sendpage,
2825 #ifdef CONFIG_COMPAT
2826         .compat_ioctl      = inet6_compat_ioctl,
2827 #endif
2828 };
2829
2830 static struct proto mptcp_v6_prot;
2831
2832 static void mptcp_v6_destroy(struct sock *sk)
2833 {
2834         mptcp_destroy(sk);
2835         inet6_destroy_sock(sk);
2836 }
2837
2838 static struct inet_protosw mptcp_v6_protosw = {
2839         .type           = SOCK_STREAM,
2840         .protocol       = IPPROTO_MPTCP,
2841         .prot           = &mptcp_v6_prot,
2842         .ops            = &mptcp_v6_stream_ops,
2843         .flags          = INET_PROTOSW_ICSK,
2844 };
2845
2846 int __init mptcp_proto_v6_init(void)
2847 {
2848         int err;
2849
2850         mptcp_v6_prot = mptcp_prot;
2851         strcpy(mptcp_v6_prot.name, "MPTCPv6");
2852         mptcp_v6_prot.slab = NULL;
2853         mptcp_v6_prot.destroy = mptcp_v6_destroy;
2854         mptcp_v6_prot.obj_size = sizeof(struct mptcp6_sock);
2855
2856         err = proto_register(&mptcp_v6_prot, 1);
2857         if (err)
2858                 return err;
2859
2860         err = inet6_register_protosw(&mptcp_v6_protosw);
2861         if (err)
2862                 proto_unregister(&mptcp_v6_prot);
2863
2864         return err;
2865 }
2866 #endif