mctp: Fix check for dev_hard_header() result
[platform/kernel/linux-rpi.git] / net / mctp / route.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Management Component Transport Protocol (MCTP) - routing
4  * implementation.
5  *
6  * This is currently based on a simple routing table, with no dst cache. The
7  * number of routes should stay fairly small, so the lookup cost is small.
8  *
9  * Copyright (c) 2021 Code Construct
10  * Copyright (c) 2021 Google
11  */
12
13 #include <linux/idr.h>
14 #include <linux/mctp.h>
15 #include <linux/netdevice.h>
16 #include <linux/rtnetlink.h>
17 #include <linux/skbuff.h>
18
19 #include <uapi/linux/if_arp.h>
20
21 #include <net/mctp.h>
22 #include <net/mctpdevice.h>
23 #include <net/netlink.h>
24 #include <net/sock.h>
25
26 static const unsigned int mctp_message_maxlen = 64 * 1024;
27
28 /* route output callbacks */
29 static int mctp_route_discard(struct mctp_route *route, struct sk_buff *skb)
30 {
31         kfree_skb(skb);
32         return 0;
33 }
34
35 static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
36 {
37         struct mctp_skb_cb *cb = mctp_cb(skb);
38         struct mctp_hdr *mh;
39         struct sock *sk;
40         u8 type;
41
42         WARN_ON(!rcu_read_lock_held());
43
44         /* TODO: look up in skb->cb? */
45         mh = mctp_hdr(skb);
46
47         if (!skb_headlen(skb))
48                 return NULL;
49
50         type = (*(u8 *)skb->data) & 0x7f;
51
52         sk_for_each_rcu(sk, &net->mctp.binds) {
53                 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
54
55                 if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net)
56                         continue;
57
58                 if (msk->bind_type != type)
59                         continue;
60
61                 if (msk->bind_addr != MCTP_ADDR_ANY &&
62                     msk->bind_addr != mh->dest)
63                         continue;
64
65                 return msk;
66         }
67
68         return NULL;
69 }
70
71 static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local,
72                            mctp_eid_t peer, u8 tag)
73 {
74         if (key->local_addr != local)
75                 return false;
76
77         if (key->peer_addr != peer)
78                 return false;
79
80         if (key->tag != tag)
81                 return false;
82
83         return true;
84 }
85
86 static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb,
87                                            mctp_eid_t peer)
88 {
89         struct mctp_sk_key *key, *ret;
90         struct mctp_hdr *mh;
91         u8 tag;
92
93         WARN_ON(!rcu_read_lock_held());
94
95         mh = mctp_hdr(skb);
96         tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
97
98         ret = NULL;
99
100         hlist_for_each_entry_rcu(key, &net->mctp.keys, hlist) {
101                 if (mctp_key_match(key, mh->dest, peer, tag)) {
102                         ret = key;
103                         break;
104                 }
105         }
106
107         return ret;
108 }
109
110 static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
111                                           mctp_eid_t local, mctp_eid_t peer,
112                                           u8 tag, gfp_t gfp)
113 {
114         struct mctp_sk_key *key;
115
116         key = kzalloc(sizeof(*key), gfp);
117         if (!key)
118                 return NULL;
119
120         key->peer_addr = peer;
121         key->local_addr = local;
122         key->tag = tag;
123         key->sk = &msk->sk;
124         spin_lock_init(&key->reasm_lock);
125
126         return key;
127 }
128
129 static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
130 {
131         struct net *net = sock_net(&msk->sk);
132         struct mctp_sk_key *tmp;
133         unsigned long flags;
134         int rc = 0;
135
136         spin_lock_irqsave(&net->mctp.keys_lock, flags);
137
138         hlist_for_each_entry(tmp, &net->mctp.keys, hlist) {
139                 if (mctp_key_match(tmp, key->local_addr, key->peer_addr,
140                                    key->tag)) {
141                         rc = -EEXIST;
142                         break;
143                 }
144         }
145
146         if (!rc) {
147                 hlist_add_head(&key->hlist, &net->mctp.keys);
148                 hlist_add_head(&key->sklist, &msk->keys);
149         }
150
151         spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
152
153         return rc;
154 }
155
156 /* Must be called with key->reasm_lock, which it will release. Will schedule
157  * the key for an RCU free.
158  */
159 static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net,
160                                    unsigned long flags)
161         __releases(&key->reasm_lock)
162 {
163         struct sk_buff *skb;
164
165         skb = key->reasm_head;
166         key->reasm_head = NULL;
167         key->reasm_dead = true;
168         spin_unlock_irqrestore(&key->reasm_lock, flags);
169
170         spin_lock_irqsave(&net->mctp.keys_lock, flags);
171         hlist_del_rcu(&key->hlist);
172         hlist_del_rcu(&key->sklist);
173         spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
174         kfree_rcu(key, rcu);
175
176         if (skb)
177                 kfree_skb(skb);
178 }
179
180 static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb)
181 {
182         struct mctp_hdr *hdr = mctp_hdr(skb);
183         u8 exp_seq, this_seq;
184
185         this_seq = (hdr->flags_seq_tag >> MCTP_HDR_SEQ_SHIFT)
186                 & MCTP_HDR_SEQ_MASK;
187
188         if (!key->reasm_head) {
189                 key->reasm_head = skb;
190                 key->reasm_tailp = &(skb_shinfo(skb)->frag_list);
191                 key->last_seq = this_seq;
192                 return 0;
193         }
194
195         exp_seq = (key->last_seq + 1) & MCTP_HDR_SEQ_MASK;
196
197         if (this_seq != exp_seq)
198                 return -EINVAL;
199
200         if (key->reasm_head->len + skb->len > mctp_message_maxlen)
201                 return -EINVAL;
202
203         skb->next = NULL;
204         skb->sk = NULL;
205         *key->reasm_tailp = skb;
206         key->reasm_tailp = &skb->next;
207
208         key->last_seq = this_seq;
209
210         key->reasm_head->data_len += skb->len;
211         key->reasm_head->len += skb->len;
212         key->reasm_head->truesize += skb->truesize;
213
214         return 0;
215 }
216
217 static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
218 {
219         struct net *net = dev_net(skb->dev);
220         struct mctp_sk_key *key;
221         struct mctp_sock *msk;
222         struct mctp_hdr *mh;
223         unsigned long f;
224         u8 tag, flags;
225         int rc;
226
227         msk = NULL;
228         rc = -EINVAL;
229
230         /* we may be receiving a locally-routed packet; drop source sk
231          * accounting
232          */
233         skb_orphan(skb);
234
235         /* ensure we have enough data for a header and a type */
236         if (skb->len < sizeof(struct mctp_hdr) + 1)
237                 goto out;
238
239         /* grab header, advance data ptr */
240         mh = mctp_hdr(skb);
241         skb_pull(skb, sizeof(struct mctp_hdr));
242
243         if (mh->ver != 1)
244                 goto out;
245
246         flags = mh->flags_seq_tag & (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM);
247         tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
248
249         rcu_read_lock();
250
251         /* lookup socket / reasm context, exactly matching (src,dest,tag) */
252         key = mctp_lookup_key(net, skb, mh->src);
253
254         if (flags & MCTP_HDR_FLAG_SOM) {
255                 if (key) {
256                         msk = container_of(key->sk, struct mctp_sock, sk);
257                 } else {
258                         /* first response to a broadcast? do a more general
259                          * key lookup to find the socket, but don't use this
260                          * key for reassembly - we'll create a more specific
261                          * one for future packets if required (ie, !EOM).
262                          */
263                         key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY);
264                         if (key) {
265                                 msk = container_of(key->sk,
266                                                    struct mctp_sock, sk);
267                                 key = NULL;
268                         }
269                 }
270
271                 if (!key && !msk && (tag & MCTP_HDR_FLAG_TO))
272                         msk = mctp_lookup_bind(net, skb);
273
274                 if (!msk) {
275                         rc = -ENOENT;
276                         goto out_unlock;
277                 }
278
279                 /* single-packet message? deliver to socket, clean up any
280                  * pending key.
281                  */
282                 if (flags & MCTP_HDR_FLAG_EOM) {
283                         sock_queue_rcv_skb(&msk->sk, skb);
284                         if (key) {
285                                 spin_lock_irqsave(&key->reasm_lock, f);
286                                 /* we've hit a pending reassembly; not much we
287                                  * can do but drop it
288                                  */
289                                 __mctp_key_unlock_drop(key, net, f);
290                         }
291                         rc = 0;
292                         goto out_unlock;
293                 }
294
295                 /* broadcast response or a bind() - create a key for further
296                  * packets for this message
297                  */
298                 if (!key) {
299                         key = mctp_key_alloc(msk, mh->dest, mh->src,
300                                              tag, GFP_ATOMIC);
301                         if (!key) {
302                                 rc = -ENOMEM;
303                                 goto out_unlock;
304                         }
305
306                         /* we can queue without the reasm lock here, as the
307                          * key isn't observable yet
308                          */
309                         mctp_frag_queue(key, skb);
310
311                         /* if the key_add fails, we've raced with another
312                          * SOM packet with the same src, dest and tag. There's
313                          * no way to distinguish future packets, so all we
314                          * can do is drop; we'll free the skb on exit from
315                          * this function.
316                          */
317                         rc = mctp_key_add(key, msk);
318                         if (rc)
319                                 kfree(key);
320
321                 } else {
322                         /* existing key: start reassembly */
323                         spin_lock_irqsave(&key->reasm_lock, f);
324
325                         if (key->reasm_head || key->reasm_dead) {
326                                 /* duplicate start? drop everything */
327                                 __mctp_key_unlock_drop(key, net, f);
328                                 rc = -EEXIST;
329                         } else {
330                                 rc = mctp_frag_queue(key, skb);
331                                 spin_unlock_irqrestore(&key->reasm_lock, f);
332                         }
333                 }
334
335         } else if (key) {
336                 /* this packet continues a previous message; reassemble
337                  * using the message-specific key
338                  */
339
340                 spin_lock_irqsave(&key->reasm_lock, f);
341
342                 /* we need to be continuing an existing reassembly... */
343                 if (!key->reasm_head)
344                         rc = -EINVAL;
345                 else
346                         rc = mctp_frag_queue(key, skb);
347
348                 /* end of message? deliver to socket, and we're done with
349                  * the reassembly/response key
350                  */
351                 if (!rc && flags & MCTP_HDR_FLAG_EOM) {
352                         sock_queue_rcv_skb(key->sk, key->reasm_head);
353                         key->reasm_head = NULL;
354                         __mctp_key_unlock_drop(key, net, f);
355                 } else {
356                         spin_unlock_irqrestore(&key->reasm_lock, f);
357                 }
358
359         } else {
360                 /* not a start, no matching key */
361                 rc = -ENOENT;
362         }
363
364 out_unlock:
365         rcu_read_unlock();
366 out:
367         if (rc)
368                 kfree_skb(skb);
369         return rc;
370 }
371
372 static unsigned int mctp_route_mtu(struct mctp_route *rt)
373 {
374         return rt->mtu ?: READ_ONCE(rt->dev->dev->mtu);
375 }
376
377 static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb)
378 {
379         struct mctp_hdr *hdr = mctp_hdr(skb);
380         char daddr_buf[MAX_ADDR_LEN];
381         char *daddr = NULL;
382         unsigned int mtu;
383         int rc;
384
385         skb->protocol = htons(ETH_P_MCTP);
386
387         mtu = READ_ONCE(skb->dev->mtu);
388         if (skb->len > mtu) {
389                 kfree_skb(skb);
390                 return -EMSGSIZE;
391         }
392
393         /* If lookup fails let the device handle daddr==NULL */
394         if (mctp_neigh_lookup(route->dev, hdr->dest, daddr_buf) == 0)
395                 daddr = daddr_buf;
396
397         rc = dev_hard_header(skb, skb->dev, ntohs(skb->protocol),
398                              daddr, skb->dev->dev_addr, skb->len);
399         if (rc < 0) {
400                 kfree_skb(skb);
401                 return -EHOSTUNREACH;
402         }
403
404         rc = dev_queue_xmit(skb);
405         if (rc)
406                 rc = net_xmit_errno(rc);
407
408         return rc;
409 }
410
411 /* route alloc/release */
412 static void mctp_route_release(struct mctp_route *rt)
413 {
414         if (refcount_dec_and_test(&rt->refs)) {
415                 dev_put(rt->dev->dev);
416                 kfree_rcu(rt, rcu);
417         }
418 }
419
420 /* returns a route with the refcount at 1 */
421 static struct mctp_route *mctp_route_alloc(void)
422 {
423         struct mctp_route *rt;
424
425         rt = kzalloc(sizeof(*rt), GFP_KERNEL);
426         if (!rt)
427                 return NULL;
428
429         INIT_LIST_HEAD(&rt->list);
430         refcount_set(&rt->refs, 1);
431         rt->output = mctp_route_discard;
432
433         return rt;
434 }
435
436 unsigned int mctp_default_net(struct net *net)
437 {
438         return READ_ONCE(net->mctp.default_net);
439 }
440
441 int mctp_default_net_set(struct net *net, unsigned int index)
442 {
443         if (index == 0)
444                 return -EINVAL;
445         WRITE_ONCE(net->mctp.default_net, index);
446         return 0;
447 }
448
449 /* tag management */
450 static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key,
451                              struct mctp_sock *msk)
452 {
453         struct netns_mctp *mns = &net->mctp;
454
455         lockdep_assert_held(&mns->keys_lock);
456
457         /* we hold the net->key_lock here, allowing updates to both
458          * then net and sk
459          */
460         hlist_add_head_rcu(&key->hlist, &mns->keys);
461         hlist_add_head_rcu(&key->sklist, &msk->keys);
462 }
463
464 /* Allocate a locally-owned tag value for (saddr, daddr), and reserve
465  * it for the socket msk
466  */
467 static int mctp_alloc_local_tag(struct mctp_sock *msk,
468                                 mctp_eid_t saddr, mctp_eid_t daddr, u8 *tagp)
469 {
470         struct net *net = sock_net(&msk->sk);
471         struct netns_mctp *mns = &net->mctp;
472         struct mctp_sk_key *key, *tmp;
473         unsigned long flags;
474         int rc = -EAGAIN;
475         u8 tagbits;
476
477         /* be optimistic, alloc now */
478         key = mctp_key_alloc(msk, saddr, daddr, 0, GFP_KERNEL);
479         if (!key)
480                 return -ENOMEM;
481
482         /* 8 possible tag values */
483         tagbits = 0xff;
484
485         spin_lock_irqsave(&mns->keys_lock, flags);
486
487         /* Walk through the existing keys, looking for potential conflicting
488          * tags. If we find a conflict, clear that bit from tagbits
489          */
490         hlist_for_each_entry(tmp, &mns->keys, hlist) {
491                 /* if we don't own the tag, it can't conflict */
492                 if (tmp->tag & MCTP_HDR_FLAG_TO)
493                         continue;
494
495                 if ((tmp->peer_addr == daddr ||
496                      tmp->peer_addr == MCTP_ADDR_ANY) &&
497                     tmp->local_addr == saddr)
498                         tagbits &= ~(1 << tmp->tag);
499
500                 if (!tagbits)
501                         break;
502         }
503
504         if (tagbits) {
505                 key->tag = __ffs(tagbits);
506                 mctp_reserve_tag(net, key, msk);
507                 *tagp = key->tag;
508                 rc = 0;
509         }
510
511         spin_unlock_irqrestore(&mns->keys_lock, flags);
512
513         if (!tagbits)
514                 kfree(key);
515
516         return rc;
517 }
518
519 /* routing lookups */
520 static bool mctp_rt_match_eid(struct mctp_route *rt,
521                               unsigned int net, mctp_eid_t eid)
522 {
523         return READ_ONCE(rt->dev->net) == net &&
524                 rt->min <= eid && rt->max >= eid;
525 }
526
527 /* compares match, used for duplicate prevention */
528 static bool mctp_rt_compare_exact(struct mctp_route *rt1,
529                                   struct mctp_route *rt2)
530 {
531         ASSERT_RTNL();
532         return rt1->dev->net == rt2->dev->net &&
533                 rt1->min == rt2->min &&
534                 rt1->max == rt2->max;
535 }
536
537 struct mctp_route *mctp_route_lookup(struct net *net, unsigned int dnet,
538                                      mctp_eid_t daddr)
539 {
540         struct mctp_route *tmp, *rt = NULL;
541
542         list_for_each_entry_rcu(tmp, &net->mctp.routes, list) {
543                 /* TODO: add metrics */
544                 if (mctp_rt_match_eid(tmp, dnet, daddr)) {
545                         if (refcount_inc_not_zero(&tmp->refs)) {
546                                 rt = tmp;
547                                 break;
548                         }
549                 }
550         }
551
552         return rt;
553 }
554
555 /* sends a skb to rt and releases the route. */
556 int mctp_do_route(struct mctp_route *rt, struct sk_buff *skb)
557 {
558         int rc;
559
560         rc = rt->output(rt, skb);
561         mctp_route_release(rt);
562         return rc;
563 }
564
565 static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
566                                   unsigned int mtu, u8 tag)
567 {
568         const unsigned int hlen = sizeof(struct mctp_hdr);
569         struct mctp_hdr *hdr, *hdr2;
570         unsigned int pos, size;
571         struct sk_buff *skb2;
572         int rc;
573         u8 seq;
574
575         hdr = mctp_hdr(skb);
576         seq = 0;
577         rc = 0;
578
579         if (mtu < hlen + 1) {
580                 kfree_skb(skb);
581                 return -EMSGSIZE;
582         }
583
584         /* we've got the header */
585         skb_pull(skb, hlen);
586
587         for (pos = 0; pos < skb->len;) {
588                 /* size of message payload */
589                 size = min(mtu - hlen, skb->len - pos);
590
591                 skb2 = alloc_skb(MCTP_HEADER_MAXLEN + hlen + size, GFP_KERNEL);
592                 if (!skb2) {
593                         rc = -ENOMEM;
594                         break;
595                 }
596
597                 /* generic skb copy */
598                 skb2->protocol = skb->protocol;
599                 skb2->priority = skb->priority;
600                 skb2->dev = skb->dev;
601                 memcpy(skb2->cb, skb->cb, sizeof(skb2->cb));
602
603                 if (skb->sk)
604                         skb_set_owner_w(skb2, skb->sk);
605
606                 /* establish packet */
607                 skb_reserve(skb2, MCTP_HEADER_MAXLEN);
608                 skb_reset_network_header(skb2);
609                 skb_put(skb2, hlen + size);
610                 skb2->transport_header = skb2->network_header + hlen;
611
612                 /* copy header fields, calculate SOM/EOM flags & seq */
613                 hdr2 = mctp_hdr(skb2);
614                 hdr2->ver = hdr->ver;
615                 hdr2->dest = hdr->dest;
616                 hdr2->src = hdr->src;
617                 hdr2->flags_seq_tag = tag &
618                         (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
619
620                 if (pos == 0)
621                         hdr2->flags_seq_tag |= MCTP_HDR_FLAG_SOM;
622
623                 if (pos + size == skb->len)
624                         hdr2->flags_seq_tag |= MCTP_HDR_FLAG_EOM;
625
626                 hdr2->flags_seq_tag |= seq << MCTP_HDR_SEQ_SHIFT;
627
628                 /* copy message payload */
629                 skb_copy_bits(skb, pos, skb_transport_header(skb2), size);
630
631                 /* do route, but don't drop the rt reference */
632                 rc = rt->output(rt, skb2);
633                 if (rc)
634                         break;
635
636                 seq = (seq + 1) & MCTP_HDR_SEQ_MASK;
637                 pos += size;
638         }
639
640         mctp_route_release(rt);
641         consume_skb(skb);
642         return rc;
643 }
644
645 int mctp_local_output(struct sock *sk, struct mctp_route *rt,
646                       struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag)
647 {
648         struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
649         struct mctp_skb_cb *cb = mctp_cb(skb);
650         struct mctp_hdr *hdr;
651         unsigned long flags;
652         unsigned int mtu;
653         mctp_eid_t saddr;
654         int rc;
655         u8 tag;
656
657         if (WARN_ON(!rt->dev))
658                 return -EINVAL;
659
660         spin_lock_irqsave(&rt->dev->addrs_lock, flags);
661         if (rt->dev->num_addrs == 0) {
662                 rc = -EHOSTUNREACH;
663         } else {
664                 /* use the outbound interface's first address as our source */
665                 saddr = rt->dev->addrs[0];
666                 rc = 0;
667         }
668         spin_unlock_irqrestore(&rt->dev->addrs_lock, flags);
669
670         if (rc)
671                 return rc;
672
673         if (req_tag & MCTP_HDR_FLAG_TO) {
674                 rc = mctp_alloc_local_tag(msk, saddr, daddr, &tag);
675                 if (rc)
676                         return rc;
677                 tag |= MCTP_HDR_FLAG_TO;
678         } else {
679                 tag = req_tag;
680         }
681
682
683         skb->protocol = htons(ETH_P_MCTP);
684         skb->priority = 0;
685         skb_reset_transport_header(skb);
686         skb_push(skb, sizeof(struct mctp_hdr));
687         skb_reset_network_header(skb);
688         skb->dev = rt->dev->dev;
689
690         /* cb->net will have been set on initial ingress */
691         cb->src = saddr;
692
693         /* set up common header fields */
694         hdr = mctp_hdr(skb);
695         hdr->ver = 1;
696         hdr->dest = daddr;
697         hdr->src = saddr;
698
699         mtu = mctp_route_mtu(rt);
700
701         if (skb->len + sizeof(struct mctp_hdr) <= mtu) {
702                 hdr->flags_seq_tag = MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM |
703                         tag;
704                 return mctp_do_route(rt, skb);
705         } else {
706                 return mctp_do_fragment_route(rt, skb, mtu, tag);
707         }
708 }
709
710 /* route management */
711 static int mctp_route_add(struct mctp_dev *mdev, mctp_eid_t daddr_start,
712                           unsigned int daddr_extent, unsigned int mtu,
713                           unsigned char type)
714 {
715         int (*rtfn)(struct mctp_route *rt, struct sk_buff *skb);
716         struct net *net = dev_net(mdev->dev);
717         struct mctp_route *rt, *ert;
718
719         if (!mctp_address_ok(daddr_start))
720                 return -EINVAL;
721
722         if (daddr_extent > 0xff || daddr_start + daddr_extent >= 255)
723                 return -EINVAL;
724
725         switch (type) {
726         case RTN_LOCAL:
727                 rtfn = mctp_route_input;
728                 break;
729         case RTN_UNICAST:
730                 rtfn = mctp_route_output;
731                 break;
732         default:
733                 return -EINVAL;
734         }
735
736         rt = mctp_route_alloc();
737         if (!rt)
738                 return -ENOMEM;
739
740         rt->min = daddr_start;
741         rt->max = daddr_start + daddr_extent;
742         rt->mtu = mtu;
743         rt->dev = mdev;
744         dev_hold(rt->dev->dev);
745         rt->type = type;
746         rt->output = rtfn;
747
748         ASSERT_RTNL();
749         /* Prevent duplicate identical routes. */
750         list_for_each_entry(ert, &net->mctp.routes, list) {
751                 if (mctp_rt_compare_exact(rt, ert)) {
752                         mctp_route_release(rt);
753                         return -EEXIST;
754                 }
755         }
756
757         list_add_rcu(&rt->list, &net->mctp.routes);
758
759         return 0;
760 }
761
762 static int mctp_route_remove(struct mctp_dev *mdev, mctp_eid_t daddr_start,
763                              unsigned int daddr_extent, unsigned char type)
764 {
765         struct net *net = dev_net(mdev->dev);
766         struct mctp_route *rt, *tmp;
767         mctp_eid_t daddr_end;
768         bool dropped;
769
770         if (daddr_extent > 0xff || daddr_start + daddr_extent >= 255)
771                 return -EINVAL;
772
773         daddr_end = daddr_start + daddr_extent;
774         dropped = false;
775
776         ASSERT_RTNL();
777
778         list_for_each_entry_safe(rt, tmp, &net->mctp.routes, list) {
779                 if (rt->dev == mdev &&
780                     rt->min == daddr_start && rt->max == daddr_end &&
781                     rt->type == type) {
782                         list_del_rcu(&rt->list);
783                         /* TODO: immediate RTM_DELROUTE */
784                         mctp_route_release(rt);
785                         dropped = true;
786                 }
787         }
788
789         return dropped ? 0 : -ENOENT;
790 }
791
792 int mctp_route_add_local(struct mctp_dev *mdev, mctp_eid_t addr)
793 {
794         return mctp_route_add(mdev, addr, 0, 0, RTN_LOCAL);
795 }
796
797 int mctp_route_remove_local(struct mctp_dev *mdev, mctp_eid_t addr)
798 {
799         return mctp_route_remove(mdev, addr, 0, RTN_LOCAL);
800 }
801
802 /* removes all entries for a given device */
803 void mctp_route_remove_dev(struct mctp_dev *mdev)
804 {
805         struct net *net = dev_net(mdev->dev);
806         struct mctp_route *rt, *tmp;
807
808         ASSERT_RTNL();
809         list_for_each_entry_safe(rt, tmp, &net->mctp.routes, list) {
810                 if (rt->dev == mdev) {
811                         list_del_rcu(&rt->list);
812                         /* TODO: immediate RTM_DELROUTE */
813                         mctp_route_release(rt);
814                 }
815         }
816 }
817
818 /* Incoming packet-handling */
819
820 static int mctp_pkttype_receive(struct sk_buff *skb, struct net_device *dev,
821                                 struct packet_type *pt,
822                                 struct net_device *orig_dev)
823 {
824         struct net *net = dev_net(dev);
825         struct mctp_skb_cb *cb;
826         struct mctp_route *rt;
827         struct mctp_hdr *mh;
828
829         /* basic non-data sanity checks */
830         if (dev->type != ARPHRD_MCTP)
831                 goto err_drop;
832
833         if (!pskb_may_pull(skb, sizeof(struct mctp_hdr)))
834                 goto err_drop;
835
836         skb_reset_transport_header(skb);
837         skb_reset_network_header(skb);
838
839         /* We have enough for a header; decode and route */
840         mh = mctp_hdr(skb);
841         if (mh->ver < MCTP_VER_MIN || mh->ver > MCTP_VER_MAX)
842                 goto err_drop;
843
844         cb = __mctp_cb(skb);
845         rcu_read_lock();
846         cb->net = READ_ONCE(__mctp_dev_get(dev)->net);
847         rcu_read_unlock();
848
849         rt = mctp_route_lookup(net, cb->net, mh->dest);
850         if (!rt)
851                 goto err_drop;
852
853         mctp_do_route(rt, skb);
854
855         return NET_RX_SUCCESS;
856
857 err_drop:
858         kfree_skb(skb);
859         return NET_RX_DROP;
860 }
861
862 static struct packet_type mctp_packet_type = {
863         .type = cpu_to_be16(ETH_P_MCTP),
864         .func = mctp_pkttype_receive,
865 };
866
867 /* netlink interface */
868
869 static const struct nla_policy rta_mctp_policy[RTA_MAX + 1] = {
870         [RTA_DST]               = { .type = NLA_U8 },
871         [RTA_METRICS]           = { .type = NLA_NESTED },
872         [RTA_OIF]               = { .type = NLA_U32 },
873 };
874
875 /* Common part for RTM_NEWROUTE and RTM_DELROUTE parsing.
876  * tb must hold RTA_MAX+1 elements.
877  */
878 static int mctp_route_nlparse(struct sk_buff *skb, struct nlmsghdr *nlh,
879                               struct netlink_ext_ack *extack,
880                               struct nlattr **tb, struct rtmsg **rtm,
881                               struct mctp_dev **mdev, mctp_eid_t *daddr_start)
882 {
883         struct net *net = sock_net(skb->sk);
884         struct net_device *dev;
885         unsigned int ifindex;
886         int rc;
887
888         rc = nlmsg_parse(nlh, sizeof(struct rtmsg), tb, RTA_MAX,
889                          rta_mctp_policy, extack);
890         if (rc < 0) {
891                 NL_SET_ERR_MSG(extack, "incorrect format");
892                 return rc;
893         }
894
895         if (!tb[RTA_DST]) {
896                 NL_SET_ERR_MSG(extack, "dst EID missing");
897                 return -EINVAL;
898         }
899         *daddr_start = nla_get_u8(tb[RTA_DST]);
900
901         if (!tb[RTA_OIF]) {
902                 NL_SET_ERR_MSG(extack, "ifindex missing");
903                 return -EINVAL;
904         }
905         ifindex = nla_get_u32(tb[RTA_OIF]);
906
907         *rtm = nlmsg_data(nlh);
908         if ((*rtm)->rtm_family != AF_MCTP) {
909                 NL_SET_ERR_MSG(extack, "route family must be AF_MCTP");
910                 return -EINVAL;
911         }
912
913         dev = __dev_get_by_index(net, ifindex);
914         if (!dev) {
915                 NL_SET_ERR_MSG(extack, "bad ifindex");
916                 return -ENODEV;
917         }
918         *mdev = mctp_dev_get_rtnl(dev);
919         if (!*mdev)
920                 return -ENODEV;
921
922         if (dev->flags & IFF_LOOPBACK) {
923                 NL_SET_ERR_MSG(extack, "no routes to loopback");
924                 return -EINVAL;
925         }
926
927         return 0;
928 }
929
930 static int mctp_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
931                          struct netlink_ext_ack *extack)
932 {
933         struct nlattr *tb[RTA_MAX + 1];
934         mctp_eid_t daddr_start;
935         struct mctp_dev *mdev;
936         struct rtmsg *rtm;
937         unsigned int mtu;
938         int rc;
939
940         rc = mctp_route_nlparse(skb, nlh, extack, tb,
941                                 &rtm, &mdev, &daddr_start);
942         if (rc < 0)
943                 return rc;
944
945         if (rtm->rtm_type != RTN_UNICAST) {
946                 NL_SET_ERR_MSG(extack, "rtm_type must be RTN_UNICAST");
947                 return -EINVAL;
948         }
949
950         /* TODO: parse mtu from nlparse */
951         mtu = 0;
952
953         if (rtm->rtm_type != RTN_UNICAST)
954                 return -EINVAL;
955
956         rc = mctp_route_add(mdev, daddr_start, rtm->rtm_dst_len, mtu,
957                             rtm->rtm_type);
958         return rc;
959 }
960
961 static int mctp_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
962                          struct netlink_ext_ack *extack)
963 {
964         struct nlattr *tb[RTA_MAX + 1];
965         mctp_eid_t daddr_start;
966         struct mctp_dev *mdev;
967         struct rtmsg *rtm;
968         int rc;
969
970         rc = mctp_route_nlparse(skb, nlh, extack, tb,
971                                 &rtm, &mdev, &daddr_start);
972         if (rc < 0)
973                 return rc;
974
975         /* we only have unicast routes */
976         if (rtm->rtm_type != RTN_UNICAST)
977                 return -EINVAL;
978
979         rc = mctp_route_remove(mdev, daddr_start, rtm->rtm_dst_len, RTN_UNICAST);
980         return rc;
981 }
982
983 static int mctp_fill_rtinfo(struct sk_buff *skb, struct mctp_route *rt,
984                             u32 portid, u32 seq, int event, unsigned int flags)
985 {
986         struct nlmsghdr *nlh;
987         struct rtmsg *hdr;
988         void *metrics;
989
990         nlh = nlmsg_put(skb, portid, seq, event, sizeof(*hdr), flags);
991         if (!nlh)
992                 return -EMSGSIZE;
993
994         hdr = nlmsg_data(nlh);
995         hdr->rtm_family = AF_MCTP;
996
997         /* we use the _len fields as a number of EIDs, rather than
998          * a number of bits in the address
999          */
1000         hdr->rtm_dst_len = rt->max - rt->min;
1001         hdr->rtm_src_len = 0;
1002         hdr->rtm_tos = 0;
1003         hdr->rtm_table = RT_TABLE_DEFAULT;
1004         hdr->rtm_protocol = RTPROT_STATIC; /* everything is user-defined */
1005         hdr->rtm_scope = RT_SCOPE_LINK; /* TODO: scope in mctp_route? */
1006         hdr->rtm_type = rt->type;
1007
1008         if (nla_put_u8(skb, RTA_DST, rt->min))
1009                 goto cancel;
1010
1011         metrics = nla_nest_start_noflag(skb, RTA_METRICS);
1012         if (!metrics)
1013                 goto cancel;
1014
1015         if (rt->mtu) {
1016                 if (nla_put_u32(skb, RTAX_MTU, rt->mtu))
1017                         goto cancel;
1018         }
1019
1020         nla_nest_end(skb, metrics);
1021
1022         if (rt->dev) {
1023                 if (nla_put_u32(skb, RTA_OIF, rt->dev->dev->ifindex))
1024                         goto cancel;
1025         }
1026
1027         /* TODO: conditional neighbour physaddr? */
1028
1029         nlmsg_end(skb, nlh);
1030
1031         return 0;
1032
1033 cancel:
1034         nlmsg_cancel(skb, nlh);
1035         return -EMSGSIZE;
1036 }
1037
1038 static int mctp_dump_rtinfo(struct sk_buff *skb, struct netlink_callback *cb)
1039 {
1040         struct net *net = sock_net(skb->sk);
1041         struct mctp_route *rt;
1042         int s_idx, idx;
1043
1044         /* TODO: allow filtering on route data, possibly under
1045          * cb->strict_check
1046          */
1047
1048         /* TODO: change to struct overlay */
1049         s_idx = cb->args[0];
1050         idx = 0;
1051
1052         rcu_read_lock();
1053         list_for_each_entry_rcu(rt, &net->mctp.routes, list) {
1054                 if (idx++ < s_idx)
1055                         continue;
1056                 if (mctp_fill_rtinfo(skb, rt,
1057                                      NETLINK_CB(cb->skb).portid,
1058                                      cb->nlh->nlmsg_seq,
1059                                      RTM_NEWROUTE, NLM_F_MULTI) < 0)
1060                         break;
1061         }
1062
1063         rcu_read_unlock();
1064         cb->args[0] = idx;
1065
1066         return skb->len;
1067 }
1068
1069 /* net namespace implementation */
1070 static int __net_init mctp_routes_net_init(struct net *net)
1071 {
1072         struct netns_mctp *ns = &net->mctp;
1073
1074         INIT_LIST_HEAD(&ns->routes);
1075         INIT_HLIST_HEAD(&ns->binds);
1076         mutex_init(&ns->bind_lock);
1077         INIT_HLIST_HEAD(&ns->keys);
1078         spin_lock_init(&ns->keys_lock);
1079         WARN_ON(mctp_default_net_set(net, MCTP_INITIAL_DEFAULT_NET));
1080         return 0;
1081 }
1082
1083 static void __net_exit mctp_routes_net_exit(struct net *net)
1084 {
1085         struct mctp_route *rt;
1086
1087         rcu_read_lock();
1088         list_for_each_entry_rcu(rt, &net->mctp.routes, list)
1089                 mctp_route_release(rt);
1090         rcu_read_unlock();
1091 }
1092
1093 static struct pernet_operations mctp_net_ops = {
1094         .init = mctp_routes_net_init,
1095         .exit = mctp_routes_net_exit,
1096 };
1097
1098 int __init mctp_routes_init(void)
1099 {
1100         dev_add_pack(&mctp_packet_type);
1101
1102         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_GETROUTE,
1103                              NULL, mctp_dump_rtinfo, 0);
1104         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_NEWROUTE,
1105                              mctp_newroute, NULL, 0);
1106         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_DELROUTE,
1107                              mctp_delroute, NULL, 0);
1108
1109         return register_pernet_subsys(&mctp_net_ops);
1110 }
1111
1112 void __exit mctp_routes_exit(void)
1113 {
1114         unregister_pernet_subsys(&mctp_net_ops);
1115         rtnl_unregister(PF_MCTP, RTM_DELROUTE);
1116         rtnl_unregister(PF_MCTP, RTM_NEWROUTE);
1117         rtnl_unregister(PF_MCTP, RTM_GETROUTE);
1118         dev_remove_pack(&mctp_packet_type);
1119 }