#include <err.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <openssl/mlkem.h>
#include "mlkem_internal.h"
#include "mlkem_tests_util.h"
#include "sha3_internal.h"
const uint8_t kExpectedSeedStart[16] = {
0x7f, 0x9c, 0x2b, 0xa4, 0xe8, 0x8f, 0x82, 0x7d, 0x61, 0x60, 0x45,
0x50, 0x76, 0x05, 0x85, 0x3e
};
const uint8_t kExpectedAdam768[32] = {
0xf9, 0x59, 0xd1, 0x8d, 0x3d, 0x11, 0x80, 0x12, 0x14, 0x33, 0xbf,
0x0e, 0x05, 0xf1, 0x1e, 0x79, 0x08, 0xcf, 0x9d, 0x03, 0xed, 0xc1,
0x50, 0xb2, 0xb0, 0x7c, 0xb9, 0x0b, 0xef, 0x5b, 0xc1, 0xc1
};
const uint8_t kExpectedAdam1024[32] = {
0xe3, 0xbf, 0x82, 0xb0, 0x13, 0x30, 0x7b, 0x2e, 0x9d, 0x47, 0xdd,
0xe7, 0x91, 0xff, 0x6d, 0xfc, 0x82, 0xe6, 0x94, 0xe6, 0x38, 0x24,
0x04, 0xab, 0xdb, 0x94, 0x8b, 0x90, 0x8b, 0x75, 0xba, 0xd5
};
static int
MlkemIterativeTest(int rank)
{
const uint8_t *start, *expected;
size_t start_len;
uint8_t encap_entropy[MLKEM_ENCAP_ENTROPY];
uint8_t seed[MLKEM_SEED_LENGTH] = {0};
uint8_t *shared_secret = NULL;
sha3_ctx drng, results;
uint8_t out[32];
int i;
start = kExpectedSeedStart;
start_len = sizeof(kExpectedSeedStart);
switch(rank){
case MLKEM768_RANK:
expected = kExpectedAdam768;
break;
case MLKEM1024_RANK:
expected = kExpectedAdam1024;
break;
default:
errx(1, "invalid rank %d", rank);
}
shake128_init(&drng);
shake128_init(&results);
shake_xof(&drng);
for (i = 0; i < 10000; i++) {
uint8_t *encoded_public_key = NULL, *ciphertext = NULL,
*encoded_private_key = NULL, *invalid_ciphertext = NULL;
size_t encoded_public_key_len, ciphertext_len,
encoded_private_key_len, invalid_ciphertext_len;
MLKEM_private_key *priv;
MLKEM_public_key *pub;
size_t s_len = 0;
if ((priv = MLKEM_private_key_new(rank)) == NULL)
errx(1, "malloc");
if ((pub = MLKEM_public_key_new(rank)) == NULL)
errx(1, "malloc");
shake_out(&drng, seed, sizeof(seed));
if (i == 0) {
if (compare_data(seed, start, start_len,
"seed start") != 0)
errx(1, "compare_data");
}
if (!MLKEM_generate_key_external_entropy(priv,
&encoded_public_key, &encoded_public_key_len,
seed))
errx(1, "generate_key_external_entropy");
if (!MLKEM_public_from_private(priv, pub))
errx(1, "public_from_private");
shake_update(&results, encoded_public_key,
encoded_public_key_len);
if (!MLKEM_marshal_private_key(priv, &encoded_private_key,
&encoded_private_key_len))
errx(1, "marshal private key");
shake_update(&results, encoded_private_key,
encoded_private_key_len);
freezero(encoded_private_key, encoded_private_key_len);
shake_out(&drng, encap_entropy, sizeof(encap_entropy));
if (!MLKEM_encap_external_entropy(pub, encap_entropy,
&ciphertext, &ciphertext_len, &shared_secret, &s_len))
errx(1, "encap_external_entropy");
shake_update(&results, ciphertext, ciphertext_len);
shake_update(&results, shared_secret, s_len);
freezero(shared_secret, s_len);
shared_secret = NULL;
invalid_ciphertext_len = ciphertext_len;
if ((invalid_ciphertext = calloc(1, invalid_ciphertext_len))
== NULL)
errx(1, "malloc");
shake_out(&drng, invalid_ciphertext, invalid_ciphertext_len);
if (!MLKEM_decap(priv, invalid_ciphertext,
invalid_ciphertext_len, &shared_secret, &s_len))
errx(1, "decap failed, iteration %d", i);
shake_update(&results, shared_secret, s_len);
freezero(shared_secret, s_len);
shared_secret = NULL;
freezero(invalid_ciphertext, invalid_ciphertext_len);
invalid_ciphertext = NULL;
MLKEM_private_key_free(priv);
MLKEM_public_key_free(pub);
freezero(encoded_public_key, encoded_public_key_len);
freezero(ciphertext, ciphertext_len);
}
shake_xof(&results);
shake_out(&results, out, sizeof(out));
return compare_data(expected, out, sizeof(out), "final result hash");
}
int
main(void)
{
int failed = 0;
failed |= MlkemIterativeTest(MLKEM768_RANK);
failed |= MlkemIterativeTest(MLKEM1024_RANK);
return failed;
}