- add sources.
[platform/framework/web/crosswalk.git] / src / third_party / tlslite / tlslite / messages.py
1 """Classes representing TLS messages."""
2
3 from utils.compat import *
4 from utils.cryptomath import *
5 from errors import *
6 from utils.codec import *
7 from constants import *
8 from X509 import X509
9 from X509CertChain import X509CertChain
10
11 # The sha module is deprecated in Python 2.6 
12 try:
13     import sha
14 except ImportError:
15     from hashlib import sha1 as sha
16
17 # The md5 module is deprecated in Python 2.6
18 try:
19     import md5
20 except ImportError:
21     from hashlib import md5
22
23 class RecordHeader3:
24     def __init__(self):
25         self.type = 0
26         self.version = (0,0)
27         self.length = 0
28         self.ssl2 = False
29
30     def create(self, version, type, length):
31         self.type = type
32         self.version = version
33         self.length = length
34         return self
35
36     def write(self):
37         w = Writer(5)
38         w.add(self.type, 1)
39         w.add(self.version[0], 1)
40         w.add(self.version[1], 1)
41         w.add(self.length, 2)
42         return w.bytes
43
44     def parse(self, p):
45         self.type = p.get(1)
46         self.version = (p.get(1), p.get(1))
47         self.length = p.get(2)
48         self.ssl2 = False
49         return self
50
51 class RecordHeader2:
52     def __init__(self):
53         self.type = 0
54         self.version = (0,0)
55         self.length = 0
56         self.ssl2 = True
57
58     def parse(self, p):
59         if p.get(1)!=128:
60             raise SyntaxError()
61         self.type = ContentType.handshake
62         self.version = (2,0)
63         #We don't support 2-byte-length-headers; could be a problem
64         self.length = p.get(1)
65         return self
66
67
68 class Msg:
69     def preWrite(self, trial):
70         if trial:
71             w = Writer()
72         else:
73             length = self.write(True)
74             w = Writer(length)
75         return w
76
77     def postWrite(self, w, trial):
78         if trial:
79             return w.index
80         else:
81             return w.bytes
82
83 class Alert(Msg):
84     def __init__(self):
85         self.contentType = ContentType.alert
86         self.level = 0
87         self.description = 0
88
89     def create(self, description, level=AlertLevel.fatal):
90         self.level = level
91         self.description = description
92         return self
93
94     def parse(self, p):
95         p.setLengthCheck(2)
96         self.level = p.get(1)
97         self.description = p.get(1)
98         p.stopLengthCheck()
99         return self
100
101     def write(self):
102         w = Writer(2)
103         w.add(self.level, 1)
104         w.add(self.description, 1)
105         return w.bytes
106
107
108 class HandshakeMsg(Msg):
109     def preWrite(self, handshakeType, trial):
110         if trial:
111             w = Writer()
112             w.add(handshakeType, 1)
113             w.add(0, 3)
114         else:
115             length = self.write(True)
116             w = Writer(length)
117             w.add(handshakeType, 1)
118             w.add(length-4, 3)
119         return w
120
121
122 class ClientHello(HandshakeMsg):
123     def __init__(self, ssl2=False):
124         self.contentType = ContentType.handshake
125         self.ssl2 = ssl2
126         self.client_version = (0,0)
127         self.random = createByteArrayZeros(32)
128         self.session_id = createByteArraySequence([])
129         self.cipher_suites = []         # a list of 16-bit values
130         self.certificate_types = [CertificateType.x509]
131         self.compression_methods = []   # a list of 8-bit values
132         self.srp_username = None        # a string
133         self.channel_id = False
134
135     def create(self, version, random, session_id, cipher_suites,
136                certificate_types=None, srp_username=None):
137         self.client_version = version
138         self.random = random
139         self.session_id = session_id
140         self.cipher_suites = cipher_suites
141         self.certificate_types = certificate_types
142         self.compression_methods = [0]
143         self.srp_username = srp_username
144         return self
145
146     def parse(self, p):
147         if self.ssl2:
148             self.client_version = (p.get(1), p.get(1))
149             cipherSpecsLength = p.get(2)
150             sessionIDLength = p.get(2)
151             randomLength = p.get(2)
152             self.cipher_suites = p.getFixList(3, int(cipherSpecsLength/3))
153             self.session_id = p.getFixBytes(sessionIDLength)
154             self.random = p.getFixBytes(randomLength)
155             if len(self.random) < 32:
156                 zeroBytes = 32-len(self.random)
157                 self.random = createByteArrayZeros(zeroBytes) + self.random
158             self.compression_methods = [0]#Fake this value
159
160             #We're not doing a stopLengthCheck() for SSLv2, oh well..
161         else:
162             p.startLengthCheck(3)
163             self.client_version = (p.get(1), p.get(1))
164             self.random = p.getFixBytes(32)
165             self.session_id = p.getVarBytes(1)
166             self.cipher_suites = p.getVarList(2, 2)
167             self.compression_methods = p.getVarList(1, 1)
168             if not p.atLengthCheck():
169                 totalExtLength = p.get(2)
170                 soFar = 0
171                 while soFar != totalExtLength:
172                     extType = p.get(2)
173                     extLength = p.get(2)
174                     if extType == 6:
175                         self.srp_username = bytesToString(p.getVarBytes(1))
176                     elif extType == 7:
177                         self.certificate_types = p.getVarList(1, 1)
178                     elif extType == ExtensionType.channel_id:
179                         self.channel_id = True
180                     else:
181                         p.getFixBytes(extLength)
182                     soFar += 4 + extLength
183             p.stopLengthCheck()
184         return self
185
186     def write(self, trial=False):
187         w = HandshakeMsg.preWrite(self, HandshakeType.client_hello, trial)
188         w.add(self.client_version[0], 1)
189         w.add(self.client_version[1], 1)
190         w.addFixSeq(self.random, 1)
191         w.addVarSeq(self.session_id, 1, 1)
192         w.addVarSeq(self.cipher_suites, 2, 2)
193         w.addVarSeq(self.compression_methods, 1, 1)
194
195         extLength = 0
196         if self.certificate_types and self.certificate_types != \
197                 [CertificateType.x509]:
198             extLength += 5 + len(self.certificate_types)
199         if self.srp_username:
200             extLength += 5 + len(self.srp_username)
201         if extLength > 0:
202             w.add(extLength, 2)
203
204         if self.certificate_types and self.certificate_types != \
205                 [CertificateType.x509]:
206             w.add(7, 2)
207             w.add(len(self.certificate_types)+1, 2)
208             w.addVarSeq(self.certificate_types, 1, 1)
209         if self.srp_username:
210             w.add(6, 2)
211             w.add(len(self.srp_username)+1, 2)
212             w.addVarSeq(stringToBytes(self.srp_username), 1, 1)
213
214         return HandshakeMsg.postWrite(self, w, trial)
215
216
217 class ServerHello(HandshakeMsg):
218     def __init__(self):
219         self.contentType = ContentType.handshake
220         self.server_version = (0,0)
221         self.random = createByteArrayZeros(32)
222         self.session_id = createByteArraySequence([])
223         self.cipher_suite = 0
224         self.certificate_type = CertificateType.x509
225         self.compression_method = 0
226         self.channel_id = False
227
228     def create(self, version, random, session_id, cipher_suite,
229                certificate_type):
230         self.server_version = version
231         self.random = random
232         self.session_id = session_id
233         self.cipher_suite = cipher_suite
234         self.certificate_type = certificate_type
235         self.compression_method = 0
236         return self
237
238     def parse(self, p):
239         p.startLengthCheck(3)
240         self.server_version = (p.get(1), p.get(1))
241         self.random = p.getFixBytes(32)
242         self.session_id = p.getVarBytes(1)
243         self.cipher_suite = p.get(2)
244         self.compression_method = p.get(1)
245         if not p.atLengthCheck():
246             totalExtLength = p.get(2)
247             soFar = 0
248             while soFar != totalExtLength:
249                 extType = p.get(2)
250                 extLength = p.get(2)
251                 if extType == 7:
252                     self.certificate_type = p.get(1)
253                 else:
254                     p.getFixBytes(extLength)
255                 soFar += 4 + extLength
256         p.stopLengthCheck()
257         return self
258
259     def write(self, trial=False):
260         w = HandshakeMsg.preWrite(self, HandshakeType.server_hello, trial)
261         w.add(self.server_version[0], 1)
262         w.add(self.server_version[1], 1)
263         w.addFixSeq(self.random, 1)
264         w.addVarSeq(self.session_id, 1, 1)
265         w.add(self.cipher_suite, 2)
266         w.add(self.compression_method, 1)
267
268         extLength = 0
269         if self.certificate_type and self.certificate_type != \
270                 CertificateType.x509:
271             extLength += 5
272
273         if self.channel_id:
274             extLength += 4
275
276         if extLength != 0:
277             w.add(extLength, 2)
278
279         if self.certificate_type and self.certificate_type != \
280                 CertificateType.x509:
281             w.add(7, 2)
282             w.add(1, 2)
283             w.add(self.certificate_type, 1)
284
285         if self.channel_id:
286             w.add(ExtensionType.channel_id, 2)
287             w.add(0, 2)
288
289         return HandshakeMsg.postWrite(self, w, trial)
290
291 class Certificate(HandshakeMsg):
292     def __init__(self, certificateType):
293         self.certificateType = certificateType
294         self.contentType = ContentType.handshake
295         self.certChain = None
296
297     def create(self, certChain):
298         self.certChain = certChain
299         return self
300
301     def parse(self, p):
302         p.startLengthCheck(3)
303         if self.certificateType == CertificateType.x509:
304             chainLength = p.get(3)
305             index = 0
306             certificate_list = []
307             while index != chainLength:
308                 certBytes = p.getVarBytes(3)
309                 x509 = X509()
310                 x509.parseBinary(certBytes)
311                 certificate_list.append(x509)
312                 index += len(certBytes)+3
313             if certificate_list:
314                 self.certChain = X509CertChain(certificate_list)
315         elif self.certificateType == CertificateType.cryptoID:
316             s = bytesToString(p.getVarBytes(2))
317             if s:
318                 try:
319                     import cryptoIDlib.CertChain
320                 except ImportError:
321                     raise SyntaxError(\
322                     "cryptoID cert chain received, cryptoIDlib not present")
323                 self.certChain = cryptoIDlib.CertChain.CertChain().parse(s)
324         else:
325             raise AssertionError()
326
327         p.stopLengthCheck()
328         return self
329
330     def write(self, trial=False):
331         w = HandshakeMsg.preWrite(self, HandshakeType.certificate, trial)
332         if self.certificateType == CertificateType.x509:
333             chainLength = 0
334             if self.certChain:
335                 certificate_list = self.certChain.x509List
336             else:
337                 certificate_list = []
338             #determine length
339             for cert in certificate_list:
340                 bytes = cert.writeBytes()
341                 chainLength += len(bytes)+3
342             #add bytes
343             w.add(chainLength, 3)
344             for cert in certificate_list:
345                 bytes = cert.writeBytes()
346                 w.addVarSeq(bytes, 1, 3)
347         elif self.certificateType == CertificateType.cryptoID:
348             if self.certChain:
349                 bytes = stringToBytes(self.certChain.write())
350             else:
351                 bytes = createByteArraySequence([])
352             w.addVarSeq(bytes, 1, 2)
353         else:
354             raise AssertionError()
355         return HandshakeMsg.postWrite(self, w, trial)
356
357 class CertificateRequest(HandshakeMsg):
358     def __init__(self):
359         self.contentType = ContentType.handshake
360         #Apple's Secure Transport library rejects empty certificate_types, so
361         #default to rsa_sign.
362         self.certificate_types = [ClientCertificateType.rsa_sign]
363         self.certificate_authorities = []
364
365     def create(self, certificate_types, certificate_authorities):
366         self.certificate_types = certificate_types
367         self.certificate_authorities = certificate_authorities
368         return self
369
370     def parse(self, p):
371         p.startLengthCheck(3)
372         self.certificate_types = p.getVarList(1, 1)
373         ca_list_length = p.get(2)
374         index = 0
375         self.certificate_authorities = []
376         while index != ca_list_length:
377           ca_bytes = p.getVarBytes(2)
378           self.certificate_authorities.append(ca_bytes)
379           index += len(ca_bytes)+2
380         p.stopLengthCheck()
381         return self
382
383     def write(self, trial=False):
384         w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request,
385                                   trial)
386         w.addVarSeq(self.certificate_types, 1, 1)
387         caLength = 0
388         #determine length
389         for ca_dn in self.certificate_authorities:
390             caLength += len(ca_dn)+2
391         w.add(caLength, 2)
392         #add bytes
393         for ca_dn in self.certificate_authorities:
394             w.addVarSeq(ca_dn, 1, 2)
395         return HandshakeMsg.postWrite(self, w, trial)
396
397 class ServerKeyExchange(HandshakeMsg):
398     def __init__(self, cipherSuite):
399         self.cipherSuite = cipherSuite
400         self.contentType = ContentType.handshake
401         self.srp_N = 0L
402         self.srp_g = 0L
403         self.srp_s = createByteArraySequence([])
404         self.srp_B = 0L
405         self.signature = createByteArraySequence([])
406
407     def createSRP(self, srp_N, srp_g, srp_s, srp_B):
408         self.srp_N = srp_N
409         self.srp_g = srp_g
410         self.srp_s = srp_s
411         self.srp_B = srp_B
412         return self
413
414     def parse(self, p):
415         p.startLengthCheck(3)
416         self.srp_N = bytesToNumber(p.getVarBytes(2))
417         self.srp_g = bytesToNumber(p.getVarBytes(2))
418         self.srp_s = p.getVarBytes(1)
419         self.srp_B = bytesToNumber(p.getVarBytes(2))
420         if self.cipherSuite in CipherSuite.srpRsaSuites:
421             self.signature = p.getVarBytes(2)
422         p.stopLengthCheck()
423         return self
424
425     def write(self, trial=False):
426         w = HandshakeMsg.preWrite(self, HandshakeType.server_key_exchange,
427                                   trial)
428         w.addVarSeq(numberToBytes(self.srp_N), 1, 2)
429         w.addVarSeq(numberToBytes(self.srp_g), 1, 2)
430         w.addVarSeq(self.srp_s, 1, 1)
431         w.addVarSeq(numberToBytes(self.srp_B), 1, 2)
432         if self.cipherSuite in CipherSuite.srpRsaSuites:
433             w.addVarSeq(self.signature, 1, 2)
434         return HandshakeMsg.postWrite(self, w, trial)
435
436     def hash(self, clientRandom, serverRandom):
437         oldCipherSuite = self.cipherSuite
438         self.cipherSuite = None
439         try:
440             bytes = clientRandom + serverRandom + self.write()[4:]
441             s = bytesToString(bytes)
442             return stringToBytes(md5.md5(s).digest() + sha.sha(s).digest())
443         finally:
444             self.cipherSuite = oldCipherSuite
445
446 class ServerHelloDone(HandshakeMsg):
447     def __init__(self):
448         self.contentType = ContentType.handshake
449
450     def create(self):
451         return self
452
453     def parse(self, p):
454         p.startLengthCheck(3)
455         p.stopLengthCheck()
456         return self
457
458     def write(self, trial=False):
459         w = HandshakeMsg.preWrite(self, HandshakeType.server_hello_done, trial)
460         return HandshakeMsg.postWrite(self, w, trial)
461
462 class ClientKeyExchange(HandshakeMsg):
463     def __init__(self, cipherSuite, version=None):
464         self.cipherSuite = cipherSuite
465         self.version = version
466         self.contentType = ContentType.handshake
467         self.srp_A = 0
468         self.encryptedPreMasterSecret = createByteArraySequence([])
469
470     def createSRP(self, srp_A):
471         self.srp_A = srp_A
472         return self
473
474     def createRSA(self, encryptedPreMasterSecret):
475         self.encryptedPreMasterSecret = encryptedPreMasterSecret
476         return self
477
478     def parse(self, p):
479         p.startLengthCheck(3)
480         if self.cipherSuite in CipherSuite.srpSuites + \
481                                CipherSuite.srpRsaSuites:
482             self.srp_A = bytesToNumber(p.getVarBytes(2))
483         elif self.cipherSuite in CipherSuite.rsaSuites:
484             if self.version in ((3,1), (3,2)):
485                 self.encryptedPreMasterSecret = p.getVarBytes(2)
486             elif self.version == (3,0):
487                 self.encryptedPreMasterSecret = \
488                     p.getFixBytes(len(p.bytes)-p.index)
489             else:
490                 raise AssertionError()
491         else:
492             raise AssertionError()
493         p.stopLengthCheck()
494         return self
495
496     def write(self, trial=False):
497         w = HandshakeMsg.preWrite(self, HandshakeType.client_key_exchange,
498                                   trial)
499         if self.cipherSuite in CipherSuite.srpSuites + \
500                                CipherSuite.srpRsaSuites:
501             w.addVarSeq(numberToBytes(self.srp_A), 1, 2)
502         elif self.cipherSuite in CipherSuite.rsaSuites:
503             if self.version in ((3,1), (3,2)):
504                 w.addVarSeq(self.encryptedPreMasterSecret, 1, 2)
505             elif self.version == (3,0):
506                 w.addFixSeq(self.encryptedPreMasterSecret, 1)
507             else:
508                 raise AssertionError()
509         else:
510             raise AssertionError()
511         return HandshakeMsg.postWrite(self, w, trial)
512
513 class CertificateVerify(HandshakeMsg):
514     def __init__(self):
515         self.contentType = ContentType.handshake
516         self.signature = createByteArraySequence([])
517
518     def create(self, signature):
519         self.signature = signature
520         return self
521
522     def parse(self, p):
523         p.startLengthCheck(3)
524         self.signature = p.getVarBytes(2)
525         p.stopLengthCheck()
526         return self
527
528     def write(self, trial=False):
529         w = HandshakeMsg.preWrite(self, HandshakeType.certificate_verify,
530                                   trial)
531         w.addVarSeq(self.signature, 1, 2)
532         return HandshakeMsg.postWrite(self, w, trial)
533
534 class ChangeCipherSpec(Msg):
535     def __init__(self):
536         self.contentType = ContentType.change_cipher_spec
537         self.type = 1
538
539     def create(self):
540         self.type = 1
541         return self
542
543     def parse(self, p):
544         p.setLengthCheck(1)
545         self.type = p.get(1)
546         p.stopLengthCheck()
547         return self
548
549     def write(self, trial=False):
550         w = Msg.preWrite(self, trial)
551         w.add(self.type,1)
552         return Msg.postWrite(self, w, trial)
553
554
555 class Finished(HandshakeMsg):
556     def __init__(self, version):
557         self.contentType = ContentType.handshake
558         self.version = version
559         self.verify_data = createByteArraySequence([])
560
561     def create(self, verify_data):
562         self.verify_data = verify_data
563         return self
564
565     def parse(self, p):
566         p.startLengthCheck(3)
567         if self.version == (3,0):
568             self.verify_data = p.getFixBytes(36)
569         elif self.version in ((3,1), (3,2)):
570             self.verify_data = p.getFixBytes(12)
571         else:
572             raise AssertionError()
573         p.stopLengthCheck()
574         return self
575
576     def write(self, trial=False):
577         w = HandshakeMsg.preWrite(self, HandshakeType.finished, trial)
578         w.addFixSeq(self.verify_data, 1)
579         return HandshakeMsg.postWrite(self, w, trial)
580
581 class EncryptedExtensions(HandshakeMsg):
582     def __init__(self):
583         self.channel_id_key = None
584         self.channel_id_proof = None
585
586     def parse(self, p):
587         p.startLengthCheck(3)
588         soFar = 0
589         while soFar != p.lengthCheck:
590             extType = p.get(2)
591             extLength = p.get(2)
592             if extType == ExtensionType.channel_id:
593                 if extLength != 32*4:
594                     raise SyntaxError()
595                 self.channel_id_key = p.getFixBytes(64)
596                 self.channel_id_proof = p.getFixBytes(64)
597             else:
598                 p.getFixBytes(extLength)
599             soFar += 4 + extLength
600         p.stopLengthCheck()
601         return self
602
603 class ApplicationData(Msg):
604     def __init__(self):
605         self.contentType = ContentType.application_data
606         self.bytes = createByteArraySequence([])
607
608     def create(self, bytes):
609         self.bytes = bytes
610         return self
611
612     def parse(self, p):
613         self.bytes = p.bytes
614         return self
615
616     def write(self):
617         return self.bytes