stree fallout: implications not working 100%
[smatch.git] / avl.c
blob9935840d837ca4189e340208f185c2fa8712bb0b
1 /*
2 * Copyright (C) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 * THE SOFTWARE.
23 #include <assert.h>
24 #include <stdlib.h>
26 #include "smatch.h"
27 #include "smatch_slist.h"
29 static AvlNode *mkNode(const struct sm_state *sm);
30 static void freeNode(AvlNode *node);
32 static AvlNode *lookup(const struct stree *avl, AvlNode *node, const struct sm_state *sm);
34 static bool insert_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm);
35 static bool remove_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret);
36 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret);
38 static int sway(AvlNode **p, int sway);
39 static void balance(AvlNode **p, int side);
41 static bool checkBalances(AvlNode *node, int *height);
42 static bool checkOrder(struct stree *avl);
43 static size_t countNode(AvlNode *node);
46 * Utility macros for converting between
47 * "balance" values (-1 or 1) and "side" values (0 or 1).
49 * bal(0) == -1
50 * bal(1) == +1
51 * side(-1) == 0
52 * side(+1) == 1
54 #define bal(side) ((side) == 0 ? -1 : 1)
55 #define side(bal) ((bal) == 1 ? 1 : 0)
57 static int sign(int cmp)
59 if (cmp < 0)
60 return -1;
61 if (cmp == 0)
62 return 0;
63 return 1;
66 struct stree *avl_new(void)
68 struct stree *avl = malloc(sizeof(*avl));
70 assert(avl != NULL);
72 avl->root = NULL;
73 avl->count = 0;
74 return avl;
77 void free_stree(struct stree **avl)
79 if (*avl) {
80 freeNode((*avl)->root);
81 free(*avl);
83 *avl = NULL;
86 struct sm_state *avl_lookup(const struct stree *avl, const struct sm_state *sm)
88 AvlNode *found;
90 if (!avl)
91 return NULL;
92 found = lookup(avl, avl->root, sm);
93 if (!found)
94 return NULL;
95 return (struct sm_state *)found->sm;
98 AvlNode *avl_lookup_node(const struct stree *avl, const struct sm_state *sm)
100 return lookup(avl, avl->root, sm);
103 size_t stree_count(const struct stree *avl)
105 if (!avl)
106 return 0;
107 return avl->count;
110 bool avl_insert(struct stree **avl, const struct sm_state *sm)
112 size_t old_count;
114 if (!*avl)
115 *avl = avl_new();
116 old_count = (*avl)->count;
117 insert_sm(*avl, &(*avl)->root, sm);
118 return (*avl)->count != old_count;
121 bool avl_remove(struct stree **avl, const struct sm_state *sm)
123 AvlNode *node = NULL;
125 if (!*avl)
126 return false;
128 remove_sm(*avl, &(*avl)->root, sm, &node);
130 if ((*avl)->count == 0)
131 free_stree(avl);
133 if (node == NULL) {
134 return false;
135 } else {
136 free(node);
137 return true;
141 static AvlNode *mkNode(const struct sm_state *sm)
143 AvlNode *node = malloc(sizeof(*node));
145 assert(node != NULL);
147 node->sm = sm;
148 node->lr[0] = NULL;
149 node->lr[1] = NULL;
150 node->balance = 0;
151 return node;
154 static void freeNode(AvlNode *node)
156 if (node) {
157 freeNode(node->lr[0]);
158 freeNode(node->lr[1]);
159 free(node);
163 static AvlNode *lookup(const struct stree *avl, AvlNode *node, const struct sm_state *sm)
165 int cmp;
167 if (node == NULL)
168 return NULL;
170 cmp = cmp_tracker(sm, node->sm);
172 if (cmp < 0)
173 return lookup(avl, node->lr[0], sm);
174 if (cmp > 0)
175 return lookup(avl, node->lr[1], sm);
176 return node;
180 * Insert an sm into a subtree, rebalancing if necessary.
182 * Return true if the subtree's height increased.
184 static bool insert_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm)
186 if (*p == NULL) {
187 *p = mkNode(sm);
188 avl->count++;
189 return true;
190 } else {
191 AvlNode *node = *p;
192 int cmp = sign(cmp_tracker(sm, node->sm));
194 if (cmp == 0) {
195 node->sm = sm;
196 return false;
199 if (!insert_sm(avl, &node->lr[side(cmp)], sm))
200 return false;
202 /* If tree's balance became -1 or 1, it means the tree's height grew due to insertion. */
203 return sway(p, cmp) != 0;
208 * Remove the node matching the given sm.
209 * If present, return the removed node through *ret .
210 * The returned node's lr and balance are meaningless.
212 * Return true if the subtree's height decreased.
214 static bool remove_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret)
216 if (p == NULL || *p == NULL) {
217 return false;
218 } else {
219 AvlNode *node = *p;
220 int cmp = sign(cmp_tracker(sm, node->sm));
222 if (cmp == 0) {
223 *ret = node;
224 avl->count--;
226 if (node->lr[0] != NULL && node->lr[1] != NULL) {
227 AvlNode *replacement;
228 int side;
229 bool shrunk;
231 /* Pick a subtree to pull the replacement from such that
232 * this node doesn't have to be rebalanced. */
233 side = node->balance <= 0 ? 0 : 1;
235 shrunk = removeExtremum(&node->lr[side], 1 - side, &replacement);
237 replacement->lr[0] = node->lr[0];
238 replacement->lr[1] = node->lr[1];
239 replacement->balance = node->balance;
240 *p = replacement;
242 if (!shrunk)
243 return false;
245 replacement->balance -= bal(side);
247 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
248 return replacement->balance == 0;
251 if (node->lr[0] != NULL)
252 *p = node->lr[0];
253 else
254 *p = node->lr[1];
256 return true;
258 } else {
259 if (!remove_sm(avl, &node->lr[side(cmp)], sm, ret))
260 return false;
262 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
263 return sway(p, -cmp) == 0;
269 * Remove either the left-most (if side == 0) or right-most (if side == 1)
270 * node in a subtree, returning the removed node through *ret .
271 * The returned node's lr and balance are meaningless.
273 * The subtree must not be empty (i.e. *p must not be NULL).
275 * Return true if the subtree's height decreased.
277 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret)
279 AvlNode *node = *p;
281 if (node->lr[side] == NULL) {
282 *ret = node;
283 *p = node->lr[1 - side];
284 return true;
287 if (!removeExtremum(&node->lr[side], side, ret))
288 return false;
290 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
291 return sway(p, -bal(side)) == 0;
295 * Rebalance a node if necessary. Think of this function
296 * as a higher-level interface to balance().
298 * sway must be either -1 or 1, and indicates what was added to
299 * the balance of this node by a prior operation.
301 * Return the new balance of the subtree.
303 static int sway(AvlNode **p, int sway)
305 if ((*p)->balance != sway)
306 (*p)->balance += sway;
307 else
308 balance(p, side(sway));
310 return (*p)->balance;
314 * Perform tree rotations on an unbalanced node.
316 * side == 0 means the node's balance is -2 .
317 * side == 1 means the node's balance is +2 .
319 static void balance(AvlNode **p, int side)
321 AvlNode *node = *p,
322 *child = node->lr[side];
323 int opposite = 1 - side;
324 int bal = bal(side);
326 if (child->balance != -bal) {
327 /* Left-left (side == 0) or right-right (side == 1) */
328 node->lr[side] = child->lr[opposite];
329 child->lr[opposite] = node;
330 *p = child;
332 child->balance -= bal;
333 node->balance = -child->balance;
335 } else {
336 /* Left-right (side == 0) or right-left (side == 1) */
337 AvlNode *grandchild = child->lr[opposite];
339 node->lr[side] = grandchild->lr[opposite];
340 child->lr[opposite] = grandchild->lr[side];
341 grandchild->lr[side] = child;
342 grandchild->lr[opposite] = node;
343 *p = grandchild;
345 node->balance = 0;
346 child->balance = 0;
348 if (grandchild->balance == bal)
349 node->balance = -bal;
350 else if (grandchild->balance == -bal)
351 child->balance = bal;
353 grandchild->balance = 0;
358 /************************* avl_check_invariants() *************************/
360 bool avl_check_invariants(struct stree *avl)
362 int dummy;
364 return checkBalances(avl->root, &dummy)
365 && checkOrder(avl)
366 && countNode(avl->root) == avl->count;
369 static bool checkBalances(AvlNode *node, int *height)
371 if (node) {
372 int h0, h1;
374 if (!checkBalances(node->lr[0], &h0))
375 return false;
376 if (!checkBalances(node->lr[1], &h1))
377 return false;
379 if (node->balance != h1 - h0 || node->balance < -1 || node->balance > 1)
380 return false;
382 *height = (h0 > h1 ? h0 : h1) + 1;
383 return true;
384 } else {
385 *height = 0;
386 return true;
390 static bool checkOrder(struct stree *avl)
392 AvlIter i;
393 const struct sm_state *last = NULL;
394 bool last_set = false;
396 avl_foreach(i, avl) {
397 if (last_set && cmp_tracker(last, i.sm) >= 0)
398 return false;
399 last = i.sm;
400 last_set = true;
403 return true;
406 static size_t countNode(AvlNode *node)
408 if (node)
409 return 1 + countNode(node->lr[0]) + countNode(node->lr[1]);
410 else
411 return 0;
415 /************************* Traversal *************************/
417 void avl_iter_begin(AvlIter *iter, struct stree *avl, AvlDirection dir)
419 AvlNode *node;
421 iter->stack_index = 0;
422 iter->direction = dir;
424 if (!avl || !avl->root) {
425 iter->sm = NULL;
426 iter->node = NULL;
427 return;
429 node = avl->root;
431 while (node->lr[dir] != NULL) {
432 iter->stack[iter->stack_index++] = node;
433 node = node->lr[dir];
436 iter->sm = (struct sm_state *) node->sm;
437 iter->node = node;
440 void avl_iter_next(AvlIter *iter)
442 AvlNode *node = iter->node;
443 AvlDirection dir = iter->direction;
445 if (node == NULL)
446 return;
448 node = node->lr[1 - dir];
449 if (node != NULL) {
450 while (node->lr[dir] != NULL) {
451 iter->stack[iter->stack_index++] = node;
452 node = node->lr[dir];
454 } else if (iter->stack_index > 0) {
455 node = iter->stack[--iter->stack_index];
456 } else {
457 iter->sm = NULL;
458 iter->node = NULL;
459 return;
462 iter->node = node;
463 iter->sm = (struct sm_state *) node->sm;
466 struct stree *clone_stree(struct stree *orig)
468 struct stree *new = NULL;
469 AvlIter i;
471 avl_foreach(i, orig)
472 avl_insert(&new, i.sm);
474 return new;