Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / web / test / test_webclient.py
1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
3
4 """
5 Tests for L{twisted.web.client}.
6 """
7
8 import cookielib
9 import os
10 from errno import ENOSPC
11 import zlib
12 from StringIO import StringIO
13
14 from urlparse import urlparse, urljoin
15
16 from zope.interface.verify import verifyObject
17
18 from twisted.trial import unittest
19 from twisted.web import server, static, client, error, util, resource, http_headers
20 from twisted.web._newclient import RequestNotSent, RequestTransmissionFailed
21 from twisted.web._newclient import ResponseNeverReceived, ResponseFailed
22 from twisted.internet import reactor, defer, interfaces, task
23 from twisted.python.failure import Failure
24 from twisted.python.filepath import FilePath
25 from twisted.python.log import msg
26 from twisted.python.components import proxyForInterface
27 from twisted.protocols.policies import WrappingFactory
28 from twisted.test.proto_helpers import StringTransport
29 from twisted.test.proto_helpers import MemoryReactor
30 from twisted.internet.task import Clock
31 from twisted.internet.error import ConnectionRefusedError, ConnectionDone
32 from twisted.internet.protocol import Protocol, Factory
33 from twisted.internet.defer import Deferred, succeed
34 from twisted.internet.endpoints import TCP4ClientEndpoint, SSL4ClientEndpoint
35 from twisted.web.client import FileBodyProducer, Request, HTTPConnectionPool
36 from twisted.web.client import _WebToNormalContextFactory
37 from twisted.web.client import WebClientContextFactory, _HTTP11ClientFactory
38 from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer, IResponse
39 from twisted.web._newclient import HTTP11ClientProtocol, Response
40 from twisted.web.error import SchemeNotSupported
41
42 try:
43     from twisted.internet import ssl
44 except:
45     ssl = None
46
47
48
49 class ExtendedRedirect(resource.Resource):
50     """
51     Redirection resource.
52
53     The HTTP status code is set according to the C{code} query parameter.
54
55     @type lastMethod: C{str}
56     @ivar lastMethod: Last handled HTTP request method
57     """
58     isLeaf = 1
59     lastMethod = None
60
61
62     def __init__(self, url):
63         resource.Resource.__init__(self)
64         self.url = url
65
66
67     def render(self, request):
68         if self.lastMethod:
69             self.lastMethod = request.method
70             return "OK Thnx!"
71         else:
72             self.lastMethod = request.method
73             code = int(request.args['code'][0])
74             return self.redirectTo(self.url, request, code)
75
76
77     def getChild(self, name, request):
78         return self
79
80
81     def redirectTo(self, url, request, code):
82         request.setResponseCode(code)
83         request.setHeader("location", url)
84         return "OK Bye!"
85
86
87
88 class ForeverTakingResource(resource.Resource):
89     """
90     L{ForeverTakingResource} is a resource which never finishes responding
91     to requests.
92     """
93     def __init__(self, write=False):
94         resource.Resource.__init__(self)
95         self._write = write
96
97     def render(self, request):
98         if self._write:
99             request.write('some bytes')
100         return server.NOT_DONE_YET
101
102
103 class CookieMirrorResource(resource.Resource):
104     def render(self, request):
105         l = []
106         for k,v in request.received_cookies.items():
107             l.append((k, v))
108         l.sort()
109         return repr(l)
110
111 class RawCookieMirrorResource(resource.Resource):
112     def render(self, request):
113         return repr(request.getHeader('cookie'))
114
115 class ErrorResource(resource.Resource):
116
117     def render(self, request):
118         request.setResponseCode(401)
119         if request.args.get("showlength"):
120             request.setHeader("content-length", "0")
121         return ""
122
123 class NoLengthResource(resource.Resource):
124
125     def render(self, request):
126         return "nolength"
127
128
129
130 class HostHeaderResource(resource.Resource):
131     """
132     A testing resource which renders itself as the value of the host header
133     from the request.
134     """
135     def render(self, request):
136         return request.received_headers['host']
137
138
139
140 class PayloadResource(resource.Resource):
141     """
142     A testing resource which renders itself as the contents of the request body
143     as long as the request body is 100 bytes long, otherwise which renders
144     itself as C{"ERROR"}.
145     """
146     def render(self, request):
147         data = request.content.read()
148         contentLength = request.received_headers['content-length']
149         if len(data) != 100 or int(contentLength) != 100:
150             return "ERROR"
151         return data
152
153
154 class DelayResource(resource.Resource):
155
156     def __init__(self, seconds):
157         self.seconds = seconds
158
159     def render(self, request):
160         def response():
161             request.write('some bytes')
162             request.finish()
163         reactor.callLater(self.seconds, response)
164         return server.NOT_DONE_YET
165
166
167 class BrokenDownloadResource(resource.Resource):
168
169     def render(self, request):
170         # only sends 3 bytes even though it claims to send 5
171         request.setHeader("content-length", "5")
172         request.write('abc')
173         return ''
174
175 class CountingRedirect(util.Redirect):
176     """
177     A L{util.Redirect} resource that keeps track of the number of times the
178     resource has been accessed.
179     """
180     def __init__(self, *a, **kw):
181         util.Redirect.__init__(self, *a, **kw)
182         self.count = 0
183
184     def render(self, request):
185         self.count += 1
186         return util.Redirect.render(self, request)
187
188
189 class CountingResource(resource.Resource):
190     """
191     A resource that keeps track of the number of times it has been accessed.
192     """
193     def __init__(self):
194         resource.Resource.__init__(self)
195         self.count = 0
196
197     def render(self, request):
198         self.count += 1
199         return "Success"
200
201
202 class ParseUrlTestCase(unittest.TestCase):
203     """
204     Test URL parsing facility and defaults values.
205     """
206
207     def test_parse(self):
208         """
209         L{client._parse} correctly parses a URL into its various components.
210         """
211         # The default port for HTTP is 80.
212         self.assertEqual(
213             client._parse('http://127.0.0.1/'),
214             ('http', '127.0.0.1', 80, '/'))
215
216         # The default port for HTTPS is 443.
217         self.assertEqual(
218             client._parse('https://127.0.0.1/'),
219             ('https', '127.0.0.1', 443, '/'))
220
221         # Specifying a port.
222         self.assertEqual(
223             client._parse('http://spam:12345/'),
224             ('http', 'spam', 12345, '/'))
225
226         # Weird (but commonly accepted) structure uses default port.
227         self.assertEqual(
228             client._parse('http://spam:/'),
229             ('http', 'spam', 80, '/'))
230
231         # Spaces in the hostname are trimmed, the default path is /.
232         self.assertEqual(
233             client._parse('http://foo '),
234             ('http', 'foo', 80, '/'))
235
236
237     def test_externalUnicodeInterference(self):
238         """
239         L{client._parse} should return C{str} for the scheme, host, and path
240         elements of its return tuple, even when passed an URL which has
241         previously been passed to L{urlparse} as a C{unicode} string.
242         """
243         badInput = u'http://example.com/path'
244         goodInput = badInput.encode('ascii')
245         urlparse(badInput)
246         scheme, host, port, path = client._parse(goodInput)
247         self.assertIsInstance(scheme, str)
248         self.assertIsInstance(host, str)
249         self.assertIsInstance(path, str)
250
251
252
253 class HTTPPageGetterTests(unittest.TestCase):
254     """
255     Tests for L{HTTPPagerGetter}, the HTTP client protocol implementation
256     used to implement L{getPage}.
257     """
258     def test_earlyHeaders(self):
259         """
260         When a connection is made, L{HTTPPagerGetter} sends the headers from
261         its factory's C{headers} dict.  If I{Host} or I{Content-Length} is
262         present in this dict, the values are not sent, since they are sent with
263         special values before the C{headers} dict is processed.  If
264         I{User-Agent} is present in the dict, it overrides the value of the
265         C{agent} attribute of the factory.  If I{Cookie} is present in the
266         dict, its value is added to the values from the factory's C{cookies}
267         attribute.
268         """
269         factory = client.HTTPClientFactory(
270             'http://foo/bar',
271             agent="foobar",
272             cookies={'baz': 'quux'},
273             postdata="some data",
274             headers={
275                 'Host': 'example.net',
276                 'User-Agent': 'fooble',
277                 'Cookie': 'blah blah',
278                 'Content-Length': '12981',
279                 'Useful': 'value'})
280         transport = StringTransport()
281         protocol = client.HTTPPageGetter()
282         protocol.factory = factory
283         protocol.makeConnection(transport)
284         self.assertEqual(
285             transport.value(),
286             "GET /bar HTTP/1.0\r\n"
287             "Host: example.net\r\n"
288             "User-Agent: foobar\r\n"
289             "Content-Length: 9\r\n"
290             "Useful: value\r\n"
291             "connection: close\r\n"
292             "Cookie: blah blah; baz=quux\r\n"
293             "\r\n"
294             "some data")
295
296
297 class GetBodyProtocol(Protocol):
298
299     def __init__(self, deferred):
300         self.deferred = deferred
301         self.buf = ''
302
303     def dataReceived(self, bytes):
304         self.buf += bytes
305
306     def connectionLost(self, reason):
307         self.deferred.callback(self.buf)
308
309
310 def getBody(response):
311     d = defer.Deferred()
312     response.deliverBody(GetBodyProtocol(d))
313     return d
314
315
316 class WebClientTestCase(unittest.TestCase):
317     def _listen(self, site):
318         return reactor.listenTCP(0, site, interface="127.0.0.1")
319
320     def setUp(self):
321         self.agent = None # for twisted.web.client.Agent test
322         self.cleanupServerConnections = 0
323         name = self.mktemp()
324         os.mkdir(name)
325         FilePath(name).child("file").setContent("0123456789")
326         r = static.File(name)
327         r.putChild("redirect", util.Redirect("/file"))
328         self.infiniteRedirectResource = CountingRedirect("/infiniteRedirect")
329         r.putChild("infiniteRedirect", self.infiniteRedirectResource)
330         r.putChild("wait", ForeverTakingResource())
331         r.putChild("write-then-wait", ForeverTakingResource(write=True))
332         r.putChild("error", ErrorResource())
333         r.putChild("nolength", NoLengthResource())
334         r.putChild("host", HostHeaderResource())
335         r.putChild("payload", PayloadResource())
336         r.putChild("broken", BrokenDownloadResource())
337         r.putChild("cookiemirror", CookieMirrorResource())
338         r.putChild('delay1', DelayResource(1))
339         r.putChild('delay2', DelayResource(2))
340
341         self.afterFoundGetCounter = CountingResource()
342         r.putChild("afterFoundGetCounter", self.afterFoundGetCounter)
343         r.putChild("afterFoundGetRedirect", util.Redirect("/afterFoundGetCounter"))
344
345         miscasedHead = static.Data("miscased-head GET response content", "major/minor")
346         miscasedHead.render_Head = lambda request: "miscased-head content"
347         r.putChild("miscased-head", miscasedHead)
348
349         self.extendedRedirect = ExtendedRedirect('/extendedRedirect')
350         r.putChild("extendedRedirect", self.extendedRedirect)
351         self.site = server.Site(r, timeout=None)
352         self.wrapper = WrappingFactory(self.site)
353         self.port = self._listen(self.wrapper)
354         self.portno = self.port.getHost().port
355
356     def tearDown(self):
357         if self.agent:
358             # clean up connections for twisted.web.client.Agent test.
359             self.agent.closeCachedConnections()
360             self.agent = None
361
362         # If the test indicated it might leave some server-side connections
363         # around, clean them up.
364         connections = self.wrapper.protocols.keys()
365         # If there are fewer server-side connections than requested,
366         # that's okay.  Some might have noticed that the client closed
367         # the connection and cleaned up after themselves.
368         for n in range(min(len(connections), self.cleanupServerConnections)):
369             proto = connections.pop()
370             msg("Closing %r" % (proto,))
371             proto.transport.loseConnection()
372         if connections:
373             msg("Some left-over connections; this test is probably buggy.")
374         return self.port.stopListening()
375
376     def getURL(self, path):
377         host = "http://127.0.0.1:%d/" % self.portno
378         return urljoin(host, path)
379
380     def testPayload(self):
381         s = "0123456789" * 10
382         return client.getPage(self.getURL("payload"), postdata=s
383             ).addCallback(self.assertEqual, s
384             )
385
386
387     def test_getPageBrokenDownload(self):
388         """
389         If the connection is closed before the number of bytes indicated by
390         I{Content-Length} have been received, the L{Deferred} returned by
391         L{getPage} fails with L{PartialDownloadError}.
392         """
393         d = client.getPage(self.getURL("broken"))
394         d = self.assertFailure(d, client.PartialDownloadError)
395         d.addCallback(lambda exc: self.assertEqual(exc.response, "abc"))
396         return d
397
398
399     def test_downloadPageBrokenDownload(self):
400         """
401         If the connection is closed before the number of bytes indicated by
402         I{Content-Length} have been received, the L{Deferred} returned by
403         L{downloadPage} fails with L{PartialDownloadError}.
404         """
405         # test what happens when download gets disconnected in the middle
406         path = FilePath(self.mktemp())
407         d = client.downloadPage(self.getURL("broken"), path.path)
408         d = self.assertFailure(d, client.PartialDownloadError)
409
410         def checkResponse(response):
411             """
412             The HTTP status code from the server is propagated through the
413             C{PartialDownloadError}.
414             """
415             self.assertEqual(response.status, "200")
416             self.assertEqual(response.message, "OK")
417             return response
418         d.addCallback(checkResponse)
419
420         def cbFailed(ignored):
421             self.assertEqual(path.getContent(), "abc")
422         d.addCallback(cbFailed)
423         return d
424
425
426     def test_downloadPageLogsFileCloseError(self):
427         """
428         If there is an exception closing the file being written to after the
429         connection is prematurely closed, that exception is logged.
430         """
431         class BrokenFile:
432             def write(self, bytes):
433                 pass
434
435             def close(self):
436                 raise IOError(ENOSPC, "No file left on device")
437
438         d = client.downloadPage(self.getURL("broken"), BrokenFile())
439         d = self.assertFailure(d, client.PartialDownloadError)
440         def cbFailed(ignored):
441             self.assertEqual(len(self.flushLoggedErrors(IOError)), 1)
442         d.addCallback(cbFailed)
443         return d
444
445
446     def testHostHeader(self):
447         # if we pass Host header explicitly, it should be used, otherwise
448         # it should extract from url
449         return defer.gatherResults([
450             client.getPage(self.getURL("host")).addCallback(self.assertEqual, "127.0.0.1:%s" % (self.portno,)),
451             client.getPage(self.getURL("host"), headers={"Host": "www.example.com"}).addCallback(self.assertEqual, "www.example.com")])
452
453
454     def test_getPage(self):
455         """
456         L{client.getPage} returns a L{Deferred} which is called back with
457         the body of the response if the default method B{GET} is used.
458         """
459         d = client.getPage(self.getURL("file"))
460         d.addCallback(self.assertEqual, "0123456789")
461         return d
462
463
464     def test_getPageHEAD(self):
465         """
466         L{client.getPage} returns a L{Deferred} which is called back with
467         the empty string if the method is I{HEAD} and there is a successful
468         response code.
469         """
470         d = client.getPage(self.getURL("file"), method="HEAD")
471         d.addCallback(self.assertEqual, "")
472         return d
473
474
475     def test_getPageNotQuiteHEAD(self):
476         """
477         If the request method is a different casing of I{HEAD} (ie, not all
478         capitalized) then it is not a I{HEAD} request and the response body
479         is returned.
480         """
481         d = client.getPage(self.getURL("miscased-head"), method='Head')
482         d.addCallback(self.assertEqual, "miscased-head content")
483         return d
484
485
486     def test_timeoutNotTriggering(self):
487         """
488         When a non-zero timeout is passed to L{getPage} and the page is
489         retrieved before the timeout period elapses, the L{Deferred} is
490         called back with the contents of the page.
491         """
492         d = client.getPage(self.getURL("host"), timeout=100)
493         d.addCallback(self.assertEqual, "127.0.0.1:%s" % (self.portno,))
494         return d
495
496
497     def test_timeoutTriggering(self):
498         """
499         When a non-zero timeout is passed to L{getPage} and that many
500         seconds elapse before the server responds to the request. the
501         L{Deferred} is errbacked with a L{error.TimeoutError}.
502         """
503         # This will probably leave some connections around.
504         self.cleanupServerConnections = 1
505         return self.assertFailure(
506             client.getPage(self.getURL("wait"), timeout=0.000001),
507             defer.TimeoutError)
508
509
510     def testDownloadPage(self):
511         downloads = []
512         downloadData = [("file", self.mktemp(), "0123456789"),
513                         ("nolength", self.mktemp(), "nolength")]
514
515         for (url, name, data) in downloadData:
516             d = client.downloadPage(self.getURL(url), name)
517             d.addCallback(self._cbDownloadPageTest, data, name)
518             downloads.append(d)
519         return defer.gatherResults(downloads)
520
521     def _cbDownloadPageTest(self, ignored, data, name):
522         bytes = file(name, "rb").read()
523         self.assertEqual(bytes, data)
524
525     def testDownloadPageError1(self):
526         class errorfile:
527             def write(self, data):
528                 raise IOError, "badness happened during write"
529             def close(self):
530                 pass
531         ef = errorfile()
532         return self.assertFailure(
533             client.downloadPage(self.getURL("file"), ef),
534             IOError)
535
536     def testDownloadPageError2(self):
537         class errorfile:
538             def write(self, data):
539                 pass
540             def close(self):
541                 raise IOError, "badness happened during close"
542         ef = errorfile()
543         return self.assertFailure(
544             client.downloadPage(self.getURL("file"), ef),
545             IOError)
546
547     def testDownloadPageError3(self):
548         # make sure failures in open() are caught too. This is tricky.
549         # Might only work on posix.
550         tmpfile = open("unwritable", "wb")
551         tmpfile.close()
552         os.chmod("unwritable", 0) # make it unwritable (to us)
553         d = self.assertFailure(
554             client.downloadPage(self.getURL("file"), "unwritable"),
555             IOError)
556         d.addBoth(self._cleanupDownloadPageError3)
557         return d
558
559     def _cleanupDownloadPageError3(self, ignored):
560         os.chmod("unwritable", 0700)
561         os.unlink("unwritable")
562         return ignored
563
564     def _downloadTest(self, method):
565         dl = []
566         for (url, code) in [("nosuchfile", "404"), ("error", "401"),
567                             ("error?showlength=1", "401")]:
568             d = method(url)
569             d = self.assertFailure(d, error.Error)
570             d.addCallback(lambda exc, code=code: self.assertEqual(exc.args[0], code))
571             dl.append(d)
572         return defer.DeferredList(dl, fireOnOneErrback=True)
573
574     def testServerError(self):
575         return self._downloadTest(lambda url: client.getPage(self.getURL(url)))
576
577     def testDownloadServerError(self):
578         return self._downloadTest(lambda url: client.downloadPage(self.getURL(url), url.split('?')[0]))
579
580     def testFactoryInfo(self):
581         url = self.getURL('file')
582         scheme, host, port, path = client._parse(url)
583         factory = client.HTTPClientFactory(url)
584         reactor.connectTCP(host, port, factory)
585         return factory.deferred.addCallback(self._cbFactoryInfo, factory)
586
587     def _cbFactoryInfo(self, ignoredResult, factory):
588         self.assertEqual(factory.status, '200')
589         self.assert_(factory.version.startswith('HTTP/'))
590         self.assertEqual(factory.message, 'OK')
591         self.assertEqual(factory.response_headers['content-length'][0], '10')
592
593
594     def test_followRedirect(self):
595         """
596         By default, L{client.getPage} follows redirects and returns the content
597         of the target resource.
598         """
599         d = client.getPage(self.getURL("redirect"))
600         d.addCallback(self.assertEqual, "0123456789")
601         return d
602
603
604     def test_noFollowRedirect(self):
605         """
606         If C{followRedirect} is passed a false value, L{client.getPage} does not
607         follow redirects and returns a L{Deferred} which fails with
608         L{error.PageRedirect} when it encounters one.
609         """
610         d = self.assertFailure(
611             client.getPage(self.getURL("redirect"), followRedirect=False),
612             error.PageRedirect)
613         d.addCallback(self._cbCheckLocation)
614         return d
615
616
617     def _cbCheckLocation(self, exc):
618         self.assertEqual(exc.location, "/file")
619
620
621     def test_infiniteRedirection(self):
622         """
623         When more than C{redirectLimit} HTTP redirects are encountered, the
624         page request fails with L{InfiniteRedirection}.
625         """
626         def checkRedirectCount(*a):
627             self.assertEqual(f._redirectCount, 13)
628             self.assertEqual(self.infiniteRedirectResource.count, 13)
629
630         f = client._makeGetterFactory(
631             self.getURL('infiniteRedirect'),
632             client.HTTPClientFactory,
633             redirectLimit=13)
634         d = self.assertFailure(f.deferred, error.InfiniteRedirection)
635         d.addCallback(checkRedirectCount)
636         return d
637
638
639     def test_isolatedFollowRedirect(self):
640         """
641         C{client.HTTPPagerGetter} instances each obey the C{followRedirect}
642         value passed to the L{client.getPage} call which created them.
643         """
644         d1 = client.getPage(self.getURL('redirect'), followRedirect=True)
645         d2 = client.getPage(self.getURL('redirect'), followRedirect=False)
646
647         d = self.assertFailure(d2, error.PageRedirect
648             ).addCallback(lambda dummy: d1)
649         return d
650
651
652     def test_afterFoundGet(self):
653         """
654         Enabling unsafe redirection behaviour overwrites the method of
655         redirected C{POST} requests with C{GET}.
656         """
657         url = self.getURL('extendedRedirect?code=302')
658         f = client.HTTPClientFactory(url, followRedirect=True, method="POST")
659         self.assertFalse(
660             f.afterFoundGet,
661             "By default, afterFoundGet must be disabled")
662
663         def gotPage(page):
664             self.assertEqual(
665                 self.extendedRedirect.lastMethod,
666                 "GET",
667                 "With afterFoundGet, the HTTP method must change to GET")
668
669         d = client.getPage(
670             url, followRedirect=True, afterFoundGet=True, method="POST")
671         d.addCallback(gotPage)
672         return d
673
674
675     def test_downloadAfterFoundGet(self):
676         """
677         Passing C{True} for C{afterFoundGet} to L{client.downloadPage} invokes
678         the same kind of redirect handling as passing that argument to
679         L{client.getPage} invokes.
680         """
681         url = self.getURL('extendedRedirect?code=302')
682
683         def gotPage(page):
684             self.assertEqual(
685                 self.extendedRedirect.lastMethod,
686                 "GET",
687                 "With afterFoundGet, the HTTP method must change to GET")
688
689         d = client.downloadPage(url, "downloadTemp",
690             followRedirect=True, afterFoundGet=True, method="POST")
691         d.addCallback(gotPage)
692         return d
693
694
695     def test_afterFoundGetMakesOneRequest(self):
696         """
697         When C{afterFoundGet} is C{True}, L{client.getPage} only issues one
698         request to the server when following the redirect.  This is a regression
699         test, see #4760.
700         """
701         def checkRedirectCount(*a):
702             self.assertEqual(self.afterFoundGetCounter.count, 1)
703
704         url = self.getURL('afterFoundGetRedirect')
705         d = client.getPage(
706             url, followRedirect=True, afterFoundGet=True, method="POST")
707         d.addCallback(checkRedirectCount)
708         return d
709
710
711     def testPartial(self):
712         name = self.mktemp()
713         f = open(name, "wb")
714         f.write("abcd")
715         f.close()
716
717         partialDownload = [(True, "abcd456789"),
718                            (True, "abcd456789"),
719                            (False, "0123456789")]
720
721         d = defer.succeed(None)
722         for (partial, expectedData) in partialDownload:
723             d.addCallback(self._cbRunPartial, name, partial)
724             d.addCallback(self._cbPartialTest, expectedData, name)
725
726         return d
727
728     testPartial.skip = "Cannot test until webserver can serve partial data properly"
729
730     def _cbRunPartial(self, ignored, name, partial):
731         return client.downloadPage(self.getURL("file"), name, supportPartial=partial)
732
733     def _cbPartialTest(self, ignored, expectedData, filename):
734         bytes = file(filename, "rb").read()
735         self.assertEqual(bytes, expectedData)
736
737
738     def test_downloadTimeout(self):
739         """
740         If the timeout indicated by the C{timeout} parameter to
741         L{client.HTTPDownloader.__init__} elapses without the complete response
742         being received, the L{defer.Deferred} returned by
743         L{client.downloadPage} fires with a L{Failure} wrapping a
744         L{defer.TimeoutError}.
745         """
746         self.cleanupServerConnections = 2
747         # Verify the behavior if no bytes are ever written.
748         first = client.downloadPage(
749             self.getURL("wait"),
750             self.mktemp(), timeout=0.01)
751
752         # Verify the behavior if some bytes are written but then the request
753         # never completes.
754         second = client.downloadPage(
755             self.getURL("write-then-wait"),
756             self.mktemp(), timeout=0.01)
757
758         return defer.gatherResults([
759             self.assertFailure(first, defer.TimeoutError),
760             self.assertFailure(second, defer.TimeoutError)])
761
762
763     def test_downloadHeaders(self):
764         """
765         After L{client.HTTPDownloader.deferred} fires, the
766         L{client.HTTPDownloader} instance's C{status} and C{response_headers}
767         attributes are populated with the values from the response.
768         """
769         def checkHeaders(factory):
770             self.assertEqual(factory.status, '200')
771             self.assertEqual(factory.response_headers['content-type'][0], 'text/html')
772             self.assertEqual(factory.response_headers['content-length'][0], '10')
773             os.unlink(factory.fileName)
774         factory = client._makeGetterFactory(
775             self.getURL('file'),
776             client.HTTPDownloader,
777             fileOrName=self.mktemp())
778         return factory.deferred.addCallback(lambda _: checkHeaders(factory))
779
780
781     def test_downloadCookies(self):
782         """
783         The C{cookies} dict passed to the L{client.HTTPDownloader}
784         initializer is used to populate the I{Cookie} header included in the
785         request sent to the server.
786         """
787         output = self.mktemp()
788         factory = client._makeGetterFactory(
789             self.getURL('cookiemirror'),
790             client.HTTPDownloader,
791             fileOrName=output,
792             cookies={'foo': 'bar'})
793         def cbFinished(ignored):
794             self.assertEqual(
795                 FilePath(output).getContent(),
796                 "[('foo', 'bar')]")
797         factory.deferred.addCallback(cbFinished)
798         return factory.deferred
799
800
801     def test_downloadRedirectLimit(self):
802         """
803         When more than C{redirectLimit} HTTP redirects are encountered, the
804         page request fails with L{InfiniteRedirection}.
805         """
806         def checkRedirectCount(*a):
807             self.assertEqual(f._redirectCount, 7)
808             self.assertEqual(self.infiniteRedirectResource.count, 7)
809
810         f = client._makeGetterFactory(
811             self.getURL('infiniteRedirect'),
812             client.HTTPDownloader,
813             fileOrName=self.mktemp(),
814             redirectLimit=7)
815         d = self.assertFailure(f.deferred, error.InfiniteRedirection)
816         d.addCallback(checkRedirectCount)
817         return d
818
819
820
821 class WebClientSSLTestCase(WebClientTestCase):
822     def _listen(self, site):
823         from twisted import test
824         return reactor.listenSSL(0, site,
825                                  contextFactory=ssl.DefaultOpenSSLContextFactory(
826             FilePath(test.__file__).sibling('server.pem').path,
827             FilePath(test.__file__).sibling('server.pem').path,
828             ),
829                                  interface="127.0.0.1")
830
831     def getURL(self, path):
832         return "https://127.0.0.1:%d/%s" % (self.portno, path)
833
834     def testFactoryInfo(self):
835         url = self.getURL('file')
836         scheme, host, port, path = client._parse(url)
837         factory = client.HTTPClientFactory(url)
838         reactor.connectSSL(host, port, factory, ssl.ClientContextFactory())
839         # The base class defines _cbFactoryInfo correctly for this
840         return factory.deferred.addCallback(self._cbFactoryInfo, factory)
841
842
843
844 class WebClientRedirectBetweenSSLandPlainText(unittest.TestCase):
845     def getHTTPS(self, path):
846         return "https://127.0.0.1:%d/%s" % (self.tlsPortno, path)
847
848     def getHTTP(self, path):
849         return "http://127.0.0.1:%d/%s" % (self.plainPortno, path)
850
851     def setUp(self):
852         plainRoot = static.Data('not me', 'text/plain')
853         tlsRoot = static.Data('me neither', 'text/plain')
854
855         plainSite = server.Site(plainRoot, timeout=None)
856         tlsSite = server.Site(tlsRoot, timeout=None)
857
858         from twisted import test
859         self.tlsPort = reactor.listenSSL(0, tlsSite,
860                                          contextFactory=ssl.DefaultOpenSSLContextFactory(
861             FilePath(test.__file__).sibling('server.pem').path,
862             FilePath(test.__file__).sibling('server.pem').path,
863             ),
864                                          interface="127.0.0.1")
865         self.plainPort = reactor.listenTCP(0, plainSite, interface="127.0.0.1")
866
867         self.plainPortno = self.plainPort.getHost().port
868         self.tlsPortno = self.tlsPort.getHost().port
869
870         plainRoot.putChild('one', util.Redirect(self.getHTTPS('two')))
871         tlsRoot.putChild('two', util.Redirect(self.getHTTP('three')))
872         plainRoot.putChild('three', util.Redirect(self.getHTTPS('four')))
873         tlsRoot.putChild('four', static.Data('FOUND IT!', 'text/plain'))
874
875     def tearDown(self):
876         ds = map(defer.maybeDeferred,
877                  [self.plainPort.stopListening, self.tlsPort.stopListening])
878         return defer.gatherResults(ds)
879
880     def testHoppingAround(self):
881         return client.getPage(self.getHTTP("one")
882             ).addCallback(self.assertEqual, "FOUND IT!"
883             )
884
885 class FakeTransport:
886     disconnecting = False
887     def __init__(self):
888         self.data = []
889     def write(self, stuff):
890         self.data.append(stuff)
891
892 class CookieTestCase(unittest.TestCase):
893     def _listen(self, site):
894         return reactor.listenTCP(0, site, interface="127.0.0.1")
895
896     def setUp(self):
897         root = static.Data('El toro!', 'text/plain')
898         root.putChild("cookiemirror", CookieMirrorResource())
899         root.putChild("rawcookiemirror", RawCookieMirrorResource())
900         site = server.Site(root, timeout=None)
901         self.port = self._listen(site)
902         self.portno = self.port.getHost().port
903
904     def tearDown(self):
905         return self.port.stopListening()
906
907     def getHTTP(self, path):
908         return "http://127.0.0.1:%d/%s" % (self.portno, path)
909
910     def testNoCookies(self):
911         return client.getPage(self.getHTTP("cookiemirror")
912             ).addCallback(self.assertEqual, "[]"
913             )
914
915     def testSomeCookies(self):
916         cookies = {'foo': 'bar', 'baz': 'quux'}
917         return client.getPage(self.getHTTP("cookiemirror"), cookies=cookies
918             ).addCallback(self.assertEqual, "[('baz', 'quux'), ('foo', 'bar')]"
919             )
920
921     def testRawNoCookies(self):
922         return client.getPage(self.getHTTP("rawcookiemirror")
923             ).addCallback(self.assertEqual, "None"
924             )
925
926     def testRawSomeCookies(self):
927         cookies = {'foo': 'bar', 'baz': 'quux'}
928         return client.getPage(self.getHTTP("rawcookiemirror"), cookies=cookies
929             ).addCallback(self.assertEqual, "'foo=bar; baz=quux'"
930             )
931
932     def testCookieHeaderParsing(self):
933         factory = client.HTTPClientFactory('http://foo.example.com/')
934         proto = factory.buildProtocol('127.42.42.42')
935         proto.transport = FakeTransport()
936         proto.connectionMade()
937         for line in [
938             '200 Ok',
939             'Squash: yes',
940             'Hands: stolen',
941             'Set-Cookie: CUSTOMER=WILE_E_COYOTE; path=/; expires=Wednesday, 09-Nov-99 23:12:40 GMT',
942             'Set-Cookie: PART_NUMBER=ROCKET_LAUNCHER_0001; path=/',
943             'Set-Cookie: SHIPPING=FEDEX; path=/foo',
944             '',
945             'body',
946             'more body',
947             ]:
948             proto.dataReceived(line + '\r\n')
949         self.assertEqual(proto.transport.data,
950                           ['GET / HTTP/1.0\r\n',
951                            'Host: foo.example.com\r\n',
952                            'User-Agent: Twisted PageGetter\r\n',
953                            '\r\n'])
954         self.assertEqual(factory.cookies,
955                           {
956             'CUSTOMER': 'WILE_E_COYOTE',
957             'PART_NUMBER': 'ROCKET_LAUNCHER_0001',
958             'SHIPPING': 'FEDEX',
959             })
960
961
962
963 class TestHostHeader(unittest.TestCase):
964     """
965     Test that L{HTTPClientFactory} includes the port in the host header
966     if needed.
967     """
968
969     def _getHost(self, bytes):
970         """
971         Retrieve the value of the I{Host} header from the serialized
972         request given by C{bytes}.
973         """
974         for line in bytes.splitlines():
975             try:
976                 name, value = line.split(':', 1)
977                 if name.strip().lower() == 'host':
978                     return value.strip()
979             except ValueError:
980                 pass
981
982
983     def test_HTTPDefaultPort(self):
984         """
985         No port should be included in the host header when connecting to the
986         default HTTP port.
987         """
988         factory = client.HTTPClientFactory('http://foo.example.com/')
989         proto = factory.buildProtocol('127.42.42.42')
990         proto.makeConnection(StringTransport())
991         self.assertEqual(self._getHost(proto.transport.value()),
992                           'foo.example.com')
993
994
995     def test_HTTPPort80(self):
996         """
997         No port should be included in the host header when connecting to the
998         default HTTP port even if it is in the URL.
999         """
1000         factory = client.HTTPClientFactory('http://foo.example.com:80/')
1001         proto = factory.buildProtocol('127.42.42.42')
1002         proto.makeConnection(StringTransport())
1003         self.assertEqual(self._getHost(proto.transport.value()),
1004                           'foo.example.com')
1005
1006
1007     def test_HTTPNotPort80(self):
1008         """
1009         The port should be included in the host header when connecting to the
1010         a non default HTTP port.
1011         """
1012         factory = client.HTTPClientFactory('http://foo.example.com:8080/')
1013         proto = factory.buildProtocol('127.42.42.42')
1014         proto.makeConnection(StringTransport())
1015         self.assertEqual(self._getHost(proto.transport.value()),
1016                           'foo.example.com:8080')
1017
1018
1019     def test_HTTPSDefaultPort(self):
1020         """
1021         No port should be included in the host header when connecting to the
1022         default HTTPS port.
1023         """
1024         factory = client.HTTPClientFactory('https://foo.example.com/')
1025         proto = factory.buildProtocol('127.42.42.42')
1026         proto.makeConnection(StringTransport())
1027         self.assertEqual(self._getHost(proto.transport.value()),
1028                           'foo.example.com')
1029
1030
1031     def test_HTTPSPort443(self):
1032         """
1033         No port should be included in the host header when connecting to the
1034         default HTTPS port even if it is in the URL.
1035         """
1036         factory = client.HTTPClientFactory('https://foo.example.com:443/')
1037         proto = factory.buildProtocol('127.42.42.42')
1038         proto.makeConnection(StringTransport())
1039         self.assertEqual(self._getHost(proto.transport.value()),
1040                           'foo.example.com')
1041
1042
1043     def test_HTTPSNotPort443(self):
1044         """
1045         The port should be included in the host header when connecting to the
1046         a non default HTTPS port.
1047         """
1048         factory = client.HTTPClientFactory('http://foo.example.com:8080/')
1049         proto = factory.buildProtocol('127.42.42.42')
1050         proto.makeConnection(StringTransport())
1051         self.assertEqual(self._getHost(proto.transport.value()),
1052                           'foo.example.com:8080')
1053
1054
1055
1056 class StubHTTPProtocol(Protocol):
1057     """
1058     A protocol like L{HTTP11ClientProtocol} but which does not actually know
1059     HTTP/1.1 and only collects requests in a list.
1060
1061     @ivar requests: A C{list} of two-tuples.  Each time a request is made, a
1062         tuple consisting of the request and the L{Deferred} returned from the
1063         request method is appended to this list.
1064     """
1065     def __init__(self):
1066         self.requests = []
1067         self.state = 'QUIESCENT'
1068
1069
1070     def request(self, request):
1071         """
1072         Capture the given request for later inspection.
1073
1074         @return: A L{Deferred} which this code will never fire.
1075         """
1076         result = Deferred()
1077         self.requests.append((request, result))
1078         return result
1079
1080
1081
1082 class FileConsumer(object):
1083     def __init__(self, outputFile):
1084         self.outputFile = outputFile
1085
1086
1087     def write(self, bytes):
1088         self.outputFile.write(bytes)
1089
1090
1091
1092 class FileBodyProducerTests(unittest.TestCase):
1093     """
1094     Tests for the L{FileBodyProducer} which reads bytes from a file and writes
1095     them to an L{IConsumer}.
1096     """
1097     _NO_RESULT = object()
1098
1099     def _resultNow(self, deferred):
1100         """
1101         Return the current result of C{deferred} if it is not a failure.  If it
1102         has no result, return C{self._NO_RESULT}.  If it is a failure, raise an
1103         exception.
1104         """
1105         result = []
1106         failure = []
1107         deferred.addCallbacks(result.append, failure.append)
1108         if len(result) == 1:
1109             return result[0]
1110         elif len(failure) == 1:
1111             raise Exception(
1112                 "Deferred had failure instead of success: %r" % (failure[0],))
1113         return self._NO_RESULT
1114
1115
1116     def _failureNow(self, deferred):
1117         """
1118         Return the current result of C{deferred} if it is a failure.  If it has
1119         no result, return C{self._NO_RESULT}.  If it is not a failure, raise an
1120         exception.
1121         """
1122         result = []
1123         failure = []
1124         deferred.addCallbacks(result.append, failure.append)
1125         if len(result) == 1:
1126             raise Exception(
1127                 "Deferred had success instead of failure: %r" % (result[0],))
1128         elif len(failure) == 1:
1129             return failure[0]
1130         return self._NO_RESULT
1131
1132
1133     def _termination(self):
1134         """
1135         This method can be used as the C{terminationPredicateFactory} for a
1136         L{Cooperator}.  It returns a predicate which immediately returns
1137         C{False}, indicating that no more work should be done this iteration.
1138         This has the result of only allowing one iteration of a cooperative
1139         task to be run per L{Cooperator} iteration.
1140         """
1141         return lambda: True
1142
1143
1144     def setUp(self):
1145         """
1146         Create a L{Cooperator} hooked up to an easily controlled, deterministic
1147         scheduler to use with L{FileBodyProducer}.
1148         """
1149         self._scheduled = []
1150         self.cooperator = task.Cooperator(
1151             self._termination, self._scheduled.append)
1152
1153
1154     def test_interface(self):
1155         """
1156         L{FileBodyProducer} instances provide L{IBodyProducer}.
1157         """
1158         self.assertTrue(verifyObject(
1159                 IBodyProducer, FileBodyProducer(StringIO(""))))
1160
1161
1162     def test_unknownLength(self):
1163         """
1164         If the L{FileBodyProducer} is constructed with a file-like object
1165         without either a C{seek} or C{tell} method, its C{length} attribute is
1166         set to C{UNKNOWN_LENGTH}.
1167         """
1168         class HasSeek(object):
1169             def seek(self, offset, whence):
1170                 pass
1171
1172         class HasTell(object):
1173             def tell(self):
1174                 pass
1175
1176         producer = FileBodyProducer(HasSeek())
1177         self.assertEqual(UNKNOWN_LENGTH, producer.length)
1178         producer = FileBodyProducer(HasTell())
1179         self.assertEqual(UNKNOWN_LENGTH, producer.length)
1180
1181
1182     def test_knownLength(self):
1183         """
1184         If the L{FileBodyProducer} is constructed with a file-like object with
1185         both C{seek} and C{tell} methods, its C{length} attribute is set to the
1186         size of the file as determined by those methods.
1187         """
1188         inputBytes = "here are some bytes"
1189         inputFile = StringIO(inputBytes)
1190         inputFile.seek(5)
1191         producer = FileBodyProducer(inputFile)
1192         self.assertEqual(len(inputBytes) - 5, producer.length)
1193         self.assertEqual(inputFile.tell(), 5)
1194
1195
1196     def test_defaultCooperator(self):
1197         """
1198         If no L{Cooperator} instance is passed to L{FileBodyProducer}, the
1199         global cooperator is used.
1200         """
1201         producer = FileBodyProducer(StringIO(""))
1202         self.assertEqual(task.cooperate, producer._cooperate)
1203
1204
1205     def test_startProducing(self):
1206         """
1207         L{FileBodyProducer.startProducing} starts writing bytes from the input
1208         file to the given L{IConsumer} and returns a L{Deferred} which fires
1209         when they have all been written.
1210         """
1211         expectedResult = "hello, world"
1212         readSize = 3
1213         output = StringIO()
1214         consumer = FileConsumer(output)
1215         producer = FileBodyProducer(
1216             StringIO(expectedResult), self.cooperator, readSize)
1217         complete = producer.startProducing(consumer)
1218         for i in range(len(expectedResult) / readSize + 1):
1219             self._scheduled.pop(0)()
1220         self.assertEqual([], self._scheduled)
1221         self.assertEqual(expectedResult, output.getvalue())
1222         self.assertEqual(None, self._resultNow(complete))
1223
1224
1225     def test_inputClosedAtEOF(self):
1226         """
1227         When L{FileBodyProducer} reaches end-of-file on the input file given to
1228         it, the input file is closed.
1229         """
1230         readSize = 4
1231         inputBytes = "some friendly bytes"
1232         inputFile = StringIO(inputBytes)
1233         producer = FileBodyProducer(inputFile, self.cooperator, readSize)
1234         consumer = FileConsumer(StringIO())
1235         producer.startProducing(consumer)
1236         for i in range(len(inputBytes) / readSize + 2):
1237             self._scheduled.pop(0)()
1238         self.assertTrue(inputFile.closed)
1239
1240
1241     def test_failedReadWhileProducing(self):
1242         """
1243         If a read from the input file fails while producing bytes to the
1244         consumer, the L{Deferred} returned by
1245         L{FileBodyProducer.startProducing} fires with a L{Failure} wrapping
1246         that exception.
1247         """
1248         class BrokenFile(object):
1249             def read(self, count):
1250                 raise IOError("Simulated bad thing")
1251         producer = FileBodyProducer(BrokenFile(), self.cooperator)
1252         complete = producer.startProducing(FileConsumer(StringIO()))
1253         self._scheduled.pop(0)()
1254         self._failureNow(complete).trap(IOError)
1255
1256
1257     def test_stopProducing(self):
1258         """
1259         L{FileBodyProducer.stopProducing} stops the underlying L{IPullProducer}
1260         and the cooperative task responsible for calling C{resumeProducing} and
1261         closes the input file but does not cause the L{Deferred} returned by
1262         C{startProducing} to fire.
1263         """
1264         expectedResult = "hello, world"
1265         readSize = 3
1266         output = StringIO()
1267         consumer = FileConsumer(output)
1268         inputFile = StringIO(expectedResult)
1269         producer = FileBodyProducer(
1270             inputFile, self.cooperator, readSize)
1271         complete = producer.startProducing(consumer)
1272         producer.stopProducing()
1273         self.assertTrue(inputFile.closed)
1274         self._scheduled.pop(0)()
1275         self.assertEqual("", output.getvalue())
1276         self.assertIdentical(self._NO_RESULT, self._resultNow(complete))
1277
1278
1279     def test_pauseProducing(self):
1280         """
1281         L{FileBodyProducer.pauseProducing} temporarily suspends writing bytes
1282         from the input file to the given L{IConsumer}.
1283         """
1284         expectedResult = "hello, world"
1285         readSize = 5
1286         output = StringIO()
1287         consumer = FileConsumer(output)
1288         producer = FileBodyProducer(
1289             StringIO(expectedResult), self.cooperator, readSize)
1290         complete = producer.startProducing(consumer)
1291         self._scheduled.pop(0)()
1292         self.assertEqual(output.getvalue(), expectedResult[:5])
1293         producer.pauseProducing()
1294
1295         # Sort of depends on an implementation detail of Cooperator: even
1296         # though the only task is paused, there's still a scheduled call.  If
1297         # this were to go away because Cooperator became smart enough to cancel
1298         # this call in this case, that would be fine.
1299         self._scheduled.pop(0)()
1300
1301         # Since the producer is paused, no new data should be here.
1302         self.assertEqual(output.getvalue(), expectedResult[:5])
1303         self.assertEqual([], self._scheduled)
1304         self.assertIdentical(self._NO_RESULT, self._resultNow(complete))
1305
1306
1307     def test_resumeProducing(self):
1308         """
1309         L{FileBodyProducer.resumeProducing} re-commences writing bytes from the
1310         input file to the given L{IConsumer} after it was previously paused
1311         with L{FileBodyProducer.pauseProducing}.
1312         """
1313         expectedResult = "hello, world"
1314         readSize = 5
1315         output = StringIO()
1316         consumer = FileConsumer(output)
1317         producer = FileBodyProducer(
1318             StringIO(expectedResult), self.cooperator, readSize)
1319         producer.startProducing(consumer)
1320         self._scheduled.pop(0)()
1321         self.assertEqual(expectedResult[:readSize], output.getvalue())
1322         producer.pauseProducing()
1323         producer.resumeProducing()
1324         self._scheduled.pop(0)()
1325         self.assertEqual(expectedResult[:readSize * 2], output.getvalue())
1326
1327
1328
1329 class FakeReactorAndConnectMixin:
1330     """
1331     A test mixin providing a testable C{Reactor} class and a dummy C{connect}
1332     method which allows instances to pretend to be endpoints.
1333     """
1334
1335     class Reactor(MemoryReactor, Clock):
1336         def __init__(self):
1337             MemoryReactor.__init__(self)
1338             Clock.__init__(self)
1339
1340
1341     class StubEndpoint(object):
1342         """
1343         Endpoint that wraps existing endpoint, substitutes StubHTTPProtocol, and
1344         resulting protocol instances are attached to the given test case.
1345         """
1346
1347         def __init__(self, endpoint, testCase):
1348             self.endpoint = endpoint
1349             self.testCase = testCase
1350             self.factory = _HTTP11ClientFactory(lambda p: None)
1351             self.protocol = StubHTTPProtocol()
1352             self.factory.buildProtocol = lambda addr: self.protocol
1353
1354         def connect(self, ignoredFactory):
1355             self.testCase.protocol = self.protocol
1356             self.endpoint.connect(self.factory)
1357             return succeed(self.protocol)
1358
1359
1360     def buildAgentForWrapperTest(self, reactor):
1361         """
1362         Return an Agent suitable for use in tests that wrap the Agent and want
1363         both a fake reactor and StubHTTPProtocol.
1364         """
1365         agent = client.Agent(reactor)
1366         _oldGetEndpoint = agent._getEndpoint
1367         agent._getEndpoint = lambda *args: (
1368             self.StubEndpoint(_oldGetEndpoint(*args), self))
1369         return agent
1370
1371
1372     def connect(self, factory):
1373         """
1374         Fake implementation of an endpoint which synchronously
1375         succeeds with an instance of L{StubHTTPProtocol} for ease of
1376         testing.
1377         """
1378         protocol = StubHTTPProtocol()
1379         protocol.makeConnection(None)
1380         self.protocol = protocol
1381         return succeed(protocol)
1382
1383
1384
1385 class DummyEndpoint(object):
1386     """
1387     An endpoint that uses a fake transport.
1388     """
1389
1390     def connect(self, factory):
1391         protocol = factory.buildProtocol(None)
1392         protocol.makeConnection(StringTransport())
1393         return succeed(protocol)
1394
1395
1396
1397 class BadEndpoint(object):
1398     """
1399     An endpoint that shouldn't be called.
1400     """
1401
1402     def connect(self, factory):
1403         raise RuntimeError("This endpoint should not have been used.")
1404
1405
1406 class DummyFactory(Factory):
1407     """
1408     Create C{StubHTTPProtocol} instances.
1409     """
1410     def __init__(self, quiescentCallback):
1411         pass
1412
1413     protocol = StubHTTPProtocol
1414
1415
1416
1417 class HTTPConnectionPoolTests(unittest.TestCase, FakeReactorAndConnectMixin):
1418     """
1419     Tests for the L{HTTPConnectionPool} class.
1420     """
1421
1422     def setUp(self):
1423         self.fakeReactor = self.Reactor()
1424         self.pool = HTTPConnectionPool(self.fakeReactor)
1425         self.pool._factory = DummyFactory
1426         # The retry code path is tested in HTTPConnectionPoolRetryTests:
1427         self.pool.retryAutomatically = False
1428
1429
1430     def test_getReturnsNewIfCacheEmpty(self):
1431         """
1432         If there are no cached connections,
1433         L{HTTPConnectionPool.getConnection} returns a new connection.
1434         """
1435         self.assertEqual(self.pool._connections, {})
1436
1437         def gotConnection(conn):
1438             self.assertIsInstance(conn, StubHTTPProtocol)
1439             # The new connection is not stored in the pool:
1440             self.assertNotIn(conn, self.pool._connections.values())
1441
1442         unknownKey = 12245
1443         d = self.pool.getConnection(unknownKey, DummyEndpoint())
1444         return d.addCallback(gotConnection)
1445
1446
1447     def test_putStartsTimeout(self):
1448         """
1449         If a connection is put back to the pool, a 240-sec timeout is started.
1450
1451         When the timeout hits, the connection is closed and removed from the
1452         pool.
1453         """
1454         # We start out with one cached connection:
1455         protocol = StubHTTPProtocol()
1456         protocol.makeConnection(StringTransport())
1457         self.pool._putConnection(("http", "example.com", 80), protocol)
1458
1459         # Connection is in pool, still not closed:
1460         self.assertEqual(protocol.transport.disconnecting, False)
1461         self.assertIn(protocol,
1462                       self.pool._connections[("http", "example.com", 80)])
1463
1464         # Advance 239 seconds, still not closed:
1465         self.fakeReactor.advance(239)
1466         self.assertEqual(protocol.transport.disconnecting, False)
1467         self.assertIn(protocol,
1468                       self.pool._connections[("http", "example.com", 80)])
1469         self.assertIn(protocol, self.pool._timeouts)
1470
1471         # Advance past 240 seconds, connection will be closed:
1472         self.fakeReactor.advance(1.1)
1473         self.assertEqual(protocol.transport.disconnecting, True)
1474         self.assertNotIn(protocol,
1475                          self.pool._connections[("http", "example.com", 80)])
1476         self.assertNotIn(protocol, self.pool._timeouts)
1477
1478
1479     def test_putExceedsMaxPersistent(self):
1480         """
1481         If an idle connection is put back in the cache and the max number of
1482         persistent connections has been exceeded, one of the connections is
1483         closed and removed from the cache.
1484         """
1485         pool = self.pool
1486
1487         # We start out with two cached connection, the max:
1488         origCached = [StubHTTPProtocol(), StubHTTPProtocol()]
1489         for p in origCached:
1490             p.makeConnection(StringTransport())
1491             pool._putConnection(("http", "example.com", 80), p)
1492         self.assertEqual(pool._connections[("http", "example.com", 80)],
1493                          origCached)
1494         timeouts = pool._timeouts.copy()
1495
1496         # Now we add another one:
1497         newProtocol = StubHTTPProtocol()
1498         newProtocol.makeConnection(StringTransport())
1499         pool._putConnection(("http", "example.com", 80), newProtocol)
1500
1501         # The oldest cached connections will be removed and disconnected:
1502         newCached = pool._connections[("http", "example.com", 80)]
1503         self.assertEqual(len(newCached), 2)
1504         self.assertEqual(newCached, [origCached[1], newProtocol])
1505         self.assertEqual([p.transport.disconnecting for p in newCached],
1506                          [False, False])
1507         self.assertEqual(origCached[0].transport.disconnecting, True)
1508         self.assertTrue(timeouts[origCached[0]].cancelled)
1509         self.assertNotIn(origCached[0], pool._timeouts)
1510
1511
1512     def test_maxPersistentPerHost(self):
1513         """
1514         C{maxPersistentPerHost} is enforced per C{(scheme, host, port)}:
1515         different keys have different max connections.
1516         """
1517         def addProtocol(scheme, host, port):
1518             p = StubHTTPProtocol()
1519             p.makeConnection(StringTransport())
1520             self.pool._putConnection((scheme, host, port), p)
1521             return p
1522         persistent = []
1523         persistent.append(addProtocol("http", "example.com", 80))
1524         persistent.append(addProtocol("http", "example.com", 80))
1525         addProtocol("https", "example.com", 443)
1526         addProtocol("http", "www2.example.com", 80)
1527
1528         self.assertEqual(
1529             self.pool._connections[("http", "example.com", 80)], persistent)
1530         self.assertEqual(
1531             len(self.pool._connections[("https", "example.com", 443)]), 1)
1532         self.assertEqual(
1533             len(self.pool._connections[("http", "www2.example.com", 80)]), 1)
1534
1535
1536     def test_getCachedConnection(self):
1537         """
1538         Getting an address which has a cached connection returns the cached
1539         connection, removes it from the cache and cancels its timeout.
1540         """
1541         # We start out with one cached connection:
1542         protocol = StubHTTPProtocol()
1543         protocol.makeConnection(StringTransport())
1544         self.pool._putConnection(("http", "example.com", 80), protocol)
1545
1546         def gotConnection(conn):
1547             # We got the cached connection:
1548             self.assertIdentical(protocol, conn)
1549             self.assertNotIn(
1550                 conn, self.pool._connections[("http", "example.com", 80)])
1551             # And the timeout was cancelled:
1552             self.fakeReactor.advance(241)
1553             self.assertEqual(conn.transport.disconnecting, False)
1554             self.assertNotIn(conn, self.pool._timeouts)
1555
1556         return self.pool.getConnection(("http", "example.com", 80),
1557                                        BadEndpoint(),
1558                                        ).addCallback(gotConnection)
1559
1560
1561     def test_newConnection(self):
1562         """
1563         The pool's C{_newConnection} method constructs a new connection.
1564         """
1565         # We start out with one cached connection:
1566         protocol = StubHTTPProtocol()
1567         protocol.makeConnection(StringTransport())
1568         key = 12245
1569         self.pool._putConnection(key, protocol)
1570
1571         def gotConnection(newConnection):
1572             # We got a new connection:
1573             self.assertNotIdentical(protocol, newConnection)
1574             # And the old connection is still there:
1575             self.assertIn(protocol, self.pool._connections[key])
1576             # While the new connection is not:
1577             self.assertNotIn(newConnection, self.pool._connections.values())
1578
1579         d = self.pool._newConnection(key, DummyEndpoint())
1580         return d.addCallback(gotConnection)
1581
1582
1583     def test_getSkipsDisconnected(self):
1584         """
1585         When getting connections out of the cache, disconnected connections
1586         are removed and not returned.
1587         """
1588         pool = self.pool
1589         key = ("http", "example.com", 80)
1590
1591         # We start out with two cached connection, the max:
1592         origCached = [StubHTTPProtocol(), StubHTTPProtocol()]
1593         for p in origCached:
1594             p.makeConnection(StringTransport())
1595             pool._putConnection(key, p)
1596         self.assertEqual(pool._connections[key], origCached)
1597
1598         # We close the first one:
1599         origCached[0].state = "DISCONNECTED"
1600
1601         # Now, when we retrive connections we should get the *second* one:
1602         result = []
1603         self.pool.getConnection(key,
1604                                 BadEndpoint()).addCallback(result.append)
1605         self.assertIdentical(result[0], origCached[1])
1606
1607         # And both the disconnected and removed connections should be out of
1608         # the cache:
1609         self.assertEqual(pool._connections[key], [])
1610         self.assertEqual(pool._timeouts, {})
1611
1612
1613     def test_putNotQuiescent(self):
1614         """
1615         If a non-quiescent connection is put back in the cache, an error is
1616         logged.
1617         """
1618         protocol = StubHTTPProtocol()
1619         # By default state is QUIESCENT
1620         self.assertEqual(protocol.state, "QUIESCENT")
1621
1622         protocol.state = "NOTQUIESCENT"
1623         self.pool._putConnection(("http", "example.com", 80), protocol)
1624         error, = self.flushLoggedErrors(RuntimeError)
1625         self.assertEqual(
1626             error.value.args[0],
1627             "BUG: Non-quiescent protocol added to connection pool.")
1628         self.assertIdentical(None, self.pool._connections.get(
1629                 ("http", "example.com", 80)))
1630
1631
1632     def test_getUsesQuiescentCallback(self):
1633         """
1634         When L{HTTPConnectionPool.getConnection} connects, it returns a
1635         C{Deferred} that fires with an instance of L{HTTP11ClientProtocol}
1636         that has the correct quiescent callback attached. When this callback
1637         is called the protocol is returned to the cache correctly, using the
1638         right key.
1639         """
1640         class StringEndpoint(object):
1641             def connect(self, factory):
1642                 p = factory.buildProtocol(None)
1643                 p.makeConnection(StringTransport())
1644                 return succeed(p)
1645
1646         pool = HTTPConnectionPool(self.fakeReactor, True)
1647         pool.retryAutomatically = False
1648         result = []
1649         key = "a key"
1650         pool.getConnection(
1651             key, StringEndpoint()).addCallback(
1652             result.append)
1653         protocol = result[0]
1654         self.assertIsInstance(protocol, HTTP11ClientProtocol)
1655
1656         # Now that we have protocol instance, lets try to put it back in the
1657         # pool:
1658         protocol._state = "QUIESCENT"
1659         protocol._quiescentCallback(protocol)
1660
1661         # If we try to retrive a connection to same destination again, we
1662         # should get the same protocol, because it should've been added back
1663         # to the pool:
1664         result2 = []
1665         pool.getConnection(
1666             key, StringEndpoint()).addCallback(
1667             result2.append)
1668         self.assertIdentical(result2[0], protocol)
1669
1670
1671     def test_closeCachedConnections(self):
1672         """
1673         L{HTTPConnectionPool.closeCachedConnections} closes all cached
1674         connections and removes them from the cache. It returns a Deferred
1675         that fires when they have all lost their connections.
1676         """
1677         persistent = []
1678         def addProtocol(scheme, host, port):
1679             p = HTTP11ClientProtocol()
1680             p.makeConnection(StringTransport())
1681             self.pool._putConnection((scheme, host, port), p)
1682             persistent.append(p)
1683         addProtocol("http", "example.com", 80)
1684         addProtocol("http", "www2.example.com", 80)
1685         doneDeferred = self.pool.closeCachedConnections()
1686
1687         # Connections have begun disconnecting:
1688         for p in persistent:
1689             self.assertEqual(p.transport.disconnecting, True)
1690         self.assertEqual(self.pool._connections, {})
1691         # All timeouts were cancelled and removed:
1692         for dc in self.fakeReactor.getDelayedCalls():
1693             self.assertEqual(dc.cancelled, True)
1694         self.assertEqual(self.pool._timeouts, {})
1695
1696         # Returned Deferred fires when all connections have been closed:
1697         result = []
1698         doneDeferred.addCallback(result.append)
1699         self.assertEqual(result, [])
1700         persistent[0].connectionLost(Failure(ConnectionDone()))
1701         self.assertEqual(result, [])
1702         persistent[1].connectionLost(Failure(ConnectionDone()))
1703         self.assertEqual(result, [None])
1704
1705
1706
1707 class AgentTests(unittest.TestCase, FakeReactorAndConnectMixin):
1708     """
1709     Tests for the new HTTP client API provided by L{Agent}.
1710     """
1711     def setUp(self):
1712         """
1713         Create an L{Agent} wrapped around a fake reactor.
1714         """
1715         self.reactor = self.Reactor()
1716         self.agent = client.Agent(self.reactor)
1717
1718
1719     def completeConnection(self):
1720         """
1721         Do whitebox stuff to finish any outstanding connection attempts the
1722         agent may have initiated.
1723
1724         This spins the fake reactor clock just enough to get L{ClientCreator},
1725         which agent is implemented in terms of, to fire its Deferreds.
1726         """
1727         self.reactor.advance(0)
1728
1729
1730     def test_defaultPool(self):
1731         """
1732         If no pool is passed in, the L{Agent} creates a non-persistent pool.
1733         """
1734         agent = client.Agent(self.reactor)
1735         self.assertIsInstance(agent._pool, HTTPConnectionPool)
1736         self.assertEqual(agent._pool.persistent, False)
1737         self.assertIdentical(agent._reactor, agent._pool._reactor)
1738
1739
1740     def test_persistent(self):
1741         """
1742         If C{persistent} is set to C{True} on the L{HTTPConnectionPool} (the
1743         default), C{Request}s are created with their C{persistent} flag set to
1744         C{True}.
1745         """
1746         pool = HTTPConnectionPool(self.reactor)
1747         agent = client.Agent(self.reactor, pool=pool)
1748         agent._getEndpoint = lambda *args: self
1749         agent.request("GET", "http://127.0.0.1")
1750         self.assertEqual(self.protocol.requests[0][0].persistent, True)
1751
1752
1753     def test_nonPersistent(self):
1754         """
1755         If C{persistent} is set to C{False} when creating the
1756         L{HTTPConnectionPool}, C{Request}s are created with their
1757         C{persistent} flag set to C{False}.
1758
1759         Elsewhere in the tests for the underlying HTTP code we ensure that
1760         this will result in the disconnection of the HTTP protocol once the
1761         request is done, so that the connection will not be returned to the
1762         pool.
1763         """
1764         pool = HTTPConnectionPool(self.reactor, persistent=False)
1765         agent = client.Agent(self.reactor, pool=pool)
1766         agent._getEndpoint = lambda *args: self
1767         agent.request("GET", "http://127.0.0.1")
1768         self.assertEqual(self.protocol.requests[0][0].persistent, False)
1769
1770
1771     def test_connectUsesConnectionPool(self):
1772         """
1773         When a connection is made by the Agent, it uses its pool's
1774         C{getConnection} method to do so, with the endpoint returned by
1775         C{self._getEndpoint}. The key used is C{(scheme, host, port)}.
1776         """
1777         endpoint = DummyEndpoint()
1778         class MyAgent(client.Agent):
1779             def _getEndpoint(this, scheme, host, port):
1780                 self.assertEqual((scheme, host, port),
1781                                  ("http", "foo", 80))
1782                 return endpoint
1783
1784         class DummyPool(object):
1785             connected = False
1786             persistent = False
1787             def getConnection(this, key, ep):
1788                 this.connected = True
1789                 self.assertEqual(ep, endpoint)
1790                 # This is the key the default Agent uses, others will have
1791                 # different keys:
1792                 self.assertEqual(key, ("http", "foo", 80))
1793                 return defer.succeed(StubHTTPProtocol())
1794
1795         pool = DummyPool()
1796         agent = MyAgent(self.reactor, pool=pool)
1797         self.assertIdentical(pool, agent._pool)
1798
1799         headers = http_headers.Headers()
1800         headers.addRawHeader("host", "foo")
1801         bodyProducer = object()
1802         agent.request('GET', 'http://foo/',
1803                       bodyProducer=bodyProducer, headers=headers)
1804         self.assertEqual(agent._pool.connected, True)
1805
1806
1807     def test_unsupportedScheme(self):
1808         """
1809         L{Agent.request} returns a L{Deferred} which fails with
1810         L{SchemeNotSupported} if the scheme of the URI passed to it is not
1811         C{'http'}.
1812         """
1813         return self.assertFailure(
1814             self.agent.request('GET', 'mailto:alice@example.com'),
1815             SchemeNotSupported)
1816
1817
1818     def test_connectionFailed(self):
1819         """
1820         The L{Deferred} returned by L{Agent.request} fires with a L{Failure} if
1821         the TCP connection attempt fails.
1822         """
1823         result = self.agent.request('GET', 'http://foo/')
1824         # Cause the connection to be refused
1825         host, port, factory = self.reactor.tcpClients.pop()[:3]
1826         factory.clientConnectionFailed(None, Failure(ConnectionRefusedError()))
1827         self.completeConnection()
1828         return self.assertFailure(result, ConnectionRefusedError)
1829
1830
1831     def test_connectHTTP(self):
1832         """
1833         L{Agent._getEndpoint} return a C{TCP4ClientEndpoint} when passed a
1834         scheme of C{'http'}.
1835         """
1836         expectedHost = 'example.com'
1837         expectedPort = 1234
1838         endpoint = self.agent._getEndpoint('http', expectedHost, expectedPort)
1839         self.assertEqual(endpoint._host, expectedHost)
1840         self.assertEqual(endpoint._port, expectedPort)
1841         self.assertIsInstance(endpoint, TCP4ClientEndpoint)
1842
1843
1844     def test_connectHTTPS(self):
1845         """
1846         L{Agent._getEndpoint} return a C{SSL4ClientEndpoint} when passed a
1847         scheme of C{'https'}.
1848         """
1849         expectedHost = 'example.com'
1850         expectedPort = 4321
1851         endpoint = self.agent._getEndpoint('https', expectedHost, expectedPort)
1852         self.assertIsInstance(endpoint, SSL4ClientEndpoint)
1853         self.assertEqual(endpoint._host, expectedHost)
1854         self.assertEqual(endpoint._port, expectedPort)
1855         self.assertIsInstance(endpoint._sslContextFactory,
1856                               _WebToNormalContextFactory)
1857         # Default context factory was used:
1858         self.assertIsInstance(endpoint._sslContextFactory._webContext,
1859                               WebClientContextFactory)
1860     if ssl is None:
1861         test_connectHTTPS.skip = "OpenSSL not present"
1862
1863
1864     def test_connectHTTPSCustomContextFactory(self):
1865         """
1866         If a context factory is passed to L{Agent.__init__} it will be used to
1867         determine the SSL parameters for HTTPS requests.  When an HTTPS request
1868         is made, the hostname and port number of the request URL will be passed
1869         to the context factory's C{getContext} method.  The resulting context
1870         object will be used to establish the SSL connection.
1871         """
1872         expectedHost = 'example.org'
1873         expectedPort = 20443
1874         expectedContext = object()
1875
1876         contextArgs = []
1877         class StubWebContextFactory(object):
1878             def getContext(self, hostname, port):
1879                 contextArgs.append((hostname, port))
1880                 return expectedContext
1881
1882         agent = client.Agent(self.reactor, StubWebContextFactory())
1883         endpoint = agent._getEndpoint('https', expectedHost, expectedPort)
1884         contextFactory = endpoint._sslContextFactory
1885         context = contextFactory.getContext()
1886         self.assertEqual(context, expectedContext)
1887         self.assertEqual(contextArgs, [(expectedHost, expectedPort)])
1888
1889
1890     def test_hostProvided(self):
1891         """
1892         If C{None} is passed to L{Agent.request} for the C{headers} parameter,
1893         a L{Headers} instance is created for the request and a I{Host} header
1894         added to it.
1895         """
1896         self.agent._getEndpoint = lambda *args: self
1897         self.agent.request(
1898             'GET', 'http://example.com/foo?bar')
1899
1900         req, res = self.protocol.requests.pop()
1901         self.assertEqual(req.headers.getRawHeaders('host'), ['example.com'])
1902
1903
1904     def test_hostOverride(self):
1905         """
1906         If the headers passed to L{Agent.request} includes a value for the
1907         I{Host} header, that value takes precedence over the one which would
1908         otherwise be automatically provided.
1909         """
1910         headers = http_headers.Headers({'foo': ['bar'], 'host': ['quux']})
1911         self.agent._getEndpoint = lambda *args: self
1912         self.agent.request(
1913             'GET', 'http://example.com/foo?bar', headers)
1914
1915         req, res = self.protocol.requests.pop()
1916         self.assertEqual(req.headers.getRawHeaders('host'), ['quux'])
1917
1918
1919     def test_headersUnmodified(self):
1920         """
1921         If a I{Host} header must be added to the request, the L{Headers}
1922         instance passed to L{Agent.request} is not modified.
1923         """
1924         headers = http_headers.Headers()
1925         self.agent._getEndpoint = lambda *args: self
1926         self.agent.request(
1927             'GET', 'http://example.com/foo', headers)
1928
1929         protocol = self.protocol
1930
1931         # The request should have been issued.
1932         self.assertEqual(len(protocol.requests), 1)
1933         # And the headers object passed in should not have changed.
1934         self.assertEqual(headers, http_headers.Headers())
1935
1936
1937     def test_hostValueStandardHTTP(self):
1938         """
1939         When passed a scheme of C{'http'} and a port of C{80},
1940         L{Agent._computeHostValue} returns a string giving just
1941         the host name passed to it.
1942         """
1943         self.assertEqual(
1944             self.agent._computeHostValue('http', 'example.com', 80),
1945             'example.com')
1946
1947
1948     def test_hostValueNonStandardHTTP(self):
1949         """
1950         When passed a scheme of C{'http'} and a port other than C{80},
1951         L{Agent._computeHostValue} returns a string giving the
1952         host passed to it joined together with the port number by C{":"}.
1953         """
1954         self.assertEqual(
1955             self.agent._computeHostValue('http', 'example.com', 54321),
1956             'example.com:54321')
1957
1958
1959     def test_hostValueStandardHTTPS(self):
1960         """
1961         When passed a scheme of C{'https'} and a port of C{443},
1962         L{Agent._computeHostValue} returns a string giving just
1963         the host name passed to it.
1964         """
1965         self.assertEqual(
1966             self.agent._computeHostValue('https', 'example.com', 443),
1967             'example.com')
1968
1969
1970     def test_hostValueNonStandardHTTPS(self):
1971         """
1972         When passed a scheme of C{'https'} and a port other than C{443},
1973         L{Agent._computeHostValue} returns a string giving the
1974         host passed to it joined together with the port number by C{":"}.
1975         """
1976         self.assertEqual(
1977             self.agent._computeHostValue('https', 'example.com', 54321),
1978             'example.com:54321')
1979
1980
1981     def test_request(self):
1982         """
1983         L{Agent.request} establishes a new connection to the host indicated by
1984         the host part of the URI passed to it and issues a request using the
1985         method, the path portion of the URI, the headers, and the body producer
1986         passed to it.  It returns a L{Deferred} which fires with an
1987         L{IResponse} from the server.
1988         """
1989         self.agent._getEndpoint = lambda *args: self
1990
1991         headers = http_headers.Headers({'foo': ['bar']})
1992         # Just going to check the body for identity, so it doesn't need to be
1993         # real.
1994         body = object()
1995         self.agent.request(
1996             'GET', 'http://example.com:1234/foo?bar', headers, body)
1997
1998         protocol = self.protocol
1999
2000         # The request should be issued.
2001         self.assertEqual(len(protocol.requests), 1)
2002         req, res = protocol.requests.pop()
2003         self.assertIsInstance(req, Request)
2004         self.assertEqual(req.method, 'GET')
2005         self.assertEqual(req.uri, '/foo?bar')
2006         self.assertEqual(
2007             req.headers,
2008             http_headers.Headers({'foo': ['bar'],
2009                                   'host': ['example.com:1234']}))
2010         self.assertIdentical(req.bodyProducer, body)
2011
2012
2013     def test_connectTimeout(self):
2014         """
2015         L{Agent} takes a C{connectTimeout} argument which is forwarded to the
2016         following C{connectTCP} agent.
2017         """
2018         agent = client.Agent(self.reactor, connectTimeout=5)
2019         agent.request('GET', 'http://foo/')
2020         timeout = self.reactor.tcpClients.pop()[3]
2021         self.assertEqual(5, timeout)
2022
2023
2024     def test_connectSSLTimeout(self):
2025         """
2026         L{Agent} takes a C{connectTimeout} argument which is forwarded to the
2027         following C{connectSSL} call.
2028         """
2029         agent = client.Agent(self.reactor, connectTimeout=5)
2030         agent.request('GET', 'https://foo/')
2031         timeout = self.reactor.sslClients.pop()[4]
2032         self.assertEqual(5, timeout)
2033
2034
2035     def test_bindAddress(self):
2036         """
2037         L{Agent} takes a C{bindAddress} argument which is forwarded to the
2038         following C{connectTCP} call.
2039         """
2040         agent = client.Agent(self.reactor, bindAddress='192.168.0.1')
2041         agent.request('GET', 'http://foo/')
2042         address = self.reactor.tcpClients.pop()[4]
2043         self.assertEqual('192.168.0.1', address)
2044
2045
2046     def test_bindAddressSSL(self):
2047         """
2048         L{Agent} takes a C{bindAddress} argument which is forwarded to the
2049         following C{connectSSL} call.
2050         """
2051         agent = client.Agent(self.reactor, bindAddress='192.168.0.1')
2052         agent.request('GET', 'https://foo/')
2053         address = self.reactor.sslClients.pop()[5]
2054         self.assertEqual('192.168.0.1', address)
2055
2056
2057
2058 class HTTPConnectionPoolRetryTests(unittest.TestCase, FakeReactorAndConnectMixin):
2059     """
2060     L{client.HTTPConnectionPool}, by using
2061     L{client._RetryingHTTP11ClientProtocol}, supports retrying requests done
2062     against previously cached connections.
2063     """
2064
2065     def test_onlyRetryIdempotentMethods(self):
2066         """
2067         Only GET, HEAD, OPTIONS, TRACE, DELETE methods should cause a retry.
2068         """
2069         pool = client.HTTPConnectionPool(None)
2070         connection = client._RetryingHTTP11ClientProtocol(None, pool)
2071         self.assertTrue(connection._shouldRetry("GET", RequestNotSent(), None))
2072         self.assertTrue(connection._shouldRetry("HEAD", RequestNotSent(), None))
2073         self.assertTrue(connection._shouldRetry(
2074                 "OPTIONS", RequestNotSent(), None))
2075         self.assertTrue(connection._shouldRetry(
2076                 "TRACE", RequestNotSent(), None))
2077         self.assertTrue(connection._shouldRetry(
2078                 "DELETE", RequestNotSent(), None))
2079         self.assertFalse(connection._shouldRetry(
2080                 "POST", RequestNotSent(), None))
2081         self.assertFalse(connection._shouldRetry(
2082                 "MYMETHOD", RequestNotSent(), None))
2083         # This will be covered by a different ticket, since we need support
2084         #for resettable body producers:
2085         # self.assertTrue(connection._doRetry("PUT", RequestNotSent(), None))
2086
2087
2088     def test_onlyRetryIfNoResponseReceived(self):
2089         """
2090         Only L{RequestNotSent}, L{RequestTransmissionFailed} and
2091         L{ResponseNeverReceived} exceptions should be a cause for retrying.
2092         """
2093         pool = client.HTTPConnectionPool(None)
2094         connection = client._RetryingHTTP11ClientProtocol(None, pool)
2095         self.assertTrue(connection._shouldRetry("GET", RequestNotSent(), None))
2096         self.assertTrue(connection._shouldRetry(
2097                 "GET", RequestTransmissionFailed([]), None))
2098         self.assertTrue(connection._shouldRetry(
2099                 "GET", ResponseNeverReceived([]),None))
2100         self.assertFalse(connection._shouldRetry(
2101                 "GET", ResponseFailed([]), None))
2102         self.assertFalse(connection._shouldRetry(
2103                 "GET", ConnectionRefusedError(), None))
2104
2105
2106     def test_wrappedOnPersistentReturned(self):
2107         """
2108         If L{client.HTTPConnectionPool.getConnection} returns a previously
2109         cached connection, it will get wrapped in a
2110         L{client._RetryingHTTP11ClientProtocol}.
2111         """
2112         pool = client.HTTPConnectionPool(Clock())
2113
2114         # Add a connection to the cache:
2115         protocol = StubHTTPProtocol()
2116         protocol.makeConnection(StringTransport())
2117         pool._putConnection(123, protocol)
2118
2119         # Retrieve it, it should come back wrapped in a
2120         # _RetryingHTTP11ClientProtocol:
2121         d = pool.getConnection(123, DummyEndpoint())
2122
2123         def gotConnection(connection):
2124             self.assertIsInstance(connection,
2125                                   client._RetryingHTTP11ClientProtocol)
2126             self.assertIdentical(connection._clientProtocol, protocol)
2127         return d.addCallback(gotConnection)
2128
2129
2130     def test_notWrappedOnNewReturned(self):
2131         """
2132         If L{client.HTTPConnectionPool.getConnection} returns a new
2133         connection, it will be returned as is.
2134         """
2135         pool = client.HTTPConnectionPool(None)
2136         d = pool.getConnection(123, DummyEndpoint())
2137
2138         def gotConnection(connection):
2139             # Don't want to use isinstance since potentially the wrapper might
2140             # subclass it at some point:
2141             self.assertIdentical(connection.__class__, HTTP11ClientProtocol)
2142         return d.addCallback(gotConnection)
2143
2144
2145     def retryAttempt(self, willWeRetry):
2146         """
2147         Fail a first request, possibly retrying depending on argument.
2148         """
2149         protocols = []
2150         def newProtocol():
2151             protocol = StubHTTPProtocol()
2152             protocols.append(protocol)
2153             return defer.succeed(protocol)
2154
2155         bodyProducer = object()
2156         request = client.Request("FOO", "/", client.Headers(), bodyProducer,
2157                                  persistent=True)
2158         newProtocol()
2159         protocol = protocols[0]
2160         retrier = client._RetryingHTTP11ClientProtocol(protocol, newProtocol)
2161
2162         def _shouldRetry(m, e, bp):
2163             self.assertEqual(m, "FOO")
2164             self.assertIdentical(bp, bodyProducer)
2165             self.assertIsInstance(e, (RequestNotSent, ResponseNeverReceived))
2166             return willWeRetry
2167         retrier._shouldRetry = _shouldRetry
2168
2169         d = retrier.request(request)
2170
2171         # So far, one request made:
2172         self.assertEqual(len(protocols), 1)
2173         self.assertEqual(len(protocols[0].requests), 1)
2174
2175         # Fail the first request:
2176         protocol.requests[0][1].errback(RequestNotSent())
2177         return d, protocols
2178
2179
2180     def test_retryIfShouldRetryReturnsTrue(self):
2181         """
2182         L{client._RetryingHTTP11ClientProtocol} retries when
2183         L{client._RetryingHTTP11ClientProtocol._shouldRetry} returns C{True}.
2184         """
2185         d, protocols = self.retryAttempt(True)
2186         # We retried!
2187         self.assertEqual(len(protocols), 2)
2188         response = object()
2189         protocols[1].requests[0][1].callback(response)
2190         return d.addCallback(self.assertIdentical, response)
2191
2192
2193     def test_dontRetryIfShouldRetryReturnsFalse(self):
2194         """
2195         L{client._RetryingHTTP11ClientProtocol} does not retry when
2196         L{client._RetryingHTTP11ClientProtocol._shouldRetry} returns C{False}.
2197         """
2198         d, protocols = self.retryAttempt(False)
2199         # We did not retry:
2200         self.assertEqual(len(protocols), 1)
2201         return self.assertFailure(d, RequestNotSent)
2202
2203
2204     def test_onlyRetryWithoutBody(self):
2205         """
2206         L{_RetryingHTTP11ClientProtocol} only retries queries that don't have
2207         a body.
2208
2209         This is an implementation restriction; if the restriction is fixed,
2210         this test should be removed and PUT added to list of methods that
2211         support retries.
2212         """
2213         pool = client.HTTPConnectionPool(None)
2214         connection = client._RetryingHTTP11ClientProtocol(None, pool)
2215         self.assertTrue(connection._shouldRetry("GET", RequestNotSent(), None))
2216         self.assertFalse(connection._shouldRetry("GET", RequestNotSent(), object()))
2217
2218
2219     def test_onlyRetryOnce(self):
2220         """
2221         If a L{client._RetryingHTTP11ClientProtocol} fails more than once on
2222         an idempotent query before a response is received, it will not retry.
2223         """
2224         d, protocols = self.retryAttempt(True)
2225         self.assertEqual(len(protocols), 2)
2226         # Fail the second request too:
2227         protocols[1].requests[0][1].errback(ResponseNeverReceived([]))
2228         # We didn't retry again:
2229         self.assertEqual(len(protocols), 2)
2230         return self.assertFailure(d, ResponseNeverReceived)
2231
2232
2233     def test_dontRetryIfRetryAutomaticallyFalse(self):
2234         """
2235         If L{HTTPConnectionPool.retryAutomatically} is set to C{False}, don't
2236         wrap connections with retrying logic.
2237         """
2238         pool = client.HTTPConnectionPool(Clock())
2239         pool.retryAutomatically = False
2240
2241         # Add a connection to the cache:
2242         protocol = StubHTTPProtocol()
2243         protocol.makeConnection(StringTransport())
2244         pool._putConnection(123, protocol)
2245
2246         # Retrieve it, it should come back unwrapped:
2247         d = pool.getConnection(123, DummyEndpoint())
2248
2249         def gotConnection(connection):
2250             self.assertIdentical(connection, protocol)
2251         return d.addCallback(gotConnection)
2252
2253
2254     def test_retryWithNewConnection(self):
2255         """
2256         L{client.HTTPConnectionPool} creates
2257         {client._RetryingHTTP11ClientProtocol} with a new connection factory
2258         method that creates a new connection using the same key and endpoint
2259         as the wrapped connection.
2260         """
2261         pool = client.HTTPConnectionPool(Clock())
2262         key = 123
2263         endpoint = DummyEndpoint()
2264         newConnections = []
2265
2266         # Override the pool's _newConnection:
2267         def newConnection(k, e):
2268             newConnections.append((k, e))
2269         pool._newConnection = newConnection
2270
2271         # Add a connection to the cache:
2272         protocol = StubHTTPProtocol()
2273         protocol.makeConnection(StringTransport())
2274         pool._putConnection(key, protocol)
2275
2276         # Retrieve it, it should come back wrapped in a
2277         # _RetryingHTTP11ClientProtocol:
2278         d = pool.getConnection(key, endpoint)
2279
2280         def gotConnection(connection):
2281             self.assertIsInstance(connection,
2282                                   client._RetryingHTTP11ClientProtocol)
2283             self.assertIdentical(connection._clientProtocol, protocol)
2284             # Verify that the _newConnection method on retrying connection
2285             # calls _newConnection on the pool:
2286             self.assertEqual(newConnections, [])
2287             connection._newConnection()
2288             self.assertEqual(len(newConnections), 1)
2289             self.assertEqual(newConnections[0][0], key)
2290             self.assertIdentical(newConnections[0][1], endpoint)
2291         return d.addCallback(gotConnection)
2292
2293
2294
2295
2296 class CookieTestsMixin(object):
2297     """
2298     Mixin for unit tests dealing with cookies.
2299     """
2300     def addCookies(self, cookieJar, uri, cookies):
2301         """
2302         Add a cookie to a cookie jar.
2303         """
2304         response = client._FakeUrllib2Response(
2305             client.Response(
2306                 ('HTTP', 1, 1),
2307                 200,
2308                 'OK',
2309                 client.Headers({'Set-Cookie': cookies}),
2310                 None))
2311         request = client._FakeUrllib2Request(uri)
2312         cookieJar.extract_cookies(response, request)
2313         return request, response
2314
2315
2316
2317 class CookieJarTests(unittest.TestCase, CookieTestsMixin):
2318     """
2319     Tests for L{twisted.web.client._FakeUrllib2Response} and
2320     L{twisted.web.client._FakeUrllib2Request}'s interactions with
2321     C{cookielib.CookieJar} instances.
2322     """
2323     def makeCookieJar(self):
2324         """
2325         Create a C{cookielib.CookieJar} with some sample cookies.
2326         """
2327         cookieJar = cookielib.CookieJar()
2328         reqres = self.addCookies(
2329             cookieJar,
2330             'http://example.com:1234/foo?bar',
2331             ['foo=1; cow=moo; Path=/foo; Comment=hello',
2332              'bar=2; Comment=goodbye'])
2333         return cookieJar, reqres
2334
2335
2336     def test_extractCookies(self):
2337         """
2338         L{cookielib.CookieJar.extract_cookies} extracts cookie information from
2339         fake urllib2 response instances.
2340         """
2341         jar = self.makeCookieJar()[0]
2342         cookies = dict([(c.name, c) for c in jar])
2343
2344         cookie = cookies['foo']
2345         self.assertEqual(cookie.version, 0)
2346         self.assertEqual(cookie.name, 'foo')
2347         self.assertEqual(cookie.value, '1')
2348         self.assertEqual(cookie.path, '/foo')
2349         self.assertEqual(cookie.comment, 'hello')
2350         self.assertEqual(cookie.get_nonstandard_attr('cow'), 'moo')
2351
2352         cookie = cookies['bar']
2353         self.assertEqual(cookie.version, 0)
2354         self.assertEqual(cookie.name, 'bar')
2355         self.assertEqual(cookie.value, '2')
2356         self.assertEqual(cookie.path, '/')
2357         self.assertEqual(cookie.comment, 'goodbye')
2358         self.assertIdentical(cookie.get_nonstandard_attr('cow'), None)
2359
2360
2361     def test_sendCookie(self):
2362         """
2363         L{cookielib.CookieJar.add_cookie_header} adds a cookie header to a fake
2364         urllib2 request instance.
2365         """
2366         jar, (request, response) = self.makeCookieJar()
2367
2368         self.assertIdentical(
2369             request.get_header('Cookie', None),
2370             None)
2371
2372         jar.add_cookie_header(request)
2373         self.assertEqual(
2374             request.get_header('Cookie', None),
2375             'foo=1; bar=2')
2376
2377
2378
2379 class CookieAgentTests(unittest.TestCase, CookieTestsMixin,
2380                        FakeReactorAndConnectMixin):
2381     """
2382     Tests for L{twisted.web.client.CookieAgent}.
2383     """
2384     def setUp(self):
2385         self.reactor = self.Reactor()
2386
2387
2388     def test_emptyCookieJarRequest(self):
2389         """
2390         L{CookieAgent.request} does not insert any C{'Cookie'} header into the
2391         L{Request} object if there is no cookie in the cookie jar for the URI
2392         being requested. Cookies are extracted from the response and stored in
2393         the cookie jar.
2394         """
2395         cookieJar = cookielib.CookieJar()
2396         self.assertEqual(list(cookieJar), [])
2397
2398         agent = self.buildAgentForWrapperTest(self.reactor)
2399         cookieAgent = client.CookieAgent(agent, cookieJar)
2400         d = cookieAgent.request(
2401             'GET', 'http://example.com:1234/foo?bar')
2402
2403         def _checkCookie(ignored):
2404             cookies = list(cookieJar)
2405             self.assertEqual(len(cookies), 1)
2406             self.assertEqual(cookies[0].name, 'foo')
2407             self.assertEqual(cookies[0].value, '1')
2408
2409         d.addCallback(_checkCookie)
2410
2411         req, res = self.protocol.requests.pop()
2412         self.assertIdentical(req.headers.getRawHeaders('cookie'), None)
2413
2414         resp = client.Response(
2415             ('HTTP', 1, 1),
2416             200,
2417             'OK',
2418             client.Headers({'Set-Cookie': ['foo=1',]}),
2419             None)
2420         res.callback(resp)
2421
2422         return d
2423
2424
2425     def test_requestWithCookie(self):
2426         """
2427         L{CookieAgent.request} inserts a C{'Cookie'} header into the L{Request}
2428         object when there is a cookie matching the request URI in the cookie
2429         jar.
2430         """
2431         uri = 'http://example.com:1234/foo?bar'
2432         cookie = 'foo=1'
2433
2434         cookieJar = cookielib.CookieJar()
2435         self.addCookies(cookieJar, uri, [cookie])
2436         self.assertEqual(len(list(cookieJar)), 1)
2437
2438         agent = self.buildAgentForWrapperTest(self.reactor)
2439         cookieAgent = client.CookieAgent(agent, cookieJar)
2440         cookieAgent.request('GET', uri)
2441
2442         req, res = self.protocol.requests.pop()
2443         self.assertEqual(req.headers.getRawHeaders('cookie'), [cookie])
2444
2445
2446     def test_secureCookie(self):
2447         """
2448         L{CookieAgent} is able to handle secure cookies, ie cookies which
2449         should only be handled over https.
2450         """
2451         uri = 'https://example.com:1234/foo?bar'
2452         cookie = 'foo=1;secure'
2453
2454         cookieJar = cookielib.CookieJar()
2455         self.addCookies(cookieJar, uri, [cookie])
2456         self.assertEqual(len(list(cookieJar)), 1)
2457
2458         agent = self.buildAgentForWrapperTest(self.reactor)
2459         cookieAgent = client.CookieAgent(agent, cookieJar)
2460         cookieAgent.request('GET', uri)
2461
2462         req, res = self.protocol.requests.pop()
2463         self.assertEqual(req.headers.getRawHeaders('cookie'), ['foo=1'])
2464
2465
2466     def test_secureCookieOnInsecureConnection(self):
2467         """
2468         If a cookie is setup as secure, it won't be sent with the request if
2469         it's not over HTTPS.
2470         """
2471         uri = 'http://example.com/foo?bar'
2472         cookie = 'foo=1;secure'
2473
2474         cookieJar = cookielib.CookieJar()
2475         self.addCookies(cookieJar, uri, [cookie])
2476         self.assertEqual(len(list(cookieJar)), 1)
2477
2478         agent = self.buildAgentForWrapperTest(self.reactor)
2479         cookieAgent = client.CookieAgent(agent, cookieJar)
2480         cookieAgent.request('GET', uri)
2481
2482         req, res = self.protocol.requests.pop()
2483         self.assertIdentical(None, req.headers.getRawHeaders('cookie'))
2484
2485
2486     def test_portCookie(self):
2487         """
2488         L{CookieAgent} supports cookies which enforces the port number they
2489         need to be transferred upon.
2490         """
2491         uri = 'https://example.com:1234/foo?bar'
2492         cookie = 'foo=1;port=1234'
2493
2494         cookieJar = cookielib.CookieJar()
2495         self.addCookies(cookieJar, uri, [cookie])
2496         self.assertEqual(len(list(cookieJar)), 1)
2497
2498         agent = self.buildAgentForWrapperTest(self.reactor)
2499         cookieAgent = client.CookieAgent(agent, cookieJar)
2500         cookieAgent.request('GET', uri)
2501
2502         req, res = self.protocol.requests.pop()
2503         self.assertEqual(req.headers.getRawHeaders('cookie'), ['foo=1'])
2504
2505
2506     def test_portCookieOnWrongPort(self):
2507         """
2508         When creating a cookie with a port directive, it won't be added to the
2509         L{cookie.CookieJar} if the URI is on a different port.
2510         """
2511         uri = 'https://example.com:4567/foo?bar'
2512         cookie = 'foo=1;port=1234'
2513
2514         cookieJar = cookielib.CookieJar()
2515         self.addCookies(cookieJar, uri, [cookie])
2516         self.assertEqual(len(list(cookieJar)), 0)
2517
2518
2519
2520 class Decoder1(proxyForInterface(IResponse)):
2521     """
2522     A test decoder to be used by L{client.ContentDecoderAgent} tests.
2523     """
2524
2525
2526
2527 class Decoder2(Decoder1):
2528     """
2529     A test decoder to be used by L{client.ContentDecoderAgent} tests.
2530     """
2531
2532
2533
2534 class ContentDecoderAgentTests(unittest.TestCase, FakeReactorAndConnectMixin):
2535     """
2536     Tests for L{client.ContentDecoderAgent}.
2537     """
2538
2539     def setUp(self):
2540         """
2541         Create an L{Agent} wrapped around a fake reactor.
2542         """
2543         self.reactor = self.Reactor()
2544         self.agent = self.buildAgentForWrapperTest(self.reactor)
2545
2546
2547     def test_acceptHeaders(self):
2548         """
2549         L{client.ContentDecoderAgent} sets the I{Accept-Encoding} header to the
2550         names of the available decoder objects.
2551         """
2552         agent = client.ContentDecoderAgent(
2553             self.agent, [('decoder1', Decoder1), ('decoder2', Decoder2)])
2554
2555         agent.request('GET', 'http://example.com/foo')
2556
2557         protocol = self.protocol
2558
2559         self.assertEqual(len(protocol.requests), 1)
2560         req, res = protocol.requests.pop()
2561         self.assertEqual(req.headers.getRawHeaders('accept-encoding'),
2562                           ['decoder1,decoder2'])
2563
2564
2565     def test_existingHeaders(self):
2566         """
2567         If there are existing I{Accept-Encoding} fields,
2568         L{client.ContentDecoderAgent} creates a new field for the decoders it
2569         knows about.
2570         """
2571         headers = http_headers.Headers({'foo': ['bar'],
2572                                         'accept-encoding': ['fizz']})
2573         agent = client.ContentDecoderAgent(
2574             self.agent, [('decoder1', Decoder1), ('decoder2', Decoder2)])
2575         agent.request('GET', 'http://example.com/foo', headers=headers)
2576
2577         protocol = self.protocol
2578
2579         self.assertEqual(len(protocol.requests), 1)
2580         req, res = protocol.requests.pop()
2581         self.assertEqual(
2582             list(req.headers.getAllRawHeaders()),
2583             [('Host', ['example.com']),
2584              ('Foo', ['bar']),
2585              ('Accept-Encoding', ['fizz', 'decoder1,decoder2'])])
2586
2587
2588     def test_plainEncodingResponse(self):
2589         """
2590         If the response is not encoded despited the request I{Accept-Encoding}
2591         headers, L{client.ContentDecoderAgent} simply forwards the response.
2592         """
2593         agent = client.ContentDecoderAgent(
2594             self.agent, [('decoder1', Decoder1), ('decoder2', Decoder2)])
2595         deferred = agent.request('GET', 'http://example.com/foo')
2596
2597         req, res = self.protocol.requests.pop()
2598
2599         response = Response(('HTTP', 1, 1), 200, 'OK', http_headers.Headers(),
2600                             None)
2601         res.callback(response)
2602
2603         return deferred.addCallback(self.assertIdentical, response)
2604
2605
2606     def test_unsupportedEncoding(self):
2607         """
2608         If an encoding unknown to the L{client.ContentDecoderAgent} is found,
2609         the response is unchanged.
2610         """
2611         agent = client.ContentDecoderAgent(
2612             self.agent, [('decoder1', Decoder1), ('decoder2', Decoder2)])
2613         deferred = agent.request('GET', 'http://example.com/foo')
2614
2615         req, res = self.protocol.requests.pop()
2616
2617         headers = http_headers.Headers({'foo': ['bar'],
2618                                         'content-encoding': ['fizz']})
2619         response = Response(('HTTP', 1, 1), 200, 'OK', headers, None)
2620         res.callback(response)
2621
2622         return deferred.addCallback(self.assertIdentical, response)
2623
2624
2625     def test_unknownEncoding(self):
2626         """
2627         When L{client.ContentDecoderAgent} encounters a decoder it doesn't know
2628         about, it stops decoding even if another encoding is known afterwards.
2629         """
2630         agent = client.ContentDecoderAgent(
2631             self.agent, [('decoder1', Decoder1), ('decoder2', Decoder2)])
2632         deferred = agent.request('GET', 'http://example.com/foo')
2633
2634         req, res = self.protocol.requests.pop()
2635
2636         headers = http_headers.Headers({'foo': ['bar'],
2637                                         'content-encoding':
2638                                         ['decoder1,fizz,decoder2']})
2639         response = Response(('HTTP', 1, 1), 200, 'OK', headers, None)
2640         res.callback(response)
2641
2642         def check(result):
2643             self.assertNotIdentical(response, result)
2644             self.assertIsInstance(result, Decoder2)
2645             self.assertEqual(['decoder1,fizz'],
2646                               result.headers.getRawHeaders('content-encoding'))
2647
2648         return deferred.addCallback(check)
2649
2650
2651
2652 class SimpleAgentProtocol(Protocol):
2653     """
2654     A L{Protocol} to be used with an L{client.Agent} to receive data.
2655
2656     @ivar finished: L{Deferred} firing when C{connectionLost} is called.
2657
2658     @ivar made: L{Deferred} firing when C{connectionMade} is called.
2659
2660     @ivar received: C{list} of received data.
2661     """
2662
2663     def __init__(self):
2664         self.made = Deferred()
2665         self.finished = Deferred()
2666         self.received = []
2667
2668
2669     def connectionMade(self):
2670         self.made.callback(None)
2671
2672
2673     def connectionLost(self, reason):
2674         self.finished.callback(None)
2675
2676
2677     def dataReceived(self, data):
2678         self.received.append(data)
2679
2680
2681
2682 class ContentDecoderAgentWithGzipTests(unittest.TestCase,
2683                                        FakeReactorAndConnectMixin):
2684
2685     def setUp(self):
2686         """
2687         Create an L{Agent} wrapped around a fake reactor.
2688         """
2689         self.reactor = self.Reactor()
2690         agent = self.buildAgentForWrapperTest(self.reactor)
2691         self.agent = client.ContentDecoderAgent(
2692             agent, [("gzip", client.GzipDecoder)])
2693
2694
2695     def test_gzipEncodingResponse(self):
2696         """
2697         If the response has a C{gzip} I{Content-Encoding} header,
2698         L{GzipDecoder} wraps the response to return uncompressed data to the
2699         user.
2700         """
2701         deferred = self.agent.request('GET', 'http://example.com/foo')
2702
2703         req, res = self.protocol.requests.pop()
2704
2705         headers = http_headers.Headers({'foo': ['bar'],
2706                                         'content-encoding': ['gzip']})
2707         transport = StringTransport()
2708         response = Response(('HTTP', 1, 1), 200, 'OK', headers, transport)
2709         response.length = 12
2710         res.callback(response)
2711
2712         compressor = zlib.compressobj(2, zlib.DEFLATED, 16 + zlib.MAX_WBITS)
2713         data = (compressor.compress('x' * 6) + compressor.compress('y' * 4) +
2714                 compressor.flush())
2715
2716         def checkResponse(result):
2717             self.assertNotIdentical(result, response)
2718             self.assertEqual(result.version, ('HTTP', 1, 1))
2719             self.assertEqual(result.code, 200)
2720             self.assertEqual(result.phrase, 'OK')
2721             self.assertEqual(list(result.headers.getAllRawHeaders()),
2722                               [('Foo', ['bar'])])
2723             self.assertEqual(result.length, UNKNOWN_LENGTH)
2724             self.assertRaises(AttributeError, getattr, result, 'unknown')
2725
2726             response._bodyDataReceived(data[:5])
2727             response._bodyDataReceived(data[5:])
2728             response._bodyDataFinished()
2729
2730             protocol = SimpleAgentProtocol()
2731             result.deliverBody(protocol)
2732
2733             self.assertEqual(protocol.received, ['x' * 6 + 'y' * 4])
2734             return defer.gatherResults([protocol.made, protocol.finished])
2735
2736         deferred.addCallback(checkResponse)
2737
2738         return deferred
2739
2740
2741     def test_brokenContent(self):
2742         """
2743         If the data received by the L{GzipDecoder} isn't valid gzip-compressed
2744         data, the call to C{deliverBody} fails with a C{zlib.error}.
2745         """
2746         deferred = self.agent.request('GET', 'http://example.com/foo')
2747
2748         req, res = self.protocol.requests.pop()
2749
2750         headers = http_headers.Headers({'foo': ['bar'],
2751                                         'content-encoding': ['gzip']})
2752         transport = StringTransport()
2753         response = Response(('HTTP', 1, 1), 200, 'OK', headers, transport)
2754         response.length = 12
2755         res.callback(response)
2756
2757         data = "not gzipped content"
2758
2759         def checkResponse(result):
2760             response._bodyDataReceived(data)
2761
2762             result.deliverBody(Protocol())
2763
2764         deferred.addCallback(checkResponse)
2765         self.assertFailure(deferred, client.ResponseFailed)
2766
2767         def checkFailure(error):
2768             error.reasons[0].trap(zlib.error)
2769             self.assertIsInstance(error.response, Response)
2770
2771         return deferred.addCallback(checkFailure)
2772
2773
2774     def test_flushData(self):
2775         """
2776         When the connection with the server is lost, the gzip protocol calls
2777         C{flush} on the zlib decompressor object to get uncompressed data which
2778         may have been buffered.
2779         """
2780         class decompressobj(object):
2781
2782             def __init__(self, wbits):
2783                 pass
2784
2785             def decompress(self, data):
2786                 return 'x'
2787
2788             def flush(self):
2789                 return 'y'
2790
2791
2792         oldDecompressObj = zlib.decompressobj
2793         zlib.decompressobj = decompressobj
2794         self.addCleanup(setattr, zlib, 'decompressobj', oldDecompressObj)
2795
2796         deferred = self.agent.request('GET', 'http://example.com/foo')
2797
2798         req, res = self.protocol.requests.pop()
2799
2800         headers = http_headers.Headers({'content-encoding': ['gzip']})
2801         transport = StringTransport()
2802         response = Response(('HTTP', 1, 1), 200, 'OK', headers, transport)
2803         res.callback(response)
2804
2805         def checkResponse(result):
2806             response._bodyDataReceived('data')
2807             response._bodyDataFinished()
2808
2809             protocol = SimpleAgentProtocol()
2810             result.deliverBody(protocol)
2811
2812             self.assertEqual(protocol.received, ['x', 'y'])
2813             return defer.gatherResults([protocol.made, protocol.finished])
2814
2815         deferred.addCallback(checkResponse)
2816
2817         return deferred
2818
2819
2820     def test_flushError(self):
2821         """
2822         If the C{flush} call in C{connectionLost} fails, the C{zlib.error}
2823         exception is caught and turned into a L{ResponseFailed}.
2824         """
2825         class decompressobj(object):
2826
2827             def __init__(self, wbits):
2828                 pass
2829
2830             def decompress(self, data):
2831                 return 'x'
2832
2833             def flush(self):
2834                 raise zlib.error()
2835
2836
2837         oldDecompressObj = zlib.decompressobj
2838         zlib.decompressobj = decompressobj
2839         self.addCleanup(setattr, zlib, 'decompressobj', oldDecompressObj)
2840
2841         deferred = self.agent.request('GET', 'http://example.com/foo')
2842
2843         req, res = self.protocol.requests.pop()
2844
2845         headers = http_headers.Headers({'content-encoding': ['gzip']})
2846         transport = StringTransport()
2847         response = Response(('HTTP', 1, 1), 200, 'OK', headers, transport)
2848         res.callback(response)
2849
2850         def checkResponse(result):
2851             response._bodyDataReceived('data')
2852             response._bodyDataFinished()
2853
2854             protocol = SimpleAgentProtocol()
2855             result.deliverBody(protocol)
2856
2857             self.assertEqual(protocol.received, ['x', 'y'])
2858             return defer.gatherResults([protocol.made, protocol.finished])
2859
2860         deferred.addCallback(checkResponse)
2861
2862         self.assertFailure(deferred, client.ResponseFailed)
2863
2864         def checkFailure(error):
2865             error.reasons[1].trap(zlib.error)
2866             self.assertIsInstance(error.response, Response)
2867
2868         return deferred.addCallback(checkFailure)
2869
2870
2871
2872 class ProxyAgentTests(unittest.TestCase, FakeReactorAndConnectMixin):
2873     """
2874     Tests for L{client.ProxyAgent}.
2875     """
2876
2877     def setUp(self):
2878         self.reactor = self.Reactor()
2879         self.agent = client.ProxyAgent(
2880             TCP4ClientEndpoint(self.reactor, "bar", 5678), self.reactor)
2881         oldEndpoint = self.agent._proxyEndpoint
2882         self.agent._proxyEndpoint = self.StubEndpoint(oldEndpoint, self)
2883
2884
2885     def test_proxyRequest(self):
2886         """
2887         L{client.ProxyAgent} issues an HTTP request against the proxy, with the
2888         full URI as path, when C{request} is called.
2889         """
2890         headers = http_headers.Headers({'foo': ['bar']})
2891         # Just going to check the body for identity, so it doesn't need to be
2892         # real.
2893         body = object()
2894         self.agent.request(
2895             'GET', 'http://example.com:1234/foo?bar', headers, body)
2896
2897         host, port, factory = self.reactor.tcpClients.pop()[:3]
2898         self.assertEqual(host, "bar")
2899         self.assertEqual(port, 5678)
2900
2901         self.assertIsInstance(factory._wrappedFactory,
2902                               client._HTTP11ClientFactory)
2903
2904         protocol = self.protocol
2905
2906         # The request should be issued.
2907         self.assertEqual(len(protocol.requests), 1)
2908         req, res = protocol.requests.pop()
2909         self.assertIsInstance(req, Request)
2910         self.assertEqual(req.method, 'GET')
2911         self.assertEqual(req.uri, 'http://example.com:1234/foo?bar')
2912         self.assertEqual(
2913             req.headers,
2914             http_headers.Headers({'foo': ['bar'],
2915                                   'host': ['example.com:1234']}))
2916         self.assertIdentical(req.bodyProducer, body)
2917
2918
2919     def test_nonPersistent(self):
2920         """
2921         C{ProxyAgent} connections are not persistent by default.
2922         """
2923         self.assertEqual(self.agent._pool.persistent, False)
2924
2925
2926     def test_connectUsesConnectionPool(self):
2927         """
2928         When a connection is made by the C{ProxyAgent}, it uses its pool's
2929         C{getConnection} method to do so, with the endpoint it was constructed
2930         with and a key of C{("http-proxy", endpoint)}.
2931         """
2932         endpoint = DummyEndpoint()
2933         class DummyPool(object):
2934             connected = False
2935             persistent = False
2936             def getConnection(this, key, ep):
2937                 this.connected = True
2938                 self.assertIdentical(ep, endpoint)
2939                 # The key is *not* tied to the final destination, but only to
2940                 # the address of the proxy, since that's where *we* are
2941                 # connecting:
2942                 self.assertEqual(key, ("http-proxy", endpoint))
2943                 return defer.succeed(StubHTTPProtocol())
2944
2945         pool = DummyPool()
2946         agent = client.ProxyAgent(endpoint, self.reactor, pool=pool)
2947         self.assertIdentical(pool, agent._pool)
2948
2949         agent.request('GET', 'http://foo/')
2950         self.assertEqual(agent._pool.connected, True)
2951
2952
2953
2954 class RedirectAgentTests(unittest.TestCase, FakeReactorAndConnectMixin):
2955     """
2956     Tests for L{client.RedirectAgent}.
2957     """
2958
2959     def setUp(self):
2960         self.reactor = self.Reactor()
2961         self.agent = client.RedirectAgent(
2962             self.buildAgentForWrapperTest(self.reactor))
2963
2964
2965     def test_noRedirect(self):
2966         """
2967         L{client.RedirectAgent} behaves like L{client.Agent} if the response
2968         doesn't contain a redirect.
2969         """
2970         deferred = self.agent.request('GET', 'http://example.com/foo')
2971
2972         req, res = self.protocol.requests.pop()
2973
2974         headers = http_headers.Headers()
2975         response = Response(('HTTP', 1, 1), 200, 'OK', headers, None)
2976         res.callback(response)
2977
2978         self.assertEqual(0, len(self.protocol.requests))
2979
2980         def checkResponse(result):
2981             self.assertIdentical(result, response)
2982
2983         return deferred.addCallback(checkResponse)
2984
2985
2986     def _testRedirectDefault(self, code):
2987         """
2988         When getting a redirect, L{RedirectAgent} follows the URL specified in
2989         the L{Location} header field and make a new request.
2990         """
2991         self.agent.request('GET', 'http://example.com/foo')
2992
2993         host, port = self.reactor.tcpClients.pop()[:2]
2994         self.assertEqual("example.com", host)
2995         self.assertEqual(80, port)
2996
2997         req, res = self.protocol.requests.pop()
2998
2999         headers = http_headers.Headers(
3000             {'location': ['https://example.com/bar']})
3001         response = Response(('HTTP', 1, 1), code, 'OK', headers, None)
3002         res.callback(response)
3003
3004         req2, res2 = self.protocol.requests.pop()
3005         self.assertEqual('GET', req2.method)
3006         self.assertEqual('/bar', req2.uri)
3007
3008         host, port = self.reactor.sslClients.pop()[:2]
3009         self.assertEqual("example.com", host)
3010         self.assertEqual(443, port)
3011
3012
3013     def test_redirect301(self):
3014         """
3015         L{RedirectAgent} follows redirects on status code 301.
3016         """
3017         self._testRedirectDefault(301)
3018
3019
3020     def test_redirect302(self):
3021         """
3022         L{RedirectAgent} follows redirects on status code 302.
3023         """
3024         self._testRedirectDefault(302)
3025
3026
3027     def test_redirect307(self):
3028         """
3029         L{RedirectAgent} follows redirects on status code 307.
3030         """
3031         self._testRedirectDefault(307)
3032
3033
3034     def test_redirect303(self):
3035         """
3036         L{RedirectAgent} changes the methods to C{GET} when getting a redirect
3037         on a C{POST} request.
3038         """
3039         self.agent.request('POST', 'http://example.com/foo')
3040
3041         req, res = self.protocol.requests.pop()
3042
3043         headers = http_headers.Headers(
3044             {'location': ['http://example.com/bar']})
3045         response = Response(('HTTP', 1, 1), 303, 'OK', headers, None)
3046         res.callback(response)
3047
3048         req2, res2 = self.protocol.requests.pop()
3049         self.assertEqual('GET', req2.method)
3050         self.assertEqual('/bar', req2.uri)
3051
3052
3053     def test_noLocationField(self):
3054         """
3055         If no L{Location} header field is found when getting a redirect,
3056         L{RedirectAgent} fails with a L{ResponseFailed} error wrapping a
3057         L{error.RedirectWithNoLocation} exception.
3058         """
3059         deferred = self.agent.request('GET', 'http://example.com/foo')
3060
3061         req, res = self.protocol.requests.pop()
3062
3063         headers = http_headers.Headers()
3064         response = Response(('HTTP', 1, 1), 301, 'OK', headers, None)
3065         res.callback(response)
3066
3067         self.assertFailure(deferred, client.ResponseFailed)
3068
3069         def checkFailure(fail):
3070             fail.reasons[0].trap(error.RedirectWithNoLocation)
3071             self.assertEqual('http://example.com/foo',
3072                              fail.reasons[0].value.uri)
3073             self.assertEqual(301, fail.response.code)
3074
3075         return deferred.addCallback(checkFailure)
3076
3077
3078     def test_307OnPost(self):
3079         """
3080         When getting a 307 redirect on a C{POST} request, L{RedirectAgent} fais
3081         with a L{ResponseFailed} error wrapping a L{error.PageRedirect}
3082         exception.
3083         """
3084         deferred = self.agent.request('POST', 'http://example.com/foo')
3085
3086         req, res = self.protocol.requests.pop()
3087
3088         headers = http_headers.Headers()
3089         response = Response(('HTTP', 1, 1), 307, 'OK', headers, None)
3090         res.callback(response)
3091
3092         self.assertFailure(deferred, client.ResponseFailed)
3093
3094         def checkFailure(fail):
3095             fail.reasons[0].trap(error.PageRedirect)
3096             self.assertEqual('http://example.com/foo',
3097                              fail.reasons[0].value.location)
3098             self.assertEqual(307, fail.response.code)
3099
3100         return deferred.addCallback(checkFailure)
3101
3102
3103     def test_redirectLimit(self):
3104         """
3105         If the limit of redirects specified to L{RedirectAgent} is reached, the
3106         deferred fires with L{ResponseFailed} error wrapping a
3107         L{InfiniteRedirection} exception.
3108         """
3109         agent = self.buildAgentForWrapperTest(self.reactor)
3110         redirectAgent = client.RedirectAgent(agent, 1)
3111
3112         deferred = redirectAgent.request('GET', 'http://example.com/foo')
3113
3114         req, res = self.protocol.requests.pop()
3115
3116         headers = http_headers.Headers(
3117             {'location': ['http://example.com/bar']})
3118         response = Response(('HTTP', 1, 1), 302, 'OK', headers, None)
3119         res.callback(response)
3120
3121         req2, res2 = self.protocol.requests.pop()
3122
3123         response2 = Response(('HTTP', 1, 1), 302, 'OK', headers, None)
3124         res2.callback(response2)
3125
3126         self.assertFailure(deferred, client.ResponseFailed)
3127
3128         def checkFailure(fail):
3129             fail.reasons[0].trap(error.InfiniteRedirection)
3130             self.assertEqual('http://example.com/foo',
3131                              fail.reasons[0].value.location)
3132             self.assertEqual(302, fail.response.code)
3133
3134         return deferred.addCallback(checkFailure)
3135
3136
3137
3138 if ssl is None or not hasattr(ssl, 'DefaultOpenSSLContextFactory'):
3139     for case in [WebClientSSLTestCase, WebClientRedirectBetweenSSLandPlainText]:
3140         case.skip = "OpenSSL not present"
3141
3142 if not interfaces.IReactorSSL(reactor, None):
3143     for case in [WebClientSSLTestCase, WebClientRedirectBetweenSSLandPlainText]:
3144         case.skip = "Reactor doesn't support SSL"