Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / names / test / test_client.py
1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
3
4 """
5 Test cases for L{twisted.names.client}.
6 """
7
8 from twisted.names import client, dns
9 from twisted.names.error import DNSQueryTimeoutError
10 from twisted.trial import unittest
11 from twisted.names.common import ResolverBase
12 from twisted.internet import defer, error
13 from twisted.python import failure
14 from twisted.python.deprecate import getWarningMethod, setWarningMethod
15 from twisted.python.compat import set
16
17
18 class FakeResolver(ResolverBase):
19
20     def _lookup(self, name, cls, qtype, timeout):
21         """
22         The getHostByNameTest does a different type of query that requires it
23         return an A record from an ALL_RECORDS lookup, so we accomodate that
24         here.
25         """
26         if name == 'getHostByNameTest':
27             rr = dns.RRHeader(name=name, type=dns.A, cls=cls, ttl=60,
28                     payload=dns.Record_A(address='127.0.0.1', ttl=60))
29         else:
30             rr = dns.RRHeader(name=name, type=qtype, cls=cls, ttl=60)
31
32         results = [rr]
33         authority = []
34         addtional = []
35         return defer.succeed((results, authority, addtional))
36
37
38
39 class StubPort(object):
40     """
41     A partial implementation of L{IListeningPort} which only keeps track of
42     whether it has been stopped.
43
44     @ivar disconnected: A C{bool} which is C{False} until C{stopListening} is
45         called, C{True} afterwards.
46     """
47     disconnected = False
48
49     def stopListening(self):
50         self.disconnected = True
51
52
53
54 class StubDNSDatagramProtocol(object):
55     """
56     L{dns.DNSDatagramProtocol}-alike.
57
58     @ivar queries: A C{list} of tuples giving the arguments passed to
59         C{query} along with the L{defer.Deferred} which was returned from
60         the call.
61     """
62     def __init__(self):
63         self.queries = []
64         self.transport = StubPort()
65
66
67     def query(self, address, queries, timeout=10, id=None):
68         """
69         Record the given arguments and return a Deferred which will not be
70         called back by this code.
71         """
72         result = defer.Deferred()
73         self.queries.append((address, queries, timeout, id, result))
74         return result
75
76
77
78 class ResolverTests(unittest.TestCase):
79     """
80     Tests for L{client.Resolver}.
81     """
82     def test_resolverProtocol(self):
83         """
84         Reading L{client.Resolver.protocol} causes a deprecation warning to be
85         emitted and evaluates to an instance of L{DNSDatagramProtocol}.
86         """
87         resolver = client.Resolver(servers=[('example.com', 53)])
88         self.addCleanup(setWarningMethod, getWarningMethod())
89         warnings = []
90         setWarningMethod(
91             lambda message, category, stacklevel:
92                 warnings.append((message, category, stacklevel)))
93         protocol = resolver.protocol
94         self.assertIsInstance(protocol, dns.DNSDatagramProtocol)
95         self.assertEqual(
96             warnings, [("Resolver.protocol is deprecated; use "
97                         "Resolver.queryUDP instead.",
98                         PendingDeprecationWarning, 0)])
99         self.assertIdentical(protocol, resolver.protocol)
100
101
102     def test_datagramQueryServerOrder(self):
103         """
104         L{client.Resolver.queryUDP} should issue queries to its
105         L{dns.DNSDatagramProtocol} with server addresses taken from its own
106         C{servers} and C{dynServers} lists, proceeding through them in order
107         as L{DNSQueryTimeoutError}s occur.
108         """
109         protocol = StubDNSDatagramProtocol()
110
111         servers = [object(), object()]
112         dynServers = [object(), object()]
113         resolver = client.Resolver(servers=servers)
114         resolver.dynServers = dynServers
115         resolver.protocol = protocol
116
117         expectedResult = object()
118         queryResult = resolver.queryUDP(None)
119         queryResult.addCallback(self.assertEqual, expectedResult)
120
121         self.assertEqual(len(protocol.queries), 1)
122         self.assertIdentical(protocol.queries[0][0], servers[0])
123         protocol.queries[0][-1].errback(DNSQueryTimeoutError(0))
124         self.assertEqual(len(protocol.queries), 2)
125         self.assertIdentical(protocol.queries[1][0], servers[1])
126         protocol.queries[1][-1].errback(DNSQueryTimeoutError(1))
127         self.assertEqual(len(protocol.queries), 3)
128         self.assertIdentical(protocol.queries[2][0], dynServers[0])
129         protocol.queries[2][-1].errback(DNSQueryTimeoutError(2))
130         self.assertEqual(len(protocol.queries), 4)
131         self.assertIdentical(protocol.queries[3][0], dynServers[1])
132         protocol.queries[3][-1].callback(expectedResult)
133
134         return queryResult
135
136
137     def test_singleConcurrentRequest(self):
138         """
139         L{client.Resolver.query} only issues one request at a time per query.
140         Subsequent requests made before responses to prior ones are received
141         are queued and given the same response as is given to the first one.
142         """
143         resolver = client.Resolver(servers=[('example.com', 53)])
144         resolver.protocol = StubDNSDatagramProtocol()
145         queries = resolver.protocol.queries
146
147         query = dns.Query('foo.example.com', dns.A, dns.IN)
148         # The first query should be passed to the underlying protocol.
149         firstResult = resolver.query(query)
150         self.assertEqual(len(queries), 1)
151
152         # The same query again should not be passed to the underlying protocol.
153         secondResult = resolver.query(query)
154         self.assertEqual(len(queries), 1)
155
156         # The response to the first query should be sent in response to both
157         # queries.
158         answer = object()
159         response = dns.Message()
160         response.answers.append(answer)
161         queries.pop()[-1].callback(response)
162
163         d = defer.gatherResults([firstResult, secondResult])
164         def cbFinished((firstResponse, secondResponse)):
165             self.assertEqual(firstResponse, ([answer], [], []))
166             self.assertEqual(secondResponse, ([answer], [], []))
167         d.addCallback(cbFinished)
168         return d
169
170
171     def test_multipleConcurrentRequests(self):
172         """
173         L{client.Resolver.query} issues a request for each different concurrent
174         query.
175         """
176         resolver = client.Resolver(servers=[('example.com', 53)])
177         resolver.protocol = StubDNSDatagramProtocol()
178         queries = resolver.protocol.queries
179
180         # The first query should be passed to the underlying protocol.
181         firstQuery = dns.Query('foo.example.com', dns.A)
182         resolver.query(firstQuery)
183         self.assertEqual(len(queries), 1)
184
185         # A query for a different name is also passed to the underlying
186         # protocol.
187         secondQuery = dns.Query('bar.example.com', dns.A)
188         resolver.query(secondQuery)
189         self.assertEqual(len(queries), 2)
190
191         # A query for a different type is also passed to the underlying
192         # protocol.
193         thirdQuery = dns.Query('foo.example.com', dns.A6)
194         resolver.query(thirdQuery)
195         self.assertEqual(len(queries), 3)
196
197
198     def test_multipleSequentialRequests(self):
199         """
200         After a response is received to a query issued with
201         L{client.Resolver.query}, another query with the same parameters
202         results in a new network request.
203         """
204         resolver = client.Resolver(servers=[('example.com', 53)])
205         resolver.protocol = StubDNSDatagramProtocol()
206         queries = resolver.protocol.queries
207
208         query = dns.Query('foo.example.com', dns.A)
209
210         # The first query should be passed to the underlying protocol.
211         resolver.query(query)
212         self.assertEqual(len(queries), 1)
213
214         # Deliver the response.
215         queries.pop()[-1].callback(dns.Message())
216
217         # Repeating the first query should touch the protocol again.
218         resolver.query(query)
219         self.assertEqual(len(queries), 1)
220
221
222     def test_multipleConcurrentFailure(self):
223         """
224         If the result of a request is an error response, the Deferreds for all
225         concurrently issued requests associated with that result fire with the
226         L{Failure}.
227         """
228         resolver = client.Resolver(servers=[('example.com', 53)])
229         resolver.protocol = StubDNSDatagramProtocol()
230         queries = resolver.protocol.queries
231
232         query = dns.Query('foo.example.com', dns.A)
233         firstResult = resolver.query(query)
234         secondResult = resolver.query(query)
235
236         class ExpectedException(Exception):
237             pass
238
239         queries.pop()[-1].errback(failure.Failure(ExpectedException()))
240
241         return defer.gatherResults([
242                 self.assertFailure(firstResult, ExpectedException),
243                 self.assertFailure(secondResult, ExpectedException)])
244
245
246     def test_connectedProtocol(self):
247         """
248         L{client.Resolver._connectedProtocol} returns a new
249         L{DNSDatagramProtocol} connected to a new address with a
250         cryptographically secure random port number.
251         """
252         resolver = client.Resolver(servers=[('example.com', 53)])
253         firstProto = resolver._connectedProtocol()
254         secondProto = resolver._connectedProtocol()
255
256         self.assertNotIdentical(firstProto.transport, None)
257         self.assertNotIdentical(secondProto.transport, None)
258         self.assertNotEqual(
259             firstProto.transport.getHost().port,
260             secondProto.transport.getHost().port)
261
262         return defer.gatherResults([
263                 defer.maybeDeferred(firstProto.transport.stopListening),
264                 defer.maybeDeferred(secondProto.transport.stopListening)])
265
266
267     def test_differentProtocol(self):
268         """
269         L{client.Resolver._connectedProtocol} is called once each time a UDP
270         request needs to be issued and the resulting protocol instance is used
271         for that request.
272         """
273         resolver = client.Resolver(servers=[('example.com', 53)])
274         protocols = []
275
276         class FakeProtocol(object):
277             def __init__(self):
278                 self.transport = StubPort()
279
280             def query(self, address, query, timeout=10, id=None):
281                 protocols.append(self)
282                 return defer.succeed(dns.Message())
283
284         resolver._connectedProtocol = FakeProtocol
285         resolver.query(dns.Query('foo.example.com'))
286         resolver.query(dns.Query('bar.example.com'))
287         self.assertEqual(len(set(protocols)), 2)
288
289
290     def test_disallowedPort(self):
291         """
292         If a port number is initially selected which cannot be bound, the
293         L{CannotListenError} is handled and another port number is attempted.
294         """
295         ports = []
296
297         class FakeReactor(object):
298             def listenUDP(self, port, *args):
299                 ports.append(port)
300                 if len(ports) == 1:
301                     raise error.CannotListenError(None, port, None)
302
303         resolver = client.Resolver(servers=[('example.com', 53)])
304         resolver._reactor = FakeReactor()
305
306         proto = resolver._connectedProtocol()
307         self.assertEqual(len(set(ports)), 2)
308
309
310     def test_differentProtocolAfterTimeout(self):
311         """
312         When a query issued by L{client.Resolver.query} times out, the retry
313         uses a new protocol instance.
314         """
315         resolver = client.Resolver(servers=[('example.com', 53)])
316         protocols = []
317         results = [defer.fail(failure.Failure(DNSQueryTimeoutError(None))),
318                    defer.succeed(dns.Message())]
319
320         class FakeProtocol(object):
321             def __init__(self):
322                 self.transport = StubPort()
323
324             def query(self, address, query, timeout=10, id=None):
325                 protocols.append(self)
326                 return results.pop(0)
327
328         resolver._connectedProtocol = FakeProtocol
329         resolver.query(dns.Query('foo.example.com'))
330         self.assertEqual(len(set(protocols)), 2)
331
332
333     def test_protocolShutDown(self):
334         """
335         After the L{Deferred} returned by L{DNSDatagramProtocol.query} is
336         called back, the L{DNSDatagramProtocol} is disconnected from its
337         transport.
338         """
339         resolver = client.Resolver(servers=[('example.com', 53)])
340         protocols = []
341         result = defer.Deferred()
342
343         class FakeProtocol(object):
344             def __init__(self):
345                 self.transport = StubPort()
346
347             def query(self, address, query, timeout=10, id=None):
348                 protocols.append(self)
349                 return result
350
351         resolver._connectedProtocol = FakeProtocol
352         resolver.query(dns.Query('foo.example.com'))
353
354         self.assertFalse(protocols[0].transport.disconnected)
355         result.callback(dns.Message())
356         self.assertTrue(protocols[0].transport.disconnected)
357
358
359     def test_protocolShutDownAfterTimeout(self):
360         """
361         The L{DNSDatagramProtocol} created when an interim timeout occurs is
362         also disconnected from its transport after the Deferred returned by its
363         query method completes.
364         """
365         resolver = client.Resolver(servers=[('example.com', 53)])
366         protocols = []
367         result = defer.Deferred()
368         results = [defer.fail(failure.Failure(DNSQueryTimeoutError(None))),
369                    result]
370
371         class FakeProtocol(object):
372             def __init__(self):
373                 self.transport = StubPort()
374
375             def query(self, address, query, timeout=10, id=None):
376                 protocols.append(self)
377                 return results.pop(0)
378
379         resolver._connectedProtocol = FakeProtocol
380         resolver.query(dns.Query('foo.example.com'))
381
382         self.assertFalse(protocols[1].transport.disconnected)
383         result.callback(dns.Message())
384         self.assertTrue(protocols[1].transport.disconnected)
385
386
387     def test_protocolShutDownAfterFailure(self):
388         """
389         If the L{Deferred} returned by L{DNSDatagramProtocol.query} fires with
390         a failure, the L{DNSDatagramProtocol} is still disconnected from its
391         transport.
392         """
393         class ExpectedException(Exception):
394             pass
395
396         resolver = client.Resolver(servers=[('example.com', 53)])
397         protocols = []
398         result = defer.Deferred()
399
400         class FakeProtocol(object):
401             def __init__(self):
402                 self.transport = StubPort()
403
404             def query(self, address, query, timeout=10, id=None):
405                 protocols.append(self)
406                 return result
407
408         resolver._connectedProtocol = FakeProtocol
409         queryResult = resolver.query(dns.Query('foo.example.com'))
410
411         self.assertFalse(protocols[0].transport.disconnected)
412         result.errback(failure.Failure(ExpectedException()))
413         self.assertTrue(protocols[0].transport.disconnected)
414
415         return self.assertFailure(queryResult, ExpectedException)
416
417
418     def test_tcpDisconnectRemovesFromConnections(self):
419         """
420         When a TCP DNS protocol associated with a Resolver disconnects, it is
421         removed from the Resolver's connection list.
422         """
423         resolver = client.Resolver(servers=[('example.com', 53)])
424         protocol = resolver.factory.buildProtocol(None)
425         protocol.makeConnection(None)
426         self.assertIn(protocol, resolver.connections)
427
428         # Disconnecting should remove the protocol from the connection list:
429         protocol.connectionLost(None)
430         self.assertNotIn(protocol, resolver.connections)
431
432
433
434 class ClientTestCase(unittest.TestCase):
435
436     def setUp(self):
437         """
438         Replace the resolver with a FakeResolver
439         """
440         client.theResolver = FakeResolver()
441         self.hostname = 'example.com'
442         self.ghbntest = 'getHostByNameTest'
443
444     def tearDown(self):
445         """
446         By setting the resolver to None, it will be recreated next time a name
447         lookup is done.
448         """
449         client.theResolver = None
450
451     def checkResult(self, (results, authority, additional), qtype):
452         """
453         Verify that the result is the same query type as what is expected.
454         """
455         result = results[0]
456         self.assertEqual(str(result.name), self.hostname)
457         self.assertEqual(result.type, qtype)
458
459     def checkGetHostByName(self, result):
460         """
461         Test that the getHostByName query returns the 127.0.0.1 address.
462         """
463         self.assertEqual(result, '127.0.0.1')
464
465     def test_getHostByName(self):
466         """
467         do a getHostByName of a value that should return 127.0.0.1.
468         """
469         d = client.getHostByName(self.ghbntest)
470         d.addCallback(self.checkGetHostByName)
471         return d
472
473     def test_lookupAddress(self):
474         """
475         Do a lookup and test that the resolver will issue the correct type of
476         query type. We do this by checking that FakeResolver returns a result
477         record with the same query type as what we issued.
478         """
479         d = client.lookupAddress(self.hostname)
480         d.addCallback(self.checkResult, dns.A)
481         return d
482
483     def test_lookupIPV6Address(self):
484         """
485         See L{test_lookupAddress}
486         """
487         d = client.lookupIPV6Address(self.hostname)
488         d.addCallback(self.checkResult, dns.AAAA)
489         return d
490
491     def test_lookupAddress6(self):
492         """
493         See L{test_lookupAddress}
494         """
495         d = client.lookupAddress6(self.hostname)
496         d.addCallback(self.checkResult, dns.A6)
497         return d
498
499     def test_lookupNameservers(self):
500         """
501         See L{test_lookupAddress}
502         """
503         d = client.lookupNameservers(self.hostname)
504         d.addCallback(self.checkResult, dns.NS)
505         return d
506
507     def test_lookupCanonicalName(self):
508         """
509         See L{test_lookupAddress}
510         """
511         d = client.lookupCanonicalName(self.hostname)
512         d.addCallback(self.checkResult, dns.CNAME)
513         return d
514
515     def test_lookupAuthority(self):
516         """
517         See L{test_lookupAddress}
518         """
519         d = client.lookupAuthority(self.hostname)
520         d.addCallback(self.checkResult, dns.SOA)
521         return d
522
523     def test_lookupMailBox(self):
524         """
525         See L{test_lookupAddress}
526         """
527         d = client.lookupMailBox(self.hostname)
528         d.addCallback(self.checkResult, dns.MB)
529         return d
530
531     def test_lookupMailGroup(self):
532         """
533         See L{test_lookupAddress}
534         """
535         d = client.lookupMailGroup(self.hostname)
536         d.addCallback(self.checkResult, dns.MG)
537         return d
538
539     def test_lookupMailRename(self):
540         """
541         See L{test_lookupAddress}
542         """
543         d = client.lookupMailRename(self.hostname)
544         d.addCallback(self.checkResult, dns.MR)
545         return d
546
547     def test_lookupNull(self):
548         """
549         See L{test_lookupAddress}
550         """
551         d = client.lookupNull(self.hostname)
552         d.addCallback(self.checkResult, dns.NULL)
553         return d
554
555     def test_lookupWellKnownServices(self):
556         """
557         See L{test_lookupAddress}
558         """
559         d = client.lookupWellKnownServices(self.hostname)
560         d.addCallback(self.checkResult, dns.WKS)
561         return d
562
563     def test_lookupPointer(self):
564         """
565         See L{test_lookupAddress}
566         """
567         d = client.lookupPointer(self.hostname)
568         d.addCallback(self.checkResult, dns.PTR)
569         return d
570
571     def test_lookupHostInfo(self):
572         """
573         See L{test_lookupAddress}
574         """
575         d = client.lookupHostInfo(self.hostname)
576         d.addCallback(self.checkResult, dns.HINFO)
577         return d
578
579     def test_lookupMailboxInfo(self):
580         """
581         See L{test_lookupAddress}
582         """
583         d = client.lookupMailboxInfo(self.hostname)
584         d.addCallback(self.checkResult, dns.MINFO)
585         return d
586
587     def test_lookupMailExchange(self):
588         """
589         See L{test_lookupAddress}
590         """
591         d = client.lookupMailExchange(self.hostname)
592         d.addCallback(self.checkResult, dns.MX)
593         return d
594
595     def test_lookupText(self):
596         """
597         See L{test_lookupAddress}
598         """
599         d = client.lookupText(self.hostname)
600         d.addCallback(self.checkResult, dns.TXT)
601         return d
602
603     def test_lookupSenderPolicy(self):
604         """
605         See L{test_lookupAddress}
606         """
607         d = client.lookupSenderPolicy(self.hostname)
608         d.addCallback(self.checkResult, dns.SPF)
609         return d
610
611     def test_lookupResponsibility(self):
612         """
613         See L{test_lookupAddress}
614         """
615         d = client.lookupResponsibility(self.hostname)
616         d.addCallback(self.checkResult, dns.RP)
617         return d
618
619     def test_lookupAFSDatabase(self):
620         """
621         See L{test_lookupAddress}
622         """
623         d = client.lookupAFSDatabase(self.hostname)
624         d.addCallback(self.checkResult, dns.AFSDB)
625         return d
626
627     def test_lookupService(self):
628         """
629         See L{test_lookupAddress}
630         """
631         d = client.lookupService(self.hostname)
632         d.addCallback(self.checkResult, dns.SRV)
633         return d
634
635     def test_lookupZone(self):
636         """
637         See L{test_lookupAddress}
638         """
639         d = client.lookupZone(self.hostname)
640         d.addCallback(self.checkResult, dns.AXFR)
641         return d
642
643     def test_lookupAllRecords(self):
644         """
645         See L{test_lookupAddress}
646         """
647         d = client.lookupAllRecords(self.hostname)
648         d.addCallback(self.checkResult, dns.ALL_RECORDS)
649         return d
650
651
652     def test_lookupNamingAuthorityPointer(self):
653         """
654         See L{test_lookupAddress}
655         """
656         d = client.lookupNamingAuthorityPointer(self.hostname)
657         d.addCallback(self.checkResult, dns.NAPTR)
658         return d
659
660
661 class ThreadedResolverTests(unittest.TestCase):
662     """
663     Tests for L{client.ThreadedResolver}.
664     """
665     def test_deprecated(self):
666         """
667         L{client.ThreadedResolver} is deprecated.  Instantiating it emits a
668         deprecation warning pointing at the code that does the instantiation.
669         """
670         client.ThreadedResolver()
671         warnings = self.flushWarnings(offendingFunctions=[self.test_deprecated])
672         self.assertEqual(
673             warnings[0]['message'],
674             "twisted.names.client.ThreadedResolver is deprecated since "
675             "Twisted 9.0, use twisted.internet.base.ThreadedResolver "
676             "instead.")
677         self.assertEqual(warnings[0]['category'], DeprecationWarning)
678         self.assertEqual(len(warnings), 1)