Bluetooth: Add RSSI Monitor feature
[platform/kernel/linux-rpi.git] / net / mctp / route.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Management Component Transport Protocol (MCTP) - routing
4  * implementation.
5  *
6  * This is currently based on a simple routing table, with no dst cache. The
7  * number of routes should stay fairly small, so the lookup cost is small.
8  *
9  * Copyright (c) 2021 Code Construct
10  * Copyright (c) 2021 Google
11  */
12
13 #include <linux/idr.h>
14 #include <linux/mctp.h>
15 #include <linux/netdevice.h>
16 #include <linux/rtnetlink.h>
17 #include <linux/skbuff.h>
18
19 #include <uapi/linux/if_arp.h>
20
21 #include <net/mctp.h>
22 #include <net/mctpdevice.h>
23 #include <net/netlink.h>
24 #include <net/sock.h>
25
26 static const unsigned int mctp_message_maxlen = 64 * 1024;
27
28 /* route output callbacks */
29 static int mctp_route_discard(struct mctp_route *route, struct sk_buff *skb)
30 {
31         kfree_skb(skb);
32         return 0;
33 }
34
35 static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
36 {
37         struct mctp_skb_cb *cb = mctp_cb(skb);
38         struct mctp_hdr *mh;
39         struct sock *sk;
40         u8 type;
41
42         WARN_ON(!rcu_read_lock_held());
43
44         /* TODO: look up in skb->cb? */
45         mh = mctp_hdr(skb);
46
47         if (!skb_headlen(skb))
48                 return NULL;
49
50         type = (*(u8 *)skb->data) & 0x7f;
51
52         sk_for_each_rcu(sk, &net->mctp.binds) {
53                 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
54
55                 if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net)
56                         continue;
57
58                 if (msk->bind_type != type)
59                         continue;
60
61                 if (msk->bind_addr != MCTP_ADDR_ANY &&
62                     msk->bind_addr != mh->dest)
63                         continue;
64
65                 return msk;
66         }
67
68         return NULL;
69 }
70
71 static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local,
72                            mctp_eid_t peer, u8 tag)
73 {
74         if (key->local_addr != local)
75                 return false;
76
77         if (key->peer_addr != peer)
78                 return false;
79
80         if (key->tag != tag)
81                 return false;
82
83         return true;
84 }
85
86 static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb,
87                                            mctp_eid_t peer)
88 {
89         struct mctp_sk_key *key, *ret;
90         struct mctp_hdr *mh;
91         u8 tag;
92
93         WARN_ON(!rcu_read_lock_held());
94
95         mh = mctp_hdr(skb);
96         tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
97
98         ret = NULL;
99
100         hlist_for_each_entry_rcu(key, &net->mctp.keys, hlist) {
101                 if (mctp_key_match(key, mh->dest, peer, tag)) {
102                         ret = key;
103                         break;
104                 }
105         }
106
107         return ret;
108 }
109
110 static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
111                                           mctp_eid_t local, mctp_eid_t peer,
112                                           u8 tag, gfp_t gfp)
113 {
114         struct mctp_sk_key *key;
115
116         key = kzalloc(sizeof(*key), gfp);
117         if (!key)
118                 return NULL;
119
120         key->peer_addr = peer;
121         key->local_addr = local;
122         key->tag = tag;
123         key->sk = &msk->sk;
124         spin_lock_init(&key->reasm_lock);
125
126         return key;
127 }
128
129 static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
130 {
131         struct net *net = sock_net(&msk->sk);
132         struct mctp_sk_key *tmp;
133         unsigned long flags;
134         int rc = 0;
135
136         spin_lock_irqsave(&net->mctp.keys_lock, flags);
137
138         if (sock_flag(&msk->sk, SOCK_DEAD)) {
139                 rc = -EINVAL;
140                 goto out_unlock;
141         }
142
143         hlist_for_each_entry(tmp, &net->mctp.keys, hlist) {
144                 if (mctp_key_match(tmp, key->local_addr, key->peer_addr,
145                                    key->tag)) {
146                         rc = -EEXIST;
147                         break;
148                 }
149         }
150
151         if (!rc) {
152                 hlist_add_head(&key->hlist, &net->mctp.keys);
153                 hlist_add_head(&key->sklist, &msk->keys);
154         }
155
156 out_unlock:
157         spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
158
159         return rc;
160 }
161
162 /* Must be called with key->reasm_lock, which it will release. Will schedule
163  * the key for an RCU free.
164  */
165 static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net,
166                                    unsigned long flags)
167         __releases(&key->reasm_lock)
168 {
169         struct sk_buff *skb;
170
171         skb = key->reasm_head;
172         key->reasm_head = NULL;
173         key->reasm_dead = true;
174         spin_unlock_irqrestore(&key->reasm_lock, flags);
175
176         spin_lock_irqsave(&net->mctp.keys_lock, flags);
177         hlist_del_rcu(&key->hlist);
178         hlist_del_rcu(&key->sklist);
179         spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
180         kfree_rcu(key, rcu);
181
182         if (skb)
183                 kfree_skb(skb);
184 }
185
186 static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb)
187 {
188         struct mctp_hdr *hdr = mctp_hdr(skb);
189         u8 exp_seq, this_seq;
190
191         this_seq = (hdr->flags_seq_tag >> MCTP_HDR_SEQ_SHIFT)
192                 & MCTP_HDR_SEQ_MASK;
193
194         if (!key->reasm_head) {
195                 key->reasm_head = skb;
196                 key->reasm_tailp = &(skb_shinfo(skb)->frag_list);
197                 key->last_seq = this_seq;
198                 return 0;
199         }
200
201         exp_seq = (key->last_seq + 1) & MCTP_HDR_SEQ_MASK;
202
203         if (this_seq != exp_seq)
204                 return -EINVAL;
205
206         if (key->reasm_head->len + skb->len > mctp_message_maxlen)
207                 return -EINVAL;
208
209         skb->next = NULL;
210         skb->sk = NULL;
211         *key->reasm_tailp = skb;
212         key->reasm_tailp = &skb->next;
213
214         key->last_seq = this_seq;
215
216         key->reasm_head->data_len += skb->len;
217         key->reasm_head->len += skb->len;
218         key->reasm_head->truesize += skb->truesize;
219
220         return 0;
221 }
222
223 static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
224 {
225         struct net *net = dev_net(skb->dev);
226         struct mctp_sk_key *key;
227         struct mctp_sock *msk;
228         struct mctp_hdr *mh;
229         unsigned long f;
230         u8 tag, flags;
231         int rc;
232
233         msk = NULL;
234         rc = -EINVAL;
235
236         /* we may be receiving a locally-routed packet; drop source sk
237          * accounting
238          */
239         skb_orphan(skb);
240
241         /* ensure we have enough data for a header and a type */
242         if (skb->len < sizeof(struct mctp_hdr) + 1)
243                 goto out;
244
245         /* grab header, advance data ptr */
246         mh = mctp_hdr(skb);
247         skb_pull(skb, sizeof(struct mctp_hdr));
248
249         if (mh->ver != 1)
250                 goto out;
251
252         flags = mh->flags_seq_tag & (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM);
253         tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
254
255         rcu_read_lock();
256
257         /* lookup socket / reasm context, exactly matching (src,dest,tag) */
258         key = mctp_lookup_key(net, skb, mh->src);
259
260         if (flags & MCTP_HDR_FLAG_SOM) {
261                 if (key) {
262                         msk = container_of(key->sk, struct mctp_sock, sk);
263                 } else {
264                         /* first response to a broadcast? do a more general
265                          * key lookup to find the socket, but don't use this
266                          * key for reassembly - we'll create a more specific
267                          * one for future packets if required (ie, !EOM).
268                          */
269                         key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY);
270                         if (key) {
271                                 msk = container_of(key->sk,
272                                                    struct mctp_sock, sk);
273                                 key = NULL;
274                         }
275                 }
276
277                 if (!key && !msk && (tag & MCTP_HDR_FLAG_TO))
278                         msk = mctp_lookup_bind(net, skb);
279
280                 if (!msk) {
281                         rc = -ENOENT;
282                         goto out_unlock;
283                 }
284
285                 /* single-packet message? deliver to socket, clean up any
286                  * pending key.
287                  */
288                 if (flags & MCTP_HDR_FLAG_EOM) {
289                         sock_queue_rcv_skb(&msk->sk, skb);
290                         if (key) {
291                                 spin_lock_irqsave(&key->reasm_lock, f);
292                                 /* we've hit a pending reassembly; not much we
293                                  * can do but drop it
294                                  */
295                                 __mctp_key_unlock_drop(key, net, f);
296                         }
297                         rc = 0;
298                         goto out_unlock;
299                 }
300
301                 /* broadcast response or a bind() - create a key for further
302                  * packets for this message
303                  */
304                 if (!key) {
305                         key = mctp_key_alloc(msk, mh->dest, mh->src,
306                                              tag, GFP_ATOMIC);
307                         if (!key) {
308                                 rc = -ENOMEM;
309                                 goto out_unlock;
310                         }
311
312                         /* we can queue without the reasm lock here, as the
313                          * key isn't observable yet
314                          */
315                         mctp_frag_queue(key, skb);
316
317                         /* if the key_add fails, we've raced with another
318                          * SOM packet with the same src, dest and tag. There's
319                          * no way to distinguish future packets, so all we
320                          * can do is drop; we'll free the skb on exit from
321                          * this function.
322                          */
323                         rc = mctp_key_add(key, msk);
324                         if (rc)
325                                 kfree(key);
326
327                 } else {
328                         /* existing key: start reassembly */
329                         spin_lock_irqsave(&key->reasm_lock, f);
330
331                         if (key->reasm_head || key->reasm_dead) {
332                                 /* duplicate start? drop everything */
333                                 __mctp_key_unlock_drop(key, net, f);
334                                 rc = -EEXIST;
335                         } else {
336                                 rc = mctp_frag_queue(key, skb);
337                                 spin_unlock_irqrestore(&key->reasm_lock, f);
338                         }
339                 }
340
341         } else if (key) {
342                 /* this packet continues a previous message; reassemble
343                  * using the message-specific key
344                  */
345
346                 spin_lock_irqsave(&key->reasm_lock, f);
347
348                 /* we need to be continuing an existing reassembly... */
349                 if (!key->reasm_head)
350                         rc = -EINVAL;
351                 else
352                         rc = mctp_frag_queue(key, skb);
353
354                 /* end of message? deliver to socket, and we're done with
355                  * the reassembly/response key
356                  */
357                 if (!rc && flags & MCTP_HDR_FLAG_EOM) {
358                         sock_queue_rcv_skb(key->sk, key->reasm_head);
359                         key->reasm_head = NULL;
360                         __mctp_key_unlock_drop(key, net, f);
361                 } else {
362                         spin_unlock_irqrestore(&key->reasm_lock, f);
363                 }
364
365         } else {
366                 /* not a start, no matching key */
367                 rc = -ENOENT;
368         }
369
370 out_unlock:
371         rcu_read_unlock();
372 out:
373         if (rc)
374                 kfree_skb(skb);
375         return rc;
376 }
377
378 static unsigned int mctp_route_mtu(struct mctp_route *rt)
379 {
380         return rt->mtu ?: READ_ONCE(rt->dev->dev->mtu);
381 }
382
383 static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb)
384 {
385         struct mctp_hdr *hdr = mctp_hdr(skb);
386         char daddr_buf[MAX_ADDR_LEN];
387         char *daddr = NULL;
388         unsigned int mtu;
389         int rc;
390
391         skb->protocol = htons(ETH_P_MCTP);
392
393         mtu = READ_ONCE(skb->dev->mtu);
394         if (skb->len > mtu) {
395                 kfree_skb(skb);
396                 return -EMSGSIZE;
397         }
398
399         /* If lookup fails let the device handle daddr==NULL */
400         if (mctp_neigh_lookup(route->dev, hdr->dest, daddr_buf) == 0)
401                 daddr = daddr_buf;
402
403         rc = dev_hard_header(skb, skb->dev, ntohs(skb->protocol),
404                              daddr, skb->dev->dev_addr, skb->len);
405         if (rc < 0) {
406                 kfree_skb(skb);
407                 return -EHOSTUNREACH;
408         }
409
410         rc = dev_queue_xmit(skb);
411         if (rc)
412                 rc = net_xmit_errno(rc);
413
414         return rc;
415 }
416
417 /* route alloc/release */
418 static void mctp_route_release(struct mctp_route *rt)
419 {
420         if (refcount_dec_and_test(&rt->refs)) {
421                 dev_put(rt->dev->dev);
422                 kfree_rcu(rt, rcu);
423         }
424 }
425
426 /* returns a route with the refcount at 1 */
427 static struct mctp_route *mctp_route_alloc(void)
428 {
429         struct mctp_route *rt;
430
431         rt = kzalloc(sizeof(*rt), GFP_KERNEL);
432         if (!rt)
433                 return NULL;
434
435         INIT_LIST_HEAD(&rt->list);
436         refcount_set(&rt->refs, 1);
437         rt->output = mctp_route_discard;
438
439         return rt;
440 }
441
442 unsigned int mctp_default_net(struct net *net)
443 {
444         return READ_ONCE(net->mctp.default_net);
445 }
446
447 int mctp_default_net_set(struct net *net, unsigned int index)
448 {
449         if (index == 0)
450                 return -EINVAL;
451         WRITE_ONCE(net->mctp.default_net, index);
452         return 0;
453 }
454
455 /* tag management */
456 static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key,
457                              struct mctp_sock *msk)
458 {
459         struct netns_mctp *mns = &net->mctp;
460
461         lockdep_assert_held(&mns->keys_lock);
462
463         /* we hold the net->key_lock here, allowing updates to both
464          * then net and sk
465          */
466         hlist_add_head_rcu(&key->hlist, &mns->keys);
467         hlist_add_head_rcu(&key->sklist, &msk->keys);
468 }
469
470 /* Allocate a locally-owned tag value for (saddr, daddr), and reserve
471  * it for the socket msk
472  */
473 static int mctp_alloc_local_tag(struct mctp_sock *msk,
474                                 mctp_eid_t saddr, mctp_eid_t daddr, u8 *tagp)
475 {
476         struct net *net = sock_net(&msk->sk);
477         struct netns_mctp *mns = &net->mctp;
478         struct mctp_sk_key *key, *tmp;
479         unsigned long flags;
480         int rc = -EAGAIN;
481         u8 tagbits;
482
483         /* be optimistic, alloc now */
484         key = mctp_key_alloc(msk, saddr, daddr, 0, GFP_KERNEL);
485         if (!key)
486                 return -ENOMEM;
487
488         /* 8 possible tag values */
489         tagbits = 0xff;
490
491         spin_lock_irqsave(&mns->keys_lock, flags);
492
493         /* Walk through the existing keys, looking for potential conflicting
494          * tags. If we find a conflict, clear that bit from tagbits
495          */
496         hlist_for_each_entry(tmp, &mns->keys, hlist) {
497                 /* if we don't own the tag, it can't conflict */
498                 if (tmp->tag & MCTP_HDR_FLAG_TO)
499                         continue;
500
501                 if ((tmp->peer_addr == daddr ||
502                      tmp->peer_addr == MCTP_ADDR_ANY) &&
503                     tmp->local_addr == saddr)
504                         tagbits &= ~(1 << tmp->tag);
505
506                 if (!tagbits)
507                         break;
508         }
509
510         if (tagbits) {
511                 key->tag = __ffs(tagbits);
512                 mctp_reserve_tag(net, key, msk);
513                 *tagp = key->tag;
514                 rc = 0;
515         }
516
517         spin_unlock_irqrestore(&mns->keys_lock, flags);
518
519         if (!tagbits)
520                 kfree(key);
521
522         return rc;
523 }
524
525 /* routing lookups */
526 static bool mctp_rt_match_eid(struct mctp_route *rt,
527                               unsigned int net, mctp_eid_t eid)
528 {
529         return READ_ONCE(rt->dev->net) == net &&
530                 rt->min <= eid && rt->max >= eid;
531 }
532
533 /* compares match, used for duplicate prevention */
534 static bool mctp_rt_compare_exact(struct mctp_route *rt1,
535                                   struct mctp_route *rt2)
536 {
537         ASSERT_RTNL();
538         return rt1->dev->net == rt2->dev->net &&
539                 rt1->min == rt2->min &&
540                 rt1->max == rt2->max;
541 }
542
543 struct mctp_route *mctp_route_lookup(struct net *net, unsigned int dnet,
544                                      mctp_eid_t daddr)
545 {
546         struct mctp_route *tmp, *rt = NULL;
547
548         list_for_each_entry_rcu(tmp, &net->mctp.routes, list) {
549                 /* TODO: add metrics */
550                 if (mctp_rt_match_eid(tmp, dnet, daddr)) {
551                         if (refcount_inc_not_zero(&tmp->refs)) {
552                                 rt = tmp;
553                                 break;
554                         }
555                 }
556         }
557
558         return rt;
559 }
560
561 /* sends a skb to rt and releases the route. */
562 int mctp_do_route(struct mctp_route *rt, struct sk_buff *skb)
563 {
564         int rc;
565
566         rc = rt->output(rt, skb);
567         mctp_route_release(rt);
568         return rc;
569 }
570
571 static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
572                                   unsigned int mtu, u8 tag)
573 {
574         const unsigned int hlen = sizeof(struct mctp_hdr);
575         struct mctp_hdr *hdr, *hdr2;
576         unsigned int pos, size;
577         struct sk_buff *skb2;
578         int rc;
579         u8 seq;
580
581         hdr = mctp_hdr(skb);
582         seq = 0;
583         rc = 0;
584
585         if (mtu < hlen + 1) {
586                 kfree_skb(skb);
587                 return -EMSGSIZE;
588         }
589
590         /* we've got the header */
591         skb_pull(skb, hlen);
592
593         for (pos = 0; pos < skb->len;) {
594                 /* size of message payload */
595                 size = min(mtu - hlen, skb->len - pos);
596
597                 skb2 = alloc_skb(MCTP_HEADER_MAXLEN + hlen + size, GFP_KERNEL);
598                 if (!skb2) {
599                         rc = -ENOMEM;
600                         break;
601                 }
602
603                 /* generic skb copy */
604                 skb2->protocol = skb->protocol;
605                 skb2->priority = skb->priority;
606                 skb2->dev = skb->dev;
607                 memcpy(skb2->cb, skb->cb, sizeof(skb2->cb));
608
609                 if (skb->sk)
610                         skb_set_owner_w(skb2, skb->sk);
611
612                 /* establish packet */
613                 skb_reserve(skb2, MCTP_HEADER_MAXLEN);
614                 skb_reset_network_header(skb2);
615                 skb_put(skb2, hlen + size);
616                 skb2->transport_header = skb2->network_header + hlen;
617
618                 /* copy header fields, calculate SOM/EOM flags & seq */
619                 hdr2 = mctp_hdr(skb2);
620                 hdr2->ver = hdr->ver;
621                 hdr2->dest = hdr->dest;
622                 hdr2->src = hdr->src;
623                 hdr2->flags_seq_tag = tag &
624                         (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
625
626                 if (pos == 0)
627                         hdr2->flags_seq_tag |= MCTP_HDR_FLAG_SOM;
628
629                 if (pos + size == skb->len)
630                         hdr2->flags_seq_tag |= MCTP_HDR_FLAG_EOM;
631
632                 hdr2->flags_seq_tag |= seq << MCTP_HDR_SEQ_SHIFT;
633
634                 /* copy message payload */
635                 skb_copy_bits(skb, pos, skb_transport_header(skb2), size);
636
637                 /* do route, but don't drop the rt reference */
638                 rc = rt->output(rt, skb2);
639                 if (rc)
640                         break;
641
642                 seq = (seq + 1) & MCTP_HDR_SEQ_MASK;
643                 pos += size;
644         }
645
646         mctp_route_release(rt);
647         consume_skb(skb);
648         return rc;
649 }
650
651 int mctp_local_output(struct sock *sk, struct mctp_route *rt,
652                       struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag)
653 {
654         struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
655         struct mctp_skb_cb *cb = mctp_cb(skb);
656         struct mctp_hdr *hdr;
657         unsigned long flags;
658         unsigned int mtu;
659         mctp_eid_t saddr;
660         int rc;
661         u8 tag;
662
663         if (WARN_ON(!rt->dev))
664                 return -EINVAL;
665
666         spin_lock_irqsave(&rt->dev->addrs_lock, flags);
667         if (rt->dev->num_addrs == 0) {
668                 rc = -EHOSTUNREACH;
669         } else {
670                 /* use the outbound interface's first address as our source */
671                 saddr = rt->dev->addrs[0];
672                 rc = 0;
673         }
674         spin_unlock_irqrestore(&rt->dev->addrs_lock, flags);
675
676         if (rc)
677                 return rc;
678
679         if (req_tag & MCTP_HDR_FLAG_TO) {
680                 rc = mctp_alloc_local_tag(msk, saddr, daddr, &tag);
681                 if (rc)
682                         return rc;
683                 tag |= MCTP_HDR_FLAG_TO;
684         } else {
685                 tag = req_tag;
686         }
687
688
689         skb->protocol = htons(ETH_P_MCTP);
690         skb->priority = 0;
691         skb_reset_transport_header(skb);
692         skb_push(skb, sizeof(struct mctp_hdr));
693         skb_reset_network_header(skb);
694         skb->dev = rt->dev->dev;
695
696         /* cb->net will have been set on initial ingress */
697         cb->src = saddr;
698
699         /* set up common header fields */
700         hdr = mctp_hdr(skb);
701         hdr->ver = 1;
702         hdr->dest = daddr;
703         hdr->src = saddr;
704
705         mtu = mctp_route_mtu(rt);
706
707         if (skb->len + sizeof(struct mctp_hdr) <= mtu) {
708                 hdr->flags_seq_tag = MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM |
709                         tag;
710                 return mctp_do_route(rt, skb);
711         } else {
712                 return mctp_do_fragment_route(rt, skb, mtu, tag);
713         }
714 }
715
716 /* route management */
717 static int mctp_route_add(struct mctp_dev *mdev, mctp_eid_t daddr_start,
718                           unsigned int daddr_extent, unsigned int mtu,
719                           unsigned char type)
720 {
721         int (*rtfn)(struct mctp_route *rt, struct sk_buff *skb);
722         struct net *net = dev_net(mdev->dev);
723         struct mctp_route *rt, *ert;
724
725         if (!mctp_address_ok(daddr_start))
726                 return -EINVAL;
727
728         if (daddr_extent > 0xff || daddr_start + daddr_extent >= 255)
729                 return -EINVAL;
730
731         switch (type) {
732         case RTN_LOCAL:
733                 rtfn = mctp_route_input;
734                 break;
735         case RTN_UNICAST:
736                 rtfn = mctp_route_output;
737                 break;
738         default:
739                 return -EINVAL;
740         }
741
742         rt = mctp_route_alloc();
743         if (!rt)
744                 return -ENOMEM;
745
746         rt->min = daddr_start;
747         rt->max = daddr_start + daddr_extent;
748         rt->mtu = mtu;
749         rt->dev = mdev;
750         dev_hold(rt->dev->dev);
751         rt->type = type;
752         rt->output = rtfn;
753
754         ASSERT_RTNL();
755         /* Prevent duplicate identical routes. */
756         list_for_each_entry(ert, &net->mctp.routes, list) {
757                 if (mctp_rt_compare_exact(rt, ert)) {
758                         mctp_route_release(rt);
759                         return -EEXIST;
760                 }
761         }
762
763         list_add_rcu(&rt->list, &net->mctp.routes);
764
765         return 0;
766 }
767
768 static int mctp_route_remove(struct mctp_dev *mdev, mctp_eid_t daddr_start,
769                              unsigned int daddr_extent, unsigned char type)
770 {
771         struct net *net = dev_net(mdev->dev);
772         struct mctp_route *rt, *tmp;
773         mctp_eid_t daddr_end;
774         bool dropped;
775
776         if (daddr_extent > 0xff || daddr_start + daddr_extent >= 255)
777                 return -EINVAL;
778
779         daddr_end = daddr_start + daddr_extent;
780         dropped = false;
781
782         ASSERT_RTNL();
783
784         list_for_each_entry_safe(rt, tmp, &net->mctp.routes, list) {
785                 if (rt->dev == mdev &&
786                     rt->min == daddr_start && rt->max == daddr_end &&
787                     rt->type == type) {
788                         list_del_rcu(&rt->list);
789                         /* TODO: immediate RTM_DELROUTE */
790                         mctp_route_release(rt);
791                         dropped = true;
792                 }
793         }
794
795         return dropped ? 0 : -ENOENT;
796 }
797
798 int mctp_route_add_local(struct mctp_dev *mdev, mctp_eid_t addr)
799 {
800         return mctp_route_add(mdev, addr, 0, 0, RTN_LOCAL);
801 }
802
803 int mctp_route_remove_local(struct mctp_dev *mdev, mctp_eid_t addr)
804 {
805         return mctp_route_remove(mdev, addr, 0, RTN_LOCAL);
806 }
807
808 /* removes all entries for a given device */
809 void mctp_route_remove_dev(struct mctp_dev *mdev)
810 {
811         struct net *net = dev_net(mdev->dev);
812         struct mctp_route *rt, *tmp;
813
814         ASSERT_RTNL();
815         list_for_each_entry_safe(rt, tmp, &net->mctp.routes, list) {
816                 if (rt->dev == mdev) {
817                         list_del_rcu(&rt->list);
818                         /* TODO: immediate RTM_DELROUTE */
819                         mctp_route_release(rt);
820                 }
821         }
822 }
823
824 /* Incoming packet-handling */
825
826 static int mctp_pkttype_receive(struct sk_buff *skb, struct net_device *dev,
827                                 struct packet_type *pt,
828                                 struct net_device *orig_dev)
829 {
830         struct net *net = dev_net(dev);
831         struct mctp_skb_cb *cb;
832         struct mctp_route *rt;
833         struct mctp_hdr *mh;
834
835         /* basic non-data sanity checks */
836         if (dev->type != ARPHRD_MCTP)
837                 goto err_drop;
838
839         if (!pskb_may_pull(skb, sizeof(struct mctp_hdr)))
840                 goto err_drop;
841
842         skb_reset_transport_header(skb);
843         skb_reset_network_header(skb);
844
845         /* We have enough for a header; decode and route */
846         mh = mctp_hdr(skb);
847         if (mh->ver < MCTP_VER_MIN || mh->ver > MCTP_VER_MAX)
848                 goto err_drop;
849
850         cb = __mctp_cb(skb);
851         rcu_read_lock();
852         cb->net = READ_ONCE(__mctp_dev_get(dev)->net);
853         rcu_read_unlock();
854
855         rt = mctp_route_lookup(net, cb->net, mh->dest);
856         if (!rt)
857                 goto err_drop;
858
859         mctp_do_route(rt, skb);
860
861         return NET_RX_SUCCESS;
862
863 err_drop:
864         kfree_skb(skb);
865         return NET_RX_DROP;
866 }
867
868 static struct packet_type mctp_packet_type = {
869         .type = cpu_to_be16(ETH_P_MCTP),
870         .func = mctp_pkttype_receive,
871 };
872
873 /* netlink interface */
874
875 static const struct nla_policy rta_mctp_policy[RTA_MAX + 1] = {
876         [RTA_DST]               = { .type = NLA_U8 },
877         [RTA_METRICS]           = { .type = NLA_NESTED },
878         [RTA_OIF]               = { .type = NLA_U32 },
879 };
880
881 /* Common part for RTM_NEWROUTE and RTM_DELROUTE parsing.
882  * tb must hold RTA_MAX+1 elements.
883  */
884 static int mctp_route_nlparse(struct sk_buff *skb, struct nlmsghdr *nlh,
885                               struct netlink_ext_ack *extack,
886                               struct nlattr **tb, struct rtmsg **rtm,
887                               struct mctp_dev **mdev, mctp_eid_t *daddr_start)
888 {
889         struct net *net = sock_net(skb->sk);
890         struct net_device *dev;
891         unsigned int ifindex;
892         int rc;
893
894         rc = nlmsg_parse(nlh, sizeof(struct rtmsg), tb, RTA_MAX,
895                          rta_mctp_policy, extack);
896         if (rc < 0) {
897                 NL_SET_ERR_MSG(extack, "incorrect format");
898                 return rc;
899         }
900
901         if (!tb[RTA_DST]) {
902                 NL_SET_ERR_MSG(extack, "dst EID missing");
903                 return -EINVAL;
904         }
905         *daddr_start = nla_get_u8(tb[RTA_DST]);
906
907         if (!tb[RTA_OIF]) {
908                 NL_SET_ERR_MSG(extack, "ifindex missing");
909                 return -EINVAL;
910         }
911         ifindex = nla_get_u32(tb[RTA_OIF]);
912
913         *rtm = nlmsg_data(nlh);
914         if ((*rtm)->rtm_family != AF_MCTP) {
915                 NL_SET_ERR_MSG(extack, "route family must be AF_MCTP");
916                 return -EINVAL;
917         }
918
919         dev = __dev_get_by_index(net, ifindex);
920         if (!dev) {
921                 NL_SET_ERR_MSG(extack, "bad ifindex");
922                 return -ENODEV;
923         }
924         *mdev = mctp_dev_get_rtnl(dev);
925         if (!*mdev)
926                 return -ENODEV;
927
928         if (dev->flags & IFF_LOOPBACK) {
929                 NL_SET_ERR_MSG(extack, "no routes to loopback");
930                 return -EINVAL;
931         }
932
933         return 0;
934 }
935
936 static int mctp_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
937                          struct netlink_ext_ack *extack)
938 {
939         struct nlattr *tb[RTA_MAX + 1];
940         mctp_eid_t daddr_start;
941         struct mctp_dev *mdev;
942         struct rtmsg *rtm;
943         unsigned int mtu;
944         int rc;
945
946         rc = mctp_route_nlparse(skb, nlh, extack, tb,
947                                 &rtm, &mdev, &daddr_start);
948         if (rc < 0)
949                 return rc;
950
951         if (rtm->rtm_type != RTN_UNICAST) {
952                 NL_SET_ERR_MSG(extack, "rtm_type must be RTN_UNICAST");
953                 return -EINVAL;
954         }
955
956         /* TODO: parse mtu from nlparse */
957         mtu = 0;
958
959         if (rtm->rtm_type != RTN_UNICAST)
960                 return -EINVAL;
961
962         rc = mctp_route_add(mdev, daddr_start, rtm->rtm_dst_len, mtu,
963                             rtm->rtm_type);
964         return rc;
965 }
966
967 static int mctp_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
968                          struct netlink_ext_ack *extack)
969 {
970         struct nlattr *tb[RTA_MAX + 1];
971         mctp_eid_t daddr_start;
972         struct mctp_dev *mdev;
973         struct rtmsg *rtm;
974         int rc;
975
976         rc = mctp_route_nlparse(skb, nlh, extack, tb,
977                                 &rtm, &mdev, &daddr_start);
978         if (rc < 0)
979                 return rc;
980
981         /* we only have unicast routes */
982         if (rtm->rtm_type != RTN_UNICAST)
983                 return -EINVAL;
984
985         rc = mctp_route_remove(mdev, daddr_start, rtm->rtm_dst_len, RTN_UNICAST);
986         return rc;
987 }
988
989 static int mctp_fill_rtinfo(struct sk_buff *skb, struct mctp_route *rt,
990                             u32 portid, u32 seq, int event, unsigned int flags)
991 {
992         struct nlmsghdr *nlh;
993         struct rtmsg *hdr;
994         void *metrics;
995
996         nlh = nlmsg_put(skb, portid, seq, event, sizeof(*hdr), flags);
997         if (!nlh)
998                 return -EMSGSIZE;
999
1000         hdr = nlmsg_data(nlh);
1001         hdr->rtm_family = AF_MCTP;
1002
1003         /* we use the _len fields as a number of EIDs, rather than
1004          * a number of bits in the address
1005          */
1006         hdr->rtm_dst_len = rt->max - rt->min;
1007         hdr->rtm_src_len = 0;
1008         hdr->rtm_tos = 0;
1009         hdr->rtm_table = RT_TABLE_DEFAULT;
1010         hdr->rtm_protocol = RTPROT_STATIC; /* everything is user-defined */
1011         hdr->rtm_scope = RT_SCOPE_LINK; /* TODO: scope in mctp_route? */
1012         hdr->rtm_type = rt->type;
1013
1014         if (nla_put_u8(skb, RTA_DST, rt->min))
1015                 goto cancel;
1016
1017         metrics = nla_nest_start_noflag(skb, RTA_METRICS);
1018         if (!metrics)
1019                 goto cancel;
1020
1021         if (rt->mtu) {
1022                 if (nla_put_u32(skb, RTAX_MTU, rt->mtu))
1023                         goto cancel;
1024         }
1025
1026         nla_nest_end(skb, metrics);
1027
1028         if (rt->dev) {
1029                 if (nla_put_u32(skb, RTA_OIF, rt->dev->dev->ifindex))
1030                         goto cancel;
1031         }
1032
1033         /* TODO: conditional neighbour physaddr? */
1034
1035         nlmsg_end(skb, nlh);
1036
1037         return 0;
1038
1039 cancel:
1040         nlmsg_cancel(skb, nlh);
1041         return -EMSGSIZE;
1042 }
1043
1044 static int mctp_dump_rtinfo(struct sk_buff *skb, struct netlink_callback *cb)
1045 {
1046         struct net *net = sock_net(skb->sk);
1047         struct mctp_route *rt;
1048         int s_idx, idx;
1049
1050         /* TODO: allow filtering on route data, possibly under
1051          * cb->strict_check
1052          */
1053
1054         /* TODO: change to struct overlay */
1055         s_idx = cb->args[0];
1056         idx = 0;
1057
1058         rcu_read_lock();
1059         list_for_each_entry_rcu(rt, &net->mctp.routes, list) {
1060                 if (idx++ < s_idx)
1061                         continue;
1062                 if (mctp_fill_rtinfo(skb, rt,
1063                                      NETLINK_CB(cb->skb).portid,
1064                                      cb->nlh->nlmsg_seq,
1065                                      RTM_NEWROUTE, NLM_F_MULTI) < 0)
1066                         break;
1067         }
1068
1069         rcu_read_unlock();
1070         cb->args[0] = idx;
1071
1072         return skb->len;
1073 }
1074
1075 /* net namespace implementation */
1076 static int __net_init mctp_routes_net_init(struct net *net)
1077 {
1078         struct netns_mctp *ns = &net->mctp;
1079
1080         INIT_LIST_HEAD(&ns->routes);
1081         INIT_HLIST_HEAD(&ns->binds);
1082         mutex_init(&ns->bind_lock);
1083         INIT_HLIST_HEAD(&ns->keys);
1084         spin_lock_init(&ns->keys_lock);
1085         WARN_ON(mctp_default_net_set(net, MCTP_INITIAL_DEFAULT_NET));
1086         return 0;
1087 }
1088
1089 static void __net_exit mctp_routes_net_exit(struct net *net)
1090 {
1091         struct mctp_route *rt;
1092
1093         rcu_read_lock();
1094         list_for_each_entry_rcu(rt, &net->mctp.routes, list)
1095                 mctp_route_release(rt);
1096         rcu_read_unlock();
1097 }
1098
1099 static struct pernet_operations mctp_net_ops = {
1100         .init = mctp_routes_net_init,
1101         .exit = mctp_routes_net_exit,
1102 };
1103
1104 int __init mctp_routes_init(void)
1105 {
1106         dev_add_pack(&mctp_packet_type);
1107
1108         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_GETROUTE,
1109                              NULL, mctp_dump_rtinfo, 0);
1110         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_NEWROUTE,
1111                              mctp_newroute, NULL, 0);
1112         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_DELROUTE,
1113                              mctp_delroute, NULL, 0);
1114
1115         return register_pernet_subsys(&mctp_net_ops);
1116 }
1117
1118 void mctp_routes_exit(void)
1119 {
1120         unregister_pernet_subsys(&mctp_net_ops);
1121         rtnl_unregister(PF_MCTP, RTM_DELROUTE);
1122         rtnl_unregister(PF_MCTP, RTM_NEWROUTE);
1123         rtnl_unregister(PF_MCTP, RTM_GETROUTE);
1124         dev_remove_pack(&mctp_packet_type);
1125 }