Lua: Don't lua_error() out of context with pending dtors
[lsnes.git] / src / library / curve25519.cpp
blob6c7d39a5389268ddcce66f69f74a97715f643abb
1 #include "curve25519.hpp"
2 #include <cstring>
3 #include <iostream>
4 #include <iomanip>
6 namespace
8 void zeroize(void* ptr, size_t size)
10 //Whee... Do it like OpenSSL/GnuTLS.
11 volatile char* vptr = (volatile char*)ptr;
12 volatile size_t vidx = 0;
13 do { memset(ptr, 0, size); } while(vptr[vidx]);
16 #if defined(__x86_64__)
17 //Generic (slow).
18 struct element
20 typedef uint64_t smallval_t;
21 typedef uint64_t cond_t;
22 typedef uint64_t limb_t;
23 typedef __uint128_t wide_t;
24 const static int shift = 51;
25 const static limb_t mask = (1ULL << shift) - 1;
26 //a^(2^count) -> self
27 inline void square(const element& a, unsigned count = 1)
29 wide_t s[5];
30 memcpy(n, a.n, sizeof(n));
31 for(unsigned i = 0; i < count; i++) {
32 s[0] = (wide_t)n[0] * (wide_t)n[0] +
33 (wide_t)(n[1] << 1) * (wide_t)(n[4] * 19) +
34 (wide_t)(n[2] << 1) * (wide_t)(n[3] * 19);
35 s[1] = (wide_t)n[0] * (wide_t)(n[1] << 1) +
36 (wide_t)(n[2] << 1) * (wide_t)(n[4] * 19) +
37 (wide_t)n[3] * (wide_t)(n[3] * 19);
38 s[2] = ((wide_t)n[0] * (wide_t)(n[2] << 1) +
39 (wide_t)n[1] * (wide_t)n[1]) +
40 (wide_t)(n[3] << 1) * (wide_t)(n[4] * 19);
41 s[3] = ((wide_t)n[0] * (wide_t)n[3] << 1) +
42 ((wide_t)n[1] * (wide_t)n[2] << 1) +
43 ((wide_t)n[4] * (wide_t)(n[4] * 19));
44 s[4] = ((wide_t)n[0] * (wide_t)n[4] << 1) +
45 ((wide_t)n[1] * (wide_t)n[3] << 1) +
46 ((wide_t)n[2] * (wide_t)n[2]);
47 s[1] += (s[0] >> shift);
48 s[2] += (s[1] >> shift);
49 n[2] = (limb_t)s[2] & mask;
50 s[3] += (s[2] >> shift);
51 n[3] = (limb_t)s[3] & mask;
52 s[4] += (s[3] >> shift);
53 n[4] = (limb_t)s[4] & mask;
54 s[0] = ((limb_t)s[0] & mask) + 19 * (limb_t)(s[4] >> shift);
55 n[0] = (limb_t)s[0] & mask;
56 s[1] = ((limb_t)s[1] & mask) + (limb_t)(s[0] >> shift);
57 n[1] = (limb_t)s[1];
59 zeroize(s, sizeof(s));
61 //a * b -> self
62 inline void multiply(const element& a, const element& b)
64 wide_t s[5];
65 s[0] = (wide_t)a.n[0] * (wide_t)b.n[0] + (wide_t)(a.n[4] * 19) * (wide_t)b.n[1] +
66 (wide_t)(a.n[3] * 19) * (wide_t)b.n[2] + (wide_t)a.n[2] * (wide_t)(b.n[3] * 19) +
67 (wide_t)a.n[1] * (wide_t)(b.n[4] * 19);
68 s[1] = (wide_t)a.n[0] * (wide_t)b.n[1] + (wide_t)a.n[1] * (wide_t)b.n[0] +
69 (wide_t)(a.n[4] * 19) * (wide_t)b.n[2] + (wide_t)(a.n[3] * 19) * (wide_t)b.n[3] +
70 (wide_t)a.n[2] * (wide_t)(b.n[4] * 19);
71 s[2] = (wide_t)a.n[0] * (wide_t)b.n[2] + (wide_t)a.n[1] * (wide_t)b.n[1] +
72 (wide_t)a.n[2] * (wide_t)b.n[0] + (wide_t)(a.n[4] * 19) * (wide_t)b.n[3] +
73 (wide_t)(a.n[3] * 19) * (wide_t)b.n[4];
74 s[3] = (wide_t)a.n[0] * (wide_t)b.n[3] + (wide_t)a.n[1] * (wide_t)b.n[2] +
75 (wide_t)a.n[2] * (wide_t)b.n[1] + (wide_t)a.n[3] * (wide_t)b.n[0] +
76 (wide_t)(a.n[4] * 19) * (wide_t)b.n[4];
77 s[4] = (wide_t)a.n[0] * (wide_t)b.n[4] + (wide_t)a.n[1] * (wide_t)b.n[3] +
78 (wide_t)a.n[2] * (wide_t)b.n[2] + (wide_t)a.n[3] * (wide_t)b.n[1] +
79 (wide_t)a.n[4] * (wide_t)b.n[0];
80 s[1] += (s[0] >> shift);
81 s[2] += (s[1] >> shift);
82 n[2] = (limb_t)s[2] & mask;
83 s[3] += (s[2] >> shift);
84 n[3] = (limb_t)s[3] & mask;
85 s[4] += (s[3] >> shift);
86 n[4] = (limb_t)s[4] & mask;
87 s[0] = ((limb_t)s[0] & mask) + 19 * (limb_t)(s[4] >> shift);
88 n[0] = (limb_t)s[0] & mask;
89 s[1] = ((limb_t)s[1] & mask) + (limb_t)(s[0] >> shift);
90 n[1] = (limb_t)s[1];
91 zeroize(s, sizeof(s));
93 //e - self -> self
94 inline void diff_back(const element& e)
96 limb_t C1 = 2 * mask - 2 * (19 - 1);
97 limb_t C2 = 2 * mask;
98 n[0] = e.n[0] + C1 - n[0];
99 for(unsigned i = 1; i < 5; i++)
100 n[i] = e.n[i] + C2 - n[i];
101 limb_t carry = 0;
102 for(unsigned i = 0; i < 5; i++) {
103 n[i] += carry;
104 carry = n[i] >> shift;
105 n[i] &= mask;
107 carry *= 19;
108 n[0] += carry;
110 //a * b -> self (with constant b).
111 inline void multiply(const element& a, smallval_t b)
113 limb_t carry = 0;
114 for(unsigned i = 0; i < 5; i++) {
115 wide_t x = (wide_t)a.n[i] * b + carry;
116 n[i] = x & mask;
117 carry = x >> shift;
119 carry *= 19;
120 n[0] += carry;
122 //Reduce mod 2^255-19 and store to buffer.
123 inline void store(uint8_t* buffer)
125 limb_t carry = 19;
126 for(int i = 0; i < 5; i++) {
127 n[i] = n[i] + carry;
128 carry = n[i] >> shift;
129 n[i] = n[i] & mask;
131 carry = 19 - carry * 19;
132 for(int i = 0; i < 5; i++) {
133 n[i] = n[i] - carry;
134 carry = (n[i] >> shift) & 1;
135 n[i] = n[i] & mask;
137 for(unsigned i = 0; i < 32; i++) {
138 buffer[i] = n[8 * i / shift] >> (8 * i % shift);
139 if(8 * i % shift > shift - 8 && i < 26)
140 buffer[i] |= n[8 * i / shift + 1] << (shift - 8 * i % shift);
143 //Load from buffer.
144 inline explicit element(const uint8_t* buffer)
146 memset(n, 0, sizeof(n));
147 for(unsigned i = 0; i < 32; i++) {
148 n[8 * i / shift] |= (limb_t)buffer[i] << (8 * i % shift);
149 n[8 * i / shift] &= mask;
150 if(8 * i % shift > shift - 8 && i < 26) {
151 n[8 * i / shift + 1] |= (limb_t)buffer[i] >> (shift - 8 * i % shift);
155 //Construct 0.
156 inline element()
158 memset(n, 0, sizeof(n));
160 //Construct small value.
161 inline element(smallval_t sval)
163 memset(n, 0, sizeof(n));
164 n[0] = sval;
166 //self + e -> self.
167 inline void sum(const element& e)
169 limb_t carry = 0;
170 for(int i = 0; i < 5; i++) {
171 n[i] = n[i] + e.n[i] + carry;
172 carry = n[i] >> shift;
173 n[i] = n[i] & mask;
175 n[0] += carry * 19;
177 //If condition=1, swap self,e.
178 inline void swap_cond(element& e, cond_t condition)
180 condition = -condition;
181 for(int i = 0; i < 5; i++) {
182 limb_t t = condition & (n[i] ^ e.n[i]);
183 n[i] ^= t;
184 e.n[i] ^= t;
187 inline ~element()
189 zeroize(n, sizeof(n));
191 void debug(const char* pfx) const
193 uint8_t buf[34];
194 std::cerr << pfx << ": ";
195 memset(buf, 0, 34);
196 for(unsigned i = 0; i < 5*64; i++) {
197 unsigned rbit = shift*(i>>6)+(i&63);
198 if((n[i>>6] >> (i&63)) & 1)
199 buf[rbit>>3]|=(1<<(rbit&7));
201 for(unsigned i = 33; i < 34; i--)
202 std::cerr << std::setw(2) << std::setfill('0') << std::hex << std::uppercase
203 << (int)buf[i];
204 std::cerr << std::endl;
206 private:
207 limb_t n[5];
209 #else
210 //Generic (slow).
211 struct element
213 typedef uint32_t smallval_t;
214 typedef uint32_t cond_t;
215 //a^(2^count) -> self
216 inline void square(const element& a, unsigned count = 1)
218 uint64_t x[20];
219 uint32_t t[10];
220 memcpy(t, a.n, sizeof(t));
221 for(unsigned c = 0; c < count; c++) {
222 memset(x, 0, sizeof(x));
223 for(unsigned i = 0; i < 10; i++) {
224 x[i + i] += (uint64_t)t[i] * t[i];
225 for(unsigned j = 0; j < i; j++)
226 x[i + j] += ((uint64_t)t[i] * t[j]) << 1;
229 for(unsigned i = 0; i < 20; i++) {
230 std::cerr << "2^" << std::hex << std::uppercase << 26*i << "*"
231 << std::hex << std::uppercase << x[i] << "+";
233 std::cerr << "0" << std::endl;
235 //Multiplication by 608 can overflow, so reduce these.
236 uint64_t carry2 = 0;
237 for(unsigned i = 0; i < 20; i++) {
238 x[i] += carry2;
239 carry2 = x[i] >> 26;
240 x[i] &= 0x3FFFFFF;
242 x[19] += (carry2 << 26);
244 for(unsigned i = 0; i < 20; i++) {
245 std::cerr << "2^" << std::hex << std::uppercase << 26*i << "*"
246 << std::hex << std::uppercase << x[i] << "+";
248 std::cerr << "0" << std::endl;
250 //Reduce and fold.
251 uint64_t carry = 0;
252 for(unsigned i = 0; i < 10; i++) {
253 x[i] = x[i] + x[10 + i] * 608 + carry;
254 carry = x[i] >> 26;
255 x[i] &= 0x3FFFFFF;
257 //Final reduction.
258 x[0] += carry * 608;
260 for(unsigned i = 0; i < 10; i++)
261 t[i] = x[i];
263 memcpy(n, t, sizeof(n));
264 zeroize(x, sizeof(x));
265 zeroize(t, sizeof(t));
267 //a * b -> self
268 inline void multiply(const element& a, const element& b)
270 uint64_t x[20];
271 memset(x, 0, sizeof(x));
272 for(unsigned i = 0; i < 10; i++)
273 for(unsigned j = 0; j < 10; j++)
274 x[i + j] += (uint64_t)a.n[i] * b.n[j];
276 //Multiplication by 608 can overflow, so reduce these.
277 uint64_t carry2 = 0;
278 for(unsigned i = 9; i < 20; i++) {
279 x[i] += carry2;
280 carry2 = x[i] >> 26;
281 x[i] &= 0x3FFFFFF;
283 x[19] += (carry2 << 26);
285 //Reduce and fold.
286 uint64_t carry = 0;
287 for(unsigned i = 0; i < 10; i++) {
288 x[i] = x[i] + x[10 + i] * 608 + carry;
289 carry = x[i] >> 26;
290 x[i] &= 0x3FFFFFF;
292 //Final reduction.
293 x[0] += carry * 608;
294 for(unsigned i = 0; i < 10; i++)
295 n[i] = x[i];
296 zeroize(x, sizeof(x));
298 //e - self -> self
299 inline void diff_back(const element& e)
301 uint32_t C1 = (1<<28)-2432;
302 uint32_t C2 = (1<<28)-4;
303 n[0] = e.n[0] + C1 - n[0];
304 for(unsigned i = 1; i < 10; i++)
305 n[i] = e.n[i] + C2 - n[i];
306 uint32_t carry = 0;
307 for(unsigned i = 0; i < 10; i++) {
308 n[i] += carry;
309 carry = n[i] >> 26;
310 n[i] &= 0x3FFFFFF;
312 n[9] |= (carry << 26);
314 //a * b -> self (with constant b).
315 inline void multiply(const element& a, smallval_t b)
317 uint64_t carry = 0;
318 for(unsigned i = 0; i < 10; i++) {
319 uint64_t x = (uint64_t)a.n[i] * b + carry;
320 n[i] = x & 0x3FFFFFF;
321 carry = x >> 26;
323 carry = ((carry << 5) | (n[9] >> 21)) * 19;
324 n[9] &= 0x1FFFFF;
325 n[0] += carry;
327 //Reduce mod 2^255-19 and store to buffer.
328 inline void store(uint8_t* buffer)
330 uint32_t carry = (n[9] >> 21) * 19 + 19;
331 n[9] &= 0x1FFFFF;
332 for(int i = 0; i < 10; i++) {
333 n[i] = n[i] + carry;
334 carry = n[i] >> 26;
335 n[i] = n[i] & 0x3FFFFFF;
337 carry = 19 - (n[9] >> 21) * 19;
338 for(int i = 0; i < 10; i++) {
339 n[i] = n[i] - carry;
340 carry = (n[i] >> 26) & 1;
341 n[i] = n[i] & 0x3FFFFFF;
343 n[9] &= 0x1FFFFF;
344 for(unsigned i = 0; i < 32; i++) {
345 buffer[i] = n[8 * i / 26] >> (8 * i % 26);
346 if(8 * i % 26 > 18)
347 buffer[i] |= n[8 * i / 26 + 1] << (26 - 8 * i % 26);
350 //Load from buffer.
351 inline explicit element(const uint8_t* buffer)
353 memset(n, 0, sizeof(n));
354 for(unsigned i = 0; i < 32; i++) {
355 n[8 * i / 26] |= (uint32_t)buffer[i] << (8 * i % 26);
356 n[8 * i / 26] &= 0x3FFFFFF;
357 if(8 * i % 26 > 18) {
358 n[8 * i / 26 + 1] |= (uint32_t)buffer[i] >> (26 - 8 * i % 26);
362 //Construct 0.
363 inline element()
365 memset(n, 0, sizeof(n));
367 //Construct small value.
368 inline element(smallval_t sval)
370 memset(n, 0, sizeof(n));
371 n[0] = sval;
373 //self + e -> self.
374 inline void sum(const element& e)
376 uint32_t carry = 0;
377 for(int i = 0; i < 10; i++) {
378 n[i] = n[i] + e.n[i] + carry;
379 carry = n[i] >> 26;
380 n[i] = n[i] & 0x3FFFFFF;
382 n[0] += carry * 608;
384 //If condition=1, swap self,e.
385 inline void swap_cond(element& e, cond_t condition)
387 condition = -condition;
388 for(int i = 0; i < 10; i++) {
389 uint32_t t = condition & (n[i] ^ e.n[i]);
390 n[i] ^= t;
391 e.n[i] ^= t;
394 inline ~element()
396 zeroize(n, sizeof(n));
398 void debug(const char* pfx) const
400 uint8_t buf[34];
401 std::cerr << pfx << ": ";
402 memset(buf, 0, 34);
403 for(unsigned i = 0; i < 10*32; i++) {
404 unsigned rbit = 26*(i>>5)+(i&31);
405 if((n[i>>5] >> (i&31)) & 1)
406 buf[rbit>>3]|=(1<<(rbit&7));
408 for(unsigned i = 33; i < 34; i--)
409 std::cerr << std::setw(2) << std::setfill('0') << std::hex << std::uppercase
410 << (int)buf[i];
411 std::cerr << std::endl;
413 private:
414 uint32_t n[10];
416 #endif
419 static void montgomery(element& dblx, element& dblz, element& sumx, element& sumz,
420 element& ax, element& az, element& bx, element& bz, const element& diff)
422 element tmp;
423 element oax = ax;
424 ax.sum(az);
425 az.diff_back(oax);
426 element obx = bx;
427 bx.sum(bz);
428 bz.diff_back(obx);
429 oax.multiply(az, bx);
430 obx.multiply(ax, bz);
431 bx.square(ax);
432 bz.square(az);
433 dblx.multiply(bx, bz);
434 bz.diff_back(bx);
435 tmp.multiply(bz, 121665);
436 bx.sum(tmp);
437 dblz.multiply(bx, bz);
438 bx = oax;
439 oax.sum(obx);
440 obx.diff_back(bx);
441 sumx.square(oax);
442 bz.square(obx);
443 sumz.multiply(bz, diff);
446 static void cmultiply(element& ox, element& oz, const uint8_t* key, const element& base)
448 element x1a(1), z1a, x2a(base), z2a(1), x1b, z1b(1), x2b, z2b(1);
450 element::cond_t lbit = 0;
451 for(unsigned i = 31; i < 32; i--) {
452 uint8_t x = key[i];
453 for(unsigned j = 0; j < 4; j++) {
454 element::cond_t bit = (x >> 7);
455 x1a.swap_cond(x2a, bit ^ lbit);
456 z1a.swap_cond(z2a, bit ^ lbit);
457 montgomery(x1b, z1b, x2b, z2b, x1a, z1a, x2a, z2a, base);
458 lbit = bit;
459 x <<= 1;
460 bit = (x >> 7);
461 x1b.swap_cond(x2b, bit ^ lbit);
462 z1b.swap_cond(z2b, bit ^ lbit);
463 montgomery(x1a, z1a, x2a, z2a, x1b, z1b, x2b, z2b, base);
464 x <<= 1;
465 lbit = bit;
468 x1a.swap_cond(x2a, lbit);
469 z1a.swap_cond(z2a, lbit);
470 ox = x1a;
471 oz = z1a;
474 static void invert(element& out, const element& in)
476 element r, y, g, b, c;
477 y.square(in);
478 g.square(y, 2);
479 b.multiply(g, in);
480 r.multiply(b, y);
481 y.square(r);
482 g.multiply(y,b);
483 y.square(g, 5);
484 b.multiply(y, g);
485 y.square(b, 10);
486 g.multiply(y, b);
487 y.square(g, 20);
488 c.multiply(y, g);
489 g.square(c, 10);
490 y.multiply(g, b);
491 b.square(y, 50);
492 g.multiply(y, b);
493 b.square(g, 100);
494 c.multiply(g, b);
495 g.square(c, 50);
496 b.multiply(y, g);
497 y.square(b, 5);
498 out.multiply(r, y);
501 void curve25519(uint8_t* _out, const uint8_t* key, const uint8_t* _base)
503 element base(_base), outx, outz, zinv;
504 cmultiply(outx, outz, key, base);
505 invert(zinv, outz);
506 outz.multiply(outx, zinv);
507 outz.store(_out);
510 void curve25519_clamp(uint8_t* key)
512 key[0] &= 0xF8;
513 key[31] &= 0x7F;
514 key[31] |= 0x40;
517 const uint8_t curve25519_base[32] = {9};
519 #ifdef CURVE25519_VALGRIND_TEST
521 int main()
523 uint8_t buf[128];
524 curve25519(buf+64, buf, buf+32);
525 std::cerr << buf[64] << std::endl;
526 return 0;
529 #endif
531 #ifdef CURVE25519_TEST_MODE
532 #include <cmath>
533 uint64_t arch_get_tsc()
535 uint32_t a, b;
536 asm volatile("rdtsc" : "=a"(a), "=d"(b));
537 return ((uint64_t)b << 32) | a;
540 //For comparision
541 extern "C"
543 int curve25519_donna(uint8_t *mypublic, const uint8_t *secret, const uint8_t *basepoint);
546 int main()
548 uint8_t buf[128] = {0};
549 FILE* fd = fopen("/dev/urandom", "rb");
550 uint64_t ctr = 0;
551 uint64_t _t;
552 double tsum;
553 double tsqr;
554 uint64_t tmin = 999999999;
556 buf[32] = 9;
557 fread(buf, 1, 32, fd);
558 buf[0] &= 248;
559 buf[31] &= 127;
560 buf[31] |= 64;
562 for(unsigned i = 0; i < 32768; i++) {
563 _t = arch_get_tsc();
564 curve25519(buf+64, buf, buf+32);
565 _t = arch_get_tsc() - _t;
566 tsum += _t;
567 tsqr += _t * _t;
568 if(_t < tmin) tmin = _t;
570 tsum /= 32768;
571 tsqr /= 32768;
572 std::cerr << "Time: " << tsum << "+-" << sqrt(tsqr - tsum * tsum) << " >=" << tmin << std::endl;
573 while(true) {
574 fread(buf, 1, 32, fd);
575 buf[0] &= 248;
576 buf[31] &= 127;
577 buf[31] |= 64;
578 curve25519(buf+64, buf, buf+32);
579 curve25519_donna(buf+96, buf, buf+32);
580 if(memcmp(buf+64,buf+96,32)) {
581 std::cerr << "Fail test: " << std::endl;
582 std::cerr << "key:\t";
583 for(unsigned i = 31; i < 32; i--)
584 std::cerr << std::hex << std::uppercase << std::setw(2) << std::setfill('0')
585 << (int)buf[i];
586 std::cerr << std::endl;
587 std::cerr << "point:\t";
588 for(unsigned i = 31; i < 32; i--)
589 std::cerr << std::hex << std::uppercase << std::setw(2) << std::setfill('0')
590 << (int)buf[i+32];
591 std::cerr << std::endl;
592 std::cerr << "res1:\t";
593 for(unsigned i = 31; i < 32; i--)
594 std::cerr << std::hex << std::uppercase << std::setw(2) << std::setfill('0')
595 << (int)buf[i+64];
596 std::cerr << std::endl;
597 std::cerr << "res2:\t";
598 for(unsigned i = 31; i < 32; i--)
599 std::cerr << std::hex << std::uppercase << std::setw(2) << std::setfill('0')
600 << (int)buf[i+96];
601 std::cerr << std::endl;
602 abort();
604 if(++ctr % 10000 == 0)
605 std::cerr << "Passed " << ctr << " tests." << std::endl;
608 #endif