root/io_uring/cancel.c
// SPDX-License-Identifier: GPL-2.0
#include <linux/kernel.h>
#include <linux/errno.h>
#include <linux/fs.h>
#include <linux/mm.h>
#include <linux/slab.h>
#include <linux/nospec.h>
#include <linux/io_uring.h>

#include <uapi/linux/io_uring.h>

#include "filetable.h"
#include "io_uring.h"
#include "tctx.h"
#include "sqpoll.h"
#include "uring_cmd.h"
#include "poll.h"
#include "timeout.h"
#include "waitid.h"
#include "futex.h"
#include "cancel.h"
#include "wait.h"

struct io_cancel {
        struct file                     *file;
        u64                             addr;
        u32                             flags;
        s32                             fd;
        u8                              opcode;
};

#define CANCEL_FLAGS    (IORING_ASYNC_CANCEL_ALL | IORING_ASYNC_CANCEL_FD | \
                         IORING_ASYNC_CANCEL_ANY | IORING_ASYNC_CANCEL_FD_FIXED | \
                         IORING_ASYNC_CANCEL_USERDATA | IORING_ASYNC_CANCEL_OP)

/*
 * Returns true if the request matches the criteria outlined by 'cd'.
 */
bool io_cancel_req_match(struct io_kiocb *req, struct io_cancel_data *cd)
{
        bool match_user_data = cd->flags & IORING_ASYNC_CANCEL_USERDATA;

        if (req->ctx != cd->ctx)
                return false;

        if (!(cd->flags & (IORING_ASYNC_CANCEL_FD | IORING_ASYNC_CANCEL_OP)))
                match_user_data = true;

        if (cd->flags & IORING_ASYNC_CANCEL_ANY)
                goto check_seq;
        if (cd->flags & IORING_ASYNC_CANCEL_FD) {
                if (req->file != cd->file)
                        return false;
        }
        if (cd->flags & IORING_ASYNC_CANCEL_OP) {
                if (req->opcode != cd->opcode)
                        return false;
        }
        if (match_user_data && req->cqe.user_data != cd->data)
                return false;
        if (cd->flags & IORING_ASYNC_CANCEL_ALL) {
check_seq:
                if (io_cancel_match_sequence(req, cd->seq))
                        return false;
        }

        return true;
}

static bool io_cancel_cb(struct io_wq_work *work, void *data)
{
        struct io_kiocb *req = container_of(work, struct io_kiocb, work);
        struct io_cancel_data *cd = data;

        return io_cancel_req_match(req, cd);
}

static int io_async_cancel_one(struct io_uring_task *tctx,
                               struct io_cancel_data *cd)
{
        enum io_wq_cancel cancel_ret;
        int ret = 0;
        bool all;

        if (!tctx || !tctx->io_wq)
                return -ENOENT;

        all = cd->flags & (IORING_ASYNC_CANCEL_ALL|IORING_ASYNC_CANCEL_ANY);
        cancel_ret = io_wq_cancel_cb(tctx->io_wq, io_cancel_cb, cd, all);
        switch (cancel_ret) {
        case IO_WQ_CANCEL_OK:
                ret = 0;
                break;
        case IO_WQ_CANCEL_RUNNING:
                ret = -EALREADY;
                break;
        case IO_WQ_CANCEL_NOTFOUND:
                ret = -ENOENT;
                break;
        }

        return ret;
}

int io_try_cancel(struct io_uring_task *tctx, struct io_cancel_data *cd,
                  unsigned issue_flags)
{
        struct io_ring_ctx *ctx = cd->ctx;
        int ret;

        WARN_ON_ONCE(!io_wq_current_is_worker() && tctx != current->io_uring);

        ret = io_async_cancel_one(tctx, cd);
        /*
         * Fall-through even for -EALREADY, as we may have poll armed
         * that need unarming.
         */
        if (!ret)
                return 0;

        ret = io_poll_cancel(ctx, cd, issue_flags);
        if (ret != -ENOENT)
                return ret;

        ret = io_waitid_cancel(ctx, cd, issue_flags);
        if (ret != -ENOENT)
                return ret;

        ret = io_futex_cancel(ctx, cd, issue_flags);
        if (ret != -ENOENT)
                return ret;

        spin_lock(&ctx->completion_lock);
        if (!(cd->flags & IORING_ASYNC_CANCEL_FD))
                ret = io_timeout_cancel(ctx, cd);
        spin_unlock(&ctx->completion_lock);
        return ret;
}

int io_async_cancel_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
{
        struct io_cancel *cancel = io_kiocb_to_cmd(req, struct io_cancel);

        if (unlikely(req->flags & REQ_F_BUFFER_SELECT))
                return -EINVAL;
        if (sqe->off || sqe->splice_fd_in)
                return -EINVAL;

        cancel->addr = READ_ONCE(sqe->addr);
        cancel->flags = READ_ONCE(sqe->cancel_flags);
        if (cancel->flags & ~CANCEL_FLAGS)
                return -EINVAL;
        if (cancel->flags & IORING_ASYNC_CANCEL_FD) {
                if (cancel->flags & IORING_ASYNC_CANCEL_ANY)
                        return -EINVAL;
                cancel->fd = READ_ONCE(sqe->fd);
        }
        if (cancel->flags & IORING_ASYNC_CANCEL_OP) {
                if (cancel->flags & IORING_ASYNC_CANCEL_ANY)
                        return -EINVAL;
                cancel->opcode = READ_ONCE(sqe->len);
        }

        return 0;
}

static int __io_async_cancel(struct io_cancel_data *cd,
                             struct io_uring_task *tctx,
                             unsigned int issue_flags)
{
        bool all = cd->flags & (IORING_ASYNC_CANCEL_ALL|IORING_ASYNC_CANCEL_ANY);
        struct io_ring_ctx *ctx = cd->ctx;
        struct io_tctx_node *node;
        int ret, nr = 0;

        do {
                ret = io_try_cancel(tctx, cd, issue_flags);
                if (ret == -ENOENT)
                        break;
                if (!all)
                        return ret;
                nr++;
        } while (1);

        /* slow path, try all io-wq's */
        __set_current_state(TASK_RUNNING);
        io_ring_submit_lock(ctx, issue_flags);
        mutex_lock(&ctx->tctx_lock);
        ret = -ENOENT;
        list_for_each_entry(node, &ctx->tctx_list, ctx_node) {
                ret = io_async_cancel_one(node->task->io_uring, cd);
                if (ret != -ENOENT) {
                        if (!all)
                                break;
                        nr++;
                }
        }
        mutex_unlock(&ctx->tctx_lock);
        io_ring_submit_unlock(ctx, issue_flags);
        return all ? nr : ret;
}

int io_async_cancel(struct io_kiocb *req, unsigned int issue_flags)
{
        struct io_cancel *cancel = io_kiocb_to_cmd(req, struct io_cancel);
        struct io_cancel_data cd = {
                .ctx    = req->ctx,
                .data   = cancel->addr,
                .flags  = cancel->flags,
                .opcode = cancel->opcode,
                .seq    = atomic_inc_return(&req->ctx->cancel_seq),
        };
        struct io_uring_task *tctx = req->tctx;
        int ret;

        if (cd.flags & IORING_ASYNC_CANCEL_FD) {
                if (req->flags & REQ_F_FIXED_FILE ||
                    cd.flags & IORING_ASYNC_CANCEL_FD_FIXED) {
                        req->flags |= REQ_F_FIXED_FILE;
                        req->file = io_file_get_fixed(req, cancel->fd,
                                                        issue_flags);
                } else {
                        req->file = io_file_get_normal(req, cancel->fd);
                }
                if (!req->file) {
                        ret = -EBADF;
                        goto done;
                }
                cd.file = req->file;
        }

        ret = __io_async_cancel(&cd, tctx, issue_flags);
done:
        if (ret < 0)
                req_set_fail(req);
        io_req_set_res(req, ret, 0);
        return IOU_COMPLETE;
}

static int __io_sync_cancel(struct io_uring_task *tctx,
                            struct io_cancel_data *cd, int fd)
{
        struct io_ring_ctx *ctx = cd->ctx;

        /* fixed must be grabbed every time since we drop the uring_lock */
        if ((cd->flags & IORING_ASYNC_CANCEL_FD) &&
            (cd->flags & IORING_ASYNC_CANCEL_FD_FIXED)) {
                struct io_rsrc_node *node;

                node = io_rsrc_node_lookup(&ctx->file_table.data, fd);
                if (unlikely(!node))
                        return -EBADF;
                cd->file = io_slot_file(node);
                if (!cd->file)
                        return -EBADF;
        }

        return __io_async_cancel(cd, tctx, 0);
}

int io_sync_cancel(struct io_ring_ctx *ctx, void __user *arg)
        __must_hold(&ctx->uring_lock)
{
        struct io_cancel_data cd = {
                .ctx    = ctx,
                .seq    = atomic_inc_return(&ctx->cancel_seq),
        };
        ktime_t timeout = KTIME_MAX;
        struct io_uring_sync_cancel_reg sc;
        struct file *file = NULL;
        DEFINE_WAIT(wait);
        int ret, i;

        if (copy_from_user(&sc, arg, sizeof(sc)))
                return -EFAULT;
        if (sc.flags & ~CANCEL_FLAGS)
                return -EINVAL;
        for (i = 0; i < ARRAY_SIZE(sc.pad); i++)
                if (sc.pad[i])
                        return -EINVAL;
        for (i = 0; i < ARRAY_SIZE(sc.pad2); i++)
                if (sc.pad2[i])
                        return -EINVAL;

        cd.data = sc.addr;
        cd.flags = sc.flags;
        cd.opcode = sc.opcode;

        /* we can grab a normal file descriptor upfront */
        if ((cd.flags & IORING_ASYNC_CANCEL_FD) &&
           !(cd.flags & IORING_ASYNC_CANCEL_FD_FIXED)) {
                file = fget(sc.fd);
                if (!file)
                        return -EBADF;
                cd.file = file;
        }

        ret = __io_sync_cancel(current->io_uring, &cd, sc.fd);

        /* found something, done! */
        if (ret != -EALREADY)
                goto out;

        if (sc.timeout.tv_sec != -1UL || sc.timeout.tv_nsec != -1UL) {
                struct timespec64 ts = {
                        .tv_sec         = sc.timeout.tv_sec,
                        .tv_nsec        = sc.timeout.tv_nsec
                };

                timeout = ktime_add_ns(timespec64_to_ktime(ts), ktime_get_ns());
        }

        /*
         * Keep looking until we get -ENOENT. we'll get woken everytime
         * every time a request completes and will retry the cancelation.
         */
        do {
                cd.seq = atomic_inc_return(&ctx->cancel_seq);

                prepare_to_wait(&ctx->cq_wait, &wait, TASK_INTERRUPTIBLE);

                ret = __io_sync_cancel(current->io_uring, &cd, sc.fd);

                mutex_unlock(&ctx->uring_lock);
                if (ret != -EALREADY)
                        break;

                ret = io_run_task_work_sig(ctx);
                if (ret < 0)
                        break;
                ret = schedule_hrtimeout(&timeout, HRTIMER_MODE_ABS);
                if (!ret) {
                        ret = -ETIME;
                        break;
                }
                mutex_lock(&ctx->uring_lock);
        } while (1);

        finish_wait(&ctx->cq_wait, &wait);
        mutex_lock(&ctx->uring_lock);

        if (ret == -ENOENT || ret > 0)
                ret = 0;
out:
        if (file)
                fput(file);
        return ret;
}

bool io_cancel_remove_all(struct io_ring_ctx *ctx, struct io_uring_task *tctx,
                          struct hlist_head *list, bool cancel_all,
                          bool (*cancel)(struct io_kiocb *))
{
        struct hlist_node *tmp;
        struct io_kiocb *req;
        bool found = false;

        lockdep_assert_held(&ctx->uring_lock);

        hlist_for_each_entry_safe(req, tmp, list, hash_node) {
                if (!io_match_task_safe(req, tctx, cancel_all))
                        continue;
                hlist_del_init(&req->hash_node);
                if (cancel(req))
                        found = true;
        }

        return found;
}

int io_cancel_remove(struct io_ring_ctx *ctx, struct io_cancel_data *cd,
                     unsigned int issue_flags, struct hlist_head *list,
                     bool (*cancel)(struct io_kiocb *))
{
        struct hlist_node *tmp;
        struct io_kiocb *req;
        int nr = 0;

        io_ring_submit_lock(ctx, issue_flags);
        hlist_for_each_entry_safe(req, tmp, list, hash_node) {
                if (!io_cancel_req_match(req, cd))
                        continue;
                if (cancel(req))
                        nr++;
                if (!(cd->flags & IORING_ASYNC_CANCEL_ALL))
                        break;
        }
        io_ring_submit_unlock(ctx, issue_flags);
        return nr ?: -ENOENT;
}

static bool io_match_linked(struct io_kiocb *head)
{
        struct io_kiocb *req;

        io_for_each_link(req, head) {
                if (req->flags & REQ_F_INFLIGHT)
                        return true;
        }
        return false;
}

/*
 * As io_match_task() but protected against racing with linked timeouts.
 * User must not hold timeout_lock.
 */
bool io_match_task_safe(struct io_kiocb *head, struct io_uring_task *tctx,
                        bool cancel_all)
{
        bool matched;

        if (tctx && head->tctx != tctx)
                return false;
        if (cancel_all)
                return true;

        if (head->flags & REQ_F_LINK_TIMEOUT) {
                struct io_ring_ctx *ctx = head->ctx;

                /* protect against races with linked timeouts */
                raw_spin_lock_irq(&ctx->timeout_lock);
                matched = io_match_linked(head);
                raw_spin_unlock_irq(&ctx->timeout_lock);
        } else {
                matched = io_match_linked(head);
        }
        return matched;
}

void __io_uring_cancel(bool cancel_all)
{
        io_uring_unreg_ringfd();
        io_uring_cancel_generic(cancel_all, NULL);
}

struct io_task_cancel {
        struct io_uring_task *tctx;
        bool all;
};

static bool io_cancel_task_cb(struct io_wq_work *work, void *data)
{
        struct io_kiocb *req = container_of(work, struct io_kiocb, work);
        struct io_task_cancel *cancel = data;

        return io_match_task_safe(req, cancel->tctx, cancel->all);
}

static __cold bool io_cancel_defer_files(struct io_ring_ctx *ctx,
                                         struct io_uring_task *tctx,
                                         bool cancel_all)
{
        struct io_defer_entry *de;
        LIST_HEAD(list);

        list_for_each_entry_reverse(de, &ctx->defer_list, list) {
                if (io_match_task_safe(de->req, tctx, cancel_all)) {
                        list_cut_position(&list, &ctx->defer_list, &de->list);
                        break;
                }
        }
        if (list_empty(&list))
                return false;

        while (!list_empty(&list)) {
                de = list_first_entry(&list, struct io_defer_entry, list);
                list_del_init(&de->list);
                ctx->nr_drained -= io_linked_nr(de->req);
                io_req_task_queue_fail(de->req, -ECANCELED);
                kfree(de);
        }
        return true;
}

__cold bool io_cancel_ctx_cb(struct io_wq_work *work, void *data)
{
        struct io_kiocb *req = container_of(work, struct io_kiocb, work);

        return req->ctx == data;
}

static __cold bool io_uring_try_cancel_iowq(struct io_ring_ctx *ctx)
{
        struct io_tctx_node *node;
        enum io_wq_cancel cret;
        bool ret = false;

        mutex_lock(&ctx->uring_lock);
        mutex_lock(&ctx->tctx_lock);
        list_for_each_entry(node, &ctx->tctx_list, ctx_node) {
                struct io_uring_task *tctx = node->task->io_uring;

                /*
                 * io_wq will stay alive while we hold uring_lock, because it's
                 * killed after ctx nodes, which requires to take the lock.
                 */
                if (!tctx || !tctx->io_wq)
                        continue;
                cret = io_wq_cancel_cb(tctx->io_wq, io_cancel_ctx_cb, ctx, true);
                ret |= (cret != IO_WQ_CANCEL_NOTFOUND);
        }
        mutex_unlock(&ctx->tctx_lock);
        mutex_unlock(&ctx->uring_lock);

        return ret;
}

__cold bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
                                         struct io_uring_task *tctx,
                                         bool cancel_all, bool is_sqpoll_thread)
{
        struct io_task_cancel cancel = { .tctx = tctx, .all = cancel_all, };
        enum io_wq_cancel cret;
        bool ret = false;

        /* set it so io_req_local_work_add() would wake us up */
        if (ctx->flags & IORING_SETUP_DEFER_TASKRUN) {
                atomic_set(&ctx->cq_wait_nr, 1);
                smp_mb();
        }

        /* failed during ring init, it couldn't have issued any requests */
        if (!ctx->rings)
                return false;

        if (!tctx) {
                ret |= io_uring_try_cancel_iowq(ctx);
        } else if (tctx->io_wq) {
                /*
                 * Cancels requests of all rings, not only @ctx, but
                 * it's fine as the task is in exit/exec.
                 */
                cret = io_wq_cancel_cb(tctx->io_wq, io_cancel_task_cb,
                                       &cancel, true);
                ret |= (cret != IO_WQ_CANCEL_NOTFOUND);
        }

        /* SQPOLL thread does its own polling */
        if ((!(ctx->flags & IORING_SETUP_SQPOLL) && cancel_all) ||
            is_sqpoll_thread) {
                while (!list_empty(&ctx->iopoll_list)) {
                        io_iopoll_try_reap_events(ctx);
                        ret = true;
                        cond_resched();
                }
        }

        if ((ctx->flags & IORING_SETUP_DEFER_TASKRUN) &&
            io_allowed_defer_tw_run(ctx))
                ret |= io_run_local_work(ctx, INT_MAX, INT_MAX) > 0;
        mutex_lock(&ctx->uring_lock);
        ret |= io_cancel_defer_files(ctx, tctx, cancel_all);
        ret |= io_poll_remove_all(ctx, tctx, cancel_all);
        ret |= io_waitid_remove_all(ctx, tctx, cancel_all);
        ret |= io_futex_remove_all(ctx, tctx, cancel_all);
        ret |= io_uring_try_cancel_uring_cmd(ctx, tctx, cancel_all);
        mutex_unlock(&ctx->uring_lock);
        ret |= io_kill_timeouts(ctx, tctx, cancel_all);
        if (tctx)
                ret |= io_run_task_work() > 0;
        else
                ret |= flush_delayed_work(&ctx->fallback_work);
        return ret;
}

static s64 tctx_inflight(struct io_uring_task *tctx, bool tracked)
{
        if (tracked)
                return atomic_read(&tctx->inflight_tracked);
        return percpu_counter_sum(&tctx->inflight);
}

/*
 * Find any io_uring ctx that this task has registered or done IO on, and cancel
 * requests. @sqd should be not-null IFF it's an SQPOLL thread cancellation.
 */
__cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd)
{
        struct io_uring_task *tctx = current->io_uring;
        struct io_ring_ctx *ctx;
        struct io_tctx_node *node;
        unsigned long index;
        s64 inflight;
        DEFINE_WAIT(wait);

        WARN_ON_ONCE(sqd && sqpoll_task_locked(sqd) != current);

        if (!current->io_uring)
                return;
        if (tctx->io_wq)
                io_wq_exit_start(tctx->io_wq);

        atomic_inc(&tctx->in_cancel);
        do {
                bool loop = false;

                io_uring_drop_tctx_refs(current);
                if (!tctx_inflight(tctx, !cancel_all))
                        break;

                /* read completions before cancelations */
                inflight = tctx_inflight(tctx, false);
                if (!inflight)
                        break;

                if (!sqd) {
                        xa_for_each(&tctx->xa, index, node) {
                                /* sqpoll task will cancel all its requests */
                                if (node->ctx->sq_data)
                                        continue;
                                loop |= io_uring_try_cancel_requests(node->ctx,
                                                        current->io_uring,
                                                        cancel_all,
                                                        false);
                        }
                } else {
                        list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
                                loop |= io_uring_try_cancel_requests(ctx,
                                                                     current->io_uring,
                                                                     cancel_all,
                                                                     true);
                }

                if (loop) {
                        cond_resched();
                        continue;
                }

                prepare_to_wait(&tctx->wait, &wait, TASK_INTERRUPTIBLE);
                io_run_task_work();
                io_uring_drop_tctx_refs(current);
                xa_for_each(&tctx->xa, index, node) {
                        if (io_local_work_pending(node->ctx)) {
                                WARN_ON_ONCE(node->ctx->submitter_task &&
                                             node->ctx->submitter_task != current);
                                goto end_wait;
                        }
                }
                /*
                 * If we've seen completions, retry without waiting. This
                 * avoids a race where a completion comes in before we did
                 * prepare_to_wait().
                 */
                if (inflight == tctx_inflight(tctx, !cancel_all))
                        schedule();
end_wait:
                finish_wait(&tctx->wait, &wait);
        } while (1);

        io_uring_clean_tctx(tctx);
        if (cancel_all) {
                /*
                 * We shouldn't run task_works after cancel, so just leave
                 * ->in_cancel set for normal exit.
                 */
                atomic_dec(&tctx->in_cancel);
                /* for exec all current's requests should be gone, kill tctx */
                __io_uring_free(current);
        }
}