Initial import to Tizen
[profile/ivi/python-pyOpenSSL.git] / OpenSSL / test / test_ssl.py
1 # Copyright (C) Jean-Paul Calderone
2 # See LICENSE for details.
3
4 """
5 Unit tests for L{OpenSSL.SSL}.
6 """
7
8 from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK
9 from sys import platform
10 from socket import error, socket
11 from os import makedirs
12 from os.path import join
13 from unittest import main
14
15 from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, FILETYPE_ASN1
16 from OpenSSL.crypto import PKey, X509, X509Extension
17 from OpenSSL.crypto import dump_privatekey, load_privatekey
18 from OpenSSL.crypto import dump_certificate, load_certificate
19
20 from OpenSSL.SSL import SENT_SHUTDOWN, RECEIVED_SHUTDOWN
21 from OpenSSL.SSL import SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD
22 from OpenSSL.SSL import OP_NO_SSLv2, OP_NO_SSLv3, OP_SINGLE_DH_USE
23 from OpenSSL.SSL import VERIFY_PEER, VERIFY_FAIL_IF_NO_PEER_CERT, VERIFY_CLIENT_ONCE
24 from OpenSSL.SSL import Error, SysCallError, WantReadError, ZeroReturnError
25 from OpenSSL.SSL import Context, ContextType, Connection, ConnectionType
26
27 from OpenSSL.test.util import TestCase, bytes, b
28 from OpenSSL.test.test_crypto import cleartextCertificatePEM, cleartextPrivateKeyPEM
29 from OpenSSL.test.test_crypto import client_cert_pem, client_key_pem
30 from OpenSSL.test.test_crypto import server_cert_pem, server_key_pem, root_cert_pem
31
32 try:
33     from OpenSSL.SSL import OP_NO_QUERY_MTU
34 except ImportError:
35     OP_NO_QUERY_MTU = None
36 try:
37     from OpenSSL.SSL import OP_COOKIE_EXCHANGE
38 except ImportError:
39     OP_COOKIE_EXCHANGE = None
40 try:
41     from OpenSSL.SSL import OP_NO_TICKET
42 except ImportError:
43     OP_NO_TICKET = None
44
45 from OpenSSL.SSL import (
46     SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK, SSL_ST_INIT, SSL_ST_BEFORE,
47     SSL_ST_OK, SSL_ST_RENEGOTIATE,
48     SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT,
49     SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP,
50     SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT,
51     SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE)
52
53 # openssl dhparam 128 -out dh-128.pem (note that 128 is a small number of bits
54 # to use)
55 dhparam = """\
56 -----BEGIN DH PARAMETERS-----
57 MBYCEQCobsg29c9WZP/54oAPcwiDAgEC
58 -----END DH PARAMETERS-----
59 """
60
61
62 def verify_cb(conn, cert, errnum, depth, ok):
63     return ok
64
65 def socket_pair():
66     """
67     Establish and return a pair of network sockets connected to each other.
68     """
69     # Connect a pair of sockets
70     port = socket()
71     port.bind(('', 0))
72     port.listen(1)
73     client = socket()
74     client.setblocking(False)
75     client.connect_ex(("127.0.0.1", port.getsockname()[1]))
76     client.setblocking(True)
77     server = port.accept()[0]
78
79     # Let's pass some unencrypted data to make sure our socket connection is
80     # fine.  Just one byte, so we don't have to worry about buffers getting
81     # filled up or fragmentation.
82     server.send(b("x"))
83     assert client.recv(1024) == b("x")
84     client.send(b("y"))
85     assert server.recv(1024) == b("y")
86
87     # Most of our callers want non-blocking sockets, make it easy for them.
88     server.setblocking(False)
89     client.setblocking(False)
90
91     return (server, client)
92
93
94
95 def handshake(client, server):
96     conns = [client, server]
97     while conns:
98         for conn in conns:
99             try:
100                 conn.do_handshake()
101             except WantReadError:
102                 pass
103             else:
104                 conns.remove(conn)
105
106
107 class _LoopbackMixin:
108     """
109     Helper mixin which defines methods for creating a connected socket pair and
110     for forcing two connected SSL sockets to talk to each other via memory BIOs.
111     """
112     def _loopback(self):
113         (server, client) = socket_pair()
114
115         ctx = Context(TLSv1_METHOD)
116         ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
117         ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
118         server = Connection(ctx, server)
119         server.set_accept_state()
120         client = Connection(Context(TLSv1_METHOD), client)
121         client.set_connect_state()
122
123         handshake(client, server)
124
125         server.setblocking(True)
126         client.setblocking(True)
127         return server, client
128
129
130     def _interactInMemory(self, client_conn, server_conn):
131         """
132         Try to read application bytes from each of the two L{Connection}
133         objects.  Copy bytes back and forth between their send/receive buffers
134         for as long as there is anything to copy.  When there is nothing more
135         to copy, return C{None}.  If one of them actually manages to deliver
136         some application bytes, return a two-tuple of the connection from which
137         the bytes were read and the bytes themselves.
138         """
139         wrote = True
140         while wrote:
141             # Loop until neither side has anything to say
142             wrote = False
143
144             # Copy stuff from each side's send buffer to the other side's
145             # receive buffer.
146             for (read, write) in [(client_conn, server_conn),
147                                   (server_conn, client_conn)]:
148
149                 # Give the side a chance to generate some more bytes, or
150                 # succeed.
151                 try:
152                     bytes = read.recv(2 ** 16)
153                 except WantReadError:
154                     # It didn't succeed, so we'll hope it generated some
155                     # output.
156                     pass
157                 else:
158                     # It did succeed, so we'll stop now and let the caller deal
159                     # with it.
160                     return (read, bytes)
161
162                 while True:
163                     # Keep copying as long as there's more stuff there.
164                     try:
165                         dirty = read.bio_read(4096)
166                     except WantReadError:
167                         # Okay, nothing more waiting to be sent.  Stop
168                         # processing this send buffer.
169                         break
170                     else:
171                         # Keep track of the fact that someone generated some
172                         # output.
173                         wrote = True
174                         write.bio_write(dirty)
175
176
177
178 class ContextTests(TestCase, _LoopbackMixin):
179     """
180     Unit tests for L{OpenSSL.SSL.Context}.
181     """
182     def test_method(self):
183         """
184         L{Context} can be instantiated with one of L{SSLv2_METHOD},
185         L{SSLv3_METHOD}, L{SSLv23_METHOD}, or L{TLSv1_METHOD}.
186         """
187         for meth in [SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD]:
188             Context(meth)
189         self.assertRaises(TypeError, Context, "")
190         self.assertRaises(ValueError, Context, 10)
191
192
193     def test_type(self):
194         """
195         L{Context} and L{ContextType} refer to the same type object and can be
196         used to create instances of that type.
197         """
198         self.assertIdentical(Context, ContextType)
199         self.assertConsistentType(Context, 'Context', TLSv1_METHOD)
200
201
202     def test_use_privatekey(self):
203         """
204         L{Context.use_privatekey} takes an L{OpenSSL.crypto.PKey} instance.
205         """
206         key = PKey()
207         key.generate_key(TYPE_RSA, 128)
208         ctx = Context(TLSv1_METHOD)
209         ctx.use_privatekey(key)
210         self.assertRaises(TypeError, ctx.use_privatekey, "")
211
212
213     def test_set_app_data_wrong_args(self):
214         """
215         L{Context.set_app_data} raises L{TypeError} if called with other than
216         one argument.
217         """
218         context = Context(TLSv1_METHOD)
219         self.assertRaises(TypeError, context.set_app_data)
220         self.assertRaises(TypeError, context.set_app_data, None, None)
221
222
223     def test_get_app_data_wrong_args(self):
224         """
225         L{Context.get_app_data} raises L{TypeError} if called with any
226         arguments.
227         """
228         context = Context(TLSv1_METHOD)
229         self.assertRaises(TypeError, context.get_app_data, None)
230
231
232     def test_app_data(self):
233         """
234         L{Context.set_app_data} stores an object for later retrieval using
235         L{Context.get_app_data}.
236         """
237         app_data = object()
238         context = Context(TLSv1_METHOD)
239         context.set_app_data(app_data)
240         self.assertIdentical(context.get_app_data(), app_data)
241
242
243     def test_set_options_wrong_args(self):
244         """
245         L{Context.set_options} raises L{TypeError} if called with the wrong
246         number of arguments or a non-C{int} argument.
247         """
248         context = Context(TLSv1_METHOD)
249         self.assertRaises(TypeError, context.set_options)
250         self.assertRaises(TypeError, context.set_options, None)
251         self.assertRaises(TypeError, context.set_options, 1, None)
252
253
254     def test_set_timeout_wrong_args(self):
255         """
256         L{Context.set_timeout} raises L{TypeError} if called with the wrong
257         number of arguments or a non-C{int} argument.
258         """
259         context = Context(TLSv1_METHOD)
260         self.assertRaises(TypeError, context.set_timeout)
261         self.assertRaises(TypeError, context.set_timeout, None)
262         self.assertRaises(TypeError, context.set_timeout, 1, None)
263
264
265     def test_get_timeout_wrong_args(self):
266         """
267         L{Context.get_timeout} raises L{TypeError} if called with any arguments.
268         """
269         context = Context(TLSv1_METHOD)
270         self.assertRaises(TypeError, context.get_timeout, None)
271
272
273     def test_timeout(self):
274         """
275         L{Context.set_timeout} sets the session timeout for all connections
276         created using the context object.  L{Context.get_timeout} retrieves this
277         value.
278         """
279         context = Context(TLSv1_METHOD)
280         context.set_timeout(1234)
281         self.assertEquals(context.get_timeout(), 1234)
282
283
284     def test_set_verify_depth_wrong_args(self):
285         """
286         L{Context.set_verify_depth} raises L{TypeError} if called with the wrong
287         number of arguments or a non-C{int} argument.
288         """
289         context = Context(TLSv1_METHOD)
290         self.assertRaises(TypeError, context.set_verify_depth)
291         self.assertRaises(TypeError, context.set_verify_depth, None)
292         self.assertRaises(TypeError, context.set_verify_depth, 1, None)
293
294
295     def test_get_verify_depth_wrong_args(self):
296         """
297         L{Context.get_verify_depth} raises L{TypeError} if called with any arguments.
298         """
299         context = Context(TLSv1_METHOD)
300         self.assertRaises(TypeError, context.get_verify_depth, None)
301
302
303     def test_verify_depth(self):
304         """
305         L{Context.set_verify_depth} sets the number of certificates in a chain
306         to follow before giving up.  The value can be retrieved with
307         L{Context.get_verify_depth}.
308         """
309         context = Context(TLSv1_METHOD)
310         context.set_verify_depth(11)
311         self.assertEquals(context.get_verify_depth(), 11)
312
313
314     def _write_encrypted_pem(self, passphrase):
315         """
316         Write a new private key out to a new file, encrypted using the given
317         passphrase.  Return the path to the new file.
318         """
319         key = PKey()
320         key.generate_key(TYPE_RSA, 128)
321         pemFile = self.mktemp()
322         fObj = open(pemFile, 'w')
323         pem = dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase)
324         fObj.write(pem.decode('ascii'))
325         fObj.close()
326         return pemFile
327
328
329     def test_set_passwd_cb_wrong_args(self):
330         """
331         L{Context.set_passwd_cb} raises L{TypeError} if called with the
332         wrong arguments or with a non-callable first argument.
333         """
334         context = Context(TLSv1_METHOD)
335         self.assertRaises(TypeError, context.set_passwd_cb)
336         self.assertRaises(TypeError, context.set_passwd_cb, None)
337         self.assertRaises(TypeError, context.set_passwd_cb, lambda: None, None, None)
338
339
340     def test_set_passwd_cb(self):
341         """
342         L{Context.set_passwd_cb} accepts a callable which will be invoked when
343         a private key is loaded from an encrypted PEM.
344         """
345         passphrase = b("foobar")
346         pemFile = self._write_encrypted_pem(passphrase)
347         calledWith = []
348         def passphraseCallback(maxlen, verify, extra):
349             calledWith.append((maxlen, verify, extra))
350             return passphrase
351         context = Context(TLSv1_METHOD)
352         context.set_passwd_cb(passphraseCallback)
353         context.use_privatekey_file(pemFile)
354         self.assertTrue(len(calledWith), 1)
355         self.assertTrue(isinstance(calledWith[0][0], int))
356         self.assertTrue(isinstance(calledWith[0][1], int))
357         self.assertEqual(calledWith[0][2], None)
358
359
360     def test_passwd_callback_exception(self):
361         """
362         L{Context.use_privatekey_file} propagates any exception raised by the
363         passphrase callback.
364         """
365         pemFile = self._write_encrypted_pem(b("monkeys are nice"))
366         def passphraseCallback(maxlen, verify, extra):
367             raise RuntimeError("Sorry, I am a fail.")
368
369         context = Context(TLSv1_METHOD)
370         context.set_passwd_cb(passphraseCallback)
371         self.assertRaises(RuntimeError, context.use_privatekey_file, pemFile)
372
373
374     def test_passwd_callback_false(self):
375         """
376         L{Context.use_privatekey_file} raises L{OpenSSL.SSL.Error} if the
377         passphrase callback returns a false value.
378         """
379         pemFile = self._write_encrypted_pem(b("monkeys are nice"))
380         def passphraseCallback(maxlen, verify, extra):
381             return None
382
383         context = Context(TLSv1_METHOD)
384         context.set_passwd_cb(passphraseCallback)
385         self.assertRaises(Error, context.use_privatekey_file, pemFile)
386
387
388     def test_passwd_callback_non_string(self):
389         """
390         L{Context.use_privatekey_file} raises L{OpenSSL.SSL.Error} if the
391         passphrase callback returns a true non-string value.
392         """
393         pemFile = self._write_encrypted_pem(b("monkeys are nice"))
394         def passphraseCallback(maxlen, verify, extra):
395             return 10
396
397         context = Context(TLSv1_METHOD)
398         context.set_passwd_cb(passphraseCallback)
399         self.assertRaises(Error, context.use_privatekey_file, pemFile)
400
401
402     def test_passwd_callback_too_long(self):
403         """
404         If the passphrase returned by the passphrase callback returns a string
405         longer than the indicated maximum length, it is truncated.
406         """
407         # A priori knowledge!
408         passphrase = b("x") * 1024
409         pemFile = self._write_encrypted_pem(passphrase)
410         def passphraseCallback(maxlen, verify, extra):
411             assert maxlen == 1024
412             return passphrase + b("y")
413
414         context = Context(TLSv1_METHOD)
415         context.set_passwd_cb(passphraseCallback)
416         # This shall succeed because the truncated result is the correct
417         # passphrase.
418         context.use_privatekey_file(pemFile)
419
420
421     def test_set_info_callback(self):
422         """
423         L{Context.set_info_callback} accepts a callable which will be invoked
424         when certain information about an SSL connection is available.
425         """
426         (server, client) = socket_pair()
427
428         clientSSL = Connection(Context(TLSv1_METHOD), client)
429         clientSSL.set_connect_state()
430
431         called = []
432         def info(conn, where, ret):
433             called.append((conn, where, ret))
434         context = Context(TLSv1_METHOD)
435         context.set_info_callback(info)
436         context.use_certificate(
437             load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
438         context.use_privatekey(
439             load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
440
441         serverSSL = Connection(context, server)
442         serverSSL.set_accept_state()
443
444         while not called:
445             for ssl in clientSSL, serverSSL:
446                 try:
447                     ssl.do_handshake()
448                 except WantReadError:
449                     pass
450
451         # Kind of lame.  Just make sure it got called somehow.
452         self.assertTrue(called)
453
454
455     def _load_verify_locations_test(self, *args):
456         """
457         Create a client context which will verify the peer certificate and call
458         its C{load_verify_locations} method with C{*args}.  Then connect it to a
459         server and ensure that the handshake succeeds.
460         """
461         (server, client) = socket_pair()
462
463         clientContext = Context(TLSv1_METHOD)
464         clientContext.load_verify_locations(*args)
465         # Require that the server certificate verify properly or the
466         # connection will fail.
467         clientContext.set_verify(
468             VERIFY_PEER,
469             lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
470
471         clientSSL = Connection(clientContext, client)
472         clientSSL.set_connect_state()
473
474         serverContext = Context(TLSv1_METHOD)
475         serverContext.use_certificate(
476             load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
477         serverContext.use_privatekey(
478             load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
479
480         serverSSL = Connection(serverContext, server)
481         serverSSL.set_accept_state()
482
483         # Without load_verify_locations above, the handshake
484         # will fail:
485         # Error: [('SSL routines', 'SSL3_GET_SERVER_CERTIFICATE',
486         #          'certificate verify failed')]
487         handshake(clientSSL, serverSSL)
488
489         cert = clientSSL.get_peer_certificate()
490         self.assertEqual(cert.get_subject().CN, 'Testing Root CA')
491
492
493     def test_load_verify_file(self):
494         """
495         L{Context.load_verify_locations} accepts a file name and uses the
496         certificates within for verification purposes.
497         """
498         cafile = self.mktemp()
499         fObj = open(cafile, 'w')
500         fObj.write(cleartextCertificatePEM.decode('ascii'))
501         fObj.close()
502
503         self._load_verify_locations_test(cafile)
504
505
506     def test_load_verify_invalid_file(self):
507         """
508         L{Context.load_verify_locations} raises L{Error} when passed a
509         non-existent cafile.
510         """
511         clientContext = Context(TLSv1_METHOD)
512         self.assertRaises(
513             Error, clientContext.load_verify_locations, self.mktemp())
514
515
516     def test_load_verify_directory(self):
517         """
518         L{Context.load_verify_locations} accepts a directory name and uses
519         the certificates within for verification purposes.
520         """
521         capath = self.mktemp()
522         makedirs(capath)
523         # Hash value computed manually with c_rehash to avoid depending on
524         # c_rehash in the test suite.
525         cafile = join(capath, 'c7adac82.0')
526         fObj = open(cafile, 'w')
527         fObj.write(cleartextCertificatePEM.decode('ascii'))
528         fObj.close()
529
530         self._load_verify_locations_test(None, capath)
531
532
533     def test_load_verify_locations_wrong_args(self):
534         """
535         L{Context.load_verify_locations} raises L{TypeError} if called with
536         the wrong number of arguments or with non-C{str} arguments.
537         """
538         context = Context(TLSv1_METHOD)
539         self.assertRaises(TypeError, context.load_verify_locations)
540         self.assertRaises(TypeError, context.load_verify_locations, object())
541         self.assertRaises(TypeError, context.load_verify_locations, object(), object())
542         self.assertRaises(TypeError, context.load_verify_locations, None, None, None)
543
544
545     if platform == "win32":
546         "set_default_verify_paths appears not to work on Windows.  "
547         "See LP#404343 and LP#404344."
548     else:
549         def test_set_default_verify_paths(self):
550             """
551             L{Context.set_default_verify_paths} causes the platform-specific CA
552             certificate locations to be used for verification purposes.
553             """
554             # Testing this requires a server with a certificate signed by one of
555             # the CAs in the platform CA location.  Getting one of those costs
556             # money.  Fortunately (or unfortunately, depending on your
557             # perspective), it's easy to think of a public server on the
558             # internet which has such a certificate.  Connecting to the network
559             # in a unit test is bad, but it's the only way I can think of to
560             # really test this. -exarkun
561
562             # Arg, verisign.com doesn't speak TLSv1
563             context = Context(SSLv3_METHOD)
564             context.set_default_verify_paths()
565             context.set_verify(
566                 VERIFY_PEER,
567                 lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
568
569             client = socket()
570             client.connect(('verisign.com', 443))
571             clientSSL = Connection(context, client)
572             clientSSL.set_connect_state()
573             clientSSL.do_handshake()
574             clientSSL.send('GET / HTTP/1.0\r\n\r\n')
575             self.assertTrue(clientSSL.recv(1024))
576
577
578     def test_set_default_verify_paths_signature(self):
579         """
580         L{Context.set_default_verify_paths} takes no arguments and raises
581         L{TypeError} if given any.
582         """
583         context = Context(TLSv1_METHOD)
584         self.assertRaises(TypeError, context.set_default_verify_paths, None)
585         self.assertRaises(TypeError, context.set_default_verify_paths, 1)
586         self.assertRaises(TypeError, context.set_default_verify_paths, "")
587
588
589     def test_add_extra_chain_cert_invalid_cert(self):
590         """
591         L{Context.add_extra_chain_cert} raises L{TypeError} if called with
592         other than one argument or if called with an object which is not an
593         instance of L{X509}.
594         """
595         context = Context(TLSv1_METHOD)
596         self.assertRaises(TypeError, context.add_extra_chain_cert)
597         self.assertRaises(TypeError, context.add_extra_chain_cert, object())
598         self.assertRaises(TypeError, context.add_extra_chain_cert, object(), object())
599
600
601     def _create_certificate_chain(self):
602         """
603         Construct and return a chain of certificates.
604
605             1. A new self-signed certificate authority certificate (cacert)
606             2. A new intermediate certificate signed by cacert (icert)
607             3. A new server certificate signed by icert (scert)
608         """
609         caext = X509Extension(b('basicConstraints'), False, b('CA:true'))
610
611         # Step 1
612         cakey = PKey()
613         cakey.generate_key(TYPE_RSA, 512)
614         cacert = X509()
615         cacert.get_subject().commonName = "Authority Certificate"
616         cacert.set_issuer(cacert.get_subject())
617         cacert.set_pubkey(cakey)
618         cacert.set_notBefore(b("20000101000000Z"))
619         cacert.set_notAfter(b("20200101000000Z"))
620         cacert.add_extensions([caext])
621         cacert.set_serial_number(0)
622         cacert.sign(cakey, "sha1")
623
624         # Step 2
625         ikey = PKey()
626         ikey.generate_key(TYPE_RSA, 512)
627         icert = X509()
628         icert.get_subject().commonName = "Intermediate Certificate"
629         icert.set_issuer(cacert.get_subject())
630         icert.set_pubkey(ikey)
631         icert.set_notBefore(b("20000101000000Z"))
632         icert.set_notAfter(b("20200101000000Z"))
633         icert.add_extensions([caext])
634         icert.set_serial_number(0)
635         icert.sign(cakey, "sha1")
636
637         # Step 3
638         skey = PKey()
639         skey.generate_key(TYPE_RSA, 512)
640         scert = X509()
641         scert.get_subject().commonName = "Server Certificate"
642         scert.set_issuer(icert.get_subject())
643         scert.set_pubkey(skey)
644         scert.set_notBefore(b("20000101000000Z"))
645         scert.set_notAfter(b("20200101000000Z"))
646         scert.add_extensions([
647                 X509Extension(b('basicConstraints'), True, b('CA:false'))])
648         scert.set_serial_number(0)
649         scert.sign(ikey, "sha1")
650
651         return [(cakey, cacert), (ikey, icert), (skey, scert)]
652
653
654     def _handshake_test(self, serverContext, clientContext):
655         """
656         Verify that a client and server created with the given contexts can
657         successfully handshake and communicate.
658         """
659         serverSocket, clientSocket = socket_pair()
660
661         server = Connection(serverContext, serverSocket)
662         server.set_accept_state()
663
664         client = Connection(clientContext, clientSocket)
665         client.set_connect_state()
666
667         # Make them talk to each other.
668         # self._interactInMemory(client, server)
669         for i in range(3):
670             for s in [client, server]:
671                 try:
672                     s.do_handshake()
673                 except WantReadError:
674                     pass
675
676
677     def test_add_extra_chain_cert(self):
678         """
679         L{Context.add_extra_chain_cert} accepts an L{X509} instance to add to
680         the certificate chain.
681
682         See L{_create_certificate_chain} for the details of the certificate
683         chain tested.
684
685         The chain is tested by starting a server with scert and connecting
686         to it with a client which trusts cacert and requires verification to
687         succeed.
688         """
689         chain = self._create_certificate_chain()
690         [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
691
692         # Dump the CA certificate to a file because that's the only way to load
693         # it as a trusted CA in the client context.
694         for cert, name in [(cacert, 'ca.pem'), (icert, 'i.pem'), (scert, 's.pem')]:
695             fObj = open(name, 'w')
696             fObj.write(dump_certificate(FILETYPE_PEM, cert).decode('ascii'))
697             fObj.close()
698
699         for key, name in [(cakey, 'ca.key'), (ikey, 'i.key'), (skey, 's.key')]:
700             fObj = open(name, 'w')
701             fObj.write(dump_privatekey(FILETYPE_PEM, key).decode('ascii'))
702             fObj.close()
703
704         # Create the server context
705         serverContext = Context(TLSv1_METHOD)
706         serverContext.use_privatekey(skey)
707         serverContext.use_certificate(scert)
708         # The client already has cacert, we only need to give them icert.
709         serverContext.add_extra_chain_cert(icert)
710
711         # Create the client
712         clientContext = Context(TLSv1_METHOD)
713         clientContext.set_verify(
714             VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb)
715         clientContext.load_verify_locations('ca.pem')
716
717         # Try it out.
718         self._handshake_test(serverContext, clientContext)
719
720
721     def test_use_certificate_chain_file(self):
722         """
723         L{Context.use_certificate_chain_file} reads a certificate chain from
724         the specified file.
725
726         The chain is tested by starting a server with scert and connecting
727         to it with a client which trusts cacert and requires verification to
728         succeed.
729         """
730         chain = self._create_certificate_chain()
731         [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
732
733         # Write out the chain file.
734         chainFile = self.mktemp()
735         fObj = open(chainFile, 'w')
736         # Most specific to least general.
737         fObj.write(dump_certificate(FILETYPE_PEM, scert).decode('ascii'))
738         fObj.write(dump_certificate(FILETYPE_PEM, icert).decode('ascii'))
739         fObj.write(dump_certificate(FILETYPE_PEM, cacert).decode('ascii'))
740         fObj.close()
741
742         serverContext = Context(TLSv1_METHOD)
743         serverContext.use_certificate_chain_file(chainFile)
744         serverContext.use_privatekey(skey)
745
746         fObj = open('ca.pem', 'w')
747         fObj.write(dump_certificate(FILETYPE_PEM, cacert).decode('ascii'))
748         fObj.close()
749
750         clientContext = Context(TLSv1_METHOD)
751         clientContext.set_verify(
752             VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb)
753         clientContext.load_verify_locations('ca.pem')
754
755         self._handshake_test(serverContext, clientContext)
756
757     # XXX load_client_ca
758     # XXX set_session_id
759
760     def test_get_verify_mode_wrong_args(self):
761         """
762         L{Context.get_verify_mode} raises L{TypeError} if called with any
763         arguments.
764         """
765         context = Context(TLSv1_METHOD)
766         self.assertRaises(TypeError, context.get_verify_mode, None)
767
768
769     def test_get_verify_mode(self):
770         """
771         L{Context.get_verify_mode} returns the verify mode flags previously
772         passed to L{Context.set_verify}.
773         """
774         context = Context(TLSv1_METHOD)
775         self.assertEquals(context.get_verify_mode(), 0)
776         context.set_verify(
777             VERIFY_PEER | VERIFY_CLIENT_ONCE, lambda *args: None)
778         self.assertEquals(
779             context.get_verify_mode(), VERIFY_PEER | VERIFY_CLIENT_ONCE)
780
781
782     def test_load_tmp_dh_wrong_args(self):
783         """
784         L{Context.load_tmp_dh} raises L{TypeError} if called with the wrong
785         number of arguments or with a non-C{str} argument.
786         """
787         context = Context(TLSv1_METHOD)
788         self.assertRaises(TypeError, context.load_tmp_dh)
789         self.assertRaises(TypeError, context.load_tmp_dh, "foo", None)
790         self.assertRaises(TypeError, context.load_tmp_dh, object())
791
792
793     def test_load_tmp_dh_missing_file(self):
794         """
795         L{Context.load_tmp_dh} raises L{OpenSSL.SSL.Error} if the specified file
796         does not exist.
797         """
798         context = Context(TLSv1_METHOD)
799         self.assertRaises(Error, context.load_tmp_dh, "hello")
800
801
802     def test_load_tmp_dh(self):
803         """
804         L{Context.load_tmp_dh} loads Diffie-Hellman parameters from the
805         specified file.
806         """
807         context = Context(TLSv1_METHOD)
808         dhfilename = self.mktemp()
809         dhfile = open(dhfilename, "w")
810         dhfile.write(dhparam)
811         dhfile.close()
812         context.load_tmp_dh(dhfilename)
813         # XXX What should I assert here? -exarkun
814
815
816     def test_set_cipher_list(self):
817         """
818         L{Context.set_cipher_list} accepts a C{str} naming the ciphers which
819         connections created with the context object will be able to choose from.
820         """
821         context = Context(TLSv1_METHOD)
822         context.set_cipher_list("hello world:EXP-RC4-MD5")
823         conn = Connection(context, None)
824         self.assertEquals(conn.get_cipher_list(), ["EXP-RC4-MD5"])
825
826
827
828 class ConnectionTests(TestCase, _LoopbackMixin):
829     """
830     Unit tests for L{OpenSSL.SSL.Connection}.
831     """
832     # XXX want_write
833     # XXX want_read
834     # XXX get_peer_certificate -> None
835     # XXX sock_shutdown
836     # XXX master_key -> TypeError
837     # XXX server_random -> TypeError
838     # XXX state_string
839     # XXX connect -> TypeError
840     # XXX connect_ex -> TypeError
841     # XXX set_connect_state -> TypeError
842     # XXX set_accept_state -> TypeError
843     # XXX renegotiate_pending
844     # XXX do_handshake -> TypeError
845     # XXX bio_read -> TypeError
846     # XXX recv -> TypeError
847     # XXX send -> TypeError
848     # XXX bio_write -> TypeError
849
850     def test_type(self):
851         """
852         L{Connection} and L{ConnectionType} refer to the same type object and
853         can be used to create instances of that type.
854         """
855         self.assertIdentical(Connection, ConnectionType)
856         ctx = Context(TLSv1_METHOD)
857         self.assertConsistentType(Connection, 'Connection', ctx, None)
858
859
860     def test_get_context(self):
861         """
862         L{Connection.get_context} returns the L{Context} instance used to
863         construct the L{Connection} instance.
864         """
865         context = Context(TLSv1_METHOD)
866         connection = Connection(context, None)
867         self.assertIdentical(connection.get_context(), context)
868
869
870     def test_get_context_wrong_args(self):
871         """
872         L{Connection.get_context} raises L{TypeError} if called with any
873         arguments.
874         """
875         connection = Connection(Context(TLSv1_METHOD), None)
876         self.assertRaises(TypeError, connection.get_context, None)
877
878
879     def test_pending(self):
880         """
881         L{Connection.pending} returns the number of bytes available for
882         immediate read.
883         """
884         connection = Connection(Context(TLSv1_METHOD), None)
885         self.assertEquals(connection.pending(), 0)
886
887
888     def test_pending_wrong_args(self):
889         """
890         L{Connection.pending} raises L{TypeError} if called with any arguments.
891         """
892         connection = Connection(Context(TLSv1_METHOD), None)
893         self.assertRaises(TypeError, connection.pending, None)
894
895
896     def test_connect_wrong_args(self):
897         """
898         L{Connection.connect} raises L{TypeError} if called with a non-address
899         argument or with the wrong number of arguments.
900         """
901         connection = Connection(Context(TLSv1_METHOD), socket())
902         self.assertRaises(TypeError, connection.connect, None)
903         self.assertRaises(TypeError, connection.connect)
904         self.assertRaises(TypeError, connection.connect, ("127.0.0.1", 1), None)
905
906
907     def test_connect_refused(self):
908         """
909         L{Connection.connect} raises L{socket.error} if the underlying socket
910         connect method raises it.
911         """
912         client = socket()
913         context = Context(TLSv1_METHOD)
914         clientSSL = Connection(context, client)
915         exc = self.assertRaises(error, clientSSL.connect, ("127.0.0.1", 1))
916         self.assertEquals(exc.args[0], ECONNREFUSED)
917
918
919     def test_connect(self):
920         """
921         L{Connection.connect} establishes a connection to the specified address.
922         """
923         port = socket()
924         port.bind(('', 0))
925         port.listen(3)
926
927         clientSSL = Connection(Context(TLSv1_METHOD), socket())
928         clientSSL.connect(('127.0.0.1', port.getsockname()[1]))
929         # XXX An assertion?  Or something?
930
931
932     if platform == "darwin":
933         "connect_ex sometimes causes a kernel panic on OS X 10.6.4"
934     else:
935         def test_connect_ex(self):
936             """
937             If there is a connection error, L{Connection.connect_ex} returns the
938             errno instead of raising an exception.
939             """
940             port = socket()
941             port.bind(('', 0))
942             port.listen(3)
943
944             clientSSL = Connection(Context(TLSv1_METHOD), socket())
945             clientSSL.setblocking(False)
946             result = clientSSL.connect_ex(port.getsockname())
947             expected = (EINPROGRESS, EWOULDBLOCK)
948             self.assertTrue(
949                     result in expected, "%r not in %r" % (result, expected))
950
951
952     def test_accept_wrong_args(self):
953         """
954         L{Connection.accept} raises L{TypeError} if called with any arguments.
955         """
956         connection = Connection(Context(TLSv1_METHOD), socket())
957         self.assertRaises(TypeError, connection.accept, None)
958
959
960     def test_accept(self):
961         """
962         L{Connection.accept} accepts a pending connection attempt and returns a
963         tuple of a new L{Connection} (the accepted client) and the address the
964         connection originated from.
965         """
966         ctx = Context(TLSv1_METHOD)
967         ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
968         ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
969         port = socket()
970         portSSL = Connection(ctx, port)
971         portSSL.bind(('', 0))
972         portSSL.listen(3)
973
974         clientSSL = Connection(Context(TLSv1_METHOD), socket())
975
976         # Calling portSSL.getsockname() here to get the server IP address sounds
977         # great, but frequently fails on Windows.
978         clientSSL.connect(('127.0.0.1', portSSL.getsockname()[1]))
979
980         serverSSL, address = portSSL.accept()
981
982         self.assertTrue(isinstance(serverSSL, Connection))
983         self.assertIdentical(serverSSL.get_context(), ctx)
984         self.assertEquals(address, clientSSL.getsockname())
985
986
987     def test_shutdown_wrong_args(self):
988         """
989         L{Connection.shutdown} raises L{TypeError} if called with the wrong
990         number of arguments or with arguments other than integers.
991         """
992         connection = Connection(Context(TLSv1_METHOD), None)
993         self.assertRaises(TypeError, connection.shutdown, None)
994         self.assertRaises(TypeError, connection.get_shutdown, None)
995         self.assertRaises(TypeError, connection.set_shutdown)
996         self.assertRaises(TypeError, connection.set_shutdown, None)
997         self.assertRaises(TypeError, connection.set_shutdown, 0, 1)
998
999
1000     def test_shutdown(self):
1001         """
1002         L{Connection.shutdown} performs an SSL-level connection shutdown.
1003         """
1004         server, client = self._loopback()
1005         self.assertFalse(server.shutdown())
1006         self.assertEquals(server.get_shutdown(), SENT_SHUTDOWN)
1007         self.assertRaises(ZeroReturnError, client.recv, 1024)
1008         self.assertEquals(client.get_shutdown(), RECEIVED_SHUTDOWN)
1009         client.shutdown()
1010         self.assertEquals(client.get_shutdown(), SENT_SHUTDOWN|RECEIVED_SHUTDOWN)
1011         self.assertRaises(ZeroReturnError, server.recv, 1024)
1012         self.assertEquals(server.get_shutdown(), SENT_SHUTDOWN|RECEIVED_SHUTDOWN)
1013
1014
1015     def test_set_shutdown(self):
1016         """
1017         L{Connection.set_shutdown} sets the state of the SSL connection shutdown
1018         process.
1019         """
1020         connection = Connection(Context(TLSv1_METHOD), socket())
1021         connection.set_shutdown(RECEIVED_SHUTDOWN)
1022         self.assertEquals(connection.get_shutdown(), RECEIVED_SHUTDOWN)
1023
1024
1025     def test_app_data_wrong_args(self):
1026         """
1027         L{Connection.set_app_data} raises L{TypeError} if called with other than
1028         one argument.  L{Connection.get_app_data} raises L{TypeError} if called
1029         with any arguments.
1030         """
1031         conn = Connection(Context(TLSv1_METHOD), None)
1032         self.assertRaises(TypeError, conn.get_app_data, None)
1033         self.assertRaises(TypeError, conn.set_app_data)
1034         self.assertRaises(TypeError, conn.set_app_data, None, None)
1035
1036
1037     def test_app_data(self):
1038         """
1039         Any object can be set as app data by passing it to
1040         L{Connection.set_app_data} and later retrieved with
1041         L{Connection.get_app_data}.
1042         """
1043         conn = Connection(Context(TLSv1_METHOD), None)
1044         app_data = object()
1045         conn.set_app_data(app_data)
1046         self.assertIdentical(conn.get_app_data(), app_data)
1047
1048
1049     def test_makefile(self):
1050         """
1051         L{Connection.makefile} is not implemented and calling that method raises
1052         L{NotImplementedError}.
1053         """
1054         conn = Connection(Context(TLSv1_METHOD), None)
1055         self.assertRaises(NotImplementedError, conn.makefile)
1056
1057
1058
1059 class ConnectionGetCipherListTests(TestCase):
1060     """
1061     Tests for L{Connection.get_cipher_list}.
1062     """
1063     def test_wrong_args(self):
1064         """
1065         L{Connection.get_cipher_list} raises L{TypeError} if called with any
1066         arguments.
1067         """
1068         connection = Connection(Context(TLSv1_METHOD), None)
1069         self.assertRaises(TypeError, connection.get_cipher_list, None)
1070
1071
1072     def test_result(self):
1073         """
1074         L{Connection.get_cipher_list} returns a C{list} of C{str} giving the
1075         names of the ciphers which might be used.
1076         """
1077         connection = Connection(Context(TLSv1_METHOD), None)
1078         ciphers = connection.get_cipher_list()
1079         self.assertTrue(isinstance(ciphers, list))
1080         for cipher in ciphers:
1081             self.assertTrue(isinstance(cipher, str))
1082
1083
1084
1085 class ConnectionSendTests(TestCase, _LoopbackMixin):
1086     """
1087     Tests for L{Connection.send}
1088     """
1089     def test_wrong_args(self):
1090         """
1091         When called with arguments other than a single string,
1092         L{Connection.send} raises L{TypeError}.
1093         """
1094         connection = Connection(Context(TLSv1_METHOD), None)
1095         self.assertRaises(TypeError, connection.send)
1096         self.assertRaises(TypeError, connection.send, object())
1097         self.assertRaises(TypeError, connection.send, "foo", "bar")
1098
1099
1100     def test_short_bytes(self):
1101         """
1102         When passed a short byte string, L{Connection.send} transmits all of it
1103         and returns the number of bytes sent.
1104         """
1105         server, client = self._loopback()
1106         count = server.send(b('xy'))
1107         self.assertEquals(count, 2)
1108         self.assertEquals(client.recv(2), b('xy'))
1109
1110     try:
1111         memoryview
1112     except NameError:
1113         "cannot test sending memoryview without memoryview"
1114     else:
1115         def test_short_memoryview(self):
1116             """
1117             When passed a memoryview onto a small number of bytes,
1118             L{Connection.send} transmits all of them and returns the number of
1119             bytes sent.
1120             """
1121             server, client = self._loopback()
1122             count = server.send(memoryview(b('xy')))
1123             self.assertEquals(count, 2)
1124             self.assertEquals(client.recv(2), b('xy'))
1125
1126
1127
1128 class ConnectionSendallTests(TestCase, _LoopbackMixin):
1129     """
1130     Tests for L{Connection.sendall}.
1131     """
1132     def test_wrong_args(self):
1133         """
1134         When called with arguments other than a single string,
1135         L{Connection.sendall} raises L{TypeError}.
1136         """
1137         connection = Connection(Context(TLSv1_METHOD), None)
1138         self.assertRaises(TypeError, connection.sendall)
1139         self.assertRaises(TypeError, connection.sendall, object())
1140         self.assertRaises(TypeError, connection.sendall, "foo", "bar")
1141
1142
1143     def test_short(self):
1144         """
1145         L{Connection.sendall} transmits all of the bytes in the string passed to
1146         it.
1147         """
1148         server, client = self._loopback()
1149         server.sendall(b('x'))
1150         self.assertEquals(client.recv(1), b('x'))
1151
1152
1153     try:
1154         memoryview
1155     except NameError:
1156         "cannot test sending memoryview without memoryview"
1157     else:
1158         def test_short_memoryview(self):
1159             """
1160             When passed a memoryview onto a small number of bytes,
1161             L{Connection.sendall} transmits all of them.
1162             """
1163             server, client = self._loopback()
1164             server.sendall(memoryview(b('x')))
1165             self.assertEquals(client.recv(1), b('x'))
1166
1167
1168     def test_long(self):
1169         """
1170         L{Connection.sendall} transmits all of the bytes in the string passed to
1171         it even if this requires multiple calls of an underlying write function.
1172         """
1173         server, client = self._loopback()
1174         # Should be enough, underlying SSL_write should only do 16k at a time.
1175         # On Windows, after 32k of bytes the write will block (forever - because
1176         # no one is yet reading).
1177         message = b('x') * (1024 * 32 - 1) + b('y')
1178         server.sendall(message)
1179         accum = []
1180         received = 0
1181         while received < len(message):
1182             data = client.recv(1024)
1183             accum.append(data)
1184             received += len(data)
1185         self.assertEquals(message, b('').join(accum))
1186
1187
1188     def test_closed(self):
1189         """
1190         If the underlying socket is closed, L{Connection.sendall} propagates the
1191         write error from the low level write call.
1192         """
1193         server, client = self._loopback()
1194         server.sock_shutdown(2)
1195         self.assertRaises(SysCallError, server.sendall, "hello, world")
1196
1197
1198
1199 class ConnectionRenegotiateTests(TestCase, _LoopbackMixin):
1200     """
1201     Tests for SSL renegotiation APIs.
1202     """
1203     def test_renegotiate_wrong_args(self):
1204         """
1205         L{Connection.renegotiate} raises L{TypeError} if called with any
1206         arguments.
1207         """
1208         connection = Connection(Context(TLSv1_METHOD), None)
1209         self.assertRaises(TypeError, connection.renegotiate, None)
1210
1211
1212     def test_total_renegotiations_wrong_args(self):
1213         """
1214         L{Connection.total_renegotiations} raises L{TypeError} if called with
1215         any arguments.
1216         """
1217         connection = Connection(Context(TLSv1_METHOD), None)
1218         self.assertRaises(TypeError, connection.total_renegotiations, None)
1219
1220
1221     def test_total_renegotiations(self):
1222         """
1223         L{Connection.total_renegotiations} returns C{0} before any
1224         renegotiations have happened.
1225         """
1226         connection = Connection(Context(TLSv1_METHOD), None)
1227         self.assertEquals(connection.total_renegotiations(), 0)
1228
1229
1230 #     def test_renegotiate(self):
1231 #         """
1232 #         """
1233 #         server, client = self._loopback()
1234
1235 #         server.send("hello world")
1236 #         self.assertEquals(client.recv(len("hello world")), "hello world")
1237
1238 #         self.assertEquals(server.total_renegotiations(), 0)
1239 #         self.assertTrue(server.renegotiate())
1240
1241 #         server.setblocking(False)
1242 #         client.setblocking(False)
1243 #         while server.renegotiate_pending():
1244 #             client.do_handshake()
1245 #             server.do_handshake()
1246
1247 #         self.assertEquals(server.total_renegotiations(), 1)
1248
1249
1250
1251
1252 class ErrorTests(TestCase):
1253     """
1254     Unit tests for L{OpenSSL.SSL.Error}.
1255     """
1256     def test_type(self):
1257         """
1258         L{Error} is an exception type.
1259         """
1260         self.assertTrue(issubclass(Error, Exception))
1261         self.assertEqual(Error.__name__, 'Error')
1262
1263
1264
1265 class ConstantsTests(TestCase):
1266     """
1267     Tests for the values of constants exposed in L{OpenSSL.SSL}.
1268
1269     These are values defined by OpenSSL intended only to be used as flags to
1270     OpenSSL APIs.  The only assertions it seems can be made about them is
1271     their values.
1272     """
1273     # unittest.TestCase has no skip mechanism
1274     if OP_NO_QUERY_MTU is not None:
1275         def test_op_no_query_mtu(self):
1276             """
1277             The value of L{OpenSSL.SSL.OP_NO_QUERY_MTU} is 0x1000, the value of
1278             I{SSL_OP_NO_QUERY_MTU} defined by I{openssl/ssl.h}.
1279             """
1280             self.assertEqual(OP_NO_QUERY_MTU, 0x1000)
1281     else:
1282         "OP_NO_QUERY_MTU unavailable - OpenSSL version may be too old"
1283
1284
1285     if OP_COOKIE_EXCHANGE is not None:
1286         def test_op_cookie_exchange(self):
1287             """
1288             The value of L{OpenSSL.SSL.OP_COOKIE_EXCHANGE} is 0x2000, the value
1289             of I{SSL_OP_COOKIE_EXCHANGE} defined by I{openssl/ssl.h}.
1290             """
1291             self.assertEqual(OP_COOKIE_EXCHANGE, 0x2000)
1292     else:
1293         "OP_COOKIE_EXCHANGE unavailable - OpenSSL version may be too old"
1294
1295
1296     if OP_NO_TICKET is not None:
1297         def test_op_no_ticket(self):
1298             """
1299             The value of L{OpenSSL.SSL.OP_NO_TICKET} is 0x4000, the value of
1300             I{SSL_OP_NO_TICKET} defined by I{openssl/ssl.h}.
1301             """
1302             self.assertEqual(OP_NO_TICKET, 0x4000)
1303     else:
1304         "OP_NO_TICKET unavailable - OpenSSL version may be too old"
1305
1306
1307
1308 class MemoryBIOTests(TestCase, _LoopbackMixin):
1309     """
1310     Tests for L{OpenSSL.SSL.Connection} using a memory BIO.
1311     """
1312     def _server(self, sock):
1313         """
1314         Create a new server-side SSL L{Connection} object wrapped around
1315         C{sock}.
1316         """
1317         # Create the server side Connection.  This is mostly setup boilerplate
1318         # - use TLSv1, use a particular certificate, etc.
1319         server_ctx = Context(TLSv1_METHOD)
1320         server_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE )
1321         server_ctx.set_verify(VERIFY_PEER|VERIFY_FAIL_IF_NO_PEER_CERT|VERIFY_CLIENT_ONCE, verify_cb)
1322         server_store = server_ctx.get_cert_store()
1323         server_ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
1324         server_ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
1325         server_ctx.check_privatekey()
1326         server_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem))
1327         # Here the Connection is actually created.  If None is passed as the 2nd
1328         # parameter, it indicates a memory BIO should be created.
1329         server_conn = Connection(server_ctx, sock)
1330         server_conn.set_accept_state()
1331         return server_conn
1332
1333
1334     def _client(self, sock):
1335         """
1336         Create a new client-side SSL L{Connection} object wrapped around
1337         C{sock}.
1338         """
1339         # Now create the client side Connection.  Similar boilerplate to the
1340         # above.
1341         client_ctx = Context(TLSv1_METHOD)
1342         client_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE )
1343         client_ctx.set_verify(VERIFY_PEER|VERIFY_FAIL_IF_NO_PEER_CERT|VERIFY_CLIENT_ONCE, verify_cb)
1344         client_store = client_ctx.get_cert_store()
1345         client_ctx.use_privatekey(load_privatekey(FILETYPE_PEM, client_key_pem))
1346         client_ctx.use_certificate(load_certificate(FILETYPE_PEM, client_cert_pem))
1347         client_ctx.check_privatekey()
1348         client_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem))
1349         client_conn = Connection(client_ctx, sock)
1350         client_conn.set_connect_state()
1351         return client_conn
1352
1353
1354     def test_memoryConnect(self):
1355         """
1356         Two L{Connection}s which use memory BIOs can be manually connected by
1357         reading from the output of each and writing those bytes to the input of
1358         the other and in this way establish a connection and exchange
1359         application-level bytes with each other.
1360         """
1361         server_conn = self._server(None)
1362         client_conn = self._client(None)
1363
1364         # There should be no key or nonces yet.
1365         self.assertIdentical(server_conn.master_key(), None)
1366         self.assertIdentical(server_conn.client_random(), None)
1367         self.assertIdentical(server_conn.server_random(), None)
1368
1369         # First, the handshake needs to happen.  We'll deliver bytes back and
1370         # forth between the client and server until neither of them feels like
1371         # speaking any more.
1372         self.assertIdentical(
1373             self._interactInMemory(client_conn, server_conn), None)
1374
1375         # Now that the handshake is done, there should be a key and nonces.
1376         self.assertNotIdentical(server_conn.master_key(), None)
1377         self.assertNotIdentical(server_conn.client_random(), None)
1378         self.assertNotIdentical(server_conn.server_random(), None)
1379         self.assertEquals(server_conn.client_random(), client_conn.client_random())
1380         self.assertEquals(server_conn.server_random(), client_conn.server_random())
1381         self.assertNotEquals(server_conn.client_random(), server_conn.server_random())
1382         self.assertNotEquals(client_conn.client_random(), client_conn.server_random())
1383
1384         # Here are the bytes we'll try to send.
1385         important_message = b('One if by land, two if by sea.')
1386
1387         server_conn.write(important_message)
1388         self.assertEquals(
1389             self._interactInMemory(client_conn, server_conn),
1390             (client_conn, important_message))
1391
1392         client_conn.write(important_message[::-1])
1393         self.assertEquals(
1394             self._interactInMemory(client_conn, server_conn),
1395             (server_conn, important_message[::-1]))
1396
1397
1398     def test_socketConnect(self):
1399         """
1400         Just like L{test_memoryConnect} but with an actual socket.
1401
1402         This is primarily to rule out the memory BIO code as the source of
1403         any problems encountered while passing data over a L{Connection} (if
1404         this test fails, there must be a problem outside the memory BIO
1405         code, as no memory BIO is involved here).  Even though this isn't a
1406         memory BIO test, it's convenient to have it here.
1407         """
1408         server_conn, client_conn = self._loopback()
1409
1410         important_message = b("Help me Obi Wan Kenobi, you're my only hope.")
1411         client_conn.send(important_message)
1412         msg = server_conn.recv(1024)
1413         self.assertEqual(msg, important_message)
1414
1415         # Again in the other direction, just for fun.
1416         important_message = important_message[::-1]
1417         server_conn.send(important_message)
1418         msg = client_conn.recv(1024)
1419         self.assertEqual(msg, important_message)
1420
1421
1422     def test_socketOverridesMemory(self):
1423         """
1424         Test that L{OpenSSL.SSL.bio_read} and L{OpenSSL.SSL.bio_write} don't
1425         work on L{OpenSSL.SSL.Connection}() that use sockets.
1426         """
1427         context = Context(SSLv3_METHOD)
1428         client = socket()
1429         clientSSL = Connection(context, client)
1430         self.assertRaises( TypeError, clientSSL.bio_read, 100)
1431         self.assertRaises( TypeError, clientSSL.bio_write, "foo")
1432         self.assertRaises( TypeError, clientSSL.bio_shutdown )
1433
1434
1435     def test_outgoingOverflow(self):
1436         """
1437         If more bytes than can be written to the memory BIO are passed to
1438         L{Connection.send} at once, the number of bytes which were written is
1439         returned and that many bytes from the beginning of the input can be
1440         read from the other end of the connection.
1441         """
1442         server = self._server(None)
1443         client = self._client(None)
1444
1445         self._interactInMemory(client, server)
1446
1447         size = 2 ** 15
1448         sent = client.send("x" * size)
1449         # Sanity check.  We're trying to test what happens when the entire
1450         # input can't be sent.  If the entire input was sent, this test is
1451         # meaningless.
1452         self.assertTrue(sent < size)
1453
1454         receiver, received = self._interactInMemory(client, server)
1455         self.assertIdentical(receiver, server)
1456
1457         # We can rely on all of these bytes being received at once because
1458         # _loopback passes 2 ** 16 to recv - more than 2 ** 15.
1459         self.assertEquals(len(received), sent)
1460
1461
1462     def test_shutdown(self):
1463         """
1464         L{Connection.bio_shutdown} signals the end of the data stream from
1465         which the L{Connection} reads.
1466         """
1467         server = self._server(None)
1468         server.bio_shutdown()
1469         e = self.assertRaises(Error, server.recv, 1024)
1470         # We don't want WantReadError or ZeroReturnError or anything - it's a
1471         # handshake failure.
1472         self.assertEquals(e.__class__, Error)
1473
1474
1475     def _check_client_ca_list(self, func):
1476         """
1477         Verify the return value of the C{get_client_ca_list} method for server and client connections.
1478
1479         @param func: A function which will be called with the server context
1480             before the client and server are connected to each other.  This
1481             function should specify a list of CAs for the server to send to the
1482             client and return that same list.  The list will be used to verify
1483             that C{get_client_ca_list} returns the proper value at various
1484             times.
1485         """
1486         server = self._server(None)
1487         client = self._client(None)
1488         self.assertEqual(client.get_client_ca_list(), [])
1489         self.assertEqual(server.get_client_ca_list(), [])
1490         ctx = server.get_context()
1491         expected = func(ctx)
1492         self.assertEqual(client.get_client_ca_list(), [])
1493         self.assertEqual(server.get_client_ca_list(), expected)
1494         self._interactInMemory(client, server)
1495         self.assertEqual(client.get_client_ca_list(), expected)
1496         self.assertEqual(server.get_client_ca_list(), expected)
1497
1498
1499     def test_set_client_ca_list_errors(self):
1500         """
1501         L{Context.set_client_ca_list} raises a L{TypeError} if called with a
1502         non-list or a list that contains objects other than X509Names.
1503         """
1504         ctx = Context(TLSv1_METHOD)
1505         self.assertRaises(TypeError, ctx.set_client_ca_list, "spam")
1506         self.assertRaises(TypeError, ctx.set_client_ca_list, ["spam"])
1507         self.assertIdentical(ctx.set_client_ca_list([]), None)
1508
1509
1510     def test_set_empty_ca_list(self):
1511         """
1512         If passed an empty list, L{Context.set_client_ca_list} configures the
1513         context to send no CA names to the client and, on both the server and
1514         client sides, L{Connection.get_client_ca_list} returns an empty list
1515         after the connection is set up.
1516         """
1517         def no_ca(ctx):
1518             ctx.set_client_ca_list([])
1519             return []
1520         self._check_client_ca_list(no_ca)
1521
1522
1523     def test_set_one_ca_list(self):
1524         """
1525         If passed a list containing a single X509Name,
1526         L{Context.set_client_ca_list} configures the context to send that CA
1527         name to the client and, on both the server and client sides,
1528         L{Connection.get_client_ca_list} returns a list containing that
1529         X509Name after the connection is set up.
1530         """
1531         cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
1532         cadesc = cacert.get_subject()
1533         def single_ca(ctx):
1534             ctx.set_client_ca_list([cadesc])
1535             return [cadesc]
1536         self._check_client_ca_list(single_ca)
1537
1538
1539     def test_set_multiple_ca_list(self):
1540         """
1541         If passed a list containing multiple X509Name objects,
1542         L{Context.set_client_ca_list} configures the context to send those CA
1543         names to the client and, on both the server and client sides,
1544         L{Connection.get_client_ca_list} returns a list containing those
1545         X509Names after the connection is set up.
1546         """
1547         secert = load_certificate(FILETYPE_PEM, server_cert_pem)
1548         clcert = load_certificate(FILETYPE_PEM, server_cert_pem)
1549
1550         sedesc = secert.get_subject()
1551         cldesc = clcert.get_subject()
1552
1553         def multiple_ca(ctx):
1554             L = [sedesc, cldesc]
1555             ctx.set_client_ca_list(L)
1556             return L
1557         self._check_client_ca_list(multiple_ca)
1558
1559
1560     def test_reset_ca_list(self):
1561         """
1562         If called multiple times, only the X509Names passed to the final call
1563         of L{Context.set_client_ca_list} are used to configure the CA names
1564         sent to the client.
1565         """
1566         cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
1567         secert = load_certificate(FILETYPE_PEM, server_cert_pem)
1568         clcert = load_certificate(FILETYPE_PEM, server_cert_pem)
1569
1570         cadesc = cacert.get_subject()
1571         sedesc = secert.get_subject()
1572         cldesc = clcert.get_subject()
1573
1574         def changed_ca(ctx):
1575             ctx.set_client_ca_list([sedesc, cldesc])
1576             ctx.set_client_ca_list([cadesc])
1577             return [cadesc]
1578         self._check_client_ca_list(changed_ca)
1579
1580
1581     def test_mutated_ca_list(self):
1582         """
1583         If the list passed to L{Context.set_client_ca_list} is mutated
1584         afterwards, this does not affect the list of CA names sent to the
1585         client.
1586         """
1587         cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
1588         secert = load_certificate(FILETYPE_PEM, server_cert_pem)
1589
1590         cadesc = cacert.get_subject()
1591         sedesc = secert.get_subject()
1592
1593         def mutated_ca(ctx):
1594             L = [cadesc]
1595             ctx.set_client_ca_list([cadesc])
1596             L.append(sedesc)
1597             return [cadesc]
1598         self._check_client_ca_list(mutated_ca)
1599
1600
1601     def test_add_client_ca_errors(self):
1602         """
1603         L{Context.add_client_ca} raises L{TypeError} if called with a non-X509
1604         object or with a number of arguments other than one.
1605         """
1606         ctx = Context(TLSv1_METHOD)
1607         cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
1608         self.assertRaises(TypeError, ctx.add_client_ca)
1609         self.assertRaises(TypeError, ctx.add_client_ca, "spam")
1610         self.assertRaises(TypeError, ctx.add_client_ca, cacert, cacert)
1611
1612
1613     def test_one_add_client_ca(self):
1614         """
1615         A certificate's subject can be added as a CA to be sent to the client
1616         with L{Context.add_client_ca}.
1617         """
1618         cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
1619         cadesc = cacert.get_subject()
1620         def single_ca(ctx):
1621             ctx.add_client_ca(cacert)
1622             return [cadesc]
1623         self._check_client_ca_list(single_ca)
1624
1625
1626     def test_multiple_add_client_ca(self):
1627         """
1628         Multiple CA names can be sent to the client by calling
1629         L{Context.add_client_ca} with multiple X509 objects.
1630         """
1631         cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
1632         secert = load_certificate(FILETYPE_PEM, server_cert_pem)
1633
1634         cadesc = cacert.get_subject()
1635         sedesc = secert.get_subject()
1636
1637         def multiple_ca(ctx):
1638             ctx.add_client_ca(cacert)
1639             ctx.add_client_ca(secert)
1640             return [cadesc, sedesc]
1641         self._check_client_ca_list(multiple_ca)
1642
1643
1644     def test_set_and_add_client_ca(self):
1645         """
1646         A call to L{Context.set_client_ca_list} followed by a call to
1647         L{Context.add_client_ca} results in using the CA names from the first
1648         call and the CA name from the second call.
1649         """
1650         cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
1651         secert = load_certificate(FILETYPE_PEM, server_cert_pem)
1652         clcert = load_certificate(FILETYPE_PEM, server_cert_pem)
1653
1654         cadesc = cacert.get_subject()
1655         sedesc = secert.get_subject()
1656         cldesc = clcert.get_subject()
1657
1658         def mixed_set_add_ca(ctx):
1659             ctx.set_client_ca_list([cadesc, sedesc])
1660             ctx.add_client_ca(clcert)
1661             return [cadesc, sedesc, cldesc]
1662         self._check_client_ca_list(mixed_set_add_ca)
1663
1664
1665     def test_set_after_add_client_ca(self):
1666         """
1667         A call to L{Context.set_client_ca_list} after a call to
1668         L{Context.add_client_ca} replaces the CA name specified by the former
1669         call with the names specified by the latter cal.
1670         """
1671         cacert = load_certificate(FILETYPE_PEM, root_cert_pem)
1672         secert = load_certificate(FILETYPE_PEM, server_cert_pem)
1673         clcert = load_certificate(FILETYPE_PEM, server_cert_pem)
1674
1675         cadesc = cacert.get_subject()
1676         sedesc = secert.get_subject()
1677
1678         def set_replaces_add_ca(ctx):
1679             ctx.add_client_ca(clcert)
1680             ctx.set_client_ca_list([cadesc])
1681             ctx.add_client_ca(secert)
1682             return [cadesc, sedesc]
1683         self._check_client_ca_list(set_replaces_add_ca)
1684
1685
1686 class InfoConstantTests(TestCase):
1687     """
1688     Tests for assorted constants exposed for use in info callbacks.
1689     """
1690     def test_integers(self):
1691         """
1692         All of the info constants are integers.
1693
1694         This is a very weak test.  It would be nice to have one that actually
1695         verifies that as certain info events happen, the value passed to the
1696         info callback matches up with the constant exposed by OpenSSL.SSL.
1697         """
1698         for const in [
1699             SSL_ST_CONNECT, SSL_ST_ACCEPT, SSL_ST_MASK, SSL_ST_INIT,
1700             SSL_ST_BEFORE, SSL_ST_OK, SSL_ST_RENEGOTIATE,
1701             SSL_CB_LOOP, SSL_CB_EXIT, SSL_CB_READ, SSL_CB_WRITE, SSL_CB_ALERT,
1702             SSL_CB_READ_ALERT, SSL_CB_WRITE_ALERT, SSL_CB_ACCEPT_LOOP,
1703             SSL_CB_ACCEPT_EXIT, SSL_CB_CONNECT_LOOP, SSL_CB_CONNECT_EXIT,
1704             SSL_CB_HANDSHAKE_START, SSL_CB_HANDSHAKE_DONE]:
1705
1706             self.assertTrue(isinstance(const, int))
1707
1708
1709 if __name__ == '__main__':
1710     main()