#include "opt_wlan.h"
#include <sys/param.h>
#include <sys/systm.h>
#include <sys/kernel.h>
#include <crypto/rijndael/rijndael.h>
#include <net80211/ieee80211_crypto_gcm.h>
#define AES_BLOCK_LEN 16
#define BIT(x) (1U << (x))
static __inline void
xor_block(uint8_t *b, const uint8_t *a, size_t len)
{
int i;
for (i = 0; i < len; i++)
b[i] ^= a[i];
}
static inline
void WPA_PUT_BE64(uint8_t *a, uint64_t val)
{
a[0] = val >> 56;
a[1] = val >> 48;
a[2] = val >> 40;
a[3] = val >> 32;
a[4] = val >> 24;
a[5] = val >> 16;
a[6] = val >> 8;
a[7] = val & 0xff;
}
static inline void
WPA_PUT_BE32(uint8_t *a, uint32_t val)
{
a[0] = (val >> 24) & 0xff;
a[1] = (val >> 16) & 0xff;
a[2] = (val >> 8) & 0xff;
a[3] = val & 0xff;
}
static inline uint32_t
WPA_GET_BE32(const uint8_t *a)
{
return (((uint32_t) a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]);
}
static void
inc32(uint8_t *block)
{
uint32_t val;
val = WPA_GET_BE32(block + AES_BLOCK_LEN - 4);
val++;
WPA_PUT_BE32(block + AES_BLOCK_LEN - 4, val);
}
static void
shift_right_block(uint8_t *v)
{
uint32_t val;
val = WPA_GET_BE32(v + 12);
val >>= 1;
if (v[11] & 0x01)
val |= 0x80000000;
WPA_PUT_BE32(v + 12, val);
val = WPA_GET_BE32(v + 8);
val >>= 1;
if (v[7] & 0x01)
val |= 0x80000000;
WPA_PUT_BE32(v + 8, val);
val = WPA_GET_BE32(v + 4);
val >>= 1;
if (v[3] & 0x01)
val |= 0x80000000;
WPA_PUT_BE32(v + 4, val);
val = WPA_GET_BE32(v);
val >>= 1;
WPA_PUT_BE32(v, val);
}
static void
gf_mult(const uint8_t *x, const uint8_t *y, uint8_t *z)
{
uint8_t v[16];
int i, j;
memset(z, 0, 16);
memcpy(v, y, 16);
for (i = 0; i < 16; i++) {
for (j = 0; j < 8; j++) {
if (x[i] & BIT(7 - j)) {
xor_block(z, v, AES_BLOCK_LEN);
} else {
}
if (v[15] & 0x01) {
shift_right_block(v);
v[0] ^= 0xe1;
} else {
shift_right_block(v);
}
}
}
}
static void
ghash_start(uint8_t *y)
{
memset(y, 0, 16);
}
static void
ghash(const uint8_t *h, const uint8_t *x, size_t xlen, uint8_t *y)
{
size_t m, i;
const uint8_t *xpos = x;
uint8_t tmp[16];
m = xlen / 16;
for (i = 0; i < m; i++) {
xor_block(y, xpos, AES_BLOCK_LEN);
xpos += 16;
gf_mult(y, h, tmp);
memcpy(y, tmp, 16);
}
if (x + xlen > xpos) {
size_t last = x + xlen - xpos;
memcpy(tmp, xpos, last);
memset(tmp + last, 0, sizeof(tmp) - last);
xor_block(y, tmp, AES_BLOCK_LEN);
gf_mult(y, h, tmp);
memcpy(y, tmp, 16);
}
}
static void
aes_gctr(rijndael_ctx *aes, const uint8_t *icb,
const uint8_t *x, size_t xlen, uint8_t *y)
{
size_t i, n, last;
uint8_t cb[AES_BLOCK_LEN], tmp[AES_BLOCK_LEN];
const uint8_t *xpos = x;
uint8_t *ypos = y;
if (xlen == 0)
return;
n = xlen / 16;
memcpy(cb, icb, AES_BLOCK_LEN);
for (i = 0; i < n; i++) {
rijndael_encrypt(aes, cb, ypos);
xor_block(ypos, xpos, AES_BLOCK_LEN);
xpos += AES_BLOCK_LEN;
ypos += AES_BLOCK_LEN;
inc32(cb);
}
last = x + xlen - xpos;
if (last) {
rijndael_encrypt(aes, cb, tmp);
for (i = 0; i < last; i++)
*ypos++ = *xpos++ ^ tmp[i];
}
}
static void
aes_gcm_init_hash_subkey(rijndael_ctx *aes, uint8_t *H)
{
memset(H, 0, AES_BLOCK_LEN);
rijndael_encrypt(aes, H, H);
}
static void
aes_gcm_prepare_j0(const uint8_t *iv, size_t iv_len, const uint8_t *H,
uint8_t *J0)
{
uint8_t len_buf[16];
if (iv_len == 12) {
memcpy(J0, iv, iv_len);
memset(J0 + iv_len, 0, AES_BLOCK_LEN - iv_len);
J0[AES_BLOCK_LEN - 1] = 0x01;
} else {
ghash_start(J0);
ghash(H, iv, iv_len, J0);
WPA_PUT_BE64(len_buf, 0);
WPA_PUT_BE64(len_buf + 8, iv_len * 8);
ghash(H, len_buf, sizeof(len_buf), J0);
}
}
static void
aes_gcm_gctr(rijndael_ctx *aes, const uint8_t *J0, const uint8_t *in,
size_t len, uint8_t *out)
{
uint8_t J0inc[AES_BLOCK_LEN];
if (len == 0)
return;
memcpy(J0inc, J0, AES_BLOCK_LEN);
inc32(J0inc);
aes_gctr(aes, J0inc, in, len, out);
}
static void
aes_gcm_ghash(const uint8_t *H, const uint8_t *aad, size_t aad_len,
const uint8_t *crypt, size_t crypt_len, uint8_t *S)
{
uint8_t len_buf[16];
ghash_start(S);
ghash(H, aad, aad_len, S);
ghash(H, crypt, crypt_len, S);
WPA_PUT_BE64(len_buf, aad_len * 8);
WPA_PUT_BE64(len_buf + 8, crypt_len * 8);
ghash(H, len_buf, sizeof(len_buf), S);
}
void
ieee80211_crypto_aes_gcm_ae(rijndael_ctx *aes, const uint8_t *iv, size_t iv_len,
const uint8_t *plain, size_t plain_len,
const uint8_t *aad, size_t aad_len, uint8_t *crypt, uint8_t *tag)
{
uint8_t H[AES_BLOCK_LEN];
uint8_t J0[AES_BLOCK_LEN];
uint8_t S[GCMP_MIC_LEN];
aes_gcm_init_hash_subkey(aes, H);
aes_gcm_prepare_j0(iv, iv_len, H, J0);
aes_gcm_gctr(aes, J0, plain, plain_len, crypt);
aes_gcm_ghash(H, aad, aad_len, crypt, plain_len, S);
aes_gctr(aes, J0, S, sizeof(S), tag);
}
int
ieee80211_crypto_aes_gcm_ad(rijndael_ctx *aes, const uint8_t *iv, size_t iv_len,
const uint8_t *crypt, size_t crypt_len,
const uint8_t *aad, size_t aad_len, const uint8_t *tag, uint8_t *plain)
{
uint8_t H[AES_BLOCK_LEN];
uint8_t J0[AES_BLOCK_LEN];
uint8_t S[16], T[GCMP_MIC_LEN];
aes_gcm_init_hash_subkey(aes, H);
aes_gcm_prepare_j0(iv, iv_len, H, J0);
aes_gcm_gctr(aes, J0, crypt, crypt_len, plain);
aes_gcm_ghash(H, aad, aad_len, crypt, crypt_len, S);
aes_gctr(aes, J0, S, sizeof(S), T);
if (memcmp(tag, T, 16) != 0) {
return (-1);
}
return (0);
}