Some tweaks to Lua docs
[lsnes.git] / src / library / curve25519.cpp
bloba8dd0ed41c0aaf8f66422afddf52d00c195264f0
1 #include "curve25519.hpp"
2 #include <cstring>
3 #include <iostream>
4 #include <iomanip>
6 namespace
8 //Generic (slow).
9 struct element
11 typedef uint32_t smallval_t;
12 typedef uint32_t cond_t;
13 //a^(2^count) -> self
14 inline void square(const element& a, unsigned count = 1)
16 uint64_t x[20];
17 uint32_t t[10];
18 memcpy(t, a.n, sizeof(t));
19 for(unsigned c = 0; c < count; c++) {
20 memset(x, 0, sizeof(x));
21 for(unsigned i = 0; i < 10; i++) {
22 x[i + i] += (uint64_t)t[i] * t[i];
23 for(unsigned j = 0; j < i; j++)
24 x[i + j] += ((uint64_t)t[i] * t[j]) << 1;
27 for(unsigned i = 0; i < 20; i++) {
28 std::cerr << "2^" << std::hex << std::uppercase << 26*i << "*"
29 << std::hex << std::uppercase << x[i] << "+";
31 std::cerr << "0" << std::endl;
33 //Multiplication by 608 can overflow, so reduce these.
34 uint64_t carry2 = 0;
35 for(unsigned i = 0; i < 20; i++) {
36 x[i] += carry2;
37 carry2 = x[i] >> 26;
38 x[i] &= 0x3FFFFFF;
40 x[19] += (carry2 << 26);
42 for(unsigned i = 0; i < 20; i++) {
43 std::cerr << "2^" << std::hex << std::uppercase << 26*i << "*"
44 << std::hex << std::uppercase << x[i] << "+";
46 std::cerr << "0" << std::endl;
48 //Reduce and fold.
49 uint64_t carry = 0;
50 for(unsigned i = 0; i < 10; i++) {
51 x[i] = x[i] + x[10 + i] * 608 + carry;
52 carry = x[i] >> 26;
53 x[i] &= 0x3FFFFFF;
55 //Final reduction.
56 x[0] += carry * 608;
58 for(unsigned i = 0; i < 10; i++)
59 t[i] = x[i];
61 memcpy(n, t, sizeof(n));
63 //a * b -> self
64 inline void multiply(const element& a, const element& b)
66 uint64_t x[20];
67 memset(x, 0, sizeof(x));
68 for(unsigned i = 0; i < 10; i++)
69 for(unsigned j = 0; j < 10; j++)
70 x[i + j] += (uint64_t)a.n[i] * b.n[j];
72 //Multiplication by 608 can overflow, so reduce these.
73 uint64_t carry2 = 0;
74 for(unsigned i = 9; i < 20; i++) {
75 x[i] += carry2;
76 carry2 = x[i] >> 26;
77 x[i] &= 0x3FFFFFF;
79 x[19] += (carry2 << 26);
81 //Reduce and fold.
82 uint64_t carry = 0;
83 for(unsigned i = 0; i < 10; i++) {
84 x[i] = x[i] + x[10 + i] * 608 + carry;
85 carry = x[i] >> 26;
86 x[i] &= 0x3FFFFFF;
88 //Final reduction.
89 x[0] += carry * 608;
90 for(unsigned i = 0; i < 10; i++)
91 n[i] = x[i];
93 //e - self -> self
94 inline void diff_back(const element& e)
96 uint32_t C1 = (1<<28)-2432;
97 uint32_t C2 = (1<<28)-4;
98 n[0] = e.n[0] + C1 - n[0];
99 for(unsigned i = 1; i < 10; i++)
100 n[i] = e.n[i] + C2 - n[i];
101 uint32_t carry = 0;
102 for(unsigned i = 0; i < 10; i++) {
103 n[i] += carry;
104 carry = n[i] >> 26;
105 n[i] &= 0x3FFFFFF;
107 n[9] |= (carry << 26);
109 //a * b -> self (with constant b).
110 inline void multiply(const element& a, smallval_t b)
112 uint64_t carry = 0;
113 for(unsigned i = 0; i < 10; i++) {
114 uint64_t x = (uint64_t)a.n[i] * b + carry;
115 n[i] = x & 0x3FFFFFF;
116 carry = x >> 26;
118 carry = ((carry << 5) | (n[9] >> 21)) * 19;
119 n[9] &= 0x1FFFFF;
120 n[0] += carry;
122 //Reduce mod 2^255-19 and store to buffer.
123 inline void store(uint8_t* buffer)
125 uint32_t carry = (n[9] >> 21) * 19 + 19;
126 n[9] &= 0x1FFFFF;
127 for(int i = 0; i < 10; i++) {
128 n[i] = n[i] + carry;
129 carry = n[i] >> 26;
130 n[i] = n[i] & 0x3FFFFFF;
132 carry = 19 - (n[9] >> 21) * 19;
133 for(int i = 0; i < 10; i++) {
134 n[i] = n[i] - carry;
135 carry = (n[i] >> 26) & 1;
136 n[i] = n[i] & 0x3FFFFFF;
138 n[9] &= 0x1FFFFF;
139 for(unsigned i = 0; i < 32; i++) {
140 buffer[i] = n[8 * i / 26] >> (8 * i % 26);
141 if(8 * i % 26 > 18)
142 buffer[i] |= n[8 * i / 26 + 1] << (26 - 8 * i % 26);
145 //Load from buffer.
146 inline explicit element(const uint8_t* buffer)
148 memset(n, 0, sizeof(n));
149 for(unsigned i = 0; i < 32; i++) {
150 n[8 * i / 26] |= (uint32_t)buffer[i] << (8 * i % 26);
151 n[8 * i / 26] &= 0x3FFFFFF;
152 if(8 * i % 26 > 18) {
153 n[8 * i / 26 + 1] |= (uint32_t)buffer[i] >> (26 - 8 * i % 26);
157 //Construct 0.
158 inline element()
160 memset(n, 0, sizeof(n));
162 //Construct small value.
163 inline element(smallval_t sval)
165 memset(n, 0, sizeof(n));
166 n[0] = sval;
168 //self + e -> self.
169 inline void sum(const element& e)
171 uint32_t carry = 0;
172 for(int i = 0; i < 10; i++) {
173 n[i] = n[i] + e.n[i] + carry;
174 carry = n[i] >> 26;
175 n[i] = n[i] & 0x3FFFFFF;
177 n[0] += carry * 608;
179 //If condition=1, swap self,e.
180 inline void swap_cond(element& e, cond_t condition)
182 condition = -condition;
183 for(int i = 0; i < 10; i++) {
184 uint32_t t = condition & (n[i] ^ e.n[i]);
185 n[i] ^= t;
186 e.n[i] ^= t;
189 void debug(const char* pfx) const
191 uint8_t buf[34];
192 std::cerr << pfx << ": ";
193 memset(buf, 0, 34);
194 for(unsigned i = 0; i < 10*32; i++) {
195 unsigned rbit = 26*(i>>5)+(i&31);
196 if((n[i>>5] >> (i&31)) & 1)
197 buf[rbit>>3]|=(1<<(rbit&7));
199 for(unsigned i = 33; i < 34; i--)
200 std::cerr << std::setw(2) << std::setfill('0') << std::hex << std::uppercase
201 << (int)buf[i];
202 std::cerr << std::endl;
204 private:
205 uint32_t n[10];
209 static void montgomery(element& dblx, element& dblz, element& sumx, element& sumz,
210 element& ax, element& az, element& bx, element& bz, const element& diff)
212 element tmp;
213 element oax = ax;
214 ax.sum(az);
215 az.diff_back(oax);
216 element obx = bx;
217 bx.sum(bz);
218 bz.diff_back(obx);
219 oax.multiply(az, bx);
220 obx.multiply(ax, bz);
221 bx.square(ax);
222 bz.square(az);
223 dblx.multiply(bx, bz);
224 bz.diff_back(bx);
225 tmp.multiply(bz, 121665);
226 bx.sum(tmp);
227 dblz.multiply(bx, bz);
228 bx = oax;
229 oax.sum(obx);
230 obx.diff_back(bx);
231 sumx.square(oax);
232 bz.square(obx);
233 sumz.multiply(bz, diff);
236 static void cmultiply(element& ox, element& oz, const uint8_t* key, const element& base)
238 element x1a(1), z1a, x2a(base), z2a(1), x1b, z1b(1), x2b, z2b(1);
240 for(unsigned i = 31; i < 32; i--) {
241 uint8_t x = key[i];
242 for(unsigned j = 0; j < 4; j++) {
243 element::cond_t bit = (x >> 7);
244 x1a.swap_cond(x2a, bit);
245 z1a.swap_cond(z2a, bit);
246 montgomery(x1b, z1b, x2b, z2b, x1a, z1a, x2a, z2a, base);
247 x1b.swap_cond(x2b, bit);
248 z1b.swap_cond(z2b, bit);
249 x <<= 1;
250 bit = (x >> 7);
251 x1b.swap_cond(x2b, bit);
252 z1b.swap_cond(z2b, bit);
253 montgomery(x1a, z1a, x2a, z2a, x1b, z1b, x2b, z2b, base);
254 x1a.swap_cond(x2a, bit);
255 z1a.swap_cond(z2a, bit);
256 x <<= 1;
259 ox = x1a;
260 oz = z1a;
263 static void invert(element& out, const element& in)
265 element r, y, g, b, c;
266 y.square(in);
267 g.square(y, 2);
268 b.multiply(g, in);
269 r.multiply(b, y);
270 y.square(r);
271 g.multiply(y,b);
272 y.square(g, 5);
273 b.multiply(y, g);
274 y.square(b, 10);
275 g.multiply(y, b);
276 y.square(g, 20);
277 c.multiply(y, g);
278 g.square(c, 10);
279 y.multiply(g, b);
280 b.square(y, 50);
281 g.multiply(y, b);
282 b.square(g, 100);
283 c.multiply(g, b);
284 g.square(c, 50);
285 b.multiply(y, g);
286 y.square(b, 5);
287 out.multiply(r, y);
290 void curve25519(uint8_t* _out, const uint8_t* key, const uint8_t* _base)
292 element base(_base), outx, outz, zinv;
293 cmultiply(outx, outz, key, base);
294 invert(zinv, outz);
295 outz.multiply(outx, zinv);
296 outz.store(_out);
299 void curve25519_clamp(uint8_t* key)
301 key[0] &= 0xF8;
302 key[31] &= 0x7F;
303 key[31] |= 0x40;
306 const uint8_t curve25519_base[32] = {9};
309 //For comparision
310 extern "C"
312 int curve25519_donna(uint8_t *mypublic, const uint8_t *secret, const uint8_t *basepoint);
315 int main()
317 uint8_t buf[128] = {0};
318 FILE* fd = fopen("/dev/urandom", "rb");
319 uint64_t ctr = 0;
320 buf[32] = 9;
321 while(true) {
322 fread(buf, 1, 32, fd);
323 buf[0] &= 248;
324 buf[31] &= 127;
325 buf[31] |= 64;
326 curve25519(buf+64, buf, buf+32);
327 curve25519_donna(buf+96, buf, buf+32);
328 if(memcmp(buf+64,buf+96,32)) {
329 std::cerr << "Fail test: " << std::endl;
330 std::cerr << "key:\t";
331 for(unsigned i = 31; i < 32; i--)
332 std::cerr << std::hex << std::uppercase << std::setw(2) << std::setfill('0')
333 << (int)buf[i];
334 std::cerr << std::endl;
335 std::cerr << "point:\t";
336 for(unsigned i = 31; i < 32; i--)
337 std::cerr << std::hex << std::uppercase << std::setw(2) << std::setfill('0')
338 << (int)buf[i+32];
339 std::cerr << std::endl;
340 std::cerr << "res1:\t";
341 for(unsigned i = 31; i < 32; i--)
342 std::cerr << std::hex << std::uppercase << std::setw(2) << std::setfill('0')
343 << (int)buf[i+64];
344 std::cerr << std::endl;
345 std::cerr << "res2:\t";
346 for(unsigned i = 31; i < 32; i--)
347 std::cerr << std::hex << std::uppercase << std::setw(2) << std::setfill('0')
348 << (int)buf[i+96];
349 std::cerr << std::endl;
350 abort();
352 if(++ctr % 10000 == 0)
353 std::cerr << "Passed " << ctr << " tests." << std::endl;