root/kernel/locking/test-ww_mutex.c
// SPDX-License-Identifier: GPL-2.0-or-later
/*
 * Module-based API test facility for ww_mutexes
 */

#include <linux/kernel.h>

#include <linux/completion.h>
#include <linux/delay.h>
#include <linux/kthread.h>
#include <linux/module.h>
#include <linux/prandom.h>
#include <linux/slab.h>
#include <linux/ww_mutex.h>

static DEFINE_WD_CLASS(wd_class);
static DEFINE_WW_CLASS(ww_class);
struct workqueue_struct *wq;

#ifdef CONFIG_DEBUG_WW_MUTEX_SLOWPATH
#define ww_acquire_init_noinject(a, b) do { \
                ww_acquire_init((a), (b)); \
                (a)->deadlock_inject_countdown = ~0U; \
        } while (0)
#else
#define ww_acquire_init_noinject(a, b) ww_acquire_init((a), (b))
#endif

struct test_mutex {
        struct work_struct work;
        struct ww_mutex mutex;
        struct completion ready, go, done;
        unsigned int flags;
};

#define TEST_MTX_SPIN BIT(0)
#define TEST_MTX_TRY BIT(1)
#define TEST_MTX_CTX BIT(2)
#define __TEST_MTX_LAST BIT(3)

static void test_mutex_work(struct work_struct *work)
{
        struct test_mutex *mtx = container_of(work, typeof(*mtx), work);

        complete(&mtx->ready);
        wait_for_completion(&mtx->go);

        if (mtx->flags & TEST_MTX_TRY) {
                while (!ww_mutex_trylock(&mtx->mutex, NULL))
                        cond_resched();
        } else {
                ww_mutex_lock(&mtx->mutex, NULL);
        }
        complete(&mtx->done);
        ww_mutex_unlock(&mtx->mutex);
}

static int __test_mutex(struct ww_class *class, unsigned int flags)
{
#define TIMEOUT (HZ / 16)
        struct test_mutex mtx;
        struct ww_acquire_ctx ctx;
        int ret;

        ww_mutex_init(&mtx.mutex, class);
        if (flags & TEST_MTX_CTX)
                ww_acquire_init(&ctx, class);

        INIT_WORK_ONSTACK(&mtx.work, test_mutex_work);
        init_completion(&mtx.ready);
        init_completion(&mtx.go);
        init_completion(&mtx.done);
        mtx.flags = flags;

        queue_work(wq, &mtx.work);

        wait_for_completion(&mtx.ready);
        ww_mutex_lock(&mtx.mutex, (flags & TEST_MTX_CTX) ? &ctx : NULL);
        complete(&mtx.go);
        if (flags & TEST_MTX_SPIN) {
                unsigned long timeout = jiffies + TIMEOUT;

                ret = 0;
                do {
                        if (completion_done(&mtx.done)) {
                                ret = -EINVAL;
                                break;
                        }
                        cond_resched();
                } while (time_before(jiffies, timeout));
        } else {
                ret = wait_for_completion_timeout(&mtx.done, TIMEOUT);
        }
        ww_mutex_unlock(&mtx.mutex);
        if (flags & TEST_MTX_CTX)
                ww_acquire_fini(&ctx);

        if (ret) {
                pr_err("%s(flags=%x): mutual exclusion failure\n",
                       __func__, flags);
                ret = -EINVAL;
        }

        flush_work(&mtx.work);
        destroy_work_on_stack(&mtx.work);
        return ret;
#undef TIMEOUT
}

static int test_mutex(struct ww_class *class)
{
        int ret;
        int i;

        for (i = 0; i < __TEST_MTX_LAST; i++) {
                ret = __test_mutex(class, i);
                if (ret)
                        return ret;
        }

        return 0;
}

static int test_aa(struct ww_class *class, bool trylock)
{
        struct ww_mutex mutex;
        struct ww_acquire_ctx ctx;
        int ret;
        const char *from = trylock ? "trylock" : "lock";

        ww_mutex_init(&mutex, class);
        ww_acquire_init(&ctx, class);

        if (!trylock) {
                ret = ww_mutex_lock(&mutex, &ctx);
                if (ret) {
                        pr_err("%s: initial lock failed!\n", __func__);
                        goto out;
                }
        } else {
                ret = !ww_mutex_trylock(&mutex, &ctx);
                if (ret) {
                        pr_err("%s: initial trylock failed!\n", __func__);
                        goto out;
                }
        }

        if (ww_mutex_trylock(&mutex, NULL))  {
                pr_err("%s: trylocked itself without context from %s!\n", __func__, from);
                ww_mutex_unlock(&mutex);
                ret = -EINVAL;
                goto out;
        }

        if (ww_mutex_trylock(&mutex, &ctx))  {
                pr_err("%s: trylocked itself with context from %s!\n", __func__, from);
                ww_mutex_unlock(&mutex);
                ret = -EINVAL;
                goto out;
        }

        ret = ww_mutex_lock(&mutex, &ctx);
        if (ret != -EALREADY) {
                pr_err("%s: missed deadlock for recursing, ret=%d from %s\n",
                       __func__, ret, from);
                if (!ret)
                        ww_mutex_unlock(&mutex);
                ret = -EINVAL;
                goto out;
        }

        ww_mutex_unlock(&mutex);
        ret = 0;
out:
        ww_acquire_fini(&ctx);
        return ret;
}

struct test_abba {
        struct work_struct work;
        struct ww_class *class;
        struct ww_mutex a_mutex;
        struct ww_mutex b_mutex;
        struct completion a_ready;
        struct completion b_ready;
        bool resolve, trylock;
        int result;
};

static void test_abba_work(struct work_struct *work)
{
        struct test_abba *abba = container_of(work, typeof(*abba), work);
        struct ww_acquire_ctx ctx;
        int err;

        ww_acquire_init_noinject(&ctx, abba->class);
        if (!abba->trylock)
                ww_mutex_lock(&abba->b_mutex, &ctx);
        else
                WARN_ON(!ww_mutex_trylock(&abba->b_mutex, &ctx));

        WARN_ON(READ_ONCE(abba->b_mutex.ctx) != &ctx);

        complete(&abba->b_ready);
        wait_for_completion(&abba->a_ready);

        err = ww_mutex_lock(&abba->a_mutex, &ctx);
        if (abba->resolve && err == -EDEADLK) {
                ww_mutex_unlock(&abba->b_mutex);
                ww_mutex_lock_slow(&abba->a_mutex, &ctx);
                err = ww_mutex_lock(&abba->b_mutex, &ctx);
        }

        if (!err)
                ww_mutex_unlock(&abba->a_mutex);
        ww_mutex_unlock(&abba->b_mutex);
        ww_acquire_fini(&ctx);

        abba->result = err;
}

static int test_abba(struct ww_class *class, bool trylock, bool resolve)
{
        struct test_abba abba;
        struct ww_acquire_ctx ctx;
        int err, ret;

        ww_mutex_init(&abba.a_mutex, class);
        ww_mutex_init(&abba.b_mutex, class);
        INIT_WORK_ONSTACK(&abba.work, test_abba_work);
        init_completion(&abba.a_ready);
        init_completion(&abba.b_ready);
        abba.class = class;
        abba.trylock = trylock;
        abba.resolve = resolve;

        queue_work(wq, &abba.work);

        ww_acquire_init_noinject(&ctx, class);
        if (!trylock)
                ww_mutex_lock(&abba.a_mutex, &ctx);
        else
                WARN_ON(!ww_mutex_trylock(&abba.a_mutex, &ctx));

        WARN_ON(READ_ONCE(abba.a_mutex.ctx) != &ctx);

        complete(&abba.a_ready);
        wait_for_completion(&abba.b_ready);

        err = ww_mutex_lock(&abba.b_mutex, &ctx);
        if (resolve && err == -EDEADLK) {
                ww_mutex_unlock(&abba.a_mutex);
                ww_mutex_lock_slow(&abba.b_mutex, &ctx);
                err = ww_mutex_lock(&abba.a_mutex, &ctx);
        }

        if (!err)
                ww_mutex_unlock(&abba.b_mutex);
        ww_mutex_unlock(&abba.a_mutex);
        ww_acquire_fini(&ctx);

        flush_work(&abba.work);
        destroy_work_on_stack(&abba.work);

        ret = 0;
        if (resolve) {
                if (err || abba.result) {
                        pr_err("%s: failed to resolve ABBA deadlock, A err=%d, B err=%d\n",
                               __func__, err, abba.result);
                        ret = -EINVAL;
                }
        } else {
                if (err != -EDEADLK && abba.result != -EDEADLK) {
                        pr_err("%s: missed ABBA deadlock, A err=%d, B err=%d\n",
                               __func__, err, abba.result);
                        ret = -EINVAL;
                }
        }
        return ret;
}

struct test_cycle {
        struct work_struct work;
        struct ww_class *class;
        struct ww_mutex a_mutex;
        struct ww_mutex *b_mutex;
        struct completion *a_signal;
        struct completion b_signal;
        int result;
};

static void test_cycle_work(struct work_struct *work)
{
        struct test_cycle *cycle = container_of(work, typeof(*cycle), work);
        struct ww_acquire_ctx ctx;
        int err, erra = 0;

        ww_acquire_init_noinject(&ctx, cycle->class);
        ww_mutex_lock(&cycle->a_mutex, &ctx);

        complete(cycle->a_signal);
        wait_for_completion(&cycle->b_signal);

        err = ww_mutex_lock(cycle->b_mutex, &ctx);
        if (err == -EDEADLK) {
                err = 0;
                ww_mutex_unlock(&cycle->a_mutex);
                ww_mutex_lock_slow(cycle->b_mutex, &ctx);
                erra = ww_mutex_lock(&cycle->a_mutex, &ctx);
        }

        if (!err)
                ww_mutex_unlock(cycle->b_mutex);
        if (!erra)
                ww_mutex_unlock(&cycle->a_mutex);
        ww_acquire_fini(&ctx);

        cycle->result = err ?: erra;
}

static int __test_cycle(struct ww_class *class, unsigned int nthreads)
{
        struct test_cycle *cycles;
        unsigned int n, last = nthreads - 1;
        int ret;

        cycles = kmalloc_objs(*cycles, nthreads);
        if (!cycles)
                return -ENOMEM;

        for (n = 0; n < nthreads; n++) {
                struct test_cycle *cycle = &cycles[n];

                cycle->class = class;
                ww_mutex_init(&cycle->a_mutex, class);
                if (n == last)
                        cycle->b_mutex = &cycles[0].a_mutex;
                else
                        cycle->b_mutex = &cycles[n + 1].a_mutex;

                if (n == 0)
                        cycle->a_signal = &cycles[last].b_signal;
                else
                        cycle->a_signal = &cycles[n - 1].b_signal;
                init_completion(&cycle->b_signal);

                INIT_WORK(&cycle->work, test_cycle_work);
                cycle->result = 0;
        }

        for (n = 0; n < nthreads; n++)
                queue_work(wq, &cycles[n].work);

        flush_workqueue(wq);

        ret = 0;
        for (n = 0; n < nthreads; n++) {
                struct test_cycle *cycle = &cycles[n];

                if (!cycle->result)
                        continue;

                pr_err("cyclic deadlock not resolved, ret[%d/%d] = %d\n",
                       n, nthreads, cycle->result);
                ret = -EINVAL;
                break;
        }

        for (n = 0; n < nthreads; n++)
                ww_mutex_destroy(&cycles[n].a_mutex);
        kfree(cycles);
        return ret;
}

static int test_cycle(struct ww_class *class, unsigned int ncpus)
{
        unsigned int n;
        int ret;

        for (n = 2; n <= ncpus + 1; n++) {
                ret = __test_cycle(class, n);
                if (ret)
                        return ret;
        }

        return 0;
}

struct stress {
        struct work_struct work;
        struct ww_mutex *locks;
        struct ww_class *class;
        unsigned long timeout;
        int nlocks;
};

struct rnd_state rng;
DEFINE_SPINLOCK(rng_lock);

static inline u32 prandom_u32_below(u32 ceil)
{
        u32 ret;

        spin_lock(&rng_lock);
        ret = prandom_u32_state(&rng) % ceil;
        spin_unlock(&rng_lock);
        return ret;
}

static int *get_random_order(int count)
{
        int *order;
        int n, r;

        order = kmalloc_objs(*order, count);
        if (!order)
                return order;

        for (n = 0; n < count; n++)
                order[n] = n;

        for (n = count - 1; n > 1; n--) {
                r = prandom_u32_below(n + 1);
                if (r != n)
                        swap(order[n], order[r]);
        }

        return order;
}

static void dummy_load(struct stress *stress)
{
        usleep_range(1000, 2000);
}

static void stress_inorder_work(struct work_struct *work)
{
        struct stress *stress = container_of(work, typeof(*stress), work);
        const int nlocks = stress->nlocks;
        struct ww_mutex *locks = stress->locks;
        struct ww_acquire_ctx ctx;
        int *order;

        order = get_random_order(nlocks);
        if (!order)
                return;

        do {
                int contended = -1;
                int n, err;

                ww_acquire_init(&ctx, stress->class);
retry:
                err = 0;
                for (n = 0; n < nlocks; n++) {
                        if (n == contended)
                                continue;

                        err = ww_mutex_lock(&locks[order[n]], &ctx);
                        if (err < 0)
                                break;
                }
                if (!err)
                        dummy_load(stress);

                if (contended > n)
                        ww_mutex_unlock(&locks[order[contended]]);
                contended = n;
                while (n--)
                        ww_mutex_unlock(&locks[order[n]]);

                if (err == -EDEADLK) {
                        if (!time_after(jiffies, stress->timeout)) {
                                ww_mutex_lock_slow(&locks[order[contended]], &ctx);
                                goto retry;
                        }
                }

                ww_acquire_fini(&ctx);
                if (err) {
                        pr_err_once("stress (%s) failed with %d\n",
                                    __func__, err);
                        break;
                }
        } while (!time_after(jiffies, stress->timeout));

        kfree(order);
}

struct reorder_lock {
        struct list_head link;
        struct ww_mutex *lock;
};

static void stress_reorder_work(struct work_struct *work)
{
        struct stress *stress = container_of(work, typeof(*stress), work);
        LIST_HEAD(locks);
        struct ww_acquire_ctx ctx;
        struct reorder_lock *ll, *ln;
        int *order;
        int n, err;

        order = get_random_order(stress->nlocks);
        if (!order)
                return;

        for (n = 0; n < stress->nlocks; n++) {
                ll = kmalloc_obj(*ll);
                if (!ll)
                        goto out;

                ll->lock = &stress->locks[order[n]];
                list_add(&ll->link, &locks);
        }
        kfree(order);
        order = NULL;

        do {
                ww_acquire_init(&ctx, stress->class);

                list_for_each_entry(ll, &locks, link) {
                        err = ww_mutex_lock(ll->lock, &ctx);
                        if (!err)
                                continue;

                        ln = ll;
                        list_for_each_entry_continue_reverse(ln, &locks, link)
                                ww_mutex_unlock(ln->lock);

                        if (err != -EDEADLK) {
                                pr_err_once("stress (%s) failed with %d\n",
                                            __func__, err);
                                break;
                        }

                        ww_mutex_lock_slow(ll->lock, &ctx);
                        list_move(&ll->link, &locks); /* restarts iteration */
                }

                dummy_load(stress);
                list_for_each_entry(ll, &locks, link)
                        ww_mutex_unlock(ll->lock);

                ww_acquire_fini(&ctx);
        } while (!time_after(jiffies, stress->timeout));

out:
        list_for_each_entry_safe(ll, ln, &locks, link)
                kfree(ll);
        kfree(order);
}

static void stress_one_work(struct work_struct *work)
{
        struct stress *stress = container_of(work, typeof(*stress), work);
        const int nlocks = stress->nlocks;
        struct ww_mutex *lock = stress->locks + get_random_u32_below(nlocks);
        int err;

        do {
                err = ww_mutex_lock(lock, NULL);
                if (!err) {
                        dummy_load(stress);
                        ww_mutex_unlock(lock);
                } else {
                        pr_err_once("stress (%s) failed with %d\n",
                                    __func__, err);
                        break;
                }
        } while (!time_after(jiffies, stress->timeout));
}

#define STRESS_INORDER BIT(0)
#define STRESS_REORDER BIT(1)
#define STRESS_ONE BIT(2)
#define STRESS_ALL (STRESS_INORDER | STRESS_REORDER | STRESS_ONE)

static int stress(struct ww_class *class, int nlocks, int nthreads, unsigned int flags)
{
        struct ww_mutex *locks;
        struct stress *stress_array;
        int n, count;

        locks = kmalloc_objs(*locks, nlocks);
        if (!locks)
                return -ENOMEM;

        stress_array = kmalloc_objs(*stress_array, nthreads);
        if (!stress_array) {
                kfree(locks);
                return -ENOMEM;
        }

        for (n = 0; n < nlocks; n++)
                ww_mutex_init(&locks[n], class);

        count = 0;
        for (n = 0; nthreads; n++) {
                struct stress *stress;
                void (*fn)(struct work_struct *work);

                fn = NULL;
                switch (n & 3) {
                case 0:
                        if (flags & STRESS_INORDER)
                                fn = stress_inorder_work;
                        break;
                case 1:
                        if (flags & STRESS_REORDER)
                                fn = stress_reorder_work;
                        break;
                case 2:
                        if (flags & STRESS_ONE)
                                fn = stress_one_work;
                        break;
                }

                if (!fn)
                        continue;

                stress = &stress_array[count++];

                INIT_WORK(&stress->work, fn);
                stress->class = class;
                stress->locks = locks;
                stress->nlocks = nlocks;
                stress->timeout = jiffies + 2*HZ;

                queue_work(wq, &stress->work);
                nthreads--;
        }

        flush_workqueue(wq);

        for (n = 0; n < nlocks; n++)
                ww_mutex_destroy(&locks[n]);
        kfree(stress_array);
        kfree(locks);

        return 0;
}

static int run_tests(struct ww_class *class)
{
        int ncpus = num_online_cpus();
        int ret, i;

        ret = test_mutex(class);
        if (ret)
                return ret;

        ret = test_aa(class, false);
        if (ret)
                return ret;

        ret = test_aa(class, true);
        if (ret)
                return ret;

        for (i = 0; i < 4; i++) {
                ret = test_abba(class, i & 1, i & 2);
                if (ret)
                        return ret;
        }

        ret = test_cycle(class, ncpus);
        if (ret)
                return ret;

        ret = stress(class, 16, 2 * ncpus, STRESS_INORDER);
        if (ret)
                return ret;

        ret = stress(class, 16, 2 * ncpus, STRESS_REORDER);
        if (ret)
                return ret;

        ret = stress(class, 2046, hweight32(STRESS_ALL) * ncpus, STRESS_ALL);
        if (ret)
                return ret;

        return 0;
}

static int run_test_classes(void)
{
        int ret;

        pr_info("Beginning ww (wound) mutex selftests\n");

        ret = run_tests(&ww_class);
        if (ret)
                return ret;

        pr_info("Beginning ww (die) mutex selftests\n");
        ret = run_tests(&wd_class);
        if (ret)
                return ret;

        pr_info("All ww mutex selftests passed\n");
        return 0;
}

static DEFINE_MUTEX(run_lock);

static ssize_t run_tests_store(struct kobject *kobj, struct kobj_attribute *attr,
                               const char *buf, size_t count)
{
        if (!mutex_trylock(&run_lock)) {
                pr_err("Test already running\n");
                return count;
        }

        run_test_classes();
        mutex_unlock(&run_lock);

        return count;
}

static struct kobj_attribute run_tests_attribute =
        __ATTR(run_tests, 0664, NULL, run_tests_store);

static struct attribute *attrs[] = {
        &run_tests_attribute.attr,
        NULL,   /* need to NULL terminate the list of attributes */
};

static struct attribute_group attr_group = {
        .attrs = attrs,
};

static struct kobject *test_ww_mutex_kobj;

static int __init test_ww_mutex_init(void)
{
        int ret;

        prandom_seed_state(&rng, get_random_u64());

        wq = alloc_workqueue("test-ww_mutex", WQ_UNBOUND, 0);
        if (!wq)
                return -ENOMEM;

        test_ww_mutex_kobj = kobject_create_and_add("test_ww_mutex", kernel_kobj);
        if (!test_ww_mutex_kobj) {
                destroy_workqueue(wq);
                return -ENOMEM;
        }

        /* Create the files associated with this kobject */
        ret = sysfs_create_group(test_ww_mutex_kobj, &attr_group);
        if (ret) {
                kobject_put(test_ww_mutex_kobj);
                destroy_workqueue(wq);
                return ret;
        }

        mutex_lock(&run_lock);
        ret = run_test_classes();
        mutex_unlock(&run_lock);

        return ret;
}

static void __exit test_ww_mutex_exit(void)
{
        kobject_put(test_ww_mutex_kobj);
        destroy_workqueue(wq);
}

module_init(test_ww_mutex_init);
module_exit(test_ww_mutex_exit);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Intel Corporation");
MODULE_DESCRIPTION("API test facility for ww_mutexes");