Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / test / test_loopback.py
1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
3
4 """
5 Test case for L{twisted.protocols.loopback}.
6 """
7
8 from zope.interface import implements
9
10 from twisted.trial import unittest
11 from twisted.trial.util import suppress as SUPPRESS
12 from twisted.protocols import basic, loopback
13 from twisted.internet import defer
14 from twisted.internet.protocol import Protocol
15 from twisted.internet.defer import Deferred
16 from twisted.internet.interfaces import IAddress, IPushProducer, IPullProducer
17 from twisted.internet import reactor, interfaces
18
19
20 class SimpleProtocol(basic.LineReceiver):
21     def __init__(self):
22         self.conn = defer.Deferred()
23         self.lines = []
24         self.connLost = []
25
26     def connectionMade(self):
27         self.conn.callback(None)
28
29     def lineReceived(self, line):
30         self.lines.append(line)
31
32     def connectionLost(self, reason):
33         self.connLost.append(reason)
34
35
36 class DoomProtocol(SimpleProtocol):
37     i = 0
38     def lineReceived(self, line):
39         self.i += 1
40         if self.i < 4:
41             # by this point we should have connection closed,
42             # but just in case we didn't we won't ever send 'Hello 4'
43             self.sendLine("Hello %d" % self.i)
44         SimpleProtocol.lineReceived(self, line)
45         if self.lines[-1] == "Hello 3":
46             self.transport.loseConnection()
47
48
49 class LoopbackTestCaseMixin:
50     def testRegularFunction(self):
51         s = SimpleProtocol()
52         c = SimpleProtocol()
53
54         def sendALine(result):
55             s.sendLine("THIS IS LINE ONE!")
56             s.transport.loseConnection()
57         s.conn.addCallback(sendALine)
58
59         def check(ignored):
60             self.assertEqual(c.lines, ["THIS IS LINE ONE!"])
61             self.assertEqual(len(s.connLost), 1)
62             self.assertEqual(len(c.connLost), 1)
63         d = defer.maybeDeferred(self.loopbackFunc, s, c)
64         d.addCallback(check)
65         return d
66
67     def testSneakyHiddenDoom(self):
68         s = DoomProtocol()
69         c = DoomProtocol()
70
71         def sendALine(result):
72             s.sendLine("DOOM LINE")
73         s.conn.addCallback(sendALine)
74
75         def check(ignored):
76             self.assertEqual(s.lines, ['Hello 1', 'Hello 2', 'Hello 3'])
77             self.assertEqual(c.lines, ['DOOM LINE', 'Hello 1', 'Hello 2', 'Hello 3'])
78             self.assertEqual(len(s.connLost), 1)
79             self.assertEqual(len(c.connLost), 1)
80         d = defer.maybeDeferred(self.loopbackFunc, s, c)
81         d.addCallback(check)
82         return d
83
84
85
86 class LoopbackAsyncTestCase(LoopbackTestCaseMixin, unittest.TestCase):
87     loopbackFunc = staticmethod(loopback.loopbackAsync)
88
89
90     def test_makeConnection(self):
91         """
92         Test that the client and server protocol both have makeConnection
93         invoked on them by loopbackAsync.
94         """
95         class TestProtocol(Protocol):
96             transport = None
97             def makeConnection(self, transport):
98                 self.transport = transport
99
100         server = TestProtocol()
101         client = TestProtocol()
102         loopback.loopbackAsync(server, client)
103         self.failIfEqual(client.transport, None)
104         self.failIfEqual(server.transport, None)
105
106
107     def _hostpeertest(self, get, testServer):
108         """
109         Test one of the permutations of client/server host/peer.
110         """
111         class TestProtocol(Protocol):
112             def makeConnection(self, transport):
113                 Protocol.makeConnection(self, transport)
114                 self.onConnection.callback(transport)
115
116         if testServer:
117             server = TestProtocol()
118             d = server.onConnection = Deferred()
119             client = Protocol()
120         else:
121             server = Protocol()
122             client = TestProtocol()
123             d = client.onConnection = Deferred()
124
125         loopback.loopbackAsync(server, client)
126
127         def connected(transport):
128             host = getattr(transport, get)()
129             self.failUnless(IAddress.providedBy(host))
130
131         return d.addCallback(connected)
132
133
134     def test_serverHost(self):
135         """
136         Test that the server gets a transport with a properly functioning
137         implementation of L{ITransport.getHost}.
138         """
139         return self._hostpeertest("getHost", True)
140
141
142     def test_serverPeer(self):
143         """
144         Like C{test_serverHost} but for L{ITransport.getPeer}
145         """
146         return self._hostpeertest("getPeer", True)
147
148
149     def test_clientHost(self, get="getHost"):
150         """
151         Test that the client gets a transport with a properly functioning
152         implementation of L{ITransport.getHost}.
153         """
154         return self._hostpeertest("getHost", False)
155
156
157     def test_clientPeer(self):
158         """
159         Like C{test_clientHost} but for L{ITransport.getPeer}.
160         """
161         return self._hostpeertest("getPeer", False)
162
163
164     def _greetingtest(self, write, testServer):
165         """
166         Test one of the permutations of write/writeSequence client/server.
167         """
168         class GreeteeProtocol(Protocol):
169             bytes = ""
170             def dataReceived(self, bytes):
171                 self.bytes += bytes
172                 if self.bytes == "bytes":
173                     self.received.callback(None)
174
175         class GreeterProtocol(Protocol):
176             def connectionMade(self):
177                 getattr(self.transport, write)("bytes")
178
179         if testServer:
180             server = GreeterProtocol()
181             client = GreeteeProtocol()
182             d = client.received = Deferred()
183         else:
184             server = GreeteeProtocol()
185             d = server.received = Deferred()
186             client = GreeterProtocol()
187
188         loopback.loopbackAsync(server, client)
189         return d
190
191
192     def test_clientGreeting(self):
193         """
194         Test that on a connection where the client speaks first, the server
195         receives the bytes sent by the client.
196         """
197         return self._greetingtest("write", False)
198
199
200     def test_clientGreetingSequence(self):
201         """
202         Like C{test_clientGreeting}, but use C{writeSequence} instead of
203         C{write} to issue the greeting.
204         """
205         return self._greetingtest("writeSequence", False)
206
207
208     def test_serverGreeting(self, write="write"):
209         """
210         Test that on a connection where the server speaks first, the client
211         receives the bytes sent by the server.
212         """
213         return self._greetingtest("write", True)
214
215
216     def test_serverGreetingSequence(self):
217         """
218         Like C{test_serverGreeting}, but use C{writeSequence} instead of
219         C{write} to issue the greeting.
220         """
221         return self._greetingtest("writeSequence", True)
222
223
224     def _producertest(self, producerClass):
225         toProduce = map(str, range(0, 10))
226
227         class ProducingProtocol(Protocol):
228             def connectionMade(self):
229                 self.producer = producerClass(list(toProduce))
230                 self.producer.start(self.transport)
231
232         class ReceivingProtocol(Protocol):
233             bytes = ""
234             def dataReceived(self, bytes):
235                 self.bytes += bytes
236                 if self.bytes == ''.join(toProduce):
237                     self.received.callback((client, server))
238
239         server = ProducingProtocol()
240         client = ReceivingProtocol()
241         client.received = Deferred()
242
243         loopback.loopbackAsync(server, client)
244         return client.received
245
246
247     def test_pushProducer(self):
248         """
249         Test a push producer registered against a loopback transport.
250         """
251         class PushProducer(object):
252             implements(IPushProducer)
253             resumed = False
254
255             def __init__(self, toProduce):
256                 self.toProduce = toProduce
257
258             def resumeProducing(self):
259                 self.resumed = True
260
261             def start(self, consumer):
262                 self.consumer = consumer
263                 consumer.registerProducer(self, True)
264                 self._produceAndSchedule()
265
266             def _produceAndSchedule(self):
267                 if self.toProduce:
268                     self.consumer.write(self.toProduce.pop(0))
269                     reactor.callLater(0, self._produceAndSchedule)
270                 else:
271                     self.consumer.unregisterProducer()
272         d = self._producertest(PushProducer)
273
274         def finished((client, server)):
275             self.failIf(
276                 server.producer.resumed,
277                 "Streaming producer should not have been resumed.")
278         d.addCallback(finished)
279         return d
280
281
282     def test_pullProducer(self):
283         """
284         Test a pull producer registered against a loopback transport.
285         """
286         class PullProducer(object):
287             implements(IPullProducer)
288
289             def __init__(self, toProduce):
290                 self.toProduce = toProduce
291
292             def start(self, consumer):
293                 self.consumer = consumer
294                 self.consumer.registerProducer(self, False)
295
296             def resumeProducing(self):
297                 self.consumer.write(self.toProduce.pop(0))
298                 if not self.toProduce:
299                     self.consumer.unregisterProducer()
300         return self._producertest(PullProducer)
301
302
303     def test_writeNotReentrant(self):
304         """
305         L{loopback.loopbackAsync} does not call a protocol's C{dataReceived}
306         method while that protocol's transport's C{write} method is higher up
307         on the stack.
308         """
309         class Server(Protocol):
310             def dataReceived(self, bytes):
311                 self.transport.write("bytes")
312
313         class Client(Protocol):
314             ready = False
315
316             def connectionMade(self):
317                 reactor.callLater(0, self.go)
318
319             def go(self):
320                 self.transport.write("foo")
321                 self.ready = True
322
323             def dataReceived(self, bytes):
324                 self.wasReady = self.ready
325                 self.transport.loseConnection()
326
327
328         server = Server()
329         client = Client()
330         d = loopback.loopbackAsync(client, server)
331         def cbFinished(ignored):
332             self.assertTrue(client.wasReady)
333         d.addCallback(cbFinished)
334         return d
335
336
337     def test_pumpPolicy(self):
338         """
339         The callable passed as the value for the C{pumpPolicy} parameter to
340         L{loopbackAsync} is called with a L{_LoopbackQueue} of pending bytes
341         and a protocol to which they should be delivered.
342         """
343         pumpCalls = []
344         def dummyPolicy(queue, target):
345             bytes = []
346             while queue:
347                 bytes.append(queue.get())
348             pumpCalls.append((target, bytes))
349
350         client = Protocol()
351         server = Protocol()
352
353         finished = loopback.loopbackAsync(server, client, dummyPolicy)
354         self.assertEqual(pumpCalls, [])
355
356         client.transport.write("foo")
357         client.transport.write("bar")
358         server.transport.write("baz")
359         server.transport.write("quux")
360         server.transport.loseConnection()
361
362         def cbComplete(ignored):
363             self.assertEqual(
364                 pumpCalls,
365                 # The order here is somewhat arbitrary.  The implementation
366                 # happens to always deliver data to the client first.
367                 [(client, ["baz", "quux", None]),
368                  (server, ["foo", "bar"])])
369         finished.addCallback(cbComplete)
370         return finished
371
372
373     def test_identityPumpPolicy(self):
374         """
375         L{identityPumpPolicy} is a pump policy which calls the target's
376         C{dataReceived} method one for each string in the queue passed to it.
377         """
378         bytes = []
379         client = Protocol()
380         client.dataReceived = bytes.append
381         queue = loopback._LoopbackQueue()
382         queue.put("foo")
383         queue.put("bar")
384         queue.put(None)
385
386         loopback.identityPumpPolicy(queue, client)
387
388         self.assertEqual(bytes, ["foo", "bar"])
389
390
391     def test_collapsingPumpPolicy(self):
392         """
393         L{collapsingPumpPolicy} is a pump policy which calls the target's
394         C{dataReceived} only once with all of the strings in the queue passed
395         to it joined together.
396         """
397         bytes = []
398         client = Protocol()
399         client.dataReceived = bytes.append
400         queue = loopback._LoopbackQueue()
401         queue.put("foo")
402         queue.put("bar")
403         queue.put(None)
404
405         loopback.collapsingPumpPolicy(queue, client)
406
407         self.assertEqual(bytes, ["foobar"])
408
409
410
411 class LoopbackTCPTestCase(LoopbackTestCaseMixin, unittest.TestCase):
412     loopbackFunc = staticmethod(loopback.loopbackTCP)
413
414
415 class LoopbackUNIXTestCase(LoopbackTestCaseMixin, unittest.TestCase):
416     loopbackFunc = staticmethod(loopback.loopbackUNIX)
417
418     if interfaces.IReactorUNIX(reactor, None) is None:
419         skip = "Current reactor does not support UNIX sockets"