root/tools/perf/util/comm.c
// SPDX-License-Identifier: GPL-2.0
#include "comm.h"
#include <errno.h>
#include <string.h>
#include <internal/rc_check.h>
#include <linux/refcount.h>
#include <linux/zalloc.h>
#include <tools/libc_compat.h> // reallocarray

#include "rwsem.h"

DECLARE_RC_STRUCT(comm_str) {
        refcount_t refcnt;
        char str[];
};

static struct comm_strs {
        struct rw_semaphore lock;
        struct comm_str **strs;
        int num_strs;
        int capacity;
} _comm_strs;

static void comm_strs__remove_if_last(struct comm_str *cs);

static void comm_strs__init(void)
        NO_THREAD_SAFETY_ANALYSIS /* Inherently single threaded due to pthread_once. */
{
        init_rwsem(&_comm_strs.lock);
        _comm_strs.capacity = 16;
        _comm_strs.num_strs = 0;
        _comm_strs.strs = calloc(16, sizeof(*_comm_strs.strs));
}

static struct comm_strs *comm_strs__get(void)
{
        static pthread_once_t comm_strs_type_once = PTHREAD_ONCE_INIT;

        pthread_once(&comm_strs_type_once, comm_strs__init);

        return &_comm_strs;
}

static refcount_t *comm_str__refcnt(struct comm_str *cs)
{
        return &RC_CHK_ACCESS(cs)->refcnt;
}

static const char *comm_str__str(const struct comm_str *cs)
{
        return &RC_CHK_ACCESS(cs)->str[0];
}

static struct comm_str *comm_str__get(struct comm_str *cs)
{
        struct comm_str *result;

        if (RC_CHK_GET(result, cs))
                refcount_inc_not_zero(comm_str__refcnt(cs));

        return result;
}

static void comm_str__put(struct comm_str *cs)
{
        if (!cs)
                return;

        if (refcount_dec_and_test(comm_str__refcnt(cs))) {
                RC_CHK_FREE(cs);
        } else {
                if (refcount_read(comm_str__refcnt(cs)) == 1)
                        comm_strs__remove_if_last(cs);

                RC_CHK_PUT(cs);
        }
}

static struct comm_str *comm_str__new(const char *str)
{
        struct comm_str *result = NULL;
        RC_STRUCT(comm_str) *cs;

        cs = malloc(sizeof(*cs) + strlen(str) + 1);
        if (ADD_RC_CHK(result, cs)) {
                refcount_set(comm_str__refcnt(result), 1);
                strcpy(&cs->str[0], str);
        }
        return result;
}

static int comm_str__search(const void *_key, const void *_member)
{
        const char *key = _key;
        const struct comm_str *member = *(const struct comm_str * const *)_member;

        return strcmp(key, comm_str__str(member));
}

static void comm_strs__remove_if_last(struct comm_str *cs)
{
        struct comm_strs *comm_strs = comm_strs__get();

        down_write(&comm_strs->lock);
        /*
         * Are there only references from the array, if so remove the array
         * reference under the write lock so that we don't race with findnew.
         */
        if (refcount_read(comm_str__refcnt(cs)) == 1) {
                struct comm_str **entry;

                entry = bsearch(comm_str__str(cs), comm_strs->strs, comm_strs->num_strs,
                                sizeof(struct comm_str *), comm_str__search);
                comm_str__put(*entry);
                for (int i = entry - comm_strs->strs; i < comm_strs->num_strs - 1; i++)
                        comm_strs->strs[i] = comm_strs->strs[i + 1];
                comm_strs->num_strs--;
        }
        up_write(&comm_strs->lock);
}

static struct comm_str *__comm_strs__find(struct comm_strs *comm_strs, const char *str)
        SHARED_LOCKS_REQUIRED(comm_strs->lock)
{
        struct comm_str **result;

        result = bsearch(str, comm_strs->strs, comm_strs->num_strs, sizeof(struct comm_str *),
                         comm_str__search);

        if (!result)
                return NULL;

        return comm_str__get(*result);
}

static struct comm_str *comm_strs__findnew(const char *str)
{
        struct comm_strs *comm_strs = comm_strs__get();
        struct comm_str *result;

        if (!comm_strs)
                return NULL;

        down_read(&comm_strs->lock);
        result = __comm_strs__find(comm_strs, str);
        up_read(&comm_strs->lock);
        if (result)
                return result;

        down_write(&comm_strs->lock);
        result = __comm_strs__find(comm_strs, str);
        if (!result) {
                if (comm_strs->num_strs == comm_strs->capacity) {
                        struct comm_str **tmp;

                        tmp = reallocarray(comm_strs->strs,
                                           comm_strs->capacity + 16,
                                           sizeof(*comm_strs->strs));
                        if (!tmp) {
                                up_write(&comm_strs->lock);
                                return NULL;
                        }
                        comm_strs->strs = tmp;
                        comm_strs->capacity += 16;
                }
                result = comm_str__new(str);
                if (result) {
                        int low = 0, high = comm_strs->num_strs - 1;
                        int insert = comm_strs->num_strs; /* Default to inserting at the end. */

                        while (low <= high) {
                                int mid = low + (high - low) / 2;
                                int cmp = strcmp(comm_str__str(comm_strs->strs[mid]), str);

                                if (cmp < 0) {
                                        low = mid + 1;
                                } else {
                                        high = mid - 1;
                                        insert = mid;
                                }
                        }
                        memmove(&comm_strs->strs[insert + 1], &comm_strs->strs[insert],
                                (comm_strs->num_strs - insert) * sizeof(struct comm_str *));
                        comm_strs->num_strs++;
                        comm_strs->strs[insert] = result;
                }
        }
        up_write(&comm_strs->lock);
        return comm_str__get(result);
}

struct comm *comm__new(const char *str, u64 timestamp, bool exec)
{
        struct comm *comm = zalloc(sizeof(*comm));

        if (!comm)
                return NULL;

        comm->start = timestamp;
        comm->exec = exec;

        comm->comm_str = comm_strs__findnew(str);
        if (!comm->comm_str) {
                free(comm);
                return NULL;
        }

        return comm;
}

int comm__override(struct comm *comm, const char *str, u64 timestamp, bool exec)
{
        struct comm_str *new, *old = comm->comm_str;

        new = comm_strs__findnew(str);
        if (!new)
                return -ENOMEM;

        comm_str__put(old);
        comm->comm_str = new;
        comm->start = timestamp;
        if (exec)
                comm->exec = true;

        return 0;
}

void comm__free(struct comm *comm)
{
        comm_str__put(comm->comm_str);
        free(comm);
}

const char *comm__str(const struct comm *comm)
{
        return comm_str__str(comm->comm_str);
}