lib/hcrypto: mpz2BN return NULL if mp_ubin_size(s) returns zero
[heimdal.git] / lib / hcrypto / rsa-ltm.c
blob1d5b73e60e5a99d02311309e6e9bf07c614349d5
1 /*
2 * Copyright (c) 2006 - 2007, 2010 Kungliga Tekniska Högskolan
3 * (Royal Institute of Technology, Stockholm, Sweden).
4 * All rights reserved.
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
10 * 1. Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
13 * 2. Redistributions in binary form must reproduce the above copyright
14 * notice, this list of conditions and the following disclaimer in the
15 * documentation and/or other materials provided with the distribution.
17 * 3. Neither the name of the Institute nor the names of its contributors
18 * may be used to endorse or promote products derived from this software
19 * without specific prior written permission.
21 * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND
22 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
24 * ARE DISCLAIMED. IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE
25 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
27 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
28 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
29 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
30 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
31 * SUCH DAMAGE.
34 #include <config.h>
35 #include <roken.h>
36 #include <krb5-types.h>
37 #include <assert.h>
39 #include <rsa.h>
41 #include "tommath.h"
43 #define CHECK(f) \
44 do { if (ret == MP_OKAY && ((ret = f)) != MP_OKAY) { goto out; } } while (0)
45 #define FIRST(e) do { ret = (e); } while (0)
46 #define FIRST_ALLOC(e) \
47 do { where = __LINE__; ret = ((e)) ? MP_OKAY : MP_MEM; } while (0)
48 #define THEN_MP(e) \
49 do { where = __LINE__ + 1; if (ret == MP_OKAY) ret = (e); } while (0)
50 #define THEN_IF_MP(cond, e) \
51 do { where = __LINE__ + 1; if (ret == MP_OKAY && (cond)) ret = (e); } while (0)
52 #define THEN_IF_VOID(cond, e) \
53 do { if (ret == MP_OKAY && (cond)) e; } while (0)
54 #define THEN_VOID(e) \
55 do { if (ret == MP_OKAY) e; } while (0)
56 #define THEN_ALLOC(e) \
57 do { where = __LINE__ + 1; if (ret == MP_OKAY) ret = ((e)) ? MP_OKAY : MP_MEM; } while (0)
59 static mp_err
60 random_num(mp_int *num, size_t len)
62 unsigned char *p;
63 mp_err ret = MP_MEM;
65 len = (len + 7) / 8; /* bits to bytes */
66 if ((p = malloc(len)) && RAND_bytes(p, len) != 1)
67 ret = MP_ERR;
68 if (p)
69 ret = mp_from_ubin(num, p, len);
70 free(p);
71 return ret;
74 static mp_err
75 BN2mpz(mp_int *s, const BIGNUM *bn)
77 size_t len;
78 mp_err ret = MP_MEM;
79 void *p;
81 len = BN_num_bytes(bn);
82 p = malloc(len);
83 if (p) {
84 BN_bn2bin(bn, p);
85 ret = mp_from_ubin(s, p, len);
87 free(p);
88 return ret;
91 static mp_err
92 setup_blind(mp_int *n, mp_int *b, mp_int *bi)
94 mp_err ret;
96 ret = random_num(b, mp_count_bits(n));
97 if (ret == MP_OKAY) ret = mp_mod(b, n, b);
98 if (ret == MP_OKAY) ret = mp_invmod(b, n, bi);
99 return ret;
102 static mp_err
103 blind(mp_int *in, mp_int *b, mp_int *e, mp_int *n)
105 mp_err ret;
106 mp_int t1;
108 ret = mp_init(&t1);
109 /* in' = (in * b^e) mod n */
110 if (ret == MP_OKAY) ret = mp_exptmod(b, e, n, &t1);
111 if (ret == MP_OKAY) ret = mp_mul(&t1, in, in);
112 if (ret == MP_OKAY) ret = mp_mod(in, n, in);
113 mp_clear(&t1);
114 return ret;
117 static mp_err
118 unblind(mp_int *out, mp_int *bi, mp_int *n)
120 mp_err ret;
122 /* out' = (out * 1/b) mod n */
123 ret = mp_mul(out, bi, out);
124 if (ret == MP_OKAY) ret = mp_mod(out, n, out);
125 return ret;
128 static mp_err
129 ltm_rsa_private_calculate(mp_int * in, mp_int * p, mp_int * q,
130 mp_int * dmp1, mp_int * dmq1, mp_int * iqmp,
131 mp_int * out)
133 mp_err ret;
134 mp_int vp, vq, u;
135 int where HEIMDAL_UNUSED_ATTRIBUTE = 0;
137 FIRST(mp_init_multi(&vp, &vq, &u, NULL));
139 /* vq = c ^ (d mod (q - 1)) mod q */
140 /* vp = c ^ (d mod (p - 1)) mod p */
141 THEN_MP(mp_mod(in, p, &u));
142 THEN_MP(mp_exptmod(&u, dmp1, p, &vp));
143 THEN_MP(mp_mod(in, q, &u));
144 THEN_MP(mp_exptmod(&u, dmq1, q, &vq));
146 /* C2 = 1/q mod p (iqmp) */
147 /* u = (vp - vq)C2 mod p. */
148 THEN_MP(mp_sub(&vp, &vq, &u));
149 THEN_IF_MP(mp_isneg(&u), mp_add(&u, p, &u));
150 THEN_MP(mp_mul(&u, iqmp, &u));
151 THEN_MP(mp_mod(&u, p, &u));
153 /* c ^ d mod n = vq + u q */
154 THEN_MP(mp_mul(&u, q, &u));
155 THEN_MP(mp_add(&u, &vq, out));
157 mp_clear_multi(&vp, &vq, &u, NULL);
158 return ret;
165 static int
166 ltm_rsa_public_encrypt(int flen, const unsigned char* from,
167 unsigned char* to, RSA* rsa, int padding)
169 unsigned char *p = NULL, *p0 = NULL;
170 size_t size, ssize, padlen;
171 mp_int enc, dec, n, e;
172 mp_err ret;
173 int where = __LINE__;
175 if (padding != RSA_PKCS1_PADDING)
176 return -1;
178 FIRST(mp_init_multi(&n, &e, &enc, &dec, NULL));
180 size = RSA_size(rsa);
181 THEN_IF_MP((size < RSA_PKCS1_PADDING_SIZE ||
182 size - RSA_PKCS1_PADDING_SIZE < flen),
183 MP_ERR);
184 THEN_MP(BN2mpz(&n, rsa->n));
185 THEN_MP(BN2mpz(&e, rsa->e));
186 THEN_IF_MP((mp_cmp_d(&e, 3) == MP_LT), MP_ERR);
187 THEN_ALLOC((p = p0 = malloc(size - 1)));
189 if (ret == MP_OKAY) {
190 padlen = size - flen - 3;
191 *p++ = 2;
193 THEN_IF_MP((RAND_bytes(p, padlen) != 1), MP_ERR);
195 if (ret == MP_OKAY) {
196 while (padlen) {
197 if (*p == 0)
198 *p = 1;
199 padlen--;
200 p++;
202 *p++ = 0;
203 memcpy(p, from, flen);
204 p += flen;
205 assert((p - p0) == size - 1);
208 THEN_MP(mp_from_ubin(&dec, p0, size - 1));
209 THEN_MP(mp_exptmod(&dec, &e, &n, &enc));
210 THEN_VOID(ssize = mp_ubin_size(&enc));
211 THEN_VOID(assert(size >= ssize));
212 THEN_MP(mp_to_ubin(&enc, to, SIZE_MAX, NULL));
213 THEN_VOID(size = ssize);
215 mp_clear_multi(&dec, &e, &n, NULL);
216 mp_clear(&enc);
217 free(p0);
218 return ret == MP_OKAY ? size : -where;
221 static int
222 ltm_rsa_public_decrypt(int flen, const unsigned char* from,
223 unsigned char* to, RSA* rsa, int padding)
225 unsigned char *p;
226 mp_err ret;
227 size_t size;
228 mp_int s, us, n, e;
229 int where = 0;
231 if (padding != RSA_PKCS1_PADDING)
232 return -1;
234 if (flen > RSA_size(rsa))
235 return -2;
237 FIRST(mp_init_multi(&e, &n, &s, &us, NULL));
238 THEN_MP(BN2mpz(&n, rsa->n));
239 THEN_MP(BN2mpz(&e, rsa->e));
240 THEN_MP((mp_cmp_d(&e, 3) == MP_LT) ? MP_ERR : MP_OKAY);
241 THEN_MP(mp_from_ubin(&s, rk_UNCONST(from), (size_t)flen));
242 THEN_MP((mp_cmp(&s, &n) >= 0) ? MP_ERR : MP_OKAY);
243 THEN_MP(mp_exptmod(&s, &e, &n, &us));
245 THEN_VOID(p = to);
246 THEN_VOID(size = mp_ubin_size(&us));
247 THEN_VOID(assert(size <= RSA_size(rsa)));
248 THEN_MP(mp_to_ubin(&us, p, SIZE_MAX, NULL));
250 mp_clear_multi(&e, &n, &s, NULL);
251 mp_clear(&us);
253 if (ret != MP_OKAY)
254 return -where;
256 /* head zero was skipped by mp_to_unsigned_bin */
257 if (*p == 0)
258 return -where;
259 if (*p != 1)
260 return -(where + 1);
261 size--; p++;
262 while (size && *p == 0xff) {
263 size--; p++;
265 if (size == 0 || *p != 0)
266 return -(where + 2);
267 size--; p++;
268 memmove(to, p, size);
269 return size;
272 static int
273 ltm_rsa_private_encrypt(int flen, const unsigned char* from,
274 unsigned char* to, RSA* rsa, int padding)
276 unsigned char *ptr, *ptr0 = NULL;
277 mp_err ret;
278 mp_int in, out, n, e;
279 mp_int bi, b;
280 size_t size;
281 int blinding = (rsa->flags & RSA_FLAG_NO_BLINDING) == 0;
282 int do_unblind = 0;
283 int where = 0;
285 if (padding != RSA_PKCS1_PADDING)
286 return -1;
288 FIRST(mp_init_multi(&e, &n, &in, &out, &b, &bi, NULL));
290 size = RSA_size(rsa);
291 if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
292 return -2;
294 THEN_ALLOC((ptr0 = ptr = malloc(size)));
295 if (ret == MP_OKAY) {
296 *ptr++ = 0;
297 *ptr++ = 1;
298 memset(ptr, 0xff, size - flen - 3);
299 ptr += size - flen - 3;
300 *ptr++ = 0;
301 memcpy(ptr, from, flen);
302 ptr += flen;
303 assert((ptr - ptr0) == size);
306 THEN_MP(BN2mpz(&n, rsa->n));
307 THEN_MP(BN2mpz(&e, rsa->e));
308 THEN_IF_MP((mp_cmp_d(&e, 3) == MP_LT), MP_ERR);
309 THEN_MP(mp_from_ubin(&in, ptr0, size));
310 free(ptr0);
312 THEN_IF_MP((mp_isneg(&in) || mp_cmp(&in, &n) >= 0), MP_ERR);
314 if (blinding) {
315 THEN_MP(setup_blind(&n, &b, &bi));
316 THEN_MP(blind(&in, &b, &e, &n));
317 do_unblind = 1;
320 if (ret == MP_OKAY && rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 &&
321 rsa->iqmp) {
322 mp_int p, q, dmp1, dmq1, iqmp;
324 FIRST(mp_init_multi(&p, &q, &dmp1, &dmq1, &iqmp, NULL));
325 THEN_MP(BN2mpz(&p, rsa->p));
326 THEN_MP(BN2mpz(&q, rsa->q));
327 THEN_MP(BN2mpz(&dmp1, rsa->dmp1));
328 THEN_MP(BN2mpz(&dmq1, rsa->dmq1));
329 THEN_MP(BN2mpz(&iqmp, rsa->iqmp));
330 THEN_MP(ltm_rsa_private_calculate(&in, &p, &q, &dmp1, &dmq1, &iqmp,
331 &out));
332 mp_clear_multi(&p, &q, &dmp1, &dmq1, &iqmp, NULL);
333 if (ret != MP_OKAY) goto out;
334 } else if (ret == MP_OKAY) {
335 mp_int d;
337 THEN_MP(BN2mpz(&d, rsa->d));
338 THEN_MP(mp_exptmod(&in, &d, &n, &out));
339 mp_clear(&d);
340 if (ret != MP_OKAY) goto out;
343 if (do_unblind)
344 THEN_MP(unblind(&out, &bi, &n));
346 if (ret == MP_OKAY && size > 0) {
347 size_t ssize;
349 ssize = mp_ubin_size(&out);
350 assert(size >= ssize);
351 THEN_MP(mp_to_ubin(&out, to, SIZE_MAX, NULL));
352 size = ssize;
355 out:
356 mp_clear_multi(&e, &n, &in, &out, &b, &bi, NULL);
357 return ret == MP_OKAY ? size : -where;
360 static int
361 ltm_rsa_private_decrypt(int flen, const unsigned char* from,
362 unsigned char* to, RSA* rsa, int padding)
364 unsigned char *ptr;
365 size_t size;
366 mp_err ret;
367 mp_int in, out, n, e, b, bi;
368 int blinding = (rsa->flags & RSA_FLAG_NO_BLINDING) == 0;
369 int do_unblind = 0;
370 int where = 0;
372 if (padding != RSA_PKCS1_PADDING)
373 return -1;
375 size = RSA_size(rsa);
376 if (flen > size)
377 return -2;
379 FIRST(mp_init_multi(&in, &n, &e, &out, &b, &bi, NULL));
380 THEN_MP(BN2mpz(&n, rsa->n));
381 THEN_MP(BN2mpz(&e, rsa->e));
382 THEN_IF_MP((mp_cmp_d(&e, 3) == MP_LT), MP_ERR);
383 THEN_MP(mp_from_ubin(&in, rk_UNCONST(from), flen));
384 THEN_IF_MP((mp_isneg(&in) || mp_cmp(&in, &n) >= 0), MP_ERR);
386 if (blinding) {
387 THEN_MP(setup_blind(&n, &b, &bi));
388 THEN_MP(blind(&in, &b, &e, &n));
389 do_unblind = 1;
392 if (ret == MP_OKAY && rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 &&
393 rsa->iqmp) {
394 mp_int p, q, dmp1, dmq1, iqmp;
396 THEN_MP(mp_init_multi(&p, &q, &dmp1, &dmq1, &iqmp, NULL));
397 THEN_MP(BN2mpz(&p, rsa->p));
398 THEN_MP(BN2mpz(&q, rsa->q));
399 THEN_MP(BN2mpz(&dmp1, rsa->dmp1));
400 THEN_MP(BN2mpz(&dmq1, rsa->dmq1));
401 THEN_MP(BN2mpz(&iqmp, rsa->iqmp));
402 THEN_MP(ltm_rsa_private_calculate(&in, &p, &q, &dmp1, &dmq1, &iqmp, &out));
403 mp_clear_multi(&p, &q, &dmp1, &dmq1, &iqmp, NULL);
404 if (ret != MP_OKAY) goto out;
405 } else if (ret == MP_OKAY) {
406 mp_int d;
408 THEN_IF_MP((mp_isneg(&in) || mp_cmp(&in, &n) >= 0), MP_ERR);
409 THEN_MP(BN2mpz(&d, rsa->d));
410 THEN_MP(mp_exptmod(&in, &d, &n, &out));
411 mp_clear(&d);
412 if (ret != MP_OKAY) goto out;
415 if (do_unblind)
416 THEN_MP(unblind(&out, &bi, &n));
418 if (ret == MP_OKAY) {
419 size_t ssize;
421 ptr = to;
422 ssize = mp_ubin_size(&out);
423 assert(size >= ssize);
424 ret = mp_to_ubin(&out, ptr, SIZE_MAX, NULL);
425 if (ret != MP_OKAY) goto out;
426 size = ssize;
428 /* head zero was skipped by mp_int_to_unsigned */
429 if (*ptr != 2) {
430 where = __LINE__;
431 goto out;
433 size--; ptr++;
434 while (size && *ptr != 0) {
435 size--; ptr++;
437 if (size == 0) {
438 where = __LINE__;
439 goto out;
441 size--; ptr++;
442 memmove(to, ptr, size);
445 out:
446 mp_clear_multi(&e, &n, &in, &out, &b, &bi, NULL);
447 return (ret == MP_OKAY) ? size : -where;
450 static BIGNUM *
451 mpz2BN(mp_int *s)
453 size_t size;
454 BIGNUM *bn;
455 mp_err ret;
456 void *p;
458 size = mp_ubin_size(s);
459 if (size == 0)
460 return NULL;
462 p = malloc(size);
463 if (p == NULL)
464 return NULL;
466 ret = mp_to_ubin(s, p, SIZE_MAX, NULL);
467 if (ret == MP_OKAY)
468 bn = BN_bin2bn(p, size, NULL);
469 free(p);
470 return (ret == MP_OKAY) ? bn : NULL;
473 enum gen_pq_type { GEN_P, GEN_Q };
475 static int
476 gen_p(int bits, enum gen_pq_type pq_type, uint8_t nibble_pair, mp_int *p, mp_int *e, BN_GENCB *cb)
478 unsigned char *buf = NULL;
479 mp_bool res;
480 mp_err ret = MP_MEM;
481 mp_int t1, t2;
482 size_t len = (bits + 7) / 8;
483 int trials = mp_prime_rabin_miller_trials(bits);
484 int counter = 0;
485 int where HEIMDAL_UNUSED_ATTRIBUTE = 0;
488 FIRST(mp_init_multi(&t1, &t2, NULL));
489 if (ret == MP_OKAY && (buf = malloc(len))) do {
490 BN_GENCB_call(cb, 2, counter++);
491 /* random bytes */
492 ret = (RAND_bytes(buf, len) == 1) ? MP_OKAY : MP_ERR;
494 /* make it odd */
495 buf[len - 1] |= 1;
497 /* ensure the high nibble of the product is at least 128 */
498 if (pq_type == GEN_P)
499 buf[0] = (nibble_pair & 0xf0) | (buf[0] & 0x0f);
500 else
501 buf[0] = ((nibble_pair & 0x0f) << 4) | (buf[0] & 0x0f);
503 /* load number */
504 THEN_MP(mp_from_ubin(p, buf, len));
506 /* test primality; repeat if not */
507 THEN_MP(mp_prime_is_prime(p, trials, &res));
508 if (ret == MP_OKAY && res == MP_NO) continue;
510 /* check gcd(p - 1, e) == 1 */
511 THEN_MP(mp_sub_d(p, 1, &t1));
512 THEN_MP(mp_gcd(&t1, e, &t2));
513 } while (ret == MP_OKAY && mp_cmp_d(&t2, 1) != MP_EQ);
515 mp_clear_multi(&t1, &t2, NULL);
516 free(buf);
517 return ret;
520 static uint8_t pq_high_nibble_pairs[] = {
521 0x9f, 0xad, 0xae, 0xaf, 0xbc, 0xbd, 0xbe, 0xbf, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
522 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, 0xf9,
523 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff
526 static int
527 ltm_rsa_generate_key(RSA *rsa, int bits, BIGNUM *e, BN_GENCB *cb)
529 mp_int el, p, q, n, d, dmp1, dmq1, iqmp, t1, t2, t3;
530 mp_err ret;
531 uint8_t high_nibbles = 0;
532 int bitsp;
533 int where = 0;
535 if (bits < 789)
536 return -1;
538 bitsp = (bits + 1) / 2;
540 FIRST(mp_init_multi(&el, &p, &q, &n, &d,
541 &dmp1, &dmq1, &iqmp,
542 &t1, &t2, &t3, NULL));
543 THEN_MP(BN2mpz(&el, e));
546 * randomly pick a pair of high nibbles for p and q to ensure the product's
547 * high nibble is at least 128
549 if (ret == MP_OKAY)
550 ret = (RAND_bytes(&high_nibbles, 1) == 1) ? MP_OKAY : MP_ERR;
551 high_nibbles %= sizeof(pq_high_nibble_pairs);
552 high_nibbles = pq_high_nibble_pairs[high_nibbles];
554 /* generate p and q so that p != q and bits(pq) ~ bits */
555 THEN_MP(gen_p(bitsp, GEN_P, high_nibbles, &p, &el, cb));
556 BN_GENCB_call(cb, 3, 0);
557 THEN_MP(gen_p(bitsp, GEN_Q, high_nibbles, &q, &el, cb));
559 /* make p > q */
560 if (mp_cmp(&p, &q) < 0) {
561 mp_int c;
562 c = p;
563 p = q;
564 q = c;
567 BN_GENCB_call(cb, 3, 1);
569 /* calculate n, n = p * q */
570 THEN_MP(mp_mul(&p, &q, &n));
572 /* calculate d, d = 1/e mod (p - 1)(q - 1) */
573 THEN_MP(mp_sub_d(&p, 1, &t1));
574 THEN_MP(mp_sub_d(&q, 1, &t2));
575 THEN_MP(mp_mul(&t1, &t2, &t3));
576 THEN_MP(mp_invmod(&el, &t3, &d));
578 /* calculate dmp1 dmp1 = d mod (p-1) */
579 THEN_MP(mp_mod(&d, &t1, &dmp1));
580 /* calculate dmq1 dmq1 = d mod (q-1) */
581 THEN_MP(mp_mod(&d, &t2, &dmq1));
582 /* calculate iqmp iqmp = 1/q mod p */
583 THEN_MP(mp_invmod(&q, &p, &iqmp));
585 /* fill in RSA key */
587 if (ret == MP_OKAY) {
588 rsa->e = mpz2BN(&el);
589 rsa->p = mpz2BN(&p);
590 rsa->q = mpz2BN(&q);
591 rsa->n = mpz2BN(&n);
592 rsa->d = mpz2BN(&d);
593 rsa->dmp1 = mpz2BN(&dmp1);
594 rsa->dmq1 = mpz2BN(&dmq1);
595 rsa->iqmp = mpz2BN(&iqmp);
598 mp_clear_multi(&el, &p, &q, &n, &d,
599 &dmp1, &dmq1, &iqmp,
600 &t1, &t2, &t3, NULL);
601 return (ret == MP_OKAY) ? 1 : -where;
604 static int
605 ltm_rsa_init(RSA *rsa)
607 return 1;
610 static int
611 ltm_rsa_finish(RSA *rsa)
613 return 1;
616 const RSA_METHOD hc_rsa_ltm_method = {
617 "hcrypto ltm RSA",
618 ltm_rsa_public_encrypt,
619 ltm_rsa_public_decrypt,
620 ltm_rsa_private_encrypt,
621 ltm_rsa_private_decrypt,
622 NULL,
623 NULL,
624 ltm_rsa_init,
625 ltm_rsa_finish,
627 NULL,
628 NULL,
629 NULL,
630 ltm_rsa_generate_key
633 const RSA_METHOD *
634 RSA_ltm_method(void)
636 return &hc_rsa_ltm_method;