root/src/tools/remote_disk_server/remote_disk_server.cpp
/*
 * Copyright 2005-2007, Ingo Weinhold <bonefish@cs.tu-berlin.de>.
 * All rights reserved. Distributed under the terms of the MIT License.
 */

#include <endian.h>
#include <errno.h>
#include <fcntl.h>
#include <inttypes.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/stat.h>

#include <boot/net/RemoteDiskDefs.h>


#if defined(__BEOS__) && !defined(__HAIKU__)
typedef int socklen_t;
#endif


#if __BYTE_ORDER == __LITTLE_ENDIAN

static inline
uint64_t swap_uint64(uint64_t data)
{
        return ((data & 0xff) << 56)
                | ((data & 0xff00) << 40)
                | ((data & 0xff0000) << 24)
                | ((data & 0xff000000) << 8)
                | ((data >> 8) & 0xff000000)
                | ((data >> 24) & 0xff0000)
                | ((data >> 40) & 0xff00)
                | ((data >> 56) & 0xff);
}

#define host_to_net64(data)     swap_uint64(data)
#define net_to_host64(data)     swap_uint64(data)

#endif

#if __BYTE_ORDER == __BIG_ENDIAN
#define host_to_net64(data)     (data)
#define net_to_host64(data)     (data)
#endif

#undef htonll
#undef ntohll
#define htonll(data)    host_to_net64(data)
#define ntohll(data)    net_to_host64(data)


class Server {
public:
        Server(const char *fileName)
                : fImagePath(fileName),
                  fImageFD(-1),
                  fImageSize(0),
                  fSocket(-1)
        {
        }

        int Run()
        {
                _CreateSocket();

                // main server loop
                for (;;) {
                        // receive
                        fClientAddress.sin_family = AF_INET;
                        fClientAddress.sin_port = 0;
                        fClientAddress.sin_addr.s_addr = htonl(INADDR_ANY);
                        socklen_t addrSize = sizeof(fClientAddress);
                        char buffer[2048];
                        ssize_t bytesRead = recvfrom(fSocket, buffer, sizeof(buffer), 0,
                                                                (sockaddr*)&fClientAddress, &addrSize);
                        // handle error
                        if (bytesRead < 0) {
                                if (errno == EINTR)
                                        continue;
                                fprintf(stderr, "Error: Failed to read from socket: %s.\n",
                                        strerror(errno));
                                exit(1);
                        }

                        // short package?
                        if (bytesRead < (ssize_t)sizeof(remote_disk_header)) {
                                fprintf(stderr, "Dropping short request package (%d bytes).\n",
                                        (int)bytesRead);
                                continue;
                        }

                        fRequest = (remote_disk_header*)buffer;
                        fRequestSize = bytesRead;

                        switch (fRequest->command) {
                                case REMOTE_DISK_HELLO_REQUEST:
                                        _HandleHelloRequest();
                                        break;

                                case REMOTE_DISK_READ_REQUEST:
                                        _HandleReadRequest();
                                        break;

                                case REMOTE_DISK_WRITE_REQUEST:
                                        _HandleWriteRequest();
                                        break;

                                default:
                                        fprintf(stderr, "Ignoring invalid request %d.\n",
                                                (int)fRequest->command);
                                        break;
                        }
                }

                return 0;
        }

private:
        void _OpenImage(bool reopen)
        {
                // already open?
                if (fImageFD >= 0) {
                        if (!reopen)
                                return;

                        close(fImageFD);
                        fImageFD = -1;
                        fImageSize = 0;
                }

                // open the image
                fImageFD = open(fImagePath, O_RDWR);
                if (fImageFD < 0) {
                        fprintf(stderr, "Error: Failed to open \"%s\": %s.\n", fImagePath,
                                strerror(errno));
                        exit(1);
                }

                // get its size
                struct stat st;
                if (fstat(fImageFD, &st) < 0) {
                        fprintf(stderr, "Error: Failed to stat \"%s\": %s.\n", fImagePath,
                                strerror(errno));
                        exit(1);
                }
                fImageSize = st.st_size;
        }

        void _CreateSocket()
        {
                // create a socket
                fSocket = socket(AF_INET, SOCK_DGRAM, 0);
                if (fSocket < 0) {
                        fprintf(stderr, "Error: Failed to create a socket: %s.",
                                strerror(errno));
                        exit(1);
                }

                // bind it to the port
                sockaddr_in addr;
                addr.sin_family = AF_INET;
                addr.sin_port = htons(REMOTE_DISK_SERVER_PORT);
                addr.sin_addr.s_addr = INADDR_ANY;
                if (bind(fSocket, (sockaddr*)&addr, sizeof(addr)) < 0) {
                        fprintf(stderr, "Error: Failed to bind socket to port %hu: %s\n",
                                REMOTE_DISK_SERVER_PORT, strerror(errno));
                        exit(1);
                }
        }

        void _HandleHelloRequest()
        {
                printf("HELLO request\n");

                _OpenImage(true);

                remote_disk_header reply;
                reply.offset = htonll(fImageSize);

                reply.command = REMOTE_DISK_HELLO_REPLY;
                _SendReply(&reply, sizeof(remote_disk_header));
        }

        void _HandleReadRequest()
        {
                _OpenImage(false);

                char buffer[2048];
                remote_disk_header *reply = (remote_disk_header*)buffer;
                uint64_t offset = ntohll(fRequest->offset);
                int16_t size = ntohs(fRequest->size);
                int16_t result = 0;

                printf("READ request: offset: %" PRIu64 ", %hd bytes\n", offset, size);

                if (offset < (uint64_t)fImageSize && size > 0) {
                        // always read 1024 bytes
                        size = REMOTE_DISK_BLOCK_SIZE;
                        if (offset + size > (uint64_t)fImageSize)
                                size = fImageSize - offset;

                        // seek to the offset
                        off_t oldOffset = lseek(fImageFD, offset, SEEK_SET);
                        if (oldOffset >= 0) {
                                // read
                                ssize_t bytesRead = read(fImageFD, reply->data, size);
                                if (bytesRead >= 0) {
                                        result = bytesRead;
                                } else {
                                        fprintf(stderr, "Error: Failed to read at position %" PRIu64 ": "
                                                "%s.", offset, strerror(errno));
                                        result = REMOTE_DISK_IO_ERROR;
                                }
                        } else {
                                fprintf(stderr, "Error: Failed to seek to position %" PRIu64 ": %s.",
                                        offset, strerror(errno));
                                result = REMOTE_DISK_IO_ERROR;
                        }
                }

                // send reply
                reply->command = REMOTE_DISK_READ_REPLY;
                reply->offset = htonll(offset);
                reply->size = htons(result);
                _SendReply(reply, sizeof(*reply) + (result >= 0 ? result : 0));
        }

        void _HandleWriteRequest()
        {
                _OpenImage(false);

                remote_disk_header reply;
                uint64_t offset = ntohll(fRequest->offset);
                int16_t size = ntohs(fRequest->size);
                int16_t result = 0;

                printf("WRITE request: offset: %" PRIu64 ", %hd bytes\n", offset, size);

                if (size < 0
                        || (uint32_t)size > fRequestSize - sizeof(remote_disk_header)
                        || offset > (uint64_t)fImageSize) {
                        result = REMOTE_DISK_BAD_REQUEST;
                } else if (offset < (uint64_t)fImageSize && size > 0) {
                        if (offset + size > (uint64_t)fImageSize)
                                size = fImageSize - offset;

                        // seek to the offset
                        off_t oldOffset = lseek(fImageFD, offset, SEEK_SET);
                        if (oldOffset >= 0) {
                                // write
                                ssize_t bytesWritten = write(fImageFD, fRequest->data, size);
                                if (bytesWritten >= 0) {
                                        result = bytesWritten;
                                } else {
                                        fprintf(stderr, "Error: Failed to write at position %" PRIu64 ": "
                                                "%s.", offset, strerror(errno));
                                        result = REMOTE_DISK_IO_ERROR;
                                }
                        } else {
                                fprintf(stderr, "Error: Failed to seek to position %" PRIu64 ": %s.",
                                        offset, strerror(errno));
                                result = REMOTE_DISK_IO_ERROR;
                        }
                }

                // send reply
                reply.command = REMOTE_DISK_WRITE_REPLY;
                reply.offset = htonll(offset);
                reply.size = htons(result);
                _SendReply(&reply, sizeof(reply));
        }

        void _SendReply(remote_disk_header *reply, int size)
        {
                reply->request_id = fRequest->request_id;
                reply->port = htons(REMOTE_DISK_SERVER_PORT);

                for (;;) {
                        ssize_t bytesSent = sendto(fSocket, reply, size, 0,
                                (const sockaddr*)&fClientAddress, sizeof(fClientAddress));

                        if (bytesSent < 0) {
                                if (errno == EINTR)
                                        continue;
                                fprintf(stderr, "Error: Failed to send reply to client: %s.\n",
                                        strerror(errno));
                        }
                        break;
                }
        }

private:
        const char                      *fImagePath;
        int                                     fImageFD;
        off_t                           fImageSize;
        int                                     fSocket;
        remote_disk_header      *fRequest;
        ssize_t                         fRequestSize;
        sockaddr_in                     fClientAddress;
};


// main
int
main(int argc, const char *const *argv)
{
        if (argc != 2) {
                fprintf(stderr, "Usage: %s <image path>\n", argv[0]);
                exit(1);
        }
        const char *fileName = argv[1];

        Server server(fileName);
        return server.Run();
}