Merge tag 'for-linus-2023030901' of git://git.kernel.org/pub/scm/linux/kernel/git...
[platform/kernel/linux-rpi.git] / tools / net / ynl / ynl-gen-c.py
1 #!/usr/bin/env python3
2
3 import argparse
4 import collections
5 import os
6 import yaml
7
8 from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation
9
10
11 def c_upper(name):
12     return name.upper().replace('-', '_')
13
14
15 def c_lower(name):
16     return name.lower().replace('-', '_')
17
18
19 class BaseNlLib:
20     def get_family_id(self):
21         return 'ys->family_id'
22
23     def parse_cb_run(self, cb, data, is_dump=False, indent=1):
24         ind = '\n\t\t' + '\t' * indent + ' '
25         if is_dump:
26             return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
27         else:
28             return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
29                    "ynl_cb_array, NLMSG_MIN_TYPE)"
30
31
32 class Type(SpecAttr):
33     def __init__(self, family, attr_set, attr, value):
34         super().__init__(family, attr_set, attr, value)
35
36         self.attr = attr
37         self.attr_set = attr_set
38         self.type = attr['type']
39         self.checks = attr.get('checks', {})
40
41         if 'len' in attr:
42             self.len = attr['len']
43         if 'nested-attributes' in attr:
44             self.nested_attrs = attr['nested-attributes']
45             if self.nested_attrs == family.name:
46                 self.nested_render_name = f"{family.name}"
47             else:
48                 self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
49
50         self.c_name = c_lower(self.name)
51         if self.c_name in _C_KW:
52             self.c_name += '_'
53
54         # Added by resolve():
55         self.enum_name = None
56         delattr(self, "enum_name")
57
58     def resolve(self):
59         self.enum_name = f"{self.attr_set.name_prefix}{self.name}"
60         self.enum_name = c_upper(self.enum_name)
61
62     def is_multi_val(self):
63         return None
64
65     def is_scalar(self):
66         return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
67
68     def presence_type(self):
69         return 'bit'
70
71     def presence_member(self, space, type_filter):
72         if self.presence_type() != type_filter:
73             return
74
75         if self.presence_type() == 'bit':
76             pfx = '__' if space == 'user' else ''
77             return f"{pfx}u32 {self.c_name}:1;"
78
79         if self.presence_type() == 'len':
80             pfx = '__' if space == 'user' else ''
81             return f"{pfx}u32 {self.c_name}_len;"
82
83     def _complex_member_type(self, ri):
84         return None
85
86     def free_needs_iter(self):
87         return False
88
89     def free(self, ri, var, ref):
90         if self.is_multi_val() or self.presence_type() == 'len':
91             ri.cw.p(f'free({var}->{ref}{self.c_name});')
92
93     def arg_member(self, ri):
94         member = self._complex_member_type(ri)
95         if member:
96             return [member + ' *' + self.c_name]
97         raise Exception(f"Struct member not implemented for class type {self.type}")
98
99     def struct_member(self, ri):
100         if self.is_multi_val():
101             ri.cw.p(f"unsigned int n_{self.c_name};")
102         member = self._complex_member_type(ri)
103         if member:
104             ptr = '*' if self.is_multi_val() else ''
105             ri.cw.p(f"{member} {ptr}{self.c_name};")
106             return
107         members = self.arg_member(ri)
108         for one in members:
109             ri.cw.p(one + ';')
110
111     def _attr_policy(self, policy):
112         return '{ .type = ' + policy + ', }'
113
114     def attr_policy(self, cw):
115         policy = c_upper('nla-' + self.attr['type'])
116
117         spec = self._attr_policy(policy)
118         cw.p(f"\t[{self.enum_name}] = {spec},")
119
120     def _attr_typol(self):
121         raise Exception(f"Type policy not implemented for class type {self.type}")
122
123     def attr_typol(self, cw):
124         typol = self._attr_typol()
125         cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
126
127     def _attr_put_line(self, ri, var, line):
128         if self.presence_type() == 'bit':
129             ri.cw.p(f"if ({var}->_present.{self.c_name})")
130         elif self.presence_type() == 'len':
131             ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
132         ri.cw.p(f"{line};")
133
134     def _attr_put_simple(self, ri, var, put_type):
135         line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
136         self._attr_put_line(ri, var, line)
137
138     def attr_put(self, ri, var):
139         raise Exception(f"Put not implemented for class type {self.type}")
140
141     def _attr_get(self, ri, var):
142         raise Exception(f"Attr get not implemented for class type {self.type}")
143
144     def attr_get(self, ri, var, first):
145         lines, init_lines, local_vars = self._attr_get(ri, var)
146         if type(lines) is str:
147             lines = [lines]
148         if type(init_lines) is str:
149             init_lines = [init_lines]
150
151         kw = 'if' if first else 'else if'
152         ri.cw.block_start(line=f"{kw} (mnl_attr_get_type(attr) == {self.enum_name})")
153         if local_vars:
154             for local in local_vars:
155                 ri.cw.p(local)
156             ri.cw.nl()
157
158         if not self.is_multi_val():
159             ri.cw.p("if (ynl_attr_validate(yarg, attr))")
160             ri.cw.p("return MNL_CB_ERROR;")
161             if self.presence_type() == 'bit':
162                 ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
163
164         if init_lines:
165             ri.cw.nl()
166             for line in init_lines:
167                 ri.cw.p(line)
168
169         for line in lines:
170             ri.cw.p(line)
171         ri.cw.block_end()
172
173     def _setter_lines(self, ri, member, presence):
174         raise Exception(f"Setter not implemented for class type {self.type}")
175
176     def setter(self, ri, space, direction, deref=False, ref=None):
177         ref = (ref if ref else []) + [self.c_name]
178         var = "req"
179         member = f"{var}->{'.'.join(ref)}"
180
181         code = []
182         presence = ''
183         for i in range(0, len(ref)):
184             presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
185             if self.presence_type() == 'bit':
186                 code.append(presence + ' = 1;')
187         code += self._setter_lines(ri, member, presence)
188
189         ri.cw.write_func('static inline void',
190                          f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}",
191                          body=code,
192                          args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
193
194
195 class TypeUnused(Type):
196     def presence_type(self):
197         return ''
198
199     def _attr_typol(self):
200         return '.type = YNL_PT_REJECT, '
201
202     def attr_policy(self, cw):
203         pass
204
205
206 class TypePad(Type):
207     def presence_type(self):
208         return ''
209
210     def _attr_typol(self):
211         return '.type = YNL_PT_REJECT, '
212
213     def attr_policy(self, cw):
214         pass
215
216
217 class TypeScalar(Type):
218     def __init__(self, family, attr_set, attr, value):
219         super().__init__(family, attr_set, attr, value)
220
221         self.byte_order_comment = ''
222         if 'byte-order' in attr:
223             self.byte_order_comment = f" /* {attr['byte-order']} */"
224
225         # Added by resolve():
226         self.is_bitfield = None
227         delattr(self, "is_bitfield")
228         self.type_name = None
229         delattr(self, "type_name")
230
231     def resolve(self):
232         self.resolve_up(super())
233
234         if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
235             self.is_bitfield = True
236         elif 'enum' in self.attr:
237             self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
238         else:
239             self.is_bitfield = False
240
241         if 'enum' in self.attr and not self.is_bitfield:
242             self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
243         else:
244             self.type_name = '__' + self.type
245
246     def _mnl_type(self):
247         t = self.type
248         # mnl does not have a helper for signed types
249         if t[0] == 's':
250             t = 'u' + t[1:]
251         return t
252
253     def _attr_policy(self, policy):
254         if 'flags-mask' in self.checks or self.is_bitfield:
255             if self.is_bitfield:
256                 mask = self.family.consts[self.attr['enum']].get_mask()
257             else:
258                 flags = self.family.consts[self.checks['flags-mask']]
259                 flag_cnt = len(flags['entries'])
260                 mask = (1 << flag_cnt) - 1
261             return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
262         elif 'min' in self.checks:
263             return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
264         elif 'enum' in self.attr:
265             enum = self.family.consts[self.attr['enum']]
266             cnt = len(enum['entries'])
267             return f"NLA_POLICY_MAX({policy}, {cnt - 1})"
268         return super()._attr_policy(policy)
269
270     def _attr_typol(self):
271         return f'.type = YNL_PT_U{self.type[1:]}, '
272
273     def arg_member(self, ri):
274         return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
275
276     def attr_put(self, ri, var):
277         self._attr_put_simple(ri, var, self._mnl_type())
278
279     def _attr_get(self, ri, var):
280         return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
281
282     def _setter_lines(self, ri, member, presence):
283         return [f"{member} = {self.c_name};"]
284
285
286 class TypeFlag(Type):
287     def arg_member(self, ri):
288         return []
289
290     def _attr_typol(self):
291         return '.type = YNL_PT_FLAG, '
292
293     def attr_put(self, ri, var):
294         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
295
296     def _attr_get(self, ri, var):
297         return [], None, None
298
299     def _setter_lines(self, ri, member, presence):
300         return []
301
302
303 class TypeString(Type):
304     def arg_member(self, ri):
305         return [f"const char *{self.c_name}"]
306
307     def presence_type(self):
308         return 'len'
309
310     def struct_member(self, ri):
311         ri.cw.p(f"char *{self.c_name};")
312
313     def _attr_typol(self):
314         return f'.type = YNL_PT_NUL_STR, '
315
316     def _attr_policy(self, policy):
317         mem = '{ .type = ' + policy
318         if 'max-len' in self.checks:
319             mem += ', .len = ' + str(self.checks['max-len'])
320         mem += ', }'
321         return mem
322
323     def attr_policy(self, cw):
324         if self.checks.get('unterminated-ok', False):
325             policy = 'NLA_STRING'
326         else:
327             policy = 'NLA_NUL_STRING'
328
329         spec = self._attr_policy(policy)
330         cw.p(f"\t[{self.enum_name}] = {spec},")
331
332     def attr_put(self, ri, var):
333         self._attr_put_simple(ri, var, 'strz')
334
335     def _attr_get(self, ri, var):
336         len_mem = var + '->_present.' + self.c_name + '_len'
337         return [f"{len_mem} = len;",
338                 f"{var}->{self.c_name} = malloc(len + 1);",
339                 f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
340                 f"{var}->{self.c_name}[len] = 0;"], \
341                ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
342                ['unsigned int len;']
343
344     def _setter_lines(self, ri, member, presence):
345         return [f"free({member});",
346                 f"{presence}_len = strlen({self.c_name});",
347                 f"{member} = malloc({presence}_len + 1);",
348                 f'memcpy({member}, {self.c_name}, {presence}_len);',
349                 f'{member}[{presence}_len] = 0;']
350
351
352 class TypeBinary(Type):
353     def arg_member(self, ri):
354         return [f"const void *{self.c_name}", 'size_t len']
355
356     def presence_type(self):
357         return 'len'
358
359     def struct_member(self, ri):
360         ri.cw.p(f"void *{self.c_name};")
361
362     def _attr_typol(self):
363         return f'.type = YNL_PT_BINARY,'
364
365     def _attr_policy(self, policy):
366         mem = '{ '
367         if len(self.checks) == 1 and 'min-len' in self.checks:
368             mem += '.len = ' + str(self.checks['min-len'])
369         elif len(self.checks) == 0:
370             mem += '.type = NLA_BINARY'
371         else:
372             raise Exception('One or more of binary type checks not implemented, yet')
373         mem += ', }'
374         return mem
375
376     def attr_put(self, ri, var):
377         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
378                             f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
379
380     def _attr_get(self, ri, var):
381         len_mem = var + '->_present.' + self.c_name + '_len'
382         return [f"{len_mem} = len;",
383                 f"{var}->{self.c_name} = malloc(len);",
384                 f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
385                ['len = mnl_attr_get_payload_len(attr);'], \
386                ['unsigned int len;']
387
388     def _setter_lines(self, ri, member, presence):
389         return [f"free({member});",
390                 f"{member} = malloc({presence}_len);",
391                 f'memcpy({member}, {self.c_name}, {presence}_len);']
392
393
394 class TypeNest(Type):
395     def _complex_member_type(self, ri):
396         return f"struct {self.nested_render_name}"
397
398     def free(self, ri, var, ref):
399         ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
400
401     def _attr_typol(self):
402         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
403
404     def _attr_policy(self, policy):
405         return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
406
407     def attr_put(self, ri, var):
408         self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
409                             f"{self.enum_name}, &{var}->{self.c_name})")
410
411     def _attr_get(self, ri, var):
412         get_lines = [f"{self.nested_render_name}_parse(&parg, attr);"]
413         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
414                       f"parg.data = &{var}->{self.c_name};"]
415         return get_lines, init_lines, None
416
417     def setter(self, ri, space, direction, deref=False, ref=None):
418         ref = (ref if ref else []) + [self.c_name]
419
420         for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
421             attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
422
423
424 class TypeMultiAttr(Type):
425     def is_multi_val(self):
426         return True
427
428     def presence_type(self):
429         return 'count'
430
431     def _complex_member_type(self, ri):
432         if 'type' not in self.attr or self.attr['type'] == 'nest':
433             return f"struct {self.nested_render_name}"
434         elif self.attr['type'] in scalars:
435             scalar_pfx = '__' if ri.ku_space == 'user' else ''
436             return scalar_pfx + self.attr['type']
437         else:
438             raise Exception(f"Sub-type {self.attr['type']} not supported yet")
439
440     def free_needs_iter(self):
441         return 'type' not in self.attr or self.attr['type'] == 'nest'
442
443     def free(self, ri, var, ref):
444         if 'type' not in self.attr or self.attr['type'] == 'nest':
445             ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
446             ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
447
448     def _attr_typol(self):
449         if 'type' not in self.attr or self.attr['type'] == 'nest':
450             return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
451         elif self.attr['type'] in scalars:
452             return f".type = YNL_PT_U{self.attr['type'][1:]}, "
453         else:
454             raise Exception(f"Sub-type {self.attr['type']} not supported yet")
455
456     def _attr_get(self, ri, var):
457         return f'{var}->n_{self.c_name}++;', None, None
458
459
460 class TypeArrayNest(Type):
461     def is_multi_val(self):
462         return True
463
464     def presence_type(self):
465         return 'count'
466
467     def _complex_member_type(self, ri):
468         if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
469             return f"struct {self.nested_render_name}"
470         elif self.attr['sub-type'] in scalars:
471             scalar_pfx = '__' if ri.ku_space == 'user' else ''
472             return scalar_pfx + self.attr['sub-type']
473         else:
474             raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
475
476     def _attr_typol(self):
477         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
478
479     def _attr_get(self, ri, var):
480         local_vars = ['const struct nlattr *attr2;']
481         get_lines = [f'attr_{self.c_name} = attr;',
482                      'mnl_attr_for_each_nested(attr2, attr)',
483                      f'\t{var}->n_{self.c_name}++;']
484         return get_lines, None, local_vars
485
486
487 class TypeNestTypeValue(Type):
488     def _complex_member_type(self, ri):
489         return f"struct {self.nested_render_name}"
490
491     def _attr_typol(self):
492         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
493
494     def _attr_get(self, ri, var):
495         prev = 'attr'
496         tv_args = ''
497         get_lines = []
498         local_vars = []
499         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
500                       f"parg.data = &{var}->{self.c_name};"]
501         if 'type-value' in self.attr:
502             tv_names = [c_lower(x) for x in self.attr["type-value"]]
503             local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
504             local_vars += [f'__u32 {", ".join(tv_names)};']
505             for level in self.attr["type-value"]:
506                 level = c_lower(level)
507                 get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
508                 get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
509                 prev = 'attr_' + level
510
511             tv_args = f", {', '.join(tv_names)}"
512
513         get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
514         return get_lines, init_lines, local_vars
515
516
517 class Struct:
518     def __init__(self, family, space_name, type_list=None, inherited=None):
519         self.family = family
520         self.space_name = space_name
521         self.attr_set = family.attr_sets[space_name]
522         # Use list to catch comparisons with empty sets
523         self._inherited = inherited if inherited is not None else []
524         self.inherited = []
525
526         self.nested = type_list is None
527         if family.name == c_lower(space_name):
528             self.render_name = f"{family.name}"
529         else:
530             self.render_name = f"{family.name}_{c_lower(space_name)}"
531         self.struct_name = 'struct ' + self.render_name
532         self.ptr_name = self.struct_name + ' *'
533
534         self.request = False
535         self.reply = False
536
537         self.attr_list = []
538         self.attrs = dict()
539         if type_list:
540             for t in type_list:
541                 self.attr_list.append((t, self.attr_set[t]),)
542         else:
543             for t in self.attr_set:
544                 self.attr_list.append((t, self.attr_set[t]),)
545
546         max_val = 0
547         self.attr_max_val = None
548         for name, attr in self.attr_list:
549             if attr.value >= max_val:
550                 max_val = attr.value
551                 self.attr_max_val = attr
552             self.attrs[name] = attr
553
554     def __iter__(self):
555         yield from self.attrs
556
557     def __getitem__(self, key):
558         return self.attrs[key]
559
560     def member_list(self):
561         return self.attr_list
562
563     def set_inherited(self, new_inherited):
564         if self._inherited != new_inherited:
565             raise Exception("Inheriting different members not supported")
566         self.inherited = [c_lower(x) for x in sorted(self._inherited)]
567
568
569 class EnumEntry:
570     def __init__(self, enum_set, yaml, prev, value_start):
571         if isinstance(yaml, str):
572             self.name = yaml
573             yaml = {}
574             self.doc = ''
575         else:
576             self.name = yaml['name']
577             self.doc = yaml.get('doc', '')
578
579         self.yaml = yaml
580         self.enum_set = enum_set
581         self.c_name = c_upper(enum_set.value_pfx + self.name)
582
583         if 'value' in yaml:
584             self.value = yaml['value']
585             if prev:
586                 self.value_change = (self.value != prev.value + 1)
587         elif prev:
588             self.value_change = False
589             self.value = prev.value + 1
590         else:
591             self.value = value_start
592             self.value_change = (self.value != 0)
593
594         self.value_change = self.value_change or self.enum_set['type'] == 'flags'
595
596     def __getitem__(self, key):
597         return self.yaml[key]
598
599     def __contains__(self, key):
600         return key in self.yaml
601
602     def has_doc(self):
603         return bool(self.doc)
604
605     # raw value, i.e. the id in the enum, unlike user value which is a mask for flags
606     def raw_value(self):
607         return self.value
608
609     # user value, same as raw value for enums, for flags it's the mask
610     def user_value(self):
611         if self.enum_set['type'] == 'flags':
612             return 1 << self.value
613         else:
614             return self.value
615
616
617 class EnumSet:
618     def __init__(self, family, yaml):
619         self.yaml = yaml
620         self.family = family
621
622         self.render_name = c_lower(family.name + '-' + yaml['name'])
623         self.enum_name = 'enum ' + self.render_name
624
625         self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
626
627         self.type = yaml['type']
628
629         prev_entry = None
630         value_start = self.yaml.get('value-start', 0)
631         self.entries = {}
632         self.entry_list = []
633         for entry in self.yaml['entries']:
634             e = EnumEntry(self, entry, prev_entry, value_start)
635             self.entries[e.name] = e
636             self.entry_list.append(e)
637             prev_entry = e
638
639     def __getitem__(self, key):
640         return self.yaml[key]
641
642     def __contains__(self, key):
643         return key in self.yaml
644
645     def has_doc(self):
646         if 'doc' in self.yaml:
647             return True
648         for entry in self.entry_list:
649             if entry.has_doc():
650                 return True
651         return False
652
653     def get_mask(self):
654         mask = 0
655         idx = self.yaml.get('value-start', 0)
656         for _ in self.entry_list:
657             mask |= 1 << idx
658             idx += 1
659         return mask
660
661
662 class AttrSet(SpecAttrSet):
663     def __init__(self, family, yaml):
664         super().__init__(family, yaml)
665
666         if self.subset_of is None:
667             if 'name-prefix' in yaml:
668                 pfx = yaml['name-prefix']
669             elif self.name == family.name:
670                 pfx = family.name + '-a-'
671             else:
672                 pfx = f"{family.name}-a-{self.name}-"
673             self.name_prefix = c_upper(pfx)
674             self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
675         else:
676             self.name_prefix = family.attr_sets[self.subset_of].name_prefix
677             self.max_name = family.attr_sets[self.subset_of].max_name
678
679         # Added by resolve:
680         self.c_name = None
681         delattr(self, "c_name")
682
683     def resolve(self):
684         self.c_name = c_lower(self.name)
685         if self.c_name in _C_KW:
686             self.c_name += '_'
687         if self.c_name == self.family.c_name:
688             self.c_name = ''
689
690     def new_attr(self, elem, value):
691         if 'multi-attr' in elem and elem['multi-attr']:
692             return TypeMultiAttr(self.family, self, elem, value)
693         elif elem['type'] in scalars:
694             return TypeScalar(self.family, self, elem, value)
695         elif elem['type'] == 'unused':
696             return TypeUnused(self.family, self, elem, value)
697         elif elem['type'] == 'pad':
698             return TypePad(self.family, self, elem, value)
699         elif elem['type'] == 'flag':
700             return TypeFlag(self.family, self, elem, value)
701         elif elem['type'] == 'string':
702             return TypeString(self.family, self, elem, value)
703         elif elem['type'] == 'binary':
704             return TypeBinary(self.family, self, elem, value)
705         elif elem['type'] == 'nest':
706             return TypeNest(self.family, self, elem, value)
707         elif elem['type'] == 'array-nest':
708             return TypeArrayNest(self.family, self, elem, value)
709         elif elem['type'] == 'nest-type-value':
710             return TypeNestTypeValue(self.family, self, elem, value)
711         else:
712             raise Exception(f"No typed class for type {elem['type']}")
713
714
715 class Operation(SpecOperation):
716     def __init__(self, family, yaml, req_value, rsp_value):
717         super().__init__(family, yaml, req_value, rsp_value)
718
719         if req_value != rsp_value:
720             raise Exception("Directional messages not supported by codegen")
721
722         self.render_name = family.name + '_' + c_lower(self.name)
723
724         self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
725                          ('dump' in yaml and 'request' in yaml['dump'])
726
727         # Added by resolve:
728         self.enum_name = None
729         delattr(self, "enum_name")
730
731     def resolve(self):
732         self.resolve_up(super())
733
734         if not self.is_async:
735             self.enum_name = self.family.op_prefix + c_upper(self.name)
736         else:
737             self.enum_name = self.family.async_op_prefix + c_upper(self.name)
738
739     def add_notification(self, op):
740         if 'notify' not in self.yaml:
741             self.yaml['notify'] = dict()
742             self.yaml['notify']['reply'] = self.yaml['do']['reply']
743             self.yaml['notify']['cmds'] = []
744         self.yaml['notify']['cmds'].append(op)
745
746
747 class Family(SpecFamily):
748     def __init__(self, file_name):
749         # Added by resolve:
750         self.c_name = None
751         delattr(self, "c_name")
752         self.op_prefix = None
753         delattr(self, "op_prefix")
754         self.async_op_prefix = None
755         delattr(self, "async_op_prefix")
756         self.mcgrps = None
757         delattr(self, "mcgrps")
758         self.consts = None
759         delattr(self, "consts")
760         self.hooks = None
761         delattr(self, "hooks")
762
763         super().__init__(file_name)
764
765         self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
766         self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
767
768         if 'definitions' not in self.yaml:
769             self.yaml['definitions'] = []
770
771         if 'uapi-header' in self.yaml:
772             self.uapi_header = self.yaml['uapi-header']
773         else:
774             self.uapi_header = f"linux/{self.name}.h"
775
776     def resolve(self):
777         self.resolve_up(super())
778
779         if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
780             raise Exception("Codegen only supported for genetlink")
781
782         self.c_name = c_lower(self.name)
783         if 'name-prefix' in self.yaml['operations']:
784             self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
785         else:
786             self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
787         if 'async-prefix' in self.yaml['operations']:
788             self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
789         else:
790             self.async_op_prefix = self.op_prefix
791
792         self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
793
794         self.consts = dict()
795
796         self.hooks = dict()
797         for when in ['pre', 'post']:
798             self.hooks[when] = dict()
799             for op_mode in ['do', 'dump']:
800                 self.hooks[when][op_mode] = dict()
801                 self.hooks[when][op_mode]['set'] = set()
802                 self.hooks[when][op_mode]['list'] = []
803
804         # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
805         self.root_sets = dict()
806         # dict space-name -> set('request', 'reply')
807         self.pure_nested_structs = dict()
808         self.all_notify = dict()
809
810         self._mock_up_events()
811
812         self._dictify()
813         self._load_root_sets()
814         self._load_nested_sets()
815         self._load_all_notify()
816         self._load_hooks()
817
818         self.kernel_policy = self.yaml.get('kernel-policy', 'split')
819         if self.kernel_policy == 'global':
820             self._load_global_policy()
821
822     def new_attr_set(self, elem):
823         return AttrSet(self, elem)
824
825     def new_operation(self, elem, req_value, rsp_value):
826         return Operation(self, elem, req_value, rsp_value)
827
828     # Fake a 'do' equivalent of all events, so that we can render their response parsing
829     def _mock_up_events(self):
830         for op in self.yaml['operations']['list']:
831             if 'event' in op:
832                 op['do'] = {
833                     'reply': {
834                         'attributes': op['event']['attributes']
835                     }
836                 }
837
838     def _dictify(self):
839         for elem in self.yaml['definitions']:
840             if elem['type'] == 'enum' or elem['type'] == 'flags':
841                 self.consts[elem['name']] = EnumSet(self, elem)
842             else:
843                 self.consts[elem['name']] = elem
844
845         ntf = []
846         for msg in self.msgs.values():
847             if 'notify' in msg:
848                 ntf.append(msg)
849         for n in ntf:
850             self.ops[n['notify']].add_notification(n)
851
852     def _load_root_sets(self):
853         for op_name, op in self.ops.items():
854             if 'attribute-set' not in op:
855                 continue
856
857             req_attrs = set()
858             rsp_attrs = set()
859             for op_mode in ['do', 'dump']:
860                 if op_mode in op and 'request' in op[op_mode]:
861                     req_attrs.update(set(op[op_mode]['request']['attributes']))
862                 if op_mode in op and 'reply' in op[op_mode]:
863                     rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
864
865             if op['attribute-set'] not in self.root_sets:
866                 self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
867             else:
868                 self.root_sets[op['attribute-set']]['request'].update(req_attrs)
869                 self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
870
871     def _load_nested_sets(self):
872         for root_set, rs_members in self.root_sets.items():
873             for attr, spec in self.attr_sets[root_set].items():
874                 if 'nested-attributes' in spec:
875                     inherit = set()
876                     nested = spec['nested-attributes']
877                     if nested not in self.root_sets:
878                         self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
879                     if attr in rs_members['request']:
880                         self.pure_nested_structs[nested].request = True
881                     if attr in rs_members['reply']:
882                         self.pure_nested_structs[nested].reply = True
883
884                     if 'type-value' in spec:
885                         if nested in self.root_sets:
886                             raise Exception("Inheriting members to a space used as root not supported")
887                         inherit.update(set(spec['type-value']))
888                     elif spec['type'] == 'array-nest':
889                         inherit.add('idx')
890                     self.pure_nested_structs[nested].set_inherited(inherit)
891
892     def _load_all_notify(self):
893         for op_name, op in self.ops.items():
894             if not op:
895                 continue
896
897             if 'notify' in op:
898                 self.all_notify[op_name] = op['notify']['cmds']
899
900     def _load_global_policy(self):
901         global_set = set()
902         attr_set_name = None
903         for op_name, op in self.ops.items():
904             if not op:
905                 continue
906             if 'attribute-set' not in op:
907                 continue
908
909             if attr_set_name is None:
910                 attr_set_name = op['attribute-set']
911             if attr_set_name != op['attribute-set']:
912                 raise Exception('For a global policy all ops must use the same set')
913
914             for op_mode in ['do', 'dump']:
915                 if op_mode in op:
916                     global_set.update(op[op_mode].get('request', []))
917
918         self.global_policy = []
919         self.global_policy_set = attr_set_name
920         for attr in self.attr_sets[attr_set_name]:
921             if attr in global_set:
922                 self.global_policy.append(attr)
923
924     def _load_hooks(self):
925         for op in self.ops.values():
926             for op_mode in ['do', 'dump']:
927                 if op_mode not in op:
928                     continue
929                 for when in ['pre', 'post']:
930                     if when not in op[op_mode]:
931                         continue
932                     name = op[op_mode][when]
933                     if name in self.hooks[when][op_mode]['set']:
934                         continue
935                     self.hooks[when][op_mode]['set'].add(name)
936                     self.hooks[when][op_mode]['list'].append(name)
937
938
939 class RenderInfo:
940     def __init__(self, cw, family, ku_space, op, op_name, op_mode, attr_set=None):
941         self.family = family
942         self.nl = cw.nlib
943         self.ku_space = ku_space
944         self.op = op
945         self.op_name = op_name
946         self.op_mode = op_mode
947
948         # 'do' and 'dump' response parsing is identical
949         if op_mode != 'do' and 'dump' in op and 'do' in op and 'reply' in op['do'] and \
950            op["do"]["reply"] == op["dump"]["reply"]:
951             self.type_consistent = True
952         else:
953             self.type_consistent = op_mode == 'event'
954
955         self.attr_set = attr_set
956         if not self.attr_set:
957             self.attr_set = op['attribute-set']
958
959         if op:
960             self.type_name = c_lower(op_name)
961         else:
962             self.type_name = c_lower(attr_set)
963
964         self.cw = cw
965
966         self.struct = dict()
967         for op_dir in ['request', 'reply']:
968             if op and op_dir in op[op_mode]:
969                 self.struct[op_dir] = Struct(family, self.attr_set,
970                                              type_list=op[op_mode][op_dir]['attributes'])
971         if op_mode == 'event':
972             self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
973
974
975 class CodeWriter:
976     def __init__(self, nlib, out_file):
977         self.nlib = nlib
978
979         self._nl = False
980         self._silent_block = False
981         self._ind = 0
982         self._out = out_file
983
984     @classmethod
985     def _is_cond(cls, line):
986         return line.startswith('if') or line.startswith('while') or line.startswith('for')
987
988     def p(self, line, add_ind=0):
989         if self._nl:
990             self._out.write('\n')
991             self._nl = False
992         ind = self._ind
993         if line[-1] == ':':
994             ind -= 1
995         if self._silent_block:
996             ind += 1
997         self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
998         if add_ind:
999             ind += add_ind
1000         self._out.write('\t' * ind + line + '\n')
1001
1002     def nl(self):
1003         self._nl = True
1004
1005     def block_start(self, line=''):
1006         if line:
1007             line = line + ' '
1008         self.p(line + '{')
1009         self._ind += 1
1010
1011     def block_end(self, line=''):
1012         if line and line[0] not in {';', ','}:
1013             line = ' ' + line
1014         self._ind -= 1
1015         self.p('}' + line)
1016
1017     def write_doc_line(self, doc, indent=True):
1018         words = doc.split()
1019         line = ' *'
1020         for word in words:
1021             if len(line) + len(word) >= 79:
1022                 self.p(line)
1023                 line = ' *'
1024                 if indent:
1025                     line += '  '
1026             line += ' ' + word
1027         self.p(line)
1028
1029     def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1030         if not args:
1031             args = ['void']
1032
1033         if doc:
1034             self.p('/*')
1035             self.p(' * ' + doc)
1036             self.p(' */')
1037
1038         oneline = qual_ret
1039         if qual_ret[-1] != '*':
1040             oneline += ' '
1041         oneline += f"{name}({', '.join(args)}){suffix}"
1042
1043         if len(oneline) < 80:
1044             self.p(oneline)
1045             return
1046
1047         v = qual_ret
1048         if len(v) > 3:
1049             self.p(v)
1050             v = ''
1051         elif qual_ret[-1] != '*':
1052             v += ' '
1053         v += name + '('
1054         ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1055         delta_ind = len(v) - len(ind)
1056         v += args[0]
1057         i = 1
1058         while i < len(args):
1059             next_len = len(v) + len(args[i])
1060             if v[0] == '\t':
1061                 next_len += delta_ind
1062             if next_len > 76:
1063                 self.p(v + ',')
1064                 v = ind
1065             else:
1066                 v += ', '
1067             v += args[i]
1068             i += 1
1069         self.p(v + ')' + suffix)
1070
1071     def write_func_lvar(self, local_vars):
1072         if not local_vars:
1073             return
1074
1075         if type(local_vars) is str:
1076             local_vars = [local_vars]
1077
1078         local_vars.sort(key=len, reverse=True)
1079         for var in local_vars:
1080             self.p(var)
1081         self.nl()
1082
1083     def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1084         self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1085         self.write_func_lvar(local_vars=local_vars)
1086
1087         self.block_start()
1088         for line in body:
1089             self.p(line)
1090         self.block_end()
1091
1092     def writes_defines(self, defines):
1093         longest = 0
1094         for define in defines:
1095             if len(define[0]) > longest:
1096                 longest = len(define[0])
1097         longest = ((longest + 8) // 8) * 8
1098         for define in defines:
1099             line = '#define ' + define[0]
1100             line += '\t' * ((longest - len(define[0]) + 7) // 8)
1101             if type(define[1]) is int:
1102                 line += str(define[1])
1103             elif type(define[1]) is str:
1104                 line += '"' + define[1] + '"'
1105             self.p(line)
1106
1107     def write_struct_init(self, members):
1108         longest = max([len(x[0]) for x in members])
1109         longest += 1  # because we prepend a .
1110         longest = ((longest + 8) // 8) * 8
1111         for one in members:
1112             line = '.' + one[0]
1113             line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1114             line += '= ' + one[1] + ','
1115             self.p(line)
1116
1117
1118 scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1119
1120 direction_to_suffix = {
1121     'reply': '_rsp',
1122     'request': '_req',
1123     '': ''
1124 }
1125
1126 op_mode_to_wrapper = {
1127     'do': '',
1128     'dump': '_list',
1129     'notify': '_ntf',
1130     'event': '',
1131 }
1132
1133 _C_KW = {
1134     'do'
1135 }
1136
1137
1138 def rdir(direction):
1139     if direction == 'reply':
1140         return 'request'
1141     if direction == 'request':
1142         return 'reply'
1143     return direction
1144
1145
1146 def op_prefix(ri, direction, deref=False):
1147     suffix = f"_{ri.type_name}"
1148
1149     if not ri.op_mode or ri.op_mode == 'do':
1150         suffix += f"{direction_to_suffix[direction]}"
1151     else:
1152         if direction == 'request':
1153             suffix += '_req_dump'
1154         else:
1155             if ri.type_consistent:
1156                 if deref:
1157                     suffix += f"{direction_to_suffix[direction]}"
1158                 else:
1159                     suffix += op_mode_to_wrapper[ri.op_mode]
1160             else:
1161                 suffix += '_rsp'
1162                 suffix += '_dump' if deref else '_list'
1163
1164     return f"{ri.family['name']}{suffix}"
1165
1166
1167 def type_name(ri, direction, deref=False):
1168     return f"struct {op_prefix(ri, direction, deref=deref)}"
1169
1170
1171 def print_prototype(ri, direction, terminate=True, doc=None):
1172     suffix = ';' if terminate else ''
1173
1174     fname = ri.op.render_name
1175     if ri.op_mode == 'dump':
1176         fname += '_dump'
1177
1178     args = ['struct ynl_sock *ys']
1179     if 'request' in ri.op[ri.op_mode]:
1180         args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1181
1182     ret = 'int'
1183     if 'reply' in ri.op[ri.op_mode]:
1184         ret = f"{type_name(ri, rdir(direction))} *"
1185
1186     ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1187
1188
1189 def print_req_prototype(ri):
1190     print_prototype(ri, "request", doc=ri.op['doc'])
1191
1192
1193 def print_dump_prototype(ri):
1194     print_prototype(ri, "request")
1195
1196
1197 def put_typol_fwd(cw, struct):
1198     cw.p(f'extern struct ynl_policy_nest {struct.render_name}_nest;')
1199
1200
1201 def put_typol(cw, struct):
1202     type_max = struct.attr_set.max_name
1203     cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1204
1205     for _, arg in struct.member_list():
1206         arg.attr_typol(cw)
1207
1208     cw.block_end(line=';')
1209     cw.nl()
1210
1211     cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1212     cw.p(f'.max_attr = {type_max},')
1213     cw.p(f'.table = {struct.render_name}_policy,')
1214     cw.block_end(line=';')
1215     cw.nl()
1216
1217
1218 def put_req_nested(ri, struct):
1219     func_args = ['struct nlmsghdr *nlh',
1220                  'unsigned int attr_type',
1221                  f'{struct.ptr_name}obj']
1222
1223     ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1224     ri.cw.block_start()
1225     ri.cw.write_func_lvar('struct nlattr *nest;')
1226
1227     ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1228
1229     for _, arg in struct.member_list():
1230         arg.attr_put(ri, "obj")
1231
1232     ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1233
1234     ri.cw.nl()
1235     ri.cw.p('return 0;')
1236     ri.cw.block_end()
1237     ri.cw.nl()
1238
1239
1240 def _multi_parse(ri, struct, init_lines, local_vars):
1241     if struct.nested:
1242         iter_line = "mnl_attr_for_each_nested(attr, nested)"
1243     else:
1244         iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1245
1246     array_nests = set()
1247     multi_attrs = set()
1248     needs_parg = False
1249     for arg, aspec in struct.member_list():
1250         if aspec['type'] == 'array-nest':
1251             local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1252             array_nests.add(arg)
1253         if 'multi-attr' in aspec:
1254             multi_attrs.add(arg)
1255         needs_parg |= 'nested-attributes' in aspec
1256     if array_nests or multi_attrs:
1257         local_vars.append('int i;')
1258     if needs_parg:
1259         local_vars.append('struct ynl_parse_arg parg;')
1260         init_lines.append('parg.ys = yarg->ys;')
1261
1262     ri.cw.block_start()
1263     ri.cw.write_func_lvar(local_vars)
1264
1265     for line in init_lines:
1266         ri.cw.p(line)
1267     ri.cw.nl()
1268
1269     for arg in struct.inherited:
1270         ri.cw.p(f'dst->{arg} = {arg};')
1271
1272     ri.cw.nl()
1273     ri.cw.block_start(line=iter_line)
1274
1275     first = True
1276     for _, arg in struct.member_list():
1277         arg.attr_get(ri, 'dst', first=first)
1278         first = False
1279
1280     ri.cw.block_end()
1281     ri.cw.nl()
1282
1283     for anest in sorted(array_nests):
1284         aspec = struct[anest]
1285
1286         ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})")
1287         ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1288         ri.cw.p('i = 0;')
1289         ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1290         ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1291         ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1292         ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1293         ri.cw.p('return MNL_CB_ERROR;')
1294         ri.cw.p('i++;')
1295         ri.cw.block_end()
1296         ri.cw.block_end()
1297     ri.cw.nl()
1298
1299     for anest in sorted(multi_attrs):
1300         aspec = struct[anest]
1301         ri.cw.block_start(line=f"if (dst->n_{aspec.c_name})")
1302         ri.cw.p(f"dst->{aspec.c_name} = calloc(dst->n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1303         ri.cw.p('i = 0;')
1304         if 'nested-attributes' in aspec:
1305             ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1306         ri.cw.block_start(line=iter_line)
1307         ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1308         if 'nested-attributes' in aspec:
1309             ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1310             ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1311             ri.cw.p('return MNL_CB_ERROR;')
1312         elif aspec['type'] in scalars:
1313             t = aspec['type']
1314             if t[0] == 's':
1315                 t = 'u' + t[1:]
1316             ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1317         else:
1318             raise Exception('Nest parsing type not supported yet')
1319         ri.cw.p('i++;')
1320         ri.cw.block_end()
1321         ri.cw.block_end()
1322         ri.cw.block_end()
1323     ri.cw.nl()
1324
1325     if struct.nested:
1326         ri.cw.p('return 0;')
1327     else:
1328         ri.cw.p('return MNL_CB_OK;')
1329     ri.cw.block_end()
1330     ri.cw.nl()
1331
1332
1333 def parse_rsp_nested(ri, struct):
1334     func_args = ['struct ynl_parse_arg *yarg',
1335                  'const struct nlattr *nested']
1336     for arg in struct.inherited:
1337         func_args.append('__u32 ' + arg)
1338
1339     local_vars = ['const struct nlattr *attr;',
1340                   f'{struct.ptr_name}dst = yarg->data;']
1341     init_lines = []
1342
1343     ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1344
1345     _multi_parse(ri, struct, init_lines, local_vars)
1346
1347
1348 def parse_rsp_msg(ri, deref=False):
1349     if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1350         return
1351
1352     func_args = ['const struct nlmsghdr *nlh',
1353                  'void *data']
1354
1355     local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1356                   'struct ynl_parse_arg *yarg = data;',
1357                   'const struct nlattr *attr;']
1358     init_lines = ['dst = yarg->data;']
1359
1360     ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1361
1362     _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1363
1364
1365 def print_req(ri):
1366     ret_ok = '0'
1367     ret_err = '-1'
1368     direction = "request"
1369     local_vars = ['struct nlmsghdr *nlh;',
1370                   'int len, err;']
1371
1372     if 'reply' in ri.op[ri.op_mode]:
1373         ret_ok = 'rsp'
1374         ret_err = 'NULL'
1375         local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1376                        'struct ynl_parse_arg yarg = { .ys = ys, };']
1377
1378     print_prototype(ri, direction, terminate=False)
1379     ri.cw.block_start()
1380     ri.cw.write_func_lvar(local_vars)
1381
1382     ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1383
1384     ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1385     if 'reply' in ri.op[ri.op_mode]:
1386         ri.cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1387     ri.cw.nl()
1388     for _, attr in ri.struct["request"].member_list():
1389         attr.attr_put(ri, "req")
1390     ri.cw.nl()
1391
1392     ri.cw.p('err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);')
1393     ri.cw.p('if (err < 0)')
1394     ri.cw.p(f"return {ret_err};")
1395     ri.cw.nl()
1396     ri.cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1397     ri.cw.p('if (len < 0)')
1398     ri.cw.p(f"return {ret_err};")
1399     ri.cw.nl()
1400
1401     if 'reply' in ri.op[ri.op_mode]:
1402         ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1403         ri.cw.p('yarg.data = rsp;')
1404         ri.cw.nl()
1405         ri.cw.p(f"err = {ri.nl.parse_cb_run(op_prefix(ri, 'reply') + '_parse', '&yarg', False)};")
1406         ri.cw.p('if (err < 0)')
1407         ri.cw.p('goto err_free;')
1408         ri.cw.nl()
1409
1410     ri.cw.p('err = ynl_recv_ack(ys, err);')
1411     ri.cw.p('if (err)')
1412     ri.cw.p('goto err_free;')
1413     ri.cw.nl()
1414     ri.cw.p(f"return {ret_ok};")
1415     ri.cw.nl()
1416     ri.cw.p('err_free:')
1417
1418     if 'reply' in ri.op[ri.op_mode]:
1419         ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1420     ri.cw.p(f"return {ret_err};")
1421     ri.cw.block_end()
1422
1423
1424 def print_dump(ri):
1425     direction = "request"
1426     print_prototype(ri, direction, terminate=False)
1427     ri.cw.block_start()
1428     local_vars = ['struct ynl_dump_state yds = {};',
1429                   'struct nlmsghdr *nlh;',
1430                   'int len, err;']
1431
1432     for var in local_vars:
1433         ri.cw.p(f'{var}')
1434     ri.cw.nl()
1435
1436     ri.cw.p('yds.ys = ys;')
1437     ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1438     ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1439     ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1440     ri.cw.nl()
1441     ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1442
1443     if "request" in ri.op[ri.op_mode]:
1444         ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1445         ri.cw.nl()
1446         for _, attr in ri.struct["request"].member_list():
1447             attr.attr_put(ri, "req")
1448     ri.cw.nl()
1449
1450     ri.cw.p('err = mnl_socket_sendto(ys->sock, nlh, nlh->nlmsg_len);')
1451     ri.cw.p('if (err < 0)')
1452     ri.cw.p('return NULL;')
1453     ri.cw.nl()
1454
1455     ri.cw.block_start(line='do')
1456     ri.cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1457     ri.cw.p('if (len < 0)')
1458     ri.cw.p('goto free_list;')
1459     ri.cw.nl()
1460     ri.cw.p(f"err = {ri.nl.parse_cb_run('ynl_dump_trampoline', '&yds', False, indent=2)};")
1461     ri.cw.p('if (err < 0)')
1462     ri.cw.p('goto free_list;')
1463     ri.cw.block_end(line='while (err > 0);')
1464     ri.cw.nl()
1465
1466     ri.cw.p('return yds.first;')
1467     ri.cw.nl()
1468     ri.cw.p('free_list:')
1469     ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1470     ri.cw.p('return NULL;')
1471     ri.cw.block_end()
1472
1473
1474 def call_free(ri, direction, var):
1475     return f"{op_prefix(ri, direction)}_free({var});"
1476
1477
1478 def free_arg_name(direction):
1479     if direction:
1480         return direction_to_suffix[direction][1:]
1481     return 'obj'
1482
1483
1484 def print_free_prototype(ri, direction, suffix=';'):
1485     name = op_prefix(ri, direction)
1486     arg = free_arg_name(direction)
1487     ri.cw.write_func_prot('void', f"{name}_free", [f"struct {name} *{arg}"], suffix=suffix)
1488
1489
1490 def _print_type(ri, direction, struct):
1491     suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1492
1493     if ri.op_mode == 'dump':
1494         suffix += '_dump'
1495
1496     ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1497
1498     meta_started = False
1499     for _, attr in struct.member_list():
1500         for type_filter in ['len', 'bit']:
1501             line = attr.presence_member(ri.ku_space, type_filter)
1502             if line:
1503                 if not meta_started:
1504                     ri.cw.block_start(line=f"struct")
1505                     meta_started = True
1506                 ri.cw.p(line)
1507     if meta_started:
1508         ri.cw.block_end(line='_present;')
1509         ri.cw.nl()
1510
1511     for arg in struct.inherited:
1512         ri.cw.p(f"__u32 {arg};")
1513
1514     for _, attr in struct.member_list():
1515         attr.struct_member(ri)
1516
1517     ri.cw.block_end(line=';')
1518     ri.cw.nl()
1519
1520
1521 def print_type(ri, direction):
1522     _print_type(ri, direction, ri.struct[direction])
1523
1524
1525 def print_type_full(ri, struct):
1526     _print_type(ri, "", struct)
1527
1528
1529 def print_type_helpers(ri, direction, deref=False):
1530     print_free_prototype(ri, direction)
1531
1532     if ri.ku_space == 'user' and direction == 'request':
1533         for _, attr in ri.struct[direction].member_list():
1534             attr.setter(ri, ri.attr_set, direction, deref=deref)
1535     ri.cw.nl()
1536
1537
1538 def print_req_type_helpers(ri):
1539     print_type_helpers(ri, "request")
1540
1541
1542 def print_rsp_type_helpers(ri):
1543     if 'reply' not in ri.op[ri.op_mode]:
1544         return
1545     print_type_helpers(ri, "reply")
1546
1547
1548 def print_parse_prototype(ri, direction, terminate=True):
1549     suffix = "_rsp" if direction == "reply" else "_req"
1550     term = ';' if terminate else ''
1551
1552     ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1553                           ['const struct nlattr **tb',
1554                            f"struct {ri.op.render_name}{suffix} *req"],
1555                           suffix=term)
1556
1557
1558 def print_req_type(ri):
1559     print_type(ri, "request")
1560
1561
1562 def print_rsp_type(ri):
1563     if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1564         direction = 'reply'
1565     elif ri.op_mode == 'event':
1566         direction = 'reply'
1567     else:
1568         return
1569     print_type(ri, direction)
1570
1571
1572 def print_wrapped_type(ri):
1573     ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1574     if ri.op_mode == 'dump':
1575         ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1576     elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1577         ri.cw.p('__u16 family;')
1578         ri.cw.p('__u8 cmd;')
1579         ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1580     ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1581     ri.cw.block_end(line=';')
1582     ri.cw.nl()
1583     print_free_prototype(ri, 'reply')
1584     ri.cw.nl()
1585
1586
1587 def _free_type_members_iter(ri, struct):
1588     for _, attr in struct.member_list():
1589         if attr.free_needs_iter():
1590             ri.cw.p('unsigned int i;')
1591             ri.cw.nl()
1592             break
1593
1594
1595 def _free_type_members(ri, var, struct, ref=''):
1596     for _, attr in struct.member_list():
1597         attr.free(ri, var, ref)
1598
1599
1600 def _free_type(ri, direction, struct):
1601     var = free_arg_name(direction)
1602
1603     print_free_prototype(ri, direction, suffix='')
1604     ri.cw.block_start()
1605     _free_type_members_iter(ri, struct)
1606     _free_type_members(ri, var, struct)
1607     if direction:
1608         ri.cw.p(f'free({var});')
1609     ri.cw.block_end()
1610     ri.cw.nl()
1611
1612
1613 def free_rsp_nested(ri, struct):
1614     _free_type(ri, "", struct)
1615
1616
1617 def print_rsp_free(ri):
1618     if 'reply' not in ri.op[ri.op_mode]:
1619         return
1620     _free_type(ri, 'reply', ri.struct['reply'])
1621
1622
1623 def print_dump_type_free(ri):
1624     sub_type = type_name(ri, 'reply')
1625
1626     print_free_prototype(ri, 'reply', suffix='')
1627     ri.cw.block_start()
1628     ri.cw.p(f"{sub_type} *next = rsp;")
1629     ri.cw.nl()
1630     ri.cw.block_start(line='while (next)')
1631     _free_type_members_iter(ri, ri.struct['reply'])
1632     ri.cw.p('rsp = next;')
1633     ri.cw.p('next = rsp->next;')
1634     ri.cw.nl()
1635
1636     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1637     ri.cw.p(f'free(rsp);')
1638     ri.cw.block_end()
1639     ri.cw.block_end()
1640     ri.cw.nl()
1641
1642
1643 def print_ntf_type_free(ri):
1644     print_free_prototype(ri, 'reply', suffix='')
1645     ri.cw.block_start()
1646     _free_type_members_iter(ri, ri.struct['reply'])
1647     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1648     ri.cw.p(f'free(rsp);')
1649     ri.cw.block_end()
1650     ri.cw.nl()
1651
1652
1653 def print_ntf_parse_prototype(family, cw, suffix=';'):
1654     cw.write_func_prot('struct ynl_ntf_base_type *', f"{family['name']}_ntf_parse",
1655                        ['struct ynl_sock *ys'], suffix=suffix)
1656
1657
1658 def print_ntf_type_parse(family, cw, ku_mode):
1659     print_ntf_parse_prototype(family, cw, suffix='')
1660     cw.block_start()
1661     cw.write_func_lvar(['struct genlmsghdr *genlh;',
1662                         'struct nlmsghdr *nlh;',
1663                         'struct ynl_parse_arg yarg = { .ys = ys, };',
1664                         'struct ynl_ntf_base_type *rsp;',
1665                         'int len, err;',
1666                         'mnl_cb_t parse;'])
1667     cw.p('len = mnl_socket_recvfrom(ys->sock, ys->rx_buf, MNL_SOCKET_BUFFER_SIZE);')
1668     cw.p('if (len < (ssize_t)(sizeof(*nlh) + sizeof(*genlh)))')
1669     cw.p('return NULL;')
1670     cw.nl()
1671     cw.p('nlh = (struct nlmsghdr *)ys->rx_buf;')
1672     cw.p('genlh = mnl_nlmsg_get_payload(nlh);')
1673     cw.nl()
1674     cw.block_start(line='switch (genlh->cmd)')
1675     for ntf_op in sorted(family.all_notify.keys()):
1676         op = family.ops[ntf_op]
1677         ri = RenderInfo(cw, family, ku_mode, op, ntf_op, "notify")
1678         for ntf in op['notify']['cmds']:
1679             cw.p(f"case {ntf.enum_name}:")
1680         cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'notify')}));")
1681         cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;")
1682         cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1683         cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;")
1684         cw.p('break;')
1685     for op_name, op in family.ops.items():
1686         if 'event' not in op:
1687             continue
1688         ri = RenderInfo(cw, family, ku_mode, op, op_name, "event")
1689         cw.p(f"case {op.enum_name}:")
1690         cw.p(f"rsp = calloc(1, sizeof({type_name(ri, 'event')}));")
1691         cw.p(f"parse = {op_prefix(ri, 'reply', deref=True)}_parse;")
1692         cw.p(f"yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1693         cw.p(f"rsp->free = (void *){op_prefix(ri, 'notify')}_free;")
1694         cw.p('break;')
1695     cw.p('default:')
1696     cw.p('ynl_error_unknown_notification(ys, genlh->cmd);')
1697     cw.p('return NULL;')
1698     cw.block_end()
1699     cw.nl()
1700     cw.p('yarg.data = rsp->data;')
1701     cw.nl()
1702     cw.p(f"err = {cw.nlib.parse_cb_run('parse', '&yarg', True)};")
1703     cw.p('if (err < 0)')
1704     cw.p('goto err_free;')
1705     cw.nl()
1706     cw.p('rsp->family = nlh->nlmsg_type;')
1707     cw.p('rsp->cmd = genlh->cmd;')
1708     cw.p('return rsp;')
1709     cw.nl()
1710     cw.p('err_free:')
1711     cw.p('free(rsp);')
1712     cw.p('return NULL;')
1713     cw.block_end()
1714     cw.nl()
1715
1716
1717 def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1718     if terminate and ri and kernel_can_gen_family_struct(struct.family):
1719         return
1720
1721     if terminate:
1722         prefix = 'extern '
1723     else:
1724         if kernel_can_gen_family_struct(struct.family) and ri:
1725             prefix = 'static '
1726         else:
1727             prefix = ''
1728
1729     suffix = ';' if terminate else ' = {'
1730
1731     max_attr = struct.attr_max_val
1732     if ri:
1733         name = ri.op.render_name
1734         if ri.op.dual_policy:
1735             name += '_' + ri.op_mode
1736     else:
1737         name = struct.render_name
1738     cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1739
1740
1741 def print_req_policy(cw, struct, ri=None):
1742     print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1743     for _, arg in struct.member_list():
1744         arg.attr_policy(cw)
1745     cw.p("};")
1746
1747
1748 def kernel_can_gen_family_struct(family):
1749     return family.proto == 'genetlink'
1750
1751
1752 def print_kernel_op_table_fwd(family, cw, terminate):
1753     exported = not kernel_can_gen_family_struct(family)
1754
1755     if not terminate or exported:
1756         cw.p(f"/* Ops table for {family.name} */")
1757
1758         pol_to_struct = {'global': 'genl_small_ops',
1759                          'per-op': 'genl_ops',
1760                          'split': 'genl_split_ops'}
1761         struct_type = pol_to_struct[family.kernel_policy]
1762
1763         if family.kernel_policy == 'split':
1764             cnt = 0
1765             for op in family.ops.values():
1766                 if 'do' in op:
1767                     cnt += 1
1768                 if 'dump' in op:
1769                     cnt += 1
1770         else:
1771             cnt = len(family.ops)
1772
1773         qual = 'static const' if not exported else 'const'
1774         line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1775         if terminate:
1776             cw.p(f"extern {line};")
1777         else:
1778             cw.block_start(line=line + ' =')
1779
1780     if not terminate:
1781         return
1782
1783     cw.nl()
1784     for name in family.hooks['pre']['do']['list']:
1785         cw.write_func_prot('int', c_lower(name),
1786                            ['const struct genl_split_ops *ops',
1787                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1788     for name in family.hooks['post']['do']['list']:
1789         cw.write_func_prot('void', c_lower(name),
1790                            ['const struct genl_split_ops *ops',
1791                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1792     for name in family.hooks['pre']['dump']['list']:
1793         cw.write_func_prot('int', c_lower(name),
1794                            ['struct netlink_callback *cb'], suffix=';')
1795     for name in family.hooks['post']['dump']['list']:
1796         cw.write_func_prot('int', c_lower(name),
1797                            ['struct netlink_callback *cb'], suffix=';')
1798
1799     cw.nl()
1800
1801     for op_name, op in family.ops.items():
1802         if op.is_async:
1803             continue
1804
1805         if 'do' in op:
1806             name = c_lower(f"{family.name}-nl-{op_name}-doit")
1807             cw.write_func_prot('int', name,
1808                                ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1809
1810         if 'dump' in op:
1811             name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1812             cw.write_func_prot('int', name,
1813                                ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1814     cw.nl()
1815
1816
1817 def print_kernel_op_table_hdr(family, cw):
1818     print_kernel_op_table_fwd(family, cw, terminate=True)
1819
1820
1821 def print_kernel_op_table(family, cw):
1822     print_kernel_op_table_fwd(family, cw, terminate=False)
1823     if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1824         for op_name, op in family.ops.items():
1825             if op.is_async:
1826                 continue
1827
1828             cw.block_start()
1829             members = [('cmd', op.enum_name)]
1830             if 'dont-validate' in op:
1831                 members.append(('validate',
1832                                 ' | '.join([c_upper('genl-dont-validate-' + x)
1833                                             for x in op['dont-validate']])), )
1834             for op_mode in ['do', 'dump']:
1835                 if op_mode in op:
1836                     name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1837                     members.append((op_mode + 'it', name))
1838             if family.kernel_policy == 'per-op':
1839                 struct = Struct(family, op['attribute-set'],
1840                                 type_list=op['do']['request']['attributes'])
1841
1842                 name = c_lower(f"{family.name}-{op_name}-nl-policy")
1843                 members.append(('policy', name))
1844                 members.append(('maxattr', struct.attr_max_val.enum_name))
1845             if 'flags' in op:
1846                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1847             cw.write_struct_init(members)
1848             cw.block_end(line=',')
1849     elif family.kernel_policy == 'split':
1850         cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1851                     'dump': {'pre': 'start', 'post': 'done'}}
1852
1853         for op_name, op in family.ops.items():
1854             for op_mode in ['do', 'dump']:
1855                 if op.is_async or op_mode not in op:
1856                     continue
1857
1858                 cw.block_start()
1859                 members = [('cmd', op.enum_name)]
1860                 if 'dont-validate' in op:
1861                     members.append(('validate',
1862                                     ' | '.join([c_upper('genl-dont-validate-' + x)
1863                                                 for x in op['dont-validate']])), )
1864                 name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1865                 if 'pre' in op[op_mode]:
1866                     members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
1867                 members.append((op_mode + 'it', name))
1868                 if 'post' in op[op_mode]:
1869                     members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
1870                 if 'request' in op[op_mode]:
1871                     struct = Struct(family, op['attribute-set'],
1872                                     type_list=op[op_mode]['request']['attributes'])
1873
1874                     if op.dual_policy:
1875                         name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
1876                     else:
1877                         name = c_lower(f"{family.name}-{op_name}-nl-policy")
1878                     members.append(('policy', name))
1879                     members.append(('maxattr', struct.attr_max_val.enum_name))
1880                 flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
1881                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
1882                 cw.write_struct_init(members)
1883                 cw.block_end(line=',')
1884
1885     cw.block_end(line=';')
1886     cw.nl()
1887
1888
1889 def print_kernel_mcgrp_hdr(family, cw):
1890     if not family.mcgrps['list']:
1891         return
1892
1893     cw.block_start('enum')
1894     for grp in family.mcgrps['list']:
1895         grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
1896         cw.p(grp_id)
1897     cw.block_end(';')
1898     cw.nl()
1899
1900
1901 def print_kernel_mcgrp_src(family, cw):
1902     if not family.mcgrps['list']:
1903         return
1904
1905     cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
1906     for grp in family.mcgrps['list']:
1907         name = grp['name']
1908         grp_id = c_upper(f"{family.name}-nlgrp-{name}")
1909         cw.p('[' + grp_id + '] = { "' + name + '", },')
1910     cw.block_end(';')
1911     cw.nl()
1912
1913
1914 def print_kernel_family_struct_hdr(family, cw):
1915     if not kernel_can_gen_family_struct(family):
1916         return
1917
1918     cw.p(f"extern struct genl_family {family.name}_nl_family;")
1919     cw.nl()
1920
1921
1922 def print_kernel_family_struct_src(family, cw):
1923     if not kernel_can_gen_family_struct(family):
1924         return
1925
1926     cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
1927     cw.p('.name\t\t= ' + family.fam_key + ',')
1928     cw.p('.version\t= ' + family.ver_key + ',')
1929     cw.p('.netnsok\t= true,')
1930     cw.p('.parallel_ops\t= true,')
1931     cw.p('.module\t\t= THIS_MODULE,')
1932     if family.kernel_policy == 'per-op':
1933         cw.p(f'.ops\t\t= {family.name}_nl_ops,')
1934         cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
1935     elif family.kernel_policy == 'split':
1936         cw.p(f'.split_ops\t= {family.name}_nl_ops,')
1937         cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
1938     if family.mcgrps['list']:
1939         cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
1940         cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
1941     cw.block_end(';')
1942
1943
1944 def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
1945     start_line = 'enum'
1946     if enum_name in obj:
1947         if obj[enum_name]:
1948             start_line = 'enum ' + c_lower(obj[enum_name])
1949     elif ckey and ckey in obj:
1950         start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
1951     cw.block_start(line=start_line)
1952
1953
1954 def render_uapi(family, cw):
1955     hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
1956     cw.p('#ifndef ' + hdr_prot)
1957     cw.p('#define ' + hdr_prot)
1958     cw.nl()
1959
1960     defines = [(family.fam_key, family["name"]),
1961                (family.ver_key, family.get('version', 1))]
1962     cw.writes_defines(defines)
1963     cw.nl()
1964
1965     defines = []
1966     for const in family['definitions']:
1967         if const['type'] != 'const':
1968             cw.writes_defines(defines)
1969             defines = []
1970             cw.nl()
1971
1972         # Write kdoc for enum and flags (one day maybe also structs)
1973         if const['type'] == 'enum' or const['type'] == 'flags':
1974             enum = family.consts[const['name']]
1975
1976             if enum.has_doc():
1977                 cw.p('/**')
1978                 doc = ''
1979                 if 'doc' in enum:
1980                     doc = ' - ' + enum['doc']
1981                 cw.write_doc_line(enum.enum_name + doc)
1982                 for entry in enum.entry_list:
1983                     if entry.has_doc():
1984                         doc = '@' + entry.c_name + ': ' + entry['doc']
1985                         cw.write_doc_line(doc)
1986                 cw.p(' */')
1987
1988             uapi_enum_start(family, cw, const, 'name')
1989             name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
1990             for entry in enum.entry_list:
1991                 suffix = ','
1992                 if entry.value_change:
1993                     suffix = f" = {entry.user_value()}" + suffix
1994                 cw.p(entry.c_name + suffix)
1995
1996             if const.get('render-max', False):
1997                 cw.nl()
1998                 max_name = c_upper(name_pfx + 'max')
1999                 cw.p('__' + max_name + ',')
2000                 cw.p(max_name + ' = (__' + max_name + ' - 1)')
2001             cw.block_end(line=';')
2002             cw.nl()
2003         elif const['type'] == 'const':
2004             defines.append([c_upper(family.get('c-define-name',
2005                                                f"{family.name}-{const['name']}")),
2006                             const['value']])
2007
2008     if defines:
2009         cw.writes_defines(defines)
2010         cw.nl()
2011
2012     max_by_define = family.get('max-by-define', False)
2013
2014     for _, attr_set in family.attr_sets.items():
2015         if attr_set.subset_of:
2016             continue
2017
2018         cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
2019         max_value = f"({cnt_name} - 1)"
2020
2021         val = 0
2022         uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2023         for _, attr in attr_set.items():
2024             suffix = ','
2025             if attr.value != val:
2026                 suffix = f" = {attr.value},"
2027                 val = attr.value
2028             val += 1
2029             cw.p(attr.enum_name + suffix)
2030         cw.nl()
2031         cw.p(cnt_name + ('' if max_by_define else ','))
2032         if not max_by_define:
2033             cw.p(f"{attr_set.max_name} = {max_value}")
2034         cw.block_end(line=';')
2035         if max_by_define:
2036             cw.p(f"#define {attr_set.max_name} {max_value}")
2037         cw.nl()
2038
2039     # Commands
2040     separate_ntf = 'async-prefix' in family['operations']
2041
2042     max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2043     cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2044     max_value = f"({cnt_name} - 1)"
2045
2046     uapi_enum_start(family, cw, family['operations'], 'enum-name')
2047     for op in family.msgs.values():
2048         if separate_ntf and ('notify' in op or 'event' in op):
2049             continue
2050
2051         suffix = ','
2052         if 'value' in op:
2053             suffix = f" = {op['value']},"
2054         cw.p(op.enum_name + suffix)
2055     cw.nl()
2056     cw.p(cnt_name + ('' if max_by_define else ','))
2057     if not max_by_define:
2058         cw.p(f"{max_name} = {max_value}")
2059     cw.block_end(line=';')
2060     if max_by_define:
2061         cw.p(f"#define {max_name} {max_value}")
2062     cw.nl()
2063
2064     if separate_ntf:
2065         uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2066         for op in family.msgs.values():
2067             if separate_ntf and not ('notify' in op or 'event' in op):
2068                 continue
2069
2070             suffix = ','
2071             if 'value' in op:
2072                 suffix = f" = {op['value']},"
2073             cw.p(op.enum_name + suffix)
2074         cw.block_end(line=';')
2075         cw.nl()
2076
2077     # Multicast
2078     defines = []
2079     for grp in family.mcgrps['list']:
2080         name = grp['name']
2081         defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2082                         f'{name}'])
2083     cw.nl()
2084     if defines:
2085         cw.writes_defines(defines)
2086         cw.nl()
2087
2088     cw.p(f'#endif /* {hdr_prot} */')
2089
2090
2091 def find_kernel_root(full_path):
2092     sub_path = ''
2093     while True:
2094         sub_path = os.path.join(os.path.basename(full_path), sub_path)
2095         full_path = os.path.dirname(full_path)
2096         maintainers = os.path.join(full_path, "MAINTAINERS")
2097         if os.path.exists(maintainers):
2098             return full_path, sub_path[:-1]
2099
2100
2101 def main():
2102     parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2103     parser.add_argument('--mode', dest='mode', type=str, required=True)
2104     parser.add_argument('--spec', dest='spec', type=str, required=True)
2105     parser.add_argument('--header', dest='header', action='store_true', default=None)
2106     parser.add_argument('--source', dest='header', action='store_false')
2107     parser.add_argument('--user-header', nargs='+', default=[])
2108     parser.add_argument('-o', dest='out_file', type=str)
2109     args = parser.parse_args()
2110
2111     out_file = open(args.out_file, 'w+') if args.out_file else os.sys.stdout
2112
2113     if args.header is None:
2114         parser.error("--header or --source is required")
2115
2116     try:
2117         parsed = Family(args.spec)
2118     except yaml.YAMLError as exc:
2119         print(exc)
2120         os.sys.exit(1)
2121         return
2122
2123     cw = CodeWriter(BaseNlLib(), out_file)
2124
2125     _, spec_kernel = find_kernel_root(args.spec)
2126     if args.mode == 'uapi':
2127         cw.p('/* SPDX-License-Identifier: GPL-2.0 WITH Linux-syscall-note */')
2128     else:
2129         if args.header:
2130             cw.p('/* SPDX-License-Identifier: BSD-3-Clause */')
2131         else:
2132             cw.p('// SPDX-License-Identifier: BSD-3-Clause')
2133     cw.p("/* Do not edit directly, auto-generated from: */")
2134     cw.p(f"/*\t{spec_kernel} */")
2135     cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2136     cw.nl()
2137
2138     if args.mode == 'uapi':
2139         render_uapi(parsed, cw)
2140         return
2141
2142     hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2143     if args.header:
2144         cw.p('#ifndef ' + hdr_prot)
2145         cw.p('#define ' + hdr_prot)
2146         cw.nl()
2147
2148     if args.mode == 'kernel':
2149         cw.p('#include <net/netlink.h>')
2150         cw.p('#include <net/genetlink.h>')
2151         cw.nl()
2152         if not args.header:
2153             if args.out_file:
2154                 cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2155             cw.nl()
2156     headers = [parsed.uapi_header]
2157     for definition in parsed['definitions']:
2158         if 'header' in definition:
2159             headers.append(definition['header'])
2160     for one in headers:
2161         cw.p(f"#include <{one}>")
2162     cw.nl()
2163
2164     if args.mode == "user":
2165         if not args.header:
2166             cw.p("#include <stdlib.h>")
2167             cw.p("#include <stdio.h>")
2168             cw.p("#include <string.h>")
2169             cw.p("#include <libmnl/libmnl.h>")
2170             cw.p("#include <linux/genetlink.h>")
2171             cw.nl()
2172             for one in args.user_header:
2173                 cw.p(f'#include "{one}"')
2174         else:
2175             cw.p('struct ynl_sock;')
2176         cw.nl()
2177
2178     if args.mode == "kernel":
2179         if args.header:
2180             for _, struct in sorted(parsed.pure_nested_structs.items()):
2181                 if struct.request:
2182                     cw.p('/* Common nested types */')
2183                     break
2184             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2185                 if struct.request:
2186                     print_req_policy_fwd(cw, struct)
2187             cw.nl()
2188
2189             if parsed.kernel_policy == 'global':
2190                 cw.p(f"/* Global operation policy for {parsed.name} */")
2191
2192                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2193                 print_req_policy_fwd(cw, struct)
2194                 cw.nl()
2195
2196             if parsed.kernel_policy in {'per-op', 'split'}:
2197                 for op_name, op in parsed.ops.items():
2198                     if 'do' in op and 'event' not in op:
2199                         ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2200                         print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2201                         cw.nl()
2202
2203             print_kernel_op_table_hdr(parsed, cw)
2204             print_kernel_mcgrp_hdr(parsed, cw)
2205             print_kernel_family_struct_hdr(parsed, cw)
2206         else:
2207             for _, struct in sorted(parsed.pure_nested_structs.items()):
2208                 if struct.request:
2209                     cw.p('/* Common nested types */')
2210                     break
2211             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2212                 if struct.request:
2213                     print_req_policy(cw, struct)
2214             cw.nl()
2215
2216             if parsed.kernel_policy == 'global':
2217                 cw.p(f"/* Global operation policy for {parsed.name} */")
2218
2219                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2220                 print_req_policy(cw, struct)
2221                 cw.nl()
2222
2223             for op_name, op in parsed.ops.items():
2224                 if parsed.kernel_policy in {'per-op', 'split'}:
2225                     for op_mode in ['do', 'dump']:
2226                         if op_mode in op and 'request' in op[op_mode]:
2227                             cw.p(f"/* {op.enum_name} - {op_mode} */")
2228                             ri = RenderInfo(cw, parsed, args.mode, op, op_name, op_mode)
2229                             print_req_policy(cw, ri.struct['request'], ri=ri)
2230                             cw.nl()
2231
2232             print_kernel_op_table(parsed, cw)
2233             print_kernel_mcgrp_src(parsed, cw)
2234             print_kernel_family_struct_src(parsed, cw)
2235
2236     if args.mode == "user":
2237         has_ntf = False
2238         if args.header:
2239             cw.p('/* Common nested types */')
2240             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2241                 ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2242                 print_type_full(ri, struct)
2243
2244             for op_name, op in parsed.ops.items():
2245                 cw.p(f"/* ============== {op.enum_name} ============== */")
2246
2247                 if 'do' in op and 'event' not in op:
2248                     cw.p(f"/* {op.enum_name} - do */")
2249                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2250                     print_req_type(ri)
2251                     print_req_type_helpers(ri)
2252                     cw.nl()
2253                     print_rsp_type(ri)
2254                     print_rsp_type_helpers(ri)
2255                     cw.nl()
2256                     print_req_prototype(ri)
2257                     cw.nl()
2258
2259                 if 'dump' in op:
2260                     cw.p(f"/* {op.enum_name} - dump */")
2261                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'dump')
2262                     if 'request' in op['dump']:
2263                         print_req_type(ri)
2264                         print_req_type_helpers(ri)
2265                     if not ri.type_consistent:
2266                         print_rsp_type(ri)
2267                     print_wrapped_type(ri)
2268                     print_dump_prototype(ri)
2269                     cw.nl()
2270
2271                 if 'notify' in op:
2272                     cw.p(f"/* {op.enum_name} - notify */")
2273                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2274                     has_ntf = True
2275                     if not ri.type_consistent:
2276                         raise Exception('Only notifications with consistent types supported')
2277                     print_wrapped_type(ri)
2278
2279                 if 'event' in op:
2280                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'event')
2281                     cw.p(f"/* {op.enum_name} - event */")
2282                     print_rsp_type(ri)
2283                     cw.nl()
2284                     print_wrapped_type(ri)
2285
2286             if has_ntf:
2287                 cw.p('/* --------------- Common notification parsing --------------- */')
2288                 print_ntf_parse_prototype(parsed, cw)
2289             cw.nl()
2290         else:
2291             cw.p('/* Policies */')
2292             for name, _ in parsed.attr_sets.items():
2293                 struct = Struct(parsed, name)
2294                 put_typol_fwd(cw, struct)
2295             cw.nl()
2296
2297             for name, _ in parsed.attr_sets.items():
2298                 struct = Struct(parsed, name)
2299                 put_typol(cw, struct)
2300
2301             cw.p('/* Common nested types */')
2302             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2303                 ri = RenderInfo(cw, parsed, args.mode, "", "", "", attr_set)
2304
2305                 free_rsp_nested(ri, struct)
2306                 if struct.request:
2307                     put_req_nested(ri, struct)
2308                 if struct.reply:
2309                     parse_rsp_nested(ri, struct)
2310
2311             for op_name, op in parsed.ops.items():
2312                 cw.p(f"/* ============== {op.enum_name} ============== */")
2313                 if 'do' in op and 'event' not in op:
2314                     cw.p(f"/* {op.enum_name} - do */")
2315                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2316                     print_rsp_free(ri)
2317                     parse_rsp_msg(ri)
2318                     print_req(ri)
2319                     cw.nl()
2320
2321                 if 'dump' in op:
2322                     cw.p(f"/* {op.enum_name} - dump */")
2323                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "dump")
2324                     if not ri.type_consistent:
2325                         parse_rsp_msg(ri, deref=True)
2326                     print_dump_type_free(ri)
2327                     print_dump(ri)
2328                     cw.nl()
2329
2330                 if 'notify' in op:
2331                     cw.p(f"/* {op.enum_name} - notify */")
2332                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, 'notify')
2333                     has_ntf = True
2334                     if not ri.type_consistent:
2335                         raise Exception('Only notifications with consistent types supported')
2336                     print_ntf_type_free(ri)
2337
2338                 if 'event' in op:
2339                     cw.p(f"/* {op.enum_name} - event */")
2340                     has_ntf = True
2341
2342                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "do")
2343                     parse_rsp_msg(ri)
2344
2345                     ri = RenderInfo(cw, parsed, args.mode, op, op_name, "event")
2346                     print_ntf_type_free(ri)
2347
2348             if has_ntf:
2349                 cw.p('/* --------------- Common notification parsing --------------- */')
2350                 print_ntf_type_parse(parsed, cw, args.mode)
2351
2352     if args.header:
2353         cw.p(f'#endif /* {hdr_prot} */')
2354
2355
2356 if __name__ == "__main__":
2357     main()