root/usr/src/lib/libktest/common/libktest.c
/*
 * This file and its contents are supplied under the terms of the
 * Common Development and Distribution License ("CDDL"), version 1.0.
 * You may only use this file in accordance with the terms of version
 * 1.0 of the CDDL.
 *
 * A full copy of the text of the CDDL should have accompanied this
 * source.  A copy of the CDDL is also available via the Internet at
 * http://www.illumos.org/license/CDDL.
 */

/*
 * Copyright 2025 Oxide Computer Company
 * Copyright 2024 Ryan Zezeski
 */

#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <errno.h>
#include <fcntl.h>
#include <string.h>
#include <dirent.h>
#include <atomic.h>
#include <sys/ktest.h>
#include <sys/systeminfo.h>
#include <sys/modctl.h>
#include <upanic.h>

#include "libktest_impl.h"

/*
 * Open ktest device handle.
 *
 * Returns handle pointer on success, otherwise NULL (setting errno).
 */
ktest_hdl_t *
ktest_init(void)
{
        int fd = open(KTEST_DEV_PATH, O_RDONLY | O_CLOEXEC, 0);
        if (fd < 0) {
                return (NULL);
        }

        ktest_hdl_t *hdl = malloc(sizeof (ktest_hdl_t));
        if (hdl == NULL) {
                const int err = errno;
                (void) close(fd);
                errno = err;
                return (NULL);
        }

        hdl->kt_fd = fd;
        return (hdl);
}

/*
 * Close ktest device handle.
 */
void
ktest_fini(ktest_hdl_t *hdl)
{
        if (hdl == NULL) {
                return;
        }
        if (hdl->kt_fd >= 0) {
                (void) close(hdl->kt_fd);
                hdl->kt_fd = -1;
        }
        free(hdl);
}

#define DEFAULT_LIST_BUF_SZ     (64 * 1024)

static nvlist_t *
ktest_query_tests(ktest_hdl_t *hdl)
{
        char *resp = malloc(DEFAULT_LIST_BUF_SZ);
        ktest_list_op_t op = {
                .klo_resp = resp,
                .klo_resp_len = DEFAULT_LIST_BUF_SZ,
        };

        if (resp == NULL) {
                return (NULL);
        }

        int ret;
        ret = ioctl(hdl->kt_fd, KTEST_IOCTL_LIST_TESTS, &op);

        /* Resize buffer and retry, if ioctl indicates it was too small */
        if (ret == -1 && errno == ENOBUFS) {
                free(resp);

                if ((resp = malloc(op.klo_resp_len)) == NULL) {
                        return (NULL);
                }
                op.klo_resp = resp;
                ret = ioctl(hdl->kt_fd, KTEST_IOCTL_LIST_TESTS, &op);
        }
        if (ret == -1) {
                free(resp);
                return (NULL);
        }

        nvlist_t *tests = NULL;
        if ((ret = nvlist_unpack(resp, op.klo_resp_len, &tests, 0)) != 0) {
                free(resp);
                errno = ret;
                return (NULL);
        }
        free(resp);

        /*
         * Verify that the response is marked with the expected serialization
         * format. Remove the nvpair so that only the modules remain.
         */
        uint64_t vsn = 0;
        if (nvlist_lookup_uint64(tests, KTEST_SER_FMT_KEY, &vsn) != 0 ||
            vsn != KTEST_SER_FMT_VSN) {
                nvlist_free(tests);
                errno = EINVAL;
                return (NULL);
        }
        fnvlist_remove(tests, KTEST_SER_FMT_KEY);

        return (tests);
}

static void
ktest_iter_next_test(ktest_list_iter_t *iter, bool do_reset)
{
        if (iter->kli_test == NULL && !do_reset) {
                return;
        }
        if (iter->kli_tests == NULL) {
                return;
        }

        iter->kli_test = nvlist_next_nvpair(iter->kli_tests, iter->kli_test);
        while (iter->kli_test != NULL) {
                boolean_t requires_input;
                if (nvpair_type(iter->kli_test) == DATA_TYPE_NVLIST &&
                    nvlist_lookup_boolean_value(
                    fnvpair_value_nvlist(iter->kli_test), KTEST_TEST_INPUT_KEY,
                    &requires_input) == 0) {
                        iter->kli_req_input = requires_input;
                        break;
                }
                iter->kli_test =
                    nvlist_next_nvpair(iter->kli_tests, iter->kli_test);
        }
}

static void
ktest_iter_next_suite(ktest_list_iter_t *iter, bool do_reset)
{
        if (iter->kli_suite == NULL && !do_reset) {
                return;
        }
        if (iter->kli_suites == NULL) {
                return;
        }

        iter->kli_suite = nvlist_next_nvpair(iter->kli_suites, iter->kli_suite);
        iter->kli_tests = NULL;
        iter->kli_test = NULL;
        while (iter->kli_suite != NULL) {
                if (nvpair_type(iter->kli_suite) == DATA_TYPE_NVLIST &&
                    nvlist_lookup_nvlist(fnvpair_value_nvlist(iter->kli_suite),
                    KTEST_SUITE_TESTS_KEY, &iter->kli_tests) == 0) {
                        break;
                }

                iter->kli_suite = nvlist_next_nvpair(iter->kli_suites,
                    iter->kli_suite);
        }

        ktest_iter_next_test(iter, true);
}

static void
ktest_iter_next_module(ktest_list_iter_t *iter, bool do_reset)
{
        if (iter->kli_module == NULL && !do_reset) {
                return;
        }
        VERIFY(iter->kli_modules != NULL);

        iter->kli_module = nvlist_next_nvpair(iter->kli_modules,
            iter->kli_module);
        iter->kli_suites = NULL;
        iter->kli_suite = NULL;

        while (iter->kli_module != NULL) {
                if (nvpair_type(iter->kli_module) == DATA_TYPE_NVLIST &&
                    nvlist_lookup_nvlist(fnvpair_value_nvlist(iter->kli_module),
                    KTEST_MODULE_SUITES_KEY, &iter->kli_suites) == 0) {
                        break;
                }

                iter->kli_module =
                    nvlist_next_nvpair(iter->kli_modules, iter->kli_module);
        }

        ktest_iter_next_suite(iter, true);
}

/*
 * List currently available ktests.
 *
 * Returns test list iterator on success, otherwise NULL (setting errno).
 */
ktest_list_iter_t *
ktest_list(ktest_hdl_t *hdl)
{
        nvlist_t *tests = ktest_query_tests(hdl);

        if (tests == NULL) {
                return (NULL);
        }

        ktest_list_iter_t *iter = malloc(sizeof (ktest_list_iter_t));
        if (iter == NULL) {
                const int err = errno;
                nvlist_free(tests);
                errno = err;
                return (NULL);
        }

        iter->kli_hdl = hdl;
        iter->kli_modules = tests;
        iter->kli_module = NULL;
        ktest_iter_next_module(iter, true);

        return (iter);
}

/*
 * Get the next ktest entry from a test list iterator.
 *
 * Returns true an item was available from the iterator (populating entry),
 * otherwise false.
 */
bool
ktest_list_next(ktest_list_iter_t *iter, ktest_entry_t *entry)
{
        while (iter->kli_module != NULL) {
                if (iter->kli_test != NULL) {
                        /*
                         * Output the current test, and move iterator to the
                         * next in preparation for a subsequent call.
                         */
                        entry->ke_module = nvpair_name(iter->kli_module);
                        entry->ke_suite = nvpair_name(iter->kli_suite);
                        entry->ke_test = nvpair_name(iter->kli_test);
                        entry->ke_requires_input = iter->kli_req_input;
                        ktest_iter_next_test(iter, false);
                        return (true);
                }

                ktest_iter_next_suite(iter, false);
                if (iter->kli_suite != NULL) {
                        continue;
                }
                ktest_iter_next_module(iter, false);
        }
        return (false);
}

/*
 * Reset ktest list iterator to its beginning.
 */
void
ktest_list_reset(ktest_list_iter_t *iter)
{
        iter->kli_module = NULL;
        ktest_iter_next_module(iter, true);
}

/*
 * Free a ktest list iterator.
 */
void
ktest_list_free(ktest_list_iter_t *iter)
{
        if (iter == NULL) {
                return;
        }
        iter->kli_test = iter->kli_suite = iter->kli_module = NULL;
        if (iter->kli_modules != NULL) {
                nvlist_free(iter->kli_modules);
                iter->kli_modules = NULL;
        }
        free(iter);
}

/*
 * Run a ktest.
 *
 * Requests that the ktest module run a ktest specified by module/suite/test
 * triple in the ktest_run_req_t.  If the test requires input, that too is
 * expected to be set in the run request.
 *
 * If the test was able to be run (regardless of its actual result), and the
 * emitted results data processed, ktest_run() will return true.  Any message
 * output from the test will be placed in krr_msg, which the caller is
 * responsible for free()-ing.
 *
 * If an error was encountered while attempting to run the test, or process its
 * results, ktest_run() will return false, and the ktest_run_result_t will not
 * be populated.
 */
bool
ktest_run(ktest_hdl_t *hdl, const ktest_run_req_t *req, ktest_run_result_t *res)
{
        ktest_run_op_t kro = {
                .kro_input_bytes = req->krq_input,
                .kro_input_len = req->krq_input_len,
        };

        (void) strncpy(kro.kro_module, req->krq_module,
            sizeof (kro.kro_module));
        (void) strncpy(kro.kro_suite, req->krq_suite, sizeof (kro.kro_suite));
        (void) strncpy(kro.kro_test, req->krq_test, sizeof (kro.kro_test));

        if (ioctl(hdl->kt_fd, KTEST_IOCTL_RUN_TEST, &kro) == -1) {
                return (false);
        }

        const ktest_result_t *kres = &kro.kro_result;
        res->krr_code = (ktest_code_t)kres->kr_type;
        res->krr_line = (uint_t)kres->kr_line;

        const size_t msg_len =
            strnlen(kres->kr_msg_prepend, sizeof (kres->kr_msg_prepend)) +
            strnlen(kres->kr_msg, sizeof (kres->kr_msg));

        if (msg_len != 0) {
                if (asprintf(&res->krr_msg, "%s%s", kres->kr_msg_prepend,
                    kres->kr_msg) == -1) {
                        return (false);
                }
        } else {
                res->krr_msg = NULL;
        }

        return (true);
}


/*
 * Get the string name for a ktest_code_t.
 */
const char *
ktest_code_name(ktest_code_t code)
{
        switch (code) {
        case KTEST_CODE_NONE:
                return ("NONE");
        case KTEST_CODE_PASS:
                return ("PASS");
        case KTEST_CODE_FAIL:
                return ("FAIL");
        case KTEST_CODE_SKIP:
                return ("SKIP");
        case KTEST_CODE_ERROR:
                return ("ERROR");
        default:
                break;
        }
        const char errmsg[] = "unexpected ktest_code value";
        upanic(errmsg, sizeof (errmsg));
}

#define KTEST_MODULE_SUFFIX     "_ktest"
#define KTEST_BASE_MODULE_DIR   "/usr/kernel/misc/ktest"

static char *ktest_cached_module_dir;

static const char *
ktest_mod_directory(void)
{
        if (ktest_cached_module_dir != NULL) {
                return (ktest_cached_module_dir);
        }

        char archbuf[20];
        if (sysinfo(SI_ARCHITECTURE_64, archbuf, sizeof (archbuf)) < 0) {
                return (NULL);
        }

        char *path = NULL;
        if (asprintf(&path, "%s/%s", KTEST_BASE_MODULE_DIR, archbuf) < 0) {
                return (NULL);
        }

        char *old = atomic_cas_ptr(&ktest_cached_module_dir, NULL, path);
        if (old == NULL) {
                return (path);
        } else {
                free(path);
                return (ktest_cached_module_dir);
        }
}

static bool
ktest_mod_path(const char *name, char *buf)
{
        const char *base = ktest_mod_directory();
        if (base == NULL) {
                return (false);
        }

        (void) snprintf(buf, MAXPATHLEN, "%s/%s" KTEST_MODULE_SUFFIX, base,
            name);
        return (true);
}

static int
ktest_mod_id_for_name(const char *name)
{
        struct modinfo modinfo = {
                .mi_info = MI_INFO_ONE | MI_INFO_BY_NAME,
        };
        (void) snprintf(modinfo.mi_name, sizeof (modinfo.mi_name),
            "%s" KTEST_MODULE_SUFFIX, name);

        if (modctl(MODINFO, 0, &modinfo) < 0) {
                return (-1);
        }
        return (modinfo.mi_id);
}

/*
 * Attempt to load a test module.
 *
 * Returns true if a ktests for the module could be loaded (or were already so),
 * otherwise false (setting errno).
 */
bool
ktest_mod_load(const char *name)
{
        if (ktest_mod_id_for_name(name) > 0) {
                /* Module is already loaded */
                return (true);
        }

        char path[MAXPATHLEN];
        if (!ktest_mod_path(name, path)) {
                return (false);
        }

        int id = 0;
        if (modctl(MODLOAD, 0, path, &id) != 0) {
                return (false);
        }
        return (true);
}

/*
 * Attempt to unload a test module.
 */
void
ktest_mod_unload(const char *name)
{
        const int id = ktest_mod_id_for_name(name);
        if (id > 0) {
                (void) modctl(MODUNLOAD, id);
        }
}

static bool
ktest_mod_for_each(void (*cb)(const char *))
{
        const char *dpath;

        if ((dpath = ktest_mod_directory()) == NULL) {
                return (false);
        }

        DIR *dp = opendir(dpath);
        if (dp == NULL) {
                return (false);
        }
        struct dirent *de;
        while ((de = readdir(dp)) != NULL) {
                char *suffix = strrchr(de->d_name, '_');

                if (suffix == NULL ||
                    strcmp(suffix, KTEST_MODULE_SUFFIX) != 0) {
                        continue;
                }
                /*
                 * Drop the suffix from the name, and confirm that it is
                 * appropriately sized for a module.
                 */
                *suffix = '\0';
                const size_t name_sz = strnlen(de->d_name, MODMAXNAMELEN);
                if (name_sz == 0 || name_sz >= MODMAXNAMELEN) {
                        continue;
                }

                /* Execute on valid candidate */
                cb(de->d_name);
        }
        return (true);
}

static void
ktest_mod_load_cb(const char *name)
{
        (void) ktest_mod_load(name);
}

/*
 * Attempt to load all known test modules.
 *
 * Returns true if the modules could be found in their expected directory, and
 * all were loaded successfully, otherwise false (setting errno).  In the case
 * of error, some modules may have been loaded.
 */
bool
ktest_mod_load_all(void)
{
        return (ktest_mod_for_each(ktest_mod_load_cb));
}

static void
ktest_mod_unload_cb(const char *name)
{
        ktest_mod_unload(name);
}

/*
 * Attempt to unload all known test modules.
 *
 * Returns true if the test modules could be iterated over, and unload
 * operations attempted.  A false result (with its accompanying errno) indicates
 * a problem reading the directory in which the tests reside.
 */
bool
ktest_mod_unload_all(void)
{
        return (ktest_mod_for_each(ktest_mod_unload_cb));
}

/*
 * Query the max input data size (in bytes) which can be provided to a test.
 */
size_t
ktest_max_input_size(void)
{
        return (KTEST_IOCTL_MAX_LEN);
}