#include <libecc/libarith.h>
ATTRIBUTE_WARN_UNUSED_RET int miller_rabin(nn_src_t n, const unsigned int t, int *res);
int miller_rabin(nn_src_t n, const unsigned int t, int *res)
{
int ret, iszero, cmp, isodd, cmp1, cmp2;
unsigned int i;
bitcnt_t k;
nn s, q, r, d, a, y, j, one, two, tmp;
s.magic = q.magic = r.magic = d.magic = a.magic = y.magic = j.magic = WORD(0);
one.magic = two.magic = tmp.magic = WORD(0);
ret = nn_check_initialized(n); EG(ret, err);
MUST_HAVE((res != NULL), ret, err);
(*res) = 0;
ret = nn_init(&s, 0); EG(ret, err);
ret = nn_init(&q, 0); EG(ret, err);
ret = nn_init(&r, 0); EG(ret, err);
ret = nn_init(&d, 0); EG(ret, err);
ret = nn_init(&a, 0); EG(ret, err);
ret = nn_init(&y, 0); EG(ret, err);
ret = nn_init(&j, 0); EG(ret, err);
ret = nn_init(&one, 0); EG(ret, err);
ret = nn_init(&two, 0); EG(ret, err);
ret = nn_init(&tmp, 0); EG(ret, err);
MUST_HAVE((t >= 1), ret, err);
ret = nn_one(&one); EG(ret, err);
ret = nn_set_word_value(&two, WORD(2)); EG(ret, err);
ret = nn_iszero(n, &iszero); EG(ret, err);
if (iszero) {
ret = 0;
(*res) = 0;
goto err;
}
ret = nn_cmp(n, &one, &cmp); EG(ret, err);
if (cmp == 0) {
ret = 0;
(*res) = 0;
goto err;
}
ret = nn_cmp(n, &two, &cmp); EG(ret, err);
if (cmp == 0) {
ret = 0;
(*res) = 1;
goto err;
}
ret = nn_copy(&tmp, n); EG(ret, err);
ret = nn_dec(&tmp, &tmp); EG(ret, err);
ret = nn_cmp(&tmp, &two, &cmp); EG(ret, err);
if (cmp == 0) {
ret = 0;
(*res) = 1;
goto err;
}
ret = nn_isodd(n, &isodd); EG(ret, err);
if (!isodd) {
ret = 0;
(*res) = 0;
goto err;
}
ret = nn_zero(&s); EG(ret, err);
ret = nn_copy(&r, n); EG(ret, err);
ret = nn_dec(&r, &r); EG(ret, err);
while (1) {
ret = nn_divrem(&q, &d, &r, &two); EG(ret, err);
ret = nn_inc(&s, &s); EG(ret, err);
ret = nn_copy(&r, &q); EG(ret, err);
ret = nn_isodd(&r, &isodd); EG(ret, err);
if (isodd) {
break;
}
}
for (i = 1; i <= t; i++) {
bitcnt_t blen;
ret = nn_copy(&tmp, n); EG(ret, err);
ret = nn_dec(&tmp, &tmp); EG(ret, err);
ret = nn_zero(&a); EG(ret, err);
ret = nn_cmp(&a, &two, &cmp); EG(ret, err);
while (cmp < 0) {
ret = nn_get_random_mod(&a, &tmp); EG(ret, err);
ret = nn_cmp(&a, &two, &cmp); EG(ret, err);
}
ret = nn_one(&y); EG(ret, err);
ret = nn_bitlen(&r, &blen); EG(ret, err);
for (k = 0; k < blen; k++) {
u8 bit;
ret = nn_getbit(&r, k, &bit); EG(ret, err);
if (bit) {
MUST_HAVE((NN_MAX_BIT_LEN >=
(WORD_BITS * (y.wlen + a.wlen))), ret, err);
ret = nn_mul(&y, &y, &a); EG(ret, err);
ret = nn_mod(&y, &y, n); EG(ret, err);
}
MUST_HAVE((NN_MAX_BIT_LEN >= (2 * WORD_BITS * a.wlen)), ret, err);
ret = nn_sqr(&a, &a); EG(ret, err);
ret = nn_mod(&a, &a, n); EG(ret, err);
}
ret = nn_cmp(&y, &one, &cmp1); EG(ret, err);
ret = nn_cmp(&y, &tmp, &cmp2); EG(ret, err);
if ((cmp1 != 0) && (cmp2 != 0)) {
ret = nn_one(&j); EG(ret, err);
ret = nn_cmp(&j, &s, &cmp1); EG(ret, err);
ret = nn_cmp(&y, &tmp, &cmp2); EG(ret, err);
while ((cmp1 < 0) && (cmp2 != 0)) {
MUST_HAVE((NN_MAX_BIT_LEN >=
(2 * WORD_BITS * y.wlen)), ret, err);
ret = nn_sqr(&y, &y); EG(ret, err);
ret = nn_mod(&y, &y, n); EG(ret, err);
ret = nn_cmp(&y, &one, &cmp); EG(ret, err);
if (cmp == 0) {
ret = 0;
(*res) = 0;
goto err;
}
ret = nn_inc(&j, &j); EG(ret, err);
ret = nn_cmp(&j, &s, &cmp1); EG(ret, err);
ret = nn_cmp(&y, &tmp, &cmp2); EG(ret, err);
}
ret = nn_cmp(&y, &tmp, &cmp); EG(ret, err);
if (cmp != 0) {
ret = 0;
(*res) = 0;
goto err;
}
}
ret = 0;
(*res) = 1;
}
err:
nn_uninit(&s);
nn_uninit(&q);
nn_uninit(&r);
nn_uninit(&d);
nn_uninit(&a);
nn_uninit(&y);
nn_uninit(&j);
nn_uninit(&one);
nn_uninit(&two);
nn_uninit(&tmp);
return ret;
}