Merge tag 'block-6.1-2022-11-11' of git://git.kernel.dk/linux
[platform/kernel/linux-starfive.git] / net / mctp / device.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Management Component Transport Protocol (MCTP) - device implementation.
4  *
5  * Copyright (c) 2021 Code Construct
6  * Copyright (c) 2021 Google
7  */
8
9 #include <linux/if_arp.h>
10 #include <linux/if_link.h>
11 #include <linux/mctp.h>
12 #include <linux/netdevice.h>
13 #include <linux/rcupdate.h>
14 #include <linux/rtnetlink.h>
15
16 #include <net/addrconf.h>
17 #include <net/netlink.h>
18 #include <net/mctp.h>
19 #include <net/mctpdevice.h>
20 #include <net/sock.h>
21
22 struct mctp_dump_cb {
23         int h;
24         int idx;
25         size_t a_idx;
26 };
27
28 /* unlocked: caller must hold rcu_read_lock.
29  * Returned mctp_dev has its refcount incremented, or NULL if unset.
30  */
31 struct mctp_dev *__mctp_dev_get(const struct net_device *dev)
32 {
33         struct mctp_dev *mdev = rcu_dereference(dev->mctp_ptr);
34
35         /* RCU guarantees that any mdev is still live.
36          * Zero refcount implies a pending free, return NULL.
37          */
38         if (mdev)
39                 if (!refcount_inc_not_zero(&mdev->refs))
40                         return NULL;
41         return mdev;
42 }
43
44 /* Returned mctp_dev does not have refcount incremented. The returned pointer
45  * remains live while rtnl_lock is held, as that prevents mctp_unregister()
46  */
47 struct mctp_dev *mctp_dev_get_rtnl(const struct net_device *dev)
48 {
49         return rtnl_dereference(dev->mctp_ptr);
50 }
51
52 static int mctp_addrinfo_size(void)
53 {
54         return NLMSG_ALIGN(sizeof(struct ifaddrmsg))
55                 + nla_total_size(1) // IFA_LOCAL
56                 + nla_total_size(1) // IFA_ADDRESS
57                 ;
58 }
59
60 /* flag should be NLM_F_MULTI for dump calls */
61 static int mctp_fill_addrinfo(struct sk_buff *skb,
62                               struct mctp_dev *mdev, mctp_eid_t eid,
63                               int msg_type, u32 portid, u32 seq, int flag)
64 {
65         struct ifaddrmsg *hdr;
66         struct nlmsghdr *nlh;
67
68         nlh = nlmsg_put(skb, portid, seq,
69                         msg_type, sizeof(*hdr), flag);
70         if (!nlh)
71                 return -EMSGSIZE;
72
73         hdr = nlmsg_data(nlh);
74         hdr->ifa_family = AF_MCTP;
75         hdr->ifa_prefixlen = 0;
76         hdr->ifa_flags = 0;
77         hdr->ifa_scope = 0;
78         hdr->ifa_index = mdev->dev->ifindex;
79
80         if (nla_put_u8(skb, IFA_LOCAL, eid))
81                 goto cancel;
82
83         if (nla_put_u8(skb, IFA_ADDRESS, eid))
84                 goto cancel;
85
86         nlmsg_end(skb, nlh);
87
88         return 0;
89
90 cancel:
91         nlmsg_cancel(skb, nlh);
92         return -EMSGSIZE;
93 }
94
95 static int mctp_dump_dev_addrinfo(struct mctp_dev *mdev, struct sk_buff *skb,
96                                   struct netlink_callback *cb)
97 {
98         struct mctp_dump_cb *mcb = (void *)cb->ctx;
99         u32 portid, seq;
100         int rc = 0;
101
102         portid = NETLINK_CB(cb->skb).portid;
103         seq = cb->nlh->nlmsg_seq;
104         for (; mcb->a_idx < mdev->num_addrs; mcb->a_idx++) {
105                 rc = mctp_fill_addrinfo(skb, mdev, mdev->addrs[mcb->a_idx],
106                                         RTM_NEWADDR, portid, seq, NLM_F_MULTI);
107                 if (rc < 0)
108                         break;
109         }
110
111         return rc;
112 }
113
114 static int mctp_dump_addrinfo(struct sk_buff *skb, struct netlink_callback *cb)
115 {
116         struct mctp_dump_cb *mcb = (void *)cb->ctx;
117         struct net *net = sock_net(skb->sk);
118         struct hlist_head *head;
119         struct net_device *dev;
120         struct ifaddrmsg *hdr;
121         struct mctp_dev *mdev;
122         int ifindex;
123         int idx = 0, rc;
124
125         hdr = nlmsg_data(cb->nlh);
126         // filter by ifindex if requested
127         ifindex = hdr->ifa_index;
128
129         rcu_read_lock();
130         for (; mcb->h < NETDEV_HASHENTRIES; mcb->h++, mcb->idx = 0) {
131                 idx = 0;
132                 head = &net->dev_index_head[mcb->h];
133                 hlist_for_each_entry_rcu(dev, head, index_hlist) {
134                         if (idx >= mcb->idx &&
135                             (ifindex == 0 || ifindex == dev->ifindex)) {
136                                 mdev = __mctp_dev_get(dev);
137                                 if (mdev) {
138                                         rc = mctp_dump_dev_addrinfo(mdev,
139                                                                     skb, cb);
140                                         mctp_dev_put(mdev);
141                                         // Error indicates full buffer, this
142                                         // callback will get retried.
143                                         if (rc < 0)
144                                                 goto out;
145                                 }
146                         }
147                         idx++;
148                         // reset for next iteration
149                         mcb->a_idx = 0;
150                 }
151         }
152 out:
153         rcu_read_unlock();
154         mcb->idx = idx;
155
156         return skb->len;
157 }
158
159 static void mctp_addr_notify(struct mctp_dev *mdev, mctp_eid_t eid, int msg_type,
160                              struct sk_buff *req_skb, struct nlmsghdr *req_nlh)
161 {
162         u32 portid = NETLINK_CB(req_skb).portid;
163         struct net *net = dev_net(mdev->dev);
164         struct sk_buff *skb;
165         int rc = -ENOBUFS;
166
167         skb = nlmsg_new(mctp_addrinfo_size(), GFP_KERNEL);
168         if (!skb)
169                 goto out;
170
171         rc = mctp_fill_addrinfo(skb, mdev, eid, msg_type,
172                                 portid, req_nlh->nlmsg_seq, 0);
173         if (rc < 0) {
174                 WARN_ON_ONCE(rc == -EMSGSIZE);
175                 goto out;
176         }
177
178         rtnl_notify(skb, net, portid, RTNLGRP_MCTP_IFADDR, req_nlh, GFP_KERNEL);
179         return;
180 out:
181         kfree_skb(skb);
182         rtnl_set_sk_err(net, RTNLGRP_MCTP_IFADDR, rc);
183 }
184
185 static const struct nla_policy ifa_mctp_policy[IFA_MAX + 1] = {
186         [IFA_ADDRESS]           = { .type = NLA_U8 },
187         [IFA_LOCAL]             = { .type = NLA_U8 },
188 };
189
190 static int mctp_rtm_newaddr(struct sk_buff *skb, struct nlmsghdr *nlh,
191                             struct netlink_ext_ack *extack)
192 {
193         struct net *net = sock_net(skb->sk);
194         struct nlattr *tb[IFA_MAX + 1];
195         struct net_device *dev;
196         struct mctp_addr *addr;
197         struct mctp_dev *mdev;
198         struct ifaddrmsg *ifm;
199         unsigned long flags;
200         u8 *tmp_addrs;
201         int rc;
202
203         rc = nlmsg_parse(nlh, sizeof(*ifm), tb, IFA_MAX, ifa_mctp_policy,
204                          extack);
205         if (rc < 0)
206                 return rc;
207
208         ifm = nlmsg_data(nlh);
209
210         if (tb[IFA_LOCAL])
211                 addr = nla_data(tb[IFA_LOCAL]);
212         else if (tb[IFA_ADDRESS])
213                 addr = nla_data(tb[IFA_ADDRESS]);
214         else
215                 return -EINVAL;
216
217         /* find device */
218         dev = __dev_get_by_index(net, ifm->ifa_index);
219         if (!dev)
220                 return -ENODEV;
221
222         mdev = mctp_dev_get_rtnl(dev);
223         if (!mdev)
224                 return -ENODEV;
225
226         if (!mctp_address_unicast(addr->s_addr))
227                 return -EINVAL;
228
229         /* Prevent duplicates. Under RTNL so don't need to lock for reading */
230         if (memchr(mdev->addrs, addr->s_addr, mdev->num_addrs))
231                 return -EEXIST;
232
233         tmp_addrs = kmalloc(mdev->num_addrs + 1, GFP_KERNEL);
234         if (!tmp_addrs)
235                 return -ENOMEM;
236         memcpy(tmp_addrs, mdev->addrs, mdev->num_addrs);
237         tmp_addrs[mdev->num_addrs] = addr->s_addr;
238
239         /* Lock to write */
240         spin_lock_irqsave(&mdev->addrs_lock, flags);
241         mdev->num_addrs++;
242         swap(mdev->addrs, tmp_addrs);
243         spin_unlock_irqrestore(&mdev->addrs_lock, flags);
244
245         kfree(tmp_addrs);
246
247         mctp_addr_notify(mdev, addr->s_addr, RTM_NEWADDR, skb, nlh);
248         mctp_route_add_local(mdev, addr->s_addr);
249
250         return 0;
251 }
252
253 static int mctp_rtm_deladdr(struct sk_buff *skb, struct nlmsghdr *nlh,
254                             struct netlink_ext_ack *extack)
255 {
256         struct net *net = sock_net(skb->sk);
257         struct nlattr *tb[IFA_MAX + 1];
258         struct net_device *dev;
259         struct mctp_addr *addr;
260         struct mctp_dev *mdev;
261         struct ifaddrmsg *ifm;
262         unsigned long flags;
263         u8 *pos;
264         int rc;
265
266         rc = nlmsg_parse(nlh, sizeof(*ifm), tb, IFA_MAX, ifa_mctp_policy,
267                          extack);
268         if (rc < 0)
269                 return rc;
270
271         ifm = nlmsg_data(nlh);
272
273         if (tb[IFA_LOCAL])
274                 addr = nla_data(tb[IFA_LOCAL]);
275         else if (tb[IFA_ADDRESS])
276                 addr = nla_data(tb[IFA_ADDRESS]);
277         else
278                 return -EINVAL;
279
280         /* find device */
281         dev = __dev_get_by_index(net, ifm->ifa_index);
282         if (!dev)
283                 return -ENODEV;
284
285         mdev = mctp_dev_get_rtnl(dev);
286         if (!mdev)
287                 return -ENODEV;
288
289         pos = memchr(mdev->addrs, addr->s_addr, mdev->num_addrs);
290         if (!pos)
291                 return -ENOENT;
292
293         rc = mctp_route_remove_local(mdev, addr->s_addr);
294         // we can ignore -ENOENT in the case a route was already removed
295         if (rc < 0 && rc != -ENOENT)
296                 return rc;
297
298         spin_lock_irqsave(&mdev->addrs_lock, flags);
299         memmove(pos, pos + 1, mdev->num_addrs - 1 - (pos - mdev->addrs));
300         mdev->num_addrs--;
301         spin_unlock_irqrestore(&mdev->addrs_lock, flags);
302
303         mctp_addr_notify(mdev, addr->s_addr, RTM_DELADDR, skb, nlh);
304
305         return 0;
306 }
307
308 void mctp_dev_hold(struct mctp_dev *mdev)
309 {
310         refcount_inc(&mdev->refs);
311 }
312
313 void mctp_dev_put(struct mctp_dev *mdev)
314 {
315         if (mdev && refcount_dec_and_test(&mdev->refs)) {
316                 kfree(mdev->addrs);
317                 dev_put(mdev->dev);
318                 kfree_rcu(mdev, rcu);
319         }
320 }
321
322 void mctp_dev_release_key(struct mctp_dev *dev, struct mctp_sk_key *key)
323         __must_hold(&key->lock)
324 {
325         if (!dev)
326                 return;
327         if (dev->ops && dev->ops->release_flow)
328                 dev->ops->release_flow(dev, key);
329         key->dev = NULL;
330         mctp_dev_put(dev);
331 }
332
333 void mctp_dev_set_key(struct mctp_dev *dev, struct mctp_sk_key *key)
334         __must_hold(&key->lock)
335 {
336         mctp_dev_hold(dev);
337         key->dev = dev;
338 }
339
340 static struct mctp_dev *mctp_add_dev(struct net_device *dev)
341 {
342         struct mctp_dev *mdev;
343
344         ASSERT_RTNL();
345
346         mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
347         if (!mdev)
348                 return ERR_PTR(-ENOMEM);
349
350         spin_lock_init(&mdev->addrs_lock);
351
352         mdev->net = mctp_default_net(dev_net(dev));
353
354         /* associate to net_device */
355         refcount_set(&mdev->refs, 1);
356         rcu_assign_pointer(dev->mctp_ptr, mdev);
357
358         dev_hold(dev);
359         mdev->dev = dev;
360
361         return mdev;
362 }
363
364 static int mctp_fill_link_af(struct sk_buff *skb,
365                              const struct net_device *dev, u32 ext_filter_mask)
366 {
367         struct mctp_dev *mdev;
368
369         mdev = mctp_dev_get_rtnl(dev);
370         if (!mdev)
371                 return -ENODATA;
372         if (nla_put_u32(skb, IFLA_MCTP_NET, mdev->net))
373                 return -EMSGSIZE;
374         return 0;
375 }
376
377 static size_t mctp_get_link_af_size(const struct net_device *dev,
378                                     u32 ext_filter_mask)
379 {
380         struct mctp_dev *mdev;
381         unsigned int ret;
382
383         /* caller holds RCU */
384         mdev = __mctp_dev_get(dev);
385         if (!mdev)
386                 return 0;
387         ret = nla_total_size(4); /* IFLA_MCTP_NET */
388         mctp_dev_put(mdev);
389         return ret;
390 }
391
392 static const struct nla_policy ifla_af_mctp_policy[IFLA_MCTP_MAX + 1] = {
393         [IFLA_MCTP_NET]         = { .type = NLA_U32 },
394 };
395
396 static int mctp_set_link_af(struct net_device *dev, const struct nlattr *attr,
397                             struct netlink_ext_ack *extack)
398 {
399         struct nlattr *tb[IFLA_MCTP_MAX + 1];
400         struct mctp_dev *mdev;
401         int rc;
402
403         rc = nla_parse_nested(tb, IFLA_MCTP_MAX, attr, ifla_af_mctp_policy,
404                               NULL);
405         if (rc)
406                 return rc;
407
408         mdev = mctp_dev_get_rtnl(dev);
409         if (!mdev)
410                 return 0;
411
412         if (tb[IFLA_MCTP_NET])
413                 WRITE_ONCE(mdev->net, nla_get_u32(tb[IFLA_MCTP_NET]));
414
415         return 0;
416 }
417
418 /* Matches netdev types that should have MCTP handling */
419 static bool mctp_known(struct net_device *dev)
420 {
421         /* only register specific types (inc. NONE for TUN devices) */
422         return dev->type == ARPHRD_MCTP ||
423                    dev->type == ARPHRD_LOOPBACK ||
424                    dev->type == ARPHRD_NONE;
425 }
426
427 static void mctp_unregister(struct net_device *dev)
428 {
429         struct mctp_dev *mdev;
430
431         mdev = mctp_dev_get_rtnl(dev);
432         if (mdev && !mctp_known(dev)) {
433                 // Sanity check, should match what was set in mctp_register
434                 netdev_warn(dev, "%s: BUG mctp_ptr set for unknown type %d",
435                             __func__, dev->type);
436                 return;
437         }
438         if (!mdev)
439                 return;
440
441         RCU_INIT_POINTER(mdev->dev->mctp_ptr, NULL);
442
443         mctp_route_remove_dev(mdev);
444         mctp_neigh_remove_dev(mdev);
445
446         mctp_dev_put(mdev);
447 }
448
449 static int mctp_register(struct net_device *dev)
450 {
451         struct mctp_dev *mdev;
452
453         /* Already registered? */
454         mdev = rtnl_dereference(dev->mctp_ptr);
455
456         if (mdev) {
457                 if (!mctp_known(dev))
458                         netdev_warn(dev, "%s: BUG mctp_ptr set for unknown type %d",
459                                     __func__, dev->type);
460                 return 0;
461         }
462
463         /* only register specific types */
464         if (!mctp_known(dev))
465                 return 0;
466
467         mdev = mctp_add_dev(dev);
468         if (IS_ERR(mdev))
469                 return PTR_ERR(mdev);
470
471         return 0;
472 }
473
474 static int mctp_dev_notify(struct notifier_block *this, unsigned long event,
475                            void *ptr)
476 {
477         struct net_device *dev = netdev_notifier_info_to_dev(ptr);
478         int rc;
479
480         switch (event) {
481         case NETDEV_REGISTER:
482                 rc = mctp_register(dev);
483                 if (rc)
484                         return notifier_from_errno(rc);
485                 break;
486         case NETDEV_UNREGISTER:
487                 mctp_unregister(dev);
488                 break;
489         }
490
491         return NOTIFY_OK;
492 }
493
494 static int mctp_register_netdevice(struct net_device *dev,
495                                    const struct mctp_netdev_ops *ops)
496 {
497         struct mctp_dev *mdev;
498
499         mdev = mctp_add_dev(dev);
500         if (IS_ERR(mdev))
501                 return PTR_ERR(mdev);
502
503         mdev->ops = ops;
504
505         return register_netdevice(dev);
506 }
507
508 int mctp_register_netdev(struct net_device *dev,
509                          const struct mctp_netdev_ops *ops)
510 {
511         int rc;
512
513         rtnl_lock();
514         rc = mctp_register_netdevice(dev, ops);
515         rtnl_unlock();
516
517         return rc;
518 }
519 EXPORT_SYMBOL_GPL(mctp_register_netdev);
520
521 void mctp_unregister_netdev(struct net_device *dev)
522 {
523         unregister_netdev(dev);
524 }
525 EXPORT_SYMBOL_GPL(mctp_unregister_netdev);
526
527 static struct rtnl_af_ops mctp_af_ops = {
528         .family = AF_MCTP,
529         .fill_link_af = mctp_fill_link_af,
530         .get_link_af_size = mctp_get_link_af_size,
531         .set_link_af = mctp_set_link_af,
532 };
533
534 static struct notifier_block mctp_dev_nb = {
535         .notifier_call = mctp_dev_notify,
536         .priority = ADDRCONF_NOTIFY_PRIORITY,
537 };
538
539 void __init mctp_device_init(void)
540 {
541         register_netdevice_notifier(&mctp_dev_nb);
542
543         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_GETADDR,
544                              NULL, mctp_dump_addrinfo, 0);
545         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_NEWADDR,
546                              mctp_rtm_newaddr, NULL, 0);
547         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_DELADDR,
548                              mctp_rtm_deladdr, NULL, 0);
549         rtnl_af_register(&mctp_af_ops);
550 }
551
552 void __exit mctp_device_exit(void)
553 {
554         rtnl_af_unregister(&mctp_af_ops);
555         rtnl_unregister(PF_MCTP, RTM_DELADDR);
556         rtnl_unregister(PF_MCTP, RTM_NEWADDR);
557         rtnl_unregister(PF_MCTP, RTM_GETADDR);
558
559         unregister_netdevice_notifier(&mctp_dev_nb);
560 }