root/tools/testing/selftests/bpf/progs/bpf_iter_setsockopt.c
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2021 Facebook */
#include <vmlinux.h>
#include "bpf_tracing_net.h"
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_endian.h>

#define bpf_tcp_sk(skc) ({                              \
        struct sock_common *_skc = skc;                 \
        sk = NULL;                                      \
        tp = NULL;                                      \
        if (_skc) {                                     \
                tp = bpf_skc_to_tcp_sock(_skc);         \
                sk = (struct sock *)tp;                 \
        }                                               \
        tp;                                             \
})

unsigned short reuse_listen_hport = 0;
unsigned short listen_hport = 0;
const char cubic_cc[] = "bpf_cubic";
char dctcp_cc[TCP_CA_NAME_MAX] = "bpf_dctcp";
bool random_retry = false;


SEC("iter/tcp")
int change_tcp_cc(struct bpf_iter__tcp *ctx)
{
        char cur_cc[TCP_CA_NAME_MAX];
        struct tcp_sock *tp;
        struct sock *sk;

        if (!bpf_tcp_sk(ctx->sk_common))
                return 0;

        if (sk->sk_family != AF_INET6 ||
            (sk->sk_state != TCP_LISTEN &&
             sk->sk_state != TCP_ESTABLISHED) ||
            (sk->sk_num != reuse_listen_hport &&
             sk->sk_num != listen_hport &&
             bpf_ntohs(sk->sk_dport) != listen_hport))
                return 0;

        if (bpf_getsockopt(tp, SOL_TCP, TCP_CONGESTION,
                           cur_cc, sizeof(cur_cc)))
                return 0;

        if (bpf_strncmp(cur_cc, TCP_CA_NAME_MAX, cubic_cc))
                return 0;

        if (random_retry && bpf_get_prandom_u32() % 4 == 1)
                return 1;

        bpf_setsockopt(tp, SOL_TCP, TCP_CONGESTION, dctcp_cc, sizeof(dctcp_cc));
        return 0;
}

char _license[] SEC("license") = "GPL";