#include <string.h>
#include <openssl/err.h>
#include <openssl/core_dispatch.h>
#include <openssl/core_names.h>
#include <openssl/params.h>
#include <openssl/rand.h>
#include <openssl/proverr.h>
#include "slh_dsa_local.h"
#include "slh_dsa_key.h"
#include "internal/encoder.h"
static int slh_dsa_compute_pk_root(SLH_DSA_HASH_CTX *ctx, SLH_DSA_KEY *out, int verify);
static void slh_dsa_key_hash_cleanup(SLH_DSA_KEY *key)
{
OPENSSL_free(key->propq);
if (key->md_big != key->md)
EVP_MD_free(key->md_big);
key->md_big = NULL;
EVP_MD_free(key->md);
EVP_MAC_free(key->hmac);
key->md = NULL;
}
static int slh_dsa_key_hash_init(SLH_DSA_KEY *key)
{
int is_shake = key->params->is_shake;
int security_category = key->params->security_category;
const char *digest_alg = is_shake ? "SHAKE-256" : "SHA2-256";
key->md = EVP_MD_fetch(key->libctx, digest_alg, key->propq);
if (key->md == NULL)
return 0;
if (is_shake == 0) {
if (security_category == 1) {
key->md_big = key->md;
} else {
key->md_big = EVP_MD_fetch(key->libctx, "SHA2-512", key->propq);
if (key->md_big == NULL)
goto err;
}
key->hmac = EVP_MAC_fetch(key->libctx, "HMAC", key->propq);
if (key->hmac == NULL)
goto err;
}
key->adrs_func = ossl_slh_get_adrs_fn(is_shake == 0);
key->hash_func = ossl_slh_get_hash_fn(is_shake);
return 1;
err:
slh_dsa_key_hash_cleanup(key);
return 0;
}
static void slh_dsa_key_hash_dup(SLH_DSA_KEY *dst, const SLH_DSA_KEY *src)
{
if (src->md_big != NULL && src->md_big != src->md)
EVP_MD_up_ref(src->md_big);
if (src->md != NULL)
EVP_MD_up_ref(src->md);
if (src->hmac != NULL)
EVP_MAC_up_ref(src->hmac);
}
OSSL_LIB_CTX *ossl_slh_dsa_key_get0_libctx(const SLH_DSA_KEY *key)
{
return key != NULL ? key->libctx : NULL;
}
SLH_DSA_KEY *ossl_slh_dsa_key_new(OSSL_LIB_CTX *libctx, const char *propq,
const char *alg)
{
SLH_DSA_KEY *ret;
const SLH_DSA_PARAMS *params = ossl_slh_dsa_params_get(alg);
if (params == NULL)
return NULL;
ret = OPENSSL_zalloc(sizeof(*ret));
if (ret != NULL) {
ret->libctx = libctx;
ret->params = params;
if (propq != NULL) {
ret->propq = OPENSSL_strdup(propq);
if (ret->propq == NULL)
goto err;
}
if (!slh_dsa_key_hash_init(ret))
goto err;
}
return ret;
err:
ossl_slh_dsa_key_free(ret);
return NULL;
}
void ossl_slh_dsa_key_free(SLH_DSA_KEY *key)
{
if (key == NULL)
return;
slh_dsa_key_hash_cleanup(key);
OPENSSL_cleanse(&key->priv, sizeof(key->priv) >> 1);
OPENSSL_free(key);
}
SLH_DSA_KEY *ossl_slh_dsa_key_dup(const SLH_DSA_KEY *src, int selection)
{
SLH_DSA_KEY *ret = NULL;
if (src == NULL)
return NULL;
ret = OPENSSL_zalloc(sizeof(*ret));
if (ret != NULL) {
*ret = *src;
ret->propq = NULL;
ret->pub = NULL;
ret->has_priv = 0;
slh_dsa_key_hash_dup(ret, src);
if (src->propq != NULL) {
ret->propq = OPENSSL_strdup(src->propq);
if (ret->propq == NULL)
goto err;
}
if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0) {
if (src->pub != NULL)
ret->pub = SLH_DSA_PUB(ret);
if ((selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0)
ret->has_priv = src->has_priv;
}
}
return ret;
err:
ossl_slh_dsa_key_free(ret);
return NULL;
}
int ossl_slh_dsa_key_equal(const SLH_DSA_KEY *key1, const SLH_DSA_KEY *key2,
int selection)
{
int key_checked = 0;
if (key1->params != key2->params)
return 0;
if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0) {
if ((selection & OSSL_KEYMGMT_SELECT_PUBLIC_KEY) != 0) {
if (key1->pub != NULL && key2->pub != NULL) {
if (memcmp(key1->pub, key2->pub, key1->params->pk_len) != 0)
return 0;
key_checked = 1;
}
}
if (!key_checked
&& (selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0) {
if (key1->has_priv && key2->has_priv) {
if (memcmp(key1->priv, key2->priv,
key1->params->pk_len)
!= 0)
return 0;
key_checked = 1;
}
}
return key_checked;
}
return 1;
}
int ossl_slh_dsa_key_has(const SLH_DSA_KEY *key, int selection)
{
if ((selection & OSSL_KEYMGMT_SELECT_KEYPAIR) != 0) {
if (key->pub == NULL)
return 0;
if ((selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0
&& key->has_priv == 0)
return 0;
return 1;
}
return 0;
}
int ossl_slh_dsa_key_pairwise_check(const SLH_DSA_KEY *key)
{
int ret;
SLH_DSA_HASH_CTX *ctx = NULL;
if (key->pub == NULL || key->has_priv == 0)
return 0;
ctx = ossl_slh_dsa_hash_ctx_new(key);
if (ctx == NULL)
return 0;
ret = slh_dsa_compute_pk_root(ctx, (SLH_DSA_KEY *)key, 1);
ossl_slh_dsa_hash_ctx_free(ctx);
return ret;
}
void ossl_slh_dsa_key_reset(SLH_DSA_KEY *key)
{
key->pub = NULL;
if (key->has_priv) {
key->has_priv = 0;
OPENSSL_cleanse(key->priv, sizeof(key->priv));
}
}
int ossl_slh_dsa_key_fromdata(SLH_DSA_KEY *key, const OSSL_PARAM params[],
int include_private)
{
size_t priv_len, key_len, data_len = 0;
const OSSL_PARAM *param_priv = NULL, *param_pub = NULL;
void *p;
if (key == NULL)
return 0;
priv_len = ossl_slh_dsa_key_get_priv_len(key);
key_len = priv_len >> 1;
if (include_private) {
param_priv = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PRIV_KEY);
if (param_priv != NULL) {
p = key->priv;
if (!OSSL_PARAM_get_octet_string(param_priv, &p, priv_len, &data_len))
return 0;
if (data_len == priv_len) {
key->has_priv = 1;
key->pub = SLH_DSA_PUB(key);
return 1;
}
if (data_len != key_len)
goto err;
key->has_priv = 1;
}
}
p = SLH_DSA_PUB(key);
param_pub = OSSL_PARAM_locate_const(params, OSSL_PKEY_PARAM_PUB_KEY);
if (param_pub == NULL
|| !OSSL_PARAM_get_octet_string(param_pub, &p, key_len, &data_len)
|| data_len != key_len)
goto err;
key->pub = p;
return 1;
err:
ossl_slh_dsa_key_reset(key);
return 0;
}
static int slh_dsa_compute_pk_root(SLH_DSA_HASH_CTX *ctx, SLH_DSA_KEY *out,
int validate)
{
const SLH_DSA_KEY *key = ctx->key;
SLH_ADRS_FUNC_DECLARE(key, adrsf);
SLH_ADRS_DECLARE(adrs);
const SLH_DSA_PARAMS *params = key->params;
size_t n = params->n;
uint8_t pk_root[SLH_DSA_MAX_N], *dst;
adrsf->zero(adrs);
adrsf->set_layer_address(adrs, params->d - 1);
dst = validate ? pk_root : SLH_DSA_PK_ROOT(out);
return ossl_slh_xmss_node(ctx, SLH_DSA_SK_SEED(key), 0, params->hm,
SLH_DSA_PK_SEED(key), adrs, dst, n)
&& (validate == 0 || memcmp(dst, SLH_DSA_PK_ROOT(out), n) == 0);
}
int ossl_slh_dsa_generate_key(SLH_DSA_HASH_CTX *ctx, SLH_DSA_KEY *out,
OSSL_LIB_CTX *lib_ctx,
const uint8_t *entropy, size_t entropy_len)
{
size_t n = out->params->n;
size_t secret_key_len = 2 * n;
size_t pk_seed_len = n;
size_t entropy_len_expected = secret_key_len + pk_seed_len;
uint8_t *priv = SLH_DSA_PRIV(out);
uint8_t *pub = SLH_DSA_PUB(out);
if (entropy != NULL && entropy_len != 0) {
if (entropy_len != entropy_len_expected)
goto err;
memcpy(priv, entropy, entropy_len_expected);
} else {
if (RAND_priv_bytes_ex(lib_ctx, priv, secret_key_len, 0) <= 0
|| RAND_bytes_ex(lib_ctx, pub, pk_seed_len, 0) <= 0)
goto err;
}
if (!slh_dsa_compute_pk_root(ctx, out, 0))
goto err;
out->pub = pub;
out->has_priv = 1;
return 1;
err:
out->pub = NULL;
out->has_priv = 0;
OPENSSL_cleanse(priv, secret_key_len);
return 0;
}
int ossl_slh_dsa_key_type_matches(const SLH_DSA_KEY *key, const char *alg)
{
return (OPENSSL_strcasecmp(key->params->alg, alg) == 0);
}
const uint8_t *ossl_slh_dsa_key_get_pub(const SLH_DSA_KEY *key)
{
return key->pub;
}
size_t ossl_slh_dsa_key_get_pub_len(const SLH_DSA_KEY *key)
{
return 2 * key->params->n;
}
const uint8_t *ossl_slh_dsa_key_get_priv(const SLH_DSA_KEY *key)
{
return key->has_priv ? key->priv : NULL;
}
size_t ossl_slh_dsa_key_get_priv_len(const SLH_DSA_KEY *key)
{
return 4 * key->params->n;
}
size_t ossl_slh_dsa_key_get_n(const SLH_DSA_KEY *key)
{
return key->params->n;
}
size_t ossl_slh_dsa_key_get_sig_len(const SLH_DSA_KEY *key)
{
return key->params->sig_len;
}
const char *ossl_slh_dsa_key_get_name(const SLH_DSA_KEY *key)
{
return key->params->alg;
}
int ossl_slh_dsa_key_get_type(const SLH_DSA_KEY *key)
{
return key->params->type;
}
int ossl_slh_dsa_set_priv(SLH_DSA_KEY *key, const uint8_t *priv, size_t priv_len)
{
if (ossl_slh_dsa_key_get_priv_len(key) != priv_len)
return 0;
memcpy(key->priv, priv, priv_len);
key->has_priv = 1;
key->pub = SLH_DSA_PUB(key);
return 1;
}
int ossl_slh_dsa_set_pub(SLH_DSA_KEY *key, const uint8_t *pub, size_t pub_len)
{
if (ossl_slh_dsa_key_get_pub_len(key) != pub_len)
return 0;
key->pub = SLH_DSA_PUB(key);
memcpy(key->pub, pub, pub_len);
key->has_priv = 0;
return 1;
}
#ifndef FIPS_MODULE
int ossl_slh_dsa_key_to_text(BIO *out, const SLH_DSA_KEY *key, int selection)
{
const char *name;
if (out == NULL || key == NULL) {
ERR_raise(ERR_LIB_PROV, ERR_R_PASSED_NULL_PARAMETER);
return 0;
}
name = ossl_slh_dsa_key_get_name(key);
if (ossl_slh_dsa_key_get_pub(key) == NULL) {
ERR_raise_data(ERR_LIB_PROV, PROV_R_MISSING_KEY,
"no %s key material available", name);
return 0;
}
if ((selection & OSSL_KEYMGMT_SELECT_PRIVATE_KEY) != 0) {
if (ossl_slh_dsa_key_get_priv(key) == NULL) {
ERR_raise_data(ERR_LIB_PROV, PROV_R_MISSING_KEY,
"no %s key material available", name);
return 0;
}
if (BIO_printf(out, "%s Private-Key:\n", name) <= 0)
return 0;
if (!ossl_bio_print_labeled_buf(out, "priv:", ossl_slh_dsa_key_get_priv(key),
ossl_slh_dsa_key_get_priv_len(key)))
return 0;
} else if ((selection & OSSL_KEYMGMT_SELECT_PUBLIC_KEY) != 0) {
if (BIO_printf(out, "%s Public-Key:\n", name) <= 0)
return 0;
}
if (!ossl_bio_print_labeled_buf(out, "pub:", ossl_slh_dsa_key_get_pub(key),
ossl_slh_dsa_key_get_pub_len(key)))
return 0;
return 1;
}
#endif