Merge tag '6.5-rc1-smb3-fixes' of git://git.samba.org/sfrench/cifs-2.6
[platform/kernel/linux-starfive.git] / net / ipv6 / udp.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *      UDP over IPv6
4  *      Linux INET6 implementation
5  *
6  *      Authors:
7  *      Pedro Roque             <roque@di.fc.ul.pt>
8  *
9  *      Based on linux/ipv4/udp.c
10  *
11  *      Fixes:
12  *      Hideaki YOSHIFUJI       :       sin6_scope_id support
13  *      YOSHIFUJI Hideaki @USAGI and:   Support IPV6_V6ONLY socket option, which
14  *      Alexey Kuznetsov                allow both IPv4 and IPv6 sockets to bind
15  *                                      a single port at the same time.
16  *      Kazunori MIYAZAWA @USAGI:       change process style to use ip6_append_data
17  *      YOSHIFUJI Hideaki @USAGI:       convert /proc/net/udp6 to seq_file.
18  */
19
20 #include <linux/bpf-cgroup.h>
21 #include <linux/errno.h>
22 #include <linux/types.h>
23 #include <linux/socket.h>
24 #include <linux/sockios.h>
25 #include <linux/net.h>
26 #include <linux/in6.h>
27 #include <linux/netdevice.h>
28 #include <linux/if_arp.h>
29 #include <linux/ipv6.h>
30 #include <linux/icmpv6.h>
31 #include <linux/init.h>
32 #include <linux/module.h>
33 #include <linux/skbuff.h>
34 #include <linux/slab.h>
35 #include <linux/uaccess.h>
36 #include <linux/indirect_call_wrapper.h>
37
38 #include <net/addrconf.h>
39 #include <net/ndisc.h>
40 #include <net/protocol.h>
41 #include <net/transp_v6.h>
42 #include <net/ip6_route.h>
43 #include <net/raw.h>
44 #include <net/seg6.h>
45 #include <net/tcp_states.h>
46 #include <net/ip6_checksum.h>
47 #include <net/ip6_tunnel.h>
48 #include <trace/events/udp.h>
49 #include <net/xfrm.h>
50 #include <net/inet_hashtables.h>
51 #include <net/inet6_hashtables.h>
52 #include <net/busy_poll.h>
53 #include <net/sock_reuseport.h>
54
55 #include <linux/proc_fs.h>
56 #include <linux/seq_file.h>
57 #include <trace/events/skb.h>
58 #include "udp_impl.h"
59
60 static void udpv6_destruct_sock(struct sock *sk)
61 {
62         udp_destruct_common(sk);
63         inet6_sock_destruct(sk);
64 }
65
66 int udpv6_init_sock(struct sock *sk)
67 {
68         udp_lib_init_sock(sk);
69         sk->sk_destruct = udpv6_destruct_sock;
70         set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags);
71         return 0;
72 }
73
74 static u32 udp6_ehashfn(const struct net *net,
75                         const struct in6_addr *laddr,
76                         const u16 lport,
77                         const struct in6_addr *faddr,
78                         const __be16 fport)
79 {
80         static u32 udp6_ehash_secret __read_mostly;
81         static u32 udp_ipv6_hash_secret __read_mostly;
82
83         u32 lhash, fhash;
84
85         net_get_random_once(&udp6_ehash_secret,
86                             sizeof(udp6_ehash_secret));
87         net_get_random_once(&udp_ipv6_hash_secret,
88                             sizeof(udp_ipv6_hash_secret));
89
90         lhash = (__force u32)laddr->s6_addr32[3];
91         fhash = __ipv6_addr_jhash(faddr, udp_ipv6_hash_secret);
92
93         return __inet6_ehashfn(lhash, lport, fhash, fport,
94                                udp6_ehash_secret + net_hash_mix(net));
95 }
96
97 int udp_v6_get_port(struct sock *sk, unsigned short snum)
98 {
99         unsigned int hash2_nulladdr =
100                 ipv6_portaddr_hash(sock_net(sk), &in6addr_any, snum);
101         unsigned int hash2_partial =
102                 ipv6_portaddr_hash(sock_net(sk), &sk->sk_v6_rcv_saddr, 0);
103
104         /* precompute partial secondary hash */
105         udp_sk(sk)->udp_portaddr_hash = hash2_partial;
106         return udp_lib_get_port(sk, snum, hash2_nulladdr);
107 }
108
109 void udp_v6_rehash(struct sock *sk)
110 {
111         u16 new_hash = ipv6_portaddr_hash(sock_net(sk),
112                                           &sk->sk_v6_rcv_saddr,
113                                           inet_sk(sk)->inet_num);
114
115         udp_lib_rehash(sk, new_hash);
116 }
117
118 static int compute_score(struct sock *sk, struct net *net,
119                          const struct in6_addr *saddr, __be16 sport,
120                          const struct in6_addr *daddr, unsigned short hnum,
121                          int dif, int sdif)
122 {
123         int bound_dev_if, score;
124         struct inet_sock *inet;
125         bool dev_match;
126
127         if (!net_eq(sock_net(sk), net) ||
128             udp_sk(sk)->udp_port_hash != hnum ||
129             sk->sk_family != PF_INET6)
130                 return -1;
131
132         if (!ipv6_addr_equal(&sk->sk_v6_rcv_saddr, daddr))
133                 return -1;
134
135         score = 0;
136         inet = inet_sk(sk);
137
138         if (inet->inet_dport) {
139                 if (inet->inet_dport != sport)
140                         return -1;
141                 score++;
142         }
143
144         if (!ipv6_addr_any(&sk->sk_v6_daddr)) {
145                 if (!ipv6_addr_equal(&sk->sk_v6_daddr, saddr))
146                         return -1;
147                 score++;
148         }
149
150         bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
151         dev_match = udp_sk_bound_dev_eq(net, bound_dev_if, dif, sdif);
152         if (!dev_match)
153                 return -1;
154         if (bound_dev_if)
155                 score++;
156
157         if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id())
158                 score++;
159
160         return score;
161 }
162
163 static struct sock *lookup_reuseport(struct net *net, struct sock *sk,
164                                      struct sk_buff *skb,
165                                      const struct in6_addr *saddr,
166                                      __be16 sport,
167                                      const struct in6_addr *daddr,
168                                      unsigned int hnum)
169 {
170         struct sock *reuse_sk = NULL;
171         u32 hash;
172
173         if (sk->sk_reuseport && sk->sk_state != TCP_ESTABLISHED) {
174                 hash = udp6_ehashfn(net, daddr, hnum, saddr, sport);
175                 reuse_sk = reuseport_select_sock(sk, hash, skb,
176                                                  sizeof(struct udphdr));
177         }
178         return reuse_sk;
179 }
180
181 /* called with rcu_read_lock() */
182 static struct sock *udp6_lib_lookup2(struct net *net,
183                 const struct in6_addr *saddr, __be16 sport,
184                 const struct in6_addr *daddr, unsigned int hnum,
185                 int dif, int sdif, struct udp_hslot *hslot2,
186                 struct sk_buff *skb)
187 {
188         struct sock *sk, *result;
189         int score, badness;
190
191         result = NULL;
192         badness = -1;
193         udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
194                 score = compute_score(sk, net, saddr, sport,
195                                       daddr, hnum, dif, sdif);
196                 if (score > badness) {
197                         result = lookup_reuseport(net, sk, skb,
198                                                   saddr, sport, daddr, hnum);
199                         /* Fall back to scoring if group has connections */
200                         if (result && !reuseport_has_conns(sk))
201                                 return result;
202
203                         result = result ? : sk;
204                         badness = score;
205                 }
206         }
207         return result;
208 }
209
210 static inline struct sock *udp6_lookup_run_bpf(struct net *net,
211                                                struct udp_table *udptable,
212                                                struct sk_buff *skb,
213                                                const struct in6_addr *saddr,
214                                                __be16 sport,
215                                                const struct in6_addr *daddr,
216                                                u16 hnum, const int dif)
217 {
218         struct sock *sk, *reuse_sk;
219         bool no_reuseport;
220
221         if (udptable != net->ipv4.udp_table)
222                 return NULL; /* only UDP is supported */
223
224         no_reuseport = bpf_sk_lookup_run_v6(net, IPPROTO_UDP, saddr, sport,
225                                             daddr, hnum, dif, &sk);
226         if (no_reuseport || IS_ERR_OR_NULL(sk))
227                 return sk;
228
229         reuse_sk = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum);
230         if (reuse_sk)
231                 sk = reuse_sk;
232         return sk;
233 }
234
235 /* rcu_read_lock() must be held */
236 struct sock *__udp6_lib_lookup(struct net *net,
237                                const struct in6_addr *saddr, __be16 sport,
238                                const struct in6_addr *daddr, __be16 dport,
239                                int dif, int sdif, struct udp_table *udptable,
240                                struct sk_buff *skb)
241 {
242         unsigned short hnum = ntohs(dport);
243         unsigned int hash2, slot2;
244         struct udp_hslot *hslot2;
245         struct sock *result, *sk;
246
247         hash2 = ipv6_portaddr_hash(net, daddr, hnum);
248         slot2 = hash2 & udptable->mask;
249         hslot2 = &udptable->hash2[slot2];
250
251         /* Lookup connected or non-wildcard sockets */
252         result = udp6_lib_lookup2(net, saddr, sport,
253                                   daddr, hnum, dif, sdif,
254                                   hslot2, skb);
255         if (!IS_ERR_OR_NULL(result) && result->sk_state == TCP_ESTABLISHED)
256                 goto done;
257
258         /* Lookup redirect from BPF */
259         if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
260                 sk = udp6_lookup_run_bpf(net, udptable, skb,
261                                          saddr, sport, daddr, hnum, dif);
262                 if (sk) {
263                         result = sk;
264                         goto done;
265                 }
266         }
267
268         /* Got non-wildcard socket or error on first lookup */
269         if (result)
270                 goto done;
271
272         /* Lookup wildcard sockets */
273         hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum);
274         slot2 = hash2 & udptable->mask;
275         hslot2 = &udptable->hash2[slot2];
276
277         result = udp6_lib_lookup2(net, saddr, sport,
278                                   &in6addr_any, hnum, dif, sdif,
279                                   hslot2, skb);
280 done:
281         if (IS_ERR(result))
282                 return NULL;
283         return result;
284 }
285 EXPORT_SYMBOL_GPL(__udp6_lib_lookup);
286
287 static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb,
288                                           __be16 sport, __be16 dport,
289                                           struct udp_table *udptable)
290 {
291         const struct ipv6hdr *iph = ipv6_hdr(skb);
292
293         return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport,
294                                  &iph->daddr, dport, inet6_iif(skb),
295                                  inet6_sdif(skb), udptable, skb);
296 }
297
298 struct sock *udp6_lib_lookup_skb(const struct sk_buff *skb,
299                                  __be16 sport, __be16 dport)
300 {
301         const struct ipv6hdr *iph = ipv6_hdr(skb);
302         struct net *net = dev_net(skb->dev);
303
304         return __udp6_lib_lookup(net, &iph->saddr, sport,
305                                  &iph->daddr, dport, inet6_iif(skb),
306                                  inet6_sdif(skb), net->ipv4.udp_table, NULL);
307 }
308
309 /* Must be called under rcu_read_lock().
310  * Does increment socket refcount.
311  */
312 #if IS_ENABLED(CONFIG_NF_TPROXY_IPV6) || IS_ENABLED(CONFIG_NF_SOCKET_IPV6)
313 struct sock *udp6_lib_lookup(struct net *net, const struct in6_addr *saddr, __be16 sport,
314                              const struct in6_addr *daddr, __be16 dport, int dif)
315 {
316         struct sock *sk;
317
318         sk =  __udp6_lib_lookup(net, saddr, sport, daddr, dport,
319                                 dif, 0, net->ipv4.udp_table, NULL);
320         if (sk && !refcount_inc_not_zero(&sk->sk_refcnt))
321                 sk = NULL;
322         return sk;
323 }
324 EXPORT_SYMBOL_GPL(udp6_lib_lookup);
325 #endif
326
327 /* do not use the scratch area len for jumbogram: their length execeeds the
328  * scratch area space; note that the IP6CB flags is still in the first
329  * cacheline, so checking for jumbograms is cheap
330  */
331 static int udp6_skb_len(struct sk_buff *skb)
332 {
333         return unlikely(inet6_is_jumbogram(skb)) ? skb->len : udp_skb_len(skb);
334 }
335
336 /*
337  *      This should be easy, if there is something there we
338  *      return it, otherwise we block.
339  */
340
341 int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
342                   int flags, int *addr_len)
343 {
344         struct ipv6_pinfo *np = inet6_sk(sk);
345         struct inet_sock *inet = inet_sk(sk);
346         struct sk_buff *skb;
347         unsigned int ulen, copied;
348         int off, err, peeking = flags & MSG_PEEK;
349         int is_udplite = IS_UDPLITE(sk);
350         struct udp_mib __percpu *mib;
351         bool checksum_valid = false;
352         int is_udp4;
353
354         if (flags & MSG_ERRQUEUE)
355                 return ipv6_recv_error(sk, msg, len, addr_len);
356
357         if (np->rxpmtu && np->rxopt.bits.rxpmtu)
358                 return ipv6_recv_rxpmtu(sk, msg, len, addr_len);
359
360 try_again:
361         off = sk_peek_offset(sk, flags);
362         skb = __skb_recv_udp(sk, flags, &off, &err);
363         if (!skb)
364                 return err;
365
366         ulen = udp6_skb_len(skb);
367         copied = len;
368         if (copied > ulen - off)
369                 copied = ulen - off;
370         else if (copied < ulen)
371                 msg->msg_flags |= MSG_TRUNC;
372
373         is_udp4 = (skb->protocol == htons(ETH_P_IP));
374         mib = __UDPX_MIB(sk, is_udp4);
375
376         /*
377          * If checksum is needed at all, try to do it while copying the
378          * data.  If the data is truncated, or if we only want a partial
379          * coverage checksum (UDP-Lite), do it before the copy.
380          */
381
382         if (copied < ulen || peeking ||
383             (is_udplite && UDP_SKB_CB(skb)->partial_cov)) {
384                 checksum_valid = udp_skb_csum_unnecessary(skb) ||
385                                 !__udp_lib_checksum_complete(skb);
386                 if (!checksum_valid)
387                         goto csum_copy_err;
388         }
389
390         if (checksum_valid || udp_skb_csum_unnecessary(skb)) {
391                 if (udp_skb_is_linear(skb))
392                         err = copy_linear_skb(skb, copied, off, &msg->msg_iter);
393                 else
394                         err = skb_copy_datagram_msg(skb, off, msg, copied);
395         } else {
396                 err = skb_copy_and_csum_datagram_msg(skb, off, msg);
397                 if (err == -EINVAL)
398                         goto csum_copy_err;
399         }
400         if (unlikely(err)) {
401                 if (!peeking) {
402                         atomic_inc(&sk->sk_drops);
403                         SNMP_INC_STATS(mib, UDP_MIB_INERRORS);
404                 }
405                 kfree_skb(skb);
406                 return err;
407         }
408         if (!peeking)
409                 SNMP_INC_STATS(mib, UDP_MIB_INDATAGRAMS);
410
411         sock_recv_cmsgs(msg, sk, skb);
412
413         /* Copy the address. */
414         if (msg->msg_name) {
415                 DECLARE_SOCKADDR(struct sockaddr_in6 *, sin6, msg->msg_name);
416                 sin6->sin6_family = AF_INET6;
417                 sin6->sin6_port = udp_hdr(skb)->source;
418                 sin6->sin6_flowinfo = 0;
419
420                 if (is_udp4) {
421                         ipv6_addr_set_v4mapped(ip_hdr(skb)->saddr,
422                                                &sin6->sin6_addr);
423                         sin6->sin6_scope_id = 0;
424                 } else {
425                         sin6->sin6_addr = ipv6_hdr(skb)->saddr;
426                         sin6->sin6_scope_id =
427                                 ipv6_iface_scope_id(&sin6->sin6_addr,
428                                                     inet6_iif(skb));
429                 }
430                 *addr_len = sizeof(*sin6);
431
432                 BPF_CGROUP_RUN_PROG_UDP6_RECVMSG_LOCK(sk,
433                                                       (struct sockaddr *)sin6);
434         }
435
436         if (udp_sk(sk)->gro_enabled)
437                 udp_cmsg_recv(msg, sk, skb);
438
439         if (np->rxopt.all)
440                 ip6_datagram_recv_common_ctl(sk, msg, skb);
441
442         if (is_udp4) {
443                 if (inet->cmsg_flags)
444                         ip_cmsg_recv_offset(msg, sk, skb,
445                                             sizeof(struct udphdr), off);
446         } else {
447                 if (np->rxopt.all)
448                         ip6_datagram_recv_specific_ctl(sk, msg, skb);
449         }
450
451         err = copied;
452         if (flags & MSG_TRUNC)
453                 err = ulen;
454
455         skb_consume_udp(sk, skb, peeking ? -err : err);
456         return err;
457
458 csum_copy_err:
459         if (!__sk_queue_drop_skb(sk, &udp_sk(sk)->reader_queue, skb, flags,
460                                  udp_skb_destructor)) {
461                 SNMP_INC_STATS(mib, UDP_MIB_CSUMERRORS);
462                 SNMP_INC_STATS(mib, UDP_MIB_INERRORS);
463         }
464         kfree_skb(skb);
465
466         /* starting over for a new packet, but check if we need to yield */
467         cond_resched();
468         msg->msg_flags &= ~MSG_TRUNC;
469         goto try_again;
470 }
471
472 DEFINE_STATIC_KEY_FALSE(udpv6_encap_needed_key);
473 void udpv6_encap_enable(void)
474 {
475         static_branch_inc(&udpv6_encap_needed_key);
476 }
477 EXPORT_SYMBOL(udpv6_encap_enable);
478
479 /* Handler for tunnels with arbitrary destination ports: no socket lookup, go
480  * through error handlers in encapsulations looking for a match.
481  */
482 static int __udp6_lib_err_encap_no_sk(struct sk_buff *skb,
483                                       struct inet6_skb_parm *opt,
484                                       u8 type, u8 code, int offset, __be32 info)
485 {
486         int i;
487
488         for (i = 0; i < MAX_IPTUN_ENCAP_OPS; i++) {
489                 int (*handler)(struct sk_buff *skb, struct inet6_skb_parm *opt,
490                                u8 type, u8 code, int offset, __be32 info);
491                 const struct ip6_tnl_encap_ops *encap;
492
493                 encap = rcu_dereference(ip6tun_encaps[i]);
494                 if (!encap)
495                         continue;
496                 handler = encap->err_handler;
497                 if (handler && !handler(skb, opt, type, code, offset, info))
498                         return 0;
499         }
500
501         return -ENOENT;
502 }
503
504 /* Try to match ICMP errors to UDP tunnels by looking up a socket without
505  * reversing source and destination port: this will match tunnels that force the
506  * same destination port on both endpoints (e.g. VXLAN, GENEVE). Note that
507  * lwtunnels might actually break this assumption by being configured with
508  * different destination ports on endpoints, in this case we won't be able to
509  * trace ICMP messages back to them.
510  *
511  * If this doesn't match any socket, probe tunnels with arbitrary destination
512  * ports (e.g. FoU, GUE): there, the receiving socket is useless, as the port
513  * we've sent packets to won't necessarily match the local destination port.
514  *
515  * Then ask the tunnel implementation to match the error against a valid
516  * association.
517  *
518  * Return an error if we can't find a match, the socket if we need further
519  * processing, zero otherwise.
520  */
521 static struct sock *__udp6_lib_err_encap(struct net *net,
522                                          const struct ipv6hdr *hdr, int offset,
523                                          struct udphdr *uh,
524                                          struct udp_table *udptable,
525                                          struct sock *sk,
526                                          struct sk_buff *skb,
527                                          struct inet6_skb_parm *opt,
528                                          u8 type, u8 code, __be32 info)
529 {
530         int (*lookup)(struct sock *sk, struct sk_buff *skb);
531         int network_offset, transport_offset;
532         struct udp_sock *up;
533
534         network_offset = skb_network_offset(skb);
535         transport_offset = skb_transport_offset(skb);
536
537         /* Network header needs to point to the outer IPv6 header inside ICMP */
538         skb_reset_network_header(skb);
539
540         /* Transport header needs to point to the UDP header */
541         skb_set_transport_header(skb, offset);
542
543         if (sk) {
544                 up = udp_sk(sk);
545
546                 lookup = READ_ONCE(up->encap_err_lookup);
547                 if (lookup && lookup(sk, skb))
548                         sk = NULL;
549
550                 goto out;
551         }
552
553         sk = __udp6_lib_lookup(net, &hdr->daddr, uh->source,
554                                &hdr->saddr, uh->dest,
555                                inet6_iif(skb), 0, udptable, skb);
556         if (sk) {
557                 up = udp_sk(sk);
558
559                 lookup = READ_ONCE(up->encap_err_lookup);
560                 if (!lookup || lookup(sk, skb))
561                         sk = NULL;
562         }
563
564 out:
565         if (!sk) {
566                 sk = ERR_PTR(__udp6_lib_err_encap_no_sk(skb, opt, type, code,
567                                                         offset, info));
568         }
569
570         skb_set_transport_header(skb, transport_offset);
571         skb_set_network_header(skb, network_offset);
572
573         return sk;
574 }
575
576 int __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
577                    u8 type, u8 code, int offset, __be32 info,
578                    struct udp_table *udptable)
579 {
580         struct ipv6_pinfo *np;
581         const struct ipv6hdr *hdr = (const struct ipv6hdr *)skb->data;
582         const struct in6_addr *saddr = &hdr->saddr;
583         const struct in6_addr *daddr = seg6_get_daddr(skb, opt) ? : &hdr->daddr;
584         struct udphdr *uh = (struct udphdr *)(skb->data+offset);
585         bool tunnel = false;
586         struct sock *sk;
587         int harderr;
588         int err;
589         struct net *net = dev_net(skb->dev);
590
591         sk = __udp6_lib_lookup(net, daddr, uh->dest, saddr, uh->source,
592                                inet6_iif(skb), inet6_sdif(skb), udptable, NULL);
593
594         if (!sk || udp_sk(sk)->encap_type) {
595                 /* No socket for error: try tunnels before discarding */
596                 if (static_branch_unlikely(&udpv6_encap_needed_key)) {
597                         sk = __udp6_lib_err_encap(net, hdr, offset, uh,
598                                                   udptable, sk, skb,
599                                                   opt, type, code, info);
600                         if (!sk)
601                                 return 0;
602                 } else
603                         sk = ERR_PTR(-ENOENT);
604
605                 if (IS_ERR(sk)) {
606                         __ICMP6_INC_STATS(net, __in6_dev_get(skb->dev),
607                                           ICMP6_MIB_INERRORS);
608                         return PTR_ERR(sk);
609                 }
610
611                 tunnel = true;
612         }
613
614         harderr = icmpv6_err_convert(type, code, &err);
615         np = inet6_sk(sk);
616
617         if (type == ICMPV6_PKT_TOOBIG) {
618                 if (!ip6_sk_accept_pmtu(sk))
619                         goto out;
620                 ip6_sk_update_pmtu(skb, sk, info);
621                 if (np->pmtudisc != IPV6_PMTUDISC_DONT)
622                         harderr = 1;
623         }
624         if (type == NDISC_REDIRECT) {
625                 if (tunnel) {
626                         ip6_redirect(skb, sock_net(sk), inet6_iif(skb),
627                                      sk->sk_mark, sk->sk_uid);
628                 } else {
629                         ip6_sk_redirect(skb, sk);
630                 }
631                 goto out;
632         }
633
634         /* Tunnels don't have an application socket: don't pass errors back */
635         if (tunnel) {
636                 if (udp_sk(sk)->encap_err_rcv)
637                         udp_sk(sk)->encap_err_rcv(sk, skb, err, uh->dest,
638                                                   ntohl(info), (u8 *)(uh+1));
639                 goto out;
640         }
641
642         if (!np->recverr) {
643                 if (!harderr || sk->sk_state != TCP_ESTABLISHED)
644                         goto out;
645         } else {
646                 ipv6_icmp_error(sk, skb, err, uh->dest, ntohl(info), (u8 *)(uh+1));
647         }
648
649         sk->sk_err = err;
650         sk_error_report(sk);
651 out:
652         return 0;
653 }
654
655 static int __udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
656 {
657         int rc;
658
659         if (!ipv6_addr_any(&sk->sk_v6_daddr)) {
660                 sock_rps_save_rxhash(sk, skb);
661                 sk_mark_napi_id(sk, skb);
662                 sk_incoming_cpu_update(sk);
663         } else {
664                 sk_mark_napi_id_once(sk, skb);
665         }
666
667         rc = __udp_enqueue_schedule_skb(sk, skb);
668         if (rc < 0) {
669                 int is_udplite = IS_UDPLITE(sk);
670                 enum skb_drop_reason drop_reason;
671
672                 /* Note that an ENOMEM error is charged twice */
673                 if (rc == -ENOMEM) {
674                         UDP6_INC_STATS(sock_net(sk),
675                                          UDP_MIB_RCVBUFERRORS, is_udplite);
676                         drop_reason = SKB_DROP_REASON_SOCKET_RCVBUFF;
677                 } else {
678                         UDP6_INC_STATS(sock_net(sk),
679                                        UDP_MIB_MEMERRORS, is_udplite);
680                         drop_reason = SKB_DROP_REASON_PROTO_MEM;
681                 }
682                 UDP6_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, is_udplite);
683                 kfree_skb_reason(skb, drop_reason);
684                 trace_udp_fail_queue_rcv_skb(rc, sk);
685                 return -1;
686         }
687
688         return 0;
689 }
690
691 static __inline__ int udpv6_err(struct sk_buff *skb,
692                                 struct inet6_skb_parm *opt, u8 type,
693                                 u8 code, int offset, __be32 info)
694 {
695         return __udp6_lib_err(skb, opt, type, code, offset, info,
696                               dev_net(skb->dev)->ipv4.udp_table);
697 }
698
699 static int udpv6_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb)
700 {
701         enum skb_drop_reason drop_reason = SKB_DROP_REASON_NOT_SPECIFIED;
702         struct udp_sock *up = udp_sk(sk);
703         int is_udplite = IS_UDPLITE(sk);
704
705         if (!xfrm6_policy_check(sk, XFRM_POLICY_IN, skb)) {
706                 drop_reason = SKB_DROP_REASON_XFRM_POLICY;
707                 goto drop;
708         }
709         nf_reset_ct(skb);
710
711         if (static_branch_unlikely(&udpv6_encap_needed_key) && up->encap_type) {
712                 int (*encap_rcv)(struct sock *sk, struct sk_buff *skb);
713
714                 /*
715                  * This is an encapsulation socket so pass the skb to
716                  * the socket's udp_encap_rcv() hook. Otherwise, just
717                  * fall through and pass this up the UDP socket.
718                  * up->encap_rcv() returns the following value:
719                  * =0 if skb was successfully passed to the encap
720                  *    handler or was discarded by it.
721                  * >0 if skb should be passed on to UDP.
722                  * <0 if skb should be resubmitted as proto -N
723                  */
724
725                 /* if we're overly short, let UDP handle it */
726                 encap_rcv = READ_ONCE(up->encap_rcv);
727                 if (encap_rcv) {
728                         int ret;
729
730                         /* Verify checksum before giving to encap */
731                         if (udp_lib_checksum_complete(skb))
732                                 goto csum_error;
733
734                         ret = encap_rcv(sk, skb);
735                         if (ret <= 0) {
736                                 __UDP6_INC_STATS(sock_net(sk),
737                                                  UDP_MIB_INDATAGRAMS,
738                                                  is_udplite);
739                                 return -ret;
740                         }
741                 }
742
743                 /* FALLTHROUGH -- it's a UDP Packet */
744         }
745
746         /*
747          * UDP-Lite specific tests, ignored on UDP sockets (see net/ipv4/udp.c).
748          */
749         if ((up->pcflag & UDPLITE_RECV_CC)  &&  UDP_SKB_CB(skb)->partial_cov) {
750
751                 if (up->pcrlen == 0) {          /* full coverage was set  */
752                         net_dbg_ratelimited("UDPLITE6: partial coverage %d while full coverage %d requested\n",
753                                             UDP_SKB_CB(skb)->cscov, skb->len);
754                         goto drop;
755                 }
756                 if (UDP_SKB_CB(skb)->cscov  <  up->pcrlen) {
757                         net_dbg_ratelimited("UDPLITE6: coverage %d too small, need min %d\n",
758                                             UDP_SKB_CB(skb)->cscov, up->pcrlen);
759                         goto drop;
760                 }
761         }
762
763         prefetch(&sk->sk_rmem_alloc);
764         if (rcu_access_pointer(sk->sk_filter) &&
765             udp_lib_checksum_complete(skb))
766                 goto csum_error;
767
768         if (sk_filter_trim_cap(sk, skb, sizeof(struct udphdr))) {
769                 drop_reason = SKB_DROP_REASON_SOCKET_FILTER;
770                 goto drop;
771         }
772
773         udp_csum_pull_header(skb);
774
775         skb_dst_drop(skb);
776
777         return __udpv6_queue_rcv_skb(sk, skb);
778
779 csum_error:
780         drop_reason = SKB_DROP_REASON_UDP_CSUM;
781         __UDP6_INC_STATS(sock_net(sk), UDP_MIB_CSUMERRORS, is_udplite);
782 drop:
783         __UDP6_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, is_udplite);
784         atomic_inc(&sk->sk_drops);
785         kfree_skb_reason(skb, drop_reason);
786         return -1;
787 }
788
789 static int udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
790 {
791         struct sk_buff *next, *segs;
792         int ret;
793
794         if (likely(!udp_unexpected_gso(sk, skb)))
795                 return udpv6_queue_rcv_one_skb(sk, skb);
796
797         __skb_push(skb, -skb_mac_offset(skb));
798         segs = udp_rcv_segment(sk, skb, false);
799         skb_list_walk_safe(segs, skb, next) {
800                 __skb_pull(skb, skb_transport_offset(skb));
801
802                 udp_post_segment_fix_csum(skb);
803                 ret = udpv6_queue_rcv_one_skb(sk, skb);
804                 if (ret > 0)
805                         ip6_protocol_deliver_rcu(dev_net(skb->dev), skb, ret,
806                                                  true);
807         }
808         return 0;
809 }
810
811 static bool __udp_v6_is_mcast_sock(struct net *net, const struct sock *sk,
812                                    __be16 loc_port, const struct in6_addr *loc_addr,
813                                    __be16 rmt_port, const struct in6_addr *rmt_addr,
814                                    int dif, int sdif, unsigned short hnum)
815 {
816         const struct inet_sock *inet = inet_sk(sk);
817
818         if (!net_eq(sock_net(sk), net))
819                 return false;
820
821         if (udp_sk(sk)->udp_port_hash != hnum ||
822             sk->sk_family != PF_INET6 ||
823             (inet->inet_dport && inet->inet_dport != rmt_port) ||
824             (!ipv6_addr_any(&sk->sk_v6_daddr) &&
825                     !ipv6_addr_equal(&sk->sk_v6_daddr, rmt_addr)) ||
826             !udp_sk_bound_dev_eq(net, READ_ONCE(sk->sk_bound_dev_if), dif, sdif) ||
827             (!ipv6_addr_any(&sk->sk_v6_rcv_saddr) &&
828                     !ipv6_addr_equal(&sk->sk_v6_rcv_saddr, loc_addr)))
829                 return false;
830         if (!inet6_mc_check(sk, loc_addr, rmt_addr))
831                 return false;
832         return true;
833 }
834
835 static void udp6_csum_zero_error(struct sk_buff *skb)
836 {
837         /* RFC 2460 section 8.1 says that we SHOULD log
838          * this error. Well, it is reasonable.
839          */
840         net_dbg_ratelimited("IPv6: udp checksum is 0 for [%pI6c]:%u->[%pI6c]:%u\n",
841                             &ipv6_hdr(skb)->saddr, ntohs(udp_hdr(skb)->source),
842                             &ipv6_hdr(skb)->daddr, ntohs(udp_hdr(skb)->dest));
843 }
844
845 /*
846  * Note: called only from the BH handler context,
847  * so we don't need to lock the hashes.
848  */
849 static int __udp6_lib_mcast_deliver(struct net *net, struct sk_buff *skb,
850                 const struct in6_addr *saddr, const struct in6_addr *daddr,
851                 struct udp_table *udptable, int proto)
852 {
853         struct sock *sk, *first = NULL;
854         const struct udphdr *uh = udp_hdr(skb);
855         unsigned short hnum = ntohs(uh->dest);
856         struct udp_hslot *hslot = udp_hashslot(udptable, net, hnum);
857         unsigned int offset = offsetof(typeof(*sk), sk_node);
858         unsigned int hash2 = 0, hash2_any = 0, use_hash2 = (hslot->count > 10);
859         int dif = inet6_iif(skb);
860         int sdif = inet6_sdif(skb);
861         struct hlist_node *node;
862         struct sk_buff *nskb;
863
864         if (use_hash2) {
865                 hash2_any = ipv6_portaddr_hash(net, &in6addr_any, hnum) &
866                             udptable->mask;
867                 hash2 = ipv6_portaddr_hash(net, daddr, hnum) & udptable->mask;
868 start_lookup:
869                 hslot = &udptable->hash2[hash2];
870                 offset = offsetof(typeof(*sk), __sk_common.skc_portaddr_node);
871         }
872
873         sk_for_each_entry_offset_rcu(sk, node, &hslot->head, offset) {
874                 if (!__udp_v6_is_mcast_sock(net, sk, uh->dest, daddr,
875                                             uh->source, saddr, dif, sdif,
876                                             hnum))
877                         continue;
878                 /* If zero checksum and no_check is not on for
879                  * the socket then skip it.
880                  */
881                 if (!uh->check && !udp_sk(sk)->no_check6_rx)
882                         continue;
883                 if (!first) {
884                         first = sk;
885                         continue;
886                 }
887                 nskb = skb_clone(skb, GFP_ATOMIC);
888                 if (unlikely(!nskb)) {
889                         atomic_inc(&sk->sk_drops);
890                         __UDP6_INC_STATS(net, UDP_MIB_RCVBUFERRORS,
891                                          IS_UDPLITE(sk));
892                         __UDP6_INC_STATS(net, UDP_MIB_INERRORS,
893                                          IS_UDPLITE(sk));
894                         continue;
895                 }
896
897                 if (udpv6_queue_rcv_skb(sk, nskb) > 0)
898                         consume_skb(nskb);
899         }
900
901         /* Also lookup *:port if we are using hash2 and haven't done so yet. */
902         if (use_hash2 && hash2 != hash2_any) {
903                 hash2 = hash2_any;
904                 goto start_lookup;
905         }
906
907         if (first) {
908                 if (udpv6_queue_rcv_skb(first, skb) > 0)
909                         consume_skb(skb);
910         } else {
911                 kfree_skb(skb);
912                 __UDP6_INC_STATS(net, UDP_MIB_IGNOREDMULTI,
913                                  proto == IPPROTO_UDPLITE);
914         }
915         return 0;
916 }
917
918 static void udp6_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst)
919 {
920         if (udp_sk_rx_dst_set(sk, dst)) {
921                 const struct rt6_info *rt = (const struct rt6_info *)dst;
922
923                 sk->sk_rx_dst_cookie = rt6_get_cookie(rt);
924         }
925 }
926
927 /* wrapper for udp_queue_rcv_skb tacking care of csum conversion and
928  * return code conversion for ip layer consumption
929  */
930 static int udp6_unicast_rcv_skb(struct sock *sk, struct sk_buff *skb,
931                                 struct udphdr *uh)
932 {
933         int ret;
934
935         if (inet_get_convert_csum(sk) && uh->check && !IS_UDPLITE(sk))
936                 skb_checksum_try_convert(skb, IPPROTO_UDP, ip6_compute_pseudo);
937
938         ret = udpv6_queue_rcv_skb(sk, skb);
939
940         /* a return value > 0 means to resubmit the input */
941         if (ret > 0)
942                 return ret;
943         return 0;
944 }
945
946 int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
947                    int proto)
948 {
949         enum skb_drop_reason reason = SKB_DROP_REASON_NOT_SPECIFIED;
950         const struct in6_addr *saddr, *daddr;
951         struct net *net = dev_net(skb->dev);
952         struct udphdr *uh;
953         struct sock *sk;
954         bool refcounted;
955         u32 ulen = 0;
956
957         if (!pskb_may_pull(skb, sizeof(struct udphdr)))
958                 goto discard;
959
960         saddr = &ipv6_hdr(skb)->saddr;
961         daddr = &ipv6_hdr(skb)->daddr;
962         uh = udp_hdr(skb);
963
964         ulen = ntohs(uh->len);
965         if (ulen > skb->len)
966                 goto short_packet;
967
968         if (proto == IPPROTO_UDP) {
969                 /* UDP validates ulen. */
970
971                 /* Check for jumbo payload */
972                 if (ulen == 0)
973                         ulen = skb->len;
974
975                 if (ulen < sizeof(*uh))
976                         goto short_packet;
977
978                 if (ulen < skb->len) {
979                         if (pskb_trim_rcsum(skb, ulen))
980                                 goto short_packet;
981                         saddr = &ipv6_hdr(skb)->saddr;
982                         daddr = &ipv6_hdr(skb)->daddr;
983                         uh = udp_hdr(skb);
984                 }
985         }
986
987         if (udp6_csum_init(skb, uh, proto))
988                 goto csum_error;
989
990         /* Check if the socket is already available, e.g. due to early demux */
991         sk = skb_steal_sock(skb, &refcounted);
992         if (sk) {
993                 struct dst_entry *dst = skb_dst(skb);
994                 int ret;
995
996                 if (unlikely(rcu_dereference(sk->sk_rx_dst) != dst))
997                         udp6_sk_rx_dst_set(sk, dst);
998
999                 if (!uh->check && !udp_sk(sk)->no_check6_rx) {
1000                         if (refcounted)
1001                                 sock_put(sk);
1002                         goto report_csum_error;
1003                 }
1004
1005                 ret = udp6_unicast_rcv_skb(sk, skb, uh);
1006                 if (refcounted)
1007                         sock_put(sk);
1008                 return ret;
1009         }
1010
1011         /*
1012          *      Multicast receive code
1013          */
1014         if (ipv6_addr_is_multicast(daddr))
1015                 return __udp6_lib_mcast_deliver(net, skb,
1016                                 saddr, daddr, udptable, proto);
1017
1018         /* Unicast */
1019         sk = __udp6_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
1020         if (sk) {
1021                 if (!uh->check && !udp_sk(sk)->no_check6_rx)
1022                         goto report_csum_error;
1023                 return udp6_unicast_rcv_skb(sk, skb, uh);
1024         }
1025
1026         reason = SKB_DROP_REASON_NO_SOCKET;
1027
1028         if (!uh->check)
1029                 goto report_csum_error;
1030
1031         if (!xfrm6_policy_check(NULL, XFRM_POLICY_IN, skb))
1032                 goto discard;
1033         nf_reset_ct(skb);
1034
1035         if (udp_lib_checksum_complete(skb))
1036                 goto csum_error;
1037
1038         __UDP6_INC_STATS(net, UDP_MIB_NOPORTS, proto == IPPROTO_UDPLITE);
1039         icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_PORT_UNREACH, 0);
1040
1041         kfree_skb_reason(skb, reason);
1042         return 0;
1043
1044 short_packet:
1045         if (reason == SKB_DROP_REASON_NOT_SPECIFIED)
1046                 reason = SKB_DROP_REASON_PKT_TOO_SMALL;
1047         net_dbg_ratelimited("UDP%sv6: short packet: From [%pI6c]:%u %d/%d to [%pI6c]:%u\n",
1048                             proto == IPPROTO_UDPLITE ? "-Lite" : "",
1049                             saddr, ntohs(uh->source),
1050                             ulen, skb->len,
1051                             daddr, ntohs(uh->dest));
1052         goto discard;
1053
1054 report_csum_error:
1055         udp6_csum_zero_error(skb);
1056 csum_error:
1057         if (reason == SKB_DROP_REASON_NOT_SPECIFIED)
1058                 reason = SKB_DROP_REASON_UDP_CSUM;
1059         __UDP6_INC_STATS(net, UDP_MIB_CSUMERRORS, proto == IPPROTO_UDPLITE);
1060 discard:
1061         __UDP6_INC_STATS(net, UDP_MIB_INERRORS, proto == IPPROTO_UDPLITE);
1062         kfree_skb_reason(skb, reason);
1063         return 0;
1064 }
1065
1066
1067 static struct sock *__udp6_lib_demux_lookup(struct net *net,
1068                         __be16 loc_port, const struct in6_addr *loc_addr,
1069                         __be16 rmt_port, const struct in6_addr *rmt_addr,
1070                         int dif, int sdif)
1071 {
1072         struct udp_table *udptable = net->ipv4.udp_table;
1073         unsigned short hnum = ntohs(loc_port);
1074         unsigned int hash2, slot2;
1075         struct udp_hslot *hslot2;
1076         __portpair ports;
1077         struct sock *sk;
1078
1079         hash2 = ipv6_portaddr_hash(net, loc_addr, hnum);
1080         slot2 = hash2 & udptable->mask;
1081         hslot2 = &udptable->hash2[slot2];
1082         ports = INET_COMBINED_PORTS(rmt_port, hnum);
1083
1084         udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
1085                 if (sk->sk_state == TCP_ESTABLISHED &&
1086                     inet6_match(net, sk, rmt_addr, loc_addr, ports, dif, sdif))
1087                         return sk;
1088                 /* Only check first socket in chain */
1089                 break;
1090         }
1091         return NULL;
1092 }
1093
1094 void udp_v6_early_demux(struct sk_buff *skb)
1095 {
1096         struct net *net = dev_net(skb->dev);
1097         const struct udphdr *uh;
1098         struct sock *sk;
1099         struct dst_entry *dst;
1100         int dif = skb->dev->ifindex;
1101         int sdif = inet6_sdif(skb);
1102
1103         if (!pskb_may_pull(skb, skb_transport_offset(skb) +
1104             sizeof(struct udphdr)))
1105                 return;
1106
1107         uh = udp_hdr(skb);
1108
1109         if (skb->pkt_type == PACKET_HOST)
1110                 sk = __udp6_lib_demux_lookup(net, uh->dest,
1111                                              &ipv6_hdr(skb)->daddr,
1112                                              uh->source, &ipv6_hdr(skb)->saddr,
1113                                              dif, sdif);
1114         else
1115                 return;
1116
1117         if (!sk || !refcount_inc_not_zero(&sk->sk_refcnt))
1118                 return;
1119
1120         skb->sk = sk;
1121         skb->destructor = sock_efree;
1122         dst = rcu_dereference(sk->sk_rx_dst);
1123
1124         if (dst)
1125                 dst = dst_check(dst, sk->sk_rx_dst_cookie);
1126         if (dst) {
1127                 /* set noref for now.
1128                  * any place which wants to hold dst has to call
1129                  * dst_hold_safe()
1130                  */
1131                 skb_dst_set_noref(skb, dst);
1132         }
1133 }
1134
1135 INDIRECT_CALLABLE_SCOPE int udpv6_rcv(struct sk_buff *skb)
1136 {
1137         return __udp6_lib_rcv(skb, dev_net(skb->dev)->ipv4.udp_table, IPPROTO_UDP);
1138 }
1139
1140 /*
1141  * Throw away all pending data and cancel the corking. Socket is locked.
1142  */
1143 static void udp_v6_flush_pending_frames(struct sock *sk)
1144 {
1145         struct udp_sock *up = udp_sk(sk);
1146
1147         if (up->pending == AF_INET)
1148                 udp_flush_pending_frames(sk);
1149         else if (up->pending) {
1150                 up->len = 0;
1151                 up->pending = 0;
1152                 ip6_flush_pending_frames(sk);
1153         }
1154 }
1155
1156 static int udpv6_pre_connect(struct sock *sk, struct sockaddr *uaddr,
1157                              int addr_len)
1158 {
1159         if (addr_len < offsetofend(struct sockaddr, sa_family))
1160                 return -EINVAL;
1161         /* The following checks are replicated from __ip6_datagram_connect()
1162          * and intended to prevent BPF program called below from accessing
1163          * bytes that are out of the bound specified by user in addr_len.
1164          */
1165         if (uaddr->sa_family == AF_INET) {
1166                 if (ipv6_only_sock(sk))
1167                         return -EAFNOSUPPORT;
1168                 return udp_pre_connect(sk, uaddr, addr_len);
1169         }
1170
1171         if (addr_len < SIN6_LEN_RFC2133)
1172                 return -EINVAL;
1173
1174         return BPF_CGROUP_RUN_PROG_INET6_CONNECT_LOCK(sk, uaddr);
1175 }
1176
1177 /**
1178  *      udp6_hwcsum_outgoing  -  handle outgoing HW checksumming
1179  *      @sk:    socket we are sending on
1180  *      @skb:   sk_buff containing the filled-in UDP header
1181  *              (checksum field must be zeroed out)
1182  *      @saddr: source address
1183  *      @daddr: destination address
1184  *      @len:   length of packet
1185  */
1186 static void udp6_hwcsum_outgoing(struct sock *sk, struct sk_buff *skb,
1187                                  const struct in6_addr *saddr,
1188                                  const struct in6_addr *daddr, int len)
1189 {
1190         unsigned int offset;
1191         struct udphdr *uh = udp_hdr(skb);
1192         struct sk_buff *frags = skb_shinfo(skb)->frag_list;
1193         __wsum csum = 0;
1194
1195         if (!frags) {
1196                 /* Only one fragment on the socket.  */
1197                 skb->csum_start = skb_transport_header(skb) - skb->head;
1198                 skb->csum_offset = offsetof(struct udphdr, check);
1199                 uh->check = ~csum_ipv6_magic(saddr, daddr, len, IPPROTO_UDP, 0);
1200         } else {
1201                 /*
1202                  * HW-checksum won't work as there are two or more
1203                  * fragments on the socket so that all csums of sk_buffs
1204                  * should be together
1205                  */
1206                 offset = skb_transport_offset(skb);
1207                 skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
1208                 csum = skb->csum;
1209
1210                 skb->ip_summed = CHECKSUM_NONE;
1211
1212                 do {
1213                         csum = csum_add(csum, frags->csum);
1214                 } while ((frags = frags->next));
1215
1216                 uh->check = csum_ipv6_magic(saddr, daddr, len, IPPROTO_UDP,
1217                                             csum);
1218                 if (uh->check == 0)
1219                         uh->check = CSUM_MANGLED_0;
1220         }
1221 }
1222
1223 /*
1224  *      Sending
1225  */
1226
1227 static int udp_v6_send_skb(struct sk_buff *skb, struct flowi6 *fl6,
1228                            struct inet_cork *cork)
1229 {
1230         struct sock *sk = skb->sk;
1231         struct udphdr *uh;
1232         int err = 0;
1233         int is_udplite = IS_UDPLITE(sk);
1234         __wsum csum = 0;
1235         int offset = skb_transport_offset(skb);
1236         int len = skb->len - offset;
1237         int datalen = len - sizeof(*uh);
1238
1239         /*
1240          * Create a UDP header
1241          */
1242         uh = udp_hdr(skb);
1243         uh->source = fl6->fl6_sport;
1244         uh->dest = fl6->fl6_dport;
1245         uh->len = htons(len);
1246         uh->check = 0;
1247
1248         if (cork->gso_size) {
1249                 const int hlen = skb_network_header_len(skb) +
1250                                  sizeof(struct udphdr);
1251
1252                 if (hlen + cork->gso_size > cork->fragsize) {
1253                         kfree_skb(skb);
1254                         return -EINVAL;
1255                 }
1256                 if (datalen > cork->gso_size * UDP_MAX_SEGMENTS) {
1257                         kfree_skb(skb);
1258                         return -EINVAL;
1259                 }
1260                 if (udp_sk(sk)->no_check6_tx) {
1261                         kfree_skb(skb);
1262                         return -EINVAL;
1263                 }
1264                 if (skb->ip_summed != CHECKSUM_PARTIAL || is_udplite ||
1265                     dst_xfrm(skb_dst(skb))) {
1266                         kfree_skb(skb);
1267                         return -EIO;
1268                 }
1269
1270                 if (datalen > cork->gso_size) {
1271                         skb_shinfo(skb)->gso_size = cork->gso_size;
1272                         skb_shinfo(skb)->gso_type = SKB_GSO_UDP_L4;
1273                         skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(datalen,
1274                                                                  cork->gso_size);
1275                 }
1276                 goto csum_partial;
1277         }
1278
1279         if (is_udplite)
1280                 csum = udplite_csum(skb);
1281         else if (udp_sk(sk)->no_check6_tx) {   /* UDP csum disabled */
1282                 skb->ip_summed = CHECKSUM_NONE;
1283                 goto send;
1284         } else if (skb->ip_summed == CHECKSUM_PARTIAL) { /* UDP hardware csum */
1285 csum_partial:
1286                 udp6_hwcsum_outgoing(sk, skb, &fl6->saddr, &fl6->daddr, len);
1287                 goto send;
1288         } else
1289                 csum = udp_csum(skb);
1290
1291         /* add protocol-dependent pseudo-header */
1292         uh->check = csum_ipv6_magic(&fl6->saddr, &fl6->daddr,
1293                                     len, fl6->flowi6_proto, csum);
1294         if (uh->check == 0)
1295                 uh->check = CSUM_MANGLED_0;
1296
1297 send:
1298         err = ip6_send_skb(skb);
1299         if (err) {
1300                 if (err == -ENOBUFS && !inet6_sk(sk)->recverr) {
1301                         UDP6_INC_STATS(sock_net(sk),
1302                                        UDP_MIB_SNDBUFERRORS, is_udplite);
1303                         err = 0;
1304                 }
1305         } else {
1306                 UDP6_INC_STATS(sock_net(sk),
1307                                UDP_MIB_OUTDATAGRAMS, is_udplite);
1308         }
1309         return err;
1310 }
1311
1312 static int udp_v6_push_pending_frames(struct sock *sk)
1313 {
1314         struct sk_buff *skb;
1315         struct udp_sock  *up = udp_sk(sk);
1316         int err = 0;
1317
1318         if (up->pending == AF_INET)
1319                 return udp_push_pending_frames(sk);
1320
1321         skb = ip6_finish_skb(sk);
1322         if (!skb)
1323                 goto out;
1324
1325         err = udp_v6_send_skb(skb, &inet_sk(sk)->cork.fl.u.ip6,
1326                               &inet_sk(sk)->cork.base);
1327 out:
1328         up->len = 0;
1329         up->pending = 0;
1330         return err;
1331 }
1332
1333 int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
1334 {
1335         struct ipv6_txoptions opt_space;
1336         struct udp_sock *up = udp_sk(sk);
1337         struct inet_sock *inet = inet_sk(sk);
1338         struct ipv6_pinfo *np = inet6_sk(sk);
1339         DECLARE_SOCKADDR(struct sockaddr_in6 *, sin6, msg->msg_name);
1340         struct in6_addr *daddr, *final_p, final;
1341         struct ipv6_txoptions *opt = NULL;
1342         struct ipv6_txoptions *opt_to_free = NULL;
1343         struct ip6_flowlabel *flowlabel = NULL;
1344         struct inet_cork_full cork;
1345         struct flowi6 *fl6 = &cork.fl.u.ip6;
1346         struct dst_entry *dst;
1347         struct ipcm6_cookie ipc6;
1348         int addr_len = msg->msg_namelen;
1349         bool connected = false;
1350         int ulen = len;
1351         int corkreq = READ_ONCE(up->corkflag) || msg->msg_flags&MSG_MORE;
1352         int err;
1353         int is_udplite = IS_UDPLITE(sk);
1354         int (*getfrag)(void *, char *, int, int, int, struct sk_buff *);
1355
1356         ipcm6_init(&ipc6);
1357         ipc6.gso_size = READ_ONCE(up->gso_size);
1358         ipc6.sockc.tsflags = sk->sk_tsflags;
1359         ipc6.sockc.mark = sk->sk_mark;
1360
1361         /* destination address check */
1362         if (sin6) {
1363                 if (addr_len < offsetof(struct sockaddr, sa_data))
1364                         return -EINVAL;
1365
1366                 switch (sin6->sin6_family) {
1367                 case AF_INET6:
1368                         if (addr_len < SIN6_LEN_RFC2133)
1369                                 return -EINVAL;
1370                         daddr = &sin6->sin6_addr;
1371                         if (ipv6_addr_any(daddr) &&
1372                             ipv6_addr_v4mapped(&np->saddr))
1373                                 ipv6_addr_set_v4mapped(htonl(INADDR_LOOPBACK),
1374                                                        daddr);
1375                         break;
1376                 case AF_INET:
1377                         goto do_udp_sendmsg;
1378                 case AF_UNSPEC:
1379                         msg->msg_name = sin6 = NULL;
1380                         msg->msg_namelen = addr_len = 0;
1381                         daddr = NULL;
1382                         break;
1383                 default:
1384                         return -EINVAL;
1385                 }
1386         } else if (!up->pending) {
1387                 if (sk->sk_state != TCP_ESTABLISHED)
1388                         return -EDESTADDRREQ;
1389                 daddr = &sk->sk_v6_daddr;
1390         } else
1391                 daddr = NULL;
1392
1393         if (daddr) {
1394                 if (ipv6_addr_v4mapped(daddr)) {
1395                         struct sockaddr_in sin;
1396                         sin.sin_family = AF_INET;
1397                         sin.sin_port = sin6 ? sin6->sin6_port : inet->inet_dport;
1398                         sin.sin_addr.s_addr = daddr->s6_addr32[3];
1399                         msg->msg_name = &sin;
1400                         msg->msg_namelen = sizeof(sin);
1401 do_udp_sendmsg:
1402                         err = ipv6_only_sock(sk) ?
1403                                 -ENETUNREACH : udp_sendmsg(sk, msg, len);
1404                         msg->msg_name = sin6;
1405                         msg->msg_namelen = addr_len;
1406                         return err;
1407                 }
1408         }
1409
1410         /* Rough check on arithmetic overflow,
1411            better check is made in ip6_append_data().
1412            */
1413         if (len > INT_MAX - sizeof(struct udphdr))
1414                 return -EMSGSIZE;
1415
1416         getfrag  =  is_udplite ?  udplite_getfrag : ip_generic_getfrag;
1417         if (up->pending) {
1418                 if (up->pending == AF_INET)
1419                         return udp_sendmsg(sk, msg, len);
1420                 /*
1421                  * There are pending frames.
1422                  * The socket lock must be held while it's corked.
1423                  */
1424                 lock_sock(sk);
1425                 if (likely(up->pending)) {
1426                         if (unlikely(up->pending != AF_INET6)) {
1427                                 release_sock(sk);
1428                                 return -EAFNOSUPPORT;
1429                         }
1430                         dst = NULL;
1431                         goto do_append_data;
1432                 }
1433                 release_sock(sk);
1434         }
1435         ulen += sizeof(struct udphdr);
1436
1437         memset(fl6, 0, sizeof(*fl6));
1438
1439         if (sin6) {
1440                 if (sin6->sin6_port == 0)
1441                         return -EINVAL;
1442
1443                 fl6->fl6_dport = sin6->sin6_port;
1444                 daddr = &sin6->sin6_addr;
1445
1446                 if (np->sndflow) {
1447                         fl6->flowlabel = sin6->sin6_flowinfo&IPV6_FLOWINFO_MASK;
1448                         if (fl6->flowlabel & IPV6_FLOWLABEL_MASK) {
1449                                 flowlabel = fl6_sock_lookup(sk, fl6->flowlabel);
1450                                 if (IS_ERR(flowlabel))
1451                                         return -EINVAL;
1452                         }
1453                 }
1454
1455                 /*
1456                  * Otherwise it will be difficult to maintain
1457                  * sk->sk_dst_cache.
1458                  */
1459                 if (sk->sk_state == TCP_ESTABLISHED &&
1460                     ipv6_addr_equal(daddr, &sk->sk_v6_daddr))
1461                         daddr = &sk->sk_v6_daddr;
1462
1463                 if (addr_len >= sizeof(struct sockaddr_in6) &&
1464                     sin6->sin6_scope_id &&
1465                     __ipv6_addr_needs_scope_id(__ipv6_addr_type(daddr)))
1466                         fl6->flowi6_oif = sin6->sin6_scope_id;
1467         } else {
1468                 if (sk->sk_state != TCP_ESTABLISHED)
1469                         return -EDESTADDRREQ;
1470
1471                 fl6->fl6_dport = inet->inet_dport;
1472                 daddr = &sk->sk_v6_daddr;
1473                 fl6->flowlabel = np->flow_label;
1474                 connected = true;
1475         }
1476
1477         if (!fl6->flowi6_oif)
1478                 fl6->flowi6_oif = READ_ONCE(sk->sk_bound_dev_if);
1479
1480         if (!fl6->flowi6_oif)
1481                 fl6->flowi6_oif = np->sticky_pktinfo.ipi6_ifindex;
1482
1483         fl6->flowi6_uid = sk->sk_uid;
1484
1485         if (msg->msg_controllen) {
1486                 opt = &opt_space;
1487                 memset(opt, 0, sizeof(struct ipv6_txoptions));
1488                 opt->tot_len = sizeof(*opt);
1489                 ipc6.opt = opt;
1490
1491                 err = udp_cmsg_send(sk, msg, &ipc6.gso_size);
1492                 if (err > 0)
1493                         err = ip6_datagram_send_ctl(sock_net(sk), sk, msg, fl6,
1494                                                     &ipc6);
1495                 if (err < 0) {
1496                         fl6_sock_release(flowlabel);
1497                         return err;
1498                 }
1499                 if ((fl6->flowlabel&IPV6_FLOWLABEL_MASK) && !flowlabel) {
1500                         flowlabel = fl6_sock_lookup(sk, fl6->flowlabel);
1501                         if (IS_ERR(flowlabel))
1502                                 return -EINVAL;
1503                 }
1504                 if (!(opt->opt_nflen|opt->opt_flen))
1505                         opt = NULL;
1506                 connected = false;
1507         }
1508         if (!opt) {
1509                 opt = txopt_get(np);
1510                 opt_to_free = opt;
1511         }
1512         if (flowlabel)
1513                 opt = fl6_merge_options(&opt_space, flowlabel, opt);
1514         opt = ipv6_fixup_options(&opt_space, opt);
1515         ipc6.opt = opt;
1516
1517         fl6->flowi6_proto = sk->sk_protocol;
1518         fl6->flowi6_mark = ipc6.sockc.mark;
1519         fl6->daddr = *daddr;
1520         if (ipv6_addr_any(&fl6->saddr) && !ipv6_addr_any(&np->saddr))
1521                 fl6->saddr = np->saddr;
1522         fl6->fl6_sport = inet->inet_sport;
1523
1524         if (cgroup_bpf_enabled(CGROUP_UDP6_SENDMSG) && !connected) {
1525                 err = BPF_CGROUP_RUN_PROG_UDP6_SENDMSG_LOCK(sk,
1526                                            (struct sockaddr *)sin6,
1527                                            &fl6->saddr);
1528                 if (err)
1529                         goto out_no_dst;
1530                 if (sin6) {
1531                         if (ipv6_addr_v4mapped(&sin6->sin6_addr)) {
1532                                 /* BPF program rewrote IPv6-only by IPv4-mapped
1533                                  * IPv6. It's currently unsupported.
1534                                  */
1535                                 err = -ENOTSUPP;
1536                                 goto out_no_dst;
1537                         }
1538                         if (sin6->sin6_port == 0) {
1539                                 /* BPF program set invalid port. Reject it. */
1540                                 err = -EINVAL;
1541                                 goto out_no_dst;
1542                         }
1543                         fl6->fl6_dport = sin6->sin6_port;
1544                         fl6->daddr = sin6->sin6_addr;
1545                 }
1546         }
1547
1548         if (ipv6_addr_any(&fl6->daddr))
1549                 fl6->daddr.s6_addr[15] = 0x1; /* :: means loopback (BSD'ism) */
1550
1551         final_p = fl6_update_dst(fl6, opt, &final);
1552         if (final_p)
1553                 connected = false;
1554
1555         if (!fl6->flowi6_oif && ipv6_addr_is_multicast(&fl6->daddr)) {
1556                 fl6->flowi6_oif = np->mcast_oif;
1557                 connected = false;
1558         } else if (!fl6->flowi6_oif)
1559                 fl6->flowi6_oif = np->ucast_oif;
1560
1561         security_sk_classify_flow(sk, flowi6_to_flowi_common(fl6));
1562
1563         if (ipc6.tclass < 0)
1564                 ipc6.tclass = np->tclass;
1565
1566         fl6->flowlabel = ip6_make_flowinfo(ipc6.tclass, fl6->flowlabel);
1567
1568         dst = ip6_sk_dst_lookup_flow(sk, fl6, final_p, connected);
1569         if (IS_ERR(dst)) {
1570                 err = PTR_ERR(dst);
1571                 dst = NULL;
1572                 goto out;
1573         }
1574
1575         if (ipc6.hlimit < 0)
1576                 ipc6.hlimit = ip6_sk_dst_hoplimit(np, fl6, dst);
1577
1578         if (msg->msg_flags&MSG_CONFIRM)
1579                 goto do_confirm;
1580 back_from_confirm:
1581
1582         /* Lockless fast path for the non-corking case */
1583         if (!corkreq) {
1584                 struct sk_buff *skb;
1585
1586                 skb = ip6_make_skb(sk, getfrag, msg, ulen,
1587                                    sizeof(struct udphdr), &ipc6,
1588                                    (struct rt6_info *)dst,
1589                                    msg->msg_flags, &cork);
1590                 err = PTR_ERR(skb);
1591                 if (!IS_ERR_OR_NULL(skb))
1592                         err = udp_v6_send_skb(skb, fl6, &cork.base);
1593                 /* ip6_make_skb steals dst reference */
1594                 goto out_no_dst;
1595         }
1596
1597         lock_sock(sk);
1598         if (unlikely(up->pending)) {
1599                 /* The socket is already corked while preparing it. */
1600                 /* ... which is an evident application bug. --ANK */
1601                 release_sock(sk);
1602
1603                 net_dbg_ratelimited("udp cork app bug 2\n");
1604                 err = -EINVAL;
1605                 goto out;
1606         }
1607
1608         up->pending = AF_INET6;
1609
1610 do_append_data:
1611         if (ipc6.dontfrag < 0)
1612                 ipc6.dontfrag = np->dontfrag;
1613         up->len += ulen;
1614         err = ip6_append_data(sk, getfrag, msg, ulen, sizeof(struct udphdr),
1615                               &ipc6, fl6, (struct rt6_info *)dst,
1616                               corkreq ? msg->msg_flags|MSG_MORE : msg->msg_flags);
1617         if (err)
1618                 udp_v6_flush_pending_frames(sk);
1619         else if (!corkreq)
1620                 err = udp_v6_push_pending_frames(sk);
1621         else if (unlikely(skb_queue_empty(&sk->sk_write_queue)))
1622                 up->pending = 0;
1623
1624         if (err > 0)
1625                 err = np->recverr ? net_xmit_errno(err) : 0;
1626         release_sock(sk);
1627
1628 out:
1629         dst_release(dst);
1630 out_no_dst:
1631         fl6_sock_release(flowlabel);
1632         txopt_put(opt_to_free);
1633         if (!err)
1634                 return len;
1635         /*
1636          * ENOBUFS = no kernel mem, SOCK_NOSPACE = no sndbuf space.  Reporting
1637          * ENOBUFS might not be good (it's not tunable per se), but otherwise
1638          * we don't have a good statistic (IpOutDiscards but it can be too many
1639          * things).  We could add another new stat but at least for now that
1640          * seems like overkill.
1641          */
1642         if (err == -ENOBUFS || test_bit(SOCK_NOSPACE, &sk->sk_socket->flags)) {
1643                 UDP6_INC_STATS(sock_net(sk),
1644                                UDP_MIB_SNDBUFERRORS, is_udplite);
1645         }
1646         return err;
1647
1648 do_confirm:
1649         if (msg->msg_flags & MSG_PROBE)
1650                 dst_confirm_neigh(dst, &fl6->daddr);
1651         if (!(msg->msg_flags&MSG_PROBE) || len)
1652                 goto back_from_confirm;
1653         err = 0;
1654         goto out;
1655 }
1656 EXPORT_SYMBOL(udpv6_sendmsg);
1657
1658 static void udpv6_splice_eof(struct socket *sock)
1659 {
1660         struct sock *sk = sock->sk;
1661         struct udp_sock *up = udp_sk(sk);
1662
1663         if (!up->pending || READ_ONCE(up->corkflag))
1664                 return;
1665
1666         lock_sock(sk);
1667         if (up->pending && !READ_ONCE(up->corkflag))
1668                 udp_v6_push_pending_frames(sk);
1669         release_sock(sk);
1670 }
1671
1672 void udpv6_destroy_sock(struct sock *sk)
1673 {
1674         struct udp_sock *up = udp_sk(sk);
1675         lock_sock(sk);
1676
1677         /* protects from races with udp_abort() */
1678         sock_set_flag(sk, SOCK_DEAD);
1679         udp_v6_flush_pending_frames(sk);
1680         release_sock(sk);
1681
1682         if (static_branch_unlikely(&udpv6_encap_needed_key)) {
1683                 if (up->encap_type) {
1684                         void (*encap_destroy)(struct sock *sk);
1685                         encap_destroy = READ_ONCE(up->encap_destroy);
1686                         if (encap_destroy)
1687                                 encap_destroy(sk);
1688                 }
1689                 if (up->encap_enabled) {
1690                         static_branch_dec(&udpv6_encap_needed_key);
1691                         udp_encap_disable();
1692                 }
1693         }
1694 }
1695
1696 /*
1697  *      Socket option code for UDP
1698  */
1699 int udpv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval,
1700                      unsigned int optlen)
1701 {
1702         if (level == SOL_UDP  ||  level == SOL_UDPLITE || level == SOL_SOCKET)
1703                 return udp_lib_setsockopt(sk, level, optname,
1704                                           optval, optlen,
1705                                           udp_v6_push_pending_frames);
1706         return ipv6_setsockopt(sk, level, optname, optval, optlen);
1707 }
1708
1709 int udpv6_getsockopt(struct sock *sk, int level, int optname,
1710                      char __user *optval, int __user *optlen)
1711 {
1712         if (level == SOL_UDP  ||  level == SOL_UDPLITE)
1713                 return udp_lib_getsockopt(sk, level, optname, optval, optlen);
1714         return ipv6_getsockopt(sk, level, optname, optval, optlen);
1715 }
1716
1717 static const struct inet6_protocol udpv6_protocol = {
1718         .handler        =       udpv6_rcv,
1719         .err_handler    =       udpv6_err,
1720         .flags          =       INET6_PROTO_NOPOLICY|INET6_PROTO_FINAL,
1721 };
1722
1723 /* ------------------------------------------------------------------------ */
1724 #ifdef CONFIG_PROC_FS
1725 int udp6_seq_show(struct seq_file *seq, void *v)
1726 {
1727         if (v == SEQ_START_TOKEN) {
1728                 seq_puts(seq, IPV6_SEQ_DGRAM_HEADER);
1729         } else {
1730                 int bucket = ((struct udp_iter_state *)seq->private)->bucket;
1731                 const struct inet_sock *inet = inet_sk((const struct sock *)v);
1732                 __u16 srcp = ntohs(inet->inet_sport);
1733                 __u16 destp = ntohs(inet->inet_dport);
1734                 __ip6_dgram_sock_seq_show(seq, v, srcp, destp,
1735                                           udp_rqueue_get(v), bucket);
1736         }
1737         return 0;
1738 }
1739
1740 const struct seq_operations udp6_seq_ops = {
1741         .start          = udp_seq_start,
1742         .next           = udp_seq_next,
1743         .stop           = udp_seq_stop,
1744         .show           = udp6_seq_show,
1745 };
1746 EXPORT_SYMBOL(udp6_seq_ops);
1747
1748 static struct udp_seq_afinfo udp6_seq_afinfo = {
1749         .family         = AF_INET6,
1750         .udp_table      = NULL,
1751 };
1752
1753 int __net_init udp6_proc_init(struct net *net)
1754 {
1755         if (!proc_create_net_data("udp6", 0444, net->proc_net, &udp6_seq_ops,
1756                         sizeof(struct udp_iter_state), &udp6_seq_afinfo))
1757                 return -ENOMEM;
1758         return 0;
1759 }
1760
1761 void udp6_proc_exit(struct net *net)
1762 {
1763         remove_proc_entry("udp6", net->proc_net);
1764 }
1765 #endif /* CONFIG_PROC_FS */
1766
1767 /* ------------------------------------------------------------------------ */
1768
1769 struct proto udpv6_prot = {
1770         .name                   = "UDPv6",
1771         .owner                  = THIS_MODULE,
1772         .close                  = udp_lib_close,
1773         .pre_connect            = udpv6_pre_connect,
1774         .connect                = ip6_datagram_connect,
1775         .disconnect             = udp_disconnect,
1776         .ioctl                  = udp_ioctl,
1777         .init                   = udpv6_init_sock,
1778         .destroy                = udpv6_destroy_sock,
1779         .setsockopt             = udpv6_setsockopt,
1780         .getsockopt             = udpv6_getsockopt,
1781         .sendmsg                = udpv6_sendmsg,
1782         .recvmsg                = udpv6_recvmsg,
1783         .splice_eof             = udpv6_splice_eof,
1784         .release_cb             = ip6_datagram_release_cb,
1785         .hash                   = udp_lib_hash,
1786         .unhash                 = udp_lib_unhash,
1787         .rehash                 = udp_v6_rehash,
1788         .get_port               = udp_v6_get_port,
1789         .put_port               = udp_lib_unhash,
1790 #ifdef CONFIG_BPF_SYSCALL
1791         .psock_update_sk_prot   = udp_bpf_update_proto,
1792 #endif
1793
1794         .memory_allocated       = &udp_memory_allocated,
1795         .per_cpu_fw_alloc       = &udp_memory_per_cpu_fw_alloc,
1796
1797         .sysctl_mem             = sysctl_udp_mem,
1798         .sysctl_wmem_offset     = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
1799         .sysctl_rmem_offset     = offsetof(struct net, ipv4.sysctl_udp_rmem_min),
1800         .obj_size               = sizeof(struct udp6_sock),
1801         .h.udp_table            = NULL,
1802         .diag_destroy           = udp_abort,
1803 };
1804
1805 static struct inet_protosw udpv6_protosw = {
1806         .type =      SOCK_DGRAM,
1807         .protocol =  IPPROTO_UDP,
1808         .prot =      &udpv6_prot,
1809         .ops =       &inet6_dgram_ops,
1810         .flags =     INET_PROTOSW_PERMANENT,
1811 };
1812
1813 int __init udpv6_init(void)
1814 {
1815         int ret;
1816
1817         ret = inet6_add_protocol(&udpv6_protocol, IPPROTO_UDP);
1818         if (ret)
1819                 goto out;
1820
1821         ret = inet6_register_protosw(&udpv6_protosw);
1822         if (ret)
1823                 goto out_udpv6_protocol;
1824 out:
1825         return ret;
1826
1827 out_udpv6_protocol:
1828         inet6_del_protocol(&udpv6_protocol, IPPROTO_UDP);
1829         goto out;
1830 }
1831
1832 void udpv6_exit(void)
1833 {
1834         inet6_unregister_protosw(&udpv6_protosw);
1835         inet6_del_protocol(&udpv6_protocol, IPPROTO_UDP);
1836 }