root/usr.sbin/nsd/packet.c
/*
 * packet.c -- low-level DNS packet encoding and decoding functions.
 *
 * Copyright (c) 2001-2006, NLnet Labs. All rights reserved.
 *
 * See LICENSE for the license.
 *
 */

#include "config.h"

#include <string.h>

#include "packet.h"
#include "query.h"
#include "rdata.h"
#include "dns.h"

int round_robin = 0;
int minimal_responses = 0;

static void
encode_dname(query_type *q, domain_type *domain)
{
        while (domain->parent && query_get_dname_offset(q, domain) == 0) {
                query_put_dname_offset(q, domain, buffer_position(q->packet));
                DEBUG(DEBUG_NAME_COMPRESSION, 2,
                      (LOG_INFO, "dname: %s, number: %lu, offset: %u\n",
                       domain_to_string(domain),
                       (unsigned long) domain->number,
                       query_get_dname_offset(q, domain)));
                buffer_write(q->packet, dname_name(domain_dname(domain)),
                             label_length(dname_name(domain_dname(domain))) + 1U);
                domain = domain->parent;
        }
        if (domain->parent) {
                DEBUG(DEBUG_NAME_COMPRESSION, 2,
                      (LOG_INFO, "dname: %s, number: %lu, pointer: %u\n",
                       domain_to_string(domain),
                       (unsigned long) domain->number,
                       query_get_dname_offset(q, domain)));
                assert(query_get_dname_offset(q, domain) <= MAX_COMPRESSION_OFFSET);
                buffer_write_u16(q->packet,
                                 0xc000 | query_get_dname_offset(q, domain));
        } else {
                buffer_write_u8(q->packet, 0);
        }
}

int
packet_encode_rr(query_type *q, domain_type *owner, rr_type *rr, uint32_t ttl)
{
        size_t truncation_mark;
        uint16_t rdlength = 0;
        size_t rdlength_pos;
        const nsd_type_descriptor_type *descriptor =
                nsd_type_descriptor(rr->type);

        assert(q);
        assert(owner);
        assert(rr);

        /*
         * If the record does not in fit in the packet the packet size
         * will be restored to the mark.
         */
        truncation_mark = buffer_position(q->packet);

        encode_dname(q, owner);
        buffer_write_u16(q->packet, rr->type);
        buffer_write_u16(q->packet, rr->klass);
        buffer_write_u32(q->packet, ttl);

        /* Reserve space for rdlength. */
        rdlength_pos = buffer_position(q->packet);
        buffer_skip(q->packet, sizeof(rdlength));

        descriptor->write_rdata(q, rr);

        if (!query_overflow(q)) {
                rdlength = (buffer_position(q->packet) - rdlength_pos
                            - sizeof(rdlength));
                buffer_write_u16_at(q->packet, rdlength_pos, rdlength);
                return 1;
        } else {
                buffer_set_position(q->packet, truncation_mark);
                query_clear_dname_offsets(q, truncation_mark);
                assert(!query_overflow(q));
                return 0;
        }
}

int
packet_encode_rrset(query_type *query,
                    domain_type *owner,
                    rrset_type *rrset,
                    int section,
#ifdef MINIMAL_RESPONSES
                    size_t minimal_respsize,
                    int* done)
#else
                    size_t ATTR_UNUSED(minimal_respsize),
                    int* ATTR_UNUSED(done))
#endif
{
        uint16_t i;
        size_t truncation_mark;
        uint16_t added = 0;
        int all_added = 1;
#ifdef MINIMAL_RESPONSES
        int minimize_response = (section >= OPTIONAL_AUTHORITY_SECTION);
        int truncate_rrset = (section == ANSWER_SECTION ||
                                section == AUTHORITY_SECTION);
#else
        int truncate_rrset = (section == ANSWER_SECTION ||
                                section == AUTHORITY_SECTION ||
                                section == OPTIONAL_AUTHORITY_SECTION);
#endif
        static int round_robin_off = 0;
        int do_robin = (round_robin && section == ANSWER_SECTION &&
                query->qtype != TYPE_AXFR && query->qtype != TYPE_IXFR);
        uint16_t start;
        rrset_type *rrsig;

        assert(rrset->rr_count > 0);

        truncation_mark = buffer_position(query->packet);

        if(do_robin && rrset->rr_count)
                start = (uint16_t)(round_robin_off++ % rrset->rr_count);
        else    start = 0;
        for (i = start; i < rrset->rr_count; ++i) {
                if (packet_encode_rr(query, owner, rrset->rrs[i],
                        rrset->rrs[i]->ttl)) {
                        ++added;
                } else {
                        all_added = 0;
                        start = 0;
                        break;
                }
        }
        for (i = 0; i < start; ++i) {
                if (packet_encode_rr(query, owner, rrset->rrs[i],
                        rrset->rrs[i]->ttl)) {
                        ++added;
                } else {
                        all_added = 0;
                        break;
                }
        }

        if (all_added &&
            query->edns.dnssec_ok &&
            zone_is_secure(rrset->zone) &&
            rrset_rrtype(rrset) != TYPE_RRSIG &&
            (rrsig = domain_find_rrset(owner, rrset->zone, TYPE_RRSIG)))
        {
                for (i = 0; i < rrsig->rr_count; ++i) {
                        if (rr_rrsig_type_covered(rrsig->rrs[i])
                            == rrset_rrtype(rrset))
                        {
                                if (packet_encode_rr(query, owner,
                                        rrsig->rrs[i],
                                        rrset_rrtype(rrset)==TYPE_SOA?rrset->rrs[0]->ttl:rrsig->rrs[i]->ttl))
                                {
                                        ++added;
                                } else {
                                        all_added = 0;
                                        break;
                                }
                        }
                }
        }

#ifdef MINIMAL_RESPONSES
        if ((!all_added || buffer_position(query->packet) > minimal_respsize)
            && !query->tcp && minimize_response) {
                /* Truncate entire RRset. */
                buffer_set_position(query->packet, truncation_mark);
                query_clear_dname_offsets(query, truncation_mark);
                added = 0;
                *done = 1;
        }
#endif

        if (!all_added && truncate_rrset) {
                /* Truncate entire RRset and set truncate flag. */
                buffer_set_position(query->packet, truncation_mark);
                query_clear_dname_offsets(query, truncation_mark);
                TC_SET(query->packet);
                added = 0;
        }

        return added;
}

int
packet_skip_dname(buffer_type *packet)
{
        while (1) {
                uint8_t label_size;
                if (!buffer_available(packet, 1))
                        return 0;

                label_size = buffer_read_u8(packet);
                if (label_size == 0) {
                        return 1;
                } else if ((label_size & 0xc0) != 0) {
                        if (!buffer_available(packet, 1))
                                return 0;
                        buffer_skip(packet, 1);
                        return 1;
                } else if (!buffer_available(packet, label_size)) {
                        return 0;
                } else {
                        buffer_skip(packet, label_size);
                }
        }
}

int
packet_skip_rr(buffer_type *packet, int question_section)
{
        if (!packet_skip_dname(packet))
                return 0;

        if (question_section) {
                if (!buffer_available(packet, 4))
                        return 0;
                buffer_skip(packet, 4);
        } else {
                uint16_t rdata_size;
                if (!buffer_available(packet, 10))
                        return 0;
                buffer_skip(packet, 8);
                rdata_size = buffer_read_u16(packet);
                if (!buffer_available(packet, rdata_size))
                        return 0;
                buffer_skip(packet, rdata_size);
        }

        return 1;
}

rr_type *
packet_read_rr(region_type *region, domain_table_type *owners,
               buffer_type *packet, int question_section)
{
        const nsd_type_descriptor_type *descriptor;
        const dname_type *owner;
        struct domain* domain;
        uint16_t type, class, rdlength;
        uint32_t ttl;
        struct rr *rr;
        int32_t code;

        owner = dname_make_from_packet(region, packet, 1, 1);
        if (!owner || !buffer_available(packet, 2*sizeof(uint16_t))) {
                return NULL;
        }

        domain = domain_table_insert(owners, owner);
        type = buffer_read_u16(packet);
        class = buffer_read_u16(packet);

        if (question_section) {
                rr_type *result = (rr_type *) region_alloc(region,
                        sizeof(rr_type));
                result->owner = domain;
                result->type = type;
                result->klass = class;
                result->ttl = 0;
                result->rdlength = 0;
                return result;
        } else if (!buffer_available(packet, sizeof(uint32_t) + sizeof(uint16_t))) {
                return NULL;
        }

        ttl = buffer_read_u32(packet);
        rdlength = buffer_read_u16(packet);

        if (!buffer_available(packet, rdlength)) {
                return NULL;
        }

        descriptor = nsd_type_descriptor(type);
        code = descriptor->read_rdata(owners, rdlength, packet, &rr);
        if(code < 0) {
                return NULL;
        }
        rr->owner = domain;
        rr->type = type;
        rr->klass = class;
        rr->ttl = ttl;

        return rr;
}

int packet_read_query_section(buffer_type *packet,
        uint8_t* dst, uint16_t* qtype, uint16_t* qclass)
{
        uint8_t *query_name = buffer_current(packet);
        uint8_t *src = query_name;
        size_t len;

        while (*src) {
                /*
                 * If we are out of buffer limits or we have a pointer
                 * in question dname or the domain name is longer than
                 * MAXDOMAINLEN ...
                 */
                if ((*src & 0xc0) ||
                    (src + *src + 2 > buffer_end(packet)) ||
                    (src + *src + 2 > query_name + MAXDOMAINLEN))
                {
                        return 0;
                }
                memcpy(dst, src, *src + 1);
                dst += *src + 1;
                src += *src + 1;
        }
        *dst++ = *src++;

        /* Make sure name is not too long or we have stripped packet... */
        len = src - query_name;
        if (len > MAXDOMAINLEN ||
            (src + 2*sizeof(uint16_t) > buffer_end(packet)))
        {
                return 0;
        }
        buffer_set_position(packet, src - buffer_begin(packet));

        *qtype = buffer_read_u16(packet);
        *qclass = buffer_read_u16(packet);
        return 1;
}

int packet_find_notify_serial(buffer_type *packet, uint32_t* serial)
{
        size_t saved_position = buffer_position(packet);
        /* count of further RRs after question section */
        size_t rrcount = (size_t)ANCOUNT(packet) + (size_t)NSCOUNT(packet) + (size_t)ARCOUNT(packet);
        size_t qcount = (size_t)QDCOUNT(packet);
        size_t i;
        buffer_set_position(packet, QHEADERSZ);
        if(qcount > 64 || rrcount > 65530) {
                /* query count 0 or 1 only, rr number limited by 64k packet,
                 * and should not be impossibly high, parse error */
                buffer_set_position(packet, saved_position);
                return 0;
        }

        /* skip all question RRs */
        for (i = 0; i < qcount; ++i) {
                if (!packet_skip_rr(packet, 1)) {
                        buffer_set_position(packet, saved_position);
                        return 0;
                }
        }

        /* Find the SOA RR */
        for(i = 0; i < rrcount; i++) {
                uint16_t rdata_size;
                if (!packet_skip_dname(packet))
                        break;
                /* check length available for type,class,ttl,rdatalen */
                if (!buffer_available(packet, 10))
                        break;
                /* check type, class */
                if(buffer_read_u16(packet) == TYPE_SOA) {
                        if(buffer_read_u16(packet) != CLASS_IN)
                                break;
                        buffer_skip(packet, 4); /* skip ttl */
                        rdata_size = buffer_read_u16(packet);
                        if (!buffer_available(packet, rdata_size))
                                break;
                        /* skip two dnames, then serial */
                        if (!packet_skip_dname(packet) ||
                                !packet_skip_dname(packet))
                                break;
                        if (!buffer_available(packet, 4))
                                break;
                        *serial = buffer_read_u32(packet);
                        buffer_set_position(packet, saved_position);
                        return 1;
                }
                /* continue to next RR */
                buffer_skip(packet, 6);
                rdata_size = buffer_read_u16(packet);
                if (!buffer_available(packet, rdata_size))
                        break;
                buffer_skip(packet, rdata_size);
        }
        /* failed to find SOA */
        buffer_set_position(packet, saved_position);
        return 0;
}