wait for dead children, and then abandon the live ones
[heimdal.git] / lib / hcrypto / rsa-gmp.c
blobb3a994b8034907a3e3e6119837684b0c915d4302
1 /*
2 * Copyright (c) 2006 - 2007 Kungliga Tekniska Högskolan
3 * (Royal Institute of Technology, Stockholm, Sweden).
4 * All rights reserved.
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
10 * 1. Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
13 * 2. Redistributions in binary form must reproduce the above copyright
14 * notice, this list of conditions and the following disclaimer in the
15 * documentation and/or other materials provided with the distribution.
17 * 3. Neither the name of the Institute nor the names of its contributors
18 * may be used to endorse or promote products derived from this software
19 * without specific prior written permission.
21 * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND
22 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24 * ARE DISCLAIMED. IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE
25 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
27 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
30 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
31 * SUCH DAMAGE.
34 #include <config.h>
36 #include <stdio.h>
37 #include <stdlib.h>
38 #include <krb5-types.h>
39 #include <assert.h>
41 #include <rsa.h>
43 #include <roken.h>
45 #ifdef HAVE_GMP
47 #include <gmp.h>
49 static void
50 BN2mpz(mpz_t s, const BIGNUM *bn)
52 size_t len;
53 void *p;
55 len = BN_num_bytes(bn);
56 p = malloc(len);
57 BN_bn2bin(bn, p);
58 mpz_init(s);
59 mpz_import(s, len, 1, 1, 1, 0, p);
61 free(p);
65 static BIGNUM *
66 mpz2BN(mpz_t s)
68 size_t size;
69 BIGNUM *bn;
70 void *p;
72 mpz_export(NULL, &size, 1, 1, 1, 0, s);
73 p = malloc(size);
74 if (p == NULL && size != 0)
75 return NULL;
76 mpz_export(p, &size, 1, 1, 1, 0, s);
77 bn = BN_bin2bn(p, size, NULL);
78 free(p);
79 return bn;
82 static int
83 rsa_private_calculate(mpz_t in, mpz_t p, mpz_t q,
84 mpz_t dmp1, mpz_t dmq1, mpz_t iqmp,
85 mpz_t out)
87 mpz_t vp, vq, u;
88 mpz_init(vp); mpz_init(vq); mpz_init(u);
90 /* vq = c ^ (d mod (q - 1)) mod q */
91 /* vp = c ^ (d mod (p - 1)) mod p */
92 mpz_fdiv_r(vp, in, p);
93 mpz_powm(vp, vp, dmp1, p);
94 mpz_fdiv_r(vq, in, q);
95 mpz_powm(vq, vq, dmq1, q);
97 /* C2 = 1/q mod p (iqmp) */
98 /* u = (vp - vq)C2 mod p. */
99 mpz_sub(u, vp, vq);
100 #if 0
101 if (mp_int_compare_zero(&u) < 0)
102 mp_int_add(&u, p, &u);
103 #endif
104 mpz_mul(u, iqmp, u);
105 mpz_fdiv_r(u, u, p);
107 /* c ^ d mod n = vq + u q */
108 mpz_mul(u, q, u);
109 mpz_add(out, u, vq);
111 mpz_clear(vp);
112 mpz_clear(vq);
113 mpz_clear(u);
115 return 0;
122 static int
123 gmp_rsa_public_encrypt(int flen, const unsigned char* from,
124 unsigned char* to, RSA* rsa, int padding)
126 unsigned char *p, *p0;
127 size_t size, padlen;
128 mpz_t enc, dec, n, e;
130 if (padding != RSA_PKCS1_PADDING)
131 return -1;
133 size = RSA_size(rsa);
135 if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
136 return -2;
138 BN2mpz(n, rsa->n);
139 BN2mpz(e, rsa->e);
141 p = p0 = malloc(size - 1);
142 if (p0 == NULL) {
143 mpz_clear(e);
144 mpz_clear(n);
145 return -3;
148 padlen = size - flen - 3;
149 assert(padlen >= 8);
151 *p++ = 2;
152 if (RAND_bytes(p, padlen) != 1) {
153 mpz_clear(e);
154 mpz_clear(n);
155 free(p0);
156 return -4;
158 while(padlen) {
159 if (*p == 0)
160 *p = 1;
161 padlen--;
162 p++;
164 *p++ = 0;
165 memcpy(p, from, flen);
166 p += flen;
167 assert((p - p0) == size - 1);
169 mpz_init(enc);
170 mpz_init(dec);
171 mpz_import(dec, size - 1, 1, 1, 1, 0, p0);
172 free(p0);
174 mpz_powm(enc, dec, e, n);
176 mpz_clear(dec);
177 mpz_clear(e);
178 mpz_clear(n);
180 size_t ssize;
181 mpz_export(to, &ssize, 1, 1, 1, 0, enc);
182 assert(size >= ssize);
183 size = ssize;
185 mpz_clear(enc);
187 return size;
190 static int
191 gmp_rsa_public_decrypt(int flen, const unsigned char* from,
192 unsigned char* to, RSA* rsa, int padding)
194 unsigned char *p;
195 size_t size;
196 mpz_t s, us, n, e;
198 if (padding != RSA_PKCS1_PADDING)
199 return -1;
201 if (flen > RSA_size(rsa))
202 return -2;
204 BN2mpz(n, rsa->n);
205 BN2mpz(e, rsa->e);
207 #if 0
208 /* Check that the exponent is larger then 3 */
209 if (mp_int_compare_value(&e, 3) <= 0) {
210 mp_int_clear(&n);
211 mp_int_clear(&e);
212 return -3;
214 #endif
216 mpz_init(s);
217 mpz_init(us);
218 mpz_import(s, flen, 1, 1, 1, 0, rk_UNCONST(from));
220 if (mpz_cmp(s, n) >= 0) {
221 mpz_clear(n);
222 mpz_clear(e);
223 return -4;
226 mpz_powm(us, s, e, n);
228 mpz_clear(s);
229 mpz_clear(n);
230 mpz_clear(e);
232 p = to;
234 mpz_export(p, &size, 1, 1, 1, 0, us);
235 assert(size <= RSA_size(rsa));
237 mpz_clear(us);
239 /* head zero was skipped by mp_int_to_unsigned */
240 if (*p == 0)
241 return -6;
242 if (*p != 1)
243 return -7;
244 size--; p++;
245 while (size && *p == 0xff) {
246 size--; p++;
248 if (size == 0 || *p != 0)
249 return -8;
250 size--; p++;
252 memmove(to, p, size);
254 return size;
257 static int
258 gmp_rsa_private_encrypt(int flen, const unsigned char* from,
259 unsigned char* to, RSA* rsa, int padding)
261 unsigned char *p, *p0;
262 size_t size;
263 mpz_t in, out, n, e;
265 if (padding != RSA_PKCS1_PADDING)
266 return -1;
268 size = RSA_size(rsa);
270 if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
271 return -2;
273 p0 = p = malloc(size);
274 *p++ = 0;
275 *p++ = 1;
276 memset(p, 0xff, size - flen - 3);
277 p += size - flen - 3;
278 *p++ = 0;
279 memcpy(p, from, flen);
280 p += flen;
281 assert((p - p0) == size);
283 BN2mpz(n, rsa->n);
284 BN2mpz(e, rsa->e);
286 mpz_init(in);
287 mpz_init(out);
288 mpz_import(in, size, 1, 1, 1, 0, p0);
289 free(p0);
291 #if 0
292 if(mp_int_compare_zero(&in) < 0 ||
293 mp_int_compare(&in, &n) >= 0) {
294 size = 0;
295 goto out;
297 #endif
299 if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
300 mpz_t p, q, dmp1, dmq1, iqmp;
302 BN2mpz(p, rsa->p);
303 BN2mpz(q, rsa->q);
304 BN2mpz(dmp1, rsa->dmp1);
305 BN2mpz(dmq1, rsa->dmq1);
306 BN2mpz(iqmp, rsa->iqmp);
308 rsa_private_calculate(in, p, q, dmp1, dmq1, iqmp, out);
310 mpz_clear(p);
311 mpz_clear(q);
312 mpz_clear(dmp1);
313 mpz_clear(dmq1);
314 mpz_clear(iqmp);
315 } else {
316 mpz_t d;
318 BN2mpz(d, rsa->d);
319 mpz_powm(out, in, d, n);
320 mpz_clear(d);
324 size_t ssize;
325 mpz_export(to, &ssize, 1, 1, 1, 0, out);
326 assert(size >= ssize);
327 size = ssize;
330 mpz_clear(e);
331 mpz_clear(n);
332 mpz_clear(in);
333 mpz_clear(out);
335 return size;
338 static int
339 gmp_rsa_private_decrypt(int flen, const unsigned char* from,
340 unsigned char* to, RSA* rsa, int padding)
342 unsigned char *ptr;
343 size_t size;
344 mpz_t in, out, n, e;
346 if (padding != RSA_PKCS1_PADDING)
347 return -1;
349 size = RSA_size(rsa);
350 if (flen > size)
351 return -2;
353 mpz_init(in);
354 mpz_init(out);
356 BN2mpz(n, rsa->n);
357 BN2mpz(e, rsa->e);
359 mpz_import(in, flen, 1, 1, 1, 0, from);
361 if(mpz_cmp_ui(in, 0) < 0 ||
362 mpz_cmp(in, n) >= 0) {
363 size = 0;
364 goto out;
367 if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
368 mpz_t p, q, dmp1, dmq1, iqmp;
370 BN2mpz(p, rsa->p);
371 BN2mpz(q, rsa->q);
372 BN2mpz(dmp1, rsa->dmp1);
373 BN2mpz(dmq1, rsa->dmq1);
374 BN2mpz(iqmp, rsa->iqmp);
376 rsa_private_calculate(in, p, q, dmp1, dmq1, iqmp, out);
378 mpz_clear(p);
379 mpz_clear(q);
380 mpz_clear(dmp1);
381 mpz_clear(dmq1);
382 mpz_clear(iqmp);
383 } else {
384 mpz_t d;
386 #if 0
387 if(mp_int_compare_zero(&in) < 0 ||
388 mp_int_compare(&in, &n) >= 0)
389 return MP_RANGE;
390 #endif
392 BN2mpz(d, rsa->d);
393 mpz_powm(out, in, d, n);
394 mpz_clear(d);
397 ptr = to;
399 size_t ssize;
400 mpz_export(ptr, &ssize, 1, 1, 1, 0, out);
401 assert(size >= ssize);
402 size = ssize;
405 /* head zero was skipped by mp_int_to_unsigned */
406 if (*ptr != 2)
407 return -3;
408 size--; ptr++;
409 while (size && *ptr != 0) {
410 size--; ptr++;
412 if (size == 0)
413 return -4;
414 size--; ptr++;
416 memmove(to, ptr, size);
418 out:
419 mpz_clear(e);
420 mpz_clear(n);
421 mpz_clear(in);
422 mpz_clear(out);
424 return size;
427 static int
428 random_num(mpz_t num, size_t len)
430 unsigned char *p;
432 len = (len + 7) / 8;
433 p = malloc(len);
434 if (p == NULL)
435 return 1;
436 if (RAND_bytes(p, len) != 1) {
437 free(p);
438 return 1;
440 mpz_import(num, len, 1, 1, 1, 0, p);
441 free(p);
442 return 0;
446 static int
447 gmp_rsa_generate_key(RSA *rsa, int bits, BIGNUM *e, BN_GENCB *cb)
449 mpz_t el, p, q, n, d, dmp1, dmq1, iqmp, t1, t2, t3;
450 int counter, ret;
452 if (bits < 789)
453 return -1;
455 ret = -1;
457 mpz_init(el);
458 mpz_init(p);
459 mpz_init(q);
460 mpz_init(n);
461 mpz_init(d);
462 mpz_init(dmp1);
463 mpz_init(dmq1);
464 mpz_init(iqmp);
465 mpz_init(t1);
466 mpz_init(t2);
467 mpz_init(t3);
469 BN2mpz(el, e);
471 /* generate p and q so that p != q and bits(pq) ~ bits */
473 counter = 0;
474 do {
475 BN_GENCB_call(cb, 2, counter++);
476 random_num(p, bits / 2 + 1);
477 mpz_nextprime(p, p);
479 mpz_sub_ui(t1, p, 1);
480 mpz_gcd(t2, t1, el);
481 } while(mpz_cmp_ui(t2, 1) != 0);
483 BN_GENCB_call(cb, 3, 0);
485 counter = 0;
486 do {
487 BN_GENCB_call(cb, 2, counter++);
488 random_num(q, bits / 2 + 1);
489 mpz_nextprime(q, q);
491 mpz_sub_ui(t1, q, 1);
492 mpz_gcd(t2, t1, el);
493 } while(mpz_cmp_ui(t2, 1) != 0);
495 /* make p > q */
496 if (mpz_cmp(p, q) < 0)
497 mpz_swap(p, q);
499 BN_GENCB_call(cb, 3, 1);
501 /* calculate n, n = p * q */
502 mpz_mul(n, p, q);
504 /* calculate d, d = 1/e mod (p - 1)(q - 1) */
505 mpz_sub_ui(t1, p, 1);
506 mpz_sub_ui(t2, q, 1);
507 mpz_mul(t3, t1, t2);
508 mpz_invert(d, el, t3);
510 /* calculate dmp1 dmp1 = d mod (p-1) */
511 mpz_mod(dmp1, d, t1);
512 /* calculate dmq1 dmq1 = d mod (q-1) */
513 mpz_mod(dmq1, d, t2);
514 /* calculate iqmp iqmp = 1/q mod p */
515 mpz_invert(iqmp, q, p);
517 /* fill in RSA key */
519 rsa->e = mpz2BN(el);
520 rsa->p = mpz2BN(p);
521 rsa->q = mpz2BN(q);
522 rsa->n = mpz2BN(n);
523 rsa->d = mpz2BN(d);
524 rsa->dmp1 = mpz2BN(dmp1);
525 rsa->dmq1 = mpz2BN(dmq1);
526 rsa->iqmp = mpz2BN(iqmp);
528 ret = 1;
530 mpz_clear(el);
531 mpz_clear(p);
532 mpz_clear(q);
533 mpz_clear(n);
534 mpz_clear(d);
535 mpz_clear(dmp1);
536 mpz_clear(dmq1);
537 mpz_clear(iqmp);
538 mpz_clear(t1);
539 mpz_clear(t2);
540 mpz_clear(t3);
542 return ret;
545 static int
546 gmp_rsa_init(RSA *rsa)
548 return 1;
551 static int
552 gmp_rsa_finish(RSA *rsa)
554 return 1;
557 const RSA_METHOD hc_rsa_gmp_method = {
558 "hcrypto GMP RSA",
559 gmp_rsa_public_encrypt,
560 gmp_rsa_public_decrypt,
561 gmp_rsa_private_encrypt,
562 gmp_rsa_private_decrypt,
563 NULL,
564 NULL,
565 gmp_rsa_init,
566 gmp_rsa_finish,
568 NULL,
569 NULL,
570 NULL,
571 gmp_rsa_generate_key
574 #endif /* HAVE_GMP */
577 * RSA implementation using Gnu Multipresistion Library.
580 const RSA_METHOD *
581 RSA_gmp_method(void)
583 #ifdef HAVE_GMP
584 return &hc_rsa_gmp_method;
585 #else
586 return NULL;
587 #endif