#include <security/cryptoki.h>
#include <security/pkcs11.h>
#include <netsmb/nsmb_kcrypt.h>
#include <sys/cmn_err.h>
#include <sys/debug.h>
#include <sys/stream.h>
#include <sys/strsun.h>
#include <umem.h>
#include <strings.h>
size_t msgsize(mblk_t *);
static int copy_mblks(void *buf, size_t buflen, enum uio_rw, mblk_t *m);
static int
find_mech(CK_MECHANISM_TYPE id)
{
CK_SESSION_HANDLE hdl;
CK_RV rv;
rv = SUNW_C_GetMechSession(id, &hdl);
if (rv != CKR_OK) {
return (-1);
}
(void) C_CloseSession(hdl);
return (0);
}
int
nsmb_aes_ccm_getmech(smb_crypto_mech_t *mech)
{
if (find_mech(CKM_AES_CCM) != 0) {
cmn_err(CE_NOTE, "PKCS#11: no mech AES_CCM");
return (-1);
}
mech->mechanism = CKM_AES_CCM;
return (0);
}
int
nsmb_aes_gcm_getmech(smb_crypto_mech_t *mech)
{
if (find_mech(CKM_AES_GCM) != 0) {
cmn_err(CE_NOTE, "PKCS#11: no mech AES_GCM");
return (-1);
}
mech->mechanism = CKM_AES_GCM;
return (0);
}
void
nsmb_crypto_init_ccm_param(smb_enc_ctx_t *ctx,
uint8_t *nonce, size_t noncesize,
uint8_t *auth, size_t authsize,
size_t datasize)
{
ASSERT3U(noncesize, >=, SMB3_AES_CCM_NONCE_SIZE);
ctx->param.ccm.ulDataLen = datasize;
ctx->param.ccm.pNonce = nonce;
ctx->param.ccm.ulNonceLen = SMB3_AES_CCM_NONCE_SIZE;
ctx->param.ccm.pAAD = auth;
ctx->param.ccm.ulAADLen = authsize;
ctx->param.ccm.ulMACLen = SMB2_SIG_SIZE;
ctx->mech.pParameter = (caddr_t)&ctx->param.ccm;
ctx->mech.ulParameterLen = sizeof (ctx->param.ccm);
}
void
nsmb_crypto_init_gcm_param(smb_enc_ctx_t *ctx,
uint8_t *nonce, size_t noncesize,
uint8_t *auth, size_t authsize)
{
ASSERT3U(noncesize, >=, SMB3_AES_GCM_NONCE_SIZE);
ctx->param.gcm.pIv = nonce;
ctx->param.gcm.ulIvLen = SMB3_AES_GCM_NONCE_SIZE;
ctx->param.gcm.pAAD = auth;
ctx->param.gcm.ulAADLen = authsize;
ctx->param.gcm.ulTagBits = SMB2_SIG_SIZE << 3;
ctx->mech.pParameter = (caddr_t)&ctx->param.gcm;
ctx->mech.ulParameterLen = sizeof (ctx->param.gcm);
}
int
nsmb_encrypt_init(smb_enc_ctx_t *ctxp,
uint8_t *key, size_t keylen)
{
CK_OBJECT_HANDLE hkey = 0;
CK_MECHANISM *mech = &ctxp->mech;
CK_RV rv;
rv = SUNW_C_GetMechSession(mech->mechanism, &ctxp->ctx);
if (rv != CKR_OK)
return (-1);
rv = SUNW_C_KeyToObject(ctxp->ctx, mech->mechanism,
key, keylen, &hkey);
if (rv != CKR_OK)
return (-1);
rv = C_EncryptInit(ctxp->ctx, mech, hkey);
if (rv != CKR_OK) {
cmn_err(CE_WARN, "C_EncryptInit failed: 0x%lx", rv);
}
(void) C_DestroyObject(ctxp->ctx, hkey);
return (rv == CKR_OK ? 0 : -1);
}
int
nsmb_decrypt_init(smb_enc_ctx_t *ctxp,
uint8_t *key, size_t keylen)
{
CK_OBJECT_HANDLE hkey = 0;
CK_MECHANISM *mech = &ctxp->mech;
CK_RV rv;
rv = SUNW_C_GetMechSession(mech->mechanism, &ctxp->ctx);
if (rv != CKR_OK)
return (-1);
rv = SUNW_C_KeyToObject(ctxp->ctx, mech->mechanism,
key, keylen, &hkey);
if (rv != CKR_OK)
return (-1);
rv = C_DecryptInit(ctxp->ctx, mech, hkey);
if (rv != CKR_OK) {
cmn_err(CE_WARN, "C_DecryptInit failed: 0x%lx", rv);
}
(void) C_DestroyObject(ctxp->ctx, hkey);
return (rv == CKR_OK ? 0 : -1);
}
void
nsmb_enc_ctx_done(smb_enc_ctx_t *ctxp)
{
if (ctxp->ctx != 0) {
(void) C_CloseSession(ctxp->ctx);
ctxp->ctx = 0;
}
}
int
nsmb_encrypt_mblks(smb_enc_ctx_t *ctxp, mblk_t *mp, size_t clearlen)
{
uint8_t *buf = NULL;
size_t inlen, outlen;
ulong_t tlen;
int err;
CK_RV rv;
inlen = clearlen;
outlen = clearlen + SMB2_SIG_SIZE;
ASSERT(msgsize(mp) >= outlen);
buf = umem_alloc(outlen, 0);
if (buf == NULL)
return (-1);
err = copy_mblks(buf, inlen, UIO_WRITE, mp);
if (err != 0)
goto out;
tlen = outlen;
rv = C_Encrypt(ctxp->ctx, buf, inlen, buf, &tlen);
if (rv != CKR_OK) {
cmn_err(CE_WARN, "C_Encrypt failed: 0x%lx", rv);
err = -1;
goto out;
}
if (tlen != outlen) {
cmn_err(CE_WARN, "nsmb_encrypt_mblks outlen %d vs %d",
(int)tlen, (int)outlen);
err = -1;
goto out;
}
err = copy_mblks(buf, outlen, UIO_READ, mp);
out:
if (buf != NULL)
umem_free(buf, outlen);
return (err);
}
int
nsmb_decrypt_mblks(smb_enc_ctx_t *ctxp, mblk_t *mp, size_t cipherlen)
{
uint8_t *buf = NULL;
size_t inlen, outlen;
ulong_t tlen;
int err;
CK_RV rv;
if (cipherlen <= SMB2_SIG_SIZE)
return (-1);
inlen = cipherlen;
outlen = cipherlen - SMB2_SIG_SIZE;
ASSERT(msgsize(mp) >= inlen);
buf = umem_alloc(inlen, 0);
if (buf == NULL)
return (-1);
err = copy_mblks(buf, inlen, UIO_WRITE, mp);
if (err != 0)
goto out;
tlen = outlen;
rv = C_Decrypt(ctxp->ctx, buf, inlen, buf, &tlen);
if (rv != CKR_OK) {
cmn_err(CE_WARN, "C_Decrypt failed: 0x%lx", rv);
err = -1;
goto out;
}
if (tlen != outlen) {
cmn_err(CE_WARN, "nsmb_decrypt_mblks outlen %d vs %d",
(int)tlen, (int)outlen);
err = -1;
goto out;
}
err = copy_mblks(buf, outlen, UIO_READ, mp);
out:
if (buf != NULL)
umem_free(buf, inlen);
return (err);
}
static int
copy_mblks(void *buf, size_t buflen, enum uio_rw rw, mblk_t *m)
{
uchar_t *p = buf;
size_t rem = buflen;
size_t len;
while (rem > 0) {
if (m == NULL)
return (-1);
ASSERT(m->b_datap->db_type == M_DATA);
len = MBLKL(m);
if (len > rem)
len = rem;
if (rw == UIO_READ) {
bcopy(p, m->b_rptr, len);
} else {
bcopy(m->b_rptr, p, len);
}
m = m->b_cont;
p += len;
rem -= len;
}
return (0);
}
size_t
msgsize(mblk_t *mp)
{
size_t n = 0;
for (; mp != NULL; mp = mp->b_cont)
n += MBLKL(mp);
return (n);
}