#include <stdlib.h>
#include <openssl/curve25519.h>
#include <openssl/dh.h>
#include <openssl/ec.h>
#include <openssl/evp.h>
#include <openssl/mlkem.h>
#include "bytestring.h"
#include "ssl_local.h"
#include "tls_internal.h"
struct tls_key_share {
int nid;
uint16_t group_id;
size_t key_bits;
DH *dhe;
DH *dhe_peer;
EC_KEY *ecdhe;
EC_KEY *ecdhe_peer;
uint8_t *x25519_public;
uint8_t *x25519_private;
uint8_t *x25519_peer_public;
uint8_t *mlkem_public;
size_t mlkem_public_len;
MLKEM_private_key *mlkem_private;
MLKEM_public_key *mlkem_peer_public;
uint8_t *mlkem_encap;
size_t mlkem_encap_len;
uint8_t *mlkem_shared_secret;
size_t mlkem_shared_secret_len;
};
static struct tls_key_share *
tls_key_share_new_internal(int nid, uint16_t group_id)
{
struct tls_key_share *ks;
if ((ks = calloc(1, sizeof(struct tls_key_share))) == NULL)
return NULL;
ks->group_id = group_id;
ks->nid = nid;
return ks;
}
struct tls_key_share *
tls_key_share_new(uint16_t group_id)
{
int nid;
if (!tls1_ec_group_id2nid(group_id, &nid))
return NULL;
return tls_key_share_new_internal(nid, group_id);
}
struct tls_key_share *
tls_key_share_new_nid(int nid)
{
uint16_t group_id = 0;
if (nid != NID_dhKeyAgreement) {
if (!tls1_ec_nid2group_id(nid, &group_id))
return NULL;
}
return tls_key_share_new_internal(nid, group_id);
}
void
tls_key_share_free(struct tls_key_share *ks)
{
if (ks == NULL)
return;
DH_free(ks->dhe);
DH_free(ks->dhe_peer);
EC_KEY_free(ks->ecdhe);
EC_KEY_free(ks->ecdhe_peer);
freezero(ks->x25519_public, X25519_KEY_LENGTH);
freezero(ks->x25519_private, X25519_KEY_LENGTH);
freezero(ks->x25519_peer_public, X25519_KEY_LENGTH);
freezero(ks->mlkem_public, ks->mlkem_public_len);
MLKEM_private_key_free(ks->mlkem_private);
MLKEM_public_key_free(ks->mlkem_peer_public);
freezero(ks->mlkem_encap, ks->mlkem_encap_len);
freezero(ks->mlkem_shared_secret, ks->mlkem_shared_secret_len);
freezero(ks, sizeof(*ks));
}
uint16_t
tls_key_share_group(struct tls_key_share *ks)
{
return ks->group_id;
}
int
tls_key_share_nid(struct tls_key_share *ks)
{
return ks->nid;
}
void
tls_key_share_set_key_bits(struct tls_key_share *ks, size_t key_bits)
{
ks->key_bits = key_bits;
}
int
tls_key_share_set_dh_params(struct tls_key_share *ks, DH *dh_params)
{
if (ks->nid != NID_dhKeyAgreement)
return 0;
if (ks->dhe != NULL || ks->dhe_peer != NULL)
return 0;
if ((ks->dhe = DHparams_dup(dh_params)) == NULL)
return 0;
if ((ks->dhe_peer = DHparams_dup(dh_params)) == NULL)
return 0;
return 1;
}
int
tls_key_share_peer_pkey(struct tls_key_share *ks, EVP_PKEY *pkey)
{
if (ks->nid == NID_dhKeyAgreement && ks->dhe_peer != NULL)
return EVP_PKEY_set1_DH(pkey, ks->dhe_peer);
if (ks->nid == NID_X25519 && ks->x25519_peer_public != NULL)
return ssl_kex_dummy_ecdhe_x25519(pkey);
if (ks->ecdhe_peer != NULL)
return EVP_PKEY_set1_EC_KEY(pkey, ks->ecdhe_peer);
return 0;
}
static int
tls_key_share_generate_dhe(struct tls_key_share *ks)
{
if (ks->key_bits == 0) {
if (ks->dhe == NULL)
return 0;
return ssl_kex_generate_dhe(ks->dhe, ks->dhe);
}
if (ks->dhe != NULL || ks->dhe_peer != NULL)
return 0;
if ((ks->dhe = DH_new()) == NULL)
return 0;
if (!ssl_kex_generate_dhe_params_auto(ks->dhe, ks->key_bits))
return 0;
if ((ks->dhe_peer = DHparams_dup(ks->dhe)) == NULL)
return 0;
return 1;
}
static int
tls_key_share_generate_ecdhe_ecp(struct tls_key_share *ks)
{
EC_KEY *ecdhe = NULL;
int ret = 0;
if (ks->ecdhe != NULL)
goto err;
if ((ecdhe = EC_KEY_new()) == NULL)
goto err;
if (!ssl_kex_generate_ecdhe_ecp(ecdhe, ks->nid))
goto err;
ks->ecdhe = ecdhe;
ecdhe = NULL;
ret = 1;
err:
EC_KEY_free(ecdhe);
return ret;
}
static int
tls_key_share_generate_x25519(struct tls_key_share *ks)
{
uint8_t *public = NULL, *private = NULL;
int ret = 0;
if (ks->x25519_public != NULL || ks->x25519_private != NULL)
goto err;
if ((public = calloc(1, X25519_KEY_LENGTH)) == NULL)
goto err;
if ((private = calloc(1, X25519_KEY_LENGTH)) == NULL)
goto err;
X25519_keypair(public, private);
ks->x25519_public = public;
ks->x25519_private = private;
public = NULL;
private = NULL;
ret = 1;
err:
freezero(public, X25519_KEY_LENGTH);
freezero(private, X25519_KEY_LENGTH);
return ret;
}
static int
tls_key_share_generate_mlkem(struct tls_key_share *ks, int rank)
{
MLKEM_private_key *private = NULL;
uint8_t *public = NULL;
size_t p_len = 0;
int ret = 0;
if (ks->mlkem_public != NULL || ks->mlkem_private != NULL)
goto err;
if ((private = MLKEM_private_key_new(rank)) == NULL)
goto err;
if (!MLKEM_generate_key(private, &public, &p_len, NULL, NULL))
goto err;
ks->mlkem_public = public;
ks->mlkem_public_len = p_len;
ks->mlkem_private = private;
public = NULL;
private = NULL;
ret = 1;
err:
freezero(public, p_len);
MLKEM_private_key_free(private);
return ret;
}
static int
tls_key_share_client_generate_mlkem768x25519(struct tls_key_share *ks)
{
if (!tls_key_share_generate_mlkem(ks, MLKEM768_RANK))
return 0;
if (!tls_key_share_generate_x25519(ks))
return 0;
return 1;
}
static int
tls_key_share_server_generate_mlkem768x25519(struct tls_key_share *ks)
{
if (ks->mlkem_private != NULL)
return 0;
if (ks->x25519_peer_public == NULL)
return 0;
if (ks->mlkem_peer_public == NULL)
return 0;
if (!tls_key_share_generate_x25519(ks))
return 0;
return MLKEM_encap(ks->mlkem_peer_public, &ks->mlkem_encap,
&ks->mlkem_encap_len, &ks->mlkem_shared_secret,
&ks->mlkem_shared_secret_len);
}
static int
tls_key_share_generate(struct tls_key_share *ks)
{
if (ks->nid == NID_dhKeyAgreement)
return tls_key_share_generate_dhe(ks);
if (ks->nid == NID_X25519)
return tls_key_share_generate_x25519(ks);
return tls_key_share_generate_ecdhe_ecp(ks);
}
int
tls_key_share_client_generate(struct tls_key_share *ks)
{
if (ks->nid == NID_X25519MLKEM768)
return tls_key_share_client_generate_mlkem768x25519(ks);
return tls_key_share_generate(ks);
}
int
tls_key_share_server_generate(struct tls_key_share *ks)
{
if (ks->nid == NID_X25519MLKEM768)
return tls_key_share_server_generate_mlkem768x25519(ks);
return tls_key_share_generate(ks);
}
static int
tls_key_share_params_dhe(struct tls_key_share *ks, CBB *cbb)
{
if (ks->dhe == NULL)
return 0;
return ssl_kex_params_dhe(ks->dhe, cbb);
}
int
tls_key_share_params(struct tls_key_share *ks, CBB *cbb)
{
if (ks->nid == NID_dhKeyAgreement)
return tls_key_share_params_dhe(ks, cbb);
return 0;
}
static int
tls_key_share_public_dhe(struct tls_key_share *ks, CBB *cbb)
{
if (ks->dhe == NULL)
return 0;
return ssl_kex_public_dhe(ks->dhe, cbb);
}
static int
tls_key_share_public_ecdhe_ecp(struct tls_key_share *ks, CBB *cbb)
{
if (ks->ecdhe == NULL)
return 0;
return ssl_kex_public_ecdhe_ecp(ks->ecdhe, cbb);
}
static int
tls_key_share_public_x25519(struct tls_key_share *ks, CBB *cbb)
{
if (ks->x25519_public == NULL)
return 0;
return CBB_add_bytes(cbb, ks->x25519_public, X25519_KEY_LENGTH);
}
static int
tls_key_share_public_mlkem768x25519(struct tls_key_share *ks, CBB *cbb)
{
uint8_t *mlkem_part;
size_t mlkem_part_len;
if (ks->x25519_public == NULL)
return 0;
mlkem_part = ks->mlkem_encap;
mlkem_part_len = ks->mlkem_encap_len;
if (mlkem_part == NULL) {
mlkem_part = ks->mlkem_public;
mlkem_part_len = ks->mlkem_public_len;
}
if (mlkem_part == NULL)
return 0;
if (!CBB_add_bytes(cbb, mlkem_part, mlkem_part_len))
return 0;
return CBB_add_bytes(cbb, ks->x25519_public, X25519_KEY_LENGTH);
}
int
tls_key_share_public(struct tls_key_share *ks, CBB *cbb)
{
if (ks->nid == NID_dhKeyAgreement)
return tls_key_share_public_dhe(ks, cbb);
if (ks->nid == NID_X25519)
return tls_key_share_public_x25519(ks, cbb);
if (ks->nid == NID_X25519MLKEM768)
return tls_key_share_public_mlkem768x25519(ks, cbb);
return tls_key_share_public_ecdhe_ecp(ks, cbb);
}
static int
tls_key_share_peer_params_dhe(struct tls_key_share *ks, CBS *cbs,
int *decode_error, int *invalid_params)
{
if (ks->dhe != NULL || ks->dhe_peer != NULL)
return 0;
if ((ks->dhe_peer = DH_new()) == NULL)
return 0;
if (!ssl_kex_peer_params_dhe(ks->dhe_peer, cbs, decode_error,
invalid_params))
return 0;
if ((ks->dhe = DHparams_dup(ks->dhe_peer)) == NULL)
return 0;
return 1;
}
int
tls_key_share_peer_params(struct tls_key_share *ks, CBS *cbs,
int *decode_error, int *invalid_params)
{
if (ks->nid != NID_dhKeyAgreement)
return 0;
return tls_key_share_peer_params_dhe(ks, cbs, decode_error,
invalid_params);
}
static int
tls_key_share_peer_public_dhe(struct tls_key_share *ks, CBS *cbs,
int *decode_error, int *invalid_key)
{
if (ks->dhe_peer == NULL)
return 0;
return ssl_kex_peer_public_dhe(ks->dhe_peer, cbs, decode_error,
invalid_key);
}
static int
tls_key_share_peer_public_ecdhe_ecp(struct tls_key_share *ks, CBS *cbs)
{
EC_KEY *ecdhe = NULL;
int ret = 0;
if (ks->ecdhe_peer != NULL)
goto err;
if ((ecdhe = EC_KEY_new()) == NULL)
goto err;
if (!ssl_kex_peer_public_ecdhe_ecp(ecdhe, ks->nid, cbs))
goto err;
ks->ecdhe_peer = ecdhe;
ecdhe = NULL;
ret = 1;
err:
EC_KEY_free(ecdhe);
return ret;
}
static int
tls_key_share_peer_public_x25519(struct tls_key_share *ks, CBS *cbs,
int *decode_error)
{
size_t out_len;
*decode_error = 0;
if (ks->x25519_peer_public != NULL)
return 0;
if (CBS_len(cbs) != X25519_KEY_LENGTH) {
*decode_error = 1;
return 0;
}
return CBS_stow(cbs, &ks->x25519_peer_public, &out_len);
}
static int
tls_key_share_client_peer_public_mlkem768x25519(struct tls_key_share *ks,
CBS *cbs, int *decode_error)
{
CBS x25519_cbs, mlkem_ciphertext_cbs;
size_t out_len;
if (ks->mlkem_shared_secret != NULL)
return 0;
if (ks->mlkem_private == NULL)
return 0;
if (!CBS_get_bytes(cbs, &mlkem_ciphertext_cbs,
MLKEM_private_key_ciphertext_length(ks->mlkem_private)))
return 0;
if (!CBS_get_bytes(cbs, &x25519_cbs, X25519_KEY_LENGTH))
return 0;
if (CBS_len(cbs) != 0)
return 0;
if (!CBS_stow(&x25519_cbs, &ks->x25519_peer_public, &out_len))
return 0;
if (!CBS_stow(&mlkem_ciphertext_cbs, &ks->mlkem_encap, &ks->mlkem_encap_len))
return 0;
return 1;
}
static int
tls_key_share_server_peer_public_mlkem768x25519(struct tls_key_share *ks,
CBS *cbs, int *decode_error)
{
CBS x25519_cbs, mlkem768_cbs;
size_t out_len;
*decode_error = 0;
if (ks->mlkem_private != NULL)
return 0;
if (ks->mlkem_shared_secret != NULL)
return 0;
if (ks->mlkem_peer_public != NULL)
return 0;
if (ks->x25519_peer_public != NULL)
return 0;
if ((ks->mlkem_peer_public = MLKEM_public_key_new(MLKEM768_RANK)) == NULL)
goto err;
if (!CBS_get_bytes(cbs, &mlkem768_cbs,
MLKEM_public_key_encoded_length(ks->mlkem_peer_public)))
goto err;
if (!CBS_get_bytes(cbs, &x25519_cbs, X25519_KEY_LENGTH))
goto err;
if (CBS_len(cbs) != 0)
goto err;
if (!CBS_stow(&x25519_cbs, &ks->x25519_peer_public, &out_len))
goto err;
if (!MLKEM_parse_public_key(ks->mlkem_peer_public,
CBS_data(&mlkem768_cbs), CBS_len(&mlkem768_cbs)))
goto err;
return 1;
err:
*decode_error = 1;
return 0;
}
static int
tls_key_share_peer_public(struct tls_key_share *ks, CBS *cbs, int *decode_error,
int *invalid_key)
{
*decode_error = 0;
if (invalid_key != NULL)
*invalid_key = 0;
if (ks->nid == NID_dhKeyAgreement)
return tls_key_share_peer_public_dhe(ks, cbs, decode_error,
invalid_key);
if (ks->nid == NID_X25519)
return tls_key_share_peer_public_x25519(ks, cbs, decode_error);
return tls_key_share_peer_public_ecdhe_ecp(ks, cbs);
}
int
tls_key_share_client_peer_public(struct tls_key_share *ks, CBS *cbs,
int *decode_error, int *invalid_key)
{
if (ks->nid == NID_X25519MLKEM768)
return tls_key_share_client_peer_public_mlkem768x25519(ks, cbs,
decode_error);
return tls_key_share_peer_public(ks, cbs, decode_error, invalid_key);
}
int
tls_key_share_server_peer_public(struct tls_key_share *ks, CBS *cbs,
int *decode_error, int *invalid_key)
{
if (ks->nid == NID_X25519MLKEM768)
return tls_key_share_server_peer_public_mlkem768x25519(ks, cbs,
decode_error);
return tls_key_share_peer_public(ks, cbs, decode_error, invalid_key);
}
static int
tls_key_share_derive_dhe(struct tls_key_share *ks,
uint8_t **shared_key, size_t *shared_key_len)
{
if (ks->dhe == NULL || ks->dhe_peer == NULL)
return 0;
return ssl_kex_derive_dhe(ks->dhe, ks->dhe_peer, shared_key,
shared_key_len);
}
static int
tls_key_share_derive_ecdhe_ecp(struct tls_key_share *ks,
uint8_t **shared_key, size_t *shared_key_len)
{
if (ks->ecdhe == NULL || ks->ecdhe_peer == NULL)
return 0;
return ssl_kex_derive_ecdhe_ecp(ks->ecdhe, ks->ecdhe_peer,
shared_key, shared_key_len);
}
static int
tls_key_share_derive_x25519(struct tls_key_share *ks,
uint8_t **shared_key, size_t *shared_key_len)
{
uint8_t *sk = NULL;
int ret = 0;
if (ks->x25519_private == NULL || ks->x25519_peer_public == NULL)
goto err;
if ((sk = calloc(1, X25519_KEY_LENGTH)) == NULL)
goto err;
if (!X25519(sk, ks->x25519_private, ks->x25519_peer_public))
goto err;
*shared_key = sk;
*shared_key_len = X25519_KEY_LENGTH;
sk = NULL;
ret = 1;
err:
freezero(sk, X25519_KEY_LENGTH);
return ret;
}
static int
tls_key_share_derive_mlkem768x25519(struct tls_key_share *ks,
uint8_t **out_shared_key, size_t *out_shared_key_len)
{
uint8_t *x25519_shared_key;
CBB cbb;
memset(&cbb, 0, sizeof(cbb));
if (ks->x25519_private == NULL)
goto err;
if (ks->x25519_peer_public == NULL)
goto err;
if (ks->mlkem_shared_secret == NULL) {
if (ks->mlkem_private == NULL)
goto err;
if (ks->mlkem_encap == NULL)
goto err;
if (!MLKEM_decap(ks->mlkem_private, ks->mlkem_encap,
MLKEM_private_key_ciphertext_length(ks->mlkem_private),
&ks->mlkem_shared_secret, &ks->mlkem_shared_secret_len))
goto err;
}
if (!CBB_init(&cbb, ks->mlkem_shared_secret_len + X25519_KEY_LENGTH))
goto err;
if (!CBB_add_bytes(&cbb, ks->mlkem_shared_secret,
ks->mlkem_shared_secret_len))
goto err;
if (!CBB_add_space(&cbb, &x25519_shared_key, X25519_KEY_LENGTH))
goto err;
if (!X25519(x25519_shared_key, ks->x25519_private,
ks->x25519_peer_public))
goto err;
if (!CBB_finish(&cbb, out_shared_key, out_shared_key_len))
goto err;
return 1;
err:
CBB_cleanup(&cbb);
return 0;
}
int
tls_key_share_derive(struct tls_key_share *ks, uint8_t **shared_key,
size_t *shared_key_len)
{
if (*shared_key != NULL)
return 0;
*shared_key_len = 0;
if (ks->nid == NID_dhKeyAgreement)
return tls_key_share_derive_dhe(ks, shared_key,
shared_key_len);
if (ks->nid == NID_X25519)
return tls_key_share_derive_x25519(ks, shared_key,
shared_key_len);
if (ks->nid == NID_X25519MLKEM768)
return tls_key_share_derive_mlkem768x25519(ks, shared_key,
shared_key_len);
return tls_key_share_derive_ecdhe_ecp(ks, shared_key,
shared_key_len);
}
int
tls_key_share_peer_security(const SSL *ssl, struct tls_key_share *ks)
{
switch (ks->nid) {
case NID_dhKeyAgreement:
return ssl_security_dh(ssl, ks->dhe_peer);
default:
return 0;
}
}