root/tools/testing/selftests/net/tcp_ao/connect-deny.c
// SPDX-License-Identifier: GPL-2.0
/* Author: Dmitry Safonov <dima@arista.com> */
#include <inttypes.h>
#include "aolib.h"

#define fault(type)     (inj == FAULT_ ## type)
static volatile int sk_pair;

static inline int test_add_key_maclen(int sk, const char *key, uint8_t maclen,
                                      union tcp_addr in_addr, uint8_t prefix,
                                      uint8_t sndid, uint8_t rcvid)
{
        struct tcp_ao_add tmp = {};
        int err;

        if (prefix > DEFAULT_TEST_PREFIX)
                prefix = DEFAULT_TEST_PREFIX;

        err = test_prepare_key(&tmp, DEFAULT_TEST_ALGO, in_addr, false, false,
                               prefix, 0, sndid, rcvid, maclen,
                               0, strlen(key), key);
        if (err)
                return err;

        err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
        if (err < 0)
                return -errno;

        return test_verify_socket_key(sk, &tmp);
}

static void try_accept(const char *tst_name, unsigned int port, const char *pwd,
                       union tcp_addr addr, uint8_t prefix,
                       uint8_t sndid, uint8_t rcvid, uint8_t maclen,
                       const char *cnt_name, test_cnt cnt_expected,
                       fault_t inj)
{
        struct tcp_counters cnt1, cnt2;
        uint64_t before_cnt = 0, after_cnt = 0; /* silence GCC */
        test_cnt poll_cnt = (cnt_expected == TEST_CNT_GOOD) ? 0 : cnt_expected;
        int lsk, err, sk = 0;

        lsk = test_listen_socket(this_ip_addr, port, 1);

        if (pwd && test_add_key_maclen(lsk, pwd, maclen, addr, prefix, sndid, rcvid))
                test_error("setsockopt(TCP_AO_ADD_KEY)");

        if (cnt_name)
                before_cnt = netstat_get_one(cnt_name, NULL);
        if (pwd && test_get_tcp_counters(lsk, &cnt1))
                test_error("test_get_tcp_counters()");

        synchronize_threads(); /* preparations done */

        err = test_skpair_wait_poll(lsk, 0, poll_cnt, &sk_pair);
        if (err == -ETIMEDOUT) {
                sk_pair = err;
                if (!fault(TIMEOUT))
                        test_fail("%s: timed out for accept()", tst_name);
        } else if (err == -EKEYREJECTED) {
                if (!fault(KEYREJECT))
                        test_fail("%s: key was rejected", tst_name);
        } else if (err < 0) {
                test_error("test_skpair_wait_poll()");
        } else {
                if (fault(TIMEOUT))
                        test_fail("%s: ready to accept", tst_name);

                sk = accept(lsk, NULL, NULL);
                if (sk < 0) {
                        test_error("accept()");
                } else {
                        if (fault(TIMEOUT))
                                test_fail("%s: accepted", tst_name);
                }
        }

        synchronize_threads(); /* before counter checks */
        if (pwd && test_get_tcp_counters(lsk, &cnt2))
                test_error("test_get_tcp_counters()");

        close(lsk);

        if (pwd)
                test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected);

        if (!cnt_name)
                goto out;

        after_cnt = netstat_get_one(cnt_name, NULL);

        if (after_cnt <= before_cnt) {
                test_fail("%s: %s counter did not increase: %" PRIu64 " <= %" PRIu64,
                                tst_name, cnt_name, after_cnt, before_cnt);
        } else {
                test_ok("%s: counter %s increased %" PRIu64  " => %" PRIu64,
                        tst_name, cnt_name, before_cnt, after_cnt);
        }

out:
        synchronize_threads(); /* close() */
        if (sk > 0)
                close(sk);
}

static void *server_fn(void *arg)
{
        union tcp_addr wrong_addr, network_addr;
        unsigned int port = test_server_port;

        if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
                test_error("Can't convert ip address %s", TEST_WRONG_IP);

        try_accept("Non-AO server + AO client", port++, NULL,
                   this_ip_dest, -1, 100, 100, 0,
                   "TCPAOKeyNotFound", TEST_CNT_NS_KEY_NOT_FOUND, FAULT_TIMEOUT);

        try_accept("AO server + Non-AO client", port++, DEFAULT_TEST_PASSWORD,
                   this_ip_dest, -1, 100, 100, 0,
                   "TCPAORequired", TEST_CNT_AO_REQUIRED, FAULT_TIMEOUT);

        try_accept("Wrong password", port++, "something that is not DEFAULT_TEST_PASSWORD",
                   this_ip_dest, -1, 100, 100, 0,
                   "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT);

        try_accept("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD,
                   this_ip_dest, -1, 100, 101, 0,
                   "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT);

        try_accept("Wrong snd id", port++, DEFAULT_TEST_PASSWORD,
                   this_ip_dest, -1, 101, 100, 0,
                   "TCPAOGood", TEST_CNT_GOOD, FAULT_TIMEOUT);

        try_accept("Different maclen", port++, DEFAULT_TEST_PASSWORD,
                   this_ip_dest, -1, 100, 100, 8,
                   "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT);

        try_accept("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
                   wrong_addr, -1, 100, 100, 0,
                   "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT);

        /* Key rejected by the other side, failing short through skpair */
        try_accept("Client: Wrong addr", port++, NULL,
                   this_ip_dest, -1, 100, 100, 0, NULL, 0, FAULT_KEYREJECT);

        try_accept("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD,
                   this_ip_dest, -1, 200, 100, 0,
                   "TCPAOGood", TEST_CNT_GOOD, 0);

        if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1)
                test_error("Can't convert ip address %s", TEST_NETWORK);

        try_accept("Server: prefix match", port++, DEFAULT_TEST_PASSWORD,
                   network_addr, 16, 100, 100, 0,
                   "TCPAOGood", TEST_CNT_GOOD, 0);

        try_accept("Client: prefix match", port++, DEFAULT_TEST_PASSWORD,
                   this_ip_dest, -1, 100, 100, 0,
                   "TCPAOGood", TEST_CNT_GOOD, 0);

        /* client exits */
        synchronize_threads();
        return NULL;
}

static void try_connect(const char *tst_name, unsigned int port,
                        const char *pwd, union tcp_addr addr, uint8_t prefix,
                        uint8_t sndid, uint8_t rcvid,
                        test_cnt cnt_expected, fault_t inj)
{
        struct tcp_counters cnt1, cnt2;
        int sk, ret;

        sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
        if (sk < 0)
                test_error("socket()");

        if (pwd && test_add_key(sk, pwd, addr, prefix, sndid, rcvid))
                test_error("setsockopt(TCP_AO_ADD_KEY)");

        if (pwd && test_get_tcp_counters(sk, &cnt1))
                test_error("test_get_tcp_counters()");

        synchronize_threads(); /* preparations done */

        ret = test_skpair_connect_poll(sk, this_ip_dest, port, cnt_expected, &sk_pair);
        synchronize_threads(); /* before counter checks */
        if (ret < 0) {
                sk_pair = ret;
                if (fault(KEYREJECT) && ret == -EKEYREJECTED) {
                        test_ok("%s: connect() was prevented", tst_name);
                } else if (ret == -ETIMEDOUT && fault(TIMEOUT)) {
                        test_ok("%s", tst_name);
                } else if (ret == -ECONNREFUSED &&
                                (fault(TIMEOUT) || fault(KEYREJECT))) {
                        test_ok("%s: refused to connect", tst_name);
                } else {
                        test_error("%s: connect() returned %d", tst_name, ret);
                }
                goto out;
        }

        if (fault(TIMEOUT) || fault(KEYREJECT))
                test_fail("%s: connected", tst_name);
        else
                test_ok("%s: connected", tst_name);
        if (pwd && ret > 0) {
                if (test_get_tcp_counters(sk, &cnt2))
                        test_error("test_get_tcp_counters()");
                test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected);
        } else if (pwd) {
                test_tcp_counters_free(&cnt1);
        }
out:
        synchronize_threads(); /* close() */

        if (ret > 0)
                close(sk);
}

static void *client_fn(void *arg)
{
        union tcp_addr wrong_addr, network_addr, addr_any = {};
        unsigned int port = test_server_port;

        if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
                test_error("Can't convert ip address %s", TEST_WRONG_IP);

        trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
                              -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
        try_connect("Non-AO server + AO client", port++, DEFAULT_TEST_PASSWORD,
                        this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);

        trace_hash_event_expect(TCP_HASH_AO_REQUIRED, this_ip_addr, this_ip_dest,
                                -1, port, 0, 0, 1, 0, 0, 0);
        try_connect("AO server + Non-AO client", port++, NULL,
                        this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);

        trace_ao_event_expect(TCP_AO_MISMATCH, this_ip_addr, this_ip_dest,
                              -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
        try_connect("Wrong password", port++, DEFAULT_TEST_PASSWORD,
                        this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);

        trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
                              -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
        try_connect("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD,
                        this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);

        /*
         * XXX: The test doesn't increase any counters, see tcp_make_synack().
         * Potentially, it can be speed up by setting sk_pair = -ETIMEDOUT
         * but the price would be increased complexity of the tracer thread.
         */
        trace_ao_event_sk_expect(TCP_AO_SYNACK_NO_KEY, this_ip_dest, addr_any,
                                 port, 0, 100, 100);
        try_connect("Wrong snd id", port++, DEFAULT_TEST_PASSWORD,
                        this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);

        trace_ao_event_expect(TCP_AO_WRONG_MACLEN, this_ip_addr, this_ip_dest,
                              -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
        try_connect("Different maclen", port++, DEFAULT_TEST_PASSWORD,
                        this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);

        trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest,
                              -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1);
        try_connect("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
                        this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);

        try_connect("Client: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
                        wrong_addr, -1, 100, 100, 0, FAULT_KEYREJECT);

        try_connect("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD,
                        this_ip_dest, -1, 100, 200, TEST_CNT_GOOD, 0);

        if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1)
                test_error("Can't convert ip address %s", TEST_NETWORK);

        try_connect("Server: prefix match", port++, DEFAULT_TEST_PASSWORD,
                        this_ip_dest, -1, 100, 100, TEST_CNT_GOOD, 0);

        try_connect("Client: prefix match", port++, DEFAULT_TEST_PASSWORD,
                        network_addr, 16, 100, 100, TEST_CNT_GOOD, 0);

        return NULL;
}

int main(int argc, char *argv[])
{
        test_init(22, server_fn, client_fn);
        return 0;
}