root/tools/testing/selftests/bpf/prog_tests/test_ldsx_insn.c
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2023 Meta Platforms, Inc. and affiliates.*/

#include <test_progs.h>
#include <network_helpers.h>
#include "test_ldsx_insn.skel.h"

static void test_map_val_and_probed_memory(void)
{
        struct test_ldsx_insn *skel;
        int err;

        skel = test_ldsx_insn__open();
        if (!ASSERT_OK_PTR(skel, "test_ldsx_insn__open"))
                return;

        if (skel->rodata->skip) {
                test__skip();
                goto out;
        }

        bpf_program__set_autoload(skel->progs.rdonly_map_prog, true);
        bpf_program__set_autoload(skel->progs.map_val_prog, true);
        bpf_program__set_autoload(skel->progs.test_ptr_struct_arg, true);

        err = test_ldsx_insn__load(skel);
        if (!ASSERT_OK(err, "test_ldsx_insn__load"))
                goto out;

        err = test_ldsx_insn__attach(skel);
        if (!ASSERT_OK(err, "test_ldsx_insn__attach"))
                goto out;

        ASSERT_OK(trigger_module_test_read(256), "trigger_read");

        ASSERT_EQ(skel->bss->done1, 1, "done1");
        ASSERT_EQ(skel->bss->ret1, 1, "ret1");
        ASSERT_EQ(skel->bss->done2, 1, "done2");
        ASSERT_EQ(skel->bss->ret2, 1, "ret2");
        ASSERT_EQ(skel->bss->int_member, -1, "int_member");

out:
        test_ldsx_insn__destroy(skel);
}

static void test_ctx_member_sign_ext(void)
{
        struct test_ldsx_insn *skel;
        int err, fd, cgroup_fd;
        char buf[16] = {0};
        socklen_t optlen;

        cgroup_fd = test__join_cgroup("/ldsx_test");
        if (!ASSERT_GE(cgroup_fd, 0, "join_cgroup /ldsx_test"))
                return;

        skel = test_ldsx_insn__open();
        if (!ASSERT_OK_PTR(skel, "test_ldsx_insn__open"))
                goto close_cgroup_fd;

        if (skel->rodata->skip) {
                test__skip();
                goto destroy_skel;
        }

        bpf_program__set_autoload(skel->progs._getsockopt, true);

        err = test_ldsx_insn__load(skel);
        if (!ASSERT_OK(err, "test_ldsx_insn__load"))
                goto destroy_skel;

        skel->links._getsockopt =
                bpf_program__attach_cgroup(skel->progs._getsockopt, cgroup_fd);
        if (!ASSERT_OK_PTR(skel->links._getsockopt, "getsockopt_link"))
                goto destroy_skel;

        fd = socket(AF_INET, SOCK_STREAM, 0);
        if (!ASSERT_GE(fd, 0, "socket"))
                goto destroy_skel;

        optlen = sizeof(buf);
        (void)getsockopt(fd, SOL_IP, IP_TTL, buf, &optlen);

        ASSERT_EQ(skel->bss->set_optlen, -1, "optlen");
        ASSERT_EQ(skel->bss->set_retval, -1, "retval");

        close(fd);
destroy_skel:
        test_ldsx_insn__destroy(skel);
close_cgroup_fd:
        close(cgroup_fd);
}

static void test_ctx_member_narrow_sign_ext(void)
{
        struct test_ldsx_insn *skel;
        struct __sk_buff skb = {};
        LIBBPF_OPTS(bpf_test_run_opts, topts,
                    .data_in = &pkt_v4,
                    .data_size_in = sizeof(pkt_v4),
                    .ctx_in = &skb,
                    .ctx_size_in = sizeof(skb),
        );
        int err, prog_fd;

        skel = test_ldsx_insn__open();
        if (!ASSERT_OK_PTR(skel, "test_ldsx_insn__open"))
                return;

        if (skel->rodata->skip) {
                test__skip();
                goto out;
        }

        bpf_program__set_autoload(skel->progs._tc, true);

        err = test_ldsx_insn__load(skel);
        if (!ASSERT_OK(err, "test_ldsx_insn__load"))
                goto out;

        prog_fd = bpf_program__fd(skel->progs._tc);
        err = bpf_prog_test_run_opts(prog_fd, &topts);
        ASSERT_OK(err, "test_run");

        ASSERT_EQ(skel->bss->set_mark, -2, "set_mark");

out:
        test_ldsx_insn__destroy(skel);
}

void test_ldsx_insn(void)
{
        if (test__start_subtest("map_val and probed_memory"))
                test_map_val_and_probed_memory();
        if (test__start_subtest("ctx_member_sign_ext"))
                test_ctx_member_sign_ext();
        if (test__start_subtest("ctx_member_narrow_sign_ext"))
                test_ctx_member_narrow_sign_ext();
}