avl: include smatch.h and smatch_slist.h into avl.c
[smatch.git] / avl.c
blob8606089b5b5a80f761019ac695ce54dbe1ae8fd8
1 /*
2 * Copyright (C) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 * THE SOFTWARE.
23 #include <assert.h>
24 #include <stdlib.h>
26 #include "smatch.h"
27 #include "smatch_slist.h"
29 static AvlNode *mkNode(const void *key, const void *value);
30 static void freeNode(AvlNode *node);
32 static AvlNode *lookup(const AVL *avl, AvlNode *node, const void *key);
34 static bool insert_sm(AVL *avl, AvlNode **p, const void *key, const void *value);
35 static bool remove_sm(AVL *avl, AvlNode **p, const void *key, AvlNode **ret);
36 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret);
38 static int sway(AvlNode **p, int sway);
39 static void balance(AvlNode **p, int side);
41 static bool checkBalances(AvlNode *node, int *height);
42 static bool checkOrder(AVL *avl);
43 static size_t countNode(AvlNode *node);
46 * Utility macros for converting between
47 * "balance" values (-1 or 1) and "side" values (0 or 1).
49 * bal(0) == -1
50 * bal(1) == +1
51 * side(-1) == 0
52 * side(+1) == 1
54 #define bal(side) ((side) == 0 ? -1 : 1)
55 #define side(bal) ((bal) == 1 ? 1 : 0)
57 static int sign(int cmp)
59 if (cmp < 0)
60 return -1;
61 if (cmp == 0)
62 return 0;
63 return 1;
66 AVL *avl_new(AvlCompare compare)
68 AVL *avl = malloc(sizeof(*avl));
70 assert(avl != NULL);
72 avl->compare = compare;
73 avl->root = NULL;
74 avl->count = 0;
75 return avl;
78 void avl_free(AVL *avl)
80 freeNode(avl->root);
81 free(avl);
84 void *avl_lookup(const AVL *avl, const void *key)
86 AvlNode *found = lookup(avl, avl->root, key);
87 return found ? (void*) found->value : NULL;
90 AvlNode *avl_lookup_node(const AVL *avl, const void *key)
92 return lookup(avl, avl->root, key);
95 size_t avl_count(const AVL *avl)
97 return avl->count;
100 bool avl_insert(AVL *avl, const void *key, const void *value)
102 size_t old_count = avl->count;
103 insert_sm(avl, &avl->root, key, value);
104 return avl->count != old_count;
107 bool avl_remove(AVL *avl, const void *key)
109 AvlNode *node = NULL;
111 remove_sm(avl, &avl->root, key, &node);
113 if (node == NULL) {
114 return false;
115 } else {
116 free(node);
117 return true;
121 static AvlNode *mkNode(const void *key, const void *value)
123 AvlNode *node = malloc(sizeof(*node));
125 assert(node != NULL);
127 node->key = key;
128 node->value = value;
129 node->lr[0] = NULL;
130 node->lr[1] = NULL;
131 node->balance = 0;
132 return node;
135 static void freeNode(AvlNode *node)
137 if (node) {
138 freeNode(node->lr[0]);
139 freeNode(node->lr[1]);
140 free(node);
144 static AvlNode *lookup(const AVL *avl, AvlNode *node, const void *key)
146 int cmp;
148 if (node == NULL)
149 return NULL;
151 cmp = avl->compare(key, node->key);
153 if (cmp < 0)
154 return lookup(avl, node->lr[0], key);
155 if (cmp > 0)
156 return lookup(avl, node->lr[1], key);
157 return node;
161 * Insert a key/value into a subtree, rebalancing if necessary.
163 * Return true if the subtree's height increased.
165 static bool insert_sm(AVL *avl, AvlNode **p, const void *key, const void *value)
167 if (*p == NULL) {
168 *p = mkNode(key, value);
169 avl->count++;
170 return true;
171 } else {
172 AvlNode *node = *p;
173 int cmp = sign(avl->compare(key, node->key));
175 if (cmp == 0) {
176 node->key = key;
177 node->value = value;
178 return false;
181 if (!insert_sm(avl, &node->lr[side(cmp)], key, value))
182 return false;
184 /* If tree's balance became -1 or 1, it means the tree's height grew due to insertion. */
185 return sway(p, cmp) != 0;
190 * Remove the node matching the given key.
191 * If present, return the removed node through *ret .
192 * The returned node's lr and balance are meaningless.
194 * Return true if the subtree's height decreased.
196 static bool remove_sm(AVL *avl, AvlNode **p, const void *key, AvlNode **ret)
198 if (*p == NULL) {
199 return false;
200 } else {
201 AvlNode *node = *p;
202 int cmp = sign(avl->compare(key, node->key));
204 if (cmp == 0) {
205 *ret = node;
206 avl->count--;
208 if (node->lr[0] != NULL && node->lr[1] != NULL) {
209 AvlNode *replacement;
210 int side;
211 bool shrunk;
213 /* Pick a subtree to pull the replacement from such that
214 * this node doesn't have to be rebalanced. */
215 side = node->balance <= 0 ? 0 : 1;
217 shrunk = removeExtremum(&node->lr[side], 1 - side, &replacement);
219 replacement->lr[0] = node->lr[0];
220 replacement->lr[1] = node->lr[1];
221 replacement->balance = node->balance;
222 *p = replacement;
224 if (!shrunk)
225 return false;
227 replacement->balance -= bal(side);
229 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
230 return replacement->balance == 0;
233 if (node->lr[0] != NULL)
234 *p = node->lr[0];
235 else
236 *p = node->lr[1];
238 return true;
240 } else {
241 if (!remove_sm(avl, &node->lr[side(cmp)], key, ret))
242 return false;
244 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
245 return sway(p, -cmp) == 0;
251 * Remove either the left-most (if side == 0) or right-most (if side == 1)
252 * node in a subtree, returning the removed node through *ret .
253 * The returned node's lr and balance are meaningless.
255 * The subtree must not be empty (i.e. *p must not be NULL).
257 * Return true if the subtree's height decreased.
259 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret)
261 AvlNode *node = *p;
263 if (node->lr[side] == NULL) {
264 *ret = node;
265 *p = node->lr[1 - side];
266 return true;
269 if (!removeExtremum(&node->lr[side], side, ret))
270 return false;
272 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
273 return sway(p, -bal(side)) == 0;
277 * Rebalance a node if necessary. Think of this function
278 * as a higher-level interface to balance().
280 * sway must be either -1 or 1, and indicates what was added to
281 * the balance of this node by a prior operation.
283 * Return the new balance of the subtree.
285 static int sway(AvlNode **p, int sway)
287 if ((*p)->balance != sway)
288 (*p)->balance += sway;
289 else
290 balance(p, side(sway));
292 return (*p)->balance;
296 * Perform tree rotations on an unbalanced node.
298 * side == 0 means the node's balance is -2 .
299 * side == 1 means the node's balance is +2 .
301 static void balance(AvlNode **p, int side)
303 AvlNode *node = *p,
304 *child = node->lr[side];
305 int opposite = 1 - side;
306 int bal = bal(side);
308 if (child->balance != -bal) {
309 /* Left-left (side == 0) or right-right (side == 1) */
310 node->lr[side] = child->lr[opposite];
311 child->lr[opposite] = node;
312 *p = child;
314 child->balance -= bal;
315 node->balance = -child->balance;
317 } else {
318 /* Left-right (side == 0) or right-left (side == 1) */
319 AvlNode *grandchild = child->lr[opposite];
321 node->lr[side] = grandchild->lr[opposite];
322 child->lr[opposite] = grandchild->lr[side];
323 grandchild->lr[side] = child;
324 grandchild->lr[opposite] = node;
325 *p = grandchild;
327 node->balance = 0;
328 child->balance = 0;
330 if (grandchild->balance == bal)
331 node->balance = -bal;
332 else if (grandchild->balance == -bal)
333 child->balance = bal;
335 grandchild->balance = 0;
340 /************************* avl_check_invariants() *************************/
342 bool avl_check_invariants(AVL *avl)
344 int dummy;
346 return checkBalances(avl->root, &dummy)
347 && checkOrder(avl)
348 && countNode(avl->root) == avl->count;
351 static bool checkBalances(AvlNode *node, int *height)
353 if (node) {
354 int h0, h1;
356 if (!checkBalances(node->lr[0], &h0))
357 return false;
358 if (!checkBalances(node->lr[1], &h1))
359 return false;
361 if (node->balance != h1 - h0 || node->balance < -1 || node->balance > 1)
362 return false;
364 *height = (h0 > h1 ? h0 : h1) + 1;
365 return true;
366 } else {
367 *height = 0;
368 return true;
372 static bool checkOrder(AVL *avl)
374 AvlIter i;
375 const void *last = NULL;
376 bool last_set = false;
378 avl_foreach(i, avl) {
379 if (last_set && avl->compare(last, i.key) >= 0)
380 return false;
381 last = i.key;
382 last_set = true;
385 return true;
388 static size_t countNode(AvlNode *node)
390 if (node)
391 return 1 + countNode(node->lr[0]) + countNode(node->lr[1]);
392 else
393 return 0;
397 /************************* Traversal *************************/
399 void avl_iter_begin(AvlIter *iter, AVL *avl, AvlDirection dir)
401 AvlNode *node = avl->root;
403 iter->stack_index = 0;
404 iter->direction = dir;
406 if (node == NULL) {
407 iter->key = NULL;
408 iter->value = NULL;
409 iter->node = NULL;
410 return;
413 while (node->lr[dir] != NULL) {
414 iter->stack[iter->stack_index++] = node;
415 node = node->lr[dir];
418 iter->key = (void*) node->key;
419 iter->value = (void*) node->value;
420 iter->node = node;
423 void avl_iter_next(AvlIter *iter)
425 AvlNode *node = iter->node;
426 AvlDirection dir = iter->direction;
428 if (node == NULL)
429 return;
431 node = node->lr[1 - dir];
432 if (node != NULL) {
433 while (node->lr[dir] != NULL) {
434 iter->stack[iter->stack_index++] = node;
435 node = node->lr[dir];
437 } else if (iter->stack_index > 0) {
438 node = iter->stack[--iter->stack_index];
439 } else {
440 iter->key = NULL;
441 iter->value = NULL;
442 iter->node = NULL;
443 return;
446 iter->node = node;
447 iter->key = (void*) node->key;
448 iter->value = (void*) node->value;