root/src/add-ons/kernel/network/stack/link.cpp
/*
 * Copyright 2006-2013, Haiku, Inc. All Rights Reserved.
 * Distributed under the terms of the MIT License.
 *
 * Authors:
 *              Axel Dörfler, axeld@pinc-software.de
 */


//! The net_protocol one talks to when using the AF_LINK protocol


#include "link.h"

#include <net/if_dl.h>
#include <net/if_types.h>
#include <new>
#include <stdlib.h>
#include <string.h>
#include <sys/sockio.h>

#include <KernelExport.h>

#include <lock.h>
#include <net_datalink.h>
#include <net_device.h>
#include <ProtocolUtilities.h>
#include <util/AutoLock.h>

#include "device_interfaces.h"
#include "domains.h"
#include "interfaces.h"
#include "stack_private.h"
#include "utility.h"


class LocalStackBundle {
public:
        static net_stack_module_info* Stack() { return &gNetStackModule; }
        static net_buffer_module_info* Buffer() { return &gNetBufferModule; }
};

typedef DatagramSocket<MutexLocking, LocalStackBundle> LocalDatagramSocket;

class LinkProtocol : public net_protocol, public LocalDatagramSocket {
public:
                                                                LinkProtocol(net_socket* socket);
        virtual                                         ~LinkProtocol();

                        status_t                        StartMonitoring(const char* deviceName);
                        status_t                        StopMonitoring(const char* deviceName);

                        status_t                        Bind(const sockaddr* address);
                        status_t                        Unbind();
                        bool                            IsBound() const
                                                                        { return fBoundToDevice != NULL; }

                        size_t                          MTU();

protected:
                        status_t                        SocketStatus(bool peek) const;

private:
                        status_t                        _Unregister();

        static  status_t                        _MonitorData(net_device_monitor* monitor,
                                                                        net_buffer* buffer);
        static  void                            _MonitorEvent(net_device_monitor* monitor,
                                                                        int32 event);
        static  status_t                        _ReceiveData(void* cookie, net_device* device,
                                                                        net_buffer* buffer);

private:
                        net_device_monitor      fMonitor;
                        net_device_interface* fMonitoredDevice;
                        net_device_interface* fBoundToDevice;
                        uint32                          fBoundType;
};


struct net_domain* sDomain;


LinkProtocol::LinkProtocol(net_socket* socket)
        :
        LocalDatagramSocket("packet capture", socket),
        fMonitoredDevice(NULL),
        fBoundToDevice(NULL)
{
        fMonitor.cookie = this;
        fMonitor.receive = _MonitorData;
        fMonitor.event = _MonitorEvent;
}


LinkProtocol::~LinkProtocol()
{
        if (fMonitoredDevice != NULL) {
                unregister_device_monitor(fMonitoredDevice->device, &fMonitor);
                put_device_interface(fMonitoredDevice);
        } else
                Unbind();
}


status_t
LinkProtocol::StartMonitoring(const char* deviceName)
{
        MutexLocker locker(fLock);

        if (fMonitoredDevice != NULL)
                return B_BUSY;

        net_device_interface* interface = get_device_interface(deviceName);
        if (interface == NULL)
                return B_DEVICE_NOT_FOUND;

        status_t status = register_device_monitor(interface->device, &fMonitor);
        if (status < B_OK) {
                put_device_interface(interface);
                return status;
        }

        fMonitoredDevice = interface;
        return B_OK;
}


status_t
LinkProtocol::StopMonitoring(const char* deviceName)
{
        MutexLocker locker(fLock);

        if (fMonitoredDevice == NULL
                || strcmp(fMonitoredDevice->device->name, deviceName) != 0)
                return B_BAD_VALUE;

        return _Unregister();
}


status_t
LinkProtocol::Bind(const sockaddr* address)
{
        // Only root is allowed to bind to a link layer interface
        if (address == NULL || geteuid() != 0)
                return B_NOT_ALLOWED;

        MutexLocker locker(fLock);

        if (fMonitoredDevice != NULL)
                return B_BUSY;

        Interface* interface = get_interface_for_link(sDomain, address);
        if (interface == NULL)
                return B_BAD_VALUE;

        net_device_interface* boundTo
                = acquire_device_interface(interface->DeviceInterface());

        interface->ReleaseReference();

        if (boundTo == NULL)
                return B_BAD_VALUE;

        sockaddr_dl& linkAddress = *(sockaddr_dl*)address;

        if (linkAddress.sdl_type != 0) {
                fBoundType = B_NET_FRAME_TYPE(linkAddress.sdl_type,
                        ntohs(linkAddress.sdl_e_type));
                // Bind to the type requested - this is needed in order to
                // receive any buffers
                // TODO: this could be easily changed by introducing catch all or rule
                // based handlers!
                status_t status = register_device_handler(boundTo->device, fBoundType,
                        &LinkProtocol::_ReceiveData, this);
                if (status != B_OK)
                        return status;
        } else
                fBoundType = 0;

        fBoundToDevice = boundTo;
        socket->bound_to_device = boundTo->device->index;

        memcpy(&socket->address, address, sizeof(struct sockaddr_storage));
        socket->address.ss_len = sizeof(struct sockaddr_storage);

        return B_OK;
}


status_t
LinkProtocol::Unbind()
{
        MutexLocker locker(fLock);

        if (fBoundToDevice == NULL)
                return B_BAD_VALUE;

        unregister_device_handler(fBoundToDevice->device, fBoundType);
        put_device_interface(fBoundToDevice);

        socket->bound_to_device = 0;
        socket->address.ss_len = 0;
        return B_OK;
}


size_t
LinkProtocol::MTU()
{
        MutexLocker locker(fLock);

        if (!IsBound())
                return 0;

        return fBoundToDevice->device->mtu;
}


status_t
LinkProtocol::SocketStatus(bool peek) const
{
        if (fMonitoredDevice == NULL && !IsBound())
                return B_DEVICE_NOT_FOUND;

        return LocalDatagramSocket::SocketStatus(peek);
}


status_t
LinkProtocol::_Unregister()
{
        if (fMonitoredDevice == NULL)
                return B_BAD_VALUE;

        status_t status = unregister_device_monitor(fMonitoredDevice->device,
                &fMonitor);
        put_device_interface(fMonitoredDevice);
        fMonitoredDevice = NULL;

        return status;
}


/*static*/ status_t
LinkProtocol::_MonitorData(net_device_monitor* monitor, net_buffer* packet)
{
        return ((LinkProtocol*)monitor->cookie)->EnqueueClone(packet);
}


/*static*/ void
LinkProtocol::_MonitorEvent(net_device_monitor* monitor, int32 event)
{
        LinkProtocol* protocol = (LinkProtocol*)monitor->cookie;

        if (event == B_DEVICE_GOING_DOWN) {
                MutexLocker _(protocol->fLock);

                protocol->_Unregister();
                if (protocol->IsEmpty()) {
                        protocol->WakeAll();
                        notify_socket(protocol->socket, B_SELECT_READ, B_DEVICE_NOT_FOUND);
                }
        }
}


/*static*/ status_t
LinkProtocol::_ReceiveData(void* cookie, net_device* device, net_buffer* buffer)
{
        LinkProtocol* protocol = (LinkProtocol*)cookie;

        return protocol->Enqueue(buffer);
}


//      #pragma mark -


static bool
user_request_get_device_interface(void* value, struct ifreq& request,
        net_device_interface*& interface)
{
        if (user_memcpy(&request, value, IF_NAMESIZE) < B_OK)
                return false;

        interface = get_device_interface(request.ifr_name);
        return true;
}


//      #pragma mark - net_protocol module


static net_protocol*
link_init_protocol(net_socket* socket)
{
        LinkProtocol* protocol = new (std::nothrow) LinkProtocol(socket);
        if (protocol != NULL && protocol->InitCheck() < B_OK) {
                delete protocol;
                return NULL;
        }

        return protocol;
}


static status_t
link_uninit_protocol(net_protocol* protocol)
{
        delete (LinkProtocol*)protocol;
        return B_OK;
}


static status_t
link_open(net_protocol* protocol)
{
        return B_OK;
}


static status_t
link_close(net_protocol* protocol)
{
        return B_OK;
}


static status_t
link_free(net_protocol* protocol)
{
        return B_OK;
}


static status_t
link_connect(net_protocol* protocol, const struct sockaddr* address)
{
        return B_NOT_SUPPORTED;
}


static status_t
link_accept(net_protocol* protocol, struct net_socket** _acceptedSocket)
{
        return B_NOT_SUPPORTED;
}


static status_t
link_control(net_protocol* _protocol, int level, int option, void* value,
        size_t* _length)
{
        LinkProtocol* protocol = (LinkProtocol*)_protocol;

        switch (option) {
                case SIOCGIFINDEX:
                {
                        // get index of interface
                        net_device_interface* interface;
                        struct ifreq request;
                        if (!user_request_get_device_interface(value, request, interface))
                                return B_BAD_ADDRESS;

                        if (interface != NULL) {
                                request.ifr_index = interface->device->index;
                                put_device_interface(interface);
                        } else
                                request.ifr_index = 0;

                        return user_memcpy(value, &request, sizeof(struct ifreq));
                }
                case SIOCGIFNAME:
                {
                        // get name of interface via index
                        struct ifreq request;
                        if (user_memcpy(&request, value, sizeof(struct ifreq)) < B_OK)
                                return B_BAD_ADDRESS;

                        net_device_interface* interface
                                = get_device_interface(request.ifr_index);
                        if (interface == NULL)
                                return B_DEVICE_NOT_FOUND;

                        strlcpy(request.ifr_name, interface->device->name, IF_NAMESIZE);
                        put_device_interface(interface);

                        return user_memcpy(value, &request, sizeof(struct ifreq));
                }

                case SIOCGIFCOUNT:
                {
                        // count number of interfaces
                        struct ifconf config;
                        config.ifc_value = count_device_interfaces();

                        return user_memcpy(value, &config, sizeof(struct ifconf));
                }

                case SIOCGIFCONF:
                {
                        // retrieve available interfaces
                        struct ifconf config;
                        if (user_memcpy(&config, value, sizeof(struct ifconf)) < B_OK)
                                return B_BAD_ADDRESS;

                        status_t result = list_device_interfaces(config.ifc_buf,
                                (size_t*)&config.ifc_len);
                        if (result != B_OK)
                                return result;

                        return user_memcpy(value, &config, sizeof(struct ifconf));
                }

                case SIOCGIFADDR:
                {
                        // get address of interface
                        net_device_interface* interface;
                        struct ifreq request;
                        if (!user_request_get_device_interface(value, request, interface))
                                return B_BAD_ADDRESS;

                        if (interface == NULL)
                                return B_DEVICE_NOT_FOUND;

                        sockaddr_storage address;
                        get_device_interface_address(interface, (sockaddr*)&address);
                        put_device_interface(interface);

                        return user_memcpy(&((struct ifreq*)value)->ifr_addr,
                                &address, address.ss_len);
                }

                case SIOCGIFFLAGS:
                {
                        // get flags of interface
                        net_device_interface* interface;
                        struct ifreq request;
                        if (!user_request_get_device_interface(value, request, interface))
                                return B_BAD_ADDRESS;

                        if (interface == NULL)
                                return B_DEVICE_NOT_FOUND;

                        request.ifr_flags = interface->device->flags;
                        put_device_interface(interface);

                        return user_memcpy(&((struct ifreq*)value)->ifr_flags,
                                &request.ifr_flags, sizeof(request.ifr_flags));
                }

                case SIOCGIFMEDIA:
                {
                        // get media
                        const size_t copylen = offsetof(ifreq, ifr_media) + sizeof(ifreq::ifr_media);
                        if (*_length > 0 && *_length < copylen)
                                return B_BAD_VALUE;

                        net_device_interface* interface;
                        struct ifreq request;
                        if (!user_request_get_device_interface(value, request, interface))
                                return B_BAD_ADDRESS;
                        if (interface == NULL)
                                return B_DEVICE_NOT_FOUND;

                        request.ifr_media = interface->device->media;

                        put_device_interface(interface);

                        return user_memcpy(value, &request, copylen);
                }

                case SIOCSPACKETCAP:
                {
                        // Only root is allowed to capture packets
                        if (geteuid() != 0)
                                return B_NOT_ALLOWED;

                        struct ifreq request;
                        if (user_memcpy(&request, value, IF_NAMESIZE) != B_OK)
                                return B_BAD_ADDRESS;

                        return protocol->StartMonitoring(request.ifr_name);
                }

                case SIOCCPACKETCAP:
                {
                        struct ifreq request;
                        if (user_memcpy(&request, value, IF_NAMESIZE) != B_OK)
                                return B_BAD_ADDRESS;

                        return protocol->StopMonitoring(request.ifr_name);
                }
        }

        return gNetDatalinkModule.control(sDomain, option, value, _length);
}


static status_t
link_getsockopt(net_protocol* protocol, int level, int option, void* value,
        int* length)
{
        if (protocol->next != NULL) {
                return protocol->next->module->getsockopt(protocol, level, option,
                        value, length);
        }

        return gNetSocketModule.get_option(protocol->socket, level, option, value,
                length);
}


static status_t
link_setsockopt(net_protocol* protocol, int level, int option,
        const void* value, int length)
{
        if (protocol->next != NULL) {
                return protocol->next->module->setsockopt(protocol, level, option,
                        value, length);
        }

        return gNetSocketModule.set_option(protocol->socket, level, option,
                value, length);
}


static status_t
link_bind(net_protocol* _protocol, const struct sockaddr* address)
{
        LinkProtocol* protocol = (LinkProtocol*)_protocol;
        return protocol->Bind(address);
}


static status_t
link_unbind(net_protocol* _protocol, struct sockaddr* address)
{
        LinkProtocol* protocol = (LinkProtocol*)_protocol;
        return protocol->Unbind();
}


static status_t
link_listen(net_protocol* protocol, int count)
{
        return B_NOT_SUPPORTED;
}


static status_t
link_shutdown(net_protocol* protocol, int direction)
{
        return B_NOT_SUPPORTED;
}


static status_t
link_send_data(net_protocol* protocol, net_buffer* buffer)
{
        return gNetDatalinkModule.send_data(protocol, sDomain, buffer);
}


static status_t
link_send_routed_data(net_protocol* protocol, struct net_route* route,
        net_buffer* buffer)
{
        if (buffer->destination->sa_family != buffer->source->sa_family
                || buffer->destination->sa_family != AF_LINK)
                return B_BAD_VALUE;

        // The datalink layer will take care of the framing

        return gNetDatalinkModule.send_routed_data(route, buffer);
}


static ssize_t
link_send_avail(net_protocol* _protocol)
{
        LinkProtocol* protocol = (LinkProtocol*)_protocol;
        if (!protocol->IsBound())
                return B_ERROR;

        return protocol->socket->send.buffer_size;
}


static status_t
link_read_data(net_protocol* protocol, size_t numBytes, uint32 flags,
        net_buffer** _buffer)
{
        return ((LinkProtocol*)protocol)->Dequeue(flags, _buffer);
}


static ssize_t
link_read_avail(net_protocol* protocol)
{
        return ((LinkProtocol*)protocol)->AvailableData();
}


static struct net_domain*
link_get_domain(net_protocol* protocol)
{
        return sDomain;
}


static size_t
link_get_mtu(net_protocol* _protocol, const struct sockaddr* address)
{
        LinkProtocol* protocol = (LinkProtocol*)_protocol;
        return protocol->MTU();
}


static status_t
link_receive_data(net_buffer* buffer)
{
        // We never receive any data this way
        return B_ERROR;
}


static status_t
link_error_received(net_error error, net_error_data* errorData, net_buffer* data)
{
        // We don't do any error processing
        return B_ERROR;
}


static status_t
link_error_reply(net_protocol* protocol, net_buffer* cause, net_error error,
        net_error_data* errorData)
{
        // We don't do any error processing
        return B_ERROR;
}


static status_t
link_std_ops(int32 op, ...)
{
        switch (op) {
                case B_MODULE_INIT:
                        return register_domain(AF_LINK, "link", NULL, NULL, &sDomain);

                case B_MODULE_UNINIT:
                        unregister_domain(sDomain);
                        return B_OK;

                default:
                        return B_ERROR;
        }
}


//      #pragma mark -


void
link_init()
{
        register_domain_protocols(AF_LINK, SOCK_DGRAM, 0, "network/stack/link/v1",
                NULL);

        // TODO: this should actually be registered for all types (besides local)
        register_domain_datalink_protocols(AF_LINK, IFT_ETHER,
                "network/datalink_protocols/ethernet_frame/v1",
                NULL);
}


net_protocol_module_info gLinkModule = {
        {
                "network/stack/link/v1",
                0,
                link_std_ops
        },
        NET_PROTOCOL_ATOMIC_MESSAGES,

        link_init_protocol,
        link_uninit_protocol,
        link_open,
        link_close,
        link_free,
        link_connect,
        link_accept,
        link_control,
        link_getsockopt,
        link_setsockopt,
        link_bind,
        link_unbind,
        link_listen,
        link_shutdown,
        link_send_data,
        link_send_routed_data,
        link_send_avail,
        link_read_data,
        link_read_avail,
        link_get_domain,
        link_get_mtu,
        link_receive_data,
        NULL,           // deliver_data
        link_error_received,
        link_error_reply,
        NULL,           // add_ancillary_data()
        NULL,           // process_ancillary_data()
        NULL,           // process_ancillary_data_no_container()
        NULL,           // send_data_no_buffer()
        NULL            // read_data_no_buffer()
};