Merge 6.4-rc5 into usb-next
[platform/kernel/linux-starfive.git] / tools / net / ynl / lib / ynl.py
1 # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3 import functools
4 import os
5 import random
6 import socket
7 import struct
8 import yaml
9
10 from .nlspec import SpecFamily
11
12 #
13 # Generic Netlink code which should really be in some library, but I can't quickly find one.
14 #
15
16
17 class Netlink:
18     # Netlink socket
19     SOL_NETLINK = 270
20
21     NETLINK_ADD_MEMBERSHIP = 1
22     NETLINK_CAP_ACK = 10
23     NETLINK_EXT_ACK = 11
24
25     # Netlink message
26     NLMSG_ERROR = 2
27     NLMSG_DONE = 3
28
29     NLM_F_REQUEST = 1
30     NLM_F_ACK = 4
31     NLM_F_ROOT = 0x100
32     NLM_F_MATCH = 0x200
33     NLM_F_APPEND = 0x800
34
35     NLM_F_CAPPED = 0x100
36     NLM_F_ACK_TLVS = 0x200
37
38     NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
39
40     NLA_F_NESTED = 0x8000
41     NLA_F_NET_BYTEORDER = 0x4000
42
43     NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
44
45     # Genetlink defines
46     NETLINK_GENERIC = 16
47
48     GENL_ID_CTRL = 0x10
49
50     # nlctrl
51     CTRL_CMD_GETFAMILY = 3
52
53     CTRL_ATTR_FAMILY_ID = 1
54     CTRL_ATTR_FAMILY_NAME = 2
55     CTRL_ATTR_MAXATTR = 5
56     CTRL_ATTR_MCAST_GROUPS = 7
57
58     CTRL_ATTR_MCAST_GRP_NAME = 1
59     CTRL_ATTR_MCAST_GRP_ID = 2
60
61     # Extack types
62     NLMSGERR_ATTR_MSG = 1
63     NLMSGERR_ATTR_OFFS = 2
64     NLMSGERR_ATTR_COOKIE = 3
65     NLMSGERR_ATTR_POLICY = 4
66     NLMSGERR_ATTR_MISS_TYPE = 5
67     NLMSGERR_ATTR_MISS_NEST = 6
68
69
70 class NlError(Exception):
71   def __init__(self, nl_msg):
72     self.nl_msg = nl_msg
73
74   def __str__(self):
75     return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
76
77
78 class NlAttr:
79     type_formats = { 'u8' : ('B', 1), 's8' : ('b', 1),
80                      'u16': ('H', 2), 's16': ('h', 2),
81                      'u32': ('I', 4), 's32': ('i', 4),
82                      'u64': ('Q', 8), 's64': ('q', 8) }
83
84     def __init__(self, raw, offset):
85         self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
86         self.type = self._type & ~Netlink.NLA_TYPE_MASK
87         self.payload_len = self._len
88         self.full_len = (self.payload_len + 3) & ~3
89         self.raw = raw[offset + 4:offset + self.payload_len]
90
91     def format_byte_order(byte_order):
92         if byte_order:
93             return ">" if byte_order == "big-endian" else "<"
94         return ""
95
96     def as_u8(self):
97         return struct.unpack("B", self.raw)[0]
98
99     def as_u16(self, byte_order=None):
100         endian = NlAttr.format_byte_order(byte_order)
101         return struct.unpack(f"{endian}H", self.raw)[0]
102
103     def as_u32(self, byte_order=None):
104         endian = NlAttr.format_byte_order(byte_order)
105         return struct.unpack(f"{endian}I", self.raw)[0]
106
107     def as_u64(self, byte_order=None):
108         endian = NlAttr.format_byte_order(byte_order)
109         return struct.unpack(f"{endian}Q", self.raw)[0]
110
111     def as_strz(self):
112         return self.raw.decode('ascii')[:-1]
113
114     def as_bin(self):
115         return self.raw
116
117     def as_c_array(self, type):
118         format, _ = self.type_formats[type]
119         return list({ x[0] for x in struct.iter_unpack(format, self.raw) })
120
121     def as_struct(self, members):
122         value = dict()
123         offset = 0
124         for m in members:
125             # TODO: handle non-scalar members
126             format, size = self.type_formats[m.type]
127             decoded = struct.unpack_from(format, self.raw, offset)
128             offset += size
129             value[m.name] = decoded[0]
130         return value
131
132     def __repr__(self):
133         return f"[type:{self.type} len:{self._len}] {self.raw}"
134
135
136 class NlAttrs:
137     def __init__(self, msg):
138         self.attrs = []
139
140         offset = 0
141         while offset < len(msg):
142             attr = NlAttr(msg, offset)
143             offset += attr.full_len
144             self.attrs.append(attr)
145
146     def __iter__(self):
147         yield from self.attrs
148
149     def __repr__(self):
150         msg = ''
151         for a in self.attrs:
152             if msg:
153                 msg += '\n'
154             msg += repr(a)
155         return msg
156
157
158 class NlMsg:
159     def __init__(self, msg, offset, attr_space=None):
160         self.hdr = msg[offset:offset + 16]
161
162         self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
163             struct.unpack("IHHII", self.hdr)
164
165         self.raw = msg[offset + 16:offset + self.nl_len]
166
167         self.error = 0
168         self.done = 0
169
170         extack_off = None
171         if self.nl_type == Netlink.NLMSG_ERROR:
172             self.error = struct.unpack("i", self.raw[0:4])[0]
173             self.done = 1
174             extack_off = 20
175         elif self.nl_type == Netlink.NLMSG_DONE:
176             self.done = 1
177             extack_off = 4
178
179         self.extack = None
180         if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
181             self.extack = dict()
182             extack_attrs = NlAttrs(self.raw[extack_off:])
183             for extack in extack_attrs:
184                 if extack.type == Netlink.NLMSGERR_ATTR_MSG:
185                     self.extack['msg'] = extack.as_strz()
186                 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
187                     self.extack['miss-type'] = extack.as_u32()
188                 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
189                     self.extack['miss-nest'] = extack.as_u32()
190                 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
191                     self.extack['bad-attr-offs'] = extack.as_u32()
192                 else:
193                     if 'unknown' not in self.extack:
194                         self.extack['unknown'] = []
195                     self.extack['unknown'].append(extack)
196
197             if attr_space:
198                 # We don't have the ability to parse nests yet, so only do global
199                 if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
200                     miss_type = self.extack['miss-type']
201                     if miss_type in attr_space.attrs_by_val:
202                         spec = attr_space.attrs_by_val[miss_type]
203                         desc = spec['name']
204                         if 'doc' in spec:
205                             desc += f" ({spec['doc']})"
206                         self.extack['miss-type'] = desc
207
208     def __repr__(self):
209         msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
210         if self.error:
211             msg += '\terror: ' + str(self.error)
212         if self.extack:
213             msg += '\textack: ' + repr(self.extack)
214         return msg
215
216
217 class NlMsgs:
218     def __init__(self, data, attr_space=None):
219         self.msgs = []
220
221         offset = 0
222         while offset < len(data):
223             msg = NlMsg(data, offset, attr_space=attr_space)
224             offset += msg.nl_len
225             self.msgs.append(msg)
226
227     def __iter__(self):
228         yield from self.msgs
229
230
231 genl_family_name_to_id = None
232
233
234 def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
235     # we prepend length in _genl_msg_finalize()
236     if seq is None:
237         seq = random.randint(1, 1024)
238     nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
239     genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
240     return nlmsg + genlmsg
241
242
243 def _genl_msg_finalize(msg):
244     return struct.pack("I", len(msg) + 4) + msg
245
246
247 def _genl_load_families():
248     with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
249         sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
250
251         msg = _genl_msg(Netlink.GENL_ID_CTRL,
252                         Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
253                         Netlink.CTRL_CMD_GETFAMILY, 1)
254         msg = _genl_msg_finalize(msg)
255
256         sock.send(msg, 0)
257
258         global genl_family_name_to_id
259         genl_family_name_to_id = dict()
260
261         while True:
262             reply = sock.recv(128 * 1024)
263             nms = NlMsgs(reply)
264             for nl_msg in nms:
265                 if nl_msg.error:
266                     print("Netlink error:", nl_msg.error)
267                     return
268                 if nl_msg.done:
269                     return
270
271                 gm = GenlMsg(nl_msg)
272                 fam = dict()
273                 for attr in gm.raw_attrs:
274                     if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
275                         fam['id'] = attr.as_u16()
276                     elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
277                         fam['name'] = attr.as_strz()
278                     elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
279                         fam['maxattr'] = attr.as_u32()
280                     elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
281                         fam['mcast'] = dict()
282                         for entry in NlAttrs(attr.raw):
283                             mcast_name = None
284                             mcast_id = None
285                             for entry_attr in NlAttrs(entry.raw):
286                                 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
287                                     mcast_name = entry_attr.as_strz()
288                                 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
289                                     mcast_id = entry_attr.as_u32()
290                             if mcast_name and mcast_id is not None:
291                                 fam['mcast'][mcast_name] = mcast_id
292                 if 'name' in fam and 'id' in fam:
293                     genl_family_name_to_id[fam['name']] = fam
294
295
296 class GenlMsg:
297     def __init__(self, nl_msg, fixed_header_members=[]):
298         self.nl = nl_msg
299
300         self.hdr = nl_msg.raw[0:4]
301         offset = 4
302
303         self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr)
304
305         self.fixed_header_attrs = dict()
306         for m in fixed_header_members:
307             format, size = NlAttr.type_formats[m.type]
308             decoded = struct.unpack_from(format, nl_msg.raw, offset)
309             offset += size
310             self.fixed_header_attrs[m.name] = decoded[0]
311
312         self.raw = nl_msg.raw[offset:]
313         self.raw_attrs = NlAttrs(self.raw)
314
315     def __repr__(self):
316         msg = repr(self.nl)
317         msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
318         for a in self.raw_attrs:
319             msg += '\t\t' + repr(a) + '\n'
320         return msg
321
322
323 class GenlFamily:
324     def __init__(self, family_name):
325         self.family_name = family_name
326
327         global genl_family_name_to_id
328         if genl_family_name_to_id is None:
329             _genl_load_families()
330
331         self.genl_family = genl_family_name_to_id[family_name]
332         self.family_id = genl_family_name_to_id[family_name]['id']
333
334
335 #
336 # YNL implementation details.
337 #
338
339
340 class YnlFamily(SpecFamily):
341     def __init__(self, def_path, schema=None):
342         super().__init__(def_path, schema)
343
344         self.include_raw = False
345
346         self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC)
347         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
348         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
349
350         self.async_msg_ids = set()
351         self.async_msg_queue = []
352
353         for msg in self.msgs.values():
354             if msg.is_async:
355                 self.async_msg_ids.add(msg.rsp_value)
356
357         for op_name, op in self.ops.items():
358             bound_f = functools.partial(self._op, op_name)
359             setattr(self, op.ident_name, bound_f)
360
361         try:
362             self.family = GenlFamily(self.yaml['name'])
363         except KeyError:
364             raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
365
366     def ntf_subscribe(self, mcast_name):
367         if mcast_name not in self.family.genl_family['mcast']:
368             raise Exception(f'Multicast group "{mcast_name}" not present in the family')
369
370         self.sock.bind((0, 0))
371         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
372                              self.family.genl_family['mcast'][mcast_name])
373
374     def _add_attr(self, space, name, value):
375         attr = self.attr_sets[space][name]
376         nl_type = attr.value
377         if attr["type"] == 'nest':
378             nl_type |= Netlink.NLA_F_NESTED
379             attr_payload = b''
380             for subname, subvalue in value.items():
381                 attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
382         elif attr["type"] == 'flag':
383             attr_payload = b''
384         elif attr["type"] == 'u8':
385             attr_payload = struct.pack("B", int(value))
386         elif attr["type"] == 'u16':
387             endian = NlAttr.format_byte_order(attr.byte_order)
388             attr_payload = struct.pack(f"{endian}H", int(value))
389         elif attr["type"] == 'u32':
390             endian = NlAttr.format_byte_order(attr.byte_order)
391             attr_payload = struct.pack(f"{endian}I", int(value))
392         elif attr["type"] == 'u64':
393             endian = NlAttr.format_byte_order(attr.byte_order)
394             attr_payload = struct.pack(f"{endian}Q", int(value))
395         elif attr["type"] == 'string':
396             attr_payload = str(value).encode('ascii') + b'\x00'
397         elif attr["type"] == 'binary':
398             attr_payload = value
399         else:
400             raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
401
402         pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
403         return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
404
405     def _decode_enum(self, rsp, attr_spec):
406         raw = rsp[attr_spec['name']]
407         enum = self.consts[attr_spec['enum']]
408         i = attr_spec.get('value-start', 0)
409         if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']:
410             value = set()
411             while raw:
412                 if raw & 1:
413                     value.add(enum.entries_by_val[i].name)
414                 raw >>= 1
415                 i += 1
416         else:
417             value = enum.entries_by_val[raw - i].name
418         rsp[attr_spec['name']] = value
419
420     def _decode_binary(self, attr, attr_spec):
421         if attr_spec.struct_name:
422             decoded = attr.as_struct(self.consts[attr_spec.struct_name])
423         elif attr_spec.sub_type:
424             decoded = attr.as_c_array(attr_spec.sub_type)
425         else:
426             decoded = attr.as_bin()
427         return decoded
428
429     def _decode(self, attrs, space):
430         attr_space = self.attr_sets[space]
431         rsp = dict()
432         for attr in attrs:
433             attr_spec = attr_space.attrs_by_val[attr.type]
434             if attr_spec["type"] == 'nest':
435                 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
436                 decoded = subdict
437             elif attr_spec['type'] == 'u8':
438                 decoded = attr.as_u8()
439             elif attr_spec['type'] == 'u16':
440                 decoded = attr.as_u16(attr_spec.byte_order)
441             elif attr_spec['type'] == 'u32':
442                 decoded = attr.as_u32(attr_spec.byte_order)
443             elif attr_spec['type'] == 'u64':
444                 decoded = attr.as_u64(attr_spec.byte_order)
445             elif attr_spec["type"] == 'string':
446                 decoded = attr.as_strz()
447             elif attr_spec["type"] == 'binary':
448                 decoded = self._decode_binary(attr, attr_spec)
449             elif attr_spec["type"] == 'flag':
450                 decoded = True
451             else:
452                 raise Exception(f'Unknown {attr.type} {attr_spec["name"]} {attr_spec["type"]}')
453
454             if not attr_spec.is_multi:
455                 rsp[attr_spec['name']] = decoded
456             elif attr_spec.name in rsp:
457                 rsp[attr_spec.name].append(decoded)
458             else:
459                 rsp[attr_spec.name] = [decoded]
460
461             if 'enum' in attr_spec:
462                 self._decode_enum(rsp, attr_spec)
463         return rsp
464
465     def _decode_extack_path(self, attrs, attr_set, offset, target):
466         for attr in attrs:
467             attr_spec = attr_set.attrs_by_val[attr.type]
468             if offset > target:
469                 break
470             if offset == target:
471                 return '.' + attr_spec.name
472
473             if offset + attr.full_len <= target:
474                 offset += attr.full_len
475                 continue
476             if attr_spec['type'] != 'nest':
477                 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
478             offset += 4
479             subpath = self._decode_extack_path(NlAttrs(attr.raw),
480                                                self.attr_sets[attr_spec['nested-attributes']],
481                                                offset, target)
482             if subpath is None:
483                 return None
484             return '.' + attr_spec.name + subpath
485
486         return None
487
488     def _decode_extack(self, request, attr_space, extack):
489         if 'bad-attr-offs' not in extack:
490             return
491
492         genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space))
493         path = self._decode_extack_path(genl_req.raw_attrs, attr_space,
494                                         20, extack['bad-attr-offs'])
495         if path:
496             del extack['bad-attr-offs']
497             extack['bad-attr'] = path
498
499     def handle_ntf(self, nl_msg, genl_msg):
500         msg = dict()
501         if self.include_raw:
502             msg['nlmsg'] = nl_msg
503             msg['genlmsg'] = genl_msg
504         op = self.rsp_by_value[genl_msg.genl_cmd]
505         msg['name'] = op['name']
506         msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name)
507         self.async_msg_queue.append(msg)
508
509     def check_ntf(self):
510         while True:
511             try:
512                 reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
513             except BlockingIOError:
514                 return
515
516             nms = NlMsgs(reply)
517             for nl_msg in nms:
518                 if nl_msg.error:
519                     print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
520                     print(nl_msg)
521                     continue
522                 if nl_msg.done:
523                     print("Netlink done while checking for ntf!?")
524                     continue
525
526                 gm = GenlMsg(nl_msg)
527                 if gm.genl_cmd not in self.async_msg_ids:
528                     print("Unexpected msg id done while checking for ntf", gm)
529                     continue
530
531                 self.handle_ntf(nl_msg, gm)
532
533     def operation_do_attributes(self, name):
534       """
535       For a given operation name, find and return a supported
536       set of attributes (as a dict).
537       """
538       op = self.find_operation(name)
539       if not op:
540         return None
541
542       return op['do']['request']['attributes'].copy()
543
544     def _op(self, method, vals, dump=False):
545         op = self.ops[method]
546
547         nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
548         if dump:
549             nl_flags |= Netlink.NLM_F_DUMP
550
551         req_seq = random.randint(1024, 65535)
552         msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq)
553         fixed_header_members = []
554         if op.fixed_header:
555             fixed_header_members = self.consts[op.fixed_header].members
556             for m in fixed_header_members:
557                 value = vals.pop(m.name)
558                 format, _ = NlAttr.type_formats[m.type]
559                 msg += struct.pack(format, value)
560         for name, value in vals.items():
561             msg += self._add_attr(op.attr_set.name, name, value)
562         msg = _genl_msg_finalize(msg)
563
564         self.sock.send(msg, 0)
565
566         done = False
567         rsp = []
568         while not done:
569             reply = self.sock.recv(128 * 1024)
570             nms = NlMsgs(reply, attr_space=op.attr_set)
571             for nl_msg in nms:
572                 if nl_msg.extack:
573                     self._decode_extack(msg, op.attr_set, nl_msg.extack)
574
575                 if nl_msg.error:
576                     raise NlError(nl_msg)
577                 if nl_msg.done:
578                     if nl_msg.extack:
579                         print("Netlink warning:")
580                         print(nl_msg)
581                     done = True
582                     break
583
584                 gm = GenlMsg(nl_msg, fixed_header_members)
585                 # Check if this is a reply to our request
586                 if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value:
587                     if gm.genl_cmd in self.async_msg_ids:
588                         self.handle_ntf(nl_msg, gm)
589                         continue
590                     else:
591                         print('Unexpected message: ' + repr(gm))
592                         continue
593
594                 rsp_msg = self._decode(gm.raw_attrs, op.attr_set.name)
595                 rsp_msg.update(gm.fixed_header_attrs)
596                 rsp.append(rsp_msg)
597
598         if not rsp:
599             return None
600         if not dump and len(rsp) == 1:
601             return rsp[0]
602         return rsp
603
604     def do(self, method, vals):
605         return self._op(method, vals)
606
607     def dump(self, method, vals):
608         return self._op(method, vals, dump=True)