3 This module has basic math/crypto code."""
10 # The sha module is deprecated in Python 2.6
14 from hashlib import sha1 as sha
16 # The md5 module is deprecated in Python 2.6
20 from hashlib import md5
25 # **************************************************************************
26 # Load Optional Modules
27 # **************************************************************************
29 # Try to load M2Crypto/OpenSSL
31 from M2Crypto import m2
35 m2cryptoLoaded = False
38 # Try to load cryptlib
42 cryptlib_py.cryptInit()
43 except cryptlib_py.CryptException, e:
44 #If tlslite and cryptoIDlib are both present,
45 #they might each try to re-initialize this,
46 #so we're tolerant of that.
47 if e[0] != cryptlib_py.CRYPT_ERROR_INITED:
49 cryptlibpyLoaded = True
52 cryptlibpyLoaded = False
63 import Crypto.Cipher.AES
66 pycryptoLoaded = False
69 # **************************************************************************
71 # **************************************************************************
76 def getRandomBytes(howMany):
77 return stringToBytes(os.urandom(howMany))
78 prngName = "os.urandom"
81 # Else get cryptlib PRNG
83 def getRandomBytes(howMany):
84 randomKey = cryptlib_py.cryptCreateContext(cryptlib_py.CRYPT_UNUSED,
85 cryptlib_py.CRYPT_ALGO_AES)
86 cryptlib_py.cryptSetAttribute(randomKey,
87 cryptlib_py.CRYPT_CTXINFO_MODE,
88 cryptlib_py.CRYPT_MODE_OFB)
89 cryptlib_py.cryptGenerateKey(randomKey)
90 bytes = createByteArrayZeros(howMany)
91 cryptlib_py.cryptEncrypt(randomKey, bytes)
96 #Else get UNIX /dev/urandom PRNG
98 devRandomFile = open("/dev/urandom", "rb")
99 def getRandomBytes(howMany):
100 return stringToBytes(devRandomFile.read(howMany))
101 prngName = "/dev/urandom"
103 #Else get Win32 CryptoAPI PRNG
106 def getRandomBytes(howMany):
107 s = win32prng.getRandomBytes(howMany)
108 if len(s) != howMany:
109 raise AssertionError()
110 return stringToBytes(s)
111 prngName ="CryptoAPI"
114 def getRandomBytes(howMany):
115 raise NotImplementedError("No Random Number Generator "\
119 # **************************************************************************
120 # Converter Functions
121 # **************************************************************************
123 def bytesToNumber(bytes):
126 for count in range(len(bytes)-1, -1, -1):
128 total += multiplier * byte
132 def numberToBytes(n):
133 howManyBytes = numBytes(n)
134 bytes = createByteArrayZeros(howManyBytes)
135 for count in range(howManyBytes-1, -1, -1):
136 bytes[count] = int(n % 256)
140 def bytesToBase64(bytes):
141 s = bytesToString(bytes)
142 return stringToBase64(s)
144 def base64ToBytes(s):
145 s = base64ToString(s)
146 return stringToBytes(s)
148 def numberToBase64(n):
149 bytes = numberToBytes(n)
150 return bytesToBase64(bytes)
152 def base64ToNumber(s):
153 bytes = base64ToBytes(s)
154 return bytesToNumber(bytes)
156 def stringToNumber(s):
157 bytes = stringToBytes(s)
158 return bytesToNumber(bytes)
160 def numberToString(s):
161 bytes = numberToBytes(s)
162 return bytesToString(bytes)
164 def base64ToString(s):
166 return base64.decodestring(s)
167 except binascii.Error, e:
169 except binascii.Incomplete, e:
172 def stringToBase64(s):
173 return base64.encodestring(s).replace("\n", "")
175 def mpiToNumber(mpi): #mpi is an openssl-format bignum string
176 if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
177 raise AssertionError()
178 bytes = stringToBytes(mpi[4:])
179 return bytesToNumber(bytes)
182 bytes = numberToBytes(n)
184 #If the high-order bit is going to be set,
185 #add an extra byte of zeros
186 if (numBits(n) & 0x7)==0:
188 length = numBytes(n) + ext
189 bytes = concatArrays(createByteArrayZeros(4+ext), bytes)
190 bytes[0] = (length >> 24) & 0xFF
191 bytes[1] = (length >> 16) & 0xFF
192 bytes[2] = (length >> 8) & 0xFF
193 bytes[3] = length & 0xFF
194 return bytesToString(bytes)
198 # **************************************************************************
199 # Misc. Utility Functions
200 # **************************************************************************
206 return int(math.ceil(bits / 8.0))
208 def hashAndBase64(s):
209 return stringToBase64(sha.sha(s).digest())
211 def getBase64Nonce(numChars=22): #defaults to an 132 bit nonce
212 bytes = getRandomBytes(numChars)
213 bytesStr = "".join([chr(b) for b in bytes])
214 return stringToBase64(bytesStr)[:numChars]
217 # **************************************************************************
219 # **************************************************************************
221 def getRandomNumber(low, high):
223 raise AssertionError()
224 howManyBits = numBits(high)
225 howManyBytes = numBytes(high)
226 lastBits = howManyBits % 8
228 bytes = getRandomBytes(howManyBytes)
230 bytes[0] = bytes[0] % (1 << lastBits)
231 n = bytesToNumber(bytes)
232 if n >= low and n < high:
236 a, b = max(a,b), min(a,b)
242 #This will break when python division changes, but we can't use // cause
244 return (a * b) / gcd(a, b)
246 #Returns inverse of a mod b, zero if none
247 #Uses Extended Euclidean Algorithm
252 #This will break when python division changes, but we can't use //
256 uc, ud = ud - (q * uc), uc
263 def powMod(base, power, modulus):
264 base = gmpy.mpz(base)
265 power = gmpy.mpz(power)
266 modulus = gmpy.mpz(modulus)
267 result = pow(base, power, modulus)
271 #Copied from Bryan G. Olson's post to comp.lang.python
272 #Does left-to-right instead of pow()'s right-to-left,
273 #thus about 30% faster than the python built-in with small bases
274 def powMod(base, power, modulus):
277 """ Return base**power mod modulus, using multi bit scanning
278 with nBitScan bits at a time."""
280 #TREV - Added support for negative exponents
281 negativeResult = False
284 negativeResult = True
289 # Break power into a list of digits of nBitScan bits.
290 # The list is recursive so easy to read in reverse direction.
293 nibbles = int(power & mask), nibbles
294 power = power >> nBitScan
296 # Make a table of powers of base up to 2**nBitScan - 1
298 for i in xrange(1, exp2):
299 lowPowers.append((lowPowers[i-1] * base) % modulus)
301 # To exponentiate by the first nibble, look it up in the table
302 nib, nibbles = nibbles
303 prod = lowPowers[nib]
305 # For the rest, square nBitScan times, then multiply by
308 nib, nibbles = nibbles
309 for i in xrange(nBitScan):
310 prod = (prod * prod) % modulus
311 if nib: prod = (prod * lowPowers[nib]) % modulus
313 #TREV - Added support for negative exponents
315 prodInv = invMod(prod, modulus)
316 #Check to make sure the inverse is correct
317 if (prod * prodInv) % modulus != 1:
318 raise AssertionError()
323 #Pre-calculate a sieve of the ~100 primes < 1000:
326 for count in range(2, int(math.sqrt(n))):
327 if sieve[count] == 0:
330 while x < len(sieve):
333 sieve = [x for x in sieve[2:] if x]
336 sieve = makeSieve(1000)
338 def isPrime(n, iterations=5, display=False):
339 #Trial division with sieve
341 if x >= n: return True
342 if n % x == 0: return False
343 #Passed trial division, proceed to Rabin-Miller
344 #Rabin-Miller implemented per Ferguson & Schneier
345 #Compute s, t for Rabin-Miller
346 if display: print "*",
350 #Repeat Rabin-Miller x times
351 a = 2 #Use 2 as a base for first iteration speedup, per HAC
352 for count in range(iterations):
361 v, i = powMod(v, 2, n), i+1
362 a = getRandomNumber(2, n)
365 def getRandomPrime(bits, display=False):
367 raise AssertionError()
368 #The 1.5 ensures the 2 MSBs are set
369 #Thus, when used for p,q in RSA, n will have its MSB set
371 #Since 30 is lcm(2,3,5), we'll set our test numbers to
372 #29 % 30 and keep them there
373 low = (2L ** (bits-1)) * 3/2
374 high = 2L ** bits - 30
375 p = getRandomNumber(low, high)
378 if display: print ".",
381 p = getRandomNumber(low, high)
383 if isPrime(p, display=display):
386 #Unused at the moment...
387 def getRandomSafePrime(bits, display=False):
389 raise AssertionError()
390 #The 1.5 ensures the 2 MSBs are set
391 #Thus, when used for p,q in RSA, n will have its MSB set
393 #Since 30 is lcm(2,3,5), we'll set our test numbers to
394 #29 % 30 and keep them there
395 low = (2 ** (bits-2)) * 3/2
396 high = (2 ** (bits-1)) - 30
397 q = getRandomNumber(low, high)
400 if display: print ".",
403 q = getRandomNumber(low, high)
405 #Ideas from Tom Wu's SRP code
406 #Do trial division on p and q before Rabin-Miller
407 if isPrime(q, 0, display=display):
409 if isPrime(p, display=display):
410 if isPrime(q, display=display):