#include <assert.h>
#include <stdlib.h>
#include "smatch.h"
#include "smatch_slist.h"
static AvlNode *mkNode(const struct sm_state *sm);
static void freeNode(AvlNode *node);
static AvlNode *lookup(const struct stree *avl, AvlNode *node, const struct sm_state *sm);
static bool insert_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm);
static bool remove_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret);
static bool removeExtremum(AvlNode **p, int side, AvlNode **ret);
static int sway(AvlNode **p, int sway);
static void balance(AvlNode **p, int side);
static bool checkBalances(AvlNode *node, int *height);
static bool checkOrder(struct stree *avl);
static size_t countNode(AvlNode *node);
int unfree_stree;
#define bal(side) ((side) == 0 ? -1 : 1)
#define side(bal) ((bal) == 1 ? 1 : 0)
static struct stree *avl_new(void)
{
struct stree *avl = malloc(sizeof(*avl));
unfree_stree++;
assert(avl != NULL);
avl->root = NULL;
avl->base_stree = NULL;
avl->has_states = calloc(num_checks + 1, sizeof(char));
avl->count = 0;
avl->stree_id = 0;
avl->references = 1;
return avl;
}
void free_stree(struct stree **avl)
{
if (!*avl)
return;
assert((*avl)->references > 0);
(*avl)->references--;
if ((*avl)->references != 0) {
*avl = NULL;
return;
}
unfree_stree--;
freeNode((*avl)->root);
free(*avl);
*avl = NULL;
}
struct sm_state *avl_lookup(const struct stree *avl, const struct sm_state *sm)
{
AvlNode *found;
if (!avl)
return NULL;
if (sm->owner != USHRT_MAX &&
!avl->has_states[sm->owner])
return NULL;
found = lookup(avl, avl->root, sm);
if (!found)
return NULL;
return (struct sm_state *)found->sm;
}
AvlNode *avl_lookup_node(const struct stree *avl, const struct sm_state *sm)
{
return lookup(avl, avl->root, sm);
}
size_t stree_count(const struct stree *avl)
{
if (!avl)
return 0;
return avl->count;
}
static struct stree *clone_stree_real(struct stree *orig)
{
struct stree *new = avl_new();
AvlIter i;
avl_foreach(i, orig)
avl_insert(&new, i.sm);
new->base_stree = orig->base_stree;
return new;
}
bool avl_insert(struct stree **avl, const struct sm_state *sm)
{
size_t old_count;
if (!*avl)
*avl = avl_new();
if ((*avl)->references > 1) {
(*avl)->references--;
*avl = clone_stree_real(*avl);
}
old_count = (*avl)->count;
if (sm->owner != USHRT_MAX)
(*avl)->has_states[sm->owner] = 1;
insert_sm(*avl, &(*avl)->root, sm);
return (*avl)->count != old_count;
}
bool avl_remove(struct stree **avl, const struct sm_state *sm)
{
AvlNode *node = NULL;
if (!*avl)
return false;
if ((*avl)->references > 1) {
(*avl)->references--;
*avl = clone_stree_real(*avl);
}
remove_sm(*avl, &(*avl)->root, sm, &node);
if ((*avl)->count == 0)
free_stree(avl);
if (node == NULL) {
return false;
} else {
free(node);
return true;
}
}
static AvlNode *mkNode(const struct sm_state *sm)
{
AvlNode *node = malloc(sizeof(*node));
assert(node != NULL);
node->sm = sm;
node->lr[0] = NULL;
node->lr[1] = NULL;
node->balance = 0;
return node;
}
static void freeNode(AvlNode *node)
{
if (node) {
freeNode(node->lr[0]);
freeNode(node->lr[1]);
free(node);
}
}
static AvlNode *lookup(const struct stree *avl, AvlNode *node, const struct sm_state *sm)
{
int cmp;
if (node == NULL)
return NULL;
cmp = cmp_tracker(sm, node->sm);
if (cmp < 0)
return lookup(avl, node->lr[0], sm);
if (cmp > 0)
return lookup(avl, node->lr[1], sm);
return node;
}
static bool insert_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm)
{
if (*p == NULL) {
*p = mkNode(sm);
avl->count++;
return true;
} else {
AvlNode *node = *p;
int cmp = cmp_tracker(sm, node->sm);
if (cmp == 0) {
node->sm = sm;
return false;
}
if (!insert_sm(avl, &node->lr[side(cmp)], sm))
return false;
return sway(p, cmp) != 0;
}
}
static bool remove_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret)
{
if (p == NULL || *p == NULL) {
return false;
} else {
AvlNode *node = *p;
int cmp = cmp_tracker(sm, node->sm);
if (cmp == 0) {
*ret = node;
avl->count--;
if (node->lr[0] != NULL && node->lr[1] != NULL) {
AvlNode *replacement;
int side;
bool shrunk;
side = node->balance <= 0 ? 0 : 1;
shrunk = removeExtremum(&node->lr[side], 1 - side, &replacement);
replacement->lr[0] = node->lr[0];
replacement->lr[1] = node->lr[1];
replacement->balance = node->balance;
*p = replacement;
if (!shrunk)
return false;
replacement->balance -= bal(side);
return replacement->balance == 0;
}
if (node->lr[0] != NULL)
*p = node->lr[0];
else
*p = node->lr[1];
return true;
} else {
if (!remove_sm(avl, &node->lr[side(cmp)], sm, ret))
return false;
return sway(p, -cmp) == 0;
}
}
}
static bool removeExtremum(AvlNode **p, int side, AvlNode **ret)
{
AvlNode *node = *p;
if (node->lr[side] == NULL) {
*ret = node;
*p = node->lr[1 - side];
return true;
}
if (!removeExtremum(&node->lr[side], side, ret))
return false;
return sway(p, -bal(side)) == 0;
}
static int sway(AvlNode **p, int sway)
{
if ((*p)->balance != sway)
(*p)->balance += sway;
else
balance(p, side(sway));
return (*p)->balance;
}
static void balance(AvlNode **p, int side)
{
AvlNode *node = *p,
*child = node->lr[side];
int opposite = 1 - side;
int bal = bal(side);
if (child->balance != -bal) {
node->lr[side] = child->lr[opposite];
child->lr[opposite] = node;
*p = child;
child->balance -= bal;
node->balance = -child->balance;
} else {
AvlNode *grandchild = child->lr[opposite];
node->lr[side] = grandchild->lr[opposite];
child->lr[opposite] = grandchild->lr[side];
grandchild->lr[side] = child;
grandchild->lr[opposite] = node;
*p = grandchild;
node->balance = 0;
child->balance = 0;
if (grandchild->balance == bal)
node->balance = -bal;
else if (grandchild->balance == -bal)
child->balance = bal;
grandchild->balance = 0;
}
}
bool avl_check_invariants(struct stree *avl)
{
int dummy;
return checkBalances(avl->root, &dummy)
&& checkOrder(avl)
&& countNode(avl->root) == avl->count;
}
static bool checkBalances(AvlNode *node, int *height)
{
if (node) {
int h0, h1;
if (!checkBalances(node->lr[0], &h0))
return false;
if (!checkBalances(node->lr[1], &h1))
return false;
if (node->balance != h1 - h0 || node->balance < -1 || node->balance > 1)
return false;
*height = (h0 > h1 ? h0 : h1) + 1;
return true;
} else {
*height = 0;
return true;
}
}
static bool checkOrder(struct stree *avl)
{
AvlIter i;
const struct sm_state *last = NULL;
bool last_set = false;
avl_foreach(i, avl) {
if (last_set && cmp_tracker(last, i.sm) >= 0)
return false;
last = i.sm;
last_set = true;
}
return true;
}
static size_t countNode(AvlNode *node)
{
if (node)
return 1 + countNode(node->lr[0]) + countNode(node->lr[1]);
else
return 0;
}
void avl_iter_begin(AvlIter *iter, struct stree *avl, AvlDirection dir)
{
AvlNode *node;
iter->stack_index = 0;
iter->direction = dir;
if (!avl || !avl->root) {
iter->sm = NULL;
iter->node = NULL;
return;
}
node = avl->root;
while (node->lr[dir] != NULL) {
iter->stack[iter->stack_index++] = node;
node = node->lr[dir];
}
iter->sm = (struct sm_state *) node->sm;
iter->node = node;
}
void avl_iter_next(AvlIter *iter)
{
AvlNode *node = iter->node;
AvlDirection dir = iter->direction;
if (node == NULL)
return;
node = node->lr[1 - dir];
if (node != NULL) {
while (node->lr[dir] != NULL) {
iter->stack[iter->stack_index++] = node;
node = node->lr[dir];
}
} else if (iter->stack_index > 0) {
node = iter->stack[--iter->stack_index];
} else {
iter->sm = NULL;
iter->node = NULL;
return;
}
iter->node = node;
iter->sm = (struct sm_state *) node->sm;
}
struct stree *clone_stree(struct stree *orig)
{
if (!orig)
return NULL;
orig->references++;
return orig;
}
void set_stree_id(struct stree **stree, int stree_id)
{
if ((*stree)->stree_id != 0)
*stree = clone_stree_real(*stree);
(*stree)->stree_id = stree_id;
}
int get_stree_id(struct stree *stree)
{
if (!stree)
return -1;
return stree->stree_id;
}