3 # Google - handling CertificateRequest.certificate_types
4 # Google (adapted by Sam Rushing and Marcelo Fernandez) - NPN support
5 # Dimitris Moraitis - Anon ciphersuites
7 # See the LICENSE file for legal information regarding use of this file.
9 """Classes representing TLS messages."""
11 from .utils.compat import *
12 from .utils.cryptomath import *
14 from .utils.codec import *
15 from .constants import *
16 from .x509 import X509
17 from .x509certchain import X509CertChain
18 from .utils.tackwrapper import *
20 class RecordHeader3(object):
27 def create(self, version, type, length):
29 self.version = version
36 w.add(self.version[0], 1)
37 w.add(self.version[1], 1)
43 self.version = (p.get(1), p.get(1))
44 self.length = p.get(2)
48 class RecordHeader2(object):
58 self.type = ContentType.handshake
60 #We don't support 2-byte-length-headers; could be a problem
61 self.length = p.get(1)
67 self.contentType = ContentType.alert
71 def create(self, description, level=AlertLevel.fatal):
73 self.description = description
79 self.description = p.get(1)
86 w.add(self.description, 1)
90 class HandshakeMsg(object):
91 def __init__(self, handshakeType):
92 self.contentType = ContentType.handshake
93 self.handshakeType = handshakeType
95 def postWrite(self, w):
96 headerWriter = Writer()
97 headerWriter.add(self.handshakeType, 1)
98 headerWriter.add(len(w.bytes), 3)
99 return headerWriter.bytes + w.bytes
101 class ClientHello(HandshakeMsg):
102 def __init__(self, ssl2=False):
103 HandshakeMsg.__init__(self, HandshakeType.client_hello)
105 self.client_version = (0,0)
106 self.random = bytearray(32)
107 self.session_id = bytearray(0)
108 self.cipher_suites = [] # a list of 16-bit values
109 self.certificate_types = [CertificateType.x509]
110 self.compression_methods = [] # a list of 8-bit values
111 self.srp_username = None # a string
113 self.supports_npn = False
114 self.server_name = bytearray(0)
115 self.channel_id = False
116 self.support_signed_cert_timestamps = False
117 self.status_request = False
119 def create(self, version, random, session_id, cipher_suites,
120 certificate_types=None, srpUsername=None,
121 tack=False, supports_npn=False, serverName=None):
122 self.client_version = version
124 self.session_id = session_id
125 self.cipher_suites = cipher_suites
126 self.certificate_types = certificate_types
127 self.compression_methods = [0]
129 self.srp_username = bytearray(srpUsername, "utf-8")
131 self.supports_npn = supports_npn
133 self.server_name = bytearray(serverName, "utf-8")
138 self.client_version = (p.get(1), p.get(1))
139 cipherSpecsLength = p.get(2)
140 sessionIDLength = p.get(2)
141 randomLength = p.get(2)
142 self.cipher_suites = p.getFixList(3, cipherSpecsLength//3)
143 self.session_id = p.getFixBytes(sessionIDLength)
144 self.random = p.getFixBytes(randomLength)
145 if len(self.random) < 32:
146 zeroBytes = 32-len(self.random)
147 self.random = bytearray(zeroBytes) + self.random
148 self.compression_methods = [0]#Fake this value
150 #We're not doing a stopLengthCheck() for SSLv2, oh well..
152 p.startLengthCheck(3)
153 self.client_version = (p.get(1), p.get(1))
154 self.random = p.getFixBytes(32)
155 self.session_id = p.getVarBytes(1)
156 self.cipher_suites = p.getVarList(2, 2)
157 self.compression_methods = p.getVarList(1, 1)
158 if not p.atLengthCheck():
159 totalExtLength = p.get(2)
161 while soFar != totalExtLength:
165 if extType == ExtensionType.srp:
166 self.srp_username = p.getVarBytes(1)
167 elif extType == ExtensionType.cert_type:
168 self.certificate_types = p.getVarList(1, 1)
169 elif extType == ExtensionType.tack:
171 elif extType == ExtensionType.supports_npn:
172 self.supports_npn = True
173 elif extType == ExtensionType.server_name:
174 serverNameListBytes = p.getFixBytes(extLength)
175 p2 = Parser(serverNameListBytes)
176 p2.startLengthCheck(2)
178 if p2.atLengthCheck():
179 break # no host_name, oh well
180 name_type = p2.get(1)
181 hostNameBytes = p2.getVarBytes(2)
182 if name_type == NameType.host_name:
183 self.server_name = hostNameBytes
185 elif extType == ExtensionType.channel_id:
186 self.channel_id = True
187 elif extType == ExtensionType.signed_cert_timestamps:
190 self.support_signed_cert_timestamps = True
191 elif extType == ExtensionType.status_request:
192 # Extension contents are currently ignored.
193 # According to RFC 6066, this is not strictly forbidden
194 # (although it is suboptimal):
195 # Servers that receive a client hello containing the
196 # "status_request" extension MAY return a suitable
197 # certificate status response to the client along with
198 # their certificate. If OCSP is requested, they
199 # SHOULD use the information contained in the extension
200 # when selecting an OCSP responder and SHOULD include
201 # request_extensions in the OCSP request.
202 p.getFixBytes(extLength)
203 self.status_request = True
205 _ = p.getFixBytes(extLength)
207 if index2 - index1 != extLength:
208 raise SyntaxError("Bad length for extension_data")
209 soFar += 4 + extLength
215 w.add(self.client_version[0], 1)
216 w.add(self.client_version[1], 1)
217 w.addFixSeq(self.random, 1)
218 w.addVarSeq(self.session_id, 1, 1)
219 w.addVarSeq(self.cipher_suites, 2, 2)
220 w.addVarSeq(self.compression_methods, 1, 1)
222 w2 = Writer() # For Extensions
223 if self.certificate_types and self.certificate_types != \
224 [CertificateType.x509]:
225 w2.add(ExtensionType.cert_type, 2)
226 w2.add(len(self.certificate_types)+1, 2)
227 w2.addVarSeq(self.certificate_types, 1, 1)
228 if self.srp_username:
229 w2.add(ExtensionType.srp, 2)
230 w2.add(len(self.srp_username)+1, 2)
231 w2.addVarSeq(self.srp_username, 1, 1)
232 if self.supports_npn:
233 w2.add(ExtensionType.supports_npn, 2)
236 w2.add(ExtensionType.server_name, 2)
237 w2.add(len(self.server_name)+5, 2)
238 w2.add(len(self.server_name)+3, 2)
239 w2.add(NameType.host_name, 1)
240 w2.addVarSeq(self.server_name, 1, 2)
242 w2.add(ExtensionType.tack, 2)
245 w.add(len(w2.bytes), 2)
247 return self.postWrite(w)
249 class BadNextProtos(Exception):
250 def __init__(self, l):
254 return 'Cannot encode a list of next protocols because it contains an element with invalid length %d. Element lengths must be 0 < x < 256' % self.length
256 class ServerHello(HandshakeMsg):
258 HandshakeMsg.__init__(self, HandshakeType.server_hello)
259 self.server_version = (0,0)
260 self.random = bytearray(32)
261 self.session_id = bytearray(0)
262 self.cipher_suite = 0
263 self.certificate_type = CertificateType.x509
264 self.compression_method = 0
266 self.next_protos_advertised = None
267 self.next_protos = None
268 self.channel_id = False
269 self.signed_cert_timestamps = None
270 self.status_request = False
272 def create(self, version, random, session_id, cipher_suite,
273 certificate_type, tackExt, next_protos_advertised):
274 self.server_version = version
276 self.session_id = session_id
277 self.cipher_suite = cipher_suite
278 self.certificate_type = certificate_type
279 self.compression_method = 0
280 self.tackExt = tackExt
281 self.next_protos_advertised = next_protos_advertised
285 p.startLengthCheck(3)
286 self.server_version = (p.get(1), p.get(1))
287 self.random = p.getFixBytes(32)
288 self.session_id = p.getVarBytes(1)
289 self.cipher_suite = p.get(2)
290 self.compression_method = p.get(1)
291 if not p.atLengthCheck():
292 totalExtLength = p.get(2)
294 while soFar != totalExtLength:
297 if extType == ExtensionType.cert_type:
300 self.certificate_type = p.get(1)
301 elif extType == ExtensionType.tack and tackpyLoaded:
302 self.tackExt = TackExtension(p.getFixBytes(extLength))
303 elif extType == ExtensionType.supports_npn:
304 self.next_protos = self.__parse_next_protos(p.getFixBytes(extLength))
306 p.getFixBytes(extLength)
307 soFar += 4 + extLength
311 def __parse_next_protos(self, b):
319 raise BadNextProtos(len(b))
324 def __next_protos_encoded(self):
326 for e in self.next_protos_advertised:
327 if len(e) > 255 or len(e) == 0:
328 raise BadNextProtos(len(e))
329 b += bytearray( [len(e)] ) + bytearray(e)
334 w.add(self.server_version[0], 1)
335 w.add(self.server_version[1], 1)
336 w.addFixSeq(self.random, 1)
337 w.addVarSeq(self.session_id, 1, 1)
338 w.add(self.cipher_suite, 2)
339 w.add(self.compression_method, 1)
341 w2 = Writer() # For Extensions
342 if self.certificate_type and self.certificate_type != \
343 CertificateType.x509:
344 w2.add(ExtensionType.cert_type, 2)
346 w2.add(self.certificate_type, 1)
348 b = self.tackExt.serialize()
349 w2.add(ExtensionType.tack, 2)
352 if self.next_protos_advertised is not None:
353 encoded_next_protos_advertised = self.__next_protos_encoded()
354 w2.add(ExtensionType.supports_npn, 2)
355 w2.add(len(encoded_next_protos_advertised), 2)
356 w2.addFixSeq(encoded_next_protos_advertised, 1)
358 w2.add(ExtensionType.channel_id, 2)
360 if self.signed_cert_timestamps:
361 w2.add(ExtensionType.signed_cert_timestamps, 2)
362 w2.addVarSeq(bytearray(self.signed_cert_timestamps), 1, 2)
363 if self.status_request:
364 w2.add(ExtensionType.status_request, 2)
367 w.add(len(w2.bytes), 2)
369 return self.postWrite(w)
372 class Certificate(HandshakeMsg):
373 def __init__(self, certificateType):
374 HandshakeMsg.__init__(self, HandshakeType.certificate)
375 self.certificateType = certificateType
376 self.certChain = None
378 def create(self, certChain):
379 self.certChain = certChain
383 p.startLengthCheck(3)
384 if self.certificateType == CertificateType.x509:
385 chainLength = p.get(3)
387 certificate_list = []
388 while index != chainLength:
389 certBytes = p.getVarBytes(3)
391 x509.parseBinary(certBytes)
392 certificate_list.append(x509)
393 index += len(certBytes)+3
395 self.certChain = X509CertChain(certificate_list)
397 raise AssertionError()
404 if self.certificateType == CertificateType.x509:
407 certificate_list = self.certChain.x509List
409 certificate_list = []
411 for cert in certificate_list:
412 bytes = cert.writeBytes()
413 chainLength += len(bytes)+3
415 w.add(chainLength, 3)
416 for cert in certificate_list:
417 bytes = cert.writeBytes()
418 w.addVarSeq(bytes, 1, 3)
420 raise AssertionError()
421 return self.postWrite(w)
423 class CertificateStatus(HandshakeMsg):
425 HandshakeMsg.__init__(self, HandshakeType.certificate_status)
427 def create(self, ocsp_response):
428 self.ocsp_response = ocsp_response
431 # Defined for the sake of completeness, even though we currently only
432 # support sending the status message (server-side), not requesting
433 # or receiving it (client-side).
435 p.startLengthCheck(3)
436 status_type = p.get(1)
437 # Only one type is specified, so hardwire it.
438 if status_type != CertificateStatusType.ocsp:
440 ocsp_response = p.getVarBytes(3)
441 if not ocsp_response:
444 self.ocsp_response = ocsp_response
450 w.add(CertificateStatusType.ocsp, 1)
451 w.addVarSeq(bytearray(self.ocsp_response), 1, 3)
452 return self.postWrite(w)
454 class CertificateRequest(HandshakeMsg):
456 HandshakeMsg.__init__(self, HandshakeType.certificate_request)
457 self.certificate_types = []
458 self.certificate_authorities = []
460 def create(self, certificate_types, certificate_authorities):
461 self.certificate_types = certificate_types
462 self.certificate_authorities = certificate_authorities
466 p.startLengthCheck(3)
467 self.certificate_types = p.getVarList(1, 1)
468 ca_list_length = p.get(2)
470 self.certificate_authorities = []
471 while index != ca_list_length:
472 ca_bytes = p.getVarBytes(2)
473 self.certificate_authorities.append(ca_bytes)
474 index += len(ca_bytes)+2
480 w.addVarSeq(self.certificate_types, 1, 1)
483 for ca_dn in self.certificate_authorities:
484 caLength += len(ca_dn)+2
487 for ca_dn in self.certificate_authorities:
488 w.addVarSeq(ca_dn, 1, 2)
489 return self.postWrite(w)
491 class ServerKeyExchange(HandshakeMsg):
492 def __init__(self, cipherSuite):
493 HandshakeMsg.__init__(self, HandshakeType.server_key_exchange)
494 self.cipherSuite = cipherSuite
497 self.srp_s = bytearray(0)
503 self.signature = bytearray(0)
505 def createSRP(self, srp_N, srp_g, srp_s, srp_B):
512 def createDH(self, dh_p, dh_g, dh_Ys):
519 p.startLengthCheck(3)
520 if self.cipherSuite in CipherSuite.srpAllSuites:
521 self.srp_N = bytesToNumber(p.getVarBytes(2))
522 self.srp_g = bytesToNumber(p.getVarBytes(2))
523 self.srp_s = p.getVarBytes(1)
524 self.srp_B = bytesToNumber(p.getVarBytes(2))
525 if self.cipherSuite in CipherSuite.srpCertSuites:
526 self.signature = p.getVarBytes(2)
527 elif self.cipherSuite in CipherSuite.anonSuites:
528 self.dh_p = bytesToNumber(p.getVarBytes(2))
529 self.dh_g = bytesToNumber(p.getVarBytes(2))
530 self.dh_Ys = bytesToNumber(p.getVarBytes(2))
534 def write_params(self):
536 if self.cipherSuite in CipherSuite.srpAllSuites:
537 w.addVarSeq(numberToByteArray(self.srp_N), 1, 2)
538 w.addVarSeq(numberToByteArray(self.srp_g), 1, 2)
539 w.addVarSeq(self.srp_s, 1, 1)
540 w.addVarSeq(numberToByteArray(self.srp_B), 1, 2)
541 elif self.cipherSuite in CipherSuite.dhAllSuites:
542 w.addVarSeq(numberToByteArray(self.dh_p), 1, 2)
543 w.addVarSeq(numberToByteArray(self.dh_g), 1, 2)
544 w.addVarSeq(numberToByteArray(self.dh_Ys), 1, 2)
551 w.bytes += self.write_params()
552 if self.cipherSuite in CipherSuite.certAllSuites:
553 w.addVarSeq(self.signature, 1, 2)
554 return self.postWrite(w)
556 def hash(self, clientRandom, serverRandom):
557 bytes = clientRandom + serverRandom + self.write_params()
558 return MD5(bytes) + SHA1(bytes)
560 class ServerHelloDone(HandshakeMsg):
562 HandshakeMsg.__init__(self, HandshakeType.server_hello_done)
568 p.startLengthCheck(3)
574 return self.postWrite(w)
576 class ClientKeyExchange(HandshakeMsg):
577 def __init__(self, cipherSuite, version=None):
578 HandshakeMsg.__init__(self, HandshakeType.client_key_exchange)
579 self.cipherSuite = cipherSuite
580 self.version = version
582 self.encryptedPreMasterSecret = bytearray(0)
584 def createSRP(self, srp_A):
588 def createRSA(self, encryptedPreMasterSecret):
589 self.encryptedPreMasterSecret = encryptedPreMasterSecret
592 def createDH(self, dh_Yc):
597 p.startLengthCheck(3)
598 if self.cipherSuite in CipherSuite.srpAllSuites:
599 self.srp_A = bytesToNumber(p.getVarBytes(2))
600 elif self.cipherSuite in CipherSuite.certSuites:
601 if self.version in ((3,1), (3,2)):
602 self.encryptedPreMasterSecret = p.getVarBytes(2)
603 elif self.version == (3,0):
604 self.encryptedPreMasterSecret = \
605 p.getFixBytes(len(p.bytes)-p.index)
607 raise AssertionError()
608 elif self.cipherSuite in CipherSuite.dhAllSuites:
609 self.dh_Yc = bytesToNumber(p.getVarBytes(2))
611 raise AssertionError()
617 if self.cipherSuite in CipherSuite.srpAllSuites:
618 w.addVarSeq(numberToByteArray(self.srp_A), 1, 2)
619 elif self.cipherSuite in CipherSuite.certSuites:
620 if self.version in ((3,1), (3,2)):
621 w.addVarSeq(self.encryptedPreMasterSecret, 1, 2)
622 elif self.version == (3,0):
623 w.addFixSeq(self.encryptedPreMasterSecret, 1)
625 raise AssertionError()
626 elif self.cipherSuite in CipherSuite.anonSuites:
627 w.addVarSeq(numberToByteArray(self.dh_Yc), 1, 2)
629 raise AssertionError()
630 return self.postWrite(w)
632 class CertificateVerify(HandshakeMsg):
634 HandshakeMsg.__init__(self, HandshakeType.certificate_verify)
635 self.signature = bytearray(0)
637 def create(self, signature):
638 self.signature = signature
642 p.startLengthCheck(3)
643 self.signature = p.getVarBytes(2)
649 w.addVarSeq(self.signature, 1, 2)
650 return self.postWrite(w)
652 class ChangeCipherSpec(object):
654 self.contentType = ContentType.change_cipher_spec
673 class NextProtocol(HandshakeMsg):
675 HandshakeMsg.__init__(self, HandshakeType.next_protocol)
676 self.next_proto = None
678 def create(self, next_proto):
679 self.next_proto = next_proto
683 p.startLengthCheck(3)
684 self.next_proto = p.getVarBytes(1)
689 def write(self, trial=False):
691 w.addVarSeq(self.next_proto, 1, 1)
692 paddingLen = 32 - ((len(self.next_proto) + 2) % 32)
693 w.addVarSeq(bytearray(paddingLen), 1, 1)
694 return self.postWrite(w)
696 class Finished(HandshakeMsg):
697 def __init__(self, version):
698 HandshakeMsg.__init__(self, HandshakeType.finished)
699 self.version = version
700 self.verify_data = bytearray(0)
702 def create(self, verify_data):
703 self.verify_data = verify_data
707 p.startLengthCheck(3)
708 if self.version == (3,0):
709 self.verify_data = p.getFixBytes(36)
710 elif self.version in ((3,1), (3,2)):
711 self.verify_data = p.getFixBytes(12)
713 raise AssertionError()
719 w.addFixSeq(self.verify_data, 1)
720 return self.postWrite(w)
722 class EncryptedExtensions(HandshakeMsg):
724 self.channel_id_key = None
725 self.channel_id_proof = None
728 p.startLengthCheck(3)
730 while soFar != p.lengthCheck:
733 if extType == ExtensionType.channel_id:
734 if extLength != 32*4:
736 self.channel_id_key = p.getFixBytes(64)
737 self.channel_id_proof = p.getFixBytes(64)
739 p.getFixBytes(extLength)
740 soFar += 4 + extLength
744 class ApplicationData(object):
746 self.contentType = ContentType.application_data
747 self.bytes = bytearray(0)
749 def create(self, bytes):
753 def splitFirstByte(self):
754 newMsg = ApplicationData().create(self.bytes[:1])
755 self.bytes = self.bytes[1:]