root/drivers/iommu/iommufd/selftest.c
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
 *
 * Kernel side components to support tools/testing/selftests/iommu
 */
#include <linux/anon_inodes.h>
#include <linux/debugfs.h>
#include <linux/dma-buf.h>
#include <linux/dma-resv.h>
#include <linux/fault-inject.h>
#include <linux/file.h>
#include <linux/iommu.h>
#include <linux/platform_device.h>
#include <linux/slab.h>
#include <linux/xarray.h>
#include <uapi/linux/iommufd.h>
#include <linux/generic_pt/iommu.h>
#include "../iommu-pages.h"

#include "../iommu-priv.h"
#include "io_pagetable.h"
#include "iommufd_private.h"
#include "iommufd_test.h"

static DECLARE_FAULT_ATTR(fail_iommufd);
static struct dentry *dbgfs_root;
static struct platform_device *selftest_iommu_dev;
static const struct iommu_ops mock_ops;
static struct iommu_domain_ops domain_nested_ops;

size_t iommufd_test_memory_limit = 65536;

struct mock_bus_type {
        struct bus_type bus;
        struct notifier_block nb;
};

static struct mock_bus_type iommufd_mock_bus_type = {
        .bus = {
                .name = "iommufd_mock",
        },
};

static DEFINE_IDA(mock_dev_ida);

enum {
        MOCK_DIRTY_TRACK = 1,
};

static int mock_dev_enable_iopf(struct device *dev, struct iommu_domain *domain);
static void mock_dev_disable_iopf(struct device *dev, struct iommu_domain *domain);

/*
 * Syzkaller has trouble randomizing the correct iova to use since it is linked
 * to the map ioctl's output, and it has no ide about that. So, simplify things.
 * In syzkaller mode the 64 bit IOVA is converted into an nth area and offset
 * value. This has a much smaller randomization space and syzkaller can hit it.
 */
static unsigned long __iommufd_test_syz_conv_iova(struct io_pagetable *iopt,
                                                  u64 *iova)
{
        struct syz_layout {
                __u32 nth_area;
                __u32 offset;
        };
        struct syz_layout *syz = (void *)iova;
        unsigned int nth = syz->nth_area;
        struct iopt_area *area;

        down_read(&iopt->iova_rwsem);
        for (area = iopt_area_iter_first(iopt, 0, ULONG_MAX); area;
             area = iopt_area_iter_next(area, 0, ULONG_MAX)) {
                if (nth == 0) {
                        up_read(&iopt->iova_rwsem);
                        return iopt_area_iova(area) + syz->offset;
                }
                nth--;
        }
        up_read(&iopt->iova_rwsem);

        return 0;
}

static unsigned long iommufd_test_syz_conv_iova(struct iommufd_access *access,
                                                u64 *iova)
{
        unsigned long ret;

        mutex_lock(&access->ioas_lock);
        if (!access->ioas) {
                mutex_unlock(&access->ioas_lock);
                return 0;
        }
        ret = __iommufd_test_syz_conv_iova(&access->ioas->iopt, iova);
        mutex_unlock(&access->ioas_lock);
        return ret;
}

void iommufd_test_syz_conv_iova_id(struct iommufd_ucmd *ucmd,
                                   unsigned int ioas_id, u64 *iova, u32 *flags)
{
        struct iommufd_ioas *ioas;

        if (!(*flags & MOCK_FLAGS_ACCESS_SYZ))
                return;
        *flags &= ~(u32)MOCK_FLAGS_ACCESS_SYZ;

        ioas = iommufd_get_ioas(ucmd->ictx, ioas_id);
        if (IS_ERR(ioas))
                return;
        *iova = __iommufd_test_syz_conv_iova(&ioas->iopt, iova);
        iommufd_put_object(ucmd->ictx, &ioas->obj);
}

struct mock_iommu_domain {
        union {
                struct iommu_domain domain;
                struct pt_iommu iommu;
                struct pt_iommu_amdv1 amdv1;
        };
        unsigned long flags;
};
PT_IOMMU_CHECK_DOMAIN(struct mock_iommu_domain, iommu, domain);
PT_IOMMU_CHECK_DOMAIN(struct mock_iommu_domain, amdv1.iommu, domain);

static inline struct mock_iommu_domain *
to_mock_domain(struct iommu_domain *domain)
{
        return container_of(domain, struct mock_iommu_domain, domain);
}

struct mock_iommu_domain_nested {
        struct iommu_domain domain;
        struct mock_viommu *mock_viommu;
        u32 iotlb[MOCK_NESTED_DOMAIN_IOTLB_NUM];
};

static inline struct mock_iommu_domain_nested *
to_mock_nested(struct iommu_domain *domain)
{
        return container_of(domain, struct mock_iommu_domain_nested, domain);
}

struct mock_viommu {
        struct iommufd_viommu core;
        struct mock_iommu_domain *s2_parent;
        struct mock_hw_queue *hw_queue[IOMMU_TEST_HW_QUEUE_MAX];
        struct mutex queue_mutex;

        unsigned long mmap_offset;
        u32 *page; /* Mmap page to test u32 type of in_data */
};

static inline struct mock_viommu *to_mock_viommu(struct iommufd_viommu *viommu)
{
        return container_of(viommu, struct mock_viommu, core);
}

struct mock_hw_queue {
        struct iommufd_hw_queue core;
        struct mock_viommu *mock_viommu;
        struct mock_hw_queue *prev;
        u16 index;
};

static inline struct mock_hw_queue *
to_mock_hw_queue(struct iommufd_hw_queue *hw_queue)
{
        return container_of(hw_queue, struct mock_hw_queue, core);
}

enum selftest_obj_type {
        TYPE_IDEV,
};

struct mock_dev {
        struct device dev;
        struct mock_viommu *viommu;
        struct rw_semaphore viommu_rwsem;
        unsigned long flags;
        unsigned long vdev_id;
        int id;
        u32 cache[MOCK_DEV_CACHE_NUM];
        atomic_t pasid_1024_fake_error;
        unsigned int iopf_refcount;
        struct iommu_domain *domain;
};

static inline struct mock_dev *to_mock_dev(struct device *dev)
{
        return container_of(dev, struct mock_dev, dev);
}

struct selftest_obj {
        struct iommufd_object obj;
        enum selftest_obj_type type;

        union {
                struct {
                        struct iommufd_device *idev;
                        struct iommufd_ctx *ictx;
                        struct mock_dev *mock_dev;
                } idev;
        };
};

static inline struct selftest_obj *to_selftest_obj(struct iommufd_object *obj)
{
        return container_of(obj, struct selftest_obj, obj);
}

static int mock_domain_nop_attach(struct iommu_domain *domain,
                                  struct device *dev, struct iommu_domain *old)
{
        struct mock_dev *mdev = to_mock_dev(dev);
        struct mock_viommu *new_viommu = NULL;
        unsigned long vdev_id = 0;
        int rc;

        if (domain->dirty_ops && (mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY))
                return -EINVAL;

        iommu_group_mutex_assert(dev);
        if (domain->type == IOMMU_DOMAIN_NESTED) {
                new_viommu = to_mock_nested(domain)->mock_viommu;
                if (new_viommu) {
                        rc = iommufd_viommu_get_vdev_id(&new_viommu->core, dev,
                                                        &vdev_id);
                        if (rc)
                                return rc;
                }
        }
        if (new_viommu != mdev->viommu) {
                down_write(&mdev->viommu_rwsem);
                mdev->viommu = new_viommu;
                mdev->vdev_id = vdev_id;
                up_write(&mdev->viommu_rwsem);
        }

        rc = mock_dev_enable_iopf(dev, domain);
        if (rc)
                return rc;

        mock_dev_disable_iopf(dev, mdev->domain);
        mdev->domain = domain;

        return 0;
}

static int mock_domain_set_dev_pasid_nop(struct iommu_domain *domain,
                                         struct device *dev, ioasid_t pasid,
                                         struct iommu_domain *old)
{
        struct mock_dev *mdev = to_mock_dev(dev);
        int rc;

        /*
         * Per the first attach with pasid 1024, set the
         * mdev->pasid_1024_fake_error. Hence the second call of this op
         * can fake an error to validate the error path of the core. This
         * is helpful to test the case in which the iommu core needs to
         * rollback to the old domain due to driver failure. e.g. replace.
         * User should be careful about the third call of this op, it shall
         * succeed since the mdev->pasid_1024_fake_error is cleared in the
         * second call.
         */
        if (pasid == 1024) {
                if (domain->type == IOMMU_DOMAIN_BLOCKED) {
                        atomic_set(&mdev->pasid_1024_fake_error, 0);
                } else if (atomic_read(&mdev->pasid_1024_fake_error)) {
                        /*
                         * Clear the flag, and fake an error to fail the
                         * replacement.
                         */
                        atomic_set(&mdev->pasid_1024_fake_error, 0);
                        return -ENOMEM;
                } else {
                        /* Set the flag to fake an error in next call */
                        atomic_set(&mdev->pasid_1024_fake_error, 1);
                }
        }

        rc = mock_dev_enable_iopf(dev, domain);
        if (rc)
                return rc;

        mock_dev_disable_iopf(dev, old);

        return 0;
}

static const struct iommu_domain_ops mock_blocking_ops = {
        .attach_dev = mock_domain_nop_attach,
        .set_dev_pasid = mock_domain_set_dev_pasid_nop
};

static struct iommu_domain mock_blocking_domain = {
        .type = IOMMU_DOMAIN_BLOCKED,
        .ops = &mock_blocking_ops,
};

static void *mock_domain_hw_info(struct device *dev, u32 *length,
                                 enum iommu_hw_info_type *type)
{
        struct iommu_test_hw_info *info;

        if (*type != IOMMU_HW_INFO_TYPE_DEFAULT &&
            *type != IOMMU_HW_INFO_TYPE_SELFTEST)
                return ERR_PTR(-EOPNOTSUPP);

        info = kzalloc_obj(*info);
        if (!info)
                return ERR_PTR(-ENOMEM);

        info->test_reg = IOMMU_HW_INFO_SELFTEST_REGVAL;
        *length = sizeof(*info);
        *type = IOMMU_HW_INFO_TYPE_SELFTEST;

        return info;
}

static int mock_domain_set_dirty_tracking(struct iommu_domain *domain,
                                          bool enable)
{
        struct mock_iommu_domain *mock = to_mock_domain(domain);
        unsigned long flags = mock->flags;

        if (enable && !domain->dirty_ops)
                return -EINVAL;

        /* No change? */
        if (!(enable ^ !!(flags & MOCK_DIRTY_TRACK)))
                return 0;

        flags = (enable ? flags | MOCK_DIRTY_TRACK : flags & ~MOCK_DIRTY_TRACK);

        mock->flags = flags;
        return 0;
}

static struct mock_iommu_domain_nested *
__mock_domain_alloc_nested(const struct iommu_user_data *user_data)
{
        struct mock_iommu_domain_nested *mock_nested;
        struct iommu_hwpt_selftest user_cfg;
        int rc, i;

        if (user_data->type != IOMMU_HWPT_DATA_SELFTEST)
                return ERR_PTR(-EOPNOTSUPP);

        rc = iommu_copy_struct_from_user(&user_cfg, user_data,
                                         IOMMU_HWPT_DATA_SELFTEST, iotlb);
        if (rc)
                return ERR_PTR(rc);

        mock_nested = kzalloc_obj(*mock_nested);
        if (!mock_nested)
                return ERR_PTR(-ENOMEM);
        mock_nested->domain.ops = &domain_nested_ops;
        mock_nested->domain.type = IOMMU_DOMAIN_NESTED;
        for (i = 0; i < MOCK_NESTED_DOMAIN_IOTLB_NUM; i++)
                mock_nested->iotlb[i] = user_cfg.iotlb;
        return mock_nested;
}

static struct iommu_domain *
mock_domain_alloc_nested(struct device *dev, struct iommu_domain *parent,
                         u32 flags, const struct iommu_user_data *user_data)
{
        struct mock_iommu_domain_nested *mock_nested;
        struct mock_iommu_domain *mock_parent;

        if (flags & ~IOMMU_HWPT_ALLOC_PASID)
                return ERR_PTR(-EOPNOTSUPP);
        if (!parent || !(parent->type & __IOMMU_DOMAIN_PAGING))
                return ERR_PTR(-EINVAL);

        mock_parent = to_mock_domain(parent);
        if (!mock_parent)
                return ERR_PTR(-EINVAL);

        mock_nested = __mock_domain_alloc_nested(user_data);
        if (IS_ERR(mock_nested))
                return ERR_CAST(mock_nested);
        return &mock_nested->domain;
}

static void mock_domain_free(struct iommu_domain *domain)
{
        struct mock_iommu_domain *mock = to_mock_domain(domain);

        pt_iommu_deinit(&mock->iommu);
        kfree(mock);
}

static void mock_iotlb_sync(struct iommu_domain *domain,
                                struct iommu_iotlb_gather *gather)
{
        iommu_put_pages_list(&gather->freelist);
}

static const struct iommu_domain_ops amdv1_mock_ops = {
        IOMMU_PT_DOMAIN_OPS(amdv1_mock),
        .free = mock_domain_free,
        .attach_dev = mock_domain_nop_attach,
        .set_dev_pasid = mock_domain_set_dev_pasid_nop,
        .iotlb_sync = &mock_iotlb_sync,
};

static const struct iommu_domain_ops amdv1_mock_huge_ops = {
        IOMMU_PT_DOMAIN_OPS(amdv1_mock),
        .free = mock_domain_free,
        .attach_dev = mock_domain_nop_attach,
        .set_dev_pasid = mock_domain_set_dev_pasid_nop,
        .iotlb_sync = &mock_iotlb_sync,
};
#undef pt_iommu_amdv1_mock_map_pages

static const struct iommu_dirty_ops amdv1_mock_dirty_ops = {
        IOMMU_PT_DIRTY_OPS(amdv1_mock),
        .set_dirty_tracking = mock_domain_set_dirty_tracking,
};

static const struct iommu_domain_ops amdv1_ops = {
        IOMMU_PT_DOMAIN_OPS(amdv1),
        .free = mock_domain_free,
        .attach_dev = mock_domain_nop_attach,
        .set_dev_pasid = mock_domain_set_dev_pasid_nop,
        .iotlb_sync = &mock_iotlb_sync,
};

static const struct iommu_dirty_ops amdv1_dirty_ops = {
        IOMMU_PT_DIRTY_OPS(amdv1),
        .set_dirty_tracking = mock_domain_set_dirty_tracking,
};

static struct mock_iommu_domain *
mock_domain_alloc_pgtable(struct device *dev,
                          const struct iommu_hwpt_selftest *user_cfg, u32 flags)
{
        struct mock_iommu_domain *mock;
        int rc;

        mock = kzalloc_obj(*mock);
        if (!mock)
                return ERR_PTR(-ENOMEM);
        mock->domain.type = IOMMU_DOMAIN_UNMANAGED;

        mock->amdv1.iommu.nid = NUMA_NO_NODE;

        switch (user_cfg->pagetable_type) {
        case MOCK_IOMMUPT_DEFAULT:
        case MOCK_IOMMUPT_HUGE: {
                struct pt_iommu_amdv1_cfg cfg = {};

                /* The mock version has a 2k page size */
                cfg.common.hw_max_vasz_lg2 = 56;
                cfg.common.hw_max_oasz_lg2 = 51;
                cfg.starting_level = 2;
                if (user_cfg->pagetable_type == MOCK_IOMMUPT_HUGE)
                        mock->domain.ops = &amdv1_mock_huge_ops;
                else
                        mock->domain.ops = &amdv1_mock_ops;
                rc = pt_iommu_amdv1_mock_init(&mock->amdv1, &cfg, GFP_KERNEL);
                if (rc)
                        goto err_free;

                /*
                 * In huge mode userspace should only provide huge pages, we
                 * have to include PAGE_SIZE for the domain to be accepted by
                 * iommufd.
                 */
                if (user_cfg->pagetable_type == MOCK_IOMMUPT_HUGE)
                        mock->domain.pgsize_bitmap = MOCK_HUGE_PAGE_SIZE |
                                                     PAGE_SIZE;
                if (flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING)
                        mock->domain.dirty_ops = &amdv1_mock_dirty_ops;
                break;
        }

        case MOCK_IOMMUPT_AMDV1: {
                struct pt_iommu_amdv1_cfg cfg = {};

                cfg.common.hw_max_vasz_lg2 = 64;
                cfg.common.hw_max_oasz_lg2 = 52;
                cfg.common.features = BIT(PT_FEAT_DYNAMIC_TOP) |
                                      BIT(PT_FEAT_AMDV1_ENCRYPT_TABLES) |
                                      BIT(PT_FEAT_AMDV1_FORCE_COHERENCE);
                cfg.starting_level = 2;
                mock->domain.ops = &amdv1_ops;
                rc = pt_iommu_amdv1_init(&mock->amdv1, &cfg, GFP_KERNEL);
                if (rc)
                        goto err_free;
                if (flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING)
                        mock->domain.dirty_ops = &amdv1_dirty_ops;
                break;
        }
        default:
                rc = -EOPNOTSUPP;
                goto err_free;
        }

        /*
         * Override the real aperture to the MOCK aperture for test purposes.
         */
        if (user_cfg->pagetable_type == MOCK_IOMMUPT_DEFAULT) {
                WARN_ON(mock->domain.geometry.aperture_start != 0);
                WARN_ON(mock->domain.geometry.aperture_end < MOCK_APERTURE_LAST);

                mock->domain.geometry.aperture_start = MOCK_APERTURE_START;
                mock->domain.geometry.aperture_end = MOCK_APERTURE_LAST;
        }

        return mock;
err_free:
        kfree(mock);
        return ERR_PTR(rc);
}

static struct iommu_domain *
mock_domain_alloc_paging_flags(struct device *dev, u32 flags,
                               const struct iommu_user_data *user_data)
{
        bool has_dirty_flag = flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING;
        const u32 PAGING_FLAGS = IOMMU_HWPT_ALLOC_DIRTY_TRACKING |
                                 IOMMU_HWPT_ALLOC_NEST_PARENT |
                                 IOMMU_HWPT_ALLOC_PASID;
        struct mock_dev *mdev = to_mock_dev(dev);
        bool no_dirty_ops = mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY;
        struct iommu_hwpt_selftest user_cfg = {};
        struct mock_iommu_domain *mock;
        int rc;

        if ((flags & ~PAGING_FLAGS) || (has_dirty_flag && no_dirty_ops))
                return ERR_PTR(-EOPNOTSUPP);

        if (user_data && (user_data->type != IOMMU_HWPT_DATA_SELFTEST &&
                          user_data->type != IOMMU_HWPT_DATA_NONE))
                return ERR_PTR(-EOPNOTSUPP);

        if (user_data) {
                rc = iommu_copy_struct_from_user(
                        &user_cfg, user_data, IOMMU_HWPT_DATA_SELFTEST, iotlb);
                if (rc)
                        return ERR_PTR(rc);
        }

        mock = mock_domain_alloc_pgtable(dev, &user_cfg, flags);
        if (IS_ERR(mock))
                return ERR_CAST(mock);
        return &mock->domain;
}

static bool mock_domain_capable(struct device *dev, enum iommu_cap cap)
{
        struct mock_dev *mdev = to_mock_dev(dev);

        switch (cap) {
        case IOMMU_CAP_CACHE_COHERENCY:
                return true;
        case IOMMU_CAP_DIRTY_TRACKING:
                return !(mdev->flags & MOCK_FLAGS_DEVICE_NO_DIRTY);
        default:
                break;
        }

        return false;
}

static struct iopf_queue *mock_iommu_iopf_queue;

static struct mock_iommu_device {
        struct iommu_device iommu_dev;
        struct completion complete;
        refcount_t users;
} mock_iommu;

static struct iommu_device *mock_probe_device(struct device *dev)
{
        if (dev->bus != &iommufd_mock_bus_type.bus)
                return ERR_PTR(-ENODEV);
        return &mock_iommu.iommu_dev;
}

static void mock_domain_page_response(struct device *dev, struct iopf_fault *evt,
                                      struct iommu_page_response *msg)
{
}

static int mock_dev_enable_iopf(struct device *dev, struct iommu_domain *domain)
{
        struct mock_dev *mdev = to_mock_dev(dev);
        int ret;

        if (!domain || !domain->iopf_handler)
                return 0;

        if (!mock_iommu_iopf_queue)
                return -ENODEV;

        if (mdev->iopf_refcount) {
                mdev->iopf_refcount++;
                return 0;
        }

        ret = iopf_queue_add_device(mock_iommu_iopf_queue, dev);
        if (ret)
                return ret;

        mdev->iopf_refcount = 1;

        return 0;
}

static void mock_dev_disable_iopf(struct device *dev, struct iommu_domain *domain)
{
        struct mock_dev *mdev = to_mock_dev(dev);

        if (!domain || !domain->iopf_handler)
                return;

        if (--mdev->iopf_refcount)
                return;

        iopf_queue_remove_device(mock_iommu_iopf_queue, dev);
}

static void mock_viommu_destroy(struct iommufd_viommu *viommu)
{
        struct mock_iommu_device *mock_iommu = container_of(
                viommu->iommu_dev, struct mock_iommu_device, iommu_dev);
        struct mock_viommu *mock_viommu = to_mock_viommu(viommu);

        if (refcount_dec_and_test(&mock_iommu->users))
                complete(&mock_iommu->complete);
        if (mock_viommu->mmap_offset)
                iommufd_viommu_destroy_mmap(&mock_viommu->core,
                                            mock_viommu->mmap_offset);
        free_page((unsigned long)mock_viommu->page);
        mutex_destroy(&mock_viommu->queue_mutex);

        /* iommufd core frees mock_viommu and viommu */
}

static struct iommu_domain *
mock_viommu_alloc_domain_nested(struct iommufd_viommu *viommu, u32 flags,
                                const struct iommu_user_data *user_data)
{
        struct mock_viommu *mock_viommu = to_mock_viommu(viommu);
        struct mock_iommu_domain_nested *mock_nested;

        if (flags & ~IOMMU_HWPT_ALLOC_PASID)
                return ERR_PTR(-EOPNOTSUPP);

        mock_nested = __mock_domain_alloc_nested(user_data);
        if (IS_ERR(mock_nested))
                return ERR_CAST(mock_nested);
        mock_nested->mock_viommu = mock_viommu;
        return &mock_nested->domain;
}

static int mock_viommu_cache_invalidate(struct iommufd_viommu *viommu,
                                        struct iommu_user_data_array *array)
{
        struct iommu_viommu_invalidate_selftest *cmds;
        struct iommu_viommu_invalidate_selftest *cur;
        struct iommu_viommu_invalidate_selftest *end;
        int rc;

        /* A zero-length array is allowed to validate the array type */
        if (array->entry_num == 0 &&
            array->type == IOMMU_VIOMMU_INVALIDATE_DATA_SELFTEST) {
                array->entry_num = 0;
                return 0;
        }

        cmds = kzalloc_objs(*cmds, array->entry_num);
        if (!cmds)
                return -ENOMEM;
        cur = cmds;
        end = cmds + array->entry_num;

        static_assert(sizeof(*cmds) == 3 * sizeof(u32));
        rc = iommu_copy_struct_from_full_user_array(
                cmds, sizeof(*cmds), array,
                IOMMU_VIOMMU_INVALIDATE_DATA_SELFTEST);
        if (rc)
                goto out;

        while (cur != end) {
                struct mock_dev *mdev;
                struct device *dev;
                int i;

                if (cur->flags & ~IOMMU_TEST_INVALIDATE_FLAG_ALL) {
                        rc = -EOPNOTSUPP;
                        goto out;
                }

                if (cur->cache_id > MOCK_DEV_CACHE_ID_MAX) {
                        rc = -EINVAL;
                        goto out;
                }

                xa_lock(&viommu->vdevs);
                dev = iommufd_viommu_find_dev(viommu,
                                              (unsigned long)cur->vdev_id);
                if (!dev) {
                        xa_unlock(&viommu->vdevs);
                        rc = -EINVAL;
                        goto out;
                }
                mdev = container_of(dev, struct mock_dev, dev);

                if (cur->flags & IOMMU_TEST_INVALIDATE_FLAG_ALL) {
                        /* Invalidate all cache entries and ignore cache_id */
                        for (i = 0; i < MOCK_DEV_CACHE_NUM; i++)
                                mdev->cache[i] = 0;
                } else {
                        mdev->cache[cur->cache_id] = 0;
                }
                xa_unlock(&viommu->vdevs);

                cur++;
        }
out:
        array->entry_num = cur - cmds;
        kfree(cmds);
        return rc;
}

static size_t mock_viommu_get_hw_queue_size(struct iommufd_viommu *viommu,
                                            enum iommu_hw_queue_type queue_type)
{
        if (queue_type != IOMMU_HW_QUEUE_TYPE_SELFTEST)
                return 0;
        return HW_QUEUE_STRUCT_SIZE(struct mock_hw_queue, core);
}

static void mock_hw_queue_destroy(struct iommufd_hw_queue *hw_queue)
{
        struct mock_hw_queue *mock_hw_queue = to_mock_hw_queue(hw_queue);
        struct mock_viommu *mock_viommu = mock_hw_queue->mock_viommu;

        mutex_lock(&mock_viommu->queue_mutex);
        mock_viommu->hw_queue[mock_hw_queue->index] = NULL;
        if (mock_hw_queue->prev)
                iommufd_hw_queue_undepend(mock_hw_queue, mock_hw_queue->prev,
                                          core);
        mutex_unlock(&mock_viommu->queue_mutex);
}

/* Test iommufd_hw_queue_depend/undepend() */
static int mock_hw_queue_init_phys(struct iommufd_hw_queue *hw_queue, u32 index,
                                   phys_addr_t base_addr_pa)
{
        struct mock_viommu *mock_viommu = to_mock_viommu(hw_queue->viommu);
        struct mock_hw_queue *mock_hw_queue = to_mock_hw_queue(hw_queue);
        struct mock_hw_queue *prev = NULL;
        int rc = 0;

        if (index >= IOMMU_TEST_HW_QUEUE_MAX)
                return -EINVAL;

        mutex_lock(&mock_viommu->queue_mutex);

        if (mock_viommu->hw_queue[index]) {
                rc = -EEXIST;
                goto unlock;
        }

        if (index) {
                prev = mock_viommu->hw_queue[index - 1];
                if (!prev) {
                        rc = -EIO;
                        goto unlock;
                }
        }

        /*
         * Test to catch a kernel bug if the core converted the physical address
         * incorrectly. Let mock_domain_iova_to_phys() WARN_ON if it fails.
         */
        if (base_addr_pa != iommu_iova_to_phys(&mock_viommu->s2_parent->domain,
                                               hw_queue->base_addr)) {
                rc = -EFAULT;
                goto unlock;
        }

        if (prev) {
                rc = iommufd_hw_queue_depend(mock_hw_queue, prev, core);
                if (rc)
                        goto unlock;
        }

        mock_hw_queue->prev = prev;
        mock_hw_queue->mock_viommu = mock_viommu;
        mock_viommu->hw_queue[index] = mock_hw_queue;

        hw_queue->destroy = &mock_hw_queue_destroy;
unlock:
        mutex_unlock(&mock_viommu->queue_mutex);
        return rc;
}

static struct iommufd_viommu_ops mock_viommu_ops = {
        .destroy = mock_viommu_destroy,
        .alloc_domain_nested = mock_viommu_alloc_domain_nested,
        .cache_invalidate = mock_viommu_cache_invalidate,
        .get_hw_queue_size = mock_viommu_get_hw_queue_size,
        .hw_queue_init_phys = mock_hw_queue_init_phys,
};

static size_t mock_get_viommu_size(struct device *dev,
                                   enum iommu_viommu_type viommu_type)
{
        if (viommu_type != IOMMU_VIOMMU_TYPE_SELFTEST)
                return 0;
        return VIOMMU_STRUCT_SIZE(struct mock_viommu, core);
}

static int mock_viommu_init(struct iommufd_viommu *viommu,
                            struct iommu_domain *parent_domain,
                            const struct iommu_user_data *user_data)
{
        struct mock_iommu_device *mock_iommu = container_of(
                viommu->iommu_dev, struct mock_iommu_device, iommu_dev);
        struct mock_viommu *mock_viommu = to_mock_viommu(viommu);
        struct iommu_viommu_selftest data;
        int rc;

        if (user_data) {
                rc = iommu_copy_struct_from_user(
                        &data, user_data, IOMMU_VIOMMU_TYPE_SELFTEST, out_data);
                if (rc)
                        return rc;

                /* Allocate two pages */
                mock_viommu->page =
                        (u32 *)__get_free_pages(GFP_KERNEL | __GFP_ZERO, 1);
                if (!mock_viommu->page)
                        return -ENOMEM;

                rc = iommufd_viommu_alloc_mmap(&mock_viommu->core,
                                               __pa(mock_viommu->page),
                                               PAGE_SIZE * 2,
                                               &mock_viommu->mmap_offset);
                if (rc)
                        goto err_free_page;

                /* For loopback tests on both the page and out_data */
                *mock_viommu->page = data.in_data;
                data.out_data = data.in_data;
                data.out_mmap_length = PAGE_SIZE * 2;
                data.out_mmap_offset = mock_viommu->mmap_offset;
                rc = iommu_copy_struct_to_user(
                        user_data, &data, IOMMU_VIOMMU_TYPE_SELFTEST, out_data);
                if (rc)
                        goto err_destroy_mmap;
        }

        refcount_inc(&mock_iommu->users);
        mutex_init(&mock_viommu->queue_mutex);
        mock_viommu->s2_parent = to_mock_domain(parent_domain);

        viommu->ops = &mock_viommu_ops;
        return 0;

err_destroy_mmap:
        iommufd_viommu_destroy_mmap(&mock_viommu->core,
                                    mock_viommu->mmap_offset);
err_free_page:
        free_page((unsigned long)mock_viommu->page);
        return rc;
}

static const struct iommu_ops mock_ops = {
        /*
         * IOMMU_DOMAIN_BLOCKED cannot be returned from def_domain_type()
         * because it is zero.
         */
        .default_domain = &mock_blocking_domain,
        .blocked_domain = &mock_blocking_domain,
        .owner = THIS_MODULE,
        .hw_info = mock_domain_hw_info,
        .domain_alloc_paging_flags = mock_domain_alloc_paging_flags,
        .domain_alloc_nested = mock_domain_alloc_nested,
        .capable = mock_domain_capable,
        .device_group = generic_device_group,
        .probe_device = mock_probe_device,
        .page_response = mock_domain_page_response,
        .user_pasid_table = true,
        .get_viommu_size = mock_get_viommu_size,
        .viommu_init = mock_viommu_init,
};

static void mock_domain_free_nested(struct iommu_domain *domain)
{
        kfree(to_mock_nested(domain));
}

static int
mock_domain_cache_invalidate_user(struct iommu_domain *domain,
                                  struct iommu_user_data_array *array)
{
        struct mock_iommu_domain_nested *mock_nested = to_mock_nested(domain);
        struct iommu_hwpt_invalidate_selftest inv;
        u32 processed = 0;
        int i = 0, j;
        int rc = 0;

        if (array->type != IOMMU_HWPT_INVALIDATE_DATA_SELFTEST) {
                rc = -EINVAL;
                goto out;
        }

        for ( ; i < array->entry_num; i++) {
                rc = iommu_copy_struct_from_user_array(&inv, array,
                                                       IOMMU_HWPT_INVALIDATE_DATA_SELFTEST,
                                                       i, iotlb_id);
                if (rc)
                        break;

                if (inv.flags & ~IOMMU_TEST_INVALIDATE_FLAG_ALL) {
                        rc = -EOPNOTSUPP;
                        break;
                }

                if (inv.iotlb_id > MOCK_NESTED_DOMAIN_IOTLB_ID_MAX) {
                        rc = -EINVAL;
                        break;
                }

                if (inv.flags & IOMMU_TEST_INVALIDATE_FLAG_ALL) {
                        /* Invalidate all mock iotlb entries and ignore iotlb_id */
                        for (j = 0; j < MOCK_NESTED_DOMAIN_IOTLB_NUM; j++)
                                mock_nested->iotlb[j] = 0;
                } else {
                        mock_nested->iotlb[inv.iotlb_id] = 0;
                }

                processed++;
        }

out:
        array->entry_num = processed;
        return rc;
}

static struct iommu_domain_ops domain_nested_ops = {
        .free = mock_domain_free_nested,
        .attach_dev = mock_domain_nop_attach,
        .cache_invalidate_user = mock_domain_cache_invalidate_user,
        .set_dev_pasid = mock_domain_set_dev_pasid_nop,
};

static inline struct iommufd_hw_pagetable *
__get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id, u32 hwpt_type)
{
        struct iommufd_object *obj;

        obj = iommufd_get_object(ucmd->ictx, mockpt_id, hwpt_type);
        if (IS_ERR(obj))
                return ERR_CAST(obj);
        return container_of(obj, struct iommufd_hw_pagetable, obj);
}

static inline struct iommufd_hw_pagetable *
get_md_pagetable(struct iommufd_ucmd *ucmd, u32 mockpt_id,
                 struct mock_iommu_domain **mock)
{
        struct iommufd_hw_pagetable *hwpt;

        hwpt = __get_md_pagetable(ucmd, mockpt_id, IOMMUFD_OBJ_HWPT_PAGING);
        if (IS_ERR(hwpt))
                return hwpt;
        if (hwpt->domain->type != IOMMU_DOMAIN_UNMANAGED ||
            hwpt->domain->owner != &mock_ops) {
                iommufd_put_object(ucmd->ictx, &hwpt->obj);
                return ERR_PTR(-EINVAL);
        }
        *mock = to_mock_domain(hwpt->domain);
        return hwpt;
}

static inline struct iommufd_hw_pagetable *
get_md_pagetable_nested(struct iommufd_ucmd *ucmd, u32 mockpt_id,
                        struct mock_iommu_domain_nested **mock_nested)
{
        struct iommufd_hw_pagetable *hwpt;

        hwpt = __get_md_pagetable(ucmd, mockpt_id, IOMMUFD_OBJ_HWPT_NESTED);
        if (IS_ERR(hwpt))
                return hwpt;
        if (hwpt->domain->type != IOMMU_DOMAIN_NESTED ||
            hwpt->domain->ops != &domain_nested_ops) {
                iommufd_put_object(ucmd->ictx, &hwpt->obj);
                return ERR_PTR(-EINVAL);
        }
        *mock_nested = to_mock_nested(hwpt->domain);
        return hwpt;
}

static void mock_dev_release(struct device *dev)
{
        struct mock_dev *mdev = to_mock_dev(dev);

        ida_free(&mock_dev_ida, mdev->id);
        kfree(mdev);
}

static struct mock_dev *mock_dev_create(unsigned long dev_flags)
{
        struct property_entry prop[] = {
                PROPERTY_ENTRY_U32("pasid-num-bits", 0),
                {},
        };
        const u32 valid_flags = MOCK_FLAGS_DEVICE_NO_DIRTY |
                                MOCK_FLAGS_DEVICE_PASID;
        struct mock_dev *mdev;
        int rc, i;

        if (dev_flags & ~valid_flags)
                return ERR_PTR(-EINVAL);

        mdev = kzalloc_obj(*mdev);
        if (!mdev)
                return ERR_PTR(-ENOMEM);

        init_rwsem(&mdev->viommu_rwsem);
        device_initialize(&mdev->dev);
        mdev->flags = dev_flags;
        mdev->dev.release = mock_dev_release;
        mdev->dev.bus = &iommufd_mock_bus_type.bus;
        for (i = 0; i < MOCK_DEV_CACHE_NUM; i++)
                mdev->cache[i] = IOMMU_TEST_DEV_CACHE_DEFAULT;

        rc = ida_alloc(&mock_dev_ida, GFP_KERNEL);
        if (rc < 0)
                goto err_put;
        mdev->id = rc;

        rc = dev_set_name(&mdev->dev, "iommufd_mock%u", mdev->id);
        if (rc)
                goto err_put;

        if (dev_flags & MOCK_FLAGS_DEVICE_PASID)
                prop[0] = PROPERTY_ENTRY_U32("pasid-num-bits", MOCK_PASID_WIDTH);

        rc = device_create_managed_software_node(&mdev->dev, prop, NULL);
        if (rc) {
                dev_err(&mdev->dev, "add pasid-num-bits property failed, rc: %d", rc);
                goto err_put;
        }

        rc = iommu_mock_device_add(&mdev->dev, &mock_iommu.iommu_dev);
        if (rc)
                goto err_put;
        return mdev;

err_put:
        put_device(&mdev->dev);
        return ERR_PTR(rc);
}

static void mock_dev_destroy(struct mock_dev *mdev)
{
        device_unregister(&mdev->dev);
}

bool iommufd_selftest_is_mock_dev(struct device *dev)
{
        return dev->release == mock_dev_release;
}

/* Create an hw_pagetable with the mock domain so we can test the domain ops */
static int iommufd_test_mock_domain(struct iommufd_ucmd *ucmd,
                                    struct iommu_test_cmd *cmd)
{
        struct iommufd_device *idev;
        struct selftest_obj *sobj;
        u32 pt_id = cmd->id;
        u32 dev_flags = 0;
        u32 idev_id;
        int rc;

        sobj = iommufd_object_alloc(ucmd->ictx, sobj, IOMMUFD_OBJ_SELFTEST);
        if (IS_ERR(sobj))
                return PTR_ERR(sobj);

        sobj->idev.ictx = ucmd->ictx;
        sobj->type = TYPE_IDEV;

        if (cmd->op == IOMMU_TEST_OP_MOCK_DOMAIN_FLAGS)
                dev_flags = cmd->mock_domain_flags.dev_flags;

        sobj->idev.mock_dev = mock_dev_create(dev_flags);
        if (IS_ERR(sobj->idev.mock_dev)) {
                rc = PTR_ERR(sobj->idev.mock_dev);
                goto out_sobj;
        }

        idev = iommufd_device_bind(ucmd->ictx, &sobj->idev.mock_dev->dev,
                                   &idev_id);
        if (IS_ERR(idev)) {
                rc = PTR_ERR(idev);
                goto out_mdev;
        }
        sobj->idev.idev = idev;

        rc = iommufd_device_attach(idev, IOMMU_NO_PASID, &pt_id);
        if (rc)
                goto out_unbind;

        /* Userspace must destroy the device_id to destroy the object */
        cmd->mock_domain.out_hwpt_id = pt_id;
        cmd->mock_domain.out_stdev_id = sobj->obj.id;
        cmd->mock_domain.out_idev_id = idev_id;
        rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
        if (rc)
                goto out_detach;
        iommufd_object_finalize(ucmd->ictx, &sobj->obj);
        return 0;

out_detach:
        iommufd_device_detach(idev, IOMMU_NO_PASID);
out_unbind:
        iommufd_device_unbind(idev);
out_mdev:
        mock_dev_destroy(sobj->idev.mock_dev);
out_sobj:
        iommufd_object_abort(ucmd->ictx, &sobj->obj);
        return rc;
}

static struct selftest_obj *
iommufd_test_get_selftest_obj(struct iommufd_ctx *ictx, u32 id)
{
        struct iommufd_object *dev_obj;
        struct selftest_obj *sobj;

        /*
         * Prefer to use the OBJ_SELFTEST because the destroy_rwsem will ensure
         * it doesn't race with detach, which is not allowed.
         */
        dev_obj = iommufd_get_object(ictx, id, IOMMUFD_OBJ_SELFTEST);
        if (IS_ERR(dev_obj))
                return ERR_CAST(dev_obj);

        sobj = to_selftest_obj(dev_obj);
        if (sobj->type != TYPE_IDEV) {
                iommufd_put_object(ictx, dev_obj);
                return ERR_PTR(-EINVAL);
        }
        return sobj;
}

/* Replace the mock domain with a manually allocated hw_pagetable */
static int iommufd_test_mock_domain_replace(struct iommufd_ucmd *ucmd,
                                            unsigned int device_id, u32 pt_id,
                                            struct iommu_test_cmd *cmd)
{
        struct selftest_obj *sobj;
        int rc;

        sobj = iommufd_test_get_selftest_obj(ucmd->ictx, device_id);
        if (IS_ERR(sobj))
                return PTR_ERR(sobj);

        rc = iommufd_device_replace(sobj->idev.idev, IOMMU_NO_PASID, &pt_id);
        if (rc)
                goto out_sobj;

        cmd->mock_domain_replace.pt_id = pt_id;
        rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));

out_sobj:
        iommufd_put_object(ucmd->ictx, &sobj->obj);
        return rc;
}

/* Add an additional reserved IOVA to the IOAS */
static int iommufd_test_add_reserved(struct iommufd_ucmd *ucmd,
                                     unsigned int mockpt_id,
                                     unsigned long start, size_t length)
{
        unsigned long last;
        struct iommufd_ioas *ioas;
        int rc;

        if (!length)
                return -EINVAL;
        if (check_add_overflow(start, length - 1, &last))
                return -EOVERFLOW;

        ioas = iommufd_get_ioas(ucmd->ictx, mockpt_id);
        if (IS_ERR(ioas))
                return PTR_ERR(ioas);
        down_write(&ioas->iopt.iova_rwsem);
        rc = iopt_reserve_iova(&ioas->iopt, start, last, NULL);
        up_write(&ioas->iopt.iova_rwsem);
        iommufd_put_object(ucmd->ictx, &ioas->obj);
        return rc;
}

/* Check that every pfn under each iova matches the pfn under a user VA */
static int iommufd_test_md_check_pa(struct iommufd_ucmd *ucmd,
                                    unsigned int mockpt_id, unsigned long iova,
                                    size_t length, void __user *uptr)
{
        struct iommufd_hw_pagetable *hwpt;
        struct mock_iommu_domain *mock;
        unsigned int page_size;
        uintptr_t end;
        int rc;

        hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
        if (IS_ERR(hwpt))
                return PTR_ERR(hwpt);

        page_size = 1 << __ffs(mock->domain.pgsize_bitmap);
        if (iova % page_size || length % page_size ||
            (uintptr_t)uptr % page_size ||
            check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end)) {
                rc = -EINVAL;
                goto out_put;
        }

        for (; length; length -= page_size) {
                struct page *pages[1];
                phys_addr_t io_phys;
                unsigned long pfn;
                long npages;

                npages = get_user_pages_fast((uintptr_t)uptr & PAGE_MASK, 1, 0,
                                             pages);
                if (npages < 0) {
                        rc = npages;
                        goto out_put;
                }
                if (WARN_ON(npages != 1)) {
                        rc = -EFAULT;
                        goto out_put;
                }
                pfn = page_to_pfn(pages[0]);
                put_page(pages[0]);

                io_phys = mock->domain.ops->iova_to_phys(&mock->domain, iova);
                if (io_phys !=
                    pfn * PAGE_SIZE + ((uintptr_t)uptr % PAGE_SIZE)) {
                        rc = -EINVAL;
                        goto out_put;
                }
                iova += page_size;
                uptr += page_size;
        }
        rc = 0;

out_put:
        iommufd_put_object(ucmd->ictx, &hwpt->obj);
        return rc;
}

/* Check that the page ref count matches, to look for missing pin/unpins */
static int iommufd_test_md_check_refs(struct iommufd_ucmd *ucmd,
                                      void __user *uptr, size_t length,
                                      unsigned int refs)
{
        uintptr_t end;

        if (length % PAGE_SIZE || (uintptr_t)uptr % PAGE_SIZE ||
            check_add_overflow((uintptr_t)uptr, (uintptr_t)length, &end))
                return -EINVAL;

        for (; length; length -= PAGE_SIZE) {
                struct page *pages[1];
                long npages;

                npages = get_user_pages_fast((uintptr_t)uptr, 1, 0, pages);
                if (npages < 0)
                        return npages;
                if (WARN_ON(npages != 1))
                        return -EFAULT;
                if (!PageCompound(pages[0])) {
                        unsigned int count;

                        count = page_ref_count(pages[0]);
                        if (count / GUP_PIN_COUNTING_BIAS != refs) {
                                put_page(pages[0]);
                                return -EIO;
                        }
                }
                put_page(pages[0]);
                uptr += PAGE_SIZE;
        }
        return 0;
}

static int iommufd_test_md_check_iotlb(struct iommufd_ucmd *ucmd, u32 mockpt_id,
                                       unsigned int iotlb_id, u32 iotlb)
{
        struct mock_iommu_domain_nested *mock_nested;
        struct iommufd_hw_pagetable *hwpt;
        int rc = 0;

        hwpt = get_md_pagetable_nested(ucmd, mockpt_id, &mock_nested);
        if (IS_ERR(hwpt))
                return PTR_ERR(hwpt);

        mock_nested = to_mock_nested(hwpt->domain);

        if (iotlb_id > MOCK_NESTED_DOMAIN_IOTLB_ID_MAX ||
            mock_nested->iotlb[iotlb_id] != iotlb)
                rc = -EINVAL;
        iommufd_put_object(ucmd->ictx, &hwpt->obj);
        return rc;
}

static int iommufd_test_dev_check_cache(struct iommufd_ucmd *ucmd, u32 idev_id,
                                        unsigned int cache_id, u32 cache)
{
        struct iommufd_device *idev;
        struct mock_dev *mdev;
        int rc = 0;

        idev = iommufd_get_device(ucmd, idev_id);
        if (IS_ERR(idev))
                return PTR_ERR(idev);
        mdev = container_of(idev->dev, struct mock_dev, dev);

        if (cache_id > MOCK_DEV_CACHE_ID_MAX || mdev->cache[cache_id] != cache)
                rc = -EINVAL;
        iommufd_put_object(ucmd->ictx, &idev->obj);
        return rc;
}

struct selftest_access {
        struct iommufd_access *access;
        struct file *file;
        struct mutex lock;
        struct list_head items;
        unsigned int next_id;
        bool destroying;
};

struct selftest_access_item {
        struct list_head items_elm;
        unsigned long iova;
        size_t length;
        unsigned int id;
};

static const struct file_operations iommfd_test_staccess_fops;

static struct selftest_access *iommufd_access_get(int fd)
{
        struct file *file;

        file = fget(fd);
        if (!file)
                return ERR_PTR(-EBADFD);

        if (file->f_op != &iommfd_test_staccess_fops) {
                fput(file);
                return ERR_PTR(-EBADFD);
        }
        return file->private_data;
}

static void iommufd_test_access_unmap(void *data, unsigned long iova,
                                      unsigned long length)
{
        unsigned long iova_last = iova + length - 1;
        struct selftest_access *staccess = data;
        struct selftest_access_item *item;
        struct selftest_access_item *tmp;

        mutex_lock(&staccess->lock);
        list_for_each_entry_safe(item, tmp, &staccess->items, items_elm) {
                if (iova > item->iova + item->length - 1 ||
                    iova_last < item->iova)
                        continue;
                list_del(&item->items_elm);
                iommufd_access_unpin_pages(staccess->access, item->iova,
                                           item->length);
                kfree(item);
        }
        mutex_unlock(&staccess->lock);
}

static int iommufd_test_access_item_destroy(struct iommufd_ucmd *ucmd,
                                            unsigned int access_id,
                                            unsigned int item_id)
{
        struct selftest_access_item *item;
        struct selftest_access *staccess;

        staccess = iommufd_access_get(access_id);
        if (IS_ERR(staccess))
                return PTR_ERR(staccess);

        mutex_lock(&staccess->lock);
        list_for_each_entry(item, &staccess->items, items_elm) {
                if (item->id == item_id) {
                        list_del(&item->items_elm);
                        iommufd_access_unpin_pages(staccess->access, item->iova,
                                                   item->length);
                        mutex_unlock(&staccess->lock);
                        kfree(item);
                        fput(staccess->file);
                        return 0;
                }
        }
        mutex_unlock(&staccess->lock);
        fput(staccess->file);
        return -ENOENT;
}

static int iommufd_test_staccess_release(struct inode *inode,
                                         struct file *filep)
{
        struct selftest_access *staccess = filep->private_data;

        if (staccess->access) {
                iommufd_test_access_unmap(staccess, 0, ULONG_MAX);
                iommufd_access_destroy(staccess->access);
        }
        mutex_destroy(&staccess->lock);
        kfree(staccess);
        return 0;
}

static const struct iommufd_access_ops selftest_access_ops_pin = {
        .needs_pin_pages = 1,
        .unmap = iommufd_test_access_unmap,
};

static const struct iommufd_access_ops selftest_access_ops = {
        .unmap = iommufd_test_access_unmap,
};

static const struct file_operations iommfd_test_staccess_fops = {
        .release = iommufd_test_staccess_release,
};

static struct selftest_access *iommufd_test_alloc_access(void)
{
        struct selftest_access *staccess;
        struct file *filep;

        staccess = kzalloc_obj(*staccess, GFP_KERNEL_ACCOUNT);
        if (!staccess)
                return ERR_PTR(-ENOMEM);
        INIT_LIST_HEAD(&staccess->items);
        mutex_init(&staccess->lock);

        filep = anon_inode_getfile("[iommufd_test_staccess]",
                                   &iommfd_test_staccess_fops, staccess,
                                   O_RDWR);
        if (IS_ERR(filep)) {
                kfree(staccess);
                return ERR_CAST(filep);
        }
        staccess->file = filep;
        return staccess;
}

static int iommufd_test_create_access(struct iommufd_ucmd *ucmd,
                                      unsigned int ioas_id, unsigned int flags)
{
        struct iommu_test_cmd *cmd = ucmd->cmd;
        struct selftest_access *staccess;
        struct iommufd_access *access;
        u32 id;
        int fdno;
        int rc;

        if (flags & ~MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES)
                return -EOPNOTSUPP;

        staccess = iommufd_test_alloc_access();
        if (IS_ERR(staccess))
                return PTR_ERR(staccess);

        fdno = get_unused_fd_flags(O_CLOEXEC);
        if (fdno < 0) {
                rc = -ENOMEM;
                goto out_free_staccess;
        }

        access = iommufd_access_create(
                ucmd->ictx,
                (flags & MOCK_FLAGS_ACCESS_CREATE_NEEDS_PIN_PAGES) ?
                        &selftest_access_ops_pin :
                        &selftest_access_ops,
                staccess, &id);
        if (IS_ERR(access)) {
                rc = PTR_ERR(access);
                goto out_put_fdno;
        }
        rc = iommufd_access_attach(access, ioas_id);
        if (rc)
                goto out_destroy;
        cmd->create_access.out_access_fd = fdno;
        rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
        if (rc)
                goto out_destroy;

        staccess->access = access;
        fd_install(fdno, staccess->file);
        return 0;

out_destroy:
        iommufd_access_destroy(access);
out_put_fdno:
        put_unused_fd(fdno);
out_free_staccess:
        fput(staccess->file);
        return rc;
}

static int iommufd_test_access_replace_ioas(struct iommufd_ucmd *ucmd,
                                            unsigned int access_id,
                                            unsigned int ioas_id)
{
        struct selftest_access *staccess;
        int rc;

        staccess = iommufd_access_get(access_id);
        if (IS_ERR(staccess))
                return PTR_ERR(staccess);

        rc = iommufd_access_replace(staccess->access, ioas_id);
        fput(staccess->file);
        return rc;
}

/* Check that the pages in a page array match the pages in the user VA */
static int iommufd_test_check_pages(void __user *uptr, struct page **pages,
                                    size_t npages)
{
        for (; npages; npages--) {
                struct page *tmp_pages[1];
                long rc;

                rc = get_user_pages_fast((uintptr_t)uptr, 1, 0, tmp_pages);
                if (rc < 0)
                        return rc;
                if (WARN_ON(rc != 1))
                        return -EFAULT;
                put_page(tmp_pages[0]);
                if (tmp_pages[0] != *pages)
                        return -EBADE;
                pages++;
                uptr += PAGE_SIZE;
        }
        return 0;
}

static int iommufd_test_access_pages(struct iommufd_ucmd *ucmd,
                                     unsigned int access_id, unsigned long iova,
                                     size_t length, void __user *uptr,
                                     u32 flags)
{
        struct iommu_test_cmd *cmd = ucmd->cmd;
        struct selftest_access_item *item;
        struct selftest_access *staccess;
        struct page **pages;
        size_t npages;
        int rc;

        /* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
        if (length > 16 * 1024 * 1024)
                return -ENOMEM;

        if (flags & ~(MOCK_FLAGS_ACCESS_WRITE | MOCK_FLAGS_ACCESS_SYZ))
                return -EOPNOTSUPP;

        staccess = iommufd_access_get(access_id);
        if (IS_ERR(staccess))
                return PTR_ERR(staccess);

        if (staccess->access->ops != &selftest_access_ops_pin) {
                rc = -EOPNOTSUPP;
                goto out_put;
        }

        if (flags & MOCK_FLAGS_ACCESS_SYZ)
                iova = iommufd_test_syz_conv_iova(staccess->access,
                                                  &cmd->access_pages.iova);

        npages = (ALIGN(iova + length, PAGE_SIZE) -
                  ALIGN_DOWN(iova, PAGE_SIZE)) /
                 PAGE_SIZE;
        pages = kvzalloc_objs(*pages, npages, GFP_KERNEL_ACCOUNT);
        if (!pages) {
                rc = -ENOMEM;
                goto out_put;
        }

        /*
         * Drivers will need to think very carefully about this locking. The
         * core code can do multiple unmaps instantaneously after
         * iommufd_access_pin_pages() and *all* the unmaps must not return until
         * the range is unpinned. This simple implementation puts a global lock
         * around the pin, which may not suit drivers that want this to be a
         * performance path. drivers that get this wrong will trigger WARN_ON
         * races and cause EDEADLOCK failures to userspace.
         */
        mutex_lock(&staccess->lock);
        rc = iommufd_access_pin_pages(staccess->access, iova, length, pages,
                                      flags & MOCK_FLAGS_ACCESS_WRITE);
        if (rc)
                goto out_unlock;

        /* For syzkaller allow uptr to be NULL to skip this check */
        if (uptr) {
                rc = iommufd_test_check_pages(
                        uptr - (iova - ALIGN_DOWN(iova, PAGE_SIZE)), pages,
                        npages);
                if (rc)
                        goto out_unaccess;
        }

        item = kzalloc_obj(*item, GFP_KERNEL_ACCOUNT);
        if (!item) {
                rc = -ENOMEM;
                goto out_unaccess;
        }

        item->iova = iova;
        item->length = length;
        item->id = staccess->next_id++;
        list_add_tail(&item->items_elm, &staccess->items);

        cmd->access_pages.out_access_pages_id = item->id;
        rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
        if (rc)
                goto out_free_item;
        goto out_unlock;

out_free_item:
        list_del(&item->items_elm);
        kfree(item);
out_unaccess:
        iommufd_access_unpin_pages(staccess->access, iova, length);
out_unlock:
        mutex_unlock(&staccess->lock);
        kvfree(pages);
out_put:
        fput(staccess->file);
        return rc;
}

static int iommufd_test_access_rw(struct iommufd_ucmd *ucmd,
                                  unsigned int access_id, unsigned long iova,
                                  size_t length, void __user *ubuf,
                                  unsigned int flags)
{
        struct iommu_test_cmd *cmd = ucmd->cmd;
        struct selftest_access *staccess;
        void *tmp;
        int rc;

        /* Prevent syzkaller from triggering a WARN_ON in kvzalloc() */
        if (length > 16 * 1024 * 1024)
                return -ENOMEM;

        if (flags & ~(MOCK_ACCESS_RW_WRITE | MOCK_ACCESS_RW_SLOW_PATH |
                      MOCK_FLAGS_ACCESS_SYZ))
                return -EOPNOTSUPP;

        staccess = iommufd_access_get(access_id);
        if (IS_ERR(staccess))
                return PTR_ERR(staccess);

        tmp = kvzalloc(length, GFP_KERNEL_ACCOUNT);
        if (!tmp) {
                rc = -ENOMEM;
                goto out_put;
        }

        if (flags & MOCK_ACCESS_RW_WRITE) {
                if (copy_from_user(tmp, ubuf, length)) {
                        rc = -EFAULT;
                        goto out_free;
                }
        }

        if (flags & MOCK_FLAGS_ACCESS_SYZ)
                iova = iommufd_test_syz_conv_iova(staccess->access,
                                                  &cmd->access_rw.iova);

        rc = iommufd_access_rw(staccess->access, iova, tmp, length, flags);
        if (rc)
                goto out_free;
        if (!(flags & MOCK_ACCESS_RW_WRITE)) {
                if (copy_to_user(ubuf, tmp, length)) {
                        rc = -EFAULT;
                        goto out_free;
                }
        }

out_free:
        kvfree(tmp);
out_put:
        fput(staccess->file);
        return rc;
}
static_assert((unsigned int)MOCK_ACCESS_RW_WRITE == IOMMUFD_ACCESS_RW_WRITE);
static_assert((unsigned int)MOCK_ACCESS_RW_SLOW_PATH ==
              __IOMMUFD_ACCESS_RW_SLOW_PATH);

static int iommufd_test_dirty(struct iommufd_ucmd *ucmd, unsigned int mockpt_id,
                              unsigned long iova, size_t length,
                              unsigned long page_size, void __user *uptr,
                              u32 flags)
{
        unsigned long i, max;
        struct iommu_test_cmd *cmd = ucmd->cmd;
        struct iommufd_hw_pagetable *hwpt;
        struct mock_iommu_domain *mock;
        int rc, count = 0;
        void *tmp;

        if (!page_size || !length || iova % page_size || length % page_size ||
            !uptr)
                return -EINVAL;

        hwpt = get_md_pagetable(ucmd, mockpt_id, &mock);
        if (IS_ERR(hwpt))
                return PTR_ERR(hwpt);

        if (!(mock->flags & MOCK_DIRTY_TRACK) || !mock->iommu.ops->set_dirty) {
                rc = -EINVAL;
                goto out_put;
        }

        max = length / page_size;
        tmp = kvzalloc(DIV_ROUND_UP(max, BITS_PER_LONG) * sizeof(unsigned long),
                       GFP_KERNEL_ACCOUNT);
        if (!tmp) {
                rc = -ENOMEM;
                goto out_put;
        }

        if (copy_from_user(tmp, uptr, DIV_ROUND_UP(max, BITS_PER_BYTE))) {
                rc = -EFAULT;
                goto out_free;
        }

        for (i = 0; i < max; i++) {
                if (!test_bit(i, (unsigned long *)tmp))
                        continue;
                mock->iommu.ops->set_dirty(&mock->iommu, iova + i * page_size);
                count++;
        }

        cmd->dirty.out_nr_dirty = count;
        rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
out_free:
        kvfree(tmp);
out_put:
        iommufd_put_object(ucmd->ictx, &hwpt->obj);
        return rc;
}

static int iommufd_test_trigger_iopf(struct iommufd_ucmd *ucmd,
                                     struct iommu_test_cmd *cmd)
{
        struct iopf_fault event = {};
        struct iommufd_device *idev;

        idev = iommufd_get_device(ucmd, cmd->trigger_iopf.dev_id);
        if (IS_ERR(idev))
                return PTR_ERR(idev);

        event.fault.prm.flags = IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE;
        if (cmd->trigger_iopf.pasid != IOMMU_NO_PASID)
                event.fault.prm.flags |= IOMMU_FAULT_PAGE_REQUEST_PASID_VALID;
        event.fault.type = IOMMU_FAULT_PAGE_REQ;
        event.fault.prm.addr = cmd->trigger_iopf.addr;
        event.fault.prm.pasid = cmd->trigger_iopf.pasid;
        event.fault.prm.grpid = cmd->trigger_iopf.grpid;
        event.fault.prm.perm = cmd->trigger_iopf.perm;

        iommu_report_device_fault(idev->dev, &event);
        iommufd_put_object(ucmd->ictx, &idev->obj);

        return 0;
}

static int iommufd_test_trigger_vevent(struct iommufd_ucmd *ucmd,
                                       struct iommu_test_cmd *cmd)
{
        struct iommu_viommu_event_selftest test = {};
        struct iommufd_device *idev;
        struct mock_dev *mdev;
        int rc = -ENOENT;

        idev = iommufd_get_device(ucmd, cmd->trigger_vevent.dev_id);
        if (IS_ERR(idev))
                return PTR_ERR(idev);
        mdev = to_mock_dev(idev->dev);

        down_read(&mdev->viommu_rwsem);
        if (!mdev->viommu || !mdev->vdev_id)
                goto out_unlock;

        test.virt_id = mdev->vdev_id;
        rc = iommufd_viommu_report_event(&mdev->viommu->core,
                                         IOMMU_VEVENTQ_TYPE_SELFTEST, &test,
                                         sizeof(test));
out_unlock:
        up_read(&mdev->viommu_rwsem);
        iommufd_put_object(ucmd->ictx, &idev->obj);

        return rc;
}

static inline struct iommufd_hw_pagetable *
iommufd_get_hwpt(struct iommufd_ucmd *ucmd, u32 id)
{
        struct iommufd_object *pt_obj;

        pt_obj = iommufd_get_object(ucmd->ictx, id, IOMMUFD_OBJ_ANY);
        if (IS_ERR(pt_obj))
                return ERR_CAST(pt_obj);

        if (pt_obj->type != IOMMUFD_OBJ_HWPT_NESTED &&
            pt_obj->type != IOMMUFD_OBJ_HWPT_PAGING) {
                iommufd_put_object(ucmd->ictx, pt_obj);
                return ERR_PTR(-EINVAL);
        }

        return container_of(pt_obj, struct iommufd_hw_pagetable, obj);
}

static int iommufd_test_pasid_check_hwpt(struct iommufd_ucmd *ucmd,
                                         struct iommu_test_cmd *cmd)
{
        u32 hwpt_id = cmd->pasid_check.hwpt_id;
        struct iommu_domain *attached_domain;
        struct iommu_attach_handle *handle;
        struct iommufd_hw_pagetable *hwpt;
        struct selftest_obj *sobj;
        struct mock_dev *mdev;
        int rc = 0;

        sobj = iommufd_test_get_selftest_obj(ucmd->ictx, cmd->id);
        if (IS_ERR(sobj))
                return PTR_ERR(sobj);

        mdev = sobj->idev.mock_dev;

        handle = iommu_attach_handle_get(mdev->dev.iommu_group,
                                         cmd->pasid_check.pasid, 0);
        if (IS_ERR(handle))
                attached_domain = NULL;
        else
                attached_domain = handle->domain;

        /* hwpt_id == 0 means to check if pasid is detached */
        if (!hwpt_id) {
                if (attached_domain)
                        rc = -EINVAL;
                goto out_sobj;
        }

        hwpt = iommufd_get_hwpt(ucmd, hwpt_id);
        if (IS_ERR(hwpt)) {
                rc = PTR_ERR(hwpt);
                goto out_sobj;
        }

        if (attached_domain != hwpt->domain)
                rc = -EINVAL;

        iommufd_put_object(ucmd->ictx, &hwpt->obj);
out_sobj:
        iommufd_put_object(ucmd->ictx, &sobj->obj);
        return rc;
}

static int iommufd_test_pasid_attach(struct iommufd_ucmd *ucmd,
                                     struct iommu_test_cmd *cmd)
{
        struct selftest_obj *sobj;
        int rc;

        sobj = iommufd_test_get_selftest_obj(ucmd->ictx, cmd->id);
        if (IS_ERR(sobj))
                return PTR_ERR(sobj);

        rc = iommufd_device_attach(sobj->idev.idev, cmd->pasid_attach.pasid,
                                   &cmd->pasid_attach.pt_id);
        if (rc)
                goto out_sobj;

        rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));
        if (rc)
                iommufd_device_detach(sobj->idev.idev, cmd->pasid_attach.pasid);

out_sobj:
        iommufd_put_object(ucmd->ictx, &sobj->obj);
        return rc;
}

static int iommufd_test_pasid_replace(struct iommufd_ucmd *ucmd,
                                      struct iommu_test_cmd *cmd)
{
        struct selftest_obj *sobj;
        int rc;

        sobj = iommufd_test_get_selftest_obj(ucmd->ictx, cmd->id);
        if (IS_ERR(sobj))
                return PTR_ERR(sobj);

        rc = iommufd_device_replace(sobj->idev.idev, cmd->pasid_attach.pasid,
                                    &cmd->pasid_attach.pt_id);
        if (rc)
                goto out_sobj;

        rc = iommufd_ucmd_respond(ucmd, sizeof(*cmd));

out_sobj:
        iommufd_put_object(ucmd->ictx, &sobj->obj);
        return rc;
}

static int iommufd_test_pasid_detach(struct iommufd_ucmd *ucmd,
                                     struct iommu_test_cmd *cmd)
{
        struct selftest_obj *sobj;

        sobj = iommufd_test_get_selftest_obj(ucmd->ictx, cmd->id);
        if (IS_ERR(sobj))
                return PTR_ERR(sobj);

        iommufd_device_detach(sobj->idev.idev, cmd->pasid_detach.pasid);
        iommufd_put_object(ucmd->ictx, &sobj->obj);
        return 0;
}

void iommufd_selftest_destroy(struct iommufd_object *obj)
{
        struct selftest_obj *sobj = to_selftest_obj(obj);

        switch (sobj->type) {
        case TYPE_IDEV:
                iommufd_device_detach(sobj->idev.idev, IOMMU_NO_PASID);
                iommufd_device_unbind(sobj->idev.idev);
                mock_dev_destroy(sobj->idev.mock_dev);
                break;
        }
}

struct iommufd_test_dma_buf {
        void *memory;
        size_t length;
        bool revoked;
};

static int iommufd_test_dma_buf_attach(struct dma_buf *dmabuf,
                                       struct dma_buf_attachment *attachment)
{
        return 0;
}

static void iommufd_test_dma_buf_detach(struct dma_buf *dmabuf,
                                        struct dma_buf_attachment *attachment)
{
}

static struct sg_table *
iommufd_test_dma_buf_map(struct dma_buf_attachment *attachment,
                         enum dma_data_direction dir)
{
        return ERR_PTR(-EOPNOTSUPP);
}

static void iommufd_test_dma_buf_unmap(struct dma_buf_attachment *attachment,
                                       struct sg_table *sgt,
                                       enum dma_data_direction dir)
{
}

static void iommufd_test_dma_buf_release(struct dma_buf *dmabuf)
{
        struct iommufd_test_dma_buf *priv = dmabuf->priv;

        kfree(priv->memory);
        kfree(priv);
}

static const struct dma_buf_ops iommufd_test_dmabuf_ops = {
        .attach = iommufd_test_dma_buf_attach,
        .detach = iommufd_test_dma_buf_detach,
        .map_dma_buf = iommufd_test_dma_buf_map,
        .release = iommufd_test_dma_buf_release,
        .unmap_dma_buf = iommufd_test_dma_buf_unmap,
};

int iommufd_test_dma_buf_iommufd_map(struct dma_buf_attachment *attachment,
                                     struct phys_vec *phys)
{
        struct iommufd_test_dma_buf *priv = attachment->dmabuf->priv;

        dma_resv_assert_held(attachment->dmabuf->resv);

        if (attachment->dmabuf->ops != &iommufd_test_dmabuf_ops)
                return -EOPNOTSUPP;

        if (priv->revoked)
                return -ENODEV;

        phys->paddr = virt_to_phys(priv->memory);
        phys->len = priv->length;
        return 0;
}

static int iommufd_test_dmabuf_get(struct iommufd_ucmd *ucmd,
                                   unsigned int open_flags,
                                   size_t len)
{
        DEFINE_DMA_BUF_EXPORT_INFO(exp_info);
        struct iommufd_test_dma_buf *priv;
        struct dma_buf *dmabuf;
        int rc;

        len = ALIGN(len, PAGE_SIZE);
        if (len == 0 || len > PAGE_SIZE * 512)
                return -EINVAL;

        priv = kzalloc_obj(*priv);
        if (!priv)
                return -ENOMEM;

        priv->length = len;
        priv->memory = kzalloc(len, GFP_KERNEL);
        if (!priv->memory) {
                rc = -ENOMEM;
                goto err_free;
        }

        exp_info.ops = &iommufd_test_dmabuf_ops;
        exp_info.size = len;
        exp_info.flags = open_flags;
        exp_info.priv = priv;

        dmabuf = dma_buf_export(&exp_info);
        if (IS_ERR(dmabuf)) {
                rc = PTR_ERR(dmabuf);
                goto err_free;
        }

        return dma_buf_fd(dmabuf, open_flags);

err_free:
        kfree(priv->memory);
        kfree(priv);
        return rc;
}

static int iommufd_test_dmabuf_revoke(struct iommufd_ucmd *ucmd, int fd,
                                      bool revoked)
{
        struct iommufd_test_dma_buf *priv;
        struct dma_buf *dmabuf;
        int rc = 0;

        dmabuf = dma_buf_get(fd);
        if (IS_ERR(dmabuf))
                return PTR_ERR(dmabuf);

        if (dmabuf->ops != &iommufd_test_dmabuf_ops) {
                rc = -EOPNOTSUPP;
                goto err_put;
        }

        priv = dmabuf->priv;
        dma_resv_lock(dmabuf->resv, NULL);
        priv->revoked = revoked;
        dma_buf_move_notify(dmabuf);
        dma_resv_unlock(dmabuf->resv);

err_put:
        dma_buf_put(dmabuf);
        return rc;
}

int iommufd_test(struct iommufd_ucmd *ucmd)
{
        struct iommu_test_cmd *cmd = ucmd->cmd;

        switch (cmd->op) {
        case IOMMU_TEST_OP_ADD_RESERVED:
                return iommufd_test_add_reserved(ucmd, cmd->id,
                                                 cmd->add_reserved.start,
                                                 cmd->add_reserved.length);
        case IOMMU_TEST_OP_MOCK_DOMAIN:
        case IOMMU_TEST_OP_MOCK_DOMAIN_FLAGS:
                return iommufd_test_mock_domain(ucmd, cmd);
        case IOMMU_TEST_OP_MOCK_DOMAIN_REPLACE:
                return iommufd_test_mock_domain_replace(
                        ucmd, cmd->id, cmd->mock_domain_replace.pt_id, cmd);
        case IOMMU_TEST_OP_MD_CHECK_MAP:
                return iommufd_test_md_check_pa(
                        ucmd, cmd->id, cmd->check_map.iova,
                        cmd->check_map.length,
                        u64_to_user_ptr(cmd->check_map.uptr));
        case IOMMU_TEST_OP_MD_CHECK_REFS:
                return iommufd_test_md_check_refs(
                        ucmd, u64_to_user_ptr(cmd->check_refs.uptr),
                        cmd->check_refs.length, cmd->check_refs.refs);
        case IOMMU_TEST_OP_MD_CHECK_IOTLB:
                return iommufd_test_md_check_iotlb(ucmd, cmd->id,
                                                   cmd->check_iotlb.id,
                                                   cmd->check_iotlb.iotlb);
        case IOMMU_TEST_OP_DEV_CHECK_CACHE:
                return iommufd_test_dev_check_cache(ucmd, cmd->id,
                                                    cmd->check_dev_cache.id,
                                                    cmd->check_dev_cache.cache);
        case IOMMU_TEST_OP_CREATE_ACCESS:
                return iommufd_test_create_access(ucmd, cmd->id,
                                                  cmd->create_access.flags);
        case IOMMU_TEST_OP_ACCESS_REPLACE_IOAS:
                return iommufd_test_access_replace_ioas(
                        ucmd, cmd->id, cmd->access_replace_ioas.ioas_id);
        case IOMMU_TEST_OP_ACCESS_PAGES:
                return iommufd_test_access_pages(
                        ucmd, cmd->id, cmd->access_pages.iova,
                        cmd->access_pages.length,
                        u64_to_user_ptr(cmd->access_pages.uptr),
                        cmd->access_pages.flags);
        case IOMMU_TEST_OP_ACCESS_RW:
                return iommufd_test_access_rw(
                        ucmd, cmd->id, cmd->access_rw.iova,
                        cmd->access_rw.length,
                        u64_to_user_ptr(cmd->access_rw.uptr),
                        cmd->access_rw.flags);
        case IOMMU_TEST_OP_DESTROY_ACCESS_PAGES:
                return iommufd_test_access_item_destroy(
                        ucmd, cmd->id, cmd->destroy_access_pages.access_pages_id);
        case IOMMU_TEST_OP_SET_TEMP_MEMORY_LIMIT:
                /* Protect _batch_init(), can not be less than elmsz */
                if (cmd->memory_limit.limit <
                    sizeof(unsigned long) + sizeof(u32))
                        return -EINVAL;
                iommufd_test_memory_limit = cmd->memory_limit.limit;
                return 0;
        case IOMMU_TEST_OP_DIRTY:
                return iommufd_test_dirty(ucmd, cmd->id, cmd->dirty.iova,
                                          cmd->dirty.length,
                                          cmd->dirty.page_size,
                                          u64_to_user_ptr(cmd->dirty.uptr),
                                          cmd->dirty.flags);
        case IOMMU_TEST_OP_TRIGGER_IOPF:
                return iommufd_test_trigger_iopf(ucmd, cmd);
        case IOMMU_TEST_OP_TRIGGER_VEVENT:
                return iommufd_test_trigger_vevent(ucmd, cmd);
        case IOMMU_TEST_OP_PASID_ATTACH:
                return iommufd_test_pasid_attach(ucmd, cmd);
        case IOMMU_TEST_OP_PASID_REPLACE:
                return iommufd_test_pasid_replace(ucmd, cmd);
        case IOMMU_TEST_OP_PASID_DETACH:
                return iommufd_test_pasid_detach(ucmd, cmd);
        case IOMMU_TEST_OP_PASID_CHECK_HWPT:
                return iommufd_test_pasid_check_hwpt(ucmd, cmd);
        case IOMMU_TEST_OP_DMABUF_GET:
                return iommufd_test_dmabuf_get(ucmd, cmd->dmabuf_get.open_flags,
                                               cmd->dmabuf_get.length);
        case IOMMU_TEST_OP_DMABUF_REVOKE:
                return iommufd_test_dmabuf_revoke(ucmd,
                                                  cmd->dmabuf_revoke.dmabuf_fd,
                                                  cmd->dmabuf_revoke.revoked);
        default:
                return -EOPNOTSUPP;
        }
}

bool iommufd_should_fail(void)
{
        return should_fail(&fail_iommufd, 1);
}

int __init iommufd_test_init(void)
{
        struct platform_device_info pdevinfo = {
                .name = "iommufd_selftest_iommu",
        };
        int rc;

        dbgfs_root =
                fault_create_debugfs_attr("fail_iommufd", NULL, &fail_iommufd);

        selftest_iommu_dev = platform_device_register_full(&pdevinfo);
        if (IS_ERR(selftest_iommu_dev)) {
                rc = PTR_ERR(selftest_iommu_dev);
                goto err_dbgfs;
        }

        rc = bus_register(&iommufd_mock_bus_type.bus);
        if (rc)
                goto err_platform;

        rc = iommu_device_sysfs_add(&mock_iommu.iommu_dev,
                                    &selftest_iommu_dev->dev, NULL, "%s",
                                    dev_name(&selftest_iommu_dev->dev));
        if (rc)
                goto err_bus;

        rc = iommu_device_register_bus(&mock_iommu.iommu_dev, &mock_ops,
                                       &iommufd_mock_bus_type.bus,
                                       &iommufd_mock_bus_type.nb);
        if (rc)
                goto err_sysfs;

        refcount_set(&mock_iommu.users, 1);
        init_completion(&mock_iommu.complete);

        mock_iommu_iopf_queue = iopf_queue_alloc("mock-iopfq");
        mock_iommu.iommu_dev.max_pasids = (1 << MOCK_PASID_WIDTH);

        return 0;

err_sysfs:
        iommu_device_sysfs_remove(&mock_iommu.iommu_dev);
err_bus:
        bus_unregister(&iommufd_mock_bus_type.bus);
err_platform:
        platform_device_unregister(selftest_iommu_dev);
err_dbgfs:
        debugfs_remove_recursive(dbgfs_root);
        return rc;
}

static void iommufd_test_wait_for_users(void)
{
        if (refcount_dec_and_test(&mock_iommu.users))
                return;
        /*
         * Time out waiting for iommu device user count to become 0.
         *
         * Note that this is just making an example here, since the selftest is
         * built into the iommufd module, i.e. it only unplugs the iommu device
         * when unloading the module. So, it is expected that this WARN_ON will
         * not trigger, as long as any iommufd FDs are open.
         */
        WARN_ON(!wait_for_completion_timeout(&mock_iommu.complete,
                                             msecs_to_jiffies(10000)));
}

void iommufd_test_exit(void)
{
        if (mock_iommu_iopf_queue) {
                iopf_queue_free(mock_iommu_iopf_queue);
                mock_iommu_iopf_queue = NULL;
        }

        iommufd_test_wait_for_users();
        iommu_device_sysfs_remove(&mock_iommu.iommu_dev);
        iommu_device_unregister_bus(&mock_iommu.iommu_dev,
                                    &iommufd_mock_bus_type.bus,
                                    &iommufd_mock_bus_type.nb);
        bus_unregister(&iommufd_mock_bus_type.bus);
        platform_device_unregister(selftest_iommu_dev);
        debugfs_remove_recursive(dbgfs_root);
}

MODULE_IMPORT_NS("GENERIC_PT_IOMMU");