cbf184a71ed75039bd9f72cd504b5e932a8fd515
[platform/kernel/linux-starfive.git] / net / mptcp / protocol.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2017 - 2019, Intel Corporation.
5  */
6
7 #define pr_fmt(fmt) "MPTCP: " fmt
8
9 #include <linux/kernel.h>
10 #include <linux/module.h>
11 #include <linux/netdevice.h>
12 #include <linux/sched/signal.h>
13 #include <linux/atomic.h>
14 #include <net/sock.h>
15 #include <net/inet_common.h>
16 #include <net/inet_hashtables.h>
17 #include <net/protocol.h>
18 #include <net/tcp.h>
19 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
20 #include <net/transp_v6.h>
21 #endif
22 #include <net/mptcp.h>
23 #include "protocol.h"
24
25 #define MPTCP_SAME_STATE TCP_MAX_STATES
26
27 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
28 struct mptcp6_sock {
29         struct mptcp_sock msk;
30         struct ipv6_pinfo np;
31 };
32 #endif
33
34 /* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not
35  * completed yet or has failed, return the subflow socket.
36  * Otherwise return NULL.
37  */
38 static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
39 {
40         if (!msk->subflow || READ_ONCE(msk->can_ack))
41                 return NULL;
42
43         return msk->subflow;
44 }
45
46 static bool __mptcp_needs_tcp_fallback(const struct mptcp_sock *msk)
47 {
48         return msk->first && !sk_is_mptcp(msk->first);
49 }
50
51 static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk)
52 {
53         sock_owned_by_me((const struct sock *)msk);
54
55         if (likely(!__mptcp_needs_tcp_fallback(msk)))
56                 return NULL;
57
58         if (msk->subflow) {
59                 release_sock((struct sock *)msk);
60                 return msk->subflow;
61         }
62
63         return NULL;
64 }
65
66 static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk)
67 {
68         return !msk->first;
69 }
70
71 static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state)
72 {
73         struct mptcp_subflow_context *subflow;
74         struct sock *sk = (struct sock *)msk;
75         struct socket *ssock;
76         int err;
77
78         ssock = __mptcp_nmpc_socket(msk);
79         if (ssock)
80                 goto set_state;
81
82         if (!__mptcp_can_create_subflow(msk))
83                 return ERR_PTR(-EINVAL);
84
85         err = mptcp_subflow_create_socket(sk, &ssock);
86         if (err)
87                 return ERR_PTR(err);
88
89         msk->first = ssock->sk;
90         msk->subflow = ssock;
91         subflow = mptcp_subflow_ctx(ssock->sk);
92         list_add(&subflow->node, &msk->conn_list);
93         subflow->request_mptcp = 1;
94
95 set_state:
96         if (state != MPTCP_SAME_STATE)
97                 inet_sk_state_store(sk, state);
98         return ssock;
99 }
100
101 static struct sock *mptcp_subflow_get(const struct mptcp_sock *msk)
102 {
103         struct mptcp_subflow_context *subflow;
104
105         sock_owned_by_me((const struct sock *)msk);
106
107         mptcp_for_each_subflow(msk, subflow) {
108                 return mptcp_subflow_tcp_sock(subflow);
109         }
110
111         return NULL;
112 }
113
114 void mptcp_data_ready(struct sock *sk)
115 {
116         struct mptcp_sock *msk = mptcp_sk(sk);
117
118         set_bit(MPTCP_DATA_READY, &msk->flags);
119         sk->sk_data_ready(sk);
120 }
121
122 static bool mptcp_ext_cache_refill(struct mptcp_sock *msk)
123 {
124         if (!msk->cached_ext)
125                 msk->cached_ext = __skb_ext_alloc();
126
127         return !!msk->cached_ext;
128 }
129
130 static struct sock *mptcp_subflow_recv_lookup(const struct mptcp_sock *msk)
131 {
132         struct mptcp_subflow_context *subflow;
133         struct sock *sk = (struct sock *)msk;
134
135         sock_owned_by_me(sk);
136
137         mptcp_for_each_subflow(msk, subflow) {
138                 if (subflow->data_avail)
139                         return mptcp_subflow_tcp_sock(subflow);
140         }
141
142         return NULL;
143 }
144
145 static inline bool mptcp_skb_can_collapse_to(const struct mptcp_sock *msk,
146                                              const struct sk_buff *skb,
147                                              const struct mptcp_ext *mpext)
148 {
149         if (!tcp_skb_can_collapse_to(skb))
150                 return false;
151
152         /* can collapse only if MPTCP level sequence is in order */
153         return mpext && mpext->data_seq + mpext->data_len == msk->write_seq;
154 }
155
156 static int mptcp_sendmsg_frag(struct sock *sk, struct sock *ssk,
157                               struct msghdr *msg, long *timeo, int *pmss_now,
158                               int *ps_goal)
159 {
160         int mss_now, avail_size, size_goal, ret;
161         struct mptcp_sock *msk = mptcp_sk(sk);
162         struct mptcp_ext *mpext = NULL;
163         struct sk_buff *skb, *tail;
164         bool can_collapse = false;
165         struct page_frag *pfrag;
166         size_t psize;
167
168         /* use the mptcp page cache so that we can easily move the data
169          * from one substream to another, but do per subflow memory accounting
170          */
171         pfrag = sk_page_frag(sk);
172         while (!sk_page_frag_refill(ssk, pfrag) ||
173                !mptcp_ext_cache_refill(msk)) {
174                 ret = sk_stream_wait_memory(ssk, timeo);
175                 if (ret)
176                         return ret;
177                 if (unlikely(__mptcp_needs_tcp_fallback(msk)))
178                         return 0;
179         }
180
181         /* compute copy limit */
182         mss_now = tcp_send_mss(ssk, &size_goal, msg->msg_flags);
183         *pmss_now = mss_now;
184         *ps_goal = size_goal;
185         avail_size = size_goal;
186         skb = tcp_write_queue_tail(ssk);
187         if (skb) {
188                 mpext = skb_ext_find(skb, SKB_EXT_MPTCP);
189
190                 /* Limit the write to the size available in the
191                  * current skb, if any, so that we create at most a new skb.
192                  * Explicitly tells TCP internals to avoid collapsing on later
193                  * queue management operation, to avoid breaking the ext <->
194                  * SSN association set here
195                  */
196                 can_collapse = (size_goal - skb->len > 0) &&
197                               mptcp_skb_can_collapse_to(msk, skb, mpext);
198                 if (!can_collapse)
199                         TCP_SKB_CB(skb)->eor = 1;
200                 else
201                         avail_size = size_goal - skb->len;
202         }
203         psize = min_t(size_t, pfrag->size - pfrag->offset, avail_size);
204
205         /* Copy to page */
206         pr_debug("left=%zu", msg_data_left(msg));
207         psize = copy_page_from_iter(pfrag->page, pfrag->offset,
208                                     min_t(size_t, msg_data_left(msg), psize),
209                                     &msg->msg_iter);
210         pr_debug("left=%zu", msg_data_left(msg));
211         if (!psize)
212                 return -EINVAL;
213
214         /* tell the TCP stack to delay the push so that we can safely
215          * access the skb after the sendpages call
216          */
217         ret = do_tcp_sendpages(ssk, pfrag->page, pfrag->offset, psize,
218                                msg->msg_flags | MSG_SENDPAGE_NOTLAST);
219         if (ret <= 0)
220                 return ret;
221         if (unlikely(ret < psize))
222                 iov_iter_revert(&msg->msg_iter, psize - ret);
223
224         /* if the tail skb extension is still the cached one, collapsing
225          * really happened. Note: we can't check for 'same skb' as the sk_buff
226          * hdr on tail can be transmitted, freed and re-allocated by the
227          * do_tcp_sendpages() call
228          */
229         tail = tcp_write_queue_tail(ssk);
230         if (mpext && tail && mpext == skb_ext_find(tail, SKB_EXT_MPTCP)) {
231                 WARN_ON_ONCE(!can_collapse);
232                 mpext->data_len += ret;
233                 goto out;
234         }
235
236         skb = tcp_write_queue_tail(ssk);
237         mpext = __skb_ext_set(skb, SKB_EXT_MPTCP, msk->cached_ext);
238         msk->cached_ext = NULL;
239
240         memset(mpext, 0, sizeof(*mpext));
241         mpext->data_seq = msk->write_seq;
242         mpext->subflow_seq = mptcp_subflow_ctx(ssk)->rel_write_seq;
243         mpext->data_len = ret;
244         mpext->use_map = 1;
245         mpext->dsn64 = 1;
246
247         pr_debug("data_seq=%llu subflow_seq=%u data_len=%u dsn64=%d",
248                  mpext->data_seq, mpext->subflow_seq, mpext->data_len,
249                  mpext->dsn64);
250
251 out:
252         pfrag->offset += ret;
253         msk->write_seq += ret;
254         mptcp_subflow_ctx(ssk)->rel_write_seq += ret;
255
256         return ret;
257 }
258
259 static void ssk_check_wmem(struct mptcp_sock *msk, struct sock *ssk)
260 {
261         struct socket *sock;
262
263         if (likely(sk_stream_is_writeable(ssk)))
264                 return;
265
266         sock = READ_ONCE(ssk->sk_socket);
267
268         if (sock) {
269                 clear_bit(MPTCP_SEND_SPACE, &msk->flags);
270                 smp_mb__after_atomic();
271                 /* set NOSPACE only after clearing SEND_SPACE flag */
272                 set_bit(SOCK_NOSPACE, &sock->flags);
273         }
274 }
275
276 static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
277 {
278         int mss_now = 0, size_goal = 0, ret = 0;
279         struct mptcp_sock *msk = mptcp_sk(sk);
280         struct socket *ssock;
281         size_t copied = 0;
282         struct sock *ssk;
283         long timeo;
284
285         if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
286                 return -EOPNOTSUPP;
287
288         lock_sock(sk);
289         ssock = __mptcp_tcp_fallback(msk);
290         if (unlikely(ssock)) {
291 fallback:
292                 pr_debug("fallback passthrough");
293                 ret = sock_sendmsg(ssock, msg);
294                 return ret >= 0 ? ret + copied : (copied ? copied : ret);
295         }
296
297         timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
298
299         ssk = mptcp_subflow_get(msk);
300         if (!ssk) {
301                 release_sock(sk);
302                 return -ENOTCONN;
303         }
304
305         pr_debug("conn_list->subflow=%p", ssk);
306
307         lock_sock(ssk);
308         while (msg_data_left(msg)) {
309                 ret = mptcp_sendmsg_frag(sk, ssk, msg, &timeo, &mss_now,
310                                          &size_goal);
311                 if (ret < 0)
312                         break;
313                 if (ret == 0 && unlikely(__mptcp_needs_tcp_fallback(msk))) {
314                         release_sock(ssk);
315                         ssock = __mptcp_tcp_fallback(msk);
316                         goto fallback;
317                 }
318
319                 copied += ret;
320         }
321
322         if (copied) {
323                 ret = copied;
324                 tcp_push(ssk, msg->msg_flags, mss_now, tcp_sk(ssk)->nonagle,
325                          size_goal);
326         }
327
328         ssk_check_wmem(msk, ssk);
329         release_sock(ssk);
330         release_sock(sk);
331         return ret;
332 }
333
334 int mptcp_read_actor(read_descriptor_t *desc, struct sk_buff *skb,
335                      unsigned int offset, size_t len)
336 {
337         struct mptcp_read_arg *arg = desc->arg.data;
338         size_t copy_len;
339
340         copy_len = min(desc->count, len);
341
342         if (likely(arg->msg)) {
343                 int err;
344
345                 err = skb_copy_datagram_msg(skb, offset, arg->msg, copy_len);
346                 if (err) {
347                         pr_debug("error path");
348                         desc->error = err;
349                         return err;
350                 }
351         } else {
352                 pr_debug("Flushing skb payload");
353         }
354
355         desc->count -= copy_len;
356
357         pr_debug("consumed %zu bytes, %zu left", copy_len, desc->count);
358         return copy_len;
359 }
360
361 static void mptcp_wait_data(struct sock *sk, long *timeo)
362 {
363         DEFINE_WAIT_FUNC(wait, woken_wake_function);
364         struct mptcp_sock *msk = mptcp_sk(sk);
365
366         add_wait_queue(sk_sleep(sk), &wait);
367         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
368
369         sk_wait_event(sk, timeo,
370                       test_and_clear_bit(MPTCP_DATA_READY, &msk->flags), &wait);
371
372         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
373         remove_wait_queue(sk_sleep(sk), &wait);
374 }
375
376 static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
377                          int nonblock, int flags, int *addr_len)
378 {
379         struct mptcp_sock *msk = mptcp_sk(sk);
380         struct mptcp_subflow_context *subflow;
381         bool more_data_avail = false;
382         struct mptcp_read_arg arg;
383         read_descriptor_t desc;
384         bool wait_data = false;
385         struct socket *ssock;
386         struct tcp_sock *tp;
387         bool done = false;
388         struct sock *ssk;
389         int copied = 0;
390         int target;
391         long timeo;
392
393         if (msg->msg_flags & ~(MSG_WAITALL | MSG_DONTWAIT))
394                 return -EOPNOTSUPP;
395
396         lock_sock(sk);
397         ssock = __mptcp_tcp_fallback(msk);
398         if (unlikely(ssock)) {
399 fallback:
400                 pr_debug("fallback-read subflow=%p",
401                          mptcp_subflow_ctx(ssock->sk));
402                 copied = sock_recvmsg(ssock, msg, flags);
403                 return copied;
404         }
405
406         arg.msg = msg;
407         desc.arg.data = &arg;
408         desc.error = 0;
409
410         timeo = sock_rcvtimeo(sk, nonblock);
411
412         len = min_t(size_t, len, INT_MAX);
413         target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
414
415         while (!done) {
416                 u32 map_remaining;
417                 int bytes_read;
418
419                 ssk = mptcp_subflow_recv_lookup(msk);
420                 pr_debug("msk=%p ssk=%p", msk, ssk);
421                 if (!ssk)
422                         goto wait_for_data;
423
424                 subflow = mptcp_subflow_ctx(ssk);
425                 tp = tcp_sk(ssk);
426
427                 lock_sock(ssk);
428                 do {
429                         /* try to read as much data as available */
430                         map_remaining = subflow->map_data_len -
431                                         mptcp_subflow_get_map_offset(subflow);
432                         desc.count = min_t(size_t, len - copied, map_remaining);
433                         pr_debug("reading %zu bytes, copied %d", desc.count,
434                                  copied);
435                         bytes_read = tcp_read_sock(ssk, &desc,
436                                                    mptcp_read_actor);
437                         if (bytes_read < 0) {
438                                 if (!copied)
439                                         copied = bytes_read;
440                                 done = true;
441                                 goto next;
442                         }
443
444                         pr_debug("msk ack_seq=%llx -> %llx", msk->ack_seq,
445                                  msk->ack_seq + bytes_read);
446                         msk->ack_seq += bytes_read;
447                         copied += bytes_read;
448                         if (copied >= len) {
449                                 done = true;
450                                 goto next;
451                         }
452                         if (tp->urg_data && tp->urg_seq == tp->copied_seq) {
453                                 pr_err("Urgent data present, cannot proceed");
454                                 done = true;
455                                 goto next;
456                         }
457 next:
458                         more_data_avail = mptcp_subflow_data_available(ssk);
459                 } while (more_data_avail && !done);
460                 release_sock(ssk);
461                 continue;
462
463 wait_for_data:
464                 more_data_avail = false;
465
466                 /* only the master socket status is relevant here. The exit
467                  * conditions mirror closely tcp_recvmsg()
468                  */
469                 if (copied >= target)
470                         break;
471
472                 if (copied) {
473                         if (sk->sk_err ||
474                             sk->sk_state == TCP_CLOSE ||
475                             (sk->sk_shutdown & RCV_SHUTDOWN) ||
476                             !timeo ||
477                             signal_pending(current))
478                                 break;
479                 } else {
480                         if (sk->sk_err) {
481                                 copied = sock_error(sk);
482                                 break;
483                         }
484
485                         if (sk->sk_shutdown & RCV_SHUTDOWN)
486                                 break;
487
488                         if (sk->sk_state == TCP_CLOSE) {
489                                 copied = -ENOTCONN;
490                                 break;
491                         }
492
493                         if (!timeo) {
494                                 copied = -EAGAIN;
495                                 break;
496                         }
497
498                         if (signal_pending(current)) {
499                                 copied = sock_intr_errno(timeo);
500                                 break;
501                         }
502                 }
503
504                 pr_debug("block timeout %ld", timeo);
505                 wait_data = true;
506                 mptcp_wait_data(sk, &timeo);
507                 if (unlikely(__mptcp_tcp_fallback(msk)))
508                         goto fallback;
509         }
510
511         if (more_data_avail) {
512                 if (!test_bit(MPTCP_DATA_READY, &msk->flags))
513                         set_bit(MPTCP_DATA_READY, &msk->flags);
514         } else if (!wait_data) {
515                 clear_bit(MPTCP_DATA_READY, &msk->flags);
516
517                 /* .. race-breaker: ssk might get new data after last
518                  * data_available() returns false.
519                  */
520                 ssk = mptcp_subflow_recv_lookup(msk);
521                 if (unlikely(ssk))
522                         set_bit(MPTCP_DATA_READY, &msk->flags);
523         }
524
525         release_sock(sk);
526         return copied;
527 }
528
529 /* subflow sockets can be either outgoing (connect) or incoming
530  * (accept).
531  *
532  * Outgoing subflows use in-kernel sockets.
533  * Incoming subflows do not have their own 'struct socket' allocated,
534  * so we need to use tcp_close() after detaching them from the mptcp
535  * parent socket.
536  */
537 static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
538                               struct mptcp_subflow_context *subflow,
539                               long timeout)
540 {
541         struct socket *sock = READ_ONCE(ssk->sk_socket);
542
543         list_del(&subflow->node);
544
545         if (sock && sock != sk->sk_socket) {
546                 /* outgoing subflow */
547                 sock_release(sock);
548         } else {
549                 /* incoming subflow */
550                 tcp_close(ssk, timeout);
551         }
552 }
553
554 static void mptcp_worker(struct work_struct *work)
555 {
556         struct mptcp_sock *msk = container_of(work, struct mptcp_sock, work);
557         struct sock *sk = &msk->sk.icsk_inet.sk;
558
559         lock_sock(sk);
560
561         release_sock(sk);
562         sock_put(sk);
563 }
564
565 static int __mptcp_init_sock(struct sock *sk)
566 {
567         struct mptcp_sock *msk = mptcp_sk(sk);
568
569         INIT_LIST_HEAD(&msk->conn_list);
570         __set_bit(MPTCP_SEND_SPACE, &msk->flags);
571         INIT_WORK(&msk->work, mptcp_worker);
572
573         msk->first = NULL;
574
575         return 0;
576 }
577
578 static int mptcp_init_sock(struct sock *sk)
579 {
580         if (!mptcp_is_enabled(sock_net(sk)))
581                 return -ENOPROTOOPT;
582
583         return __mptcp_init_sock(sk);
584 }
585
586 static void mptcp_cancel_work(struct sock *sk)
587 {
588         struct mptcp_sock *msk = mptcp_sk(sk);
589
590         if (cancel_work_sync(&msk->work))
591                 sock_put(sk);
592 }
593
594 static void mptcp_subflow_shutdown(struct sock *ssk, int how)
595 {
596         lock_sock(ssk);
597
598         switch (ssk->sk_state) {
599         case TCP_LISTEN:
600                 if (!(how & RCV_SHUTDOWN))
601                         break;
602                 /* fall through */
603         case TCP_SYN_SENT:
604                 tcp_disconnect(ssk, O_NONBLOCK);
605                 break;
606         default:
607                 ssk->sk_shutdown |= how;
608                 tcp_shutdown(ssk, how);
609                 break;
610         }
611
612         /* Wake up anyone sleeping in poll. */
613         ssk->sk_state_change(ssk);
614         release_sock(ssk);
615 }
616
617 /* Called with msk lock held, releases such lock before returning */
618 static void mptcp_close(struct sock *sk, long timeout)
619 {
620         struct mptcp_subflow_context *subflow, *tmp;
621         struct mptcp_sock *msk = mptcp_sk(sk);
622         LIST_HEAD(conn_list);
623
624         lock_sock(sk);
625
626         mptcp_token_destroy(msk->token);
627         inet_sk_state_store(sk, TCP_CLOSE);
628
629         list_splice_init(&msk->conn_list, &conn_list);
630
631         release_sock(sk);
632
633         list_for_each_entry_safe(subflow, tmp, &conn_list, node) {
634                 struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
635
636                 __mptcp_close_ssk(sk, ssk, subflow, timeout);
637         }
638
639         mptcp_cancel_work(sk);
640
641         sk_common_release(sk);
642 }
643
644 static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
645 {
646 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
647         const struct ipv6_pinfo *ssk6 = inet6_sk(ssk);
648         struct ipv6_pinfo *msk6 = inet6_sk(msk);
649
650         msk->sk_v6_daddr = ssk->sk_v6_daddr;
651         msk->sk_v6_rcv_saddr = ssk->sk_v6_rcv_saddr;
652
653         if (msk6 && ssk6) {
654                 msk6->saddr = ssk6->saddr;
655                 msk6->flow_label = ssk6->flow_label;
656         }
657 #endif
658
659         inet_sk(msk)->inet_num = inet_sk(ssk)->inet_num;
660         inet_sk(msk)->inet_dport = inet_sk(ssk)->inet_dport;
661         inet_sk(msk)->inet_sport = inet_sk(ssk)->inet_sport;
662         inet_sk(msk)->inet_daddr = inet_sk(ssk)->inet_daddr;
663         inet_sk(msk)->inet_saddr = inet_sk(ssk)->inet_saddr;
664         inet_sk(msk)->inet_rcv_saddr = inet_sk(ssk)->inet_rcv_saddr;
665 }
666
667 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
668 static struct ipv6_pinfo *mptcp_inet6_sk(const struct sock *sk)
669 {
670         unsigned int offset = sizeof(struct mptcp6_sock) - sizeof(struct ipv6_pinfo);
671
672         return (struct ipv6_pinfo *)(((u8 *)sk) + offset);
673 }
674 #endif
675
676 static struct sock *mptcp_sk_clone_lock(const struct sock *sk)
677 {
678         struct sock *nsk = sk_clone_lock(sk, GFP_ATOMIC);
679
680         if (!nsk)
681                 return NULL;
682
683 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
684         if (nsk->sk_family == AF_INET6)
685                 inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk);
686 #endif
687
688         return nsk;
689 }
690
691 static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
692                                  bool kern)
693 {
694         struct mptcp_sock *msk = mptcp_sk(sk);
695         struct socket *listener;
696         struct sock *newsk;
697
698         listener = __mptcp_nmpc_socket(msk);
699         if (WARN_ON_ONCE(!listener)) {
700                 *err = -EINVAL;
701                 return NULL;
702         }
703
704         pr_debug("msk=%p, listener=%p", msk, mptcp_subflow_ctx(listener->sk));
705         newsk = inet_csk_accept(listener->sk, flags, err, kern);
706         if (!newsk)
707                 return NULL;
708
709         pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));
710
711         if (sk_is_mptcp(newsk)) {
712                 struct mptcp_subflow_context *subflow;
713                 struct sock *new_mptcp_sock;
714                 struct sock *ssk = newsk;
715                 u64 ack_seq;
716
717                 subflow = mptcp_subflow_ctx(newsk);
718                 lock_sock(sk);
719
720                 local_bh_disable();
721                 new_mptcp_sock = mptcp_sk_clone_lock(sk);
722                 if (!new_mptcp_sock) {
723                         *err = -ENOBUFS;
724                         local_bh_enable();
725                         release_sock(sk);
726                         mptcp_subflow_shutdown(newsk, SHUT_RDWR + 1);
727                         tcp_close(newsk, 0);
728                         return NULL;
729                 }
730
731                 __mptcp_init_sock(new_mptcp_sock);
732
733                 msk = mptcp_sk(new_mptcp_sock);
734                 msk->local_key = subflow->local_key;
735                 msk->token = subflow->token;
736                 msk->subflow = NULL;
737                 msk->first = newsk;
738
739                 mptcp_token_update_accept(newsk, new_mptcp_sock);
740
741                 msk->write_seq = subflow->idsn + 1;
742                 if (subflow->can_ack) {
743                         msk->can_ack = true;
744                         msk->remote_key = subflow->remote_key;
745                         mptcp_crypto_key_sha(msk->remote_key, NULL, &ack_seq);
746                         ack_seq++;
747                         msk->ack_seq = ack_seq;
748                 }
749                 newsk = new_mptcp_sock;
750                 mptcp_copy_inaddrs(newsk, ssk);
751                 list_add(&subflow->node, &msk->conn_list);
752
753                 /* will be fully established at mptcp_stream_accept()
754                  * completion.
755                  */
756                 inet_sk_state_store(new_mptcp_sock, TCP_SYN_RECV);
757                 bh_unlock_sock(new_mptcp_sock);
758                 local_bh_enable();
759                 release_sock(sk);
760
761                 /* the subflow can already receive packet, avoid racing with
762                  * the receive path and process the pending ones
763                  */
764                 lock_sock(ssk);
765                 subflow->rel_write_seq = 1;
766                 subflow->tcp_sock = ssk;
767                 subflow->conn = new_mptcp_sock;
768                 if (unlikely(!skb_queue_empty(&ssk->sk_receive_queue)))
769                         mptcp_subflow_data_available(ssk);
770                 release_sock(ssk);
771         }
772
773         return newsk;
774 }
775
776 static void mptcp_destroy(struct sock *sk)
777 {
778         struct mptcp_sock *msk = mptcp_sk(sk);
779
780         if (msk->cached_ext)
781                 __skb_ext_put(msk->cached_ext);
782 }
783
784 static int mptcp_setsockopt(struct sock *sk, int level, int optname,
785                             char __user *optval, unsigned int optlen)
786 {
787         struct mptcp_sock *msk = mptcp_sk(sk);
788         struct socket *ssock;
789
790         pr_debug("msk=%p", msk);
791
792         /* @@ the meaning of setsockopt() when the socket is connected and
793          * there are multiple subflows is not yet defined. It is up to the
794          * MPTCP-level socket to configure the subflows until the subflow
795          * is in TCP fallback, when TCP socket options are passed through
796          * to the one remaining subflow.
797          */
798         lock_sock(sk);
799         ssock = __mptcp_tcp_fallback(msk);
800         if (ssock)
801                 return tcp_setsockopt(ssock->sk, level, optname, optval,
802                                       optlen);
803
804         release_sock(sk);
805
806         return -EOPNOTSUPP;
807 }
808
809 static int mptcp_getsockopt(struct sock *sk, int level, int optname,
810                             char __user *optval, int __user *option)
811 {
812         struct mptcp_sock *msk = mptcp_sk(sk);
813         struct socket *ssock;
814
815         pr_debug("msk=%p", msk);
816
817         /* @@ the meaning of setsockopt() when the socket is connected and
818          * there are multiple subflows is not yet defined. It is up to the
819          * MPTCP-level socket to configure the subflows until the subflow
820          * is in TCP fallback, when socket options are passed through
821          * to the one remaining subflow.
822          */
823         lock_sock(sk);
824         ssock = __mptcp_tcp_fallback(msk);
825         if (ssock)
826                 return tcp_getsockopt(ssock->sk, level, optname, optval,
827                                       option);
828
829         release_sock(sk);
830
831         return -EOPNOTSUPP;
832 }
833
834 static int mptcp_get_port(struct sock *sk, unsigned short snum)
835 {
836         struct mptcp_sock *msk = mptcp_sk(sk);
837         struct socket *ssock;
838
839         ssock = __mptcp_nmpc_socket(msk);
840         pr_debug("msk=%p, subflow=%p", msk, ssock);
841         if (WARN_ON_ONCE(!ssock))
842                 return -EINVAL;
843
844         return inet_csk_get_port(ssock->sk, snum);
845 }
846
847 void mptcp_finish_connect(struct sock *ssk)
848 {
849         struct mptcp_subflow_context *subflow;
850         struct mptcp_sock *msk;
851         struct sock *sk;
852         u64 ack_seq;
853
854         subflow = mptcp_subflow_ctx(ssk);
855
856         if (!subflow->mp_capable)
857                 return;
858
859         sk = subflow->conn;
860         msk = mptcp_sk(sk);
861
862         pr_debug("msk=%p, token=%u", sk, subflow->token);
863
864         mptcp_crypto_key_sha(subflow->remote_key, NULL, &ack_seq);
865         ack_seq++;
866         subflow->map_seq = ack_seq;
867         subflow->map_subflow_seq = 1;
868         subflow->rel_write_seq = 1;
869
870         /* the socket is not connected yet, no msk/subflow ops can access/race
871          * accessing the field below
872          */
873         WRITE_ONCE(msk->remote_key, subflow->remote_key);
874         WRITE_ONCE(msk->local_key, subflow->local_key);
875         WRITE_ONCE(msk->token, subflow->token);
876         WRITE_ONCE(msk->write_seq, subflow->idsn + 1);
877         WRITE_ONCE(msk->ack_seq, ack_seq);
878         WRITE_ONCE(msk->can_ack, 1);
879 }
880
881 static void mptcp_sock_graft(struct sock *sk, struct socket *parent)
882 {
883         write_lock_bh(&sk->sk_callback_lock);
884         rcu_assign_pointer(sk->sk_wq, &parent->wq);
885         sk_set_socket(sk, parent);
886         sk->sk_uid = SOCK_INODE(parent)->i_uid;
887         write_unlock_bh(&sk->sk_callback_lock);
888 }
889
890 static bool mptcp_memory_free(const struct sock *sk, int wake)
891 {
892         struct mptcp_sock *msk = mptcp_sk(sk);
893
894         return wake ? test_bit(MPTCP_SEND_SPACE, &msk->flags) : true;
895 }
896
897 static struct proto mptcp_prot = {
898         .name           = "MPTCP",
899         .owner          = THIS_MODULE,
900         .init           = mptcp_init_sock,
901         .close          = mptcp_close,
902         .accept         = mptcp_accept,
903         .setsockopt     = mptcp_setsockopt,
904         .getsockopt     = mptcp_getsockopt,
905         .shutdown       = tcp_shutdown,
906         .destroy        = mptcp_destroy,
907         .sendmsg        = mptcp_sendmsg,
908         .recvmsg        = mptcp_recvmsg,
909         .hash           = inet_hash,
910         .unhash         = inet_unhash,
911         .get_port       = mptcp_get_port,
912         .stream_memory_free     = mptcp_memory_free,
913         .obj_size       = sizeof(struct mptcp_sock),
914         .no_autobind    = true,
915 };
916
917 static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
918 {
919         struct mptcp_sock *msk = mptcp_sk(sock->sk);
920         struct socket *ssock;
921         int err;
922
923         lock_sock(sock->sk);
924         ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE);
925         if (IS_ERR(ssock)) {
926                 err = PTR_ERR(ssock);
927                 goto unlock;
928         }
929
930         err = ssock->ops->bind(ssock, uaddr, addr_len);
931         if (!err)
932                 mptcp_copy_inaddrs(sock->sk, ssock->sk);
933
934 unlock:
935         release_sock(sock->sk);
936         return err;
937 }
938
939 static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
940                                 int addr_len, int flags)
941 {
942         struct mptcp_sock *msk = mptcp_sk(sock->sk);
943         struct socket *ssock;
944         int err;
945
946         lock_sock(sock->sk);
947         ssock = __mptcp_socket_create(msk, TCP_SYN_SENT);
948         if (IS_ERR(ssock)) {
949                 err = PTR_ERR(ssock);
950                 goto unlock;
951         }
952
953 #ifdef CONFIG_TCP_MD5SIG
954         /* no MPTCP if MD5SIG is enabled on this socket or we may run out of
955          * TCP option space.
956          */
957         if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info))
958                 mptcp_subflow_ctx(ssock->sk)->request_mptcp = 0;
959 #endif
960
961         err = ssock->ops->connect(ssock, uaddr, addr_len, flags);
962         inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
963         mptcp_copy_inaddrs(sock->sk, ssock->sk);
964
965 unlock:
966         release_sock(sock->sk);
967         return err;
968 }
969
970 static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr,
971                             int peer)
972 {
973         if (sock->sk->sk_prot == &tcp_prot) {
974                 /* we are being invoked from __sys_accept4, after
975                  * mptcp_accept() has just accepted a non-mp-capable
976                  * flow: sk is a tcp_sk, not an mptcp one.
977                  *
978                  * Hand the socket over to tcp so all further socket ops
979                  * bypass mptcp.
980                  */
981                 sock->ops = &inet_stream_ops;
982         }
983
984         return inet_getname(sock, uaddr, peer);
985 }
986
987 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
988 static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr,
989                             int peer)
990 {
991         if (sock->sk->sk_prot == &tcpv6_prot) {
992                 /* we are being invoked from __sys_accept4 after
993                  * mptcp_accept() has accepted a non-mp-capable
994                  * subflow: sk is a tcp_sk, not mptcp.
995                  *
996                  * Hand the socket over to tcp so all further
997                  * socket ops bypass mptcp.
998                  */
999                 sock->ops = &inet6_stream_ops;
1000         }
1001
1002         return inet6_getname(sock, uaddr, peer);
1003 }
1004 #endif
1005
1006 static int mptcp_listen(struct socket *sock, int backlog)
1007 {
1008         struct mptcp_sock *msk = mptcp_sk(sock->sk);
1009         struct socket *ssock;
1010         int err;
1011
1012         pr_debug("msk=%p", msk);
1013
1014         lock_sock(sock->sk);
1015         ssock = __mptcp_socket_create(msk, TCP_LISTEN);
1016         if (IS_ERR(ssock)) {
1017                 err = PTR_ERR(ssock);
1018                 goto unlock;
1019         }
1020
1021         err = ssock->ops->listen(ssock, backlog);
1022         inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
1023         if (!err)
1024                 mptcp_copy_inaddrs(sock->sk, ssock->sk);
1025
1026 unlock:
1027         release_sock(sock->sk);
1028         return err;
1029 }
1030
1031 static bool is_tcp_proto(const struct proto *p)
1032 {
1033 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1034         return p == &tcp_prot || p == &tcpv6_prot;
1035 #else
1036         return p == &tcp_prot;
1037 #endif
1038 }
1039
1040 static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
1041                                int flags, bool kern)
1042 {
1043         struct mptcp_sock *msk = mptcp_sk(sock->sk);
1044         struct socket *ssock;
1045         int err;
1046
1047         pr_debug("msk=%p", msk);
1048
1049         lock_sock(sock->sk);
1050         if (sock->sk->sk_state != TCP_LISTEN)
1051                 goto unlock_fail;
1052
1053         ssock = __mptcp_nmpc_socket(msk);
1054         if (!ssock)
1055                 goto unlock_fail;
1056
1057         sock_hold(ssock->sk);
1058         release_sock(sock->sk);
1059
1060         err = ssock->ops->accept(sock, newsock, flags, kern);
1061         if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) {
1062                 struct mptcp_sock *msk = mptcp_sk(newsock->sk);
1063                 struct mptcp_subflow_context *subflow;
1064
1065                 /* set ssk->sk_socket of accept()ed flows to mptcp socket.
1066                  * This is needed so NOSPACE flag can be set from tcp stack.
1067                  */
1068                 list_for_each_entry(subflow, &msk->conn_list, node) {
1069                         struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
1070
1071                         if (!ssk->sk_socket)
1072                                 mptcp_sock_graft(ssk, newsock);
1073                 }
1074
1075                 inet_sk_state_store(newsock->sk, TCP_ESTABLISHED);
1076         }
1077
1078         sock_put(ssock->sk);
1079         return err;
1080
1081 unlock_fail:
1082         release_sock(sock->sk);
1083         return -EINVAL;
1084 }
1085
1086 static __poll_t mptcp_poll(struct file *file, struct socket *sock,
1087                            struct poll_table_struct *wait)
1088 {
1089         struct sock *sk = sock->sk;
1090         struct mptcp_sock *msk;
1091         struct socket *ssock;
1092         __poll_t mask = 0;
1093
1094         msk = mptcp_sk(sk);
1095         lock_sock(sk);
1096         ssock = __mptcp_nmpc_socket(msk);
1097         if (ssock) {
1098                 mask = ssock->ops->poll(file, ssock, wait);
1099                 release_sock(sk);
1100                 return mask;
1101         }
1102
1103         release_sock(sk);
1104         sock_poll_wait(file, sock, wait);
1105         lock_sock(sk);
1106         ssock = __mptcp_tcp_fallback(msk);
1107         if (unlikely(ssock))
1108                 return ssock->ops->poll(file, ssock, NULL);
1109
1110         if (test_bit(MPTCP_DATA_READY, &msk->flags))
1111                 mask = EPOLLIN | EPOLLRDNORM;
1112         if (sk_stream_is_writeable(sk) &&
1113             test_bit(MPTCP_SEND_SPACE, &msk->flags))
1114                 mask |= EPOLLOUT | EPOLLWRNORM;
1115         if (sk->sk_shutdown & RCV_SHUTDOWN)
1116                 mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
1117
1118         release_sock(sk);
1119
1120         return mask;
1121 }
1122
1123 static int mptcp_shutdown(struct socket *sock, int how)
1124 {
1125         struct mptcp_sock *msk = mptcp_sk(sock->sk);
1126         struct mptcp_subflow_context *subflow;
1127         int ret = 0;
1128
1129         pr_debug("sk=%p, how=%d", msk, how);
1130
1131         lock_sock(sock->sk);
1132
1133         if (how == SHUT_WR || how == SHUT_RDWR)
1134                 inet_sk_state_store(sock->sk, TCP_FIN_WAIT1);
1135
1136         how++;
1137
1138         if ((how & ~SHUTDOWN_MASK) || !how) {
1139                 ret = -EINVAL;
1140                 goto out_unlock;
1141         }
1142
1143         if (sock->state == SS_CONNECTING) {
1144                 if ((1 << sock->sk->sk_state) &
1145                     (TCPF_SYN_SENT | TCPF_SYN_RECV | TCPF_CLOSE))
1146                         sock->state = SS_DISCONNECTING;
1147                 else
1148                         sock->state = SS_CONNECTED;
1149         }
1150
1151         mptcp_for_each_subflow(msk, subflow) {
1152                 struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow);
1153
1154                 mptcp_subflow_shutdown(tcp_sk, how);
1155         }
1156
1157 out_unlock:
1158         release_sock(sock->sk);
1159
1160         return ret;
1161 }
1162
1163 static const struct proto_ops mptcp_stream_ops = {
1164         .family            = PF_INET,
1165         .owner             = THIS_MODULE,
1166         .release           = inet_release,
1167         .bind              = mptcp_bind,
1168         .connect           = mptcp_stream_connect,
1169         .socketpair        = sock_no_socketpair,
1170         .accept            = mptcp_stream_accept,
1171         .getname           = mptcp_v4_getname,
1172         .poll              = mptcp_poll,
1173         .ioctl             = inet_ioctl,
1174         .gettstamp         = sock_gettstamp,
1175         .listen            = mptcp_listen,
1176         .shutdown          = mptcp_shutdown,
1177         .setsockopt        = sock_common_setsockopt,
1178         .getsockopt        = sock_common_getsockopt,
1179         .sendmsg           = inet_sendmsg,
1180         .recvmsg           = inet_recvmsg,
1181         .mmap              = sock_no_mmap,
1182         .sendpage          = inet_sendpage,
1183 #ifdef CONFIG_COMPAT
1184         .compat_setsockopt = compat_sock_common_setsockopt,
1185         .compat_getsockopt = compat_sock_common_getsockopt,
1186 #endif
1187 };
1188
1189 static struct inet_protosw mptcp_protosw = {
1190         .type           = SOCK_STREAM,
1191         .protocol       = IPPROTO_MPTCP,
1192         .prot           = &mptcp_prot,
1193         .ops            = &mptcp_stream_ops,
1194         .flags          = INET_PROTOSW_ICSK,
1195 };
1196
1197 void mptcp_proto_init(void)
1198 {
1199         mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo;
1200
1201         mptcp_subflow_init();
1202
1203         if (proto_register(&mptcp_prot, 1) != 0)
1204                 panic("Failed to register MPTCP proto.\n");
1205
1206         inet_register_protosw(&mptcp_protosw);
1207 }
1208
1209 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1210 static const struct proto_ops mptcp_v6_stream_ops = {
1211         .family            = PF_INET6,
1212         .owner             = THIS_MODULE,
1213         .release           = inet6_release,
1214         .bind              = mptcp_bind,
1215         .connect           = mptcp_stream_connect,
1216         .socketpair        = sock_no_socketpair,
1217         .accept            = mptcp_stream_accept,
1218         .getname           = mptcp_v6_getname,
1219         .poll              = mptcp_poll,
1220         .ioctl             = inet6_ioctl,
1221         .gettstamp         = sock_gettstamp,
1222         .listen            = mptcp_listen,
1223         .shutdown          = mptcp_shutdown,
1224         .setsockopt        = sock_common_setsockopt,
1225         .getsockopt        = sock_common_getsockopt,
1226         .sendmsg           = inet6_sendmsg,
1227         .recvmsg           = inet6_recvmsg,
1228         .mmap              = sock_no_mmap,
1229         .sendpage          = inet_sendpage,
1230 #ifdef CONFIG_COMPAT
1231         .compat_setsockopt = compat_sock_common_setsockopt,
1232         .compat_getsockopt = compat_sock_common_getsockopt,
1233 #endif
1234 };
1235
1236 static struct proto mptcp_v6_prot;
1237
1238 static void mptcp_v6_destroy(struct sock *sk)
1239 {
1240         mptcp_destroy(sk);
1241         inet6_destroy_sock(sk);
1242 }
1243
1244 static struct inet_protosw mptcp_v6_protosw = {
1245         .type           = SOCK_STREAM,
1246         .protocol       = IPPROTO_MPTCP,
1247         .prot           = &mptcp_v6_prot,
1248         .ops            = &mptcp_v6_stream_ops,
1249         .flags          = INET_PROTOSW_ICSK,
1250 };
1251
1252 int mptcp_proto_v6_init(void)
1253 {
1254         int err;
1255
1256         mptcp_v6_prot = mptcp_prot;
1257         strcpy(mptcp_v6_prot.name, "MPTCPv6");
1258         mptcp_v6_prot.slab = NULL;
1259         mptcp_v6_prot.destroy = mptcp_v6_destroy;
1260         mptcp_v6_prot.obj_size = sizeof(struct mptcp6_sock);
1261
1262         err = proto_register(&mptcp_v6_prot, 1);
1263         if (err)
1264                 return err;
1265
1266         err = inet6_register_protosw(&mptcp_v6_protosw);
1267         if (err)
1268                 proto_unregister(&mptcp_v6_prot);
1269
1270         return err;
1271 }
1272 #endif