root/sys/netgraph/ng_checksum.c
/*-
 * Copyright (c) 2015 Dmitry Vagin <daemon.hammer@ya.ru>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 */

#include <sys/cdefs.h>
#include "opt_inet.h"
#include "opt_inet6.h"

#include <sys/param.h>
#include <sys/systm.h>
#include <sys/kernel.h>
#include <sys/endian.h>
#include <sys/malloc.h>
#include <sys/mbuf.h>
#include <sys/socket.h>

#include <net/bpf.h>
#include <net/ethernet.h>
#include <net/if.h>
#include <net/if_vlan_var.h>

#include <netinet/in.h>
#include <netinet/ip.h>
#include <netinet/ip6.h>
#include <netinet/tcp.h>
#include <netinet/udp.h>
#include <machine/in_cksum.h>

#include <netgraph/ng_message.h>
#include <netgraph/ng_parse.h>
#include <netgraph/netgraph.h>

#include <netgraph/ng_checksum.h>

/* private data */
struct ng_checksum_priv {
        hook_p in;
        hook_p out;
        uint8_t dlt;    /* DLT_XXX from bpf.h */
        struct ng_checksum_config *conf;
        struct ng_checksum_stats stats;
};

typedef struct ng_checksum_priv *priv_p;

/* Netgraph methods */
static ng_constructor_t ng_checksum_constructor;
static ng_rcvmsg_t      ng_checksum_rcvmsg;
static ng_shutdown_t    ng_checksum_shutdown;
static ng_newhook_t     ng_checksum_newhook;
static ng_rcvdata_t     ng_checksum_rcvdata;
static ng_disconnect_t  ng_checksum_disconnect;
#define ERROUT(x) { error = (x); goto done; }

static const struct ng_parse_struct_field ng_checksum_config_type_fields[]
        = NG_CHECKSUM_CONFIG_TYPE;
static const struct ng_parse_type ng_checksum_config_type = {
        &ng_parse_struct_type,
        &ng_checksum_config_type_fields
};

static const struct ng_parse_struct_field ng_checksum_stats_fields[]
        = NG_CHECKSUM_STATS_TYPE;
static const struct ng_parse_type ng_checksum_stats_type = {
        &ng_parse_struct_type,
        &ng_checksum_stats_fields
};

static const struct ng_cmdlist ng_checksum_cmdlist[] = {
        {
                NGM_CHECKSUM_COOKIE,
                NGM_CHECKSUM_GETDLT,
                "getdlt",
                NULL,
                &ng_parse_uint8_type
        },
        {
                NGM_CHECKSUM_COOKIE,
                NGM_CHECKSUM_SETDLT,
                "setdlt",
                &ng_parse_uint8_type,
                NULL
        },
        {
                NGM_CHECKSUM_COOKIE,
                NGM_CHECKSUM_GETCONFIG,
                "getconfig",
                NULL,
                &ng_checksum_config_type
        },
        {
                NGM_CHECKSUM_COOKIE,
                NGM_CHECKSUM_SETCONFIG,
                "setconfig",
                &ng_checksum_config_type,
                NULL
        },
        {
                NGM_CHECKSUM_COOKIE,
                NGM_CHECKSUM_GET_STATS,
                "getstats",
                NULL,
                &ng_checksum_stats_type
        },
        {
                NGM_CHECKSUM_COOKIE,
                NGM_CHECKSUM_CLR_STATS,
                "clrstats",
                NULL,
                NULL
        },
        {
                NGM_CHECKSUM_COOKIE,
                NGM_CHECKSUM_GETCLR_STATS,
                "getclrstats",
                NULL,
                &ng_checksum_stats_type
        },
        { 0 }
};

static struct ng_type typestruct = {
        .version =      NG_ABI_VERSION,
        .name =         NG_CHECKSUM_NODE_TYPE,
        .constructor =  ng_checksum_constructor,
        .rcvmsg =       ng_checksum_rcvmsg,
        .shutdown =     ng_checksum_shutdown,
        .newhook =      ng_checksum_newhook,
        .rcvdata =      ng_checksum_rcvdata,
        .disconnect =   ng_checksum_disconnect,
        .cmdlist =      ng_checksum_cmdlist,
};

NETGRAPH_INIT(checksum, &typestruct);

static int
ng_checksum_constructor(node_p node)
{
        priv_p priv;

        priv = malloc(sizeof(*priv), M_NETGRAPH, M_WAITOK|M_ZERO);
        priv->dlt = DLT_RAW;

        NG_NODE_SET_PRIVATE(node, priv);

        return (0);
}

static int
ng_checksum_newhook(node_p node, hook_p hook, const char *name)
{
        const priv_p priv = NG_NODE_PRIVATE(node);

        if (strncmp(name, NG_CHECKSUM_HOOK_IN, strlen(NG_CHECKSUM_HOOK_IN)) == 0) {
                priv->in = hook;
        } else if (strncmp(name, NG_CHECKSUM_HOOK_OUT, strlen(NG_CHECKSUM_HOOK_OUT)) == 0) {
                priv->out = hook;
        } else
                return (EINVAL);

        return (0);
}

static int
ng_checksum_rcvmsg(node_p node, item_p item, hook_p lasthook)
{
        const priv_p priv = NG_NODE_PRIVATE(node);
        struct ng_checksum_config *conf, *newconf;
        struct ng_mesg *msg;
        struct ng_mesg *resp = NULL;
        int error = 0;

        NGI_GET_MSG(item, msg);

        if  (msg->header.typecookie != NGM_CHECKSUM_COOKIE)
                ERROUT(EINVAL);

        switch (msg->header.cmd)
        {
                case NGM_CHECKSUM_GETDLT:
                        NG_MKRESPONSE(resp, msg, sizeof(uint8_t), M_WAITOK);

                        if (resp == NULL)
                                ERROUT(ENOMEM);

                        *((uint8_t *) resp->data) = priv->dlt;

                        break;

                case NGM_CHECKSUM_SETDLT:
                        if (msg->header.arglen != sizeof(uint8_t))
                                ERROUT(EINVAL);

                        switch (*(uint8_t *) msg->data)
                        {
                                case DLT_EN10MB:
                                case DLT_RAW:
                                        priv->dlt = *(uint8_t *) msg->data;
                                        break;

                                default:
                                        ERROUT(EINVAL);
                        }

                        break;

                case NGM_CHECKSUM_GETCONFIG:
                        if (priv->conf == NULL)
                                ERROUT(0);

                        NG_MKRESPONSE(resp, msg, sizeof(struct ng_checksum_config), M_WAITOK);

                        if (resp == NULL)
                                ERROUT(ENOMEM);

                        bcopy(priv->conf, resp->data, sizeof(struct ng_checksum_config));

                        break;

                case NGM_CHECKSUM_SETCONFIG:
                        conf = (struct ng_checksum_config *) msg->data;

                        if (msg->header.arglen != sizeof(struct ng_checksum_config))
                                ERROUT(EINVAL);

                        conf->csum_flags &= NG_CHECKSUM_CSUM_IPV4|NG_CHECKSUM_CSUM_IPV6;
                        conf->csum_offload &= NG_CHECKSUM_CSUM_IPV4|NG_CHECKSUM_CSUM_IPV6;

                        newconf = malloc(sizeof(struct ng_checksum_config), M_NETGRAPH, M_WAITOK|M_ZERO);

                        bcopy(conf, newconf, sizeof(struct ng_checksum_config));

                        if (priv->conf)
                                free(priv->conf, M_NETGRAPH);

                        priv->conf = newconf;

                        break;

                case NGM_CHECKSUM_GET_STATS:
                case NGM_CHECKSUM_CLR_STATS:
                case NGM_CHECKSUM_GETCLR_STATS:
                        if (msg->header.cmd != NGM_CHECKSUM_CLR_STATS) {
                                NG_MKRESPONSE(resp, msg, sizeof(struct ng_checksum_stats), M_WAITOK);

                                if (resp == NULL)
                                        ERROUT(ENOMEM);

                                bcopy(&(priv->stats), resp->data, sizeof(struct ng_checksum_stats));
                        }

                        if (msg->header.cmd != NGM_CHECKSUM_GET_STATS)
                                bzero(&(priv->stats), sizeof(struct ng_checksum_stats));

                        break;

                default:
                        ERROUT(EINVAL);
        }

done:
        NG_RESPOND_MSG(error, node, item, resp);
        NG_FREE_MSG(msg);

        return (error);
}

#define PULLUP_CHECK(mbuf, length) do {                                 \
        pullup_len += length;                                           \
        if (((mbuf)->m_pkthdr.len < pullup_len) ||                      \
            (pullup_len > MHLEN)) {                                     \
                return (EINVAL);                                        \
        }                                                               \
        if ((mbuf)->m_len < pullup_len &&                               \
            (((mbuf) = m_pullup((mbuf), pullup_len)) == NULL)) {        \
                return (ENOBUFS);                                       \
        }                                                               \
} while (0)

#ifdef INET
static int
checksum_ipv4(priv_p priv, struct mbuf *m, int l3_offset)
{
        struct ip *ip4;
        int pullup_len;
        int hlen, plen;
        int processed = 0;

        pullup_len = l3_offset;

        PULLUP_CHECK(m, sizeof(struct ip));
        ip4 = (struct ip *) mtodo(m, l3_offset);

        if (ip4->ip_v != IPVERSION)
                return (EOPNOTSUPP);

        hlen = ip4->ip_hl << 2;
        plen = ntohs(ip4->ip_len);

        if (hlen < sizeof(struct ip) || m->m_pkthdr.len < l3_offset + plen)
                return (EINVAL);

        if (m->m_pkthdr.csum_flags & CSUM_IP) {
                ip4->ip_sum = 0;

                if ((priv->conf->csum_offload & CSUM_IP) == 0) {
                        if (hlen == sizeof(struct ip))
                                ip4->ip_sum = in_cksum_hdr(ip4);
                        else
                                ip4->ip_sum = in_cksum_skip(m, l3_offset + hlen, l3_offset);

                        m->m_pkthdr.csum_flags &= ~CSUM_IP;
                }

                processed = 1;
        }

        pullup_len = l3_offset + hlen;

        /* We can not calculate a checksum fragmented packets */
        if (ip4->ip_off & htons(IP_MF|IP_OFFMASK)) {
                m->m_pkthdr.csum_flags &= ~(CSUM_TCP|CSUM_UDP);
                return (0);
        }

        switch (ip4->ip_p)
        {
                case IPPROTO_TCP:
                        if (m->m_pkthdr.csum_flags & CSUM_TCP) {
                                struct tcphdr *th;

                                PULLUP_CHECK(m, sizeof(struct tcphdr));
                                th = (struct tcphdr *) mtodo(m, l3_offset + hlen);

                                th->th_sum = in_pseudo(ip4->ip_src.s_addr,
                                    ip4->ip_dst.s_addr, htons(ip4->ip_p + plen - hlen));

                                if ((priv->conf->csum_offload & CSUM_TCP) == 0) {
                                        th->th_sum = in_cksum_skip(m, l3_offset + plen, l3_offset + hlen);
                                        m->m_pkthdr.csum_flags &= ~CSUM_TCP;
                                }

                                processed = 1;
                        }

                        m->m_pkthdr.csum_flags &= ~CSUM_UDP;
                        break;

                case IPPROTO_UDP:
                        if (m->m_pkthdr.csum_flags & CSUM_UDP) {
                                struct udphdr *uh;

                                PULLUP_CHECK(m, sizeof(struct udphdr));
                                uh = (struct udphdr *) mtodo(m, l3_offset + hlen);

                                uh->uh_sum = in_pseudo(ip4->ip_src.s_addr,
                                    ip4->ip_dst.s_addr, htons(ip4->ip_p + plen - hlen));

                                if ((priv->conf->csum_offload & CSUM_UDP) == 0) {
                                        uh->uh_sum = in_cksum_skip(m,
                                            l3_offset + plen, l3_offset + hlen);

                                        if (uh->uh_sum == 0)
                                                uh->uh_sum = 0xffff;

                                        m->m_pkthdr.csum_flags &= ~CSUM_UDP;
                                }

                                processed = 1;
                        }

                        m->m_pkthdr.csum_flags &= ~CSUM_TCP;
                        break;

                default:
                        m->m_pkthdr.csum_flags &= ~(CSUM_TCP|CSUM_UDP);
                        break;
        }

        m->m_pkthdr.csum_flags &= ~NG_CHECKSUM_CSUM_IPV6;

        if (processed)
                priv->stats.processed++;

        return (0);
}
#endif /* INET */

#ifdef INET6
static int
checksum_ipv6(priv_p priv, struct mbuf *m, int l3_offset)
{
        struct ip6_hdr *ip6;
        struct ip6_ext *ip6e = NULL;
        int pullup_len;
        int hlen, plen;
        int nxt;
        int processed = 0;

        pullup_len = l3_offset;

        PULLUP_CHECK(m, sizeof(struct ip6_hdr));
        ip6 = (struct ip6_hdr *) mtodo(m, l3_offset);

        if ((ip6->ip6_vfc & IPV6_VERSION_MASK) != IPV6_VERSION)
                return (EOPNOTSUPP);

        hlen = sizeof(struct ip6_hdr);
        plen = ntohs(ip6->ip6_plen) + hlen;

        if (m->m_pkthdr.len < l3_offset + plen)
                return (EINVAL);

        nxt = ip6->ip6_nxt;

        for (;;) {
                switch (nxt)
                {
                        case IPPROTO_DSTOPTS:
                        case IPPROTO_HOPOPTS:
                        case IPPROTO_ROUTING:
                                PULLUP_CHECK(m, sizeof(struct ip6_ext));
                                ip6e = (struct ip6_ext *) mtodo(m, l3_offset + hlen);
                                nxt = ip6e->ip6e_nxt;
                                hlen += (ip6e->ip6e_len + 1) << 3;
                                pullup_len = l3_offset + hlen;
                                break;

                        case IPPROTO_AH:
                                PULLUP_CHECK(m, sizeof(struct ip6_ext));
                                ip6e = (struct ip6_ext *) mtodo(m, l3_offset + hlen);
                                nxt = ip6e->ip6e_nxt;
                                hlen += (ip6e->ip6e_len + 2) << 2;
                                pullup_len = l3_offset + hlen;
                                break;

                        case IPPROTO_FRAGMENT:
                                /* We can not calculate a checksum fragmented packets */
                                m->m_pkthdr.csum_flags &= ~(CSUM_TCP_IPV6|CSUM_UDP_IPV6);
                                return (0);

                        default:
                                goto loopend;
                }

                if (nxt == 0)
                        return (EINVAL);
        }

loopend:

        switch (nxt)
        {
                case IPPROTO_TCP:
                        if (m->m_pkthdr.csum_flags & CSUM_TCP_IPV6) {
                                struct tcphdr *th;

                                PULLUP_CHECK(m, sizeof(struct tcphdr));
                                th = (struct tcphdr *) mtodo(m, l3_offset + hlen);

                                th->th_sum = in6_cksum_pseudo(ip6, plen - hlen, nxt, 0);

                                if ((priv->conf->csum_offload & CSUM_TCP_IPV6) == 0) {
                                        th->th_sum = in_cksum_skip(m, l3_offset + plen, l3_offset + hlen);
                                        m->m_pkthdr.csum_flags &= ~CSUM_TCP_IPV6;
                                }

                                processed = 1;
                        }

                        m->m_pkthdr.csum_flags &= ~CSUM_UDP_IPV6;
                        break;

                case IPPROTO_UDP:
                        if (m->m_pkthdr.csum_flags & CSUM_UDP_IPV6) {
                                struct udphdr *uh;

                                PULLUP_CHECK(m, sizeof(struct udphdr));
                                uh = (struct udphdr *) mtodo(m, l3_offset + hlen);

                                uh->uh_sum = in6_cksum_pseudo(ip6, plen - hlen, nxt, 0);

                                if ((priv->conf->csum_offload & CSUM_UDP_IPV6) == 0) {
                                        uh->uh_sum = in_cksum_skip(m,
                                            l3_offset + plen, l3_offset + hlen);

                                        if (uh->uh_sum == 0)
                                                uh->uh_sum = 0xffff;

                                        m->m_pkthdr.csum_flags &= ~CSUM_UDP_IPV6;
                                }

                                processed = 1;
                        }

                        m->m_pkthdr.csum_flags &= ~CSUM_TCP_IPV6;
                        break;

                default:
                        m->m_pkthdr.csum_flags &= ~(CSUM_TCP_IPV6|CSUM_UDP_IPV6);
                        break;
        }

        m->m_pkthdr.csum_flags &= ~NG_CHECKSUM_CSUM_IPV4;

        if (processed)
                priv->stats.processed++;

        return (0);
}
#endif /* INET6 */

#undef  PULLUP_CHECK

static int
ng_checksum_rcvdata(hook_p hook, item_p item)
{
        const priv_p priv = NG_NODE_PRIVATE(NG_HOOK_NODE(hook));
        struct mbuf *m;
        hook_p out;
        int error = 0;

        priv->stats.received++;

        NGI_GET_M(item, m);

#define PULLUP_CHECK(mbuf, length) do {                                 \
        pullup_len += length;                                           \
        if (((mbuf)->m_pkthdr.len < pullup_len) ||                      \
            (pullup_len > MHLEN)) {                                     \
                error = EINVAL;                                         \
                goto bypass;                                            \
        }                                                               \
        if ((mbuf)->m_len < pullup_len &&                               \
            (((mbuf) = m_pullup((mbuf), pullup_len)) == NULL)) {        \
                error = ENOBUFS;                                        \
                goto drop;                                              \
        }                                                               \
} while (0)

        if (!(priv->conf && hook == priv->in && m && (m->m_flags & M_PKTHDR)))
                goto bypass;

        m->m_pkthdr.csum_flags |= priv->conf->csum_flags;

        if (m->m_pkthdr.csum_flags & (NG_CHECKSUM_CSUM_IPV4|NG_CHECKSUM_CSUM_IPV6))
        {
                struct ether_header *eh;
                struct ng_checksum_vlan_header *vh;
                int pullup_len = 0;
                uint16_t etype;

                m = m_unshare(m, M_NOWAIT);

                if (m == NULL)
                        ERROUT(ENOMEM);

                switch (priv->dlt)
                {
                        case DLT_EN10MB:
                                PULLUP_CHECK(m, sizeof(struct ether_header));
                                eh = mtod(m, struct ether_header *);
                                etype = ntohs(eh->ether_type);

                                for (;;) {      /* QinQ support */
                                        switch (etype)
                                        {
                                                case 0x8100:
                                                case 0x88A8:
                                                case 0x9100:
                                                        PULLUP_CHECK(m, sizeof(struct ng_checksum_vlan_header));
                                                        vh = (struct ng_checksum_vlan_header *) mtodo(m,
                                                            pullup_len - sizeof(struct ng_checksum_vlan_header));
                                                        etype = ntohs(vh->etype);
                                                        break;

                                                default:
                                                        goto loopend;
                                        }
                                }
loopend:
#ifdef INET
                                if (etype == ETHERTYPE_IP &&
                                    (m->m_pkthdr.csum_flags & NG_CHECKSUM_CSUM_IPV4)) {
                                        error = checksum_ipv4(priv, m, pullup_len);
                                        if (error == ENOBUFS)
                                                goto drop;
                                } else
#endif
#ifdef INET6
                                if (etype == ETHERTYPE_IPV6 &&
                                    (m->m_pkthdr.csum_flags & NG_CHECKSUM_CSUM_IPV6)) {
                                        error = checksum_ipv6(priv, m, pullup_len);
                                        if (error == ENOBUFS)
                                                goto drop;
                                } else
#endif
                                {
                                        m->m_pkthdr.csum_flags &=
                                            ~(NG_CHECKSUM_CSUM_IPV4|NG_CHECKSUM_CSUM_IPV6);
                                }

                                break;

                        case DLT_RAW:
#ifdef INET
                                if (m->m_pkthdr.csum_flags & NG_CHECKSUM_CSUM_IPV4)
                                {
                                        error = checksum_ipv4(priv, m, pullup_len);

                                        if (error == 0)
                                                goto bypass;
                                        else if (error == ENOBUFS)
                                                goto drop;
                                }
#endif
#ifdef INET6
                                if (m->m_pkthdr.csum_flags & NG_CHECKSUM_CSUM_IPV6)
                                {
                                        error = checksum_ipv6(priv, m, pullup_len);

                                        if (error == 0)
                                                goto bypass;
                                        else if (error == ENOBUFS)
                                                goto drop;
                                }
#endif
                                if (error)
                                        m->m_pkthdr.csum_flags &=
                                            ~(NG_CHECKSUM_CSUM_IPV4|NG_CHECKSUM_CSUM_IPV6);

                                break;

                        default:
                                ERROUT(EINVAL);
                }
        }

#undef  PULLUP_CHECK

bypass:
        out = NULL;

        if (hook == priv->in) {
                /* return frames on 'in' hook if 'out' not connected */
                out = priv->out ? priv->out : priv->in;
        } else if (hook == priv->out && priv->in) {
                /* pass frames on 'out' hook if 'in' connected */
                out = priv->in;
        }

        if (out == NULL)
                ERROUT(0);

        NG_FWD_NEW_DATA(error, item, out, m);

        return (error);

done:
        NG_FREE_M(m);
drop:
        NG_FREE_ITEM(item);

        priv->stats.dropped++;

        return (error);
}

static int
ng_checksum_shutdown(node_p node)
{
        const priv_p priv = NG_NODE_PRIVATE(node);

        NG_NODE_SET_PRIVATE(node, NULL);
        NG_NODE_UNREF(node);

        if (priv->conf)
                free(priv->conf, M_NETGRAPH);

        free(priv, M_NETGRAPH);

        return (0);
}

static int
ng_checksum_disconnect(hook_p hook)
{
        priv_p priv;

        priv = NG_NODE_PRIVATE(NG_HOOK_NODE(hook));

        if (hook == priv->in)
                priv->in = NULL;

        if (hook == priv->out)
                priv->out = NULL;

        if (NG_NODE_NUMHOOKS(NG_HOOK_NODE(hook)) == 0 &&
            NG_NODE_IS_VALID(NG_HOOK_NODE(hook))) /* already shutting down? */
                ng_rmnode_self(NG_HOOK_NODE(hook));

        return (0);
}