#include <stdlib.h>
#include <string.h>
#include <openssl/mlkem.h>
#include "mlkem_internal.h"
static inline int
private_key_is_new(const MLKEM_private_key *key)
{
return (key != NULL &&
key->state == MLKEM_PRIVATE_KEY_UNINITIALIZED &&
(key->rank == MLKEM768_RANK || key->rank == MLKEM1024_RANK));
}
static inline int
private_key_is_valid(const MLKEM_private_key *key)
{
return (key != NULL &&
key->state == MLKEM_PRIVATE_KEY_INITIALIZED &&
(key->rank == MLKEM768_RANK || key->rank == MLKEM1024_RANK));
}
static inline int
public_key_is_new(const MLKEM_public_key *key)
{
return (key != NULL &&
key->state == MLKEM_PUBLIC_KEY_UNINITIALIZED &&
(key->rank == MLKEM768_RANK || key->rank == MLKEM1024_RANK));
}
static inline int
public_key_is_valid(const MLKEM_public_key *key)
{
return (key != NULL &&
key->state == MLKEM_PUBLIC_KEY_INITIALIZED &&
(key->rank == MLKEM768_RANK || key->rank == MLKEM1024_RANK));
}
int
MLKEM_generate_key_external_entropy(MLKEM_private_key *private_key,
uint8_t **out_encoded_public_key, size_t *out_encoded_public_key_len,
const uint8_t *entropy)
{
uint8_t *k = NULL;
size_t k_len = 0;
int ret = 0;
if (*out_encoded_public_key != NULL)
goto err;
if (!private_key_is_new(private_key))
goto err;
k_len = MLKEM768_PUBLIC_KEY_BYTES;
if (private_key->rank == MLKEM1024_RANK)
k_len = MLKEM1024_PUBLIC_KEY_BYTES;
if ((k = calloc(1, k_len)) == NULL)
goto err;
if (!mlkem_generate_key_external_entropy(k, private_key, entropy))
goto err;
private_key->state = MLKEM_PRIVATE_KEY_INITIALIZED;
*out_encoded_public_key = k;
*out_encoded_public_key_len = k_len;
k = NULL;
k_len = 0;
ret = 1;
err:
freezero(k, k_len);
return ret;
}
int
MLKEM_generate_key(MLKEM_private_key *private_key,
uint8_t **out_encoded_public_key, size_t *out_encoded_public_key_len,
uint8_t **out_optional_seed, size_t *out_optional_seed_len)
{
uint8_t *entropy_buf = NULL;
int ret = 0;
if (*out_encoded_public_key != NULL)
goto err;
if (out_optional_seed != NULL && *out_optional_seed != NULL)
goto err;
if ((entropy_buf = calloc(1, MLKEM_SEED_LENGTH)) == NULL)
goto err;
arc4random_buf(entropy_buf, MLKEM_SEED_LENGTH);
if (!MLKEM_generate_key_external_entropy(private_key,
out_encoded_public_key, out_encoded_public_key_len,
entropy_buf))
goto err;
if (out_optional_seed != NULL) {
*out_optional_seed = entropy_buf;
*out_optional_seed_len = MLKEM_SEED_LENGTH;
entropy_buf = NULL;
}
ret = 1;
err:
freezero(entropy_buf, MLKEM_SEED_LENGTH);
return ret;
}
LCRYPTO_ALIAS(MLKEM_generate_key);
int
MLKEM_private_key_from_seed(MLKEM_private_key *private_key,
const uint8_t *seed, size_t seed_len)
{
int ret = 0;
if (!private_key_is_new(private_key))
goto err;
if (seed_len != MLKEM_SEED_LENGTH)
goto err;
if (!mlkem_private_key_from_seed(seed, seed_len, private_key))
goto err;
private_key->state = MLKEM_PRIVATE_KEY_INITIALIZED;
ret = 1;
err:
return ret;
}
LCRYPTO_ALIAS(MLKEM_private_key_from_seed);
int
MLKEM_public_from_private(const MLKEM_private_key *private_key,
MLKEM_public_key *public_key)
{
if (!private_key_is_valid(private_key))
return 0;
if (!public_key_is_new(public_key))
return 0;
if (public_key->rank != private_key->rank)
return 0;
mlkem_public_from_private(private_key, public_key);
public_key->state = MLKEM_PUBLIC_KEY_INITIALIZED;
return 1;
}
LCRYPTO_ALIAS(MLKEM_public_from_private);
int
MLKEM_encap_external_entropy(const MLKEM_public_key *public_key,
const uint8_t *entropy, uint8_t **out_ciphertext,
size_t *out_ciphertext_len, uint8_t **out_shared_secret,
size_t *out_shared_secret_len)
{
uint8_t *secret = NULL;
uint8_t *ciphertext = NULL;
size_t ciphertext_len = 0;
int ret = 0;
if (*out_ciphertext != NULL)
goto err;
if (*out_shared_secret != NULL)
goto err;
if (!public_key_is_valid(public_key))
goto err;
if ((secret = calloc(1, MLKEM_SHARED_SECRET_LENGTH)) == NULL)
goto err;
ciphertext_len = MLKEM_public_key_ciphertext_length(public_key);
if ((ciphertext = calloc(1, ciphertext_len)) == NULL)
goto err;
mlkem_encap_external_entropy(ciphertext, secret, public_key, entropy);
*out_ciphertext = ciphertext;
*out_ciphertext_len = ciphertext_len;
ciphertext = NULL;
*out_shared_secret = secret;
*out_shared_secret_len = MLKEM_SHARED_SECRET_LENGTH;
secret = NULL;
ret = 1;
err:
freezero(secret, MLKEM_SHARED_SECRET_LENGTH);
freezero(ciphertext, ciphertext_len);
return ret;
}
int
MLKEM_encap(const MLKEM_public_key *public_key,
uint8_t **out_ciphertext, size_t *out_ciphertext_len,
uint8_t **out_shared_secret, size_t *out_shared_secret_len)
{
uint8_t entropy[MLKEM_ENCAP_ENTROPY];
int ret;
arc4random_buf(entropy, sizeof(entropy));
ret = MLKEM_encap_external_entropy(public_key, entropy, out_ciphertext,
out_ciphertext_len, out_shared_secret, out_shared_secret_len);
explicit_bzero(entropy, sizeof(entropy));
return ret;
}
LCRYPTO_ALIAS(MLKEM_encap);
int
MLKEM_decap(const MLKEM_private_key *private_key,
const uint8_t *ciphertext, size_t ciphertext_len,
uint8_t **out_shared_secret, size_t *out_shared_secret_len)
{
uint8_t *s = NULL;
int ret = 0;
if (*out_shared_secret != NULL)
goto err;
if (!private_key_is_valid(private_key))
goto err;
if (ciphertext_len != MLKEM_private_key_ciphertext_length(private_key))
goto err;
if ((s = calloc(1, MLKEM_SHARED_SECRET_LENGTH)) == NULL)
goto err;
mlkem_decap(private_key, ciphertext, ciphertext_len, s);
*out_shared_secret = s;
*out_shared_secret_len = MLKEM_SHARED_SECRET_LENGTH;
s = NULL;
ret = 1;
err:
freezero(s, MLKEM_SHARED_SECRET_LENGTH);
return ret;
}
LCRYPTO_ALIAS(MLKEM_decap);
int
MLKEM_marshal_public_key(const MLKEM_public_key *public_key, uint8_t **out,
size_t *out_len)
{
if (*out != NULL)
return 0;
if (!public_key_is_valid(public_key))
return 0;
return mlkem_marshal_public_key(public_key, out, out_len);
}
LCRYPTO_ALIAS(MLKEM_marshal_public_key);
int
MLKEM_marshal_private_key(const MLKEM_private_key *private_key, uint8_t **out,
size_t *out_len)
{
if (*out != NULL)
return 0;
if (!private_key_is_valid(private_key))
return 0;
return mlkem_marshal_private_key(private_key, out, out_len);
}
LCRYPTO_ALIAS(MLKEM_marshal_private_key);
int
MLKEM_parse_public_key(MLKEM_public_key *public_key, const uint8_t *in,
size_t in_len)
{
if (!public_key_is_new(public_key))
return 0;
if (in_len != MLKEM_public_key_encoded_length(public_key))
return 0;
if (!mlkem_parse_public_key(in, in_len, public_key))
return 0;
public_key->state = MLKEM_PUBLIC_KEY_INITIALIZED;
return 1;
}
LCRYPTO_ALIAS(MLKEM_parse_public_key);
int
MLKEM_parse_private_key(MLKEM_private_key *private_key, const uint8_t *in,
size_t in_len)
{
if (!private_key_is_new(private_key))
return 0;
if (in_len != MLKEM_private_key_encoded_length(private_key))
return 0;
if (!mlkem_parse_private_key(in, in_len, private_key))
return 0;
private_key->state = MLKEM_PRIVATE_KEY_INITIALIZED;
return 1;
}
LCRYPTO_ALIAS(MLKEM_parse_private_key);