#include <sys/atomic.h>
#include <sys/cmn_err.h>
#include <sys/cpuvar.h>
#include <sys/debug.h>
#include <sys/errno.h>
#include <sys/kmem.h>
#include <sys/list.h>
#include <sys/md5.h>
#include <sys/stdbool.h>
#include <sys/stream.h>
#include <sys/stropts.h>
#include <sys/strsubr.h>
#include <sys/strsun.h>
#include <sys/sysmacros.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <netinet/ip6.h>
#include <net/pfkeyv2.h>
#include <net/pfpolicy.h>
#include <inet/common.h>
#include <inet/mi.h>
#include <inet/ip.h>
#include <inet/ip6.h>
#include <inet/ip_if.h>
#include <inet/tcp_stats.h>
#include <inet/keysock.h>
#include <inet/sadb.h>
#include <inet/tcp_sig.h>
static void tcpsig_sa_free(tcpsig_sa_t *);
void
tcpsig_init(tcp_stack_t *tcps)
{
mutex_init(&tcps->tcps_sigdb_lock, NULL, MUTEX_DEFAULT, NULL);
}
void
tcpsig_fini(tcp_stack_t *tcps)
{
tcpsig_db_t *db;
if ((db = tcps->tcps_sigdb) != NULL) {
tcpsig_sa_t *sa;
rw_destroy(&db->td_lock);
while ((sa = list_remove_head(&db->td_salist)) != NULL)
tcpsig_sa_free(sa);
list_destroy(&db->td_salist);
kmem_free(tcps->tcps_sigdb, sizeof (tcpsig_db_t));
tcps->tcps_sigdb = NULL;
}
mutex_destroy(&tcps->tcps_sigdb_lock);
}
static tcpsig_db_t *
tcpsig_db(tcp_stack_t *tcps)
{
mutex_enter(&tcps->tcps_sigdb_lock);
if (tcps->tcps_sigdb == NULL) {
tcpsig_db_t *db = kmem_alloc(sizeof (tcpsig_db_t), KM_SLEEP);
rw_init(&db->td_lock, NULL, RW_DEFAULT, 0);
list_create(&db->td_salist, sizeof (tcpsig_sa_t),
offsetof(tcpsig_sa_t, ts_link));
tcps->tcps_sigdb = db;
}
mutex_exit(&tcps->tcps_sigdb_lock);
return ((tcpsig_db_t *)tcps->tcps_sigdb);
}
static uint8_t *
tcpsig_make_sa_ext(uint8_t *start, const uint8_t * const end,
const tcpsig_sa_t *sa)
{
sadb_sa_t *assoc;
ASSERT3P(end, >, start);
if (start == NULL || end - start < sizeof (*assoc))
return (NULL);
assoc = (sadb_sa_t *)start;
assoc->sadb_sa_exttype = SADB_EXT_SA;
assoc->sadb_sa_len = SADB_8TO64(sizeof (*assoc));
assoc->sadb_sa_auth = sa->ts_key.sak_algid;
assoc->sadb_sa_flags = SADB_X_SAFLAGS_TCPSIG;
assoc->sadb_sa_state = sa->ts_state;
return ((uint8_t *)(assoc + 1));
}
static size_t
tcpsig_addr_extsize(const tcpsig_sa_t *sa)
{
size_t addrsize = 0;
switch (sa->ts_family) {
case AF_INET:
addrsize = roundup(sizeof (sin_t) +
sizeof (sadb_address_t), sizeof (uint64_t));
break;
case AF_INET6:
addrsize = roundup(sizeof (sin6_t) +
sizeof (sadb_address_t), sizeof (uint64_t));
break;
}
return (addrsize);
}
static uint8_t *
tcpsig_make_addr_ext(uint8_t *start, const uint8_t * const end,
uint16_t exttype, sa_family_t af, const struct sockaddr_storage *addr)
{
uint8_t *cur = start;
unsigned int addrext_len;
sadb_address_t *addrext;
ASSERT(af == AF_INET || af == AF_INET6);
ASSERT3P(end, >, start);
if (cur == NULL)
return (NULL);
if (end - cur < sizeof (*addrext))
return (NULL);
addrext = (sadb_address_t *)cur;
addrext->sadb_address_proto = IPPROTO_TCP;
addrext->sadb_address_reserved = 0;
addrext->sadb_address_prefixlen = 0;
addrext->sadb_address_exttype = exttype;
cur = (uint8_t *)(addrext + 1);
if (af == AF_INET) {
sin_t *sin;
if (end - cur < sizeof (*sin))
return (NULL);
sin = (sin_t *)cur;
*sin = sin_null;
bcopy(addr, sin, sizeof (*sin));
cur = (uint8_t *)(sin + 1);
} else {
sin6_t *sin6;
if (end - cur < sizeof (*sin6))
return (NULL);
sin6 = (sin6_t *)cur;
*sin6 = sin6_null;
bcopy(addr, sin6, sizeof (*sin6));
cur = (uint8_t *)(sin6 + 1);
}
addrext_len = roundup(cur - start, sizeof (uint64_t));
addrext->sadb_address_len = SADB_8TO64(addrext_len);
if (end - start < addrext_len)
return (NULL);
return (start + addrext_len);
}
#define SET_EXPIRE(sa, delta, exp) do { \
if (((sa)->ts_ ## delta) != 0) { \
(sa)->ts_ ## exp = tcpsig_add_time((sa)->ts_addtime, \
(sa)->ts_ ## delta); \
} \
} while (0)
#define UPDATE_EXPIRE(sa, delta, exp) do { \
if (((sa)->ts_ ## delta) != 0) { \
time_t tmp = tcpsig_add_time((sa)->ts_usetime, \
(sa)->ts_ ## delta); \
if (((sa)->ts_ ## exp) == 0) \
(sa)->ts_ ## exp = tmp; \
else \
(sa)->ts_ ## exp = MIN((sa)->ts_ ## exp, tmp); \
} \
} while (0)
#define EXPIRED(sa, exp, now) \
((sa)->ts_ ## exp != 0 && sa->ts_ ## exp < (now))
#define TIME_MAX INT64_MAX
static time_t
tcpsig_add_time(time_t base, uint64_t delta)
{
if (delta > TIME_MAX)
delta = TIME_MAX;
if (base > 0) {
if (TIME_MAX - base < delta)
return (TIME_MAX);
}
return (base + delta);
}
static int
tcpsig_check_lifetimes(sadb_lifetime_t *hard, sadb_lifetime_t *soft)
{
if (hard == NULL || soft == NULL)
return (SADB_X_DIAGNOSTIC_NONE);
if (hard->sadb_lifetime_addtime != 0 &&
soft->sadb_lifetime_addtime != 0 &&
hard->sadb_lifetime_addtime < soft->sadb_lifetime_addtime) {
return (SADB_X_DIAGNOSTIC_ADDTIME_HSERR);
}
if (hard->sadb_lifetime_usetime != 0 &&
soft->sadb_lifetime_usetime != 0 &&
hard->sadb_lifetime_usetime < soft->sadb_lifetime_usetime) {
return (SADB_X_DIAGNOSTIC_USETIME_HSERR);
}
return (SADB_X_DIAGNOSTIC_NONE);
}
static void
tcpsig_update_lifetimes(tcpsig_sa_t *sa, sadb_lifetime_t *hard,
sadb_lifetime_t *soft)
{
const time_t now = gethrestime_sec();
mutex_enter(&sa->ts_lock);
if (hard != NULL) {
if (hard->sadb_lifetime_usetime != 0)
sa->ts_harduselt = hard->sadb_lifetime_usetime;
if (hard->sadb_lifetime_addtime != 0)
sa->ts_hardaddlt = hard->sadb_lifetime_addtime;
if (sa->ts_hardaddlt != 0)
SET_EXPIRE(sa, hardaddlt, hardexpiretime);
if (sa->ts_harduselt != 0 && sa->ts_usetime != 0)
UPDATE_EXPIRE(sa, harduselt, hardexpiretime);
if (sa->ts_state == SADB_SASTATE_DEAD &&
!EXPIRED(sa, hardexpiretime, now)) {
sa->ts_state = SADB_SASTATE_MATURE;
}
}
if (soft != NULL) {
if (soft->sadb_lifetime_usetime != 0) {
sa->ts_softuselt = MIN(sa->ts_harduselt,
soft->sadb_lifetime_usetime);
}
if (soft->sadb_lifetime_addtime != 0) {
sa->ts_softaddlt = MIN(sa->ts_hardaddlt,
soft->sadb_lifetime_addtime);
}
if (sa->ts_softaddlt != 0)
SET_EXPIRE(sa, softaddlt, softexpiretime);
if (sa->ts_softuselt != 0 && sa->ts_usetime != 0)
UPDATE_EXPIRE(sa, softuselt, softexpiretime);
if (sa->ts_state == SADB_SASTATE_DYING &&
!EXPIRED(sa, softexpiretime, now)) {
sa->ts_state = SADB_SASTATE_MATURE;
}
}
mutex_exit(&sa->ts_lock);
}
static void
tcpsig_sa_touch(tcpsig_sa_t *sa)
{
const time_t now = gethrestime_sec();
mutex_enter(&sa->ts_lock);
sa->ts_lastuse = now;
if (sa->ts_usetime == 0) {
sa->ts_usetime = now;
UPDATE_EXPIRE(sa, softuselt, softexpiretime);
UPDATE_EXPIRE(sa, harduselt, hardexpiretime);
}
mutex_exit(&sa->ts_lock);
}
static void
tcpsig_sa_expiremsg(keysock_t *ks, const tcpsig_sa_t *sa, int ltt)
{
size_t alloclen;
sadb_sa_t *assoc;
sadb_msg_t *samsg;
sadb_lifetime_t *lt;
uint8_t *cur, *end;
mblk_t *mp;
alloclen = sizeof (sadb_msg_t) + sizeof (sadb_sa_t) +
2 * sizeof (sadb_lifetime_t) + 2 * tcpsig_addr_extsize(sa);
mp = allocb(alloclen, BPRI_HI);
if (mp == NULL)
return;
bzero(mp->b_rptr, alloclen);
mp->b_wptr += alloclen;
end = mp->b_wptr;
samsg = (sadb_msg_t *)mp->b_rptr;
samsg->sadb_msg_version = PF_KEY_V2;
samsg->sadb_msg_type = SADB_EXPIRE;
samsg->sadb_msg_errno = 0;
samsg->sadb_msg_satype = SADB_X_SATYPE_TCPSIG;
samsg->sadb_msg_reserved = 0;
samsg->sadb_msg_seq = 0;
samsg->sadb_msg_pid = 0;
samsg->sadb_msg_len = (uint16_t)SADB_8TO64(alloclen);
cur = (uint8_t *)(samsg + 1);
cur = tcpsig_make_sa_ext(cur, end, sa);
cur = tcpsig_make_addr_ext(cur, end, SADB_EXT_ADDRESS_SRC,
sa->ts_family, &sa->ts_src);
cur = tcpsig_make_addr_ext(cur, end, SADB_EXT_ADDRESS_DST,
sa->ts_family, &sa->ts_dst);
if (cur == NULL) {
freeb(mp);
return;
}
lt = (sadb_lifetime_t *)cur;
lt->sadb_lifetime_len = SADB_8TO64(sizeof (*lt));
lt->sadb_lifetime_exttype = SADB_EXT_LIFETIME_CURRENT;
lt->sadb_lifetime_allocations = 0;
lt->sadb_lifetime_bytes = 0;
lt->sadb_lifetime_addtime = sa->ts_addtime;
lt->sadb_lifetime_usetime = sa->ts_usetime;
lt++;
lt->sadb_lifetime_len = SADB_8TO64(sizeof (*lt));
lt->sadb_lifetime_exttype = ltt;
lt->sadb_lifetime_allocations = 0;
lt->sadb_lifetime_bytes = 0;
lt->sadb_lifetime_addtime = sa->ts_hardaddlt;
lt->sadb_lifetime_usetime = sa->ts_harduselt;
keysock_passup(mp, (sadb_msg_t *)mp->b_rptr,
0, NULL, B_TRUE, ks->keysock_keystack);
}
static void
tcpsig_sa_age(keysock_t *ks, tcp_stack_t *tcps)
{
tcpsig_db_t *db = tcpsig_db(tcps);
tcpsig_sa_t *nextsa;
const time_t now = gethrestime_sec();
rw_enter(&db->td_lock, RW_WRITER);
nextsa = list_head(&db->td_salist);
while (nextsa != NULL) {
tcpsig_sa_t *sa = nextsa;
nextsa = list_next(&db->td_salist, sa);
mutex_enter(&sa->ts_lock);
if (sa->ts_tombstoned) {
mutex_exit(&sa->ts_lock);
continue;
}
if (EXPIRED(sa, hardexpiretime, now)) {
sa->ts_state = IPSA_STATE_DEAD;
tcpsig_sa_expiremsg(ks, sa, SADB_EXT_LIFETIME_HARD);
if (sa->ts_refcnt > 0) {
sa->ts_tombstoned = true;
mutex_exit(&sa->ts_lock);
} else {
list_remove(&db->td_salist, sa);
mutex_exit(&sa->ts_lock);
tcpsig_sa_free(sa);
}
continue;
}
if (EXPIRED(sa, softexpiretime, now) &&
sa->ts_state == IPSA_STATE_MATURE) {
sa->ts_state = IPSA_STATE_DYING;
tcpsig_sa_expiremsg(ks, sa, SADB_EXT_LIFETIME_SOFT);
}
mutex_exit(&sa->ts_lock);
}
rw_exit(&db->td_lock);
}
static void
tcpsig_sa_free(tcpsig_sa_t *sa)
{
ASSERT0(sa->ts_refcnt);
mutex_destroy(&sa->ts_lock);
kmem_free(sa->ts_key.sak_key, sa->ts_key.sak_keylen);
kmem_free(sa, sizeof (*sa));
}
void
tcpsig_sa_rele(tcpsig_sa_t *sa)
{
mutex_enter(&sa->ts_lock);
VERIFY3U(sa->ts_refcnt, >, 0);
sa->ts_refcnt--;
if (sa->ts_tombstoned && sa->ts_refcnt == 0) {
tcpsig_db_t *db = tcpsig_db(sa->ts_stack);
sa->ts_refcnt++;
mutex_exit(&sa->ts_lock);
rw_enter(&db->td_lock, RW_WRITER);
mutex_enter(&sa->ts_lock);
sa->ts_refcnt--;
mutex_exit(&sa->ts_lock);
list_remove(&db->td_salist, sa);
rw_exit(&db->td_lock);
tcpsig_sa_free(sa);
} else {
mutex_exit(&sa->ts_lock);
}
}
static bool
tcpsig_sa_match4(tcpsig_sa_t *sa, struct sockaddr_storage *src_s,
struct sockaddr_storage *dst_s)
{
sin_t msrc, mdst, *src, *dst, *sasrc, *sadst;
if (src_s->ss_family != AF_INET)
return (false);
src = (sin_t *)src_s;
dst = (sin_t *)dst_s;
if (sa->ts_family == AF_INET6) {
sin6_t *sasrc6 = (sin6_t *)&sa->ts_src;
sin6_t *sadst6 = (sin6_t *)&sa->ts_dst;
if (!IN6_IS_ADDR_V4MAPPED(&sasrc6->sin6_addr) ||
!IN6_IS_ADDR_V4MAPPED(&sadst6->sin6_addr)) {
return (false);
}
msrc = sin_null;
msrc.sin_family = AF_INET;
msrc.sin_port = sasrc6->sin6_port;
IN6_V4MAPPED_TO_INADDR(&sasrc6->sin6_addr, &msrc.sin_addr);
sasrc = &msrc;
mdst = sin_null;
mdst.sin_family = AF_INET;
mdst.sin_port = sadst6->sin6_port;
IN6_V4MAPPED_TO_INADDR(&sadst6->sin6_addr, &mdst.sin_addr);
sadst = &mdst;
} else {
sasrc = (sin_t *)&sa->ts_src;
sadst = (sin_t *)&sa->ts_dst;
}
if (sasrc->sin_port != 0 && sasrc->sin_port != src->sin_port)
return (false);
if (sadst->sin_port != 0 && sadst->sin_port != dst->sin_port)
return (false);
if (sasrc->sin_addr.s_addr != src->sin_addr.s_addr)
return (false);
if (sadst->sin_addr.s_addr != dst->sin_addr.s_addr)
return (false);
return (true);
}
static bool
tcpsig_sa_match6(tcpsig_sa_t *sa, struct sockaddr_storage *src_s,
struct sockaddr_storage *dst_s)
{
sin6_t *src, *dst, *sasrc, *sadst;
if (src_s->ss_family != AF_INET6 || sa->ts_src.ss_family != AF_INET6)
return (false);
src = (sin6_t *)src_s;
dst = (sin6_t *)dst_s;
sasrc = (sin6_t *)&sa->ts_src;
sadst = (sin6_t *)&sa->ts_dst;
if (sasrc->sin6_port != 0 && sasrc->sin6_port != src->sin6_port)
return (false);
if (sadst->sin6_port != 0 && sadst->sin6_port != dst->sin6_port)
return (false);
if (!IN6_ARE_ADDR_EQUAL(&sasrc->sin6_addr, &src->sin6_addr))
return (false);
if (!IN6_ARE_ADDR_EQUAL(&sadst->sin6_addr, &dst->sin6_addr))
return (false);
return (true);
}
static tcpsig_sa_t *
tcpsig_sa_find_held(struct sockaddr_storage *src, struct sockaddr_storage *dst,
tcp_stack_t *tcps)
{
tcpsig_db_t *db = tcpsig_db(tcps);
tcpsig_sa_t *sa = NULL;
const time_t now = gethrestime_sec();
ASSERT(RW_LOCK_HELD(&db->td_lock));
if (src->ss_family != dst->ss_family)
return (NULL);
for (sa = list_head(&db->td_salist); sa != NULL;
sa = list_next(&db->td_salist, sa)) {
mutex_enter(&sa->ts_lock);
if (sa->ts_tombstoned || EXPIRED(sa, hardexpiretime, now)) {
mutex_exit(&sa->ts_lock);
continue;
}
if (tcpsig_sa_match4(sa, src, dst) ||
tcpsig_sa_match6(sa, src, dst)) {
sa->ts_refcnt++;
mutex_exit(&sa->ts_lock);
break;
}
mutex_exit(&sa->ts_lock);
}
return (sa);
}
static tcpsig_sa_t *
tcpsig_sa_find(struct sockaddr_storage *src, struct sockaddr_storage *dst,
tcp_stack_t *tcps)
{
tcpsig_db_t *db = tcpsig_db(tcps);
tcpsig_sa_t *sa;
rw_enter(&db->td_lock, RW_READER);
sa = tcpsig_sa_find_held(src, dst, tcps);
rw_exit(&db->td_lock);
return (sa);
}
static int
tcpsig_sa_flush(keysock_t *ks, tcp_stack_t *tcps, int *diagp)
{
tcpsig_db_t *db = tcpsig_db(tcps);
tcpsig_sa_t *nextsa;
rw_enter(&db->td_lock, RW_WRITER);
nextsa = list_head(&db->td_salist);
while (nextsa != NULL) {
tcpsig_sa_t *sa = nextsa;
nextsa = list_next(&db->td_salist, sa);
mutex_enter(&sa->ts_lock);
if (sa->ts_refcnt > 0) {
sa->ts_tombstoned = true;
mutex_exit(&sa->ts_lock);
continue;
}
list_remove(&db->td_salist, sa);
mutex_exit(&sa->ts_lock);
tcpsig_sa_free(sa);
}
rw_exit(&db->td_lock);
return (0);
}
static int
tcpsig_sa_add(keysock_t *ks, tcp_stack_t *tcps, keysock_in_t *ksi,
sadb_ext_t **extv, int *diagp)
{
tcpsig_db_t *db;
sadb_address_t *srcext, *dstext;
sadb_lifetime_t *soft, *hard;
sadb_sa_t *assoc;
struct sockaddr_storage *src, *dst;
sadb_key_t *key;
tcpsig_sa_t *sa, *dupsa;
int ret = 0;
assoc = (sadb_sa_t *)extv[SADB_EXT_SA];
srcext = (sadb_address_t *)extv[SADB_EXT_ADDRESS_SRC];
dstext = (sadb_address_t *)extv[SADB_EXT_ADDRESS_DST];
key = (sadb_key_t *)extv[SADB_X_EXT_STR_AUTH];
soft = (sadb_lifetime_t *)extv[SADB_EXT_LIFETIME_SOFT];
hard = (sadb_lifetime_t *)extv[SADB_EXT_LIFETIME_HARD];
if (assoc == NULL) {
*diagp = SADB_X_DIAGNOSTIC_MISSING_SA;
return (EINVAL);
}
if (srcext == NULL) {
*diagp = SADB_X_DIAGNOSTIC_MISSING_SRC;
return (EINVAL);
}
if (dstext == NULL) {
*diagp = SADB_X_DIAGNOSTIC_MISSING_DST;
return (EINVAL);
}
if (key == NULL) {
*diagp = SADB_X_DIAGNOSTIC_MISSING_ASTR;
return (EINVAL);
}
if ((*diagp = tcpsig_check_lifetimes(hard, soft)) !=
SADB_X_DIAGNOSTIC_NONE) {
return (EINVAL);
}
src = (struct sockaddr_storage *)(srcext + 1);
dst = (struct sockaddr_storage *)(dstext + 1);
if (src->ss_family != dst->ss_family) {
*diagp = SADB_X_DIAGNOSTIC_AF_MISMATCH;
return (EINVAL);
}
if (src->ss_family != AF_INET && src->ss_family != AF_INET6) {
*diagp = SADB_X_DIAGNOSTIC_BAD_SRC_AF;
return (EINVAL);
}
if (assoc->sadb_sa_auth != SADB_AALG_MD5) {
*diagp = SADB_X_DIAGNOSTIC_BAD_AALG;
return (EINVAL);
}
if ((key->sadb_key_bits & 0x7) != 0) {
*diagp = SADB_X_DIAGNOSTIC_MALFORMED_AKEY;
return (EINVAL);
}
db = tcpsig_db(tcps);
sa = kmem_zalloc(sizeof (*sa), KM_NOSLEEP_LAZY);
if (sa == NULL)
return (ENOMEM);
sa->ts_stack = tcps;
sa->ts_family = src->ss_family;
if (sa->ts_family == AF_INET6) {
bcopy(src, (sin6_t *)&sa->ts_src, sizeof (sin6_t));
bcopy(dst, (sin6_t *)&sa->ts_dst, sizeof (sin6_t));
} else {
bcopy(src, (sin_t *)&sa->ts_src, sizeof (sin_t));
bcopy(dst, (sin_t *)&sa->ts_dst, sizeof (sin_t));
}
sa->ts_key.sak_algid = assoc->sadb_sa_auth;
sa->ts_key.sak_keylen = SADB_1TO8(key->sadb_key_bits);
sa->ts_key.sak_keybits = key->sadb_key_bits;
sa->ts_key.sak_key = kmem_alloc(sa->ts_key.sak_keylen,
KM_NOSLEEP_LAZY);
if (sa->ts_key.sak_key == NULL) {
kmem_free(sa, sizeof (*sa));
return (ENOMEM);
}
bcopy(key + 1, sa->ts_key.sak_key, sa->ts_key.sak_keylen);
bzero(key + 1, sa->ts_key.sak_keylen);
mutex_init(&sa->ts_lock, NULL, MUTEX_DEFAULT, NULL);
sa->ts_state = SADB_SASTATE_MATURE;
sa->ts_addtime = gethrestime_sec();
sa->ts_usetime = 0;
if (soft != NULL) {
sa->ts_softaddlt = soft->sadb_lifetime_addtime;
sa->ts_softuselt = soft->sadb_lifetime_usetime;
SET_EXPIRE(sa, softaddlt, softexpiretime);
}
if (hard != NULL) {
sa->ts_hardaddlt = hard->sadb_lifetime_addtime;
sa->ts_harduselt = hard->sadb_lifetime_usetime;
SET_EXPIRE(sa, hardaddlt, hardexpiretime);
}
sa->ts_refcnt = 0;
sa->ts_tombstoned = false;
rw_enter(&db->td_lock, RW_WRITER);
if ((dupsa = tcpsig_sa_find_held(src, dst, tcps)) != NULL) {
rw_exit(&db->td_lock);
tcpsig_sa_rele(dupsa);
tcpsig_sa_free(sa);
*diagp = SADB_X_DIAGNOSTIC_DUPLICATE_SA;
ret = EEXIST;
} else {
list_insert_tail(&db->td_salist, sa);
rw_exit(&db->td_lock);
}
return (ret);
}
static int
tcpsig_sa_update(keysock_t *ks, tcp_stack_t *tcps, keysock_in_t *ksi,
sadb_ext_t **extv, int *diagp)
{
tcpsig_db_t *db;
sadb_address_t *srcext, *dstext;
sadb_lifetime_t *soft, *hard;
struct sockaddr_storage *src, *dst;
tcpsig_sa_t *sa;
srcext = (sadb_address_t *)extv[SADB_EXT_ADDRESS_SRC];
dstext = (sadb_address_t *)extv[SADB_EXT_ADDRESS_DST];
soft = (sadb_lifetime_t *)extv[SADB_EXT_LIFETIME_SOFT];
hard = (sadb_lifetime_t *)extv[SADB_EXT_LIFETIME_HARD];
if (srcext == NULL) {
*diagp = SADB_X_DIAGNOSTIC_MISSING_SRC;
return (EINVAL);
}
if (dstext == NULL) {
*diagp = SADB_X_DIAGNOSTIC_MISSING_DST;
return (EINVAL);
}
if ((*diagp = tcpsig_check_lifetimes(hard, soft)) !=
SADB_X_DIAGNOSTIC_NONE) {
return (EINVAL);
}
src = (struct sockaddr_storage *)(srcext + 1);
dst = (struct sockaddr_storage *)(dstext + 1);
sa = tcpsig_sa_find(src, dst, tcps);
if (sa == NULL) {
*diagp = SADB_X_DIAGNOSTIC_PAIR_SA_NOTFOUND;
return (ESRCH);
}
tcpsig_update_lifetimes(sa, hard, soft);
tcpsig_sa_rele(sa);
tcpsig_sa_age(ks, tcps);
return (0);
}
static mblk_t *
tcpsig_dump_one(const tcpsig_sa_t *sa, sadb_msg_t *samsg)
{
size_t alloclen, keysize;
sadb_sa_t *assoc;
sadb_msg_t *newsamsg;
uint8_t *cur, *end;
sadb_key_t *key;
mblk_t *mp;
bool soft = false, hard = false;
ASSERT(MUTEX_HELD(&sa->ts_lock));
alloclen = sizeof (sadb_msg_t) + sizeof (sadb_sa_t) +
2 * tcpsig_addr_extsize(sa);
if (sa->ts_softaddlt != 0 || sa->ts_softuselt != 0) {
alloclen += sizeof (sadb_lifetime_t);
soft = true;
}
if (sa->ts_hardaddlt != 0 || sa->ts_harduselt != 0) {
alloclen += sizeof (sadb_lifetime_t);
hard = true;
}
if (soft || hard)
alloclen += sizeof (sadb_lifetime_t);
keysize = roundup(sizeof (sadb_key_t) + sa->ts_key.sak_keylen,
sizeof (uint64_t));
alloclen += keysize;
mp = allocb(alloclen, BPRI_HI);
if (mp == NULL)
return (NULL);
bzero(mp->b_rptr, alloclen);
mp->b_wptr += alloclen;
end = mp->b_wptr;
newsamsg = (sadb_msg_t *)mp->b_rptr;
*newsamsg = *samsg;
newsamsg->sadb_msg_len = (uint16_t)SADB_8TO64(alloclen);
cur = (uint8_t *)(newsamsg + 1);
cur = tcpsig_make_sa_ext(cur, end, sa);
cur = tcpsig_make_addr_ext(cur, end, SADB_EXT_ADDRESS_SRC,
sa->ts_family, &sa->ts_src);
cur = tcpsig_make_addr_ext(cur, end, SADB_EXT_ADDRESS_DST,
sa->ts_family, &sa->ts_dst);
if (cur == NULL) {
freeb(mp);
return (NULL);
}
if (soft || hard) {
sadb_lifetime_t *lt = (sadb_lifetime_t *)cur;
lt->sadb_lifetime_len = SADB_8TO64(sizeof (*lt));
lt->sadb_lifetime_exttype = SADB_EXT_LIFETIME_CURRENT;
lt->sadb_lifetime_allocations = 0;
lt->sadb_lifetime_bytes = 0;
lt->sadb_lifetime_addtime = sa->ts_addtime;
lt->sadb_lifetime_usetime = sa->ts_usetime;
lt++;
if (soft) {
lt->sadb_lifetime_len = SADB_8TO64(sizeof (*lt));
lt->sadb_lifetime_exttype = SADB_EXT_LIFETIME_SOFT;
lt->sadb_lifetime_allocations = 0;
lt->sadb_lifetime_bytes = 0;
lt->sadb_lifetime_addtime = sa->ts_softaddlt;
lt->sadb_lifetime_usetime = sa->ts_softuselt;
lt++;
}
if (hard) {
lt->sadb_lifetime_len = SADB_8TO64(sizeof (*lt));
lt->sadb_lifetime_exttype = SADB_EXT_LIFETIME_HARD;
lt->sadb_lifetime_allocations = 0;
lt->sadb_lifetime_bytes = 0;
lt->sadb_lifetime_addtime = sa->ts_hardaddlt;
lt->sadb_lifetime_usetime = sa->ts_harduselt;
lt++;
}
cur = (uint8_t *)lt;
}
key = (sadb_key_t *)cur;
key->sadb_key_exttype = SADB_X_EXT_STR_AUTH;
key->sadb_key_len = SADB_8TO64(keysize);
key->sadb_key_bits = sa->ts_key.sak_keybits;
key->sadb_key_reserved = 0;
bcopy(sa->ts_key.sak_key, (uint8_t *)(key + 1), sa->ts_key.sak_keylen);
return (mp);
}
static int
tcpsig_sa_dump(keysock_t *ks, tcp_stack_t *tcps, sadb_msg_t *samsg, int *diag)
{
tcpsig_db_t *db;
tcpsig_sa_t *sa;
db = tcpsig_db(tcps);
rw_enter(&db->td_lock, RW_READER);
for (sa = list_head(&db->td_salist); sa != NULL;
sa = list_next(&db->td_salist, sa)) {
mblk_t *mp;
mutex_enter(&sa->ts_lock);
if (sa->ts_tombstoned) {
mutex_exit(&sa->ts_lock);
continue;
}
mp = tcpsig_dump_one(sa, samsg);
mutex_exit(&sa->ts_lock);
if (mp == NULL) {
rw_exit(&db->td_lock);
return (ENOMEM);
}
keysock_passup(mp, (sadb_msg_t *)mp->b_rptr,
ks->keysock_serial, NULL, B_TRUE, ks->keysock_keystack);
}
rw_exit(&db->td_lock);
samsg->sadb_msg_seq = 0;
return (0);
}
static int
tcpsig_sa_delget(keysock_t *ks, tcp_stack_t *tcps, sadb_msg_t *samsg,
sadb_ext_t **extv, int *diagp)
{
sadb_address_t *srcext, *dstext;
struct sockaddr_storage *src, *dst;
tcpsig_sa_t *sa;
mblk_t *mp;
srcext = (sadb_address_t *)extv[SADB_EXT_ADDRESS_SRC];
dstext = (sadb_address_t *)extv[SADB_EXT_ADDRESS_DST];
if (srcext == NULL) {
*diagp = SADB_X_DIAGNOSTIC_MISSING_SRC;
return (EINVAL);
}
if (dstext == NULL) {
*diagp = SADB_X_DIAGNOSTIC_MISSING_DST;
return (EINVAL);
}
src = (struct sockaddr_storage *)(srcext + 1);
dst = (struct sockaddr_storage *)(dstext + 1);
sa = tcpsig_sa_find(src, dst, tcps);
if (sa == NULL) {
*diagp = SADB_X_DIAGNOSTIC_PAIR_SA_NOTFOUND;
return (ESRCH);
}
if (samsg->sadb_msg_type == SADB_GET) {
mutex_enter(&sa->ts_lock);
mp = tcpsig_dump_one(sa, samsg);
mutex_exit(&sa->ts_lock);
if (mp == NULL) {
tcpsig_sa_rele(sa);
return (ENOMEM);
}
keysock_passup(mp, (sadb_msg_t *)mp->b_rptr,
ks->keysock_serial, NULL, B_TRUE, ks->keysock_keystack);
tcpsig_sa_rele(sa);
return (0);
}
mutex_enter(&sa->ts_lock);
sa->ts_tombstoned = true;
mutex_exit(&sa->ts_lock);
tcpsig_sa_rele(sa);
return (0);
}
void
tcpsig_sa_handler(keysock_t *ks, mblk_t *mp, sadb_msg_t *samsg,
sadb_ext_t **extv)
{
keysock_stack_t *keystack = ks->keysock_keystack;
netstack_t *nst = keystack->keystack_netstack;
tcp_stack_t *tcps = nst->netstack_tcp;
keysock_in_t *ksi = (keysock_in_t *)mp->b_rptr;
int diag = SADB_X_DIAGNOSTIC_NONE;
int error;
tcpsig_sa_age(ks, tcps);
switch (samsg->sadb_msg_type) {
case SADB_ADD:
error = tcpsig_sa_add(ks, tcps, ksi, extv, &diag);
keysock_error(ks, mp, error, diag);
break;
case SADB_UPDATE:
error = tcpsig_sa_update(ks, tcps, ksi, extv, &diag);
keysock_error(ks, mp, error, diag);
break;
case SADB_GET:
case SADB_DELETE:
error = tcpsig_sa_delget(ks, tcps, samsg, extv, &diag);
keysock_error(ks, mp, error, diag);
break;
case SADB_FLUSH:
error = tcpsig_sa_flush(ks, tcps, &diag);
keysock_error(ks, mp, error, diag);
break;
case SADB_DUMP:
error = tcpsig_sa_dump(ks, tcps, samsg, &diag);
keysock_error(ks, mp, error, diag);
break;
default:
keysock_error(ks, mp, EOPNOTSUPP, diag);
break;
}
}
bool
tcpsig_sa_exists(tcp_t *tcp, bool inbound, tcpsig_sa_t **sap)
{
tcp_stack_t *tcps = tcp->tcp_tcps;
conn_t *connp = tcp->tcp_connp;
struct sockaddr_storage src, dst;
tcpsig_sa_t *sa;
bzero(&src, sizeof (src));
bzero(&dst, sizeof (dst));
if (connp->conn_ipversion == IPV6_VERSION) {
sin6_t *sin6;
sin6 = (sin6_t *)&src;
sin6->sin6_family = AF_INET6;
if (inbound) {
sin6->sin6_addr = connp->conn_faddr_v6;
sin6->sin6_port = connp->conn_fport;
} else {
sin6->sin6_addr = connp->conn_saddr_v6;
sin6->sin6_port = connp->conn_lport;
}
sin6 = (sin6_t *)&dst;
sin6->sin6_family = AF_INET6;
if (inbound) {
sin6->sin6_addr = connp->conn_saddr_v6;
sin6->sin6_port = connp->conn_lport;
} else {
sin6->sin6_addr = connp->conn_faddr_v6;
sin6->sin6_port = connp->conn_fport;
}
} else {
sin_t *sin;
sin = (sin_t *)&src;
sin->sin_family = AF_INET;
if (inbound) {
sin->sin_addr.s_addr = connp->conn_faddr_v4;
sin->sin_port = connp->conn_fport;
} else {
sin->sin_addr.s_addr = connp->conn_saddr_v4;
sin->sin_port = connp->conn_lport;
}
sin = (sin_t *)&dst;
sin->sin_family = AF_INET;
if (inbound) {
sin->sin_addr.s_addr = connp->conn_saddr_v4;
sin->sin_port = connp->conn_lport;
} else {
sin->sin_addr.s_addr = connp->conn_faddr_v4;
sin->sin_port = connp->conn_fport;
}
}
sa = tcpsig_sa_find(&src, &dst, tcps);
if (sa == NULL)
return (false);
if (sap != NULL)
*sap = sa;
else
tcpsig_sa_rele(sa);
return (true);
}
static void
tcpsig_pseudo_compute4(tcp_t *tcp, int tcplen, MD5_CTX *ctx, bool inbound)
{
struct ip_pseudo {
struct in_addr ipp_src;
struct in_addr ipp_dst;
uint8_t ipp_pad;
uint8_t ipp_proto;
uint16_t ipp_len;
} ipp;
conn_t *connp = tcp->tcp_connp;
if (inbound) {
ipp.ipp_src.s_addr = connp->conn_faddr_v4;
ipp.ipp_dst.s_addr = connp->conn_saddr_v4;
} else {
ipp.ipp_src.s_addr = connp->conn_saddr_v4;
ipp.ipp_dst.s_addr = connp->conn_faddr_v4;
}
ipp.ipp_pad = 0;
ipp.ipp_proto = IPPROTO_TCP;
ipp.ipp_len = htons(tcplen);
DTRACE_PROBE1(ipp4, struct ip_pseudo *, &ipp);
MD5Update(ctx, (char *)&ipp, sizeof (ipp));
}
static void
tcpsig_pseudo_compute6(tcp_t *tcp, int tcplen, MD5_CTX *ctx, bool inbound)
{
struct ip6_pseudo {
struct in6_addr ipp_src;
struct in6_addr ipp_dst;
uint32_t ipp_len;
uint32_t ipp_nxt;
} ip6p;
conn_t *connp = tcp->tcp_connp;
if (inbound) {
ip6p.ipp_src = connp->conn_faddr_v6;
ip6p.ipp_dst = connp->conn_saddr_v6;
} else {
ip6p.ipp_src = connp->conn_saddr_v6;
ip6p.ipp_dst = connp->conn_faddr_v6;
}
ip6p.ipp_len = htonl(tcplen);
ip6p.ipp_nxt = htonl(IPPROTO_TCP);
DTRACE_PROBE1(ipp6, struct ip6_pseudo *, &ip6p);
MD5Update(ctx, (char *)&ip6p, sizeof (ip6p));
}
bool
tcpsig_signature(mblk_t *mp, tcp_t *tcp, tcpha_t *tcpha, int tcplen,
uint8_t *digest, bool inbound)
{
tcp_stack_t *tcps = tcp->tcp_tcps;
conn_t *connp = tcp->tcp_connp;
tcpsig_sa_t *sa;
MD5_CTX context;
if (!inbound && (tcpha->tha_offset_and_reserved >> 4) > 10) {
TCP_STAT(tcps, tcp_sig_no_space);
return (false);
}
sa = inbound ? tcp->tcp_sig_sa_in : tcp->tcp_sig_sa_out;
if (sa == NULL) {
if (!tcpsig_sa_exists(tcp, inbound, &sa)) {
TCP_STAT(tcps, tcp_sig_match_failed);
return (false);
}
if (inbound)
tcp->tcp_sig_sa_in = sa;
else
tcp->tcp_sig_sa_out = sa;
}
tcpsig_sa_touch(sa);
VERIFY3U(sa->ts_key.sak_algid, ==, SADB_AALG_MD5);
MD5Init(&context);
if (connp->conn_ipversion == IPV6_VERSION)
tcpsig_pseudo_compute6(tcp, tcplen, &context, inbound);
else
tcpsig_pseudo_compute4(tcp, tcplen, &context, inbound);
uint16_t offset = tcpha->tha_offset_and_reserved;
uint16_t sum = tcpha->tha_sum;
if (!inbound) {
tcpha->tha_offset_and_reserved += (5 << 4);
}
tcpha->tha_sum = 0;
MD5Update(&context, tcpha, sizeof (*tcpha));
tcpha->tha_offset_and_reserved = offset;
tcpha->tha_sum = sum;
for (; mp != NULL; mp = mp->b_cont)
MD5Update(&context, mp->b_rptr, mp->b_wptr - mp->b_rptr);
MD5Update(&context, sa->ts_key.sak_key, sa->ts_key.sak_keylen);
MD5Final(digest, &context);
return (true);
}
bool
tcpsig_verify(mblk_t *mp, tcp_t *tcp, tcpha_t *tcpha, ip_recv_attr_t *ira,
uint8_t *digest)
{
uint8_t calc_digest[MD5_DIGEST_LENGTH];
if (!tcpsig_signature(mp, tcp, tcpha,
ira->ira_pktlen - ira->ira_ip_hdr_length, calc_digest, true)) {
return (false);
}
if (bcmp(digest, calc_digest, sizeof (calc_digest)) != 0) {
TCP_STAT(tcp->tcp_tcps, tcp_sig_verify_failed);
return (false);
}
return (true);
}