root/tools/testing/selftests/arm64/fp/vlset.c
// SPDX-License-Identifier: GPL-2.0-only
/*
 * Copyright (C) 2015-2019 ARM Limited.
 * Original author: Dave Martin <Dave.Martin@arm.com>
 */
#define _GNU_SOURCE
#include <assert.h>
#include <errno.h>
#include <limits.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <getopt.h>
#include <unistd.h>
#include <sys/auxv.h>
#include <sys/prctl.h>
#include <asm/hwcap.h>
#include <asm/sigcontext.h>

static int inherit = 0;
static int no_inherit = 0;
static int force = 0;
static unsigned long vl;
static int set_ctl = PR_SVE_SET_VL;
static int get_ctl = PR_SVE_GET_VL;

static const struct option options[] = {
        { "force",      no_argument, NULL, 'f' },
        { "inherit",    no_argument, NULL, 'i' },
        { "max",        no_argument, NULL, 'M' },
        { "no-inherit", no_argument, &no_inherit, 1 },
        { "sme",        no_argument, NULL, 's' },
        { "help",       no_argument, NULL, '?' },
        {}
};

static char const *program_name;

static int parse_options(int argc, char **argv)
{
        int c;
        char *rest;

        program_name = strrchr(argv[0], '/');
        if (program_name)
                ++program_name;
        else
                program_name = argv[0];

        while ((c = getopt_long(argc, argv, "Mfhi", options, NULL)) != -1)
                switch (c) {
                case 'M':       vl = SVE_VL_MAX; break;
                case 'f':       force = 1; break;
                case 'i':       inherit = 1; break;
                case 's':       set_ctl = PR_SME_SET_VL;
                                get_ctl = PR_SME_GET_VL;
                                break;
                case 0:         break;
                default:        goto error;
                }

        if (inherit && no_inherit)
                goto error;

        if (!vl) {
                /* vector length */
                if (optind >= argc)
                        goto error;

                errno = 0;
                vl = strtoul(argv[optind], &rest, 0);
                if (*rest) {
                        vl = ULONG_MAX;
                        errno = EINVAL;
                }
                if (vl == ULONG_MAX && errno) {
                        fprintf(stderr, "%s: %s: %s\n",
                                program_name, argv[optind], strerror(errno));
                        goto error;
                }

                ++optind;
        }

        /* command */
        if (optind >= argc)
                goto error;

        return 0;

error:
        fprintf(stderr,
                "Usage: %s [-f | --force] "
                "[-i | --inherit | --no-inherit] "
                "{-M | --max | <vector length>} "
                "<command> [<arguments> ...]\n",
                program_name);
        return -1;
}

int main(int argc, char **argv)
{
        int ret = 126;  /* same as sh(1) command-not-executable error */
        long flags;
        char *path;
        int t, e;

        if (parse_options(argc, argv))
                return 2;       /* same as sh(1) builtin incorrect-usage */

        if (vl & ~(vl & PR_SVE_VL_LEN_MASK)) {
                fprintf(stderr, "%s: Invalid vector length %lu\n",
                        program_name, vl);
                return 2;       /* same as sh(1) builtin incorrect-usage */
        }

        if (!(getauxval(AT_HWCAP) & HWCAP_SVE)) {
                fprintf(stderr, "%s: Scalable Vector Extension not present\n",
                        program_name);

                if (!force)
                        goto error;

                fputs("Going ahead anyway (--force):  "
                      "This is a debug option.  Don't rely on it.\n",
                      stderr);
        }

        flags = PR_SVE_SET_VL_ONEXEC;
        if (inherit)
                flags |= PR_SVE_VL_INHERIT;

        t = prctl(set_ctl, vl | flags);
        if (t < 0) {
                fprintf(stderr, "%s: PR_SVE_SET_VL: %s\n",
                        program_name, strerror(errno));
                goto error;
        }

        t = prctl(get_ctl);
        if (t == -1) {
                fprintf(stderr, "%s: PR_SVE_GET_VL: %s\n",
                        program_name, strerror(errno));
                goto error;
        }
        flags = PR_SVE_VL_LEN_MASK;
        flags = t & ~flags;

        assert(optind < argc);
        path = argv[optind];

        execvp(path, &argv[optind]);
        e = errno;
        if (errno == ENOENT)
                ret = 127;      /* same as sh(1) not-found error */
        fprintf(stderr, "%s: %s: %s\n", program_name, path, strerror(e));

error:
        return ret;             /* same as sh(1) not-executable error */
}