Upstream version 7.36.149.0
[platform/framework/web/crosswalk.git] / src / third_party / tlslite / tlslite / utils / cryptomath.py
index 86da25e..30354b2 100644 (file)
@@ -1,25 +1,19 @@
+# Authors: 
+#   Trevor Perrin
+#   Martin von Loewis - python 3 port
+#
+# See the LICENSE file for legal information regarding use of this file.
+
 """cryptomath module
 
 This module has basic math/crypto code."""
-
+from __future__ import print_function
 import os
 import math
 import base64
 import binascii
 
-# The sha module is deprecated in Python 2.6 
-try:
-    import sha
-except ImportError:
-    from hashlib import sha1 as sha
-
-# The md5 module is deprecated in Python 2.6
-try:
-    import md5
-except ImportError:
-    from hashlib import md5
-
-from compat import *
+from .compat import *
 
 
 # **************************************************************************
@@ -34,23 +28,6 @@ try:
 except ImportError:
     m2cryptoLoaded = False
 
-
-# Try to load cryptlib
-try:
-    import cryptlib_py
-    try:
-        cryptlib_py.cryptInit()
-    except cryptlib_py.CryptException, e:
-        #If tlslite and cryptoIDlib are both present,
-        #they might each try to re-initialize this,
-        #so we're tolerant of that.
-        if e[0] != cryptlib_py.CRYPT_ERROR_INITED:
-            raise
-    cryptlibpyLoaded = True
-
-except ImportError:
-    cryptlibpyLoaded = False
-
 #Try to load GMPY
 try:
     import gmpy
@@ -70,151 +47,116 @@ except ImportError:
 # PRNG Functions
 # **************************************************************************
 
-# Get os.urandom PRNG
-try:
-    os.urandom(1)
-    def getRandomBytes(howMany):
-        return stringToBytes(os.urandom(howMany))
-    prngName = "os.urandom"
-
-except:
-    # Else get cryptlib PRNG
-    if cryptlibpyLoaded:
-        def getRandomBytes(howMany):
-            randomKey = cryptlib_py.cryptCreateContext(cryptlib_py.CRYPT_UNUSED,
-                                                       cryptlib_py.CRYPT_ALGO_AES)
-            cryptlib_py.cryptSetAttribute(randomKey,
-                                          cryptlib_py.CRYPT_CTXINFO_MODE,
-                                          cryptlib_py.CRYPT_MODE_OFB)
-            cryptlib_py.cryptGenerateKey(randomKey)
-            bytes = createByteArrayZeros(howMany)
-            cryptlib_py.cryptEncrypt(randomKey, bytes)
-            return bytes
-        prngName = "cryptlib"
-
-    else:
-        #Else get UNIX /dev/urandom PRNG
-        try:
-            devRandomFile = open("/dev/urandom", "rb")
-            def getRandomBytes(howMany):
-                return stringToBytes(devRandomFile.read(howMany))
-            prngName = "/dev/urandom"
-        except IOError:
-            #Else get Win32 CryptoAPI PRNG
-            try:
-                import win32prng
-                def getRandomBytes(howMany):
-                    s = win32prng.getRandomBytes(howMany)
-                    if len(s) != howMany:
-                        raise AssertionError()
-                    return stringToBytes(s)
-                prngName ="CryptoAPI"
-            except ImportError:
-                #Else no PRNG :-(
-                def getRandomBytes(howMany):
-                    raise NotImplementedError("No Random Number Generator "\
-                                              "available.")
-            prngName = "None"
+# Check that os.urandom works
+import zlib
+length = len(zlib.compress(os.urandom(1000)))
+assert(length > 900)
+
+def getRandomBytes(howMany):
+    b = bytearray(os.urandom(howMany))
+    assert(len(b) == howMany)
+    return b
+
+prngName = "os.urandom"
 
 # **************************************************************************
-# Converter Functions
+# Simple hash functions
 # **************************************************************************
 
-def bytesToNumber(bytes):
-    total = 0L
-    multiplier = 1L
-    for count in range(len(bytes)-1, -1, -1):
-        byte = bytes[count]
-        total += multiplier * byte
-        multiplier *= 256
-    return total
+import hmac
+import hashlib
 
-def numberToBytes(n, howManyBytes=None):
-    if howManyBytes == None:
-      howManyBytes = numBytes(n)
-    bytes = createByteArrayZeros(howManyBytes)
-    for count in range(howManyBytes-1, -1, -1):
-        bytes[count] = int(n % 256)
-        n >>= 8
-    return bytes
+def MD5(b):
+    return bytearray(hashlib.md5(compat26Str(b)).digest())
 
-def bytesToBase64(bytes):
-    s = bytesToString(bytes)
-    return stringToBase64(s)
+def SHA1(b):
+    return bytearray(hashlib.sha1(compat26Str(b)).digest())
 
-def base64ToBytes(s):
-    s = base64ToString(s)
-    return stringToBytes(s)
+def HMAC_MD5(k, b):
+    k = compatHMAC(k)
+    b = compatHMAC(b)
+    return bytearray(hmac.new(k, b, hashlib.md5).digest())
 
-def numberToBase64(n):
-    bytes = numberToBytes(n)
-    return bytesToBase64(bytes)
+def HMAC_SHA1(k, b):
+    k = compatHMAC(k)
+    b = compatHMAC(b)
+    return bytearray(hmac.new(k, b, hashlib.sha1).digest())
 
-def base64ToNumber(s):
-    bytes = base64ToBytes(s)
-    return bytesToNumber(bytes)
 
-def stringToNumber(s):
-    bytes = stringToBytes(s)
-    return bytesToNumber(bytes)
+# **************************************************************************
+# Converter Functions
+# **************************************************************************
 
-def numberToString(s):
-    bytes = numberToBytes(s)
-    return bytesToString(bytes)
+def bytesToNumber(b):
+    total = 0
+    multiplier = 1
+    for count in range(len(b)-1, -1, -1):
+        byte = b[count]
+        total += multiplier * byte
+        multiplier *= 256
+    # Force-cast to long to appease PyCrypto.
+    # https://github.com/trevp/tlslite/issues/15
+    return long(total)
 
-def base64ToString(s):
-    try:
-        return base64.decodestring(s)
-    except binascii.Error, e:
-        raise SyntaxError(e)
-    except binascii.Incomplete, e:
-        raise SyntaxError(e)
+def numberToByteArray(n, howManyBytes=None):
+    """Convert an integer into a bytearray, zero-pad to howManyBytes.
 
-def stringToBase64(s):
-    return base64.encodestring(s).replace("\n", "")
+    The returned bytearray may be smaller than howManyBytes, but will
+    not be larger.  The returned bytearray will contain a big-endian
+    encoding of the input integer (n).
+    """    
+    if howManyBytes == None:
+        howManyBytes = numBytes(n)
+    b = bytearray(howManyBytes)
+    for count in range(howManyBytes-1, -1, -1):
+        b[count] = int(n % 256)
+        n >>= 8
+    return b
 
 def mpiToNumber(mpi): #mpi is an openssl-format bignum string
     if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
         raise AssertionError()
-    bytes = stringToBytes(mpi[4:])
-    return bytesToNumber(bytes)
+    b = bytearray(mpi[4:])
+    return bytesToNumber(b)
 
 def numberToMPI(n):
-    bytes = numberToBytes(n)
+    b = numberToByteArray(n)
     ext = 0
     #If the high-order bit is going to be set,
     #add an extra byte of zeros
     if (numBits(n) & 0x7)==0:
         ext = 1
     length = numBytes(n) + ext
-    bytes = concatArrays(createByteArrayZeros(4+ext), bytes)
-    bytes[0] = (length >> 24) & 0xFF
-    bytes[1] = (length >> 16) & 0xFF
-    bytes[2] = (length >> 8) & 0xFF
-    bytes[3] = length & 0xFF
-    return bytesToString(bytes)
-
+    b = bytearray(4+ext) + b
+    b[0] = (length >> 24) & 0xFF
+    b[1] = (length >> 16) & 0xFF
+    b[2] = (length >> 8) & 0xFF
+    b[3] = length & 0xFF
+    return bytes(b)
 
 
 # **************************************************************************
 # Misc. Utility Functions
 # **************************************************************************
 
+def numBits(n):
+    if n==0:
+        return 0
+    s = "%x" % n
+    return ((len(s)-1)*4) + \
+    {'0':0, '1':1, '2':2, '3':2,
+     '4':3, '5':3, '6':3, '7':3,
+     '8':4, '9':4, 'a':4, 'b':4,
+     'c':4, 'd':4, 'e':4, 'f':4,
+     }[s[0]]
+    return int(math.floor(math.log(n, 2))+1)
+
 def numBytes(n):
     if n==0:
         return 0
     bits = numBits(n)
     return int(math.ceil(bits / 8.0))
 
-def hashAndBase64(s):
-    return stringToBase64(sha.sha(s).digest())
-
-def getBase64Nonce(numChars=22): #defaults to an 132 bit nonce
-    bytes = getRandomBytes(numChars)
-    bytesStr = "".join([chr(b) for b in bytes])
-    return stringToBase64(bytesStr)[:numChars]
-
-
 # **************************************************************************
 # Big Number Math
 # **************************************************************************
@@ -240,9 +182,7 @@ def gcd(a,b):
     return a
 
 def lcm(a, b):
-    #This will break when python division changes, but we can't use // cause
-    #of Jython
-    return (a * b) / gcd(a, b)
+    return (a * b) // gcd(a, b)
 
 #Returns inverse of a mod b, zero if none
 #Uses Extended Euclidean Algorithm
@@ -250,9 +190,7 @@ def invMod(a, b):
     c, d = a, b
     uc, ud = 1, 0
     while c != 0:
-        #This will break when python division changes, but we can't use //
-        #cause of Jython
-        q = d / c
+        q = d // c
         c, d = d-(q*c), c
         uc, ud = ud - (q * uc), uc
     if d == 1:
@@ -269,61 +207,17 @@ if gmpyLoaded:
         return long(result)
 
 else:
-    #Copied from Bryan G. Olson's post to comp.lang.python
-    #Does left-to-right instead of pow()'s right-to-left,
-    #thus about 30% faster than the python built-in with small bases
     def powMod(base, power, modulus):
-        nBitScan = 5
-
-        """ Return base**power mod modulus, using multi bit scanning
-        with nBitScan bits at a time."""
-
-        #TREV - Added support for negative exponents
-        negativeResult = False
-        if (power < 0):
-            power *= -1
-            negativeResult = True
-
-        exp2 = 2**nBitScan
-        mask = exp2 - 1
-
-        # Break power into a list of digits of nBitScan bits.
-        # The list is recursive so easy to read in reverse direction.
-        nibbles = None
-        while power:
-            nibbles = int(power & mask), nibbles
-            power = power >> nBitScan
-
-        # Make a table of powers of base up to 2**nBitScan - 1
-        lowPowers = [1]
-        for i in xrange(1, exp2):
-            lowPowers.append((lowPowers[i-1] * base) % modulus)
-
-        # To exponentiate by the first nibble, look it up in the table
-        nib, nibbles = nibbles
-        prod = lowPowers[nib]
-
-        # For the rest, square nBitScan times, then multiply by
-        # base^nibble
-        while nibbles:
-            nib, nibbles = nibbles
-            for i in xrange(nBitScan):
-                prod = (prod * prod) % modulus
-            if nib: prod = (prod * lowPowers[nib]) % modulus
-
-        #TREV - Added support for negative exponents
-        if negativeResult:
-            prodInv = invMod(prod, modulus)
-            #Check to make sure the inverse is correct
-            if (prod * prodInv) % modulus != 1:
-                raise AssertionError()
-            return prodInv
-        return prod
-
+        if power < 0:
+            result = pow(base, power*-1, modulus)
+            result = invMod(result, modulus)
+            return result
+        else:
+            return pow(base, power, modulus)
 
 #Pre-calculate a sieve of the ~100 primes < 1000:
 def makeSieve(n):
-    sieve = range(n)
+    sieve = list(range(n))
     for count in range(2, int(math.sqrt(n))):
         if sieve[count] == 0:
             continue
@@ -344,10 +238,10 @@ def isPrime(n, iterations=5, display=False):
     #Passed trial division, proceed to Rabin-Miller
     #Rabin-Miller implemented per Ferguson & Schneier
     #Compute s, t for Rabin-Miller
-    if display: print "*",
+    if display: print("*", end=' ')
     s, t = n-1, 0
     while s % 2 == 0:
-        s, t = s/2, t+1
+        s, t = s//2, t+1
     #Repeat Rabin-Miller x times
     a = 2 #Use 2 as a base for first iteration speedup, per HAC
     for count in range(iterations):
@@ -371,12 +265,12 @@ def getRandomPrime(bits, display=False):
     #
     #Since 30 is lcm(2,3,5), we'll set our test numbers to
     #29 % 30 and keep them there
-    low = (2L ** (bits-1)) * 3/2
-    high = 2L ** bits - 30
+    low = ((2 ** (bits-1)) * 3) // 2
+    high = 2 ** bits - 30
     p = getRandomNumber(low, high)
     p += 29 - (p % 30)
     while 1:
-        if display: print ".",
+        if display: print(".", end=' ')
         p += 30
         if p >= high:
             p = getRandomNumber(low, high)
@@ -393,12 +287,12 @@ def getRandomSafePrime(bits, display=False):
     #
     #Since 30 is lcm(2,3,5), we'll set our test numbers to
     #29 % 30 and keep them there
-    low = (2 ** (bits-2)) * 3/2
+    low = (2 ** (bits-2)) * 3//2
     high = (2 ** (bits-1)) - 30
     q = getRandomNumber(low, high)
     q += 29 - (q % 30)
     while 1:
-        if display: print ".",
+        if display: print(".", end=' ')
         q += 30
         if (q >= high):
             q = getRandomNumber(low, high)