Upstream version 5.34.104.0
[platform/framework/web/crosswalk.git] / src / third_party / tlslite / tlslite / utils / cryptomath.py
1 """cryptomath module
2
3 This module has basic math/crypto code."""
4
5 import os
6 import math
7 import base64
8 import binascii
9
10 # The sha module is deprecated in Python 2.6 
11 try:
12     import sha
13 except ImportError:
14     from hashlib import sha1 as sha
15
16 # The md5 module is deprecated in Python 2.6
17 try:
18     import md5
19 except ImportError:
20     from hashlib import md5
21
22 from compat import *
23
24
25 # **************************************************************************
26 # Load Optional Modules
27 # **************************************************************************
28
29 # Try to load M2Crypto/OpenSSL
30 try:
31     from M2Crypto import m2
32     m2cryptoLoaded = True
33
34 except ImportError:
35     m2cryptoLoaded = False
36
37
38 # Try to load cryptlib
39 try:
40     import cryptlib_py
41     try:
42         cryptlib_py.cryptInit()
43     except cryptlib_py.CryptException, e:
44         #If tlslite and cryptoIDlib are both present,
45         #they might each try to re-initialize this,
46         #so we're tolerant of that.
47         if e[0] != cryptlib_py.CRYPT_ERROR_INITED:
48             raise
49     cryptlibpyLoaded = True
50
51 except ImportError:
52     cryptlibpyLoaded = False
53
54 #Try to load GMPY
55 try:
56     import gmpy
57     gmpyLoaded = True
58 except ImportError:
59     gmpyLoaded = False
60
61 #Try to load pycrypto
62 try:
63     import Crypto.Cipher.AES
64     pycryptoLoaded = True
65 except ImportError:
66     pycryptoLoaded = False
67
68
69 # **************************************************************************
70 # PRNG Functions
71 # **************************************************************************
72
73 # Get os.urandom PRNG
74 try:
75     os.urandom(1)
76     def getRandomBytes(howMany):
77         return stringToBytes(os.urandom(howMany))
78     prngName = "os.urandom"
79
80 except:
81     # Else get cryptlib PRNG
82     if cryptlibpyLoaded:
83         def getRandomBytes(howMany):
84             randomKey = cryptlib_py.cryptCreateContext(cryptlib_py.CRYPT_UNUSED,
85                                                        cryptlib_py.CRYPT_ALGO_AES)
86             cryptlib_py.cryptSetAttribute(randomKey,
87                                           cryptlib_py.CRYPT_CTXINFO_MODE,
88                                           cryptlib_py.CRYPT_MODE_OFB)
89             cryptlib_py.cryptGenerateKey(randomKey)
90             bytes = createByteArrayZeros(howMany)
91             cryptlib_py.cryptEncrypt(randomKey, bytes)
92             return bytes
93         prngName = "cryptlib"
94
95     else:
96         #Else get UNIX /dev/urandom PRNG
97         try:
98             devRandomFile = open("/dev/urandom", "rb")
99             def getRandomBytes(howMany):
100                 return stringToBytes(devRandomFile.read(howMany))
101             prngName = "/dev/urandom"
102         except IOError:
103             #Else get Win32 CryptoAPI PRNG
104             try:
105                 import win32prng
106                 def getRandomBytes(howMany):
107                     s = win32prng.getRandomBytes(howMany)
108                     if len(s) != howMany:
109                         raise AssertionError()
110                     return stringToBytes(s)
111                 prngName ="CryptoAPI"
112             except ImportError:
113                 #Else no PRNG :-(
114                 def getRandomBytes(howMany):
115                     raise NotImplementedError("No Random Number Generator "\
116                                               "available.")
117             prngName = "None"
118
119 # **************************************************************************
120 # Converter Functions
121 # **************************************************************************
122
123 def bytesToNumber(bytes):
124     total = 0L
125     multiplier = 1L
126     for count in range(len(bytes)-1, -1, -1):
127         byte = bytes[count]
128         total += multiplier * byte
129         multiplier *= 256
130     return total
131
132 def numberToBytes(n, howManyBytes=None):
133     if howManyBytes == None:
134       howManyBytes = numBytes(n)
135     bytes = createByteArrayZeros(howManyBytes)
136     for count in range(howManyBytes-1, -1, -1):
137         bytes[count] = int(n % 256)
138         n >>= 8
139     return bytes
140
141 def bytesToBase64(bytes):
142     s = bytesToString(bytes)
143     return stringToBase64(s)
144
145 def base64ToBytes(s):
146     s = base64ToString(s)
147     return stringToBytes(s)
148
149 def numberToBase64(n):
150     bytes = numberToBytes(n)
151     return bytesToBase64(bytes)
152
153 def base64ToNumber(s):
154     bytes = base64ToBytes(s)
155     return bytesToNumber(bytes)
156
157 def stringToNumber(s):
158     bytes = stringToBytes(s)
159     return bytesToNumber(bytes)
160
161 def numberToString(s):
162     bytes = numberToBytes(s)
163     return bytesToString(bytes)
164
165 def base64ToString(s):
166     try:
167         return base64.decodestring(s)
168     except binascii.Error, e:
169         raise SyntaxError(e)
170     except binascii.Incomplete, e:
171         raise SyntaxError(e)
172
173 def stringToBase64(s):
174     return base64.encodestring(s).replace("\n", "")
175
176 def mpiToNumber(mpi): #mpi is an openssl-format bignum string
177     if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
178         raise AssertionError()
179     bytes = stringToBytes(mpi[4:])
180     return bytesToNumber(bytes)
181
182 def numberToMPI(n):
183     bytes = numberToBytes(n)
184     ext = 0
185     #If the high-order bit is going to be set,
186     #add an extra byte of zeros
187     if (numBits(n) & 0x7)==0:
188         ext = 1
189     length = numBytes(n) + ext
190     bytes = concatArrays(createByteArrayZeros(4+ext), bytes)
191     bytes[0] = (length >> 24) & 0xFF
192     bytes[1] = (length >> 16) & 0xFF
193     bytes[2] = (length >> 8) & 0xFF
194     bytes[3] = length & 0xFF
195     return bytesToString(bytes)
196
197
198
199 # **************************************************************************
200 # Misc. Utility Functions
201 # **************************************************************************
202
203 def numBytes(n):
204     if n==0:
205         return 0
206     bits = numBits(n)
207     return int(math.ceil(bits / 8.0))
208
209 def hashAndBase64(s):
210     return stringToBase64(sha.sha(s).digest())
211
212 def getBase64Nonce(numChars=22): #defaults to an 132 bit nonce
213     bytes = getRandomBytes(numChars)
214     bytesStr = "".join([chr(b) for b in bytes])
215     return stringToBase64(bytesStr)[:numChars]
216
217
218 # **************************************************************************
219 # Big Number Math
220 # **************************************************************************
221
222 def getRandomNumber(low, high):
223     if low >= high:
224         raise AssertionError()
225     howManyBits = numBits(high)
226     howManyBytes = numBytes(high)
227     lastBits = howManyBits % 8
228     while 1:
229         bytes = getRandomBytes(howManyBytes)
230         if lastBits:
231             bytes[0] = bytes[0] % (1 << lastBits)
232         n = bytesToNumber(bytes)
233         if n >= low and n < high:
234             return n
235
236 def gcd(a,b):
237     a, b = max(a,b), min(a,b)
238     while b:
239         a, b = b, a % b
240     return a
241
242 def lcm(a, b):
243     #This will break when python division changes, but we can't use // cause
244     #of Jython
245     return (a * b) / gcd(a, b)
246
247 #Returns inverse of a mod b, zero if none
248 #Uses Extended Euclidean Algorithm
249 def invMod(a, b):
250     c, d = a, b
251     uc, ud = 1, 0
252     while c != 0:
253         #This will break when python division changes, but we can't use //
254         #cause of Jython
255         q = d / c
256         c, d = d-(q*c), c
257         uc, ud = ud - (q * uc), uc
258     if d == 1:
259         return ud % b
260     return 0
261
262
263 if gmpyLoaded:
264     def powMod(base, power, modulus):
265         base = gmpy.mpz(base)
266         power = gmpy.mpz(power)
267         modulus = gmpy.mpz(modulus)
268         result = pow(base, power, modulus)
269         return long(result)
270
271 else:
272     #Copied from Bryan G. Olson's post to comp.lang.python
273     #Does left-to-right instead of pow()'s right-to-left,
274     #thus about 30% faster than the python built-in with small bases
275     def powMod(base, power, modulus):
276         nBitScan = 5
277
278         """ Return base**power mod modulus, using multi bit scanning
279         with nBitScan bits at a time."""
280
281         #TREV - Added support for negative exponents
282         negativeResult = False
283         if (power < 0):
284             power *= -1
285             negativeResult = True
286
287         exp2 = 2**nBitScan
288         mask = exp2 - 1
289
290         # Break power into a list of digits of nBitScan bits.
291         # The list is recursive so easy to read in reverse direction.
292         nibbles = None
293         while power:
294             nibbles = int(power & mask), nibbles
295             power = power >> nBitScan
296
297         # Make a table of powers of base up to 2**nBitScan - 1
298         lowPowers = [1]
299         for i in xrange(1, exp2):
300             lowPowers.append((lowPowers[i-1] * base) % modulus)
301
302         # To exponentiate by the first nibble, look it up in the table
303         nib, nibbles = nibbles
304         prod = lowPowers[nib]
305
306         # For the rest, square nBitScan times, then multiply by
307         # base^nibble
308         while nibbles:
309             nib, nibbles = nibbles
310             for i in xrange(nBitScan):
311                 prod = (prod * prod) % modulus
312             if nib: prod = (prod * lowPowers[nib]) % modulus
313
314         #TREV - Added support for negative exponents
315         if negativeResult:
316             prodInv = invMod(prod, modulus)
317             #Check to make sure the inverse is correct
318             if (prod * prodInv) % modulus != 1:
319                 raise AssertionError()
320             return prodInv
321         return prod
322
323
324 #Pre-calculate a sieve of the ~100 primes < 1000:
325 def makeSieve(n):
326     sieve = range(n)
327     for count in range(2, int(math.sqrt(n))):
328         if sieve[count] == 0:
329             continue
330         x = sieve[count] * 2
331         while x < len(sieve):
332             sieve[x] = 0
333             x += sieve[count]
334     sieve = [x for x in sieve[2:] if x]
335     return sieve
336
337 sieve = makeSieve(1000)
338
339 def isPrime(n, iterations=5, display=False):
340     #Trial division with sieve
341     for x in sieve:
342         if x >= n: return True
343         if n % x == 0: return False
344     #Passed trial division, proceed to Rabin-Miller
345     #Rabin-Miller implemented per Ferguson & Schneier
346     #Compute s, t for Rabin-Miller
347     if display: print "*",
348     s, t = n-1, 0
349     while s % 2 == 0:
350         s, t = s/2, t+1
351     #Repeat Rabin-Miller x times
352     a = 2 #Use 2 as a base for first iteration speedup, per HAC
353     for count in range(iterations):
354         v = powMod(a, s, n)
355         if v==1:
356             continue
357         i = 0
358         while v != n-1:
359             if i == t-1:
360                 return False
361             else:
362                 v, i = powMod(v, 2, n), i+1
363         a = getRandomNumber(2, n)
364     return True
365
366 def getRandomPrime(bits, display=False):
367     if bits < 10:
368         raise AssertionError()
369     #The 1.5 ensures the 2 MSBs are set
370     #Thus, when used for p,q in RSA, n will have its MSB set
371     #
372     #Since 30 is lcm(2,3,5), we'll set our test numbers to
373     #29 % 30 and keep them there
374     low = (2L ** (bits-1)) * 3/2
375     high = 2L ** bits - 30
376     p = getRandomNumber(low, high)
377     p += 29 - (p % 30)
378     while 1:
379         if display: print ".",
380         p += 30
381         if p >= high:
382             p = getRandomNumber(low, high)
383             p += 29 - (p % 30)
384         if isPrime(p, display=display):
385             return p
386
387 #Unused at the moment...
388 def getRandomSafePrime(bits, display=False):
389     if bits < 10:
390         raise AssertionError()
391     #The 1.5 ensures the 2 MSBs are set
392     #Thus, when used for p,q in RSA, n will have its MSB set
393     #
394     #Since 30 is lcm(2,3,5), we'll set our test numbers to
395     #29 % 30 and keep them there
396     low = (2 ** (bits-2)) * 3/2
397     high = (2 ** (bits-1)) - 30
398     q = getRandomNumber(low, high)
399     q += 29 - (q % 30)
400     while 1:
401         if display: print ".",
402         q += 30
403         if (q >= high):
404             q = getRandomNumber(low, high)
405             q += 29 - (q % 30)
406         #Ideas from Tom Wu's SRP code
407         #Do trial division on p and q before Rabin-Miller
408         if isPrime(q, 0, display=display):
409             p = (2 * q) + 1
410             if isPrime(p, display=display):
411                 if isPrime(q, display=display):
412                     return p