root/sys/dev/virtio/scmi/virtio_scmi.c
/*-
 * SPDX-License-Identifier: BSD-2-Clause
 *
 * Copyright (c) 2023 Arm Ltd
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice unmodified, this list of conditions, and the following
 *    disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

/* Driver for VirtIO SCMI device. */

#include <sys/cdefs.h>
#include <sys/param.h>
#include <sys/types.h>
#include <sys/eventhandler.h>
#include <sys/kernel.h>
#include <sys/malloc.h>
#include <sys/module.h>
#include <sys/queue.h>
#include <sys/sglist.h>

#include <machine/bus.h>
#include <machine/resource.h>
#include <sys/bus.h>

#include <dev/virtio/virtio.h>
#include <dev/virtio/virtqueue.h>
#include <dev/virtio/scmi/virtio_scmi.h>

struct vtscmi_pdu {
        enum vtscmi_chan        chan;
        struct sglist           sg;
        struct sglist_seg       segs[2];
        void                    *buf;
        SLIST_ENTRY(vtscmi_pdu) next;
};

struct vtscmi_queue {
        device_t                                dev;
        int                                     vq_id;
        unsigned int                            vq_sz;
        struct virtqueue                        *vq;
        struct mtx                              vq_mtx;
        struct vtscmi_pdu                       *pdus;
        SLIST_HEAD(pdus_head, vtscmi_pdu)       p_head;
        struct mtx                              p_mtx;
        virtio_scmi_rx_callback_t               *rx_callback;
        void                                    *priv;
};

struct vtscmi_softc {
        device_t        vtscmi_dev;
        uint64_t        vtscmi_features;
        uint8_t         vtscmi_vqs_cnt;
        struct vtscmi_queue     vtscmi_queues[VIRTIO_SCMI_CHAN_MAX];
        bool            has_p2a;
        bool            has_shared;
};

static device_t vtscmi_dev;

static int vtscmi_modevent(module_t, int, void *);

static int      vtscmi_probe(device_t);
static int      vtscmi_attach(device_t);
static int      vtscmi_detach(device_t);
static int      vtscmi_shutdown(device_t);
static int      vtscmi_negotiate_features(struct vtscmi_softc *);
static int      vtscmi_setup_features(struct vtscmi_softc *);
static void     vtscmi_vq_intr(void *);
static int      vtscmi_alloc_virtqueues(struct vtscmi_softc *);
static int      vtscmi_alloc_queues(struct vtscmi_softc *);
static void     vtscmi_free_queues(struct vtscmi_softc *);
static void     *virtio_scmi_pdu_get(struct vtscmi_queue *, void *,
    unsigned int, unsigned int);
static void     virtio_scmi_pdu_put(device_t, struct vtscmi_pdu *);

static struct virtio_feature_desc vtscmi_feature_desc[] = {
        { VIRTIO_SCMI_F_P2A_CHANNELS, "P2AChannel" },
        { VIRTIO_SCMI_F_SHARED_MEMORY, "SharedMem" },
        { 0, NULL }
};

static device_method_t vtscmi_methods[] = {
        /* Device methods. */
        DEVMETHOD(device_probe,         vtscmi_probe),
        DEVMETHOD(device_attach,        vtscmi_attach),
        DEVMETHOD(device_detach,        vtscmi_detach),
        DEVMETHOD(device_shutdown,      vtscmi_shutdown),

        DEVMETHOD_END
};

static driver_t vtscmi_driver = {
        "vtscmi",
        vtscmi_methods,
        sizeof(struct vtscmi_softc)
};

VIRTIO_DRIVER_MODULE(virtio_scmi, vtscmi_driver, vtscmi_modevent, NULL);
MODULE_VERSION(virtio_scmi, 1);
MODULE_DEPEND(virtio_scmi, virtio, 1, 1, 1);

VIRTIO_SIMPLE_PNPINFO(virtio_scmi, VIRTIO_ID_SCMI, "VirtIO SCMI Adapter");

static int
vtscmi_modevent(module_t mod, int type, void *unused)
{
        int error;

        switch (type) {
        case MOD_LOAD:
        case MOD_QUIESCE:
        case MOD_UNLOAD:
        case MOD_SHUTDOWN:
                error = 0;
                break;
        default:
                error = EOPNOTSUPP;
                break;
        }

        return (error);
}

static int
vtscmi_probe(device_t dev)
{
        return (VIRTIO_SIMPLE_PROBE(dev, virtio_scmi));
}

static int
vtscmi_attach(device_t dev)
{
        struct vtscmi_softc *sc;
        int error;

        /* Only one SCMI device per-agent */
        if (vtscmi_dev != NULL)
                return (EEXIST);

        sc = device_get_softc(dev);
        sc->vtscmi_dev = dev;

        virtio_set_feature_desc(dev, vtscmi_feature_desc);
        error = vtscmi_setup_features(sc);
        if (error) {
                device_printf(dev, "cannot setup features\n");
                goto fail;
        }

        error = vtscmi_alloc_virtqueues(sc);
        if (error) {
                device_printf(dev, "cannot allocate virtqueues\n");
                goto fail;
        }

        error = vtscmi_alloc_queues(sc);
        if (error) {
                device_printf(dev, "cannot allocate queues\n");
                goto fail;
        }

        error = virtio_setup_intr(dev, INTR_TYPE_MISC);
        if (error) {
                device_printf(dev, "cannot setup intr\n");
                vtscmi_free_queues(sc);
                goto fail;
        }

        /* Save unique device */
        vtscmi_dev = sc->vtscmi_dev;

fail:

        return (error);
}

static int
vtscmi_detach(device_t dev)
{
        struct vtscmi_softc *sc;

        sc = device_get_softc(dev);

        /* These also disable related interrupts */
        virtio_scmi_channel_callback_set(dev, VIRTIO_SCMI_CHAN_A2P, NULL, NULL);
        virtio_scmi_channel_callback_set(dev, VIRTIO_SCMI_CHAN_P2A, NULL, NULL);

        virtio_stop(dev);

        vtscmi_free_queues(sc);

        return (0);
}

static int
vtscmi_shutdown(device_t dev)
{

        return (0);
}

static int
vtscmi_negotiate_features(struct vtscmi_softc *sc)
{
        device_t dev;
        uint64_t features;

        dev = sc->vtscmi_dev;
        /* We still don't support shared mem (stats)...so don't advertise it */
        features = VIRTIO_SCMI_F_P2A_CHANNELS;

        sc->vtscmi_features = virtio_negotiate_features(dev, features);
        return (virtio_finalize_features(dev));
}

static int
vtscmi_setup_features(struct vtscmi_softc *sc)
{
        device_t dev;
        int error;

        dev = sc->vtscmi_dev;
        error = vtscmi_negotiate_features(sc);
        if (error)
                return (error);

        if (virtio_with_feature(dev, VIRTIO_SCMI_F_P2A_CHANNELS))
                sc->has_p2a = true;
        if (virtio_with_feature(dev, VIRTIO_SCMI_F_SHARED_MEMORY))
                sc->has_shared = true;

        device_printf(dev, "Platform %s P2A channel.\n",
            sc->has_p2a ? "supports" : "does NOT support");

        return (0);
}

static int
vtscmi_alloc_queues(struct vtscmi_softc *sc)
{
        int idx;

        for (idx = VIRTIO_SCMI_CHAN_A2P; idx < VIRTIO_SCMI_CHAN_MAX; idx++) {
                int i, vq_sz;
                struct vtscmi_queue *q;
                struct vtscmi_pdu *pdu;

                if (idx == VIRTIO_SCMI_CHAN_P2A && !sc->has_p2a)
                        continue;

                q = &sc->vtscmi_queues[idx];
                q->dev = sc->vtscmi_dev;
                q->vq_id = idx;
                vq_sz = virtqueue_size(q->vq);
                q->vq_sz = idx != VIRTIO_SCMI_CHAN_A2P ? vq_sz : vq_sz / 2;

                q->pdus = mallocarray(q->vq_sz, sizeof(*pdu), M_DEVBUF,
                    M_ZERO | M_WAITOK);

                SLIST_INIT(&q->p_head);
                for (i = 0, pdu = q->pdus; i < q->vq_sz; i++, pdu++) {
                        pdu->chan = idx;
                        //XXX Maybe one seg redndant for P2A
                        sglist_init(&pdu->sg,
                            idx == VIRTIO_SCMI_CHAN_A2P ? 2 : 1, pdu->segs);
                        SLIST_INSERT_HEAD(&q->p_head, pdu, next);
                }

                mtx_init(&q->p_mtx, "vtscmi_pdus", "VTSCMI", MTX_SPIN);
                mtx_init(&q->vq_mtx, "vtscmi_vq", "VTSCMI", MTX_SPIN);
        }

        return (0);
}

static void
vtscmi_free_queues(struct vtscmi_softc *sc)
{
        int idx;

        for (idx = VIRTIO_SCMI_CHAN_A2P; idx < VIRTIO_SCMI_CHAN_MAX; idx++) {
                struct vtscmi_queue *q;

                if (idx == VIRTIO_SCMI_CHAN_P2A && !sc->has_p2a)
                        continue;

                q = &sc->vtscmi_queues[idx];
                if (q->vq_sz == 0)
                        continue;

                free(q->pdus, M_DEVBUF);
                mtx_destroy(&q->p_mtx);
                mtx_destroy(&q->vq_mtx);
        }
}

static void
vtscmi_vq_intr(void *arg)
{
        struct vtscmi_queue *q = arg;

        /*
         * TODO
         * - consider pressure on RX by msg floods
         *   + Does it need a taskqueue_ like virtio/net to postpone processing
         *     under pressure ? (SCMI is low_freq compared to network though)
         */
        for (;;) {
                struct vtscmi_pdu *pdu;
                uint32_t rx_len;

                mtx_lock_spin(&q->vq_mtx);
                pdu = virtqueue_dequeue(q->vq, &rx_len);
                mtx_unlock_spin(&q->vq_mtx);
                if (!pdu)
                        return;

                if (q->rx_callback)
                        q->rx_callback(pdu->buf, rx_len, q->priv);

                /* Note that this only frees the PDU, NOT the buffer itself */
                virtio_scmi_pdu_put(q->dev, pdu);
        }
}

static int
vtscmi_alloc_virtqueues(struct vtscmi_softc *sc)
{
        device_t dev;
        struct vq_alloc_info vq_info[VIRTIO_SCMI_CHAN_MAX];

        dev = sc->vtscmi_dev;
        sc->vtscmi_vqs_cnt = sc->has_p2a ? 2 : 1;

        VQ_ALLOC_INFO_INIT(&vq_info[VIRTIO_SCMI_CHAN_A2P], 0,
                           vtscmi_vq_intr,
                           &sc->vtscmi_queues[VIRTIO_SCMI_CHAN_A2P],
                           &sc->vtscmi_queues[VIRTIO_SCMI_CHAN_A2P].vq,
                           "%s cmdq", device_get_nameunit(dev));

        if (sc->has_p2a) {
                VQ_ALLOC_INFO_INIT(&vq_info[VIRTIO_SCMI_CHAN_P2A], 0,
                                   vtscmi_vq_intr,
                                   &sc->vtscmi_queues[VIRTIO_SCMI_CHAN_P2A],
                                   &sc->vtscmi_queues[VIRTIO_SCMI_CHAN_P2A].vq,
                                   "%s evtq", device_get_nameunit(dev));
        }

        return (virtio_alloc_virtqueues(dev, sc->vtscmi_vqs_cnt, vq_info));
}

static void *
virtio_scmi_pdu_get(struct vtscmi_queue *q, void *buf, unsigned int tx_len,
    unsigned int rx_len)
{
        struct vtscmi_pdu *pdu = NULL;

        if (rx_len == 0)
                return (NULL);

        mtx_lock_spin(&q->p_mtx);
        if (!SLIST_EMPTY(&q->p_head)) {
                pdu = SLIST_FIRST(&q->p_head);
                SLIST_REMOVE_HEAD(&q->p_head, next);
        }
        mtx_unlock_spin(&q->p_mtx);

        if (pdu == NULL) {
                device_printf(q->dev, "Cannot allocate PDU.\n");
                return (NULL);
        }

        /*Save msg buffer for easy access */
        pdu->buf = buf;
        if (tx_len != 0)
                sglist_append(&pdu->sg, pdu->buf, tx_len);
        sglist_append(&pdu->sg, pdu->buf, rx_len);

        return (pdu);
}

static void
virtio_scmi_pdu_put(device_t dev, struct vtscmi_pdu *pdu)
{
        struct vtscmi_softc *sc;
        struct vtscmi_queue *q;

        if (pdu == NULL)
                return;

        sc = device_get_softc(dev);
        q = &sc->vtscmi_queues[pdu->chan];

        sglist_reset(&pdu->sg);

        mtx_lock_spin(&q->p_mtx);
        SLIST_INSERT_HEAD(&q->p_head, pdu, next);
        mtx_unlock_spin(&q->p_mtx);
}

device_t
virtio_scmi_transport_get(void)
{
        return (vtscmi_dev);
}

int
virtio_scmi_channel_size_get(device_t dev, enum vtscmi_chan chan)
{
        struct vtscmi_softc *sc;

        sc = device_get_softc(dev);
        if (chan >= sc->vtscmi_vqs_cnt)
                return (0);

        return (sc->vtscmi_queues[chan].vq_sz);
}

int
virtio_scmi_channel_callback_set(device_t dev, enum vtscmi_chan chan,
    virtio_scmi_rx_callback_t *cb, void *priv)
{
        struct vtscmi_softc *sc;

        sc = device_get_softc(dev);
        if (chan >= sc->vtscmi_vqs_cnt)
                return (1);

        if (cb == NULL)
                virtqueue_disable_intr(sc->vtscmi_queues[chan].vq);

        sc->vtscmi_queues[chan].rx_callback = cb;
        sc->vtscmi_queues[chan].priv = priv;

        /* Enable Interrupt on VQ once the callback is set */
        if (cb != NULL)
                /*
                 * TODO
                 * Does this need a taskqueue_ task to process already pending
                 * messages ?
                 */
                virtqueue_enable_intr(sc->vtscmi_queues[chan].vq);

        device_printf(dev, "%sabled interrupts on VQ[%d].\n",
            cb ? "En" : "Dis", chan);

        return (0);
}

int
virtio_scmi_message_enqueue(device_t dev, enum vtscmi_chan chan,
    void *buf, unsigned int tx_len, unsigned int rx_len)
{
        struct vtscmi_softc *sc;
        struct vtscmi_pdu *pdu;
        struct vtscmi_queue *q;
        int ret;

        sc = device_get_softc(dev);
        if (chan >= sc->vtscmi_vqs_cnt)
                return (1);

        q = &sc->vtscmi_queues[chan];
        pdu = virtio_scmi_pdu_get(q, buf, tx_len, rx_len);
        if (pdu == NULL)
                return (ENXIO);

        mtx_lock_spin(&q->vq_mtx);
        ret = virtqueue_enqueue(q->vq, pdu, &pdu->sg,
            chan == VIRTIO_SCMI_CHAN_A2P ? 1 : 0, 1);
        if (ret == 0)
                virtqueue_notify(q->vq);
        mtx_unlock_spin(&q->vq_mtx);

        return (ret);
}

void *
virtio_scmi_message_poll(device_t dev, uint32_t *rx_len)
{
        struct vtscmi_softc *sc;
        struct vtscmi_queue *q;
        struct vtscmi_pdu *pdu;
        void *buf = NULL;

        sc = device_get_softc(dev);

        q = &sc->vtscmi_queues[VIRTIO_SCMI_CHAN_A2P];

        mtx_lock_spin(&q->vq_mtx);
        /* Not using virtqueue_poll since has no configurable timeout */
        pdu = virtqueue_dequeue(q->vq, rx_len);
        mtx_unlock_spin(&q->vq_mtx);
        if (pdu != NULL) {
                buf = pdu->buf;
                virtio_scmi_pdu_put(dev, pdu);
        }

        return (buf);
}