static int crypto_rsa_common(const BYTE* input, int length, UINT32 key_length, const BYTE* modulus,
const BYTE* exponent, int exponent_size, BYTE* output)
{
- BN_CTX* ctx;
+ BN_CTX* ctx = NULL;
int output_length = -1;
- BYTE* input_reverse;
- BYTE* modulus_reverse;
- BYTE* exponent_reverse;
- BIGNUM *mod, *exp, *x, *y;
- input_reverse = (BYTE*)malloc(2 * key_length + exponent_size);
+ BYTE* input_reverse = NULL;
+ BYTE* modulus_reverse = NULL;
+ BYTE* exponent_reverse = NULL;
+ BIGNUM* mod = NULL;
+ BIGNUM* exp = NULL;
+ BIGNUM* x = NULL;
+ BIGNUM* y = NULL;
+ size_t bufferSize = 2 * key_length + exponent_size;
+
+ if (!input || (length < 0) || (exponent_size < 0) || !modulus || !exponent || !output)
+ return -1;
+
+ if (length > bufferSize)
+ bufferSize = length;
+
+ input_reverse = (BYTE*)calloc(bufferSize, 1);
if (!input_reverse)
return -1;
if (!(y = BN_new()))
goto fail_bn_y;
- BN_bin2bn(modulus_reverse, key_length, mod);
- BN_bin2bn(exponent_reverse, exponent_size, exp);
- BN_bin2bn(input_reverse, length, x);
- BN_mod_exp(y, x, exp, mod, ctx);
+ if (!BN_bin2bn(modulus_reverse, key_length, mod))
+ goto fail;
+
+ if (!BN_bin2bn(exponent_reverse, exponent_size, exp))
+ goto fail;
+ if (!BN_bin2bn(input_reverse, length, x))
+ goto fail;
+ if (BN_mod_exp(y, x, exp, mod, ctx) != 1)
+ goto fail;
output_length = BN_bn2bin(y, output);
+ if (output_length < 0)
+ goto fail;
crypto_reverse(output, output_length);
- if (output_length < (int)key_length)
+ if (output_length < key_length)
memset(output + output_length, 0, key_length - output_length);
+fail:
BN_free(y);
fail_bn_y:
BN_clear_free(x);