MAINTAINERS: add Vincenzo Frascino to KASAN reviewers
[platform/kernel/linux-starfive.git] / drivers / net / amt.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Copyright (c) 2021 Taehee Yoo <ap420073@gmail.com> */
3
4 #define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
5
6 #include <linux/module.h>
7 #include <linux/skbuff.h>
8 #include <linux/udp.h>
9 #include <linux/jhash.h>
10 #include <linux/if_tunnel.h>
11 #include <linux/net.h>
12 #include <linux/igmp.h>
13 #include <linux/workqueue.h>
14 #include <net/sch_generic.h>
15 #include <net/net_namespace.h>
16 #include <net/ip.h>
17 #include <net/udp.h>
18 #include <net/udp_tunnel.h>
19 #include <net/icmp.h>
20 #include <net/mld.h>
21 #include <net/amt.h>
22 #include <uapi/linux/amt.h>
23 #include <linux/security.h>
24 #include <net/gro_cells.h>
25 #include <net/ipv6.h>
26 #include <net/if_inet6.h>
27 #include <net/ndisc.h>
28 #include <net/addrconf.h>
29 #include <net/ip6_route.h>
30 #include <net/inet_common.h>
31 #include <net/ip6_checksum.h>
32
33 static struct workqueue_struct *amt_wq;
34
35 static HLIST_HEAD(source_gc_list);
36 /* Lock for source_gc_list */
37 static spinlock_t source_gc_lock;
38 static struct delayed_work source_gc_wq;
39 static char *status_str[] = {
40         "AMT_STATUS_INIT",
41         "AMT_STATUS_SENT_DISCOVERY",
42         "AMT_STATUS_RECEIVED_DISCOVERY",
43         "AMT_STATUS_SENT_ADVERTISEMENT",
44         "AMT_STATUS_RECEIVED_ADVERTISEMENT",
45         "AMT_STATUS_SENT_REQUEST",
46         "AMT_STATUS_RECEIVED_REQUEST",
47         "AMT_STATUS_SENT_QUERY",
48         "AMT_STATUS_RECEIVED_QUERY",
49         "AMT_STATUS_SENT_UPDATE",
50         "AMT_STATUS_RECEIVED_UPDATE",
51 };
52
53 static char *type_str[] = {
54         "AMT_MSG_DISCOVERY",
55         "AMT_MSG_ADVERTISEMENT",
56         "AMT_MSG_REQUEST",
57         "AMT_MSG_MEMBERSHIP_QUERY",
58         "AMT_MSG_MEMBERSHIP_UPDATE",
59         "AMT_MSG_MULTICAST_DATA",
60         "AMT_MSG_TEARDOWM",
61 };
62
63 static char *action_str[] = {
64         "AMT_ACT_GMI",
65         "AMT_ACT_GMI_ZERO",
66         "AMT_ACT_GT",
67         "AMT_ACT_STATUS_FWD_NEW",
68         "AMT_ACT_STATUS_D_FWD_NEW",
69         "AMT_ACT_STATUS_NONE_NEW",
70 };
71
72 static struct igmpv3_grec igmpv3_zero_grec;
73
74 #if IS_ENABLED(CONFIG_IPV6)
75 #define MLD2_ALL_NODE_INIT { { { 0xff, 0x02, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01 } } }
76 static struct in6_addr mld2_all_node = MLD2_ALL_NODE_INIT;
77 static struct mld2_grec mldv2_zero_grec;
78 #endif
79
80 static struct amt_skb_cb *amt_skb_cb(struct sk_buff *skb)
81 {
82         BUILD_BUG_ON(sizeof(struct amt_skb_cb) + sizeof(struct qdisc_skb_cb) >
83                      sizeof_field(struct sk_buff, cb));
84
85         return (struct amt_skb_cb *)((void *)skb->cb +
86                 sizeof(struct qdisc_skb_cb));
87 }
88
89 static void __amt_source_gc_work(void)
90 {
91         struct amt_source_node *snode;
92         struct hlist_head gc_list;
93         struct hlist_node *t;
94
95         spin_lock_bh(&source_gc_lock);
96         hlist_move_list(&source_gc_list, &gc_list);
97         spin_unlock_bh(&source_gc_lock);
98
99         hlist_for_each_entry_safe(snode, t, &gc_list, node) {
100                 hlist_del_rcu(&snode->node);
101                 kfree_rcu(snode, rcu);
102         }
103 }
104
105 static void amt_source_gc_work(struct work_struct *work)
106 {
107         __amt_source_gc_work();
108
109         spin_lock_bh(&source_gc_lock);
110         mod_delayed_work(amt_wq, &source_gc_wq,
111                          msecs_to_jiffies(AMT_GC_INTERVAL));
112         spin_unlock_bh(&source_gc_lock);
113 }
114
115 static bool amt_addr_equal(union amt_addr *a, union amt_addr *b)
116 {
117         return !memcmp(a, b, sizeof(union amt_addr));
118 }
119
120 static u32 amt_source_hash(struct amt_tunnel_list *tunnel, union amt_addr *src)
121 {
122         u32 hash = jhash(src, sizeof(*src), tunnel->amt->hash_seed);
123
124         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
125 }
126
127 static bool amt_status_filter(struct amt_source_node *snode,
128                               enum amt_filter filter)
129 {
130         bool rc = false;
131
132         switch (filter) {
133         case AMT_FILTER_FWD:
134                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
135                     snode->flags == AMT_SOURCE_OLD)
136                         rc = true;
137                 break;
138         case AMT_FILTER_D_FWD:
139                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
140                     snode->flags == AMT_SOURCE_OLD)
141                         rc = true;
142                 break;
143         case AMT_FILTER_FWD_NEW:
144                 if (snode->status == AMT_SOURCE_STATUS_FWD &&
145                     snode->flags == AMT_SOURCE_NEW)
146                         rc = true;
147                 break;
148         case AMT_FILTER_D_FWD_NEW:
149                 if (snode->status == AMT_SOURCE_STATUS_D_FWD &&
150                     snode->flags == AMT_SOURCE_NEW)
151                         rc = true;
152                 break;
153         case AMT_FILTER_ALL:
154                 rc = true;
155                 break;
156         case AMT_FILTER_NONE_NEW:
157                 if (snode->status == AMT_SOURCE_STATUS_NONE &&
158                     snode->flags == AMT_SOURCE_NEW)
159                         rc = true;
160                 break;
161         case AMT_FILTER_BOTH:
162                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
163                      snode->status == AMT_SOURCE_STATUS_FWD) &&
164                     snode->flags == AMT_SOURCE_OLD)
165                         rc = true;
166                 break;
167         case AMT_FILTER_BOTH_NEW:
168                 if ((snode->status == AMT_SOURCE_STATUS_D_FWD ||
169                      snode->status == AMT_SOURCE_STATUS_FWD) &&
170                     snode->flags == AMT_SOURCE_NEW)
171                         rc = true;
172                 break;
173         default:
174                 WARN_ON_ONCE(1);
175                 break;
176         }
177
178         return rc;
179 }
180
181 static struct amt_source_node *amt_lookup_src(struct amt_tunnel_list *tunnel,
182                                               struct amt_group_node *gnode,
183                                               enum amt_filter filter,
184                                               union amt_addr *src)
185 {
186         u32 hash = amt_source_hash(tunnel, src);
187         struct amt_source_node *snode;
188
189         hlist_for_each_entry_rcu(snode, &gnode->sources[hash], node)
190                 if (amt_status_filter(snode, filter) &&
191                     amt_addr_equal(&snode->source_addr, src))
192                         return snode;
193
194         return NULL;
195 }
196
197 static u32 amt_group_hash(struct amt_tunnel_list *tunnel, union amt_addr *group)
198 {
199         u32 hash = jhash(group, sizeof(*group), tunnel->amt->hash_seed);
200
201         return reciprocal_scale(hash, tunnel->amt->hash_buckets);
202 }
203
204 static struct amt_group_node *amt_lookup_group(struct amt_tunnel_list *tunnel,
205                                                union amt_addr *group,
206                                                union amt_addr *host,
207                                                bool v6)
208 {
209         u32 hash = amt_group_hash(tunnel, group);
210         struct amt_group_node *gnode;
211
212         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash], node) {
213                 if (amt_addr_equal(&gnode->group_addr, group) &&
214                     amt_addr_equal(&gnode->host_addr, host) &&
215                     gnode->v6 == v6)
216                         return gnode;
217         }
218
219         return NULL;
220 }
221
222 static void amt_destroy_source(struct amt_source_node *snode)
223 {
224         struct amt_group_node *gnode = snode->gnode;
225         struct amt_tunnel_list *tunnel;
226
227         tunnel = gnode->tunnel_list;
228
229         if (!gnode->v6) {
230                 netdev_dbg(snode->gnode->amt->dev,
231                            "Delete source %pI4 from %pI4\n",
232                            &snode->source_addr.ip4,
233                            &gnode->group_addr.ip4);
234 #if IS_ENABLED(CONFIG_IPV6)
235         } else {
236                 netdev_dbg(snode->gnode->amt->dev,
237                            "Delete source %pI6 from %pI6\n",
238                            &snode->source_addr.ip6,
239                            &gnode->group_addr.ip6);
240 #endif
241         }
242
243         cancel_delayed_work(&snode->source_timer);
244         hlist_del_init_rcu(&snode->node);
245         tunnel->nr_sources--;
246         gnode->nr_sources--;
247         spin_lock_bh(&source_gc_lock);
248         hlist_add_head_rcu(&snode->node, &source_gc_list);
249         spin_unlock_bh(&source_gc_lock);
250 }
251
252 static void amt_del_group(struct amt_dev *amt, struct amt_group_node *gnode)
253 {
254         struct amt_source_node *snode;
255         struct hlist_node *t;
256         int i;
257
258         if (cancel_delayed_work(&gnode->group_timer))
259                 dev_put(amt->dev);
260         hlist_del_rcu(&gnode->node);
261         gnode->tunnel_list->nr_groups--;
262
263         if (!gnode->v6)
264                 netdev_dbg(amt->dev, "Leave group %pI4\n",
265                            &gnode->group_addr.ip4);
266 #if IS_ENABLED(CONFIG_IPV6)
267         else
268                 netdev_dbg(amt->dev, "Leave group %pI6\n",
269                            &gnode->group_addr.ip6);
270 #endif
271         for (i = 0; i < amt->hash_buckets; i++)
272                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node)
273                         amt_destroy_source(snode);
274
275         /* tunnel->lock was acquired outside of amt_del_group()
276          * But rcu_read_lock() was acquired too so It's safe.
277          */
278         kfree_rcu(gnode, rcu);
279 }
280
281 /* If a source timer expires with a router filter-mode for the group of
282  * INCLUDE, the router concludes that traffic from this particular
283  * source is no longer desired on the attached network, and deletes the
284  * associated source record.
285  */
286 static void amt_source_work(struct work_struct *work)
287 {
288         struct amt_source_node *snode = container_of(to_delayed_work(work),
289                                                      struct amt_source_node,
290                                                      source_timer);
291         struct amt_group_node *gnode = snode->gnode;
292         struct amt_dev *amt = gnode->amt;
293         struct amt_tunnel_list *tunnel;
294
295         tunnel = gnode->tunnel_list;
296         spin_lock_bh(&tunnel->lock);
297         rcu_read_lock();
298         if (gnode->filter_mode == MCAST_INCLUDE) {
299                 amt_destroy_source(snode);
300                 if (!gnode->nr_sources)
301                         amt_del_group(amt, gnode);
302         } else {
303                 /* When a router filter-mode for a group is EXCLUDE,
304                  * source records are only deleted when the group timer expires
305                  */
306                 snode->status = AMT_SOURCE_STATUS_D_FWD;
307         }
308         rcu_read_unlock();
309         spin_unlock_bh(&tunnel->lock);
310 }
311
312 static void amt_act_src(struct amt_tunnel_list *tunnel,
313                         struct amt_group_node *gnode,
314                         struct amt_source_node *snode,
315                         enum amt_act act)
316 {
317         struct amt_dev *amt = tunnel->amt;
318
319         switch (act) {
320         case AMT_ACT_GMI:
321                 mod_delayed_work(amt_wq, &snode->source_timer,
322                                  msecs_to_jiffies(amt_gmi(amt)));
323                 break;
324         case AMT_ACT_GMI_ZERO:
325                 cancel_delayed_work(&snode->source_timer);
326                 break;
327         case AMT_ACT_GT:
328                 mod_delayed_work(amt_wq, &snode->source_timer,
329                                  gnode->group_timer.timer.expires);
330                 break;
331         case AMT_ACT_STATUS_FWD_NEW:
332                 snode->status = AMT_SOURCE_STATUS_FWD;
333                 snode->flags = AMT_SOURCE_NEW;
334                 break;
335         case AMT_ACT_STATUS_D_FWD_NEW:
336                 snode->status = AMT_SOURCE_STATUS_D_FWD;
337                 snode->flags = AMT_SOURCE_NEW;
338                 break;
339         case AMT_ACT_STATUS_NONE_NEW:
340                 cancel_delayed_work(&snode->source_timer);
341                 snode->status = AMT_SOURCE_STATUS_NONE;
342                 snode->flags = AMT_SOURCE_NEW;
343                 break;
344         default:
345                 WARN_ON_ONCE(1);
346                 return;
347         }
348
349         if (!gnode->v6)
350                 netdev_dbg(amt->dev, "Source %pI4 from %pI4 Acted %s\n",
351                            &snode->source_addr.ip4,
352                            &gnode->group_addr.ip4,
353                            action_str[act]);
354 #if IS_ENABLED(CONFIG_IPV6)
355         else
356                 netdev_dbg(amt->dev, "Source %pI6 from %pI6 Acted %s\n",
357                            &snode->source_addr.ip6,
358                            &gnode->group_addr.ip6,
359                            action_str[act]);
360 #endif
361 }
362
363 static struct amt_source_node *amt_alloc_snode(struct amt_group_node *gnode,
364                                                union amt_addr *src)
365 {
366         struct amt_source_node *snode;
367
368         snode = kzalloc(sizeof(*snode), GFP_ATOMIC);
369         if (!snode)
370                 return NULL;
371
372         memcpy(&snode->source_addr, src, sizeof(union amt_addr));
373         snode->gnode = gnode;
374         snode->status = AMT_SOURCE_STATUS_NONE;
375         snode->flags = AMT_SOURCE_NEW;
376         INIT_HLIST_NODE(&snode->node);
377         INIT_DELAYED_WORK(&snode->source_timer, amt_source_work);
378
379         return snode;
380 }
381
382 /* RFC 3810 - 7.2.2.  Definition of Filter Timers
383  *
384  *  Router Mode          Filter Timer         Actions/Comments
385  *  -----------       -----------------       ----------------
386  *
387  *    INCLUDE             Not Used            All listeners in
388  *                                            INCLUDE mode.
389  *
390  *    EXCLUDE             Timer > 0           At least one listener
391  *                                            in EXCLUDE mode.
392  *
393  *    EXCLUDE             Timer == 0          No more listeners in
394  *                                            EXCLUDE mode for the
395  *                                            multicast address.
396  *                                            If the Requested List
397  *                                            is empty, delete
398  *                                            Multicast Address
399  *                                            Record.  If not, switch
400  *                                            to INCLUDE filter mode;
401  *                                            the sources in the
402  *                                            Requested List are
403  *                                            moved to the Include
404  *                                            List, and the Exclude
405  *                                            List is deleted.
406  */
407 static void amt_group_work(struct work_struct *work)
408 {
409         struct amt_group_node *gnode = container_of(to_delayed_work(work),
410                                                     struct amt_group_node,
411                                                     group_timer);
412         struct amt_tunnel_list *tunnel = gnode->tunnel_list;
413         struct amt_dev *amt = gnode->amt;
414         struct amt_source_node *snode;
415         bool delete_group = true;
416         struct hlist_node *t;
417         int i, buckets;
418
419         buckets = amt->hash_buckets;
420
421         spin_lock_bh(&tunnel->lock);
422         if (gnode->filter_mode == MCAST_INCLUDE) {
423                 /* Not Used */
424                 spin_unlock_bh(&tunnel->lock);
425                 goto out;
426         }
427
428         rcu_read_lock();
429         for (i = 0; i < buckets; i++) {
430                 hlist_for_each_entry_safe(snode, t,
431                                           &gnode->sources[i], node) {
432                         if (!delayed_work_pending(&snode->source_timer) ||
433                             snode->status == AMT_SOURCE_STATUS_D_FWD) {
434                                 amt_destroy_source(snode);
435                         } else {
436                                 delete_group = false;
437                                 snode->status = AMT_SOURCE_STATUS_FWD;
438                         }
439                 }
440         }
441         if (delete_group)
442                 amt_del_group(amt, gnode);
443         else
444                 gnode->filter_mode = MCAST_INCLUDE;
445         rcu_read_unlock();
446         spin_unlock_bh(&tunnel->lock);
447 out:
448         dev_put(amt->dev);
449 }
450
451 /* Non-existant group is created as INCLUDE {empty}:
452  *
453  * RFC 3376 - 5.1. Action on Change of Interface State
454  *
455  * If no interface state existed for that multicast address before
456  * the change (i.e., the change consisted of creating a new
457  * per-interface record), or if no state exists after the change
458  * (i.e., the change consisted of deleting a per-interface record),
459  * then the "non-existent" state is considered to have a filter mode
460  * of INCLUDE and an empty source list.
461  */
462 static struct amt_group_node *amt_add_group(struct amt_dev *amt,
463                                             struct amt_tunnel_list *tunnel,
464                                             union amt_addr *group,
465                                             union amt_addr *host,
466                                             bool v6)
467 {
468         struct amt_group_node *gnode;
469         u32 hash;
470         int i;
471
472         if (tunnel->nr_groups >= amt->max_groups)
473                 return ERR_PTR(-ENOSPC);
474
475         gnode = kzalloc(sizeof(*gnode) +
476                         (sizeof(struct hlist_head) * amt->hash_buckets),
477                         GFP_ATOMIC);
478         if (unlikely(!gnode))
479                 return ERR_PTR(-ENOMEM);
480
481         gnode->amt = amt;
482         gnode->group_addr = *group;
483         gnode->host_addr = *host;
484         gnode->v6 = v6;
485         gnode->tunnel_list = tunnel;
486         gnode->filter_mode = MCAST_INCLUDE;
487         INIT_HLIST_NODE(&gnode->node);
488         INIT_DELAYED_WORK(&gnode->group_timer, amt_group_work);
489         for (i = 0; i < amt->hash_buckets; i++)
490                 INIT_HLIST_HEAD(&gnode->sources[i]);
491
492         hash = amt_group_hash(tunnel, group);
493         hlist_add_head_rcu(&gnode->node, &tunnel->groups[hash]);
494         tunnel->nr_groups++;
495
496         if (!gnode->v6)
497                 netdev_dbg(amt->dev, "Join group %pI4\n",
498                            &gnode->group_addr.ip4);
499 #if IS_ENABLED(CONFIG_IPV6)
500         else
501                 netdev_dbg(amt->dev, "Join group %pI6\n",
502                            &gnode->group_addr.ip6);
503 #endif
504
505         return gnode;
506 }
507
508 static struct sk_buff *amt_build_igmp_gq(struct amt_dev *amt)
509 {
510         u8 ra[AMT_IPHDR_OPTS] = { IPOPT_RA, 4, 0, 0 };
511         int hlen = LL_RESERVED_SPACE(amt->dev);
512         int tlen = amt->dev->needed_tailroom;
513         struct igmpv3_query *ihv3;
514         void *csum_start = NULL;
515         __sum16 *csum = NULL;
516         struct sk_buff *skb;
517         struct ethhdr *eth;
518         struct iphdr *iph;
519         unsigned int len;
520         int offset;
521
522         len = hlen + tlen + sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3);
523         skb = netdev_alloc_skb_ip_align(amt->dev, len);
524         if (!skb)
525                 return NULL;
526
527         skb_reserve(skb, hlen);
528         skb_push(skb, sizeof(*eth));
529         skb->protocol = htons(ETH_P_IP);
530         skb_reset_mac_header(skb);
531         skb->priority = TC_PRIO_CONTROL;
532         skb_put(skb, sizeof(*iph));
533         skb_put_data(skb, ra, sizeof(ra));
534         skb_put(skb, sizeof(*ihv3));
535         skb_pull(skb, sizeof(*eth));
536         skb_reset_network_header(skb);
537
538         iph             = ip_hdr(skb);
539         iph->version    = 4;
540         iph->ihl        = (sizeof(struct iphdr) + AMT_IPHDR_OPTS) >> 2;
541         iph->tos        = AMT_TOS;
542         iph->tot_len    = htons(sizeof(*iph) + AMT_IPHDR_OPTS + sizeof(*ihv3));
543         iph->frag_off   = htons(IP_DF);
544         iph->ttl        = 1;
545         iph->id         = 0;
546         iph->protocol   = IPPROTO_IGMP;
547         iph->daddr      = htonl(INADDR_ALLHOSTS_GROUP);
548         iph->saddr      = htonl(INADDR_ANY);
549         ip_send_check(iph);
550
551         eth = eth_hdr(skb);
552         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
553         ip_eth_mc_map(htonl(INADDR_ALLHOSTS_GROUP), eth->h_dest);
554         eth->h_proto = htons(ETH_P_IP);
555
556         ihv3            = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
557         skb_reset_transport_header(skb);
558         ihv3->type      = IGMP_HOST_MEMBERSHIP_QUERY;
559         ihv3->code      = 1;
560         ihv3->group     = 0;
561         ihv3->qqic      = amt->qi;
562         ihv3->nsrcs     = 0;
563         ihv3->resv      = 0;
564         ihv3->suppress  = false;
565         ihv3->qrv       = amt->net->ipv4.sysctl_igmp_qrv;
566         ihv3->csum      = 0;
567         csum            = &ihv3->csum;
568         csum_start      = (void *)ihv3;
569         *csum           = ip_compute_csum(csum_start, sizeof(*ihv3));
570         offset          = skb_transport_offset(skb);
571         skb->csum       = skb_checksum(skb, offset, skb->len - offset, 0);
572         skb->ip_summed  = CHECKSUM_NONE;
573
574         skb_push(skb, sizeof(*eth) + sizeof(*iph) + AMT_IPHDR_OPTS);
575
576         return skb;
577 }
578
579 static void __amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
580                                    bool validate)
581 {
582         if (validate && amt->status >= status)
583                 return;
584         netdev_dbg(amt->dev, "Update GW status %s -> %s",
585                    status_str[amt->status], status_str[status]);
586         amt->status = status;
587 }
588
589 static void __amt_update_relay_status(struct amt_tunnel_list *tunnel,
590                                       enum amt_status status,
591                                       bool validate)
592 {
593         if (validate && tunnel->status >= status)
594                 return;
595         netdev_dbg(tunnel->amt->dev,
596                    "Update Tunnel(IP = %pI4, PORT = %u) status %s -> %s",
597                    &tunnel->ip4, ntohs(tunnel->source_port),
598                    status_str[tunnel->status], status_str[status]);
599         tunnel->status = status;
600 }
601
602 static void amt_update_gw_status(struct amt_dev *amt, enum amt_status status,
603                                  bool validate)
604 {
605         spin_lock_bh(&amt->lock);
606         __amt_update_gw_status(amt, status, validate);
607         spin_unlock_bh(&amt->lock);
608 }
609
610 static void amt_update_relay_status(struct amt_tunnel_list *tunnel,
611                                     enum amt_status status, bool validate)
612 {
613         spin_lock_bh(&tunnel->lock);
614         __amt_update_relay_status(tunnel, status, validate);
615         spin_unlock_bh(&tunnel->lock);
616 }
617
618 static void amt_send_discovery(struct amt_dev *amt)
619 {
620         struct amt_header_discovery *amtd;
621         int hlen, tlen, offset;
622         struct socket *sock;
623         struct udphdr *udph;
624         struct sk_buff *skb;
625         struct iphdr *iph;
626         struct rtable *rt;
627         struct flowi4 fl4;
628         u32 len;
629         int err;
630
631         rcu_read_lock();
632         sock = rcu_dereference(amt->sock);
633         if (!sock)
634                 goto out;
635
636         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
637                 goto out;
638
639         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
640                                    amt->discovery_ip, amt->local_ip,
641                                    amt->gw_port, amt->relay_port,
642                                    IPPROTO_UDP, 0,
643                                    amt->stream_dev->ifindex);
644         if (IS_ERR(rt)) {
645                 amt->dev->stats.tx_errors++;
646                 goto out;
647         }
648
649         hlen = LL_RESERVED_SPACE(amt->dev);
650         tlen = amt->dev->needed_tailroom;
651         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
652         skb = netdev_alloc_skb_ip_align(amt->dev, len);
653         if (!skb) {
654                 ip_rt_put(rt);
655                 amt->dev->stats.tx_errors++;
656                 goto out;
657         }
658
659         skb->priority = TC_PRIO_CONTROL;
660         skb_dst_set(skb, &rt->dst);
661
662         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtd);
663         skb_reset_network_header(skb);
664         skb_put(skb, len);
665         amtd = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
666         amtd->version   = 0;
667         amtd->type      = AMT_MSG_DISCOVERY;
668         amtd->reserved  = 0;
669         amtd->nonce     = amt->nonce;
670         skb_push(skb, sizeof(*udph));
671         skb_reset_transport_header(skb);
672         udph            = udp_hdr(skb);
673         udph->source    = amt->gw_port;
674         udph->dest      = amt->relay_port;
675         udph->len       = htons(sizeof(*udph) + sizeof(*amtd));
676         udph->check     = 0;
677         offset = skb_transport_offset(skb);
678         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
679         udph->check = csum_tcpudp_magic(amt->local_ip, amt->discovery_ip,
680                                         sizeof(*udph) + sizeof(*amtd),
681                                         IPPROTO_UDP, skb->csum);
682
683         skb_push(skb, sizeof(*iph));
684         iph             = ip_hdr(skb);
685         iph->version    = 4;
686         iph->ihl        = (sizeof(struct iphdr)) >> 2;
687         iph->tos        = AMT_TOS;
688         iph->frag_off   = 0;
689         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
690         iph->daddr      = amt->discovery_ip;
691         iph->saddr      = amt->local_ip;
692         iph->protocol   = IPPROTO_UDP;
693         iph->tot_len    = htons(len);
694
695         skb->ip_summed = CHECKSUM_NONE;
696         ip_select_ident(amt->net, skb, NULL);
697         ip_send_check(iph);
698         err = ip_local_out(amt->net, sock->sk, skb);
699         if (unlikely(net_xmit_eval(err)))
700                 amt->dev->stats.tx_errors++;
701
702         spin_lock_bh(&amt->lock);
703         __amt_update_gw_status(amt, AMT_STATUS_SENT_DISCOVERY, true);
704         spin_unlock_bh(&amt->lock);
705 out:
706         rcu_read_unlock();
707 }
708
709 static void amt_send_request(struct amt_dev *amt, bool v6)
710 {
711         struct amt_header_request *amtrh;
712         int hlen, tlen, offset;
713         struct socket *sock;
714         struct udphdr *udph;
715         struct sk_buff *skb;
716         struct iphdr *iph;
717         struct rtable *rt;
718         struct flowi4 fl4;
719         u32 len;
720         int err;
721
722         rcu_read_lock();
723         sock = rcu_dereference(amt->sock);
724         if (!sock)
725                 goto out;
726
727         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
728                 goto out;
729
730         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
731                                    amt->remote_ip, amt->local_ip,
732                                    amt->gw_port, amt->relay_port,
733                                    IPPROTO_UDP, 0,
734                                    amt->stream_dev->ifindex);
735         if (IS_ERR(rt)) {
736                 amt->dev->stats.tx_errors++;
737                 goto out;
738         }
739
740         hlen = LL_RESERVED_SPACE(amt->dev);
741         tlen = amt->dev->needed_tailroom;
742         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
743         skb = netdev_alloc_skb_ip_align(amt->dev, len);
744         if (!skb) {
745                 ip_rt_put(rt);
746                 amt->dev->stats.tx_errors++;
747                 goto out;
748         }
749
750         skb->priority = TC_PRIO_CONTROL;
751         skb_dst_set(skb, &rt->dst);
752
753         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amtrh);
754         skb_reset_network_header(skb);
755         skb_put(skb, len);
756         amtrh = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
757         amtrh->version   = 0;
758         amtrh->type      = AMT_MSG_REQUEST;
759         amtrh->reserved1 = 0;
760         amtrh->p         = v6;
761         amtrh->reserved2 = 0;
762         amtrh->nonce     = amt->nonce;
763         skb_push(skb, sizeof(*udph));
764         skb_reset_transport_header(skb);
765         udph            = udp_hdr(skb);
766         udph->source    = amt->gw_port;
767         udph->dest      = amt->relay_port;
768         udph->len       = htons(sizeof(*amtrh) + sizeof(*udph));
769         udph->check     = 0;
770         offset = skb_transport_offset(skb);
771         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
772         udph->check = csum_tcpudp_magic(amt->local_ip, amt->remote_ip,
773                                         sizeof(*udph) + sizeof(*amtrh),
774                                         IPPROTO_UDP, skb->csum);
775
776         skb_push(skb, sizeof(*iph));
777         iph             = ip_hdr(skb);
778         iph->version    = 4;
779         iph->ihl        = (sizeof(struct iphdr)) >> 2;
780         iph->tos        = AMT_TOS;
781         iph->frag_off   = 0;
782         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
783         iph->daddr      = amt->remote_ip;
784         iph->saddr      = amt->local_ip;
785         iph->protocol   = IPPROTO_UDP;
786         iph->tot_len    = htons(len);
787
788         skb->ip_summed = CHECKSUM_NONE;
789         ip_select_ident(amt->net, skb, NULL);
790         ip_send_check(iph);
791         err = ip_local_out(amt->net, sock->sk, skb);
792         if (unlikely(net_xmit_eval(err)))
793                 amt->dev->stats.tx_errors++;
794
795 out:
796         rcu_read_unlock();
797 }
798
799 static void amt_send_igmp_gq(struct amt_dev *amt,
800                              struct amt_tunnel_list *tunnel)
801 {
802         struct sk_buff *skb;
803
804         skb = amt_build_igmp_gq(amt);
805         if (!skb)
806                 return;
807
808         amt_skb_cb(skb)->tunnel = tunnel;
809         dev_queue_xmit(skb);
810 }
811
812 #if IS_ENABLED(CONFIG_IPV6)
813 static struct sk_buff *amt_build_mld_gq(struct amt_dev *amt)
814 {
815         u8 ra[AMT_IP6HDR_OPTS] = { IPPROTO_ICMPV6, 0, IPV6_TLV_ROUTERALERT,
816                                    2, 0, 0, IPV6_TLV_PAD1, IPV6_TLV_PAD1 };
817         int hlen = LL_RESERVED_SPACE(amt->dev);
818         int tlen = amt->dev->needed_tailroom;
819         struct mld2_query *mld2q;
820         void *csum_start = NULL;
821         struct ipv6hdr *ip6h;
822         struct sk_buff *skb;
823         struct ethhdr *eth;
824         u32 len;
825
826         len = hlen + tlen + sizeof(*ip6h) + sizeof(ra) + sizeof(*mld2q);
827         skb = netdev_alloc_skb_ip_align(amt->dev, len);
828         if (!skb)
829                 return NULL;
830
831         skb_reserve(skb, hlen);
832         skb_push(skb, sizeof(*eth));
833         skb_reset_mac_header(skb);
834         eth = eth_hdr(skb);
835         skb->priority = TC_PRIO_CONTROL;
836         skb->protocol = htons(ETH_P_IPV6);
837         skb_put_zero(skb, sizeof(*ip6h));
838         skb_put_data(skb, ra, sizeof(ra));
839         skb_put_zero(skb, sizeof(*mld2q));
840         skb_pull(skb, sizeof(*eth));
841         skb_reset_network_header(skb);
842         ip6h                    = ipv6_hdr(skb);
843         ip6h->payload_len       = htons(sizeof(ra) + sizeof(*mld2q));
844         ip6h->nexthdr           = NEXTHDR_HOP;
845         ip6h->hop_limit         = 1;
846         ip6h->daddr             = mld2_all_node;
847         ip6_flow_hdr(ip6h, 0, 0);
848
849         if (ipv6_dev_get_saddr(amt->net, amt->dev, &ip6h->daddr, 0,
850                                &ip6h->saddr)) {
851                 amt->dev->stats.tx_errors++;
852                 kfree_skb(skb);
853                 return NULL;
854         }
855
856         eth->h_proto = htons(ETH_P_IPV6);
857         ether_addr_copy(eth->h_source, amt->dev->dev_addr);
858         ipv6_eth_mc_map(&mld2_all_node, eth->h_dest);
859
860         skb_pull(skb, sizeof(*ip6h) + sizeof(ra));
861         skb_reset_transport_header(skb);
862         mld2q                   = (struct mld2_query *)icmp6_hdr(skb);
863         mld2q->mld2q_mrc        = htons(1);
864         mld2q->mld2q_type       = ICMPV6_MGM_QUERY;
865         mld2q->mld2q_code       = 0;
866         mld2q->mld2q_cksum      = 0;
867         mld2q->mld2q_resv1      = 0;
868         mld2q->mld2q_resv2      = 0;
869         mld2q->mld2q_suppress   = 0;
870         mld2q->mld2q_qrv        = amt->qrv;
871         mld2q->mld2q_nsrcs      = 0;
872         mld2q->mld2q_qqic       = amt->qi;
873         csum_start              = (void *)mld2q;
874         mld2q->mld2q_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
875                                              sizeof(*mld2q),
876                                              IPPROTO_ICMPV6,
877                                              csum_partial(csum_start,
878                                                           sizeof(*mld2q), 0));
879
880         skb->ip_summed = CHECKSUM_NONE;
881         skb_push(skb, sizeof(*eth) + sizeof(*ip6h) + sizeof(ra));
882         return skb;
883 }
884
885 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
886 {
887         struct sk_buff *skb;
888
889         skb = amt_build_mld_gq(amt);
890         if (!skb)
891                 return;
892
893         amt_skb_cb(skb)->tunnel = tunnel;
894         dev_queue_xmit(skb);
895 }
896 #else
897 static void amt_send_mld_gq(struct amt_dev *amt, struct amt_tunnel_list *tunnel)
898 {
899 }
900 #endif
901
902 static void amt_secret_work(struct work_struct *work)
903 {
904         struct amt_dev *amt = container_of(to_delayed_work(work),
905                                            struct amt_dev,
906                                            secret_wq);
907
908         spin_lock_bh(&amt->lock);
909         get_random_bytes(&amt->key, sizeof(siphash_key_t));
910         spin_unlock_bh(&amt->lock);
911         mod_delayed_work(amt_wq, &amt->secret_wq,
912                          msecs_to_jiffies(AMT_SECRET_TIMEOUT));
913 }
914
915 static void amt_discovery_work(struct work_struct *work)
916 {
917         struct amt_dev *amt = container_of(to_delayed_work(work),
918                                            struct amt_dev,
919                                            discovery_wq);
920
921         spin_lock_bh(&amt->lock);
922         if (amt->status > AMT_STATUS_SENT_DISCOVERY)
923                 goto out;
924         get_random_bytes(&amt->nonce, sizeof(__be32));
925         spin_unlock_bh(&amt->lock);
926
927         amt_send_discovery(amt);
928         spin_lock_bh(&amt->lock);
929 out:
930         mod_delayed_work(amt_wq, &amt->discovery_wq,
931                          msecs_to_jiffies(AMT_DISCOVERY_TIMEOUT));
932         spin_unlock_bh(&amt->lock);
933 }
934
935 static void amt_req_work(struct work_struct *work)
936 {
937         struct amt_dev *amt = container_of(to_delayed_work(work),
938                                            struct amt_dev,
939                                            req_wq);
940         u32 exp;
941
942         spin_lock_bh(&amt->lock);
943         if (amt->status < AMT_STATUS_RECEIVED_ADVERTISEMENT)
944                 goto out;
945
946         if (amt->req_cnt++ > AMT_MAX_REQ_COUNT) {
947                 netdev_dbg(amt->dev, "Gateway is not ready");
948                 amt->qi = AMT_INIT_REQ_TIMEOUT;
949                 amt->ready4 = false;
950                 amt->ready6 = false;
951                 amt->remote_ip = 0;
952                 __amt_update_gw_status(amt, AMT_STATUS_INIT, false);
953                 amt->req_cnt = 0;
954         }
955         spin_unlock_bh(&amt->lock);
956
957         amt_send_request(amt, false);
958         amt_send_request(amt, true);
959         amt_update_gw_status(amt, AMT_STATUS_SENT_REQUEST, true);
960         spin_lock_bh(&amt->lock);
961 out:
962         exp = min_t(u32, (1 * (1 << amt->req_cnt)), AMT_MAX_REQ_TIMEOUT);
963         mod_delayed_work(amt_wq, &amt->req_wq, msecs_to_jiffies(exp * 1000));
964         spin_unlock_bh(&amt->lock);
965 }
966
967 static bool amt_send_membership_update(struct amt_dev *amt,
968                                        struct sk_buff *skb,
969                                        bool v6)
970 {
971         struct amt_header_membership_update *amtmu;
972         struct socket *sock;
973         struct iphdr *iph;
974         struct flowi4 fl4;
975         struct rtable *rt;
976         int err;
977
978         sock = rcu_dereference_bh(amt->sock);
979         if (!sock)
980                 return true;
981
982         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmu) +
983                            sizeof(*iph) + sizeof(struct udphdr));
984         if (err)
985                 return true;
986
987         skb_reset_inner_headers(skb);
988         memset(&fl4, 0, sizeof(struct flowi4));
989         fl4.flowi4_oif         = amt->stream_dev->ifindex;
990         fl4.daddr              = amt->remote_ip;
991         fl4.saddr              = amt->local_ip;
992         fl4.flowi4_tos         = AMT_TOS;
993         fl4.flowi4_proto       = IPPROTO_UDP;
994         rt = ip_route_output_key(amt->net, &fl4);
995         if (IS_ERR(rt)) {
996                 netdev_dbg(amt->dev, "no route to %pI4\n", &amt->remote_ip);
997                 return true;
998         }
999
1000         amtmu                   = skb_push(skb, sizeof(*amtmu));
1001         amtmu->version          = 0;
1002         amtmu->type             = AMT_MSG_MEMBERSHIP_UPDATE;
1003         amtmu->reserved         = 0;
1004         amtmu->nonce            = amt->nonce;
1005         amtmu->response_mac     = amt->mac;
1006
1007         if (!v6)
1008                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1009         else
1010                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1011         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1012                             fl4.saddr,
1013                             fl4.daddr,
1014                             AMT_TOS,
1015                             ip4_dst_hoplimit(&rt->dst),
1016                             0,
1017                             amt->gw_port,
1018                             amt->relay_port,
1019                             false,
1020                             false);
1021         amt_update_gw_status(amt, AMT_STATUS_SENT_UPDATE, true);
1022         return false;
1023 }
1024
1025 static void amt_send_multicast_data(struct amt_dev *amt,
1026                                     const struct sk_buff *oskb,
1027                                     struct amt_tunnel_list *tunnel,
1028                                     bool v6)
1029 {
1030         struct amt_header_mcast_data *amtmd;
1031         struct socket *sock;
1032         struct sk_buff *skb;
1033         struct iphdr *iph;
1034         struct flowi4 fl4;
1035         struct rtable *rt;
1036
1037         sock = rcu_dereference_bh(amt->sock);
1038         if (!sock)
1039                 return;
1040
1041         skb = skb_copy_expand(oskb, sizeof(*amtmd) + sizeof(*iph) +
1042                               sizeof(struct udphdr), 0, GFP_ATOMIC);
1043         if (!skb)
1044                 return;
1045
1046         skb_reset_inner_headers(skb);
1047         memset(&fl4, 0, sizeof(struct flowi4));
1048         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1049         fl4.daddr              = tunnel->ip4;
1050         fl4.saddr              = amt->local_ip;
1051         fl4.flowi4_proto       = IPPROTO_UDP;
1052         rt = ip_route_output_key(amt->net, &fl4);
1053         if (IS_ERR(rt)) {
1054                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1055                 kfree_skb(skb);
1056                 return;
1057         }
1058
1059         amtmd = skb_push(skb, sizeof(*amtmd));
1060         amtmd->version = 0;
1061         amtmd->reserved = 0;
1062         amtmd->type = AMT_MSG_MULTICAST_DATA;
1063
1064         if (!v6)
1065                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1066         else
1067                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1068         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1069                             fl4.saddr,
1070                             fl4.daddr,
1071                             AMT_TOS,
1072                             ip4_dst_hoplimit(&rt->dst),
1073                             0,
1074                             amt->relay_port,
1075                             tunnel->source_port,
1076                             false,
1077                             false);
1078 }
1079
1080 static bool amt_send_membership_query(struct amt_dev *amt,
1081                                       struct sk_buff *skb,
1082                                       struct amt_tunnel_list *tunnel,
1083                                       bool v6)
1084 {
1085         struct amt_header_membership_query *amtmq;
1086         struct socket *sock;
1087         struct rtable *rt;
1088         struct flowi4 fl4;
1089         int err;
1090
1091         sock = rcu_dereference_bh(amt->sock);
1092         if (!sock)
1093                 return true;
1094
1095         err = skb_cow_head(skb, LL_RESERVED_SPACE(amt->dev) + sizeof(*amtmq) +
1096                            sizeof(struct iphdr) + sizeof(struct udphdr));
1097         if (err)
1098                 return true;
1099
1100         skb_reset_inner_headers(skb);
1101         memset(&fl4, 0, sizeof(struct flowi4));
1102         fl4.flowi4_oif         = amt->stream_dev->ifindex;
1103         fl4.daddr              = tunnel->ip4;
1104         fl4.saddr              = amt->local_ip;
1105         fl4.flowi4_tos         = AMT_TOS;
1106         fl4.flowi4_proto       = IPPROTO_UDP;
1107         rt = ip_route_output_key(amt->net, &fl4);
1108         if (IS_ERR(rt)) {
1109                 netdev_dbg(amt->dev, "no route to %pI4\n", &tunnel->ip4);
1110                 return true;
1111         }
1112
1113         amtmq           = skb_push(skb, sizeof(*amtmq));
1114         amtmq->version  = 0;
1115         amtmq->type     = AMT_MSG_MEMBERSHIP_QUERY;
1116         amtmq->reserved = 0;
1117         amtmq->l        = 0;
1118         amtmq->g        = 0;
1119         amtmq->nonce    = tunnel->nonce;
1120         amtmq->response_mac = tunnel->mac;
1121
1122         if (!v6)
1123                 skb_set_inner_protocol(skb, htons(ETH_P_IP));
1124         else
1125                 skb_set_inner_protocol(skb, htons(ETH_P_IPV6));
1126         udp_tunnel_xmit_skb(rt, sock->sk, skb,
1127                             fl4.saddr,
1128                             fl4.daddr,
1129                             AMT_TOS,
1130                             ip4_dst_hoplimit(&rt->dst),
1131                             0,
1132                             amt->relay_port,
1133                             tunnel->source_port,
1134                             false,
1135                             false);
1136         amt_update_relay_status(tunnel, AMT_STATUS_SENT_QUERY, true);
1137         return false;
1138 }
1139
1140 static netdev_tx_t amt_dev_xmit(struct sk_buff *skb, struct net_device *dev)
1141 {
1142         struct amt_dev *amt = netdev_priv(dev);
1143         struct amt_tunnel_list *tunnel;
1144         struct amt_group_node *gnode;
1145         union amt_addr group = {0,};
1146 #if IS_ENABLED(CONFIG_IPV6)
1147         struct ipv6hdr *ip6h;
1148         struct mld_msg *mld;
1149 #endif
1150         bool report = false;
1151         struct igmphdr *ih;
1152         bool query = false;
1153         struct iphdr *iph;
1154         bool data = false;
1155         bool v6 = false;
1156         u32 hash;
1157
1158         iph = ip_hdr(skb);
1159         if (iph->version == 4) {
1160                 if (!ipv4_is_multicast(iph->daddr))
1161                         goto free;
1162
1163                 if (!ip_mc_check_igmp(skb)) {
1164                         ih = igmp_hdr(skb);
1165                         switch (ih->type) {
1166                         case IGMPV3_HOST_MEMBERSHIP_REPORT:
1167                         case IGMP_HOST_MEMBERSHIP_REPORT:
1168                                 report = true;
1169                                 break;
1170                         case IGMP_HOST_MEMBERSHIP_QUERY:
1171                                 query = true;
1172                                 break;
1173                         default:
1174                                 goto free;
1175                         }
1176                 } else {
1177                         data = true;
1178                 }
1179                 v6 = false;
1180                 group.ip4 = iph->daddr;
1181 #if IS_ENABLED(CONFIG_IPV6)
1182         } else if (iph->version == 6) {
1183                 ip6h = ipv6_hdr(skb);
1184                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
1185                         goto free;
1186
1187                 if (!ipv6_mc_check_mld(skb)) {
1188                         mld = (struct mld_msg *)skb_transport_header(skb);
1189                         switch (mld->mld_type) {
1190                         case ICMPV6_MGM_REPORT:
1191                         case ICMPV6_MLD2_REPORT:
1192                                 report = true;
1193                                 break;
1194                         case ICMPV6_MGM_QUERY:
1195                                 query = true;
1196                                 break;
1197                         default:
1198                                 goto free;
1199                         }
1200                 } else {
1201                         data = true;
1202                 }
1203                 v6 = true;
1204                 group.ip6 = ip6h->daddr;
1205 #endif
1206         } else {
1207                 dev->stats.tx_errors++;
1208                 goto free;
1209         }
1210
1211         if (!pskb_may_pull(skb, sizeof(struct ethhdr)))
1212                 goto free;
1213
1214         skb_pull(skb, sizeof(struct ethhdr));
1215
1216         if (amt->mode == AMT_MODE_GATEWAY) {
1217                 /* Gateway only passes IGMP/MLD packets */
1218                 if (!report)
1219                         goto free;
1220                 if ((!v6 && !amt->ready4) || (v6 && !amt->ready6))
1221                         goto free;
1222                 if (amt_send_membership_update(amt, skb,  v6))
1223                         goto free;
1224                 goto unlock;
1225         } else if (amt->mode == AMT_MODE_RELAY) {
1226                 if (query) {
1227                         tunnel = amt_skb_cb(skb)->tunnel;
1228                         if (!tunnel) {
1229                                 WARN_ON(1);
1230                                 goto free;
1231                         }
1232
1233                         /* Do not forward unexpected query */
1234                         if (amt_send_membership_query(amt, skb, tunnel, v6))
1235                                 goto free;
1236                         goto unlock;
1237                 }
1238
1239                 if (!data)
1240                         goto free;
1241                 list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
1242                         hash = amt_group_hash(tunnel, &group);
1243                         hlist_for_each_entry_rcu(gnode, &tunnel->groups[hash],
1244                                                  node) {
1245                                 if (!v6) {
1246                                         if (gnode->group_addr.ip4 == iph->daddr)
1247                                                 goto found;
1248 #if IS_ENABLED(CONFIG_IPV6)
1249                                 } else {
1250                                         if (ipv6_addr_equal(&gnode->group_addr.ip6,
1251                                                             &ip6h->daddr))
1252                                                 goto found;
1253 #endif
1254                                 }
1255                         }
1256                         continue;
1257 found:
1258                         amt_send_multicast_data(amt, skb, tunnel, v6);
1259                 }
1260         }
1261
1262         dev_kfree_skb(skb);
1263         return NETDEV_TX_OK;
1264 free:
1265         dev_kfree_skb(skb);
1266 unlock:
1267         dev->stats.tx_dropped++;
1268         return NETDEV_TX_OK;
1269 }
1270
1271 static int amt_parse_type(struct sk_buff *skb)
1272 {
1273         struct amt_header *amth;
1274
1275         if (!pskb_may_pull(skb, sizeof(struct udphdr) +
1276                            sizeof(struct amt_header)))
1277                 return -1;
1278
1279         amth = (struct amt_header *)(udp_hdr(skb) + 1);
1280
1281         if (amth->version != 0)
1282                 return -1;
1283
1284         if (amth->type >= __AMT_MSG_MAX || !amth->type)
1285                 return -1;
1286         return amth->type;
1287 }
1288
1289 static void amt_clear_groups(struct amt_tunnel_list *tunnel)
1290 {
1291         struct amt_dev *amt = tunnel->amt;
1292         struct amt_group_node *gnode;
1293         struct hlist_node *t;
1294         int i;
1295
1296         spin_lock_bh(&tunnel->lock);
1297         rcu_read_lock();
1298         for (i = 0; i < amt->hash_buckets; i++)
1299                 hlist_for_each_entry_safe(gnode, t, &tunnel->groups[i], node)
1300                         amt_del_group(amt, gnode);
1301         rcu_read_unlock();
1302         spin_unlock_bh(&tunnel->lock);
1303 }
1304
1305 static void amt_tunnel_expire(struct work_struct *work)
1306 {
1307         struct amt_tunnel_list *tunnel = container_of(to_delayed_work(work),
1308                                                       struct amt_tunnel_list,
1309                                                       gc_wq);
1310         struct amt_dev *amt = tunnel->amt;
1311
1312         spin_lock_bh(&amt->lock);
1313         rcu_read_lock();
1314         list_del_rcu(&tunnel->list);
1315         amt->nr_tunnels--;
1316         amt_clear_groups(tunnel);
1317         rcu_read_unlock();
1318         spin_unlock_bh(&amt->lock);
1319         kfree_rcu(tunnel, rcu);
1320 }
1321
1322 static void amt_cleanup_srcs(struct amt_dev *amt,
1323                              struct amt_tunnel_list *tunnel,
1324                              struct amt_group_node *gnode)
1325 {
1326         struct amt_source_node *snode;
1327         struct hlist_node *t;
1328         int i;
1329
1330         /* Delete old sources */
1331         for (i = 0; i < amt->hash_buckets; i++) {
1332                 hlist_for_each_entry_safe(snode, t, &gnode->sources[i], node) {
1333                         if (snode->flags == AMT_SOURCE_OLD)
1334                                 amt_destroy_source(snode);
1335                 }
1336         }
1337
1338         /* switch from new to old */
1339         for (i = 0; i < amt->hash_buckets; i++)  {
1340                 hlist_for_each_entry_rcu(snode, &gnode->sources[i], node) {
1341                         snode->flags = AMT_SOURCE_OLD;
1342                         if (!gnode->v6)
1343                                 netdev_dbg(snode->gnode->amt->dev,
1344                                            "Add source as OLD %pI4 from %pI4\n",
1345                                            &snode->source_addr.ip4,
1346                                            &gnode->group_addr.ip4);
1347 #if IS_ENABLED(CONFIG_IPV6)
1348                         else
1349                                 netdev_dbg(snode->gnode->amt->dev,
1350                                            "Add source as OLD %pI6 from %pI6\n",
1351                                            &snode->source_addr.ip6,
1352                                            &gnode->group_addr.ip6);
1353 #endif
1354                 }
1355         }
1356 }
1357
1358 static void amt_add_srcs(struct amt_dev *amt, struct amt_tunnel_list *tunnel,
1359                          struct amt_group_node *gnode, void *grec,
1360                          bool v6)
1361 {
1362         struct igmpv3_grec *igmp_grec;
1363         struct amt_source_node *snode;
1364 #if IS_ENABLED(CONFIG_IPV6)
1365         struct mld2_grec *mld_grec;
1366 #endif
1367         union amt_addr src = {0,};
1368         u16 nsrcs;
1369         u32 hash;
1370         int i;
1371
1372         if (!v6) {
1373                 igmp_grec = (struct igmpv3_grec *)grec;
1374                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1375         } else {
1376 #if IS_ENABLED(CONFIG_IPV6)
1377                 mld_grec = (struct mld2_grec *)grec;
1378                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1379 #else
1380         return;
1381 #endif
1382         }
1383         for (i = 0; i < nsrcs; i++) {
1384                 if (tunnel->nr_sources >= amt->max_sources)
1385                         return;
1386                 if (!v6)
1387                         src.ip4 = igmp_grec->grec_src[i];
1388 #if IS_ENABLED(CONFIG_IPV6)
1389                 else
1390                         memcpy(&src.ip6, &mld_grec->grec_src[i],
1391                                sizeof(struct in6_addr));
1392 #endif
1393                 if (amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL, &src))
1394                         continue;
1395
1396                 snode = amt_alloc_snode(gnode, &src);
1397                 if (snode) {
1398                         hash = amt_source_hash(tunnel, &snode->source_addr);
1399                         hlist_add_head_rcu(&snode->node, &gnode->sources[hash]);
1400                         tunnel->nr_sources++;
1401                         gnode->nr_sources++;
1402
1403                         if (!gnode->v6)
1404                                 netdev_dbg(snode->gnode->amt->dev,
1405                                            "Add source as NEW %pI4 from %pI4\n",
1406                                            &snode->source_addr.ip4,
1407                                            &gnode->group_addr.ip4);
1408 #if IS_ENABLED(CONFIG_IPV6)
1409                         else
1410                                 netdev_dbg(snode->gnode->amt->dev,
1411                                            "Add source as NEW %pI6 from %pI6\n",
1412                                            &snode->source_addr.ip6,
1413                                            &gnode->group_addr.ip6);
1414 #endif
1415                 }
1416         }
1417 }
1418
1419 /* Router State   Report Rec'd New Router State
1420  * ------------   ------------ ----------------
1421  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)
1422  *
1423  * -----------+-----------+-----------+
1424  *            |    OLD    |    NEW    |
1425  * -----------+-----------+-----------+
1426  *    FWD     |     X     |    X+A    |
1427  * -----------+-----------+-----------+
1428  *    D_FWD   |     Y     |    Y-A    |
1429  * -----------+-----------+-----------+
1430  *    NONE    |           |     A     |
1431  * -----------+-----------+-----------+
1432  *
1433  * a) Received sources are NONE/NEW
1434  * b) All NONE will be deleted by amt_cleanup_srcs().
1435  * c) All OLD will be deleted by amt_cleanup_srcs().
1436  * d) After delete, NEW source will be switched to OLD.
1437  */
1438 static void amt_lookup_act_srcs(struct amt_tunnel_list *tunnel,
1439                                 struct amt_group_node *gnode,
1440                                 void *grec,
1441                                 enum amt_ops ops,
1442                                 enum amt_filter filter,
1443                                 enum amt_act act,
1444                                 bool v6)
1445 {
1446         struct amt_dev *amt = tunnel->amt;
1447         struct amt_source_node *snode;
1448         struct igmpv3_grec *igmp_grec;
1449 #if IS_ENABLED(CONFIG_IPV6)
1450         struct mld2_grec *mld_grec;
1451 #endif
1452         union amt_addr src = {0,};
1453         struct hlist_node *t;
1454         u16 nsrcs;
1455         int i, j;
1456
1457         if (!v6) {
1458                 igmp_grec = (struct igmpv3_grec *)grec;
1459                 nsrcs = ntohs(igmp_grec->grec_nsrcs);
1460         } else {
1461 #if IS_ENABLED(CONFIG_IPV6)
1462                 mld_grec = (struct mld2_grec *)grec;
1463                 nsrcs = ntohs(mld_grec->grec_nsrcs);
1464 #else
1465         return;
1466 #endif
1467         }
1468
1469         memset(&src, 0, sizeof(union amt_addr));
1470         switch (ops) {
1471         case AMT_OPS_INT:
1472                 /* A*B */
1473                 for (i = 0; i < nsrcs; i++) {
1474                         if (!v6)
1475                                 src.ip4 = igmp_grec->grec_src[i];
1476 #if IS_ENABLED(CONFIG_IPV6)
1477                         else
1478                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1479                                        sizeof(struct in6_addr));
1480 #endif
1481                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1482                         if (!snode)
1483                                 continue;
1484                         amt_act_src(tunnel, gnode, snode, act);
1485                 }
1486                 break;
1487         case AMT_OPS_UNI:
1488                 /* A+B */
1489                 for (i = 0; i < amt->hash_buckets; i++) {
1490                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1491                                                   node) {
1492                                 if (amt_status_filter(snode, filter))
1493                                         amt_act_src(tunnel, gnode, snode, act);
1494                         }
1495                 }
1496                 for (i = 0; i < nsrcs; i++) {
1497                         if (!v6)
1498                                 src.ip4 = igmp_grec->grec_src[i];
1499 #if IS_ENABLED(CONFIG_IPV6)
1500                         else
1501                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1502                                        sizeof(struct in6_addr));
1503 #endif
1504                         snode = amt_lookup_src(tunnel, gnode, filter, &src);
1505                         if (!snode)
1506                                 continue;
1507                         amt_act_src(tunnel, gnode, snode, act);
1508                 }
1509                 break;
1510         case AMT_OPS_SUB:
1511                 /* A-B */
1512                 for (i = 0; i < amt->hash_buckets; i++) {
1513                         hlist_for_each_entry_safe(snode, t, &gnode->sources[i],
1514                                                   node) {
1515                                 if (!amt_status_filter(snode, filter))
1516                                         continue;
1517                                 for (j = 0; j < nsrcs; j++) {
1518                                         if (!v6)
1519                                                 src.ip4 = igmp_grec->grec_src[j];
1520 #if IS_ENABLED(CONFIG_IPV6)
1521                                         else
1522                                                 memcpy(&src.ip6,
1523                                                        &mld_grec->grec_src[j],
1524                                                        sizeof(struct in6_addr));
1525 #endif
1526                                         if (amt_addr_equal(&snode->source_addr,
1527                                                            &src))
1528                                                 goto out_sub;
1529                                 }
1530                                 amt_act_src(tunnel, gnode, snode, act);
1531                                 continue;
1532 out_sub:;
1533                         }
1534                 }
1535                 break;
1536         case AMT_OPS_SUB_REV:
1537                 /* B-A */
1538                 for (i = 0; i < nsrcs; i++) {
1539                         if (!v6)
1540                                 src.ip4 = igmp_grec->grec_src[i];
1541 #if IS_ENABLED(CONFIG_IPV6)
1542                         else
1543                                 memcpy(&src.ip6, &mld_grec->grec_src[i],
1544                                        sizeof(struct in6_addr));
1545 #endif
1546                         snode = amt_lookup_src(tunnel, gnode, AMT_FILTER_ALL,
1547                                                &src);
1548                         if (!snode) {
1549                                 snode = amt_lookup_src(tunnel, gnode,
1550                                                        filter, &src);
1551                                 if (snode)
1552                                         amt_act_src(tunnel, gnode, snode, act);
1553                         }
1554                 }
1555                 break;
1556         default:
1557                 netdev_dbg(amt->dev, "Invalid type\n");
1558                 return;
1559         }
1560 }
1561
1562 static void amt_mcast_is_in_handler(struct amt_dev *amt,
1563                                     struct amt_tunnel_list *tunnel,
1564                                     struct amt_group_node *gnode,
1565                                     void *grec, void *zero_grec, bool v6)
1566 {
1567         if (gnode->filter_mode == MCAST_INCLUDE) {
1568 /* Router State   Report Rec'd New Router State        Actions
1569  * ------------   ------------ ----------------        -------
1570  * INCLUDE (A)    IS_IN (B)    INCLUDE (A+B)           (B)=GMI
1571  */
1572                 /* Update IS_IN (B) as FWD/NEW */
1573                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1574                                     AMT_FILTER_NONE_NEW,
1575                                     AMT_ACT_STATUS_FWD_NEW,
1576                                     v6);
1577                 /* Update INCLUDE (A) as NEW */
1578                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1579                                     AMT_FILTER_FWD,
1580                                     AMT_ACT_STATUS_FWD_NEW,
1581                                     v6);
1582                 /* (B)=GMI */
1583                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1584                                     AMT_FILTER_FWD_NEW,
1585                                     AMT_ACT_GMI,
1586                                     v6);
1587         } else {
1588 /* State        Actions
1589  * ------------   ------------ ----------------        -------
1590  * EXCLUDE (X,Y)  IS_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1591  */
1592                 /* Update (A) in (X, Y) as NONE/NEW */
1593                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1594                                     AMT_FILTER_BOTH,
1595                                     AMT_ACT_STATUS_NONE_NEW,
1596                                     v6);
1597                 /* Update FWD/OLD as FWD/NEW */
1598                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1599                                     AMT_FILTER_FWD,
1600                                     AMT_ACT_STATUS_FWD_NEW,
1601                                     v6);
1602                 /* Update IS_IN (A) as FWD/NEW */
1603                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1604                                     AMT_FILTER_NONE_NEW,
1605                                     AMT_ACT_STATUS_FWD_NEW,
1606                                     v6);
1607                 /* Update EXCLUDE (, Y-A) as D_FWD_NEW */
1608                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1609                                     AMT_FILTER_D_FWD,
1610                                     AMT_ACT_STATUS_D_FWD_NEW,
1611                                     v6);
1612         }
1613 }
1614
1615 static void amt_mcast_is_ex_handler(struct amt_dev *amt,
1616                                     struct amt_tunnel_list *tunnel,
1617                                     struct amt_group_node *gnode,
1618                                     void *grec, void *zero_grec, bool v6)
1619 {
1620         if (gnode->filter_mode == MCAST_INCLUDE) {
1621 /* Router State   Report Rec'd  New Router State         Actions
1622  * ------------   ------------  ----------------         -------
1623  * INCLUDE (A)    IS_EX (B)     EXCLUDE (A*B,B-A)        (B-A)=0
1624  *                                                       Delete (A-B)
1625  *                                                       Group Timer=GMI
1626  */
1627                 /* EXCLUDE(A*B, ) */
1628                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1629                                     AMT_FILTER_FWD,
1630                                     AMT_ACT_STATUS_FWD_NEW,
1631                                     v6);
1632                 /* EXCLUDE(, B-A) */
1633                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1634                                     AMT_FILTER_FWD,
1635                                     AMT_ACT_STATUS_D_FWD_NEW,
1636                                     v6);
1637                 /* (B-A)=0 */
1638                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1639                                     AMT_FILTER_D_FWD_NEW,
1640                                     AMT_ACT_GMI_ZERO,
1641                                     v6);
1642                 /* Group Timer=GMI */
1643                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1644                                       msecs_to_jiffies(amt_gmi(amt))))
1645                         dev_hold(amt->dev);
1646                 gnode->filter_mode = MCAST_EXCLUDE;
1647                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1648         } else {
1649 /* Router State   Report Rec'd  New Router State        Actions
1650  * ------------   ------------  ----------------        -------
1651  * EXCLUDE (X,Y)  IS_EX (A)     EXCLUDE (A-Y,Y*A)       (A-X-Y)=GMI
1652  *                                                      Delete (X-A)
1653  *                                                      Delete (Y-A)
1654  *                                                      Group Timer=GMI
1655  */
1656                 /* EXCLUDE (A-Y, ) */
1657                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1658                                     AMT_FILTER_D_FWD,
1659                                     AMT_ACT_STATUS_FWD_NEW,
1660                                     v6);
1661                 /* EXCLUDE (, Y*A ) */
1662                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1663                                     AMT_FILTER_D_FWD,
1664                                     AMT_ACT_STATUS_D_FWD_NEW,
1665                                     v6);
1666                 /* (A-X-Y)=GMI */
1667                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1668                                     AMT_FILTER_BOTH_NEW,
1669                                     AMT_ACT_GMI,
1670                                     v6);
1671                 /* Group Timer=GMI */
1672                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1673                                       msecs_to_jiffies(amt_gmi(amt))))
1674                         dev_hold(amt->dev);
1675                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1676         }
1677 }
1678
1679 static void amt_mcast_to_in_handler(struct amt_dev *amt,
1680                                     struct amt_tunnel_list *tunnel,
1681                                     struct amt_group_node *gnode,
1682                                     void *grec, void *zero_grec, bool v6)
1683 {
1684         if (gnode->filter_mode == MCAST_INCLUDE) {
1685 /* Router State   Report Rec'd New Router State        Actions
1686  * ------------   ------------ ----------------        -------
1687  * INCLUDE (A)    TO_IN (B)    INCLUDE (A+B)           (B)=GMI
1688  *                                                     Send Q(G,A-B)
1689  */
1690                 /* Update TO_IN (B) sources as FWD/NEW */
1691                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1692                                     AMT_FILTER_NONE_NEW,
1693                                     AMT_ACT_STATUS_FWD_NEW,
1694                                     v6);
1695                 /* Update INCLUDE (A) sources as NEW */
1696                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1697                                     AMT_FILTER_FWD,
1698                                     AMT_ACT_STATUS_FWD_NEW,
1699                                     v6);
1700                 /* (B)=GMI */
1701                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1702                                     AMT_FILTER_FWD_NEW,
1703                                     AMT_ACT_GMI,
1704                                     v6);
1705         } else {
1706 /* Router State   Report Rec'd New Router State        Actions
1707  * ------------   ------------ ----------------        -------
1708  * EXCLUDE (X,Y)  TO_IN (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1709  *                                                     Send Q(G,X-A)
1710  *                                                     Send Q(G)
1711  */
1712                 /* Update TO_IN (A) sources as FWD/NEW */
1713                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1714                                     AMT_FILTER_NONE_NEW,
1715                                     AMT_ACT_STATUS_FWD_NEW,
1716                                     v6);
1717                 /* Update EXCLUDE(X,) sources as FWD/NEW */
1718                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1719                                     AMT_FILTER_FWD,
1720                                     AMT_ACT_STATUS_FWD_NEW,
1721                                     v6);
1722                 /* EXCLUDE (, Y-A)
1723                  * (A) are already switched to FWD_NEW.
1724                  * So, D_FWD/OLD -> D_FWD/NEW is okay.
1725                  */
1726                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1727                                     AMT_FILTER_D_FWD,
1728                                     AMT_ACT_STATUS_D_FWD_NEW,
1729                                     v6);
1730                 /* (A)=GMI
1731                  * Only FWD_NEW will have (A) sources.
1732                  */
1733                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1734                                     AMT_FILTER_FWD_NEW,
1735                                     AMT_ACT_GMI,
1736                                     v6);
1737         }
1738 }
1739
1740 static void amt_mcast_to_ex_handler(struct amt_dev *amt,
1741                                     struct amt_tunnel_list *tunnel,
1742                                     struct amt_group_node *gnode,
1743                                     void *grec, void *zero_grec, bool v6)
1744 {
1745         if (gnode->filter_mode == MCAST_INCLUDE) {
1746 /* Router State   Report Rec'd New Router State        Actions
1747  * ------------   ------------ ----------------        -------
1748  * INCLUDE (A)    TO_EX (B)    EXCLUDE (A*B,B-A)       (B-A)=0
1749  *                                                     Delete (A-B)
1750  *                                                     Send Q(G,A*B)
1751  *                                                     Group Timer=GMI
1752  */
1753                 /* EXCLUDE (A*B, ) */
1754                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1755                                     AMT_FILTER_FWD,
1756                                     AMT_ACT_STATUS_FWD_NEW,
1757                                     v6);
1758                 /* EXCLUDE (, B-A) */
1759                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1760                                     AMT_FILTER_FWD,
1761                                     AMT_ACT_STATUS_D_FWD_NEW,
1762                                     v6);
1763                 /* (B-A)=0 */
1764                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1765                                     AMT_FILTER_D_FWD_NEW,
1766                                     AMT_ACT_GMI_ZERO,
1767                                     v6);
1768                 /* Group Timer=GMI */
1769                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1770                                       msecs_to_jiffies(amt_gmi(amt))))
1771                         dev_hold(amt->dev);
1772                 gnode->filter_mode = MCAST_EXCLUDE;
1773                 /* Delete (A-B) will be worked by amt_cleanup_srcs(). */
1774         } else {
1775 /* Router State   Report Rec'd New Router State        Actions
1776  * ------------   ------------ ----------------        -------
1777  * EXCLUDE (X,Y)  TO_EX (A)    EXCLUDE (A-Y,Y*A)       (A-X-Y)=Group Timer
1778  *                                                     Delete (X-A)
1779  *                                                     Delete (Y-A)
1780  *                                                     Send Q(G,A-Y)
1781  *                                                     Group Timer=GMI
1782  */
1783                 /* Update (A-X-Y) as NONE/OLD */
1784                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1785                                     AMT_FILTER_BOTH,
1786                                     AMT_ACT_GT,
1787                                     v6);
1788                 /* EXCLUDE (A-Y, ) */
1789                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1790                                     AMT_FILTER_D_FWD,
1791                                     AMT_ACT_STATUS_FWD_NEW,
1792                                     v6);
1793                 /* EXCLUDE (, Y*A) */
1794                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1795                                     AMT_FILTER_D_FWD,
1796                                     AMT_ACT_STATUS_D_FWD_NEW,
1797                                     v6);
1798                 /* Group Timer=GMI */
1799                 if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1800                                       msecs_to_jiffies(amt_gmi(amt))))
1801                         dev_hold(amt->dev);
1802                 /* Delete (X-A), (Y-A) will be worked by amt_cleanup_srcs(). */
1803         }
1804 }
1805
1806 static void amt_mcast_allow_handler(struct amt_dev *amt,
1807                                     struct amt_tunnel_list *tunnel,
1808                                     struct amt_group_node *gnode,
1809                                     void *grec, void *zero_grec, bool v6)
1810 {
1811         if (gnode->filter_mode == MCAST_INCLUDE) {
1812 /* Router State   Report Rec'd New Router State        Actions
1813  * ------------   ------------ ----------------        -------
1814  * INCLUDE (A)    ALLOW (B)    INCLUDE (A+B)           (B)=GMI
1815  */
1816                 /* INCLUDE (A+B) */
1817                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1818                                     AMT_FILTER_FWD,
1819                                     AMT_ACT_STATUS_FWD_NEW,
1820                                     v6);
1821                 /* (B)=GMI */
1822                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1823                                     AMT_FILTER_FWD_NEW,
1824                                     AMT_ACT_GMI,
1825                                     v6);
1826         } else {
1827 /* Router State   Report Rec'd New Router State        Actions
1828  * ------------   ------------ ----------------        -------
1829  * EXCLUDE (X,Y)  ALLOW (A)    EXCLUDE (X+A,Y-A)       (A)=GMI
1830  */
1831                 /* EXCLUDE (X+A, ) */
1832                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1833                                     AMT_FILTER_FWD,
1834                                     AMT_ACT_STATUS_FWD_NEW,
1835                                     v6);
1836                 /* EXCLUDE (, Y-A) */
1837                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB,
1838                                     AMT_FILTER_D_FWD,
1839                                     AMT_ACT_STATUS_D_FWD_NEW,
1840                                     v6);
1841                 /* (A)=GMI
1842                  * All (A) source are now FWD/NEW status.
1843                  */
1844                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_INT,
1845                                     AMT_FILTER_FWD_NEW,
1846                                     AMT_ACT_GMI,
1847                                     v6);
1848         }
1849 }
1850
1851 static void amt_mcast_block_handler(struct amt_dev *amt,
1852                                     struct amt_tunnel_list *tunnel,
1853                                     struct amt_group_node *gnode,
1854                                     void *grec, void *zero_grec, bool v6)
1855 {
1856         if (gnode->filter_mode == MCAST_INCLUDE) {
1857 /* Router State   Report Rec'd New Router State        Actions
1858  * ------------   ------------ ----------------        -------
1859  * INCLUDE (A)    BLOCK (B)    INCLUDE (A)             Send Q(G,A*B)
1860  */
1861                 /* INCLUDE (A) */
1862                 amt_lookup_act_srcs(tunnel, gnode, zero_grec, AMT_OPS_UNI,
1863                                     AMT_FILTER_FWD,
1864                                     AMT_ACT_STATUS_FWD_NEW,
1865                                     v6);
1866         } else {
1867 /* Router State   Report Rec'd New Router State        Actions
1868  * ------------   ------------ ----------------        -------
1869  * EXCLUDE (X,Y)  BLOCK (A)    EXCLUDE (X+(A-Y),Y)     (A-X-Y)=Group Timer
1870  *                                                     Send Q(G,A-Y)
1871  */
1872                 /* (A-X-Y)=Group Timer */
1873                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1874                                     AMT_FILTER_BOTH,
1875                                     AMT_ACT_GT,
1876                                     v6);
1877                 /* EXCLUDE (X, ) */
1878                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1879                                     AMT_FILTER_FWD,
1880                                     AMT_ACT_STATUS_FWD_NEW,
1881                                     v6);
1882                 /* EXCLUDE (X+(A-Y) */
1883                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_SUB_REV,
1884                                     AMT_FILTER_D_FWD,
1885                                     AMT_ACT_STATUS_FWD_NEW,
1886                                     v6);
1887                 /* EXCLUDE (, Y) */
1888                 amt_lookup_act_srcs(tunnel, gnode, grec, AMT_OPS_UNI,
1889                                     AMT_FILTER_D_FWD,
1890                                     AMT_ACT_STATUS_D_FWD_NEW,
1891                                     v6);
1892         }
1893 }
1894
1895 /* RFC 3376
1896  * 7.3.2. In the Presence of Older Version Group Members
1897  *
1898  * When Group Compatibility Mode is IGMPv2, a router internally
1899  * translates the following IGMPv2 messages for that group to their
1900  * IGMPv3 equivalents:
1901  *
1902  * IGMPv2 Message                IGMPv3 Equivalent
1903  * --------------                -----------------
1904  * Report                        IS_EX( {} )
1905  * Leave                         TO_IN( {} )
1906  */
1907 static void amt_igmpv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1908                                       struct amt_tunnel_list *tunnel)
1909 {
1910         struct igmphdr *ih = igmp_hdr(skb);
1911         struct iphdr *iph = ip_hdr(skb);
1912         struct amt_group_node *gnode;
1913         union amt_addr group, host;
1914
1915         memset(&group, 0, sizeof(union amt_addr));
1916         group.ip4 = ih->group;
1917         memset(&host, 0, sizeof(union amt_addr));
1918         host.ip4 = iph->saddr;
1919
1920         gnode = amt_lookup_group(tunnel, &group, &host, false);
1921         if (!gnode) {
1922                 gnode = amt_add_group(amt, tunnel, &group, &host, false);
1923                 if (!IS_ERR(gnode)) {
1924                         gnode->filter_mode = MCAST_EXCLUDE;
1925                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
1926                                               msecs_to_jiffies(amt_gmi(amt))))
1927                                 dev_hold(amt->dev);
1928                 }
1929         }
1930 }
1931
1932 /* RFC 3376
1933  * 7.3.2. In the Presence of Older Version Group Members
1934  *
1935  * When Group Compatibility Mode is IGMPv2, a router internally
1936  * translates the following IGMPv2 messages for that group to their
1937  * IGMPv3 equivalents:
1938  *
1939  * IGMPv2 Message                IGMPv3 Equivalent
1940  * --------------                -----------------
1941  * Report                        IS_EX( {} )
1942  * Leave                         TO_IN( {} )
1943  */
1944 static void amt_igmpv2_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
1945                                      struct amt_tunnel_list *tunnel)
1946 {
1947         struct igmphdr *ih = igmp_hdr(skb);
1948         struct iphdr *iph = ip_hdr(skb);
1949         struct amt_group_node *gnode;
1950         union amt_addr group, host;
1951
1952         memset(&group, 0, sizeof(union amt_addr));
1953         group.ip4 = ih->group;
1954         memset(&host, 0, sizeof(union amt_addr));
1955         host.ip4 = iph->saddr;
1956
1957         gnode = amt_lookup_group(tunnel, &group, &host, false);
1958         if (gnode)
1959                 amt_del_group(amt, gnode);
1960 }
1961
1962 static void amt_igmpv3_report_handler(struct amt_dev *amt, struct sk_buff *skb,
1963                                       struct amt_tunnel_list *tunnel)
1964 {
1965         struct igmpv3_report *ihrv3 = igmpv3_report_hdr(skb);
1966         int len = skb_transport_offset(skb) + sizeof(*ihrv3);
1967         void *zero_grec = (void *)&igmpv3_zero_grec;
1968         struct iphdr *iph = ip_hdr(skb);
1969         struct amt_group_node *gnode;
1970         union amt_addr group, host;
1971         struct igmpv3_grec *grec;
1972         u16 nsrcs;
1973         int i;
1974
1975         for (i = 0; i < ntohs(ihrv3->ngrec); i++) {
1976                 len += sizeof(*grec);
1977                 if (!ip_mc_may_pull(skb, len))
1978                         break;
1979
1980                 grec = (void *)(skb->data + len - sizeof(*grec));
1981                 nsrcs = ntohs(grec->grec_nsrcs);
1982
1983                 len += nsrcs * sizeof(__be32);
1984                 if (!ip_mc_may_pull(skb, len))
1985                         break;
1986
1987                 memset(&group, 0, sizeof(union amt_addr));
1988                 group.ip4 = grec->grec_mca;
1989                 memset(&host, 0, sizeof(union amt_addr));
1990                 host.ip4 = iph->saddr;
1991                 gnode = amt_lookup_group(tunnel, &group, &host, false);
1992                 if (!gnode) {
1993                         gnode = amt_add_group(amt, tunnel, &group, &host,
1994                                               false);
1995                         if (IS_ERR(gnode))
1996                                 continue;
1997                 }
1998
1999                 amt_add_srcs(amt, tunnel, gnode, grec, false);
2000                 switch (grec->grec_type) {
2001                 case IGMPV3_MODE_IS_INCLUDE:
2002                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2003                                                 zero_grec, false);
2004                         break;
2005                 case IGMPV3_MODE_IS_EXCLUDE:
2006                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2007                                                 zero_grec, false);
2008                         break;
2009                 case IGMPV3_CHANGE_TO_INCLUDE:
2010                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2011                                                 zero_grec, false);
2012                         break;
2013                 case IGMPV3_CHANGE_TO_EXCLUDE:
2014                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2015                                                 zero_grec, false);
2016                         break;
2017                 case IGMPV3_ALLOW_NEW_SOURCES:
2018                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2019                                                 zero_grec, false);
2020                         break;
2021                 case IGMPV3_BLOCK_OLD_SOURCES:
2022                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2023                                                 zero_grec, false);
2024                         break;
2025                 default:
2026                         break;
2027                 }
2028                 amt_cleanup_srcs(amt, tunnel, gnode);
2029         }
2030 }
2031
2032 /* caller held tunnel->lock */
2033 static void amt_igmp_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2034                                     struct amt_tunnel_list *tunnel)
2035 {
2036         struct igmphdr *ih = igmp_hdr(skb);
2037
2038         switch (ih->type) {
2039         case IGMPV3_HOST_MEMBERSHIP_REPORT:
2040                 amt_igmpv3_report_handler(amt, skb, tunnel);
2041                 break;
2042         case IGMPV2_HOST_MEMBERSHIP_REPORT:
2043                 amt_igmpv2_report_handler(amt, skb, tunnel);
2044                 break;
2045         case IGMP_HOST_LEAVE_MESSAGE:
2046                 amt_igmpv2_leave_handler(amt, skb, tunnel);
2047                 break;
2048         default:
2049                 break;
2050         }
2051 }
2052
2053 #if IS_ENABLED(CONFIG_IPV6)
2054 /* RFC 3810
2055  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2056  *
2057  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2058  * using the MLDv2 protocol for that multicast address.  When Multicast
2059  * Address Compatibility Mode is MLDv1, a router internally translates
2060  * the following MLDv1 messages for that multicast address to their
2061  * MLDv2 equivalents:
2062  *
2063  * MLDv1 Message                 MLDv2 Equivalent
2064  * --------------                -----------------
2065  * Report                        IS_EX( {} )
2066  * Done                          TO_IN( {} )
2067  */
2068 static void amt_mldv1_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2069                                      struct amt_tunnel_list *tunnel)
2070 {
2071         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2072         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2073         struct amt_group_node *gnode;
2074         union amt_addr group, host;
2075
2076         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2077         memcpy(&host.ip6, &ip6h->saddr, sizeof(struct in6_addr));
2078
2079         gnode = amt_lookup_group(tunnel, &group, &host, true);
2080         if (!gnode) {
2081                 gnode = amt_add_group(amt, tunnel, &group, &host, true);
2082                 if (!IS_ERR(gnode)) {
2083                         gnode->filter_mode = MCAST_EXCLUDE;
2084                         if (!mod_delayed_work(amt_wq, &gnode->group_timer,
2085                                               msecs_to_jiffies(amt_gmi(amt))))
2086                                 dev_hold(amt->dev);
2087                 }
2088         }
2089 }
2090
2091 /* RFC 3810
2092  * 8.3.2. In the Presence of MLDv1 Multicast Address Listeners
2093  *
2094  * When Multicast Address Compatibility Mode is MLDv2, a router acts
2095  * using the MLDv2 protocol for that multicast address.  When Multicast
2096  * Address Compatibility Mode is MLDv1, a router internally translates
2097  * the following MLDv1 messages for that multicast address to their
2098  * MLDv2 equivalents:
2099  *
2100  * MLDv1 Message                 MLDv2 Equivalent
2101  * --------------                -----------------
2102  * Report                        IS_EX( {} )
2103  * Done                          TO_IN( {} )
2104  */
2105 static void amt_mldv1_leave_handler(struct amt_dev *amt, struct sk_buff *skb,
2106                                     struct amt_tunnel_list *tunnel)
2107 {
2108         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2109         struct iphdr *iph = ip_hdr(skb);
2110         struct amt_group_node *gnode;
2111         union amt_addr group, host;
2112
2113         memcpy(&group.ip6, &mld->mld_mca, sizeof(struct in6_addr));
2114         memset(&host, 0, sizeof(union amt_addr));
2115         host.ip4 = iph->saddr;
2116
2117         gnode = amt_lookup_group(tunnel, &group, &host, true);
2118         if (gnode) {
2119                 amt_del_group(amt, gnode);
2120                 return;
2121         }
2122 }
2123
2124 static void amt_mldv2_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2125                                      struct amt_tunnel_list *tunnel)
2126 {
2127         struct mld2_report *mld2r = (struct mld2_report *)icmp6_hdr(skb);
2128         int len = skb_transport_offset(skb) + sizeof(*mld2r);
2129         void *zero_grec = (void *)&mldv2_zero_grec;
2130         struct ipv6hdr *ip6h = ipv6_hdr(skb);
2131         struct amt_group_node *gnode;
2132         union amt_addr group, host;
2133         struct mld2_grec *grec;
2134         u16 nsrcs;
2135         int i;
2136
2137         for (i = 0; i < ntohs(mld2r->mld2r_ngrec); i++) {
2138                 len += sizeof(*grec);
2139                 if (!ipv6_mc_may_pull(skb, len))
2140                         break;
2141
2142                 grec = (void *)(skb->data + len - sizeof(*grec));
2143                 nsrcs = ntohs(grec->grec_nsrcs);
2144
2145                 len += nsrcs * sizeof(struct in6_addr);
2146                 if (!ipv6_mc_may_pull(skb, len))
2147                         break;
2148
2149                 memset(&group, 0, sizeof(union amt_addr));
2150                 group.ip6 = grec->grec_mca;
2151                 memset(&host, 0, sizeof(union amt_addr));
2152                 host.ip6 = ip6h->saddr;
2153                 gnode = amt_lookup_group(tunnel, &group, &host, true);
2154                 if (!gnode) {
2155                         gnode = amt_add_group(amt, tunnel, &group, &host,
2156                                               ETH_P_IPV6);
2157                         if (IS_ERR(gnode))
2158                                 continue;
2159                 }
2160
2161                 amt_add_srcs(amt, tunnel, gnode, grec, true);
2162                 switch (grec->grec_type) {
2163                 case MLD2_MODE_IS_INCLUDE:
2164                         amt_mcast_is_in_handler(amt, tunnel, gnode, grec,
2165                                                 zero_grec, true);
2166                         break;
2167                 case MLD2_MODE_IS_EXCLUDE:
2168                         amt_mcast_is_ex_handler(amt, tunnel, gnode, grec,
2169                                                 zero_grec, true);
2170                         break;
2171                 case MLD2_CHANGE_TO_INCLUDE:
2172                         amt_mcast_to_in_handler(amt, tunnel, gnode, grec,
2173                                                 zero_grec, true);
2174                         break;
2175                 case MLD2_CHANGE_TO_EXCLUDE:
2176                         amt_mcast_to_ex_handler(amt, tunnel, gnode, grec,
2177                                                 zero_grec, true);
2178                         break;
2179                 case MLD2_ALLOW_NEW_SOURCES:
2180                         amt_mcast_allow_handler(amt, tunnel, gnode, grec,
2181                                                 zero_grec, true);
2182                         break;
2183                 case MLD2_BLOCK_OLD_SOURCES:
2184                         amt_mcast_block_handler(amt, tunnel, gnode, grec,
2185                                                 zero_grec, true);
2186                         break;
2187                 default:
2188                         break;
2189                 }
2190                 amt_cleanup_srcs(amt, tunnel, gnode);
2191         }
2192 }
2193
2194 /* caller held tunnel->lock */
2195 static void amt_mld_report_handler(struct amt_dev *amt, struct sk_buff *skb,
2196                                    struct amt_tunnel_list *tunnel)
2197 {
2198         struct mld_msg *mld = (struct mld_msg *)icmp6_hdr(skb);
2199
2200         switch (mld->mld_type) {
2201         case ICMPV6_MGM_REPORT:
2202                 amt_mldv1_report_handler(amt, skb, tunnel);
2203                 break;
2204         case ICMPV6_MLD2_REPORT:
2205                 amt_mldv2_report_handler(amt, skb, tunnel);
2206                 break;
2207         case ICMPV6_MGM_REDUCTION:
2208                 amt_mldv1_leave_handler(amt, skb, tunnel);
2209                 break;
2210         default:
2211                 break;
2212         }
2213 }
2214 #endif
2215
2216 static bool amt_advertisement_handler(struct amt_dev *amt, struct sk_buff *skb)
2217 {
2218         struct amt_header_advertisement *amta;
2219         int hdr_size;
2220
2221         hdr_size = sizeof(*amta) - sizeof(struct amt_header);
2222
2223         if (!pskb_may_pull(skb, hdr_size))
2224                 return true;
2225
2226         amta = (struct amt_header_advertisement *)(udp_hdr(skb) + 1);
2227         if (!amta->ip4)
2228                 return true;
2229
2230         if (amta->reserved || amta->version)
2231                 return true;
2232
2233         if (ipv4_is_loopback(amta->ip4) || ipv4_is_multicast(amta->ip4) ||
2234             ipv4_is_zeronet(amta->ip4))
2235                 return true;
2236
2237         amt->remote_ip = amta->ip4;
2238         netdev_dbg(amt->dev, "advertised remote ip = %pI4\n", &amt->remote_ip);
2239         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2240
2241         amt_update_gw_status(amt, AMT_STATUS_RECEIVED_ADVERTISEMENT, true);
2242         return false;
2243 }
2244
2245 static bool amt_multicast_data_handler(struct amt_dev *amt, struct sk_buff *skb)
2246 {
2247         struct amt_header_mcast_data *amtmd;
2248         int hdr_size, len, err;
2249         struct ethhdr *eth;
2250         struct iphdr *iph;
2251
2252         amtmd = (struct amt_header_mcast_data *)(udp_hdr(skb) + 1);
2253         if (amtmd->reserved || amtmd->version)
2254                 return true;
2255
2256         hdr_size = sizeof(*amtmd) + sizeof(struct udphdr);
2257         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_IP), false))
2258                 return true;
2259         skb_reset_network_header(skb);
2260         skb_push(skb, sizeof(*eth));
2261         skb_reset_mac_header(skb);
2262         skb_pull(skb, sizeof(*eth));
2263         eth = eth_hdr(skb);
2264         iph = ip_hdr(skb);
2265         if (iph->version == 4) {
2266                 if (!ipv4_is_multicast(iph->daddr))
2267                         return true;
2268                 skb->protocol = htons(ETH_P_IP);
2269                 eth->h_proto = htons(ETH_P_IP);
2270                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2271 #if IS_ENABLED(CONFIG_IPV6)
2272         } else if (iph->version == 6) {
2273                 struct ipv6hdr *ip6h;
2274
2275                 ip6h = ipv6_hdr(skb);
2276                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2277                         return true;
2278                 skb->protocol = htons(ETH_P_IPV6);
2279                 eth->h_proto = htons(ETH_P_IPV6);
2280                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2281 #endif
2282         } else {
2283                 return true;
2284         }
2285
2286         skb->pkt_type = PACKET_MULTICAST;
2287         skb->ip_summed = CHECKSUM_NONE;
2288         len = skb->len;
2289         err = gro_cells_receive(&amt->gro_cells, skb);
2290         if (likely(err == NET_RX_SUCCESS))
2291                 dev_sw_netstats_rx_add(amt->dev, len);
2292         else
2293                 amt->dev->stats.rx_dropped++;
2294
2295         return false;
2296 }
2297
2298 static bool amt_membership_query_handler(struct amt_dev *amt,
2299                                          struct sk_buff *skb)
2300 {
2301         struct amt_header_membership_query *amtmq;
2302         struct igmpv3_query *ihv3;
2303         struct ethhdr *eth, *oeth;
2304         struct iphdr *iph;
2305         int hdr_size, len;
2306
2307         hdr_size = sizeof(*amtmq) - sizeof(struct amt_header);
2308
2309         if (!pskb_may_pull(skb, hdr_size))
2310                 return true;
2311
2312         amtmq = (struct amt_header_membership_query *)(udp_hdr(skb) + 1);
2313         if (amtmq->reserved || amtmq->version)
2314                 return true;
2315
2316         hdr_size = sizeof(*amtmq) + sizeof(struct udphdr) - sizeof(*eth);
2317         if (iptunnel_pull_header(skb, hdr_size, htons(ETH_P_TEB), false))
2318                 return true;
2319         oeth = eth_hdr(skb);
2320         skb_reset_mac_header(skb);
2321         skb_pull(skb, sizeof(*eth));
2322         skb_reset_network_header(skb);
2323         eth = eth_hdr(skb);
2324         iph = ip_hdr(skb);
2325         if (iph->version == 4) {
2326                 if (!ipv4_is_multicast(iph->daddr))
2327                         return true;
2328                 if (!pskb_may_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS +
2329                                    sizeof(*ihv3)))
2330                         return true;
2331
2332                 ihv3 = skb_pull(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2333                 skb_reset_transport_header(skb);
2334                 skb_push(skb, sizeof(*iph) + AMT_IPHDR_OPTS);
2335                 spin_lock_bh(&amt->lock);
2336                 amt->ready4 = true;
2337                 amt->mac = amtmq->response_mac;
2338                 amt->req_cnt = 0;
2339                 amt->qi = ihv3->qqic;
2340                 spin_unlock_bh(&amt->lock);
2341                 skb->protocol = htons(ETH_P_IP);
2342                 eth->h_proto = htons(ETH_P_IP);
2343                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2344 #if IS_ENABLED(CONFIG_IPV6)
2345         } else if (iph->version == 6) {
2346                 struct ipv6hdr *ip6h = ipv6_hdr(skb);
2347                 struct mld2_query *mld2q;
2348
2349                 if (!ipv6_addr_is_multicast(&ip6h->daddr))
2350                         return true;
2351                 if (!pskb_may_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS +
2352                                    sizeof(*mld2q)))
2353                         return true;
2354
2355                 mld2q = skb_pull(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2356                 skb_reset_transport_header(skb);
2357                 skb_push(skb, sizeof(*ip6h) + AMT_IP6HDR_OPTS);
2358                 spin_lock_bh(&amt->lock);
2359                 amt->ready6 = true;
2360                 amt->mac = amtmq->response_mac;
2361                 amt->req_cnt = 0;
2362                 amt->qi = mld2q->mld2q_qqic;
2363                 spin_unlock_bh(&amt->lock);
2364                 skb->protocol = htons(ETH_P_IPV6);
2365                 eth->h_proto = htons(ETH_P_IPV6);
2366                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2367 #endif
2368         } else {
2369                 return true;
2370         }
2371
2372         ether_addr_copy(eth->h_source, oeth->h_source);
2373         skb->pkt_type = PACKET_MULTICAST;
2374         skb->ip_summed = CHECKSUM_NONE;
2375         len = skb->len;
2376         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2377                 amt_update_gw_status(amt, AMT_STATUS_RECEIVED_QUERY, true);
2378                 dev_sw_netstats_rx_add(amt->dev, len);
2379         } else {
2380                 amt->dev->stats.rx_dropped++;
2381         }
2382
2383         return false;
2384 }
2385
2386 static bool amt_update_handler(struct amt_dev *amt, struct sk_buff *skb)
2387 {
2388         struct amt_header_membership_update *amtmu;
2389         struct amt_tunnel_list *tunnel;
2390         struct udphdr *udph;
2391         struct ethhdr *eth;
2392         struct iphdr *iph;
2393         int len;
2394
2395         iph = ip_hdr(skb);
2396         udph = udp_hdr(skb);
2397
2398         if (__iptunnel_pull_header(skb, sizeof(*udph), skb->protocol,
2399                                    false, false))
2400                 return true;
2401
2402         amtmu = (struct amt_header_membership_update *)skb->data;
2403         if (amtmu->reserved || amtmu->version)
2404                 return true;
2405
2406         skb_pull(skb, sizeof(*amtmu));
2407         skb_reset_network_header(skb);
2408
2409         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list) {
2410                 if (tunnel->ip4 == iph->saddr) {
2411                         if ((amtmu->nonce == tunnel->nonce &&
2412                              amtmu->response_mac == tunnel->mac)) {
2413                                 mod_delayed_work(amt_wq, &tunnel->gc_wq,
2414                                                  msecs_to_jiffies(amt_gmi(amt))
2415                                                                   * 3);
2416                                 goto report;
2417                         } else {
2418                                 netdev_dbg(amt->dev, "Invalid MAC\n");
2419                                 return true;
2420                         }
2421                 }
2422         }
2423
2424         return false;
2425
2426 report:
2427         iph = ip_hdr(skb);
2428         if (iph->version == 4) {
2429                 if (ip_mc_check_igmp(skb)) {
2430                         netdev_dbg(amt->dev, "Invalid IGMP\n");
2431                         return true;
2432                 }
2433
2434                 spin_lock_bh(&tunnel->lock);
2435                 amt_igmp_report_handler(amt, skb, tunnel);
2436                 spin_unlock_bh(&tunnel->lock);
2437
2438                 skb_push(skb, sizeof(struct ethhdr));
2439                 skb_reset_mac_header(skb);
2440                 eth = eth_hdr(skb);
2441                 skb->protocol = htons(ETH_P_IP);
2442                 eth->h_proto = htons(ETH_P_IP);
2443                 ip_eth_mc_map(iph->daddr, eth->h_dest);
2444 #if IS_ENABLED(CONFIG_IPV6)
2445         } else if (iph->version == 6) {
2446                 struct ipv6hdr *ip6h = ipv6_hdr(skb);
2447
2448                 if (ipv6_mc_check_mld(skb)) {
2449                         netdev_dbg(amt->dev, "Invalid MLD\n");
2450                         return true;
2451                 }
2452
2453                 spin_lock_bh(&tunnel->lock);
2454                 amt_mld_report_handler(amt, skb, tunnel);
2455                 spin_unlock_bh(&tunnel->lock);
2456
2457                 skb_push(skb, sizeof(struct ethhdr));
2458                 skb_reset_mac_header(skb);
2459                 eth = eth_hdr(skb);
2460                 skb->protocol = htons(ETH_P_IPV6);
2461                 eth->h_proto = htons(ETH_P_IPV6);
2462                 ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest);
2463 #endif
2464         } else {
2465                 netdev_dbg(amt->dev, "Unsupported Protocol\n");
2466                 return true;
2467         }
2468
2469         skb_pull(skb, sizeof(struct ethhdr));
2470         skb->pkt_type = PACKET_MULTICAST;
2471         skb->ip_summed = CHECKSUM_NONE;
2472         len = skb->len;
2473         if (__netif_rx(skb) == NET_RX_SUCCESS) {
2474                 amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_UPDATE,
2475                                         true);
2476                 dev_sw_netstats_rx_add(amt->dev, len);
2477         } else {
2478                 amt->dev->stats.rx_dropped++;
2479         }
2480
2481         return false;
2482 }
2483
2484 static void amt_send_advertisement(struct amt_dev *amt, __be32 nonce,
2485                                    __be32 daddr, __be16 dport)
2486 {
2487         struct amt_header_advertisement *amta;
2488         int hlen, tlen, offset;
2489         struct socket *sock;
2490         struct udphdr *udph;
2491         struct sk_buff *skb;
2492         struct iphdr *iph;
2493         struct rtable *rt;
2494         struct flowi4 fl4;
2495         u32 len;
2496         int err;
2497
2498         rcu_read_lock();
2499         sock = rcu_dereference(amt->sock);
2500         if (!sock)
2501                 goto out;
2502
2503         if (!netif_running(amt->stream_dev) || !netif_running(amt->dev))
2504                 goto out;
2505
2506         rt = ip_route_output_ports(amt->net, &fl4, sock->sk,
2507                                    daddr, amt->local_ip,
2508                                    dport, amt->relay_port,
2509                                    IPPROTO_UDP, 0,
2510                                    amt->stream_dev->ifindex);
2511         if (IS_ERR(rt)) {
2512                 amt->dev->stats.tx_errors++;
2513                 goto out;
2514         }
2515
2516         hlen = LL_RESERVED_SPACE(amt->dev);
2517         tlen = amt->dev->needed_tailroom;
2518         len = hlen + tlen + sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2519         skb = netdev_alloc_skb_ip_align(amt->dev, len);
2520         if (!skb) {
2521                 ip_rt_put(rt);
2522                 amt->dev->stats.tx_errors++;
2523                 goto out;
2524         }
2525
2526         skb->priority = TC_PRIO_CONTROL;
2527         skb_dst_set(skb, &rt->dst);
2528
2529         len = sizeof(*iph) + sizeof(*udph) + sizeof(*amta);
2530         skb_reset_network_header(skb);
2531         skb_put(skb, len);
2532         amta = skb_pull(skb, sizeof(*iph) + sizeof(*udph));
2533         amta->version   = 0;
2534         amta->type      = AMT_MSG_ADVERTISEMENT;
2535         amta->reserved  = 0;
2536         amta->nonce     = nonce;
2537         amta->ip4       = amt->local_ip;
2538         skb_push(skb, sizeof(*udph));
2539         skb_reset_transport_header(skb);
2540         udph            = udp_hdr(skb);
2541         udph->source    = amt->relay_port;
2542         udph->dest      = dport;
2543         udph->len       = htons(sizeof(*amta) + sizeof(*udph));
2544         udph->check     = 0;
2545         offset = skb_transport_offset(skb);
2546         skb->csum = skb_checksum(skb, offset, skb->len - offset, 0);
2547         udph->check = csum_tcpudp_magic(amt->local_ip, daddr,
2548                                         sizeof(*udph) + sizeof(*amta),
2549                                         IPPROTO_UDP, skb->csum);
2550
2551         skb_push(skb, sizeof(*iph));
2552         iph             = ip_hdr(skb);
2553         iph->version    = 4;
2554         iph->ihl        = (sizeof(struct iphdr)) >> 2;
2555         iph->tos        = AMT_TOS;
2556         iph->frag_off   = 0;
2557         iph->ttl        = ip4_dst_hoplimit(&rt->dst);
2558         iph->daddr      = daddr;
2559         iph->saddr      = amt->local_ip;
2560         iph->protocol   = IPPROTO_UDP;
2561         iph->tot_len    = htons(len);
2562
2563         skb->ip_summed = CHECKSUM_NONE;
2564         ip_select_ident(amt->net, skb, NULL);
2565         ip_send_check(iph);
2566         err = ip_local_out(amt->net, sock->sk, skb);
2567         if (unlikely(net_xmit_eval(err)))
2568                 amt->dev->stats.tx_errors++;
2569
2570 out:
2571         rcu_read_unlock();
2572 }
2573
2574 static bool amt_discovery_handler(struct amt_dev *amt, struct sk_buff *skb)
2575 {
2576         struct amt_header_discovery *amtd;
2577         struct udphdr *udph;
2578         struct iphdr *iph;
2579
2580         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtd)))
2581                 return true;
2582
2583         iph = ip_hdr(skb);
2584         udph = udp_hdr(skb);
2585         amtd = (struct amt_header_discovery *)(udp_hdr(skb) + 1);
2586
2587         if (amtd->reserved || amtd->version)
2588                 return true;
2589
2590         amt_send_advertisement(amt, amtd->nonce, iph->saddr, udph->source);
2591
2592         return false;
2593 }
2594
2595 static bool amt_request_handler(struct amt_dev *amt, struct sk_buff *skb)
2596 {
2597         struct amt_header_request *amtrh;
2598         struct amt_tunnel_list *tunnel;
2599         unsigned long long key;
2600         struct udphdr *udph;
2601         struct iphdr *iph;
2602         u64 mac;
2603         int i;
2604
2605         if (!pskb_may_pull(skb, sizeof(*udph) + sizeof(*amtrh)))
2606                 return true;
2607
2608         iph = ip_hdr(skb);
2609         udph = udp_hdr(skb);
2610         amtrh = (struct amt_header_request *)(udp_hdr(skb) + 1);
2611
2612         if (amtrh->reserved1 || amtrh->reserved2 || amtrh->version)
2613                 return true;
2614
2615         list_for_each_entry_rcu(tunnel, &amt->tunnel_list, list)
2616                 if (tunnel->ip4 == iph->saddr)
2617                         goto send;
2618
2619         if (amt->nr_tunnels >= amt->max_tunnels) {
2620                 icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0);
2621                 return true;
2622         }
2623
2624         tunnel = kzalloc(sizeof(*tunnel) +
2625                          (sizeof(struct hlist_head) * amt->hash_buckets),
2626                          GFP_ATOMIC);
2627         if (!tunnel)
2628                 return true;
2629
2630         tunnel->source_port = udph->source;
2631         tunnel->ip4 = iph->saddr;
2632
2633         memcpy(&key, &tunnel->key, sizeof(unsigned long long));
2634         tunnel->amt = amt;
2635         spin_lock_init(&tunnel->lock);
2636         for (i = 0; i < amt->hash_buckets; i++)
2637                 INIT_HLIST_HEAD(&tunnel->groups[i]);
2638
2639         INIT_DELAYED_WORK(&tunnel->gc_wq, amt_tunnel_expire);
2640
2641         spin_lock_bh(&amt->lock);
2642         list_add_tail_rcu(&tunnel->list, &amt->tunnel_list);
2643         tunnel->key = amt->key;
2644         amt_update_relay_status(tunnel, AMT_STATUS_RECEIVED_REQUEST, true);
2645         amt->nr_tunnels++;
2646         mod_delayed_work(amt_wq, &tunnel->gc_wq,
2647                          msecs_to_jiffies(amt_gmi(amt)));
2648         spin_unlock_bh(&amt->lock);
2649
2650 send:
2651         tunnel->nonce = amtrh->nonce;
2652         mac = siphash_3u32((__force u32)tunnel->ip4,
2653                            (__force u32)tunnel->source_port,
2654                            (__force u32)tunnel->nonce,
2655                            &tunnel->key);
2656         tunnel->mac = mac >> 16;
2657
2658         if (!netif_running(amt->dev) || !netif_running(amt->stream_dev))
2659                 return true;
2660
2661         if (!amtrh->p)
2662                 amt_send_igmp_gq(amt, tunnel);
2663         else
2664                 amt_send_mld_gq(amt, tunnel);
2665
2666         return false;
2667 }
2668
2669 static int amt_rcv(struct sock *sk, struct sk_buff *skb)
2670 {
2671         struct amt_dev *amt;
2672         struct iphdr *iph;
2673         int type;
2674         bool err;
2675
2676         rcu_read_lock_bh();
2677         amt = rcu_dereference_sk_user_data(sk);
2678         if (!amt) {
2679                 err = true;
2680                 goto out;
2681         }
2682
2683         skb->dev = amt->dev;
2684         iph = ip_hdr(skb);
2685         type = amt_parse_type(skb);
2686         if (type == -1) {
2687                 err = true;
2688                 goto drop;
2689         }
2690
2691         if (amt->mode == AMT_MODE_GATEWAY) {
2692                 switch (type) {
2693                 case AMT_MSG_ADVERTISEMENT:
2694                         if (iph->saddr != amt->discovery_ip) {
2695                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2696                                 err = true;
2697                                 goto drop;
2698                         }
2699                         if (amt_advertisement_handler(amt, skb))
2700                                 amt->dev->stats.rx_dropped++;
2701                         goto out;
2702                 case AMT_MSG_MULTICAST_DATA:
2703                         if (iph->saddr != amt->remote_ip) {
2704                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2705                                 err = true;
2706                                 goto drop;
2707                         }
2708                         err = amt_multicast_data_handler(amt, skb);
2709                         if (err)
2710                                 goto drop;
2711                         else
2712                                 goto out;
2713                 case AMT_MSG_MEMBERSHIP_QUERY:
2714                         if (iph->saddr != amt->remote_ip) {
2715                                 netdev_dbg(amt->dev, "Invalid Relay IP\n");
2716                                 err = true;
2717                                 goto drop;
2718                         }
2719                         err = amt_membership_query_handler(amt, skb);
2720                         if (err)
2721                                 goto drop;
2722                         else
2723                                 goto out;
2724                 default:
2725                         err = true;
2726                         netdev_dbg(amt->dev, "Invalid type of Gateway\n");
2727                         break;
2728                 }
2729         } else {
2730                 switch (type) {
2731                 case AMT_MSG_DISCOVERY:
2732                         err = amt_discovery_handler(amt, skb);
2733                         break;
2734                 case AMT_MSG_REQUEST:
2735                         err = amt_request_handler(amt, skb);
2736                         break;
2737                 case AMT_MSG_MEMBERSHIP_UPDATE:
2738                         err = amt_update_handler(amt, skb);
2739                         if (err)
2740                                 goto drop;
2741                         else
2742                                 goto out;
2743                 default:
2744                         err = true;
2745                         netdev_dbg(amt->dev, "Invalid type of relay\n");
2746                         break;
2747                 }
2748         }
2749 drop:
2750         if (err) {
2751                 amt->dev->stats.rx_dropped++;
2752                 kfree_skb(skb);
2753         } else {
2754                 consume_skb(skb);
2755         }
2756 out:
2757         rcu_read_unlock_bh();
2758         return 0;
2759 }
2760
2761 static int amt_err_lookup(struct sock *sk, struct sk_buff *skb)
2762 {
2763         struct amt_dev *amt;
2764         int type;
2765
2766         rcu_read_lock_bh();
2767         amt = rcu_dereference_sk_user_data(sk);
2768         if (!amt)
2769                 goto out;
2770
2771         if (amt->mode != AMT_MODE_GATEWAY)
2772                 goto drop;
2773
2774         type = amt_parse_type(skb);
2775         if (type == -1)
2776                 goto drop;
2777
2778         netdev_dbg(amt->dev, "Received IGMP Unreachable of %s\n",
2779                    type_str[type]);
2780         switch (type) {
2781         case AMT_MSG_DISCOVERY:
2782                 break;
2783         case AMT_MSG_REQUEST:
2784         case AMT_MSG_MEMBERSHIP_UPDATE:
2785                 if (amt->status >= AMT_STATUS_RECEIVED_ADVERTISEMENT)
2786                         mod_delayed_work(amt_wq, &amt->req_wq, 0);
2787                 break;
2788         default:
2789                 goto drop;
2790         }
2791 out:
2792         rcu_read_unlock_bh();
2793         return 0;
2794 drop:
2795         rcu_read_unlock_bh();
2796         amt->dev->stats.rx_dropped++;
2797         return 0;
2798 }
2799
2800 static struct socket *amt_create_sock(struct net *net, __be16 port)
2801 {
2802         struct udp_port_cfg udp_conf;
2803         struct socket *sock;
2804         int err;
2805
2806         memset(&udp_conf, 0, sizeof(udp_conf));
2807         udp_conf.family = AF_INET;
2808         udp_conf.local_ip.s_addr = htonl(INADDR_ANY);
2809
2810         udp_conf.local_udp_port = port;
2811
2812         err = udp_sock_create(net, &udp_conf, &sock);
2813         if (err < 0)
2814                 return ERR_PTR(err);
2815
2816         return sock;
2817 }
2818
2819 static int amt_socket_create(struct amt_dev *amt)
2820 {
2821         struct udp_tunnel_sock_cfg tunnel_cfg;
2822         struct socket *sock;
2823
2824         sock = amt_create_sock(amt->net, amt->relay_port);
2825         if (IS_ERR(sock))
2826                 return PTR_ERR(sock);
2827
2828         /* Mark socket as an encapsulation socket */
2829         memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
2830         tunnel_cfg.sk_user_data = amt;
2831         tunnel_cfg.encap_type = 1;
2832         tunnel_cfg.encap_rcv = amt_rcv;
2833         tunnel_cfg.encap_err_lookup = amt_err_lookup;
2834         tunnel_cfg.encap_destroy = NULL;
2835         setup_udp_tunnel_sock(amt->net, sock, &tunnel_cfg);
2836
2837         rcu_assign_pointer(amt->sock, sock);
2838         return 0;
2839 }
2840
2841 static int amt_dev_open(struct net_device *dev)
2842 {
2843         struct amt_dev *amt = netdev_priv(dev);
2844         int err;
2845
2846         amt->ready4 = false;
2847         amt->ready6 = false;
2848
2849         err = amt_socket_create(amt);
2850         if (err)
2851                 return err;
2852
2853         amt->req_cnt = 0;
2854         amt->remote_ip = 0;
2855         get_random_bytes(&amt->key, sizeof(siphash_key_t));
2856
2857         amt->status = AMT_STATUS_INIT;
2858         if (amt->mode == AMT_MODE_GATEWAY) {
2859                 mod_delayed_work(amt_wq, &amt->discovery_wq, 0);
2860                 mod_delayed_work(amt_wq, &amt->req_wq, 0);
2861         } else if (amt->mode == AMT_MODE_RELAY) {
2862                 mod_delayed_work(amt_wq, &amt->secret_wq,
2863                                  msecs_to_jiffies(AMT_SECRET_TIMEOUT));
2864         }
2865         return err;
2866 }
2867
2868 static int amt_dev_stop(struct net_device *dev)
2869 {
2870         struct amt_dev *amt = netdev_priv(dev);
2871         struct amt_tunnel_list *tunnel, *tmp;
2872         struct socket *sock;
2873
2874         cancel_delayed_work_sync(&amt->req_wq);
2875         cancel_delayed_work_sync(&amt->discovery_wq);
2876         cancel_delayed_work_sync(&amt->secret_wq);
2877
2878         /* shutdown */
2879         sock = rtnl_dereference(amt->sock);
2880         RCU_INIT_POINTER(amt->sock, NULL);
2881         synchronize_net();
2882         if (sock)
2883                 udp_tunnel_sock_release(sock);
2884
2885         amt->ready4 = false;
2886         amt->ready6 = false;
2887         amt->req_cnt = 0;
2888         amt->remote_ip = 0;
2889
2890         list_for_each_entry_safe(tunnel, tmp, &amt->tunnel_list, list) {
2891                 list_del_rcu(&tunnel->list);
2892                 amt->nr_tunnels--;
2893                 cancel_delayed_work_sync(&tunnel->gc_wq);
2894                 amt_clear_groups(tunnel);
2895                 kfree_rcu(tunnel, rcu);
2896         }
2897
2898         return 0;
2899 }
2900
2901 static const struct device_type amt_type = {
2902         .name = "amt",
2903 };
2904
2905 static int amt_dev_init(struct net_device *dev)
2906 {
2907         struct amt_dev *amt = netdev_priv(dev);
2908         int err;
2909
2910         amt->dev = dev;
2911         dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
2912         if (!dev->tstats)
2913                 return -ENOMEM;
2914
2915         err = gro_cells_init(&amt->gro_cells, dev);
2916         if (err) {
2917                 free_percpu(dev->tstats);
2918                 return err;
2919         }
2920
2921         return 0;
2922 }
2923
2924 static void amt_dev_uninit(struct net_device *dev)
2925 {
2926         struct amt_dev *amt = netdev_priv(dev);
2927
2928         gro_cells_destroy(&amt->gro_cells);
2929         free_percpu(dev->tstats);
2930 }
2931
2932 static const struct net_device_ops amt_netdev_ops = {
2933         .ndo_init               = amt_dev_init,
2934         .ndo_uninit             = amt_dev_uninit,
2935         .ndo_open               = amt_dev_open,
2936         .ndo_stop               = amt_dev_stop,
2937         .ndo_start_xmit         = amt_dev_xmit,
2938         .ndo_get_stats64        = dev_get_tstats64,
2939 };
2940
2941 static void amt_link_setup(struct net_device *dev)
2942 {
2943         dev->netdev_ops         = &amt_netdev_ops;
2944         dev->needs_free_netdev  = true;
2945         SET_NETDEV_DEVTYPE(dev, &amt_type);
2946         dev->min_mtu            = ETH_MIN_MTU;
2947         dev->max_mtu            = ETH_MAX_MTU;
2948         dev->type               = ARPHRD_NONE;
2949         dev->flags              = IFF_POINTOPOINT | IFF_NOARP | IFF_MULTICAST;
2950         dev->hard_header_len    = 0;
2951         dev->addr_len           = 0;
2952         dev->priv_flags         |= IFF_NO_QUEUE;
2953         dev->features           |= NETIF_F_LLTX;
2954         dev->features           |= NETIF_F_GSO_SOFTWARE;
2955         dev->features           |= NETIF_F_NETNS_LOCAL;
2956         dev->hw_features        |= NETIF_F_SG | NETIF_F_HW_CSUM;
2957         dev->hw_features        |= NETIF_F_FRAGLIST | NETIF_F_RXCSUM;
2958         dev->hw_features        |= NETIF_F_GSO_SOFTWARE;
2959         eth_hw_addr_random(dev);
2960         eth_zero_addr(dev->broadcast);
2961         ether_setup(dev);
2962 }
2963
2964 static const struct nla_policy amt_policy[IFLA_AMT_MAX + 1] = {
2965         [IFLA_AMT_MODE]         = { .type = NLA_U32 },
2966         [IFLA_AMT_RELAY_PORT]   = { .type = NLA_U16 },
2967         [IFLA_AMT_GATEWAY_PORT] = { .type = NLA_U16 },
2968         [IFLA_AMT_LINK]         = { .type = NLA_U32 },
2969         [IFLA_AMT_LOCAL_IP]     = { .len = sizeof_field(struct iphdr, daddr) },
2970         [IFLA_AMT_REMOTE_IP]    = { .len = sizeof_field(struct iphdr, daddr) },
2971         [IFLA_AMT_DISCOVERY_IP] = { .len = sizeof_field(struct iphdr, daddr) },
2972         [IFLA_AMT_MAX_TUNNELS]  = { .type = NLA_U32 },
2973 };
2974
2975 static int amt_validate(struct nlattr *tb[], struct nlattr *data[],
2976                         struct netlink_ext_ack *extack)
2977 {
2978         if (!data)
2979                 return -EINVAL;
2980
2981         if (!data[IFLA_AMT_LINK]) {
2982                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LINK],
2983                                     "Link attribute is required");
2984                 return -EINVAL;
2985         }
2986
2987         if (!data[IFLA_AMT_MODE]) {
2988                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
2989                                     "Mode attribute is required");
2990                 return -EINVAL;
2991         }
2992
2993         if (nla_get_u32(data[IFLA_AMT_MODE]) > AMT_MODE_MAX) {
2994                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_MODE],
2995                                     "Mode attribute is not valid");
2996                 return -EINVAL;
2997         }
2998
2999         if (!data[IFLA_AMT_LOCAL_IP]) {
3000                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_DISCOVERY_IP],
3001                                     "Local attribute is required");
3002                 return -EINVAL;
3003         }
3004
3005         if (!data[IFLA_AMT_DISCOVERY_IP] &&
3006             nla_get_u32(data[IFLA_AMT_MODE]) == AMT_MODE_GATEWAY) {
3007                 NL_SET_ERR_MSG_ATTR(extack, data[IFLA_AMT_LOCAL_IP],
3008                                     "Discovery attribute is required");
3009                 return -EINVAL;
3010         }
3011
3012         return 0;
3013 }
3014
3015 static int amt_newlink(struct net *net, struct net_device *dev,
3016                        struct nlattr *tb[], struct nlattr *data[],
3017                        struct netlink_ext_ack *extack)
3018 {
3019         struct amt_dev *amt = netdev_priv(dev);
3020         int err = -EINVAL;
3021
3022         amt->net = net;
3023         amt->mode = nla_get_u32(data[IFLA_AMT_MODE]);
3024
3025         if (data[IFLA_AMT_MAX_TUNNELS] &&
3026             nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]))
3027                 amt->max_tunnels = nla_get_u32(data[IFLA_AMT_MAX_TUNNELS]);
3028         else
3029                 amt->max_tunnels = AMT_MAX_TUNNELS;
3030
3031         spin_lock_init(&amt->lock);
3032         amt->max_groups = AMT_MAX_GROUP;
3033         amt->max_sources = AMT_MAX_SOURCE;
3034         amt->hash_buckets = AMT_HSIZE;
3035         amt->nr_tunnels = 0;
3036         get_random_bytes(&amt->hash_seed, sizeof(amt->hash_seed));
3037         amt->stream_dev = dev_get_by_index(net,
3038                                            nla_get_u32(data[IFLA_AMT_LINK]));
3039         if (!amt->stream_dev) {
3040                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3041                                     "Can't find stream device");
3042                 return -ENODEV;
3043         }
3044
3045         if (amt->stream_dev->type != ARPHRD_ETHER) {
3046                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LINK],
3047                                     "Invalid stream device type");
3048                 goto err;
3049         }
3050
3051         amt->local_ip = nla_get_in_addr(data[IFLA_AMT_LOCAL_IP]);
3052         if (ipv4_is_loopback(amt->local_ip) ||
3053             ipv4_is_zeronet(amt->local_ip) ||
3054             ipv4_is_multicast(amt->local_ip)) {
3055                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_LOCAL_IP],
3056                                     "Invalid Local address");
3057                 goto err;
3058         }
3059
3060         if (data[IFLA_AMT_RELAY_PORT])
3061                 amt->relay_port = nla_get_be16(data[IFLA_AMT_RELAY_PORT]);
3062         else
3063                 amt->relay_port = htons(IANA_AMT_UDP_PORT);
3064
3065         if (data[IFLA_AMT_GATEWAY_PORT])
3066                 amt->gw_port = nla_get_be16(data[IFLA_AMT_GATEWAY_PORT]);
3067         else
3068                 amt->gw_port = htons(IANA_AMT_UDP_PORT);
3069
3070         if (!amt->relay_port) {
3071                 NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3072                                     "relay port must not be 0");
3073                 goto err;
3074         }
3075         if (amt->mode == AMT_MODE_RELAY) {
3076                 amt->qrv = amt->net->ipv4.sysctl_igmp_qrv;
3077                 amt->qri = 10;
3078                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3079                                        AMT_RELAY_HLEN;
3080                 dev->mtu = amt->stream_dev->mtu - AMT_RELAY_HLEN;
3081                 dev->max_mtu = dev->mtu;
3082                 dev->min_mtu = ETH_MIN_MTU + AMT_RELAY_HLEN;
3083         } else {
3084                 if (!data[IFLA_AMT_DISCOVERY_IP]) {
3085                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3086                                             "discovery must be set in gateway mode");
3087                         goto err;
3088                 }
3089                 if (!amt->gw_port) {
3090                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3091                                             "gateway port must not be 0");
3092                         goto err;
3093                 }
3094                 amt->remote_ip = 0;
3095                 amt->discovery_ip = nla_get_in_addr(data[IFLA_AMT_DISCOVERY_IP]);
3096                 if (ipv4_is_loopback(amt->discovery_ip) ||
3097                     ipv4_is_zeronet(amt->discovery_ip) ||
3098                     ipv4_is_multicast(amt->discovery_ip)) {
3099                         NL_SET_ERR_MSG_ATTR(extack, tb[IFLA_AMT_DISCOVERY_IP],
3100                                             "discovery must be unicast");
3101                         goto err;
3102                 }
3103
3104                 dev->needed_headroom = amt->stream_dev->needed_headroom +
3105                                        AMT_GW_HLEN;
3106                 dev->mtu = amt->stream_dev->mtu - AMT_GW_HLEN;
3107                 dev->max_mtu = dev->mtu;
3108                 dev->min_mtu = ETH_MIN_MTU + AMT_GW_HLEN;
3109         }
3110         amt->qi = AMT_INIT_QUERY_INTERVAL;
3111
3112         err = register_netdevice(dev);
3113         if (err < 0) {
3114                 netdev_dbg(dev, "failed to register new netdev %d\n", err);
3115                 goto err;
3116         }
3117
3118         err = netdev_upper_dev_link(amt->stream_dev, dev, extack);
3119         if (err < 0) {
3120                 unregister_netdevice(dev);
3121                 goto err;
3122         }
3123
3124         INIT_DELAYED_WORK(&amt->discovery_wq, amt_discovery_work);
3125         INIT_DELAYED_WORK(&amt->req_wq, amt_req_work);
3126         INIT_DELAYED_WORK(&amt->secret_wq, amt_secret_work);
3127         INIT_LIST_HEAD(&amt->tunnel_list);
3128
3129         return 0;
3130 err:
3131         dev_put(amt->stream_dev);
3132         return err;
3133 }
3134
3135 static void amt_dellink(struct net_device *dev, struct list_head *head)
3136 {
3137         struct amt_dev *amt = netdev_priv(dev);
3138
3139         unregister_netdevice_queue(dev, head);
3140         netdev_upper_dev_unlink(amt->stream_dev, dev);
3141         dev_put(amt->stream_dev);
3142 }
3143
3144 static size_t amt_get_size(const struct net_device *dev)
3145 {
3146         return nla_total_size(sizeof(__u32)) + /* IFLA_AMT_MODE */
3147                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_RELAY_PORT */
3148                nla_total_size(sizeof(__u16)) + /* IFLA_AMT_GATEWAY_PORT */
3149                nla_total_size(sizeof(__u32)) + /* IFLA_AMT_LINK */
3150                nla_total_size(sizeof(__u32)) + /* IFLA_MAX_TUNNELS */
3151                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_DISCOVERY_IP */
3152                nla_total_size(sizeof(struct iphdr)) + /* IFLA_AMT_REMOTE_IP */
3153                nla_total_size(sizeof(struct iphdr)); /* IFLA_AMT_LOCAL_IP */
3154 }
3155
3156 static int amt_fill_info(struct sk_buff *skb, const struct net_device *dev)
3157 {
3158         struct amt_dev *amt = netdev_priv(dev);
3159
3160         if (nla_put_u32(skb, IFLA_AMT_MODE, amt->mode))
3161                 goto nla_put_failure;
3162         if (nla_put_be16(skb, IFLA_AMT_RELAY_PORT, amt->relay_port))
3163                 goto nla_put_failure;
3164         if (nla_put_be16(skb, IFLA_AMT_GATEWAY_PORT, amt->gw_port))
3165                 goto nla_put_failure;
3166         if (nla_put_u32(skb, IFLA_AMT_LINK, amt->stream_dev->ifindex))
3167                 goto nla_put_failure;
3168         if (nla_put_in_addr(skb, IFLA_AMT_LOCAL_IP, amt->local_ip))
3169                 goto nla_put_failure;
3170         if (nla_put_in_addr(skb, IFLA_AMT_DISCOVERY_IP, amt->discovery_ip))
3171                 goto nla_put_failure;
3172         if (amt->remote_ip)
3173                 if (nla_put_in_addr(skb, IFLA_AMT_REMOTE_IP, amt->remote_ip))
3174                         goto nla_put_failure;
3175         if (nla_put_u32(skb, IFLA_AMT_MAX_TUNNELS, amt->max_tunnels))
3176                 goto nla_put_failure;
3177
3178         return 0;
3179
3180 nla_put_failure:
3181         return -EMSGSIZE;
3182 }
3183
3184 static struct rtnl_link_ops amt_link_ops __read_mostly = {
3185         .kind           = "amt",
3186         .maxtype        = IFLA_AMT_MAX,
3187         .policy         = amt_policy,
3188         .priv_size      = sizeof(struct amt_dev),
3189         .setup          = amt_link_setup,
3190         .validate       = amt_validate,
3191         .newlink        = amt_newlink,
3192         .dellink        = amt_dellink,
3193         .get_size       = amt_get_size,
3194         .fill_info      = amt_fill_info,
3195 };
3196
3197 static struct net_device *amt_lookup_upper_dev(struct net_device *dev)
3198 {
3199         struct net_device *upper_dev;
3200         struct amt_dev *amt;
3201
3202         for_each_netdev(dev_net(dev), upper_dev) {
3203                 if (netif_is_amt(upper_dev)) {
3204                         amt = netdev_priv(upper_dev);
3205                         if (amt->stream_dev == dev)
3206                                 return upper_dev;
3207                 }
3208         }
3209
3210         return NULL;
3211 }
3212
3213 static int amt_device_event(struct notifier_block *unused,
3214                             unsigned long event, void *ptr)
3215 {
3216         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
3217         struct net_device *upper_dev;
3218         struct amt_dev *amt;
3219         LIST_HEAD(list);
3220         int new_mtu;
3221
3222         upper_dev = amt_lookup_upper_dev(dev);
3223         if (!upper_dev)
3224                 return NOTIFY_DONE;
3225         amt = netdev_priv(upper_dev);
3226
3227         switch (event) {
3228         case NETDEV_UNREGISTER:
3229                 amt_dellink(amt->dev, &list);
3230                 unregister_netdevice_many(&list);
3231                 break;
3232         case NETDEV_CHANGEMTU:
3233                 if (amt->mode == AMT_MODE_RELAY)
3234                         new_mtu = dev->mtu - AMT_RELAY_HLEN;
3235                 else
3236                         new_mtu = dev->mtu - AMT_GW_HLEN;
3237
3238                 dev_set_mtu(amt->dev, new_mtu);
3239                 break;
3240         }
3241
3242         return NOTIFY_DONE;
3243 }
3244
3245 static struct notifier_block amt_notifier_block __read_mostly = {
3246         .notifier_call = amt_device_event,
3247 };
3248
3249 static int __init amt_init(void)
3250 {
3251         int err;
3252
3253         err = register_netdevice_notifier(&amt_notifier_block);
3254         if (err < 0)
3255                 goto err;
3256
3257         err = rtnl_link_register(&amt_link_ops);
3258         if (err < 0)
3259                 goto unregister_notifier;
3260
3261         amt_wq = alloc_workqueue("amt", WQ_UNBOUND, 1);
3262         if (!amt_wq) {
3263                 err = -ENOMEM;
3264                 goto rtnl_unregister;
3265         }
3266
3267         spin_lock_init(&source_gc_lock);
3268         spin_lock_bh(&source_gc_lock);
3269         INIT_DELAYED_WORK(&source_gc_wq, amt_source_gc_work);
3270         mod_delayed_work(amt_wq, &source_gc_wq,
3271                          msecs_to_jiffies(AMT_GC_INTERVAL));
3272         spin_unlock_bh(&source_gc_lock);
3273
3274         return 0;
3275
3276 rtnl_unregister:
3277         rtnl_link_unregister(&amt_link_ops);
3278 unregister_notifier:
3279         unregister_netdevice_notifier(&amt_notifier_block);
3280 err:
3281         pr_err("error loading AMT module loaded\n");
3282         return err;
3283 }
3284 late_initcall(amt_init);
3285
3286 static void __exit amt_fini(void)
3287 {
3288         rtnl_link_unregister(&amt_link_ops);
3289         unregister_netdevice_notifier(&amt_notifier_block);
3290         cancel_delayed_work_sync(&source_gc_wq);
3291         __amt_source_gc_work();
3292         destroy_workqueue(amt_wq);
3293 }
3294 module_exit(amt_fini);
3295
3296 MODULE_LICENSE("GPL");
3297 MODULE_AUTHOR("Taehee Yoo <ap420073@gmail.com>");
3298 MODULE_ALIAS_RTNL_LINK("amt");