#include <sys/types.h>
#include <sys/param.h>
#include <sys/byteorder.h>
#include <sys/systm.h>
#include <sys/sysmacros.h>
#include <netinet/in.h>
#include <netinet/in_systm.h>
#include <netinet/ip.h>
#include <netinet/tcp.h>
#include <net/ppp_defs.h>
#include <net/vjcompress.h>
#ifndef VJ_NO_STATS
#define INCR(counter) ++comp->stats.counter
#else
#define INCR(counter)
#endif
#define BCMP(p1, p2, n) bcmp((char *)(p1), (char *)(p2), (unsigned int)(n))
#undef BCOPY
#define BCOPY(p1, p2, n) bcopy((char *)(p1), (char *)(p2), (unsigned int)(n))
#define getip_hl(bp) (((uchar_t *)bp)[0] & 0x0F)
#define getth_off(bp) (((uchar_t *)bp)[12] >> 4)
#define getip_p(bp) (((uchar_t *)bp)[offsetof(struct ip, ip_p)])
#define setip_p(bp, v) (((uchar_t *)bp)[offsetof(struct ip, ip_p)] = (v))
void
vj_compress_init(struct vjcompress *comp, int max_state)
{
register uint_t i;
register struct cstate *tstate = comp->tstate;
if (max_state == -1) {
max_state = MAX_STATES - 1;
}
bzero((char *)comp, sizeof (*comp));
for (i = max_state; i > 0; --i) {
tstate[i].cs_id = i & 0xff;
tstate[i].cs_next = &tstate[i - 1];
}
tstate[0].cs_next = &tstate[max_state];
tstate[0].cs_id = 0;
comp->last_cs = &tstate[0];
comp->last_recv = 255;
comp->last_xmit = 255;
comp->flags = VJF_TOSS;
}
#define ENCODE(n) { \
if ((ushort_t)(n) >= 256) { \
*cp++ = 0; \
cp[1] = (n) & 0xff; \
cp[0] = ((n) >> 8) & 0xff; \
cp += 2; \
} else { \
*cp++ = (n) & 0xff; \
} \
}
#define ENCODEZ(n) { \
if ((ushort_t)(n) >= 256 || (ushort_t)(n) == 0) { \
*cp++ = 0; \
cp[1] = (n) & 0xff; \
cp[0] = ((n) >> 8) & 0xff; \
cp += 2; \
} else { \
*cp++ = (n) & 0xff; \
} \
}
#define DECODEL(f) { \
if (*cp == 0) { \
uint32_t tmp = ntohl(f) + ((cp[1] << 8) | cp[2]); \
(f) = htonl(tmp); \
cp += 3; \
} else { \
uint32_t tmp = ntohl(f) + (uint32_t)*cp++; \
(f) = htonl(tmp); \
} \
}
#define DECODES(f) { \
if (*cp == 0) { \
ushort_t tmp = ntohs(f) + ((cp[1] << 8) | cp[2]); \
(f) = htons(tmp); \
cp += 3; \
} else { \
ushort_t tmp = ntohs(f) + (uint32_t)*cp++; \
(f) = htons(tmp); \
} \
}
#define DECODEU(f) { \
if (*cp == 0) { \
(f) = htons((cp[1] << 8) | cp[2]); \
cp += 3; \
} else { \
(f) = htons((uint32_t)*cp++); \
} \
}
uint_t
vj_compress_tcp(register struct ip *ip, uint_t mlen, struct vjcompress *comp,
int compress_cid, uchar_t **vjhdrp)
{
register struct cstate *cs = comp->last_cs->cs_next;
register uint_t hlen = getip_hl(ip);
register struct tcphdr *oth;
register struct tcphdr *th;
register uint_t deltaS;
register uint_t deltaA;
register uint_t changes = 0;
uchar_t new_seq[16];
register uchar_t *cp = new_seq;
register uint_t thlen;
if ((ip->ip_off & htons(0x3fff)) || mlen < 40) {
return (TYPE_IP);
}
th = (struct tcphdr *)&((int *)ip)[hlen];
if ((th->th_flags & (TH_SYN|TH_FIN|TH_RST|TH_ACK)) != TH_ACK) {
return (TYPE_IP);
}
thlen = (hlen + getth_off(th)) << 2;
if (thlen > mlen) {
return (TYPE_IP);
}
INCR(vjs_packets);
if (ip->ip_src.s_addr != cs->cs_ip.ip_src.s_addr ||
ip->ip_dst.s_addr != cs->cs_ip.ip_dst.s_addr ||
*(int *)th != ((int *)&cs->cs_ip)[getip_hl(&cs->cs_ip)]) {
register struct cstate *lcs;
register struct cstate *lastcs = comp->last_cs;
do {
lcs = cs; cs = cs->cs_next;
INCR(vjs_searches);
if (ip->ip_src.s_addr == cs->cs_ip.ip_src.s_addr &&
ip->ip_dst.s_addr == cs->cs_ip.ip_dst.s_addr &&
*(int *)th == ((int *)
&cs->cs_ip)[getip_hl(&cs->cs_ip)]) {
goto found;
}
} while (cs != lastcs);
INCR(vjs_misses);
comp->last_cs = lcs;
goto uncompressed;
found:
if (cs == lastcs) {
comp->last_cs = lcs;
} else {
lcs->cs_next = cs->cs_next;
cs->cs_next = lastcs->cs_next;
lastcs->cs_next = cs;
}
}
oth = (struct tcphdr *)&((int *)&cs->cs_ip)[hlen];
deltaS = hlen;
if (((ushort_t *)ip)[0] != ((ushort_t *)&cs->cs_ip)[0] ||
((ushort_t *)ip)[3] != ((ushort_t *)&cs->cs_ip)[3] ||
((ushort_t *)ip)[4] != ((ushort_t *)&cs->cs_ip)[4] ||
getth_off(th) != getth_off(oth) ||
(deltaS > 5 &&
BCMP(ip + 1, &cs->cs_ip + 1, (deltaS - 5) << 2)) ||
(getth_off(th) > 5 &&
BCMP(th + 1, oth + 1, (getth_off(th) - 5) << 2))) {
goto uncompressed;
}
if (th->th_flags & TH_URG) {
deltaS = ntohs(th->th_urp);
ENCODEZ(deltaS);
changes |= NEW_U;
} else if (th->th_urp != oth->th_urp) {
goto uncompressed;
}
if ((deltaS = (ushort_t)(ntohs(th->th_win) - ntohs(oth->th_win))) > 0) {
ENCODE(deltaS);
changes |= NEW_W;
}
if ((deltaA = ntohl(th->th_ack) - ntohl(oth->th_ack)) > 0) {
if (deltaA > 0xffff) {
goto uncompressed;
}
ENCODE(deltaA);
changes |= NEW_A;
}
if ((deltaS = ntohl(th->th_seq) - ntohl(oth->th_seq)) > 0) {
if (deltaS > 0xffff) {
goto uncompressed;
}
ENCODE(deltaS);
changes |= NEW_S;
}
switch (changes) {
case 0:
if (ip->ip_len != cs->cs_ip.ip_len &&
ntohs(cs->cs_ip.ip_len) == thlen) {
break;
}
case SPECIAL_I:
case SPECIAL_D:
goto uncompressed;
case NEW_S|NEW_A:
if (deltaS == deltaA &&
deltaS == ntohs(cs->cs_ip.ip_len) - thlen) {
changes = SPECIAL_I;
cp = new_seq;
}
break;
case NEW_S:
if (deltaS == ntohs(cs->cs_ip.ip_len) - thlen) {
changes = SPECIAL_D;
cp = new_seq;
}
break;
}
deltaS = ntohs(ip->ip_id) - ntohs(cs->cs_ip.ip_id);
if (deltaS != 1) {
ENCODEZ(deltaS);
changes |= NEW_I;
}
if (th->th_flags & TH_PUSH) {
changes |= TCP_PUSH_BIT;
}
deltaA = ntohs(th->th_sum);
BCOPY(ip, &cs->cs_ip, thlen);
deltaS = cp - new_seq;
cp = (uchar_t *)ip;
if (compress_cid == 0 || comp->last_xmit != cs->cs_id) {
comp->last_xmit = cs->cs_id;
thlen -= deltaS + 4;
*vjhdrp = (cp += thlen);
*cp++ = changes | NEW_C;
*cp++ = cs->cs_id;
} else {
thlen -= deltaS + 3;
*vjhdrp = (cp += thlen);
*cp++ = changes & 0xff;
}
*cp++ = (deltaA >> 8) & 0xff;
*cp++ = deltaA & 0xff;
BCOPY(new_seq, cp, deltaS);
INCR(vjs_compressed);
return (TYPE_COMPRESSED_TCP);
uncompressed:
BCOPY(ip, &cs->cs_ip, thlen);
ip->ip_p = cs->cs_id;
comp->last_xmit = cs->cs_id;
return (TYPE_UNCOMPRESSED_TCP);
}
void
vj_uncompress_err(struct vjcompress *comp)
{
comp->flags |= VJF_TOSS;
INCR(vjs_errorin);
}
int
vj_uncompress_uncomp(uchar_t *buf, int buflen, struct vjcompress *comp)
{
register uint_t hlen;
register struct cstate *cs;
hlen = getip_hl(buf) << 2;
if (getip_p(buf) >= MAX_STATES ||
hlen + sizeof (struct tcphdr) > buflen ||
(hlen += getth_off(buf+hlen) << 2) > buflen || hlen > MAX_HDR) {
comp->flags |= VJF_TOSS;
INCR(vjs_errorin);
return (0);
}
cs = &comp->rstate[comp->last_recv = getip_p(buf)];
comp->flags &= ~VJF_TOSS;
setip_p(buf, IPPROTO_TCP);
BCOPY(buf, &cs->cs_ip, hlen);
cs->cs_hlen = hlen & 0xff;
INCR(vjs_uncompressedin);
return (1);
}
int
vj_uncompress_tcp(uchar_t *buf, int buflen, int total_len,
struct vjcompress *comp, uchar_t **hdrp, uint_t *hlenp)
{
register uchar_t *cp;
register uint_t hlen;
register uint_t changes;
register struct tcphdr *th;
register struct cstate *cs;
register ushort_t *bp;
register uint_t vjlen;
register uint32_t tmp;
INCR(vjs_compressedin);
cp = buf;
changes = *cp++;
if (changes & NEW_C) {
if (*cp >= MAX_STATES) {
goto bad;
}
comp->flags &= ~VJF_TOSS;
comp->last_recv = *cp++;
} else {
if (comp->flags & VJF_TOSS) {
INCR(vjs_tossed);
return (-1);
}
}
cs = &comp->rstate[comp->last_recv];
hlen = getip_hl(&cs->cs_ip) << 2;
th = (struct tcphdr *)((uint32_t *)&cs->cs_ip+hlen/sizeof (uint32_t));
th->th_sum = htons((*cp << 8) | cp[1]);
cp += 2;
if (changes & TCP_PUSH_BIT) {
th->th_flags |= TH_PUSH;
} else {
th->th_flags &= ~TH_PUSH;
}
switch (changes & SPECIALS_MASK) {
case SPECIAL_I:
{
register uint32_t i;
i = ntohs(cs->cs_ip.ip_len) - cs->cs_hlen;
tmp = ntohl(th->th_ack) + i;
th->th_ack = htonl(tmp);
tmp = ntohl(th->th_seq) + i;
th->th_seq = htonl(tmp);
}
break;
case SPECIAL_D:
tmp = ntohl(th->th_seq) + ntohs(cs->cs_ip.ip_len) - cs->cs_hlen;
th->th_seq = htonl(tmp);
break;
default:
if (changes & NEW_U) {
th->th_flags |= TH_URG;
DECODEU(th->th_urp);
} else {
th->th_flags &= ~TH_URG;
}
if (changes & NEW_W) {
DECODES(th->th_win);
}
if (changes & NEW_A) {
DECODEL(th->th_ack);
}
if (changes & NEW_S) {
DECODEL(th->th_seq);
}
break;
}
if (changes & NEW_I) {
DECODES(cs->cs_ip.ip_id);
} else {
cs->cs_ip.ip_id = ntohs(cs->cs_ip.ip_id) + 1;
cs->cs_ip.ip_id = htons(cs->cs_ip.ip_id);
}
vjlen = cp - buf;
buflen -= vjlen;
if (buflen < 0) {
goto bad;
}
total_len += cs->cs_hlen - vjlen;
cs->cs_ip.ip_len = htons(total_len);
bp = (ushort_t *)&cs->cs_ip;
cs->cs_ip.ip_sum = 0;
for (changes = 0; hlen > 0; hlen -= 2) {
changes += *bp++;
}
changes = (changes & 0xffff) + (changes >> 16);
changes = (changes & 0xffff) + (changes >> 16);
cs->cs_ip.ip_sum = ~ changes;
*hdrp = (uchar_t *)&cs->cs_ip;
*hlenp = cs->cs_hlen;
return (vjlen);
bad:
comp->flags |= VJF_TOSS;
INCR(vjs_errorin);
return (-1);
}