#include <libecc/nn/nn_mul_redc1.h>
#include <libecc/nn/nn_mul_public.h>
#include <libecc/nn/nn_add.h>
#include <libecc/nn/nn_logical.h>
#include <libecc/nn/nn_div_public.h>
#include <libecc/nn/nn_modinv.h>
#include <libecc/nn/nn.h>
int nn_compute_redc1_coefs(nn_t r, nn_t r_square, nn_src_t p_in, word_t *mpinv)
{
bitcnt_t p_rounded_bitlen;
nn p, tmp_nn1, tmp_nn2;
word_t _mpinv;
int ret, isodd;
p.magic = tmp_nn1.magic = tmp_nn2.magic = WORD(0);
ret = nn_check_initialized(p_in); EG(ret, err);
ret = nn_init(&p, 0); EG(ret, err);
ret = nn_copy(&p, p_in); EG(ret, err);
MUST_HAVE((mpinv != NULL), ret, err);
if (p.wlen < 2) {
ret = nn_set_wlen(&p, 2); EG(ret, err);
}
ret = nn_init(r, 0); EG(ret, err);
ret = nn_init(r_square, 0); EG(ret, err);
ret = nn_init(&tmp_nn1, 0); EG(ret, err);
ret = nn_init(&tmp_nn2, 0); EG(ret, err);
p_rounded_bitlen = (bitcnt_t)(WORD_BITS * p.wlen);
ret = nn_set_wlen(&tmp_nn1, 2); EG(ret, err);
tmp_nn1.val[1] = WORD(1);
ret = nn_copy(&tmp_nn2, &tmp_nn1); EG(ret, err);
ret = nn_modinv_2exp(&tmp_nn1, &p, WORD_BITS, &isodd); EG(ret, err);
ret = nn_sub(&tmp_nn1, &tmp_nn2, &tmp_nn1); EG(ret, err);
_mpinv = tmp_nn1.val[0];
ret = nn_one(r); EG(ret, err);
ret = nn_lshift(r, r, p_rounded_bitlen); EG(ret, err);
ret = nn_mod(r, r, &p); EG(ret, err);
MUST_HAVE(!(NN_MAX_BIT_LEN < (2 * p_rounded_bitlen)), ret, err);
ret = nn_sqr(r_square, r); EG(ret, err);
ret = nn_mod(r_square, r_square, &p); EG(ret, err);
(*mpinv) = _mpinv;
err:
nn_uninit(&p);
nn_uninit(&tmp_nn1);
nn_uninit(&tmp_nn2);
return ret;
}
ATTRIBUTE_WARN_UNUSED_RET static int _nn_mul_redc1(nn_t out, nn_src_t in1, nn_src_t in2, nn_src_t p,
word_t mpinv)
{
word_t prod_high, prod_low, carry, acc, m;
unsigned int i, j, len, len_mul;
nn_src_t a, b;
int ret, cmp;
u8 old_wlen;
SHOULD_HAVE((!nn_cmp(in1, p, &cmp)) && (cmp < 0), ret, err);
SHOULD_HAVE((!nn_cmp(in2, p, &cmp)) && (cmp < 0), ret, err);
ret = nn_init(out, 0); EG(ret, err);
a = (in1->wlen <= in2->wlen) ? in2 : in1;
b = (in1->wlen <= in2->wlen) ? in1 : in2;
ret = nn_set_wlen(out, p->wlen); EG(ret, err);
len = out->wlen;
len_mul = b->wlen;
MUST_HAVE(((WORD_BITS * (out->wlen + 1)) <= NN_MAX_BIT_LEN), ret, err);
old_wlen = out->wlen;
out->wlen = (u8)(out->wlen + 1);
for (i = 0; i < out->wlen; i++) {
out->val[i] = 0;
}
for (i = 0; i < len; i++) {
carry = WORD(0);
for (j = 0; j < len_mul; j++) {
WORD_MUL(prod_high, prod_low, a->val[i], b->val[j]);
prod_low = (word_t)(prod_low + carry);
prod_high = (word_t)(prod_high + (prod_low < carry));
out->val[j] = (word_t)(out->val[j] + prod_low);
carry = (word_t)(prod_high + (out->val[j] < prod_low));
}
for (; j < len; j++) {
out->val[j] = (word_t)(out->val[j] + carry);
carry = (word_t)(out->val[j] < carry);
}
out->val[j] = (word_t)(out->val[j] + carry);
acc = (word_t)(out->val[j] < carry);
m = (word_t)(out->val[0] * mpinv);
WORD_MUL(prod_high, prod_low, m, p->val[0]);
prod_low = (word_t)(prod_low + out->val[0]);
carry = (word_t)(prod_high + (prod_low < out->val[0]));
for (j = 1; j < len; j++) {
WORD_MUL(prod_high, prod_low, m, p->val[j]);
prod_low = (word_t)(prod_low + carry);
prod_high = (word_t)(prod_high + (prod_low < carry));
out->val[j - 1] = (word_t)(prod_low + out->val[j]);
carry = (word_t)(prod_high + (out->val[j - 1] < prod_low));
}
out->val[j - 1] = (word_t)(carry + out->val[j]);
carry = (word_t)(out->val[j - 1] < out->val[j]);
out->val[j] = (word_t)(acc + carry);
}
ret = nn_cmp(out, p, &cmp); EG(ret, err);
ret = nn_cnd_sub(cmp >= 0, out, out, p); EG(ret, err);
MUST_HAVE((!nn_cmp(out, p, &cmp)) && (cmp < 0), ret, err);
out->wlen = old_wlen;
err:
return ret;
}
ATTRIBUTE_WARN_UNUSED_RET static int _nn_mul_redc1_aliased(nn_t out, nn_src_t in1, nn_src_t in2,
nn_src_t p, word_t mpinv)
{
nn out_cpy;
int ret;
out_cpy.magic = WORD(0);
ret = _nn_mul_redc1(&out_cpy, in1, in2, p, mpinv); EG(ret, err);
ret = nn_init(out, out_cpy.wlen); EG(ret, err);
ret = nn_copy(out, &out_cpy);
err:
nn_uninit(&out_cpy);
return ret;
}
int nn_mul_redc1(nn_t out, nn_src_t in1, nn_src_t in2, nn_src_t p,
word_t mpinv)
{
int ret;
ret = nn_check_initialized(in1); EG(ret, err);
ret = nn_check_initialized(in2); EG(ret, err);
ret = nn_check_initialized(p); EG(ret, err);
if ((out == in1) || (out == in2) || (out == p)) {
ret = _nn_mul_redc1_aliased(out, in1, in2, p, mpinv);
} else {
ret = _nn_mul_redc1(out, in1, in2, p, mpinv);
}
err:
return ret;
}
int nn_mod_mul(nn_t out, nn_src_t in1, nn_src_t in2, nn_src_t p_in)
{
nn r_square, p;
nn in1_tmp, in2_tmp;
word_t mpinv;
int ret, isodd;
r_square.magic = in1_tmp.magic = in2_tmp.magic = p.magic = WORD(0);
ret = nn_isodd(p_in, &isodd); EG(ret, err);
if(!isodd){
ret = nn_mul(out, in1, in2); EG(ret, err);
ret = nn_mod(out, out, p_in); EG(ret, err);
}
else{
ret = nn_copy(&p, p_in); EG(ret, err);
if (p.wlen < 2) {
ret = nn_set_wlen(&p, 2); EG(ret, err);
}
ret = nn_compute_redc1_coefs(&in1_tmp, &r_square, &p, &mpinv); EG(ret, err);
ret = nn_mul_redc1(&in1_tmp, in1, &r_square, &p, mpinv); EG(ret, err);
ret = nn_mul_redc1(&in2_tmp, in2, &r_square, &p, mpinv); EG(ret, err);
ret = nn_mul_redc1(&r_square, &in1_tmp, &in2_tmp, &p, mpinv); EG(ret, err);
ret = nn_init(&in1_tmp, 0); EG(ret, err);
ret = nn_one(&in1_tmp); EG(ret, err);
ret = nn_mul_redc1(out, &r_square, &in1_tmp, &p, mpinv); EG(ret, err);
}
err:
nn_uninit(&p);
nn_uninit(&r_square);
nn_uninit(&in1_tmp);
nn_uninit(&in2_tmp);
return ret;
}