root/tools/testing/vsock/vsock_test_zerocopy.c
// SPDX-License-Identifier: GPL-2.0-only
/* MSG_ZEROCOPY feature tests for vsock
 *
 * Copyright (C) 2023 SberDevices.
 *
 * Author: Arseniy Krasnov <avkrasnov@salutedevices.com>
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <unistd.h>
#include <poll.h>
#include <linux/errqueue.h>
#include <linux/kernel.h>
#include <linux/sockios.h>
#include <linux/time64.h>
#include <errno.h>

#include "control.h"
#include "timeout.h"
#include "vsock_test_zerocopy.h"
#include "msg_zerocopy_common.h"

#ifndef PAGE_SIZE
#define PAGE_SIZE               4096
#endif

#define VSOCK_TEST_DATA_MAX_IOV 3

struct vsock_test_data {
        /* This test case if for SOCK_STREAM only. */
        bool stream_only;
        /* Data must be zerocopied. This field is checked against
         * field 'ee_code' of the 'struct sock_extended_err', which
         * contains bit to detect that zerocopy transmission was
         * fallbacked to copy mode.
         */
        bool zerocopied;
        /* Enable SO_ZEROCOPY option on the socket. Without enabled
         * SO_ZEROCOPY, every MSG_ZEROCOPY transmission will behave
         * like without MSG_ZEROCOPY flag.
         */
        bool so_zerocopy;
        /* 'errno' after 'sendmsg()' call. */
        int sendmsg_errno;
        /* Number of valid elements in 'vecs'. */
        int vecs_cnt;
        struct iovec vecs[VSOCK_TEST_DATA_MAX_IOV];
};

static struct vsock_test_data test_data_array[] = {
        /* Last element has non-page aligned size. */
        {
                .zerocopied = true,
                .so_zerocopy = true,
                .sendmsg_errno = 0,
                .vecs_cnt = 3,
                {
                        { NULL, PAGE_SIZE },
                        { NULL, PAGE_SIZE },
                        { NULL, 200 }
                }
        },
        /* All elements have page aligned base and size. */
        {
                .zerocopied = true,
                .so_zerocopy = true,
                .sendmsg_errno = 0,
                .vecs_cnt = 3,
                {
                        { NULL, PAGE_SIZE },
                        { NULL, PAGE_SIZE * 2 },
                        { NULL, PAGE_SIZE * 3 }
                }
        },
        /* All elements have page aligned base and size. But
         * data length is bigger than 64Kb.
         */
        {
                .zerocopied = true,
                .so_zerocopy = true,
                .sendmsg_errno = 0,
                .vecs_cnt = 3,
                {
                        { NULL, PAGE_SIZE * 16 },
                        { NULL, PAGE_SIZE * 16 },
                        { NULL, PAGE_SIZE * 16 }
                }
        },
        /* Middle element has both non-page aligned base and size. */
        {
                .zerocopied = true,
                .so_zerocopy = true,
                .sendmsg_errno = 0,
                .vecs_cnt = 3,
                {
                        { NULL, PAGE_SIZE },
                        { (void *)1, 100 },
                        { NULL, PAGE_SIZE }
                }
        },
        /* Middle element is unmapped. */
        {
                .zerocopied = false,
                .so_zerocopy = true,
                .sendmsg_errno = ENOMEM,
                .vecs_cnt = 3,
                {
                        { NULL, PAGE_SIZE },
                        { MAP_FAILED, PAGE_SIZE },
                        { NULL, PAGE_SIZE }
                }
        },
        /* Valid data, but SO_ZEROCOPY is off. This
         * will trigger fallback to copy.
         */
        {
                .zerocopied = false,
                .so_zerocopy = false,
                .sendmsg_errno = 0,
                .vecs_cnt = 1,
                {
                        { NULL, PAGE_SIZE }
                }
        },
        /* Valid data, but message is bigger than peer's
         * buffer, so this will trigger fallback to copy.
         * This test is for SOCK_STREAM only, because
         * for SOCK_SEQPACKET, 'sendmsg()' returns EMSGSIZE.
         */
        {
                .stream_only = true,
                .zerocopied = false,
                .so_zerocopy = true,
                .sendmsg_errno = 0,
                .vecs_cnt = 1,
                {
                        { NULL, 100 * PAGE_SIZE }
                }
        },
};

#define POLL_TIMEOUT_MS         100

static void test_client(const struct test_opts *opts,
                        const struct vsock_test_data *test_data,
                        bool sock_seqpacket)
{
        struct pollfd fds = { 0 };
        struct msghdr msg = { 0 };
        ssize_t sendmsg_res;
        struct iovec *iovec;
        int fd;

        if (sock_seqpacket)
                fd = vsock_seqpacket_connect(opts->peer_cid, opts->peer_port);
        else
                fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);

        if (fd < 0) {
                perror("connect");
                exit(EXIT_FAILURE);
        }

        if (test_data->so_zerocopy)
                enable_so_zerocopy_check(fd);

        iovec = alloc_test_iovec(test_data->vecs, test_data->vecs_cnt);

        msg.msg_iov = iovec;
        msg.msg_iovlen = test_data->vecs_cnt;

        errno = 0;

        sendmsg_res = sendmsg(fd, &msg, MSG_ZEROCOPY);
        if (errno != test_data->sendmsg_errno) {
                fprintf(stderr, "expected 'errno' == %i, got %i\n",
                        test_data->sendmsg_errno, errno);
                exit(EXIT_FAILURE);
        }

        if (!errno) {
                if (sendmsg_res != iovec_bytes(iovec, test_data->vecs_cnt)) {
                        fprintf(stderr, "expected 'sendmsg()' == %li, got %li\n",
                                iovec_bytes(iovec, test_data->vecs_cnt),
                                sendmsg_res);
                        exit(EXIT_FAILURE);
                }
        }

        fds.fd = fd;
        fds.events = 0;

        if (poll(&fds, 1, POLL_TIMEOUT_MS) < 0) {
                perror("poll");
                exit(EXIT_FAILURE);
        }

        if (fds.revents & POLLERR) {
                vsock_recv_completion(fd, &test_data->zerocopied);
        } else if (test_data->so_zerocopy && !test_data->sendmsg_errno) {
                /* If we don't have data in the error queue, but
                 * SO_ZEROCOPY was enabled and 'sendmsg()' was
                 * successful - this is an error.
                 */
                fprintf(stderr, "POLLERR expected\n");
                exit(EXIT_FAILURE);
        }

        if (!test_data->sendmsg_errno)
                control_writeulong(iovec_hash_djb2(iovec, test_data->vecs_cnt));
        else
                control_writeulong(0);

        control_writeln("DONE");
        free_test_iovec(test_data->vecs, iovec, test_data->vecs_cnt);
        close(fd);
}

void test_stream_msgzcopy_client(const struct test_opts *opts)
{
        int i;

        for (i = 0; i < ARRAY_SIZE(test_data_array); i++)
                test_client(opts, &test_data_array[i], false);
}

void test_seqpacket_msgzcopy_client(const struct test_opts *opts)
{
        int i;

        for (i = 0; i < ARRAY_SIZE(test_data_array); i++) {
                if (test_data_array[i].stream_only)
                        continue;

                test_client(opts, &test_data_array[i], true);
        }
}

static void test_server(const struct test_opts *opts,
                        const struct vsock_test_data *test_data,
                        bool sock_seqpacket)
{
        unsigned long remote_hash;
        unsigned long local_hash;
        ssize_t total_bytes_rec;
        unsigned char *data;
        size_t data_len;
        int fd;

        if (sock_seqpacket)
                fd = vsock_seqpacket_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
        else
                fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);

        if (fd < 0) {
                perror("accept");
                exit(EXIT_FAILURE);
        }

        data_len = iovec_bytes(test_data->vecs, test_data->vecs_cnt);

        data = malloc(data_len);
        if (!data) {
                perror("malloc");
                exit(EXIT_FAILURE);
        }

        total_bytes_rec = 0;

        while (total_bytes_rec != data_len) {
                ssize_t bytes_rec;

                bytes_rec = read(fd, data + total_bytes_rec,
                                 data_len - total_bytes_rec);
                if (bytes_rec <= 0)
                        break;

                total_bytes_rec += bytes_rec;
        }

        if (test_data->sendmsg_errno == 0)
                local_hash = hash_djb2(data, data_len);
        else
                local_hash = 0;

        free(data);

        /* Waiting for some result. */
        remote_hash = control_readulong();
        if (remote_hash != local_hash) {
                fprintf(stderr, "hash mismatch\n");
                exit(EXIT_FAILURE);
        }

        control_expectln("DONE");
        close(fd);
}

void test_stream_msgzcopy_server(const struct test_opts *opts)
{
        int i;

        for (i = 0; i < ARRAY_SIZE(test_data_array); i++)
                test_server(opts, &test_data_array[i], false);
}

void test_seqpacket_msgzcopy_server(const struct test_opts *opts)
{
        int i;

        for (i = 0; i < ARRAY_SIZE(test_data_array); i++) {
                if (test_data_array[i].stream_only)
                        continue;

                test_server(opts, &test_data_array[i], true);
        }
}

void test_stream_msgzcopy_empty_errq_client(const struct test_opts *opts)
{
        struct msghdr msg = { 0 };
        char cmsg_data[128];
        ssize_t res;
        int fd;

        fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
        if (fd < 0) {
                perror("connect");
                exit(EXIT_FAILURE);
        }

        msg.msg_control = cmsg_data;
        msg.msg_controllen = sizeof(cmsg_data);

        res = recvmsg(fd, &msg, MSG_ERRQUEUE);
        if (res != -1) {
                fprintf(stderr, "expected 'recvmsg(2)' failure, got %zi\n",
                        res);
                exit(EXIT_FAILURE);
        }

        control_writeln("DONE");
        close(fd);
}

void test_stream_msgzcopy_empty_errq_server(const struct test_opts *opts)
{
        int fd;

        fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
        if (fd < 0) {
                perror("accept");
                exit(EXIT_FAILURE);
        }

        control_expectln("DONE");
        close(fd);
}

#define GOOD_COPY_LEN   128     /* net/vmw_vsock/virtio_transport_common.c */

void test_stream_msgzcopy_mangle_client(const struct test_opts *opts)
{
        char sbuf1[PAGE_SIZE + 1], sbuf2[GOOD_COPY_LEN];
        unsigned long hash;
        struct pollfd fds;
        int fd, i;

        fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
        if (fd < 0) {
                perror("connect");
                exit(EXIT_FAILURE);
        }

        enable_so_zerocopy_check(fd);

        memset(sbuf1, 'x', sizeof(sbuf1));
        send_buf(fd, sbuf1, sizeof(sbuf1), 0, sizeof(sbuf1));

        for (i = 0; i < sizeof(sbuf2); i++)
                sbuf2[i] = rand() & 0xff;

        send_buf(fd, sbuf2, sizeof(sbuf2), MSG_ZEROCOPY, sizeof(sbuf2));

        hash = hash_djb2(sbuf2, sizeof(sbuf2));
        control_writeulong(hash);

        fds.fd = fd;
        fds.events = 0;

        if (poll(&fds, 1, TIMEOUT * MSEC_PER_SEC) != 1 ||
            !(fds.revents & POLLERR)) {
                perror("poll");
                exit(EXIT_FAILURE);
        }

        close(fd);
}

void test_stream_msgzcopy_mangle_server(const struct test_opts *opts)
{
        unsigned long local_hash, remote_hash;
        char rbuf[PAGE_SIZE + 1];
        int fd;

        fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
        if (fd < 0) {
                perror("accept");
                exit(EXIT_FAILURE);
        }

        /* Wait, don't race the (buggy) skbs coalescence. */
        vsock_ioctl_int(fd, SIOCINQ, PAGE_SIZE + 1 + GOOD_COPY_LEN);

        /* Discard the first packet. */
        recv_buf(fd, rbuf, PAGE_SIZE + 1, 0, PAGE_SIZE + 1);

        recv_buf(fd, rbuf, GOOD_COPY_LEN, 0, GOOD_COPY_LEN);
        remote_hash = control_readulong();
        local_hash = hash_djb2(rbuf, GOOD_COPY_LEN);

        if (local_hash != remote_hash) {
                fprintf(stderr, "Data received corrupted\n");
                exit(EXIT_FAILURE);
        }

        close(fd);
}