1 // Solve a sudoku with ZDDs.
10 // Construct ZDD of sets containing exactly 1 digit at forall boxes
11 // (r, c), starting at pool entry d.
22 // and repeats every 10 levels:
25 // and so on until 729 --- F, 729 ... T.
27 // This ZDD has 9^81 members.
28 void global_one_digit_per_box() {
31 uint32_t n
= zdd_next_node();
32 for(int i
= 1; i
<= 729; i
++) {
33 zdd_add_node(i
, (i
% 9) ? 1 : 0, -1);
35 zdd_set_hi(zdd_last_node(), n
+ next
);
37 if (!(i
% 9)) next
+= 9;
41 // Construct ZDD of sets containing exactly 1 occurrence of digit d in row r.
42 // For instance, if d = 3, r = 0:
46 // 4a === 5a === ... === 9a === ... === 11a === 12
47 // 4b === 5b === ... === 9b === ... === 11b === 13b
48 // 12 ... 13a, --- 13b
49 // 13a === ... === 20a === 21
50 // 13b === ... === 20b === 22b
52 // 74a === 75 ... F, --- 76
54 // 76 === ... === 729 === T
56 // This ZDD has 9*2^720 members.
57 // When intersected with the one-digit-per-box set, the result has
58 // 9*8^8*9^72 members. (Pick 1 of 9 positions ford, leaving 8 possible choices
59 // forthe remaining 8 boxes in that row. The other 81 - 9 boxes can contain
60 // any single digitS.)
61 // The intersection forall d and a fixed r has 9!*9^72 members.
62 // The intersection forall r and a fixed d has 9^9*8^72 members.
63 void unique_digit_per_row(int d
, int r
) {
65 // The order is determined by sorting by number, then by letter.
66 int next
= 81 * r
+ d
; // The next node involving the digit d.
74 // The first split in the ZDD.
75 zdd_add_node(v
, 1, 2);
76 } else if (state
< 9) {
77 // Fix previous node. We must not have a second occurrence of d.
78 uint32_t n
= zdd_last_node();
79 zdd_set_hilo(n
, n
+ 3);
80 // If this is the first occurrence of d, we're on notice.
81 zdd_add_node(v
, 1, 2);
83 // If we never saw d, then branch to FALSE.
84 // Otherwise reunite the branches.
85 zdd_add_node(v
, 0, 1);
88 } else if (state
== 0 || state
== 9) {
89 zdd_add_node(v
, 1, 1);
91 zdd_add_node(v
, 2, 2);
92 zdd_add_node(v
, 2, 2);
97 uint32_t n
= zdd_last_node();
98 if (zdd_lo(n
- 1) > n
) zdd_set_lo(n
- 1, 1);
99 if (zdd_hi(n
- 1) > n
) zdd_set_hi(n
- 1, 1);
100 if (zdd_lo(n
) > n
) zdd_set_lo(n
, 1);
101 if (zdd_hi(n
) > n
) zdd_set_hi(n
, 1);
104 // Construct ZDD of sets containing all elements in the given list.
105 // The list is terminated by -1.
106 void contains_all(int *list
) {
113 zdd_add_node(v
, 0, 1);
115 zdd_add_node(v
, 1, 1);
120 uint32_t n
= zdd_last_node();
121 if (zdd_lo(n
) > n
) zdd_set_lo(n
, 1);
122 if (zdd_hi(n
) > n
) zdd_set_hi(n
, 1);
125 void unique_digit_per_col(int d
, int col
) {
127 for(int i
= 0; i
< 9; i
++) {
128 list
[i
] = 81 * i
+ 9 * col
+ d
;
130 zdd_contains_exactly_1(list
, 9);
133 void unique_digit_per_3x3(int d
, int row
, int col
) {
135 for(int i
= 0; i
< 3; i
++) {
136 for(int j
= 0; j
< 3; j
++) {
137 list
[i
* 3 + j
] = 81 * (i
+ 3 * row
) + 9 * (j
+ 3 * col
) + d
;
140 zdd_contains_exactly_1(list
, 9);
145 // The universe is {1, ..., 9^3 = 729}.
147 // Number rows and columns from 0. Digits are integers [1..9].
148 // The digit d at (r, c) is represented by element 81 r + 9 c + d.
151 for(int i
= 0; i
< 9; i
++) {
152 for(int j
= 0; j
< 9; j
++) {
154 if (EOF
== c
) die("unexpected EOF");
155 if ('\n' == c
) die("unexpected newline");
156 if (c
>= '1' && c
<= '9') {
157 inta_append(list
, 81 * i
+ 9 * j
+ c
- '0');
161 if (EOF
== c
) die("unexpected EOF");
162 if ('\n' != c
) die("expected newline");
164 inta_append(list
, -1);
165 contains_all(inta_raw(list
));
168 global_one_digit_per_box();
171 // Number of ways you can put nine 1s into a sudoku is
175 for(int i
= 1; i
<= 9; i
++) {
176 for(int r
= 0; r
< 9; r
++) {
177 unique_digit_per_row(i
, r
);
178 if (r
) zdd_intersection();
182 for(int i
= 1; i
<= 9; i
++) {
183 for(int c
= 0; c
< 3; c
++) {
184 for(int r
= 0; r
< 3; r
++) {
185 printf("3x3 %d: %d, %d\n", i
, r
, c
);
187 unique_digit_per_3x3(i
, r
, c
);
188 if (r
) zdd_intersection();
190 if (c
) zdd_intersection();
194 for(int i
= 1; i
<= 9; i
++) {
195 for(int c
= 0; c
< 9; c
++) {
196 printf("cols %d: %d\n", i
, c
);
198 unique_digit_per_col(i
, c
);
199 if (c
) zdd_intersection();
204 void printsol(int *v
, int vcount
) {
205 for(int i
= 0; i
< vcount
; i
++) {
206 putchar(((v
[i
] - 1) % 9) + '1');
207 if (8 == (i
% 9)) putchar('\n');
211 zdd_forall(printsol
);