Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / internet / test / test_tcp.py
1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
3
4 """
5 Tests for implementations of L{IReactorTCP} and the TCP parts of
6 L{IReactorSocket}.
7 """
8
9 __metaclass__ = type
10
11 import socket, errno
12
13 from zope.interface import implements
14
15 from twisted.python.runtime import platform
16 from twisted.python.failure import Failure
17 from twisted.python import log
18
19 from twisted.trial.unittest import SkipTest, TestCase
20 from twisted.internet.test.reactormixins import ReactorBuilder, EndpointCreator
21 from twisted.internet.test.reactormixins import ConnectableProtocol
22 from twisted.internet.test.reactormixins import runProtocolsWithReactor
23 from twisted.internet.error import ConnectionLost, UserError, ConnectionRefusedError
24 from twisted.internet.error import ConnectionDone, ConnectionAborted
25 from twisted.internet.interfaces import (
26     ILoggingContext, IConnector, IReactorFDSet, IReactorSocket)
27 from twisted.internet.address import IPv4Address, IPv6Address
28 from twisted.internet.defer import (
29     Deferred, DeferredList, maybeDeferred, gatherResults)
30 from twisted.internet.endpoints import (
31     TCP4ServerEndpoint, TCP4ClientEndpoint)
32 from twisted.internet.protocol import ServerFactory, ClientFactory, Protocol
33 from twisted.internet.interfaces import (
34     IPushProducer, IPullProducer, IHalfCloseableProtocol)
35 from twisted.internet.tcp import Connection, Server, _resolveIPv6
36
37 from twisted.internet.test.connectionmixins import (
38     LogObserverMixin, ConnectionTestsMixin, TCPClientTestsMixin, findFreePort)
39 from twisted.internet.test.test_core import ObjectModelIntegrationMixin
40 from twisted.test.test_tcp import MyClientFactory, MyServerFactory
41 from twisted.test.test_tcp import ClosingFactory, ClientStartStopFactory
42
43 try:
44     from OpenSSL import SSL
45 except ImportError:
46     useSSL = False
47 else:
48     from twisted.internet.ssl import ClientContextFactory
49     useSSL = True
50
51 try:
52     socket.socket(socket.AF_INET6, socket.SOCK_STREAM).close()
53 except socket.error, e:
54     ipv6Skip = str(e)
55 else:
56     ipv6Skip = None
57
58
59
60 if platform.isWindows():
61     from twisted.internet.test import _win32ifaces
62     getLinkLocalIPv6Addresses = _win32ifaces.win32GetLinkLocalIPv6Addresses
63 else:
64     try:
65         from twisted.internet.test import _posixifaces
66     except ImportError:
67         getLinkLocalIPv6Addresses = lambda: []
68     else:
69         getLinkLocalIPv6Addresses = _posixifaces.posixGetLinkLocalIPv6Addresses
70
71
72
73 def getLinkLocalIPv6Address():
74     """
75     Find and return a configured link local IPv6 address including a scope
76     identifier using the % separation syntax.  If the system has no link local
77     IPv6 addresses, raise L{SkipTest} instead.
78
79     @raise SkipTest: if no link local address can be found or if the
80         C{netifaces} module is not available.
81
82     @return: a C{str} giving the address
83     """
84     addresses = getLinkLocalIPv6Addresses()
85     if addresses:
86         return addresses[0]
87     raise SkipTest("Link local IPv6 address unavailable")
88
89
90
91 def connect(client, (host, port)):
92     if '%' in host or ':' in host:
93         address = socket.getaddrinfo(host, port)[0][4]
94     else:
95         address = (host, port)
96     client.connect(address)
97
98
99
100 class FakeSocket(object):
101     """
102     A fake for L{socket.socket} objects.
103
104     @ivar data: A C{str} giving the data which will be returned from
105         L{FakeSocket.recv}.
106
107     @ivar sendBuffer: A C{list} of the objects passed to L{FakeSocket.send}.
108     """
109     def __init__(self, data):
110         self.data = data
111         self.sendBuffer = []
112
113     def setblocking(self, blocking):
114         self.blocking = blocking
115
116     def recv(self, size):
117         return self.data
118
119     def send(self, bytes):
120         """
121         I{Send} all of C{bytes} by accumulating it into C{self.sendBuffer}.
122
123         @return: The length of C{bytes}, indicating all the data has been
124             accepted.
125         """
126         self.sendBuffer.append(bytes)
127         return len(bytes)
128
129
130     def shutdown(self, how):
131         """
132         Shutdown is not implemented.  The method is provided since real sockets
133         have it and some code expects it.  No behavior of L{FakeSocket} is
134         affected by a call to it.
135         """
136
137
138     def close(self):
139         """
140         Close is not implemented.  The method is provided since real sockets
141         have it and some code expects it.  No behavior of L{FakeSocket} is
142         affected by a call to it.
143         """
144
145
146     def setsockopt(self, *args):
147         """
148         Setsockopt is not implemented.  The method is provided since
149         real sockets have it and some code expects it.  No behavior of
150         L{FakeSocket} is affected by a call to it.
151         """
152
153
154     def fileno(self):
155         """
156         Return a fake file descriptor.  If actually used, this will have no
157         connection to this L{FakeSocket} and will probably cause surprising
158         results.
159         """
160         return 1
161
162
163
164 class TestFakeSocket(TestCase):
165     """
166     Test that the FakeSocket can be used by the doRead method of L{Connection}
167     """
168
169     def test_blocking(self):
170         skt = FakeSocket("someData")
171         skt.setblocking(0)
172         self.assertEqual(skt.blocking, 0)
173
174
175     def test_recv(self):
176         skt = FakeSocket("someData")
177         self.assertEqual(skt.recv(10), "someData")
178
179
180     def test_send(self):
181         """
182         L{FakeSocket.send} accepts the entire string passed to it, adds it to
183         its send buffer, and returns its length.
184         """
185         skt = FakeSocket("")
186         count = skt.send("foo")
187         self.assertEqual(count, 3)
188         self.assertEqual(skt.sendBuffer, ["foo"])
189
190
191
192 class FakeProtocol(Protocol):
193     """
194     An L{IProtocol} that returns a value from its dataReceived method.
195     """
196     def dataReceived(self, data):
197         """
198         Return something other than C{None} to trigger a deprecation warning for
199         that behavior.
200         """
201         return ()
202
203
204
205 class _FakeFDSetReactor(object):
206     """
207     A no-op implementation of L{IReactorFDSet}, which ignores all adds and
208     removes.
209     """
210     implements(IReactorFDSet)
211
212     addReader = addWriter = removeReader = removeWriter = (
213         lambda self, desc: None)
214
215
216
217 class TCPServerTests(TestCase):
218     """
219     Whitebox tests for L{twisted.internet.tcp.Server}.
220     """
221     def setUp(self):
222         self.reactor = _FakeFDSetReactor()
223         class FakePort(object):
224             _realPortNumber = 3
225         self.skt = FakeSocket("")
226         self.protocol = Protocol()
227         self.server = Server(
228             self.skt, self.protocol, ("", 0), FakePort(), None, self.reactor)
229
230
231     def test_writeAfterDisconnect(self):
232         """
233         L{Server.write} discards bytes passed to it if called after it has lost
234         its connection.
235         """
236         self.server.connectionLost(
237             Failure(Exception("Simulated lost connection")))
238         self.server.write("hello world")
239         self.assertEqual(self.skt.sendBuffer, [])
240
241
242     def test_writeAfteDisconnectAfterTLS(self):
243         """
244         L{Server.write} discards bytes passed to it if called after it has lost
245         its connection when the connection had started TLS.
246         """
247         self.server.TLS = True
248         self.test_writeAfterDisconnect()
249
250
251     def test_writeSequenceAfterDisconnect(self):
252         """
253         L{Server.writeSequence} discards bytes passed to it if called after it
254         has lost its connection.
255         """
256         self.server.connectionLost(
257             Failure(Exception("Simulated lost connection")))
258         self.server.writeSequence(["hello world"])
259         self.assertEqual(self.skt.sendBuffer, [])
260
261
262     def test_writeSequenceAfteDisconnectAfterTLS(self):
263         """
264         L{Server.writeSequence} discards bytes passed to it if called after it
265         has lost its connection when the connection had started TLS.
266         """
267         self.server.TLS = True
268         self.test_writeSequenceAfterDisconnect()
269
270
271
272 class TCPConnectionTests(TestCase):
273     """
274     Whitebox tests for L{twisted.internet.tcp.Connection}.
275     """
276     def test_doReadWarningIsRaised(self):
277         """
278         When an L{IProtocol} implementation that returns a value from its
279         C{dataReceived} method, a deprecated warning is emitted.
280         """
281         skt = FakeSocket("someData")
282         protocol = FakeProtocol()
283         conn = Connection(skt, protocol)
284         conn.doRead()
285         warnings = self.flushWarnings([FakeProtocol.dataReceived])
286         self.assertEqual(warnings[0]['category'], DeprecationWarning)
287         self.assertEqual(
288             warnings[0]["message"],
289             "Returning a value other than None from "
290             "twisted.internet.test.test_tcp.FakeProtocol.dataReceived "
291             "is deprecated since Twisted 11.0.0.")
292         self.assertEqual(len(warnings), 1)
293
294
295     def test_noTLSBeforeStartTLS(self):
296         """
297         The C{TLS} attribute of a L{Connection} instance is C{False} before
298         L{Connection.startTLS} is called.
299         """
300         skt = FakeSocket("")
301         protocol = FakeProtocol()
302         conn = Connection(skt, protocol)
303         self.assertFalse(conn.TLS)
304
305
306     def test_tlsAfterStartTLS(self):
307         """
308         The C{TLS} attribute of a L{Connection} instance is C{True} after
309         L{Connection.startTLS} is called.
310         """
311         skt = FakeSocket("")
312         protocol = FakeProtocol()
313         conn = Connection(skt, protocol, reactor=_FakeFDSetReactor())
314         conn._tlsClientDefault = True
315         conn.startTLS(ClientContextFactory(), True)
316         self.assertTrue(conn.TLS)
317     if not useSSL:
318         test_tlsAfterStartTLS.skip = "No SSL support available"
319
320
321
322 class TCPCreator(EndpointCreator):
323     """
324     Create IPv4 TCP endpoints for L{runProtocolsWithReactor}-based tests.
325     """
326
327     interface = "127.0.0.1"
328
329     def server(self, reactor):
330         """
331         Create a server-side TCP endpoint.
332         """
333         return TCP4ServerEndpoint(reactor, 0, interface=self.interface)
334
335
336     def client(self, reactor, serverAddress):
337         """
338         Create a client end point that will connect to the given address.
339
340         @type serverAddress: L{IPv4Address}
341         """
342         return TCP4ClientEndpoint(reactor, self.interface, serverAddress.port)
343
344
345
346 class TCP6Creator(TCPCreator):
347     """
348     Create IPv6 TCP endpoints for
349     C{ReactorBuilder.runProtocolsWithReactor}-based tests.
350
351     The endpoint types in question here are still the TCP4 variety, since
352     these simply pass through IPv6 address literals to the reactor, and we are
353     only testing address literals, not name resolution (as name resolution has
354     not yet been implemented).  See http://twistedmatrix.com/trac/ticket/4470
355     for more specific information about new endpoint classes.  The naming is
356     slightly misleading, but presumably if you're passing an IPv6 literal, you
357     know what you're asking for.
358     """
359     def __init__(self):
360         self.interface = getLinkLocalIPv6Address()
361
362
363
364 class TCPClientTestsBase(ReactorBuilder, ConnectionTestsMixin,
365                             TCPClientTestsMixin):
366     """
367     Base class for builders defining tests related to L{IReactorTCP.connectTCP}.
368     """
369     port = 1234
370
371     @property
372     def interface(self):
373         """
374         Return the interface attribute from the endpoints object.
375         """
376         return self.endpoints.interface
377
378
379
380 class TCP4ClientTestsBuilder(TCPClientTestsBase):
381     """
382     Builder configured with IPv4 parameters for tests related to L{IReactorTCP.connectTCP}.
383     """
384     fakeDomainName = 'some-fake.domain.example.com'
385     family = socket.AF_INET
386     addressClass = IPv4Address
387
388     endpoints = TCPCreator()
389
390
391
392 class TCP6ClientTestsBuilder(TCPClientTestsBase):
393     """
394     Builder configured with IPv6 parameters for tests related to L{IReactorTCP.connectTCP}.
395     """
396
397     if ipv6Skip:
398         skip = "Platform does not support ipv6"
399
400     family = socket.AF_INET6
401     addressClass = IPv6Address
402
403
404     def setUp(self):
405         # Only create this object here, so that it won't be created if tests
406         # are being skipped:
407         self.endpoints = TCP6Creator()
408         # This is used by test_addresses to test the distinction between the
409         # resolved name and the name on the socket itself.  All the same
410         # invariants should hold, but giving back an IPv6 address from a
411         # resolver is not something the reactor can handle, so instead, we make
412         # it so that the connect call for the IPv6 address test simply uses an
413         # address literal.
414         self.fakeDomainName = self.endpoints.interface
415
416
417
418 class TCPConnectorTestsBuilder(ReactorBuilder):
419
420     def test_connectorIdentity(self):
421         """
422         L{IReactorTCP.connectTCP} returns an object which provides
423         L{IConnector}.  The destination of the connector is the address which
424         was passed to C{connectTCP}.  The same connector object is passed to
425         the factory's C{startedConnecting} method as to the factory's
426         C{clientConnectionLost} method.
427         """
428         serverFactory = ClosingFactory()
429         reactor = self.buildReactor()
430         tcpPort = reactor.listenTCP(0, serverFactory, interface=self.interface)
431         serverFactory.port = tcpPort
432         portNumber = tcpPort.getHost().port
433
434         seenConnectors = []
435         seenFailures = []
436
437         clientFactory = ClientStartStopFactory()
438         clientFactory.clientConnectionLost = (
439             lambda connector, reason: (seenConnectors.append(connector),
440                                        seenFailures.append(reason)))
441         clientFactory.startedConnecting = seenConnectors.append
442
443         connector = reactor.connectTCP(self.interface, portNumber,
444                                        clientFactory)
445         self.assertTrue(IConnector.providedBy(connector))
446         dest = connector.getDestination()
447         self.assertEqual(dest.type, "TCP")
448         self.assertEqual(dest.host, self.interface)
449         self.assertEqual(dest.port, portNumber)
450
451         clientFactory.whenStopped.addBoth(lambda _: reactor.stop())
452
453         self.runReactor(reactor)
454
455         seenFailures[0].trap(ConnectionDone)
456         self.assertEqual(seenConnectors, [connector, connector])
457
458
459     def test_userFail(self):
460         """
461         Calling L{IConnector.stopConnecting} in C{Factory.startedConnecting}
462         results in C{Factory.clientConnectionFailed} being called with
463         L{error.UserError} as the reason.
464         """
465         serverFactory = MyServerFactory()
466         reactor = self.buildReactor()
467         tcpPort = reactor.listenTCP(0, serverFactory, interface=self.interface)
468         portNumber = tcpPort.getHost().port
469
470         fatalErrors = []
471
472         def startedConnecting(connector):
473             try:
474                 connector.stopConnecting()
475             except Exception:
476                 fatalErrors.append(Failure())
477                 reactor.stop()
478
479         clientFactory = ClientStartStopFactory()
480         clientFactory.startedConnecting = startedConnecting
481
482         clientFactory.whenStopped.addBoth(lambda _: reactor.stop())
483
484         reactor.callWhenRunning(lambda: reactor.connectTCP(self.interface,
485                                                            portNumber,
486                                                            clientFactory))
487
488         self.runReactor(reactor)
489
490         if fatalErrors:
491             self.fail(fatalErrors[0].getTraceback())
492         clientFactory.reason.trap(UserError)
493         self.assertEqual(clientFactory.failed, 1)
494
495
496     def test_reconnect(self):
497         """
498         Calling L{IConnector.connect} in C{Factory.clientConnectionLost} causes
499         a new connection attempt to be made.
500         """
501         serverFactory = ClosingFactory()
502         reactor = self.buildReactor()
503         tcpPort = reactor.listenTCP(0, serverFactory, interface=self.interface)
504         serverFactory.port = tcpPort
505         portNumber = tcpPort.getHost().port
506
507         clientFactory = MyClientFactory()
508
509         def clientConnectionLost(connector, reason):
510             connector.connect()
511         clientFactory.clientConnectionLost = clientConnectionLost
512         reactor.connectTCP(self.interface, portNumber, clientFactory)
513
514         protocolMadeAndClosed = []
515         def reconnectFailed(ignored):
516             p = clientFactory.protocol
517             protocolMadeAndClosed.append((p.made, p.closed))
518             reactor.stop()
519
520         clientFactory.failDeferred.addCallback(reconnectFailed)
521
522         self.runReactor(reactor)
523
524         clientFactory.reason.trap(ConnectionRefusedError)
525         self.assertEqual(protocolMadeAndClosed, [(1, 1)])
526
527
528
529 class TCP4ConnectorTestsBuilder(TCPConnectorTestsBuilder):
530     interface = '127.0.0.1'
531     family = socket.AF_INET
532     addressClass = IPv4Address
533
534
535
536 class TCP6ConnectorTestsBuilder(TCPConnectorTestsBuilder):
537     family = socket.AF_INET6
538     addressClass = IPv6Address
539
540     if ipv6Skip:
541         skip = "Platform does not support ipv6"
542
543     def setUp(self):
544         self.interface = getLinkLocalIPv6Address()
545
546
547
548 def createTestSocket(test, addressFamily, socketType):
549     """
550     Create a socket for the duration of the given test.
551
552     @param test: the test to add cleanup to.
553
554     @param addressFamily: an C{AF_*} constant
555
556     @param socketType: a C{SOCK_*} constant.
557
558     @return: a socket object.
559     """
560     skt = socket.socket(addressFamily, socketType)
561     test.addCleanup(skt.close)
562     return skt
563
564
565
566 class StreamTransportTestsMixin(LogObserverMixin):
567     """
568     Mixin defining tests which apply to any port/connection based transport.
569     """
570     def test_startedListeningLogMessage(self):
571         """
572         When a port starts, a message including a description of the associated
573         factory is logged.
574         """
575         loggedMessages = self.observe()
576         reactor = self.buildReactor()
577         class SomeFactory(ServerFactory):
578             implements(ILoggingContext)
579             def logPrefix(self):
580                 return "Crazy Factory"
581         factory = SomeFactory()
582         p = self.getListeningPort(reactor, factory)
583         expectedMessage = self.getExpectedStartListeningLogMessage(
584             p, "Crazy Factory")
585         self.assertEqual((expectedMessage,), loggedMessages[0]['message'])
586
587
588     def test_connectionLostLogMsg(self):
589         """
590         When a connection is lost, an informative message should be logged
591         (see L{getExpectedConnectionLostLogMsg}): an address identifying
592         the port and the fact that it was closed.
593         """
594
595         loggedMessages = []
596         def logConnectionLostMsg(eventDict):
597             loggedMessages.append(log.textFromEventDict(eventDict))
598
599         reactor = self.buildReactor()
600         p = self.getListeningPort(reactor, ServerFactory())
601         expectedMessage = self.getExpectedConnectionLostLogMsg(p)
602         log.addObserver(logConnectionLostMsg)
603
604         def stopReactor(ignored):
605             log.removeObserver(logConnectionLostMsg)
606             reactor.stop()
607
608         def doStopListening():
609             log.addObserver(logConnectionLostMsg)
610             maybeDeferred(p.stopListening).addCallback(stopReactor)
611
612         reactor.callWhenRunning(doStopListening)
613         reactor.run()
614
615         self.assertIn(expectedMessage, loggedMessages)
616
617
618     def test_allNewStyle(self):
619         """
620         The L{IListeningPort} object is an instance of a class with no
621         classic classes in its hierarchy.
622         """
623         reactor = self.buildReactor()
624         port = self.getListeningPort(reactor, ServerFactory())
625         self.assertFullyNewStyle(port)
626
627
628 class ListenTCPMixin(object):
629     """
630     Mixin which uses L{IReactorTCP.listenTCP} to hand out listening TCP ports.
631     """
632     def getListeningPort(self, reactor, factory, port=0, interface=''):
633         """
634         Get a TCP port from a reactor.
635         """
636         return reactor.listenTCP(port, factory, interface=interface)
637
638
639
640 class SocketTCPMixin(object):
641     """
642     Mixin which uses L{IReactorSocket.adoptStreamPort} to hand out listening TCP
643     ports.
644     """
645     def getListeningPort(self, reactor, factory, port=0, interface=''):
646         """
647         Get a TCP port from a reactor, wrapping an already-initialized file
648         descriptor.
649         """
650         if IReactorSocket.providedBy(reactor):
651             if ':' in interface:
652                 domain = socket.AF_INET6
653                 address = socket.getaddrinfo(interface, port)[0][4]
654             else:
655                 domain = socket.AF_INET
656                 address = (interface, port)
657             portSock = socket.socket(domain)
658             portSock.bind(address)
659             portSock.listen(3)
660             portSock.setblocking(False)
661             try:
662                 return reactor.adoptStreamPort(
663                     portSock.fileno(), portSock.family, factory)
664             finally:
665                 # The socket should still be open; fileno will raise if it is
666                 # not.
667                 portSock.fileno()
668                 # Now clean it up, because the rest of the test does not need
669                 # it.
670                 portSock.close()
671         else:
672             raise SkipTest("Reactor does not provide IReactorSocket")
673
674
675
676 class TCPPortTestsMixin(object):
677     """
678     Tests for L{IReactorTCP.listenTCP}
679     """
680     def getExpectedStartListeningLogMessage(self, port, factory):
681         """
682         Get the message expected to be logged when a TCP port starts listening.
683         """
684         return "%s starting on %d" % (
685             factory, port.getHost().port)
686
687
688     def getExpectedConnectionLostLogMsg(self, port):
689         """
690         Get the expected connection lost message for a TCP port.
691         """
692         return "(TCP Port %s Closed)" % (port.getHost().port,)
693
694
695     def test_portGetHostOnIPv4(self):
696         """
697         When no interface is passed to L{IReactorTCP.listenTCP}, the returned
698         listening port listens on an IPv4 address.
699         """
700         reactor = self.buildReactor()
701         port = self.getListeningPort(reactor, ServerFactory())
702         address = port.getHost()
703         self.assertIsInstance(address, IPv4Address)
704
705
706     def test_portGetHostOnIPv6(self):
707         """
708         When listening on an IPv6 address, L{IListeningPort.getHost} returns
709         an L{IPv6Address} with C{host} and C{port} attributes reflecting the
710         address the port is bound to.
711         """
712         reactor = self.buildReactor()
713         host, portNumber = findFreePort(
714             family=socket.AF_INET6, interface='::1')[:2]
715         port = self.getListeningPort(
716             reactor, ServerFactory(), portNumber, host)
717         address = port.getHost()
718         self.assertIsInstance(address, IPv6Address)
719         self.assertEqual('::1', address.host)
720         self.assertEqual(portNumber, address.port)
721     if ipv6Skip:
722         test_portGetHostOnIPv6.skip = ipv6Skip
723
724
725     def test_portGetHostOnIPv6ScopeID(self):
726         """
727         When a link-local IPv6 address including a scope identifier is passed as
728         the C{interface} argument to L{IReactorTCP.listenTCP}, the resulting
729         L{IListeningPort} reports its address as an L{IPv6Address} with a host
730         value that includes the scope identifier.
731         """
732         linkLocal = getLinkLocalIPv6Address()
733         reactor = self.buildReactor()
734         port = self.getListeningPort(reactor, ServerFactory(), 0, linkLocal)
735         address = port.getHost()
736         self.assertIsInstance(address, IPv6Address)
737         self.assertEqual(linkLocal, address.host)
738     if ipv6Skip:
739         test_portGetHostOnIPv6ScopeID.skip = ipv6Skip
740
741
742     def _buildProtocolAddressTest(self, client, interface):
743         """
744         Connect C{client} to a server listening on C{interface} started with
745         L{IReactorTCP.listenTCP} and return the address passed to the factory's
746         C{buildProtocol} method.
747
748         @param client: A C{SOCK_STREAM} L{socket.socket} created with an address
749             family such that it will be able to connect to a server listening on
750             C{interface}.
751
752         @param interface: A C{str} giving an address for a server to listen on.
753             This should almost certainly be the loopback address for some
754             address family supported by L{IReactorTCP.listenTCP}.
755
756         @return: Whatever object, probably an L{IAddress} provider, is passed to
757             a server factory's C{buildProtocol} method when C{client}
758             establishes a connection.
759         """
760         class ObserveAddress(ServerFactory):
761             def buildProtocol(self, address):
762                 reactor.stop()
763                 self.observedAddress = address
764                 return Protocol()
765
766         factory = ObserveAddress()
767         reactor = self.buildReactor()
768         port = self.getListeningPort(reactor, factory, 0, interface)
769         client.setblocking(False)
770         try:
771             connect(client, (port.getHost().host, port.getHost().port))
772         except socket.error, (errnum, message):
773             self.assertIn(errnum, (errno.EINPROGRESS, errno.EWOULDBLOCK))
774
775         self.runReactor(reactor)
776
777         return factory.observedAddress
778
779
780     def test_buildProtocolIPv4Address(self):
781         """
782         When a connection is accepted over IPv4, an L{IPv4Address} is passed
783         to the factory's C{buildProtocol} method giving the peer's address.
784         """
785         interface = '127.0.0.1'
786         client = createTestSocket(self, socket.AF_INET, socket.SOCK_STREAM)
787         observedAddress = self._buildProtocolAddressTest(client, interface)
788         self.assertEqual(
789             IPv4Address('TCP', *client.getsockname()), observedAddress)
790
791
792     def test_buildProtocolIPv6Address(self):
793         """
794         When a connection is accepted to an IPv6 address, an L{IPv6Address} is
795         passed to the factory's C{buildProtocol} method giving the peer's
796         address.
797         """
798         interface = '::1'
799         client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM)
800         observedAddress = self._buildProtocolAddressTest(client, interface)
801         self.assertEqual(
802             IPv6Address('TCP', *client.getsockname()[:2]), observedAddress)
803     if ipv6Skip:
804         test_buildProtocolIPv6Address.skip = ipv6Skip
805
806
807     def test_buildProtocolIPv6AddressScopeID(self):
808         """
809         When a connection is accepted to a link-local IPv6 address, an
810         L{IPv6Address} is passed to the factory's C{buildProtocol} method
811         giving the peer's address, including a scope identifier.
812         """
813         interface = getLinkLocalIPv6Address()
814         client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM)
815         observedAddress = self._buildProtocolAddressTest(client, interface)
816         self.assertEqual(
817             IPv6Address('TCP', *client.getsockname()[:2]), observedAddress)
818     if ipv6Skip:
819         test_buildProtocolIPv6AddressScopeID.skip = ipv6Skip
820
821
822     def _serverGetConnectionAddressTest(self, client, interface, which):
823         """
824         Connect C{client} to a server listening on C{interface} started with
825         L{IReactorTCP.listenTCP} and return the address returned by one of the
826         server transport's address lookup methods, C{getHost} or C{getPeer}.
827
828         @param client: A C{SOCK_STREAM} L{socket.socket} created with an address
829             family such that it will be able to connect to a server listening on
830             C{interface}.
831
832         @param interface: A C{str} giving an address for a server to listen on.
833             This should almost certainly be the loopback address for some
834             address family supported by L{IReactorTCP.listenTCP}.
835
836         @param which: A C{str} equal to either C{"getHost"} or C{"getPeer"}
837             determining which address will be returned.
838
839         @return: Whatever object, probably an L{IAddress} provider, is returned
840             from the method indicated by C{which}.
841         """
842         class ObserveAddress(Protocol):
843             def makeConnection(self, transport):
844                 reactor.stop()
845                 self.factory.address = getattr(transport, which)()
846
847         reactor = self.buildReactor()
848         factory = ServerFactory()
849         factory.protocol = ObserveAddress
850         port = self.getListeningPort(reactor, factory, 0, interface)
851         client.setblocking(False)
852         try:
853             connect(client, (port.getHost().host, port.getHost().port))
854         except socket.error, (errnum, message):
855             self.assertIn(errnum, (errno.EINPROGRESS, errno.EWOULDBLOCK))
856         self.runReactor(reactor)
857         return factory.address
858
859
860     def test_serverGetHostOnIPv4(self):
861         """
862         When a connection is accepted over IPv4, the server
863         L{ITransport.getHost} method returns an L{IPv4Address} giving the
864         address on which the server accepted the connection.
865         """
866         interface = '127.0.0.1'
867         client = createTestSocket(self, socket.AF_INET, socket.SOCK_STREAM)
868         hostAddress = self._serverGetConnectionAddressTest(
869             client, interface, 'getHost')
870         self.assertEqual(
871             IPv4Address('TCP', *client.getpeername()), hostAddress)
872
873
874     def test_serverGetHostOnIPv6(self):
875         """
876         When a connection is accepted over IPv6, the server
877         L{ITransport.getHost} method returns an L{IPv6Address} giving the
878         address on which the server accepted the connection.
879         """
880         interface = '::1'
881         client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM)
882         hostAddress = self._serverGetConnectionAddressTest(
883             client, interface, 'getHost')
884         self.assertEqual(
885             IPv6Address('TCP', *client.getpeername()[:2]), hostAddress)
886     if ipv6Skip:
887         test_serverGetHostOnIPv6.skip = ipv6Skip
888
889
890     def test_serverGetHostOnIPv6ScopeID(self):
891         """
892         When a connection is accepted over IPv6, the server
893         L{ITransport.getHost} method returns an L{IPv6Address} giving the
894         address on which the server accepted the connection, including the scope
895         identifier.
896         """
897         interface = getLinkLocalIPv6Address()
898         client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM)
899         hostAddress = self._serverGetConnectionAddressTest(
900             client, interface, 'getHost')
901         self.assertEqual(
902             IPv6Address('TCP', *client.getpeername()[:2]), hostAddress)
903     if ipv6Skip:
904         test_serverGetHostOnIPv6ScopeID.skip = ipv6Skip
905
906
907     def test_serverGetPeerOnIPv4(self):
908         """
909         When a connection is accepted over IPv4, the server
910         L{ITransport.getPeer} method returns an L{IPv4Address} giving the
911         address of the remote end of the connection.
912         """
913         interface = '127.0.0.1'
914         client = createTestSocket(self, socket.AF_INET, socket.SOCK_STREAM)
915         peerAddress = self._serverGetConnectionAddressTest(
916             client, interface, 'getPeer')
917         self.assertEqual(
918             IPv4Address('TCP', *client.getsockname()), peerAddress)
919
920
921     def test_serverGetPeerOnIPv6(self):
922         """
923         When a connection is accepted over IPv6, the server
924         L{ITransport.getPeer} method returns an L{IPv6Address} giving the
925         address on the remote end of the connection.
926         """
927         interface = '::1'
928         client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM)
929         peerAddress = self._serverGetConnectionAddressTest(
930             client, interface, 'getPeer')
931         self.assertEqual(
932             IPv6Address('TCP', *client.getsockname()[:2]), peerAddress)
933     if ipv6Skip:
934         test_serverGetPeerOnIPv6.skip = ipv6Skip
935
936
937     def test_serverGetPeerOnIPv6ScopeID(self):
938         """
939         When a connection is accepted over IPv6, the server
940         L{ITransport.getPeer} method returns an L{IPv6Address} giving the
941         address on the remote end of the connection, including the scope
942         identifier.
943         """
944         interface = getLinkLocalIPv6Address()
945         client = createTestSocket(self, socket.AF_INET6, socket.SOCK_STREAM)
946         peerAddress = self._serverGetConnectionAddressTest(
947             client, interface, 'getPeer')
948         self.assertEqual(
949             IPv6Address('TCP', *client.getsockname()[:2]), peerAddress)
950     if ipv6Skip:
951         test_serverGetPeerOnIPv6ScopeID.skip = ipv6Skip
952
953
954
955 class TCPPortTestsBuilder(ReactorBuilder, ListenTCPMixin, TCPPortTestsMixin,
956                           ObjectModelIntegrationMixin,
957                           StreamTransportTestsMixin):
958     pass
959
960
961
962 class TCPFDPortTestsBuilder(ReactorBuilder, SocketTCPMixin, TCPPortTestsMixin,
963                             ObjectModelIntegrationMixin,
964                             StreamTransportTestsMixin):
965     pass
966
967
968
969 class StopStartReadingProtocol(Protocol):
970     """
971     Protocol that pauses and resumes the transport a few times
972     """
973
974     def connectionMade(self):
975         self.data = ''
976         self.pauseResumeProducing(3)
977
978
979     def pauseResumeProducing(self, counter):
980         """
981         Toggle transport read state, then count down.
982         """
983         self.transport.pauseProducing()
984         self.transport.resumeProducing()
985         if counter:
986             self.factory.reactor.callLater(0,
987                     self.pauseResumeProducing, counter - 1)
988         else:
989             self.factory.reactor.callLater(0,
990                     self.factory.ready.callback, self)
991
992
993     def dataReceived(self, data):
994         log.msg('got data', len(data))
995         self.data += data
996         if len(self.data) == 4*4096:
997             self.factory.stop.callback(self.data)
998
999
1000
1001 class TCPConnectionTestsBuilder(ReactorBuilder):
1002     """
1003     Builder defining tests relating to L{twisted.internet.tcp.Connection}.
1004     """
1005
1006     def test_stopStartReading(self):
1007         """
1008         This test verifies transport socket read state after multiple
1009         pause/resumeProducing calls.
1010         """
1011         sf = ServerFactory()
1012         reactor = sf.reactor = self.buildReactor()
1013
1014         skippedReactors = ["Glib2Reactor", "Gtk2Reactor"]
1015         reactorClassName = reactor.__class__.__name__
1016         if reactorClassName in skippedReactors and platform.isWindows():
1017             raise SkipTest(
1018                 "This test is broken on gtk/glib under Windows.")
1019
1020         sf.protocol = StopStartReadingProtocol
1021         sf.ready = Deferred()
1022         sf.stop = Deferred()
1023         p = reactor.listenTCP(0, sf)
1024         port = p.getHost().port
1025         def proceed(protos, port):
1026             """
1027             Send several IOCPReactor's buffers' worth of data.
1028             """
1029             self.assertTrue(protos[0])
1030             self.assertTrue(protos[1])
1031             protos = protos[0][1], protos[1][1]
1032             protos[0].transport.write('x' * (2 * 4096) + 'y' * (2 * 4096))
1033             return (sf.stop.addCallback(cleanup, protos, port)
1034                            .addCallback(lambda ign: reactor.stop()))
1035
1036         def cleanup(data, protos, port):
1037             """
1038             Make sure IOCPReactor didn't start several WSARecv operations
1039             that clobbered each other's results.
1040             """
1041             self.assertEqual(data, 'x'*(2*4096) + 'y'*(2*4096),
1042                                  'did not get the right data')
1043             return DeferredList([
1044                     maybeDeferred(protos[0].transport.loseConnection),
1045                     maybeDeferred(protos[1].transport.loseConnection),
1046                     maybeDeferred(port.stopListening)])
1047
1048         cc = TCP4ClientEndpoint(reactor, '127.0.0.1', port)
1049         cf = ClientFactory()
1050         cf.protocol = Protocol
1051         d = DeferredList([cc.connect(cf), sf.ready]).addCallback(proceed, p)
1052         self.runReactor(reactor)
1053         return d
1054
1055
1056     def test_connectionLostAfterPausedTransport(self):
1057         """
1058         Alice connects to Bob.  Alice writes some bytes and then shuts down the
1059         connection.  Bob receives the bytes from the connection and then pauses
1060         the transport object.  Shortly afterwards Bob resumes the transport
1061         object.  At that point, Bob is notified that the connection has been
1062         closed.
1063
1064         This is no problem for most reactors.  The underlying event notification
1065         API will probably just remind them that the connection has been closed.
1066         It is a little tricky for win32eventreactor (MsgWaitForMultipleObjects).
1067         MsgWaitForMultipleObjects will only deliver the close notification once.
1068         The reactor needs to remember that notification until Bob resumes the
1069         transport.
1070         """
1071         class Pauser(ConnectableProtocol):
1072             def __init__(self):
1073                 self.events = []
1074
1075             def dataReceived(self, bytes):
1076                 self.events.append("paused")
1077                 self.transport.pauseProducing()
1078                 self.reactor.callLater(0, self.resume)
1079
1080             def resume(self):
1081                 self.events.append("resumed")
1082                 self.transport.resumeProducing()
1083
1084             def connectionLost(self, reason):
1085                 # This is the event you have been waiting for.
1086                 self.events.append("lost")
1087                 ConnectableProtocol.connectionLost(self, reason)
1088
1089         class Client(ConnectableProtocol):
1090             def connectionMade(self):
1091                 self.transport.write("some bytes for you")
1092                 self.transport.loseConnection()
1093
1094         pauser = Pauser()
1095         runProtocolsWithReactor(self, pauser, Client(), TCPCreator())
1096         self.assertEqual(pauser.events, ["paused", "resumed", "lost"])
1097
1098
1099     def test_doubleHalfClose(self):
1100         """
1101         If one side half-closes its connection, and then the other side of the
1102         connection calls C{loseWriteConnection}, and then C{loseConnection} in
1103         {writeConnectionLost}, the connection is closed correctly.
1104
1105         This rather obscure case used to fail (see ticket #3037).
1106         """
1107         class ListenerProtocol(ConnectableProtocol):
1108             implements(IHalfCloseableProtocol)
1109
1110             def readConnectionLost(self):
1111                 self.transport.loseWriteConnection()
1112
1113             def writeConnectionLost(self):
1114                 self.transport.loseConnection()
1115
1116         class Client(ConnectableProtocol):
1117             def connectionMade(self):
1118                 self.transport.loseConnection()
1119
1120         # If test fails, reactor won't stop and we'll hit timeout:
1121         runProtocolsWithReactor(
1122             self, ListenerProtocol(), Client(), TCPCreator())
1123
1124
1125
1126 class WriteSequenceTests(ReactorBuilder):
1127     """
1128     Test for L{twisted.internet.abstract.FileDescriptor.writeSequence}.
1129
1130     @ivar client: the connected client factory to be used in tests.
1131     @type client: L{MyClientFactory}
1132
1133     @ivar server: the listening server factory to be used in tests.
1134     @type server: L{MyServerFactory}
1135     """
1136     def setUp(self):
1137         server = MyServerFactory()
1138         server.protocolConnectionMade = Deferred()
1139         server.protocolConnectionLost = Deferred()
1140         self.server = server
1141
1142         client = MyClientFactory()
1143         client.protocolConnectionMade = Deferred()
1144         client.protocolConnectionLost = Deferred()
1145         self.client = client
1146
1147
1148     def setWriteBufferSize(self, transport, value):
1149         """
1150         Set the write buffer size for the given transport, mananing possible
1151         differences (ie, IOCP). Bug #4322 should remove the need of that hack.
1152         """
1153         if getattr(transport, "writeBufferSize", None) is not None:
1154             transport.writeBufferSize = value
1155         else:
1156             transport.bufferSize = value
1157
1158
1159     def test_withoutWrite(self):
1160         """
1161         C{writeSequence} sends the data even if C{write} hasn't been called.
1162         """
1163         client, server = self.client, self.server
1164         reactor = self.buildReactor()
1165
1166         port = reactor.listenTCP(0, server)
1167
1168         def dataReceived(data):
1169             log.msg("data received: %r" % data)
1170             self.assertEquals(data, "Some sequence splitted")
1171             client.protocol.transport.loseConnection()
1172
1173         def clientConnected(proto):
1174             log.msg("client connected %s" % proto)
1175             proto.transport.writeSequence(["Some ", "sequence ", "splitted"])
1176
1177         def serverConnected(proto):
1178             log.msg("server connected %s" % proto)
1179             proto.dataReceived = dataReceived
1180
1181         d1 = client.protocolConnectionMade.addCallback(clientConnected)
1182         d2 = server.protocolConnectionMade.addCallback(serverConnected)
1183         d3 = server.protocolConnectionLost
1184         d4 = client.protocolConnectionLost
1185         d = gatherResults([d1, d2, d3, d4])
1186         def stop(result):
1187             reactor.stop()
1188             return result
1189         d.addBoth(stop)
1190
1191         reactor.connectTCP("127.0.0.1", port.getHost().port, client)
1192         self.runReactor(reactor)
1193
1194
1195     def test_writeSequenceWithUnicodeRaisesException(self):
1196         """
1197         C{writeSequence} with an element in the sequence of type unicode raises
1198         C{TypeError}.
1199         """
1200         client, server = self.client, self.server
1201         reactor = self.buildReactor()
1202
1203         port = reactor.listenTCP(0, server)
1204
1205         reactor.connectTCP("127.0.0.1", port.getHost().port, client)
1206
1207         def serverConnected(proto):
1208             log.msg("server connected %s" % proto)
1209             exc = self.assertRaises(
1210                 TypeError,
1211                 proto.transport.writeSequence, [u"Unicode is not kosher"])
1212             self.assertEquals(str(exc), "Data must not be unicode")
1213
1214         d = server.protocolConnectionMade.addCallback(serverConnected)
1215         d.addErrback(log.err)
1216         d.addCallback(lambda ignored: reactor.stop())
1217
1218         self.runReactor(reactor)
1219
1220
1221     def _producerTest(self, clientConnected):
1222         """
1223         Helper for testing producers which call C{writeSequence}.  This will set
1224         up a connection which a producer can use.  It returns after the
1225         connection is closed.
1226
1227         @param clientConnected: A callback which will be invoked with a client
1228             protocol after a connection is setup.  This is responsible for
1229             setting up some sort of producer.
1230         """
1231         reactor = self.buildReactor()
1232
1233         port = reactor.listenTCP(0, self.server)
1234
1235         # The following could probably all be much simpler, but for #5285.
1236
1237         # First let the server notice the connection
1238         d1 = self.server.protocolConnectionMade
1239
1240         # Grab the client connection Deferred now though, so we don't lose it if
1241         # the client connects before the server.
1242         d2 = self.client.protocolConnectionMade
1243
1244         def serverConnected(proto):
1245             # Now take action as soon as the client is connected
1246             d2.addCallback(clientConnected)
1247             return d2
1248         d1.addCallback(serverConnected)
1249
1250         d3 = self.server.protocolConnectionLost
1251         d4 = self.client.protocolConnectionLost
1252
1253         # After the client is connected and does its producer stuff, wait for
1254         # the disconnection events.
1255         def didProducerActions(ignored):
1256             return gatherResults([d3, d4])
1257         d1.addCallback(didProducerActions)
1258
1259         def stop(result):
1260             reactor.stop()
1261             return result
1262         d1.addBoth(stop)
1263
1264         reactor.connectTCP("127.0.0.1", port.getHost().port, self.client)
1265
1266         self.runReactor(reactor)
1267
1268
1269     def test_streamingProducer(self):
1270         """
1271         C{writeSequence} pauses its streaming producer if too much data is
1272         buffered, and then resumes it.
1273         """
1274         client, server = self.client, self.server
1275
1276         class SaveActionProducer(object):
1277             implements(IPushProducer)
1278             def __init__(self):
1279                 self.actions = []
1280
1281             def pauseProducing(self):
1282                 self.actions.append("pause")
1283
1284             def resumeProducing(self):
1285                 self.actions.append("resume")
1286                 # Unregister the producer so the connection can close
1287                 client.protocol.transport.unregisterProducer()
1288                 # This is why the code below waits for the server connection
1289                 # first - so we have it to close here.  We close the server side
1290                 # because win32evenreactor cannot reliably observe us closing
1291                 # the client side (#5285).
1292                 server.protocol.transport.loseConnection()
1293
1294             def stopProducing(self):
1295                 self.actions.append("stop")
1296
1297         producer = SaveActionProducer()
1298
1299         def clientConnected(proto):
1300             # Register a streaming producer and verify that it gets paused after
1301             # it writes more than the local send buffer can hold.
1302             proto.transport.registerProducer(producer, True)
1303             self.assertEquals(producer.actions, [])
1304             self.setWriteBufferSize(proto.transport, 500)
1305             proto.transport.writeSequence(["x" * 50] * 20)
1306             self.assertEquals(producer.actions, ["pause"])
1307
1308         self._producerTest(clientConnected)
1309         # After the send buffer gets a chance to empty out a bit, the producer
1310         # should be resumed.
1311         self.assertEquals(producer.actions, ["pause", "resume"])
1312
1313
1314     def test_nonStreamingProducer(self):
1315         """
1316         C{writeSequence} pauses its producer if too much data is buffered only
1317         if this is a streaming producer.
1318         """
1319         client, server = self.client, self.server
1320         test = self
1321
1322         class SaveActionProducer(object):
1323             implements(IPullProducer)
1324             def __init__(self):
1325                 self.actions = []
1326
1327             def resumeProducing(self):
1328                 self.actions.append("resume")
1329                 if self.actions.count("resume") == 2:
1330                     client.protocol.transport.stopConsuming()
1331                 else:
1332                     test.setWriteBufferSize(client.protocol.transport, 500)
1333                     client.protocol.transport.writeSequence(["x" * 50] * 20)
1334
1335             def stopProducing(self):
1336                 self.actions.append("stop")
1337
1338         producer = SaveActionProducer()
1339
1340         def clientConnected(proto):
1341             # Register a non-streaming producer and verify that it is resumed
1342             # immediately.
1343             proto.transport.registerProducer(producer, False)
1344             self.assertEquals(producer.actions, ["resume"])
1345
1346         self._producerTest(clientConnected)
1347         # After the local send buffer empties out, the producer should be
1348         # resumed again.
1349         self.assertEquals(producer.actions, ["resume", "resume"])
1350
1351
1352 globals().update(TCP4ClientTestsBuilder.makeTestCaseClasses())
1353 globals().update(TCP6ClientTestsBuilder.makeTestCaseClasses())
1354 globals().update(TCPPortTestsBuilder.makeTestCaseClasses())
1355 globals().update(TCPFDPortTestsBuilder.makeTestCaseClasses())
1356 globals().update(TCPConnectionTestsBuilder.makeTestCaseClasses())
1357 globals().update(TCP4ConnectorTestsBuilder.makeTestCaseClasses())
1358 globals().update(TCP6ConnectorTestsBuilder.makeTestCaseClasses())
1359 globals().update(WriteSequenceTests.makeTestCaseClasses())
1360
1361
1362
1363 class ServerAbortsTwice(ConnectableProtocol):
1364     """
1365     Call abortConnection() twice.
1366     """
1367
1368     def dataReceived(self, data):
1369         self.transport.abortConnection()
1370         self.transport.abortConnection()
1371
1372
1373
1374 class ServerAbortsThenLoses(ConnectableProtocol):
1375     """
1376     Call abortConnection() followed by loseConnection().
1377     """
1378
1379     def dataReceived(self, data):
1380         self.transport.abortConnection()
1381         self.transport.loseConnection()
1382
1383
1384
1385 class AbortServerWritingProtocol(ConnectableProtocol):
1386     """
1387     Protocol that writes data upon connection.
1388     """
1389
1390     def connectionMade(self):
1391         """
1392         Tell the client that the connection is set up and it's time to abort.
1393         """
1394         self.transport.write("ready")
1395
1396
1397
1398 class ReadAbortServerProtocol(AbortServerWritingProtocol):
1399     """
1400     Server that should never receive any data, except 'X's which are written
1401     by the other side of the connection before abortConnection, and so might
1402     possibly arrive.
1403     """
1404
1405     def dataReceived(self, data):
1406         if data.replace('X', ''):
1407             raise Exception("Unexpectedly received data.")
1408
1409
1410
1411 class NoReadServer(ConnectableProtocol):
1412     """
1413     Stop reading immediately on connection.
1414
1415     This simulates a lost connection that will cause the other side to time
1416     out, and therefore call abortConnection().
1417     """
1418
1419     def connectionMade(self):
1420         self.transport.stopReading()
1421
1422
1423
1424 class EventualNoReadServer(ConnectableProtocol):
1425     """
1426     Like NoReadServer, except we Wait until some bytes have been delivered
1427     before stopping reading. This means TLS handshake has finished, where
1428     applicable.
1429     """
1430
1431     gotData = False
1432     stoppedReading = False
1433
1434
1435     def dataReceived(self, data):
1436         if not self.gotData:
1437             self.gotData = True
1438             self.transport.registerProducer(self, False)
1439             self.transport.write("hello")
1440
1441
1442     def resumeProducing(self):
1443         if self.stoppedReading:
1444             return
1445         self.stoppedReading = True
1446         # We've written out the data:
1447         self.transport.stopReading()
1448
1449
1450     def pauseProducing(self):
1451         pass
1452
1453
1454     def stopProducing(self):
1455         pass
1456
1457
1458
1459 class BaseAbortingClient(ConnectableProtocol):
1460     """
1461     Base class for abort-testing clients.
1462     """
1463     inReactorMethod = False
1464
1465     def connectionLost(self, reason):
1466         if self.inReactorMethod:
1467             raise RuntimeError("BUG: connectionLost was called re-entrantly!")
1468         ConnectableProtocol.connectionLost(self, reason)
1469
1470
1471
1472 class WritingButNotAbortingClient(BaseAbortingClient):
1473     """
1474     Write data, but don't abort.
1475     """
1476
1477     def connectionMade(self):
1478         self.transport.write("hello")
1479
1480
1481
1482 class AbortingClient(BaseAbortingClient):
1483     """
1484     Call abortConnection() after writing some data.
1485     """
1486
1487     def dataReceived(self, data):
1488         """
1489         Some data was received, so the connection is set up.
1490         """
1491         self.inReactorMethod = True
1492         self.writeAndAbort()
1493         self.inReactorMethod = False
1494
1495
1496     def writeAndAbort(self):
1497         # X is written before abortConnection, and so there is a chance it
1498         # might arrive. Y is written after, and so no Ys should ever be
1499         # delivered:
1500         self.transport.write("X" * 10000)
1501         self.transport.abortConnection()
1502         self.transport.write("Y" * 10000)
1503
1504
1505
1506 class AbortingTwiceClient(AbortingClient):
1507     """
1508     Call abortConnection() twice, after writing some data.
1509     """
1510
1511     def writeAndAbort(self):
1512         AbortingClient.writeAndAbort(self)
1513         self.transport.abortConnection()
1514
1515
1516
1517 class AbortingThenLosingClient(AbortingClient):
1518     """
1519     Call abortConnection() and then loseConnection().
1520     """
1521
1522     def writeAndAbort(self):
1523         AbortingClient.writeAndAbort(self)
1524         self.transport.loseConnection()
1525
1526
1527
1528 class ProducerAbortingClient(ConnectableProtocol):
1529     """
1530     Call abortConnection from doWrite, via resumeProducing.
1531     """
1532
1533     inReactorMethod = True
1534     producerStopped = False
1535
1536     def write(self):
1537         self.transport.write("lalala" * 127000)
1538         self.inRegisterProducer = True
1539         self.transport.registerProducer(self, False)
1540         self.inRegisterProducer = False
1541
1542
1543     def connectionMade(self):
1544         self.write()
1545
1546
1547     def resumeProducing(self):
1548         self.inReactorMethod = True
1549         if not self.inRegisterProducer:
1550             self.transport.abortConnection()
1551         self.inReactorMethod = False
1552
1553
1554     def stopProducing(self):
1555         self.producerStopped = True
1556
1557
1558     def connectionLost(self, reason):
1559         if not self.producerStopped:
1560             raise RuntimeError("BUG: stopProducing() was never called.")
1561         if self.inReactorMethod:
1562             raise RuntimeError("BUG: connectionLost called re-entrantly!")
1563         ConnectableProtocol.connectionLost(self, reason)
1564
1565
1566
1567 class StreamingProducerClient(ConnectableProtocol):
1568     """
1569     Call abortConnection() when the other side has stopped reading.
1570
1571     In particular, we want to call abortConnection() only once our local
1572     socket hits a state where it is no longer writeable. This helps emulate
1573     the most common use case for abortConnection(), closing a connection after
1574     a timeout, with write buffers being full.
1575
1576     Since it's very difficult to know when this actually happens, we just
1577     write a lot of data, and assume at that point no more writes will happen.
1578     """
1579     paused = False
1580     extraWrites = 0
1581     inReactorMethod = False
1582
1583     def connectionMade(self):
1584         self.write()
1585
1586
1587     def write(self):
1588         """
1589         Write large amount to transport, then wait for a while for buffers to
1590         fill up.
1591         """
1592         self.transport.registerProducer(self, True)
1593         for i in range(100):
1594             self.transport.write("1234567890" * 32000)
1595
1596
1597     def resumeProducing(self):
1598         self.paused = False
1599
1600
1601     def stopProducing(self):
1602         pass
1603
1604
1605     def pauseProducing(self):
1606         """
1607         Called when local buffer fills up.
1608
1609         The goal is to hit the point where the local file descriptor is not
1610         writeable (or the moral equivalent). The fact that pauseProducing has
1611         been called is not sufficient, since that can happen when Twisted's
1612         buffers fill up but OS hasn't gotten any writes yet. We want to be as
1613         close as possible to every buffer (including OS buffers) being full.
1614
1615         So, we wait a bit more after this for Twisted to write out a few
1616         chunks, then abortConnection.
1617         """
1618         if self.paused:
1619             return
1620         self.paused = True
1621         # The amount we wait is arbitrary, we just want to make sure some
1622         # writes have happened and outgoing OS buffers filled up -- see
1623         # http://twistedmatrix.com/trac/ticket/5303 for details:
1624         self.reactor.callLater(0.01, self.doAbort)
1625
1626
1627     def doAbort(self):
1628         if not self.paused:
1629             log.err(RuntimeError("BUG: We should be paused a this point."))
1630         self.inReactorMethod = True
1631         self.transport.abortConnection()
1632         self.inReactorMethod = False
1633
1634
1635     def connectionLost(self, reason):
1636         # Tell server to start reading again so it knows to go away:
1637         self.otherProtocol.transport.startReading()
1638         ConnectableProtocol.connectionLost(self, reason)
1639
1640
1641
1642 class StreamingProducerClientLater(StreamingProducerClient):
1643     """
1644     Call abortConnection() from dataReceived, after bytes have been
1645     exchanged.
1646     """
1647
1648     def connectionMade(self):
1649         self.transport.write("hello")
1650         self.gotData = False
1651
1652
1653     def dataReceived(self, data):
1654         if not self.gotData:
1655             self.gotData = True
1656             self.write()
1657
1658
1659 class ProducerAbortingClientLater(ProducerAbortingClient):
1660     """
1661     Call abortConnection from doWrite, via resumeProducing.
1662
1663     Try to do so after some bytes have already been exchanged, so we
1664     don't interrupt SSL handshake.
1665     """
1666
1667     def connectionMade(self):
1668         # Override base class connectionMade().
1669         pass
1670
1671
1672     def dataReceived(self, data):
1673         self.write()
1674
1675
1676
1677 class DataReceivedRaisingClient(AbortingClient):
1678     """
1679     Call abortConnection(), and then throw exception, from dataReceived.
1680     """
1681
1682     def dataReceived(self, data):
1683         self.transport.abortConnection()
1684         raise ZeroDivisionError("ONO")
1685
1686
1687
1688 class ResumeThrowsClient(ProducerAbortingClient):
1689     """
1690     Call abortConnection() and throw exception from resumeProducing().
1691     """
1692
1693     def resumeProducing(self):
1694         if not self.inRegisterProducer:
1695             self.transport.abortConnection()
1696             raise ZeroDivisionError("ono!")
1697
1698
1699     def connectionLost(self, reason):
1700         # Base class assertion about stopProducing being called isn't valid;
1701         # if the we blew up in resumeProducing, consumers are justified in
1702         # giving up on the producer and not calling stopProducing.
1703         ConnectableProtocol.connectionLost(self, reason)
1704
1705
1706
1707 class AbortConnectionMixin(object):
1708     """
1709     Unit tests for L{ITransport.abortConnection}.
1710     """
1711     # Override in subclasses, should be a EndpointCreator instance:
1712     endpoints = None
1713
1714     def runAbortTest(self, clientClass, serverClass,
1715                      clientConnectionLostReason=None):
1716         """
1717         A test runner utility function, which hooks up a matched pair of client
1718         and server protocols.
1719
1720         We then run the reactor until both sides have disconnected, and then
1721         verify that the right exception resulted.
1722         """
1723         clientExpectedExceptions = (ConnectionAborted, ConnectionLost)
1724         serverExpectedExceptions = (ConnectionLost, ConnectionDone)
1725         # In TLS tests we may get SSL.Error instead of ConnectionLost,
1726         # since we're trashing the TLS protocol layer.
1727         if useSSL:
1728             clientExpectedExceptions = clientExpectedExceptions + (SSL.Error,)
1729             serverExpectedExceptions = serverExpectedExceptions + (SSL.Error,)
1730
1731         client = clientClass()
1732         server = serverClass()
1733         client.otherProtocol = server
1734         server.otherProtocol = client
1735         reactor = runProtocolsWithReactor(self, server, client, self.endpoints)
1736
1737         # Make sure everything was shutdown correctly:
1738         self.assertEqual(reactor.removeAll(), [])
1739         # The reactor always has a timeout added in runReactor():
1740         delayedCalls = reactor.getDelayedCalls()
1741         self.assertEqual(len(delayedCalls), 1, map(str, delayedCalls))
1742
1743         if clientConnectionLostReason is not None:
1744             self.assertIsInstance(
1745                 client.disconnectReason.value,
1746                 (clientConnectionLostReason,) + clientExpectedExceptions)
1747         else:
1748             self.assertIsInstance(client.disconnectReason.value,
1749                                   clientExpectedExceptions)
1750         self.assertIsInstance(server.disconnectReason.value, serverExpectedExceptions)
1751
1752
1753     def test_dataReceivedAbort(self):
1754         """
1755         abortConnection() is called in dataReceived. The protocol should be
1756         disconnected, but connectionLost should not be called re-entrantly.
1757         """
1758         return self.runAbortTest(AbortingClient, ReadAbortServerProtocol)
1759
1760
1761     def test_clientAbortsConnectionTwice(self):
1762         """
1763         abortConnection() is called twice by client.
1764
1765         No exception should be thrown, and the connection will be closed.
1766         """
1767         return self.runAbortTest(AbortingTwiceClient, ReadAbortServerProtocol)
1768
1769
1770     def test_clientAbortsConnectionThenLosesConnection(self):
1771         """
1772         Client calls abortConnection(), followed by loseConnection().
1773
1774         No exception should be thrown, and the connection will be closed.
1775         """
1776         return self.runAbortTest(AbortingThenLosingClient,
1777                                  ReadAbortServerProtocol)
1778
1779
1780     def test_serverAbortsConnectionTwice(self):
1781         """
1782         abortConnection() is called twice by server.
1783
1784         No exception should be thrown, and the connection will be closed.
1785         """
1786         return self.runAbortTest(WritingButNotAbortingClient, ServerAbortsTwice,
1787                                  clientConnectionLostReason=ConnectionLost)
1788
1789
1790     def test_serverAbortsConnectionThenLosesConnection(self):
1791         """
1792         Server calls abortConnection(), followed by loseConnection().
1793
1794         No exception should be thrown, and the connection will be closed.
1795         """
1796         return self.runAbortTest(WritingButNotAbortingClient,
1797                                  ServerAbortsThenLoses,
1798                                  clientConnectionLostReason=ConnectionLost)
1799
1800
1801     def test_resumeProducingAbort(self):
1802         """
1803         abortConnection() is called in resumeProducing, before any bytes have
1804         been exchanged. The protocol should be disconnected, but
1805         connectionLost should not be called re-entrantly.
1806         """
1807         self.runAbortTest(ProducerAbortingClient,
1808                           ConnectableProtocol)
1809
1810
1811     def test_resumeProducingAbortLater(self):
1812         """
1813         abortConnection() is called in resumeProducing, after some
1814         bytes have been exchanged. The protocol should be disconnected.
1815         """
1816         return self.runAbortTest(ProducerAbortingClientLater,
1817                                  AbortServerWritingProtocol)
1818
1819
1820     def test_fullWriteBuffer(self):
1821         """
1822         abortConnection() triggered by the write buffer being full.
1823
1824         In particular, the server side stops reading. This is supposed
1825         to simulate a realistic timeout scenario where the client
1826         notices the server is no longer accepting data.
1827
1828         The protocol should be disconnected, but connectionLost should not be
1829         called re-entrantly.
1830         """
1831         self.runAbortTest(StreamingProducerClient,
1832                           NoReadServer)
1833
1834
1835     def test_fullWriteBufferAfterByteExchange(self):
1836         """
1837         abortConnection() is triggered by a write buffer being full.
1838
1839         However, this buffer is filled after some bytes have been exchanged,
1840         allowing a TLS handshake if we're testing TLS. The connection will
1841         then be lost.
1842         """
1843         return self.runAbortTest(StreamingProducerClientLater,
1844                                  EventualNoReadServer)
1845
1846
1847     def test_dataReceivedThrows(self):
1848         """
1849         dataReceived calls abortConnection(), and then raises an exception.
1850
1851         The connection will be lost, with the thrown exception
1852         (C{ZeroDivisionError}) as the reason on the client. The idea here is
1853         that bugs should not be masked by abortConnection, in particular
1854         unexpected exceptions.
1855         """
1856         self.runAbortTest(DataReceivedRaisingClient,
1857                           AbortServerWritingProtocol,
1858                           clientConnectionLostReason=ZeroDivisionError)
1859         errors = self.flushLoggedErrors(ZeroDivisionError)
1860         self.assertEquals(len(errors), 1)
1861
1862
1863     def test_resumeProducingThrows(self):
1864         """
1865         resumeProducing calls abortConnection(), and then raises an exception.
1866
1867         The connection will be lost, with the thrown exception
1868         (C{ZeroDivisionError}) as the reason on the client. The idea here is
1869         that bugs should not be masked by abortConnection, in particular
1870         unexpected exceptions.
1871         """
1872         self.runAbortTest(ResumeThrowsClient,
1873                           ConnectableProtocol,
1874                           clientConnectionLostReason=ZeroDivisionError)
1875         errors = self.flushLoggedErrors(ZeroDivisionError)
1876         self.assertEquals(len(errors), 1)
1877
1878
1879
1880 class AbortConnectionTestCase(ReactorBuilder, AbortConnectionMixin):
1881     """
1882     TCP-specific L{AbortConnectionMixin} tests.
1883     """
1884
1885     endpoints = TCPCreator()
1886
1887 globals().update(AbortConnectionTestCase.makeTestCaseClasses())
1888
1889
1890
1891 class SimpleUtilityTestCase(TestCase):
1892     """
1893     Simple, direct tests for helpers within L{twisted.internet.tcp}.
1894     """
1895
1896     skip = ipv6Skip
1897
1898     def test_resolveNumericHost(self):
1899         """
1900         L{_resolveIPv6} raises a L{socket.gaierror} (L{socket.EAI_NONAME}) when
1901         invoked with a non-numeric host.  (In other words, it is passing
1902         L{socket.AI_NUMERICHOST} to L{socket.getaddrinfo} and will not
1903         accidentally block if it receives bad input.)
1904         """
1905         err = self.assertRaises(socket.gaierror, _resolveIPv6, "localhost", 1)
1906         self.assertEqual(err.args[0], socket.EAI_NONAME)
1907
1908
1909     def test_resolveNumericService(self):
1910         """
1911         L{_resolveIPv6} raises a L{socket.gaierror} (L{socket.EAI_NONAME}) when
1912         invoked with a non-numeric port.  (In other words, it is passing
1913         L{socket.AI_NUMERICSERV} to L{socket.getaddrinfo} and will not
1914         accidentally block if it receives bad input.)
1915         """
1916         err = self.assertRaises(socket.gaierror, _resolveIPv6, "::1", "http")
1917         self.assertEqual(err.args[0], socket.EAI_NONAME)
1918
1919     if platform.isWindows():
1920         test_resolveNumericService.skip = ("The AI_NUMERICSERV flag is not "
1921                                            "supported by Microsoft providers.")
1922         # http://msdn.microsoft.com/en-us/library/windows/desktop/ms738520.aspx
1923
1924
1925     def test_resolveIPv6(self):
1926         """
1927         L{_resolveIPv6} discovers the flow info and scope ID of an IPv6
1928         address.
1929         """
1930         result = _resolveIPv6("::1", 2)
1931         self.assertEqual(len(result), 4)
1932         # We can't say anything more useful about these than that they're
1933         # integers, because the whole point of getaddrinfo is that you can never
1934         # know a-priori know _anything_ about the network interfaces of the
1935         # computer that you're on and you have to ask it.
1936         self.assertIsInstance(result[2], int) # flow info
1937         self.assertIsInstance(result[3], int) # scope id
1938         # but, luckily, IP presentation format and what it means to be a port
1939         # number are a little better specified.
1940         self.assertEqual(result[:2], ("::1", 2))
1941
1942
1943