1 # Copyright (c) Twisted Matrix Laboratories.
2 # See LICENSE for details.
5 Tests for twisted SSL support.
8 from twisted.trial import unittest
9 from twisted.internet import protocol, reactor, interfaces, defer
10 from twisted.internet.error import ConnectionDone
11 from twisted.protocols import basic
12 from twisted.python import util
13 from twisted.python.runtime import platform
14 from twisted.test.test_tcp import ProperlyCloseFilesMixin
19 from OpenSSL import SSL, crypto
20 from twisted.internet import ssl
21 from twisted.test.ssl_helpers import ClientTLSContext
24 # ugh, make pyflakes happy.
31 from twisted.protocols import tls as newTLS
33 # Assuming SSL exists, we're using old version in reactor (i.e. non-protocol)
36 certPath = util.sibpath(__file__, "server.pem")
40 class UnintelligentProtocol(basic.LineReceiver):
42 @ivar deferred: a deferred that will fire at connection lost.
43 @type deferred: L{defer.Deferred}
45 @cvar pretext: text sent before TLS is set up.
48 @cvar posttext: text sent after TLS is set up.
49 @type posttext: C{str}
53 "last thing before tls starts",
57 "first thing after tls started",
61 self.deferred = defer.Deferred()
64 def connectionMade(self):
65 for l in self.pretext:
69 def lineReceived(self, line):
71 self.transport.startTLS(ClientTLSContext(), self.factory.client)
72 for l in self.posttext:
74 self.transport.loseConnection()
77 def connectionLost(self, reason):
78 self.deferred.callback(None)
82 class LineCollector(basic.LineReceiver):
84 @ivar deferred: a deferred that will fire at connection lost.
85 @type deferred: L{defer.Deferred}
87 @ivar doTLS: whether the protocol is initiate TLS or not.
90 @ivar fillBuffer: if set to True, it will send lots of data once
91 C{STARTTLS} is received.
92 @type fillBuffer: C{bool}
95 def __init__(self, doTLS, fillBuffer=False):
97 self.fillBuffer = fillBuffer
98 self.deferred = defer.Deferred()
101 def connectionMade(self):
102 self.factory.rawdata = ''
103 self.factory.lines = []
106 def lineReceived(self, line):
107 self.factory.lines.append(line)
108 if line == 'STARTTLS':
111 self.sendLine('X' * 1000)
112 self.sendLine('READY')
114 ctx = ServerTLSContext(
115 privateKeyFileName=certPath,
116 certificateFileName=certPath,
118 self.transport.startTLS(ctx, self.factory.server)
123 def rawDataReceived(self, data):
124 self.factory.rawdata += data
125 self.transport.loseConnection()
128 def connectionLost(self, reason):
129 self.deferred.callback(None)
133 class SingleLineServerProtocol(protocol.Protocol):
135 A protocol that sends a single line of data at C{connectionMade}.
138 def connectionMade(self):
139 self.transport.write("+OK <some crap>\r\n")
140 self.transport.getPeerCertificate()
144 class RecordingClientProtocol(protocol.Protocol):
146 @ivar deferred: a deferred that will fire with first received content.
147 @type deferred: L{defer.Deferred}
151 self.deferred = defer.Deferred()
154 def connectionMade(self):
155 self.transport.getPeerCertificate()
158 def dataReceived(self, data):
159 self.deferred.callback(data)
163 class ImmediatelyDisconnectingProtocol(protocol.Protocol):
165 A protocol that disconnect immediately on connection. It fires the
166 C{connectionDisconnected} deferred of its factory on connetion lost.
169 def connectionMade(self):
170 self.transport.loseConnection()
173 def connectionLost(self, reason):
174 self.factory.connectionDisconnected.callback(None)
178 def generateCertificateObjects(organization, organizationalUnit):
180 Create a certificate for given C{organization} and C{organizationalUnit}.
182 @return: a tuple of (key, request, certificate) objects.
185 pkey.generate_key(crypto.TYPE_RSA, 512)
186 req = crypto.X509Req()
187 subject = req.get_subject()
188 subject.O = organization
189 subject.OU = organizationalUnit
191 req.sign(pkey, "md5")
193 # Here comes the actual certificate
195 cert.set_serial_number(1)
196 cert.gmtime_adj_notBefore(0)
197 cert.gmtime_adj_notAfter(60) # Testing certificates need not be long lived
198 cert.set_issuer(req.get_subject())
199 cert.set_subject(req.get_subject())
200 cert.set_pubkey(req.get_pubkey())
201 cert.sign(pkey, "md5")
203 return pkey, req, cert
207 def generateCertificateFiles(basename, organization, organizationalUnit):
209 Create certificate files key, req and cert prefixed by C{basename} for
210 given C{organization} and C{organizationalUnit}.
212 pkey, req, cert = generateCertificateObjects(organization, organizationalUnit)
214 for ext, obj, dumpFunc in [
215 ('key', pkey, crypto.dump_privatekey),
216 ('req', req, crypto.dump_certificate_request),
217 ('cert', cert, crypto.dump_certificate)]:
218 fName = os.extsep.join((basename, ext))
219 fObj = file(fName, 'w')
220 fObj.write(dumpFunc(crypto.FILETYPE_PEM, obj))
225 class ContextGeneratingMixin:
227 Offer methods to create L{ssl.DefaultOpenSSLContextFactory} for both client
230 @ivar clientBase: prefix of client certificate files.
231 @type clientBase: C{str}
233 @ivar serverBase: prefix of server certificate files.
234 @type serverBase: C{str}
236 @ivar clientCtxFactory: a generated context factory to be used in
237 C{reactor.connectSSL}.
238 @type clientCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
240 @ivar serverCtxFactory: a generated context factory to be used in
241 C{reactor.listenSSL}.
242 @type serverCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
245 def makeContextFactory(self, org, orgUnit, *args, **kwArgs):
247 generateCertificateFiles(base, org, orgUnit)
248 serverCtxFactory = ssl.DefaultOpenSSLContextFactory(
249 os.extsep.join((base, 'key')),
250 os.extsep.join((base, 'cert')),
253 return base, serverCtxFactory
256 def setupServerAndClient(self, clientArgs, clientKwArgs, serverArgs,
258 self.clientBase, self.clientCtxFactory = self.makeContextFactory(
259 *clientArgs, **clientKwArgs)
260 self.serverBase, self.serverCtxFactory = self.makeContextFactory(
261 *serverArgs, **serverKwArgs)
266 class ServerTLSContext(ssl.DefaultOpenSSLContextFactory):
268 A context factory with a default method set to L{SSL.TLSv1_METHOD}.
272 def __init__(self, *args, **kw):
273 kw['sslmethod'] = SSL.TLSv1_METHOD
274 ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw)
278 class StolenTCPTestCase(ProperlyCloseFilesMixin, unittest.TestCase):
280 For SSL transports, test many of the same things which are tested for
284 def createServer(self, address, portNumber, factory):
286 Create an SSL server with a certificate using L{IReactorSSL.listenSSL}.
288 cert = ssl.PrivateCertificate.loadPEM(file(certPath).read())
289 contextFactory = cert.options()
290 return reactor.listenSSL(
291 portNumber, factory, contextFactory, interface=address)
294 def connectClient(self, address, portNumber, clientCreator):
296 Create an SSL client using L{IReactorSSL.connectSSL}.
298 contextFactory = ssl.CertificateOptions()
299 return clientCreator.connectSSL(address, portNumber, contextFactory)
302 def getHandleExceptionType(self):
304 Return L{SSL.Error} as the expected error type which will be raised by
305 a write to the L{OpenSSL.SSL.Connection} object after it has been
311 def getHandleErrorCode(self):
313 Return the argument L{SSL.Error} will be constructed with for this
314 case. This is basically just a random OpenSSL implementation detail.
315 It would be better if this test worked in a way which did not require
318 # Windows 2000 SP 4 and Windows XP SP 2 give back WSAENOTSOCK for
319 # SSL.Connection.write for some reason. The twisted.protocols.tls
320 # implementation of IReactorSSL doesn't suffer from this imprecation,
321 # though, since it is isolated from the Windows I/O layer (I suppose?).
323 # If test_properlyCloseFiles waited for the SSL handshake to complete
324 # and performed an orderly shutdown, then this would probably be a
325 # little less weird: writing to a shutdown SSL connection has a more
326 # well-defined failure mode (or at least it should).
328 # So figure out if twisted.protocols.tls is in use. If it can be
329 # imported, it should be.
331 import twisted.protocols.tls
333 # It isn't available, so we expect WSAENOTSOCK if we're on Windows.
334 if platform.getType() == 'win32':
335 return errno.WSAENOTSOCK
337 # Otherwise, we expect an error about how we tried to write to a
338 # shutdown connection. This is terribly implementation-specific.
339 return [('SSL routines', 'SSL_write', 'protocol is shutdown')]
343 class TLSTestCase(unittest.TestCase):
345 Tests for startTLS support.
347 @ivar fillBuffer: forwarded to L{LineCollector.fillBuffer}
348 @type fillBuffer: C{bool}
357 if self.clientProto.transport is not None:
358 self.clientProto.transport.loseConnection()
359 if self.serverProto.transport is not None:
360 self.serverProto.transport.loseConnection()
363 def _runTest(self, clientProto, serverProto, clientIsServer=False):
365 Helper method to run TLS tests.
367 @param clientProto: protocol instance attached to the client
369 @param serverProto: protocol instance attached to the server
371 @param clientIsServer: flag indicated if client should initiate
372 startTLS instead of server.
374 @return: a L{defer.Deferred} that will fire when both connections are
377 self.clientProto = clientProto
378 cf = self.clientFactory = protocol.ClientFactory()
379 cf.protocol = lambda: clientProto
385 self.serverProto = serverProto
386 sf = self.serverFactory = protocol.ServerFactory()
387 sf.protocol = lambda: serverProto
393 port = reactor.listenTCP(0, sf, interface="127.0.0.1")
394 self.addCleanup(port.stopListening)
396 reactor.connectTCP('127.0.0.1', port.getHost().port, cf)
398 return defer.gatherResults([clientProto.deferred, serverProto.deferred])
403 Test for server and client startTLS: client should received data both
404 before and after the startTLS.
408 self.serverFactory.lines,
409 UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
411 d = self._runTest(UnintelligentProtocol(),
412 LineCollector(True, self.fillBuffer))
413 return d.addCallback(check)
416 def test_unTLS(self):
418 Test for server startTLS not followed by a startTLS in client: the data
419 received after server startTLS should be received as raw.
423 self.serverFactory.lines,
424 UnintelligentProtocol.pretext
426 self.failUnless(self.serverFactory.rawdata,
427 "No encrypted bytes received")
428 d = self._runTest(UnintelligentProtocol(),
429 LineCollector(False, self.fillBuffer))
430 return d.addCallback(check)
433 def test_backwardsTLS(self):
435 Test startTLS first initiated by client.
439 self.clientFactory.lines,
440 UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
442 d = self._runTest(LineCollector(True, self.fillBuffer),
443 UnintelligentProtocol(), True)
444 return d.addCallback(check)
448 class SpammyTLSTestCase(TLSTestCase):
450 Test TLS features with bytes sitting in the out buffer.
456 class BufferingTestCase(unittest.TestCase):
462 if self.serverProto.transport is not None:
463 self.serverProto.transport.loseConnection()
464 if self.clientProto.transport is not None:
465 self.clientProto.transport.loseConnection()
468 def test_openSSLBuffering(self):
469 serverProto = self.serverProto = SingleLineServerProtocol()
470 clientProto = self.clientProto = RecordingClientProtocol()
472 server = protocol.ServerFactory()
473 client = self.client = protocol.ClientFactory()
475 server.protocol = lambda: serverProto
476 client.protocol = lambda: clientProto
478 sCTX = ssl.DefaultOpenSSLContextFactory(certPath, certPath)
479 cCTX = ssl.ClientContextFactory()
481 port = reactor.listenSSL(0, server, sCTX, interface='127.0.0.1')
482 self.addCleanup(port.stopListening)
484 reactor.connectSSL('127.0.0.1', port.getHost().port, client, cCTX)
486 return clientProto.deferred.addCallback(
487 self.assertEqual, "+OK <some crap>\r\n")
491 class ConnectionLostTestCase(unittest.TestCase, ContextGeneratingMixin):
493 SSL connection closing tests.
496 def testImmediateDisconnect(self):
497 org = "twisted.test.test_ssl"
498 self.setupServerAndClient(
499 (org, org + ", client"), {},
500 (org, org + ", server"), {})
502 # Set up a server, connect to it with a client, which should work since our verifiers
503 # allow anything, then disconnect.
504 serverProtocolFactory = protocol.ServerFactory()
505 serverProtocolFactory.protocol = protocol.Protocol
506 self.serverPort = serverPort = reactor.listenSSL(0,
507 serverProtocolFactory, self.serverCtxFactory)
509 clientProtocolFactory = protocol.ClientFactory()
510 clientProtocolFactory.protocol = ImmediatelyDisconnectingProtocol
511 clientProtocolFactory.connectionDisconnected = defer.Deferred()
512 clientConnector = reactor.connectSSL('127.0.0.1',
513 serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
515 return clientProtocolFactory.connectionDisconnected.addCallback(
516 lambda ignoredResult: self.serverPort.stopListening())
519 def test_bothSidesLoseConnection(self):
521 Both sides of SSL connection close connection; the connections should
522 close cleanly, and only after the underlying TCP connection has
525 class CloseAfterHandshake(protocol.Protocol):
527 self.done = defer.Deferred()
529 def connectionMade(self):
530 self.transport.write("a")
532 def dataReceived(self, data):
533 # If we got data, handshake is over:
534 self.transport.loseConnection()
536 def connectionLost(self2, reason):
537 self2.done.errback(reason)
540 org = "twisted.test.test_ssl"
541 self.setupServerAndClient(
542 (org, org + ", client"), {},
543 (org, org + ", server"), {})
545 serverProtocol = CloseAfterHandshake()
546 serverProtocolFactory = protocol.ServerFactory()
547 serverProtocolFactory.protocol = lambda: serverProtocol
548 serverPort = reactor.listenSSL(0,
549 serverProtocolFactory, self.serverCtxFactory)
550 self.addCleanup(serverPort.stopListening)
552 clientProtocol = CloseAfterHandshake()
553 clientProtocolFactory = protocol.ClientFactory()
554 clientProtocolFactory.protocol = lambda: clientProtocol
555 clientConnector = reactor.connectSSL('127.0.0.1',
556 serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
558 def checkResult(failure):
559 failure.trap(ConnectionDone)
560 return defer.gatherResults(
561 [clientProtocol.done.addErrback(checkResult),
562 serverProtocol.done.addErrback(checkResult)])
565 test_bothSidesLoseConnection.skip = "Old SSL code doesn't always close cleanly."
568 def testFailedVerify(self):
569 org = "twisted.test.test_ssl"
570 self.setupServerAndClient(
571 (org, org + ", client"), {},
572 (org, org + ", server"), {})
576 self.clientCtxFactory.getContext().set_verify(SSL.VERIFY_PEER, verify)
578 serverConnLost = defer.Deferred()
579 serverProtocol = protocol.Protocol()
580 serverProtocol.connectionLost = serverConnLost.callback
581 serverProtocolFactory = protocol.ServerFactory()
582 serverProtocolFactory.protocol = lambda: serverProtocol
583 self.serverPort = serverPort = reactor.listenSSL(0,
584 serverProtocolFactory, self.serverCtxFactory)
586 clientConnLost = defer.Deferred()
587 clientProtocol = protocol.Protocol()
588 clientProtocol.connectionLost = clientConnLost.callback
589 clientProtocolFactory = protocol.ClientFactory()
590 clientProtocolFactory.protocol = lambda: clientProtocol
591 clientConnector = reactor.connectSSL('127.0.0.1',
592 serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
594 dl = defer.DeferredList([serverConnLost, clientConnLost], consumeErrors=True)
595 return dl.addCallback(self._cbLostConns)
598 def _cbLostConns(self, results):
599 (sSuccess, sResult), (cSuccess, cResult) = results
601 self.failIf(sSuccess)
602 self.failIf(cSuccess)
604 acceptableErrors = [SSL.Error]
606 # Rather than getting a verification failure on Windows, we are getting
607 # a connection failure. Without something like sslverify proxying
608 # in-between we can't fix up the platform's errors, so let's just
609 # specifically say it is only OK in this one case to keep the tests
610 # passing. Normally we'd like to be as strict as possible here, so
611 # we're not going to allow this to report errors incorrectly on any
614 if platform.isWindows():
615 from twisted.internet.error import ConnectionLost
616 acceptableErrors.append(ConnectionLost)
618 sResult.trap(*acceptableErrors)
619 cResult.trap(*acceptableErrors)
621 return self.serverPort.stopListening()
627 L{OpenSSL.SSL.Context} double which can more easily be inspected.
629 def __init__(self, method):
630 self._method = method
634 def set_options(self, options):
635 self._options |= options
638 def use_certificate_file(self, fileName):
642 def use_privatekey_file(self, fileName):
647 class DefaultOpenSSLContextFactoryTests(unittest.TestCase):
649 Tests for L{ssl.DefaultOpenSSLContextFactory}.
652 # pyOpenSSL Context objects aren't introspectable enough. Pass in
653 # an alternate context factory so we can inspect what is done to it.
654 self.contextFactory = ssl.DefaultOpenSSLContextFactory(
655 certPath, certPath, _contextFactory=FakeContext)
656 self.context = self.contextFactory.getContext()
659 def test_method(self):
661 L{ssl.DefaultOpenSSLContextFactory.getContext} returns an SSL context
662 which can use SSLv3 or TLSv1 but not SSLv2.
664 # SSLv23_METHOD allows SSLv2, SSLv3, or TLSv1
665 self.assertEqual(self.context._method, SSL.SSLv23_METHOD)
667 # And OP_NO_SSLv2 disables the SSLv2 support.
668 self.assertTrue(self.context._options & SSL.OP_NO_SSLv2)
670 # Make sure SSLv3 and TLSv1 aren't disabled though.
671 self.assertFalse(self.context._options & SSL.OP_NO_SSLv3)
672 self.assertFalse(self.context._options & SSL.OP_NO_TLSv1)
675 def test_missingCertificateFile(self):
677 Instantiating L{ssl.DefaultOpenSSLContextFactory} with a certificate
678 filename which does not identify an existing file results in the
679 initializer raising L{OpenSSL.SSL.Error}.
683 ssl.DefaultOpenSSLContextFactory, certPath, self.mktemp())
686 def test_missingPrivateKeyFile(self):
688 Instantiating L{ssl.DefaultOpenSSLContextFactory} with a private key
689 filename which does not identify an existing file results in the
690 initializer raising L{OpenSSL.SSL.Error}.
694 ssl.DefaultOpenSSLContextFactory, self.mktemp(), certPath)
698 class ClientContextFactoryTests(unittest.TestCase):
700 Tests for L{ssl.ClientContextFactory}.
703 self.contextFactory = ssl.ClientContextFactory()
704 self.contextFactory._contextFactory = FakeContext
705 self.context = self.contextFactory.getContext()
708 def test_method(self):
710 L{ssl.ClientContextFactory.getContext} returns a context which can use
711 SSLv3 or TLSv1 but not SSLv2.
713 self.assertEqual(self.context._method, SSL.SSLv23_METHOD)
714 self.assertTrue(self.context._options & SSL.OP_NO_SSLv2)
715 self.assertFalse(self.context._options & SSL.OP_NO_SSLv3)
716 self.assertFalse(self.context._options & SSL.OP_NO_TLSv1)
720 if interfaces.IReactorSSL(reactor, None) is None:
721 for tCase in [StolenTCPTestCase, TLSTestCase, SpammyTLSTestCase,
722 BufferingTestCase, ConnectionLostTestCase,
723 DefaultOpenSSLContextFactoryTests,
724 ClientContextFactoryTests]:
725 tCase.skip = "Reactor does not support SSL, cannot run SSL tests"