root/arch/csky/abiv1/alignment.c
// SPDX-License-Identifier: GPL-2.0
// Copyright (C) 2018 Hangzhou C-SKY Microsystems co.,ltd.

#include <linux/kernel.h>
#include <linux/uaccess.h>
#include <linux/ptrace.h>

static int align_kern_enable = 1;
static int align_usr_enable = 1;
static int align_kern_count = 0;
static int align_usr_count = 0;

static inline uint32_t get_ptreg(struct pt_regs *regs, uint32_t rx)
{
        return rx == 15 ? regs->lr : *((uint32_t *)&(regs->a0) - 2 + rx);
}

static inline void put_ptreg(struct pt_regs *regs, uint32_t rx, uint32_t val)
{
        if (rx == 15)
                regs->lr = val;
        else
                *((uint32_t *)&(regs->a0) - 2 + rx) = val;
}

/*
 * Get byte-value from addr and set it to *valp.
 *
 * Success: return 0
 * Failure: return 1
 */
static int ldb_asm(uint32_t addr, uint32_t *valp)
{
        uint32_t val;
        int err;

        asm volatile (
                "movi   %0, 0\n"
                "1:\n"
                "ldb    %1, (%2)\n"
                "br     3f\n"
                "2:\n"
                "movi   %0, 1\n"
                "br     3f\n"
                ".section __ex_table,\"a\"\n"
                ".align 2\n"
                ".long  1b, 2b\n"
                ".previous\n"
                "3:\n"
                : "=&r"(err), "=r"(val)
                : "r" (addr)
        );

        *valp = val;

        return err;
}

/*
 * Put byte-value to addr.
 *
 * Success: return 0
 * Failure: return 1
 */
static int stb_asm(uint32_t addr, uint32_t val)
{
        int err;

        asm volatile (
                "movi   %0, 0\n"
                "1:\n"
                "stb    %1, (%2)\n"
                "br     3f\n"
                "2:\n"
                "movi   %0, 1\n"
                "br     3f\n"
                ".section __ex_table,\"a\"\n"
                ".align 2\n"
                ".long  1b, 2b\n"
                ".previous\n"
                "3:\n"
                : "=&r"(err)
                : "r"(val), "r" (addr)
        );

        return err;
}

/*
 * Get half-word from [rx + imm]
 *
 * Success: return 0
 * Failure: return 1
 */
static int ldh_c(struct pt_regs *regs, uint32_t rz, uint32_t addr)
{
        uint32_t byte0, byte1;

        if (ldb_asm(addr, &byte0))
                return 1;
        addr += 1;
        if (ldb_asm(addr, &byte1))
                return 1;

        byte0 |= byte1 << 8;
        put_ptreg(regs, rz, byte0);

        return 0;
}

/*
 * Store half-word to [rx + imm]
 *
 * Success: return 0
 * Failure: return 1
 */
static int sth_c(struct pt_regs *regs, uint32_t rz, uint32_t addr)
{
        uint32_t byte0, byte1;

        byte0 = byte1 = get_ptreg(regs, rz);

        byte0 &= 0xff;

        if (stb_asm(addr, byte0))
                return 1;

        addr += 1;
        byte1 = (byte1 >> 8) & 0xff;
        if (stb_asm(addr, byte1))
                return 1;

        return 0;
}

/*
 * Get word from [rx + imm]
 *
 * Success: return 0
 * Failure: return 1
 */
static int ldw_c(struct pt_regs *regs, uint32_t rz, uint32_t addr)
{
        uint32_t byte0, byte1, byte2, byte3;

        if (ldb_asm(addr, &byte0))
                return 1;

        addr += 1;
        if (ldb_asm(addr, &byte1))
                return 1;

        addr += 1;
        if (ldb_asm(addr, &byte2))
                return 1;

        addr += 1;
        if (ldb_asm(addr, &byte3))
                return 1;

        byte0 |= byte1 << 8;
        byte0 |= byte2 << 16;
        byte0 |= byte3 << 24;

        put_ptreg(regs, rz, byte0);

        return 0;
}

/*
 * Store word to [rx + imm]
 *
 * Success: return 0
 * Failure: return 1
 */
static int stw_c(struct pt_regs *regs, uint32_t rz, uint32_t addr)
{
        uint32_t byte0, byte1, byte2, byte3;

        byte0 = byte1 = byte2 = byte3 = get_ptreg(regs, rz);

        byte0 &= 0xff;

        if (stb_asm(addr, byte0))
                return 1;

        addr += 1;
        byte1 = (byte1 >> 8) & 0xff;
        if (stb_asm(addr, byte1))
                return 1;

        addr += 1;
        byte2 = (byte2 >> 16) & 0xff;
        if (stb_asm(addr, byte2))
                return 1;

        addr += 1;
        byte3 = (byte3 >> 24) & 0xff;
        if (stb_asm(addr, byte3))
                return 1;

        return 0;
}

extern int fixup_exception(struct pt_regs *regs);

#define OP_LDH 0xc000
#define OP_STH 0xd000
#define OP_LDW 0x8000
#define OP_STW 0x9000

void csky_alignment(struct pt_regs *regs)
{
        int ret;
        uint16_t tmp;
        uint32_t opcode = 0;
        uint32_t rx     = 0;
        uint32_t rz     = 0;
        uint32_t imm    = 0;
        uint32_t addr   = 0;

        if (!user_mode(regs))
                goto kernel_area;

        if (!align_usr_enable) {
                pr_err("%s user disabled.\n", __func__);
                goto bad_area;
        }

        align_usr_count++;

        ret = get_user(tmp, (uint16_t *)instruction_pointer(regs));
        if (ret) {
                pr_err("%s get_user failed.\n", __func__);
                goto bad_area;
        }

        goto good_area;

kernel_area:
        if (!align_kern_enable) {
                pr_err("%s kernel disabled.\n", __func__);
                goto bad_area;
        }

        align_kern_count++;

        tmp = *(uint16_t *)instruction_pointer(regs);

good_area:
        opcode = (uint32_t)tmp;

        rx  = opcode & 0xf;
        imm = (opcode >> 4) & 0xf;
        rz  = (opcode >> 8) & 0xf;
        opcode &= 0xf000;

        if (rx == 0 || rx == 1 || rz == 0 || rz == 1)
                goto bad_area;

        switch (opcode) {
        case OP_LDH:
                addr = get_ptreg(regs, rx) + (imm << 1);
                ret = ldh_c(regs, rz, addr);
                break;
        case OP_LDW:
                addr = get_ptreg(regs, rx) + (imm << 2);
                ret = ldw_c(regs, rz, addr);
                break;
        case OP_STH:
                addr = get_ptreg(regs, rx) + (imm << 1);
                ret = sth_c(regs, rz, addr);
                break;
        case OP_STW:
                addr = get_ptreg(regs, rx) + (imm << 2);
                ret = stw_c(regs, rz, addr);
                break;
        }

        if (ret)
                goto bad_area;

        regs->pc += 2;

        return;

bad_area:
        if (!user_mode(regs)) {
                if (fixup_exception(regs))
                        return;

                bust_spinlocks(1);
                pr_alert("%s opcode: %x, rz: %d, rx: %d, imm: %d, addr: %x.\n",
                                __func__, opcode, rz, rx, imm, addr);
                show_regs(regs);
                bust_spinlocks(0);
                make_task_dead(SIGKILL);
        }

        force_sig_fault(SIGBUS, BUS_ADRALN, (void __user *)addr);
}

static const struct ctl_table alignment_tbl[] = {
        {
                .procname = "kernel_enable",
                .data = &align_kern_enable,
                .maxlen = sizeof(align_kern_enable),
                .mode = 0666,
                .proc_handler = &proc_dointvec
        },
        {
                .procname = "user_enable",
                .data = &align_usr_enable,
                .maxlen = sizeof(align_usr_enable),
                .mode = 0666,
                .proc_handler = &proc_dointvec
        },
        {
                .procname = "kernel_count",
                .data = &align_kern_count,
                .maxlen = sizeof(align_kern_count),
                .mode = 0666,
                .proc_handler = &proc_dointvec
        },
        {
                .procname = "user_count",
                .data = &align_usr_count,
                .maxlen = sizeof(align_usr_count),
                .mode = 0666,
                .proc_handler = &proc_dointvec
        },
};

static int __init csky_alignment_init(void)
{
        register_sysctl_init("csky/csky_alignment", alignment_tbl);
        return 0;
}

arch_initcall(csky_alignment_init);