+# Authors:
+# Trevor Perrin
+# Martin von Loewis - python 3 port
+#
+# See the LICENSE file for legal information regarding use of this file.
+
"""cryptomath module
This module has basic math/crypto code."""
-
+from __future__ import print_function
import os
import math
import base64
import binascii
-# The sha module is deprecated in Python 2.6
-try:
- import sha
-except ImportError:
- from hashlib import sha1 as sha
-
-# The md5 module is deprecated in Python 2.6
-try:
- import md5
-except ImportError:
- from hashlib import md5
-
-from compat import *
+from .compat import *
# **************************************************************************
except ImportError:
m2cryptoLoaded = False
-
-# Try to load cryptlib
-try:
- import cryptlib_py
- try:
- cryptlib_py.cryptInit()
- except cryptlib_py.CryptException, e:
- #If tlslite and cryptoIDlib are both present,
- #they might each try to re-initialize this,
- #so we're tolerant of that.
- if e[0] != cryptlib_py.CRYPT_ERROR_INITED:
- raise
- cryptlibpyLoaded = True
-
-except ImportError:
- cryptlibpyLoaded = False
-
#Try to load GMPY
try:
import gmpy
# PRNG Functions
# **************************************************************************
-# Get os.urandom PRNG
-try:
- os.urandom(1)
- def getRandomBytes(howMany):
- return stringToBytes(os.urandom(howMany))
- prngName = "os.urandom"
-
-except:
- # Else get cryptlib PRNG
- if cryptlibpyLoaded:
- def getRandomBytes(howMany):
- randomKey = cryptlib_py.cryptCreateContext(cryptlib_py.CRYPT_UNUSED,
- cryptlib_py.CRYPT_ALGO_AES)
- cryptlib_py.cryptSetAttribute(randomKey,
- cryptlib_py.CRYPT_CTXINFO_MODE,
- cryptlib_py.CRYPT_MODE_OFB)
- cryptlib_py.cryptGenerateKey(randomKey)
- bytes = createByteArrayZeros(howMany)
- cryptlib_py.cryptEncrypt(randomKey, bytes)
- return bytes
- prngName = "cryptlib"
-
- else:
- #Else get UNIX /dev/urandom PRNG
- try:
- devRandomFile = open("/dev/urandom", "rb")
- def getRandomBytes(howMany):
- return stringToBytes(devRandomFile.read(howMany))
- prngName = "/dev/urandom"
- except IOError:
- #Else get Win32 CryptoAPI PRNG
- try:
- import win32prng
- def getRandomBytes(howMany):
- s = win32prng.getRandomBytes(howMany)
- if len(s) != howMany:
- raise AssertionError()
- return stringToBytes(s)
- prngName ="CryptoAPI"
- except ImportError:
- #Else no PRNG :-(
- def getRandomBytes(howMany):
- raise NotImplementedError("No Random Number Generator "\
- "available.")
- prngName = "None"
+# Check that os.urandom works
+import zlib
+length = len(zlib.compress(os.urandom(1000)))
+assert(length > 900)
+
+def getRandomBytes(howMany):
+ b = bytearray(os.urandom(howMany))
+ assert(len(b) == howMany)
+ return b
+
+prngName = "os.urandom"
# **************************************************************************
-# Converter Functions
+# Simple hash functions
# **************************************************************************
-def bytesToNumber(bytes):
- total = 0L
- multiplier = 1L
- for count in range(len(bytes)-1, -1, -1):
- byte = bytes[count]
- total += multiplier * byte
- multiplier *= 256
- return total
+import hmac
+import hashlib
-def numberToBytes(n, howManyBytes=None):
- if howManyBytes == None:
- howManyBytes = numBytes(n)
- bytes = createByteArrayZeros(howManyBytes)
- for count in range(howManyBytes-1, -1, -1):
- bytes[count] = int(n % 256)
- n >>= 8
- return bytes
+def MD5(b):
+ return bytearray(hashlib.md5(compat26Str(b)).digest())
-def bytesToBase64(bytes):
- s = bytesToString(bytes)
- return stringToBase64(s)
+def SHA1(b):
+ return bytearray(hashlib.sha1(compat26Str(b)).digest())
-def base64ToBytes(s):
- s = base64ToString(s)
- return stringToBytes(s)
+def HMAC_MD5(k, b):
+ k = compatHMAC(k)
+ b = compatHMAC(b)
+ return bytearray(hmac.new(k, b, hashlib.md5).digest())
-def numberToBase64(n):
- bytes = numberToBytes(n)
- return bytesToBase64(bytes)
+def HMAC_SHA1(k, b):
+ k = compatHMAC(k)
+ b = compatHMAC(b)
+ return bytearray(hmac.new(k, b, hashlib.sha1).digest())
-def base64ToNumber(s):
- bytes = base64ToBytes(s)
- return bytesToNumber(bytes)
-def stringToNumber(s):
- bytes = stringToBytes(s)
- return bytesToNumber(bytes)
+# **************************************************************************
+# Converter Functions
+# **************************************************************************
-def numberToString(s):
- bytes = numberToBytes(s)
- return bytesToString(bytes)
+def bytesToNumber(b):
+ total = 0
+ multiplier = 1
+ for count in range(len(b)-1, -1, -1):
+ byte = b[count]
+ total += multiplier * byte
+ multiplier *= 256
+ # Force-cast to long to appease PyCrypto.
+ # https://github.com/trevp/tlslite/issues/15
+ return long(total)
-def base64ToString(s):
- try:
- return base64.decodestring(s)
- except binascii.Error, e:
- raise SyntaxError(e)
- except binascii.Incomplete, e:
- raise SyntaxError(e)
+def numberToByteArray(n, howManyBytes=None):
+ """Convert an integer into a bytearray, zero-pad to howManyBytes.
-def stringToBase64(s):
- return base64.encodestring(s).replace("\n", "")
+ The returned bytearray may be smaller than howManyBytes, but will
+ not be larger. The returned bytearray will contain a big-endian
+ encoding of the input integer (n).
+ """
+ if howManyBytes == None:
+ howManyBytes = numBytes(n)
+ b = bytearray(howManyBytes)
+ for count in range(howManyBytes-1, -1, -1):
+ b[count] = int(n % 256)
+ n >>= 8
+ return b
def mpiToNumber(mpi): #mpi is an openssl-format bignum string
if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
raise AssertionError()
- bytes = stringToBytes(mpi[4:])
- return bytesToNumber(bytes)
+ b = bytearray(mpi[4:])
+ return bytesToNumber(b)
def numberToMPI(n):
- bytes = numberToBytes(n)
+ b = numberToByteArray(n)
ext = 0
#If the high-order bit is going to be set,
#add an extra byte of zeros
if (numBits(n) & 0x7)==0:
ext = 1
length = numBytes(n) + ext
- bytes = concatArrays(createByteArrayZeros(4+ext), bytes)
- bytes[0] = (length >> 24) & 0xFF
- bytes[1] = (length >> 16) & 0xFF
- bytes[2] = (length >> 8) & 0xFF
- bytes[3] = length & 0xFF
- return bytesToString(bytes)
-
+ b = bytearray(4+ext) + b
+ b[0] = (length >> 24) & 0xFF
+ b[1] = (length >> 16) & 0xFF
+ b[2] = (length >> 8) & 0xFF
+ b[3] = length & 0xFF
+ return bytes(b)
# **************************************************************************
# Misc. Utility Functions
# **************************************************************************
+def numBits(n):
+ if n==0:
+ return 0
+ s = "%x" % n
+ return ((len(s)-1)*4) + \
+ {'0':0, '1':1, '2':2, '3':2,
+ '4':3, '5':3, '6':3, '7':3,
+ '8':4, '9':4, 'a':4, 'b':4,
+ 'c':4, 'd':4, 'e':4, 'f':4,
+ }[s[0]]
+ return int(math.floor(math.log(n, 2))+1)
+
def numBytes(n):
if n==0:
return 0
bits = numBits(n)
return int(math.ceil(bits / 8.0))
-def hashAndBase64(s):
- return stringToBase64(sha.sha(s).digest())
-
-def getBase64Nonce(numChars=22): #defaults to an 132 bit nonce
- bytes = getRandomBytes(numChars)
- bytesStr = "".join([chr(b) for b in bytes])
- return stringToBase64(bytesStr)[:numChars]
-
-
# **************************************************************************
# Big Number Math
# **************************************************************************
return a
def lcm(a, b):
- #This will break when python division changes, but we can't use // cause
- #of Jython
- return (a * b) / gcd(a, b)
+ return (a * b) // gcd(a, b)
#Returns inverse of a mod b, zero if none
#Uses Extended Euclidean Algorithm
c, d = a, b
uc, ud = 1, 0
while c != 0:
- #This will break when python division changes, but we can't use //
- #cause of Jython
- q = d / c
+ q = d // c
c, d = d-(q*c), c
uc, ud = ud - (q * uc), uc
if d == 1:
return long(result)
else:
- #Copied from Bryan G. Olson's post to comp.lang.python
- #Does left-to-right instead of pow()'s right-to-left,
- #thus about 30% faster than the python built-in with small bases
def powMod(base, power, modulus):
- nBitScan = 5
-
- """ Return base**power mod modulus, using multi bit scanning
- with nBitScan bits at a time."""
-
- #TREV - Added support for negative exponents
- negativeResult = False
- if (power < 0):
- power *= -1
- negativeResult = True
-
- exp2 = 2**nBitScan
- mask = exp2 - 1
-
- # Break power into a list of digits of nBitScan bits.
- # The list is recursive so easy to read in reverse direction.
- nibbles = None
- while power:
- nibbles = int(power & mask), nibbles
- power = power >> nBitScan
-
- # Make a table of powers of base up to 2**nBitScan - 1
- lowPowers = [1]
- for i in xrange(1, exp2):
- lowPowers.append((lowPowers[i-1] * base) % modulus)
-
- # To exponentiate by the first nibble, look it up in the table
- nib, nibbles = nibbles
- prod = lowPowers[nib]
-
- # For the rest, square nBitScan times, then multiply by
- # base^nibble
- while nibbles:
- nib, nibbles = nibbles
- for i in xrange(nBitScan):
- prod = (prod * prod) % modulus
- if nib: prod = (prod * lowPowers[nib]) % modulus
-
- #TREV - Added support for negative exponents
- if negativeResult:
- prodInv = invMod(prod, modulus)
- #Check to make sure the inverse is correct
- if (prod * prodInv) % modulus != 1:
- raise AssertionError()
- return prodInv
- return prod
-
+ if power < 0:
+ result = pow(base, power*-1, modulus)
+ result = invMod(result, modulus)
+ return result
+ else:
+ return pow(base, power, modulus)
#Pre-calculate a sieve of the ~100 primes < 1000:
def makeSieve(n):
- sieve = range(n)
+ sieve = list(range(n))
for count in range(2, int(math.sqrt(n))):
if sieve[count] == 0:
continue
#Passed trial division, proceed to Rabin-Miller
#Rabin-Miller implemented per Ferguson & Schneier
#Compute s, t for Rabin-Miller
- if display: print "*",
+ if display: print("*", end=' ')
s, t = n-1, 0
while s % 2 == 0:
- s, t = s/2, t+1
+ s, t = s//2, t+1
#Repeat Rabin-Miller x times
a = 2 #Use 2 as a base for first iteration speedup, per HAC
for count in range(iterations):
#
#Since 30 is lcm(2,3,5), we'll set our test numbers to
#29 % 30 and keep them there
- low = (2L ** (bits-1)) * 3/2
- high = 2L ** bits - 30
+ low = ((2 ** (bits-1)) * 3) // 2
+ high = 2 ** bits - 30
p = getRandomNumber(low, high)
p += 29 - (p % 30)
while 1:
- if display: print ".",
+ if display: print(".", end=' ')
p += 30
if p >= high:
p = getRandomNumber(low, high)
#
#Since 30 is lcm(2,3,5), we'll set our test numbers to
#29 % 30 and keep them there
- low = (2 ** (bits-2)) * 3/2
+ low = (2 ** (bits-2)) * 3//2
high = (2 ** (bits-1)) - 30
q = getRandomNumber(low, high)
q += 29 - (q % 30)
while 1:
- if display: print ".",
+ if display: print(".", end=' ')
q += 30
if (q >= high):
q = getRandomNumber(low, high)