seg6: export get_srh() for ICMP handling
[platform/kernel/linux-rpi.git] / net / ipv6 / seg6.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *  SR-IPv6 implementation
4  *
5  *  Author:
6  *  David Lebrun <david.lebrun@uclouvain.be>
7  */
8
9 #include <linux/errno.h>
10 #include <linux/types.h>
11 #include <linux/socket.h>
12 #include <linux/net.h>
13 #include <linux/in6.h>
14 #include <linux/slab.h>
15 #include <linux/rhashtable.h>
16
17 #include <net/ipv6.h>
18 #include <net/protocol.h>
19
20 #include <net/seg6.h>
21 #include <net/genetlink.h>
22 #include <linux/seg6.h>
23 #include <linux/seg6_genl.h>
24 #ifdef CONFIG_IPV6_SEG6_HMAC
25 #include <net/seg6_hmac.h>
26 #endif
27
28 bool seg6_validate_srh(struct ipv6_sr_hdr *srh, int len, bool reduced)
29 {
30         unsigned int tlv_offset;
31         int max_last_entry;
32         int trailing;
33
34         if (srh->type != IPV6_SRCRT_TYPE_4)
35                 return false;
36
37         if (((srh->hdrlen + 1) << 3) != len)
38                 return false;
39
40         if (!reduced && srh->segments_left > srh->first_segment) {
41                 return false;
42         } else {
43                 max_last_entry = (srh->hdrlen / 2) - 1;
44
45                 if (srh->first_segment > max_last_entry)
46                         return false;
47
48                 if (srh->segments_left > srh->first_segment + 1)
49                         return false;
50         }
51
52         tlv_offset = sizeof(*srh) + ((srh->first_segment + 1) << 4);
53
54         trailing = len - tlv_offset;
55         if (trailing < 0)
56                 return false;
57
58         while (trailing) {
59                 struct sr6_tlv *tlv;
60                 unsigned int tlv_len;
61
62                 if (trailing < sizeof(*tlv))
63                         return false;
64
65                 tlv = (struct sr6_tlv *)((unsigned char *)srh + tlv_offset);
66                 tlv_len = sizeof(*tlv) + tlv->len;
67
68                 trailing -= tlv_len;
69                 if (trailing < 0)
70                         return false;
71
72                 tlv_offset += tlv_len;
73         }
74
75         return true;
76 }
77
78 struct ipv6_sr_hdr *seg6_get_srh(struct sk_buff *skb, int flags)
79 {
80         struct ipv6_sr_hdr *srh;
81         int len, srhoff = 0;
82
83         if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, &flags) < 0)
84                 return NULL;
85
86         if (!pskb_may_pull(skb, srhoff + sizeof(*srh)))
87                 return NULL;
88
89         srh = (struct ipv6_sr_hdr *)(skb->data + srhoff);
90
91         len = (srh->hdrlen + 1) << 3;
92
93         if (!pskb_may_pull(skb, srhoff + len))
94                 return NULL;
95
96         /* note that pskb_may_pull may change pointers in header;
97          * for this reason it is necessary to reload them when needed.
98          */
99         srh = (struct ipv6_sr_hdr *)(skb->data + srhoff);
100
101         if (!seg6_validate_srh(srh, len, true))
102                 return NULL;
103
104         return srh;
105 }
106
107 static struct genl_family seg6_genl_family;
108
109 static const struct nla_policy seg6_genl_policy[SEG6_ATTR_MAX + 1] = {
110         [SEG6_ATTR_DST]                         = { .type = NLA_BINARY,
111                 .len = sizeof(struct in6_addr) },
112         [SEG6_ATTR_DSTLEN]                      = { .type = NLA_S32, },
113         [SEG6_ATTR_HMACKEYID]           = { .type = NLA_U32, },
114         [SEG6_ATTR_SECRET]                      = { .type = NLA_BINARY, },
115         [SEG6_ATTR_SECRETLEN]           = { .type = NLA_U8, },
116         [SEG6_ATTR_ALGID]                       = { .type = NLA_U8, },
117         [SEG6_ATTR_HMACINFO]            = { .type = NLA_NESTED, },
118 };
119
120 #ifdef CONFIG_IPV6_SEG6_HMAC
121
122 static int seg6_genl_sethmac(struct sk_buff *skb, struct genl_info *info)
123 {
124         struct net *net = genl_info_net(info);
125         struct seg6_pernet_data *sdata;
126         struct seg6_hmac_info *hinfo;
127         u32 hmackeyid;
128         char *secret;
129         int err = 0;
130         u8 algid;
131         u8 slen;
132
133         sdata = seg6_pernet(net);
134
135         if (!info->attrs[SEG6_ATTR_HMACKEYID] ||
136             !info->attrs[SEG6_ATTR_SECRETLEN] ||
137             !info->attrs[SEG6_ATTR_ALGID])
138                 return -EINVAL;
139
140         hmackeyid = nla_get_u32(info->attrs[SEG6_ATTR_HMACKEYID]);
141         slen = nla_get_u8(info->attrs[SEG6_ATTR_SECRETLEN]);
142         algid = nla_get_u8(info->attrs[SEG6_ATTR_ALGID]);
143
144         if (hmackeyid == 0)
145                 return -EINVAL;
146
147         if (slen > SEG6_HMAC_SECRET_LEN)
148                 return -EINVAL;
149
150         mutex_lock(&sdata->lock);
151         hinfo = seg6_hmac_info_lookup(net, hmackeyid);
152
153         if (!slen) {
154                 err = seg6_hmac_info_del(net, hmackeyid);
155
156                 goto out_unlock;
157         }
158
159         if (!info->attrs[SEG6_ATTR_SECRET]) {
160                 err = -EINVAL;
161                 goto out_unlock;
162         }
163
164         if (hinfo) {
165                 err = seg6_hmac_info_del(net, hmackeyid);
166                 if (err)
167                         goto out_unlock;
168         }
169
170         secret = (char *)nla_data(info->attrs[SEG6_ATTR_SECRET]);
171
172         hinfo = kzalloc(sizeof(*hinfo), GFP_KERNEL);
173         if (!hinfo) {
174                 err = -ENOMEM;
175                 goto out_unlock;
176         }
177
178         memcpy(hinfo->secret, secret, slen);
179         hinfo->slen = slen;
180         hinfo->alg_id = algid;
181         hinfo->hmackeyid = hmackeyid;
182
183         err = seg6_hmac_info_add(net, hmackeyid, hinfo);
184         if (err)
185                 kfree(hinfo);
186
187 out_unlock:
188         mutex_unlock(&sdata->lock);
189         return err;
190 }
191
192 #else
193
194 static int seg6_genl_sethmac(struct sk_buff *skb, struct genl_info *info)
195 {
196         return -ENOTSUPP;
197 }
198
199 #endif
200
201 static int seg6_genl_set_tunsrc(struct sk_buff *skb, struct genl_info *info)
202 {
203         struct net *net = genl_info_net(info);
204         struct in6_addr *val, *t_old, *t_new;
205         struct seg6_pernet_data *sdata;
206
207         sdata = seg6_pernet(net);
208
209         if (!info->attrs[SEG6_ATTR_DST])
210                 return -EINVAL;
211
212         val = nla_data(info->attrs[SEG6_ATTR_DST]);
213         t_new = kmemdup(val, sizeof(*val), GFP_KERNEL);
214         if (!t_new)
215                 return -ENOMEM;
216
217         mutex_lock(&sdata->lock);
218
219         t_old = sdata->tun_src;
220         rcu_assign_pointer(sdata->tun_src, t_new);
221
222         mutex_unlock(&sdata->lock);
223
224         synchronize_net();
225         kfree(t_old);
226
227         return 0;
228 }
229
230 static int seg6_genl_get_tunsrc(struct sk_buff *skb, struct genl_info *info)
231 {
232         struct net *net = genl_info_net(info);
233         struct in6_addr *tun_src;
234         struct sk_buff *msg;
235         void *hdr;
236
237         msg = genlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
238         if (!msg)
239                 return -ENOMEM;
240
241         hdr = genlmsg_put(msg, info->snd_portid, info->snd_seq,
242                           &seg6_genl_family, 0, SEG6_CMD_GET_TUNSRC);
243         if (!hdr)
244                 goto free_msg;
245
246         rcu_read_lock();
247         tun_src = rcu_dereference(seg6_pernet(net)->tun_src);
248
249         if (nla_put(msg, SEG6_ATTR_DST, sizeof(struct in6_addr), tun_src))
250                 goto nla_put_failure;
251
252         rcu_read_unlock();
253
254         genlmsg_end(msg, hdr);
255         return genlmsg_reply(msg, info);
256
257 nla_put_failure:
258         rcu_read_unlock();
259 free_msg:
260         nlmsg_free(msg);
261         return -ENOMEM;
262 }
263
264 #ifdef CONFIG_IPV6_SEG6_HMAC
265
266 static int __seg6_hmac_fill_info(struct seg6_hmac_info *hinfo,
267                                  struct sk_buff *msg)
268 {
269         if (nla_put_u32(msg, SEG6_ATTR_HMACKEYID, hinfo->hmackeyid) ||
270             nla_put_u8(msg, SEG6_ATTR_SECRETLEN, hinfo->slen) ||
271             nla_put(msg, SEG6_ATTR_SECRET, hinfo->slen, hinfo->secret) ||
272             nla_put_u8(msg, SEG6_ATTR_ALGID, hinfo->alg_id))
273                 return -1;
274
275         return 0;
276 }
277
278 static int __seg6_genl_dumphmac_element(struct seg6_hmac_info *hinfo,
279                                         u32 portid, u32 seq, u32 flags,
280                                         struct sk_buff *skb, u8 cmd)
281 {
282         void *hdr;
283
284         hdr = genlmsg_put(skb, portid, seq, &seg6_genl_family, flags, cmd);
285         if (!hdr)
286                 return -ENOMEM;
287
288         if (__seg6_hmac_fill_info(hinfo, skb) < 0)
289                 goto nla_put_failure;
290
291         genlmsg_end(skb, hdr);
292         return 0;
293
294 nla_put_failure:
295         genlmsg_cancel(skb, hdr);
296         return -EMSGSIZE;
297 }
298
299 static int seg6_genl_dumphmac_start(struct netlink_callback *cb)
300 {
301         struct net *net = sock_net(cb->skb->sk);
302         struct seg6_pernet_data *sdata;
303         struct rhashtable_iter *iter;
304
305         sdata = seg6_pernet(net);
306         iter = (struct rhashtable_iter *)cb->args[0];
307
308         if (!iter) {
309                 iter = kmalloc(sizeof(*iter), GFP_KERNEL);
310                 if (!iter)
311                         return -ENOMEM;
312
313                 cb->args[0] = (long)iter;
314         }
315
316         rhashtable_walk_enter(&sdata->hmac_infos, iter);
317
318         return 0;
319 }
320
321 static int seg6_genl_dumphmac_done(struct netlink_callback *cb)
322 {
323         struct rhashtable_iter *iter = (struct rhashtable_iter *)cb->args[0];
324
325         rhashtable_walk_exit(iter);
326
327         kfree(iter);
328
329         return 0;
330 }
331
332 static int seg6_genl_dumphmac(struct sk_buff *skb, struct netlink_callback *cb)
333 {
334         struct rhashtable_iter *iter = (struct rhashtable_iter *)cb->args[0];
335         struct seg6_hmac_info *hinfo;
336         int ret;
337
338         rhashtable_walk_start(iter);
339
340         for (;;) {
341                 hinfo = rhashtable_walk_next(iter);
342
343                 if (IS_ERR(hinfo)) {
344                         if (PTR_ERR(hinfo) == -EAGAIN)
345                                 continue;
346                         ret = PTR_ERR(hinfo);
347                         goto done;
348                 } else if (!hinfo) {
349                         break;
350                 }
351
352                 ret = __seg6_genl_dumphmac_element(hinfo,
353                                                    NETLINK_CB(cb->skb).portid,
354                                                    cb->nlh->nlmsg_seq,
355                                                    NLM_F_MULTI,
356                                                    skb, SEG6_CMD_DUMPHMAC);
357                 if (ret)
358                         goto done;
359         }
360
361         ret = skb->len;
362
363 done:
364         rhashtable_walk_stop(iter);
365         return ret;
366 }
367
368 #else
369
370 static int seg6_genl_dumphmac_start(struct netlink_callback *cb)
371 {
372         return 0;
373 }
374
375 static int seg6_genl_dumphmac_done(struct netlink_callback *cb)
376 {
377         return 0;
378 }
379
380 static int seg6_genl_dumphmac(struct sk_buff *skb, struct netlink_callback *cb)
381 {
382         return -ENOTSUPP;
383 }
384
385 #endif
386
387 static int __net_init seg6_net_init(struct net *net)
388 {
389         struct seg6_pernet_data *sdata;
390
391         sdata = kzalloc(sizeof(*sdata), GFP_KERNEL);
392         if (!sdata)
393                 return -ENOMEM;
394
395         mutex_init(&sdata->lock);
396
397         sdata->tun_src = kzalloc(sizeof(*sdata->tun_src), GFP_KERNEL);
398         if (!sdata->tun_src) {
399                 kfree(sdata);
400                 return -ENOMEM;
401         }
402
403         net->ipv6.seg6_data = sdata;
404
405 #ifdef CONFIG_IPV6_SEG6_HMAC
406         seg6_hmac_net_init(net);
407 #endif
408
409         return 0;
410 }
411
412 static void __net_exit seg6_net_exit(struct net *net)
413 {
414         struct seg6_pernet_data *sdata = seg6_pernet(net);
415
416 #ifdef CONFIG_IPV6_SEG6_HMAC
417         seg6_hmac_net_exit(net);
418 #endif
419
420         kfree(sdata->tun_src);
421         kfree(sdata);
422 }
423
424 static struct pernet_operations ip6_segments_ops = {
425         .init = seg6_net_init,
426         .exit = seg6_net_exit,
427 };
428
429 static const struct genl_ops seg6_genl_ops[] = {
430         {
431                 .cmd    = SEG6_CMD_SETHMAC,
432                 .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
433                 .doit   = seg6_genl_sethmac,
434                 .flags  = GENL_ADMIN_PERM,
435         },
436         {
437                 .cmd    = SEG6_CMD_DUMPHMAC,
438                 .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
439                 .start  = seg6_genl_dumphmac_start,
440                 .dumpit = seg6_genl_dumphmac,
441                 .done   = seg6_genl_dumphmac_done,
442                 .flags  = GENL_ADMIN_PERM,
443         },
444         {
445                 .cmd    = SEG6_CMD_SET_TUNSRC,
446                 .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
447                 .doit   = seg6_genl_set_tunsrc,
448                 .flags  = GENL_ADMIN_PERM,
449         },
450         {
451                 .cmd    = SEG6_CMD_GET_TUNSRC,
452                 .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
453                 .doit   = seg6_genl_get_tunsrc,
454                 .flags  = GENL_ADMIN_PERM,
455         },
456 };
457
458 static struct genl_family seg6_genl_family __ro_after_init = {
459         .hdrsize        = 0,
460         .name           = SEG6_GENL_NAME,
461         .version        = SEG6_GENL_VERSION,
462         .maxattr        = SEG6_ATTR_MAX,
463         .policy = seg6_genl_policy,
464         .netnsok        = true,
465         .parallel_ops   = true,
466         .ops            = seg6_genl_ops,
467         .n_ops          = ARRAY_SIZE(seg6_genl_ops),
468         .module         = THIS_MODULE,
469 };
470
471 int __init seg6_init(void)
472 {
473         int err;
474
475         err = genl_register_family(&seg6_genl_family);
476         if (err)
477                 goto out;
478
479         err = register_pernet_subsys(&ip6_segments_ops);
480         if (err)
481                 goto out_unregister_genl;
482
483 #ifdef CONFIG_IPV6_SEG6_LWTUNNEL
484         err = seg6_iptunnel_init();
485         if (err)
486                 goto out_unregister_pernet;
487
488         err = seg6_local_init();
489         if (err)
490                 goto out_unregister_pernet;
491 #endif
492
493 #ifdef CONFIG_IPV6_SEG6_HMAC
494         err = seg6_hmac_init();
495         if (err)
496                 goto out_unregister_iptun;
497 #endif
498
499         pr_info("Segment Routing with IPv6\n");
500
501 out:
502         return err;
503 #ifdef CONFIG_IPV6_SEG6_HMAC
504 out_unregister_iptun:
505 #ifdef CONFIG_IPV6_SEG6_LWTUNNEL
506         seg6_local_exit();
507         seg6_iptunnel_exit();
508 #endif
509 #endif
510 #ifdef CONFIG_IPV6_SEG6_LWTUNNEL
511 out_unregister_pernet:
512         unregister_pernet_subsys(&ip6_segments_ops);
513 #endif
514 out_unregister_genl:
515         genl_unregister_family(&seg6_genl_family);
516         goto out;
517 }
518
519 void seg6_exit(void)
520 {
521 #ifdef CONFIG_IPV6_SEG6_HMAC
522         seg6_hmac_exit();
523 #endif
524 #ifdef CONFIG_IPV6_SEG6_LWTUNNEL
525         seg6_iptunnel_exit();
526 #endif
527         unregister_pernet_subsys(&ip6_segments_ops);
528         genl_unregister_family(&seg6_genl_family);
529 }