#include <linux/slab.h>
#include <linux/completion.h>
#include <linux/sched/task.h>
#include <linux/sched/vhost_task.h>
#include <linux/sched/signal.h>
enum vhost_task_flags {
VHOST_TASK_FLAGS_STOP,
VHOST_TASK_FLAGS_KILLED,
};
struct vhost_task {
bool (*fn)(void *data);
void (*handle_sigkill)(void *data);
void *data;
struct completion exited;
unsigned long flags;
struct task_struct *task;
struct mutex exit_mutex;
};
static int vhost_task_fn(void *data)
{
struct vhost_task *vtsk = data;
for (;;) {
bool did_work;
if (signal_pending(current)) {
struct ksignal ksig;
if (get_signal(&ksig))
break;
}
set_current_state(TASK_INTERRUPTIBLE);
if (test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) {
__set_current_state(TASK_RUNNING);
break;
}
did_work = vtsk->fn(vtsk->data);
if (!did_work)
schedule();
}
mutex_lock(&vtsk->exit_mutex);
if (!test_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags)) {
set_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags);
vtsk->handle_sigkill(vtsk->data);
}
mutex_unlock(&vtsk->exit_mutex);
complete(&vtsk->exited);
do_exit(0);
}
void vhost_task_wake(struct vhost_task *vtsk)
{
wake_up_process(vtsk->task);
}
EXPORT_SYMBOL_GPL(vhost_task_wake);
void vhost_task_stop(struct vhost_task *vtsk)
{
mutex_lock(&vtsk->exit_mutex);
if (!test_bit(VHOST_TASK_FLAGS_KILLED, &vtsk->flags)) {
set_bit(VHOST_TASK_FLAGS_STOP, &vtsk->flags);
vhost_task_wake(vtsk);
}
mutex_unlock(&vtsk->exit_mutex);
wait_for_completion(&vtsk->exited);
put_task_struct(vtsk->task);
kfree(vtsk);
}
EXPORT_SYMBOL_GPL(vhost_task_stop);
struct vhost_task *vhost_task_create(bool (*fn)(void *),
void (*handle_sigkill)(void *), void *arg,
const char *name)
{
struct kernel_clone_args args = {
.flags = CLONE_FS | CLONE_UNTRACED | CLONE_VM |
CLONE_THREAD | CLONE_SIGHAND,
.exit_signal = 0,
.fn = vhost_task_fn,
.name = name,
.user_worker = 1,
.no_files = 1,
};
struct vhost_task *vtsk;
struct task_struct *tsk;
vtsk = kzalloc_obj(*vtsk);
if (!vtsk)
return ERR_PTR(-ENOMEM);
init_completion(&vtsk->exited);
mutex_init(&vtsk->exit_mutex);
vtsk->data = arg;
vtsk->fn = fn;
vtsk->handle_sigkill = handle_sigkill;
args.fn_arg = vtsk;
tsk = copy_process(NULL, 0, NUMA_NO_NODE, &args);
if (IS_ERR(tsk)) {
kfree(vtsk);
return ERR_CAST(tsk);
}
vtsk->task = get_task_struct(tsk);
return vtsk;
}
EXPORT_SYMBOL_GPL(vhost_task_create);
void vhost_task_start(struct vhost_task *vtsk)
{
wake_up_new_task(vtsk->task);
}
EXPORT_SYMBOL_GPL(vhost_task_start);