root/src/system/boot/loader/net/UDP.cpp
/*
 * Copyright 2005, Ingo Weinhold <bonefish@cs.tu-berlin.de>.
 * All rights reserved. Distributed under the terms of the MIT License.
 */


#include <boot/net/UDP.h>

#include <stdio.h>

#include <KernelExport.h>

#include <boot/net/ChainBuffer.h>
#include <boot/net/NetStack.h>


//#define TRACE_UDP
#ifdef TRACE_UDP
#       define TRACE(x) dprintf x
#else
#       define TRACE(x) ;
#endif


using std::nothrow;


// #pragma mark - UDPPacket


UDPPacket::UDPPacket()
        :
        fNext(NULL),
        fData(NULL),
        fSize(0)
{
}


UDPPacket::~UDPPacket()
{
        free(fData);
}


status_t
UDPPacket::SetTo(const void *data, size_t size, ip_addr_t sourceAddress,
        uint16 sourcePort, ip_addr_t destinationAddress, uint16 destinationPort)
{
        if (data == NULL)
                return B_BAD_VALUE;

        // clone the data
        fData = malloc(size);
        if (fData == NULL)
                return B_NO_MEMORY;
        memcpy(fData, data, size);

        fSize = size;
        fSourceAddress = sourceAddress;
        fDestinationAddress = destinationAddress;
        fSourcePort = sourcePort;
        fDestinationPort = destinationPort;

        return B_OK;
}


UDPPacket *
UDPPacket::Next() const
{
        return fNext;
}


void
UDPPacket::SetNext(UDPPacket *next)
{
        fNext = next;
}


const void *
UDPPacket::Data() const
{
        return fData;
}


size_t
UDPPacket::DataSize() const
{
        return fSize;
}


ip_addr_t
UDPPacket::SourceAddress() const
{
        return fSourceAddress;
}


uint16
UDPPacket::SourcePort() const
{
        return fSourcePort;
}


ip_addr_t
UDPPacket::DestinationAddress() const
{
        return fDestinationAddress;
}


uint16
UDPPacket::DestinationPort() const
{
        return fDestinationPort;
}


// #pragma mark - UDPSocket


UDPSocket::UDPSocket()
        :
        fUDPService(NetStack::Default()->GetUDPService()),
        fFirstPacket(NULL),
        fLastPacket(NULL),
        fAddress(INADDR_ANY),
        fPort(0)
{
}


UDPSocket::~UDPSocket()
{
        if (fPort != 0 && fUDPService != NULL)
                fUDPService->UnbindSocket(this);
}


status_t
UDPSocket::Bind(ip_addr_t address, uint16 port)
{
        if (fUDPService == NULL) {
                printf("UDPSocket::Bind(): no UDP service\n");
                return B_NO_INIT;
        }

        if (address == INADDR_BROADCAST || port == 0) {
                printf("UDPSocket::Bind(): broadcast IP or port 0\n");
                return B_BAD_VALUE;
        }

        if (fPort != 0) {
                printf("UDPSocket::Bind(): already bound\n");
                return EALREADY;
                        // correct code?
        }

        status_t error = fUDPService->BindSocket(this, address, port);
        if (error != B_OK) {
                printf("UDPSocket::Bind(): service BindSocket() failed\n");
                return error;
        }

        fAddress = address;
        fPort = port;

        return B_OK;
}


void
UDPSocket::Detach()
{
        fUDPService = NULL;
                // This will lead to subsequent methods returning B_NO_INIT
}



status_t
UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort,
        ChainBuffer *buffer)
{
        if (fUDPService == NULL)
                return B_NO_INIT;

        return fUDPService->Send(fPort, destinationAddress, destinationPort,
                buffer);
}


status_t
UDPSocket::Send(ip_addr_t destinationAddress, uint16 destinationPort,
        const void *data, size_t size)
{
        if (data == NULL)
                return B_BAD_VALUE;

        ChainBuffer buffer((void*)data, size);
        return Send(destinationAddress, destinationPort, &buffer);
}


status_t
UDPSocket::Receive(UDPPacket **_packet, bigtime_t timeout)
{
        if (fUDPService == NULL)
                return B_NO_INIT;

        if (_packet == NULL)
                return B_BAD_VALUE;

        bigtime_t startTime = system_time();
        for (;;) {
                fUDPService->ProcessIncomingPackets();
                *_packet = PopPacket();
                if (*_packet != NULL)
                        return B_OK;
        
                if (system_time() - startTime > timeout)
                        return (timeout == 0 ? B_WOULD_BLOCK : B_TIMED_OUT);
        }
}


void
UDPSocket::PushPacket(UDPPacket *packet)
{
        if (fLastPacket != NULL)
                fLastPacket->SetNext(packet);
        else
                fFirstPacket = packet;

        fLastPacket = packet;
        packet->SetNext(NULL);
}


UDPPacket *
UDPSocket::PopPacket()
{
        if (fFirstPacket == NULL)
                return NULL;

        UDPPacket *packet = fFirstPacket;
        fFirstPacket = packet->Next();

        if (fFirstPacket == NULL)
                fLastPacket = NULL;

        packet->SetNext(NULL);
        return packet;
}


// #pragma mark - UDPService


UDPService::UDPService(IPService *ipService)
        :
        IPSubService(kUDPServiceName),
        fIPService(ipService)
{
}


UDPService::~UDPService()
{
        int count = fSockets.Count();
        for (int i = 0; i < count; i++) {
                UDPSocket *socket = fSockets.ElementAt(i);
                socket->Detach();
        }

        if (fIPService != NULL)
                fIPService->UnregisterIPSubService(this);
}


status_t
UDPService::Init()
{
        if (fIPService == NULL)
                return B_BAD_VALUE;
        if (!fIPService->RegisterIPSubService(this))
                return B_NO_MEMORY;
        return B_OK;
}


uint8
UDPService::IPProtocol() const
{
        return IPPROTO_UDP;
}


void
UDPService::HandleIPPacket(IPService *ipService, ip_addr_t sourceIP,
        ip_addr_t destinationIP, const void *data, size_t size)
{
        TRACE(("UDPService::HandleIPPacket(): source: %08lx, destination: %08lx, "
                "%lu - %lu bytes\n", sourceIP, destinationIP, size,
                sizeof(udp_header)));

        if (data == NULL || size < sizeof(udp_header))
                return;

        const udp_header *header = (const udp_header*)data;
        uint16 source = ntohs(header->source);
        uint16 destination = ntohs(header->destination);
        uint16 length = ntohs(header->length);

        // check the header
        if (length < sizeof(udp_header) || length > size
                || (header->checksum != 0       // 0 => checksum disabled
                        && _ChecksumData(data, length, sourceIP, destinationIP) != 0)) {
                TRACE(("UDPService::HandleIPPacket(): dropping packet -- invalid size "
                        "or checksum\n"));
                return;
        }

        // find the target socket
        UDPSocket *socket = _FindSocket(destinationIP, destination);
        if (socket == NULL)
                return;

        // create a UDPPacket and queue it in the socket
        UDPPacket *packet = new(nothrow) UDPPacket;
        if (packet == NULL)
                return;
        status_t error = packet->SetTo((uint8*)data + sizeof(udp_header),
                length - sizeof(udp_header), sourceIP, source, destinationIP,
                destination);
        if (error == B_OK)
                socket->PushPacket(packet);
        else
                delete packet;
}


status_t
UDPService::Send(uint16 sourcePort, ip_addr_t destinationAddress,
        uint16 destinationPort, ChainBuffer *buffer)
{
        TRACE(("UDPService::Send(source port: %hu, to: %08lx:%hu, %lu bytes)\n",
                sourcePort, destinationAddress, destinationPort,
                (buffer != NULL ? buffer->TotalSize() : 0)));

        if (fIPService == NULL)
                return B_NO_INIT;

        if (buffer == NULL)
                return B_BAD_VALUE;

        // prepend the UDP header
        udp_header header;
        ChainBuffer headerBuffer(&header, sizeof(header), buffer);
        header.source = htons(sourcePort);
        header.destination = htons(destinationPort);
        header.length = htons(headerBuffer.TotalSize());

        // compute the checksum
        header.checksum = 0;
        header.checksum = htons(_ChecksumBuffer(&headerBuffer,
                fIPService->IPAddress(), destinationAddress,
                headerBuffer.TotalSize()));
        // 0 means checksum disabled; 0xffff is equivalent in this case
        if (header.checksum == 0)
                header.checksum = 0xffff;

        return fIPService->Send(destinationAddress, IPPROTO_UDP, &headerBuffer);
}


void
UDPService::ProcessIncomingPackets()
{
        if (fIPService != NULL)
                fIPService->ProcessIncomingPackets();
}


status_t
UDPService::BindSocket(UDPSocket *socket, ip_addr_t address, uint16 port)
{
        if (socket == NULL)
                return B_BAD_VALUE;

        if (_FindSocket(address, port) != NULL) {
                printf("UDPService::BindSocket(): address in use\n");
                return EADDRINUSE;
        }

        return fSockets.Add(socket);
}


void
UDPService::UnbindSocket(UDPSocket *socket)
{
        fSockets.Remove(socket);
}


uint16
UDPService::_ChecksumBuffer(ChainBuffer *buffer, ip_addr_t source,
        ip_addr_t destination, uint16 length)
{
        // The checksum is calculated over a pseudo-header plus the UDP packet.
        // So we temporarily prepend the pseudo-header.
        struct pseudo_header {
                ip_addr_t       source;
                ip_addr_t       destination;
                uint8           pad;
                uint8           protocol;
                uint16          length;
        } __attribute__ ((__packed__));
        pseudo_header header = {
                htonl(source),
                htonl(destination),
                0,
                IPPROTO_UDP,
                htons(length)
        };

        ChainBuffer headerBuffer(&header, sizeof(header), buffer);
        uint16 checksum = ip_checksum(&headerBuffer);
        headerBuffer.DetachNext();
        return checksum;
}


uint16
UDPService::_ChecksumData(const void *data, uint16 length, ip_addr_t source,
        ip_addr_t destination)
{
        ChainBuffer buffer((void*)data, length);
        return _ChecksumBuffer(&buffer, source, destination, length);
}


UDPSocket *
UDPService::_FindSocket(ip_addr_t address, uint16 port)
{
        int count = fSockets.Count();
        for (int i = 0; i < count; i++) {
                UDPSocket *socket = fSockets.ElementAt(i);
                if ((address == INADDR_ANY || socket->Address() == INADDR_ANY
                                || socket->Address() == address)
                        && port == socket->Port()) {
                        return socket;
                }
        }

        return NULL;
}