Merge tag 'net-5.18-rc4' of git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net
[platform/kernel/linux-starfive.git] / net / mctp / neigh.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Management Component Transport Protocol (MCTP) - routing
4  * implementation.
5  *
6  * This is currently based on a simple routing table, with no dst cache. The
7  * number of routes should stay fairly small, so the lookup cost is small.
8  *
9  * Copyright (c) 2021 Code Construct
10  * Copyright (c) 2021 Google
11  */
12
13 #include <linux/idr.h>
14 #include <linux/mctp.h>
15 #include <linux/netdevice.h>
16 #include <linux/rtnetlink.h>
17 #include <linux/skbuff.h>
18
19 #include <net/mctp.h>
20 #include <net/mctpdevice.h>
21 #include <net/netlink.h>
22 #include <net/sock.h>
23
24 static int mctp_neigh_add(struct mctp_dev *mdev, mctp_eid_t eid,
25                           enum mctp_neigh_source source,
26                           size_t lladdr_len, const void *lladdr)
27 {
28         struct net *net = dev_net(mdev->dev);
29         struct mctp_neigh *neigh;
30         int rc;
31
32         mutex_lock(&net->mctp.neigh_lock);
33         if (mctp_neigh_lookup(mdev, eid, NULL) == 0) {
34                 rc = -EEXIST;
35                 goto out;
36         }
37
38         if (lladdr_len > sizeof(neigh->ha)) {
39                 rc = -EINVAL;
40                 goto out;
41         }
42
43         neigh = kzalloc(sizeof(*neigh), GFP_KERNEL);
44         if (!neigh) {
45                 rc = -ENOMEM;
46                 goto out;
47         }
48         INIT_LIST_HEAD(&neigh->list);
49         neigh->dev = mdev;
50         mctp_dev_hold(neigh->dev);
51         neigh->eid = eid;
52         neigh->source = source;
53         memcpy(neigh->ha, lladdr, lladdr_len);
54
55         list_add_rcu(&neigh->list, &net->mctp.neighbours);
56         rc = 0;
57 out:
58         mutex_unlock(&net->mctp.neigh_lock);
59         return rc;
60 }
61
62 static void __mctp_neigh_free(struct rcu_head *rcu)
63 {
64         struct mctp_neigh *neigh = container_of(rcu, struct mctp_neigh, rcu);
65
66         mctp_dev_put(neigh->dev);
67         kfree(neigh);
68 }
69
70 /* Removes all neighbour entries referring to a device */
71 void mctp_neigh_remove_dev(struct mctp_dev *mdev)
72 {
73         struct net *net = dev_net(mdev->dev);
74         struct mctp_neigh *neigh, *tmp;
75
76         mutex_lock(&net->mctp.neigh_lock);
77         list_for_each_entry_safe(neigh, tmp, &net->mctp.neighbours, list) {
78                 if (neigh->dev == mdev) {
79                         list_del_rcu(&neigh->list);
80                         /* TODO: immediate RTM_DELNEIGH */
81                         call_rcu(&neigh->rcu, __mctp_neigh_free);
82                 }
83         }
84
85         mutex_unlock(&net->mctp.neigh_lock);
86 }
87
88 static int mctp_neigh_remove(struct mctp_dev *mdev, mctp_eid_t eid,
89                              enum mctp_neigh_source source)
90 {
91         struct net *net = dev_net(mdev->dev);
92         struct mctp_neigh *neigh, *tmp;
93         bool dropped = false;
94
95         mutex_lock(&net->mctp.neigh_lock);
96         list_for_each_entry_safe(neigh, tmp, &net->mctp.neighbours, list) {
97                 if (neigh->dev == mdev && neigh->eid == eid &&
98                     neigh->source == source) {
99                         list_del_rcu(&neigh->list);
100                         /* TODO: immediate RTM_DELNEIGH */
101                         call_rcu(&neigh->rcu, __mctp_neigh_free);
102                         dropped = true;
103                 }
104         }
105
106         mutex_unlock(&net->mctp.neigh_lock);
107         return dropped ? 0 : -ENOENT;
108 }
109
110 static const struct nla_policy nd_mctp_policy[NDA_MAX + 1] = {
111         [NDA_DST]               = { .type = NLA_U8 },
112         [NDA_LLADDR]            = { .type = NLA_BINARY, .len = MAX_ADDR_LEN },
113 };
114
115 static int mctp_rtm_newneigh(struct sk_buff *skb, struct nlmsghdr *nlh,
116                              struct netlink_ext_ack *extack)
117 {
118         struct net *net = sock_net(skb->sk);
119         struct net_device *dev;
120         struct mctp_dev *mdev;
121         struct ndmsg *ndm;
122         struct nlattr *tb[NDA_MAX + 1];
123         int rc;
124         mctp_eid_t eid;
125         void *lladdr;
126         int lladdr_len;
127
128         rc = nlmsg_parse(nlh, sizeof(*ndm), tb, NDA_MAX, nd_mctp_policy,
129                          extack);
130         if (rc < 0) {
131                 NL_SET_ERR_MSG(extack, "lladdr too large?");
132                 return rc;
133         }
134
135         if (!tb[NDA_DST]) {
136                 NL_SET_ERR_MSG(extack, "Neighbour EID must be specified");
137                 return -EINVAL;
138         }
139
140         if (!tb[NDA_LLADDR]) {
141                 NL_SET_ERR_MSG(extack, "Neighbour lladdr must be specified");
142                 return -EINVAL;
143         }
144
145         eid = nla_get_u8(tb[NDA_DST]);
146         if (!mctp_address_unicast(eid)) {
147                 NL_SET_ERR_MSG(extack, "Invalid neighbour EID");
148                 return -EINVAL;
149         }
150
151         lladdr = nla_data(tb[NDA_LLADDR]);
152         lladdr_len = nla_len(tb[NDA_LLADDR]);
153
154         ndm = nlmsg_data(nlh);
155
156         dev = __dev_get_by_index(net, ndm->ndm_ifindex);
157         if (!dev)
158                 return -ENODEV;
159
160         mdev = mctp_dev_get_rtnl(dev);
161         if (!mdev)
162                 return -ENODEV;
163
164         if (lladdr_len != dev->addr_len) {
165                 NL_SET_ERR_MSG(extack, "Wrong lladdr length");
166                 return -EINVAL;
167         }
168
169         return mctp_neigh_add(mdev, eid, MCTP_NEIGH_STATIC,
170                         lladdr_len, lladdr);
171 }
172
173 static int mctp_rtm_delneigh(struct sk_buff *skb, struct nlmsghdr *nlh,
174                              struct netlink_ext_ack *extack)
175 {
176         struct net *net = sock_net(skb->sk);
177         struct nlattr *tb[NDA_MAX + 1];
178         struct net_device *dev;
179         struct mctp_dev *mdev;
180         struct ndmsg *ndm;
181         int rc;
182         mctp_eid_t eid;
183
184         rc = nlmsg_parse(nlh, sizeof(*ndm), tb, NDA_MAX, nd_mctp_policy,
185                          extack);
186         if (rc < 0) {
187                 NL_SET_ERR_MSG(extack, "incorrect format");
188                 return rc;
189         }
190
191         if (!tb[NDA_DST]) {
192                 NL_SET_ERR_MSG(extack, "Neighbour EID must be specified");
193                 return -EINVAL;
194         }
195         eid = nla_get_u8(tb[NDA_DST]);
196
197         ndm = nlmsg_data(nlh);
198         dev = __dev_get_by_index(net, ndm->ndm_ifindex);
199         if (!dev)
200                 return -ENODEV;
201
202         mdev = mctp_dev_get_rtnl(dev);
203         if (!mdev)
204                 return -ENODEV;
205
206         return mctp_neigh_remove(mdev, eid, MCTP_NEIGH_STATIC);
207 }
208
209 static int mctp_fill_neigh(struct sk_buff *skb, u32 portid, u32 seq, int event,
210                            unsigned int flags, struct mctp_neigh *neigh)
211 {
212         struct net_device *dev = neigh->dev->dev;
213         struct nlmsghdr *nlh;
214         struct ndmsg *hdr;
215
216         nlh = nlmsg_put(skb, portid, seq, event, sizeof(*hdr), flags);
217         if (!nlh)
218                 return -EMSGSIZE;
219
220         hdr = nlmsg_data(nlh);
221         hdr->ndm_family = AF_MCTP;
222         hdr->ndm_ifindex = dev->ifindex;
223         hdr->ndm_state = 0; // TODO other state bits?
224         if (neigh->source == MCTP_NEIGH_STATIC)
225                 hdr->ndm_state |= NUD_PERMANENT;
226         hdr->ndm_flags = 0;
227         hdr->ndm_type = RTN_UNICAST; // TODO: is loopback RTN_LOCAL?
228
229         if (nla_put_u8(skb, NDA_DST, neigh->eid))
230                 goto cancel;
231
232         if (nla_put(skb, NDA_LLADDR, dev->addr_len, neigh->ha))
233                 goto cancel;
234
235         nlmsg_end(skb, nlh);
236
237         return 0;
238 cancel:
239         nlmsg_cancel(skb, nlh);
240         return -EMSGSIZE;
241 }
242
243 static int mctp_rtm_getneigh(struct sk_buff *skb, struct netlink_callback *cb)
244 {
245         struct net *net = sock_net(skb->sk);
246         int rc, idx, req_ifindex;
247         struct mctp_neigh *neigh;
248         struct ndmsg *ndmsg;
249         struct {
250                 int idx;
251         } *cbctx = (void *)cb->ctx;
252
253         ndmsg = nlmsg_data(cb->nlh);
254         req_ifindex = ndmsg->ndm_ifindex;
255
256         idx = 0;
257         rcu_read_lock();
258         list_for_each_entry_rcu(neigh, &net->mctp.neighbours, list) {
259                 if (idx < cbctx->idx)
260                         goto cont;
261
262                 rc = 0;
263                 if (req_ifindex == 0 || req_ifindex == neigh->dev->dev->ifindex)
264                         rc = mctp_fill_neigh(skb, NETLINK_CB(cb->skb).portid,
265                                              cb->nlh->nlmsg_seq,
266                                              RTM_NEWNEIGH, NLM_F_MULTI, neigh);
267
268                 if (rc)
269                         break;
270 cont:
271                 idx++;
272         }
273         rcu_read_unlock();
274
275         cbctx->idx = idx;
276         return skb->len;
277 }
278
279 int mctp_neigh_lookup(struct mctp_dev *mdev, mctp_eid_t eid, void *ret_hwaddr)
280 {
281         struct net *net = dev_net(mdev->dev);
282         struct mctp_neigh *neigh;
283         int rc = -EHOSTUNREACH; // TODO: or ENOENT?
284
285         rcu_read_lock();
286         list_for_each_entry_rcu(neigh, &net->mctp.neighbours, list) {
287                 if (mdev == neigh->dev && eid == neigh->eid) {
288                         if (ret_hwaddr)
289                                 memcpy(ret_hwaddr, neigh->ha,
290                                        sizeof(neigh->ha));
291                         rc = 0;
292                         break;
293                 }
294         }
295         rcu_read_unlock();
296         return rc;
297 }
298
299 /* namespace registration */
300 static int __net_init mctp_neigh_net_init(struct net *net)
301 {
302         struct netns_mctp *ns = &net->mctp;
303
304         INIT_LIST_HEAD(&ns->neighbours);
305         mutex_init(&ns->neigh_lock);
306         return 0;
307 }
308
309 static void __net_exit mctp_neigh_net_exit(struct net *net)
310 {
311         struct netns_mctp *ns = &net->mctp;
312         struct mctp_neigh *neigh;
313
314         list_for_each_entry(neigh, &ns->neighbours, list)
315                 call_rcu(&neigh->rcu, __mctp_neigh_free);
316 }
317
318 /* net namespace implementation */
319
320 static struct pernet_operations mctp_net_ops = {
321         .init = mctp_neigh_net_init,
322         .exit = mctp_neigh_net_exit,
323 };
324
325 int __init mctp_neigh_init(void)
326 {
327         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_NEWNEIGH,
328                              mctp_rtm_newneigh, NULL, 0);
329         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_DELNEIGH,
330                              mctp_rtm_delneigh, NULL, 0);
331         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_GETNEIGH,
332                              NULL, mctp_rtm_getneigh, 0);
333
334         return register_pernet_subsys(&mctp_net_ops);
335 }
336
337 void __exit mctp_neigh_exit(void)
338 {
339         unregister_pernet_subsys(&mctp_net_ops);
340         rtnl_unregister(PF_MCTP, RTM_GETNEIGH);
341         rtnl_unregister(PF_MCTP, RTM_DELNEIGH);
342         rtnl_unregister(PF_MCTP, RTM_NEWNEIGH);
343 }