#ifndef OPENSSL_NO_SM2
#include <string.h>
#include <openssl/sm2.h>
#include <openssl/asn1t.h>
#include <openssl/x509.h>
#include <openssl/evp.h>
#include "err_local.h"
#include "evp_local.h"
#include "sm2_local.h"
typedef struct {
EC_GROUP *gen_group;
const EVP_MD *md;
EVP_MD_CTX *md_ctx;
uint8_t* uid;
size_t uid_len;
} SM2_PKEY_CTX;
static int
pkey_sm2_init(EVP_PKEY_CTX *ctx)
{
SM2_PKEY_CTX *dctx;
if ((dctx = calloc(1, sizeof(*dctx))) == NULL) {
SM2error(ERR_R_MALLOC_FAILURE);
return 0;
}
ctx->data = dctx;
return 1;
}
static void
pkey_sm2_cleanup(EVP_PKEY_CTX *ctx)
{
SM2_PKEY_CTX *dctx = ctx->data;
if (ctx == NULL || ctx->data == NULL)
return;
EC_GROUP_free(dctx->gen_group);
free(dctx->uid);
free(dctx);
ctx->data = NULL;
}
static int
pkey_sm2_copy(EVP_PKEY_CTX *dst, EVP_PKEY_CTX *src)
{
SM2_PKEY_CTX *dctx, *sctx;
if (!pkey_sm2_init(dst))
return 0;
sctx = src->data;
dctx = dst->data;
if (sctx->gen_group) {
if ((dctx->gen_group = EC_GROUP_dup(sctx->gen_group)) == NULL) {
SM2error(ERR_R_MALLOC_FAILURE);
goto err;
}
}
if (sctx->uid != NULL) {
if ((dctx->uid = malloc(sctx->uid_len)) == NULL) {
SM2error(ERR_R_MALLOC_FAILURE);
goto err;
}
memcpy(dctx->uid, sctx->uid, sctx->uid_len);
dctx->uid_len = sctx->uid_len;
}
dctx->md = sctx->md;
if (!EVP_MD_CTX_copy(dctx->md_ctx, sctx->md_ctx))
goto err;
return 1;
err:
pkey_sm2_cleanup(dst);
return 0;
}
static int
pkey_sm2_sign(EVP_PKEY_CTX *ctx, unsigned char *sig, size_t *siglen,
const unsigned char *tbs, size_t tbslen)
{
unsigned int sltmp;
int ret, sig_sz;
if ((sig_sz = ECDSA_size(ctx->pkey->pkey.ec)) <= 0)
return 0;
if (sig == NULL) {
*siglen = sig_sz;
return 1;
}
if (*siglen < (size_t)sig_sz) {
SM2error(SM2_R_BUFFER_TOO_SMALL);
return 0;
}
if ((ret = SM2_sign(tbs, tbslen, sig, &sltmp, ctx->pkey->pkey.ec)) <= 0)
return ret;
*siglen = (size_t)sltmp;
return 1;
}
static int
pkey_sm2_verify(EVP_PKEY_CTX *ctx, const unsigned char *sig, size_t siglen,
const unsigned char *tbs, size_t tbslen)
{
return SM2_verify(tbs, tbslen, sig, siglen, ctx->pkey->pkey.ec);
}
static int
pkey_sm2_encrypt(EVP_PKEY_CTX *ctx, unsigned char *out, size_t *outlen,
const unsigned char *in, size_t inlen)
{
SM2_PKEY_CTX *dctx = ctx->data;
const EVP_MD *md = (dctx->md == NULL) ? EVP_sm3() : dctx->md;
if (out == NULL) {
if (!SM2_ciphertext_size(ctx->pkey->pkey.ec, md, inlen, outlen))
return -1;
else
return 1;
}
return SM2_encrypt(ctx->pkey->pkey.ec, md, in, inlen, out, outlen);
}
static int
pkey_sm2_decrypt(EVP_PKEY_CTX *ctx, unsigned char *out, size_t *outlen,
const unsigned char *in, size_t inlen)
{
SM2_PKEY_CTX *dctx = ctx->data;
const EVP_MD *md = (dctx->md == NULL) ? EVP_sm3() : dctx->md;
if (out == NULL) {
if (!SM2_plaintext_size(ctx->pkey->pkey.ec, md, inlen, outlen))
return -1;
else
return 1;
}
return SM2_decrypt(ctx->pkey->pkey.ec, md, in, inlen, out, outlen);
}
static int
pkey_sm2_ctrl(EVP_PKEY_CTX *ctx, int type, int p1, void *p2)
{
SM2_PKEY_CTX *dctx = ctx->data;
EC_GROUP *group = NULL;
switch (type) {
case EVP_PKEY_CTRL_DIGESTINIT:
dctx->md_ctx = p2;
return 1;
case EVP_PKEY_CTRL_EC_PARAMGEN_CURVE_NID:
if ((group = EC_GROUP_new_by_curve_name(p1)) == NULL) {
SM2error(SM2_R_INVALID_CURVE);
return 0;
}
EC_GROUP_free(dctx->gen_group);
dctx->gen_group = group;
return 1;
case EVP_PKEY_CTRL_SM2_SET_UID:
if ((p1 < 0) || ((p1 == 0) && (p2 != NULL))) {
SM2error(SM2_R_INVALID_ARGUMENT);
return 0;
}
if ((p1 > 0) && (p2 == NULL)) {
SM2error(ERR_R_PASSED_NULL_PARAMETER);
return 0;
}
free(dctx->uid);
if (p2 == NULL) {
dctx->uid = NULL;
dctx->uid_len = 0;
return 1;
}
if ((dctx->uid = malloc(p1)) == NULL) {
SM2error(ERR_R_MALLOC_FAILURE);
return 1;
}
memcpy(dctx->uid, p2, p1);
dctx->uid_len = p1;
return 1;
case EVP_PKEY_CTRL_SM2_HASH_UID:
{
const EVP_MD* md;
uint8_t za[EVP_MAX_MD_SIZE] = {0};
int md_len;
if (dctx->uid == NULL) {
SM2error(SM2_R_INVALID_ARGUMENT);
return 0;
}
if ((md = EVP_MD_CTX_md(dctx->md_ctx)) == NULL) {
SM2error(ERR_R_EVP_LIB);
return 0;
}
if ((md_len = EVP_MD_size(md)) < 0) {
SM2error(SM2_R_INVALID_DIGEST);
return 0;
}
if (sm2_compute_userid_digest(za, md, dctx->uid, dctx->uid_len,
ctx->pkey->pkey.ec) != 1) {
SM2error(SM2_R_DIGEST_FAILURE);
return 0;
}
return EVP_DigestUpdate(dctx->md_ctx, za, md_len);
}
case EVP_PKEY_CTRL_SM2_GET_UID_LEN:
if (p2 == NULL) {
SM2error(ERR_R_PASSED_NULL_PARAMETER);
return 0;
}
*(size_t *)p2 = dctx->uid_len;
return 1;
case EVP_PKEY_CTRL_SM2_GET_UID:
if (p2 == NULL) {
SM2error(ERR_R_PASSED_NULL_PARAMETER);
return 0;
}
if (dctx->uid_len == 0) {
return 1;
}
memcpy(p2, dctx->uid, dctx->uid_len);
return 1;
case EVP_PKEY_CTRL_MD:
dctx->md = p2;
return 1;
default:
return -2;
}
}
static int
pkey_sm2_ctrl_str(EVP_PKEY_CTX *ctx, const char *type, const char *value)
{
int nid;
if (strcmp(type, "ec_paramgen_curve") == 0) {
if (((nid = EC_curve_nist2nid(value)) == NID_undef) &&
((nid = OBJ_sn2nid(value)) == NID_undef) &&
((nid = OBJ_ln2nid(value)) == NID_undef)) {
SM2error(SM2_R_INVALID_CURVE);
return 0;
}
return EVP_PKEY_CTX_set_ec_paramgen_curve_nid(ctx, nid);
} else if (strcmp(type, "sm2_uid") == 0) {
return EVP_PKEY_CTX_set_sm2_uid(ctx, (void*) value,
(int)strlen(value));
}
return -2;
}
const EVP_PKEY_METHOD sm2_pkey_meth = {
.pkey_id = EVP_PKEY_SM2,
.init = pkey_sm2_init,
.copy = pkey_sm2_copy,
.cleanup = pkey_sm2_cleanup,
.sign = pkey_sm2_sign,
.verify = pkey_sm2_verify,
.encrypt = pkey_sm2_encrypt,
.decrypt = pkey_sm2_decrypt,
.ctrl = pkey_sm2_ctrl,
.ctrl_str = pkey_sm2_ctrl_str
};
#endif