1 #include "curve25519.hpp"
8 void zeroize(void* ptr
, size_t size
)
10 //Whee... Do it like OpenSSL/GnuTLS.
11 volatile char* vptr
= (volatile char*)ptr
;
12 volatile size_t vidx
= 0;
13 do { memset(ptr
, 0, size
); } while(vptr
[vidx
]);
16 #if defined(__x86_64__)
20 typedef uint64_t smallval_t
;
21 typedef uint64_t cond_t
;
22 typedef uint64_t limb_t
;
23 typedef __uint128_t wide_t
;
24 const static int shift
= 51;
25 const static limb_t mask
= (1ULL << shift
) - 1;
27 inline void square(const element
& a
, unsigned count
= 1)
30 memcpy(n
, a
.n
, sizeof(n
));
31 for(unsigned i
= 0; i
< count
; i
++) {
32 s
[0] = (wide_t
)n
[0] * (wide_t
)n
[0] +
33 (wide_t
)(n
[1] << 1) * (wide_t
)(n
[4] * 19) +
34 (wide_t
)(n
[2] << 1) * (wide_t
)(n
[3] * 19);
35 s
[1] = (wide_t
)n
[0] * (wide_t
)(n
[1] << 1) +
36 (wide_t
)(n
[2] << 1) * (wide_t
)(n
[4] * 19) +
37 (wide_t
)n
[3] * (wide_t
)(n
[3] * 19);
38 s
[2] = ((wide_t
)n
[0] * (wide_t
)(n
[2] << 1) +
39 (wide_t
)n
[1] * (wide_t
)n
[1]) +
40 (wide_t
)(n
[3] << 1) * (wide_t
)(n
[4] * 19);
41 s
[3] = ((wide_t
)n
[0] * (wide_t
)n
[3] << 1) +
42 ((wide_t
)n
[1] * (wide_t
)n
[2] << 1) +
43 ((wide_t
)n
[4] * (wide_t
)(n
[4] * 19));
44 s
[4] = ((wide_t
)n
[0] * (wide_t
)n
[4] << 1) +
45 ((wide_t
)n
[1] * (wide_t
)n
[3] << 1) +
46 ((wide_t
)n
[2] * (wide_t
)n
[2]);
47 s
[1] += (s
[0] >> shift
);
48 s
[2] += (s
[1] >> shift
);
49 n
[2] = (limb_t
)s
[2] & mask
;
50 s
[3] += (s
[2] >> shift
);
51 n
[3] = (limb_t
)s
[3] & mask
;
52 s
[4] += (s
[3] >> shift
);
53 n
[4] = (limb_t
)s
[4] & mask
;
54 s
[0] = ((limb_t
)s
[0] & mask
) + 19 * (limb_t
)(s
[4] >> shift
);
55 n
[0] = (limb_t
)s
[0] & mask
;
56 s
[1] = ((limb_t
)s
[1] & mask
) + (limb_t
)(s
[0] >> shift
);
59 zeroize(s
, sizeof(s
));
62 inline void multiply(const element
& a
, const element
& b
)
65 s
[0] = (wide_t
)a
.n
[0] * (wide_t
)b
.n
[0] + (wide_t
)(a
.n
[4] * 19) * (wide_t
)b
.n
[1] +
66 (wide_t
)(a
.n
[3] * 19) * (wide_t
)b
.n
[2] + (wide_t
)a
.n
[2] * (wide_t
)(b
.n
[3] * 19) +
67 (wide_t
)a
.n
[1] * (wide_t
)(b
.n
[4] * 19);
68 s
[1] = (wide_t
)a
.n
[0] * (wide_t
)b
.n
[1] + (wide_t
)a
.n
[1] * (wide_t
)b
.n
[0] +
69 (wide_t
)(a
.n
[4] * 19) * (wide_t
)b
.n
[2] + (wide_t
)(a
.n
[3] * 19) * (wide_t
)b
.n
[3] +
70 (wide_t
)a
.n
[2] * (wide_t
)(b
.n
[4] * 19);
71 s
[2] = (wide_t
)a
.n
[0] * (wide_t
)b
.n
[2] + (wide_t
)a
.n
[1] * (wide_t
)b
.n
[1] +
72 (wide_t
)a
.n
[2] * (wide_t
)b
.n
[0] + (wide_t
)(a
.n
[4] * 19) * (wide_t
)b
.n
[3] +
73 (wide_t
)(a
.n
[3] * 19) * (wide_t
)b
.n
[4];
74 s
[3] = (wide_t
)a
.n
[0] * (wide_t
)b
.n
[3] + (wide_t
)a
.n
[1] * (wide_t
)b
.n
[2] +
75 (wide_t
)a
.n
[2] * (wide_t
)b
.n
[1] + (wide_t
)a
.n
[3] * (wide_t
)b
.n
[0] +
76 (wide_t
)(a
.n
[4] * 19) * (wide_t
)b
.n
[4];
77 s
[4] = (wide_t
)a
.n
[0] * (wide_t
)b
.n
[4] + (wide_t
)a
.n
[1] * (wide_t
)b
.n
[3] +
78 (wide_t
)a
.n
[2] * (wide_t
)b
.n
[2] + (wide_t
)a
.n
[3] * (wide_t
)b
.n
[1] +
79 (wide_t
)a
.n
[4] * (wide_t
)b
.n
[0];
80 s
[1] += (s
[0] >> shift
);
81 s
[2] += (s
[1] >> shift
);
82 n
[2] = (limb_t
)s
[2] & mask
;
83 s
[3] += (s
[2] >> shift
);
84 n
[3] = (limb_t
)s
[3] & mask
;
85 s
[4] += (s
[3] >> shift
);
86 n
[4] = (limb_t
)s
[4] & mask
;
87 s
[0] = ((limb_t
)s
[0] & mask
) + 19 * (limb_t
)(s
[4] >> shift
);
88 n
[0] = (limb_t
)s
[0] & mask
;
89 s
[1] = ((limb_t
)s
[1] & mask
) + (limb_t
)(s
[0] >> shift
);
91 zeroize(s
, sizeof(s
));
94 inline void diff_back(const element
& e
)
96 limb_t C1
= 2 * mask
- 2 * (19 - 1);
98 n
[0] = e
.n
[0] + C1
- n
[0];
99 for(unsigned i
= 1; i
< 5; i
++)
100 n
[i
] = e
.n
[i
] + C2
- n
[i
];
102 for(unsigned i
= 0; i
< 5; i
++) {
104 carry
= n
[i
] >> shift
;
110 //a * b -> self (with constant b).
111 inline void multiply(const element
& a
, smallval_t b
)
114 for(unsigned i
= 0; i
< 5; i
++) {
115 wide_t x
= (wide_t
)a
.n
[i
] * b
+ carry
;
122 //Reduce mod 2^255-19 and store to buffer.
123 inline void store(uint8_t* buffer
)
126 for(int i
= 0; i
< 5; i
++) {
128 carry
= n
[i
] >> shift
;
131 carry
= 19 - carry
* 19;
132 for(int i
= 0; i
< 5; i
++) {
134 carry
= (n
[i
] >> shift
) & 1;
137 for(unsigned i
= 0; i
< 32; i
++) {
138 buffer
[i
] = n
[8 * i
/ shift
] >> (8 * i
% shift
);
139 if(8 * i
% shift
> shift
- 8 && i
< 26)
140 buffer
[i
] |= n
[8 * i
/ shift
+ 1] << (shift
- 8 * i
% shift
);
144 inline explicit element(const uint8_t* buffer
)
146 memset(n
, 0, sizeof(n
));
147 for(unsigned i
= 0; i
< 32; i
++) {
148 n
[8 * i
/ shift
] |= (limb_t
)buffer
[i
] << (8 * i
% shift
);
149 n
[8 * i
/ shift
] &= mask
;
150 if(8 * i
% shift
> shift
- 8 && i
< 26) {
151 n
[8 * i
/ shift
+ 1] |= (limb_t
)buffer
[i
] >> (shift
- 8 * i
% shift
);
158 memset(n
, 0, sizeof(n
));
160 //Construct small value.
161 inline element(smallval_t sval
)
163 memset(n
, 0, sizeof(n
));
167 inline void sum(const element
& e
)
170 for(int i
= 0; i
< 5; i
++) {
171 n
[i
] = n
[i
] + e
.n
[i
] + carry
;
172 carry
= n
[i
] >> shift
;
177 //If condition=1, swap self,e.
178 inline void swap_cond(element
& e
, cond_t condition
)
180 condition
= -condition
;
181 for(int i
= 0; i
< 5; i
++) {
182 limb_t t
= condition
& (n
[i
] ^ e
.n
[i
]);
189 zeroize(n
, sizeof(n
));
191 void debug(const char* pfx
) const
194 std::cerr
<< pfx
<< ": ";
196 for(unsigned i
= 0; i
< 5*64; i
++) {
197 unsigned rbit
= shift
*(i
>>6)+(i
&63);
198 if((n
[i
>>6] >> (i
&63)) & 1)
199 buf
[rbit
>>3]|=(1<<(rbit
&7));
201 for(unsigned i
= 33; i
< 34; i
--)
202 std::cerr
<< std::setw(2) << std::setfill('0') << std::hex
<< std::uppercase
204 std::cerr
<< std::endl
;
213 typedef uint32_t smallval_t
;
214 typedef uint32_t cond_t
;
215 //a^(2^count) -> self
216 inline void square(const element
& a
, unsigned count
= 1)
220 memcpy(t
, a
.n
, sizeof(t
));
221 for(unsigned c
= 0; c
< count
; c
++) {
222 memset(x
, 0, sizeof(x
));
223 for(unsigned i
= 0; i
< 10; i
++) {
224 x
[i
+ i
] += (uint64_t)t
[i
] * t
[i
];
225 for(unsigned j
= 0; j
< i
; j
++)
226 x
[i
+ j
] += ((uint64_t)t
[i
] * t
[j
]) << 1;
229 for(unsigned i = 0; i < 20; i++) {
230 std::cerr << "2^" << std::hex << std::uppercase << 26*i << "*"
231 << std::hex << std::uppercase << x[i] << "+";
233 std::cerr << "0" << std::endl;
235 //Multiplication by 608 can overflow, so reduce these.
237 for(unsigned i
= 0; i
< 20; i
++) {
242 x
[19] += (carry2
<< 26);
244 for(unsigned i = 0; i < 20; i++) {
245 std::cerr << "2^" << std::hex << std::uppercase << 26*i << "*"
246 << std::hex << std::uppercase << x[i] << "+";
248 std::cerr << "0" << std::endl;
252 for(unsigned i
= 0; i
< 10; i
++) {
253 x
[i
] = x
[i
] + x
[10 + i
] * 608 + carry
;
260 for(unsigned i
= 0; i
< 10; i
++)
263 memcpy(n
, t
, sizeof(n
));
264 zeroize(x
, sizeof(x
));
265 zeroize(t
, sizeof(t
));
268 inline void multiply(const element
& a
, const element
& b
)
271 memset(x
, 0, sizeof(x
));
272 for(unsigned i
= 0; i
< 10; i
++)
273 for(unsigned j
= 0; j
< 10; j
++)
274 x
[i
+ j
] += (uint64_t)a
.n
[i
] * b
.n
[j
];
276 //Multiplication by 608 can overflow, so reduce these.
278 for(unsigned i
= 9; i
< 20; i
++) {
283 x
[19] += (carry2
<< 26);
287 for(unsigned i
= 0; i
< 10; i
++) {
288 x
[i
] = x
[i
] + x
[10 + i
] * 608 + carry
;
294 for(unsigned i
= 0; i
< 10; i
++)
296 zeroize(x
, sizeof(x
));
299 inline void diff_back(const element
& e
)
301 uint32_t C1
= (1<<28)-2432;
302 uint32_t C2
= (1<<28)-4;
303 n
[0] = e
.n
[0] + C1
- n
[0];
304 for(unsigned i
= 1; i
< 10; i
++)
305 n
[i
] = e
.n
[i
] + C2
- n
[i
];
307 for(unsigned i
= 0; i
< 10; i
++) {
312 n
[9] |= (carry
<< 26);
314 //a * b -> self (with constant b).
315 inline void multiply(const element
& a
, smallval_t b
)
318 for(unsigned i
= 0; i
< 10; i
++) {
319 uint64_t x
= (uint64_t)a
.n
[i
] * b
+ carry
;
320 n
[i
] = x
& 0x3FFFFFF;
323 carry
= ((carry
<< 5) | (n
[9] >> 21)) * 19;
327 //Reduce mod 2^255-19 and store to buffer.
328 inline void store(uint8_t* buffer
)
330 uint32_t carry
= (n
[9] >> 21) * 19 + 19;
332 for(int i
= 0; i
< 10; i
++) {
335 n
[i
] = n
[i
] & 0x3FFFFFF;
337 carry
= 19 - (n
[9] >> 21) * 19;
338 for(int i
= 0; i
< 10; i
++) {
340 carry
= (n
[i
] >> 26) & 1;
341 n
[i
] = n
[i
] & 0x3FFFFFF;
344 for(unsigned i
= 0; i
< 32; i
++) {
345 buffer
[i
] = n
[8 * i
/ 26] >> (8 * i
% 26);
347 buffer
[i
] |= n
[8 * i
/ 26 + 1] << (26 - 8 * i
% 26);
351 inline explicit element(const uint8_t* buffer
)
353 memset(n
, 0, sizeof(n
));
354 for(unsigned i
= 0; i
< 32; i
++) {
355 n
[8 * i
/ 26] |= (uint32_t)buffer
[i
] << (8 * i
% 26);
356 n
[8 * i
/ 26] &= 0x3FFFFFF;
357 if(8 * i
% 26 > 18) {
358 n
[8 * i
/ 26 + 1] |= (uint32_t)buffer
[i
] >> (26 - 8 * i
% 26);
365 memset(n
, 0, sizeof(n
));
367 //Construct small value.
368 inline element(smallval_t sval
)
370 memset(n
, 0, sizeof(n
));
374 inline void sum(const element
& e
)
377 for(int i
= 0; i
< 10; i
++) {
378 n
[i
] = n
[i
] + e
.n
[i
] + carry
;
380 n
[i
] = n
[i
] & 0x3FFFFFF;
384 //If condition=1, swap self,e.
385 inline void swap_cond(element
& e
, cond_t condition
)
387 condition
= -condition
;
388 for(int i
= 0; i
< 10; i
++) {
389 uint32_t t
= condition
& (n
[i
] ^ e
.n
[i
]);
396 zeroize(n
, sizeof(n
));
398 void debug(const char* pfx
) const
401 std::cerr
<< pfx
<< ": ";
403 for(unsigned i
= 0; i
< 10*32; i
++) {
404 unsigned rbit
= 26*(i
>>5)+(i
&31);
405 if((n
[i
>>5] >> (i
&31)) & 1)
406 buf
[rbit
>>3]|=(1<<(rbit
&7));
408 for(unsigned i
= 33; i
< 34; i
--)
409 std::cerr
<< std::setw(2) << std::setfill('0') << std::hex
<< std::uppercase
411 std::cerr
<< std::endl
;
419 static void montgomery(element
& dblx
, element
& dblz
, element
& sumx
, element
& sumz
,
420 element
& ax
, element
& az
, element
& bx
, element
& bz
, const element
& diff
)
429 oax
.multiply(az
, bx
);
430 obx
.multiply(ax
, bz
);
433 dblx
.multiply(bx
, bz
);
435 tmp
.multiply(bz
, 121665);
437 dblz
.multiply(bx
, bz
);
443 sumz
.multiply(bz
, diff
);
446 static void cmultiply(element
& ox
, element
& oz
, const uint8_t* key
, const element
& base
)
448 element
x1a(1), z1a
, x2a(base
), z2a(1), x1b
, z1b(1), x2b
, z2b(1);
450 element::cond_t lbit
= 0;
451 for(unsigned i
= 31; i
< 32; i
--) {
453 for(unsigned j
= 0; j
< 4; j
++) {
454 element::cond_t bit
= (x
>> 7);
455 x1a
.swap_cond(x2a
, bit
^ lbit
);
456 z1a
.swap_cond(z2a
, bit
^ lbit
);
457 montgomery(x1b
, z1b
, x2b
, z2b
, x1a
, z1a
, x2a
, z2a
, base
);
461 x1b
.swap_cond(x2b
, bit
^ lbit
);
462 z1b
.swap_cond(z2b
, bit
^ lbit
);
463 montgomery(x1a
, z1a
, x2a
, z2a
, x1b
, z1b
, x2b
, z2b
, base
);
468 x1a
.swap_cond(x2a
, lbit
);
469 z1a
.swap_cond(z2a
, lbit
);
474 static void invert(element
& out
, const element
& in
)
476 element r
, y
, g
, b
, c
;
501 void curve25519(uint8_t* _out
, const uint8_t* key
, const uint8_t* _base
)
503 element
base(_base
), outx
, outz
, zinv
;
504 cmultiply(outx
, outz
, key
, base
);
506 outz
.multiply(outx
, zinv
);
510 void curve25519_clamp(uint8_t* key
)
517 const uint8_t curve25519_base
[32] = {9};
519 #ifdef CURVE25519_VALGRIND_TEST
524 curve25519(buf
+64, buf
, buf
+32);
525 std::cerr
<< buf
[64] << std::endl
;
531 #ifdef CURVE25519_TEST_MODE
533 uint64_t arch_get_tsc()
536 asm volatile("rdtsc" : "=a"(a
), "=d"(b
));
537 return ((uint64_t)b
<< 32) | a
;
543 int curve25519_donna(uint8_t *mypublic
, const uint8_t *secret
, const uint8_t *basepoint
);
548 uint8_t buf
[128] = {0};
549 FILE* fd
= fopen("/dev/urandom", "rb");
554 uint64_t tmin
= 999999999;
557 fread(buf
, 1, 32, fd
);
562 for(unsigned i
= 0; i
< 32768; i
++) {
564 curve25519(buf
+64, buf
, buf
+32);
565 _t
= arch_get_tsc() - _t
;
568 if(_t
< tmin
) tmin
= _t
;
572 std::cerr
<< "Time: " << tsum
<< "+-" << sqrt(tsqr
- tsum
* tsum
) << " >=" << tmin
<< std::endl
;
574 fread(buf
, 1, 32, fd
);
578 curve25519(buf
+64, buf
, buf
+32);
579 curve25519_donna(buf
+96, buf
, buf
+32);
580 if(memcmp(buf
+64,buf
+96,32)) {
581 std::cerr
<< "Fail test: " << std::endl
;
582 std::cerr
<< "key:\t";
583 for(unsigned i
= 31; i
< 32; i
--)
584 std::cerr
<< std::hex
<< std::uppercase
<< std::setw(2) << std::setfill('0')
586 std::cerr
<< std::endl
;
587 std::cerr
<< "point:\t";
588 for(unsigned i
= 31; i
< 32; i
--)
589 std::cerr
<< std::hex
<< std::uppercase
<< std::setw(2) << std::setfill('0')
591 std::cerr
<< std::endl
;
592 std::cerr
<< "res1:\t";
593 for(unsigned i
= 31; i
< 32; i
--)
594 std::cerr
<< std::hex
<< std::uppercase
<< std::setw(2) << std::setfill('0')
596 std::cerr
<< std::endl
;
597 std::cerr
<< "res2:\t";
598 for(unsigned i
= 31; i
< 32; i
--)
599 std::cerr
<< std::hex
<< std::uppercase
<< std::setw(2) << std::setfill('0')
601 std::cerr
<< std::endl
;
604 if(++ctr
% 10000 == 0)
605 std::cerr
<< "Passed " << ctr
<< " tests." << std::endl
;