root/net/netfilter/nft_chain_route.c
// SPDX-License-Identifier: GPL-2.0

#include <linux/skbuff.h>
#include <linux/netfilter.h>
#include <linux/netfilter_ipv4.h>
#include <linux/netfilter_ipv6.h>
#include <linux/netfilter/nfnetlink.h>
#include <linux/netfilter/nf_tables.h>
#include <net/netfilter/nf_tables.h>
#include <net/netfilter/nf_tables_ipv4.h>
#include <net/netfilter/nf_tables_ipv6.h>
#include <net/route.h>
#include <net/ip.h>

#ifdef CONFIG_NF_TABLES_IPV4
static unsigned int nf_route_table_hook4(void *priv,
                                         struct sk_buff *skb,
                                         const struct nf_hook_state *state)
{
        const struct iphdr *iph;
        struct nft_pktinfo pkt;
        __be32 saddr, daddr;
        unsigned int ret;
        u32 mark;
        int err;
        u8 tos;

        nft_set_pktinfo(&pkt, skb, state);
        nft_set_pktinfo_ipv4(&pkt);

        mark = skb->mark;
        iph = ip_hdr(skb);
        saddr = iph->saddr;
        daddr = iph->daddr;
        tos = iph->tos;

        ret = nft_do_chain(&pkt, priv);
        if (ret == NF_ACCEPT) {
                iph = ip_hdr(skb);

                if (iph->saddr != saddr ||
                    iph->daddr != daddr ||
                    skb->mark != mark ||
                    iph->tos != tos) {
                        err = ip_route_me_harder(state->net, state->sk, skb, RTN_UNSPEC);
                        if (err < 0)
                                ret = NF_DROP_ERR(err);
                }
        }
        return ret;
}

static const struct nft_chain_type nft_chain_route_ipv4 = {
        .name           = "route",
        .type           = NFT_CHAIN_T_ROUTE,
        .family         = NFPROTO_IPV4,
        .hook_mask      = (1 << NF_INET_LOCAL_OUT),
        .hooks          = {
                [NF_INET_LOCAL_OUT]     = nf_route_table_hook4,
        },
};
#endif

#ifdef CONFIG_NF_TABLES_IPV6
static unsigned int nf_route_table_hook6(void *priv,
                                         struct sk_buff *skb,
                                         const struct nf_hook_state *state)
{
        struct in6_addr saddr, daddr;
        struct nft_pktinfo pkt;
        u32 mark, flowlabel;
        unsigned int ret;
        u8 hop_limit;
        int err;

        nft_set_pktinfo(&pkt, skb, state);
        nft_set_pktinfo_ipv6(&pkt);

        /* save source/dest address, mark, hoplimit, flowlabel, priority */
        memcpy(&saddr, &ipv6_hdr(skb)->saddr, sizeof(saddr));
        memcpy(&daddr, &ipv6_hdr(skb)->daddr, sizeof(daddr));
        mark = skb->mark;
        hop_limit = ipv6_hdr(skb)->hop_limit;

        /* flowlabel and prio (includes version, which shouldn't change either)*/
        flowlabel = *((u32 *)ipv6_hdr(skb));

        ret = nft_do_chain(&pkt, priv);
        if (ret == NF_ACCEPT &&
            (memcmp(&ipv6_hdr(skb)->saddr, &saddr, sizeof(saddr)) ||
             memcmp(&ipv6_hdr(skb)->daddr, &daddr, sizeof(daddr)) ||
             skb->mark != mark ||
             ipv6_hdr(skb)->hop_limit != hop_limit ||
             flowlabel != *((u32 *)ipv6_hdr(skb)))) {
                err = nf_ip6_route_me_harder(state->net, state->sk, skb);
                if (err < 0)
                        ret = NF_DROP_ERR(err);
        }

        return ret;
}

static const struct nft_chain_type nft_chain_route_ipv6 = {
        .name           = "route",
        .type           = NFT_CHAIN_T_ROUTE,
        .family         = NFPROTO_IPV6,
        .hook_mask      = (1 << NF_INET_LOCAL_OUT),
        .hooks          = {
                [NF_INET_LOCAL_OUT]     = nf_route_table_hook6,
        },
};
#endif

#ifdef CONFIG_NF_TABLES_INET
static unsigned int nf_route_table_inet(void *priv,
                                        struct sk_buff *skb,
                                        const struct nf_hook_state *state)
{
        struct nft_pktinfo pkt;

        switch (state->pf) {
        case NFPROTO_IPV4:
                return nf_route_table_hook4(priv, skb, state);
        case NFPROTO_IPV6:
                return nf_route_table_hook6(priv, skb, state);
        default:
                nft_set_pktinfo(&pkt, skb, state);
                break;
        }

        return nft_do_chain(&pkt, priv);
}

static const struct nft_chain_type nft_chain_route_inet = {
        .name           = "route",
        .type           = NFT_CHAIN_T_ROUTE,
        .family         = NFPROTO_INET,
        .hook_mask      = (1 << NF_INET_LOCAL_OUT),
        .hooks          = {
                [NF_INET_LOCAL_OUT]     = nf_route_table_inet,
        },
};
#endif

void __init nft_chain_route_init(void)
{
#ifdef CONFIG_NF_TABLES_IPV6
        nft_register_chain_type(&nft_chain_route_ipv6);
#endif
#ifdef CONFIG_NF_TABLES_IPV4
        nft_register_chain_type(&nft_chain_route_ipv4);
#endif
#ifdef CONFIG_NF_TABLES_INET
        nft_register_chain_type(&nft_chain_route_inet);
#endif
}

void __exit nft_chain_route_fini(void)
{
#ifdef CONFIG_NF_TABLES_IPV6
        nft_unregister_chain_type(&nft_chain_route_ipv6);
#endif
#ifdef CONFIG_NF_TABLES_IPV4
        nft_unregister_chain_type(&nft_chain_route_ipv4);
#endif
#ifdef CONFIG_NF_TABLES_INET
        nft_unregister_chain_type(&nft_chain_route_inet);
#endif
}