#include <sys/types.h>
#include <sys/cmn_err.h>
#include <sys/crc32.h>
#include <netinet/in.h>
#include <inet/ip.h>
#include <inet/ip6.h>
#include <inet/tcp.h>
#include <inet/udp_impl.h>
#include <inet/ilb.h>
#include "ilb_impl.h"
#include "ilb_stack.h"
#include "ilb_nat.h"
#define ILB_NAT_SRC_TIMEOUT 30
#define ILB_NAT_SRC_TIMEOUT_JITTER 5
#define ILB_NAT_SRC_HASH(hash, key1, key2, hash_size) \
{ \
CRC32((hash), (key1), sizeof (uint32_t), -1U, crc32_table); \
CRC32((hash), (key2), sizeof (uint32_t), (hash), crc32_table); \
(hash) %= (hash_size); \
}
static uint32_t ilb_nat_src_instance = 0;
static void
incr_addr(in6_addr_t *a)
{
uint32_t i;
i = ntohl(a->s6_addr32[3]);
if (IN6_IS_ADDR_V4MAPPED(a)) {
a->s6_addr32[3] = htonl(++i);
ASSERT(i != 0);
return;
}
if (++i != 0) {
a->s6_addr32[3] = htonl(i);
return;
}
a->s6_addr32[3] = 0;
i = ntohl(a->s6_addr[2]);
if (++i != 0) {
a->s6_addr32[2] = htonl(i);
return;
}
a->s6_addr32[2] = 0;
i = ntohl(a->s6_addr[1]);
if (++i != 0) {
a->s6_addr32[1] = htonl(i);
return;
}
a->s6_addr32[1] = 0;
i = ntohl(a->s6_addr[0]);
a->s6_addr[0] = htonl(++i);
ASSERT(i != 0);
}
void
ilb_nat_src_init(ilb_stack_t *ilbs)
{
int i;
ilbs->ilbs_nat_src = kmem_zalloc(sizeof (ilb_nat_src_hash_t) *
ilbs->ilbs_nat_src_hash_size, KM_SLEEP);
for (i = 0; i < ilbs->ilbs_nat_src_hash_size; i++) {
list_create(&ilbs->ilbs_nat_src[i].nsh_head,
sizeof (ilb_nat_src_entry_t),
offsetof(ilb_nat_src_entry_t, nse_link));
mutex_init(&ilbs->ilbs_nat_src[i].nsh_lock, NULL,
MUTEX_DEFAULT, NULL);
}
ilbs->ilbs_nat_src_tid = timeout(ilb_nat_src_timer, ilbs,
SEC_TO_TICK(ILB_NAT_SRC_TIMEOUT +
gethrtime() % ILB_NAT_SRC_TIMEOUT_JITTER));
}
void
ilb_nat_src_fini(ilb_stack_t *ilbs)
{
ilb_nat_src_entry_t *cur;
timeout_id_t tid;
int i;
mutex_enter(&ilbs->ilbs_nat_src_lock);
tid = ilbs->ilbs_nat_src_tid;
ilbs->ilbs_nat_src_tid = 0;
mutex_exit(&ilbs->ilbs_nat_src_lock);
if (tid != 0)
(void) untimeout(tid);
mutex_destroy(&ilbs->ilbs_nat_src_lock);
for (i = 0; i < ilbs->ilbs_nat_src_hash_size; i++) {
while ((cur = list_remove_head(&ilbs->ilbs_nat_src[i].nsh_head))
!= NULL) {
vmem_destroy(cur->nse_port_arena);
kmem_free(cur, sizeof (ilb_nat_src_entry_t));
}
mutex_destroy(&ilbs->ilbs_nat_src[i].nsh_lock);
}
kmem_free(ilbs->ilbs_nat_src, sizeof (ilb_nat_src_hash_t) *
ilbs->ilbs_nat_src_hash_size);
ilbs->ilbs_nat_src = NULL;
}
#define ARENA_NAMESZ 18
#define NAT_PORT_START 4096
#define NAT_PORT_SIZE 65535 - NAT_PORT_START
static ilb_nat_src_entry_t *
ilb_find_nat_src(ilb_stack_t *ilbs, const in6_addr_t *nat_src,
const in6_addr_t *serv_addr, in_port_t port)
{
ilb_nat_src_entry_t *tmp;
uint32_t idx;
char arena_name[ARENA_NAMESZ];
list_t *head;
ILB_NAT_SRC_HASH(idx, &nat_src->s6_addr32[3], &serv_addr->s6_addr32[3],
ilbs->ilbs_nat_src_hash_size);
mutex_enter(&ilbs->ilbs_nat_src[idx].nsh_lock);
head = &ilbs->ilbs_nat_src[idx].nsh_head;
for (tmp = list_head(head); tmp != NULL; tmp = list_next(head, tmp)) {
if (IN6_ARE_ADDR_EQUAL(&tmp->nse_src_addr, nat_src) &&
IN6_ARE_ADDR_EQUAL(&tmp->nse_serv_addr, serv_addr) &&
(port == tmp->nse_port || port == 0 ||
tmp->nse_port == 0)) {
break;
}
}
if (tmp != NULL) {
tmp->nse_refcnt++;
mutex_exit(&ilbs->ilbs_nat_src[idx].nsh_lock);
return (tmp);
}
tmp = kmem_alloc(sizeof (ilb_nat_src_entry_t), KM_NOSLEEP);
if (tmp == NULL) {
mutex_exit(&ilbs->ilbs_nat_src[idx].nsh_lock);
return (NULL);
}
tmp->nse_src_addr = *nat_src;
tmp->nse_serv_addr = *serv_addr;
tmp->nse_port = port;
tmp->nse_nsh_lock = &ilbs->ilbs_nat_src[idx].nsh_lock;
tmp->nse_refcnt = 1;
(void) snprintf(arena_name, ARENA_NAMESZ, "ilb_ns_%u",
atomic_inc_32_nv(&ilb_nat_src_instance));
if ((tmp->nse_port_arena = vmem_create(arena_name,
(void *)NAT_PORT_START, NAT_PORT_SIZE, 1, NULL, NULL, NULL, 1,
VM_SLEEP | VMC_IDENTIFIER)) == NULL) {
kmem_free(tmp, sizeof (*tmp));
return (NULL);
}
list_insert_tail(head, tmp);
mutex_exit(&ilbs->ilbs_nat_src[idx].nsh_lock);
return (tmp);
}
int
ilb_create_nat_src(ilb_stack_t *ilbs, ilb_nat_src_t **nat_src,
const in6_addr_t *srv_addr, in_port_t port, const in6_addr_t *start,
int num)
{
ilb_nat_src_t *src;
in6_addr_t cur_addr;
int i;
if ((src = kmem_zalloc(sizeof (ilb_nat_src_t), KM_NOSLEEP)) == NULL) {
*nat_src = NULL;
return (ENOMEM);
}
cur_addr = *start;
for (i = 0; i < num && i < ILB_MAX_NAT_SRC; i++) {
src->src_list[i] = ilb_find_nat_src(ilbs, &cur_addr, srv_addr,
port);
if (src->src_list[i] == NULL) {
ilb_destroy_nat_src(&src);
*nat_src = NULL;
return (ENOMEM);
}
incr_addr(&cur_addr);
src->num_src++;
}
*nat_src = src;
return (0);
}
void
ilb_nat_src_timer(void *arg)
{
ilb_stack_t *ilbs = (ilb_stack_t *)arg;
ilb_nat_src_entry_t *cur, *tmp;
list_t *head;
int i;
for (i = 0; i < ilbs->ilbs_nat_src_hash_size; i++) {
mutex_enter(&ilbs->ilbs_nat_src[i].nsh_lock);
head = &ilbs->ilbs_nat_src[i].nsh_head;
cur = list_head(head);
while (cur != NULL) {
if (cur->nse_refcnt != 0 ||
vmem_size(cur->nse_port_arena, VMEM_ALLOC) != 0) {
cur = list_next(head, cur);
continue;
}
tmp = cur;
cur = list_next(head, cur);
list_remove(head, tmp);
vmem_destroy(tmp->nse_port_arena);
kmem_free(tmp, sizeof (ilb_nat_src_entry_t));
}
mutex_exit(&ilbs->ilbs_nat_src[i].nsh_lock);
}
mutex_enter(&ilbs->ilbs_nat_src_lock);
if (ilbs->ilbs_nat_src_tid == 0) {
mutex_exit(&ilbs->ilbs_nat_src_lock);
} else {
ilbs->ilbs_nat_src_tid = timeout(ilb_nat_src_timer, ilbs,
SEC_TO_TICK(ILB_NAT_SRC_TIMEOUT +
gethrtime() % ILB_NAT_SRC_TIMEOUT_JITTER));
mutex_exit(&ilbs->ilbs_nat_src_lock);
}
}
void
ilb_destroy_nat_src(ilb_nat_src_t **nat_src)
{
int i, size;
ilb_nat_src_t *src;
ilb_nat_src_entry_t *entry;
src = *nat_src;
if (src == NULL)
return;
size = src->num_src;
for (i = 0; i < size; i++) {
entry = src->src_list[i];
mutex_enter(entry->nse_nsh_lock);
entry->nse_refcnt--;
mutex_exit(entry->nse_nsh_lock);
}
kmem_free(src, sizeof (ilb_nat_src_t));
*nat_src = NULL;
}
ilb_nat_src_entry_t *
ilb_alloc_nat_addr(ilb_nat_src_t *src, in6_addr_t *addr, in_port_t *port,
uint16_t *nat_src_idx)
{
int i, try, size;
in_port_t p;
size = src->num_src;
if (nat_src_idx == NULL)
i = (++src->cur) % size;
else
i = *nat_src_idx;
for (try = 0; try < size; try++) {
p = (in_port_t)(uintptr_t)vmem_alloc(
src->src_list[i]->nse_port_arena, 1, VM_NOSLEEP);
if (p != 0)
break;
if (nat_src_idx != NULL)
return (NULL);
i = (i + 1) % size;
}
if (try == size)
return (NULL);
*addr = src->src_list[i]->nse_src_addr;
*port = htons(p);
return (src->src_list[i]);
}
static void
adj_cksum(uint16_t *chksum, uint32_t adj_sum)
{
adj_sum += (uint16_t)~(*chksum);
while ((adj_sum >> 16) != 0)
adj_sum = (adj_sum & 0xffff) + (adj_sum >> 16);
*chksum = (uint16_t)~adj_sum;
}
void
ilb_full_nat(int l3, void *iph, int l4, void *tph, ilb_nat_info_t *info,
uint32_t adj_ip_sum, uint32_t adj_tp_sum, boolean_t c2s)
{
in_port_t *orig_sport, *orig_dport;
uint16_t *tp_cksum;
switch (l4) {
case IPPROTO_TCP:
orig_sport = &((tcpha_t *)tph)->tha_lport;
orig_dport = &((tcpha_t *)tph)->tha_fport;
tp_cksum = &((tcpha_t *)tph)->tha_sum;
break;
case IPPROTO_UDP:
orig_sport = &((udpha_t *)tph)->uha_src_port;
orig_dport = &((udpha_t *)tph)->uha_dst_port;
tp_cksum = &((udpha_t *)tph)->uha_checksum;
break;
default:
ASSERT(0);
return;
}
switch (l3) {
case IPPROTO_IP: {
ipha_t *ipha;
ipha = iph;
if (c2s) {
IN6_V4MAPPED_TO_IPADDR(&info->nat_src,
ipha->ipha_src);
IN6_V4MAPPED_TO_IPADDR(&info->nat_dst,
ipha->ipha_dst);
*orig_sport = info->nat_sport;
*orig_dport = info->nat_dport;
} else {
IN6_V4MAPPED_TO_IPADDR(&info->vip, ipha->ipha_src);
IN6_V4MAPPED_TO_IPADDR(&info->src, ipha->ipha_dst);
*orig_sport = info->dport;
*orig_dport = info->sport;
}
adj_cksum(&ipha->ipha_hdr_checksum, adj_ip_sum);
adj_cksum(tp_cksum, adj_tp_sum);
break;
}
case IPPROTO_IPV6: {
ip6_t *ip6h;
ip6h = iph;
if (c2s) {
ip6h->ip6_src = info->nat_src;
ip6h->ip6_dst = info->nat_dst;
*orig_sport = info->nat_sport;
*orig_dport = info->nat_dport;
} else {
ip6h->ip6_src = info->vip;
ip6h->ip6_dst = info->src;
*orig_sport = info->dport;
*orig_dport = info->sport;
}
adj_cksum(tp_cksum, adj_tp_sum);
break;
}
default:
ASSERT(0);
break;
}
}
void
ilb_half_nat(int l3, void *iph, int l4, void *tph, ilb_nat_info_t *info,
uint32_t adj_ip_sum, uint32_t adj_tp_sum, boolean_t c2s)
{
in_port_t *orig_port;
uint16_t *tp_cksum;
switch (l4) {
case IPPROTO_TCP:
if (c2s)
orig_port = &((tcpha_t *)tph)->tha_fport;
else
orig_port = &((tcpha_t *)tph)->tha_lport;
tp_cksum = &((tcpha_t *)tph)->tha_sum;
break;
case IPPROTO_UDP:
if (c2s)
orig_port = &((udpha_t *)tph)->uha_dst_port;
else
orig_port = &((udpha_t *)tph)->uha_src_port;
tp_cksum = &((udpha_t *)tph)->uha_checksum;
break;
default:
ASSERT(0);
return;
}
switch (l3) {
case IPPROTO_IP: {
ipha_t *ipha;
ipha = iph;
if (c2s) {
IN6_V4MAPPED_TO_IPADDR(&info->nat_dst,
ipha->ipha_dst);
*orig_port = info->nat_dport;
} else {
IN6_V4MAPPED_TO_IPADDR(&info->vip, ipha->ipha_src);
*orig_port = info->dport;
}
adj_cksum(&ipha->ipha_hdr_checksum, adj_ip_sum);
adj_cksum(tp_cksum, adj_tp_sum);
break;
}
case IPPROTO_IPV6: {
ip6_t *ip6h;
ip6h = iph;
if (c2s) {
ip6h->ip6_dst = info->nat_dst;
*orig_port = info->nat_dport;
} else {
ip6h->ip6_src = info->vip;
*orig_port = info->dport;
}
adj_cksum(tp_cksum, adj_tp_sum);
break;
}
default:
ASSERT(0);
break;
}
}
uint32_t
ilb_pseudo_sum_v6(ip6_t *ip6h, uint8_t nxt_hdr)
{
uint32_t sum;
uint16_t *cur;
cur = (uint16_t *)&ip6h->ip6_src;
sum = cur[0] + cur[1] + cur[2] + cur[3] + cur[4] + cur[5] + cur[6] +
cur[7] + cur[8] + cur[9] + cur[10] + cur[11] + cur[12] + cur[13] +
cur[14] + cur[15] + htons(nxt_hdr);
return ((sum & 0xffff) + (sum >> 16));
}
void
ilb_nat_icmpv4(mblk_t *mp, ipha_t *out_iph, icmph_t *icmph, ipha_t *in_iph,
in_port_t *sport, in_port_t *dport, ilb_nat_info_t *info, uint32_t sum,
boolean_t full_nat)
{
if (full_nat) {
IN6_V4MAPPED_TO_IPADDR(&info->nat_src, out_iph->ipha_src);
IN6_V4MAPPED_TO_IPADDR(&info->nat_src, in_iph->ipha_dst);
*dport = info->nat_sport;
}
IN6_V4MAPPED_TO_IPADDR(&info->nat_dst, out_iph->ipha_dst);
adj_cksum(&out_iph->ipha_hdr_checksum, sum);
IN6_V4MAPPED_TO_IPADDR(&info->nat_dst, in_iph->ipha_src);
*sport = info->nat_dport;
icmph->icmph_checksum = 0;
icmph->icmph_checksum = IP_CSUM(mp, IPH_HDR_LENGTH(out_iph), 0);
}
void
ilb_nat_icmpv6(mblk_t *mp, ip6_t *out_ip6h, icmp6_t *icmp6h, ip6_t *in_ip6h,
in_port_t *sport, in_port_t *dport, ilb_nat_info_t *info,
boolean_t full_nat)
{
int hdr_len;
if (full_nat) {
out_ip6h->ip6_src = info->nat_src;
in_ip6h->ip6_dst = info->nat_src;
*dport = info->nat_sport;
}
out_ip6h->ip6_dst = info->nat_dst;
in_ip6h->ip6_src = info->nat_dst;
*sport = info->nat_dport;
icmp6h->icmp6_cksum = out_ip6h->ip6_plen;
hdr_len = (char *)icmp6h - (char *)out_ip6h;
icmp6h->icmp6_cksum = IP_CSUM(mp, hdr_len,
ilb_pseudo_sum_v6(out_ip6h, IPPROTO_ICMPV6));
}