f88417947e60e5f749a7c4efeec1075ffe1b3ec5
[platform/kernel/linux-rpi.git] / tools / net / ynl / ynl-gen-c.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4 import argparse
5 import collections
6 import os
7 import yaml
8
9 from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
10
11
12 def c_upper(name):
13     return name.upper().replace('-', '_')
14
15
16 def c_lower(name):
17     return name.lower().replace('-', '_')
18
19
20 class BaseNlLib:
21     def get_family_id(self):
22         return 'ys->family_id'
23
24     def parse_cb_run(self, cb, data, is_dump=False, indent=1):
25         ind = '\n\t\t' + '\t' * indent + ' '
26         if is_dump:
27             return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
28         else:
29             return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
30                    "ynl_cb_array, NLMSG_MIN_TYPE)"
31
32
33 class Type(SpecAttr):
34     def __init__(self, family, attr_set, attr, value):
35         super().__init__(family, attr_set, attr, value)
36
37         self.attr = attr
38         self.attr_set = attr_set
39         self.type = attr['type']
40         self.checks = attr.get('checks', {})
41
42         if 'len' in attr:
43             self.len = attr['len']
44         if 'nested-attributes' in attr:
45             self.nested_attrs = attr['nested-attributes']
46             if self.nested_attrs == family.name:
47                 self.nested_render_name = f"{family.name}"
48             else:
49                 self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
50
51         self.c_name = c_lower(self.name)
52         if self.c_name in _C_KW:
53             self.c_name += '_'
54
55         # Added by resolve():
56         self.enum_name = None
57         delattr(self, "enum_name")
58
59     def resolve(self):
60         self.enum_name = f"{self.attr_set.name_prefix}{self.name}"
61         self.enum_name = c_upper(self.enum_name)
62
63     def is_multi_val(self):
64         return None
65
66     def is_scalar(self):
67         return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
68
69     def presence_type(self):
70         return 'bit'
71
72     def presence_member(self, space, type_filter):
73         if self.presence_type() != type_filter:
74             return
75
76         if self.presence_type() == 'bit':
77             pfx = '__' if space == 'user' else ''
78             return f"{pfx}u32 {self.c_name}:1;"
79
80         if self.presence_type() == 'len':
81             pfx = '__' if space == 'user' else ''
82             return f"{pfx}u32 {self.c_name}_len;"
83
84     def _complex_member_type(self, ri):
85         return None
86
87     def free_needs_iter(self):
88         return False
89
90     def free(self, ri, var, ref):
91         if self.is_multi_val() or self.presence_type() == 'len':
92             ri.cw.p(f'free({var}->{ref}{self.c_name});')
93
94     def arg_member(self, ri):
95         member = self._complex_member_type(ri)
96         if member:
97             arg = [member + ' *' + self.c_name]
98             if self.presence_type() == 'count':
99                 arg += ['unsigned int n_' + self.c_name]
100             return arg
101         raise Exception(f"Struct member not implemented for class type {self.type}")
102
103     def struct_member(self, ri):
104         if self.is_multi_val():
105             ri.cw.p(f"unsigned int n_{self.c_name};")
106         member = self._complex_member_type(ri)
107         if member:
108             ptr = '*' if self.is_multi_val() else ''
109             ri.cw.p(f"{member} {ptr}{self.c_name};")
110             return
111         members = self.arg_member(ri)
112         for one in members:
113             ri.cw.p(one + ';')
114
115     def _attr_policy(self, policy):
116         return '{ .type = ' + policy + ', }'
117
118     def attr_policy(self, cw):
119         policy = c_upper('nla-' + self.attr['type'])
120
121         spec = self._attr_policy(policy)
122         cw.p(f"\t[{self.enum_name}] = {spec},")
123
124     def _attr_typol(self):
125         raise Exception(f"Type policy not implemented for class type {self.type}")
126
127     def attr_typol(self, cw):
128         typol = self._attr_typol()
129         cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
130
131     def _attr_put_line(self, ri, var, line):
132         if self.presence_type() == 'bit':
133             ri.cw.p(f"if ({var}->_present.{self.c_name})")
134         elif self.presence_type() == 'len':
135             ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
136         ri.cw.p(f"{line};")
137
138     def _attr_put_simple(self, ri, var, put_type):
139         line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
140         self._attr_put_line(ri, var, line)
141
142     def attr_put(self, ri, var):
143         raise Exception(f"Put not implemented for class type {self.type}")
144
145     def _attr_get(self, ri, var):
146         raise Exception(f"Attr get not implemented for class type {self.type}")
147
148     def attr_get(self, ri, var, first):
149         lines, init_lines, local_vars = self._attr_get(ri, var)
150         if type(lines) is str:
151             lines = [lines]
152         if type(init_lines) is str:
153             init_lines = [init_lines]
154
155         kw = 'if' if first else 'else if'
156         ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
157         if local_vars:
158             for local in local_vars:
159                 ri.cw.p(local)
160             ri.cw.nl()
161
162         if not self.is_multi_val():
163             ri.cw.p("if (ynl_attr_validate(yarg, attr))")
164             ri.cw.p("return MNL_CB_ERROR;")
165             if self.presence_type() == 'bit':
166                 ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
167
168         if init_lines:
169             ri.cw.nl()
170             for line in init_lines:
171                 ri.cw.p(line)
172
173         for line in lines:
174             ri.cw.p(line)
175         ri.cw.block_end()
176         return True
177
178     def _setter_lines(self, ri, member, presence):
179         raise Exception(f"Setter not implemented for class type {self.type}")
180
181     def setter(self, ri, space, direction, deref=False, ref=None):
182         ref = (ref if ref else []) + [self.c_name]
183         var = "req"
184         member = f"{var}->{'.'.join(ref)}"
185
186         code = []
187         presence = ''
188         for i in range(0, len(ref)):
189             presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
190             if self.presence_type() == 'bit':
191                 code.append(presence + ' = 1;')
192         code += self._setter_lines(ri, member, presence)
193
194         func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
195         free = bool([x for x in code if 'free(' in x])
196         alloc = bool([x for x in code if 'alloc(' in x])
197         if free and not alloc:
198             func_name = '__' + func_name
199         ri.cw.write_func('static inline void', func_name, body=code,
200                          args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
201
202
203 class TypeUnused(Type):
204     def presence_type(self):
205         return ''
206
207     def arg_member(self, ri):
208         return []
209
210     def _attr_get(self, ri, var):
211         return ['return MNL_CB_ERROR;'], None, None
212
213     def _attr_typol(self):
214         return '.type = YNL_PT_REJECT, '
215
216     def attr_policy(self, cw):
217         pass
218
219
220 class TypePad(Type):
221     def presence_type(self):
222         return ''
223
224     def arg_member(self, ri):
225         return []
226
227     def _attr_typol(self):
228         return '.type = YNL_PT_IGNORE, '
229
230     def attr_get(self, ri, var, first):
231         pass
232
233     def attr_policy(self, cw):
234         pass
235
236
237 class TypeScalar(Type):
238     def __init__(self, family, attr_set, attr, value):
239         super().__init__(family, attr_set, attr, value)
240
241         self.byte_order_comment = ''
242         if 'byte-order' in attr:
243             self.byte_order_comment = f" /* {attr['byte-order']} */"
244
245         # Added by resolve():
246         self.is_bitfield = None
247         delattr(self, "is_bitfield")
248         self.type_name = None
249         delattr(self, "type_name")
250
251     def resolve(self):
252         self.resolve_up(super())
253
254         if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
255             self.is_bitfield = True
256         elif 'enum' in self.attr:
257             self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
258         else:
259             self.is_bitfield = False
260
261         if 'enum' in self.attr and not self.is_bitfield:
262             self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
263         else:
264             self.type_name = '__' + self.type
265
266     def _mnl_type(self):
267         t = self.type
268         # mnl does not have a helper for signed types
269         if t[0] == 's':
270             t = 'u' + t[1:]
271         return t
272
273     def _attr_policy(self, policy):
274         if 'flags-mask' in self.checks or self.is_bitfield:
275             if self.is_bitfield:
276                 enum = self.family.consts[self.attr['enum']]
277                 mask = enum.get_mask(as_flags=True)
278             else:
279                 flags = self.family.consts[self.checks['flags-mask']]
280                 flag_cnt = len(flags['entries'])
281                 mask = (1 << flag_cnt) - 1
282             return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
283         elif 'min' in self.checks:
284             return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
285         elif 'enum' in self.attr:
286             enum = self.family.consts[self.attr['enum']]
287             cnt = len(enum['entries'])
288             return f"NLA_POLICY_MAX({policy}, {cnt - 1})"
289         return super()._attr_policy(policy)
290
291     def _attr_typol(self):
292         return f'.type = YNL_PT_U{self.type[1:]}, '
293
294     def arg_member(self, ri):
295         return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
296
297     def attr_put(self, ri, var):
298         self._attr_put_simple(ri, var, self._mnl_type())
299
300     def _attr_get(self, ri, var):
301         return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
302
303     def _setter_lines(self, ri, member, presence):
304         return [f"{member} = {self.c_name};"]
305
306
307 class TypeFlag(Type):
308     def arg_member(self, ri):
309         return []
310
311     def _attr_typol(self):
312         return '.type = YNL_PT_FLAG, '
313
314     def attr_put(self, ri, var):
315         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
316
317     def _attr_get(self, ri, var):
318         return [], None, None
319
320     def _setter_lines(self, ri, member, presence):
321         return []
322
323
324 class TypeString(Type):
325     def arg_member(self, ri):
326         return [f"const char *{self.c_name}"]
327
328     def presence_type(self):
329         return 'len'
330
331     def struct_member(self, ri):
332         ri.cw.p(f"char *{self.c_name};")
333
334     def _attr_typol(self):
335         return f'.type = YNL_PT_NUL_STR, '
336
337     def _attr_policy(self, policy):
338         mem = '{ .type = ' + policy
339         if 'max-len' in self.checks:
340             mem += ', .len = ' + str(self.checks['max-len'])
341         mem += ', }'
342         return mem
343
344     def attr_policy(self, cw):
345         if self.checks.get('unterminated-ok', False):
346             policy = 'NLA_STRING'
347         else:
348             policy = 'NLA_NUL_STRING'
349
350         spec = self._attr_policy(policy)
351         cw.p(f"\t[{self.enum_name}] = {spec},")
352
353     def attr_put(self, ri, var):
354         self._attr_put_simple(ri, var, 'strz')
355
356     def _attr_get(self, ri, var):
357         len_mem = var + '->_present.' + self.c_name + '_len'
358         return [f"{len_mem} = len;",
359                 f"{var}->{self.c_name} = malloc(len + 1);",
360                 f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
361                 f"{var}->{self.c_name}[len] = 0;"], \
362                ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
363                ['unsigned int len;']
364
365     def _setter_lines(self, ri, member, presence):
366         return [f"free({member});",
367                 f"{presence}_len = strlen({self.c_name});",
368                 f"{member} = malloc({presence}_len + 1);",
369                 f'memcpy({member}, {self.c_name}, {presence}_len);',
370                 f'{member}[{presence}_len] = 0;']
371
372
373 class TypeBinary(Type):
374     def arg_member(self, ri):
375         return [f"const void *{self.c_name}", 'size_t len']
376
377     def presence_type(self):
378         return 'len'
379
380     def struct_member(self, ri):
381         ri.cw.p(f"void *{self.c_name};")
382
383     def _attr_typol(self):
384         return f'.type = YNL_PT_BINARY,'
385
386     def _attr_policy(self, policy):
387         mem = '{ '
388         if len(self.checks) == 1 and 'min-len' in self.checks:
389             mem += '.len = ' + str(self.checks['min-len'])
390         elif len(self.checks) == 0:
391             mem += '.type = NLA_BINARY'
392         else:
393             raise Exception('One or more of binary type checks not implemented, yet')
394         mem += ', }'
395         return mem
396
397     def attr_put(self, ri, var):
398         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
399                             f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
400
401     def _attr_get(self, ri, var):
402         len_mem = var + '->_present.' + self.c_name + '_len'
403         return [f"{len_mem} = len;",
404                 f"{var}->{self.c_name} = malloc(len);",
405                 f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
406                ['len = mnl_attr_get_payload_len(attr);'], \
407                ['unsigned int len;']
408
409     def _setter_lines(self, ri, member, presence):
410         return [f"free({member});",
411                 f"{member} = malloc({presence}_len);",
412                 f'memcpy({member}, {self.c_name}, {presence}_len);']
413
414
415 class TypeNest(Type):
416     def _complex_member_type(self, ri):
417         return f"struct {self.nested_render_name}"
418
419     def free(self, ri, var, ref):
420         ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
421
422     def _attr_typol(self):
423         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
424
425     def _attr_policy(self, policy):
426         return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
427
428     def attr_put(self, ri, var):
429         self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
430                             f"{self.enum_name}, &{var}->{self.c_name})")
431
432     def _attr_get(self, ri, var):
433         get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
434                      "return MNL_CB_ERROR;"]
435         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
436                       f"parg.data = &{var}->{self.c_name};"]
437         return get_lines, init_lines, None
438
439     def setter(self, ri, space, direction, deref=False, ref=None):
440         ref = (ref if ref else []) + [self.c_name]
441
442         for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
443             attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
444
445
446 class TypeMultiAttr(Type):
447     def is_multi_val(self):
448         return True
449
450     def presence_type(self):
451         return 'count'
452
453     def _mnl_type(self):
454         t = self.type
455         # mnl does not have a helper for signed types
456         if t[0] == 's':
457             t = 'u' + t[1:]
458         return t
459
460     def _complex_member_type(self, ri):
461         if 'type' not in self.attr or self.attr['type'] == 'nest':
462             return f"struct {self.nested_render_name}"
463         elif self.attr['type'] in scalars:
464             scalar_pfx = '__' if ri.ku_space == 'user' else ''
465             return scalar_pfx + self.attr['type']
466         else:
467             raise Exception(f"Sub-type {self.attr['type']} not supported yet")
468
469     def free_needs_iter(self):
470         return 'type' not in self.attr or self.attr['type'] == 'nest'
471
472     def free(self, ri, var, ref):
473         if self.attr['type'] in scalars:
474             ri.cw.p(f"free({var}->{ref}{self.c_name});")
475         elif 'type' not in self.attr or self.attr['type'] == 'nest':
476             ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
477             ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
478             ri.cw.p(f"free({var}->{ref}{self.c_name});")
479         else:
480             raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
481
482     def _attr_typol(self):
483         if 'type' not in self.attr or self.attr['type'] == 'nest':
484             return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
485         elif self.attr['type'] in scalars:
486             return f".type = YNL_PT_U{self.attr['type'][1:]}, "
487         else:
488             raise Exception(f"Sub-type {self.attr['type']} not supported yet")
489
490     def _attr_get(self, ri, var):
491         return f'n_{self.c_name}++;', None, None
492
493     def attr_put(self, ri, var):
494         if self.attr['type'] in scalars:
495             put_type = self._mnl_type()
496             ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
497             ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
498         elif 'type' not in self.attr or self.attr['type'] == 'nest':
499             ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
500             self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
501                                 f"{self.enum_name}, &{var}->{self.c_name}[i])")
502         else:
503             raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
504
505     def _setter_lines(self, ri, member, presence):
506         # For multi-attr we have a count, not presence, hack up the presence
507         presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
508         return [f"free({member});",
509                 f"{member} = {self.c_name};",
510                 f"{presence} = n_{self.c_name};"]
511
512
513 class TypeArrayNest(Type):
514     def is_multi_val(self):
515         return True
516
517     def presence_type(self):
518         return 'count'
519
520     def _complex_member_type(self, ri):
521         if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
522             return f"struct {self.nested_render_name}"
523         elif self.attr['sub-type'] in scalars:
524             scalar_pfx = '__' if ri.ku_space == 'user' else ''
525             return scalar_pfx + self.attr['sub-type']
526         else:
527             raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
528
529     def _attr_typol(self):
530         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
531
532     def _attr_get(self, ri, var):
533         local_vars = ['const struct nlattr *attr2;']
534         get_lines = [f'attr_{self.c_name} = attr;',
535                      'mnl_attr_for_each_nested(attr2, attr)',
536                      f'\t{var}->n_{self.c_name}++;']
537         return get_lines, None, local_vars
538
539
540 class TypeNestTypeValue(Type):
541     def _complex_member_type(self, ri):
542         return f"struct {self.nested_render_name}"
543
544     def _attr_typol(self):
545         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
546
547     def _attr_get(self, ri, var):
548         prev = 'attr'
549         tv_args = ''
550         get_lines = []
551         local_vars = []
552         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
553                       f"parg.data = &{var}->{self.c_name};"]
554         if 'type-value' in self.attr:
555             tv_names = [c_lower(x) for x in self.attr["type-value"]]
556             local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
557             local_vars += [f'__u32 {", ".join(tv_names)};']
558             for level in self.attr["type-value"]:
559                 level = c_lower(level)
560                 get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
561                 get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
562                 prev = 'attr_' + level
563
564             tv_args = f", {', '.join(tv_names)}"
565
566         get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
567         return get_lines, init_lines, local_vars
568
569
570 class Struct:
571     def __init__(self, family, space_name, type_list=None, inherited=None):
572         self.family = family
573         self.space_name = space_name
574         self.attr_set = family.attr_sets[space_name]
575         # Use list to catch comparisons with empty sets
576         self._inherited = inherited if inherited is not None else []
577         self.inherited = []
578
579         self.nested = type_list is None
580         if family.name == c_lower(space_name):
581             self.render_name = f"{family.name}"
582         else:
583             self.render_name = f"{family.name}_{c_lower(space_name)}"
584         self.struct_name = 'struct ' + self.render_name
585         self.ptr_name = self.struct_name + ' *'
586
587         self.request = False
588         self.reply = False
589
590         self.attr_list = []
591         self.attrs = dict()
592         if type_list:
593             for t in type_list:
594                 self.attr_list.append((t, self.attr_set[t]),)
595         else:
596             for t in self.attr_set:
597                 self.attr_list.append((t, self.attr_set[t]),)
598
599         max_val = 0
600         self.attr_max_val = None
601         for name, attr in self.attr_list:
602             if attr.value >= max_val:
603                 max_val = attr.value
604                 self.attr_max_val = attr
605             self.attrs[name] = attr
606
607     def __iter__(self):
608         yield from self.attrs
609
610     def __getitem__(self, key):
611         return self.attrs[key]
612
613     def member_list(self):
614         return self.attr_list
615
616     def set_inherited(self, new_inherited):
617         if self._inherited != new_inherited:
618             raise Exception("Inheriting different members not supported")
619         self.inherited = [c_lower(x) for x in sorted(self._inherited)]
620
621
622 class EnumEntry(SpecEnumEntry):
623     def __init__(self, enum_set, yaml, prev, value_start):
624         super().__init__(enum_set, yaml, prev, value_start)
625
626         if prev:
627             self.value_change = (self.value != prev.value + 1)
628         else:
629             self.value_change = (self.value != 0)
630         self.value_change = self.value_change or self.enum_set['type'] == 'flags'
631
632         # Added by resolve:
633         self.c_name = None
634         delattr(self, "c_name")
635
636     def resolve(self):
637         self.resolve_up(super())
638
639         self.c_name = c_upper(self.enum_set.value_pfx + self.name)
640
641
642 class EnumSet(SpecEnumSet):
643     def __init__(self, family, yaml):
644         self.render_name = c_lower(family.name + '-' + yaml['name'])
645         self.enum_name = 'enum ' + self.render_name
646
647         self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
648
649         super().__init__(family, yaml)
650
651     def new_entry(self, entry, prev_entry, value_start):
652         return EnumEntry(self, entry, prev_entry, value_start)
653
654
655 class AttrSet(SpecAttrSet):
656     def __init__(self, family, yaml):
657         super().__init__(family, yaml)
658
659         if self.subset_of is None:
660             if 'name-prefix' in yaml:
661                 pfx = yaml['name-prefix']
662             elif self.name == family.name:
663                 pfx = family.name + '-a-'
664             else:
665                 pfx = f"{family.name}-a-{self.name}-"
666             self.name_prefix = c_upper(pfx)
667             self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
668         else:
669             self.name_prefix = family.attr_sets[self.subset_of].name_prefix
670             self.max_name = family.attr_sets[self.subset_of].max_name
671
672         # Added by resolve:
673         self.c_name = None
674         delattr(self, "c_name")
675
676     def resolve(self):
677         self.c_name = c_lower(self.name)
678         if self.c_name in _C_KW:
679             self.c_name += '_'
680         if self.c_name == self.family.c_name:
681             self.c_name = ''
682
683     def new_attr(self, elem, value):
684         if 'multi-attr' in elem and elem['multi-attr']:
685             return TypeMultiAttr(self.family, self, elem, value)
686         elif elem['type'] in scalars:
687             return TypeScalar(self.family, self, elem, value)
688         elif elem['type'] == 'unused':
689             return TypeUnused(self.family, self, elem, value)
690         elif elem['type'] == 'pad':
691             return TypePad(self.family, self, elem, value)
692         elif elem['type'] == 'flag':
693             return TypeFlag(self.family, self, elem, value)
694         elif elem['type'] == 'string':
695             return TypeString(self.family, self, elem, value)
696         elif elem['type'] == 'binary':
697             return TypeBinary(self.family, self, elem, value)
698         elif elem['type'] == 'nest':
699             return TypeNest(self.family, self, elem, value)
700         elif elem['type'] == 'array-nest':
701             return TypeArrayNest(self.family, self, elem, value)
702         elif elem['type'] == 'nest-type-value':
703             return TypeNestTypeValue(self.family, self, elem, value)
704         else:
705             raise Exception(f"No typed class for type {elem['type']}")
706
707
708 class Operation(SpecOperation):
709     def __init__(self, family, yaml, req_value, rsp_value):
710         super().__init__(family, yaml, req_value, rsp_value)
711
712         self.render_name = family.name + '_' + c_lower(self.name)
713
714         self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
715                          ('dump' in yaml and 'request' in yaml['dump'])
716
717         # Added by resolve:
718         self.enum_name = None
719         delattr(self, "enum_name")
720
721     def resolve(self):
722         self.resolve_up(super())
723
724         if not self.is_async:
725             self.enum_name = self.family.op_prefix + c_upper(self.name)
726         else:
727             self.enum_name = self.family.async_op_prefix + c_upper(self.name)
728
729     def add_notification(self, op):
730         if 'notify' not in self.yaml:
731             self.yaml['notify'] = dict()
732             self.yaml['notify']['reply'] = self.yaml['do']['reply']
733             self.yaml['notify']['cmds'] = []
734         self.yaml['notify']['cmds'].append(op)
735
736
737 class Family(SpecFamily):
738     def __init__(self, file_name):
739         # Added by resolve:
740         self.c_name = None
741         delattr(self, "c_name")
742         self.op_prefix = None
743         delattr(self, "op_prefix")
744         self.async_op_prefix = None
745         delattr(self, "async_op_prefix")
746         self.mcgrps = None
747         delattr(self, "mcgrps")
748         self.consts = None
749         delattr(self, "consts")
750         self.hooks = None
751         delattr(self, "hooks")
752
753         super().__init__(file_name)
754
755         self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
756         self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
757
758         if 'definitions' not in self.yaml:
759             self.yaml['definitions'] = []
760
761         if 'uapi-header' in self.yaml:
762             self.uapi_header = self.yaml['uapi-header']
763         else:
764             self.uapi_header = f"linux/{self.name}.h"
765
766     def resolve(self):
767         self.resolve_up(super())
768
769         if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
770             raise Exception("Codegen only supported for genetlink")
771
772         self.c_name = c_lower(self.name)
773         if 'name-prefix' in self.yaml['operations']:
774             self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
775         else:
776             self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
777         if 'async-prefix' in self.yaml['operations']:
778             self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
779         else:
780             self.async_op_prefix = self.op_prefix
781
782         self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
783
784         self.hooks = dict()
785         for when in ['pre', 'post']:
786             self.hooks[when] = dict()
787             for op_mode in ['do', 'dump']:
788                 self.hooks[when][op_mode] = dict()
789                 self.hooks[when][op_mode]['set'] = set()
790                 self.hooks[when][op_mode]['list'] = []
791
792         # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
793         self.root_sets = dict()
794         # dict space-name -> set('request', 'reply')
795         self.pure_nested_structs = dict()
796         self.all_notify = dict()
797
798         self._mock_up_events()
799
800         self._dictify()
801         self._load_root_sets()
802         self._load_nested_sets()
803         self._load_all_notify()
804         self._load_hooks()
805
806         self.kernel_policy = self.yaml.get('kernel-policy', 'split')
807         if self.kernel_policy == 'global':
808             self._load_global_policy()
809
810     def new_enum(self, elem):
811         return EnumSet(self, elem)
812
813     def new_attr_set(self, elem):
814         return AttrSet(self, elem)
815
816     def new_operation(self, elem, req_value, rsp_value):
817         return Operation(self, elem, req_value, rsp_value)
818
819     # Fake a 'do' equivalent of all events, so that we can render their response parsing
820     def _mock_up_events(self):
821         for op in self.yaml['operations']['list']:
822             if 'event' in op:
823                 op['do'] = {
824                     'reply': {
825                         'attributes': op['event']['attributes']
826                     }
827                 }
828
829     def _dictify(self):
830         ntf = []
831         for msg in self.msgs.values():
832             if 'notify' in msg:
833                 ntf.append(msg)
834         for n in ntf:
835             self.ops[n['notify']].add_notification(n)
836
837     def _load_root_sets(self):
838         for op_name, op in self.ops.items():
839             if 'attribute-set' not in op:
840                 continue
841
842             req_attrs = set()
843             rsp_attrs = set()
844             for op_mode in ['do', 'dump']:
845                 if op_mode in op and 'request' in op[op_mode]:
846                     req_attrs.update(set(op[op_mode]['request']['attributes']))
847                 if op_mode in op and 'reply' in op[op_mode]:
848                     rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
849
850             if op['attribute-set'] not in self.root_sets:
851                 self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
852             else:
853                 self.root_sets[op['attribute-set']]['request'].update(req_attrs)
854                 self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
855
856     def _load_nested_sets(self):
857         attr_set_queue = list(self.root_sets.keys())
858         attr_set_seen = set(self.root_sets.keys())
859
860         while len(attr_set_queue):
861             a_set = attr_set_queue.pop(0)
862             for attr, spec in self.attr_sets[a_set].items():
863                 if 'nested-attributes' not in spec:
864                     continue
865
866                 nested = spec['nested-attributes']
867                 if nested not in attr_set_seen:
868                     attr_set_queue.append(nested)
869                     attr_set_seen.add(nested)
870
871                 inherit = set()
872                 if nested not in self.root_sets:
873                     if nested not in self.pure_nested_structs:
874                         self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
875                 else:
876                     raise Exception(f'Using attr set as root and nested not supported - {nested}')
877
878                 if 'type-value' in spec:
879                     if nested in self.root_sets:
880                         raise Exception("Inheriting members to a space used as root not supported")
881                     inherit.update(set(spec['type-value']))
882                 elif spec['type'] == 'array-nest':
883                     inherit.add('idx')
884                 self.pure_nested_structs[nested].set_inherited(inherit)
885
886         for root_set, rs_members in self.root_sets.items():
887             for attr, spec in self.attr_sets[root_set].items():
888                 if 'nested-attributes' in spec:
889                     nested = spec['nested-attributes']
890                     if attr in rs_members['request']:
891                         self.pure_nested_structs[nested].request = True
892                     if attr in rs_members['reply']:
893                         self.pure_nested_structs[nested].reply = True
894
895         # Try to reorder according to dependencies
896         pns_key_list = list(self.pure_nested_structs.keys())
897         pns_key_seen = set()
898         rounds = len(pns_key_list)**2  # it's basically bubble sort
899         for _ in range(rounds):
900             if len(pns_key_list) == 0:
901                 break
902             name = pns_key_list.pop(0)
903             finished = True
904             for _, spec in self.attr_sets[name].items():
905                 if 'nested-attributes' in spec:
906                     if spec['nested-attributes'] not in pns_key_seen:
907                         # Dicts are sorted, this will make struct last
908                         struct = self.pure_nested_structs.pop(name)
909                         self.pure_nested_structs[name] = struct
910                         finished = False
911                         break
912             if finished:
913                 pns_key_seen.add(name)
914             else:
915                 pns_key_list.append(name)
916         # Propagate the request / reply
917         for attr_set, struct in reversed(self.pure_nested_structs.items()):
918             for _, spec in self.attr_sets[attr_set].items():
919                 if 'nested-attributes' in spec:
920                     child = self.pure_nested_structs.get(spec['nested-attributes'])
921                     if child:
922                         child.request |= struct.request
923                         child.reply |= struct.reply
924
925     def _load_all_notify(self):
926         for op_name, op in self.ops.items():
927             if not op:
928                 continue
929
930             if 'notify' in op:
931                 self.all_notify[op_name] = op['notify']['cmds']
932
933     def _load_global_policy(self):
934         global_set = set()
935         attr_set_name = None
936         for op_name, op in self.ops.items():
937             if not op:
938                 continue
939             if 'attribute-set' not in op:
940                 continue
941
942             if attr_set_name is None:
943                 attr_set_name = op['attribute-set']
944             if attr_set_name != op['attribute-set']:
945                 raise Exception('For a global policy all ops must use the same set')
946
947             for op_mode in ['do', 'dump']:
948                 if op_mode in op:
949                     global_set.update(op[op_mode].get('request', []))
950
951         self.global_policy = []
952         self.global_policy_set = attr_set_name
953         for attr in self.attr_sets[attr_set_name]:
954             if attr in global_set:
955                 self.global_policy.append(attr)
956
957     def _load_hooks(self):
958         for op in self.ops.values():
959             for op_mode in ['do', 'dump']:
960                 if op_mode not in op:
961                     continue
962                 for when in ['pre', 'post']:
963                     if when not in op[op_mode]:
964                         continue
965                     name = op[op_mode][when]
966                     if name in self.hooks[when][op_mode]['set']:
967                         continue
968                     self.hooks[when][op_mode]['set'].add(name)
969                     self.hooks[when][op_mode]['list'].append(name)
970
971     def has_notifications(self):
972         for op in self.ops.values():
973             if 'notify' in op or 'event' in op:
974                 return True
975         return False
976
977
978 class RenderInfo:
979     def __init__(self, cw, family, ku_space, op, op_name, op_mode, attr_set=None):
980         self.family = family
981         self.nl = cw.nlib
982         self.ku_space = ku_space
983         self.op = op
984         self.op_name = op_name
985         self.op_mode = op_mode
986
987         # 'do' and 'dump' response parsing is identical
988         self.type_consistent = True
989         if op_mode != 'do' and 'dump' in op and 'do' in op:
990             if ('reply' in op['do']) != ('reply' in op["dump"]):
991                 self.type_consistent = False
992             elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
993                 self.type_consistent = False
994
995         self.attr_set = attr_set
996         if not self.attr_set:
997             self.attr_set = op['attribute-set']
998
999         if op:
1000             self.type_name = c_lower(op_name)
1001         else:
1002             self.type_name = c_lower(attr_set)
1003
1004         self.cw = cw
1005
1006         self.struct = dict()
1007         for op_dir in ['request', 'reply']:
1008             if op and op_dir in op[op_mode]:
1009                 self.struct[op_dir] = Struct(family, self.attr_set,
1010                                              type_list=op[op_mode][op_dir]['attributes'])
1011         if op_mode == 'event':
1012             self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1013
1014
1015 class CodeWriter:
1016     def __init__(self, nlib, out_file):
1017         self.nlib = nlib
1018
1019         self._nl = False
1020         self._block_end = False
1021         self._silent_block = False
1022         self._ind = 0
1023         self._out = out_file
1024
1025     @classmethod
1026     def _is_cond(cls, line):
1027         return line.startswith('if') or line.startswith('while') or line.startswith('for')
1028
1029     def p(self, line, add_ind=0):
1030         if self._block_end:
1031             self._block_end = False
1032             if line.startswith('else'):
1033                 line = '} ' + line
1034             else:
1035                 self._out.write('\t' * self._ind + '}\n')
1036
1037         if self._nl:
1038             self._out.write('\n')
1039             self._nl = False
1040
1041         ind = self._ind
1042         if line[-1] == ':':
1043             ind -= 1
1044         if self._silent_block:
1045             ind += 1
1046         self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1047         if add_ind:
1048             ind += add_ind
1049         self._out.write('\t' * ind + line + '\n')
1050
1051     def nl(self):
1052         self._nl = True
1053
1054     def block_start(self, line=''):
1055         if line:
1056             line = line + ' '
1057         self.p(line + '{')
1058         self._ind += 1
1059
1060     def block_end(self, line=''):
1061         if line and line[0] not in {';', ','}:
1062             line = ' ' + line
1063         self._ind -= 1
1064         self._nl = False
1065         if not line:
1066             # Delay printing closing bracket in case "else" comes next
1067             if self._block_end:
1068                 self._out.write('\t' * (self._ind + 1) + '}\n')
1069             self._block_end = True
1070         else:
1071             self.p('}' + line)
1072
1073     def write_doc_line(self, doc, indent=True):
1074         words = doc.split()
1075         line = ' *'
1076         for word in words:
1077             if len(line) + len(word) >= 79:
1078                 self.p(line)
1079                 line = ' *'
1080                 if indent:
1081                     line += '  '
1082             line += ' ' + word
1083         self.p(line)
1084
1085     def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1086         if not args:
1087             args = ['void']
1088
1089         if doc:
1090             self.p('/*')
1091             self.p(' * ' + doc)
1092             self.p(' */')
1093
1094         oneline = qual_ret
1095         if qual_ret[-1] != '*':
1096             oneline += ' '
1097         oneline += f"{name}({', '.join(args)}){suffix}"
1098
1099         if len(oneline) < 80:
1100             self.p(oneline)
1101             return
1102
1103         v = qual_ret
1104         if len(v) > 3:
1105             self.p(v)
1106             v = ''
1107         elif qual_ret[-1] != '*':
1108             v += ' '
1109         v += name + '('
1110         ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1111         delta_ind = len(v) - len(ind)
1112         v += args[0]
1113         i = 1
1114         while i < len(args):
1115             next_len = len(v) + len(args[i])
1116             if v[0] == '\t':
1117                 next_len += delta_ind
1118             if next_len > 76:
1119                 self.p(v + ',')
1120                 v = ind
1121             else:
1122                 v += ', '
1123             v += args[i]
1124             i += 1
1125         self.p(v + ')' + suffix)
1126
1127     def write_func_lvar(self, local_vars):
1128         if not local_vars:
1129             return
1130
1131         if type(local_vars) is str:
1132             local_vars = [local_vars]
1133
1134         local_vars.sort(key=len, reverse=True)
1135         for var in local_vars:
1136             self.p(var)
1137         self.nl()
1138
1139     def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1140         self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1141         self.write_func_lvar(local_vars=local_vars)
1142
1143         self.block_start()
1144         for line in body:
1145             self.p(line)
1146         self.block_end()
1147
1148     def writes_defines(self, defines):
1149         longest = 0
1150         for define in defines:
1151             if len(define[0]) > longest:
1152                 longest = len(define[0])
1153         longest = ((longest + 8) // 8) * 8
1154         for define in defines:
1155             line = '#define ' + define[0]
1156             line += '\t' * ((longest - len(define[0]) + 7) // 8)
1157             if type(define[1]) is int:
1158                 line += str(define[1])
1159             elif type(define[1]) is str:
1160                 line += '"' + define[1] + '"'
1161             self.p(line)
1162
1163     def write_struct_init(self, members):
1164         longest = max([len(x[0]) for x in members])
1165         longest += 1  # because we prepend a .
1166         longest = ((longest + 8) // 8) * 8
1167         for one in members:
1168             line = '.' + one[0]
1169             line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1170             line += '= ' + one[1] + ','
1171             self.p(line)
1172
1173
1174 scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1175
1176 direction_to_suffix = {
1177     'reply': '_rsp',
1178     'request': '_req',
1179     '': ''
1180 }
1181
1182 op_mode_to_wrapper = {
1183     'do': '',
1184     'dump': '_list',
1185     'notify': '_ntf',
1186     'event': '',
1187 }
1188
1189 _C_KW = {
1190     'auto',
1191     'bool',
1192     'break',
1193     'case',
1194     'char',
1195     'const',
1196     'continue',
1197     'default',
1198     'do',
1199     'double',
1200     'else',
1201     'enum',
1202     'extern',
1203     'float',
1204     'for',
1205     'goto',
1206     'if',
1207     'inline',
1208     'int',
1209     'long',
1210     'register',
1211     'return',
1212     'short',
1213     'signed',
1214     'sizeof',
1215     'static',
1216     'struct',
1217     'switch',
1218     'typedef',
1219     'union',
1220     'unsigned',
1221     'void',
1222     'volatile',
1223     'while'
1224 }
1225
1226
1227 def rdir(direction):
1228     if direction == 'reply':
1229         return 'request'
1230     if direction == 'request':
1231         return 'reply'
1232     return direction
1233
1234
1235 def op_prefix(ri, direction, deref=False):
1236     suffix = f"_{ri.type_name}"
1237
1238     if not ri.op_mode or ri.op_mode == 'do':
1239         suffix += f"{direction_to_suffix[direction]}"
1240     else:
1241         if direction == 'request':
1242             suffix += '_req_dump'
1243         else:
1244             if ri.type_consistent:
1245                 if deref:
1246                     suffix += f"{direction_to_suffix[direction]}"
1247                 else:
1248                     suffix += op_mode_to_wrapper[ri.op_mode]
1249             else:
1250                 suffix += '_rsp'
1251                 suffix += '_dump' if deref else '_list'
1252
1253     return f"{ri.family['name']}{suffix}"
1254
1255
1256 def type_name(ri, direction, deref=False):
1257     return f"struct {op_prefix(ri, direction, deref=deref)}"
1258
1259
1260 def print_prototype(ri, direction, terminate=True, doc=None):
1261     suffix = ';' if terminate else ''
1262
1263     fname = ri.op.render_name
1264     if ri.op_mode == 'dump':
1265         fname += '_dump'
1266
1267     args = ['struct ynl_sock *ys']
1268     if 'request' in ri.op[ri.op_mode]:
1269         args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1270
1271     ret = 'int'
1272     if 'reply' in ri.op[ri.op_mode]:
1273         ret = f"{type_name(ri, rdir(direction))} *"
1274
1275     ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1276
1277
1278 def print_req_prototype(ri):
1279     print_prototype(ri, "request", doc=ri.op['doc'])
1280
1281
1282 def print_dump_prototype(ri):
1283     print_prototype(ri, "request")
1284
1285
1286 def put_typol(cw, struct):
1287     type_max = struct.attr_set.max_name
1288     cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1289
1290     for _, arg in struct.member_list():
1291         arg.attr_typol(cw)
1292
1293     cw.block_end(line=';')
1294     cw.nl()
1295
1296     cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1297     cw.p(f'.max_attr = {type_max},')
1298     cw.p(f'.table = {struct.render_name}_policy,')
1299     cw.block_end(line=';')
1300     cw.nl()
1301
1302
1303 def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1304     args = [f'int {arg_name}']
1305     if enum and not ('enum-name' in enum and not enum['enum-name']):
1306         args = [f'enum {render_name} {arg_name}']
1307     cw.write_func_prot('const char *', f'{render_name}_str', args)
1308     cw.block_start()
1309     if enum and enum.type == 'flags':
1310         cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1311     cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)MNL_ARRAY_SIZE({map_name}))')
1312     cw.p('return NULL;')
1313     cw.p(f'return {map_name}[{arg_name}];')
1314     cw.block_end()
1315     cw.nl()
1316
1317
1318 def put_op_name_fwd(family, cw):
1319     cw.write_func_prot('const char *', f'{family.name}_op_str', ['int op'], suffix=';')
1320
1321
1322 def put_op_name(family, cw):
1323     map_name = f'{family.name}_op_strmap'
1324     cw.block_start(line=f"static const char * const {map_name}[] =")
1325     for op_name, op in family.msgs.items():
1326         if op.rsp_value:
1327             if op.req_value == op.rsp_value:
1328                 cw.p(f'[{op.enum_name}] = "{op_name}",')
1329             else:
1330                 cw.p(f'[{op.rsp_value}] = "{op_name}",')
1331     cw.block_end(line=';')
1332     cw.nl()
1333
1334     _put_enum_to_str_helper(cw, family.name + '_op', map_name, 'op')
1335
1336
1337 def put_enum_to_str_fwd(family, cw, enum):
1338     args = [f'enum {enum.render_name} value']
1339     if 'enum-name' in enum and not enum['enum-name']:
1340         args = ['int value']
1341     cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1342
1343
1344 def put_enum_to_str(family, cw, enum):
1345     map_name = f'{enum.render_name}_strmap'
1346     cw.block_start(line=f"static const char * const {map_name}[] =")
1347     for entry in enum.entries.values():
1348         cw.p(f'[{entry.value}] = "{entry.name}",')
1349     cw.block_end(line=';')
1350     cw.nl()
1351
1352     _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1353
1354
1355 def put_req_nested(ri, struct):
1356     func_args = ['struct nlmsghdr *nlh',
1357                  'unsigned int attr_type',
1358                  f'{struct.ptr_name}obj']
1359
1360     ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1361     ri.cw.block_start()
1362     ri.cw.write_func_lvar('struct nlattr *nest;')
1363
1364     ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1365
1366     for _, arg in struct.member_list():
1367         arg.attr_put(ri, "obj")
1368
1369     ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1370
1371     ri.cw.nl()
1372     ri.cw.p('return 0;')
1373     ri.cw.block_end()
1374     ri.cw.nl()
1375
1376
1377 def _multi_parse(ri, struct, init_lines, local_vars):
1378     if struct.nested:
1379         iter_line = "mnl_attr_for_each_nested(attr, nested)"
1380     else:
1381         iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1382
1383     array_nests = set()
1384     multi_attrs = set()
1385     needs_parg = False
1386     for arg, aspec in struct.member_list():
1387         if aspec['type'] == 'array-nest':
1388             local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1389             array_nests.add(arg)
1390         if 'multi-attr' in aspec:
1391             multi_attrs.add(arg)
1392         needs_parg |= 'nested-attributes' in aspec
1393     if array_nests or multi_attrs:
1394         local_vars.append('int i;')
1395     if needs_parg:
1396         local_vars.append('struct ynl_parse_arg parg;')
1397         init_lines.append('parg.ys = yarg->ys;')
1398
1399     all_multi = array_nests | multi_attrs
1400
1401     for anest in sorted(all_multi):
1402         local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1403
1404     ri.cw.block_start()
1405     ri.cw.write_func_lvar(local_vars)
1406
1407     for line in init_lines:
1408         ri.cw.p(line)
1409     ri.cw.nl()
1410
1411     for arg in struct.inherited:
1412         ri.cw.p(f'dst->{arg} = {arg};')
1413
1414     for anest in sorted(all_multi):
1415         aspec = struct[anest]
1416         ri.cw.p(f"if (dst->{aspec.c_name})")
1417         ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1418
1419     ri.cw.nl()
1420     ri.cw.block_start(line=iter_line)
1421     ri.cw.p('unsigned int type = mnl_attr_get_type(attr);')
1422     ri.cw.nl()
1423
1424     first = True
1425     for _, arg in struct.member_list():
1426         good = arg.attr_get(ri, 'dst', first=first)
1427         # First may be 'unused' or 'pad', ignore those
1428         first &= not good
1429
1430     ri.cw.block_end()
1431     ri.cw.nl()
1432
1433     for anest in sorted(array_nests):
1434         aspec = struct[anest]
1435
1436         ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1437         ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1438         ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1439         ri.cw.p('i = 0;')
1440         ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1441         ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1442         ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1443         ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1444         ri.cw.p('return MNL_CB_ERROR;')
1445         ri.cw.p('i++;')
1446         ri.cw.block_end()
1447         ri.cw.block_end()
1448     ri.cw.nl()
1449
1450     for anest in sorted(multi_attrs):
1451         aspec = struct[anest]
1452         ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1453         ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1454         ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1455         ri.cw.p('i = 0;')
1456         if 'nested-attributes' in aspec:
1457             ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1458         ri.cw.block_start(line=iter_line)
1459         ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1460         if 'nested-attributes' in aspec:
1461             ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1462             ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1463             ri.cw.p('return MNL_CB_ERROR;')
1464         elif aspec['type'] in scalars:
1465             t = aspec['type']
1466             if t[0] == 's':
1467                 t = 'u' + t[1:]
1468             ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1469         else:
1470             raise Exception('Nest parsing type not supported yet')
1471         ri.cw.p('i++;')
1472         ri.cw.block_end()
1473         ri.cw.block_end()
1474         ri.cw.block_end()
1475     ri.cw.nl()
1476
1477     if struct.nested:
1478         ri.cw.p('return 0;')
1479     else:
1480         ri.cw.p('return MNL_CB_OK;')
1481     ri.cw.block_end()
1482     ri.cw.nl()
1483
1484
1485 def parse_rsp_nested(ri, struct):
1486     func_args = ['struct ynl_parse_arg *yarg',
1487                  'const struct nlattr *nested']
1488     for arg in struct.inherited:
1489         func_args.append('__u32 ' + arg)
1490
1491     local_vars = ['const struct nlattr *attr;',
1492                   f'{struct.ptr_name}dst = yarg->data;']
1493     init_lines = []
1494
1495     ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1496
1497     _multi_parse(ri, struct, init_lines, local_vars)
1498
1499
1500 def parse_rsp_msg(ri, deref=False):
1501     if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1502         return
1503
1504     func_args = ['const struct nlmsghdr *nlh',
1505                  'void *data']
1506
1507     local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1508                   'struct ynl_parse_arg *yarg = data;',
1509                   'const struct nlattr *attr;']
1510     init_lines = ['dst = yarg->data;']
1511
1512     ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1513
1514     _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1515
1516
1517 def print_req(ri):
1518     ret_ok = '0'
1519     ret_err = '-1'
1520     direction = "request"
1521     local_vars = ['struct nlmsghdr *nlh;',
1522                   'int err;']
1523
1524     if 'reply' in ri.op[ri.op_mode]:
1525         ret_ok = 'rsp'
1526         ret_err = 'NULL'
1527         local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1528                        'struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };']
1529
1530     print_prototype(ri, direction, terminate=False)
1531     ri.cw.block_start()
1532     ri.cw.write_func_lvar(local_vars)
1533
1534     ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1535
1536     ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1537     if 'reply' in ri.op[ri.op_mode]:
1538         ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1539     ri.cw.nl()
1540     for _, attr in ri.struct["request"].member_list():
1541         attr.attr_put(ri, "req")
1542     ri.cw.nl()
1543
1544     parse_arg = "NULL"
1545     if 'reply' in ri.op[ri.op_mode]:
1546         ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1547         ri.cw.p('yrs.yarg.data = rsp;')
1548         ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1549         if ri.op.value is not None:
1550             ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1551         else:
1552             ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1553         ri.cw.nl()
1554         parse_arg = '&yrs'
1555     ri.cw.p(f"err = ynl_exec(ys, nlh, {parse_arg});")
1556     ri.cw.p('if (err < 0)')
1557     if 'reply' in ri.op[ri.op_mode]:
1558         ri.cw.p('goto err_free;')
1559     else:
1560         ri.cw.p('return -1;')
1561     ri.cw.nl()
1562
1563     ri.cw.p(f"return {ret_ok};")
1564     ri.cw.nl()
1565
1566     if 'reply' in ri.op[ri.op_mode]:
1567         ri.cw.p('err_free:')
1568         ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1569         ri.cw.p(f"return {ret_err};")
1570
1571     ri.cw.block_end()
1572
1573
1574 def print_dump(ri):
1575     direction = "request"
1576     print_prototype(ri, direction, terminate=False)
1577     ri.cw.block_start()
1578     local_vars = ['struct ynl_dump_state yds = {};',
1579                   'struct nlmsghdr *nlh;',
1580                   'int err;']
1581
1582     for var in local_vars:
1583         ri.cw.p(f'{var}')
1584     ri.cw.nl()
1585
1586     ri.cw.p('yds.ys = ys;')
1587     ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1588     ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1589     if ri.op.value is not None:
1590         ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1591     else:
1592         ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1593     ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1594     ri.cw.nl()
1595     ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1596
1597     if "request" in ri.op[ri.op_mode]:
1598         ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1599         ri.cw.nl()
1600         for _, attr in ri.struct["request"].member_list():
1601             attr.attr_put(ri, "req")
1602     ri.cw.nl()
1603
1604     ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1605     ri.cw.p('if (err < 0)')
1606     ri.cw.p('goto free_list;')
1607     ri.cw.nl()
1608
1609     ri.cw.p('return yds.first;')
1610     ri.cw.nl()
1611     ri.cw.p('free_list:')
1612     ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1613     ri.cw.p('return NULL;')
1614     ri.cw.block_end()
1615
1616
1617 def call_free(ri, direction, var):
1618     return f"{op_prefix(ri, direction)}_free({var});"
1619
1620
1621 def free_arg_name(direction):
1622     if direction:
1623         return direction_to_suffix[direction][1:]
1624     return 'obj'
1625
1626
1627 def print_alloc_wrapper(ri, direction):
1628     name = op_prefix(ri, direction)
1629     ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1630     ri.cw.block_start()
1631     ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1632     ri.cw.block_end()
1633
1634
1635 def print_free_prototype(ri, direction, suffix=';'):
1636     name = op_prefix(ri, direction)
1637     arg = free_arg_name(direction)
1638     ri.cw.write_func_prot('void', f"{name}_free", [f"struct {name} *{arg}"], suffix=suffix)
1639
1640
1641 def _print_type(ri, direction, struct):
1642     suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1643
1644     if ri.op_mode == 'dump':
1645         suffix += '_dump'
1646
1647     ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1648
1649     meta_started = False
1650     for _, attr in struct.member_list():
1651         for type_filter in ['len', 'bit']:
1652             line = attr.presence_member(ri.ku_space, type_filter)
1653             if line:
1654                 if not meta_started:
1655                     ri.cw.block_start(line=f"struct")
1656                     meta_started = True
1657                 ri.cw.p(line)
1658     if meta_started:
1659         ri.cw.block_end(line='_present;')
1660         ri.cw.nl()
1661
1662     for arg in struct.inherited:
1663         ri.cw.p(f"__u32 {arg};")
1664
1665     for _, attr in struct.member_list():
1666         attr.struct_member(ri)
1667
1668     ri.cw.block_end(line=';')
1669     ri.cw.nl()
1670
1671
1672 def print_type(ri, direction):
1673     _print_type(ri, direction, ri.struct[direction])
1674
1675
1676 def print_type_full(ri, struct):
1677     _print_type(ri, "", struct)
1678
1679
1680 def print_type_helpers(ri, direction, deref=False):
1681     print_free_prototype(ri, direction)
1682     ri.cw.nl()
1683
1684     if ri.ku_space == 'user' and direction == 'request':
1685         for _, attr in ri.struct[direction].member_list():
1686             attr.setter(ri, ri.attr_set, direction, deref=deref)
1687     ri.cw.nl()
1688
1689
1690 def print_req_type_helpers(ri):
1691     print_alloc_wrapper(ri, "request")
1692     print_type_helpers(ri, "request")
1693
1694
1695 def print_rsp_type_helpers(ri):
1696     if 'reply' not in ri.op[ri.op_mode]:
1697         return
1698     print_type_helpers(ri, "reply")
1699
1700
1701 def print_parse_prototype(ri, direction, terminate=True):
1702     suffix = "_rsp" if direction == "reply" else "_req"
1703     term = ';' if terminate else ''
1704
1705     ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1706                           ['const struct nlattr **tb',
1707                            f"struct {ri.op.render_name}{suffix} *req"],
1708                           suffix=term)
1709
1710
1711 def print_req_type(ri):
1712     print_type(ri, "request")
1713
1714
1715 def print_req_free(ri):
1716     if 'request' not in ri.op[ri.op_mode]:
1717         return
1718     _free_type(ri, 'request', ri.struct['request'])
1719
1720
1721 def print_rsp_type(ri):
1722     if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1723         direction = 'reply'
1724     elif ri.op_mode == 'event':
1725         direction = 'reply'
1726     else:
1727         return
1728     print_type(ri, direction)
1729
1730
1731 def print_wrapped_type(ri):
1732     ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1733     if ri.op_mode == 'dump':
1734         ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1735     elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1736         ri.cw.p('__u16 family;')
1737         ri.cw.p('__u8 cmd;')
1738         ri.cw.p('struct ynl_ntf_base_type *next;')
1739         ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1740     ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1741     ri.cw.block_end(line=';')
1742     ri.cw.nl()
1743     print_free_prototype(ri, 'reply')
1744     ri.cw.nl()
1745
1746
1747 def _free_type_members_iter(ri, struct):
1748     for _, attr in struct.member_list():
1749         if attr.free_needs_iter():
1750             ri.cw.p('unsigned int i;')
1751             ri.cw.nl()
1752             break
1753
1754
1755 def _free_type_members(ri, var, struct, ref=''):
1756     for _, attr in struct.member_list():
1757         attr.free(ri, var, ref)
1758
1759
1760 def _free_type(ri, direction, struct):
1761     var = free_arg_name(direction)
1762
1763     print_free_prototype(ri, direction, suffix='')
1764     ri.cw.block_start()
1765     _free_type_members_iter(ri, struct)
1766     _free_type_members(ri, var, struct)
1767     if direction:
1768         ri.cw.p(f'free({var});')
1769     ri.cw.block_end()
1770     ri.cw.nl()
1771
1772
1773 def free_rsp_nested(ri, struct):
1774     _free_type(ri, "", struct)
1775
1776
1777 def print_rsp_free(ri):
1778     if 'reply' not in ri.op[ri.op_mode]:
1779         return
1780     _free_type(ri, 'reply', ri.struct['reply'])
1781
1782
1783 def print_dump_type_free(ri):
1784     sub_type = type_name(ri, 'reply')
1785
1786     print_free_prototype(ri, 'reply', suffix='')
1787     ri.cw.block_start()
1788     ri.cw.p(f"{sub_type} *next = rsp;")
1789     ri.cw.nl()
1790     ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
1791     _free_type_members_iter(ri, ri.struct['reply'])
1792     ri.cw.p('rsp = next;')
1793     ri.cw.p('next = rsp->next;')
1794     ri.cw.nl()
1795
1796     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1797     ri.cw.p(f'free(rsp);')
1798     ri.cw.block_end()
1799     ri.cw.block_end()
1800     ri.cw.nl()
1801
1802
1803 def print_ntf_type_free(ri):
1804     print_free_prototype(ri, 'reply', suffix='')
1805     ri.cw.block_start()
1806     _free_type_members_iter(ri, ri.struct['reply'])
1807     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1808     ri.cw.p(f'free(rsp);')
1809     ri.cw.block_end()
1810     ri.cw.nl()
1811
1812
1813 def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1814     if terminate and ri and kernel_can_gen_family_struct(struct.family):
1815         return
1816
1817     if terminate:
1818         prefix = 'extern '
1819     else:
1820         if kernel_can_gen_family_struct(struct.family) and ri:
1821             prefix = 'static '
1822         else:
1823             prefix = ''
1824
1825     suffix = ';' if terminate else ' = {'
1826
1827     max_attr = struct.attr_max_val
1828     if ri:
1829         name = ri.op.render_name
1830         if ri.op.dual_policy:
1831             name += '_' + ri.op_mode
1832     else:
1833         name = struct.render_name
1834     cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1835
1836
1837 def print_req_policy(cw, struct, ri=None):
1838     print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1839     for _, arg in struct.member_list():
1840         arg.attr_policy(cw)
1841     cw.p("};")
1842
1843
1844 def kernel_can_gen_family_struct(family):
1845     return family.proto == 'genetlink'
1846
1847
1848 def print_kernel_op_table_fwd(family, cw, terminate):
1849     exported = not kernel_can_gen_family_struct(family)
1850
1851     if not terminate or exported:
1852         cw.p(f"/* Ops table for {family.name} */")
1853
1854         pol_to_struct = {'global': 'genl_small_ops',
1855                          'per-op': 'genl_ops',
1856                          'split': 'genl_split_ops'}
1857         struct_type = pol_to_struct[family.kernel_policy]
1858
1859         if not exported:
1860             cnt = ""
1861         elif family.kernel_policy == 'split':
1862             cnt = 0
1863             for op in family.ops.values():
1864                 if 'do' in op:
1865                     cnt += 1
1866                 if 'dump' in op:
1867                     cnt += 1
1868         else:
1869             cnt = len(family.ops)
1870
1871         qual = 'static const' if not exported else 'const'
1872         line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1873         if terminate:
1874             cw.p(f"extern {line};")
1875         else:
1876             cw.block_start(line=line + ' =')
1877
1878     if not terminate:
1879         return
1880
1881     cw.nl()
1882     for name in family.hooks['pre']['do']['list']:
1883         cw.write_func_prot('int', c_lower(name),
1884                            ['const struct genl_split_ops *ops',
1885                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1886     for name in family.hooks['post']['do']['list']:
1887         cw.write_func_prot('void', c_lower(name),
1888                            ['const struct genl_split_ops *ops',
1889                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1890     for name in family.hooks['pre']['dump']['list']:
1891         cw.write_func_prot('int', c_lower(name),
1892                            ['struct netlink_callback *cb'], suffix=';')
1893     for name in family.hooks['post']['dump']['list']:
1894         cw.write_func_prot('int', c_lower(name),
1895                            ['struct netlink_callback *cb'], suffix=';')
1896
1897     cw.nl()
1898
1899     for op_name, op in family.ops.items():
1900         if op.is_async:
1901             continue
1902
1903         if 'do' in op:
1904             name = c_lower(f"{family.name}-nl-{op_name}-doit")
1905             cw.write_func_prot('int', name,
1906                                ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1907
1908         if 'dump' in op:
1909             name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1910             cw.write_func_prot('int', name,
1911                                ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1912     cw.nl()
1913
1914
1915 def print_kernel_op_table_hdr(family, cw):
1916     print_kernel_op_table_fwd(family, cw, terminate=True)
1917
1918
1919 def print_kernel_op_table(family, cw):
1920     print_kernel_op_table_fwd(family, cw, terminate=False)
1921     if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1922         for op_name, op in family.ops.items():
1923             if op.is_async:
1924                 continue
1925
1926             cw.block_start()
1927             members = [('cmd', op.enum_name)]
1928             if 'dont-validate' in op:
1929                 members.append(('validate',
1930                                 ' | '.join([c_upper('genl-dont-validate-' + x)
1931                                             for x in op['dont-validate']])), )
1932             for op_mode in ['do', 'dump']:
1933                 if op_mode in op:
1934                     name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1935                     members.append((op_mode + 'it', name))
1936             if family.kernel_policy == 'per-op':
1937                 struct = Struct(family, op['attribute-set'],
1938                                 type_list=op['do']['request']['attributes'])
1939
1940                 name = c_lower(f"{family.name}-{op_name}-nl-policy")
1941                 members.append(('policy', name))
1942                 members.append(('maxattr', struct.attr_max_val.enum_name))
1943             if 'flags' in op:
1944                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1945             cw.write_struct_init(members)
1946             cw.block_end(line=',')
1947     elif family.kernel_policy == 'split':
1948         cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1949                     'dump': {'pre': 'start', 'post': 'done'}}
1950
1951         for op_name, op in family.ops.items():
1952             for op_mode in ['do', 'dump']:
1953                 if op.is_async or op_mode not in op:
1954                     continue
1955
1956                 cw.block_start()
1957                 members = [('cmd', op.enum_name)]
1958                 if 'dont-validate' in op:
1959                     members.append(('validate',
1960                                     ' | '.join([c_upper('genl-dont-validate-' + x)
1961                                                 for x in op['dont-validate']])), )
1962                 name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1963                 if 'pre' in op[op_mode]:
1964                     members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
1965                 members.append((op_mode + 'it', name))
1966                 if 'post' in op[op_mode]:
1967                     members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
1968                 if 'request' in op[op_mode]:
1969                     struct = Struct(family, op['attribute-set'],
1970                                     type_list=op[op_mode]['request']['attributes'])
1971
1972                     if op.dual_policy:
1973                         name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
1974                     else:
1975                         name = c_lower(f"{family.name}-{op_name}-nl-policy")
1976                     members.append(('policy', name))
1977                     members.append(('maxattr', struct.attr_max_val.enum_name))
1978                 flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
1979                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
1980                 cw.write_struct_init(members)
1981                 cw.block_end(line=',')
1982
1983     cw.block_end(line=';')
1984     cw.nl()
1985
1986
1987 def print_kernel_mcgrp_hdr(family, cw):
1988     if not family.mcgrps['list']:
1989         return
1990
1991     cw.block_start('enum')
1992     for grp in family.mcgrps['list']:
1993         grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
1994         cw.p(grp_id)
1995     cw.block_end(';')
1996     cw.nl()
1997
1998
1999 def print_kernel_mcgrp_src(family, cw):
2000     if not family.mcgrps['list']:
2001         return
2002
2003     cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
2004     for grp in family.mcgrps['list']:
2005         name = grp['name']
2006         grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2007         cw.p('[' + grp_id + '] = { "' + name + '", },')
2008     cw.block_end(';')
2009     cw.nl()
2010
2011
2012 def print_kernel_family_struct_hdr(family, cw):
2013     if not kernel_can_gen_family_struct(family):
2014         return
2015
2016     cw.p(f"extern struct genl_family {family.name}_nl_family;")
2017     cw.nl()
2018
2019
2020 def print_kernel_family_struct_src(family, cw):
2021     if not kernel_can_gen_family_struct(family):
2022         return
2023
2024     cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2025     cw.p('.name\t\t= ' + family.fam_key + ',')
2026     cw.p('.version\t= ' + family.ver_key + ',')
2027     cw.p('.netnsok\t= true,')
2028     cw.p('.parallel_ops\t= true,')
2029     cw.p('.module\t\t= THIS_MODULE,')
2030     if family.kernel_policy == 'per-op':
2031         cw.p(f'.ops\t\t= {family.name}_nl_ops,')
2032         cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
2033     elif family.kernel_policy == 'split':
2034         cw.p(f'.split_ops\t= {family.name}_nl_ops,')
2035         cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
2036     if family.mcgrps['list']:
2037         cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
2038         cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
2039     cw.block_end(';')
2040
2041
2042 def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2043     start_line = 'enum'
2044     if enum_name in obj:
2045         if obj[enum_name]:
2046             start_line = 'enum ' + c_lower(obj[enum_name])
2047     elif ckey and ckey in obj:
2048         start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
2049     cw.block_start(line=start_line)
2050
2051
2052 def render_uapi(family, cw):
2053     hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
2054     cw.p('#ifndef ' + hdr_prot)
2055     cw.p('#define ' + hdr_prot)
2056     cw.nl()
2057
2058     defines = [(family.fam_key, family["name"]),
2059                (family.ver_key, family.get('version', 1))]
2060     cw.writes_defines(defines)
2061     cw.nl()
2062
2063     defines = []
2064     for const in family['definitions']:
2065         if const['type'] != 'const':
2066             cw.writes_defines(defines)
2067             defines = []
2068             cw.nl()
2069
2070         # Write kdoc for enum and flags (one day maybe also structs)
2071         if const['type'] == 'enum' or const['type'] == 'flags':
2072             enum = family.consts[const['name']]
2073
2074             if enum.has_doc():
2075                 cw.p('/**')
2076                 doc = ''
2077                 if 'doc' in enum:
2078                     doc = ' - ' + enum['doc']
2079                 cw.write_doc_line(enum.enum_name + doc)
2080                 for entry in enum.entries.values():
2081                     if entry.has_doc():
2082                         doc = '@' + entry.c_name + ': ' + entry['doc']
2083                         cw.write_doc_line(doc)
2084                 cw.p(' */')
2085
2086             uapi_enum_start(family, cw, const, 'name')
2087             name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2088             for entry in enum.entries.values():
2089                 suffix = ','
2090                 if entry.value_change:
2091                     suffix = f" = {entry.user_value()}" + suffix
2092                 cw.p(entry.c_name + suffix)
2093
2094             if const.get('render-max', False):
2095                 cw.nl()
2096                 if const['type'] == 'flags':
2097                     max_name = c_upper(name_pfx + 'mask')
2098                     max_val = f' = {enum.get_mask()},'
2099                     cw.p(max_name + max_val)
2100                 else:
2101                     max_name = c_upper(name_pfx + 'max')
2102                     cw.p('__' + max_name + ',')
2103                     cw.p(max_name + ' = (__' + max_name + ' - 1)')
2104             cw.block_end(line=';')
2105             cw.nl()
2106         elif const['type'] == 'const':
2107             defines.append([c_upper(family.get('c-define-name',
2108                                                f"{family.name}-{const['name']}")),
2109                             const['value']])
2110
2111     if defines:
2112         cw.writes_defines(defines)
2113         cw.nl()
2114
2115     max_by_define = family.get('max-by-define', False)
2116
2117     for _, attr_set in family.attr_sets.items():
2118         if attr_set.subset_of:
2119             continue
2120
2121         cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
2122         max_value = f"({cnt_name} - 1)"
2123
2124         val = 0
2125         uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2126         for _, attr in attr_set.items():
2127             suffix = ','
2128             if attr.value != val:
2129                 suffix = f" = {attr.value},"
2130                 val = attr.value
2131             val += 1
2132             cw.p(attr.enum_name + suffix)
2133         cw.nl()
2134         cw.p(cnt_name + ('' if max_by_define else ','))
2135         if not max_by_define:
2136             cw.p(f"{attr_set.max_name} = {max_value}")
2137         cw.block_end(line=';')
2138         if max_by_define:
2139             cw.p(f"#define {attr_set.max_name} {max_value}")
2140         cw.nl()
2141
2142     # Commands
2143     separate_ntf = 'async-prefix' in family['operations']
2144
2145     max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2146     cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2147     max_value = f"({cnt_name} - 1)"
2148
2149     uapi_enum_start(family, cw, family['operations'], 'enum-name')
2150     val = 0
2151     for op in family.msgs.values():
2152         if separate_ntf and ('notify' in op or 'event' in op):
2153             continue
2154
2155         suffix = ','
2156         if op.value != val:
2157             suffix = f" = {op.value},"
2158             val = op.value
2159         cw.p(op.enum_name + suffix)
2160         val += 1
2161     cw.nl()
2162     cw.p(cnt_name + ('' if max_by_define else ','))
2163     if not max_by_define:
2164         cw.p(f"{max_name} = {max_value}")
2165     cw.block_end(line=';')
2166     if max_by_define:
2167         cw.p(f"#define {max_name} {max_value}")
2168     cw.nl()
2169
2170     if separate_ntf:
2171         uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2172         for op in family.msgs.values():
2173             if separate_ntf and not ('notify' in op or 'event' in op):
2174                 continue
2175
2176             suffix = ','
2177             if 'value' in op:
2178                 suffix = f" = {op['value']},"
2179             cw.p(op.enum_name + suffix)
2180         cw.block_end(line=';')
2181         cw.nl()
2182
2183     # Multicast
2184     defines = []
2185     for grp in family.mcgrps['list']:
2186         name = grp['name']
2187         defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2188                         f'{name}'])
2189     cw.nl()
2190     if defines:
2191         cw.writes_defines(defines)
2192         cw.nl()
2193
2194     cw.p(f'#endif /* {hdr_prot} */')
2195
2196
2197 def _render_user_ntf_entry(ri, op):
2198     ri.cw.block_start(line=f"[{op.enum_name}] = ")
2199     ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2200     ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2201     ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2202     ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2203     ri.cw.block_end(line=',')
2204
2205
2206 def render_user_family(family, cw, prototype):
2207     symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2208     if prototype:
2209         cw.p(f'extern {symbol};')
2210         return
2211
2212     ntf = family.has_notifications()
2213     if ntf:
2214         cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2215         for ntf_op in sorted(family.all_notify.keys()):
2216             op = family.ops[ntf_op]
2217             ri = RenderInfo(cw, family, "user", op, ntf_op, "notify")
2218             for ntf in op['notify']['cmds']:
2219                 _render_user_ntf_entry(ri, ntf)
2220         for op_name, op in family.ops.items():
2221             if 'event' not in op:
2222                 continue
2223             ri = RenderInfo(cw, family, "user", op, op_name, "event")
2224             _render_user_ntf_entry(ri, op)
2225         cw.block_end(line=";")
2226         cw.nl()
2227
2228     cw.block_start(f'{symbol} = ')
2229     cw.p(f'.name\t\t= "{family.name}",')
2230     if ntf:
2231         cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2232         cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2233     cw.block_end(line=';')
2234
2235
2236 def find_kernel_root(full_path):
2237     sub_path = ''
2238     while True:
2239         sub_path = os.path.join(os.path.basename(full_path), sub_path)
2240         full_path = os.path.dirname(full_path)
2241         maintainers = os.path.join(full_path, "MAINTAINERS")
2242         if os.path.exists(maintainers):
2243             return full_path, sub_path[:-1]
2244
2245
2246 def main():
2247     parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2248     parser.add_argument('--mode', dest='mode', type=str, required=True)
2249     parser.add_argument('--spec', dest='spec', type=str, required=True)
2250     parser.add_argument('--header', dest='header', action='store_true', default=None)
2251     parser.add_argument('--source', dest='header', action='store_false')
2252     parser.add_argument('--user-header', nargs='+', default=[])
2253     parser.add_argument('-o', dest='out_file', type=str)
2254     args = parser.parse_args()
2255
2256     out_file = open(args.out_file, 'w+') if args.out_file else os.sys.stdout
2257
2258     if args.header is None:
2259         parser.error("--header or --source is required")
2260
2261     try:
2262         parsed = Family(args.spec)
2263         if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2264             print('Spec license:', parsed.license)
2265             print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2266             os.sys.exit(1)
2267     except yaml.YAMLError as exc:
2268         print(exc)
2269         os.sys.exit(1)
2270         return
2271
2272     supported_models = ['unified']
2273     if args.mode == 'user':
2274         supported_models += ['directional']
2275     if parsed.msg_id_model not in supported_models:
2276         print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2277         os.sys.exit(1)
2278
2279     cw = CodeWriter(BaseNlLib(), out_file)
2280
2281     _, spec_kernel = find_kernel_root(args.spec)
2282     if args.mode == 'uapi' or args.header:
2283         cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2284     else:
2285         cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2286     cw.p("/* Do not edit directly, auto-generated from: */")
2287     cw.p(f"/*\t{spec_kernel} */")
2288     cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2289     cw.nl()
2290
2291     if args.mode == 'uapi':
2292         render_uapi(parsed, cw)
2293         return
2294
2295     hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2296     if args.header:
2297         cw.p('#ifndef ' + hdr_prot)
2298         cw.p('#define ' + hdr_prot)
2299         cw.nl()
2300
2301     if args.mode == 'kernel':
2302         cw.p('#include <net/netlink.h>')
2303         cw.p('#include <net/genetlink.h>')
2304         cw.nl()
2305         if not args.header:
2306             if args.out_file:
2307                 cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2308             cw.nl()
2309         headers = ['uapi/' + parsed.uapi_header]
2310     else:
2311         cw.p('#include <stdlib.h>')
2312         cw.p('#include <string.h>')
2313         if args.header:
2314             cw.p('#include <linux/types.h>')
2315         else:
2316             cw.p(f'#include "{parsed.name}-user.h"')
2317             cw.p('#include "ynl.h"')
2318         headers = [parsed.uapi_header]
2319     for definition in parsed['definitions']:
2320         if 'header' in definition:
2321             headers.append(definition['header'])
2322     for one in headers:
2323         cw.p(f"#include <{one}>")
2324     cw.nl()
2325
2326     if args.mode == "user":
2327         if not args.header:
2328             cw.p("#include <libmnl/libmnl.h>")
2329             cw.p("#include <linux/genetlink.h>")
2330             cw.nl()
2331             for one in args.user_header:
2332                 cw.p(f'#include "{one}"')
2333         else:
2334             cw.p('struct ynl_sock;')
2335             cw.nl()
2336             render_user_family(parsed, cw, True)
2337         cw.nl()
2338
2339     if args.mode == "kernel":
2340         if args.header:
2341             for _, struct in sorted(parsed.pure_nested_structs.items()):
2342                 if struct.request:
2343                     cw.p('/* Common nested types */')
2344                     break
2345             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2346                 if struct.request:
2347                     print_req_policy_fwd(cw, struct)
2348             cw.nl()
2349
2350             if parsed.kernel_policy == 'global':
2351                 cw.p(f"/* Global operation policy for {parsed.name} */")
2352
2353                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2354                 print_req_policy_fwd(cw, struct)
2355                 cw.nl()
2356
2357             if parsed.kernel_policy in {'per-op', 'split'}:
2358                 for op_name, op in parsed.ops.items():
2359                     if 'do' in op and 'event' not in op:
2360                         ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2361                         print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2362                         cw.nl()
2363
2364             print_kernel_op_table_hdr(parsed, cw)
2365             print_kernel_mcgrp_hdr(parsed, cw)
2366             print_kernel_family_struct_hdr(parsed, cw)
2367         else:
2368             for _, struct in sorted(parsed.pure_nested_structs.items()):
2369                 if struct.request:
2370                     cw.p('/* Common nested types */')
2371                     break
2372             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2373                 if struct.request:
2374                     print_req_policy(cw, struct)
2375             cw.nl()
2376
2377             if parsed.kernel_policy == 'global':
2378                 cw.p(f"/* Global operation policy for {parsed.name} */")
2379
2380                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2381                 print_req_policy(cw, struct)
2382                 cw.nl()
2383
2384             for op_name, op in parsed.ops.items():
2385                 if parsed.kernel_policy in {'per-op', 'split'}:
2386                     for op_mode in ['do', 'dump']:
2387                         if op_mode in op and 'request' in op[op_mode]:
2388                             cw.p(f"/* {op.enum_name} - {op_mode} */")
2389                             ri = RenderInfo(cw, parsed, args.mode, op, op_name, op_mode)
2390                             print_req_policy(cw, ri.struct['request'], ri=ri)
2391                             cw.nl()
2392
2393             print_kernel_op_table(parsed, cw)
2394             print_kernel_mcgrp_src(parsed, cw)
2395             print_kernel_family_struct_src(parsed, cw)
2396
2397     if args.mode == "user":
2398         if args.header:
2399             cw.p('/* Enums */')
2400             put_op_name_fwd(parsed, cw)
2401
2402             for name, const in parsed.consts.items():
2403                 if isinstance(const, EnumSet):
2404                     put_enum_to_str_fwd(parsed, cw, const)
2405             cw.nl()
2406
2407             cw.p('/* Common nested types */')
2408             for attr_set, struct in parsed.pure_nested_structs.items():
2409                 ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2410                 print_type_full(ri, struct)
2411
2412             for op_name, op in parsed.ops.items():
2413                 cw.p(f"/* ============== {op.enum_name} ============== */")
2414
2415                 if 'do' in op and 'event' not in op:
2416                     cw.p(f"/* {op.enum_name} - do */")
2417                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2418                     print_req_type(ri)
2419                     print_req_type_helpers(ri)
2420                     cw.nl()
2421                     print_rsp_type(ri)
2422                     print_rsp_type_helpers(ri)
2423                     cw.nl()
2424                     print_req_prototype(ri)
2425                     cw.nl()
2426
2427                 if 'dump' in op:
2428                     cw.p(f"/* {op.enum_name} - dump */")
2429                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'dump')
2430                     if 'request' in op['dump']:
2431                         print_req_type(ri)
2432                         print_req_type_helpers(ri)
2433                     if not ri.type_consistent:
2434                         print_rsp_type(ri)
2435                     print_wrapped_type(ri)
2436                     print_dump_prototype(ri)
2437                     cw.nl()
2438
2439                 if 'notify' in op:
2440                     cw.p(f"/* {op.enum_name} - notify */")
2441                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2442                     if not ri.type_consistent:
2443                         raise Exception(f'Only notifications with consistent types supported ({op.name})')
2444                     print_wrapped_type(ri)
2445
2446                 if 'event' in op:
2447                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'event')
2448                     cw.p(f"/* {op.enum_name} - event */")
2449                     print_rsp_type(ri)
2450                     cw.nl()
2451                     print_wrapped_type(ri)
2452             cw.nl()
2453         else:
2454             cw.p('/* Enums */')
2455             put_op_name(parsed, cw)
2456
2457             for name, const in parsed.consts.items():
2458                 if isinstance(const, EnumSet):
2459                     put_enum_to_str(parsed, cw, const)
2460             cw.nl()
2461
2462             cw.p('/* Policies */')
2463             for name in parsed.pure_nested_structs:
2464                 struct = Struct(parsed, name)
2465                 put_typol(cw, struct)
2466             for name in parsed.root_sets:
2467                 struct = Struct(parsed, name)
2468                 put_typol(cw, struct)
2469
2470             cw.p('/* Common nested types */')
2471             for attr_set, struct in parsed.pure_nested_structs.items():
2472                 ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2473
2474                 free_rsp_nested(ri, struct)
2475                 if struct.request:
2476                     put_req_nested(ri, struct)
2477                 if struct.reply:
2478                     parse_rsp_nested(ri, struct)
2479
2480             for op_name, op in parsed.ops.items():
2481                 cw.p(f"/* ============== {op.enum_name} ============== */")
2482                 if 'do' in op and 'event' not in op:
2483                     cw.p(f"/* {op.enum_name} - do */")
2484                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2485                     print_req_free(ri)
2486                     print_rsp_free(ri)
2487                     parse_rsp_msg(ri)
2488                     print_req(ri)
2489                     cw.nl()
2490
2491                 if 'dump' in op:
2492                     cw.p(f"/* {op.enum_name} - dump */")
2493                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "dump")
2494                     if not ri.type_consistent:
2495                         parse_rsp_msg(ri, deref=True)
2496                     print_dump_type_free(ri)
2497                     print_dump(ri)
2498                     cw.nl()
2499
2500                 if 'notify' in op:
2501                     cw.p(f"/* {op.enum_name} - notify */")
2502                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2503                     if not ri.type_consistent:
2504                         raise Exception(f'Only notifications with consistent types supported ({op.name})')
2505                     print_ntf_type_free(ri)
2506
2507                 if 'event' in op:
2508                     cw.p(f"/* {op.enum_name} - event */")
2509
2510                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2511                     parse_rsp_msg(ri)
2512
2513                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "event")
2514                     print_ntf_type_free(ri)
2515             cw.nl()
2516             render_user_family(parsed, cw, False)
2517
2518     if args.header:
2519         cw.p(f'#endif /* {hdr_prot} */')
2520
2521
2522 if __name__ == "__main__":
2523     main()