[HEIMDAL-646] malloc(0) checks for AIX
[heimdal.git] / lib / hcrypto / rsa-imath.c
bloba2b9d2a6787c0c7095e92c5da09f6cf3dc0ef82d
1 /*
2 * Copyright (c) 2006 - 2007 Kungliga Tekniska Högskolan
3 * (Royal Institute of Technology, Stockholm, Sweden).
4 * All rights reserved.
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
10 * 1. Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
13 * 2. Redistributions in binary form must reproduce the above copyright
14 * notice, this list of conditions and the following disclaimer in the
15 * documentation and/or other materials provided with the distribution.
17 * 3. Neither the name of the Institute nor the names of its contributors
18 * may be used to endorse or promote products derived from this software
19 * without specific prior written permission.
21 * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND
22 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24 * ARE DISCLAIMED. IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE
25 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
27 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
30 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
31 * SUCH DAMAGE.
34 #include <config.h>
36 #include <stdio.h>
37 #include <stdlib.h>
38 #include <krb5-types.h>
39 #include <assert.h>
41 #include <rsa.h>
43 #include <roken.h>
45 #include "imath/imath.h"
46 #include "imath/iprime.h"
48 static void
49 BN2mpz(mpz_t *s, const BIGNUM *bn)
51 size_t len;
52 void *p;
54 mp_int_init(s);
56 len = BN_num_bytes(bn);
57 p = malloc(len);
58 BN_bn2bin(bn, p);
59 mp_int_read_unsigned(s, p, len);
60 free(p);
63 static BIGNUM *
64 mpz2BN(mpz_t *s)
66 size_t size;
67 BIGNUM *bn;
68 void *p;
70 size = mp_int_unsigned_len(s);
71 p = malloc(size);
72 if (p == NULL && size != 0)
73 return NULL;
74 mp_int_to_unsigned(s, p, size);
76 bn = BN_bin2bn(p, size, NULL);
77 free(p);
78 return bn;
81 static int random_num(mp_int, size_t);
83 static void
84 setup_blind(mp_int n, mp_int b, mp_int bi)
86 mp_int_init(b);
87 mp_int_init(bi);
88 random_num(b, mp_int_count_bits(n));
89 mp_int_mod(b, n, b);
90 mp_int_invmod(b, n, bi);
93 static void
94 blind(mp_int in, mp_int b, mp_int e, mp_int n)
96 mpz_t t1;
97 mp_int_init(&t1);
98 /* in' = (in * b^e) mod n */
99 mp_int_exptmod(b, e, n, &t1);
100 mp_int_mul(&t1, in, in);
101 mp_int_mod(in, n, in);
102 mp_int_clear(&t1);
105 static void
106 unblind(mp_int out, mp_int bi, mp_int n)
108 /* out' = (out * 1/b) mod n */
109 mp_int_mul(out, bi, out);
110 mp_int_mod(out, n, out);
113 static mp_result
114 rsa_private_calculate(mp_int in, mp_int p, mp_int q,
115 mp_int dmp1, mp_int dmq1, mp_int iqmp,
116 mp_int out)
118 mpz_t vp, vq, u;
119 mp_int_init(&vp); mp_int_init(&vq); mp_int_init(&u);
121 /* vq = c ^ (d mod (q - 1)) mod q */
122 /* vp = c ^ (d mod (p - 1)) mod p */
123 mp_int_mod(in, p, &u);
124 mp_int_exptmod(&u, dmp1, p, &vp);
125 mp_int_mod(in, q, &u);
126 mp_int_exptmod(&u, dmq1, q, &vq);
128 /* C2 = 1/q mod p (iqmp) */
129 /* u = (vp - vq)C2 mod p. */
130 mp_int_sub(&vp, &vq, &u);
131 if (mp_int_compare_zero(&u) < 0)
132 mp_int_add(&u, p, &u);
133 mp_int_mul(&u, iqmp, &u);
134 mp_int_mod(&u, p, &u);
136 /* c ^ d mod n = vq + u q */
137 mp_int_mul(&u, q, &u);
138 mp_int_add(&u, &vq, out);
140 mp_int_clear(&vp);
141 mp_int_clear(&vq);
142 mp_int_clear(&u);
144 return MP_OK;
151 static int
152 imath_rsa_public_encrypt(int flen, const unsigned char* from,
153 unsigned char* to, RSA* rsa, int padding)
155 unsigned char *p, *p0;
156 mp_result res;
157 size_t size, padlen;
158 mpz_t enc, dec, n, e;
160 if (padding != RSA_PKCS1_PADDING)
161 return -1;
163 size = RSA_size(rsa);
165 if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
166 return -2;
168 BN2mpz(&n, rsa->n);
169 BN2mpz(&e, rsa->e);
171 p = p0 = malloc(size - 1);
172 if (p0 == NULL) {
173 mp_int_clear(&e);
174 mp_int_clear(&n);
175 return -3;
178 padlen = size - flen - 3;
180 *p++ = 2;
181 if (RAND_bytes(p, padlen) != 1) {
182 mp_int_clear(&e);
183 mp_int_clear(&n);
184 free(p0);
185 return -4;
187 while(padlen) {
188 if (*p == 0)
189 *p = 1;
190 padlen--;
191 p++;
193 *p++ = 0;
194 memcpy(p, from, flen);
195 p += flen;
196 assert((p - p0) == size - 1);
198 mp_int_init(&enc);
199 mp_int_init(&dec);
200 mp_int_read_unsigned(&dec, p0, size - 1);
201 free(p0);
203 res = mp_int_exptmod(&dec, &e, &n, &enc);
205 mp_int_clear(&dec);
206 mp_int_clear(&e);
207 mp_int_clear(&n);
209 if (res != MP_OK)
210 return -4;
213 size_t ssize;
214 ssize = mp_int_unsigned_len(&enc);
215 assert(size >= ssize);
216 mp_int_to_unsigned(&enc, to, ssize);
217 size = ssize;
219 mp_int_clear(&enc);
221 return size;
224 static int
225 imath_rsa_public_decrypt(int flen, const unsigned char* from,
226 unsigned char* to, RSA* rsa, int padding)
228 unsigned char *p;
229 mp_result res;
230 size_t size;
231 mpz_t s, us, n, e;
233 if (padding != RSA_PKCS1_PADDING)
234 return -1;
236 if (flen > RSA_size(rsa))
237 return -2;
239 BN2mpz(&n, rsa->n);
240 BN2mpz(&e, rsa->e);
242 #if 0
243 /* Check that the exponent is larger then 3 */
244 if (mp_int_compare_value(&e, 3) <= 0) {
245 mp_int_clear(&n);
246 mp_int_clear(&e);
247 return -3;
249 #endif
251 mp_int_init(&s);
252 mp_int_init(&us);
253 mp_int_read_unsigned(&s, rk_UNCONST(from), flen);
255 if (mp_int_compare(&s, &n) >= 0) {
256 mp_int_clear(&n);
257 mp_int_clear(&e);
258 return -4;
261 res = mp_int_exptmod(&s, &e, &n, &us);
263 mp_int_clear(&s);
264 mp_int_clear(&n);
265 mp_int_clear(&e);
267 if (res != MP_OK)
268 return -5;
269 p = to;
272 size = mp_int_unsigned_len(&us);
273 assert(size <= RSA_size(rsa));
274 mp_int_to_unsigned(&us, p, size);
276 mp_int_clear(&us);
278 /* head zero was skipped by mp_int_to_unsigned */
279 if (*p == 0)
280 return -6;
281 if (*p != 1)
282 return -7;
283 size--; p++;
284 while (size && *p == 0xff) {
285 size--; p++;
287 if (size == 0 || *p != 0)
288 return -8;
289 size--; p++;
291 memmove(to, p, size);
293 return size;
296 static int
297 imath_rsa_private_encrypt(int flen, const unsigned char* from,
298 unsigned char* to, RSA* rsa, int padding)
300 unsigned char *p, *p0;
301 mp_result res;
302 int size;
303 mpz_t in, out, n, e, b, bi;
304 int blinding = (rsa->flags & RSA_FLAG_NO_BLINDING) == 0;
305 int do_unblind = 0;
307 if (padding != RSA_PKCS1_PADDING)
308 return -1;
310 size = RSA_size(rsa);
312 if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
313 return -2;
315 p0 = p = malloc(size);
316 *p++ = 0;
317 *p++ = 1;
318 memset(p, 0xff, size - flen - 3);
319 p += size - flen - 3;
320 *p++ = 0;
321 memcpy(p, from, flen);
322 p += flen;
323 assert((p - p0) == size);
325 BN2mpz(&n, rsa->n);
326 BN2mpz(&e, rsa->e);
328 mp_int_init(&in);
329 mp_int_init(&out);
330 mp_int_read_unsigned(&in, p0, size);
331 free(p0);
333 if(mp_int_compare_zero(&in) < 0 ||
334 mp_int_compare(&in, &n) >= 0) {
335 size = -3;
336 goto out;
339 if (blinding) {
340 setup_blind(&n, &b, &bi);
341 blind(&in, &b, &e, &n);
342 do_unblind = 1;
345 if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
346 mpz_t p, q, dmp1, dmq1, iqmp;
348 BN2mpz(&p, rsa->p);
349 BN2mpz(&q, rsa->q);
350 BN2mpz(&dmp1, rsa->dmp1);
351 BN2mpz(&dmq1, rsa->dmq1);
352 BN2mpz(&iqmp, rsa->iqmp);
354 res = rsa_private_calculate(&in, &p, &q, &dmp1, &dmq1, &iqmp, &out);
356 mp_int_clear(&p);
357 mp_int_clear(&q);
358 mp_int_clear(&dmp1);
359 mp_int_clear(&dmq1);
360 mp_int_clear(&iqmp);
362 if (res != MP_OK) {
363 size = -4;
364 goto out;
366 } else {
367 mpz_t d;
369 BN2mpz(&d, rsa->d);
370 res = mp_int_exptmod(&in, &d, &n, &out);
371 mp_int_clear(&d);
372 if (res != MP_OK) {
373 size = -5;
374 goto out;
378 if (do_unblind)
379 unblind(&out, &bi, &n);
381 if (size > 0) {
382 size_t ssize;
383 ssize = mp_int_unsigned_len(&out);
384 assert(size >= ssize);
385 mp_int_to_unsigned(&out, to, size);
386 size = ssize;
389 out:
390 if (do_unblind) {
391 mp_int_clear(&b);
392 mp_int_clear(&bi);
395 mp_int_clear(&e);
396 mp_int_clear(&n);
397 mp_int_clear(&in);
398 mp_int_clear(&out);
400 return size;
403 static int
404 imath_rsa_private_decrypt(int flen, const unsigned char* from,
405 unsigned char* to, RSA* rsa, int padding)
407 unsigned char *ptr;
408 mp_result res;
409 size_t size;
410 mpz_t in, out, n, e, b, bi;
411 int blinding = (rsa->flags & RSA_FLAG_NO_BLINDING) == 0;
412 int do_unblind = 0;
414 if (padding != RSA_PKCS1_PADDING)
415 return -1;
417 size = RSA_size(rsa);
418 if (flen > size)
419 return -2;
421 mp_int_init(&in);
422 mp_int_init(&out);
424 BN2mpz(&n, rsa->n);
425 BN2mpz(&e, rsa->e);
427 res = mp_int_read_unsigned(&in, rk_UNCONST(from), flen);
428 if (res != MP_OK) {
429 size = -1;
430 goto out;
433 if(mp_int_compare_zero(&in) < 0 ||
434 mp_int_compare(&in, &n) >= 0) {
435 size = -2;
436 goto out;
439 if (blinding) {
440 setup_blind(&n, &b, &bi);
441 blind(&in, &b, &e, &n);
442 do_unblind = 1;
445 if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
446 mpz_t p, q, dmp1, dmq1, iqmp;
448 BN2mpz(&p, rsa->p);
449 BN2mpz(&q, rsa->q);
450 BN2mpz(&dmp1, rsa->dmp1);
451 BN2mpz(&dmq1, rsa->dmq1);
452 BN2mpz(&iqmp, rsa->iqmp);
454 res = rsa_private_calculate(&in, &p, &q, &dmp1, &dmq1, &iqmp, &out);
456 mp_int_clear(&p);
457 mp_int_clear(&q);
458 mp_int_clear(&dmp1);
459 mp_int_clear(&dmq1);
460 mp_int_clear(&iqmp);
462 if (res != MP_OK) {
463 size = -3;
464 goto out;
467 } else {
468 mpz_t d;
470 if(mp_int_compare_zero(&in) < 0 ||
471 mp_int_compare(&in, &n) >= 0)
472 return MP_RANGE;
474 BN2mpz(&d, rsa->d);
475 res = mp_int_exptmod(&in, &d, &n, &out);
476 mp_int_clear(&d);
477 if (res != MP_OK) {
478 size = -4;
479 goto out;
483 if (do_unblind)
484 unblind(&out, &bi, &n);
486 ptr = to;
488 size_t ssize;
489 ssize = mp_int_unsigned_len(&out);
490 assert(size >= ssize);
491 mp_int_to_unsigned(&out, ptr, ssize);
492 size = ssize;
495 /* head zero was skipped by mp_int_to_unsigned */
496 if (*ptr != 2) {
497 size = -5;
498 goto out;
500 size--; ptr++;
501 while (size && *ptr != 0) {
502 size--; ptr++;
504 if (size == 0)
505 return -6;
506 size--; ptr++;
508 memmove(to, ptr, size);
510 out:
511 if (do_unblind) {
512 mp_int_clear(&b);
513 mp_int_clear(&bi);
516 mp_int_clear(&e);
517 mp_int_clear(&n);
518 mp_int_clear(&in);
519 mp_int_clear(&out);
521 return size;
524 static int
525 random_num(mp_int num, size_t len)
527 unsigned char *p;
528 mp_result res;
530 len = (len + 7) / 8;
531 p = malloc(len);
532 if (p == NULL)
533 return 1;
534 if (RAND_bytes(p, len) != 1) {
535 free(p);
536 return 1;
538 res = mp_int_read_unsigned(num, p, len);
539 free(p);
540 if (res != MP_OK)
541 return 1;
542 return 0;
545 #define CHECK(f, v) if ((f) != (v)) { goto out; }
547 static int
548 imath_rsa_generate_key(RSA *rsa, int bits, BIGNUM *e, BN_GENCB *cb)
550 mpz_t el, p, q, n, d, dmp1, dmq1, iqmp, t1, t2, t3;
551 int counter, ret;
553 if (bits < 789)
554 return -1;
556 ret = -1;
558 mp_int_init(&el);
559 mp_int_init(&p);
560 mp_int_init(&q);
561 mp_int_init(&n);
562 mp_int_init(&d);
563 mp_int_init(&dmp1);
564 mp_int_init(&dmq1);
565 mp_int_init(&iqmp);
566 mp_int_init(&t1);
567 mp_int_init(&t2);
568 mp_int_init(&t3);
570 BN2mpz(&el, e);
572 /* generate p and q so that p != q and bits(pq) ~ bits */
573 counter = 0;
574 do {
575 BN_GENCB_call(cb, 2, counter++);
576 CHECK(random_num(&p, bits / 2 + 1), 0);
577 CHECK(mp_int_find_prime(&p), MP_TRUE);
579 CHECK(mp_int_sub_value(&p, 1, &t1), MP_OK);
580 CHECK(mp_int_gcd(&t1, &el, &t2), MP_OK);
581 } while(mp_int_compare_value(&t2, 1) != 0);
583 BN_GENCB_call(cb, 3, 0);
585 counter = 0;
586 do {
587 BN_GENCB_call(cb, 2, counter++);
588 CHECK(random_num(&q, bits / 2 + 1), 0);
589 CHECK(mp_int_find_prime(&q), MP_TRUE);
591 if (mp_int_compare(&p, &q) == 0) /* don't let p and q be the same */
592 continue;
594 CHECK(mp_int_sub_value(&q, 1, &t1), MP_OK);
595 CHECK(mp_int_gcd(&t1, &el, &t2), MP_OK);
596 } while(mp_int_compare_value(&t2, 1) != 0);
598 /* make p > q */
599 if (mp_int_compare(&p, &q) < 0)
600 mp_int_swap(&p, &q);
602 BN_GENCB_call(cb, 3, 1);
604 /* calculate n, n = p * q */
605 CHECK(mp_int_mul(&p, &q, &n), MP_OK);
607 /* calculate d, d = 1/e mod (p - 1)(q - 1) */
608 CHECK(mp_int_sub_value(&p, 1, &t1), MP_OK);
609 CHECK(mp_int_sub_value(&q, 1, &t2), MP_OK);
610 CHECK(mp_int_mul(&t1, &t2, &t3), MP_OK);
611 CHECK(mp_int_invmod(&el, &t3, &d), MP_OK);
613 /* calculate dmp1 dmp1 = d mod (p-1) */
614 CHECK(mp_int_mod(&d, &t1, &dmp1), MP_OK);
615 /* calculate dmq1 dmq1 = d mod (q-1) */
616 CHECK(mp_int_mod(&d, &t2, &dmq1), MP_OK);
617 /* calculate iqmp iqmp = 1/q mod p */
618 CHECK(mp_int_invmod(&q, &p, &iqmp), MP_OK);
620 /* fill in RSA key */
622 rsa->e = mpz2BN(&el);
623 rsa->p = mpz2BN(&p);
624 rsa->q = mpz2BN(&q);
625 rsa->n = mpz2BN(&n);
626 rsa->d = mpz2BN(&d);
627 rsa->dmp1 = mpz2BN(&dmp1);
628 rsa->dmq1 = mpz2BN(&dmq1);
629 rsa->iqmp = mpz2BN(&iqmp);
631 ret = 1;
632 out:
633 mp_int_clear(&el);
634 mp_int_clear(&p);
635 mp_int_clear(&q);
636 mp_int_clear(&n);
637 mp_int_clear(&d);
638 mp_int_clear(&dmp1);
639 mp_int_clear(&dmq1);
640 mp_int_clear(&iqmp);
641 mp_int_clear(&t1);
642 mp_int_clear(&t2);
643 mp_int_clear(&t3);
645 return ret;
648 static int
649 imath_rsa_init(RSA *rsa)
651 return 1;
654 static int
655 imath_rsa_finish(RSA *rsa)
657 return 1;
660 const RSA_METHOD hc_rsa_imath_method = {
661 "hcrypto imath RSA",
662 imath_rsa_public_encrypt,
663 imath_rsa_public_decrypt,
664 imath_rsa_private_encrypt,
665 imath_rsa_private_decrypt,
666 NULL,
667 NULL,
668 imath_rsa_init,
669 imath_rsa_finish,
671 NULL,
672 NULL,
673 NULL,
674 imath_rsa_generate_key
677 const RSA_METHOD *
678 RSA_imath_method(void)
680 return &hc_rsa_imath_method;