#include <openssl/core_names.h>
#include "prov/ciphercommon.h"
#include "internal/nelem.h"
#include "cipher_cts.h"
#define CTS_CS1 0
#define CTS_CS2 1
#define CTS_CS3 2
#define CTS_BLOCK_SIZE 16
typedef union {
size_t align;
unsigned char c[CTS_BLOCK_SIZE];
} aligned_16bytes;
typedef struct cts_mode_name2id_st {
unsigned int id;
const char *name;
} CTS_MODE_NAME2ID;
static CTS_MODE_NAME2ID cts_modes[] = {
{ CTS_CS1, OSSL_CIPHER_CTS_MODE_CS1 },
{ CTS_CS2, OSSL_CIPHER_CTS_MODE_CS2 },
{ CTS_CS3, OSSL_CIPHER_CTS_MODE_CS3 },
};
const char *ossl_cipher_cbc_cts_mode_id2name(unsigned int id)
{
size_t i;
for (i = 0; i < OSSL_NELEM(cts_modes); ++i) {
if (cts_modes[i].id == id)
return cts_modes[i].name;
}
return NULL;
}
int ossl_cipher_cbc_cts_mode_name2id(const char *name)
{
size_t i;
for (i = 0; i < OSSL_NELEM(cts_modes); ++i) {
if (OPENSSL_strcasecmp(name, cts_modes[i].name) == 0)
return (int)cts_modes[i].id;
}
return -1;
}
static size_t cts128_cs1_encrypt(PROV_CIPHER_CTX *ctx, const unsigned char *in,
unsigned char *out, size_t len)
{
aligned_16bytes tmp_in;
size_t residue;
residue = len % CTS_BLOCK_SIZE;
len -= residue;
if (!ctx->hw->cipher(ctx, out, in, len))
return 0;
if (residue == 0)
return len;
in += len;
out += len;
memset(tmp_in.c, 0, sizeof(tmp_in));
memcpy(tmp_in.c, in, residue);
if (!ctx->hw->cipher(ctx, out - CTS_BLOCK_SIZE + residue, tmp_in.c,
CTS_BLOCK_SIZE))
return 0;
return len + residue;
}
static void do_xor(const unsigned char *in1, const unsigned char *in2,
size_t len, unsigned char *out)
{
size_t i;
for (i = 0; i < len; ++i)
out[i] = in1[i] ^ in2[i];
}
static size_t cts128_cs1_decrypt(PROV_CIPHER_CTX *ctx, const unsigned char *in,
unsigned char *out, size_t len)
{
aligned_16bytes mid_iv, ct_mid, cn, pt_last;
size_t residue;
residue = len % CTS_BLOCK_SIZE;
if (residue == 0) {
if (!ctx->hw->cipher(ctx, out, in, len))
return 0;
return len;
}
len -= CTS_BLOCK_SIZE + residue;
if (len > 0) {
if (!ctx->hw->cipher(ctx, out, in, len))
return 0;
in += len;
out += len;
}
memcpy(mid_iv.c, ctx->iv, CTS_BLOCK_SIZE);
memcpy(cn.c, in + residue, CTS_BLOCK_SIZE);
memset(ctx->iv, 0, CTS_BLOCK_SIZE);
if (!ctx->hw->cipher(ctx, pt_last.c, in + residue, CTS_BLOCK_SIZE))
return 0;
memcpy(ct_mid.c, in, residue);
memcpy(ct_mid.c + residue, pt_last.c + residue, CTS_BLOCK_SIZE - residue);
do_xor(ct_mid.c, pt_last.c, residue, out + CTS_BLOCK_SIZE);
memcpy(ctx->iv, mid_iv.c, CTS_BLOCK_SIZE);
if (!ctx->hw->cipher(ctx, out, ct_mid.c, CTS_BLOCK_SIZE))
return 0;
memcpy(ctx->iv, cn.c, CTS_BLOCK_SIZE);
return len + CTS_BLOCK_SIZE + residue;
}
static size_t cts128_cs3_encrypt(PROV_CIPHER_CTX *ctx, const unsigned char *in,
unsigned char *out, size_t len)
{
aligned_16bytes tmp_in;
size_t residue;
if (len < CTS_BLOCK_SIZE)
return 0;
if (len == CTS_BLOCK_SIZE)
return ctx->hw->cipher(ctx, out, in, len) ? len : 0;
residue = len % CTS_BLOCK_SIZE;
if (residue == 0)
residue = CTS_BLOCK_SIZE;
len -= residue;
if (!ctx->hw->cipher(ctx, out, in, len))
return 0;
in += len;
out += len;
memset(tmp_in.c, 0, sizeof(tmp_in));
memcpy(tmp_in.c, in, residue);
memcpy(out, out - CTS_BLOCK_SIZE, residue);
if (!ctx->hw->cipher(ctx, out - CTS_BLOCK_SIZE, tmp_in.c, CTS_BLOCK_SIZE))
return 0;
return len + residue;
}
static size_t cts128_cs3_decrypt(PROV_CIPHER_CTX *ctx, const unsigned char *in,
unsigned char *out, size_t len)
{
aligned_16bytes mid_iv, ct_mid, cn, pt_last;
size_t residue;
if (len < CTS_BLOCK_SIZE)
return 0;
if (len == CTS_BLOCK_SIZE)
return ctx->hw->cipher(ctx, out, in, len) ? len : 0;
residue = len % CTS_BLOCK_SIZE;
if (residue == 0)
residue = CTS_BLOCK_SIZE;
len -= CTS_BLOCK_SIZE + residue;
if (len > 0) {
if (!ctx->hw->cipher(ctx, out, in, len))
return 0;
in += len;
out += len;
}
memcpy(mid_iv.c, ctx->iv, CTS_BLOCK_SIZE);
memcpy(cn.c, in, CTS_BLOCK_SIZE);
memset(ctx->iv, 0, CTS_BLOCK_SIZE);
if (!ctx->hw->cipher(ctx, pt_last.c, in, CTS_BLOCK_SIZE))
return 0;
memcpy(ct_mid.c, in + CTS_BLOCK_SIZE, residue);
if (residue != CTS_BLOCK_SIZE)
memcpy(ct_mid.c + residue, pt_last.c + residue, CTS_BLOCK_SIZE - residue);
do_xor(ct_mid.c, pt_last.c, residue, out + CTS_BLOCK_SIZE);
memcpy(ctx->iv, mid_iv.c, CTS_BLOCK_SIZE);
if (!ctx->hw->cipher(ctx, out, ct_mid.c, CTS_BLOCK_SIZE))
return 0;
memcpy(ctx->iv, cn.c, CTS_BLOCK_SIZE);
return len + CTS_BLOCK_SIZE + residue;
}
static size_t cts128_cs2_encrypt(PROV_CIPHER_CTX *ctx, const unsigned char *in,
unsigned char *out, size_t len)
{
if (len % CTS_BLOCK_SIZE == 0) {
if (!ctx->hw->cipher(ctx, out, in, len))
return 0;
return len;
}
return cts128_cs3_encrypt(ctx, in, out, len);
}
static size_t cts128_cs2_decrypt(PROV_CIPHER_CTX *ctx, const unsigned char *in,
unsigned char *out, size_t len)
{
if (len % CTS_BLOCK_SIZE == 0) {
if (!ctx->hw->cipher(ctx, out, in, len))
return 0;
return len;
}
return cts128_cs3_decrypt(ctx, in, out, len);
}
int ossl_cipher_cbc_cts_block_update(void *vctx, unsigned char *out, size_t *outl,
size_t outsize, const unsigned char *in,
size_t inl)
{
PROV_CIPHER_CTX *ctx = (PROV_CIPHER_CTX *)vctx;
size_t sz = 0;
if (inl < CTS_BLOCK_SIZE)
return 0;
if (outsize < inl)
return 0;
if (out == NULL) {
*outl = inl;
return 1;
}
if (ctx->updated == 1)
return 0;
if (ctx->enc) {
if (ctx->cts_mode == CTS_CS1)
sz = cts128_cs1_encrypt(ctx, in, out, inl);
else if (ctx->cts_mode == CTS_CS2)
sz = cts128_cs2_encrypt(ctx, in, out, inl);
else if (ctx->cts_mode == CTS_CS3)
sz = cts128_cs3_encrypt(ctx, in, out, inl);
} else {
if (ctx->cts_mode == CTS_CS1)
sz = cts128_cs1_decrypt(ctx, in, out, inl);
else if (ctx->cts_mode == CTS_CS2)
sz = cts128_cs2_decrypt(ctx, in, out, inl);
else if (ctx->cts_mode == CTS_CS3)
sz = cts128_cs3_decrypt(ctx, in, out, inl);
}
if (sz == 0)
return 0;
ctx->updated = 1;
*outl = sz;
return 1;
}
int ossl_cipher_cbc_cts_block_final(void *vctx, unsigned char *out, size_t *outl,
size_t outsize)
{
*outl = 0;
return 1;
}