root/usr/src/test/os-tests/tests/ksid/ksid.c
/*
 * This file and its contents are supplied under the terms of the
 * Common Development and Distribution License ("CDDL"), version 1.0.
 * You may only use this file in accordance with the terms of version
 * 1.0 of the CDDL.
 *
 * A full copy of the text of the CDDL should have accompanied this
 * source.  A copy of the CDDL is also available via the Internet at
 * http://www.illumos.org/license/CDDL.
 */

/*
 * Copyright 2020 Tintri by DDN, Inc. All rights reserved.
 * Copyright 2025 RackTop Systems, Inc.
 */

/*
 * Test the kernel ksid interfaces used by ZFS and SMB
 */

#include <sys/sid.h>
#include <sys/time.h>
#include <stdio.h>
#include <stdlib.h>
#include <strings.h>
#include <errno.h>

extern int ksl_bin_search_cutoff;

#define MAX_DOMAINS 8
#define DEFAULT_NUMSIDS 128
#define DEFAULT_ITERS 1

char ksid_sids[MAX_DOMAINS+2][256];
ksid_t *bad_ksids;

boolean_t
run_test(ksidlist_t *ksl, uint32_t idx, uint32_t numsids,
    uint32_t iters, boolean_t pid)
{
        uint64_t savenum = ksl->ksl_nsid;
        uint64_t success = 0;
        hrtime_t start, end;
        uint32_t i, id;
        boolean_t expect_success = (idx < numsids);
        ksid_t *ks;

        ksl->ksl_nsid = numsids;

        start = gethrtime();
        if (pid) {
                /*
                 * Only the first savenum entries are sorted,
                 * but ksl_sorted is only used when numsids > cutoff.
                 * Use ksl_sorted when idx is among them
                 */
                if (numsids <= ksl_bin_search_cutoff)
                        id = ksl->ksl_sids[idx].ks_id;
                else if (idx >= savenum)
                        id = bad_ksids[idx - savenum].ks_id;
                else
                        id = ksl->ksl_sorted[idx]->ks_id;

                for (i = 0; i < iters; i++) {
                        success += (ksidlist_has_pid(ksl, id) ==
                            expect_success) ? 1 : 0;
                }
        } else {
                if (idx >= savenum)
                        ks = &bad_ksids[idx - savenum];
                else
                        ks = &ksl->ksl_sids[idx];

                for (i = 0; i < iters; i++) {
                        success += (ksidlist_has_sid(ksl,
                            ksid_getdomain(ks), ks->ks_rid) ==
                            expect_success) ? 1 : 0;
                }
        }
        end = gethrtime();

        if (iters > 1) {
                printf("avg time to %s %s in %d sids "
                    "over %d iterations: %llu\n",
                    (expect_success) ? "find" : "not find",
                    (pid) ? "pid" : "sid", numsids, iters,
                    (end - start) / iters);
        }

        ksl->ksl_nsid = savenum;
        return (success == iters);
}

void
usage(char *prog)
{
        fprintf(stderr, "usage: %s [num sids] [num iters]\n", prog);
}

int
main(int argc, char *argv[])
{
        credsid_t *kcr;
        ksidlist_t *ksl;
        uint64_t num_failures = 0;
        uint32_t i, j, numsids, iters;
        boolean_t retry;

        if (argc > 1) {
                errno = 0;
                numsids = strtoul(argv[1], NULL, 0);
                if (errno != 0) {
                        fprintf(stderr, "error decoding numsids (%s): \n",
                            argv[1]);
                        usage(argv[0]);
                        perror(argv[0]);
                        return (errno);
                }
        } else {
                numsids = DEFAULT_NUMSIDS;
        }

        if (argc > 2) {
                iters = strtoul(argv[2], NULL, 0);
                if (errno != 0) {
                        fprintf(stderr, "error decoding iters (%s): \n",
                            argv[2]);
                        usage(argv[0]);
                        perror(argv[0]);
                        return (errno);
                }
        } else {
                iters = DEFAULT_ITERS;
        }

        if (numsids < 1 || iters < 1) {
                fprintf(stderr, "both numsids and iters "
                    "need to be at least 1\n");
                usage(argv[0]);
                return (2);
        }

        /* create MAX_DOMAINS random SIDs */
        for (i = 0; i < MAX_DOMAINS; i++) {
                (void) snprintf(ksid_sids[i], sizeof (ksid_sids[0]),
                    "S-1-5-21-%u-%u-%u",
                    arc4random(), arc4random(), arc4random());
        }

        /* create two unique SIDs for negative testing */
        for (j = MAX_DOMAINS; j < MAX_DOMAINS+2; j++) {
                do {
                        retry = B_FALSE;
                        (void) snprintf(ksid_sids[j], sizeof (ksid_sids[0]),
                            "S-1-5-21-%u-%u-%u",
                            arc4random(), arc4random(), arc4random());
                        for (i = 0; i < MAX_DOMAINS; i++) {
                                if (strncmp(ksid_sids[i], ksid_sids[j],
                                    sizeof (ksid_sids[0])) == 0) {
                                        retry = B_TRUE;
                                        break;
                                }
                        }
                } while (retry);
        }

        ksl = calloc(1, KSIDLIST_MEM(numsids));
        ksl->ksl_ref = 1;
        ksl->ksl_nsid = numsids;
        ksl->ksl_neid = 0;

        bad_ksids = malloc(sizeof (ksid_t) * numsids);

        /* initialize numsids random sids and ids in the sidlist */
        for (i = 0; i < numsids; i++) {
                uint32_t idx = arc4random_uniform(MAX_DOMAINS);

                ksl->ksl_sids[i].ks_id = arc4random();
                ksl->ksl_sids[i].ks_rid = arc4random();
                ksl->ksl_sids[i].ks_domain = ksid_lookupdomain(ksid_sids[idx]);
                ksl->ksl_sids[i].ks_attr = 0;
        }

        /*
         * create numsids random sids, whose sids and ids aren't in
         * the sidlist for negative testing
         */
        for (i = 0; i < numsids; i++) {
                bad_ksids[i].ks_attr = 0;
                bad_ksids[i].ks_rid = arc4random();
                bad_ksids[i].ks_domain =
                    ksid_lookupdomain(ksid_sids[MAX_DOMAINS +
                    (arc4random() % 2)]);

                do {
                        retry = B_FALSE;
                        bad_ksids[i].ks_id = arc4random();
                        for (j = 0; j < numsids; j++) {
                                if (ksl->ksl_sids[j].ks_id ==
                                    bad_ksids[i].ks_id) {
                                        retry = B_TRUE;
                                        break;
                                }
                        }
                } while (retry);
        }

        kcr = kcrsid_setsidlist(NULL, ksl);

        /* run tests */
        for (i = 1; i <= numsids; i++) {
                uint32_t s_idx = arc4random_uniform(i);
                uint32_t f_idx = numsids + i - 1;

                if (!run_test(ksl, s_idx, i, iters, B_FALSE)) {
                        fprintf(stderr, "Sid search failed unexpectedly: "
                            "numsids %u\n", i);
                        fprintf(stderr, "Bad SID: id %u rid %u domain %s\n",
                            ksl->ksl_sids[s_idx].ks_id,
                            ksl->ksl_sids[s_idx].ks_rid,
                            ksid_getdomain(&ksl->ksl_sids[s_idx]));
                        num_failures++;
                }
                if (!run_test(ksl, s_idx, i, iters, B_TRUE)) {
                        fprintf(stderr, "Pid search failed unexpectedly: "
                            "numsids %u\n", i);
                        fprintf(stderr, "Bad PID: id %u rid %u domain %s\n",
                            ksl->ksl_sorted[s_idx]->ks_id,
                            ksl->ksl_sorted[s_idx]->ks_rid,
                            ksid_getdomain(ksl->ksl_sorted[s_idx]));
                        num_failures++;
                }

                if (!run_test(ksl, f_idx, i, iters, B_FALSE)) {
                        fprintf(stderr, "Sid search succeeded unexpectedly: "
                            "numsids %u\n", i);
                        fprintf(stderr, "Bad SID: id %u rid %u domain %s\n",
                            ksl->ksl_sids[f_idx].ks_id,
                            ksl->ksl_sids[f_idx].ks_rid,
                            ksid_getdomain(&ksl->ksl_sids[f_idx]));
                        num_failures++;
                }
                if (!run_test(ksl, f_idx, i, iters, B_TRUE)) {
                        fprintf(stderr, "Pid search succeeded unexpectedly: "
                            "numsids %u\n", i);
                        fprintf(stderr, "Bad PID: id %u rid %u domain %s\n",
                            ksl->ksl_sids[f_idx].ks_id,
                            ksl->ksl_sids[f_idx].ks_rid,
                            ksid_getdomain(&ksl->ksl_sids[f_idx]));
                        num_failures++;
                }
        }

        for (i = 0; i < numsids - 1; i++) {
                if (ksl->ksl_sorted[i]->ks_id > ksl->ksl_sorted[i + 1]->ks_id) {
                        fprintf(stderr, "PID %u is not sorted correctly: "
                            "%u %u\n", i, ksl->ksl_sorted[i]->ks_id,
                            ksl->ksl_sorted[i + 1]->ks_id);
                        num_failures++;
                }
                if (ksl->ksl_sids[i].ks_rid > ksl->ksl_sids[i + 1].ks_rid) {
                        fprintf(stderr, "RID %u is not sorted correctly: "
                            "%u %u\n", i, ksl->ksl_sids[i].ks_rid,
                            ksl->ksl_sids[i + 1].ks_rid);
                        num_failures++;
                } else if (ksl->ksl_sids[i].ks_rid ==
                    ksl->ksl_sids[i + 1].ks_rid &&
                    strcmp(ksid_getdomain(&ksl->ksl_sids[i]),
                    ksid_getdomain(&ksl->ksl_sids[i + 1])) > 0) {
                        fprintf(stderr, "SID %u is not sorted correctly: "
                            "%s %s\n", i, ksl->ksl_sids[i].ks_rid,
                            ksl->ksl_sids[i + 1].ks_rid);
                        num_failures++;
                }
        }

        if (num_failures != 0) {
                fprintf(stderr, "%d failures detected; dumping SID table\n",
                    num_failures);
                for (i = 0; i < numsids; i++) {
                        fprintf(stderr, "SID %u: %s-%u -> %u\n", i,
                            ksid_getdomain(&ksl->ksl_sids[i]),
                            ksl->ksl_sids[i].ks_rid,
                            ksl->ksl_sids[i].ks_id);
                }

                for (i = 0; i < numsids; i++) {
                        fprintf(stderr, "SID %u: %s-%u -> %u\n",
                            i + numsids,
                            ksid_getdomain(&bad_ksids[i]),
                            bad_ksids[i].ks_rid,
                            bad_ksids[i].ks_id);
                }

                for (i = 0; i < numsids; i++) {
                        fprintf(stderr, "PID %u: %u\n", i,
                            ksl->ksl_sorted[i]->ks_id);
                }
        } else {
                printf("all tests completed successfully!\n");
        }

        /* Clear ksidlist. This should free the credsid_t. */
        kcr = kcrsid_setsidlist(kcr, NULL);
        /* Add a core ksid. This should allocate a credsid_t. */
        kcr = kcrsid_setsid(kcr, &ksl->ksl_sids[0], KSID_GROUP);
        /* Set the ksidlist to NULL again. This should be a no-op. */
        kcr = kcrsid_setsidlist(kcr, NULL);
        kcrsid_rele(kcr);

        return (num_failures);
}