tools: ynl: move the enum classes to shared code
[platform/kernel/linux-starfive.git] / tools / net / ynl / ynl-gen-c.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
3
4 import argparse
5 import collections
6 import os
7 import yaml
8
9 from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
10
11
12 def c_upper(name):
13     return name.upper().replace('-', '_')
14
15
16 def c_lower(name):
17     return name.lower().replace('-', '_')
18
19
20 class BaseNlLib:
21     def get_family_id(self):
22         return 'ys->family_id'
23
24     def parse_cb_run(self, cb, data, is_dump=False, indent=1):
25         ind = '\n\t\t' + '\t' * indent + ' '
26         if is_dump:
27             return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
28         else:
29             return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
30                    "ynl_cb_array, NLMSG_MIN_TYPE)"
31
32
33 class Type(SpecAttr):
34     def __init__(self, family, attr_set, attr, value):
35         super().__init__(family, attr_set, attr, value)
36
37         self.attr = attr
38         self.attr_set = attr_set
39         self.type = attr['type']
40         self.checks = attr.get('checks', {})
41
42         if 'len' in attr:
43             self.len = attr['len']
44         if 'nested-attributes' in attr:
45             self.nested_attrs = attr['nested-attributes']
46             if self.nested_attrs == family.name:
47                 self.nested_render_name = f"{family.name}"
48             else:
49                 self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
50
51         self.c_name = c_lower(self.name)
52         if self.c_name in _C_KW:
53             self.c_name += '_'
54
55         # Added by resolve():
56         self.enum_name = None
57         delattr(self, "enum_name")
58
59     def resolve(self):
60         self.enum_name = f"{self.attr_set.name_prefix}{self.name}"
61         self.enum_name = c_upper(self.enum_name)
62
63     def is_multi_val(self):
64         return None
65
66     def is_scalar(self):
67         return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
68
69     def presence_type(self):
70         return 'bit'
71
72     def presence_member(self, space, type_filter):
73         if self.presence_type() != type_filter:
74             return
75
76         if self.presence_type() == 'bit':
77             pfx = '__' if space == 'user' else ''
78             return f"{pfx}u32 {self.c_name}:1;"
79
80         if self.presence_type() == 'len':
81             pfx = '__' if space == 'user' else ''
82             return f"{pfx}u32 {self.c_name}_len;"
83
84     def _complex_member_type(self, ri):
85         return None
86
87     def free_needs_iter(self):
88         return False
89
90     def free(self, ri, var, ref):
91         if self.is_multi_val() or self.presence_type() == 'len':
92             ri.cw.p(f'free({var}->{ref}{self.c_name});')
93
94     def arg_member(self, ri):
95         member = self._complex_member_type(ri)
96         if member:
97             return [member + ' *' + self.c_name]
98         raise Exception(f"Struct member not implemented for class type {self.type}")
99
100     def struct_member(self, ri):
101         if self.is_multi_val():
102             ri.cw.p(f"unsigned int n_{self.c_name};")
103         member = self._complex_member_type(ri)
104         if member:
105             ptr = '*' if self.is_multi_val() else ''
106             ri.cw.p(f"{member} {ptr}{self.c_name};")
107             return
108         members = self.arg_member(ri)
109         for one in members:
110             ri.cw.p(one + ';')
111
112     def _attr_policy(self, policy):
113         return '{ .type = ' + policy + ', }'
114
115     def attr_policy(self, cw):
116         policy = c_upper('nla-' + self.attr['type'])
117
118         spec = self._attr_policy(policy)
119         cw.p(f"\t[{self.enum_name}] = {spec},")
120
121     def _attr_typol(self):
122         raise Exception(f"Type policy not implemented for class type {self.type}")
123
124     def attr_typol(self, cw):
125         typol = self._attr_typol()
126         cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
127
128     def _attr_put_line(self, ri, var, line):
129         if self.presence_type() == 'bit':
130             ri.cw.p(f"if ({var}->_present.{self.c_name})")
131         elif self.presence_type() == 'len':
132             ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
133         ri.cw.p(f"{line};")
134
135     def _attr_put_simple(self, ri, var, put_type):
136         line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
137         self._attr_put_line(ri, var, line)
138
139     def attr_put(self, ri, var):
140         raise Exception(f"Put not implemented for class type {self.type}")
141
142     def _attr_get(self, ri, var):
143         raise Exception(f"Attr get not implemented for class type {self.type}")
144
145     def attr_get(self, ri, var, first):
146         lines, init_lines, local_vars = self._attr_get(ri, var)
147         if type(lines) is str:
148             lines = [lines]
149         if type(init_lines) is str:
150             init_lines = [init_lines]
151
152         kw = 'if' if first else 'else if'
153         ri.cw.block_start(line=f"{kw} (mnl_attr_get_type(attr) == {self.enum_name})")
154         if local_vars:
155             for local in local_vars:
156                 ri.cw.p(local)
157             ri.cw.nl()
158
159         if not self.is_multi_val():
160             ri.cw.p("if (ynl_attr_validate(yarg, attr))")
161             ri.cw.p("return MNL_CB_ERROR;")
162             if self.presence_type() == 'bit':
163                 ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
164
165         if init_lines:
166             ri.cw.nl()
167             for line in init_lines:
168                 ri.cw.p(line)
169
170         for line in lines:
171             ri.cw.p(line)
172         ri.cw.block_end()
173
174     def _setter_lines(self, ri, member, presence):
175         raise Exception(f"Setter not implemented for class type {self.type}")
176
177     def setter(self, ri, space, direction, deref=False, ref=None):
178         ref = (ref if ref else []) + [self.c_name]
179         var = "req"
180         member = f"{var}->{'.'.join(ref)}"
181
182         code = []
183         presence = ''
184         for i in range(0, len(ref)):
185             presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
186             if self.presence_type() == 'bit':
187                 code.append(presence + ' = 1;')
188         code += self._setter_lines(ri, member, presence)
189
190         ri.cw.write_func('static inline void',
191                          f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}",
192                          body=code,
193                          args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
194
195
196 class TypeUnused(Type):
197     def presence_type(self):
198         return ''
199
200     def _attr_typol(self):
201         return '.type = YNL_PT_REJECT, '
202
203     def attr_policy(self, cw):
204         pass
205
206
207 class TypePad(Type):
208     def presence_type(self):
209         return ''
210
211     def _attr_typol(self):
212         return '.type = YNL_PT_REJECT, '
213
214     def attr_policy(self, cw):
215         pass
216
217
218 class TypeScalar(Type):
219     def __init__(self, family, attr_set, attr, value):
220         super().__init__(family, attr_set, attr, value)
221
222         self.byte_order_comment = ''
223         if 'byte-order' in attr:
224             self.byte_order_comment = f" /* {attr['byte-order']} */"
225
226         # Added by resolve():
227         self.is_bitfield = None
228         delattr(self, "is_bitfield")
229         self.type_name = None
230         delattr(self, "type_name")
231
232     def resolve(self):
233         self.resolve_up(super())
234
235         if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
236             self.is_bitfield = True
237         elif 'enum' in self.attr:
238             self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
239         else:
240             self.is_bitfield = False
241
242         if 'enum' in self.attr and not self.is_bitfield:
243             self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
244         else:
245             self.type_name = '__' + self.type
246
247     def _mnl_type(self):
248         t = self.type
249         # mnl does not have a helper for signed types
250         if t[0] == 's':
251             t = 'u' + t[1:]
252         return t
253
254     def _attr_policy(self, policy):
255         if 'flags-mask' in self.checks or self.is_bitfield:
256             if self.is_bitfield:
257                 mask = self.family.consts[self.attr['enum']].get_mask()
258             else:
259                 flags = self.family.consts[self.checks['flags-mask']]
260                 flag_cnt = len(flags['entries'])
261                 mask = (1 << flag_cnt) - 1
262             return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
263         elif 'min' in self.checks:
264             return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
265         elif 'enum' in self.attr:
266             enum = self.family.consts[self.attr['enum']]
267             cnt = len(enum['entries'])
268             return f"NLA_POLICY_MAX({policy}, {cnt - 1})"
269         return super()._attr_policy(policy)
270
271     def _attr_typol(self):
272         return f'.type = YNL_PT_U{self.type[1:]}, '
273
274     def arg_member(self, ri):
275         return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
276
277     def attr_put(self, ri, var):
278         self._attr_put_simple(ri, var, self._mnl_type())
279
280     def _attr_get(self, ri, var):
281         return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
282
283     def _setter_lines(self, ri, member, presence):
284         return [f"{member} = {self.c_name};"]
285
286
287 class TypeFlag(Type):
288     def arg_member(self, ri):
289         return []
290
291     def _attr_typol(self):
292         return '.type = YNL_PT_FLAG, '
293
294     def attr_put(self, ri, var):
295         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
296
297     def _attr_get(self, ri, var):
298         return [], None, None
299
300     def _setter_lines(self, ri, member, presence):
301         return []
302
303
304 class TypeString(Type):
305     def arg_member(self, ri):
306         return [f"const char *{self.c_name}"]
307
308     def presence_type(self):
309         return 'len'
310
311     def struct_member(self, ri):
312         ri.cw.p(f"char *{self.c_name};")
313
314     def _attr_typol(self):
315         return f'.type = YNL_PT_NUL_STR, '
316
317     def _attr_policy(self, policy):
318         mem = '{ .type = ' + policy
319         if 'max-len' in self.checks:
320             mem += ', .len = ' + str(self.checks['max-len'])
321         mem += ', }'
322         return mem
323
324     def attr_policy(self, cw):
325         if self.checks.get('unterminated-ok', False):
326             policy = 'NLA_STRING'
327         else:
328             policy = 'NLA_NUL_STRING'
329
330         spec = self._attr_policy(policy)
331         cw.p(f"\t[{self.enum_name}] = {spec},")
332
333     def attr_put(self, ri, var):
334         self._attr_put_simple(ri, var, 'strz')
335
336     def _attr_get(self, ri, var):
337         len_mem = var + '->_present.' + self.c_name + '_len'
338         return [f"{len_mem} = len;",
339                 f"{var}->{self.c_name} = malloc(len + 1);",
340                 f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
341                 f"{var}->{self.c_name}[len] = 0;"], \
342                ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
343                ['unsigned int len;']
344
345     def _setter_lines(self, ri, member, presence):
346         return [f"free({member});",
347                 f"{presence}_len = strlen({self.c_name});",
348                 f"{member} = malloc({presence}_len + 1);",
349                 f'memcpy({member}, {self.c_name}, {presence}_len);',
350                 f'{member}[{presence}_len] = 0;']
351
352
353 class TypeBinary(Type):
354     def arg_member(self, ri):
355         return [f"const void *{self.c_name}", 'size_t len']
356
357     def presence_type(self):
358         return 'len'
359
360     def struct_member(self, ri):
361         ri.cw.p(f"void *{self.c_name};")
362
363     def _attr_typol(self):
364         return f'.type = YNL_PT_BINARY,'
365
366     def _attr_policy(self, policy):
367         mem = '{ '
368         if len(self.checks) == 1 and 'min-len' in self.checks:
369             mem += '.len = ' + str(self.checks['min-len'])
370         elif len(self.checks) == 0:
371             mem += '.type = NLA_BINARY'
372         else:
373             raise Exception('One or more of binary type checks not implemented, yet')
374         mem += ', }'
375         return mem
376
377     def attr_put(self, ri, var):
378         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
379                             f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
380
381     def _attr_get(self, ri, var):
382         len_mem = var + '->_present.' + self.c_name + '_len'
383         return [f"{len_mem} = len;",
384                 f"{var}->{self.c_name} = malloc(len);",
385                 f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
386                ['len = mnl_attr_get_payload_len(attr);'], \
387                ['unsigned int len;']
388
389     def _setter_lines(self, ri, member, presence):
390         return [f"free({member});",
391                 f"{member} = malloc({presence}_len);",
392                 f'memcpy({member}, {self.c_name}, {presence}_len);']
393
394
395 class TypeNest(Type):
396     def _complex_member_type(self, ri):
397         return f"struct {self.nested_render_name}"
398
399     def free(self, ri, var, ref):
400         ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
401
402     def _attr_typol(self):
403         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
404
405     def _attr_policy(self, policy):
406         return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
407
408     def attr_put(self, ri, var):
409         self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
410                             f"{self.enum_name}, &{var}->{self.c_name})")
411
412     def _attr_get(self, ri, var):
413         get_lines = [f"{self.nested_render_name}_parse(&parg, attr);"]
414         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
415                       f"parg.data = &{var}->{self.c_name};"]
416         return get_lines, init_lines, None
417
418     def setter(self, ri, space, direction, deref=False, ref=None):
419         ref = (ref if ref else []) + [self.c_name]
420
421         for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
422             attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
423
424
425 class TypeMultiAttr(Type):
426     def is_multi_val(self):
427         return True
428
429     def presence_type(self):
430         return 'count'
431
432     def _complex_member_type(self, ri):
433         if 'type' not in self.attr or self.attr['type'] == 'nest':
434             return f"struct {self.nested_render_name}"
435         elif self.attr['type'] in scalars:
436             scalar_pfx = '__' if ri.ku_space == 'user' else ''
437             return scalar_pfx + self.attr['type']
438         else:
439             raise Exception(f"Sub-type {self.attr['type']} not supported yet")
440
441     def free_needs_iter(self):
442         return 'type' not in self.attr or self.attr['type'] == 'nest'
443
444     def free(self, ri, var, ref):
445         if 'type' not in self.attr or self.attr['type'] == 'nest':
446             ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
447             ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
448
449     def _attr_typol(self):
450         if 'type' not in self.attr or self.attr['type'] == 'nest':
451             return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
452         elif self.attr['type'] in scalars:
453             return f".type = YNL_PT_U{self.attr['type'][1:]}, "
454         else:
455             raise Exception(f"Sub-type {self.attr['type']} not supported yet")
456
457     def _attr_get(self, ri, var):
458         return f'{var}->n_{self.c_name}++;', None, None
459
460
461 class TypeArrayNest(Type):
462     def is_multi_val(self):
463         return True
464
465     def presence_type(self):
466         return 'count'
467
468     def _complex_member_type(self, ri):
469         if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
470             return f"struct {self.nested_render_name}"
471         elif self.attr['sub-type'] in scalars:
472             scalar_pfx = '__' if ri.ku_space == 'user' else ''
473             return scalar_pfx + self.attr['sub-type']
474         else:
475             raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
476
477     def _attr_typol(self):
478         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
479
480     def _attr_get(self, ri, var):
481         local_vars = ['const struct nlattr *attr2;']
482         get_lines = [f'attr_{self.c_name} = attr;',
483                      'mnl_attr_for_each_nested(attr2, attr)',
484                      f'\t{var}->n_{self.c_name}++;']
485         return get_lines, None, local_vars
486
487
488 class TypeNestTypeValue(Type):
489     def _complex_member_type(self, ri):
490         return f"struct {self.nested_render_name}"
491
492     def _attr_typol(self):
493         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
494
495     def _attr_get(self, ri, var):
496         prev = 'attr'
497         tv_args = ''
498         get_lines = []
499         local_vars = []
500         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
501                       f"parg.data = &{var}->{self.c_name};"]
502         if 'type-value' in self.attr:
503             tv_names = [c_lower(x) for x in self.attr["type-value"]]
504             local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
505             local_vars += [f'__u32 {", ".join(tv_names)};']
506             for level in self.attr["type-value"]:
507                 level = c_lower(level)
508                 get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
509                 get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
510                 prev = 'attr_' + level
511
512             tv_args = f", {', '.join(tv_names)}"
513
514         get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
515         return get_lines, init_lines, local_vars
516
517
518 class Struct:
519     def __init__(self, family, space_name, type_list=None, inherited=None):
520         self.family = family
521         self.space_name = space_name
522         self.attr_set = family.attr_sets[space_name]
523         # Use list to catch comparisons with empty sets
524         self._inherited = inherited if inherited is not None else []
525         self.inherited = []
526
527         self.nested = type_list is None
528         if family.name == c_lower(space_name):
529             self.render_name = f"{family.name}"
530         else:
531             self.render_name = f"{family.name}_{c_lower(space_name)}"
532         self.struct_name = 'struct ' + self.render_name
533         self.ptr_name = self.struct_name + ' *'
534
535         self.request = False
536         self.reply = False
537
538         self.attr_list = []
539         self.attrs = dict()
540         if type_list:
541             for t in type_list:
542                 self.attr_list.append((t, self.attr_set[t]),)
543         else:
544             for t in self.attr_set:
545                 self.attr_list.append((t, self.attr_set[t]),)
546
547         max_val = 0
548         self.attr_max_val = None
549         for name, attr in self.attr_list:
550             if attr.value >= max_val:
551                 max_val = attr.value
552                 self.attr_max_val = attr
553             self.attrs[name] = attr
554
555     def __iter__(self):
556         yield from self.attrs
557
558     def __getitem__(self, key):
559         return self.attrs[key]
560
561     def member_list(self):
562         return self.attr_list
563
564     def set_inherited(self, new_inherited):
565         if self._inherited != new_inherited:
566             raise Exception("Inheriting different members not supported")
567         self.inherited = [c_lower(x) for x in sorted(self._inherited)]
568
569
570 class EnumEntry(SpecEnumEntry):
571     def __init__(self, enum_set, yaml, prev, value_start):
572         super().__init__(enum_set, yaml, prev, value_start)
573
574         if prev:
575             self.value_change = (self.value != prev.value + 1)
576         else:
577             self.value_change = (self.value != 0)
578         self.value_change = self.value_change or self.enum_set['type'] == 'flags'
579
580         # Added by resolve:
581         self.c_name = None
582         delattr(self, "c_name")
583
584     def resolve(self):
585         self.resolve_up(super())
586
587         self.c_name = c_upper(self.enum_set.value_pfx + self.name)
588
589
590 class EnumSet(SpecEnumSet):
591     def __init__(self, family, yaml):
592         self.render_name = c_lower(family.name + '-' + yaml['name'])
593         self.enum_name = 'enum ' + self.render_name
594
595         self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
596
597         super().__init__(family, yaml)
598
599     def new_entry(self, entry, prev_entry, value_start):
600         return EnumEntry(self, entry, prev_entry, value_start)
601
602
603 class AttrSet(SpecAttrSet):
604     def __init__(self, family, yaml):
605         super().__init__(family, yaml)
606
607         if self.subset_of is None:
608             if 'name-prefix' in yaml:
609                 pfx = yaml['name-prefix']
610             elif self.name == family.name:
611                 pfx = family.name + '-a-'
612             else:
613                 pfx = f"{family.name}-a-{self.name}-"
614             self.name_prefix = c_upper(pfx)
615             self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
616         else:
617             self.name_prefix = family.attr_sets[self.subset_of].name_prefix
618             self.max_name = family.attr_sets[self.subset_of].max_name
619
620         # Added by resolve:
621         self.c_name = None
622         delattr(self, "c_name")
623
624     def resolve(self):
625         self.c_name = c_lower(self.name)
626         if self.c_name in _C_KW:
627             self.c_name += '_'
628         if self.c_name == self.family.c_name:
629             self.c_name = ''
630
631     def new_attr(self, elem, value):
632         if 'multi-attr' in elem and elem['multi-attr']:
633             return TypeMultiAttr(self.family, self, elem, value)
634         elif elem['type'] in scalars:
635             return TypeScalar(self.family, self, elem, value)
636         elif elem['type'] == 'unused':
637             return TypeUnused(self.family, self, elem, value)
638         elif elem['type'] == 'pad':
639             return TypePad(self.family, self, elem, value)
640         elif elem['type'] == 'flag':
641             return TypeFlag(self.family, self, elem, value)
642         elif elem['type'] == 'string':
643             return TypeString(self.family, self, elem, value)
644         elif elem['type'] == 'binary':
645             return TypeBinary(self.family, self, elem, value)
646         elif elem['type'] == 'nest':
647             return TypeNest(self.family, self, elem, value)
648         elif elem['type'] == 'array-nest':
649             return TypeArrayNest(self.family, self, elem, value)
650         elif elem['type'] == 'nest-type-value':
651             return TypeNestTypeValue(self.family, self, elem, value)
652         else:
653             raise Exception(f"No typed class for type {elem['type']}")
654
655
656 class Operation(SpecOperation):
657     def __init__(self, family, yaml, req_value, rsp_value):
658         super().__init__(family, yaml, req_value, rsp_value)
659
660         if req_value != rsp_value:
661             raise Exception("Directional messages not supported by codegen")
662
663         self.render_name = family.name + '_' + c_lower(self.name)
664
665         self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
666                          ('dump' in yaml and 'request' in yaml['dump'])
667
668         # Added by resolve:
669         self.enum_name = None
670         delattr(self, "enum_name")
671
672     def resolve(self):
673         self.resolve_up(super())
674
675         if not self.is_async:
676             self.enum_name = self.family.op_prefix + c_upper(self.name)
677         else:
678             self.enum_name = self.family.async_op_prefix + c_upper(self.name)
679
680     def add_notification(self, op):
681         if 'notify' not in self.yaml:
682             self.yaml['notify'] = dict()
683             self.yaml['notify']['reply'] = self.yaml['do']['reply']
684             self.yaml['notify']['cmds'] = []
685         self.yaml['notify']['cmds'].append(op)
686
687
688 class Family(SpecFamily):
689     def __init__(self, file_name):
690         # Added by resolve:
691         self.c_name = None
692         delattr(self, "c_name")
693         self.op_prefix = None
694         delattr(self, "op_prefix")
695         self.async_op_prefix = None
696         delattr(self, "async_op_prefix")
697         self.mcgrps = None
698         delattr(self, "mcgrps")
699         self.consts = None
700         delattr(self, "consts")
701         self.hooks = None
702         delattr(self, "hooks")
703
704         super().__init__(file_name)
705
706         self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
707         self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
708
709         if 'definitions' not in self.yaml:
710             self.yaml['definitions'] = []
711
712         if 'uapi-header' in self.yaml:
713             self.uapi_header = self.yaml['uapi-header']
714         else:
715             self.uapi_header = f"linux/{self.name}.h"
716
717     def resolve(self):
718         self.resolve_up(super())
719
720         if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
721             raise Exception("Codegen only supported for genetlink")
722
723         self.c_name = c_lower(self.name)
724         if 'name-prefix' in self.yaml['operations']:
725             self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
726         else:
727             self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
728         if 'async-prefix' in self.yaml['operations']:
729             self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
730         else:
731             self.async_op_prefix = self.op_prefix
732
733         self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
734
735         self.hooks = dict()
736         for when in ['pre', 'post']:
737             self.hooks[when] = dict()
738             for op_mode in ['do', 'dump']:
739                 self.hooks[when][op_mode] = dict()
740                 self.hooks[when][op_mode]['set'] = set()
741                 self.hooks[when][op_mode]['list'] = []
742
743         # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
744         self.root_sets = dict()
745         # dict space-name -> set('request', 'reply')
746         self.pure_nested_structs = dict()
747         self.all_notify = dict()
748
749         self._mock_up_events()
750
751         self._dictify()
752         self._load_root_sets()
753         self._load_nested_sets()
754         self._load_all_notify()
755         self._load_hooks()
756
757         self.kernel_policy = self.yaml.get('kernel-policy', 'split')
758         if self.kernel_policy == 'global':
759             self._load_global_policy()
760
761     def new_enum(self, elem):
762         return EnumSet(self, elem)
763
764     def new_attr_set(self, elem):
765         return AttrSet(self, elem)
766
767     def new_operation(self, elem, req_value, rsp_value):
768         return Operation(self, elem, req_value, rsp_value)
769
770     # Fake a 'do' equivalent of all events, so that we can render their response parsing
771     def _mock_up_events(self):
772         for op in self.yaml['operations']['list']:
773             if 'event' in op:
774                 op['do'] = {
775                     'reply': {
776                         'attributes': op['event']['attributes']
777                     }
778                 }
779
780     def _dictify(self):
781         ntf = []
782         for msg in self.msgs.values():
783             if 'notify' in msg:
784                 ntf.append(msg)
785         for n in ntf:
786             self.ops[n['notify']].add_notification(n)
787
788     def _load_root_sets(self):
789         for op_name, op in self.ops.items():
790             if 'attribute-set' not in op:
791                 continue
792
793             req_attrs = set()
794             rsp_attrs = set()
795             for op_mode in ['do', 'dump']:
796                 if op_mode in op and 'request' in op[op_mode]:
797                     req_attrs.update(set(op[op_mode]['request']['attributes']))
798                 if op_mode in op and 'reply' in op[op_mode]:
799                     rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
800
801             if op['attribute-set'] not in self.root_sets:
802                 self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
803             else:
804                 self.root_sets[op['attribute-set']]['request'].update(req_attrs)
805                 self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
806
807     def _load_nested_sets(self):
808         for root_set, rs_members in self.root_sets.items():
809             for attr, spec in self.attr_sets[root_set].items():
810                 if 'nested-attributes' in spec:
811                     inherit = set()
812                     nested = spec['nested-attributes']
813                     if nested not in self.root_sets:
814                         self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
815                     if attr in rs_members['request']:
816                         self.pure_nested_structs[nested].request = True
817                     if attr in rs_members['reply']:
818                         self.pure_nested_structs[nested].reply = True
819
820                     if 'type-value' in spec:
821                         if nested in self.root_sets:
822                             raise Exception("Inheriting members to a space used as root not supported")
823                         inherit.update(set(spec['type-value']))
824                     elif spec['type'] == 'array-nest':
825                         inherit.add('idx')
826                     self.pure_nested_structs[nested].set_inherited(inherit)
827
828     def _load_all_notify(self):
829         for op_name, op in self.ops.items():
830             if not op:
831                 continue
832
833             if 'notify' in op:
834                 self.all_notify[op_name] = op['notify']['cmds']
835
836     def _load_global_policy(self):
837         global_set = set()
838         attr_set_name = None
839         for op_name, op in self.ops.items():
840             if not op:
841                 continue
842             if 'attribute-set' not in op:
843                 continue
844
845             if attr_set_name is None:
846                 attr_set_name = op['attribute-set']
847             if attr_set_name != op['attribute-set']:
848                 raise Exception('For a global policy all ops must use the same set')
849
850             for op_mode in ['do', 'dump']:
851                 if op_mode in op:
852                     global_set.update(op[op_mode].get('request', []))
853
854         self.global_policy = []
855         self.global_policy_set = attr_set_name
856         for attr in self.attr_sets[attr_set_name]:
857             if attr in global_set:
858                 self.global_policy.append(attr)
859
860     def _load_hooks(self):
861         for op in self.ops.values():
862             for op_mode in ['do', 'dump']:
863                 if op_mode not in op:
864                     continue
865                 for when in ['pre', 'post']:
866                     if when not in op[op_mode]:
867                         continue
868                     name = op[op_mode][when]
869                     if name in self.hooks[when][op_mode]['set']:
870                         continue
871                     self.hooks[when][op_mode]['set'].add(name)
872                     self.hooks[when][op_mode]['list'].append(name)
873
874
875 class RenderInfo:
876     def __init__(self, cw, family, ku_space, op, op_name, op_mode, attr_set=None):
877         self.family = family
878         self.nl = cw.nlib
879         self.ku_space = ku_space
880         self.op = op
881         self.op_name = op_name
882         self.op_mode = op_mode
883
884         # 'do' and 'dump' response parsing is identical
885         if op_mode != 'do' and 'dump' in op and 'do' in op and 'reply' in op['do'] and \
886            op["do"]["reply"] == op["dump"]["reply"]:
887             self.type_consistent = True
888         else:
889             self.type_consistent = op_mode == 'event'
890
891         self.attr_set = attr_set
892         if not self.attr_set:
893             self.attr_set = op['attribute-set']
894
895         if op:
896             self.type_name = c_lower(op_name)
897         else:
898             self.type_name = c_lower(attr_set)
899
900         self.cw = cw
901
902         self.struct = dict()
903         for op_dir in ['request', 'reply']:
904             if op and op_dir in op[op_mode]:
905                 self.struct[op_dir] = Struct(family, self.attr_set,
906                                              type_list=op[op_mode][op_dir]['attributes'])
907         if op_mode == 'event':
908             self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
909
910
911 class CodeWriter:
912     def __init__(self, nlib, out_file):
913         self.nlib = nlib
914
915         self._nl = False
916         self._silent_block = False
917         self._ind = 0
918         self._out = out_file
919
920     @classmethod
921     def _is_cond(cls, line):
922         return line.startswith('if') or line.startswith('while') or line.startswith('for')
923
924     def p(self, line, add_ind=0):
925         if self._nl:
926             self._out.write('\n')
927             self._nl = False
928         ind = self._ind
929         if line[-1] == ':':
930             ind -= 1
931         if self._silent_block:
932             ind += 1
933         self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
934         if add_ind:
935             ind += add_ind
936         self._out.write('\t' * ind + line + '\n')
937
938     def nl(self):
939         self._nl = True
940
941     def block_start(self, line=''):
942         if line:
943             line = line + ' '
944         self.p(line + '{')
945         self._ind += 1
946
947     def block_end(self, line=''):
948         if line and line[0] not in {';', ','}:
949             line = ' ' + line
950         self._ind -= 1
951         self.p('}' + line)
952
953     def write_doc_line(self, doc, indent=True):
954         words = doc.split()
955         line = ' *'
956         for word in words:
957             if len(line) + len(word) >= 79:
958                 self.p(line)
959                 line = ' *'
960                 if indent:
961                     line += '  '
962             line += ' ' + word
963         self.p(line)
964
965     def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
966         if not args:
967             args = ['void']
968
969         if doc:
970             self.p('/*')
971             self.p(' * ' + doc)
972             self.p(' */')
973
974         oneline = qual_ret
975         if qual_ret[-1] != '*':
976             oneline += ' '
977         oneline += f"{name}({', '.join(args)}){suffix}"
978
979         if len(oneline) < 80:
980             self.p(oneline)
981             return
982
983         v = qual_ret
984         if len(v) > 3:
985             self.p(v)
986             v = ''
987         elif qual_ret[-1] != '*':
988             v += ' '
989         v += name + '('
990         ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
991         delta_ind = len(v) - len(ind)
992         v += args[0]
993         i = 1
994         while i < len(args):
995             next_len = len(v) + len(args[i])
996             if v[0] == '\t':
997                 next_len += delta_ind
998             if next_len > 76:
999                 self.p(v + ',')
1000                 v = ind
1001             else:
1002                 v += ', '
1003             v += args[i]
1004             i += 1
1005         self.p(v + ')' + suffix)
1006
1007     def write_func_lvar(self, local_vars):
1008         if not local_vars:
1009             return
1010
1011         if type(local_vars) is str:
1012             local_vars = [local_vars]
1013
1014         local_vars.sort(key=len, reverse=True)
1015         for var in local_vars:
1016             self.p(var)
1017         self.nl()
1018
1019     def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1020         self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1021         self.write_func_lvar(local_vars=local_vars)
1022
1023         self.block_start()
1024         for line in body:
1025             self.p(line)
1026         self.block_end()
1027
1028     def writes_defines(self, defines):
1029         longest = 0
1030         for define in defines:
1031             if len(define[0]) > longest:
1032                 longest = len(define[0])
1033         longest = ((longest + 8) // 8) * 8
1034         for define in defines:
1035             line = '#define ' + define[0]
1036             line += '\t' * ((longest - len(define[0]) + 7) // 8)
1037             if type(define[1]) is int:
1038                 line += str(define[1])
1039             elif type(define[1]) is str:
1040                 line += '"' + define[1] + '"'
1041             self.p(line)
1042
1043     def write_struct_init(self, members):
1044         longest = max([len(x[0]) for x in members])
1045         longest += 1  # because we prepend a .
1046         longest = ((longest + 8) // 8) * 8
1047         for one in members:
1048             line = '.' + one[0]
1049             line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1050             line += '= ' + one[1] + ','
1051             self.p(line)
1052
1053
1054 scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1055
1056 direction_to_suffix = {
1057     'reply': '_rsp',
1058     'request': '_req',
1059     '': ''
1060 }
1061
1062 op_mode_to_wrapper = {
1063     'do': '',
1064     'dump': '_list',
1065     'notify': '_ntf',
1066     'event': '',
1067 }
1068
1069 _C_KW = {
1070     'do'
1071 }
1072
1073
1074 def rdir(direction):
1075     if direction == 'reply':
1076         return 'request'
1077     if direction == 'request':
1078         return 'reply'
1079     return direction
1080
1081
1082 def op_prefix(ri, direction, deref=False):
1083     suffix = f"_{ri.type_name}"
1084
1085     if not ri.op_mode or ri.op_mode == 'do':
1086         suffix += f"{direction_to_suffix[direction]}"
1087     else:
1088         if direction == 'request':
1089             suffix += '_req_dump'
1090         else:
1091             if ri.type_consistent:
1092                 if deref:
1093                     suffix += f"{direction_to_suffix[direction]}"
1094                 else:
1095                     suffix += op_mode_to_wrapper[ri.op_mode]
1096             else:
1097                 suffix += '_rsp'
1098                 suffix += '_dump' if deref else '_list'
1099
1100     return f"{ri.family['name']}{suffix}"
1101
1102
1103 def type_name(ri, direction, deref=False):
1104     return f"struct {op_prefix(ri, direction, deref=deref)}"
1105
1106
1107 def print_prototype(ri, direction, terminate=True, doc=None):
1108     suffix = ';' if terminate else ''
1109
1110     fname = ri.op.render_name
1111     if ri.op_mode == 'dump':
1112         fname += '_dump'
1113
1114     args = ['struct ynl_sock *ys']
1115     if 'request' in ri.op[ri.op_mode]:
1116         args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1117
1118     ret = 'int'
1119     if 'reply' in ri.op[ri.op_mode]:
1120         ret = f"{type_name(ri, rdir(direction))} *"
1121
1122     ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1123
1124
1125 def print_req_prototype(ri):
1126     print_prototype(ri, "request", doc=ri.op['doc'])
1127
1128
1129 def print_dump_prototype(ri):
1130     print_prototype(ri, "request")
1131
1132
1133 def put_typol_fwd(cw, struct):
1134     cw.p(f'extern struct ynl_policy_nest {struct.render_name}_nest;')
1135
1136
1137 def put_typol(cw, struct):
1138     type_max = struct.attr_set.max_name
1139     cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1140
1141     for _, arg in struct.member_list():
1142         arg.attr_typol(cw)
1143
1144     cw.block_end(line=';')
1145     cw.nl()
1146
1147     cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1148     cw.p(f'.max_attr = {type_max},')
1149     cw.p(f'.table = {struct.render_name}_policy,')
1150     cw.block_end(line=';')
1151     cw.nl()
1152
1153
1154 def put_req_nested(ri, struct):
1155     func_args = ['struct nlmsghdr *nlh',
1156                  'unsigned int attr_type',
1157                  f'{struct.ptr_name}obj']
1158
1159     ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1160     ri.cw.block_start()
1161     ri.cw.write_func_lvar('struct nlattr *nest;')
1162
1163     ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1164
1165     for _, arg in struct.member_list():
1166         arg.attr_put(ri, "obj")
1167
1168     ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1169
1170     ri.cw.nl()
1171     ri.cw.p('return 0;')
1172     ri.cw.block_end()
1173     ri.cw.nl()
1174
1175
1176 def _multi_parse(ri, struct, init_lines, local_vars):
1177     if struct.nested:
1178         iter_line = "mnl_attr_for_each_nested(attr, nested)"
1179     else:
1180         iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1181
1182     array_nests = set()
1183     multi_attrs = set()
1184     needs_parg = False
1185     for arg, aspec in struct.member_list():
1186         if aspec['type'] == 'array-nest':
1187             local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1188             array_nests.add(arg)
1189         if 'multi-attr' in aspec:
1190             multi_attrs.add(arg)
1191         needs_parg |= 'nested-attributes' in aspec
1192     if array_nests or multi_attrs:
1193         local_vars.append('int i;')
1194     if needs_parg:
1195         local_vars.append('struct ynl_parse_arg parg;')
1196         init_lines.append('parg.ys = yarg->ys;')
1197
1198     ri.cw.block_start()
1199     ri.cw.write_func_lvar(local_vars)
1200
1201     for line in init_lines:
1202         ri.cw.p(line)
1203     ri.cw.nl()
1204
1205     for arg in struct.inherited:
1206         ri.cw.p(f'dst->{arg} = {arg};')
1207
1208     ri.cw.nl()
1209     ri.cw.block_start(line=iter_line)
1210
1211     first = True
1212     for _, arg in struct.member_list():
1213         arg.attr_get(ri, 'dst', first=first)
1214         first = False
1215
1216     ri.cw.block_end()
1217     ri.cw.nl()
1218
1219     for anest in sorted(array_nests):
1220         aspec = struct[anest]
1221
1222         ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})")
1223         ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1224         ri.cw.p('i = 0;')
1225         ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1226         ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1227         ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1228         ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1229         ri.cw.p('return MNL_CB_ERROR;')
1230         ri.cw.p('i++;')
1231         ri.cw.block_end()
1232         ri.cw.block_end()
1233     ri.cw.nl()
1234
1235     for anest in sorted(multi_attrs):
1236         aspec = struct[anest]
1237         ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})")
1238         ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1239         ri.cw.p('i = 0;')
1240         if 'nested-attributes' in aspec:
1241             ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1242         ri.cw.block_start(line=iter_line)
1243         ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1244         if 'nested-attributes' in aspec:
1245             ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1246             ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1247             ri.cw.p('return MNL_CB_ERROR;')
1248         elif aspec['type'] in scalars:
1249             t = aspec['type']
1250             if t[0] == 's':
1251                 t = 'u' + t[1:]
1252             ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1253         else:
1254             raise Exception('Nest parsing type not supported yet')
1255         ri.cw.p('i++;')
1256         ri.cw.block_end()
1257         ri.cw.block_end()
1258         ri.cw.block_end()
1259     ri.cw.nl()
1260
1261     if struct.nested:
1262         ri.cw.p('return 0;')
1263     else:
1264         ri.cw.p('return MNL_CB_OK;')
1265     ri.cw.block_end()
1266     ri.cw.nl()
1267
1268
1269 def parse_rsp_nested(ri, struct):
1270     func_args = ['struct ynl_parse_arg *yarg',
1271                  'const struct nlattr *nested']
1272     for arg in struct.inherited:
1273         func_args.append('__u32 ' + arg)
1274
1275     local_vars = ['const struct nlattr *attr;',
1276                   f'{struct.ptr_name}dst = yarg->data;']
1277     init_lines = []
1278
1279     ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1280
1281     _multi_parse(ri, struct, init_lines, local_vars)
1282
1283
1284 def parse_rsp_msg(ri, deref=False):
1285     if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1286         return
1287
1288     func_args = ['const struct nlmsghdr *nlh',
1289                  'void *data']
1290
1291     local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1292                   'struct ynl_parse_arg *yarg = data;',
1293                   'const struct nlattr *attr;']
1294     init_lines = ['dst = yarg->data;']
1295
1296     ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1297
1298     _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1299
1300
1301 def print_req(ri):
1302     ret_ok = '0'
1303     ret_err = '-1'
1304     direction = "request"
1305     local_vars = ['struct nlmsghdr *nlh;',
1306                   'int len, err;']
1307
1308     if 'reply' in ri.op[ri.op_mode]:
1309         ret_ok = 'rsp'
1310         ret_err = 'NULL'
1311         local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1312                        'struct ynl_parse_arg yarg = { .ys = ys, };']
1313
1314     print_prototype(ri, direction, terminate=False)
1315     ri.cw.block_start()
1316     ri.cw.write_func_lvar(local_vars)
1317
1318     ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1319
1320     ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1321     if 'reply' in ri.op[ri.op_mode]:
1322         ri.cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1323     ri.cw.nl()
1324     for _, attr in ri.struct["request"].member_list():
1325         attr.attr_put(ri, "req")
1326     ri.cw.nl()
1327
1328     ri.cw.p('err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);')
1329     ri.cw.p('if (err < 0)')
1330     ri.cw.p(f"return {ret_err};")
1331     ri.cw.nl()
1332     ri.cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1333     ri.cw.p('if (len < 0)')
1334     ri.cw.p(f"return {ret_err};")
1335     ri.cw.nl()
1336
1337     if 'reply' in ri.op[ri.op_mode]:
1338         ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1339         ri.cw.p('yarg.data = rsp;')
1340         ri.cw.nl()
1341         ri.cw.p(f"err = {ri.nl.parse_cb_run(op_prefix(ri, 'reply') + '_parse', '&yarg', False)};")
1342         ri.cw.p('if (err < 0)')
1343         ri.cw.p('goto err_free;')
1344         ri.cw.nl()
1345
1346     ri.cw.p('err = ynl_recv_ack(ys, err);')
1347     ri.cw.p('if (err)')
1348     ri.cw.p('goto err_free;')
1349     ri.cw.nl()
1350     ri.cw.p(f"return {ret_ok};")
1351     ri.cw.nl()
1352     ri.cw.p('err_free:')
1353
1354     if 'reply' in ri.op[ri.op_mode]:
1355         ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1356     ri.cw.p(f"return {ret_err};")
1357     ri.cw.block_end()
1358
1359
1360 def print_dump(ri):
1361     direction = "request"
1362     print_prototype(ri, direction, terminate=False)
1363     ri.cw.block_start()
1364     local_vars = ['struct ynl_dump_state yds = {};',
1365                   'struct nlmsghdr *nlh;',
1366                   'int len, err;']
1367
1368     for var in local_vars:
1369         ri.cw.p(f'{var}')
1370     ri.cw.nl()
1371
1372     ri.cw.p('yds.ys = ys;')
1373     ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1374     ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1375     ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1376     ri.cw.nl()
1377     ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1378
1379     if "request" in ri.op[ri.op_mode]:
1380         ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1381         ri.cw.nl()
1382         for _, attr in ri.struct["request"].member_list():
1383             attr.attr_put(ri, "req")
1384     ri.cw.nl()
1385
1386     ri.cw.p('err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);')
1387     ri.cw.p('if (err < 0)')
1388     ri.cw.p('return NULL;')
1389     ri.cw.nl()
1390
1391     ri.cw.block_start(line='do')
1392     ri.cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1393     ri.cw.p('if (len < 0)')
1394     ri.cw.p('goto free_list;')
1395     ri.cw.nl()
1396     ri.cw.p(f"err = {ri.nl.parse_cb_run('ynl_dump_trampoline', '&yds', False, indent=2)};")
1397     ri.cw.p('if (err < 0)')
1398     ri.cw.p('goto free_list;')
1399     ri.cw.block_end(line='while (err > 0);')
1400     ri.cw.nl()
1401
1402     ri.cw.p('return yds.first;')
1403     ri.cw.nl()
1404     ri.cw.p('free_list:')
1405     ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1406     ri.cw.p('return NULL;')
1407     ri.cw.block_end()
1408
1409
1410 def call_free(ri, direction, var):
1411     return f"{op_prefix(ri, direction)}_free({var});"
1412
1413
1414 def free_arg_name(direction):
1415     if direction:
1416         return direction_to_suffix[direction][1:]
1417     return 'obj'
1418
1419
1420 def print_free_prototype(ri, direction, suffix=';'):
1421     name = op_prefix(ri, direction)
1422     arg = free_arg_name(direction)
1423     ri.cw.write_func_prot('void', f"{name}_free", [f"struct {name} *{arg}"], suffix=suffix)
1424
1425
1426 def _print_type(ri, direction, struct):
1427     suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1428
1429     if ri.op_mode == 'dump':
1430         suffix += '_dump'
1431
1432     ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1433
1434     meta_started = False
1435     for _, attr in struct.member_list():
1436         for type_filter in ['len', 'bit']:
1437             line = attr.presence_member(ri.ku_space, type_filter)
1438             if line:
1439                 if not meta_started:
1440                     ri.cw.block_start(line=f"struct")
1441                     meta_started = True
1442                 ri.cw.p(line)
1443     if meta_started:
1444         ri.cw.block_end(line='_present;')
1445         ri.cw.nl()
1446
1447     for arg in struct.inherited:
1448         ri.cw.p(f"__u32 {arg};")
1449
1450     for _, attr in struct.member_list():
1451         attr.struct_member(ri)
1452
1453     ri.cw.block_end(line=';')
1454     ri.cw.nl()
1455
1456
1457 def print_type(ri, direction):
1458     _print_type(ri, direction, ri.struct[direction])
1459
1460
1461 def print_type_full(ri, struct):
1462     _print_type(ri, "", struct)
1463
1464
1465 def print_type_helpers(ri, direction, deref=False):
1466     print_free_prototype(ri, direction)
1467
1468     if ri.ku_space == 'user' and direction == 'request':
1469         for _, attr in ri.struct[direction].member_list():
1470             attr.setter(ri, ri.attr_set, direction, deref=deref)
1471     ri.cw.nl()
1472
1473
1474 def print_req_type_helpers(ri):
1475     print_type_helpers(ri, "request")
1476
1477
1478 def print_rsp_type_helpers(ri):
1479     if 'reply' not in ri.op[ri.op_mode]:
1480         return
1481     print_type_helpers(ri, "reply")
1482
1483
1484 def print_parse_prototype(ri, direction, terminate=True):
1485     suffix = "_rsp" if direction == "reply" else "_req"
1486     term = ';' if terminate else ''
1487
1488     ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1489                           ['const struct nlattr **tb',
1490                            f"struct {ri.op.render_name}{suffix} *req"],
1491                           suffix=term)
1492
1493
1494 def print_req_type(ri):
1495     print_type(ri, "request")
1496
1497
1498 def print_rsp_type(ri):
1499     if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1500         direction = 'reply'
1501     elif ri.op_mode == 'event':
1502         direction = 'reply'
1503     else:
1504         return
1505     print_type(ri, direction)
1506
1507
1508 def print_wrapped_type(ri):
1509     ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1510     if ri.op_mode == 'dump':
1511         ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1512     elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1513         ri.cw.p('__u16 family;')
1514         ri.cw.p('__u8 cmd;')
1515         ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1516     ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1517     ri.cw.block_end(line=';')
1518     ri.cw.nl()
1519     print_free_prototype(ri, 'reply')
1520     ri.cw.nl()
1521
1522
1523 def _free_type_members_iter(ri, struct):
1524     for _, attr in struct.member_list():
1525         if attr.free_needs_iter():
1526             ri.cw.p('unsigned int i;')
1527             ri.cw.nl()
1528             break
1529
1530
1531 def _free_type_members(ri, var, struct, ref=''):
1532     for _, attr in struct.member_list():
1533         attr.free(ri, var, ref)
1534
1535
1536 def _free_type(ri, direction, struct):
1537     var = free_arg_name(direction)
1538
1539     print_free_prototype(ri, direction, suffix='')
1540     ri.cw.block_start()
1541     _free_type_members_iter(ri, struct)
1542     _free_type_members(ri, var, struct)
1543     if direction:
1544         ri.cw.p(f'free({var});')
1545     ri.cw.block_end()
1546     ri.cw.nl()
1547
1548
1549 def free_rsp_nested(ri, struct):
1550     _free_type(ri, "", struct)
1551
1552
1553 def print_rsp_free(ri):
1554     if 'reply' not in ri.op[ri.op_mode]:
1555         return
1556     _free_type(ri, 'reply', ri.struct['reply'])
1557
1558
1559 def print_dump_type_free(ri):
1560     sub_type = type_name(ri, 'reply')
1561
1562     print_free_prototype(ri, 'reply', suffix='')
1563     ri.cw.block_start()
1564     ri.cw.p(f"{sub_type} *next = rsp;")
1565     ri.cw.nl()
1566     ri.cw.block_start(line='while (next)')
1567     _free_type_members_iter(ri, ri.struct['reply'])
1568     ri.cw.p('rsp = next;')
1569     ri.cw.p('next = rsp->next;')
1570     ri.cw.nl()
1571
1572     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1573     ri.cw.p(f'free(rsp);')
1574     ri.cw.block_end()
1575     ri.cw.block_end()
1576     ri.cw.nl()
1577
1578
1579 def print_ntf_type_free(ri):
1580     print_free_prototype(ri, 'reply', suffix='')
1581     ri.cw.block_start()
1582     _free_type_members_iter(ri, ri.struct['reply'])
1583     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1584     ri.cw.p(f'free(rsp);')
1585     ri.cw.block_end()
1586     ri.cw.nl()
1587
1588
1589 def print_ntf_parse_prototype(family, cw, suffix=';'):
1590     cw.write_func_prot('struct ynl_ntf_base_type *', f"{family['name']}_ntf_parse",
1591                        ['struct ynl_sock *ys'], suffix=suffix)
1592
1593
1594 def print_ntf_type_parse(family, cw, ku_mode):
1595     print_ntf_parse_prototype(family, cw, suffix='')
1596     cw.block_start()
1597     cw.write_func_lvar(['struct genlmsghdr *genlh;',
1598                         'struct nlmsghdr *nlh;',
1599                         'struct ynl_parse_arg yarg = { .ys = ys, };',
1600                         'struct ynl_ntf_base_type *rsp;',
1601                         'int len, err;',
1602                         'mnl_cb_t parse;'])
1603     cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1604     cw.p('if (len < (ssize_t)(sizeof(*nlh) + sizeof(*genlh)))')
1605     cw.p('return NULL;')
1606     cw.nl()
1607     cw.p('nlh = (struct nlmsghdr *)ys->rx_buf;')
1608     cw.p('genlh = mnl_nlmsg_get_payload(nlh);')
1609     cw.nl()
1610     cw.block_start(line='switch (genlh->cmd)')
1611     for ntf_op in sorted(family.all_notify.keys()):
1612         op = family.ops[ntf_op]
1613         ri = RenderInfo(cw, family, ku_mode, op, ntf_op, "notify")
1614         for ntf in op['notify']['cmds']:
1615             cw.p(f"case {ntf.enum_name}:")
1616         cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'notify')}));")
1617         cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;")
1618         cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1619         cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;")
1620         cw.p('break;')
1621     for op_name, op in family.ops.items():
1622         if 'event' not in op:
1623             continue
1624         ri = RenderInfo(cw, family, ku_mode, op, op_name, "event")
1625         cw.p(f"case {op.enum_name}:")
1626         cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'event')}));")
1627         cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;")
1628         cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1629         cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;")
1630         cw.p('break;')
1631     cw.p('default:')
1632     cw.p('ynl_error_unknown_notification(ys, genlh->cmd);')
1633     cw.p('return NULL;')
1634     cw.block_end()
1635     cw.nl()
1636     cw.p('yarg.data = rsp->data;')
1637     cw.nl()
1638     cw.p(f"err = {cw.nlib.parse_cb_run('parse', '&yarg', True)};")
1639     cw.p('if (err < 0)')
1640     cw.p('goto err_free;')
1641     cw.nl()
1642     cw.p('rsp->family = nlh->nlmsg_type;')
1643     cw.p('rsp->cmd = genlh->cmd;')
1644     cw.p('return rsp;')
1645     cw.nl()
1646     cw.p('err_free:')
1647     cw.p('free(rsp);')
1648     cw.p('return NULL;')
1649     cw.block_end()
1650     cw.nl()
1651
1652
1653 def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1654     if terminate and ri and kernel_can_gen_family_struct(struct.family):
1655         return
1656
1657     if terminate:
1658         prefix = 'extern '
1659     else:
1660         if kernel_can_gen_family_struct(struct.family) and ri:
1661             prefix = 'static '
1662         else:
1663             prefix = ''
1664
1665     suffix = ';' if terminate else ' = {'
1666
1667     max_attr = struct.attr_max_val
1668     if ri:
1669         name = ri.op.render_name
1670         if ri.op.dual_policy:
1671             name += '_' + ri.op_mode
1672     else:
1673         name = struct.render_name
1674     cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1675
1676
1677 def print_req_policy(cw, struct, ri=None):
1678     print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1679     for _, arg in struct.member_list():
1680         arg.attr_policy(cw)
1681     cw.p("};")
1682
1683
1684 def kernel_can_gen_family_struct(family):
1685     return family.proto == 'genetlink'
1686
1687
1688 def print_kernel_op_table_fwd(family, cw, terminate):
1689     exported = not kernel_can_gen_family_struct(family)
1690
1691     if not terminate or exported:
1692         cw.p(f"/* Ops table for {family.name} */")
1693
1694         pol_to_struct = {'global': 'genl_small_ops',
1695                          'per-op': 'genl_ops',
1696                          'split': 'genl_split_ops'}
1697         struct_type = pol_to_struct[family.kernel_policy]
1698
1699         if family.kernel_policy == 'split':
1700             cnt = 0
1701             for op in family.ops.values():
1702                 if 'do' in op:
1703                     cnt += 1
1704                 if 'dump' in op:
1705                     cnt += 1
1706         else:
1707             cnt = len(family.ops)
1708
1709         qual = 'static const' if not exported else 'const'
1710         line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1711         if terminate:
1712             cw.p(f"extern {line};")
1713         else:
1714             cw.block_start(line=line + ' =')
1715
1716     if not terminate:
1717         return
1718
1719     cw.nl()
1720     for name in family.hooks['pre']['do']['list']:
1721         cw.write_func_prot('int', c_lower(name),
1722                            ['const struct genl_split_ops *ops',
1723                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1724     for name in family.hooks['post']['do']['list']:
1725         cw.write_func_prot('void', c_lower(name),
1726                            ['const struct genl_split_ops *ops',
1727                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1728     for name in family.hooks['pre']['dump']['list']:
1729         cw.write_func_prot('int', c_lower(name),
1730                            ['struct netlink_callback *cb'], suffix=';')
1731     for name in family.hooks['post']['dump']['list']:
1732         cw.write_func_prot('int', c_lower(name),
1733                            ['struct netlink_callback *cb'], suffix=';')
1734
1735     cw.nl()
1736
1737     for op_name, op in family.ops.items():
1738         if op.is_async:
1739             continue
1740
1741         if 'do' in op:
1742             name = c_lower(f"{family.name}-nl-{op_name}-doit")
1743             cw.write_func_prot('int', name,
1744                                ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1745
1746         if 'dump' in op:
1747             name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1748             cw.write_func_prot('int', name,
1749                                ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1750     cw.nl()
1751
1752
1753 def print_kernel_op_table_hdr(family, cw):
1754     print_kernel_op_table_fwd(family, cw, terminate=True)
1755
1756
1757 def print_kernel_op_table(family, cw):
1758     print_kernel_op_table_fwd(family, cw, terminate=False)
1759     if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1760         for op_name, op in family.ops.items():
1761             if op.is_async:
1762                 continue
1763
1764             cw.block_start()
1765             members = [('cmd', op.enum_name)]
1766             if 'dont-validate' in op:
1767                 members.append(('validate',
1768                                 ' | '.join([c_upper('genl-dont-validate-' + x)
1769                                             for x in op['dont-validate']])), )
1770             for op_mode in ['do', 'dump']:
1771                 if op_mode in op:
1772                     name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1773                     members.append((op_mode + 'it', name))
1774             if family.kernel_policy == 'per-op':
1775                 struct = Struct(family, op['attribute-set'],
1776                                 type_list=op['do']['request']['attributes'])
1777
1778                 name = c_lower(f"{family.name}-{op_name}-nl-policy")
1779                 members.append(('policy', name))
1780                 members.append(('maxattr', struct.attr_max_val.enum_name))
1781             if 'flags' in op:
1782                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1783             cw.write_struct_init(members)
1784             cw.block_end(line=',')
1785     elif family.kernel_policy == 'split':
1786         cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1787                     'dump': {'pre': 'start', 'post': 'done'}}
1788
1789         for op_name, op in family.ops.items():
1790             for op_mode in ['do', 'dump']:
1791                 if op.is_async or op_mode not in op:
1792                     continue
1793
1794                 cw.block_start()
1795                 members = [('cmd', op.enum_name)]
1796                 if 'dont-validate' in op:
1797                     members.append(('validate',
1798                                     ' | '.join([c_upper('genl-dont-validate-' + x)
1799                                                 for x in op['dont-validate']])), )
1800                 name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1801                 if 'pre' in op[op_mode]:
1802                     members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
1803                 members.append((op_mode + 'it', name))
1804                 if 'post' in op[op_mode]:
1805                     members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
1806                 if 'request' in op[op_mode]:
1807                     struct = Struct(family, op['attribute-set'],
1808                                     type_list=op[op_mode]['request']['attributes'])
1809
1810                     if op.dual_policy:
1811                         name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
1812                     else:
1813                         name = c_lower(f"{family.name}-{op_name}-nl-policy")
1814                     members.append(('policy', name))
1815                     members.append(('maxattr', struct.attr_max_val.enum_name))
1816                 flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
1817                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
1818                 cw.write_struct_init(members)
1819                 cw.block_end(line=',')
1820
1821     cw.block_end(line=';')
1822     cw.nl()
1823
1824
1825 def print_kernel_mcgrp_hdr(family, cw):
1826     if not family.mcgrps['list']:
1827         return
1828
1829     cw.block_start('enum')
1830     for grp in family.mcgrps['list']:
1831         grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
1832         cw.p(grp_id)
1833     cw.block_end(';')
1834     cw.nl()
1835
1836
1837 def print_kernel_mcgrp_src(family, cw):
1838     if not family.mcgrps['list']:
1839         return
1840
1841     cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
1842     for grp in family.mcgrps['list']:
1843         name = grp['name']
1844         grp_id = c_upper(f"{family.name}-nlgrp-{name}")
1845         cw.p('[' + grp_id + '] = { "' + name + '", },')
1846     cw.block_end(';')
1847     cw.nl()
1848
1849
1850 def print_kernel_family_struct_hdr(family, cw):
1851     if not kernel_can_gen_family_struct(family):
1852         return
1853
1854     cw.p(f"extern struct genl_family {family.name}_nl_family;")
1855     cw.nl()
1856
1857
1858 def print_kernel_family_struct_src(family, cw):
1859     if not kernel_can_gen_family_struct(family):
1860         return
1861
1862     cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
1863     cw.p('.name\t\t= ' + family.fam_key + ',')
1864     cw.p('.version\t= ' + family.ver_key + ',')
1865     cw.p('.netnsok\t= true,')
1866     cw.p('.parallel_ops\t= true,')
1867     cw.p('.module\t\t= THIS_MODULE,')
1868     if family.kernel_policy == 'per-op':
1869         cw.p(f'.ops\t\t= {family.name}_nl_ops,')
1870         cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
1871     elif family.kernel_policy == 'split':
1872         cw.p(f'.split_ops\t= {family.name}_nl_ops,')
1873         cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
1874     if family.mcgrps['list']:
1875         cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
1876         cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
1877     cw.block_end(';')
1878
1879
1880 def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
1881     start_line = 'enum'
1882     if enum_name in obj:
1883         if obj[enum_name]:
1884             start_line = 'enum ' + c_lower(obj[enum_name])
1885     elif ckey and ckey in obj:
1886         start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
1887     cw.block_start(line=start_line)
1888
1889
1890 def render_uapi(family, cw):
1891     hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
1892     cw.p('#ifndef ' + hdr_prot)
1893     cw.p('#define ' + hdr_prot)
1894     cw.nl()
1895
1896     defines = [(family.fam_key, family["name"]),
1897                (family.ver_key, family.get('version', 1))]
1898     cw.writes_defines(defines)
1899     cw.nl()
1900
1901     defines = []
1902     for const in family['definitions']:
1903         if const['type'] != 'const':
1904             cw.writes_defines(defines)
1905             defines = []
1906             cw.nl()
1907
1908         # Write kdoc for enum and flags (one day maybe also structs)
1909         if const['type'] == 'enum' or const['type'] == 'flags':
1910             enum = family.consts[const['name']]
1911
1912             if enum.has_doc():
1913                 cw.p('/**')
1914                 doc = ''
1915                 if 'doc' in enum:
1916                     doc = ' - ' + enum['doc']
1917                 cw.write_doc_line(enum.enum_name + doc)
1918                 for entry in enum.entries.values():
1919                     if entry.has_doc():
1920                         doc = '@' + entry.c_name + ': ' + entry['doc']
1921                         cw.write_doc_line(doc)
1922                 cw.p(' */')
1923
1924             uapi_enum_start(family, cw, const, 'name')
1925             name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
1926             for entry in enum.entries.values():
1927                 suffix = ','
1928                 if entry.value_change:
1929                     suffix = f" = {entry.user_value()}" + suffix
1930                 cw.p(entry.c_name + suffix)
1931
1932             if const.get('render-max', False):
1933                 cw.nl()
1934                 max_name = c_upper(name_pfx + 'max')
1935                 cw.p('__' + max_name + ',')
1936                 cw.p(max_name + ' = (__' + max_name + ' - 1)')
1937             cw.block_end(line=';')
1938             cw.nl()
1939         elif const['type'] == 'const':
1940             defines.append([c_upper(family.get('c-define-name',
1941                                                f"{family.name}-{const['name']}")),
1942                             const['value']])
1943
1944     if defines:
1945         cw.writes_defines(defines)
1946         cw.nl()
1947
1948     max_by_define = family.get('max-by-define', False)
1949
1950     for _, attr_set in family.attr_sets.items():
1951         if attr_set.subset_of:
1952             continue
1953
1954         cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
1955         max_value = f"({cnt_name} - 1)"
1956
1957         val = 0
1958         uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
1959         for _, attr in attr_set.items():
1960             suffix = ','
1961             if attr.value != val:
1962                 suffix = f" = {attr.value},"
1963                 val = attr.value
1964             val += 1
1965             cw.p(attr.enum_name + suffix)
1966         cw.nl()
1967         cw.p(cnt_name + ('' if max_by_define else ','))
1968         if not max_by_define:
1969             cw.p(f"{attr_set.max_name} = {max_value}")
1970         cw.block_end(line=';')
1971         if max_by_define:
1972             cw.p(f"#define {attr_set.max_name} {max_value}")
1973         cw.nl()
1974
1975     # Commands
1976     separate_ntf = 'async-prefix' in family['operations']
1977
1978     max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
1979     cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
1980     max_value = f"({cnt_name} - 1)"
1981
1982     uapi_enum_start(family, cw, family['operations'], 'enum-name')
1983     val = 0
1984     for op in family.msgs.values():
1985         if separate_ntf and ('notify' in op or 'event' in op):
1986             continue
1987
1988         suffix = ','
1989         if op.value != val:
1990             suffix = f" = {op.value},"
1991             val = op.value
1992         cw.p(op.enum_name + suffix)
1993         val += 1
1994     cw.nl()
1995     cw.p(cnt_name + ('' if max_by_define else ','))
1996     if not max_by_define:
1997         cw.p(f"{max_name} = {max_value}")
1998     cw.block_end(line=';')
1999     if max_by_define:
2000         cw.p(f"#define {max_name} {max_value}")
2001     cw.nl()
2002
2003     if separate_ntf:
2004         uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2005         for op in family.msgs.values():
2006             if separate_ntf and not ('notify' in op or 'event' in op):
2007                 continue
2008
2009             suffix = ','
2010             if 'value' in op:
2011                 suffix = f" = {op['value']},"
2012             cw.p(op.enum_name + suffix)
2013         cw.block_end(line=';')
2014         cw.nl()
2015
2016     # Multicast
2017     defines = []
2018     for grp in family.mcgrps['list']:
2019         name = grp['name']
2020         defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2021                         f'{name}'])
2022     cw.nl()
2023     if defines:
2024         cw.writes_defines(defines)
2025         cw.nl()
2026
2027     cw.p(f'#endif /* {hdr_prot} */')
2028
2029
2030 def find_kernel_root(full_path):
2031     sub_path = ''
2032     while True:
2033         sub_path = os.path.join(os.path.basename(full_path), sub_path)
2034         full_path = os.path.dirname(full_path)
2035         maintainers = os.path.join(full_path, "MAINTAINERS")
2036         if os.path.exists(maintainers):
2037             return full_path, sub_path[:-1]
2038
2039
2040 def main():
2041     parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2042     parser.add_argument('--mode', dest='mode', type=str, required=True)
2043     parser.add_argument('--spec', dest='spec', type=str, required=True)
2044     parser.add_argument('--header', dest='header', action='store_true', default=None)
2045     parser.add_argument('--source', dest='header', action='store_false')
2046     parser.add_argument('--user-header', nargs='+', default=[])
2047     parser.add_argument('-o', dest='out_file', type=str)
2048     args = parser.parse_args()
2049
2050     out_file = open(args.out_file, 'w+') if args.out_file else os.sys.stdout
2051
2052     if args.header is None:
2053         parser.error("--header or --source is required")
2054
2055     try:
2056         parsed = Family(args.spec)
2057     except yaml.YAMLError as exc:
2058         print(exc)
2059         os.sys.exit(1)
2060         return
2061
2062     cw = CodeWriter(BaseNlLib(), out_file)
2063
2064     _, spec_kernel = find_kernel_root(args.spec)
2065     if args.mode == 'uapi':
2066         cw.p('/* SPDX-License-Identifier: (GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause */')
2067     else:
2068         if args.header:
2069             cw.p('/* SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause */')
2070         else:
2071             cw.p('// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause')
2072     cw.p("/* Do not edit directly, auto-generated from: */")
2073     cw.p(f"/*\t{spec_kernel} */")
2074     cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2075     cw.nl()
2076
2077     if args.mode == 'uapi':
2078         render_uapi(parsed, cw)
2079         return
2080
2081     hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2082     if args.header:
2083         cw.p('#ifndef ' + hdr_prot)
2084         cw.p('#define ' + hdr_prot)
2085         cw.nl()
2086
2087     if args.mode == 'kernel':
2088         cw.p('#include <net/netlink.h>')
2089         cw.p('#include <net/genetlink.h>')
2090         cw.nl()
2091         if not args.header:
2092             if args.out_file:
2093                 cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2094             cw.nl()
2095     headers = [parsed.uapi_header]
2096     for definition in parsed['definitions']:
2097         if 'header' in definition:
2098             headers.append(definition['header'])
2099     for one in headers:
2100         cw.p(f"#include <{one}>")
2101     cw.nl()
2102
2103     if args.mode == "user":
2104         if not args.header:
2105             cw.p("#include <stdlib.h>")
2106             cw.p("#include <stdio.h>")
2107             cw.p("#include <string.h>")
2108             cw.p("#include <libmnl/libmnl.h>")
2109             cw.p("#include <linux/genetlink.h>")
2110             cw.nl()
2111             for one in args.user_header:
2112                 cw.p(f'#include "{one}"')
2113         else:
2114             cw.p('struct ynl_sock;')
2115         cw.nl()
2116
2117     if args.mode == "kernel":
2118         if args.header:
2119             for _, struct in sorted(parsed.pure_nested_structs.items()):
2120                 if struct.request:
2121                     cw.p('/* Common nested types */')
2122                     break
2123             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2124                 if struct.request:
2125                     print_req_policy_fwd(cw, struct)
2126             cw.nl()
2127
2128             if parsed.kernel_policy == 'global':
2129                 cw.p(f"/* Global operation policy for {parsed.name} */")
2130
2131                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2132                 print_req_policy_fwd(cw, struct)
2133                 cw.nl()
2134
2135             if parsed.kernel_policy in {'per-op', 'split'}:
2136                 for op_name, op in parsed.ops.items():
2137                     if 'do' in op and 'event' not in op:
2138                         ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2139                         print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2140                         cw.nl()
2141
2142             print_kernel_op_table_hdr(parsed, cw)
2143             print_kernel_mcgrp_hdr(parsed, cw)
2144             print_kernel_family_struct_hdr(parsed, cw)
2145         else:
2146             for _, struct in sorted(parsed.pure_nested_structs.items()):
2147                 if struct.request:
2148                     cw.p('/* Common nested types */')
2149                     break
2150             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2151                 if struct.request:
2152                     print_req_policy(cw, struct)
2153             cw.nl()
2154
2155             if parsed.kernel_policy == 'global':
2156                 cw.p(f"/* Global operation policy for {parsed.name} */")
2157
2158                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2159                 print_req_policy(cw, struct)
2160                 cw.nl()
2161
2162             for op_name, op in parsed.ops.items():
2163                 if parsed.kernel_policy in {'per-op', 'split'}:
2164                     for op_mode in ['do', 'dump']:
2165                         if op_mode in op and 'request' in op[op_mode]:
2166                             cw.p(f"/* {op.enum_name} - {op_mode} */")
2167                             ri = RenderInfo(cw, parsed, args.mode, op, op_name, op_mode)
2168                             print_req_policy(cw, ri.struct['request'], ri=ri)
2169                             cw.nl()
2170
2171             print_kernel_op_table(parsed, cw)
2172             print_kernel_mcgrp_src(parsed, cw)
2173             print_kernel_family_struct_src(parsed, cw)
2174
2175     if args.mode == "user":
2176         has_ntf = False
2177         if args.header:
2178             cw.p('/* Common nested types */')
2179             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2180                 ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2181                 print_type_full(ri, struct)
2182
2183             for op_name, op in parsed.ops.items():
2184                 cw.p(f"/* ============== {op.enum_name} ============== */")
2185
2186                 if 'do' in op and 'event' not in op:
2187                     cw.p(f"/* {op.enum_name} - do */")
2188                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2189                     print_req_type(ri)
2190                     print_req_type_helpers(ri)
2191                     cw.nl()
2192                     print_rsp_type(ri)
2193                     print_rsp_type_helpers(ri)
2194                     cw.nl()
2195                     print_req_prototype(ri)
2196                     cw.nl()
2197
2198                 if 'dump' in op:
2199                     cw.p(f"/* {op.enum_name} - dump */")
2200                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'dump')
2201                     if 'request' in op['dump']:
2202                         print_req_type(ri)
2203                         print_req_type_helpers(ri)
2204                     if not ri.type_consistent:
2205                         print_rsp_type(ri)
2206                     print_wrapped_type(ri)
2207                     print_dump_prototype(ri)
2208                     cw.nl()
2209
2210                 if 'notify' in op:
2211                     cw.p(f"/* {op.enum_name} - notify */")
2212                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2213                     has_ntf = True
2214                     if not ri.type_consistent:
2215                         raise Exception('Only notifications with consistent types supported')
2216                     print_wrapped_type(ri)
2217
2218                 if 'event' in op:
2219                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'event')
2220                     cw.p(f"/* {op.enum_name} - event */")
2221                     print_rsp_type(ri)
2222                     cw.nl()
2223                     print_wrapped_type(ri)
2224
2225             if has_ntf:
2226                 cw.p('/* --------------- Common notification parsing --------------- */')
2227                 print_ntf_parse_prototype(parsed, cw)
2228             cw.nl()
2229         else:
2230             cw.p('/* Policies */')
2231             for name, _ in parsed.attr_sets.items():
2232                 struct = Struct(parsed, name)
2233                 put_typol_fwd(cw, struct)
2234             cw.nl()
2235
2236             for name, _ in parsed.attr_sets.items():
2237                 struct = Struct(parsed, name)
2238                 put_typol(cw, struct)
2239
2240             cw.p('/* Common nested types */')
2241             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2242                 ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2243
2244                 free_rsp_nested(ri, struct)
2245                 if struct.request:
2246                     put_req_nested(ri, struct)
2247                 if struct.reply:
2248                     parse_rsp_nested(ri, struct)
2249
2250             for op_name, op in parsed.ops.items():
2251                 cw.p(f"/* ============== {op.enum_name} ============== */")
2252                 if 'do' in op and 'event' not in op:
2253                     cw.p(f"/* {op.enum_name} - do */")
2254                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2255                     print_rsp_free(ri)
2256                     parse_rsp_msg(ri)
2257                     print_req(ri)
2258                     cw.nl()
2259
2260                 if 'dump' in op:
2261                     cw.p(f"/* {op.enum_name} - dump */")
2262                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "dump")
2263                     if not ri.type_consistent:
2264                         parse_rsp_msg(ri, deref=True)
2265                     print_dump_type_free(ri)
2266                     print_dump(ri)
2267                     cw.nl()
2268
2269                 if 'notify' in op:
2270                     cw.p(f"/* {op.enum_name} - notify */")
2271                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2272                     has_ntf = True
2273                     if not ri.type_consistent:
2274                         raise Exception('Only notifications with consistent types supported')
2275                     print_ntf_type_free(ri)
2276
2277                 if 'event' in op:
2278                     cw.p(f"/* {op.enum_name} - event */")
2279                     has_ntf = True
2280
2281                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2282                     parse_rsp_msg(ri)
2283
2284                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "event")
2285                     print_ntf_type_free(ri)
2286
2287             if has_ntf:
2288                 cw.p('/* --------------- Common notification parsing --------------- */')
2289                 print_ntf_type_parse(parsed, cw, args.mode)
2290
2291     if args.header:
2292         cw.p(f'#endif /* {hdr_prot} */')
2293
2294
2295 if __name__ == "__main__":
2296     main()