buf_size: allow strncmp("foo", bar, 100) where 100 is larger than "foo"
[smatch.git] / avl.c
blob5cfb85804a48183d32ffdab3ba2a204afc510288
1 /*
2 * Copyright (C) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 * THE SOFTWARE.
23 #include <assert.h>
24 #include <stdlib.h>
26 #include "smatch.h"
27 #include "smatch_slist.h"
29 static AvlNode *mkNode(const struct sm_state *sm);
30 static void freeNode(AvlNode *node);
32 static AvlNode *lookup(const struct stree *avl, AvlNode *node, const struct sm_state *sm);
34 static bool insert_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm);
35 static bool remove_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret);
36 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret);
38 static int sway(AvlNode **p, int sway);
39 static void balance(AvlNode **p, int side);
41 static bool checkBalances(AvlNode *node, int *height);
42 static bool checkOrder(struct stree *avl);
43 static size_t countNode(AvlNode *node);
46 * Utility macros for converting between
47 * "balance" values (-1 or 1) and "side" values (0 or 1).
49 * bal(0) == -1
50 * bal(1) == +1
51 * side(-1) == 0
52 * side(+1) == 1
54 #define bal(side) ((side) == 0 ? -1 : 1)
55 #define side(bal) ((bal) == 1 ? 1 : 0)
57 static int sign(int cmp)
59 if (cmp < 0)
60 return -1;
61 if (cmp == 0)
62 return 0;
63 return 1;
66 struct stree *avl_new(void)
68 struct stree *avl = malloc(sizeof(*avl));
70 assert(avl != NULL);
72 avl->root = NULL;
73 avl->count = 0;
74 avl->stree_id = 0;
75 avl->references = 1;
76 return avl;
79 void free_stree(struct stree **avl)
81 if (!*avl)
82 return;
84 assert((*avl)->references > 0);
86 (*avl)->references--;
87 if ((*avl)->references != 0) {
88 *avl = NULL;
89 return;
92 freeNode((*avl)->root);
93 free(*avl);
94 *avl = NULL;
97 struct sm_state *avl_lookup(const struct stree *avl, const struct sm_state *sm)
99 AvlNode *found;
101 if (!avl)
102 return NULL;
103 found = lookup(avl, avl->root, sm);
104 if (!found)
105 return NULL;
106 return (struct sm_state *)found->sm;
109 AvlNode *avl_lookup_node(const struct stree *avl, const struct sm_state *sm)
111 return lookup(avl, avl->root, sm);
114 size_t stree_count(const struct stree *avl)
116 if (!avl)
117 return 0;
118 return avl->count;
121 static struct stree *clone_stree_real(struct stree *orig)
123 struct stree *new = avl_new();
124 AvlIter i;
126 avl_foreach(i, orig)
127 avl_insert(&new, i.sm);
129 return new;
132 bool avl_insert(struct stree **avl, const struct sm_state *sm)
134 size_t old_count;
136 if (!*avl)
137 *avl = avl_new();
138 if ((*avl)->references > 1) {
139 (*avl)->references--;
140 *avl = clone_stree_real(*avl);
142 old_count = (*avl)->count;
143 insert_sm(*avl, &(*avl)->root, sm);
144 return (*avl)->count != old_count;
147 bool avl_remove(struct stree **avl, const struct sm_state *sm)
149 AvlNode *node = NULL;
151 if (!*avl)
152 return false;
153 /* it's fairly rare for smatch to call avl_remove */
154 if ((*avl)->references > 1) {
155 (*avl)->references--;
156 *avl = clone_stree_real(*avl);
159 remove_sm(*avl, &(*avl)->root, sm, &node);
161 if ((*avl)->count == 0)
162 free_stree(avl);
164 if (node == NULL) {
165 return false;
166 } else {
167 free(node);
168 return true;
172 static AvlNode *mkNode(const struct sm_state *sm)
174 AvlNode *node = malloc(sizeof(*node));
176 assert(node != NULL);
178 node->sm = sm;
179 node->lr[0] = NULL;
180 node->lr[1] = NULL;
181 node->balance = 0;
182 return node;
185 static void freeNode(AvlNode *node)
187 if (node) {
188 freeNode(node->lr[0]);
189 freeNode(node->lr[1]);
190 free(node);
194 static AvlNode *lookup(const struct stree *avl, AvlNode *node, const struct sm_state *sm)
196 int cmp;
198 if (node == NULL)
199 return NULL;
201 cmp = cmp_tracker(sm, node->sm);
203 if (cmp < 0)
204 return lookup(avl, node->lr[0], sm);
205 if (cmp > 0)
206 return lookup(avl, node->lr[1], sm);
207 return node;
211 * Insert an sm into a subtree, rebalancing if necessary.
213 * Return true if the subtree's height increased.
215 static bool insert_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm)
217 if (*p == NULL) {
218 *p = mkNode(sm);
219 avl->count++;
220 return true;
221 } else {
222 AvlNode *node = *p;
223 int cmp = sign(cmp_tracker(sm, node->sm));
225 if (cmp == 0) {
226 node->sm = sm;
227 return false;
230 if (!insert_sm(avl, &node->lr[side(cmp)], sm))
231 return false;
233 /* If tree's balance became -1 or 1, it means the tree's height grew due to insertion. */
234 return sway(p, cmp) != 0;
239 * Remove the node matching the given sm.
240 * If present, return the removed node through *ret .
241 * The returned node's lr and balance are meaningless.
243 * Return true if the subtree's height decreased.
245 static bool remove_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret)
247 if (p == NULL || *p == NULL) {
248 return false;
249 } else {
250 AvlNode *node = *p;
251 int cmp = sign(cmp_tracker(sm, node->sm));
253 if (cmp == 0) {
254 *ret = node;
255 avl->count--;
257 if (node->lr[0] != NULL && node->lr[1] != NULL) {
258 AvlNode *replacement;
259 int side;
260 bool shrunk;
262 /* Pick a subtree to pull the replacement from such that
263 * this node doesn't have to be rebalanced. */
264 side = node->balance <= 0 ? 0 : 1;
266 shrunk = removeExtremum(&node->lr[side], 1 - side, &replacement);
268 replacement->lr[0] = node->lr[0];
269 replacement->lr[1] = node->lr[1];
270 replacement->balance = node->balance;
271 *p = replacement;
273 if (!shrunk)
274 return false;
276 replacement->balance -= bal(side);
278 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
279 return replacement->balance == 0;
282 if (node->lr[0] != NULL)
283 *p = node->lr[0];
284 else
285 *p = node->lr[1];
287 return true;
289 } else {
290 if (!remove_sm(avl, &node->lr[side(cmp)], sm, ret))
291 return false;
293 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
294 return sway(p, -cmp) == 0;
300 * Remove either the left-most (if side == 0) or right-most (if side == 1)
301 * node in a subtree, returning the removed node through *ret .
302 * The returned node's lr and balance are meaningless.
304 * The subtree must not be empty (i.e. *p must not be NULL).
306 * Return true if the subtree's height decreased.
308 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret)
310 AvlNode *node = *p;
312 if (node->lr[side] == NULL) {
313 *ret = node;
314 *p = node->lr[1 - side];
315 return true;
318 if (!removeExtremum(&node->lr[side], side, ret))
319 return false;
321 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
322 return sway(p, -bal(side)) == 0;
326 * Rebalance a node if necessary. Think of this function
327 * as a higher-level interface to balance().
329 * sway must be either -1 or 1, and indicates what was added to
330 * the balance of this node by a prior operation.
332 * Return the new balance of the subtree.
334 static int sway(AvlNode **p, int sway)
336 if ((*p)->balance != sway)
337 (*p)->balance += sway;
338 else
339 balance(p, side(sway));
341 return (*p)->balance;
345 * Perform tree rotations on an unbalanced node.
347 * side == 0 means the node's balance is -2 .
348 * side == 1 means the node's balance is +2 .
350 static void balance(AvlNode **p, int side)
352 AvlNode *node = *p,
353 *child = node->lr[side];
354 int opposite = 1 - side;
355 int bal = bal(side);
357 if (child->balance != -bal) {
358 /* Left-left (side == 0) or right-right (side == 1) */
359 node->lr[side] = child->lr[opposite];
360 child->lr[opposite] = node;
361 *p = child;
363 child->balance -= bal;
364 node->balance = -child->balance;
366 } else {
367 /* Left-right (side == 0) or right-left (side == 1) */
368 AvlNode *grandchild = child->lr[opposite];
370 node->lr[side] = grandchild->lr[opposite];
371 child->lr[opposite] = grandchild->lr[side];
372 grandchild->lr[side] = child;
373 grandchild->lr[opposite] = node;
374 *p = grandchild;
376 node->balance = 0;
377 child->balance = 0;
379 if (grandchild->balance == bal)
380 node->balance = -bal;
381 else if (grandchild->balance == -bal)
382 child->balance = bal;
384 grandchild->balance = 0;
389 /************************* avl_check_invariants() *************************/
391 bool avl_check_invariants(struct stree *avl)
393 int dummy;
395 return checkBalances(avl->root, &dummy)
396 && checkOrder(avl)
397 && countNode(avl->root) == avl->count;
400 static bool checkBalances(AvlNode *node, int *height)
402 if (node) {
403 int h0, h1;
405 if (!checkBalances(node->lr[0], &h0))
406 return false;
407 if (!checkBalances(node->lr[1], &h1))
408 return false;
410 if (node->balance != h1 - h0 || node->balance < -1 || node->balance > 1)
411 return false;
413 *height = (h0 > h1 ? h0 : h1) + 1;
414 return true;
415 } else {
416 *height = 0;
417 return true;
421 static bool checkOrder(struct stree *avl)
423 AvlIter i;
424 const struct sm_state *last = NULL;
425 bool last_set = false;
427 avl_foreach(i, avl) {
428 if (last_set && cmp_tracker(last, i.sm) >= 0)
429 return false;
430 last = i.sm;
431 last_set = true;
434 return true;
437 static size_t countNode(AvlNode *node)
439 if (node)
440 return 1 + countNode(node->lr[0]) + countNode(node->lr[1]);
441 else
442 return 0;
446 /************************* Traversal *************************/
448 void avl_iter_begin(AvlIter *iter, struct stree *avl, AvlDirection dir)
450 AvlNode *node;
452 iter->stack_index = 0;
453 iter->direction = dir;
455 if (!avl || !avl->root) {
456 iter->sm = NULL;
457 iter->node = NULL;
458 return;
460 node = avl->root;
462 while (node->lr[dir] != NULL) {
463 iter->stack[iter->stack_index++] = node;
464 node = node->lr[dir];
467 iter->sm = (struct sm_state *) node->sm;
468 iter->node = node;
471 void avl_iter_next(AvlIter *iter)
473 AvlNode *node = iter->node;
474 AvlDirection dir = iter->direction;
476 if (node == NULL)
477 return;
479 node = node->lr[1 - dir];
480 if (node != NULL) {
481 while (node->lr[dir] != NULL) {
482 iter->stack[iter->stack_index++] = node;
483 node = node->lr[dir];
485 } else if (iter->stack_index > 0) {
486 node = iter->stack[--iter->stack_index];
487 } else {
488 iter->sm = NULL;
489 iter->node = NULL;
490 return;
493 iter->node = node;
494 iter->sm = (struct sm_state *) node->sm;
497 struct stree *clone_stree(struct stree *orig)
499 if (!orig)
500 return NULL;
502 orig->references++;
503 return orig;
506 void set_stree_id(struct stree *stree, int stree_id)
508 assert(stree->stree_id == 0);
510 stree->stree_id = stree_id;
513 int get_stree_id(struct stree *stree)
515 if (!stree)
516 return 0;
517 return stree->stree_id;