#ifndef _TCP_ECN_H
#define _TCP_ECN_H
#include <linux/tcp.h>
#include <linux/skbuff.h>
#include <linux/bitfield.h>
#include <net/inet_connection_sock.h>
#include <net/sock.h>
#include <net/tcp.h>
#include <net/inet_ecn.h>
enum tcp_ecn_mode {
TCP_ECN_IN_NOECN_OUT_NOECN = 0,
TCP_ECN_IN_ECN_OUT_ECN = 1,
TCP_ECN_IN_ECN_OUT_NOECN = 2,
TCP_ECN_IN_ACCECN_OUT_ACCECN = 3,
TCP_ECN_IN_ACCECN_OUT_ECN = 4,
TCP_ECN_IN_ACCECN_OUT_NOECN = 5,
};
enum tcp_accecn_option {
TCP_ACCECN_OPTION_DISABLED = 0,
TCP_ACCECN_OPTION_MINIMUM = 1,
TCP_ACCECN_OPTION_FULL = 2,
TCP_ACCECN_OPTION_PERSIST = 3,
};
static inline void INET_ECN_xmit_ect_1_negotiation(struct sock *sk)
{
__INET_ECN_xmit(sk, tcp_ca_ect_1_negotiation(sk));
}
static inline void tcp_ecn_queue_cwr(struct tcp_sock *tp)
{
if (tcp_ecn_mode_rfc3168(tp))
tp->ecn_flags |= TCP_ECN_QUEUE_CWR;
}
static inline void tcp_ecn_accept_cwr(struct sock *sk,
const struct sk_buff *skb)
{
struct tcp_sock *tp = tcp_sk(sk);
if (tcp_ecn_mode_rfc3168(tp) && tcp_hdr(skb)->cwr) {
tp->ecn_flags &= ~TCP_ECN_DEMAND_CWR;
if (TCP_SKB_CB(skb)->seq != TCP_SKB_CB(skb)->end_seq)
inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_NOW;
}
}
static inline void tcp_ecn_withdraw_cwr(struct tcp_sock *tp)
{
tp->ecn_flags &= ~TCP_ECN_QUEUE_CWR;
}
static inline bool tcp_accecn_ace_fail_send(const struct tcp_sock *tp)
{
return tp->accecn_fail_mode & TCP_ACCECN_ACE_FAIL_SEND;
}
static inline bool tcp_accecn_ace_fail_recv(const struct tcp_sock *tp)
{
return tp->accecn_fail_mode & TCP_ACCECN_ACE_FAIL_RECV;
}
static inline bool tcp_accecn_opt_fail_send(const struct tcp_sock *tp)
{
return tp->accecn_fail_mode & TCP_ACCECN_OPT_FAIL_SEND;
}
static inline bool tcp_accecn_opt_fail_recv(const struct tcp_sock *tp)
{
return tp->accecn_fail_mode & TCP_ACCECN_OPT_FAIL_RECV;
}
static inline void tcp_accecn_fail_mode_set(struct tcp_sock *tp, u8 mode)
{
tp->accecn_fail_mode |= mode;
}
static inline u8 tcp_accecn_ace(const struct tcphdr *th)
{
return (th->ae << 2) | (th->cwr << 1) | th->ece;
}
static inline int tcp_accecn_extract_syn_ect(u8 ace)
{
static const int ace_to_ecn[8] = {
INET_ECN_ECT_0,
INET_ECN_ECT_1,
INET_ECN_NOT_ECT,
INET_ECN_ECT_1,
INET_ECN_ECT_0,
INET_ECN_ECT_1,
INET_ECN_CE,
INET_ECN_ECT_1
};
return ace_to_ecn[ace & 0x7];
}
static inline bool tcp_ect_transition_valid(u8 snt, u8 rcv)
{
if (rcv == snt)
return true;
if (snt == INET_ECN_NOT_ECT || rcv == INET_ECN_NOT_ECT)
return false;
if (snt == INET_ECN_CE)
return false;
return true;
}
static inline bool tcp_accecn_validate_syn_feedback(struct sock *sk, u8 ace,
u8 sent_ect)
{
u8 ect = tcp_accecn_extract_syn_ect(ace);
struct tcp_sock *tp = tcp_sk(sk);
if (!READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_ecn_fallback))
return true;
if (!tcp_ect_transition_valid(sent_ect, ect)) {
tcp_accecn_fail_mode_set(tp, TCP_ACCECN_ACE_FAIL_RECV);
return false;
}
return true;
}
static inline void tcp_accecn_saw_opt_fail_recv(struct tcp_sock *tp,
u8 saw_opt)
{
tp->saw_accecn_opt = saw_opt;
if (tp->saw_accecn_opt == TCP_ACCECN_OPT_FAIL_SEEN)
tcp_accecn_fail_mode_set(tp, TCP_ACCECN_OPT_FAIL_RECV);
}
static inline void tcp_accecn_third_ack(struct sock *sk,
const struct sk_buff *skb, u8 sent_ect)
{
u8 ace = tcp_accecn_ace(tcp_hdr(skb));
struct tcp_sock *tp = tcp_sk(sk);
switch (ace) {
case 0x0:
if (!TCP_SKB_CB(skb)->sacked)
tcp_accecn_fail_mode_set(tp, TCP_ACCECN_ACE_FAIL_RECV |
TCP_ACCECN_OPT_FAIL_RECV);
break;
case 0x7:
case 0x5:
case 0x1:
break;
default:
if (TCP_SKB_CB(skb)->seq == TCP_SKB_CB(skb)->end_seq &&
!TCP_SKB_CB(skb)->sacked &&
tcp_accecn_validate_syn_feedback(sk, ace, sent_ect)) {
if ((tcp_accecn_extract_syn_ect(ace) == INET_ECN_CE) &&
!tp->delivered_ce)
tp->delivered_ce++;
}
break;
}
}
static inline void tcp_accecn_opt_demand_min(struct sock *sk,
u8 opt_demand_min)
{
struct tcp_sock *tp = tcp_sk(sk);
u8 opt_demand;
opt_demand = max_t(u8, opt_demand_min, tp->accecn_opt_demand);
tp->accecn_opt_demand = opt_demand;
}
static inline u8 tcp_ecnfield_to_accecn_optfield(u8 ecnfield)
{
switch (ecnfield & INET_ECN_MASK) {
case INET_ECN_NOT_ECT:
return 0;
case INET_ECN_ECT_1:
return 1;
case INET_ECN_CE:
return 2;
case INET_ECN_ECT_0:
return 3;
}
return 0;
}
static inline u32 tcp_accecn_field_init_offset(u8 ecnfield)
{
switch (ecnfield & INET_ECN_MASK) {
case INET_ECN_NOT_ECT:
return 0;
case INET_ECN_ECT_1:
return TCP_ACCECN_E1B_INIT_OFFSET;
case INET_ECN_CE:
return TCP_ACCECN_CEB_INIT_OFFSET;
case INET_ECN_ECT_0:
return TCP_ACCECN_E0B_INIT_OFFSET;
}
return 0;
}
static inline unsigned int tcp_accecn_optfield_to_ecnfield(unsigned int option,
bool order)
{
static const u8 optfield_lookup[2][3] = {
{ INET_ECN_ECT_0, INET_ECN_CE, INET_ECN_ECT_1 },
{ INET_ECN_ECT_1, INET_ECN_CE, INET_ECN_ECT_0 }
};
return optfield_lookup[order][option % 3];
}
static inline s32 tcp_update_ecn_bytes(u32 *cnt, const char *from,
u32 init_offset)
{
u32 truncated = (get_unaligned_be32(from - 1) - init_offset) &
0xFFFFFFU;
u32 delta = (truncated - *cnt) & 0xFFFFFFU;
delta = sign_extend32(delta, 23);
*cnt += delta;
return (s32)delta;
}
static inline void tcp_ecn_received_counters(struct sock *sk,
const struct sk_buff *skb, u32 len)
{
u8 ecnfield = TCP_SKB_CB(skb)->ip_dsfield & INET_ECN_MASK;
u8 is_ce = INET_ECN_is_ce(ecnfield);
struct tcp_sock *tp = tcp_sk(sk);
bool ecn_edge;
if (!INET_ECN_is_not_ect(ecnfield)) {
u32 pcount = is_ce * max_t(u16, 1, skb_shinfo(skb)->gso_segs);
if (!tcp_ecn_mode_rfc3168(tp))
tp->ecn_flags |= TCP_ECN_SEEN;
tp->received_ce += pcount;
tp->received_ce_pending = min(tp->received_ce_pending + pcount,
0xfU);
if (len > 0) {
u8 minlen = tcp_ecnfield_to_accecn_optfield(ecnfield);
u32 oldbytes = tp->received_ecn_bytes[ecnfield - 1];
u32 bytes_mask = GENMASK_U32(31, 22);
tp->received_ecn_bytes[ecnfield - 1] += len;
tp->accecn_minlen = max_t(u8, tp->accecn_minlen,
minlen);
if ((tp->received_ecn_bytes[ecnfield - 1] ^ oldbytes) &
bytes_mask) {
tcp_accecn_opt_demand_min(sk, 1);
}
}
}
ecn_edge = tp->prev_ecnfield != ecnfield;
if (ecn_edge || is_ce) {
tp->prev_ecnfield = ecnfield;
if (tcp_ecn_mode_accecn(tp)) {
if (ecn_edge)
inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_NOW;
tp->accecn_opt_demand = 2;
}
}
}
static inline void tcp_ecn_received_counters_payload(struct sock *sk,
const struct sk_buff *skb)
{
const struct tcphdr *th = (const struct tcphdr *)skb->data;
tcp_ecn_received_counters(sk, skb, skb->len - th->doff * 4);
}
static inline bool cookie_accecn_ok(const struct tcphdr *th)
{
return tcp_accecn_ace(th) > 0x1;
}
static inline u16 tcp_accecn_reflector_flags(u8 ect)
{
static const u8 ecn_to_ace_flags[4] = {
0b010,
0b011,
0b100,
0b110
};
return FIELD_PREP(TCPHDR_ACE, ecn_to_ace_flags[ect & 0x3]);
}
static inline bool tcp_accecn_syn_requested(const struct tcphdr *th)
{
u8 ace = tcp_accecn_ace(th);
return ace && ace != 0x3;
}
static inline void __tcp_accecn_init_bytes_counters(int *counter_array)
{
BUILD_BUG_ON(INET_ECN_ECT_1 != 0x1);
BUILD_BUG_ON(INET_ECN_ECT_0 != 0x2);
BUILD_BUG_ON(INET_ECN_CE != 0x3);
counter_array[INET_ECN_ECT_1 - 1] = 0;
counter_array[INET_ECN_ECT_0 - 1] = 0;
counter_array[INET_ECN_CE - 1] = 0;
}
static inline void tcp_accecn_init_counters(struct tcp_sock *tp)
{
tp->received_ce = 0;
tp->received_ce_pending = 0;
__tcp_accecn_init_bytes_counters(tp->received_ecn_bytes);
__tcp_accecn_init_bytes_counters(tp->delivered_ecn_bytes);
tp->accecn_opt_sent_w_dsack = 0;
tp->accecn_minlen = 0;
tp->accecn_opt_demand = 0;
tp->est_ecnfield = 0;
}
static inline void tcp_accecn_echo_syn_ect(struct tcphdr *th, u8 ect)
{
th->ae = !!(ect & INET_ECN_ECT_0);
th->cwr = ect != INET_ECN_ECT_0;
th->ece = ect == INET_ECN_ECT_1;
}
static inline void tcp_accecn_set_ace(struct tcp_sock *tp, struct sk_buff *skb,
struct tcphdr *th)
{
u32 wire_ace;
if (likely(!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_ACE))) {
wire_ace = tp->received_ce + TCP_ACCECN_CEP_INIT_OFFSET;
th->ece = !!(wire_ace & 0x1);
th->cwr = !!(wire_ace & 0x2);
th->ae = !!(wire_ace & 0x4);
tp->received_ce_pending = 0;
}
}
static inline u8 tcp_accecn_option_init(const struct sk_buff *skb,
u8 opt_offset)
{
u8 *ptr = skb_transport_header(skb) + opt_offset;
unsigned int optlen = ptr[1] - 2;
if (WARN_ON_ONCE(ptr[0] != TCPOPT_ACCECN0 && ptr[0] != TCPOPT_ACCECN1))
return TCP_ACCECN_OPT_FAIL_SEEN;
ptr += 2;
if (optlen < TCPOLEN_ACCECN_PERFIELD)
return TCP_ACCECN_OPT_EMPTY_SEEN;
if (get_unaligned_be24(ptr) == 0)
return TCP_ACCECN_OPT_FAIL_SEEN;
if (optlen < TCPOLEN_ACCECN_PERFIELD * 3)
return TCP_ACCECN_OPT_COUNTER_SEEN;
ptr += TCPOLEN_ACCECN_PERFIELD * 2;
if (get_unaligned_be24(ptr) == 0)
return TCP_ACCECN_OPT_FAIL_SEEN;
return TCP_ACCECN_OPT_COUNTER_SEEN;
}
static inline void tcp_ecn_rcv_synack_accecn(struct sock *sk,
const struct sk_buff *skb, u8 dsf)
{
struct tcp_sock *tp = tcp_sk(sk);
tcp_ecn_mode_set(tp, TCP_ECN_MODE_ACCECN);
tp->syn_ect_rcv = dsf & INET_ECN_MASK;
if (tp->rx_opt.accecn &&
tp->saw_accecn_opt < TCP_ACCECN_OPT_COUNTER_SEEN) {
u8 saw_opt = tcp_accecn_option_init(skb, tp->rx_opt.accecn);
tcp_accecn_saw_opt_fail_recv(tp, saw_opt);
tp->accecn_opt_demand = 2;
}
}
static inline void tcp_ecn_rcv_synack(struct sock *sk, const struct sk_buff *skb,
const struct tcphdr *th, u8 ip_dsfield)
{
struct tcp_sock *tp = tcp_sk(sk);
u8 ace = tcp_accecn_ace(th);
switch (ace) {
case 0x0:
case 0x7:
tcp_ecn_mode_set(tp, TCP_ECN_DISABLED);
break;
case 0x1:
if (tcp_ca_no_fallback_rfc3168(sk))
tcp_ecn_mode_set(tp, TCP_ECN_DISABLED);
else
tcp_ecn_mode_set(tp, TCP_ECN_MODE_RFC3168);
break;
case 0x5:
if (tcp_ecn_mode_pending(tp)) {
tcp_ecn_rcv_synack_accecn(sk, skb, ip_dsfield);
if (INET_ECN_is_ce(ip_dsfield)) {
tp->received_ce++;
tp->received_ce_pending++;
}
}
break;
default:
tcp_ecn_rcv_synack_accecn(sk, skb, ip_dsfield);
if (INET_ECN_is_ce(ip_dsfield) &&
tcp_accecn_validate_syn_feedback(sk, ace,
tp->syn_ect_snt)) {
tp->received_ce++;
tp->received_ce_pending++;
}
break;
}
}
static inline void tcp_ecn_rcv_syn(struct sock *sk, const struct tcphdr *th,
const struct sk_buff *skb)
{
struct tcp_sock *tp = tcp_sk(sk);
if (tcp_ecn_mode_pending(tp)) {
if (!tcp_accecn_syn_requested(th)) {
tcp_ecn_mode_set(tp, TCP_ECN_MODE_RFC3168);
} else {
tp->syn_ect_rcv = TCP_SKB_CB(skb)->ip_dsfield &
INET_ECN_MASK;
tp->prev_ecnfield = tp->syn_ect_rcv;
tcp_ecn_mode_set(tp, TCP_ECN_MODE_ACCECN);
}
}
if (tcp_ecn_mode_rfc3168(tp) &&
(!th->ece || !th->cwr || tcp_ca_no_fallback_rfc3168(sk)))
tcp_ecn_mode_set(tp, TCP_ECN_DISABLED);
}
static inline bool tcp_ecn_rcv_ecn_echo(const struct tcp_sock *tp,
const struct tcphdr *th)
{
if (th->ece && !th->syn && tcp_ecn_mode_rfc3168(tp))
return true;
return false;
}
static inline void tcp_ecn_send_synack(struct sock *sk, struct sk_buff *skb)
{
struct tcp_sock *tp = tcp_sk(sk);
TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_CWR;
if (tcp_ecn_disabled(tp))
TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_ECE;
else if (tcp_ca_needs_ecn(sk) ||
tcp_bpf_ca_needs_ecn(sk))
INET_ECN_xmit_ect_1_negotiation(sk);
if (tp->ecn_flags & TCP_ECN_MODE_ACCECN) {
TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_ACE;
TCP_SKB_CB(skb)->tcp_flags |=
tcp_accecn_reflector_flags(tp->syn_ect_rcv);
tp->syn_ect_snt = inet_sk(sk)->tos & INET_ECN_MASK;
}
}
static inline void tcp_ecn_send_syn(struct sock *sk, struct sk_buff *skb)
{
struct tcp_sock *tp = tcp_sk(sk);
bool bpf_needs_ecn = tcp_bpf_ca_needs_ecn(sk);
bool use_ecn, use_accecn;
u8 tcp_ecn = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_ecn);
use_accecn = tcp_ecn == TCP_ECN_IN_ACCECN_OUT_ACCECN ||
tcp_ca_needs_accecn(sk);
use_ecn = tcp_ecn == TCP_ECN_IN_ECN_OUT_ECN ||
tcp_ecn == TCP_ECN_IN_ACCECN_OUT_ECN ||
tcp_ca_needs_ecn(sk) || bpf_needs_ecn || use_accecn;
if (!use_ecn) {
const struct dst_entry *dst = __sk_dst_get(sk);
if (dst && dst_feature(dst, RTAX_FEATURE_ECN))
use_ecn = true;
}
tp->ecn_flags = 0;
if (use_ecn) {
if (tcp_ca_needs_ecn(sk) || bpf_needs_ecn)
INET_ECN_xmit_ect_1_negotiation(sk);
TCP_SKB_CB(skb)->tcp_flags |= TCPHDR_ECE | TCPHDR_CWR;
if (use_accecn) {
TCP_SKB_CB(skb)->tcp_flags |= TCPHDR_AE;
tcp_ecn_mode_set(tp, TCP_ECN_MODE_PENDING);
tp->syn_ect_snt = inet_sk(sk)->tos & INET_ECN_MASK;
} else {
tcp_ecn_mode_set(tp, TCP_ECN_MODE_RFC3168);
}
}
}
static inline void tcp_ecn_clear_syn(struct sock *sk, struct sk_buff *skb)
{
if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_ecn_fallback)) {
TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_ACE;
}
}
static inline void
tcp_ecn_make_synack(const struct request_sock *req, struct tcphdr *th,
enum tcp_synack_type synack_type)
{
if (!req->num_timeout || synack_type != TCP_SYNACK_RETRANS) {
if (tcp_rsk(req)->accecn_ok)
tcp_accecn_echo_syn_ect(th, tcp_rsk(req)->syn_ect_rcv);
else if (inet_rsk(req)->ecn_ok)
th->ece = 1;
} else if (tcp_rsk(req)->accecn_ok) {
th->ae = 0;
th->cwr = 0;
th->ece = 0;
}
}
static inline bool tcp_accecn_option_beacon_check(const struct sock *sk)
{
u32 ecn_beacon = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_ecn_option_beacon);
const struct tcp_sock *tp = tcp_sk(sk);
if (!ecn_beacon)
return false;
return tcp_stamp_us_delta(tp->tcp_mstamp, tp->accecn_opt_tstamp) * ecn_beacon >=
(tp->srtt_us >> 3);
}
#endif