root/src/preferences/mail/DNSQuery.cpp
#include "DNSQuery.h"

#include <errno.h>
#include <stdio.h>

#include <ByteOrder.h>
#include <FindDirectory.h>
#include <NetAddress.h>
#include <NetEndpoint.h>
#include <Path.h> 

// #define DEBUG 1

#undef PRINT
#ifdef DEBUG
#define PRINT(a...) printf(a)
#else
#define PRINT(a...)
#endif


static int32 gID = 1;


BRawNetBuffer::BRawNetBuffer()
{
        _Init(NULL, 0);
}


BRawNetBuffer::BRawNetBuffer(off_t size)
{
        _Init(NULL, 0);
        fBuffer.SetSize(size);
}


BRawNetBuffer::BRawNetBuffer(const void* buf, size_t size)
{
        _Init(buf, size);
}


status_t
BRawNetBuffer::AppendUint16(uint16 value)
{
        uint16 netVal = B_HOST_TO_BENDIAN_INT16(value);
        ssize_t sizeW = fBuffer.WriteAt(fWritePosition, &netVal, sizeof(uint16));
        if (sizeW == B_NO_MEMORY)
                return B_NO_MEMORY;
        fWritePosition += sizeof(uint16);
        return B_OK;
}


status_t
BRawNetBuffer::AppendString(const char* string)
{
        size_t length = strlen(string) + 1;
        ssize_t sizeW = fBuffer.WriteAt(fWritePosition, string, length);
        if (sizeW == B_NO_MEMORY)
                return B_NO_MEMORY;
        fWritePosition += length;
        return B_OK;
}


status_t
BRawNetBuffer::ReadUint16(uint16& value)
{
        uint16 netVal;
        ssize_t sizeW = fBuffer.ReadAt(fReadPosition, &netVal, sizeof(uint16));
        if (sizeW == 0)
                return B_ERROR;
        value= B_BENDIAN_TO_HOST_INT16(netVal);
        fReadPosition += sizeof(uint16);
        return B_OK;
}


status_t
BRawNetBuffer::ReadUint32(uint32& value)
{
        uint32 netVal;
        ssize_t sizeW = fBuffer.ReadAt(fReadPosition, &netVal, sizeof(uint32));
        if (sizeW == 0)
                return B_ERROR;
        value= B_BENDIAN_TO_HOST_INT32(netVal);
        fReadPosition += sizeof(uint32);
        return B_OK;
}


status_t
BRawNetBuffer::ReadString(BString& string)
{
        string = "";
        ssize_t bytesRead = _ReadStringAt(string, fReadPosition);
        if (bytesRead < 0)
                return B_ERROR;
        fReadPosition += bytesRead;
        return B_OK;
}


status_t
BRawNetBuffer::SkipReading(off_t skip)
{
        if (fReadPosition + skip > (off_t)fBuffer.BufferLength())
                return B_ERROR;
        fReadPosition += skip;
        return B_OK;
}


void
BRawNetBuffer::_Init(const void* buf, size_t size)
{
        fWritePosition = 0;
        fReadPosition = 0;
        fBuffer.WriteAt(fWritePosition, buf, size);
}


ssize_t
BRawNetBuffer::_ReadStringAt(BString& string, off_t pos)
{
        if (pos >= (off_t)fBuffer.BufferLength())
                return -1;

        ssize_t bytesRead = 0;
        char* buffer = (char*)fBuffer.Buffer();
        buffer = &buffer[pos];
        // if the string is compressed we have to follow the links to the
        // sub strings
        while (pos < (off_t)fBuffer.BufferLength() && *buffer != 0) {
                if (uint8(*buffer) == 192) {
                        // found a pointer mark
                        buffer++;
                        bytesRead++;
                        off_t subPos = uint8(*buffer);
                        _ReadStringAt(string, subPos);
                        break;
                }
                string.Append(buffer, 1);
                buffer++;
                bytesRead++;
        }
        bytesRead++;
        return bytesRead;
}


// #pragma mark - DNSTools


status_t
DNSTools::GetDNSServers(BObjectList<BString, true>* serverList)
{
        // TODO: reading resolv.conf ourselves shouldn't be needed.
        // we should have some function to retrieve the dns list
#define MATCH(line, name) \
        (!strncmp(line, name, sizeof(name) - 1) && \
        (line[sizeof(name) - 1] == ' ' || \
         line[sizeof(name) - 1] == '\t'))

        BPath path;
        if (find_directory(B_SYSTEM_SETTINGS_DIRECTORY, &path) != B_OK)
                return B_ENTRY_NOT_FOUND;

        path.Append("network/resolv.conf");

        FILE* fp = fopen(path.Path(), "r");
        if (fp == NULL) {
                fprintf(stderr, "failed to open '%s' to read nameservers: %s\n",
                        path.Path(), strerror(errno));
                return B_ENTRY_NOT_FOUND;
        }

        int nserv = 0;
        char buf[1024];
        char *cp; //, **pp;
        int MAXNS = 2;

        // read the config file
        while (fgets(buf, sizeof(buf), fp) != NULL) {
                // skip comments
                if (*buf == ';' || *buf == '#')
                        continue;

                // read nameservers to query
                if (MATCH(buf, "nameserver") && nserv < MAXNS) {
//                      char sbuf[2];
                        cp = buf + sizeof("nameserver") - 1;
                        while (*cp == ' ' || *cp == '\t')
                                cp++;
                        cp[strcspn(cp, ";# \t\n")] = '\0';
                        if ((*cp != '\0') && (*cp != '\n')) {
                                serverList->AddItem(new BString(cp));
                                nserv++;
                        }
                }
                continue;
        }

        fclose(fp);
        
        return B_OK;
}


BString
DNSTools::ConvertToDNSName(const BString& string)
{
        BString outString = string;
        int32 dot, lastDot, diff;

        dot = string.FindFirst(".");
        if (dot != B_ERROR) {
                outString.Prepend((char*)&dot, 1);
                // because we prepend a char add 1 more
                lastDot = dot + 1;

                while (true) {
                        dot = outString.FindFirst(".", lastDot + 1);
                        if (dot == B_ERROR)
                                break;

                        // set a counts to the dot
                        diff =  dot - 1 - lastDot;
                        outString.SetByteAt(lastDot, (char)diff);
                        lastDot = dot;
                }
        } else
                lastDot = 0;

        diff = outString.CountChars() - 1 - lastDot;
        outString.SetByteAt(lastDot, (char)diff);

        return outString;
}


BString
DNSTools::ConvertFromDNSName(const BString& string)
{
        if (string.Length() == 0)
                return string;

        BString outString = string;
        int32 dot = string[0];
        int32 nextDot = dot;
        outString.Remove(0, sizeof(char));
        while (true) {
                if (nextDot >= outString.Length())
                        break;
                dot = outString[nextDot];
                if (dot == 0)
                        break;
                // set a "."
                outString.SetByteAt(nextDot, '.');
                nextDot+= dot + 1;
        }
        return outString;
}


// #pragma mark - DNSQuery
// see http://tools.ietf.org/html/rfc1035 for more information about DNS


DNSQuery::DNSQuery()
{
}


DNSQuery::~DNSQuery()
{
}


status_t
DNSQuery::ReadDNSServer(in_addr* add)
{
        // list owns the items
        BObjectList<BString, true> dnsServerList(5);
        status_t status = DNSTools::GetDNSServers(&dnsServerList);
        if (status != B_OK)
                return status;
                
        BString* firstDNS = dnsServerList.ItemAt(0);
        if (firstDNS == NULL || inet_aton(firstDNS->String(), add) != 1)
                return B_ERROR;

        PRINT("dns server found: %s \n", firstDNS->String());
        return B_OK;
}


status_t
DNSQuery::GetMXRecords(const BString&  serverName,
        BObjectList<mx_record, true>* mxList, bigtime_t timeout)
{
        // get the DNS server to ask for the mx record
        in_addr dnsAddress;
        if (ReadDNSServer(&dnsAddress) != B_OK)
                return B_ERROR;

        // create dns query package
        BRawNetBuffer buffer;
        dns_header header;
        _SetMXHeader(&header);
        _AppendQueryHeader(buffer, &header);

        BString serverNameConv = DNSTools::ConvertToDNSName(serverName);
        buffer.AppendString(serverNameConv);
        buffer.AppendUint16(uint16(MX_RECORD));
        buffer.AppendUint16(uint16(1));

        // send the buffer
        PRINT("send buffer\n");
        BNetAddress netAddress(dnsAddress, 53);
        BNetEndpoint netEndpoint(SOCK_DGRAM);
        if (netEndpoint.InitCheck() != B_OK)
                return B_ERROR;

        if (netEndpoint.Connect(netAddress) != B_OK)
                return B_ERROR;
        PRINT("Connected\n");

        int32 bytesSend = netEndpoint.Send(buffer.Data(), buffer.Size());
        if (bytesSend == B_ERROR)
                return B_ERROR;
        PRINT("bytes send %i\n", int(bytesSend));

        // receive buffer
        BRawNetBuffer receiBuffer(512);
        netEndpoint.SetTimeout(timeout);

        int32 bytesRecei = netEndpoint.ReceiveFrom(receiBuffer.Data(), 512,
                netAddress);
        if (bytesRecei == B_ERROR)
                return B_ERROR;
        PRINT("bytes received %i\n", int(bytesRecei));

        dns_header receiHeader;

        _ReadQueryHeader(receiBuffer, &receiHeader);
        PRINT("Package contains :");
        PRINT("%d Questions, ", receiHeader.q_count);
        PRINT("%d Answers, ", receiHeader.ans_count);
        PRINT("%d Authoritative Servers, ", receiHeader.auth_count);
        PRINT("%d Additional records\n", receiHeader.add_count);

        // remove name and Question
        BString dummyS;
        uint16 dummy;
        receiBuffer.ReadString(dummyS);
        receiBuffer.ReadUint16(dummy);
        receiBuffer.ReadUint16(dummy);

        bool mxRecordFound = false;
        for (int i = 0; i < receiHeader.ans_count; i++) {
                resource_record_head rrHead;
                _ReadResourceRecord(receiBuffer, &rrHead);
                if (rrHead.type == MX_RECORD) {
                        mx_record* mxRec = new mx_record;
                        _ReadMXRecord(receiBuffer, mxRec);
                        PRINT("MX record found pri %i, name %s\n",
                                mxRec->priority, mxRec->serverName.String());
                        // Add mx record to the list
                        mxList->AddItem(mxRec);
                        mxRecordFound = true;
                } else {
                        buffer.SkipReading(rrHead.dataLength);
                }
        }

        if (!mxRecordFound)
                return B_ERROR;

        return B_OK;
}


uint16
DNSQuery::_GetUniqueID()
{
        int32 nextId= atomic_add(&gID, 1);
        // just to be sure
        if (nextId > 65529)
                nextId = 0;
        return nextId;
}


void
DNSQuery::_SetMXHeader(dns_header* header)
{
        header->id = _GetUniqueID();
        header->qr = 0;      //This is a query
        header->opcode = 0;  //This is a standard query
        header->aa = 0;      //Not Authoritative
        header->tc = 0;      //This message is not truncated
        header->rd = 1;      //Recursion Desired
        header->ra = 0;      //Recursion not available! hey we dont have it (lol)
        header->z  = 0;
        header->rcode = 0;
        header->q_count = 1;   //we have only 1 question
        header->ans_count  = 0;
        header->auth_count = 0;
        header->add_count  = 0;
}


void
DNSQuery::_AppendQueryHeader(BRawNetBuffer& buffer, const dns_header* header)
{
        buffer.AppendUint16(header->id);
        uint16 data = 0;
        data |= header->rcode;
        data |= header->z << 4;
        data |= header->ra << 7;
        data |= header->rd << 8;
        data |= header->tc << 9;
        data |= header->aa << 10;
        data |= header->opcode << 11;
        data |= header->qr << 15;
        buffer.AppendUint16(data);
        buffer.AppendUint16(header->q_count);
        buffer.AppendUint16(header->ans_count);
        buffer.AppendUint16(header->auth_count);
        buffer.AppendUint16(header->add_count);
}


void
DNSQuery::_ReadQueryHeader(BRawNetBuffer& buffer, dns_header* header)
{
        buffer.ReadUint16(header->id);
        uint16 data = 0;
        buffer.ReadUint16(data);
        header->rcode = data & 0x0F;
        header->z = (data >> 4) & 0x07;
        header->ra = (data >> 7) & 0x01;
        header->rd = (data >> 8) & 0x01;
        header->tc = (data >> 9) & 0x01;
        header->aa = (data >> 10) & 0x01;
        header->opcode = (data >> 11) & 0x0F;
        header->qr = (data >> 15) & 0x01;
        buffer.ReadUint16(header->q_count);
        buffer.ReadUint16(header->ans_count);
        buffer.ReadUint16(header->auth_count);
        buffer.ReadUint16(header->add_count);
}


void
DNSQuery::_ReadMXRecord(BRawNetBuffer& buffer, mx_record* mxRecord)
{
        buffer.ReadUint16(mxRecord->priority);
        buffer.ReadString(mxRecord->serverName);
        mxRecord->serverName = DNSTools::ConvertFromDNSName(mxRecord->serverName);
}


void
DNSQuery::_ReadResourceRecord(BRawNetBuffer& buffer,
        resource_record_head *rrHead)
{
        buffer.ReadString(rrHead->name);
        buffer.ReadUint16(rrHead->type);
        buffer.ReadUint16(rrHead->dataClass);
        buffer.ReadUint32(rrHead->ttl);
        buffer.ReadUint16(rrHead->dataLength);
}