Initial import to Tizen
[profile/ivi/python-twisted.git] / twisted / test / test_ssl.py
1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
3
4 """
5 Tests for twisted SSL support.
6 """
7
8 from twisted.trial import unittest
9 from twisted.internet import protocol, reactor, interfaces, defer
10 from twisted.internet.error import ConnectionDone
11 from twisted.protocols import basic
12 from twisted.python import util
13 from twisted.python.runtime import platform
14 from twisted.test.test_tcp import ProperlyCloseFilesMixin
15
16 import os, errno
17
18 try:
19     from OpenSSL import SSL, crypto
20     from twisted.internet import ssl
21     from twisted.test.ssl_helpers import ClientTLSContext
22 except ImportError:
23     def _noSSL():
24         # ugh, make pyflakes happy.
25         global SSL
26         global ssl
27         SSL = ssl = None
28     _noSSL()
29
30 try:
31     from twisted.protocols import tls as newTLS
32 except ImportError:
33     # Assuming SSL exists, we're using old version in reactor (i.e. non-protocol)
34     newTLS = None
35
36 certPath = util.sibpath(__file__, "server.pem")
37
38
39
40 class UnintelligentProtocol(basic.LineReceiver):
41     """
42     @ivar deferred: a deferred that will fire at connection lost.
43     @type deferred: L{defer.Deferred}
44
45     @cvar pretext: text sent before TLS is set up.
46     @type pretext: C{str}
47
48     @cvar posttext: text sent after TLS is set up.
49     @type posttext: C{str}
50     """
51     pretext = [
52         "first line",
53         "last thing before tls starts",
54         "STARTTLS"]
55
56     posttext = [
57         "first thing after tls started",
58         "last thing ever"]
59
60     def __init__(self):
61         self.deferred = defer.Deferred()
62
63
64     def connectionMade(self):
65         for l in self.pretext:
66             self.sendLine(l)
67
68
69     def lineReceived(self, line):
70         if line == "READY":
71             self.transport.startTLS(ClientTLSContext(), self.factory.client)
72             for l in self.posttext:
73                 self.sendLine(l)
74             self.transport.loseConnection()
75
76
77     def connectionLost(self, reason):
78         self.deferred.callback(None)
79
80
81
82 class LineCollector(basic.LineReceiver):
83     """
84     @ivar deferred: a deferred that will fire at connection lost.
85     @type deferred: L{defer.Deferred}
86
87     @ivar doTLS: whether the protocol is initiate TLS or not.
88     @type doTLS: C{bool}
89
90     @ivar fillBuffer: if set to True, it will send lots of data once
91         C{STARTTLS} is received.
92     @type fillBuffer: C{bool}
93     """
94
95     def __init__(self, doTLS, fillBuffer=False):
96         self.doTLS = doTLS
97         self.fillBuffer = fillBuffer
98         self.deferred = defer.Deferred()
99
100
101     def connectionMade(self):
102         self.factory.rawdata = ''
103         self.factory.lines = []
104
105
106     def lineReceived(self, line):
107         self.factory.lines.append(line)
108         if line == 'STARTTLS':
109             if self.fillBuffer:
110                 for x in range(500):
111                     self.sendLine('X' * 1000)
112             self.sendLine('READY')
113             if self.doTLS:
114                 ctx = ServerTLSContext(
115                     privateKeyFileName=certPath,
116                     certificateFileName=certPath,
117                 )
118                 self.transport.startTLS(ctx, self.factory.server)
119             else:
120                 self.setRawMode()
121
122
123     def rawDataReceived(self, data):
124         self.factory.rawdata += data
125         self.transport.loseConnection()
126
127
128     def connectionLost(self, reason):
129         self.deferred.callback(None)
130
131
132
133 class SingleLineServerProtocol(protocol.Protocol):
134     """
135     A protocol that sends a single line of data at C{connectionMade}.
136     """
137
138     def connectionMade(self):
139         self.transport.write("+OK <some crap>\r\n")
140         self.transport.getPeerCertificate()
141
142
143
144 class RecordingClientProtocol(protocol.Protocol):
145     """
146     @ivar deferred: a deferred that will fire with first received content.
147     @type deferred: L{defer.Deferred}
148     """
149
150     def __init__(self):
151         self.deferred = defer.Deferred()
152
153
154     def connectionMade(self):
155         self.transport.getPeerCertificate()
156
157
158     def dataReceived(self, data):
159         self.deferred.callback(data)
160
161
162
163 class ImmediatelyDisconnectingProtocol(protocol.Protocol):
164     """
165     A protocol that disconnect immediately on connection. It fires the
166     C{connectionDisconnected} deferred of its factory on connetion lost.
167     """
168
169     def connectionMade(self):
170         self.transport.loseConnection()
171
172
173     def connectionLost(self, reason):
174         self.factory.connectionDisconnected.callback(None)
175
176
177
178 def generateCertificateObjects(organization, organizationalUnit):
179     """
180     Create a certificate for given C{organization} and C{organizationalUnit}.
181
182     @return: a tuple of (key, request, certificate) objects.
183     """
184     pkey = crypto.PKey()
185     pkey.generate_key(crypto.TYPE_RSA, 512)
186     req = crypto.X509Req()
187     subject = req.get_subject()
188     subject.O = organization
189     subject.OU = organizationalUnit
190     req.set_pubkey(pkey)
191     req.sign(pkey, "md5")
192
193     # Here comes the actual certificate
194     cert = crypto.X509()
195     cert.set_serial_number(1)
196     cert.gmtime_adj_notBefore(0)
197     cert.gmtime_adj_notAfter(60) # Testing certificates need not be long lived
198     cert.set_issuer(req.get_subject())
199     cert.set_subject(req.get_subject())
200     cert.set_pubkey(req.get_pubkey())
201     cert.sign(pkey, "md5")
202
203     return pkey, req, cert
204
205
206
207 def generateCertificateFiles(basename, organization, organizationalUnit):
208     """
209     Create certificate files key, req and cert prefixed by C{basename} for
210     given C{organization} and C{organizationalUnit}.
211     """
212     pkey, req, cert = generateCertificateObjects(organization, organizationalUnit)
213
214     for ext, obj, dumpFunc in [
215         ('key', pkey, crypto.dump_privatekey),
216         ('req', req, crypto.dump_certificate_request),
217         ('cert', cert, crypto.dump_certificate)]:
218         fName = os.extsep.join((basename, ext))
219         fObj = file(fName, 'w')
220         fObj.write(dumpFunc(crypto.FILETYPE_PEM, obj))
221         fObj.close()
222
223
224
225 class ContextGeneratingMixin:
226     """
227     Offer methods to create L{ssl.DefaultOpenSSLContextFactory} for both client
228     and server.
229
230     @ivar clientBase: prefix of client certificate files.
231     @type clientBase: C{str}
232
233     @ivar serverBase: prefix of server certificate files.
234     @type serverBase: C{str}
235
236     @ivar clientCtxFactory: a generated context factory to be used in
237         C{reactor.connectSSL}.
238     @type clientCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
239
240     @ivar serverCtxFactory: a generated context factory to be used in
241         C{reactor.listenSSL}.
242     @type serverCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
243     """
244
245     def makeContextFactory(self, org, orgUnit, *args, **kwArgs):
246         base = self.mktemp()
247         generateCertificateFiles(base, org, orgUnit)
248         serverCtxFactory = ssl.DefaultOpenSSLContextFactory(
249             os.extsep.join((base, 'key')),
250             os.extsep.join((base, 'cert')),
251             *args, **kwArgs)
252
253         return base, serverCtxFactory
254
255
256     def setupServerAndClient(self, clientArgs, clientKwArgs, serverArgs,
257                              serverKwArgs):
258         self.clientBase, self.clientCtxFactory = self.makeContextFactory(
259             *clientArgs, **clientKwArgs)
260         self.serverBase, self.serverCtxFactory = self.makeContextFactory(
261             *serverArgs, **serverKwArgs)
262
263
264
265 if SSL is not None:
266     class ServerTLSContext(ssl.DefaultOpenSSLContextFactory):
267         """
268         A context factory with a default method set to L{SSL.TLSv1_METHOD}.
269         """
270         isClient = False
271
272         def __init__(self, *args, **kw):
273             kw['sslmethod'] = SSL.TLSv1_METHOD
274             ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw)
275
276
277
278 class StolenTCPTestCase(ProperlyCloseFilesMixin, unittest.TestCase):
279     """
280     For SSL transports, test many of the same things which are tested for
281     TCP transports.
282     """
283
284     def createServer(self, address, portNumber, factory):
285         """
286         Create an SSL server with a certificate using L{IReactorSSL.listenSSL}.
287         """
288         cert = ssl.PrivateCertificate.loadPEM(file(certPath).read())
289         contextFactory = cert.options()
290         return reactor.listenSSL(
291             portNumber, factory, contextFactory, interface=address)
292
293
294     def connectClient(self, address, portNumber, clientCreator):
295         """
296         Create an SSL client using L{IReactorSSL.connectSSL}.
297         """
298         contextFactory = ssl.CertificateOptions()
299         return clientCreator.connectSSL(address, portNumber, contextFactory)
300
301
302     def getHandleExceptionType(self):
303         """
304         Return L{SSL.Error} as the expected error type which will be raised by
305         a write to the L{OpenSSL.SSL.Connection} object after it has been
306         closed.
307         """
308         return SSL.Error
309
310
311     def getHandleErrorCode(self):
312         """
313         Return the argument L{SSL.Error} will be constructed with for this
314         case.  This is basically just a random OpenSSL implementation detail.
315         It would be better if this test worked in a way which did not require
316         this.
317         """
318         # Windows 2000 SP 4 and Windows XP SP 2 give back WSAENOTSOCK for
319         # SSL.Connection.write for some reason.  The twisted.protocols.tls
320         # implementation of IReactorSSL doesn't suffer from this imprecation,
321         # though, since it is isolated from the Windows I/O layer (I suppose?).
322
323         # If test_properlyCloseFiles waited for the SSL handshake to complete
324         # and performed an orderly shutdown, then this would probably be a
325         # little less weird: writing to a shutdown SSL connection has a more
326         # well-defined failure mode (or at least it should).
327
328         # So figure out if twisted.protocols.tls is in use.  If it can be
329         # imported, it should be.
330         try:
331             import twisted.protocols.tls
332         except ImportError:
333             # It isn't available, so we expect WSAENOTSOCK if we're on Windows.
334             if platform.getType() == 'win32':
335                 return errno.WSAENOTSOCK
336
337         # Otherwise, we expect an error about how we tried to write to a
338         # shutdown connection.  This is terribly implementation-specific.
339         return [('SSL routines', 'SSL_write', 'protocol is shutdown')]
340
341
342
343 class TLSTestCase(unittest.TestCase):
344     """
345     Tests for startTLS support.
346
347     @ivar fillBuffer: forwarded to L{LineCollector.fillBuffer}
348     @type fillBuffer: C{bool}
349     """
350     fillBuffer = False
351
352     clientProto = None
353     serverProto = None
354
355
356     def tearDown(self):
357         if self.clientProto.transport is not None:
358             self.clientProto.transport.loseConnection()
359         if self.serverProto.transport is not None:
360             self.serverProto.transport.loseConnection()
361
362
363     def _runTest(self, clientProto, serverProto, clientIsServer=False):
364         """
365         Helper method to run TLS tests.
366
367         @param clientProto: protocol instance attached to the client
368             connection.
369         @param serverProto: protocol instance attached to the server
370             connection.
371         @param clientIsServer: flag indicated if client should initiate
372             startTLS instead of server.
373
374         @return: a L{defer.Deferred} that will fire when both connections are
375             lost.
376         """
377         self.clientProto = clientProto
378         cf = self.clientFactory = protocol.ClientFactory()
379         cf.protocol = lambda: clientProto
380         if clientIsServer:
381             cf.server = False
382         else:
383             cf.client = True
384
385         self.serverProto = serverProto
386         sf = self.serverFactory = protocol.ServerFactory()
387         sf.protocol = lambda: serverProto
388         if clientIsServer:
389             sf.client = False
390         else:
391             sf.server = True
392
393         port = reactor.listenTCP(0, sf, interface="127.0.0.1")
394         self.addCleanup(port.stopListening)
395
396         reactor.connectTCP('127.0.0.1', port.getHost().port, cf)
397
398         return defer.gatherResults([clientProto.deferred, serverProto.deferred])
399
400
401     def test_TLS(self):
402         """
403         Test for server and client startTLS: client should received data both
404         before and after the startTLS.
405         """
406         def check(ignore):
407             self.assertEqual(
408                 self.serverFactory.lines,
409                 UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
410             )
411         d = self._runTest(UnintelligentProtocol(),
412                           LineCollector(True, self.fillBuffer))
413         return d.addCallback(check)
414
415
416     def test_unTLS(self):
417         """
418         Test for server startTLS not followed by a startTLS in client: the data
419         received after server startTLS should be received as raw.
420         """
421         def check(ignored):
422             self.assertEqual(
423                 self.serverFactory.lines,
424                 UnintelligentProtocol.pretext
425             )
426             self.failUnless(self.serverFactory.rawdata,
427                             "No encrypted bytes received")
428         d = self._runTest(UnintelligentProtocol(),
429                           LineCollector(False, self.fillBuffer))
430         return d.addCallback(check)
431
432
433     def test_backwardsTLS(self):
434         """
435         Test startTLS first initiated by client.
436         """
437         def check(ignored):
438             self.assertEqual(
439                 self.clientFactory.lines,
440                 UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
441             )
442         d = self._runTest(LineCollector(True, self.fillBuffer),
443                           UnintelligentProtocol(), True)
444         return d.addCallback(check)
445
446
447
448 class SpammyTLSTestCase(TLSTestCase):
449     """
450     Test TLS features with bytes sitting in the out buffer.
451     """
452     fillBuffer = True
453
454
455
456 class BufferingTestCase(unittest.TestCase):
457     serverProto = None
458     clientProto = None
459
460
461     def tearDown(self):
462         if self.serverProto.transport is not None:
463             self.serverProto.transport.loseConnection()
464         if self.clientProto.transport is not None:
465             self.clientProto.transport.loseConnection()
466
467
468     def test_openSSLBuffering(self):
469         serverProto = self.serverProto = SingleLineServerProtocol()
470         clientProto = self.clientProto = RecordingClientProtocol()
471
472         server = protocol.ServerFactory()
473         client = self.client = protocol.ClientFactory()
474
475         server.protocol = lambda: serverProto
476         client.protocol = lambda: clientProto
477
478         sCTX = ssl.DefaultOpenSSLContextFactory(certPath, certPath)
479         cCTX = ssl.ClientContextFactory()
480
481         port = reactor.listenSSL(0, server, sCTX, interface='127.0.0.1')
482         self.addCleanup(port.stopListening)
483
484         reactor.connectSSL('127.0.0.1', port.getHost().port, client, cCTX)
485
486         return clientProto.deferred.addCallback(
487             self.assertEqual, "+OK <some crap>\r\n")
488
489
490
491 class ConnectionLostTestCase(unittest.TestCase, ContextGeneratingMixin):
492     """
493     SSL connection closing tests.
494     """
495
496     def testImmediateDisconnect(self):
497         org = "twisted.test.test_ssl"
498         self.setupServerAndClient(
499             (org, org + ", client"), {},
500             (org, org + ", server"), {})
501
502         # Set up a server, connect to it with a client, which should work since our verifiers
503         # allow anything, then disconnect.
504         serverProtocolFactory = protocol.ServerFactory()
505         serverProtocolFactory.protocol = protocol.Protocol
506         self.serverPort = serverPort = reactor.listenSSL(0,
507             serverProtocolFactory, self.serverCtxFactory)
508
509         clientProtocolFactory = protocol.ClientFactory()
510         clientProtocolFactory.protocol = ImmediatelyDisconnectingProtocol
511         clientProtocolFactory.connectionDisconnected = defer.Deferred()
512         clientConnector = reactor.connectSSL('127.0.0.1',
513             serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
514
515         return clientProtocolFactory.connectionDisconnected.addCallback(
516             lambda ignoredResult: self.serverPort.stopListening())
517
518
519     def test_bothSidesLoseConnection(self):
520         """
521         Both sides of SSL connection close connection; the connections should
522         close cleanly, and only after the underlying TCP connection has
523         disconnected.
524         """
525         class CloseAfterHandshake(protocol.Protocol):
526             def __init__(self):
527                 self.done = defer.Deferred()
528
529             def connectionMade(self):
530                 self.transport.write("a")
531
532             def dataReceived(self, data):
533                 # If we got data, handshake is over:
534                 self.transport.loseConnection()
535
536             def connectionLost(self2, reason):
537                 self2.done.errback(reason)
538                 del self2.done
539
540         org = "twisted.test.test_ssl"
541         self.setupServerAndClient(
542             (org, org + ", client"), {},
543             (org, org + ", server"), {})
544
545         serverProtocol = CloseAfterHandshake()
546         serverProtocolFactory = protocol.ServerFactory()
547         serverProtocolFactory.protocol = lambda: serverProtocol
548         serverPort = reactor.listenSSL(0,
549             serverProtocolFactory, self.serverCtxFactory)
550         self.addCleanup(serverPort.stopListening)
551
552         clientProtocol = CloseAfterHandshake()
553         clientProtocolFactory = protocol.ClientFactory()
554         clientProtocolFactory.protocol = lambda: clientProtocol
555         clientConnector = reactor.connectSSL('127.0.0.1',
556             serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
557
558         def checkResult(failure):
559             failure.trap(ConnectionDone)
560         return defer.gatherResults(
561             [clientProtocol.done.addErrback(checkResult),
562              serverProtocol.done.addErrback(checkResult)])
563
564     if newTLS is None:
565         test_bothSidesLoseConnection.skip = "Old SSL code doesn't always close cleanly."
566
567
568     def testFailedVerify(self):
569         org = "twisted.test.test_ssl"
570         self.setupServerAndClient(
571             (org, org + ", client"), {},
572             (org, org + ", server"), {})
573
574         def verify(*a):
575             return False
576         self.clientCtxFactory.getContext().set_verify(SSL.VERIFY_PEER, verify)
577
578         serverConnLost = defer.Deferred()
579         serverProtocol = protocol.Protocol()
580         serverProtocol.connectionLost = serverConnLost.callback
581         serverProtocolFactory = protocol.ServerFactory()
582         serverProtocolFactory.protocol = lambda: serverProtocol
583         self.serverPort = serverPort = reactor.listenSSL(0,
584             serverProtocolFactory, self.serverCtxFactory)
585
586         clientConnLost = defer.Deferred()
587         clientProtocol = protocol.Protocol()
588         clientProtocol.connectionLost = clientConnLost.callback
589         clientProtocolFactory = protocol.ClientFactory()
590         clientProtocolFactory.protocol = lambda: clientProtocol
591         clientConnector = reactor.connectSSL('127.0.0.1',
592             serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
593
594         dl = defer.DeferredList([serverConnLost, clientConnLost], consumeErrors=True)
595         return dl.addCallback(self._cbLostConns)
596
597
598     def _cbLostConns(self, results):
599         (sSuccess, sResult), (cSuccess, cResult) = results
600
601         self.failIf(sSuccess)
602         self.failIf(cSuccess)
603
604         acceptableErrors = [SSL.Error]
605
606         # Rather than getting a verification failure on Windows, we are getting
607         # a connection failure.  Without something like sslverify proxying
608         # in-between we can't fix up the platform's errors, so let's just
609         # specifically say it is only OK in this one case to keep the tests
610         # passing.  Normally we'd like to be as strict as possible here, so
611         # we're not going to allow this to report errors incorrectly on any
612         # other platforms.
613
614         if platform.isWindows():
615             from twisted.internet.error import ConnectionLost
616             acceptableErrors.append(ConnectionLost)
617
618         sResult.trap(*acceptableErrors)
619         cResult.trap(*acceptableErrors)
620
621         return self.serverPort.stopListening()
622
623
624
625 class FakeContext:
626     """
627     L{OpenSSL.SSL.Context} double which can more easily be inspected.
628     """
629     def __init__(self, method):
630         self._method = method
631         self._options = 0
632
633
634     def set_options(self, options):
635         self._options |= options
636
637
638     def use_certificate_file(self, fileName):
639         pass
640
641
642     def use_privatekey_file(self, fileName):
643         pass
644
645
646
647 class DefaultOpenSSLContextFactoryTests(unittest.TestCase):
648     """
649     Tests for L{ssl.DefaultOpenSSLContextFactory}.
650     """
651     def setUp(self):
652         # pyOpenSSL Context objects aren't introspectable enough.  Pass in
653         # an alternate context factory so we can inspect what is done to it.
654         self.contextFactory = ssl.DefaultOpenSSLContextFactory(
655             certPath, certPath, _contextFactory=FakeContext)
656         self.context = self.contextFactory.getContext()
657
658
659     def test_method(self):
660         """
661         L{ssl.DefaultOpenSSLContextFactory.getContext} returns an SSL context
662         which can use SSLv3 or TLSv1 but not SSLv2.
663         """
664         # SSLv23_METHOD allows SSLv2, SSLv3, or TLSv1
665         self.assertEqual(self.context._method, SSL.SSLv23_METHOD)
666
667         # And OP_NO_SSLv2 disables the SSLv2 support.
668         self.assertTrue(self.context._options & SSL.OP_NO_SSLv2)
669
670         # Make sure SSLv3 and TLSv1 aren't disabled though.
671         self.assertFalse(self.context._options & SSL.OP_NO_SSLv3)
672         self.assertFalse(self.context._options & SSL.OP_NO_TLSv1)
673
674
675     def test_missingCertificateFile(self):
676         """
677         Instantiating L{ssl.DefaultOpenSSLContextFactory} with a certificate
678         filename which does not identify an existing file results in the
679         initializer raising L{OpenSSL.SSL.Error}.
680         """
681         self.assertRaises(
682             SSL.Error,
683             ssl.DefaultOpenSSLContextFactory, certPath, self.mktemp())
684
685
686     def test_missingPrivateKeyFile(self):
687         """
688         Instantiating L{ssl.DefaultOpenSSLContextFactory} with a private key
689         filename which does not identify an existing file results in the
690         initializer raising L{OpenSSL.SSL.Error}.
691         """
692         self.assertRaises(
693             SSL.Error,
694             ssl.DefaultOpenSSLContextFactory, self.mktemp(), certPath)
695
696
697
698 class ClientContextFactoryTests(unittest.TestCase):
699     """
700     Tests for L{ssl.ClientContextFactory}.
701     """
702     def setUp(self):
703         self.contextFactory = ssl.ClientContextFactory()
704         self.contextFactory._contextFactory = FakeContext
705         self.context = self.contextFactory.getContext()
706
707
708     def test_method(self):
709         """
710         L{ssl.ClientContextFactory.getContext} returns a context which can use
711         SSLv3 or TLSv1 but not SSLv2.
712         """
713         self.assertEqual(self.context._method, SSL.SSLv23_METHOD)
714         self.assertTrue(self.context._options & SSL.OP_NO_SSLv2)
715         self.assertFalse(self.context._options & SSL.OP_NO_SSLv3)
716         self.assertFalse(self.context._options & SSL.OP_NO_TLSv1)
717
718
719
720 if interfaces.IReactorSSL(reactor, None) is None:
721     for tCase in [StolenTCPTestCase, TLSTestCase, SpammyTLSTestCase,
722                   BufferingTestCase, ConnectionLostTestCase,
723                   DefaultOpenSSLContextFactoryTests,
724                   ClientContextFactoryTests]:
725         tCase.skip = "Reactor does not support SSL, cannot run SSL tests"
726