1 // ZDD stack-based calculator library.
17 typedef struct node_s
*node_ptr
;
18 typedef struct node_s node_t
[1];
20 static node_t pool
[1<<24];
21 static uint32_t freenode
, POOL_MAX
= (1<<24) - 1;
22 static darray_t stack
;
24 static char vmax_is_set
;
26 uint16_t zdd_set_vmax(int i
) {
32 if (!vmax_is_set
) die("vmax not set");
35 void zdd_push() { darray_append(stack
, (void *) freenode
); }
36 void zdd_pop() { darray_remove_last(stack
); }
38 void set_node(uint32_t n
, uint16_t v
, uint32_t lo
, uint32_t hi
) {
44 uint32_t zdd_v(uint32_t n
) { return pool
[n
]->v
; }
45 uint32_t zdd_hi(uint32_t n
) { return pool
[n
]->hi
; }
46 uint32_t zdd_lo(uint32_t n
) { return pool
[n
]->lo
; }
47 uint32_t zdd_set_lo(uint32_t n
, uint32_t lo
) { return pool
[n
]->lo
= lo
; }
48 uint32_t zdd_set_hi(uint32_t n
, uint32_t hi
) { return pool
[n
]->hi
= hi
; }
49 uint32_t zdd_set_hilo(uint32_t n
, uint32_t hilo
) {
50 return pool
[n
]->lo
= pool
[n
]->hi
= hilo
;
52 uint32_t zdd_next_node() { return freenode
; }
53 uint32_t zdd_last_node() { return freenode
- 1; }
55 static void pool_swap(uint32_t x
, uint32_t y
) {
56 struct node_s tmp
= *pool
[y
];
59 for(uint32_t i
= 2; i
< freenode
; i
++) {
60 if (pool
[i
]->lo
== x
) pool
[i
]->lo
= y
;
61 else if (pool
[i
]->lo
== y
) pool
[i
]->lo
= x
;
62 if (pool
[i
]->hi
== x
) pool
[i
]->hi
= y
;
63 else if (pool
[i
]->hi
== y
) pool
[i
]->hi
= y
;
67 uint32_t zdd_root() { return (uint32_t) darray_last(stack
); }
68 uint32_t zdd_set_root(uint32_t root
) {
69 uint32_t i
= zdd_root();
70 if (i
!= root
) pool_swap(i
, root
);
74 void zdd_count(mpz_ptr z
) {
75 static mpz_ptr count
[1<<24];
76 uint32_t root
= zdd_root();
77 // Count elements in ZDD rooted at node n.
78 void get_count(uint32_t n
) {
80 count
[n
] = malloc(sizeof(mpz_t
));
83 mpz_set_ui(count
[n
], n
);
86 uint32_t x
= pool
[n
]->lo
;
87 uint32_t y
= pool
[n
]->hi
;
88 if (!count
[x
]) get_count(x
);
89 if (!count
[y
]) get_count(y
);
91 mpz_add(z
, count
[x
], count
[y
]);
93 mpz_add(count
[n
], count
[x
], count
[y
]);
99 void clearz(uint32_t n
) {
100 if (!count
[n
]) return;
111 uint32_t zdd_abs_node(uint32_t v
, uint32_t lo
, uint32_t hi
) {
112 set_node(freenode
, v
, lo
, hi
);
116 uint32_t zdd_add_node(uint32_t v
, int offlo
, int offhi
) {
118 uint32_t adjust(int off
) {
120 if (-1 == off
) return 1;
123 set_node(n
, v
, adjust(offlo
), adjust(offhi
));
127 uint32_t zdd_intersection() {
129 if (darray_count(stack
) == 0) return 0;
130 if (darray_count(stack
) == 1) return (uint32_t) darray_last(stack
);
131 uint32_t z0
= (uint32_t) darray_at(stack
, darray_count(stack
) - 2);
132 uint32_t z1
= (uint32_t) darray_remove_last(stack
);
133 struct node_template_s
{
135 // NULL means this template have been instantiated.
136 // Otherwise it points to the LO template.
139 // Points to HI template when template is not yet instantiated.
141 // During template instantiation we set n to the pool index
142 // of the newly created node.
146 typedef struct node_template_s
*node_template_ptr
;
147 typedef struct node_template_s node_template_t
[1];
149 node_template_t top
, bot
;
157 // Naive implementation with two tries. One stores templates, the other
158 // unique nodes. See Knuth for how to meld using just memory allocated
159 // for a pool of nodes.
163 memo_it
insert_template(uint32_t k0
, uint32_t k1
) {
165 // Taking advantage of symmetry of intersection appears to help a tiny bit.
174 int just_created
= memo_it_insert_u(&it
, tab
, (void *) key
, 8);
175 if (!just_created
) return it
;
177 memo_it_put(it
, bot
);
180 if (k0
== 1 && k1
== 1) {
181 memo_it_put(it
, top
);
184 node_ptr n0
= pool
[k0
];
185 node_ptr n1
= pool
[k1
];
186 if (n0
->v
== n1
->v
) {
187 node_template_ptr t
= malloc(sizeof(*t
));
189 if (n0
->lo
== n0
->hi
&& n1
->lo
== n0
->hi
) {
190 t
->lo
= t
->hi
= insert_template(n0
->lo
, n1
->lo
);
192 t
->lo
= insert_template(n0
->lo
, n1
->lo
);
193 t
->hi
= insert_template(n0
->hi
, n1
->hi
);
197 } else if (n0
->v
< n1
->v
) {
198 memo_it it2
= insert_template(n0
->lo
, k1
);
199 memo_it_put(it
, memo_it_data(it2
));
202 memo_it it2
= insert_template(k0
, n1
->lo
);
203 memo_it_put(it
, memo_it_data(it2
));
208 void dump(void* data
, const char* key
) {
209 uint32_t *n
= (uint32_t *) key
;
211 printf("NULL %d:%d\n", n
[0], n
[1]);
214 node_template_ptr t
= (node_template_ptr
) data
;
216 printf("%d:%d = (%d)\n", n
[0], n
[1], t
->n
);
219 uint32_t *l
= (uint32_t *) memo_it_key(t
->lo
);
220 uint32_t *h
= (uint32_t *) memo_it_key(t
->hi
);
221 printf("%d:%d = %d:%d, %d:%d\n", n
[0], n
[1], l
[0], l
[1], h
[0], h
[1]);
224 memo_t node_tab
[vmax
+ 1];
225 for(uint16_t v
= 1; v
<= vmax
; v
++) memo_init(node_tab
[v
]);
227 uint32_t unique(uint16_t v
, uint32_t lo
, uint32_t hi
) {
228 // Create or return existing node representing !v ? lo : hi.
229 uint32_t key
[2] = { lo
, hi
};
231 int just_created
= memo_it_insert_u(&it
, node_tab
[v
], (void *) key
, 8);
233 memo_it_put(it
, (void *) freenode
);
234 node_ptr n
= pool
[freenode
];
238 if (!(freenode
<< 15)) printf("freenode = %x\n", freenode
);
239 if (POOL_MAX
== freenode
) {
244 return (uint32_t) memo_it_data(it
);
247 uint32_t instantiate(memo_it it
) {
248 node_template_ptr t
= (node_template_ptr
) memo_it_data(it
);
249 // Return if already converted to node.
250 if (!t
->lo
) return t
->n
;
251 // Recurse on LO, HI edges.
252 uint32_t lo
= instantiate(t
->lo
);
253 uint32_t hi
= instantiate(t
->hi
);
254 // Remove HI edges pointing to FALSE.
261 uint32_t r
= unique(t
->v
, lo
, hi
);
267 insert_template(z0
, z1
);
268 freenode
= z0
; // Overwrite input trees.
269 //memo_forall(tab, dump);
273 memo_it it
= memo_it_at_u(tab
, (void *) key
, 8);
274 uint32_t root
= instantiate(it
);
275 // TODO: What if the intersection is node 0 or 1?
277 die("root is 0 or 1!");
280 *pool
[z0
] = *pool
[root
];
281 } else if (root
> z0
) {
284 void clear_it(void* data
, const char* key
) {
285 node_template_ptr t
= (node_template_ptr
) data
;
286 uint32_t *k
= (uint32_t *) key
;
287 if (k
[0] == k
[1] && t
!= top
&& t
!= bot
) free(t
);
289 memo_forall(tab
, clear_it
);
292 for(uint16_t v
= 1; v
<= vmax
; v
++) memo_clear(node_tab
[v
]);
299 for (uint32_t i
= 2; i
< freenode
; i
++) {
302 key
[0] = pool
[i
]->lo
;
303 key
[1] = pool
[i
]->hi
;
305 if (!memo_it_insert_u(&it
, node_tab
, (void *) key
, 12)) {
306 printf("duplicate: %d %d\n", i
, (int) it
->data
);
308 it
->data
= (void *) i
;
311 printf("HI -> FALSE: %d\n", i
);
313 if (i
== pool
[i
]->lo
) {
314 printf("LO self-loop: %d\n", i
);
316 if (i
== pool
[i
]->hi
) {
317 printf("HI self-loop: %d\n", i
);
320 memo_clear(node_tab
);
324 // Initialize TRUE and FALSE nodes.
336 for(uint32_t i
= (uint32_t) darray_last(stack
); i
< freenode
; i
++) {
337 printf("I%d: !%d ? %d : %d\n", i
, pool
[i
]->v
, pool
[i
]->lo
, pool
[i
]->hi
);
341 uint32_t zdd_powerset() {
343 uint16_t r
= zdd_next_node();
345 for(int v
= 1; v
< vmax
; v
++) zdd_add_node(v
, 1, 1);
346 zdd_add_node(vmax
, -1, -1);
350 void zdd_forall(void (*fn
)(int *, int)) {
352 int v
[vmax
], vcount
= 0;
353 void recurse(uint32_t p
) {
359 if (zdd_lo(p
)) recurse(zdd_lo(p
));
360 v
[vcount
++] = zdd_v(p
);
367 uint16_t zdd_vmax() {
372 uint32_t zdd_size() {
373 return zdd_next_node() - zdd_root() + 2;
376 // Construct ZDD of sets containing exactly 1 of the elements in the given list.
377 // Zero suppression means we must treat sequences in the list carefully.
378 void zdd_contains_exactly_1(const int *a
, int count
) {
385 // Don't care about the rest of the elements.
386 zdd_add_node(v
++, 1, 1);
387 } else if (v
== a
[i
]) {
388 // Find length of consecutive sequence.
390 for(k
= 0; i
+ k
< count
&& v
+ k
== a
[i
+ k
]; k
++);
391 uint32_t n
= zdd_next_node();
392 uint32_t h
= v
+ k
> vmax
? 1 : n
+ k
+ (count
!= i
+ k
);
394 // In the middle of the list: must fix previous node; we reach said node
395 // if we've seen an element in the list already, in which case the
396 // arrows must bypass the entire sequence, i.e. we need the whole
397 // sequence to be out of the set.
398 //set_node(n - 1, v - 1, h, h);
399 zdd_set_hilo(n
- 1, h
);
404 // If we see an element, bypass the rest of the sequence (see above),
405 // otherwise we look for the next element in the sequence.
406 zdd_add_node(v
++, 1, 1);
407 zdd_set_hi(zdd_last_node(), h
);
408 //set_node(n, v++, n + 1, h);
413 // If none of the list showed up, then return false, otherwise,
414 // onwards! (Through the rest of the elements to the end.)
415 //set_node(n - 1, v, 0, h);
416 zdd_set_lo(zdd_last_node(), 0);
417 zdd_set_hi(zdd_last_node(), h
);
420 // We don't care about the membership of elements before the list.
421 zdd_add_node(v
++, 1, 1);
423 zdd_add_node(v
, 2, 2);
424 zdd_add_node(v
, 2, 2);
429 uint32_t last
= zdd_last_node();
430 if (zdd_lo(last
) > last
) zdd_set_lo(last
, 1);
431 if (zdd_hi(last
) > last
) zdd_set_hi(last
, 1);
434 // Construct ZDD of sets containing at most 1 of the elements in the given
436 void zdd_contains_at_most_1(const int *a
, int count
) {
439 uint32_t n
= zdd_last_node();
440 // Start with ZDD of all sets.
443 zdd_add_node(v
++, 1, 1);
445 zdd_add_node(v
, -1, -1);
446 // If there is nothing or only one element in the list then we are done.
447 if (count
<= 1) return;
449 // At this point, there are at least two elements in the list.
450 // Construct new branch for when elements of the list are detected. We
451 // branch off at the first element, then hop over all remaining elements,
455 uint32_t n1
= zdd_next_node();
456 zdd_set_hi(n
+ v
, n1
);
459 for(int i
= 1; i
< count
; i
++) {
462 last
= zdd_add_node(v
++, 1, 1);
464 zdd_set_hi(n
+ v
, zdd_next_node());
467 // v = last element of list + 1
469 // The HI edges of the last element of the list, and more generally, the last
470 // sequence of the list must be corrected.
471 for(int v1
= a
[count
- 1]; zdd_hi(n
+ v1
) == zdd_next_node(); v1
--) {
472 zdd_set_hi(n
+ v1
, n
+ v
);
476 // Special case: list ends with vmax. Especially troublesome if there's
477 // a little sequence, e.g. vmax - 2, vmax - 1, vmax.
478 for(v
= vmax
; zdd_hi(n
+ v
) > n
+ vmax
; v
--) {
479 zdd_set_hi(n
+ v
, 1);
481 // The following line is only needed if we added any nodes to the branch,
482 // but is harmless if we execute it unconditionally since the last node
483 // to be added was (!vmax ? 1 : 1).
484 zdd_set_hilo(zdd_last_node(), 1);
488 // Rejoin main branch.
489 if (last
) zdd_set_hilo(last
, n
+ v
);
492 // Construct ZDD of sets containing at least 1 of the elements in the given
494 void zdd_contains_at_least_1(const int *a
, int count
) {
497 uint32_t n
= zdd_last_node();
498 // Start with ZDD of all sets.
501 zdd_add_node(v
++, 1, 1);
503 zdd_add_node(v
, -1, -1);
506 // Construct new branch for when elements of the list are not found.
509 zdd_set_lo(n
+ v
, 0);
513 uint32_t n1
= zdd_next_node();
514 zdd_set_lo(n
+ v
, n1
);
516 for(int i
= 1; i
< count
; i
++) {
519 zdd_add_node(v
++, 1, 1);
521 zdd_set_hi(zdd_last_node(), n
+ v
);
524 zdd_set_lo(zdd_last_node(), 0);
525 if (vmax
< v
) zdd_set_hi(zdd_last_node(), 1);
528 // Construct ZDD of sets not containing any elements from the given list.
529 // Assumes not every variable is on the list.
530 void zdd_contains_0(const int *a
, int count
) {
534 int v1
= count
? a
[0] : -1;
535 for(int v
= 1; v
<= vmax
; v
++) {
537 v1
= i
< count
? a
[i
++] : -1;
539 zdd_add_node(v
, 1, 1);
542 uint32_t n
= zdd_last_node();
547 // Construct ZDD of sets containing exactly 1 element for each interval
548 // [a_k, a_{k+1}) in given list. List must start with a_0 = 1, while there is an
549 // implied vmax + 1 at end of list, so the last interval is [a_n, vmax + 1).
563 // and so on until vmax --- F, vmax ... T.
564 void zdd_1_per_interval(const int* list
, int count
) {
567 // Check list[0] is 1.
569 uint32_t n
= zdd_last_node();
572 //return i < inta_count(a) ? inta_at(a, i) : -1;
573 return i
< count
? list
[i
] : -1;
576 for (int v
= 1; v
<= vmax
; v
++) {
577 zdd_abs_node(v
, n
+ v
+ 1, target
> 0 ? n
+ target
: 1);
578 if (v
== target
- 1 || v
== vmax
) {
579 zdd_set_lo(zdd_last_node(), 0);
585 // Construct ZDD of sets containing exactly n of the elements in the
587 void zdd_contains_exactly_n(int n
, const int *a
, int count
) {
590 die("unhandled special case (should return empty family");
592 // Lookup table for sub-ZDDs we construct recursively.
593 uint32_t tab
[count
][n
+ 1];
594 memset(tab
, 0, count
* (n
+ 1) * sizeof(uint32_t));
595 uint32_t recurse(int i
, int n
) {
596 // The outermost invocation is a special case, as other invocations
597 // assume part of the ZDD has already been built. We have i == -1
598 // during this special case.
599 int v
= -1 == i
? 1 : a
[i
] + 1;
601 if (i
== count
- 1) {
602 // Base case: by now, n should be zero, so finish off the ZDD with
603 // everything leading to TRUE.
604 // We can reach here even in the first invocation of recurse(); this
605 // happens if there is nothing in the list.
606 if (-1 != i
&& tab
[i
][0]) return tab
[i
][0];
610 root
= zdd_next_node();
611 while(v
< vmax
) zdd_add_node(v
++, 1, 1);
612 zdd_add_node(v
, -1, -1);
614 if (-1 != i
) tab
[i
][0] = root
;
617 if (-1 != i
&& tab
[i
][n
]) return tab
[i
][n
];
619 int is_empty
= v
== v1
;
620 root
= zdd_next_node();
621 while(v
< v1
) zdd_add_node(v
++, 1, 1);
624 root
= recurse(i
+ 1, n
);
626 uint32_t last
= zdd_last_node();
627 zdd_set_hilo(last
, recurse(i
+ 1, n
));
629 if (-1 != i
) tab
[i
][n
] = root
;
632 uint32_t last
= zdd_add_node(v
, 0, 0);
633 // If we include this variable, then that's one down, n - 1 more to go
635 zdd_set_hi(last
, recurse(i
+ 1, n
- 1));
636 if (n
< count
- i
- 1) {
637 // If there are enough unexamined nodes we can leave this variable
638 // out and still make the quota.
639 zdd_set_lo(last
, recurse(i
+ 1, n
));
641 if (-1 != i
) tab
[i
][n
] = root
;