root/src/add-ons/kernel/network/stack/net_socket.cpp
/*
 * Copyright 2006-2010, Haiku, Inc. All Rights Reserved.
 * Distributed under the terms of the MIT License.
 *
 * Authors:
 *              Axel Dörfler, axeld@pinc-software.de
 */


#include "stack_private.h"

#include <stdlib.h>
#include <string.h>
#include <sys/ioctl.h>
#include <sys/time.h>

#include <new>

#include <Drivers.h>
#include <KernelExport.h>
#include <Select.h>

#include <AutoDeleter.h>
#include <team.h>
#include <util/AutoLock.h>
#include <util/list.h>
#include <WeakReferenceable.h>

#include <fs/select_sync_pool.h>
#include <kernel.h>

#include <net_protocol.h>
#include <net_stack.h>
#include <net_stat.h>

#include "ancillary_data.h"
#include "utility.h"


//#define TRACE_SOCKET
#ifdef TRACE_SOCKET
#       define TRACE(x...) dprintf(STACK_DEBUG_PREFIX x)
#else
#       define TRACE(x...) ;
#endif


struct net_socket_private;
typedef DoublyLinkedList<net_socket_private> SocketList;

struct net_socket_private : net_socket,
                DoublyLinkedListLinkImpl<net_socket_private>,
                BWeakReferenceable {
        net_socket_private();
        ~net_socket_private();

        void RemoveFromParent();

        BWeakReference<net_socket_private> parent;
        team_id                                         owner;
        uint32                                          max_backlog;
        uint32                                          child_count;
        SocketList                                      pending_children;
        SocketList                                      connected_children;

        struct select_sync_pool*        select_pool;
        mutex                                           lock;

        bool                                            is_connected;
        bool                                            is_in_socket_list;
};


int socket_bind(net_socket* socket, const struct sockaddr* address,
        socklen_t addressLength);
int socket_setsockopt(net_socket* socket, int level, int option,
        const void* value, int length);
ssize_t socket_read_avail(net_socket* socket);

static SocketList sSocketList;
static mutex sSocketLock;


net_socket_private::net_socket_private()
        :
        owner(-1),
        max_backlog(0),
        child_count(0),
        select_pool(NULL),
        is_connected(false),
        is_in_socket_list(false)
{
        first_protocol = NULL;
        first_info = NULL;
        options = 0;
        linger = 0;
        bound_to_device = 0;
        error = 0;

        address.ss_len = 0;
        peer.ss_len = 0;

        mutex_init(&lock, "socket");

        // set defaults (may be overridden by the protocols)
        send.buffer_size = 65535;
        send.low_water_mark = 1;
        send.timeout = B_INFINITE_TIMEOUT;
        receive.buffer_size = 65535;
        receive.low_water_mark = 1;
        receive.timeout = B_INFINITE_TIMEOUT;
}


net_socket_private::~net_socket_private()
{
        TRACE("delete net_socket %p\n", this);

        if (parent != NULL)
                panic("socket still has a parent!");

        if (is_in_socket_list) {
                MutexLocker _(sSocketLock);
                sSocketList.Remove(this);
        }

        mutex_lock(&lock);

        // also delete all children of this socket
        while (net_socket_private* child = pending_children.RemoveHead()) {
                child->RemoveFromParent();
        }
        while (net_socket_private* child = connected_children.RemoveHead()) {
                child->RemoveFromParent();
        }

        mutex_unlock(&lock);

        put_domain_protocols(this);

        mutex_destroy(&lock);
}


void
net_socket_private::RemoveFromParent()
{
        ASSERT(!is_in_socket_list && parent != NULL);

        parent = NULL;

        mutex_lock(&sSocketLock);
        sSocketList.Add(this);
        mutex_unlock(&sSocketLock);

        is_in_socket_list = true;

        ReleaseReference();
}


//      #pragma mark -


static status_t
create_socket(int family, int type, int protocol, net_socket_private** _socket)
{
        struct net_socket_private* socket = new(std::nothrow) net_socket_private;
        if (socket == NULL)
                return B_NO_MEMORY;
        status_t status = socket->InitCheck();
        if (status != B_OK) {
                delete socket;
                return status;
        }

        socket->family = family;
        socket->type = type;
        socket->protocol = protocol;

        status = get_domain_protocols(socket);
        if (status != B_OK) {
                delete socket;
                return status;
        }

        TRACE("create net_socket %p (%u.%u.%u):\n", socket, socket->family,
                socket->type, socket->protocol);

#ifdef TRACE_SOCKET
        net_protocol* current = socket->first_protocol;
        for (int i = 0; current != NULL; current = current->next, i++)
                TRACE("  [%d] %p  %s\n", i, current, current->module->info.name);
#endif

        *_socket = socket;
        return B_OK;
}


static status_t
add_ancillary_data(net_socket* socket, ancillary_data_container* container,
        void* data, size_t dataLen)
{
        cmsghdr* header = (cmsghdr*)data;

        if (dataLen == 0)
                return B_OK;

        if (socket->first_info->add_ancillary_data == NULL)
                return B_NOT_SUPPORTED;

        while (true) {
                if (header->cmsg_len < CMSG_LEN(0) || header->cmsg_len > dataLen)
                        return B_BAD_VALUE;

                status_t status = socket->first_info->add_ancillary_data(
                        socket->first_protocol, container, header);
                if (status != B_OK)
                        return status;

                const size_t alignedLength = CMSG_ALIGN(header->cmsg_len);
                if (dataLen <= alignedLength)
                        break;

                dataLen -= alignedLength;
                header = (cmsghdr*)((uint8*)header + alignedLength);
        }

        return B_OK;
}


static status_t
process_ancillary_data(net_socket* socket, ancillary_data_container* container,
        msghdr* messageHeader, int flags)
{
        uint8* dataBuffer = (uint8*)messageHeader->msg_control;
        int dataBufferLen = messageHeader->msg_controllen;

        if (container == NULL || dataBuffer == NULL) {
                messageHeader->msg_controllen = 0;
                return B_OK;
        }

        if (socket->first_info->process_ancillary_data == NULL)
                return B_NOT_SUPPORTED;

        ssize_t bytesWritten = socket->first_info->process_ancillary_data(
                socket->first_protocol, container, dataBuffer, dataBufferLen, flags);
        if (bytesWritten < 0)
                return bytesWritten;

        messageHeader->msg_controllen = bytesWritten;
        return B_OK;
}


static status_t
process_ancillary_data(net_socket* socket,
        net_buffer* buffer, msghdr* messageHeader, int flags)
{
        void *dataBuffer = messageHeader->msg_control;
        ssize_t bytesWritten;

        if (dataBuffer == NULL) {
                messageHeader->msg_controllen = 0;
                return B_OK;
        }

        if (socket->first_info->process_ancillary_data_no_container == NULL) {
                messageHeader->msg_controllen = 0;
                return B_OK;
        }

        bytesWritten = socket->first_info->process_ancillary_data_no_container(
                socket->first_protocol, buffer, dataBuffer,
                messageHeader->msg_controllen);
        if (bytesWritten < 0)
                return bytesWritten;
        messageHeader->msg_controllen = bytesWritten;

        return B_OK;
}


static size_t
calculate_total_length(msghdr* header, void* data, size_t length)
{
        size_t totalLength = length;
        if (header != NULL) {
                ASSERT((header->msg_iovlen == 0 && data == NULL)
                        || data == header->msg_iov[0].iov_base);

                // calculate the length considering all of the extra buffers
                for (int i = 1; i < header->msg_iovlen; i++)
                        totalLength += header->msg_iov[i].iov_len;
        }
        return totalLength;
}


static ssize_t
socket_receive_no_buffer(net_socket* socket, msghdr* header, void* data,
        size_t length, int flags)
{
        iovec stackVec = { data, length };
        iovec* vecs = header ? header->msg_iov : &stackVec;
        int vecCount = header ? header->msg_iovlen : 1;
        sockaddr* address = header ? (sockaddr*)header->msg_name : NULL;
        socklen_t* addressLen = header ? &header->msg_namelen : NULL;

        ancillary_data_container* ancillaryData = NULL;
        ssize_t bytesRead = socket->first_info->read_data_no_buffer(
                socket->first_protocol, vecs, vecCount, &ancillaryData, address,
                addressLen, flags & ~(MSG_CMSG_CLOEXEC | MSG_CMSG_CLOFORK | MSG_TRUNC));
        if (bytesRead < 0)
                return bytesRead;

        CObjectDeleter<
                ancillary_data_container, void, delete_ancillary_data_container>
                ancillaryDataDeleter(ancillaryData);

        // process ancillary data
        if (header != NULL) {
                status_t status = process_ancillary_data(socket, ancillaryData, header,
                        flags & (MSG_CMSG_CLOEXEC | MSG_CMSG_CLOFORK));
                if (status != B_OK)
                        return status;

                header->msg_flags = 0;
        }

        size_t totalLength = calculate_total_length(header, data, length);
        if (totalLength < (size_t)bytesRead) {
                if (header != NULL)
                        header->msg_flags = MSG_TRUNC;
                if ((flags & MSG_TRUNC) == 0)
                        return totalLength;
        }
        return bytesRead;
}


#if ENABLE_DEBUGGER_COMMANDS


static void
print_socket_line(net_socket_private* socket, const char* prefix)
{
        BReference<net_socket_private> parent = socket->parent.GetReference();
        kprintf("%s%p %2d.%2d.%2d %6" B_PRId32 " %p %p  %p%s\n", prefix, socket,
                socket->family, socket->type, socket->protocol, socket->owner,
                socket->first_protocol, socket->first_info, parent.Get(),
                parent.IsSet() ? socket->is_connected ? " (c)" : " (p)" : "");
}


static int
dump_socket(int argc, char** argv)
{
        if (argc < 2) {
                kprintf("usage: %s [address]\n", argv[0]);
                return 0;
        }

        net_socket_private* socket = (net_socket_private*)parse_expression(argv[1]);

        kprintf("SOCKET %p\n", socket);
        kprintf("  family.type.protocol: %d.%d.%d\n",
                socket->family, socket->type, socket->protocol);
        BReference<net_socket_private> parent = socket->parent.GetReference();
        kprintf("  parent:               %p\n", parent.Get());
        kprintf("  first protocol:       %p\n", socket->first_protocol);
        kprintf("  first module_info:    %p\n", socket->first_info);
        kprintf("  options:              %x\n", socket->options);
        kprintf("  linger:               %d\n", socket->linger);
        kprintf("  bound to device:      %" B_PRIu32 "\n", socket->bound_to_device);
        kprintf("  owner:                %" B_PRId32 "\n", socket->owner);
        kprintf("  max backlog:          %" B_PRId32 "\n", socket->max_backlog);
        kprintf("  is connected:         %d\n", socket->is_connected);
        kprintf("  child_count:          %" B_PRIu32 "\n", socket->child_count);

        if (socket->child_count == 0)
                return 0;

        kprintf("    pending children:\n");
        SocketList::Iterator iterator = socket->pending_children.GetIterator();
        while (net_socket_private* child = iterator.Next()) {
                print_socket_line(child, "      ");
        }

        kprintf("    connected children:\n");
        iterator = socket->connected_children.GetIterator();
        while (net_socket_private* child = iterator.Next()) {
                print_socket_line(child, "      ");
        }

        return 0;
}


static int
dump_sockets(int argc, char** argv)
{
        kprintf("address        kind  owner protocol   module_info parent\n");

        SocketList::Iterator iterator = sSocketList.GetIterator();
        while (net_socket_private* socket = iterator.Next()) {
                print_socket_line(socket, "");

                SocketList::Iterator childIterator
                        = socket->pending_children.GetIterator();
                while (net_socket_private* child = childIterator.Next()) {
                        print_socket_line(child, " ");
                }

                childIterator = socket->connected_children.GetIterator();
                while (net_socket_private* child = childIterator.Next()) {
                        print_socket_line(child, " ");
                }
        }

        return 0;
}


#endif  // ENABLE_DEBUGGER_COMMANDS


//      #pragma mark -


status_t
socket_open(int family, int type, int protocol, net_socket** _socket)
{
        net_socket_private* socket;
        status_t status = create_socket(family, type, protocol, &socket);
        if (status != B_OK)
                return status;

        status = socket->first_info->open(socket->first_protocol);
        if (status != B_OK) {
                delete socket;
                return status;
        }

        socket->owner = team_get_current_team_id();
        socket->is_in_socket_list = true;

        mutex_lock(&sSocketLock);
        sSocketList.Add(socket);
        mutex_unlock(&sSocketLock);

        *_socket = socket;
        return B_OK;
}


status_t
socket_close(net_socket* _socket)
{
        net_socket_private* socket = (net_socket_private*)_socket;
        return socket->first_info->close(socket->first_protocol);
}


void
socket_free(net_socket* _socket)
{
        net_socket_private* socket = (net_socket_private*)_socket;
        socket->first_info->free(socket->first_protocol);
        socket->ReleaseReference();
}


status_t
socket_control(net_socket* socket, uint32 op, void* data, size_t length)
{
        switch (op) {
                case FIONREAD:
                {
                        if (data == NULL || (socket->options & SO_ACCEPTCONN) != 0)
                                return B_BAD_VALUE;

                        int available = (int)socket_read_avail(socket);
                        if (available < 0)
                                available = 0;

                        if (is_syscall()) {
                                if (!IS_USER_ADDRESS(data)
                                        || user_memcpy(data, &available, sizeof(available))
                                                != B_OK) {
                                        return B_BAD_ADDRESS;
                                }
                        } else
                                *(int*)data = available;

                        return B_OK;
                }

                case B_SET_BLOCKING_IO:
                case B_SET_NONBLOCKING_IO:
                {
                        int value = op == B_SET_NONBLOCKING_IO;
                        return socket_setsockopt(socket, SOL_SOCKET, SO_NONBLOCK, &value,
                                sizeof(int));
                }
        }

        return socket->first_info->control(socket->first_protocol,
                LEVEL_DRIVER_IOCTL, op, data, &length);
}


ssize_t
socket_read_avail(net_socket* socket)
{
        return socket->first_info->read_avail(socket->first_protocol);
}


ssize_t
socket_send_avail(net_socket* socket)
{
        return socket->first_info->send_avail(socket->first_protocol);
}


status_t
socket_send_data(net_socket* socket, net_buffer* buffer)
{
        return socket->first_info->send_data(socket->first_protocol,
                buffer);
}


status_t
socket_receive_data(net_socket* socket, size_t length, uint32 flags,
        net_buffer** _buffer)
{
        status_t status = socket->first_info->read_data(socket->first_protocol,
                length, flags, _buffer);
        if (status != B_OK)
                return status;

        if (*_buffer && length < (*_buffer)->size) {
                // discard any data behind the amount requested
                gNetBufferModule.trim(*_buffer, length);
        }

        return status;
}


status_t
socket_get_next_stat(uint32* _cookie, int family, struct net_stat* stat)
{
        MutexLocker locker(sSocketLock);

        net_socket_private* socket = NULL;
        SocketList::Iterator iterator = sSocketList.GetIterator();
        uint32 cookie = *_cookie;
        uint32 count = 0;

        while (true) {
                socket = iterator.Next();
                if (socket == NULL)
                        return B_ENTRY_NOT_FOUND;

                // TODO: also traverse the pending connections
                if (count == cookie)
                        break;

                if (family == -1 || family == socket->family)
                        count++;
        }

        *_cookie = count + 1;

        stat->family = socket->family;
        stat->type = socket->type;
        stat->protocol = socket->protocol;
        stat->owner = socket->owner;
        stat->state[0] = '\0';
        memcpy(&stat->address, &socket->address, sizeof(struct sockaddr_storage));
        memcpy(&stat->peer, &socket->peer, sizeof(struct sockaddr_storage));
        stat->receive_queue_size = 0;
        stat->send_queue_size = 0;

        // fill in protocol specific data (if supported by the protocol)
        size_t length = sizeof(net_stat);
        socket->first_info->control(socket->first_protocol, socket->protocol,
                NET_STAT_SOCKET, stat, &length);

        return B_OK;
}


//      #pragma mark - connections


bool
socket_acquire(net_socket* _socket)
{
        net_socket_private* socket = (net_socket_private*)_socket;

        // During destruction, the socket might still be accessible over its
        // endpoint protocol. We need to make sure the endpoint cannot acquire the
        // socket anymore -- while not obvious, the endpoint protocol is responsible
        // for the proper locking here.
        if (socket->CountReferences() == 0)
                return false;

        socket->AcquireReference();
        return true;
}


bool
socket_release(net_socket* _socket)
{
        net_socket_private* socket = (net_socket_private*)_socket;
        return socket->ReleaseReference();
}


status_t
socket_spawn_pending(net_socket* _parent, net_socket** _socket)
{
        net_socket_private* parent = (net_socket_private*)_parent;

        TRACE("%s(%p)\n", __FUNCTION__, parent);

        MutexLocker locker(parent->lock);

        // We actually accept more pending connections to compensate for those
        // that never complete, and also make sure at least a single connection
        // can always be accepted
        if (parent->child_count > 3 * parent->max_backlog / 2)
                return ENOBUFS;

        net_socket_private* socket;
        status_t status = create_socket(parent->family, parent->type,
                parent->protocol, &socket);
        if (status != B_OK)
                return status;

        // inherit parent's properties
        socket->send = parent->send;
        socket->receive = parent->receive;
        socket->options = parent->options & (SO_KEEPALIVE | SO_DONTROUTE | SO_LINGER | SO_OOBINLINE);
        socket->linger = parent->linger;
        socket->owner = parent->owner;
        memcpy(&socket->address, &parent->address, parent->address.ss_len);
        memcpy(&socket->peer, &parent->peer, parent->peer.ss_len);

        // add to the parent's list of pending connections
        parent->pending_children.Add(socket);
        socket->parent = parent;
        parent->child_count++;

        *_socket = socket;
        return B_OK;
}


/*!     Dequeues a connected child from a parent socket.
        It also returns a reference with the child socket.
*/
status_t
socket_dequeue_connected(net_socket* _parent, net_socket** _socket)
{
        net_socket_private* parent = (net_socket_private*)_parent;

        mutex_lock(&parent->lock);

        net_socket_private* socket = parent->connected_children.RemoveHead();
        if (socket != NULL) {
                socket->AcquireReference();
                socket->RemoveFromParent();
                parent->child_count--;
                *_socket = socket;
        }

        mutex_unlock(&parent->lock);

        if (socket == NULL)
                return B_ENTRY_NOT_FOUND;

        return B_OK;
}


ssize_t
socket_count_connected(net_socket* _parent)
{
        net_socket_private* parent = (net_socket_private*)_parent;

        MutexLocker _(parent->lock);
        return parent->connected_children.Count();
}


status_t
socket_set_max_backlog(net_socket* _socket, uint32 backlog)
{
        net_socket_private* socket = (net_socket_private*)_socket;

        // we enforce an upper limit of connections waiting to be accepted
        if (backlog > 256)
                backlog = 256;

        MutexLocker _(socket->lock);

        // first remove the pending connections, then the already connected
        // ones as needed
        net_socket_private* child;
        while (socket->child_count > backlog
                && (child = socket->pending_children.RemoveTail()) != NULL) {
                child->RemoveFromParent();
                socket->child_count--;
        }
        while (socket->child_count > backlog
                && (child = socket->connected_children.RemoveTail()) != NULL) {
                child->RemoveFromParent();
                socket->child_count--;
        }

        socket->max_backlog = backlog;
        return B_OK;
}


/*!     Returns whether or not this socket has a parent. The parent might not be
        valid anymore, though.
*/
bool
socket_has_parent(net_socket* _socket)
{
        net_socket_private* socket = (net_socket_private*)_socket;
        return socket->parent != NULL;
}


/*!     The socket has been connected. It will be moved to the connected queue
        of its parent socket.
*/
status_t
socket_connected(net_socket* _socket)
{
        net_socket_private* socket = (net_socket_private*)_socket;

        TRACE("socket_connected(%p)\n", socket);

        if (socket->parent == NULL) {
                socket->is_connected = true;
                return B_OK;
        }

        BReference<net_socket_private> parent = socket->parent.GetReference();
        if (!parent.IsSet())
                return B_BAD_VALUE;

        MutexLocker _(parent->lock);

        parent->pending_children.Remove(socket);
        parent->connected_children.Add(socket);
        socket->is_connected = true;

        // notify parent
        if (parent->select_pool)
                notify_select_event_pool(parent->select_pool, B_SELECT_READ);

        return B_OK;
}


/*!     The socket has been aborted. Steals the parent's reference, and releases
        it.
*/
status_t
socket_aborted(net_socket* _socket)
{
        net_socket_private* socket = (net_socket_private*)_socket;

        TRACE("socket_aborted(%p)\n", socket);

        BReference<net_socket_private> parent = socket->parent.GetReference();
        if (!parent.IsSet())
                return B_BAD_VALUE;

        MutexLocker _(parent->lock);

        if (socket->is_connected)
                parent->connected_children.Remove(socket);
        else
                parent->pending_children.Remove(socket);

        parent->child_count--;
        socket->RemoveFromParent();

        return B_OK;
}


//      #pragma mark - notifications


status_t
socket_request_notification(net_socket* _socket, uint8 event, selectsync* sync)
{
        net_socket_private* socket = (net_socket_private*)_socket;

        mutex_lock(&socket->lock);

        status_t status = add_select_sync_pool_entry(&socket->select_pool, sync,
                event);

        mutex_unlock(&socket->lock);

        if (status != B_OK)
                return status;

        // check if the event is already present
        // TODO: add support for poll() types

        switch (event) {
                case B_SELECT_READ:
                {
                        ssize_t available = socket_read_avail(socket);
                        if ((ssize_t)socket->receive.low_water_mark <= available
                                || available < B_OK)
                                notify_select_event(sync, event);
                        break;
                }
                case B_SELECT_WRITE:
                {
                        if ((socket->options & SO_ACCEPTCONN) != 0)
                                break;

                        ssize_t available = socket_send_avail(socket);
                        if ((ssize_t)socket->send.low_water_mark <= available
                                || available < B_OK)
                                notify_select_event(sync, event);
                        break;
                }
                case B_SELECT_ERROR:
                        if (socket->error != B_OK)
                                notify_select_event(sync, event);
                        break;
        }

        return B_OK;
}


status_t
socket_cancel_notification(net_socket* _socket, uint8 event, selectsync* sync)
{
        net_socket_private* socket = (net_socket_private*)_socket;

        MutexLocker _(socket->lock);
        return remove_select_sync_pool_entry(&socket->select_pool, sync, event);
}


status_t
socket_notify(net_socket* _socket, uint8 event, int32 value)
{
        net_socket_private* socket = (net_socket_private*)_socket;
        bool notify = true;
        bool atomic = (socket->first_info->flags & NET_PROTOCOL_ATOMIC_MESSAGES) != 0;

        switch (event) {
                case B_SELECT_READ:
                        if ((ssize_t)socket->receive.low_water_mark > value
                                && value >= B_OK && !atomic) {
                                notify = false;
                        }
                        break;

                case B_SELECT_WRITE:
                        if ((ssize_t)socket->send.low_water_mark > value
                                && value >= B_OK && !atomic) {
                                notify = false;
                        }
                        break;

                case B_SELECT_ERROR:
                        socket->error = value;
                        break;
        }

        MutexLocker _(socket->lock);

        if (notify && socket->select_pool != NULL) {
                notify_select_event_pool(socket->select_pool, event);

                if (event == B_SELECT_ERROR) {
                        // always notify read/write on error
                        notify_select_event_pool(socket->select_pool, B_SELECT_READ);
                        notify_select_event_pool(socket->select_pool, B_SELECT_WRITE);
                }
        }

        return B_OK;
}


//      #pragma mark - standard socket API


int
socket_accept(net_socket* socket, struct sockaddr* address,
        socklen_t* _addressLength, net_socket** _acceptedSocket)
{
        if (socket->type != SOCK_STREAM && socket->type != SOCK_SEQPACKET)
                return B_NOT_SUPPORTED;
        if ((socket->options & SO_ACCEPTCONN) == 0)
                return B_BAD_VALUE;

        net_socket* accepted;
        status_t status = socket->first_info->accept(socket->first_protocol,
                &accepted);
        if (status != B_OK)
                return status;

        if (address && *_addressLength > 0) {
                memcpy(address, &accepted->peer, min_c(*_addressLength,
                        min_c(accepted->peer.ss_len, sizeof(sockaddr_storage))));
                *_addressLength = accepted->peer.ss_len;
        }

        *_acceptedSocket = accepted;
        return B_OK;
}


int
socket_bind(net_socket* socket, const struct sockaddr* address,
        socklen_t addressLength)
{
        sockaddr empty;
        if (address == NULL) {
                // special - try to bind to an empty address, like INADDR_ANY
                memset(&empty, 0, sizeof(sockaddr));
                empty.sa_len = sizeof(sockaddr);
                empty.sa_family = socket->family;

                address = &empty;
                addressLength = sizeof(sockaddr);
        }

        if (socket->address.ss_len != 0)
                return B_BAD_VALUE;

        memcpy(&socket->address, address, sizeof(sockaddr));
        socket->address.ss_len = sizeof(sockaddr_storage);

        status_t status = socket->first_info->bind(socket->first_protocol,
                (sockaddr*)address);
        if (status != B_OK) {
                // clear address again, as binding failed
                socket->address.ss_len = 0;
        }

        return status;
}


int
socket_connect(net_socket* socket, const struct sockaddr* address,
        socklen_t addressLength)
{
        if (address == NULL || addressLength == 0)
                return ENETUNREACH;

        if (socket->address.ss_len == 0) {
                // try to bind first
                status_t status = socket_bind(socket, NULL, 0);
                if (status != B_OK)
                        return status;
        }

        return socket->first_info->connect(socket->first_protocol, address);
}


int
socket_getpeername(net_socket* _socket, struct sockaddr* address,
        socklen_t* _addressLength)
{
        net_socket_private* socket = (net_socket_private*)_socket;
        BReference<net_socket_private> parent = socket->parent.GetReference();

        if ((!parent.IsSet() && !socket->is_connected) || socket->peer.ss_len == 0
                || socket->peer.ss_family == AF_UNSPEC) {
                return ENOTCONN;
        }

        memcpy(address, &socket->peer, min_c(*_addressLength, socket->peer.ss_len));
        *_addressLength = socket->peer.ss_len;
        return B_OK;
}


int
socket_getsockname(net_socket* socket, struct sockaddr* address,
        socklen_t* _addressLength)
{
        if (socket->address.ss_len == 0) {
                struct sockaddr buffer;
                memset(&buffer, 0, sizeof(buffer));
                buffer.sa_family = socket->family;

                memcpy(address, &buffer, min_c(*_addressLength, sizeof(buffer)));
                *_addressLength = sizeof(buffer);
                return B_OK;
        }

        memcpy(address, &socket->address, min_c(*_addressLength,
                socket->address.ss_len));
        *_addressLength = socket->address.ss_len;
        return B_OK;
}


status_t
socket_get_option(net_socket* socket, int level, int option, void* value,
        int* _length)
{
        if (level != SOL_SOCKET)
                return ENOPROTOOPT;

        switch (option) {
                case SO_SNDBUF:
                {
                        uint32* size = (uint32*)value;
                        *size = socket->send.buffer_size;
                        *_length = sizeof(uint32);
                        return B_OK;
                }

                case SO_RCVBUF:
                {
                        uint32* size = (uint32*)value;
                        *size = socket->receive.buffer_size;
                        *_length = sizeof(uint32);
                        return B_OK;
                }

                case SO_SNDLOWAT:
                {
                        uint32* size = (uint32*)value;
                        *size = socket->send.low_water_mark;
                        *_length = sizeof(uint32);
                        return B_OK;
                }

                case SO_RCVLOWAT:
                {
                        uint32* size = (uint32*)value;
                        *size = socket->receive.low_water_mark;
                        *_length = sizeof(uint32);
                        return B_OK;
                }

                case SO_RCVTIMEO:
                case SO_SNDTIMEO:
                {
                        if (*_length < (int)sizeof(struct timeval))
                                return B_BAD_VALUE;

                        bigtime_t timeout;
                        if (option == SO_SNDTIMEO)
                                timeout = socket->send.timeout;
                        else
                                timeout = socket->receive.timeout;
                        if (timeout == B_INFINITE_TIMEOUT)
                                timeout = 0;

                        struct timeval* timeval = (struct timeval*)value;
                        timeval->tv_sec = timeout / 1000000LL;
                        timeval->tv_usec = timeout % 1000000LL;

                        *_length = sizeof(struct timeval);
                        return B_OK;
                }

                case SO_NONBLOCK:
                {
                        int32* _set = (int32*)value;
                        *_set = socket->receive.timeout == 0 && socket->send.timeout == 0;
                        *_length = sizeof(int32);
                        return B_OK;
                }

                case SO_ACCEPTCONN:
                case SO_BROADCAST:
                case SO_DEBUG:
                case SO_DONTROUTE:
                case SO_KEEPALIVE:
                case SO_OOBINLINE:
                case SO_REUSEADDR:
                case SO_REUSEPORT:
                case SO_USELOOPBACK:
                {
                        int32* _set = (int32*)value;
                        *_set = (socket->options & option) != 0;
                        *_length = sizeof(int32);
                        return B_OK;
                }

                case SO_TYPE:
                {
                        int32* _set = (int32*)value;
                        *_set = socket->type;
                        *_length = sizeof(int32);
                        return B_OK;
                }

                case SO_ERROR:
                {
                        int32* _set = (int32*)value;
                        *_set = socket->error;
                        *_length = sizeof(int32);

                        socket->error = B_OK;
                                // clear error upon retrieval
                        return B_OK;
                }

                default:
                        break;
        }

        dprintf("socket_getsockopt: unknown option %d\n", option);
        return ENOPROTOOPT;
}


int
socket_getsockopt(net_socket* socket, int level, int option, void* value,
        int* _length)
{
        return socket->first_protocol->module->getsockopt(socket->first_protocol,
                level, option, value, _length);
}


int
socket_listen(net_socket* socket, int backlog)
{
        status_t status = socket->first_info->listen(socket->first_protocol,
                backlog);
        if (status == B_OK)
                socket->options |= SO_ACCEPTCONN;

        return status;
}


ssize_t
socket_receive(net_socket* socket, msghdr* header, void* data, size_t length,
        int flags)
{
        const int originalFlags = flags;

        // MSG_NOSIGNAL is only meaningful for send(), not receive(), but it is
        // sometimes specified anyway. Mask it off to avoid unnecessary errors.
        flags &= ~MSG_NOSIGNAL;

        // If the protocol sports read_data_no_buffer() we use it.
        if (socket->first_info->read_data_no_buffer != NULL)
                return socket_receive_no_buffer(socket, header, data, length, flags);

        // Mask off flags handled in this function.
        flags &= ~(MSG_TRUNC);

        size_t totalLength = calculate_total_length(header, data, length);

        net_buffer* buffer;
        status_t status = socket->first_info->read_data(
                socket->first_protocol, totalLength, flags, &buffer);
        if (status != B_OK)
                return status;

        // process ancillary data
        if (header != NULL) {
                if (buffer != NULL && header->msg_control != NULL) {
                        ancillary_data_container* container
                                = gNetBufferModule.get_ancillary_data(buffer);
                        if (container != NULL)
                                status = process_ancillary_data(socket, container, header, flags);
                        else
                                status = process_ancillary_data(socket, buffer, header, flags);
                        if (status != B_OK) {
                                gNetBufferModule.free(buffer);
                                return status;
                        }
                } else
                        header->msg_controllen = 0;
        }

        // TODO: - returning a NULL buffer when received 0 bytes
        //         may not make much sense as we still need the address

        size_t nameLen = 0;
        if (header != NULL) {
                // TODO: - consider the control buffer options
                nameLen = header->msg_namelen;
                header->msg_namelen = 0;
                header->msg_flags = 0;
        }

        if (buffer == NULL)
                return 0;

        const size_t bytesReceived = buffer->size;
        size_t bytesCopied = 0;

        size_t toRead = min_c(bytesReceived, length);
        status = gNetBufferModule.read(buffer, 0, data, toRead);
        if (status != B_OK) {
                gNetBufferModule.free(buffer);

                if (status == B_BAD_ADDRESS)
                        return status;
                return ENOBUFS;
        }

        // if first copy was a success, proceed to following copies as required
        bytesCopied += toRead;

        if (header != NULL) {
                // We start at iovec[1] as { data, length } is iovec[0].
                for (int i = 1; i < header->msg_iovlen && bytesCopied < bytesReceived; i++) {
                        iovec& vec = header->msg_iov[i];
                        toRead = min_c(bytesReceived - bytesCopied, vec.iov_len);
                        if (gNetBufferModule.read(buffer, bytesCopied, vec.iov_base,
                                        toRead) < B_OK) {
                                break;
                        }

                        bytesCopied += toRead;
                }

                if (header->msg_name != NULL) {
                        header->msg_namelen = min_c(nameLen, buffer->source->sa_len);
                        memcpy(header->msg_name, buffer->source, header->msg_namelen);
                }
        }

        gNetBufferModule.free(buffer);

        if (bytesCopied < bytesReceived) {
                if (header != NULL)
                        header->msg_flags = MSG_TRUNC;

                if ((originalFlags & MSG_TRUNC) != 0)
                        return bytesReceived;
        }

        return bytesCopied;
}


ssize_t
socket_send(net_socket* socket, msghdr* header, const void* data, size_t length,
        int flags)
{
        const bool nosignal = ((flags & MSG_NOSIGNAL) != 0);
        flags &= ~MSG_NOSIGNAL;

        size_t bytesLeft = length;
        if (length > SSIZE_MAX)
                return B_BAD_VALUE;

        ancillary_data_container* ancillaryData = NULL;
        CObjectDeleter<
                ancillary_data_container, void, delete_ancillary_data_container>
                ancillaryDataDeleter;

        const sockaddr* address = NULL;
        socklen_t addressLength = 0;
        if (header != NULL) {
                address = (const sockaddr*)header->msg_name;
                addressLength = header->msg_namelen;

                // get the ancillary data
                if (header->msg_control != NULL) {
                        ancillaryData = create_ancillary_data_container();
                        if (ancillaryData == NULL)
                                return B_NO_MEMORY;
                        ancillaryDataDeleter.SetTo(ancillaryData);

                        status_t status = add_ancillary_data(socket, ancillaryData,
                                (cmsghdr*)header->msg_control, header->msg_controllen);
                        if (status != B_OK)
                                return status;
                }
        }

        if (addressLength == 0)
                address = NULL;
        else if (address == NULL)
                return B_BAD_VALUE;

        if (socket->peer.ss_len != 0) {
                if (address != NULL)
                        return EISCONN;

                // socket is connected, we use that address
                address = (struct sockaddr*)&socket->peer;
                addressLength = socket->peer.ss_len;
        }

        if (address == NULL || addressLength == 0) {
                // don't know where to send to:
                return EDESTADDRREQ;
        }

        bool atomic = (socket->first_info->flags & NET_PROTOCOL_ATOMIC_MESSAGES) != 0;
        if (atomic && bytesLeft > socket->send.buffer_size)
                return EMSGSIZE;

        if (socket->address.ss_len == 0) {
                // try to bind first
                status_t status = socket_bind(socket, NULL, 0);
                if (status != B_OK)
                        return status;
        }

        // If the protocol has a send_data_no_buffer() hook, we use that one.
        if (socket->first_info->send_data_no_buffer != NULL) {
                iovec stackVec = { (void*)data, length };
                iovec* vecs = header ? header->msg_iov : &stackVec;
                int vecCount = header ? header->msg_iovlen : 1;

                ssize_t written = socket->first_info->send_data_no_buffer(
                        socket->first_protocol, vecs, vecCount, ancillaryData, address,
                        addressLength, flags);

                // we only send signals when called from userland
                if (written == EPIPE && is_syscall() && !nosignal)
                        send_signal(find_thread(NULL), SIGPIPE);

                if (written > 0)
                        ancillaryDataDeleter.Detach();
                return written;
        }

        // By convention, if a header is given, the (data, length) equals the first
        // iovec. So drop the header, if it is the only iovec. Otherwise compute
        // the size of the remaining ones.
        if (header != NULL) {
                if (header->msg_iovlen <= 1) {
                        header = NULL;
                } else {
                        for (int i = 1; i < header->msg_iovlen; i++)
                                bytesLeft += header->msg_iov[i].iov_len;
                }
        }

        ssize_t bytesSent = 0;
        size_t vecOffset = 0;
        uint32 vecIndex = 0;

        while (bytesLeft > 0 || atomic) {
                // TODO: useful, maybe even computed header space!
                net_buffer* buffer = gNetBufferModule.create(256);
                if (buffer == NULL)
                        return ENOBUFS;

                while (buffer->size < socket->send.buffer_size
                        && buffer->size < bytesLeft) {
                        if (vecIndex > 0 && vecOffset == 0) {
                                // retrieve next iovec buffer from header
                                data = header->msg_iov[vecIndex].iov_base;
                                length = header->msg_iov[vecIndex].iov_len;
                        }

                        size_t bytes = length;
                        if (buffer->size + bytes > socket->send.buffer_size)
                                bytes = socket->send.buffer_size - buffer->size;

                        if (gNetBufferModule.append(buffer, data, bytes) < B_OK) {
                                gNetBufferModule.free(buffer);
                                return ENOBUFS;
                        }

                        if (bytes != length) {
                                // partial send
                                vecOffset = bytes;
                                length -= vecOffset;
                                data = (uint8*)data + vecOffset;
                        } else if (header != NULL) {
                                // proceed with next buffer, if any
                                vecOffset = 0;
                                vecIndex++;

                                if (vecIndex >= (uint32)header->msg_iovlen)
                                        break;
                        }
                }

                // attach ancillary data to the first buffer
                status_t status;
                if (ancillaryData != NULL) {
                        gNetBufferModule.set_ancillary_data(buffer, ancillaryData);
                        ancillaryDataDeleter.Detach();
                        ancillaryData = NULL;
                }

                size_t bufferSize = buffer->size;
                buffer->msg_flags = flags;
                memcpy(buffer->source, &socket->address, socket->address.ss_len);
                memcpy(buffer->destination, address, addressLength);
                buffer->destination->sa_len = addressLength;

                status = socket->first_info->send_data(socket->first_protocol, buffer);
                if (status != B_OK) {
                        // we only send signals when called from userland
                        if (status == EPIPE && is_syscall() && !nosignal)
                                send_signal(find_thread(NULL), SIGPIPE);

                        size_t sizeAfterSend = buffer->size;
                        gNetBufferModule.free(buffer);

                        if ((sizeAfterSend != bufferSize || bytesSent > 0)
                                && (status == B_INTERRUPTED || status == B_WOULD_BLOCK)) {
                                // this appears to be a partial write
                                return bytesSent + (bufferSize - sizeAfterSend);
                        }
                        return status;
                }

                bytesLeft -= bufferSize;
                bytesSent += bufferSize;
                atomic = false;
        }

        return bytesSent;
}


status_t
socket_set_option(net_socket* socket, int level, int option, const void* value,
        int length)
{
        if (level != SOL_SOCKET)
                return ENOPROTOOPT;

        TRACE("%s(socket %p, option %d\n", __FUNCTION__, socket, option);

        switch (option) {
                // TODO: implement other options!
                case SO_LINGER:
                {
                        if (length < (int)sizeof(struct linger))
                                return B_BAD_VALUE;

                        struct linger* linger = (struct linger*)value;
                        if (linger->l_onoff) {
                                socket->options |= SO_LINGER;
                                socket->linger = linger->l_linger;
                        } else {
                                socket->options &= ~SO_LINGER;
                                socket->linger = 0;
                        }
                        return B_OK;
                }

                case SO_SNDBUF:
                        if (length != sizeof(uint32))
                                return B_BAD_VALUE;

                        socket->send.buffer_size = *(const uint32*)value;
                        return B_OK;

                case SO_RCVBUF:
                        if (length != sizeof(uint32))
                                return B_BAD_VALUE;

                        socket->receive.buffer_size = *(const uint32*)value;
                        return B_OK;

                case SO_SNDLOWAT:
                        if (length != sizeof(uint32))
                                return B_BAD_VALUE;

                        socket->send.low_water_mark = *(const uint32*)value;
                        return B_OK;

                case SO_RCVLOWAT:
                        if (length != sizeof(uint32))
                                return B_BAD_VALUE;

                        socket->receive.low_water_mark = *(const uint32*)value;
                        return B_OK;

                case SO_RCVTIMEO:
                case SO_SNDTIMEO:
                {
                        if (length != sizeof(struct timeval))
                                return B_BAD_VALUE;

                        const struct timeval* timeval = (const struct timeval*)value;
                        bigtime_t timeout = timeval->tv_sec * 1000000LL + timeval->tv_usec;
                        if (timeout == 0)
                                timeout = B_INFINITE_TIMEOUT;

                        if (option == SO_SNDTIMEO)
                                socket->send.timeout = timeout;
                        else
                                socket->receive.timeout = timeout;
                        return B_OK;
                }

                case SO_NONBLOCK:
                        if (length != sizeof(int32))
                                return B_BAD_VALUE;

                        if (*(const int32*)value) {
                                socket->send.timeout = 0;
                                socket->receive.timeout = 0;
                        } else {
                                socket->send.timeout = B_INFINITE_TIMEOUT;
                                socket->receive.timeout = B_INFINITE_TIMEOUT;
                        }
                        return B_OK;

                case SO_BROADCAST:
                case SO_DEBUG:
                case SO_DONTROUTE:
                case SO_KEEPALIVE:
                case SO_OOBINLINE:
                case SO_REUSEADDR:
                case SO_REUSEPORT:
                case SO_USELOOPBACK:
                        if (length != sizeof(int32))
                                return B_BAD_VALUE;

                        if (*(const int32*)value)
                                socket->options |= option;
                        else
                                socket->options &= ~option;
                        return B_OK;

                case SO_BINDTODEVICE:
                {
                        if (length != sizeof(uint32))
                                return B_BAD_VALUE;

                        // TODO: we might want to check if the device exists at all
                        // (although it doesn't really harm when we don't)
                        socket->bound_to_device = *(const uint32*)value;
                        return B_OK;
                }

                default:
                        break;
        }

        dprintf("socket_setsockopt: unknown option %d\n", option);
        return ENOPROTOOPT;
}


int
socket_setsockopt(net_socket* socket, int level, int option, const void* value,
        int length)
{
        return socket->first_protocol->module->setsockopt(socket->first_protocol,
                level, option, value, length);
}


int
socket_shutdown(net_socket* socket, int direction)
{
        return socket->first_info->shutdown(socket->first_protocol, direction);
}


status_t
socket_socketpair(int family, int type, int protocol, net_socket* sockets[2])
{
        sockets[0] = NULL;
        sockets[1] = NULL;

        // create sockets
        status_t error = socket_open(family, type, protocol, &sockets[0]);
        if (error != B_OK)
                return error;

        error = socket_open(family, type, protocol, &sockets[1]);

        // bind one
        if (error == B_OK)
                error = socket_bind(sockets[0], NULL, 0);

        // start listening
        if (error == B_OK && (type == SOCK_STREAM || type == SOCK_SEQPACKET))
                error = socket_listen(sockets[0], 1);

        // connect them
        if (error == B_OK) {
                error = socket_connect(sockets[1], (sockaddr*)&sockets[0]->address,
                        sockets[0]->address.ss_len);
        }

        if (error == B_OK) {
                // accept a socket
                if (type == SOCK_STREAM || type == SOCK_SEQPACKET) {
                        net_socket* acceptedSocket = NULL;
                        error = socket_accept(sockets[0], NULL, NULL, &acceptedSocket);
                        if (error == B_OK) {
                                // everything worked: close the listener socket
                                socket_close(sockets[0]);
                                socket_free(sockets[0]);
                                sockets[0] = acceptedSocket;
                        }
                // connect the other side
                } else {
                        error = socket_connect(sockets[0], (sockaddr*)&sockets[1]->address,
                                sockets[1]->address.ss_len);
                }
        }

        if (error != B_OK) {
                // close sockets on error
                for (int i = 0; i < 2; i++) {
                        if (sockets[i] != NULL) {
                                socket_close(sockets[i]);
                                socket_free(sockets[i]);
                                sockets[i] = NULL;
                        }
                }
        }

        return error;
}


//      #pragma mark -


static status_t
socket_std_ops(int32 op, ...)
{
        switch (op) {
                case B_MODULE_INIT:
                {
                        new (&sSocketList) SocketList;
                        mutex_init(&sSocketLock, "socket list");

#if ENABLE_DEBUGGER_COMMANDS
                        add_debugger_command("sockets", dump_sockets, "lists all sockets");
                        add_debugger_command("socket", dump_socket, "dumps a socket");
#endif
                        return B_OK;
                }
                case B_MODULE_UNINIT:
                        ASSERT(sSocketList.IsEmpty());
                        mutex_destroy(&sSocketLock);

#if ENABLE_DEBUGGER_COMMANDS
                        remove_debugger_command("socket", dump_socket);
                        remove_debugger_command("sockets", dump_sockets);
#endif
                        return B_OK;

                default:
                        return B_ERROR;
        }
}


net_socket_module_info gNetSocketModule = {
        {
                NET_SOCKET_MODULE_NAME,
                0,
                socket_std_ops
        },
        socket_open,
        socket_close,
        socket_free,

        socket_control,

        socket_read_avail,
        socket_send_avail,

        socket_send_data,
        socket_receive_data,

        socket_get_option,
        socket_set_option,

        socket_get_next_stat,

        // connections
        socket_acquire,
        socket_release,
        socket_spawn_pending,
        socket_dequeue_connected,
        socket_count_connected,
        socket_set_max_backlog,
        socket_has_parent,
        socket_connected,
        socket_aborted,

        // notifications
        socket_request_notification,
        socket_cancel_notification,
        socket_notify,

        // standard socket API
        socket_accept,
        socket_bind,
        socket_connect,
        socket_getpeername,
        socket_getsockname,
        socket_getsockopt,
        socket_listen,
        socket_receive,
        socket_send,
        socket_setsockopt,
        socket_shutdown,
        socket_socketpair
};