e27deb199a707f45fd16cf2cc397817c4becd0ae
[platform/kernel/linux-rpi.git] / tools / net / ynl / ynl-gen-c.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)
3
4 import argparse
5 import collections
6 import os
7 import re
8 import shutil
9 import tempfile
10 import yaml
11
12 from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
13
14
15 def c_upper(name):
16     return name.upper().replace('-', '_')
17
18
19 def c_lower(name):
20     return name.lower().replace('-', '_')
21
22
23 class BaseNlLib:
24     def get_family_id(self):
25         return 'ys->family_id'
26
27     def parse_cb_run(self, cb, data, is_dump=False, indent=1):
28         ind = '\n\t\t' + '\t' * indent + ' '
29         if is_dump:
30             return f"mnl_cb_run2(ys->rx_buf, len, 0, 0, {cb}, {data},{ind}ynl_cb_array, NLMSG_MIN_TYPE)"
31         else:
32             return f"mnl_cb_run2(ys->rx_buf, len, ys->seq, ys->portid,{ind}{cb}, {data},{ind}" + \
33                    "ynl_cb_array, NLMSG_MIN_TYPE)"
34
35
36 class Type(SpecAttr):
37     def __init__(self, family, attr_set, attr, value):
38         super().__init__(family, attr_set, attr, value)
39
40         self.attr = attr
41         self.attr_set = attr_set
42         self.type = attr['type']
43         self.checks = attr.get('checks', {})
44
45         if 'len' in attr:
46             self.len = attr['len']
47         if 'nested-attributes' in attr:
48             self.nested_attrs = attr['nested-attributes']
49             if self.nested_attrs == family.name:
50                 self.nested_render_name = f"{family.name}"
51             else:
52                 self.nested_render_name = f"{family.name}_{c_lower(self.nested_attrs)}"
53
54             if self.nested_attrs in self.family.consts:
55                 self.nested_struct_type = 'struct ' + self.nested_render_name + '_'
56             else:
57                 self.nested_struct_type = 'struct ' + self.nested_render_name
58
59         self.c_name = c_lower(self.name)
60         if self.c_name in _C_KW:
61             self.c_name += '_'
62
63         # Added by resolve():
64         self.enum_name = None
65         delattr(self, "enum_name")
66
67     def resolve(self):
68         if 'name-prefix' in self.attr:
69             enum_name = f"{self.attr['name-prefix']}{self.name}"
70         else:
71             enum_name = f"{self.attr_set.name_prefix}{self.name}"
72         self.enum_name = c_upper(enum_name)
73
74     def is_multi_val(self):
75         return None
76
77     def is_scalar(self):
78         return self.type in {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
79
80     def presence_type(self):
81         return 'bit'
82
83     def presence_member(self, space, type_filter):
84         if self.presence_type() != type_filter:
85             return
86
87         if self.presence_type() == 'bit':
88             pfx = '__' if space == 'user' else ''
89             return f"{pfx}u32 {self.c_name}:1;"
90
91         if self.presence_type() == 'len':
92             pfx = '__' if space == 'user' else ''
93             return f"{pfx}u32 {self.c_name}_len;"
94
95     def _complex_member_type(self, ri):
96         return None
97
98     def free_needs_iter(self):
99         return False
100
101     def free(self, ri, var, ref):
102         if self.is_multi_val() or self.presence_type() == 'len':
103             ri.cw.p(f'free({var}->{ref}{self.c_name});')
104
105     def arg_member(self, ri):
106         member = self._complex_member_type(ri)
107         if member:
108             arg = [member + ' *' + self.c_name]
109             if self.presence_type() == 'count':
110                 arg += ['unsigned int n_' + self.c_name]
111             return arg
112         raise Exception(f"Struct member not implemented for class type {self.type}")
113
114     def struct_member(self, ri):
115         if self.is_multi_val():
116             ri.cw.p(f"unsigned int n_{self.c_name};")
117         member = self._complex_member_type(ri)
118         if member:
119             ptr = '*' if self.is_multi_val() else ''
120             ri.cw.p(f"{member} {ptr}{self.c_name};")
121             return
122         members = self.arg_member(ri)
123         for one in members:
124             ri.cw.p(one + ';')
125
126     def _attr_policy(self, policy):
127         return '{ .type = ' + policy + ', }'
128
129     def attr_policy(self, cw):
130         policy = c_upper('nla-' + self.attr['type'])
131
132         spec = self._attr_policy(policy)
133         cw.p(f"\t[{self.enum_name}] = {spec},")
134
135     def _attr_typol(self):
136         raise Exception(f"Type policy not implemented for class type {self.type}")
137
138     def attr_typol(self, cw):
139         typol = self._attr_typol()
140         cw.p(f'[{self.enum_name}] = {"{"} .name = "{self.name}", {typol}{"}"},')
141
142     def _attr_put_line(self, ri, var, line):
143         if self.presence_type() == 'bit':
144             ri.cw.p(f"if ({var}->_present.{self.c_name})")
145         elif self.presence_type() == 'len':
146             ri.cw.p(f"if ({var}->_present.{self.c_name}_len)")
147         ri.cw.p(f"{line};")
148
149     def _attr_put_simple(self, ri, var, put_type):
150         line = f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name})"
151         self._attr_put_line(ri, var, line)
152
153     def attr_put(self, ri, var):
154         raise Exception(f"Put not implemented for class type {self.type}")
155
156     def _attr_get(self, ri, var):
157         raise Exception(f"Attr get not implemented for class type {self.type}")
158
159     def attr_get(self, ri, var, first):
160         lines, init_lines, local_vars = self._attr_get(ri, var)
161         if type(lines) is str:
162             lines = [lines]
163         if type(init_lines) is str:
164             init_lines = [init_lines]
165
166         kw = 'if' if first else 'else if'
167         ri.cw.block_start(line=f"{kw} (type == {self.enum_name})")
168         if local_vars:
169             for local in local_vars:
170                 ri.cw.p(local)
171             ri.cw.nl()
172
173         if not self.is_multi_val():
174             ri.cw.p("if (ynl_attr_validate(yarg, attr))")
175             ri.cw.p("return MNL_CB_ERROR;")
176             if self.presence_type() == 'bit':
177                 ri.cw.p(f"{var}->_present.{self.c_name} = 1;")
178
179         if init_lines:
180             ri.cw.nl()
181             for line in init_lines:
182                 ri.cw.p(line)
183
184         for line in lines:
185             ri.cw.p(line)
186         ri.cw.block_end()
187         return True
188
189     def _setter_lines(self, ri, member, presence):
190         raise Exception(f"Setter not implemented for class type {self.type}")
191
192     def setter(self, ri, space, direction, deref=False, ref=None):
193         ref = (ref if ref else []) + [self.c_name]
194         var = "req"
195         member = f"{var}->{'.'.join(ref)}"
196
197         code = []
198         presence = ''
199         for i in range(0, len(ref)):
200             presence = f"{var}->{'.'.join(ref[:i] + [''])}_present.{ref[i]}"
201             if self.presence_type() == 'bit':
202                 code.append(presence + ' = 1;')
203         code += self._setter_lines(ri, member, presence)
204
205         func_name = f"{op_prefix(ri, direction, deref=deref)}_set_{'_'.join(ref)}"
206         free = bool([x for x in code if 'free(' in x])
207         alloc = bool([x for x in code if 'alloc(' in x])
208         if free and not alloc:
209             func_name = '__' + func_name
210         ri.cw.write_func('static inline void', func_name, body=code,
211                          args=[f'{type_name(ri, direction, deref=deref)} *{var}'] + self.arg_member(ri))
212
213
214 class TypeUnused(Type):
215     def presence_type(self):
216         return ''
217
218     def arg_member(self, ri):
219         return []
220
221     def _attr_get(self, ri, var):
222         return ['return MNL_CB_ERROR;'], None, None
223
224     def _attr_typol(self):
225         return '.type = YNL_PT_REJECT, '
226
227     def attr_policy(self, cw):
228         pass
229
230
231 class TypePad(Type):
232     def presence_type(self):
233         return ''
234
235     def arg_member(self, ri):
236         return []
237
238     def _attr_typol(self):
239         return '.type = YNL_PT_IGNORE, '
240
241     def attr_put(self, ri, var):
242         pass
243
244     def attr_get(self, ri, var, first):
245         pass
246
247     def attr_policy(self, cw):
248         pass
249
250     def setter(self, ri, space, direction, deref=False, ref=None):
251         pass
252
253
254 class TypeScalar(Type):
255     def __init__(self, family, attr_set, attr, value):
256         super().__init__(family, attr_set, attr, value)
257
258         self.byte_order_comment = ''
259         if 'byte-order' in attr:
260             self.byte_order_comment = f" /* {attr['byte-order']} */"
261
262         # Added by resolve():
263         self.is_bitfield = None
264         delattr(self, "is_bitfield")
265         self.type_name = None
266         delattr(self, "type_name")
267
268     def resolve(self):
269         self.resolve_up(super())
270
271         if 'enum-as-flags' in self.attr and self.attr['enum-as-flags']:
272             self.is_bitfield = True
273         elif 'enum' in self.attr:
274             self.is_bitfield = self.family.consts[self.attr['enum']]['type'] == 'flags'
275         else:
276             self.is_bitfield = False
277
278         maybe_enum = not self.is_bitfield and 'enum' in self.attr
279         if maybe_enum and self.family.consts[self.attr['enum']].enum_name:
280             self.type_name = f"enum {self.family.name}_{c_lower(self.attr['enum'])}"
281         else:
282             self.type_name = '__' + self.type
283
284     def _mnl_type(self):
285         t = self.type
286         # mnl does not have a helper for signed types
287         if t[0] == 's':
288             t = 'u' + t[1:]
289         return t
290
291     def _attr_policy(self, policy):
292         if 'flags-mask' in self.checks or self.is_bitfield:
293             if self.is_bitfield:
294                 enum = self.family.consts[self.attr['enum']]
295                 mask = enum.get_mask(as_flags=True)
296             else:
297                 flags = self.family.consts[self.checks['flags-mask']]
298                 flag_cnt = len(flags['entries'])
299                 mask = (1 << flag_cnt) - 1
300             return f"NLA_POLICY_MASK({policy}, 0x{mask:x})"
301         elif 'min' in self.checks:
302             return f"NLA_POLICY_MIN({policy}, {self.checks['min']})"
303         elif 'enum' in self.attr:
304             enum = self.family.consts[self.attr['enum']]
305             low, high = enum.value_range()
306             if low == 0:
307                 return f"NLA_POLICY_MAX({policy}, {high})"
308             return f"NLA_POLICY_RANGE({policy}, {low}, {high})"
309         return super()._attr_policy(policy)
310
311     def _attr_typol(self):
312         return f'.type = YNL_PT_U{self.type[1:]}, '
313
314     def arg_member(self, ri):
315         return [f'{self.type_name} {self.c_name}{self.byte_order_comment}']
316
317     def attr_put(self, ri, var):
318         self._attr_put_simple(ri, var, self._mnl_type())
319
320     def _attr_get(self, ri, var):
321         return f"{var}->{self.c_name} = mnl_attr_get_{self._mnl_type()}(attr);", None, None
322
323     def _setter_lines(self, ri, member, presence):
324         return [f"{member} = {self.c_name};"]
325
326
327 class TypeFlag(Type):
328     def arg_member(self, ri):
329         return []
330
331     def _attr_typol(self):
332         return '.type = YNL_PT_FLAG, '
333
334     def attr_put(self, ri, var):
335         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, 0, NULL)")
336
337     def _attr_get(self, ri, var):
338         return [], None, None
339
340     def _setter_lines(self, ri, member, presence):
341         return []
342
343
344 class TypeString(Type):
345     def arg_member(self, ri):
346         return [f"const char *{self.c_name}"]
347
348     def presence_type(self):
349         return 'len'
350
351     def struct_member(self, ri):
352         ri.cw.p(f"char *{self.c_name};")
353
354     def _attr_typol(self):
355         return f'.type = YNL_PT_NUL_STR, '
356
357     def _attr_policy(self, policy):
358         mem = '{ .type = ' + policy
359         if 'max-len' in self.checks:
360             mem += ', .len = ' + str(self.checks['max-len'])
361         mem += ', }'
362         return mem
363
364     def attr_policy(self, cw):
365         if self.checks.get('unterminated-ok', False):
366             policy = 'NLA_STRING'
367         else:
368             policy = 'NLA_NUL_STRING'
369
370         spec = self._attr_policy(policy)
371         cw.p(f"\t[{self.enum_name}] = {spec},")
372
373     def attr_put(self, ri, var):
374         self._attr_put_simple(ri, var, 'strz')
375
376     def _attr_get(self, ri, var):
377         len_mem = var + '->_present.' + self.c_name + '_len'
378         return [f"{len_mem} = len;",
379                 f"{var}->{self.c_name} = malloc(len + 1);",
380                 f"memcpy({var}->{self.c_name}, mnl_attr_get_str(attr), len);",
381                 f"{var}->{self.c_name}[len] = 0;"], \
382                ['len = strnlen(mnl_attr_get_str(attr), mnl_attr_get_payload_len(attr));'], \
383                ['unsigned int len;']
384
385     def _setter_lines(self, ri, member, presence):
386         return [f"free({member});",
387                 f"{presence}_len = strlen({self.c_name});",
388                 f"{member} = malloc({presence}_len + 1);",
389                 f'memcpy({member}, {self.c_name}, {presence}_len);',
390                 f'{member}[{presence}_len] = 0;']
391
392
393 class TypeBinary(Type):
394     def arg_member(self, ri):
395         return [f"const void *{self.c_name}", 'size_t len']
396
397     def presence_type(self):
398         return 'len'
399
400     def struct_member(self, ri):
401         ri.cw.p(f"void *{self.c_name};")
402
403     def _attr_typol(self):
404         return f'.type = YNL_PT_BINARY,'
405
406     def _attr_policy(self, policy):
407         mem = '{ '
408         if len(self.checks) == 1 and 'min-len' in self.checks:
409             mem += '.len = ' + str(self.checks['min-len'])
410         elif len(self.checks) == 0:
411             mem += '.type = NLA_BINARY'
412         else:
413             raise Exception('One or more of binary type checks not implemented, yet')
414         mem += ', }'
415         return mem
416
417     def attr_put(self, ri, var):
418         self._attr_put_line(ri, var, f"mnl_attr_put(nlh, {self.enum_name}, " +
419                             f"{var}->_present.{self.c_name}_len, {var}->{self.c_name})")
420
421     def _attr_get(self, ri, var):
422         len_mem = var + '->_present.' + self.c_name + '_len'
423         return [f"{len_mem} = len;",
424                 f"{var}->{self.c_name} = malloc(len);",
425                 f"memcpy({var}->{self.c_name}, mnl_attr_get_payload(attr), len);"], \
426                ['len = mnl_attr_get_payload_len(attr);'], \
427                ['unsigned int len;']
428
429     def _setter_lines(self, ri, member, presence):
430         return [f"free({member});",
431                 f"{presence}_len = len;",
432                 f"{member} = malloc({presence}_len);",
433                 f'memcpy({member}, {self.c_name}, {presence}_len);']
434
435
436 class TypeNest(Type):
437     def _complex_member_type(self, ri):
438         return self.nested_struct_type
439
440     def free(self, ri, var, ref):
441         ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name});')
442
443     def _attr_typol(self):
444         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
445
446     def _attr_policy(self, policy):
447         return 'NLA_POLICY_NESTED(' + self.nested_render_name + '_nl_policy)'
448
449     def attr_put(self, ri, var):
450         self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
451                             f"{self.enum_name}, &{var}->{self.c_name})")
452
453     def _attr_get(self, ri, var):
454         get_lines = [f"if ({self.nested_render_name}_parse(&parg, attr))",
455                      "return MNL_CB_ERROR;"]
456         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
457                       f"parg.data = &{var}->{self.c_name};"]
458         return get_lines, init_lines, None
459
460     def setter(self, ri, space, direction, deref=False, ref=None):
461         ref = (ref if ref else []) + [self.c_name]
462
463         for _, attr in ri.family.pure_nested_structs[self.nested_attrs].member_list():
464             attr.setter(ri, self.nested_attrs, direction, deref=deref, ref=ref)
465
466
467 class TypeMultiAttr(Type):
468     def __init__(self, family, attr_set, attr, value, base_type):
469         super().__init__(family, attr_set, attr, value)
470
471         self.base_type = base_type
472
473     def is_multi_val(self):
474         return True
475
476     def presence_type(self):
477         return 'count'
478
479     def _mnl_type(self):
480         t = self.type
481         # mnl does not have a helper for signed types
482         if t[0] == 's':
483             t = 'u' + t[1:]
484         return t
485
486     def _complex_member_type(self, ri):
487         if 'type' not in self.attr or self.attr['type'] == 'nest':
488             return self.nested_struct_type
489         elif self.attr['type'] in scalars:
490             scalar_pfx = '__' if ri.ku_space == 'user' else ''
491             return scalar_pfx + self.attr['type']
492         else:
493             raise Exception(f"Sub-type {self.attr['type']} not supported yet")
494
495     def free_needs_iter(self):
496         return 'type' not in self.attr or self.attr['type'] == 'nest'
497
498     def free(self, ri, var, ref):
499         if self.attr['type'] in scalars:
500             ri.cw.p(f"free({var}->{ref}{self.c_name});")
501         elif 'type' not in self.attr or self.attr['type'] == 'nest':
502             ri.cw.p(f"for (i = 0; i < {var}->{ref}n_{self.c_name}; i++)")
503             ri.cw.p(f'{self.nested_render_name}_free(&{var}->{ref}{self.c_name}[i]);')
504             ri.cw.p(f"free({var}->{ref}{self.c_name});")
505         else:
506             raise Exception(f"Free of MultiAttr sub-type {self.attr['type']} not supported yet")
507
508     def _attr_policy(self, policy):
509         return self.base_type._attr_policy(policy)
510
511     def _attr_typol(self):
512         return self.base_type._attr_typol()
513
514     def _attr_get(self, ri, var):
515         return f'n_{self.c_name}++;', None, None
516
517     def attr_put(self, ri, var):
518         if self.attr['type'] in scalars:
519             put_type = self._mnl_type()
520             ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
521             ri.cw.p(f"mnl_attr_put_{put_type}(nlh, {self.enum_name}, {var}->{self.c_name}[i]);")
522         elif 'type' not in self.attr or self.attr['type'] == 'nest':
523             ri.cw.p(f"for (unsigned int i = 0; i < {var}->n_{self.c_name}; i++)")
524             self._attr_put_line(ri, var, f"{self.nested_render_name}_put(nlh, " +
525                                 f"{self.enum_name}, &{var}->{self.c_name}[i])")
526         else:
527             raise Exception(f"Put of MultiAttr sub-type {self.attr['type']} not supported yet")
528
529     def _setter_lines(self, ri, member, presence):
530         # For multi-attr we have a count, not presence, hack up the presence
531         presence = presence[:-(len('_present.') + len(self.c_name))] + "n_" + self.c_name
532         return [f"free({member});",
533                 f"{member} = {self.c_name};",
534                 f"{presence} = n_{self.c_name};"]
535
536
537 class TypeArrayNest(Type):
538     def is_multi_val(self):
539         return True
540
541     def presence_type(self):
542         return 'count'
543
544     def _complex_member_type(self, ri):
545         if 'sub-type' not in self.attr or self.attr['sub-type'] == 'nest':
546             return self.nested_struct_type
547         elif self.attr['sub-type'] in scalars:
548             scalar_pfx = '__' if ri.ku_space == 'user' else ''
549             return scalar_pfx + self.attr['sub-type']
550         else:
551             raise Exception(f"Sub-type {self.attr['sub-type']} not supported yet")
552
553     def _attr_typol(self):
554         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
555
556     def _attr_get(self, ri, var):
557         local_vars = ['const struct nlattr *attr2;']
558         get_lines = [f'attr_{self.c_name} = attr;',
559                      'mnl_attr_for_each_nested(attr2, attr)',
560                      f'\t{var}->n_{self.c_name}++;']
561         return get_lines, None, local_vars
562
563
564 class TypeNestTypeValue(Type):
565     def _complex_member_type(self, ri):
566         return self.nested_struct_type
567
568     def _attr_typol(self):
569         return f'.type = YNL_PT_NEST, .nest = &{self.nested_render_name}_nest, '
570
571     def _attr_get(self, ri, var):
572         prev = 'attr'
573         tv_args = ''
574         get_lines = []
575         local_vars = []
576         init_lines = [f"parg.rsp_policy = &{self.nested_render_name}_nest;",
577                       f"parg.data = &{var}->{self.c_name};"]
578         if 'type-value' in self.attr:
579             tv_names = [c_lower(x) for x in self.attr["type-value"]]
580             local_vars += [f'const struct nlattr *attr_{", *attr_".join(tv_names)};']
581             local_vars += [f'__u32 {", ".join(tv_names)};']
582             for level in self.attr["type-value"]:
583                 level = c_lower(level)
584                 get_lines += [f'attr_{level} = mnl_attr_get_payload({prev});']
585                 get_lines += [f'{level} = mnl_attr_get_type(attr_{level});']
586                 prev = 'attr_' + level
587
588             tv_args = f", {', '.join(tv_names)}"
589
590         get_lines += [f"{self.nested_render_name}_parse(&parg, {prev}{tv_args});"]
591         return get_lines, init_lines, local_vars
592
593
594 class Struct:
595     def __init__(self, family, space_name, type_list=None, inherited=None):
596         self.family = family
597         self.space_name = space_name
598         self.attr_set = family.attr_sets[space_name]
599         # Use list to catch comparisons with empty sets
600         self._inherited = inherited if inherited is not None else []
601         self.inherited = []
602
603         self.nested = type_list is None
604         if family.name == c_lower(space_name):
605             self.render_name = f"{family.name}"
606         else:
607             self.render_name = f"{family.name}_{c_lower(space_name)}"
608         self.struct_name = 'struct ' + self.render_name
609         if self.nested and space_name in family.consts:
610             self.struct_name += '_'
611         self.ptr_name = self.struct_name + ' *'
612
613         self.request = False
614         self.reply = False
615
616         self.attr_list = []
617         self.attrs = dict()
618         if type_list:
619             for t in type_list:
620                 self.attr_list.append((t, self.attr_set[t]),)
621         else:
622             for t in self.attr_set:
623                 self.attr_list.append((t, self.attr_set[t]),)
624
625         max_val = 0
626         self.attr_max_val = None
627         for name, attr in self.attr_list:
628             if attr.value >= max_val:
629                 max_val = attr.value
630                 self.attr_max_val = attr
631             self.attrs[name] = attr
632
633     def __iter__(self):
634         yield from self.attrs
635
636     def __getitem__(self, key):
637         return self.attrs[key]
638
639     def member_list(self):
640         return self.attr_list
641
642     def set_inherited(self, new_inherited):
643         if self._inherited != new_inherited:
644             raise Exception("Inheriting different members not supported")
645         self.inherited = [c_lower(x) for x in sorted(self._inherited)]
646
647
648 class EnumEntry(SpecEnumEntry):
649     def __init__(self, enum_set, yaml, prev, value_start):
650         super().__init__(enum_set, yaml, prev, value_start)
651
652         if prev:
653             self.value_change = (self.value != prev.value + 1)
654         else:
655             self.value_change = (self.value != 0)
656         self.value_change = self.value_change or self.enum_set['type'] == 'flags'
657
658         # Added by resolve:
659         self.c_name = None
660         delattr(self, "c_name")
661
662     def resolve(self):
663         self.resolve_up(super())
664
665         self.c_name = c_upper(self.enum_set.value_pfx + self.name)
666
667
668 class EnumSet(SpecEnumSet):
669     def __init__(self, family, yaml):
670         self.render_name = c_lower(family.name + '-' + yaml['name'])
671
672         if 'enum-name' in yaml:
673             if yaml['enum-name']:
674                 self.enum_name = 'enum ' + c_lower(yaml['enum-name'])
675             else:
676                 self.enum_name = None
677         else:
678             self.enum_name = 'enum ' + self.render_name
679
680         self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
681
682         super().__init__(family, yaml)
683
684     def new_entry(self, entry, prev_entry, value_start):
685         return EnumEntry(self, entry, prev_entry, value_start)
686
687     def value_range(self):
688         low = min([x.value for x in self.entries.values()])
689         high = max([x.value for x in self.entries.values()])
690
691         if high - low + 1 != len(self.entries):
692             raise Exception("Can't get value range for a noncontiguous enum")
693
694         return low, high
695
696
697 class AttrSet(SpecAttrSet):
698     def __init__(self, family, yaml):
699         super().__init__(family, yaml)
700
701         if self.subset_of is None:
702             if 'name-prefix' in yaml:
703                 pfx = yaml['name-prefix']
704             elif self.name == family.name:
705                 pfx = family.name + '-a-'
706             else:
707                 pfx = f"{family.name}-a-{self.name}-"
708             self.name_prefix = c_upper(pfx)
709             self.max_name = c_upper(self.yaml.get('attr-max-name', f"{self.name_prefix}max"))
710         else:
711             self.name_prefix = family.attr_sets[self.subset_of].name_prefix
712             self.max_name = family.attr_sets[self.subset_of].max_name
713
714         # Added by resolve:
715         self.c_name = None
716         delattr(self, "c_name")
717
718     def resolve(self):
719         self.c_name = c_lower(self.name)
720         if self.c_name in _C_KW:
721             self.c_name += '_'
722         if self.c_name == self.family.c_name:
723             self.c_name = ''
724
725     def new_attr(self, elem, value):
726         if elem['type'] in scalars:
727             t = TypeScalar(self.family, self, elem, value)
728         elif elem['type'] == 'unused':
729             t = TypeUnused(self.family, self, elem, value)
730         elif elem['type'] == 'pad':
731             t = TypePad(self.family, self, elem, value)
732         elif elem['type'] == 'flag':
733             t = TypeFlag(self.family, self, elem, value)
734         elif elem['type'] == 'string':
735             t = TypeString(self.family, self, elem, value)
736         elif elem['type'] == 'binary':
737             t = TypeBinary(self.family, self, elem, value)
738         elif elem['type'] == 'nest':
739             t = TypeNest(self.family, self, elem, value)
740         elif elem['type'] == 'array-nest':
741             t = TypeArrayNest(self.family, self, elem, value)
742         elif elem['type'] == 'nest-type-value':
743             t = TypeNestTypeValue(self.family, self, elem, value)
744         else:
745             raise Exception(f"No typed class for type {elem['type']}")
746
747         if 'multi-attr' in elem and elem['multi-attr']:
748             t = TypeMultiAttr(self.family, self, elem, value, t)
749
750         return t
751
752
753 class Operation(SpecOperation):
754     def __init__(self, family, yaml, req_value, rsp_value):
755         super().__init__(family, yaml, req_value, rsp_value)
756
757         self.render_name = family.name + '_' + c_lower(self.name)
758
759         self.dual_policy = ('do' in yaml and 'request' in yaml['do']) and \
760                          ('dump' in yaml and 'request' in yaml['dump'])
761
762         self.has_ntf = False
763
764         # Added by resolve:
765         self.enum_name = None
766         delattr(self, "enum_name")
767
768     def resolve(self):
769         self.resolve_up(super())
770
771         if not self.is_async:
772             self.enum_name = self.family.op_prefix + c_upper(self.name)
773         else:
774             self.enum_name = self.family.async_op_prefix + c_upper(self.name)
775
776     def mark_has_ntf(self):
777         self.has_ntf = True
778
779
780 class Family(SpecFamily):
781     def __init__(self, file_name, exclude_ops):
782         # Added by resolve:
783         self.c_name = None
784         delattr(self, "c_name")
785         self.op_prefix = None
786         delattr(self, "op_prefix")
787         self.async_op_prefix = None
788         delattr(self, "async_op_prefix")
789         self.mcgrps = None
790         delattr(self, "mcgrps")
791         self.consts = None
792         delattr(self, "consts")
793         self.hooks = None
794         delattr(self, "hooks")
795
796         super().__init__(file_name, exclude_ops=exclude_ops)
797
798         self.fam_key = c_upper(self.yaml.get('c-family-name', self.yaml["name"] + '_FAMILY_NAME'))
799         self.ver_key = c_upper(self.yaml.get('c-version-name', self.yaml["name"] + '_FAMILY_VERSION'))
800
801         if 'definitions' not in self.yaml:
802             self.yaml['definitions'] = []
803
804         if 'uapi-header' in self.yaml:
805             self.uapi_header = self.yaml['uapi-header']
806         else:
807             self.uapi_header = f"linux/{self.name}.h"
808
809     def resolve(self):
810         self.resolve_up(super())
811
812         if self.yaml.get('protocol', 'genetlink') not in {'genetlink', 'genetlink-c', 'genetlink-legacy'}:
813             raise Exception("Codegen only supported for genetlink")
814
815         self.c_name = c_lower(self.name)
816         if 'name-prefix' in self.yaml['operations']:
817             self.op_prefix = c_upper(self.yaml['operations']['name-prefix'])
818         else:
819             self.op_prefix = c_upper(self.yaml['name'] + '-cmd-')
820         if 'async-prefix' in self.yaml['operations']:
821             self.async_op_prefix = c_upper(self.yaml['operations']['async-prefix'])
822         else:
823             self.async_op_prefix = self.op_prefix
824
825         self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
826
827         self.hooks = dict()
828         for when in ['pre', 'post']:
829             self.hooks[when] = dict()
830             for op_mode in ['do', 'dump']:
831                 self.hooks[when][op_mode] = dict()
832                 self.hooks[when][op_mode]['set'] = set()
833                 self.hooks[when][op_mode]['list'] = []
834
835         # dict space-name -> 'request': set(attrs), 'reply': set(attrs)
836         self.root_sets = dict()
837         # dict space-name -> set('request', 'reply')
838         self.pure_nested_structs = dict()
839
840         self._mark_notify()
841         self._mock_up_events()
842
843         self._load_root_sets()
844         self._load_nested_sets()
845         self._load_hooks()
846
847         self.kernel_policy = self.yaml.get('kernel-policy', 'split')
848         if self.kernel_policy == 'global':
849             self._load_global_policy()
850
851     def new_enum(self, elem):
852         return EnumSet(self, elem)
853
854     def new_attr_set(self, elem):
855         return AttrSet(self, elem)
856
857     def new_operation(self, elem, req_value, rsp_value):
858         return Operation(self, elem, req_value, rsp_value)
859
860     def _mark_notify(self):
861         for op in self.msgs.values():
862             if 'notify' in op:
863                 self.ops[op['notify']].mark_has_ntf()
864
865     # Fake a 'do' equivalent of all events, so that we can render their response parsing
866     def _mock_up_events(self):
867         for op in self.yaml['operations']['list']:
868             if 'event' in op:
869                 op['do'] = {
870                     'reply': {
871                         'attributes': op['event']['attributes']
872                     }
873                 }
874
875     def _load_root_sets(self):
876         for op_name, op in self.msgs.items():
877             if 'attribute-set' not in op:
878                 continue
879
880             req_attrs = set()
881             rsp_attrs = set()
882             for op_mode in ['do', 'dump']:
883                 if op_mode in op and 'request' in op[op_mode]:
884                     req_attrs.update(set(op[op_mode]['request']['attributes']))
885                 if op_mode in op and 'reply' in op[op_mode]:
886                     rsp_attrs.update(set(op[op_mode]['reply']['attributes']))
887             if 'event' in op:
888                 rsp_attrs.update(set(op['event']['attributes']))
889
890             if op['attribute-set'] not in self.root_sets:
891                 self.root_sets[op['attribute-set']] = {'request': req_attrs, 'reply': rsp_attrs}
892             else:
893                 self.root_sets[op['attribute-set']]['request'].update(req_attrs)
894                 self.root_sets[op['attribute-set']]['reply'].update(rsp_attrs)
895
896     def _load_nested_sets(self):
897         attr_set_queue = list(self.root_sets.keys())
898         attr_set_seen = set(self.root_sets.keys())
899
900         while len(attr_set_queue):
901             a_set = attr_set_queue.pop(0)
902             for attr, spec in self.attr_sets[a_set].items():
903                 if 'nested-attributes' not in spec:
904                     continue
905
906                 nested = spec['nested-attributes']
907                 if nested not in attr_set_seen:
908                     attr_set_queue.append(nested)
909                     attr_set_seen.add(nested)
910
911                 inherit = set()
912                 if nested not in self.root_sets:
913                     if nested not in self.pure_nested_structs:
914                         self.pure_nested_structs[nested] = Struct(self, nested, inherited=inherit)
915                 else:
916                     raise Exception(f'Using attr set as root and nested not supported - {nested}')
917
918                 if 'type-value' in spec:
919                     if nested in self.root_sets:
920                         raise Exception("Inheriting members to a space used as root not supported")
921                     inherit.update(set(spec['type-value']))
922                 elif spec['type'] == 'array-nest':
923                     inherit.add('idx')
924                 self.pure_nested_structs[nested].set_inherited(inherit)
925
926         for root_set, rs_members in self.root_sets.items():
927             for attr, spec in self.attr_sets[root_set].items():
928                 if 'nested-attributes' in spec:
929                     nested = spec['nested-attributes']
930                     if attr in rs_members['request']:
931                         self.pure_nested_structs[nested].request = True
932                     if attr in rs_members['reply']:
933                         self.pure_nested_structs[nested].reply = True
934
935         # Try to reorder according to dependencies
936         pns_key_list = list(self.pure_nested_structs.keys())
937         pns_key_seen = set()
938         rounds = len(pns_key_list)**2  # it's basically bubble sort
939         for _ in range(rounds):
940             if len(pns_key_list) == 0:
941                 break
942             name = pns_key_list.pop(0)
943             finished = True
944             for _, spec in self.attr_sets[name].items():
945                 if 'nested-attributes' in spec:
946                     if spec['nested-attributes'] not in pns_key_seen:
947                         # Dicts are sorted, this will make struct last
948                         struct = self.pure_nested_structs.pop(name)
949                         self.pure_nested_structs[name] = struct
950                         finished = False
951                         break
952             if finished:
953                 pns_key_seen.add(name)
954             else:
955                 pns_key_list.append(name)
956         # Propagate the request / reply
957         for attr_set, struct in reversed(self.pure_nested_structs.items()):
958             for _, spec in self.attr_sets[attr_set].items():
959                 if 'nested-attributes' in spec:
960                     child = self.pure_nested_structs.get(spec['nested-attributes'])
961                     if child:
962                         child.request |= struct.request
963                         child.reply |= struct.reply
964
965     def _load_global_policy(self):
966         global_set = set()
967         attr_set_name = None
968         for op_name, op in self.ops.items():
969             if not op:
970                 continue
971             if 'attribute-set' not in op:
972                 continue
973
974             if attr_set_name is None:
975                 attr_set_name = op['attribute-set']
976             if attr_set_name != op['attribute-set']:
977                 raise Exception('For a global policy all ops must use the same set')
978
979             for op_mode in ['do', 'dump']:
980                 if op_mode in op:
981                     global_set.update(op[op_mode].get('request', []))
982
983         self.global_policy = []
984         self.global_policy_set = attr_set_name
985         for attr in self.attr_sets[attr_set_name]:
986             if attr in global_set:
987                 self.global_policy.append(attr)
988
989     def _load_hooks(self):
990         for op in self.ops.values():
991             for op_mode in ['do', 'dump']:
992                 if op_mode not in op:
993                     continue
994                 for when in ['pre', 'post']:
995                     if when not in op[op_mode]:
996                         continue
997                     name = op[op_mode][when]
998                     if name in self.hooks[when][op_mode]['set']:
999                         continue
1000                     self.hooks[when][op_mode]['set'].add(name)
1001                     self.hooks[when][op_mode]['list'].append(name)
1002
1003
1004 class RenderInfo:
1005     def __init__(self, cw, family, ku_space, op, op_mode, attr_set=None):
1006         self.family = family
1007         self.nl = cw.nlib
1008         self.ku_space = ku_space
1009         self.op_mode = op_mode
1010         self.op = op
1011
1012         # 'do' and 'dump' response parsing is identical
1013         self.type_consistent = True
1014         if op_mode != 'do' and 'dump' in op and 'do' in op:
1015             if ('reply' in op['do']) != ('reply' in op["dump"]):
1016                 self.type_consistent = False
1017             elif 'reply' in op['do'] and op["do"]["reply"] != op["dump"]["reply"]:
1018                 self.type_consistent = False
1019
1020         self.attr_set = attr_set
1021         if not self.attr_set:
1022             self.attr_set = op['attribute-set']
1023
1024         self.type_name_conflict = False
1025         if op:
1026             self.type_name = c_lower(op.name)
1027         else:
1028             self.type_name = c_lower(attr_set)
1029             if attr_set in family.consts:
1030                 self.type_name_conflict = True
1031
1032         self.cw = cw
1033
1034         self.struct = dict()
1035         if op_mode == 'notify':
1036             op_mode = 'do'
1037         for op_dir in ['request', 'reply']:
1038             if op and op_dir in op[op_mode]:
1039                 self.struct[op_dir] = Struct(family, self.attr_set,
1040                                              type_list=op[op_mode][op_dir]['attributes'])
1041         if op_mode == 'event':
1042             self.struct['reply'] = Struct(family, self.attr_set, type_list=op['event']['attributes'])
1043
1044
1045 class CodeWriter:
1046     def __init__(self, nlib, out_file):
1047         self.nlib = nlib
1048
1049         self._nl = False
1050         self._block_end = False
1051         self._silent_block = False
1052         self._ind = 0
1053         self._out = out_file
1054
1055     @classmethod
1056     def _is_cond(cls, line):
1057         return line.startswith('if') or line.startswith('while') or line.startswith('for')
1058
1059     def p(self, line, add_ind=0):
1060         if self._block_end:
1061             self._block_end = False
1062             if line.startswith('else'):
1063                 line = '} ' + line
1064             else:
1065                 self._out.write('\t' * self._ind + '}\n')
1066
1067         if self._nl:
1068             self._out.write('\n')
1069             self._nl = False
1070
1071         ind = self._ind
1072         if line[-1] == ':':
1073             ind -= 1
1074         if self._silent_block:
1075             ind += 1
1076         self._silent_block = line.endswith(')') and CodeWriter._is_cond(line)
1077         if add_ind:
1078             ind += add_ind
1079         self._out.write('\t' * ind + line + '\n')
1080
1081     def nl(self):
1082         self._nl = True
1083
1084     def block_start(self, line=''):
1085         if line:
1086             line = line + ' '
1087         self.p(line + '{')
1088         self._ind += 1
1089
1090     def block_end(self, line=''):
1091         if line and line[0] not in {';', ','}:
1092             line = ' ' + line
1093         self._ind -= 1
1094         self._nl = False
1095         if not line:
1096             # Delay printing closing bracket in case "else" comes next
1097             if self._block_end:
1098                 self._out.write('\t' * (self._ind + 1) + '}\n')
1099             self._block_end = True
1100         else:
1101             self.p('}' + line)
1102
1103     def write_doc_line(self, doc, indent=True):
1104         words = doc.split()
1105         line = ' *'
1106         for word in words:
1107             if len(line) + len(word) >= 79:
1108                 self.p(line)
1109                 line = ' *'
1110                 if indent:
1111                     line += '  '
1112             line += ' ' + word
1113         self.p(line)
1114
1115     def write_func_prot(self, qual_ret, name, args=None, doc=None, suffix=''):
1116         if not args:
1117             args = ['void']
1118
1119         if doc:
1120             self.p('/*')
1121             self.p(' * ' + doc)
1122             self.p(' */')
1123
1124         oneline = qual_ret
1125         if qual_ret[-1] != '*':
1126             oneline += ' '
1127         oneline += f"{name}({', '.join(args)}){suffix}"
1128
1129         if len(oneline) < 80:
1130             self.p(oneline)
1131             return
1132
1133         v = qual_ret
1134         if len(v) > 3:
1135             self.p(v)
1136             v = ''
1137         elif qual_ret[-1] != '*':
1138             v += ' '
1139         v += name + '('
1140         ind = '\t' * (len(v) // 8) + ' ' * (len(v) % 8)
1141         delta_ind = len(v) - len(ind)
1142         v += args[0]
1143         i = 1
1144         while i < len(args):
1145             next_len = len(v) + len(args[i])
1146             if v[0] == '\t':
1147                 next_len += delta_ind
1148             if next_len > 76:
1149                 self.p(v + ',')
1150                 v = ind
1151             else:
1152                 v += ', '
1153             v += args[i]
1154             i += 1
1155         self.p(v + ')' + suffix)
1156
1157     def write_func_lvar(self, local_vars):
1158         if not local_vars:
1159             return
1160
1161         if type(local_vars) is str:
1162             local_vars = [local_vars]
1163
1164         local_vars.sort(key=len, reverse=True)
1165         for var in local_vars:
1166             self.p(var)
1167         self.nl()
1168
1169     def write_func(self, qual_ret, name, body, args=None, local_vars=None):
1170         self.write_func_prot(qual_ret=qual_ret, name=name, args=args)
1171         self.write_func_lvar(local_vars=local_vars)
1172
1173         self.block_start()
1174         for line in body:
1175             self.p(line)
1176         self.block_end()
1177
1178     def writes_defines(self, defines):
1179         longest = 0
1180         for define in defines:
1181             if len(define[0]) > longest:
1182                 longest = len(define[0])
1183         longest = ((longest + 8) // 8) * 8
1184         for define in defines:
1185             line = '#define ' + define[0]
1186             line += '\t' * ((longest - len(define[0]) + 7) // 8)
1187             if type(define[1]) is int:
1188                 line += str(define[1])
1189             elif type(define[1]) is str:
1190                 line += '"' + define[1] + '"'
1191             self.p(line)
1192
1193     def write_struct_init(self, members):
1194         longest = max([len(x[0]) for x in members])
1195         longest += 1  # because we prepend a .
1196         longest = ((longest + 8) // 8) * 8
1197         for one in members:
1198             line = '.' + one[0]
1199             line += '\t' * ((longest - len(one[0]) - 1 + 7) // 8)
1200             line += '= ' + one[1] + ','
1201             self.p(line)
1202
1203
1204 scalars = {'u8', 'u16', 'u32', 'u64', 's32', 's64'}
1205
1206 direction_to_suffix = {
1207     'reply': '_rsp',
1208     'request': '_req',
1209     '': ''
1210 }
1211
1212 op_mode_to_wrapper = {
1213     'do': '',
1214     'dump': '_list',
1215     'notify': '_ntf',
1216     'event': '',
1217 }
1218
1219 _C_KW = {
1220     'auto',
1221     'bool',
1222     'break',
1223     'case',
1224     'char',
1225     'const',
1226     'continue',
1227     'default',
1228     'do',
1229     'double',
1230     'else',
1231     'enum',
1232     'extern',
1233     'float',
1234     'for',
1235     'goto',
1236     'if',
1237     'inline',
1238     'int',
1239     'long',
1240     'register',
1241     'return',
1242     'short',
1243     'signed',
1244     'sizeof',
1245     'static',
1246     'struct',
1247     'switch',
1248     'typedef',
1249     'union',
1250     'unsigned',
1251     'void',
1252     'volatile',
1253     'while'
1254 }
1255
1256
1257 def rdir(direction):
1258     if direction == 'reply':
1259         return 'request'
1260     if direction == 'request':
1261         return 'reply'
1262     return direction
1263
1264
1265 def op_prefix(ri, direction, deref=False):
1266     suffix = f"_{ri.type_name}"
1267
1268     if not ri.op_mode or ri.op_mode == 'do':
1269         suffix += f"{direction_to_suffix[direction]}"
1270     else:
1271         if direction == 'request':
1272             suffix += '_req_dump'
1273         else:
1274             if ri.type_consistent:
1275                 if deref:
1276                     suffix += f"{direction_to_suffix[direction]}"
1277                 else:
1278                     suffix += op_mode_to_wrapper[ri.op_mode]
1279             else:
1280                 suffix += '_rsp'
1281                 suffix += '_dump' if deref else '_list'
1282
1283     return f"{ri.family['name']}{suffix}"
1284
1285
1286 def type_name(ri, direction, deref=False):
1287     return f"struct {op_prefix(ri, direction, deref=deref)}"
1288
1289
1290 def print_prototype(ri, direction, terminate=True, doc=None):
1291     suffix = ';' if terminate else ''
1292
1293     fname = ri.op.render_name
1294     if ri.op_mode == 'dump':
1295         fname += '_dump'
1296
1297     args = ['struct ynl_sock *ys']
1298     if 'request' in ri.op[ri.op_mode]:
1299         args.append(f"{type_name(ri, direction)} *" + f"{direction_to_suffix[direction][1:]}")
1300
1301     ret = 'int'
1302     if 'reply' in ri.op[ri.op_mode]:
1303         ret = f"{type_name(ri, rdir(direction))} *"
1304
1305     ri.cw.write_func_prot(ret, fname, args, doc=doc, suffix=suffix)
1306
1307
1308 def print_req_prototype(ri):
1309     print_prototype(ri, "request", doc=ri.op['doc'])
1310
1311
1312 def print_dump_prototype(ri):
1313     print_prototype(ri, "request")
1314
1315
1316 def put_typol(cw, struct):
1317     type_max = struct.attr_set.max_name
1318     cw.block_start(line=f'struct ynl_policy_attr {struct.render_name}_policy[{type_max} + 1] =')
1319
1320     for _, arg in struct.member_list():
1321         arg.attr_typol(cw)
1322
1323     cw.block_end(line=';')
1324     cw.nl()
1325
1326     cw.block_start(line=f'struct ynl_policy_nest {struct.render_name}_nest =')
1327     cw.p(f'.max_attr = {type_max},')
1328     cw.p(f'.table = {struct.render_name}_policy,')
1329     cw.block_end(line=';')
1330     cw.nl()
1331
1332
1333 def _put_enum_to_str_helper(cw, render_name, map_name, arg_name, enum=None):
1334     args = [f'int {arg_name}']
1335     if enum and not ('enum-name' in enum and not enum['enum-name']):
1336         args = [f'enum {render_name} {arg_name}']
1337     cw.write_func_prot('const char *', f'{render_name}_str', args)
1338     cw.block_start()
1339     if enum and enum.type == 'flags':
1340         cw.p(f'{arg_name} = ffs({arg_name}) - 1;')
1341     cw.p(f'if ({arg_name} < 0 || {arg_name} >= (int)MNL_ARRAY_SIZE({map_name}))')
1342     cw.p('return NULL;')
1343     cw.p(f'return {map_name}[{arg_name}];')
1344     cw.block_end()
1345     cw.nl()
1346
1347
1348 def put_op_name_fwd(family, cw):
1349     cw.write_func_prot('const char *', f'{family.name}_op_str', ['int op'], suffix=';')
1350
1351
1352 def put_op_name(family, cw):
1353     map_name = f'{family.name}_op_strmap'
1354     cw.block_start(line=f"static const char * const {map_name}[] =")
1355     for op_name, op in family.msgs.items():
1356         if op.rsp_value:
1357             if op.req_value == op.rsp_value:
1358                 cw.p(f'[{op.enum_name}] = "{op_name}",')
1359             else:
1360                 cw.p(f'[{op.rsp_value}] = "{op_name}",')
1361     cw.block_end(line=';')
1362     cw.nl()
1363
1364     _put_enum_to_str_helper(cw, family.name + '_op', map_name, 'op')
1365
1366
1367 def put_enum_to_str_fwd(family, cw, enum):
1368     args = [f'enum {enum.render_name} value']
1369     if 'enum-name' in enum and not enum['enum-name']:
1370         args = ['int value']
1371     cw.write_func_prot('const char *', f'{enum.render_name}_str', args, suffix=';')
1372
1373
1374 def put_enum_to_str(family, cw, enum):
1375     map_name = f'{enum.render_name}_strmap'
1376     cw.block_start(line=f"static const char * const {map_name}[] =")
1377     for entry in enum.entries.values():
1378         cw.p(f'[{entry.value}] = "{entry.name}",')
1379     cw.block_end(line=';')
1380     cw.nl()
1381
1382     _put_enum_to_str_helper(cw, enum.render_name, map_name, 'value', enum=enum)
1383
1384
1385 def put_req_nested(ri, struct):
1386     func_args = ['struct nlmsghdr *nlh',
1387                  'unsigned int attr_type',
1388                  f'{struct.ptr_name}obj']
1389
1390     ri.cw.write_func_prot('int', f'{struct.render_name}_put', func_args)
1391     ri.cw.block_start()
1392     ri.cw.write_func_lvar('struct nlattr *nest;')
1393
1394     ri.cw.p("nest = mnl_attr_nest_start(nlh, attr_type);")
1395
1396     for _, arg in struct.member_list():
1397         arg.attr_put(ri, "obj")
1398
1399     ri.cw.p("mnl_attr_nest_end(nlh, nest);")
1400
1401     ri.cw.nl()
1402     ri.cw.p('return 0;')
1403     ri.cw.block_end()
1404     ri.cw.nl()
1405
1406
1407 def _multi_parse(ri, struct, init_lines, local_vars):
1408     if struct.nested:
1409         iter_line = "mnl_attr_for_each_nested(attr, nested)"
1410     else:
1411         iter_line = "mnl_attr_for_each(attr, nlh, sizeof(struct genlmsghdr))"
1412
1413     array_nests = set()
1414     multi_attrs = set()
1415     needs_parg = False
1416     for arg, aspec in struct.member_list():
1417         if aspec['type'] == 'array-nest':
1418             local_vars.append(f'const struct nlattr *attr_{aspec.c_name};')
1419             array_nests.add(arg)
1420         if 'multi-attr' in aspec:
1421             multi_attrs.add(arg)
1422         needs_parg |= 'nested-attributes' in aspec
1423     if array_nests or multi_attrs:
1424         local_vars.append('int i;')
1425     if needs_parg:
1426         local_vars.append('struct ynl_parse_arg parg;')
1427         init_lines.append('parg.ys = yarg->ys;')
1428
1429     all_multi = array_nests | multi_attrs
1430
1431     for anest in sorted(all_multi):
1432         local_vars.append(f"unsigned int n_{struct[anest].c_name} = 0;")
1433
1434     ri.cw.block_start()
1435     ri.cw.write_func_lvar(local_vars)
1436
1437     for line in init_lines:
1438         ri.cw.p(line)
1439     ri.cw.nl()
1440
1441     for arg in struct.inherited:
1442         ri.cw.p(f'dst->{arg} = {arg};')
1443
1444     for anest in sorted(all_multi):
1445         aspec = struct[anest]
1446         ri.cw.p(f"if (dst->{aspec.c_name})")
1447         ri.cw.p(f'return ynl_error_parse(yarg, "attribute already present ({struct.attr_set.name}.{aspec.name})");')
1448
1449     ri.cw.nl()
1450     ri.cw.block_start(line=iter_line)
1451     ri.cw.p('unsigned int type = mnl_attr_get_type(attr);')
1452     ri.cw.nl()
1453
1454     first = True
1455     for _, arg in struct.member_list():
1456         good = arg.attr_get(ri, 'dst', first=first)
1457         # First may be 'unused' or 'pad', ignore those
1458         first &= not good
1459
1460     ri.cw.block_end()
1461     ri.cw.nl()
1462
1463     for anest in sorted(array_nests):
1464         aspec = struct[anest]
1465
1466         ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1467         ri.cw.p(f"dst->{aspec.c_name} = calloc({aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1468         ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1469         ri.cw.p('i = 0;')
1470         ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1471         ri.cw.block_start(line=f"mnl_attr_for_each_nested(attr, attr_{aspec.c_name})")
1472         ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1473         ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr, mnl_attr_get_type(attr)))")
1474         ri.cw.p('return MNL_CB_ERROR;')
1475         ri.cw.p('i++;')
1476         ri.cw.block_end()
1477         ri.cw.block_end()
1478     ri.cw.nl()
1479
1480     for anest in sorted(multi_attrs):
1481         aspec = struct[anest]
1482         ri.cw.block_start(line=f"if (n_{aspec.c_name})")
1483         ri.cw.p(f"dst->{aspec.c_name} = calloc(n_{aspec.c_name}, sizeof(*dst->{aspec.c_name}));")
1484         ri.cw.p(f"dst->n_{aspec.c_name} = n_{aspec.c_name};")
1485         ri.cw.p('i = 0;')
1486         if 'nested-attributes' in aspec:
1487             ri.cw.p(f"parg.rsp_policy = &{aspec.nested_render_name}_nest;")
1488         ri.cw.block_start(line=iter_line)
1489         ri.cw.block_start(line=f"if (mnl_attr_get_type(attr) == {aspec.enum_name})")
1490         if 'nested-attributes' in aspec:
1491             ri.cw.p(f"parg.data = &dst->{aspec.c_name}[i];")
1492             ri.cw.p(f"if ({aspec.nested_render_name}_parse(&parg, attr))")
1493             ri.cw.p('return MNL_CB_ERROR;')
1494         elif aspec['type'] in scalars:
1495             t = aspec['type']
1496             if t[0] == 's':
1497                 t = 'u' + t[1:]
1498             ri.cw.p(f"dst->{aspec.c_name}[i] = mnl_attr_get_{t}(attr);")
1499         else:
1500             raise Exception('Nest parsing type not supported yet')
1501         ri.cw.p('i++;')
1502         ri.cw.block_end()
1503         ri.cw.block_end()
1504         ri.cw.block_end()
1505     ri.cw.nl()
1506
1507     if struct.nested:
1508         ri.cw.p('return 0;')
1509     else:
1510         ri.cw.p('return MNL_CB_OK;')
1511     ri.cw.block_end()
1512     ri.cw.nl()
1513
1514
1515 def parse_rsp_nested(ri, struct):
1516     func_args = ['struct ynl_parse_arg *yarg',
1517                  'const struct nlattr *nested']
1518     for arg in struct.inherited:
1519         func_args.append('__u32 ' + arg)
1520
1521     local_vars = ['const struct nlattr *attr;',
1522                   f'{struct.ptr_name}dst = yarg->data;']
1523     init_lines = []
1524
1525     ri.cw.write_func_prot('int', f'{struct.render_name}_parse', func_args)
1526
1527     _multi_parse(ri, struct, init_lines, local_vars)
1528
1529
1530 def parse_rsp_msg(ri, deref=False):
1531     if 'reply' not in ri.op[ri.op_mode] and ri.op_mode != 'event':
1532         return
1533
1534     func_args = ['const struct nlmsghdr *nlh',
1535                  'void *data']
1536
1537     local_vars = [f'{type_name(ri, "reply", deref=deref)} *dst;',
1538                   'struct ynl_parse_arg *yarg = data;',
1539                   'const struct nlattr *attr;']
1540     init_lines = ['dst = yarg->data;']
1541
1542     ri.cw.write_func_prot('int', f'{op_prefix(ri, "reply", deref=deref)}_parse', func_args)
1543
1544     _multi_parse(ri, ri.struct["reply"], init_lines, local_vars)
1545
1546
1547 def print_req(ri):
1548     ret_ok = '0'
1549     ret_err = '-1'
1550     direction = "request"
1551     local_vars = ['struct nlmsghdr *nlh;',
1552                   'int err;']
1553
1554     if 'reply' in ri.op[ri.op_mode]:
1555         ret_ok = 'rsp'
1556         ret_err = 'NULL'
1557         local_vars += [f'{type_name(ri, rdir(direction))} *rsp;',
1558                        'struct ynl_req_state yrs = { .yarg = { .ys = ys, }, };']
1559
1560     print_prototype(ri, direction, terminate=False)
1561     ri.cw.block_start()
1562     ri.cw.write_func_lvar(local_vars)
1563
1564     ri.cw.p(f"nlh = ynl_gemsg_start_req(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1565
1566     ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1567     if 'reply' in ri.op[ri.op_mode]:
1568         ri.cw.p(f"yrs.yarg.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1569     ri.cw.nl()
1570     for _, attr in ri.struct["request"].member_list():
1571         attr.attr_put(ri, "req")
1572     ri.cw.nl()
1573
1574     parse_arg = "NULL"
1575     if 'reply' in ri.op[ri.op_mode]:
1576         ri.cw.p('rsp = calloc(1, sizeof(*rsp));')
1577         ri.cw.p('yrs.yarg.data = rsp;')
1578         ri.cw.p(f"yrs.cb = {op_prefix(ri, 'reply')}_parse;")
1579         if ri.op.value is not None:
1580             ri.cw.p(f'yrs.rsp_cmd = {ri.op.enum_name};')
1581         else:
1582             ri.cw.p(f'yrs.rsp_cmd = {ri.op.rsp_value};')
1583         ri.cw.nl()
1584         parse_arg = '&yrs'
1585     ri.cw.p(f"err = ynl_exec(ys, nlh, {parse_arg});")
1586     ri.cw.p('if (err < 0)')
1587     if 'reply' in ri.op[ri.op_mode]:
1588         ri.cw.p('goto err_free;')
1589     else:
1590         ri.cw.p('return -1;')
1591     ri.cw.nl()
1592
1593     ri.cw.p(f"return {ret_ok};")
1594     ri.cw.nl()
1595
1596     if 'reply' in ri.op[ri.op_mode]:
1597         ri.cw.p('err_free:')
1598         ri.cw.p(f"{call_free(ri, rdir(direction), 'rsp')}")
1599         ri.cw.p(f"return {ret_err};")
1600
1601     ri.cw.block_end()
1602
1603
1604 def print_dump(ri):
1605     direction = "request"
1606     print_prototype(ri, direction, terminate=False)
1607     ri.cw.block_start()
1608     local_vars = ['struct ynl_dump_state yds = {};',
1609                   'struct nlmsghdr *nlh;',
1610                   'int err;']
1611
1612     for var in local_vars:
1613         ri.cw.p(f'{var}')
1614     ri.cw.nl()
1615
1616     ri.cw.p('yds.ys = ys;')
1617     ri.cw.p(f"yds.alloc_sz = sizeof({type_name(ri, rdir(direction))});")
1618     ri.cw.p(f"yds.cb = {op_prefix(ri, 'reply', deref=True)}_parse;")
1619     if ri.op.value is not None:
1620         ri.cw.p(f'yds.rsp_cmd = {ri.op.enum_name};')
1621     else:
1622         ri.cw.p(f'yds.rsp_cmd = {ri.op.rsp_value};')
1623     ri.cw.p(f"yds.rsp_policy = &{ri.struct['reply'].render_name}_nest;")
1624     ri.cw.nl()
1625     ri.cw.p(f"nlh = ynl_gemsg_start_dump(ys, {ri.nl.get_family_id()}, {ri.op.enum_name}, 1);")
1626
1627     if "request" in ri.op[ri.op_mode]:
1628         ri.cw.p(f"ys->req_policy = &{ri.struct['request'].render_name}_nest;")
1629         ri.cw.nl()
1630         for _, attr in ri.struct["request"].member_list():
1631             attr.attr_put(ri, "req")
1632     ri.cw.nl()
1633
1634     ri.cw.p('err = ynl_exec_dump(ys, nlh, &yds);')
1635     ri.cw.p('if (err < 0)')
1636     ri.cw.p('goto free_list;')
1637     ri.cw.nl()
1638
1639     ri.cw.p('return yds.first;')
1640     ri.cw.nl()
1641     ri.cw.p('free_list:')
1642     ri.cw.p(call_free(ri, rdir(direction), 'yds.first'))
1643     ri.cw.p('return NULL;')
1644     ri.cw.block_end()
1645
1646
1647 def call_free(ri, direction, var):
1648     return f"{op_prefix(ri, direction)}_free({var});"
1649
1650
1651 def free_arg_name(direction):
1652     if direction:
1653         return direction_to_suffix[direction][1:]
1654     return 'obj'
1655
1656
1657 def print_alloc_wrapper(ri, direction):
1658     name = op_prefix(ri, direction)
1659     ri.cw.write_func_prot(f'static inline struct {name} *', f"{name}_alloc", [f"void"])
1660     ri.cw.block_start()
1661     ri.cw.p(f'return calloc(1, sizeof(struct {name}));')
1662     ri.cw.block_end()
1663
1664
1665 def print_free_prototype(ri, direction, suffix=';'):
1666     name = op_prefix(ri, direction)
1667     struct_name = name
1668     if ri.type_name_conflict:
1669         struct_name += '_'
1670     arg = free_arg_name(direction)
1671     ri.cw.write_func_prot('void', f"{name}_free", [f"struct {struct_name} *{arg}"], suffix=suffix)
1672
1673
1674 def _print_type(ri, direction, struct):
1675     suffix = f'_{ri.type_name}{direction_to_suffix[direction]}'
1676     if not direction and ri.type_name_conflict:
1677         suffix += '_'
1678
1679     if ri.op_mode == 'dump':
1680         suffix += '_dump'
1681
1682     ri.cw.block_start(line=f"struct {ri.family['name']}{suffix}")
1683
1684     meta_started = False
1685     for _, attr in struct.member_list():
1686         for type_filter in ['len', 'bit']:
1687             line = attr.presence_member(ri.ku_space, type_filter)
1688             if line:
1689                 if not meta_started:
1690                     ri.cw.block_start(line=f"struct")
1691                     meta_started = True
1692                 ri.cw.p(line)
1693     if meta_started:
1694         ri.cw.block_end(line='_present;')
1695         ri.cw.nl()
1696
1697     for arg in struct.inherited:
1698         ri.cw.p(f"__u32 {arg};")
1699
1700     for _, attr in struct.member_list():
1701         attr.struct_member(ri)
1702
1703     ri.cw.block_end(line=';')
1704     ri.cw.nl()
1705
1706
1707 def print_type(ri, direction):
1708     _print_type(ri, direction, ri.struct[direction])
1709
1710
1711 def print_type_full(ri, struct):
1712     _print_type(ri, "", struct)
1713
1714
1715 def print_type_helpers(ri, direction, deref=False):
1716     print_free_prototype(ri, direction)
1717     ri.cw.nl()
1718
1719     if ri.ku_space == 'user' and direction == 'request':
1720         for _, attr in ri.struct[direction].member_list():
1721             attr.setter(ri, ri.attr_set, direction, deref=deref)
1722     ri.cw.nl()
1723
1724
1725 def print_req_type_helpers(ri):
1726     print_alloc_wrapper(ri, "request")
1727     print_type_helpers(ri, "request")
1728
1729
1730 def print_rsp_type_helpers(ri):
1731     if 'reply' not in ri.op[ri.op_mode]:
1732         return
1733     print_type_helpers(ri, "reply")
1734
1735
1736 def print_parse_prototype(ri, direction, terminate=True):
1737     suffix = "_rsp" if direction == "reply" else "_req"
1738     term = ';' if terminate else ''
1739
1740     ri.cw.write_func_prot('void', f"{ri.op.render_name}{suffix}_parse",
1741                           ['const struct nlattr **tb',
1742                            f"struct {ri.op.render_name}{suffix} *req"],
1743                           suffix=term)
1744
1745
1746 def print_req_type(ri):
1747     print_type(ri, "request")
1748
1749
1750 def print_req_free(ri):
1751     if 'request' not in ri.op[ri.op_mode]:
1752         return
1753     _free_type(ri, 'request', ri.struct['request'])
1754
1755
1756 def print_rsp_type(ri):
1757     if (ri.op_mode == 'do' or ri.op_mode == 'dump') and 'reply' in ri.op[ri.op_mode]:
1758         direction = 'reply'
1759     elif ri.op_mode == 'event':
1760         direction = 'reply'
1761     else:
1762         return
1763     print_type(ri, direction)
1764
1765
1766 def print_wrapped_type(ri):
1767     ri.cw.block_start(line=f"{type_name(ri, 'reply')}")
1768     if ri.op_mode == 'dump':
1769         ri.cw.p(f"{type_name(ri, 'reply')} *next;")
1770     elif ri.op_mode == 'notify' or ri.op_mode == 'event':
1771         ri.cw.p('__u16 family;')
1772         ri.cw.p('__u8 cmd;')
1773         ri.cw.p('struct ynl_ntf_base_type *next;')
1774         ri.cw.p(f"void (*free)({type_name(ri, 'reply')} *ntf);")
1775     ri.cw.p(f"{type_name(ri, 'reply', deref=True)} obj __attribute__ ((aligned (8)));")
1776     ri.cw.block_end(line=';')
1777     ri.cw.nl()
1778     print_free_prototype(ri, 'reply')
1779     ri.cw.nl()
1780
1781
1782 def _free_type_members_iter(ri, struct):
1783     for _, attr in struct.member_list():
1784         if attr.free_needs_iter():
1785             ri.cw.p('unsigned int i;')
1786             ri.cw.nl()
1787             break
1788
1789
1790 def _free_type_members(ri, var, struct, ref=''):
1791     for _, attr in struct.member_list():
1792         attr.free(ri, var, ref)
1793
1794
1795 def _free_type(ri, direction, struct):
1796     var = free_arg_name(direction)
1797
1798     print_free_prototype(ri, direction, suffix='')
1799     ri.cw.block_start()
1800     _free_type_members_iter(ri, struct)
1801     _free_type_members(ri, var, struct)
1802     if direction:
1803         ri.cw.p(f'free({var});')
1804     ri.cw.block_end()
1805     ri.cw.nl()
1806
1807
1808 def free_rsp_nested(ri, struct):
1809     _free_type(ri, "", struct)
1810
1811
1812 def print_rsp_free(ri):
1813     if 'reply' not in ri.op[ri.op_mode]:
1814         return
1815     _free_type(ri, 'reply', ri.struct['reply'])
1816
1817
1818 def print_dump_type_free(ri):
1819     sub_type = type_name(ri, 'reply')
1820
1821     print_free_prototype(ri, 'reply', suffix='')
1822     ri.cw.block_start()
1823     ri.cw.p(f"{sub_type} *next = rsp;")
1824     ri.cw.nl()
1825     ri.cw.block_start(line='while ((void *)next != YNL_LIST_END)')
1826     _free_type_members_iter(ri, ri.struct['reply'])
1827     ri.cw.p('rsp = next;')
1828     ri.cw.p('next = rsp->next;')
1829     ri.cw.nl()
1830
1831     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1832     ri.cw.p(f'free(rsp);')
1833     ri.cw.block_end()
1834     ri.cw.block_end()
1835     ri.cw.nl()
1836
1837
1838 def print_ntf_type_free(ri):
1839     print_free_prototype(ri, 'reply', suffix='')
1840     ri.cw.block_start()
1841     _free_type_members_iter(ri, ri.struct['reply'])
1842     _free_type_members(ri, 'rsp', ri.struct['reply'], ref='obj.')
1843     ri.cw.p(f'free(rsp);')
1844     ri.cw.block_end()
1845     ri.cw.nl()
1846
1847
1848 def print_req_policy_fwd(cw, struct, ri=None, terminate=True):
1849     if terminate and ri and policy_should_be_static(struct.family):
1850         return
1851
1852     if terminate:
1853         prefix = 'extern '
1854     else:
1855         if ri and policy_should_be_static(struct.family):
1856             prefix = 'static '
1857         else:
1858             prefix = ''
1859
1860     suffix = ';' if terminate else ' = {'
1861
1862     max_attr = struct.attr_max_val
1863     if ri:
1864         name = ri.op.render_name
1865         if ri.op.dual_policy:
1866             name += '_' + ri.op_mode
1867     else:
1868         name = struct.render_name
1869     cw.p(f"{prefix}const struct nla_policy {name}_nl_policy[{max_attr.enum_name} + 1]{suffix}")
1870
1871
1872 def print_req_policy(cw, struct, ri=None):
1873     print_req_policy_fwd(cw, struct, ri=ri, terminate=False)
1874     for _, arg in struct.member_list():
1875         arg.attr_policy(cw)
1876     cw.p("};")
1877     cw.nl()
1878
1879
1880 def kernel_can_gen_family_struct(family):
1881     return family.proto == 'genetlink'
1882
1883
1884 def policy_should_be_static(family):
1885     return family.kernel_policy == 'split' or kernel_can_gen_family_struct(family)
1886
1887
1888 def print_kernel_op_table_fwd(family, cw, terminate):
1889     exported = not kernel_can_gen_family_struct(family)
1890
1891     if not terminate or exported:
1892         cw.p(f"/* Ops table for {family.name} */")
1893
1894         pol_to_struct = {'global': 'genl_small_ops',
1895                          'per-op': 'genl_ops',
1896                          'split': 'genl_split_ops'}
1897         struct_type = pol_to_struct[family.kernel_policy]
1898
1899         if not exported:
1900             cnt = ""
1901         elif family.kernel_policy == 'split':
1902             cnt = 0
1903             for op in family.ops.values():
1904                 if 'do' in op:
1905                     cnt += 1
1906                 if 'dump' in op:
1907                     cnt += 1
1908         else:
1909             cnt = len(family.ops)
1910
1911         qual = 'static const' if not exported else 'const'
1912         line = f"{qual} struct {struct_type} {family.name}_nl_ops[{cnt}]"
1913         if terminate:
1914             cw.p(f"extern {line};")
1915         else:
1916             cw.block_start(line=line + ' =')
1917
1918     if not terminate:
1919         return
1920
1921     cw.nl()
1922     for name in family.hooks['pre']['do']['list']:
1923         cw.write_func_prot('int', c_lower(name),
1924                            ['const struct genl_split_ops *ops',
1925                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1926     for name in family.hooks['post']['do']['list']:
1927         cw.write_func_prot('void', c_lower(name),
1928                            ['const struct genl_split_ops *ops',
1929                             'struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1930     for name in family.hooks['pre']['dump']['list']:
1931         cw.write_func_prot('int', c_lower(name),
1932                            ['struct netlink_callback *cb'], suffix=';')
1933     for name in family.hooks['post']['dump']['list']:
1934         cw.write_func_prot('int', c_lower(name),
1935                            ['struct netlink_callback *cb'], suffix=';')
1936
1937     cw.nl()
1938
1939     for op_name, op in family.ops.items():
1940         if op.is_async:
1941             continue
1942
1943         if 'do' in op:
1944             name = c_lower(f"{family.name}-nl-{op_name}-doit")
1945             cw.write_func_prot('int', name,
1946                                ['struct sk_buff *skb', 'struct genl_info *info'], suffix=';')
1947
1948         if 'dump' in op:
1949             name = c_lower(f"{family.name}-nl-{op_name}-dumpit")
1950             cw.write_func_prot('int', name,
1951                                ['struct sk_buff *skb', 'struct netlink_callback *cb'], suffix=';')
1952     cw.nl()
1953
1954
1955 def print_kernel_op_table_hdr(family, cw):
1956     print_kernel_op_table_fwd(family, cw, terminate=True)
1957
1958
1959 def print_kernel_op_table(family, cw):
1960     print_kernel_op_table_fwd(family, cw, terminate=False)
1961     if family.kernel_policy == 'global' or family.kernel_policy == 'per-op':
1962         for op_name, op in family.ops.items():
1963             if op.is_async:
1964                 continue
1965
1966             cw.block_start()
1967             members = [('cmd', op.enum_name)]
1968             if 'dont-validate' in op:
1969                 members.append(('validate',
1970                                 ' | '.join([c_upper('genl-dont-validate-' + x)
1971                                             for x in op['dont-validate']])), )
1972             for op_mode in ['do', 'dump']:
1973                 if op_mode in op:
1974                     name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
1975                     members.append((op_mode + 'it', name))
1976             if family.kernel_policy == 'per-op':
1977                 struct = Struct(family, op['attribute-set'],
1978                                 type_list=op['do']['request']['attributes'])
1979
1980                 name = c_lower(f"{family.name}-{op_name}-nl-policy")
1981                 members.append(('policy', name))
1982                 members.append(('maxattr', struct.attr_max_val.enum_name))
1983             if 'flags' in op:
1984                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in op['flags']])))
1985             cw.write_struct_init(members)
1986             cw.block_end(line=',')
1987     elif family.kernel_policy == 'split':
1988         cb_names = {'do':   {'pre': 'pre_doit', 'post': 'post_doit'},
1989                     'dump': {'pre': 'start', 'post': 'done'}}
1990
1991         for op_name, op in family.ops.items():
1992             for op_mode in ['do', 'dump']:
1993                 if op.is_async or op_mode not in op:
1994                     continue
1995
1996                 cw.block_start()
1997                 members = [('cmd', op.enum_name)]
1998                 if 'dont-validate' in op:
1999                     dont_validate = []
2000                     for x in op['dont-validate']:
2001                         if op_mode == 'do' and x in ['dump', 'dump-strict']:
2002                             continue
2003                         if op_mode == "dump" and x == 'strict':
2004                             continue
2005                         dont_validate.append(x)
2006
2007                     if dont_validate:
2008                         members.append(('validate',
2009                                         ' | '.join([c_upper('genl-dont-validate-' + x)
2010                                                     for x in dont_validate])), )
2011                 name = c_lower(f"{family.name}-nl-{op_name}-{op_mode}it")
2012                 if 'pre' in op[op_mode]:
2013                     members.append((cb_names[op_mode]['pre'], c_lower(op[op_mode]['pre'])))
2014                 members.append((op_mode + 'it', name))
2015                 if 'post' in op[op_mode]:
2016                     members.append((cb_names[op_mode]['post'], c_lower(op[op_mode]['post'])))
2017                 if 'request' in op[op_mode]:
2018                     struct = Struct(family, op['attribute-set'],
2019                                     type_list=op[op_mode]['request']['attributes'])
2020
2021                     if op.dual_policy:
2022                         name = c_lower(f"{family.name}-{op_name}-{op_mode}-nl-policy")
2023                     else:
2024                         name = c_lower(f"{family.name}-{op_name}-nl-policy")
2025                     members.append(('policy', name))
2026                     members.append(('maxattr', struct.attr_max_val.enum_name))
2027                 flags = (op['flags'] if 'flags' in op else []) + ['cmd-cap-' + op_mode]
2028                 members.append(('flags', ' | '.join([c_upper('genl-' + x) for x in flags])))
2029                 cw.write_struct_init(members)
2030                 cw.block_end(line=',')
2031
2032     cw.block_end(line=';')
2033     cw.nl()
2034
2035
2036 def print_kernel_mcgrp_hdr(family, cw):
2037     if not family.mcgrps['list']:
2038         return
2039
2040     cw.block_start('enum')
2041     for grp in family.mcgrps['list']:
2042         grp_id = c_upper(f"{family.name}-nlgrp-{grp['name']},")
2043         cw.p(grp_id)
2044     cw.block_end(';')
2045     cw.nl()
2046
2047
2048 def print_kernel_mcgrp_src(family, cw):
2049     if not family.mcgrps['list']:
2050         return
2051
2052     cw.block_start('static const struct genl_multicast_group ' + family.name + '_nl_mcgrps[] =')
2053     for grp in family.mcgrps['list']:
2054         name = grp['name']
2055         grp_id = c_upper(f"{family.name}-nlgrp-{name}")
2056         cw.p('[' + grp_id + '] = { "' + name + '", },')
2057     cw.block_end(';')
2058     cw.nl()
2059
2060
2061 def print_kernel_family_struct_hdr(family, cw):
2062     if not kernel_can_gen_family_struct(family):
2063         return
2064
2065     cw.p(f"extern struct genl_family {family.name}_nl_family;")
2066     cw.nl()
2067
2068
2069 def print_kernel_family_struct_src(family, cw):
2070     if not kernel_can_gen_family_struct(family):
2071         return
2072
2073     cw.block_start(f"struct genl_family {family.name}_nl_family __ro_after_init =")
2074     cw.p('.name\t\t= ' + family.fam_key + ',')
2075     cw.p('.version\t= ' + family.ver_key + ',')
2076     cw.p('.netnsok\t= true,')
2077     cw.p('.parallel_ops\t= true,')
2078     cw.p('.module\t\t= THIS_MODULE,')
2079     if family.kernel_policy == 'per-op':
2080         cw.p(f'.ops\t\t= {family.name}_nl_ops,')
2081         cw.p(f'.n_ops\t\t= ARRAY_SIZE({family.name}_nl_ops),')
2082     elif family.kernel_policy == 'split':
2083         cw.p(f'.split_ops\t= {family.name}_nl_ops,')
2084         cw.p(f'.n_split_ops\t= ARRAY_SIZE({family.name}_nl_ops),')
2085     if family.mcgrps['list']:
2086         cw.p(f'.mcgrps\t\t= {family.name}_nl_mcgrps,')
2087         cw.p(f'.n_mcgrps\t= ARRAY_SIZE({family.name}_nl_mcgrps),')
2088     cw.block_end(';')
2089
2090
2091 def uapi_enum_start(family, cw, obj, ckey='', enum_name='enum-name'):
2092     start_line = 'enum'
2093     if enum_name in obj:
2094         if obj[enum_name]:
2095             start_line = 'enum ' + c_lower(obj[enum_name])
2096     elif ckey and ckey in obj:
2097         start_line = 'enum ' + family.name + '_' + c_lower(obj[ckey])
2098     cw.block_start(line=start_line)
2099
2100
2101 def render_uapi(family, cw):
2102     hdr_prot = f"_UAPI_LINUX_{family.name.upper()}_H"
2103     cw.p('#ifndef ' + hdr_prot)
2104     cw.p('#define ' + hdr_prot)
2105     cw.nl()
2106
2107     defines = [(family.fam_key, family["name"]),
2108                (family.ver_key, family.get('version', 1))]
2109     cw.writes_defines(defines)
2110     cw.nl()
2111
2112     defines = []
2113     for const in family['definitions']:
2114         if const['type'] != 'const':
2115             cw.writes_defines(defines)
2116             defines = []
2117             cw.nl()
2118
2119         # Write kdoc for enum and flags (one day maybe also structs)
2120         if const['type'] == 'enum' or const['type'] == 'flags':
2121             enum = family.consts[const['name']]
2122
2123             if enum.has_doc():
2124                 cw.p('/**')
2125                 doc = ''
2126                 if 'doc' in enum:
2127                     doc = ' - ' + enum['doc']
2128                 cw.write_doc_line(enum.enum_name + doc)
2129                 for entry in enum.entries.values():
2130                     if entry.has_doc():
2131                         doc = '@' + entry.c_name + ': ' + entry['doc']
2132                         cw.write_doc_line(doc)
2133                 cw.p(' */')
2134
2135             uapi_enum_start(family, cw, const, 'name')
2136             name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
2137             for entry in enum.entries.values():
2138                 suffix = ','
2139                 if entry.value_change:
2140                     suffix = f" = {entry.user_value()}" + suffix
2141                 cw.p(entry.c_name + suffix)
2142
2143             if const.get('render-max', False):
2144                 cw.nl()
2145                 cw.p('/* private: */')
2146                 if const['type'] == 'flags':
2147                     max_name = c_upper(name_pfx + 'mask')
2148                     max_val = f' = {enum.get_mask()},'
2149                     cw.p(max_name + max_val)
2150                 else:
2151                     max_name = c_upper(name_pfx + 'max')
2152                     cw.p('__' + max_name + ',')
2153                     cw.p(max_name + ' = (__' + max_name + ' - 1)')
2154             cw.block_end(line=';')
2155             cw.nl()
2156         elif const['type'] == 'const':
2157             defines.append([c_upper(family.get('c-define-name',
2158                                                f"{family.name}-{const['name']}")),
2159                             const['value']])
2160
2161     if defines:
2162         cw.writes_defines(defines)
2163         cw.nl()
2164
2165     max_by_define = family.get('max-by-define', False)
2166
2167     for _, attr_set in family.attr_sets.items():
2168         if attr_set.subset_of:
2169             continue
2170
2171         cnt_name = c_upper(family.get('attr-cnt-name', f"__{attr_set.name_prefix}MAX"))
2172         max_value = f"({cnt_name} - 1)"
2173
2174         val = 0
2175         uapi_enum_start(family, cw, attr_set.yaml, 'enum-name')
2176         for _, attr in attr_set.items():
2177             suffix = ','
2178             if attr.value != val:
2179                 suffix = f" = {attr.value},"
2180                 val = attr.value
2181             val += 1
2182             cw.p(attr.enum_name + suffix)
2183         cw.nl()
2184         cw.p(cnt_name + ('' if max_by_define else ','))
2185         if not max_by_define:
2186             cw.p(f"{attr_set.max_name} = {max_value}")
2187         cw.block_end(line=';')
2188         if max_by_define:
2189             cw.p(f"#define {attr_set.max_name} {max_value}")
2190         cw.nl()
2191
2192     # Commands
2193     separate_ntf = 'async-prefix' in family['operations']
2194
2195     max_name = c_upper(family.get('cmd-max-name', f"{family.op_prefix}MAX"))
2196     cnt_name = c_upper(family.get('cmd-cnt-name', f"__{family.op_prefix}MAX"))
2197     max_value = f"({cnt_name} - 1)"
2198
2199     uapi_enum_start(family, cw, family['operations'], 'enum-name')
2200     val = 0
2201     for op in family.msgs.values():
2202         if separate_ntf and ('notify' in op or 'event' in op):
2203             continue
2204
2205         suffix = ','
2206         if op.value != val:
2207             suffix = f" = {op.value},"
2208             val = op.value
2209         cw.p(op.enum_name + suffix)
2210         val += 1
2211     cw.nl()
2212     cw.p(cnt_name + ('' if max_by_define else ','))
2213     if not max_by_define:
2214         cw.p(f"{max_name} = {max_value}")
2215     cw.block_end(line=';')
2216     if max_by_define:
2217         cw.p(f"#define {max_name} {max_value}")
2218     cw.nl()
2219
2220     if separate_ntf:
2221         uapi_enum_start(family, cw, family['operations'], enum_name='async-enum')
2222         for op in family.msgs.values():
2223             if separate_ntf and not ('notify' in op or 'event' in op):
2224                 continue
2225
2226             suffix = ','
2227             if 'value' in op:
2228                 suffix = f" = {op['value']},"
2229             cw.p(op.enum_name + suffix)
2230         cw.block_end(line=';')
2231         cw.nl()
2232
2233     # Multicast
2234     defines = []
2235     for grp in family.mcgrps['list']:
2236         name = grp['name']
2237         defines.append([c_upper(grp.get('c-define-name', f"{family.name}-mcgrp-{name}")),
2238                         f'{name}'])
2239     cw.nl()
2240     if defines:
2241         cw.writes_defines(defines)
2242         cw.nl()
2243
2244     cw.p(f'#endif /* {hdr_prot} */')
2245
2246
2247 def _render_user_ntf_entry(ri, op):
2248     ri.cw.block_start(line=f"[{op.enum_name}] = ")
2249     ri.cw.p(f".alloc_sz\t= sizeof({type_name(ri, 'event')}),")
2250     ri.cw.p(f".cb\t\t= {op_prefix(ri, 'reply', deref=True)}_parse,")
2251     ri.cw.p(f".policy\t\t= &{ri.struct['reply'].render_name}_nest,")
2252     ri.cw.p(f".free\t\t= (void *){op_prefix(ri, 'notify')}_free,")
2253     ri.cw.block_end(line=',')
2254
2255
2256 def render_user_family(family, cw, prototype):
2257     symbol = f'const struct ynl_family ynl_{family.c_name}_family'
2258     if prototype:
2259         cw.p(f'extern {symbol};')
2260         return
2261
2262     if family.ntfs:
2263         cw.block_start(line=f"static const struct ynl_ntf_info {family['name']}_ntf_info[] = ")
2264         for ntf_op_name, ntf_op in family.ntfs.items():
2265             if 'notify' in ntf_op:
2266                 op = family.ops[ntf_op['notify']]
2267                 ri = RenderInfo(cw, family, "user", op, "notify")
2268             elif 'event' in ntf_op:
2269                 ri = RenderInfo(cw, family, "user", ntf_op, "event")
2270             else:
2271                 raise Exception('Invalid notification ' + ntf_op_name)
2272             _render_user_ntf_entry(ri, ntf_op)
2273         for op_name, op in family.ops.items():
2274             if 'event' not in op:
2275                 continue
2276             ri = RenderInfo(cw, family, "user", op, "event")
2277             _render_user_ntf_entry(ri, op)
2278         cw.block_end(line=";")
2279         cw.nl()
2280
2281     cw.block_start(f'{symbol} = ')
2282     cw.p(f'.name\t\t= "{family.name}",')
2283     if family.ntfs:
2284         cw.p(f".ntf_info\t= {family['name']}_ntf_info,")
2285         cw.p(f".ntf_info_size\t= MNL_ARRAY_SIZE({family['name']}_ntf_info),")
2286     cw.block_end(line=';')
2287
2288
2289 def find_kernel_root(full_path):
2290     sub_path = ''
2291     while True:
2292         sub_path = os.path.join(os.path.basename(full_path), sub_path)
2293         full_path = os.path.dirname(full_path)
2294         maintainers = os.path.join(full_path, "MAINTAINERS")
2295         if os.path.exists(maintainers):
2296             return full_path, sub_path[:-1]
2297
2298
2299 def main():
2300     parser = argparse.ArgumentParser(description='Netlink simple parsing generator')
2301     parser.add_argument('--mode', dest='mode', type=str, required=True)
2302     parser.add_argument('--spec', dest='spec', type=str, required=True)
2303     parser.add_argument('--header', dest='header', action='store_true', default=None)
2304     parser.add_argument('--source', dest='header', action='store_false')
2305     parser.add_argument('--user-header', nargs='+', default=[])
2306     parser.add_argument('--exclude-op', action='append', default=[])
2307     parser.add_argument('-o', dest='out_file', type=str)
2308     args = parser.parse_args()
2309
2310     tmp_file = tempfile.TemporaryFile('w+') if args.out_file else os.sys.stdout
2311
2312     if args.header is None:
2313         parser.error("--header or --source is required")
2314
2315     exclude_ops = [re.compile(expr) for expr in args.exclude_op]
2316
2317     try:
2318         parsed = Family(args.spec, exclude_ops)
2319         if parsed.license != '((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)':
2320             print('Spec license:', parsed.license)
2321             print('License must be: ((GPL-2.0 WITH Linux-syscall-note) OR BSD-3-Clause)')
2322             os.sys.exit(1)
2323     except yaml.YAMLError as exc:
2324         print(exc)
2325         os.sys.exit(1)
2326         return
2327
2328     supported_models = ['unified']
2329     if args.mode in ['user', 'kernel']:
2330         supported_models += ['directional']
2331     if parsed.msg_id_model not in supported_models:
2332         print(f'Message enum-model {parsed.msg_id_model} not supported for {args.mode} generation')
2333         os.sys.exit(1)
2334
2335     cw = CodeWriter(BaseNlLib(), tmp_file)
2336
2337     _, spec_kernel = find_kernel_root(args.spec)
2338     if args.mode == 'uapi' or args.header:
2339         cw.p(f'/* SPDX-License-Identifier: {parsed.license} */')
2340     else:
2341         cw.p(f'// SPDX-License-Identifier: {parsed.license}')
2342     cw.p("/* Do not edit directly, auto-generated from: */")
2343     cw.p(f"/*\t{spec_kernel} */")
2344     cw.p(f"/* YNL-GEN {args.mode} {'header' if args.header else 'source'} */")
2345     if args.exclude_op or args.user_header:
2346         line = ''
2347         line += ' --user-header '.join([''] + args.user_header)
2348         line += ' --exclude-op '.join([''] + args.exclude_op)
2349         cw.p(f'/* YNL-ARG{line} */')
2350     cw.nl()
2351
2352     if args.mode == 'uapi':
2353         render_uapi(parsed, cw)
2354         return
2355
2356     hdr_prot = f"_LINUX_{parsed.name.upper()}_GEN_H"
2357     if args.header:
2358         cw.p('#ifndef ' + hdr_prot)
2359         cw.p('#define ' + hdr_prot)
2360         cw.nl()
2361
2362     if args.mode == 'kernel':
2363         cw.p('#include <net/netlink.h>')
2364         cw.p('#include <net/genetlink.h>')
2365         cw.nl()
2366         if not args.header:
2367             if args.out_file:
2368                 cw.p(f'#include "{os.path.basename(args.out_file[:-2])}.h"')
2369             cw.nl()
2370         headers = ['uapi/' + parsed.uapi_header]
2371     else:
2372         cw.p('#include <stdlib.h>')
2373         cw.p('#include <string.h>')
2374         if args.header:
2375             cw.p('#include <linux/types.h>')
2376         else:
2377             cw.p(f'#include "{parsed.name}-user.h"')
2378             cw.p('#include "ynl.h"')
2379         headers = [parsed.uapi_header]
2380     for definition in parsed['definitions']:
2381         if 'header' in definition:
2382             headers.append(definition['header'])
2383     for one in headers:
2384         cw.p(f"#include <{one}>")
2385     cw.nl()
2386
2387     if args.mode == "user":
2388         if not args.header:
2389             cw.p("#include <libmnl/libmnl.h>")
2390             cw.p("#include <linux/genetlink.h>")
2391             cw.nl()
2392             for one in args.user_header:
2393                 cw.p(f'#include "{one}"')
2394         else:
2395             cw.p('struct ynl_sock;')
2396             cw.nl()
2397             render_user_family(parsed, cw, True)
2398         cw.nl()
2399
2400     if args.mode == "kernel":
2401         if args.header:
2402             for _, struct in sorted(parsed.pure_nested_structs.items()):
2403                 if struct.request:
2404                     cw.p('/* Common nested types */')
2405                     break
2406             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2407                 if struct.request:
2408                     print_req_policy_fwd(cw, struct)
2409             cw.nl()
2410
2411             if parsed.kernel_policy == 'global':
2412                 cw.p(f"/* Global operation policy for {parsed.name} */")
2413
2414                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2415                 print_req_policy_fwd(cw, struct)
2416                 cw.nl()
2417
2418             if parsed.kernel_policy in {'per-op', 'split'}:
2419                 for op_name, op in parsed.ops.items():
2420                     if 'do' in op and 'event' not in op:
2421                         ri = RenderInfo(cw, parsed, args.mode, op, "do")
2422                         print_req_policy_fwd(cw, ri.struct['request'], ri=ri)
2423                         cw.nl()
2424
2425             print_kernel_op_table_hdr(parsed, cw)
2426             print_kernel_mcgrp_hdr(parsed, cw)
2427             print_kernel_family_struct_hdr(parsed, cw)
2428         else:
2429             for _, struct in sorted(parsed.pure_nested_structs.items()):
2430                 if struct.request:
2431                     cw.p('/* Common nested types */')
2432                     break
2433             for attr_set, struct in sorted(parsed.pure_nested_structs.items()):
2434                 if struct.request:
2435                     print_req_policy(cw, struct)
2436             cw.nl()
2437
2438             if parsed.kernel_policy == 'global':
2439                 cw.p(f"/* Global operation policy for {parsed.name} */")
2440
2441                 struct = Struct(parsed, parsed.global_policy_set, type_list=parsed.global_policy)
2442                 print_req_policy(cw, struct)
2443                 cw.nl()
2444
2445             for op_name, op in parsed.ops.items():
2446                 if parsed.kernel_policy in {'per-op', 'split'}:
2447                     for op_mode in ['do', 'dump']:
2448                         if op_mode in op and 'request' in op[op_mode]:
2449                             cw.p(f"/* {op.enum_name} - {op_mode} */")
2450                             ri = RenderInfo(cw, parsed, args.mode, op, op_mode)
2451                             print_req_policy(cw, ri.struct['request'], ri=ri)
2452                             cw.nl()
2453
2454             print_kernel_op_table(parsed, cw)
2455             print_kernel_mcgrp_src(parsed, cw)
2456             print_kernel_family_struct_src(parsed, cw)
2457
2458     if args.mode == "user":
2459         if args.header:
2460             cw.p('/* Enums */')
2461             put_op_name_fwd(parsed, cw)
2462
2463             for name, const in parsed.consts.items():
2464                 if isinstance(const, EnumSet):
2465                     put_enum_to_str_fwd(parsed, cw, const)
2466             cw.nl()
2467
2468             cw.p('/* Common nested types */')
2469             for attr_set, struct in parsed.pure_nested_structs.items():
2470                 ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2471                 print_type_full(ri, struct)
2472
2473             for op_name, op in parsed.ops.items():
2474                 cw.p(f"/* ============== {op.enum_name} ============== */")
2475
2476                 if 'do' in op and 'event' not in op:
2477                     cw.p(f"/* {op.enum_name} - do */")
2478                     ri = RenderInfo(cw, parsed, args.mode, op, "do")
2479                     print_req_type(ri)
2480                     print_req_type_helpers(ri)
2481                     cw.nl()
2482                     print_rsp_type(ri)
2483                     print_rsp_type_helpers(ri)
2484                     cw.nl()
2485                     print_req_prototype(ri)
2486                     cw.nl()
2487
2488                 if 'dump' in op:
2489                     cw.p(f"/* {op.enum_name} - dump */")
2490                     ri = RenderInfo(cw, parsed, args.mode, op, 'dump')
2491                     if 'request' in op['dump']:
2492                         print_req_type(ri)
2493                         print_req_type_helpers(ri)
2494                     if not ri.type_consistent:
2495                         print_rsp_type(ri)
2496                     print_wrapped_type(ri)
2497                     print_dump_prototype(ri)
2498                     cw.nl()
2499
2500                 if op.has_ntf:
2501                     cw.p(f"/* {op.enum_name} - notify */")
2502                     ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2503                     if not ri.type_consistent:
2504                         raise Exception(f'Only notifications with consistent types supported ({op.name})')
2505                     print_wrapped_type(ri)
2506
2507             for op_name, op in parsed.ntfs.items():
2508                 if 'event' in op:
2509                     ri = RenderInfo(cw, parsed, args.mode, op, 'event')
2510                     cw.p(f"/* {op.enum_name} - event */")
2511                     print_rsp_type(ri)
2512                     cw.nl()
2513                     print_wrapped_type(ri)
2514             cw.nl()
2515         else:
2516             cw.p('/* Enums */')
2517             put_op_name(parsed, cw)
2518
2519             for name, const in parsed.consts.items():
2520                 if isinstance(const, EnumSet):
2521                     put_enum_to_str(parsed, cw, const)
2522             cw.nl()
2523
2524             cw.p('/* Policies */')
2525             for name in parsed.pure_nested_structs:
2526                 struct = Struct(parsed, name)
2527                 put_typol(cw, struct)
2528             for name in parsed.root_sets:
2529                 struct = Struct(parsed, name)
2530                 put_typol(cw, struct)
2531
2532             cw.p('/* Common nested types */')
2533             for attr_set, struct in parsed.pure_nested_structs.items():
2534                 ri = RenderInfo(cw, parsed, args.mode, "", "", attr_set)
2535
2536                 free_rsp_nested(ri, struct)
2537                 if struct.request:
2538                     put_req_nested(ri, struct)
2539                 if struct.reply:
2540                     parse_rsp_nested(ri, struct)
2541
2542             for op_name, op in parsed.ops.items():
2543                 cw.p(f"/* ============== {op.enum_name} ============== */")
2544                 if 'do' in op and 'event' not in op:
2545                     cw.p(f"/* {op.enum_name} - do */")
2546                     ri = RenderInfo(cw, parsed, args.mode, op, "do")
2547                     print_req_free(ri)
2548                     print_rsp_free(ri)
2549                     parse_rsp_msg(ri)
2550                     print_req(ri)
2551                     cw.nl()
2552
2553                 if 'dump' in op:
2554                     cw.p(f"/* {op.enum_name} - dump */")
2555                     ri = RenderInfo(cw, parsed, args.mode, op, "dump")
2556                     if not ri.type_consistent:
2557                         parse_rsp_msg(ri, deref=True)
2558                     print_dump_type_free(ri)
2559                     print_dump(ri)
2560                     cw.nl()
2561
2562                 if op.has_ntf:
2563                     cw.p(f"/* {op.enum_name} - notify */")
2564                     ri = RenderInfo(cw, parsed, args.mode, op, 'notify')
2565                     if not ri.type_consistent:
2566                         raise Exception(f'Only notifications with consistent types supported ({op.name})')
2567                     print_ntf_type_free(ri)
2568
2569             for op_name, op in parsed.ntfs.items():
2570                 if 'event' in op:
2571                     cw.p(f"/* {op.enum_name} - event */")
2572
2573                     ri = RenderInfo(cw, parsed, args.mode, op, "do")
2574                     parse_rsp_msg(ri)
2575
2576                     ri = RenderInfo(cw, parsed, args.mode, op, "event")
2577                     print_ntf_type_free(ri)
2578             cw.nl()
2579             render_user_family(parsed, cw, False)
2580
2581     if args.header:
2582         cw.p(f'#endif /* {hdr_prot} */')
2583
2584     if args.out_file:
2585         out_file = open(args.out_file, 'w+')
2586         tmp_file.seek(0)
2587         shutil.copyfileobj(tmp_file, out_file)
2588
2589 if __name__ == "__main__":
2590     main()