root/net/netfilter/nf_conntrack_proto_sctp.c
// SPDX-License-Identifier: GPL-2.0-only
/*
 * Connection tracking protocol helper module for SCTP.
 *
 * Copyright (c) 2004 Kiran Kumar Immidi <immidi_kiran@yahoo.com>
 * Copyright (c) 2004-2012 Patrick McHardy <kaber@trash.net>
 *
 * SCTP is defined in RFC 2960. References to various sections in this code
 * are to this RFC.
 */

#include <linux/types.h>
#include <linux/timer.h>
#include <linux/netfilter.h>
#include <linux/in.h>
#include <linux/ip.h>
#include <linux/sctp.h>
#include <linux/string.h>
#include <linux/seq_file.h>
#include <linux/spinlock.h>
#include <linux/interrupt.h>
#include <net/sctp/checksum.h>

#include <net/netfilter/nf_log.h>
#include <net/netfilter/nf_conntrack.h>
#include <net/netfilter/nf_conntrack_l4proto.h>
#include <net/netfilter/nf_conntrack_ecache.h>
#include <net/netfilter/nf_conntrack_timeout.h>

static const char *const sctp_conntrack_names[] = {
        [SCTP_CONNTRACK_NONE]                   = "NONE",
        [SCTP_CONNTRACK_CLOSED]                 = "CLOSED",
        [SCTP_CONNTRACK_COOKIE_WAIT]            = "COOKIE_WAIT",
        [SCTP_CONNTRACK_COOKIE_ECHOED]          = "COOKIE_ECHOED",
        [SCTP_CONNTRACK_ESTABLISHED]            = "ESTABLISHED",
        [SCTP_CONNTRACK_SHUTDOWN_SENT]          = "SHUTDOWN_SENT",
        [SCTP_CONNTRACK_SHUTDOWN_RECD]          = "SHUTDOWN_RECD",
        [SCTP_CONNTRACK_SHUTDOWN_ACK_SENT]      = "SHUTDOWN_ACK_SENT",
        [SCTP_CONNTRACK_HEARTBEAT_SENT]         = "HEARTBEAT_SENT",
};

static const unsigned int sctp_timeouts[SCTP_CONNTRACK_MAX] = {
        [SCTP_CONNTRACK_CLOSED]                 = secs_to_jiffies(10),
        [SCTP_CONNTRACK_COOKIE_WAIT]            = secs_to_jiffies(3),
        [SCTP_CONNTRACK_COOKIE_ECHOED]          = secs_to_jiffies(3),
        [SCTP_CONNTRACK_ESTABLISHED]            = secs_to_jiffies(210),
        [SCTP_CONNTRACK_SHUTDOWN_SENT]          = secs_to_jiffies(3),
        [SCTP_CONNTRACK_SHUTDOWN_RECD]          = secs_to_jiffies(3),
        [SCTP_CONNTRACK_SHUTDOWN_ACK_SENT]      = secs_to_jiffies(3),
        [SCTP_CONNTRACK_HEARTBEAT_SENT]         = secs_to_jiffies(30),
};

#define SCTP_FLAG_HEARTBEAT_VTAG_FAILED 1

#define sNO SCTP_CONNTRACK_NONE
#define sCL SCTP_CONNTRACK_CLOSED
#define sCW SCTP_CONNTRACK_COOKIE_WAIT
#define sCE SCTP_CONNTRACK_COOKIE_ECHOED
#define sES SCTP_CONNTRACK_ESTABLISHED
#define sSS SCTP_CONNTRACK_SHUTDOWN_SENT
#define sSR SCTP_CONNTRACK_SHUTDOWN_RECD
#define sSA SCTP_CONNTRACK_SHUTDOWN_ACK_SENT
#define sHS SCTP_CONNTRACK_HEARTBEAT_SENT
#define sIV SCTP_CONNTRACK_MAX

/*
        These are the descriptions of the states:

NOTE: These state names are tantalizingly similar to the states of an
SCTP endpoint. But the interpretation of the states is a little different,
considering that these are the states of the connection and not of an end
point. Please note the subtleties. -Kiran

NONE              - Nothing so far.
COOKIE WAIT       - We have seen an INIT chunk in the original direction, or also
                    an INIT_ACK chunk in the reply direction.
COOKIE ECHOED     - We have seen a COOKIE_ECHO chunk in the original direction.
ESTABLISHED       - We have seen a COOKIE_ACK in the reply direction.
SHUTDOWN_SENT     - We have seen a SHUTDOWN chunk in the original direction.
SHUTDOWN_RECD     - We have seen a SHUTDOWN chunk in the reply direction.
SHUTDOWN_ACK_SENT - We have seen a SHUTDOWN_ACK chunk in the direction opposite
                    to that of the SHUTDOWN chunk.
CLOSED            - We have seen a SHUTDOWN_COMPLETE chunk in the direction of
                    the SHUTDOWN chunk. Connection is closed.
HEARTBEAT_SENT    - We have seen a HEARTBEAT in a new flow.
*/

/* TODO
 - I have assumed that the first INIT is in the original direction.
 This messes things when an INIT comes in the reply direction in CLOSED
 state.
 - Check the error type in the reply dir before transitioning from
cookie echoed to closed.
 - Sec 5.2.4 of RFC 2960
 - Full Multi Homing support.
*/

/* SCTP conntrack state transitions */
static const u8 sctp_conntracks[2][11][SCTP_CONNTRACK_MAX] = {
        {
/*      ORIGINAL        */
/*                  sNO, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS */
/* init         */ {sCL, sCL, sCW, sCE, sES, sCL, sCL, sSA, sCW},
/* init_ack     */ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sCL},
/* abort        */ {sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL},
/* shutdown     */ {sCL, sCL, sCW, sCE, sSS, sSS, sSR, sSA, sCL},
/* shutdown_ack */ {sSA, sCL, sCW, sCE, sES, sSA, sSA, sSA, sSA},
/* error        */ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sCL},/* Can't have Stale cookie*/
/* cookie_echo  */ {sCL, sCL, sCE, sCE, sES, sSS, sSR, sSA, sCL},/* 5.2.4 - Big TODO */
/* cookie_ack   */ {sCL, sCL, sCW, sES, sES, sSS, sSR, sSA, sCL},/* Can't come in orig dir */
/* shutdown_comp*/ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sCL, sCL},
/* heartbeat    */ {sHS, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS},
/* heartbeat_ack*/ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS},
        },
        {
/*      REPLY   */
/*                  sNO, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS */
/* init         */ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sIV},/* INIT in sCL Big TODO */
/* init_ack     */ {sIV, sCW, sCW, sCE, sES, sSS, sSR, sSA, sIV},
/* abort        */ {sIV, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sIV},
/* shutdown     */ {sIV, sCL, sCW, sCE, sSR, sSS, sSR, sSA, sIV},
/* shutdown_ack */ {sIV, sCL, sCW, sCE, sES, sSA, sSA, sSA, sIV},
/* error        */ {sIV, sCL, sCW, sCL, sES, sSS, sSR, sSA, sIV},
/* cookie_echo  */ {sIV, sCL, sCE, sCE, sES, sSS, sSR, sSA, sIV},/* Can't come in reply dir */
/* cookie_ack   */ {sIV, sCL, sCW, sES, sES, sSS, sSR, sSA, sIV},
/* shutdown_comp*/ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sCL, sIV},
/* heartbeat    */ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS},
/* heartbeat_ack*/ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sES},
        }
};

#ifdef CONFIG_NF_CONNTRACK_PROCFS
/* Print out the private part of the conntrack. */
static void sctp_print_conntrack(struct seq_file *s, struct nf_conn *ct)
{
        seq_printf(s, "%s ", sctp_conntrack_names[ct->proto.sctp.state]);
}
#endif

/* do_basic_checks ensures sch->length > 0, do not use before */
#define for_each_sctp_chunk(skb, sch, _sch, offset, dataoff, count)     \
for ((offset) = (dataoff) + sizeof(struct sctphdr), (count) = 0;        \
        (offset) < (skb)->len &&                                        \
        ((sch) = skb_header_pointer((skb), (offset), sizeof(_sch), &(_sch)));   \
        (offset) += (ntohs((sch)->length) + 3) & ~3, (count)++)

/* Some validity checks to make sure the chunks are fine */
static int do_basic_checks(struct nf_conn *ct,
                           const struct sk_buff *skb,
                           unsigned int dataoff,
                           unsigned long *map,
                           const struct nf_hook_state *state)
{
        u_int32_t offset, count;
        struct sctp_chunkhdr _sch, *sch;
        int flag;

        flag = 0;

        for_each_sctp_chunk (skb, sch, _sch, offset, dataoff, count) {
                if (sch->type == SCTP_CID_INIT ||
                    sch->type == SCTP_CID_INIT_ACK ||
                    sch->type == SCTP_CID_SHUTDOWN_COMPLETE)
                        flag = 1;

                /*
                 * Cookie Ack/Echo chunks not the first OR
                 * Init / Init Ack / Shutdown compl chunks not the only chunks
                 * OR zero-length.
                 */
                if (((sch->type == SCTP_CID_COOKIE_ACK ||
                      sch->type == SCTP_CID_COOKIE_ECHO ||
                      flag) &&
                     count != 0) || !sch->length) {
                        nf_ct_l4proto_log_invalid(skb, ct, state,
                                                  "%s failed. chunk num %d, type %d, len %d flag %d\n",
                                                  __func__, count, sch->type, sch->length, flag);
                        return 1;
                }

                if (map)
                        set_bit(sch->type, map);
        }

        return count == 0;
}

static int sctp_new_state(enum ip_conntrack_dir dir,
                          enum sctp_conntrack cur_state,
                          int chunk_type)
{
        int i;

        switch (chunk_type) {
        case SCTP_CID_INIT:
                i = 0;
                break;
        case SCTP_CID_INIT_ACK:
                i = 1;
                break;
        case SCTP_CID_ABORT:
                i = 2;
                break;
        case SCTP_CID_SHUTDOWN:
                i = 3;
                break;
        case SCTP_CID_SHUTDOWN_ACK:
                i = 4;
                break;
        case SCTP_CID_ERROR:
                i = 5;
                break;
        case SCTP_CID_COOKIE_ECHO:
                i = 6;
                break;
        case SCTP_CID_COOKIE_ACK:
                i = 7;
                break;
        case SCTP_CID_SHUTDOWN_COMPLETE:
                i = 8;
                break;
        case SCTP_CID_HEARTBEAT:
                i = 9;
                break;
        case SCTP_CID_HEARTBEAT_ACK:
                i = 10;
                break;
        default:
                /* Other chunks like DATA or SACK do not change the state */
                pr_debug("Unknown chunk type %d, Will stay in %s\n",
                         chunk_type, sctp_conntrack_names[cur_state]);
                return cur_state;
        }

        return sctp_conntracks[dir][i][cur_state];
}

/* Don't need lock here: this conntrack not in circulation yet */
static noinline bool
sctp_new(struct nf_conn *ct, const struct sk_buff *skb,
         const struct sctphdr *sh, unsigned int dataoff)
{
        enum sctp_conntrack new_state;
        const struct sctp_chunkhdr *sch;
        struct sctp_chunkhdr _sch;
        u32 offset, count;

        memset(&ct->proto.sctp, 0, sizeof(ct->proto.sctp));
        new_state = SCTP_CONNTRACK_MAX;
        for_each_sctp_chunk(skb, sch, _sch, offset, dataoff, count) {
                new_state = sctp_new_state(IP_CT_DIR_ORIGINAL,
                                           SCTP_CONNTRACK_NONE, sch->type);

                /* Invalid: delete conntrack */
                if (new_state == SCTP_CONNTRACK_NONE ||
                    new_state == SCTP_CONNTRACK_MAX) {
                        pr_debug("nf_conntrack_sctp: invalid new deleting.\n");
                        return false;
                }

                /* Copy the vtag into the state info */
                if (sch->type == SCTP_CID_INIT) {
                        struct sctp_inithdr _inithdr, *ih;
                        /* Sec 8.5.1 (A) */
                        if (sh->vtag)
                                return false;

                        ih = skb_header_pointer(skb, offset + sizeof(_sch),
                                                sizeof(_inithdr), &_inithdr);
                        if (!ih)
                                return false;

                        pr_debug("Setting vtag %x for new conn\n",
                                 ih->init_tag);

                        ct->proto.sctp.vtag[IP_CT_DIR_REPLY] = ih->init_tag;
                } else if (sch->type == SCTP_CID_HEARTBEAT) {
                        pr_debug("Setting vtag %x for secondary conntrack\n",
                                 sh->vtag);
                        ct->proto.sctp.vtag[IP_CT_DIR_ORIGINAL] = sh->vtag;
                } else if (sch->type == SCTP_CID_SHUTDOWN_ACK) {
                /* If it is a shutdown ack OOTB packet, we expect a return
                   shutdown complete, otherwise an ABORT Sec 8.4 (5) and (8) */
                        pr_debug("Setting vtag %x for new conn OOTB\n",
                                 sh->vtag);
                        ct->proto.sctp.vtag[IP_CT_DIR_REPLY] = sh->vtag;
                }

                ct->proto.sctp.state = SCTP_CONNTRACK_NONE;
        }

        return true;
}

static bool sctp_error(struct sk_buff *skb,
                       unsigned int dataoff,
                       const struct nf_hook_state *state)
{
        const struct sctphdr *sh;
        const char *logmsg;

        if (skb->len < dataoff + sizeof(struct sctphdr)) {
                logmsg = "nf_ct_sctp: short packet ";
                goto out_invalid;
        }
        if (state->hook == NF_INET_PRE_ROUTING &&
            state->net->ct.sysctl_checksum &&
            skb->ip_summed == CHECKSUM_NONE) {
                if (skb_ensure_writable(skb, dataoff + sizeof(*sh))) {
                        logmsg = "nf_ct_sctp: failed to read header ";
                        goto out_invalid;
                }
                sh = (const struct sctphdr *)(skb->data + dataoff);
                if (sh->checksum != sctp_compute_cksum(skb, dataoff)) {
                        logmsg = "nf_ct_sctp: bad CRC ";
                        goto out_invalid;
                }
                skb->ip_summed = CHECKSUM_UNNECESSARY;
        }
        return false;
out_invalid:
        nf_l4proto_log_invalid(skb, state, IPPROTO_SCTP, "%s", logmsg);
        return true;
}

/* Returns verdict for packet, or -NF_ACCEPT for invalid. */
int nf_conntrack_sctp_packet(struct nf_conn *ct,
                             struct sk_buff *skb,
                             unsigned int dataoff,
                             enum ip_conntrack_info ctinfo,
                             const struct nf_hook_state *state)
{
        enum sctp_conntrack new_state, old_state;
        enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
        const struct sctphdr *sh;
        struct sctphdr _sctph;
        const struct sctp_chunkhdr *sch;
        struct sctp_chunkhdr _sch;
        u_int32_t offset, count;
        unsigned int *timeouts;
        unsigned long map[256 / sizeof(unsigned long)] = { 0 };
        bool ignore = false;

        if (sctp_error(skb, dataoff, state))
                return -NF_ACCEPT;

        sh = skb_header_pointer(skb, dataoff, sizeof(_sctph), &_sctph);
        if (sh == NULL)
                goto out;

        if (do_basic_checks(ct, skb, dataoff, map, state) != 0)
                goto out;

        if (!nf_ct_is_confirmed(ct)) {
                /* If an OOTB packet has any of these chunks discard (Sec 8.4) */
                if (test_bit(SCTP_CID_ABORT, map) ||
                    test_bit(SCTP_CID_SHUTDOWN_COMPLETE, map) ||
                    test_bit(SCTP_CID_COOKIE_ACK, map))
                        return -NF_ACCEPT;

                if (!sctp_new(ct, skb, sh, dataoff))
                        return -NF_ACCEPT;
        }

        /* Check the verification tag (Sec 8.5) */
        if (!test_bit(SCTP_CID_INIT, map) &&
            !test_bit(SCTP_CID_SHUTDOWN_COMPLETE, map) &&
            !test_bit(SCTP_CID_COOKIE_ECHO, map) &&
            !test_bit(SCTP_CID_ABORT, map) &&
            !test_bit(SCTP_CID_SHUTDOWN_ACK, map) &&
            !test_bit(SCTP_CID_HEARTBEAT, map) &&
            !test_bit(SCTP_CID_HEARTBEAT_ACK, map) &&
            sh->vtag != ct->proto.sctp.vtag[dir]) {
                nf_ct_l4proto_log_invalid(skb, ct, state,
                                          "verification tag check failed %x vs %x for dir %d",
                                          sh->vtag, ct->proto.sctp.vtag[dir], dir);
                goto out;
        }

        old_state = new_state = SCTP_CONNTRACK_NONE;
        spin_lock_bh(&ct->lock);
        for_each_sctp_chunk (skb, sch, _sch, offset, dataoff, count) {
                /* Special cases of Verification tag check (Sec 8.5.1) */
                if (sch->type == SCTP_CID_INIT) {
                        /* (A) vtag MUST be zero */
                        if (sh->vtag != 0)
                                goto out_unlock;
                } else if (sch->type == SCTP_CID_ABORT) {
                        /* (B) vtag MUST match own vtag if T flag is unset OR
                         * MUST match peer's vtag if T flag is set
                         */
                        if ((!(sch->flags & SCTP_CHUNK_FLAG_T) &&
                             sh->vtag != ct->proto.sctp.vtag[dir]) ||
                            ((sch->flags & SCTP_CHUNK_FLAG_T) &&
                             sh->vtag != ct->proto.sctp.vtag[!dir]))
                                goto out_unlock;
                } else if (sch->type == SCTP_CID_SHUTDOWN_COMPLETE) {
                        /* (C) vtag MUST match own vtag if T flag is unset OR
                         * MUST match peer's vtag if T flag is set
                         */
                        if ((!(sch->flags & SCTP_CHUNK_FLAG_T) &&
                             sh->vtag != ct->proto.sctp.vtag[dir]) ||
                            ((sch->flags & SCTP_CHUNK_FLAG_T) &&
                             sh->vtag != ct->proto.sctp.vtag[!dir]))
                                goto out_unlock;
                } else if (sch->type == SCTP_CID_COOKIE_ECHO) {
                        /* (D) vtag must be same as init_vtag as found in INIT_ACK */
                        if (sh->vtag != ct->proto.sctp.vtag[dir])
                                goto out_unlock;
                } else if (sch->type == SCTP_CID_COOKIE_ACK) {
                        ct->proto.sctp.init[dir] = 0;
                        ct->proto.sctp.init[!dir] = 0;
                } else if (sch->type == SCTP_CID_HEARTBEAT) {
                        if (ct->proto.sctp.vtag[dir] == 0) {
                                pr_debug("Setting %d vtag %x for dir %d\n", sch->type, sh->vtag, dir);
                                ct->proto.sctp.vtag[dir] = sh->vtag;
                        } else if (sh->vtag != ct->proto.sctp.vtag[dir]) {
                                if (test_bit(SCTP_CID_DATA, map) || ignore)
                                        goto out_unlock;

                                ct->proto.sctp.flags |= SCTP_FLAG_HEARTBEAT_VTAG_FAILED;
                                ct->proto.sctp.last_dir = dir;
                                ignore = true;
                                continue;
                        } else if (ct->proto.sctp.flags & SCTP_FLAG_HEARTBEAT_VTAG_FAILED) {
                                ct->proto.sctp.flags &= ~SCTP_FLAG_HEARTBEAT_VTAG_FAILED;
                        }
                } else if (sch->type == SCTP_CID_HEARTBEAT_ACK) {
                        if (ct->proto.sctp.vtag[dir] == 0) {
                                pr_debug("Setting vtag %x for dir %d\n",
                                         sh->vtag, dir);
                                ct->proto.sctp.vtag[dir] = sh->vtag;
                        } else if (sh->vtag != ct->proto.sctp.vtag[dir]) {
                                if (test_bit(SCTP_CID_DATA, map) || ignore)
                                        goto out_unlock;

                                if ((ct->proto.sctp.flags & SCTP_FLAG_HEARTBEAT_VTAG_FAILED) == 0 ||
                                    ct->proto.sctp.last_dir == dir)
                                        goto out_unlock;

                                ct->proto.sctp.flags &= ~SCTP_FLAG_HEARTBEAT_VTAG_FAILED;
                                ct->proto.sctp.vtag[dir] = sh->vtag;
                                ct->proto.sctp.vtag[!dir] = 0;
                        } else if (ct->proto.sctp.flags & SCTP_FLAG_HEARTBEAT_VTAG_FAILED) {
                                ct->proto.sctp.flags &= ~SCTP_FLAG_HEARTBEAT_VTAG_FAILED;
                        }
                }

                old_state = ct->proto.sctp.state;
                new_state = sctp_new_state(dir, old_state, sch->type);

                /* Invalid */
                if (new_state == SCTP_CONNTRACK_MAX) {
                        nf_ct_l4proto_log_invalid(skb, ct, state,
                                                  "Invalid, old_state %d, dir %d, type %d",
                                                  old_state, dir, sch->type);

                        goto out_unlock;
                }

                /* If it is an INIT or an INIT ACK note down the vtag */
                if (sch->type == SCTP_CID_INIT) {
                        struct sctp_inithdr _ih, *ih;

                        ih = skb_header_pointer(skb, offset + sizeof(_sch), sizeof(*ih), &_ih);
                        if (!ih)
                                goto out_unlock;

                        if (ct->proto.sctp.init[dir] && ct->proto.sctp.init[!dir])
                                ct->proto.sctp.init[!dir] = 0;
                        ct->proto.sctp.init[dir] = 1;

                        pr_debug("Setting vtag %x for dir %d\n", ih->init_tag, !dir);
                        ct->proto.sctp.vtag[!dir] = ih->init_tag;

                        /* don't renew timeout on init retransmit so
                         * port reuse by client or NAT middlebox cannot
                         * keep entry alive indefinitely (incl. nat info).
                         */
                        if (new_state == SCTP_CONNTRACK_CLOSED &&
                            old_state == SCTP_CONNTRACK_CLOSED &&
                            nf_ct_is_confirmed(ct))
                                ignore = true;
                } else if (sch->type == SCTP_CID_INIT_ACK) {
                        struct sctp_inithdr _ih, *ih;
                        __be32 vtag;

                        ih = skb_header_pointer(skb, offset + sizeof(_sch), sizeof(*ih), &_ih);
                        if (!ih)
                                goto out_unlock;

                        vtag = ct->proto.sctp.vtag[!dir];
                        if (!ct->proto.sctp.init[!dir] && vtag && vtag != ih->init_tag)
                                goto out_unlock;
                        /* collision */
                        if (ct->proto.sctp.init[dir] && ct->proto.sctp.init[!dir] &&
                            vtag != ih->init_tag)
                                goto out_unlock;

                        pr_debug("Setting vtag %x for dir %d\n", ih->init_tag, !dir);
                        ct->proto.sctp.vtag[!dir] = ih->init_tag;
                }

                ct->proto.sctp.state = new_state;
                if (old_state != new_state) {
                        nf_conntrack_event_cache(IPCT_PROTOINFO, ct);
                        if (new_state == SCTP_CONNTRACK_ESTABLISHED &&
                            !test_and_set_bit(IPS_ASSURED_BIT, &ct->status))
                                nf_conntrack_event_cache(IPCT_ASSURED, ct);
                }
        }
        spin_unlock_bh(&ct->lock);

        /* allow but do not refresh timeout */
        if (ignore)
                return NF_ACCEPT;

        timeouts = nf_ct_timeout_lookup(ct);
        if (!timeouts)
                timeouts = nf_sctp_pernet(nf_ct_net(ct))->timeouts;

        nf_ct_refresh_acct(ct, ctinfo, skb, timeouts[new_state]);

        return NF_ACCEPT;

out_unlock:
        spin_unlock_bh(&ct->lock);
out:
        return -NF_ACCEPT;
}

static bool sctp_can_early_drop(const struct nf_conn *ct)
{
        switch (ct->proto.sctp.state) {
        case SCTP_CONNTRACK_SHUTDOWN_SENT:
        case SCTP_CONNTRACK_SHUTDOWN_RECD:
        case SCTP_CONNTRACK_SHUTDOWN_ACK_SENT:
                return true;
        default:
                break;
        }

        return false;
}

#if IS_ENABLED(CONFIG_NF_CT_NETLINK)

#include <linux/netfilter/nfnetlink.h>
#include <linux/netfilter/nfnetlink_conntrack.h>

static int sctp_to_nlattr(struct sk_buff *skb, struct nlattr *nla,
                          struct nf_conn *ct, bool destroy)
{
        struct nlattr *nest_parms;

        spin_lock_bh(&ct->lock);
        nest_parms = nla_nest_start(skb, CTA_PROTOINFO_SCTP);
        if (!nest_parms)
                goto nla_put_failure;

        if (nla_put_u8(skb, CTA_PROTOINFO_SCTP_STATE, ct->proto.sctp.state))
                goto nla_put_failure;

        if (destroy)
                goto skip_state;

        if (nla_put_be32(skb, CTA_PROTOINFO_SCTP_VTAG_ORIGINAL,
                         ct->proto.sctp.vtag[IP_CT_DIR_ORIGINAL]) ||
            nla_put_be32(skb, CTA_PROTOINFO_SCTP_VTAG_REPLY,
                         ct->proto.sctp.vtag[IP_CT_DIR_REPLY]))
                goto nla_put_failure;

skip_state:
        spin_unlock_bh(&ct->lock);
        nla_nest_end(skb, nest_parms);

        return 0;

nla_put_failure:
        spin_unlock_bh(&ct->lock);
        return -1;
}

static const struct nla_policy sctp_nla_policy[CTA_PROTOINFO_SCTP_MAX+1] = {
        [CTA_PROTOINFO_SCTP_STATE]          = NLA_POLICY_MAX(NLA_U8,
                                                         SCTP_CONNTRACK_HEARTBEAT_SENT),
        [CTA_PROTOINFO_SCTP_VTAG_ORIGINAL]  = { .type = NLA_U32 },
        [CTA_PROTOINFO_SCTP_VTAG_REPLY]     = { .type = NLA_U32 },
};

#define SCTP_NLATTR_SIZE ( \
                NLA_ALIGN(NLA_HDRLEN + 1) + \
                NLA_ALIGN(NLA_HDRLEN + 4) + \
                NLA_ALIGN(NLA_HDRLEN + 4))

static int nlattr_to_sctp(struct nlattr *cda[], struct nf_conn *ct)
{
        struct nlattr *attr = cda[CTA_PROTOINFO_SCTP];
        struct nlattr *tb[CTA_PROTOINFO_SCTP_MAX+1];
        int err;

        /* updates may not contain the internal protocol info, skip parsing */
        if (!attr)
                return 0;

        err = nla_parse_nested_deprecated(tb, CTA_PROTOINFO_SCTP_MAX, attr,
                                          sctp_nla_policy, NULL);
        if (err < 0)
                return err;

        if (!tb[CTA_PROTOINFO_SCTP_STATE] ||
            !tb[CTA_PROTOINFO_SCTP_VTAG_ORIGINAL] ||
            !tb[CTA_PROTOINFO_SCTP_VTAG_REPLY])
                return -EINVAL;

        spin_lock_bh(&ct->lock);
        ct->proto.sctp.state = nla_get_u8(tb[CTA_PROTOINFO_SCTP_STATE]);
        ct->proto.sctp.vtag[IP_CT_DIR_ORIGINAL] =
                nla_get_be32(tb[CTA_PROTOINFO_SCTP_VTAG_ORIGINAL]);
        ct->proto.sctp.vtag[IP_CT_DIR_REPLY] =
                nla_get_be32(tb[CTA_PROTOINFO_SCTP_VTAG_REPLY]);
        spin_unlock_bh(&ct->lock);

        return 0;
}
#endif

#ifdef CONFIG_NF_CONNTRACK_TIMEOUT

#include <linux/netfilter/nfnetlink.h>
#include <linux/netfilter/nfnetlink_cttimeout.h>

static int sctp_timeout_nlattr_to_obj(struct nlattr *tb[],
                                      struct net *net, void *data)
{
        unsigned int *timeouts = data;
        struct nf_sctp_net *sn = nf_sctp_pernet(net);
        int i;

        if (!timeouts)
                timeouts = sn->timeouts;

        /* set default SCTP timeouts. */
        for (i=0; i<SCTP_CONNTRACK_MAX; i++)
                timeouts[i] = sn->timeouts[i];

        /* there's a 1:1 mapping between attributes and protocol states. */
        for (i=CTA_TIMEOUT_SCTP_UNSPEC+1; i<CTA_TIMEOUT_SCTP_MAX+1; i++) {
                if (tb[i]) {
                        timeouts[i] = ntohl(nla_get_be32(tb[i])) * HZ;
                }
        }

        timeouts[CTA_TIMEOUT_SCTP_UNSPEC] = timeouts[CTA_TIMEOUT_SCTP_CLOSED];
        return 0;
}

static int
sctp_timeout_obj_to_nlattr(struct sk_buff *skb, const void *data)
{
        const unsigned int *timeouts = data;
        int i;

        for (i=CTA_TIMEOUT_SCTP_UNSPEC+1; i<CTA_TIMEOUT_SCTP_MAX+1; i++) {
                if (nla_put_be32(skb, i, htonl(timeouts[i] / HZ)))
                        goto nla_put_failure;
        }
        return 0;

nla_put_failure:
        return -ENOSPC;
}

static const struct nla_policy
sctp_timeout_nla_policy[CTA_TIMEOUT_SCTP_MAX+1] = {
        [CTA_TIMEOUT_SCTP_CLOSED]               = { .type = NLA_U32 },
        [CTA_TIMEOUT_SCTP_COOKIE_WAIT]          = { .type = NLA_U32 },
        [CTA_TIMEOUT_SCTP_COOKIE_ECHOED]        = { .type = NLA_U32 },
        [CTA_TIMEOUT_SCTP_ESTABLISHED]          = { .type = NLA_U32 },
        [CTA_TIMEOUT_SCTP_SHUTDOWN_SENT]        = { .type = NLA_U32 },
        [CTA_TIMEOUT_SCTP_SHUTDOWN_RECD]        = { .type = NLA_U32 },
        [CTA_TIMEOUT_SCTP_SHUTDOWN_ACK_SENT]    = { .type = NLA_U32 },
        [CTA_TIMEOUT_SCTP_HEARTBEAT_SENT]       = { .type = NLA_U32 },
        [CTA_TIMEOUT_SCTP_HEARTBEAT_ACKED]      = { .type = NLA_U32 },
};
#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */

void nf_conntrack_sctp_init_net(struct net *net)
{
        struct nf_sctp_net *sn = nf_sctp_pernet(net);
        int i;

        for (i = 0; i < SCTP_CONNTRACK_MAX; i++)
                sn->timeouts[i] = sctp_timeouts[i];

        /* timeouts[0] is unused, init it so ->timeouts[0] contains
         * 'new' timeout, like udp or icmp.
         */
        sn->timeouts[0] = sctp_timeouts[SCTP_CONNTRACK_CLOSED];
}

const struct nf_conntrack_l4proto nf_conntrack_l4proto_sctp = {
        .l4proto                = IPPROTO_SCTP,
#ifdef CONFIG_NF_CONNTRACK_PROCFS
        .print_conntrack        = sctp_print_conntrack,
#endif
        .can_early_drop         = sctp_can_early_drop,
#if IS_ENABLED(CONFIG_NF_CT_NETLINK)
        .nlattr_size            = SCTP_NLATTR_SIZE,
        .to_nlattr              = sctp_to_nlattr,
        .from_nlattr            = nlattr_to_sctp,
        .tuple_to_nlattr        = nf_ct_port_tuple_to_nlattr,
        .nlattr_tuple_size      = nf_ct_port_nlattr_tuple_size,
        .nlattr_to_tuple        = nf_ct_port_nlattr_to_tuple,
        .nla_policy             = nf_ct_port_nla_policy,
#endif
#ifdef CONFIG_NF_CONNTRACK_TIMEOUT
        .ctnl_timeout           = {
                .nlattr_to_obj  = sctp_timeout_nlattr_to_obj,
                .obj_to_nlattr  = sctp_timeout_obj_to_nlattr,
                .nlattr_max     = CTA_TIMEOUT_SCTP_MAX,
                .obj_size       = sizeof(unsigned int) * SCTP_CONNTRACK_MAX,
                .nla_policy     = sctp_timeout_nla_policy,
        },
#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */
};