root/src/kits/network/libnetapi/NetEndpoint.cpp
/*
 * Copyright 2002-2008, Haiku, Inc. All Rights Reserved.
 * Distributed under the terms of the MIT License.
 */

#include <Message.h>
#include <NetEndpoint.h>

#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <new>
#include <string.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <sys/types.h>
#include <unistd.h>


BNetEndpoint::BNetEndpoint(int type)
        :
        fStatus(B_NO_INIT),
        fFamily(AF_INET),
        fType(type),
        fProtocol(0),
        fSocket(-1),
        fTimeout(B_INFINITE_TIMEOUT)
{
        _SetupSocket();
}


BNetEndpoint::BNetEndpoint(int family, int type, int protocol)
        :
        fStatus(B_NO_INIT),
        fFamily(family),
        fType(type),
        fProtocol(protocol),
        fSocket(-1),
        fTimeout(B_INFINITE_TIMEOUT)
{
        _SetupSocket();
}


BNetEndpoint::BNetEndpoint(BMessage* archive)
        :
        fStatus(B_NO_INIT),
        fFamily(AF_INET),
        fProtocol(0),
        fSocket(-1),
        fTimeout(B_INFINITE_TIMEOUT)
{
        if (!archive)
                return;

        in_addr addr, peer;
        unsigned short addrPort = 0, peerPort = 0;

        fStatus = archive->FindInt32("_BNetEndpoint_addr_addr",
                (int32 *)&addr.s_addr);
        if (fStatus == B_OK) {
                fStatus = archive->FindInt16("_BNetEndpoint_addr_port",
                        (int16 *)&addrPort);
                if (fStatus == B_OK)
                        fStatus = fAddr.SetTo(addr, addrPort);
        }

        fStatus = archive->FindInt32("_BNetEndpoint_peer_addr",
                (int32 *)&peer.s_addr);
        if (fStatus == B_OK) {
                fStatus = archive->FindInt16("_BNetEndpoint_peer_port",
                        (int16 *)&peerPort);
                if (fStatus == B_OK)
                        fStatus = fPeer.SetTo(peer, peerPort);
        }

        fStatus = archive->FindInt64("_BNetEndpoint_timeout", (int64 *)&fTimeout);
        if (fStatus == B_OK)
                fStatus = archive->FindInt32("_BNetEndpoint_proto", (int32 *)&fType);

        if (fStatus == B_OK)
                _SetupSocket();
}


BNetEndpoint::BNetEndpoint(const BNetEndpoint& endpoint)
        :
        fStatus(endpoint.fStatus),
        fFamily(endpoint.fFamily),
        fType(endpoint.fType),
        fProtocol(endpoint.fProtocol),
        fSocket(-1),
        fTimeout(endpoint.fTimeout),
        fAddr(endpoint.fAddr),
        fPeer(endpoint.fPeer)

{
        if (endpoint.fSocket >= 0) {
                fSocket = dup(endpoint.fSocket);
                if (fSocket < 0)
                        fStatus = errno;
        }
}


// Private constructor only used from BNetEndpoint::Accept().
BNetEndpoint::BNetEndpoint(const BNetEndpoint& endpoint, int socket,
        const struct sockaddr_in& localAddress,
        const struct sockaddr_in& peerAddress)
        :
        fStatus(endpoint.fStatus),
        fFamily(endpoint.fFamily),
        fType(endpoint.fType),
        fProtocol(endpoint.fProtocol),
        fSocket(socket),
        fTimeout(endpoint.fTimeout),
        fAddr(localAddress),
        fPeer(peerAddress)
{
}


BNetEndpoint&
BNetEndpoint::operator=(const BNetEndpoint& endpoint)
{
        if (this == &endpoint)
                return *this;

        Close();

        fStatus = endpoint.fStatus;
        fFamily = endpoint.fFamily;
        fType = endpoint.fType;
        fProtocol = endpoint.fProtocol;
        fTimeout = endpoint.fTimeout;
        fAddr = endpoint.fAddr;
        fPeer = endpoint.fPeer;

        fSocket = -1;
        if (endpoint.fSocket >= 0) {
                fSocket = dup(endpoint.fSocket);
                if (fSocket < 0)
                        fStatus = errno;
        }

    return *this;
}


BNetEndpoint::~BNetEndpoint()
{
        if (fSocket >= 0)
                Close();
}


// #pragma mark -


status_t
BNetEndpoint::Archive(BMessage* into, bool deep) const
{
        if (!into)
                return B_ERROR;

        status_t status = BArchivable::Archive(into, deep);
        if (status != B_OK)
                return status;

        in_addr addr, peer;
        unsigned short addrPort, peerPort;

        status = fAddr.GetAddr(addr, &addrPort);
        if (status == B_OK) {
                status = into->AddInt32("_BNetEndpoint_addr_addr", addr.s_addr);
                if (status == B_OK)
                        status = into->AddInt16("_BNetEndpoint_addr_port", addrPort);
                if (status != B_OK)
                        return status;
        }
        status = fPeer.GetAddr(peer, &peerPort);
        if (status == B_OK) {
                status = into->AddInt32("_BNetEndpoint_peer_addr", peer.s_addr);
                if (status == B_OK)
                        status = into->AddInt16("_BNetEndpoint_peer_port", peerPort);
                if (status != B_OK)
                        return status;
        }

        status = into->AddInt64("_BNetEndpoint_timeout", fTimeout);
        if (status == B_OK)
                status = into->AddInt32("_BNetEndpoint_proto", fType);

        return status;
}


BArchivable*
BNetEndpoint::Instantiate(BMessage* archive)
{
        if (!archive)
                return NULL;

        if (!validate_instantiation(archive, "BNetEndpoint"))
                return NULL;

        BNetEndpoint* endpoint = new BNetEndpoint(archive);
        if (endpoint && endpoint->InitCheck() == B_OK)
                return endpoint;

        delete endpoint;
        return NULL;
}


// #pragma mark -


status_t
BNetEndpoint::InitCheck() const
{
        return fSocket == -1 ? B_NO_INIT : B_OK;
}


int
BNetEndpoint::Socket() const
{
        return fSocket;
}


const BNetAddress&
BNetEndpoint::LocalAddr() const
{
        return fAddr;
}


const BNetAddress&
BNetEndpoint::RemoteAddr() const
{
        return fPeer;
}


status_t
BNetEndpoint::SetProtocol(int protocol)
{
        Close();
        fType = protocol;       // sic (protocol is SOCK_DGRAM or SOCK_STREAM)
        return _SetupSocket();
}


int
BNetEndpoint::SetOption(int32 option, int32 level,
        const void* data, unsigned int length)
{
        if (fSocket < 0 && _SetupSocket() != B_OK)
                return fStatus;

        if (setsockopt(fSocket, level, option, data, length) < 0) {
                fStatus = errno;
                return B_ERROR;
        }

        return B_OK;
}


int
BNetEndpoint::SetNonBlocking(bool enable)
{
        if (fSocket < 0 && _SetupSocket() != B_OK)
                return fStatus;

        int flags = fcntl(fSocket, F_GETFL);
        if (flags < 0) {
                fStatus = errno;
                return B_ERROR;
        }

        if (enable)
                flags |= O_NONBLOCK;
        else
                flags &= ~O_NONBLOCK;

        if (fcntl(fSocket, F_SETFL, flags) < 0) {
                fStatus = errno;
                return B_ERROR;
        }

        return B_OK;
}


int
BNetEndpoint::SetReuseAddr(bool enable)
{
        if (fSocket < 0 && _SetupSocket() != B_OK)
                return fStatus;

        int onoff = (int) enable;
        return SetOption(SO_REUSEADDR, SOL_SOCKET, &onoff, sizeof(onoff));
}


void
BNetEndpoint::SetTimeout(bigtime_t timeout)
{
        fTimeout = timeout < 0 ? B_INFINITE_TIMEOUT : timeout;
}


int
BNetEndpoint::Error() const
{
        return (int)fStatus;
}


char*
BNetEndpoint::ErrorStr() const
{
        return strerror(fStatus);
}


// #pragma mark -


void
BNetEndpoint::Close()
{
        if (fSocket >= 0)
                close(fSocket);

        fSocket = -1;
        fStatus = B_NO_INIT;
}


status_t
BNetEndpoint::Bind(const BNetAddress& address)
{
        if (fSocket < 0 && _SetupSocket() != B_OK)
                return fStatus;

        struct sockaddr_in addr;
        status_t status = address.GetAddr(addr);
        if (status != B_OK)
                return status;

        if (bind(fSocket, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
                fStatus = errno;
                Close();
                return B_ERROR;
        }

        socklen_t addrSize = sizeof(addr);
        if (getsockname(fSocket, (struct sockaddr *)&addr, &addrSize) < 0) {
                fStatus = errno;
                Close();
                return B_ERROR;
        }

        fAddr.SetTo(addr);
        return B_OK;
}


status_t
BNetEndpoint::Bind(int port)
{
        BNetAddress addr(INADDR_ANY, port);
        return Bind(addr);
}


status_t
BNetEndpoint::Connect(const BNetAddress& address)
{
        if (fSocket < 0 && _SetupSocket() != B_OK)
                return fStatus;

        sockaddr_in addr;
        if (address.GetAddr(addr) != B_OK)
                return B_ERROR;

        if (connect(fSocket, (sockaddr *) &addr, sizeof(addr)) < 0) {
                Close();
                fStatus = errno;
                return B_ERROR;
        }

        socklen_t addrSize = sizeof(addr);
        if (getpeername(fSocket, (sockaddr *) &addr, &addrSize) < 0) {
                Close();
                fStatus = errno;
                return B_ERROR;
        }
        fPeer.SetTo(addr);
        return B_OK;
}


status_t
BNetEndpoint::Connect(const char *hostname, int port)
{
        BNetAddress addr(hostname, port);
        return Connect(addr);
}


status_t
BNetEndpoint::Listen(int backlog)
{
        if (fSocket < 0 && _SetupSocket() != B_OK)
                return fStatus;

        if (listen(fSocket, backlog) < 0) {
                Close();
                fStatus = errno;
                return B_ERROR;
        }
        return B_OK;
}


BNetEndpoint*
BNetEndpoint::Accept(int32 timeout)
{
        if (!IsDataPending(timeout < 0 ? B_INFINITE_TIMEOUT : 1000LL * timeout))
                return NULL;

        struct sockaddr_in peerAddress;
        socklen_t peerAddressSize = sizeof(peerAddress);

        int socket
                = accept(fSocket, (struct sockaddr *)&peerAddress, &peerAddressSize);
        if (socket < 0) {
                Close();
                fStatus = errno;
                return NULL;
        }

        struct sockaddr_in localAddress;
        socklen_t localAddressSize = sizeof(localAddress);
        if (getsockname(socket, (struct sockaddr *)&localAddress,
                        &localAddressSize) < 0) {
                close(socket);
                fStatus = errno;
                return NULL;
        }

        BNetEndpoint* endpoint = new (std::nothrow) BNetEndpoint(*this, socket,
                localAddress, peerAddress);
        if (endpoint == NULL) {
                close(socket);
                fStatus = B_NO_MEMORY;
                return NULL;
        }

        return endpoint;
}


// #pragma mark -


bool
BNetEndpoint::IsDataPending(bigtime_t timeout)
{
        struct timeval tv;
        fd_set fds;

        FD_ZERO(&fds);
        FD_SET(fSocket, &fds);

        // Make sure the timeout does not overflow. If it does, have an infinite
        // timeout instead. Note that this conveniently includes B_INFINITE_TIMEOUT.
        if (timeout > INT32_MAX * 1000000ll)
                timeout = -1;

        if (timeout >= 0) {
                tv.tv_sec = timeout / 1000000;
                tv.tv_usec = (timeout % 1000000);
        }

        int status;
        do {
                status = select(fSocket + 1, &fds, NULL, NULL,
                        timeout >= 0 ? &tv : NULL);
        } while (status == -1 && errno == EINTR);

        if (status < 0) {
                fStatus = errno;
                return false;
        }

        return FD_ISSET(fSocket, &fds);
}


int32
BNetEndpoint::Receive(void* buffer, size_t length, int flags)
{
        if (fSocket < 0 && _SetupSocket() != B_OK)
                return fStatus;

        if (fTimeout >= 0 && IsDataPending(fTimeout) == false)
                return 0;

        ssize_t bytesReceived = recv(fSocket, buffer, length, flags);
        if (bytesReceived < 0)
                fStatus = errno;

        return bytesReceived;
}


int32
BNetEndpoint::Receive(BNetBuffer& buffer, size_t length, int flags)
{
        BNetBuffer chunk(length);
        ssize_t bytesReceived = Receive(chunk.Data(), length, flags);
        if (bytesReceived > 0)
                buffer.AppendData(chunk.Data(), bytesReceived);
        return bytesReceived;
}


int32
BNetEndpoint::ReceiveFrom(void* buffer, size_t length,
        BNetAddress& address, int flags)
{
        if (fSocket < 0 && _SetupSocket() != B_OK)
                return fStatus;

        if (fTimeout >= 0 && IsDataPending(fTimeout) == false)
                return 0;

        struct sockaddr_in addr;
        socklen_t addrSize = sizeof(addr);

        ssize_t bytesReceived = recvfrom(fSocket, buffer, length, flags,
                (struct sockaddr *)&addr, &addrSize);
        if (bytesReceived < 0)
                fStatus = errno;
        else
                address.SetTo(addr);

        return bytesReceived;
}


int32
BNetEndpoint::ReceiveFrom(BNetBuffer& buffer, size_t length,
        BNetAddress& address, int flags)
{
        BNetBuffer chunk(length);
        ssize_t bytesReceived = ReceiveFrom(chunk.Data(), length, address, flags);
        if (bytesReceived > 0)
                buffer.AppendData(chunk.Data(), bytesReceived);
        return bytesReceived;
}


int32
BNetEndpoint::Send(const void* buffer, size_t length, int flags)
{
        if (fSocket < 0 && _SetupSocket() != B_OK)
                return fStatus;

        ssize_t bytesSent = send(fSocket, (const char *) buffer, length, flags);
        if (bytesSent < 0)
                fStatus = errno;

        return bytesSent;
}


int32
BNetEndpoint::Send(BNetBuffer& buffer, int flags)
{
        return Send(buffer.Data(), buffer.Size(), flags);
}


int32
BNetEndpoint::SendTo(const void* buffer, size_t length,
        const BNetAddress& address, int flags)
{
        if (fSocket < 0 && _SetupSocket() != B_OK)
                return fStatus;

        struct sockaddr_in addr;
        if (address.GetAddr(addr) != B_OK)
                return B_ERROR;

        ssize_t bytesSent = sendto(fSocket, buffer, length, flags,
                (struct sockaddr *) &addr, sizeof(addr));
        if (bytesSent < 0)
                fStatus = errno;

        return bytesSent;
}


int32
BNetEndpoint::SendTo(BNetBuffer& buffer,
        const BNetAddress& address, int flags)
{
        return SendTo(buffer.Data(), buffer.Size(), address, flags);
}


// #pragma mark -


status_t
BNetEndpoint::_SetupSocket()
{
        if ((fSocket = socket(fFamily, fType, fProtocol)) < 0)
                fStatus = errno;
        else
                fStatus = B_OK;
        return fStatus;
}


// #pragma mark -

status_t BNetEndpoint::InitCheck()
{
        return const_cast<const BNetEndpoint*>(this)->InitCheck();
}


const BNetAddress& BNetEndpoint::LocalAddr()
{
        return const_cast<const BNetEndpoint*>(this)->LocalAddr();
}


const BNetAddress& BNetEndpoint::RemoteAddr()
{
        return const_cast<const BNetEndpoint*>(this)->RemoteAddr();
}


// #pragma mark -


// These are virtuals, implemented for binary compatibility purpose
void BNetEndpoint::_ReservedBNetEndpointFBCCruft1() {}
void BNetEndpoint::_ReservedBNetEndpointFBCCruft2() {}
void BNetEndpoint::_ReservedBNetEndpointFBCCruft3() {}
void BNetEndpoint::_ReservedBNetEndpointFBCCruft4() {}
void BNetEndpoint::_ReservedBNetEndpointFBCCruft5() {}
void BNetEndpoint::_ReservedBNetEndpointFBCCruft6() {}