day 25 optimize and improve heuristics
[aoc_eblake.git] / 2019 / day22.c
blobc23346d9f01536dda8a9005d8029c37c721032a4
1 #define _GNU_SOURCE 1
2 #include <stdio.h>
3 #include <string.h>
4 #include <stdlib.h>
5 #include <stdarg.h>
6 #include <stdbool.h>
7 #include <stdint.h>
8 #include <inttypes.h>
10 static int debug_level = -1;
11 static void
12 debug_init(void) {
13 if (debug_level < 0)
14 debug_level = atoi(getenv("DEBUG") ?: "0");
17 static void __attribute__((format(printf, 2, 3)))
18 debug_raw(int level, const char *fmt, ...) {
19 va_list ap;
20 if (debug_level < 0)
21 debug_level = atoi(getenv("DEBUG") ?: "0");
22 if (debug_level >= level) {
23 va_start(ap, fmt);
24 vfprintf(stderr, fmt, ap);
25 va_end(ap);
28 #define debug(...) debug_raw(1, __VA_ARGS__)
30 static int __attribute__((noreturn)) __attribute__((format(printf, 1, 2)))
31 die(const char *fmt, ...)
33 va_list ap;
34 va_start(ap, fmt);
35 vprintf(fmt, ap);
36 va_end(ap);
37 putchar('\n');
38 exit(1);
41 #define LIMIT 100
42 enum op { NEW, CUT, DEAL };
43 static struct act {
44 enum op op;
45 int n;
46 long long n1;
47 } actions[LIMIT];
48 static int nactions;
50 /* https://www.geeksforgeeks.org/multiplicative-inverse-under-modulo-m/
51 * assumes a and m are co-prime, m > 1 */
52 static long long
53 modInverse(long long a, long long m) {
54 long long m0 = m;
55 long long y = 0, x = 1;
57 while (a > 1) {
58 long long q = a / m;
59 long long t = m;
60 m = a % m;
61 a = t;
62 t = y;
63 y = x - q * y;
64 x = t;
66 if (x < 0)
67 x += m0;
68 return x;
71 /* https://www.geeksforgeeks.org/how-to-avoid-overflow-in-modular-multiplication/ */
72 static long long
73 modMul(long long a, long long b, long long mod) {
74 long long r = 0;
76 while (b > 0) {
77 if (b & 1)
78 r = (r + a) % mod;
79 a = (a * 2) % mod;
80 b /= 2;
82 return r % mod;
85 /* https://en.wikipedia.org/wiki/Modular_exponentiation */
86 static long long
87 modPow(long long base, long long exp, long long m) {
88 long long r = 1;
89 while (exp) {
90 if (exp & 1)
91 r = modMul(r, base, m);
92 exp /= 2;
93 base = modMul(base, base, m);
95 return r;
98 static long long
99 shuffle(long long size, long long track) {
100 int i;
102 debug("shuffling deck size %lld, while tracking card %lld\n", size, track);
103 for (i = 0; i < nactions; i++) {
104 switch (actions[i].op) {
105 case NEW:
106 track = size - 1 - track;
107 debug("new deck moved card to %lld\n", track);
108 break;
109 case CUT:
110 track = (size + track - actions[i].n) % size;
111 debug("cut %d moved card to %lld\n", actions[i].n, track);
112 break;
113 case DEAL:
114 track = modMul(track, actions[i].n, size);
115 debug("deal %d moved card to %lld\n", actions[i].n, track);
116 break;
119 return track;
122 static long long
123 unshuffle(long long size, long long track) {
124 int i;
125 static int iter;
127 debug("on unshuffle %d, deck size %lld undoing pos %lld\n",
128 iter, size, track);
129 for (i = nactions - 1; i >= 0; i--) {
130 switch (actions[i].op) {
131 case NEW:
132 track = size - 1 - track;
133 debug("new deck moved card from %lld\n", track);
134 break;
135 case CUT:
136 track = (size + track + actions[i].n) % size;
137 debug("cut %d moved card from %lld\n", actions[i].n, track);
138 break;
139 case DEAL:
140 track = modMul(track, actions[i].n1, size);
141 debug("deal %d moved card from %lld\n", actions[i].n, track);
142 break;
145 return track;
148 int main(int argc, char **argv) {
149 size_t len = 0;
150 char *line;
151 long long size = 10007, track = 2019, remaining = 101741582076661, pos = 2020;
152 int i, part1;
153 long long a, b, y, z, part2;
155 debug_init();
156 if (argc > 1 && strcmp(argv[1], "-"))
157 if (!(stdin = freopen(argv[1], "r", stdin))) {
158 perror("failure");
159 exit(2);
162 if (argc > 2)
163 size = atoi(argv[2]);
164 if (argc > 3)
165 track = atoi(argv[3]);
166 if (0U + track > size)
167 die("card %lld not in deck of size %lld\n", track, size);
168 if (argc > 4)
169 remaining = atoll(argv[4]);
171 while ((i = getline(&line, &len, stdin)) >= 0) {
172 if (nactions >= LIMIT)
173 die("recompile with larger LIMIT");
174 if (line[0] == 'c') { /* "cut N" */
175 actions[nactions].op = CUT;
176 actions[nactions].n = atoi(line + 4);
177 } else if (line[5] == 'w') { /* "deal with increment N" */
178 actions[nactions].op = DEAL;
179 actions[nactions].n = atoi(line + 20);
180 } else { /* "deal into new stack */
181 actions[nactions].op = NEW;
183 nactions++;
185 part1 = shuffle(size, track);
186 printf("after shuffling %lld cards %d times, card %lld is at position %d\n",
187 size, nactions, track, part1);
188 for (i = 0; i < nactions; i++)
189 if (actions[i].op == DEAL)
190 actions[i].n1 = modInverse(actions[i].n, size);
191 printf("sanity check: pos %d started with card %lld\n", part1,
192 unshuffle(size, part1));
194 /* With hints from https://www.reddit.com/r/adventofcode/comments/ee0rqi/2019_day_22_solutions/fbnifwk/
195 X = 2020
196 Y = f(X)
197 Z = f(Y) = f(f(X))
198 f(i) = A*i + B
199 A*X + B = Y
200 A*Y + B = Z
201 A*(X - Y) = (Y - Z)
202 A = (Y - Z) * (X - Y)^-1
203 B = Y - A*X
204 f(f(x)) = A*(A*x + B)+ B = A^2*x + A*x*B + B
205 f^n(x) = A^n*x + A^(n-1)*B + A^(n-2)*B ... + B
206 = A^n*x + (A^(n-1) + A^(n-2) ... + 1) * B
207 = A^n*x + (A^(n-1) + A^(n-2) ... + 1)*(A-1) / (A-1) * B
208 = A^n*x + (A^n - 1) *(A-1)^-1 * B
210 size = 119315717514047;
211 printf("for reference, shuffle(0)=%lld shuffle(1)=%lld\n",
212 shuffle(size, 0), shuffle(size, 1));
213 printf("preparing to unshuffle %lld cards %lld times\n", size, remaining);
214 for (i = 0; i < nactions; i++)
215 if (actions[i].op == DEAL)
216 actions[i].n1 = modInverse(actions[i].n, size);
217 y = unshuffle(size, pos);
218 z = unshuffle(size, y);
219 a = modMul((y - z + size) % size, modInverse((pos + size - y) % size,
220 size), size);
221 b = (y + size - modMul(a, pos, size)) % size;
222 printf("computed y=%lld z=%lld a=%lld b=%lld\n", y, z, a, b);
223 part2 = (modMul(modPow(a, remaining, size), pos, size) +
224 modMul(modMul(modPow(a, remaining, size) - 1,
225 modInverse(a - 1, size), size), b, size)) % size;
226 printf("position %lld contains %lld\n", pos, part2);
227 return 0;