7 from dataclasses import dataclass
9 from construct import *
11 LOCAL_ADDR = ("0.0.0.0", 5388)
12 # LOCAL_ADDR = ("::%178", 5388)
14 DEST_ADDR = ("224.0.0.251", 5353)
15 # DEST_ADDR = ("ff02::fb", 5353)
18 QUERY_QNAME = "_googlecast._tcp.local"
20 def EndsWithEmpty(x, lst, ctx):
24 class QNameValidator(Validator):
26 def _validate(self, obj, context, path):
30 class QNameArrayAdapter(Adapter):
32 def _decode(self, obj, context, path):
33 return ".".join(map(str, obj[:-1]))
35 def _encode(self, obj, context, path):
36 return list(map(str, obj.split("."))) + [""]
44 class AnswerPart(Subconstruct):
47 self.name = "AnswerPart"
48 self.subcon = PascalString
49 self.flagbuildnone = False
52 def _parse(self, stream, context, path):
53 # read from the stream
55 len = stream.read(1)[0]
57 if (len & 0xC0) == 0xC0:
58 l2 = stream.read(1)[0]
59 return AnswerPtr(((len & 0x3F) << 8) | l2)
61 return stream.read(len)
63 def _build(self, obj, stream, context, path):
64 # write obj to the stream
65 # return same value (obj) or a modified value
66 # that will replace the context dictionary entry
67 raise Error("Answer part build not yet implemented")
69 def _sizeof(self, context, path):
70 # return computed size (when fixed size or depends on context)
71 # or raise SizeofError (when variable size or unknown)
72 raise SizeofError("Answer part has avariable size")
75 def EndsWithEmptyOrPointer(x, lst, ctx):
76 return (not x) or isinstance(x, AnswerPtr)
79 class IpAddressAdapter(Adapter):
81 def _decode(self, obj, context, path):
82 return ".".join(map(str, obj))
84 def _encode(self, obj, context, path):
85 return list(map(int, obj.split(".")))
88 IpAddress = IpAddressAdapter(Byte[4])
90 HEX = HexDump(GreedyBytes)
92 QNAME = QNameArrayAdapter(
93 QNameValidator(RepeatUntil(EndsWithEmpty, PascalString(Byte, "utf8"))))
96 "NAME" / RepeatUntil(EndsWithEmptyOrPointer, AnswerPart()),
112 "CLASS" / Enum(BitsInteger(15), IN=1),
119 "TXT": GreedyRange(PascalString(Byte, "utf8")),
121 "AAAA": Array(16, Byte),
122 "PTR": RepeatUntil(EndsWithEmptyOrPointer, AnswerPart()),
124 default=GreedyBytes)),
129 "Control" / BitStruct(
130 "QR" / Default(Flag, False),
132 Default(Enum(BitsInteger(4), QUERY=0, IQUERY=1, STATUS=2), "QUERY"),
133 "AA" / Default(Flag, False),
134 "TC" / Default(Flag, False),
135 "RD" / Default(Flag, False),
136 "RA" / Default(Flag, False),
138 "AD" / Default(Flag, False),
139 "CD" / Default(Flag, False),
153 "QuestionCount" / Rebuild(Int16ub, len_(this.Questions)),
154 "AnswerCount" / Rebuild(Int16ub, len_(this.Answers)),
155 "AuthorityCount" / Rebuild(Int16ub, len_(this.Authorities)),
156 "AdditionalCount" / Rebuild(Int16ub, len_(this.Additionals)),
175 "QCLASS" / BitStruct(
176 "Unicast" / Default(Flag, False),
177 "Class" / Default(Enum(BitsInteger(15), IN=1, ANY=255), "IN"),
181 "Answers" / Default(Array(this.AnswerCount, DNSAnswer), []),
182 "Authorities" / Default(Array(this.AuthorityCount, DNSAnswer), []),
183 "Additionals" / Default(Array(this.AdditionalCount, DNSAnswer), []),
187 class EchoClientProtocol:
189 def __init__(self, on_con_lost):
190 self.on_con_lost = on_con_lost
191 self.transport = None
193 def connection_made(self, transport):
194 self.transport = transport
196 query = DNSQuery.build({
200 "QNAME": QUERY_QNAME,
211 logging.info("Connection made")
212 logging.info("Sending:\n%s", DNSQuery.parse(query))
213 logging.info("BINARY:\n%s", HEX.parse(query))
215 self.transport.sendto(query, DEST_ADDR)
217 logging.info("Query sent")
219 def datagram_received(self, data, addr):
220 logging.info("Received reply from: %r", addr)
221 logging.debug(HEX.parse(data))
222 logging.info(DNSQuery.parse(data))
224 def error_received(self, exc):
225 logging.error("Error")
227 def connection_lost(self, exc):
228 logging.error("Lost connection")
229 self.on_con_lost.set_result(True)
233 loop = asyncio.get_running_loop()
235 client_done = loop.create_future()
237 transport, protocol = await loop.create_datagram_endpoint(
238 lambda: EchoClientProtocol(client_done),
239 local_addr=LOCAL_ADDR
248 if __name__ == "__main__":
249 parser = argparse.ArgumentParser(description="mDNS test app")
252 default=logging.INFO,
253 type=lambda x: getattr(logging, x),
254 help="Configure the logging level.",
256 args = parser.parse_args()
259 level=args.log_level,
260 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
262 coloredlogs.install(level=args.log_level)