root/lib/alloc_tag.c
// SPDX-License-Identifier: GPL-2.0-only
#include <linux/alloc_tag.h>
#include <linux/execmem.h>
#include <linux/fs.h>
#include <linux/gfp.h>
#include <linux/kallsyms.h>
#include <linux/module.h>
#include <linux/page_ext.h>
#include <linux/proc_fs.h>
#include <linux/seq_buf.h>
#include <linux/seq_file.h>
#include <linux/string_choices.h>
#include <linux/vmalloc.h>
#include <linux/kmemleak.h>

#define ALLOCINFO_FILE_NAME             "allocinfo"
#define MODULE_ALLOC_TAG_VMAP_SIZE      (100000UL * sizeof(struct alloc_tag))
#define SECTION_START(NAME)             (CODETAG_SECTION_START_PREFIX NAME)
#define SECTION_STOP(NAME)              (CODETAG_SECTION_STOP_PREFIX NAME)

#ifdef CONFIG_MEM_ALLOC_PROFILING_ENABLED_BY_DEFAULT
static bool mem_profiling_support = true;
#else
static bool mem_profiling_support;
#endif

static struct codetag_type *alloc_tag_cttype;

#ifdef CONFIG_ARCH_MODULE_NEEDS_WEAK_PER_CPU
DEFINE_PER_CPU(struct alloc_tag_counters, _shared_alloc_tag);
EXPORT_SYMBOL(_shared_alloc_tag);
#endif

DEFINE_STATIC_KEY_MAYBE(CONFIG_MEM_ALLOC_PROFILING_ENABLED_BY_DEFAULT,
                        mem_alloc_profiling_key);
EXPORT_SYMBOL(mem_alloc_profiling_key);

DEFINE_STATIC_KEY_FALSE(mem_profiling_compressed);

struct alloc_tag_kernel_section kernel_tags = { NULL, 0 };
unsigned long alloc_tag_ref_mask;
int alloc_tag_ref_offs;

struct allocinfo_private {
        struct codetag_iterator iter;
        bool print_header;
};

static void *allocinfo_start(struct seq_file *m, loff_t *pos)
{
        struct allocinfo_private *priv;
        loff_t node = *pos;

        priv = (struct allocinfo_private *)m->private;
        codetag_lock_module_list(alloc_tag_cttype, true);
        if (node == 0) {
                priv->print_header = true;
                priv->iter = codetag_get_ct_iter(alloc_tag_cttype);
                codetag_next_ct(&priv->iter);
        }
        return priv->iter.ct ? priv : NULL;
}

static void *allocinfo_next(struct seq_file *m, void *arg, loff_t *pos)
{
        struct allocinfo_private *priv = (struct allocinfo_private *)arg;
        struct codetag *ct = codetag_next_ct(&priv->iter);

        (*pos)++;
        if (!ct)
                return NULL;

        return priv;
}

static void allocinfo_stop(struct seq_file *m, void *arg)
{
        codetag_lock_module_list(alloc_tag_cttype, false);
}

static void print_allocinfo_header(struct seq_buf *buf)
{
        /* Output format version, so we can change it. */
        seq_buf_printf(buf, "allocinfo - version: 2.0\n");
        seq_buf_printf(buf, "#     <size>  <calls> <tag info>\n");
}

static void alloc_tag_to_text(struct seq_buf *out, struct codetag *ct)
{
        struct alloc_tag *tag = ct_to_alloc_tag(ct);
        struct alloc_tag_counters counter = alloc_tag_read(tag);
        s64 bytes = counter.bytes;

        seq_buf_printf(out, "%12lli %8llu ", bytes, counter.calls);
        codetag_to_text(out, ct);
        if (unlikely(alloc_tag_is_inaccurate(tag)))
                seq_buf_printf(out, " accurate:no");
        seq_buf_putc(out, ' ');
        seq_buf_putc(out, '\n');
}

static int allocinfo_show(struct seq_file *m, void *arg)
{
        struct allocinfo_private *priv = (struct allocinfo_private *)arg;
        char *bufp;
        size_t n = seq_get_buf(m, &bufp);
        struct seq_buf buf;

        seq_buf_init(&buf, bufp, n);
        if (priv->print_header) {
                print_allocinfo_header(&buf);
                priv->print_header = false;
        }
        alloc_tag_to_text(&buf, priv->iter.ct);
        seq_commit(m, seq_buf_used(&buf));
        return 0;
}

static const struct seq_operations allocinfo_seq_op = {
        .start  = allocinfo_start,
        .next   = allocinfo_next,
        .stop   = allocinfo_stop,
        .show   = allocinfo_show,
};

size_t alloc_tag_top_users(struct codetag_bytes *tags, size_t count, bool can_sleep)
{
        struct codetag_iterator iter;
        struct codetag *ct;
        struct codetag_bytes n;
        unsigned int i, nr = 0;

        if (IS_ERR_OR_NULL(alloc_tag_cttype))
                return 0;

        if (can_sleep)
                codetag_lock_module_list(alloc_tag_cttype, true);
        else if (!codetag_trylock_module_list(alloc_tag_cttype))
                return 0;

        iter = codetag_get_ct_iter(alloc_tag_cttype);
        while ((ct = codetag_next_ct(&iter))) {
                struct alloc_tag_counters counter = alloc_tag_read(ct_to_alloc_tag(ct));

                n.ct    = ct;
                n.bytes = counter.bytes;

                for (i = 0; i < nr; i++)
                        if (n.bytes > tags[i].bytes)
                                break;

                if (i < count) {
                        nr -= nr == count;
                        memmove(&tags[i + 1],
                                &tags[i],
                                sizeof(tags[0]) * (nr - i));
                        nr++;
                        tags[i] = n;
                }
        }

        codetag_lock_module_list(alloc_tag_cttype, false);

        return nr;
}

void pgalloc_tag_split(struct folio *folio, int old_order, int new_order)
{
        int i;
        struct alloc_tag *tag;
        unsigned int nr_pages = 1 << new_order;

        if (!mem_alloc_profiling_enabled())
                return;

        tag = __pgalloc_tag_get(&folio->page);
        if (!tag)
                return;

        for (i = nr_pages; i < (1 << old_order); i += nr_pages) {
                union pgtag_ref_handle handle;
                union codetag_ref ref;

                if (get_page_tag_ref(folio_page(folio, i), &ref, &handle)) {
                        /* Set new reference to point to the original tag */
                        alloc_tag_ref_set(&ref, tag);
                        update_page_tag_ref(handle, &ref);
                        put_page_tag_ref(handle);
                }
        }
}

void pgalloc_tag_swap(struct folio *new, struct folio *old)
{
        union pgtag_ref_handle handle_old, handle_new;
        union codetag_ref ref_old, ref_new;
        struct alloc_tag *tag_old, *tag_new;

        if (!mem_alloc_profiling_enabled())
                return;

        tag_old = __pgalloc_tag_get(&old->page);
        if (!tag_old)
                return;
        tag_new = __pgalloc_tag_get(&new->page);
        if (!tag_new)
                return;

        if (!get_page_tag_ref(&old->page, &ref_old, &handle_old))
                return;
        if (!get_page_tag_ref(&new->page, &ref_new, &handle_new)) {
                put_page_tag_ref(handle_old);
                return;
        }

        /*
         * Clear tag references to avoid debug warning when using
         * __alloc_tag_ref_set() with non-empty reference.
         */
        set_codetag_empty(&ref_old);
        set_codetag_empty(&ref_new);

        /* swap tags */
        __alloc_tag_ref_set(&ref_old, tag_new);
        update_page_tag_ref(handle_old, &ref_old);
        __alloc_tag_ref_set(&ref_new, tag_old);
        update_page_tag_ref(handle_new, &ref_new);

        put_page_tag_ref(handle_old);
        put_page_tag_ref(handle_new);
}

static void shutdown_mem_profiling(bool remove_file)
{
        if (mem_alloc_profiling_enabled())
                static_branch_disable(&mem_alloc_profiling_key);

        if (!mem_profiling_support)
                return;

        if (remove_file)
                remove_proc_entry(ALLOCINFO_FILE_NAME, NULL);
        mem_profiling_support = false;
}

void __init alloc_tag_sec_init(void)
{
        struct alloc_tag *last_codetag;

        if (!mem_profiling_support)
                return;

        if (!static_key_enabled(&mem_profiling_compressed))
                return;

        kernel_tags.first_tag = (struct alloc_tag *)kallsyms_lookup_name(
                                        SECTION_START(ALLOC_TAG_SECTION_NAME));
        last_codetag = (struct alloc_tag *)kallsyms_lookup_name(
                                        SECTION_STOP(ALLOC_TAG_SECTION_NAME));
        kernel_tags.count = last_codetag - kernel_tags.first_tag;

        /* Check if kernel tags fit into page flags */
        if (kernel_tags.count > (1UL << NR_UNUSED_PAGEFLAG_BITS)) {
                shutdown_mem_profiling(false); /* allocinfo file does not exist yet */
                pr_err("%lu allocation tags cannot be references using %d available page flag bits. Memory allocation profiling is disabled!\n",
                        kernel_tags.count, NR_UNUSED_PAGEFLAG_BITS);
                return;
        }

        alloc_tag_ref_offs = (LRU_REFS_PGOFF - NR_UNUSED_PAGEFLAG_BITS);
        alloc_tag_ref_mask = ((1UL << NR_UNUSED_PAGEFLAG_BITS) - 1);
        pr_debug("Memory allocation profiling compression is using %d page flag bits!\n",
                 NR_UNUSED_PAGEFLAG_BITS);
}

#ifdef CONFIG_MODULES

static struct maple_tree mod_area_mt = MTREE_INIT(mod_area_mt, MT_FLAGS_ALLOC_RANGE);
static struct vm_struct *vm_module_tags;
/* A dummy object used to indicate an unloaded module */
static struct module unloaded_mod;
/* A dummy object used to indicate a module prepended area */
static struct module prepend_mod;

struct alloc_tag_module_section module_tags;

static inline unsigned long alloc_tag_align(unsigned long val)
{
        if (!static_key_enabled(&mem_profiling_compressed)) {
                /* No alignment requirements when we are not indexing the tags */
                return val;
        }

        if (val % sizeof(struct alloc_tag) == 0)
                return val;
        return ((val / sizeof(struct alloc_tag)) + 1) * sizeof(struct alloc_tag);
}

static bool ensure_alignment(unsigned long align, unsigned int *prepend)
{
        if (!static_key_enabled(&mem_profiling_compressed)) {
                /* No alignment requirements when we are not indexing the tags */
                return true;
        }

        /*
         * If alloc_tag size is not a multiple of required alignment, tag
         * indexing does not work.
         */
        if (!IS_ALIGNED(sizeof(struct alloc_tag), align))
                return false;

        /* Ensure prepend consumes multiple of alloc_tag-sized blocks */
        if (*prepend)
                *prepend = alloc_tag_align(*prepend);

        return true;
}

static inline bool tags_addressable(void)
{
        unsigned long tag_idx_count;

        if (!static_key_enabled(&mem_profiling_compressed))
                return true; /* with page_ext tags are always addressable */

        tag_idx_count = CODETAG_ID_FIRST + kernel_tags.count +
                        module_tags.size / sizeof(struct alloc_tag);

        return tag_idx_count < (1UL << NR_UNUSED_PAGEFLAG_BITS);
}

static bool needs_section_mem(struct module *mod, unsigned long size)
{
        if (!mem_profiling_support)
                return false;

        return size >= sizeof(struct alloc_tag);
}

static bool clean_unused_counters(struct alloc_tag *start_tag,
                                  struct alloc_tag *end_tag)
{
        struct alloc_tag *tag;
        bool ret = true;

        for (tag = start_tag; tag <= end_tag; tag++) {
                struct alloc_tag_counters counter;

                if (!tag->counters)
                        continue;

                counter = alloc_tag_read(tag);
                if (!counter.bytes) {
                        free_percpu(tag->counters);
                        tag->counters = NULL;
                } else {
                        ret = false;
                }
        }

        return ret;
}

/* Called with mod_area_mt locked */
static void clean_unused_module_areas_locked(void)
{
        MA_STATE(mas, &mod_area_mt, 0, module_tags.size);
        struct module *val;

        mas_for_each(&mas, val, module_tags.size) {
                struct alloc_tag *start_tag;
                struct alloc_tag *end_tag;

                if (val != &unloaded_mod)
                        continue;

                /* Release area if all tags are unused */
                start_tag = (struct alloc_tag *)(module_tags.start_addr + mas.index);
                end_tag = (struct alloc_tag *)(module_tags.start_addr + mas.last);
                if (clean_unused_counters(start_tag, end_tag))
                        mas_erase(&mas);
        }
}

/* Called with mod_area_mt locked */
static bool find_aligned_area(struct ma_state *mas, unsigned long section_size,
                              unsigned long size, unsigned int prepend, unsigned long align)
{
        bool cleanup_done = false;

repeat:
        /* Try finding exact size and hope the start is aligned */
        if (!mas_empty_area(mas, 0, section_size - 1, prepend + size)) {
                if (IS_ALIGNED(mas->index + prepend, align))
                        return true;

                /* Try finding larger area to align later */
                mas_reset(mas);
                if (!mas_empty_area(mas, 0, section_size - 1,
                                    size + prepend + align - 1))
                        return true;
        }

        /* No free area, try cleanup stale data and repeat the search once */
        if (!cleanup_done) {
                clean_unused_module_areas_locked();
                cleanup_done = true;
                mas_reset(mas);
                goto repeat;
        }

        return false;
}

static int vm_module_tags_populate(void)
{
        unsigned long phys_end = ALIGN_DOWN(module_tags.start_addr, PAGE_SIZE) +
                                 (vm_module_tags->nr_pages << PAGE_SHIFT);
        unsigned long new_end = module_tags.start_addr + module_tags.size;

        if (phys_end < new_end) {
                struct page **next_page = vm_module_tags->pages + vm_module_tags->nr_pages;
                unsigned long old_shadow_end = ALIGN(phys_end, MODULE_ALIGN);
                unsigned long new_shadow_end = ALIGN(new_end, MODULE_ALIGN);
                unsigned long more_pages;
                unsigned long nr = 0;

                more_pages = ALIGN(new_end - phys_end, PAGE_SIZE) >> PAGE_SHIFT;
                while (nr < more_pages) {
                        unsigned long allocated;

                        allocated = alloc_pages_bulk_node(GFP_KERNEL | __GFP_NOWARN,
                                NUMA_NO_NODE, more_pages - nr, next_page + nr);

                        if (!allocated)
                                break;
                        nr += allocated;
                }

                if (nr < more_pages ||
                    vmap_pages_range(phys_end, phys_end + (nr << PAGE_SHIFT), PAGE_KERNEL,
                                     next_page, PAGE_SHIFT) < 0) {
                        release_pages_arg arg = { .pages = next_page };

                        /* Clean up and error out */
                        release_pages(arg, nr);
                        return -ENOMEM;
                }

                vm_module_tags->nr_pages += nr;

                /*
                 * Kasan allocates 1 byte of shadow for every 8 bytes of data.
                 * When kasan_alloc_module_shadow allocates shadow memory,
                 * its unit of allocation is a page.
                 * Therefore, here we need to align to MODULE_ALIGN.
                 */
                if (old_shadow_end < new_shadow_end)
                        kasan_alloc_module_shadow((void *)old_shadow_end,
                                                  new_shadow_end - old_shadow_end,
                                                  GFP_KERNEL);
        }

        /*
         * Mark the pages as accessible, now that they are mapped.
         * With hardware tag-based KASAN, marking is skipped for
         * non-VM_ALLOC mappings, see __kasan_unpoison_vmalloc().
         */
        kasan_unpoison_vmalloc((void *)module_tags.start_addr,
                                new_end - module_tags.start_addr,
                                KASAN_VMALLOC_PROT_NORMAL);

        return 0;
}

static void *reserve_module_tags(struct module *mod, unsigned long size,
                                 unsigned int prepend, unsigned long align)
{
        unsigned long section_size = module_tags.end_addr - module_tags.start_addr;
        MA_STATE(mas, &mod_area_mt, 0, section_size - 1);
        unsigned long offset;
        void *ret = NULL;

        /* If no tags return error */
        if (size < sizeof(struct alloc_tag))
                return ERR_PTR(-EINVAL);

        /*
         * align is always power of 2, so we can use IS_ALIGNED and ALIGN.
         * align 0 or 1 means no alignment, to simplify set to 1.
         */
        if (!align)
                align = 1;

        if (!ensure_alignment(align, &prepend)) {
                shutdown_mem_profiling(true);
                pr_err("%s: alignment %lu is incompatible with allocation tag indexing. Memory allocation profiling is disabled!\n",
                        mod->name, align);
                return ERR_PTR(-EINVAL);
        }

        mas_lock(&mas);
        if (!find_aligned_area(&mas, section_size, size, prepend, align)) {
                ret = ERR_PTR(-ENOMEM);
                goto unlock;
        }

        /* Mark found area as reserved */
        offset = mas.index;
        offset += prepend;
        offset = ALIGN(offset, align);
        if (offset != mas.index) {
                unsigned long pad_start = mas.index;

                mas.last = offset - 1;
                mas_store(&mas, &prepend_mod);
                if (mas_is_err(&mas)) {
                        ret = ERR_PTR(xa_err(mas.node));
                        goto unlock;
                }
                mas.index = offset;
                mas.last = offset + size - 1;
                mas_store(&mas, mod);
                if (mas_is_err(&mas)) {
                        mas.index = pad_start;
                        mas_erase(&mas);
                        ret = ERR_PTR(xa_err(mas.node));
                }
        } else {
                mas.last = offset + size - 1;
                mas_store(&mas, mod);
                if (mas_is_err(&mas))
                        ret = ERR_PTR(xa_err(mas.node));
        }
unlock:
        mas_unlock(&mas);

        if (IS_ERR(ret))
                return ret;

        if (module_tags.size < offset + size) {
                int grow_res;

                module_tags.size = offset + size;
                if (mem_alloc_profiling_enabled() && !tags_addressable()) {
                        shutdown_mem_profiling(true);
                        pr_warn("With module %s there are too many tags to fit in %d page flag bits. Memory allocation profiling is disabled!\n",
                                mod->name, NR_UNUSED_PAGEFLAG_BITS);
                }

                grow_res = vm_module_tags_populate();
                if (grow_res) {
                        shutdown_mem_profiling(true);
                        pr_err("Failed to allocate memory for allocation tags in the module %s. Memory allocation profiling is disabled!\n",
                               mod->name);
                        return ERR_PTR(grow_res);
                }
        }

        return (struct alloc_tag *)(module_tags.start_addr + offset);
}

static void release_module_tags(struct module *mod, bool used)
{
        MA_STATE(mas, &mod_area_mt, module_tags.size, module_tags.size);
        struct alloc_tag *start_tag;
        struct alloc_tag *end_tag;
        struct module *val;

        mas_lock(&mas);
        mas_for_each_rev(&mas, val, 0)
                if (val == mod)
                        break;

        if (!val) /* module not found */
                goto out;

        if (!used)
                goto release_area;

        start_tag = (struct alloc_tag *)(module_tags.start_addr + mas.index);
        end_tag = (struct alloc_tag *)(module_tags.start_addr + mas.last);
        if (!clean_unused_counters(start_tag, end_tag)) {
                struct alloc_tag *tag;

                for (tag = start_tag; tag <= end_tag; tag++) {
                        struct alloc_tag_counters counter;

                        if (!tag->counters)
                                continue;

                        counter = alloc_tag_read(tag);
                        pr_info("%s:%u module %s func:%s has %llu allocated at module unload\n",
                                tag->ct.filename, tag->ct.lineno, tag->ct.modname,
                                tag->ct.function, counter.bytes);
                }
        } else {
                used = false;
        }
release_area:
        mas_store(&mas, used ? &unloaded_mod : NULL);
        val = mas_prev_range(&mas, 0);
        if (val == &prepend_mod)
                mas_store(&mas, NULL);
out:
        mas_unlock(&mas);
}

static int load_module(struct module *mod, struct codetag *start, struct codetag *stop)
{
        /* Allocate module alloc_tag percpu counters */
        struct alloc_tag *start_tag;
        struct alloc_tag *stop_tag;
        struct alloc_tag *tag;

        /* percpu counters for core allocations are already statically allocated */
        if (!mod)
                return 0;

        start_tag = ct_to_alloc_tag(start);
        stop_tag = ct_to_alloc_tag(stop);
        for (tag = start_tag; tag < stop_tag; tag++) {
                WARN_ON(tag->counters);
                tag->counters = alloc_percpu(struct alloc_tag_counters);
                if (!tag->counters) {
                        while (--tag >= start_tag) {
                                free_percpu(tag->counters);
                                tag->counters = NULL;
                        }
                        pr_err("Failed to allocate memory for allocation tag percpu counters in the module %s\n",
                               mod->name);
                        return -ENOMEM;
                }

                /*
                 * Avoid a kmemleak false positive. The pointer to the counters is stored
                 * in the alloc_tag section of the module and cannot be directly accessed.
                 */
                kmemleak_ignore_percpu(tag->counters);
        }
        return 0;
}

static void replace_module(struct module *mod, struct module *new_mod)
{
        MA_STATE(mas, &mod_area_mt, 0, module_tags.size);
        struct module *val;

        mas_lock(&mas);
        mas_for_each(&mas, val, module_tags.size) {
                if (val != mod)
                        continue;

                mas_store_gfp(&mas, new_mod, GFP_KERNEL);
                break;
        }
        mas_unlock(&mas);
}

static int __init alloc_mod_tags_mem(void)
{
        /* Map space to copy allocation tags */
        vm_module_tags = execmem_vmap(MODULE_ALLOC_TAG_VMAP_SIZE);
        if (!vm_module_tags) {
                pr_err("Failed to map %lu bytes for module allocation tags\n",
                        MODULE_ALLOC_TAG_VMAP_SIZE);
                module_tags.start_addr = 0;
                return -ENOMEM;
        }

        vm_module_tags->pages = kmalloc_objs(struct page *,
                                             get_vm_area_size(vm_module_tags) >> PAGE_SHIFT,
                                             GFP_KERNEL | __GFP_ZERO);
        if (!vm_module_tags->pages) {
                free_vm_area(vm_module_tags);
                return -ENOMEM;
        }

        module_tags.start_addr = (unsigned long)vm_module_tags->addr;
        module_tags.end_addr = module_tags.start_addr + MODULE_ALLOC_TAG_VMAP_SIZE;
        /* Ensure the base is alloc_tag aligned when required for indexing */
        module_tags.start_addr = alloc_tag_align(module_tags.start_addr);

        return 0;
}

static void __init free_mod_tags_mem(void)
{
        release_pages_arg arg = { .pages = vm_module_tags->pages };

        module_tags.start_addr = 0;
        release_pages(arg, vm_module_tags->nr_pages);
        kfree(vm_module_tags->pages);
        free_vm_area(vm_module_tags);
}

#else /* CONFIG_MODULES */

static inline int alloc_mod_tags_mem(void) { return 0; }
static inline void free_mod_tags_mem(void) {}

#endif /* CONFIG_MODULES */

/* See: Documentation/mm/allocation-profiling.rst */
static int __init setup_early_mem_profiling(char *str)
{
        bool compressed = false;
        bool enable;

        if (!str || !str[0])
                return -EINVAL;

        if (!strncmp(str, "never", 5)) {
                enable = false;
                mem_profiling_support = false;
                pr_info("Memory allocation profiling is disabled!\n");
        } else {
                char *token = strsep(&str, ",");

                if (kstrtobool(token, &enable))
                        return -EINVAL;

                if (str) {

                        if (strcmp(str, "compressed"))
                                return -EINVAL;

                        compressed = true;
                }
                mem_profiling_support = true;
                pr_info("Memory allocation profiling is enabled %s compression and is turned %s!\n",
                        compressed ? "with" : "without", str_on_off(enable));
        }

        if (enable != mem_alloc_profiling_enabled()) {
                if (enable)
                        static_branch_enable(&mem_alloc_profiling_key);
                else
                        static_branch_disable(&mem_alloc_profiling_key);
        }
        if (compressed != static_key_enabled(&mem_profiling_compressed)) {
                if (compressed)
                        static_branch_enable(&mem_profiling_compressed);
                else
                        static_branch_disable(&mem_profiling_compressed);
        }

        return 0;
}
early_param("sysctl.vm.mem_profiling", setup_early_mem_profiling);

static __init bool need_page_alloc_tagging(void)
{
        if (static_key_enabled(&mem_profiling_compressed))
                return false;

        return mem_profiling_support;
}

static __init void init_page_alloc_tagging(void)
{
}

struct page_ext_operations page_alloc_tagging_ops = {
        .size = sizeof(union codetag_ref),
        .need = need_page_alloc_tagging,
        .init = init_page_alloc_tagging,
};
EXPORT_SYMBOL(page_alloc_tagging_ops);

#ifdef CONFIG_SYSCTL
/*
 * Not using proc_do_static_key() directly to prevent enabling profiling
 * after it was shut down.
 */
static int proc_mem_profiling_handler(const struct ctl_table *table, int write,
                                      void *buffer, size_t *lenp, loff_t *ppos)
{
        if (write) {
                /*
                 * Call from do_sysctl_args() which is a no-op since the same
                 * value was already set by setup_early_mem_profiling.
                 * Return success to avoid warnings from do_sysctl_args().
                 */
                if (!current->mm)
                        return 0;

#ifdef CONFIG_MEM_ALLOC_PROFILING_DEBUG
                /* User can't toggle profiling while debugging */
                return -EACCES;
#endif
                if (!mem_profiling_support)
                        return -EINVAL;
        }

        return proc_do_static_key(table, write, buffer, lenp, ppos);
}


static const struct ctl_table memory_allocation_profiling_sysctls[] = {
        {
                .procname       = "mem_profiling",
                .data           = &mem_alloc_profiling_key,
                .mode           = 0644,
                .proc_handler   = proc_mem_profiling_handler,
        },
};

static void __init sysctl_init(void)
{
        register_sysctl_init("vm", memory_allocation_profiling_sysctls);
}
#else /* CONFIG_SYSCTL */
static inline void sysctl_init(void) {}
#endif /* CONFIG_SYSCTL */

static int __init alloc_tag_init(void)
{
        const struct codetag_type_desc desc = {
                .section                = ALLOC_TAG_SECTION_NAME,
                .tag_size               = sizeof(struct alloc_tag),
#ifdef CONFIG_MODULES
                .needs_section_mem      = needs_section_mem,
                .alloc_section_mem      = reserve_module_tags,
                .free_section_mem       = release_module_tags,
                .module_load            = load_module,
                .module_replaced        = replace_module,
#endif
        };
        int res;

        sysctl_init();

        if (!mem_profiling_support) {
                pr_info("Memory allocation profiling is not supported!\n");
                return 0;
        }

        if (!proc_create_seq_private(ALLOCINFO_FILE_NAME, 0400, NULL, &allocinfo_seq_op,
                                     sizeof(struct allocinfo_private), NULL)) {
                pr_err("Failed to create %s file\n", ALLOCINFO_FILE_NAME);
                shutdown_mem_profiling(false);
                return -ENOMEM;
        }

        res = alloc_mod_tags_mem();
        if (res) {
                pr_err("Failed to reserve address space for module tags, errno = %d\n", res);
                shutdown_mem_profiling(true);
                return res;
        }

        alloc_tag_cttype = codetag_register_type(&desc);
        if (IS_ERR(alloc_tag_cttype)) {
                pr_err("Allocation tags registration failed, errno = %pe\n", alloc_tag_cttype);
                free_mod_tags_mem();
                shutdown_mem_profiling(true);
                return PTR_ERR(alloc_tag_cttype);
        }

        return 0;
}
module_init(alloc_tag_init);