#include <libecc/nn/nn_mul_redc1.h>
#include <libecc/nn/nn_div_public.h>
#include <libecc/nn/nn_logical.h>
#include <libecc/nn/nn_mod_pow.h>
#include <libecc/nn/nn_rand.h>
#include <libecc/nn/nn.h>
ATTRIBUTE_WARN_UNUSED_RET static int _nn_exp_monty_ladder_ltr(nn_t out, nn_src_t base, nn_src_t exp, nn_src_t mod, nn_src_t r, nn_src_t r_square, word_t mpinv)
{
nn T[3];
nn mask;
bitcnt_t explen, oldexplen;
u8 expbit, rbit;
int ret, cmp;
T[0].magic = T[1].magic = T[2].magic = mask.magic = WORD(0);
ret = nn_init(out, 0); EG(ret, err);
ret = nn_init(&T[0], 0); EG(ret, err);
ret = nn_init(&T[1], 0); EG(ret, err);
ret = nn_init(&T[2], 0); EG(ret, err);
ret = nn_get_random_len(&mask, NN_MAX_BYTE_LEN); EG(ret, err);
ret = nn_bitlen(exp, &explen); EG(ret, err);
oldexplen = explen;
explen = (explen < 2) ? 2 : explen;
ret = nn_getbit(&mask, (bitcnt_t)(explen - 1), &rbit); EG(ret, err);
ret = nn_cmp(base, mod, &cmp); EG(ret, err);
if(cmp >= 0){
ret = nn_mod(&T[rbit], base, mod); EG(ret, err);
if(r != NULL){
ret = nn_mul_redc1(&T[rbit], &T[rbit], r_square, mod, mpinv); EG(ret, err);
}
}
else{
if(r != NULL){
ret = nn_mul_redc1(&T[rbit], base, r_square, mod, mpinv); EG(ret, err);
}
else{
ret = nn_copy(&T[rbit], base); EG(ret, err);
}
}
if(r != NULL){
ret = nn_mul_redc1(&T[1-rbit], &T[rbit], &T[rbit], mod, mpinv); EG(ret, err);
}
else{
ret = nn_mod_mul(&T[1-rbit], &T[rbit], &T[rbit], mod); EG(ret, err);
}
explen = (bitcnt_t)(explen - 1);
while (explen > 0) {
u8 rbit_next;
explen = (bitcnt_t)(explen - 1);
ret = nn_getbit(&mask, explen, &rbit_next); EG(ret, err);
ret = nn_getbit(exp, explen, &expbit); EG(ret, err);
if(r != NULL){
ret = nn_mul_redc1(&T[2], &T[expbit ^ rbit], &T[expbit ^ rbit], mod, mpinv); EG(ret, err);
}
else{
ret = nn_mod_mul(&T[2], &T[expbit ^ rbit], &T[expbit ^ rbit], mod); EG(ret, err);
}
if(r != NULL){
ret = nn_mul_redc1(&T[1], &T[0], &T[1], mod, mpinv); EG(ret, err);
}
else{
ret = nn_mod_mul(&T[1], &T[0], &T[1], mod); EG(ret, err);
}
ret = nn_copy(&T[0], &T[2 - (expbit ^ rbit_next)]); EG(ret, err);
ret = nn_copy(&T[1], &T[1 + (expbit ^ rbit_next)]); EG(ret, err);
rbit = rbit_next;
}
ret = nn_one(&T[1 - rbit]);
if(r != NULL){
ret = nn_mul_redc1(&T[rbit], &T[rbit], &T[1 - rbit], mod, mpinv); EG(ret, err);
}
ret = nn_mod(&T[1 - rbit], &T[1 - rbit], mod); EG(ret, err);
ret = nn_mod(&T[2], base, mod); EG(ret, err);
ret = nn_cnd_swap((oldexplen == 0), out, &T[1 - rbit]);
ret = nn_cnd_swap((oldexplen == 1), out, &T[2]);
ret = nn_cnd_swap(((oldexplen != 0) && (oldexplen != 1)), out, &T[rbit]);
err:
nn_uninit(&T[0]);
nn_uninit(&T[1]);
nn_uninit(&T[2]);
nn_uninit(&mask);
return ret;
}
ATTRIBUTE_WARN_UNUSED_RET static int _nn_mod_pow_redc(nn_t out, nn_src_t base, nn_src_t exp, nn_src_t mod, nn_src_t r, nn_src_t r_square, word_t mpinv)
{
return _nn_exp_monty_ladder_ltr(out, base, exp, mod, r, r_square, mpinv);
}
ATTRIBUTE_WARN_UNUSED_RET static int _nn_mod_pow(nn_t out, nn_src_t base, nn_src_t exp, nn_src_t mod)
{
int ret;
if ((out == base) || (out == exp) || (out == mod)) {
nn _out;
_out.magic = WORD(0);
ret = nn_init(&_out, 0); EG(ret, err);
ret = _nn_exp_monty_ladder_ltr(&_out, base, exp, mod, NULL, NULL, WORD(0)); EG(ret, err);
ret = nn_copy(out, &_out);
}
else{
ret = _nn_exp_monty_ladder_ltr(out, base, exp, mod, NULL, NULL, WORD(0));
}
err:
return ret;
}
ATTRIBUTE_WARN_UNUSED_RET static int _nn_mod_pow_redc_aliased(nn_t out, nn_src_t base, nn_src_t exp, nn_src_t mod, nn_src_t r, nn_src_t r_square, word_t mpinv)
{
nn _out;
int ret;
_out.magic = WORD(0);
ret = nn_init(&_out, 0); EG(ret, err);
ret = _nn_mod_pow_redc(&_out, base, exp, mod, r, r_square, mpinv); EG(ret, err);
ret = nn_copy(out, &_out);
err:
nn_uninit(&_out);
return ret;
}
int nn_mod_pow_redc(nn_t out, nn_src_t base, nn_src_t exp, nn_src_t mod, nn_src_t r, nn_src_t r_square, word_t mpinv)
{
int ret, isodd;
ret = nn_check_initialized(base); EG(ret, err);
ret = nn_check_initialized(exp); EG(ret, err);
ret = nn_check_initialized(mod); EG(ret, err);
ret = nn_check_initialized(r); EG(ret, err);
ret = nn_check_initialized(r_square); EG(ret, err);
ret = nn_isodd(mod, &isodd); EG(ret, err);
MUST_HAVE(isodd, ret, err);
if(mod->wlen < 2){
nn _mod;
_mod.magic = WORD(0);
ret = nn_copy(&_mod, mod); EG(ret, err1);
ret = nn_set_wlen(&_mod, 2); EG(ret, err1);
if ((out == base) || (out == exp) || (out == mod) || (out == r) || (out == r_square)) {
ret = _nn_mod_pow_redc_aliased(out, base, exp, &_mod, r, r_square, mpinv); EG(ret, err1);
} else {
ret = _nn_mod_pow_redc(out, base, exp, &_mod, r, r_square, mpinv); EG(ret, err1);
}
err1:
nn_uninit(&_mod);
EG(ret, err);
}
else{
if ((out == base) || (out == exp) || (out == mod) || (out == r) || (out == r_square)) {
ret = _nn_mod_pow_redc_aliased(out, base, exp, mod, r, r_square, mpinv);
} else {
ret = _nn_mod_pow_redc(out, base, exp, mod, r, r_square, mpinv);
}
}
err:
return ret;
}
int nn_mod_pow(nn_t out, nn_src_t base, nn_src_t exp, nn_src_t mod)
{
nn r, r_square;
word_t mpinv;
int ret, isodd;
r.magic = r_square.magic = WORD(0);
ret = nn_isodd(mod, &isodd); EG(ret, err);
if(!isodd){
ret = _nn_mod_pow(out, base, exp, mod);
}
else{
ret = nn_compute_redc1_coefs(&r, &r_square, mod, &mpinv); EG(ret, err);
ret = nn_mod_pow_redc(out, base, exp, mod, &r, &r_square, mpinv);
}
err:
nn_uninit(&r);
nn_uninit(&r_square);
return ret;
}