root/tools/testing/selftests/bpf/prog_tests/test_task_local_data.c
// SPDX-License-Identifier: GPL-2.0
#include <pthread.h>
#include <bpf/btf.h>
#include <test_progs.h>

#define TLD_FREE_DATA_ON_THREAD_EXIT
#define TLD_DYN_DATA_SIZE (getpagesize() - 8)
#include "task_local_data.h"

struct test_tld_struct {
        __u64 a;
        __u64 b;
        __u64 c;
        __u64 d;
};

#include "test_task_local_data.skel.h"

TLD_DEFINE_KEY(value0_key, "value0", sizeof(int));

/*
 * Reset task local data between subtests by clearing metadata other
 * than the statically defined value0. This is safe as subtests run
 * sequentially. Users of task local data library should not touch
 * library internal.
 */
static void reset_tld(void)
{
        if (TLD_READ_ONCE(tld_meta_p)) {
                /* Remove TLDs created by tld_create_key() */
                tld_meta_p->cnt = 1;
                tld_meta_p->size = TLD_DYN_DATA_SIZE;
                memset(&tld_meta_p->metadata[1], 0,
                       (TLD_MAX_DATA_CNT - 1) * sizeof(struct tld_metadata));
        }
}

/* Serialize access to bpf program's global variables */
static pthread_mutex_t global_mutex;

static tld_key_t *tld_keys;

#define TEST_BASIC_THREAD_NUM 32

void *test_task_local_data_basic_thread(void *arg)
{
        LIBBPF_OPTS(bpf_test_run_opts, opts);
        struct test_task_local_data *skel = (struct test_task_local_data *)arg;
        int fd, err, tid, *value0, *value1;
        struct test_tld_struct *value2;

        fd = bpf_map__fd(skel->maps.tld_data_map);

        value0 = tld_get_data(fd, value0_key);
        if (!ASSERT_OK_PTR(value0, "tld_get_data"))
                goto out;

        value1 = tld_get_data(fd, tld_keys[1]);
        if (!ASSERT_OK_PTR(value1, "tld_get_data"))
                goto out;

        value2 = tld_get_data(fd, tld_keys[2]);
        if (!ASSERT_OK_PTR(value2, "tld_get_data"))
                goto out;

        tid = sys_gettid();

        *value0 = tid + 0;
        *value1 = tid + 1;
        value2->a = tid + 2;
        value2->b = tid + 3;
        value2->c = tid + 4;
        value2->d = tid + 5;

        pthread_mutex_lock(&global_mutex);
        /* Run task_main that read task local data and save to global variables */
        err = bpf_prog_test_run_opts(bpf_program__fd(skel->progs.task_main), &opts);
        ASSERT_OK(err, "run task_main");
        ASSERT_OK(opts.retval, "task_main retval");

        ASSERT_EQ(skel->bss->test_value0, tid + 0, "tld_get_data value0");
        ASSERT_EQ(skel->bss->test_value1, tid + 1, "tld_get_data value1");
        ASSERT_EQ(skel->bss->test_value2.a, tid + 2, "tld_get_data value2.a");
        ASSERT_EQ(skel->bss->test_value2.b, tid + 3, "tld_get_data value2.b");
        ASSERT_EQ(skel->bss->test_value2.c, tid + 4, "tld_get_data value2.c");
        ASSERT_EQ(skel->bss->test_value2.d, tid + 5, "tld_get_data value2.d");
        pthread_mutex_unlock(&global_mutex);

        /* Make sure valueX are indeed local to threads */
        ASSERT_EQ(*value0, tid + 0, "value0");
        ASSERT_EQ(*value1, tid + 1, "value1");
        ASSERT_EQ(value2->a, tid + 2, "value2.a");
        ASSERT_EQ(value2->b, tid + 3, "value2.b");
        ASSERT_EQ(value2->c, tid + 4, "value2.c");
        ASSERT_EQ(value2->d, tid + 5, "value2.d");

        *value0 = tid + 5;
        *value1 = tid + 4;
        value2->a = tid + 3;
        value2->b = tid + 2;
        value2->c = tid + 1;
        value2->d = tid + 0;

        /* Run task_main again */
        pthread_mutex_lock(&global_mutex);
        err = bpf_prog_test_run_opts(bpf_program__fd(skel->progs.task_main), &opts);
        ASSERT_OK(err, "run task_main");
        ASSERT_OK(opts.retval, "task_main retval");

        ASSERT_EQ(skel->bss->test_value0, tid + 5, "tld_get_data value0");
        ASSERT_EQ(skel->bss->test_value1, tid + 4, "tld_get_data value1");
        ASSERT_EQ(skel->bss->test_value2.a, tid + 3, "tld_get_data value2.a");
        ASSERT_EQ(skel->bss->test_value2.b, tid + 2, "tld_get_data value2.b");
        ASSERT_EQ(skel->bss->test_value2.c, tid + 1, "tld_get_data value2.c");
        ASSERT_EQ(skel->bss->test_value2.d, tid + 0, "tld_get_data value2.d");
        pthread_mutex_unlock(&global_mutex);

out:
        pthread_exit(NULL);
}

static void test_task_local_data_basic(void)
{
        struct test_task_local_data *skel;
        pthread_t thread[TEST_BASIC_THREAD_NUM];
        char dummy_key_name[TLD_NAME_LEN];
        tld_key_t key;
        int i, err;

        reset_tld();

        ASSERT_OK(pthread_mutex_init(&global_mutex, NULL), "pthread_mutex_init");

        skel = test_task_local_data__open_and_load();
        if (!ASSERT_OK_PTR(skel, "skel_open_and_load"))
                return;

        tld_keys = calloc(TLD_MAX_DATA_CNT, sizeof(tld_key_t));
        if (!ASSERT_OK_PTR(tld_keys, "calloc tld_keys"))
                goto out;

        ASSERT_FALSE(tld_key_is_err(value0_key), "TLD_DEFINE_KEY");
        tld_keys[1] = tld_create_key("value1", sizeof(int));
        ASSERT_FALSE(tld_key_is_err(tld_keys[1]), "tld_create_key");
        tld_keys[2] = tld_create_key("value2", sizeof(struct test_tld_struct));
        ASSERT_FALSE(tld_key_is_err(tld_keys[2]), "tld_create_key");

        /*
         * Shouldn't be able to store data exceed a page. Create a TLD just big
         * enough to exceed a page. TLDs already created are int value0, int
         * value1, and struct test_tld_struct value2.
         */
        key = tld_create_key("value_not_exist",
                             TLD_PAGE_SIZE - 2 * sizeof(int) - sizeof(struct test_tld_struct) + 1);
        ASSERT_EQ(tld_key_err_or_zero(key), -E2BIG, "tld_create_key");

        key = tld_create_key("value2", sizeof(struct test_tld_struct));
        ASSERT_EQ(tld_key_err_or_zero(key), -EEXIST, "tld_create_key");

        /* Shouldn't be able to create the (TLD_MAX_DATA_CNT+1)-th TLD */
        for (i = 3; i < TLD_MAX_DATA_CNT; i++) {
                snprintf(dummy_key_name, TLD_NAME_LEN, "dummy_value%d", i);
                tld_keys[i] = tld_create_key(dummy_key_name, sizeof(int));
                ASSERT_FALSE(tld_key_is_err(tld_keys[i]), "tld_create_key");
        }
        key = tld_create_key("value_not_exist", sizeof(struct test_tld_struct));
        ASSERT_EQ(tld_key_err_or_zero(key), -ENOSPC, "tld_create_key");

        /* Access TLDs from multiple threads and check if they are thread-specific */
        for (i = 0; i < TEST_BASIC_THREAD_NUM; i++) {
                err = pthread_create(&thread[i], NULL, test_task_local_data_basic_thread, skel);
                if (!ASSERT_OK(err, "pthread_create"))
                        goto out;
        }

out:
        for (i = 0; i < TEST_BASIC_THREAD_NUM; i++)
                pthread_join(thread[i], NULL);

        if (tld_keys) {
                free(tld_keys);
                tld_keys = NULL;
        }
        tld_free();
        test_task_local_data__destroy(skel);
}

#define TEST_RACE_THREAD_NUM (TLD_MAX_DATA_CNT - 3)

void *test_task_local_data_race_thread(void *arg)
{
        int err = 0, id = (intptr_t)arg;
        char key_name[32];
        tld_key_t key;

        key = tld_create_key("value_not_exist", TLD_PAGE_SIZE + 1);
        if (tld_key_err_or_zero(key) != -E2BIG) {
                err = 1;
                goto out;
        }

        /* Only one thread will succeed in creating value1 */
        key = tld_create_key("value1", sizeof(int));
        if (!tld_key_is_err(key))
                tld_keys[1] = key;

        /* Only one thread will succeed in creating value2 */
        key = tld_create_key("value2", sizeof(struct test_tld_struct));
        if (!tld_key_is_err(key))
                tld_keys[2] = key;

        snprintf(key_name, 32, "thread_%d", id);
        tld_keys[id] = tld_create_key(key_name, sizeof(int));
        if (tld_key_is_err(tld_keys[id]))
                err = 2;
out:
        return (void *)(intptr_t)err;
}

static void test_task_local_data_race(void)
{
        LIBBPF_OPTS(bpf_test_run_opts, opts);
        pthread_t thread[TEST_RACE_THREAD_NUM];
        struct test_task_local_data *skel;
        int fd, i, j, err, *data;
        void *ret = NULL;

        skel = test_task_local_data__open_and_load();
        if (!ASSERT_OK_PTR(skel, "skel_open_and_load"))
                return;

        tld_keys = calloc(TLD_MAX_DATA_CNT, sizeof(tld_key_t));
        if (!ASSERT_OK_PTR(tld_keys, "calloc tld_keys"))
                goto out;

        fd = bpf_map__fd(skel->maps.tld_data_map);

        ASSERT_FALSE(tld_key_is_err(value0_key), "TLD_DEFINE_KEY");
        tld_keys[0] = value0_key;

        for (j = 0; j < 100; j++) {
                reset_tld();

                for (i = 0; i < TEST_RACE_THREAD_NUM; i++) {
                        /*
                         * Try to make tld_create_key() race with each other. Call
                         * tld_create_key(), both valid and invalid, from different threads.
                         */
                        err = pthread_create(&thread[i], NULL, test_task_local_data_race_thread,
                                             (void *)(intptr_t)(i + 3));
                        if (CHECK_FAIL(err))
                                break;
                }

                /* Wait for all tld_create_key() to return */
                for (i = 0; i < TEST_RACE_THREAD_NUM; i++) {
                        pthread_join(thread[i], &ret);
                        if (CHECK_FAIL(ret))
                                break;
                }

                /* Write a unique number to each TLD */
                for (i = 0; i < TLD_MAX_DATA_CNT; i++) {
                        data = tld_get_data(fd, tld_keys[i]);
                        if (CHECK_FAIL(!data))
                                break;
                        *data = i;
                }

                /* Read TLDs and check the value to see if any address collides with another */
                for (i = 0; i < TLD_MAX_DATA_CNT; i++) {
                        data = tld_get_data(fd, tld_keys[i]);
                        if (CHECK_FAIL(*data != i))
                                break;
                }

                /* Run task_main to make sure no invalid TLDs are added */
                err = bpf_prog_test_run_opts(bpf_program__fd(skel->progs.task_main), &opts);
                ASSERT_OK(err, "run task_main");
                ASSERT_OK(opts.retval, "task_main retval");
        }
out:
        if (tld_keys) {
                free(tld_keys);
                tld_keys = NULL;
        }
        tld_free();
        test_task_local_data__destroy(skel);
}

void test_task_local_data(void)
{
        if (test__start_subtest("task_local_data_basic"))
                test_task_local_data_basic();
        if (test__start_subtest("task_local_data_race"))
                test_task_local_data_race();
}