root/tools/testing/selftests/bpf/prog_tests/btf_skc_cls_ingress.c
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2020 Facebook */

#define _GNU_SOURCE
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <sched.h>
#include <net/if.h>
#include <linux/compiler.h>
#include <bpf/libbpf.h>

#include "network_helpers.h"
#include "test_progs.h"
#include "test_btf_skc_cls_ingress.skel.h"

#define TEST_NS "skc_cls_ingress"

#define BIT(n)          (1 << (n))
#define TEST_MODE_IPV4  BIT(0)
#define TEST_MODE_IPV6  BIT(1)
#define TEST_MODE_DUAL  (TEST_MODE_IPV4 | TEST_MODE_IPV6)

#define SERVER_ADDR_IPV4        "127.0.0.1"
#define SERVER_ADDR_IPV6        "::1"
#define SERVER_ADDR_DUAL        "::0"
/* RFC791, 576 for minimal IPv4 datagram, minus 40 bytes of TCP header */
#define MIN_IPV4_MSS            536

static struct netns_obj *prepare_netns(struct test_btf_skc_cls_ingress *skel)
{
        LIBBPF_OPTS(bpf_tc_hook, qdisc_lo, .attach_point = BPF_TC_INGRESS);
        LIBBPF_OPTS(bpf_tc_opts, tc_attach,
                    .prog_fd = bpf_program__fd(skel->progs.cls_ingress));
        struct netns_obj *ns = NULL;

        ns = netns_new(TEST_NS, true);
        if (!ASSERT_OK_PTR(ns, "create and join netns"))
                return ns;

        qdisc_lo.ifindex = if_nametoindex("lo");
        if (!ASSERT_OK(bpf_tc_hook_create(&qdisc_lo), "qdisc add dev lo clsact"))
                goto free_ns;

        if (!ASSERT_OK(bpf_tc_attach(&qdisc_lo, &tc_attach),
                       "filter add dev lo ingress"))
                goto free_ns;

        /* Ensure 20 bytes options (i.e. in total 40 bytes tcp header) for the
         * bpf_tcp_gen_syncookie() helper.
         */
        if (write_sysctl("/proc/sys/net/ipv4/tcp_window_scaling", "1") ||
            write_sysctl("/proc/sys/net/ipv4/tcp_timestamps", "1") ||
            write_sysctl("/proc/sys/net/ipv4/tcp_sack", "1"))
                goto free_ns;

        return ns;

free_ns:
        netns_free(ns);
        return NULL;
}

static void reset_test(struct test_btf_skc_cls_ingress *skel)
{
        memset(&skel->bss->srv_sa4, 0, sizeof(skel->bss->srv_sa4));
        memset(&skel->bss->srv_sa6, 0, sizeof(skel->bss->srv_sa6));
        skel->bss->listen_tp_sport = 0;
        skel->bss->req_sk_sport = 0;
        skel->bss->recv_cookie = 0;
        skel->bss->gen_cookie = 0;
        skel->bss->linum = 0;
        skel->bss->mss = 0;
}

static void print_err_line(struct test_btf_skc_cls_ingress *skel)
{
        if (skel->bss->linum)
                printf("bpf prog error at line %u\n", skel->bss->linum);
}

static int v6only_true(int fd, void *opts)
{
        int mode = true;

        return setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &mode, sizeof(mode));
}

static int v6only_false(int fd, void *opts)
{
        int mode = false;

        return setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &mode, sizeof(mode));
}

static void run_test(struct test_btf_skc_cls_ingress *skel, bool gen_cookies,
                     int ip_mode)
{
        const char *tcp_syncookies = gen_cookies ? "2" : "1";
        int listen_fd = -1, cli_fd = -1, srv_fd = -1, err;
        struct network_helper_opts opts = { 0 };
        struct sockaddr_storage *addr;
        struct sockaddr_in6 srv_sa6;
        struct sockaddr_in srv_sa4;
        socklen_t addr_len;
        int sock_family;
        char *srv_addr;
        int srv_port;

        switch (ip_mode) {
        case TEST_MODE_IPV4:
                sock_family = AF_INET;
                srv_addr = SERVER_ADDR_IPV4;
                addr = (struct sockaddr_storage *)&srv_sa4;
                addr_len = sizeof(srv_sa4);
                break;
        case TEST_MODE_IPV6:
                opts.post_socket_cb = v6only_true;
                sock_family = AF_INET6;
                srv_addr = SERVER_ADDR_IPV6;
                addr = (struct sockaddr_storage *)&srv_sa6;
                addr_len = sizeof(srv_sa6);
                break;
        case TEST_MODE_DUAL:
                opts.post_socket_cb = v6only_false;
                sock_family = AF_INET6;
                srv_addr = SERVER_ADDR_DUAL;
                addr = (struct sockaddr_storage *)&srv_sa6;
                addr_len = sizeof(srv_sa6);
                break;
        default:
                PRINT_FAIL("Unknown IP mode %d", ip_mode);
                return;
        }

        if (write_sysctl("/proc/sys/net/ipv4/tcp_syncookies", tcp_syncookies))
                return;

        listen_fd = start_server_str(sock_family, SOCK_STREAM, srv_addr,  0,
                                     &opts);
        if (!ASSERT_OK_FD(listen_fd, "start server"))
                return;

        err = getsockname(listen_fd, (struct sockaddr *)addr, &addr_len);
        if (!ASSERT_OK(err, "getsockname(listen_fd)"))
                goto done;

        switch (ip_mode) {
        case TEST_MODE_IPV4:
                memcpy(&skel->bss->srv_sa4, &srv_sa4, sizeof(srv_sa4));
                srv_port = ntohs(srv_sa4.sin_port);
                break;
        case TEST_MODE_IPV6:
        case TEST_MODE_DUAL:
                memcpy(&skel->bss->srv_sa6, &srv_sa6, sizeof(srv_sa6));
                srv_port = ntohs(srv_sa6.sin6_port);
                break;
        default:
                goto done;
        }

        cli_fd = connect_to_fd(listen_fd, 0);
        if (!ASSERT_OK_FD(cli_fd, "connect client"))
                goto done;

        srv_fd = accept(listen_fd, NULL, NULL);
        if (!ASSERT_OK_FD(srv_fd, "accept connection"))
                goto done;

        ASSERT_EQ(skel->bss->listen_tp_sport, srv_port, "listen tp src port");

        if (!gen_cookies) {
                ASSERT_EQ(skel->bss->req_sk_sport, srv_port,
                          "request socket source port with syncookies disabled");
                ASSERT_EQ(skel->bss->gen_cookie, 0,
                          "generated syncookie with syncookies disabled");
                ASSERT_EQ(skel->bss->recv_cookie, 0,
                          "received syncookie with syncookies disabled");
        } else {
                ASSERT_EQ(skel->bss->req_sk_sport, 0,
                          "request socket source port with syncookies enabled");
                ASSERT_NEQ(skel->bss->gen_cookie, 0,
                           "syncookie properly generated");
                ASSERT_EQ(skel->bss->gen_cookie, skel->bss->recv_cookie,
                          "matching syncookies on client and server");
                ASSERT_GT(skel->bss->mss, MIN_IPV4_MSS,
                          "MSS in cookie min value");
                ASSERT_LT(skel->bss->mss, USHRT_MAX,
                          "MSS in cookie max value");
        }

done:
        if (listen_fd != -1)
                close(listen_fd);
        if (cli_fd != -1)
                close(cli_fd);
        if (srv_fd != -1)
                close(srv_fd);
}

static void test_conn_ipv4(struct test_btf_skc_cls_ingress *skel)
{
        run_test(skel, false, TEST_MODE_IPV4);
}

static void test_conn_ipv6(struct test_btf_skc_cls_ingress *skel)
{
        run_test(skel, false, TEST_MODE_IPV6);
}

static void test_conn_dual(struct test_btf_skc_cls_ingress *skel)
{
        run_test(skel, false, TEST_MODE_DUAL);
}

static void test_syncookie_ipv4(struct test_btf_skc_cls_ingress *skel)
{
        run_test(skel, true, TEST_MODE_IPV4);
}

static void test_syncookie_ipv6(struct test_btf_skc_cls_ingress *skel)
{
        run_test(skel, true, TEST_MODE_IPV6);
}

static void test_syncookie_dual(struct test_btf_skc_cls_ingress *skel)
{
        run_test(skel, true, TEST_MODE_DUAL);
}

struct test {
        const char *desc;
        void (*run)(struct test_btf_skc_cls_ingress *skel);
};

#define DEF_TEST(name) { #name, test_##name }
static struct test tests[] = {
        DEF_TEST(conn_ipv4),
        DEF_TEST(conn_ipv6),
        DEF_TEST(conn_dual),
        DEF_TEST(syncookie_ipv4),
        DEF_TEST(syncookie_ipv6),
        DEF_TEST(syncookie_dual),
};

void test_btf_skc_cls_ingress(void)
{
        struct test_btf_skc_cls_ingress *skel;
        struct netns_obj *ns;
        int i;

        skel = test_btf_skc_cls_ingress__open_and_load();
        if (!ASSERT_OK_PTR(skel, "test_btf_skc_cls_ingress__open_and_load"))
                return;

        for (i = 0; i < ARRAY_SIZE(tests); i++) {
                if (!test__start_subtest(tests[i].desc))
                        continue;

                ns = prepare_netns(skel);
                if (!ns)
                        break;

                tests[i].run(skel);

                print_err_line(skel);
                reset_test(skel);
                netns_free(ns);
        }

        test_btf_skc_cls_ingress__destroy(skel);
}