kernel: parse ATOMIC_SET() manually
[smatch.git] / avl.c
blobf87ceb0ec26bb592740b93e83de2bdd01367e37d
1 /*
2 * Copyright (C) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 * THE SOFTWARE.
23 #include <assert.h>
24 #include <stdlib.h>
26 #include "smatch.h"
27 #include "smatch_slist.h"
29 static AvlNode *mkNode(const struct sm_state *sm);
30 static void freeNode(AvlNode *node);
32 static AvlNode *lookup(const struct stree *avl, AvlNode *node, const struct sm_state *sm);
34 static bool insert_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm);
35 static bool remove_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret);
36 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret);
38 static int sway(AvlNode **p, int sway);
39 static void balance(AvlNode **p, int side);
41 static bool checkBalances(AvlNode *node, int *height);
42 static bool checkOrder(struct stree *avl);
43 static size_t countNode(AvlNode *node);
45 int unfree_stree;
48 * Utility macros for converting between
49 * "balance" values (-1 or 1) and "side" values (0 or 1).
51 * bal(0) == -1
52 * bal(1) == +1
53 * side(-1) == 0
54 * side(+1) == 1
56 #define bal(side) ((side) == 0 ? -1 : 1)
57 #define side(bal) ((bal) == 1 ? 1 : 0)
59 struct stree *avl_new(void)
61 struct stree *avl = malloc(sizeof(*avl));
63 unfree_stree++;
64 assert(avl != NULL);
66 avl->root = NULL;
67 avl->base_stree = NULL;
68 avl->count = 0;
69 avl->stree_id = 0;
70 avl->references = 1;
71 return avl;
74 void free_stree(struct stree **avl)
76 if (!*avl)
77 return;
79 assert((*avl)->references > 0);
81 (*avl)->references--;
82 if ((*avl)->references != 0) {
83 *avl = NULL;
84 return;
87 unfree_stree--;
89 freeNode((*avl)->root);
90 free(*avl);
91 *avl = NULL;
94 struct sm_state *avl_lookup(const struct stree *avl, const struct sm_state *sm)
96 AvlNode *found;
98 if (!avl)
99 return NULL;
100 found = lookup(avl, avl->root, sm);
101 if (!found)
102 return NULL;
103 return (struct sm_state *)found->sm;
106 AvlNode *avl_lookup_node(const struct stree *avl, const struct sm_state *sm)
108 return lookup(avl, avl->root, sm);
111 size_t stree_count(const struct stree *avl)
113 if (!avl)
114 return 0;
115 return avl->count;
118 static struct stree *clone_stree_real(struct stree *orig)
120 struct stree *new = avl_new();
121 AvlIter i;
123 avl_foreach(i, orig)
124 avl_insert(&new, i.sm);
126 new->base_stree = orig->base_stree;
127 return new;
130 bool avl_insert(struct stree **avl, const struct sm_state *sm)
132 size_t old_count;
134 if (!*avl)
135 *avl = avl_new();
136 if ((*avl)->references > 1) {
137 (*avl)->references--;
138 *avl = clone_stree_real(*avl);
140 old_count = (*avl)->count;
141 insert_sm(*avl, &(*avl)->root, sm);
142 return (*avl)->count != old_count;
145 bool avl_remove(struct stree **avl, const struct sm_state *sm)
147 AvlNode *node = NULL;
149 if (!*avl)
150 return false;
151 /* it's fairly rare for smatch to call avl_remove */
152 if ((*avl)->references > 1) {
153 (*avl)->references--;
154 *avl = clone_stree_real(*avl);
157 remove_sm(*avl, &(*avl)->root, sm, &node);
159 if ((*avl)->count == 0)
160 free_stree(avl);
162 if (node == NULL) {
163 return false;
164 } else {
165 free(node);
166 return true;
170 static AvlNode *mkNode(const struct sm_state *sm)
172 AvlNode *node = malloc(sizeof(*node));
174 assert(node != NULL);
176 node->sm = sm;
177 node->lr[0] = NULL;
178 node->lr[1] = NULL;
179 node->balance = 0;
180 return node;
183 static void freeNode(AvlNode *node)
185 if (node) {
186 freeNode(node->lr[0]);
187 freeNode(node->lr[1]);
188 free(node);
192 static AvlNode *lookup(const struct stree *avl, AvlNode *node, const struct sm_state *sm)
194 int cmp;
196 if (node == NULL)
197 return NULL;
199 cmp = cmp_tracker(sm, node->sm);
201 if (cmp < 0)
202 return lookup(avl, node->lr[0], sm);
203 if (cmp > 0)
204 return lookup(avl, node->lr[1], sm);
205 return node;
209 * Insert an sm into a subtree, rebalancing if necessary.
211 * Return true if the subtree's height increased.
213 static bool insert_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm)
215 if (*p == NULL) {
216 *p = mkNode(sm);
217 avl->count++;
218 return true;
219 } else {
220 AvlNode *node = *p;
221 int cmp = cmp_tracker(sm, node->sm);
223 if (cmp == 0) {
224 node->sm = sm;
225 return false;
228 if (!insert_sm(avl, &node->lr[side(cmp)], sm))
229 return false;
231 /* If tree's balance became -1 or 1, it means the tree's height grew due to insertion. */
232 return sway(p, cmp) != 0;
237 * Remove the node matching the given sm.
238 * If present, return the removed node through *ret .
239 * The returned node's lr and balance are meaningless.
241 * Return true if the subtree's height decreased.
243 static bool remove_sm(struct stree *avl, AvlNode **p, const struct sm_state *sm, AvlNode **ret)
245 if (p == NULL || *p == NULL) {
246 return false;
247 } else {
248 AvlNode *node = *p;
249 int cmp = cmp_tracker(sm, node->sm);
251 if (cmp == 0) {
252 *ret = node;
253 avl->count--;
255 if (node->lr[0] != NULL && node->lr[1] != NULL) {
256 AvlNode *replacement;
257 int side;
258 bool shrunk;
260 /* Pick a subtree to pull the replacement from such that
261 * this node doesn't have to be rebalanced. */
262 side = node->balance <= 0 ? 0 : 1;
264 shrunk = removeExtremum(&node->lr[side], 1 - side, &replacement);
266 replacement->lr[0] = node->lr[0];
267 replacement->lr[1] = node->lr[1];
268 replacement->balance = node->balance;
269 *p = replacement;
271 if (!shrunk)
272 return false;
274 replacement->balance -= bal(side);
276 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
277 return replacement->balance == 0;
280 if (node->lr[0] != NULL)
281 *p = node->lr[0];
282 else
283 *p = node->lr[1];
285 return true;
287 } else {
288 if (!remove_sm(avl, &node->lr[side(cmp)], sm, ret))
289 return false;
291 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
292 return sway(p, -cmp) == 0;
298 * Remove either the left-most (if side == 0) or right-most (if side == 1)
299 * node in a subtree, returning the removed node through *ret .
300 * The returned node's lr and balance are meaningless.
302 * The subtree must not be empty (i.e. *p must not be NULL).
304 * Return true if the subtree's height decreased.
306 static bool removeExtremum(AvlNode **p, int side, AvlNode **ret)
308 AvlNode *node = *p;
310 if (node->lr[side] == NULL) {
311 *ret = node;
312 *p = node->lr[1 - side];
313 return true;
316 if (!removeExtremum(&node->lr[side], side, ret))
317 return false;
319 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
320 return sway(p, -bal(side)) == 0;
324 * Rebalance a node if necessary. Think of this function
325 * as a higher-level interface to balance().
327 * sway must be either -1 or 1, and indicates what was added to
328 * the balance of this node by a prior operation.
330 * Return the new balance of the subtree.
332 static int sway(AvlNode **p, int sway)
334 if ((*p)->balance != sway)
335 (*p)->balance += sway;
336 else
337 balance(p, side(sway));
339 return (*p)->balance;
343 * Perform tree rotations on an unbalanced node.
345 * side == 0 means the node's balance is -2 .
346 * side == 1 means the node's balance is +2 .
348 static void balance(AvlNode **p, int side)
350 AvlNode *node = *p,
351 *child = node->lr[side];
352 int opposite = 1 - side;
353 int bal = bal(side);
355 if (child->balance != -bal) {
356 /* Left-left (side == 0) or right-right (side == 1) */
357 node->lr[side] = child->lr[opposite];
358 child->lr[opposite] = node;
359 *p = child;
361 child->balance -= bal;
362 node->balance = -child->balance;
364 } else {
365 /* Left-right (side == 0) or right-left (side == 1) */
366 AvlNode *grandchild = child->lr[opposite];
368 node->lr[side] = grandchild->lr[opposite];
369 child->lr[opposite] = grandchild->lr[side];
370 grandchild->lr[side] = child;
371 grandchild->lr[opposite] = node;
372 *p = grandchild;
374 node->balance = 0;
375 child->balance = 0;
377 if (grandchild->balance == bal)
378 node->balance = -bal;
379 else if (grandchild->balance == -bal)
380 child->balance = bal;
382 grandchild->balance = 0;
387 /************************* avl_check_invariants() *************************/
389 bool avl_check_invariants(struct stree *avl)
391 int dummy;
393 return checkBalances(avl->root, &dummy)
394 && checkOrder(avl)
395 && countNode(avl->root) == avl->count;
398 static bool checkBalances(AvlNode *node, int *height)
400 if (node) {
401 int h0, h1;
403 if (!checkBalances(node->lr[0], &h0))
404 return false;
405 if (!checkBalances(node->lr[1], &h1))
406 return false;
408 if (node->balance != h1 - h0 || node->balance < -1 || node->balance > 1)
409 return false;
411 *height = (h0 > h1 ? h0 : h1) + 1;
412 return true;
413 } else {
414 *height = 0;
415 return true;
419 static bool checkOrder(struct stree *avl)
421 AvlIter i;
422 const struct sm_state *last = NULL;
423 bool last_set = false;
425 avl_foreach(i, avl) {
426 if (last_set && cmp_tracker(last, i.sm) >= 0)
427 return false;
428 last = i.sm;
429 last_set = true;
432 return true;
435 static size_t countNode(AvlNode *node)
437 if (node)
438 return 1 + countNode(node->lr[0]) + countNode(node->lr[1]);
439 else
440 return 0;
444 /************************* Traversal *************************/
446 void avl_iter_begin(AvlIter *iter, struct stree *avl, AvlDirection dir)
448 AvlNode *node;
450 iter->stack_index = 0;
451 iter->direction = dir;
453 if (!avl || !avl->root) {
454 iter->sm = NULL;
455 iter->node = NULL;
456 return;
458 node = avl->root;
460 while (node->lr[dir] != NULL) {
461 iter->stack[iter->stack_index++] = node;
462 node = node->lr[dir];
465 iter->sm = (struct sm_state *) node->sm;
466 iter->node = node;
469 void avl_iter_next(AvlIter *iter)
471 AvlNode *node = iter->node;
472 AvlDirection dir = iter->direction;
474 if (node == NULL)
475 return;
477 node = node->lr[1 - dir];
478 if (node != NULL) {
479 while (node->lr[dir] != NULL) {
480 iter->stack[iter->stack_index++] = node;
481 node = node->lr[dir];
483 } else if (iter->stack_index > 0) {
484 node = iter->stack[--iter->stack_index];
485 } else {
486 iter->sm = NULL;
487 iter->node = NULL;
488 return;
491 iter->node = node;
492 iter->sm = (struct sm_state *) node->sm;
495 struct stree *clone_stree(struct stree *orig)
497 if (!orig)
498 return NULL;
500 orig->references++;
501 return orig;
504 void set_stree_id(struct stree **stree, int stree_id)
506 if ((*stree)->stree_id != 0)
507 *stree = clone_stree_real(*stree);
509 (*stree)->stree_id = stree_id;
512 int get_stree_id(struct stree *stree)
514 if (!stree)
515 return -1;
516 return stree->stree_id;