root/tools/testing/selftests/arm64/mte/check_tags_inclusion.c
// SPDX-License-Identifier: GPL-2.0
// Copyright (C) 2020 ARM Limited

#define _GNU_SOURCE

#include <errno.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <ucontext.h>
#include <sys/wait.h>

#include "kselftest.h"
#include "mte_common_util.h"
#include "mte_def.h"

#define BUFFER_SIZE             (5 * MT_GRANULE_SIZE)
#define RUNS                    (MT_TAG_COUNT * 2)
#define MTE_LAST_TAG_MASK       (0x7FFF)

static int verify_mte_pointer_validity(char *ptr, int mode)
{
        mte_initialize_current_context(mode, (uintptr_t)ptr, BUFFER_SIZE);
        /* Check the validity of the tagged pointer */
        memset(ptr, '1', BUFFER_SIZE);
        mte_wait_after_trig();
        if (cur_mte_cxt.fault_valid) {
                ksft_print_msg("Unexpected fault recorded for %p-%p in mode %x\n",
                               ptr, ptr + BUFFER_SIZE, mode);
                return KSFT_FAIL;
        }
        /* Proceed further for nonzero tags */
        if (!MT_FETCH_TAG((uintptr_t)ptr))
                return KSFT_PASS;
        mte_initialize_current_context(mode, (uintptr_t)ptr, BUFFER_SIZE + 1);
        /* Check the validity outside the range */
        ptr[BUFFER_SIZE] = '2';
        mte_wait_after_trig();
        if (!cur_mte_cxt.fault_valid) {
                ksft_print_msg("No valid fault recorded for %p in mode %x\n",
                               ptr, mode);
                return KSFT_FAIL;
        } else {
                return KSFT_PASS;
        }
}

static int check_single_included_tags(int mem_type, int mode)
{
        char *ptr;
        int tag, run, ret, result = KSFT_PASS;

        ptr = mte_allocate_memory(BUFFER_SIZE + MT_GRANULE_SIZE, mem_type, 0, false);
        if (check_allocated_memory(ptr, BUFFER_SIZE + MT_GRANULE_SIZE,
                                   mem_type, false) != KSFT_PASS)
                return KSFT_FAIL;

        for (tag = 0; (tag < MT_TAG_COUNT) && (result == KSFT_PASS); tag++) {
                ret = mte_switch_mode(mode, MT_INCLUDE_VALID_TAG(tag), false);
                if (ret != 0)
                        result = KSFT_FAIL;
                /* Try to catch a excluded tag by a number of tries. */
                for (run = 0; (run < RUNS) && (result == KSFT_PASS); run++) {
                        ptr = mte_insert_tags(ptr, BUFFER_SIZE);
                        /* Check tag value */
                        if (MT_FETCH_TAG((uintptr_t)ptr) == tag) {
                                ksft_print_msg("FAIL: wrong tag = 0x%lx with include mask=0x%x\n",
                                               MT_FETCH_TAG((uintptr_t)ptr),
                                               MT_INCLUDE_VALID_TAG(tag));
                                result = KSFT_FAIL;
                                break;
                        }
                        result = verify_mte_pointer_validity(ptr, mode);
                }
        }
        mte_free_memory_tag_range(ptr, BUFFER_SIZE, mem_type, 0, MT_GRANULE_SIZE);
        return result;
}

static int check_multiple_included_tags(int mem_type, int mode)
{
        char *ptr;
        int tag, run, result = KSFT_PASS;
        unsigned long excl_mask = 0;

        ptr = mte_allocate_memory(BUFFER_SIZE + MT_GRANULE_SIZE, mem_type, 0, false);
        if (check_allocated_memory(ptr, BUFFER_SIZE + MT_GRANULE_SIZE,
                                   mem_type, false) != KSFT_PASS)
                return KSFT_FAIL;

        for (tag = 0; (tag < MT_TAG_COUNT - 1) && (result == KSFT_PASS); tag++) {
                excl_mask |= 1 << tag;
                mte_switch_mode(mode, MT_INCLUDE_VALID_TAGS(excl_mask), false);
                /* Try to catch a excluded tag by a number of tries. */
                for (run = 0; (run < RUNS) && (result == KSFT_PASS); run++) {
                        ptr = mte_insert_tags(ptr, BUFFER_SIZE);
                        /* Check tag value */
                        if (MT_FETCH_TAG((uintptr_t)ptr) < tag) {
                                ksft_print_msg("FAIL: wrong tag = 0x%lx with include mask=0x%lx\n",
                                               MT_FETCH_TAG((uintptr_t)ptr),
                                               MT_INCLUDE_VALID_TAGS(excl_mask));
                                result = KSFT_FAIL;
                                break;
                        }
                        result = verify_mte_pointer_validity(ptr, mode);
                }
        }
        mte_free_memory_tag_range(ptr, BUFFER_SIZE, mem_type, 0, MT_GRANULE_SIZE);
        return result;
}

static int check_all_included_tags(int mem_type, int mode)
{
        char *ptr;
        int run, ret, result = KSFT_PASS;

        ptr = mte_allocate_memory(BUFFER_SIZE + MT_GRANULE_SIZE, mem_type, 0, false);
        if (check_allocated_memory(ptr, BUFFER_SIZE + MT_GRANULE_SIZE,
                                   mem_type, false) != KSFT_PASS)
                return KSFT_FAIL;

        ret = mte_switch_mode(mode, MT_INCLUDE_TAG_MASK, false);
        if (ret != 0)
                return KSFT_FAIL;
        /* Try to catch a excluded tag by a number of tries. */
        for (run = 0; (run < RUNS) && (result == KSFT_PASS); run++) {
                ptr = (char *)mte_insert_tags(ptr, BUFFER_SIZE);
                /*
                 * Here tag byte can be between 0x0 to 0xF (full allowed range)
                 * so no need to match so just verify if it is writable.
                 */
                result = verify_mte_pointer_validity(ptr, mode);
        }
        mte_free_memory_tag_range(ptr, BUFFER_SIZE, mem_type, 0, MT_GRANULE_SIZE);
        return result;
}

static int check_none_included_tags(int mem_type, int mode)
{
        char *ptr;
        int run, ret;

        ptr = mte_allocate_memory(BUFFER_SIZE, mem_type, 0, false);
        if (check_allocated_memory(ptr, BUFFER_SIZE, mem_type, false) != KSFT_PASS)
                return KSFT_FAIL;

        ret = mte_switch_mode(mode, MT_EXCLUDE_TAG_MASK, false);
        if (ret != 0)
                return KSFT_FAIL;
        /* Try to catch a excluded tag by a number of tries. */
        for (run = 0; run < RUNS; run++) {
                ptr = (char *)mte_insert_tags(ptr, BUFFER_SIZE);
                /* Here all tags exluded so tag value generated should be 0 */
                if (MT_FETCH_TAG((uintptr_t)ptr)) {
                        ksft_print_msg("FAIL: included tag value found\n");
                        mte_free_memory((void *)ptr, BUFFER_SIZE, mem_type, true);
                        return KSFT_FAIL;
                }
                mte_initialize_current_context(mode, (uintptr_t)ptr, BUFFER_SIZE);
                /* Check the write validity of the untagged pointer */
                memset(ptr, '1', BUFFER_SIZE);
                mte_wait_after_trig();
                if (cur_mte_cxt.fault_valid)
                        break;
        }
        mte_free_memory(ptr, BUFFER_SIZE, mem_type, false);
        if (cur_mte_cxt.fault_valid)
                return KSFT_FAIL;
        else
                return KSFT_PASS;
}

int main(int argc, char *argv[])
{
        int err;

        err = mte_default_setup();
        if (err)
                return err;

        /* Register SIGSEGV handler */
        mte_register_signal(SIGSEGV, mte_default_handler, false);

        /* Set test plan */
        ksft_set_plan(4);

        evaluate_test(check_single_included_tags(USE_MMAP, MTE_SYNC_ERR),
                      "Check an included tag value with sync mode\n");
        evaluate_test(check_multiple_included_tags(USE_MMAP, MTE_SYNC_ERR),
                      "Check different included tags value with sync mode\n");
        evaluate_test(check_none_included_tags(USE_MMAP, MTE_SYNC_ERR),
                      "Check none included tags value with sync mode\n");
        evaluate_test(check_all_included_tags(USE_MMAP, MTE_SYNC_ERR),
                      "Check all included tags value with sync mode\n");

        mte_restore_setup();
        ksft_print_cnts();
        return ksft_get_fail_cnt() == 0 ? KSFT_PASS : KSFT_FAIL;
}