1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
5 Test cases for L{twisted.names.client}.
8 from twisted.names import client, dns
9 from twisted.names.error import DNSQueryTimeoutError
10 from twisted.trial import unittest
11 from twisted.names.common import ResolverBase
12 from twisted.internet import defer, error
13 from twisted.python import failure
14 from twisted.python.deprecate import getWarningMethod, setWarningMethod
15 from twisted.python.compat import set
18 class FakeResolver(ResolverBase):
20 def _lookup(self, name, cls, qtype, timeout):
22 The getHostByNameTest does a different type of query that requires it
23 return an A record from an ALL_RECORDS lookup, so we accomodate that
26 if name == 'getHostByNameTest':
27 rr = dns.RRHeader(name=name, type=dns.A, cls=cls, ttl=60,
28 payload=dns.Record_A(address='127.0.0.1', ttl=60))
30 rr = dns.RRHeader(name=name, type=qtype, cls=cls, ttl=60)
35 return defer.succeed((results, authority, addtional))
39 class StubPort(object):
41 A partial implementation of L{IListeningPort} which only keeps track of
42 whether it has been stopped.
44 @ivar disconnected: A C{bool} which is C{False} until C{stopListening} is
45 called, C{True} afterwards.
49 def stopListening(self):
50 self.disconnected = True
54 class StubDNSDatagramProtocol(object):
56 L{dns.DNSDatagramProtocol}-alike.
58 @ivar queries: A C{list} of tuples giving the arguments passed to
59 C{query} along with the L{defer.Deferred} which was returned from
64 self.transport = StubPort()
67 def query(self, address, queries, timeout=10, id=None):
69 Record the given arguments and return a Deferred which will not be
70 called back by this code.
72 result = defer.Deferred()
73 self.queries.append((address, queries, timeout, id, result))
78 class ResolverTests(unittest.TestCase):
80 Tests for L{client.Resolver}.
82 def test_resolverProtocol(self):
84 Reading L{client.Resolver.protocol} causes a deprecation warning to be
85 emitted and evaluates to an instance of L{DNSDatagramProtocol}.
87 resolver = client.Resolver(servers=[('example.com', 53)])
88 self.addCleanup(setWarningMethod, getWarningMethod())
91 lambda message, category, stacklevel:
92 warnings.append((message, category, stacklevel)))
93 protocol = resolver.protocol
94 self.assertIsInstance(protocol, dns.DNSDatagramProtocol)
96 warnings, [("Resolver.protocol is deprecated; use "
97 "Resolver.queryUDP instead.",
98 PendingDeprecationWarning, 0)])
99 self.assertIdentical(protocol, resolver.protocol)
102 def test_datagramQueryServerOrder(self):
104 L{client.Resolver.queryUDP} should issue queries to its
105 L{dns.DNSDatagramProtocol} with server addresses taken from its own
106 C{servers} and C{dynServers} lists, proceeding through them in order
107 as L{DNSQueryTimeoutError}s occur.
109 protocol = StubDNSDatagramProtocol()
111 servers = [object(), object()]
112 dynServers = [object(), object()]
113 resolver = client.Resolver(servers=servers)
114 resolver.dynServers = dynServers
115 resolver.protocol = protocol
117 expectedResult = object()
118 queryResult = resolver.queryUDP(None)
119 queryResult.addCallback(self.assertEqual, expectedResult)
121 self.assertEqual(len(protocol.queries), 1)
122 self.assertIdentical(protocol.queries[0][0], servers[0])
123 protocol.queries[0][-1].errback(DNSQueryTimeoutError(0))
124 self.assertEqual(len(protocol.queries), 2)
125 self.assertIdentical(protocol.queries[1][0], servers[1])
126 protocol.queries[1][-1].errback(DNSQueryTimeoutError(1))
127 self.assertEqual(len(protocol.queries), 3)
128 self.assertIdentical(protocol.queries[2][0], dynServers[0])
129 protocol.queries[2][-1].errback(DNSQueryTimeoutError(2))
130 self.assertEqual(len(protocol.queries), 4)
131 self.assertIdentical(protocol.queries[3][0], dynServers[1])
132 protocol.queries[3][-1].callback(expectedResult)
137 def test_singleConcurrentRequest(self):
139 L{client.Resolver.query} only issues one request at a time per query.
140 Subsequent requests made before responses to prior ones are received
141 are queued and given the same response as is given to the first one.
143 resolver = client.Resolver(servers=[('example.com', 53)])
144 resolver.protocol = StubDNSDatagramProtocol()
145 queries = resolver.protocol.queries
147 query = dns.Query('foo.example.com', dns.A, dns.IN)
148 # The first query should be passed to the underlying protocol.
149 firstResult = resolver.query(query)
150 self.assertEqual(len(queries), 1)
152 # The same query again should not be passed to the underlying protocol.
153 secondResult = resolver.query(query)
154 self.assertEqual(len(queries), 1)
156 # The response to the first query should be sent in response to both
159 response = dns.Message()
160 response.answers.append(answer)
161 queries.pop()[-1].callback(response)
163 d = defer.gatherResults([firstResult, secondResult])
164 def cbFinished((firstResponse, secondResponse)):
165 self.assertEqual(firstResponse, ([answer], [], []))
166 self.assertEqual(secondResponse, ([answer], [], []))
167 d.addCallback(cbFinished)
171 def test_multipleConcurrentRequests(self):
173 L{client.Resolver.query} issues a request for each different concurrent
176 resolver = client.Resolver(servers=[('example.com', 53)])
177 resolver.protocol = StubDNSDatagramProtocol()
178 queries = resolver.protocol.queries
180 # The first query should be passed to the underlying protocol.
181 firstQuery = dns.Query('foo.example.com', dns.A)
182 resolver.query(firstQuery)
183 self.assertEqual(len(queries), 1)
185 # A query for a different name is also passed to the underlying
187 secondQuery = dns.Query('bar.example.com', dns.A)
188 resolver.query(secondQuery)
189 self.assertEqual(len(queries), 2)
191 # A query for a different type is also passed to the underlying
193 thirdQuery = dns.Query('foo.example.com', dns.A6)
194 resolver.query(thirdQuery)
195 self.assertEqual(len(queries), 3)
198 def test_multipleSequentialRequests(self):
200 After a response is received to a query issued with
201 L{client.Resolver.query}, another query with the same parameters
202 results in a new network request.
204 resolver = client.Resolver(servers=[('example.com', 53)])
205 resolver.protocol = StubDNSDatagramProtocol()
206 queries = resolver.protocol.queries
208 query = dns.Query('foo.example.com', dns.A)
210 # The first query should be passed to the underlying protocol.
211 resolver.query(query)
212 self.assertEqual(len(queries), 1)
214 # Deliver the response.
215 queries.pop()[-1].callback(dns.Message())
217 # Repeating the first query should touch the protocol again.
218 resolver.query(query)
219 self.assertEqual(len(queries), 1)
222 def test_multipleConcurrentFailure(self):
224 If the result of a request is an error response, the Deferreds for all
225 concurrently issued requests associated with that result fire with the
228 resolver = client.Resolver(servers=[('example.com', 53)])
229 resolver.protocol = StubDNSDatagramProtocol()
230 queries = resolver.protocol.queries
232 query = dns.Query('foo.example.com', dns.A)
233 firstResult = resolver.query(query)
234 secondResult = resolver.query(query)
236 class ExpectedException(Exception):
239 queries.pop()[-1].errback(failure.Failure(ExpectedException()))
241 return defer.gatherResults([
242 self.assertFailure(firstResult, ExpectedException),
243 self.assertFailure(secondResult, ExpectedException)])
246 def test_connectedProtocol(self):
248 L{client.Resolver._connectedProtocol} returns a new
249 L{DNSDatagramProtocol} connected to a new address with a
250 cryptographically secure random port number.
252 resolver = client.Resolver(servers=[('example.com', 53)])
253 firstProto = resolver._connectedProtocol()
254 secondProto = resolver._connectedProtocol()
256 self.assertNotIdentical(firstProto.transport, None)
257 self.assertNotIdentical(secondProto.transport, None)
259 firstProto.transport.getHost().port,
260 secondProto.transport.getHost().port)
262 return defer.gatherResults([
263 defer.maybeDeferred(firstProto.transport.stopListening),
264 defer.maybeDeferred(secondProto.transport.stopListening)])
267 def test_differentProtocol(self):
269 L{client.Resolver._connectedProtocol} is called once each time a UDP
270 request needs to be issued and the resulting protocol instance is used
273 resolver = client.Resolver(servers=[('example.com', 53)])
276 class FakeProtocol(object):
278 self.transport = StubPort()
280 def query(self, address, query, timeout=10, id=None):
281 protocols.append(self)
282 return defer.succeed(dns.Message())
284 resolver._connectedProtocol = FakeProtocol
285 resolver.query(dns.Query('foo.example.com'))
286 resolver.query(dns.Query('bar.example.com'))
287 self.assertEqual(len(set(protocols)), 2)
290 def test_disallowedPort(self):
292 If a port number is initially selected which cannot be bound, the
293 L{CannotListenError} is handled and another port number is attempted.
297 class FakeReactor(object):
298 def listenUDP(self, port, *args):
301 raise error.CannotListenError(None, port, None)
303 resolver = client.Resolver(servers=[('example.com', 53)])
304 resolver._reactor = FakeReactor()
306 proto = resolver._connectedProtocol()
307 self.assertEqual(len(set(ports)), 2)
310 def test_differentProtocolAfterTimeout(self):
312 When a query issued by L{client.Resolver.query} times out, the retry
313 uses a new protocol instance.
315 resolver = client.Resolver(servers=[('example.com', 53)])
317 results = [defer.fail(failure.Failure(DNSQueryTimeoutError(None))),
318 defer.succeed(dns.Message())]
320 class FakeProtocol(object):
322 self.transport = StubPort()
324 def query(self, address, query, timeout=10, id=None):
325 protocols.append(self)
326 return results.pop(0)
328 resolver._connectedProtocol = FakeProtocol
329 resolver.query(dns.Query('foo.example.com'))
330 self.assertEqual(len(set(protocols)), 2)
333 def test_protocolShutDown(self):
335 After the L{Deferred} returned by L{DNSDatagramProtocol.query} is
336 called back, the L{DNSDatagramProtocol} is disconnected from its
339 resolver = client.Resolver(servers=[('example.com', 53)])
341 result = defer.Deferred()
343 class FakeProtocol(object):
345 self.transport = StubPort()
347 def query(self, address, query, timeout=10, id=None):
348 protocols.append(self)
351 resolver._connectedProtocol = FakeProtocol
352 resolver.query(dns.Query('foo.example.com'))
354 self.assertFalse(protocols[0].transport.disconnected)
355 result.callback(dns.Message())
356 self.assertTrue(protocols[0].transport.disconnected)
359 def test_protocolShutDownAfterTimeout(self):
361 The L{DNSDatagramProtocol} created when an interim timeout occurs is
362 also disconnected from its transport after the Deferred returned by its
363 query method completes.
365 resolver = client.Resolver(servers=[('example.com', 53)])
367 result = defer.Deferred()
368 results = [defer.fail(failure.Failure(DNSQueryTimeoutError(None))),
371 class FakeProtocol(object):
373 self.transport = StubPort()
375 def query(self, address, query, timeout=10, id=None):
376 protocols.append(self)
377 return results.pop(0)
379 resolver._connectedProtocol = FakeProtocol
380 resolver.query(dns.Query('foo.example.com'))
382 self.assertFalse(protocols[1].transport.disconnected)
383 result.callback(dns.Message())
384 self.assertTrue(protocols[1].transport.disconnected)
387 def test_protocolShutDownAfterFailure(self):
389 If the L{Deferred} returned by L{DNSDatagramProtocol.query} fires with
390 a failure, the L{DNSDatagramProtocol} is still disconnected from its
393 class ExpectedException(Exception):
396 resolver = client.Resolver(servers=[('example.com', 53)])
398 result = defer.Deferred()
400 class FakeProtocol(object):
402 self.transport = StubPort()
404 def query(self, address, query, timeout=10, id=None):
405 protocols.append(self)
408 resolver._connectedProtocol = FakeProtocol
409 queryResult = resolver.query(dns.Query('foo.example.com'))
411 self.assertFalse(protocols[0].transport.disconnected)
412 result.errback(failure.Failure(ExpectedException()))
413 self.assertTrue(protocols[0].transport.disconnected)
415 return self.assertFailure(queryResult, ExpectedException)
418 def test_tcpDisconnectRemovesFromConnections(self):
420 When a TCP DNS protocol associated with a Resolver disconnects, it is
421 removed from the Resolver's connection list.
423 resolver = client.Resolver(servers=[('example.com', 53)])
424 protocol = resolver.factory.buildProtocol(None)
425 protocol.makeConnection(None)
426 self.assertIn(protocol, resolver.connections)
428 # Disconnecting should remove the protocol from the connection list:
429 protocol.connectionLost(None)
430 self.assertNotIn(protocol, resolver.connections)
434 class ClientTestCase(unittest.TestCase):
438 Replace the resolver with a FakeResolver
440 client.theResolver = FakeResolver()
441 self.hostname = 'example.com'
442 self.ghbntest = 'getHostByNameTest'
446 By setting the resolver to None, it will be recreated next time a name
449 client.theResolver = None
451 def checkResult(self, (results, authority, additional), qtype):
453 Verify that the result is the same query type as what is expected.
456 self.assertEqual(str(result.name), self.hostname)
457 self.assertEqual(result.type, qtype)
459 def checkGetHostByName(self, result):
461 Test that the getHostByName query returns the 127.0.0.1 address.
463 self.assertEqual(result, '127.0.0.1')
465 def test_getHostByName(self):
467 do a getHostByName of a value that should return 127.0.0.1.
469 d = client.getHostByName(self.ghbntest)
470 d.addCallback(self.checkGetHostByName)
473 def test_lookupAddress(self):
475 Do a lookup and test that the resolver will issue the correct type of
476 query type. We do this by checking that FakeResolver returns a result
477 record with the same query type as what we issued.
479 d = client.lookupAddress(self.hostname)
480 d.addCallback(self.checkResult, dns.A)
483 def test_lookupIPV6Address(self):
485 See L{test_lookupAddress}
487 d = client.lookupIPV6Address(self.hostname)
488 d.addCallback(self.checkResult, dns.AAAA)
491 def test_lookupAddress6(self):
493 See L{test_lookupAddress}
495 d = client.lookupAddress6(self.hostname)
496 d.addCallback(self.checkResult, dns.A6)
499 def test_lookupNameservers(self):
501 See L{test_lookupAddress}
503 d = client.lookupNameservers(self.hostname)
504 d.addCallback(self.checkResult, dns.NS)
507 def test_lookupCanonicalName(self):
509 See L{test_lookupAddress}
511 d = client.lookupCanonicalName(self.hostname)
512 d.addCallback(self.checkResult, dns.CNAME)
515 def test_lookupAuthority(self):
517 See L{test_lookupAddress}
519 d = client.lookupAuthority(self.hostname)
520 d.addCallback(self.checkResult, dns.SOA)
523 def test_lookupMailBox(self):
525 See L{test_lookupAddress}
527 d = client.lookupMailBox(self.hostname)
528 d.addCallback(self.checkResult, dns.MB)
531 def test_lookupMailGroup(self):
533 See L{test_lookupAddress}
535 d = client.lookupMailGroup(self.hostname)
536 d.addCallback(self.checkResult, dns.MG)
539 def test_lookupMailRename(self):
541 See L{test_lookupAddress}
543 d = client.lookupMailRename(self.hostname)
544 d.addCallback(self.checkResult, dns.MR)
547 def test_lookupNull(self):
549 See L{test_lookupAddress}
551 d = client.lookupNull(self.hostname)
552 d.addCallback(self.checkResult, dns.NULL)
555 def test_lookupWellKnownServices(self):
557 See L{test_lookupAddress}
559 d = client.lookupWellKnownServices(self.hostname)
560 d.addCallback(self.checkResult, dns.WKS)
563 def test_lookupPointer(self):
565 See L{test_lookupAddress}
567 d = client.lookupPointer(self.hostname)
568 d.addCallback(self.checkResult, dns.PTR)
571 def test_lookupHostInfo(self):
573 See L{test_lookupAddress}
575 d = client.lookupHostInfo(self.hostname)
576 d.addCallback(self.checkResult, dns.HINFO)
579 def test_lookupMailboxInfo(self):
581 See L{test_lookupAddress}
583 d = client.lookupMailboxInfo(self.hostname)
584 d.addCallback(self.checkResult, dns.MINFO)
587 def test_lookupMailExchange(self):
589 See L{test_lookupAddress}
591 d = client.lookupMailExchange(self.hostname)
592 d.addCallback(self.checkResult, dns.MX)
595 def test_lookupText(self):
597 See L{test_lookupAddress}
599 d = client.lookupText(self.hostname)
600 d.addCallback(self.checkResult, dns.TXT)
603 def test_lookupSenderPolicy(self):
605 See L{test_lookupAddress}
607 d = client.lookupSenderPolicy(self.hostname)
608 d.addCallback(self.checkResult, dns.SPF)
611 def test_lookupResponsibility(self):
613 See L{test_lookupAddress}
615 d = client.lookupResponsibility(self.hostname)
616 d.addCallback(self.checkResult, dns.RP)
619 def test_lookupAFSDatabase(self):
621 See L{test_lookupAddress}
623 d = client.lookupAFSDatabase(self.hostname)
624 d.addCallback(self.checkResult, dns.AFSDB)
627 def test_lookupService(self):
629 See L{test_lookupAddress}
631 d = client.lookupService(self.hostname)
632 d.addCallback(self.checkResult, dns.SRV)
635 def test_lookupZone(self):
637 See L{test_lookupAddress}
639 d = client.lookupZone(self.hostname)
640 d.addCallback(self.checkResult, dns.AXFR)
643 def test_lookupAllRecords(self):
645 See L{test_lookupAddress}
647 d = client.lookupAllRecords(self.hostname)
648 d.addCallback(self.checkResult, dns.ALL_RECORDS)
652 def test_lookupNamingAuthorityPointer(self):
654 See L{test_lookupAddress}
656 d = client.lookupNamingAuthorityPointer(self.hostname)
657 d.addCallback(self.checkResult, dns.NAPTR)
661 class ThreadedResolverTests(unittest.TestCase):
663 Tests for L{client.ThreadedResolver}.
665 def test_deprecated(self):
667 L{client.ThreadedResolver} is deprecated. Instantiating it emits a
668 deprecation warning pointing at the code that does the instantiation.
670 client.ThreadedResolver()
671 warnings = self.flushWarnings(offendingFunctions=[self.test_deprecated])
673 warnings[0]['message'],
674 "twisted.names.client.ThreadedResolver is deprecated since "
675 "Twisted 9.0, use twisted.internet.base.ThreadedResolver "
677 self.assertEqual(warnings[0]['category'], DeprecationWarning)
678 self.assertEqual(len(warnings), 1)