root/tools/testing/selftests/mm/write_to_hugetlbfs.c
// SPDX-License-Identifier: GPL-2.0
/*
 * This program reserves and uses hugetlb memory, supporting a bunch of
 * scenarios needed by the charged_reserved_hugetlb.sh test.
 */

#include <err.h>
#include <errno.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/shm.h>
#include <sys/stat.h>
#include <sys/mman.h>

/* Global definitions. */
enum method {
        HUGETLBFS,
        MMAP_MAP_HUGETLB,
        SHM,
        MAX_METHOD
};


/* Global variables. */
static const char *self;
static int *shmaddr;
static int shmid;

/*
 * Show usage and exit.
 */
static void exit_usage(void)
{
        printf("Usage: %s -p <path to hugetlbfs file> -s <size to map> "
               "[-m <0=hugetlbfs | 1=mmap(MAP_HUGETLB)>] [-l] [-r] "
               "[-o] [-w] [-n]\n",
               self);
        exit(EXIT_FAILURE);
}

void sig_handler(int signo)
{
        printf("Received %d.\n", signo);
        if (signo == SIGINT) {
                if (shmaddr) {
                        printf("Deleting the memory\n");
                        if (shmdt((const void *)shmaddr) != 0) {
                                perror("Detach failure");
                                shmctl(shmid, IPC_RMID, NULL);
                                exit(4);
                        }

                        shmctl(shmid, IPC_RMID, NULL);
                        printf("Done deleting the memory\n");
                }
        }
        exit(2);
}

int main(int argc, char **argv)
{
        int fd = 0;
        int key = 0;
        int *ptr = NULL;
        int c = 0;
        size_t size = 0;
        char path[256] = "";
        enum method method = MAX_METHOD;
        int want_sleep = 0, private = 0;
        int populate = 0;
        int write = 0;
        int reserve = 1;

        if (signal(SIGINT, sig_handler) == SIG_ERR)
                err(1, "\ncan't catch SIGINT\n");

        /* Parse command-line arguments. */
        setvbuf(stdout, NULL, _IONBF, 0);
        self = argv[0];

        while ((c = getopt(argc, argv, "s:p:m:owlrn")) != -1) {
                switch (c) {
                case 's':
                        if (sscanf(optarg, "%zu", &size) != 1) {
                                perror("Invalid -s.");
                                exit_usage();
                        }
                        break;
                case 'p':
                        strncpy(path, optarg, sizeof(path) - 1);
                        break;
                case 'm':
                        if (atoi(optarg) >= MAX_METHOD) {
                                errno = EINVAL;
                                perror("Invalid -m.");
                                exit_usage();
                        }
                        method = atoi(optarg);
                        break;
                case 'o':
                        populate = 1;
                        break;
                case 'w':
                        write = 1;
                        break;
                case 'l':
                        want_sleep = 1;
                        break;
                case 'r':
                    private
                        = 1;
                        break;
                case 'n':
                        reserve = 0;
                        break;
                default:
                        errno = EINVAL;
                        perror("Invalid arg");
                        exit_usage();
                }
        }

        if (strncmp(path, "", sizeof(path)) != 0) {
                printf("Writing to this path: %s\n", path);
        } else {
                errno = EINVAL;
                perror("path not found");
                exit_usage();
        }

        if (size != 0) {
                printf("Writing this size: %zu\n", size);
        } else {
                errno = EINVAL;
                perror("size not found");
                exit_usage();
        }

        if (!populate)
                printf("Not populating.\n");
        else
                printf("Populating.\n");

        if (!write)
                printf("Not writing to memory.\n");

        if (method == MAX_METHOD) {
                errno = EINVAL;
                perror("-m Invalid");
                exit_usage();
        } else
                printf("Using method=%d\n", method);

        if (!private)
                printf("Shared mapping.\n");
        else
                printf("Private mapping.\n");

        if (!reserve)
                printf("NO_RESERVE mapping.\n");
        else
                printf("RESERVE mapping.\n");

        switch (method) {
        case HUGETLBFS:
                printf("Allocating using HUGETLBFS.\n");
                fd = open(path, O_CREAT | O_RDWR, 0777);
                if (fd == -1)
                        err(1, "Failed to open file.");

                ptr = mmap(NULL, size, PROT_READ | PROT_WRITE,
                           (private ? MAP_PRIVATE : MAP_SHARED) |
                                   (populate ? MAP_POPULATE : 0) |
                                   (reserve ? 0 : MAP_NORESERVE),
                           fd, 0);

                if (ptr == MAP_FAILED) {
                        close(fd);
                        err(1, "Error mapping the file");
                }
                break;
        case MMAP_MAP_HUGETLB:
                printf("Allocating using MAP_HUGETLB.\n");
                ptr = mmap(NULL, size, PROT_READ | PROT_WRITE,
                           (private ? (MAP_PRIVATE | MAP_ANONYMOUS) :
                                      MAP_SHARED) |
                                   MAP_HUGETLB | (populate ? MAP_POPULATE : 0) |
                                   (reserve ? 0 : MAP_NORESERVE),
                           -1, 0);

                if (ptr == MAP_FAILED)
                        err(1, "mmap");

                printf("Returned address is %p\n", ptr);
                break;
        case SHM:
                printf("Allocating using SHM.\n");
                shmid = shmget(key, size,
                               SHM_HUGETLB | IPC_CREAT | SHM_R | SHM_W);
                if (shmid < 0) {
                        shmid = shmget(++key, size,
                                       SHM_HUGETLB | IPC_CREAT | SHM_R | SHM_W);
                        if (shmid < 0)
                                err(1, "shmget");
                }
                printf("shmid: 0x%x, shmget key:%d\n", shmid, key);

                ptr = shmat(shmid, NULL, 0);
                if (ptr == (int *)-1) {
                        perror("Shared memory attach failure");
                        shmctl(shmid, IPC_RMID, NULL);
                        exit(2);
                }
                shmaddr = ptr;
                printf("shmaddr: %p\n", shmaddr);

                break;
        default:
                errno = EINVAL;
                err(1, "Invalid method.");
        }

        if (write) {
                printf("Writing to memory.\n");
                memset(ptr, 1, size);
        }

        if (want_sleep) {
                /* Signal to caller that we're done. */
                printf("DONE\n");

                /* Hold memory until external kill signal is delivered. */
                while (1)
                        sleep(100);
        }

        if (method == HUGETLBFS)
                close(fd);

        return 0;
}