root/security/apparmor/match.c
// SPDX-License-Identifier: GPL-2.0-only
/*
 * AppArmor security module
 *
 * This file contains AppArmor dfa based regular expression matching engine
 *
 * Copyright (C) 1998-2008 Novell/SUSE
 * Copyright 2009-2012 Canonical Ltd.
 */

#include <linux/errno.h>
#include <linux/kernel.h>
#include <linux/mm.h>
#include <linux/slab.h>
#include <linux/vmalloc.h>
#include <linux/err.h>
#include <linux/kref.h>
#include <linux/unaligned.h>

#include "include/lib.h"
#include "include/match.h"

#define base_idx(X) ((X) & 0xffffff)

/**
 * unpack_table - unpack a dfa table (one of accept, default, base, next check)
 * @blob: data to unpack (NOT NULL)
 * @bsize: size of blob
 *
 * Returns: pointer to table else NULL on failure
 *
 * NOTE: must be freed by kvfree (not kfree)
 */
static struct table_header *unpack_table(char *blob, size_t bsize)
{
        struct table_header *table = NULL;
        struct table_header th;
        size_t tsize;

        if (bsize < sizeof(struct table_header))
                goto out;

        /* loaded td_id's start at 1, subtract 1 now to avoid doing
         * it every time we use td_id as an index
         */
        th.td_id = get_unaligned_be16(blob) - 1;
        if (th.td_id > YYTD_ID_MAX)
                goto out;
        th.td_flags = get_unaligned_be16(blob + 2);
        th.td_lolen = get_unaligned_be32(blob + 8);
        blob += sizeof(struct table_header);

        if (!(th.td_flags == YYTD_DATA16 || th.td_flags == YYTD_DATA32 ||
              th.td_flags == YYTD_DATA8))
                goto out;

        /* if we have a table it must have some entries */
        if (th.td_lolen == 0)
                goto out;
        tsize = table_size(th.td_lolen, th.td_flags);
        if (bsize < tsize)
                goto out;

        table = kvzalloc(tsize, GFP_KERNEL);
        if (table) {
                table->td_id = th.td_id;
                table->td_flags = th.td_flags;
                table->td_lolen = th.td_lolen;
                if (th.td_flags == YYTD_DATA8)
                        memcpy(table->td_data, blob, th.td_lolen);
                else if (th.td_flags == YYTD_DATA16)
                        UNPACK_ARRAY(table->td_data, blob, th.td_lolen,
                                     u16, __be16, get_unaligned_be16);
                else if (th.td_flags == YYTD_DATA32)
                        UNPACK_ARRAY(table->td_data, blob, th.td_lolen,
                                     u32, __be32, get_unaligned_be32);
                else
                        goto fail;
                /* if table was vmalloced make sure the page tables are synced
                 * before it is used, as it goes live to all cpus.
                 */
                if (is_vmalloc_addr(table))
                        vm_unmap_aliases();
        }

out:
        return table;
fail:
        kvfree(table);
        return NULL;
}

/**
 * verify_table_headers - verify that the tables headers are as expected
 * @tables: array of dfa tables to check (NOT NULL)
 * @flags: flags controlling what type of accept table are acceptable
 *
 * Assumes dfa has gone through the first pass verification done by unpacking
 * NOTE: this does not valid accept table values
 *
 * Returns: %0 else error code on failure to verify
 */
static int verify_table_headers(struct table_header **tables, int flags)
{
        size_t state_count, trans_count;
        int error = -EPROTO;

        /* check that required tables exist */
        if (!(tables[YYTD_ID_DEF] && tables[YYTD_ID_BASE] &&
              tables[YYTD_ID_NXT] && tables[YYTD_ID_CHK]))
                goto out;

        /* accept.size == default.size == base.size */
        state_count = tables[YYTD_ID_BASE]->td_lolen;
        if (ACCEPT1_FLAGS(flags)) {
                if (!tables[YYTD_ID_ACCEPT])
                        goto out;
                if (state_count != tables[YYTD_ID_ACCEPT]->td_lolen)
                        goto out;
        }
        if (ACCEPT2_FLAGS(flags)) {
                if (!tables[YYTD_ID_ACCEPT2])
                        goto out;
                if (state_count != tables[YYTD_ID_ACCEPT2]->td_lolen)
                        goto out;
        }
        if (state_count != tables[YYTD_ID_DEF]->td_lolen)
                goto out;

        /* next.size == chk.size */
        trans_count = tables[YYTD_ID_NXT]->td_lolen;
        if (trans_count != tables[YYTD_ID_CHK]->td_lolen)
                goto out;

        /* if equivalence classes then its table size must be 256 */
        if (tables[YYTD_ID_EC] && tables[YYTD_ID_EC]->td_lolen != 256)
                goto out;

        error = 0;
out:
        return error;
}

/**
 * verify_dfa - verify that transitions and states in the tables are in bounds.
 * @dfa: dfa to test  (NOT NULL)
 *
 * Assumes dfa has gone through the first pass verification done by unpacking
 * NOTE: this does not valid accept table values
 *
 * Returns: %0 else error code on failure to verify
 */
static int verify_dfa(struct aa_dfa *dfa)
{
        size_t i, state_count, trans_count;
        int error = -EPROTO;

        state_count = dfa->tables[YYTD_ID_BASE]->td_lolen;
        trans_count = dfa->tables[YYTD_ID_NXT]->td_lolen;
        if (state_count == 0)
                goto out;
        for (i = 0; i < state_count; i++) {
                if (DEFAULT_TABLE(dfa)[i] >= state_count) {
                        pr_err("AppArmor DFA default state out of bounds");
                        goto out;
                }
                if (BASE_TABLE(dfa)[i] & MATCH_FLAGS_INVALID) {
                        pr_err("AppArmor DFA state with invalid match flags");
                        goto out;
                }
                if ((BASE_TABLE(dfa)[i] & MATCH_FLAG_DIFF_ENCODE)) {
                        if (!(dfa->flags & YYTH_FLAG_DIFF_ENCODE)) {
                                pr_err("AppArmor DFA diff encoded transition state without header flag");
                                goto out;
                        }
                }
                if ((BASE_TABLE(dfa)[i] & MATCH_FLAG_OOB_TRANSITION)) {
                        if (base_idx(BASE_TABLE(dfa)[i]) < dfa->max_oob) {
                                pr_err("AppArmor DFA out of bad transition out of range");
                                goto out;
                        }
                        if (!(dfa->flags & YYTH_FLAG_OOB_TRANS)) {
                                pr_err("AppArmor DFA out of bad transition state without header flag");
                                goto out;
                        }
                }
                if (base_idx(BASE_TABLE(dfa)[i]) + 255 >= trans_count) {
                        pr_err("AppArmor DFA next/check upper bounds error\n");
                        goto out;
                }
        }

        for (i = 0; i < trans_count; i++) {
                if (NEXT_TABLE(dfa)[i] >= state_count)
                        goto out;
                if (CHECK_TABLE(dfa)[i] >= state_count)
                        goto out;
        }

        /* Now that all the other tables are verified, verify diffencoding */
        for (i = 0; i < state_count; i++) {
                size_t j, k;

                for (j = i;
                     ((BASE_TABLE(dfa)[j] & MATCH_FLAG_DIFF_ENCODE) &&
                      !(BASE_TABLE(dfa)[j] & MARK_DIFF_ENCODE_VERIFIED));
                     j = k) {
                        if (BASE_TABLE(dfa)[j] & MARK_DIFF_ENCODE)
                                /* loop in current chain */
                                goto out;
                        k = DEFAULT_TABLE(dfa)[j];
                        if (j == k)
                                /* self loop */
                                goto out;
                        BASE_TABLE(dfa)[j] |= MARK_DIFF_ENCODE;
                }
                /* move mark to verified */
                for (j = i;
                     (BASE_TABLE(dfa)[j] & MATCH_FLAG_DIFF_ENCODE);
                     j = k) {
                        k = DEFAULT_TABLE(dfa)[j];
                        if (j < i)
                                /* jumps to state/chain that has been
                                 * verified
                                 */
                                break;
                        BASE_TABLE(dfa)[j] &= ~MARK_DIFF_ENCODE;
                        BASE_TABLE(dfa)[j] |= MARK_DIFF_ENCODE_VERIFIED;
                }
        }
        error = 0;

out:
        return error;
}

/**
 * dfa_free - free a dfa allocated by aa_dfa_unpack
 * @dfa: the dfa to free  (MAYBE NULL)
 *
 * Requires: reference count to dfa == 0
 */
static void dfa_free(struct aa_dfa *dfa)
{
        if (dfa) {
                int i;

                for (i = 0; i < ARRAY_SIZE(dfa->tables); i++) {
                        kvfree(dfa->tables[i]);
                        dfa->tables[i] = NULL;
                }
                kfree(dfa);
        }
}

/**
 * aa_dfa_free_kref - free aa_dfa by kref (called by aa_put_dfa)
 * @kref: kref callback for freeing of a dfa  (NOT NULL)
 */
void aa_dfa_free_kref(struct kref *kref)
{
        struct aa_dfa *dfa = container_of(kref, struct aa_dfa, count);
        dfa_free(dfa);
}



/**
 * remap_data16_to_data32 - remap u16 @old table to a u32 based table
 * @old: table to remap
 *
 * Returns: new table with u32 entries instead of u16.
 *
 * Note: will free @old so caller does not have to
 */
static struct table_header *remap_data16_to_data32(struct table_header *old)
{
        struct table_header *new;
        size_t tsize;
        u32 i;

        tsize = table_size(old->td_lolen, YYTD_DATA32);
        new = kvzalloc(tsize, GFP_KERNEL);
        if (!new) {
                kvfree(old);
                return NULL;
        }
        new->td_id = old->td_id;
        new->td_flags = YYTD_DATA32;
        new->td_lolen = old->td_lolen;

        for (i = 0; i < old->td_lolen; i++)
                TABLE_DATAU32(new)[i] = (u32) TABLE_DATAU16(old)[i];

        kvfree(old);
        if (is_vmalloc_addr(new))
                vm_unmap_aliases();

        return new;
}

/**
 * aa_dfa_unpack - unpack the binary tables of a serialized dfa
 * @blob: aligned serialized stream of data to unpack  (NOT NULL)
 * @size: size of data to unpack
 * @flags: flags controlling what type of accept tables are acceptable
 *
 * Unpack a dfa that has been serialized.  To find information on the dfa
 * format look in Documentation/admin-guide/LSM/apparmor.rst
 * Assumes the dfa @blob stream has been aligned on a 8 byte boundary
 *
 * Returns: an unpacked dfa ready for matching or ERR_PTR on failure
 */
struct aa_dfa *aa_dfa_unpack(void *blob, size_t size, int flags)
{
        int hsize;
        int error = -ENOMEM;
        char *data = blob;
        struct table_header *table = NULL;
        struct aa_dfa *dfa = kzalloc_obj(struct aa_dfa);
        if (!dfa)
                goto fail;

        kref_init(&dfa->count);

        error = -EPROTO;

        /* get dfa table set header */
        if (size < sizeof(struct table_set_header))
                goto fail;

        if (get_unaligned_be32(data) != YYTH_MAGIC)
                goto fail;

        hsize = get_unaligned_be32(data + 4);
        if (size < hsize)
                goto fail;

        dfa->flags = get_unaligned_be16(data + 12);
        if (dfa->flags & ~(YYTH_FLAGS))
                goto fail;

        /*
         * TODO: needed for dfa to support more than 1 oob
         * if (dfa->flags & YYTH_FLAGS_OOB_TRANS) {
         *      if (hsize < 16 + 4)
         *              goto fail;
         *      dfa->max_oob = get_unaligned_be32(data + 16);
         *      if (dfa->max <= MAX_OOB_SUPPORTED) {
         *              pr_err("AppArmor DFA OOB greater than supported\n");
         *              goto fail;
         *      }
         * }
         */
        dfa->max_oob = 1;

        data += hsize;
        size -= hsize;

        while (size > 0) {
                table = unpack_table(data, size);
                if (!table)
                        goto fail;

                switch (table->td_id) {
                case YYTD_ID_ACCEPT:
                        if (!(table->td_flags & ACCEPT1_FLAGS(flags)))
                                goto fail;
                        break;
                case YYTD_ID_ACCEPT2:
                        if (!(table->td_flags & ACCEPT2_FLAGS(flags)))
                                goto fail;
                        break;
                case YYTD_ID_BASE:
                        if (table->td_flags != YYTD_DATA32)
                                goto fail;
                        break;
                case YYTD_ID_DEF:
                case YYTD_ID_NXT:
                case YYTD_ID_CHK:
                        if (!(table->td_flags == YYTD_DATA16 ||
                              table->td_flags == YYTD_DATA32)) {
                                goto fail;
                        }
                        break;
                case YYTD_ID_EC:
                        if (table->td_flags != YYTD_DATA8)
                                goto fail;
                        break;
                default:
                        goto fail;
                }
                /* check for duplicate table entry */
                if (dfa->tables[table->td_id])
                        goto fail;
                dfa->tables[table->td_id] = table;
                data += table_size(table->td_lolen, table->td_flags);
                size -= table_size(table->td_lolen, table->td_flags);

                /*
                 * this remapping has to be done after incrementing data above
                 * for now straight remap, later have dfa support both
                 */
                switch (table->td_id) {
                case YYTD_ID_DEF:
                case YYTD_ID_NXT:
                case YYTD_ID_CHK:
                        if (table->td_flags == YYTD_DATA16) {
                                table = remap_data16_to_data32(table);
                                if (!table)
                                        goto fail;
                        }
                        dfa->tables[table->td_id] = table;
                        break;
                }
                table = NULL;
        }
        error = verify_table_headers(dfa->tables, flags);
        if (error)
                goto fail;

        if (flags & DFA_FLAG_VERIFY_STATES) {
                error = verify_dfa(dfa);
                if (error)
                        goto fail;
        }

        return dfa;

fail:
        kvfree(table);
        dfa_free(dfa);
        return ERR_PTR(error);
}

#define match_char(state, def, base, next, check, C)    \
do {                                                    \
        u32 b = (base)[(state)];                        \
        unsigned int pos = base_idx(b) + (C);           \
        if ((check)[pos] != (state)) {                  \
                (state) = (def)[(state)];               \
                if (b & MATCH_FLAG_DIFF_ENCODE)         \
                        continue;                       \
                break;                                  \
        }                                               \
        (state) = (next)[pos];                          \
        break;                                          \
} while (1)

/**
 * aa_dfa_match_len - traverse @dfa to find state @str stops at
 * @dfa: the dfa to match @str against  (NOT NULL)
 * @start: the state of the dfa to start matching in
 * @str: the string of bytes to match against the dfa  (NOT NULL)
 * @len: length of the string of bytes to match
 *
 * aa_dfa_match_len will match @str against the dfa and return the state it
 * finished matching in. The final state can be used to look up the accepting
 * label, or as the start state of a continuing match.
 *
 * This function will happily match again the 0 byte and only finishes
 * when @len input is consumed.
 *
 * Returns: final state reached after input is consumed
 */
aa_state_t aa_dfa_match_len(struct aa_dfa *dfa, aa_state_t start,
                            const char *str, int len)
{
        u32 *def = DEFAULT_TABLE(dfa);
        u32 *base = BASE_TABLE(dfa);
        u32 *next = NEXT_TABLE(dfa);
        u32 *check = CHECK_TABLE(dfa);
        aa_state_t state = start;

        if (state == DFA_NOMATCH)
                return DFA_NOMATCH;

        /* current state is <state>, matching character *str */
        if (dfa->tables[YYTD_ID_EC]) {
                /* Equivalence class table defined */
                u8 *equiv = EQUIV_TABLE(dfa);
                for (; len; len--) {
                        u8 c = equiv[(u8) *str];

                        match_char(state, def, base, next, check, c);
                        str++;
                }
        } else {
                /* default is direct to next state */
                for (; len; len--) {
                        match_char(state, def, base, next, check, (u8) *str);
                        str++;
                }
        }

        return state;
}

/**
 * aa_dfa_match - traverse @dfa to find state @str stops at
 * @dfa: the dfa to match @str against  (NOT NULL)
 * @start: the state of the dfa to start matching in
 * @str: the null terminated string of bytes to match against the dfa (NOT NULL)
 *
 * aa_dfa_match will match @str against the dfa and return the state it
 * finished matching in. The final state can be used to look up the accepting
 * label, or as the start state of a continuing match.
 *
 * Returns: final state reached after input is consumed
 */
aa_state_t aa_dfa_match(struct aa_dfa *dfa, aa_state_t start, const char *str)
{
        u32 *def = DEFAULT_TABLE(dfa);
        u32 *base = BASE_TABLE(dfa);
        u32 *next = NEXT_TABLE(dfa);
        u32 *check = CHECK_TABLE(dfa);
        aa_state_t state = start;

        if (state == DFA_NOMATCH)
                return DFA_NOMATCH;

        /* current state is <state>, matching character *str */
        if (dfa->tables[YYTD_ID_EC]) {
                /* Equivalence class table defined */
                u8 *equiv = EQUIV_TABLE(dfa);
                /* default is direct to next state */
                while (*str) {
                        u8 c = equiv[(u8) *str];

                        match_char(state, def, base, next, check, c);
                        str++;
                }
        } else {
                /* default is direct to next state */
                while (*str) {
                        match_char(state, def, base, next, check, (u8) *str);
                        str++;
                }
        }

        return state;
}

/**
 * aa_dfa_next - step one character to the next state in the dfa
 * @dfa: the dfa to traverse (NOT NULL)
 * @state: the state to start in
 * @c: the input character to transition on
 *
 * aa_dfa_match will step through the dfa by one input character @c
 *
 * Returns: state reach after input @c
 */
aa_state_t aa_dfa_next(struct aa_dfa *dfa, aa_state_t state, const char c)
{
        u32 *def = DEFAULT_TABLE(dfa);
        u32 *base = BASE_TABLE(dfa);
        u32 *next = NEXT_TABLE(dfa);
        u32 *check = CHECK_TABLE(dfa);

        /* current state is <state>, matching character *str */
        if (dfa->tables[YYTD_ID_EC]) {
                /* Equivalence class table defined */
                u8 *equiv = EQUIV_TABLE(dfa);
                match_char(state, def, base, next, check, equiv[(u8) c]);
        } else
                match_char(state, def, base, next, check, (u8) c);

        return state;
}

aa_state_t aa_dfa_outofband_transition(struct aa_dfa *dfa, aa_state_t state)
{
        u32 *def = DEFAULT_TABLE(dfa);
        u32 *base = BASE_TABLE(dfa);
        u32 *next = NEXT_TABLE(dfa);
        u32 *check = CHECK_TABLE(dfa);
        u32 b = (base)[(state)];

        if (!(b & MATCH_FLAG_OOB_TRANSITION))
                return DFA_NOMATCH;

        /* No Equivalence class remapping for outofband transitions */
        match_char(state, def, base, next, check, -1);

        return state;
}

/**
 * aa_dfa_match_until - traverse @dfa until accept state or end of input
 * @dfa: the dfa to match @str against  (NOT NULL)
 * @start: the state of the dfa to start matching in
 * @str: the null terminated string of bytes to match against the dfa (NOT NULL)
 * @retpos: first character in str after match OR end of string
 *
 * aa_dfa_match will match @str against the dfa and return the state it
 * finished matching in. The final state can be used to look up the accepting
 * label, or as the start state of a continuing match.
 *
 * Returns: final state reached after input is consumed
 */
aa_state_t aa_dfa_match_until(struct aa_dfa *dfa, aa_state_t start,
                                const char *str, const char **retpos)
{
        u32 *def = DEFAULT_TABLE(dfa);
        u32 *base = BASE_TABLE(dfa);
        u32 *next = NEXT_TABLE(dfa);
        u32 *check = CHECK_TABLE(dfa);
        u32 *accept = ACCEPT_TABLE(dfa);
        aa_state_t state = start, pos;

        if (state == DFA_NOMATCH)
                return DFA_NOMATCH;

        /* current state is <state>, matching character *str */
        if (dfa->tables[YYTD_ID_EC]) {
                /* Equivalence class table defined */
                u8 *equiv = EQUIV_TABLE(dfa);
                /* default is direct to next state */
                while (*str) {
                        pos = base_idx(base[state]) + equiv[(u8) *str++];
                        if (check[pos] == state)
                                state = next[pos];
                        else
                                state = def[state];
                        if (accept[state])
                                break;
                }
        } else {
                /* default is direct to next state */
                while (*str) {
                        pos = base_idx(base[state]) + (u8) *str++;
                        if (check[pos] == state)
                                state = next[pos];
                        else
                                state = def[state];
                        if (accept[state])
                                break;
                }
        }

        *retpos = str;
        return state;
}

/**
 * aa_dfa_matchn_until - traverse @dfa until accept or @n bytes consumed
 * @dfa: the dfa to match @str against  (NOT NULL)
 * @start: the state of the dfa to start matching in
 * @str: the string of bytes to match against the dfa  (NOT NULL)
 * @n: length of the string of bytes to match
 * @retpos: first character in str after match OR str + n
 *
 * aa_dfa_match_len will match @str against the dfa and return the state it
 * finished matching in. The final state can be used to look up the accepting
 * label, or as the start state of a continuing match.
 *
 * This function will happily match again the 0 byte and only finishes
 * when @n input is consumed.
 *
 * Returns: final state reached after input is consumed
 */
aa_state_t aa_dfa_matchn_until(struct aa_dfa *dfa, aa_state_t start,
                                 const char *str, int n, const char **retpos)
{
        u32 *def = DEFAULT_TABLE(dfa);
        u32 *base = BASE_TABLE(dfa);
        u32 *next = NEXT_TABLE(dfa);
        u32 *check = CHECK_TABLE(dfa);
        u32 *accept = ACCEPT_TABLE(dfa);
        aa_state_t state = start, pos;

        *retpos = NULL;
        if (state == DFA_NOMATCH)
                return DFA_NOMATCH;

        /* current state is <state>, matching character *str */
        if (dfa->tables[YYTD_ID_EC]) {
                /* Equivalence class table defined */
                u8 *equiv = EQUIV_TABLE(dfa);
                /* default is direct to next state */
                for (; n; n--) {
                        pos = base_idx(base[state]) + equiv[(u8) *str++];
                        if (check[pos] == state)
                                state = next[pos];
                        else
                                state = def[state];
                        if (accept[state])
                                break;
                }
        } else {
                /* default is direct to next state */
                for (; n; n--) {
                        pos = base_idx(base[state]) + (u8) *str++;
                        if (check[pos] == state)
                                state = next[pos];
                        else
                                state = def[state];
                        if (accept[state])
                                break;
                }
        }

        *retpos = str;
        return state;
}

#define inc_wb_pos(wb)                                                  \
do {                                                                    \
        BUILD_BUG_ON_NOT_POWER_OF_2(WB_HISTORY_SIZE);                   \
        wb->pos = (wb->pos + 1) & (WB_HISTORY_SIZE - 1);                \
        wb->len = (wb->len + 1) > WB_HISTORY_SIZE ? WB_HISTORY_SIZE :   \
                                wb->len + 1;                            \
} while (0)

/* For DFAs that don't support extended tagging of states */
/* adjust is only set if is_loop returns true */
static bool is_loop(struct match_workbuf *wb, aa_state_t state,
                    unsigned int *adjust)
{
        int pos = wb->pos;
        int i;

        if (wb->history[pos] < state)
                return false;

        for (i = 0; i < wb->len; i++) {
                if (wb->history[pos] == state) {
                        *adjust = i;
                        return true;
                }
                /* -1 wraps to WB_HISTORY_SIZE - 1 */
                pos = (pos - 1) & (WB_HISTORY_SIZE - 1);
        }

        return false;
}

static aa_state_t leftmatch_fb(struct aa_dfa *dfa, aa_state_t start,
                                 const char *str, struct match_workbuf *wb,
                                 unsigned int *count)
{
        u32 *def = DEFAULT_TABLE(dfa);
        u32 *base = BASE_TABLE(dfa);
        u32 *next = NEXT_TABLE(dfa);
        u32 *check = CHECK_TABLE(dfa);
        aa_state_t state = start, pos;

        AA_BUG(!dfa);
        AA_BUG(!str);
        AA_BUG(!wb);
        AA_BUG(!count);

        *count = 0;
        if (state == DFA_NOMATCH)
                return DFA_NOMATCH;

        /* current state is <state>, matching character *str */
        if (dfa->tables[YYTD_ID_EC]) {
                /* Equivalence class table defined */
                u8 *equiv = EQUIV_TABLE(dfa);
                /* default is direct to next state */
                while (*str) {
                        unsigned int adjust;

                        wb->history[wb->pos] = state;
                        pos = base_idx(base[state]) + equiv[(u8) *str++];
                        if (check[pos] == state)
                                state = next[pos];
                        else
                                state = def[state];
                        if (is_loop(wb, state, &adjust)) {
                                state = aa_dfa_match(dfa, state, str);
                                *count -= adjust;
                                goto out;
                        }
                        inc_wb_pos(wb);
                        (*count)++;
                }
        } else {
                /* default is direct to next state */
                while (*str) {
                        unsigned int adjust;

                        wb->history[wb->pos] = state;
                        pos = base_idx(base[state]) + (u8) *str++;
                        if (check[pos] == state)
                                state = next[pos];
                        else
                                state = def[state];
                        if (is_loop(wb, state, &adjust)) {
                                state = aa_dfa_match(dfa, state, str);
                                *count -= adjust;
                                goto out;
                        }
                        inc_wb_pos(wb);
                        (*count)++;
                }
        }

out:
        if (!state)
                *count = 0;
        return state;
}

/**
 * aa_dfa_leftmatch - traverse @dfa to find state @str stops at
 * @dfa: the dfa to match @str against  (NOT NULL)
 * @start: the state of the dfa to start matching in
 * @str: the null terminated string of bytes to match against the dfa (NOT NULL)
 * @count: current count of longest left.
 *
 * aa_dfa_match will match @str against the dfa and return the state it
 * finished matching in. The final state can be used to look up the accepting
 * label, or as the start state of a continuing match.
 *
 * Returns: final state reached after input is consumed
 */
aa_state_t aa_dfa_leftmatch(struct aa_dfa *dfa, aa_state_t start,
                            const char *str, unsigned int *count)
{
        DEFINE_MATCH_WB(wb);

        /* TODO: match for extended state dfas */

        return leftmatch_fb(dfa, start, str, &wb, count);
}