#include <libecc/nn/nn_modinv.h>
#include <libecc/nn/nn_div_public.h>
#include <libecc/nn/nn_logical.h>
#include <libecc/nn/nn_add.h>
#include <libecc/nn/nn_mod_pow.h>
#include <libecc/nn/nn.h>
#include "../nn/nn_mul.h"
ATTRIBUTE_WARN_UNUSED_RET static int _nn_modinv_odd(nn_t out, nn_src_t x, nn_src_t m)
{
int isodd, swap, smaller, ret, cmp, iszero, tmp_isodd;
nn a, b, u, tmp, mp1d2;
nn_t uu = out;
bitcnt_t cnt;
a.magic = b.magic = u.magic = tmp.magic = mp1d2.magic = WORD(0);
ret = nn_init(out, 0); EG(ret, err);
ret = nn_init(&a, (u16)(m->wlen * WORD_BYTES)); EG(ret, err);
ret = nn_init(&b, (u16)(m->wlen * WORD_BYTES)); EG(ret, err);
ret = nn_init(&u, (u16)(m->wlen * WORD_BYTES)); EG(ret, err);
ret = nn_init(&mp1d2, (u16)(m->wlen * WORD_BYTES)); EG(ret, err);
ret = nn_init(&tmp, (u16)(m->wlen * WORD_BYTES)); EG(ret, err);
MUST_HAVE((!nn_isodd(m, &isodd)) && isodd, ret, err);
MUST_HAVE((!nn_cmp(x, m, &cmp)) && (cmp < 0), ret, err);
MUST_HAVE((!nn_iszero(x, &iszero)) && (!iszero), ret, err);
ret = nn_copy(&a, x); EG(ret, err);
ret = nn_set_wlen(&a, m->wlen); EG(ret, err);
ret = nn_copy(&b, m); EG(ret, err);
ret = nn_one(&u); EG(ret, err);
ret = nn_zero(uu); EG(ret, err);
ret = nn_set_wlen(&u, m->wlen); EG(ret, err);
ret = nn_set_wlen(uu, m->wlen); EG(ret, err);
ret = nn_rshift_fixedlen(&mp1d2, m, 1); EG(ret, err);
ret = nn_inc(&mp1d2, &mp1d2); EG(ret, err);
cnt = (bitcnt_t)((a.wlen + b.wlen) * WORD_BITS);
while (cnt > 0) {
cnt = (bitcnt_t)(cnt - 1);
MUST_HAVE((!nn_isodd(&b, &tmp_isodd)) && tmp_isodd, ret, err);
ret = nn_isodd(&a, &isodd); EG(ret, err);
ret = nn_cmp(&a, &b, &cmp); EG(ret, err);
swap = isodd & (cmp == -1);
ret = nn_cnd_swap(swap, &a, &b); EG(ret, err);
ret = nn_cnd_sub(isodd, &a, &a, &b); EG(ret, err);
MUST_HAVE((!nn_isodd(&a, &tmp_isodd)) && (!tmp_isodd), ret, err);
ret = nn_rshift_fixedlen(&a, &a, 1); EG(ret, err);
ret = nn_cnd_swap(swap, &u, uu); EG(ret, err);
ret = nn_cmp(&u, uu, &cmp); EG(ret, err);
smaller = (cmp == -1);
ret = nn_sub(&tmp, m, uu); EG(ret, err);
ret = nn_cnd_add(isodd & smaller, &u, &u, &tmp); EG(ret, err);
ret = nn_cnd_sub(isodd & (!smaller), &u, &u, uu); EG(ret, err);
ret = nn_isodd(&u, &isodd); EG(ret, err);
ret = nn_rshift_fixedlen(&u, &u, 1); EG(ret, err);
ret = nn_cnd_add(isodd, &u, &u, &mp1d2); EG(ret, err);
MUST_HAVE((!nn_cmp(&u, m, &cmp)) && (cmp < 0), ret, err);
MUST_HAVE((!nn_cmp(uu, m, &cmp)) && (cmp < 0), ret, err);
}
MUST_HAVE((!nn_iszero(&a, &iszero)) && iszero, ret, err);
ret = nn_cmp_word(&b, WORD(1), &cmp); EG(ret, err);
ret = nn_cnd_sub(cmp != 0, uu, uu, uu); EG(ret, err);
ret = cmp ? -1 : 0;
err:
nn_uninit(&a);
nn_uninit(&b);
nn_uninit(&u);
nn_uninit(&mp1d2);
nn_uninit(&tmp);
PTR_NULLIFY(uu);
return ret;
}
int nn_modinv(nn_t _out, nn_src_t x, nn_src_t m)
{
int sign, ret, cmp, isodd, isone;
nn_t x_mod_m;
nn u, v, out;
out.magic = u.magic = v.magic = WORD(0);
ret = nn_check_initialized(x); EG(ret, err);
ret = nn_check_initialized(m); EG(ret, err);
ret = nn_init(&out, 0); EG(ret, err);
ret = nn_isodd(m, &isodd); EG(ret, err);
if (isodd) {
ret = nn_cmp(x, m, &cmp); EG(ret, err);
if (cmp >= 0) {
x_mod_m = &u;
ret = nn_mod(x_mod_m, x, m); EG(ret, err);
ret = _nn_modinv_odd(&out, x_mod_m, m); EG(ret, err);
} else {
ret = _nn_modinv_odd(&out, x, m); EG(ret, err);
}
ret = nn_copy(_out, &out);
goto err;
}
ret = nn_isodd(x, &isodd); EG(ret, err);
MUST_HAVE(isodd, ret, err);
ret = nn_init(&u, 0); EG(ret, err);
ret = nn_init(&v, 0); EG(ret, err);
ret = nn_xgcd(&out, &u, &v, x, m, &sign); EG(ret, err);
ret = nn_isone(&out, &isone); EG(ret, err);
MUST_HAVE(isone, ret, err);
ret = nn_mod(&out, &u, m); EG(ret, err);
if (sign == -1) {
ret = nn_sub(&out, m, &out); EG(ret, err);
}
ret = nn_copy(_out, &out);
err:
nn_uninit(&out);
nn_uninit(&u);
nn_uninit(&v);
PTR_NULLIFY(x_mod_m);
return ret;
}
ATTRIBUTE_WARN_UNUSED_RET static inline int _nn_sub_mod_2exp(nn_t A, nn_src_t B)
{
u8 Awlen = A->wlen;
int ret;
ret = nn_set_wlen(A, (u8)(Awlen + 1)); EG(ret, err);
A->val[A->wlen - 1] = WORD(1);
ret = nn_sub(A, A, B); EG(ret, err);
ret = nn_set_wlen(A, Awlen);
err:
return ret;
}
int nn_modinv_2exp(nn_t _out, nn_src_t x, bitcnt_t exp, int *x_isodd)
{
bitcnt_t cnt;
u8 exp_wlen = (u8)BIT_LEN_WORDS(exp);
bitcnt_t exp_cnt = exp % WORD_BITS;
word_t mask = (word_t)((exp_cnt == 0) ? WORD_MASK : (word_t)((WORD(1) << exp_cnt) - WORD(1)));
nn tmp_sqr, tmp_mul;
int isodd, ret;
nn out;
out.magic = tmp_sqr.magic = tmp_mul.magic = WORD(0);
MUST_HAVE((x_isodd != NULL), ret, err);
ret = nn_check_initialized(x); EG(ret, err);
ret = nn_check_initialized(_out); EG(ret, err);
ret = nn_init(&out, 0); EG(ret, err);
ret = nn_init(&tmp_sqr, 0); EG(ret, err);
ret = nn_init(&tmp_mul, 0); EG(ret, err);
ret = nn_isodd(x, &isodd); EG(ret, err);
if (exp == (bitcnt_t)0){
(*x_isodd) = isodd;
goto err;
}
if (!isodd) {
ret = nn_zero(_out); EG(ret, err);
(*x_isodd) = 0;
goto err;
}
cnt = 1;
ret = nn_one(&out); EG(ret, err);
for (; cnt < WORD_MIN(WORD_BITS, exp); cnt = (bitcnt_t)(cnt << 1)) {
ret = nn_sqr_low(&tmp_sqr, &out, out.wlen); EG(ret, err);
ret = nn_mul_low(&tmp_mul, &tmp_sqr, x, out.wlen); EG(ret, err);
ret = nn_lshift_fixedlen(&out, &out, 1); EG(ret, err);
ret = _nn_sub_mod_2exp(&out, &tmp_mul); EG(ret, err);
}
for (; cnt < ((exp + 1) >> 1); cnt = (bitcnt_t)(cnt << 1)) {
ret = nn_set_wlen(&out, (u8)(2 * out.wlen)); EG(ret, err);
ret = nn_sqr_low(&tmp_sqr, &out, out.wlen); EG(ret, err);
ret = nn_mul_low(&tmp_mul, &tmp_sqr, x, out.wlen); EG(ret, err);
ret = nn_lshift_fixedlen(&out, &out, 1); EG(ret, err);
ret = _nn_sub_mod_2exp(&out, &tmp_mul); EG(ret, err);
}
if (exp > WORD_BITS) {
ret = nn_set_wlen(&out, exp_wlen); EG(ret, err);
ret = nn_sqr_low(&tmp_sqr, &out, out.wlen); EG(ret, err);
ret = nn_mul_low(&tmp_mul, &tmp_sqr, x, out.wlen); EG(ret, err);
ret = nn_lshift_fixedlen(&out, &out, 1); EG(ret, err);
ret = _nn_sub_mod_2exp(&out, &tmp_mul); EG(ret, err);
}
out.val[exp_wlen - 1] &= mask;
ret = nn_copy(_out, &out); EG(ret, err);
(*x_isodd) = 1;
err:
nn_uninit(&out);
nn_uninit(&tmp_sqr);
nn_uninit(&tmp_mul);
return ret;
}
int nn_modinv_word(nn_t out, word_t w, nn_src_t m)
{
nn nn_tmp;
int ret;
nn_tmp.magic = WORD(0);
ret = nn_init(&nn_tmp, 0); EG(ret, err);
ret = nn_set_word_value(&nn_tmp, w); EG(ret, err);
ret = nn_modinv(out, &nn_tmp, m);
err:
nn_uninit(&nn_tmp);
return ret;
}
ATTRIBUTE_WARN_UNUSED_RET static int _nn_modinv_fermat_common(nn_t out, nn_src_t x, nn_src_t p, nn_t p_minus_two, int *lesstwo)
{
int ret, cmp, isodd;
nn two;
two.magic = WORD(0);
ret = nn_check_initialized(x); EG(ret, err);
ret = nn_check_initialized(p); EG(ret, err);
ret = nn_iszero(x, &cmp); EG(ret, err);
if(cmp){
ret = nn_init(out, 0); EG(ret, err);
ret = nn_zero(out); EG(ret, err);
ret = -1;
goto err;
}
(*lesstwo) = 0;
ret = nn_cmp_word(p, WORD(2), &cmp); EG(ret, err);
if(cmp == 0){
ret = nn_isodd(x, &isodd); EG(ret, err);
if(isodd){
ret = nn_init(out, 0); EG(ret, err);
ret = nn_one(out); EG(ret, err);
ret = 0;
}
else{
ret = nn_init(out, 0); EG(ret, err);
ret = nn_zero(out); EG(ret, err);
ret = -1;
}
(*lesstwo) = 1;
goto err;
} else if (cmp < 0){
ret = nn_init(out, 0); EG(ret, err);
ret = nn_zero(out); EG(ret, err);
ret = -1;
(*lesstwo) = 1;
goto err;
}
if(p != p_minus_two){
ret = nn_init(p_minus_two, 0); EG(ret, err);
}
ret = nn_init(&two, 0); EG(ret, err);
ret = nn_set_word_value(&two, WORD(2)); EG(ret, err);
ret = nn_sub(p_minus_two, p, &two);
err:
nn_uninit(&two);
return ret;
}
int nn_modinv_fermat(nn_t out, nn_src_t x, nn_src_t p)
{
int ret, lesstwo;
nn p_minus_two;
p_minus_two.magic = WORD(0);
ret = _nn_modinv_fermat_common(out, x, p, &p_minus_two, &lesstwo); EG(ret, err);
if(!lesstwo){
ret = nn_mod_pow(out, x, &p_minus_two, p);
}
err:
nn_uninit(&p_minus_two);
return ret;
}
int nn_modinv_fermat_redc(nn_t out, nn_src_t x, nn_src_t p, nn_src_t r, nn_src_t r_square, word_t mpinv)
{
int ret, lesstwo;
nn p_minus_two;
p_minus_two.magic = WORD(0);
ret = _nn_modinv_fermat_common(out, x, p, &p_minus_two, &lesstwo); EG(ret, err);
if(!lesstwo){
ret = nn_mod_pow_redc(out, x, &p_minus_two, p, r, r_square, mpinv);
}
err:
nn_uninit(&p_minus_two);
return ret;
}