Slight tweaks.
[zddfun.git] / zdd.c
blobbc74c1da6c903be396060140be0e5c2abd2e333c
1 // ZDD stack-based calculator library.
2 #include <stdarg.h>
3 #include <stdint.h>
4 #include <stdlib.h>
5 #include <stdio.h>
6 #include <string.h>
7 #include <gmp.h>
8 #include "memo.h"
9 #include "darray.h"
10 #include "zdd.h"
11 #include "io.h"
13 struct node_s {
14 uint16_t v;
15 uint32_t lo, hi;
17 typedef struct node_s *node_ptr;
18 typedef struct node_s node_t[1];
20 static node_t pool[1<<24];
21 static uint32_t freenode, POOL_MAX = (1<<24) - 1;
22 static darray_t stack;
23 static uint16_t vmax;
24 static char vmax_is_set;
26 uint16_t zdd_set_vmax(int i) {
27 vmax_is_set = 1;
28 return vmax = i;
31 void vmax_check() {
32 if (!vmax_is_set) die("vmax not set");
35 void zdd_push() { darray_append(stack, (void *) freenode); }
36 void zdd_pop() { darray_remove_last(stack); }
38 void set_node(uint32_t n, uint16_t v, uint32_t lo, uint32_t hi) {
39 pool[n]->v = v;
40 pool[n]->lo = lo;
41 pool[n]->hi = hi;
44 uint32_t zdd_v(uint32_t n) { return pool[n]->v; }
45 uint32_t zdd_hi(uint32_t n) { return pool[n]->hi; }
46 uint32_t zdd_lo(uint32_t n) { return pool[n]->lo; }
47 uint32_t zdd_set_lo(uint32_t n, uint32_t lo) { return pool[n]->lo = lo; }
48 uint32_t zdd_set_hi(uint32_t n, uint32_t hi) { return pool[n]->hi = hi; }
49 uint32_t zdd_set_hilo(uint32_t n, uint32_t hilo) {
50 return pool[n]->lo = pool[n]->hi = hilo;
52 uint32_t zdd_next_node() { return freenode; }
53 uint32_t zdd_last_node() { return freenode - 1; }
55 static void pool_swap(uint32_t x, uint32_t y) {
56 struct node_s tmp = *pool[y];
57 *pool[y] = *pool[x];
58 *pool[x] = tmp;
59 for(uint32_t i = 2; i < freenode; i++) {
60 if (pool[i]->lo == x) pool[i]->lo = y;
61 else if (pool[i]->lo == y) pool[i]->lo = x;
62 if (pool[i]->hi == x) pool[i]->hi = y;
63 else if (pool[i]->hi == y) pool[i]->hi = y;
67 uint32_t zdd_root() { return (uint32_t) darray_last(stack); }
68 uint32_t zdd_set_root(uint32_t root) {
69 uint32_t i = zdd_root();
70 if (i != root) pool_swap(i, root);
71 return i;
74 void zdd_count(mpz_ptr z) {
75 static mpz_ptr count[1<<24];
76 uint32_t root = zdd_root();
77 // Count elements in ZDD rooted at node n.
78 void get_count(uint32_t n) {
79 if (count[n]) return;
80 count[n] = malloc(sizeof(mpz_t));
81 mpz_init(count[n]);
82 if (n <= 1) {
83 mpz_set_ui(count[n], n);
84 return;
86 uint32_t x = pool[n]->lo;
87 uint32_t y = pool[n]->hi;
88 if (!count[x]) get_count(x);
89 if (!count[y]) get_count(y);
90 if (n == root) {
91 mpz_add(z, count[x], count[y]);
92 } else {
93 mpz_add(count[n], count[x], count[y]);
97 get_count(root);
99 void clearz(uint32_t n) {
100 if (!count[n]) return;
101 mpz_clear(count[n]);
102 free(count[n]);
103 count[n] = NULL;
104 if (n <= 1) return;
105 clearz(pool[n]->lo);
106 clearz(pool[n]->hi);
108 clearz(root);
111 uint32_t zdd_abs_node(uint32_t v, uint32_t lo, uint32_t hi) {
112 set_node(freenode, v, lo, hi);
113 return freenode++;
116 uint32_t zdd_add_node(uint32_t v, int offlo, int offhi) {
117 int n = freenode;
118 uint32_t adjust(int off) {
119 if (!off) return 0;
120 if (-1 == off) return 1;
121 return n + off;
123 set_node(n, v, adjust(offlo), adjust(offhi));
124 return freenode++;
127 uint32_t zdd_intersection() {
128 vmax_check();
129 if (darray_count(stack) == 0) return 0;
130 if (darray_count(stack) == 1) return (uint32_t) darray_last(stack);
131 uint32_t z0 = (uint32_t) darray_at(stack, darray_count(stack) - 2);
132 uint32_t z1 = (uint32_t) darray_remove_last(stack);
133 struct node_template_s {
134 uint16_t v;
135 // NULL means this template have been instantiated.
136 // Otherwise it points to the LO template.
137 memo_it lo;
138 union {
139 // Points to HI template when template is not yet instantiated.
140 memo_it hi;
141 // During template instantiation we set n to the pool index
142 // of the newly created node.
143 uint32_t n;
146 typedef struct node_template_s *node_template_ptr;
147 typedef struct node_template_s node_template_t[1];
149 node_template_t top, bot;
150 bot->v = 0;
151 bot->lo = NULL;
152 bot->n = 0;
153 top->v = 1;
154 top->lo = NULL;
155 top->n = 1;
157 // Naive implementation with two tries. One stores templates, the other
158 // unique nodes. See Knuth for how to meld using just memory allocated
159 // for a pool of nodes.
160 memo_t tab;
161 memo_init(tab);
163 memo_it insert_template(uint32_t k0, uint32_t k1) {
164 uint32_t key[2];
165 // Taking advantage of symmetry of intersection appears to help a tiny bit.
166 if (k0 < k1) {
167 key[0] = k0;
168 key[1] = k1;
169 } else {
170 key[0] = k1;
171 key[1] = k0;
173 memo_it it;
174 int just_created = memo_it_insert_u(&it, tab, (void *) key, 8);
175 if (!just_created) return it;
176 if (!k0 || !k1) {
177 memo_it_put(it, bot);
178 return it;
180 if (k0 == 1 && k1 == 1) {
181 memo_it_put(it, top);
182 return it;
184 node_ptr n0 = pool[k0];
185 node_ptr n1 = pool[k1];
186 if (n0->v == n1->v) {
187 node_template_ptr t = malloc(sizeof(*t));
188 t->v = n0->v;
189 if (n0->lo == n0->hi && n1->lo == n0->hi) {
190 t->lo = t->hi = insert_template(n0->lo, n1->lo);
191 } else {
192 t->lo = insert_template(n0->lo, n1->lo);
193 t->hi = insert_template(n0->hi, n1->hi);
195 memo_it_put(it, t);
196 return it;
197 } else if (n0->v < n1->v) {
198 memo_it it2 = insert_template(n0->lo, k1);
199 memo_it_put(it, memo_it_data(it2));
200 return it2;
201 } else {
202 memo_it it2 = insert_template(k0, n1->lo);
203 memo_it_put(it, memo_it_data(it2));
204 return it2;
208 void dump(void* data, const char* key) {
209 uint32_t *n = (uint32_t *) key;
210 if (!data) {
211 printf("NULL %d:%d\n", n[0], n[1]);
212 return;
214 node_template_ptr t = (node_template_ptr) data;
215 if (!t->lo) {
216 printf("%d:%d = (%d)\n", n[0], n[1], t->n);
217 return;
219 uint32_t *l = (uint32_t *) memo_it_key(t->lo);
220 uint32_t *h = (uint32_t *) memo_it_key(t->hi);
221 printf("%d:%d = %d:%d, %d:%d\n", n[0], n[1], l[0], l[1], h[0], h[1]);
224 memo_t node_tab[vmax + 1];
225 for(uint16_t v = 1; v <= vmax; v++) memo_init(node_tab[v]);
227 uint32_t unique(uint16_t v, uint32_t lo, uint32_t hi) {
228 // Create or return existing node representing !v ? lo : hi.
229 uint32_t key[2] = { lo, hi };
230 memo_it it;
231 int just_created = memo_it_insert_u(&it, node_tab[v], (void *) key, 8);
232 if (just_created) {
233 memo_it_put(it, (void *) freenode);
234 node_ptr n = pool[freenode];
235 n->v = v;
236 n->lo = lo;
237 n->hi = hi;
238 if (!(freenode << 15)) printf("freenode = %x\n", freenode);
239 if (POOL_MAX == freenode) {
240 die("pool is full");
242 return freenode++;
244 return (uint32_t) memo_it_data(it);
247 uint32_t instantiate(memo_it it) {
248 node_template_ptr t = (node_template_ptr) memo_it_data(it);
249 // Return if already converted to node.
250 if (!t->lo) return t->n;
251 // Recurse on LO, HI edges.
252 uint32_t lo = instantiate(t->lo);
253 uint32_t hi = instantiate(t->hi);
254 // Remove HI edges pointing to FALSE.
255 if (!hi) {
256 t->lo = NULL;
257 t->n = lo;
258 return lo;
260 // Convert to node.
261 uint32_t r = unique(t->v, lo, hi);
262 t->lo = NULL;
263 t->n = r;
264 return r;
267 insert_template(z0, z1);
268 freenode = z0; // Overwrite input trees.
269 //memo_forall(tab, dump);
270 uint32_t key[2];
271 key[0] = z0;
272 key[1] = z1;
273 memo_it it = memo_it_at_u(tab, (void *) key, 8);
274 uint32_t root = instantiate(it);
275 // TODO: What if the intersection is node 0 or 1?
276 if (root <= 1) {
277 die("root is 0 or 1!");
279 if (root < z0) {
280 *pool[z0] = *pool[root];
281 } else if (root > z0) {
282 pool_swap(z0, root);
284 void clear_it(void* data, const char* key) {
285 node_template_ptr t = (node_template_ptr) data;
286 uint32_t *k = (uint32_t *) key;
287 if (k[0] == k[1] && t != top && t != bot) free(t);
289 memo_forall(tab, clear_it);
290 memo_clear(tab);
292 for(uint16_t v = 1; v <= vmax; v++) memo_clear(node_tab[v]);
293 return z0;
296 void zdd_check() {
297 memo_t node_tab;
298 memo_init(node_tab);
299 for (uint32_t i = 2; i < freenode; i++) {
300 memo_it it;
301 uint32_t key[3];
302 key[0] = pool[i]->lo;
303 key[1] = pool[i]->hi;
304 key[2] = pool[i]->v;
305 if (!memo_it_insert_u(&it, node_tab, (void *) key, 12)) {
306 printf("duplicate: %d %d\n", i, (int) it->data);
307 } else {
308 it->data = (void *) i;
310 if (!pool[i]->hi) {
311 printf("HI -> FALSE: %d\n", i);
313 if (i == pool[i]->lo) {
314 printf("LO self-loop: %d\n", i);
316 if (i == pool[i]->hi) {
317 printf("HI self-loop: %d\n", i);
320 memo_clear(node_tab);
323 void zdd_init() {
324 // Initialize TRUE and FALSE nodes.
325 pool[0]->v = ~0;
326 pool[0]->lo = 0;
327 pool[0]->hi = 0;
328 pool[1]->v = ~0;
329 pool[1]->lo = 1;
330 pool[1]->hi = 1;
331 freenode = 2;
332 darray_init(stack);
335 void zdd_dump() {
336 for(uint32_t i = (uint32_t) darray_last(stack); i < freenode; i++) {
337 printf("I%d: !%d ? %d : %d\n", i, pool[i]->v, pool[i]->lo, pool[i]->hi);
341 uint32_t zdd_powerset() {
342 vmax_check();
343 uint16_t r = zdd_next_node();
344 zdd_push();
345 for(int v = 1; v < vmax; v++) zdd_add_node(v, 1, 1);
346 zdd_add_node(vmax, -1, -1);
347 return r;
350 void zdd_forall(void (*fn)(int *, int)) {
351 vmax_check();
352 int v[vmax], vcount = 0;
353 void recurse(uint32_t p) {
354 if (!p) return;
355 if (1 == p) {
356 fn(v, vcount);
357 return;
359 if (zdd_lo(p)) recurse(zdd_lo(p));
360 v[vcount++] = zdd_v(p);
361 recurse(zdd_hi(p));
362 vcount--;
364 recurse(zdd_root());
367 uint16_t zdd_vmax() {
368 vmax_check();
369 return vmax;
372 uint32_t zdd_size() {
373 return zdd_next_node() - zdd_root() + 2;
376 // Construct ZDD of sets containing exactly 1 of the elements in the given list.
377 // Zero suppression means we must treat sequences in the list carefully.
378 void zdd_contains_exactly_1(const int *a, int count) {
379 vmax_check();
380 zdd_push();
381 int v = 1;
382 int i = 0;
383 while(v <= vmax) {
384 if (i >= count) {
385 // Don't care about the rest of the elements.
386 zdd_add_node(v++, 1, 1);
387 } else if (v == a[i]) {
388 // Find length of consecutive sequence.
389 int k;
390 for(k = 0; i + k < count && v + k == a[i + k]; k++);
391 uint32_t n = zdd_next_node();
392 uint32_t h = v + k > vmax ? 1 : n + k + (count != i + k);
393 if (i >= 1) {
394 // In the middle of the list: must fix previous node; we reach said node
395 // if we've seen an element in the list already, in which case the
396 // arrows must bypass the entire sequence, i.e. we need the whole
397 // sequence to be out of the set.
398 //set_node(n - 1, v - 1, h, h);
399 zdd_set_hilo(n - 1, h);
401 i += k;
402 k += v;
403 while (v < k) {
404 // If we see an element, bypass the rest of the sequence (see above),
405 // otherwise we look for the next element in the sequence.
406 zdd_add_node(v++, 1, 1);
407 zdd_set_hi(zdd_last_node(), h);
408 //set_node(n, v++, n + 1, h);
409 //n++;
411 //v--;
412 if (count == i) {
413 // If none of the list showed up, then return false, otherwise,
414 // onwards! (Through the rest of the elements to the end.)
415 //set_node(n - 1, v, 0, h);
416 zdd_set_lo(zdd_last_node(), 0);
417 zdd_set_hi(zdd_last_node(), h);
419 } else if (!i) {
420 // We don't care about the membership of elements before the list.
421 zdd_add_node(v++, 1, 1);
422 } else {
423 zdd_add_node(v, 2, 2);
424 zdd_add_node(v, 2, 2);
425 v++;
428 // Fix last node.
429 uint32_t last = zdd_last_node();
430 if (zdd_lo(last) > last) zdd_set_lo(last, 1);
431 if (zdd_hi(last) > last) zdd_set_hi(last, 1);
434 // Construct ZDD of sets containing at most 1 of the elements in the given
435 // list.
436 void zdd_contains_at_most_1(const int *a, int count) {
437 vmax_check();
438 zdd_push();
439 uint32_t n = zdd_last_node();
440 // Start with ZDD of all sets.
441 int v = 1;
442 while(v < vmax) {
443 zdd_add_node(v++, 1, 1);
445 zdd_add_node(v, -1, -1);
446 // If there is nothing or only one element in the list then we are done.
447 if (count <= 1) return;
449 // At this point, there are at least two elements in the list.
450 // Construct new branch for when elements of the list are detected. We
451 // branch off at the first element, then hop over all remaining elements,
452 // then rejoin.
453 v = a[0];
455 uint32_t n1 = zdd_next_node();
456 zdd_set_hi(n + v, n1);
457 v++;
458 uint32_t last = 0;
459 for(int i = 1; i < count; i++) {
460 int v1 = a[i];
461 while(v < v1) {
462 last = zdd_add_node(v++, 1, 1);
464 zdd_set_hi(n + v, zdd_next_node());
465 v++;
467 // v = last element of list + 1
469 // The HI edges of the last element of the list, and more generally, the last
470 // sequence of the list must be corrected.
471 for(int v1 = a[count - 1]; zdd_hi(n + v1) == zdd_next_node(); v1--) {
472 zdd_set_hi(n + v1, n + v);
475 if (vmax < v) {
476 // Special case: list ends with vmax. Especially troublesome if there's
477 // a little sequence, e.g. vmax - 2, vmax - 1, vmax.
478 for(v = vmax; zdd_hi(n + v) > n + vmax; v--) {
479 zdd_set_hi(n + v, 1);
481 // The following line is only needed if we added any nodes to the branch,
482 // but is harmless if we execute it unconditionally since the last node
483 // to be added was (!vmax ? 1 : 1).
484 zdd_set_hilo(zdd_last_node(), 1);
485 return;
488 // Rejoin main branch.
489 if (last) zdd_set_hilo(last, n + v);
492 // Construct ZDD of sets containing at least 1 of the elements in the given
493 // list.
494 void zdd_contains_at_least_1(const int *a, int count) {
495 vmax_check();
496 zdd_push();
497 uint32_t n = zdd_last_node();
498 // Start with ZDD of all sets.
499 int v = 1;
500 while(v < vmax) {
501 zdd_add_node(v++, 1, 1);
503 zdd_add_node(v, -1, -1);
504 if (!count) return;
506 // Construct new branch for when elements of the list are not found.
507 v = a[0];
508 if (1 == count) {
509 zdd_set_lo(n + v, 0);
510 return;
513 uint32_t n1 = zdd_next_node();
514 zdd_set_lo(n + v, n1);
515 v++;
516 for(int i = 1; i < count; i++) {
517 int v1 = a[i];
518 while(v <= v1) {
519 zdd_add_node(v++, 1, 1);
521 zdd_set_hi(zdd_last_node(), n + v);
524 zdd_set_lo(zdd_last_node(), 0);
525 if (vmax < v) zdd_set_hi(zdd_last_node(), 1);
528 // Construct ZDD of sets not containing any elements from the given list.
529 // Assumes not every variable is on the list.
530 void zdd_contains_0(const int *a, int count) {
531 vmax_check();
532 zdd_push();
533 int i = 1;
534 int v1 = count ? a[0] : -1;
535 for(int v = 1; v <= vmax; v++) {
536 if (v1 == v) {
537 v1 = i < count ? a[i++] : -1;
538 } else {
539 zdd_add_node(v, 1, 1);
542 uint32_t n = zdd_last_node();
543 zdd_set_lo(n, 1);
544 zdd_set_hi(n, 1);
547 // Construct ZDD of sets containing exactly 1 element for each interval
548 // [a_k, a_{k+1}) in given list. List must start with a_0 = 1, while there is an
549 // implied vmax + 1 at end of list, so the last interval is [a_n, vmax + 1).
551 // The ZDD begins:
552 // 1 ... 2
553 // 1 --- a_1
554 // 2 ... 3
555 // 2 --- a_1
556 // ...
557 // a_1 - 1 ... F
558 // a_1 - 1 --- a_1
560 // and so on:
561 // a_k ... a_k + 1
562 // a_k --- a_{k+1}
563 // and so on until vmax --- F, vmax ... T.
564 void zdd_1_per_interval(const int* list, int count) {
565 vmax_check();
566 zdd_push();
567 // Check list[0] is 1.
568 int i = 0;
569 uint32_t n = zdd_last_node();
570 int get() {
571 i++;
572 //return i < inta_count(a) ? inta_at(a, i) : -1;
573 return i < count ? list[i] : -1;
575 int target = get();
576 for (int v = 1; v <= vmax; v++) {
577 zdd_abs_node(v, n + v + 1, target > 0 ? n + target : 1);
578 if (v == target - 1 || v == vmax) {
579 zdd_set_lo(zdd_last_node(), 0);
580 target = get();
585 // Construct ZDD of sets containing exactly n of the elements in the
586 // given list.
587 void zdd_contains_exactly_n(int n, const int *a, int count) {
588 zdd_push();
589 if (n > count) {
590 die("unhandled special case (should return empty family");
592 // Lookup table for sub-ZDDs we construct recursively.
593 uint32_t tab[count][n + 1];
594 memset(tab, 0, count * (n + 1) * sizeof(uint32_t));
595 uint32_t recurse(int i, int n) {
596 // The outermost invocation is a special case, as other invocations
597 // assume part of the ZDD has already been built. We have i == -1
598 // during this special case.
599 int v = -1 == i ? 1 : a[i] + 1;
600 uint32_t root;
601 if (i == count - 1) {
602 // Base case: by now, n should be zero, so finish off the ZDD with
603 // everything leading to TRUE.
604 // We can reach here even in the first invocation of recurse(); this
605 // happens if there is nothing in the list.
606 if (-1 != i && tab[i][0]) return tab[i][0];
607 if (vmax < v) {
608 root = 1;
609 } else {
610 root = zdd_next_node();
611 while(v < vmax) zdd_add_node(v++, 1, 1);
612 zdd_add_node(v, -1, -1);
614 if (-1 != i) tab[i][0] = root;
615 return root;
617 if (-1 != i && tab[i][n]) return tab[i][n];
618 int v1 = a[i + 1];
619 int is_empty = v == v1;
620 root = zdd_next_node();
621 while(v < v1) zdd_add_node(v++, 1, 1);
622 if (!n) {
623 if (is_empty) {
624 root = recurse(i + 1, n);
625 } else {
626 uint32_t last = zdd_last_node();
627 zdd_set_hilo(last, recurse(i + 1, n));
629 if (-1 != i) tab[i][n] = root;
630 return root;
632 uint32_t last = zdd_add_node(v, 0, 0);
633 // If we include this variable, then that's one down, n - 1 more to go
634 // in the remaining.
635 zdd_set_hi(last, recurse(i + 1, n - 1));
636 if (n < count - i - 1) {
637 // If there are enough unexamined nodes we can leave this variable
638 // out and still make the quota.
639 zdd_set_lo(last, recurse(i + 1, n));
641 if (-1 != i) tab[i][n] = root;
642 return root;
644 recurse(-1, n);