Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / test / test_tcp.py
1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
3
4 """
5 Tests for implementations of L{IReactorTCP}.
6 """
7
8 import socket, random, errno
9
10 from zope.interface import implements
11
12 from twisted.trial import unittest
13
14 from twisted.python.log import msg
15 from twisted.internet import protocol, reactor, defer, interfaces
16 from twisted.internet import error
17 from twisted.internet.address import IPv4Address
18 from twisted.internet.interfaces import IHalfCloseableProtocol, IPullProducer
19 from twisted.protocols import policies
20 from twisted.test.proto_helpers import AccumulatingProtocol
21
22
23 def loopUntil(predicate, interval=0):
24     """
25     Poor excuse for an event notification helper.  This polls a condition and
26     calls back a Deferred when it is seen to be true.
27
28     Do not use this function.
29     """
30     from twisted.internet import task
31     d = defer.Deferred()
32     def check():
33         res = predicate()
34         if res:
35             d.callback(res)
36     call = task.LoopingCall(check)
37     def stop(result):
38         call.stop()
39         return result
40     d.addCallback(stop)
41     d2 = call.start(interval)
42     d2.addErrback(d.errback)
43     return d
44
45
46
47 class ClosingProtocol(protocol.Protocol):
48
49     def connectionMade(self):
50         msg("ClosingProtocol.connectionMade")
51         self.transport.loseConnection()
52
53     def connectionLost(self, reason):
54         msg("ClosingProtocol.connectionLost")
55         reason.trap(error.ConnectionDone)
56
57
58
59 class ClosingFactory(protocol.ServerFactory):
60     """
61     Factory that closes port immediately.
62     """
63
64     _cleanerUpper = None
65
66     def buildProtocol(self, conn):
67         self._cleanerUpper = self.port.stopListening()
68         return ClosingProtocol()
69
70
71     def cleanUp(self):
72         """
73         Clean-up for tests to wait for the port to stop listening.
74         """
75         if self._cleanerUpper is None:
76             return self.port.stopListening()
77         return self._cleanerUpper
78
79
80
81 class MyProtocolFactoryMixin(object):
82     """
83     Mixin for factories which create L{AccumulatingProtocol} instances.
84
85     @type protocolFactory: no-argument callable
86     @ivar protocolFactory: Factory for protocols - takes the place of the
87         typical C{protocol} attribute of factories (but that name is used by
88         this class for something else).
89
90     @type protocolConnectionMade: L{NoneType} or L{defer.Deferred}
91     @ivar protocolConnectionMade: When an instance of L{AccumulatingProtocol}
92         is connected, if this is not C{None}, the L{Deferred} will be called
93         back with the protocol instance and the attribute set to C{None}.
94
95     @type protocolConnectionLost: L{NoneType} or L{defer.Deferred}
96     @ivar protocolConnectionLost: When an instance of L{AccumulatingProtocol}
97         is created, this will be set as its C{closedDeferred} attribute and
98         then this attribute will be set to C{None} so the L{defer.Deferred} is
99         not used by more than one protocol.
100
101     @ivar protocol: The most recently created L{AccumulatingProtocol} instance
102         which was returned from C{buildProtocol}.
103
104     @type called: C{int}
105     @ivar called: A counter which is incremented each time C{buildProtocol}
106         is called.
107
108     @ivar peerAddresses: A C{list} of the addresses passed to C{buildProtocol}.
109     """
110     protocolFactory = AccumulatingProtocol
111
112     protocolConnectionMade = None
113     protocolConnectionLost = None
114     protocol = None
115     called = 0
116
117     def __init__(self):
118         self.peerAddresses = []
119
120
121     def buildProtocol(self, addr):
122         """
123         Create a L{AccumulatingProtocol} and set it up to be able to perform
124         callbacks.
125         """
126         self.peerAddresses.append(addr)
127         self.called += 1
128         p = self.protocolFactory()
129         p.factory = self
130         p.closedDeferred = self.protocolConnectionLost
131         self.protocolConnectionLost = None
132         self.protocol = p
133         return p
134
135
136
137 class MyServerFactory(MyProtocolFactoryMixin, protocol.ServerFactory):
138     """
139     Server factory which creates L{AccumulatingProtocol} instances.
140     """
141
142
143
144 class MyClientFactory(MyProtocolFactoryMixin, protocol.ClientFactory):
145     """
146     Client factory which creates L{AccumulatingProtocol} instances.
147     """
148     failed = 0
149     stopped = 0
150
151     def __init__(self):
152         MyProtocolFactoryMixin.__init__(self)
153         self.deferred = defer.Deferred()
154         self.failDeferred = defer.Deferred()
155
156     def clientConnectionFailed(self, connector, reason):
157         self.failed = 1
158         self.reason = reason
159         self.failDeferred.callback(None)
160
161     def clientConnectionLost(self, connector, reason):
162         self.lostReason = reason
163         self.deferred.callback(None)
164
165     def stopFactory(self):
166         self.stopped = 1
167
168
169
170 class ListeningTestCase(unittest.TestCase):
171
172     def test_listen(self):
173         """
174         L{IReactorTCP.listenTCP} returns an object which provides
175         L{IListeningPort}.
176         """
177         f = MyServerFactory()
178         p1 = reactor.listenTCP(0, f, interface="127.0.0.1")
179         self.addCleanup(p1.stopListening)
180         self.failUnless(interfaces.IListeningPort.providedBy(p1))
181
182
183     def testStopListening(self):
184         """
185         The L{IListeningPort} returned by L{IReactorTCP.listenTCP} can be
186         stopped with its C{stopListening} method.  After the L{Deferred} it
187         (optionally) returns has been called back, the port number can be bound
188         to a new server.
189         """
190         f = MyServerFactory()
191         port = reactor.listenTCP(0, f, interface="127.0.0.1")
192         n = port.getHost().port
193
194         def cbStopListening(ignored):
195             # Make sure we can rebind the port right away
196             port = reactor.listenTCP(n, f, interface="127.0.0.1")
197             return port.stopListening()
198
199         d = defer.maybeDeferred(port.stopListening)
200         d.addCallback(cbStopListening)
201         return d
202
203
204     def testNumberedInterface(self):
205         f = MyServerFactory()
206         # listen only on the loopback interface
207         p1 = reactor.listenTCP(0, f, interface='127.0.0.1')
208         return p1.stopListening()
209
210     def testPortRepr(self):
211         f = MyServerFactory()
212         p = reactor.listenTCP(0, f)
213         portNo = str(p.getHost().port)
214         self.failIf(repr(p).find(portNo) == -1)
215         def stoppedListening(ign):
216             self.failIf(repr(p).find(portNo) != -1)
217         d = defer.maybeDeferred(p.stopListening)
218         return d.addCallback(stoppedListening)
219
220
221     def test_serverRepr(self):
222         """
223         Check that the repr string of the server transport get the good port
224         number if the server listens on 0.
225         """
226         server = MyServerFactory()
227         serverConnMade = server.protocolConnectionMade = defer.Deferred()
228         port = reactor.listenTCP(0, server)
229         self.addCleanup(port.stopListening)
230
231         client = MyClientFactory()
232         clientConnMade = client.protocolConnectionMade = defer.Deferred()
233         connector = reactor.connectTCP("127.0.0.1",
234                                        port.getHost().port, client)
235         self.addCleanup(connector.disconnect)
236         def check((serverProto, clientProto)):
237             portNumber = port.getHost().port
238             self.assertEqual(
239                 repr(serverProto.transport),
240                 "<AccumulatingProtocol #0 on %s>" % (portNumber,))
241             serverProto.transport.loseConnection()
242             clientProto.transport.loseConnection()
243         return defer.gatherResults([serverConnMade, clientConnMade]
244             ).addCallback(check)
245
246
247     def test_restartListening(self):
248         """
249         Stop and then try to restart a L{tcp.Port}: after a restart, the
250         server should be able to handle client connections.
251         """
252         serverFactory = MyServerFactory()
253         port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
254         self.addCleanup(port.stopListening)
255
256         def cbStopListening(ignored):
257             port.startListening()
258
259             client = MyClientFactory()
260             serverFactory.protocolConnectionMade = defer.Deferred()
261             client.protocolConnectionMade = defer.Deferred()
262             connector = reactor.connectTCP("127.0.0.1",
263                                            port.getHost().port, client)
264             self.addCleanup(connector.disconnect)
265             return defer.gatherResults([serverFactory.protocolConnectionMade,
266                                         client.protocolConnectionMade]
267                 ).addCallback(close)
268
269         def close((serverProto, clientProto)):
270             clientProto.transport.loseConnection()
271             serverProto.transport.loseConnection()
272
273         d = defer.maybeDeferred(port.stopListening)
274         d.addCallback(cbStopListening)
275         return d
276
277
278     def test_exceptInStop(self):
279         """
280         If the server factory raises an exception in C{stopFactory}, the
281         deferred returned by L{tcp.Port.stopListening} should fail with the
282         corresponding error.
283         """
284         serverFactory = MyServerFactory()
285         def raiseException():
286             raise RuntimeError("An error")
287         serverFactory.stopFactory = raiseException
288         port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
289
290         return self.assertFailure(port.stopListening(), RuntimeError)
291
292
293     def test_restartAfterExcept(self):
294         """
295         Even if the server factory raise an exception in C{stopFactory}, the
296         corresponding C{tcp.Port} instance should be in a sane state and can
297         be restarted.
298         """
299         serverFactory = MyServerFactory()
300         def raiseException():
301             raise RuntimeError("An error")
302         serverFactory.stopFactory = raiseException
303         port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
304         self.addCleanup(port.stopListening)
305
306         def cbStopListening(ignored):
307             del serverFactory.stopFactory
308             port.startListening()
309
310             client = MyClientFactory()
311             serverFactory.protocolConnectionMade = defer.Deferred()
312             client.protocolConnectionMade = defer.Deferred()
313             connector = reactor.connectTCP("127.0.0.1",
314                                            port.getHost().port, client)
315             self.addCleanup(connector.disconnect)
316             return defer.gatherResults([serverFactory.protocolConnectionMade,
317                                         client.protocolConnectionMade]
318                 ).addCallback(close)
319
320         def close((serverProto, clientProto)):
321             clientProto.transport.loseConnection()
322             serverProto.transport.loseConnection()
323
324         return self.assertFailure(port.stopListening(), RuntimeError
325             ).addCallback(cbStopListening)
326
327
328     def test_directConnectionLostCall(self):
329         """
330         If C{connectionLost} is called directly on a port object, it succeeds
331         (and doesn't expect the presence of a C{deferred} attribute).
332
333         C{connectionLost} is called by L{reactor.disconnectAll} at shutdown.
334         """
335         serverFactory = MyServerFactory()
336         port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
337         portNumber = port.getHost().port
338         port.connectionLost(None)
339
340         client = MyClientFactory()
341         serverFactory.protocolConnectionMade = defer.Deferred()
342         client.protocolConnectionMade = defer.Deferred()
343         reactor.connectTCP("127.0.0.1", portNumber, client)
344         def check(ign):
345             client.reason.trap(error.ConnectionRefusedError)
346         return client.failDeferred.addCallback(check)
347
348
349     def test_exceptInConnectionLostCall(self):
350         """
351         If C{connectionLost} is called directory on a port object and that the
352         server factory raises an exception in C{stopFactory}, the exception is
353         passed through to the caller.
354
355         C{connectionLost} is called by L{reactor.disconnectAll} at shutdown.
356         """
357         serverFactory = MyServerFactory()
358         def raiseException():
359             raise RuntimeError("An error")
360         serverFactory.stopFactory = raiseException
361         port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
362         self.assertRaises(RuntimeError, port.connectionLost, None)
363
364
365
366 def callWithSpew(f):
367     from twisted.python.util import spewerWithLinenums as spewer
368     import sys
369     sys.settrace(spewer)
370     try:
371         f()
372     finally:
373         sys.settrace(None)
374
375 class LoopbackTestCase(unittest.TestCase):
376     """
377     Test loopback connections.
378     """
379     def test_closePortInProtocolFactory(self):
380         """
381         A port created with L{IReactorTCP.listenTCP} can be connected to with
382         L{IReactorTCP.connectTCP}.
383         """
384         f = ClosingFactory()
385         port = reactor.listenTCP(0, f, interface="127.0.0.1")
386         f.port = port
387         self.addCleanup(f.cleanUp)
388         portNumber = port.getHost().port
389         clientF = MyClientFactory()
390         reactor.connectTCP("127.0.0.1", portNumber, clientF)
391         def check(x):
392             self.assertTrue(clientF.protocol.made)
393             self.assertTrue(port.disconnected)
394             clientF.lostReason.trap(error.ConnectionDone)
395         return clientF.deferred.addCallback(check)
396
397     def _trapCnxDone(self, obj):
398         getattr(obj, 'trap', lambda x: None)(error.ConnectionDone)
399
400
401     def _connectedClientAndServerTest(self, callback):
402         """
403         Invoke the given callback with a client protocol and a server protocol
404         which have been connected to each other.
405         """
406         serverFactory = MyServerFactory()
407         serverConnMade = defer.Deferred()
408         serverFactory.protocolConnectionMade = serverConnMade
409         port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
410         self.addCleanup(port.stopListening)
411
412         portNumber = port.getHost().port
413         clientF = MyClientFactory()
414         clientConnMade = defer.Deferred()
415         clientF.protocolConnectionMade = clientConnMade
416         reactor.connectTCP("127.0.0.1", portNumber, clientF)
417
418         connsMade = defer.gatherResults([serverConnMade, clientConnMade])
419         def connected((serverProtocol, clientProtocol)):
420             callback(serverProtocol, clientProtocol)
421             serverProtocol.transport.loseConnection()
422             clientProtocol.transport.loseConnection()
423         connsMade.addCallback(connected)
424         return connsMade
425
426
427     def test_tcpNoDelay(self):
428         """
429         The transport of a protocol connected with L{IReactorTCP.connectTCP} or
430         L{IReactor.TCP.listenTCP} can have its I{TCP_NODELAY} state inspected
431         and manipulated with L{ITCPTransport.getTcpNoDelay} and
432         L{ITCPTransport.setTcpNoDelay}.
433         """
434         def check(serverProtocol, clientProtocol):
435             for p in [serverProtocol, clientProtocol]:
436                 transport = p.transport
437                 self.assertEqual(transport.getTcpNoDelay(), 0)
438                 transport.setTcpNoDelay(1)
439                 self.assertEqual(transport.getTcpNoDelay(), 1)
440                 transport.setTcpNoDelay(0)
441                 self.assertEqual(transport.getTcpNoDelay(), 0)
442         return self._connectedClientAndServerTest(check)
443
444
445     def test_tcpKeepAlive(self):
446         """
447         The transport of a protocol connected with L{IReactorTCP.connectTCP} or
448         L{IReactor.TCP.listenTCP} can have its I{SO_KEEPALIVE} state inspected
449         and manipulated with L{ITCPTransport.getTcpKeepAlive} and
450         L{ITCPTransport.setTcpKeepAlive}.
451         """
452         def check(serverProtocol, clientProtocol):
453             for p in [serverProtocol, clientProtocol]:
454                 transport = p.transport
455                 self.assertEqual(transport.getTcpKeepAlive(), 0)
456                 transport.setTcpKeepAlive(1)
457                 self.assertEqual(transport.getTcpKeepAlive(), 1)
458                 transport.setTcpKeepAlive(0)
459                 self.assertEqual(transport.getTcpKeepAlive(), 0)
460         return self._connectedClientAndServerTest(check)
461
462
463     def testFailing(self):
464         clientF = MyClientFactory()
465         # XXX we assume no one is listening on TCP port 69
466         reactor.connectTCP("127.0.0.1", 69, clientF, timeout=5)
467         def check(ignored):
468             clientF.reason.trap(error.ConnectionRefusedError)
469         return clientF.failDeferred.addCallback(check)
470
471
472     def test_connectionRefusedErrorNumber(self):
473         """
474         Assert that the error number of the ConnectionRefusedError is
475         ECONNREFUSED, and not some other socket related error.
476         """
477
478         # Bind a number of ports in the operating system.  We will attempt
479         # to connect to these in turn immediately after closing them, in the
480         # hopes that no one else has bound them in the mean time.  Any
481         # connection which succeeds is ignored and causes us to move on to
482         # the next port.  As soon as a connection attempt fails, we move on
483         # to making an assertion about how it failed.  If they all succeed,
484         # the test will fail.
485
486         # It would be nice to have a simpler, reliable way to cause a
487         # connection failure from the platform.
488         #
489         # On Linux (2.6.15), connecting to port 0 always fails.  FreeBSD
490         # (5.4) rejects the connection attempt with EADDRNOTAVAIL.
491         #
492         # On FreeBSD (5.4), listening on a port and then repeatedly
493         # connecting to it without ever accepting any connections eventually
494         # leads to an ECONNREFUSED.  On Linux (2.6.15), a seemingly
495         # unbounded number of connections succeed.
496
497         serverSockets = []
498         for i in xrange(10):
499             serverSocket = socket.socket()
500             serverSocket.bind(('127.0.0.1', 0))
501             serverSocket.listen(1)
502             serverSockets.append(serverSocket)
503         random.shuffle(serverSockets)
504
505         clientCreator = protocol.ClientCreator(reactor, protocol.Protocol)
506
507         def tryConnectFailure():
508             def connected(proto):
509                 """
510                 Darn.  Kill it and try again, if there are any tries left.
511                 """
512                 proto.transport.loseConnection()
513                 if serverSockets:
514                     return tryConnectFailure()
515                 self.fail("Could not fail to connect - could not test errno for that case.")
516
517             serverSocket = serverSockets.pop()
518             serverHost, serverPort = serverSocket.getsockname()
519             serverSocket.close()
520
521             connectDeferred = clientCreator.connectTCP(serverHost, serverPort)
522             connectDeferred.addCallback(connected)
523             return connectDeferred
524
525         refusedDeferred = tryConnectFailure()
526         self.assertFailure(refusedDeferred, error.ConnectionRefusedError)
527         def connRefused(exc):
528             self.assertEqual(exc.osError, errno.ECONNREFUSED)
529         refusedDeferred.addCallback(connRefused)
530         def cleanup(passthrough):
531             while serverSockets:
532                 serverSockets.pop().close()
533             return passthrough
534         refusedDeferred.addBoth(cleanup)
535         return refusedDeferred
536
537
538     def test_connectByServiceFail(self):
539         """
540         Connecting to a named service which does not exist raises
541         L{error.ServiceNameUnknownError}.
542         """
543         self.assertRaises(
544             error.ServiceNameUnknownError,
545             reactor.connectTCP,
546             "127.0.0.1", "thisbetternotexist", MyClientFactory())
547
548
549     def test_connectByService(self):
550         """
551         L{IReactorTCP.connectTCP} accepts the name of a service instead of a
552         port number and connects to the port number associated with that
553         service, as defined by L{socket.getservbyname}.
554         """
555         serverFactory = MyServerFactory()
556         serverConnMade = defer.Deferred()
557         serverFactory.protocolConnectionMade = serverConnMade
558         port = reactor.listenTCP(0, serverFactory, interface="127.0.0.1")
559         self.addCleanup(port.stopListening)
560         portNumber = port.getHost().port
561         clientFactory = MyClientFactory()
562         clientConnMade = defer.Deferred()
563         clientFactory.protocolConnectionMade = clientConnMade
564
565         def fakeGetServicePortByName(serviceName, protocolName):
566             if serviceName == 'http' and protocolName == 'tcp':
567                 return portNumber
568             return 10
569         self.patch(socket, 'getservbyname', fakeGetServicePortByName)
570
571         reactor.connectTCP('127.0.0.1', 'http', clientFactory)
572
573         connMade = defer.gatherResults([serverConnMade, clientConnMade])
574         def connected((serverProtocol, clientProtocol)):
575             self.assertTrue(
576                 serverFactory.called,
577                 "Server factory was not called upon to build a protocol.")
578             serverProtocol.transport.loseConnection()
579             clientProtocol.transport.loseConnection()
580         connMade.addCallback(connected)
581         return connMade
582
583
584 class StartStopFactory(protocol.Factory):
585
586     started = 0
587     stopped = 0
588
589     def startFactory(self):
590         if self.started or self.stopped:
591             raise RuntimeError
592         self.started = 1
593
594     def stopFactory(self):
595         if not self.started or self.stopped:
596             raise RuntimeError
597         self.stopped = 1
598
599
600 class ClientStartStopFactory(MyClientFactory):
601
602     started = 0
603     stopped = 0
604
605     def __init__(self, *a, **kw):
606         MyClientFactory.__init__(self, *a, **kw)
607         self.whenStopped = defer.Deferred()
608
609     def startFactory(self):
610         if self.started or self.stopped:
611             raise RuntimeError
612         self.started = 1
613
614     def stopFactory(self):
615         if not self.started or self.stopped:
616             raise RuntimeError
617         self.stopped = 1
618         self.whenStopped.callback(True)
619
620
621 class FactoryTestCase(unittest.TestCase):
622     """Tests for factories."""
623
624     def test_serverStartStop(self):
625         """
626         The factory passed to L{IReactorTCP.listenTCP} should be started only
627         when it transitions from being used on no ports to being used on one
628         port and should be stopped only when it transitions from being used on
629         one port to being used on no ports.
630         """
631         # Note - this test doesn't need to use listenTCP.  It is exercising
632         # logic implemented in Factory.doStart and Factory.doStop, so it could
633         # just call that directly.  Some other test can make sure that
634         # listenTCP and stopListening correctly call doStart and
635         # doStop. -exarkun
636
637         f = StartStopFactory()
638
639         # listen on port
640         p1 = reactor.listenTCP(0, f, interface='127.0.0.1')
641         self.addCleanup(p1.stopListening)
642
643         self.assertEqual((f.started, f.stopped), (1, 0))
644
645         # listen on two more ports
646         p2 = reactor.listenTCP(0, f, interface='127.0.0.1')
647         p3 = reactor.listenTCP(0, f, interface='127.0.0.1')
648
649         self.assertEqual((f.started, f.stopped), (1, 0))
650
651         # close two ports
652         d1 = defer.maybeDeferred(p1.stopListening)
653         d2 = defer.maybeDeferred(p2.stopListening)
654         closedDeferred = defer.gatherResults([d1, d2])
655         def cbClosed(ignored):
656             self.assertEqual((f.started, f.stopped), (1, 0))
657             # Close the last port
658             return p3.stopListening()
659         closedDeferred.addCallback(cbClosed)
660
661         def cbClosedAll(ignored):
662             self.assertEqual((f.started, f.stopped), (1, 1))
663         closedDeferred.addCallback(cbClosedAll)
664         return closedDeferred
665
666
667     def test_clientStartStop(self):
668         """
669         The factory passed to L{IReactorTCP.connectTCP} should be started when
670         the connection attempt starts and stopped when it is over.
671         """
672         f = ClosingFactory()
673         p = reactor.listenTCP(0, f, interface="127.0.0.1")
674         f.port = p
675         self.addCleanup(f.cleanUp)
676         portNumber = p.getHost().port
677
678         factory = ClientStartStopFactory()
679         reactor.connectTCP("127.0.0.1", portNumber, factory)
680         self.assertTrue(factory.started)
681         return loopUntil(lambda: factory.stopped)
682
683
684
685 class CannotBindTestCase(unittest.TestCase):
686     """
687     Tests for correct behavior when a reactor cannot bind to the required TCP
688     port.
689     """
690
691     def test_cannotBind(self):
692         """
693         L{IReactorTCP.listenTCP} raises L{error.CannotListenError} if the
694         address to listen on is already in use.
695         """
696         f = MyServerFactory()
697
698         p1 = reactor.listenTCP(0, f, interface='127.0.0.1')
699         self.addCleanup(p1.stopListening)
700         n = p1.getHost().port
701         dest = p1.getHost()
702         self.assertEqual(dest.type, "TCP")
703         self.assertEqual(dest.host, "127.0.0.1")
704         self.assertEqual(dest.port, n)
705
706         # make sure new listen raises error
707         self.assertRaises(error.CannotListenError,
708                           reactor.listenTCP, n, f, interface='127.0.0.1')
709
710
711
712     def _fireWhenDoneFunc(self, d, f):
713         """Returns closure that when called calls f and then callbacks d.
714         """
715         from twisted.python import util as tputil
716         def newf(*args, **kw):
717             rtn = f(*args, **kw)
718             d.callback('')
719             return rtn
720         return tputil.mergeFunctionMetadata(f, newf)
721
722
723     def test_clientBind(self):
724         """
725         L{IReactorTCP.connectTCP} calls C{Factory.clientConnectionFailed} with
726         L{error.ConnectBindError} if the bind address specified is already in
727         use.
728         """
729         theDeferred = defer.Deferred()
730         sf = MyServerFactory()
731         sf.startFactory = self._fireWhenDoneFunc(theDeferred, sf.startFactory)
732         p = reactor.listenTCP(0, sf, interface="127.0.0.1")
733         self.addCleanup(p.stopListening)
734
735         def _connect1(results):
736             d = defer.Deferred()
737             cf1 = MyClientFactory()
738             cf1.buildProtocol = self._fireWhenDoneFunc(d, cf1.buildProtocol)
739             reactor.connectTCP("127.0.0.1", p.getHost().port, cf1,
740                                bindAddress=("127.0.0.1", 0))
741             d.addCallback(_conmade, cf1)
742             return d
743
744         def _conmade(results, cf1):
745             d = defer.Deferred()
746             cf1.protocol.connectionMade = self._fireWhenDoneFunc(
747                 d, cf1.protocol.connectionMade)
748             d.addCallback(_check1connect2, cf1)
749             return d
750
751         def _check1connect2(results, cf1):
752             self.assertEqual(cf1.protocol.made, 1)
753
754             d1 = defer.Deferred()
755             d2 = defer.Deferred()
756             port = cf1.protocol.transport.getHost().port
757             cf2 = MyClientFactory()
758             cf2.clientConnectionFailed = self._fireWhenDoneFunc(
759                 d1, cf2.clientConnectionFailed)
760             cf2.stopFactory = self._fireWhenDoneFunc(d2, cf2.stopFactory)
761             reactor.connectTCP("127.0.0.1", p.getHost().port, cf2,
762                                bindAddress=("127.0.0.1", port))
763             d1.addCallback(_check2failed, cf1, cf2)
764             d2.addCallback(_check2stopped, cf1, cf2)
765             dl = defer.DeferredList([d1, d2])
766             dl.addCallback(_stop, cf1, cf2)
767             return dl
768
769         def _check2failed(results, cf1, cf2):
770             self.assertEqual(cf2.failed, 1)
771             cf2.reason.trap(error.ConnectBindError)
772             self.assertTrue(cf2.reason.check(error.ConnectBindError))
773             return results
774
775         def _check2stopped(results, cf1, cf2):
776             self.assertEqual(cf2.stopped, 1)
777             return results
778
779         def _stop(results, cf1, cf2):
780             d = defer.Deferred()
781             d.addCallback(_check1cleanup, cf1)
782             cf1.stopFactory = self._fireWhenDoneFunc(d, cf1.stopFactory)
783             cf1.protocol.transport.loseConnection()
784             return d
785
786         def _check1cleanup(results, cf1):
787             self.assertEqual(cf1.stopped, 1)
788
789         theDeferred.addCallback(_connect1)
790         return theDeferred
791
792
793
794 class MyOtherClientFactory(protocol.ClientFactory):
795     def buildProtocol(self, address):
796         self.address = address
797         self.protocol = AccumulatingProtocol()
798         return self.protocol
799
800
801
802 class LocalRemoteAddressTestCase(unittest.TestCase):
803     """
804     Tests for correct getHost/getPeer values and that the correct address is
805     passed to buildProtocol.
806     """
807     def test_hostAddress(self):
808         """
809         L{IListeningPort.getHost} returns the same address as a client
810         connection's L{ITCPTransport.getPeer}.
811         """
812         serverFactory = MyServerFactory()
813         serverFactory.protocolConnectionLost = defer.Deferred()
814         serverConnectionLost = serverFactory.protocolConnectionLost
815         port = reactor.listenTCP(0, serverFactory, interface='127.0.0.1')
816         self.addCleanup(port.stopListening)
817         n = port.getHost().port
818
819         clientFactory = MyClientFactory()
820         onConnection = clientFactory.protocolConnectionMade = defer.Deferred()
821         connector = reactor.connectTCP('127.0.0.1', n, clientFactory)
822
823         def check(ignored):
824             self.assertEqual([port.getHost()], clientFactory.peerAddresses)
825             self.assertEqual(
826                 port.getHost(), clientFactory.protocol.transport.getPeer())
827         onConnection.addCallback(check)
828
829         def cleanup(ignored):
830             # Clean up the client explicitly here so that tear down of
831             # the server side of the connection begins, then wait for
832             # the server side to actually disconnect.
833             connector.disconnect()
834             return serverConnectionLost
835         onConnection.addCallback(cleanup)
836
837         return onConnection
838
839
840
841 class WriterProtocol(protocol.Protocol):
842     def connectionMade(self):
843         # use everything ITransport claims to provide. If something here
844         # fails, the exception will be written to the log, but it will not
845         # directly flunk the test. The test will fail when maximum number of
846         # iterations have passed and the writer's factory.done has not yet
847         # been set.
848         self.transport.write("Hello Cleveland!\n")
849         seq = ["Goodbye", " cruel", " world", "\n"]
850         self.transport.writeSequence(seq)
851         peer = self.transport.getPeer()
852         if peer.type != "TCP":
853             print "getPeer returned non-TCP socket:", peer
854             self.factory.problem = 1
855         us = self.transport.getHost()
856         if us.type != "TCP":
857             print "getHost returned non-TCP socket:", us
858             self.factory.problem = 1
859         self.factory.done = 1
860
861         self.transport.loseConnection()
862
863 class ReaderProtocol(protocol.Protocol):
864     def dataReceived(self, data):
865         self.factory.data += data
866     def connectionLost(self, reason):
867         self.factory.done = 1
868
869 class WriterClientFactory(protocol.ClientFactory):
870     def __init__(self):
871         self.done = 0
872         self.data = ""
873     def buildProtocol(self, addr):
874         p = ReaderProtocol()
875         p.factory = self
876         self.protocol = p
877         return p
878
879 class WriteDataTestCase(unittest.TestCase):
880     """
881     Test that connected TCP sockets can actually write data. Try to exercise
882     the entire ITransport interface.
883     """
884
885     def test_writer(self):
886         """
887         L{ITCPTransport.write} and L{ITCPTransport.writeSequence} send bytes to
888         the other end of the connection.
889         """
890         f = protocol.Factory()
891         f.protocol = WriterProtocol
892         f.done = 0
893         f.problem = 0
894         wrappedF = WiredFactory(f)
895         p = reactor.listenTCP(0, wrappedF, interface="127.0.0.1")
896         self.addCleanup(p.stopListening)
897         n = p.getHost().port
898         clientF = WriterClientFactory()
899         wrappedClientF = WiredFactory(clientF)
900         reactor.connectTCP("127.0.0.1", n, wrappedClientF)
901
902         def check(ignored):
903             self.failUnless(f.done, "writer didn't finish, it probably died")
904             self.failUnless(f.problem == 0, "writer indicated an error")
905             self.failUnless(clientF.done,
906                             "client didn't see connection dropped")
907             expected = "".join(["Hello Cleveland!\n",
908                                 "Goodbye", " cruel", " world", "\n"])
909             self.failUnless(clientF.data == expected,
910                             "client didn't receive all the data it expected")
911         d = defer.gatherResults([wrappedF.onDisconnect,
912                                  wrappedClientF.onDisconnect])
913         return d.addCallback(check)
914
915
916     def test_writeAfterShutdownWithoutReading(self):
917         """
918         A TCP transport which is written to after the connection has been shut
919         down should notify its protocol that the connection has been lost, even
920         if the TCP transport is not actively being monitored for read events
921         (ie, pauseProducing was called on it).
922         """
923         # This is an unpleasant thing.  Generally tests shouldn't skip or
924         # run based on the name of the reactor being used (most tests
925         # shouldn't care _at all_ what reactor is being used, in fact).  The
926         # Gtk reactor cannot pass this test, though, because it fails to
927         # implement IReactorTCP entirely correctly.  Gtk is quite old at
928         # this point, so it's more likely that gtkreactor will be deprecated
929         # and removed rather than fixed to handle this case correctly.
930         # Since this is a pre-existing (and very long-standing) issue with
931         # the Gtk reactor, there's no reason for it to prevent this test
932         # being added to exercise the other reactors, for which the behavior
933         # was also untested but at least works correctly (now).  See #2833
934         # for information on the status of gtkreactor.
935         if reactor.__class__.__name__ == 'IOCPReactor':
936             raise unittest.SkipTest(
937                 "iocpreactor does not, in fact, stop reading immediately after "
938                 "pauseProducing is called. This results in a bonus disconnection "
939                 "notification. Under some circumstances, it might be possible to "
940                 "not receive this notifications (specifically, pauseProducing, "
941                 "deliver some data, proceed with this test).")
942         if reactor.__class__.__name__ == 'GtkReactor':
943             raise unittest.SkipTest(
944                 "gtkreactor does not implement unclean disconnection "
945                 "notification correctly.  This might more properly be "
946                 "a todo, but due to technical limitations it cannot be.")
947
948         # Called back after the protocol for the client side of the connection
949         # has paused its transport, preventing it from reading, therefore
950         # preventing it from noticing the disconnection before the rest of the
951         # actions which are necessary to trigger the case this test is for have
952         # been taken.
953         clientPaused = defer.Deferred()
954
955         # Called back when the protocol for the server side of the connection
956         # has received connection lost notification.
957         serverLost = defer.Deferred()
958
959         class Disconnecter(protocol.Protocol):
960             """
961             Protocol for the server side of the connection which disconnects
962             itself in a callback on clientPaused and publishes notification
963             when its connection is actually lost.
964             """
965             def connectionMade(self):
966                 """
967                 Set up a callback on clientPaused to lose the connection.
968                 """
969                 msg('Disconnector.connectionMade')
970                 def disconnect(ignored):
971                     msg('Disconnector.connectionMade disconnect')
972                     self.transport.loseConnection()
973                     msg('loseConnection called')
974                 clientPaused.addCallback(disconnect)
975
976             def connectionLost(self, reason):
977                 """
978                 Notify observers that the server side of the connection has
979                 ended.
980                 """
981                 msg('Disconnecter.connectionLost')
982                 serverLost.callback(None)
983                 msg('serverLost called back')
984
985         # Create the server port to which a connection will be made.
986         server = protocol.ServerFactory()
987         server.protocol = Disconnecter
988         port = reactor.listenTCP(0, server, interface='127.0.0.1')
989         self.addCleanup(port.stopListening)
990         addr = port.getHost()
991
992         class Infinite(object):
993             """
994             A producer which will write to its consumer as long as
995             resumeProducing is called.
996
997             @ivar consumer: The L{IConsumer} which will be written to.
998             """
999             implements(IPullProducer)
1000
1001             def __init__(self, consumer):
1002                 self.consumer = consumer
1003
1004             def resumeProducing(self):
1005                 msg('Infinite.resumeProducing')
1006                 self.consumer.write('x')
1007                 msg('Infinite.resumeProducing wrote to consumer')
1008
1009             def stopProducing(self):
1010                 msg('Infinite.stopProducing')
1011
1012
1013         class UnreadingWriter(protocol.Protocol):
1014             """
1015             Trivial protocol which pauses its transport immediately and then
1016             writes some bytes to it.
1017             """
1018             def connectionMade(self):
1019                 msg('UnreadingWriter.connectionMade')
1020                 self.transport.pauseProducing()
1021                 clientPaused.callback(None)
1022                 msg('clientPaused called back')
1023                 def write(ignored):
1024                     msg('UnreadingWriter.connectionMade write')
1025                     # This needs to be enough bytes to spill over into the
1026                     # userspace Twisted send buffer - if it all fits into
1027                     # the kernel, Twisted won't even poll for OUT events,
1028                     # which means it won't poll for any events at all, so
1029                     # the disconnection is never noticed.  This is due to
1030                     # #1662.  When #1662 is fixed, this test will likely
1031                     # need to be adjusted, otherwise connection lost
1032                     # notification will happen too soon and the test will
1033                     # probably begin to fail with ConnectionDone instead of
1034                     # ConnectionLost (in any case, it will no longer be
1035                     # entirely correct).
1036                     producer = Infinite(self.transport)
1037                     msg('UnreadingWriter.connectionMade write created producer')
1038                     self.transport.registerProducer(producer, False)
1039                     msg('UnreadingWriter.connectionMade write registered producer')
1040                 serverLost.addCallback(write)
1041
1042         # Create the client and initiate the connection
1043         client = MyClientFactory()
1044         client.protocolFactory = UnreadingWriter
1045         clientConnectionLost = client.deferred
1046         def cbClientLost(ignored):
1047             msg('cbClientLost')
1048             return client.lostReason
1049         clientConnectionLost.addCallback(cbClientLost)
1050         msg('Connecting to %s:%s' % (addr.host, addr.port))
1051         reactor.connectTCP(addr.host, addr.port, client)
1052
1053         # By the end of the test, the client should have received notification
1054         # of unclean disconnection.
1055         msg('Returning Deferred')
1056         return self.assertFailure(clientConnectionLost, error.ConnectionLost)
1057
1058
1059
1060 class ConnectionLosingProtocol(protocol.Protocol):
1061     def connectionMade(self):
1062         self.transport.write("1")
1063         self.transport.loseConnection()
1064         self.master._connectionMade()
1065         self.master.ports.append(self.transport)
1066
1067
1068
1069 class NoopProtocol(protocol.Protocol):
1070     def connectionMade(self):
1071         self.d = defer.Deferred()
1072         self.master.serverConns.append(self.d)
1073
1074     def connectionLost(self, reason):
1075         self.d.callback(True)
1076
1077
1078
1079 class ConnectionLostNotifyingProtocol(protocol.Protocol):
1080     """
1081     Protocol which fires a Deferred which was previously passed to
1082     its initializer when the connection is lost.
1083
1084     @ivar onConnectionLost: The L{Deferred} which will be fired in
1085         C{connectionLost}.
1086
1087     @ivar lostConnectionReason: C{None} until the connection is lost, then a
1088         reference to the reason passed to C{connectionLost}.
1089     """
1090     def __init__(self, onConnectionLost):
1091         self.lostConnectionReason = None
1092         self.onConnectionLost = onConnectionLost
1093
1094
1095     def connectionLost(self, reason):
1096         self.lostConnectionReason = reason
1097         self.onConnectionLost.callback(self)
1098
1099
1100
1101 class HandleSavingProtocol(ConnectionLostNotifyingProtocol):
1102     """
1103     Protocol which grabs the platform-specific socket handle and
1104     saves it as an attribute on itself when the connection is
1105     established.
1106     """
1107     def makeConnection(self, transport):
1108         """
1109         Save the platform-specific socket handle for future
1110         introspection.
1111         """
1112         self.handle = transport.getHandle()
1113         return protocol.Protocol.makeConnection(self, transport)
1114
1115
1116
1117 class ProperlyCloseFilesMixin:
1118     """
1119     Tests for platform resources properly being cleaned up.
1120     """
1121     def createServer(self, address, portNumber, factory):
1122         """
1123         Bind a server port to which connections will be made.  The server
1124         should use the given protocol factory.
1125
1126         @return: The L{IListeningPort} for the server created.
1127         """
1128         raise NotImplementedError()
1129
1130
1131     def connectClient(self, address, portNumber, clientCreator):
1132         """
1133         Establish a connection to the given address using the given
1134         L{ClientCreator} instance.
1135
1136         @return: A Deferred which will fire with the connected protocol instance.
1137         """
1138         raise NotImplementedError()
1139
1140
1141     def getHandleExceptionType(self):
1142         """
1143         Return the exception class which will be raised when an operation is
1144         attempted on a closed platform handle.
1145         """
1146         raise NotImplementedError()
1147
1148
1149     def getHandleErrorCode(self):
1150         """
1151         Return the errno expected to result from writing to a closed
1152         platform socket handle.
1153         """
1154         # These platforms have been seen to give EBADF:
1155         #
1156         #  Linux 2.4.26, Linux 2.6.15, OS X 10.4, FreeBSD 5.4
1157         #  Windows 2000 SP 4, Windows XP SP 2
1158         return errno.EBADF
1159
1160
1161     def test_properlyCloseFiles(self):
1162         """
1163         Test that lost connections properly have their underlying socket
1164         resources cleaned up.
1165         """
1166         onServerConnectionLost = defer.Deferred()
1167         serverFactory = protocol.ServerFactory()
1168         serverFactory.protocol = lambda: ConnectionLostNotifyingProtocol(
1169             onServerConnectionLost)
1170         serverPort = self.createServer('127.0.0.1', 0, serverFactory)
1171
1172         onClientConnectionLost = defer.Deferred()
1173         serverAddr = serverPort.getHost()
1174         clientCreator = protocol.ClientCreator(
1175             reactor, lambda: HandleSavingProtocol(onClientConnectionLost))
1176         clientDeferred = self.connectClient(
1177             serverAddr.host, serverAddr.port, clientCreator)
1178
1179         def clientConnected(client):
1180             """
1181             Disconnect the client.  Return a Deferred which fires when both
1182             the client and the server have received disconnect notification.
1183             """
1184             client.transport.write(
1185                 'some bytes to make sure the connection is set up')
1186             client.transport.loseConnection()
1187             return defer.gatherResults([
1188                 onClientConnectionLost, onServerConnectionLost])
1189         clientDeferred.addCallback(clientConnected)
1190
1191         def clientDisconnected((client, server)):
1192             """
1193             Verify that the underlying platform socket handle has been
1194             cleaned up.
1195             """
1196             client.lostConnectionReason.trap(error.ConnectionClosed)
1197             server.lostConnectionReason.trap(error.ConnectionClosed)
1198             expectedErrorCode = self.getHandleErrorCode()
1199             err = self.assertRaises(
1200                 self.getHandleExceptionType(), client.handle.send, 'bytes')
1201             self.assertEqual(err.args[0], expectedErrorCode)
1202         clientDeferred.addCallback(clientDisconnected)
1203
1204         def cleanup(passthrough):
1205             """
1206             Shut down the server port.  Return a Deferred which fires when
1207             this has completed.
1208             """
1209             result = defer.maybeDeferred(serverPort.stopListening)
1210             result.addCallback(lambda ign: passthrough)
1211             return result
1212         clientDeferred.addBoth(cleanup)
1213
1214         return clientDeferred
1215
1216
1217
1218 class ProperlyCloseFilesTestCase(unittest.TestCase, ProperlyCloseFilesMixin):
1219     """
1220     Test that the sockets created by L{IReactorTCP.connectTCP} are cleaned up
1221     when the connection they are associated with is closed.
1222     """
1223     def createServer(self, address, portNumber, factory):
1224         """
1225         Create a TCP server using L{IReactorTCP.listenTCP}.
1226         """
1227         return reactor.listenTCP(portNumber, factory, interface=address)
1228
1229
1230     def connectClient(self, address, portNumber, clientCreator):
1231         """
1232         Create a TCP client using L{IReactorTCP.connectTCP}.
1233         """
1234         return clientCreator.connectTCP(address, portNumber)
1235
1236
1237     def getHandleExceptionType(self):
1238         """
1239         Return L{socket.error} as the expected error type which will be
1240         raised by a write to the low-level socket object after it has been
1241         closed.
1242         """
1243         return socket.error
1244
1245
1246
1247 class WiredForDeferreds(policies.ProtocolWrapper):
1248     def __init__(self, factory, wrappedProtocol):
1249         policies.ProtocolWrapper.__init__(self, factory, wrappedProtocol)
1250
1251     def connectionMade(self):
1252         policies.ProtocolWrapper.connectionMade(self)
1253         self.factory.onConnect.callback(None)
1254
1255     def connectionLost(self, reason):
1256         policies.ProtocolWrapper.connectionLost(self, reason)
1257         self.factory.onDisconnect.callback(None)
1258
1259
1260
1261 class WiredFactory(policies.WrappingFactory):
1262     protocol = WiredForDeferreds
1263
1264     def __init__(self, wrappedFactory):
1265         policies.WrappingFactory.__init__(self, wrappedFactory)
1266         self.onConnect = defer.Deferred()
1267         self.onDisconnect = defer.Deferred()
1268
1269
1270
1271 class AddressTestCase(unittest.TestCase):
1272     """
1273     Tests for address-related interactions with client and server protocols.
1274     """
1275     def setUp(self):
1276         """
1277         Create a port and connected client/server pair which can be used
1278         to test factory behavior related to addresses.
1279
1280         @return: A L{defer.Deferred} which will be called back when both the
1281             client and server protocols have received their connection made
1282             callback.
1283         """
1284         class RememberingWrapper(protocol.ClientFactory):
1285             """
1286             Simple wrapper factory which records the addresses which are
1287             passed to its L{buildProtocol} method and delegates actual
1288             protocol creation to another factory.
1289
1290             @ivar addresses: A list of the objects passed to buildProtocol.
1291             @ivar factory: The wrapped factory to which protocol creation is
1292                 delegated.
1293             """
1294             def __init__(self, factory):
1295                 self.addresses = []
1296                 self.factory = factory
1297
1298             # Only bother to pass on buildProtocol calls to the wrapped
1299             # factory - doStart, doStop, etc aren't necessary for this test
1300             # to pass.
1301             def buildProtocol(self, addr):
1302                 """
1303                 Append the given address to C{self.addresses} and forward
1304                 the call to C{self.factory}.
1305                 """
1306                 self.addresses.append(addr)
1307                 return self.factory.buildProtocol(addr)
1308
1309         # Make a server which we can receive connection and disconnection
1310         # notification for, and which will record the address passed to its
1311         # buildProtocol.
1312         self.server = MyServerFactory()
1313         self.serverConnMade = self.server.protocolConnectionMade = defer.Deferred()
1314         self.serverConnLost = self.server.protocolConnectionLost = defer.Deferred()
1315         # RememberingWrapper is a ClientFactory, but ClientFactory is-a
1316         # ServerFactory, so this is okay.
1317         self.serverWrapper = RememberingWrapper(self.server)
1318
1319         # Do something similar for a client.
1320         self.client = MyClientFactory()
1321         self.clientConnMade = self.client.protocolConnectionMade = defer.Deferred()
1322         self.clientConnLost = self.client.protocolConnectionLost = defer.Deferred()
1323         self.clientWrapper = RememberingWrapper(self.client)
1324
1325         self.port = reactor.listenTCP(0, self.serverWrapper, interface='127.0.0.1')
1326         self.connector = reactor.connectTCP(
1327             self.port.getHost().host, self.port.getHost().port, self.clientWrapper)
1328
1329         return defer.gatherResults([self.serverConnMade, self.clientConnMade])
1330
1331
1332     def tearDown(self):
1333         """
1334         Disconnect the client/server pair and shutdown the port created in
1335         L{setUp}.
1336         """
1337         self.connector.disconnect()
1338         return defer.gatherResults([
1339             self.serverConnLost, self.clientConnLost,
1340             defer.maybeDeferred(self.port.stopListening)])
1341
1342
1343     def test_buildProtocolClient(self):
1344         """
1345         L{ClientFactory.buildProtocol} should be invoked with the address of
1346         the server to which a connection has been established, which should
1347         be the same as the address reported by the C{getHost} method of the
1348         transport of the server protocol and as the C{getPeer} method of the
1349         transport of the client protocol.
1350         """
1351         serverHost = self.server.protocol.transport.getHost()
1352         clientPeer = self.client.protocol.transport.getPeer()
1353
1354         self.assertEqual(
1355             self.clientWrapper.addresses,
1356             [IPv4Address('TCP', serverHost.host, serverHost.port)])
1357         self.assertEqual(
1358             self.clientWrapper.addresses,
1359             [IPv4Address('TCP', clientPeer.host, clientPeer.port)])
1360
1361
1362
1363 class LargeBufferWriterProtocol(protocol.Protocol):
1364
1365     # Win32 sockets cannot handle single huge chunks of bytes.  Write one
1366     # massive string to make sure Twisted deals with this fact.
1367
1368     def connectionMade(self):
1369         # write 60MB
1370         self.transport.write('X'*self.factory.len)
1371         self.factory.done = 1
1372         self.transport.loseConnection()
1373
1374 class LargeBufferReaderProtocol(protocol.Protocol):
1375     def dataReceived(self, data):
1376         self.factory.len += len(data)
1377     def connectionLost(self, reason):
1378         self.factory.done = 1
1379
1380 class LargeBufferReaderClientFactory(protocol.ClientFactory):
1381     def __init__(self):
1382         self.done = 0
1383         self.len = 0
1384     def buildProtocol(self, addr):
1385         p = LargeBufferReaderProtocol()
1386         p.factory = self
1387         self.protocol = p
1388         return p
1389
1390
1391 class FireOnClose(policies.ProtocolWrapper):
1392     """A wrapper around a protocol that makes it fire a deferred when
1393     connectionLost is called.
1394     """
1395     def connectionLost(self, reason):
1396         policies.ProtocolWrapper.connectionLost(self, reason)
1397         self.factory.deferred.callback(None)
1398
1399
1400 class FireOnCloseFactory(policies.WrappingFactory):
1401     protocol = FireOnClose
1402
1403     def __init__(self, wrappedFactory):
1404         policies.WrappingFactory.__init__(self, wrappedFactory)
1405         self.deferred = defer.Deferred()
1406
1407
1408 class LargeBufferTestCase(unittest.TestCase):
1409     """Test that buffering large amounts of data works.
1410     """
1411
1412     datalen = 60*1024*1024
1413     def testWriter(self):
1414         f = protocol.Factory()
1415         f.protocol = LargeBufferWriterProtocol
1416         f.done = 0
1417         f.problem = 0
1418         f.len = self.datalen
1419         wrappedF = FireOnCloseFactory(f)
1420         p = reactor.listenTCP(0, wrappedF, interface="127.0.0.1")
1421         self.addCleanup(p.stopListening)
1422         n = p.getHost().port
1423         clientF = LargeBufferReaderClientFactory()
1424         wrappedClientF = FireOnCloseFactory(clientF)
1425         reactor.connectTCP("127.0.0.1", n, wrappedClientF)
1426
1427         d = defer.gatherResults([wrappedF.deferred, wrappedClientF.deferred])
1428         def check(ignored):
1429             self.failUnless(f.done, "writer didn't finish, it probably died")
1430             self.failUnless(clientF.len == self.datalen,
1431                             "client didn't receive all the data it expected "
1432                             "(%d != %d)" % (clientF.len, self.datalen))
1433             self.failUnless(clientF.done,
1434                             "client didn't see connection dropped")
1435         return d.addCallback(check)
1436
1437
1438 class MyHCProtocol(AccumulatingProtocol):
1439
1440     implements(IHalfCloseableProtocol)
1441
1442     readHalfClosed = False
1443     writeHalfClosed = False
1444
1445     def readConnectionLost(self):
1446         self.readHalfClosed = True
1447         # Invoke notification logic from the base class to simplify testing.
1448         if self.writeHalfClosed:
1449             self.connectionLost(None)
1450
1451     def writeConnectionLost(self):
1452         self.writeHalfClosed = True
1453         # Invoke notification logic from the base class to simplify testing.
1454         if self.readHalfClosed:
1455             self.connectionLost(None)
1456
1457
1458 class MyHCFactory(protocol.ServerFactory):
1459
1460     called = 0
1461     protocolConnectionMade = None
1462
1463     def buildProtocol(self, addr):
1464         self.called += 1
1465         p = MyHCProtocol()
1466         p.factory = self
1467         self.protocol = p
1468         return p
1469
1470
1471 class HalfCloseTestCase(unittest.TestCase):
1472     """Test half-closing connections."""
1473
1474     def setUp(self):
1475         self.f = f = MyHCFactory()
1476         self.p = p = reactor.listenTCP(0, f, interface="127.0.0.1")
1477         self.addCleanup(p.stopListening)
1478         d = loopUntil(lambda :p.connected)
1479
1480         self.cf = protocol.ClientCreator(reactor, MyHCProtocol)
1481
1482         d.addCallback(lambda _: self.cf.connectTCP(p.getHost().host,
1483                                                    p.getHost().port))
1484         d.addCallback(self._setUp)
1485         return d
1486
1487     def _setUp(self, client):
1488         self.client = client
1489         self.clientProtoConnectionLost = self.client.closedDeferred = defer.Deferred()
1490         self.assertEqual(self.client.transport.connected, 1)
1491         # Wait for the server to notice there is a connection, too.
1492         return loopUntil(lambda: getattr(self.f, 'protocol', None) is not None)
1493
1494     def tearDown(self):
1495         self.assertEqual(self.client.closed, 0)
1496         self.client.transport.loseConnection()
1497         d = defer.maybeDeferred(self.p.stopListening)
1498         d.addCallback(lambda ign: self.clientProtoConnectionLost)
1499         d.addCallback(self._tearDown)
1500         return d
1501
1502     def _tearDown(self, ignored):
1503         self.assertEqual(self.client.closed, 1)
1504         # because we did half-close, the server also needs to
1505         # closed explicitly.
1506         self.assertEqual(self.f.protocol.closed, 0)
1507         d = defer.Deferred()
1508         def _connectionLost(reason):
1509             self.f.protocol.closed = 1
1510             d.callback(None)
1511         self.f.protocol.connectionLost = _connectionLost
1512         self.f.protocol.transport.loseConnection()
1513         d.addCallback(lambda x:self.assertEqual(self.f.protocol.closed, 1))
1514         return d
1515
1516     def testCloseWriteCloser(self):
1517         client = self.client
1518         f = self.f
1519         t = client.transport
1520
1521         t.write("hello")
1522         d = loopUntil(lambda :len(t._tempDataBuffer) == 0)
1523         def loseWrite(ignored):
1524             t.loseWriteConnection()
1525             return loopUntil(lambda :t._writeDisconnected)
1526         def check(ignored):
1527             self.assertEqual(client.closed, False)
1528             self.assertEqual(client.writeHalfClosed, True)
1529             self.assertEqual(client.readHalfClosed, False)
1530             return loopUntil(lambda :f.protocol.readHalfClosed)
1531         def write(ignored):
1532             w = client.transport.write
1533             w(" world")
1534             w("lalala fooled you")
1535             self.assertEqual(0, len(client.transport._tempDataBuffer))
1536             self.assertEqual(f.protocol.data, "hello")
1537             self.assertEqual(f.protocol.closed, False)
1538             self.assertEqual(f.protocol.readHalfClosed, True)
1539         return d.addCallback(loseWrite).addCallback(check).addCallback(write)
1540
1541     def testWriteCloseNotification(self):
1542         f = self.f
1543         f.protocol.transport.loseWriteConnection()
1544
1545         d = defer.gatherResults([
1546             loopUntil(lambda :f.protocol.writeHalfClosed),
1547             loopUntil(lambda :self.client.readHalfClosed)])
1548         d.addCallback(lambda _: self.assertEqual(
1549             f.protocol.readHalfClosed, False))
1550         return d
1551
1552
1553 class HalfClose2TestCase(unittest.TestCase):
1554
1555     def setUp(self):
1556         self.f = f = MyServerFactory()
1557         self.f.protocolConnectionMade = defer.Deferred()
1558         self.p = p = reactor.listenTCP(0, f, interface="127.0.0.1")
1559
1560         # XXX we don't test server side yet since we don't do it yet
1561         d = protocol.ClientCreator(reactor, AccumulatingProtocol).connectTCP(
1562             p.getHost().host, p.getHost().port)
1563         d.addCallback(self._gotClient)
1564         return d
1565
1566     def _gotClient(self, client):
1567         self.client = client
1568         # Now wait for the server to catch up - it doesn't matter if this
1569         # Deferred has already fired and gone away, in that case we'll
1570         # return None and not wait at all, which is precisely correct.
1571         return self.f.protocolConnectionMade
1572
1573     def tearDown(self):
1574         self.client.transport.loseConnection()
1575         return self.p.stopListening()
1576
1577     def testNoNotification(self):
1578         """
1579         TCP protocols support half-close connections, but not all of them
1580         support being notified of write closes.  In this case, test that
1581         half-closing the connection causes the peer's connection to be
1582         closed.
1583         """
1584         self.client.transport.write("hello")
1585         self.client.transport.loseWriteConnection()
1586         self.f.protocol.closedDeferred = d = defer.Deferred()
1587         self.client.closedDeferred = d2 = defer.Deferred()
1588         d.addCallback(lambda x:
1589                       self.assertEqual(self.f.protocol.data, 'hello'))
1590         d.addCallback(lambda x: self.assertEqual(self.f.protocol.closed, True))
1591         return defer.gatherResults([d, d2])
1592
1593     def testShutdownException(self):
1594         """
1595         If the other side has already closed its connection,
1596         loseWriteConnection should pass silently.
1597         """
1598         self.f.protocol.transport.loseConnection()
1599         self.client.transport.write("X")
1600         self.client.transport.loseWriteConnection()
1601         self.f.protocol.closedDeferred = d = defer.Deferred()
1602         self.client.closedDeferred = d2 = defer.Deferred()
1603         d.addCallback(lambda x:
1604                       self.assertEqual(self.f.protocol.closed, True))
1605         return defer.gatherResults([d, d2])
1606
1607
1608 class HalfCloseBuggyApplicationTests(unittest.TestCase):
1609     """
1610     Test half-closing connections where notification code has bugs.
1611     """
1612
1613     def setUp(self):
1614         """
1615         Set up a server and connect a client to it.  Return a Deferred which
1616         only fires once this is done.
1617         """
1618         self.serverFactory = MyHCFactory()
1619         self.serverFactory.protocolConnectionMade = defer.Deferred()
1620         self.port = reactor.listenTCP(
1621             0, self.serverFactory, interface="127.0.0.1")
1622         self.addCleanup(self.port.stopListening)
1623         addr = self.port.getHost()
1624         creator = protocol.ClientCreator(reactor, MyHCProtocol)
1625         clientDeferred = creator.connectTCP(addr.host, addr.port)
1626         def setClient(clientProtocol):
1627             self.clientProtocol = clientProtocol
1628         clientDeferred.addCallback(setClient)
1629         return defer.gatherResults([
1630             self.serverFactory.protocolConnectionMade,
1631             clientDeferred])
1632
1633
1634     def aBug(self, *args):
1635         """
1636         Fake implementation of a callback which illegally raises an
1637         exception.
1638         """
1639         raise RuntimeError("ONO I AM BUGGY CODE")
1640
1641
1642     def _notificationRaisesTest(self):
1643         """
1644         Helper for testing that an exception is logged by the time the
1645         client protocol loses its connection.
1646         """
1647         closed = self.clientProtocol.closedDeferred = defer.Deferred()
1648         self.clientProtocol.transport.loseWriteConnection()
1649         def check(ignored):
1650             errors = self.flushLoggedErrors(RuntimeError)
1651             self.assertEqual(len(errors), 1)
1652         closed.addCallback(check)
1653         return closed
1654
1655
1656     def test_readNotificationRaises(self):
1657         """
1658         If C{readConnectionLost} raises an exception when the transport
1659         calls it to notify the protocol of that event, the exception should
1660         be logged and the protocol should be disconnected completely.
1661         """
1662         self.serverFactory.protocol.readConnectionLost = self.aBug
1663         return self._notificationRaisesTest()
1664
1665
1666     def test_writeNotificationRaises(self):
1667         """
1668         If C{writeConnectionLost} raises an exception when the transport
1669         calls it to notify the protocol of that event, the exception should
1670         be logged and the protocol should be disconnected completely.
1671         """
1672         self.clientProtocol.writeConnectionLost = self.aBug
1673         return self._notificationRaisesTest()
1674
1675
1676
1677 class LogTestCase(unittest.TestCase):
1678     """
1679     Test logging facility of TCP base classes.
1680     """
1681
1682     def test_logstrClientSetup(self):
1683         """
1684         Check that the log customization of the client transport happens
1685         once the client is connected.
1686         """
1687         server = MyServerFactory()
1688
1689         client = MyClientFactory()
1690         client.protocolConnectionMade = defer.Deferred()
1691
1692         port = reactor.listenTCP(0, server, interface='127.0.0.1')
1693         self.addCleanup(port.stopListening)
1694
1695         connector = reactor.connectTCP(
1696             port.getHost().host, port.getHost().port, client)
1697         self.addCleanup(connector.disconnect)
1698
1699         # It should still have the default value
1700         self.assertEqual(connector.transport.logstr,
1701                           "Uninitialized")
1702
1703         def cb(ign):
1704             self.assertEqual(connector.transport.logstr,
1705                               "AccumulatingProtocol,client")
1706         client.protocolConnectionMade.addCallback(cb)
1707         return client.protocolConnectionMade
1708
1709
1710
1711 class PauseProducingTestCase(unittest.TestCase):
1712     """
1713     Test some behaviors of pausing the production of a transport.
1714     """
1715
1716     def test_pauseProducingInConnectionMade(self):
1717         """
1718         In C{connectionMade} of a client protocol, C{pauseProducing} used to be
1719         ignored: this test is here to ensure it's not ignored.
1720         """
1721         server = MyServerFactory()
1722
1723         client = MyClientFactory()
1724         client.protocolConnectionMade = defer.Deferred()
1725
1726         port = reactor.listenTCP(0, server, interface='127.0.0.1')
1727         self.addCleanup(port.stopListening)
1728
1729         connector = reactor.connectTCP(
1730             port.getHost().host, port.getHost().port, client)
1731         self.addCleanup(connector.disconnect)
1732
1733         def checkInConnectionMade(proto):
1734             tr = proto.transport
1735             # The transport should already be monitored
1736             self.assertIn(tr, reactor.getReaders() +
1737                               reactor.getWriters())
1738             proto.transport.pauseProducing()
1739             self.assertNotIn(tr, reactor.getReaders() +
1740                                  reactor.getWriters())
1741             d = defer.Deferred()
1742             d.addCallback(checkAfterConnectionMade)
1743             reactor.callLater(0, d.callback, proto)
1744             return d
1745         def checkAfterConnectionMade(proto):
1746             tr = proto.transport
1747             # The transport should still not be monitored
1748             self.assertNotIn(tr, reactor.getReaders() +
1749                                  reactor.getWriters())
1750         client.protocolConnectionMade.addCallback(checkInConnectionMade)
1751         return client.protocolConnectionMade
1752
1753     if not interfaces.IReactorFDSet.providedBy(reactor):
1754         test_pauseProducingInConnectionMade.skip = "Reactor not providing IReactorFDSet"
1755
1756
1757
1758 class CallBackOrderTestCase(unittest.TestCase):
1759     """
1760     Test the order of reactor callbacks
1761     """
1762
1763     def test_loseOrder(self):
1764         """
1765         Check that Protocol.connectionLost is called before factory's
1766         clientConnectionLost
1767         """
1768         server = MyServerFactory()
1769         server.protocolConnectionMade = (defer.Deferred()
1770                 .addCallback(lambda proto: self.addCleanup(
1771                              proto.transport.loseConnection)))
1772
1773         client = MyClientFactory()
1774         client.protocolConnectionLost = defer.Deferred()
1775         client.protocolConnectionMade = defer.Deferred()
1776
1777         def _cbCM(res):
1778             """
1779             protocol.connectionMade callback
1780             """
1781             reactor.callLater(0, client.protocol.transport.loseConnection)
1782
1783         client.protocolConnectionMade.addCallback(_cbCM)
1784
1785         port = reactor.listenTCP(0, server, interface='127.0.0.1')
1786         self.addCleanup(port.stopListening)
1787
1788         connector = reactor.connectTCP(
1789             port.getHost().host, port.getHost().port, client)
1790         self.addCleanup(connector.disconnect)
1791
1792         def _cbCCL(res):
1793             """
1794             factory.clientConnectionLost callback
1795             """
1796             return 'CCL'
1797
1798         def _cbCL(res):
1799             """
1800             protocol.connectionLost callback
1801             """
1802             return 'CL'
1803
1804         def _cbGather(res):
1805             self.assertEqual(res, ['CL', 'CCL'])
1806
1807         d = defer.gatherResults([
1808                 client.protocolConnectionLost.addCallback(_cbCL),
1809                 client.deferred.addCallback(_cbCCL)])
1810         return d.addCallback(_cbGather)
1811
1812
1813
1814 try:
1815     import resource
1816 except ImportError:
1817     pass
1818 else:
1819     numRounds = resource.getrlimit(resource.RLIMIT_NOFILE)[0] + 10
1820     ProperlyCloseFilesTestCase.numberRounds = numRounds