Merge branch 'linus' into smp/urgent
[platform/kernel/linux-rpi.git] / net / mctp / neigh.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Management Component Transport Protocol (MCTP) - routing
4  * implementation.
5  *
6  * This is currently based on a simple routing table, with no dst cache. The
7  * number of routes should stay fairly small, so the lookup cost is small.
8  *
9  * Copyright (c) 2021 Code Construct
10  * Copyright (c) 2021 Google
11  */
12
13 #include <linux/idr.h>
14 #include <linux/mctp.h>
15 #include <linux/netdevice.h>
16 #include <linux/rtnetlink.h>
17 #include <linux/skbuff.h>
18
19 #include <net/mctp.h>
20 #include <net/mctpdevice.h>
21 #include <net/netlink.h>
22 #include <net/sock.h>
23
24 static int mctp_neigh_add(struct mctp_dev *mdev, mctp_eid_t eid,
25                           enum mctp_neigh_source source,
26                           size_t lladdr_len, const void *lladdr)
27 {
28         struct net *net = dev_net(mdev->dev);
29         struct mctp_neigh *neigh;
30         int rc;
31
32         mutex_lock(&net->mctp.neigh_lock);
33         if (mctp_neigh_lookup(mdev, eid, NULL) == 0) {
34                 rc = -EEXIST;
35                 goto out;
36         }
37
38         if (lladdr_len > sizeof(neigh->ha)) {
39                 rc = -EINVAL;
40                 goto out;
41         }
42
43         neigh = kzalloc(sizeof(*neigh), GFP_KERNEL);
44         if (!neigh) {
45                 rc = -ENOMEM;
46                 goto out;
47         }
48         INIT_LIST_HEAD(&neigh->list);
49         neigh->dev = mdev;
50         dev_hold(neigh->dev->dev);
51         neigh->eid = eid;
52         neigh->source = source;
53         memcpy(neigh->ha, lladdr, lladdr_len);
54
55         list_add_rcu(&neigh->list, &net->mctp.neighbours);
56         rc = 0;
57 out:
58         mutex_unlock(&net->mctp.neigh_lock);
59         return rc;
60 }
61
62 static void __mctp_neigh_free(struct rcu_head *rcu)
63 {
64         struct mctp_neigh *neigh = container_of(rcu, struct mctp_neigh, rcu);
65
66         dev_put(neigh->dev->dev);
67         kfree(neigh);
68 }
69
70 /* Removes all neighbour entries referring to a device */
71 void mctp_neigh_remove_dev(struct mctp_dev *mdev)
72 {
73         struct net *net = dev_net(mdev->dev);
74         struct mctp_neigh *neigh, *tmp;
75
76         mutex_lock(&net->mctp.neigh_lock);
77         list_for_each_entry_safe(neigh, tmp, &net->mctp.neighbours, list) {
78                 if (neigh->dev == mdev) {
79                         list_del_rcu(&neigh->list);
80                         /* TODO: immediate RTM_DELNEIGH */
81                         call_rcu(&neigh->rcu, __mctp_neigh_free);
82                 }
83         }
84
85         mutex_unlock(&net->mctp.neigh_lock);
86 }
87
88 // TODO: add a "source" flag so netlink can only delete static neighbours?
89 static int mctp_neigh_remove(struct mctp_dev *mdev, mctp_eid_t eid)
90 {
91         struct net *net = dev_net(mdev->dev);
92         struct mctp_neigh *neigh, *tmp;
93         bool dropped = false;
94
95         mutex_lock(&net->mctp.neigh_lock);
96         list_for_each_entry_safe(neigh, tmp, &net->mctp.neighbours, list) {
97                 if (neigh->dev == mdev && neigh->eid == eid) {
98                         list_del_rcu(&neigh->list);
99                         /* TODO: immediate RTM_DELNEIGH */
100                         call_rcu(&neigh->rcu, __mctp_neigh_free);
101                         dropped = true;
102                 }
103         }
104
105         mutex_unlock(&net->mctp.neigh_lock);
106         return dropped ? 0 : -ENOENT;
107 }
108
109 static const struct nla_policy nd_mctp_policy[NDA_MAX + 1] = {
110         [NDA_DST]               = { .type = NLA_U8 },
111         [NDA_LLADDR]            = { .type = NLA_BINARY, .len = MAX_ADDR_LEN },
112 };
113
114 static int mctp_rtm_newneigh(struct sk_buff *skb, struct nlmsghdr *nlh,
115                              struct netlink_ext_ack *extack)
116 {
117         struct net *net = sock_net(skb->sk);
118         struct net_device *dev;
119         struct mctp_dev *mdev;
120         struct ndmsg *ndm;
121         struct nlattr *tb[NDA_MAX + 1];
122         int rc;
123         mctp_eid_t eid;
124         void *lladdr;
125         int lladdr_len;
126
127         rc = nlmsg_parse(nlh, sizeof(*ndm), tb, NDA_MAX, nd_mctp_policy,
128                          extack);
129         if (rc < 0) {
130                 NL_SET_ERR_MSG(extack, "lladdr too large?");
131                 return rc;
132         }
133
134         if (!tb[NDA_DST]) {
135                 NL_SET_ERR_MSG(extack, "Neighbour EID must be specified");
136                 return -EINVAL;
137         }
138
139         if (!tb[NDA_LLADDR]) {
140                 NL_SET_ERR_MSG(extack, "Neighbour lladdr must be specified");
141                 return -EINVAL;
142         }
143
144         eid = nla_get_u8(tb[NDA_DST]);
145         if (!mctp_address_ok(eid)) {
146                 NL_SET_ERR_MSG(extack, "Invalid neighbour EID");
147                 return -EINVAL;
148         }
149
150         lladdr = nla_data(tb[NDA_LLADDR]);
151         lladdr_len = nla_len(tb[NDA_LLADDR]);
152
153         ndm = nlmsg_data(nlh);
154
155         dev = __dev_get_by_index(net, ndm->ndm_ifindex);
156         if (!dev)
157                 return -ENODEV;
158
159         mdev = mctp_dev_get_rtnl(dev);
160         if (!mdev)
161                 return -ENODEV;
162
163         if (lladdr_len != dev->addr_len) {
164                 NL_SET_ERR_MSG(extack, "Wrong lladdr length");
165                 return -EINVAL;
166         }
167
168         return mctp_neigh_add(mdev, eid, MCTP_NEIGH_STATIC,
169                         lladdr_len, lladdr);
170 }
171
172 static int mctp_rtm_delneigh(struct sk_buff *skb, struct nlmsghdr *nlh,
173                              struct netlink_ext_ack *extack)
174 {
175         struct net *net = sock_net(skb->sk);
176         struct nlattr *tb[NDA_MAX + 1];
177         struct net_device *dev;
178         struct mctp_dev *mdev;
179         struct ndmsg *ndm;
180         int rc;
181         mctp_eid_t eid;
182
183         rc = nlmsg_parse(nlh, sizeof(*ndm), tb, NDA_MAX, nd_mctp_policy,
184                          extack);
185         if (rc < 0) {
186                 NL_SET_ERR_MSG(extack, "incorrect format");
187                 return rc;
188         }
189
190         if (!tb[NDA_DST]) {
191                 NL_SET_ERR_MSG(extack, "Neighbour EID must be specified");
192                 return -EINVAL;
193         }
194         eid = nla_get_u8(tb[NDA_DST]);
195
196         ndm = nlmsg_data(nlh);
197         dev = __dev_get_by_index(net, ndm->ndm_ifindex);
198         if (!dev)
199                 return -ENODEV;
200
201         mdev = mctp_dev_get_rtnl(dev);
202         if (!mdev)
203                 return -ENODEV;
204
205         return mctp_neigh_remove(mdev, eid);
206 }
207
208 static int mctp_fill_neigh(struct sk_buff *skb, u32 portid, u32 seq, int event,
209                            unsigned int flags, struct mctp_neigh *neigh)
210 {
211         struct net_device *dev = neigh->dev->dev;
212         struct nlmsghdr *nlh;
213         struct ndmsg *hdr;
214
215         nlh = nlmsg_put(skb, portid, seq, event, sizeof(*hdr), flags);
216         if (!nlh)
217                 return -EMSGSIZE;
218
219         hdr = nlmsg_data(nlh);
220         hdr->ndm_family = AF_MCTP;
221         hdr->ndm_ifindex = dev->ifindex;
222         hdr->ndm_state = 0; // TODO other state bits?
223         if (neigh->source == MCTP_NEIGH_STATIC)
224                 hdr->ndm_state |= NUD_PERMANENT;
225         hdr->ndm_flags = 0;
226         hdr->ndm_type = RTN_UNICAST; // TODO: is loopback RTN_LOCAL?
227
228         if (nla_put_u8(skb, NDA_DST, neigh->eid))
229                 goto cancel;
230
231         if (nla_put(skb, NDA_LLADDR, dev->addr_len, neigh->ha))
232                 goto cancel;
233
234         nlmsg_end(skb, nlh);
235
236         return 0;
237 cancel:
238         nlmsg_cancel(skb, nlh);
239         return -EMSGSIZE;
240 }
241
242 static int mctp_rtm_getneigh(struct sk_buff *skb, struct netlink_callback *cb)
243 {
244         struct net *net = sock_net(skb->sk);
245         int rc, idx, req_ifindex;
246         struct mctp_neigh *neigh;
247         struct ndmsg *ndmsg;
248         struct {
249                 int idx;
250         } *cbctx = (void *)cb->ctx;
251
252         ndmsg = nlmsg_data(cb->nlh);
253         req_ifindex = ndmsg->ndm_ifindex;
254
255         idx = 0;
256         rcu_read_lock();
257         list_for_each_entry_rcu(neigh, &net->mctp.neighbours, list) {
258                 if (idx < cbctx->idx)
259                         goto cont;
260
261                 rc = 0;
262                 if (req_ifindex == 0 || req_ifindex == neigh->dev->dev->ifindex)
263                         rc = mctp_fill_neigh(skb, NETLINK_CB(cb->skb).portid,
264                                              cb->nlh->nlmsg_seq,
265                                              RTM_NEWNEIGH, NLM_F_MULTI, neigh);
266
267                 if (rc)
268                         break;
269 cont:
270                 idx++;
271         }
272         rcu_read_unlock();
273
274         cbctx->idx = idx;
275         return skb->len;
276 }
277
278 int mctp_neigh_lookup(struct mctp_dev *mdev, mctp_eid_t eid, void *ret_hwaddr)
279 {
280         struct net *net = dev_net(mdev->dev);
281         struct mctp_neigh *neigh;
282         int rc = -EHOSTUNREACH; // TODO: or ENOENT?
283
284         rcu_read_lock();
285         list_for_each_entry_rcu(neigh, &net->mctp.neighbours, list) {
286                 if (mdev == neigh->dev && eid == neigh->eid) {
287                         if (ret_hwaddr)
288                                 memcpy(ret_hwaddr, neigh->ha,
289                                        sizeof(neigh->ha));
290                         rc = 0;
291                         break;
292                 }
293         }
294         rcu_read_unlock();
295         return rc;
296 }
297
298 /* namespace registration */
299 static int __net_init mctp_neigh_net_init(struct net *net)
300 {
301         struct netns_mctp *ns = &net->mctp;
302
303         INIT_LIST_HEAD(&ns->neighbours);
304         mutex_init(&ns->neigh_lock);
305         return 0;
306 }
307
308 static void __net_exit mctp_neigh_net_exit(struct net *net)
309 {
310         struct netns_mctp *ns = &net->mctp;
311         struct mctp_neigh *neigh;
312
313         list_for_each_entry(neigh, &ns->neighbours, list)
314                 call_rcu(&neigh->rcu, __mctp_neigh_free);
315 }
316
317 /* net namespace implementation */
318
319 static struct pernet_operations mctp_net_ops = {
320         .init = mctp_neigh_net_init,
321         .exit = mctp_neigh_net_exit,
322 };
323
324 int __init mctp_neigh_init(void)
325 {
326         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_NEWNEIGH,
327                              mctp_rtm_newneigh, NULL, 0);
328         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_DELNEIGH,
329                              mctp_rtm_delneigh, NULL, 0);
330         rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_GETNEIGH,
331                              NULL, mctp_rtm_getneigh, 0);
332
333         return register_pernet_subsys(&mctp_net_ops);
334 }
335
336 void __exit mctp_neigh_exit(void)
337 {
338         unregister_pernet_subsys(&mctp_net_ops);
339         rtnl_unregister(PF_MCTP, RTM_GETNEIGH);
340         rtnl_unregister(PF_MCTP, RTM_DELNEIGH);
341         rtnl_unregister(PF_MCTP, RTM_NEWNEIGH);
342 }