ec5ae38c327f6792e62de635727a91ecc947d989
[platform/upstream/connectedhomeip.git] / src / lib / mdns / minimal / format_test.py
1 #!/usr/bin/env python
2
3 import argparse
4 import asyncio
5 import coloredlogs
6 import logging
7 from dataclasses import dataclass
8
9 from construct import *
10
11 LOCAL_ADDR = ("0.0.0.0", 5388)
12 # LOCAL_ADDR = ("::%178", 5388)
13
14 DEST_ADDR = ("224.0.0.251", 5353)
15 # DEST_ADDR = ("ff02::fb", 5353)
16
17
18 QUERY_QNAME = "_googlecast._tcp.local"
19
20 def EndsWithEmpty(x, lst, ctx):
21   return not x
22
23
24 class QNameValidator(Validator):
25
26   def _validate(self, obj, context, path):
27     return obj[-1] == ""
28
29
30 class QNameArrayAdapter(Adapter):
31
32   def _decode(self, obj, context, path):
33     return ".".join(map(str, obj[:-1]))
34
35   def _encode(self, obj, context, path):
36     return list(map(str, obj.split("."))) + [""]
37
38
39 @dataclass
40 class AnswerPtr:
41   offset: int
42
43
44 class AnswerPart(Subconstruct):
45
46   def __init__(self):
47     self.name = "AnswerPart"
48     self.subcon = PascalString
49     self.flagbuildnone = False
50     self.parsed = None
51
52   def _parse(self, stream, context, path):
53     # read from the stream
54     # return object
55     len = stream.read(1)[0]
56
57     if (len & 0xC0) == 0xC0:
58       l2 = stream.read(1)[0]
59       return AnswerPtr(((len & 0x3F) << 8) | l2)
60     else:
61       return stream.read(len)
62
63   def _build(self, obj, stream, context, path):
64     # write obj to the stream
65     # return same value (obj) or a modified value
66     # that will replace the context dictionary entry
67     raise Error("Answer part build not yet implemented")
68
69   def _sizeof(self, context, path):
70     # return computed size (when fixed size or depends on context)
71     # or raise SizeofError (when variable size or unknown)
72     raise SizeofError("Answer part has avariable size")
73
74
75 def EndsWithEmptyOrPointer(x, lst, ctx):
76   return (not x) or isinstance(x, AnswerPtr)
77
78
79 class IpAddressAdapter(Adapter):
80
81   def _decode(self, obj, context, path):
82     return ".".join(map(str, obj))
83
84   def _encode(self, obj, context, path):
85     return list(map(int, obj.split(".")))
86
87
88 IpAddress = IpAddressAdapter(Byte[4])
89
90 HEX = HexDump(GreedyBytes)
91
92 QNAME = QNameArrayAdapter(
93     QNameValidator(RepeatUntil(EndsWithEmpty, PascalString(Byte, "utf8"))))
94
95 DNSAnswer = Struct(
96     "NAME" / RepeatUntil(EndsWithEmptyOrPointer, AnswerPart()),
97     "TYPE" / Enum(
98         Int16ub,
99         A=1,
100         NS=2,
101         CNAME=5,
102         SOA=6,
103         WKS=11,
104         PTR=12,
105         MX=15,
106         TXT=16,
107         AAA=28,
108         SRV=33,
109     ),
110     "CLASS" / BitStruct(
111         "FlushCache" / Flag,
112         "CLASS" / Enum(BitsInteger(15), IN=1),
113     ),
114     "TTL" / Int32ub,
115     "RDATA" / Prefixed(
116         Int16ub,
117         Switch(
118             this.TYPE, {
119                 "TXT": GreedyRange(PascalString(Byte, "utf8")),
120                 "A": IpAddress,
121                 "AAAA": Array(16, Byte),
122                 "PTR": RepeatUntil(EndsWithEmptyOrPointer, AnswerPart()),
123             },
124             default=GreedyBytes)),
125 )
126
127 DNSQuery = Struct(
128     "ID" / Int16ub,
129     "Control" / BitStruct(
130         "QR" / Default(Flag, False),
131         "OpCode" /
132         Default(Enum(BitsInteger(4), QUERY=0, IQUERY=1, STATUS=2), "QUERY"),
133         "AA" / Default(Flag, False),
134         "TC" / Default(Flag, False),
135         "RD" / Default(Flag, False),
136         "RA" / Default(Flag, False),
137         "Z" / Padding(1),
138         "AD" / Default(Flag, False),
139         "CD" / Default(Flag, False),
140         "Rcode" / Default(
141             Enum(
142                 BitsInteger(4),
143                 OK=0,
144                 FORMAT_ERROR=1,
145                 SERVER_FAILURE=2,
146                 NAME_ERROR=3,
147                 NOT_IMPLEMENTED=4,
148                 REFUSED=5,
149             ),
150             "OK",
151         ),
152     ),
153     "QuestionCount" / Rebuild(Int16ub, len_(this.Questions)),
154     "AnswerCount" / Rebuild(Int16ub, len_(this.Answers)),
155     "AuthorityCount" / Rebuild(Int16ub, len_(this.Authorities)),
156     "AdditionalCount" / Rebuild(Int16ub, len_(this.Additionals)),
157     "Questions" / Array(
158         this.QuestionCount,
159         Struct(
160             "QNAME" / QNAME,
161             "QTYPE" / Default(
162                 Enum(
163                     Int16ub,
164                     A=1,
165                     NS=2,
166                     CNAME=5,
167                     SOA=6,
168                     WKS=11,
169                     PTR=12,
170                     MX=15,
171                     SRV=33,
172                     AAAA=28,
173                     ANY=255,
174                 ), "ANY"),
175             "QCLASS" / BitStruct(
176                 "Unicast" / Default(Flag, False),
177                 "Class" / Default(Enum(BitsInteger(15), IN=1, ANY=255), "IN"),
178             ),
179         ),
180     ),
181     "Answers" / Default(Array(this.AnswerCount, DNSAnswer), []),
182     "Authorities" / Default(Array(this.AuthorityCount, DNSAnswer), []),
183     "Additionals" / Default(Array(this.AdditionalCount, DNSAnswer), []),
184 )
185
186
187 class EchoClientProtocol:
188
189   def __init__(self, on_con_lost):
190     self.on_con_lost = on_con_lost
191     self.transport = None
192
193   def connection_made(self, transport):
194     self.transport = transport
195
196     query = DNSQuery.build({
197         "ID": 0x1234,
198         "Questions": [
199             {
200                 "QNAME": QUERY_QNAME,
201                 "QCLASS": {
202                     "Unicast": True
203                 }
204             },
205         ],
206         "Answers": [],
207         "Authorities": [],
208         "Additionals": [],
209     })
210
211     logging.info("Connection made")
212     logging.info("Sending:\n%s", DNSQuery.parse(query))
213     logging.info("BINARY:\n%s", HEX.parse(query))
214
215     self.transport.sendto(query, DEST_ADDR)
216
217     logging.info("Query sent")
218
219   def datagram_received(self, data, addr):
220     logging.info("Received reply from: %r", addr)
221     logging.debug(HEX.parse(data))
222     logging.info(DNSQuery.parse(data))
223
224   def error_received(self, exc):
225     logging.error("Error")
226
227   def connection_lost(self, exc):
228     logging.error("Lost connection")
229     self.on_con_lost.set_result(True)
230
231
232 async def main():
233   loop = asyncio.get_running_loop()
234
235   client_done = loop.create_future()
236
237   transport, protocol = await loop.create_datagram_endpoint(
238       lambda: EchoClientProtocol(client_done),
239       local_addr=LOCAL_ADDR
240   )
241
242   try:
243     await client_done
244   finally:
245     transport.close()
246
247
248 if __name__ == "__main__":
249   parser = argparse.ArgumentParser(description="mDNS test app")
250   parser.add_argument(
251       "--log-level",
252       default=logging.INFO,
253       type=lambda x: getattr(logging, x),
254       help="Configure the logging level.",
255   )
256   args = parser.parse_args()
257
258   logging.basicConfig(
259       level=args.log_level,
260       format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
261   )
262   coloredlogs.install(level=args.log_level)
263
264   asyncio.run(main())