root/drivers/md/dm-bio-prison-v2.c
// SPDX-License-Identifier: GPL-2.0-only
/*
 * Copyright (C) 2012-2017 Red Hat, Inc.
 *
 * This file is released under the GPL.
 */

#include "dm.h"
#include "dm-bio-prison-v2.h"

#include <linux/spinlock.h>
#include <linux/mempool.h>
#include <linux/module.h>
#include <linux/slab.h>
#include <linux/rwsem.h>

/*----------------------------------------------------------------*/

#define MIN_CELLS 1024

struct dm_bio_prison_v2 {
        struct workqueue_struct *wq;

        spinlock_t lock;
        struct rb_root cells;
        mempool_t cell_pool;
};

static struct kmem_cache *_cell_cache;

/*----------------------------------------------------------------*/

/*
 * @nr_cells should be the number of cells you want in use _concurrently_.
 * Don't confuse it with the number of distinct keys.
 */
struct dm_bio_prison_v2 *dm_bio_prison_create_v2(struct workqueue_struct *wq)
{
        struct dm_bio_prison_v2 *prison = kzalloc_obj(*prison);
        int ret;

        if (!prison)
                return NULL;

        prison->wq = wq;
        spin_lock_init(&prison->lock);

        ret = mempool_init_slab_pool(&prison->cell_pool, MIN_CELLS, _cell_cache);
        if (ret) {
                kfree(prison);
                return NULL;
        }

        prison->cells = RB_ROOT;

        return prison;
}
EXPORT_SYMBOL_GPL(dm_bio_prison_create_v2);

void dm_bio_prison_destroy_v2(struct dm_bio_prison_v2 *prison)
{
        mempool_exit(&prison->cell_pool);
        kfree(prison);
}
EXPORT_SYMBOL_GPL(dm_bio_prison_destroy_v2);

struct dm_bio_prison_cell_v2 *dm_bio_prison_alloc_cell_v2(struct dm_bio_prison_v2 *prison, gfp_t gfp)
{
        return mempool_alloc(&prison->cell_pool, gfp);
}
EXPORT_SYMBOL_GPL(dm_bio_prison_alloc_cell_v2);

void dm_bio_prison_free_cell_v2(struct dm_bio_prison_v2 *prison,
                                struct dm_bio_prison_cell_v2 *cell)
{
        mempool_free(cell, &prison->cell_pool);
}
EXPORT_SYMBOL_GPL(dm_bio_prison_free_cell_v2);

static void __setup_new_cell(struct dm_cell_key_v2 *key,
                             struct dm_bio_prison_cell_v2 *cell)
{
        memset(cell, 0, sizeof(*cell));
        memcpy(&cell->key, key, sizeof(cell->key));
        bio_list_init(&cell->bios);
}

static int cmp_keys(struct dm_cell_key_v2 *lhs,
                    struct dm_cell_key_v2 *rhs)
{
        if (lhs->virtual < rhs->virtual)
                return -1;

        if (lhs->virtual > rhs->virtual)
                return 1;

        if (lhs->dev < rhs->dev)
                return -1;

        if (lhs->dev > rhs->dev)
                return 1;

        if (lhs->block_end <= rhs->block_begin)
                return -1;

        if (lhs->block_begin >= rhs->block_end)
                return 1;

        return 0;
}

/*
 * Returns true if node found, otherwise it inserts a new one.
 */
static bool __find_or_insert(struct dm_bio_prison_v2 *prison,
                             struct dm_cell_key_v2 *key,
                             struct dm_bio_prison_cell_v2 *cell_prealloc,
                             struct dm_bio_prison_cell_v2 **result)
{
        int r;
        struct rb_node **new = &prison->cells.rb_node, *parent = NULL;

        while (*new) {
                struct dm_bio_prison_cell_v2 *cell =
                        rb_entry(*new, struct dm_bio_prison_cell_v2, node);

                r = cmp_keys(key, &cell->key);

                parent = *new;
                if (r < 0)
                        new = &((*new)->rb_left);

                else if (r > 0)
                        new = &((*new)->rb_right);

                else {
                        *result = cell;
                        return true;
                }
        }

        __setup_new_cell(key, cell_prealloc);
        *result = cell_prealloc;
        rb_link_node(&cell_prealloc->node, parent, new);
        rb_insert_color(&cell_prealloc->node, &prison->cells);

        return false;
}

static bool __get(struct dm_bio_prison_v2 *prison,
                  struct dm_cell_key_v2 *key,
                  unsigned int lock_level,
                  struct bio *inmate,
                  struct dm_bio_prison_cell_v2 *cell_prealloc,
                  struct dm_bio_prison_cell_v2 **cell)
{
        if (__find_or_insert(prison, key, cell_prealloc, cell)) {
                if ((*cell)->exclusive_lock) {
                        if (lock_level <= (*cell)->exclusive_level) {
                                bio_list_add(&(*cell)->bios, inmate);
                                return false;
                        }
                }

                (*cell)->shared_count++;

        } else
                (*cell)->shared_count = 1;

        return true;
}

bool dm_cell_get_v2(struct dm_bio_prison_v2 *prison,
                    struct dm_cell_key_v2 *key,
                    unsigned int lock_level,
                    struct bio *inmate,
                    struct dm_bio_prison_cell_v2 *cell_prealloc,
                    struct dm_bio_prison_cell_v2 **cell_result)
{
        int r;

        spin_lock_irq(&prison->lock);
        r = __get(prison, key, lock_level, inmate, cell_prealloc, cell_result);
        spin_unlock_irq(&prison->lock);

        return r;
}
EXPORT_SYMBOL_GPL(dm_cell_get_v2);

static bool __put(struct dm_bio_prison_v2 *prison,
                  struct dm_bio_prison_cell_v2 *cell)
{
        BUG_ON(!cell->shared_count);
        cell->shared_count--;

        // FIXME: shared locks granted above the lock level could starve this
        if (!cell->shared_count) {
                if (cell->exclusive_lock) {
                        if (cell->quiesce_continuation) {
                                queue_work(prison->wq, cell->quiesce_continuation);
                                cell->quiesce_continuation = NULL;
                        }
                } else {
                        rb_erase(&cell->node, &prison->cells);
                        return true;
                }
        }

        return false;
}

bool dm_cell_put_v2(struct dm_bio_prison_v2 *prison,
                    struct dm_bio_prison_cell_v2 *cell)
{
        bool r;
        unsigned long flags;

        spin_lock_irqsave(&prison->lock, flags);
        r = __put(prison, cell);
        spin_unlock_irqrestore(&prison->lock, flags);

        return r;
}
EXPORT_SYMBOL_GPL(dm_cell_put_v2);

static int __lock(struct dm_bio_prison_v2 *prison,
                  struct dm_cell_key_v2 *key,
                  unsigned int lock_level,
                  struct dm_bio_prison_cell_v2 *cell_prealloc,
                  struct dm_bio_prison_cell_v2 **cell_result)
{
        struct dm_bio_prison_cell_v2 *cell;

        if (__find_or_insert(prison, key, cell_prealloc, &cell)) {
                if (cell->exclusive_lock)
                        return -EBUSY;

                cell->exclusive_lock = true;
                cell->exclusive_level = lock_level;
                *cell_result = cell;

                // FIXME: we don't yet know what level these shared locks
                // were taken at, so have to quiesce them all.
                return cell->shared_count > 0;

        } else {
                cell = cell_prealloc;
                cell->shared_count = 0;
                cell->exclusive_lock = true;
                cell->exclusive_level = lock_level;
                *cell_result = cell;
        }

        return 0;
}

int dm_cell_lock_v2(struct dm_bio_prison_v2 *prison,
                    struct dm_cell_key_v2 *key,
                    unsigned int lock_level,
                    struct dm_bio_prison_cell_v2 *cell_prealloc,
                    struct dm_bio_prison_cell_v2 **cell_result)
{
        int r;

        spin_lock_irq(&prison->lock);
        r = __lock(prison, key, lock_level, cell_prealloc, cell_result);
        spin_unlock_irq(&prison->lock);

        return r;
}
EXPORT_SYMBOL_GPL(dm_cell_lock_v2);

static void __quiesce(struct dm_bio_prison_v2 *prison,
                      struct dm_bio_prison_cell_v2 *cell,
                      struct work_struct *continuation)
{
        if (!cell->shared_count)
                queue_work(prison->wq, continuation);
        else
                cell->quiesce_continuation = continuation;
}

void dm_cell_quiesce_v2(struct dm_bio_prison_v2 *prison,
                        struct dm_bio_prison_cell_v2 *cell,
                        struct work_struct *continuation)
{
        spin_lock_irq(&prison->lock);
        __quiesce(prison, cell, continuation);
        spin_unlock_irq(&prison->lock);
}
EXPORT_SYMBOL_GPL(dm_cell_quiesce_v2);

static int __promote(struct dm_bio_prison_v2 *prison,
                     struct dm_bio_prison_cell_v2 *cell,
                     unsigned int new_lock_level)
{
        if (!cell->exclusive_lock)
                return -EINVAL;

        cell->exclusive_level = new_lock_level;
        return cell->shared_count > 0;
}

int dm_cell_lock_promote_v2(struct dm_bio_prison_v2 *prison,
                            struct dm_bio_prison_cell_v2 *cell,
                            unsigned int new_lock_level)
{
        int r;

        spin_lock_irq(&prison->lock);
        r = __promote(prison, cell, new_lock_level);
        spin_unlock_irq(&prison->lock);

        return r;
}
EXPORT_SYMBOL_GPL(dm_cell_lock_promote_v2);

static bool __unlock(struct dm_bio_prison_v2 *prison,
                     struct dm_bio_prison_cell_v2 *cell,
                     struct bio_list *bios)
{
        BUG_ON(!cell->exclusive_lock);

        bio_list_merge_init(bios, &cell->bios);

        if (cell->shared_count) {
                cell->exclusive_lock = false;
                return false;
        }

        rb_erase(&cell->node, &prison->cells);
        return true;
}

bool dm_cell_unlock_v2(struct dm_bio_prison_v2 *prison,
                       struct dm_bio_prison_cell_v2 *cell,
                       struct bio_list *bios)
{
        bool r;

        spin_lock_irq(&prison->lock);
        r = __unlock(prison, cell, bios);
        spin_unlock_irq(&prison->lock);

        return r;
}
EXPORT_SYMBOL_GPL(dm_cell_unlock_v2);

/*----------------------------------------------------------------*/

int __init dm_bio_prison_init_v2(void)
{
        _cell_cache = KMEM_CACHE(dm_bio_prison_cell_v2, 0);
        if (!_cell_cache)
                return -ENOMEM;

        return 0;
}

void dm_bio_prison_exit_v2(void)
{
        kmem_cache_destroy(_cell_cache);
        _cell_cache = NULL;
}