1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
5 Test case for L{twisted.protocols.loopback}.
8 from zope.interface import implements
10 from twisted.trial import unittest
11 from twisted.trial.util import suppress as SUPPRESS
12 from twisted.protocols import basic, loopback
13 from twisted.internet import defer
14 from twisted.internet.protocol import Protocol
15 from twisted.internet.defer import Deferred
16 from twisted.internet.interfaces import IAddress, IPushProducer, IPullProducer
17 from twisted.internet import reactor, interfaces
20 class SimpleProtocol(basic.LineReceiver):
22 self.conn = defer.Deferred()
26 def connectionMade(self):
27 self.conn.callback(None)
29 def lineReceived(self, line):
30 self.lines.append(line)
32 def connectionLost(self, reason):
33 self.connLost.append(reason)
36 class DoomProtocol(SimpleProtocol):
38 def lineReceived(self, line):
41 # by this point we should have connection closed,
42 # but just in case we didn't we won't ever send 'Hello 4'
43 self.sendLine("Hello %d" % self.i)
44 SimpleProtocol.lineReceived(self, line)
45 if self.lines[-1] == "Hello 3":
46 self.transport.loseConnection()
49 class LoopbackTestCaseMixin:
50 def testRegularFunction(self):
54 def sendALine(result):
55 s.sendLine("THIS IS LINE ONE!")
56 s.transport.loseConnection()
57 s.conn.addCallback(sendALine)
60 self.assertEqual(c.lines, ["THIS IS LINE ONE!"])
61 self.assertEqual(len(s.connLost), 1)
62 self.assertEqual(len(c.connLost), 1)
63 d = defer.maybeDeferred(self.loopbackFunc, s, c)
67 def testSneakyHiddenDoom(self):
71 def sendALine(result):
72 s.sendLine("DOOM LINE")
73 s.conn.addCallback(sendALine)
76 self.assertEqual(s.lines, ['Hello 1', 'Hello 2', 'Hello 3'])
77 self.assertEqual(c.lines, ['DOOM LINE', 'Hello 1', 'Hello 2', 'Hello 3'])
78 self.assertEqual(len(s.connLost), 1)
79 self.assertEqual(len(c.connLost), 1)
80 d = defer.maybeDeferred(self.loopbackFunc, s, c)
86 class LoopbackAsyncTestCase(LoopbackTestCaseMixin, unittest.TestCase):
87 loopbackFunc = staticmethod(loopback.loopbackAsync)
90 def test_makeConnection(self):
92 Test that the client and server protocol both have makeConnection
93 invoked on them by loopbackAsync.
95 class TestProtocol(Protocol):
97 def makeConnection(self, transport):
98 self.transport = transport
100 server = TestProtocol()
101 client = TestProtocol()
102 loopback.loopbackAsync(server, client)
103 self.failIfEqual(client.transport, None)
104 self.failIfEqual(server.transport, None)
107 def _hostpeertest(self, get, testServer):
109 Test one of the permutations of client/server host/peer.
111 class TestProtocol(Protocol):
112 def makeConnection(self, transport):
113 Protocol.makeConnection(self, transport)
114 self.onConnection.callback(transport)
117 server = TestProtocol()
118 d = server.onConnection = Deferred()
122 client = TestProtocol()
123 d = client.onConnection = Deferred()
125 loopback.loopbackAsync(server, client)
127 def connected(transport):
128 host = getattr(transport, get)()
129 self.failUnless(IAddress.providedBy(host))
131 return d.addCallback(connected)
134 def test_serverHost(self):
136 Test that the server gets a transport with a properly functioning
137 implementation of L{ITransport.getHost}.
139 return self._hostpeertest("getHost", True)
142 def test_serverPeer(self):
144 Like C{test_serverHost} but for L{ITransport.getPeer}
146 return self._hostpeertest("getPeer", True)
149 def test_clientHost(self, get="getHost"):
151 Test that the client gets a transport with a properly functioning
152 implementation of L{ITransport.getHost}.
154 return self._hostpeertest("getHost", False)
157 def test_clientPeer(self):
159 Like C{test_clientHost} but for L{ITransport.getPeer}.
161 return self._hostpeertest("getPeer", False)
164 def _greetingtest(self, write, testServer):
166 Test one of the permutations of write/writeSequence client/server.
168 class GreeteeProtocol(Protocol):
170 def dataReceived(self, bytes):
172 if self.bytes == "bytes":
173 self.received.callback(None)
175 class GreeterProtocol(Protocol):
176 def connectionMade(self):
177 getattr(self.transport, write)("bytes")
180 server = GreeterProtocol()
181 client = GreeteeProtocol()
182 d = client.received = Deferred()
184 server = GreeteeProtocol()
185 d = server.received = Deferred()
186 client = GreeterProtocol()
188 loopback.loopbackAsync(server, client)
192 def test_clientGreeting(self):
194 Test that on a connection where the client speaks first, the server
195 receives the bytes sent by the client.
197 return self._greetingtest("write", False)
200 def test_clientGreetingSequence(self):
202 Like C{test_clientGreeting}, but use C{writeSequence} instead of
203 C{write} to issue the greeting.
205 return self._greetingtest("writeSequence", False)
208 def test_serverGreeting(self, write="write"):
210 Test that on a connection where the server speaks first, the client
211 receives the bytes sent by the server.
213 return self._greetingtest("write", True)
216 def test_serverGreetingSequence(self):
218 Like C{test_serverGreeting}, but use C{writeSequence} instead of
219 C{write} to issue the greeting.
221 return self._greetingtest("writeSequence", True)
224 def _producertest(self, producerClass):
225 toProduce = map(str, range(0, 10))
227 class ProducingProtocol(Protocol):
228 def connectionMade(self):
229 self.producer = producerClass(list(toProduce))
230 self.producer.start(self.transport)
232 class ReceivingProtocol(Protocol):
234 def dataReceived(self, bytes):
236 if self.bytes == ''.join(toProduce):
237 self.received.callback((client, server))
239 server = ProducingProtocol()
240 client = ReceivingProtocol()
241 client.received = Deferred()
243 loopback.loopbackAsync(server, client)
244 return client.received
247 def test_pushProducer(self):
249 Test a push producer registered against a loopback transport.
251 class PushProducer(object):
252 implements(IPushProducer)
255 def __init__(self, toProduce):
256 self.toProduce = toProduce
258 def resumeProducing(self):
261 def start(self, consumer):
262 self.consumer = consumer
263 consumer.registerProducer(self, True)
264 self._produceAndSchedule()
266 def _produceAndSchedule(self):
268 self.consumer.write(self.toProduce.pop(0))
269 reactor.callLater(0, self._produceAndSchedule)
271 self.consumer.unregisterProducer()
272 d = self._producertest(PushProducer)
274 def finished((client, server)):
276 server.producer.resumed,
277 "Streaming producer should not have been resumed.")
278 d.addCallback(finished)
282 def test_pullProducer(self):
284 Test a pull producer registered against a loopback transport.
286 class PullProducer(object):
287 implements(IPullProducer)
289 def __init__(self, toProduce):
290 self.toProduce = toProduce
292 def start(self, consumer):
293 self.consumer = consumer
294 self.consumer.registerProducer(self, False)
296 def resumeProducing(self):
297 self.consumer.write(self.toProduce.pop(0))
298 if not self.toProduce:
299 self.consumer.unregisterProducer()
300 return self._producertest(PullProducer)
303 def test_writeNotReentrant(self):
305 L{loopback.loopbackAsync} does not call a protocol's C{dataReceived}
306 method while that protocol's transport's C{write} method is higher up
309 class Server(Protocol):
310 def dataReceived(self, bytes):
311 self.transport.write("bytes")
313 class Client(Protocol):
316 def connectionMade(self):
317 reactor.callLater(0, self.go)
320 self.transport.write("foo")
323 def dataReceived(self, bytes):
324 self.wasReady = self.ready
325 self.transport.loseConnection()
330 d = loopback.loopbackAsync(client, server)
331 def cbFinished(ignored):
332 self.assertTrue(client.wasReady)
333 d.addCallback(cbFinished)
337 def test_pumpPolicy(self):
339 The callable passed as the value for the C{pumpPolicy} parameter to
340 L{loopbackAsync} is called with a L{_LoopbackQueue} of pending bytes
341 and a protocol to which they should be delivered.
344 def dummyPolicy(queue, target):
347 bytes.append(queue.get())
348 pumpCalls.append((target, bytes))
353 finished = loopback.loopbackAsync(server, client, dummyPolicy)
354 self.assertEqual(pumpCalls, [])
356 client.transport.write("foo")
357 client.transport.write("bar")
358 server.transport.write("baz")
359 server.transport.write("quux")
360 server.transport.loseConnection()
362 def cbComplete(ignored):
365 # The order here is somewhat arbitrary. The implementation
366 # happens to always deliver data to the client first.
367 [(client, ["baz", "quux", None]),
368 (server, ["foo", "bar"])])
369 finished.addCallback(cbComplete)
373 def test_identityPumpPolicy(self):
375 L{identityPumpPolicy} is a pump policy which calls the target's
376 C{dataReceived} method one for each string in the queue passed to it.
380 client.dataReceived = bytes.append
381 queue = loopback._LoopbackQueue()
386 loopback.identityPumpPolicy(queue, client)
388 self.assertEqual(bytes, ["foo", "bar"])
391 def test_collapsingPumpPolicy(self):
393 L{collapsingPumpPolicy} is a pump policy which calls the target's
394 C{dataReceived} only once with all of the strings in the queue passed
395 to it joined together.
399 client.dataReceived = bytes.append
400 queue = loopback._LoopbackQueue()
405 loopback.collapsingPumpPolicy(queue, client)
407 self.assertEqual(bytes, ["foobar"])
411 class LoopbackTCPTestCase(LoopbackTestCaseMixin, unittest.TestCase):
412 loopbackFunc = staticmethod(loopback.loopbackTCP)
415 class LoopbackUNIXTestCase(LoopbackTestCaseMixin, unittest.TestCase):
416 loopbackFunc = staticmethod(loopback.loopbackUNIX)
418 if interfaces.IReactorUNIX(reactor, None) is None:
419 skip = "Current reactor does not support UNIX sockets"