root/sys/netgraph/ng_deflate.c
/*-
 * SPDX-License-Identifier: BSD-2-Clause
 *
 * Copyright (c) 2006 Alexander Motin <mav@alkar.net>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice unmodified, this list of conditions, and the following
 *    disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */

/*
 * Deflate PPP compression netgraph node type.
 */

#include <sys/param.h>
#include <sys/systm.h>
#include <sys/kernel.h>
#include <sys/mbuf.h>
#include <sys/malloc.h>
#include <sys/endian.h>
#include <sys/errno.h>
#include <sys/syslog.h>
#include <contrib/zlib/zlib.h>

#include <netgraph/ng_message.h>
#include <netgraph/netgraph.h>
#include <netgraph/ng_parse.h>
#include <netgraph/ng_deflate.h>

#include "opt_netgraph.h"

static MALLOC_DEFINE(M_NETGRAPH_DEFLATE, "netgraph_deflate",
    "netgraph deflate node");

/* DEFLATE header length */
#define DEFLATE_HDRLEN          2

#define PROT_COMPD              0x00fd

#define DEFLATE_BUF_SIZE        4096

/* Node private data */
struct ng_deflate_private {
        struct ng_deflate_config cfg;           /* configuration */
        u_char          inbuf[DEFLATE_BUF_SIZE];        /* input buffer */
        u_char          outbuf[DEFLATE_BUF_SIZE];       /* output buffer */
        z_stream        cx;                     /* compression context */
        struct ng_deflate_stats stats;          /* statistics */
        ng_ID_t         ctrlnode;               /* path to controlling node */
        uint16_t        seqnum;                 /* sequence number */
        u_char          compress;               /* compress/decompress flag */
};
typedef struct ng_deflate_private *priv_p;

/* Netgraph node methods */
static ng_constructor_t ng_deflate_constructor;
static ng_rcvmsg_t      ng_deflate_rcvmsg;
static ng_shutdown_t    ng_deflate_shutdown;
static ng_newhook_t     ng_deflate_newhook;
static ng_rcvdata_t     ng_deflate_rcvdata;
static ng_disconnect_t  ng_deflate_disconnect;

/* Helper functions */
static int      ng_deflate_compress(node_p, struct mbuf *, struct mbuf **);
static int      ng_deflate_decompress(node_p, struct mbuf *, struct mbuf **);
static void     ng_deflate_reset_req(node_p);

/* Parse type for struct ng_deflate_config. */
static const struct ng_parse_struct_field ng_deflate_config_type_fields[]
        = NG_DEFLATE_CONFIG_INFO;
static const struct ng_parse_type ng_deflate_config_type = {
        &ng_parse_struct_type,
        ng_deflate_config_type_fields
};

/* Parse type for struct ng_deflate_stat. */
static const struct ng_parse_struct_field ng_deflate_stats_type_fields[]
        = NG_DEFLATE_STATS_INFO;
static const struct ng_parse_type ng_deflate_stat_type = {
        &ng_parse_struct_type,
        ng_deflate_stats_type_fields
};

/* List of commands and how to convert arguments to/from ASCII. */
static const struct ng_cmdlist ng_deflate_cmds[] = {
        {
          NGM_DEFLATE_COOKIE,
          NGM_DEFLATE_CONFIG,
          "config",
          &ng_deflate_config_type,
          NULL
        },
        {
          NGM_DEFLATE_COOKIE,
          NGM_DEFLATE_RESETREQ,
          "resetreq",
          NULL,
          NULL
        },
        {
          NGM_DEFLATE_COOKIE,
          NGM_DEFLATE_GET_STATS,
          "getstats",
          NULL,
          &ng_deflate_stat_type
        },
        {
          NGM_DEFLATE_COOKIE,
          NGM_DEFLATE_CLR_STATS,
          "clrstats",
          NULL,
          NULL
        },
        {
          NGM_DEFLATE_COOKIE,
          NGM_DEFLATE_GETCLR_STATS,
          "getclrstats",
          NULL,
          &ng_deflate_stat_type
        },
        { 0 }
};

/* Node type descriptor */
static struct ng_type ng_deflate_typestruct = {
        .version =      NG_ABI_VERSION,
        .name =         NG_DEFLATE_NODE_TYPE,
        .constructor =  ng_deflate_constructor,
        .rcvmsg =       ng_deflate_rcvmsg,
        .shutdown =     ng_deflate_shutdown,
        .newhook =      ng_deflate_newhook,
        .rcvdata =      ng_deflate_rcvdata,
        .disconnect =   ng_deflate_disconnect,
        .cmdlist =      ng_deflate_cmds,
};
NETGRAPH_INIT(deflate, &ng_deflate_typestruct);

/* Depend on separate zlib module. */
MODULE_DEPEND(ng_deflate, zlib, 1, 1, 1);

#define ERROUT(x)       do { error = (x); goto done; } while (0)

/************************************************************************
                        NETGRAPH NODE STUFF
 ************************************************************************/

/*
 * Node type constructor
 */
static int
ng_deflate_constructor(node_p node)
{
        priv_p priv;

        /* Allocate private structure. */
        priv = malloc(sizeof(*priv), M_NETGRAPH_DEFLATE, M_WAITOK | M_ZERO);

        NG_NODE_SET_PRIVATE(node, priv);

        /* This node is not thread safe. */
        NG_NODE_FORCE_WRITER(node);

        /* Done */
        return (0);
}

/*
 * Give our OK for a hook to be added.
 */
static int
ng_deflate_newhook(node_p node, hook_p hook, const char *name)
{
        const priv_p priv = NG_NODE_PRIVATE(node);

        if (NG_NODE_NUMHOOKS(node) > 0)
                return (EINVAL);

        if (strcmp(name, NG_DEFLATE_HOOK_COMP) == 0)
                priv->compress = 1;
        else if (strcmp(name, NG_DEFLATE_HOOK_DECOMP) == 0)
                priv->compress = 0;
        else
                return (EINVAL);

        return (0);
}

/*
 * Receive a control message
 */
static int
ng_deflate_rcvmsg(node_p node, item_p item, hook_p lasthook)
{
        const priv_p priv = NG_NODE_PRIVATE(node);
        struct ng_mesg *resp = NULL;
        int error = 0;
        struct ng_mesg *msg;

        NGI_GET_MSG(item, msg);

        if (msg->header.typecookie != NGM_DEFLATE_COOKIE)
                ERROUT(EINVAL);

        switch (msg->header.cmd) {
        case NGM_DEFLATE_CONFIG:
            {
                struct ng_deflate_config *const cfg
                    = (struct ng_deflate_config *)msg->data;

                /* Check configuration. */
                if (msg->header.arglen != sizeof(*cfg))
                        ERROUT(EINVAL);
                if (cfg->enable) {
                    if (cfg->windowBits < 8 || cfg->windowBits > 15)
                        ERROUT(EINVAL);
                } else
                    cfg->windowBits = 0;

                /* Clear previous state. */
                if (priv->cfg.enable) {
                        if (priv->compress)
                                deflateEnd(&priv->cx);
                        else
                                inflateEnd(&priv->cx);
                        priv->cfg.enable = 0;
                }

                /* Configuration is OK, reset to it. */
                priv->cfg = *cfg;

                if (priv->cfg.enable) {
                        priv->cx.next_in = NULL;
                        int res;
                        if (priv->compress) {
                                if ((res = deflateInit2(&priv->cx,
                                    Z_DEFAULT_COMPRESSION, Z_DEFLATED,
                                    -cfg->windowBits, 8,
                                    Z_DEFAULT_STRATEGY)) != Z_OK) {
                                        log(LOG_NOTICE,
                                            "deflateInit2: error %d, %s\n",
                                            res, priv->cx.msg);
                                        priv->cfg.enable = 0;
                                        ERROUT(ENOMEM);
                                }
                        } else {
                                if ((res = inflateInit2(&priv->cx,
                                    -cfg->windowBits)) != Z_OK) {
                                        log(LOG_NOTICE,
                                            "inflateInit2: error %d, %s\n",
                                            res, priv->cx.msg);
                                        priv->cfg.enable = 0;
                                        ERROUT(ENOMEM);
                                }
                        }
                }

                /* Initialize other state. */
                priv->seqnum = 0;

                /* Save return address so we can send reset-req's */
                priv->ctrlnode = NGI_RETADDR(item);
                break;
            }

        case NGM_DEFLATE_RESETREQ:
                ng_deflate_reset_req(node);
                break;

        case NGM_DEFLATE_GET_STATS:
        case NGM_DEFLATE_CLR_STATS:
        case NGM_DEFLATE_GETCLR_STATS:
                /* Create response if requested. */
                if (msg->header.cmd != NGM_DEFLATE_CLR_STATS) {
                        NG_MKRESPONSE(resp, msg,
                            sizeof(struct ng_deflate_stats), M_NOWAIT);
                        if (resp == NULL)
                                ERROUT(ENOMEM);
                        bcopy(&priv->stats, resp->data,
                            sizeof(struct ng_deflate_stats));
                }

                /* Clear stats if requested. */
                if (msg->header.cmd != NGM_DEFLATE_GET_STATS)
                        bzero(&priv->stats,
                            sizeof(struct ng_deflate_stats));
                break;

        default:
                error = EINVAL;
                break;
        }
done:
        NG_RESPOND_MSG(error, node, item, resp);
        NG_FREE_MSG(msg);
        return (error);
}

/*
 * Receive incoming data on our hook.
 */
static int
ng_deflate_rcvdata(hook_p hook, item_p item)
{
        const node_p node = NG_HOOK_NODE(hook);
        const priv_p priv = NG_NODE_PRIVATE(node);
        struct mbuf *m, *out;
        int error;

        if (!priv->cfg.enable) {
                NG_FREE_ITEM(item);
                return (ENXIO);
        }

        NGI_GET_M(item, m);
        /* Compress */
        if (priv->compress) {
                if ((error = ng_deflate_compress(node, m, &out)) != 0) {
                        NG_FREE_ITEM(item);
                        log(LOG_NOTICE, "%s: error: %d\n", __func__, error);
                        return (error);
                }
        } else { /* Decompress */
                if ((error = ng_deflate_decompress(node, m, &out)) != 0) {
                        NG_FREE_ITEM(item);
                        log(LOG_NOTICE, "%s: error: %d\n", __func__, error);
                        if (priv->ctrlnode != 0) {
                                struct ng_mesg *msg;

                                /* Need to send a reset-request. */
                                NG_MKMESSAGE(msg, NGM_DEFLATE_COOKIE,
                                    NGM_DEFLATE_RESETREQ, 0, M_NOWAIT);
                                if (msg == NULL)
                                        return (error);
                                NG_SEND_MSG_ID(error, node, msg,
                                        priv->ctrlnode, 0);
                        }
                        return (error);
                }
        }

        NG_FWD_NEW_DATA(error, item, hook, out);
        return (error);
}

/*
 * Destroy node.
 */
static int
ng_deflate_shutdown(node_p node)
{
        const priv_p priv = NG_NODE_PRIVATE(node);

        /* Take down netgraph node. */
        if (priv->cfg.enable) {
            if (priv->compress)
                deflateEnd(&priv->cx);
            else
                inflateEnd(&priv->cx);
        }

        free(priv, M_NETGRAPH_DEFLATE);
        NG_NODE_SET_PRIVATE(node, NULL);
        NG_NODE_UNREF(node);            /* let the node escape */
        return (0);
}

/*
 * Hook disconnection
 */
static int
ng_deflate_disconnect(hook_p hook)
{
        const node_p node = NG_HOOK_NODE(hook);
        const priv_p priv = NG_NODE_PRIVATE(node);

        if (priv->cfg.enable) {
            if (priv->compress)
                deflateEnd(&priv->cx);
            else
                inflateEnd(&priv->cx);
            priv->cfg.enable = 0;
        }

        /* Go away if no longer connected. */
        if ((NG_NODE_NUMHOOKS(node) == 0) && NG_NODE_IS_VALID(node))
                ng_rmnode_self(node);
        return (0);
}

/************************************************************************
                        HELPER STUFF
 ************************************************************************/

/*
 * Compress/encrypt a packet and put the result in a new mbuf at *resultp.
 * The original mbuf is not free'd.
 */
static int
ng_deflate_compress(node_p node, struct mbuf *m, struct mbuf **resultp)
{
        const priv_p    priv = NG_NODE_PRIVATE(node);
        int             outlen, inlen;
        int             rtn;

        /* Initialize. */
        *resultp = NULL;

        inlen = m->m_pkthdr.len;

        priv->stats.FramesPlain++;
        priv->stats.InOctets+=inlen;

        if (inlen > DEFLATE_BUF_SIZE) {
                priv->stats.Errors++;
                NG_FREE_M(m);
                return (ENOMEM);
        }

        /* We must own the mbuf chain exclusively to modify it. */
        m = m_unshare(m, M_NOWAIT);
        if (m == NULL) {
                priv->stats.Errors++;
                return (ENOMEM);
        }

        /* Work with contiguous regions of memory. */
        m_copydata(m, 0, inlen, (caddr_t)priv->inbuf);
        outlen = DEFLATE_BUF_SIZE;

        /* Compress "inbuf" into "outbuf". */
        /* Prepare to compress. */
        if (priv->inbuf[0] != 0) {
                priv->cx.next_in = priv->inbuf;
                priv->cx.avail_in = inlen;
        } else {
                priv->cx.next_in = priv->inbuf + 1; /* compress protocol */
                priv->cx.avail_in = inlen - 1;
        }
        priv->cx.next_out = priv->outbuf + 2 + DEFLATE_HDRLEN;
        priv->cx.avail_out = outlen - 2 - DEFLATE_HDRLEN;

        /* Compress. */
        rtn = deflate(&priv->cx, Z_SYNC_FLUSH);

        /* Check return value. */
        if (rtn != Z_OK) {
                priv->stats.Errors++;
                log(LOG_NOTICE, "ng_deflate: compression error: %d (%s)\n",
                    rtn, priv->cx.msg);
                NG_FREE_M(m);
                return (EINVAL);
        }

        /* Calculate resulting size. */
        outlen -= priv->cx.avail_out;
        /*
         * Z_SYNC_FLUSH completes the current deflate block and follows
         * it with an empty stored block that is three bits plus filler
         * bits to the next byte, followed by four bytes (00 00 ff ff).
         * RFC 1979 Section 2.1, "Data" requires the four bytes be
         * removed before transmission.
         */
        outlen -= 4;
        MPASS(outlen > 0);
        MPASS(priv->outbuf[outlen + 0] == 0x00);
        MPASS(priv->outbuf[outlen + 1] == 0x00);
        MPASS(priv->outbuf[outlen + 2] == 0xff);
        MPASS(priv->outbuf[outlen + 3] == 0xff);

        /* If we can't compress this packet, send it as-is. */
        if (outlen > inlen) {
                /* Return original packet uncompressed. */
                *resultp = m;
                priv->stats.FramesUncomp++;
                priv->stats.OutOctets+=inlen;
        } else {
                /* Install header. */
                be16enc(priv->outbuf, PROT_COMPD);
                be16enc(priv->outbuf + 2, priv->seqnum);

                /* Return packet in an mbuf. */
                m_copyback(m, 0, outlen, (caddr_t)priv->outbuf);
                if (m->m_pkthdr.len < outlen) {
                        m_freem(m);
                        priv->stats.Errors++;
                        return (ENOMEM);
                } else if (outlen < m->m_pkthdr.len)
                        m_adj(m, outlen - m->m_pkthdr.len);
                *resultp = m;
                priv->stats.FramesComp++;
                priv->stats.OutOctets+=outlen;
        }

        /* Update sequence number. */
        priv->seqnum++;

        return (0);
}

/*
 * Decompress/decrypt packet and put the result in a new mbuf at *resultp.
 * The original mbuf is not free'd.
 */
static int
ng_deflate_decompress(node_p node, struct mbuf *m, struct mbuf **resultp)
{
        const priv_p    priv = NG_NODE_PRIVATE(node);
        int             outlen, inlen, datalen;
        int             rtn;
        uint16_t        proto;
        int             offset;
        uint16_t        rseqnum;
        u_char          headbuf[5];
        static u_char   EMPTY_BLOCK[4] = { 0x00, 0x00, 0xff, 0xff };

        /* Initialize. */
        *resultp = NULL;

        inlen = m->m_pkthdr.len;

        if (inlen > DEFLATE_BUF_SIZE) {
                priv->stats.Errors++;
                NG_FREE_M(m);
                priv->seqnum = 0;
                return (ENOMEM);
        }

        /* We must own the mbuf chain exclusively to modify it. */
        m = m_unshare(m, M_NOWAIT);
        if (m == NULL) {
                priv->stats.Errors++;
                return (ENOMEM);
        }

        /* Work with contiguous regions of memory. */
        m_copydata(m, 0, inlen, (caddr_t)priv->inbuf);

        /* Separate proto. */
        if ((priv->inbuf[0] & 0x01) != 0) {
                proto = priv->inbuf[0];
                offset = 1;
        } else {
                proto = be16dec(priv->inbuf);
                offset = 2;
        }

        priv->stats.InOctets += inlen;

        /* Packet is compressed, so decompress. */
        if (proto == PROT_COMPD) {
                priv->stats.FramesComp++;

                /* Check sequence number. */
                rseqnum = be16dec(priv->inbuf + offset);
                offset += 2;
                if (rseqnum != priv->seqnum) {
                        priv->stats.Errors++;
                        log(LOG_NOTICE, "ng_deflate: wrong sequence: %u "
                            "instead of %u\n", rseqnum, priv->seqnum);
                        NG_FREE_M(m);
                        priv->seqnum = 0;
                        return (EPIPE);
                }

                outlen = DEFLATE_BUF_SIZE;

                /* Decompress "inbuf" into "outbuf". */
                /* Prepare to decompress. */
                priv->cx.next_in = priv->inbuf + offset;
                priv->cx.avail_in = inlen - offset;
                /* Reserve space for protocol decompression. */
                priv->cx.next_out = priv->outbuf + 1;
                priv->cx.avail_out = outlen - 1;

                /* Decompress. */
                rtn = inflate(&priv->cx, Z_SYNC_FLUSH);

                /* Check return value. */
                if (rtn != Z_OK && rtn != Z_STREAM_END) {
                        priv->stats.Errors++;
                        NG_FREE_M(m);
                        priv->seqnum = 0;
                        log(LOG_NOTICE, "%s: decompression error: %d (%s)\n",
                            __func__, rtn, priv->cx.msg);

                        switch (rtn) {
                        case Z_MEM_ERROR:
                                return (ENOMEM);
                        case Z_DATA_ERROR:
                                return (EIO);
                        default:
                                return (EINVAL);
                        }
                }

                /* Handle the EMPTY_BLOCK omitted by sender */
                if (inflateSyncPoint(&priv->cx)) {
                        priv->cx.avail_in = 4;
                        priv->cx.next_in = EMPTY_BLOCK;
                        inflate(&priv->cx, Z_SYNC_FLUSH);
                }

                /* Calculate resulting size. */
                outlen -= priv->cx.avail_out;

                /* Decompress protocol. */
                if ((priv->outbuf[1] & 0x01) != 0) {
                        priv->outbuf[0] = 0;
                        /* Return packet in an mbuf. */
                        m_copyback(m, 0, outlen, (caddr_t)priv->outbuf);
                } else {
                        outlen--;
                        /* Return packet in an mbuf. */
                        m_copyback(m, 0, outlen, (caddr_t)(priv->outbuf + 1));
                }
                if (m->m_pkthdr.len < outlen) {
                        m_freem(m);
                        priv->stats.Errors++;
                        priv->seqnum = 0;
                        return (ENOMEM);
                } else if (outlen < m->m_pkthdr.len)
                        m_adj(m, outlen - m->m_pkthdr.len);
                *resultp = m;
                priv->stats.FramesPlain++;
                priv->stats.OutOctets+=outlen;

        } else {
                /* Packet is not compressed, just update dictionary. */
                priv->stats.FramesUncomp++;

                /*
                 * Fake a header for uncompressed data block
                 */
                datalen = inlen - offset + 1;
                headbuf[0] = 0x80;
                headbuf[1] = datalen & 0xff;
                headbuf[2] = datalen >> 8;
                headbuf[3] = (~datalen) & 0xff;
                headbuf[4] = (~datalen) >> 8;

                priv->cx.next_in = headbuf;
                priv->cx.avail_in = sizeof(headbuf);
                priv->cx.next_out = priv->outbuf;
                priv->cx.avail_out = DEFLATE_BUF_SIZE;

                rtn = inflate(&priv->cx, Z_NO_FLUSH);

                if (priv->inbuf[0] == 0) {
                        priv->cx.next_in =
                            priv->inbuf + 1; /* compress protocol */
                        priv->cx.avail_in = inlen - 1;
                } else {
                        priv->cx.next_in = priv->inbuf;
                        priv->cx.avail_in = inlen;
                }
                priv->cx.next_out = priv->outbuf;
                priv->cx.avail_out = DEFLATE_BUF_SIZE;

                rtn = inflate(&priv->cx, Z_SYNC_FLUSH);

                /* Check return value */
                if (rtn != Z_OK) {
                        priv->stats.Errors++;
                        log(LOG_NOTICE, "%s: inflate error: %d (%s)\n",
                            __func__, rtn, priv->cx.msg);
                        NG_FREE_M(m);
                        priv->seqnum = 0;
                        return (EINVAL);
                }

                *resultp = m;
                priv->stats.FramesPlain++;
                priv->stats.OutOctets += inlen;
        }

        /* Update sequence number. */
        priv->seqnum++;

        return (0);
}

/*
 * The peer has sent us a CCP ResetRequest, so reset our transmit state.
 */
static void
ng_deflate_reset_req(node_p node)
{
        const priv_p priv = NG_NODE_PRIVATE(node);

        priv->seqnum = 0;
        if (priv->cfg.enable) {
            if (priv->compress)
                deflateReset(&priv->cx);
            else
                inflateReset(&priv->cx);
        }
}