Upstream version 7.35.144.0
[platform/framework/web/crosswalk.git] / src / third_party / tlslite / tlslite / TLSRecordLayer.py
1 """Helper class for TLSConnection."""
2 from __future__ import generators
3
4 from utils.compat import *
5 from utils.cryptomath import *
6 from utils.cipherfactory import createAES, createRC4, createTripleDES
7 from utils.codec import *
8 from errors import *
9 from messages import *
10 from mathtls import *
11 from constants import *
12 from utils.cryptomath import getRandomBytes
13 from utils import hmac
14 from FileObject import FileObject
15
16 # The sha module is deprecated in Python 2.6 
17 try:
18     import sha
19 except ImportError:
20     from hashlib import sha1 as sha
21
22 # The md5 module is deprecated in Python 2.6
23 try:
24     import md5
25 except ImportError:
26     from hashlib import md5
27
28 import socket
29 import errno
30 import traceback
31
32 class _ConnectionState:
33     def __init__(self):
34         self.macContext = None
35         self.encContext = None
36         self.seqnum = 0
37
38     def getSeqNumStr(self):
39         w = Writer(8)
40         w.add(self.seqnum, 8)
41         seqnumStr = bytesToString(w.bytes)
42         self.seqnum += 1
43         return seqnumStr
44
45
46 class TLSRecordLayer:
47     """
48     This class handles data transmission for a TLS connection.
49
50     Its only subclass is L{tlslite.TLSConnection.TLSConnection}.  We've
51     separated the code in this class from TLSConnection to make things
52     more readable.
53
54
55     @type sock: socket.socket
56     @ivar sock: The underlying socket object.
57
58     @type session: L{tlslite.Session.Session}
59     @ivar session: The session corresponding to this connection.
60
61     Due to TLS session resumption, multiple connections can correspond
62     to the same underlying session.
63
64     @type version: tuple
65     @ivar version: The TLS version being used for this connection.
66
67     (3,0) means SSL 3.0, and (3,1) means TLS 1.0.
68
69     @type closed: bool
70     @ivar closed: If this connection is closed.
71
72     @type resumed: bool
73     @ivar resumed: If this connection is based on a resumed session.
74
75     @type allegedSharedKeyUsername: str or None
76     @ivar allegedSharedKeyUsername:  This is set to the shared-key
77     username asserted by the client, whether the handshake succeeded or
78     not.  If the handshake fails, this can be inspected to
79     determine if a guessing attack is in progress against a particular
80     user account.
81
82     @type allegedSrpUsername: str or None
83     @ivar allegedSrpUsername:  This is set to the SRP username
84     asserted by the client, whether the handshake succeeded or not.
85     If the handshake fails, this can be inspected to determine
86     if a guessing attack is in progress against a particular user
87     account.
88
89     @type closeSocket: bool
90     @ivar closeSocket: If the socket should be closed when the
91     connection is closed (writable).
92
93     If you set this to True, TLS Lite will assume the responsibility of
94     closing the socket when the TLS Connection is shutdown (either
95     through an error or through the user calling close()).  The default
96     is False.
97
98     @type ignoreAbruptClose: bool
99     @ivar ignoreAbruptClose: If an abrupt close of the socket should
100     raise an error (writable).
101
102     If you set this to True, TLS Lite will not raise a
103     L{tlslite.errors.TLSAbruptCloseError} exception if the underlying
104     socket is unexpectedly closed.  Such an unexpected closure could be
105     caused by an attacker.  However, it also occurs with some incorrect
106     TLS implementations.
107
108     You should set this to True only if you're not worried about an
109     attacker truncating the connection, and only if necessary to avoid
110     spurious errors.  The default is False.
111
112     @sort: __init__, read, readAsync, write, writeAsync, close, closeAsync,
113     getCipherImplementation, getCipherName
114     """
115
116     def __init__(self, sock):
117         self.sock = sock
118
119         #My session object (Session instance; read-only)
120         self.session = None
121
122         #Am I a client or server?
123         self._client = None
124
125         #Buffers for processing messages
126         self._handshakeBuffer = []
127         self._readBuffer = ""
128
129         #Handshake digests
130         self._handshake_md5 = md5.md5()
131         self._handshake_sha = sha.sha()
132
133         #TLS Protocol Version
134         self.version = (0,0) #read-only
135         self._versionCheck = False #Once we choose a version, this is True
136
137         #Current and Pending connection states
138         self._writeState = _ConnectionState()
139         self._readState = _ConnectionState()
140         self._pendingWriteState = _ConnectionState()
141         self._pendingReadState = _ConnectionState()
142
143         #Is the connection open?
144         self.closed = True #read-only
145         self._refCount = 0 #Used to trigger closure
146
147         #Is this a resumed (or shared-key) session?
148         self.resumed = False #read-only
149
150         #What username did the client claim in his handshake?
151         self.allegedSharedKeyUsername = None
152         self.allegedSrpUsername = None
153
154         #On a call to close(), do we close the socket? (writeable)
155         self.closeSocket = False
156
157         #If the socket is abruptly closed, do we ignore it
158         #and pretend the connection was shut down properly? (writeable)
159         self.ignoreAbruptClose = False
160
161         #Fault we will induce, for testing purposes
162         self.fault = None
163
164     #*********************************************************
165     # Public Functions START
166     #*********************************************************
167
168     def read(self, max=None, min=1):
169         """Read some data from the TLS connection.
170
171         This function will block until at least 'min' bytes are
172         available (or the connection is closed).
173
174         If an exception is raised, the connection will have been
175         automatically closed.
176
177         @type max: int
178         @param max: The maximum number of bytes to return.
179
180         @type min: int
181         @param min: The minimum number of bytes to return
182
183         @rtype: str
184         @return: A string of no more than 'max' bytes, and no fewer
185         than 'min' (unless the connection has been closed, in which
186         case fewer than 'min' bytes may be returned).
187
188         @raise socket.error: If a socket error occurs.
189         @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
190         without a preceding alert.
191         @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
192         """
193         for result in self.readAsync(max, min):
194             pass
195         return result
196
197     def readAsync(self, max=None, min=1):
198         """Start a read operation on the TLS connection.
199
200         This function returns a generator which behaves similarly to
201         read().  Successive invocations of the generator will return 0
202         if it is waiting to read from the socket, 1 if it is waiting
203         to write to the socket, or a string if the read operation has
204         completed.
205
206         @rtype: iterable
207         @return: A generator; see above for details.
208         """
209         try:
210             while len(self._readBuffer)<min and not self.closed:
211                 try:
212                     for result in self._getMsg(ContentType.application_data):
213                         if result in (0,1):
214                             yield result
215                     applicationData = result
216                     self._readBuffer += bytesToString(applicationData.write())
217                 except TLSRemoteAlert, alert:
218                     if alert.description != AlertDescription.close_notify:
219                         raise
220                 except TLSAbruptCloseError:
221                     if not self.ignoreAbruptClose:
222                         raise
223                     else:
224                         self._shutdown(True)
225
226             if max == None:
227                 max = len(self._readBuffer)
228
229             returnStr = self._readBuffer[:max]
230             self._readBuffer = self._readBuffer[max:]
231             yield returnStr
232         except:
233             self._shutdown(False)
234             raise
235
236     def write(self, s):
237         """Write some data to the TLS connection.
238
239         This function will block until all the data has been sent.
240
241         If an exception is raised, the connection will have been
242         automatically closed.
243
244         @type s: str
245         @param s: The data to transmit to the other party.
246
247         @raise socket.error: If a socket error occurs.
248         """
249         for result in self.writeAsync(s):
250             pass
251
252     def writeAsync(self, s):
253         """Start a write operation on the TLS connection.
254
255         This function returns a generator which behaves similarly to
256         write().  Successive invocations of the generator will return
257         1 if it is waiting to write to the socket, or will raise
258         StopIteration if the write operation has completed.
259
260         @rtype: iterable
261         @return: A generator; see above for details.
262         """
263         try:
264             if self.closed:
265                 raise ValueError()
266
267             index = 0
268             blockSize = 16384
269             skipEmptyFrag = False
270             while 1:
271                 startIndex = index * blockSize
272                 endIndex = startIndex + blockSize
273                 if startIndex >= len(s):
274                     break
275                 if endIndex > len(s):
276                     endIndex = len(s)
277                 block = stringToBytes(s[startIndex : endIndex])
278                 applicationData = ApplicationData().create(block)
279                 for result in self._sendMsg(applicationData, skipEmptyFrag):
280                     yield result
281                 skipEmptyFrag = True #only send an empy fragment on 1st message
282                 index += 1
283         except:
284             self._shutdown(False)
285             raise
286
287     def close(self):
288         """Close the TLS connection.
289
290         This function will block until it has exchanged close_notify
291         alerts with the other party.  After doing so, it will shut down the
292         TLS connection.  Further attempts to read through this connection
293         will return "".  Further attempts to write through this connection
294         will raise ValueError.
295
296         If makefile() has been called on this connection, the connection
297         will be not be closed until the connection object and all file
298         objects have been closed.
299
300         Even if an exception is raised, the connection will have been
301         closed.
302
303         @raise socket.error: If a socket error occurs.
304         @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
305         without a preceding alert.
306         @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
307         """
308         if not self.closed:
309             for result in self._decrefAsync():
310                 pass
311
312     def closeAsync(self):
313         """Start a close operation on the TLS connection.
314
315         This function returns a generator which behaves similarly to
316         close().  Successive invocations of the generator will return 0
317         if it is waiting to read from the socket, 1 if it is waiting
318         to write to the socket, or will raise StopIteration if the
319         close operation has completed.
320
321         @rtype: iterable
322         @return: A generator; see above for details.
323         """
324         if not self.closed:
325             for result in self._decrefAsync():
326                 yield result
327
328     def _decrefAsync(self):
329         self._refCount -= 1
330         if self._refCount == 0 and not self.closed:
331             try:
332                 for result in self._sendMsg(Alert().create(\
333                         AlertDescription.close_notify, AlertLevel.warning)):
334                     yield result
335                 alert = None
336                 # Forcing a shutdown as WinHTTP does not seem to be
337                 # responsive to the close notify.
338                 prevCloseSocket = self.closeSocket
339                 self.closeSocket = True
340                 self._shutdown(True)
341                 self.closeSocket = prevCloseSocket
342                 while not alert:
343                     for result in self._getMsg((ContentType.alert, \
344                                               ContentType.application_data)):
345                         if result in (0,1):
346                             yield result
347                     if result.contentType == ContentType.alert:
348                         alert = result
349                 if alert.description == AlertDescription.close_notify:
350                     self._shutdown(True)
351                 else:
352                     raise TLSRemoteAlert(alert)
353             except (socket.error, TLSAbruptCloseError):
354                 #If the other side closes the socket, that's okay
355                 self._shutdown(True)
356             except:
357                 self._shutdown(False)
358                 raise
359
360     def getCipherName(self):
361         """Get the name of the cipher used with this connection.
362
363         @rtype: str
364         @return: The name of the cipher used with this connection.
365         Either 'aes128', 'aes256', 'rc4', or '3des'.
366         """
367         if not self._writeState.encContext:
368             return None
369         return self._writeState.encContext.name
370
371     def getCipherImplementation(self):
372         """Get the name of the cipher implementation used with
373         this connection.
374
375         @rtype: str
376         @return: The name of the cipher implementation used with
377         this connection.  Either 'python', 'cryptlib', 'openssl',
378         or 'pycrypto'.
379         """
380         if not self._writeState.encContext:
381             return None
382         return self._writeState.encContext.implementation
383
384
385
386     #Emulate a socket, somewhat -
387     def send(self, s):
388         """Send data to the TLS connection (socket emulation).
389
390         @raise socket.error: If a socket error occurs.
391         """
392         self.write(s)
393         return len(s)
394
395     def sendall(self, s):
396         """Send data to the TLS connection (socket emulation).
397
398         @raise socket.error: If a socket error occurs.
399         """
400         self.write(s)
401
402     def recv(self, bufsize):
403         """Get some data from the TLS connection (socket emulation).
404
405         @raise socket.error: If a socket error occurs.
406         @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
407         without a preceding alert.
408         @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
409         """
410         return self.read(bufsize)
411
412     def makefile(self, mode='r', bufsize=-1):
413         """Create a file object for the TLS connection (socket emulation).
414
415         @rtype: L{tlslite.FileObject.FileObject}
416         """
417         self._refCount += 1
418         return FileObject(self, mode, bufsize)
419
420     def getsockname(self):
421         """Return the socket's own address (socket emulation)."""
422         return self.sock.getsockname()
423
424     def getpeername(self):
425         """Return the remote address to which the socket is connected
426         (socket emulation)."""
427         return self.sock.getpeername()
428
429     def settimeout(self, value):
430         """Set a timeout on blocking socket operations (socket emulation)."""
431         return self.sock.settimeout(value)
432
433     def gettimeout(self):
434         """Return the timeout associated with socket operations (socket
435         emulation)."""
436         return self.sock.gettimeout()
437
438     def setsockopt(self, level, optname, value):
439         """Set the value of the given socket option (socket emulation)."""
440         return self.sock.setsockopt(level, optname, value)
441
442
443      #*********************************************************
444      # Public Functions END
445      #*********************************************************
446
447     def _shutdown(self, resumable):
448         self._writeState = _ConnectionState()
449         self._readState = _ConnectionState()
450         #Don't do this: self._readBuffer = ""
451         self.version = (0,0)
452         self._versionCheck = False
453         self.closed = True
454         if self.closeSocket:
455             self.sock.close()
456
457         #Even if resumable is False, we'll never toggle this on
458         if not resumable and self.session:
459             self.session.resumable = False
460
461
462     def _sendError(self, alertDescription, errorStr=None):
463         alert = Alert().create(alertDescription, AlertLevel.fatal)
464         for result in self._sendMsg(alert):
465             yield result
466         self._shutdown(False)
467         raise TLSLocalAlert(alert, errorStr)
468
469     def _sendMsgs(self, msgs):
470         skipEmptyFrag = False
471         for msg in msgs:
472             for result in self._sendMsg(msg, skipEmptyFrag):
473                 yield result
474             skipEmptyFrag = True
475
476     def _sendMsg(self, msg, skipEmptyFrag=False):
477         bytes = msg.write()
478         contentType = msg.contentType
479
480         #Whenever we're connected and asked to send a message,
481         #we first send an empty Application Data message.  This prevents
482         #an attacker from launching a chosen-plaintext attack based on
483         #knowing the next IV.
484         if not self.closed and not skipEmptyFrag and self.version == (3,1):
485             if self._writeState.encContext:
486                 if self._writeState.encContext.isBlockCipher:
487                     for result in self._sendMsg(ApplicationData(),
488                                                skipEmptyFrag=True):
489                         yield result
490
491         #Update handshake hashes
492         if contentType == ContentType.handshake:
493             bytesStr = bytesToString(bytes)
494             self._handshake_md5.update(bytesStr)
495             self._handshake_sha.update(bytesStr)
496
497         #Calculate MAC
498         if self._writeState.macContext:
499             seqnumStr = self._writeState.getSeqNumStr()
500             bytesStr = bytesToString(bytes)
501             mac = self._writeState.macContext.copy()
502             mac.update(seqnumStr)
503             mac.update(chr(contentType))
504             if self.version == (3,0):
505                 mac.update( chr( int(len(bytes)/256) ) )
506                 mac.update( chr( int(len(bytes)%256) ) )
507             elif self.version in ((3,1), (3,2)):
508                 mac.update(chr(self.version[0]))
509                 mac.update(chr(self.version[1]))
510                 mac.update( chr( int(len(bytes)/256) ) )
511                 mac.update( chr( int(len(bytes)%256) ) )
512             else:
513                 raise AssertionError()
514             mac.update(bytesStr)
515             macString = mac.digest()
516             macBytes = stringToBytes(macString)
517             if self.fault == Fault.badMAC:
518                 macBytes[0] = (macBytes[0]+1) % 256
519
520         #Encrypt for Block or Stream Cipher
521         if self._writeState.encContext:
522             #Add padding and encrypt (for Block Cipher):
523             if self._writeState.encContext.isBlockCipher:
524
525                 #Add TLS 1.1 fixed block
526                 if self.version == (3,2):
527                     bytes = self.fixedIVBlock + bytes
528
529                 #Add padding: bytes = bytes + (macBytes + paddingBytes)
530                 currentLength = len(bytes) + len(macBytes) + 1
531                 blockLength = self._writeState.encContext.block_size
532                 paddingLength = blockLength-(currentLength % blockLength)
533
534                 paddingBytes = createByteArraySequence([paddingLength] * \
535                                                        (paddingLength+1))
536                 if self.fault == Fault.badPadding:
537                     paddingBytes[0] = (paddingBytes[0]+1) % 256
538                 endBytes = concatArrays(macBytes, paddingBytes)
539                 bytes = concatArrays(bytes, endBytes)
540                 #Encrypt
541                 plaintext = stringToBytes(bytes)
542                 ciphertext = self._writeState.encContext.encrypt(plaintext)
543                 bytes = stringToBytes(ciphertext)
544
545             #Encrypt (for Stream Cipher)
546             else:
547                 bytes = concatArrays(bytes, macBytes)
548                 plaintext = bytesToString(bytes)
549                 ciphertext = self._writeState.encContext.encrypt(plaintext)
550                 bytes = stringToBytes(ciphertext)
551
552         #Add record header and send
553         r = RecordHeader3().create(self.version, contentType, len(bytes))
554         s = bytesToString(concatArrays(r.write(), bytes))
555         while 1:
556             try:
557                 bytesSent = self.sock.send(s) #Might raise socket.error
558             except socket.error, why:
559                 if why[0] == errno.EWOULDBLOCK:
560                     yield 1
561                     continue
562                 else:
563                     raise
564             if bytesSent == len(s):
565                 return
566             s = s[bytesSent:]
567             yield 1
568
569
570     def _getMsg(self, expectedType, secondaryType=None, constructorType=None):
571         try:
572             if not isinstance(expectedType, tuple):
573                 expectedType = (expectedType,)
574
575             #Spin in a loop, until we've got a non-empty record of a type we
576             #expect.  The loop will be repeated if:
577             #  - we receive a renegotiation attempt; we send no_renegotiation,
578             #    then try again
579             #  - we receive an empty application-data fragment; we try again
580             while 1:
581                 for result in self._getNextRecord():
582                     if result in (0,1):
583                         yield result
584                 recordHeader, p = result
585
586                 #If this is an empty application-data fragment, try again
587                 if recordHeader.type == ContentType.application_data:
588                     if p.index == len(p.bytes):
589                         continue
590
591                 #If we received an unexpected record type...
592                 if recordHeader.type not in expectedType:
593
594                     #If we received an alert...
595                     if recordHeader.type == ContentType.alert:
596                         alert = Alert().parse(p)
597
598                         #We either received a fatal error, a warning, or a
599                         #close_notify.  In any case, we're going to close the
600                         #connection.  In the latter two cases we respond with
601                         #a close_notify, but ignore any socket errors, since
602                         #the other side might have already closed the socket.
603                         if alert.level == AlertLevel.warning or \
604                            alert.description == AlertDescription.close_notify:
605
606                             #If the sendMsg() call fails because the socket has
607                             #already been closed, we will be forgiving and not
608                             #report the error nor invalidate the "resumability"
609                             #of the session.
610                             try:
611                                 alertMsg = Alert()
612                                 alertMsg.create(AlertDescription.close_notify,
613                                                 AlertLevel.warning)
614                                 for result in self._sendMsg(alertMsg):
615                                     yield result
616                             except socket.error:
617                                 pass
618
619                             if alert.description == \
620                                    AlertDescription.close_notify:
621                                 self._shutdown(True)
622                             elif alert.level == AlertLevel.warning:
623                                 self._shutdown(False)
624
625                         else: #Fatal alert:
626                             self._shutdown(False)
627
628                         #Raise the alert as an exception
629                         raise TLSRemoteAlert(alert)
630
631                     #If we received a renegotiation attempt...
632                     if recordHeader.type == ContentType.handshake:
633                         subType = p.get(1)
634                         reneg = False
635                         if self._client:
636                             if subType == HandshakeType.hello_request:
637                                 reneg = True
638                         else:
639                             if subType == HandshakeType.client_hello:
640                                 reneg = True
641                         #Send no_renegotiation, then try again
642                         if reneg:
643                             alertMsg = Alert()
644                             alertMsg.create(AlertDescription.no_renegotiation,
645                                             AlertLevel.warning)
646                             for result in self._sendMsg(alertMsg):
647                                 yield result
648                             continue
649
650                     #Otherwise: this is an unexpected record, but neither an
651                     #alert nor renegotiation
652                     for result in self._sendError(\
653                             AlertDescription.unexpected_message,
654                             "received type=%d" % recordHeader.type):
655                         yield result
656
657                 break
658
659             #Parse based on content_type
660             if recordHeader.type == ContentType.change_cipher_spec:
661                 yield ChangeCipherSpec().parse(p)
662             elif recordHeader.type == ContentType.alert:
663                 yield Alert().parse(p)
664             elif recordHeader.type == ContentType.application_data:
665                 yield ApplicationData().parse(p)
666             elif recordHeader.type == ContentType.handshake:
667                 #Convert secondaryType to tuple, if it isn't already
668                 if not isinstance(secondaryType, tuple):
669                     secondaryType = (secondaryType,)
670
671                 #If it's a handshake message, check handshake header
672                 if recordHeader.ssl2:
673                     subType = p.get(1)
674                     if subType != HandshakeType.client_hello:
675                         for result in self._sendError(\
676                                 AlertDescription.unexpected_message,
677                                 "Can only handle SSLv2 ClientHello messages"):
678                             yield result
679                     if HandshakeType.client_hello not in secondaryType:
680                         for result in self._sendError(\
681                                 AlertDescription.unexpected_message):
682                             yield result
683                     subType = HandshakeType.client_hello
684                 else:
685                     subType = p.get(1)
686                     if subType not in secondaryType:
687                         for result in self._sendError(\
688                                 AlertDescription.unexpected_message,
689                                 "Expecting %s, got %s" % (str(secondaryType), subType)):
690                             yield result
691
692                 #Update handshake hashes
693                 sToHash = bytesToString(p.bytes)
694                 self._handshake_md5.update(sToHash)
695                 self._handshake_sha.update(sToHash)
696
697                 #Parse based on handshake type
698                 if subType == HandshakeType.client_hello:
699                     yield ClientHello(recordHeader.ssl2).parse(p)
700                 elif subType == HandshakeType.server_hello:
701                     yield ServerHello().parse(p)
702                 elif subType == HandshakeType.certificate:
703                     yield Certificate(constructorType).parse(p)
704                 elif subType == HandshakeType.certificate_request:
705                     yield CertificateRequest().parse(p)
706                 elif subType == HandshakeType.certificate_verify:
707                     yield CertificateVerify().parse(p)
708                 elif subType == HandshakeType.server_key_exchange:
709                     yield ServerKeyExchange(constructorType).parse(p)
710                 elif subType == HandshakeType.server_hello_done:
711                     yield ServerHelloDone().parse(p)
712                 elif subType == HandshakeType.client_key_exchange:
713                     yield ClientKeyExchange(constructorType, \
714                                             self.version).parse(p)
715                 elif subType == HandshakeType.finished:
716                     yield Finished(self.version).parse(p)
717                 elif subType == HandshakeType.encrypted_extensions:
718                     yield EncryptedExtensions().parse(p)
719                 else:
720                     raise AssertionError()
721
722         #If an exception was raised by a Parser or Message instance:
723         except SyntaxError, e:
724             for result in self._sendError(AlertDescription.decode_error,
725                                          formatExceptionTrace(e)):
726                 yield result
727
728
729     #Returns next record or next handshake message
730     def _getNextRecord(self):
731
732         #If there's a handshake message waiting, return it
733         if self._handshakeBuffer:
734             recordHeader, bytes = self._handshakeBuffer[0]
735             self._handshakeBuffer = self._handshakeBuffer[1:]
736             yield (recordHeader, Parser(bytes))
737             return
738
739         #Otherwise...
740         #Read the next record header
741         bytes = createByteArraySequence([])
742         recordHeaderLength = 1
743         ssl2 = False
744         while 1:
745             try:
746                 s = self.sock.recv(recordHeaderLength-len(bytes))
747             except socket.error, why:
748                 if why[0] == errno.EWOULDBLOCK:
749                     yield 0
750                     continue
751                 else:
752                     raise
753
754             #If the connection was abruptly closed, raise an error
755             if len(s)==0:
756                 raise TLSAbruptCloseError()
757
758             bytes += stringToBytes(s)
759             if len(bytes)==1:
760                 if bytes[0] in ContentType.all:
761                     ssl2 = False
762                     recordHeaderLength = 5
763                 elif bytes[0] == 128:
764                     ssl2 = True
765                     recordHeaderLength = 2
766                 else:
767                     raise SyntaxError()
768             if len(bytes) == recordHeaderLength:
769                 break
770
771         #Parse the record header
772         if ssl2:
773             r = RecordHeader2().parse(Parser(bytes))
774         else:
775             r = RecordHeader3().parse(Parser(bytes))
776
777         #Check the record header fields
778         if r.length > 18432:
779             for result in self._sendError(AlertDescription.record_overflow):
780                 yield result
781
782         #Read the record contents
783         bytes = createByteArraySequence([])
784         while 1:
785             try:
786                 s = self.sock.recv(r.length - len(bytes))
787             except socket.error, why:
788                 if why[0] == errno.EWOULDBLOCK:
789                     yield 0
790                     continue
791                 else:
792                     raise
793
794             #If the connection is closed, raise a socket error
795             if len(s)==0:
796                     raise TLSAbruptCloseError()
797
798             bytes += stringToBytes(s)
799             if len(bytes) == r.length:
800                 break
801
802         #Check the record header fields (2)
803         #We do this after reading the contents from the socket, so that
804         #if there's an error, we at least don't leave extra bytes in the
805         #socket..
806         #
807         # THIS CHECK HAS NO SECURITY RELEVANCE (?), BUT COULD HURT INTEROP.
808         # SO WE LEAVE IT OUT FOR NOW.
809         #
810         #if self._versionCheck and r.version != self.version:
811         #    for result in self._sendError(AlertDescription.protocol_version,
812         #            "Version in header field: %s, should be %s" % (str(r.version),
813         #                                                       str(self.version))):
814         #        yield result
815
816         #Decrypt the record
817         for result in self._decryptRecord(r.type, bytes):
818             if result in (0,1):
819                 yield result
820             else:
821                 break
822         bytes = result
823         p = Parser(bytes)
824
825         #If it doesn't contain handshake messages, we can just return it
826         if r.type != ContentType.handshake:
827             yield (r, p)
828         #If it's an SSLv2 ClientHello, we can return it as well
829         elif r.ssl2:
830             yield (r, p)
831         else:
832             #Otherwise, we loop through and add the handshake messages to the
833             #handshake buffer
834             while 1:
835                 if p.index == len(bytes): #If we're at the end
836                     if not self._handshakeBuffer:
837                         for result in self._sendError(\
838                                 AlertDescription.decode_error, \
839                                 "Received empty handshake record"):
840                             yield result
841                     break
842                 #There needs to be at least 4 bytes to get a header
843                 if p.index+4 > len(bytes):
844                     for result in self._sendError(\
845                             AlertDescription.decode_error,
846                             "A record has a partial handshake message (1)"):
847                         yield result
848                 p.get(1) # skip handshake type
849                 msgLength = p.get(3)
850                 if p.index+msgLength > len(bytes):
851                     for result in self._sendError(\
852                             AlertDescription.decode_error,
853                             "A record has a partial handshake message (2)"):
854                         yield result
855
856                 handshakePair = (r, bytes[p.index-4 : p.index+msgLength])
857                 self._handshakeBuffer.append(handshakePair)
858                 p.index += msgLength
859
860             #We've moved at least one handshake message into the
861             #handshakeBuffer, return the first one
862             recordHeader, bytes = self._handshakeBuffer[0]
863             self._handshakeBuffer = self._handshakeBuffer[1:]
864             yield (recordHeader, Parser(bytes))
865
866
867     def _decryptRecord(self, recordType, bytes):
868         if self._readState.encContext:
869
870             #Decrypt if it's a block cipher
871             if self._readState.encContext.isBlockCipher:
872                 blockLength = self._readState.encContext.block_size
873                 if len(bytes) % blockLength != 0:
874                     for result in self._sendError(\
875                             AlertDescription.decryption_failed,
876                             "Encrypted data not a multiple of blocksize"):
877                         yield result
878                 ciphertext = bytesToString(bytes)
879                 plaintext = self._readState.encContext.decrypt(ciphertext)
880                 if self.version == (3,2): #For TLS 1.1, remove explicit IV
881                     plaintext = plaintext[self._readState.encContext.block_size : ]
882                 bytes = stringToBytes(plaintext)
883
884                 #Check padding
885                 paddingGood = True
886                 paddingLength = bytes[-1]
887                 if (paddingLength+1) > len(bytes):
888                     paddingGood=False
889                     totalPaddingLength = 0
890                 else:
891                     if self.version == (3,0):
892                         totalPaddingLength = paddingLength+1
893                     elif self.version in ((3,1), (3,2)):
894                         totalPaddingLength = paddingLength+1
895                         paddingBytes = bytes[-totalPaddingLength:-1]
896                         for byte in paddingBytes:
897                             if byte != paddingLength:
898                                 paddingGood = False
899                                 totalPaddingLength = 0
900                     else:
901                         raise AssertionError()
902
903             #Decrypt if it's a stream cipher
904             else:
905                 paddingGood = True
906                 ciphertext = bytesToString(bytes)
907                 plaintext = self._readState.encContext.decrypt(ciphertext)
908                 bytes = stringToBytes(plaintext)
909                 totalPaddingLength = 0
910
911             #Check MAC
912             macGood = True
913             macLength = self._readState.macContext.digest_size
914             endLength = macLength + totalPaddingLength
915             if endLength > len(bytes):
916                 macGood = False
917             else:
918                 #Read MAC
919                 startIndex = len(bytes) - endLength
920                 endIndex = startIndex + macLength
921                 checkBytes = bytes[startIndex : endIndex]
922
923                 #Calculate MAC
924                 seqnumStr = self._readState.getSeqNumStr()
925                 bytes = bytes[:-endLength]
926                 bytesStr = bytesToString(bytes)
927                 mac = self._readState.macContext.copy()
928                 mac.update(seqnumStr)
929                 mac.update(chr(recordType))
930                 if self.version == (3,0):
931                     mac.update( chr( int(len(bytes)/256) ) )
932                     mac.update( chr( int(len(bytes)%256) ) )
933                 elif self.version in ((3,1), (3,2)):
934                     mac.update(chr(self.version[0]))
935                     mac.update(chr(self.version[1]))
936                     mac.update( chr( int(len(bytes)/256) ) )
937                     mac.update( chr( int(len(bytes)%256) ) )
938                 else:
939                     raise AssertionError()
940                 mac.update(bytesStr)
941                 macString = mac.digest()
942                 macBytes = stringToBytes(macString)
943
944                 #Compare MACs
945                 if macBytes != checkBytes:
946                     macGood = False
947
948             if not (paddingGood and macGood):
949                 for result in self._sendError(AlertDescription.bad_record_mac,
950                                           "MAC failure (or padding failure)"):
951                     yield result
952
953         yield bytes
954
955     def _handshakeStart(self, client):
956         self._client = client
957         self._handshake_md5 = md5.md5()
958         self._handshake_sha = sha.sha()
959         self._handshakeBuffer = []
960         self.allegedSharedKeyUsername = None
961         self.allegedSrpUsername = None
962         self._refCount = 1
963
964     def _handshakeDone(self, resumed):
965         self.resumed = resumed
966         self.closed = False
967
968     def _calcPendingStates(self, clientRandom, serverRandom, implementations):
969         if self.session.cipherSuite in CipherSuite.aes128Suites:
970             macLength = 20
971             keyLength = 16
972             ivLength = 16
973             createCipherFunc = createAES
974         elif self.session.cipherSuite in CipherSuite.aes256Suites:
975             macLength = 20
976             keyLength = 32
977             ivLength = 16
978             createCipherFunc = createAES
979         elif self.session.cipherSuite in CipherSuite.rc4Suites:
980             macLength = 20
981             keyLength = 16
982             ivLength = 0
983             createCipherFunc = createRC4
984         elif self.session.cipherSuite in CipherSuite.tripleDESSuites:
985             macLength = 20
986             keyLength = 24
987             ivLength = 8
988             createCipherFunc = createTripleDES
989         else:
990             raise AssertionError()
991
992         if self.version == (3,0):
993             createMACFunc = MAC_SSL
994         elif self.version in ((3,1), (3,2)):
995             createMACFunc = hmac.HMAC
996
997         outputLength = (macLength*2) + (keyLength*2) + (ivLength*2)
998
999         #Calculate Keying Material from Master Secret
1000         if self.version == (3,0):
1001             keyBlock = PRF_SSL(self.session.masterSecret,
1002                                concatArrays(serverRandom, clientRandom),
1003                                outputLength)
1004         elif self.version in ((3,1), (3,2)):
1005             keyBlock = PRF(self.session.masterSecret,
1006                            "key expansion",
1007                            concatArrays(serverRandom,clientRandom),
1008                            outputLength)
1009         else:
1010             raise AssertionError()
1011
1012         #Slice up Keying Material
1013         clientPendingState = _ConnectionState()
1014         serverPendingState = _ConnectionState()
1015         p = Parser(keyBlock)
1016         clientMACBlock = bytesToString(p.getFixBytes(macLength))
1017         serverMACBlock = bytesToString(p.getFixBytes(macLength))
1018         clientKeyBlock = bytesToString(p.getFixBytes(keyLength))
1019         serverKeyBlock = bytesToString(p.getFixBytes(keyLength))
1020         clientIVBlock  = bytesToString(p.getFixBytes(ivLength))
1021         serverIVBlock  = bytesToString(p.getFixBytes(ivLength))
1022         clientPendingState.macContext = createMACFunc(clientMACBlock,
1023                                                       digestmod=sha)
1024         serverPendingState.macContext = createMACFunc(serverMACBlock,
1025                                                       digestmod=sha)
1026         clientPendingState.encContext = createCipherFunc(clientKeyBlock,
1027                                                          clientIVBlock,
1028                                                          implementations)
1029         serverPendingState.encContext = createCipherFunc(serverKeyBlock,
1030                                                          serverIVBlock,
1031                                                          implementations)
1032
1033         #Assign new connection states to pending states
1034         if self._client:
1035             self._pendingWriteState = clientPendingState
1036             self._pendingReadState = serverPendingState
1037         else:
1038             self._pendingWriteState = serverPendingState
1039             self._pendingReadState = clientPendingState
1040
1041         if self.version == (3,2) and ivLength:
1042             #Choose fixedIVBlock for TLS 1.1 (this is encrypted with the CBC
1043             #residue to create the IV for each sent block)
1044             self.fixedIVBlock = getRandomBytes(ivLength)
1045
1046     def _changeWriteState(self):
1047         self._writeState = self._pendingWriteState
1048         self._pendingWriteState = _ConnectionState()
1049
1050     def _changeReadState(self):
1051         self._readState = self._pendingReadState
1052         self._pendingReadState = _ConnectionState()
1053
1054     def _sendFinished(self):
1055         #Send ChangeCipherSpec
1056         for result in self._sendMsg(ChangeCipherSpec()):
1057             yield result
1058
1059         #Switch to pending write state
1060         self._changeWriteState()
1061
1062         #Calculate verification data
1063         verifyData = self._calcFinished(True)
1064         if self.fault == Fault.badFinished:
1065             verifyData[0] = (verifyData[0]+1)%256
1066
1067         #Send Finished message under new state
1068         finished = Finished(self.version).create(verifyData)
1069         for result in self._sendMsg(finished):
1070             yield result
1071
1072     def _getChangeCipherSpec(self):
1073         #Get and check ChangeCipherSpec
1074         for result in self._getMsg(ContentType.change_cipher_spec):
1075             if result in (0,1):
1076                 yield result
1077         changeCipherSpec = result
1078
1079         if changeCipherSpec.type != 1:
1080             for result in self._sendError(AlertDescription.illegal_parameter,
1081                                          "ChangeCipherSpec type incorrect"):
1082                 yield result
1083
1084         #Switch to pending read state
1085         self._changeReadState()
1086
1087     def _getEncryptedExtensions(self):
1088         for result in self._getMsg(ContentType.handshake,
1089                                    HandshakeType.encrypted_extensions):
1090             if result in (0,1):
1091                 yield result
1092         encrypted_extensions = result
1093         self.channel_id = encrypted_extensions.channel_id_key
1094
1095     def _getFinished(self):
1096         #Calculate verification data
1097         verifyData = self._calcFinished(False)
1098
1099         #Get and check Finished message under new state
1100         for result in self._getMsg(ContentType.handshake,
1101                                   HandshakeType.finished):
1102             if result in (0,1):
1103                 yield result
1104         finished = result
1105         if finished.verify_data != verifyData:
1106             for result in self._sendError(AlertDescription.decrypt_error,
1107                                          "Finished message is incorrect"):
1108                 yield result
1109
1110     def _calcFinished(self, send=True):
1111         if self.version == (3,0):
1112             if (self._client and send) or (not self._client and not send):
1113                 senderStr = "\x43\x4C\x4E\x54"
1114             else:
1115                 senderStr = "\x53\x52\x56\x52"
1116
1117             verifyData = self._calcSSLHandshakeHash(self.session.masterSecret,
1118                                                    senderStr)
1119             return verifyData
1120
1121         elif self.version in ((3,1), (3,2)):
1122             if (self._client and send) or (not self._client and not send):
1123                 label = "client finished"
1124             else:
1125                 label = "server finished"
1126
1127             handshakeHashes = stringToBytes(self._handshake_md5.digest() + \
1128                                             self._handshake_sha.digest())
1129             verifyData = PRF(self.session.masterSecret, label, handshakeHashes,
1130                              12)
1131             return verifyData
1132         else:
1133             raise AssertionError()
1134
1135     #Used for Finished messages and CertificateVerify messages in SSL v3
1136     def _calcSSLHandshakeHash(self, masterSecret, label):
1137         masterSecretStr = bytesToString(masterSecret)
1138
1139         imac_md5 = self._handshake_md5.copy()
1140         imac_sha = self._handshake_sha.copy()
1141
1142         imac_md5.update(label + masterSecretStr + '\x36'*48)
1143         imac_sha.update(label + masterSecretStr + '\x36'*40)
1144
1145         md5Str = md5.md5(masterSecretStr + ('\x5c'*48) + \
1146                          imac_md5.digest()).digest()
1147         shaStr = sha.sha(masterSecretStr + ('\x5c'*40) + \
1148                          imac_sha.digest()).digest()
1149
1150         return stringToBytes(md5Str + shaStr)