root/tools/testing/selftests/bpf/test_sockmap.c
// SPDX-License-Identifier: GPL-2.0
// Copyright (c) 2017-2018 Covalent IO, Inc. http://covalent.io
#include <stdio.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#include <sys/select.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <string.h>
#include <errno.h>
#include <stdbool.h>
#include <signal.h>
#include <fcntl.h>
#include <sys/wait.h>
#include <time.h>
#include <sched.h>

#include <sys/time.h>
#include <sys/types.h>
#include <sys/sendfile.h>

#include <linux/netlink.h>
#include <linux/socket.h>
#include <linux/sock_diag.h>
#include <linux/bpf.h>
#include <linux/if_link.h>
#include <linux/tls.h>
#include <assert.h>
#include <libgen.h>

#include <getopt.h>

#include <bpf/bpf.h>
#include <bpf/libbpf.h>

#include "bpf_util.h"
#include "cgroup_helpers.h"

int running;
static void running_handler(int a);

#ifndef TCP_ULP
# define TCP_ULP 31
#endif
#ifndef SOL_TLS
# define SOL_TLS 282
#endif

/* randomly selected ports for testing on lo */
#define S1_PORT 10000
#define S2_PORT 10001

#define BPF_SOCKMAP_FILENAME  "test_sockmap_kern.bpf.o"
#define BPF_SOCKHASH_FILENAME "test_sockhash_kern.bpf.o"
#define CG_PATH "/sockmap"

#define EDATAINTEGRITY 2001

/* global sockets */
int s1, s2, c1, c2, p1, p2;
int test_cnt;
int passed;
int failed;
int map_fd[9];
struct bpf_map *maps[9];
struct bpf_program *progs[9];
struct bpf_link *links[9];

int txmsg_pass;
int txmsg_redir;
int txmsg_drop;
int txmsg_apply;
int txmsg_cork;
int txmsg_start;
int txmsg_end;
int txmsg_start_push;
int txmsg_end_push;
int txmsg_start_pop;
int txmsg_pop;
int txmsg_ingress;
int txmsg_redir_skb;
int txmsg_ktls_skb;
int txmsg_ktls_skb_drop;
int txmsg_ktls_skb_redir;
int ktls;
int peek_flag;
int skb_use_parser;
int txmsg_omit_skb_parser;
int verify_push_start;
int verify_push_len;
int verify_pop_start;
int verify_pop_len;

static const struct option long_options[] = {
        {"help",        no_argument,            NULL, 'h' },
        {"cgroup",      required_argument,      NULL, 'c' },
        {"rate",        required_argument,      NULL, 'r' },
        {"verbose",     optional_argument,      NULL, 'v' },
        {"iov_count",   required_argument,      NULL, 'i' },
        {"length",      required_argument,      NULL, 'l' },
        {"test",        required_argument,      NULL, 't' },
        {"data_test",   no_argument,            NULL, 'd' },
        {"txmsg",               no_argument,    &txmsg_pass,  1  },
        {"txmsg_redir",         no_argument,    &txmsg_redir, 1  },
        {"txmsg_drop",          no_argument,    &txmsg_drop, 1 },
        {"txmsg_apply", required_argument,      NULL, 'a'},
        {"txmsg_cork",  required_argument,      NULL, 'k'},
        {"txmsg_start", required_argument,      NULL, 's'},
        {"txmsg_end",   required_argument,      NULL, 'e'},
        {"txmsg_start_push", required_argument, NULL, 'p'},
        {"txmsg_end_push",   required_argument, NULL, 'q'},
        {"txmsg_start_pop",  required_argument, NULL, 'w'},
        {"txmsg_pop",        required_argument, NULL, 'x'},
        {"txmsg_ingress", no_argument,          &txmsg_ingress, 1 },
        {"txmsg_redir_skb", no_argument,        &txmsg_redir_skb, 1 },
        {"ktls", no_argument,                   &ktls, 1 },
        {"peek", no_argument,                   &peek_flag, 1 },
        {"txmsg_omit_skb_parser", no_argument,      &txmsg_omit_skb_parser, 1},
        {"whitelist", required_argument,        NULL, 'n' },
        {"blacklist", required_argument,        NULL, 'b' },
        {0, 0, NULL, 0 }
};

struct test_env {
        const char *type;
        const char *subtest;
        const char *prepend;

        int test_num;
        int subtest_num;

        int succ_cnt;
        int fail_cnt;
        int fail_last;
};

struct test_env env;

struct sockmap_options {
        int verbose;
        bool base;
        bool sendpage;
        bool data_test;
        bool drop_expected;
        bool check_recved_len;
        bool tx_wait_mem;
        int iov_count;
        int iov_length;
        int rate;
        char *map;
        char *whitelist;
        char *blacklist;
        char *prepend;
};

struct _test {
        char *title;
        void (*tester)(int cg_fd, struct sockmap_options *opt);
};

static void test_start(void)
{
        env.subtest_num++;
}

static void test_fail(void)
{
        env.fail_cnt++;
}

static void test_pass(void)
{
        env.succ_cnt++;
}

static void test_reset(void)
{
        txmsg_start = txmsg_end = 0;
        txmsg_start_pop = txmsg_pop = 0;
        txmsg_start_push = txmsg_end_push = 0;
        txmsg_pass = txmsg_drop = txmsg_redir = 0;
        txmsg_apply = txmsg_cork = 0;
        txmsg_ingress = txmsg_redir_skb = 0;
        txmsg_ktls_skb = txmsg_ktls_skb_drop = txmsg_ktls_skb_redir = 0;
        txmsg_omit_skb_parser = 0;
        skb_use_parser = 0;
}

static int test_start_subtest(const struct _test *t, struct sockmap_options *o)
{
        env.type = o->map;
        env.subtest = t->title;
        env.prepend = o->prepend;
        env.test_num++;
        env.subtest_num = 0;
        env.fail_last = env.fail_cnt;
        test_reset();
        return 0;
}

static void test_end_subtest(void)
{
        int error = env.fail_cnt - env.fail_last;
        int type = strcmp(env.type, BPF_SOCKMAP_FILENAME);

        if (!error)
                test_pass();

        fprintf(stdout, "#%2d/%2d %8s:%s:%s:%s\n",
                env.test_num, env.subtest_num,
                !type ? "sockmap" : "sockhash",
                env.prepend ? : "",
                env.subtest, error ? "FAIL" : "OK");
}

static void test_print_results(void)
{
        fprintf(stdout, "Pass: %d Fail: %d\n",
                env.succ_cnt, env.fail_cnt);
}

static void usage(char *argv[])
{
        int i;

        printf(" Usage: %s --cgroup <cgroup_path>\n", argv[0]);
        printf(" options:\n");
        for (i = 0; long_options[i].name != 0; i++) {
                printf(" --%-12s", long_options[i].name);
                if (long_options[i].flag != NULL)
                        printf(" flag (internal value:%d)\n",
                                *long_options[i].flag);
                else
                        printf(" -%c\n", long_options[i].val);
        }
        printf("\n");
}

char *sock_to_string(int s)
{
        if (s == c1)
                return "client1";
        else if (s == c2)
                return "client2";
        else if (s == s1)
                return "server1";
        else if (s == s2)
                return "server2";
        else if (s == p1)
                return "peer1";
        else if (s == p2)
                return "peer2";
        else
                return "unknown";
}

static int sockmap_init_ktls(int verbose, int s)
{
        struct tls12_crypto_info_aes_gcm_128 tls_tx = {
                .info = {
                        .version     = TLS_1_2_VERSION,
                        .cipher_type = TLS_CIPHER_AES_GCM_128,
                },
        };
        struct tls12_crypto_info_aes_gcm_128 tls_rx = {
                .info = {
                        .version     = TLS_1_2_VERSION,
                        .cipher_type = TLS_CIPHER_AES_GCM_128,
                },
        };
        int so_buf = 6553500;
        int err;

        err = setsockopt(s, 6, TCP_ULP, "tls", sizeof("tls"));
        if (err) {
                fprintf(stderr, "setsockopt: TCP_ULP(%s) failed with error %i\n", sock_to_string(s), err);
                return -EINVAL;
        }
        err = setsockopt(s, SOL_TLS, TLS_TX, (void *)&tls_tx, sizeof(tls_tx));
        if (err) {
                fprintf(stderr, "setsockopt: TLS_TX(%s) failed with error %i\n", sock_to_string(s), err);
                return -EINVAL;
        }
        err = setsockopt(s, SOL_TLS, TLS_RX, (void *)&tls_rx, sizeof(tls_rx));
        if (err) {
                fprintf(stderr, "setsockopt: TLS_RX(%s) failed with error %i\n", sock_to_string(s), err);
                return -EINVAL;
        }
        err = setsockopt(s, SOL_SOCKET, SO_SNDBUF, &so_buf, sizeof(so_buf));
        if (err) {
                fprintf(stderr, "setsockopt: (%s) failed sndbuf with error %i\n", sock_to_string(s), err);
                return -EINVAL;
        }
        err = setsockopt(s, SOL_SOCKET, SO_RCVBUF, &so_buf, sizeof(so_buf));
        if (err) {
                fprintf(stderr, "setsockopt: (%s) failed rcvbuf with error %i\n", sock_to_string(s), err);
                return -EINVAL;
        }

        if (verbose)
                fprintf(stdout, "socket(%s) kTLS enabled\n", sock_to_string(s));
        return 0;
}
static int sockmap_init_sockets(int verbose)
{
        int i, err, one = 1;
        struct sockaddr_in addr;
        int *fds[4] = {&s1, &s2, &c1, &c2};

        s1 = s2 = p1 = p2 = c1 = c2 = 0;

        /* Init sockets */
        for (i = 0; i < 4; i++) {
                *fds[i] = socket(AF_INET, SOCK_STREAM, 0);
                if (*fds[i] < 0) {
                        perror("socket s1 failed()");
                        return errno;
                }
        }

        /* Allow reuse */
        for (i = 0; i < 2; i++) {
                err = setsockopt(*fds[i], SOL_SOCKET, SO_REUSEADDR,
                                 (char *)&one, sizeof(one));
                if (err) {
                        perror("setsockopt failed()");
                        return errno;
                }
        }

        /* Non-blocking sockets */
        for (i = 0; i < 2; i++) {
                err = ioctl(*fds[i], FIONBIO, (char *)&one);
                if (err < 0) {
                        perror("ioctl s1 failed()");
                        return errno;
                }
        }

        /* Bind server sockets */
        memset(&addr, 0, sizeof(struct sockaddr_in));
        addr.sin_family = AF_INET;
        addr.sin_addr.s_addr = inet_addr("127.0.0.1");

        addr.sin_port = htons(S1_PORT);
        err = bind(s1, (struct sockaddr *)&addr, sizeof(addr));
        if (err < 0) {
                perror("bind s1 failed()");
                return errno;
        }

        addr.sin_port = htons(S2_PORT);
        err = bind(s2, (struct sockaddr *)&addr, sizeof(addr));
        if (err < 0) {
                perror("bind s2 failed()");
                return errno;
        }

        /* Listen server sockets */
        addr.sin_port = htons(S1_PORT);
        err = listen(s1, 32);
        if (err < 0) {
                perror("listen s1 failed()");
                return errno;
        }

        addr.sin_port = htons(S2_PORT);
        err = listen(s2, 32);
        if (err < 0) {
                perror("listen s1 failed()");
                return errno;
        }

        /* Initiate Connect */
        addr.sin_port = htons(S1_PORT);
        err = connect(c1, (struct sockaddr *)&addr, sizeof(addr));
        if (err < 0 && errno != EINPROGRESS) {
                perror("connect c1 failed()");
                return errno;
        }

        addr.sin_port = htons(S2_PORT);
        err = connect(c2, (struct sockaddr *)&addr, sizeof(addr));
        if (err < 0 && errno != EINPROGRESS) {
                perror("connect c2 failed()");
                return errno;
        } else if (err < 0) {
                err = 0;
        }

        /* Accept Connecrtions */
        p1 = accept(s1, NULL, NULL);
        if (p1 < 0) {
                perror("accept s1 failed()");
                return errno;
        }

        p2 = accept(s2, NULL, NULL);
        if (p2 < 0) {
                perror("accept s1 failed()");
                return errno;
        }

        if (verbose > 1) {
                printf("connected sockets: c1 <-> p1, c2 <-> p2\n");
                printf("cgroups binding: c1(%i) <-> s1(%i) - - - c2(%i) <-> s2(%i)\n",
                        c1, s1, c2, s2);
        }
        return 0;
}

struct msg_stats {
        size_t bytes_sent;
        size_t bytes_recvd;
        struct timespec start;
        struct timespec end;
};

static int msg_loop_sendpage(int fd, int iov_length, int cnt,
                             struct msg_stats *s,
                             struct sockmap_options *opt)
{
        bool drop = opt->drop_expected;
        unsigned char k = 0;
        int i, j, fp;
        FILE *file;

        file = tmpfile();
        if (!file) {
                perror("create file for sendpage");
                return 1;
        }
        for (i = 0; i < cnt; i++, k = 0) {
                for (j = 0; j < iov_length; j++, k++)
                        fwrite(&k, sizeof(char), 1, file);
        }
        fflush(file);
        fseek(file, 0, SEEK_SET);

        fp = fileno(file);

        clock_gettime(CLOCK_MONOTONIC, &s->start);
        for (i = 0; i < cnt; i++) {
                int sent;

                errno = 0;
                sent = sendfile(fd, fp, NULL, iov_length);

                if (!drop && sent < 0) {
                        perror("sendpage loop error");
                        fclose(file);
                        return sent;
                } else if (drop && sent >= 0) {
                        printf("sendpage loop error expected: %i errno %i\n",
                               sent, errno);
                        fclose(file);
                        return -EIO;
                }

                if (sent > 0)
                        s->bytes_sent += sent;
        }
        clock_gettime(CLOCK_MONOTONIC, &s->end);
        fclose(file);
        return 0;
}

static void msg_free_iov(struct msghdr *msg)
{
        int i;

        for (i = 0; i < msg->msg_iovlen; i++)
                free(msg->msg_iov[i].iov_base);
        free(msg->msg_iov);
        msg->msg_iov = NULL;
        msg->msg_iovlen = 0;
}

static int msg_alloc_iov(struct msghdr *msg,
                         int iov_count, int iov_length,
                         bool data, bool xmit)
{
        unsigned char k = 0;
        struct iovec *iov;
        int i;

        iov = calloc(iov_count, sizeof(struct iovec));
        if (!iov)
                return errno;

        for (i = 0; i < iov_count; i++) {
                unsigned char *d = calloc(iov_length, sizeof(char));

                if (!d) {
                        fprintf(stderr, "iov_count %i/%i OOM\n", i, iov_count);
                        goto unwind_iov;
                }
                iov[i].iov_base = d;
                iov[i].iov_len = iov_length;

                if (data && xmit) {
                        int j;

                        for (j = 0; j < iov_length; j++)
                                d[j] = k++;
                }
        }

        msg->msg_iov = iov;
        msg->msg_iovlen = iov_count;

        return 0;
unwind_iov:
        for (i--; i >= 0 ; i--)
                free(msg->msg_iov[i].iov_base);
        return -ENOMEM;
}

/* In push or pop test, we need to do some calculations for msg_verify_data */
static void msg_verify_date_prep(void)
{
        int push_range_end = txmsg_start_push + txmsg_end_push - 1;
        int pop_range_end = txmsg_start_pop + txmsg_pop - 1;

        if (txmsg_end_push && txmsg_pop &&
            txmsg_start_push <= pop_range_end && txmsg_start_pop <= push_range_end) {
                /* The push range and the pop range overlap */
                int overlap_len;

                verify_push_start = txmsg_start_push;
                verify_pop_start = txmsg_start_pop;
                if (txmsg_start_push < txmsg_start_pop)
                        overlap_len = min(push_range_end - txmsg_start_pop + 1, txmsg_pop);
                else
                        overlap_len = min(pop_range_end - txmsg_start_push + 1, txmsg_end_push);
                verify_push_len = max(txmsg_end_push - overlap_len, 0);
                verify_pop_len = max(txmsg_pop - overlap_len, 0);
        } else {
                /* Otherwise */
                verify_push_start = txmsg_start_push;
                verify_pop_start = txmsg_start_pop;
                verify_push_len = txmsg_end_push;
                verify_pop_len = txmsg_pop;
        }
}

static int msg_verify_data(struct msghdr *msg, int size, int chunk_sz,
                           unsigned char *k_p, int *bytes_cnt_p,
                           int *check_cnt_p, int *push_p)
{
        int bytes_cnt = *bytes_cnt_p, check_cnt = *check_cnt_p, push = *push_p;
        unsigned char k = *k_p;
        int i, j;

        for (i = 0, j = 0; i < msg->msg_iovlen && size; i++, j = 0) {
                unsigned char *d = msg->msg_iov[i].iov_base;

                /* Special case test for skb ingress + ktls */
                if (i == 0 && txmsg_ktls_skb) {
                        if (msg->msg_iov[i].iov_len < 4)
                                return -EDATAINTEGRITY;
                        if (memcmp(d, "PASS", 4) != 0) {
                                fprintf(stderr,
                                        "detected skb data error with skb ingress update @iov[%i]:%i \"%02x %02x %02x %02x\" != \"PASS\"\n",
                                        i, 0, d[0], d[1], d[2], d[3]);
                                return -EDATAINTEGRITY;
                        }
                        j = 4; /* advance index past PASS header */
                }

                for (; j < msg->msg_iov[i].iov_len && size; j++) {
                        if (push > 0 &&
                            check_cnt == verify_push_start + verify_push_len - push) {
                                int skipped;
revisit_push:
                                skipped = push;
                                if (j + push >= msg->msg_iov[i].iov_len)
                                        skipped = msg->msg_iov[i].iov_len - j;
                                push -= skipped;
                                size -= skipped;
                                j += skipped - 1;
                                check_cnt += skipped;
                                continue;
                        }

                        if (verify_pop_len > 0 && check_cnt == verify_pop_start) {
                                bytes_cnt += verify_pop_len;
                                check_cnt += verify_pop_len;
                                k += verify_pop_len;

                                if (bytes_cnt == chunk_sz) {
                                        k = 0;
                                        bytes_cnt = 0;
                                        check_cnt = 0;
                                        push = verify_push_len;
                                }

                                if (push > 0 &&
                                    check_cnt == verify_push_start + verify_push_len - push)
                                        goto revisit_push;
                        }

                        if (d[j] != k++) {
                                fprintf(stderr,
                                        "detected data corruption @iov[%i]:%i %02x != %02x, %02x ?= %02x\n",
                                        i, j, d[j], k - 1, d[j+1], k);
                                return -EDATAINTEGRITY;
                        }
                        bytes_cnt++;
                        check_cnt++;
                        if (bytes_cnt == chunk_sz) {
                                k = 0;
                                bytes_cnt = 0;
                                check_cnt = 0;
                                push = verify_push_len;
                        }
                        size--;
                }
        }
        *k_p = k;
        *bytes_cnt_p = bytes_cnt;
        *check_cnt_p = check_cnt;
        *push_p = push;
        return 0;
}

static int msg_loop(int fd, int iov_count, int iov_length, int cnt,
                    struct msg_stats *s, bool tx,
                    struct sockmap_options *opt)
{
        struct msghdr msg = {0}, msg_peek = {0};
        int err, i, flags = MSG_NOSIGNAL;
        bool drop = opt->drop_expected;
        bool data = opt->data_test;
        int iov_alloc_length = iov_length;

        if (!tx && opt->check_recved_len)
                iov_alloc_length *= 2;

        err = msg_alloc_iov(&msg, iov_count, iov_alloc_length, data, tx);
        if (err)
                goto out_errno;
        if (peek_flag) {
                err = msg_alloc_iov(&msg_peek, iov_count, iov_length, data, tx);
                if (err)
                        goto out_errno;
        }

        if (tx) {
                clock_gettime(CLOCK_MONOTONIC, &s->start);
                for (i = 0; i < cnt; i++) {
                        int sent;

                        errno = 0;
                        sent = sendmsg(fd, &msg, flags);

                        if (!drop && sent < 0) {
                                if (opt->tx_wait_mem && errno == EACCES) {
                                        errno = 0;
                                        goto out_errno;
                                }
                                perror("sendmsg loop error");
                                goto out_errno;
                        } else if (drop && sent >= 0) {
                                fprintf(stderr,
                                        "sendmsg loop error expected: %i errno %i\n",
                                        sent, errno);
                                errno = -EIO;
                                goto out_errno;
                        }
                        if (sent > 0)
                                s->bytes_sent += sent;
                }
                clock_gettime(CLOCK_MONOTONIC, &s->end);
        } else {
                float total_bytes, txmsg_pop_total, txmsg_push_total;
                int slct, recvp = 0, recv, max_fd = fd;
                int fd_flags = O_NONBLOCK;
                struct timeval timeout;
                unsigned char k = 0;
                int bytes_cnt = 0;
                int check_cnt = 0;
                int push = 0;
                fd_set w;

                fcntl(fd, fd_flags);
                /* Account for pop bytes noting each iteration of apply will
                 * call msg_pop_data helper so we need to account for this
                 * by calculating the number of apply iterations. Note user
                 * of the tool can create cases where no data is sent by
                 * manipulating pop/push/pull/etc. For example txmsg_apply 1
                 * with txmsg_pop 1 will try to apply 1B at a time but each
                 * iteration will then pop 1B so no data will ever be sent.
                 * This is really only useful for testing edge cases in code
                 * paths.
                 */
                total_bytes = (float)iov_length * (float)cnt;
                if (!opt->sendpage)
                        total_bytes *= (float)iov_count;
                if (txmsg_apply) {
                        txmsg_push_total = txmsg_end_push * (total_bytes / txmsg_apply);
                        txmsg_pop_total = txmsg_pop * (total_bytes / txmsg_apply);
                } else {
                        txmsg_push_total = txmsg_end_push * cnt;
                        txmsg_pop_total = txmsg_pop * cnt;
                }
                total_bytes += txmsg_push_total;
                total_bytes -= txmsg_pop_total;
                if (data) {
                        msg_verify_date_prep();
                        push = verify_push_len;
                }
                err = clock_gettime(CLOCK_MONOTONIC, &s->start);
                if (err < 0)
                        perror("recv start time");
                while (s->bytes_recvd < total_bytes) {
                        if (txmsg_cork) {
                                timeout.tv_sec = 0;
                                timeout.tv_usec = 300000;
                        } else {
                                timeout.tv_sec = 3;
                                timeout.tv_usec = 0;
                        }

                        /* FD sets */
                        FD_ZERO(&w);
                        FD_SET(fd, &w);

                        slct = select(max_fd + 1, &w, NULL, NULL, &timeout);
                        if (slct == -1) {
                                perror("select()");
                                clock_gettime(CLOCK_MONOTONIC, &s->end);
                                goto out_errno;
                        } else if (!slct) {
                                if (opt->verbose)
                                        fprintf(stderr, "unexpected timeout: recved %zu/%f pop_total %f\n", s->bytes_recvd, total_bytes, txmsg_pop_total);
                                errno = -EIO;
                                clock_gettime(CLOCK_MONOTONIC, &s->end);
                                goto out_errno;
                        }

                        if (opt->tx_wait_mem) {
                                FD_ZERO(&w);
                                FD_SET(fd, &w);
                                slct = select(max_fd + 1, NULL, NULL, &w, &timeout);
                                errno = 0;
                                close(fd);
                                goto out_errno;
                        }

                        errno = 0;
                        if (peek_flag) {
                                flags |= MSG_PEEK;
                                recvp = recvmsg(fd, &msg_peek, flags);
                                if (recvp < 0) {
                                        if (errno != EWOULDBLOCK) {
                                                clock_gettime(CLOCK_MONOTONIC, &s->end);
                                                goto out_errno;
                                        }
                                }
                                flags = 0;
                        }

                        recv = recvmsg(fd, &msg, flags);
                        if (recv < 0) {
                                if (errno != EWOULDBLOCK) {
                                        clock_gettime(CLOCK_MONOTONIC, &s->end);
                                        perror("recv failed()");
                                        goto out_errno;
                                }
                        }

                        if (recv > 0)
                                s->bytes_recvd += recv;

                        if (opt->check_recved_len && s->bytes_recvd > total_bytes) {
                                errno = EMSGSIZE;
                                fprintf(stderr, "recv failed(), bytes_recvd:%zd, total_bytes:%f\n",
                                                s->bytes_recvd, total_bytes);
                                goto out_errno;
                        }

                        if (data) {
                                int chunk_sz = opt->sendpage ?
                                                iov_length :
                                                iov_length * iov_count;

                                errno = msg_verify_data(&msg, recv, chunk_sz, &k, &bytes_cnt,
                                                        &check_cnt, &push);
                                if (errno) {
                                        perror("data verify msg failed");
                                        goto out_errno;
                                }
                                if (recvp) {
                                        errno = msg_verify_data(&msg_peek,
                                                                recvp,
                                                                chunk_sz,
                                                                &k,
                                                                &bytes_cnt,
                                                                &check_cnt,
                                                                &push);
                                        if (errno) {
                                                perror("data verify msg_peek failed");
                                                goto out_errno;
                                        }
                                }
                        }
                }
                clock_gettime(CLOCK_MONOTONIC, &s->end);
        }

        msg_free_iov(&msg);
        msg_free_iov(&msg_peek);
        return err;
out_errno:
        msg_free_iov(&msg);
        msg_free_iov(&msg_peek);
        return errno;
}

static float giga = 1000000000;

static inline float sentBps(struct msg_stats s)
{
        return s.bytes_sent / (s.end.tv_sec - s.start.tv_sec);
}

static inline float recvdBps(struct msg_stats s)
{
        return s.bytes_recvd / (s.end.tv_sec - s.start.tv_sec);
}

static int sendmsg_test(struct sockmap_options *opt)
{
        float sent_Bps = 0, recvd_Bps = 0;
        int rx_fd, txpid, rxpid, err = 0;
        struct msg_stats s = {0};
        int iov_count = opt->iov_count;
        int iov_buf = opt->iov_length;
        int rx_status, tx_status;
        int cnt = opt->rate;

        errno = 0;

        if (opt->base)
                rx_fd = p1;
        else
                rx_fd = p2;

        if (ktls) {
                /* Redirecting into non-TLS socket which sends into a TLS
                 * socket is not a valid test. So in this case lets not
                 * enable kTLS but still run the test.
                 */
                if (!txmsg_redir || txmsg_ingress) {
                        err = sockmap_init_ktls(opt->verbose, rx_fd);
                        if (err)
                                return err;
                }
                err = sockmap_init_ktls(opt->verbose, c1);
                if (err)
                        return err;
        }

        if (opt->tx_wait_mem) {
                struct timeval timeout;
                int rxtx_buf_len = 1024;

                timeout.tv_sec = 3;
                timeout.tv_usec = 0;

                err = setsockopt(c2, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(struct timeval));
                err |= setsockopt(c2, SOL_SOCKET, SO_SNDBUFFORCE, &rxtx_buf_len, sizeof(int));
                err |= setsockopt(p2, SOL_SOCKET, SO_RCVBUFFORCE, &rxtx_buf_len, sizeof(int));
                if (err) {
                        perror("setsockopt failed()");
                        return errno;
                }
        }

        rxpid = fork();
        if (rxpid == 0) {
                if (opt->drop_expected || txmsg_ktls_skb_drop)
                        _exit(0);

                if (!iov_buf) /* zero bytes sent case */
                        _exit(0);

                if (opt->sendpage)
                        iov_count = 1;
                err = msg_loop(rx_fd, iov_count, iov_buf,
                               cnt, &s, false, opt);
                if (opt->verbose > 1)
                        fprintf(stderr,
                                "msg_loop_rx: iov_count %i iov_buf %i cnt %i err %i\n",
                                iov_count, iov_buf, cnt, err);
                if (s.end.tv_sec - s.start.tv_sec) {
                        sent_Bps = sentBps(s);
                        recvd_Bps = recvdBps(s);
                }
                if (opt->verbose > 1)
                        fprintf(stdout,
                                "rx_sendmsg: TX: %zuB %fB/s %fGB/s RX: %zuB %fB/s %fGB/s %s\n",
                                s.bytes_sent, sent_Bps, sent_Bps/giga,
                                s.bytes_recvd, recvd_Bps, recvd_Bps/giga,
                                peek_flag ? "(peek_msg)" : "");
                if (err && err != -EDATAINTEGRITY && txmsg_cork)
                        err = 0;
                exit(err ? 1 : 0);
        } else if (rxpid == -1) {
                perror("msg_loop_rx");
                return errno;
        }

        if (opt->tx_wait_mem)
                close(c2);

        txpid = fork();
        if (txpid == 0) {
                if (opt->sendpage)
                        err = msg_loop_sendpage(c1, iov_buf, cnt, &s, opt);
                else
                        err = msg_loop(c1, iov_count, iov_buf,
                                       cnt, &s, true, opt);

                if (err)
                        fprintf(stderr,
                                "msg_loop_tx: iov_count %i iov_buf %i cnt %i err %i\n",
                                iov_count, iov_buf, cnt, err);
                if (s.end.tv_sec - s.start.tv_sec) {
                        sent_Bps = sentBps(s);
                        recvd_Bps = recvdBps(s);
                }
                if (opt->verbose > 1)
                        fprintf(stdout,
                                "tx_sendmsg: TX: %zuB %fB/s %f GB/s RX: %zuB %fB/s %fGB/s\n",
                                s.bytes_sent, sent_Bps, sent_Bps/giga,
                                s.bytes_recvd, recvd_Bps, recvd_Bps/giga);
                exit(err ? 1 : 0);
        } else if (txpid == -1) {
                perror("msg_loop_tx");
                return errno;
        }

        assert(waitpid(rxpid, &rx_status, 0) == rxpid);
        assert(waitpid(txpid, &tx_status, 0) == txpid);
        if (WIFEXITED(rx_status)) {
                err = WEXITSTATUS(rx_status);
                if (err) {
                        fprintf(stderr, "rx thread exited with err %d.\n", err);
                        goto out;
                }
        }
        if (WIFEXITED(tx_status)) {
                err = WEXITSTATUS(tx_status);
                if (err)
                        fprintf(stderr, "tx thread exited with err %d.\n", err);
        }
out:
        return err;
}

static int forever_ping_pong(int rate, struct sockmap_options *opt)
{
        struct timeval timeout;
        char buf[1024] = {0};
        int sc;

        timeout.tv_sec = 10;
        timeout.tv_usec = 0;

        /* Ping/Pong data from client to server */
        sc = send(c1, buf, sizeof(buf), 0);
        if (sc < 0) {
                perror("send failed()");
                return sc;
        }

        do {
                int s, rc, i, max_fd = p2;
                fd_set w;

                /* FD sets */
                FD_ZERO(&w);
                FD_SET(c1, &w);
                FD_SET(c2, &w);
                FD_SET(p1, &w);
                FD_SET(p2, &w);

                s = select(max_fd + 1, &w, NULL, NULL, &timeout);
                if (s == -1) {
                        perror("select()");
                        break;
                } else if (!s) {
                        fprintf(stderr, "unexpected timeout\n");
                        break;
                }

                for (i = 0; i <= max_fd && s > 0; ++i) {
                        if (!FD_ISSET(i, &w))
                                continue;

                        s--;

                        rc = recv(i, buf, sizeof(buf), 0);
                        if (rc < 0) {
                                if (errno != EWOULDBLOCK) {
                                        perror("recv failed()");
                                        return rc;
                                }
                        }

                        if (rc == 0) {
                                close(i);
                                break;
                        }

                        sc = send(i, buf, rc, 0);
                        if (sc < 0) {
                                perror("send failed()");
                                return sc;
                        }
                }

                if (rate)
                        sleep(rate);

                if (opt->verbose) {
                        printf(".");
                        fflush(stdout);

                }
        } while (running);

        return 0;
}

enum {
        SELFTESTS,
        PING_PONG,
        SENDMSG,
        BASE,
        BASE_SENDPAGE,
        SENDPAGE,
};

static int run_options(struct sockmap_options *options, int cg_fd,  int test)
{
        int i, key, next_key, err, zero = 0;
        struct bpf_program *tx_prog;

        /* If base test skip BPF setup */
        if (test == BASE || test == BASE_SENDPAGE)
                goto run;

        /* Attach programs to sockmap */
        if (!txmsg_omit_skb_parser) {
                links[0] = bpf_program__attach_sockmap(progs[0], map_fd[0]);
                if (!links[0]) {
                        fprintf(stderr,
                                "ERROR: bpf_program__attach_sockmap (sockmap %i->%i): (%s)\n",
                                bpf_program__fd(progs[0]), map_fd[0], strerror(errno));
                        return -1;
                }
        }

        links[1] = bpf_program__attach_sockmap(progs[1], map_fd[0]);
        if (!links[1]) {
                fprintf(stderr, "ERROR: bpf_program__attach_sockmap (sockmap): (%s)\n",
                        strerror(errno));
                return -1;
        }

        /* Attach programs to TLS sockmap */
        if (txmsg_ktls_skb) {
                if (!txmsg_omit_skb_parser) {
                        links[2] = bpf_program__attach_sockmap(progs[0], map_fd[8]);
                        if (!links[2]) {
                                fprintf(stderr,
                                        "ERROR: bpf_program__attach_sockmap (TLS sockmap %i->%i): (%s)\n",
                                        bpf_program__fd(progs[0]), map_fd[8], strerror(errno));
                                return -1;
                        }
                }

                links[3] = bpf_program__attach_sockmap(progs[2], map_fd[8]);
                if (!links[3]) {
                        fprintf(stderr, "ERROR: bpf_program__attach_sockmap (TLS sockmap): (%s)\n",
                                strerror(errno));
                        return -1;
                }
        }

        /* Attach to cgroups */
        err = bpf_prog_attach(bpf_program__fd(progs[3]), cg_fd, BPF_CGROUP_SOCK_OPS, 0);
        if (err) {
                fprintf(stderr, "ERROR: bpf_prog_attach (groups): %d (%s)\n",
                        err, strerror(errno));
                return err;
        }

run:
        err = sockmap_init_sockets(options->verbose);
        if (err) {
                fprintf(stderr, "ERROR: test socket failed: %d\n", err);
                goto out;
        }

        /* Attach txmsg program to sockmap */
        if (txmsg_pass)
                tx_prog = progs[4];
        else if (txmsg_redir)
                tx_prog = progs[5];
        else if (txmsg_apply)
                tx_prog = progs[6];
        else if (txmsg_cork)
                tx_prog = progs[7];
        else if (txmsg_drop)
                tx_prog = progs[8];
        else
                tx_prog = NULL;

        if (tx_prog) {
                int redir_fd;

                links[4] = bpf_program__attach_sockmap(tx_prog, map_fd[1]);
                if (!links[4]) {
                        fprintf(stderr,
                                "ERROR: bpf_program__attach_sockmap (txmsg): (%s)\n",
                                strerror(errno));
                        err = -1;
                        goto out;
                }

                i = 0;
                err = bpf_map_update_elem(map_fd[1], &i, &c1, BPF_ANY);
                if (err) {
                        fprintf(stderr,
                                "ERROR: bpf_map_update_elem (txmsg):  %d (%s\n",
                                err, strerror(errno));
                        goto out;
                }

                if (txmsg_redir)
                        redir_fd = c2;
                else
                        redir_fd = c1;

                err = bpf_map_update_elem(map_fd[2], &i, &redir_fd, BPF_ANY);
                if (err) {
                        fprintf(stderr,
                                "ERROR: bpf_map_update_elem (txmsg):  %d (%s\n",
                                err, strerror(errno));
                        goto out;
                }

                if (txmsg_apply) {
                        err = bpf_map_update_elem(map_fd[3],
                                                  &i, &txmsg_apply, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem (apply_bytes):  %d (%s\n",
                                        err, strerror(errno));
                                goto out;
                        }
                }

                if (txmsg_cork) {
                        err = bpf_map_update_elem(map_fd[4],
                                                  &i, &txmsg_cork, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem (cork_bytes):  %d (%s\n",
                                        err, strerror(errno));
                                goto out;
                        }
                }

                if (txmsg_start) {
                        err = bpf_map_update_elem(map_fd[5],
                                                  &i, &txmsg_start, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem (txmsg_start):  %d (%s)\n",
                                        err, strerror(errno));
                                goto out;
                        }
                }

                if (txmsg_end) {
                        i = 1;
                        err = bpf_map_update_elem(map_fd[5],
                                                  &i, &txmsg_end, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem (txmsg_end):  %d (%s)\n",
                                        err, strerror(errno));
                                goto out;
                        }
                }

                if (txmsg_start_push) {
                        i = 2;
                        err = bpf_map_update_elem(map_fd[5],
                                                  &i, &txmsg_start_push, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem (txmsg_start_push):  %d (%s)\n",
                                        err, strerror(errno));
                                goto out;
                        }
                }

                if (txmsg_end_push) {
                        i = 3;
                        err = bpf_map_update_elem(map_fd[5],
                                                  &i, &txmsg_end_push, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem %i@%i (txmsg_end_push):  %d (%s)\n",
                                        txmsg_end_push, i, err, strerror(errno));
                                goto out;
                        }
                }

                if (txmsg_start_pop) {
                        i = 4;
                        err = bpf_map_update_elem(map_fd[5],
                                                  &i, &txmsg_start_pop, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem %i@%i (txmsg_start_pop):  %d (%s)\n",
                                        txmsg_start_pop, i, err, strerror(errno));
                                goto out;
                        }
                } else {
                        i = 4;
                        bpf_map_update_elem(map_fd[5],
                                                  &i, &txmsg_start_pop, BPF_ANY);
                }

                if (txmsg_pop) {
                        i = 5;
                        err = bpf_map_update_elem(map_fd[5],
                                                  &i, &txmsg_pop, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem %i@%i (txmsg_pop):  %d (%s)\n",
                                        txmsg_pop, i, err, strerror(errno));
                                goto out;
                        }
                } else {
                        i = 5;
                        bpf_map_update_elem(map_fd[5],
                                            &i, &txmsg_pop, BPF_ANY);

                }

                if (txmsg_ingress) {
                        int in = BPF_F_INGRESS;

                        i = 0;
                        err = bpf_map_update_elem(map_fd[6], &i, &in, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem (txmsg_ingress): %d (%s)\n",
                                        err, strerror(errno));
                        }
                        i = 1;
                        err = bpf_map_update_elem(map_fd[1], &i, &p1, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem (p1 txmsg): %d (%s)\n",
                                        err, strerror(errno));
                        }
                        err = bpf_map_update_elem(map_fd[2], &i, &p1, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem (p1 redir): %d (%s)\n",
                                        err, strerror(errno));
                        }

                        i = 2;
                        err = bpf_map_update_elem(map_fd[2], &i, &p2, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem (p2 txmsg): %d (%s)\n",
                                        err, strerror(errno));
                        }
                }

                if (txmsg_ktls_skb) {
                        int ingress = BPF_F_INGRESS;

                        i = 0;
                        err = bpf_map_update_elem(map_fd[8], &i, &p2, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem (c1 sockmap): %d (%s)\n",
                                        err, strerror(errno));
                        }

                        if (txmsg_ktls_skb_redir) {
                                i = 1;
                                err = bpf_map_update_elem(map_fd[7],
                                                          &i, &ingress, BPF_ANY);
                                if (err) {
                                        fprintf(stderr,
                                                "ERROR: bpf_map_update_elem (txmsg_ingress): %d (%s)\n",
                                                err, strerror(errno));
                                }
                        }

                        if (txmsg_ktls_skb_drop) {
                                i = 1;
                                err = bpf_map_update_elem(map_fd[7], &i, &i, BPF_ANY);
                        }
                }

                if (txmsg_redir_skb) {
                        int skb_fd = (test == SENDMSG || test == SENDPAGE) ?
                                        p2 : p1;
                        int ingress = BPF_F_INGRESS;

                        i = 0;
                        err = bpf_map_update_elem(map_fd[7],
                                                  &i, &ingress, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem (txmsg_ingress): %d (%s)\n",
                                        err, strerror(errno));
                        }

                        i = 3;
                        err = bpf_map_update_elem(map_fd[0], &i, &skb_fd, BPF_ANY);
                        if (err) {
                                fprintf(stderr,
                                        "ERROR: bpf_map_update_elem (c1 sockmap): %d (%s)\n",
                                        err, strerror(errno));
                        }
                }
        }

        if (skb_use_parser) {
                i = 2;
                err = bpf_map_update_elem(map_fd[7], &i, &skb_use_parser, BPF_ANY);
        }

        if (txmsg_drop)
                options->drop_expected = true;

        if (test == PING_PONG)
                err = forever_ping_pong(options->rate, options);
        else if (test == SENDMSG) {
                options->base = false;
                options->sendpage = false;
                err = sendmsg_test(options);
        } else if (test == SENDPAGE) {
                options->base = false;
                options->sendpage = true;
                err = sendmsg_test(options);
        } else if (test == BASE) {
                options->base = true;
                options->sendpage = false;
                err = sendmsg_test(options);
        } else if (test == BASE_SENDPAGE) {
                options->base = true;
                options->sendpage = true;
                err = sendmsg_test(options);
        } else
                fprintf(stderr, "unknown test\n");
out:
        /* Detach and zero all the maps */
        bpf_prog_detach2(bpf_program__fd(progs[3]), cg_fd, BPF_CGROUP_SOCK_OPS);

        for (i = 0; i < ARRAY_SIZE(links); i++) {
                if (links[i])
                        bpf_link__detach(links[i]);
        }

        for (i = 0; i < ARRAY_SIZE(map_fd); i++) {
                key = next_key = 0;
                bpf_map_update_elem(map_fd[i], &key, &zero, BPF_ANY);
                while (bpf_map_get_next_key(map_fd[i], &key, &next_key) == 0) {
                        bpf_map_update_elem(map_fd[i], &key, &zero, BPF_ANY);
                        key = next_key;
                }
        }

        close(s1);
        close(s2);
        close(p1);
        close(p2);
        close(c1);
        close(c2);
        return err;
}

static char *test_to_str(int test)
{
        switch (test) {
        case SENDMSG:
                return "sendmsg";
        case SENDPAGE:
                return "sendpage";
        }
        return "unknown";
}

static void append_str(char *dst, const char *src, size_t dst_cap)
{
        size_t avail = dst_cap - strlen(dst);

        if (avail <= 1) /* just zero byte could be written */
                return;

        strncat(dst, src, avail - 1); /* strncat() adds + 1 for zero byte */
}

#define OPTSTRING 60
static void test_options(char *options)
{
        char tstr[OPTSTRING];

        memset(options, 0, OPTSTRING);

        if (txmsg_pass)
                append_str(options, "pass,", OPTSTRING);
        if (txmsg_redir)
                append_str(options, "redir,", OPTSTRING);
        if (txmsg_drop)
                append_str(options, "drop,", OPTSTRING);
        if (txmsg_apply) {
                snprintf(tstr, OPTSTRING, "apply %d,", txmsg_apply);
                append_str(options, tstr, OPTSTRING);
        }
        if (txmsg_cork) {
                snprintf(tstr, OPTSTRING, "cork %d,", txmsg_cork);
                append_str(options, tstr, OPTSTRING);
        }
        if (txmsg_start) {
                snprintf(tstr, OPTSTRING, "start %d,", txmsg_start);
                append_str(options, tstr, OPTSTRING);
        }
        if (txmsg_end) {
                snprintf(tstr, OPTSTRING, "end %d,", txmsg_end);
                append_str(options, tstr, OPTSTRING);
        }
        if (txmsg_start_pop) {
                snprintf(tstr, OPTSTRING, "pop (%d,%d),",
                         txmsg_start_pop, txmsg_start_pop + txmsg_pop);
                append_str(options, tstr, OPTSTRING);
        }
        if (txmsg_ingress)
                append_str(options, "ingress,", OPTSTRING);
        if (txmsg_redir_skb)
                append_str(options, "redir_skb,", OPTSTRING);
        if (txmsg_ktls_skb)
                append_str(options, "ktls_skb,", OPTSTRING);
        if (ktls)
                append_str(options, "ktls,", OPTSTRING);
        if (peek_flag)
                append_str(options, "peek,", OPTSTRING);
}

static int __test_exec(int cgrp, int test, struct sockmap_options *opt)
{
        char *options = calloc(OPTSTRING, sizeof(char));
        int err;

        if (test == SENDPAGE)
                opt->sendpage = true;
        else
                opt->sendpage = false;

        if (txmsg_drop)
                opt->drop_expected = true;
        else
                opt->drop_expected = false;

        test_options(options);

        if (opt->verbose) {
                fprintf(stdout,
                        " [TEST %i]: (%i, %i, %i, %s, %s): ",
                        test_cnt, opt->rate, opt->iov_count, opt->iov_length,
                        test_to_str(test), options);
                fflush(stdout);
        }
        err = run_options(opt, cgrp, test);
        if (opt->verbose)
                fprintf(stdout, " %s\n", !err ? "PASS" : "FAILED");
        test_cnt++;
        !err ? passed++ : failed++;
        free(options);
        return err;
}

static void test_exec(int cgrp, struct sockmap_options *opt)
{
        int type = strcmp(opt->map, BPF_SOCKMAP_FILENAME);
        int err;

        if (type == 0) {
                test_start();
                err = __test_exec(cgrp, SENDMSG, opt);
                if (err)
                        test_fail();
        } else {
                test_start();
                err = __test_exec(cgrp, SENDPAGE, opt);
                if (err)
                        test_fail();
        }
}

static void test_send_one(struct sockmap_options *opt, int cgrp)
{
        opt->iov_length = 1;
        opt->iov_count = 1;
        opt->rate = 1;
        test_exec(cgrp, opt);

        opt->iov_length = 1;
        opt->iov_count = 1024;
        opt->rate = 1;
        test_exec(cgrp, opt);

        opt->iov_length = 1024;
        opt->iov_count = 1;
        opt->rate = 1;
        test_exec(cgrp, opt);

}

static void test_send_many(struct sockmap_options *opt, int cgrp)
{
        opt->iov_length = 3;
        opt->iov_count = 1;
        opt->rate = 512;
        test_exec(cgrp, opt);

        opt->rate = 100;
        opt->iov_count = 1;
        opt->iov_length = 5;
        test_exec(cgrp, opt);
}

static void test_send_large(struct sockmap_options *opt, int cgrp)
{
        opt->iov_length = 8192;
        opt->iov_count = 32;
        opt->rate = 2;
        test_exec(cgrp, opt);
}

static void test_send(struct sockmap_options *opt, int cgrp)
{
        test_send_one(opt, cgrp);
        test_send_many(opt, cgrp);
        test_send_large(opt, cgrp);
        sched_yield();
}

static void test_txmsg_pass(int cgrp, struct sockmap_options *opt)
{
        /* Test small and large iov_count values with pass/redir/apply/cork */
        txmsg_pass = 1;
        test_send(opt, cgrp);
}

static void test_txmsg_redir(int cgrp, struct sockmap_options *opt)
{
        txmsg_redir = 1;
        test_send(opt, cgrp);
}

static void test_txmsg_redir_wait_sndmem(int cgrp, struct sockmap_options *opt)
{
        opt->tx_wait_mem = true;
        txmsg_redir = 1;
        test_send_large(opt, cgrp);

        txmsg_redir = 1;
        txmsg_apply = 4097;
        test_send_large(opt, cgrp);
        opt->tx_wait_mem = false;
}

static void test_txmsg_drop(int cgrp, struct sockmap_options *opt)
{
        txmsg_drop = 1;
        test_send(opt, cgrp);
}

static void test_txmsg_ingress_redir(int cgrp, struct sockmap_options *opt)
{
        txmsg_pass = txmsg_drop = 0;
        txmsg_ingress = txmsg_redir = 1;
        test_send(opt, cgrp);
}

static void test_txmsg_skb(int cgrp, struct sockmap_options *opt)
{
        bool data = opt->data_test;
        int k = ktls;

        opt->data_test = true;
        ktls = 1;

        txmsg_pass = txmsg_drop = 0;
        txmsg_ingress = txmsg_redir = 0;
        txmsg_ktls_skb = 1;
        txmsg_pass = 1;

        /* Using data verification so ensure iov layout is
         * expected from test receiver side. e.g. has enough
         * bytes to write test code.
         */
        opt->iov_length = 100;
        opt->iov_count = 1;
        opt->rate = 1;
        test_exec(cgrp, opt);

        txmsg_ktls_skb_drop = 1;
        test_exec(cgrp, opt);

        txmsg_ktls_skb_drop = 0;
        txmsg_ktls_skb_redir = 1;
        test_exec(cgrp, opt);
        txmsg_ktls_skb_redir = 0;

        /* Tests that omit skb_parser */
        txmsg_omit_skb_parser = 1;
        ktls = 0;
        txmsg_ktls_skb = 0;
        test_exec(cgrp, opt);

        txmsg_ktls_skb_drop = 1;
        test_exec(cgrp, opt);
        txmsg_ktls_skb_drop = 0;

        txmsg_ktls_skb_redir = 1;
        test_exec(cgrp, opt);

        ktls = 1;
        test_exec(cgrp, opt);
        txmsg_omit_skb_parser = 0;

        opt->data_test = data;
        ktls = k;
}

/* Test cork with hung data. This tests poor usage patterns where
 * cork can leave data on the ring if user program is buggy and
 * doesn't flush them somehow. They do take some time however
 * because they wait for a timeout. Test pass, redir and cork with
 * apply logic. Use cork size of 4097 with send_large to avoid
 * aligning cork size with send size.
 */
static void test_txmsg_cork_hangs(int cgrp, struct sockmap_options *opt)
{
        txmsg_pass = 1;
        txmsg_redir = 0;
        txmsg_cork = 4097;
        txmsg_apply = 4097;
        test_send_large(opt, cgrp);

        txmsg_pass = 0;
        txmsg_redir = 1;
        txmsg_apply = 0;
        txmsg_cork = 4097;
        test_send_large(opt, cgrp);

        txmsg_pass = 0;
        txmsg_redir = 1;
        txmsg_apply = 4097;
        txmsg_cork = 4097;
        test_send_large(opt, cgrp);
}

static void test_txmsg_pull(int cgrp, struct sockmap_options *opt)
{
        /* Test basic start/end */
        txmsg_pass = 1;
        txmsg_start = 1;
        txmsg_end = 2;
        test_send(opt, cgrp);

        /* Test >4k pull */
        txmsg_pass = 1;
        txmsg_start = 4096;
        txmsg_end = 9182;
        test_send_large(opt, cgrp);

        /* Test pull + redirect */
        txmsg_redir = 1;
        txmsg_start = 1;
        txmsg_end = 2;
        test_send(opt, cgrp);

        /* Test pull + cork */
        txmsg_redir = 0;
        txmsg_cork = 512;
        txmsg_start = 1;
        txmsg_end = 2;
        test_send_many(opt, cgrp);

        /* Test pull + cork + redirect */
        txmsg_redir = 1;
        txmsg_cork = 512;
        txmsg_start = 1;
        txmsg_end = 2;
        test_send_many(opt, cgrp);
}

static void test_txmsg_pop(int cgrp, struct sockmap_options *opt)
{
        bool data = opt->data_test;

        /* Test basic pop */
        txmsg_pass = 1;
        txmsg_start_pop = 1;
        txmsg_pop = 2;
        test_send_many(opt, cgrp);

        /* Test pop with >4k */
        txmsg_pass = 1;
        txmsg_start_pop = 4096;
        txmsg_pop = 4096;
        test_send_large(opt, cgrp);

        /* Test pop + redirect */
        txmsg_redir = 1;
        txmsg_start_pop = 1;
        txmsg_pop = 2;
        test_send_many(opt, cgrp);

        /* TODO: Test for pop + cork should be different,
         * - It makes the layout of the received data difficult
         * - It makes it hard to calculate the total_bytes in the recvmsg
         * Temporarily skip the data integrity test for this case now.
         */
        opt->data_test = false;
        /* Test pop + cork */
        txmsg_redir = 0;
        txmsg_cork = 512;
        txmsg_start_pop = 1;
        txmsg_pop = 2;
        test_send_many(opt, cgrp);

        /* Test pop + redirect + cork */
        txmsg_redir = 1;
        txmsg_cork = 4;
        txmsg_start_pop = 1;
        txmsg_pop = 2;
        test_send_many(opt, cgrp);
        opt->data_test = data;
}

static void test_txmsg_push(int cgrp, struct sockmap_options *opt)
{
        bool data = opt->data_test;

        /* Test basic push */
        txmsg_pass = 1;
        txmsg_start_push = 1;
        txmsg_end_push = 1;
        test_send(opt, cgrp);

        /* Test push 4kB >4k */
        txmsg_pass = 1;
        txmsg_start_push = 4096;
        txmsg_end_push = 4096;
        test_send_large(opt, cgrp);

        /* Test push + redirect */
        txmsg_redir = 1;
        txmsg_start_push = 1;
        txmsg_end_push = 2;
        test_send_many(opt, cgrp);

        /* TODO: Test for push + cork should be different,
         * - It makes the layout of the received data difficult
         * - It makes it hard to calculate the total_bytes in the recvmsg
         * Temporarily skip the data integrity test for this case now.
         */
        opt->data_test = false;
        /* Test push + cork */
        txmsg_redir = 0;
        txmsg_cork = 512;
        txmsg_start_push = 1;
        txmsg_end_push = 2;
        test_send_many(opt, cgrp);
        opt->data_test = data;
}

static void test_txmsg_push_pop(int cgrp, struct sockmap_options *opt)
{
        /* Test push/pop range overlapping */
        txmsg_pass = 1;
        txmsg_start_push = 1;
        txmsg_end_push = 10;
        txmsg_start_pop = 5;
        txmsg_pop = 4;
        test_send_large(opt, cgrp);

        txmsg_pass = 1;
        txmsg_start_push = 1;
        txmsg_end_push = 10;
        txmsg_start_pop = 5;
        txmsg_pop = 16;
        test_send_large(opt, cgrp);

        txmsg_pass = 1;
        txmsg_start_push = 5;
        txmsg_end_push = 4;
        txmsg_start_pop = 1;
        txmsg_pop = 10;
        test_send_large(opt, cgrp);

        txmsg_pass = 1;
        txmsg_start_push = 5;
        txmsg_end_push = 16;
        txmsg_start_pop = 1;
        txmsg_pop = 10;
        test_send_large(opt, cgrp);

        /* Test push/pop range non-overlapping */
        txmsg_pass = 1;
        txmsg_start_push = 1;
        txmsg_end_push = 10;
        txmsg_start_pop = 16;
        txmsg_pop = 4;
        test_send_large(opt, cgrp);

        txmsg_pass = 1;
        txmsg_start_push = 16;
        txmsg_end_push = 10;
        txmsg_start_pop = 5;
        txmsg_pop = 4;
        test_send_large(opt, cgrp);
}

static void test_txmsg_apply(int cgrp, struct sockmap_options *opt)
{
        txmsg_pass = 1;
        txmsg_redir = 0;
        txmsg_ingress = 0;
        txmsg_apply = 1;
        txmsg_cork = 0;
        test_send_one(opt, cgrp);

        txmsg_pass = 0;
        txmsg_redir = 1;
        txmsg_ingress = 0;
        txmsg_apply = 1;
        txmsg_cork = 0;
        test_send_one(opt, cgrp);

        txmsg_pass = 0;
        txmsg_redir = 1;
        txmsg_ingress = 1;
        txmsg_apply = 1;
        txmsg_cork = 0;
        test_send_one(opt, cgrp);

        txmsg_pass = 1;
        txmsg_redir = 0;
        txmsg_ingress = 0;
        txmsg_apply = 1024;
        txmsg_cork = 0;
        test_send_large(opt, cgrp);

        txmsg_pass = 0;
        txmsg_redir = 1;
        txmsg_ingress = 0;
        txmsg_apply = 1024;
        txmsg_cork = 0;
        test_send_large(opt, cgrp);

        txmsg_pass = 0;
        txmsg_redir = 1;
        txmsg_ingress = 1;
        txmsg_apply = 1024;
        txmsg_cork = 0;
        test_send_large(opt, cgrp);
}

static void test_txmsg_cork(int cgrp, struct sockmap_options *opt)
{
        txmsg_pass = 1;
        txmsg_redir = 0;
        txmsg_apply = 0;
        txmsg_cork = 1;
        test_send(opt, cgrp);

        txmsg_pass = 1;
        txmsg_redir = 0;
        txmsg_apply = 1;
        txmsg_cork = 1;
        test_send(opt, cgrp);
}

static void test_txmsg_ingress_parser(int cgrp, struct sockmap_options *opt)
{
        txmsg_pass = 1;
        skb_use_parser = 512;
        if (ktls == 1)
                skb_use_parser = 570;
        opt->iov_length = 256;
        opt->iov_count = 1;
        opt->rate = 2;
        test_exec(cgrp, opt);
}

static void test_txmsg_ingress_parser2(int cgrp, struct sockmap_options *opt)
{
        if (ktls == 1)
                return;
        skb_use_parser = 10;
        opt->iov_length = 20;
        opt->iov_count = 1;
        opt->rate = 1;
        opt->check_recved_len = true;
        test_exec(cgrp, opt);
        opt->check_recved_len = false;
}

char *map_names[] = {
        "sock_map",
        "sock_map_txmsg",
        "sock_map_redir",
        "sock_apply_bytes",
        "sock_cork_bytes",
        "sock_bytes",
        "sock_redir_flags",
        "sock_skb_opts",
        "tls_sock_map",
};

static int populate_progs(char *bpf_file)
{
        struct bpf_program *prog;
        struct bpf_object *obj;
        int i = 0;
        long err;

        obj = bpf_object__open(bpf_file);
        err = libbpf_get_error(obj);
        if (err) {
                char err_buf[256];

                libbpf_strerror(err, err_buf, sizeof(err_buf));
                printf("Unable to load eBPF objects in file '%s' : %s\n",
                       bpf_file, err_buf);
                return -1;
        }

        i = bpf_object__load(obj);
        i = 0;
        bpf_object__for_each_program(prog, obj) {
                progs[i] = prog;
                i++;
        }

        for (i = 0; i < ARRAY_SIZE(map_fd); i++) {
                maps[i] = bpf_object__find_map_by_name(obj, map_names[i]);
                map_fd[i] = bpf_map__fd(maps[i]);
                if (map_fd[i] < 0) {
                        fprintf(stderr, "load_bpf_file: (%i) %s\n",
                                map_fd[i], strerror(errno));
                        return -1;
                }
        }

        for (i = 0; i < ARRAY_SIZE(links); i++)
                links[i] = NULL;

        return 0;
}

struct _test test[] = {
        {"txmsg test passthrough", test_txmsg_pass},
        {"txmsg test redirect", test_txmsg_redir},
        {"txmsg test redirect wait send mem", test_txmsg_redir_wait_sndmem},
        {"txmsg test drop", test_txmsg_drop},
        {"txmsg test ingress redirect", test_txmsg_ingress_redir},
        {"txmsg test skb", test_txmsg_skb},
        {"txmsg test apply", test_txmsg_apply},
        {"txmsg test cork", test_txmsg_cork},
        {"txmsg test hanging corks", test_txmsg_cork_hangs},
        {"txmsg test push_data", test_txmsg_push},
        {"txmsg test pull-data", test_txmsg_pull},
        {"txmsg test pop-data", test_txmsg_pop},
        {"txmsg test push/pop data", test_txmsg_push_pop},
        {"txmsg test ingress parser", test_txmsg_ingress_parser},
        {"txmsg test ingress parser2", test_txmsg_ingress_parser2},
};

static int check_whitelist(struct _test *t, struct sockmap_options *opt)
{
        char *entry, *ptr;

        if (!opt->whitelist)
                return 0;
        ptr = strdup(opt->whitelist);
        if (!ptr)
                return -ENOMEM;
        entry = strtok(ptr, ",");
        while (entry) {
                if ((opt->prepend && strstr(opt->prepend, entry) != 0) ||
                    strstr(opt->map, entry) != 0 ||
                    strstr(t->title, entry) != 0) {
                        free(ptr);
                        return 0;
                }
                entry = strtok(NULL, ",");
        }
        free(ptr);
        return -EINVAL;
}

static int check_blacklist(struct _test *t, struct sockmap_options *opt)
{
        char *entry, *ptr;

        if (!opt->blacklist)
                return -EINVAL;
        ptr = strdup(opt->blacklist);
        if (!ptr)
                return -ENOMEM;
        entry = strtok(ptr, ",");
        while (entry) {
                if ((opt->prepend && strstr(opt->prepend, entry) != 0) ||
                    strstr(opt->map, entry) != 0 ||
                    strstr(t->title, entry) != 0) {
                        free(ptr);
                        return 0;
                }
                entry = strtok(NULL, ",");
        }
        free(ptr);
        return -EINVAL;
}

static int __test_selftests(int cg_fd, struct sockmap_options *opt)
{
        int i, err;

        err = populate_progs(opt->map);
        if (err < 0) {
                fprintf(stderr, "ERROR: (%i) load bpf failed\n", err);
                return err;
        }

        /* Tests basic commands and APIs */
        for (i = 0; i < ARRAY_SIZE(test); i++) {
                struct _test t = test[i];

                if (check_whitelist(&t, opt) != 0)
                        continue;
                if (check_blacklist(&t, opt) == 0)
                        continue;

                test_start_subtest(&t, opt);
                t.tester(cg_fd, opt);
                test_end_subtest();
        }

        return err;
}

static void test_selftests_sockmap(int cg_fd, struct sockmap_options *opt)
{
        opt->map = BPF_SOCKMAP_FILENAME;
        __test_selftests(cg_fd, opt);
}

static void test_selftests_sockhash(int cg_fd, struct sockmap_options *opt)
{
        opt->map = BPF_SOCKHASH_FILENAME;
        __test_selftests(cg_fd, opt);
}

static void test_selftests_ktls(int cg_fd, struct sockmap_options *opt)
{
        opt->map = BPF_SOCKHASH_FILENAME;
        opt->prepend = "ktls";
        ktls = 1;
        __test_selftests(cg_fd, opt);
        ktls = 0;
}

static int test_selftest(int cg_fd, struct sockmap_options *opt)
{
        test_selftests_sockmap(cg_fd, opt);
        test_selftests_sockhash(cg_fd, opt);
        test_selftests_ktls(cg_fd, opt);
        test_print_results();
        return 0;
}

int main(int argc, char **argv)
{
        int iov_count = 1, length = 1024, rate = 1;
        struct sockmap_options options = {0};
        int opt, longindex, err, cg_fd = 0;
        char *bpf_file = BPF_SOCKMAP_FILENAME;
        int test = SELFTESTS;
        bool cg_created = 0;

        while ((opt = getopt_long(argc, argv, ":dhv:c:r:i:l:t:p:q:n:b:",
                                  long_options, &longindex)) != -1) {
                switch (opt) {
                case 's':
                        txmsg_start = atoi(optarg);
                        break;
                case 'e':
                        txmsg_end = atoi(optarg);
                        break;
                case 'p':
                        txmsg_start_push = atoi(optarg);
                        break;
                case 'q':
                        txmsg_end_push = atoi(optarg);
                        break;
                case 'w':
                        txmsg_start_pop = atoi(optarg);
                        break;
                case 'x':
                        txmsg_pop = atoi(optarg);
                        break;
                case 'a':
                        txmsg_apply = atoi(optarg);
                        break;
                case 'k':
                        txmsg_cork = atoi(optarg);
                        break;
                case 'c':
                        cg_fd = open(optarg, O_DIRECTORY, O_RDONLY);
                        if (cg_fd < 0) {
                                fprintf(stderr,
                                        "ERROR: (%i) open cg path failed: %s\n",
                                        cg_fd, optarg);
                                return cg_fd;
                        }
                        break;
                case 'r':
                        rate = atoi(optarg);
                        break;
                case 'v':
                        options.verbose = 1;
                        if (optarg)
                                options.verbose = atoi(optarg);
                        break;
                case 'i':
                        iov_count = atoi(optarg);
                        break;
                case 'l':
                        length = atoi(optarg);
                        break;
                case 'd':
                        options.data_test = true;
                        break;
                case 't':
                        if (strcmp(optarg, "ping") == 0) {
                                test = PING_PONG;
                        } else if (strcmp(optarg, "sendmsg") == 0) {
                                test = SENDMSG;
                        } else if (strcmp(optarg, "base") == 0) {
                                test = BASE;
                        } else if (strcmp(optarg, "base_sendpage") == 0) {
                                test = BASE_SENDPAGE;
                        } else if (strcmp(optarg, "sendpage") == 0) {
                                test = SENDPAGE;
                        } else {
                                usage(argv);
                                return -1;
                        }
                        break;
                case 'n':
                        options.whitelist = strdup(optarg);
                        if (!options.whitelist)
                                return -ENOMEM;
                        break;
                case 'b':
                        options.blacklist = strdup(optarg);
                        if (!options.blacklist)
                                return -ENOMEM;
                case 0:
                        break;
                case 'h':
                default:
                        usage(argv);
                        return -1;
                }
        }

        if (!cg_fd) {
                cg_fd = cgroup_setup_and_join(CG_PATH);
                if (cg_fd < 0)
                        return cg_fd;
                cg_created = 1;
        }

        /* Use libbpf 1.0 API mode */
        libbpf_set_strict_mode(LIBBPF_STRICT_ALL);

        if (test == SELFTESTS) {
                err = test_selftest(cg_fd, &options);
                goto out;
        }

        err = populate_progs(bpf_file);
        if (err) {
                fprintf(stderr, "populate program: (%s) %s\n",
                        bpf_file, strerror(errno));
                return 1;
        }
        running = 1;

        /* catch SIGINT */
        signal(SIGINT, running_handler);

        options.iov_count = iov_count;
        options.iov_length = length;
        options.rate = rate;

        err = run_options(&options, cg_fd, test);
out:
        if (options.whitelist)
                free(options.whitelist);
        if (options.blacklist)
                free(options.blacklist);
        close(cg_fd);
        if (cg_created)
                cleanup_cgroup_environment();
        return err;
}

void running_handler(int a)
{
        running = 0;
}