#include <err.h>
#include <stdlib.h>
#include <unistd.h>
#include <stdbool.h>
#include <sys/sysmacros.h>
#include <sys/debug.h>
#include <thread.h>
#include <synch.h>
#include <strings.h>
#include "nvme_ioctl_util.h"
#define MAX_LOCKS 10
typedef struct {
thread_t loi_thread;
const nvme_ioctl_lock_t *loi_lock;
} lock_order_info_t;
static mutex_t lock_mutex;
static lock_order_info_t lock_results[MAX_LOCKS];
static uint32_t lock_nextres;
static bool lock_valid;
typedef struct lock_order_test lock_order_test_t;
typedef bool (*lock_order_valif_f)(const lock_order_test_t *, uint32_t);
struct lock_order_test {
const char *lot_desc;
const nvme_ioctl_lock_t *lot_initlock;
const nvme_ioctl_lock_t *lot_locks[MAX_LOCKS];
lock_order_valif_f lot_verif;
};
static void
lock_verify_dump(void)
{
for (size_t i = 0; i < lock_nextres; i++) {
const nvme_ioctl_lock_t *lock = lock_results[i].loi_lock;
const char *targ = lock->nil_ent == NVME_LOCK_E_CTRL ?
"controller" : "namespace";
const char *level = lock->nil_level == NVME_LOCK_L_READ ?
"read" : "write";
(void) printf("\t[%zu] = { %s, %s }\n", i, targ, level);
}
}
static bool
lock_verify_write_before_read(const lock_order_test_t *test, uint32_t nthr)
{
bool pass = true;
size_t nwrite = 0;
size_t nread = 0;
for (size_t i = 0; i < MAX_LOCKS; i++) {
if (test->lot_locks[i] == NULL)
break;
if (test->lot_locks[i]->nil_level == NVME_LOCK_L_READ) {
nread++;
} else {
nwrite++;
}
}
VERIFY3U(nwrite + nread, ==, nthr);
mutex_enter(&lock_mutex);
for (size_t i = 0; i < nthr; i++) {
nvme_lock_level_t exp_level;
const char *str;
const lock_order_info_t *res = &lock_results[i];
if (nwrite > 0) {
exp_level = NVME_LOCK_L_WRITE;
str = "WRITE";
nwrite--;
} else {
exp_level = NVME_LOCK_L_READ;
str = "READ";
nread--;
}
if (exp_level != res->loi_lock->nil_level) {
pass = false;
warnx("TEST FAILED: %s: lock %zu (tid %u, ent %u, "
"level %u) was the wrong level, expected level %u "
"(%s)", test->lot_desc, i, res->loi_thread,
res->loi_lock->nil_ent, res->loi_lock->nil_level,
exp_level, str);
}
}
VERIFY3U(nwrite, ==, 0);
VERIFY3U(nread, ==, 0);
if (!pass) {
lock_verify_dump();
}
mutex_exit(&lock_mutex);
return (pass);
}
static bool
lock_verify_ctrl_before_ns(const lock_order_test_t *test, uint32_t nthr)
{
bool pass = true;
size_t nctrl = 0;
size_t nns = 0;
for (size_t i = 0; i < MAX_LOCKS; i++) {
if (test->lot_locks[i] == NULL)
break;
if (test->lot_locks[i]->nil_ent == NVME_LOCK_E_CTRL) {
nctrl++;
} else {
nns++;
}
}
VERIFY3U(nctrl + nns, ==, nthr);
mutex_enter(&lock_mutex);
for (size_t i = 0; i < nthr; i++) {
nvme_lock_ent_t exp_ent;
const char *str;
const lock_order_info_t *res = &lock_results[i];
if (nctrl > 0) {
exp_ent = NVME_LOCK_E_CTRL;
str = "ctrl";
nctrl--;
} else {
exp_ent = NVME_LOCK_E_NS;
str = "ns";
nns--;
}
if (exp_ent != res->loi_lock->nil_ent) {
pass = false;
warnx("TEST FAILED: %s: lock %zu (tid %u, ent %u, "
"level %u) was the wrong entity, expected type %u "
"(%s)", test->lot_desc, i, res->loi_thread,
res->loi_lock->nil_ent, res->loi_lock->nil_level,
exp_ent, str);
}
}
VERIFY3U(nctrl, ==, 0);
VERIFY3U(nns, ==, 0);
if (!pass) {
lock_verify_dump();
}
mutex_exit(&lock_mutex);
return (pass);
}
static bool
lock_verif_ent_level(const lock_order_test_t *test, uint32_t nthr)
{
bool pass = true;
if (!lock_verify_ctrl_before_ns(test, nthr))
pass = false;
if (!lock_verify_write_before_read(test, nthr))
pass = false;
return (pass);
}
static const lock_order_test_t lock_order_tests[] = { {
.lot_desc = "ns(rd): pending ns writer doesn't allow more ns readers",
.lot_initlock = &nvme_test_ns_rdlock,
.lot_locks = { &nvme_test_ns_wrlock, &nvme_test_ns_rdlock },
.lot_verif = lock_verify_write_before_read,
}, {
.lot_desc = "ns(wr): pending ns writer beats waiting ns reader",
.lot_initlock = &nvme_test_ns_wrlock,
.lot_locks = { &nvme_test_ns_rdlock, &nvme_test_ns_wrlock },
.lot_verif = lock_verify_write_before_read,
}, {
.lot_desc = "ns(rd): all pend ns writers beat prior pend readers",
.lot_initlock = &nvme_test_ns_rdlock,
.lot_locks = { &nvme_test_ns_wrlock, &nvme_test_ns_rdlock,
&nvme_test_ns_rdlock, &nvme_test_ns_wrlock, &nvme_test_ns_rdlock,
&nvme_test_ns_wrlock },
.lot_verif = lock_verify_write_before_read,
}, {
.lot_desc = "ns(rd): pending ctrl writer doesn't allow more ns readers",
.lot_initlock = &nvme_test_ns_rdlock,
.lot_locks = { &nvme_test_ctrl_wrlock, &nvme_test_ns_rdlock,
&nvme_test_ns_rdlock },
.lot_verif = lock_verify_write_before_read,
}, {
.lot_desc = "ns(wr): pending ctrl writer beats prior pend ns readers",
.lot_initlock = &nvme_test_ns_wrlock,
.lot_locks = { &nvme_test_ns_rdlock, &nvme_test_ns_rdlock,
&nvme_test_ctrl_wrlock, &nvme_test_ns_rdlock },
.lot_verif = lock_verify_write_before_read,
}, {
.lot_desc = "ns(rd): pending ctrl writer doesn't allow ctrl readers",
.lot_initlock = &nvme_test_ns_rdlock,
.lot_locks = { &nvme_test_ctrl_wrlock, &nvme_test_ctrl_rdlock,
&nvme_test_ctrl_rdlock },
.lot_verif = lock_verify_write_before_read,
}, {
.lot_desc = "ns(rd): pending ctrl writer beats pending ns writer "
"and readers",
.lot_initlock = &nvme_test_ns_rdlock,
.lot_locks = { &nvme_test_ns_wrlock, &nvme_test_ns_rdlock,
&nvme_test_ctrl_wrlock, &nvme_test_ctrl_rdlock },
.lot_verif = lock_verify_ctrl_before_ns,
}, {
.lot_desc = "ctrl(rd): pending ctrl writer blocks ns read",
.lot_initlock = &nvme_test_ctrl_rdlock,
.lot_locks = { &nvme_test_ctrl_wrlock, &nvme_test_ns_rdlock,
&nvme_test_ns_rdlock },
.lot_verif = lock_verif_ent_level,
}, {
.lot_desc = "ctrl(rd): pending ctrl writer blocks ns writer",
.lot_initlock = &nvme_test_ctrl_rdlock,
.lot_locks = { &nvme_test_ctrl_wrlock, &nvme_test_ns_wrlock },
.lot_verif = lock_verif_ent_level,
}, {
.lot_desc = "ctrl(rd): pending ctrl writer blocks ctrl reader",
.lot_initlock = &nvme_test_ctrl_rdlock,
.lot_locks = { &nvme_test_ctrl_wrlock, &nvme_test_ctrl_rdlock },
.lot_verif = lock_verify_write_before_read,
}, {
.lot_desc = "ctrl(wr): ctrl writer beats all pending readers",
.lot_initlock = &nvme_test_ctrl_wrlock,
.lot_locks = { &nvme_test_ctrl_rdlock, &nvme_test_ctrl_rdlock,
&nvme_test_ns_rdlock, &nvme_test_ns_rdlock,
&nvme_test_ctrl_wrlock },
.lot_verif = lock_verify_write_before_read,
}, {
.lot_desc = "ctrl(wr): ns writer beats all pending ns readers",
.lot_initlock = &nvme_test_ctrl_wrlock,
.lot_locks = { &nvme_test_ns_rdlock, &nvme_test_ns_rdlock,
&nvme_test_ns_wrlock, &nvme_test_ns_rdlock, &nvme_test_ns_wrlock },
.lot_verif = lock_verify_write_before_read,
} };
static void *
lock_thread(void *arg)
{
const nvme_ioctl_lock_t *tmpl = arg;
nvme_ioctl_lock_t lock = *tmpl;
int ctrlfd = nvme_ioctl_test_get_fd(0);
const char *targ = tmpl->nil_ent == NVME_LOCK_E_CTRL ?
"controller" : "namespace";
const char *level = tmpl->nil_level == NVME_LOCK_L_READ ?
"read" : "write";
lock.nil_flags &= ~NVME_LOCK_F_DONT_BLOCK;
nvme_ioctl_test_lock(ctrlfd, &lock);
mutex_enter(&lock_mutex);
if (!lock_valid) {
errx(EXIT_FAILURE, "TEST FAILED: thread 0x%x managed to return "
"with held %s %s lock before main thread unlocked: test "
"cannot continue", thr_self(), targ, level);
}
VERIFY3U(lock_nextres, <, MAX_LOCKS);
lock_results[lock_nextres].loi_thread = thr_self();
lock_results[lock_nextres].loi_lock = tmpl;
lock_nextres++;
mutex_exit(&lock_mutex);
VERIFY0(close(ctrlfd));
thr_exit(NULL);
}
static bool
lock_order_test(const lock_order_test_t *test)
{
int ctrlfd;
uint32_t nthr = 0;
thread_t thrids[MAX_LOCKS];
ctrlfd = nvme_ioctl_test_get_fd(0);
nvme_ioctl_test_lock(ctrlfd, test->lot_initlock);
mutex_enter(&lock_mutex);
(void) memset(&lock_results, 0, sizeof (lock_results));
lock_nextres = 0;
lock_valid = false;
mutex_exit(&lock_mutex);
for (uint32_t i = 0; i < MAX_LOCKS; i++, nthr++) {
int err;
if (test->lot_locks[i] == NULL)
break;
err = thr_create(NULL, 0, lock_thread,
(void *)test->lot_locks[i], 0, &thrids[i]);
if (err != 0) {
errc(EXIT_FAILURE, err, "TEST FAILED: %s: cannot "
"continue because we failed to create thread %u",
test->lot_desc, i);
}
while (!nvme_ioctl_test_thr_blocked(thrids[i])) {
struct timespec sleep;
sleep.tv_sec = 0;
sleep.tv_nsec = MSEC2NSEC(10);
(void) nanosleep(&sleep, NULL);
}
}
mutex_enter(&lock_mutex);
lock_valid = true;
mutex_exit(&lock_mutex);
VERIFY0(close(ctrlfd));
for (uint32_t i = 0; i < nthr; i++) {
int err = thr_join(thrids[i], NULL, NULL);
if (err != 0) {
errc(EXIT_FAILURE, err, "TEST FAILED: %s: cannot "
"continue because we failed to join thread %u",
test->lot_desc, i);
}
}
mutex_enter(&lock_mutex);
VERIFY3U(lock_nextres, ==, nthr);
mutex_exit(&lock_mutex);
if (test->lot_verif(test, nthr)) {
(void) printf("TEST PASSED: %s\n", test->lot_desc);
return (true);
}
return (false);
}
int
main(void)
{
int ret = EXIT_SUCCESS;
VERIFY0(mutex_init(&lock_mutex, USYNC_THREAD | LOCK_ERRORCHECK, NULL));
for (size_t i = 0; i < ARRAY_SIZE(lock_order_tests); i++) {
if (!lock_order_test(&lock_order_tests[i])) {
ret = EXIT_FAILURE;
}
}
VERIFY0(mutex_destroy(&lock_mutex));
return (ret);
}