root/drivers/hid/bpf/progs/hid_bpf_async.h
/* SPDX-License-Identifier: GPL-2.0-only
 * Copyright (c) 2024 Benjamin Tissoires
 */

#ifndef __HID_BPF_ASYNC_H__
#define __HID_BPF_ASYNC_H__

#ifndef HID_BPF_ASYNC_MAX_CTX
#error "HID_BPF_ASYNC_MAX_CTX should be set to the maximum number of concurrent async functions"
#endif /* HID_BPF_ASYNC_MAX_CTX */

#define CLOCK_MONOTONIC         1

typedef int (*hid_bpf_async_callback_t)(void *map, int *key, void *value);

enum hid_bpf_async_state {
        HID_BPF_ASYNC_STATE_UNSET = 0,
        HID_BPF_ASYNC_STATE_INITIALIZING,
        HID_BPF_ASYNC_STATE_INITIALIZED,
        HID_BPF_ASYNC_STATE_STARTING,
        HID_BPF_ASYNC_STATE_RUNNING,
};

struct hid_bpf_async_map_elem {
        struct bpf_spin_lock lock;
        enum hid_bpf_async_state state;
        struct bpf_timer t;
        struct bpf_wq wq;
        u32 hid;
};

struct {
        __uint(type, BPF_MAP_TYPE_ARRAY);
        __uint(max_entries, HID_BPF_ASYNC_MAX_CTX);
        __type(key, u32);
        __type(value, struct hid_bpf_async_map_elem);
} hid_bpf_async_ctx_map SEC(".maps");

/**
 * HID_BPF_ASYNC_CB: macro to define an async callback used in a bpf_wq
 *
 * The caller is responsible for allocating a key in the async map
 * with hid_bpf_async_get_ctx().
 */
#define HID_BPF_ASYNC_CB(cb)                                    \
cb(void *map, int *key, void *value);                           \
static __always_inline int                                      \
____##cb(struct hid_bpf_ctx *ctx);                              \
typeof(cb(0, 0, 0)) cb(void *map, int *key, void *value)        \
{                                                               \
        struct hid_bpf_async_map_elem *e;                               \
        struct hid_bpf_ctx *ctx;                                \
                                                                \
        e = (struct hid_bpf_async_map_elem *)value;                     \
        ctx = hid_bpf_allocate_context(e->hid);                 \
        if (!ctx)                                               \
                return 0; /* EPERM check */                     \
                                                                \
        e->state = HID_BPF_ASYNC_STATE_RUNNING;                 \
                                                                \
        ____##cb(ctx);                                          \
                                                                \
        e->state = HID_BPF_ASYNC_STATE_INITIALIZED;             \
        hid_bpf_release_context(ctx);                           \
        return 0;                                               \
}                                                               \
static __always_inline int                                      \
____##cb

/**
 * ASYNC: macro to automatically handle async callbacks contexts
 *
 * Needs to be used in conjunction with HID_BPF_ASYNC_INIT and HID_BPF_ASYNC_DELAYED_CALL
 */
#define HID_BPF_ASYNC_FUN(fun)                                          \
fun(struct hid_bpf_ctx *ctx);                                   \
int ____key__##fun;                                             \
static int ____async_init_##fun(void)                           \
{                                                               \
        ____key__##fun = hid_bpf_async_get_ctx();                       \
        if (____key__##fun < 0)                                 \
                return ____key__##fun;                          \
        return 0;                                               \
}                                                               \
static int HID_BPF_ASYNC_CB(____##fun##_cb)(struct hid_bpf_ctx *hctx)   \
{                                                               \
        return fun(hctx);                                       \
}                                                               \
typeof(fun(0)) fun

#define HID_BPF_ASYNC_INIT(fun) ____async_init_##fun()
#define HID_BPF_ASYNC_DELAYED_CALL(fun, ctx, delay)             \
        hid_bpf_async_delayed_call(ctx, delay, ____key__##fun, ____##fun##_cb)

/*
 * internal cb for starting the delayed work callback in a workqueue.
 */
static int __start_wq_timer_cb(void *map, int *key, void *value)
{
        struct hid_bpf_async_map_elem *e = (struct hid_bpf_async_map_elem *)value;

        bpf_wq_start(&e->wq, 0);

        return 0;
}

static int hid_bpf_async_find_empty_key(void)
{
        int i;

        bpf_for(i, 0, HID_BPF_ASYNC_MAX_CTX) {
                struct hid_bpf_async_map_elem *elem;
                int key = i;

                elem = bpf_map_lookup_elem(&hid_bpf_async_ctx_map, &key);
                if (!elem)
                        return -ENOMEM; /* should never happen */

                bpf_spin_lock(&elem->lock);

                if (elem->state == HID_BPF_ASYNC_STATE_UNSET) {
                        elem->state = HID_BPF_ASYNC_STATE_INITIALIZING;
                        bpf_spin_unlock(&elem->lock);
                        return i;
                }

                bpf_spin_unlock(&elem->lock);
        }

        return -EINVAL;
}

static int hid_bpf_async_get_ctx(void)
{
        int key = hid_bpf_async_find_empty_key();
        struct hid_bpf_async_map_elem *elem;
        int err;

        if (key < 0)
                return key;

        elem = bpf_map_lookup_elem(&hid_bpf_async_ctx_map, &key);
        if (!elem)
                return -EINVAL;

        err = bpf_timer_init(&elem->t, &hid_bpf_async_ctx_map, CLOCK_MONOTONIC);
        if (err)
                return err;

        err = bpf_timer_set_callback(&elem->t, __start_wq_timer_cb);
        if (err)
                return err;

        err = bpf_wq_init(&elem->wq, &hid_bpf_async_ctx_map, 0);
        if (err)
                return err;

        elem->state = HID_BPF_ASYNC_STATE_INITIALIZED;

        return key;
}

static inline u64 ms_to_ns(u64 milliseconds)
{
        return (u64)milliseconds * 1000UL * 1000UL;
}

static int hid_bpf_async_delayed_call(struct hid_bpf_ctx *hctx, u64 milliseconds, int key,
                              hid_bpf_async_callback_t wq_cb)
{
        struct hid_bpf_async_map_elem *elem;
        int err;

        elem = bpf_map_lookup_elem(&hid_bpf_async_ctx_map, &key);
        if (!elem)
                return -EINVAL;

        bpf_spin_lock(&elem->lock);
        /* The wq must be:
         * - HID_BPF_ASYNC_STATE_INITIALIZED -> it's been initialized and ready to be called
         * - HID_BPF_ASYNC_STATE_RUNNING -> possible re-entry from the wq itself
         */
        if (elem->state != HID_BPF_ASYNC_STATE_INITIALIZED &&
            elem->state != HID_BPF_ASYNC_STATE_RUNNING) {
                bpf_spin_unlock(&elem->lock);
                return -EINVAL;
        }
        elem->state = HID_BPF_ASYNC_STATE_STARTING;
        bpf_spin_unlock(&elem->lock);

        elem->hid = hctx->hid->id;

        err = bpf_wq_set_callback(&elem->wq, wq_cb, 0);
        if (err)
                return err;

        if (milliseconds) {
                /* needed for every call because a cancel might unset this */
                err = bpf_timer_set_callback(&elem->t, __start_wq_timer_cb);
                if (err)
                        return err;

                err = bpf_timer_start(&elem->t, ms_to_ns(milliseconds), 0);
                if (err)
                        return err;

                return 0;
        }

        return bpf_wq_start(&elem->wq, 0);
}

static inline int hid_bpf_async_call(struct hid_bpf_ctx *ctx, int key,
                                     hid_bpf_async_callback_t wq_cb)
{
        return hid_bpf_async_delayed_call(ctx, 0, key, wq_cb);
}

#endif /* __HID_BPF_ASYNC_H__ */