#include <inttypes.h>
#include "../../../../include/linux/kernel.h"
#include "aolib.h"
const size_t quota = 1000;
const size_t packet_sz = 100;
const unsigned int backlog;
static void netstats_check(struct netstat *before, struct netstat *after,
char *msg)
{
uint64_t before_cnt, after_cnt;
before_cnt = netstat_get(before, "TCPAORequired", NULL);
after_cnt = netstat_get(after, "TCPAORequired", NULL);
if (after_cnt > before_cnt)
test_fail("Segments without AO sign (%s): %" PRIu64 " => %" PRIu64,
msg, before_cnt, after_cnt);
else
test_ok("No segments without AO sign (%s)", msg);
before_cnt = netstat_get(before, "TCPAOGood", NULL);
after_cnt = netstat_get(after, "TCPAOGood", NULL);
if (after_cnt <= before_cnt)
test_fail("Signed AO segments (%s): %" PRIu64 " => %" PRIu64,
msg, before_cnt, after_cnt);
else
test_ok("Signed AO segments (%s): %" PRIu64 " => %" PRIu64,
msg, before_cnt, after_cnt);
before_cnt = netstat_get(before, "TCPAOBad", NULL);
after_cnt = netstat_get(after, "TCPAOBad", NULL);
if (after_cnt > before_cnt)
test_fail("Segments with bad AO sign (%s): %" PRIu64 " => %" PRIu64,
msg, before_cnt, after_cnt);
else
test_ok("No segments with bad AO sign (%s)", msg);
}
static void close_forced(int sk)
{
struct linger sl;
sl.l_onoff = 1;
sl.l_linger = 0;
if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)))
test_error("setsockopt(SO_LINGER)");
close(sk);
}
static void test_server_active_rst(unsigned int port)
{
struct tcp_counters cnt1, cnt2;
ssize_t bytes;
int sk, lsk;
lsk = test_listen_socket(this_ip_addr, port, backlog);
if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
test_error("setsockopt(TCP_AO_ADD_KEY)");
if (test_get_tcp_counters(lsk, &cnt1))
test_error("test_get_tcp_counters()");
synchronize_threads();
if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
test_error("test_wait_fd()");
sk = accept(lsk, NULL, NULL);
if (sk < 0)
test_error("accept()");
synchronize_threads();
if (test_get_tcp_counters(lsk, &cnt2))
test_error("test_get_tcp_counters()");
synchronize_threads();
close(lsk);
bytes = test_server_run(sk, quota, 0);
if (bytes != quota)
test_error("servered only %zd bytes", bytes);
else
test_ok("servered %zd bytes", bytes);
synchronize_threads();
close_forced(sk);
synchronize_threads();
synchronize_threads();
if (test_assert_counters("active RST server", &cnt1, &cnt2, TEST_CNT_GOOD))
test_fail("MKT counters (server) have not only good packets");
else
test_ok("MKT counters are good on server");
}
static void test_server_passive_rst(unsigned int port)
{
struct tcp_counters cnt1, cnt2;
int sk, lsk;
ssize_t bytes;
lsk = test_listen_socket(this_ip_addr, port, 1);
if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
test_error("setsockopt(TCP_AO_ADD_KEY)");
synchronize_threads();
if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
test_error("test_wait_fd()");
sk = accept(lsk, NULL, NULL);
if (sk < 0)
test_error("accept()");
synchronize_threads();
close(lsk);
if (test_get_tcp_counters(sk, &cnt1))
test_error("test_get_tcp_counters()");
bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
if (bytes != quota) {
if (bytes > 0)
test_fail("server served: %zd", bytes);
else
test_fail("server returned %zd", bytes);
}
synchronize_threads();
synchronize_threads();
if (test_get_tcp_counters(sk, &cnt2))
test_error("test_get_tcp_counters()");
close(sk);
synchronize_threads();
test_assert_counters("passive RST server", &cnt1, &cnt2, TEST_CNT_GOOD);
synchronize_threads();
}
static void *server_fn(void *arg)
{
struct netstat *ns_before, *ns_after;
unsigned int port = test_server_port;
ns_before = netstat_read();
test_server_active_rst(port++);
test_server_passive_rst(port++);
ns_after = netstat_read();
netstats_check(ns_before, ns_after, "server");
netstat_free(ns_after);
netstat_free(ns_before);
synchronize_threads();
synchronize_threads();
return NULL;
}
static int test_wait_fds(int sk[], size_t nr, bool is_writable[],
ssize_t wait_for, time_t sec)
{
struct timeval tv = { .tv_sec = sec };
struct timeval *ptv = NULL;
fd_set left;
size_t i;
int ret;
FD_ZERO(&left);
for (i = 0; i < nr; i++) {
FD_SET(sk[i], &left);
if (is_writable)
is_writable[i] = false;
}
if (sec)
ptv = &tv;
do {
bool is_empty = true;
fd_set fds, efds;
int nfd = 0;
FD_ZERO(&fds);
FD_ZERO(&efds);
for (i = 0; i < nr; i++) {
if (!FD_ISSET(sk[i], &left))
continue;
if (sk[i] > nfd)
nfd = sk[i];
FD_SET(sk[i], &fds);
FD_SET(sk[i], &efds);
is_empty = false;
}
if (is_empty)
return -ENOENT;
errno = 0;
ret = select(nfd + 1, NULL, &fds, &efds, ptv);
if (ret < 0)
return -errno;
if (!ret)
return -ETIMEDOUT;
for (i = 0; i < nr; i++) {
if (FD_ISSET(sk[i], &fds)) {
if (is_writable)
is_writable[i] = true;
FD_CLR(sk[i], &left);
wait_for--;
continue;
}
if (FD_ISSET(sk[i], &efds)) {
FD_CLR(sk[i], &left);
wait_for--;
}
}
} while (wait_for > 0);
return 0;
}
static void test_client_active_rst(unsigned int port)
{
int i, sk[3], err;
bool is_writable[ARRAY_SIZE(sk)] = {false};
unsigned int last = ARRAY_SIZE(sk) - 1;
for (i = 0; i < ARRAY_SIZE(sk); i++) {
sk[i] = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
if (sk[i] < 0)
test_error("socket()");
if (test_add_key(sk[i], DEFAULT_TEST_PASSWORD,
this_ip_dest, -1, 100, 100))
test_error("setsockopt(TCP_AO_ADD_KEY)");
}
synchronize_threads();
for (i = 0; i < last; i++) {
err = _test_connect_socket(sk[i], this_ip_dest, port, i != 0);
if (err < 0)
test_error("failed to connect()");
}
synchronize_threads();
err = test_wait_fds(sk, last, is_writable, last, TEST_TIMEOUT_SEC);
if (err < 0)
test_error("test_wait_fds(): %d", err);
err = _test_connect_socket(sk[last], this_ip_dest, port, 1);
if (err < 0)
test_error("failed to connect()");
synchronize_threads();
if (test_client_verify(sk[0], packet_sz, quota / packet_sz))
test_fail("Failed to send data on connected socket");
else
test_ok("Verified established tcp connection");
synchronize_threads();
synchronize_threads();
err = test_wait_fds(sk, last, NULL, last, TEST_TIMEOUT_SEC);
if (err < 0)
test_error("select(): %d", err);
for (i = 0; i < ARRAY_SIZE(sk); i++) {
socklen_t slen = sizeof(err);
if (getsockopt(sk[i], SOL_SOCKET, SO_ERROR, &err, &slen))
test_error("getsockopt()");
if (is_writable[i] && err != ECONNRESET) {
test_fail("sk[%d] = %d, err = %d, connection wasn't reset",
i, sk[i], err);
} else {
test_ok("sk[%d] = %d%s", i, sk[i],
is_writable[i] ? ", connection was reset" : "");
}
}
synchronize_threads();
}
static void test_client_passive_rst(unsigned int port)
{
struct tcp_counters cnt1, cnt2;
struct tcp_ao_repair ao_img;
struct tcp_sock_state img;
sockaddr_af saddr;
int sk, err;
sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
if (sk < 0)
test_error("socket()");
if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
test_error("setsockopt(TCP_AO_ADD_KEY)");
synchronize_threads();
if (test_connect_socket(sk, this_ip_dest, port) <= 0)
test_error("failed to connect()");
synchronize_threads();
if (test_client_verify(sk, packet_sz, quota / packet_sz))
test_fail("Failed to send data on connected socket");
else
test_ok("Verified established tcp connection");
synchronize_threads();
test_enable_repair(sk);
test_sock_checkpoint(sk, &img, &saddr);
test_ao_checkpoint(sk, &ao_img);
test_disable_repair(sk);
synchronize_threads();
img.out.seq += 5;
sleep(TEST_RETRANSMIT_SEC);
synchronize_threads();
test_kill_sk(sk);
sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
if (sk < 0)
test_error("socket()");
test_enable_repair(sk);
test_sock_restore(sk, &img, &saddr, this_ip_dest, port);
if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, this_ip_dest, -1, 100, 100))
test_error("setsockopt(TCP_AO_ADD_KEY)");
test_ao_restore(sk, &ao_img);
if (test_get_tcp_counters(sk, &cnt1))
test_error("test_get_tcp_counters()");
test_disable_repair(sk);
test_sock_state_free(&img);
err = test_client_verify(sk, packet_sz, quota / packet_sz);
if (err && err == -ECONNRESET)
test_ok("client sock was passively reset post-seq-adjust");
else if (err)
test_fail("client sock was not reset post-seq-adjust: %d", err);
else
test_fail("client sock is yet connected post-seq-adjust");
if (test_get_tcp_counters(sk, &cnt2))
test_error("test_get_tcp_counters()");
synchronize_threads();
close(sk);
test_assert_counters("client passive RST", &cnt1, &cnt2, TEST_CNT_GOOD);
}
static void *client_fn(void *arg)
{
struct netstat *ns_before, *ns_after;
unsigned int port = test_server_port;
ns_before = netstat_read();
test_client_active_rst(port++);
test_client_passive_rst(port++);
ns_after = netstat_read();
netstats_check(ns_before, ns_after, "client");
netstat_free(ns_after);
netstat_free(ns_before);
synchronize_threads();
return NULL;
}
int main(int argc, char *argv[])
{
test_init(15, server_fn, client_fn);
return 0;
}