root/net/netfilter/ipvs/ip_vs_mh.c
// SPDX-License-Identifier: GPL-2.0
/* IPVS:        Maglev Hashing scheduling module
 *
 * Authors:     Inju Song <inju.song@navercorp.com>
 *
 */

/* The mh algorithm is to assign a preference list of all the lookup
 * table positions to each destination and populate the table with
 * the most-preferred position of destinations. Then it is to select
 * destination with the hash key of source IP address through looking
 * up a the lookup table.
 *
 * The algorithm is detailed in:
 * [3.4 Consistent Hasing]
https://www.usenix.org/system/files/conference/nsdi16/nsdi16-paper-eisenbud.pdf
 *
 */

#define pr_fmt(fmt) "IPVS: " fmt

#include <linux/ip.h>
#include <linux/slab.h>
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/skbuff.h>

#include <net/ip_vs.h>

#include <linux/siphash.h>
#include <linux/bitops.h>
#include <linux/gcd.h>

#define IP_VS_SVC_F_SCHED_MH_FALLBACK   IP_VS_SVC_F_SCHED1 /* MH fallback */
#define IP_VS_SVC_F_SCHED_MH_PORT       IP_VS_SVC_F_SCHED2 /* MH use port */

struct ip_vs_mh_lookup {
        struct ip_vs_dest __rcu *dest;  /* real server (cache) */
};

struct ip_vs_mh_dest_setup {
        unsigned int    offset; /* starting offset */
        unsigned int    skip;   /* skip */
        unsigned int    perm;   /* next_offset */
        int             turns;  /* weight / gcd() and rshift */
};

/* Available prime numbers for MH table */
static int primes[] = {251, 509, 1021, 2039, 4093,
                       8191, 16381, 32749, 65521, 131071};

/* For IPVS MH entry hash table */
#ifndef CONFIG_IP_VS_MH_TAB_INDEX
#define CONFIG_IP_VS_MH_TAB_INDEX       12
#endif
#define IP_VS_MH_TAB_BITS               (CONFIG_IP_VS_MH_TAB_INDEX / 2)
#define IP_VS_MH_TAB_INDEX              (CONFIG_IP_VS_MH_TAB_INDEX - 8)
#define IP_VS_MH_TAB_SIZE               primes[IP_VS_MH_TAB_INDEX]

struct ip_vs_mh_state {
        struct rcu_head                 rcu_head;
        struct ip_vs_mh_lookup          *lookup;
        struct ip_vs_mh_dest_setup      *dest_setup;
        hsiphash_key_t                  hash1, hash2;
        int                             gcd;
        int                             rshift;
};

static inline void generate_hash_secret(hsiphash_key_t *hash1,
                                        hsiphash_key_t *hash2)
{
        hash1->key[0] = 2654435761UL;
        hash1->key[1] = 2654435761UL;

        hash2->key[0] = 2654446892UL;
        hash2->key[1] = 2654446892UL;
}

/* Helper function to determine if server is unavailable */
static inline bool is_unavailable(struct ip_vs_dest *dest)
{
        return atomic_read(&dest->weight) <= 0 ||
               dest->flags & IP_VS_DEST_F_OVERLOAD;
}

/* Returns hash value for IPVS MH entry */
static inline unsigned int
ip_vs_mh_hashkey(int af, const union nf_inet_addr *addr,
                 __be16 port, hsiphash_key_t *key, unsigned int offset)
{
        unsigned int v;
        __be32 addr_fold = addr->ip;

#ifdef CONFIG_IP_VS_IPV6
        if (af == AF_INET6)
                addr_fold = addr->ip6[0] ^ addr->ip6[1] ^
                            addr->ip6[2] ^ addr->ip6[3];
#endif
        v = (offset + ntohs(port) + ntohl(addr_fold));
        return hsiphash(&v, sizeof(v), key);
}

/* Reset all the hash buckets of the specified table. */
static void ip_vs_mh_reset(struct ip_vs_mh_state *s)
{
        int i;
        struct ip_vs_mh_lookup *l;
        struct ip_vs_dest *dest;

        l = &s->lookup[0];
        for (i = 0; i < IP_VS_MH_TAB_SIZE; i++) {
                dest = rcu_dereference_protected(l->dest, 1);
                if (dest) {
                        ip_vs_dest_put(dest);
                        RCU_INIT_POINTER(l->dest, NULL);
                }
                l++;
        }
}

static int ip_vs_mh_permutate(struct ip_vs_mh_state *s,
                              struct ip_vs_service *svc)
{
        struct list_head *p;
        struct ip_vs_mh_dest_setup *ds;
        struct ip_vs_dest *dest;
        int lw;

        /* If gcd is smaller then 1, number of dests or
         * all last_weight of dests are zero. So, skip
         * permutation for the dests.
         */
        if (s->gcd < 1)
                return 0;

        /* Set dest_setup for the dests permutation */
        p = &svc->destinations;
        ds = &s->dest_setup[0];
        while ((p = p->next) != &svc->destinations) {
                dest = list_entry(p, struct ip_vs_dest, n_list);

                ds->offset = ip_vs_mh_hashkey(svc->af, &dest->addr,
                                              dest->port, &s->hash1, 0) %
                                              IP_VS_MH_TAB_SIZE;
                ds->skip = ip_vs_mh_hashkey(svc->af, &dest->addr,
                                            dest->port, &s->hash2, 0) %
                                            (IP_VS_MH_TAB_SIZE - 1) + 1;
                ds->perm = ds->offset;

                lw = atomic_read(&dest->last_weight);
                ds->turns = ((lw / s->gcd) >> s->rshift) ? : (lw != 0);
                ds++;
        }

        return 0;
}

static int ip_vs_mh_populate(struct ip_vs_mh_state *s,
                             struct ip_vs_service *svc)
{
        int n, c, dt_count;
        unsigned long *table;
        struct list_head *p;
        struct ip_vs_mh_dest_setup *ds;
        struct ip_vs_dest *dest, *new_dest;

        /* If gcd is smaller then 1, number of dests or
         * all last_weight of dests are zero. So, skip
         * the population for the dests and reset lookup table.
         */
        if (s->gcd < 1) {
                ip_vs_mh_reset(s);
                return 0;
        }

        table = bitmap_zalloc(IP_VS_MH_TAB_SIZE, GFP_KERNEL);
        if (!table)
                return -ENOMEM;

        p = &svc->destinations;
        n = 0;
        dt_count = 0;
        while (n < IP_VS_MH_TAB_SIZE) {
                if (p == &svc->destinations)
                        p = p->next;

                ds = &s->dest_setup[0];
                while (p != &svc->destinations) {
                        /* Ignore added server with zero weight */
                        if (ds->turns < 1) {
                                p = p->next;
                                ds++;
                                continue;
                        }

                        c = ds->perm;
                        while (test_bit(c, table)) {
                                /* Add skip, mod IP_VS_MH_TAB_SIZE */
                                ds->perm += ds->skip;
                                if (ds->perm >= IP_VS_MH_TAB_SIZE)
                                        ds->perm -= IP_VS_MH_TAB_SIZE;
                                c = ds->perm;
                        }

                        __set_bit(c, table);

                        dest = rcu_dereference_protected(s->lookup[c].dest, 1);
                        new_dest = list_entry(p, struct ip_vs_dest, n_list);
                        if (dest != new_dest) {
                                if (dest)
                                        ip_vs_dest_put(dest);
                                ip_vs_dest_hold(new_dest);
                                RCU_INIT_POINTER(s->lookup[c].dest, new_dest);
                        }

                        if (++n == IP_VS_MH_TAB_SIZE)
                                goto out;

                        if (++dt_count >= ds->turns) {
                                dt_count = 0;
                                p = p->next;
                                ds++;
                        }
                }
        }

out:
        bitmap_free(table);
        return 0;
}

/* Get ip_vs_dest associated with supplied parameters. */
static inline struct ip_vs_dest *
ip_vs_mh_get(struct ip_vs_service *svc, struct ip_vs_mh_state *s,
             const union nf_inet_addr *addr, __be16 port)
{
        unsigned int hash = ip_vs_mh_hashkey(svc->af, addr, port, &s->hash1, 0)
                                             % IP_VS_MH_TAB_SIZE;
        struct ip_vs_dest *dest = rcu_dereference(s->lookup[hash].dest);

        return (!dest || is_unavailable(dest)) ? NULL : dest;
}

/* As ip_vs_mh_get, but with fallback if selected server is unavailable */
static inline struct ip_vs_dest *
ip_vs_mh_get_fallback(struct ip_vs_service *svc, struct ip_vs_mh_state *s,
                      const union nf_inet_addr *addr, __be16 port)
{
        unsigned int offset, roffset;
        unsigned int hash, ihash;
        struct ip_vs_dest *dest;

        /* First try the dest it's supposed to go to */
        ihash = ip_vs_mh_hashkey(svc->af, addr, port,
                                 &s->hash1, 0) % IP_VS_MH_TAB_SIZE;
        dest = rcu_dereference(s->lookup[ihash].dest);
        if (!dest)
                return NULL;
        if (!is_unavailable(dest))
                return dest;

        IP_VS_DBG_BUF(6, "MH: selected unavailable server %s:%u, reselecting",
                      IP_VS_DBG_ADDR(dest->af, &dest->addr), ntohs(dest->port));

        /* If the original dest is unavailable, loop around the table
         * starting from ihash to find a new dest
         */
        for (offset = 0; offset < IP_VS_MH_TAB_SIZE; offset++) {
                roffset = (offset + ihash) % IP_VS_MH_TAB_SIZE;
                hash = ip_vs_mh_hashkey(svc->af, addr, port, &s->hash1,
                                        roffset) % IP_VS_MH_TAB_SIZE;
                dest = rcu_dereference(s->lookup[hash].dest);
                if (!dest)
                        break;
                if (!is_unavailable(dest))
                        return dest;
                IP_VS_DBG_BUF(6,
                              "MH: selected unavailable server %s:%u (offset %u), reselecting",
                              IP_VS_DBG_ADDR(dest->af, &dest->addr),
                              ntohs(dest->port), roffset);
        }

        return NULL;
}

/* Assign all the hash buckets of the specified table with the service. */
static int ip_vs_mh_reassign(struct ip_vs_mh_state *s,
                             struct ip_vs_service *svc)
{
        int ret;

        if (svc->num_dests > IP_VS_MH_TAB_SIZE)
                return -EINVAL;

        if (svc->num_dests >= 1) {
                s->dest_setup = kzalloc_objs(struct ip_vs_mh_dest_setup,
                                             svc->num_dests);
                if (!s->dest_setup)
                        return -ENOMEM;
        }

        ip_vs_mh_permutate(s, svc);

        ret = ip_vs_mh_populate(s, svc);
        if (ret < 0)
                goto out;

        IP_VS_DBG_BUF(6, "MH: reassign lookup table of %s:%u\n",
                      IP_VS_DBG_ADDR(svc->af, &svc->addr),
                      ntohs(svc->port));

out:
        if (svc->num_dests >= 1) {
                kfree(s->dest_setup);
                s->dest_setup = NULL;
        }
        return ret;
}

static int ip_vs_mh_gcd_weight(struct ip_vs_service *svc)
{
        struct ip_vs_dest *dest;
        int weight;
        int g = 0;

        list_for_each_entry(dest, &svc->destinations, n_list) {
                weight = atomic_read(&dest->last_weight);
                if (weight > 0) {
                        if (g > 0)
                                g = gcd(weight, g);
                        else
                                g = weight;
                }
        }
        return g;
}

/* To avoid assigning huge weight for the MH table,
 * calculate shift value with gcd.
 */
static int ip_vs_mh_shift_weight(struct ip_vs_service *svc, int gcd)
{
        struct ip_vs_dest *dest;
        int new_weight, weight = 0;
        int mw, shift;

        /* If gcd is smaller then 1, number of dests or
         * all last_weight of dests are zero. So, return
         * shift value as zero.
         */
        if (gcd < 1)
                return 0;

        list_for_each_entry(dest, &svc->destinations, n_list) {
                new_weight = atomic_read(&dest->last_weight);
                if (new_weight > weight)
                        weight = new_weight;
        }

        /* Because gcd is greater than zero,
         * the maximum weight and gcd are always greater than zero
         */
        mw = weight / gcd;

        /* shift = occupied bits of weight/gcd - MH highest bits */
        shift = fls(mw) - IP_VS_MH_TAB_BITS;
        return (shift >= 0) ? shift : 0;
}

static void ip_vs_mh_state_free(struct rcu_head *head)
{
        struct ip_vs_mh_state *s;

        s = container_of(head, struct ip_vs_mh_state, rcu_head);
        kfree(s->lookup);
        kfree(s);
}

static int ip_vs_mh_init_svc(struct ip_vs_service *svc)
{
        int ret;
        struct ip_vs_mh_state *s;

        /* Allocate the MH table for this service */
        s = kzalloc_obj(*s);
        if (!s)
                return -ENOMEM;

        s->lookup = kzalloc_objs(struct ip_vs_mh_lookup, IP_VS_MH_TAB_SIZE);
        if (!s->lookup) {
                kfree(s);
                return -ENOMEM;
        }

        generate_hash_secret(&s->hash1, &s->hash2);
        s->gcd = ip_vs_mh_gcd_weight(svc);
        s->rshift = ip_vs_mh_shift_weight(svc, s->gcd);

        IP_VS_DBG(6,
                  "MH lookup table (memory=%zdbytes) allocated for current service\n",
                  sizeof(struct ip_vs_mh_lookup) * IP_VS_MH_TAB_SIZE);

        /* Assign the lookup table with current dests */
        ret = ip_vs_mh_reassign(s, svc);
        if (ret < 0) {
                ip_vs_mh_reset(s);
                ip_vs_mh_state_free(&s->rcu_head);
                return ret;
        }

        /* No more failures, attach state */
        svc->sched_data = s;
        return 0;
}

static void ip_vs_mh_done_svc(struct ip_vs_service *svc)
{
        struct ip_vs_mh_state *s = svc->sched_data;

        /* Got to clean up lookup entry here */
        ip_vs_mh_reset(s);

        call_rcu(&s->rcu_head, ip_vs_mh_state_free);
        IP_VS_DBG(6, "MH lookup table (memory=%zdbytes) released\n",
                  sizeof(struct ip_vs_mh_lookup) * IP_VS_MH_TAB_SIZE);
}

static int ip_vs_mh_dest_changed(struct ip_vs_service *svc,
                                 struct ip_vs_dest *dest)
{
        struct ip_vs_mh_state *s = svc->sched_data;

        s->gcd = ip_vs_mh_gcd_weight(svc);
        s->rshift = ip_vs_mh_shift_weight(svc, s->gcd);

        /* Assign the lookup table with the updated service */
        return ip_vs_mh_reassign(s, svc);
}

/* Helper function to get port number */
static inline __be16
ip_vs_mh_get_port(const struct sk_buff *skb, struct ip_vs_iphdr *iph)
{
        __be16 _ports[2], *ports;

        /* At this point we know that we have a valid packet of some kind.
         * Because ICMP packets are only guaranteed to have the first 8
         * bytes, let's just grab the ports.  Fortunately they're in the
         * same position for all three of the protocols we care about.
         */
        switch (iph->protocol) {
        case IPPROTO_TCP:
        case IPPROTO_UDP:
        case IPPROTO_SCTP:
                ports = skb_header_pointer(skb, iph->len, sizeof(_ports),
                                           &_ports);
                if (unlikely(!ports))
                        return 0;

                if (likely(!ip_vs_iph_inverse(iph)))
                        return ports[0];
                else
                        return ports[1];
        default:
                return 0;
        }
}

/* Maglev Hashing scheduling */
static struct ip_vs_dest *
ip_vs_mh_schedule(struct ip_vs_service *svc, const struct sk_buff *skb,
                  struct ip_vs_iphdr *iph)
{
        struct ip_vs_dest *dest;
        struct ip_vs_mh_state *s;
        __be16 port = 0;
        const union nf_inet_addr *hash_addr;

        hash_addr = ip_vs_iph_inverse(iph) ? &iph->daddr : &iph->saddr;

        IP_VS_DBG(6, "%s : Scheduling...\n", __func__);

        if (svc->flags & IP_VS_SVC_F_SCHED_MH_PORT)
                port = ip_vs_mh_get_port(skb, iph);

        s = (struct ip_vs_mh_state *)svc->sched_data;

        if (svc->flags & IP_VS_SVC_F_SCHED_MH_FALLBACK)
                dest = ip_vs_mh_get_fallback(svc, s, hash_addr, port);
        else
                dest = ip_vs_mh_get(svc, s, hash_addr, port);

        if (!dest) {
                ip_vs_scheduler_err(svc, "no destination available");
                return NULL;
        }

        IP_VS_DBG_BUF(6, "MH: source IP address %s:%u --> server %s:%u\n",
                      IP_VS_DBG_ADDR(svc->af, hash_addr),
                      ntohs(port),
                      IP_VS_DBG_ADDR(dest->af, &dest->addr),
                      ntohs(dest->port));

        return dest;
}

/* IPVS MH Scheduler structure */
static struct ip_vs_scheduler ip_vs_mh_scheduler = {
        .name =                 "mh",
        .refcnt =               ATOMIC_INIT(0),
        .module =               THIS_MODULE,
        .n_list  =              LIST_HEAD_INIT(ip_vs_mh_scheduler.n_list),
        .init_service =         ip_vs_mh_init_svc,
        .done_service =         ip_vs_mh_done_svc,
        .add_dest =             ip_vs_mh_dest_changed,
        .del_dest =             ip_vs_mh_dest_changed,
        .upd_dest =             ip_vs_mh_dest_changed,
        .schedule =             ip_vs_mh_schedule,
};

static int __init ip_vs_mh_init(void)
{
        return register_ip_vs_scheduler(&ip_vs_mh_scheduler);
}

static void __exit ip_vs_mh_cleanup(void)
{
        unregister_ip_vs_scheduler(&ip_vs_mh_scheduler);
        rcu_barrier();
}

module_init(ip_vs_mh_init);
module_exit(ip_vs_mh_cleanup);
MODULE_DESCRIPTION("Maglev hashing ipvs scheduler");
MODULE_LICENSE("GPL v2");
MODULE_AUTHOR("Inju Song <inju.song@navercorp.com>");