root/src/system/kernel/locks/user_mutex.cpp
/*
 * Copyright 2023, Haiku, Inc. All rights reserved.
 * Copyright 2018, Jérôme Duval, jerome.duval@gmail.com.
 * Copyright 2015, Hamish Morrison, hamishm53@gmail.com.
 * Copyright 2010, Ingo Weinhold, ingo_weinhold@gmx.de.
 * Distributed under the terms of the MIT License.
 */


#include <user_mutex.h>
#include <user_mutex_defs.h>

#include <condition_variable.h>
#include <kernel.h>
#include <lock.h>
#include <smp.h>
#include <syscall_restart.h>
#include <util/AutoLock.h>
#include <util/ThreadAutoLock.h>
#include <util/OpenHashTable.h>
#include <vm/vm.h>
#include <vm/VMArea.h>
#include <arch/generic/user_memory.h>


/*! One UserMutexEntry corresponds to one mutex address.
 *
 * The mutex's "waiting" state is controlled by the rw_lock: a waiter acquires
 * a "read" lock before initiating a wait, and an unblocker acquires a "write"
 * lock. That way, unblockers can be sure that no waiters will start waiting
 * during unblock, and they can thus safely (without races) unset WAITING.
 */
struct UserMutexEntry {
        generic_addr_t          address;
        UserMutexEntry*         hash_next;
        int32                           ref_count;

        rw_lock                         lock;
        ConditionVariable       condition;
};

struct UserMutexHashDefinition {
        typedef generic_addr_t  KeyType;
        typedef UserMutexEntry  ValueType;

        size_t HashKey(generic_addr_t key) const
        {
                return key >> 2;
        }

        size_t Hash(const UserMutexEntry* value) const
        {
                return HashKey(value->address);
        }

        bool Compare(generic_addr_t key, const UserMutexEntry* value) const
        {
                return value->address == key;
        }

        UserMutexEntry*& GetLink(UserMutexEntry* value) const
        {
                return value->hash_next;
        }
};

typedef BOpenHashTable<UserMutexHashDefinition> UserMutexTable;


struct user_mutex_context {
        UserMutexTable table;
        rw_lock lock;
};
static user_mutex_context sSharedUserMutexContext;
static const char* kUserMutexEntryType = "umtx entry";


// #pragma mark - user atomics


static int32
user_atomic_or(int32* value, int32 orValue, bool isWired)
{
        int32 result;
        if (isWired) {
                arch_cpu_enable_user_access();
                result = atomic_or(value, orValue);
                arch_cpu_disable_user_access();
                return result;
        }

        return user_access([=, &result] {
                result = atomic_or(value, orValue);
        }) ? result : INT32_MIN;
}


static int32
user_atomic_and(int32* value, int32 andValue, bool isWired)
{
        int32 result;
        if (isWired) {
                arch_cpu_enable_user_access();
                result = atomic_and(value, andValue);
                arch_cpu_disable_user_access();
                return result;
        }

        return user_access([=, &result] {
                result = atomic_and(value, andValue);
        }) ? result : INT32_MIN;
}


static int32
user_atomic_get(int32* value, bool isWired)
{
        int32 result;
        if (isWired) {
                arch_cpu_enable_user_access();
                result = atomic_get(value);
                arch_cpu_disable_user_access();
                return result;
        }

        return user_access([=, &result] {
                result = atomic_get(value);
        }) ? result : INT32_MIN;
}


static int32
user_atomic_test_and_set(int32* value, int32 newValue, int32 testAgainst,
        bool isWired)
{
        int32 result;
        if (isWired) {
                arch_cpu_enable_user_access();
                result = atomic_test_and_set(value, newValue, testAgainst);
                arch_cpu_disable_user_access();
                return result;
        }

        return user_access([=, &result] {
                result = atomic_test_and_set(value, newValue, testAgainst);
        }) ? result : INT32_MIN;
}


// #pragma mark - user mutex context


static int
dump_user_mutex(int argc, char** argv)
{
        if (argc != 2) {
                print_debugger_command_usage(argv[0]);
                return 0;
        }

        addr_t threadID = parse_expression(argv[1]);
        if (threadID == 0)
                return 0;

        Thread* thread = Thread::GetDebug(threadID);
        if (thread == NULL) {
                kprintf("no such thread\n");
                return 0;
        }

        if (thread->wait.type != THREAD_BLOCK_TYPE_CONDITION_VARIABLE) {
                kprintf("thread is not blocked on cvar (thus not user_mutex)\n");
                return 0;
        }

        ConditionVariable* variable = (ConditionVariable*)thread->wait.object;
        if (variable->ObjectType() != kUserMutexEntryType) {
                kprintf("thread is not blocked on user_mutex\n");
                return 0;
        }

        UserMutexEntry* entry = (UserMutexEntry*)variable->Object();

        const bool physical = (sSharedUserMutexContext.table.Lookup(entry->address) == entry);
        kprintf("user mutex entry %p\n", entry);
        kprintf("  address:  0x%" B_PRIxPHYSADDR " (%s)\n", entry->address,
                physical ? "physical" : "virtual");
        kprintf("  refcount: %" B_PRId32 "\n", entry->ref_count);
        kprintf("  lock:     %p\n", &entry->lock);

        int32 mutex = 0;
        status_t status = B_ERROR;
        if (!physical) {
                status = debug_memcpy(thread->team->id, &mutex,
                        (void*)entry->address, sizeof(mutex));
        }

        if (status == B_OK)
                kprintf("  mutex:    0x%" B_PRIx32 "\n", mutex);

        entry->condition.Dump();

        return 0;
}


void
user_mutex_init()
{
        sSharedUserMutexContext.lock = RW_LOCK_INITIALIZER("shared user mutex table");
        if (sSharedUserMutexContext.table.Init() != B_OK)
                panic("user_mutex_init(): Failed to init table!");

        add_debugger_command_etc("user_mutex", &dump_user_mutex,
                "Dump user-mutex info",
                "<thread>\n"
                "Prints info about the user-mutex a thread is blocked on.\n"
                "  <thread>  - Thread ID that is blocked on a user mutex\n", 0);
}


status_t
allocate_team_user_mutex_context(Team* team)
{
        team->AssertLocked();
        if (team->user_mutex_context != NULL)
                return B_OK;

        struct user_mutex_context* context = new(std::nothrow) user_mutex_context;
        if (context == NULL)
                return B_NO_MEMORY;

        context->lock = RW_LOCK_INITIALIZER("user mutex table");
        status_t status = context->table.Init();
        if (status != B_OK) {
                delete context;
                return status;
        }

        team->user_mutex_context = context;
        return B_OK;
}


void
delete_user_mutex_context(struct user_mutex_context* context)
{
        if (context == NULL)
                return;

        // This should be empty at this point in team destruction.
        ASSERT(context->table.IsEmpty());
        delete context;
}


static UserMutexEntry*
get_user_mutex_entry(struct user_mutex_context* context,
        generic_addr_t address, bool noInsert = false, bool alreadyLocked = false)
{
        ReadLocker tableReadLocker;
        if (!alreadyLocked)
                tableReadLocker.SetTo(context->lock, false);

        UserMutexEntry* entry = context->table.Lookup(address);
        if (entry != NULL) {
                atomic_add(&entry->ref_count, 1);
                return entry;
        } else if (noInsert)
                return entry;

        tableReadLocker.Unlock();
        WriteLocker tableWriteLocker(context->lock);

        entry = context->table.Lookup(address);
        if (entry != NULL) {
                atomic_add(&entry->ref_count, 1);
                return entry;
        }

        entry = new(std::nothrow) UserMutexEntry;
        if (entry == NULL) {
                panic("UserMutexEntry allocation failed!");
                return entry;
        }

        entry->address = address;
        entry->ref_count = 1;
        rw_lock_init(&entry->lock, "UserMutexEntry lock");
        entry->condition.Init(entry, kUserMutexEntryType);

        context->table.Insert(entry);
        return entry;
}


static void
put_user_mutex_entry(struct user_mutex_context* context, UserMutexEntry* entry)
{
        if (entry == NULL)
                return;

        const generic_addr_t address = entry->address;
        if (atomic_add(&entry->ref_count, -1) != 1)
                return;

        WriteLocker tableWriteLocker(context->lock);

        // Was it removed & deleted while we were waiting for the lock?
        if (context->table.Lookup(address) != entry)
                return;

        // Or did someone else acquire a reference to it?
        if (atomic_get(&entry->ref_count) > 0)
                return;

        context->table.Remove(entry);
        tableWriteLocker.Unlock();

        rw_lock_destroy(&entry->lock);
        delete entry;
}


static status_t
user_mutex_wait_locked(UserMutexEntry* entry,
        uint32 flags, bigtime_t timeout, ReadLocker& locker)
{
        ConditionVariableEntry waiter;
        entry->condition.Add(&waiter);
        locker.Unlock();

        return waiter.Wait(flags, timeout);
}


static bool
user_mutex_prepare_to_lock(UserMutexEntry* entry, int32* mutex, bool isWired)
{
        ASSERT_READ_LOCKED_RW_LOCK(&entry->lock);

        int32 oldValue = user_atomic_or(mutex,
                B_USER_MUTEX_LOCKED | B_USER_MUTEX_WAITING, isWired);
        if ((oldValue & B_USER_MUTEX_LOCKED) == 0
                        || (oldValue & B_USER_MUTEX_DISABLED) != 0) {
                // possibly unset waiting flag
                if ((oldValue & B_USER_MUTEX_WAITING) == 0) {
                        rw_lock_read_unlock(&entry->lock);
                        rw_lock_write_lock(&entry->lock);
                        if (entry->condition.EntriesCount() == 0)
                                user_atomic_and(mutex, ~(int32)B_USER_MUTEX_WAITING, isWired);
                        rw_lock_write_unlock(&entry->lock);
                        rw_lock_read_lock(&entry->lock);
                }
                return true;
        }

        return false;
}


static status_t
user_mutex_lock_locked(UserMutexEntry* entry, int32* mutex,
        uint32 flags, bigtime_t timeout, ReadLocker& locker, bool isWired)
{
        if (user_mutex_prepare_to_lock(entry, mutex, isWired))
                return B_OK;

        status_t error = user_mutex_wait_locked(entry, flags, timeout, locker);

        // possibly unset waiting flag
        if (error != B_OK && entry->condition.EntriesCount() == 0) {
                WriteLocker writeLocker(entry->lock);
                if (entry->condition.EntriesCount() == 0)
                        user_atomic_and(mutex, ~(int32)B_USER_MUTEX_WAITING, isWired);
        }

        return error;
}


static void
user_mutex_unblock(UserMutexEntry* entry, int32* mutex, uint32 flags, bool isWired)
{
        WriteLocker entryLocker(entry->lock);
        if (entry->condition.EntriesCount() == 0) {
                // Nobody is actually waiting at present.
                user_atomic_and(mutex, ~(int32)B_USER_MUTEX_WAITING, isWired);
                return;
        }

        int32 oldValue = 0;
        if ((flags & B_USER_MUTEX_UNBLOCK_ALL) == 0) {
                // This is not merely an unblock, but a hand-off.
                oldValue = user_atomic_or(mutex, B_USER_MUTEX_LOCKED, isWired);
                if ((oldValue & B_USER_MUTEX_LOCKED) != 0)
                        return;
        }

        if ((flags & B_USER_MUTEX_UNBLOCK_ALL) != 0
                        || (oldValue & B_USER_MUTEX_DISABLED) != 0) {
                // unblock all waiting threads
                entry->condition.NotifyAll(B_OK);
        } else {
                if (!entry->condition.NotifyOne(B_OK))
                        user_atomic_and(mutex, ~(int32)B_USER_MUTEX_LOCKED, isWired);
        }

        if (entry->condition.EntriesCount() == 0)
                user_atomic_and(mutex, ~(int32)B_USER_MUTEX_WAITING, isWired);
}


static status_t
user_mutex_sem_acquire_locked(UserMutexEntry* entry, int32* sem,
        uint32 flags, bigtime_t timeout, ReadLocker& locker, bool isWired)
{
        // The semaphore may have been released in the meantime, and we also
        // need to mark it as contended if it isn't already.
        int32 oldValue = user_atomic_get(sem, isWired);
        while (oldValue > -1) {
                int32 value = user_atomic_test_and_set(sem, oldValue - 1, oldValue, isWired);
                if (value == oldValue && value > 0)
                        return B_OK;
                oldValue = value;
        }

        return user_mutex_wait_locked(entry, flags,
                timeout, locker);
}


static void
user_mutex_sem_release(UserMutexEntry* entry, int32* sem, bool isWired)
{
        WriteLocker entryLocker(entry->lock);
        if (entry->condition.NotifyOne(B_OK) == 0) {
                // no waiters - mark as uncontended and release
                int32 oldValue = user_atomic_get(sem, isWired);
                while (true) {
                        int32 inc = oldValue < 0 ? 2 : 1;
                        int32 value = user_atomic_test_and_set(sem, oldValue + inc, oldValue, isWired);
                        if (value == oldValue)
                                return;
                        oldValue = value;
                }
        }

        if (entry->condition.EntriesCount() == 0) {
                // mark the semaphore uncontended
                user_atomic_test_and_set(sem, 0, -1, isWired);
        }
}


// #pragma mark - syscalls


struct UserMutexContextFetcher {
        UserMutexContextFetcher(int32* mutex, uint32 flags)
                :
                fInitStatus(B_OK),
                fShared((flags & B_USER_MUTEX_SHARED) != 0),
                fAddress(0)
        {
                if (!fShared) {
                        fContext = thread_get_current_thread()->team->user_mutex_context;
                        if (fContext == NULL) {
                                // This should only happen for single-threaded applications (some
                                // appear to use mutexes as a sleep mechanism), as the context gets
                                // allocated on the creation of a second thread.
                                Team* team = thread_get_current_thread()->team;
                                team->Lock();
                                fInitStatus = allocate_team_user_mutex_context(team);
                                fContext = team->user_mutex_context;
                                team->Unlock();
                                if (fInitStatus != B_OK)
                                        return;
                        }

                        fAddress = (addr_t)mutex;
                } else {
                        fContext = &sSharedUserMutexContext;

                        // wire the page and get the physical address
                        fInitStatus = vm_wire_page(B_CURRENT_TEAM, (addr_t)mutex, true,
                                &fWiringInfo);
                        if (fInitStatus != B_OK)
                                return;
                        fAddress = fWiringInfo.physicalAddress;
                }
        }

        ~UserMutexContextFetcher()
        {
                if (fInitStatus != B_OK)
                        return;

                if (fShared)
                        vm_unwire_page(&fWiringInfo);
        }

        status_t InitCheck() const
                { return fInitStatus; }

        struct user_mutex_context* Context() const
                { return fContext; }

        generic_addr_t Address() const
                { return fAddress; }

        bool IsWired() const
                { return fShared; }

private:
        status_t fInitStatus;
        bool fShared;
        struct user_mutex_context* fContext;
        VMPageWiringInfo fWiringInfo;
        generic_addr_t fAddress;
};


static status_t
user_mutex_lock(int32* mutex, const char* name, uint32 flags, bigtime_t timeout)
{
        UserMutexContextFetcher contextFetcher(mutex, flags);
        if (contextFetcher.InitCheck() != B_OK)
                return contextFetcher.InitCheck();

        // get the lock
        UserMutexEntry* entry = get_user_mutex_entry(contextFetcher.Context(),
                contextFetcher.Address());
        if (entry == NULL)
                return B_NO_MEMORY;
        status_t error = B_OK;
        {
                ReadLocker entryLocker(entry->lock);
                error = user_mutex_lock_locked(entry, mutex,
                        flags, timeout, entryLocker, contextFetcher.IsWired());
        }
        put_user_mutex_entry(contextFetcher.Context(), entry);

        return error;
}


static status_t
user_mutex_switch_lock(int32* fromMutex, uint32 fromFlags,
        int32* toMutex, const char* name, uint32 toFlags, bigtime_t timeout)
{
        UserMutexContextFetcher fromFetcher(fromMutex, fromFlags);
        if (fromFetcher.InitCheck() != B_OK)
                return fromFetcher.InitCheck();

        UserMutexContextFetcher toFetcher(toMutex, toFlags);
        if (toFetcher.InitCheck() != B_OK)
                return toFetcher.InitCheck();

        // unlock the first mutex and lock the second one
        UserMutexEntry* fromEntry = NULL,
                *toEntry = get_user_mutex_entry(toFetcher.Context(), toFetcher.Address());
        if (toEntry == NULL)
                return B_NO_MEMORY;
        status_t error = B_OK;
        {
                ConditionVariableEntry waiter;

                bool alreadyLocked = false;
                {
                        ReadLocker entryLocker(toEntry->lock);
                        alreadyLocked = user_mutex_prepare_to_lock(toEntry, toMutex,
                                toFetcher.IsWired());
                        if (!alreadyLocked)
                                toEntry->condition.Add(&waiter);
                }

                const int32 oldValue = user_atomic_and(fromMutex, ~(int32)B_USER_MUTEX_LOCKED,
                        fromFetcher.IsWired());
                if ((oldValue & B_USER_MUTEX_WAITING) != 0) {
                        fromEntry = get_user_mutex_entry(fromFetcher.Context(),
                                fromFetcher.Address(), true);
                         if (fromEntry != NULL) {
                                 user_mutex_unblock(fromEntry, fromMutex, fromFlags,
                                         fromFetcher.IsWired());
                         }
                }

                if (!alreadyLocked)
                        error = waiter.Wait(toFlags, timeout);
        }
        put_user_mutex_entry(fromFetcher.Context(), fromEntry);
        put_user_mutex_entry(toFetcher.Context(), toEntry);

        return error;
}


status_t
_user_mutex_lock(int32* mutex, const char* name, uint32 flags,
        bigtime_t timeout)
{
        if (mutex == NULL || !IS_USER_ADDRESS(mutex) || (addr_t)mutex % 4 != 0)
                return B_BAD_ADDRESS;

        syscall_restart_handle_timeout_pre(flags, timeout);

        status_t error = user_mutex_lock(mutex, name, flags | B_CAN_INTERRUPT,
                timeout);

        return syscall_restart_handle_timeout_post(error, timeout);
}


status_t
_user_mutex_unblock(int32* mutex, uint32 flags)
{
        if (mutex == NULL || !IS_USER_ADDRESS(mutex) || (addr_t)mutex % 4 != 0)
                return B_BAD_ADDRESS;

        UserMutexContextFetcher contextFetcher(mutex, flags);
        if (contextFetcher.InitCheck() != B_OK)
                return contextFetcher.InitCheck();
        struct user_mutex_context* context = contextFetcher.Context();

        // In the case where there is no entry, we must hold the read lock until we
        // unset WAITING, because otherwise some other thread could initiate a wait.
        ReadLocker tableReadLocker(context->lock);
        UserMutexEntry* entry = get_user_mutex_entry(context,
                contextFetcher.Address(), true, true);
        if (entry == NULL) {
                user_atomic_and(mutex, ~(int32)B_USER_MUTEX_WAITING, contextFetcher.IsWired());
                tableReadLocker.Unlock();
        } else {
                tableReadLocker.Unlock();
                user_mutex_unblock(entry, mutex, flags, contextFetcher.IsWired());
        }
        put_user_mutex_entry(context, entry);

        return B_OK;
}


status_t
_user_mutex_switch_lock(int32* fromMutex, uint32 fromFlags,
        int32* toMutex, const char* name, uint32 toFlags, bigtime_t timeout)
{
        if (fromMutex == NULL || !IS_USER_ADDRESS(fromMutex)
                        || (addr_t)fromMutex % 4 != 0 || toMutex == NULL
                        || !IS_USER_ADDRESS(toMutex) || (addr_t)toMutex % 4 != 0) {
                return B_BAD_ADDRESS;
        }

        return user_mutex_switch_lock(fromMutex, fromFlags, toMutex, name,
                toFlags | B_CAN_INTERRUPT, timeout);
}


status_t
_user_mutex_sem_acquire(int32* sem, const char* name, uint32 flags,
        bigtime_t timeout)
{
        if (sem == NULL || !IS_USER_ADDRESS(sem) || (addr_t)sem % 4 != 0)
                return B_BAD_ADDRESS;

        syscall_restart_handle_timeout_pre(flags, timeout);

        UserMutexContextFetcher contextFetcher(sem, flags);
        if (contextFetcher.InitCheck() != B_OK)
                return contextFetcher.InitCheck();
        struct user_mutex_context* context = contextFetcher.Context();

        UserMutexEntry* entry = get_user_mutex_entry(context, contextFetcher.Address());
        if (entry == NULL)
                return B_NO_MEMORY;
        status_t error;
        {
                ReadLocker entryLocker(entry->lock);
                error = user_mutex_sem_acquire_locked(entry, sem,
                        flags | B_CAN_INTERRUPT, timeout, entryLocker, contextFetcher.IsWired());
        }
        put_user_mutex_entry(context, entry);

        return syscall_restart_handle_timeout_post(error, timeout);
}


status_t
_user_mutex_sem_release(int32* sem, uint32 flags)
{
        if (sem == NULL || !IS_USER_ADDRESS(sem) || (addr_t)sem % 4 != 0)
                return B_BAD_ADDRESS;

        UserMutexContextFetcher contextFetcher(sem, flags);
        if (contextFetcher.InitCheck() != B_OK)
                return contextFetcher.InitCheck();
        struct user_mutex_context* context = contextFetcher.Context();

        UserMutexEntry* entry = get_user_mutex_entry(context,
                contextFetcher.Address());
        {
                user_mutex_sem_release(entry, sem, contextFetcher.IsWired());
        }
        put_user_mutex_entry(context, entry);

        return B_OK;
}