root/src/tests/system/kernel/syscall_restart_test.cpp
#include <errno.h>
#include <netinet/in.h>
#include <signal.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>

#include <OS.h>


enum run_mode {
        RUN_IGNORE_SIGNAL,
        RUN_HANDLE_SIGNAL,
        RUN_HANDLE_SIGNAL_RESTART
};


class Test {
public:
        Test(const char* name)
                : fName(name)
        {
        }

        virtual ~Test()
        {
        }

        bool Run(run_mode mode)
        {
                fRunMode = mode;
                fSignalHandlerCalled = false;

                status_t error = Prepare();
                if (error != B_OK)
                        return Error("Failed to prepare test: %s", strerror(error));

                thread_id thread = spawn_thread(_ThreadEntry, fName, B_NORMAL_PRIORITY,
                        this);
                if (thread < 0)
                        return Error("Failed to spawn thread: %s\n", strerror(thread));

                resume_thread(thread);

                // ...
                // * interrupt without restart
                // * interrupt with restart

                snooze(100000);
                kill(thread, SIGINT);

                PrepareFinish();

                status_t result;
                wait_for_thread(thread, &result);

                if (result != (Interrupted() ? B_INTERRUPTED : B_OK)) {
                        return Error("Unexpected syscall return value: %s\n",
                                strerror(result));
                }

                if ((RunMode() == RUN_IGNORE_SIGNAL) == fSignalHandlerCalled) {
                        if (RunMode() == RUN_IGNORE_SIGNAL)
                                return Error("Handler was called but shouldn't have been!");
                        else
                                return Error("Handler was not called!");
                }

                return Finish(Interrupted());
        }

        void Run()
        {
                printf("%s\n", fName);

                struct {
                        const char*     name;
                        run_mode        mode;
                } tests[] = {
                        { "ignore signal", RUN_IGNORE_SIGNAL },
                        { "handle signal no restart", RUN_HANDLE_SIGNAL },
                        { "handle signal restart", RUN_HANDLE_SIGNAL_RESTART },
                        {}
                };

                for (int i = 0; tests[i].name != NULL; i++) {
                        printf("  %-30s: ", tests[i].name);
                        fflush(stdout);
                        ClearError();
                        if (Run(tests[i].mode))
                                printf("ok\n");
                        else
                                printf("failed (%s)\n", fError);

                        Cleanup();
                }
        }

        run_mode RunMode() const { return fRunMode; }
        bool Interrupted() const { return RunMode() == RUN_HANDLE_SIGNAL; }
        bigtime_t TimeWaited() const { return fTimeWaited; }

protected:
        virtual status_t Prepare()
        {
                return B_OK;
        }

        virtual status_t DoSyscall() = 0;

        virtual void HandleSignal()
        {
        }

        virtual void PrepareFinish()
        {
        }

        virtual bool Finish(bool interrupted)
        {
                return true;
        }

        virtual void Cleanup()
        {
        }

        bool Error(const char* format,...)
        {
                va_list args;
                va_start(args, format);
                vsnprintf(fError, sizeof(fError), format, args);
                va_end(args);

                return false;
        }

        bool Check(bool condition, const char* format,...)
        {
                if (condition)
                        return true;

                va_list args;
                va_start(args, format);
                vsnprintf(fError, sizeof(fError), format, args);
                va_end(args);

                return false;
        }

        void ClearError()
        {
                fError[0] = '\0';
        }

private:
        static status_t _ThreadEntry(void* data)
        {
                return ((Test*)data)->_TestThread();
        }

        static void _SignalHandler(int signal, char* userData)
        {
                Test* self = (Test*)userData;

                self->fSignalHandlerCalled = true;
                self->HandleSignal();
        }

        status_t _TestThread()
        {
                // install handler
                struct sigaction action;
                if (RunMode() == RUN_IGNORE_SIGNAL)
                        action.sa_handler = SIG_IGN;
                else
                        action.sa_handler = (void (*)(int))_SignalHandler;

                action.sa_flags = RunMode() == RUN_HANDLE_SIGNAL_RESTART
                        ? SA_RESTART : 0;

                sigemptyset(&action.sa_mask);
                action.sa_userdata = this;

                sigaction(SIGINT, &action, NULL);

                bigtime_t startTime = system_time();
                status_t status = DoSyscall();
                fTimeWaited = system_time() - startTime;

                return status;
        }

private:
        const char*     fName;
        run_mode        fRunMode;
        bool            fSignalHandlerCalled;
        bigtime_t       fTimeWaited;
        char            fError[1024];
};


class SnoozeTest : public Test {
public:
        SnoozeTest()
                : Test("snooze")
        {
        }

        virtual status_t DoSyscall()
        {
                return snooze(1000000);
        }

        virtual bool Finish(bool interrupted)
        {
                if (interrupted)
                        return Check(TimeWaited() < 200000, "waited too long");

                return Check(TimeWaited() > 900000 && TimeWaited() < 1100000,
                        "waited %lld us instead of 1000000 us", TimeWaited());
        }
};


class ReadTest : public Test {
public:
        ReadTest()
                : Test("read")
        {
        }

        virtual status_t Prepare()
        {
                fBytesRead = -1;
                fFDs[0] = -1;
                fFDs[1] = -1;

                if (pipe(fFDs) != 0)
                        return errno;

                return B_OK;
        }

        virtual status_t DoSyscall()
        {
                char buffer[256];
                fBytesRead = read(fFDs[0], buffer, sizeof(buffer));

                return fBytesRead < 0 ? errno : B_OK;
        }

        virtual void PrepareFinish()
        {
                write(fFDs[1], "Ingo", 4);
        }

        virtual bool Finish(bool interrupted)
        {
                if (interrupted)
                        return Check(fBytesRead < 0, "unexpected read");

                return Check(fBytesRead == 4, "should have read 4 bytes, read only %ld "
                        "bytes", fBytesRead);
        }

        virtual void Cleanup()
        {
                close(fFDs[0]);
                close(fFDs[1]);
        }

private:
        bigtime_t       fTimeWaited;
        ssize_t         fBytesRead;
        int                     fFDs[2];
};


class WriteTest : public Test {
public:
        WriteTest()
                : Test("write")
        {
        }

        virtual status_t Prepare()
        {
                fBytesWritten = -1;
                fFDs[0] = -1;
                fFDs[1] = -1;

                if (pipe(fFDs) != 0)
                        return errno;

                // fill pipe
                fcntl(fFDs[1], F_SETFL, O_NONBLOCK);
                while (write(fFDs[1], "a", 1) == 1);

                return B_OK;
        }

        virtual status_t DoSyscall()
        {
                // blocking wait
                fcntl(fFDs[1], F_SETFL, 0);
                fBytesWritten = write(fFDs[1], "Ingo", 4);

                return fBytesWritten < 0 ? errno : B_OK;
        }

        virtual void PrepareFinish()
        {
                char buffer[256];
                read(fFDs[0], buffer, sizeof(buffer));
        }

        virtual bool Finish(bool interrupted)
        {
                if (interrupted)
                        return Check(fBytesWritten < 0, "unexpected write");

                return Check(fBytesWritten == 4, "should have written 4 bytes, wrote only %ld "
                        "bytes", fBytesWritten);
        }

        virtual void Cleanup()
        {
                close(fFDs[0]);
                close(fFDs[1]);
        }

private:
        ssize_t         fBytesWritten;
        int                     fFDs[2];
};


class AcquireSwitchSemTest : public Test {
public:
        AcquireSwitchSemTest(bool useSwitch)
                : Test(useSwitch ? "switch_sem" : "acquire_sem"),
                fSwitch(useSwitch)
        {
        }

        virtual status_t Prepare()
        {
                fSemaphore = create_sem(0, "test sem");

                return (fSemaphore >= 0 ? B_OK : fSemaphore);
        }

        virtual status_t DoSyscall()
        {
                if (fSwitch)
                        return switch_sem(-1, fSemaphore);

                return acquire_sem(fSemaphore);
        }

        virtual void PrepareFinish()
        {
                release_sem(fSemaphore);
        }

/*
        virtual bool Finish(bool interrupted)
        {
//              int32 semCount = -1;
//              get_sem_count(fSemaphore, &semCount);

                if (interrupted)
                        return true;

                return Check(fBytesWritten == 4, "should have written 4 bytes, wrote only %ld "
                        "bytes", fBytesWritten);
        }
*/

        virtual void Cleanup()
        {
                delete_sem(fSemaphore);
        }

protected:
        sem_id          fSemaphore;
        bool            fSwitch;
};


class AcquireSwitchSemEtcTest : public Test {
public:
        AcquireSwitchSemEtcTest(bool useSwitch)
                : Test(useSwitch ? "switch_sem_etc" : "acquire_sem_etc"),
                fSwitch(useSwitch)
        {
        }

        virtual status_t Prepare()
        {
                fSemaphore = create_sem(0, "test sem");

                return fSemaphore >= 0 ? B_OK : fSemaphore;
        }

        virtual status_t DoSyscall()
        {
                status_t status;
                if (fSwitch) {
                        status = switch_sem_etc(-1, fSemaphore, 1, B_RELATIVE_TIMEOUT,
                                1000000);
                } else {
                        status = acquire_sem_etc(fSemaphore, 1, B_RELATIVE_TIMEOUT,
                                1000000);
                }

                if (!Interrupted() && status == B_TIMED_OUT)
                        return B_OK;

                return status;
        }

        virtual bool Finish(bool interrupted)
        {
                if (interrupted)
                        return Check(TimeWaited() < 200000, "waited too long");

                return Check(TimeWaited() > 900000 && TimeWaited() < 1100000,
                        "waited %lld us instead of 1000000 us", TimeWaited());
        }

        virtual void Cleanup()
        {
                delete_sem(fSemaphore);
        }

protected:
        sem_id          fSemaphore;
        bool            fSwitch;
};


class AcceptTest : public Test {
public:
        AcceptTest()
                : Test("accept")
        {
        }

        virtual status_t Prepare()
        {
                fAcceptedSocket = -1;

                fServerSocket = socket(AF_INET, SOCK_STREAM, 0);
                if (fServerSocket < 0) {
                        fprintf(stderr, "Could not open server socket: %s\n",
                                strerror(errno));
                        return errno;
                }

                int reuse = 1;
                if (setsockopt(fServerSocket, SOL_SOCKET, SO_REUSEADDR, &reuse,
                                sizeof(int)) == -1) {
                        fprintf(stderr, "Could not make server socket reusable: %s\n",
                                strerror(errno));
                        return errno;
                }

                memset(&fServerAddress, 0, sizeof(sockaddr_in));
                fServerAddress.sin_family = AF_INET;
                fServerAddress.sin_addr.s_addr = INADDR_LOOPBACK;

                if (bind(fServerSocket, (struct sockaddr *)&fServerAddress,
                                sizeof(struct sockaddr)) == -1) {
                        fprintf(stderr, "Could not bind server socket: %s\n",
                                strerror(errno));
                        return errno;
                }

                socklen_t length = sizeof(sockaddr_in);
                getsockname(fServerSocket, (sockaddr*)&fServerAddress,
                        &length);

                if (listen(fServerSocket, 10) == -1) {
                        fprintf(stderr, "Could not listen on server socket: %s\n",
                                strerror(errno));
                        return errno;
                }

                return B_OK;
        }

        virtual status_t DoSyscall()
        {
                sockaddr_in clientAddress;
                socklen_t length = sizeof(struct sockaddr_in);

                fAcceptedSocket = accept(fServerSocket,
                        (struct sockaddr *)&clientAddress, &length);
                if (fAcceptedSocket == -1)
                        return errno;

                return B_OK;
        }

        virtual void PrepareFinish()
        {
                if (Interrupted())
                        return;

                int clientSocket = socket(AF_INET, SOCK_STREAM, 0);
                if (clientSocket == -1) {
                        fprintf(stderr, "Could not open client socket: %s\n",
                                strerror(errno));
                        return;
                }

                if (connect(clientSocket, (struct sockaddr *)&fServerAddress,
                                sizeof(struct sockaddr)) == -1) {
                        fprintf(stderr, "Could not connect to server socket: %s\n",
                                strerror(errno));
                }

                close(clientSocket);
        }

        virtual bool Finish(bool interrupted)
        {
                if (interrupted)
                        return Check(fAcceptedSocket < 0, "got socket");

                return Check(fAcceptedSocket >= 0, "got no socket");
        }

        virtual void Cleanup()
        {
                close(fAcceptedSocket);
                close(fServerSocket);
        }

protected:
        int                     fServerSocket;
        sockaddr_in     fServerAddress;
        int                     fAcceptedSocket;
};


class ReceiveTest : public Test {
public:
        ReceiveTest()
                : Test("recv")
        {
        }

        virtual status_t Prepare()
        {
                fBytesRead = -1;
                fAcceptedSocket = -1;
                fClientSocket = -1;

                fServerSocket = socket(AF_INET, SOCK_STREAM, 0);
                if (fServerSocket < 0) {
                        fprintf(stderr, "Could not open server socket: %s\n",
                                strerror(errno));
                        return errno;
                }

                int reuse = 1;
                if (setsockopt(fServerSocket, SOL_SOCKET, SO_REUSEADDR, &reuse,
                                sizeof(int)) == -1) {
                        fprintf(stderr, "Could not make server socket reusable: %s\n",
                                strerror(errno));
                        return errno;
                }

                memset(&fServerAddress, 0, sizeof(sockaddr_in));
                fServerAddress.sin_family = AF_INET;
                fServerAddress.sin_addr.s_addr = INADDR_LOOPBACK;

                if (bind(fServerSocket, (struct sockaddr *)&fServerAddress,
                                sizeof(struct sockaddr)) == -1) {
                        fprintf(stderr, "Could not bind server socket: %s\n",
                                strerror(errno));
                        return errno;
                }

                socklen_t length = sizeof(sockaddr_in);
                getsockname(fServerSocket, (sockaddr*)&fServerAddress,
                        &length);

                if (listen(fServerSocket, 10) == -1) {
                        fprintf(stderr, "Could not listen on server socket: %s\n",
                                strerror(errno));
                        return errno;
                }

                fClientSocket = socket(AF_INET, SOCK_STREAM, 0);
                if (fClientSocket == -1) {
                        fprintf(stderr, "Could not open client socket: %s\n",
                                strerror(errno));
                        return errno;
                }

                fcntl(fClientSocket, F_SETFL, O_NONBLOCK);

                if (connect(fClientSocket, (struct sockaddr *)&fServerAddress,
                                sizeof(struct sockaddr)) == -1) {
                        if (errno != EINPROGRESS) {
                                fprintf(stderr, "Could not connect to server socket: %s\n",
                                        strerror(errno));
                                return errno;
                        }
                }

                sockaddr_in clientAddress;
                length = sizeof(struct sockaddr_in);

                fAcceptedSocket = accept(fServerSocket,
                        (struct sockaddr *)&clientAddress, &length);
                if (fAcceptedSocket == -1)
                        return errno;

                fcntl(fClientSocket, F_SETFL, 0);

                snooze(100000);

                return B_OK;
        }

        virtual status_t DoSyscall()
        {
                char buffer[256];
                fBytesRead = recv(fAcceptedSocket, buffer, sizeof(buffer), 0);

                return fBytesRead < 0 ? errno : B_OK;
        }

        virtual void PrepareFinish()
        {
                write(fClientSocket, "Axel", 4);
        }

        virtual bool Finish(bool interrupted)
        {
                if (interrupted)
                        return Check(fBytesRead < 0, "unexpected read");

                return Check(fBytesRead == 4, "should have read 4 bytes, read only %ld "
                        "bytes", fBytesRead);
        }

        virtual void Cleanup()
        {
                close(fAcceptedSocket);
                close(fServerSocket);
                close(fClientSocket);
        }

protected:
        int                     fServerSocket;
        sockaddr_in     fServerAddress;
        int                     fAcceptedSocket;
        int                     fClientSocket;
        ssize_t         fBytesRead;
};


int
main()
{
        Test* tests[] = {
                new SnoozeTest,
                new ReadTest,
                new WriteTest,
                new AcquireSwitchSemTest(false),
                new AcquireSwitchSemTest(true),
                new AcquireSwitchSemEtcTest(false),
                new AcquireSwitchSemEtcTest(true),
                new AcceptTest,
                new ReceiveTest,
                NULL
        };

        for (int i = 0; tests[i] != NULL; i++)
                tests[i]->Run();

        return 0;
}