1 # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
3 from collections import namedtuple
9 from struct import Struct
14 from .nlspec import SpecFamily
17 # Generic Netlink code which should really be in some library, but I can't quickly find one.
25 NETLINK_ADD_MEMBERSHIP = 1
40 NLM_F_ACK_TLVS = 0x200
42 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
45 NLA_F_NET_BYTEORDER = 0x4000
47 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
55 CTRL_CMD_GETFAMILY = 3
57 CTRL_ATTR_FAMILY_ID = 1
58 CTRL_ATTR_FAMILY_NAME = 2
60 CTRL_ATTR_MCAST_GROUPS = 7
62 CTRL_ATTR_MCAST_GRP_NAME = 1
63 CTRL_ATTR_MCAST_GRP_ID = 2
67 NLMSGERR_ATTR_OFFS = 2
68 NLMSGERR_ATTR_COOKIE = 3
69 NLMSGERR_ATTR_POLICY = 4
70 NLMSGERR_ATTR_MISS_TYPE = 5
71 NLMSGERR_ATTR_MISS_NEST = 6
74 class NlError(Exception):
75 def __init__(self, nl_msg):
79 return f"Netlink error: {os.strerror(-self.nl_msg.error)}\n{self.nl_msg}"
83 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
85 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")),
86 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")),
87 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
88 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
89 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
90 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
91 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
92 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
95 def __init__(self, raw, offset):
96 self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
97 self.type = self._type & ~Netlink.NLA_TYPE_MASK
98 self.payload_len = self._len
99 self.full_len = (self.payload_len + 3) & ~3
100 self.raw = raw[offset + 4:offset + self.payload_len]
103 def get_format(cls, attr_type, byte_order=None):
104 format = cls.type_formats[attr_type]
106 return format.big if byte_order == "big-endian" \
111 def formatted_string(cls, raw, display_hint):
112 if display_hint == 'mac':
113 formatted = ':'.join('%02x' % b for b in raw)
114 elif display_hint == 'hex':
115 formatted = bytes.hex(raw, ' ')
116 elif display_hint in [ 'ipv4', 'ipv6' ]:
117 formatted = format(ipaddress.ip_address(raw))
118 elif display_hint == 'uuid':
119 formatted = str(uuid.UUID(bytes=raw))
124 def as_scalar(self, attr_type, byte_order=None):
125 format = self.get_format(attr_type, byte_order)
126 return format.unpack(self.raw)[0]
129 return self.raw.decode('ascii')[:-1]
134 def as_c_array(self, type):
135 format = self.get_format(type)
136 return [ x[0] for x in format.iter_unpack(self.raw) ]
138 def as_struct(self, members):
142 # TODO: handle non-scalar members
143 if m.type == 'binary':
144 decoded = self.raw[offset:offset+m['len']]
146 elif m.type in NlAttr.type_formats:
147 format = self.get_format(m.type, m.byte_order)
148 [ decoded ] = format.unpack_from(self.raw, offset)
149 offset += format.size
151 decoded = self.formatted_string(decoded, m.display_hint)
152 value[m.name] = decoded
156 return f"[type:{self.type} len:{self._len}] {self.raw}"
160 def __init__(self, msg):
164 while offset < len(msg):
165 attr = NlAttr(msg, offset)
166 offset += attr.full_len
167 self.attrs.append(attr)
170 yield from self.attrs
182 def __init__(self, msg, offset, attr_space=None):
183 self.hdr = msg[offset:offset + 16]
185 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
186 struct.unpack("IHHII", self.hdr)
188 self.raw = msg[offset + 16:offset + self.nl_len]
194 if self.nl_type == Netlink.NLMSG_ERROR:
195 self.error = struct.unpack("i", self.raw[0:4])[0]
198 elif self.nl_type == Netlink.NLMSG_DONE:
203 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
205 extack_attrs = NlAttrs(self.raw[extack_off:])
206 for extack in extack_attrs:
207 if extack.type == Netlink.NLMSGERR_ATTR_MSG:
208 self.extack['msg'] = extack.as_strz()
209 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
210 self.extack['miss-type'] = extack.as_scalar('u32')
211 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
212 self.extack['miss-nest'] = extack.as_scalar('u32')
213 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
214 self.extack['bad-attr-offs'] = extack.as_scalar('u32')
216 if 'unknown' not in self.extack:
217 self.extack['unknown'] = []
218 self.extack['unknown'].append(extack)
221 # We don't have the ability to parse nests yet, so only do global
222 if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
223 miss_type = self.extack['miss-type']
224 if miss_type in attr_space.attrs_by_val:
225 spec = attr_space.attrs_by_val[miss_type]
228 desc += f" ({spec['doc']})"
229 self.extack['miss-type'] = desc
232 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}\n"
234 msg += '\terror: ' + str(self.error)
236 msg += '\textack: ' + repr(self.extack)
241 def __init__(self, data, attr_space=None):
245 while offset < len(data):
246 msg = NlMsg(data, offset, attr_space=attr_space)
248 self.msgs.append(msg)
254 genl_family_name_to_id = None
257 def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
258 # we prepend length in _genl_msg_finalize()
260 seq = random.randint(1, 1024)
261 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
262 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
263 return nlmsg + genlmsg
266 def _genl_msg_finalize(msg):
267 return struct.pack("I", len(msg) + 4) + msg
270 def _genl_load_families():
271 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
272 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
274 msg = _genl_msg(Netlink.GENL_ID_CTRL,
275 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
276 Netlink.CTRL_CMD_GETFAMILY, 1)
277 msg = _genl_msg_finalize(msg)
281 global genl_family_name_to_id
282 genl_family_name_to_id = dict()
285 reply = sock.recv(128 * 1024)
289 print("Netlink error:", nl_msg.error)
296 for attr in gm.raw_attrs:
297 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
298 fam['id'] = attr.as_scalar('u16')
299 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
300 fam['name'] = attr.as_strz()
301 elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
302 fam['maxattr'] = attr.as_scalar('u32')
303 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
304 fam['mcast'] = dict()
305 for entry in NlAttrs(attr.raw):
308 for entry_attr in NlAttrs(entry.raw):
309 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
310 mcast_name = entry_attr.as_strz()
311 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
312 mcast_id = entry_attr.as_scalar('u32')
313 if mcast_name and mcast_id is not None:
314 fam['mcast'][mcast_name] = mcast_id
315 if 'name' in fam and 'id' in fam:
316 genl_family_name_to_id[fam['name']] = fam
320 def __init__(self, nl_msg, fixed_header_members=[]):
323 self.hdr = nl_msg.raw[0:4]
326 self.genl_cmd, self.genl_version, _ = struct.unpack("BBH", self.hdr)
328 self.fixed_header_attrs = dict()
329 for m in fixed_header_members:
330 format = NlAttr.get_format(m.type, m.byte_order)
331 decoded = format.unpack_from(nl_msg.raw, offset)
332 offset += format.size
333 self.fixed_header_attrs[m.name] = decoded[0]
335 self.raw = nl_msg.raw[offset:]
336 self.raw_attrs = NlAttrs(self.raw)
340 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
341 for a in self.raw_attrs:
342 msg += '\t\t' + repr(a) + '\n'
347 def __init__(self, family_name):
348 self.family_name = family_name
350 global genl_family_name_to_id
351 if genl_family_name_to_id is None:
352 _genl_load_families()
354 self.genl_family = genl_family_name_to_id[family_name]
355 self.family_id = genl_family_name_to_id[family_name]['id']
359 # YNL implementation details.
363 class YnlFamily(SpecFamily):
364 def __init__(self, def_path, schema=None):
365 super().__init__(def_path, schema)
367 self.include_raw = False
369 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC)
370 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
371 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
373 self.async_msg_ids = set()
374 self.async_msg_queue = []
376 for msg in self.msgs.values():
378 self.async_msg_ids.add(msg.rsp_value)
380 for op_name, op in self.ops.items():
381 bound_f = functools.partial(self._op, op_name)
382 setattr(self, op.ident_name, bound_f)
385 self.family = GenlFamily(self.yaml['name'])
387 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
389 def ntf_subscribe(self, mcast_name):
390 if mcast_name not in self.family.genl_family['mcast']:
391 raise Exception(f'Multicast group "{mcast_name}" not present in the family')
393 self.sock.bind((0, 0))
394 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
395 self.family.genl_family['mcast'][mcast_name])
397 def _add_attr(self, space, name, value):
398 attr = self.attr_sets[space][name]
400 if attr["type"] == 'nest':
401 nl_type |= Netlink.NLA_F_NESTED
403 for subname, subvalue in value.items():
404 attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
405 elif attr["type"] == 'flag':
407 elif attr["type"] == 'string':
408 attr_payload = str(value).encode('ascii') + b'\x00'
409 elif attr["type"] == 'binary':
410 attr_payload = bytes.fromhex(value)
411 elif attr['type'] in NlAttr.type_formats:
412 format = NlAttr.get_format(attr['type'], attr.byte_order)
413 attr_payload = format.pack(int(value))
415 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
417 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
418 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
420 def _decode_enum(self, rsp, attr_spec):
421 raw = rsp[attr_spec['name']]
422 enum = self.consts[attr_spec['enum']]
423 i = attr_spec.get('value-start', 0)
424 if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']:
428 value.add(enum.entries_by_val[i].name)
432 value = enum.entries_by_val[raw - i].name
433 rsp[attr_spec['name']] = value
435 def _decode_binary(self, attr, attr_spec):
436 if attr_spec.struct_name:
437 members = self.consts[attr_spec.struct_name]
438 decoded = attr.as_struct(members)
441 self._decode_enum(decoded, m)
442 elif attr_spec.sub_type:
443 decoded = attr.as_c_array(attr_spec.sub_type)
445 decoded = attr.as_bin()
446 if attr_spec.display_hint:
447 decoded = NlAttr.formatted_string(decoded, attr_spec.display_hint)
450 def _decode(self, attrs, space):
451 attr_space = self.attr_sets[space]
454 attr_spec = attr_space.attrs_by_val[attr.type]
455 if attr_spec["type"] == 'nest':
456 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
458 elif attr_spec["type"] == 'string':
459 decoded = attr.as_strz()
460 elif attr_spec["type"] == 'binary':
461 decoded = self._decode_binary(attr, attr_spec)
462 elif attr_spec["type"] == 'flag':
464 elif attr_spec["type"] in NlAttr.type_formats:
465 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
467 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
469 if not attr_spec.is_multi:
470 rsp[attr_spec['name']] = decoded
471 elif attr_spec.name in rsp:
472 rsp[attr_spec.name].append(decoded)
474 rsp[attr_spec.name] = [decoded]
476 if 'enum' in attr_spec:
477 self._decode_enum(rsp, attr_spec)
480 def _decode_extack_path(self, attrs, attr_set, offset, target):
482 attr_spec = attr_set.attrs_by_val[attr.type]
486 return '.' + attr_spec.name
488 if offset + attr.full_len <= target:
489 offset += attr.full_len
491 if attr_spec['type'] != 'nest':
492 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
494 subpath = self._decode_extack_path(NlAttrs(attr.raw),
495 self.attr_sets[attr_spec['nested-attributes']],
499 return '.' + attr_spec.name + subpath
503 def _decode_extack(self, request, attr_space, extack):
504 if 'bad-attr-offs' not in extack:
507 genl_req = GenlMsg(NlMsg(request, 0, attr_space=attr_space))
508 path = self._decode_extack_path(genl_req.raw_attrs, attr_space,
509 20, extack['bad-attr-offs'])
511 del extack['bad-attr-offs']
512 extack['bad-attr'] = path
514 def handle_ntf(self, nl_msg, genl_msg):
517 msg['nlmsg'] = nl_msg
518 msg['genlmsg'] = genl_msg
519 op = self.rsp_by_value[genl_msg.genl_cmd]
520 msg['name'] = op['name']
521 msg['msg'] = self._decode(genl_msg.raw_attrs, op.attr_set.name)
522 self.async_msg_queue.append(msg)
527 reply = self.sock.recv(128 * 1024, socket.MSG_DONTWAIT)
528 except BlockingIOError:
534 print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
538 print("Netlink done while checking for ntf!?")
542 if gm.genl_cmd not in self.async_msg_ids:
543 print("Unexpected msg id done while checking for ntf", gm)
546 self.handle_ntf(nl_msg, gm)
548 def operation_do_attributes(self, name):
550 For a given operation name, find and return a supported
551 set of attributes (as a dict).
553 op = self.find_operation(name)
557 return op['do']['request']['attributes'].copy()
559 def _op(self, method, vals, dump=False):
560 op = self.ops[method]
562 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
564 nl_flags |= Netlink.NLM_F_DUMP
566 req_seq = random.randint(1024, 65535)
567 msg = _genl_msg(self.family.family_id, nl_flags, op.req_value, 1, req_seq)
568 fixed_header_members = []
570 fixed_header_members = self.consts[op.fixed_header].members
571 for m in fixed_header_members:
572 value = vals.pop(m.name) if m.name in vals else 0
573 format = NlAttr.get_format(m.type, m.byte_order)
574 msg += format.pack(value)
575 for name, value in vals.items():
576 msg += self._add_attr(op.attr_set.name, name, value)
577 msg = _genl_msg_finalize(msg)
579 self.sock.send(msg, 0)
584 reply = self.sock.recv(128 * 1024)
585 nms = NlMsgs(reply, attr_space=op.attr_set)
588 self._decode_extack(msg, op.attr_set, nl_msg.extack)
591 raise NlError(nl_msg)
594 print("Netlink warning:")
599 gm = GenlMsg(nl_msg, fixed_header_members)
600 # Check if this is a reply to our request
601 if nl_msg.nl_seq != req_seq or gm.genl_cmd != op.rsp_value:
602 if gm.genl_cmd in self.async_msg_ids:
603 self.handle_ntf(nl_msg, gm)
606 print('Unexpected message: ' + repr(gm))
609 rsp_msg = self._decode(gm.raw_attrs, op.attr_set.name)
610 rsp_msg.update(gm.fixed_header_attrs)
615 if not dump and len(rsp) == 1:
619 def do(self, method, vals):
620 return self._op(method, vals)
622 def dump(self, method, vals):
623 return self._op(method, vals, dump=True)