1 """Classes representing TLS messages."""
3 from utils.compat import *
4 from utils.cryptomath import *
6 from utils.codec import *
7 from constants import *
9 from X509CertChain import X509CertChain
11 # The sha module is deprecated in Python 2.6
15 from hashlib import sha1 as sha
17 # The md5 module is deprecated in Python 2.6
21 from hashlib import md5
30 def create(self, version, type, length):
32 self.version = version
39 w.add(self.version[0], 1)
40 w.add(self.version[1], 1)
46 self.version = (p.get(1), p.get(1))
47 self.length = p.get(2)
61 self.type = ContentType.handshake
63 #We don't support 2-byte-length-headers; could be a problem
64 self.length = p.get(1)
69 def preWrite(self, trial):
73 length = self.write(True)
77 def postWrite(self, w, trial):
85 self.contentType = ContentType.alert
89 def create(self, description, level=AlertLevel.fatal):
91 self.description = description
97 self.description = p.get(1)
104 w.add(self.description, 1)
108 class HandshakeMsg(Msg):
109 def preWrite(self, handshakeType, trial):
112 w.add(handshakeType, 1)
115 length = self.write(True)
117 w.add(handshakeType, 1)
122 class ClientHello(HandshakeMsg):
123 def __init__(self, ssl2=False):
124 self.contentType = ContentType.handshake
126 self.client_version = (0,0)
127 self.random = createByteArrayZeros(32)
128 self.session_id = createByteArraySequence([])
129 self.cipher_suites = [] # a list of 16-bit values
130 self.certificate_types = [CertificateType.x509]
131 self.compression_methods = [] # a list of 8-bit values
132 self.srp_username = None # a string
133 self.channel_id = False
135 def create(self, version, random, session_id, cipher_suites,
136 certificate_types=None, srp_username=None):
137 self.client_version = version
139 self.session_id = session_id
140 self.cipher_suites = cipher_suites
141 self.certificate_types = certificate_types
142 self.compression_methods = [0]
143 self.srp_username = srp_username
148 self.client_version = (p.get(1), p.get(1))
149 cipherSpecsLength = p.get(2)
150 sessionIDLength = p.get(2)
151 randomLength = p.get(2)
152 self.cipher_suites = p.getFixList(3, int(cipherSpecsLength/3))
153 self.session_id = p.getFixBytes(sessionIDLength)
154 self.random = p.getFixBytes(randomLength)
155 if len(self.random) < 32:
156 zeroBytes = 32-len(self.random)
157 self.random = createByteArrayZeros(zeroBytes) + self.random
158 self.compression_methods = [0]#Fake this value
160 #We're not doing a stopLengthCheck() for SSLv2, oh well..
162 p.startLengthCheck(3)
163 self.client_version = (p.get(1), p.get(1))
164 self.random = p.getFixBytes(32)
165 self.session_id = p.getVarBytes(1)
166 self.cipher_suites = p.getVarList(2, 2)
167 self.compression_methods = p.getVarList(1, 1)
168 if not p.atLengthCheck():
169 totalExtLength = p.get(2)
171 while soFar != totalExtLength:
175 self.srp_username = bytesToString(p.getVarBytes(1))
177 self.certificate_types = p.getVarList(1, 1)
178 elif extType == ExtensionType.channel_id:
179 self.channel_id = True
181 p.getFixBytes(extLength)
182 soFar += 4 + extLength
186 def write(self, trial=False):
187 w = HandshakeMsg.preWrite(self, HandshakeType.client_hello, trial)
188 w.add(self.client_version[0], 1)
189 w.add(self.client_version[1], 1)
190 w.addFixSeq(self.random, 1)
191 w.addVarSeq(self.session_id, 1, 1)
192 w.addVarSeq(self.cipher_suites, 2, 2)
193 w.addVarSeq(self.compression_methods, 1, 1)
196 if self.certificate_types and self.certificate_types != \
197 [CertificateType.x509]:
198 extLength += 5 + len(self.certificate_types)
199 if self.srp_username:
200 extLength += 5 + len(self.srp_username)
204 if self.certificate_types and self.certificate_types != \
205 [CertificateType.x509]:
207 w.add(len(self.certificate_types)+1, 2)
208 w.addVarSeq(self.certificate_types, 1, 1)
209 if self.srp_username:
211 w.add(len(self.srp_username)+1, 2)
212 w.addVarSeq(stringToBytes(self.srp_username), 1, 1)
214 return HandshakeMsg.postWrite(self, w, trial)
217 class ServerHello(HandshakeMsg):
219 self.contentType = ContentType.handshake
220 self.server_version = (0,0)
221 self.random = createByteArrayZeros(32)
222 self.session_id = createByteArraySequence([])
223 self.cipher_suite = 0
224 self.certificate_type = CertificateType.x509
225 self.compression_method = 0
226 self.channel_id = False
228 def create(self, version, random, session_id, cipher_suite,
230 self.server_version = version
232 self.session_id = session_id
233 self.cipher_suite = cipher_suite
234 self.certificate_type = certificate_type
235 self.compression_method = 0
239 p.startLengthCheck(3)
240 self.server_version = (p.get(1), p.get(1))
241 self.random = p.getFixBytes(32)
242 self.session_id = p.getVarBytes(1)
243 self.cipher_suite = p.get(2)
244 self.compression_method = p.get(1)
245 if not p.atLengthCheck():
246 totalExtLength = p.get(2)
248 while soFar != totalExtLength:
252 self.certificate_type = p.get(1)
254 p.getFixBytes(extLength)
255 soFar += 4 + extLength
259 def write(self, trial=False):
260 w = HandshakeMsg.preWrite(self, HandshakeType.server_hello, trial)
261 w.add(self.server_version[0], 1)
262 w.add(self.server_version[1], 1)
263 w.addFixSeq(self.random, 1)
264 w.addVarSeq(self.session_id, 1, 1)
265 w.add(self.cipher_suite, 2)
266 w.add(self.compression_method, 1)
269 if self.certificate_type and self.certificate_type != \
270 CertificateType.x509:
279 if self.certificate_type and self.certificate_type != \
280 CertificateType.x509:
283 w.add(self.certificate_type, 1)
286 w.add(ExtensionType.channel_id, 2)
289 return HandshakeMsg.postWrite(self, w, trial)
291 class Certificate(HandshakeMsg):
292 def __init__(self, certificateType):
293 self.certificateType = certificateType
294 self.contentType = ContentType.handshake
295 self.certChain = None
297 def create(self, certChain):
298 self.certChain = certChain
302 p.startLengthCheck(3)
303 if self.certificateType == CertificateType.x509:
304 chainLength = p.get(3)
306 certificate_list = []
307 while index != chainLength:
308 certBytes = p.getVarBytes(3)
310 x509.parseBinary(certBytes)
311 certificate_list.append(x509)
312 index += len(certBytes)+3
314 self.certChain = X509CertChain(certificate_list)
315 elif self.certificateType == CertificateType.cryptoID:
316 s = bytesToString(p.getVarBytes(2))
319 import cryptoIDlib.CertChain
322 "cryptoID cert chain received, cryptoIDlib not present")
323 self.certChain = cryptoIDlib.CertChain.CertChain().parse(s)
325 raise AssertionError()
330 def write(self, trial=False):
331 w = HandshakeMsg.preWrite(self, HandshakeType.certificate, trial)
332 if self.certificateType == CertificateType.x509:
335 certificate_list = self.certChain.x509List
337 certificate_list = []
339 for cert in certificate_list:
340 bytes = cert.writeBytes()
341 chainLength += len(bytes)+3
343 w.add(chainLength, 3)
344 for cert in certificate_list:
345 bytes = cert.writeBytes()
346 w.addVarSeq(bytes, 1, 3)
347 elif self.certificateType == CertificateType.cryptoID:
349 bytes = stringToBytes(self.certChain.write())
351 bytes = createByteArraySequence([])
352 w.addVarSeq(bytes, 1, 2)
354 raise AssertionError()
355 return HandshakeMsg.postWrite(self, w, trial)
357 class CertificateRequest(HandshakeMsg):
359 self.contentType = ContentType.handshake
360 #Apple's Secure Transport library rejects empty certificate_types, so
361 #default to rsa_sign.
362 self.certificate_types = [ClientCertificateType.rsa_sign]
363 self.certificate_authorities = []
365 def create(self, certificate_types, certificate_authorities):
366 self.certificate_types = certificate_types
367 self.certificate_authorities = certificate_authorities
371 p.startLengthCheck(3)
372 self.certificate_types = p.getVarList(1, 1)
373 ca_list_length = p.get(2)
375 self.certificate_authorities = []
376 while index != ca_list_length:
377 ca_bytes = p.getVarBytes(2)
378 self.certificate_authorities.append(ca_bytes)
379 index += len(ca_bytes)+2
383 def write(self, trial=False):
384 w = HandshakeMsg.preWrite(self, HandshakeType.certificate_request,
386 w.addVarSeq(self.certificate_types, 1, 1)
389 for ca_dn in self.certificate_authorities:
390 caLength += len(ca_dn)+2
393 for ca_dn in self.certificate_authorities:
394 w.addVarSeq(ca_dn, 1, 2)
395 return HandshakeMsg.postWrite(self, w, trial)
397 class ServerKeyExchange(HandshakeMsg):
398 def __init__(self, cipherSuite):
399 self.cipherSuite = cipherSuite
400 self.contentType = ContentType.handshake
403 self.srp_s = createByteArraySequence([])
405 self.signature = createByteArraySequence([])
407 def createSRP(self, srp_N, srp_g, srp_s, srp_B):
415 p.startLengthCheck(3)
416 self.srp_N = bytesToNumber(p.getVarBytes(2))
417 self.srp_g = bytesToNumber(p.getVarBytes(2))
418 self.srp_s = p.getVarBytes(1)
419 self.srp_B = bytesToNumber(p.getVarBytes(2))
420 if self.cipherSuite in CipherSuite.srpRsaSuites:
421 self.signature = p.getVarBytes(2)
425 def write(self, trial=False):
426 w = HandshakeMsg.preWrite(self, HandshakeType.server_key_exchange,
428 w.addVarSeq(numberToBytes(self.srp_N), 1, 2)
429 w.addVarSeq(numberToBytes(self.srp_g), 1, 2)
430 w.addVarSeq(self.srp_s, 1, 1)
431 w.addVarSeq(numberToBytes(self.srp_B), 1, 2)
432 if self.cipherSuite in CipherSuite.srpRsaSuites:
433 w.addVarSeq(self.signature, 1, 2)
434 return HandshakeMsg.postWrite(self, w, trial)
436 def hash(self, clientRandom, serverRandom):
437 oldCipherSuite = self.cipherSuite
438 self.cipherSuite = None
440 bytes = clientRandom + serverRandom + self.write()[4:]
441 s = bytesToString(bytes)
442 return stringToBytes(md5.md5(s).digest() + sha.sha(s).digest())
444 self.cipherSuite = oldCipherSuite
446 class ServerHelloDone(HandshakeMsg):
448 self.contentType = ContentType.handshake
454 p.startLengthCheck(3)
458 def write(self, trial=False):
459 w = HandshakeMsg.preWrite(self, HandshakeType.server_hello_done, trial)
460 return HandshakeMsg.postWrite(self, w, trial)
462 class ClientKeyExchange(HandshakeMsg):
463 def __init__(self, cipherSuite, version=None):
464 self.cipherSuite = cipherSuite
465 self.version = version
466 self.contentType = ContentType.handshake
468 self.encryptedPreMasterSecret = createByteArraySequence([])
470 def createSRP(self, srp_A):
474 def createRSA(self, encryptedPreMasterSecret):
475 self.encryptedPreMasterSecret = encryptedPreMasterSecret
479 p.startLengthCheck(3)
480 if self.cipherSuite in CipherSuite.srpSuites + \
481 CipherSuite.srpRsaSuites:
482 self.srp_A = bytesToNumber(p.getVarBytes(2))
483 elif self.cipherSuite in CipherSuite.rsaSuites:
484 if self.version in ((3,1), (3,2)):
485 self.encryptedPreMasterSecret = p.getVarBytes(2)
486 elif self.version == (3,0):
487 self.encryptedPreMasterSecret = \
488 p.getFixBytes(len(p.bytes)-p.index)
490 raise AssertionError()
492 raise AssertionError()
496 def write(self, trial=False):
497 w = HandshakeMsg.preWrite(self, HandshakeType.client_key_exchange,
499 if self.cipherSuite in CipherSuite.srpSuites + \
500 CipherSuite.srpRsaSuites:
501 w.addVarSeq(numberToBytes(self.srp_A), 1, 2)
502 elif self.cipherSuite in CipherSuite.rsaSuites:
503 if self.version in ((3,1), (3,2)):
504 w.addVarSeq(self.encryptedPreMasterSecret, 1, 2)
505 elif self.version == (3,0):
506 w.addFixSeq(self.encryptedPreMasterSecret, 1)
508 raise AssertionError()
510 raise AssertionError()
511 return HandshakeMsg.postWrite(self, w, trial)
513 class CertificateVerify(HandshakeMsg):
515 self.contentType = ContentType.handshake
516 self.signature = createByteArraySequence([])
518 def create(self, signature):
519 self.signature = signature
523 p.startLengthCheck(3)
524 self.signature = p.getVarBytes(2)
528 def write(self, trial=False):
529 w = HandshakeMsg.preWrite(self, HandshakeType.certificate_verify,
531 w.addVarSeq(self.signature, 1, 2)
532 return HandshakeMsg.postWrite(self, w, trial)
534 class ChangeCipherSpec(Msg):
536 self.contentType = ContentType.change_cipher_spec
549 def write(self, trial=False):
550 w = Msg.preWrite(self, trial)
552 return Msg.postWrite(self, w, trial)
555 class Finished(HandshakeMsg):
556 def __init__(self, version):
557 self.contentType = ContentType.handshake
558 self.version = version
559 self.verify_data = createByteArraySequence([])
561 def create(self, verify_data):
562 self.verify_data = verify_data
566 p.startLengthCheck(3)
567 if self.version == (3,0):
568 self.verify_data = p.getFixBytes(36)
569 elif self.version in ((3,1), (3,2)):
570 self.verify_data = p.getFixBytes(12)
572 raise AssertionError()
576 def write(self, trial=False):
577 w = HandshakeMsg.preWrite(self, HandshakeType.finished, trial)
578 w.addFixSeq(self.verify_data, 1)
579 return HandshakeMsg.postWrite(self, w, trial)
581 class EncryptedExtensions(HandshakeMsg):
583 self.channel_id_key = None
584 self.channel_id_proof = None
587 p.startLengthCheck(3)
589 while soFar != p.lengthCheck:
592 if extType == ExtensionType.channel_id:
593 if extLength != 32*4:
595 self.channel_id_key = p.getFixBytes(64)
596 self.channel_id_proof = p.getFixBytes(64)
598 p.getFixBytes(extLength)
599 soFar += 4 + extLength
603 class ApplicationData(Msg):
605 self.contentType = ContentType.application_data
606 self.bytes = createByteArraySequence([])
608 def create(self, bytes):