#include <sys/cdefs.h>
#include <ctype.h>
#include <err.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <openssl/crypto.h>
#include <openssl/err.h>
#include "mp.h"
#define MPERR(s) do { warn s; abort(); } while (0)
#define MPERRX(s) do { warnx s; abort(); } while (0)
#define BN_ERRCHECK(msg, expr) do { \
if (!(expr)) _bnerr(msg); \
} while (0)
static void _bnerr(const char *);
static MINT *_dtom(const char *, const char *);
static MINT *_itom(const char *, short);
static void _madd(const char *, const MINT *, const MINT *, MINT *);
static int _mcmpa(const char *, const MINT *, const MINT *);
static void _mdiv(const char *, const MINT *, const MINT *, MINT *, MINT *,
BN_CTX *);
static void _mfree(const char *, MINT *);
static void _moveb(const char *, const BIGNUM *, MINT *);
static void _movem(const char *, const MINT *, MINT *);
static void _msub(const char *, const MINT *, const MINT *, MINT *);
static char *_mtod(const char *, const MINT *);
static char *_mtox(const char *, const MINT *);
static void _mult(const char *, const MINT *, const MINT *, MINT *, BN_CTX *);
static void _sdiv(const char *, const MINT *, short, MINT *, short *, BN_CTX *);
static MINT *_xtom(const char *, const char *);
static void
_bnerr(const char *msg)
{
ERR_load_crypto_strings();
MPERRX(("%s: %s", msg, ERR_reason_error_string(ERR_get_error())));
}
static MINT *
_dtom(const char *msg, const char *s)
{
MINT *mp;
mp = malloc(sizeof(*mp));
if (mp == NULL)
MPERR(("%s", msg));
mp->bn = BN_new();
if (mp->bn == NULL)
_bnerr(msg);
BN_ERRCHECK(msg, BN_dec2bn(&mp->bn, s));
return (mp);
}
void
mp_gcd(const MINT *mp1, const MINT *mp2, MINT *rmp)
{
BIGNUM *b;
BN_CTX *c;
b = NULL;
c = BN_CTX_new();
if (c != NULL)
b = BN_new();
if (c == NULL || b == NULL)
_bnerr("gcd");
BN_ERRCHECK("gcd", BN_gcd(b, mp1->bn, mp2->bn, c));
_moveb("gcd", b, rmp);
BN_free(b);
BN_CTX_free(c);
}
static MINT *
_itom(const char *msg, short n)
{
MINT *mp;
char *s;
asprintf(&s, "%x", n);
if (s == NULL)
MPERR(("%s", msg));
mp = _xtom(msg, s);
free(s);
return (mp);
}
MINT *
mp_itom(short n)
{
return (_itom("itom", n));
}
static void
_madd(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
{
BIGNUM *b;
b = BN_new();
if (b == NULL)
_bnerr(msg);
BN_ERRCHECK(msg, BN_add(b, mp1->bn, mp2->bn));
_moveb(msg, b, rmp);
BN_free(b);
}
void
mp_madd(const MINT *mp1, const MINT *mp2, MINT *rmp)
{
_madd("madd", mp1, mp2, rmp);
}
int
mp_mcmp(const MINT *mp1, const MINT *mp2)
{
return (BN_cmp(mp1->bn, mp2->bn));
}
static int
_mcmpa(const char *msg __unused, const MINT *mp1, const MINT *mp2)
{
return (BN_ucmp(mp1->bn, mp2->bn));
}
static void
_mdiv(const char *msg, const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp,
BN_CTX *c)
{
BIGNUM *q, *r;
q = NULL;
r = BN_new();
if (r != NULL)
q = BN_new();
if (r == NULL || q == NULL)
_bnerr(msg);
BN_ERRCHECK(msg, BN_div(q, r, nmp->bn, dmp->bn, c));
_moveb(msg, q, qmp);
_moveb(msg, r, rmp);
BN_free(q);
BN_free(r);
}
void
mp_mdiv(const MINT *nmp, const MINT *dmp, MINT *qmp, MINT *rmp)
{
BN_CTX *c;
c = BN_CTX_new();
if (c == NULL)
_bnerr("mdiv");
_mdiv("mdiv", nmp, dmp, qmp, rmp, c);
BN_CTX_free(c);
}
static void
_mfree(const char *msg __unused, MINT *mp)
{
BN_clear(mp->bn);
BN_free(mp->bn);
free(mp);
}
void
mp_mfree(MINT *mp)
{
_mfree("mfree", mp);
}
void
mp_min(MINT *mp)
{
MINT *rmp;
char *line, *nline;
size_t linelen;
line = fgetln(stdin, &linelen);
if (line == NULL)
MPERR(("min"));
nline = malloc(linelen + 1);
if (nline == NULL)
MPERR(("min"));
memcpy(nline, line, linelen);
nline[linelen] = '\0';
rmp = _dtom("min", nline);
_movem("min", rmp, mp);
_mfree("min", rmp);
free(nline);
}
void
mp_mout(const MINT *mp)
{
char *s;
s = _mtod("mout", mp);
printf("%s", s);
free(s);
}
void
mp_move(const MINT *smp, MINT *tmp)
{
_movem("move", smp, tmp);
}
static void
_moveb(const char *msg, const BIGNUM *sbp, MINT *tmp)
{
BN_ERRCHECK(msg, BN_copy(tmp->bn, sbp));
}
static void
_movem(const char *msg, const MINT *smp, MINT *tmp)
{
BN_ERRCHECK(msg, BN_copy(tmp->bn, smp->bn));
}
void
mp_msqrt(const MINT *nmp, MINT *xmp, MINT *rmp)
{
BN_CTX *c;
MINT *tolerance;
MINT *ox, *x;
MINT *z1, *z2, *z3;
short i;
c = BN_CTX_new();
if (c == NULL)
_bnerr("msqrt");
tolerance = _itom("msqrt", 1);
x = _itom("msqrt", 1);
ox = _itom("msqrt", 0);
z1 = _itom("msqrt", 0);
z2 = _itom("msqrt", 0);
z3 = _itom("msqrt", 0);
do {
_movem("msqrt", x, ox);
_mdiv("msqrt", nmp, x, z1, z2, c);
_madd("msqrt", x, z1, z2);
_sdiv("msqrt", z2, 2, x, &i, c);
_msub("msqrt", ox, x, z3);
} while (_mcmpa("msqrt", z3, tolerance) == 1);
_movem("msqrt", x, xmp);
_mult("msqrt", x, x, z1, c);
_msub("msqrt", nmp, z1, z2);
_movem("msqrt", z2, rmp);
_mfree("msqrt", tolerance);
_mfree("msqrt", ox);
_mfree("msqrt", x);
_mfree("msqrt", z1);
_mfree("msqrt", z2);
_mfree("msqrt", z3);
BN_CTX_free(c);
}
static void
_msub(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp)
{
BIGNUM *b;
b = BN_new();
if (b == NULL)
_bnerr(msg);
BN_ERRCHECK(msg, BN_sub(b, mp1->bn, mp2->bn));
_moveb(msg, b, rmp);
BN_free(b);
}
void
mp_msub(const MINT *mp1, const MINT *mp2, MINT *rmp)
{
_msub("msub", mp1, mp2, rmp);
}
static char *
_mtod(const char *msg, const MINT *mp)
{
char *s, *s2;
s = BN_bn2dec(mp->bn);
if (s == NULL)
_bnerr(msg);
asprintf(&s2, "%s", s);
if (s2 == NULL)
MPERR(("%s", msg));
OPENSSL_free(s);
return (s2);
}
static char *
_mtox(const char *msg, const MINT *mp)
{
char *p, *s, *s2;
int len;
s = BN_bn2hex(mp->bn);
if (s == NULL)
_bnerr(msg);
asprintf(&s2, "%s", s);
if (s2 == NULL)
MPERR(("%s", msg));
OPENSSL_free(s);
len = strlen(s2);
for (p = s2; p < s2 + len; p++)
*p = tolower(*p);
return (s2);
}
char *
mp_mtox(const MINT *mp)
{
return (_mtox("mtox", mp));
}
static void
_mult(const char *msg, const MINT *mp1, const MINT *mp2, MINT *rmp, BN_CTX *c)
{
BIGNUM *b;
b = BN_new();
if (b == NULL)
_bnerr(msg);
BN_ERRCHECK(msg, BN_mul(b, mp1->bn, mp2->bn, c));
_moveb(msg, b, rmp);
BN_free(b);
}
void
mp_mult(const MINT *mp1, const MINT *mp2, MINT *rmp)
{
BN_CTX *c;
c = BN_CTX_new();
if (c == NULL)
_bnerr("mult");
_mult("mult", mp1, mp2, rmp, c);
BN_CTX_free(c);
}
void
mp_pow(const MINT *bmp, const MINT *emp, const MINT *mmp, MINT *rmp)
{
BIGNUM *b;
BN_CTX *c;
b = NULL;
c = BN_CTX_new();
if (c != NULL)
b = BN_new();
if (c == NULL || b == NULL)
_bnerr("pow");
BN_ERRCHECK("pow", BN_mod_exp(b, bmp->bn, emp->bn, mmp->bn, c));
_moveb("pow", b, rmp);
BN_free(b);
BN_CTX_free(c);
}
void
mp_rpow(const MINT *bmp, short e, MINT *rmp)
{
MINT *emp;
BIGNUM *b;
BN_CTX *c;
b = NULL;
c = BN_CTX_new();
if (c != NULL)
b = BN_new();
if (c == NULL || b == NULL)
_bnerr("rpow");
emp = _itom("rpow", e);
BN_ERRCHECK("rpow", BN_exp(b, bmp->bn, emp->bn, c));
_moveb("rpow", b, rmp);
_mfree("rpow", emp);
BN_free(b);
BN_CTX_free(c);
}
static void
_sdiv(const char *msg, const MINT *nmp, short d, MINT *qmp, short *ro,
BN_CTX *c)
{
MINT *dmp, *rmp;
BIGNUM *q, *r;
char *s;
r = NULL;
q = BN_new();
if (q != NULL)
r = BN_new();
if (q == NULL || r == NULL)
_bnerr(msg);
dmp = _itom(msg, d);
rmp = _itom(msg, 0);
BN_ERRCHECK(msg, BN_div(q, r, nmp->bn, dmp->bn, c));
_moveb(msg, q, qmp);
_moveb(msg, r, rmp);
s = _mtox(msg, rmp);
errno = 0;
*ro = strtol(s, NULL, 16);
if (errno != 0)
MPERR(("%s underflow or overflow", msg));
free(s);
_mfree(msg, dmp);
_mfree(msg, rmp);
BN_free(r);
BN_free(q);
}
void
mp_sdiv(const MINT *nmp, short d, MINT *qmp, short *ro)
{
BN_CTX *c;
c = BN_CTX_new();
if (c == NULL)
_bnerr("sdiv");
_sdiv("sdiv", nmp, d, qmp, ro, c);
BN_CTX_free(c);
}
static MINT *
_xtom(const char *msg, const char *s)
{
MINT *mp;
mp = malloc(sizeof(*mp));
if (mp == NULL)
MPERR(("%s", msg));
mp->bn = BN_new();
if (mp->bn == NULL)
_bnerr(msg);
BN_ERRCHECK(msg, BN_hex2bn(&mp->bn, s));
return (mp);
}
MINT *
mp_xtom(const char *s)
{
return (_xtom("xtom", s));
}