root/tools/testing/selftests/net/psock_snd.c
// SPDX-License-Identifier: GPL-2.0

#define _GNU_SOURCE

#include <arpa/inet.h>
#include <errno.h>
#include <error.h>
#include <fcntl.h>
#include <limits.h>
#include <linux/filter.h>
#include <linux/bpf.h>
#include <linux/if_packet.h>
#include <linux/if_vlan.h>
#include <linux/virtio_net.h>
#include <net/if.h>
#include <net/ethernet.h>
#include <netinet/ip.h>
#include <netinet/udp.h>
#include <poll.h>
#include <sched.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>

#include "psock_lib.h"

static bool     cfg_use_bind;
static bool     cfg_use_csum_off;
static bool     cfg_use_csum_off_bad;
static bool     cfg_use_dgram;
static bool     cfg_use_gso;
static bool     cfg_use_qdisc_bypass;
static bool     cfg_use_vlan;
static bool     cfg_use_vnet;

static char     *cfg_ifname = "lo";
static int      cfg_mtu = 1500;
static int      cfg_payload_len = DATA_LEN;
static int      cfg_truncate_len = INT_MAX;
static uint16_t cfg_port = 8000;

/* test sending up to max mtu + 1 */
#define TEST_SZ (sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1)

static char tbuf[TEST_SZ], rbuf[TEST_SZ];

static unsigned long add_csum_hword(const uint16_t *start, int num_u16)
{
        unsigned long sum = 0;
        int i;

        for (i = 0; i < num_u16; i++)
                sum += start[i];

        return sum;
}

static uint16_t build_ip_csum(const uint16_t *start, int num_u16,
                              unsigned long sum)
{
        sum += add_csum_hword(start, num_u16);

        while (sum >> 16)
                sum = (sum & 0xffff) + (sum >> 16);

        return ~sum;
}

static int build_vnet_header(void *header)
{
        struct virtio_net_hdr *vh = header;

        vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr);

        if (cfg_use_csum_off) {
                vh->flags |= VIRTIO_NET_HDR_F_NEEDS_CSUM;
                vh->csum_start = ETH_HLEN + sizeof(struct iphdr);
                vh->csum_offset = __builtin_offsetof(struct udphdr, check);

                /* position check field exactly one byte beyond end of packet */
                if (cfg_use_csum_off_bad)
                        vh->csum_start += sizeof(struct udphdr) + cfg_payload_len -
                                          vh->csum_offset - 1;
        }

        if (cfg_use_gso) {
                vh->gso_type = VIRTIO_NET_HDR_GSO_UDP;
                vh->gso_size = cfg_mtu - sizeof(struct iphdr);
        }

        return sizeof(*vh);
}

static int build_eth_header(void *header)
{
        struct ethhdr *eth = header;

        if (cfg_use_vlan) {
                uint16_t *tag = header + ETH_HLEN;

                eth->h_proto = htons(ETH_P_8021Q);
                tag[1] = htons(ETH_P_IP);
                return ETH_HLEN + 4;
        }

        eth->h_proto = htons(ETH_P_IP);
        return ETH_HLEN;
}

static int build_ipv4_header(void *header, int payload_len)
{
        struct iphdr *iph = header;

        iph->ihl = 5;
        iph->version = 4;
        iph->ttl = 8;
        iph->tot_len = htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len);
        iph->id = htons(1337);
        iph->protocol = IPPROTO_UDP;
        iph->saddr = htonl((172 << 24) | (17 << 16) | 2);
        iph->daddr = htonl((172 << 24) | (17 << 16) | 1);
        iph->check = build_ip_csum((void *) iph, iph->ihl << 1, 0);

        return iph->ihl << 2;
}

static int build_udp_header(void *header, int payload_len)
{
        const int alen = sizeof(uint32_t);
        struct udphdr *udph = header;
        int len = sizeof(*udph) + payload_len;

        udph->source = htons(9);
        udph->dest = htons(cfg_port);
        udph->len = htons(len);

        if (cfg_use_csum_off)
                udph->check = build_ip_csum(header - (2 * alen), alen,
                                            htons(IPPROTO_UDP) + udph->len);
        else
                udph->check = 0;

        return sizeof(*udph);
}

static int build_packet(int payload_len)
{
        int off = 0;

        off += build_vnet_header(tbuf);
        off += build_eth_header(tbuf + off);
        off += build_ipv4_header(tbuf + off, payload_len);
        off += build_udp_header(tbuf + off, payload_len);

        if (off + payload_len > sizeof(tbuf))
                error(1, 0, "payload length exceeds max");

        memset(tbuf + off, DATA_CHAR, payload_len);

        return off + payload_len;
}

static void do_bind(int fd)
{
        struct sockaddr_ll laddr = {0};

        laddr.sll_family = AF_PACKET;
        laddr.sll_protocol = htons(ETH_P_IP);
        laddr.sll_ifindex = if_nametoindex(cfg_ifname);
        if (!laddr.sll_ifindex)
                error(1, errno, "if_nametoindex");

        if (bind(fd, (void *)&laddr, sizeof(laddr)))
                error(1, errno, "bind");
}

static void do_send(int fd, char *buf, int len)
{
        int ret;

        if (!cfg_use_vnet) {
                buf += sizeof(struct virtio_net_hdr);
                len -= sizeof(struct virtio_net_hdr);
        }
        if (cfg_use_dgram) {
                buf += ETH_HLEN;
                len -= ETH_HLEN;
        }

        if (cfg_use_bind) {
                ret = write(fd, buf, len);
        } else {
                struct sockaddr_ll laddr = {0};

                laddr.sll_protocol = htons(ETH_P_IP);
                laddr.sll_ifindex = if_nametoindex(cfg_ifname);
                if (!laddr.sll_ifindex)
                        error(1, errno, "if_nametoindex");

                ret = sendto(fd, buf, len, 0, (void *)&laddr, sizeof(laddr));
        }

        if (ret == -1)
                error(1, errno, "write");
        if (ret != len)
                error(1, 0, "write: %u %u", ret, len);

        fprintf(stderr, "tx: %u\n", ret);
}

static int do_tx(void)
{
        const int one = 1;
        int fd, len;

        fd = socket(PF_PACKET, cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW, 0);
        if (fd == -1)
                error(1, errno, "socket t");

        if (cfg_use_bind)
                do_bind(fd);

        if (cfg_use_qdisc_bypass &&
            setsockopt(fd, SOL_PACKET, PACKET_QDISC_BYPASS, &one, sizeof(one)))
                error(1, errno, "setsockopt qdisc bypass");

        if (cfg_use_vnet &&
            setsockopt(fd, SOL_PACKET, PACKET_VNET_HDR, &one, sizeof(one)))
                error(1, errno, "setsockopt vnet");

        len = build_packet(cfg_payload_len);

        if (cfg_truncate_len < len)
                len = cfg_truncate_len;

        do_send(fd, tbuf, len);

        if (close(fd))
                error(1, errno, "close t");

        return len;
}

static int setup_rx(void)
{
        struct timeval tv = { .tv_usec = 100 * 1000 };
        struct sockaddr_in raddr = {0};
        int fd;

        fd = socket(PF_INET, SOCK_DGRAM, 0);
        if (fd == -1)
                error(1, errno, "socket r");

        if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
                error(1, errno, "setsockopt rcv timeout");

        raddr.sin_family = AF_INET;
        raddr.sin_port = htons(cfg_port);
        raddr.sin_addr.s_addr = htonl(INADDR_ANY);

        if (bind(fd, (void *)&raddr, sizeof(raddr)))
                error(1, errno, "bind r");

        return fd;
}

static void do_rx(int fd, int expected_len, char *expected)
{
        int ret;

        ret = recv(fd, rbuf, sizeof(rbuf), 0);
        if (ret == -1)
                error(1, errno, "recv");
        if (ret != expected_len)
                error(1, 0, "recv: %u != %u", ret, expected_len);

        if (memcmp(rbuf, expected, ret))
                error(1, 0, "recv: data mismatch");

        fprintf(stderr, "rx: %u\n", ret);
}

static int setup_sniffer(void)
{
        struct timeval tv = { .tv_usec = 100 * 1000 };
        int fd;

        fd = socket(PF_PACKET, SOCK_RAW, 0);
        if (fd == -1)
                error(1, errno, "socket p");

        if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
                error(1, errno, "setsockopt rcv timeout");

        pair_udp_setfilter(fd);
        do_bind(fd);

        return fd;
}

static void parse_opts(int argc, char **argv)
{
        int c;

        while ((c = getopt(argc, argv, "bcCdgl:qt:vV")) != -1) {
                switch (c) {
                case 'b':
                        cfg_use_bind = true;
                        break;
                case 'c':
                        cfg_use_csum_off = true;
                        break;
                case 'C':
                        cfg_use_csum_off_bad = true;
                        break;
                case 'd':
                        cfg_use_dgram = true;
                        break;
                case 'g':
                        cfg_use_gso = true;
                        break;
                case 'l':
                        cfg_payload_len = strtoul(optarg, NULL, 0);
                        break;
                case 'q':
                        cfg_use_qdisc_bypass = true;
                        break;
                case 't':
                        cfg_truncate_len = strtoul(optarg, NULL, 0);
                        break;
                case 'v':
                        cfg_use_vnet = true;
                        break;
                case 'V':
                        cfg_use_vlan = true;
                        break;
                default:
                        error(1, 0, "%s: parse error", argv[0]);
                }
        }

        if (cfg_use_vlan && cfg_use_dgram)
                error(1, 0, "option vlan (-V) conflicts with dgram (-d)");

        if (cfg_use_csum_off && !cfg_use_vnet)
                error(1, 0, "option csum offload (-c) requires vnet (-v)");

        if (cfg_use_csum_off_bad && !cfg_use_csum_off)
                error(1, 0, "option csum bad (-C) requires csum offload (-c)");

        if (cfg_use_gso && !cfg_use_csum_off)
                error(1, 0, "option gso (-g) requires csum offload (-c)");
}

static void run_test(void)
{
        int fdr, fds, total_len;

        fdr = setup_rx();
        fds = setup_sniffer();

        total_len = do_tx();

        /* BPF filter accepts only this length, vlan changes MAC */
        if (cfg_payload_len == DATA_LEN && !cfg_use_vlan)
                do_rx(fds, total_len - sizeof(struct virtio_net_hdr),
                      tbuf + sizeof(struct virtio_net_hdr));

        do_rx(fdr, cfg_payload_len, tbuf + total_len - cfg_payload_len);

        if (close(fds))
                error(1, errno, "close s");
        if (close(fdr))
                error(1, errno, "close r");
}

int main(int argc, char **argv)
{
        parse_opts(argc, argv);

        if (system("ip link set dev lo mtu 1500"))
                error(1, errno, "ip link set mtu");
        if (system("ip addr add dev lo 172.17.0.1/24"))
                error(1, errno, "ip addr add");
        if (system("sysctl -w net.ipv4.conf.lo.accept_local=1"))
                error(1, errno, "sysctl lo.accept_local");

        run_test();

        fprintf(stderr, "OK\n\n");
        return 0;
}