amt: fix wrong return type of amt_send_membership_update()
[platform/kernel/linux-starfive.git] / drivers / net / amt.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Copyright (c) 2021 Taehee Yoo <ap420073@gmail.com> */
3
4 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5
6 #include <linux/module.h>
7 #include <linux/skbuff.h>
8 #include <linux/udp.h>
9 #include <linux/jhash.h>
10 #include <linux/if_tunnel.h>
11 #include <linux/net.h>
12 #include <linux/igmp.h>
13 #include <linux/workqueue.h>
14 #include <net/net_namespace.h>
15 #include <net/ip.h>
16 #include <net/udp.h>
17 #include <net/udp_tunnel.h>
18 #include <net/icmp.h>
19 #include <net/mld.h>
20 #include <net/amt.h>
21 #include <uapi/linux/amt.h>
22 #include <linux/security.h>
23 #include <net/gro_cells.h>
24 #include <net/ipv6.h>
25 #include <net/if_inet6.h>
26 #include <net/ndisc.h>
27 #include <net/addrconf.h>
28 #include <net/ip6_route.h>
29 #include <net/inet_common.h>
30 #include <net/ip6_checksum.h>
31
32 static struct workqueue_struct *amt_wq;
33
34 static HLIST_HEAD(source_gc_list);
35 /* Lock for source_gc_list */
36 static spinlock_t source_gc_lock;
37 static struct delayed_work source_gc_wq;
38 static char *status_str[] = {
39         "AMT_STATUS_INIT",
40         "AMT_STATUS_SENT_DISCOVERY",
41         "AMT_STATUS_RECEIVED_DISCOVERY",
42         "AMT_STATUS_SENT_ADVERTISEMENT",
43         "AMT_STATUS_RECEIVED_ADVERTISEMENT",
44         "AMT_STATUS_SENT_REQUEST",
45         "AMT_STATUS_RECEIVED_REQUEST",
46         "AMT_STATUS_SENT_QUERY",
47         "AMT_STATUS_RECEIVED_QUERY",
48         "AMT_STATUS_SENT_UPDATE",
49         "AMT_STATUS_RECEIVED_UPDATE",
50 };
51
52 static char *type_str[] = {
53         "AMT_MSG_DISCOVERY",
54         "AMT_MSG_ADVERTISEMENT",
55         "AMT_MSG_REQUEST",
56         "AMT_MSG_MEMBERSHIP_QUERY",
57         "AMT_MSG_MEMBERSHIP_UPDATE",
58         "AMT_MSG_MULTICAST_DATA",
59         "AMT_MSG_TEARDOWM",
60 };
61
62 static char *action_str[] = {
63         "AMT_ACT_GMI",
64         "AMT_ACT_GMI_ZERO",
65         "AMT_ACT_GT",
66         "AMT_ACT_STATUS_FWD_NEW",
67         "AMT_ACT_STATUS_D_FWD_NEW",
68         "AMT_ACT_STATUS_NONE_NEW",
69 };
70
71 static struct igmpv3_grec igmpv3_zero_grec;
72
73 #if IS_ENABLED(CONFIG_IPV6)
74 #define MLD2_ALL_NODE_INIT { { { 0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01 } } }
75 static struct in6_addr mld2_all_node = MLD2_ALL_NODE_INIT;
76 static struct mld2_grec mldv2_zero_grec;
77 #endif
78
79 static struct amt_skb_cb *amt_skb_cb(struct sk_buff *skb)
80 {
81         BUILD_BUG_ON(sizeof(struct amt_skb_cb) + sizeof(struct qdisc_skb_cb) >
82                      sizeof_field(struct sk_buff, cb));
83
84         return (struct amt_skb_cb *)((void *)skb->cb +
85                 sizeof(struct qdisc_skb_cb));
86 }
87
88 static void __amt_source_gc_work(void)
89 {
90         struct amt_source_node *snode;
91         struct hlist_head gc_list;
92         struct hlist_node *t;
93
94         spin_lock_bh(&source_gc_lock);
95         hlist_move_list(&source_gc_list, &gc_list);
96         spin_unlock_bh(&source_gc_lock);
97
98         hlist_for_each_entry_safe(snode, t, &gc_list, node) {
99                 hlist_del_rcu(&snode->node);
100                 kfree_rcu(snode, rcu);
101         }
102 }
103
104 static void amt_source_gc_work(struct work_struct *work)
105 {
106         __amt_source_gc_work();
107
108         spin_lock_bh(&source_gc_lock);
109         mod_delayed_work(amt_wq, &source_gc_wq,
110                          msecs_to_jiffies(AMT_GC_INTERVAL));
111         spin_unlock_bh(&source_gc_lock);
112 }
113
114 static bool amt_addr_equal(union amt_addr *a, union amt_addr *b)
115 {
116         return !memcmp(a, b, sizeof(union amt_addr));
117 }
118
119 static u32 amt_source_hash(struct amt_tunnel_list *tunnel, union amt_addr *src)
120 {
121         u32 hash = jhash(src, sizeof(*src), tunnel->amt->hash_seed);
122
123         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
124 }
125
126 static bool amt_status_filter(struct amt_source_node *snode,
127                               enum amt_filter filter)
128 {
129         bool rc = false;
130
131         switch (filter) {
132         case AMT_FILTER_FWD:
133                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
134                     snode->flags == AMT_SOURCE_OLD)
135                         rc = true;
136                 break;
137         case AMT_FILTER_D_FWD:
138                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
139                     snode->flags == AMT_SOURCE_OLD)
140                         rc = true;
141                 break;
142         case AMT_FILTER_FWD_NEW:
143                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
144                     snode->flags == AMT_SOURCE_NEW)
145                         rc = true;
146                 break;
147         case AMT_FILTER_D_FWD_NEW:
148                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
149                     snode->flags == AMT_SOURCE_NEW)
150                         rc = true;
151                 break;
152         case AMT_FILTER_ALL:
153                 rc = true;
154                 break;
155         case AMT_FILTER_NONE_NEW:
156                 if (snode->status == AMT_SOURCE_STATUS_NONE &&
157                     snode->flags == AMT_SOURCE_NEW)
158                         rc = true;
159                 break;
160         case AMT_FILTER_BOTH:
161                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
162                      snode->status == AMT_SOURCE_STATUS_FWD) &&
163                     snode->flags == AMT_SOURCE_OLD)
164                         rc = true;
165                 break;
166         case AMT_FILTER_BOTH_NEW:
167                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
168                      snode->status == AMT_SOURCE_STATUS_FWD) &&
169                     snode->flags == AMT_SOURCE_NEW)
170                         rc = true;
171                 break;
172         default:
173                 WARN_ON_ONCE(1);
174                 break;
175         }
176
177         return rc;
178 }
179
180 static struct amt_source_node *amt_lookup_src(struct amt_tunnel_list *tunnel,
181                                               struct amt_group_node *gnode,
182                                               enum amt_filter filter,
183                                               union amt_addr *src)
184 {
185         u32 hash = amt_source_hash(tunnel, src);
186         struct amt_source_node *snode;
187
188         hlist_for_each_entry_rcu(snode, &gnode->sources[hash], node)
189                 if (amt_status_filter(snode, filter) &&
190                     amt_addr_equal(&snode->source_addr, src))
191                         return snode;
192
193         return NULL;
194 }
195
196 static u32 amt_group_hash(struct amt_tunnel_list *tunnel, union amt_addr *group)
197 {
198         u32 hash = jhash(group, sizeof(*group), tunnel->amt->hash_seed);
199
200         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
201 }
202
203 static struct amt_group_node *amt_lookup_group(struct amt_tunnel_list *tunnel,
204                                                union amt_addr *group,
205                                                union amt_addr *host,
206                                                bool v6)
207 {
208         u32 hash = amt_group_hash(tunnel, group);
209         struct amt_group_node *gnode;
210
211         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash], node) {
212                 if (amt_addr_equal(&gnode->group_addr, group) &&
213                     amt_addr_equal(&gnode->host_addr, host) &&
214                     gnode->v6 == v6)
215                         return gnode;
216         }
217
218         return NULL;
219 }
220
221 static void amt_destroy_source(struct amt_source_node *snode)
222 {
223         struct amt_group_node *gnode = snode->gnode;
224         struct amt_tunnel_list *tunnel;
225
226         tunnel = gnode->tunnel_list;
227
228         if (!gnode->v6) {
229                 netdev_dbg(snode->gnode->amt->dev,
230                            "Delete source %pI4 from %pI4\n",
231                            &snode->source_addr.ip4,
232                            &gnode->group_addr.ip4);
233 #if IS_ENABLED(CONFIG_IPV6)
234         } else {
235                 netdev_dbg(snode->gnode->amt->dev,
236                            "Delete source %pI6 from %pI6\n",
237                            &snode->source_addr.ip6,
238                            &gnode->group_addr.ip6);
239 #endif
240         }
241
242         cancel_delayed_work(&snode->source_timer);
243         hlist_del_init_rcu(&snode->node);
244         tunnel->nr_sources--;
245         gnode->nr_sources--;
246         spin_lock_bh(&source_gc_lock);
247         hlist_add_head_rcu(&snode->node, &source_gc_list);
248         spin_unlock_bh(&source_gc_lock);
249 }
250
251 static void amt_del_group(struct amt_dev *amt, struct amt_group_node *gnode)
252 {
253         struct amt_source_node *snode;
254         struct hlist_node *t;
255         int i;
256
257         if (cancel_delayed_work(&gnode->group_timer))
258                 dev_put(amt->dev);
259         hlist_del_rcu(&gnode->node);
260         gnode->tunnel_list->nr_groups--;
261
262         if (!gnode->v6)
263                 netdev_dbg(amt->dev, "Leave group %pI4\n",
264                            &gnode->group_addr.ip4);
265 #if IS_ENABLED(CONFIG_IPV6)
266         else
267                 netdev_dbg(amt->dev, "Leave group %pI6\n",
268                            &gnode->group_addr.ip6);
269 #endif
270         for (i = 0; i < amt->hash_buckets; i++)
271                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node)
272                         amt_destroy_source(snode);
273
274         /* tunnel->lock was acquired outside of amt_del_group()
275          * But rcu_read_lock() was acquired too so It's safe.
276          */
277         kfree_rcu(gnode, rcu);
278 }
279
280 /* If a source timer expires with a router filter-mode for the group of
281  * INCLUDE, the router concludes that traffic from this particular
282  * source is no longer desired on the attached network, and deletes the
283  * associated source record.
284  */
285 static void amt_source_work(struct work_struct *work)
286 {
287         struct amt_source_node *snode = container_of(to_delayed_work(work),
288                                                      struct amt_source_node,
289                                                      source_timer);
290         struct amt_group_node *gnode = snode->gnode;
291         struct amt_dev *amt = gnode->amt;
292         struct amt_tunnel_list *tunnel;
293
294         tunnel = gnode->tunnel_list;
295         spin_lock_bh(&tunnel->lock);
296         rcu_read_lock();
297         if (gnode->filter_mode == MCAST_INCLUDE) {
298                 amt_destroy_source(snode);
299                 if (!gnode->nr_sources)
300                         amt_del_group(amt, gnode);
301         } else {
302                 /* When a router filter-mode for a group is EXCLUDE,
303                  * source records are only deleted when the group timer expires
304                  */
305                 snode->status = AMT_SOURCE_STATUS_D_FWD;
306         }
307         rcu_read_unlock();
308         spin_unlock_bh(&tunnel->lock);
309 }
310
311 static void amt_act_src(struct amt_tunnel_list *tunnel,
312                         struct amt_group_node *gnode,
313                         struct amt_source_node *snode,
314                         enum amt_act act)
315 {
316         struct amt_dev *amt = tunnel->amt;
317
318         switch (act) {
319         case AMT_ACT_GMI:
320                 mod_delayed_work(amt_wq, &snode->source_timer,
321                                  msecs_to_jiffies(amt_gmi(amt)));
322                 break;
323         case AMT_ACT_GMI_ZERO:
324                 cancel_delayed_work(&snode->source_timer);
325                 break;
326         case AMT_ACT_GT:
327                 mod_delayed_work(amt_wq, &snode->source_timer,
328                                  gnode->group_timer.timer.expires);
329                 break;
330         case AMT_ACT_STATUS_FWD_NEW:
331                 snode->status = AMT_SOURCE_STATUS_FWD;
332                 snode->flags = AMT_SOURCE_NEW;
333                 break;
334         case AMT_ACT_STATUS_D_FWD_NEW:
335                 snode->status = AMT_SOURCE_STATUS_D_FWD;
336                 snode->flags = AMT_SOURCE_NEW;
337                 break;
338         case AMT_ACT_STATUS_NONE_NEW:
339                 cancel_delayed_work(&snode->source_timer);
340                 snode->status = AMT_SOURCE_STATUS_NONE;
341                 snode->flags = AMT_SOURCE_NEW;
342                 break;
343         default:
344                 WARN_ON_ONCE(1);
345                 return;
346         }
347
348         if (!gnode->v6)
349                 netdev_dbg(amt->dev, "Source %pI4 from %pI4 Acted %s\n",
350                            &snode->source_addr.ip4,
351                            &gnode->group_addr.ip4,
352                            action_str[act]);
353 #if IS_ENABLED(CONFIG_IPV6)
354         else
355                 netdev_dbg(amt->dev, "Source %pI6 from %pI6 Acted %s\n",
356                            &snode->source_addr.ip6,
357                            &gnode->group_addr.ip6,
358                            action_str[act]);
359 #endif
360 }
361
362 static struct amt_source_node *amt_alloc_snode(struct amt_group_node *gnode,
363                                                union amt_addr *src)
364 {
365         struct amt_source_node *snode;
366
367         snode = kzalloc(sizeof(*snode), GFP_ATOMIC);
368         if (!snode)
369                 return NULL;
370
371         memcpy(&snode->source_addr, src, sizeof(union amt_addr));
372         snode->gnode = gnode;
373         snode->status = AMT_SOURCE_STATUS_NONE;
374         snode->flags = AMT_SOURCE_NEW;
375         INIT_HLIST_NODE(&snode->node);
376         INIT_DELAYED_WORK(&snode->source_timer, amt_source_work);
377
378         return snode;
379 }
380
381 /* RFC 3810 - 7.2.2.  Definition of Filter Timers
382  *
383  *  Router Mode          Filter Timer         Actions/Comments
384  *  -----------       -----------------       ----------------
385  *
386  *    INCLUDE             Not Used            All listeners in
387  *                                            INCLUDE mode.
388  *
389  *    EXCLUDE             Timer > 0           At least one listener
390  *                                            in EXCLUDE mode.
391  *
392  *    EXCLUDE             Timer == 0          No more listeners in
393  *                                            EXCLUDE mode for the
394  *                                            multicast address.
395  *                                            If the Requested List
396  *                                            is empty, delete
397  *                                            Multicast Address
398  *                                            Record.  If not, switch
399  *                                            to INCLUDE filter mode;
400  *                                            the sources in the
401  *                                            Requested List are
402  *                                            moved to the Include
403  *                                            List, and the Exclude
404  *                                            List is deleted.
405  */
406 static void amt_group_work(struct work_struct *work)
407 {
408         struct amt_group_node *gnode = container_of(to_delayed_work(work),
409                                                     struct amt_group_node,
410                                                     group_timer);
411         struct amt_tunnel_list *tunnel = gnode->tunnel_list;
412         struct amt_dev *amt = gnode->amt;
413         struct amt_source_node *snode;
414         bool delete_group = true;
415         struct hlist_node *t;
416         int i, buckets;
417
418         buckets = amt->hash_buckets;
419
420         spin_lock_bh(&tunnel->lock);
421         if (gnode->filter_mode == MCAST_INCLUDE) {
422                 /* Not Used */
423                 spin_unlock_bh(&tunnel->lock);
424                 goto out;
425         }
426
427         rcu_read_lock();
428         for (i = 0; i < buckets; i++) {
429                 hlist_for_each_entry_safe(snode, t,
430                                           &gnode->sources[i], node) {
431                         if (!delayed_work_pending(&snode->source_timer) ||
432                             snode->status == AMT_SOURCE_STATUS_D_FWD) {
433                                 amt_destroy_source(snode);
434                         } else {
435                                 delete_group = false;
436                                 snode->status = AMT_SOURCE_STATUS_FWD;
437                         }
438                 }
439         }
440         if (delete_group)
441                 amt_del_group(amt, gnode);
442         else
443                 gnode->filter_mode = MCAST_INCLUDE;
444         rcu_read_unlock();
445         spin_unlock_bh(&tunnel->lock);
446 out:
447         dev_put(amt->dev);
448 }
449
450 /* Non-existant group is created as INCLUDE {empty}:
451  *
452  * RFC 3376 - 5.1. Action on Change of Interface State
453  *
454  * If no interface state existed for that multicast address before
455  * the change (i.e., the change consisted of creating a new
456  * per-interface record), or if no state exists after the change
457  * (i.e., the change consisted of deleting a per-interface record),
458  * then the "non-existent" state is considered to have a filter mode
459  * of INCLUDE and an empty source list.
460  */
461 static struct amt_group_node *amt_add_group(struct amt_dev *amt,
462                                             struct amt_tunnel_list *tunnel,
463                                             union amt_addr *group,
464                                             union amt_addr *host,
465                                             bool v6)
466 {
467         struct amt_group_node *gnode;
468         u32 hash;
469         int i;
470
471         if (tunnel->nr_groups >= amt->max_groups)
472                 return ERR_PTR(-ENOSPC);
473
474         gnode = kzalloc(sizeof(*gnode) +
475                         (sizeof(struct hlist_head) * amt->hash_buckets),
476                         GFP_ATOMIC);
477         if (unlikely(!gnode))
478                 return ERR_PTR(-ENOMEM);
479
480         gnode->amt = amt;
481         gnode->group_addr = *group;
482         gnode->host_addr = *host;
483         gnode->v6 = v6;
484         gnode->tunnel_list = tunnel;
485         gnode->filter_mode = MCAST_INCLUDE;
486         INIT_HLIST_NODE(&gnode->node);
487         INIT_DELAYED_WORK(&gnode->group_timer, amt_group_work);
488         for (i = 0; i < amt->hash_buckets; i++)
489                 INIT_HLIST_HEAD(&gnode->sources[i]);
490
491         hash = amt_group_hash(tunnel, group);
492         hlist_add_head_rcu(&gnode->node, &tunnel->groups[hash]);
493         tunnel->nr_groups++;
494
495         if (!gnode->v6)
496                 netdev_dbg(amt->dev, "Join group %pI4\n",
497                            &gnode->group_addr.ip4);
498 #if IS_ENABLED(CONFIG_IPV6)
499         else
500                 netdev_dbg(amt->dev, "Join group %pI6\n",
501                            &gnode->group_addr.ip6);
502 #endif
503
504         return gnode;
505 }
506
507 static struct sk_buff *amt_build_igmp_gq(struct amt_dev *amt)
508 {
509         u8 ra[AMT_IPHDR_OPTS] = { IPOPT_RA, 4, 0, 0 };
510         int hlen = LL_RESERVED_SPACE(amt->dev);
511         int tlen = amt->dev->needed_tailroom;
512         struct igmpv3_query *ihv3;
513         void *csum_start = NULL;
514         __sum16 *csum = NULL;
515         struct sk_buff *skb;
516         struct ethhdr *eth;
517         struct iphdr *iph;
518         unsigned int len;
519         int offset;
520
521         len = hlen + tlen + sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3);
522         skb = netdev_alloc_skb_ip_align(amt->dev, len);
523         if (!skb)
524                 return NULL;
525
526         skb_reserve(skb, hlen);
527         skb_push(skb, sizeof(*eth));
528         skb->protocol = htons(ETH_P_IP);
529         skb_reset_mac_header(skb);
530         skb->priority = TC_PRIO_CONTROL;
531         skb_put(skb, sizeof(*iph));
532         skb_put_data(skb, ra, sizeof(ra));
533         skb_put(skb, sizeof(*ihv3));
534         skb_pull(skb, sizeof(*eth));
535         skb_reset_network_header(skb);
536
537         iph             = ip_hdr(skb);
538         iph->version    = 4;
539         iph->ihl        = (sizeof(struct iphdr) + AMT_IPHDR_OPTS) >> 2;
540         iph->tos        = AMT_TOS;
541         iph->tot_len    = htons(sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3));
542         iph->frag_off   = htons(IP_DF);
543         iph->ttl        = 1;
544         iph->id         = 0;
545         iph->protocol   = IPPROTO_IGMP;
546         iph->daddr      = htonl(INADDR_ALLHOSTS_GROUP);
547         iph->saddr      = htonl(INADDR_ANY);
548         ip_send_check(iph);
549
550         eth = eth_hdr(skb);
551         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
552         ip_eth_mc_map(htonl(INADDR_ALLHOSTS_GROUP), eth->h_dest);
553         eth->h_proto = htons(ETH_P_IP);
554
555         ihv3            = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
556         skb_reset_transport_header(skb);
557         ihv3->type      = IGMP_HOST_MEMBERSHIP_QUERY;
558         ihv3->code      = 1;
559         ihv3->group     = 0;
560         ihv3->qqic      = amt->qi;
561         ihv3->nsrcs     = 0;
562         ihv3->resv      = 0;
563         ihv3->suppress  = false;
564         ihv3->qrv       = amt->net->ipv4.sysctl_igmp_qrv;
565         ihv3->csum      = 0;
566         csum            = &ihv3->csum;
567         csum_start      = (void *)ihv3;
568         *csum           = ip_compute_csum(csum_start, sizeof(*ihv3));
569         offset          = skb_transport_offset(skb);
570         skb->csum       = skb_checksum(skb, offset, skb->len - offset, 0);
571         skb->ip_summed  = CHECKSUM_NONE;
572
573         skb_push(skb, sizeof(*eth) + sizeof(*iph) + AMT_IPHDR_OPTS);
574
575         return skb;
576 }
577
578 static void __amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
579                                    bool validate)
580 {
581         if (validate && amt->status >= status)
582                 return;
583         netdev_dbg(amt->dev, "Update GW status %s -> %s",
584                    status_str[amt->status], status_str[status]);
585         amt->status = status;
586 }
587
588 static void __amt_update_relay_status(struct amt_tunnel_list *tunnel,
589                                       enum amt_status status,
590                                       bool validate)
591 {
592         if (validate && tunnel->status >= status)
593                 return;
594         netdev_dbg(tunnel->amt->dev,
595                    "Update Tunnel(IP = %pI4, PORT = %u) status %s -> %s",
596                    &tunnel->ip4, ntohs(tunnel->source_port),
597                    status_str[tunnel->status], status_str[status]);
598         tunnel->status = status;
599 }
600
601 static void amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
602                                  bool validate)
603 {
604         spin_lock_bh(&amt->lock);
605         __amt_update_gw_status(amt, status, validate);
606         spin_unlock_bh(&amt->lock);
607 }
608
609 static void amt_update_relay_status(struct amt_tunnel_list *tunnel,
610                                     enum amt_status status, bool validate)
611 {
612         spin_lock_bh(&tunnel->lock);
613         __amt_update_relay_status(tunnel, status, validate);
614         spin_unlock_bh(&tunnel->lock);
615 }
616
617 static void amt_send_discovery(struct amt_dev *amt)
618 {
619         struct amt_header_discovery *amtd;
620         int hlen, tlen, offset;
621         struct socket *sock;
622         struct udphdr *udph;
623         struct sk_buff *skb;
624         struct iphdr *iph;
625         struct rtable *rt;
626         struct flowi4 fl4;
627         u32 len;
628         int err;
629
630         rcu_read_lock();
631         sock = rcu_dereference(amt->sock);
632         if (!sock)
633                 goto out;
634
635         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
636                 goto out;
637
638         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
639                                    amt->discovery_ip, amt->local_ip,
640                                    amt->gw_port, amt->relay_port,
641                                    IPPROTO_UDP, 0,
642                                    amt->stream_dev->ifindex);
643         if (IS_ERR(rt)) {
644                 amt->dev->stats.tx_errors++;
645                 goto out;
646         }
647
648         hlen = LL_RESERVED_SPACE(amt->dev);
649         tlen = amt->dev->needed_tailroom;
650         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
651         skb = netdev_alloc_skb_ip_align(amt->dev, len);
652         if (!skb) {
653                 ip_rt_put(rt);
654                 amt->dev->stats.tx_errors++;
655                 goto out;
656         }
657
658         skb->priority = TC_PRIO_CONTROL;
659         skb_dst_set(skb, &rt->dst);
660
661         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
662         skb_reset_network_header(skb);
663         skb_put(skb, len);
664         amtd = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
665         amtd->version   = 0;
666         amtd->type      = AMT_MSG_DISCOVERY;
667         amtd->reserved  = 0;
668         amtd->nonce     = amt->nonce;
669         skb_push(skb, sizeof(*udph));
670         skb_reset_transport_header(skb);
671         udph            = udp_hdr(skb);
672         udph->source    = amt->gw_port;
673         udph->dest      = amt->relay_port;
674         udph->len       = htons(sizeof(*udph) + sizeof(*amtd));
675         udph->check     = 0;
676         offset = skb_transport_offset(skb);
677         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
678         udph->check = csum_tcpudp_magic(amt->local_ip, amt->discovery_ip,
679                                         sizeof(*udph) + sizeof(*amtd),
680                                         IPPROTO_UDP, skb->csum);
681
682         skb_push(skb, sizeof(*iph));
683         iph             = ip_hdr(skb);
684         iph->version    = 4;
685         iph->ihl        = (sizeof(struct iphdr)) >> 2;
686         iph->tos        = AMT_TOS;
687         iph->frag_off   = 0;
688         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
689         iph->daddr      = amt->discovery_ip;
690         iph->saddr      = amt->local_ip;
691         iph->protocol   = IPPROTO_UDP;
692         iph->tot_len    = htons(len);
693
694         skb->ip_summed = CHECKSUM_NONE;
695         ip_select_ident(amt->net, skb, NULL);
696         ip_send_check(iph);
697         err = ip_local_out(amt->net, sock->sk, skb);
698         if (unlikely(net_xmit_eval(err)))
699                 amt->dev->stats.tx_errors++;
700
701         spin_lock_bh(&amt->lock);
702         __amt_update_gw_status(amt, AMT_STATUS_SENT_DISCOVERY, true);
703         spin_unlock_bh(&amt->lock);
704 out:
705         rcu_read_unlock();
706 }
707
708 static void amt_send_request(struct amt_dev *amt, bool v6)
709 {
710         struct amt_header_request *amtrh;
711         int hlen, tlen, offset;
712         struct socket *sock;
713         struct udphdr *udph;
714         struct sk_buff *skb;
715         struct iphdr *iph;
716         struct rtable *rt;
717         struct flowi4 fl4;
718         u32 len;
719         int err;
720
721         rcu_read_lock();
722         sock = rcu_dereference(amt->sock);
723         if (!sock)
724                 goto out;
725
726         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
727                 goto out;
728
729         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
730                                    amt->remote_ip, amt->local_ip,
731                                    amt->gw_port, amt->relay_port,
732                                    IPPROTO_UDP, 0,
733                                    amt->stream_dev->ifindex);
734         if (IS_ERR(rt)) {
735                 amt->dev->stats.tx_errors++;
736                 goto out;
737         }
738
739         hlen = LL_RESERVED_SPACE(amt->dev);
740         tlen = amt->dev->needed_tailroom;
741         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
742         skb = netdev_alloc_skb_ip_align(amt->dev, len);
743         if (!skb) {
744                 ip_rt_put(rt);
745                 amt->dev->stats.tx_errors++;
746                 goto out;
747         }
748
749         skb->priority = TC_PRIO_CONTROL;
750         skb_dst_set(skb, &rt->dst);
751
752         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
753         skb_reset_network_header(skb);
754         skb_put(skb, len);
755         amtrh = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
756         amtrh->version   = 0;
757         amtrh->type      = AMT_MSG_REQUEST;
758         amtrh->reserved1 = 0;
759         amtrh->p         = v6;
760         amtrh->reserved2 = 0;
761         amtrh->nonce     = amt->nonce;
762         skb_push(skb, sizeof(*udph));
763         skb_reset_transport_header(skb);
764         udph            = udp_hdr(skb);
765         udph->source    = amt->gw_port;
766         udph->dest      = amt->relay_port;
767         udph->len       = htons(sizeof(*amtrh) + sizeof(*udph));
768         udph->check     = 0;
769         offset = skb_transport_offset(skb);
770         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
771         udph->check = csum_tcpudp_magic(amt->local_ip, amt->remote_ip,
772                                         sizeof(*udph) + sizeof(*amtrh),
773                                         IPPROTO_UDP, skb->csum);
774
775         skb_push(skb, sizeof(*iph));
776         iph             = ip_hdr(skb);
777         iph->version    = 4;
778         iph->ihl        = (sizeof(struct iphdr)) >> 2;
779         iph->tos        = AMT_TOS;
780         iph->frag_off   = 0;
781         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
782         iph->daddr      = amt->remote_ip;
783         iph->saddr      = amt->local_ip;
784         iph->protocol   = IPPROTO_UDP;
785         iph->tot_len    = htons(len);
786
787         skb->ip_summed = CHECKSUM_NONE;
788         ip_select_ident(amt->net, skb, NULL);
789         ip_send_check(iph);
790         err = ip_local_out(amt->net, sock->sk, skb);
791         if (unlikely(net_xmit_eval(err)))
792                 amt->dev->stats.tx_errors++;
793
794 out:
795         rcu_read_unlock();
796 }
797
798 static void amt_send_igmp_gq(struct amt_dev *amt,
799                              struct amt_tunnel_list *tunnel)
800 {
801         struct sk_buff *skb;
802
803         skb = amt_build_igmp_gq(amt);
804         if (!skb)
805                 return;
806
807         amt_skb_cb(skb)->tunnel = tunnel;
808         dev_queue_xmit(skb);
809 }
810
811 #if IS_ENABLED(CONFIG_IPV6)
812 static struct sk_buff *amt_build_mld_gq(struct amt_dev *amt)
813 {
814         u8 ra[AMT_IP6HDR_OPTS] = { IPPROTO_ICMPV6, 0, IPV6_TLV_ROUTERALERT,
815                                    2, 0, 0, IPV6_TLV_PAD1, IPV6_TLV_PAD1 };
816         int hlen = LL_RESERVED_SPACE(amt->dev);
817         int tlen = amt->dev->needed_tailroom;
818         struct mld2_query *mld2q;
819         void *csum_start = NULL;
820         struct ipv6hdr *ip6h;
821         struct sk_buff *skb;
822         struct ethhdr *eth;
823         u32 len;
824
825         len = hlen + tlen + sizeof(*ip6h) + sizeof(ra) + sizeof(*mld2q);
826         skb = netdev_alloc_skb_ip_align(amt->dev, len);
827         if (!skb)
828                 return NULL;
829
830         skb_reserve(skb, hlen);
831         skb_push(skb, sizeof(*eth));
832         skb_reset_mac_header(skb);
833         eth = eth_hdr(skb);
834         skb->priority = TC_PRIO_CONTROL;
835         skb->protocol = htons(ETH_P_IPV6);
836         skb_put_zero(skb, sizeof(*ip6h));
837         skb_put_data(skb, ra, sizeof(ra));
838         skb_put_zero(skb, sizeof(*mld2q));
839         skb_pull(skb, sizeof(*eth));
840         skb_reset_network_header(skb);
841         ip6h                    = ipv6_hdr(skb);
842         ip6h->payload_len       = htons(sizeof(ra) + sizeof(*mld2q));
843         ip6h->nexthdr           = NEXTHDR_HOP;
844         ip6h->hop_limit         = 1;
845         ip6h->daddr             = mld2_all_node;
846         ip6_flow_hdr(ip6h, 0, 0);
847
848         if (ipv6_dev_get_saddr(amt->net, amt->dev, &ip6h->daddr, 0,
849                                &ip6h->saddr)) {
850                 amt->dev->stats.tx_errors++;
851                 kfree_skb(skb);
852                 return NULL;
853         }
854
855         eth->h_proto = htons(ETH_P_IPV6);
856         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
857         ipv6_eth_mc_map(&mld2_all_node, eth->h_dest);
858
859         skb_pull(skb, sizeof(*ip6h) + sizeof(ra));
860         skb_reset_transport_header(skb);
861         mld2q                   = (struct mld2_query *)icmp6_hdr(skb);
862         mld2q->mld2q_mrc        = htons(1);
863         mld2q->mld2q_type       = ICMPV6_MGM_QUERY;
864         mld2q->mld2q_code       = 0;
865         mld2q->mld2q_cksum      = 0;
866         mld2q->mld2q_resv1      = 0;
867         mld2q->mld2q_resv2      = 0;
868         mld2q->mld2q_suppress   = 0;
869         mld2q->mld2q_qrv        = amt->qrv;
870         mld2q->mld2q_nsrcs      = 0;
871         mld2q->mld2q_qqic       = amt->qi;
872         csum_start              = (void *)mld2q;
873         mld2q->mld2q_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
874                                              sizeof(*mld2q),
875                                              IPPROTO_ICMPV6,
876                                              csum_partial(csum_start,
877                                                           sizeof(*mld2q), 0));
878
879         skb->ip_summed = CHECKSUM_NONE;
880         skb_push(skb, sizeof(*eth) + sizeof(*ip6h) + sizeof(ra));
881         return skb;
882 }
883
884 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
885 {
886         struct sk_buff *skb;
887
888         skb = amt_build_mld_gq(amt);
889         if (!skb)
890                 return;
891
892         amt_skb_cb(skb)->tunnel = tunnel;
893         dev_queue_xmit(skb);
894 }
895 #else
896 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
897 {
898 }
899 #endif
900
901 static void amt_secret_work(struct work_struct *work)
902 {
903         struct amt_dev *amt = container_of(to_delayed_work(work),
904                                            struct amt_dev,
905                                            secret_wq);
906
907         spin_lock_bh(&amt->lock);
908         get_random_bytes(&amt->key, sizeof(siphash_key_t));
909         spin_unlock_bh(&amt->lock);
910         mod_delayed_work(amt_wq, &amt->secret_wq,
911                          msecs_to_jiffies(AMT_SECRET_TIMEOUT));
912 }
913
914 static void amt_discovery_work(struct work_struct *work)
915 {
916         struct amt_dev *amt = container_of(to_delayed_work(work),
917                                            struct amt_dev,
918                                            discovery_wq);
919
920         spin_lock_bh(&amt->lock);
921         if (amt->status > AMT_STATUS_SENT_DISCOVERY)
922                 goto out;
923         get_random_bytes(&amt->nonce, sizeof(__be32));
924         spin_unlock_bh(&amt->lock);
925
926         amt_send_discovery(amt);
927         spin_lock_bh(&amt->lock);
928 out:
929         mod_delayed_work(amt_wq, &amt->discovery_wq,
930                          msecs_to_jiffies(AMT_DISCOVERY_TIMEOUT));
931         spin_unlock_bh(&amt->lock);
932 }
933
934 static void amt_req_work(struct work_struct *work)
935 {
936         struct amt_dev *amt = container_of(to_delayed_work(work),
937                                            struct amt_dev,
938                                            req_wq);
939         u32 exp;
940
941         spin_lock_bh(&amt->lock);
942         if (amt->status < AMT_STATUS_RECEIVED_ADVERTISEMENT)
943                 goto out;
944
945         if (amt->req_cnt++ > AMT_MAX_REQ_COUNT) {
946                 netdev_dbg(amt->dev, "Gateway is not ready");
947                 amt->qi = AMT_INIT_REQ_TIMEOUT;
948                 amt->ready4 = false;
949                 amt->ready6 = false;
950                 amt->remote_ip = 0;
951                 __amt_update_gw_status(amt, AMT_STATUS_INIT, false);
952                 amt->req_cnt = 0;
953         }
954         spin_unlock_bh(&amt->lock);
955
956         amt_send_request(amt, false);
957         amt_send_request(amt, true);
958         amt_update_gw_status(amt, AMT_STATUS_SENT_REQUEST, true);
959         spin_lock_bh(&amt->lock);
960 out:
961         exp = min_t(u32, (1 * (1 << amt->req_cnt)), AMT_MAX_REQ_TIMEOUT);
962         mod_delayed_work(amt_wq, &amt->req_wq, msecs_to_jiffies(exp * 1000));
963         spin_unlock_bh(&amt->lock);
964 }
965
966 static bool amt_send_membership_update(struct amt_dev *amt,
967                                        struct sk_buff *skb,
968                                        bool v6)
969 {
970         struct amt_header_membership_update *amtmu;
971         struct socket *sock;
972         struct iphdr *iph;
973         struct flowi4 fl4;
974         struct rtable *rt;
975         int err;
976
977         sock = rcu_dereference_bh(amt->sock);
978         if (!sock)
979                 return true;
980
981         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmu) +
982                            sizeof(*iph) + sizeof(struct udphdr));
983         if (err)
984                 return true;
985
986         skb_reset_inner_headers(skb);
987         memset(&fl4, 0, sizeof(struct flowi4));
988         fl4.flowi4_oif         = amt->stream_dev->ifindex;
989         fl4.daddr              = amt->remote_ip;
990         fl4.saddr              = amt->local_ip;
991         fl4.flowi4_tos         = AMT_TOS;
992         fl4.flowi4_proto       = IPPROTO_UDP;
993         rt = ip_route_output_key(amt->net, &fl4);
994         if (IS_ERR(rt)) {
995                 netdev_dbg(amt->dev, "no route to %pI4\n", &amt->remote_ip);
996                 return true;
997         }
998
999         amtmu                   = skb_push(skb, sizeof(*amtmu));
1000         amtmu->version          = 0;
1001         amtmu->type             = AMT_MSG_MEMBERSHIP_UPDATE;
1002         amtmu->reserved         = 0;
1003         amtmu->nonce            = amt->nonce;
1004         amtmu->response_mac     = amt->mac;
1005
1006         if (!v6)
1007                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1008         else
1009                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1010         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1011                             fl4.saddr,
1012                             fl4.daddr,
1013                             AMT_TOS,
1014                             ip4_dst_hoplimit(&rt->dst),
1015                             0,
1016                             amt->gw_port,
1017                             amt->relay_port,
1018                             false,
1019                             false);
1020         amt_update_gw_status(amt, AMT_STATUS_SENT_UPDATE, true);
1021         return false;
1022 }
1023
1024 static void amt_send_multicast_data(struct amt_dev *amt,
1025                                     const struct sk_buff *oskb,
1026                                     struct amt_tunnel_list *tunnel,
1027                                     bool v6)
1028 {
1029         struct amt_header_mcast_data *amtmd;
1030         struct socket *sock;
1031         struct sk_buff *skb;
1032         struct iphdr *iph;
1033         struct flowi4 fl4;
1034         struct rtable *rt;
1035
1036         sock = rcu_dereference_bh(amt->sock);
1037         if (!sock)
1038                 return;
1039
1040         skb = skb_copy_expand(oskb, sizeof(*amtmd) + sizeof(*iph) +
1041                               sizeof(struct udphdr), 0, GFP_ATOMIC);
1042         if (!skb)
1043                 return;
1044
1045         skb_reset_inner_headers(skb);
1046         memset(&fl4, 0, sizeof(struct flowi4));
1047         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1048         fl4.daddr              = tunnel->ip4;
1049         fl4.saddr              = amt->local_ip;
1050         fl4.flowi4_proto       = IPPROTO_UDP;
1051         rt = ip_route_output_key(amt->net, &fl4);
1052         if (IS_ERR(rt)) {
1053                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1054                 kfree_skb(skb);
1055                 return;
1056         }
1057
1058         amtmd = skb_push(skb, sizeof(*amtmd));
1059         amtmd->version = 0;
1060         amtmd->reserved = 0;
1061         amtmd->type = AMT_MSG_MULTICAST_DATA;
1062
1063         if (!v6)
1064                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1065         else
1066                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1067         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1068                             fl4.saddr,
1069                             fl4.daddr,
1070                             AMT_TOS,
1071                             ip4_dst_hoplimit(&rt->dst),
1072                             0,
1073                             amt->relay_port,
1074                             tunnel->source_port,
1075                             false,
1076                             false);
1077 }
1078
1079 static bool amt_send_membership_query(struct amt_dev *amt,
1080                                       struct sk_buff *skb,
1081                                       struct amt_tunnel_list *tunnel,
1082                                       bool v6)
1083 {
1084         struct amt_header_membership_query *amtmq;
1085         struct socket *sock;
1086         struct rtable *rt;
1087         struct flowi4 fl4;
1088         int err;
1089
1090         sock = rcu_dereference_bh(amt->sock);
1091         if (!sock)
1092                 return true;
1093
1094         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmq) +
1095                            sizeof(struct iphdr) + sizeof(struct udphdr));
1096         if (err)
1097                 return true;
1098
1099         skb_reset_inner_headers(skb);
1100         memset(&fl4, 0, sizeof(struct flowi4));
1101         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1102         fl4.daddr              = tunnel->ip4;
1103         fl4.saddr              = amt->local_ip;
1104         fl4.flowi4_tos         = AMT_TOS;
1105         fl4.flowi4_proto       = IPPROTO_UDP;
1106         rt = ip_route_output_key(amt->net, &fl4);
1107         if (IS_ERR(rt)) {
1108                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1109                 return true;
1110         }
1111
1112         amtmq           = skb_push(skb, sizeof(*amtmq));
1113         amtmq->version  = 0;
1114         amtmq->type     = AMT_MSG_MEMBERSHIP_QUERY;
1115         amtmq->reserved = 0;
1116         amtmq->l        = 0;
1117         amtmq->g        = 0;
1118         amtmq->nonce    = tunnel->nonce;
1119         amtmq->response_mac = tunnel->mac;
1120
1121         if (!v6)
1122                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1123         else
1124                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1125         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1126                             fl4.saddr,
1127                             fl4.daddr,
1128                             AMT_TOS,
1129                             ip4_dst_hoplimit(&rt->dst),
1130                             0,
1131                             amt->relay_port,
1132                             tunnel->source_port,
1133                             false,
1134                             false);
1135         amt_update_relay_status(tunnel, AMT_STATUS_SENT_QUERY, true);
1136         return false;
1137 }
1138
1139 static netdev_tx_t amt_dev_xmit(struct sk_buff *skb, struct net_device *dev)
1140 {
1141         struct amt_dev *amt = netdev_priv(dev);
1142         struct amt_tunnel_list *tunnel;
1143         struct amt_group_node *gnode;
1144         union amt_addr group = {0,};
1145 #if IS_ENABLED(CONFIG_IPV6)
1146         struct ipv6hdr *ip6h;
1147         struct mld_msg *mld;
1148 #endif
1149         bool report = false;
1150         struct igmphdr *ih;
1151         bool query = false;
1152         struct iphdr *iph;
1153         bool data = false;
1154         bool v6 = false;
1155         u32 hash;
1156
1157         iph = ip_hdr(skb);
1158         if (iph->version == 4) {
1159                 if (!ipv4_is_multicast(iph->daddr))
1160                         goto free;
1161
1162                 if (!ip_mc_check_igmp(skb)) {
1163                         ih = igmp_hdr(skb);
1164                         switch (ih->type) {
1165                         case IGMPV3_HOST_MEMBERSHIP_REPORT:
1166                         case IGMP_HOST_MEMBERSHIP_REPORT:
1167                                 report = true;
1168                                 break;
1169                         case IGMP_HOST_MEMBERSHIP_QUERY:
1170                                 query = true;
1171                                 break;
1172                         default:
1173                                 goto free;
1174                         }
1175                 } else {
1176                         data = true;
1177                 }
1178                 v6 = false;
1179                 group.ip4 = iph->daddr;
1180 #if IS_ENABLED(CONFIG_IPV6)
1181         } else if (iph->version == 6) {
1182                 ip6h = ipv6_hdr(skb);
1183                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
1184                         goto free;
1185
1186                 if (!ipv6_mc_check_mld(skb)) {
1187                         mld = (struct mld_msg *)skb_transport_header(skb);
1188                         switch (mld->mld_type) {
1189                         case ICMPV6_MGM_REPORT:
1190                         case ICMPV6_MLD2_REPORT:
1191                                 report = true;
1192                                 break;
1193                         case ICMPV6_MGM_QUERY:
1194                                 query = true;
1195                                 break;
1196                         default:
1197                                 goto free;
1198                         }
1199                 } else {
1200                         data = true;
1201                 }
1202                 v6 = true;
1203                 group.ip6 = ip6h->daddr;
1204 #endif
1205         } else {
1206                 dev->stats.tx_errors++;
1207                 goto free;
1208         }
1209
1210         if (!pskb_may_pull(skb, sizeof(struct ethhdr)))
1211                 goto free;
1212
1213         skb_pull(skb, sizeof(struct ethhdr));
1214
1215         if (amt->mode == AMT_MODE_GATEWAY) {
1216                 /* Gateway only passes IGMP/MLD packets */
1217                 if (!report)
1218                         goto free;
1219                 if ((!v6 && !amt->ready4) || (v6 && !amt->ready6))
1220                         goto free;
1221                 if (amt_send_membership_update(amt, skb,  v6))
1222                         goto free;
1223                 goto unlock;
1224         } else if (amt->mode == AMT_MODE_RELAY) {
1225                 if (query) {
1226                         tunnel = amt_skb_cb(skb)->tunnel;
1227                         if (!tunnel) {
1228                                 WARN_ON(1);
1229                                 goto free;
1230                         }
1231
1232                         /* Do not forward unexpected query */
1233                         if (amt_send_membership_query(amt, skb, tunnel, v6))
1234                                 goto free;
1235                         goto unlock;
1236                 }
1237
1238                 if (!data)
1239                         goto free;
1240                 list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
1241                         hash = amt_group_hash(tunnel, &group);
1242                         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash],
1243                                                  node) {
1244                                 if (!v6) {
1245                                         if (gnode->group_addr.ip4 == iph->daddr)
1246                                                 goto found;
1247 #if IS_ENABLED(CONFIG_IPV6)
1248                                 } else {
1249                                         if (ipv6_addr_equal(&gnode->group_addr.ip6,
1250                                                             &ip6h->daddr))
1251                                                 goto found;
1252 #endif
1253                                 }
1254                         }
1255                         continue;
1256 found:
1257                         amt_send_multicast_data(amt, skb, tunnel, v6);
1258                 }
1259         }
1260
1261         dev_kfree_skb(skb);
1262         return NETDEV_TX_OK;
1263 free:
1264         dev_kfree_skb(skb);
1265 unlock:
1266         dev->stats.tx_dropped++;
1267         return NETDEV_TX_OK;
1268 }
1269
1270 static int amt_parse_type(struct sk_buff *skb)
1271 {
1272         struct amt_header *amth;
1273
1274         if (!pskb_may_pull(skb, sizeof(struct udphdr) +
1275                            sizeof(struct amt_header)))
1276                 return -1;
1277
1278         amth = (struct amt_header *)(udp_hdr(skb) + 1);
1279
1280         if (amth->version != 0)
1281                 return -1;
1282
1283         if (amth->type >= __AMT_MSG_MAX || !amth->type)
1284                 return -1;
1285         return amth->type;
1286 }
1287
1288 static void amt_clear_groups(struct amt_tunnel_list *tunnel)
1289 {
1290         struct amt_dev *amt = tunnel->amt;
1291         struct amt_group_node *gnode;
1292         struct hlist_node *t;
1293         int i;
1294
1295         spin_lock_bh(&tunnel->lock);
1296         rcu_read_lock();
1297         for (i = 0; i < amt->hash_buckets; i++)
1298                 hlist_for_each_entry_safe(gnode, t, &tunnel->groups[i], node)
1299                         amt_del_group(amt, gnode);
1300         rcu_read_unlock();
1301         spin_unlock_bh(&tunnel->lock);
1302 }
1303
1304 static void amt_tunnel_expire(struct work_struct *work)
1305 {
1306         struct amt_tunnel_list *tunnel = container_of(to_delayed_work(work),
1307                                                       struct amt_tunnel_list,
1308                                                       gc_wq);
1309         struct amt_dev *amt = tunnel->amt;
1310
1311         spin_lock_bh(&amt->lock);
1312         rcu_read_lock();
1313         list_del_rcu(&tunnel->list);
1314         amt->nr_tunnels--;
1315         amt_clear_groups(tunnel);
1316         rcu_read_unlock();
1317         spin_unlock_bh(&amt->lock);
1318         kfree_rcu(tunnel, rcu);
1319 }
1320
1321 static void amt_cleanup_srcs(struct amt_dev *amt,
1322                              struct amt_tunnel_list *tunnel,
1323                              struct amt_group_node *gnode)
1324 {
1325         struct amt_source_node *snode;
1326         struct hlist_node *t;
1327         int i;
1328
1329         /* Delete old sources */
1330         for (i = 0; i < amt->hash_buckets; i++) {
1331                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node) {
1332                         if (snode->flags == AMT_SOURCE_OLD)
1333                                 amt_destroy_source(snode);
1334                 }
1335         }
1336
1337         /* switch from new to old */
1338         for (i = 0; i < amt->hash_buckets; i++)  {
1339                 hlist_for_each_entry_rcu(snode, &gnode->sources[i], node) {
1340                         snode->flags = AMT_SOURCE_OLD;
1341                         if (!gnode->v6)
1342                                 netdev_dbg(snode->gnode->amt->dev,
1343                                            "Add source as OLD %pI4 from %pI4\n",
1344                                            &snode->source_addr.ip4,
1345                                            &gnode->group_addr.ip4);
1346 #if IS_ENABLED(CONFIG_IPV6)
1347                         else
1348                                 netdev_dbg(snode->gnode->amt->dev,
1349                                            "Add source as OLD %pI6 from %pI6\n",
1350                                            &snode->source_addr.ip6,
1351                                            &gnode->group_addr.ip6);
1352 #endif
1353                 }
1354         }
1355 }
1356
1357 static void amt_add_srcs(struct amt_dev *amt, struct amt_tunnel_list *tunnel,
1358                          struct amt_group_node *gnode, void *grec,
1359                          bool v6)
1360 {
1361         struct igmpv3_grec *igmp_grec;
1362         struct amt_source_node *snode;
1363 #if IS_ENABLED(CONFIG_IPV6)
1364         struct mld2_grec *mld_grec;
1365 #endif
1366         union amt_addr src = {0,};
1367         u16 nsrcs;
1368         u32 hash;
1369         int i;
1370
1371         if (!v6) {
1372                 igmp_grec = (struct igmpv3_grec *)grec;
1373                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1374         } else {
1375 #if IS_ENABLED(CONFIG_IPV6)
1376                 mld_grec = (struct mld2_grec *)grec;
1377                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1378 #else
1379         return;
1380 #endif
1381         }
1382         for (i = 0; i < nsrcs; i++) {
1383                 if (tunnel->nr_sources >= amt->max_sources)
1384                         return;
1385                 if (!v6)
1386                         src.ip4 = igmp_grec->grec_src[i];
1387 #if IS_ENABLED(CONFIG_IPV6)
1388                 else
1389                         memcpy(&src.ip6, &mld_grec->grec_src[i],
1390                                sizeof(struct in6_addr));
1391 #endif
1392                 if (amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL, &src))
1393                         continue;
1394
1395                 snode = amt_alloc_snode(gnode, &src);
1396                 if (snode) {
1397                         hash = amt_source_hash(tunnel, &snode->source_addr);
1398                         hlist_add_head_rcu(&snode->node, &gnode->sources[hash]);
1399                         tunnel->nr_sources++;
1400                         gnode->nr_sources++;
1401
1402                         if (!gnode->v6)
1403                                 netdev_dbg(snode->gnode->amt->dev,
1404                                            "Add source as NEW %pI4 from %pI4\n",
1405                                            &snode->source_addr.ip4,
1406                                            &gnode->group_addr.ip4);
1407 #if IS_ENABLED(CONFIG_IPV6)
1408                         else
1409                                 netdev_dbg(snode->gnode->amt->dev,
1410                                            "Add source as NEW %pI6 from %pI6\n",
1411                                            &snode->source_addr.ip6,
1412                                            &gnode->group_addr.ip6);
1413 #endif
1414                 }
1415         }
1416 }
1417
1418 /* Router State   Report Rec'd New Router State
1419  * ------------   ------------ ----------------
1420  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)
1421  *
1422  * -----------+-----------+-----------+
1423  *            |    OLD    |    NEW    |
1424  * -----------+-----------+-----------+
1425  *    FWD     |     X     |    X+A    |
1426  * -----------+-----------+-----------+
1427  *    D_FWD   |     Y     |    Y-A    |
1428  * -----------+-----------+-----------+
1429  *    NONE    |           |     A     |
1430  * -----------+-----------+-----------+
1431  *
1432  * a) Received sources are NONE/NEW
1433  * b) All NONE will be deleted by amt_cleanup_srcs().
1434  * c) All OLD will be deleted by amt_cleanup_srcs().
1435  * d) After delete, NEW source will be switched to OLD.
1436  */
1437 static void amt_lookup_act_srcs(struct amt_tunnel_list *tunnel,
1438                                 struct amt_group_node *gnode,
1439                                 void *grec,
1440                                 enum amt_ops ops,
1441                                 enum amt_filter filter,
1442                                 enum amt_act act,
1443                                 bool v6)
1444 {
1445         struct amt_dev *amt = tunnel->amt;
1446         struct amt_source_node *snode;
1447         struct igmpv3_grec *igmp_grec;
1448 #if IS_ENABLED(CONFIG_IPV6)
1449         struct mld2_grec *mld_grec;
1450 #endif
1451         union amt_addr src = {0,};
1452         struct hlist_node *t;
1453         u16 nsrcs;
1454         int i, j;
1455
1456         if (!v6) {
1457                 igmp_grec = (struct igmpv3_grec *)grec;
1458                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1459         } else {
1460 #if IS_ENABLED(CONFIG_IPV6)
1461                 mld_grec = (struct mld2_grec *)grec;
1462                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1463 #else
1464         return;
1465 #endif
1466         }
1467
1468         memset(&src, 0, sizeof(union amt_addr));
1469         switch (ops) {
1470         case AMT_OPS_INT:
1471                 /* A*B */
1472                 for (i = 0; i < nsrcs; i++) {
1473                         if (!v6)
1474                                 src.ip4 = igmp_grec->grec_src[i];
1475 #if IS_ENABLED(CONFIG_IPV6)
1476                         else
1477                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1478                                        sizeof(struct in6_addr));
1479 #endif
1480                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1481                         if (!snode)
1482                                 continue;
1483                         amt_act_src(tunnel, gnode, snode, act);
1484                 }
1485                 break;
1486         case AMT_OPS_UNI:
1487                 /* A+B */
1488                 for (i = 0; i < amt->hash_buckets; i++) {
1489                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1490                                                   node) {
1491                                 if (amt_status_filter(snode, filter))
1492                                         amt_act_src(tunnel, gnode, snode, act);
1493                         }
1494                 }
1495                 for (i = 0; i < nsrcs; i++) {
1496                         if (!v6)
1497                                 src.ip4 = igmp_grec->grec_src[i];
1498 #if IS_ENABLED(CONFIG_IPV6)
1499                         else
1500                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1501                                        sizeof(struct in6_addr));
1502 #endif
1503                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1504                         if (!snode)
1505                                 continue;
1506                         amt_act_src(tunnel, gnode, snode, act);
1507                 }
1508                 break;
1509         case AMT_OPS_SUB:
1510                 /* A-B */
1511                 for (i = 0; i < amt->hash_buckets; i++) {
1512                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1513                                                   node) {
1514                                 if (!amt_status_filter(snode, filter))
1515                                         continue;
1516                                 for (j = 0; j < nsrcs; j++) {
1517                                         if (!v6)
1518                                                 src.ip4 = igmp_grec->grec_src[j];
1519 #if IS_ENABLED(CONFIG_IPV6)
1520                                         else
1521                                                 memcpy(&src.ip6,
1522                                                        &mld_grec->grec_src[j],
1523                                                        sizeof(struct in6_addr));
1524 #endif
1525                                         if (amt_addr_equal(&snode->source_addr,
1526                                                            &src))
1527                                                 goto out_sub;
1528                                 }
1529                                 amt_act_src(tunnel, gnode, snode, act);
1530                                 continue;
1531 out_sub:;
1532                         }
1533                 }
1534                 break;
1535         case AMT_OPS_SUB_REV:
1536                 /* B-A */
1537                 for (i = 0; i < nsrcs; i++) {
1538                         if (!v6)
1539                                 src.ip4 = igmp_grec->grec_src[i];
1540 #if IS_ENABLED(CONFIG_IPV6)
1541                         else
1542                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1543                                        sizeof(struct in6_addr));
1544 #endif
1545                         snode = amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL,
1546                                                &src);
1547                         if (!snode) {
1548                                 snode = amt_lookup_src(tunnel, gnode,
1549                                                        filter, &src);
1550                                 if (snode)
1551                                         amt_act_src(tunnel, gnode, snode, act);
1552                         }
1553                 }
1554                 break;
1555         default:
1556                 netdev_dbg(amt->dev, "Invalid type\n");
1557                 return;
1558         }
1559 }
1560
1561 static void amt_mcast_is_in_handler(struct amt_dev *amt,
1562                                     struct amt_tunnel_list *tunnel,
1563                                     struct amt_group_node *gnode,
1564                                     void *grec, void *zero_grec, bool v6)
1565 {
1566         if (gnode->filter_mode == MCAST_INCLUDE) {
1567 /* Router State   Report Rec'd New Router State        Actions
1568  * ------------   ------------ ----------------        -------
1569  * INCLUDE (A)    IS_IN (B)    INCLUDE (A+B)           (B)=GMI
1570  */
1571                 /* Update IS_IN (B) as FWD/NEW */
1572                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1573                                     AMT_FILTER_NONE_NEW,
1574                                     AMT_ACT_STATUS_FWD_NEW,
1575                                     v6);
1576                 /* Update INCLUDE (A) as NEW */
1577                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1578                                     AMT_FILTER_FWD,
1579                                     AMT_ACT_STATUS_FWD_NEW,
1580                                     v6);
1581                 /* (B)=GMI */
1582                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1583                                     AMT_FILTER_FWD_NEW,
1584                                     AMT_ACT_GMI,
1585                                     v6);
1586         } else {
1587 /* State        Actions
1588  * ------------   ------------ ----------------        -------
1589  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1590  */
1591                 /* Update (A) in (X, Y) as NONE/NEW */
1592                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1593                                     AMT_FILTER_BOTH,
1594                                     AMT_ACT_STATUS_NONE_NEW,
1595                                     v6);
1596                 /* Update FWD/OLD as FWD/NEW */
1597                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1598                                     AMT_FILTER_FWD,
1599                                     AMT_ACT_STATUS_FWD_NEW,
1600                                     v6);
1601                 /* Update IS_IN (A) as FWD/NEW */
1602                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1603                                     AMT_FILTER_NONE_NEW,
1604                                     AMT_ACT_STATUS_FWD_NEW,
1605                                     v6);
1606                 /* Update EXCLUDE (, Y-A) as D_FWD_NEW */
1607                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1608                                     AMT_FILTER_D_FWD,
1609                                     AMT_ACT_STATUS_D_FWD_NEW,
1610                                     v6);
1611         }
1612 }
1613
1614 static void amt_mcast_is_ex_handler(struct amt_dev *amt,
1615                                     struct amt_tunnel_list *tunnel,
1616                                     struct amt_group_node *gnode,
1617                                     void *grec, void *zero_grec, bool v6)
1618 {
1619         if (gnode->filter_mode == MCAST_INCLUDE) {
1620 /* Router State   Report Rec'd  New Router State         Actions
1621  * ------------   ------------  ----------------         -------
1622  * INCLUDE (A)    IS_EX (B)     EXCLUDE (A*B,B-A)        (B-A)=0
1623  *                                                       Delete (A-B)
1624  *                                                       Group Timer=GMI
1625  */
1626                 /* EXCLUDE(A*B, ) */
1627                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1628                                     AMT_FILTER_FWD,
1629                                     AMT_ACT_STATUS_FWD_NEW,
1630                                     v6);
1631                 /* EXCLUDE(, B-A) */
1632                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1633                                     AMT_FILTER_FWD,
1634                                     AMT_ACT_STATUS_D_FWD_NEW,
1635                                     v6);
1636                 /* (B-A)=0 */
1637                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1638                                     AMT_FILTER_D_FWD_NEW,
1639                                     AMT_ACT_GMI_ZERO,
1640                                     v6);
1641                 /* Group Timer=GMI */
1642                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1643                                       msecs_to_jiffies(amt_gmi(amt))))
1644                         dev_hold(amt->dev);
1645                 gnode->filter_mode = MCAST_EXCLUDE;
1646                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1647         } else {
1648 /* Router State   Report Rec'd  New Router State        Actions
1649  * ------------   ------------  ----------------        -------
1650  * EXCLUDE (X,Y)  IS_EX (A)     EXCLUDE (A-Y,Y*A)       (A-X-Y)=GMI
1651  *                                                      Delete (X-A)
1652  *                                                      Delete (Y-A)
1653  *                                                      Group Timer=GMI
1654  */
1655                 /* EXCLUDE (A-Y, ) */
1656                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1657                                     AMT_FILTER_D_FWD,
1658                                     AMT_ACT_STATUS_FWD_NEW,
1659                                     v6);
1660                 /* EXCLUDE (, Y*A ) */
1661                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1662                                     AMT_FILTER_D_FWD,
1663                                     AMT_ACT_STATUS_D_FWD_NEW,
1664                                     v6);
1665                 /* (A-X-Y)=GMI */
1666                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1667                                     AMT_FILTER_BOTH_NEW,
1668                                     AMT_ACT_GMI,
1669                                     v6);
1670                 /* Group Timer=GMI */
1671                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1672                                       msecs_to_jiffies(amt_gmi(amt))))
1673                         dev_hold(amt->dev);
1674                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1675         }
1676 }
1677
1678 static void amt_mcast_to_in_handler(struct amt_dev *amt,
1679                                     struct amt_tunnel_list *tunnel,
1680                                     struct amt_group_node *gnode,
1681                                     void *grec, void *zero_grec, bool v6)
1682 {
1683         if (gnode->filter_mode == MCAST_INCLUDE) {
1684 /* Router State   Report Rec'd New Router State        Actions
1685  * ------------   ------------ ----------------        -------
1686  * INCLUDE (A)    TO_IN (B)    INCLUDE (A+B)           (B)=GMI
1687  *                                                     Send Q(G,A-B)
1688  */
1689                 /* Update TO_IN (B) sources as FWD/NEW */
1690                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1691                                     AMT_FILTER_NONE_NEW,
1692                                     AMT_ACT_STATUS_FWD_NEW,
1693                                     v6);
1694                 /* Update INCLUDE (A) sources as NEW */
1695                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1696                                     AMT_FILTER_FWD,
1697                                     AMT_ACT_STATUS_FWD_NEW,
1698                                     v6);
1699                 /* (B)=GMI */
1700                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1701                                     AMT_FILTER_FWD_NEW,
1702                                     AMT_ACT_GMI,
1703                                     v6);
1704         } else {
1705 /* Router State   Report Rec'd New Router State        Actions
1706  * ------------   ------------ ----------------        -------
1707  * EXCLUDE (X,Y)  TO_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1708  *                                                     Send Q(G,X-A)
1709  *                                                     Send Q(G)
1710  */
1711                 /* Update TO_IN (A) sources as FWD/NEW */
1712                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1713                                     AMT_FILTER_NONE_NEW,
1714                                     AMT_ACT_STATUS_FWD_NEW,
1715                                     v6);
1716                 /* Update EXCLUDE(X,) sources as FWD/NEW */
1717                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1718                                     AMT_FILTER_FWD,
1719                                     AMT_ACT_STATUS_FWD_NEW,
1720                                     v6);
1721                 /* EXCLUDE (, Y-A)
1722                  * (A) are already switched to FWD_NEW.
1723                  * So, D_FWD/OLD -> D_FWD/NEW is okay.
1724                  */
1725                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1726                                     AMT_FILTER_D_FWD,
1727                                     AMT_ACT_STATUS_D_FWD_NEW,
1728                                     v6);
1729                 /* (A)=GMI
1730                  * Only FWD_NEW will have (A) sources.
1731                  */
1732                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1733                                     AMT_FILTER_FWD_NEW,
1734                                     AMT_ACT_GMI,
1735                                     v6);
1736         }
1737 }
1738
1739 static void amt_mcast_to_ex_handler(struct amt_dev *amt,
1740                                     struct amt_tunnel_list *tunnel,
1741                                     struct amt_group_node *gnode,
1742                                     void *grec, void *zero_grec, bool v6)
1743 {
1744         if (gnode->filter_mode == MCAST_INCLUDE) {
1745 /* Router State   Report Rec'd New Router State        Actions
1746  * ------------   ------------ ----------------        -------
1747  * INCLUDE (A)    TO_EX (B)    EXCLUDE (A*B,B-A)       (B-A)=0
1748  *                                                     Delete (A-B)
1749  *                                                     Send Q(G,A*B)
1750  *                                                     Group Timer=GMI
1751  */
1752                 /* EXCLUDE (A*B, ) */
1753                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1754                                     AMT_FILTER_FWD,
1755                                     AMT_ACT_STATUS_FWD_NEW,
1756                                     v6);
1757                 /* EXCLUDE (, B-A) */
1758                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1759                                     AMT_FILTER_FWD,
1760                                     AMT_ACT_STATUS_D_FWD_NEW,
1761                                     v6);
1762                 /* (B-A)=0 */
1763                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1764                                     AMT_FILTER_D_FWD_NEW,
1765                                     AMT_ACT_GMI_ZERO,
1766                                     v6);
1767                 /* Group Timer=GMI */
1768                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1769                                       msecs_to_jiffies(amt_gmi(amt))))
1770                         dev_hold(amt->dev);
1771                 gnode->filter_mode = MCAST_EXCLUDE;
1772                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1773         } else {
1774 /* Router State   Report Rec'd New Router State        Actions
1775  * ------------   ------------ ----------------        -------
1776  * EXCLUDE (X,Y)  TO_EX (A)    EXCLUDE (A-Y,Y*A)       (A-X-Y)=Group Timer
1777  *                                                     Delete (X-A)
1778  *                                                     Delete (Y-A)
1779  *                                                     Send Q(G,A-Y)
1780  *                                                     Group Timer=GMI
1781  */
1782                 /* Update (A-X-Y) as NONE/OLD */
1783                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1784                                     AMT_FILTER_BOTH,
1785                                     AMT_ACT_GT,
1786                                     v6);
1787                 /* EXCLUDE (A-Y, ) */
1788                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1789                                     AMT_FILTER_D_FWD,
1790                                     AMT_ACT_STATUS_FWD_NEW,
1791                                     v6);
1792                 /* EXCLUDE (, Y*A) */
1793                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1794                                     AMT_FILTER_D_FWD,
1795                                     AMT_ACT_STATUS_D_FWD_NEW,
1796                                     v6);
1797                 /* Group Timer=GMI */
1798                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1799                                       msecs_to_jiffies(amt_gmi(amt))))
1800                         dev_hold(amt->dev);
1801                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1802         }
1803 }
1804
1805 static void amt_mcast_allow_handler(struct amt_dev *amt,
1806                                     struct amt_tunnel_list *tunnel,
1807                                     struct amt_group_node *gnode,
1808                                     void *grec, void *zero_grec, bool v6)
1809 {
1810         if (gnode->filter_mode == MCAST_INCLUDE) {
1811 /* Router State   Report Rec'd New Router State        Actions
1812  * ------------   ------------ ----------------        -------
1813  * INCLUDE (A)    ALLOW (B)    INCLUDE (A+B)           (B)=GMI
1814  */
1815                 /* INCLUDE (A+B) */
1816                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1817                                     AMT_FILTER_FWD,
1818                                     AMT_ACT_STATUS_FWD_NEW,
1819                                     v6);
1820                 /* (B)=GMI */
1821                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1822                                     AMT_FILTER_FWD_NEW,
1823                                     AMT_ACT_GMI,
1824                                     v6);
1825         } else {
1826 /* Router State   Report Rec'd New Router State        Actions
1827  * ------------   ------------ ----------------        -------
1828  * EXCLUDE (X,Y)  ALLOW (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1829  */
1830                 /* EXCLUDE (X+A, ) */
1831                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1832                                     AMT_FILTER_FWD,
1833                                     AMT_ACT_STATUS_FWD_NEW,
1834                                     v6);
1835                 /* EXCLUDE (, Y-A) */
1836                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1837                                     AMT_FILTER_D_FWD,
1838                                     AMT_ACT_STATUS_D_FWD_NEW,
1839                                     v6);
1840                 /* (A)=GMI
1841                  * All (A) source are now FWD/NEW status.
1842                  */
1843                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1844                                     AMT_FILTER_FWD_NEW,
1845                                     AMT_ACT_GMI,
1846                                     v6);
1847         }
1848 }
1849
1850 static void amt_mcast_block_handler(struct amt_dev *amt,
1851                                     struct amt_tunnel_list *tunnel,
1852                                     struct amt_group_node *gnode,
1853                                     void *grec, void *zero_grec, bool v6)
1854 {
1855         if (gnode->filter_mode == MCAST_INCLUDE) {
1856 /* Router State   Report Rec'd New Router State        Actions
1857  * ------------   ------------ ----------------        -------
1858  * INCLUDE (A)    BLOCK (B)    INCLUDE (A)             Send Q(G,A*B)
1859  */
1860                 /* INCLUDE (A) */
1861                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1862                                     AMT_FILTER_FWD,
1863                                     AMT_ACT_STATUS_FWD_NEW,
1864                                     v6);
1865         } else {
1866 /* Router State   Report Rec'd New Router State        Actions
1867  * ------------   ------------ ----------------        -------
1868  * EXCLUDE (X,Y)  BLOCK (A)    EXCLUDE (X+(A-Y),Y)     (A-X-Y)=Group Timer
1869  *                                                     Send Q(G,A-Y)
1870  */
1871                 /* (A-X-Y)=Group Timer */
1872                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1873                                     AMT_FILTER_BOTH,
1874                                     AMT_ACT_GT,
1875                                     v6);
1876                 /* EXCLUDE (X, ) */
1877                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1878                                     AMT_FILTER_FWD,
1879                                     AMT_ACT_STATUS_FWD_NEW,
1880                                     v6);
1881                 /* EXCLUDE (X+(A-Y) */
1882                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1883                                     AMT_FILTER_D_FWD,
1884                                     AMT_ACT_STATUS_FWD_NEW,
1885                                     v6);
1886                 /* EXCLUDE (, Y) */
1887                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1888                                     AMT_FILTER_D_FWD,
1889                                     AMT_ACT_STATUS_D_FWD_NEW,
1890                                     v6);
1891         }
1892 }
1893
1894 /* RFC 3376
1895  * 7.3.2. In the Presence of Older Version Group Members
1896  *
1897  * When Group Compatibility Mode is IGMPv2, a router internally
1898  * translates the following IGMPv2 messages for that group to their
1899  * IGMPv3 equivalents:
1900  *
1901  * IGMPv2 Message                IGMPv3 Equivalent
1902  * --------------                -----------------
1903  * Report                        IS_EX( {} )
1904  * Leave                         TO_IN( {} )
1905  */
1906 static void amt_igmpv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1907                                       struct amt_tunnel_list *tunnel)
1908 {
1909         struct igmphdr *ih = igmp_hdr(skb);
1910         struct iphdr *iph = ip_hdr(skb);
1911         struct amt_group_node *gnode;
1912         union amt_addr group, host;
1913
1914         memset(&group, 0, sizeof(union amt_addr));
1915         group.ip4 = ih->group;
1916         memset(&host, 0, sizeof(union amt_addr));
1917         host.ip4 = iph->saddr;
1918
1919         gnode = amt_lookup_group(tunnel, &group, &host, false);
1920         if (!gnode) {
1921                 gnode = amt_add_group(amt, tunnel, &group, &host, false);
1922                 if (!IS_ERR(gnode)) {
1923                         gnode->filter_mode = MCAST_EXCLUDE;
1924                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1925                                               msecs_to_jiffies(amt_gmi(amt))))
1926                                 dev_hold(amt->dev);
1927                 }
1928         }
1929 }
1930
1931 /* RFC 3376
1932  * 7.3.2. In the Presence of Older Version Group Members
1933  *
1934  * When Group Compatibility Mode is IGMPv2, a router internally
1935  * translates the following IGMPv2 messages for that group to their
1936  * IGMPv3 equivalents:
1937  *
1938  * IGMPv2 Message                IGMPv3 Equivalent
1939  * --------------                -----------------
1940  * Report                        IS_EX( {} )
1941  * Leave                         TO_IN( {} )
1942  */
1943 static void amt_igmpv2_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
1944                                      struct amt_tunnel_list *tunnel)
1945 {
1946         struct igmphdr *ih = igmp_hdr(skb);
1947         struct iphdr *iph = ip_hdr(skb);
1948         struct amt_group_node *gnode;
1949         union amt_addr group, host;
1950
1951         memset(&group, 0, sizeof(union amt_addr));
1952         group.ip4 = ih->group;
1953         memset(&host, 0, sizeof(union amt_addr));
1954         host.ip4 = iph->saddr;
1955
1956         gnode = amt_lookup_group(tunnel, &group, &host, false);
1957         if (gnode)
1958                 amt_del_group(amt, gnode);
1959 }
1960
1961 static void amt_igmpv3_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1962                                       struct amt_tunnel_list *tunnel)
1963 {
1964         struct igmpv3_report *ihrv3 = igmpv3_report_hdr(skb);
1965         int len = skb_transport_offset(skb) + sizeof(*ihrv3);
1966         void *zero_grec = (void *)&igmpv3_zero_grec;
1967         struct iphdr *iph = ip_hdr(skb);
1968         struct amt_group_node *gnode;
1969         union amt_addr group, host;
1970         struct igmpv3_grec *grec;
1971         u16 nsrcs;
1972         int i;
1973
1974         for (i = 0; i < ntohs(ihrv3->ngrec); i++) {
1975                 len += sizeof(*grec);
1976                 if (!ip_mc_may_pull(skb, len))
1977                         break;
1978
1979                 grec = (void *)(skb->data + len - sizeof(*grec));
1980                 nsrcs = ntohs(grec->grec_nsrcs);
1981
1982                 len += nsrcs * sizeof(__be32);
1983                 if (!ip_mc_may_pull(skb, len))
1984                         break;
1985
1986                 memset(&group, 0, sizeof(union amt_addr));
1987                 group.ip4 = grec->grec_mca;
1988                 memset(&host, 0, sizeof(union amt_addr));
1989                 host.ip4 = iph->saddr;
1990                 gnode = amt_lookup_group(tunnel, &group, &host, false);
1991                 if (!gnode) {
1992                         gnode = amt_add_group(amt, tunnel, &group, &host,
1993                                               false);
1994                         if (IS_ERR(gnode))
1995                                 continue;
1996                 }
1997
1998                 amt_add_srcs(amt, tunnel, gnode, grec, false);
1999                 switch (grec->grec_type) {
2000                 case IGMPV3_MODE_IS_INCLUDE:
2001                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2002                                                 zero_grec, false);
2003                         break;
2004                 case IGMPV3_MODE_IS_EXCLUDE:
2005                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2006                                                 zero_grec, false);
2007                         break;
2008                 case IGMPV3_CHANGE_TO_INCLUDE:
2009                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2010                                                 zero_grec, false);
2011                         break;
2012                 case IGMPV3_CHANGE_TO_EXCLUDE:
2013                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2014                                                 zero_grec, false);
2015                         break;
2016                 case IGMPV3_ALLOW_NEW_SOURCES:
2017                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2018                                                 zero_grec, false);
2019                         break;
2020                 case IGMPV3_BLOCK_OLD_SOURCES:
2021                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2022                                                 zero_grec, false);
2023                         break;
2024                 default:
2025                         break;
2026                 }
2027                 amt_cleanup_srcs(amt, tunnel, gnode);
2028         }
2029 }
2030
2031 /* caller held tunnel->lock */
2032 static void amt_igmp_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2033                                     struct amt_tunnel_list *tunnel)
2034 {
2035         struct igmphdr *ih = igmp_hdr(skb);
2036
2037         switch (ih->type) {
2038         case IGMPV3_HOST_MEMBERSHIP_REPORT:
2039                 amt_igmpv3_report_handler(amt, skb, tunnel);
2040                 break;
2041         case IGMPV2_HOST_MEMBERSHIP_REPORT:
2042                 amt_igmpv2_report_handler(amt, skb, tunnel);
2043                 break;
2044         case IGMP_HOST_LEAVE_MESSAGE:
2045                 amt_igmpv2_leave_handler(amt, skb, tunnel);
2046                 break;
2047         default:
2048                 break;
2049         }
2050 }
2051
2052 #if IS_ENABLED(CONFIG_IPV6)
2053 /* RFC 3810
2054  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2055  *
2056  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2057  * using the MLDv2 protocol for that multicast address.  When Multicast
2058  * Address Compatibility Mode is MLDv1, a router internally translates
2059  * the following MLDv1 messages for that multicast address to their
2060  * MLDv2 equivalents:
2061  *
2062  * MLDv1 Message                 MLDv2 Equivalent
2063  * --------------                -----------------
2064  * Report                        IS_EX( {} )
2065  * Done                          TO_IN( {} )
2066  */
2067 static void amt_mldv1_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2068                                      struct amt_tunnel_list *tunnel)
2069 {
2070         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2071         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2072         struct amt_group_node *gnode;
2073         union amt_addr group, host;
2074
2075         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2076         memcpy(&host.ip6, &ip6h->saddr, sizeof(struct in6_addr));
2077
2078         gnode = amt_lookup_group(tunnel, &group, &host, true);
2079         if (!gnode) {
2080                 gnode = amt_add_group(amt, tunnel, &group, &host, true);
2081                 if (!IS_ERR(gnode)) {
2082                         gnode->filter_mode = MCAST_EXCLUDE;
2083                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
2084                                               msecs_to_jiffies(amt_gmi(amt))))
2085                                 dev_hold(amt->dev);
2086                 }
2087         }
2088 }
2089
2090 /* RFC 3810
2091  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2092  *
2093  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2094  * using the MLDv2 protocol for that multicast address.  When Multicast
2095  * Address Compatibility Mode is MLDv1, a router internally translates
2096  * the following MLDv1 messages for that multicast address to their
2097  * MLDv2 equivalents:
2098  *
2099  * MLDv1 Message                 MLDv2 Equivalent
2100  * --------------                -----------------
2101  * Report                        IS_EX( {} )
2102  * Done                          TO_IN( {} )
2103  */
2104 static void amt_mldv1_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
2105                                     struct amt_tunnel_list *tunnel)
2106 {
2107         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2108         struct iphdr *iph = ip_hdr(skb);
2109         struct amt_group_node *gnode;
2110         union amt_addr group, host;
2111
2112         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2113         memset(&host, 0, sizeof(union amt_addr));
2114         host.ip4 = iph->saddr;
2115
2116         gnode = amt_lookup_group(tunnel, &group, &host, true);
2117         if (gnode) {
2118                 amt_del_group(amt, gnode);
2119                 return;
2120         }
2121 }
2122
2123 static void amt_mldv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2124                                      struct amt_tunnel_list *tunnel)
2125 {
2126         struct mld2_report *mld2r = (struct mld2_report *)icmp6_hdr(skb);
2127         int len = skb_transport_offset(skb) + sizeof(*mld2r);
2128         void *zero_grec = (void *)&mldv2_zero_grec;
2129         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2130         struct amt_group_node *gnode;
2131         union amt_addr group, host;
2132         struct mld2_grec *grec;
2133         u16 nsrcs;
2134         int i;
2135
2136         for (i = 0; i < ntohs(mld2r->mld2r_ngrec); i++) {
2137                 len += sizeof(*grec);
2138                 if (!ipv6_mc_may_pull(skb, len))
2139                         break;
2140
2141                 grec = (void *)(skb->data + len - sizeof(*grec));
2142                 nsrcs = ntohs(grec->grec_nsrcs);
2143
2144                 len += nsrcs * sizeof(struct in6_addr);
2145                 if (!ipv6_mc_may_pull(skb, len))
2146                         break;
2147
2148                 memset(&group, 0, sizeof(union amt_addr));
2149                 group.ip6 = grec->grec_mca;
2150                 memset(&host, 0, sizeof(union amt_addr));
2151                 host.ip6 = ip6h->saddr;
2152                 gnode = amt_lookup_group(tunnel, &group, &host, true);
2153                 if (!gnode) {
2154                         gnode = amt_add_group(amt, tunnel, &group, &host,
2155                                               ETH_P_IPV6);
2156                         if (IS_ERR(gnode))
2157                                 continue;
2158                 }
2159
2160                 amt_add_srcs(amt, tunnel, gnode, grec, true);
2161                 switch (grec->grec_type) {
2162                 case MLD2_MODE_IS_INCLUDE:
2163                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2164                                                 zero_grec, true);
2165                         break;
2166                 case MLD2_MODE_IS_EXCLUDE:
2167                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2168                                                 zero_grec, true);
2169                         break;
2170                 case MLD2_CHANGE_TO_INCLUDE:
2171                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2172                                                 zero_grec, true);
2173                         break;
2174                 case MLD2_CHANGE_TO_EXCLUDE:
2175                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2176                                                 zero_grec, true);
2177                         break;
2178                 case MLD2_ALLOW_NEW_SOURCES:
2179                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2180                                                 zero_grec, true);
2181                         break;
2182                 case MLD2_BLOCK_OLD_SOURCES:
2183                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2184                                                 zero_grec, true);
2185                         break;
2186                 default:
2187                         break;
2188                 }
2189                 amt_cleanup_srcs(amt, tunnel, gnode);
2190         }
2191 }
2192
2193 /* caller held tunnel->lock */
2194 static void amt_mld_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2195                                    struct amt_tunnel_list *tunnel)
2196 {
2197         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2198
2199         switch (mld->mld_type) {
2200         case ICMPV6_MGM_REPORT:
2201                 amt_mldv1_report_handler(amt, skb, tunnel);
2202                 break;
2203         case ICMPV6_MLD2_REPORT:
2204                 amt_mldv2_report_handler(amt, skb, tunnel);
2205                 break;
2206         case ICMPV6_MGM_REDUCTION:
2207                 amt_mldv1_leave_handler(amt, skb, tunnel);
2208                 break;
2209         default:
2210                 break;
2211         }
2212 }
2213 #endif
2214
2215 static bool amt_advertisement_handler(struct amt_dev *amt, struct sk_buff *skb)
2216 {
2217         struct amt_header_advertisement *amta;
2218         int hdr_size;
2219
2220         hdr_size = sizeof(*amta) - sizeof(struct amt_header);
2221
2222         if (!pskb_may_pull(skb, hdr_size))
2223                 return true;
2224
2225         amta = (struct amt_header_advertisement *)(udp_hdr(skb) + 1);
2226         if (!amta->ip4)
2227                 return true;
2228
2229         if (amta->reserved || amta->version)
2230                 return true;
2231
2232         if (ipv4_is_loopback(amta->ip4) || ipv4_is_multicast(amta->ip4) ||
2233             ipv4_is_zeronet(amta->ip4))
2234                 return true;
2235
2236         amt->remote_ip = amta->ip4;
2237         netdev_dbg(amt->dev, "advertised remote ip = %pI4\n", &amt->remote_ip);
2238         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2239
2240         amt_update_gw_status(amt, AMT_STATUS_RECEIVED_ADVERTISEMENT, true);
2241         return false;
2242 }
2243
2244 static bool amt_multicast_data_handler(struct amt_dev *amt, struct sk_buff *skb)
2245 {
2246         struct amt_header_mcast_data *amtmd;
2247         int hdr_size, len, err;
2248         struct ethhdr *eth;
2249         struct iphdr *iph;
2250
2251         amtmd = (struct amt_header_mcast_data *)(udp_hdr(skb) + 1);
2252         if (amtmd->reserved || amtmd->version)
2253                 return true;
2254
2255         hdr_size = sizeof(*amtmd) + sizeof(struct udphdr);
2256         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_IP), false))
2257                 return true;
2258         skb_reset_network_header(skb);
2259         skb_push(skb, sizeof(*eth));
2260         skb_reset_mac_header(skb);
2261         skb_pull(skb, sizeof(*eth));
2262         eth = eth_hdr(skb);
2263         iph = ip_hdr(skb);
2264         if (iph->version == 4) {
2265                 if (!ipv4_is_multicast(iph->daddr))
2266                         return true;
2267                 skb->protocol = htons(ETH_P_IP);
2268                 eth->h_proto = htons(ETH_P_IP);
2269                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2270 #if IS_ENABLED(CONFIG_IPV6)
2271         } else if (iph->version == 6) {
2272                 struct ipv6hdr *ip6h;
2273
2274                 ip6h = ipv6_hdr(skb);
2275                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2276                         return true;
2277                 skb->protocol = htons(ETH_P_IPV6);
2278                 eth->h_proto = htons(ETH_P_IPV6);
2279                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2280 #endif
2281         } else {
2282                 return true;
2283         }
2284
2285         skb->pkt_type = PACKET_MULTICAST;
2286         skb->ip_summed = CHECKSUM_NONE;
2287         len = skb->len;
2288         err = gro_cells_receive(&amt->gro_cells, skb);
2289         if (likely(err == NET_RX_SUCCESS))
2290                 dev_sw_netstats_rx_add(amt->dev, len);
2291         else
2292                 amt->dev->stats.rx_dropped++;
2293
2294         return false;
2295 }
2296
2297 static bool amt_membership_query_handler(struct amt_dev *amt,
2298                                          struct sk_buff *skb)
2299 {
2300         struct amt_header_membership_query *amtmq;
2301         struct igmpv3_query *ihv3;
2302         struct ethhdr *eth, *oeth;
2303         struct iphdr *iph;
2304         int hdr_size, len;
2305
2306         hdr_size = sizeof(*amtmq) - sizeof(struct amt_header);
2307
2308         if (!pskb_may_pull(skb, hdr_size))
2309                 return true;
2310
2311         amtmq = (struct amt_header_membership_query *)(udp_hdr(skb) + 1);
2312         if (amtmq->reserved || amtmq->version)
2313                 return true;
2314
2315         hdr_size = sizeof(*amtmq) + sizeof(struct udphdr) - sizeof(*eth);
2316         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_TEB), false))
2317                 return true;
2318         oeth = eth_hdr(skb);
2319         skb_reset_mac_header(skb);
2320         skb_pull(skb, sizeof(*eth));
2321         skb_reset_network_header(skb);
2322         eth = eth_hdr(skb);
2323         iph = ip_hdr(skb);
2324         if (iph->version == 4) {
2325                 if (!ipv4_is_multicast(iph->daddr))
2326                         return true;
2327                 if (!pskb_may_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS +
2328                                    sizeof(*ihv3)))
2329                         return true;
2330
2331                 ihv3 = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2332                 skb_reset_transport_header(skb);
2333                 skb_push(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2334                 spin_lock_bh(&amt->lock);
2335                 amt->ready4 = true;
2336                 amt->mac = amtmq->response_mac;
2337                 amt->req_cnt = 0;
2338                 amt->qi = ihv3->qqic;
2339                 spin_unlock_bh(&amt->lock);
2340                 skb->protocol = htons(ETH_P_IP);
2341                 eth->h_proto = htons(ETH_P_IP);
2342                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2343 #if IS_ENABLED(CONFIG_IPV6)
2344         } else if (iph->version == 6) {
2345                 struct ipv6hdr *ip6h = ipv6_hdr(skb);
2346                 struct mld2_query *mld2q;
2347
2348                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2349                         return true;
2350                 if (!pskb_may_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS +
2351                                    sizeof(*mld2q)))
2352                         return true;
2353
2354                 mld2q = skb_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2355                 skb_reset_transport_header(skb);
2356                 skb_push(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2357                 spin_lock_bh(&amt->lock);
2358                 amt->ready6 = true;
2359                 amt->mac = amtmq->response_mac;
2360                 amt->req_cnt = 0;
2361                 amt->qi = mld2q->mld2q_qqic;
2362                 spin_unlock_bh(&amt->lock);
2363                 skb->protocol = htons(ETH_P_IPV6);
2364                 eth->h_proto = htons(ETH_P_IPV6);
2365                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2366 #endif
2367         } else {
2368                 return true;
2369         }
2370
2371         ether_addr_copy(eth->h_source, oeth->h_source);
2372         skb->pkt_type = PACKET_MULTICAST;
2373         skb->ip_summed = CHECKSUM_NONE;
2374         len = skb->len;
2375         if (netif_rx(skb) == NET_RX_SUCCESS) {
2376                 amt_update_gw_status(amt, AMT_STATUS_RECEIVED_QUERY, true);
2377                 dev_sw_netstats_rx_add(amt->dev, len);
2378         } else {
2379                 amt->dev->stats.rx_dropped++;
2380         }
2381
2382         return false;
2383 }
2384
2385 static bool amt_update_handler(struct amt_dev *amt, struct sk_buff *skb)
2386 {
2387         struct amt_header_membership_update *amtmu;
2388         struct amt_tunnel_list *tunnel;
2389         struct udphdr *udph;
2390         struct ethhdr *eth;
2391         struct iphdr *iph;
2392         int len;
2393
2394         iph = ip_hdr(skb);
2395         udph = udp_hdr(skb);
2396
2397         if (__iptunnel_pull_header(skb, sizeof(*udph), skb->protocol,
2398                                    false, false))
2399                 return true;
2400
2401         amtmu = (struct amt_header_membership_update *)skb->data;
2402         if (amtmu->reserved || amtmu->version)
2403                 return true;
2404
2405         skb_pull(skb, sizeof(*amtmu));
2406         skb_reset_network_header(skb);
2407
2408         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
2409                 if (tunnel->ip4 == iph->saddr) {
2410                         if ((amtmu->nonce == tunnel->nonce &&
2411                              amtmu->response_mac == tunnel->mac)) {
2412                                 mod_delayed_work(amt_wq, &tunnel->gc_wq,
2413                                                  msecs_to_jiffies(amt_gmi(amt))
2414                                                                   * 3);
2415                                 goto report;
2416                         } else {
2417                                 netdev_dbg(amt->dev, "Invalid MAC\n");
2418                                 return true;
2419                         }
2420                 }
2421         }
2422
2423         return false;
2424
2425 report:
2426         iph = ip_hdr(skb);
2427         if (iph->version == 4) {
2428                 if (ip_mc_check_igmp(skb)) {
2429                         netdev_dbg(amt->dev, "Invalid IGMP\n");
2430                         return true;
2431                 }
2432
2433                 spin_lock_bh(&tunnel->lock);
2434                 amt_igmp_report_handler(amt, skb, tunnel);
2435                 spin_unlock_bh(&tunnel->lock);
2436
2437                 skb_push(skb, sizeof(struct ethhdr));
2438                 skb_reset_mac_header(skb);
2439                 eth = eth_hdr(skb);
2440                 skb->protocol = htons(ETH_P_IP);
2441                 eth->h_proto = htons(ETH_P_IP);
2442                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2443 #if IS_ENABLED(CONFIG_IPV6)
2444         } else if (iph->version == 6) {
2445                 struct ipv6hdr *ip6h = ipv6_hdr(skb);
2446
2447                 if (ipv6_mc_check_mld(skb)) {
2448                         netdev_dbg(amt->dev, "Invalid MLD\n");
2449                         return true;
2450                 }
2451
2452                 spin_lock_bh(&tunnel->lock);
2453                 amt_mld_report_handler(amt, skb, tunnel);
2454                 spin_unlock_bh(&tunnel->lock);
2455
2456                 skb_push(skb, sizeof(struct ethhdr));
2457                 skb_reset_mac_header(skb);
2458                 eth = eth_hdr(skb);
2459                 skb->protocol = htons(ETH_P_IPV6);
2460                 eth->h_proto = htons(ETH_P_IPV6);
2461                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2462 #endif
2463         } else {
2464                 netdev_dbg(amt->dev, "Unsupported Protocol\n");
2465                 return true;
2466         }
2467
2468         skb_pull(skb, sizeof(struct ethhdr));
2469         skb->pkt_type = PACKET_MULTICAST;
2470         skb->ip_summed = CHECKSUM_NONE;
2471         len = skb->len;
2472         if (netif_rx(skb) == NET_RX_SUCCESS) {
2473                 amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_UPDATE,
2474                                         true);
2475                 dev_sw_netstats_rx_add(amt->dev, len);
2476         } else {
2477                 amt->dev->stats.rx_dropped++;
2478         }
2479
2480         return false;
2481 }
2482
2483 static void amt_send_advertisement(struct amt_dev *amt, __be32 nonce,
2484                                    __be32 daddr, __be16 dport)
2485 {
2486         struct amt_header_advertisement *amta;
2487         int hlen, tlen, offset;
2488         struct socket *sock;
2489         struct udphdr *udph;
2490         struct sk_buff *skb;
2491         struct iphdr *iph;
2492         struct rtable *rt;
2493         struct flowi4 fl4;
2494         u32 len;
2495         int err;
2496
2497         rcu_read_lock();
2498         sock = rcu_dereference(amt->sock);
2499         if (!sock)
2500                 goto out;
2501
2502         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
2503                 goto out;
2504
2505         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
2506                                    daddr, amt->local_ip,
2507                                    dport, amt->relay_port,
2508                                    IPPROTO_UDP, 0,
2509                                    amt->stream_dev->ifindex);
2510         if (IS_ERR(rt)) {
2511                 amt->dev->stats.tx_errors++;
2512                 goto out;
2513         }
2514
2515         hlen = LL_RESERVED_SPACE(amt->dev);
2516         tlen = amt->dev->needed_tailroom;
2517         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2518         skb = netdev_alloc_skb_ip_align(amt->dev, len);
2519         if (!skb) {
2520                 ip_rt_put(rt);
2521                 amt->dev->stats.tx_errors++;
2522                 goto out;
2523         }
2524
2525         skb->priority = TC_PRIO_CONTROL;
2526         skb_dst_set(skb, &rt->dst);
2527
2528         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2529         skb_reset_network_header(skb);
2530         skb_put(skb, len);
2531         amta = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
2532         amta->version   = 0;
2533         amta->type      = AMT_MSG_ADVERTISEMENT;
2534         amta->reserved  = 0;
2535         amta->nonce     = nonce;
2536         amta->ip4       = amt->local_ip;
2537         skb_push(skb, sizeof(*udph));
2538         skb_reset_transport_header(skb);
2539         udph            = udp_hdr(skb);
2540         udph->source    = amt->relay_port;
2541         udph->dest      = dport;
2542         udph->len       = htons(sizeof(*amta) + sizeof(*udph));
2543         udph->check     = 0;
2544         offset = skb_transport_offset(skb);
2545         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
2546         udph->check = csum_tcpudp_magic(amt->local_ip, daddr,
2547                                         sizeof(*udph) + sizeof(*amta),
2548                                         IPPROTO_UDP, skb->csum);
2549
2550         skb_push(skb, sizeof(*iph));
2551         iph             = ip_hdr(skb);
2552         iph->version    = 4;
2553         iph->ihl        = (sizeof(struct iphdr)) >> 2;
2554         iph->tos        = AMT_TOS;
2555         iph->frag_off   = 0;
2556         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
2557         iph->daddr      = daddr;
2558         iph->saddr      = amt->local_ip;
2559         iph->protocol   = IPPROTO_UDP;
2560         iph->tot_len    = htons(len);
2561
2562         skb->ip_summed = CHECKSUM_NONE;
2563         ip_select_ident(amt->net, skb, NULL);
2564         ip_send_check(iph);
2565         err = ip_local_out(amt->net, sock->sk, skb);
2566         if (unlikely(net_xmit_eval(err)))
2567                 amt->dev->stats.tx_errors++;
2568
2569 out:
2570         rcu_read_unlock();
2571 }
2572
2573 static bool amt_discovery_handler(struct amt_dev *amt, struct sk_buff *skb)
2574 {
2575         struct amt_header_discovery *amtd;
2576         struct udphdr *udph;
2577         struct iphdr *iph;
2578
2579         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtd)))
2580                 return true;
2581
2582         iph = ip_hdr(skb);
2583         udph = udp_hdr(skb);
2584         amtd = (struct amt_header_discovery *)(udp_hdr(skb) + 1);
2585
2586         if (amtd->reserved || amtd->version)
2587                 return true;
2588
2589         amt_send_advertisement(amt, amtd->nonce, iph->saddr, udph->source);
2590
2591         return false;
2592 }
2593
2594 static bool amt_request_handler(struct amt_dev *amt, struct sk_buff *skb)
2595 {
2596         struct amt_header_request *amtrh;
2597         struct amt_tunnel_list *tunnel;
2598         unsigned long long key;
2599         struct udphdr *udph;
2600         struct iphdr *iph;
2601         u64 mac;
2602         int i;
2603
2604         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtrh)))
2605                 return true;
2606
2607         iph = ip_hdr(skb);
2608         udph = udp_hdr(skb);
2609         amtrh = (struct amt_header_request *)(udp_hdr(skb) + 1);
2610
2611         if (amtrh->reserved1 || amtrh->reserved2 || amtrh->version)
2612                 return true;
2613
2614         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list)
2615                 if (tunnel->ip4 == iph->saddr)
2616                         goto send;
2617
2618         if (amt->nr_tunnels >= amt->max_tunnels) {
2619                 icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0);
2620                 return true;
2621         }
2622
2623         tunnel = kzalloc(sizeof(*tunnel) +
2624                          (sizeof(struct hlist_head) * amt->hash_buckets),
2625                          GFP_ATOMIC);
2626         if (!tunnel)
2627                 return true;
2628
2629         tunnel->source_port = udph->source;
2630         tunnel->ip4 = iph->saddr;
2631
2632         memcpy(&key, &tunnel->key, sizeof(unsigned long long));
2633         tunnel->amt = amt;
2634         spin_lock_init(&tunnel->lock);
2635         for (i = 0; i < amt->hash_buckets; i++)
2636                 INIT_HLIST_HEAD(&tunnel->groups[i]);
2637
2638         INIT_DELAYED_WORK(&tunnel->gc_wq, amt_tunnel_expire);
2639
2640         spin_lock_bh(&amt->lock);
2641         list_add_tail_rcu(&tunnel->list, &amt->tunnel_list);
2642         tunnel->key = amt->key;
2643         amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_REQUEST, true);
2644         amt->nr_tunnels++;
2645         mod_delayed_work(amt_wq, &tunnel->gc_wq,
2646                          msecs_to_jiffies(amt_gmi(amt)));
2647         spin_unlock_bh(&amt->lock);
2648
2649 send:
2650         tunnel->nonce = amtrh->nonce;
2651         mac = siphash_3u32((__force u32)tunnel->ip4,
2652                            (__force u32)tunnel->source_port,
2653                            (__force u32)tunnel->nonce,
2654                            &tunnel->key);
2655         tunnel->mac = mac >> 16;
2656
2657         if (!netif_running(amt->dev) || !netif_running(amt->stream_dev))
2658                 return true;
2659
2660         if (!amtrh->p)
2661                 amt_send_igmp_gq(amt, tunnel);
2662         else
2663                 amt_send_mld_gq(amt, tunnel);
2664
2665         return false;
2666 }
2667
2668 static int amt_rcv(struct sock *sk, struct sk_buff *skb)
2669 {
2670         struct amt_dev *amt;
2671         struct iphdr *iph;
2672         int type;
2673         bool err;
2674
2675         rcu_read_lock_bh();
2676         amt = rcu_dereference_sk_user_data(sk);
2677         if (!amt) {
2678                 err = true;
2679                 goto out;
2680         }
2681
2682         skb->dev = amt->dev;
2683         iph = ip_hdr(skb);
2684         type = amt_parse_type(skb);
2685         if (type == -1) {
2686                 err = true;
2687                 goto drop;
2688         }
2689
2690         if (amt->mode == AMT_MODE_GATEWAY) {
2691                 switch (type) {
2692                 case AMT_MSG_ADVERTISEMENT:
2693                         if (iph->saddr != amt->discovery_ip) {
2694                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2695                                 err = true;
2696                                 goto drop;
2697                         }
2698                         if (amt_advertisement_handler(amt, skb))
2699                                 amt->dev->stats.rx_dropped++;
2700                         goto out;
2701                 case AMT_MSG_MULTICAST_DATA:
2702                         if (iph->saddr != amt->remote_ip) {
2703                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2704                                 err = true;
2705                                 goto drop;
2706                         }
2707                         err = amt_multicast_data_handler(amt, skb);
2708                         if (err)
2709                                 goto drop;
2710                         else
2711                                 goto out;
2712                 case AMT_MSG_MEMBERSHIP_QUERY:
2713                         if (iph->saddr != amt->remote_ip) {
2714                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2715                                 err = true;
2716                                 goto drop;
2717                         }
2718                         err = amt_membership_query_handler(amt, skb);
2719                         if (err)
2720                                 goto drop;
2721                         else
2722                                 goto out;
2723                 default:
2724                         err = true;
2725                         netdev_dbg(amt->dev, "Invalid type of Gateway\n");
2726                         break;
2727                 }
2728         } else {
2729                 switch (type) {
2730                 case AMT_MSG_DISCOVERY:
2731                         err = amt_discovery_handler(amt, skb);
2732                         break;
2733                 case AMT_MSG_REQUEST:
2734                         err = amt_request_handler(amt, skb);
2735                         break;
2736                 case AMT_MSG_MEMBERSHIP_UPDATE:
2737                         err = amt_update_handler(amt, skb);
2738                         if (err)
2739                                 goto drop;
2740                         else
2741                                 goto out;
2742                 default:
2743                         err = true;
2744                         netdev_dbg(amt->dev, "Invalid type of relay\n");
2745                         break;
2746                 }
2747         }
2748 drop:
2749         if (err) {
2750                 amt->dev->stats.rx_dropped++;
2751                 kfree_skb(skb);
2752         } else {
2753                 consume_skb(skb);
2754         }
2755 out:
2756         rcu_read_unlock_bh();
2757         return 0;
2758 }
2759
2760 static int amt_err_lookup(struct sock *sk, struct sk_buff *skb)
2761 {
2762         struct amt_dev *amt;
2763         int type;
2764
2765         rcu_read_lock_bh();
2766         amt = rcu_dereference_sk_user_data(sk);
2767         if (!amt)
2768                 goto out;
2769
2770         if (amt->mode != AMT_MODE_GATEWAY)
2771                 goto drop;
2772
2773         type = amt_parse_type(skb);
2774         if (type == -1)
2775                 goto drop;
2776
2777         netdev_dbg(amt->dev, "Received IGMP Unreachable of %s\n",
2778                    type_str[type]);
2779         switch (type) {
2780         case AMT_MSG_DISCOVERY:
2781                 break;
2782         case AMT_MSG_REQUEST:
2783         case AMT_MSG_MEMBERSHIP_UPDATE:
2784                 if (amt->status >= AMT_STATUS_RECEIVED_ADVERTISEMENT)
2785                         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2786                 break;
2787         default:
2788                 goto drop;
2789         }
2790 out:
2791         rcu_read_unlock_bh();
2792         return 0;
2793 drop:
2794         rcu_read_unlock_bh();
2795         amt->dev->stats.rx_dropped++;
2796         return 0;
2797 }
2798
2799 static struct socket *amt_create_sock(struct net *net, __be16 port)
2800 {
2801         struct udp_port_cfg udp_conf;
2802         struct socket *sock;
2803         int err;
2804
2805         memset(&udp_conf, 0, sizeof(udp_conf));
2806         udp_conf.family = AF_INET;
2807         udp_conf.local_ip.s_addr = htonl(INADDR_ANY);
2808
2809         udp_conf.local_udp_port = port;
2810
2811         err = udp_sock_create(net, &udp_conf, &sock);
2812         if (err < 0)
2813                 return ERR_PTR(err);
2814
2815         return sock;
2816 }
2817
2818 static int amt_socket_create(struct amt_dev *amt)
2819 {
2820         struct udp_tunnel_sock_cfg tunnel_cfg;
2821         struct socket *sock;
2822
2823         sock = amt_create_sock(amt->net, amt->relay_port);
2824         if (IS_ERR(sock))
2825                 return PTR_ERR(sock);
2826
2827         /* Mark socket as an encapsulation socket */
2828         memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
2829         tunnel_cfg.sk_user_data = amt;
2830         tunnel_cfg.encap_type = 1;
2831         tunnel_cfg.encap_rcv = amt_rcv;
2832         tunnel_cfg.encap_err_lookup = amt_err_lookup;
2833         tunnel_cfg.encap_destroy = NULL;
2834         setup_udp_tunnel_sock(amt->net, sock, &tunnel_cfg);
2835
2836         rcu_assign_pointer(amt->sock, sock);
2837         return 0;
2838 }
2839
2840 static int amt_dev_open(struct net_device *dev)
2841 {
2842         struct amt_dev *amt = netdev_priv(dev);
2843         int err;
2844
2845         amt->ready4 = false;
2846         amt->ready6 = false;
2847
2848         err = amt_socket_create(amt);
2849         if (err)
2850                 return err;
2851
2852         amt->req_cnt = 0;
2853         amt->remote_ip = 0;
2854         get_random_bytes(&amt->key, sizeof(siphash_key_t));
2855
2856         amt->status = AMT_STATUS_INIT;
2857         if (amt->mode == AMT_MODE_GATEWAY) {
2858                 mod_delayed_work(amt_wq, &amt->discovery_wq, 0);
2859                 mod_delayed_work(amt_wq, &amt->req_wq, 0);
2860         } else if (amt->mode == AMT_MODE_RELAY) {
2861                 mod_delayed_work(amt_wq, &amt->secret_wq,
2862                                  msecs_to_jiffies(AMT_SECRET_TIMEOUT));
2863         }
2864         return err;
2865 }
2866
2867 static int amt_dev_stop(struct net_device *dev)
2868 {
2869         struct amt_dev *amt = netdev_priv(dev);
2870         struct amt_tunnel_list *tunnel, *tmp;
2871         struct socket *sock;
2872
2873         cancel_delayed_work_sync(&amt->req_wq);
2874         cancel_delayed_work_sync(&amt->discovery_wq);
2875         cancel_delayed_work_sync(&amt->secret_wq);
2876
2877         /* shutdown */
2878         sock = rtnl_dereference(amt->sock);
2879         RCU_INIT_POINTER(amt->sock, NULL);
2880         synchronize_net();
2881         if (sock)
2882                 udp_tunnel_sock_release(sock);
2883
2884         amt->ready4 = false;
2885         amt->ready6 = false;
2886         amt->req_cnt = 0;
2887         amt->remote_ip = 0;
2888
2889         list_for_each_entry_safe(tunnel, tmp, &amt->tunnel_list, list) {
2890                 list_del_rcu(&tunnel->list);
2891                 amt->nr_tunnels--;
2892                 cancel_delayed_work_sync(&tunnel->gc_wq);
2893                 amt_clear_groups(tunnel);
2894                 kfree_rcu(tunnel, rcu);
2895         }
2896
2897         return 0;
2898 }
2899
2900 static const struct device_type amt_type = {
2901         .name = "amt",
2902 };
2903
2904 static int amt_dev_init(struct net_device *dev)
2905 {
2906         struct amt_dev *amt = netdev_priv(dev);
2907         int err;
2908
2909         amt->dev = dev;
2910         dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
2911         if (!dev->tstats)
2912                 return -ENOMEM;
2913
2914         err = gro_cells_init(&amt->gro_cells, dev);
2915         if (err) {
2916                 free_percpu(dev->tstats);
2917                 return err;
2918         }
2919
2920         return 0;
2921 }
2922
2923 static void amt_dev_uninit(struct net_device *dev)
2924 {
2925         struct amt_dev *amt = netdev_priv(dev);
2926
2927         gro_cells_destroy(&amt->gro_cells);
2928         free_percpu(dev->tstats);
2929 }
2930
2931 static const struct net_device_ops amt_netdev_ops = {
2932         .ndo_init               = amt_dev_init,
2933         .ndo_uninit             = amt_dev_uninit,
2934         .ndo_open               = amt_dev_open,
2935         .ndo_stop               = amt_dev_stop,
2936         .ndo_start_xmit         = amt_dev_xmit,
2937         .ndo_get_stats64        = dev_get_tstats64,
2938 };
2939
2940 static void amt_link_setup(struct net_device *dev)
2941 {
2942         dev->netdev_ops         = &amt_netdev_ops;
2943         dev->needs_free_netdev  = true;
2944         SET_NETDEV_DEVTYPE(dev, &amt_type);
2945         dev->min_mtu            = ETH_MIN_MTU;
2946         dev->max_mtu            = ETH_MAX_MTU;
2947         dev->type               = ARPHRD_NONE;
2948         dev->flags              = IFF_POINTOPOINT | IFF_NOARP | IFF_MULTICAST;
2949         dev->hard_header_len    = 0;
2950         dev->addr_len           = 0;
2951         dev->priv_flags         |= IFF_NO_QUEUE;
2952         dev->features           |= NETIF_F_LLTX;
2953         dev->features           |= NETIF_F_GSO_SOFTWARE;
2954         dev->features           |= NETIF_F_NETNS_LOCAL;
2955         dev->hw_features        |= NETIF_F_SG | NETIF_F_HW_CSUM;
2956         dev->hw_features        |= NETIF_F_FRAGLIST | NETIF_F_RXCSUM;
2957         dev->hw_features        |= NETIF_F_GSO_SOFTWARE;
2958         eth_hw_addr_random(dev);
2959         eth_zero_addr(dev->broadcast);
2960         ether_setup(dev);
2961 }
2962
2963 static const struct nla_policy amt_policy[IFLA_AMT_MAX + 1] = {
2964         [IFLA_AMT_MODE]         = { .type = NLA_U32 },
2965         [IFLA_AMT_RELAY_PORT]   = { .type = NLA_U16 },
2966         [IFLA_AMT_GATEWAY_PORT] = { .type = NLA_U16 },
2967         [IFLA_AMT_LINK]         = { .type = NLA_U32 },
2968         [IFLA_AMT_LOCAL_IP]     = { .len = sizeof_field(struct iphdr, daddr) },
2969         [IFLA_AMT_REMOTE_IP]    = { .len = sizeof_field(struct iphdr, daddr) },
2970         [IFLA_AMT_DISCOVERY_IP] = { .len = sizeof_field(struct iphdr, daddr) },
2971         [IFLA_AMT_MAX_TUNNELS]  = { .type = NLA_U32 },
2972 };
2973
2974 static int amt_validate(struct nlattr *tb[], struct nlattr *data[],
2975                         struct netlink_ext_ack *extack)
2976 {
2977         if (!data)
2978                 return -EINVAL;
2979
2980         if (!data[IFLA_AMT_LINK]) {
2981                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LINK],
2982                                     "Link attribute is required");
2983                 return -EINVAL;
2984         }
2985
2986         if (!data[IFLA_AMT_MODE]) {
2987                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
2988                                     "Mode attribute is required");
2989                 return -EINVAL;
2990         }
2991
2992         if (nla_get_u32(data[IFLA_AMT_MODE]) > AMT_MODE_MAX) {
2993                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
2994                                     "Mode attribute is not valid");
2995                 return -EINVAL;
2996         }
2997
2998         if (!data[IFLA_AMT_LOCAL_IP]) {
2999                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_DISCOVERY_IP],
3000                                     "Local attribute is required");
3001                 return -EINVAL;
3002         }
3003
3004         if (!data[IFLA_AMT_DISCOVERY_IP] &&
3005             nla_get_u32(data[IFLA_AMT_MODE]) == AMT_MODE_GATEWAY) {
3006                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LOCAL_IP],
3007                                     "Discovery attribute is required");
3008                 return -EINVAL;
3009         }
3010
3011         return 0;
3012 }
3013
3014 static int amt_newlink(struct net *net, struct net_device *dev,
3015                        struct nlattr *tb[], struct nlattr *data[],
3016                        struct netlink_ext_ack *extack)
3017 {
3018         struct amt_dev *amt = netdev_priv(dev);
3019         int err = -EINVAL;
3020
3021         amt->net = net;
3022         amt->mode = nla_get_u32(data[IFLA_AMT_MODE]);
3023
3024         if (data[IFLA_AMT_MAX_TUNNELS] &&
3025             nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]))
3026                 amt->max_tunnels = nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]);
3027         else
3028                 amt->max_tunnels = AMT_MAX_TUNNELS;
3029
3030         spin_lock_init(&amt->lock);
3031         amt->max_groups = AMT_MAX_GROUP;
3032         amt->max_sources = AMT_MAX_SOURCE;
3033         amt->hash_buckets = AMT_HSIZE;
3034         amt->nr_tunnels = 0;
3035         get_random_bytes(&amt->hash_seed, sizeof(amt->hash_seed));
3036         amt->stream_dev = dev_get_by_index(net,
3037                                            nla_get_u32(data[IFLA_AMT_LINK]));
3038         if (!amt->stream_dev) {
3039                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3040                                     "Can't find stream device");
3041                 return -ENODEV;
3042         }
3043
3044         if (amt->stream_dev->type != ARPHRD_ETHER) {
3045                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3046                                     "Invalid stream device type");
3047                 goto err;
3048         }
3049
3050         amt->local_ip = nla_get_in_addr(data[IFLA_AMT_LOCAL_IP]);
3051         if (ipv4_is_loopback(amt->local_ip) ||
3052             ipv4_is_zeronet(amt->local_ip) ||
3053             ipv4_is_multicast(amt->local_ip)) {
3054                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LOCAL_IP],
3055                                     "Invalid Local address");
3056                 goto err;
3057         }
3058
3059         if (data[IFLA_AMT_RELAY_PORT])
3060                 amt->relay_port = nla_get_be16(data[IFLA_AMT_RELAY_PORT]);
3061         else
3062                 amt->relay_port = htons(IANA_AMT_UDP_PORT);
3063
3064         if (data[IFLA_AMT_GATEWAY_PORT])
3065                 amt->gw_port = nla_get_be16(data[IFLA_AMT_GATEWAY_PORT]);
3066         else
3067                 amt->gw_port = htons(IANA_AMT_UDP_PORT);
3068
3069         if (!amt->relay_port) {
3070                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3071                                     "relay port must not be 0");
3072                 goto err;
3073         }
3074         if (amt->mode == AMT_MODE_RELAY) {
3075                 amt->qrv = amt->net->ipv4.sysctl_igmp_qrv;
3076                 amt->qri = 10;
3077                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3078                                        AMT_RELAY_HLEN;
3079                 dev->mtu = amt->stream_dev->mtu - AMT_RELAY_HLEN;
3080                 dev->max_mtu = dev->mtu;
3081                 dev->min_mtu = ETH_MIN_MTU + AMT_RELAY_HLEN;
3082         } else {
3083                 if (!data[IFLA_AMT_DISCOVERY_IP]) {
3084                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3085                                             "discovery must be set in gateway mode");
3086                         goto err;
3087                 }
3088                 if (!amt->gw_port) {
3089                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3090                                             "gateway port must not be 0");
3091                         goto err;
3092                 }
3093                 amt->remote_ip = 0;
3094                 amt->discovery_ip = nla_get_in_addr(data[IFLA_AMT_DISCOVERY_IP]);
3095                 if (ipv4_is_loopback(amt->discovery_ip) ||
3096                     ipv4_is_zeronet(amt->discovery_ip) ||
3097                     ipv4_is_multicast(amt->discovery_ip)) {
3098                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3099                                             "discovery must be unicast");
3100                         goto err;
3101                 }
3102
3103                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3104                                        AMT_GW_HLEN;
3105                 dev->mtu = amt->stream_dev->mtu - AMT_GW_HLEN;
3106                 dev->max_mtu = dev->mtu;
3107                 dev->min_mtu = ETH_MIN_MTU + AMT_GW_HLEN;
3108         }
3109         amt->qi = AMT_INIT_QUERY_INTERVAL;
3110
3111         err = register_netdevice(dev);
3112         if (err < 0) {
3113                 netdev_dbg(dev, "failed to register new netdev %d\n", err);
3114                 goto err;
3115         }
3116
3117         err = netdev_upper_dev_link(amt->stream_dev, dev, extack);
3118         if (err < 0) {
3119                 unregister_netdevice(dev);
3120                 goto err;
3121         }
3122
3123         INIT_DELAYED_WORK(&amt->discovery_wq, amt_discovery_work);
3124         INIT_DELAYED_WORK(&amt->req_wq, amt_req_work);
3125         INIT_DELAYED_WORK(&amt->secret_wq, amt_secret_work);
3126         INIT_LIST_HEAD(&amt->tunnel_list);
3127
3128         return 0;
3129 err:
3130         dev_put(amt->stream_dev);
3131         return err;
3132 }
3133
3134 static void amt_dellink(struct net_device *dev, struct list_head *head)
3135 {
3136         struct amt_dev *amt = netdev_priv(dev);
3137
3138         unregister_netdevice_queue(dev, head);
3139         netdev_upper_dev_unlink(amt->stream_dev, dev);
3140         dev_put(amt->stream_dev);
3141 }
3142
3143 static size_t amt_get_size(const struct net_device *dev)
3144 {
3145         return nla_total_size(sizeof(__u32)) + /* IFLA_AMT_MODE */
3146                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_RELAY_PORT */
3147                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_GATEWAY_PORT */
3148                nla_total_size(sizeof(__u32)) + /* IFLA_AMT_LINK */
3149                nla_total_size(sizeof(__u32)) + /* IFLA_MAX_TUNNELS */
3150                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_DISCOVERY_IP */
3151                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_REMOTE_IP */
3152                nla_total_size(sizeof(struct iphdr)); /* IFLA_AMT_LOCAL_IP */
3153 }
3154
3155 static int amt_fill_info(struct sk_buff *skb, const struct net_device *dev)
3156 {
3157         struct amt_dev *amt = netdev_priv(dev);
3158
3159         if (nla_put_u32(skb, IFLA_AMT_MODE, amt->mode))
3160                 goto nla_put_failure;
3161         if (nla_put_be16(skb, IFLA_AMT_RELAY_PORT, amt->relay_port))
3162                 goto nla_put_failure;
3163         if (nla_put_be16(skb, IFLA_AMT_GATEWAY_PORT, amt->gw_port))
3164                 goto nla_put_failure;
3165         if (nla_put_u32(skb, IFLA_AMT_LINK, amt->stream_dev->ifindex))
3166                 goto nla_put_failure;
3167         if (nla_put_in_addr(skb, IFLA_AMT_LOCAL_IP, amt->local_ip))
3168                 goto nla_put_failure;
3169         if (nla_put_in_addr(skb, IFLA_AMT_DISCOVERY_IP, amt->discovery_ip))
3170                 goto nla_put_failure;
3171         if (amt->remote_ip)
3172                 if (nla_put_in_addr(skb, IFLA_AMT_REMOTE_IP, amt->remote_ip))
3173                         goto nla_put_failure;
3174         if (nla_put_u32(skb, IFLA_AMT_MAX_TUNNELS, amt->max_tunnels))
3175                 goto nla_put_failure;
3176
3177         return 0;
3178
3179 nla_put_failure:
3180         return -EMSGSIZE;
3181 }
3182
3183 static struct rtnl_link_ops amt_link_ops __read_mostly = {
3184         .kind           = "amt",
3185         .maxtype        = IFLA_AMT_MAX,
3186         .policy         = amt_policy,
3187         .priv_size      = sizeof(struct amt_dev),
3188         .setup          = amt_link_setup,
3189         .validate       = amt_validate,
3190         .newlink        = amt_newlink,
3191         .dellink        = amt_dellink,
3192         .get_size       = amt_get_size,
3193         .fill_info      = amt_fill_info,
3194 };
3195
3196 static struct net_device *amt_lookup_upper_dev(struct net_device *dev)
3197 {
3198         struct net_device *upper_dev;
3199         struct amt_dev *amt;
3200
3201         for_each_netdev(dev_net(dev), upper_dev) {
3202                 if (netif_is_amt(upper_dev)) {
3203                         amt = netdev_priv(upper_dev);
3204                         if (amt->stream_dev == dev)
3205                                 return upper_dev;
3206                 }
3207         }
3208
3209         return NULL;
3210 }
3211
3212 static int amt_device_event(struct notifier_block *unused,
3213                             unsigned long event, void *ptr)
3214 {
3215         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
3216         struct net_device *upper_dev;
3217         struct amt_dev *amt;
3218         LIST_HEAD(list);
3219         int new_mtu;
3220
3221         upper_dev = amt_lookup_upper_dev(dev);
3222         if (!upper_dev)
3223                 return NOTIFY_DONE;
3224         amt = netdev_priv(upper_dev);
3225
3226         switch (event) {
3227         case NETDEV_UNREGISTER:
3228                 amt_dellink(amt->dev, &list);
3229                 unregister_netdevice_many(&list);
3230                 break;
3231         case NETDEV_CHANGEMTU:
3232                 if (amt->mode == AMT_MODE_RELAY)
3233                         new_mtu = dev->mtu - AMT_RELAY_HLEN;
3234                 else
3235                         new_mtu = dev->mtu - AMT_GW_HLEN;
3236
3237                 dev_set_mtu(amt->dev, new_mtu);
3238                 break;
3239         }
3240
3241         return NOTIFY_DONE;
3242 }
3243
3244 static struct notifier_block amt_notifier_block __read_mostly = {
3245         .notifier_call = amt_device_event,
3246 };
3247
3248 static int __init amt_init(void)
3249 {
3250         int err;
3251
3252         err = register_netdevice_notifier(&amt_notifier_block);
3253         if (err < 0)
3254                 goto err;
3255
3256         err = rtnl_link_register(&amt_link_ops);
3257         if (err < 0)
3258                 goto unregister_notifier;
3259
3260         amt_wq = alloc_workqueue("amt", WQ_UNBOUND, 1);
3261         if (!amt_wq) {
3262                 err = -ENOMEM;
3263                 goto rtnl_unregister;
3264         }
3265
3266         spin_lock_init(&source_gc_lock);
3267         spin_lock_bh(&source_gc_lock);
3268         INIT_DELAYED_WORK(&source_gc_wq, amt_source_gc_work);
3269         mod_delayed_work(amt_wq, &source_gc_wq,
3270                          msecs_to_jiffies(AMT_GC_INTERVAL));
3271         spin_unlock_bh(&source_gc_lock);
3272
3273         return 0;
3274
3275 rtnl_unregister:
3276         rtnl_link_unregister(&amt_link_ops);
3277 unregister_notifier:
3278         unregister_netdevice_notifier(&amt_notifier_block);
3279 err:
3280         pr_err("error loading AMT module loaded\n");
3281         return err;
3282 }
3283 late_initcall(amt_init);
3284
3285 static void __exit amt_fini(void)
3286 {
3287         rtnl_link_unregister(&amt_link_ops);
3288         unregister_netdevice_notifier(&amt_notifier_block);
3289         cancel_delayed_work_sync(&source_gc_wq);
3290         __amt_source_gc_work();
3291         destroy_workqueue(amt_wq);
3292 }
3293 module_exit(amt_fini);
3294
3295 MODULE_LICENSE("GPL");
3296 MODULE_AUTHOR("Taehee Yoo <ap420073@gmail.com>");
3297 MODULE_ALIAS_RTNL_LINK("amt");