tools: ynl-gen: resolve enum vs struct name conflicts
[platform/kernel/linux-rpi.git] / tools / net / ynl / ynl-gen-c.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4 import argparse
5 import collections
6 import os
7 import re
8 import yaml
9
10 from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
11
12
13 def c_upper(name):
14     return name.upper().replace('-', '_')
15
16
17 def c_lower(name):
18     return name.lower().replace('-', '_')
19
20
21 class BaseNlLib:
22     def get_family_id(self):
23         return 'ys->family_id'
24
25     def parse_cb_run(self, cb, data, is_dump=False, indent=1):
26         ind = '\n\t\t' + '\t' * indent + ' '
27         if is_dump:
28             return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
29         else:
30             return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
31                    "ynl_cb_array, NLMSG_MIN_TYPE)"
32
33
34 class Type(SpecAttr):
35     def __init__(self, family, attr_set, attr, value):
36         super().__init__(family, attr_set, attr, value)
37
38         self.attr = attr
39         self.attr_set = attr_set
40         self.type = attr['type']
41         self.checks = attr.get('checks', {})
42
43         if 'len' in attr:
44             self.len = attr['len']
45         if 'nested-attributes' in attr:
46             self.nested_attrs = attr['nested-attributes']
47             if self.nested_attrs == family.name:
48                 self.nested_render_name = f"{family.name}"
49             else:
50                 self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
51
52             if self.nested_attrs in self.family.consts:
53                 self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
54             else:
55                 self.nested_struct_type = 'struct ' + self.nested_render_name
56
57         self.c_name = c_lower(self.name)
58         if self.c_name in _C_KW:
59             self.c_name += '_'
60
61         # Added by resolve():
62         self.enum_name = None
63         delattr(self, "enum_name")
64
65     def resolve(self):
66         if 'name-prefix' in self.attr:
67             enum_name = f"{self.attr['name-prefix']}{self.name}"
68         else:
69             enum_name = f"{self.attr_set.name_prefix}{self.name}"
70         self.enum_name = c_upper(enum_name)
71
72     def is_multi_val(self):
73         return None
74
75     def is_scalar(self):
76         return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
77
78     def presence_type(self):
79         return 'bit'
80
81     def presence_member(self, space, type_filter):
82         if self.presence_type() != type_filter:
83             return
84
85         if self.presence_type() == 'bit':
86             pfx = '__' if space == 'user' else ''
87             return f"{pfx}u32 {self.c_name}:1;"
88
89         if self.presence_type() == 'len':
90             pfx = '__' if space == 'user' else ''
91             return f"{pfx}u32 {self.c_name}_len;"
92
93     def _complex_member_type(self, ri):
94         return None
95
96     def free_needs_iter(self):
97         return False
98
99     def free(self, ri, var, ref):
100         if self.is_multi_val() or self.presence_type() == 'len':
101             ri.cw.p(f'free({var}->{ref}{self.c_name});')
102
103     def arg_member(self, ri):
104         member = self._complex_member_type(ri)
105         if member:
106             arg = [member + ' *' + self.c_name]
107             if self.presence_type() == 'count':
108                 arg += ['unsigned int n_' + self.c_name]
109             return arg
110         raise Exception(f"Struct member not implemented for class type {self.type}")
111
112     def struct_member(self, ri):
113         if self.is_multi_val():
114             ri.cw.p(f"unsigned int n_{self.c_name};")
115         member = self._complex_member_type(ri)
116         if member:
117             ptr = '*' if self.is_multi_val() else ''
118             ri.cw.p(f"{member} {ptr}{self.c_name};")
119             return
120         members = self.arg_member(ri)
121         for one in members:
122             ri.cw.p(one + ';')
123
124     def _attr_policy(self, policy):
125         return '{ .type = ' + policy + ', }'
126
127     def attr_policy(self, cw):
128         policy = c_upper('nla-' + self.attr['type'])
129
130         spec = self._attr_policy(policy)
131         cw.p(f"\t[{self.enum_name}] = {spec},")
132
133     def _attr_typol(self):
134         raise Exception(f"Type policy not implemented for class type {self.type}")
135
136     def attr_typol(self, cw):
137         typol = self._attr_typol()
138         cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
139
140     def _attr_put_line(self, ri, var, line):
141         if self.presence_type() == 'bit':
142             ri.cw.p(f"if ({var}->_present.{self.c_name})")
143         elif self.presence_type() == 'len':
144             ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
145         ri.cw.p(f"{line};")
146
147     def _attr_put_simple(self, ri, var, put_type):
148         line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
149         self._attr_put_line(ri, var, line)
150
151     def attr_put(self, ri, var):
152         raise Exception(f"Put not implemented for class type {self.type}")
153
154     def _attr_get(self, ri, var):
155         raise Exception(f"Attr get not implemented for class type {self.type}")
156
157     def attr_get(self, ri, var, first):
158         lines, init_lines, local_vars = self._attr_get(ri, var)
159         if type(lines) is str:
160             lines = [lines]
161         if type(init_lines) is str:
162             init_lines = [init_lines]
163
164         kw = 'if' if first else 'else if'
165         ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
166         if local_vars:
167             for local in local_vars:
168                 ri.cw.p(local)
169             ri.cw.nl()
170
171         if not self.is_multi_val():
172             ri.cw.p("if (ynl_attr_validate(yarg, attr))")
173             ri.cw.p("return MNL_CB_ERROR;")
174             if self.presence_type() == 'bit':
175                 ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
176
177         if init_lines:
178             ri.cw.nl()
179             for line in init_lines:
180                 ri.cw.p(line)
181
182         for line in lines:
183             ri.cw.p(line)
184         ri.cw.block_end()
185         return True
186
187     def _setter_lines(self, ri, member, presence):
188         raise Exception(f"Setter not implemented for class type {self.type}")
189
190     def setter(self, ri, space, direction, deref=False, ref=None):
191         ref = (ref if ref else []) + [self.c_name]
192         var = "req"
193         member = f"{var}->{'.'.join(ref)}"
194
195         code = []
196         presence = ''
197         for i in range(0, len(ref)):
198             presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
199             if self.presence_type() == 'bit':
200                 code.append(presence + ' = 1;')
201         code += self._setter_lines(ri, member, presence)
202
203         func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
204         free = bool([x for x in code if 'free(' in x])
205         alloc = bool([x for x in code if 'alloc(' in x])
206         if free and not alloc:
207             func_name = '__' + func_name
208         ri.cw.write_func('static inline void', func_name, body=code,
209                          args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
210
211
212 class TypeUnused(Type):
213     def presence_type(self):
214         return ''
215
216     def arg_member(self, ri):
217         return []
218
219     def _attr_get(self, ri, var):
220         return ['return MNL_CB_ERROR;'], None, None
221
222     def _attr_typol(self):
223         return '.type = YNL_PT_REJECT, '
224
225     def attr_policy(self, cw):
226         pass
227
228
229 class TypePad(Type):
230     def presence_type(self):
231         return ''
232
233     def arg_member(self, ri):
234         return []
235
236     def _attr_typol(self):
237         return '.type = YNL_PT_IGNORE, '
238
239     def attr_put(self, ri, var):
240         pass
241
242     def attr_get(self, ri, var, first):
243         pass
244
245     def attr_policy(self, cw):
246         pass
247
248     def setter(self, ri, space, direction, deref=False, ref=None):
249         pass
250
251
252 class TypeScalar(Type):
253     def __init__(self, family, attr_set, attr, value):
254         super().__init__(family, attr_set, attr, value)
255
256         self.byte_order_comment = ''
257         if 'byte-order' in attr:
258             self.byte_order_comment = f" /* {attr['byte-order']} */"
259
260         # Added by resolve():
261         self.is_bitfield = None
262         delattr(self, "is_bitfield")
263         self.type_name = None
264         delattr(self, "type_name")
265
266     def resolve(self):
267         self.resolve_up(super())
268
269         if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
270             self.is_bitfield = True
271         elif 'enum' in self.attr:
272             self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
273         else:
274             self.is_bitfield = False
275
276         maybe_enum = not self.is_bitfield and 'enum' in self.attr
277         if maybe_enum and self.family.consts[self.attr['enum']].enum_name:
278             self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
279         else:
280             self.type_name = '__' + self.type
281
282     def _mnl_type(self):
283         t = self.type
284         # mnl does not have a helper for signed types
285         if t[0] == 's':
286             t = 'u' + t[1:]
287         return t
288
289     def _attr_policy(self, policy):
290         if 'flags-mask' in self.checks or self.is_bitfield:
291             if self.is_bitfield:
292                 enum = self.family.consts[self.attr['enum']]
293                 mask = enum.get_mask(as_flags=True)
294             else:
295                 flags = self.family.consts[self.checks['flags-mask']]
296                 flag_cnt = len(flags['entries'])
297                 mask = (1 << flag_cnt) - 1
298             return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
299         elif 'min' in self.checks:
300             return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
301         elif 'enum' in self.attr:
302             enum = self.family.consts[self.attr['enum']]
303             cnt = len(enum['entries'])
304             return f"NLA_POLICY_MAX({policy}, {cnt - 1})"
305         return super()._attr_policy(policy)
306
307     def _attr_typol(self):
308         return f'.type = YNL_PT_U{self.type[1:]}, '
309
310     def arg_member(self, ri):
311         return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
312
313     def attr_put(self, ri, var):
314         self._attr_put_simple(ri, var, self._mnl_type())
315
316     def _attr_get(self, ri, var):
317         return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
318
319     def _setter_lines(self, ri, member, presence):
320         return [f"{member} = {self.c_name};"]
321
322
323 class TypeFlag(Type):
324     def arg_member(self, ri):
325         return []
326
327     def _attr_typol(self):
328         return '.type = YNL_PT_FLAG, '
329
330     def attr_put(self, ri, var):
331         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
332
333     def _attr_get(self, ri, var):
334         return [], None, None
335
336     def _setter_lines(self, ri, member, presence):
337         return []
338
339
340 class TypeString(Type):
341     def arg_member(self, ri):
342         return [f"const char *{self.c_name}"]
343
344     def presence_type(self):
345         return 'len'
346
347     def struct_member(self, ri):
348         ri.cw.p(f"char *{self.c_name};")
349
350     def _attr_typol(self):
351         return f'.type = YNL_PT_NUL_STR, '
352
353     def _attr_policy(self, policy):
354         mem = '{ .type = ' + policy
355         if 'max-len' in self.checks:
356             mem += ', .len = ' + str(self.checks['max-len'])
357         mem += ', }'
358         return mem
359
360     def attr_policy(self, cw):
361         if self.checks.get('unterminated-ok', False):
362             policy = 'NLA_STRING'
363         else:
364             policy = 'NLA_NUL_STRING'
365
366         spec = self._attr_policy(policy)
367         cw.p(f"\t[{self.enum_name}] = {spec},")
368
369     def attr_put(self, ri, var):
370         self._attr_put_simple(ri, var, 'strz')
371
372     def _attr_get(self, ri, var):
373         len_mem = var + '->_present.' + self.c_name + '_len'
374         return [f"{len_mem} = len;",
375                 f"{var}->{self.c_name} = malloc(len + 1);",
376                 f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
377                 f"{var}->{self.c_name}[len] = 0;"], \
378                ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
379                ['unsigned int len;']
380
381     def _setter_lines(self, ri, member, presence):
382         return [f"free({member});",
383                 f"{presence}_len = strlen({self.c_name});",
384                 f"{member} = malloc({presence}_len + 1);",
385                 f'memcpy({member}, {self.c_name}, {presence}_len);',
386                 f'{member}[{presence}_len] = 0;']
387
388
389 class TypeBinary(Type):
390     def arg_member(self, ri):
391         return [f"const void *{self.c_name}", 'size_t len']
392
393     def presence_type(self):
394         return 'len'
395
396     def struct_member(self, ri):
397         ri.cw.p(f"void *{self.c_name};")
398
399     def _attr_typol(self):
400         return f'.type = YNL_PT_BINARY,'
401
402     def _attr_policy(self, policy):
403         mem = '{ '
404         if len(self.checks) == 1 and 'min-len' in self.checks:
405             mem += '.len = ' + str(self.checks['min-len'])
406         elif len(self.checks) == 0:
407             mem += '.type = NLA_BINARY'
408         else:
409             raise Exception('One or more of binary type checks not implemented, yet')
410         mem += ', }'
411         return mem
412
413     def attr_put(self, ri, var):
414         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
415                             f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
416
417     def _attr_get(self, ri, var):
418         len_mem = var + '->_present.' + self.c_name + '_len'
419         return [f"{len_mem} = len;",
420                 f"{var}->{self.c_name} = malloc(len);",
421                 f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
422                ['len = mnl_attr_get_payload_len(attr);'], \
423                ['unsigned int len;']
424
425     def _setter_lines(self, ri, member, presence):
426         return [f"free({member});",
427                 f"{member} = malloc({presence}_len);",
428                 f'memcpy({member}, {self.c_name}, {presence}_len);']
429
430
431 class TypeNest(Type):
432     def _complex_member_type(self, ri):
433         return self.nested_struct_type
434
435     def free(self, ri, var, ref):
436         ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
437
438     def _attr_typol(self):
439         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
440
441     def _attr_policy(self, policy):
442         return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
443
444     def attr_put(self, ri, var):
445         self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
446                             f"{self.enum_name}, &{var}->{self.c_name})")
447
448     def _attr_get(self, ri, var):
449         get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
450                      "return MNL_CB_ERROR;"]
451         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
452                       f"parg.data = &{var}->{self.c_name};"]
453         return get_lines, init_lines, None
454
455     def setter(self, ri, space, direction, deref=False, ref=None):
456         ref = (ref if ref else []) + [self.c_name]
457
458         for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
459             attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
460
461
462 class TypeMultiAttr(Type):
463     def is_multi_val(self):
464         return True
465
466     def presence_type(self):
467         return 'count'
468
469     def _mnl_type(self):
470         t = self.type
471         # mnl does not have a helper for signed types
472         if t[0] == 's':
473             t = 'u' + t[1:]
474         return t
475
476     def _complex_member_type(self, ri):
477         if 'type' not in self.attr or self.attr['type'] == 'nest':
478             return self.nested_struct_type
479         elif self.attr['type'] in scalars:
480             scalar_pfx = '__' if ri.ku_space == 'user' else ''
481             return scalar_pfx + self.attr['type']
482         else:
483             raise Exception(f"Sub-type {self.attr['type']} not supported yet")
484
485     def free_needs_iter(self):
486         return 'type' not in self.attr or self.attr['type'] == 'nest'
487
488     def free(self, ri, var, ref):
489         if self.attr['type'] in scalars:
490             ri.cw.p(f"free({var}->{ref}{self.c_name});")
491         elif 'type' not in self.attr or self.attr['type'] == 'nest':
492             ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
493             ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
494             ri.cw.p(f"free({var}->{ref}{self.c_name});")
495         else:
496             raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
497
498     def _attr_typol(self):
499         if 'type' not in self.attr or self.attr['type'] == 'nest':
500             return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
501         elif self.attr['type'] in scalars:
502             return f".type = YNL_PT_U{self.attr['type'][1:]}, "
503         else:
504             raise Exception(f"Sub-type {self.attr['type']} not supported yet")
505
506     def _attr_get(self, ri, var):
507         return f'n_{self.c_name}++;', None, None
508
509     def attr_put(self, ri, var):
510         if self.attr['type'] in scalars:
511             put_type = self._mnl_type()
512             ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
513             ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
514         elif 'type' not in self.attr or self.attr['type'] == 'nest':
515             ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
516             self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
517                                 f"{self.enum_name}, &{var}->{self.c_name}[i])")
518         else:
519             raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
520
521     def _setter_lines(self, ri, member, presence):
522         # For multi-attr we have a count, not presence, hack up the presence
523         presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
524         return [f"free({member});",
525                 f"{member} = {self.c_name};",
526                 f"{presence} = n_{self.c_name};"]
527
528
529 class TypeArrayNest(Type):
530     def is_multi_val(self):
531         return True
532
533     def presence_type(self):
534         return 'count'
535
536     def _complex_member_type(self, ri):
537         if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
538             return self.nested_struct_type
539         elif self.attr['sub-type'] in scalars:
540             scalar_pfx = '__' if ri.ku_space == 'user' else ''
541             return scalar_pfx + self.attr['sub-type']
542         else:
543             raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
544
545     def _attr_typol(self):
546         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
547
548     def _attr_get(self, ri, var):
549         local_vars = ['const struct nlattr *attr2;']
550         get_lines = [f'attr_{self.c_name} = attr;',
551                      'mnl_attr_for_each_nested(attr2, attr)',
552                      f'\t{var}->n_{self.c_name}++;']
553         return get_lines, None, local_vars
554
555
556 class TypeNestTypeValue(Type):
557     def _complex_member_type(self, ri):
558         return self.nested_struct_type
559
560     def _attr_typol(self):
561         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
562
563     def _attr_get(self, ri, var):
564         prev = 'attr'
565         tv_args = ''
566         get_lines = []
567         local_vars = []
568         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
569                       f"parg.data = &{var}->{self.c_name};"]
570         if 'type-value' in self.attr:
571             tv_names = [c_lower(x) for x in self.attr["type-value"]]
572             local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
573             local_vars += [f'__u32 {", ".join(tv_names)};']
574             for level in self.attr["type-value"]:
575                 level = c_lower(level)
576                 get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
577                 get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
578                 prev = 'attr_' + level
579
580             tv_args = f", {', '.join(tv_names)}"
581
582         get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
583         return get_lines, init_lines, local_vars
584
585
586 class Struct:
587     def __init__(self, family, space_name, type_list=None, inherited=None):
588         self.family = family
589         self.space_name = space_name
590         self.attr_set = family.attr_sets[space_name]
591         # Use list to catch comparisons with empty sets
592         self._inherited = inherited if inherited is not None else []
593         self.inherited = []
594
595         self.nested = type_list is None
596         if family.name == c_lower(space_name):
597             self.render_name = f"{family.name}"
598         else:
599             self.render_name = f"{family.name}_{c_lower(space_name)}"
600         self.struct_name = 'struct ' + self.render_name
601         if self.nested and space_name in family.consts:
602             self.struct_name += '_'
603         self.ptr_name = self.struct_name + ' *'
604
605         self.request = False
606         self.reply = False
607
608         self.attr_list = []
609         self.attrs = dict()
610         if type_list:
611             for t in type_list:
612                 self.attr_list.append((t, self.attr_set[t]),)
613         else:
614             for t in self.attr_set:
615                 self.attr_list.append((t, self.attr_set[t]),)
616
617         max_val = 0
618         self.attr_max_val = None
619         for name, attr in self.attr_list:
620             if attr.value >= max_val:
621                 max_val = attr.value
622                 self.attr_max_val = attr
623             self.attrs[name] = attr
624
625     def __iter__(self):
626         yield from self.attrs
627
628     def __getitem__(self, key):
629         return self.attrs[key]
630
631     def member_list(self):
632         return self.attr_list
633
634     def set_inherited(self, new_inherited):
635         if self._inherited != new_inherited:
636             raise Exception("Inheriting different members not supported")
637         self.inherited = [c_lower(x) for x in sorted(self._inherited)]
638
639
640 class EnumEntry(SpecEnumEntry):
641     def __init__(self, enum_set, yaml, prev, value_start):
642         super().__init__(enum_set, yaml, prev, value_start)
643
644         if prev:
645             self.value_change = (self.value != prev.value + 1)
646         else:
647             self.value_change = (self.value != 0)
648         self.value_change = self.value_change or self.enum_set['type'] == 'flags'
649
650         # Added by resolve:
651         self.c_name = None
652         delattr(self, "c_name")
653
654     def resolve(self):
655         self.resolve_up(super())
656
657         self.c_name = c_upper(self.enum_set.value_pfx + self.name)
658
659
660 class EnumSet(SpecEnumSet):
661     def __init__(self, family, yaml):
662         self.render_name = c_lower(family.name + '-' + yaml['name'])
663
664         if 'enum-name' in yaml:
665             if yaml['enum-name']:
666                 self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
667             else:
668                 self.enum_name = None
669         else:
670             self.enum_name = 'enum ' + self.render_name
671
672         self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
673
674         super().__init__(family, yaml)
675
676     def new_entry(self, entry, prev_entry, value_start):
677         return EnumEntry(self, entry, prev_entry, value_start)
678
679
680 class AttrSet(SpecAttrSet):
681     def __init__(self, family, yaml):
682         super().__init__(family, yaml)
683
684         if self.subset_of is None:
685             if 'name-prefix' in yaml:
686                 pfx = yaml['name-prefix']
687             elif self.name == family.name:
688                 pfx = family.name + '-a-'
689             else:
690                 pfx = f"{family.name}-a-{self.name}-"
691             self.name_prefix = c_upper(pfx)
692             self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
693         else:
694             self.name_prefix = family.attr_sets[self.subset_of].name_prefix
695             self.max_name = family.attr_sets[self.subset_of].max_name
696
697         # Added by resolve:
698         self.c_name = None
699         delattr(self, "c_name")
700
701     def resolve(self):
702         self.c_name = c_lower(self.name)
703         if self.c_name in _C_KW:
704             self.c_name += '_'
705         if self.c_name == self.family.c_name:
706             self.c_name = ''
707
708     def new_attr(self, elem, value):
709         if 'multi-attr' in elem and elem['multi-attr']:
710             return TypeMultiAttr(self.family, self, elem, value)
711         elif elem['type'] in scalars:
712             return TypeScalar(self.family, self, elem, value)
713         elif elem['type'] == 'unused':
714             return TypeUnused(self.family, self, elem, value)
715         elif elem['type'] == 'pad':
716             return TypePad(self.family, self, elem, value)
717         elif elem['type'] == 'flag':
718             return TypeFlag(self.family, self, elem, value)
719         elif elem['type'] == 'string':
720             return TypeString(self.family, self, elem, value)
721         elif elem['type'] == 'binary':
722             return TypeBinary(self.family, self, elem, value)
723         elif elem['type'] == 'nest':
724             return TypeNest(self.family, self, elem, value)
725         elif elem['type'] == 'array-nest':
726             return TypeArrayNest(self.family, self, elem, value)
727         elif elem['type'] == 'nest-type-value':
728             return TypeNestTypeValue(self.family, self, elem, value)
729         else:
730             raise Exception(f"No typed class for type {elem['type']}")
731
732
733 class Operation(SpecOperation):
734     def __init__(self, family, yaml, req_value, rsp_value):
735         super().__init__(family, yaml, req_value, rsp_value)
736
737         self.render_name = family.name + '_' + c_lower(self.name)
738
739         self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
740                          ('dump' in yaml and 'request' in yaml['dump'])
741
742         self.has_ntf = False
743
744         # Added by resolve:
745         self.enum_name = None
746         delattr(self, "enum_name")
747
748     def resolve(self):
749         self.resolve_up(super())
750
751         if not self.is_async:
752             self.enum_name = self.family.op_prefix + c_upper(self.name)
753         else:
754             self.enum_name = self.family.async_op_prefix + c_upper(self.name)
755
756     def mark_has_ntf(self):
757         self.has_ntf = True
758
759
760 class Family(SpecFamily):
761     def __init__(self, file_name, exclude_ops):
762         # Added by resolve:
763         self.c_name = None
764         delattr(self, "c_name")
765         self.op_prefix = None
766         delattr(self, "op_prefix")
767         self.async_op_prefix = None
768         delattr(self, "async_op_prefix")
769         self.mcgrps = None
770         delattr(self, "mcgrps")
771         self.consts = None
772         delattr(self, "consts")
773         self.hooks = None
774         delattr(self, "hooks")
775
776         super().__init__(file_name, exclude_ops=exclude_ops)
777
778         self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
779         self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
780
781         if 'definitions' not in self.yaml:
782             self.yaml['definitions'] = []
783
784         if 'uapi-header' in self.yaml:
785             self.uapi_header = self.yaml['uapi-header']
786         else:
787             self.uapi_header = f"linux/{self.name}.h"
788
789     def resolve(self):
790         self.resolve_up(super())
791
792         if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
793             raise Exception("Codegen only supported for genetlink")
794
795         self.c_name = c_lower(self.name)
796         if 'name-prefix' in self.yaml['operations']:
797             self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
798         else:
799             self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
800         if 'async-prefix' in self.yaml['operations']:
801             self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
802         else:
803             self.async_op_prefix = self.op_prefix
804
805         self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
806
807         self.hooks = dict()
808         for when in ['pre', 'post']:
809             self.hooks[when] = dict()
810             for op_mode in ['do', 'dump']:
811                 self.hooks[when][op_mode] = dict()
812                 self.hooks[when][op_mode]['set'] = set()
813                 self.hooks[when][op_mode]['list'] = []
814
815         # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
816         self.root_sets = dict()
817         # dict space-name -> set('request', 'reply')
818         self.pure_nested_structs = dict()
819
820         self._mark_notify()
821         self._mock_up_events()
822
823         self._load_root_sets()
824         self._load_nested_sets()
825         self._load_hooks()
826
827         self.kernel_policy = self.yaml.get('kernel-policy', 'split')
828         if self.kernel_policy == 'global':
829             self._load_global_policy()
830
831     def new_enum(self, elem):
832         return EnumSet(self, elem)
833
834     def new_attr_set(self, elem):
835         return AttrSet(self, elem)
836
837     def new_operation(self, elem, req_value, rsp_value):
838         return Operation(self, elem, req_value, rsp_value)
839
840     def _mark_notify(self):
841         for op in self.msgs.values():
842             if 'notify' in op:
843                 self.ops[op['notify']].mark_has_ntf()
844
845     # Fake a 'do' equivalent of all events, so that we can render their response parsing
846     def _mock_up_events(self):
847         for op in self.yaml['operations']['list']:
848             if 'event' in op:
849                 op['do'] = {
850                     'reply': {
851                         'attributes': op['event']['attributes']
852                     }
853                 }
854
855     def _load_root_sets(self):
856         for op_name, op in self.msgs.items():
857             if 'attribute-set' not in op:
858                 continue
859
860             req_attrs = set()
861             rsp_attrs = set()
862             for op_mode in ['do', 'dump']:
863                 if op_mode in op and 'request' in op[op_mode]:
864                     req_attrs.update(set(op[op_mode]['request']['attributes']))
865                 if op_mode in op and 'reply' in op[op_mode]:
866                     rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
867             if 'event' in op:
868                 rsp_attrs.update(set(op['event']['attributes']))
869
870             if op['attribute-set'] not in self.root_sets:
871                 self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
872             else:
873                 self.root_sets[op['attribute-set']]['request'].update(req_attrs)
874                 self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
875
876     def _load_nested_sets(self):
877         attr_set_queue = list(self.root_sets.keys())
878         attr_set_seen = set(self.root_sets.keys())
879
880         while len(attr_set_queue):
881             a_set = attr_set_queue.pop(0)
882             for attr, spec in self.attr_sets[a_set].items():
883                 if 'nested-attributes' not in spec:
884                     continue
885
886                 nested = spec['nested-attributes']
887                 if nested not in attr_set_seen:
888                     attr_set_queue.append(nested)
889                     attr_set_seen.add(nested)
890
891                 inherit = set()
892                 if nested not in self.root_sets:
893                     if nested not in self.pure_nested_structs:
894                         self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
895                 else:
896                     raise Exception(f'Using attr set as root and nested not supported - {nested}')
897
898                 if 'type-value' in spec:
899                     if nested in self.root_sets:
900                         raise Exception("Inheriting members to a space used as root not supported")
901                     inherit.update(set(spec['type-value']))
902                 elif spec['type'] == 'array-nest':
903                     inherit.add('idx')
904                 self.pure_nested_structs[nested].set_inherited(inherit)
905
906         for root_set, rs_members in self.root_sets.items():
907             for attr, spec in self.attr_sets[root_set].items():
908                 if 'nested-attributes' in spec:
909                     nested = spec['nested-attributes']
910                     if attr in rs_members['request']:
911                         self.pure_nested_structs[nested].request = True
912                     if attr in rs_members['reply']:
913                         self.pure_nested_structs[nested].reply = True
914
915         # Try to reorder according to dependencies
916         pns_key_list = list(self.pure_nested_structs.keys())
917         pns_key_seen = set()
918         rounds = len(pns_key_list)**2  # it's basically bubble sort
919         for _ in range(rounds):
920             if len(pns_key_list) == 0:
921                 break
922             name = pns_key_list.pop(0)
923             finished = True
924             for _, spec in self.attr_sets[name].items():
925                 if 'nested-attributes' in spec:
926                     if spec['nested-attributes'] not in pns_key_seen:
927                         # Dicts are sorted, this will make struct last
928                         struct = self.pure_nested_structs.pop(name)
929                         self.pure_nested_structs[name] = struct
930                         finished = False
931                         break
932             if finished:
933                 pns_key_seen.add(name)
934             else:
935                 pns_key_list.append(name)
936         # Propagate the request / reply
937         for attr_set, struct in reversed(self.pure_nested_structs.items()):
938             for _, spec in self.attr_sets[attr_set].items():
939                 if 'nested-attributes' in spec:
940                     child = self.pure_nested_structs.get(spec['nested-attributes'])
941                     if child:
942                         child.request |= struct.request
943                         child.reply |= struct.reply
944
945     def _load_global_policy(self):
946         global_set = set()
947         attr_set_name = None
948         for op_name, op in self.ops.items():
949             if not op:
950                 continue
951             if 'attribute-set' not in op:
952                 continue
953
954             if attr_set_name is None:
955                 attr_set_name = op['attribute-set']
956             if attr_set_name != op['attribute-set']:
957                 raise Exception('For a global policy all ops must use the same set')
958
959             for op_mode in ['do', 'dump']:
960                 if op_mode in op:
961                     global_set.update(op[op_mode].get('request', []))
962
963         self.global_policy = []
964         self.global_policy_set = attr_set_name
965         for attr in self.attr_sets[attr_set_name]:
966             if attr in global_set:
967                 self.global_policy.append(attr)
968
969     def _load_hooks(self):
970         for op in self.ops.values():
971             for op_mode in ['do', 'dump']:
972                 if op_mode not in op:
973                     continue
974                 for when in ['pre', 'post']:
975                     if when not in op[op_mode]:
976                         continue
977                     name = op[op_mode][when]
978                     if name in self.hooks[when][op_mode]['set']:
979                         continue
980                     self.hooks[when][op_mode]['set'].add(name)
981                     self.hooks[when][op_mode]['list'].append(name)
982
983
984 class RenderInfo:
985     def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
986         self.family = family
987         self.nl = cw.nlib
988         self.ku_space = ku_space
989         self.op_mode = op_mode
990         self.op = op
991
992         # 'do' and 'dump' response parsing is identical
993         self.type_consistent = True
994         if op_mode != 'do' and 'dump' in op and 'do' in op:
995             if ('reply' in op['do']) != ('reply' in op["dump"]):
996                 self.type_consistent = False
997             elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
998                 self.type_consistent = False
999
1000         self.attr_set = attr_set
1001         if not self.attr_set:
1002             self.attr_set = op['attribute-set']
1003
1004         self.type_name_conflict = False
1005         if op:
1006             self.type_name = c_lower(op.name)
1007         else:
1008             self.type_name = c_lower(attr_set)
1009             if attr_set in family.consts:
1010                 self.type_name_conflict = True
1011
1012         self.cw = cw
1013
1014         self.struct = dict()
1015         if op_mode == 'notify':
1016             op_mode = 'do'
1017         for op_dir in ['request', 'reply']:
1018             if op and op_dir in op[op_mode]:
1019                 self.struct[op_dir] = Struct(family, self.attr_set,
1020                                              type_list=op[op_mode][op_dir]['attributes'])
1021         if op_mode == 'event':
1022             self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1023
1024
1025 class CodeWriter:
1026     def __init__(self, nlib, out_file):
1027         self.nlib = nlib
1028
1029         self._nl = False
1030         self._block_end = False
1031         self._silent_block = False
1032         self._ind = 0
1033         self._out = out_file
1034
1035     @classmethod
1036     def _is_cond(cls, line):
1037         return line.startswith('if') or line.startswith('while') or line.startswith('for')
1038
1039     def p(self, line, add_ind=0):
1040         if self._block_end:
1041             self._block_end = False
1042             if line.startswith('else'):
1043                 line = '} ' + line
1044             else:
1045                 self._out.write('\t' * self._ind + '}\n')
1046
1047         if self._nl:
1048             self._out.write('\n')
1049             self._nl = False
1050
1051         ind = self._ind
1052         if line[-1] == ':':
1053             ind -= 1
1054         if self._silent_block:
1055             ind += 1
1056         self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1057         if add_ind:
1058             ind += add_ind
1059         self._out.write('\t' * ind + line + '\n')
1060
1061     def nl(self):
1062         self._nl = True
1063
1064     def block_start(self, line=''):
1065         if line:
1066             line = line + ' '
1067         self.p(line + '{')
1068         self._ind += 1
1069
1070     def block_end(self, line=''):
1071         if line and line[0] not in {';', ','}:
1072             line = ' ' + line
1073         self._ind -= 1
1074         self._nl = False
1075         if not line:
1076             # Delay printing closing bracket in case "else" comes next
1077             if self._block_end:
1078                 self._out.write('\t' * (self._ind + 1) + '}\n')
1079             self._block_end = True
1080         else:
1081             self.p('}' + line)
1082
1083     def write_doc_line(self, doc, indent=True):
1084         words = doc.split()
1085         line = ' *'
1086         for word in words:
1087             if len(line) + len(word) >= 79:
1088                 self.p(line)
1089                 line = ' *'
1090                 if indent:
1091                     line += '  '
1092             line += ' ' + word
1093         self.p(line)
1094
1095     def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1096         if not args:
1097             args = ['void']
1098
1099         if doc:
1100             self.p('/*')
1101             self.p(' * ' + doc)
1102             self.p(' */')
1103
1104         oneline = qual_ret
1105         if qual_ret[-1] != '*':
1106             oneline += ' '
1107         oneline += f"{name}({', '.join(args)}){suffix}"
1108
1109         if len(oneline) < 80:
1110             self.p(oneline)
1111             return
1112
1113         v = qual_ret
1114         if len(v) > 3:
1115             self.p(v)
1116             v = ''
1117         elif qual_ret[-1] != '*':
1118             v += ' '
1119         v += name + '('
1120         ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1121         delta_ind = len(v) - len(ind)
1122         v += args[0]
1123         i = 1
1124         while i < len(args):
1125             next_len = len(v) + len(args[i])
1126             if v[0] == '\t':
1127                 next_len += delta_ind
1128             if next_len > 76:
1129                 self.p(v + ',')
1130                 v = ind
1131             else:
1132                 v += ', '
1133             v += args[i]
1134             i += 1
1135         self.p(v + ')' + suffix)
1136
1137     def write_func_lvar(self, local_vars):
1138         if not local_vars:
1139             return
1140
1141         if type(local_vars) is str:
1142             local_vars = [local_vars]
1143
1144         local_vars.sort(key=len, reverse=True)
1145         for var in local_vars:
1146             self.p(var)
1147         self.nl()
1148
1149     def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1150         self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1151         self.write_func_lvar(local_vars=local_vars)
1152
1153         self.block_start()
1154         for line in body:
1155             self.p(line)
1156         self.block_end()
1157
1158     def writes_defines(self, defines):
1159         longest = 0
1160         for define in defines:
1161             if len(define[0]) > longest:
1162                 longest = len(define[0])
1163         longest = ((longest + 8) // 8) * 8
1164         for define in defines:
1165             line = '#define ' + define[0]
1166             line += '\t' * ((longest - len(define[0]) + 7) // 8)
1167             if type(define[1]) is int:
1168                 line += str(define[1])
1169             elif type(define[1]) is str:
1170                 line += '"' + define[1] + '"'
1171             self.p(line)
1172
1173     def write_struct_init(self, members):
1174         longest = max([len(x[0]) for x in members])
1175         longest += 1  # because we prepend a .
1176         longest = ((longest + 8) // 8) * 8
1177         for one in members:
1178             line = '.' + one[0]
1179             line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1180             line += '= ' + one[1] + ','
1181             self.p(line)
1182
1183
1184 scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1185
1186 direction_to_suffix = {
1187     'reply': '_rsp',
1188     'request': '_req',
1189     '': ''
1190 }
1191
1192 op_mode_to_wrapper = {
1193     'do': '',
1194     'dump': '_list',
1195     'notify': '_ntf',
1196     'event': '',
1197 }
1198
1199 _C_KW = {
1200     'auto',
1201     'bool',
1202     'break',
1203     'case',
1204     'char',
1205     'const',
1206     'continue',
1207     'default',
1208     'do',
1209     'double',
1210     'else',
1211     'enum',
1212     'extern',
1213     'float',
1214     'for',
1215     'goto',
1216     'if',
1217     'inline',
1218     'int',
1219     'long',
1220     'register',
1221     'return',
1222     'short',
1223     'signed',
1224     'sizeof',
1225     'static',
1226     'struct',
1227     'switch',
1228     'typedef',
1229     'union',
1230     'unsigned',
1231     'void',
1232     'volatile',
1233     'while'
1234 }
1235
1236
1237 def rdir(direction):
1238     if direction == 'reply':
1239         return 'request'
1240     if direction == 'request':
1241         return 'reply'
1242     return direction
1243
1244
1245 def op_prefix(ri, direction, deref=False):
1246     suffix = f"_{ri.type_name}"
1247
1248     if not ri.op_mode or ri.op_mode == 'do':
1249         suffix += f"{direction_to_suffix[direction]}"
1250     else:
1251         if direction == 'request':
1252             suffix += '_req_dump'
1253         else:
1254             if ri.type_consistent:
1255                 if deref:
1256                     suffix += f"{direction_to_suffix[direction]}"
1257                 else:
1258                     suffix += op_mode_to_wrapper[ri.op_mode]
1259             else:
1260                 suffix += '_rsp'
1261                 suffix += '_dump' if deref else '_list'
1262
1263     return f"{ri.family['name']}{suffix}"
1264
1265
1266 def type_name(ri, direction, deref=False):
1267     return f"struct {op_prefix(ri, direction, deref=deref)}"
1268
1269
1270 def print_prototype(ri, direction, terminate=True, doc=None):
1271     suffix = ';' if terminate else ''
1272
1273     fname = ri.op.render_name
1274     if ri.op_mode == 'dump':
1275         fname += '_dump'
1276
1277     args = ['struct ynl_sock *ys']
1278     if 'request' in ri.op[ri.op_mode]:
1279         args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1280
1281     ret = 'int'
1282     if 'reply' in ri.op[ri.op_mode]:
1283         ret = f"{type_name(ri, rdir(direction))} *"
1284
1285     ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1286
1287
1288 def print_req_prototype(ri):
1289     print_prototype(ri, "request", doc=ri.op['doc'])
1290
1291
1292 def print_dump_prototype(ri):
1293     print_prototype(ri, "request")
1294
1295
1296 def put_typol(cw, struct):
1297     type_max = struct.attr_set.max_name
1298     cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1299
1300     for _, arg in struct.member_list():
1301         arg.attr_typol(cw)
1302
1303     cw.block_end(line=';')
1304     cw.nl()
1305
1306     cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1307     cw.p(f'.max_attr = {type_max},')
1308     cw.p(f'.table = {struct.render_name}_policy,')
1309     cw.block_end(line=';')
1310     cw.nl()
1311
1312
1313 def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1314     args = [f'int {arg_name}']
1315     if enum and not ('enum-name' in enum and not enum['enum-name']):
1316         args = [f'enum {render_name} {arg_name}']
1317     cw.write_func_prot('const char *', f'{render_name}_str', args)
1318     cw.block_start()
1319     if enum and enum.type == 'flags':
1320         cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1321     cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)MNL_ARRAY_SIZE({map_name}))')
1322     cw.p('return NULL;')
1323     cw.p(f'return {map_name}[{arg_name}];')
1324     cw.block_end()
1325     cw.nl()
1326
1327
1328 def put_op_name_fwd(family, cw):
1329     cw.write_func_prot('const char *', f'{family.name}_op_str', ['int op'], suffix=';')
1330
1331
1332 def put_op_name(family, cw):
1333     map_name = f'{family.name}_op_strmap'
1334     cw.block_start(line=f"static const char * const {map_name}[] =")
1335     for op_name, op in family.msgs.items():
1336         if op.rsp_value:
1337             if op.req_value == op.rsp_value:
1338                 cw.p(f'[{op.enum_name}] = "{op_name}",')
1339             else:
1340                 cw.p(f'[{op.rsp_value}] = "{op_name}",')
1341     cw.block_end(line=';')
1342     cw.nl()
1343
1344     _put_enum_to_str_helper(cw, family.name + '_op', map_name, 'op')
1345
1346
1347 def put_enum_to_str_fwd(family, cw, enum):
1348     args = [f'enum {enum.render_name} value']
1349     if 'enum-name' in enum and not enum['enum-name']:
1350         args = ['int value']
1351     cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1352
1353
1354 def put_enum_to_str(family, cw, enum):
1355     map_name = f'{enum.render_name}_strmap'
1356     cw.block_start(line=f"static const char * const {map_name}[] =")
1357     for entry in enum.entries.values():
1358         cw.p(f'[{entry.value}] = "{entry.name}",')
1359     cw.block_end(line=';')
1360     cw.nl()
1361
1362     _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1363
1364
1365 def put_req_nested(ri, struct):
1366     func_args = ['struct nlmsghdr *nlh',
1367                  'unsigned int attr_type',
1368                  f'{struct.ptr_name}obj']
1369
1370     ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1371     ri.cw.block_start()
1372     ri.cw.write_func_lvar('struct nlattr *nest;')
1373
1374     ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1375
1376     for _, arg in struct.member_list():
1377         arg.attr_put(ri, "obj")
1378
1379     ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1380
1381     ri.cw.nl()
1382     ri.cw.p('return 0;')
1383     ri.cw.block_end()
1384     ri.cw.nl()
1385
1386
1387 def _multi_parse(ri, struct, init_lines, local_vars):
1388     if struct.nested:
1389         iter_line = "mnl_attr_for_each_nested(attr, nested)"
1390     else:
1391         iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1392
1393     array_nests = set()
1394     multi_attrs = set()
1395     needs_parg = False
1396     for arg, aspec in struct.member_list():
1397         if aspec['type'] == 'array-nest':
1398             local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1399             array_nests.add(arg)
1400         if 'multi-attr' in aspec:
1401             multi_attrs.add(arg)
1402         needs_parg |= 'nested-attributes' in aspec
1403     if array_nests or multi_attrs:
1404         local_vars.append('int i;')
1405     if needs_parg:
1406         local_vars.append('struct ynl_parse_arg parg;')
1407         init_lines.append('parg.ys = yarg->ys;')
1408
1409     all_multi = array_nests | multi_attrs
1410
1411     for anest in sorted(all_multi):
1412         local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1413
1414     ri.cw.block_start()
1415     ri.cw.write_func_lvar(local_vars)
1416
1417     for line in init_lines:
1418         ri.cw.p(line)
1419     ri.cw.nl()
1420
1421     for arg in struct.inherited:
1422         ri.cw.p(f'dst->{arg} = {arg};')
1423
1424     for anest in sorted(all_multi):
1425         aspec = struct[anest]
1426         ri.cw.p(f"if (dst->{aspec.c_name})")
1427         ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1428
1429     ri.cw.nl()
1430     ri.cw.block_start(line=iter_line)
1431     ri.cw.p('unsigned int type = mnl_attr_get_type(attr);')
1432     ri.cw.nl()
1433
1434     first = True
1435     for _, arg in struct.member_list():
1436         good = arg.attr_get(ri, 'dst', first=first)
1437         # First may be 'unused' or 'pad', ignore those
1438         first &= not good
1439
1440     ri.cw.block_end()
1441     ri.cw.nl()
1442
1443     for anest in sorted(array_nests):
1444         aspec = struct[anest]
1445
1446         ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1447         ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1448         ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1449         ri.cw.p('i = 0;')
1450         ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1451         ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1452         ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1453         ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1454         ri.cw.p('return MNL_CB_ERROR;')
1455         ri.cw.p('i++;')
1456         ri.cw.block_end()
1457         ri.cw.block_end()
1458     ri.cw.nl()
1459
1460     for anest in sorted(multi_attrs):
1461         aspec = struct[anest]
1462         ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1463         ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1464         ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1465         ri.cw.p('i = 0;')
1466         if 'nested-attributes' in aspec:
1467             ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1468         ri.cw.block_start(line=iter_line)
1469         ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1470         if 'nested-attributes' in aspec:
1471             ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1472             ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1473             ri.cw.p('return MNL_CB_ERROR;')
1474         elif aspec['type'] in scalars:
1475             t = aspec['type']
1476             if t[0] == 's':
1477                 t = 'u' + t[1:]
1478             ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1479         else:
1480             raise Exception('Nest parsing type not supported yet')
1481         ri.cw.p('i++;')
1482         ri.cw.block_end()
1483         ri.cw.block_end()
1484         ri.cw.block_end()
1485     ri.cw.nl()
1486
1487     if struct.nested:
1488         ri.cw.p('return 0;')
1489     else:
1490         ri.cw.p('return MNL_CB_OK;')
1491     ri.cw.block_end()
1492     ri.cw.nl()
1493
1494
1495 def parse_rsp_nested(ri, struct):
1496     func_args = ['struct ynl_parse_arg *yarg',
1497                  'const struct nlattr *nested']
1498     for arg in struct.inherited:
1499         func_args.append('__u32 ' + arg)
1500
1501     local_vars = ['const struct nlattr *attr;',
1502                   f'{struct.ptr_name}dst = yarg->data;']
1503     init_lines = []
1504
1505     ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1506
1507     _multi_parse(ri, struct, init_lines, local_vars)
1508
1509
1510 def parse_rsp_msg(ri, deref=False):
1511     if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1512         return
1513
1514     func_args = ['const struct nlmsghdr *nlh',
1515                  'void *data']
1516
1517     local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1518                   'struct ynl_parse_arg *yarg = data;',
1519                   'const struct nlattr *attr;']
1520     init_lines = ['dst = yarg->data;']
1521
1522     ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1523
1524     _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1525
1526
1527 def print_req(ri):
1528     ret_ok = '0'
1529     ret_err = '-1'
1530     direction = "request"
1531     local_vars = ['struct nlmsghdr *nlh;',
1532                   'int err;']
1533
1534     if 'reply' in ri.op[ri.op_mode]:
1535         ret_ok = 'rsp'
1536         ret_err = 'NULL'
1537         local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1538                        'struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };']
1539
1540     print_prototype(ri, direction, terminate=False)
1541     ri.cw.block_start()
1542     ri.cw.write_func_lvar(local_vars)
1543
1544     ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1545
1546     ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1547     if 'reply' in ri.op[ri.op_mode]:
1548         ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1549     ri.cw.nl()
1550     for _, attr in ri.struct["request"].member_list():
1551         attr.attr_put(ri, "req")
1552     ri.cw.nl()
1553
1554     parse_arg = "NULL"
1555     if 'reply' in ri.op[ri.op_mode]:
1556         ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1557         ri.cw.p('yrs.yarg.data = rsp;')
1558         ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1559         if ri.op.value is not None:
1560             ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1561         else:
1562             ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1563         ri.cw.nl()
1564         parse_arg = '&yrs'
1565     ri.cw.p(f"err = ynl_exec(ys, nlh, {parse_arg});")
1566     ri.cw.p('if (err < 0)')
1567     if 'reply' in ri.op[ri.op_mode]:
1568         ri.cw.p('goto err_free;')
1569     else:
1570         ri.cw.p('return -1;')
1571     ri.cw.nl()
1572
1573     ri.cw.p(f"return {ret_ok};")
1574     ri.cw.nl()
1575
1576     if 'reply' in ri.op[ri.op_mode]:
1577         ri.cw.p('err_free:')
1578         ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1579         ri.cw.p(f"return {ret_err};")
1580
1581     ri.cw.block_end()
1582
1583
1584 def print_dump(ri):
1585     direction = "request"
1586     print_prototype(ri, direction, terminate=False)
1587     ri.cw.block_start()
1588     local_vars = ['struct ynl_dump_state yds = {};',
1589                   'struct nlmsghdr *nlh;',
1590                   'int err;']
1591
1592     for var in local_vars:
1593         ri.cw.p(f'{var}')
1594     ri.cw.nl()
1595
1596     ri.cw.p('yds.ys = ys;')
1597     ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1598     ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1599     if ri.op.value is not None:
1600         ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1601     else:
1602         ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1603     ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1604     ri.cw.nl()
1605     ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1606
1607     if "request" in ri.op[ri.op_mode]:
1608         ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1609         ri.cw.nl()
1610         for _, attr in ri.struct["request"].member_list():
1611             attr.attr_put(ri, "req")
1612     ri.cw.nl()
1613
1614     ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1615     ri.cw.p('if (err < 0)')
1616     ri.cw.p('goto free_list;')
1617     ri.cw.nl()
1618
1619     ri.cw.p('return yds.first;')
1620     ri.cw.nl()
1621     ri.cw.p('free_list:')
1622     ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1623     ri.cw.p('return NULL;')
1624     ri.cw.block_end()
1625
1626
1627 def call_free(ri, direction, var):
1628     return f"{op_prefix(ri, direction)}_free({var});"
1629
1630
1631 def free_arg_name(direction):
1632     if direction:
1633         return direction_to_suffix[direction][1:]
1634     return 'obj'
1635
1636
1637 def print_alloc_wrapper(ri, direction):
1638     name = op_prefix(ri, direction)
1639     ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1640     ri.cw.block_start()
1641     ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1642     ri.cw.block_end()
1643
1644
1645 def print_free_prototype(ri, direction, suffix=';'):
1646     name = op_prefix(ri, direction)
1647     struct_name = name
1648     if ri.type_name_conflict:
1649         struct_name += '_'
1650     arg = free_arg_name(direction)
1651     ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1652
1653
1654 def _print_type(ri, direction, struct):
1655     suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1656     if not direction and ri.type_name_conflict:
1657         suffix += '_'
1658
1659     if ri.op_mode == 'dump':
1660         suffix += '_dump'
1661
1662     ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1663
1664     meta_started = False
1665     for _, attr in struct.member_list():
1666         for type_filter in ['len', 'bit']:
1667             line = attr.presence_member(ri.ku_space, type_filter)
1668             if line:
1669                 if not meta_started:
1670                     ri.cw.block_start(line=f"struct")
1671                     meta_started = True
1672                 ri.cw.p(line)
1673     if meta_started:
1674         ri.cw.block_end(line='_present;')
1675         ri.cw.nl()
1676
1677     for arg in struct.inherited:
1678         ri.cw.p(f"__u32 {arg};")
1679
1680     for _, attr in struct.member_list():
1681         attr.struct_member(ri)
1682
1683     ri.cw.block_end(line=';')
1684     ri.cw.nl()
1685
1686
1687 def print_type(ri, direction):
1688     _print_type(ri, direction, ri.struct[direction])
1689
1690
1691 def print_type_full(ri, struct):
1692     _print_type(ri, "", struct)
1693
1694
1695 def print_type_helpers(ri, direction, deref=False):
1696     print_free_prototype(ri, direction)
1697     ri.cw.nl()
1698
1699     if ri.ku_space == 'user' and direction == 'request':
1700         for _, attr in ri.struct[direction].member_list():
1701             attr.setter(ri, ri.attr_set, direction, deref=deref)
1702     ri.cw.nl()
1703
1704
1705 def print_req_type_helpers(ri):
1706     print_alloc_wrapper(ri, "request")
1707     print_type_helpers(ri, "request")
1708
1709
1710 def print_rsp_type_helpers(ri):
1711     if 'reply' not in ri.op[ri.op_mode]:
1712         return
1713     print_type_helpers(ri, "reply")
1714
1715
1716 def print_parse_prototype(ri, direction, terminate=True):
1717     suffix = "_rsp" if direction == "reply" else "_req"
1718     term = ';' if terminate else ''
1719
1720     ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1721                           ['const struct nlattr **tb',
1722                            f"struct {ri.op.render_name}{suffix} *req"],
1723                           suffix=term)
1724
1725
1726 def print_req_type(ri):
1727     print_type(ri, "request")
1728
1729
1730 def print_req_free(ri):
1731     if 'request' not in ri.op[ri.op_mode]:
1732         return
1733     _free_type(ri, 'request', ri.struct['request'])
1734
1735
1736 def print_rsp_type(ri):
1737     if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1738         direction = 'reply'
1739     elif ri.op_mode == 'event':
1740         direction = 'reply'
1741     else:
1742         return
1743     print_type(ri, direction)
1744
1745
1746 def print_wrapped_type(ri):
1747     ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1748     if ri.op_mode == 'dump':
1749         ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1750     elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1751         ri.cw.p('__u16 family;')
1752         ri.cw.p('__u8 cmd;')
1753         ri.cw.p('struct ynl_ntf_base_type *next;')
1754         ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1755     ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1756     ri.cw.block_end(line=';')
1757     ri.cw.nl()
1758     print_free_prototype(ri, 'reply')
1759     ri.cw.nl()
1760
1761
1762 def _free_type_members_iter(ri, struct):
1763     for _, attr in struct.member_list():
1764         if attr.free_needs_iter():
1765             ri.cw.p('unsigned int i;')
1766             ri.cw.nl()
1767             break
1768
1769
1770 def _free_type_members(ri, var, struct, ref=''):
1771     for _, attr in struct.member_list():
1772         attr.free(ri, var, ref)
1773
1774
1775 def _free_type(ri, direction, struct):
1776     var = free_arg_name(direction)
1777
1778     print_free_prototype(ri, direction, suffix='')
1779     ri.cw.block_start()
1780     _free_type_members_iter(ri, struct)
1781     _free_type_members(ri, var, struct)
1782     if direction:
1783         ri.cw.p(f'free({var});')
1784     ri.cw.block_end()
1785     ri.cw.nl()
1786
1787
1788 def free_rsp_nested(ri, struct):
1789     _free_type(ri, "", struct)
1790
1791
1792 def print_rsp_free(ri):
1793     if 'reply' not in ri.op[ri.op_mode]:
1794         return
1795     _free_type(ri, 'reply', ri.struct['reply'])
1796
1797
1798 def print_dump_type_free(ri):
1799     sub_type = type_name(ri, 'reply')
1800
1801     print_free_prototype(ri, 'reply', suffix='')
1802     ri.cw.block_start()
1803     ri.cw.p(f"{sub_type} *next = rsp;")
1804     ri.cw.nl()
1805     ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
1806     _free_type_members_iter(ri, ri.struct['reply'])
1807     ri.cw.p('rsp = next;')
1808     ri.cw.p('next = rsp->next;')
1809     ri.cw.nl()
1810
1811     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1812     ri.cw.p(f'free(rsp);')
1813     ri.cw.block_end()
1814     ri.cw.block_end()
1815     ri.cw.nl()
1816
1817
1818 def print_ntf_type_free(ri):
1819     print_free_prototype(ri, 'reply', suffix='')
1820     ri.cw.block_start()
1821     _free_type_members_iter(ri, ri.struct['reply'])
1822     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1823     ri.cw.p(f'free(rsp);')
1824     ri.cw.block_end()
1825     ri.cw.nl()
1826
1827
1828 def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1829     if terminate and ri and kernel_can_gen_family_struct(struct.family):
1830         return
1831
1832     if terminate:
1833         prefix = 'extern '
1834     else:
1835         if kernel_can_gen_family_struct(struct.family) and ri:
1836             prefix = 'static '
1837         else:
1838             prefix = ''
1839
1840     suffix = ';' if terminate else ' = {'
1841
1842     max_attr = struct.attr_max_val
1843     if ri:
1844         name = ri.op.render_name
1845         if ri.op.dual_policy:
1846             name += '_' + ri.op_mode
1847     else:
1848         name = struct.render_name
1849     cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1850
1851
1852 def print_req_policy(cw, struct, ri=None):
1853     print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1854     for _, arg in struct.member_list():
1855         arg.attr_policy(cw)
1856     cw.p("};")
1857
1858
1859 def kernel_can_gen_family_struct(family):
1860     return family.proto == 'genetlink'
1861
1862
1863 def print_kernel_op_table_fwd(family, cw, terminate):
1864     exported = not kernel_can_gen_family_struct(family)
1865
1866     if not terminate or exported:
1867         cw.p(f"/* Ops table for {family.name} */")
1868
1869         pol_to_struct = {'global': 'genl_small_ops',
1870                          'per-op': 'genl_ops',
1871                          'split': 'genl_split_ops'}
1872         struct_type = pol_to_struct[family.kernel_policy]
1873
1874         if not exported:
1875             cnt = ""
1876         elif family.kernel_policy == 'split':
1877             cnt = 0
1878             for op in family.ops.values():
1879                 if 'do' in op:
1880                     cnt += 1
1881                 if 'dump' in op:
1882                     cnt += 1
1883         else:
1884             cnt = len(family.ops)
1885
1886         qual = 'static const' if not exported else 'const'
1887         line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1888         if terminate:
1889             cw.p(f"extern {line};")
1890         else:
1891             cw.block_start(line=line + ' =')
1892
1893     if not terminate:
1894         return
1895
1896     cw.nl()
1897     for name in family.hooks['pre']['do']['list']:
1898         cw.write_func_prot('int', c_lower(name),
1899                            ['const struct genl_split_ops *ops',
1900                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1901     for name in family.hooks['post']['do']['list']:
1902         cw.write_func_prot('void', c_lower(name),
1903                            ['const struct genl_split_ops *ops',
1904                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1905     for name in family.hooks['pre']['dump']['list']:
1906         cw.write_func_prot('int', c_lower(name),
1907                            ['struct netlink_callback *cb'], suffix=';')
1908     for name in family.hooks['post']['dump']['list']:
1909         cw.write_func_prot('int', c_lower(name),
1910                            ['struct netlink_callback *cb'], suffix=';')
1911
1912     cw.nl()
1913
1914     for op_name, op in family.ops.items():
1915         if op.is_async:
1916             continue
1917
1918         if 'do' in op:
1919             name = c_lower(f"{family.name}-nl-{op_name}-doit")
1920             cw.write_func_prot('int', name,
1921                                ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1922
1923         if 'dump' in op:
1924             name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1925             cw.write_func_prot('int', name,
1926                                ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1927     cw.nl()
1928
1929
1930 def print_kernel_op_table_hdr(family, cw):
1931     print_kernel_op_table_fwd(family, cw, terminate=True)
1932
1933
1934 def print_kernel_op_table(family, cw):
1935     print_kernel_op_table_fwd(family, cw, terminate=False)
1936     if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1937         for op_name, op in family.ops.items():
1938             if op.is_async:
1939                 continue
1940
1941             cw.block_start()
1942             members = [('cmd', op.enum_name)]
1943             if 'dont-validate' in op:
1944                 members.append(('validate',
1945                                 ' | '.join([c_upper('genl-dont-validate-' + x)
1946                                             for x in op['dont-validate']])), )
1947             for op_mode in ['do', 'dump']:
1948                 if op_mode in op:
1949                     name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1950                     members.append((op_mode + 'it', name))
1951             if family.kernel_policy == 'per-op':
1952                 struct = Struct(family, op['attribute-set'],
1953                                 type_list=op['do']['request']['attributes'])
1954
1955                 name = c_lower(f"{family.name}-{op_name}-nl-policy")
1956                 members.append(('policy', name))
1957                 members.append(('maxattr', struct.attr_max_val.enum_name))
1958             if 'flags' in op:
1959                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1960             cw.write_struct_init(members)
1961             cw.block_end(line=',')
1962     elif family.kernel_policy == 'split':
1963         cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1964                     'dump': {'pre': 'start', 'post': 'done'}}
1965
1966         for op_name, op in family.ops.items():
1967             for op_mode in ['do', 'dump']:
1968                 if op.is_async or op_mode not in op:
1969                     continue
1970
1971                 cw.block_start()
1972                 members = [('cmd', op.enum_name)]
1973                 if 'dont-validate' in op:
1974                     members.append(('validate',
1975                                     ' | '.join([c_upper('genl-dont-validate-' + x)
1976                                                 for x in op['dont-validate']])), )
1977                 name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1978                 if 'pre' in op[op_mode]:
1979                     members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
1980                 members.append((op_mode + 'it', name))
1981                 if 'post' in op[op_mode]:
1982                     members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
1983                 if 'request' in op[op_mode]:
1984                     struct = Struct(family, op['attribute-set'],
1985                                     type_list=op[op_mode]['request']['attributes'])
1986
1987                     if op.dual_policy:
1988                         name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
1989                     else:
1990                         name = c_lower(f"{family.name}-{op_name}-nl-policy")
1991                     members.append(('policy', name))
1992                     members.append(('maxattr', struct.attr_max_val.enum_name))
1993                 flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
1994                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
1995                 cw.write_struct_init(members)
1996                 cw.block_end(line=',')
1997
1998     cw.block_end(line=';')
1999     cw.nl()
2000
2001
2002 def print_kernel_mcgrp_hdr(family, cw):
2003     if not family.mcgrps['list']:
2004         return
2005
2006     cw.block_start('enum')
2007     for grp in family.mcgrps['list']:
2008         grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2009         cw.p(grp_id)
2010     cw.block_end(';')
2011     cw.nl()
2012
2013
2014 def print_kernel_mcgrp_src(family, cw):
2015     if not family.mcgrps['list']:
2016         return
2017
2018     cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
2019     for grp in family.mcgrps['list']:
2020         name = grp['name']
2021         grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2022         cw.p('[' + grp_id + '] = { "' + name + '", },')
2023     cw.block_end(';')
2024     cw.nl()
2025
2026
2027 def print_kernel_family_struct_hdr(family, cw):
2028     if not kernel_can_gen_family_struct(family):
2029         return
2030
2031     cw.p(f"extern struct genl_family {family.name}_nl_family;")
2032     cw.nl()
2033
2034
2035 def print_kernel_family_struct_src(family, cw):
2036     if not kernel_can_gen_family_struct(family):
2037         return
2038
2039     cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2040     cw.p('.name\t\t= ' + family.fam_key + ',')
2041     cw.p('.version\t= ' + family.ver_key + ',')
2042     cw.p('.netnsok\t= true,')
2043     cw.p('.parallel_ops\t= true,')
2044     cw.p('.module\t\t= THIS_MODULE,')
2045     if family.kernel_policy == 'per-op':
2046         cw.p(f'.ops\t\t= {family.name}_nl_ops,')
2047         cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
2048     elif family.kernel_policy == 'split':
2049         cw.p(f'.split_ops\t= {family.name}_nl_ops,')
2050         cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
2051     if family.mcgrps['list']:
2052         cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
2053         cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
2054     cw.block_end(';')
2055
2056
2057 def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2058     start_line = 'enum'
2059     if enum_name in obj:
2060         if obj[enum_name]:
2061             start_line = 'enum ' + c_lower(obj[enum_name])
2062     elif ckey and ckey in obj:
2063         start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
2064     cw.block_start(line=start_line)
2065
2066
2067 def render_uapi(family, cw):
2068     hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
2069     cw.p('#ifndef ' + hdr_prot)
2070     cw.p('#define ' + hdr_prot)
2071     cw.nl()
2072
2073     defines = [(family.fam_key, family["name"]),
2074                (family.ver_key, family.get('version', 1))]
2075     cw.writes_defines(defines)
2076     cw.nl()
2077
2078     defines = []
2079     for const in family['definitions']:
2080         if const['type'] != 'const':
2081             cw.writes_defines(defines)
2082             defines = []
2083             cw.nl()
2084
2085         # Write kdoc for enum and flags (one day maybe also structs)
2086         if const['type'] == 'enum' or const['type'] == 'flags':
2087             enum = family.consts[const['name']]
2088
2089             if enum.has_doc():
2090                 cw.p('/**')
2091                 doc = ''
2092                 if 'doc' in enum:
2093                     doc = ' - ' + enum['doc']
2094                 cw.write_doc_line(enum.enum_name + doc)
2095                 for entry in enum.entries.values():
2096                     if entry.has_doc():
2097                         doc = '@' + entry.c_name + ': ' + entry['doc']
2098                         cw.write_doc_line(doc)
2099                 cw.p(' */')
2100
2101             uapi_enum_start(family, cw, const, 'name')
2102             name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2103             for entry in enum.entries.values():
2104                 suffix = ','
2105                 if entry.value_change:
2106                     suffix = f" = {entry.user_value()}" + suffix
2107                 cw.p(entry.c_name + suffix)
2108
2109             if const.get('render-max', False):
2110                 cw.nl()
2111                 if const['type'] == 'flags':
2112                     max_name = c_upper(name_pfx + 'mask')
2113                     max_val = f' = {enum.get_mask()},'
2114                     cw.p(max_name + max_val)
2115                 else:
2116                     max_name = c_upper(name_pfx + 'max')
2117                     cw.p('__' + max_name + ',')
2118                     cw.p(max_name + ' = (__' + max_name + ' - 1)')
2119             cw.block_end(line=';')
2120             cw.nl()
2121         elif const['type'] == 'const':
2122             defines.append([c_upper(family.get('c-define-name',
2123                                                f"{family.name}-{const['name']}")),
2124                             const['value']])
2125
2126     if defines:
2127         cw.writes_defines(defines)
2128         cw.nl()
2129
2130     max_by_define = family.get('max-by-define', False)
2131
2132     for _, attr_set in family.attr_sets.items():
2133         if attr_set.subset_of:
2134             continue
2135
2136         cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
2137         max_value = f"({cnt_name} - 1)"
2138
2139         val = 0
2140         uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2141         for _, attr in attr_set.items():
2142             suffix = ','
2143             if attr.value != val:
2144                 suffix = f" = {attr.value},"
2145                 val = attr.value
2146             val += 1
2147             cw.p(attr.enum_name + suffix)
2148         cw.nl()
2149         cw.p(cnt_name + ('' if max_by_define else ','))
2150         if not max_by_define:
2151             cw.p(f"{attr_set.max_name} = {max_value}")
2152         cw.block_end(line=';')
2153         if max_by_define:
2154             cw.p(f"#define {attr_set.max_name} {max_value}")
2155         cw.nl()
2156
2157     # Commands
2158     separate_ntf = 'async-prefix' in family['operations']
2159
2160     max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2161     cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2162     max_value = f"({cnt_name} - 1)"
2163
2164     uapi_enum_start(family, cw, family['operations'], 'enum-name')
2165     val = 0
2166     for op in family.msgs.values():
2167         if separate_ntf and ('notify' in op or 'event' in op):
2168             continue
2169
2170         suffix = ','
2171         if op.value != val:
2172             suffix = f" = {op.value},"
2173             val = op.value
2174         cw.p(op.enum_name + suffix)
2175         val += 1
2176     cw.nl()
2177     cw.p(cnt_name + ('' if max_by_define else ','))
2178     if not max_by_define:
2179         cw.p(f"{max_name} = {max_value}")
2180     cw.block_end(line=';')
2181     if max_by_define:
2182         cw.p(f"#define {max_name} {max_value}")
2183     cw.nl()
2184
2185     if separate_ntf:
2186         uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2187         for op in family.msgs.values():
2188             if separate_ntf and not ('notify' in op or 'event' in op):
2189                 continue
2190
2191             suffix = ','
2192             if 'value' in op:
2193                 suffix = f" = {op['value']},"
2194             cw.p(op.enum_name + suffix)
2195         cw.block_end(line=';')
2196         cw.nl()
2197
2198     # Multicast
2199     defines = []
2200     for grp in family.mcgrps['list']:
2201         name = grp['name']
2202         defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2203                         f'{name}'])
2204     cw.nl()
2205     if defines:
2206         cw.writes_defines(defines)
2207         cw.nl()
2208
2209     cw.p(f'#endif /* {hdr_prot} */')
2210
2211
2212 def _render_user_ntf_entry(ri, op):
2213     ri.cw.block_start(line=f"[{op.enum_name}] = ")
2214     ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2215     ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2216     ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2217     ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2218     ri.cw.block_end(line=',')
2219
2220
2221 def render_user_family(family, cw, prototype):
2222     symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2223     if prototype:
2224         cw.p(f'extern {symbol};')
2225         return
2226
2227     if family.ntfs:
2228         cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2229         for ntf_op_name, ntf_op in family.ntfs.items():
2230             if 'notify' in ntf_op:
2231                 op = family.ops[ntf_op['notify']]
2232                 ri = RenderInfo(cw, family, "user", op, "notify")
2233             elif 'event' in ntf_op:
2234                 ri = RenderInfo(cw, family, "user", ntf_op, "event")
2235             else:
2236                 raise Exception('Invalid notification ' + ntf_op_name)
2237             _render_user_ntf_entry(ri, ntf_op)
2238         for op_name, op in family.ops.items():
2239             if 'event' not in op:
2240                 continue
2241             ri = RenderInfo(cw, family, "user", op, "event")
2242             _render_user_ntf_entry(ri, op)
2243         cw.block_end(line=";")
2244         cw.nl()
2245
2246     cw.block_start(f'{symbol} = ')
2247     cw.p(f'.name\t\t= "{family.name}",')
2248     if family.ntfs:
2249         cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2250         cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2251     cw.block_end(line=';')
2252
2253
2254 def find_kernel_root(full_path):
2255     sub_path = ''
2256     while True:
2257         sub_path = os.path.join(os.path.basename(full_path), sub_path)
2258         full_path = os.path.dirname(full_path)
2259         maintainers = os.path.join(full_path, "MAINTAINERS")
2260         if os.path.exists(maintainers):
2261             return full_path, sub_path[:-1]
2262
2263
2264 def main():
2265     parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2266     parser.add_argument('--mode', dest='mode', type=str, required=True)
2267     parser.add_argument('--spec', dest='spec', type=str, required=True)
2268     parser.add_argument('--header', dest='header', action='store_true', default=None)
2269     parser.add_argument('--source', dest='header', action='store_false')
2270     parser.add_argument('--user-header', nargs='+', default=[])
2271     parser.add_argument('--exclude-op', action='append', default=[])
2272     parser.add_argument('-o', dest='out_file', type=str)
2273     args = parser.parse_args()
2274
2275     out_file = open(args.out_file, 'w+') if args.out_file else os.sys.stdout
2276
2277     if args.header is None:
2278         parser.error("--header or --source is required")
2279
2280     exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2281
2282     try:
2283         parsed = Family(args.spec, exclude_ops)
2284         if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2285             print('Spec license:', parsed.license)
2286             print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2287             os.sys.exit(1)
2288     except yaml.YAMLError as exc:
2289         print(exc)
2290         os.sys.exit(1)
2291         return
2292
2293     supported_models = ['unified']
2294     if args.mode == 'user':
2295         supported_models += ['directional']
2296     if parsed.msg_id_model not in supported_models:
2297         print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2298         os.sys.exit(1)
2299
2300     cw = CodeWriter(BaseNlLib(), out_file)
2301
2302     _, spec_kernel = find_kernel_root(args.spec)
2303     if args.mode == 'uapi' or args.header:
2304         cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2305     else:
2306         cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2307     cw.p("/* Do not edit directly, auto-generated from: */")
2308     cw.p(f"/*\t{spec_kernel} */")
2309     cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2310     if args.exclude_op or args.user_header:
2311         line = ''
2312         line += ' --user-header '.join([''] + args.user_header)
2313         line += ' --exclude-op '.join([''] + args.exclude_op)
2314         cw.p(f'/* YNL-ARG{line} */')
2315     cw.nl()
2316
2317     if args.mode == 'uapi':
2318         render_uapi(parsed, cw)
2319         return
2320
2321     hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2322     if args.header:
2323         cw.p('#ifndef ' + hdr_prot)
2324         cw.p('#define ' + hdr_prot)
2325         cw.nl()
2326
2327     if args.mode == 'kernel':
2328         cw.p('#include <net/netlink.h>')
2329         cw.p('#include <net/genetlink.h>')
2330         cw.nl()
2331         if not args.header:
2332             if args.out_file:
2333                 cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2334             cw.nl()
2335         headers = ['uapi/' + parsed.uapi_header]
2336     else:
2337         cw.p('#include <stdlib.h>')
2338         cw.p('#include <string.h>')
2339         if args.header:
2340             cw.p('#include <linux/types.h>')
2341         else:
2342             cw.p(f'#include "{parsed.name}-user.h"')
2343             cw.p('#include "ynl.h"')
2344         headers = [parsed.uapi_header]
2345     for definition in parsed['definitions']:
2346         if 'header' in definition:
2347             headers.append(definition['header'])
2348     for one in headers:
2349         cw.p(f"#include <{one}>")
2350     cw.nl()
2351
2352     if args.mode == "user":
2353         if not args.header:
2354             cw.p("#include <libmnl/libmnl.h>")
2355             cw.p("#include <linux/genetlink.h>")
2356             cw.nl()
2357             for one in args.user_header:
2358                 cw.p(f'#include "{one}"')
2359         else:
2360             cw.p('struct ynl_sock;')
2361             cw.nl()
2362             render_user_family(parsed, cw, True)
2363         cw.nl()
2364
2365     if args.mode == "kernel":
2366         if args.header:
2367             for _, struct in sorted(parsed.pure_nested_structs.items()):
2368                 if struct.request:
2369                     cw.p('/* Common nested types */')
2370                     break
2371             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2372                 if struct.request:
2373                     print_req_policy_fwd(cw, struct)
2374             cw.nl()
2375
2376             if parsed.kernel_policy == 'global':
2377                 cw.p(f"/* Global operation policy for {parsed.name} */")
2378
2379                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2380                 print_req_policy_fwd(cw, struct)
2381                 cw.nl()
2382
2383             if parsed.kernel_policy in {'per-op', 'split'}:
2384                 for op_name, op in parsed.ops.items():
2385                     if 'do' in op and 'event' not in op:
2386                         ri = RenderInfo(cw, parsed, args.mode, op, "do")
2387                         print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2388                         cw.nl()
2389
2390             print_kernel_op_table_hdr(parsed, cw)
2391             print_kernel_mcgrp_hdr(parsed, cw)
2392             print_kernel_family_struct_hdr(parsed, cw)
2393         else:
2394             for _, struct in sorted(parsed.pure_nested_structs.items()):
2395                 if struct.request:
2396                     cw.p('/* Common nested types */')
2397                     break
2398             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2399                 if struct.request:
2400                     print_req_policy(cw, struct)
2401             cw.nl()
2402
2403             if parsed.kernel_policy == 'global':
2404                 cw.p(f"/* Global operation policy for {parsed.name} */")
2405
2406                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2407                 print_req_policy(cw, struct)
2408                 cw.nl()
2409
2410             for op_name, op in parsed.ops.items():
2411                 if parsed.kernel_policy in {'per-op', 'split'}:
2412                     for op_mode in ['do', 'dump']:
2413                         if op_mode in op and 'request' in op[op_mode]:
2414                             cw.p(f"/* {op.enum_name} - {op_mode} */")
2415                             ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2416                             print_req_policy(cw, ri.struct['request'], ri=ri)
2417                             cw.nl()
2418
2419             print_kernel_op_table(parsed, cw)
2420             print_kernel_mcgrp_src(parsed, cw)
2421             print_kernel_family_struct_src(parsed, cw)
2422
2423     if args.mode == "user":
2424         if args.header:
2425             cw.p('/* Enums */')
2426             put_op_name_fwd(parsed, cw)
2427
2428             for name, const in parsed.consts.items():
2429                 if isinstance(const, EnumSet):
2430                     put_enum_to_str_fwd(parsed, cw, const)
2431             cw.nl()
2432
2433             cw.p('/* Common nested types */')
2434             for attr_set, struct in parsed.pure_nested_structs.items():
2435                 ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2436                 print_type_full(ri, struct)
2437
2438             for op_name, op in parsed.ops.items():
2439                 cw.p(f"/* ============== {op.enum_name} ============== */")
2440
2441                 if 'do' in op and 'event' not in op:
2442                     cw.p(f"/* {op.enum_name} - do */")
2443                     ri = RenderInfo(cw, parsed, args.mode, op, "do")
2444                     print_req_type(ri)
2445                     print_req_type_helpers(ri)
2446                     cw.nl()
2447                     print_rsp_type(ri)
2448                     print_rsp_type_helpers(ri)
2449                     cw.nl()
2450                     print_req_prototype(ri)
2451                     cw.nl()
2452
2453                 if 'dump' in op:
2454                     cw.p(f"/* {op.enum_name} - dump */")
2455                     ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2456                     if 'request' in op['dump']:
2457                         print_req_type(ri)
2458                         print_req_type_helpers(ri)
2459                     if not ri.type_consistent:
2460                         print_rsp_type(ri)
2461                     print_wrapped_type(ri)
2462                     print_dump_prototype(ri)
2463                     cw.nl()
2464
2465                 if op.has_ntf:
2466                     cw.p(f"/* {op.enum_name} - notify */")
2467                     ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2468                     if not ri.type_consistent:
2469                         raise Exception(f'Only notifications with consistent types supported ({op.name})')
2470                     print_wrapped_type(ri)
2471
2472             for op_name, op in parsed.ntfs.items():
2473                 if 'event' in op:
2474                     ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2475                     cw.p(f"/* {op.enum_name} - event */")
2476                     print_rsp_type(ri)
2477                     cw.nl()
2478                     print_wrapped_type(ri)
2479             cw.nl()
2480         else:
2481             cw.p('/* Enums */')
2482             put_op_name(parsed, cw)
2483
2484             for name, const in parsed.consts.items():
2485                 if isinstance(const, EnumSet):
2486                     put_enum_to_str(parsed, cw, const)
2487             cw.nl()
2488
2489             cw.p('/* Policies */')
2490             for name in parsed.pure_nested_structs:
2491                 struct = Struct(parsed, name)
2492                 put_typol(cw, struct)
2493             for name in parsed.root_sets:
2494                 struct = Struct(parsed, name)
2495                 put_typol(cw, struct)
2496
2497             cw.p('/* Common nested types */')
2498             for attr_set, struct in parsed.pure_nested_structs.items():
2499                 ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2500
2501                 free_rsp_nested(ri, struct)
2502                 if struct.request:
2503                     put_req_nested(ri, struct)
2504                 if struct.reply:
2505                     parse_rsp_nested(ri, struct)
2506
2507             for op_name, op in parsed.ops.items():
2508                 cw.p(f"/* ============== {op.enum_name} ============== */")
2509                 if 'do' in op and 'event' not in op:
2510                     cw.p(f"/* {op.enum_name} - do */")
2511                     ri = RenderInfo(cw, parsed, args.mode, op, "do")
2512                     print_req_free(ri)
2513                     print_rsp_free(ri)
2514                     parse_rsp_msg(ri)
2515                     print_req(ri)
2516                     cw.nl()
2517
2518                 if 'dump' in op:
2519                     cw.p(f"/* {op.enum_name} - dump */")
2520                     ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2521                     if not ri.type_consistent:
2522                         parse_rsp_msg(ri, deref=True)
2523                     print_dump_type_free(ri)
2524                     print_dump(ri)
2525                     cw.nl()
2526
2527                 if op.has_ntf:
2528                     cw.p(f"/* {op.enum_name} - notify */")
2529                     ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2530                     if not ri.type_consistent:
2531                         raise Exception(f'Only notifications with consistent types supported ({op.name})')
2532                     print_ntf_type_free(ri)
2533
2534             for op_name, op in parsed.ntfs.items():
2535                 if 'event' in op:
2536                     cw.p(f"/* {op.enum_name} - event */")
2537
2538                     ri = RenderInfo(cw, parsed, args.mode, op, "do")
2539                     parse_rsp_msg(ri)
2540
2541                     ri = RenderInfo(cw, parsed, args.mode, op, "event")
2542                     print_ntf_type_free(ri)
2543             cw.nl()
2544             render_user_family(parsed, cw, False)
2545
2546     if args.header:
2547         cw.p(f'#endif /* {hdr_prot} */')
2548
2549
2550 if __name__ == "__main__":
2551     main()