root/src/tests/system/network/unix_dgram_test.cpp
/*
 * Copyright 2023, Trung Nguyen, trungnt282910@gmail.com.
 * Distributed under the terms of the MIT License.
 */


#include <errno.h>
#include <stdio.h>
#include <string.h>

#include <fcntl.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <unistd.h>


#define REPORT_ERROR(msg, ...) \
        fprintf(stderr, "%s:%d: " msg "\n", __FILE__, __LINE__, ##__VA_ARGS__)


int
connect_test()
{
        unlink("test.sock");
        unlink("test1.sock");
        unlink("test2.sock");

        int status;

        int sock = socket(AF_UNIX, SOCK_DGRAM, 0);
        if (sock == -1) {
                REPORT_ERROR("socket() failed: %s\n", strerror(errno));
                return 1;
        }

        struct sockaddr_un addr;
        addr.sun_family = AF_UNIX;
        strcpy(addr.sun_path, "test.sock");
        status = bind(sock, (struct sockaddr*)&addr, sizeof(addr));
        if (status == -1) {
                REPORT_ERROR("bind() failed: %s\n", strerror(errno));
                return 1;
        }

        int sock1 = socket(AF_UNIX, SOCK_DGRAM, 0);
        if (sock1 == -1) {
                REPORT_ERROR("socket() failed: %s\n", strerror(errno));
                return 1;
        }

        struct sockaddr_un addr1;
        addr1.sun_family = AF_UNIX;
        strcpy(addr1.sun_path, "test1.sock");
        status = bind(sock1, (struct sockaddr*)&addr1, sizeof(addr1));
        if (status == -1) {
                REPORT_ERROR("bind() failed: %s\n", strerror(errno));
                return 1;
        }

        // Set non-blocking on both sockets
        int flags1 = fcntl(sock, F_GETFL, 0);
        if (flags1 == -1) {
                REPORT_ERROR("fcntl() failed: %s\n", strerror(errno));
                return 1;
        }
        status = fcntl(sock, F_SETFL, flags1 | O_NONBLOCK);
        if (status == -1) {
                REPORT_ERROR("fcntl() failed: %s\n", strerror(errno));
                return 1;
        }
        status = fcntl(sock1, F_SETFL, flags1 | O_NONBLOCK);
        if (status == -1) {
                REPORT_ERROR("fcntl() failed: %s\n", strerror(errno));
                return 1;
        }

        status = connect(sock, (struct sockaddr*)&addr1, sizeof(addr1));
        if (status == -1) {
                REPORT_ERROR("connect() failed: %s\n", strerror(errno));
                return 1;
        }

        // Connect in the opposite way
        status = connect(sock1, (struct sockaddr*)&addr, sizeof(addr));
        if (status == -1) {
                REPORT_ERROR("connect() failed: %s\n", strerror(errno));
                return 1;
        }

        // Reconnect a connected DGRAM socket
        status = connect(sock, (struct sockaddr*)&addr1, sizeof(addr1));
        if (status == -1) {
                REPORT_ERROR("connect() failed: %s\n", strerror(errno));
                return 1;
        }

        int sock2 = socket(AF_UNIX, SOCK_DGRAM, 0);
        if (sock2 == -1) {
                REPORT_ERROR("socket() failed: %s\n", strerror(errno));
                return 1;
        }

        struct sockaddr_un addr2;
        addr2.sun_family = AF_UNIX;
        strcpy(addr2.sun_path, "test2.sock");
        status = bind(sock2, (struct sockaddr*)&addr2, sizeof(addr2));
        if (status == -1) {
                REPORT_ERROR("bind() failed: %s\n", strerror(errno));
                return 1;
        }

        // Connect to a socket that are already connected
        status = connect(sock2, (struct sockaddr*)&addr1, sizeof(addr1));
        if (status != -1) {
                REPORT_ERROR("connect() succeeded unexpectedly\n");
                return 1;
        }
        if (errno != EPERM) {
                REPORT_ERROR("connect() failed with unexpected error: %s\n", strerror(errno));
                return 1;
        }

        status = close(sock2);
        if (status == -1) {
                REPORT_ERROR("close() failed: %s\n", strerror(errno));
                return 1;
        }

        // Connect to a closed socket
        status = connect(sock, (struct sockaddr*)&addr2, sizeof(addr2));
        if (status != -1) {
                REPORT_ERROR("connect() succeeded unexpectedly\n");
                return 1;
        }
        if (errno != ECONNREFUSED) {
                REPORT_ERROR("connect() failed with unexpected error: %s\n", strerror(errno));
                return 1;
        }

        close(sock);
        close(sock1);

        unlink("test.sock");
        unlink("test1.sock");
        unlink("test2.sock");

        return 0;
}


int
send_test()
{
        unlink("test.sock");
        unlink("test1.sock");
        unlink("test2.sock");

        int status;

        int sock = socket(AF_UNIX, SOCK_DGRAM, 0);
        if (sock == -1) {
                REPORT_ERROR("socket() failed: %s\n", strerror(errno));
                return 1;
        }

        struct sockaddr_un addr;
        memset(&addr, 0, sizeof(addr));
        addr.sun_family = AF_UNIX;
        strcpy(addr.sun_path, "test.sock");
        status = bind(sock, (struct sockaddr*)&addr, sizeof(addr));
        if (status == -1) {
                REPORT_ERROR("bind() failed: %s\n", strerror(errno));
                return 1;
        }

        status = send(sock, "test", 4, 0);
        if (status != -1) {
                REPORT_ERROR("send() succeeded unexpectedly\n");
                return 1;
        }
        // if (errno != ENOTCONN) {
        //      REPORT_ERROR("send() failed with unexpected error: %s\n", strerror(errno));
        //      return 1;
        // }

        int sock1 = socket(AF_UNIX, SOCK_DGRAM, 0);
        if (sock1 == -1) {
                REPORT_ERROR("socket() failed: %s\n", strerror(errno));
                return 1;
        }

        struct sockaddr_un addr1;
        addr1.sun_family = AF_UNIX;
        strcpy(addr1.sun_path, "test1.sock");
        status = bind(sock1, (struct sockaddr*)&addr1, sizeof(addr1));
        if (status == -1) {
                REPORT_ERROR("bind() failed: %s\n", strerror(errno));
                return 1;
        }

        // Set non-blocking on both sockets
        status = fcntl(sock, F_SETFL, O_NONBLOCK);
        if (status == -1) {
                REPORT_ERROR("fcntl() failed: %s\n", strerror(errno));
                return 1;
        }
        status = fcntl(sock1, F_SETFL, O_NONBLOCK);
        if (status == -1) {
                REPORT_ERROR("fcntl() failed: %s\n", strerror(errno));
                return 1;
        }

        status = sendto(sock, "test1", 5, 0, (struct sockaddr*)&addr1, sizeof(addr1));
        if (status == -1) {
                REPORT_ERROR("sendto() failed: %s\n", strerror(errno));
                return 1;
        }

        status = connect(sock, (struct sockaddr*)&addr1, sizeof(addr1));
        if (status == -1) {
                REPORT_ERROR("connect() failed: %s\n", strerror(errno));
                return 1;
        }
        status = connect(sock1, (struct sockaddr*)&addr, sizeof(addr));
        if (status == -1) {
                REPORT_ERROR("connect() failed: %s\n", strerror(errno));
                return 1;
        }

        status = send(sock, "test2", 5, 0);
        if (status == -1) {
                REPORT_ERROR("send() failed: %s\n", strerror(errno));
                return 1;
        }

        int sock2 = socket(AF_UNIX, SOCK_DGRAM, 0);
        if (sock2 == -1) {
                REPORT_ERROR("socket() failed: %s\n", strerror(errno));
                return 1;
        }

        struct sockaddr_un addr2;
        addr2.sun_family = AF_UNIX;
        strcpy(addr2.sun_path, "test2.sock");
        status = bind(sock2, (struct sockaddr*)&addr2, sizeof(addr2));
        if (status == -1) {
                REPORT_ERROR("bind() failed: %s\n", strerror(errno));
                return 1;
        }

        status = sendto(sock2, "test3", 5, MSG_NOSIGNAL, (struct sockaddr*)&addr1, sizeof(addr1));
        if (status != -1) {
                REPORT_ERROR("sendto() succeeded unexpectedly\n");
                return 1;
        }
        if (errno != EPERM) {
                REPORT_ERROR("sendto() failed with unexpected error: %s\n", strerror(errno));
                return 1;
        }

        char buf[16];
        memset(buf, 0, sizeof(buf));
        status = recv(sock1, buf, sizeof(buf), 0);
        if (status == -1) {
                REPORT_ERROR("recv() failed: %s\n", strerror(errno));
                return 1;
        }
        if (strcmp(buf, "test1") != 0) {
                REPORT_ERROR("recv() received unexpected data: %s\n", buf);
                return 1;
        }

        memset(buf, 0, sizeof(buf));
        struct sockaddr_un addr3;
        memset(&addr3, 0, sizeof(addr3));
        socklen_t addrlen = sizeof(addr3);
        status = recvfrom(sock1, buf, sizeof(buf), 0, (struct sockaddr*)&addr3, &addrlen);
        if (status == -1) {
                REPORT_ERROR("recv() failed: %s\n", strerror(errno));
                return 1;
        }
        if (strcmp(buf, "test2") != 0) {
                REPORT_ERROR("recv() received unexpected data: %s\n", buf);
                return 1;
        }
        if (strcmp(addr.sun_path, addr3.sun_path) != 0) {
                REPORT_ERROR("recv() received unexpected address: %s\n", addr3.sun_path);
                return 1;
        }

        status = send(sock, "test4", 4, 0);
        if (status == -1) {
                REPORT_ERROR("send() failed: %s\n", strerror(errno));
                return 1;
        }

        status = send(sock, "test5", 5, 0);
        if (status == -1) {
                REPORT_ERROR("send() failed: %s\n", strerror(errno));
                return 1;
        }

        memset(buf, 0, sizeof(buf));
        status = recv(sock1, buf, 4, 0);
        if (status == -1) {
                REPORT_ERROR("recv() failed: %s\n", strerror(errno));
                return 1;
        }
        if (strcmp(buf, "test") != 0) {
                REPORT_ERROR("recv() received unexpected data: %s\n", buf);
                return 1;
        }

        // The last byte of the previous datagram should be discarded.
        memset(buf, 0, sizeof(buf));
        status = recv(sock1, buf, sizeof(buf), 0);
        if (status == -1) {
                REPORT_ERROR("recv() failed: %s\n", strerror(errno));
                return 1;
        }
        if (strcmp(buf, "test5") != 0) {
                REPORT_ERROR("recv() received unexpected data: %s\n", buf);
                return 1;
        }

        close(sock1);
        status = send(sock, "test6", 5, 0);
        if (status != -1) {
                REPORT_ERROR("send() succeeded unexpectedly\n");
                return 1;
        }
        if (errno != ECONNREFUSED) {
                REPORT_ERROR("send() failed with unexpected error: %s\n", strerror(errno));
                return 1;
        }

        close(sock);
        close(sock2);

        unlink("test.sock");
        unlink("test1.sock");
        unlink("test2.sock");

        return 0;
}


int
send_unbound_test()
{
        unlink("test-socket-unix");
        int sock_server = socket(AF_UNIX, SOCK_DGRAM, 0);
        if (sock_server == -1) {
                REPORT_ERROR("socket() failed: %s\n", strerror(errno));
                return 1;
        }

        struct sockaddr_un addr;
        addr.sun_family = AF_UNIX;
        strcpy(addr.sun_path, "test-socket-unix");
        int status = bind(sock_server, (struct sockaddr*)&addr, sizeof(addr));
        if (status == -1) {
                REPORT_ERROR("bind() failed: %s\n", strerror(errno));
                unlink("test-socket-unix");
                close(sock_server);
                return 1;
        }

        int sock_client = socket(AF_UNIX, SOCK_DGRAM, 0);
        if (sock_client == -1) {
                REPORT_ERROR("socket() failed: %s\n", strerror(errno));
                unlink("test-socket-unix");
                close(sock_server);
                return 1;
        }

        status = sendto(sock_client, "t", 1, 0, (struct sockaddr*)&addr, sizeof(addr));
        if (status != 1) {
                REPORT_ERROR("sendto() failed: %s\n", strerror(errno));
                unlink("test-socket-unix");
                close(sock_server);
                close(sock_client);
                return 1;
        }

        char buf[1024];
        memset(buf, 0, sizeof(buf));
        struct sockaddr_un addr_sender;
        memset(&addr_sender, 0, sizeof(addr_sender));
        socklen_t addrlen = sizeof(addr_sender);
        status = recvfrom(sock_server, buf, sizeof(buf), 0, (struct sockaddr*)&addr_sender, &addrlen);
        if (strcmp(addr_sender.sun_path, "") != 0 || status != 1 || strcmp(buf, "t") != 0) {
                REPORT_ERROR("recvfrom() failed: %s\n", strerror(errno));
                unlink("test-socket-unix");
                close(sock_server);
                close(sock_client);
                return 1;
        }
        return 0;
}


int
shutdown_test()
{
        unlink("test.sock");
        unlink("test1.sock");

        int status;

        int sock = socket(AF_UNIX, SOCK_DGRAM, 0);
        if (sock == -1) {
                REPORT_ERROR("socket() failed: %s\n", strerror(errno));
                return 1;
        }

        struct sockaddr_un addr;
        addr.sun_family = AF_UNIX;
        strcpy(addr.sun_path, "test.sock");
        status = bind(sock, (struct sockaddr*)&addr, sizeof(addr));
        if (status == -1) {
                REPORT_ERROR("bind() failed: %s\n", strerror(errno));
                return 1;
        }

        int sock1 = socket(AF_UNIX, SOCK_DGRAM, 0);
        if (sock1 == -1) {
                REPORT_ERROR("socket() failed: %s\n", strerror(errno));
                return 1;
        }

        struct sockaddr_un addr1;
        addr1.sun_family = AF_UNIX;
        strcpy(addr1.sun_path, "test1.sock");
        status = bind(sock1, (struct sockaddr*)&addr1, sizeof(addr1));
        if (status == -1) {
                REPORT_ERROR("bind() failed: %s\n", strerror(errno));
                return 1;
        }

        status = shutdown(sock, SHUT_WR);
        if (status == -1) {
                REPORT_ERROR("shutdown() failed: %s\n", strerror(errno));
                return 1;
        }

        status = sendto(sock, "test", 4, MSG_NOSIGNAL, (struct sockaddr*)&addr1, sizeof(addr1));
        if (status != -1) {
                REPORT_ERROR("send() succeeded unexpectedly\n");
                return 1;
        }
        if (errno != EPIPE) {
                REPORT_ERROR("send() failed with unexpected error: %s\n", strerror(errno));
                return 1;
        }

        status = sendto(sock1, "test", 4, 0, (struct sockaddr*)&addr, sizeof(addr));
        if (status == -1) {
                REPORT_ERROR("send() failed: %s\n", strerror(errno));
                return 1;
        }

        status = shutdown(sock, SHUT_RD);
        if (status == -1) {
                REPORT_ERROR("shutdown() failed: %s\n", strerror(errno));
                return 1;
        }

        status = sendto(sock1, "test", 4, MSG_NOSIGNAL, (struct sockaddr*)&addr, sizeof(addr));
        if (status != -1) {
                REPORT_ERROR("send() succeeded unexpectedly\n");
                return 1;
        }
        if (errno != EPIPE) {
                REPORT_ERROR("send() failed with unexpected error: %s\n", strerror(errno));
                return 1;
        }

        char buf[16];
        memset(buf, 0, sizeof(buf));
        status = recv(sock, buf, sizeof(buf), 0);
        if (status == -1) {
                REPORT_ERROR("recv() failed: %s\n", strerror(errno));
                return 1;
        }
        if (status != 0) {
                REPORT_ERROR("recv() received unexpected data\n");
                return 1;
        }

        close(sock);
        close(sock1);

        unlink("test.sock");
        unlink("test1.sock");

        return 0;
}


int
send_fd_test()
{
        unlink("test.sock");
        unlink("test1.sock");

        int status;

        int sock = socket(AF_UNIX, SOCK_DGRAM, 0);
        if (sock == -1) {
                REPORT_ERROR("socket() failed: %s\n", strerror(errno));
                return 1;
        }

        struct sockaddr_un addr;
        addr.sun_family = AF_UNIX;
        strcpy(addr.sun_path, "test.sock");
        status = bind(sock, (struct sockaddr*)&addr, sizeof(addr));
        if (status == -1) {
                REPORT_ERROR("bind() failed: %s\n", strerror(errno));
                return 1;
        }

        int sock1 = socket(AF_UNIX, SOCK_DGRAM, 0);
        if (sock1 == -1) {
                REPORT_ERROR("socket() failed: %s\n", strerror(errno));
                return 1;
        }

        struct sockaddr_un addr1;
        addr1.sun_family = AF_UNIX;
        strcpy(addr1.sun_path, "test1.sock");
        status = bind(sock1, (struct sockaddr*)&addr1, sizeof(addr1));
        if (status == -1) {
                REPORT_ERROR("bind() failed: %s\n", strerror(errno));
                return 1;
        }

        status = connect(sock, (struct sockaddr*)&addr1, sizeof(addr1));
        if (status == -1) {
                REPORT_ERROR("connect() failed: %s\n", strerror(errno));
                return 1;
        }

        int fd = shm_open("test_shm", O_CREAT | O_RDWR, 0666);
        if (fd == -1) {
                REPORT_ERROR("shm_open() failed: %s\n", strerror(errno));
                return 1;
        }
        shm_unlink("test_shm");

        // Send FD
        char iobuf[] = "test";
        struct iovec iov {
                .iov_base = iobuf,
                .iov_len = sizeof(iobuf),
        };

        struct msghdr msg;
        memset(&msg, 0, sizeof(msg));

        struct cmsghdr *cmsg;
        char buf[CMSG_SPACE(sizeof(fd))];
        memset(buf, 0, sizeof(buf));
        msg.msg_control = buf;
        msg.msg_controllen = sizeof(buf);
        msg.msg_iov = &iov;
        msg.msg_iovlen = 1;

        cmsg = CMSG_FIRSTHDR(&msg);
        cmsg->cmsg_level = SOL_SOCKET;
        cmsg->cmsg_type = SCM_RIGHTS;
        cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
        memcpy(CMSG_DATA(cmsg), &fd, sizeof(fd));
        msg.msg_controllen = cmsg->cmsg_len;

        status = sendmsg(sock, &msg, 0);
        if (status == -1) {
                REPORT_ERROR("sendmsg() failed: %s\n", strerror(errno));
                return 1;
        }

        // Receive FD
        memset(buf, 0, sizeof(buf));
        msg.msg_control = buf;
        msg.msg_controllen = sizeof(buf);
        msg.msg_iov = &iov;
        msg.msg_iovlen = 1;

        status = recvmsg(sock1, &msg, 0);
        if (status == -1) {
                REPORT_ERROR("recvmsg() failed: %s\n", strerror(errno));
                return 1;
        }

        cmsg = CMSG_FIRSTHDR(&msg);
        if (cmsg == NULL) {
                REPORT_ERROR("recvmsg() failed: no control message\n");
                return 1;
        }
        if (cmsg->cmsg_level != SOL_SOCKET) {
                REPORT_ERROR("recvmsg() failed: unexpected level %d\n", cmsg->cmsg_level);
                return 1;
        }
        if (cmsg->cmsg_type != SCM_RIGHTS) {
                REPORT_ERROR("recvmsg() failed: unexpected type %d\n", cmsg->cmsg_type);
                return 1;
        }
        if (cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) {
                REPORT_ERROR("recvmsg() failed: unexpected length %ld\n", cmsg->cmsg_len);
                return 1;
        }

        int fd1;
        memcpy(&fd1, CMSG_DATA(cmsg), sizeof(fd1));
        if (fd1 == -1) {
                REPORT_ERROR("recvmsg() failed: unexpected fd %d\n", fd1);
                return 1;
        }

        // Check that the FD refers to the same file
        struct stat statbuf;
        status = fstat(fd, &statbuf);
        if (status == -1) {
                REPORT_ERROR("fstat() failed: %s\n", strerror(errno));
                return 1;
        }

        struct stat statbuf1;
        status = fstat(fd1, &statbuf1);
        if (status == -1) {
                REPORT_ERROR("fstat() failed: %s\n", strerror(errno));
                return 1;
        }

        if (statbuf.st_dev != statbuf1.st_dev) {
                REPORT_ERROR("recvmsg() failed: unexpected device %ld\n", (long)statbuf1.st_dev);
                return 1;
        }
        if (statbuf.st_ino != statbuf1.st_ino) {
                REPORT_ERROR("recvmsg() failed: unexpected inode %ld\n", (long)statbuf1.st_ino);
                return 1;
        }

        close(sock);
        close(sock1);
        close(fd);
        close(fd1);

        unlink("test.sock");
        unlink("test1.sock");

        return 0;
}


int
main()
{
        if (connect_test() != 0)
                return 1;

        if (send_test() != 0)
                return 1;

        if (send_unbound_test() != 0)
                return 1;

        if (shutdown_test() != 0)
                return 1;

        if (send_fd_test() != 0)
                return 1;

        return 0;
}