#include <openssl/err.h>
#include <openssl/bn.h>
#include "crypto/bn.h"
#include "rsa_local.h"
int ossl_rsa_check_crt_components(const RSA *rsa, BN_CTX *ctx)
{
int ret = 0;
BIGNUM *r = NULL, *p1 = NULL, *q1 = NULL;
if (rsa->dmp1 == NULL || rsa->dmq1 == NULL || rsa->iqmp == NULL) {
if (rsa->dmp1 != NULL || rsa->dmq1 != NULL || rsa->iqmp != NULL)
return 0;
return 1;
}
BN_CTX_start(ctx);
r = BN_CTX_get(ctx);
p1 = BN_CTX_get(ctx);
q1 = BN_CTX_get(ctx);
if (q1 != NULL) {
BN_set_flags(r, BN_FLG_CONSTTIME);
BN_set_flags(p1, BN_FLG_CONSTTIME);
BN_set_flags(q1, BN_FLG_CONSTTIME);
ret = 1;
} else {
ret = 0;
}
ret = ret
&& (BN_copy(p1, rsa->p) != NULL)
&& BN_sub_word(p1, 1)
&& (BN_copy(q1, rsa->q) != NULL)
&& BN_sub_word(q1, 1)
&& (BN_cmp(rsa->dmp1, BN_value_one()) > 0)
&& (BN_cmp(rsa->dmp1, p1) < 0)
&& (BN_cmp(rsa->dmq1, BN_value_one()) > 0)
&& (BN_cmp(rsa->dmq1, q1) < 0)
&& (BN_cmp(rsa->iqmp, BN_value_one()) > 0)
&& (BN_cmp(rsa->iqmp, rsa->p) < 0)
&& BN_mod_mul(r, rsa->dmp1, rsa->e, p1, ctx)
&& BN_is_one(r)
&& BN_mod_mul(r, rsa->dmq1, rsa->e, q1, ctx)
&& BN_is_one(r)
&& BN_mod_mul(r, rsa->iqmp, rsa->q, rsa->p, ctx)
&& BN_is_one(r);
BN_clear(r);
BN_clear(p1);
BN_clear(q1);
BN_CTX_end(ctx);
return ret;
}
int ossl_rsa_check_prime_factor_range(const BIGNUM *p, int nbits, BN_CTX *ctx)
{
int ret = 0;
BIGNUM *low;
int shift;
nbits >>= 1;
shift = nbits - BN_num_bits(&ossl_bn_inv_sqrt_2);
if (BN_num_bits(p) != nbits)
return 0;
BN_CTX_start(ctx);
low = BN_CTX_get(ctx);
if (low == NULL)
goto err;
if (!BN_copy(low, &ossl_bn_inv_sqrt_2))
goto err;
if (shift >= 0) {
if (!BN_lshift(low, low, shift))
goto err;
} else if (!BN_rshift(low, low, -shift)) {
goto err;
}
if (BN_cmp(p, low) <= 0)
goto err;
ret = 1;
err:
BN_CTX_end(ctx);
return ret;
}
int ossl_rsa_check_prime_factor(BIGNUM *p, BIGNUM *e, int nbits, BN_CTX *ctx)
{
int ret = 0;
BIGNUM *p1 = NULL, *gcd = NULL;
if (BN_check_prime(p, ctx, NULL) != 1
|| ossl_rsa_check_prime_factor_range(p, nbits, ctx) != 1)
return 0;
BN_CTX_start(ctx);
p1 = BN_CTX_get(ctx);
gcd = BN_CTX_get(ctx);
if (gcd != NULL) {
BN_set_flags(p1, BN_FLG_CONSTTIME);
BN_set_flags(gcd, BN_FLG_CONSTTIME);
ret = 1;
} else {
ret = 0;
}
ret = ret
&& (BN_copy(p1, p) != NULL)
&& BN_sub_word(p1, 1)
&& BN_gcd(gcd, p1, e, ctx)
&& BN_is_one(gcd);
BN_clear(p1);
BN_CTX_end(ctx);
return ret;
}
int ossl_rsa_check_private_exponent(const RSA *rsa, int nbits, BN_CTX *ctx)
{
int ret;
BIGNUM *r, *p1, *q1, *lcm, *p1q1, *gcd;
if (BN_num_bits(rsa->d) <= (nbits >> 1))
return 0;
BN_CTX_start(ctx);
r = BN_CTX_get(ctx);
p1 = BN_CTX_get(ctx);
q1 = BN_CTX_get(ctx);
lcm = BN_CTX_get(ctx);
p1q1 = BN_CTX_get(ctx);
gcd = BN_CTX_get(ctx);
if (gcd != NULL) {
BN_set_flags(r, BN_FLG_CONSTTIME);
BN_set_flags(p1, BN_FLG_CONSTTIME);
BN_set_flags(q1, BN_FLG_CONSTTIME);
BN_set_flags(lcm, BN_FLG_CONSTTIME);
BN_set_flags(p1q1, BN_FLG_CONSTTIME);
BN_set_flags(gcd, BN_FLG_CONSTTIME);
ret = 1;
} else {
ret = 0;
}
ret = (ret
&& (ossl_rsa_get_lcm(ctx, rsa->p, rsa->q, lcm, gcd, p1, q1,
p1q1)
== 1)
&& (BN_cmp(rsa->d, lcm) < 0)
&& BN_mod_mul(r, rsa->e, rsa->d, lcm, ctx)
&& BN_is_one(r));
BN_clear(r);
BN_clear(p1);
BN_clear(q1);
BN_clear(lcm);
BN_clear(gcd);
BN_CTX_end(ctx);
return ret;
}
int ossl_rsa_check_public_exponent(const BIGNUM *e)
{
#ifdef FIPS_MODULE
int bitlen;
bitlen = BN_num_bits(e);
return (BN_is_odd(e) && bitlen > 16 && bitlen < 257);
#else
return BN_is_odd(e) && BN_cmp(e, BN_value_one()) > 0;
#endif
}
int ossl_rsa_check_pminusq_diff(BIGNUM *diff, const BIGNUM *p, const BIGNUM *q,
int nbits)
{
int bitlen = (nbits >> 1) - 100;
if (!BN_sub(diff, p, q))
return -1;
BN_set_negative(diff, 0);
if (BN_is_zero(diff))
return 0;
if (!BN_sub_word(diff, 1))
return -1;
return (BN_num_bits(diff) > bitlen);
}
int ossl_rsa_get_lcm(BN_CTX *ctx, const BIGNUM *p, const BIGNUM *q,
BIGNUM *lcm, BIGNUM *gcd, BIGNUM *p1, BIGNUM *q1,
BIGNUM *p1q1)
{
return BN_sub(p1, p, BN_value_one())
&& BN_sub(q1, q, BN_value_one())
&& BN_mul(p1q1, p1, q1, ctx)
&& BN_gcd(gcd, p1, q1, ctx)
&& BN_div(lcm, NULL, p1q1, gcd, ctx);
}
int ossl_rsa_sp800_56b_check_public(const RSA *rsa)
{
int ret = 0, status;
int nbits;
BN_CTX *ctx = NULL;
BIGNUM *gcd = NULL;
if (rsa->n == NULL || rsa->e == NULL)
return 0;
nbits = BN_num_bits(rsa->n);
if (nbits > OPENSSL_RSA_MAX_MODULUS_BITS) {
ERR_raise(ERR_LIB_RSA, RSA_R_MODULUS_TOO_LARGE);
return 0;
}
#ifdef FIPS_MODULE
if (!ossl_rsa_sp800_56b_validate_strength(nbits, -1)) {
ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_KEY_LENGTH);
return 0;
}
#endif
if (!BN_is_odd(rsa->n)) {
ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_MODULUS);
return 0;
}
if (!ossl_rsa_check_public_exponent(rsa->e)) {
ERR_raise(ERR_LIB_RSA, RSA_R_PUB_EXPONENT_OUT_OF_RANGE);
return 0;
}
ctx = BN_CTX_new_ex(rsa->libctx);
gcd = BN_new();
if (ctx == NULL || gcd == NULL)
goto err;
if (!BN_gcd(gcd, rsa->n, ossl_bn_get0_small_factors(), ctx)
|| !BN_is_one(gcd)) {
ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_MODULUS);
goto err;
}
ret = ossl_bn_miller_rabin_is_prime(rsa->n, 5, ctx, NULL, 1, &status);
#ifdef FIPS_MODULE
if (ret != 1 || status != BN_PRIMETEST_COMPOSITE_NOT_POWER_OF_PRIME) {
#else
if (ret != 1 || (status != BN_PRIMETEST_COMPOSITE_NOT_POWER_OF_PRIME && (nbits >= RSA_MIN_MODULUS_BITS || status != BN_PRIMETEST_COMPOSITE_WITH_FACTOR))) {
#endif
ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_MODULUS);
ret = 0;
goto err;
}
ret = 1;
err:
BN_free(gcd);
BN_CTX_free(ctx);
return ret;
}
int ossl_rsa_sp800_56b_check_private(const RSA *rsa)
{
if (rsa->d == NULL || rsa->n == NULL)
return 0;
return BN_cmp(rsa->d, BN_value_one()) >= 0 && BN_cmp(rsa->d, rsa->n) < 0;
}
int ossl_rsa_sp800_56b_check_keypair(const RSA *rsa, const BIGNUM *efixed,
int strength, int nbits)
{
int ret = 0;
BN_CTX *ctx = NULL;
BIGNUM *r = NULL;
if (rsa->p == NULL
|| rsa->q == NULL
|| rsa->e == NULL
|| rsa->d == NULL
|| rsa->n == NULL) {
ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_REQUEST);
return 0;
}
if (!ossl_rsa_sp800_56b_validate_strength(nbits, strength))
return 0;
if (efixed != NULL) {
if (BN_cmp(efixed, rsa->e) != 0) {
ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_REQUEST);
return 0;
}
}
if (!ossl_rsa_check_public_exponent(rsa->e)) {
ERR_raise(ERR_LIB_RSA, RSA_R_PUB_EXPONENT_OUT_OF_RANGE);
return 0;
}
if (nbits != BN_num_bits(rsa->n)) {
ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_KEYPAIR);
return 0;
}
if (nbits <= 0 || (nbits & 0x1)) {
ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_KEYPAIR);
return 0;
}
ctx = BN_CTX_new_ex(rsa->libctx);
if (ctx == NULL)
return 0;
BN_CTX_start(ctx);
r = BN_CTX_get(ctx);
if (r == NULL || !BN_mul(r, rsa->p, rsa->q, ctx))
goto err;
if (BN_cmp(rsa->n, r) != 0) {
ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_REQUEST);
goto err;
}
ret = ossl_rsa_check_prime_factor(rsa->p, rsa->e, nbits, ctx)
&& ossl_rsa_check_prime_factor(rsa->q, rsa->e, nbits, ctx)
&& (ossl_rsa_check_pminusq_diff(r, rsa->p, rsa->q, nbits) > 0)
&& ossl_rsa_check_private_exponent(rsa, nbits, ctx)
&& ossl_rsa_check_crt_components(rsa, ctx);
if (ret != 1)
ERR_raise(ERR_LIB_RSA, RSA_R_INVALID_KEYPAIR);
err:
BN_clear(r);
BN_CTX_end(ctx);
BN_CTX_free(ctx);
return ret;
}