root/usr/src/lib/lib9p/common/transport/socket.c
/*
 * Copyright 2016 Jakub Klama <jceel@FreeBSD.org>
 * All rights reserved
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted providing that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
 * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 *
 * Copyright 2021 Joyent, Inc.
 */

#include <stdlib.h>
#include <errno.h>
#include <string.h>
#include <unistd.h>
#include <pthread.h>
#include <assert.h>
#include <sys/types.h>
#ifdef __APPLE__
# include "../apple_endian.h"
#elif __illumos__
# include <sys/param.h>
# include <port.h>
# include "../illumos_endian.h"
#else
# include <sys/endian.h>
#endif
#include <sys/socket.h>
#ifndef __illumos__
# include <sys/event.h>
#endif
#include <sys/uio.h>
#include <netdb.h>
#include "../lib9p.h"
#include "../lib9p_impl.h"
#include "../log.h"
#include "socket.h"

struct l9p_socket_softc
{
        struct l9p_connection *ls_conn;
        struct sockaddr ls_sockaddr;
        socklen_t ls_socklen;
        pthread_t ls_thread;
        int ls_fd;
};

#ifdef __FreeBSD__
struct event_svr {
        struct kevent *ev_kev;
        struct kevent *ev_event;
        int ev_kq;
};
#elif __illumos__
struct event_svr {
        port_event_t *ev_pe;
        int ev_port;
};
#else
#error "No event server defined"
#endif

static int l9p_init_event_svr(struct event_svr *, uint_t);
static uint_t l9p_get_server_addrs(const char *, const char *,
    struct addrinfo **);
static uint_t l9p_bind_addrs(struct event_svr *, struct addrinfo *, uint_t,
    int **);
static int l9p_event_get(struct l9p_server *, struct event_svr *, uint_t,
    void (*cb)(struct l9p_server *, int));
static int l9p_socket_readmsg(struct l9p_socket_softc *, void **, size_t *);
static int l9p_socket_get_response_buffer(struct l9p_request *,
    struct iovec *, size_t *, void *);
static int l9p_socket_send_response(struct l9p_request *, const struct iovec *,
    const size_t, const size_t, void *);
static void l9p_socket_drop_response(struct l9p_request *, const struct iovec *,
    size_t, void *);
static void *l9p_socket_thread(void *);
static ssize_t xread(int, void *, size_t);
static ssize_t xwrite(int, void *, size_t);

int
l9p_start_server(struct l9p_server *server, const char *host, const char *port)
{
        struct addrinfo *res = NULL;
        int *sockets = NULL;
        uint_t naddrs = 0;
        uint_t nsockets = 0;
        uint_t i;
        struct event_svr esvr;

        naddrs = l9p_get_server_addrs(host, port, &res);
        if (naddrs == 0)
                return (-1);

        if (l9p_init_event_svr(&esvr, naddrs) != 0) {
                freeaddrinfo(res);
                return (-1);
        }

        nsockets = l9p_bind_addrs(&esvr, res, naddrs, &sockets);

        /*
         * We don't need res, after this, so free it and NULL it to prevent
         * any possible use after free.
         */
        freeaddrinfo(res);
        res = NULL;

        if (nsockets == 0)
                goto fail;

        for (;;) {
                if (l9p_event_get(server, &esvr, nsockets,
                    l9p_socket_accept) < 0)
                        break;
        }

        /* We get here if something failed */
        for (i = 0; i < nsockets; i++)
                close(sockets[i]);

fail:
        free(sockets);

#ifdef __FreeBSD__
        close(esvr.ev_kq);
        free(esvr.ev_kev);
        free(esvr.ev_event);
#elif __illumos__
        close(esvr.ev_port);
        free(esvr.ev_pe);
#else
#error "Port me"
#endif

        return (-1);
}

static uint_t
l9p_get_server_addrs(const char *host, const char *port, struct addrinfo **resp)
{
        struct addrinfo *res, hints;
        uint_t naddrs;
        int rc;

        memset(&hints, 0, sizeof(hints));
        hints.ai_family = PF_UNSPEC;
        hints.ai_socktype = SOCK_STREAM;
        rc = getaddrinfo(host, port, &hints, resp);
        if (rc > 0) {
                L9P_LOG(L9P_ERROR, "getaddrinfo(): %s", gai_strerror(rc));
                return (0);
        }

        naddrs = 0;
        for (res = *resp; res != NULL; res = res->ai_next)
                naddrs++;

        if (naddrs == 0) {
                L9P_LOG(L9P_ERROR, "no addresses found for %s:%s", host, port);
        }

        return (naddrs);
}

#ifdef __FreeBSD__
static int
l9p_init_event_svr(struct event_svr *svr, uint_t nsockets)
{
        svr->ev_kev = calloc(nsockets, sizeof(struct kevent));
        if (svr->ev_kev == NULL) {
                L9P_LOG(L9P_ERROR, "calloc(): %s", strerror(errno));
                return (-1);
        }

        svr->ev_event = calloc(nsockets, sizeof(struct kevent));
        if (svr->ev_event == NULL) {
                L9P_LOG(L9P_ERROR, "calloc(): %s", strerror(errno));
                free(svr->ev_key);
                svr->ev_key = NULL;
                return (-1);
        }

        svr->ev_kq = kqueue();
        if (svr->ev_kq == -1) {
                L9P_LOG(L9P_ERROR, "kqueue(): %s", strerror(errno));
                free(svr->ev_kev);
                free(svr->ev_event);
                svr->ev_kev = NULL;
                svr->ev_event = NULL;
                return (-1);
        }

        return (0);
}
#elif __illumos__
static int
l9p_init_event_svr(struct event_svr *svr, uint_t nsockets)
{
        svr->ev_pe = calloc(nsockets, sizeof(port_event_t));
        if (svr->ev_pe == NULL) {
                L9P_LOG(L9P_ERROR, "calloc(): %s", strerror(errno));
                return (-1);
        }

        svr->ev_port = port_create();
        if (svr->ev_port == -1) {
                L9P_LOG(L9P_ERROR, "port_create(): %s", strerror(errno));
                return (-1);
        }

        return (0);
}
#else
#error "No event server defined"
#endif

static uint_t
l9p_bind_addrs(struct event_svr *svr, struct addrinfo *addrs, uint_t naddrs,
    int **socketsp)
{
        struct addrinfo *addr;
        uint_t i, j;

        *socketsp = calloc(naddrs, sizeof(int));
        if (*socketsp == NULL) {
                L9P_LOG(L9P_ERROR, "calloc(): %s", strerror(errno));
                return (0);
        }

        for (i = 0, addr = addrs; addr != NULL; addr = addr->ai_next) {
                int s;
                int val = 1;

                s = socket(addr->ai_family, addr->ai_socktype,
                    addr->ai_protocol);
                if (s == -1) {
                        L9P_LOG(L9P_ERROR, "socket(): %s", strerror(errno));
                        continue;
                }

                if (setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &val,
                    sizeof(val)) < 0) {
                        L9P_LOG(L9P_ERROR, "setsockopt(): %s", strerror(errno));
                        close(s);
                        continue;
                }

                if (bind(s, addr->ai_addr, addr->ai_addrlen) < 0) {
                        L9P_LOG(L9P_ERROR, "bind(): %s", strerror(errno));
                        close(s);
                        continue;
                }

                if (listen(s, 10) < 0) {
                        L9P_LOG(L9P_ERROR, "listen(): %s", strerror(errno));
                        close(s);
                        continue;
                }

#ifdef __FreeBSD__
                EV_SET(&svr->ev_kev[i], s, EVFILT_READ, EV_ADD | EV_ENABLE, 0,
                    0, 0);
#elif __illumos__
                if (port_associate(svr->ev_port, PORT_SOURCE_FD, s,
                    POLLIN|POLLHUP, NULL) < 0) {
                        L9P_LOG(L9P_ERROR, "port_associate(%d): %s", s,
                            strerror(errno));
                        close(s);
                        continue;
                }
#else
#error "Port me"
#endif

                *socketsp[i++] = s;
        }

        if (i < 1) {
                free(*socketsp);
                *socketsp = NULL;
                return (0);
        }

        for (j = i; j < naddrs; j++)
                *socketsp[j++] = -1;

#ifdef __FreeBSD__
        if (kevent(svr->ev_kq, svr->ev_kev, i, NULL, 0, NULL) < 0) {
                L9P_LOG(L9P_ERROR, "kevent(): %s", strerror(errno));

                for (j = 0; j < i; j++)
                        close(j);

                free(*socketsp);
                *socketsp = NULL;

                return (0);
        }
#endif

        return (i);
}

#ifdef __FreeBSD__
static int
l9p_event_get(struct l9p_server *l9svr, struct event_svr *esvr, uint_t nsockets,
    void (*cb)(struct l9p_server *, int))
{
        int i, evs;

        evs = kevent(esvr->ev_kq, NULL, 0, esvr->ev_event, nsockets, NULL);
        if (evs < 0) {
                if (errno == EINTR)
                        return (0);
                L9P_LOG(L9P_ERROR, "kevent(): %s", strerror(errno));
                return (-1);
        }

        for (i = 0; i < evs; i++)
                cb(l9svr, (int)sevr->ev_event[i].ident);

        return (0);
}
#elif __illumos__
static int
l9p_event_get(struct l9p_server *l9svr, struct event_svr *esvr, uint_t nsockets,
    void (*cb)(struct l9p_server *, int))
{
        uint_t evs = 1;
        int i;

        if (port_getn(esvr->ev_port, esvr->ev_pe, nsockets, &evs, NULL) < 0) {
                if (errno == EINTR)
                        return (0);
                L9P_LOG(L9P_ERROR, "port_getn(): %s", strerror(errno));
                return (-1);
        }

        for (i = 0; i < evs; i++) {
                if (esvr->ev_pe[i].portev_source != PORT_SOURCE_FD)
                        continue;

                cb(l9svr, (int)esvr->ev_pe[i].portev_object);
        }

        return (0);
}
#else
#error "Port me"
#endif

void
l9p_socket_accept(struct l9p_server *server, int svr_fd)
{
        struct l9p_socket_softc *sc;
        struct l9p_connection *conn;
        char host[NI_MAXHOST + 1];
        char serv[NI_MAXSERV + 1];
        struct sockaddr client_addr;
        socklen_t client_addr_len = sizeof(client_addr);
        int conn_fd, err;

        conn_fd = accept(svr_fd, &client_addr, &client_addr_len);
        if (conn_fd < 0) {
                L9P_LOG(L9P_WARNING, "accept(): %s", strerror(errno));
                return;
        }

        err = getnameinfo(&client_addr, client_addr_len, host, NI_MAXHOST,
            serv, NI_MAXSERV, NI_NUMERICHOST | NI_NUMERICSERV);

        if (err != 0) {
                L9P_LOG(L9P_WARNING, "cannot look up client name: %s",
                    gai_strerror(err));
        } else {
                L9P_LOG(L9P_INFO, "new connection from %s:%s", host, serv);
        }

        if (l9p_connection_init(server, &conn) != 0) {
                L9P_LOG(L9P_ERROR, "cannot create new connection");
                return;
        }

        sc = l9p_calloc(1, sizeof(*sc));
        sc->ls_conn = conn;
        sc->ls_fd = conn_fd;

        /*
         * Fill in transport handler functions and aux argument.
         */
        conn->lc_lt.lt_aux = sc;
        conn->lc_lt.lt_get_response_buffer = l9p_socket_get_response_buffer;
        conn->lc_lt.lt_send_response = l9p_socket_send_response;
        conn->lc_lt.lt_drop_response = l9p_socket_drop_response;

        err = pthread_create(&sc->ls_thread, NULL, l9p_socket_thread, sc);
        if (err) {
                L9P_LOG(L9P_ERROR,
                    "pthread_create (for connection from %s:%s): error %s",
                    host, serv, strerror(err));
                l9p_connection_close(sc->ls_conn);
                free(sc);
        }
}

static void *
l9p_socket_thread(void *arg)
{
        struct l9p_socket_softc *sc = (struct l9p_socket_softc *)arg;
        struct iovec iov;
        void *buf;
        size_t length;

        for (;;) {
                if (l9p_socket_readmsg(sc, &buf, &length) != 0)
                        break;

                iov.iov_base = buf;
                iov.iov_len = length;
                l9p_connection_recv(sc->ls_conn, &iov, 1, NULL);
                free(buf);
        }

        L9P_LOG(L9P_INFO, "connection closed");
        l9p_connection_close(sc->ls_conn);
        free(sc);
        return (NULL);
}

static int
l9p_socket_readmsg(struct l9p_socket_softc *sc, void **buf, size_t *size)
{
        uint32_t msize;
        size_t toread;
        ssize_t ret;
        void *buffer;
        int fd = sc->ls_fd;

        assert(fd > 0);

        buffer = l9p_malloc(sizeof(uint32_t));

        ret = xread(fd, buffer, sizeof(uint32_t));
        if (ret < 0) {
                L9P_LOG(L9P_ERROR, "read(): %s", strerror(errno));
                return (-1);
        }

        if (ret != sizeof(uint32_t)) {
                if (ret == 0) {
                        L9P_LOG(L9P_DEBUG, "%p: EOF", (void *)sc->ls_conn);
                } else {
                        L9P_LOG(L9P_ERROR,
                            "short read: %zd bytes of %zd expected",
                            ret, sizeof(uint32_t));
                }
                return (-1);
        }

        msize = le32toh(*(uint32_t *)buffer);
        toread = msize - sizeof(uint32_t);
        buffer = l9p_realloc(buffer, msize);

        ret = xread(fd, (char *)buffer + sizeof(uint32_t), toread);
        if (ret < 0) {
                L9P_LOG(L9P_ERROR, "read(): %s", strerror(errno));
                return (-1);
        }

        if (ret != (ssize_t)toread) {
                L9P_LOG(L9P_ERROR, "short read: %zd bytes of %zd expected",
                    ret, toread);
                return (-1);
        }

        *size = msize;
        *buf = buffer;
        L9P_LOG(L9P_INFO, "%p: read complete message, buf=%p size=%d",
            (void *)sc->ls_conn, buffer, msize);

        return (0);
}

static int
l9p_socket_get_response_buffer(struct l9p_request *req, struct iovec *iov,
    size_t *niovp, void *arg __unused)
{
        size_t size = req->lr_conn->lc_msize;
        void *buf;

        buf = l9p_malloc(size);
        iov[0].iov_base = buf;
        iov[0].iov_len = size;

        *niovp = 1;
        return (0);
}

static int
l9p_socket_send_response(struct l9p_request *req __unused,
    const struct iovec *iov, const size_t niov __unused, const size_t iolen,
    void *arg)
{
        struct l9p_socket_softc *sc = (struct l9p_socket_softc *)arg;

        assert(sc->ls_fd >= 0);

        L9P_LOG(L9P_DEBUG, "%p: sending reply, buf=%p, size=%d", arg,
            iov[0].iov_base, iolen);

        if (xwrite(sc->ls_fd, iov[0].iov_base, iolen) != (int)iolen) {
                L9P_LOG(L9P_ERROR, "short write: %s", strerror(errno));
                return (-1);
        }

        free(iov[0].iov_base);
        return (0);
}

static void
l9p_socket_drop_response(struct l9p_request *req __unused,
    const struct iovec *iov, size_t niov __unused, void *arg)
{

        L9P_LOG(L9P_DEBUG, "%p: drop buf=%p", arg, iov[0].iov_base);
        free(iov[0].iov_base);
}

static ssize_t
xread(int fd, void *buf, size_t count)
{
        size_t done = 0;
        ssize_t ret;

        while (done < count) {
                ret = read(fd, (char *)buf + done, count - done);
                if (ret < 0) {
                        if (errno == EINTR)
                                continue;

                        return (-1);
                }

                if (ret == 0)
                        return ((ssize_t)done);

                done += (size_t)ret;
        }

        return ((ssize_t)done);
}

static ssize_t
xwrite(int fd, void *buf, size_t count)
{
        size_t done = 0;
        ssize_t ret;

        while (done < count) {
                ret = write(fd, (char *)buf + done, count - done);
                if (ret < 0) {
                        if (errno == EINTR)
                                continue;

                        return (-1);
                }

                if (ret == 0)
                        return ((ssize_t)done);

                done += (size_t)ret;
        }

        return ((ssize_t)done);
}