#include <sys/param.h>
#include <sys/systm.h>
#include <sys/mbuf.h>
#include <sys/malloc.h>
#include <sys/kernel.h>
#include <sys/socket.h>
#include <sys/endian.h>
#include <net/if.h>
#include <net/if_dl.h>
#include <net/if_media.h>
#include <netinet/in.h>
#include <netinet/if_ether.h>
#include <net80211/ieee80211_var.h>
#include <net80211/ieee80211_crypto.h>
#include <crypto/aes.h>
struct ieee80211_ccmp_ctx {
AES_CTX aesctx;
};
int
ieee80211_ccmp_set_key(struct ieee80211com *ic, struct ieee80211_key *k)
{
struct ieee80211_ccmp_ctx *ctx;
ctx = malloc(sizeof(*ctx), M_DEVBUF, M_NOWAIT | M_ZERO);
if (ctx == NULL)
return ENOMEM;
AES_Setkey(&ctx->aesctx, k->k_key, 16);
k->k_priv = ctx;
return 0;
}
void
ieee80211_ccmp_delete_key(struct ieee80211com *ic, struct ieee80211_key *k)
{
if (k->k_priv != NULL) {
explicit_bzero(k->k_priv, sizeof(struct ieee80211_ccmp_ctx));
free(k->k_priv, M_DEVBUF, sizeof(struct ieee80211_ccmp_ctx));
}
k->k_priv = NULL;
}
static void
ieee80211_ccmp_phase1(AES_CTX *ctx, const struct ieee80211_frame *wh,
u_int64_t pn, int lm, u_int8_t b[16], u_int8_t a[16], u_int8_t s0[16])
{
u_int8_t auth[32], nonce[13];
u_int8_t *aad;
u_int8_t tid = 0;
int la, i;
aad = &auth[2];
*aad = wh->i_fc[0];
if ((wh->i_fc[0] & IEEE80211_FC0_TYPE_MASK) ==
IEEE80211_FC0_TYPE_DATA)
*aad &= ~IEEE80211_FC0_SUBTYPE_MASK |
IEEE80211_FC0_SUBTYPE_QOS;
aad++;
*aad = wh->i_fc[1];
*aad &= ~(IEEE80211_FC1_RETRY | IEEE80211_FC1_PWR_MGT |
IEEE80211_FC1_MORE_DATA);
if (ieee80211_has_qos(wh))
*aad &= ~IEEE80211_FC1_ORDER;
aad++;
IEEE80211_ADDR_COPY(aad, wh->i_addr1); aad += IEEE80211_ADDR_LEN;
IEEE80211_ADDR_COPY(aad, wh->i_addr2); aad += IEEE80211_ADDR_LEN;
IEEE80211_ADDR_COPY(aad, wh->i_addr3); aad += IEEE80211_ADDR_LEN;
*aad++ = wh->i_seq[0] & ~0xf0;
*aad++ = 0;
if (ieee80211_has_addr4(wh)) {
IEEE80211_ADDR_COPY(aad,
((const struct ieee80211_frame_addr4 *)wh)->i_addr4);
aad += IEEE80211_ADDR_LEN;
}
if (ieee80211_has_qos(wh)) {
*aad++ = tid = ieee80211_get_qos(wh) & IEEE80211_QOS_TID;
*aad++ = 0;
}
nonce[ 0] = tid;
if ((wh->i_fc[0] & IEEE80211_FC0_TYPE_MASK) ==
IEEE80211_FC0_TYPE_MGT)
nonce[0] |= 1 << 4;
IEEE80211_ADDR_COPY(&nonce[1], wh->i_addr2);
nonce[ 7] = pn >> 40;
nonce[ 8] = pn >> 32;
nonce[ 9] = pn >> 24;
nonce[10] = pn >> 16;
nonce[11] = pn >> 8;
nonce[12] = pn;
la = aad - &auth[2];
auth[0] = la >> 8;
auth[1] = la & 0xff;
memset(aad, 0, 30 - la);
b[ 0] = 89;
memcpy(&b[1], nonce, 13);
b[14] = lm >> 8;
b[15] = lm & 0xff;
AES_Encrypt(ctx, b, b);
for (i = 0; i < 16; i++)
b[i] ^= auth[i];
AES_Encrypt(ctx, b, b);
for (i = 0; i < 16; i++)
b[i] ^= auth[16 + i];
AES_Encrypt(ctx, b, b);
a[ 0] = 1;
memcpy(&a[1], nonce, 13);
a[14] = a[15] = 0;
AES_Encrypt(ctx, a, s0);
}
struct mbuf *
ieee80211_ccmp_encrypt(struct ieee80211com *ic, struct mbuf *m0,
struct ieee80211_key *k)
{
struct ieee80211_ccmp_ctx *ctx = k->k_priv;
const struct ieee80211_frame *wh;
const u_int8_t *src;
u_int8_t *ivp, *mic, *dst;
u_int8_t a[16], b[16], s0[16], s[16];
struct mbuf *n0, *m, *n;
int hdrlen, left, moff, noff, len;
u_int16_t ctr;
int i, j;
MGET(n0, M_DONTWAIT, m0->m_type);
if (n0 == NULL)
goto nospace;
if (m_dup_pkthdr(n0, m0, M_DONTWAIT))
goto nospace;
n0->m_pkthdr.len += IEEE80211_CCMP_HDRLEN;
n0->m_len = MHLEN;
if (n0->m_pkthdr.len >= MINCLSIZE - IEEE80211_CCMP_MICLEN) {
MCLGET(n0, M_DONTWAIT);
if (n0->m_flags & M_EXT)
n0->m_len = n0->m_ext.ext_size;
}
if (n0->m_len > n0->m_pkthdr.len)
n0->m_len = n0->m_pkthdr.len;
wh = mtod(m0, struct ieee80211_frame *);
hdrlen = ieee80211_get_hdrlen(wh);
memcpy(mtod(n0, caddr_t), wh, hdrlen);
k->k_tsc++;
ivp = mtod(n0, u_int8_t *) + hdrlen;
ivp[0] = k->k_tsc;
ivp[1] = k->k_tsc >> 8;
ivp[2] = 0;
ivp[3] = k->k_id << 6 | IEEE80211_WEP_EXTIV;
ivp[4] = k->k_tsc >> 16;
ivp[5] = k->k_tsc >> 24;
ivp[6] = k->k_tsc >> 32;
ivp[7] = k->k_tsc >> 40;
ieee80211_ccmp_phase1(&ctx->aesctx, wh, k->k_tsc,
m0->m_pkthdr.len - hdrlen, b, a, s0);
ctr = 1;
a[14] = ctr >> 8;
a[15] = ctr & 0xff;
AES_Encrypt(&ctx->aesctx, a, s);
j = 0;
m = m0;
n = n0;
moff = hdrlen;
noff = hdrlen + IEEE80211_CCMP_HDRLEN;
left = m0->m_pkthdr.len - moff;
while (left > 0) {
if (moff == m->m_len) {
m = m->m_next;
moff = 0;
}
if (noff == n->m_len) {
MGET(n->m_next, M_DONTWAIT, n->m_type);
if (n->m_next == NULL)
goto nospace;
n = n->m_next;
n->m_len = MLEN;
if (left >= MINCLSIZE - IEEE80211_CCMP_MICLEN) {
MCLGET(n, M_DONTWAIT);
if (n->m_flags & M_EXT)
n->m_len = n->m_ext.ext_size;
}
if (n->m_len > left)
n->m_len = left;
noff = 0;
}
len = min(m->m_len - moff, n->m_len - noff);
src = mtod(m, u_int8_t *) + moff;
dst = mtod(n, u_int8_t *) + noff;
for (i = 0; i < len; i++) {
b[j] ^= src[i];
dst[i] = src[i] ^ s[j];
if (++j < 16)
continue;
AES_Encrypt(&ctx->aesctx, b, b);
ctr++;
a[14] = ctr >> 8;
a[15] = ctr & 0xff;
AES_Encrypt(&ctx->aesctx, a, s);
j = 0;
}
moff += len;
noff += len;
left -= len;
}
if (j != 0)
AES_Encrypt(&ctx->aesctx, b, b);
if (m_trailingspace(n) < IEEE80211_CCMP_MICLEN) {
MGET(n->m_next, M_DONTWAIT, n->m_type);
if (n->m_next == NULL)
goto nospace;
n = n->m_next;
n->m_len = 0;
}
mic = mtod(n, u_int8_t *) + n->m_len;
for (i = 0; i < IEEE80211_CCMP_MICLEN; i++)
mic[i] = b[i] ^ s0[i];
n->m_len += IEEE80211_CCMP_MICLEN;
n0->m_pkthdr.len += IEEE80211_CCMP_MICLEN;
m_freem(m0);
return n0;
nospace:
ic->ic_stats.is_tx_nombuf++;
m_freem(m0);
m_freem(n0);
return NULL;
}
int
ieee80211_ccmp_get_pn(uint64_t *pn, uint64_t **prsc, struct mbuf *m,
struct ieee80211_key *k)
{
struct ieee80211_frame *wh;
int hdrlen;
const u_int8_t *ivp;
wh = mtod(m, struct ieee80211_frame *);
hdrlen = ieee80211_get_hdrlen(wh);
if (m->m_pkthdr.len < hdrlen + IEEE80211_CCMP_HDRLEN)
return EINVAL;
ivp = (u_int8_t *)wh + hdrlen;
if (!(ivp[3] & IEEE80211_WEP_EXTIV))
return EINVAL;
if ((wh->i_fc[0] & IEEE80211_FC0_TYPE_MASK) ==
IEEE80211_FC0_TYPE_DATA) {
u_int8_t tid = ieee80211_has_qos(wh) ?
ieee80211_get_qos(wh) & IEEE80211_QOS_TID : 0;
*prsc = &k->k_rsc[tid];
} else
*prsc = &k->k_mgmt_rsc;
*pn = (u_int64_t)ivp[0] |
(u_int64_t)ivp[1] << 8 |
(u_int64_t)ivp[4] << 16 |
(u_int64_t)ivp[5] << 24 |
(u_int64_t)ivp[6] << 32 |
(u_int64_t)ivp[7] << 40;
return 0;
}
struct mbuf *
ieee80211_ccmp_decrypt(struct ieee80211com *ic, struct mbuf *m0,
struct ieee80211_key *k)
{
struct ieee80211_ccmp_ctx *ctx = k->k_priv;
struct ieee80211_frame *wh;
u_int64_t pn, *prsc;
const u_int8_t *src;
u_int8_t *dst;
u_int8_t mic0[IEEE80211_CCMP_MICLEN];
u_int8_t a[16], b[16], s0[16], s[16];
struct mbuf *n0, *m, *n;
int hdrlen, left, moff, noff, len;
u_int16_t ctr;
int i, j;
wh = mtod(m0, struct ieee80211_frame *);
hdrlen = ieee80211_get_hdrlen(wh);
if (m0->m_pkthdr.len < hdrlen + IEEE80211_CCMP_HDRLEN +
IEEE80211_CCMP_MICLEN) {
m_freem(m0);
return NULL;
}
if (ieee80211_ccmp_get_pn(&pn, &prsc, m0, k) != 0) {
m_freem(m0);
return NULL;
}
if (pn <= *prsc) {
ic->ic_stats.is_ccmp_replays++;
m_freem(m0);
return NULL;
}
MGET(n0, M_DONTWAIT, m0->m_type);
if (n0 == NULL)
goto nospace;
if (m_dup_pkthdr(n0, m0, M_DONTWAIT))
goto nospace;
n0->m_pkthdr.len -= IEEE80211_CCMP_HDRLEN + IEEE80211_CCMP_MICLEN;
n0->m_len = MHLEN;
if (n0->m_pkthdr.len >= MINCLSIZE) {
MCLGET(n0, M_DONTWAIT);
if (n0->m_flags & M_EXT)
n0->m_len = n0->m_ext.ext_size;
}
if (n0->m_len > n0->m_pkthdr.len)
n0->m_len = n0->m_pkthdr.len;
ieee80211_ccmp_phase1(&ctx->aesctx, wh, pn,
n0->m_pkthdr.len - hdrlen, b, a, s0);
memcpy(mtod(n0, caddr_t), wh, hdrlen);
wh = mtod(n0, struct ieee80211_frame *);
wh->i_fc[1] &= ~IEEE80211_FC1_PROTECTED;
ctr = 1;
a[14] = ctr >> 8;
a[15] = ctr & 0xff;
AES_Encrypt(&ctx->aesctx, a, s);
j = 0;
m = m0;
n = n0;
moff = hdrlen + IEEE80211_CCMP_HDRLEN;
noff = hdrlen;
left = n0->m_pkthdr.len - noff;
while (left > 0) {
if (moff == m->m_len) {
m = m->m_next;
moff = 0;
}
if (noff == n->m_len) {
MGET(n->m_next, M_DONTWAIT, n->m_type);
if (n->m_next == NULL)
goto nospace;
n = n->m_next;
n->m_len = MLEN;
if (left >= MINCLSIZE) {
MCLGET(n, M_DONTWAIT);
if (n->m_flags & M_EXT)
n->m_len = n->m_ext.ext_size;
}
if (n->m_len > left)
n->m_len = left;
noff = 0;
}
len = min(m->m_len - moff, n->m_len - noff);
src = mtod(m, u_int8_t *) + moff;
dst = mtod(n, u_int8_t *) + noff;
for (i = 0; i < len; i++) {
dst[i] = src[i] ^ s[j];
b[j] ^= dst[i];
if (++j < 16)
continue;
AES_Encrypt(&ctx->aesctx, b, b);
ctr++;
a[14] = ctr >> 8;
a[15] = ctr & 0xff;
AES_Encrypt(&ctx->aesctx, a, s);
j = 0;
}
moff += len;
noff += len;
left -= len;
}
if (j != 0)
AES_Encrypt(&ctx->aesctx, b, b);
for (i = 0; i < IEEE80211_CCMP_MICLEN; i++)
b[i] ^= s0[i];
m_copydata(m, moff, IEEE80211_CCMP_MICLEN, mic0);
if (timingsafe_bcmp(mic0, b, IEEE80211_CCMP_MICLEN) != 0) {
ic->ic_stats.is_ccmp_dec_errs++;
m_freem(m0);
m_freem(n0);
return NULL;
}
*prsc = pn;
m_freem(m0);
return n0;
nospace:
ic->ic_stats.is_rx_nombuf++;
m_freem(m0);
m_freem(n0);
return NULL;
}