coresight: ultrasoc-smb: Fix sleep while close preempt in enable_smb
[platform/kernel/linux-starfive.git] / drivers / net / amt.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Copyright (c) 2021 Taehee Yoo <ap420073@gmail.com> */
3
4 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5
6 #include <linux/module.h>
7 #include <linux/skbuff.h>
8 #include <linux/udp.h>
9 #include <linux/jhash.h>
10 #include <linux/if_tunnel.h>
11 #include <linux/net.h>
12 #include <linux/igmp.h>
13 #include <linux/workqueue.h>
14 #include <net/sch_generic.h>
15 #include <net/net_namespace.h>
16 #include <net/ip.h>
17 #include <net/udp.h>
18 #include <net/udp_tunnel.h>
19 #include <net/icmp.h>
20 #include <net/mld.h>
21 #include <net/amt.h>
22 #include <uapi/linux/amt.h>
23 #include <linux/security.h>
24 #include <net/gro_cells.h>
25 #include <net/ipv6.h>
26 #include <net/if_inet6.h>
27 #include <net/ndisc.h>
28 #include <net/addrconf.h>
29 #include <net/ip6_route.h>
30 #include <net/inet_common.h>
31 #include <net/ip6_checksum.h>
32
33 static struct workqueue_struct *amt_wq;
34
35 static HLIST_HEAD(source_gc_list);
36 /* Lock for source_gc_list */
37 static spinlock_t source_gc_lock;
38 static struct delayed_work source_gc_wq;
39 static char *status_str[] = {
40         "AMT_STATUS_INIT",
41         "AMT_STATUS_SENT_DISCOVERY",
42         "AMT_STATUS_RECEIVED_DISCOVERY",
43         "AMT_STATUS_SENT_ADVERTISEMENT",
44         "AMT_STATUS_RECEIVED_ADVERTISEMENT",
45         "AMT_STATUS_SENT_REQUEST",
46         "AMT_STATUS_RECEIVED_REQUEST",
47         "AMT_STATUS_SENT_QUERY",
48         "AMT_STATUS_RECEIVED_QUERY",
49         "AMT_STATUS_SENT_UPDATE",
50         "AMT_STATUS_RECEIVED_UPDATE",
51 };
52
53 static char *type_str[] = {
54         "", /* Type 0 is not defined */
55         "AMT_MSG_DISCOVERY",
56         "AMT_MSG_ADVERTISEMENT",
57         "AMT_MSG_REQUEST",
58         "AMT_MSG_MEMBERSHIP_QUERY",
59         "AMT_MSG_MEMBERSHIP_UPDATE",
60         "AMT_MSG_MULTICAST_DATA",
61         "AMT_MSG_TEARDOWN",
62 };
63
64 static char *action_str[] = {
65         "AMT_ACT_GMI",
66         "AMT_ACT_GMI_ZERO",
67         "AMT_ACT_GT",
68         "AMT_ACT_STATUS_FWD_NEW",
69         "AMT_ACT_STATUS_D_FWD_NEW",
70         "AMT_ACT_STATUS_NONE_NEW",
71 };
72
73 static struct igmpv3_grec igmpv3_zero_grec;
74
75 #if IS_ENABLED(CONFIG_IPV6)
76 #define MLD2_ALL_NODE_INIT { { { 0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01 } } }
77 static struct in6_addr mld2_all_node = MLD2_ALL_NODE_INIT;
78 static struct mld2_grec mldv2_zero_grec;
79 #endif
80
81 static struct amt_skb_cb *amt_skb_cb(struct sk_buff *skb)
82 {
83         BUILD_BUG_ON(sizeof(struct amt_skb_cb) + sizeof(struct qdisc_skb_cb) >
84                      sizeof_field(struct sk_buff, cb));
85
86         return (struct amt_skb_cb *)((void *)skb->cb +
87                 sizeof(struct qdisc_skb_cb));
88 }
89
90 static void __amt_source_gc_work(void)
91 {
92         struct amt_source_node *snode;
93         struct hlist_head gc_list;
94         struct hlist_node *t;
95
96         spin_lock_bh(&source_gc_lock);
97         hlist_move_list(&source_gc_list, &gc_list);
98         spin_unlock_bh(&source_gc_lock);
99
100         hlist_for_each_entry_safe(snode, t, &gc_list, node) {
101                 hlist_del_rcu(&snode->node);
102                 kfree_rcu(snode, rcu);
103         }
104 }
105
106 static void amt_source_gc_work(struct work_struct *work)
107 {
108         __amt_source_gc_work();
109
110         spin_lock_bh(&source_gc_lock);
111         mod_delayed_work(amt_wq, &source_gc_wq,
112                          msecs_to_jiffies(AMT_GC_INTERVAL));
113         spin_unlock_bh(&source_gc_lock);
114 }
115
116 static bool amt_addr_equal(union amt_addr *a, union amt_addr *b)
117 {
118         return !memcmp(a, b, sizeof(union amt_addr));
119 }
120
121 static u32 amt_source_hash(struct amt_tunnel_list *tunnel, union amt_addr *src)
122 {
123         u32 hash = jhash(src, sizeof(*src), tunnel->amt->hash_seed);
124
125         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
126 }
127
128 static bool amt_status_filter(struct amt_source_node *snode,
129                               enum amt_filter filter)
130 {
131         bool rc = false;
132
133         switch (filter) {
134         case AMT_FILTER_FWD:
135                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
136                     snode->flags == AMT_SOURCE_OLD)
137                         rc = true;
138                 break;
139         case AMT_FILTER_D_FWD:
140                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
141                     snode->flags == AMT_SOURCE_OLD)
142                         rc = true;
143                 break;
144         case AMT_FILTER_FWD_NEW:
145                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
146                     snode->flags == AMT_SOURCE_NEW)
147                         rc = true;
148                 break;
149         case AMT_FILTER_D_FWD_NEW:
150                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
151                     snode->flags == AMT_SOURCE_NEW)
152                         rc = true;
153                 break;
154         case AMT_FILTER_ALL:
155                 rc = true;
156                 break;
157         case AMT_FILTER_NONE_NEW:
158                 if (snode->status == AMT_SOURCE_STATUS_NONE &&
159                     snode->flags == AMT_SOURCE_NEW)
160                         rc = true;
161                 break;
162         case AMT_FILTER_BOTH:
163                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
164                      snode->status == AMT_SOURCE_STATUS_FWD) &&
165                     snode->flags == AMT_SOURCE_OLD)
166                         rc = true;
167                 break;
168         case AMT_FILTER_BOTH_NEW:
169                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
170                      snode->status == AMT_SOURCE_STATUS_FWD) &&
171                     snode->flags == AMT_SOURCE_NEW)
172                         rc = true;
173                 break;
174         default:
175                 WARN_ON_ONCE(1);
176                 break;
177         }
178
179         return rc;
180 }
181
182 static struct amt_source_node *amt_lookup_src(struct amt_tunnel_list *tunnel,
183                                               struct amt_group_node *gnode,
184                                               enum amt_filter filter,
185                                               union amt_addr *src)
186 {
187         u32 hash = amt_source_hash(tunnel, src);
188         struct amt_source_node *snode;
189
190         hlist_for_each_entry_rcu(snode, &gnode->sources[hash], node)
191                 if (amt_status_filter(snode, filter) &&
192                     amt_addr_equal(&snode->source_addr, src))
193                         return snode;
194
195         return NULL;
196 }
197
198 static u32 amt_group_hash(struct amt_tunnel_list *tunnel, union amt_addr *group)
199 {
200         u32 hash = jhash(group, sizeof(*group), tunnel->amt->hash_seed);
201
202         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
203 }
204
205 static struct amt_group_node *amt_lookup_group(struct amt_tunnel_list *tunnel,
206                                                union amt_addr *group,
207                                                union amt_addr *host,
208                                                bool v6)
209 {
210         u32 hash = amt_group_hash(tunnel, group);
211         struct amt_group_node *gnode;
212
213         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash], node) {
214                 if (amt_addr_equal(&gnode->group_addr, group) &&
215                     amt_addr_equal(&gnode->host_addr, host) &&
216                     gnode->v6 == v6)
217                         return gnode;
218         }
219
220         return NULL;
221 }
222
223 static void amt_destroy_source(struct amt_source_node *snode)
224 {
225         struct amt_group_node *gnode = snode->gnode;
226         struct amt_tunnel_list *tunnel;
227
228         tunnel = gnode->tunnel_list;
229
230         if (!gnode->v6) {
231                 netdev_dbg(snode->gnode->amt->dev,
232                            "Delete source %pI4 from %pI4\n",
233                            &snode->source_addr.ip4,
234                            &gnode->group_addr.ip4);
235 #if IS_ENABLED(CONFIG_IPV6)
236         } else {
237                 netdev_dbg(snode->gnode->amt->dev,
238                            "Delete source %pI6 from %pI6\n",
239                            &snode->source_addr.ip6,
240                            &gnode->group_addr.ip6);
241 #endif
242         }
243
244         cancel_delayed_work(&snode->source_timer);
245         hlist_del_init_rcu(&snode->node);
246         tunnel->nr_sources--;
247         gnode->nr_sources--;
248         spin_lock_bh(&source_gc_lock);
249         hlist_add_head_rcu(&snode->node, &source_gc_list);
250         spin_unlock_bh(&source_gc_lock);
251 }
252
253 static void amt_del_group(struct amt_dev *amt, struct amt_group_node *gnode)
254 {
255         struct amt_source_node *snode;
256         struct hlist_node *t;
257         int i;
258
259         if (cancel_delayed_work(&gnode->group_timer))
260                 dev_put(amt->dev);
261         hlist_del_rcu(&gnode->node);
262         gnode->tunnel_list->nr_groups--;
263
264         if (!gnode->v6)
265                 netdev_dbg(amt->dev, "Leave group %pI4\n",
266                            &gnode->group_addr.ip4);
267 #if IS_ENABLED(CONFIG_IPV6)
268         else
269                 netdev_dbg(amt->dev, "Leave group %pI6\n",
270                            &gnode->group_addr.ip6);
271 #endif
272         for (i = 0; i < amt->hash_buckets; i++)
273                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node)
274                         amt_destroy_source(snode);
275
276         /* tunnel->lock was acquired outside of amt_del_group()
277          * But rcu_read_lock() was acquired too so It's safe.
278          */
279         kfree_rcu(gnode, rcu);
280 }
281
282 /* If a source timer expires with a router filter-mode for the group of
283  * INCLUDE, the router concludes that traffic from this particular
284  * source is no longer desired on the attached network, and deletes the
285  * associated source record.
286  */
287 static void amt_source_work(struct work_struct *work)
288 {
289         struct amt_source_node *snode = container_of(to_delayed_work(work),
290                                                      struct amt_source_node,
291                                                      source_timer);
292         struct amt_group_node *gnode = snode->gnode;
293         struct amt_dev *amt = gnode->amt;
294         struct amt_tunnel_list *tunnel;
295
296         tunnel = gnode->tunnel_list;
297         spin_lock_bh(&tunnel->lock);
298         rcu_read_lock();
299         if (gnode->filter_mode == MCAST_INCLUDE) {
300                 amt_destroy_source(snode);
301                 if (!gnode->nr_sources)
302                         amt_del_group(amt, gnode);
303         } else {
304                 /* When a router filter-mode for a group is EXCLUDE,
305                  * source records are only deleted when the group timer expires
306                  */
307                 snode->status = AMT_SOURCE_STATUS_D_FWD;
308         }
309         rcu_read_unlock();
310         spin_unlock_bh(&tunnel->lock);
311 }
312
313 static void amt_act_src(struct amt_tunnel_list *tunnel,
314                         struct amt_group_node *gnode,
315                         struct amt_source_node *snode,
316                         enum amt_act act)
317 {
318         struct amt_dev *amt = tunnel->amt;
319
320         switch (act) {
321         case AMT_ACT_GMI:
322                 mod_delayed_work(amt_wq, &snode->source_timer,
323                                  msecs_to_jiffies(amt_gmi(amt)));
324                 break;
325         case AMT_ACT_GMI_ZERO:
326                 cancel_delayed_work(&snode->source_timer);
327                 break;
328         case AMT_ACT_GT:
329                 mod_delayed_work(amt_wq, &snode->source_timer,
330                                  gnode->group_timer.timer.expires);
331                 break;
332         case AMT_ACT_STATUS_FWD_NEW:
333                 snode->status = AMT_SOURCE_STATUS_FWD;
334                 snode->flags = AMT_SOURCE_NEW;
335                 break;
336         case AMT_ACT_STATUS_D_FWD_NEW:
337                 snode->status = AMT_SOURCE_STATUS_D_FWD;
338                 snode->flags = AMT_SOURCE_NEW;
339                 break;
340         case AMT_ACT_STATUS_NONE_NEW:
341                 cancel_delayed_work(&snode->source_timer);
342                 snode->status = AMT_SOURCE_STATUS_NONE;
343                 snode->flags = AMT_SOURCE_NEW;
344                 break;
345         default:
346                 WARN_ON_ONCE(1);
347                 return;
348         }
349
350         if (!gnode->v6)
351                 netdev_dbg(amt->dev, "Source %pI4 from %pI4 Acted %s\n",
352                            &snode->source_addr.ip4,
353                            &gnode->group_addr.ip4,
354                            action_str[act]);
355 #if IS_ENABLED(CONFIG_IPV6)
356         else
357                 netdev_dbg(amt->dev, "Source %pI6 from %pI6 Acted %s\n",
358                            &snode->source_addr.ip6,
359                            &gnode->group_addr.ip6,
360                            action_str[act]);
361 #endif
362 }
363
364 static struct amt_source_node *amt_alloc_snode(struct amt_group_node *gnode,
365                                                union amt_addr *src)
366 {
367         struct amt_source_node *snode;
368
369         snode = kzalloc(sizeof(*snode), GFP_ATOMIC);
370         if (!snode)
371                 return NULL;
372
373         memcpy(&snode->source_addr, src, sizeof(union amt_addr));
374         snode->gnode = gnode;
375         snode->status = AMT_SOURCE_STATUS_NONE;
376         snode->flags = AMT_SOURCE_NEW;
377         INIT_HLIST_NODE(&snode->node);
378         INIT_DELAYED_WORK(&snode->source_timer, amt_source_work);
379
380         return snode;
381 }
382
383 /* RFC 3810 - 7.2.2.  Definition of Filter Timers
384  *
385  *  Router Mode          Filter Timer         Actions/Comments
386  *  -----------       -----------------       ----------------
387  *
388  *    INCLUDE             Not Used            All listeners in
389  *                                            INCLUDE mode.
390  *
391  *    EXCLUDE             Timer > 0           At least one listener
392  *                                            in EXCLUDE mode.
393  *
394  *    EXCLUDE             Timer == 0          No more listeners in
395  *                                            EXCLUDE mode for the
396  *                                            multicast address.
397  *                                            If the Requested List
398  *                                            is empty, delete
399  *                                            Multicast Address
400  *                                            Record.  If not, switch
401  *                                            to INCLUDE filter mode;
402  *                                            the sources in the
403  *                                            Requested List are
404  *                                            moved to the Include
405  *                                            List, and the Exclude
406  *                                            List is deleted.
407  */
408 static void amt_group_work(struct work_struct *work)
409 {
410         struct amt_group_node *gnode = container_of(to_delayed_work(work),
411                                                     struct amt_group_node,
412                                                     group_timer);
413         struct amt_tunnel_list *tunnel = gnode->tunnel_list;
414         struct amt_dev *amt = gnode->amt;
415         struct amt_source_node *snode;
416         bool delete_group = true;
417         struct hlist_node *t;
418         int i, buckets;
419
420         buckets = amt->hash_buckets;
421
422         spin_lock_bh(&tunnel->lock);
423         if (gnode->filter_mode == MCAST_INCLUDE) {
424                 /* Not Used */
425                 spin_unlock_bh(&tunnel->lock);
426                 goto out;
427         }
428
429         rcu_read_lock();
430         for (i = 0; i < buckets; i++) {
431                 hlist_for_each_entry_safe(snode, t,
432                                           &gnode->sources[i], node) {
433                         if (!delayed_work_pending(&snode->source_timer) ||
434                             snode->status == AMT_SOURCE_STATUS_D_FWD) {
435                                 amt_destroy_source(snode);
436                         } else {
437                                 delete_group = false;
438                                 snode->status = AMT_SOURCE_STATUS_FWD;
439                         }
440                 }
441         }
442         if (delete_group)
443                 amt_del_group(amt, gnode);
444         else
445                 gnode->filter_mode = MCAST_INCLUDE;
446         rcu_read_unlock();
447         spin_unlock_bh(&tunnel->lock);
448 out:
449         dev_put(amt->dev);
450 }
451
452 /* Non-existent group is created as INCLUDE {empty}:
453  *
454  * RFC 3376 - 5.1. Action on Change of Interface State
455  *
456  * If no interface state existed for that multicast address before
457  * the change (i.e., the change consisted of creating a new
458  * per-interface record), or if no state exists after the change
459  * (i.e., the change consisted of deleting a per-interface record),
460  * then the "non-existent" state is considered to have a filter mode
461  * of INCLUDE and an empty source list.
462  */
463 static struct amt_group_node *amt_add_group(struct amt_dev *amt,
464                                             struct amt_tunnel_list *tunnel,
465                                             union amt_addr *group,
466                                             union amt_addr *host,
467                                             bool v6)
468 {
469         struct amt_group_node *gnode;
470         u32 hash;
471         int i;
472
473         if (tunnel->nr_groups >= amt->max_groups)
474                 return ERR_PTR(-ENOSPC);
475
476         gnode = kzalloc(sizeof(*gnode) +
477                         (sizeof(struct hlist_head) * amt->hash_buckets),
478                         GFP_ATOMIC);
479         if (unlikely(!gnode))
480                 return ERR_PTR(-ENOMEM);
481
482         gnode->amt = amt;
483         gnode->group_addr = *group;
484         gnode->host_addr = *host;
485         gnode->v6 = v6;
486         gnode->tunnel_list = tunnel;
487         gnode->filter_mode = MCAST_INCLUDE;
488         INIT_HLIST_NODE(&gnode->node);
489         INIT_DELAYED_WORK(&gnode->group_timer, amt_group_work);
490         for (i = 0; i < amt->hash_buckets; i++)
491                 INIT_HLIST_HEAD(&gnode->sources[i]);
492
493         hash = amt_group_hash(tunnel, group);
494         hlist_add_head_rcu(&gnode->node, &tunnel->groups[hash]);
495         tunnel->nr_groups++;
496
497         if (!gnode->v6)
498                 netdev_dbg(amt->dev, "Join group %pI4\n",
499                            &gnode->group_addr.ip4);
500 #if IS_ENABLED(CONFIG_IPV6)
501         else
502                 netdev_dbg(amt->dev, "Join group %pI6\n",
503                            &gnode->group_addr.ip6);
504 #endif
505
506         return gnode;
507 }
508
509 static struct sk_buff *amt_build_igmp_gq(struct amt_dev *amt)
510 {
511         u8 ra[AMT_IPHDR_OPTS] = { IPOPT_RA, 4, 0, 0 };
512         int hlen = LL_RESERVED_SPACE(amt->dev);
513         int tlen = amt->dev->needed_tailroom;
514         struct igmpv3_query *ihv3;
515         void *csum_start = NULL;
516         __sum16 *csum = NULL;
517         struct sk_buff *skb;
518         struct ethhdr *eth;
519         struct iphdr *iph;
520         unsigned int len;
521         int offset;
522
523         len = hlen + tlen + sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3);
524         skb = netdev_alloc_skb_ip_align(amt->dev, len);
525         if (!skb)
526                 return NULL;
527
528         skb_reserve(skb, hlen);
529         skb_push(skb, sizeof(*eth));
530         skb->protocol = htons(ETH_P_IP);
531         skb_reset_mac_header(skb);
532         skb->priority = TC_PRIO_CONTROL;
533         skb_put(skb, sizeof(*iph));
534         skb_put_data(skb, ra, sizeof(ra));
535         skb_put(skb, sizeof(*ihv3));
536         skb_pull(skb, sizeof(*eth));
537         skb_reset_network_header(skb);
538
539         iph             = ip_hdr(skb);
540         iph->version    = 4;
541         iph->ihl        = (sizeof(struct iphdr) + AMT_IPHDR_OPTS) >> 2;
542         iph->tos        = AMT_TOS;
543         iph->tot_len    = htons(sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3));
544         iph->frag_off   = htons(IP_DF);
545         iph->ttl        = 1;
546         iph->id         = 0;
547         iph->protocol   = IPPROTO_IGMP;
548         iph->daddr      = htonl(INADDR_ALLHOSTS_GROUP);
549         iph->saddr      = htonl(INADDR_ANY);
550         ip_send_check(iph);
551
552         eth = eth_hdr(skb);
553         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
554         ip_eth_mc_map(htonl(INADDR_ALLHOSTS_GROUP), eth->h_dest);
555         eth->h_proto = htons(ETH_P_IP);
556
557         ihv3            = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
558         skb_reset_transport_header(skb);
559         ihv3->type      = IGMP_HOST_MEMBERSHIP_QUERY;
560         ihv3->code      = 1;
561         ihv3->group     = 0;
562         ihv3->qqic      = amt->qi;
563         ihv3->nsrcs     = 0;
564         ihv3->resv      = 0;
565         ihv3->suppress  = false;
566         ihv3->qrv       = READ_ONCE(amt->net->ipv4.sysctl_igmp_qrv);
567         ihv3->csum      = 0;
568         csum            = &ihv3->csum;
569         csum_start      = (void *)ihv3;
570         *csum           = ip_compute_csum(csum_start, sizeof(*ihv3));
571         offset          = skb_transport_offset(skb);
572         skb->csum       = skb_checksum(skb, offset, skb->len - offset, 0);
573         skb->ip_summed  = CHECKSUM_NONE;
574
575         skb_push(skb, sizeof(*eth) + sizeof(*iph) + AMT_IPHDR_OPTS);
576
577         return skb;
578 }
579
580 static void amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
581                                  bool validate)
582 {
583         if (validate && amt->status >= status)
584                 return;
585         netdev_dbg(amt->dev, "Update GW status %s -> %s",
586                    status_str[amt->status], status_str[status]);
587         WRITE_ONCE(amt->status, status);
588 }
589
590 static void __amt_update_relay_status(struct amt_tunnel_list *tunnel,
591                                       enum amt_status status,
592                                       bool validate)
593 {
594         if (validate && tunnel->status >= status)
595                 return;
596         netdev_dbg(tunnel->amt->dev,
597                    "Update Tunnel(IP = %pI4, PORT = %u) status %s -> %s",
598                    &tunnel->ip4, ntohs(tunnel->source_port),
599                    status_str[tunnel->status], status_str[status]);
600         tunnel->status = status;
601 }
602
603 static void amt_update_relay_status(struct amt_tunnel_list *tunnel,
604                                     enum amt_status status, bool validate)
605 {
606         spin_lock_bh(&tunnel->lock);
607         __amt_update_relay_status(tunnel, status, validate);
608         spin_unlock_bh(&tunnel->lock);
609 }
610
611 static void amt_send_discovery(struct amt_dev *amt)
612 {
613         struct amt_header_discovery *amtd;
614         int hlen, tlen, offset;
615         struct socket *sock;
616         struct udphdr *udph;
617         struct sk_buff *skb;
618         struct iphdr *iph;
619         struct rtable *rt;
620         struct flowi4 fl4;
621         u32 len;
622         int err;
623
624         rcu_read_lock();
625         sock = rcu_dereference(amt->sock);
626         if (!sock)
627                 goto out;
628
629         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
630                 goto out;
631
632         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
633                                    amt->discovery_ip, amt->local_ip,
634                                    amt->gw_port, amt->relay_port,
635                                    IPPROTO_UDP, 0,
636                                    amt->stream_dev->ifindex);
637         if (IS_ERR(rt)) {
638                 amt->dev->stats.tx_errors++;
639                 goto out;
640         }
641
642         hlen = LL_RESERVED_SPACE(amt->dev);
643         tlen = amt->dev->needed_tailroom;
644         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
645         skb = netdev_alloc_skb_ip_align(amt->dev, len);
646         if (!skb) {
647                 ip_rt_put(rt);
648                 amt->dev->stats.tx_errors++;
649                 goto out;
650         }
651
652         skb->priority = TC_PRIO_CONTROL;
653         skb_dst_set(skb, &rt->dst);
654
655         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
656         skb_reset_network_header(skb);
657         skb_put(skb, len);
658         amtd = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
659         amtd->version   = 0;
660         amtd->type      = AMT_MSG_DISCOVERY;
661         amtd->reserved  = 0;
662         amtd->nonce     = amt->nonce;
663         skb_push(skb, sizeof(*udph));
664         skb_reset_transport_header(skb);
665         udph            = udp_hdr(skb);
666         udph->source    = amt->gw_port;
667         udph->dest      = amt->relay_port;
668         udph->len       = htons(sizeof(*udph) + sizeof(*amtd));
669         udph->check     = 0;
670         offset = skb_transport_offset(skb);
671         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
672         udph->check = csum_tcpudp_magic(amt->local_ip, amt->discovery_ip,
673                                         sizeof(*udph) + sizeof(*amtd),
674                                         IPPROTO_UDP, skb->csum);
675
676         skb_push(skb, sizeof(*iph));
677         iph             = ip_hdr(skb);
678         iph->version    = 4;
679         iph->ihl        = (sizeof(struct iphdr)) >> 2;
680         iph->tos        = AMT_TOS;
681         iph->frag_off   = 0;
682         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
683         iph->daddr      = amt->discovery_ip;
684         iph->saddr      = amt->local_ip;
685         iph->protocol   = IPPROTO_UDP;
686         iph->tot_len    = htons(len);
687
688         skb->ip_summed = CHECKSUM_NONE;
689         ip_select_ident(amt->net, skb, NULL);
690         ip_send_check(iph);
691         err = ip_local_out(amt->net, sock->sk, skb);
692         if (unlikely(net_xmit_eval(err)))
693                 amt->dev->stats.tx_errors++;
694
695         amt_update_gw_status(amt, AMT_STATUS_SENT_DISCOVERY, true);
696 out:
697         rcu_read_unlock();
698 }
699
700 static void amt_send_request(struct amt_dev *amt, bool v6)
701 {
702         struct amt_header_request *amtrh;
703         int hlen, tlen, offset;
704         struct socket *sock;
705         struct udphdr *udph;
706         struct sk_buff *skb;
707         struct iphdr *iph;
708         struct rtable *rt;
709         struct flowi4 fl4;
710         u32 len;
711         int err;
712
713         rcu_read_lock();
714         sock = rcu_dereference(amt->sock);
715         if (!sock)
716                 goto out;
717
718         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
719                 goto out;
720
721         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
722                                    amt->remote_ip, amt->local_ip,
723                                    amt->gw_port, amt->relay_port,
724                                    IPPROTO_UDP, 0,
725                                    amt->stream_dev->ifindex);
726         if (IS_ERR(rt)) {
727                 amt->dev->stats.tx_errors++;
728                 goto out;
729         }
730
731         hlen = LL_RESERVED_SPACE(amt->dev);
732         tlen = amt->dev->needed_tailroom;
733         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
734         skb = netdev_alloc_skb_ip_align(amt->dev, len);
735         if (!skb) {
736                 ip_rt_put(rt);
737                 amt->dev->stats.tx_errors++;
738                 goto out;
739         }
740
741         skb->priority = TC_PRIO_CONTROL;
742         skb_dst_set(skb, &rt->dst);
743
744         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
745         skb_reset_network_header(skb);
746         skb_put(skb, len);
747         amtrh = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
748         amtrh->version   = 0;
749         amtrh->type      = AMT_MSG_REQUEST;
750         amtrh->reserved1 = 0;
751         amtrh->p         = v6;
752         amtrh->reserved2 = 0;
753         amtrh->nonce     = amt->nonce;
754         skb_push(skb, sizeof(*udph));
755         skb_reset_transport_header(skb);
756         udph            = udp_hdr(skb);
757         udph->source    = amt->gw_port;
758         udph->dest      = amt->relay_port;
759         udph->len       = htons(sizeof(*amtrh) + sizeof(*udph));
760         udph->check     = 0;
761         offset = skb_transport_offset(skb);
762         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
763         udph->check = csum_tcpudp_magic(amt->local_ip, amt->remote_ip,
764                                         sizeof(*udph) + sizeof(*amtrh),
765                                         IPPROTO_UDP, skb->csum);
766
767         skb_push(skb, sizeof(*iph));
768         iph             = ip_hdr(skb);
769         iph->version    = 4;
770         iph->ihl        = (sizeof(struct iphdr)) >> 2;
771         iph->tos        = AMT_TOS;
772         iph->frag_off   = 0;
773         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
774         iph->daddr      = amt->remote_ip;
775         iph->saddr      = amt->local_ip;
776         iph->protocol   = IPPROTO_UDP;
777         iph->tot_len    = htons(len);
778
779         skb->ip_summed = CHECKSUM_NONE;
780         ip_select_ident(amt->net, skb, NULL);
781         ip_send_check(iph);
782         err = ip_local_out(amt->net, sock->sk, skb);
783         if (unlikely(net_xmit_eval(err)))
784                 amt->dev->stats.tx_errors++;
785
786 out:
787         rcu_read_unlock();
788 }
789
790 static void amt_send_igmp_gq(struct amt_dev *amt,
791                              struct amt_tunnel_list *tunnel)
792 {
793         struct sk_buff *skb;
794
795         skb = amt_build_igmp_gq(amt);
796         if (!skb)
797                 return;
798
799         amt_skb_cb(skb)->tunnel = tunnel;
800         dev_queue_xmit(skb);
801 }
802
803 #if IS_ENABLED(CONFIG_IPV6)
804 static struct sk_buff *amt_build_mld_gq(struct amt_dev *amt)
805 {
806         u8 ra[AMT_IP6HDR_OPTS] = { IPPROTO_ICMPV6, 0, IPV6_TLV_ROUTERALERT,
807                                    2, 0, 0, IPV6_TLV_PAD1, IPV6_TLV_PAD1 };
808         int hlen = LL_RESERVED_SPACE(amt->dev);
809         int tlen = amt->dev->needed_tailroom;
810         struct mld2_query *mld2q;
811         void *csum_start = NULL;
812         struct ipv6hdr *ip6h;
813         struct sk_buff *skb;
814         struct ethhdr *eth;
815         u32 len;
816
817         len = hlen + tlen + sizeof(*ip6h) + sizeof(ra) + sizeof(*mld2q);
818         skb = netdev_alloc_skb_ip_align(amt->dev, len);
819         if (!skb)
820                 return NULL;
821
822         skb_reserve(skb, hlen);
823         skb_push(skb, sizeof(*eth));
824         skb_reset_mac_header(skb);
825         eth = eth_hdr(skb);
826         skb->priority = TC_PRIO_CONTROL;
827         skb->protocol = htons(ETH_P_IPV6);
828         skb_put_zero(skb, sizeof(*ip6h));
829         skb_put_data(skb, ra, sizeof(ra));
830         skb_put_zero(skb, sizeof(*mld2q));
831         skb_pull(skb, sizeof(*eth));
832         skb_reset_network_header(skb);
833         ip6h                    = ipv6_hdr(skb);
834         ip6h->payload_len       = htons(sizeof(ra) + sizeof(*mld2q));
835         ip6h->nexthdr           = NEXTHDR_HOP;
836         ip6h->hop_limit         = 1;
837         ip6h->daddr             = mld2_all_node;
838         ip6_flow_hdr(ip6h, 0, 0);
839
840         if (ipv6_dev_get_saddr(amt->net, amt->dev, &ip6h->daddr, 0,
841                                &ip6h->saddr)) {
842                 amt->dev->stats.tx_errors++;
843                 kfree_skb(skb);
844                 return NULL;
845         }
846
847         eth->h_proto = htons(ETH_P_IPV6);
848         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
849         ipv6_eth_mc_map(&mld2_all_node, eth->h_dest);
850
851         skb_pull(skb, sizeof(*ip6h) + sizeof(ra));
852         skb_reset_transport_header(skb);
853         mld2q                   = (struct mld2_query *)icmp6_hdr(skb);
854         mld2q->mld2q_mrc        = htons(1);
855         mld2q->mld2q_type       = ICMPV6_MGM_QUERY;
856         mld2q->mld2q_code       = 0;
857         mld2q->mld2q_cksum      = 0;
858         mld2q->mld2q_resv1      = 0;
859         mld2q->mld2q_resv2      = 0;
860         mld2q->mld2q_suppress   = 0;
861         mld2q->mld2q_qrv        = amt->qrv;
862         mld2q->mld2q_nsrcs      = 0;
863         mld2q->mld2q_qqic       = amt->qi;
864         csum_start              = (void *)mld2q;
865         mld2q->mld2q_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
866                                              sizeof(*mld2q),
867                                              IPPROTO_ICMPV6,
868                                              csum_partial(csum_start,
869                                                           sizeof(*mld2q), 0));
870
871         skb->ip_summed = CHECKSUM_NONE;
872         skb_push(skb, sizeof(*eth) + sizeof(*ip6h) + sizeof(ra));
873         return skb;
874 }
875
876 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
877 {
878         struct sk_buff *skb;
879
880         skb = amt_build_mld_gq(amt);
881         if (!skb)
882                 return;
883
884         amt_skb_cb(skb)->tunnel = tunnel;
885         dev_queue_xmit(skb);
886 }
887 #else
888 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
889 {
890 }
891 #endif
892
893 static bool amt_queue_event(struct amt_dev *amt, enum amt_event event,
894                             struct sk_buff *skb)
895 {
896         int index;
897
898         spin_lock_bh(&amt->lock);
899         if (amt->nr_events >= AMT_MAX_EVENTS) {
900                 spin_unlock_bh(&amt->lock);
901                 return 1;
902         }
903
904         index = (amt->event_idx + amt->nr_events) % AMT_MAX_EVENTS;
905         amt->events[index].event = event;
906         amt->events[index].skb = skb;
907         amt->nr_events++;
908         amt->event_idx %= AMT_MAX_EVENTS;
909         queue_work(amt_wq, &amt->event_wq);
910         spin_unlock_bh(&amt->lock);
911
912         return 0;
913 }
914
915 static void amt_secret_work(struct work_struct *work)
916 {
917         struct amt_dev *amt = container_of(to_delayed_work(work),
918                                            struct amt_dev,
919                                            secret_wq);
920
921         spin_lock_bh(&amt->lock);
922         get_random_bytes(&amt->key, sizeof(siphash_key_t));
923         spin_unlock_bh(&amt->lock);
924         mod_delayed_work(amt_wq, &amt->secret_wq,
925                          msecs_to_jiffies(AMT_SECRET_TIMEOUT));
926 }
927
928 static void amt_event_send_discovery(struct amt_dev *amt)
929 {
930         if (amt->status > AMT_STATUS_SENT_DISCOVERY)
931                 goto out;
932         get_random_bytes(&amt->nonce, sizeof(__be32));
933
934         amt_send_discovery(amt);
935 out:
936         mod_delayed_work(amt_wq, &amt->discovery_wq,
937                          msecs_to_jiffies(AMT_DISCOVERY_TIMEOUT));
938 }
939
940 static void amt_discovery_work(struct work_struct *work)
941 {
942         struct amt_dev *amt = container_of(to_delayed_work(work),
943                                            struct amt_dev,
944                                            discovery_wq);
945
946         if (amt_queue_event(amt, AMT_EVENT_SEND_DISCOVERY, NULL))
947                 mod_delayed_work(amt_wq, &amt->discovery_wq,
948                                  msecs_to_jiffies(AMT_DISCOVERY_TIMEOUT));
949 }
950
951 static void amt_event_send_request(struct amt_dev *amt)
952 {
953         u32 exp;
954
955         if (amt->status < AMT_STATUS_RECEIVED_ADVERTISEMENT)
956                 goto out;
957
958         if (amt->req_cnt > AMT_MAX_REQ_COUNT) {
959                 netdev_dbg(amt->dev, "Gateway is not ready");
960                 amt->qi = AMT_INIT_REQ_TIMEOUT;
961                 WRITE_ONCE(amt->ready4, false);
962                 WRITE_ONCE(amt->ready6, false);
963                 amt->remote_ip = 0;
964                 amt_update_gw_status(amt, AMT_STATUS_INIT, false);
965                 amt->req_cnt = 0;
966                 amt->nonce = 0;
967                 goto out;
968         }
969
970         if (!amt->req_cnt) {
971                 WRITE_ONCE(amt->ready4, false);
972                 WRITE_ONCE(amt->ready6, false);
973                 get_random_bytes(&amt->nonce, sizeof(__be32));
974         }
975
976         amt_send_request(amt, false);
977         amt_send_request(amt, true);
978         amt_update_gw_status(amt, AMT_STATUS_SENT_REQUEST, true);
979         amt->req_cnt++;
980 out:
981         exp = min_t(u32, (1 * (1 << amt->req_cnt)), AMT_MAX_REQ_TIMEOUT);
982         mod_delayed_work(amt_wq, &amt->req_wq, msecs_to_jiffies(exp * 1000));
983 }
984
985 static void amt_req_work(struct work_struct *work)
986 {
987         struct amt_dev *amt = container_of(to_delayed_work(work),
988                                            struct amt_dev,
989                                            req_wq);
990
991         if (amt_queue_event(amt, AMT_EVENT_SEND_REQUEST, NULL))
992                 mod_delayed_work(amt_wq, &amt->req_wq,
993                                  msecs_to_jiffies(100));
994 }
995
996 static bool amt_send_membership_update(struct amt_dev *amt,
997                                        struct sk_buff *skb,
998                                        bool v6)
999 {
1000         struct amt_header_membership_update *amtmu;
1001         struct socket *sock;
1002         struct iphdr *iph;
1003         struct flowi4 fl4;
1004         struct rtable *rt;
1005         int err;
1006
1007         sock = rcu_dereference_bh(amt->sock);
1008         if (!sock)
1009                 return true;
1010
1011         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmu) +
1012                            sizeof(*iph) + sizeof(struct udphdr));
1013         if (err)
1014                 return true;
1015
1016         skb_reset_inner_headers(skb);
1017         memset(&fl4, 0, sizeof(struct flowi4));
1018         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1019         fl4.daddr              = amt->remote_ip;
1020         fl4.saddr              = amt->local_ip;
1021         fl4.flowi4_tos         = AMT_TOS;
1022         fl4.flowi4_proto       = IPPROTO_UDP;
1023         rt = ip_route_output_key(amt->net, &fl4);
1024         if (IS_ERR(rt)) {
1025                 netdev_dbg(amt->dev, "no route to %pI4\n", &amt->remote_ip);
1026                 return true;
1027         }
1028
1029         amtmu                   = skb_push(skb, sizeof(*amtmu));
1030         amtmu->version          = 0;
1031         amtmu->type             = AMT_MSG_MEMBERSHIP_UPDATE;
1032         amtmu->reserved         = 0;
1033         amtmu->nonce            = amt->nonce;
1034         amtmu->response_mac     = amt->mac;
1035
1036         if (!v6)
1037                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1038         else
1039                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1040         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1041                             fl4.saddr,
1042                             fl4.daddr,
1043                             AMT_TOS,
1044                             ip4_dst_hoplimit(&rt->dst),
1045                             0,
1046                             amt->gw_port,
1047                             amt->relay_port,
1048                             false,
1049                             false);
1050         amt_update_gw_status(amt, AMT_STATUS_SENT_UPDATE, true);
1051         return false;
1052 }
1053
1054 static void amt_send_multicast_data(struct amt_dev *amt,
1055                                     const struct sk_buff *oskb,
1056                                     struct amt_tunnel_list *tunnel,
1057                                     bool v6)
1058 {
1059         struct amt_header_mcast_data *amtmd;
1060         struct socket *sock;
1061         struct sk_buff *skb;
1062         struct iphdr *iph;
1063         struct flowi4 fl4;
1064         struct rtable *rt;
1065
1066         sock = rcu_dereference_bh(amt->sock);
1067         if (!sock)
1068                 return;
1069
1070         skb = skb_copy_expand(oskb, sizeof(*amtmd) + sizeof(*iph) +
1071                               sizeof(struct udphdr), 0, GFP_ATOMIC);
1072         if (!skb)
1073                 return;
1074
1075         skb_reset_inner_headers(skb);
1076         memset(&fl4, 0, sizeof(struct flowi4));
1077         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1078         fl4.daddr              = tunnel->ip4;
1079         fl4.saddr              = amt->local_ip;
1080         fl4.flowi4_proto       = IPPROTO_UDP;
1081         rt = ip_route_output_key(amt->net, &fl4);
1082         if (IS_ERR(rt)) {
1083                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1084                 kfree_skb(skb);
1085                 return;
1086         }
1087
1088         amtmd = skb_push(skb, sizeof(*amtmd));
1089         amtmd->version = 0;
1090         amtmd->reserved = 0;
1091         amtmd->type = AMT_MSG_MULTICAST_DATA;
1092
1093         if (!v6)
1094                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1095         else
1096                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1097         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1098                             fl4.saddr,
1099                             fl4.daddr,
1100                             AMT_TOS,
1101                             ip4_dst_hoplimit(&rt->dst),
1102                             0,
1103                             amt->relay_port,
1104                             tunnel->source_port,
1105                             false,
1106                             false);
1107 }
1108
1109 static bool amt_send_membership_query(struct amt_dev *amt,
1110                                       struct sk_buff *skb,
1111                                       struct amt_tunnel_list *tunnel,
1112                                       bool v6)
1113 {
1114         struct amt_header_membership_query *amtmq;
1115         struct socket *sock;
1116         struct rtable *rt;
1117         struct flowi4 fl4;
1118         int err;
1119
1120         sock = rcu_dereference_bh(amt->sock);
1121         if (!sock)
1122                 return true;
1123
1124         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmq) +
1125                            sizeof(struct iphdr) + sizeof(struct udphdr));
1126         if (err)
1127                 return true;
1128
1129         skb_reset_inner_headers(skb);
1130         memset(&fl4, 0, sizeof(struct flowi4));
1131         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1132         fl4.daddr              = tunnel->ip4;
1133         fl4.saddr              = amt->local_ip;
1134         fl4.flowi4_tos         = AMT_TOS;
1135         fl4.flowi4_proto       = IPPROTO_UDP;
1136         rt = ip_route_output_key(amt->net, &fl4);
1137         if (IS_ERR(rt)) {
1138                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1139                 return true;
1140         }
1141
1142         amtmq           = skb_push(skb, sizeof(*amtmq));
1143         amtmq->version  = 0;
1144         amtmq->type     = AMT_MSG_MEMBERSHIP_QUERY;
1145         amtmq->reserved = 0;
1146         amtmq->l        = 0;
1147         amtmq->g        = 0;
1148         amtmq->nonce    = tunnel->nonce;
1149         amtmq->response_mac = tunnel->mac;
1150
1151         if (!v6)
1152                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1153         else
1154                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1155         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1156                             fl4.saddr,
1157                             fl4.daddr,
1158                             AMT_TOS,
1159                             ip4_dst_hoplimit(&rt->dst),
1160                             0,
1161                             amt->relay_port,
1162                             tunnel->source_port,
1163                             false,
1164                             false);
1165         amt_update_relay_status(tunnel, AMT_STATUS_SENT_QUERY, true);
1166         return false;
1167 }
1168
1169 static netdev_tx_t amt_dev_xmit(struct sk_buff *skb, struct net_device *dev)
1170 {
1171         struct amt_dev *amt = netdev_priv(dev);
1172         struct amt_tunnel_list *tunnel;
1173         struct amt_group_node *gnode;
1174         union amt_addr group = {0,};
1175 #if IS_ENABLED(CONFIG_IPV6)
1176         struct ipv6hdr *ip6h;
1177         struct mld_msg *mld;
1178 #endif
1179         bool report = false;
1180         struct igmphdr *ih;
1181         bool query = false;
1182         struct iphdr *iph;
1183         bool data = false;
1184         bool v6 = false;
1185         u32 hash;
1186
1187         iph = ip_hdr(skb);
1188         if (iph->version == 4) {
1189                 if (!ipv4_is_multicast(iph->daddr))
1190                         goto free;
1191
1192                 if (!ip_mc_check_igmp(skb)) {
1193                         ih = igmp_hdr(skb);
1194                         switch (ih->type) {
1195                         case IGMPV3_HOST_MEMBERSHIP_REPORT:
1196                         case IGMP_HOST_MEMBERSHIP_REPORT:
1197                                 report = true;
1198                                 break;
1199                         case IGMP_HOST_MEMBERSHIP_QUERY:
1200                                 query = true;
1201                                 break;
1202                         default:
1203                                 goto free;
1204                         }
1205                 } else {
1206                         data = true;
1207                 }
1208                 v6 = false;
1209                 group.ip4 = iph->daddr;
1210 #if IS_ENABLED(CONFIG_IPV6)
1211         } else if (iph->version == 6) {
1212                 ip6h = ipv6_hdr(skb);
1213                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
1214                         goto free;
1215
1216                 if (!ipv6_mc_check_mld(skb)) {
1217                         mld = (struct mld_msg *)skb_transport_header(skb);
1218                         switch (mld->mld_type) {
1219                         case ICMPV6_MGM_REPORT:
1220                         case ICMPV6_MLD2_REPORT:
1221                                 report = true;
1222                                 break;
1223                         case ICMPV6_MGM_QUERY:
1224                                 query = true;
1225                                 break;
1226                         default:
1227                                 goto free;
1228                         }
1229                 } else {
1230                         data = true;
1231                 }
1232                 v6 = true;
1233                 group.ip6 = ip6h->daddr;
1234 #endif
1235         } else {
1236                 dev->stats.tx_errors++;
1237                 goto free;
1238         }
1239
1240         if (!pskb_may_pull(skb, sizeof(struct ethhdr)))
1241                 goto free;
1242
1243         skb_pull(skb, sizeof(struct ethhdr));
1244
1245         if (amt->mode == AMT_MODE_GATEWAY) {
1246                 /* Gateway only passes IGMP/MLD packets */
1247                 if (!report)
1248                         goto free;
1249                 if ((!v6 && !READ_ONCE(amt->ready4)) ||
1250                     (v6 && !READ_ONCE(amt->ready6)))
1251                         goto free;
1252                 if (amt_send_membership_update(amt, skb,  v6))
1253                         goto free;
1254                 goto unlock;
1255         } else if (amt->mode == AMT_MODE_RELAY) {
1256                 if (query) {
1257                         tunnel = amt_skb_cb(skb)->tunnel;
1258                         if (!tunnel) {
1259                                 WARN_ON(1);
1260                                 goto free;
1261                         }
1262
1263                         /* Do not forward unexpected query */
1264                         if (amt_send_membership_query(amt, skb, tunnel, v6))
1265                                 goto free;
1266                         goto unlock;
1267                 }
1268
1269                 if (!data)
1270                         goto free;
1271                 list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
1272                         hash = amt_group_hash(tunnel, &group);
1273                         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash],
1274                                                  node) {
1275                                 if (!v6) {
1276                                         if (gnode->group_addr.ip4 == iph->daddr)
1277                                                 goto found;
1278 #if IS_ENABLED(CONFIG_IPV6)
1279                                 } else {
1280                                         if (ipv6_addr_equal(&gnode->group_addr.ip6,
1281                                                             &ip6h->daddr))
1282                                                 goto found;
1283 #endif
1284                                 }
1285                         }
1286                         continue;
1287 found:
1288                         amt_send_multicast_data(amt, skb, tunnel, v6);
1289                 }
1290         }
1291
1292         dev_kfree_skb(skb);
1293         return NETDEV_TX_OK;
1294 free:
1295         dev_kfree_skb(skb);
1296 unlock:
1297         dev->stats.tx_dropped++;
1298         return NETDEV_TX_OK;
1299 }
1300
1301 static int amt_parse_type(struct sk_buff *skb)
1302 {
1303         struct amt_header *amth;
1304
1305         if (!pskb_may_pull(skb, sizeof(struct udphdr) +
1306                            sizeof(struct amt_header)))
1307                 return -1;
1308
1309         amth = (struct amt_header *)(udp_hdr(skb) + 1);
1310
1311         if (amth->version != 0)
1312                 return -1;
1313
1314         if (amth->type >= __AMT_MSG_MAX || !amth->type)
1315                 return -1;
1316         return amth->type;
1317 }
1318
1319 static void amt_clear_groups(struct amt_tunnel_list *tunnel)
1320 {
1321         struct amt_dev *amt = tunnel->amt;
1322         struct amt_group_node *gnode;
1323         struct hlist_node *t;
1324         int i;
1325
1326         spin_lock_bh(&tunnel->lock);
1327         rcu_read_lock();
1328         for (i = 0; i < amt->hash_buckets; i++)
1329                 hlist_for_each_entry_safe(gnode, t, &tunnel->groups[i], node)
1330                         amt_del_group(amt, gnode);
1331         rcu_read_unlock();
1332         spin_unlock_bh(&tunnel->lock);
1333 }
1334
1335 static void amt_tunnel_expire(struct work_struct *work)
1336 {
1337         struct amt_tunnel_list *tunnel = container_of(to_delayed_work(work),
1338                                                       struct amt_tunnel_list,
1339                                                       gc_wq);
1340         struct amt_dev *amt = tunnel->amt;
1341
1342         spin_lock_bh(&amt->lock);
1343         rcu_read_lock();
1344         list_del_rcu(&tunnel->list);
1345         amt->nr_tunnels--;
1346         amt_clear_groups(tunnel);
1347         rcu_read_unlock();
1348         spin_unlock_bh(&amt->lock);
1349         kfree_rcu(tunnel, rcu);
1350 }
1351
1352 static void amt_cleanup_srcs(struct amt_dev *amt,
1353                              struct amt_tunnel_list *tunnel,
1354                              struct amt_group_node *gnode)
1355 {
1356         struct amt_source_node *snode;
1357         struct hlist_node *t;
1358         int i;
1359
1360         /* Delete old sources */
1361         for (i = 0; i < amt->hash_buckets; i++) {
1362                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node) {
1363                         if (snode->flags == AMT_SOURCE_OLD)
1364                                 amt_destroy_source(snode);
1365                 }
1366         }
1367
1368         /* switch from new to old */
1369         for (i = 0; i < amt->hash_buckets; i++)  {
1370                 hlist_for_each_entry_rcu(snode, &gnode->sources[i], node) {
1371                         snode->flags = AMT_SOURCE_OLD;
1372                         if (!gnode->v6)
1373                                 netdev_dbg(snode->gnode->amt->dev,
1374                                            "Add source as OLD %pI4 from %pI4\n",
1375                                            &snode->source_addr.ip4,
1376                                            &gnode->group_addr.ip4);
1377 #if IS_ENABLED(CONFIG_IPV6)
1378                         else
1379                                 netdev_dbg(snode->gnode->amt->dev,
1380                                            "Add source as OLD %pI6 from %pI6\n",
1381                                            &snode->source_addr.ip6,
1382                                            &gnode->group_addr.ip6);
1383 #endif
1384                 }
1385         }
1386 }
1387
1388 static void amt_add_srcs(struct amt_dev *amt, struct amt_tunnel_list *tunnel,
1389                          struct amt_group_node *gnode, void *grec,
1390                          bool v6)
1391 {
1392         struct igmpv3_grec *igmp_grec;
1393         struct amt_source_node *snode;
1394 #if IS_ENABLED(CONFIG_IPV6)
1395         struct mld2_grec *mld_grec;
1396 #endif
1397         union amt_addr src = {0,};
1398         u16 nsrcs;
1399         u32 hash;
1400         int i;
1401
1402         if (!v6) {
1403                 igmp_grec = grec;
1404                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1405         } else {
1406 #if IS_ENABLED(CONFIG_IPV6)
1407                 mld_grec = grec;
1408                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1409 #else
1410         return;
1411 #endif
1412         }
1413         for (i = 0; i < nsrcs; i++) {
1414                 if (tunnel->nr_sources >= amt->max_sources)
1415                         return;
1416                 if (!v6)
1417                         src.ip4 = igmp_grec->grec_src[i];
1418 #if IS_ENABLED(CONFIG_IPV6)
1419                 else
1420                         memcpy(&src.ip6, &mld_grec->grec_src[i],
1421                                sizeof(struct in6_addr));
1422 #endif
1423                 if (amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL, &src))
1424                         continue;
1425
1426                 snode = amt_alloc_snode(gnode, &src);
1427                 if (snode) {
1428                         hash = amt_source_hash(tunnel, &snode->source_addr);
1429                         hlist_add_head_rcu(&snode->node, &gnode->sources[hash]);
1430                         tunnel->nr_sources++;
1431                         gnode->nr_sources++;
1432
1433                         if (!gnode->v6)
1434                                 netdev_dbg(snode->gnode->amt->dev,
1435                                            "Add source as NEW %pI4 from %pI4\n",
1436                                            &snode->source_addr.ip4,
1437                                            &gnode->group_addr.ip4);
1438 #if IS_ENABLED(CONFIG_IPV6)
1439                         else
1440                                 netdev_dbg(snode->gnode->amt->dev,
1441                                            "Add source as NEW %pI6 from %pI6\n",
1442                                            &snode->source_addr.ip6,
1443                                            &gnode->group_addr.ip6);
1444 #endif
1445                 }
1446         }
1447 }
1448
1449 /* Router State   Report Rec'd New Router State
1450  * ------------   ------------ ----------------
1451  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)
1452  *
1453  * -----------+-----------+-----------+
1454  *            |    OLD    |    NEW    |
1455  * -----------+-----------+-----------+
1456  *    FWD     |     X     |    X+A    |
1457  * -----------+-----------+-----------+
1458  *    D_FWD   |     Y     |    Y-A    |
1459  * -----------+-----------+-----------+
1460  *    NONE    |           |     A     |
1461  * -----------+-----------+-----------+
1462  *
1463  * a) Received sources are NONE/NEW
1464  * b) All NONE will be deleted by amt_cleanup_srcs().
1465  * c) All OLD will be deleted by amt_cleanup_srcs().
1466  * d) After delete, NEW source will be switched to OLD.
1467  */
1468 static void amt_lookup_act_srcs(struct amt_tunnel_list *tunnel,
1469                                 struct amt_group_node *gnode,
1470                                 void *grec,
1471                                 enum amt_ops ops,
1472                                 enum amt_filter filter,
1473                                 enum amt_act act,
1474                                 bool v6)
1475 {
1476         struct amt_dev *amt = tunnel->amt;
1477         struct amt_source_node *snode;
1478         struct igmpv3_grec *igmp_grec;
1479 #if IS_ENABLED(CONFIG_IPV6)
1480         struct mld2_grec *mld_grec;
1481 #endif
1482         union amt_addr src = {0,};
1483         struct hlist_node *t;
1484         u16 nsrcs;
1485         int i, j;
1486
1487         if (!v6) {
1488                 igmp_grec = grec;
1489                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1490         } else {
1491 #if IS_ENABLED(CONFIG_IPV6)
1492                 mld_grec = grec;
1493                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1494 #else
1495         return;
1496 #endif
1497         }
1498
1499         memset(&src, 0, sizeof(union amt_addr));
1500         switch (ops) {
1501         case AMT_OPS_INT:
1502                 /* A*B */
1503                 for (i = 0; i < nsrcs; i++) {
1504                         if (!v6)
1505                                 src.ip4 = igmp_grec->grec_src[i];
1506 #if IS_ENABLED(CONFIG_IPV6)
1507                         else
1508                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1509                                        sizeof(struct in6_addr));
1510 #endif
1511                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1512                         if (!snode)
1513                                 continue;
1514                         amt_act_src(tunnel, gnode, snode, act);
1515                 }
1516                 break;
1517         case AMT_OPS_UNI:
1518                 /* A+B */
1519                 for (i = 0; i < amt->hash_buckets; i++) {
1520                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1521                                                   node) {
1522                                 if (amt_status_filter(snode, filter))
1523                                         amt_act_src(tunnel, gnode, snode, act);
1524                         }
1525                 }
1526                 for (i = 0; i < nsrcs; i++) {
1527                         if (!v6)
1528                                 src.ip4 = igmp_grec->grec_src[i];
1529 #if IS_ENABLED(CONFIG_IPV6)
1530                         else
1531                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1532                                        sizeof(struct in6_addr));
1533 #endif
1534                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1535                         if (!snode)
1536                                 continue;
1537                         amt_act_src(tunnel, gnode, snode, act);
1538                 }
1539                 break;
1540         case AMT_OPS_SUB:
1541                 /* A-B */
1542                 for (i = 0; i < amt->hash_buckets; i++) {
1543                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1544                                                   node) {
1545                                 if (!amt_status_filter(snode, filter))
1546                                         continue;
1547                                 for (j = 0; j < nsrcs; j++) {
1548                                         if (!v6)
1549                                                 src.ip4 = igmp_grec->grec_src[j];
1550 #if IS_ENABLED(CONFIG_IPV6)
1551                                         else
1552                                                 memcpy(&src.ip6,
1553                                                        &mld_grec->grec_src[j],
1554                                                        sizeof(struct in6_addr));
1555 #endif
1556                                         if (amt_addr_equal(&snode->source_addr,
1557                                                            &src))
1558                                                 goto out_sub;
1559                                 }
1560                                 amt_act_src(tunnel, gnode, snode, act);
1561                                 continue;
1562 out_sub:;
1563                         }
1564                 }
1565                 break;
1566         case AMT_OPS_SUB_REV:
1567                 /* B-A */
1568                 for (i = 0; i < nsrcs; i++) {
1569                         if (!v6)
1570                                 src.ip4 = igmp_grec->grec_src[i];
1571 #if IS_ENABLED(CONFIG_IPV6)
1572                         else
1573                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1574                                        sizeof(struct in6_addr));
1575 #endif
1576                         snode = amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL,
1577                                                &src);
1578                         if (!snode) {
1579                                 snode = amt_lookup_src(tunnel, gnode,
1580                                                        filter, &src);
1581                                 if (snode)
1582                                         amt_act_src(tunnel, gnode, snode, act);
1583                         }
1584                 }
1585                 break;
1586         default:
1587                 netdev_dbg(amt->dev, "Invalid type\n");
1588                 return;
1589         }
1590 }
1591
1592 static void amt_mcast_is_in_handler(struct amt_dev *amt,
1593                                     struct amt_tunnel_list *tunnel,
1594                                     struct amt_group_node *gnode,
1595                                     void *grec, void *zero_grec, bool v6)
1596 {
1597         if (gnode->filter_mode == MCAST_INCLUDE) {
1598 /* Router State   Report Rec'd New Router State        Actions
1599  * ------------   ------------ ----------------        -------
1600  * INCLUDE (A)    IS_IN (B)    INCLUDE (A+B)           (B)=GMI
1601  */
1602                 /* Update IS_IN (B) as FWD/NEW */
1603                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1604                                     AMT_FILTER_NONE_NEW,
1605                                     AMT_ACT_STATUS_FWD_NEW,
1606                                     v6);
1607                 /* Update INCLUDE (A) as NEW */
1608                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1609                                     AMT_FILTER_FWD,
1610                                     AMT_ACT_STATUS_FWD_NEW,
1611                                     v6);
1612                 /* (B)=GMI */
1613                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1614                                     AMT_FILTER_FWD_NEW,
1615                                     AMT_ACT_GMI,
1616                                     v6);
1617         } else {
1618 /* State        Actions
1619  * ------------   ------------ ----------------        -------
1620  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1621  */
1622                 /* Update (A) in (X, Y) as NONE/NEW */
1623                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1624                                     AMT_FILTER_BOTH,
1625                                     AMT_ACT_STATUS_NONE_NEW,
1626                                     v6);
1627                 /* Update FWD/OLD as FWD/NEW */
1628                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1629                                     AMT_FILTER_FWD,
1630                                     AMT_ACT_STATUS_FWD_NEW,
1631                                     v6);
1632                 /* Update IS_IN (A) as FWD/NEW */
1633                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1634                                     AMT_FILTER_NONE_NEW,
1635                                     AMT_ACT_STATUS_FWD_NEW,
1636                                     v6);
1637                 /* Update EXCLUDE (, Y-A) as D_FWD_NEW */
1638                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1639                                     AMT_FILTER_D_FWD,
1640                                     AMT_ACT_STATUS_D_FWD_NEW,
1641                                     v6);
1642         }
1643 }
1644
1645 static void amt_mcast_is_ex_handler(struct amt_dev *amt,
1646                                     struct amt_tunnel_list *tunnel,
1647                                     struct amt_group_node *gnode,
1648                                     void *grec, void *zero_grec, bool v6)
1649 {
1650         if (gnode->filter_mode == MCAST_INCLUDE) {
1651 /* Router State   Report Rec'd  New Router State         Actions
1652  * ------------   ------------  ----------------         -------
1653  * INCLUDE (A)    IS_EX (B)     EXCLUDE (A*B,B-A)        (B-A)=0
1654  *                                                       Delete (A-B)
1655  *                                                       Group Timer=GMI
1656  */
1657                 /* EXCLUDE(A*B, ) */
1658                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1659                                     AMT_FILTER_FWD,
1660                                     AMT_ACT_STATUS_FWD_NEW,
1661                                     v6);
1662                 /* EXCLUDE(, B-A) */
1663                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1664                                     AMT_FILTER_FWD,
1665                                     AMT_ACT_STATUS_D_FWD_NEW,
1666                                     v6);
1667                 /* (B-A)=0 */
1668                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1669                                     AMT_FILTER_D_FWD_NEW,
1670                                     AMT_ACT_GMI_ZERO,
1671                                     v6);
1672                 /* Group Timer=GMI */
1673                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1674                                       msecs_to_jiffies(amt_gmi(amt))))
1675                         dev_hold(amt->dev);
1676                 gnode->filter_mode = MCAST_EXCLUDE;
1677                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1678         } else {
1679 /* Router State   Report Rec'd  New Router State        Actions
1680  * ------------   ------------  ----------------        -------
1681  * EXCLUDE (X,Y)  IS_EX (A)     EXCLUDE (A-Y,Y*A)       (A-X-Y)=GMI
1682  *                                                      Delete (X-A)
1683  *                                                      Delete (Y-A)
1684  *                                                      Group Timer=GMI
1685  */
1686                 /* EXCLUDE (A-Y, ) */
1687                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1688                                     AMT_FILTER_D_FWD,
1689                                     AMT_ACT_STATUS_FWD_NEW,
1690                                     v6);
1691                 /* EXCLUDE (, Y*A ) */
1692                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1693                                     AMT_FILTER_D_FWD,
1694                                     AMT_ACT_STATUS_D_FWD_NEW,
1695                                     v6);
1696                 /* (A-X-Y)=GMI */
1697                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1698                                     AMT_FILTER_BOTH_NEW,
1699                                     AMT_ACT_GMI,
1700                                     v6);
1701                 /* Group Timer=GMI */
1702                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1703                                       msecs_to_jiffies(amt_gmi(amt))))
1704                         dev_hold(amt->dev);
1705                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1706         }
1707 }
1708
1709 static void amt_mcast_to_in_handler(struct amt_dev *amt,
1710                                     struct amt_tunnel_list *tunnel,
1711                                     struct amt_group_node *gnode,
1712                                     void *grec, void *zero_grec, bool v6)
1713 {
1714         if (gnode->filter_mode == MCAST_INCLUDE) {
1715 /* Router State   Report Rec'd New Router State        Actions
1716  * ------------   ------------ ----------------        -------
1717  * INCLUDE (A)    TO_IN (B)    INCLUDE (A+B)           (B)=GMI
1718  *                                                     Send Q(G,A-B)
1719  */
1720                 /* Update TO_IN (B) sources as FWD/NEW */
1721                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1722                                     AMT_FILTER_NONE_NEW,
1723                                     AMT_ACT_STATUS_FWD_NEW,
1724                                     v6);
1725                 /* Update INCLUDE (A) sources as NEW */
1726                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1727                                     AMT_FILTER_FWD,
1728                                     AMT_ACT_STATUS_FWD_NEW,
1729                                     v6);
1730                 /* (B)=GMI */
1731                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1732                                     AMT_FILTER_FWD_NEW,
1733                                     AMT_ACT_GMI,
1734                                     v6);
1735         } else {
1736 /* Router State   Report Rec'd New Router State        Actions
1737  * ------------   ------------ ----------------        -------
1738  * EXCLUDE (X,Y)  TO_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1739  *                                                     Send Q(G,X-A)
1740  *                                                     Send Q(G)
1741  */
1742                 /* Update TO_IN (A) sources as FWD/NEW */
1743                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1744                                     AMT_FILTER_NONE_NEW,
1745                                     AMT_ACT_STATUS_FWD_NEW,
1746                                     v6);
1747                 /* Update EXCLUDE(X,) sources as FWD/NEW */
1748                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1749                                     AMT_FILTER_FWD,
1750                                     AMT_ACT_STATUS_FWD_NEW,
1751                                     v6);
1752                 /* EXCLUDE (, Y-A)
1753                  * (A) are already switched to FWD_NEW.
1754                  * So, D_FWD/OLD -> D_FWD/NEW is okay.
1755                  */
1756                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1757                                     AMT_FILTER_D_FWD,
1758                                     AMT_ACT_STATUS_D_FWD_NEW,
1759                                     v6);
1760                 /* (A)=GMI
1761                  * Only FWD_NEW will have (A) sources.
1762                  */
1763                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1764                                     AMT_FILTER_FWD_NEW,
1765                                     AMT_ACT_GMI,
1766                                     v6);
1767         }
1768 }
1769
1770 static void amt_mcast_to_ex_handler(struct amt_dev *amt,
1771                                     struct amt_tunnel_list *tunnel,
1772                                     struct amt_group_node *gnode,
1773                                     void *grec, void *zero_grec, bool v6)
1774 {
1775         if (gnode->filter_mode == MCAST_INCLUDE) {
1776 /* Router State   Report Rec'd New Router State        Actions
1777  * ------------   ------------ ----------------        -------
1778  * INCLUDE (A)    TO_EX (B)    EXCLUDE (A*B,B-A)       (B-A)=0
1779  *                                                     Delete (A-B)
1780  *                                                     Send Q(G,A*B)
1781  *                                                     Group Timer=GMI
1782  */
1783                 /* EXCLUDE (A*B, ) */
1784                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1785                                     AMT_FILTER_FWD,
1786                                     AMT_ACT_STATUS_FWD_NEW,
1787                                     v6);
1788                 /* EXCLUDE (, B-A) */
1789                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1790                                     AMT_FILTER_FWD,
1791                                     AMT_ACT_STATUS_D_FWD_NEW,
1792                                     v6);
1793                 /* (B-A)=0 */
1794                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1795                                     AMT_FILTER_D_FWD_NEW,
1796                                     AMT_ACT_GMI_ZERO,
1797                                     v6);
1798                 /* Group Timer=GMI */
1799                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1800                                       msecs_to_jiffies(amt_gmi(amt))))
1801                         dev_hold(amt->dev);
1802                 gnode->filter_mode = MCAST_EXCLUDE;
1803                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1804         } else {
1805 /* Router State   Report Rec'd New Router State        Actions
1806  * ------------   ------------ ----------------        -------
1807  * EXCLUDE (X,Y)  TO_EX (A)    EXCLUDE (A-Y,Y*A)       (A-X-Y)=Group Timer
1808  *                                                     Delete (X-A)
1809  *                                                     Delete (Y-A)
1810  *                                                     Send Q(G,A-Y)
1811  *                                                     Group Timer=GMI
1812  */
1813                 /* Update (A-X-Y) as NONE/OLD */
1814                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1815                                     AMT_FILTER_BOTH,
1816                                     AMT_ACT_GT,
1817                                     v6);
1818                 /* EXCLUDE (A-Y, ) */
1819                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1820                                     AMT_FILTER_D_FWD,
1821                                     AMT_ACT_STATUS_FWD_NEW,
1822                                     v6);
1823                 /* EXCLUDE (, Y*A) */
1824                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1825                                     AMT_FILTER_D_FWD,
1826                                     AMT_ACT_STATUS_D_FWD_NEW,
1827                                     v6);
1828                 /* Group Timer=GMI */
1829                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1830                                       msecs_to_jiffies(amt_gmi(amt))))
1831                         dev_hold(amt->dev);
1832                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1833         }
1834 }
1835
1836 static void amt_mcast_allow_handler(struct amt_dev *amt,
1837                                     struct amt_tunnel_list *tunnel,
1838                                     struct amt_group_node *gnode,
1839                                     void *grec, void *zero_grec, bool v6)
1840 {
1841         if (gnode->filter_mode == MCAST_INCLUDE) {
1842 /* Router State   Report Rec'd New Router State        Actions
1843  * ------------   ------------ ----------------        -------
1844  * INCLUDE (A)    ALLOW (B)    INCLUDE (A+B)           (B)=GMI
1845  */
1846                 /* INCLUDE (A+B) */
1847                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1848                                     AMT_FILTER_FWD,
1849                                     AMT_ACT_STATUS_FWD_NEW,
1850                                     v6);
1851                 /* (B)=GMI */
1852                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1853                                     AMT_FILTER_FWD_NEW,
1854                                     AMT_ACT_GMI,
1855                                     v6);
1856         } else {
1857 /* Router State   Report Rec'd New Router State        Actions
1858  * ------------   ------------ ----------------        -------
1859  * EXCLUDE (X,Y)  ALLOW (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1860  */
1861                 /* EXCLUDE (X+A, ) */
1862                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1863                                     AMT_FILTER_FWD,
1864                                     AMT_ACT_STATUS_FWD_NEW,
1865                                     v6);
1866                 /* EXCLUDE (, Y-A) */
1867                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1868                                     AMT_FILTER_D_FWD,
1869                                     AMT_ACT_STATUS_D_FWD_NEW,
1870                                     v6);
1871                 /* (A)=GMI
1872                  * All (A) source are now FWD/NEW status.
1873                  */
1874                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1875                                     AMT_FILTER_FWD_NEW,
1876                                     AMT_ACT_GMI,
1877                                     v6);
1878         }
1879 }
1880
1881 static void amt_mcast_block_handler(struct amt_dev *amt,
1882                                     struct amt_tunnel_list *tunnel,
1883                                     struct amt_group_node *gnode,
1884                                     void *grec, void *zero_grec, bool v6)
1885 {
1886         if (gnode->filter_mode == MCAST_INCLUDE) {
1887 /* Router State   Report Rec'd New Router State        Actions
1888  * ------------   ------------ ----------------        -------
1889  * INCLUDE (A)    BLOCK (B)    INCLUDE (A)             Send Q(G,A*B)
1890  */
1891                 /* INCLUDE (A) */
1892                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1893                                     AMT_FILTER_FWD,
1894                                     AMT_ACT_STATUS_FWD_NEW,
1895                                     v6);
1896         } else {
1897 /* Router State   Report Rec'd New Router State        Actions
1898  * ------------   ------------ ----------------        -------
1899  * EXCLUDE (X,Y)  BLOCK (A)    EXCLUDE (X+(A-Y),Y)     (A-X-Y)=Group Timer
1900  *                                                     Send Q(G,A-Y)
1901  */
1902                 /* (A-X-Y)=Group Timer */
1903                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1904                                     AMT_FILTER_BOTH,
1905                                     AMT_ACT_GT,
1906                                     v6);
1907                 /* EXCLUDE (X, ) */
1908                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1909                                     AMT_FILTER_FWD,
1910                                     AMT_ACT_STATUS_FWD_NEW,
1911                                     v6);
1912                 /* EXCLUDE (X+(A-Y) */
1913                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1914                                     AMT_FILTER_D_FWD,
1915                                     AMT_ACT_STATUS_FWD_NEW,
1916                                     v6);
1917                 /* EXCLUDE (, Y) */
1918                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1919                                     AMT_FILTER_D_FWD,
1920                                     AMT_ACT_STATUS_D_FWD_NEW,
1921                                     v6);
1922         }
1923 }
1924
1925 /* RFC 3376
1926  * 7.3.2. In the Presence of Older Version Group Members
1927  *
1928  * When Group Compatibility Mode is IGMPv2, a router internally
1929  * translates the following IGMPv2 messages for that group to their
1930  * IGMPv3 equivalents:
1931  *
1932  * IGMPv2 Message                IGMPv3 Equivalent
1933  * --------------                -----------------
1934  * Report                        IS_EX( {} )
1935  * Leave                         TO_IN( {} )
1936  */
1937 static void amt_igmpv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1938                                       struct amt_tunnel_list *tunnel)
1939 {
1940         struct igmphdr *ih = igmp_hdr(skb);
1941         struct iphdr *iph = ip_hdr(skb);
1942         struct amt_group_node *gnode;
1943         union amt_addr group, host;
1944
1945         memset(&group, 0, sizeof(union amt_addr));
1946         group.ip4 = ih->group;
1947         memset(&host, 0, sizeof(union amt_addr));
1948         host.ip4 = iph->saddr;
1949
1950         gnode = amt_lookup_group(tunnel, &group, &host, false);
1951         if (!gnode) {
1952                 gnode = amt_add_group(amt, tunnel, &group, &host, false);
1953                 if (!IS_ERR(gnode)) {
1954                         gnode->filter_mode = MCAST_EXCLUDE;
1955                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1956                                               msecs_to_jiffies(amt_gmi(amt))))
1957                                 dev_hold(amt->dev);
1958                 }
1959         }
1960 }
1961
1962 /* RFC 3376
1963  * 7.3.2. In the Presence of Older Version Group Members
1964  *
1965  * When Group Compatibility Mode is IGMPv2, a router internally
1966  * translates the following IGMPv2 messages for that group to their
1967  * IGMPv3 equivalents:
1968  *
1969  * IGMPv2 Message                IGMPv3 Equivalent
1970  * --------------                -----------------
1971  * Report                        IS_EX( {} )
1972  * Leave                         TO_IN( {} )
1973  */
1974 static void amt_igmpv2_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
1975                                      struct amt_tunnel_list *tunnel)
1976 {
1977         struct igmphdr *ih = igmp_hdr(skb);
1978         struct iphdr *iph = ip_hdr(skb);
1979         struct amt_group_node *gnode;
1980         union amt_addr group, host;
1981
1982         memset(&group, 0, sizeof(union amt_addr));
1983         group.ip4 = ih->group;
1984         memset(&host, 0, sizeof(union amt_addr));
1985         host.ip4 = iph->saddr;
1986
1987         gnode = amt_lookup_group(tunnel, &group, &host, false);
1988         if (gnode)
1989                 amt_del_group(amt, gnode);
1990 }
1991
1992 static void amt_igmpv3_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1993                                       struct amt_tunnel_list *tunnel)
1994 {
1995         struct igmpv3_report *ihrv3 = igmpv3_report_hdr(skb);
1996         int len = skb_transport_offset(skb) + sizeof(*ihrv3);
1997         void *zero_grec = (void *)&igmpv3_zero_grec;
1998         struct iphdr *iph = ip_hdr(skb);
1999         struct amt_group_node *gnode;
2000         union amt_addr group, host;
2001         struct igmpv3_grec *grec;
2002         u16 nsrcs;
2003         int i;
2004
2005         for (i = 0; i < ntohs(ihrv3->ngrec); i++) {
2006                 len += sizeof(*grec);
2007                 if (!ip_mc_may_pull(skb, len))
2008                         break;
2009
2010                 grec = (void *)(skb->data + len - sizeof(*grec));
2011                 nsrcs = ntohs(grec->grec_nsrcs);
2012
2013                 len += nsrcs * sizeof(__be32);
2014                 if (!ip_mc_may_pull(skb, len))
2015                         break;
2016
2017                 memset(&group, 0, sizeof(union amt_addr));
2018                 group.ip4 = grec->grec_mca;
2019                 memset(&host, 0, sizeof(union amt_addr));
2020                 host.ip4 = iph->saddr;
2021                 gnode = amt_lookup_group(tunnel, &group, &host, false);
2022                 if (!gnode) {
2023                         gnode = amt_add_group(amt, tunnel, &group, &host,
2024                                               false);
2025                         if (IS_ERR(gnode))
2026                                 continue;
2027                 }
2028
2029                 amt_add_srcs(amt, tunnel, gnode, grec, false);
2030                 switch (grec->grec_type) {
2031                 case IGMPV3_MODE_IS_INCLUDE:
2032                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2033                                                 zero_grec, false);
2034                         break;
2035                 case IGMPV3_MODE_IS_EXCLUDE:
2036                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2037                                                 zero_grec, false);
2038                         break;
2039                 case IGMPV3_CHANGE_TO_INCLUDE:
2040                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2041                                                 zero_grec, false);
2042                         break;
2043                 case IGMPV3_CHANGE_TO_EXCLUDE:
2044                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2045                                                 zero_grec, false);
2046                         break;
2047                 case IGMPV3_ALLOW_NEW_SOURCES:
2048                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2049                                                 zero_grec, false);
2050                         break;
2051                 case IGMPV3_BLOCK_OLD_SOURCES:
2052                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2053                                                 zero_grec, false);
2054                         break;
2055                 default:
2056                         break;
2057                 }
2058                 amt_cleanup_srcs(amt, tunnel, gnode);
2059         }
2060 }
2061
2062 /* caller held tunnel->lock */
2063 static void amt_igmp_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2064                                     struct amt_tunnel_list *tunnel)
2065 {
2066         struct igmphdr *ih = igmp_hdr(skb);
2067
2068         switch (ih->type) {
2069         case IGMPV3_HOST_MEMBERSHIP_REPORT:
2070                 amt_igmpv3_report_handler(amt, skb, tunnel);
2071                 break;
2072         case IGMPV2_HOST_MEMBERSHIP_REPORT:
2073                 amt_igmpv2_report_handler(amt, skb, tunnel);
2074                 break;
2075         case IGMP_HOST_LEAVE_MESSAGE:
2076                 amt_igmpv2_leave_handler(amt, skb, tunnel);
2077                 break;
2078         default:
2079                 break;
2080         }
2081 }
2082
2083 #if IS_ENABLED(CONFIG_IPV6)
2084 /* RFC 3810
2085  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2086  *
2087  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2088  * using the MLDv2 protocol for that multicast address.  When Multicast
2089  * Address Compatibility Mode is MLDv1, a router internally translates
2090  * the following MLDv1 messages for that multicast address to their
2091  * MLDv2 equivalents:
2092  *
2093  * MLDv1 Message                 MLDv2 Equivalent
2094  * --------------                -----------------
2095  * Report                        IS_EX( {} )
2096  * Done                          TO_IN( {} )
2097  */
2098 static void amt_mldv1_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2099                                      struct amt_tunnel_list *tunnel)
2100 {
2101         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2102         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2103         struct amt_group_node *gnode;
2104         union amt_addr group, host;
2105
2106         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2107         memcpy(&host.ip6, &ip6h->saddr, sizeof(struct in6_addr));
2108
2109         gnode = amt_lookup_group(tunnel, &group, &host, true);
2110         if (!gnode) {
2111                 gnode = amt_add_group(amt, tunnel, &group, &host, true);
2112                 if (!IS_ERR(gnode)) {
2113                         gnode->filter_mode = MCAST_EXCLUDE;
2114                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
2115                                               msecs_to_jiffies(amt_gmi(amt))))
2116                                 dev_hold(amt->dev);
2117                 }
2118         }
2119 }
2120
2121 /* RFC 3810
2122  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2123  *
2124  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2125  * using the MLDv2 protocol for that multicast address.  When Multicast
2126  * Address Compatibility Mode is MLDv1, a router internally translates
2127  * the following MLDv1 messages for that multicast address to their
2128  * MLDv2 equivalents:
2129  *
2130  * MLDv1 Message                 MLDv2 Equivalent
2131  * --------------                -----------------
2132  * Report                        IS_EX( {} )
2133  * Done                          TO_IN( {} )
2134  */
2135 static void amt_mldv1_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
2136                                     struct amt_tunnel_list *tunnel)
2137 {
2138         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2139         struct iphdr *iph = ip_hdr(skb);
2140         struct amt_group_node *gnode;
2141         union amt_addr group, host;
2142
2143         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2144         memset(&host, 0, sizeof(union amt_addr));
2145         host.ip4 = iph->saddr;
2146
2147         gnode = amt_lookup_group(tunnel, &group, &host, true);
2148         if (gnode) {
2149                 amt_del_group(amt, gnode);
2150                 return;
2151         }
2152 }
2153
2154 static void amt_mldv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2155                                      struct amt_tunnel_list *tunnel)
2156 {
2157         struct mld2_report *mld2r = (struct mld2_report *)icmp6_hdr(skb);
2158         int len = skb_transport_offset(skb) + sizeof(*mld2r);
2159         void *zero_grec = (void *)&mldv2_zero_grec;
2160         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2161         struct amt_group_node *gnode;
2162         union amt_addr group, host;
2163         struct mld2_grec *grec;
2164         u16 nsrcs;
2165         int i;
2166
2167         for (i = 0; i < ntohs(mld2r->mld2r_ngrec); i++) {
2168                 len += sizeof(*grec);
2169                 if (!ipv6_mc_may_pull(skb, len))
2170                         break;
2171
2172                 grec = (void *)(skb->data + len - sizeof(*grec));
2173                 nsrcs = ntohs(grec->grec_nsrcs);
2174
2175                 len += nsrcs * sizeof(struct in6_addr);
2176                 if (!ipv6_mc_may_pull(skb, len))
2177                         break;
2178
2179                 memset(&group, 0, sizeof(union amt_addr));
2180                 group.ip6 = grec->grec_mca;
2181                 memset(&host, 0, sizeof(union amt_addr));
2182                 host.ip6 = ip6h->saddr;
2183                 gnode = amt_lookup_group(tunnel, &group, &host, true);
2184                 if (!gnode) {
2185                         gnode = amt_add_group(amt, tunnel, &group, &host,
2186                                               ETH_P_IPV6);
2187                         if (IS_ERR(gnode))
2188                                 continue;
2189                 }
2190
2191                 amt_add_srcs(amt, tunnel, gnode, grec, true);
2192                 switch (grec->grec_type) {
2193                 case MLD2_MODE_IS_INCLUDE:
2194                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2195                                                 zero_grec, true);
2196                         break;
2197                 case MLD2_MODE_IS_EXCLUDE:
2198                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2199                                                 zero_grec, true);
2200                         break;
2201                 case MLD2_CHANGE_TO_INCLUDE:
2202                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2203                                                 zero_grec, true);
2204                         break;
2205                 case MLD2_CHANGE_TO_EXCLUDE:
2206                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2207                                                 zero_grec, true);
2208                         break;
2209                 case MLD2_ALLOW_NEW_SOURCES:
2210                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2211                                                 zero_grec, true);
2212                         break;
2213                 case MLD2_BLOCK_OLD_SOURCES:
2214                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2215                                                 zero_grec, true);
2216                         break;
2217                 default:
2218                         break;
2219                 }
2220                 amt_cleanup_srcs(amt, tunnel, gnode);
2221         }
2222 }
2223
2224 /* caller held tunnel->lock */
2225 static void amt_mld_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2226                                    struct amt_tunnel_list *tunnel)
2227 {
2228         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2229
2230         switch (mld->mld_type) {
2231         case ICMPV6_MGM_REPORT:
2232                 amt_mldv1_report_handler(amt, skb, tunnel);
2233                 break;
2234         case ICMPV6_MLD2_REPORT:
2235                 amt_mldv2_report_handler(amt, skb, tunnel);
2236                 break;
2237         case ICMPV6_MGM_REDUCTION:
2238                 amt_mldv1_leave_handler(amt, skb, tunnel);
2239                 break;
2240         default:
2241                 break;
2242         }
2243 }
2244 #endif
2245
2246 static bool amt_advertisement_handler(struct amt_dev *amt, struct sk_buff *skb)
2247 {
2248         struct amt_header_advertisement *amta;
2249         int hdr_size;
2250
2251         hdr_size = sizeof(*amta) + sizeof(struct udphdr);
2252         if (!pskb_may_pull(skb, hdr_size))
2253                 return true;
2254
2255         amta = (struct amt_header_advertisement *)(udp_hdr(skb) + 1);
2256         if (!amta->ip4)
2257                 return true;
2258
2259         if (amta->reserved || amta->version)
2260                 return true;
2261
2262         if (ipv4_is_loopback(amta->ip4) || ipv4_is_multicast(amta->ip4) ||
2263             ipv4_is_zeronet(amta->ip4))
2264                 return true;
2265
2266         if (amt->status != AMT_STATUS_SENT_DISCOVERY ||
2267             amt->nonce != amta->nonce)
2268                 return true;
2269
2270         amt->remote_ip = amta->ip4;
2271         netdev_dbg(amt->dev, "advertised remote ip = %pI4\n", &amt->remote_ip);
2272         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2273
2274         amt_update_gw_status(amt, AMT_STATUS_RECEIVED_ADVERTISEMENT, true);
2275         return false;
2276 }
2277
2278 static bool amt_multicast_data_handler(struct amt_dev *amt, struct sk_buff *skb)
2279 {
2280         struct amt_header_mcast_data *amtmd;
2281         int hdr_size, len, err;
2282         struct ethhdr *eth;
2283         struct iphdr *iph;
2284
2285         if (READ_ONCE(amt->status) != AMT_STATUS_SENT_UPDATE)
2286                 return true;
2287
2288         hdr_size = sizeof(*amtmd) + sizeof(struct udphdr);
2289         if (!pskb_may_pull(skb, hdr_size))
2290                 return true;
2291
2292         amtmd = (struct amt_header_mcast_data *)(udp_hdr(skb) + 1);
2293         if (amtmd->reserved || amtmd->version)
2294                 return true;
2295
2296         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_IP), false))
2297                 return true;
2298
2299         skb_reset_network_header(skb);
2300         skb_push(skb, sizeof(*eth));
2301         skb_reset_mac_header(skb);
2302         skb_pull(skb, sizeof(*eth));
2303         eth = eth_hdr(skb);
2304
2305         if (!pskb_may_pull(skb, sizeof(*iph)))
2306                 return true;
2307         iph = ip_hdr(skb);
2308
2309         if (iph->version == 4) {
2310                 if (!ipv4_is_multicast(iph->daddr))
2311                         return true;
2312                 skb->protocol = htons(ETH_P_IP);
2313                 eth->h_proto = htons(ETH_P_IP);
2314                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2315 #if IS_ENABLED(CONFIG_IPV6)
2316         } else if (iph->version == 6) {
2317                 struct ipv6hdr *ip6h;
2318
2319                 if (!pskb_may_pull(skb, sizeof(*ip6h)))
2320                         return true;
2321
2322                 ip6h = ipv6_hdr(skb);
2323                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2324                         return true;
2325                 skb->protocol = htons(ETH_P_IPV6);
2326                 eth->h_proto = htons(ETH_P_IPV6);
2327                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2328 #endif
2329         } else {
2330                 return true;
2331         }
2332
2333         skb->pkt_type = PACKET_MULTICAST;
2334         skb->ip_summed = CHECKSUM_NONE;
2335         len = skb->len;
2336         err = gro_cells_receive(&amt->gro_cells, skb);
2337         if (likely(err == NET_RX_SUCCESS))
2338                 dev_sw_netstats_rx_add(amt->dev, len);
2339         else
2340                 amt->dev->stats.rx_dropped++;
2341
2342         return false;
2343 }
2344
2345 static bool amt_membership_query_handler(struct amt_dev *amt,
2346                                          struct sk_buff *skb)
2347 {
2348         struct amt_header_membership_query *amtmq;
2349         struct igmpv3_query *ihv3;
2350         struct ethhdr *eth, *oeth;
2351         struct iphdr *iph;
2352         int hdr_size, len;
2353
2354         hdr_size = sizeof(*amtmq) + sizeof(struct udphdr);
2355         if (!pskb_may_pull(skb, hdr_size))
2356                 return true;
2357
2358         amtmq = (struct amt_header_membership_query *)(udp_hdr(skb) + 1);
2359         if (amtmq->reserved || amtmq->version)
2360                 return true;
2361
2362         if (amtmq->nonce != amt->nonce)
2363                 return true;
2364
2365         hdr_size -= sizeof(*eth);
2366         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_TEB), false))
2367                 return true;
2368
2369         oeth = eth_hdr(skb);
2370         skb_reset_mac_header(skb);
2371         skb_pull(skb, sizeof(*eth));
2372         skb_reset_network_header(skb);
2373         eth = eth_hdr(skb);
2374         if (!pskb_may_pull(skb, sizeof(*iph)))
2375                 return true;
2376
2377         iph = ip_hdr(skb);
2378         if (iph->version == 4) {
2379                 if (READ_ONCE(amt->ready4))
2380                         return true;
2381
2382                 if (!pskb_may_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS +
2383                                    sizeof(*ihv3)))
2384                         return true;
2385
2386                 if (!ipv4_is_multicast(iph->daddr))
2387                         return true;
2388
2389                 ihv3 = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2390                 skb_reset_transport_header(skb);
2391                 skb_push(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2392                 WRITE_ONCE(amt->ready4, true);
2393                 amt->mac = amtmq->response_mac;
2394                 amt->req_cnt = 0;
2395                 amt->qi = ihv3->qqic;
2396                 skb->protocol = htons(ETH_P_IP);
2397                 eth->h_proto = htons(ETH_P_IP);
2398                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2399 #if IS_ENABLED(CONFIG_IPV6)
2400         } else if (iph->version == 6) {
2401                 struct mld2_query *mld2q;
2402                 struct ipv6hdr *ip6h;
2403
2404                 if (READ_ONCE(amt->ready6))
2405                         return true;
2406
2407                 if (!pskb_may_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS +
2408                                    sizeof(*mld2q)))
2409                         return true;
2410
2411                 ip6h = ipv6_hdr(skb);
2412                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2413                         return true;
2414
2415                 mld2q = skb_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2416                 skb_reset_transport_header(skb);
2417                 skb_push(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2418                 WRITE_ONCE(amt->ready6, true);
2419                 amt->mac = amtmq->response_mac;
2420                 amt->req_cnt = 0;
2421                 amt->qi = mld2q->mld2q_qqic;
2422                 skb->protocol = htons(ETH_P_IPV6);
2423                 eth->h_proto = htons(ETH_P_IPV6);
2424                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2425 #endif
2426         } else {
2427                 return true;
2428         }
2429
2430         ether_addr_copy(eth->h_source, oeth->h_source);
2431         skb->pkt_type = PACKET_MULTICAST;
2432         skb->ip_summed = CHECKSUM_NONE;
2433         len = skb->len;
2434         local_bh_disable();
2435         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2436                 amt_update_gw_status(amt, AMT_STATUS_RECEIVED_QUERY, true);
2437                 dev_sw_netstats_rx_add(amt->dev, len);
2438         } else {
2439                 amt->dev->stats.rx_dropped++;
2440         }
2441         local_bh_enable();
2442
2443         return false;
2444 }
2445
2446 static bool amt_update_handler(struct amt_dev *amt, struct sk_buff *skb)
2447 {
2448         struct amt_header_membership_update *amtmu;
2449         struct amt_tunnel_list *tunnel;
2450         struct ethhdr *eth;
2451         struct iphdr *iph;
2452         int len, hdr_size;
2453
2454         iph = ip_hdr(skb);
2455
2456         hdr_size = sizeof(*amtmu) + sizeof(struct udphdr);
2457         if (!pskb_may_pull(skb, hdr_size))
2458                 return true;
2459
2460         amtmu = (struct amt_header_membership_update *)(udp_hdr(skb) + 1);
2461         if (amtmu->reserved || amtmu->version)
2462                 return true;
2463
2464         if (iptunnel_pull_header(skb, hdr_size, skb->protocol, false))
2465                 return true;
2466
2467         skb_reset_network_header(skb);
2468
2469         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
2470                 if (tunnel->ip4 == iph->saddr) {
2471                         if ((amtmu->nonce == tunnel->nonce &&
2472                              amtmu->response_mac == tunnel->mac)) {
2473                                 mod_delayed_work(amt_wq, &tunnel->gc_wq,
2474                                                  msecs_to_jiffies(amt_gmi(amt))
2475                                                                   * 3);
2476                                 goto report;
2477                         } else {
2478                                 netdev_dbg(amt->dev, "Invalid MAC\n");
2479                                 return true;
2480                         }
2481                 }
2482         }
2483
2484         return true;
2485
2486 report:
2487         if (!pskb_may_pull(skb, sizeof(*iph)))
2488                 return true;
2489
2490         iph = ip_hdr(skb);
2491         if (iph->version == 4) {
2492                 if (ip_mc_check_igmp(skb)) {
2493                         netdev_dbg(amt->dev, "Invalid IGMP\n");
2494                         return true;
2495                 }
2496
2497                 spin_lock_bh(&tunnel->lock);
2498                 amt_igmp_report_handler(amt, skb, tunnel);
2499                 spin_unlock_bh(&tunnel->lock);
2500
2501                 skb_push(skb, sizeof(struct ethhdr));
2502                 skb_reset_mac_header(skb);
2503                 eth = eth_hdr(skb);
2504                 skb->protocol = htons(ETH_P_IP);
2505                 eth->h_proto = htons(ETH_P_IP);
2506                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2507 #if IS_ENABLED(CONFIG_IPV6)
2508         } else if (iph->version == 6) {
2509                 struct ipv6hdr *ip6h = ipv6_hdr(skb);
2510
2511                 if (ipv6_mc_check_mld(skb)) {
2512                         netdev_dbg(amt->dev, "Invalid MLD\n");
2513                         return true;
2514                 }
2515
2516                 spin_lock_bh(&tunnel->lock);
2517                 amt_mld_report_handler(amt, skb, tunnel);
2518                 spin_unlock_bh(&tunnel->lock);
2519
2520                 skb_push(skb, sizeof(struct ethhdr));
2521                 skb_reset_mac_header(skb);
2522                 eth = eth_hdr(skb);
2523                 skb->protocol = htons(ETH_P_IPV6);
2524                 eth->h_proto = htons(ETH_P_IPV6);
2525                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2526 #endif
2527         } else {
2528                 netdev_dbg(amt->dev, "Unsupported Protocol\n");
2529                 return true;
2530         }
2531
2532         skb_pull(skb, sizeof(struct ethhdr));
2533         skb->pkt_type = PACKET_MULTICAST;
2534         skb->ip_summed = CHECKSUM_NONE;
2535         len = skb->len;
2536         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2537                 amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_UPDATE,
2538                                         true);
2539                 dev_sw_netstats_rx_add(amt->dev, len);
2540         } else {
2541                 amt->dev->stats.rx_dropped++;
2542         }
2543
2544         return false;
2545 }
2546
2547 static void amt_send_advertisement(struct amt_dev *amt, __be32 nonce,
2548                                    __be32 daddr, __be16 dport)
2549 {
2550         struct amt_header_advertisement *amta;
2551         int hlen, tlen, offset;
2552         struct socket *sock;
2553         struct udphdr *udph;
2554         struct sk_buff *skb;
2555         struct iphdr *iph;
2556         struct rtable *rt;
2557         struct flowi4 fl4;
2558         u32 len;
2559         int err;
2560
2561         rcu_read_lock();
2562         sock = rcu_dereference(amt->sock);
2563         if (!sock)
2564                 goto out;
2565
2566         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
2567                 goto out;
2568
2569         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
2570                                    daddr, amt->local_ip,
2571                                    dport, amt->relay_port,
2572                                    IPPROTO_UDP, 0,
2573                                    amt->stream_dev->ifindex);
2574         if (IS_ERR(rt)) {
2575                 amt->dev->stats.tx_errors++;
2576                 goto out;
2577         }
2578
2579         hlen = LL_RESERVED_SPACE(amt->dev);
2580         tlen = amt->dev->needed_tailroom;
2581         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2582         skb = netdev_alloc_skb_ip_align(amt->dev, len);
2583         if (!skb) {
2584                 ip_rt_put(rt);
2585                 amt->dev->stats.tx_errors++;
2586                 goto out;
2587         }
2588
2589         skb->priority = TC_PRIO_CONTROL;
2590         skb_dst_set(skb, &rt->dst);
2591
2592         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2593         skb_reset_network_header(skb);
2594         skb_put(skb, len);
2595         amta = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
2596         amta->version   = 0;
2597         amta->type      = AMT_MSG_ADVERTISEMENT;
2598         amta->reserved  = 0;
2599         amta->nonce     = nonce;
2600         amta->ip4       = amt->local_ip;
2601         skb_push(skb, sizeof(*udph));
2602         skb_reset_transport_header(skb);
2603         udph            = udp_hdr(skb);
2604         udph->source    = amt->relay_port;
2605         udph->dest      = dport;
2606         udph->len       = htons(sizeof(*amta) + sizeof(*udph));
2607         udph->check     = 0;
2608         offset = skb_transport_offset(skb);
2609         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
2610         udph->check = csum_tcpudp_magic(amt->local_ip, daddr,
2611                                         sizeof(*udph) + sizeof(*amta),
2612                                         IPPROTO_UDP, skb->csum);
2613
2614         skb_push(skb, sizeof(*iph));
2615         iph             = ip_hdr(skb);
2616         iph->version    = 4;
2617         iph->ihl        = (sizeof(struct iphdr)) >> 2;
2618         iph->tos        = AMT_TOS;
2619         iph->frag_off   = 0;
2620         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
2621         iph->daddr      = daddr;
2622         iph->saddr      = amt->local_ip;
2623         iph->protocol   = IPPROTO_UDP;
2624         iph->tot_len    = htons(len);
2625
2626         skb->ip_summed = CHECKSUM_NONE;
2627         ip_select_ident(amt->net, skb, NULL);
2628         ip_send_check(iph);
2629         err = ip_local_out(amt->net, sock->sk, skb);
2630         if (unlikely(net_xmit_eval(err)))
2631                 amt->dev->stats.tx_errors++;
2632
2633 out:
2634         rcu_read_unlock();
2635 }
2636
2637 static bool amt_discovery_handler(struct amt_dev *amt, struct sk_buff *skb)
2638 {
2639         struct amt_header_discovery *amtd;
2640         struct udphdr *udph;
2641         struct iphdr *iph;
2642
2643         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtd)))
2644                 return true;
2645
2646         iph = ip_hdr(skb);
2647         udph = udp_hdr(skb);
2648         amtd = (struct amt_header_discovery *)(udp_hdr(skb) + 1);
2649
2650         if (amtd->reserved || amtd->version)
2651                 return true;
2652
2653         amt_send_advertisement(amt, amtd->nonce, iph->saddr, udph->source);
2654
2655         return false;
2656 }
2657
2658 static bool amt_request_handler(struct amt_dev *amt, struct sk_buff *skb)
2659 {
2660         struct amt_header_request *amtrh;
2661         struct amt_tunnel_list *tunnel;
2662         unsigned long long key;
2663         struct udphdr *udph;
2664         struct iphdr *iph;
2665         u64 mac;
2666         int i;
2667
2668         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtrh)))
2669                 return true;
2670
2671         iph = ip_hdr(skb);
2672         udph = udp_hdr(skb);
2673         amtrh = (struct amt_header_request *)(udp_hdr(skb) + 1);
2674
2675         if (amtrh->reserved1 || amtrh->reserved2 || amtrh->version)
2676                 return true;
2677
2678         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list)
2679                 if (tunnel->ip4 == iph->saddr)
2680                         goto send;
2681
2682         spin_lock_bh(&amt->lock);
2683         if (amt->nr_tunnels >= amt->max_tunnels) {
2684                 spin_unlock_bh(&amt->lock);
2685                 icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0);
2686                 return true;
2687         }
2688
2689         tunnel = kzalloc(sizeof(*tunnel) +
2690                          (sizeof(struct hlist_head) * amt->hash_buckets),
2691                          GFP_ATOMIC);
2692         if (!tunnel) {
2693                 spin_unlock_bh(&amt->lock);
2694                 return true;
2695         }
2696
2697         tunnel->source_port = udph->source;
2698         tunnel->ip4 = iph->saddr;
2699
2700         memcpy(&key, &tunnel->key, sizeof(unsigned long long));
2701         tunnel->amt = amt;
2702         spin_lock_init(&tunnel->lock);
2703         for (i = 0; i < amt->hash_buckets; i++)
2704                 INIT_HLIST_HEAD(&tunnel->groups[i]);
2705
2706         INIT_DELAYED_WORK(&tunnel->gc_wq, amt_tunnel_expire);
2707
2708         list_add_tail_rcu(&tunnel->list, &amt->tunnel_list);
2709         tunnel->key = amt->key;
2710         __amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_REQUEST, true);
2711         amt->nr_tunnels++;
2712         mod_delayed_work(amt_wq, &tunnel->gc_wq,
2713                          msecs_to_jiffies(amt_gmi(amt)));
2714         spin_unlock_bh(&amt->lock);
2715
2716 send:
2717         tunnel->nonce = amtrh->nonce;
2718         mac = siphash_3u32((__force u32)tunnel->ip4,
2719                            (__force u32)tunnel->source_port,
2720                            (__force u32)tunnel->nonce,
2721                            &tunnel->key);
2722         tunnel->mac = mac >> 16;
2723
2724         if (!netif_running(amt->dev) || !netif_running(amt->stream_dev))
2725                 return true;
2726
2727         if (!amtrh->p)
2728                 amt_send_igmp_gq(amt, tunnel);
2729         else
2730                 amt_send_mld_gq(amt, tunnel);
2731
2732         return false;
2733 }
2734
2735 static void amt_gw_rcv(struct amt_dev *amt, struct sk_buff *skb)
2736 {
2737         int type = amt_parse_type(skb);
2738         int err = 1;
2739
2740         if (type == -1)
2741                 goto drop;
2742
2743         if (amt->mode == AMT_MODE_GATEWAY) {
2744                 switch (type) {
2745                 case AMT_MSG_ADVERTISEMENT:
2746                         err = amt_advertisement_handler(amt, skb);
2747                         break;
2748                 case AMT_MSG_MEMBERSHIP_QUERY:
2749                         err = amt_membership_query_handler(amt, skb);
2750                         if (!err)
2751                                 return;
2752                         break;
2753                 default:
2754                         netdev_dbg(amt->dev, "Invalid type of Gateway\n");
2755                         break;
2756                 }
2757         }
2758 drop:
2759         if (err) {
2760                 amt->dev->stats.rx_dropped++;
2761                 kfree_skb(skb);
2762         } else {
2763                 consume_skb(skb);
2764         }
2765 }
2766
2767 static int amt_rcv(struct sock *sk, struct sk_buff *skb)
2768 {
2769         struct amt_dev *amt;
2770         struct iphdr *iph;
2771         int type;
2772         bool err;
2773
2774         rcu_read_lock_bh();
2775         amt = rcu_dereference_sk_user_data(sk);
2776         if (!amt) {
2777                 err = true;
2778                 kfree_skb(skb);
2779                 goto out;
2780         }
2781
2782         skb->dev = amt->dev;
2783         iph = ip_hdr(skb);
2784         type = amt_parse_type(skb);
2785         if (type == -1) {
2786                 err = true;
2787                 goto drop;
2788         }
2789
2790         if (amt->mode == AMT_MODE_GATEWAY) {
2791                 switch (type) {
2792                 case AMT_MSG_ADVERTISEMENT:
2793                         if (iph->saddr != amt->discovery_ip) {
2794                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2795                                 err = true;
2796                                 goto drop;
2797                         }
2798                         if (amt_queue_event(amt, AMT_EVENT_RECEIVE, skb)) {
2799                                 netdev_dbg(amt->dev, "AMT Event queue full\n");
2800                                 err = true;
2801                                 goto drop;
2802                         }
2803                         goto out;
2804                 case AMT_MSG_MULTICAST_DATA:
2805                         if (iph->saddr != amt->remote_ip) {
2806                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2807                                 err = true;
2808                                 goto drop;
2809                         }
2810                         err = amt_multicast_data_handler(amt, skb);
2811                         if (err)
2812                                 goto drop;
2813                         else
2814                                 goto out;
2815                 case AMT_MSG_MEMBERSHIP_QUERY:
2816                         if (iph->saddr != amt->remote_ip) {
2817                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2818                                 err = true;
2819                                 goto drop;
2820                         }
2821                         if (amt_queue_event(amt, AMT_EVENT_RECEIVE, skb)) {
2822                                 netdev_dbg(amt->dev, "AMT Event queue full\n");
2823                                 err = true;
2824                                 goto drop;
2825                         }
2826                         goto out;
2827                 default:
2828                         err = true;
2829                         netdev_dbg(amt->dev, "Invalid type of Gateway\n");
2830                         break;
2831                 }
2832         } else {
2833                 switch (type) {
2834                 case AMT_MSG_DISCOVERY:
2835                         err = amt_discovery_handler(amt, skb);
2836                         break;
2837                 case AMT_MSG_REQUEST:
2838                         err = amt_request_handler(amt, skb);
2839                         break;
2840                 case AMT_MSG_MEMBERSHIP_UPDATE:
2841                         err = amt_update_handler(amt, skb);
2842                         if (err)
2843                                 goto drop;
2844                         else
2845                                 goto out;
2846                 default:
2847                         err = true;
2848                         netdev_dbg(amt->dev, "Invalid type of relay\n");
2849                         break;
2850                 }
2851         }
2852 drop:
2853         if (err) {
2854                 amt->dev->stats.rx_dropped++;
2855                 kfree_skb(skb);
2856         } else {
2857                 consume_skb(skb);
2858         }
2859 out:
2860         rcu_read_unlock_bh();
2861         return 0;
2862 }
2863
2864 static void amt_event_work(struct work_struct *work)
2865 {
2866         struct amt_dev *amt = container_of(work, struct amt_dev, event_wq);
2867         struct sk_buff *skb;
2868         u8 event;
2869         int i;
2870
2871         for (i = 0; i < AMT_MAX_EVENTS; i++) {
2872                 spin_lock_bh(&amt->lock);
2873                 if (amt->nr_events == 0) {
2874                         spin_unlock_bh(&amt->lock);
2875                         return;
2876                 }
2877                 event = amt->events[amt->event_idx].event;
2878                 skb = amt->events[amt->event_idx].skb;
2879                 amt->events[amt->event_idx].event = AMT_EVENT_NONE;
2880                 amt->events[amt->event_idx].skb = NULL;
2881                 amt->nr_events--;
2882                 amt->event_idx++;
2883                 amt->event_idx %= AMT_MAX_EVENTS;
2884                 spin_unlock_bh(&amt->lock);
2885
2886                 switch (event) {
2887                 case AMT_EVENT_RECEIVE:
2888                         amt_gw_rcv(amt, skb);
2889                         break;
2890                 case AMT_EVENT_SEND_DISCOVERY:
2891                         amt_event_send_discovery(amt);
2892                         break;
2893                 case AMT_EVENT_SEND_REQUEST:
2894                         amt_event_send_request(amt);
2895                         break;
2896                 default:
2897                         kfree_skb(skb);
2898                         break;
2899                 }
2900         }
2901 }
2902
2903 static int amt_err_lookup(struct sock *sk, struct sk_buff *skb)
2904 {
2905         struct amt_dev *amt;
2906         int type;
2907
2908         rcu_read_lock_bh();
2909         amt = rcu_dereference_sk_user_data(sk);
2910         if (!amt)
2911                 goto out;
2912
2913         if (amt->mode != AMT_MODE_GATEWAY)
2914                 goto drop;
2915
2916         type = amt_parse_type(skb);
2917         if (type == -1)
2918                 goto drop;
2919
2920         netdev_dbg(amt->dev, "Received IGMP Unreachable of %s\n",
2921                    type_str[type]);
2922         switch (type) {
2923         case AMT_MSG_DISCOVERY:
2924                 break;
2925         case AMT_MSG_REQUEST:
2926         case AMT_MSG_MEMBERSHIP_UPDATE:
2927                 if (READ_ONCE(amt->status) >= AMT_STATUS_RECEIVED_ADVERTISEMENT)
2928                         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2929                 break;
2930         default:
2931                 goto drop;
2932         }
2933 out:
2934         rcu_read_unlock_bh();
2935         return 0;
2936 drop:
2937         rcu_read_unlock_bh();
2938         amt->dev->stats.rx_dropped++;
2939         return 0;
2940 }
2941
2942 static struct socket *amt_create_sock(struct net *net, __be16 port)
2943 {
2944         struct udp_port_cfg udp_conf;
2945         struct socket *sock;
2946         int err;
2947
2948         memset(&udp_conf, 0, sizeof(udp_conf));
2949         udp_conf.family = AF_INET;
2950         udp_conf.local_ip.s_addr = htonl(INADDR_ANY);
2951
2952         udp_conf.local_udp_port = port;
2953
2954         err = udp_sock_create(net, &udp_conf, &sock);
2955         if (err < 0)
2956                 return ERR_PTR(err);
2957
2958         return sock;
2959 }
2960
2961 static int amt_socket_create(struct amt_dev *amt)
2962 {
2963         struct udp_tunnel_sock_cfg tunnel_cfg;
2964         struct socket *sock;
2965
2966         sock = amt_create_sock(amt->net, amt->relay_port);
2967         if (IS_ERR(sock))
2968                 return PTR_ERR(sock);
2969
2970         /* Mark socket as an encapsulation socket */
2971         memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
2972         tunnel_cfg.sk_user_data = amt;
2973         tunnel_cfg.encap_type = 1;
2974         tunnel_cfg.encap_rcv = amt_rcv;
2975         tunnel_cfg.encap_err_lookup = amt_err_lookup;
2976         tunnel_cfg.encap_destroy = NULL;
2977         setup_udp_tunnel_sock(amt->net, sock, &tunnel_cfg);
2978
2979         rcu_assign_pointer(amt->sock, sock);
2980         return 0;
2981 }
2982
2983 static int amt_dev_open(struct net_device *dev)
2984 {
2985         struct amt_dev *amt = netdev_priv(dev);
2986         int err;
2987
2988         amt->ready4 = false;
2989         amt->ready6 = false;
2990         amt->event_idx = 0;
2991         amt->nr_events = 0;
2992
2993         err = amt_socket_create(amt);
2994         if (err)
2995                 return err;
2996
2997         amt->req_cnt = 0;
2998         amt->remote_ip = 0;
2999         amt->nonce = 0;
3000         get_random_bytes(&amt->key, sizeof(siphash_key_t));
3001
3002         amt->status = AMT_STATUS_INIT;
3003         if (amt->mode == AMT_MODE_GATEWAY) {
3004                 mod_delayed_work(amt_wq, &amt->discovery_wq, 0);
3005                 mod_delayed_work(amt_wq, &amt->req_wq, 0);
3006         } else if (amt->mode == AMT_MODE_RELAY) {
3007                 mod_delayed_work(amt_wq, &amt->secret_wq,
3008                                  msecs_to_jiffies(AMT_SECRET_TIMEOUT));
3009         }
3010         return err;
3011 }
3012
3013 static int amt_dev_stop(struct net_device *dev)
3014 {
3015         struct amt_dev *amt = netdev_priv(dev);
3016         struct amt_tunnel_list *tunnel, *tmp;
3017         struct socket *sock;
3018         struct sk_buff *skb;
3019         int i;
3020
3021         cancel_delayed_work_sync(&amt->req_wq);
3022         cancel_delayed_work_sync(&amt->discovery_wq);
3023         cancel_delayed_work_sync(&amt->secret_wq);
3024
3025         /* shutdown */
3026         sock = rtnl_dereference(amt->sock);
3027         RCU_INIT_POINTER(amt->sock, NULL);
3028         synchronize_net();
3029         if (sock)
3030                 udp_tunnel_sock_release(sock);
3031
3032         cancel_work_sync(&amt->event_wq);
3033         for (i = 0; i < AMT_MAX_EVENTS; i++) {
3034                 skb = amt->events[i].skb;
3035                 kfree_skb(skb);
3036                 amt->events[i].event = AMT_EVENT_NONE;
3037                 amt->events[i].skb = NULL;
3038         }
3039
3040         amt->ready4 = false;
3041         amt->ready6 = false;
3042         amt->req_cnt = 0;
3043         amt->remote_ip = 0;
3044
3045         list_for_each_entry_safe(tunnel, tmp, &amt->tunnel_list, list) {
3046                 list_del_rcu(&tunnel->list);
3047                 amt->nr_tunnels--;
3048                 cancel_delayed_work_sync(&tunnel->gc_wq);
3049                 amt_clear_groups(tunnel);
3050                 kfree_rcu(tunnel, rcu);
3051         }
3052
3053         return 0;
3054 }
3055
3056 static const struct device_type amt_type = {
3057         .name = "amt",
3058 };
3059
3060 static int amt_dev_init(struct net_device *dev)
3061 {
3062         struct amt_dev *amt = netdev_priv(dev);
3063         int err;
3064
3065         amt->dev = dev;
3066         dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
3067         if (!dev->tstats)
3068                 return -ENOMEM;
3069
3070         err = gro_cells_init(&amt->gro_cells, dev);
3071         if (err) {
3072                 free_percpu(dev->tstats);
3073                 return err;
3074         }
3075
3076         return 0;
3077 }
3078
3079 static void amt_dev_uninit(struct net_device *dev)
3080 {
3081         struct amt_dev *amt = netdev_priv(dev);
3082
3083         gro_cells_destroy(&amt->gro_cells);
3084         free_percpu(dev->tstats);
3085 }
3086
3087 static const struct net_device_ops amt_netdev_ops = {
3088         .ndo_init               = amt_dev_init,
3089         .ndo_uninit             = amt_dev_uninit,
3090         .ndo_open               = amt_dev_open,
3091         .ndo_stop               = amt_dev_stop,
3092         .ndo_start_xmit         = amt_dev_xmit,
3093         .ndo_get_stats64        = dev_get_tstats64,
3094 };
3095
3096 static void amt_link_setup(struct net_device *dev)
3097 {
3098         dev->netdev_ops         = &amt_netdev_ops;
3099         dev->needs_free_netdev  = true;
3100         SET_NETDEV_DEVTYPE(dev, &amt_type);
3101         dev->min_mtu            = ETH_MIN_MTU;
3102         dev->max_mtu            = ETH_MAX_MTU;
3103         dev->type               = ARPHRD_NONE;
3104         dev->flags              = IFF_POINTOPOINT | IFF_NOARP | IFF_MULTICAST;
3105         dev->hard_header_len    = 0;
3106         dev->addr_len           = 0;
3107         dev->priv_flags         |= IFF_NO_QUEUE;
3108         dev->features           |= NETIF_F_LLTX;
3109         dev->features           |= NETIF_F_GSO_SOFTWARE;
3110         dev->features           |= NETIF_F_NETNS_LOCAL;
3111         dev->hw_features        |= NETIF_F_SG | NETIF_F_HW_CSUM;
3112         dev->hw_features        |= NETIF_F_FRAGLIST | NETIF_F_RXCSUM;
3113         dev->hw_features        |= NETIF_F_GSO_SOFTWARE;
3114         eth_hw_addr_random(dev);
3115         eth_zero_addr(dev->broadcast);
3116         ether_setup(dev);
3117 }
3118
3119 static const struct nla_policy amt_policy[IFLA_AMT_MAX + 1] = {
3120         [IFLA_AMT_MODE]         = { .type = NLA_U32 },
3121         [IFLA_AMT_RELAY_PORT]   = { .type = NLA_U16 },
3122         [IFLA_AMT_GATEWAY_PORT] = { .type = NLA_U16 },
3123         [IFLA_AMT_LINK]         = { .type = NLA_U32 },
3124         [IFLA_AMT_LOCAL_IP]     = { .len = sizeof_field(struct iphdr, daddr) },
3125         [IFLA_AMT_REMOTE_IP]    = { .len = sizeof_field(struct iphdr, daddr) },
3126         [IFLA_AMT_DISCOVERY_IP] = { .len = sizeof_field(struct iphdr, daddr) },
3127         [IFLA_AMT_MAX_TUNNELS]  = { .type = NLA_U32 },
3128 };
3129
3130 static int amt_validate(struct nlattr *tb[], struct nlattr *data[],
3131                         struct netlink_ext_ack *extack)
3132 {
3133         if (!data)
3134                 return -EINVAL;
3135
3136         if (!data[IFLA_AMT_LINK]) {
3137                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LINK],
3138                                     "Link attribute is required");
3139                 return -EINVAL;
3140         }
3141
3142         if (!data[IFLA_AMT_MODE]) {
3143                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
3144                                     "Mode attribute is required");
3145                 return -EINVAL;
3146         }
3147
3148         if (nla_get_u32(data[IFLA_AMT_MODE]) > AMT_MODE_MAX) {
3149                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
3150                                     "Mode attribute is not valid");
3151                 return -EINVAL;
3152         }
3153
3154         if (!data[IFLA_AMT_LOCAL_IP]) {
3155                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_DISCOVERY_IP],
3156                                     "Local attribute is required");
3157                 return -EINVAL;
3158         }
3159
3160         if (!data[IFLA_AMT_DISCOVERY_IP] &&
3161             nla_get_u32(data[IFLA_AMT_MODE]) == AMT_MODE_GATEWAY) {
3162                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LOCAL_IP],
3163                                     "Discovery attribute is required");
3164                 return -EINVAL;
3165         }
3166
3167         return 0;
3168 }
3169
3170 static int amt_newlink(struct net *net, struct net_device *dev,
3171                        struct nlattr *tb[], struct nlattr *data[],
3172                        struct netlink_ext_ack *extack)
3173 {
3174         struct amt_dev *amt = netdev_priv(dev);
3175         int err = -EINVAL;
3176
3177         amt->net = net;
3178         amt->mode = nla_get_u32(data[IFLA_AMT_MODE]);
3179
3180         if (data[IFLA_AMT_MAX_TUNNELS] &&
3181             nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]))
3182                 amt->max_tunnels = nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]);
3183         else
3184                 amt->max_tunnels = AMT_MAX_TUNNELS;
3185
3186         spin_lock_init(&amt->lock);
3187         amt->max_groups = AMT_MAX_GROUP;
3188         amt->max_sources = AMT_MAX_SOURCE;
3189         amt->hash_buckets = AMT_HSIZE;
3190         amt->nr_tunnels = 0;
3191         get_random_bytes(&amt->hash_seed, sizeof(amt->hash_seed));
3192         amt->stream_dev = dev_get_by_index(net,
3193                                            nla_get_u32(data[IFLA_AMT_LINK]));
3194         if (!amt->stream_dev) {
3195                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3196                                     "Can't find stream device");
3197                 return -ENODEV;
3198         }
3199
3200         if (amt->stream_dev->type != ARPHRD_ETHER) {
3201                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3202                                     "Invalid stream device type");
3203                 goto err;
3204         }
3205
3206         amt->local_ip = nla_get_in_addr(data[IFLA_AMT_LOCAL_IP]);
3207         if (ipv4_is_loopback(amt->local_ip) ||
3208             ipv4_is_zeronet(amt->local_ip) ||
3209             ipv4_is_multicast(amt->local_ip)) {
3210                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LOCAL_IP],
3211                                     "Invalid Local address");
3212                 goto err;
3213         }
3214
3215         if (data[IFLA_AMT_RELAY_PORT])
3216                 amt->relay_port = nla_get_be16(data[IFLA_AMT_RELAY_PORT]);
3217         else
3218                 amt->relay_port = htons(IANA_AMT_UDP_PORT);
3219
3220         if (data[IFLA_AMT_GATEWAY_PORT])
3221                 amt->gw_port = nla_get_be16(data[IFLA_AMT_GATEWAY_PORT]);
3222         else
3223                 amt->gw_port = htons(IANA_AMT_UDP_PORT);
3224
3225         if (!amt->relay_port) {
3226                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3227                                     "relay port must not be 0");
3228                 goto err;
3229         }
3230         if (amt->mode == AMT_MODE_RELAY) {
3231                 amt->qrv = READ_ONCE(amt->net->ipv4.sysctl_igmp_qrv);
3232                 amt->qri = 10;
3233                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3234                                        AMT_RELAY_HLEN;
3235                 dev->mtu = amt->stream_dev->mtu - AMT_RELAY_HLEN;
3236                 dev->max_mtu = dev->mtu;
3237                 dev->min_mtu = ETH_MIN_MTU + AMT_RELAY_HLEN;
3238         } else {
3239                 if (!data[IFLA_AMT_DISCOVERY_IP]) {
3240                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3241                                             "discovery must be set in gateway mode");
3242                         goto err;
3243                 }
3244                 if (!amt->gw_port) {
3245                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3246                                             "gateway port must not be 0");
3247                         goto err;
3248                 }
3249                 amt->remote_ip = 0;
3250                 amt->discovery_ip = nla_get_in_addr(data[IFLA_AMT_DISCOVERY_IP]);
3251                 if (ipv4_is_loopback(amt->discovery_ip) ||
3252                     ipv4_is_zeronet(amt->discovery_ip) ||
3253                     ipv4_is_multicast(amt->discovery_ip)) {
3254                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3255                                             "discovery must be unicast");
3256                         goto err;
3257                 }
3258
3259                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3260                                        AMT_GW_HLEN;
3261                 dev->mtu = amt->stream_dev->mtu - AMT_GW_HLEN;
3262                 dev->max_mtu = dev->mtu;
3263                 dev->min_mtu = ETH_MIN_MTU + AMT_GW_HLEN;
3264         }
3265         amt->qi = AMT_INIT_QUERY_INTERVAL;
3266
3267         err = register_netdevice(dev);
3268         if (err < 0) {
3269                 netdev_dbg(dev, "failed to register new netdev %d\n", err);
3270                 goto err;
3271         }
3272
3273         err = netdev_upper_dev_link(amt->stream_dev, dev, extack);
3274         if (err < 0) {
3275                 unregister_netdevice(dev);
3276                 goto err;
3277         }
3278
3279         INIT_DELAYED_WORK(&amt->discovery_wq, amt_discovery_work);
3280         INIT_DELAYED_WORK(&amt->req_wq, amt_req_work);
3281         INIT_DELAYED_WORK(&amt->secret_wq, amt_secret_work);
3282         INIT_WORK(&amt->event_wq, amt_event_work);
3283         INIT_LIST_HEAD(&amt->tunnel_list);
3284         return 0;
3285 err:
3286         dev_put(amt->stream_dev);
3287         return err;
3288 }
3289
3290 static void amt_dellink(struct net_device *dev, struct list_head *head)
3291 {
3292         struct amt_dev *amt = netdev_priv(dev);
3293
3294         unregister_netdevice_queue(dev, head);
3295         netdev_upper_dev_unlink(amt->stream_dev, dev);
3296         dev_put(amt->stream_dev);
3297 }
3298
3299 static size_t amt_get_size(const struct net_device *dev)
3300 {
3301         return nla_total_size(sizeof(__u32)) + /* IFLA_AMT_MODE */
3302                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_RELAY_PORT */
3303                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_GATEWAY_PORT */
3304                nla_total_size(sizeof(__u32)) + /* IFLA_AMT_LINK */
3305                nla_total_size(sizeof(__u32)) + /* IFLA_MAX_TUNNELS */
3306                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_DISCOVERY_IP */
3307                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_REMOTE_IP */
3308                nla_total_size(sizeof(struct iphdr)); /* IFLA_AMT_LOCAL_IP */
3309 }
3310
3311 static int amt_fill_info(struct sk_buff *skb, const struct net_device *dev)
3312 {
3313         struct amt_dev *amt = netdev_priv(dev);
3314
3315         if (nla_put_u32(skb, IFLA_AMT_MODE, amt->mode))
3316                 goto nla_put_failure;
3317         if (nla_put_be16(skb, IFLA_AMT_RELAY_PORT, amt->relay_port))
3318                 goto nla_put_failure;
3319         if (nla_put_be16(skb, IFLA_AMT_GATEWAY_PORT, amt->gw_port))
3320                 goto nla_put_failure;
3321         if (nla_put_u32(skb, IFLA_AMT_LINK, amt->stream_dev->ifindex))
3322                 goto nla_put_failure;
3323         if (nla_put_in_addr(skb, IFLA_AMT_LOCAL_IP, amt->local_ip))
3324                 goto nla_put_failure;
3325         if (nla_put_in_addr(skb, IFLA_AMT_DISCOVERY_IP, amt->discovery_ip))
3326                 goto nla_put_failure;
3327         if (amt->remote_ip)
3328                 if (nla_put_in_addr(skb, IFLA_AMT_REMOTE_IP, amt->remote_ip))
3329                         goto nla_put_failure;
3330         if (nla_put_u32(skb, IFLA_AMT_MAX_TUNNELS, amt->max_tunnels))
3331                 goto nla_put_failure;
3332
3333         return 0;
3334
3335 nla_put_failure:
3336         return -EMSGSIZE;
3337 }
3338
3339 static struct rtnl_link_ops amt_link_ops __read_mostly = {
3340         .kind           = "amt",
3341         .maxtype        = IFLA_AMT_MAX,
3342         .policy         = amt_policy,
3343         .priv_size      = sizeof(struct amt_dev),
3344         .setup          = amt_link_setup,
3345         .validate       = amt_validate,
3346         .newlink        = amt_newlink,
3347         .dellink        = amt_dellink,
3348         .get_size       = amt_get_size,
3349         .fill_info      = amt_fill_info,
3350 };
3351
3352 static struct net_device *amt_lookup_upper_dev(struct net_device *dev)
3353 {
3354         struct net_device *upper_dev;
3355         struct amt_dev *amt;
3356
3357         for_each_netdev(dev_net(dev), upper_dev) {
3358                 if (netif_is_amt(upper_dev)) {
3359                         amt = netdev_priv(upper_dev);
3360                         if (amt->stream_dev == dev)
3361                                 return upper_dev;
3362                 }
3363         }
3364
3365         return NULL;
3366 }
3367
3368 static int amt_device_event(struct notifier_block *unused,
3369                             unsigned long event, void *ptr)
3370 {
3371         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
3372         struct net_device *upper_dev;
3373         struct amt_dev *amt;
3374         LIST_HEAD(list);
3375         int new_mtu;
3376
3377         upper_dev = amt_lookup_upper_dev(dev);
3378         if (!upper_dev)
3379                 return NOTIFY_DONE;
3380         amt = netdev_priv(upper_dev);
3381
3382         switch (event) {
3383         case NETDEV_UNREGISTER:
3384                 amt_dellink(amt->dev, &list);
3385                 unregister_netdevice_many(&list);
3386                 break;
3387         case NETDEV_CHANGEMTU:
3388                 if (amt->mode == AMT_MODE_RELAY)
3389                         new_mtu = dev->mtu - AMT_RELAY_HLEN;
3390                 else
3391                         new_mtu = dev->mtu - AMT_GW_HLEN;
3392
3393                 dev_set_mtu(amt->dev, new_mtu);
3394                 break;
3395         }
3396
3397         return NOTIFY_DONE;
3398 }
3399
3400 static struct notifier_block amt_notifier_block __read_mostly = {
3401         .notifier_call = amt_device_event,
3402 };
3403
3404 static int __init amt_init(void)
3405 {
3406         int err;
3407
3408         err = register_netdevice_notifier(&amt_notifier_block);
3409         if (err < 0)
3410                 goto err;
3411
3412         err = rtnl_link_register(&amt_link_ops);
3413         if (err < 0)
3414                 goto unregister_notifier;
3415
3416         amt_wq = alloc_workqueue("amt", WQ_UNBOUND, 0);
3417         if (!amt_wq) {
3418                 err = -ENOMEM;
3419                 goto rtnl_unregister;
3420         }
3421
3422         spin_lock_init(&source_gc_lock);
3423         spin_lock_bh(&source_gc_lock);
3424         INIT_DELAYED_WORK(&source_gc_wq, amt_source_gc_work);
3425         mod_delayed_work(amt_wq, &source_gc_wq,
3426                          msecs_to_jiffies(AMT_GC_INTERVAL));
3427         spin_unlock_bh(&source_gc_lock);
3428
3429         return 0;
3430
3431 rtnl_unregister:
3432         rtnl_link_unregister(&amt_link_ops);
3433 unregister_notifier:
3434         unregister_netdevice_notifier(&amt_notifier_block);
3435 err:
3436         pr_err("error loading AMT module loaded\n");
3437         return err;
3438 }
3439 late_initcall(amt_init);
3440
3441 static void __exit amt_fini(void)
3442 {
3443         rtnl_link_unregister(&amt_link_ops);
3444         unregister_netdevice_notifier(&amt_notifier_block);
3445         cancel_delayed_work_sync(&source_gc_wq);
3446         __amt_source_gc_work();
3447         destroy_workqueue(amt_wq);
3448 }
3449 module_exit(amt_fini);
3450
3451 MODULE_LICENSE("GPL");
3452 MODULE_AUTHOR("Taehee Yoo <ap420073@gmail.com>");
3453 MODULE_ALIAS_RTNL_LINK("amt");