tools: ynl-gen: fill in support for MultiAttr scalars
[platform/kernel/linux-rpi.git] / tools / net / ynl / ynl-gen-c.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4 import argparse
5 import collections
6 import os
7 import yaml
8
9 from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
10
11
12 def c_upper(name):
13     return name.upper().replace('-', '_')
14
15
16 def c_lower(name):
17     return name.lower().replace('-', '_')
18
19
20 class BaseNlLib:
21     def get_family_id(self):
22         return 'ys->family_id'
23
24     def parse_cb_run(self, cb, data, is_dump=False, indent=1):
25         ind = '\n\t\t' + '\t' * indent + ' '
26         if is_dump:
27             return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
28         else:
29             return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
30                    "ynl_cb_array, NLMSG_MIN_TYPE)"
31
32
33 class Type(SpecAttr):
34     def __init__(self, family, attr_set, attr, value):
35         super().__init__(family, attr_set, attr, value)
36
37         self.attr = attr
38         self.attr_set = attr_set
39         self.type = attr['type']
40         self.checks = attr.get('checks', {})
41
42         if 'len' in attr:
43             self.len = attr['len']
44         if 'nested-attributes' in attr:
45             self.nested_attrs = attr['nested-attributes']
46             if self.nested_attrs == family.name:
47                 self.nested_render_name = f"{family.name}"
48             else:
49                 self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
50
51         self.c_name = c_lower(self.name)
52         if self.c_name in _C_KW:
53             self.c_name += '_'
54
55         # Added by resolve():
56         self.enum_name = None
57         delattr(self, "enum_name")
58
59     def resolve(self):
60         self.enum_name = f"{self.attr_set.name_prefix}{self.name}"
61         self.enum_name = c_upper(self.enum_name)
62
63     def is_multi_val(self):
64         return None
65
66     def is_scalar(self):
67         return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
68
69     def presence_type(self):
70         return 'bit'
71
72     def presence_member(self, space, type_filter):
73         if self.presence_type() != type_filter:
74             return
75
76         if self.presence_type() == 'bit':
77             pfx = '__' if space == 'user' else ''
78             return f"{pfx}u32 {self.c_name}:1;"
79
80         if self.presence_type() == 'len':
81             pfx = '__' if space == 'user' else ''
82             return f"{pfx}u32 {self.c_name}_len;"
83
84     def _complex_member_type(self, ri):
85         return None
86
87     def free_needs_iter(self):
88         return False
89
90     def free(self, ri, var, ref):
91         if self.is_multi_val() or self.presence_type() == 'len':
92             ri.cw.p(f'free({var}->{ref}{self.c_name});')
93
94     def arg_member(self, ri):
95         member = self._complex_member_type(ri)
96         if member:
97             arg = [member + ' *' + self.c_name]
98             if self.presence_type() == 'count':
99                 arg += ['unsigned int n_' + self.c_name]
100             return arg
101         raise Exception(f"Struct member not implemented for class type {self.type}")
102
103     def struct_member(self, ri):
104         if self.is_multi_val():
105             ri.cw.p(f"unsigned int n_{self.c_name};")
106         member = self._complex_member_type(ri)
107         if member:
108             ptr = '*' if self.is_multi_val() else ''
109             ri.cw.p(f"{member} {ptr}{self.c_name};")
110             return
111         members = self.arg_member(ri)
112         for one in members:
113             ri.cw.p(one + ';')
114
115     def _attr_policy(self, policy):
116         return '{ .type = ' + policy + ', }'
117
118     def attr_policy(self, cw):
119         policy = c_upper('nla-' + self.attr['type'])
120
121         spec = self._attr_policy(policy)
122         cw.p(f"\t[{self.enum_name}] = {spec},")
123
124     def _attr_typol(self):
125         raise Exception(f"Type policy not implemented for class type {self.type}")
126
127     def attr_typol(self, cw):
128         typol = self._attr_typol()
129         cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
130
131     def _attr_put_line(self, ri, var, line):
132         if self.presence_type() == 'bit':
133             ri.cw.p(f"if ({var}->_present.{self.c_name})")
134         elif self.presence_type() == 'len':
135             ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
136         ri.cw.p(f"{line};")
137
138     def _attr_put_simple(self, ri, var, put_type):
139         line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
140         self._attr_put_line(ri, var, line)
141
142     def attr_put(self, ri, var):
143         raise Exception(f"Put not implemented for class type {self.type}")
144
145     def _attr_get(self, ri, var):
146         raise Exception(f"Attr get not implemented for class type {self.type}")
147
148     def attr_get(self, ri, var, first):
149         lines, init_lines, local_vars = self._attr_get(ri, var)
150         if type(lines) is str:
151             lines = [lines]
152         if type(init_lines) is str:
153             init_lines = [init_lines]
154
155         kw = 'if' if first else 'else if'
156         ri.cw.block_start(line=f"{kw} (mnl_attr_get_type(attr) == {self.enum_name})")
157         if local_vars:
158             for local in local_vars:
159                 ri.cw.p(local)
160             ri.cw.nl()
161
162         if not self.is_multi_val():
163             ri.cw.p("if (ynl_attr_validate(yarg, attr))")
164             ri.cw.p("return MNL_CB_ERROR;")
165             if self.presence_type() == 'bit':
166                 ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
167
168         if init_lines:
169             ri.cw.nl()
170             for line in init_lines:
171                 ri.cw.p(line)
172
173         for line in lines:
174             ri.cw.p(line)
175         ri.cw.block_end()
176         return True
177
178     def _setter_lines(self, ri, member, presence):
179         raise Exception(f"Setter not implemented for class type {self.type}")
180
181     def setter(self, ri, space, direction, deref=False, ref=None):
182         ref = (ref if ref else []) + [self.c_name]
183         var = "req"
184         member = f"{var}->{'.'.join(ref)}"
185
186         code = []
187         presence = ''
188         for i in range(0, len(ref)):
189             presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
190             if self.presence_type() == 'bit':
191                 code.append(presence + ' = 1;')
192         code += self._setter_lines(ri, member, presence)
193
194         func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
195         free = bool([x for x in code if 'free(' in x])
196         alloc = bool([x for x in code if 'alloc(' in x])
197         if free and not alloc:
198             func_name = '__' + func_name
199         ri.cw.write_func('static inline void', func_name, body=code,
200                          args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
201
202
203 class TypeUnused(Type):
204     def presence_type(self):
205         return ''
206
207     def arg_member(self, ri):
208         return []
209
210     def _attr_get(self, ri, var):
211         return ['return MNL_CB_ERROR;'], None, None
212
213     def _attr_typol(self):
214         return '.type = YNL_PT_REJECT, '
215
216     def attr_policy(self, cw):
217         pass
218
219
220 class TypePad(Type):
221     def presence_type(self):
222         return ''
223
224     def arg_member(self, ri):
225         return []
226
227     def _attr_typol(self):
228         return '.type = YNL_PT_IGNORE, '
229
230     def attr_get(self, ri, var, first):
231         pass
232
233     def attr_policy(self, cw):
234         pass
235
236
237 class TypeScalar(Type):
238     def __init__(self, family, attr_set, attr, value):
239         super().__init__(family, attr_set, attr, value)
240
241         self.byte_order_comment = ''
242         if 'byte-order' in attr:
243             self.byte_order_comment = f" /* {attr['byte-order']} */"
244
245         # Added by resolve():
246         self.is_bitfield = None
247         delattr(self, "is_bitfield")
248         self.type_name = None
249         delattr(self, "type_name")
250
251     def resolve(self):
252         self.resolve_up(super())
253
254         if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
255             self.is_bitfield = True
256         elif 'enum' in self.attr:
257             self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
258         else:
259             self.is_bitfield = False
260
261         if 'enum' in self.attr and not self.is_bitfield:
262             self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
263         else:
264             self.type_name = '__' + self.type
265
266     def _mnl_type(self):
267         t = self.type
268         # mnl does not have a helper for signed types
269         if t[0] == 's':
270             t = 'u' + t[1:]
271         return t
272
273     def _attr_policy(self, policy):
274         if 'flags-mask' in self.checks or self.is_bitfield:
275             if self.is_bitfield:
276                 enum = self.family.consts[self.attr['enum']]
277                 mask = enum.get_mask(as_flags=True)
278             else:
279                 flags = self.family.consts[self.checks['flags-mask']]
280                 flag_cnt = len(flags['entries'])
281                 mask = (1 << flag_cnt) - 1
282             return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
283         elif 'min' in self.checks:
284             return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
285         elif 'enum' in self.attr:
286             enum = self.family.consts[self.attr['enum']]
287             cnt = len(enum['entries'])
288             return f"NLA_POLICY_MAX({policy}, {cnt - 1})"
289         return super()._attr_policy(policy)
290
291     def _attr_typol(self):
292         return f'.type = YNL_PT_U{self.type[1:]}, '
293
294     def arg_member(self, ri):
295         return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
296
297     def attr_put(self, ri, var):
298         self._attr_put_simple(ri, var, self._mnl_type())
299
300     def _attr_get(self, ri, var):
301         return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
302
303     def _setter_lines(self, ri, member, presence):
304         return [f"{member} = {self.c_name};"]
305
306
307 class TypeFlag(Type):
308     def arg_member(self, ri):
309         return []
310
311     def _attr_typol(self):
312         return '.type = YNL_PT_FLAG, '
313
314     def attr_put(self, ri, var):
315         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
316
317     def _attr_get(self, ri, var):
318         return [], None, None
319
320     def _setter_lines(self, ri, member, presence):
321         return []
322
323
324 class TypeString(Type):
325     def arg_member(self, ri):
326         return [f"const char *{self.c_name}"]
327
328     def presence_type(self):
329         return 'len'
330
331     def struct_member(self, ri):
332         ri.cw.p(f"char *{self.c_name};")
333
334     def _attr_typol(self):
335         return f'.type = YNL_PT_NUL_STR, '
336
337     def _attr_policy(self, policy):
338         mem = '{ .type = ' + policy
339         if 'max-len' in self.checks:
340             mem += ', .len = ' + str(self.checks['max-len'])
341         mem += ', }'
342         return mem
343
344     def attr_policy(self, cw):
345         if self.checks.get('unterminated-ok', False):
346             policy = 'NLA_STRING'
347         else:
348             policy = 'NLA_NUL_STRING'
349
350         spec = self._attr_policy(policy)
351         cw.p(f"\t[{self.enum_name}] = {spec},")
352
353     def attr_put(self, ri, var):
354         self._attr_put_simple(ri, var, 'strz')
355
356     def _attr_get(self, ri, var):
357         len_mem = var + '->_present.' + self.c_name + '_len'
358         return [f"{len_mem} = len;",
359                 f"{var}->{self.c_name} = malloc(len + 1);",
360                 f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
361                 f"{var}->{self.c_name}[len] = 0;"], \
362                ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
363                ['unsigned int len;']
364
365     def _setter_lines(self, ri, member, presence):
366         return [f"free({member});",
367                 f"{presence}_len = strlen({self.c_name});",
368                 f"{member} = malloc({presence}_len + 1);",
369                 f'memcpy({member}, {self.c_name}, {presence}_len);',
370                 f'{member}[{presence}_len] = 0;']
371
372
373 class TypeBinary(Type):
374     def arg_member(self, ri):
375         return [f"const void *{self.c_name}", 'size_t len']
376
377     def presence_type(self):
378         return 'len'
379
380     def struct_member(self, ri):
381         ri.cw.p(f"void *{self.c_name};")
382
383     def _attr_typol(self):
384         return f'.type = YNL_PT_BINARY,'
385
386     def _attr_policy(self, policy):
387         mem = '{ '
388         if len(self.checks) == 1 and 'min-len' in self.checks:
389             mem += '.len = ' + str(self.checks['min-len'])
390         elif len(self.checks) == 0:
391             mem += '.type = NLA_BINARY'
392         else:
393             raise Exception('One or more of binary type checks not implemented, yet')
394         mem += ', }'
395         return mem
396
397     def attr_put(self, ri, var):
398         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
399                             f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
400
401     def _attr_get(self, ri, var):
402         len_mem = var + '->_present.' + self.c_name + '_len'
403         return [f"{len_mem} = len;",
404                 f"{var}->{self.c_name} = malloc(len);",
405                 f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
406                ['len = mnl_attr_get_payload_len(attr);'], \
407                ['unsigned int len;']
408
409     def _setter_lines(self, ri, member, presence):
410         return [f"free({member});",
411                 f"{member} = malloc({presence}_len);",
412                 f'memcpy({member}, {self.c_name}, {presence}_len);']
413
414
415 class TypeNest(Type):
416     def _complex_member_type(self, ri):
417         return f"struct {self.nested_render_name}"
418
419     def free(self, ri, var, ref):
420         ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
421
422     def _attr_typol(self):
423         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
424
425     def _attr_policy(self, policy):
426         return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
427
428     def attr_put(self, ri, var):
429         self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
430                             f"{self.enum_name}, &{var}->{self.c_name})")
431
432     def _attr_get(self, ri, var):
433         get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
434                      "return MNL_CB_ERROR;"]
435         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
436                       f"parg.data = &{var}->{self.c_name};"]
437         return get_lines, init_lines, None
438
439     def setter(self, ri, space, direction, deref=False, ref=None):
440         ref = (ref if ref else []) + [self.c_name]
441
442         for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
443             attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
444
445
446 class TypeMultiAttr(Type):
447     def is_multi_val(self):
448         return True
449
450     def presence_type(self):
451         return 'count'
452
453     def _mnl_type(self):
454         t = self.type
455         # mnl does not have a helper for signed types
456         if t[0] == 's':
457             t = 'u' + t[1:]
458         return t
459
460     def _complex_member_type(self, ri):
461         if 'type' not in self.attr or self.attr['type'] == 'nest':
462             return f"struct {self.nested_render_name}"
463         elif self.attr['type'] in scalars:
464             scalar_pfx = '__' if ri.ku_space == 'user' else ''
465             return scalar_pfx + self.attr['type']
466         else:
467             raise Exception(f"Sub-type {self.attr['type']} not supported yet")
468
469     def free_needs_iter(self):
470         return 'type' not in self.attr or self.attr['type'] == 'nest'
471
472     def free(self, ri, var, ref):
473         if self.attr['type'] in scalars:
474             ri.cw.p(f"free({var}->{ref}{self.c_name});")
475         elif 'type' not in self.attr or self.attr['type'] == 'nest':
476             ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
477             ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
478             ri.cw.p(f"free({var}->{ref}{self.c_name});")
479         else:
480             raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
481
482     def _attr_typol(self):
483         if 'type' not in self.attr or self.attr['type'] == 'nest':
484             return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
485         elif self.attr['type'] in scalars:
486             return f".type = YNL_PT_U{self.attr['type'][1:]}, "
487         else:
488             raise Exception(f"Sub-type {self.attr['type']} not supported yet")
489
490     def _attr_get(self, ri, var):
491         return f'{var}->n_{self.c_name}++;', None, None
492
493     def attr_put(self, ri, var):
494         if self.attr['type'] in scalars:
495             put_type = self._mnl_type()
496             ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
497             ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
498         elif 'type' not in self.attr or self.attr['type'] == 'nest':
499             ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
500             self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
501                                 f"{self.enum_name}, &{var}->{self.c_name}[i])")
502         else:
503             raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
504
505     def _setter_lines(self, ri, member, presence):
506         # For multi-attr we have a count, not presence, hack up the presence
507         presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
508         return [f"free({member});",
509                 f"{member} = {self.c_name};",
510                 f"{presence} = n_{self.c_name};"]
511
512
513 class TypeArrayNest(Type):
514     def is_multi_val(self):
515         return True
516
517     def presence_type(self):
518         return 'count'
519
520     def _complex_member_type(self, ri):
521         if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
522             return f"struct {self.nested_render_name}"
523         elif self.attr['sub-type'] in scalars:
524             scalar_pfx = '__' if ri.ku_space == 'user' else ''
525             return scalar_pfx + self.attr['sub-type']
526         else:
527             raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
528
529     def _attr_typol(self):
530         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
531
532     def _attr_get(self, ri, var):
533         local_vars = ['const struct nlattr *attr2;']
534         get_lines = [f'attr_{self.c_name} = attr;',
535                      'mnl_attr_for_each_nested(attr2, attr)',
536                      f'\t{var}->n_{self.c_name}++;']
537         return get_lines, None, local_vars
538
539
540 class TypeNestTypeValue(Type):
541     def _complex_member_type(self, ri):
542         return f"struct {self.nested_render_name}"
543
544     def _attr_typol(self):
545         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
546
547     def _attr_get(self, ri, var):
548         prev = 'attr'
549         tv_args = ''
550         get_lines = []
551         local_vars = []
552         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
553                       f"parg.data = &{var}->{self.c_name};"]
554         if 'type-value' in self.attr:
555             tv_names = [c_lower(x) for x in self.attr["type-value"]]
556             local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
557             local_vars += [f'__u32 {", ".join(tv_names)};']
558             for level in self.attr["type-value"]:
559                 level = c_lower(level)
560                 get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
561                 get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
562                 prev = 'attr_' + level
563
564             tv_args = f", {', '.join(tv_names)}"
565
566         get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
567         return get_lines, init_lines, local_vars
568
569
570 class Struct:
571     def __init__(self, family, space_name, type_list=None, inherited=None):
572         self.family = family
573         self.space_name = space_name
574         self.attr_set = family.attr_sets[space_name]
575         # Use list to catch comparisons with empty sets
576         self._inherited = inherited if inherited is not None else []
577         self.inherited = []
578
579         self.nested = type_list is None
580         if family.name == c_lower(space_name):
581             self.render_name = f"{family.name}"
582         else:
583             self.render_name = f"{family.name}_{c_lower(space_name)}"
584         self.struct_name = 'struct ' + self.render_name
585         self.ptr_name = self.struct_name + ' *'
586
587         self.request = False
588         self.reply = False
589
590         self.attr_list = []
591         self.attrs = dict()
592         if type_list:
593             for t in type_list:
594                 self.attr_list.append((t, self.attr_set[t]),)
595         else:
596             for t in self.attr_set:
597                 self.attr_list.append((t, self.attr_set[t]),)
598
599         max_val = 0
600         self.attr_max_val = None
601         for name, attr in self.attr_list:
602             if attr.value >= max_val:
603                 max_val = attr.value
604                 self.attr_max_val = attr
605             self.attrs[name] = attr
606
607     def __iter__(self):
608         yield from self.attrs
609
610     def __getitem__(self, key):
611         return self.attrs[key]
612
613     def member_list(self):
614         return self.attr_list
615
616     def set_inherited(self, new_inherited):
617         if self._inherited != new_inherited:
618             raise Exception("Inheriting different members not supported")
619         self.inherited = [c_lower(x) for x in sorted(self._inherited)]
620
621
622 class EnumEntry(SpecEnumEntry):
623     def __init__(self, enum_set, yaml, prev, value_start):
624         super().__init__(enum_set, yaml, prev, value_start)
625
626         if prev:
627             self.value_change = (self.value != prev.value + 1)
628         else:
629             self.value_change = (self.value != 0)
630         self.value_change = self.value_change or self.enum_set['type'] == 'flags'
631
632         # Added by resolve:
633         self.c_name = None
634         delattr(self, "c_name")
635
636     def resolve(self):
637         self.resolve_up(super())
638
639         self.c_name = c_upper(self.enum_set.value_pfx + self.name)
640
641
642 class EnumSet(SpecEnumSet):
643     def __init__(self, family, yaml):
644         self.render_name = c_lower(family.name + '-' + yaml['name'])
645         self.enum_name = 'enum ' + self.render_name
646
647         self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
648
649         super().__init__(family, yaml)
650
651     def new_entry(self, entry, prev_entry, value_start):
652         return EnumEntry(self, entry, prev_entry, value_start)
653
654
655 class AttrSet(SpecAttrSet):
656     def __init__(self, family, yaml):
657         super().__init__(family, yaml)
658
659         if self.subset_of is None:
660             if 'name-prefix' in yaml:
661                 pfx = yaml['name-prefix']
662             elif self.name == family.name:
663                 pfx = family.name + '-a-'
664             else:
665                 pfx = f"{family.name}-a-{self.name}-"
666             self.name_prefix = c_upper(pfx)
667             self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
668         else:
669             self.name_prefix = family.attr_sets[self.subset_of].name_prefix
670             self.max_name = family.attr_sets[self.subset_of].max_name
671
672         # Added by resolve:
673         self.c_name = None
674         delattr(self, "c_name")
675
676     def resolve(self):
677         self.c_name = c_lower(self.name)
678         if self.c_name in _C_KW:
679             self.c_name += '_'
680         if self.c_name == self.family.c_name:
681             self.c_name = ''
682
683     def new_attr(self, elem, value):
684         if 'multi-attr' in elem and elem['multi-attr']:
685             return TypeMultiAttr(self.family, self, elem, value)
686         elif elem['type'] in scalars:
687             return TypeScalar(self.family, self, elem, value)
688         elif elem['type'] == 'unused':
689             return TypeUnused(self.family, self, elem, value)
690         elif elem['type'] == 'pad':
691             return TypePad(self.family, self, elem, value)
692         elif elem['type'] == 'flag':
693             return TypeFlag(self.family, self, elem, value)
694         elif elem['type'] == 'string':
695             return TypeString(self.family, self, elem, value)
696         elif elem['type'] == 'binary':
697             return TypeBinary(self.family, self, elem, value)
698         elif elem['type'] == 'nest':
699             return TypeNest(self.family, self, elem, value)
700         elif elem['type'] == 'array-nest':
701             return TypeArrayNest(self.family, self, elem, value)
702         elif elem['type'] == 'nest-type-value':
703             return TypeNestTypeValue(self.family, self, elem, value)
704         else:
705             raise Exception(f"No typed class for type {elem['type']}")
706
707
708 class Operation(SpecOperation):
709     def __init__(self, family, yaml, req_value, rsp_value):
710         super().__init__(family, yaml, req_value, rsp_value)
711
712         if req_value != rsp_value:
713             raise Exception("Directional messages not supported by codegen")
714
715         self.render_name = family.name + '_' + c_lower(self.name)
716
717         self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
718                          ('dump' in yaml and 'request' in yaml['dump'])
719
720         # Added by resolve:
721         self.enum_name = None
722         delattr(self, "enum_name")
723
724     def resolve(self):
725         self.resolve_up(super())
726
727         if not self.is_async:
728             self.enum_name = self.family.op_prefix + c_upper(self.name)
729         else:
730             self.enum_name = self.family.async_op_prefix + c_upper(self.name)
731
732     def add_notification(self, op):
733         if 'notify' not in self.yaml:
734             self.yaml['notify'] = dict()
735             self.yaml['notify']['reply'] = self.yaml['do']['reply']
736             self.yaml['notify']['cmds'] = []
737         self.yaml['notify']['cmds'].append(op)
738
739
740 class Family(SpecFamily):
741     def __init__(self, file_name):
742         # Added by resolve:
743         self.c_name = None
744         delattr(self, "c_name")
745         self.op_prefix = None
746         delattr(self, "op_prefix")
747         self.async_op_prefix = None
748         delattr(self, "async_op_prefix")
749         self.mcgrps = None
750         delattr(self, "mcgrps")
751         self.consts = None
752         delattr(self, "consts")
753         self.hooks = None
754         delattr(self, "hooks")
755
756         super().__init__(file_name)
757
758         self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
759         self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
760
761         if 'definitions' not in self.yaml:
762             self.yaml['definitions'] = []
763
764         if 'uapi-header' in self.yaml:
765             self.uapi_header = self.yaml['uapi-header']
766         else:
767             self.uapi_header = f"linux/{self.name}.h"
768
769     def resolve(self):
770         self.resolve_up(super())
771
772         if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
773             raise Exception("Codegen only supported for genetlink")
774
775         self.c_name = c_lower(self.name)
776         if 'name-prefix' in self.yaml['operations']:
777             self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
778         else:
779             self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
780         if 'async-prefix' in self.yaml['operations']:
781             self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
782         else:
783             self.async_op_prefix = self.op_prefix
784
785         self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
786
787         self.hooks = dict()
788         for when in ['pre', 'post']:
789             self.hooks[when] = dict()
790             for op_mode in ['do', 'dump']:
791                 self.hooks[when][op_mode] = dict()
792                 self.hooks[when][op_mode]['set'] = set()
793                 self.hooks[when][op_mode]['list'] = []
794
795         # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
796         self.root_sets = dict()
797         # dict space-name -> set('request', 'reply')
798         self.pure_nested_structs = dict()
799         self.all_notify = dict()
800
801         self._mock_up_events()
802
803         self._dictify()
804         self._load_root_sets()
805         self._load_nested_sets()
806         self._load_all_notify()
807         self._load_hooks()
808
809         self.kernel_policy = self.yaml.get('kernel-policy', 'split')
810         if self.kernel_policy == 'global':
811             self._load_global_policy()
812
813     def new_enum(self, elem):
814         return EnumSet(self, elem)
815
816     def new_attr_set(self, elem):
817         return AttrSet(self, elem)
818
819     def new_operation(self, elem, req_value, rsp_value):
820         return Operation(self, elem, req_value, rsp_value)
821
822     # Fake a 'do' equivalent of all events, so that we can render their response parsing
823     def _mock_up_events(self):
824         for op in self.yaml['operations']['list']:
825             if 'event' in op:
826                 op['do'] = {
827                     'reply': {
828                         'attributes': op['event']['attributes']
829                     }
830                 }
831
832     def _dictify(self):
833         ntf = []
834         for msg in self.msgs.values():
835             if 'notify' in msg:
836                 ntf.append(msg)
837         for n in ntf:
838             self.ops[n['notify']].add_notification(n)
839
840     def _load_root_sets(self):
841         for op_name, op in self.ops.items():
842             if 'attribute-set' not in op:
843                 continue
844
845             req_attrs = set()
846             rsp_attrs = set()
847             for op_mode in ['do', 'dump']:
848                 if op_mode in op and 'request' in op[op_mode]:
849                     req_attrs.update(set(op[op_mode]['request']['attributes']))
850                 if op_mode in op and 'reply' in op[op_mode]:
851                     rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
852
853             if op['attribute-set'] not in self.root_sets:
854                 self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
855             else:
856                 self.root_sets[op['attribute-set']]['request'].update(req_attrs)
857                 self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
858
859     def _load_nested_sets(self):
860         for root_set, rs_members in self.root_sets.items():
861             for attr, spec in self.attr_sets[root_set].items():
862                 if 'nested-attributes' in spec:
863                     inherit = set()
864                     nested = spec['nested-attributes']
865                     if nested not in self.root_sets:
866                         if nested not in self.pure_nested_structs:
867                             self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
868                     if attr in rs_members['request']:
869                         self.pure_nested_structs[nested].request = True
870                     if attr in rs_members['reply']:
871                         self.pure_nested_structs[nested].reply = True
872
873                     if 'type-value' in spec:
874                         if nested in self.root_sets:
875                             raise Exception("Inheriting members to a space used as root not supported")
876                         inherit.update(set(spec['type-value']))
877                     elif spec['type'] == 'array-nest':
878                         inherit.add('idx')
879                     self.pure_nested_structs[nested].set_inherited(inherit)
880
881     def _load_all_notify(self):
882         for op_name, op in self.ops.items():
883             if not op:
884                 continue
885
886             if 'notify' in op:
887                 self.all_notify[op_name] = op['notify']['cmds']
888
889     def _load_global_policy(self):
890         global_set = set()
891         attr_set_name = None
892         for op_name, op in self.ops.items():
893             if not op:
894                 continue
895             if 'attribute-set' not in op:
896                 continue
897
898             if attr_set_name is None:
899                 attr_set_name = op['attribute-set']
900             if attr_set_name != op['attribute-set']:
901                 raise Exception('For a global policy all ops must use the same set')
902
903             for op_mode in ['do', 'dump']:
904                 if op_mode in op:
905                     global_set.update(op[op_mode].get('request', []))
906
907         self.global_policy = []
908         self.global_policy_set = attr_set_name
909         for attr in self.attr_sets[attr_set_name]:
910             if attr in global_set:
911                 self.global_policy.append(attr)
912
913     def _load_hooks(self):
914         for op in self.ops.values():
915             for op_mode in ['do', 'dump']:
916                 if op_mode not in op:
917                     continue
918                 for when in ['pre', 'post']:
919                     if when not in op[op_mode]:
920                         continue
921                     name = op[op_mode][when]
922                     if name in self.hooks[when][op_mode]['set']:
923                         continue
924                     self.hooks[when][op_mode]['set'].add(name)
925                     self.hooks[when][op_mode]['list'].append(name)
926
927     def has_notifications(self):
928         for op in self.ops.values():
929             if 'notify' in op or 'event' in op:
930                 return True
931         return False
932
933
934 class RenderInfo:
935     def __init__(self, cw, family, ku_space, op, op_name, op_mode, attr_set=None):
936         self.family = family
937         self.nl = cw.nlib
938         self.ku_space = ku_space
939         self.op = op
940         self.op_name = op_name
941         self.op_mode = op_mode
942
943         # 'do' and 'dump' response parsing is identical
944         self.type_consistent = True
945         if op_mode != 'do' and 'dump' in op and 'do' in op:
946             if ('reply' in op['do']) != ('reply' in op["dump"]):
947                 self.type_consistent = False
948             elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
949                 self.type_consistent = False
950
951         self.attr_set = attr_set
952         if not self.attr_set:
953             self.attr_set = op['attribute-set']
954
955         if op:
956             self.type_name = c_lower(op_name)
957         else:
958             self.type_name = c_lower(attr_set)
959
960         self.cw = cw
961
962         self.struct = dict()
963         for op_dir in ['request', 'reply']:
964             if op and op_dir in op[op_mode]:
965                 self.struct[op_dir] = Struct(family, self.attr_set,
966                                              type_list=op[op_mode][op_dir]['attributes'])
967         if op_mode == 'event':
968             self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
969
970
971 class CodeWriter:
972     def __init__(self, nlib, out_file):
973         self.nlib = nlib
974
975         self._nl = False
976         self._silent_block = False
977         self._ind = 0
978         self._out = out_file
979
980     @classmethod
981     def _is_cond(cls, line):
982         return line.startswith('if') or line.startswith('while') or line.startswith('for')
983
984     def p(self, line, add_ind=0, eat_nl=False):
985         if self._nl:
986             if not eat_nl:
987                 self._out.write('\n')
988             self._nl = False
989         ind = self._ind
990         if line[-1] == ':':
991             ind -= 1
992         if self._silent_block:
993             ind += 1
994         self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
995         if add_ind:
996             ind += add_ind
997         self._out.write('\t' * ind + line + '\n')
998
999     def nl(self):
1000         self._nl = True
1001
1002     def block_start(self, line=''):
1003         if line:
1004             line = line + ' '
1005         self.p(line + '{')
1006         self._ind += 1
1007
1008     def block_end(self, line=''):
1009         if line and line[0] not in {';', ','}:
1010             line = ' ' + line
1011         self._ind -= 1
1012         self.p('}' + line, eat_nl=True)
1013
1014     def write_doc_line(self, doc, indent=True):
1015         words = doc.split()
1016         line = ' *'
1017         for word in words:
1018             if len(line) + len(word) >= 79:
1019                 self.p(line)
1020                 line = ' *'
1021                 if indent:
1022                     line += '  '
1023             line += ' ' + word
1024         self.p(line)
1025
1026     def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1027         if not args:
1028             args = ['void']
1029
1030         if doc:
1031             self.p('/*')
1032             self.p(' * ' + doc)
1033             self.p(' */')
1034
1035         oneline = qual_ret
1036         if qual_ret[-1] != '*':
1037             oneline += ' '
1038         oneline += f"{name}({', '.join(args)}){suffix}"
1039
1040         if len(oneline) < 80:
1041             self.p(oneline)
1042             return
1043
1044         v = qual_ret
1045         if len(v) > 3:
1046             self.p(v)
1047             v = ''
1048         elif qual_ret[-1] != '*':
1049             v += ' '
1050         v += name + '('
1051         ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1052         delta_ind = len(v) - len(ind)
1053         v += args[0]
1054         i = 1
1055         while i < len(args):
1056             next_len = len(v) + len(args[i])
1057             if v[0] == '\t':
1058                 next_len += delta_ind
1059             if next_len > 76:
1060                 self.p(v + ',')
1061                 v = ind
1062             else:
1063                 v += ', '
1064             v += args[i]
1065             i += 1
1066         self.p(v + ')' + suffix)
1067
1068     def write_func_lvar(self, local_vars):
1069         if not local_vars:
1070             return
1071
1072         if type(local_vars) is str:
1073             local_vars = [local_vars]
1074
1075         local_vars.sort(key=len, reverse=True)
1076         for var in local_vars:
1077             self.p(var)
1078         self.nl()
1079
1080     def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1081         self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1082         self.write_func_lvar(local_vars=local_vars)
1083
1084         self.block_start()
1085         for line in body:
1086             self.p(line)
1087         self.block_end()
1088
1089     def writes_defines(self, defines):
1090         longest = 0
1091         for define in defines:
1092             if len(define[0]) > longest:
1093                 longest = len(define[0])
1094         longest = ((longest + 8) // 8) * 8
1095         for define in defines:
1096             line = '#define ' + define[0]
1097             line += '\t' * ((longest - len(define[0]) + 7) // 8)
1098             if type(define[1]) is int:
1099                 line += str(define[1])
1100             elif type(define[1]) is str:
1101                 line += '"' + define[1] + '"'
1102             self.p(line)
1103
1104     def write_struct_init(self, members):
1105         longest = max([len(x[0]) for x in members])
1106         longest += 1  # because we prepend a .
1107         longest = ((longest + 8) // 8) * 8
1108         for one in members:
1109             line = '.' + one[0]
1110             line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1111             line += '= ' + one[1] + ','
1112             self.p(line)
1113
1114
1115 scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1116
1117 direction_to_suffix = {
1118     'reply': '_rsp',
1119     'request': '_req',
1120     '': ''
1121 }
1122
1123 op_mode_to_wrapper = {
1124     'do': '',
1125     'dump': '_list',
1126     'notify': '_ntf',
1127     'event': '',
1128 }
1129
1130 _C_KW = {
1131     'do'
1132 }
1133
1134
1135 def rdir(direction):
1136     if direction == 'reply':
1137         return 'request'
1138     if direction == 'request':
1139         return 'reply'
1140     return direction
1141
1142
1143 def op_prefix(ri, direction, deref=False):
1144     suffix = f"_{ri.type_name}"
1145
1146     if not ri.op_mode or ri.op_mode == 'do':
1147         suffix += f"{direction_to_suffix[direction]}"
1148     else:
1149         if direction == 'request':
1150             suffix += '_req_dump'
1151         else:
1152             if ri.type_consistent:
1153                 if deref:
1154                     suffix += f"{direction_to_suffix[direction]}"
1155                 else:
1156                     suffix += op_mode_to_wrapper[ri.op_mode]
1157             else:
1158                 suffix += '_rsp'
1159                 suffix += '_dump' if deref else '_list'
1160
1161     return f"{ri.family['name']}{suffix}"
1162
1163
1164 def type_name(ri, direction, deref=False):
1165     return f"struct {op_prefix(ri, direction, deref=deref)}"
1166
1167
1168 def print_prototype(ri, direction, terminate=True, doc=None):
1169     suffix = ';' if terminate else ''
1170
1171     fname = ri.op.render_name
1172     if ri.op_mode == 'dump':
1173         fname += '_dump'
1174
1175     args = ['struct ynl_sock *ys']
1176     if 'request' in ri.op[ri.op_mode]:
1177         args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1178
1179     ret = 'int'
1180     if 'reply' in ri.op[ri.op_mode]:
1181         ret = f"{type_name(ri, rdir(direction))} *"
1182
1183     ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1184
1185
1186 def print_req_prototype(ri):
1187     print_prototype(ri, "request", doc=ri.op['doc'])
1188
1189
1190 def print_dump_prototype(ri):
1191     print_prototype(ri, "request")
1192
1193
1194 def put_typol_fwd(cw, struct):
1195     cw.p(f'extern struct ynl_policy_nest {struct.render_name}_nest;')
1196
1197
1198 def put_typol(cw, struct):
1199     type_max = struct.attr_set.max_name
1200     cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1201
1202     for _, arg in struct.member_list():
1203         arg.attr_typol(cw)
1204
1205     cw.block_end(line=';')
1206     cw.nl()
1207
1208     cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1209     cw.p(f'.max_attr = {type_max},')
1210     cw.p(f'.table = {struct.render_name}_policy,')
1211     cw.block_end(line=';')
1212     cw.nl()
1213
1214
1215 def put_op_name_fwd(family, cw):
1216     cw.write_func_prot('const char *', f'{family.name}_op_str', ['int op'], suffix=';')
1217
1218
1219 def put_op_name(family, cw):
1220     map_name = f'{family.name}_op_strmap'
1221     cw.block_start(line=f"static const char * const {map_name}[] =")
1222     for op_name, op in family.msgs.items():
1223         cw.p(f'[{op.enum_name}] = "{op_name}",')
1224     cw.block_end(line=';')
1225     cw.nl()
1226
1227     cw.write_func_prot('const char *', f'{family.name}_op_str', ['int op'])
1228     cw.block_start()
1229     cw.p(f'if (op < 0 || op >= (int)MNL_ARRAY_SIZE({map_name}))')
1230     cw.p('return NULL;')
1231     cw.p(f'return {map_name}[op];')
1232     cw.block_end()
1233     cw.nl()
1234
1235
1236 def put_enum_to_str_fwd(family, cw, enum):
1237     args = [f'enum {enum.render_name} value']
1238     if 'enum-name' in enum and not enum['enum-name']:
1239         args = ['int value']
1240     cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1241
1242
1243 def put_enum_to_str(family, cw, enum):
1244     map_name = f'{enum.render_name}_strmap'
1245     cw.block_start(line=f"static const char * const {map_name}[] =")
1246     for entry in enum.entries.values():
1247         cw.p(f'[{entry.value}] = "{entry.name}",')
1248     cw.block_end(line=';')
1249     cw.nl()
1250
1251     args = [f'enum {enum.render_name} value']
1252     if 'enum-name' in enum and not enum['enum-name']:
1253         args = ['int value']
1254     cw.write_func_prot('const char *', f'{enum.render_name}_str', args)
1255     cw.block_start()
1256     if enum.type == 'flags':
1257         cw.p('value = ffs(value) - 1;')
1258     cw.p(f'if (value < 0 || value >= (int)MNL_ARRAY_SIZE({map_name}))')
1259     cw.p('return NULL;')
1260     cw.p(f'return {map_name}[value];')
1261     cw.block_end()
1262     cw.nl()
1263
1264
1265 def put_req_nested(ri, struct):
1266     func_args = ['struct nlmsghdr *nlh',
1267                  'unsigned int attr_type',
1268                  f'{struct.ptr_name}obj']
1269
1270     ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1271     ri.cw.block_start()
1272     ri.cw.write_func_lvar('struct nlattr *nest;')
1273
1274     ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1275
1276     for _, arg in struct.member_list():
1277         arg.attr_put(ri, "obj")
1278
1279     ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1280
1281     ri.cw.nl()
1282     ri.cw.p('return 0;')
1283     ri.cw.block_end()
1284     ri.cw.nl()
1285
1286
1287 def _multi_parse(ri, struct, init_lines, local_vars):
1288     if struct.nested:
1289         iter_line = "mnl_attr_for_each_nested(attr, nested)"
1290     else:
1291         iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1292
1293     array_nests = set()
1294     multi_attrs = set()
1295     needs_parg = False
1296     for arg, aspec in struct.member_list():
1297         if aspec['type'] == 'array-nest':
1298             local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1299             array_nests.add(arg)
1300         if 'multi-attr' in aspec:
1301             multi_attrs.add(arg)
1302         needs_parg |= 'nested-attributes' in aspec
1303     if array_nests or multi_attrs:
1304         local_vars.append('int i;')
1305     if needs_parg:
1306         local_vars.append('struct ynl_parse_arg parg;')
1307         init_lines.append('parg.ys = yarg->ys;')
1308
1309     ri.cw.block_start()
1310     ri.cw.write_func_lvar(local_vars)
1311
1312     for line in init_lines:
1313         ri.cw.p(line)
1314     ri.cw.nl()
1315
1316     for arg in struct.inherited:
1317         ri.cw.p(f'dst->{arg} = {arg};')
1318
1319     ri.cw.nl()
1320     ri.cw.block_start(line=iter_line)
1321
1322     first = True
1323     for _, arg in struct.member_list():
1324         good = arg.attr_get(ri, 'dst', first=first)
1325         # First may be 'unused' or 'pad', ignore those
1326         first &= not good
1327
1328     ri.cw.block_end()
1329     ri.cw.nl()
1330
1331     for anest in sorted(array_nests):
1332         aspec = struct[anest]
1333
1334         ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})")
1335         ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1336         ri.cw.p('i = 0;')
1337         ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1338         ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1339         ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1340         ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1341         ri.cw.p('return MNL_CB_ERROR;')
1342         ri.cw.p('i++;')
1343         ri.cw.block_end()
1344         ri.cw.block_end()
1345     ri.cw.nl()
1346
1347     for anest in sorted(multi_attrs):
1348         aspec = struct[anest]
1349         ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})")
1350         ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1351         ri.cw.p('i = 0;')
1352         if 'nested-attributes' in aspec:
1353             ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1354         ri.cw.block_start(line=iter_line)
1355         ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1356         if 'nested-attributes' in aspec:
1357             ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1358             ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1359             ri.cw.p('return MNL_CB_ERROR;')
1360         elif aspec['type'] in scalars:
1361             t = aspec['type']
1362             if t[0] == 's':
1363                 t = 'u' + t[1:]
1364             ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1365         else:
1366             raise Exception('Nest parsing type not supported yet')
1367         ri.cw.p('i++;')
1368         ri.cw.block_end()
1369         ri.cw.block_end()
1370         ri.cw.block_end()
1371     ri.cw.nl()
1372
1373     if struct.nested:
1374         ri.cw.p('return 0;')
1375     else:
1376         ri.cw.p('return MNL_CB_OK;')
1377     ri.cw.block_end()
1378     ri.cw.nl()
1379
1380
1381 def parse_rsp_nested(ri, struct):
1382     func_args = ['struct ynl_parse_arg *yarg',
1383                  'const struct nlattr *nested']
1384     for arg in struct.inherited:
1385         func_args.append('__u32 ' + arg)
1386
1387     local_vars = ['const struct nlattr *attr;',
1388                   f'{struct.ptr_name}dst = yarg->data;']
1389     init_lines = []
1390
1391     ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1392
1393     _multi_parse(ri, struct, init_lines, local_vars)
1394
1395
1396 def parse_rsp_msg(ri, deref=False):
1397     if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1398         return
1399
1400     func_args = ['const struct nlmsghdr *nlh',
1401                  'void *data']
1402
1403     local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1404                   'struct ynl_parse_arg *yarg = data;',
1405                   'const struct nlattr *attr;']
1406     init_lines = ['dst = yarg->data;']
1407
1408     ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1409
1410     _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1411
1412
1413 def print_req(ri):
1414     ret_ok = '0'
1415     ret_err = '-1'
1416     direction = "request"
1417     local_vars = ['struct nlmsghdr *nlh;',
1418                   'int err;']
1419
1420     if 'reply' in ri.op[ri.op_mode]:
1421         ret_ok = 'rsp'
1422         ret_err = 'NULL'
1423         local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1424                        'struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };']
1425
1426     print_prototype(ri, direction, terminate=False)
1427     ri.cw.block_start()
1428     ri.cw.write_func_lvar(local_vars)
1429
1430     ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1431
1432     ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1433     if 'reply' in ri.op[ri.op_mode]:
1434         ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1435     ri.cw.nl()
1436     for _, attr in ri.struct["request"].member_list():
1437         attr.attr_put(ri, "req")
1438     ri.cw.nl()
1439
1440     parse_arg = "NULL"
1441     if 'reply' in ri.op[ri.op_mode]:
1442         ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1443         ri.cw.p('yrs.yarg.data = rsp;')
1444         ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1445         if ri.op.value is not None:
1446             ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1447         else:
1448             ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1449         ri.cw.nl()
1450         parse_arg = '&yrs'
1451     ri.cw.p(f"err = ynl_exec(ys, nlh, {parse_arg});")
1452     ri.cw.p('if (err < 0)')
1453     if 'reply' in ri.op[ri.op_mode]:
1454         ri.cw.p('goto err_free;')
1455     else:
1456         ri.cw.p('return -1;')
1457     ri.cw.nl()
1458
1459     ri.cw.p(f"return {ret_ok};")
1460     ri.cw.nl()
1461
1462     if 'reply' in ri.op[ri.op_mode]:
1463         ri.cw.p('err_free:')
1464         ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1465         ri.cw.p(f"return {ret_err};")
1466
1467     ri.cw.block_end()
1468
1469
1470 def print_dump(ri):
1471     direction = "request"
1472     print_prototype(ri, direction, terminate=False)
1473     ri.cw.block_start()
1474     local_vars = ['struct ynl_dump_state yds = {};',
1475                   'struct nlmsghdr *nlh;',
1476                   'int err;']
1477
1478     for var in local_vars:
1479         ri.cw.p(f'{var}')
1480     ri.cw.nl()
1481
1482     ri.cw.p('yds.ys = ys;')
1483     ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1484     ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1485     if ri.op.value is not None:
1486         ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1487     else:
1488         ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1489     ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1490     ri.cw.nl()
1491     ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1492
1493     if "request" in ri.op[ri.op_mode]:
1494         ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1495         ri.cw.nl()
1496         for _, attr in ri.struct["request"].member_list():
1497             attr.attr_put(ri, "req")
1498     ri.cw.nl()
1499
1500     ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1501     ri.cw.p('if (err < 0)')
1502     ri.cw.p('goto free_list;')
1503     ri.cw.nl()
1504
1505     ri.cw.p('return yds.first;')
1506     ri.cw.nl()
1507     ri.cw.p('free_list:')
1508     ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1509     ri.cw.p('return NULL;')
1510     ri.cw.block_end()
1511
1512
1513 def call_free(ri, direction, var):
1514     return f"{op_prefix(ri, direction)}_free({var});"
1515
1516
1517 def free_arg_name(direction):
1518     if direction:
1519         return direction_to_suffix[direction][1:]
1520     return 'obj'
1521
1522
1523 def print_alloc_wrapper(ri, direction):
1524     name = op_prefix(ri, direction)
1525     ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1526     ri.cw.block_start()
1527     ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1528     ri.cw.block_end()
1529
1530
1531 def print_free_prototype(ri, direction, suffix=';'):
1532     name = op_prefix(ri, direction)
1533     arg = free_arg_name(direction)
1534     ri.cw.write_func_prot('void', f"{name}_free", [f"struct {name} *{arg}"], suffix=suffix)
1535
1536
1537 def _print_type(ri, direction, struct):
1538     suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1539
1540     if ri.op_mode == 'dump':
1541         suffix += '_dump'
1542
1543     ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1544
1545     meta_started = False
1546     for _, attr in struct.member_list():
1547         for type_filter in ['len', 'bit']:
1548             line = attr.presence_member(ri.ku_space, type_filter)
1549             if line:
1550                 if not meta_started:
1551                     ri.cw.block_start(line=f"struct")
1552                     meta_started = True
1553                 ri.cw.p(line)
1554     if meta_started:
1555         ri.cw.block_end(line='_present;')
1556         ri.cw.nl()
1557
1558     for arg in struct.inherited:
1559         ri.cw.p(f"__u32 {arg};")
1560
1561     for _, attr in struct.member_list():
1562         attr.struct_member(ri)
1563
1564     ri.cw.block_end(line=';')
1565     ri.cw.nl()
1566
1567
1568 def print_type(ri, direction):
1569     _print_type(ri, direction, ri.struct[direction])
1570
1571
1572 def print_type_full(ri, struct):
1573     _print_type(ri, "", struct)
1574
1575
1576 def print_type_helpers(ri, direction, deref=False):
1577     print_free_prototype(ri, direction)
1578     ri.cw.nl()
1579
1580     if ri.ku_space == 'user' and direction == 'request':
1581         for _, attr in ri.struct[direction].member_list():
1582             attr.setter(ri, ri.attr_set, direction, deref=deref)
1583     ri.cw.nl()
1584
1585
1586 def print_req_type_helpers(ri):
1587     print_alloc_wrapper(ri, "request")
1588     print_type_helpers(ri, "request")
1589
1590
1591 def print_rsp_type_helpers(ri):
1592     if 'reply' not in ri.op[ri.op_mode]:
1593         return
1594     print_type_helpers(ri, "reply")
1595
1596
1597 def print_parse_prototype(ri, direction, terminate=True):
1598     suffix = "_rsp" if direction == "reply" else "_req"
1599     term = ';' if terminate else ''
1600
1601     ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1602                           ['const struct nlattr **tb',
1603                            f"struct {ri.op.render_name}{suffix} *req"],
1604                           suffix=term)
1605
1606
1607 def print_req_type(ri):
1608     print_type(ri, "request")
1609
1610
1611 def print_req_free(ri):
1612     if 'request' not in ri.op[ri.op_mode]:
1613         return
1614     _free_type(ri, 'request', ri.struct['request'])
1615
1616
1617 def print_rsp_type(ri):
1618     if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1619         direction = 'reply'
1620     elif ri.op_mode == 'event':
1621         direction = 'reply'
1622     else:
1623         return
1624     print_type(ri, direction)
1625
1626
1627 def print_wrapped_type(ri):
1628     ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1629     if ri.op_mode == 'dump':
1630         ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1631     elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1632         ri.cw.p('__u16 family;')
1633         ri.cw.p('__u8 cmd;')
1634         ri.cw.p('struct ynl_ntf_base_type *next;')
1635         ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1636     ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1637     ri.cw.block_end(line=';')
1638     ri.cw.nl()
1639     print_free_prototype(ri, 'reply')
1640     ri.cw.nl()
1641
1642
1643 def _free_type_members_iter(ri, struct):
1644     for _, attr in struct.member_list():
1645         if attr.free_needs_iter():
1646             ri.cw.p('unsigned int i;')
1647             ri.cw.nl()
1648             break
1649
1650
1651 def _free_type_members(ri, var, struct, ref=''):
1652     for _, attr in struct.member_list():
1653         attr.free(ri, var, ref)
1654
1655
1656 def _free_type(ri, direction, struct):
1657     var = free_arg_name(direction)
1658
1659     print_free_prototype(ri, direction, suffix='')
1660     ri.cw.block_start()
1661     _free_type_members_iter(ri, struct)
1662     _free_type_members(ri, var, struct)
1663     if direction:
1664         ri.cw.p(f'free({var});')
1665     ri.cw.block_end()
1666     ri.cw.nl()
1667
1668
1669 def free_rsp_nested(ri, struct):
1670     _free_type(ri, "", struct)
1671
1672
1673 def print_rsp_free(ri):
1674     if 'reply' not in ri.op[ri.op_mode]:
1675         return
1676     _free_type(ri, 'reply', ri.struct['reply'])
1677
1678
1679 def print_dump_type_free(ri):
1680     sub_type = type_name(ri, 'reply')
1681
1682     print_free_prototype(ri, 'reply', suffix='')
1683     ri.cw.block_start()
1684     ri.cw.p(f"{sub_type} *next = rsp;")
1685     ri.cw.nl()
1686     ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
1687     _free_type_members_iter(ri, ri.struct['reply'])
1688     ri.cw.p('rsp = next;')
1689     ri.cw.p('next = rsp->next;')
1690     ri.cw.nl()
1691
1692     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1693     ri.cw.p(f'free(rsp);')
1694     ri.cw.block_end()
1695     ri.cw.block_end()
1696     ri.cw.nl()
1697
1698
1699 def print_ntf_type_free(ri):
1700     print_free_prototype(ri, 'reply', suffix='')
1701     ri.cw.block_start()
1702     _free_type_members_iter(ri, ri.struct['reply'])
1703     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1704     ri.cw.p(f'free(rsp);')
1705     ri.cw.block_end()
1706     ri.cw.nl()
1707
1708
1709 def print_ntf_parse_prototype(family, cw, suffix=';'):
1710     cw.write_func_prot('struct ynl_ntf_base_type *', f"{family['name']}_ntf_parse",
1711                        ['struct ynl_sock *ys'], suffix=suffix)
1712
1713
1714 def print_ntf_type_parse(family, cw, ku_mode):
1715     print_ntf_parse_prototype(family, cw, suffix='')
1716     cw.block_start()
1717     cw.write_func_lvar(['struct genlmsghdr *genlh;',
1718                         'struct nlmsghdr *nlh;',
1719                         'struct ynl_parse_arg yarg = { .ys = ys, };',
1720                         'struct ynl_ntf_base_type *rsp;',
1721                         'int len, err;',
1722                         'mnl_cb_t parse;'])
1723     cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1724     cw.p('if (len < (ssize_t)(sizeof(*nlh) + sizeof(*genlh)))')
1725     cw.p('return NULL;')
1726     cw.nl()
1727     cw.p('nlh = (struct nlmsghdr *)ys->rx_buf;')
1728     cw.p('genlh = mnl_nlmsg_get_payload(nlh);')
1729     cw.nl()
1730     cw.block_start(line='switch (genlh->cmd)')
1731     for ntf_op in sorted(family.all_notify.keys()):
1732         op = family.ops[ntf_op]
1733         ri = RenderInfo(cw, family, ku_mode, op, ntf_op, "notify")
1734         for ntf in op['notify']['cmds']:
1735             cw.p(f"case {ntf.enum_name}:")
1736         cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'notify')}));")
1737         cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;")
1738         cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1739         cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;")
1740         cw.p('break;')
1741     for op_name, op in family.ops.items():
1742         if 'event' not in op:
1743             continue
1744         ri = RenderInfo(cw, family, ku_mode, op, op_name, "event")
1745         cw.p(f"case {op.enum_name}:")
1746         cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'event')}));")
1747         cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;")
1748         cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1749         cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;")
1750         cw.p('break;')
1751     cw.p('default:')
1752     cw.p('ynl_error_unknown_notification(ys, genlh->cmd);')
1753     cw.p('return NULL;')
1754     cw.block_end()
1755     cw.nl()
1756     cw.p('yarg.data = rsp->data;')
1757     cw.nl()
1758     cw.p(f"err = {cw.nlib.parse_cb_run('parse', '&yarg', True)};")
1759     cw.p('if (err < 0)')
1760     cw.p('goto err_free;')
1761     cw.nl()
1762     cw.p('rsp->family = nlh->nlmsg_type;')
1763     cw.p('rsp->cmd = genlh->cmd;')
1764     cw.p('return rsp;')
1765     cw.nl()
1766     cw.p('err_free:')
1767     cw.p('free(rsp);')
1768     cw.p('return NULL;')
1769     cw.block_end()
1770     cw.nl()
1771
1772
1773 def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1774     if terminate and ri and kernel_can_gen_family_struct(struct.family):
1775         return
1776
1777     if terminate:
1778         prefix = 'extern '
1779     else:
1780         if kernel_can_gen_family_struct(struct.family) and ri:
1781             prefix = 'static '
1782         else:
1783             prefix = ''
1784
1785     suffix = ';' if terminate else ' = {'
1786
1787     max_attr = struct.attr_max_val
1788     if ri:
1789         name = ri.op.render_name
1790         if ri.op.dual_policy:
1791             name += '_' + ri.op_mode
1792     else:
1793         name = struct.render_name
1794     cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1795
1796
1797 def print_req_policy(cw, struct, ri=None):
1798     print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1799     for _, arg in struct.member_list():
1800         arg.attr_policy(cw)
1801     cw.p("};")
1802
1803
1804 def kernel_can_gen_family_struct(family):
1805     return family.proto == 'genetlink'
1806
1807
1808 def print_kernel_op_table_fwd(family, cw, terminate):
1809     exported = not kernel_can_gen_family_struct(family)
1810
1811     if not terminate or exported:
1812         cw.p(f"/* Ops table for {family.name} */")
1813
1814         pol_to_struct = {'global': 'genl_small_ops',
1815                          'per-op': 'genl_ops',
1816                          'split': 'genl_split_ops'}
1817         struct_type = pol_to_struct[family.kernel_policy]
1818
1819         if not exported:
1820             cnt = ""
1821         elif family.kernel_policy == 'split':
1822             cnt = 0
1823             for op in family.ops.values():
1824                 if 'do' in op:
1825                     cnt += 1
1826                 if 'dump' in op:
1827                     cnt += 1
1828         else:
1829             cnt = len(family.ops)
1830
1831         qual = 'static const' if not exported else 'const'
1832         line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1833         if terminate:
1834             cw.p(f"extern {line};")
1835         else:
1836             cw.block_start(line=line + ' =')
1837
1838     if not terminate:
1839         return
1840
1841     cw.nl()
1842     for name in family.hooks['pre']['do']['list']:
1843         cw.write_func_prot('int', c_lower(name),
1844                            ['const struct genl_split_ops *ops',
1845                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1846     for name in family.hooks['post']['do']['list']:
1847         cw.write_func_prot('void', c_lower(name),
1848                            ['const struct genl_split_ops *ops',
1849                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1850     for name in family.hooks['pre']['dump']['list']:
1851         cw.write_func_prot('int', c_lower(name),
1852                            ['struct netlink_callback *cb'], suffix=';')
1853     for name in family.hooks['post']['dump']['list']:
1854         cw.write_func_prot('int', c_lower(name),
1855                            ['struct netlink_callback *cb'], suffix=';')
1856
1857     cw.nl()
1858
1859     for op_name, op in family.ops.items():
1860         if op.is_async:
1861             continue
1862
1863         if 'do' in op:
1864             name = c_lower(f"{family.name}-nl-{op_name}-doit")
1865             cw.write_func_prot('int', name,
1866                                ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1867
1868         if 'dump' in op:
1869             name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1870             cw.write_func_prot('int', name,
1871                                ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1872     cw.nl()
1873
1874
1875 def print_kernel_op_table_hdr(family, cw):
1876     print_kernel_op_table_fwd(family, cw, terminate=True)
1877
1878
1879 def print_kernel_op_table(family, cw):
1880     print_kernel_op_table_fwd(family, cw, terminate=False)
1881     if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1882         for op_name, op in family.ops.items():
1883             if op.is_async:
1884                 continue
1885
1886             cw.block_start()
1887             members = [('cmd', op.enum_name)]
1888             if 'dont-validate' in op:
1889                 members.append(('validate',
1890                                 ' | '.join([c_upper('genl-dont-validate-' + x)
1891                                             for x in op['dont-validate']])), )
1892             for op_mode in ['do', 'dump']:
1893                 if op_mode in op:
1894                     name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1895                     members.append((op_mode + 'it', name))
1896             if family.kernel_policy == 'per-op':
1897                 struct = Struct(family, op['attribute-set'],
1898                                 type_list=op['do']['request']['attributes'])
1899
1900                 name = c_lower(f"{family.name}-{op_name}-nl-policy")
1901                 members.append(('policy', name))
1902                 members.append(('maxattr', struct.attr_max_val.enum_name))
1903             if 'flags' in op:
1904                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1905             cw.write_struct_init(members)
1906             cw.block_end(line=',')
1907     elif family.kernel_policy == 'split':
1908         cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1909                     'dump': {'pre': 'start', 'post': 'done'}}
1910
1911         for op_name, op in family.ops.items():
1912             for op_mode in ['do', 'dump']:
1913                 if op.is_async or op_mode not in op:
1914                     continue
1915
1916                 cw.block_start()
1917                 members = [('cmd', op.enum_name)]
1918                 if 'dont-validate' in op:
1919                     members.append(('validate',
1920                                     ' | '.join([c_upper('genl-dont-validate-' + x)
1921                                                 for x in op['dont-validate']])), )
1922                 name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1923                 if 'pre' in op[op_mode]:
1924                     members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
1925                 members.append((op_mode + 'it', name))
1926                 if 'post' in op[op_mode]:
1927                     members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
1928                 if 'request' in op[op_mode]:
1929                     struct = Struct(family, op['attribute-set'],
1930                                     type_list=op[op_mode]['request']['attributes'])
1931
1932                     if op.dual_policy:
1933                         name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
1934                     else:
1935                         name = c_lower(f"{family.name}-{op_name}-nl-policy")
1936                     members.append(('policy', name))
1937                     members.append(('maxattr', struct.attr_max_val.enum_name))
1938                 flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
1939                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
1940                 cw.write_struct_init(members)
1941                 cw.block_end(line=',')
1942
1943     cw.block_end(line=';')
1944     cw.nl()
1945
1946
1947 def print_kernel_mcgrp_hdr(family, cw):
1948     if not family.mcgrps['list']:
1949         return
1950
1951     cw.block_start('enum')
1952     for grp in family.mcgrps['list']:
1953         grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
1954         cw.p(grp_id)
1955     cw.block_end(';')
1956     cw.nl()
1957
1958
1959 def print_kernel_mcgrp_src(family, cw):
1960     if not family.mcgrps['list']:
1961         return
1962
1963     cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
1964     for grp in family.mcgrps['list']:
1965         name = grp['name']
1966         grp_id = c_upper(f"{family.name}-nlgrp-{name}")
1967         cw.p('[' + grp_id + '] = { "' + name + '", },')
1968     cw.block_end(';')
1969     cw.nl()
1970
1971
1972 def print_kernel_family_struct_hdr(family, cw):
1973     if not kernel_can_gen_family_struct(family):
1974         return
1975
1976     cw.p(f"extern struct genl_family {family.name}_nl_family;")
1977     cw.nl()
1978
1979
1980 def print_kernel_family_struct_src(family, cw):
1981     if not kernel_can_gen_family_struct(family):
1982         return
1983
1984     cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
1985     cw.p('.name\t\t= ' + family.fam_key + ',')
1986     cw.p('.version\t= ' + family.ver_key + ',')
1987     cw.p('.netnsok\t= true,')
1988     cw.p('.parallel_ops\t= true,')
1989     cw.p('.module\t\t= THIS_MODULE,')
1990     if family.kernel_policy == 'per-op':
1991         cw.p(f'.ops\t\t= {family.name}_nl_ops,')
1992         cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
1993     elif family.kernel_policy == 'split':
1994         cw.p(f'.split_ops\t= {family.name}_nl_ops,')
1995         cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
1996     if family.mcgrps['list']:
1997         cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
1998         cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
1999     cw.block_end(';')
2000
2001
2002 def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2003     start_line = 'enum'
2004     if enum_name in obj:
2005         if obj[enum_name]:
2006             start_line = 'enum ' + c_lower(obj[enum_name])
2007     elif ckey and ckey in obj:
2008         start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
2009     cw.block_start(line=start_line)
2010
2011
2012 def render_uapi(family, cw):
2013     hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
2014     cw.p('#ifndef ' + hdr_prot)
2015     cw.p('#define ' + hdr_prot)
2016     cw.nl()
2017
2018     defines = [(family.fam_key, family["name"]),
2019                (family.ver_key, family.get('version', 1))]
2020     cw.writes_defines(defines)
2021     cw.nl()
2022
2023     defines = []
2024     for const in family['definitions']:
2025         if const['type'] != 'const':
2026             cw.writes_defines(defines)
2027             defines = []
2028             cw.nl()
2029
2030         # Write kdoc for enum and flags (one day maybe also structs)
2031         if const['type'] == 'enum' or const['type'] == 'flags':
2032             enum = family.consts[const['name']]
2033
2034             if enum.has_doc():
2035                 cw.p('/**')
2036                 doc = ''
2037                 if 'doc' in enum:
2038                     doc = ' - ' + enum['doc']
2039                 cw.write_doc_line(enum.enum_name + doc)
2040                 for entry in enum.entries.values():
2041                     if entry.has_doc():
2042                         doc = '@' + entry.c_name + ': ' + entry['doc']
2043                         cw.write_doc_line(doc)
2044                 cw.p(' */')
2045
2046             uapi_enum_start(family, cw, const, 'name')
2047             name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2048             for entry in enum.entries.values():
2049                 suffix = ','
2050                 if entry.value_change:
2051                     suffix = f" = {entry.user_value()}" + suffix
2052                 cw.p(entry.c_name + suffix)
2053
2054             if const.get('render-max', False):
2055                 cw.nl()
2056                 if const['type'] == 'flags':
2057                     max_name = c_upper(name_pfx + 'mask')
2058                     max_val = f' = {enum.get_mask()},'
2059                     cw.p(max_name + max_val)
2060                 else:
2061                     max_name = c_upper(name_pfx + 'max')
2062                     cw.p('__' + max_name + ',')
2063                     cw.p(max_name + ' = (__' + max_name + ' - 1)')
2064             cw.block_end(line=';')
2065             cw.nl()
2066         elif const['type'] == 'const':
2067             defines.append([c_upper(family.get('c-define-name',
2068                                                f"{family.name}-{const['name']}")),
2069                             const['value']])
2070
2071     if defines:
2072         cw.writes_defines(defines)
2073         cw.nl()
2074
2075     max_by_define = family.get('max-by-define', False)
2076
2077     for _, attr_set in family.attr_sets.items():
2078         if attr_set.subset_of:
2079             continue
2080
2081         cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
2082         max_value = f"({cnt_name} - 1)"
2083
2084         val = 0
2085         uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2086         for _, attr in attr_set.items():
2087             suffix = ','
2088             if attr.value != val:
2089                 suffix = f" = {attr.value},"
2090                 val = attr.value
2091             val += 1
2092             cw.p(attr.enum_name + suffix)
2093         cw.nl()
2094         cw.p(cnt_name + ('' if max_by_define else ','))
2095         if not max_by_define:
2096             cw.p(f"{attr_set.max_name} = {max_value}")
2097         cw.block_end(line=';')
2098         if max_by_define:
2099             cw.p(f"#define {attr_set.max_name} {max_value}")
2100         cw.nl()
2101
2102     # Commands
2103     separate_ntf = 'async-prefix' in family['operations']
2104
2105     max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2106     cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2107     max_value = f"({cnt_name} - 1)"
2108
2109     uapi_enum_start(family, cw, family['operations'], 'enum-name')
2110     val = 0
2111     for op in family.msgs.values():
2112         if separate_ntf and ('notify' in op or 'event' in op):
2113             continue
2114
2115         suffix = ','
2116         if op.value != val:
2117             suffix = f" = {op.value},"
2118             val = op.value
2119         cw.p(op.enum_name + suffix)
2120         val += 1
2121     cw.nl()
2122     cw.p(cnt_name + ('' if max_by_define else ','))
2123     if not max_by_define:
2124         cw.p(f"{max_name} = {max_value}")
2125     cw.block_end(line=';')
2126     if max_by_define:
2127         cw.p(f"#define {max_name} {max_value}")
2128     cw.nl()
2129
2130     if separate_ntf:
2131         uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2132         for op in family.msgs.values():
2133             if separate_ntf and not ('notify' in op or 'event' in op):
2134                 continue
2135
2136             suffix = ','
2137             if 'value' in op:
2138                 suffix = f" = {op['value']},"
2139             cw.p(op.enum_name + suffix)
2140         cw.block_end(line=';')
2141         cw.nl()
2142
2143     # Multicast
2144     defines = []
2145     for grp in family.mcgrps['list']:
2146         name = grp['name']
2147         defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2148                         f'{name}'])
2149     cw.nl()
2150     if defines:
2151         cw.writes_defines(defines)
2152         cw.nl()
2153
2154     cw.p(f'#endif /* {hdr_prot} */')
2155
2156
2157 def _render_user_ntf_entry(ri, op):
2158     ri.cw.block_start(line=f"[{op.enum_name}] = ")
2159     ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2160     ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2161     ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2162     ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2163     ri.cw.block_end(line=',')
2164
2165
2166 def render_user_family(family, cw, prototype):
2167     symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2168     if prototype:
2169         cw.p(f'extern {symbol};')
2170         return
2171
2172     ntf = family.has_notifications()
2173     if ntf:
2174         cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2175         for ntf_op in sorted(family.all_notify.keys()):
2176             op = family.ops[ntf_op]
2177             ri = RenderInfo(cw, family, "user", op, ntf_op, "notify")
2178             for ntf in op['notify']['cmds']:
2179                 _render_user_ntf_entry(ri, ntf)
2180         for op_name, op in family.ops.items():
2181             if 'event' not in op:
2182                 continue
2183             ri = RenderInfo(cw, family, "user", op, op_name, "event")
2184             _render_user_ntf_entry(ri, op)
2185         cw.block_end(line=";")
2186         cw.nl()
2187
2188     cw.block_start(f'{symbol} = ')
2189     cw.p(f'.name\t\t= "{family.name}",')
2190     if ntf:
2191         cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2192         cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2193     cw.block_end(line=';')
2194
2195
2196 def find_kernel_root(full_path):
2197     sub_path = ''
2198     while True:
2199         sub_path = os.path.join(os.path.basename(full_path), sub_path)
2200         full_path = os.path.dirname(full_path)
2201         maintainers = os.path.join(full_path, "MAINTAINERS")
2202         if os.path.exists(maintainers):
2203             return full_path, sub_path[:-1]
2204
2205
2206 def main():
2207     parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2208     parser.add_argument('--mode', dest='mode', type=str, required=True)
2209     parser.add_argument('--spec', dest='spec', type=str, required=True)
2210     parser.add_argument('--header', dest='header', action='store_true', default=None)
2211     parser.add_argument('--source', dest='header', action='store_false')
2212     parser.add_argument('--user-header', nargs='+', default=[])
2213     parser.add_argument('-o', dest='out_file', type=str)
2214     args = parser.parse_args()
2215
2216     out_file = open(args.out_file, 'w+') if args.out_file else os.sys.stdout
2217
2218     if args.header is None:
2219         parser.error("--header or --source is required")
2220
2221     try:
2222         parsed = Family(args.spec)
2223         if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2224             print('Spec license:', parsed.license)
2225             print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2226             os.sys.exit(1)
2227     except yaml.YAMLError as exc:
2228         print(exc)
2229         os.sys.exit(1)
2230         return
2231
2232     cw = CodeWriter(BaseNlLib(), out_file)
2233
2234     _, spec_kernel = find_kernel_root(args.spec)
2235     if args.mode == 'uapi' or args.header:
2236         cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2237     else:
2238         cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2239     cw.p("/* Do not edit directly, auto-generated from: */")
2240     cw.p(f"/*\t{spec_kernel} */")
2241     cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2242     cw.nl()
2243
2244     if args.mode == 'uapi':
2245         render_uapi(parsed, cw)
2246         return
2247
2248     hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2249     if args.header:
2250         cw.p('#ifndef ' + hdr_prot)
2251         cw.p('#define ' + hdr_prot)
2252         cw.nl()
2253
2254     if args.mode == 'kernel':
2255         cw.p('#include <net/netlink.h>')
2256         cw.p('#include <net/genetlink.h>')
2257         cw.nl()
2258         if not args.header:
2259             if args.out_file:
2260                 cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2261             cw.nl()
2262         headers = ['uapi/' + parsed.uapi_header]
2263     else:
2264         cw.p('#include <stdlib.h>')
2265         if args.header:
2266             cw.p('#include <string.h>')
2267             cw.p('#include <linux/types.h>')
2268         else:
2269             cw.p(f'#include "{parsed.name}-user.h"')
2270             cw.p('#include "ynl.h"')
2271         headers = [parsed.uapi_header]
2272     for definition in parsed['definitions']:
2273         if 'header' in definition:
2274             headers.append(definition['header'])
2275     for one in headers:
2276         cw.p(f"#include <{one}>")
2277     cw.nl()
2278
2279     if args.mode == "user":
2280         if not args.header:
2281             cw.p("#include <stdlib.h>")
2282             cw.p("#include <stdio.h>")
2283             cw.p("#include <string.h>")
2284             cw.p("#include <libmnl/libmnl.h>")
2285             cw.p("#include <linux/genetlink.h>")
2286             cw.nl()
2287             for one in args.user_header:
2288                 cw.p(f'#include "{one}"')
2289         else:
2290             cw.p('struct ynl_sock;')
2291             cw.nl()
2292             render_user_family(parsed, cw, True)
2293         cw.nl()
2294
2295     if args.mode == "kernel":
2296         if args.header:
2297             for _, struct in sorted(parsed.pure_nested_structs.items()):
2298                 if struct.request:
2299                     cw.p('/* Common nested types */')
2300                     break
2301             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2302                 if struct.request:
2303                     print_req_policy_fwd(cw, struct)
2304             cw.nl()
2305
2306             if parsed.kernel_policy == 'global':
2307                 cw.p(f"/* Global operation policy for {parsed.name} */")
2308
2309                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2310                 print_req_policy_fwd(cw, struct)
2311                 cw.nl()
2312
2313             if parsed.kernel_policy in {'per-op', 'split'}:
2314                 for op_name, op in parsed.ops.items():
2315                     if 'do' in op and 'event' not in op:
2316                         ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2317                         print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2318                         cw.nl()
2319
2320             print_kernel_op_table_hdr(parsed, cw)
2321             print_kernel_mcgrp_hdr(parsed, cw)
2322             print_kernel_family_struct_hdr(parsed, cw)
2323         else:
2324             for _, struct in sorted(parsed.pure_nested_structs.items()):
2325                 if struct.request:
2326                     cw.p('/* Common nested types */')
2327                     break
2328             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2329                 if struct.request:
2330                     print_req_policy(cw, struct)
2331             cw.nl()
2332
2333             if parsed.kernel_policy == 'global':
2334                 cw.p(f"/* Global operation policy for {parsed.name} */")
2335
2336                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2337                 print_req_policy(cw, struct)
2338                 cw.nl()
2339
2340             for op_name, op in parsed.ops.items():
2341                 if parsed.kernel_policy in {'per-op', 'split'}:
2342                     for op_mode in ['do', 'dump']:
2343                         if op_mode in op and 'request' in op[op_mode]:
2344                             cw.p(f"/* {op.enum_name} - {op_mode} */")
2345                             ri = RenderInfo(cw, parsed, args.mode, op, op_name, op_mode)
2346                             print_req_policy(cw, ri.struct['request'], ri=ri)
2347                             cw.nl()
2348
2349             print_kernel_op_table(parsed, cw)
2350             print_kernel_mcgrp_src(parsed, cw)
2351             print_kernel_family_struct_src(parsed, cw)
2352
2353     if args.mode == "user":
2354         if args.header:
2355             cw.p('/* Enums */')
2356             put_op_name_fwd(parsed, cw)
2357
2358             for name, const in parsed.consts.items():
2359                 if isinstance(const, EnumSet):
2360                     put_enum_to_str_fwd(parsed, cw, const)
2361             cw.nl()
2362
2363             cw.p('/* Common nested types */')
2364             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2365                 ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2366                 print_type_full(ri, struct)
2367
2368             for op_name, op in parsed.ops.items():
2369                 cw.p(f"/* ============== {op.enum_name} ============== */")
2370
2371                 if 'do' in op and 'event' not in op:
2372                     cw.p(f"/* {op.enum_name} - do */")
2373                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2374                     print_req_type(ri)
2375                     print_req_type_helpers(ri)
2376                     cw.nl()
2377                     print_rsp_type(ri)
2378                     print_rsp_type_helpers(ri)
2379                     cw.nl()
2380                     print_req_prototype(ri)
2381                     cw.nl()
2382
2383                 if 'dump' in op:
2384                     cw.p(f"/* {op.enum_name} - dump */")
2385                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'dump')
2386                     if 'request' in op['dump']:
2387                         print_req_type(ri)
2388                         print_req_type_helpers(ri)
2389                     if not ri.type_consistent:
2390                         print_rsp_type(ri)
2391                     print_wrapped_type(ri)
2392                     print_dump_prototype(ri)
2393                     cw.nl()
2394
2395                 if 'notify' in op:
2396                     cw.p(f"/* {op.enum_name} - notify */")
2397                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2398                     if not ri.type_consistent:
2399                         raise Exception(f'Only notifications with consistent types supported ({op.name})')
2400                     print_wrapped_type(ri)
2401
2402                 if 'event' in op:
2403                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'event')
2404                     cw.p(f"/* {op.enum_name} - event */")
2405                     print_rsp_type(ri)
2406                     cw.nl()
2407                     print_wrapped_type(ri)
2408
2409             if parsed.has_notifications():
2410                 cw.p('/* --------------- Common notification parsing --------------- */')
2411                 print_ntf_parse_prototype(parsed, cw)
2412             cw.nl()
2413         else:
2414             cw.p('/* Enums */')
2415             put_op_name(parsed, cw)
2416
2417             for name, const in parsed.consts.items():
2418                 if isinstance(const, EnumSet):
2419                     put_enum_to_str(parsed, cw, const)
2420             cw.nl()
2421
2422             cw.p('/* Policies */')
2423             for name, _ in parsed.attr_sets.items():
2424                 struct = Struct(parsed, name)
2425                 put_typol_fwd(cw, struct)
2426             cw.nl()
2427
2428             for name, _ in parsed.attr_sets.items():
2429                 struct = Struct(parsed, name)
2430                 put_typol(cw, struct)
2431
2432             cw.p('/* Common nested types */')
2433             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2434                 ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2435
2436                 free_rsp_nested(ri, struct)
2437                 if struct.request:
2438                     put_req_nested(ri, struct)
2439                 if struct.reply:
2440                     parse_rsp_nested(ri, struct)
2441
2442             for op_name, op in parsed.ops.items():
2443                 cw.p(f"/* ============== {op.enum_name} ============== */")
2444                 if 'do' in op and 'event' not in op:
2445                     cw.p(f"/* {op.enum_name} - do */")
2446                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2447                     print_req_free(ri)
2448                     print_rsp_free(ri)
2449                     parse_rsp_msg(ri)
2450                     print_req(ri)
2451                     cw.nl()
2452
2453                 if 'dump' in op:
2454                     cw.p(f"/* {op.enum_name} - dump */")
2455                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "dump")
2456                     if not ri.type_consistent:
2457                         parse_rsp_msg(ri, deref=True)
2458                     print_dump_type_free(ri)
2459                     print_dump(ri)
2460                     cw.nl()
2461
2462                 if 'notify' in op:
2463                     cw.p(f"/* {op.enum_name} - notify */")
2464                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2465                     if not ri.type_consistent:
2466                         raise Exception(f'Only notifications with consistent types supported ({op.name})')
2467                     print_ntf_type_free(ri)
2468
2469                 if 'event' in op:
2470                     cw.p(f"/* {op.enum_name} - event */")
2471
2472                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2473                     parse_rsp_msg(ri)
2474
2475                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "event")
2476                     print_ntf_type_free(ri)
2477
2478             if parsed.has_notifications():
2479                 cw.p('/* --------------- Common notification parsing --------------- */')
2480                 print_ntf_type_parse(parsed, cw, args.mode)
2481
2482             cw.nl()
2483             render_user_family(parsed, cw, False)
2484
2485     if args.header:
2486         cw.p(f'#endif /* {hdr_prot} */')
2487
2488
2489 if __name__ == "__main__":
2490     main()