root/src/add-ons/kernel/file_systems/netfs/shared/RequestChannel.cpp
// RequestChannel.cpp

#include "RequestChannel.h"

#include <new>
#include <typeinfo>

#include <stdlib.h>
#include <string.h>

#include <AutoDeleter.h>
#include <ByteOrder.h>

#include "Channel.h"
#include "Compatibility.h"
#include "DebugSupport.h"
#include "Request.h"
#include "RequestDumper.h"
#include "RequestFactory.h"
#include "RequestFlattener.h"
#include "RequestUnflattener.h"

static const int32 kMaxSaneRequestSize  = 128 * 1024;   // 128 KB
static const int32 kDefaultBufferSize   = 4096;                 // 4 KB

// ChannelWriter
class RequestChannel::ChannelWriter : public Writer {
public:
        ChannelWriter(Channel* channel, void* buffer, int32 bufferSize)
                : fChannel(channel),
                  fBuffer(buffer),
                  fBufferSize(bufferSize),
                  fBytesWritten(0)
        {
        }

        virtual status_t Write(const void* buffer, int32 size)
        {
                status_t error = B_OK;
                // if the data don't fit into the buffer anymore, flush the buffer
                if (fBytesWritten + size > fBufferSize) {
                        error = Flush();
                        if (error != B_OK)
                                return error;
                }
                // if the data don't even fit into an empty buffer, just send it,
                // otherwise append it to the buffer
                if (size > fBufferSize) {
                        error = fChannel->Send(buffer, size);
                        if (error != B_OK)
                                return error;
                } else {
                        memcpy((uint8*)fBuffer + fBytesWritten, buffer, size);
                        fBytesWritten += size;
                }
                return error;
        }

        status_t Flush()
        {
                if (fBytesWritten == 0)
                        return B_OK;
                status_t error = fChannel->Send(fBuffer, fBytesWritten);
                if (error != B_OK)
                        return error;
                fBytesWritten = 0;
                return B_OK;
        }

private:
        Channel*        fChannel;
        void*           fBuffer;
        int32           fBufferSize;
        int32           fBytesWritten;
};


// MemoryReader
class RequestChannel::MemoryReader : public Reader {
public:
        MemoryReader(void* buffer, int32 bufferSize)
                : Reader(),
                  fBuffer(buffer),
                  fBufferSize(bufferSize),
                  fBytesRead(0)
        {
        }

        virtual status_t Read(void* buffer, int32 size)
        {
                // check parameters
                if (!buffer || size < 0)
                        return B_BAD_VALUE;
                if (size == 0)
                        return B_OK;
                // get pointer into data buffer
                void* localBuffer;
                bool mustFree;
                status_t error = Read(size, &localBuffer, &mustFree);
                if (error != B_OK)
                        return error;
                // copy data into supplied buffer
                memcpy(buffer, localBuffer, size);
                return B_OK;
        }

        virtual status_t Read(int32 size, void** buffer, bool* mustFree)
        {
                // check parameters
                if (size < 0 || !buffer || !mustFree)
                        return B_BAD_VALUE;
                if (fBytesRead + size > fBufferSize)
                        return B_BAD_VALUE;
                // get the data pointer
                *buffer = (uint8*)fBuffer + fBytesRead;
                *mustFree = false;
                fBytesRead += size;
                return B_OK;
        }

        bool AllBytesRead() const
        {
                return (fBytesRead == fBufferSize);
        }

private:
        void*   fBuffer;
        int32   fBufferSize;
        int32   fBytesRead;
};


// RequestHeader
struct RequestChannel::RequestHeader {
        uint32  type;
        int32   size;
};


// constructor
RequestChannel::RequestChannel(Channel* channel)
        : fChannel(channel),
          fBuffer(NULL),
          fBufferSize(0)
{
        // allocate the send buffer
        fBuffer = malloc(kDefaultBufferSize);
        if (fBuffer)
                fBufferSize = kDefaultBufferSize;
}

// destructor
RequestChannel::~RequestChannel()
{
        free(fBuffer);
}

// SendRequest
status_t
RequestChannel::SendRequest(Request* request)
{
        if (!request)
                RETURN_ERROR(B_BAD_VALUE);
        PRINT("%p->RequestChannel::SendRequest(): request: %p, type: %s\n", this, request, typeid(*request).name());

        // get request size
        int32 size;
        status_t error = _GetRequestSize(request, &size);
        if (error != B_OK)
                RETURN_ERROR(error);
        if (size < 0 || size > kMaxSaneRequestSize) {
                ERROR("RequestChannel::SendRequest(): ERROR: Invalid request size: "
                        "%" B_PRId32 "\n", size);
                RETURN_ERROR(B_BAD_DATA);
        }

        // write the request header
        RequestHeader header;
        header.type = B_HOST_TO_BENDIAN_INT32(request->GetType());
        header.size = B_HOST_TO_BENDIAN_INT32(size);
        ChannelWriter writer(fChannel, fBuffer, fBufferSize);
        error = writer.Write(&header, sizeof(RequestHeader));
        if (error != B_OK)
                RETURN_ERROR(error);

        // now write the request in earnest
        RequestFlattener flattener(&writer);
        request->Flatten(&flattener);
        error = flattener.GetStatus();
        if (error != B_OK)
                RETURN_ERROR(error);
        error = writer.Flush();
        RETURN_ERROR(error);
}

// ReceiveRequest
status_t
RequestChannel::ReceiveRequest(Request** _request)
{
        if (!_request)
                RETURN_ERROR(B_BAD_VALUE);

        // get the request header
        RequestHeader header;
        status_t error = fChannel->Receive(&header, sizeof(RequestHeader));
        if (error != B_OK)
                RETURN_ERROR(error);
        header.type = B_HOST_TO_BENDIAN_INT32(header.type);
        header.size = B_HOST_TO_BENDIAN_INT32(header.size);
        if (header.size < 0 || header.size > kMaxSaneRequestSize) {
                ERROR("RequestChannel::ReceiveRequest(): ERROR: Invalid request size: "
                        "%" B_PRId32 "\n", header.size);
                RETURN_ERROR(B_BAD_DATA);
        }

        // create the request
        Request* request;
        error = RequestFactory::CreateRequest(header.type, &request);
        if (error != B_OK)
                RETURN_ERROR(error);
        ObjectDeleter<Request> requestDeleter(request);

        // allocate a buffer for the data and read them
        if (header.size > 0) {
                RequestBuffer* requestBuffer = RequestBuffer::Create(header.size);
                if (!requestBuffer)
                        RETURN_ERROR(B_NO_MEMORY);
                request->AttachBuffer(requestBuffer);

                // receive the data
                error = fChannel->Receive(requestBuffer->GetData(), header.size);
                if (error != B_OK)
                        RETURN_ERROR(error);

                // unflatten the request
                MemoryReader reader(requestBuffer->GetData(), header.size);
                RequestUnflattener unflattener(&reader);
                request->Unflatten(&unflattener);
                error = unflattener.GetStatus();
                if (error != B_OK)
                        RETURN_ERROR(error);
                if (!reader.AllBytesRead())
                        RETURN_ERROR(B_BAD_DATA);
        }

        requestDeleter.Detach();
        *_request = request;
        PRINT("%p->RequestChannel::ReceiveRequest(): request: %p, type: %s\n", this, request, typeid(*request).name());
        return B_OK;
}

// _GetRequestSize
status_t
RequestChannel::_GetRequestSize(Request* request, int32* size)
{
        DummyWriter dummyWriter;
        RequestFlattener flattener(&dummyWriter);
        request->ShowAround(&flattener);
        status_t error = flattener.GetStatus();
        if (error != B_OK)
                return error;
        *size = flattener.GetBytesWritten();
        return error;
}