root/sys/dev/tcp_log/tcp_log_dev.c
/*-
 * SPDX-License-Identifier: BSD-2-Clause
 *
 * Copyright (c) 2016-2017 Netflix, Inc.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 */

#include <sys/param.h>
#include <sys/conf.h>
#include <sys/fcntl.h>
#include <sys/filio.h>
#include <sys/kernel.h>
#include <sys/lock.h>
#include <sys/malloc.h>
#include <sys/module.h>
#include <sys/poll.h>
#include <sys/queue.h>
#include <sys/refcount.h>
#include <sys/mutex.h>
#include <sys/selinfo.h>
#include <sys/socket.h>
#include <sys/socketvar.h>
#include <sys/sysctl.h>
#include <sys/tree.h>
#include <sys/uio.h>
#include <machine/atomic.h>
#include <sys/counter.h>

#include <dev/tcp_log/tcp_log_dev.h>

#ifdef TCPLOG_DEBUG_COUNTERS
extern counter_u64_t tcp_log_que_read;
extern counter_u64_t tcp_log_que_freed;
#endif

static struct cdev *tcp_log_dev;
static struct selinfo tcp_log_sel;

static struct log_queueh tcp_log_dev_queue_head = STAILQ_HEAD_INITIALIZER(tcp_log_dev_queue_head);
static struct log_infoh tcp_log_dev_reader_head = STAILQ_HEAD_INITIALIZER(tcp_log_dev_reader_head);

MALLOC_DEFINE(M_TCPLOGDEV, "tcp_log_dev", "TCP log device data structures");

static int      tcp_log_dev_listeners = 0;

static struct mtx tcp_log_dev_queue_lock;

#define TCP_LOG_DEV_QUEUE_LOCK()        mtx_lock(&tcp_log_dev_queue_lock)
#define TCP_LOG_DEV_QUEUE_UNLOCK()      mtx_unlock(&tcp_log_dev_queue_lock)
#define TCP_LOG_DEV_QUEUE_LOCK_ASSERT() mtx_assert(&tcp_log_dev_queue_lock, MA_OWNED)
#define TCP_LOG_DEV_QUEUE_UNLOCK_ASSERT() mtx_assert(&tcp_log_dev_queue_lock, MA_NOTOWNED)
#define TCP_LOG_DEV_QUEUE_REF(tldq)     refcount_acquire(&((tldq)->tldq_refcnt))
#define TCP_LOG_DEV_QUEUE_UNREF(tldq)   refcount_release(&((tldq)->tldq_refcnt))

static void     tcp_log_dev_clear_refcount(struct tcp_log_dev_queue *entry);
static void     tcp_log_dev_clear_cdevpriv(void *data);
static int      tcp_log_dev_open(struct cdev *dev __unused, int flags,
    int devtype __unused, struct thread *td __unused);
static int      tcp_log_dev_write(struct cdev *dev __unused,
    struct uio *uio __unused, int flags __unused);
static int      tcp_log_dev_read(struct cdev *dev __unused, struct uio *uio,
    int flags __unused);
static int      tcp_log_dev_ioctl(struct cdev *dev __unused, u_long cmd,
    caddr_t data, int fflag __unused, struct thread *td __unused);
static int      tcp_log_dev_poll(struct cdev *dev __unused, int events,
    struct thread *td);

enum tcp_log_dev_queue_lock_state {
        QUEUE_UNLOCKED = 0,
        QUEUE_LOCKED,
};

static struct cdevsw tcp_log_cdevsw = {
        .d_version =    D_VERSION,
        .d_read =       tcp_log_dev_read,
        .d_open =       tcp_log_dev_open,
        .d_write =      tcp_log_dev_write,
        .d_poll =       tcp_log_dev_poll,
        .d_ioctl =      tcp_log_dev_ioctl,
#ifdef NOTYET
        .d_mmap =       tcp_log_dev_mmap,
#endif
        .d_name =       "tcp_log",
};

static __inline void
tcp_log_dev_queue_validate_lock(int lockstate)
{

#ifdef INVARIANTS
        switch (lockstate) {
        case QUEUE_LOCKED:
                TCP_LOG_DEV_QUEUE_LOCK_ASSERT();
                break;
        case QUEUE_UNLOCKED:
                TCP_LOG_DEV_QUEUE_UNLOCK_ASSERT();
                break;
        default:
                kassert_panic("%s:%d: unknown queue lock state", __func__,
                    __LINE__);
        }
#endif
}

/*
 * Clear the refcount. If appropriate, it will remove the entry from the
 * queue and call the destructor.
 *
 * This must be called with the queue lock held.
 */
static void
tcp_log_dev_clear_refcount(struct tcp_log_dev_queue *entry)
{

        KASSERT(entry != NULL, ("%s: called with NULL entry", __func__));

        TCP_LOG_DEV_QUEUE_LOCK_ASSERT();

        if (TCP_LOG_DEV_QUEUE_UNREF(entry)) {
#ifdef TCPLOG_DEBUG_COUNTERS
                counter_u64_add(tcp_log_que_freed, 1);
#endif
                /* Remove the entry from the queue and call the destructor. */
                STAILQ_REMOVE(&tcp_log_dev_queue_head, entry, tcp_log_dev_queue,
                    tldq_queue);
                (*entry->tldq_dtor)(entry);
        }
}

static void
tcp_log_dev_clear_cdevpriv(void *data)
{
        struct tcp_log_dev_info *priv;
        struct tcp_log_dev_queue *entry, *entry_tmp;

        priv = (struct tcp_log_dev_info *)data;
        if (priv == NULL)
                return;

        /*
         * Lock the queue and drop our references. We hold references to all
         * the entries starting with tldi_head (or, if tldi_head == NULL, all
         * entries in the queue).
         * 
         * Because we don't want anyone adding addition things to the queue
         * while we are doing this, we lock the queue.
         */
        TCP_LOG_DEV_QUEUE_LOCK();
        if (priv->tldi_head != NULL) {
                entry = priv->tldi_head;
                STAILQ_FOREACH_FROM_SAFE(entry, &tcp_log_dev_queue_head,
                    tldq_queue, entry_tmp) {
                        tcp_log_dev_clear_refcount(entry);
                }
        }
        tcp_log_dev_listeners--;
        KASSERT(tcp_log_dev_listeners >= 0,
            ("%s: tcp_log_dev_listeners is unexpectedly negative", __func__));
        STAILQ_REMOVE(&tcp_log_dev_reader_head, priv, tcp_log_dev_info,
            tldi_list);
        TCP_LOG_DEV_QUEUE_LOCK_ASSERT();
        TCP_LOG_DEV_QUEUE_UNLOCK();
        free(priv, M_TCPLOGDEV);
}

static int
tcp_log_dev_open(struct cdev *dev __unused, int flags, int devtype __unused,
    struct thread *td __unused)
{
        struct tcp_log_dev_info *priv;
        struct tcp_log_dev_queue *entry;
        int rv;

        /*
         * Ideally, we shouldn't see these because of file system
         * permissions.
         */
        if (flags & (FWRITE | FEXEC | FAPPEND | O_TRUNC))
                return (ENODEV);

        /* Allocate space to hold information about where we are. */
        priv = malloc(sizeof(struct tcp_log_dev_info), M_TCPLOGDEV,
            M_ZERO | M_WAITOK);

        /* Stash the private data away. */
        rv = devfs_set_cdevpriv((void *)priv, tcp_log_dev_clear_cdevpriv);
        if (!rv) {
                /*
                 * Increase the listener count, add this reader to the list, and
                 * take references on all current queues.
                 */
                TCP_LOG_DEV_QUEUE_LOCK();
                tcp_log_dev_listeners++;
                STAILQ_INSERT_HEAD(&tcp_log_dev_reader_head, priv, tldi_list);
                priv->tldi_head = STAILQ_FIRST(&tcp_log_dev_queue_head);
                if (priv->tldi_head != NULL)
                        priv->tldi_cur = priv->tldi_head->tldq_buf;
                STAILQ_FOREACH(entry, &tcp_log_dev_queue_head, tldq_queue)
                        TCP_LOG_DEV_QUEUE_REF(entry);
                TCP_LOG_DEV_QUEUE_UNLOCK();
        } else {
                /* Free the entry. */
                free(priv, M_TCPLOGDEV);
        }
        return (rv);
}

static int
tcp_log_dev_write(struct cdev *dev __unused, struct uio *uio __unused,
    int flags __unused)
{

        return (ENODEV);
}

static __inline void
tcp_log_dev_rotate_bufs(struct tcp_log_dev_info *priv, int *lockstate)
{
        struct tcp_log_dev_queue *entry;

        KASSERT(priv->tldi_head != NULL,
            ("%s:%d: priv->tldi_head unexpectedly NULL",
            __func__, __LINE__));
        KASSERT(priv->tldi_head->tldq_buf == priv->tldi_cur,
            ("%s:%d: buffer mismatch (%p vs %p)",
            __func__, __LINE__, priv->tldi_head->tldq_buf,
            priv->tldi_cur));
        tcp_log_dev_queue_validate_lock(*lockstate);

        if (*lockstate == QUEUE_UNLOCKED) {
                TCP_LOG_DEV_QUEUE_LOCK();
                *lockstate = QUEUE_LOCKED;
        }
        entry = priv->tldi_head;
        priv->tldi_head = STAILQ_NEXT(entry, tldq_queue);
        tcp_log_dev_clear_refcount(entry);
        priv->tldi_cur = NULL;
}

static int
tcp_log_dev_read(struct cdev *dev __unused, struct uio *uio, int flags)
{
        struct tcp_log_common_header *buf;
        struct tcp_log_dev_info *priv;
        struct tcp_log_dev_queue *entry;
        ssize_t len;
        int lockstate, rv;

        /* Get our private info. */
        rv = devfs_get_cdevpriv((void **)&priv);
        if (rv)
                return (rv);

        lockstate = QUEUE_UNLOCKED;

        /* Do we need to get a new buffer? */
        while (priv->tldi_cur == NULL ||
            priv->tldi_cur->tlch_length <= priv->tldi_off) {
                /* Did we somehow forget to rotate? */
                KASSERT(priv->tldi_cur == NULL,
                    ("%s:%d: tldi_cur is unexpectedly non-NULL", __func__,
                    __LINE__));
                if (priv->tldi_cur != NULL)
                        tcp_log_dev_rotate_bufs(priv, &lockstate);

                /*
                 * Before we start looking at tldi_head, we need a lock on the
                 * queue to make sure tldi_head stays stable.
                 */
                if (lockstate == QUEUE_UNLOCKED) {
                        TCP_LOG_DEV_QUEUE_LOCK();
                        lockstate = QUEUE_LOCKED;
                }

                /* We need the next buffer. Do we have one? */
                if (priv->tldi_head == NULL && (flags & FNONBLOCK)) {
                        rv = EAGAIN;
                        goto done;
                }
                if (priv->tldi_head == NULL) {
                        /* Sleep and wait for more things we can read. */
                        rv = mtx_sleep(&tcp_log_dev_listeners,
                            &tcp_log_dev_queue_lock, PCATCH, "tcplogdev", 0);
                        if (rv)
                                goto done;
                        if (priv->tldi_head == NULL)
                                continue;
                }

                /*
                 * We have an entry to read. We want to try to create a
                 * buffer, if one doesn't already exist.
                 */
                entry = priv->tldi_head;
                if (entry->tldq_buf == NULL) {
                        TCP_LOG_DEV_QUEUE_LOCK_ASSERT();
                        buf = (*entry->tldq_xform)(entry);
                        if (buf == NULL) {
                                rv = EBUSY;
                                goto done;
                        }
                        entry->tldq_buf = buf;
                }

                priv->tldi_cur = entry->tldq_buf;
                priv->tldi_off = 0;
        }

        /* Copy what we can from this buffer to the output buffer. */
        if (uio->uio_resid > 0) {
                /* Drop locks so we can take page faults. */
                if (lockstate == QUEUE_LOCKED)
                        TCP_LOG_DEV_QUEUE_UNLOCK();
                lockstate = QUEUE_UNLOCKED;

                KASSERT(priv->tldi_cur != NULL,
                    ("%s: priv->tldi_cur is unexpectedly NULL", __func__));

                /* Copy as much as we can to this uio. */
                len = priv->tldi_cur->tlch_length - priv->tldi_off;
                if (len > uio->uio_resid)
                        len = uio->uio_resid;
                rv = uiomove(((uint8_t *)priv->tldi_cur) + priv->tldi_off,
                    len, uio);
                if (rv != 0)
                        goto done;
                priv->tldi_off += len;
#ifdef TCPLOG_DEBUG_COUNTERS
                counter_u64_add(tcp_log_que_read, len);
#endif
        }
        /* Are we done with this buffer? If so, find the next one. */
        if (priv->tldi_off >= priv->tldi_cur->tlch_length) {
                KASSERT(priv->tldi_off == priv->tldi_cur->tlch_length,
                    ("%s: offset (%ju) exceeds length (%ju)", __func__,
                    (uintmax_t)priv->tldi_off,
                    (uintmax_t)priv->tldi_cur->tlch_length));
                tcp_log_dev_rotate_bufs(priv, &lockstate);
        }
done:
        tcp_log_dev_queue_validate_lock(lockstate);
        if (lockstate == QUEUE_LOCKED)
                TCP_LOG_DEV_QUEUE_UNLOCK();
        return (rv);
}

static int
tcp_log_dev_ioctl(struct cdev *dev __unused, u_long cmd, caddr_t data,
    int fflag __unused, struct thread *td __unused)
{
        struct tcp_log_dev_info *priv;
        int rv;

        /* Get our private info. */
        rv = devfs_get_cdevpriv((void **)&priv);
        if (rv)
                return (rv);

        /*
         * Set things. Here, we are most concerned about the non-blocking I/O
         * flag.
         */
        rv = 0;
        switch (cmd) {
        case FIONBIO:
                break;
        case FIOASYNC:
                if (*(int *)data != 0)
                        rv = EINVAL;
                break;
        default:
                rv = ENOIOCTL;
        }
        return (rv);
}

static int
tcp_log_dev_poll(struct cdev *dev __unused, int events, struct thread *td)
{
        struct tcp_log_dev_info *priv;
        int revents;

        /*
         * Get our private info. If this fails, claim that all events are
         * ready. That should prod the user to do something that will
         * make the error evident to them.
         */
        if (devfs_get_cdevpriv((void **)&priv))
                return (events);

        revents = 0;
        if (events & (POLLIN | POLLRDNORM)) {
                /*
                 * We can (probably) read right now if we are partway through
                 * a buffer or if we are just about to start a buffer.
                 * Because we are going to read tldi_head, we should acquire
                 * a read lock on the queue.
                 */
                TCP_LOG_DEV_QUEUE_LOCK();
                if ((priv->tldi_head != NULL && priv->tldi_cur == NULL) ||
                    (priv->tldi_cur != NULL &&
                    priv->tldi_off < priv->tldi_cur->tlch_length))
                        revents = events & (POLLIN | POLLRDNORM);
                else
                        selrecord(td, &tcp_log_sel);
                TCP_LOG_DEV_QUEUE_UNLOCK();
        } else {
                /*
                 * It only makes sense to poll for reading. So, again, prod the
                 * user to do something that will make the error of their ways
                 * apparent.
                 */
                revents = events;
        }
        return (revents);
}

int
tcp_log_dev_add_log(struct tcp_log_dev_queue *entry)
{
        struct tcp_log_dev_info *priv;
        int rv;
        bool wakeup_needed;

        KASSERT(entry->tldq_buf != NULL || entry->tldq_xform != NULL,
            ("%s: Called with both tldq_buf and tldq_xform set to NULL",
            __func__));
        KASSERT(entry->tldq_dtor != NULL,
            ("%s: Called with tldq_dtor set to NULL", __func__));

        /* Get a lock on the queue. */
        TCP_LOG_DEV_QUEUE_LOCK();

        /* If no one is listening, tell the caller to free the resources. */
        if (tcp_log_dev_listeners == 0) {
                rv = ENXIO;
                goto done;
        }

        /* Add this to the end of the tailq. */
        STAILQ_INSERT_TAIL(&tcp_log_dev_queue_head, entry, tldq_queue);

        /* Add references for all current listeners. */
        refcount_init(&entry->tldq_refcnt, tcp_log_dev_listeners);

        /*
         * If any listener is currently stuck on NULL, that means they are
         * waiting. Point their head to this new entry.
         */
        wakeup_needed = false;
        STAILQ_FOREACH(priv, &tcp_log_dev_reader_head, tldi_list)
                if (priv->tldi_head == NULL) {
                        priv->tldi_head = entry;
                        wakeup_needed = true;
                }

        if (wakeup_needed) {
                selwakeup(&tcp_log_sel);
                wakeup(&tcp_log_dev_listeners);
        }

        rv = 0;

done:
        TCP_LOG_DEV_QUEUE_LOCK_ASSERT();
        TCP_LOG_DEV_QUEUE_UNLOCK();
        return (rv);
}

static int
tcp_log_dev_modevent(module_t mod __unused, int type, void *data __unused)
{

        /* TODO: Support intelligent unloading. */
        switch (type) {
        case MOD_LOAD:
                if (bootverbose)
                        printf("tcp_log: tcp_log device\n");
                memset(&tcp_log_sel, 0, sizeof(tcp_log_sel));
                memset(&tcp_log_dev_queue_lock, 0, sizeof(struct mtx));
                mtx_init(&tcp_log_dev_queue_lock, "tcp_log dev",
                         "tcp_log device queues", MTX_DEF);
                tcp_log_dev = make_dev_credf(MAKEDEV_ETERNAL_KLD,
                    &tcp_log_cdevsw, 0, NULL, UID_ROOT, GID_WHEEL, 0400,
                    "tcp_log");
                break;
        default:
                return (EOPNOTSUPP);
        }

        return (0);
}

DEV_MODULE(tcp_log_dev, tcp_log_dev_modevent, NULL);
MODULE_VERSION(tcp_log_dev, 1);