root/src/add-ons/kernel/network/protocols/unix/unix.cpp
/*
 * Copyright 2008, Ingo Weinhold, ingo_weinhold@gmx.de.
 * Distributed under the terms of the MIT License.
 */


#include <stdio.h>
#include <sys/un.h>

#include <new>

#include <AutoDeleter.h>
#include <StackOrHeapArray.h>

#include <fs/fd.h>
#include <lock.h>
#include <util/AutoLock.h>
#include <vfs.h>

#include <net_buffer.h>
#include <net_protocol.h>
#include <net_socket.h>
#include <net_stack.h>

#include "unix.h"
#include "UnixAddressManager.h"
#include "UnixEndpoint.h"


#define UNIX_MODULE_DEBUG_LEVEL 0
#define UNIX_DEBUG_LEVEL                UNIX_MODULE_DEBUG_LEVEL
#include "UnixDebug.h"


extern net_protocol_module_info gUnixModule;
        // extern only for forwarding

net_stack_module_info *gStackModule;
net_socket_module_info *gSocketModule;
net_buffer_module_info *gBufferModule;
UnixAddressManager gAddressManager;

static struct net_domain *sDomain;


void
destroy_scm_rights_descriptors(const ancillary_data_header* header,
        void* data)
{
        int count = header->len / sizeof(file_descriptor*);
        file_descriptor** descriptors = (file_descriptor**)data;
        io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());

        for (int i = 0; i < count; i++) {
                if (descriptors[i] != NULL) {
                        close_fd(ioContext, descriptors[i]);
                        put_fd(descriptors[i]);
                }
        }
}


void
clone_scm_rights_descriptors(const ancillary_data_header* header, void* data)
{
        int count = header->len / sizeof(file_descriptor*);
        file_descriptor** descriptors = (file_descriptor**)data;

        for (int i = 0; i < count; i++) {
                if (descriptors[i] != NULL) {
                        inc_fd_ref_count(descriptors[i]);
                        inc_fd_open_count(descriptors[i]);
                }
        }
}


// #pragma mark -


net_protocol *
unix_init_protocol(net_socket *socket)
{
        TRACE("[%" B_PRId32 "] unix_init_protocol(%p)\n", find_thread(NULL),
                socket);

        UnixEndpoint* endpoint;
        status_t error = UnixEndpoint::Create(socket, &endpoint);
        if (error != B_OK)
                return NULL;

        error = endpoint->Init();
        if (error != B_OK) {
                delete endpoint;
                return NULL;
        }

        return endpoint;
}


status_t
unix_uninit_protocol(net_protocol *_protocol)
{
        TRACE("[%" B_PRId32 "] unix_uninit_protocol(%p)\n", find_thread(NULL),
                _protocol);
        ((UnixEndpoint*)_protocol)->Uninit();
        return B_OK;
}


status_t
unix_open(net_protocol *_protocol)
{
        return ((UnixEndpoint*)_protocol)->Open();
}


status_t
unix_close(net_protocol *_protocol)
{
        return ((UnixEndpoint*)_protocol)->Close();
}


status_t
unix_free(net_protocol *_protocol)
{
        return ((UnixEndpoint*)_protocol)->Free();
}


status_t
unix_connect(net_protocol *_protocol, const struct sockaddr *address)
{
        return ((UnixEndpoint*)_protocol)->Connect(address);
}


status_t
unix_accept(net_protocol *_protocol, struct net_socket **_acceptedSocket)
{
        return ((UnixEndpoint*)_protocol)->Accept(_acceptedSocket);
}


status_t
unix_control(net_protocol *protocol, int level, int option, void *value,
        size_t *_length)
{
        return B_BAD_VALUE;
}


status_t
unix_getsockopt(net_protocol *protocol, int level, int option, void *value,
        int *_length)
{
        UnixEndpoint* endpoint = (UnixEndpoint*)protocol;

        if (level == SOL_SOCKET && option == SO_PEERCRED) {
                if (*_length < (int)sizeof(ucred))
                        return B_BAD_VALUE;

                *_length = sizeof(ucred);

                return endpoint->GetPeerCredentials((ucred*)value);
        }

        return gSocketModule->get_option(protocol->socket, level, option, value,
                _length);
}


status_t
unix_setsockopt(net_protocol *protocol, int level, int option,
        const void *_value, int length)
{
        UnixEndpoint* endpoint = (UnixEndpoint*)protocol;

        if (level == SOL_SOCKET) {
                if (option == SO_RCVBUF) {
                        if (length != sizeof(int))
                                return B_BAD_VALUE;

                        status_t error = endpoint->SetReceiveBufferSize(*(int*)_value);
                        if (error != B_OK)
                                return error;
                } else if (option == SO_SNDBUF) {
                        // We don't have a receive buffer, so silently ignore this one.
                }
        }

        return gSocketModule->set_option(protocol->socket, level, option,
                _value, length);
}


status_t
unix_bind(net_protocol *_protocol, const struct sockaddr *_address)
{
        return ((UnixEndpoint*)_protocol)->Bind(_address);
}


status_t
unix_unbind(net_protocol *_protocol, struct sockaddr *_address)
{
        return ((UnixEndpoint*)_protocol)->Unbind();
}


status_t
unix_listen(net_protocol *_protocol, int count)
{
        return ((UnixEndpoint*)_protocol)->Listen(count);
}


status_t
unix_shutdown(net_protocol *_protocol, int direction)
{
        return ((UnixEndpoint*)_protocol)->Shutdown(direction);
}


status_t
unix_send_routed_data(net_protocol *_protocol, struct net_route *route,
        net_buffer *buffer)
{
        return B_ERROR;
}


status_t
unix_send_data(net_protocol *_protocol, net_buffer *buffer)
{
        return B_ERROR;
}


ssize_t
unix_send_avail(net_protocol *_protocol)
{
        return ((UnixEndpoint*)_protocol)->Sendable();
}


status_t
unix_read_data(net_protocol *_protocol, size_t numBytes, uint32 flags,
        net_buffer **_buffer)
{
        return B_ERROR;
}


ssize_t
unix_read_avail(net_protocol *_protocol)
{
        return ((UnixEndpoint*)_protocol)->Receivable();
}


struct net_domain *
unix_get_domain(net_protocol *protocol)
{
        return sDomain;
}


size_t
unix_get_mtu(net_protocol *protocol, const struct sockaddr *address)
{
        return UNIX_MAX_TRANSFER_UNIT;
}


status_t
unix_receive_data(net_buffer *buffer)
{
        return B_ERROR;
}


status_t
unix_deliver_data(net_protocol *_protocol, net_buffer *buffer)
{
        return B_ERROR;
}


status_t
unix_error_received(net_error error, net_error_data* errorData, net_buffer *data)
{
        return B_ERROR;
}


status_t
unix_error_reply(net_protocol *protocol, net_buffer *cause, net_error error,
        net_error_data *errorData)
{
        return B_ERROR;
}


status_t
unix_add_ancillary_data(net_protocol *self, ancillary_data_container *container,
        const cmsghdr *header)
{
        TRACE("[%" B_PRId32 "] unix_add_ancillary_data(%p, %p, %p (level: %d, type: %d, "
                "len: %" B_PRId32 "))\n", find_thread(NULL), self, container, header,
                header->cmsg_level, header->cmsg_type, header->cmsg_len);

        // we support only SCM_RIGHTS
        if (header->cmsg_level != SOL_SOCKET || header->cmsg_type != SCM_RIGHTS)
                return B_BAD_VALUE;

        int* fds = (int*)CMSG_DATA(header);
        int count = (header->cmsg_len - CMSG_LEN(0)) / sizeof(int);
        if (count == 0)
                return B_BAD_VALUE;

        BStackOrHeapArray<file_descriptor*, 8> descriptors(count);
        if (!descriptors.IsValid())
                return ENOBUFS;
        memset(descriptors, 0, sizeof(file_descriptor*) * count);

        // get the file descriptors
        io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());

        status_t error = B_OK;
        for (int i = 0; i < count; i++) {
                descriptors[i] = get_open_fd(ioContext, fds[i]);
                if (descriptors[i] == NULL) {
                        error = EBADF;
                        break;
                }
        }

        // attach the ancillary data to the container
        if (error == B_OK) {
                ancillary_data_header header;
                header.level = SOL_SOCKET;
                header.type = SCM_RIGHTS;
                header.len = count * sizeof(file_descriptor*);

                TRACE("[%" B_PRId32 "] unix_add_ancillary_data(): adding %d FDs to "
                        "container\n", find_thread(NULL), count);

                error = gStackModule->add_ancillary_data(container, &header,
                        descriptors, destroy_scm_rights_descriptors, clone_scm_rights_descriptors, NULL);
        }

        // cleanup on error
        if (error != B_OK) {
                for (int i = 0; i < count; i++) {
                        if (descriptors[i] != NULL) {
                                close_fd(ioContext, descriptors[i]);
                                put_fd(descriptors[i]);
                        }
                }
        }

        return error;
}


ssize_t
unix_process_ancillary_data(net_protocol *self,
        const ancillary_data_container *container, void *buffer,
        size_t bufferSize, int flags)
{
        TRACE("[%" B_PRId32 "] unix_process_ancillary_data(%p, %p, %p, %p, %lu)\n",
                find_thread(NULL), self, container, buffer, bufferSize);

        int totalCount = 0;

        ancillary_data_header header;
        void* data = NULL;
        while ((data = gStackModule->next_ancillary_data(container, data, &header)) != NULL) {
                // we support only SCM_RIGHTS
                if (header.level != SOL_SOCKET || header.type != SCM_RIGHTS)
                        return B_BAD_VALUE;

                totalCount += header.len / sizeof(file_descriptor*);
        }

        // check if there's enough space in the buffer
        size_t neededBufferSpace = CMSG_SPACE(sizeof(int) * totalCount);
        if (bufferSize < neededBufferSpace)
                return B_BAD_VALUE;

        // init header
        cmsghdr* messageHeader = (cmsghdr*)buffer;
        messageHeader->cmsg_level = SOL_SOCKET;
        messageHeader->cmsg_type = SCM_RIGHTS;
        messageHeader->cmsg_len = CMSG_LEN(sizeof(int) * totalCount);

        // create FDs for the current process
        int* fds = (int*)CMSG_DATA(messageHeader);
        io_context* ioContext = get_current_io_context(!gStackModule->is_syscall());

        status_t error = B_OK;
        int i = 0;
        data = NULL;
        while ((data = gStackModule->next_ancillary_data(container, data, &header)) != NULL) {
                int count = header.len / sizeof(file_descriptor*);
                file_descriptor** descriptors = (file_descriptor**)data;

                for (int k = 0; k < count; k++, i++) {
                        // Get an additional reference which will go to the FD table index. The
                        // reference and open reference acquired in unix_add_ancillary_data()
                        // will be released when the container is destroyed.
                        inc_fd_ref_count(descriptors[k]);
                        fds[i] = new_fd(ioContext, descriptors[k]);

                        if (fds[i] < 0) {
                                error = fds[i];
                                put_fd(descriptors[k]);

                                // close FD indices
                                for (int j = i - 1; j >= 0; j--)
                                        close_fd_index(ioContext, fds[j]);
                                break;
                        }

                        WriteLocker locker(ioContext->lock);
                        if ((flags & MSG_CMSG_CLOEXEC) != 0)
                                fd_set_close_on_exec(ioContext, fds[i], true);
                        if ((flags & MSG_CMSG_CLOFORK) != 0)
                                fd_set_close_on_fork(ioContext, fds[i], true);
                }
                if (error != B_OK)
                        break;
        }

        return error == B_OK ? neededBufferSpace : error;
}


ssize_t
unix_send_data_no_buffer(net_protocol *_protocol, const iovec *vecs,
        size_t vecCount, ancillary_data_container *ancillaryData,
        const struct sockaddr *address, socklen_t addressLength, int flags)
{
        return ((UnixEndpoint*)_protocol)->Send(vecs, vecCount, ancillaryData,
                address, addressLength, flags);
}


ssize_t
unix_read_data_no_buffer(net_protocol *_protocol, const iovec *vecs,
        size_t vecCount, ancillary_data_container **_ancillaryData,
        struct sockaddr *_address, socklen_t *_addressLength, int flags)
{
        return ((UnixEndpoint*)_protocol)->Receive(vecs, vecCount, _ancillaryData,
                _address, _addressLength, flags);
}


// #pragma mark -


status_t
init_unix()
{
        new(&gAddressManager) UnixAddressManager;
        status_t error = gAddressManager.Init();
        if (error != B_OK)
                return error;

        error = gStackModule->register_domain_protocols(AF_UNIX, SOCK_STREAM, 0,
                "network/protocols/unix/v1", NULL);
        if (error == B_OK) {
                error = gStackModule->register_domain_protocols(AF_UNIX, SOCK_DGRAM, 0,
                        "network/protocols/unix/v1", NULL);
        }
        if (error == B_OK) {
                error = gStackModule->register_domain_protocols(AF_UNIX, SOCK_SEQPACKET, 0,
                        "network/protocols/unix/v1", NULL);
        }

        if (error != B_OK) {
                gAddressManager.~UnixAddressManager();
                return error;
        }

        error = gStackModule->register_domain(AF_UNIX, "unix", &gUnixModule,
                &gAddressModule, &sDomain);
        if (error != B_OK) {
                gAddressManager.~UnixAddressManager();
                return error;
        }

        return B_OK;
}


status_t
uninit_unix()
{
        gStackModule->unregister_domain(sDomain);

        gAddressManager.~UnixAddressManager();

        return B_OK;
}


static status_t
unix_std_ops(int32 op, ...)
{
        switch (op) {
                case B_MODULE_INIT:
                        return init_unix();
                case B_MODULE_UNINIT:
                        return uninit_unix();

                default:
                        return B_ERROR;
        }
}


net_protocol_module_info gUnixModule = {
        {
                "network/protocols/unix/v1",
                0,
                unix_std_ops
        },
        0,      // NET_PROTOCOL_ATOMIC_MESSAGES,

        unix_init_protocol,
        unix_uninit_protocol,
        unix_open,
        unix_close,
        unix_free,
        unix_connect,
        unix_accept,
        unix_control,
        unix_getsockopt,
        unix_setsockopt,
        unix_bind,
        unix_unbind,
        unix_listen,
        unix_shutdown,
        unix_send_data,
        unix_send_routed_data,
        unix_send_avail,
        unix_read_data,
        unix_read_avail,
        unix_get_domain,
        unix_get_mtu,
        unix_receive_data,
        unix_deliver_data,
        unix_error_received,
        unix_error_reply,
        unix_add_ancillary_data,
        unix_process_ancillary_data,
        NULL,
        unix_send_data_no_buffer,
        unix_read_data_no_buffer
};

module_dependency module_dependencies[] = {
        {NET_STACK_MODULE_NAME, (module_info **)&gStackModule},
        {NET_BUFFER_MODULE_NAME, (module_info **)&gBufferModule},
//      {NET_DATALINK_MODULE_NAME, (module_info **)&sDatalinkModule},
        {NET_SOCKET_MODULE_NAME, (module_info **)&gSocketModule},
        {}
};

module_info *modules[] = {
        (module_info *)&gUnixModule,
        NULL
};