Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / internet / test / test_endpoints.py
1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
3 """
4 Test the C{I...Endpoint} implementations that wrap the L{IReactorTCP},
5 L{IReactorSSL}, and L{IReactorUNIX} interfaces found in
6 L{twisted.internet.endpoints}.
7 """
8
9 from errno import EPERM
10 from socket import AF_INET, AF_INET6
11 from zope.interface import implements
12 from zope.interface.verify import verifyObject
13
14 from twisted.trial import unittest
15 from twisted.internet import error, interfaces, defer
16 from twisted.internet import endpoints
17 from twisted.internet.address import IPv4Address, UNIXAddress
18 from twisted.internet.protocol import ClientFactory, Protocol
19 from twisted.test.proto_helpers import (
20     MemoryReactor, RaisingMemoryReactor, StringTransport)
21 from twisted.python.failure import Failure
22 from twisted.python.systemd import ListenFDs
23 from twisted.plugin import getPlugins
24
25 from twisted import plugins
26 from twisted.python.modules import getModule
27 from twisted.python.filepath import FilePath
28
29 pemPath = getModule("twisted.test").filePath.sibling("server.pem")
30 casPath = getModule(__name__).filePath.sibling("fake_CAs")
31 escapedPEMPathName = endpoints.quoteStringArgument(pemPath.path)
32 escapedCAsPathName = endpoints.quoteStringArgument(casPath.path)
33
34 try:
35     from twisted.test.test_sslverify import makeCertificate
36     from twisted.internet.ssl import CertificateOptions, Certificate, \
37         KeyPair, PrivateCertificate
38     from OpenSSL.SSL import ContextType
39     testCertificate = Certificate.loadPEM(pemPath.getContent())
40     testPrivateCertificate = PrivateCertificate.loadPEM(pemPath.getContent())
41
42     skipSSL = False
43 except ImportError:
44     skipSSL = "OpenSSL is required to construct SSL Endpoints"
45
46
47 class TestProtocol(Protocol):
48     """
49     Protocol whose only function is to callback deferreds on the
50     factory when it is connected or disconnected.
51     """
52
53     def __init__(self):
54         self.data = []
55         self.connectionsLost = []
56         self.connectionMadeCalls = 0
57
58
59     def logPrefix(self):
60         return "A Test Protocol"
61
62
63     def connectionMade(self):
64         self.connectionMadeCalls += 1
65
66
67     def dataReceived(self, data):
68         self.data.append(data)
69
70
71     def connectionLost(self, reason):
72         self.connectionsLost.append(reason)
73
74
75
76 class TestHalfCloseableProtocol(TestProtocol):
77     """
78     A Protocol that implements L{IHalfCloseableProtocol} and records whether its
79     C{readConnectionLost} and {writeConnectionLost} methods are called.
80
81     @ivar readLost: A C{bool} indicating whether C{readConnectionLost} has been
82         called.
83
84     @ivar writeLost: A C{bool} indicating whether C{writeConnectionLost} has
85         been called.
86     """
87     implements(interfaces.IHalfCloseableProtocol)
88
89     def __init__(self):
90         TestProtocol.__init__(self)
91         self.readLost = False
92         self.writeLost = False
93
94
95     def readConnectionLost(self):
96         self.readLost = True
97
98
99     def writeConnectionLost(self):
100         self.writeLost = True
101
102
103
104 class TestFileDescriptorReceiverProtocol(TestProtocol):
105     """
106     A Protocol that implements L{IFileDescriptorReceiver} and records how its
107     C{fileDescriptorReceived} method is called.
108
109     @ivar receivedDescriptors: A C{list} containing all of the file descriptors
110         passed to C{fileDescriptorReceived} calls made on this instance.
111     """
112     implements(interfaces.IFileDescriptorReceiver)
113
114     def connectionMade(self):
115         TestProtocol.connectionMade(self)
116         self.receivedDescriptors = []
117
118
119     def fileDescriptorReceived(self, descriptor):
120         self.receivedDescriptors.append(descriptor)
121
122
123
124 class TestFactory(ClientFactory):
125     """
126     Simple factory to be used both when connecting and listening. It contains
127     two deferreds which are called back when my protocol connects and
128     disconnects.
129     """
130
131     protocol = TestProtocol
132
133
134
135 class WrappingFactoryTests(unittest.TestCase):
136     """
137     Test the behaviour of our ugly implementation detail C{_WrappingFactory}.
138     """
139     def test_doStart(self):
140         """
141         L{_WrappingFactory.doStart} passes through to the wrapped factory's
142         C{doStart} method, allowing application-specific setup and logging.
143         """
144         factory = ClientFactory()
145         wf = endpoints._WrappingFactory(factory)
146         wf.doStart()
147         self.assertEqual(1, factory.numPorts)
148
149
150     def test_doStop(self):
151         """
152         L{_WrappingFactory.doStop} passes through to the wrapped factory's
153         C{doStop} method, allowing application-specific cleanup and logging.
154         """
155         factory = ClientFactory()
156         factory.numPorts = 3
157         wf = endpoints._WrappingFactory(factory)
158         wf.doStop()
159         self.assertEqual(2, factory.numPorts)
160
161
162     def test_failedBuildProtocol(self):
163         """
164         An exception raised in C{buildProtocol} of our wrappedFactory
165         results in our C{onConnection} errback being fired.
166         """
167
168         class BogusFactory(ClientFactory):
169             """
170             A one off factory whose C{buildProtocol} raises an C{Exception}.
171             """
172
173             def buildProtocol(self, addr):
174                 raise ValueError("My protocol is poorly defined.")
175
176
177         wf = endpoints._WrappingFactory(BogusFactory())
178
179         wf.buildProtocol(None)
180
181         d = self.assertFailure(wf._onConnection, ValueError)
182         d.addCallback(lambda e: self.assertEqual(
183                 e.args,
184                 ("My protocol is poorly defined.",)))
185
186         return d
187
188
189     def test_logPrefixPassthrough(self):
190         """
191         If the wrapped protocol provides L{ILoggingContext}, whatever is
192         returned from the wrapped C{logPrefix} method is returned from
193         L{_WrappingProtocol.logPrefix}.
194         """
195         wf = endpoints._WrappingFactory(TestFactory())
196         wp = wf.buildProtocol(None)
197         self.assertEqual(wp.logPrefix(), "A Test Protocol")
198
199
200     def test_logPrefixDefault(self):
201         """
202         If the wrapped protocol does not provide L{ILoggingContext}, the wrapped
203         protocol's class name is returned from L{_WrappingProtocol.logPrefix}.
204         """
205         class NoProtocol(object):
206             pass
207         factory = TestFactory()
208         factory.protocol = NoProtocol
209         wf = endpoints._WrappingFactory(factory)
210         wp = wf.buildProtocol(None)
211         self.assertEqual(wp.logPrefix(), "NoProtocol")
212
213
214     def test_wrappedProtocolDataReceived(self):
215         """
216         The wrapped C{Protocol}'s C{dataReceived} will get called when our
217         C{_WrappingProtocol}'s C{dataReceived} gets called.
218         """
219         wf = endpoints._WrappingFactory(TestFactory())
220         p = wf.buildProtocol(None)
221         p.makeConnection(None)
222
223         p.dataReceived('foo')
224         self.assertEqual(p._wrappedProtocol.data, ['foo'])
225
226         p.dataReceived('bar')
227         self.assertEqual(p._wrappedProtocol.data, ['foo', 'bar'])
228
229
230     def test_wrappedProtocolTransport(self):
231         """
232         Our transport is properly hooked up to the wrappedProtocol when a
233         connection is made.
234         """
235         wf = endpoints._WrappingFactory(TestFactory())
236         p = wf.buildProtocol(None)
237
238         dummyTransport = object()
239
240         p.makeConnection(dummyTransport)
241
242         self.assertEqual(p.transport, dummyTransport)
243
244         self.assertEqual(p._wrappedProtocol.transport, dummyTransport)
245
246
247     def test_wrappedProtocolConnectionLost(self):
248         """
249         Our wrappedProtocol's connectionLost method is called when
250         L{_WrappingProtocol.connectionLost} is called.
251         """
252         tf = TestFactory()
253         wf = endpoints._WrappingFactory(tf)
254         p = wf.buildProtocol(None)
255
256         p.connectionLost("fail")
257
258         self.assertEqual(p._wrappedProtocol.connectionsLost, ["fail"])
259
260
261     def test_clientConnectionFailed(self):
262         """
263         Calls to L{_WrappingFactory.clientConnectionLost} should errback the
264         L{_WrappingFactory._onConnection} L{Deferred}
265         """
266         wf = endpoints._WrappingFactory(TestFactory())
267         expectedFailure = Failure(error.ConnectError(string="fail"))
268
269         wf.clientConnectionFailed(
270             None,
271             expectedFailure)
272
273         errors = []
274         def gotError(f):
275             errors.append(f)
276
277         wf._onConnection.addErrback(gotError)
278
279         self.assertEqual(errors, [expectedFailure])
280
281
282     def test_wrappingProtocolFileDescriptorReceiver(self):
283         """
284         Our L{_WrappingProtocol} should be an L{IFileDescriptorReceiver} if the
285         wrapped protocol is.
286         """
287         connectedDeferred = None
288         applicationProtocol = TestFileDescriptorReceiverProtocol()
289         wrapper = endpoints._WrappingProtocol(
290             connectedDeferred, applicationProtocol)
291         self.assertTrue(interfaces.IFileDescriptorReceiver.providedBy(wrapper))
292         self.assertTrue(verifyObject(interfaces.IFileDescriptorReceiver, wrapper))
293
294
295     def test_wrappingProtocolNotFileDescriptorReceiver(self):
296         """
297         Our L{_WrappingProtocol} does not provide L{IHalfCloseableProtocol} if
298         the wrapped protocol doesn't.
299         """
300         tp = TestProtocol()
301         p = endpoints._WrappingProtocol(None, tp)
302         self.assertFalse(interfaces.IFileDescriptorReceiver.providedBy(p))
303
304
305     def test_wrappedProtocolFileDescriptorReceived(self):
306         """
307         L{_WrappingProtocol.fileDescriptorReceived} calls the wrapped protocol's
308         C{fileDescriptorReceived} method.
309         """
310         wrappedProtocol = TestFileDescriptorReceiverProtocol()
311         wrapper = endpoints._WrappingProtocol(
312             defer.Deferred(), wrappedProtocol)
313         wrapper.makeConnection(StringTransport())
314         wrapper.fileDescriptorReceived(42)
315         self.assertEqual(wrappedProtocol.receivedDescriptors, [42])
316
317
318     def test_wrappingProtocolHalfCloseable(self):
319         """
320         Our L{_WrappingProtocol} should be an L{IHalfCloseableProtocol} if the
321         C{wrappedProtocol} is.
322         """
323         cd = object()
324         hcp = TestHalfCloseableProtocol()
325         p = endpoints._WrappingProtocol(cd, hcp)
326         self.assertEqual(
327             interfaces.IHalfCloseableProtocol.providedBy(p), True)
328
329
330     def test_wrappingProtocolNotHalfCloseable(self):
331         """
332         Our L{_WrappingProtocol} should not provide L{IHalfCloseableProtocol}
333         if the C{WrappedProtocol} doesn't.
334         """
335         tp = TestProtocol()
336         p = endpoints._WrappingProtocol(None, tp)
337         self.assertEqual(
338             interfaces.IHalfCloseableProtocol.providedBy(p), False)
339
340
341     def test_wrappedProtocolReadConnectionLost(self):
342         """
343         L{_WrappingProtocol.readConnectionLost} should proxy to the wrapped
344         protocol's C{readConnectionLost}
345         """
346         hcp = TestHalfCloseableProtocol()
347         p = endpoints._WrappingProtocol(None, hcp)
348         p.readConnectionLost()
349         self.assertEqual(hcp.readLost, True)
350
351
352     def test_wrappedProtocolWriteConnectionLost(self):
353         """
354         L{_WrappingProtocol.writeConnectionLost} should proxy to the wrapped
355         protocol's C{writeConnectionLost}
356         """
357         hcp = TestHalfCloseableProtocol()
358         p = endpoints._WrappingProtocol(None, hcp)
359         p.writeConnectionLost()
360         self.assertEqual(hcp.writeLost, True)
361
362
363
364 class ClientEndpointTestCaseMixin(object):
365     """
366     Generic test methods to be mixed into all client endpoint test classes.
367     """
368     def retrieveConnectedFactory(self, reactor):
369         """
370         Retrieve a single factory that has connected using the given reactor.
371         (This behavior is valid for TCP and SSL but needs to be overridden for
372         UNIX.)
373
374         @param reactor: a L{MemoryReactor}
375         """
376         return self.expectedClients(reactor)[0][2]
377
378
379     def test_endpointConnectSuccess(self):
380         """
381         A client endpoint can connect and returns a deferred who gets called
382         back with a protocol instance.
383         """
384         proto = object()
385         mreactor = MemoryReactor()
386
387         clientFactory = object()
388
389         ep, expectedArgs, ignoredDest = self.createClientEndpoint(
390             mreactor, clientFactory)
391
392         d = ep.connect(clientFactory)
393
394         receivedProtos = []
395
396         def checkProto(p):
397             receivedProtos.append(p)
398
399         d.addCallback(checkProto)
400
401         factory = self.retrieveConnectedFactory(mreactor)
402         factory._onConnection.callback(proto)
403         self.assertEqual(receivedProtos, [proto])
404
405         expectedClients = self.expectedClients(mreactor)
406
407         self.assertEqual(len(expectedClients), 1)
408         self.assertConnectArgs(expectedClients[0], expectedArgs)
409
410
411     def test_endpointConnectFailure(self):
412         """
413         If an endpoint tries to connect to a non-listening port it gets
414         a C{ConnectError} failure.
415         """
416         expectedError = error.ConnectError(string="Connection Failed")
417
418         mreactor = RaisingMemoryReactor(connectException=expectedError)
419
420         clientFactory = object()
421
422         ep, ignoredArgs, ignoredDest = self.createClientEndpoint(
423             mreactor, clientFactory)
424
425         d = ep.connect(clientFactory)
426
427         receivedExceptions = []
428
429         def checkFailure(f):
430             receivedExceptions.append(f.value)
431
432         d.addErrback(checkFailure)
433
434         self.assertEqual(receivedExceptions, [expectedError])
435
436
437     def test_endpointConnectingCancelled(self):
438         """
439         Calling L{Deferred.cancel} on the L{Deferred} returned from
440         L{IStreamClientEndpoint.connect} is errbacked with an expected
441         L{ConnectingCancelledError} exception.
442         """
443         mreactor = MemoryReactor()
444
445         clientFactory = object()
446
447         ep, ignoredArgs, address = self.createClientEndpoint(
448             mreactor, clientFactory)
449
450         d = ep.connect(clientFactory)
451
452         receivedFailures = []
453
454         def checkFailure(f):
455             receivedFailures.append(f)
456
457         d.addErrback(checkFailure)
458
459         d.cancel()
460         # When canceled, the connector will immediately notify its factory that
461         # the connection attempt has failed due to a UserError.
462         attemptFactory = self.retrieveConnectedFactory(mreactor)
463         attemptFactory.clientConnectionFailed(None, Failure(error.UserError()))
464         # This should be a feature of MemoryReactor: <http://tm.tl/5630>.
465
466         self.assertEqual(len(receivedFailures), 1)
467
468         failure = receivedFailures[0]
469
470         self.assertIsInstance(failure.value, error.ConnectingCancelledError)
471         self.assertEqual(failure.value.address, address)
472
473
474     def test_endpointConnectNonDefaultArgs(self):
475         """
476         The endpoint should pass it's connectArgs parameter to the reactor's
477         listen methods.
478         """
479         factory = object()
480
481         mreactor = MemoryReactor()
482
483         ep, expectedArgs, ignoredHost = self.createClientEndpoint(
484             mreactor, factory,
485             **self.connectArgs())
486
487         ep.connect(factory)
488
489         expectedClients = self.expectedClients(mreactor)
490
491         self.assertEqual(len(expectedClients), 1)
492         self.assertConnectArgs(expectedClients[0], expectedArgs)
493
494
495
496 class ServerEndpointTestCaseMixin(object):
497     """
498     Generic test methods to be mixed into all client endpoint test classes.
499     """
500     def test_endpointListenSuccess(self):
501         """
502         An endpoint can listen and returns a deferred that gets called back
503         with a port instance.
504         """
505         mreactor = MemoryReactor()
506
507         factory = object()
508
509         ep, expectedArgs, expectedHost = self.createServerEndpoint(
510             mreactor, factory)
511
512         d = ep.listen(factory)
513
514         receivedHosts = []
515
516         def checkPortAndServer(port):
517             receivedHosts.append(port.getHost())
518
519         d.addCallback(checkPortAndServer)
520
521         self.assertEqual(receivedHosts, [expectedHost])
522         self.assertEqual(self.expectedServers(mreactor), [expectedArgs])
523
524
525     def test_endpointListenFailure(self):
526         """
527         When an endpoint tries to listen on an already listening port, a
528         C{CannotListenError} failure is errbacked.
529         """
530         factory = object()
531         exception = error.CannotListenError('', 80, factory)
532         mreactor = RaisingMemoryReactor(listenException=exception)
533
534         ep, ignoredArgs, ignoredDest = self.createServerEndpoint(
535             mreactor, factory)
536
537         d = ep.listen(object())
538
539         receivedExceptions = []
540
541         def checkFailure(f):
542             receivedExceptions.append(f.value)
543
544         d.addErrback(checkFailure)
545
546         self.assertEqual(receivedExceptions, [exception])
547
548
549     def test_endpointListenNonDefaultArgs(self):
550         """
551         The endpoint should pass it's listenArgs parameter to the reactor's
552         listen methods.
553         """
554         factory = object()
555
556         mreactor = MemoryReactor()
557
558         ep, expectedArgs, ignoredHost = self.createServerEndpoint(
559             mreactor, factory,
560             **self.listenArgs())
561
562         ep.listen(factory)
563
564         expectedServers = self.expectedServers(mreactor)
565
566         self.assertEqual(expectedServers, [expectedArgs])
567
568
569
570 class EndpointTestCaseMixin(ServerEndpointTestCaseMixin,
571                             ClientEndpointTestCaseMixin):
572     """
573     Generic test methods to be mixed into all endpoint test classes.
574     """
575
576
577
578 class TCP4EndpointsTestCase(EndpointTestCaseMixin, unittest.TestCase):
579     """
580     Tests for TCP Endpoints.
581     """
582
583     def expectedServers(self, reactor):
584         """
585         @return: List of calls to L{IReactorTCP.listenTCP}
586         """
587         return reactor.tcpServers
588
589
590     def expectedClients(self, reactor):
591         """
592         @return: List of calls to L{IReactorTCP.connectTCP}
593         """
594         return reactor.tcpClients
595
596
597     def assertConnectArgs(self, receivedArgs, expectedArgs):
598         """
599         Compare host, port, timeout, and bindAddress in C{receivedArgs}
600         to C{expectedArgs}.  We ignore the factory because we don't
601         only care what protocol comes out of the
602         C{IStreamClientEndpoint.connect} call.
603
604         @param receivedArgs: C{tuple} of (C{host}, C{port}, C{factory},
605             C{timeout}, C{bindAddress}) that was passed to
606             L{IReactorTCP.connectTCP}.
607         @param expectedArgs: C{tuple} of (C{host}, C{port}, C{factory},
608             C{timeout}, C{bindAddress}) that we expect to have been passed
609             to L{IReactorTCP.connectTCP}.
610         """
611         (host, port, ignoredFactory, timeout, bindAddress) = receivedArgs
612         (expectedHost, expectedPort, _ignoredFactory,
613          expectedTimeout, expectedBindAddress) = expectedArgs
614
615         self.assertEqual(host, expectedHost)
616         self.assertEqual(port, expectedPort)
617         self.assertEqual(timeout, expectedTimeout)
618         self.assertEqual(bindAddress, expectedBindAddress)
619
620
621     def connectArgs(self):
622         """
623         @return: C{dict} of keyword arguments to pass to connect.
624         """
625         return {'timeout': 10, 'bindAddress': ('localhost', 49595)}
626
627
628     def listenArgs(self):
629         """
630         @return: C{dict} of keyword arguments to pass to listen
631         """
632         return {'backlog': 100, 'interface': '127.0.0.1'}
633
634
635     def createServerEndpoint(self, reactor, factory, **listenArgs):
636         """
637         Create an L{TCP4ServerEndpoint} and return the values needed to verify
638         its behaviour.
639
640         @param reactor: A fake L{IReactorTCP} that L{TCP4ServerEndpoint} can
641             call L{IReactorTCP.listenTCP} on.
642         @param factory: The thing that we expect to be passed to our
643             L{IStreamServerEndpoint.listen} implementation.
644         @param listenArgs: Optional dictionary of arguments to
645             L{IReactorTCP.listenTCP}.
646         """
647         address = IPv4Address("TCP", "0.0.0.0", 0)
648
649         if listenArgs is None:
650             listenArgs = {}
651
652         return (endpoints.TCP4ServerEndpoint(reactor,
653                                              address.port,
654                                              **listenArgs),
655                 (address.port, factory,
656                  listenArgs.get('backlog', 50),
657                  listenArgs.get('interface', '')),
658                 address)
659
660
661     def createClientEndpoint(self, reactor, clientFactory, **connectArgs):
662         """
663         Create an L{TCP4ClientEndpoint} and return the values needed to verify
664         its behavior.
665
666         @param reactor: A fake L{IReactorTCP} that L{TCP4ClientEndpoint} can
667             call L{IReactorTCP.connectTCP} on.
668         @param clientFactory: The thing that we expect to be passed to our
669             L{IStreamClientEndpoint.connect} implementation.
670         @param connectArgs: Optional dictionary of arguments to
671             L{IReactorTCP.connectTCP}
672         """
673         address = IPv4Address("TCP", "localhost", 80)
674
675         return (endpoints.TCP4ClientEndpoint(reactor,
676                                              address.host,
677                                              address.port,
678                                              **connectArgs),
679                 (address.host, address.port, clientFactory,
680                  connectArgs.get('timeout', 30),
681                  connectArgs.get('bindAddress', None)),
682                 address)
683
684
685
686 class SSL4EndpointsTestCase(EndpointTestCaseMixin,
687                             unittest.TestCase):
688     """
689     Tests for SSL Endpoints.
690     """
691     if skipSSL:
692         skip = skipSSL
693
694     def expectedServers(self, reactor):
695         """
696         @return: List of calls to L{IReactorSSL.listenSSL}
697         """
698         return reactor.sslServers
699
700
701     def expectedClients(self, reactor):
702         """
703         @return: List of calls to L{IReactorSSL.connectSSL}
704         """
705         return reactor.sslClients
706
707
708     def assertConnectArgs(self, receivedArgs, expectedArgs):
709         """
710         Compare host, port, contextFactory, timeout, and bindAddress in
711         C{receivedArgs} to C{expectedArgs}.  We ignore the factory because we
712         don't only care what protocol comes out of the
713         C{IStreamClientEndpoint.connect} call.
714
715         @param receivedArgs: C{tuple} of (C{host}, C{port}, C{factory},
716             C{contextFactory}, C{timeout}, C{bindAddress}) that was passed to
717             L{IReactorSSL.connectSSL}.
718         @param expectedArgs: C{tuple} of (C{host}, C{port}, C{factory},
719             C{contextFactory}, C{timeout}, C{bindAddress}) that we expect to
720             have been passed to L{IReactorSSL.connectSSL}.
721         """
722         (host, port, ignoredFactory, contextFactory, timeout,
723          bindAddress) = receivedArgs
724
725         (expectedHost, expectedPort, _ignoredFactory, expectedContextFactory,
726          expectedTimeout, expectedBindAddress) = expectedArgs
727
728         self.assertEqual(host, expectedHost)
729         self.assertEqual(port, expectedPort)
730         self.assertEqual(contextFactory, expectedContextFactory)
731         self.assertEqual(timeout, expectedTimeout)
732         self.assertEqual(bindAddress, expectedBindAddress)
733
734
735     def connectArgs(self):
736         """
737         @return: C{dict} of keyword arguments to pass to connect.
738         """
739         return {'timeout': 10, 'bindAddress': ('localhost', 49595)}
740
741
742     def listenArgs(self):
743         """
744         @return: C{dict} of keyword arguments to pass to listen
745         """
746         return {'backlog': 100, 'interface': '127.0.0.1'}
747
748
749     def setUp(self):
750         """
751         Set up client and server SSL contexts for use later.
752         """
753         self.sKey, self.sCert = makeCertificate(
754             O="Server Test Certificate",
755             CN="server")
756         self.cKey, self.cCert = makeCertificate(
757             O="Client Test Certificate",
758             CN="client")
759         self.serverSSLContext = CertificateOptions(
760             privateKey=self.sKey,
761             certificate=self.sCert,
762             requireCertificate=False)
763         self.clientSSLContext = CertificateOptions(
764             requireCertificate=False)
765
766
767     def createServerEndpoint(self, reactor, factory, **listenArgs):
768         """
769         Create an L{SSL4ServerEndpoint} and return the tools to verify its
770         behaviour.
771
772         @param factory: The thing that we expect to be passed to our
773             L{IStreamServerEndpoint.listen} implementation.
774         @param reactor: A fake L{IReactorSSL} that L{SSL4ServerEndpoint} can
775             call L{IReactorSSL.listenSSL} on.
776         @param listenArgs: Optional dictionary of arguments to
777             L{IReactorSSL.listenSSL}.
778         """
779         address = IPv4Address("TCP", "0.0.0.0", 0)
780
781         return (endpoints.SSL4ServerEndpoint(reactor,
782                                              address.port,
783                                              self.serverSSLContext,
784                                              **listenArgs),
785                 (address.port, factory, self.serverSSLContext,
786                  listenArgs.get('backlog', 50),
787                  listenArgs.get('interface', '')),
788                 address)
789
790
791     def createClientEndpoint(self, reactor, clientFactory, **connectArgs):
792         """
793         Create an L{SSL4ClientEndpoint} and return the values needed to verify
794         its behaviour.
795
796         @param reactor: A fake L{IReactorSSL} that L{SSL4ClientEndpoint} can
797             call L{IReactorSSL.connectSSL} on.
798         @param clientFactory: The thing that we expect to be passed to our
799             L{IStreamClientEndpoint.connect} implementation.
800         @param connectArgs: Optional dictionary of arguments to
801             L{IReactorSSL.connectSSL}
802         """
803         address = IPv4Address("TCP", "localhost", 80)
804
805         if connectArgs is None:
806             connectArgs = {}
807
808         return (endpoints.SSL4ClientEndpoint(reactor,
809                                              address.host,
810                                              address.port,
811                                              self.clientSSLContext,
812                                              **connectArgs),
813                 (address.host, address.port, clientFactory,
814                  self.clientSSLContext,
815                  connectArgs.get('timeout', 30),
816                  connectArgs.get('bindAddress', None)),
817                 address)
818
819
820
821 class UNIXEndpointsTestCase(EndpointTestCaseMixin,
822                             unittest.TestCase):
823     """
824     Tests for UnixSocket Endpoints.
825     """
826
827     def retrieveConnectedFactory(self, reactor):
828         """
829         Override L{EndpointTestCaseMixin.retrieveConnectedFactory} to account
830         for different index of 'factory' in C{connectUNIX} args.
831         """
832         return self.expectedClients(reactor)[0][1]
833
834     def expectedServers(self, reactor):
835         """
836         @return: List of calls to L{IReactorUNIX.listenUNIX}
837         """
838         return reactor.unixServers
839
840
841     def expectedClients(self, reactor):
842         """
843         @return: List of calls to L{IReactorUNIX.connectUNIX}
844         """
845         return reactor.unixClients
846
847
848     def assertConnectArgs(self, receivedArgs, expectedArgs):
849         """
850         Compare path, timeout, checkPID in C{receivedArgs} to C{expectedArgs}.
851         We ignore the factory because we don't only care what protocol comes
852         out of the C{IStreamClientEndpoint.connect} call.
853
854         @param receivedArgs: C{tuple} of (C{path}, C{timeout}, C{checkPID})
855             that was passed to L{IReactorUNIX.connectUNIX}.
856         @param expectedArgs: C{tuple} of (C{path}, C{timeout}, C{checkPID})
857             that we expect to have been passed to L{IReactorUNIX.connectUNIX}.
858         """
859
860         (path, ignoredFactory, timeout, checkPID) = receivedArgs
861
862         (expectedPath, _ignoredFactory, expectedTimeout,
863          expectedCheckPID) = expectedArgs
864
865         self.assertEqual(path, expectedPath)
866         self.assertEqual(timeout, expectedTimeout)
867         self.assertEqual(checkPID, expectedCheckPID)
868
869
870     def connectArgs(self):
871         """
872         @return: C{dict} of keyword arguments to pass to connect.
873         """
874         return {'timeout': 10, 'checkPID': 1}
875
876
877     def listenArgs(self):
878         """
879         @return: C{dict} of keyword arguments to pass to listen
880         """
881         return {'backlog': 100, 'mode': 0600, 'wantPID': 1}
882
883
884     def createServerEndpoint(self, reactor, factory, **listenArgs):
885         """
886         Create an L{UNIXServerEndpoint} and return the tools to verify its
887         behaviour.
888
889         @param reactor: A fake L{IReactorUNIX} that L{UNIXServerEndpoint} can
890             call L{IReactorUNIX.listenUNIX} on.
891         @param factory: The thing that we expect to be passed to our
892             L{IStreamServerEndpoint.listen} implementation.
893         @param listenArgs: Optional dictionary of arguments to
894             L{IReactorUNIX.listenUNIX}.
895         """
896         address = UNIXAddress(self.mktemp())
897
898         return (endpoints.UNIXServerEndpoint(reactor, address.name,
899                                              **listenArgs),
900                 (address.name, factory,
901                  listenArgs.get('backlog', 50),
902                  listenArgs.get('mode', 0666),
903                  listenArgs.get('wantPID', 0)),
904                 address)
905
906
907     def createClientEndpoint(self, reactor, clientFactory, **connectArgs):
908         """
909         Create an L{UNIXClientEndpoint} and return the values needed to verify
910         its behaviour.
911
912         @param reactor: A fake L{IReactorUNIX} that L{UNIXClientEndpoint} can
913             call L{IReactorUNIX.connectUNIX} on.
914         @param clientFactory: The thing that we expect to be passed to our
915             L{IStreamClientEndpoint.connect} implementation.
916         @param connectArgs: Optional dictionary of arguments to
917             L{IReactorUNIX.connectUNIX}
918         """
919         address = UNIXAddress(self.mktemp())
920
921         return (endpoints.UNIXClientEndpoint(reactor, address.name,
922                                              **connectArgs),
923                 (address.name, clientFactory,
924                  connectArgs.get('timeout', 30),
925                  connectArgs.get('checkPID', 0)),
926                 address)
927
928
929
930 class ParserTestCase(unittest.TestCase):
931     """
932     Tests for L{endpoints._parseServer}, the low-level parsing logic.
933     """
934
935     f = "Factory"
936
937     def parse(self, *a, **kw):
938         """
939         Provide a hook for test_strports to substitute the deprecated API.
940         """
941         return endpoints._parseServer(*a, **kw)
942
943
944     def test_simpleTCP(self):
945         """
946         Simple strings with a 'tcp:' prefix should be parsed as TCP.
947         """
948         self.assertEqual(self.parse('tcp:80', self.f),
949                          ('TCP', (80, self.f), {'interface':'', 'backlog':50}))
950
951
952     def test_interfaceTCP(self):
953         """
954         TCP port descriptions parse their 'interface' argument as a string.
955         """
956         self.assertEqual(
957              self.parse('tcp:80:interface=127.0.0.1', self.f),
958              ('TCP', (80, self.f), {'interface':'127.0.0.1', 'backlog':50}))
959
960
961     def test_backlogTCP(self):
962         """
963         TCP port descriptions parse their 'backlog' argument as an integer.
964         """
965         self.assertEqual(self.parse('tcp:80:backlog=6', self.f),
966                          ('TCP', (80, self.f),
967                                  {'interface':'', 'backlog':6}))
968
969
970     def test_simpleUNIX(self):
971         """
972         L{endpoints._parseServer} returns a C{'UNIX'} port description with
973         defaults for C{'mode'}, C{'backlog'}, and C{'wantPID'} when passed a
974         string with the C{'unix:'} prefix and no other parameter values.
975         """
976         self.assertEqual(
977             self.parse('unix:/var/run/finger', self.f),
978             ('UNIX', ('/var/run/finger', self.f),
979              {'mode': 0666, 'backlog': 50, 'wantPID': True}))
980
981
982     def test_modeUNIX(self):
983         """
984         C{mode} can be set by including C{"mode=<some integer>"}.
985         """
986         self.assertEqual(
987             self.parse('unix:/var/run/finger:mode=0660', self.f),
988             ('UNIX', ('/var/run/finger', self.f),
989              {'mode': 0660, 'backlog': 50, 'wantPID': True}))
990
991
992     def test_wantPIDUNIX(self):
993         """
994         C{wantPID} can be set to false by included C{"lockfile=0"}.
995         """
996         self.assertEqual(
997             self.parse('unix:/var/run/finger:lockfile=0', self.f),
998             ('UNIX', ('/var/run/finger', self.f),
999              {'mode': 0666, 'backlog': 50, 'wantPID': False}))
1000
1001
1002     def test_escape(self):
1003         """
1004         Backslash can be used to escape colons and backslashes in port
1005         descriptions.
1006         """
1007         self.assertEqual(
1008             self.parse(r'unix:foo\:bar\=baz\:qux\\', self.f),
1009             ('UNIX', ('foo:bar=baz:qux\\', self.f),
1010              {'mode': 0666, 'backlog': 50, 'wantPID': True}))
1011
1012
1013     def test_quoteStringArgument(self):
1014         """
1015         L{endpoints.quoteStringArgument} should quote backslashes and colons
1016         for interpolation into L{endpoints.serverFromString} and
1017         L{endpoints.clientFactory} arguments.
1018         """
1019         self.assertEqual(endpoints.quoteStringArgument("some : stuff \\"),
1020                          "some \\: stuff \\\\")
1021
1022
1023     def test_impliedEscape(self):
1024         """
1025         In strports descriptions, '=' in a parameter value does not need to be
1026         quoted; it will simply be parsed as part of the value.
1027         """
1028         self.assertEqual(
1029             self.parse(r'unix:address=foo=bar', self.f),
1030             ('UNIX', ('foo=bar', self.f),
1031              {'mode': 0666, 'backlog': 50, 'wantPID': True}))
1032
1033
1034     def test_nonstandardDefault(self):
1035         """
1036         For compatibility with the old L{twisted.application.strports.parse},
1037         the third 'mode' argument may be specified to L{endpoints.parse} to
1038         indicate a default other than TCP.
1039         """
1040         self.assertEqual(
1041             self.parse('filename', self.f, 'unix'),
1042             ('UNIX', ('filename', self.f),
1043              {'mode': 0666, 'backlog': 50, 'wantPID': True}))
1044
1045
1046     def test_unknownType(self):
1047         """
1048         L{strports.parse} raises C{ValueError} when given an unknown endpoint
1049         type.
1050         """
1051         self.assertRaises(ValueError, self.parse, "bogus-type:nothing", self.f)
1052
1053
1054
1055 class ServerStringTests(unittest.TestCase):
1056     """
1057     Tests for L{twisted.internet.endpoints.serverFromString}.
1058     """
1059
1060     def test_tcp(self):
1061         """
1062         When passed a TCP strports description, L{endpoints.serverFromString}
1063         returns a L{TCP4ServerEndpoint} instance initialized with the values
1064         from the string.
1065         """
1066         reactor = object()
1067         server = endpoints.serverFromString(
1068             reactor, "tcp:1234:backlog=12:interface=10.0.0.1")
1069         self.assertIsInstance(server, endpoints.TCP4ServerEndpoint)
1070         self.assertIdentical(server._reactor, reactor)
1071         self.assertEqual(server._port, 1234)
1072         self.assertEqual(server._backlog, 12)
1073         self.assertEqual(server._interface, "10.0.0.1")
1074
1075
1076     def test_ssl(self):
1077         """
1078         When passed an SSL strports description, L{endpoints.serverFromString}
1079         returns a L{SSL4ServerEndpoint} instance initialized with the values
1080         from the string.
1081         """
1082         reactor = object()
1083         server = endpoints.serverFromString(
1084             reactor,
1085             "ssl:1234:backlog=12:privateKey=%s:"
1086             "certKey=%s:interface=10.0.0.1" % (escapedPEMPathName,
1087                                                escapedPEMPathName))
1088         self.assertIsInstance(server, endpoints.SSL4ServerEndpoint)
1089         self.assertIdentical(server._reactor, reactor)
1090         self.assertEqual(server._port, 1234)
1091         self.assertEqual(server._backlog, 12)
1092         self.assertEqual(server._interface, "10.0.0.1")
1093         ctx = server._sslContextFactory.getContext()
1094         self.assertIsInstance(ctx, ContextType)
1095
1096     if skipSSL:
1097         test_ssl.skip = skipSSL
1098
1099
1100     def test_unix(self):
1101         """
1102         When passed a UNIX strports description, L{endpoint.serverFromString}
1103         returns a L{UNIXServerEndpoint} instance initialized with the values
1104         from the string.
1105         """
1106         reactor = object()
1107         endpoint = endpoints.serverFromString(
1108             reactor,
1109             "unix:/var/foo/bar:backlog=7:mode=0123:lockfile=1")
1110         self.assertIsInstance(endpoint, endpoints.UNIXServerEndpoint)
1111         self.assertIdentical(endpoint._reactor, reactor)
1112         self.assertEqual(endpoint._address, "/var/foo/bar")
1113         self.assertEqual(endpoint._backlog, 7)
1114         self.assertEqual(endpoint._mode, 0123)
1115         self.assertEqual(endpoint._wantPID, True)
1116
1117
1118     def test_implicitDefaultNotAllowed(self):
1119         """
1120         The older service-based API (L{twisted.internet.strports.service})
1121         allowed an implicit default of 'tcp' so that TCP ports could be
1122         specified as a simple integer, but we've since decided that's a bad
1123         idea, and the new API does not accept an implicit default argument; you
1124         have to say 'tcp:' now.  If you try passing an old implicit port number
1125         to the new API, you'll get a C{ValueError}.
1126         """
1127         value = self.assertRaises(
1128             ValueError, endpoints.serverFromString, None, "4321")
1129         self.assertEqual(
1130             str(value),
1131             "Unqualified strport description passed to 'service'."
1132             "Use qualified endpoint descriptions; for example, 'tcp:4321'.")
1133
1134
1135     def test_unknownType(self):
1136         """
1137         L{endpoints.serverFromString} raises C{ValueError} when given an
1138         unknown endpoint type.
1139         """
1140         value = self.assertRaises(
1141             # faster-than-light communication not supported
1142             ValueError, endpoints.serverFromString, None,
1143             "ftl:andromeda/carcosa/hali/2387")
1144         self.assertEqual(
1145             str(value),
1146             "Unknown endpoint type: 'ftl'")
1147
1148
1149     def test_typeFromPlugin(self):
1150         """
1151         L{endpoints.serverFromString} looks up plugins of type
1152         L{IStreamServerEndpoint} and constructs endpoints from them.
1153         """
1154         # Set up a plugin which will only be accessible for the duration of
1155         # this test.
1156         addFakePlugin(self)
1157         # Plugin is set up: now actually test.
1158         notAReactor = object()
1159         fakeEndpoint = endpoints.serverFromString(
1160             notAReactor, "fake:hello:world:yes=no:up=down")
1161         from twisted.plugins.fakeendpoint import fake
1162         self.assertIdentical(fakeEndpoint.parser, fake)
1163         self.assertEqual(fakeEndpoint.args, (notAReactor, 'hello', 'world'))
1164         self.assertEqual(fakeEndpoint.kwargs, dict(yes='no', up='down'))
1165
1166
1167
1168 def addFakePlugin(testCase, dropinSource="fakeendpoint.py"):
1169     """
1170     For the duration of C{testCase}, add a fake plugin to twisted.plugins which
1171     contains some sample endpoint parsers.
1172     """
1173     import sys
1174     savedModules = sys.modules.copy()
1175     savedPluginPath = plugins.__path__
1176     def cleanup():
1177         sys.modules.clear()
1178         sys.modules.update(savedModules)
1179         plugins.__path__[:] = savedPluginPath
1180     testCase.addCleanup(cleanup)
1181     fp = FilePath(testCase.mktemp())
1182     fp.createDirectory()
1183     getModule(__name__).filePath.sibling(dropinSource).copyTo(
1184         fp.child(dropinSource))
1185     plugins.__path__.append(fp.path)
1186
1187
1188
1189 class ClientStringTests(unittest.TestCase):
1190     """
1191     Tests for L{twisted.internet.endpoints.clientFromString}.
1192     """
1193
1194     def test_tcp(self):
1195         """
1196         When passed a TCP strports description, L{endpoints.clientFromString}
1197         returns a L{TCP4ClientEndpoint} instance initialized with the values
1198         from the string.
1199         """
1200         reactor = object()
1201         client = endpoints.clientFromString(
1202             reactor,
1203             "tcp:host=example.com:port=1234:timeout=7:bindAddress=10.0.0.2")
1204         self.assertIsInstance(client, endpoints.TCP4ClientEndpoint)
1205         self.assertIdentical(client._reactor, reactor)
1206         self.assertEqual(client._host, "example.com")
1207         self.assertEqual(client._port, 1234)
1208         self.assertEqual(client._timeout, 7)
1209         self.assertEqual(client._bindAddress, "10.0.0.2")
1210
1211
1212     def test_tcpPositionalArgs(self):
1213         """
1214         When passed a TCP strports description using positional arguments,
1215         L{endpoints.clientFromString} returns a L{TCP4ClientEndpoint} instance
1216         initialized with the values from the string.
1217         """
1218         reactor = object()
1219         client = endpoints.clientFromString(
1220             reactor,
1221             "tcp:example.com:1234:timeout=7:bindAddress=10.0.0.2")
1222         self.assertIsInstance(client, endpoints.TCP4ClientEndpoint)
1223         self.assertIdentical(client._reactor, reactor)
1224         self.assertEqual(client._host, "example.com")
1225         self.assertEqual(client._port, 1234)
1226         self.assertEqual(client._timeout, 7)
1227         self.assertEqual(client._bindAddress, "10.0.0.2")
1228
1229
1230     def test_tcpHostPositionalArg(self):
1231         """
1232         When passed a TCP strports description specifying host as a positional
1233         argument, L{endpoints.clientFromString} returns a L{TCP4ClientEndpoint}
1234         instance initialized with the values from the string.
1235         """
1236         reactor = object()
1237
1238         client = endpoints.clientFromString(
1239             reactor,
1240             "tcp:example.com:port=1234:timeout=7:bindAddress=10.0.0.2")
1241         self.assertEqual(client._host, "example.com")
1242         self.assertEqual(client._port, 1234)
1243
1244
1245     def test_tcpPortPositionalArg(self):
1246         """
1247         When passed a TCP strports description specifying port as a positional
1248         argument, L{endpoints.clientFromString} returns a L{TCP4ClientEndpoint}
1249         instance initialized with the values from the string.
1250         """
1251         reactor = object()
1252         client = endpoints.clientFromString(
1253             reactor,
1254             "tcp:host=example.com:1234:timeout=7:bindAddress=10.0.0.2")
1255         self.assertEqual(client._host, "example.com")
1256         self.assertEqual(client._port, 1234)
1257
1258
1259     def test_tcpDefaults(self):
1260         """
1261         A TCP strports description may omit I{timeout} or I{bindAddress} to
1262         allow the default to be used.
1263         """
1264         reactor = object()
1265         client = endpoints.clientFromString(
1266             reactor,
1267             "tcp:host=example.com:port=1234")
1268         self.assertEqual(client._timeout, 30)
1269         self.assertEqual(client._bindAddress, None)
1270
1271
1272     def test_unix(self):
1273         """
1274         When passed a UNIX strports description, L{endpoints.clientFromString}
1275         returns a L{UNIXClientEndpoint} instance initialized with the values
1276         from the string.
1277         """
1278         reactor = object()
1279         client = endpoints.clientFromString(
1280             reactor,
1281             "unix:path=/var/foo/bar:lockfile=1:timeout=9")
1282         self.assertIsInstance(client, endpoints.UNIXClientEndpoint)
1283         self.assertIdentical(client._reactor, reactor)
1284         self.assertEqual(client._path, "/var/foo/bar")
1285         self.assertEqual(client._timeout, 9)
1286         self.assertEqual(client._checkPID, True)
1287
1288
1289     def test_unixDefaults(self):
1290         """
1291         A UNIX strports description may omit I{lockfile} or I{timeout} to allow
1292         the defaults to be used.
1293         """
1294         client = endpoints.clientFromString(object(), "unix:path=/var/foo/bar")
1295         self.assertEqual(client._timeout, 30)
1296         self.assertEqual(client._checkPID, False)
1297
1298
1299     def test_unixPathPositionalArg(self):
1300         """
1301         When passed a UNIX strports description specifying path as a positional
1302         argument, L{endpoints.clientFromString} returns a L{UNIXClientEndpoint}
1303         instance initialized with the values from the string.
1304         """
1305         reactor = object()
1306         client = endpoints.clientFromString(
1307             reactor,
1308             "unix:/var/foo/bar:lockfile=1:timeout=9")
1309         self.assertIsInstance(client, endpoints.UNIXClientEndpoint)
1310         self.assertIdentical(client._reactor, reactor)
1311         self.assertEqual(client._path, "/var/foo/bar")
1312         self.assertEqual(client._timeout, 9)
1313         self.assertEqual(client._checkPID, True)
1314
1315
1316     def test_typeFromPlugin(self):
1317         """
1318         L{endpoints.clientFromString} looks up plugins of type
1319         L{IStreamClientEndpoint} and constructs endpoints from them.
1320         """
1321         addFakePlugin(self)
1322         notAReactor = object()
1323         clientEndpoint = endpoints.clientFromString(
1324             notAReactor, "cfake:alpha:beta:cee=dee:num=1")
1325         from twisted.plugins.fakeendpoint import fakeClient
1326         self.assertIdentical(clientEndpoint.parser, fakeClient)
1327         self.assertEqual(clientEndpoint.args, ('alpha', 'beta'))
1328         self.assertEqual(clientEndpoint.kwargs, dict(cee='dee', num='1'))
1329
1330
1331     def test_unknownType(self):
1332         """
1333         L{endpoints.serverFromString} raises C{ValueError} when given an
1334         unknown endpoint type.
1335         """
1336         value = self.assertRaises(
1337             # faster-than-light communication not supported
1338             ValueError, endpoints.clientFromString, None,
1339             "ftl:andromeda/carcosa/hali/2387")
1340         self.assertEqual(
1341             str(value),
1342             "Unknown endpoint type: 'ftl'")
1343
1344
1345
1346 class SSLClientStringTests(unittest.TestCase):
1347     """
1348     Tests for L{twisted.internet.endpoints.clientFromString} which require SSL.
1349     """
1350
1351     if skipSSL:
1352         skip = skipSSL
1353
1354     def test_ssl(self):
1355         """
1356         When passed an SSL strports description, L{clientFromString} returns a
1357         L{SSL4ClientEndpoint} instance initialized with the values from the
1358         string.
1359         """
1360         reactor = object()
1361         client = endpoints.clientFromString(
1362             reactor,
1363             "ssl:host=example.net:port=4321:privateKey=%s:"
1364             "certKey=%s:bindAddress=10.0.0.3:timeout=3:caCertsDir=%s" %
1365              (escapedPEMPathName,
1366               escapedPEMPathName,
1367               escapedCAsPathName))
1368         self.assertIsInstance(client, endpoints.SSL4ClientEndpoint)
1369         self.assertIdentical(client._reactor, reactor)
1370         self.assertEqual(client._host, "example.net")
1371         self.assertEqual(client._port, 4321)
1372         self.assertEqual(client._timeout, 3)
1373         self.assertEqual(client._bindAddress, "10.0.0.3")
1374         certOptions = client._sslContextFactory
1375         self.assertIsInstance(certOptions, CertificateOptions)
1376         ctx = certOptions.getContext()
1377         self.assertIsInstance(ctx, ContextType)
1378         self.assertEqual(Certificate(certOptions.certificate),
1379                           testCertificate)
1380         privateCert = PrivateCertificate(certOptions.certificate)
1381         privateCert._setPrivateKey(KeyPair(certOptions.privateKey))
1382         self.assertEqual(privateCert, testPrivateCertificate)
1383         expectedCerts = [
1384             Certificate.loadPEM(x.getContent()) for x in
1385                 [casPath.child("thing1.pem"), casPath.child("thing2.pem")]
1386             if x.basename().lower().endswith('.pem')
1387         ]
1388         self.assertEqual([Certificate(x) for x in certOptions.caCerts],
1389                           expectedCerts)
1390
1391
1392     def test_sslPositionalArgs(self):
1393         """
1394         When passed an SSL strports description, L{clientFromString} returns a
1395         L{SSL4ClientEndpoint} instance initialized with the values from the
1396         string.
1397         """
1398         reactor = object()
1399         client = endpoints.clientFromString(
1400             reactor,
1401             "ssl:example.net:4321:privateKey=%s:"
1402             "certKey=%s:bindAddress=10.0.0.3:timeout=3:caCertsDir=%s" %
1403              (escapedPEMPathName,
1404               escapedPEMPathName,
1405               escapedCAsPathName))
1406         self.assertIsInstance(client, endpoints.SSL4ClientEndpoint)
1407         self.assertIdentical(client._reactor, reactor)
1408         self.assertEqual(client._host, "example.net")
1409         self.assertEqual(client._port, 4321)
1410         self.assertEqual(client._timeout, 3)
1411         self.assertEqual(client._bindAddress, "10.0.0.3")
1412
1413
1414     def test_unreadableCertificate(self):
1415         """
1416         If a certificate in the directory is unreadable,
1417         L{endpoints._loadCAsFromDir} will ignore that certificate.
1418         """
1419         class UnreadableFilePath(FilePath):
1420             def getContent(self):
1421                 data = FilePath.getContent(self)
1422                 # There is a duplicate of thing2.pem, so ignore anything that
1423                 # looks like it.
1424                 if data == casPath.child("thing2.pem").getContent():
1425                     raise IOError(EPERM)
1426                 else:
1427                     return data
1428         casPathClone = casPath.child("ignored").parent()
1429         casPathClone.clonePath = UnreadableFilePath
1430         self.assertEqual(
1431             [Certificate(x) for x in endpoints._loadCAsFromDir(casPathClone)],
1432             [Certificate.loadPEM(casPath.child("thing1.pem").getContent())])
1433
1434
1435     def test_sslSimple(self):
1436         """
1437         When passed an SSL strports description without any extra parameters,
1438         L{clientFromString} returns a simple non-verifying endpoint that will
1439         speak SSL.
1440         """
1441         reactor = object()
1442         client = endpoints.clientFromString(
1443             reactor, "ssl:host=simple.example.org:port=4321")
1444         certOptions = client._sslContextFactory
1445         self.assertIsInstance(certOptions, CertificateOptions)
1446         self.assertEqual(certOptions.verify, False)
1447         ctx = certOptions.getContext()
1448         self.assertIsInstance(ctx, ContextType)
1449
1450
1451
1452 class AdoptedStreamServerEndpointTestCase(ServerEndpointTestCaseMixin,
1453                                           unittest.TestCase):
1454     """
1455     Tests for adopted socket-based stream server endpoints.
1456     """
1457     def _createStubbedAdoptedEndpoint(self, reactor, fileno, addressFamily):
1458         """
1459         Create an L{AdoptedStreamServerEndpoint} which may safely be used with
1460         an invalid file descriptor.  This is convenient for a number of unit
1461         tests.
1462         """
1463         e = endpoints.AdoptedStreamServerEndpoint(reactor, fileno, addressFamily)
1464         # Stub out some syscalls which would fail, given our invalid file
1465         # descriptor.
1466         e._close = lambda fd: None
1467         e._setNonBlocking = lambda fd: None
1468         return e
1469
1470
1471     def createServerEndpoint(self, reactor, factory):
1472         """
1473         Create a new L{AdoptedStreamServerEndpoint} for use by a test.
1474
1475         @return: A three-tuple:
1476             - The endpoint
1477             - A tuple of the arguments expected to be passed to the underlying
1478               reactor method
1479             - An IAddress object which will match the result of
1480               L{IListeningPort.getHost} on the port returned by the endpoint.
1481         """
1482         fileno = 12
1483         addressFamily = AF_INET
1484         endpoint = self._createStubbedAdoptedEndpoint(
1485             reactor, fileno, addressFamily)
1486         # Magic numbers come from the implementation of MemoryReactor
1487         address = IPv4Address("TCP", "0.0.0.0", 1234)
1488         return (endpoint, (fileno, addressFamily, factory), address)
1489
1490
1491     def expectedServers(self, reactor):
1492         """
1493         @return: The ports which were actually adopted by C{reactor} via calls
1494             to its L{IReactorSocket.adoptStreamPort} implementation.
1495         """
1496         return reactor.adoptedPorts
1497
1498
1499     def listenArgs(self):
1500         """
1501         @return: A C{dict} of additional keyword arguments to pass to the
1502             C{createServerEndpoint}.
1503         """
1504         return {}
1505
1506
1507     def test_singleUse(self):
1508         """
1509         L{AdoptedStreamServerEndpoint.listen} can only be used once.  The file
1510         descriptor given is closed after the first use, and subsequent calls to
1511         C{listen} return a L{Deferred} that fails with L{AlreadyListened}.
1512         """
1513         reactor = MemoryReactor()
1514         endpoint = self._createStubbedAdoptedEndpoint(reactor, 13, AF_INET)
1515         endpoint.listen(object())
1516         d = self.assertFailure(endpoint.listen(object()), error.AlreadyListened)
1517         def listenFailed(ignored):
1518             self.assertEqual(1, len(reactor.adoptedPorts))
1519         d.addCallback(listenFailed)
1520         return d
1521
1522
1523     def test_descriptionNonBlocking(self):
1524         """
1525         L{AdoptedStreamServerEndpoint.listen} sets the file description given to
1526         it to non-blocking.
1527         """
1528         reactor = MemoryReactor()
1529         endpoint = self._createStubbedAdoptedEndpoint(reactor, 13, AF_INET)
1530         events = []
1531         def setNonBlocking(fileno):
1532             events.append(("setNonBlocking", fileno))
1533         endpoint._setNonBlocking = setNonBlocking
1534
1535         d = endpoint.listen(object())
1536         def listened(ignored):
1537             self.assertEqual([("setNonBlocking", 13)], events)
1538         d.addCallback(listened)
1539         return d
1540
1541
1542     def test_descriptorClosed(self):
1543         """
1544         L{AdoptedStreamServerEndpoint.listen} closes its file descriptor after
1545         adding it to the reactor with L{IReactorSocket.adoptStreamPort}.
1546         """
1547         reactor = MemoryReactor()
1548         endpoint = self._createStubbedAdoptedEndpoint(reactor, 13, AF_INET)
1549         events = []
1550         def close(fileno):
1551             events.append(("close", fileno, len(reactor.adoptedPorts)))
1552         endpoint._close = close
1553
1554         d = endpoint.listen(object())
1555         def listened(ignored):
1556             self.assertEqual([("close", 13, 1)], events)
1557         d.addCallback(listened)
1558         return d
1559
1560
1561
1562 class SystemdEndpointPluginTests(unittest.TestCase):
1563     """
1564     Unit tests for the systemd stream server endpoint and endpoint string
1565     description parser.
1566
1567     @see: U{systemd<http://www.freedesktop.org/wiki/Software/systemd>}
1568     """
1569
1570     _parserClass = endpoints._SystemdParser
1571
1572     def test_pluginDiscovery(self):
1573         """
1574         L{endpoints._SystemdParser} is found as a plugin for
1575         L{interfaces.IStreamServerEndpointStringParser} interface.
1576         """
1577         parsers = list(getPlugins(
1578                 interfaces.IStreamServerEndpointStringParser))
1579         for p in parsers:
1580             if isinstance(p, self._parserClass):
1581                 break
1582         else:
1583             self.fail("Did not find systemd parser in %r" % (parsers,))
1584
1585
1586     def test_interface(self):
1587         """
1588         L{endpoints._SystemdParser} instances provide
1589         L{interfaces.IStreamServerEndpointStringParser}.
1590         """
1591         parser = self._parserClass()
1592         self.assertTrue(verifyObject(
1593                 interfaces.IStreamServerEndpointStringParser, parser))
1594
1595
1596     def _parseStreamServerTest(self, addressFamily, addressFamilyString):
1597         """
1598         Helper for unit tests for L{endpoints._SystemdParser.parseStreamServer}
1599         for different address families.
1600
1601         Handling of the address family given will be verify.  If there is a
1602         problem a test-failing exception will be raised.
1603
1604         @param addressFamily: An address family constant, like L{socket.AF_INET}.
1605
1606         @param addressFamilyString: A string which should be recognized by the
1607             parser as representing C{addressFamily}.
1608         """
1609         reactor = object()
1610         descriptors = [5, 6, 7, 8, 9]
1611         index = 3
1612
1613         parser = self._parserClass()
1614         parser._sddaemon = ListenFDs(descriptors)
1615
1616         server = parser.parseStreamServer(
1617             reactor, domain=addressFamilyString, index=str(index))
1618         self.assertIdentical(server.reactor, reactor)
1619         self.assertEqual(server.addressFamily, addressFamily)
1620         self.assertEqual(server.fileno, descriptors[index])
1621
1622
1623     def test_parseStreamServerINET(self):
1624         """
1625         IPv4 can be specified using the string C{"INET"}.
1626         """
1627         self._parseStreamServerTest(AF_INET, "INET")
1628
1629
1630     def test_parseStreamServerINET6(self):
1631         """
1632         IPv6 can be specified using the string C{"INET6"}.
1633         """
1634         self._parseStreamServerTest(AF_INET6, "INET6")
1635
1636
1637     def test_parseStreamServerUNIX(self):
1638         """
1639         A UNIX domain socket can be specified using the string C{"UNIX"}.
1640         """
1641         try:
1642             from socket import AF_UNIX
1643         except ImportError:
1644             raise unittest.SkipTest("Platform lacks AF_UNIX support")
1645         else:
1646             self._parseStreamServerTest(AF_UNIX, "UNIX")