root/tools/tools/nvmf/nvmfd/nvmfd.c
/*-
 * SPDX-License-Identifier: BSD-2-Clause
 *
 * Copyright (c) 2023-2024 Chelsio Communications, Inc.
 * Written by: John Baldwin <jhb@FreeBSD.org>
 */

#include <sys/param.h>
#include <sys/event.h>
#include <sys/linker.h>
#include <sys/module.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <assert.h>
#include <err.h>
#include <errno.h>
#include <libnvmf.h>
#include <libutil.h>
#include <netdb.h>
#include <signal.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include "internal.h"

bool data_digests = false;
bool header_digests = false;
bool flow_control_disable = false;
bool kernel_io = false;
uint32_t maxh2cdata = 256 * 1024;

static const char *subnqn;
static volatile bool quit = false;

static void
usage(void)
{
        fprintf(stderr, "nvmfd -K [-dFGg] [-H MAXH2CDATA] [-P port] [-p port] [-t transport] [-n subnqn]\n"
            "nvmfd [-dFGg] [-H MAXH2CDATA] [-P port] [-p port] [-t transport] [-n subnqn]\n"
            "\tdevice [device [...]]\n"
            "\n"
            "Devices use one of the following syntaxes:\n"
            "\tpathame      - file or disk device\n"
            "\tramdisk:size - memory disk of given size\n");
        exit(1);
}

static void
handle_sig(int sig __unused)
{
        quit = true;
}

static void
register_listen_socket(int kqfd, int s, void *udata)
{
        struct kevent kev;

        if (listen(s, -1) != 0)
                err(1, "listen");

        EV_SET(&kev, s, EVFILT_READ, EV_ADD, 0, 0, udata);
        if (kevent(kqfd, &kev, 1, NULL, 0, NULL) == -1)
                err(1, "kevent: failed to add listen socket");
}

static void
create_passive_sockets(int kqfd, const char *port, bool discovery)
{
        struct addrinfo hints, *ai, *list;
        bool created;
        int error, s;

        memset(&hints, 0, sizeof(hints));
        hints.ai_flags = AI_PASSIVE;
        hints.ai_family = AF_UNSPEC;
        hints.ai_protocol = IPPROTO_TCP;
        error = getaddrinfo(NULL, port, &hints, &list);
        if (error != 0)
                errx(1, "%s", gai_strerror(error));
        created = false;

        for (ai = list; ai != NULL; ai = ai->ai_next) {
                s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
                if (s == -1)
                        continue;

                if (bind(s, ai->ai_addr, ai->ai_addrlen) != 0) {
                        close(s);
                        continue;
                }

                if (discovery) {
                        register_listen_socket(kqfd, s, (void *)1);
                } else {
                        register_listen_socket(kqfd, s, (void *)2);
                        discovery_add_io_controller(s, subnqn);
                }
                created = true;
        }

        freeaddrinfo(list);
        if (!created)
                err(1, "Failed to create any listen sockets");
}

static void
handle_connections(int kqfd)
{
        struct kevent ev;
        int s;

        signal(SIGHUP, handle_sig);
        signal(SIGINT, handle_sig);
        signal(SIGQUIT, handle_sig);
        signal(SIGTERM, handle_sig);

        while (!quit) {
                if (kevent(kqfd, NULL, 0, &ev, 1, NULL) == -1) {
                        if (errno == EINTR)
                                continue;
                        err(1, "kevent");
                }

                assert(ev.filter == EVFILT_READ);

                s = accept(ev.ident, NULL, NULL);
                if (s == -1) {
                        warn("accept");
                        continue;
                }

                switch ((uintptr_t)ev.udata) {
                case 1:
                        handle_discovery_socket(s);
                        break;
                case 2:
                        handle_io_socket(s);
                        break;
                default:
                        __builtin_unreachable();
                }
        }
}

int
main(int ac, char **av)
{
        struct pidfh *pfh;
        const char *dport, *ioport, *transport;
        pid_t pid;
        uint64_t value;
        int ch, error, kqfd;
        bool daemonize;
        static char nqn[NVMF_NQN_MAX_LEN];

        /* 7.4.9.3 Default port for discovery */
        dport = "8009";

        pfh = NULL;
        daemonize = true;
        ioport = "0";
        subnqn = NULL;
        transport = "tcp";
        while ((ch = getopt(ac, av, "dFgGH:Kn:P:p:t:")) != -1) {
                switch (ch) {
                case 'd':
                        daemonize = false;
                        break;
                case 'F':
                        flow_control_disable = true;
                        break;
                case 'G':
                        data_digests = true;
                        break;
                case 'g':
                        header_digests = true;
                        break;
                case 'H':
                        if (expand_number(optarg, &value) != 0)
                                errx(1, "Invalid MAXH2CDATA value %s", optarg);
                        if (value < 4096 || value > UINT32_MAX ||
                            value % 4 != 0)
                                errx(1, "Invalid MAXH2CDATA value %s", optarg);
                        maxh2cdata = value;
                        break;
                case 'K':
                        kernel_io = true;
                        break;
                case 'n':
                        subnqn = optarg;
                        break;
                case 'P':
                        dport = optarg;
                        break;
                case 'p':
                        ioport = optarg;
                        break;
                case 't':
                        transport = optarg;
                        break;
                default:
                        usage();
                }
        }

        av += optind;
        ac -= optind;

        if (kernel_io) {
                if (ac > 0)
                        usage();
                if (modfind("nvmft") == -1 && kldload("nvmft") == -1)
                        warn("couldn't load nvmft");
        } else {
                if (ac < 1)
                        usage();
        }

        if (strcasecmp(transport, "tcp") == 0) {
        } else
                errx(1, "Invalid transport %s", transport);

        if (subnqn == NULL) {
                error = nvmf_nqn_from_hostuuid(nqn);
                if (error != 0)
                        errc(1, error, "Failed to generate NQN");
                subnqn = nqn;
        }

        if (!kernel_io)
                register_devices(ac, av);

        init_discovery();
        init_io(subnqn);

        if (daemonize) {
                pfh = pidfile_open(NULL, 0600, &pid);
                if (pfh == NULL) {
                        if (errno == EEXIST)
                                errx(1, "Daemon already running, pid: %jd",
                                    (intmax_t)pid);
                        warn("Cannot open or create pidfile");
                }

                if (daemon(0, 0) != 0) {
                        pidfile_remove(pfh);
                        err(1, "Failed to fork into the background");
                }

                pidfile_write(pfh);
        }

        kqfd = kqueue();
        if (kqfd == -1) {
                pidfile_remove(pfh);
                err(1, "kqueue");
        }

        create_passive_sockets(kqfd, dport, true);
        create_passive_sockets(kqfd, ioport, false);

        handle_connections(kqfd);
        shutdown_io();
        if (pfh != NULL)
                pidfile_remove(pfh);
        return (0);
}