#include <lib/libsa/stand.h>
#include "aes_xts.h"
void
aes_xts_reinit(struct aes_xts_ctx *ctx, u_int8_t *iv)
{
u_int64_t blocknum;
u_int i;
bcopy(iv, &blocknum, AES_XTS_IVSIZE);
for (i = 0; i < AES_XTS_IVSIZE; i++) {
ctx->tweak[i] = blocknum & 0xff;
blocknum >>= 8;
}
bzero(ctx->tweak + AES_XTS_IVSIZE, AES_XTS_IVSIZE);
rijndael_encrypt(&ctx->key2, ctx->tweak, ctx->tweak);
}
void
aes_xts_crypt(struct aes_xts_ctx *ctx, u_int8_t *data, u_int do_encrypt)
{
u_int8_t block[AES_XTS_BLOCKSIZE];
u_int i, carry_in, carry_out;
for (i = 0; i < AES_XTS_BLOCKSIZE; i++)
block[i] = data[i] ^ ctx->tweak[i];
if (do_encrypt)
rijndael_encrypt(&ctx->key1, block, data);
else
rijndael_decrypt(&ctx->key1, block, data);
for (i = 0; i < AES_XTS_BLOCKSIZE; i++)
data[i] ^= ctx->tweak[i];
carry_in = 0;
for (i = 0; i < AES_XTS_BLOCKSIZE; i++) {
carry_out = ctx->tweak[i] & 0x80;
ctx->tweak[i] = (ctx->tweak[i] << 1) | carry_in;
carry_in = carry_out >> 7;
}
ctx->tweak[0] ^= (AES_XTS_ALPHA & -carry_in);
explicit_bzero(block, sizeof(block));
}
void
aes_xts_encrypt(struct aes_xts_ctx *ctx, u_int8_t *data)
{
aes_xts_crypt(ctx, data, 1);
}
void
aes_xts_decrypt(struct aes_xts_ctx *ctx, u_int8_t *data)
{
aes_xts_crypt(ctx, data, 0);
}
int
aes_xts_setkey(struct aes_xts_ctx *ctx, u_int8_t *key, int len)
{
if (len != 32 && len != 64)
return -1;
rijndael_set_key(&ctx->key1, key, len * 4);
rijndael_set_key(&ctx->key2, key + (len / 2), len * 4);
return 0;
}
void
aes_xts_zerokey(struct aes_xts_ctx *ctx)
{
explicit_bzero(ctx, sizeof(struct aes_xts_ctx));
}