#include <crypto/mldsa.h>
#include <crypto/sha3.h>
#include <kunit/visibility.h>
#include <linux/export.h>
#include <linux/module.h>
#include <linux/slab.h>
#include <linux/string.h>
#include <linux/unaligned.h>
#include "fips-mldsa.h"
#define Q 8380417
#define QINV_MOD_2_32 58728449
#define N 256
#define D 13
#define RHO_LEN 32
#define MAX_W1_ENCODED_LEN 192
static const s32 zetas_times_2_32[N] = {
-4186625, 25847, -2608894, -518909, 237124, -777960, -876248,
466468, 1826347, 2353451, -359251, -2091905, 3119733, -2884855,
3111497, 2680103, 2725464, 1024112, -1079900, 3585928, -549488,
-1119584, 2619752, -2108549, -2118186, -3859737, -1399561, -3277672,
1757237, -19422, 4010497, 280005, 2706023, 95776, 3077325,
3530437, -1661693, -3592148, -2537516, 3915439, -3861115, -3043716,
3574422, -2867647, 3539968, -300467, 2348700, -539299, -1699267,
-1643818, 3505694, -3821735, 3507263, -2140649, -1600420, 3699596,
811944, 531354, 954230, 3881043, 3900724, -2556880, 2071892,
-2797779, -3930395, -1528703, -3677745, -3041255, -1452451, 3475950,
2176455, -1585221, -1257611, 1939314, -4083598, -1000202, -3190144,
-3157330, -3632928, 126922, 3412210, -983419, 2147896, 2715295,
-2967645, -3693493, -411027, -2477047, -671102, -1228525, -22981,
-1308169, -381987, 1349076, 1852771, -1430430, -3343383, 264944,
508951, 3097992, 44288, -1100098, 904516, 3958618, -3724342,
-8578, 1653064, -3249728, 2389356, -210977, 759969, -1316856,
189548, -3553272, 3159746, -1851402, -2409325, -177440, 1315589,
1341330, 1285669, -1584928, -812732, -1439742, -3019102, -3881060,
-3628969, 3839961, 2091667, 3407706, 2316500, 3817976, -3342478,
2244091, -2446433, -3562462, 266997, 2434439, -1235728, 3513181,
-3520352, -3759364, -1197226, -3193378, 900702, 1859098, 909542,
819034, 495491, -1613174, -43260, -522500, -655327, -3122442,
2031748, 3207046, -3556995, -525098, -768622, -3595838, 342297,
286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044,
2842341, 2691481, -2590150, 1265009, 4055324, 1247620, 2486353,
1595974, -3767016, 1250494, 2635921, -3548272, -2994039, 1869119,
1903435, -1050970, -1333058, 1237275, -3318210, -1430225, -451100,
1312455, 3306115, -1962642, -1279661, 1917081, -2546312, -1374803,
1500165, 777191, 2235880, 3406031, -542412, -2831860, -1671176,
-1846953, -2584293, -3724270, 594136, -3776993, -2013608, 2432395,
2454455, -164721, 1957272, 3369112, 185531, -1207385, -3183426,
162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107,
-3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735,
472078, -426683, 1723600, -1803090, 1910376, -1667432, -1104333,
-260646, -3833893, -2939036, -2235985, -420899, -2286327, 183443,
-976891, 1612842, -3545687, -554416, 3919660, -48306, -1362209,
3937738, 1400424, -846154, 1976782
};
static const struct mldsa_parameter_set {
u8 k;
u8 l;
u8 ctilde_len;
u8 omega;
u8 tau;
u8 beta;
u16 pk_len;
u16 sig_len;
s32 gamma1;
} mldsa_parameter_sets[] = {
[MLDSA44] = {
.k = 4,
.l = 4,
.ctilde_len = 32,
.omega = 80,
.tau = 39,
.beta = 78,
.pk_len = MLDSA44_PUBLIC_KEY_SIZE,
.sig_len = MLDSA44_SIGNATURE_SIZE,
.gamma1 = 1 << 17,
},
[MLDSA65] = {
.k = 6,
.l = 5,
.ctilde_len = 48,
.omega = 55,
.tau = 49,
.beta = 196,
.pk_len = MLDSA65_PUBLIC_KEY_SIZE,
.sig_len = MLDSA65_SIGNATURE_SIZE,
.gamma1 = 1 << 19,
},
[MLDSA87] = {
.k = 8,
.l = 7,
.ctilde_len = 64,
.omega = 75,
.tau = 60,
.beta = 120,
.pk_len = MLDSA87_PUBLIC_KEY_SIZE,
.sig_len = MLDSA87_SIGNATURE_SIZE,
.gamma1 = 1 << 19,
},
};
struct mldsa_ring_elem {
s32 x[N];
};
struct mldsa_verification_workspace {
struct shake_ctx shake;
union {
u8 tr[64];
u8 mu[64];
u8 block[SHAKE128_BLOCK_SIZE + 1];
u8 w1_encoded[MAX_W1_ENCODED_LEN];
u8 ctildeprime[64];
};
struct shake_ctx a_shake;
union {
struct mldsa_ring_elem a;
struct mldsa_ring_elem t1_scaled;
};
struct mldsa_ring_elem c;
struct mldsa_ring_elem tmp;
struct mldsa_ring_elem z[];
};
static inline s32 Zq_mult(s32 a, s32 b)
{
s64 c = (s64)a * b;
s32 d = (u32)c * QINV_MOD_2_32;
s64 e = c - (s64)d * Q;
return e >> 32;
}
static void ntt(struct mldsa_ring_elem *w)
{
int m = 0;
for (int len = 128; len >= 1; len /= 2) {
for (int start = 0; start < 256; start += 2 * len) {
const s32 z = zetas_times_2_32[++m];
for (int j = start; j < start + len; j++) {
s32 t = Zq_mult(z, w->x[j + len]);
w->x[j + len] = w->x[j] - t;
w->x[j] += t;
}
}
}
}
static void invntt_and_mul_2_32(struct mldsa_ring_elem *w)
{
int m = 256;
for (int j = 0; j < 256; j++)
w->x[j] %= Q;
for (int len = 1; len < 256; len *= 2) {
for (int start = 0; start < 256; start += 2 * len) {
const s32 z = -zetas_times_2_32[--m];
for (int j = start; j < start + len; j++) {
s32 t = w->x[j];
w->x[j] = t + w->x[j + len];
w->x[j + len] = Zq_mult(z, t - w->x[j + len]);
}
}
}
for (int j = 0; j < 256; j++) {
w->x[j] = Zq_mult(w->x[j], 41978);
w->x[j] += (w->x[j] >> 31) & Q;
}
}
static const u8 *decode_t1_elem(struct mldsa_ring_elem *out,
const u8 *t1_encoded)
{
for (int j = 0; j < N; j += 4, t1_encoded += 5) {
u32 v = get_unaligned_le32(t1_encoded);
out->x[j + 0] = ((v >> 0) & 0x3ff) << D;
out->x[j + 1] = ((v >> 10) & 0x3ff) << D;
out->x[j + 2] = ((v >> 20) & 0x3ff) << D;
out->x[j + 3] = ((v >> 30) | (t1_encoded[4] << 2)) << D;
static_assert(0x3ff << D < Q);
}
ntt(out);
return t1_encoded;
}
static bool decode_z(struct mldsa_ring_elem z[], int l, s32 gamma1,
int beta, const u8 **sig_ptr)
{
const u8 *sig = *sig_ptr;
for (int i = 0; i < l; i++) {
if (l == 4) {
for (int j = 0; j < N; j += 4, sig += 9) {
u64 v = get_unaligned_le64(sig);
z[i].x[j + 0] = (v >> 0) & 0x3ffff;
z[i].x[j + 1] = (v >> 18) & 0x3ffff;
z[i].x[j + 2] = (v >> 36) & 0x3ffff;
z[i].x[j + 3] = (v >> 54) | (sig[8] << 10);
}
} else {
for (int j = 0; j < N; j += 4, sig += 10) {
u64 v = get_unaligned_le64(sig);
z[i].x[j + 0] = (v >> 0) & 0xfffff;
z[i].x[j + 1] = (v >> 20) & 0xfffff;
z[i].x[j + 2] = (v >> 40) & 0xfffff;
z[i].x[j + 3] =
(v >> 60) |
(get_unaligned_le16(&sig[8]) << 4);
}
}
for (int j = 0; j < N; j++) {
z[i].x[j] = gamma1 - z[i].x[j];
if (z[i].x[j] <= -(gamma1 - beta) ||
z[i].x[j] >= gamma1 - beta)
return false;
}
ntt(&z[i]);
}
*sig_ptr = sig;
return true;
}
static bool decode_hint_vector(u8 h[], int k, int omega, const u8 *y)
{
int index = 0;
memset(h, 0, k * N);
for (int i = 0; i < k; i++) {
int count = y[omega + i];
int prev = -1;
if (count < index || count > omega)
return false;
for (; index < count; index++) {
if (prev >= y[index])
return false;
prev = y[index];
h[i * N + y[index]] = 1;
}
}
return mem_is_zero(&y[index], omega - index);
}
static void sample_in_ball(struct mldsa_ring_elem *c, const u8 *seed,
size_t seed_len, int tau, struct shake_ctx *shake)
{
u64 signs;
u8 j;
shake256_init(shake);
shake_update(shake, seed, seed_len);
shake_squeeze(shake, (u8 *)&signs, sizeof(signs));
le64_to_cpus(&signs);
*c = (struct mldsa_ring_elem){};
for (int i = N - tau; i < N; i++, signs >>= 1) {
do {
shake_squeeze(shake, &j, 1);
} while (j > i);
c->x[i] = c->x[j];
c->x[j] = 1 - 2 * (s32)(signs & 1);
}
}
static void rej_ntt_poly(struct mldsa_ring_elem *out, const u8 rho[RHO_LEN],
__le16 row_and_column, struct shake_ctx *shake,
u8 block[SHAKE128_BLOCK_SIZE + 1])
{
shake128_init(shake);
shake_update(shake, rho, RHO_LEN);
shake_update(shake, (u8 *)&row_and_column, sizeof(row_and_column));
for (int i = 0; i < N;) {
shake_squeeze(shake, block, SHAKE128_BLOCK_SIZE);
block[SHAKE128_BLOCK_SIZE] = 0;
static_assert(SHAKE128_BLOCK_SIZE % 3 == 0);
for (int j = 0; j < SHAKE128_BLOCK_SIZE && i < N; j += 3) {
u32 x = get_unaligned_le32(&block[j]) & 0x7fffff;
if (x < Q)
out->x[i++] = x;
}
}
}
static __always_inline s32 use_hint(u8 h, s32 r, const s32 gamma2)
{
const s32 m = (Q - 1) / (2 * gamma2);
s32 r1;
if (r >= Q - gamma2)
return h == 0 ? 0 : m - 1;
r1 = (u32)(r + gamma2 - 1) / (2 * gamma2);
if (h == 0)
return r1;
if (r > r1 * (2 * gamma2))
return (u32)(r1 + 1) % m;
return (u32)(r1 + m - 1) % m;
}
static __always_inline void use_hint_elem(struct mldsa_ring_elem *w,
const u8 h[N], const s32 gamma2)
{
for (int j = 0; j < N; j++)
w->x[j] = use_hint(h[j], w->x[j], gamma2);
}
#if IS_ENABLED(CONFIG_CRYPTO_LIB_MLDSA_KUNIT_TEST)
s32 mldsa_use_hint(u8 h, s32 r, s32 gamma2)
{
return use_hint(h, r, gamma2);
}
EXPORT_SYMBOL_IF_KUNIT(mldsa_use_hint);
#endif
static size_t encode_w1(u8 out[MAX_W1_ENCODED_LEN],
const struct mldsa_ring_elem *w1, int k)
{
size_t pos = 0;
static_assert(N * 6 / 8 == MAX_W1_ENCODED_LEN);
if (k == 4) {
for (int j = 0; j < N; j += 4) {
u32 v = (w1->x[j + 0] << 0) | (w1->x[j + 1] << 6) |
(w1->x[j + 2] << 12) | (w1->x[j + 3] << 18);
out[pos++] = v >> 0;
out[pos++] = v >> 8;
out[pos++] = v >> 16;
}
} else {
for (int j = 0; j < N; j += 2)
out[pos++] = w1->x[j] | (w1->x[j + 1] << 4);
}
return pos;
}
int mldsa_verify(enum mldsa_alg alg, const u8 *sig, size_t sig_len,
const u8 *msg, size_t msg_len, const u8 *pk, size_t pk_len)
{
const struct mldsa_parameter_set *params = &mldsa_parameter_sets[alg];
const int k = params->k, l = params->l;
static const u8 msg_prefix[2] = { 0, 0 };
const u8 *ctilde;
const u8 *t1_encoded = &pk[RHO_LEN];
u8 *h;
size_t w1_enc_len;
if (pk_len != params->pk_len || sig_len != params->sig_len)
return -EBADMSG;
struct mldsa_verification_workspace *ws __free(kfree_sensitive) =
kmalloc(sizeof(*ws) + (l * sizeof(ws->z[0])) + (k * N),
GFP_KERNEL);
if (!ws)
return -ENOMEM;
h = (u8 *)&ws->z[l];
ctilde = sig;
sig += params->ctilde_len;
if (!decode_z(ws->z, l, params->gamma1, params->beta, &sig))
return -EBADMSG;
if (!decode_hint_vector(h, k, params->omega, sig))
return -EBADMSG;
sample_in_ball(&ws->c, ctilde, params->ctilde_len, params->tau,
&ws->shake);
ntt(&ws->c);
shake256(pk, pk_len, ws->tr, sizeof(ws->tr));
shake256_init(&ws->shake);
shake_update(&ws->shake, ws->tr, sizeof(ws->tr));
shake_update(&ws->shake, msg_prefix, sizeof(msg_prefix));
shake_update(&ws->shake, msg, msg_len);
shake_squeeze(&ws->shake, ws->mu, sizeof(ws->mu));
shake256_init(&ws->shake);
shake_update(&ws->shake, ws->mu, sizeof(ws->mu));
for (int i = 0; i < k; i++) {
ws->tmp = (struct mldsa_ring_elem){};
for (int j = 0; j < l; j++) {
rej_ntt_poly(&ws->a, pk ,
cpu_to_le16((i << 8) | j), &ws->a_shake,
ws->block);
for (int n = 0; n < N; n++)
ws->tmp.x[n] +=
Zq_mult(ws->a.x[n], ws->z[j].x[n]);
}
t1_encoded = decode_t1_elem(&ws->t1_scaled, t1_encoded);
for (int j = 0; j < N; j++)
ws->tmp.x[j] -= Zq_mult(ws->c.x[j], ws->t1_scaled.x[j]);
invntt_and_mul_2_32(&ws->tmp);
if (k == 4)
use_hint_elem(&ws->tmp, &h[i * N], (Q - 1) / 88);
else
use_hint_elem(&ws->tmp, &h[i * N], (Q - 1) / 32);
w1_enc_len = encode_w1(ws->w1_encoded, &ws->tmp, k);
shake_update(&ws->shake, ws->w1_encoded, w1_enc_len);
}
shake_squeeze(&ws->shake, ws->ctildeprime, params->ctilde_len);
if (memcmp(ws->ctildeprime, ctilde, params->ctilde_len) != 0)
return -EKEYREJECTED;
return 0;
}
EXPORT_SYMBOL_GPL(mldsa_verify);
#ifdef CONFIG_CRYPTO_FIPS
static int __init mldsa_mod_init(void)
{
if (fips_enabled) {
int err = mldsa_verify(MLDSA65, fips_test_mldsa65_signature,
sizeof(fips_test_mldsa65_signature),
fips_test_mldsa65_message,
sizeof(fips_test_mldsa65_message),
fips_test_mldsa65_public_key,
sizeof(fips_test_mldsa65_public_key));
if (err)
panic("mldsa: FIPS self-test failed; err=%pe\n",
ERR_PTR(err));
}
return 0;
}
subsys_initcall(mldsa_mod_init);
static void __exit mldsa_mod_exit(void)
{
}
module_exit(mldsa_mod_exit);
#endif
MODULE_DESCRIPTION("ML-DSA signature verification");
MODULE_LICENSE("GPL");