avl: introduce avl_clone()
[smatch.git] / avl.c
blob16b686cdf69fd694eecc6c0a1993a6d8d951eba4
1 /*
2 * Copyright (C) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 * THE SOFTWARE.
23 #include <assert.h>
24 #include <stdlib.h>
26 #include "smatch.h"
27 #include "smatch_slist.h"
29 static AvlNode *mkNode(const struct sm_state *sm);
30 static void freeNode(AvlNode *node);
32 static AvlNode *lookup(const AVL *avl, AvlNode *node, const struct sm_state *sm);
34 static bool insert_sm(AVL *avl, AvlNode **p, const struct sm_state *sm);
35 static bool remove_sm(AVL *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret);
36 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret);
38 static int sway(AvlNode **p, int sway);
39 static void balance(AvlNode **p, int side);
41 static bool checkBalances(AvlNode *node, int *height);
42 static bool checkOrder(AVL *avl);
43 static size_t countNode(AvlNode *node);
46 * Utility macros for converting between
47 * "balance" values (-1 or 1) and "side" values (0 or 1).
49 * bal(0) == -1
50 * bal(1) == +1
51 * side(-1) == 0
52 * side(+1) == 1
54 #define bal(side) ((side) == 0 ? -1 : 1)
55 #define side(bal) ((bal) == 1 ? 1 : 0)
57 static int sign(int cmp)
59 if (cmp < 0)
60 return -1;
61 if (cmp == 0)
62 return 0;
63 return 1;
66 AVL *avl_new(void)
68 AVL *avl = malloc(sizeof(*avl));
70 assert(avl != NULL);
72 avl->root = NULL;
73 avl->count = 0;
74 return avl;
77 void avl_free(AVL **avl)
79 freeNode((*avl)->root);
80 free(*avl);
81 *avl = NULL;
84 struct sm_state *avl_lookup(const AVL *avl, const struct sm_state *sm)
86 AvlNode *found = lookup(avl, avl->root, sm);
87 return found ? (struct sm_state *) found->sm : NULL;
90 AvlNode *avl_lookup_node(const AVL *avl, const struct sm_state *sm)
92 return lookup(avl, avl->root, sm);
95 size_t avl_count(const AVL *avl)
97 return avl->count;
100 bool avl_insert(AVL **avl, const struct sm_state *sm)
102 size_t old_count;
104 if (!*avl)
105 *avl = avl_new();
106 old_count = (*avl)->count;
107 insert_sm(*avl, &(*avl)->root, sm);
108 return (*avl)->count != old_count;
111 bool avl_remove(AVL **avl, const struct sm_state *sm)
113 AvlNode *node = NULL;
115 remove_sm(*avl, &(*avl)->root, sm, &node);
117 if ((*avl)->count == 0)
118 avl_free(avl);
120 if (node == NULL) {
121 return false;
122 } else {
123 free(node);
124 return true;
128 static AvlNode *mkNode(const struct sm_state *sm)
130 AvlNode *node = malloc(sizeof(*node));
132 assert(node != NULL);
134 node->sm = sm;
135 node->lr[0] = NULL;
136 node->lr[1] = NULL;
137 node->balance = 0;
138 return node;
141 static void freeNode(AvlNode *node)
143 if (node) {
144 freeNode(node->lr[0]);
145 freeNode(node->lr[1]);
146 free(node);
150 static AvlNode *lookup(const AVL *avl, AvlNode *node, const struct sm_state *sm)
152 int cmp;
154 if (node == NULL)
155 return NULL;
157 cmp = cmp_tracker(sm, node->sm);
159 if (cmp < 0)
160 return lookup(avl, node->lr[0], sm);
161 if (cmp > 0)
162 return lookup(avl, node->lr[1], sm);
163 return node;
167 * Insert an sm into a subtree, rebalancing if necessary.
169 * Return true if the subtree's height increased.
171 static bool insert_sm(AVL *avl, AvlNode **p, const struct sm_state *sm)
173 if (*p == NULL) {
174 *p = mkNode(sm);
175 avl->count++;
176 return true;
177 } else {
178 AvlNode *node = *p;
179 int cmp = sign(cmp_tracker(sm, node->sm));
181 if (cmp == 0) {
182 node->sm = sm;
183 return false;
186 if (!insert_sm(avl, &node->lr[side(cmp)], sm))
187 return false;
189 /* If tree's balance became -1 or 1, it means the tree's height grew due to insertion. */
190 return sway(p, cmp) != 0;
195 * Remove the node matching the given sm.
196 * If present, return the removed node through *ret .
197 * The returned node's lr and balance are meaningless.
199 * Return true if the subtree's height decreased.
201 static bool remove_sm(AVL *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret)
203 if (*p == NULL) {
204 return false;
205 } else {
206 AvlNode *node = *p;
207 int cmp = sign(cmp_tracker(sm, node->sm));
209 if (cmp == 0) {
210 *ret = node;
211 avl->count--;
213 if (node->lr[0] != NULL && node->lr[1] != NULL) {
214 AvlNode *replacement;
215 int side;
216 bool shrunk;
218 /* Pick a subtree to pull the replacement from such that
219 * this node doesn't have to be rebalanced. */
220 side = node->balance <= 0 ? 0 : 1;
222 shrunk = removeExtremum(&node->lr[side], 1 - side, &replacement);
224 replacement->lr[0] = node->lr[0];
225 replacement->lr[1] = node->lr[1];
226 replacement->balance = node->balance;
227 *p = replacement;
229 if (!shrunk)
230 return false;
232 replacement->balance -= bal(side);
234 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
235 return replacement->balance == 0;
238 if (node->lr[0] != NULL)
239 *p = node->lr[0];
240 else
241 *p = node->lr[1];
243 return true;
245 } else {
246 if (!remove_sm(avl, &node->lr[side(cmp)], sm, ret))
247 return false;
249 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
250 return sway(p, -cmp) == 0;
256 * Remove either the left-most (if side == 0) or right-most (if side == 1)
257 * node in a subtree, returning the removed node through *ret .
258 * The returned node's lr and balance are meaningless.
260 * The subtree must not be empty (i.e. *p must not be NULL).
262 * Return true if the subtree's height decreased.
264 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret)
266 AvlNode *node = *p;
268 if (node->lr[side] == NULL) {
269 *ret = node;
270 *p = node->lr[1 - side];
271 return true;
274 if (!removeExtremum(&node->lr[side], side, ret))
275 return false;
277 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
278 return sway(p, -bal(side)) == 0;
282 * Rebalance a node if necessary. Think of this function
283 * as a higher-level interface to balance().
285 * sway must be either -1 or 1, and indicates what was added to
286 * the balance of this node by a prior operation.
288 * Return the new balance of the subtree.
290 static int sway(AvlNode **p, int sway)
292 if ((*p)->balance != sway)
293 (*p)->balance += sway;
294 else
295 balance(p, side(sway));
297 return (*p)->balance;
301 * Perform tree rotations on an unbalanced node.
303 * side == 0 means the node's balance is -2 .
304 * side == 1 means the node's balance is +2 .
306 static void balance(AvlNode **p, int side)
308 AvlNode *node = *p,
309 *child = node->lr[side];
310 int opposite = 1 - side;
311 int bal = bal(side);
313 if (child->balance != -bal) {
314 /* Left-left (side == 0) or right-right (side == 1) */
315 node->lr[side] = child->lr[opposite];
316 child->lr[opposite] = node;
317 *p = child;
319 child->balance -= bal;
320 node->balance = -child->balance;
322 } else {
323 /* Left-right (side == 0) or right-left (side == 1) */
324 AvlNode *grandchild = child->lr[opposite];
326 node->lr[side] = grandchild->lr[opposite];
327 child->lr[opposite] = grandchild->lr[side];
328 grandchild->lr[side] = child;
329 grandchild->lr[opposite] = node;
330 *p = grandchild;
332 node->balance = 0;
333 child->balance = 0;
335 if (grandchild->balance == bal)
336 node->balance = -bal;
337 else if (grandchild->balance == -bal)
338 child->balance = bal;
340 grandchild->balance = 0;
345 /************************* avl_check_invariants() *************************/
347 bool avl_check_invariants(AVL *avl)
349 int dummy;
351 return checkBalances(avl->root, &dummy)
352 && checkOrder(avl)
353 && countNode(avl->root) == avl->count;
356 static bool checkBalances(AvlNode *node, int *height)
358 if (node) {
359 int h0, h1;
361 if (!checkBalances(node->lr[0], &h0))
362 return false;
363 if (!checkBalances(node->lr[1], &h1))
364 return false;
366 if (node->balance != h1 - h0 || node->balance < -1 || node->balance > 1)
367 return false;
369 *height = (h0 > h1 ? h0 : h1) + 1;
370 return true;
371 } else {
372 *height = 0;
373 return true;
377 static bool checkOrder(AVL *avl)
379 AvlIter i;
380 const struct sm_state *last = NULL;
381 bool last_set = false;
383 avl_foreach(i, avl) {
384 if (last_set && cmp_tracker(last, i.sm) >= 0)
385 return false;
386 last = i.sm;
387 last_set = true;
390 return true;
393 static size_t countNode(AvlNode *node)
395 if (node)
396 return 1 + countNode(node->lr[0]) + countNode(node->lr[1]);
397 else
398 return 0;
402 /************************* Traversal *************************/
404 void avl_iter_begin(AvlIter *iter, AVL *avl, AvlDirection dir)
406 AvlNode *node = avl->root;
408 iter->stack_index = 0;
409 iter->direction = dir;
411 if (node == NULL) {
412 iter->sm = NULL;
413 iter->node = NULL;
414 return;
417 while (node->lr[dir] != NULL) {
418 iter->stack[iter->stack_index++] = node;
419 node = node->lr[dir];
422 iter->sm = (struct sm_state *) node->sm;
423 iter->node = node;
426 void avl_iter_next(AvlIter *iter)
428 AvlNode *node = iter->node;
429 AvlDirection dir = iter->direction;
431 if (node == NULL)
432 return;
434 node = node->lr[1 - dir];
435 if (node != NULL) {
436 while (node->lr[dir] != NULL) {
437 iter->stack[iter->stack_index++] = node;
438 node = node->lr[dir];
440 } else if (iter->stack_index > 0) {
441 node = iter->stack[--iter->stack_index];
442 } else {
443 iter->sm = NULL;
444 iter->node = NULL;
445 return;
448 iter->node = node;
449 iter->sm = (struct sm_state *) node->sm;
452 AVL *avl_clone(AVL *orig)
454 AVL *new = NULL;
455 AvlIter i;
457 avl_foreach(i, orig)
458 avl_insert(&new, i.sm);
460 return new;