winewayland.drv: Ensure outputs can access xdg information robustly.
[wine.git] / dlls / rsaenh / mpi.c
blobf94693ad5e0a9576f8a311b012156dfd660f1bd9
1 /*
2 * dlls/rsaenh/mpi.c
3 * Multi Precision Integer functions
5 * Copyright 2004 Michael Jung
6 * Based on public domain code by Tom St Denis (tomstdenis@iahu.ca)
8 * This library is free software; you can redistribute it and/or
9 * modify it under the terms of the GNU Lesser General Public
10 * License as published by the Free Software Foundation; either
11 * version 2.1 of the License, or (at your option) any later version.
13 * This library is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16 * Lesser General Public License for more details.
18 * You should have received a copy of the GNU Lesser General Public
19 * License along with this library; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
24 * This file contains code from the LibTomCrypt cryptographic
25 * library written by Tom St Denis (tomstdenis@iahu.ca). LibTomCrypt
26 * is in the public domain. The code in this file is tailored to
27 * special requirements. Take a look at http://libtomcrypt.org for the
28 * original version.
31 #include <stdarg.h>
32 #include <stdlib.h>
34 #include "windef.h"
35 #include "winbase.h"
36 #include "tomcrypt.h"
38 /* Known optimal configurations
39 CPU /Compiler /MUL CUTOFF/SQR CUTOFF
40 -------------------------------------------------------------
41 Intel P4 Northwood /GCC v3.4.1 / 88/ 128/LTM 0.32 ;-)
43 static const int KARATSUBA_MUL_CUTOFF = 88, /* Min. number of digits before Karatsuba multiplication is used. */
44 KARATSUBA_SQR_CUTOFF = 128; /* Min. number of digits before Karatsuba squaring is used. */
47 /* trim unused digits */
48 static void mp_clamp(mp_int *a);
50 /* compare |a| to |b| */
51 static int mp_cmp_mag(const mp_int *a, const mp_int *b);
53 /* Counts the number of lsbs which are zero before the first zero bit */
54 static int mp_cnt_lsb(const mp_int *a);
56 /* computes a = B**n mod b without division or multiplication useful for
57 * normalizing numbers in a Montgomery system.
59 static int mp_montgomery_calc_normalization(mp_int *a, const mp_int *b);
61 /* computes x/R == x (mod N) via Montgomery Reduction */
62 static int mp_montgomery_reduce(mp_int *a, const mp_int *m, mp_digit mp);
64 /* setups the montgomery reduction */
65 static int mp_montgomery_setup(const mp_int *a, mp_digit *mp);
67 /* Barrett Reduction, computes a (mod b) with a precomputed value c
69 * Assumes that 0 < a <= b*b, note if 0 > a > -(b*b) then you can merely
70 * compute the reduction as -1 * mp_reduce(mp_abs(a)) [pseudo code].
72 static int mp_reduce(mp_int *a, const mp_int *b, const mp_int *c);
74 /* reduces a modulo b where b is of the form 2**p - k [0 <= a] */
75 static int mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d);
77 /* determines k value for 2k reduction */
78 static int mp_reduce_2k_setup(const mp_int *a, mp_digit *d);
80 /* used to setup the Barrett reduction for a given modulus b */
81 static int mp_reduce_setup(mp_int *a, const mp_int *b);
83 /* set to a digit */
84 static void mp_set(mp_int *a, mp_digit b);
86 /* b = a*a */
87 static int mp_sqr(const mp_int *a, mp_int *b);
89 /* c = a * a (mod b) */
90 static int mp_sqrmod(const mp_int *a, mp_int *b, mp_int *c);
93 static void bn_reverse(unsigned char *s, int len);
94 static int s_mp_add(mp_int *a, mp_int *b, mp_int *c);
95 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y);
96 #define s_mp_mul(a, b, c) s_mp_mul_digs(a, b, c, (a)->used + (b)->used + 1)
97 static int s_mp_mul_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
98 static int s_mp_mul_high_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
99 static int s_mp_sqr(const mp_int *a, mp_int *b);
100 static int s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c);
101 static int mp_exptmod_fast(const mp_int *G, const mp_int *X, mp_int *P, mp_int *Y, int mode);
102 static int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c);
103 static int mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c);
104 static int mp_karatsuba_sqr(const mp_int *a, mp_int *b);
106 /* grow as required */
107 static int mp_grow (mp_int * a, int size)
109 int i;
110 mp_digit *tmp;
112 /* if the alloc size is smaller alloc more ram */
113 if (a->alloc < size) {
114 /* ensure there are always at least MP_PREC digits extra on top */
115 size += (MP_PREC * 2) - (size % MP_PREC);
117 /* reallocate the array a->dp
119 * We store the return in a temporary variable
120 * in case the operation failed we don't want
121 * to overwrite the dp member of a.
123 tmp = realloc(a->dp, sizeof (mp_digit) * size);
124 if (tmp == NULL) {
125 /* reallocation failed but "a" is still valid [can be freed] */
126 return MP_MEM;
129 /* reallocation succeeded so set a->dp */
130 a->dp = tmp;
132 /* zero excess digits */
133 i = a->alloc;
134 a->alloc = size;
135 for (; i < a->alloc; i++) {
136 a->dp[i] = 0;
139 return MP_OKAY;
142 /* b = a/2 */
143 static int mp_div_2(const mp_int * a, mp_int * b)
145 int x, res, oldused;
147 /* copy */
148 if (b->alloc < a->used) {
149 if ((res = mp_grow (b, a->used)) != MP_OKAY) {
150 return res;
154 oldused = b->used;
155 b->used = a->used;
157 register mp_digit r, rr, *tmpa, *tmpb;
159 /* source alias */
160 tmpa = a->dp + b->used - 1;
162 /* dest alias */
163 tmpb = b->dp + b->used - 1;
165 /* carry */
166 r = 0;
167 for (x = b->used - 1; x >= 0; x--) {
168 /* get the carry for the next iteration */
169 rr = *tmpa & 1;
171 /* shift the current digit, add in carry and store */
172 *tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
174 /* forward carry to next iteration */
175 r = rr;
178 /* zero excess digits */
179 tmpb = b->dp + b->used;
180 for (x = b->used; x < oldused; x++) {
181 *tmpb++ = 0;
184 b->sign = a->sign;
185 mp_clamp (b);
186 return MP_OKAY;
189 /* swap the elements of two integers, for cases where you can't simply swap the
190 * mp_int pointers around
192 static void
193 mp_exch (mp_int * a, mp_int * b)
195 mp_int t;
197 t = *a;
198 *a = *b;
199 *b = t;
202 /* init a new mp_int */
203 static int mp_init (mp_int * a)
205 int i;
207 /* allocate memory required and clear it */
208 a->dp = malloc(sizeof (mp_digit) * MP_PREC);
209 if (a->dp == NULL) {
210 return MP_MEM;
213 /* set the digits to zero */
214 for (i = 0; i < MP_PREC; i++) {
215 a->dp[i] = 0;
218 /* set the used to zero, allocated digits to the default precision
219 * and sign to positive */
220 a->used = 0;
221 a->alloc = MP_PREC;
222 a->sign = MP_ZPOS;
224 return MP_OKAY;
227 /* init an mp_init for a given size */
228 static int mp_init_size (mp_int * a, int size)
230 int x;
232 /* pad size so there are always extra digits */
233 size += (MP_PREC * 2) - (size % MP_PREC);
235 /* alloc mem */
236 a->dp = malloc(sizeof (mp_digit) * size);
237 if (a->dp == NULL) {
238 return MP_MEM;
241 /* set the members */
242 a->used = 0;
243 a->alloc = size;
244 a->sign = MP_ZPOS;
246 /* zero the digits */
247 for (x = 0; x < size; x++) {
248 a->dp[x] = 0;
251 return MP_OKAY;
254 /* clear one (frees) */
255 static void
256 mp_clear (mp_int * a)
258 int i;
260 /* only do anything if a hasn't been freed previously */
261 if (a->dp != NULL) {
262 /* first zero the digits */
263 for (i = 0; i < a->used; i++) {
264 a->dp[i] = 0;
267 /* free ram */
268 free(a->dp);
270 /* reset members to make debugging easier */
271 a->dp = NULL;
272 a->alloc = a->used = 0;
273 a->sign = MP_ZPOS;
277 /* set to zero */
278 static void
279 mp_zero (mp_int * a)
281 a->sign = MP_ZPOS;
282 a->used = 0;
283 memset (a->dp, 0, sizeof (mp_digit) * a->alloc);
286 /* b = |a|
288 * Simple function copies the input and fixes the sign to positive
290 static int
291 mp_abs (const mp_int * a, mp_int * b)
293 int res;
295 /* copy a to b */
296 if (a != b) {
297 if ((res = mp_copy (a, b)) != MP_OKAY) {
298 return res;
302 /* force the sign of b to positive */
303 b->sign = MP_ZPOS;
305 return MP_OKAY;
308 /* computes the modular inverse via binary extended euclidean algorithm,
309 * that is c = 1/a mod b
311 * Based on slow invmod except this is optimized for the case where b is
312 * odd as per HAC Note 14.64 on pp. 610
314 static int
315 fast_mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
317 mp_int x, y, u, v, B, D;
318 int res, neg;
320 /* 2. [modified] b must be odd */
321 if (mp_iseven (b) == 1) {
322 return MP_VAL;
325 /* init all our temps */
326 if ((res = mp_init_multi(&x, &y, &u, &v, &B, &D, NULL)) != MP_OKAY) {
327 return res;
330 /* x == modulus, y == value to invert */
331 if ((res = mp_copy (b, &x)) != MP_OKAY) {
332 goto __ERR;
335 /* we need y = |a| */
336 if ((res = mp_abs (a, &y)) != MP_OKAY) {
337 goto __ERR;
340 /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
341 if ((res = mp_copy (&x, &u)) != MP_OKAY) {
342 goto __ERR;
344 if ((res = mp_copy (&y, &v)) != MP_OKAY) {
345 goto __ERR;
347 mp_set (&D, 1);
349 top:
350 /* 4. while u is even do */
351 while (mp_iseven (&u) == 1) {
352 /* 4.1 u = u/2 */
353 if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
354 goto __ERR;
356 /* 4.2 if B is odd then */
357 if (mp_isodd (&B) == 1) {
358 if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
359 goto __ERR;
362 /* B = B/2 */
363 if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
364 goto __ERR;
368 /* 5. while v is even do */
369 while (mp_iseven (&v) == 1) {
370 /* 5.1 v = v/2 */
371 if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
372 goto __ERR;
374 /* 5.2 if D is odd then */
375 if (mp_isodd (&D) == 1) {
376 /* D = (D-x)/2 */
377 if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
378 goto __ERR;
381 /* D = D/2 */
382 if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
383 goto __ERR;
387 /* 6. if u >= v then */
388 if (mp_cmp (&u, &v) != MP_LT) {
389 /* u = u - v, B = B - D */
390 if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
391 goto __ERR;
394 if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
395 goto __ERR;
397 } else {
398 /* v - v - u, D = D - B */
399 if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
400 goto __ERR;
403 if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
404 goto __ERR;
408 /* if not zero goto step 4 */
409 if (mp_iszero (&u) == 0) {
410 goto top;
413 /* now a = C, b = D, gcd == g*v */
415 /* if v != 1 then there is no inverse */
416 if (mp_cmp_d (&v, 1) != MP_EQ) {
417 res = MP_VAL;
418 goto __ERR;
421 /* b is now the inverse */
422 neg = a->sign;
423 while (D.sign == MP_NEG) {
424 if ((res = mp_add (&D, b, &D)) != MP_OKAY) {
425 goto __ERR;
428 mp_exch (&D, c);
429 c->sign = neg;
430 res = MP_OKAY;
432 __ERR:mp_clear_multi (&x, &y, &u, &v, &B, &D, NULL);
433 return res;
436 /* computes xR**-1 == x (mod N) via Montgomery Reduction
438 * This is an optimized implementation of montgomery_reduce
439 * which uses the comba method to quickly calculate the columns of the
440 * reduction.
442 * Based on Algorithm 14.32 on pp.601 of HAC.
444 static int
445 fast_mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
447 int ix, res, olduse;
448 mp_word W[MP_WARRAY];
450 /* get old used count */
451 olduse = x->used;
453 /* grow a as required */
454 if (x->alloc < n->used + 1) {
455 if ((res = mp_grow (x, n->used + 1)) != MP_OKAY) {
456 return res;
460 /* first we have to get the digits of the input into
461 * an array of double precision words W[...]
464 register mp_word *_W;
465 register mp_digit *tmpx;
467 /* alias for the W[] array */
468 _W = W;
470 /* alias for the digits of x*/
471 tmpx = x->dp;
473 /* copy the digits of a into W[0..a->used-1] */
474 for (ix = 0; ix < x->used; ix++) {
475 *_W++ = *tmpx++;
478 /* zero the high words of W[a->used..m->used*2] */
479 for (; ix < n->used * 2 + 1; ix++) {
480 *_W++ = 0;
484 /* now we proceed to zero successive digits
485 * from the least significant upwards
487 for (ix = 0; ix < n->used; ix++) {
488 /* mu = ai * m' mod b
490 * We avoid a double precision multiplication (which isn't required)
491 * by casting the value down to a mp_digit. Note this requires
492 * that W[ix-1] have the carry cleared (see after the inner loop)
494 register mp_digit mu;
495 mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
497 /* a = a + mu * m * b**i
499 * This is computed in place and on the fly. The multiplication
500 * by b**i is handled by offsetting which columns the results
501 * are added to.
503 * Note the comba method normally doesn't handle carries in the
504 * inner loop In this case we fix the carry from the previous
505 * column since the Montgomery reduction requires digits of the
506 * result (so far) [see above] to work. This is
507 * handled by fixing up one carry after the inner loop. The
508 * carry fixups are done in order so after these loops the
509 * first m->used words of W[] have the carries fixed
512 register int iy;
513 register mp_digit *tmpn;
514 register mp_word *_W;
516 /* alias for the digits of the modulus */
517 tmpn = n->dp;
519 /* Alias for the columns set by an offset of ix */
520 _W = W + ix;
522 /* inner loop */
523 for (iy = 0; iy < n->used; iy++) {
524 *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
528 /* now fix carry for next digit, W[ix+1] */
529 W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
532 /* now we have to propagate the carries and
533 * shift the words downward [all those least
534 * significant digits we zeroed].
537 register mp_digit *tmpx;
538 register mp_word *_W, *_W1;
540 /* nox fix rest of carries */
542 /* alias for current word */
543 _W1 = W + ix;
545 /* alias for next word, where the carry goes */
546 _W = W + ++ix;
548 for (; ix <= n->used * 2 + 1; ix++) {
549 *_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
552 /* copy out, A = A/b**n
554 * The result is A/b**n but instead of converting from an
555 * array of mp_word to mp_digit than calling mp_rshd
556 * we just copy them in the right order
559 /* alias for destination word */
560 tmpx = x->dp;
562 /* alias for shifted double precision result */
563 _W = W + n->used;
565 for (ix = 0; ix < n->used + 1; ix++) {
566 *tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
569 /* zero oldused digits, if the input a was larger than
570 * m->used+1 we'll have to clear the digits
572 for (; ix < olduse; ix++) {
573 *tmpx++ = 0;
577 /* set the max used and clamp */
578 x->used = n->used + 1;
579 mp_clamp (x);
581 /* if A >= m then A = A - m */
582 if (mp_cmp_mag (x, n) != MP_LT) {
583 return s_mp_sub (x, n, x);
585 return MP_OKAY;
588 /* Fast (comba) multiplier
590 * This is the fast column-array [comba] multiplier. It is
591 * designed to compute the columns of the product first
592 * then handle the carries afterwards. This has the effect
593 * of making the nested loops that compute the columns very
594 * simple and schedulable on super-scalar processors.
596 * This has been modified to produce a variable number of
597 * digits of output so if say only a half-product is required
598 * you don't have to compute the upper half (a feature
599 * required for fast Barrett reduction).
601 * Based on Algorithm 14.12 on pp.595 of HAC.
604 static int
605 fast_s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
607 int olduse, res, pa, ix, iz;
608 mp_digit W[MP_WARRAY];
609 register mp_word _W;
611 /* grow the destination as required */
612 if (c->alloc < digs) {
613 if ((res = mp_grow (c, digs)) != MP_OKAY) {
614 return res;
618 /* number of output digits to produce */
619 pa = MIN(digs, a->used + b->used);
621 /* clear the carry */
622 _W = 0;
623 for (ix = 0; ix <= pa; ix++) {
624 int tx, ty;
625 int iy;
626 mp_digit *tmpx, *tmpy;
628 /* get offsets into the two bignums */
629 ty = MIN(b->used-1, ix);
630 tx = ix - ty;
632 /* setup temp aliases */
633 tmpx = a->dp + tx;
634 tmpy = b->dp + ty;
636 /* This is the number of times the loop will iterate, essentially it's
637 while (tx++ < a->used && ty-- >= 0) { ... }
639 iy = MIN(a->used-tx, ty+1);
641 /* execute loop */
642 for (iz = 0; iz < iy; ++iz) {
643 _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
646 /* store term */
647 W[ix] = ((mp_digit)_W) & MP_MASK;
649 /* make next carry */
650 _W = _W >> ((mp_word)DIGIT_BIT);
653 /* setup dest */
654 olduse = c->used;
655 c->used = digs;
658 register mp_digit *tmpc;
659 tmpc = c->dp;
660 for (ix = 0; ix < digs; ix++) {
661 /* now extract the previous digit [below the carry] */
662 *tmpc++ = W[ix];
665 /* clear unused digits [that existed in the old copy of c] */
666 for (; ix < olduse; ix++) {
667 *tmpc++ = 0;
670 mp_clamp (c);
671 return MP_OKAY;
674 /* this is a modified version of fast_s_mul_digs that only produces
675 * output digits *above* digs. See the comments for fast_s_mul_digs
676 * to see how it works.
678 * This is used in the Barrett reduction since for one of the multiplications
679 * only the higher digits were needed. This essentially halves the work.
681 * Based on Algorithm 14.12 on pp.595 of HAC.
683 static int
684 fast_s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
686 int olduse, res, pa, ix, iz;
687 mp_digit W[MP_WARRAY];
688 mp_word _W;
690 /* grow the destination as required */
691 pa = a->used + b->used;
692 if (c->alloc < pa) {
693 if ((res = mp_grow (c, pa)) != MP_OKAY) {
694 return res;
698 /* number of output digits to produce */
699 pa = a->used + b->used;
700 _W = 0;
701 for (ix = digs; ix <= pa; ix++) {
702 int tx, ty, iy;
703 mp_digit *tmpx, *tmpy;
705 /* get offsets into the two bignums */
706 ty = MIN(b->used-1, ix);
707 tx = ix - ty;
709 /* setup temp aliases */
710 tmpx = a->dp + tx;
711 tmpy = b->dp + ty;
713 /* This is the number of times the loop will iterate, essentially it's
714 while (tx++ < a->used && ty-- >= 0) { ... }
716 iy = MIN(a->used-tx, ty+1);
718 /* execute loop */
719 for (iz = 0; iz < iy; iz++) {
720 _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
723 /* store term */
724 W[ix] = ((mp_digit)_W) & MP_MASK;
726 /* make next carry */
727 _W = _W >> ((mp_word)DIGIT_BIT);
730 /* setup dest */
731 olduse = c->used;
732 c->used = pa;
735 register mp_digit *tmpc;
737 tmpc = c->dp + digs;
738 for (ix = digs; ix <= pa; ix++) {
739 /* now extract the previous digit [below the carry] */
740 *tmpc++ = W[ix];
743 /* clear unused digits [that existed in the old copy of c] */
744 for (; ix < olduse; ix++) {
745 *tmpc++ = 0;
748 mp_clamp (c);
749 return MP_OKAY;
752 /* fast squaring
754 * This is the comba method where the columns of the product
755 * are computed first then the carries are computed. This
756 * has the effect of making a very simple inner loop that
757 * is executed the most
759 * W2 represents the outer products and W the inner.
761 * A further optimizations is made because the inner
762 * products are of the form "A * B * 2". The *2 part does
763 * not need to be computed until the end which is good
764 * because 64-bit shifts are slow!
766 * Based on Algorithm 14.16 on pp.597 of HAC.
769 /* the jist of squaring...
771 you do like mult except the offset of the tmpx [one that starts closer to zero]
772 can't equal the offset of tmpy. So basically you set up iy like before then you min it with
773 (ty-tx) so that it never happens. You double all those you add in the inner loop
775 After that loop you do the squares and add them in.
777 Remove W2 and don't memset W
781 static int fast_s_mp_sqr (const mp_int * a, mp_int * b)
783 int olduse, res, pa, ix, iz;
784 mp_digit W[MP_WARRAY], *tmpx;
785 mp_word W1;
787 /* grow the destination as required */
788 pa = a->used + a->used;
789 if (b->alloc < pa) {
790 if ((res = mp_grow (b, pa)) != MP_OKAY) {
791 return res;
795 /* number of output digits to produce */
796 W1 = 0;
797 for (ix = 0; ix <= pa; ix++) {
798 int tx, ty, iy;
799 mp_word _W;
800 mp_digit *tmpy;
802 /* clear counter */
803 _W = 0;
805 /* get offsets into the two bignums */
806 ty = MIN(a->used-1, ix);
807 tx = ix - ty;
809 /* setup temp aliases */
810 tmpx = a->dp + tx;
811 tmpy = a->dp + ty;
813 /* This is the number of times the loop will iterate, essentially it's
814 while (tx++ < a->used && ty-- >= 0) { ... }
816 iy = MIN(a->used-tx, ty+1);
818 /* now for squaring tx can never equal ty
819 * we halve the distance since they approach at a rate of 2x
820 * and we have to round because odd cases need to be executed
822 iy = MIN(iy, (ty-tx+1)>>1);
824 /* execute loop */
825 for (iz = 0; iz < iy; iz++) {
826 _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
829 /* double the inner product and add carry */
830 _W = _W + _W + W1;
832 /* even columns have the square term in them */
833 if ((ix&1) == 0) {
834 _W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
837 /* store it */
838 W[ix] = _W;
840 /* make next carry */
841 W1 = _W >> ((mp_word)DIGIT_BIT);
844 /* setup dest */
845 olduse = b->used;
846 b->used = a->used+a->used;
849 mp_digit *tmpb;
850 tmpb = b->dp;
851 for (ix = 0; ix < pa; ix++) {
852 *tmpb++ = W[ix] & MP_MASK;
855 /* clear unused digits [that existed in the old copy of c] */
856 for (; ix < olduse; ix++) {
857 *tmpb++ = 0;
860 mp_clamp (b);
861 return MP_OKAY;
864 /* computes a = 2**b
866 * Simple algorithm which zeroes the int, grows it then just sets one bit
867 * as required.
869 static int
870 mp_2expt (mp_int * a, int b)
872 int res;
874 /* zero a as per default */
875 mp_zero (a);
877 /* grow a to accommodate the single bit */
878 if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) {
879 return res;
882 /* set the used count of where the bit will go */
883 a->used = b / DIGIT_BIT + 1;
885 /* put the single bit in its place */
886 a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
888 return MP_OKAY;
891 /* high level addition (handles signs) */
892 int mp_add (mp_int * a, mp_int * b, mp_int * c)
894 int sa, sb, res;
896 /* get sign of both inputs */
897 sa = a->sign;
898 sb = b->sign;
900 /* handle two cases, not four */
901 if (sa == sb) {
902 /* both positive or both negative */
903 /* add their magnitudes, copy the sign */
904 c->sign = sa;
905 res = s_mp_add (a, b, c);
906 } else {
907 /* one positive, the other negative */
908 /* subtract the one with the greater magnitude from */
909 /* the one of the lesser magnitude. The result gets */
910 /* the sign of the one with the greater magnitude. */
911 if (mp_cmp_mag (a, b) == MP_LT) {
912 c->sign = sb;
913 res = s_mp_sub (b, a, c);
914 } else {
915 c->sign = sa;
916 res = s_mp_sub (a, b, c);
919 return res;
923 /* single digit addition */
924 static int
925 mp_add_d (mp_int * a, mp_digit b, mp_int * c)
927 int res, ix, oldused;
928 mp_digit *tmpa, *tmpc, mu;
930 /* grow c as required */
931 if (c->alloc < a->used + 1) {
932 if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
933 return res;
937 /* if a is negative and |a| >= b, call c = |a| - b */
938 if (a->sign == MP_NEG && (a->used > 1 || a->dp[0] >= b)) {
939 /* temporarily fix sign of a */
940 a->sign = MP_ZPOS;
942 /* c = |a| - b */
943 res = mp_sub_d(a, b, c);
945 /* fix sign */
946 a->sign = c->sign = MP_NEG;
948 return res;
951 /* old number of used digits in c */
952 oldused = c->used;
954 /* sign always positive */
955 c->sign = MP_ZPOS;
957 /* source alias */
958 tmpa = a->dp;
960 /* destination alias */
961 tmpc = c->dp;
963 /* if a is positive */
964 if (a->sign == MP_ZPOS) {
965 /* add digit, after this we're propagating
966 * the carry.
968 *tmpc = *tmpa++ + b;
969 mu = *tmpc >> DIGIT_BIT;
970 *tmpc++ &= MP_MASK;
972 /* now handle rest of the digits */
973 for (ix = 1; ix < a->used; ix++) {
974 *tmpc = *tmpa++ + mu;
975 mu = *tmpc >> DIGIT_BIT;
976 *tmpc++ &= MP_MASK;
978 /* set final carry */
979 ix++;
980 *tmpc++ = mu;
982 /* setup size */
983 c->used = a->used + 1;
984 } else {
985 /* a was negative and |a| < b */
986 c->used = 1;
988 /* the result is a single digit */
989 if (a->used == 1) {
990 *tmpc++ = b - a->dp[0];
991 } else {
992 *tmpc++ = b;
995 /* setup count so the clearing of oldused
996 * can fall through correctly
998 ix = 1;
1001 /* now zero to oldused */
1002 while (ix++ < oldused) {
1003 *tmpc++ = 0;
1005 mp_clamp(c);
1007 return MP_OKAY;
1010 /* trim unused digits
1012 * This is used to ensure that leading zero digits are
1013 * trimmed and the leading "used" digit will be non-zero
1014 * Typically very fast. Also fixes the sign if there
1015 * are no more leading digits
1017 void
1018 mp_clamp (mp_int * a)
1020 /* decrease used while the most significant digit is
1021 * zero.
1023 while (a->used > 0 && a->dp[a->used - 1] == 0) {
1024 --(a->used);
1027 /* reset the sign flag if used == 0 */
1028 if (a->used == 0) {
1029 a->sign = MP_ZPOS;
1033 void WINAPIV mp_clear_multi(mp_int *mp, ...)
1035 mp_int* next_mp = mp;
1036 va_list args;
1037 va_start(args, mp);
1038 while (next_mp != NULL) {
1039 mp_clear(next_mp);
1040 next_mp = va_arg(args, mp_int*);
1042 va_end(args);
1045 /* compare two ints (signed)*/
1047 mp_cmp (const mp_int * a, const mp_int * b)
1049 /* compare based on sign */
1050 if (a->sign != b->sign) {
1051 if (a->sign == MP_NEG) {
1052 return MP_LT;
1053 } else {
1054 return MP_GT;
1058 /* compare digits */
1059 if (a->sign == MP_NEG) {
1060 /* if negative compare opposite direction */
1061 return mp_cmp_mag(b, a);
1062 } else {
1063 return mp_cmp_mag(a, b);
1067 /* compare a digit */
1068 int mp_cmp_d(const mp_int * a, mp_digit b)
1070 /* compare based on sign */
1071 if (a->sign == MP_NEG) {
1072 return MP_LT;
1075 /* compare based on magnitude */
1076 if (a->used > 1) {
1077 return MP_GT;
1080 /* compare the only digit of a to b */
1081 if (a->dp[0] > b) {
1082 return MP_GT;
1083 } else if (a->dp[0] < b) {
1084 return MP_LT;
1085 } else {
1086 return MP_EQ;
1090 /* compare maginitude of two ints (unsigned) */
1091 int mp_cmp_mag (const mp_int * a, const mp_int * b)
1093 int n;
1094 mp_digit *tmpa, *tmpb;
1096 /* compare based on # of non-zero digits */
1097 if (a->used > b->used) {
1098 return MP_GT;
1101 if (a->used < b->used) {
1102 return MP_LT;
1105 /* alias for a */
1106 tmpa = a->dp + (a->used - 1);
1108 /* alias for b */
1109 tmpb = b->dp + (a->used - 1);
1111 /* compare based on digits */
1112 for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
1113 if (*tmpa > *tmpb) {
1114 return MP_GT;
1117 if (*tmpa < *tmpb) {
1118 return MP_LT;
1121 return MP_EQ;
1124 static const int lnz[16] = {
1125 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
1128 /* Counts the number of lsbs which are zero before the first zero bit */
1129 int mp_cnt_lsb(const mp_int *a)
1131 int x;
1132 mp_digit q, qq;
1134 /* easy out */
1135 if (mp_iszero(a) == 1) {
1136 return 0;
1139 /* scan lower digits until non-zero */
1140 for (x = 0; x < a->used && a->dp[x] == 0; x++);
1141 q = a->dp[x];
1142 x *= DIGIT_BIT;
1144 /* now scan this digit until a 1 is found */
1145 if ((q & 1) == 0) {
1146 do {
1147 qq = q & 15;
1148 x += lnz[qq];
1149 q >>= 4;
1150 } while (qq == 0);
1152 return x;
1155 /* copy, b = a */
1157 mp_copy (const mp_int * a, mp_int * b)
1159 int res, n;
1161 /* if dst == src do nothing */
1162 if (a == b) {
1163 return MP_OKAY;
1166 /* grow dest */
1167 if (b->alloc < a->used) {
1168 if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1169 return res;
1173 /* zero b and copy the parameters over */
1175 register mp_digit *tmpa, *tmpb;
1177 /* pointer aliases */
1179 /* source */
1180 tmpa = a->dp;
1182 /* destination */
1183 tmpb = b->dp;
1185 /* copy all the digits */
1186 for (n = 0; n < a->used; n++) {
1187 *tmpb++ = *tmpa++;
1190 /* clear high digits */
1191 for (; n < b->used; n++) {
1192 *tmpb++ = 0;
1196 /* copy used count and sign */
1197 b->used = a->used;
1198 b->sign = a->sign;
1199 return MP_OKAY;
1202 /* returns the number of bits in an int */
1204 mp_count_bits (const mp_int * a)
1206 int r;
1207 mp_digit q;
1209 /* shortcut */
1210 if (a->used == 0) {
1211 return 0;
1214 /* get number of digits and add that */
1215 r = (a->used - 1) * DIGIT_BIT;
1217 /* take the last digit and count the bits in it */
1218 q = a->dp[a->used - 1];
1219 while (q > 0) {
1220 ++r;
1221 q >>= ((mp_digit) 1);
1223 return r;
1226 /* calc a value mod 2**b */
1227 static int
1228 mp_mod_2d (const mp_int * a, int b, mp_int * c)
1230 int x, res;
1232 /* if b is <= 0 then zero the int */
1233 if (b <= 0) {
1234 mp_zero (c);
1235 return MP_OKAY;
1238 /* if the modulus is larger than the value than return */
1239 if (b > a->used * DIGIT_BIT) {
1240 res = mp_copy (a, c);
1241 return res;
1244 /* copy */
1245 if ((res = mp_copy (a, c)) != MP_OKAY) {
1246 return res;
1249 /* zero digits above the last digit of the modulus */
1250 for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
1251 c->dp[x] = 0;
1253 /* clear the digit that is not completely outside/inside the modulus */
1254 c->dp[b / DIGIT_BIT] &= (1 << ((mp_digit)b % DIGIT_BIT)) - 1;
1255 mp_clamp (c);
1256 return MP_OKAY;
1259 /* shift right a certain amount of digits */
1260 static void mp_rshd (mp_int * a, int b)
1262 int x;
1264 /* if b <= 0 then ignore it */
1265 if (b <= 0) {
1266 return;
1269 /* if b > used then simply zero it and return */
1270 if (a->used <= b) {
1271 mp_zero (a);
1272 return;
1276 register mp_digit *bottom, *top;
1278 /* shift the digits down */
1280 /* bottom */
1281 bottom = a->dp;
1283 /* top [offset into digits] */
1284 top = a->dp + b;
1286 /* this is implemented as a sliding window where
1287 * the window is b-digits long and digits from
1288 * the top of the window are copied to the bottom
1290 * e.g.
1292 b-2 | b-1 | b0 | b1 | b2 | ... | bb | ---->
1293 /\ | ---->
1294 \-------------------/ ---->
1296 for (x = 0; x < (a->used - b); x++) {
1297 *bottom++ = *top++;
1300 /* zero the top digits */
1301 for (; x < a->used; x++) {
1302 *bottom++ = 0;
1306 /* remove excess digits */
1307 a->used -= b;
1310 /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
1311 static int mp_div_2d (const mp_int * a, int b, mp_int * c, mp_int * d)
1313 mp_digit D, r, rr;
1314 int x, res;
1315 mp_int t;
1318 /* if the shift count is <= 0 then we do no work */
1319 if (b <= 0) {
1320 res = mp_copy (a, c);
1321 if (d != NULL) {
1322 mp_zero (d);
1324 return res;
1327 if ((res = mp_init (&t)) != MP_OKAY) {
1328 return res;
1331 /* get the remainder */
1332 if (d != NULL) {
1333 if ((res = mp_mod_2d (a, b, &t)) != MP_OKAY) {
1334 mp_clear (&t);
1335 return res;
1339 /* copy */
1340 if ((res = mp_copy (a, c)) != MP_OKAY) {
1341 mp_clear (&t);
1342 return res;
1345 /* shift by as many digits in the bit count */
1346 if (b >= DIGIT_BIT) {
1347 mp_rshd (c, b / DIGIT_BIT);
1350 /* shift any bit count < DIGIT_BIT */
1351 D = (mp_digit) (b % DIGIT_BIT);
1352 if (D != 0) {
1353 register mp_digit *tmpc, mask, shift;
1355 /* mask */
1356 mask = (((mp_digit)1) << D) - 1;
1358 /* shift for lsb */
1359 shift = DIGIT_BIT - D;
1361 /* alias */
1362 tmpc = c->dp + (c->used - 1);
1364 /* carry */
1365 r = 0;
1366 for (x = c->used - 1; x >= 0; x--) {
1367 /* get the lower bits of this word in a temp */
1368 rr = *tmpc & mask;
1370 /* shift the current word and mix in the carry bits from the previous word */
1371 *tmpc = (*tmpc >> D) | (r << shift);
1372 --tmpc;
1374 /* set the carry to the carry bits of the current word found above */
1375 r = rr;
1378 mp_clamp (c);
1379 if (d != NULL) {
1380 mp_exch (&t, d);
1382 mp_clear (&t);
1383 return MP_OKAY;
1386 /* shift left a certain amount of digits */
1387 static int mp_lshd (mp_int * a, int b)
1389 int x, res;
1391 /* if it's less than zero return */
1392 if (b <= 0) {
1393 return MP_OKAY;
1396 /* grow to fit the new digits */
1397 if (a->alloc < a->used + b) {
1398 if ((res = mp_grow (a, a->used + b)) != MP_OKAY) {
1399 return res;
1404 register mp_digit *top, *bottom;
1406 /* increment the used by the shift amount then copy upwards */
1407 a->used += b;
1409 /* top */
1410 top = a->dp + a->used - 1;
1412 /* base */
1413 bottom = a->dp + a->used - 1 - b;
1415 /* much like mp_rshd this is implemented using a sliding window
1416 * except the window goes the other way around. Copying from
1417 * the bottom to the top. see bn_mp_rshd.c for more info.
1419 for (x = a->used - 1; x >= b; x--) {
1420 *top-- = *bottom--;
1423 /* zero the lower digits */
1424 top = a->dp;
1425 for (x = 0; x < b; x++) {
1426 *top++ = 0;
1429 return MP_OKAY;
1432 /* shift left by a certain bit count */
1433 static int mp_mul_2d (const mp_int * a, int b, mp_int * c)
1435 mp_digit d;
1436 int res;
1438 /* copy */
1439 if (a != c) {
1440 if ((res = mp_copy (a, c)) != MP_OKAY) {
1441 return res;
1445 if (c->alloc < c->used + b/DIGIT_BIT + 1) {
1446 if ((res = mp_grow (c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
1447 return res;
1451 /* shift by as many digits in the bit count */
1452 if (b >= DIGIT_BIT) {
1453 if ((res = mp_lshd (c, b / DIGIT_BIT)) != MP_OKAY) {
1454 return res;
1458 /* shift any bit count < DIGIT_BIT */
1459 d = (mp_digit) (b % DIGIT_BIT);
1460 if (d != 0) {
1461 register mp_digit *tmpc, shift, mask, r, rr;
1462 register int x;
1464 /* bitmask for carries */
1465 mask = (((mp_digit)1) << d) - 1;
1467 /* shift for msbs */
1468 shift = DIGIT_BIT - d;
1470 /* alias */
1471 tmpc = c->dp;
1473 /* carry */
1474 r = 0;
1475 for (x = 0; x < c->used; x++) {
1476 /* get the higher bits of the current word */
1477 rr = (*tmpc >> shift) & mask;
1479 /* shift the current word and OR in the carry */
1480 *tmpc = ((*tmpc << d) | r) & MP_MASK;
1481 ++tmpc;
1483 /* set the carry to the carry bits of the current word */
1484 r = rr;
1487 /* set final carry */
1488 if (r != 0) {
1489 c->dp[(c->used)++] = r;
1492 mp_clamp (c);
1493 return MP_OKAY;
1496 /* multiply by a digit */
1497 static int
1498 mp_mul_d (const mp_int * a, mp_digit b, mp_int * c)
1500 mp_digit u, *tmpa, *tmpc;
1501 mp_word r;
1502 int ix, res, olduse;
1504 /* make sure c is big enough to hold a*b */
1505 if (c->alloc < a->used + 1) {
1506 if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
1507 return res;
1511 /* get the original destinations used count */
1512 olduse = c->used;
1514 /* set the sign */
1515 c->sign = a->sign;
1517 /* alias for a->dp [source] */
1518 tmpa = a->dp;
1520 /* alias for c->dp [dest] */
1521 tmpc = c->dp;
1523 /* zero carry */
1524 u = 0;
1526 /* compute columns */
1527 for (ix = 0; ix < a->used; ix++) {
1528 /* compute product and carry sum for this term */
1529 r = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);
1531 /* mask off higher bits to get a single digit */
1532 *tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
1534 /* send carry into next iteration */
1535 u = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
1538 /* store final carry [if any] */
1539 *tmpc++ = u;
1541 /* now zero digits above the top */
1542 while (ix++ < olduse) {
1543 *tmpc++ = 0;
1546 /* set used count */
1547 c->used = a->used + 1;
1548 mp_clamp(c);
1550 return MP_OKAY;
1553 /* integer signed division.
1554 * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
1555 * HAC pp.598 Algorithm 14.20
1557 * Note that the description in HAC is horribly
1558 * incomplete. For example, it doesn't consider
1559 * the case where digits are removed from 'x' in
1560 * the inner loop. It also doesn't consider the
1561 * case that y has fewer than three digits, etc..
1563 * The overall algorithm is as described as
1564 * 14.20 from HAC but fixed to treat these cases.
1566 static int mp_div (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
1568 mp_int q, x, y, t1, t2;
1569 int res, n, t, i, norm, neg;
1571 /* is divisor zero ? */
1572 if (mp_iszero (b) == 1) {
1573 return MP_VAL;
1576 /* if a < b then q=0, r = a */
1577 if (mp_cmp_mag (a, b) == MP_LT) {
1578 if (d != NULL) {
1579 res = mp_copy (a, d);
1580 } else {
1581 res = MP_OKAY;
1583 if (c != NULL) {
1584 mp_zero (c);
1586 return res;
1589 if ((res = mp_init_size (&q, a->used + 2)) != MP_OKAY) {
1590 return res;
1592 q.used = a->used + 2;
1594 if ((res = mp_init (&t1)) != MP_OKAY) {
1595 goto __Q;
1598 if ((res = mp_init (&t2)) != MP_OKAY) {
1599 goto __T1;
1602 if ((res = mp_init_copy (&x, a)) != MP_OKAY) {
1603 goto __T2;
1606 if ((res = mp_init_copy (&y, b)) != MP_OKAY) {
1607 goto __X;
1610 /* fix the sign */
1611 neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
1612 x.sign = y.sign = MP_ZPOS;
1614 /* normalize both x and y, ensure that y >= b/2, [b == 2**DIGIT_BIT] */
1615 norm = mp_count_bits(&y) % DIGIT_BIT;
1616 if (norm < DIGIT_BIT-1) {
1617 norm = (DIGIT_BIT-1) - norm;
1618 if ((res = mp_mul_2d (&x, norm, &x)) != MP_OKAY) {
1619 goto __Y;
1621 if ((res = mp_mul_2d (&y, norm, &y)) != MP_OKAY) {
1622 goto __Y;
1624 } else {
1625 norm = 0;
1628 /* note hac does 0 based, so if used==5 then it's 0,1,2,3,4, e.g. use 4 */
1629 n = x.used - 1;
1630 t = y.used - 1;
1632 /* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
1633 if ((res = mp_lshd (&y, n - t)) != MP_OKAY) { /* y = y*b**{n-t} */
1634 goto __Y;
1637 while (mp_cmp (&x, &y) != MP_LT) {
1638 ++(q.dp[n - t]);
1639 if ((res = mp_sub (&x, &y, &x)) != MP_OKAY) {
1640 goto __Y;
1644 /* reset y by shifting it back down */
1645 mp_rshd (&y, n - t);
1647 /* step 3. for i from n down to (t + 1) */
1648 for (i = n; i >= (t + 1); i--) {
1649 if (i > x.used) {
1650 continue;
1653 /* step 3.1 if xi == yt then set q{i-t-1} to b-1,
1654 * otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
1655 if (x.dp[i] == y.dp[t]) {
1656 q.dp[i - t - 1] = ((((mp_digit)1) << DIGIT_BIT) - 1);
1657 } else {
1658 mp_word tmp;
1659 tmp = ((mp_word) x.dp[i]) << ((mp_word) DIGIT_BIT);
1660 tmp |= ((mp_word) x.dp[i - 1]);
1661 tmp /= ((mp_word) y.dp[t]);
1662 if (tmp > (mp_word) MP_MASK)
1663 tmp = MP_MASK;
1664 q.dp[i - t - 1] = (mp_digit) (tmp & (mp_word) (MP_MASK));
1667 /* while (q{i-t-1} * (yt * b + y{t-1})) >
1668 xi * b**2 + xi-1 * b + xi-2
1670 do q{i-t-1} -= 1;
1672 q.dp[i - t - 1] = (q.dp[i - t - 1] + 1) & MP_MASK;
1673 do {
1674 q.dp[i - t - 1] = (q.dp[i - t - 1] - 1) & MP_MASK;
1676 /* find left hand */
1677 mp_zero (&t1);
1678 t1.dp[0] = (t - 1 < 0) ? 0 : y.dp[t - 1];
1679 t1.dp[1] = y.dp[t];
1680 t1.used = 2;
1681 if ((res = mp_mul_d (&t1, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1682 goto __Y;
1685 /* find right hand */
1686 t2.dp[0] = (i - 2 < 0) ? 0 : x.dp[i - 2];
1687 t2.dp[1] = (i - 1 < 0) ? 0 : x.dp[i - 1];
1688 t2.dp[2] = x.dp[i];
1689 t2.used = 3;
1690 } while (mp_cmp_mag(&t1, &t2) == MP_GT);
1692 /* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
1693 if ((res = mp_mul_d (&y, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1694 goto __Y;
1697 if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1698 goto __Y;
1701 if ((res = mp_sub (&x, &t1, &x)) != MP_OKAY) {
1702 goto __Y;
1705 /* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
1706 if (x.sign == MP_NEG) {
1707 if ((res = mp_copy (&y, &t1)) != MP_OKAY) {
1708 goto __Y;
1710 if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1711 goto __Y;
1713 if ((res = mp_add (&x, &t1, &x)) != MP_OKAY) {
1714 goto __Y;
1717 q.dp[i - t - 1] = (q.dp[i - t - 1] - 1UL) & MP_MASK;
1721 /* now q is the quotient and x is the remainder
1722 * [which we have to normalize]
1725 /* get sign before writing to c */
1726 x.sign = x.used == 0 ? MP_ZPOS : a->sign;
1728 if (c != NULL) {
1729 mp_clamp (&q);
1730 mp_exch (&q, c);
1731 c->sign = neg;
1734 if (d != NULL) {
1735 mp_div_2d (&x, norm, &x, NULL);
1736 mp_exch (&x, d);
1739 res = MP_OKAY;
1741 __Y:mp_clear (&y);
1742 __X:mp_clear (&x);
1743 __T2:mp_clear (&t2);
1744 __T1:mp_clear (&t1);
1745 __Q:mp_clear (&q);
1746 return res;
1749 static BOOL s_is_power_of_two(mp_digit b, int *p)
1751 int x;
1753 for (x = 1; x < DIGIT_BIT; x++) {
1754 if (b == (((mp_digit)1)<<x)) {
1755 *p = x;
1756 return TRUE;
1759 return FALSE;
1762 /* single digit division (based on routine from MPI) */
1763 static int mp_div_d (const mp_int * a, mp_digit b, mp_int * c, mp_digit * d)
1765 mp_int q;
1766 mp_word w;
1767 mp_digit t;
1768 int res, ix;
1770 /* cannot divide by zero */
1771 if (b == 0) {
1772 return MP_VAL;
1775 /* quick outs */
1776 if (b == 1 || mp_iszero(a) == 1) {
1777 if (d != NULL) {
1778 *d = 0;
1780 if (c != NULL) {
1781 return mp_copy(a, c);
1783 return MP_OKAY;
1786 /* power of two ? */
1787 if (s_is_power_of_two(b, &ix)) {
1788 if (d != NULL) {
1789 *d = a->dp[0] & ((((mp_digit)1)<<ix) - 1);
1791 if (c != NULL) {
1792 return mp_div_2d(a, ix, c, NULL);
1794 return MP_OKAY;
1797 /* no easy answer [c'est la vie]. Just division */
1798 if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
1799 return res;
1802 q.used = a->used;
1803 q.sign = a->sign;
1804 w = 0;
1805 for (ix = a->used - 1; ix >= 0; ix--) {
1806 w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
1808 if (w >= b) {
1809 t = (mp_digit)(w / b);
1810 w -= ((mp_word)t) * ((mp_word)b);
1811 } else {
1812 t = 0;
1814 q.dp[ix] = t;
1817 if (d != NULL) {
1818 *d = (mp_digit)w;
1821 if (c != NULL) {
1822 mp_clamp(&q);
1823 mp_exch(&q, c);
1825 mp_clear(&q);
1827 return res;
1830 /* reduce "x" in place modulo "n" using the Diminished Radix algorithm.
1832 * Based on algorithm from the paper
1834 * "Generating Efficient Primes for Discrete Log Cryptosystems"
1835 * Chae Hoon Lim, Pil Loong Lee,
1836 * POSTECH Information Research Laboratories
1838 * The modulus must be of a special format [see manual]
1840 * Has been modified to use algorithm 7.10 from the LTM book instead
1842 * Input x must be in the range 0 <= x <= (n-1)**2
1844 static int
1845 mp_dr_reduce (mp_int * x, const mp_int * n, mp_digit k)
1847 int err, i, m;
1848 mp_word r;
1849 mp_digit mu, *tmpx1, *tmpx2;
1851 /* m = digits in modulus */
1852 m = n->used;
1854 /* ensure that "x" has at least 2m digits */
1855 if (x->alloc < m + m) {
1856 if ((err = mp_grow (x, m + m)) != MP_OKAY) {
1857 return err;
1861 /* top of loop, this is where the code resumes if
1862 * another reduction pass is required.
1864 top:
1865 /* aliases for digits */
1866 /* alias for lower half of x */
1867 tmpx1 = x->dp;
1869 /* alias for upper half of x, or x/B**m */
1870 tmpx2 = x->dp + m;
1872 /* set carry to zero */
1873 mu = 0;
1875 /* compute (x mod B**m) + k * [x/B**m] inline and inplace */
1876 for (i = 0; i < m; i++) {
1877 r = ((mp_word)*tmpx2++) * ((mp_word)k) + *tmpx1 + mu;
1878 *tmpx1++ = (mp_digit)(r & MP_MASK);
1879 mu = (mp_digit)(r >> ((mp_word)DIGIT_BIT));
1882 /* set final carry */
1883 *tmpx1++ = mu;
1885 /* zero words above m */
1886 for (i = m + 1; i < x->used; i++) {
1887 *tmpx1++ = 0;
1890 /* clamp, sub and return */
1891 mp_clamp (x);
1893 /* if x >= n then subtract and reduce again
1894 * Each successive "recursion" makes the input smaller and smaller.
1896 if (mp_cmp_mag (x, n) != MP_LT) {
1897 s_mp_sub(x, n, x);
1898 goto top;
1900 return MP_OKAY;
1903 /* sets the value of "d" required for mp_dr_reduce */
1904 static void mp_dr_setup(const mp_int *a, mp_digit *d)
1906 /* the casts are required if DIGIT_BIT is one less than
1907 * the number of bits in a mp_digit [e.g. DIGIT_BIT==31]
1909 *d = (mp_digit)((((mp_word)1) << ((mp_word)DIGIT_BIT)) -
1910 ((mp_word)a->dp[0]));
1913 /* this is a shell function that calls either the normal or Montgomery
1914 * exptmod functions. Originally the call to the montgomery code was
1915 * embedded in the normal function but that wasted a lot of stack space
1916 * for nothing (since 99% of the time the Montgomery code would be called)
1918 int mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
1920 int dr;
1922 /* modulus P must be positive */
1923 if (P->sign == MP_NEG) {
1924 return MP_VAL;
1927 /* if exponent X is negative we have to recurse */
1928 if (X->sign == MP_NEG) {
1929 mp_int tmpG, tmpX;
1930 int err;
1932 /* first compute 1/G mod P */
1933 if ((err = mp_init(&tmpG)) != MP_OKAY) {
1934 return err;
1936 if ((err = mp_invmod(G, P, &tmpG)) != MP_OKAY) {
1937 mp_clear(&tmpG);
1938 return err;
1941 /* now get |X| */
1942 if ((err = mp_init(&tmpX)) != MP_OKAY) {
1943 mp_clear(&tmpG);
1944 return err;
1946 if ((err = mp_abs(X, &tmpX)) != MP_OKAY) {
1947 mp_clear_multi(&tmpG, &tmpX, NULL);
1948 return err;
1951 /* and now compute (1/G)**|X| instead of G**X [X < 0] */
1952 err = mp_exptmod(&tmpG, &tmpX, P, Y);
1953 mp_clear_multi(&tmpG, &tmpX, NULL);
1954 return err;
1957 dr = 0;
1959 /* if the modulus is odd use the fast method */
1960 if (mp_isodd (P) == 1) {
1961 return mp_exptmod_fast (G, X, P, Y, dr);
1962 } else {
1963 /* otherwise use the generic Barrett reduction technique */
1964 return s_mp_exptmod (G, X, P, Y);
1968 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
1970 * Uses a left-to-right k-ary sliding window to compute the modular exponentiation.
1971 * The value of k changes based on the size of the exponent.
1973 * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
1977 mp_exptmod_fast (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y, int redmode)
1979 mp_int M[256], res;
1980 mp_digit buf, mp;
1981 int err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
1983 /* use a pointer to the reduction algorithm. This allows us to use
1984 * one of many reduction algorithms without modding the guts of
1985 * the code with if statements everywhere.
1987 int (*redux)(mp_int*,const mp_int*,mp_digit);
1989 /* find window size */
1990 x = mp_count_bits (X);
1991 if (x <= 7) {
1992 winsize = 2;
1993 } else if (x <= 36) {
1994 winsize = 3;
1995 } else if (x <= 140) {
1996 winsize = 4;
1997 } else if (x <= 450) {
1998 winsize = 5;
1999 } else if (x <= 1303) {
2000 winsize = 6;
2001 } else if (x <= 3529) {
2002 winsize = 7;
2003 } else {
2004 winsize = 8;
2007 /* init M array */
2008 /* init first cell */
2009 if ((err = mp_init(&M[1])) != MP_OKAY) {
2010 return err;
2013 /* now init the second half of the array */
2014 for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2015 if ((err = mp_init(&M[x])) != MP_OKAY) {
2016 for (y = 1<<(winsize-1); y < x; y++) {
2017 mp_clear (&M[y]);
2019 mp_clear(&M[1]);
2020 return err;
2024 /* determine and setup reduction code */
2025 if (redmode == 0) {
2026 /* now setup montgomery */
2027 if ((err = mp_montgomery_setup (P, &mp)) != MP_OKAY) {
2028 goto __M;
2031 /* automatically pick the comba one if available (saves quite a few calls/ifs) */
2032 if (((P->used * 2 + 1) < MP_WARRAY) &&
2033 P->used < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2034 redux = fast_mp_montgomery_reduce;
2035 } else {
2036 /* use slower baseline Montgomery method */
2037 redux = mp_montgomery_reduce;
2039 } else if (redmode == 1) {
2040 /* setup DR reduction for moduli of the form B**k - b */
2041 mp_dr_setup(P, &mp);
2042 redux = mp_dr_reduce;
2043 } else {
2044 /* setup DR reduction for moduli of the form 2**k - b */
2045 if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
2046 goto __M;
2048 redux = mp_reduce_2k;
2051 /* setup result */
2052 if ((err = mp_init (&res)) != MP_OKAY) {
2053 goto __M;
2056 /* create M table
2060 * The first half of the table is not computed though accept for M[0] and M[1]
2063 if (redmode == 0) {
2064 /* now we need R mod m */
2065 if ((err = mp_montgomery_calc_normalization (&res, P)) != MP_OKAY) {
2066 goto __RES;
2069 /* now set M[1] to G * R mod m */
2070 if ((err = mp_mulmod (G, &res, P, &M[1])) != MP_OKAY) {
2071 goto __RES;
2073 } else {
2074 mp_set(&res, 1);
2075 if ((err = mp_mod(G, P, &M[1])) != MP_OKAY) {
2076 goto __RES;
2080 /* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times */
2081 if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
2082 goto __RES;
2085 for (x = 0; x < (winsize - 1); x++) {
2086 if ((err = mp_sqr (&M[1 << (winsize - 1)], &M[1 << (winsize - 1)])) != MP_OKAY) {
2087 goto __RES;
2089 if ((err = redux (&M[1 << (winsize - 1)], P, mp)) != MP_OKAY) {
2090 goto __RES;
2094 /* create upper table */
2095 for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
2096 if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
2097 goto __RES;
2099 if ((err = redux (&M[x], P, mp)) != MP_OKAY) {
2100 goto __RES;
2104 /* set initial mode and bit cnt */
2105 mode = 0;
2106 bitcnt = 1;
2107 buf = 0;
2108 digidx = X->used - 1;
2109 bitcpy = 0;
2110 bitbuf = 0;
2112 for (;;) {
2113 /* grab next digit as required */
2114 if (--bitcnt == 0) {
2115 /* if digidx == -1 we are out of digits so break */
2116 if (digidx == -1) {
2117 break;
2119 /* read next digit and reset bitcnt */
2120 buf = X->dp[digidx--];
2121 bitcnt = DIGIT_BIT;
2124 /* grab the next msb from the exponent */
2125 y = (buf >> (DIGIT_BIT - 1)) & 1;
2126 buf <<= (mp_digit)1;
2128 /* if the bit is zero and mode == 0 then we ignore it
2129 * These represent the leading zero bits before the first 1 bit
2130 * in the exponent. Technically this opt is not required but it
2131 * does lower the # of trivial squaring/reductions used
2133 if (mode == 0 && y == 0) {
2134 continue;
2137 /* if the bit is zero and mode == 1 then we square */
2138 if (mode == 1 && y == 0) {
2139 if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2140 goto __RES;
2142 if ((err = redux (&res, P, mp)) != MP_OKAY) {
2143 goto __RES;
2145 continue;
2148 /* else we add it to the window */
2149 bitbuf |= (y << (winsize - ++bitcpy));
2150 mode = 2;
2152 if (bitcpy == winsize) {
2153 /* ok window is filled so square as required and multiply */
2154 /* square first */
2155 for (x = 0; x < winsize; x++) {
2156 if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2157 goto __RES;
2159 if ((err = redux (&res, P, mp)) != MP_OKAY) {
2160 goto __RES;
2164 /* then multiply */
2165 if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
2166 goto __RES;
2168 if ((err = redux (&res, P, mp)) != MP_OKAY) {
2169 goto __RES;
2172 /* empty window and reset */
2173 bitcpy = 0;
2174 bitbuf = 0;
2175 mode = 1;
2179 /* if bits remain then square/multiply */
2180 if (mode == 2 && bitcpy > 0) {
2181 /* square then multiply if the bit is set */
2182 for (x = 0; x < bitcpy; x++) {
2183 if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2184 goto __RES;
2186 if ((err = redux (&res, P, mp)) != MP_OKAY) {
2187 goto __RES;
2190 /* get next bit of the window */
2191 bitbuf <<= 1;
2192 if ((bitbuf & (1 << winsize)) != 0) {
2193 /* then multiply */
2194 if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
2195 goto __RES;
2197 if ((err = redux (&res, P, mp)) != MP_OKAY) {
2198 goto __RES;
2204 if (redmode == 0) {
2205 /* fixup result if Montgomery reduction is used
2206 * recall that any value in a Montgomery system is
2207 * actually multiplied by R mod n. So we have
2208 * to reduce one more time to cancel out the factor
2209 * of R.
2211 if ((err = redux(&res, P, mp)) != MP_OKAY) {
2212 goto __RES;
2216 /* swap res with Y */
2217 mp_exch (&res, Y);
2218 err = MP_OKAY;
2219 __RES:mp_clear (&res);
2220 __M:
2221 mp_clear(&M[1]);
2222 for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2223 mp_clear (&M[x]);
2225 return err;
2228 /* Greatest Common Divisor using the binary method */
2229 int mp_gcd (const mp_int * a, const mp_int * b, mp_int * c)
2231 mp_int u, v;
2232 int k, u_lsb, v_lsb, res;
2234 /* either zero than gcd is the largest */
2235 if (mp_iszero (a) == 1 && mp_iszero (b) == 0) {
2236 return mp_abs (b, c);
2238 if (mp_iszero (a) == 0 && mp_iszero (b) == 1) {
2239 return mp_abs (a, c);
2242 /* optimized. At this point if a == 0 then
2243 * b must equal zero too
2245 if (mp_iszero (a) == 1) {
2246 mp_zero(c);
2247 return MP_OKAY;
2250 /* get copies of a and b we can modify */
2251 if ((res = mp_init_copy (&u, a)) != MP_OKAY) {
2252 return res;
2255 if ((res = mp_init_copy (&v, b)) != MP_OKAY) {
2256 goto __U;
2259 /* must be positive for the remainder of the algorithm */
2260 u.sign = v.sign = MP_ZPOS;
2262 /* B1. Find the common power of two for u and v */
2263 u_lsb = mp_cnt_lsb(&u);
2264 v_lsb = mp_cnt_lsb(&v);
2265 k = MIN(u_lsb, v_lsb);
2267 if (k > 0) {
2268 /* divide the power of two out */
2269 if ((res = mp_div_2d(&u, k, &u, NULL)) != MP_OKAY) {
2270 goto __V;
2273 if ((res = mp_div_2d(&v, k, &v, NULL)) != MP_OKAY) {
2274 goto __V;
2278 /* divide any remaining factors of two out */
2279 if (u_lsb != k) {
2280 if ((res = mp_div_2d(&u, u_lsb - k, &u, NULL)) != MP_OKAY) {
2281 goto __V;
2285 if (v_lsb != k) {
2286 if ((res = mp_div_2d(&v, v_lsb - k, &v, NULL)) != MP_OKAY) {
2287 goto __V;
2291 while (mp_iszero(&v) == 0) {
2292 /* make sure v is the largest */
2293 if (mp_cmp_mag(&u, &v) == MP_GT) {
2294 /* swap u and v to make sure v is >= u */
2295 mp_exch(&u, &v);
2298 /* subtract smallest from largest */
2299 if ((res = s_mp_sub(&v, &u, &v)) != MP_OKAY) {
2300 goto __V;
2303 /* Divide out all factors of two */
2304 if ((res = mp_div_2d(&v, mp_cnt_lsb(&v), &v, NULL)) != MP_OKAY) {
2305 goto __V;
2309 /* multiply by 2**k which we divided out at the beginning */
2310 if ((res = mp_mul_2d (&u, k, c)) != MP_OKAY) {
2311 goto __V;
2313 c->sign = MP_ZPOS;
2314 res = MP_OKAY;
2315 __V:mp_clear (&u);
2316 __U:mp_clear (&v);
2317 return res;
2320 /* get the lower 32-bits of an mp_int */
2321 unsigned long mp_get_int(const mp_int * a)
2323 int i;
2324 unsigned long res;
2326 if (a->used == 0) {
2327 return 0;
2330 /* get number of digits of the lsb we have to read */
2331 i = MIN(a->used,(int)((sizeof(unsigned long)*CHAR_BIT+DIGIT_BIT-1)/DIGIT_BIT))-1;
2333 /* get most significant digit of result */
2334 res = DIGIT(a,i);
2336 while (--i >= 0) {
2337 res = (res << DIGIT_BIT) | DIGIT(a,i);
2340 /* force result to 32-bits always so it is consistent on non 32-bit platforms */
2341 return res & 0xFFFFFFFFUL;
2344 /* creates "a" then copies b into it */
2345 int mp_init_copy (mp_int * a, const mp_int * b)
2347 int res;
2349 if ((res = mp_init (a)) != MP_OKAY) {
2350 return res;
2352 return mp_copy (b, a);
2355 int WINAPIV mp_init_multi(mp_int *mp, ...)
2357 mp_err res = MP_OKAY; /* Assume ok until proven otherwise */
2358 int n = 0; /* Number of ok inits */
2359 mp_int* cur_arg = mp;
2360 va_list args;
2362 va_start(args, mp); /* init args to next argument from caller */
2363 while (cur_arg != NULL) {
2364 if (mp_init(cur_arg) != MP_OKAY) {
2365 /* Oops - error! Back-track and mp_clear what we already
2366 succeeded in init-ing, then return error.
2368 va_list clean_args;
2370 /* now start cleaning up */
2371 cur_arg = mp;
2372 va_start(clean_args, mp);
2373 while (n--) {
2374 mp_clear(cur_arg);
2375 cur_arg = va_arg(clean_args, mp_int*);
2377 va_end(clean_args);
2378 res = MP_MEM;
2379 break;
2381 n++;
2382 cur_arg = va_arg(args, mp_int*);
2384 va_end(args);
2385 return res; /* Assumed ok, if error flagged above. */
2388 /* hac 14.61, pp608 */
2389 int mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
2391 /* b cannot be negative */
2392 if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2393 return MP_VAL;
2396 /* if the modulus is odd we can use a faster routine instead */
2397 if (mp_isodd (b) == 1) {
2398 return fast_mp_invmod (a, b, c);
2401 return mp_invmod_slow(a, b, c);
2404 /* hac 14.61, pp608 */
2405 int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c)
2407 mp_int x, y, u, v, A, B, C, D;
2408 int res;
2410 /* b cannot be negative */
2411 if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2412 return MP_VAL;
2415 /* init temps */
2416 if ((res = mp_init_multi(&x, &y, &u, &v,
2417 &A, &B, &C, &D, NULL)) != MP_OKAY) {
2418 return res;
2421 /* x = a, y = b */
2422 if ((res = mp_copy (a, &x)) != MP_OKAY) {
2423 goto __ERR;
2425 if ((res = mp_copy (b, &y)) != MP_OKAY) {
2426 goto __ERR;
2429 /* 2. [modified] if x,y are both even then return an error! */
2430 if (mp_iseven (&x) == 1 && mp_iseven (&y) == 1) {
2431 res = MP_VAL;
2432 goto __ERR;
2435 /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
2436 if ((res = mp_copy (&x, &u)) != MP_OKAY) {
2437 goto __ERR;
2439 if ((res = mp_copy (&y, &v)) != MP_OKAY) {
2440 goto __ERR;
2442 mp_set (&A, 1);
2443 mp_set (&D, 1);
2445 top:
2446 /* 4. while u is even do */
2447 while (mp_iseven (&u) == 1) {
2448 /* 4.1 u = u/2 */
2449 if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
2450 goto __ERR;
2452 /* 4.2 if A or B is odd then */
2453 if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) {
2454 /* A = (A+y)/2, B = (B-x)/2 */
2455 if ((res = mp_add (&A, &y, &A)) != MP_OKAY) {
2456 goto __ERR;
2458 if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
2459 goto __ERR;
2462 /* A = A/2, B = B/2 */
2463 if ((res = mp_div_2 (&A, &A)) != MP_OKAY) {
2464 goto __ERR;
2466 if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
2467 goto __ERR;
2471 /* 5. while v is even do */
2472 while (mp_iseven (&v) == 1) {
2473 /* 5.1 v = v/2 */
2474 if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
2475 goto __ERR;
2477 /* 5.2 if C or D is odd then */
2478 if (mp_isodd (&C) == 1 || mp_isodd (&D) == 1) {
2479 /* C = (C+y)/2, D = (D-x)/2 */
2480 if ((res = mp_add (&C, &y, &C)) != MP_OKAY) {
2481 goto __ERR;
2483 if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
2484 goto __ERR;
2487 /* C = C/2, D = D/2 */
2488 if ((res = mp_div_2 (&C, &C)) != MP_OKAY) {
2489 goto __ERR;
2491 if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
2492 goto __ERR;
2496 /* 6. if u >= v then */
2497 if (mp_cmp (&u, &v) != MP_LT) {
2498 /* u = u - v, A = A - C, B = B - D */
2499 if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
2500 goto __ERR;
2503 if ((res = mp_sub (&A, &C, &A)) != MP_OKAY) {
2504 goto __ERR;
2507 if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
2508 goto __ERR;
2510 } else {
2511 /* v - v - u, C = C - A, D = D - B */
2512 if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
2513 goto __ERR;
2516 if ((res = mp_sub (&C, &A, &C)) != MP_OKAY) {
2517 goto __ERR;
2520 if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
2521 goto __ERR;
2525 /* if not zero goto step 4 */
2526 if (mp_iszero (&u) == 0)
2527 goto top;
2529 /* now a = C, b = D, gcd == g*v */
2531 /* if v != 1 then there is no inverse */
2532 if (mp_cmp_d (&v, 1) != MP_EQ) {
2533 res = MP_VAL;
2534 goto __ERR;
2537 /* if it's too low */
2538 while (mp_cmp_d(&C, 0) == MP_LT) {
2539 if ((res = mp_add(&C, b, &C)) != MP_OKAY) {
2540 goto __ERR;
2544 /* too big */
2545 while (mp_cmp_mag(&C, b) != MP_LT) {
2546 if ((res = mp_sub(&C, b, &C)) != MP_OKAY) {
2547 goto __ERR;
2551 /* C is now the inverse */
2552 mp_exch (&C, c);
2553 res = MP_OKAY;
2554 __ERR:mp_clear_multi (&x, &y, &u, &v, &A, &B, &C, &D, NULL);
2555 return res;
2558 /* c = |a| * |b| using Karatsuba Multiplication using
2559 * three half size multiplications
2561 * Let B represent the radix [e.g. 2**DIGIT_BIT] and
2562 * let n represent half of the number of digits in
2563 * the min(a,b)
2565 * a = a1 * B**n + a0
2566 * b = b1 * B**n + b0
2568 * Then, a * b =>
2569 a1b1 * B**2n + ((a1 - a0)(b1 - b0) + a0b0 + a1b1) * B + a0b0
2571 * Note that a1b1 and a0b0 are used twice and only need to be
2572 * computed once. So in total three half size (half # of
2573 * digit) multiplications are performed, a0b0, a1b1 and
2574 * (a1-b1)(a0-b0)
2576 * Note that a multiplication of half the digits requires
2577 * 1/4th the number of single precision multiplications so in
2578 * total after one call 25% of the single precision multiplications
2579 * are saved. Note also that the call to mp_mul can end up back
2580 * in this function if the a0, a1, b0, or b1 are above the threshold.
2581 * This is known as divide-and-conquer and leads to the famous
2582 * O(N**lg(3)) or O(N**1.584) work which is asymptotically lower than
2583 * the standard O(N**2) that the baseline/comba methods use.
2584 * Generally though the overhead of this method doesn't pay off
2585 * until a certain size (N ~ 80) is reached.
2587 int mp_karatsuba_mul (const mp_int * a, const mp_int * b, mp_int * c)
2589 mp_int x0, x1, y0, y1, t1, x0y0, x1y1;
2590 int B, err;
2592 /* default the return code to an error */
2593 err = MP_MEM;
2595 /* min # of digits */
2596 B = MIN (a->used, b->used);
2598 /* now divide in two */
2599 B = B >> 1;
2601 /* init copy all the temps */
2602 if (mp_init_size (&x0, B) != MP_OKAY)
2603 goto ERR;
2604 if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2605 goto X0;
2606 if (mp_init_size (&y0, B) != MP_OKAY)
2607 goto X1;
2608 if (mp_init_size (&y1, b->used - B) != MP_OKAY)
2609 goto Y0;
2611 /* init temps */
2612 if (mp_init_size (&t1, B * 2) != MP_OKAY)
2613 goto Y1;
2614 if (mp_init_size (&x0y0, B * 2) != MP_OKAY)
2615 goto T1;
2616 if (mp_init_size (&x1y1, B * 2) != MP_OKAY)
2617 goto X0Y0;
2619 /* now shift the digits */
2620 x0.used = y0.used = B;
2621 x1.used = a->used - B;
2622 y1.used = b->used - B;
2625 register int x;
2626 register mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
2628 /* we copy the digits directly instead of using higher level functions
2629 * since we also need to shift the digits
2631 tmpa = a->dp;
2632 tmpb = b->dp;
2634 tmpx = x0.dp;
2635 tmpy = y0.dp;
2636 for (x = 0; x < B; x++) {
2637 *tmpx++ = *tmpa++;
2638 *tmpy++ = *tmpb++;
2641 tmpx = x1.dp;
2642 for (x = B; x < a->used; x++) {
2643 *tmpx++ = *tmpa++;
2646 tmpy = y1.dp;
2647 for (x = B; x < b->used; x++) {
2648 *tmpy++ = *tmpb++;
2652 /* only need to clamp the lower words since by definition the
2653 * upper words x1/y1 must have a known number of digits
2655 mp_clamp (&x0);
2656 mp_clamp (&y0);
2658 /* now calc the products x0y0 and x1y1 */
2659 /* after this x0 is no longer required, free temp [x0==t2]! */
2660 if (mp_mul (&x0, &y0, &x0y0) != MP_OKAY)
2661 goto X1Y1; /* x0y0 = x0*y0 */
2662 if (mp_mul (&x1, &y1, &x1y1) != MP_OKAY)
2663 goto X1Y1; /* x1y1 = x1*y1 */
2665 /* now calc x1-x0 and y1-y0 */
2666 if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2667 goto X1Y1; /* t1 = x1 - x0 */
2668 if (mp_sub (&y1, &y0, &x0) != MP_OKAY)
2669 goto X1Y1; /* t2 = y1 - y0 */
2670 if (mp_mul (&t1, &x0, &t1) != MP_OKAY)
2671 goto X1Y1; /* t1 = (x1 - x0) * (y1 - y0) */
2673 /* add x0y0 */
2674 if (mp_add (&x0y0, &x1y1, &x0) != MP_OKAY)
2675 goto X1Y1; /* t2 = x0y0 + x1y1 */
2676 if (mp_sub (&x0, &t1, &t1) != MP_OKAY)
2677 goto X1Y1; /* t1 = x0y0 + x1y1 - (x1-x0)*(y1-y0) */
2679 /* shift by B */
2680 if (mp_lshd (&t1, B) != MP_OKAY)
2681 goto X1Y1; /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
2682 if (mp_lshd (&x1y1, B * 2) != MP_OKAY)
2683 goto X1Y1; /* x1y1 = x1y1 << 2*B */
2685 if (mp_add (&x0y0, &t1, &t1) != MP_OKAY)
2686 goto X1Y1; /* t1 = x0y0 + t1 */
2687 if (mp_add (&t1, &x1y1, c) != MP_OKAY)
2688 goto X1Y1; /* t1 = x0y0 + t1 + x1y1 */
2690 /* Algorithm succeeded set the return code to MP_OKAY */
2691 err = MP_OKAY;
2693 X1Y1:mp_clear (&x1y1);
2694 X0Y0:mp_clear (&x0y0);
2695 T1:mp_clear (&t1);
2696 Y1:mp_clear (&y1);
2697 Y0:mp_clear (&y0);
2698 X1:mp_clear (&x1);
2699 X0:mp_clear (&x0);
2700 ERR:
2701 return err;
2704 /* Karatsuba squaring, computes b = a*a using three
2705 * half size squarings
2707 * See comments of karatsuba_mul for details. It
2708 * is essentially the same algorithm but merely
2709 * tuned to perform recursive squarings.
2711 int mp_karatsuba_sqr (const mp_int * a, mp_int * b)
2713 mp_int x0, x1, t1, t2, x0x0, x1x1;
2714 int B, err;
2716 err = MP_MEM;
2718 /* min # of digits */
2719 B = a->used;
2721 /* now divide in two */
2722 B = B >> 1;
2724 /* init copy all the temps */
2725 if (mp_init_size (&x0, B) != MP_OKAY)
2726 goto ERR;
2727 if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2728 goto X0;
2730 /* init temps */
2731 if (mp_init_size (&t1, a->used * 2) != MP_OKAY)
2732 goto X1;
2733 if (mp_init_size (&t2, a->used * 2) != MP_OKAY)
2734 goto T1;
2735 if (mp_init_size (&x0x0, B * 2) != MP_OKAY)
2736 goto T2;
2737 if (mp_init_size (&x1x1, (a->used - B) * 2) != MP_OKAY)
2738 goto X0X0;
2741 register int x;
2742 register mp_digit *dst, *src;
2744 src = a->dp;
2746 /* now shift the digits */
2747 dst = x0.dp;
2748 for (x = 0; x < B; x++) {
2749 *dst++ = *src++;
2752 dst = x1.dp;
2753 for (x = B; x < a->used; x++) {
2754 *dst++ = *src++;
2758 x0.used = B;
2759 x1.used = a->used - B;
2761 mp_clamp (&x0);
2763 /* now calc the products x0*x0 and x1*x1 */
2764 if (mp_sqr (&x0, &x0x0) != MP_OKAY)
2765 goto X1X1; /* x0x0 = x0*x0 */
2766 if (mp_sqr (&x1, &x1x1) != MP_OKAY)
2767 goto X1X1; /* x1x1 = x1*x1 */
2769 /* now calc (x1-x0)**2 */
2770 if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2771 goto X1X1; /* t1 = x1 - x0 */
2772 if (mp_sqr (&t1, &t1) != MP_OKAY)
2773 goto X1X1; /* t1 = (x1 - x0) * (x1 - x0) */
2775 /* add x0y0 */
2776 if (s_mp_add (&x0x0, &x1x1, &t2) != MP_OKAY)
2777 goto X1X1; /* t2 = x0x0 + x1x1 */
2778 if (mp_sub (&t2, &t1, &t1) != MP_OKAY)
2779 goto X1X1; /* t1 = x0x0 + x1x1 - (x1-x0)*(x1-x0) */
2781 /* shift by B */
2782 if (mp_lshd (&t1, B) != MP_OKAY)
2783 goto X1X1; /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
2784 if (mp_lshd (&x1x1, B * 2) != MP_OKAY)
2785 goto X1X1; /* x1x1 = x1x1 << 2*B */
2787 if (mp_add (&x0x0, &t1, &t1) != MP_OKAY)
2788 goto X1X1; /* t1 = x0x0 + t1 */
2789 if (mp_add (&t1, &x1x1, b) != MP_OKAY)
2790 goto X1X1; /* t1 = x0x0 + t1 + x1x1 */
2792 err = MP_OKAY;
2794 X1X1:mp_clear (&x1x1);
2795 X0X0:mp_clear (&x0x0);
2796 T2:mp_clear (&t2);
2797 T1:mp_clear (&t1);
2798 X1:mp_clear (&x1);
2799 X0:mp_clear (&x0);
2800 ERR:
2801 return err;
2804 /* computes least common multiple as |a*b|/(a, b) */
2805 int mp_lcm (const mp_int * a, const mp_int * b, mp_int * c)
2807 int res;
2808 mp_int t1, t2;
2811 if ((res = mp_init_multi (&t1, &t2, NULL)) != MP_OKAY) {
2812 return res;
2815 /* t1 = get the GCD of the two inputs */
2816 if ((res = mp_gcd (a, b, &t1)) != MP_OKAY) {
2817 goto __T;
2820 /* divide the smallest by the GCD */
2821 if (mp_cmp_mag(a, b) == MP_LT) {
2822 /* store quotient in t2 so that t2 * b is the LCM */
2823 if ((res = mp_div(a, &t1, &t2, NULL)) != MP_OKAY) {
2824 goto __T;
2826 res = mp_mul(b, &t2, c);
2827 } else {
2828 /* store quotient in t2 so that t2 * a is the LCM */
2829 if ((res = mp_div(b, &t1, &t2, NULL)) != MP_OKAY) {
2830 goto __T;
2832 res = mp_mul(a, &t2, c);
2835 /* fix the sign to positive */
2836 c->sign = MP_ZPOS;
2838 __T:
2839 mp_clear_multi (&t1, &t2, NULL);
2840 return res;
2843 /* c = a mod b, 0 <= c < b */
2845 mp_mod (const mp_int * a, mp_int * b, mp_int * c)
2847 mp_int t;
2848 int res;
2850 if ((res = mp_init (&t)) != MP_OKAY) {
2851 return res;
2854 if ((res = mp_div (a, b, NULL, &t)) != MP_OKAY) {
2855 mp_clear (&t);
2856 return res;
2859 if (t.sign != b->sign) {
2860 res = mp_add (b, &t, c);
2861 } else {
2862 res = MP_OKAY;
2863 mp_exch (&t, c);
2866 mp_clear (&t);
2867 return res;
2870 static int
2871 mp_mod_d (const mp_int * a, mp_digit b, mp_digit * c)
2873 return mp_div_d(a, b, NULL, c);
2876 /* b = a*2 */
2877 static int mp_mul_2(const mp_int * a, mp_int * b)
2879 int x, res, oldused;
2881 /* grow to accommodate result */
2882 if (b->alloc < a->used + 1) {
2883 if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
2884 return res;
2888 oldused = b->used;
2889 b->used = a->used;
2892 register mp_digit r, rr, *tmpa, *tmpb;
2894 /* alias for source */
2895 tmpa = a->dp;
2897 /* alias for dest */
2898 tmpb = b->dp;
2900 /* carry */
2901 r = 0;
2902 for (x = 0; x < a->used; x++) {
2904 /* get what will be the *next* carry bit from the
2905 * MSB of the current digit
2907 rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
2909 /* now shift up this digit, add in the carry [from the previous] */
2910 *tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
2912 /* copy the carry that would be from the source
2913 * digit into the next iteration
2915 r = rr;
2918 /* new leading digit? */
2919 if (r != 0) {
2920 /* add a MSB which is always 1 at this point */
2921 *tmpb = 1;
2922 ++(b->used);
2925 /* now zero any excess digits on the destination
2926 * that we didn't write to
2928 tmpb = b->dp + b->used;
2929 for (x = b->used; x < oldused; x++) {
2930 *tmpb++ = 0;
2933 b->sign = a->sign;
2934 return MP_OKAY;
2938 * shifts with subtractions when the result is greater than b.
2940 * The method is slightly modified to shift B unconditionally up to just under
2941 * the leading bit of b. This saves a lot of multiple precision shifting.
2943 int mp_montgomery_calc_normalization (mp_int * a, const mp_int * b)
2945 int x, bits, res;
2947 /* how many bits of last digit does b use */
2948 bits = mp_count_bits (b) % DIGIT_BIT;
2951 if (b->used > 1) {
2952 if ((res = mp_2expt (a, (b->used - 1) * DIGIT_BIT + bits - 1)) != MP_OKAY) {
2953 return res;
2955 } else {
2956 mp_set(a, 1);
2957 bits = 1;
2961 /* now compute C = A * B mod b */
2962 for (x = bits - 1; x < DIGIT_BIT; x++) {
2963 if ((res = mp_mul_2 (a, a)) != MP_OKAY) {
2964 return res;
2966 if (mp_cmp_mag (a, b) != MP_LT) {
2967 if ((res = s_mp_sub (a, b, a)) != MP_OKAY) {
2968 return res;
2973 return MP_OKAY;
2976 /* computes xR**-1 == x (mod N) via Montgomery Reduction */
2978 mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
2980 int ix, res, digs;
2981 mp_digit mu;
2983 /* can the fast reduction [comba] method be used?
2985 * Note that unlike in mul you're safely allowed *less*
2986 * than the available columns [255 per default] since carries
2987 * are fixed up in the inner loop.
2989 digs = n->used * 2 + 1;
2990 if ((digs < MP_WARRAY) &&
2991 n->used <
2992 (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2993 return fast_mp_montgomery_reduce (x, n, rho);
2996 /* grow the input as required */
2997 if (x->alloc < digs) {
2998 if ((res = mp_grow (x, digs)) != MP_OKAY) {
2999 return res;
3002 x->used = digs;
3004 for (ix = 0; ix < n->used; ix++) {
3005 /* mu = ai * rho mod b
3007 * The value of rho must be precalculated via
3008 * montgomery_setup() such that
3009 * it equals -1/n0 mod b this allows the
3010 * following inner loop to reduce the
3011 * input one digit at a time
3013 mu = (mp_digit) (((mp_word)x->dp[ix]) * ((mp_word)rho) & MP_MASK);
3015 /* a = a + mu * m * b**i */
3017 register int iy;
3018 register mp_digit *tmpn, *tmpx, u;
3019 register mp_word r;
3021 /* alias for digits of the modulus */
3022 tmpn = n->dp;
3024 /* alias for the digits of x [the input] */
3025 tmpx = x->dp + ix;
3027 /* set the carry to zero */
3028 u = 0;
3030 /* Multiply and add in place */
3031 for (iy = 0; iy < n->used; iy++) {
3032 /* compute product and sum */
3033 r = ((mp_word)mu) * ((mp_word)*tmpn++) +
3034 ((mp_word) u) + ((mp_word) * tmpx);
3036 /* get carry */
3037 u = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
3039 /* fix digit */
3040 *tmpx++ = (mp_digit)(r & ((mp_word) MP_MASK));
3042 /* At this point the ix'th digit of x should be zero */
3045 /* propagate carries upwards as required*/
3046 while (u) {
3047 *tmpx += u;
3048 u = *tmpx >> DIGIT_BIT;
3049 *tmpx++ &= MP_MASK;
3054 /* at this point the n.used'th least
3055 * significant digits of x are all zero
3056 * which means we can shift x to the
3057 * right by n.used digits and the
3058 * residue is unchanged.
3061 /* x = x/b**n.used */
3062 mp_clamp(x);
3063 mp_rshd (x, n->used);
3065 /* if x >= n then x = x - n */
3066 if (mp_cmp_mag (x, n) != MP_LT) {
3067 return s_mp_sub (x, n, x);
3070 return MP_OKAY;
3073 /* setups the montgomery reduction stuff */
3075 mp_montgomery_setup (const mp_int * n, mp_digit * rho)
3077 mp_digit x, b;
3079 /* fast inversion mod 2**k
3081 * Based on the fact that
3083 * XA = 1 (mod 2**n) => (X(2-XA)) A = 1 (mod 2**2n)
3084 * => 2*X*A - X*X*A*A = 1
3085 * => 2*(1) - (1) = 1
3087 b = n->dp[0];
3089 if ((b & 1) == 0) {
3090 return MP_VAL;
3093 x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
3094 x *= 2 - b * x; /* here x*a==1 mod 2**8 */
3095 x *= 2 - b * x; /* here x*a==1 mod 2**16 */
3096 x *= 2 - b * x; /* here x*a==1 mod 2**32 */
3098 /* rho = -1/m mod b */
3099 *rho = (((mp_word)1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK;
3101 return MP_OKAY;
3104 /* high level multiplication (handles sign) */
3105 int mp_mul (const mp_int * a, const mp_int * b, mp_int * c)
3107 int res, neg;
3108 neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
3110 /* use Karatsuba? */
3111 if (MIN (a->used, b->used) >= KARATSUBA_MUL_CUTOFF) {
3112 res = mp_karatsuba_mul (a, b, c);
3113 } else
3115 /* can we use the fast multiplier?
3117 * The fast multiplier can be used if the output will
3118 * have less than MP_WARRAY digits and the number of
3119 * digits won't affect carry propagation
3121 int digs = a->used + b->used + 1;
3123 if ((digs < MP_WARRAY) &&
3124 MIN(a->used, b->used) <=
3125 (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
3126 res = fast_s_mp_mul_digs (a, b, c, digs);
3127 } else
3128 res = s_mp_mul (a, b, c); /* uses s_mp_mul_digs */
3130 c->sign = (c->used > 0) ? neg : MP_ZPOS;
3131 return res;
3134 /* d = a * b (mod c) */
3136 mp_mulmod (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
3138 int res;
3139 mp_int t;
3141 if ((res = mp_init (&t)) != MP_OKAY) {
3142 return res;
3145 if ((res = mp_mul (a, b, &t)) != MP_OKAY) {
3146 mp_clear (&t);
3147 return res;
3149 res = mp_mod (&t, c, d);
3150 mp_clear (&t);
3151 return res;
3154 /* table of first PRIME_SIZE primes */
3155 static const mp_digit __prime_tab[] = {
3156 0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
3157 0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
3158 0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
3159 0x0061, 0x0065, 0x0067, 0x006B, 0x006D, 0x0071, 0x007F, 0x0083,
3160 0x0089, 0x008B, 0x0095, 0x0097, 0x009D, 0x00A3, 0x00A7, 0x00AD,
3161 0x00B3, 0x00B5, 0x00BF, 0x00C1, 0x00C5, 0x00C7, 0x00D3, 0x00DF,
3162 0x00E3, 0x00E5, 0x00E9, 0x00EF, 0x00F1, 0x00FB, 0x0101, 0x0107,
3163 0x010D, 0x010F, 0x0115, 0x0119, 0x011B, 0x0125, 0x0133, 0x0137,
3165 0x0139, 0x013D, 0x014B, 0x0151, 0x015B, 0x015D, 0x0161, 0x0167,
3166 0x016F, 0x0175, 0x017B, 0x017F, 0x0185, 0x018D, 0x0191, 0x0199,
3167 0x01A3, 0x01A5, 0x01AF, 0x01B1, 0x01B7, 0x01BB, 0x01C1, 0x01C9,
3168 0x01CD, 0x01CF, 0x01D3, 0x01DF, 0x01E7, 0x01EB, 0x01F3, 0x01F7,
3169 0x01FD, 0x0209, 0x020B, 0x021D, 0x0223, 0x022D, 0x0233, 0x0239,
3170 0x023B, 0x0241, 0x024B, 0x0251, 0x0257, 0x0259, 0x025F, 0x0265,
3171 0x0269, 0x026B, 0x0277, 0x0281, 0x0283, 0x0287, 0x028D, 0x0293,
3172 0x0295, 0x02A1, 0x02A5, 0x02AB, 0x02B3, 0x02BD, 0x02C5, 0x02CF,
3174 0x02D7, 0x02DD, 0x02E3, 0x02E7, 0x02EF, 0x02F5, 0x02F9, 0x0301,
3175 0x0305, 0x0313, 0x031D, 0x0329, 0x032B, 0x0335, 0x0337, 0x033B,
3176 0x033D, 0x0347, 0x0355, 0x0359, 0x035B, 0x035F, 0x036D, 0x0371,
3177 0x0373, 0x0377, 0x038B, 0x038F, 0x0397, 0x03A1, 0x03A9, 0x03AD,
3178 0x03B3, 0x03B9, 0x03C7, 0x03CB, 0x03D1, 0x03D7, 0x03DF, 0x03E5,
3179 0x03F1, 0x03F5, 0x03FB, 0x03FD, 0x0407, 0x0409, 0x040F, 0x0419,
3180 0x041B, 0x0425, 0x0427, 0x042D, 0x043F, 0x0443, 0x0445, 0x0449,
3181 0x044F, 0x0455, 0x045D, 0x0463, 0x0469, 0x047F, 0x0481, 0x048B,
3183 0x0493, 0x049D, 0x04A3, 0x04A9, 0x04B1, 0x04BD, 0x04C1, 0x04C7,
3184 0x04CD, 0x04CF, 0x04D5, 0x04E1, 0x04EB, 0x04FD, 0x04FF, 0x0503,
3185 0x0509, 0x050B, 0x0511, 0x0515, 0x0517, 0x051B, 0x0527, 0x0529,
3186 0x052F, 0x0551, 0x0557, 0x055D, 0x0565, 0x0577, 0x0581, 0x058F,
3187 0x0593, 0x0595, 0x0599, 0x059F, 0x05A7, 0x05AB, 0x05AD, 0x05B3,
3188 0x05BF, 0x05C9, 0x05CB, 0x05CF, 0x05D1, 0x05D5, 0x05DB, 0x05E7,
3189 0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
3190 0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653
3193 /* determines if an integers is divisible by one
3194 * of the first PRIME_SIZE primes or not
3196 * sets result to 0 if not, 1 if yes
3198 static int mp_prime_is_divisible (const mp_int * a, int *result)
3200 int err, ix;
3201 mp_digit res;
3203 /* default to not */
3204 *result = MP_NO;
3206 for (ix = 0; ix < PRIME_SIZE; ix++) {
3207 /* what is a mod __prime_tab[ix] */
3208 if ((err = mp_mod_d (a, __prime_tab[ix], &res)) != MP_OKAY) {
3209 return err;
3212 /* is the residue zero? */
3213 if (res == 0) {
3214 *result = MP_YES;
3215 return MP_OKAY;
3219 return MP_OKAY;
3222 /* Miller-Rabin test of "a" to the base of "b" as described in
3223 * HAC pp. 139 Algorithm 4.24
3225 * Sets result to 0 if definitely composite or 1 if probably prime.
3226 * Randomly the chance of error is no more than 1/4 and often
3227 * very much lower.
3229 static int mp_prime_miller_rabin (mp_int * a, const mp_int * b, int *result)
3231 mp_int n1, y, r;
3232 int s, j, err;
3234 /* default */
3235 *result = MP_NO;
3237 /* ensure b > 1 */
3238 if (mp_cmp_d(b, 1) != MP_GT) {
3239 return MP_VAL;
3242 /* get n1 = a - 1 */
3243 if ((err = mp_init_copy (&n1, a)) != MP_OKAY) {
3244 return err;
3246 if ((err = mp_sub_d (&n1, 1, &n1)) != MP_OKAY) {
3247 goto __N1;
3250 /* set 2**s * r = n1 */
3251 if ((err = mp_init_copy (&r, &n1)) != MP_OKAY) {
3252 goto __N1;
3255 /* count the number of least significant bits
3256 * which are zero
3258 s = mp_cnt_lsb(&r);
3260 /* now divide n - 1 by 2**s */
3261 if ((err = mp_div_2d (&r, s, &r, NULL)) != MP_OKAY) {
3262 goto __R;
3265 /* compute y = b**r mod a */
3266 if ((err = mp_init (&y)) != MP_OKAY) {
3267 goto __R;
3269 if ((err = mp_exptmod (b, &r, a, &y)) != MP_OKAY) {
3270 goto __Y;
3273 /* if y != 1 and y != n1 do */
3274 if (mp_cmp_d (&y, 1) != MP_EQ && mp_cmp (&y, &n1) != MP_EQ) {
3275 j = 1;
3276 /* while j <= s-1 and y != n1 */
3277 while ((j <= (s - 1)) && mp_cmp (&y, &n1) != MP_EQ) {
3278 if ((err = mp_sqrmod (&y, a, &y)) != MP_OKAY) {
3279 goto __Y;
3282 /* if y == 1 then composite */
3283 if (mp_cmp_d (&y, 1) == MP_EQ) {
3284 goto __Y;
3287 ++j;
3290 /* if y != n1 then composite */
3291 if (mp_cmp (&y, &n1) != MP_EQ) {
3292 goto __Y;
3296 /* probably prime now */
3297 *result = MP_YES;
3298 __Y:mp_clear (&y);
3299 __R:mp_clear (&r);
3300 __N1:mp_clear (&n1);
3301 return err;
3304 /* performs a variable number of rounds of Miller-Rabin
3306 * Probability of error after t rounds is no more than
3309 * Sets result to 1 if probably prime, 0 otherwise
3311 static int mp_prime_is_prime (mp_int * a, int t, int *result)
3313 mp_int b;
3314 int ix, err, res;
3316 /* default to no */
3317 *result = MP_NO;
3319 /* valid value of t? */
3320 if (t <= 0 || t > PRIME_SIZE) {
3321 return MP_VAL;
3324 /* is the input equal to one of the primes in the table? */
3325 for (ix = 0; ix < PRIME_SIZE; ix++) {
3326 if (mp_cmp_d(a, __prime_tab[ix]) == MP_EQ) {
3327 *result = 1;
3328 return MP_OKAY;
3332 /* first perform trial division */
3333 if ((err = mp_prime_is_divisible (a, &res)) != MP_OKAY) {
3334 return err;
3337 /* return if it was trivially divisible */
3338 if (res == MP_YES) {
3339 return MP_OKAY;
3342 /* now perform the miller-rabin rounds */
3343 if ((err = mp_init (&b)) != MP_OKAY) {
3344 return err;
3347 for (ix = 0; ix < t; ix++) {
3348 /* set the prime */
3349 mp_set (&b, __prime_tab[ix]);
3351 if ((err = mp_prime_miller_rabin (a, &b, &res)) != MP_OKAY) {
3352 goto __B;
3355 if (res == MP_NO) {
3356 goto __B;
3360 /* passed the test */
3361 *result = MP_YES;
3362 __B:mp_clear (&b);
3363 return err;
3366 static const struct {
3367 int k, t;
3368 } sizes[] = {
3369 { 128, 28 },
3370 { 256, 16 },
3371 { 384, 10 },
3372 { 512, 7 },
3373 { 640, 6 },
3374 { 768, 5 },
3375 { 896, 4 },
3376 { 1024, 4 }
3379 /* returns # of RM trials required for a given bit size */
3380 int mp_prime_rabin_miller_trials(int size)
3382 unsigned int x;
3384 for (x = 0; x < ARRAY_SIZE(sizes); x++) {
3385 if (sizes[x].k == size) {
3386 return sizes[x].t;
3387 } else if (sizes[x].k > size) {
3388 return (x == 0) ? sizes[0].t : sizes[x - 1].t;
3391 return sizes[x-1].t + 1;
3394 /* makes a truly random prime of a given size (bits),
3396 * Flags are as follows:
3398 * LTM_PRIME_BBS - make prime congruent to 3 mod 4
3399 * LTM_PRIME_SAFE - make sure (p-1)/2 is prime as well (implies LTM_PRIME_BBS)
3400 * LTM_PRIME_2MSB_OFF - make the 2nd highest bit zero
3401 * LTM_PRIME_2MSB_ON - make the 2nd highest bit one
3403 * You have to supply a callback which fills in a buffer with random bytes. "dat" is a parameter you can
3404 * have passed to the callback (e.g. a state or something). This function doesn't use "dat" itself
3405 * so it can be NULL
3409 /* This is possibly the mother of all prime generation functions, muahahahahaha! */
3410 int mp_prime_random_ex(mp_int *a, int t, int size, int flags, ltm_prime_callback cb, void *dat)
3412 unsigned char *tmp, maskAND, maskOR_msb, maskOR_lsb;
3413 int res, err, bsize, maskOR_msb_offset;
3415 /* sanity check the input */
3416 if (size <= 1 || t <= 0) {
3417 return MP_VAL;
3420 /* LTM_PRIME_SAFE implies LTM_PRIME_BBS */
3421 if (flags & LTM_PRIME_SAFE) {
3422 flags |= LTM_PRIME_BBS;
3425 /* calc the byte size */
3426 bsize = (size>>3)+((size&7)?1:0);
3428 /* we need a buffer of bsize bytes */
3429 tmp = malloc(bsize);
3430 if (tmp == NULL) {
3431 return MP_MEM;
3434 /* calc the maskAND value for the MSbyte*/
3435 maskAND = ((size&7) == 0) ? 0xFF : (0xFF >> (8 - (size & 7)));
3437 /* calc the maskOR_msb */
3438 maskOR_msb = 0;
3439 maskOR_msb_offset = ((size & 7) == 1) ? 1 : 0;
3440 if (flags & LTM_PRIME_2MSB_ON) {
3441 maskOR_msb |= 1 << ((size - 2) & 7);
3442 } else if (flags & LTM_PRIME_2MSB_OFF) {
3443 maskAND &= ~(1 << ((size - 2) & 7));
3446 /* get the maskOR_lsb */
3447 maskOR_lsb = 0;
3448 if (flags & LTM_PRIME_BBS) {
3449 maskOR_lsb |= 3;
3452 do {
3453 /* read the bytes */
3454 if (cb(tmp, bsize, dat) != bsize) {
3455 err = MP_VAL;
3456 goto error;
3459 /* work over the MSbyte */
3460 tmp[0] &= maskAND;
3461 tmp[0] |= 1 << ((size - 1) & 7);
3463 /* mix in the maskORs */
3464 tmp[maskOR_msb_offset] |= maskOR_msb;
3465 tmp[bsize-1] |= maskOR_lsb;
3467 /* read it in */
3468 if ((err = mp_read_unsigned_bin(a, tmp, bsize)) != MP_OKAY) { goto error; }
3470 /* is it prime? */
3471 if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY) { goto error; }
3472 if (res == MP_NO) {
3473 continue;
3476 if (flags & LTM_PRIME_SAFE) {
3477 /* see if (a-1)/2 is prime */
3478 if ((err = mp_sub_d(a, 1, a)) != MP_OKAY) { goto error; }
3479 if ((err = mp_div_2(a, a)) != MP_OKAY) { goto error; }
3481 /* is it prime? */
3482 if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY) { goto error; }
3484 } while (res == MP_NO);
3486 if (flags & LTM_PRIME_SAFE) {
3487 /* restore a to the original value */
3488 if ((err = mp_mul_2(a, a)) != MP_OKAY) { goto error; }
3489 if ((err = mp_add_d(a, 1, a)) != MP_OKAY) { goto error; }
3492 err = MP_OKAY;
3493 error:
3494 free(tmp);
3495 return err;
3498 /* reads an unsigned char array, assumes the msb is stored first [big endian] */
3500 mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c)
3502 int res;
3504 /* make sure there are at least two digits */
3505 if (a->alloc < 2) {
3506 if ((res = mp_grow(a, 2)) != MP_OKAY) {
3507 return res;
3511 /* zero the int */
3512 mp_zero (a);
3514 /* read the bytes in */
3515 while (c-- > 0) {
3516 if ((res = mp_mul_2d (a, 8, a)) != MP_OKAY) {
3517 return res;
3520 a->dp[0] |= *b++;
3521 a->used += 1;
3523 mp_clamp (a);
3524 return MP_OKAY;
3527 /* reduces x mod m, assumes 0 < x < m**2, mu is
3528 * precomputed via mp_reduce_setup.
3529 * From HAC pp.604 Algorithm 14.42
3532 mp_reduce (mp_int * x, const mp_int * m, const mp_int * mu)
3534 mp_int q;
3535 int res, um = m->used;
3537 /* q = x */
3538 if ((res = mp_init_copy (&q, x)) != MP_OKAY) {
3539 return res;
3542 /* q1 = x / b**(k-1) */
3543 mp_rshd (&q, um - 1);
3545 /* according to HAC this optimization is ok */
3546 if (((unsigned long) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
3547 if ((res = mp_mul (&q, mu, &q)) != MP_OKAY) {
3548 goto CLEANUP;
3550 } else {
3551 if ((res = s_mp_mul_high_digs (&q, mu, &q, um - 1)) != MP_OKAY) {
3552 goto CLEANUP;
3556 /* q3 = q2 / b**(k+1) */
3557 mp_rshd (&q, um + 1);
3559 /* x = x mod b**(k+1), quick (no division) */
3560 if ((res = mp_mod_2d (x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
3561 goto CLEANUP;
3564 /* q = q * m mod b**(k+1), quick (no division) */
3565 if ((res = s_mp_mul_digs (&q, m, &q, um + 1)) != MP_OKAY) {
3566 goto CLEANUP;
3569 /* x = x - q */
3570 if ((res = mp_sub (x, &q, x)) != MP_OKAY) {
3571 goto CLEANUP;
3574 /* If x < 0, add b**(k+1) to it */
3575 if (mp_cmp_d (x, 0) == MP_LT) {
3576 mp_set (&q, 1);
3577 if ((res = mp_lshd (&q, um + 1)) != MP_OKAY)
3578 goto CLEANUP;
3579 if ((res = mp_add (x, &q, x)) != MP_OKAY)
3580 goto CLEANUP;
3583 /* Back off if it's too big */
3584 while (mp_cmp (x, m) != MP_LT) {
3585 if ((res = s_mp_sub (x, m, x)) != MP_OKAY) {
3586 goto CLEANUP;
3590 CLEANUP:
3591 mp_clear (&q);
3593 return res;
3596 /* reduces a modulo n where n is of the form 2**p - d */
3598 mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d)
3600 mp_int q;
3601 int p, res;
3603 if ((res = mp_init(&q)) != MP_OKAY) {
3604 return res;
3607 p = mp_count_bits(n);
3608 top:
3609 /* q = a/2**p, a = a mod 2**p */
3610 if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
3611 goto ERR;
3614 if (d != 1) {
3615 /* q = q * d */
3616 if ((res = mp_mul_d(&q, d, &q)) != MP_OKAY) {
3617 goto ERR;
3621 /* a = a + q */
3622 if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
3623 goto ERR;
3626 if (mp_cmp_mag(a, n) != MP_LT) {
3627 s_mp_sub(a, n, a);
3628 goto top;
3631 ERR:
3632 mp_clear(&q);
3633 return res;
3636 /* determines the setup value */
3637 static int
3638 mp_reduce_2k_setup(const mp_int *a, mp_digit *d)
3640 int res, p;
3641 mp_int tmp;
3643 if ((res = mp_init(&tmp)) != MP_OKAY) {
3644 return res;
3647 p = mp_count_bits(a);
3648 if ((res = mp_2expt(&tmp, p)) != MP_OKAY) {
3649 mp_clear(&tmp);
3650 return res;
3653 if ((res = s_mp_sub(&tmp, a, &tmp)) != MP_OKAY) {
3654 mp_clear(&tmp);
3655 return res;
3658 *d = tmp.dp[0];
3659 mp_clear(&tmp);
3660 return MP_OKAY;
3663 /* pre-calculate the value required for Barrett reduction
3664 * For a given modulus "b" it calculates the value required in "a"
3666 int mp_reduce_setup (mp_int * a, const mp_int * b)
3668 int res;
3670 if ((res = mp_2expt (a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
3671 return res;
3673 return mp_div (a, b, a, NULL);
3676 /* set to a digit */
3677 void mp_set (mp_int * a, mp_digit b)
3679 mp_zero (a);
3680 a->dp[0] = b & MP_MASK;
3681 a->used = (a->dp[0] != 0) ? 1 : 0;
3684 /* set a 32-bit const */
3685 int mp_set_int (mp_int * a, unsigned long b)
3687 int x, res;
3689 mp_zero (a);
3691 /* set four bits at a time */
3692 for (x = 0; x < 8; x++) {
3693 /* shift the number up four bits */
3694 if ((res = mp_mul_2d (a, 4, a)) != MP_OKAY) {
3695 return res;
3698 /* OR in the top four bits of the source */
3699 a->dp[0] |= (b >> 28) & 15;
3701 /* shift the source up to the next four bits */
3702 b <<= 4;
3704 /* ensure that digits are not clamped off */
3705 a->used += 1;
3707 mp_clamp (a);
3708 return MP_OKAY;
3711 /* shrink a bignum */
3712 int mp_shrink (mp_int * a)
3714 mp_digit *tmp;
3715 if (a->alloc != a->used && a->used > 0) {
3716 if ((tmp = realloc(a->dp, sizeof (mp_digit) * a->used)) == NULL) {
3717 return MP_MEM;
3719 a->dp = tmp;
3720 a->alloc = a->used;
3722 return MP_OKAY;
3725 /* computes b = a*a */
3727 mp_sqr (const mp_int * a, mp_int * b)
3729 int res;
3731 if (a->used >= KARATSUBA_SQR_CUTOFF) {
3732 res = mp_karatsuba_sqr (a, b);
3733 } else
3735 /* can we use the fast comba multiplier? */
3736 if ((a->used * 2 + 1) < MP_WARRAY &&
3737 a->used <
3738 (1 << (sizeof(mp_word) * CHAR_BIT - 2*DIGIT_BIT - 1))) {
3739 res = fast_s_mp_sqr (a, b);
3740 } else
3741 res = s_mp_sqr (a, b);
3743 b->sign = MP_ZPOS;
3744 return res;
3747 /* c = a * a (mod b) */
3749 mp_sqrmod (const mp_int * a, mp_int * b, mp_int * c)
3751 int res;
3752 mp_int t;
3754 if ((res = mp_init (&t)) != MP_OKAY) {
3755 return res;
3758 if ((res = mp_sqr (a, &t)) != MP_OKAY) {
3759 mp_clear (&t);
3760 return res;
3762 res = mp_mod (&t, b, c);
3763 mp_clear (&t);
3764 return res;
3767 /* high level subtraction (handles signs) */
3769 mp_sub (mp_int * a, mp_int * b, mp_int * c)
3771 int sa, sb, res;
3773 sa = a->sign;
3774 sb = b->sign;
3776 if (sa != sb) {
3777 /* subtract a negative from a positive, OR */
3778 /* subtract a positive from a negative. */
3779 /* In either case, ADD their magnitudes, */
3780 /* and use the sign of the first number. */
3781 c->sign = sa;
3782 res = s_mp_add (a, b, c);
3783 } else {
3784 /* subtract a positive from a positive, OR */
3785 /* subtract a negative from a negative. */
3786 /* First, take the difference between their */
3787 /* magnitudes, then... */
3788 if (mp_cmp_mag (a, b) != MP_LT) {
3789 /* Copy the sign from the first */
3790 c->sign = sa;
3791 /* The first has a larger or equal magnitude */
3792 res = s_mp_sub (a, b, c);
3793 } else {
3794 /* The result has the *opposite* sign from */
3795 /* the first number. */
3796 c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
3797 /* The second has a larger magnitude */
3798 res = s_mp_sub (b, a, c);
3801 return res;
3804 /* single digit subtraction */
3806 mp_sub_d (mp_int * a, mp_digit b, mp_int * c)
3808 mp_digit *tmpa, *tmpc, mu;
3809 int res, ix, oldused;
3811 /* grow c as required */
3812 if (c->alloc < a->used + 1) {
3813 if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
3814 return res;
3818 /* if a is negative just do an unsigned
3819 * addition [with fudged signs]
3821 if (a->sign == MP_NEG) {
3822 a->sign = MP_ZPOS;
3823 res = mp_add_d(a, b, c);
3824 a->sign = c->sign = MP_NEG;
3825 return res;
3828 /* setup regs */
3829 oldused = c->used;
3830 tmpa = a->dp;
3831 tmpc = c->dp;
3833 /* if a <= b simply fix the single digit */
3834 if ((a->used == 1 && a->dp[0] <= b) || a->used == 0) {
3835 if (a->used == 1) {
3836 *tmpc++ = b - *tmpa;
3837 } else {
3838 *tmpc++ = b;
3840 ix = 1;
3842 /* negative/1digit */
3843 c->sign = MP_NEG;
3844 c->used = 1;
3845 } else {
3846 /* positive/size */
3847 c->sign = MP_ZPOS;
3848 c->used = a->used;
3850 /* subtract first digit */
3851 *tmpc = *tmpa++ - b;
3852 mu = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3853 *tmpc++ &= MP_MASK;
3855 /* handle rest of the digits */
3856 for (ix = 1; ix < a->used; ix++) {
3857 *tmpc = *tmpa++ - mu;
3858 mu = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3859 *tmpc++ &= MP_MASK;
3863 /* zero excess digits */
3864 while (ix++ < oldused) {
3865 *tmpc++ = 0;
3867 mp_clamp(c);
3868 return MP_OKAY;
3871 /* store in unsigned [big endian] format */
3873 mp_to_unsigned_bin (const mp_int * a, unsigned char *b)
3875 int x, res;
3876 mp_int t;
3878 if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
3879 return res;
3882 x = 0;
3883 while (mp_iszero (&t) == 0) {
3884 b[x++] = (unsigned char) (t.dp[0] & 255);
3885 if ((res = mp_div_2d (&t, 8, &t, NULL)) != MP_OKAY) {
3886 mp_clear (&t);
3887 return res;
3890 bn_reverse (b, x);
3891 mp_clear (&t);
3892 return MP_OKAY;
3895 /* get the size for an unsigned equivalent */
3897 mp_unsigned_bin_size (const mp_int * a)
3899 int size = mp_count_bits (a);
3900 return (size / 8 + ((size & 7) != 0 ? 1 : 0));
3903 /* reverse an array, used for radix code */
3904 static void
3905 bn_reverse (unsigned char *s, int len)
3907 int ix, iy;
3908 unsigned char t;
3910 ix = 0;
3911 iy = len - 1;
3912 while (ix < iy) {
3913 t = s[ix];
3914 s[ix] = s[iy];
3915 s[iy] = t;
3916 ++ix;
3917 --iy;
3921 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
3922 static int
3923 s_mp_add (mp_int * a, mp_int * b, mp_int * c)
3925 mp_int *x;
3926 int olduse, res, min, max;
3928 /* find sizes, we let |a| <= |b| which means we have to sort
3929 * them. "x" will point to the input with the most digits
3931 if (a->used > b->used) {
3932 min = b->used;
3933 max = a->used;
3934 x = a;
3935 } else {
3936 min = a->used;
3937 max = b->used;
3938 x = b;
3941 /* init result */
3942 if (c->alloc < max + 1) {
3943 if ((res = mp_grow (c, max + 1)) != MP_OKAY) {
3944 return res;
3948 /* get old used digit count and set new one */
3949 olduse = c->used;
3950 c->used = max + 1;
3953 register mp_digit u, *tmpa, *tmpb, *tmpc;
3954 register int i;
3956 /* alias for digit pointers */
3958 /* first input */
3959 tmpa = a->dp;
3961 /* second input */
3962 tmpb = b->dp;
3964 /* destination */
3965 tmpc = c->dp;
3967 /* zero the carry */
3968 u = 0;
3969 for (i = 0; i < min; i++) {
3970 /* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
3971 *tmpc = *tmpa++ + *tmpb++ + u;
3973 /* U = carry bit of T[i] */
3974 u = *tmpc >> ((mp_digit)DIGIT_BIT);
3976 /* take away carry bit from T[i] */
3977 *tmpc++ &= MP_MASK;
3980 /* now copy higher words if any, that is in A+B
3981 * if A or B has more digits add those in
3983 if (min != max) {
3984 for (; i < max; i++) {
3985 /* T[i] = X[i] + U */
3986 *tmpc = x->dp[i] + u;
3988 /* U = carry bit of T[i] */
3989 u = *tmpc >> ((mp_digit)DIGIT_BIT);
3991 /* take away carry bit from T[i] */
3992 *tmpc++ &= MP_MASK;
3996 /* add carry */
3997 *tmpc++ = u;
3999 /* clear digits above oldused */
4000 for (i = c->used; i < olduse; i++) {
4001 *tmpc++ = 0;
4005 mp_clamp (c);
4006 return MP_OKAY;
4009 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
4011 mp_int M[256], res, mu;
4012 mp_digit buf;
4013 int err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
4015 /* find window size */
4016 x = mp_count_bits (X);
4017 if (x <= 7) {
4018 winsize = 2;
4019 } else if (x <= 36) {
4020 winsize = 3;
4021 } else if (x <= 140) {
4022 winsize = 4;
4023 } else if (x <= 450) {
4024 winsize = 5;
4025 } else if (x <= 1303) {
4026 winsize = 6;
4027 } else if (x <= 3529) {
4028 winsize = 7;
4029 } else {
4030 winsize = 8;
4033 /* init M array */
4034 /* init first cell */
4035 if ((err = mp_init(&M[1])) != MP_OKAY) {
4036 return err;
4039 /* now init the second half of the array */
4040 for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4041 if ((err = mp_init(&M[x])) != MP_OKAY) {
4042 for (y = 1<<(winsize-1); y < x; y++) {
4043 mp_clear (&M[y]);
4045 mp_clear(&M[1]);
4046 return err;
4050 /* create mu, used for Barrett reduction */
4051 if ((err = mp_init (&mu)) != MP_OKAY) {
4052 goto __M;
4054 if ((err = mp_reduce_setup (&mu, P)) != MP_OKAY) {
4055 goto __MU;
4058 /* create M table
4060 * The M table contains powers of the base,
4061 * e.g. M[x] = G**x mod P
4063 * The first half of the table is not
4064 * computed though accept for M[0] and M[1]
4066 if ((err = mp_mod (G, P, &M[1])) != MP_OKAY) {
4067 goto __MU;
4070 /* compute the value at M[1<<(winsize-1)] by squaring
4071 * M[1] (winsize-1) times
4073 if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
4074 goto __MU;
4077 for (x = 0; x < (winsize - 1); x++) {
4078 if ((err = mp_sqr (&M[1 << (winsize - 1)],
4079 &M[1 << (winsize - 1)])) != MP_OKAY) {
4080 goto __MU;
4082 if ((err = mp_reduce (&M[1 << (winsize - 1)], P, &mu)) != MP_OKAY) {
4083 goto __MU;
4087 /* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
4088 * for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
4090 for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
4091 if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
4092 goto __MU;
4094 if ((err = mp_reduce (&M[x], P, &mu)) != MP_OKAY) {
4095 goto __MU;
4099 /* setup result */
4100 if ((err = mp_init (&res)) != MP_OKAY) {
4101 goto __MU;
4103 mp_set (&res, 1);
4105 /* set initial mode and bit cnt */
4106 mode = 0;
4107 bitcnt = 1;
4108 buf = 0;
4109 digidx = X->used - 1;
4110 bitcpy = 0;
4111 bitbuf = 0;
4113 for (;;) {
4114 /* grab next digit as required */
4115 if (--bitcnt == 0) {
4116 /* if digidx == -1 we are out of digits */
4117 if (digidx == -1) {
4118 break;
4120 /* read next digit and reset the bitcnt */
4121 buf = X->dp[digidx--];
4122 bitcnt = DIGIT_BIT;
4125 /* grab the next msb from the exponent */
4126 y = (buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
4127 buf <<= (mp_digit)1;
4129 /* if the bit is zero and mode == 0 then we ignore it
4130 * These represent the leading zero bits before the first 1 bit
4131 * in the exponent. Technically this opt is not required but it
4132 * does lower the # of trivial squaring/reductions used
4134 if (mode == 0 && y == 0) {
4135 continue;
4138 /* if the bit is zero and mode == 1 then we square */
4139 if (mode == 1 && y == 0) {
4140 if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4141 goto __RES;
4143 if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4144 goto __RES;
4146 continue;
4149 /* else we add it to the window */
4150 bitbuf |= (y << (winsize - ++bitcpy));
4151 mode = 2;
4153 if (bitcpy == winsize) {
4154 /* ok window is filled so square as required and multiply */
4155 /* square first */
4156 for (x = 0; x < winsize; x++) {
4157 if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4158 goto __RES;
4160 if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4161 goto __RES;
4165 /* then multiply */
4166 if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
4167 goto __RES;
4169 if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4170 goto __RES;
4173 /* empty window and reset */
4174 bitcpy = 0;
4175 bitbuf = 0;
4176 mode = 1;
4180 /* if bits remain then square/multiply */
4181 if (mode == 2 && bitcpy > 0) {
4182 /* square then multiply if the bit is set */
4183 for (x = 0; x < bitcpy; x++) {
4184 if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4185 goto __RES;
4187 if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4188 goto __RES;
4191 bitbuf <<= 1;
4192 if ((bitbuf & (1 << winsize)) != 0) {
4193 /* then multiply */
4194 if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
4195 goto __RES;
4197 if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4198 goto __RES;
4204 mp_exch (&res, Y);
4205 err = MP_OKAY;
4206 __RES:mp_clear (&res);
4207 __MU:mp_clear (&mu);
4208 __M:
4209 mp_clear(&M[1]);
4210 for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4211 mp_clear (&M[x]);
4213 return err;
4216 /* multiplies |a| * |b| and only computes up to digs digits of result
4217 * HAC pp. 595, Algorithm 14.12 Modified so you can control how
4218 * many digits of output are created.
4220 static int
4221 s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4223 mp_int t;
4224 int res, pa, pb, ix, iy;
4225 mp_digit u;
4226 mp_word r;
4227 mp_digit tmpx, *tmpt, *tmpy;
4229 /* can we use the fast multiplier? */
4230 if (((digs) < MP_WARRAY) &&
4231 MIN (a->used, b->used) <
4232 (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4233 return fast_s_mp_mul_digs (a, b, c, digs);
4236 if ((res = mp_init_size (&t, digs)) != MP_OKAY) {
4237 return res;
4239 t.used = digs;
4241 /* compute the digits of the product directly */
4242 pa = a->used;
4243 for (ix = 0; ix < pa; ix++) {
4244 /* set the carry to zero */
4245 u = 0;
4247 /* limit ourselves to making digs digits of output */
4248 pb = MIN (b->used, digs - ix);
4250 /* setup some aliases */
4251 /* copy of the digit from a used within the nested loop */
4252 tmpx = a->dp[ix];
4254 /* an alias for the destination shifted ix places */
4255 tmpt = t.dp + ix;
4257 /* an alias for the digits of b */
4258 tmpy = b->dp;
4260 /* compute the columns of the output and propagate the carry */
4261 for (iy = 0; iy < pb; iy++) {
4262 /* compute the column as a mp_word */
4263 r = ((mp_word)*tmpt) +
4264 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4265 ((mp_word) u);
4267 /* the new column is the lower part of the result */
4268 *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4270 /* get the carry word from the result */
4271 u = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4273 /* set carry if it is placed below digs */
4274 if (ix + iy < digs) {
4275 *tmpt = u;
4279 mp_clamp (&t);
4280 mp_exch (&t, c);
4282 mp_clear (&t);
4283 return MP_OKAY;
4286 /* multiplies |a| * |b| and does not compute the lower digs digits
4287 * [meant to get the higher part of the product]
4289 static int
4290 s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4292 mp_int t;
4293 int res, pa, pb, ix, iy;
4294 mp_digit u;
4295 mp_word r;
4296 mp_digit tmpx, *tmpt, *tmpy;
4298 /* can we use the fast multiplier? */
4299 if (((a->used + b->used + 1) < MP_WARRAY)
4300 && MIN (a->used, b->used) < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4301 return fast_s_mp_mul_high_digs (a, b, c, digs);
4304 if ((res = mp_init_size (&t, a->used + b->used + 1)) != MP_OKAY) {
4305 return res;
4307 t.used = a->used + b->used + 1;
4309 pa = a->used;
4310 pb = b->used;
4311 for (ix = 0; ix < pa; ix++) {
4312 /* clear the carry */
4313 u = 0;
4315 /* left hand side of A[ix] * B[iy] */
4316 tmpx = a->dp[ix];
4318 /* alias to the address of where the digits will be stored */
4319 tmpt = &(t.dp[digs]);
4321 /* alias for where to read the right hand side from */
4322 tmpy = b->dp + (digs - ix);
4324 for (iy = digs - ix; iy < pb; iy++) {
4325 /* calculate the double precision result */
4326 r = ((mp_word)*tmpt) +
4327 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4328 ((mp_word) u);
4330 /* get the lower part */
4331 *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4333 /* carry the carry */
4334 u = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4336 *tmpt = u;
4338 mp_clamp (&t);
4339 mp_exch (&t, c);
4340 mp_clear (&t);
4341 return MP_OKAY;
4344 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
4345 static int
4346 s_mp_sqr (const mp_int * a, mp_int * b)
4348 mp_int t;
4349 int res, ix, iy, pa;
4350 mp_word r;
4351 mp_digit u, tmpx, *tmpt;
4353 pa = a->used;
4354 if ((res = mp_init_size (&t, 2*pa + 1)) != MP_OKAY) {
4355 return res;
4358 /* default used is maximum possible size */
4359 t.used = 2*pa + 1;
4361 for (ix = 0; ix < pa; ix++) {
4362 /* first calculate the digit at 2*ix */
4363 /* calculate double precision result */
4364 r = ((mp_word) t.dp[2*ix]) +
4365 ((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
4367 /* store lower part in result */
4368 t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
4370 /* get the carry */
4371 u = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4373 /* left hand side of A[ix] * A[iy] */
4374 tmpx = a->dp[ix];
4376 /* alias for where to store the results */
4377 tmpt = t.dp + (2*ix + 1);
4379 for (iy = ix + 1; iy < pa; iy++) {
4380 /* first calculate the product */
4381 r = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
4383 /* now calculate the double precision result, note we use
4384 * addition instead of *2 since it's easier to optimize
4386 r = ((mp_word) *tmpt) + r + r + ((mp_word) u);
4388 /* store lower part */
4389 *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4391 /* get carry */
4392 u = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4394 /* propagate upwards */
4395 while (u != 0) {
4396 r = ((mp_word) *tmpt) + ((mp_word) u);
4397 *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4398 u = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4402 mp_clamp (&t);
4403 mp_exch (&t, b);
4404 mp_clear (&t);
4405 return MP_OKAY;
4408 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
4410 s_mp_sub (const mp_int * a, const mp_int * b, mp_int * c)
4412 int olduse, res, min, max;
4414 /* find sizes */
4415 min = b->used;
4416 max = a->used;
4418 /* init result */
4419 if (c->alloc < max) {
4420 if ((res = mp_grow (c, max)) != MP_OKAY) {
4421 return res;
4424 olduse = c->used;
4425 c->used = max;
4428 register mp_digit u, *tmpa, *tmpb, *tmpc;
4429 register int i;
4431 /* alias for digit pointers */
4432 tmpa = a->dp;
4433 tmpb = b->dp;
4434 tmpc = c->dp;
4436 /* set carry to zero */
4437 u = 0;
4438 for (i = 0; i < min; i++) {
4439 /* T[i] = A[i] - B[i] - U */
4440 *tmpc = *tmpa++ - *tmpb++ - u;
4442 /* U = carry bit of T[i]
4443 * Note this saves performing an AND operation since
4444 * if a carry does occur it will propagate all the way to the
4445 * MSB. As a result a single shift is enough to get the carry
4447 u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4449 /* Clear carry from T[i] */
4450 *tmpc++ &= MP_MASK;
4453 /* now copy higher words if any, e.g. if A has more digits than B */
4454 for (; i < max; i++) {
4455 /* T[i] = A[i] - U */
4456 *tmpc = *tmpa++ - u;
4458 /* U = carry bit of T[i] */
4459 u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4461 /* Clear carry from T[i] */
4462 *tmpc++ &= MP_MASK;
4465 /* clear digits above used (since we may not have grown result above) */
4466 for (i = c->used; i < olduse; i++) {
4467 *tmpc++ = 0;
4471 mp_clamp (c);
4472 return MP_OKAY;