drm/nouveau: fence: fix undefined fence state after emit
[platform/kernel/linux-rpi.git] / tools / net / ynl / lib / ynl.py
1 # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3 from collections import namedtuple
4 import functools
5 import os
6 import random
7 import socket
8 import struct
9 from struct import Struct
10 import yaml
11 import ipaddress
12 import uuid
13
14 from .nlspec import SpecFamily
15
16 #
17 # Generic Netlink code which should really be in some library, but I can't quickly find one.
18 #
19
20
21 class Netlink:
22     # Netlink socket
23     SOL_NETLINK = 270
24
25     NETLINK_ADD_MEMBERSHIP = 1
26     NETLINK_CAP_ACK = 10
27     NETLINK_EXT_ACK = 11
28
29     # Netlink message
30     NLMSG_ERROR = 2
31     NLMSG_DONE = 3
32
33     NLM_F_REQUEST = 1
34     NLM_F_ACK = 4
35     NLM_F_ROOT = 0x100
36     NLM_F_MATCH = 0x200
37     NLM_F_APPEND = 0x800
38
39     NLM_F_CAPPED = 0x100
40     NLM_F_ACK_TLVS = 0x200
41
42     NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
43
44     NLA_F_NESTED = 0x8000
45     NLA_F_NET_BYTEORDER = 0x4000
46
47     NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
48
49     # Genetlink defines
50     NETLINK_GENERIC = 16
51
52     GENL_ID_CTRL = 0x10
53
54     # nlctrl
55     CTRL_CMD_GETFAMILY = 3
56
57     CTRL_ATTR_FAMILY_ID = 1
58     CTRL_ATTR_FAMILY_NAME = 2
59     CTRL_ATTR_MAXATTR = 5
60     CTRL_ATTR_MCAST_GROUPS = 7
61
62     CTRL_ATTR_MCAST_GRP_NAME = 1
63     CTRL_ATTR_MCAST_GRP_ID = 2
64
65     # Extack types
66     NLMSGERR_ATTR_MSG = 1
67     NLMSGERR_ATTR_OFFS = 2
68     NLMSGERR_ATTR_COOKIE = 3
69     NLMSGERR_ATTR_POLICY = 4
70     NLMSGERR_ATTR_MISS_TYPE = 5
71     NLMSGERR_ATTR_MISS_NEST = 6
72
73
74 class NlError(Exception):
75   def __init__(self, nl_msg):
76     self.nl_msg = nl_msg
77
78   def __str__(self):
79     return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
80
81
82 class NlAttr:
83     ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
84     type_formats = {
85         'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
86         's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
87         'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
88         's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
89         'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
90         's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
91         'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
92         's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
93     }
94
95     def __init__(self, raw, offset):
96         self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
97         self.type = self._type & ~Netlink.NLA_TYPE_MASK
98         self.payload_len = self._len
99         self.full_len = (self.payload_len + 3) & ~3
100         self.raw = raw[offset + 4:offset + self.payload_len]
101
102     @classmethod
103     def get_format(cls, attr_type, byte_order=None):
104         format = cls.type_formats[attr_type]
105         if byte_order:
106             return format.big if byte_order == "big-endian" \
107                 else format.little
108         return format.native
109
110     @classmethod
111     def formatted_string(cls, raw, display_hint):
112         if display_hint == 'mac':
113             formatted = ':'.join('%02x' % b for b in raw)
114         elif display_hint == 'hex':
115             formatted = bytes.hex(raw, ' ')
116         elif display_hint in [ 'ipv4', 'ipv6' ]:
117             formatted = format(ipaddress.ip_address(raw))
118         elif display_hint == 'uuid':
119             formatted = str(uuid.UUID(bytes=raw))
120         else:
121             formatted = raw
122         return formatted
123
124     def as_scalar(self, attr_type, byte_order=None):
125         format = self.get_format(attr_type, byte_order)
126         return format.unpack(self.raw)[0]
127
128     def as_strz(self):
129         return self.raw.decode('ascii')[:-1]
130
131     def as_bin(self):
132         return self.raw
133
134     def as_c_array(self, type):
135         format = self.get_format(type)
136         return [ x[0] for x in format.iter_unpack(self.raw) ]
137
138     def as_struct(self, members):
139         value = dict()
140         offset = 0
141         for m in members:
142             # TODO: handle non-scalar members
143             if m.type == 'binary':
144                 decoded = self.raw[offset:offset+m['len']]
145                 offset += m['len']
146             elif m.type in NlAttr.type_formats:
147                 format = self.get_format(m.type, m.byte_order)
148                 [ decoded ] = format.unpack_from(self.raw, offset)
149                 offset += format.size
150             if m.display_hint:
151                 decoded = self.formatted_string(decoded, m.display_hint)
152             value[m.name] = decoded
153         return value
154
155     def __repr__(self):
156         return f"[type:{self.type} len:{self._len}] {self.raw}"
157
158
159 class NlAttrs:
160     def __init__(self, msg):
161         self.attrs = []
162
163         offset = 0
164         while offset < len(msg):
165             attr = NlAttr(msg, offset)
166             offset += attr.full_len
167             self.attrs.append(attr)
168
169     def __iter__(self):
170         yield from self.attrs
171
172     def __repr__(self):
173         msg = ''
174         for a in self.attrs:
175             if msg:
176                 msg += '\n'
177             msg += repr(a)
178         return msg
179
180
181 class NlMsg:
182     def __init__(self, msg, offset, attr_space=None):
183         self.hdr = msg[offset:offset + 16]
184
185         self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
186             struct.unpack("IHHII", self.hdr)
187
188         self.raw = msg[offset + 16:offset + self.nl_len]
189
190         self.error = 0
191         self.done = 0
192
193         extack_off = None
194         if self.nl_type == Netlink.NLMSG_ERROR:
195             self.error = struct.unpack("i", self.raw[0:4])[0]
196             self.done = 1
197             extack_off = 20
198         elif self.nl_type == Netlink.NLMSG_DONE:
199             self.done = 1
200             extack_off = 4
201
202         self.extack = None
203         if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
204             self.extack = dict()
205             extack_attrs = NlAttrs(self.raw[extack_off:])
206             for extack in extack_attrs:
207                 if extack.type == Netlink.NLMSGERR_ATTR_MSG:
208                     self.extack['msg'] = extack.as_strz()
209                 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
210                     self.extack['miss-type'] = extack.as_scalar('u32')
211                 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
212                     self.extack['miss-nest'] = extack.as_scalar('u32')
213                 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
214                     self.extack['bad-attr-offs'] = extack.as_scalar('u32')
215                 else:
216                     if 'unknown' not in self.extack:
217                         self.extack['unknown'] = []
218                     self.extack['unknown'].append(extack)
219
220             if attr_space:
221                 # We don't have the ability to parse nests yet, so only do global
222                 if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
223                     miss_type = self.extack['miss-type']
224                     if miss_type in attr_space.attrs_by_val:
225                         spec = attr_space.attrs_by_val[miss_type]
226                         desc = spec['name']
227                         if 'doc' in spec:
228                             desc += f" ({spec['doc']})"
229                         self.extack['miss-type'] = desc
230
231     def __repr__(self):
232         msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
233         if self.error:
234             msg += '\terror: ' + str(self.error)
235         if self.extack:
236             msg += '\textack: ' + repr(self.extack)
237         return msg
238
239
240 class NlMsgs:
241     def __init__(self, data, attr_space=None):
242         self.msgs = []
243
244         offset = 0
245         while offset < len(data):
246             msg = NlMsg(data, offset, attr_space=attr_space)
247             offset += msg.nl_len
248             self.msgs.append(msg)
249
250     def __iter__(self):
251         yield from self.msgs
252
253
254 genl_family_name_to_id = None
255
256
257 def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
258     # we prepend length in _genl_msg_finalize()
259     if seq is None:
260         seq = random.randint(1, 1024)
261     nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
262     genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
263     return nlmsg + genlmsg
264
265
266 def _genl_msg_finalize(msg):
267     return struct.pack("I", len(msg) + 4) + msg
268
269
270 def _genl_load_families():
271     with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
272         sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
273
274         msg = _genl_msg(Netlink.GENL_ID_CTRL,
275                         Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
276                         Netlink.CTRL_CMD_GETFAMILY, 1)
277         msg = _genl_msg_finalize(msg)
278
279         sock.send(msg, 0)
280
281         global genl_family_name_to_id
282         genl_family_name_to_id = dict()
283
284         while True:
285             reply = sock.recv(128 * 1024)
286             nms = NlMsgs(reply)
287             for nl_msg in nms:
288                 if nl_msg.error:
289                     print("Netlink error:", nl_msg.error)
290                     return
291                 if nl_msg.done:
292                     return
293
294                 gm = GenlMsg(nl_msg)
295                 fam = dict()
296                 for attr in gm.raw_attrs:
297                     if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
298                         fam['id'] = attr.as_scalar('u16')
299                     elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
300                         fam['name'] = attr.as_strz()
301                     elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
302                         fam['maxattr'] = attr.as_scalar('u32')
303                     elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
304                         fam['mcast'] = dict()
305                         for entry in NlAttrs(attr.raw):
306                             mcast_name = None
307                             mcast_id = None
308                             for entry_attr in NlAttrs(entry.raw):
309                                 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
310                                     mcast_name = entry_attr.as_strz()
311                                 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
312                                     mcast_id = entry_attr.as_scalar('u32')
313                             if mcast_name and mcast_id is not None:
314                                 fam['mcast'][mcast_name] = mcast_id
315                 if 'name' in fam and 'id' in fam:
316                     genl_family_name_to_id[fam['name']] = fam
317
318
319 class GenlMsg:
320     def __init__(self, nl_msg, fixed_header_members=[]):
321         self.nl = nl_msg
322
323         self.hdr = nl_msg.raw[0:4]
324         offset = 4
325
326         self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr)
327
328         self.fixed_header_attrs = dict()
329         for m in fixed_header_members:
330             format = NlAttr.get_format(m.type, m.byte_order)
331             decoded = format.unpack_from(nl_msg.raw, offset)
332             offset += format.size
333             self.fixed_header_attrs[m.name] = decoded[0]
334
335         self.raw = nl_msg.raw[offset:]
336         self.raw_attrs = NlAttrs(self.raw)
337
338     def __repr__(self):
339         msg = repr(self.nl)
340         msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
341         for a in self.raw_attrs:
342             msg += '\t\t' + repr(a) + '\n'
343         return msg
344
345
346 class GenlFamily:
347     def __init__(self, family_name):
348         self.family_name = family_name
349
350         global genl_family_name_to_id
351         if genl_family_name_to_id is None:
352             _genl_load_families()
353
354         self.genl_family = genl_family_name_to_id[family_name]
355         self.family_id = genl_family_name_to_id[family_name]['id']
356
357
358 #
359 # YNL implementation details.
360 #
361
362
363 class YnlFamily(SpecFamily):
364     def __init__(self, def_path, schema=None):
365         super().__init__(def_path, schema)
366
367         self.include_raw = False
368
369         self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC)
370         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
371         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
372
373         self.async_msg_ids = set()
374         self.async_msg_queue = []
375
376         for msg in self.msgs.values():
377             if msg.is_async:
378                 self.async_msg_ids.add(msg.rsp_value)
379
380         for op_name, op in self.ops.items():
381             bound_f = functools.partial(self._op, op_name)
382             setattr(self, op.ident_name, bound_f)
383
384         try:
385             self.family = GenlFamily(self.yaml['name'])
386         except KeyError:
387             raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
388
389     def ntf_subscribe(self, mcast_name):
390         if mcast_name not in self.family.genl_family['mcast']:
391             raise Exception(f'Multicast group "{mcast_name}" not present in the family')
392
393         self.sock.bind((0, 0))
394         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
395                              self.family.genl_family['mcast'][mcast_name])
396
397     def _add_attr(self, space, name, value):
398         attr = self.attr_sets[space][name]
399         nl_type = attr.value
400         if attr["type"] == 'nest':
401             nl_type |= Netlink.NLA_F_NESTED
402             attr_payload = b''
403             for subname, subvalue in value.items():
404                 attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
405         elif attr["type"] == 'flag':
406             attr_payload = b''
407         elif attr["type"] == 'string':
408             attr_payload = str(value).encode('ascii') + b'\x00'
409         elif attr["type"] == 'binary':
410             attr_payload = bytes.fromhex(value)
411         elif attr['type'] in NlAttr.type_formats:
412             format = NlAttr.get_format(attr['type'], attr.byte_order)
413             attr_payload = format.pack(int(value))
414         else:
415             raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
416
417         pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
418         return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
419
420     def _decode_enum(self, rsp, attr_spec):
421         raw = rsp[attr_spec['name']]
422         enum = self.consts[attr_spec['enum']]
423         i = attr_spec.get('value-start', 0)
424         if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']:
425             value = set()
426             while raw:
427                 if raw & 1:
428                     value.add(enum.entries_by_val[i].name)
429                 raw >>= 1
430                 i += 1
431         else:
432             value = enum.entries_by_val[raw - i].name
433         rsp[attr_spec['name']] = value
434
435     def _decode_binary(self, attr, attr_spec):
436         if attr_spec.struct_name:
437             members = self.consts[attr_spec.struct_name]
438             decoded = attr.as_struct(members)
439             for m in members:
440                 if m.enum:
441                     self._decode_enum(decoded, m)
442         elif attr_spec.sub_type:
443             decoded = attr.as_c_array(attr_spec.sub_type)
444         else:
445             decoded = attr.as_bin()
446             if attr_spec.display_hint:
447                 decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint)
448         return decoded
449
450     def _decode(self, attrs, space):
451         attr_space = self.attr_sets[space]
452         rsp = dict()
453         for attr in attrs:
454             attr_spec = attr_space.attrs_by_val[attr.type]
455             if attr_spec["type"] == 'nest':
456                 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
457                 decoded = subdict
458             elif attr_spec["type"] == 'string':
459                 decoded = attr.as_strz()
460             elif attr_spec["type"] == 'binary':
461                 decoded = self._decode_binary(attr, attr_spec)
462             elif attr_spec["type"] == 'flag':
463                 decoded = True
464             elif attr_spec["type"] in NlAttr.type_formats:
465                 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
466             else:
467                 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
468
469             if not attr_spec.is_multi:
470                 rsp[attr_spec['name']] = decoded
471             elif attr_spec.name in rsp:
472                 rsp[attr_spec.name].append(decoded)
473             else:
474                 rsp[attr_spec.name] = [decoded]
475
476             if 'enum' in attr_spec:
477                 self._decode_enum(rsp, attr_spec)
478         return rsp
479
480     def _decode_extack_path(self, attrs, attr_set, offset, target):
481         for attr in attrs:
482             attr_spec = attr_set.attrs_by_val[attr.type]
483             if offset > target:
484                 break
485             if offset == target:
486                 return '.' + attr_spec.name
487
488             if offset + attr.full_len <= target:
489                 offset += attr.full_len
490                 continue
491             if attr_spec['type'] != 'nest':
492                 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
493             offset += 4
494             subpath = self._decode_extack_path(NlAttrs(attr.raw),
495                                                self.attr_sets[attr_spec['nested-attributes']],
496                                                offset, target)
497             if subpath is None:
498                 return None
499             return '.' + attr_spec.name + subpath
500
501         return None
502
503     def _decode_extack(self, request, attr_space, extack):
504         if 'bad-attr-offs' not in extack:
505             return
506
507         genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space))
508         path = self._decode_extack_path(genl_req.raw_attrs, attr_space,
509                                         20, extack['bad-attr-offs'])
510         if path:
511             del extack['bad-attr-offs']
512             extack['bad-attr'] = path
513
514     def handle_ntf(self, nl_msg, genl_msg):
515         msg = dict()
516         if self.include_raw:
517             msg['nlmsg'] = nl_msg
518             msg['genlmsg'] = genl_msg
519         op = self.rsp_by_value[genl_msg.genl_cmd]
520         msg['name'] = op['name']
521         msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name)
522         self.async_msg_queue.append(msg)
523
524     def check_ntf(self):
525         while True:
526             try:
527                 reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
528             except BlockingIOError:
529                 return
530
531             nms = NlMsgs(reply)
532             for nl_msg in nms:
533                 if nl_msg.error:
534                     print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
535                     print(nl_msg)
536                     continue
537                 if nl_msg.done:
538                     print("Netlink done while checking for ntf!?")
539                     continue
540
541                 gm = GenlMsg(nl_msg)
542                 if gm.genl_cmd not in self.async_msg_ids:
543                     print("Unexpected msg id done while checking for ntf", gm)
544                     continue
545
546                 self.handle_ntf(nl_msg, gm)
547
548     def operation_do_attributes(self, name):
549       """
550       For a given operation name, find and return a supported
551       set of attributes (as a dict).
552       """
553       op = self.find_operation(name)
554       if not op:
555         return None
556
557       return op['do']['request']['attributes'].copy()
558
559     def _op(self, method, vals, dump=False):
560         op = self.ops[method]
561
562         nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
563         if dump:
564             nl_flags |= Netlink.NLM_F_DUMP
565
566         req_seq = random.randint(1024, 65535)
567         msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq)
568         fixed_header_members = []
569         if op.fixed_header:
570             fixed_header_members = self.consts[op.fixed_header].members
571             for m in fixed_header_members:
572                 value = vals.pop(m.name) if m.name in vals else 0
573                 format = NlAttr.get_format(m.type, m.byte_order)
574                 msg += format.pack(value)
575         for name, value in vals.items():
576             msg += self._add_attr(op.attr_set.name, name, value)
577         msg = _genl_msg_finalize(msg)
578
579         self.sock.send(msg, 0)
580
581         done = False
582         rsp = []
583         while not done:
584             reply = self.sock.recv(128 * 1024)
585             nms = NlMsgs(reply, attr_space=op.attr_set)
586             for nl_msg in nms:
587                 if nl_msg.extack:
588                     self._decode_extack(msg, op.attr_set, nl_msg.extack)
589
590                 if nl_msg.error:
591                     raise NlError(nl_msg)
592                 if nl_msg.done:
593                     if nl_msg.extack:
594                         print("Netlink warning:")
595                         print(nl_msg)
596                     done = True
597                     break
598
599                 gm = GenlMsg(nl_msg, fixed_header_members)
600                 # Check if this is a reply to our request
601                 if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value:
602                     if gm.genl_cmd in self.async_msg_ids:
603                         self.handle_ntf(nl_msg, gm)
604                         continue
605                     else:
606                         print('Unexpected message: ' + repr(gm))
607                         continue
608
609                 rsp_msg = self._decode(gm.raw_attrs, op.attr_set.name)
610                 rsp_msg.update(gm.fixed_header_attrs)
611                 rsp.append(rsp_msg)
612
613         if not rsp:
614             return None
615         if not dump and len(rsp) == 1:
616             return rsp[0]
617         return rsp
618
619     def do(self, method, vals):
620         return self._op(method, vals)
621
622     def dump(self, method, vals):
623         return self._op(method, vals, dump=True)