2 * Copyright (C) 2010 Joseph Adams <joeyadams3.14159@gmail.com>
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to deal
6 * in the Software without restriction, including without limitation the rights
7 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 * copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 #include "smatch_slist.h"
29 static AvlNode
*mkNode(const void *key
, const void *value
);
30 static void freeNode(AvlNode
*node
);
32 static AvlNode
*lookup(const AVL
*avl
, AvlNode
*node
, const void *key
);
34 static bool insert_sm(AVL
*avl
, AvlNode
**p
, const void *key
, const void *value
);
35 static bool remove_sm(AVL
*avl
, AvlNode
**p
, const void *key
, AvlNode
**ret
);
36 static bool removeExtremum(AvlNode
**p
, int side
, AvlNode
**ret
);
38 static int sway(AvlNode
**p
, int sway
);
39 static void balance(AvlNode
**p
, int side
);
41 static bool checkBalances(AvlNode
*node
, int *height
);
42 static bool checkOrder(AVL
*avl
);
43 static size_t countNode(AvlNode
*node
);
46 * Utility macros for converting between
47 * "balance" values (-1 or 1) and "side" values (0 or 1).
54 #define bal(side) ((side) == 0 ? -1 : 1)
55 #define side(bal) ((bal) == 1 ? 1 : 0)
57 static int sign(int cmp
)
66 AVL
*avl_new(AvlCompare compare
)
68 AVL
*avl
= malloc(sizeof(*avl
));
72 avl
->compare
= compare
;
78 void avl_free(AVL
*avl
)
84 void *avl_lookup(const AVL
*avl
, const void *key
)
86 AvlNode
*found
= lookup(avl
, avl
->root
, key
);
87 return found
? (void*) found
->value
: NULL
;
90 AvlNode
*avl_lookup_node(const AVL
*avl
, const void *key
)
92 return lookup(avl
, avl
->root
, key
);
95 size_t avl_count(const AVL
*avl
)
100 bool avl_insert(AVL
*avl
, const void *key
, const void *value
)
102 size_t old_count
= avl
->count
;
103 insert_sm(avl
, &avl
->root
, key
, value
);
104 return avl
->count
!= old_count
;
107 bool avl_remove(AVL
*avl
, const void *key
)
109 AvlNode
*node
= NULL
;
111 remove_sm(avl
, &avl
->root
, key
, &node
);
121 static AvlNode
*mkNode(const void *key
, const void *value
)
123 AvlNode
*node
= malloc(sizeof(*node
));
125 assert(node
!= NULL
);
135 static void freeNode(AvlNode
*node
)
138 freeNode(node
->lr
[0]);
139 freeNode(node
->lr
[1]);
144 static AvlNode
*lookup(const AVL
*avl
, AvlNode
*node
, const void *key
)
151 cmp
= avl
->compare(key
, node
->key
);
154 return lookup(avl
, node
->lr
[0], key
);
156 return lookup(avl
, node
->lr
[1], key
);
161 * Insert a key/value into a subtree, rebalancing if necessary.
163 * Return true if the subtree's height increased.
165 static bool insert_sm(AVL
*avl
, AvlNode
**p
, const void *key
, const void *value
)
168 *p
= mkNode(key
, value
);
173 int cmp
= sign(avl
->compare(key
, node
->key
));
181 if (!insert_sm(avl
, &node
->lr
[side(cmp
)], key
, value
))
184 /* If tree's balance became -1 or 1, it means the tree's height grew due to insertion. */
185 return sway(p
, cmp
) != 0;
190 * Remove the node matching the given key.
191 * If present, return the removed node through *ret .
192 * The returned node's lr and balance are meaningless.
194 * Return true if the subtree's height decreased.
196 static bool remove_sm(AVL
*avl
, AvlNode
**p
, const void *key
, AvlNode
**ret
)
202 int cmp
= sign(avl
->compare(key
, node
->key
));
208 if (node
->lr
[0] != NULL
&& node
->lr
[1] != NULL
) {
209 AvlNode
*replacement
;
213 /* Pick a subtree to pull the replacement from such that
214 * this node doesn't have to be rebalanced. */
215 side
= node
->balance
<= 0 ? 0 : 1;
217 shrunk
= removeExtremum(&node
->lr
[side
], 1 - side
, &replacement
);
219 replacement
->lr
[0] = node
->lr
[0];
220 replacement
->lr
[1] = node
->lr
[1];
221 replacement
->balance
= node
->balance
;
227 replacement
->balance
-= bal(side
);
229 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
230 return replacement
->balance
== 0;
233 if (node
->lr
[0] != NULL
)
241 if (!remove_sm(avl
, &node
->lr
[side(cmp
)], key
, ret
))
244 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
245 return sway(p
, -cmp
) == 0;
251 * Remove either the left-most (if side == 0) or right-most (if side == 1)
252 * node in a subtree, returning the removed node through *ret .
253 * The returned node's lr and balance are meaningless.
255 * The subtree must not be empty (i.e. *p must not be NULL).
257 * Return true if the subtree's height decreased.
259 static bool removeExtremum(AvlNode
**p
, int side
, AvlNode
**ret
)
263 if (node
->lr
[side
] == NULL
) {
265 *p
= node
->lr
[1 - side
];
269 if (!removeExtremum(&node
->lr
[side
], side
, ret
))
272 /* If tree's balance became 0, it means the tree's height shrank due to removal. */
273 return sway(p
, -bal(side
)) == 0;
277 * Rebalance a node if necessary. Think of this function
278 * as a higher-level interface to balance().
280 * sway must be either -1 or 1, and indicates what was added to
281 * the balance of this node by a prior operation.
283 * Return the new balance of the subtree.
285 static int sway(AvlNode
**p
, int sway
)
287 if ((*p
)->balance
!= sway
)
288 (*p
)->balance
+= sway
;
290 balance(p
, side(sway
));
292 return (*p
)->balance
;
296 * Perform tree rotations on an unbalanced node.
298 * side == 0 means the node's balance is -2 .
299 * side == 1 means the node's balance is +2 .
301 static void balance(AvlNode
**p
, int side
)
304 *child
= node
->lr
[side
];
305 int opposite
= 1 - side
;
308 if (child
->balance
!= -bal
) {
309 /* Left-left (side == 0) or right-right (side == 1) */
310 node
->lr
[side
] = child
->lr
[opposite
];
311 child
->lr
[opposite
] = node
;
314 child
->balance
-= bal
;
315 node
->balance
= -child
->balance
;
318 /* Left-right (side == 0) or right-left (side == 1) */
319 AvlNode
*grandchild
= child
->lr
[opposite
];
321 node
->lr
[side
] = grandchild
->lr
[opposite
];
322 child
->lr
[opposite
] = grandchild
->lr
[side
];
323 grandchild
->lr
[side
] = child
;
324 grandchild
->lr
[opposite
] = node
;
330 if (grandchild
->balance
== bal
)
331 node
->balance
= -bal
;
332 else if (grandchild
->balance
== -bal
)
333 child
->balance
= bal
;
335 grandchild
->balance
= 0;
340 /************************* avl_check_invariants() *************************/
342 bool avl_check_invariants(AVL
*avl
)
346 return checkBalances(avl
->root
, &dummy
)
348 && countNode(avl
->root
) == avl
->count
;
351 static bool checkBalances(AvlNode
*node
, int *height
)
356 if (!checkBalances(node
->lr
[0], &h0
))
358 if (!checkBalances(node
->lr
[1], &h1
))
361 if (node
->balance
!= h1
- h0
|| node
->balance
< -1 || node
->balance
> 1)
364 *height
= (h0
> h1
? h0
: h1
) + 1;
372 static bool checkOrder(AVL
*avl
)
375 const void *last
= NULL
;
376 bool last_set
= false;
378 avl_foreach(i
, avl
) {
379 if (last_set
&& avl
->compare(last
, i
.key
) >= 0)
388 static size_t countNode(AvlNode
*node
)
391 return 1 + countNode(node
->lr
[0]) + countNode(node
->lr
[1]);
397 /************************* Traversal *************************/
399 void avl_iter_begin(AvlIter
*iter
, AVL
*avl
, AvlDirection dir
)
401 AvlNode
*node
= avl
->root
;
403 iter
->stack_index
= 0;
404 iter
->direction
= dir
;
413 while (node
->lr
[dir
] != NULL
) {
414 iter
->stack
[iter
->stack_index
++] = node
;
415 node
= node
->lr
[dir
];
418 iter
->key
= (void*) node
->key
;
419 iter
->value
= (void*) node
->value
;
423 void avl_iter_next(AvlIter
*iter
)
425 AvlNode
*node
= iter
->node
;
426 AvlDirection dir
= iter
->direction
;
431 node
= node
->lr
[1 - dir
];
433 while (node
->lr
[dir
] != NULL
) {
434 iter
->stack
[iter
->stack_index
++] = node
;
435 node
= node
->lr
[dir
];
437 } else if (iter
->stack_index
> 0) {
438 node
= iter
->stack
[--iter
->stack_index
];
447 iter
->key
= (void*) node
->key
;
448 iter
->value
= (void*) node
->value
;