Slight tweaks.
[zddfun.git] / sud.c
blob519dcd51e9d71839a1055786e384b1b6331a2df2
1 // Solve a sudoku with ZDDs.
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <stdio.h>
5 #include "inta.h"
6 #include "zdd.h"
7 #include <stdarg.h>
8 #include "io.h"
10 // Construct ZDD of sets containing exactly 1 digit at forall boxes
11 // (r, c), starting at pool entry d.
13 // The ZDD begins:
14 // 1 ... 2
15 // 1 --- 10
16 // 2 ... 3
17 // 2 --- 10
18 // ...
19 // 9 ... F
20 // 9 --- 10
22 // and repeats every 10 levels:
23 // 10 ... 11
24 // 10 --- 19
25 // and so on until 729 --- F, 729 ... T.
27 // This ZDD has 9^81 members.
28 void global_one_digit_per_box() {
29 zdd_push();
30 int next = 9;
31 uint32_t n = zdd_next_node();
32 for(int i = 1; i <= 729; i++) {
33 zdd_add_node(i, (i % 9) ? 1 : 0, -1);
34 if (next < 729) {
35 zdd_set_hi(zdd_last_node(), n + next);
37 if (!(i % 9)) next += 9;
41 // Construct ZDD of sets containing exactly 1 occurrence of digit d in row r.
42 // For instance, if d = 3, r = 0:
44 // 1 === 2 === 3
45 // 3 ... 4a, --- 4b
46 // 4a === 5a === ... === 9a === ... === 11a === 12
47 // 4b === 5b === ... === 9b === ... === 11b === 13b
48 // 12 ... 13a, --- 13b
49 // 13a === ... === 20a === 21
50 // 13b === ... === 20b === 22b
51 // and so on until:
52 // 74a === 75 ... F, --- 76
53 // 74b === 76
54 // 76 === ... === 729 === T
56 // This ZDD has 9*2^720 members.
57 // When intersected with the one-digit-per-box set, the result has
58 // 9*8^8*9^72 members. (Pick 1 of 9 positions ford, leaving 8 possible choices
59 // forthe remaining 8 boxes in that row. The other 81 - 9 boxes can contain
60 // any single digitS.)
61 // The intersection forall d and a fixed r has 9!*9^72 members.
62 // The intersection forall r and a fixed d has 9^9*8^72 members.
63 void unique_digit_per_row(int d, int r) {
64 zdd_push();
65 // The order is determined by sorting by number, then by letter.
66 int next = 81 * r + d; // The next node involving the digit d.
67 int v = 1;
68 int state = 0;
69 while (v <= 729) {
70 if (v == next) {
71 next += 9;
72 state++;
73 if (state == 1) {
74 // The first split in the ZDD.
75 zdd_add_node(v, 1, 2);
76 } else if (state < 9) {
77 // Fix previous node. We must not have a second occurrence of d.
78 uint32_t n = zdd_last_node();
79 zdd_set_hilo(n, n + 3);
80 // If this is the first occurrence of d, we're on notice.
81 zdd_add_node(v, 1, 2);
82 } else {
83 // If we never saw d, then branch to FALSE.
84 // Otherwise reunite the branches.
85 zdd_add_node(v, 0, 1);
86 next = -1;
88 } else if (state == 0 || state == 9) {
89 zdd_add_node(v, 1, 1);
90 } else {
91 zdd_add_node(v, 2, 2);
92 zdd_add_node(v, 2, 2);
94 v++;
96 // Fix last nodes.
97 uint32_t n = zdd_last_node();
98 if (zdd_lo(n - 1) > n) zdd_set_lo(n - 1, 1);
99 if (zdd_hi(n - 1) > n) zdd_set_hi(n - 1, 1);
100 if (zdd_lo(n) > n) zdd_set_lo(n, 1);
101 if (zdd_hi(n) > n) zdd_set_hi(n, 1);
104 // Construct ZDD of sets containing all elements in the given list.
105 // The list is terminated by -1.
106 void contains_all(int *list) {
107 zdd_push();
108 int v = 1;
109 int *next = list;
110 while (v <= 729) {
111 if (v == *next) {
112 next++;
113 zdd_add_node(v, 0, 1);
114 } else {
115 zdd_add_node(v, 1, 1);
117 v++;
119 // Fix 729.
120 uint32_t n = zdd_last_node();
121 if (zdd_lo(n) > n) zdd_set_lo(n, 1);
122 if (zdd_hi(n) > n) zdd_set_hi(n, 1);
125 void unique_digit_per_col(int d, int col) {
126 int list[9];
127 for(int i = 0; i < 9; i++) {
128 list[i] = 81 * i + 9 * col + d;
130 zdd_contains_exactly_1(list, 9);
133 void unique_digit_per_3x3(int d, int row, int col) {
134 int list[9];
135 for(int i = 0; i < 3; i++) {
136 for(int j = 0; j < 3; j++) {
137 list[i * 3 + j] = 81 * (i + 3 * row) + 9 * (j + 3 * col) + d;
140 zdd_contains_exactly_1(list, 9);
143 int main() {
144 zdd_init();
145 // The universe is {1, ..., 9^3 = 729}.
146 zdd_set_vmax(729);
147 // Number rows and columns from 0. Digits are integers [1..9].
148 // The digit d at (r, c) is represented by element 81 r + 9 c + d.
149 inta_t list;
150 inta_init(list);
151 for(int i = 0; i < 9; i++) {
152 for(int j = 0; j < 9; j++) {
153 int c = getchar();
154 if (EOF == c) die("unexpected EOF");
155 if ('\n' == c) die("unexpected newline");
156 if (c >= '1' && c <= '9') {
157 inta_append(list, 81 * i + 9 * j + c - '0');
160 int c = getchar();
161 if (EOF == c) die("unexpected EOF");
162 if ('\n' != c) die("expected newline");
164 inta_append(list, -1);
165 contains_all(inta_raw(list));
166 inta_clear(list);
168 global_one_digit_per_box();
169 zdd_intersection();
171 // Number of ways you can put nine 1s into a sudoku is
172 // 9*6*3*6*3*4*2*2.
173 printf("rows\n");
174 fflush(stdout);
175 for(int i = 1; i <= 9; i++) {
176 for(int r = 0; r < 9; r++) {
177 unique_digit_per_row(i, r);
178 if (r) zdd_intersection();
180 zdd_intersection();
182 for(int i = 1; i <= 9; i++) {
183 for(int c = 0; c < 3; c++) {
184 for(int r = 0; r < 3; r++) {
185 printf("3x3 %d: %d, %d\n", i, r, c);
186 fflush(stdout);
187 unique_digit_per_3x3(i, r, c);
188 if (r) zdd_intersection();
190 if (c) zdd_intersection();
192 zdd_intersection();
194 for(int i = 1; i <= 9; i++) {
195 for(int c = 0; c < 9; c++) {
196 printf("cols %d: %d\n", i, c);
197 fflush(stdout);
198 unique_digit_per_col(i, c);
199 if (c) zdd_intersection();
201 zdd_intersection();
204 void printsol(int *v, int vcount) {
205 for(int i = 0; i < vcount; i++) {
206 putchar(((v[i] - 1) % 9) + '1');
207 if (8 == (i % 9)) putchar('\n');
209 putchar('\n');
211 zdd_forall(printsol);
212 return 0;