d3d8: Add tests for IDirect3D8Device_Reset.
[wine/hacks.git] / dlls / rsaenh / mpi.c
blob0c58e39e96b54a1540d2795d4247e2caeaa2c76b
1 /*
2 * dlls/rsaenh/mpi.c
3 * Multi Precision Integer functions
5 * Copyright 2004 Michael Jung
6 * Based on public domain code by Tom St Denis (tomstdenis@iahu.ca)
8 * This library is free software; you can redistribute it and/or
9 * modify it under the terms of the GNU Lesser General Public
10 * License as published by the Free Software Foundation; either
11 * version 2.1 of the License, or (at your option) any later version.
13 * This library is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16 * Lesser General Public License for more details.
18 * You should have received a copy of the GNU Lesser General Public
19 * License along with this library; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
24 * This file contains code from the LibTomCrypt cryptographic
25 * library written by Tom St Denis (tomstdenis@iahu.ca). LibTomCrypt
26 * is in the public domain. The code in this file is tailored to
27 * special requirements. Take a look at http://libtomcrypt.org for the
28 * original version.
31 #include <stdarg.h>
33 #include "windef.h"
34 #include "winbase.h"
35 #include "tomcrypt.h"
37 /* Known optimal configurations
38 CPU /Compiler /MUL CUTOFF/SQR CUTOFF
39 -------------------------------------------------------------
40 Intel P4 Northwood /GCC v3.4.1 / 88/ 128/LTM 0.32 ;-)
42 static const int KARATSUBA_MUL_CUTOFF = 88, /* Min. number of digits before Karatsuba multiplication is used. */
43 KARATSUBA_SQR_CUTOFF = 128; /* Min. number of digits before Karatsuba squaring is used. */
45 static void bn_reverse(unsigned char *s, int len);
46 static int s_mp_add(mp_int *a, mp_int *b, mp_int *c);
47 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y);
48 #define s_mp_mul(a, b, c) s_mp_mul_digs(a, b, c, (a)->used + (b)->used + 1)
49 static int s_mp_mul_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
50 static int s_mp_mul_high_digs(const mp_int *a, const mp_int *b, mp_int *c, int digs);
51 static int s_mp_sqr(const mp_int *a, mp_int *b);
52 static int s_mp_sub(const mp_int *a, const mp_int *b, mp_int *c);
53 static int mp_exptmod_fast(const mp_int *G, const mp_int *X, mp_int *P, mp_int *Y, int mode);
54 static int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c);
55 static int mp_karatsuba_mul(const mp_int *a, const mp_int *b, mp_int *c);
56 static int mp_karatsuba_sqr(const mp_int *a, mp_int *b);
58 /* grow as required */
59 static int mp_grow (mp_int * a, int size)
61 int i;
62 mp_digit *tmp;
64 /* if the alloc size is smaller alloc more ram */
65 if (a->alloc < size) {
66 /* ensure there are always at least MP_PREC digits extra on top */
67 size += (MP_PREC * 2) - (size % MP_PREC);
69 /* reallocate the array a->dp
71 * We store the return in a temporary variable
72 * in case the operation failed we don't want
73 * to overwrite the dp member of a.
75 tmp = HeapReAlloc(GetProcessHeap(), 0, a->dp, sizeof (mp_digit) * size);
76 if (tmp == NULL) {
77 /* reallocation failed but "a" is still valid [can be freed] */
78 return MP_MEM;
81 /* reallocation succeeded so set a->dp */
82 a->dp = tmp;
84 /* zero excess digits */
85 i = a->alloc;
86 a->alloc = size;
87 for (; i < a->alloc; i++) {
88 a->dp[i] = 0;
91 return MP_OKAY;
94 /* b = a/2 */
95 static int mp_div_2(const mp_int * a, mp_int * b)
97 int x, res, oldused;
99 /* copy */
100 if (b->alloc < a->used) {
101 if ((res = mp_grow (b, a->used)) != MP_OKAY) {
102 return res;
106 oldused = b->used;
107 b->used = a->used;
109 register mp_digit r, rr, *tmpa, *tmpb;
111 /* source alias */
112 tmpa = a->dp + b->used - 1;
114 /* dest alias */
115 tmpb = b->dp + b->used - 1;
117 /* carry */
118 r = 0;
119 for (x = b->used - 1; x >= 0; x--) {
120 /* get the carry for the next iteration */
121 rr = *tmpa & 1;
123 /* shift the current digit, add in carry and store */
124 *tmpb-- = (*tmpa-- >> 1) | (r << (DIGIT_BIT - 1));
126 /* forward carry to next iteration */
127 r = rr;
130 /* zero excess digits */
131 tmpb = b->dp + b->used;
132 for (x = b->used; x < oldused; x++) {
133 *tmpb++ = 0;
136 b->sign = a->sign;
137 mp_clamp (b);
138 return MP_OKAY;
141 /* swap the elements of two integers, for cases where you can't simply swap the
142 * mp_int pointers around
144 static void
145 mp_exch (mp_int * a, mp_int * b)
147 mp_int t;
149 t = *a;
150 *a = *b;
151 *b = t;
154 /* init a new mp_int */
155 static int mp_init (mp_int * a)
157 int i;
159 /* allocate memory required and clear it */
160 a->dp = HeapAlloc(GetProcessHeap(), 0, sizeof (mp_digit) * MP_PREC);
161 if (a->dp == NULL) {
162 return MP_MEM;
165 /* set the digits to zero */
166 for (i = 0; i < MP_PREC; i++) {
167 a->dp[i] = 0;
170 /* set the used to zero, allocated digits to the default precision
171 * and sign to positive */
172 a->used = 0;
173 a->alloc = MP_PREC;
174 a->sign = MP_ZPOS;
176 return MP_OKAY;
179 /* init an mp_init for a given size */
180 static int mp_init_size (mp_int * a, int size)
182 int x;
184 /* pad size so there are always extra digits */
185 size += (MP_PREC * 2) - (size % MP_PREC);
187 /* alloc mem */
188 a->dp = HeapAlloc(GetProcessHeap(), 0, sizeof (mp_digit) * size);
189 if (a->dp == NULL) {
190 return MP_MEM;
193 /* set the members */
194 a->used = 0;
195 a->alloc = size;
196 a->sign = MP_ZPOS;
198 /* zero the digits */
199 for (x = 0; x < size; x++) {
200 a->dp[x] = 0;
203 return MP_OKAY;
206 /* clear one (frees) */
207 static void
208 mp_clear (mp_int * a)
210 int i;
212 /* only do anything if a hasn't been freed previously */
213 if (a->dp != NULL) {
214 /* first zero the digits */
215 for (i = 0; i < a->used; i++) {
216 a->dp[i] = 0;
219 /* free ram */
220 HeapFree(GetProcessHeap(), 0, a->dp);
222 /* reset members to make debugging easier */
223 a->dp = NULL;
224 a->alloc = a->used = 0;
225 a->sign = MP_ZPOS;
229 /* set to zero */
230 static void
231 mp_zero (mp_int * a)
233 a->sign = MP_ZPOS;
234 a->used = 0;
235 memset (a->dp, 0, sizeof (mp_digit) * a->alloc);
238 /* b = |a|
240 * Simple function copies the input and fixes the sign to positive
242 static int
243 mp_abs (const mp_int * a, mp_int * b)
245 int res;
247 /* copy a to b */
248 if (a != b) {
249 if ((res = mp_copy (a, b)) != MP_OKAY) {
250 return res;
254 /* force the sign of b to positive */
255 b->sign = MP_ZPOS;
257 return MP_OKAY;
260 /* computes the modular inverse via binary extended euclidean algorithm,
261 * that is c = 1/a mod b
263 * Based on slow invmod except this is optimized for the case where b is
264 * odd as per HAC Note 14.64 on pp. 610
266 static int
267 fast_mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
269 mp_int x, y, u, v, B, D;
270 int res, neg;
272 /* 2. [modified] b must be odd */
273 if (mp_iseven (b) == 1) {
274 return MP_VAL;
277 /* init all our temps */
278 if ((res = mp_init_multi(&x, &y, &u, &v, &B, &D, NULL)) != MP_OKAY) {
279 return res;
282 /* x == modulus, y == value to invert */
283 if ((res = mp_copy (b, &x)) != MP_OKAY) {
284 goto __ERR;
287 /* we need y = |a| */
288 if ((res = mp_abs (a, &y)) != MP_OKAY) {
289 goto __ERR;
292 /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
293 if ((res = mp_copy (&x, &u)) != MP_OKAY) {
294 goto __ERR;
296 if ((res = mp_copy (&y, &v)) != MP_OKAY) {
297 goto __ERR;
299 mp_set (&D, 1);
301 top:
302 /* 4. while u is even do */
303 while (mp_iseven (&u) == 1) {
304 /* 4.1 u = u/2 */
305 if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
306 goto __ERR;
308 /* 4.2 if B is odd then */
309 if (mp_isodd (&B) == 1) {
310 if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
311 goto __ERR;
314 /* B = B/2 */
315 if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
316 goto __ERR;
320 /* 5. while v is even do */
321 while (mp_iseven (&v) == 1) {
322 /* 5.1 v = v/2 */
323 if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
324 goto __ERR;
326 /* 5.2 if D is odd then */
327 if (mp_isodd (&D) == 1) {
328 /* D = (D-x)/2 */
329 if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
330 goto __ERR;
333 /* D = D/2 */
334 if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
335 goto __ERR;
339 /* 6. if u >= v then */
340 if (mp_cmp (&u, &v) != MP_LT) {
341 /* u = u - v, B = B - D */
342 if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
343 goto __ERR;
346 if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
347 goto __ERR;
349 } else {
350 /* v - v - u, D = D - B */
351 if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
352 goto __ERR;
355 if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
356 goto __ERR;
360 /* if not zero goto step 4 */
361 if (mp_iszero (&u) == 0) {
362 goto top;
365 /* now a = C, b = D, gcd == g*v */
367 /* if v != 1 then there is no inverse */
368 if (mp_cmp_d (&v, 1) != MP_EQ) {
369 res = MP_VAL;
370 goto __ERR;
373 /* b is now the inverse */
374 neg = a->sign;
375 while (D.sign == MP_NEG) {
376 if ((res = mp_add (&D, b, &D)) != MP_OKAY) {
377 goto __ERR;
380 mp_exch (&D, c);
381 c->sign = neg;
382 res = MP_OKAY;
384 __ERR:mp_clear_multi (&x, &y, &u, &v, &B, &D, NULL);
385 return res;
388 /* computes xR**-1 == x (mod N) via Montgomery Reduction
390 * This is an optimized implementation of montgomery_reduce
391 * which uses the comba method to quickly calculate the columns of the
392 * reduction.
394 * Based on Algorithm 14.32 on pp.601 of HAC.
396 static int
397 fast_mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
399 int ix, res, olduse;
400 mp_word W[MP_WARRAY];
402 /* get old used count */
403 olduse = x->used;
405 /* grow a as required */
406 if (x->alloc < n->used + 1) {
407 if ((res = mp_grow (x, n->used + 1)) != MP_OKAY) {
408 return res;
412 /* first we have to get the digits of the input into
413 * an array of double precision words W[...]
416 register mp_word *_W;
417 register mp_digit *tmpx;
419 /* alias for the W[] array */
420 _W = W;
422 /* alias for the digits of x*/
423 tmpx = x->dp;
425 /* copy the digits of a into W[0..a->used-1] */
426 for (ix = 0; ix < x->used; ix++) {
427 *_W++ = *tmpx++;
430 /* zero the high words of W[a->used..m->used*2] */
431 for (; ix < n->used * 2 + 1; ix++) {
432 *_W++ = 0;
436 /* now we proceed to zero successive digits
437 * from the least significant upwards
439 for (ix = 0; ix < n->used; ix++) {
440 /* mu = ai * m' mod b
442 * We avoid a double precision multiplication (which isn't required)
443 * by casting the value down to a mp_digit. Note this requires
444 * that W[ix-1] have the carry cleared (see after the inner loop)
446 register mp_digit mu;
447 mu = (mp_digit) (((W[ix] & MP_MASK) * rho) & MP_MASK);
449 /* a = a + mu * m * b**i
451 * This is computed in place and on the fly. The multiplication
452 * by b**i is handled by offsetting which columns the results
453 * are added to.
455 * Note the comba method normally doesn't handle carries in the
456 * inner loop In this case we fix the carry from the previous
457 * column since the Montgomery reduction requires digits of the
458 * result (so far) [see above] to work. This is
459 * handled by fixing up one carry after the inner loop. The
460 * carry fixups are done in order so after these loops the
461 * first m->used words of W[] have the carries fixed
464 register int iy;
465 register mp_digit *tmpn;
466 register mp_word *_W;
468 /* alias for the digits of the modulus */
469 tmpn = n->dp;
471 /* Alias for the columns set by an offset of ix */
472 _W = W + ix;
474 /* inner loop */
475 for (iy = 0; iy < n->used; iy++) {
476 *_W++ += ((mp_word)mu) * ((mp_word)*tmpn++);
480 /* now fix carry for next digit, W[ix+1] */
481 W[ix + 1] += W[ix] >> ((mp_word) DIGIT_BIT);
484 /* now we have to propagate the carries and
485 * shift the words downward [all those least
486 * significant digits we zeroed].
489 register mp_digit *tmpx;
490 register mp_word *_W, *_W1;
492 /* nox fix rest of carries */
494 /* alias for current word */
495 _W1 = W + ix;
497 /* alias for next word, where the carry goes */
498 _W = W + ++ix;
500 for (; ix <= n->used * 2 + 1; ix++) {
501 *_W++ += *_W1++ >> ((mp_word) DIGIT_BIT);
504 /* copy out, A = A/b**n
506 * The result is A/b**n but instead of converting from an
507 * array of mp_word to mp_digit than calling mp_rshd
508 * we just copy them in the right order
511 /* alias for destination word */
512 tmpx = x->dp;
514 /* alias for shifted double precision result */
515 _W = W + n->used;
517 for (ix = 0; ix < n->used + 1; ix++) {
518 *tmpx++ = (mp_digit)(*_W++ & ((mp_word) MP_MASK));
521 /* zero oldused digits, if the input a was larger than
522 * m->used+1 we'll have to clear the digits
524 for (; ix < olduse; ix++) {
525 *tmpx++ = 0;
529 /* set the max used and clamp */
530 x->used = n->used + 1;
531 mp_clamp (x);
533 /* if A >= m then A = A - m */
534 if (mp_cmp_mag (x, n) != MP_LT) {
535 return s_mp_sub (x, n, x);
537 return MP_OKAY;
540 /* Fast (comba) multiplier
542 * This is the fast column-array [comba] multiplier. It is
543 * designed to compute the columns of the product first
544 * then handle the carries afterwards. This has the effect
545 * of making the nested loops that compute the columns very
546 * simple and schedulable on super-scalar processors.
548 * This has been modified to produce a variable number of
549 * digits of output so if say only a half-product is required
550 * you don't have to compute the upper half (a feature
551 * required for fast Barrett reduction).
553 * Based on Algorithm 14.12 on pp.595 of HAC.
556 static int
557 fast_s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
559 int olduse, res, pa, ix, iz;
560 mp_digit W[MP_WARRAY];
561 register mp_word _W;
563 /* grow the destination as required */
564 if (c->alloc < digs) {
565 if ((res = mp_grow (c, digs)) != MP_OKAY) {
566 return res;
570 /* number of output digits to produce */
571 pa = MIN(digs, a->used + b->used);
573 /* clear the carry */
574 _W = 0;
575 for (ix = 0; ix <= pa; ix++) {
576 int tx, ty;
577 int iy;
578 mp_digit *tmpx, *tmpy;
580 /* get offsets into the two bignums */
581 ty = MIN(b->used-1, ix);
582 tx = ix - ty;
584 /* setup temp aliases */
585 tmpx = a->dp + tx;
586 tmpy = b->dp + ty;
588 /* This is the number of times the loop will iterate, essentially it's
589 while (tx++ < a->used && ty-- >= 0) { ... }
591 iy = MIN(a->used-tx, ty+1);
593 /* execute loop */
594 for (iz = 0; iz < iy; ++iz) {
595 _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
598 /* store term */
599 W[ix] = ((mp_digit)_W) & MP_MASK;
601 /* make next carry */
602 _W = _W >> ((mp_word)DIGIT_BIT);
605 /* setup dest */
606 olduse = c->used;
607 c->used = digs;
610 register mp_digit *tmpc;
611 tmpc = c->dp;
612 for (ix = 0; ix < digs; ix++) {
613 /* now extract the previous digit [below the carry] */
614 *tmpc++ = W[ix];
617 /* clear unused digits [that existed in the old copy of c] */
618 for (; ix < olduse; ix++) {
619 *tmpc++ = 0;
622 mp_clamp (c);
623 return MP_OKAY;
626 /* this is a modified version of fast_s_mul_digs that only produces
627 * output digits *above* digs. See the comments for fast_s_mul_digs
628 * to see how it works.
630 * This is used in the Barrett reduction since for one of the multiplications
631 * only the higher digits were needed. This essentially halves the work.
633 * Based on Algorithm 14.12 on pp.595 of HAC.
635 static int
636 fast_s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
638 int olduse, res, pa, ix, iz;
639 mp_digit W[MP_WARRAY];
640 mp_word _W;
642 /* grow the destination as required */
643 pa = a->used + b->used;
644 if (c->alloc < pa) {
645 if ((res = mp_grow (c, pa)) != MP_OKAY) {
646 return res;
650 /* number of output digits to produce */
651 pa = a->used + b->used;
652 _W = 0;
653 for (ix = digs; ix <= pa; ix++) {
654 int tx, ty, iy;
655 mp_digit *tmpx, *tmpy;
657 /* get offsets into the two bignums */
658 ty = MIN(b->used-1, ix);
659 tx = ix - ty;
661 /* setup temp aliases */
662 tmpx = a->dp + tx;
663 tmpy = b->dp + ty;
665 /* This is the number of times the loop will iterate, essentially it's
666 while (tx++ < a->used && ty-- >= 0) { ... }
668 iy = MIN(a->used-tx, ty+1);
670 /* execute loop */
671 for (iz = 0; iz < iy; iz++) {
672 _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
675 /* store term */
676 W[ix] = ((mp_digit)_W) & MP_MASK;
678 /* make next carry */
679 _W = _W >> ((mp_word)DIGIT_BIT);
682 /* setup dest */
683 olduse = c->used;
684 c->used = pa;
687 register mp_digit *tmpc;
689 tmpc = c->dp + digs;
690 for (ix = digs; ix <= pa; ix++) {
691 /* now extract the previous digit [below the carry] */
692 *tmpc++ = W[ix];
695 /* clear unused digits [that existed in the old copy of c] */
696 for (; ix < olduse; ix++) {
697 *tmpc++ = 0;
700 mp_clamp (c);
701 return MP_OKAY;
704 /* fast squaring
706 * This is the comba method where the columns of the product
707 * are computed first then the carries are computed. This
708 * has the effect of making a very simple inner loop that
709 * is executed the most
711 * W2 represents the outer products and W the inner.
713 * A further optimizations is made because the inner
714 * products are of the form "A * B * 2". The *2 part does
715 * not need to be computed until the end which is good
716 * because 64-bit shifts are slow!
718 * Based on Algorithm 14.16 on pp.597 of HAC.
721 /* the jist of squaring...
723 you do like mult except the offset of the tmpx [one that starts closer to zero]
724 can't equal the offset of tmpy. So basically you set up iy like before then you min it with
725 (ty-tx) so that it never happens. You double all those you add in the inner loop
727 After that loop you do the squares and add them in.
729 Remove W2 and don't memset W
733 static int fast_s_mp_sqr (const mp_int * a, mp_int * b)
735 int olduse, res, pa, ix, iz;
736 mp_digit W[MP_WARRAY], *tmpx;
737 mp_word W1;
739 /* grow the destination as required */
740 pa = a->used + a->used;
741 if (b->alloc < pa) {
742 if ((res = mp_grow (b, pa)) != MP_OKAY) {
743 return res;
747 /* number of output digits to produce */
748 W1 = 0;
749 for (ix = 0; ix <= pa; ix++) {
750 int tx, ty, iy;
751 mp_word _W;
752 mp_digit *tmpy;
754 /* clear counter */
755 _W = 0;
757 /* get offsets into the two bignums */
758 ty = MIN(a->used-1, ix);
759 tx = ix - ty;
761 /* setup temp aliases */
762 tmpx = a->dp + tx;
763 tmpy = a->dp + ty;
765 /* This is the number of times the loop will iterate, essentially it's
766 while (tx++ < a->used && ty-- >= 0) { ... }
768 iy = MIN(a->used-tx, ty+1);
770 /* now for squaring tx can never equal ty
771 * we halve the distance since they approach at a rate of 2x
772 * and we have to round because odd cases need to be executed
774 iy = MIN(iy, (ty-tx+1)>>1);
776 /* execute loop */
777 for (iz = 0; iz < iy; iz++) {
778 _W += ((mp_word)*tmpx++)*((mp_word)*tmpy--);
781 /* double the inner product and add carry */
782 _W = _W + _W + W1;
784 /* even columns have the square term in them */
785 if ((ix&1) == 0) {
786 _W += ((mp_word)a->dp[ix>>1])*((mp_word)a->dp[ix>>1]);
789 /* store it */
790 W[ix] = _W;
792 /* make next carry */
793 W1 = _W >> ((mp_word)DIGIT_BIT);
796 /* setup dest */
797 olduse = b->used;
798 b->used = a->used+a->used;
801 mp_digit *tmpb;
802 tmpb = b->dp;
803 for (ix = 0; ix < pa; ix++) {
804 *tmpb++ = W[ix] & MP_MASK;
807 /* clear unused digits [that existed in the old copy of c] */
808 for (; ix < olduse; ix++) {
809 *tmpb++ = 0;
812 mp_clamp (b);
813 return MP_OKAY;
816 /* computes a = 2**b
818 * Simple algorithm which zeroes the int, grows it then just sets one bit
819 * as required.
821 static int
822 mp_2expt (mp_int * a, int b)
824 int res;
826 /* zero a as per default */
827 mp_zero (a);
829 /* grow a to accommodate the single bit */
830 if ((res = mp_grow (a, b / DIGIT_BIT + 1)) != MP_OKAY) {
831 return res;
834 /* set the used count of where the bit will go */
835 a->used = b / DIGIT_BIT + 1;
837 /* put the single bit in its place */
838 a->dp[b / DIGIT_BIT] = ((mp_digit)1) << (b % DIGIT_BIT);
840 return MP_OKAY;
843 /* high level addition (handles signs) */
844 int mp_add (mp_int * a, mp_int * b, mp_int * c)
846 int sa, sb, res;
848 /* get sign of both inputs */
849 sa = a->sign;
850 sb = b->sign;
852 /* handle two cases, not four */
853 if (sa == sb) {
854 /* both positive or both negative */
855 /* add their magnitudes, copy the sign */
856 c->sign = sa;
857 res = s_mp_add (a, b, c);
858 } else {
859 /* one positive, the other negative */
860 /* subtract the one with the greater magnitude from */
861 /* the one of the lesser magnitude. The result gets */
862 /* the sign of the one with the greater magnitude. */
863 if (mp_cmp_mag (a, b) == MP_LT) {
864 c->sign = sb;
865 res = s_mp_sub (b, a, c);
866 } else {
867 c->sign = sa;
868 res = s_mp_sub (a, b, c);
871 return res;
875 /* single digit addition */
876 static int
877 mp_add_d (mp_int * a, mp_digit b, mp_int * c)
879 int res, ix, oldused;
880 mp_digit *tmpa, *tmpc, mu;
882 /* grow c as required */
883 if (c->alloc < a->used + 1) {
884 if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
885 return res;
889 /* if a is negative and |a| >= b, call c = |a| - b */
890 if (a->sign == MP_NEG && (a->used > 1 || a->dp[0] >= b)) {
891 /* temporarily fix sign of a */
892 a->sign = MP_ZPOS;
894 /* c = |a| - b */
895 res = mp_sub_d(a, b, c);
897 /* fix sign */
898 a->sign = c->sign = MP_NEG;
900 return res;
903 /* old number of used digits in c */
904 oldused = c->used;
906 /* sign always positive */
907 c->sign = MP_ZPOS;
909 /* source alias */
910 tmpa = a->dp;
912 /* destination alias */
913 tmpc = c->dp;
915 /* if a is positive */
916 if (a->sign == MP_ZPOS) {
917 /* add digit, after this we're propagating
918 * the carry.
920 *tmpc = *tmpa++ + b;
921 mu = *tmpc >> DIGIT_BIT;
922 *tmpc++ &= MP_MASK;
924 /* now handle rest of the digits */
925 for (ix = 1; ix < a->used; ix++) {
926 *tmpc = *tmpa++ + mu;
927 mu = *tmpc >> DIGIT_BIT;
928 *tmpc++ &= MP_MASK;
930 /* set final carry */
931 ix++;
932 *tmpc++ = mu;
934 /* setup size */
935 c->used = a->used + 1;
936 } else {
937 /* a was negative and |a| < b */
938 c->used = 1;
940 /* the result is a single digit */
941 if (a->used == 1) {
942 *tmpc++ = b - a->dp[0];
943 } else {
944 *tmpc++ = b;
947 /* setup count so the clearing of oldused
948 * can fall through correctly
950 ix = 1;
953 /* now zero to oldused */
954 while (ix++ < oldused) {
955 *tmpc++ = 0;
957 mp_clamp(c);
959 return MP_OKAY;
962 /* trim unused digits
964 * This is used to ensure that leading zero digits are
965 * trimed and the leading "used" digit will be non-zero
966 * Typically very fast. Also fixes the sign if there
967 * are no more leading digits
969 void
970 mp_clamp (mp_int * a)
972 /* decrease used while the most significant digit is
973 * zero.
975 while (a->used > 0 && a->dp[a->used - 1] == 0) {
976 --(a->used);
979 /* reset the sign flag if used == 0 */
980 if (a->used == 0) {
981 a->sign = MP_ZPOS;
985 void mp_clear_multi(mp_int *mp, ...)
987 mp_int* next_mp = mp;
988 va_list args;
989 va_start(args, mp);
990 while (next_mp != NULL) {
991 mp_clear(next_mp);
992 next_mp = va_arg(args, mp_int*);
994 va_end(args);
997 /* compare two ints (signed)*/
999 mp_cmp (const mp_int * a, const mp_int * b)
1001 /* compare based on sign */
1002 if (a->sign != b->sign) {
1003 if (a->sign == MP_NEG) {
1004 return MP_LT;
1005 } else {
1006 return MP_GT;
1010 /* compare digits */
1011 if (a->sign == MP_NEG) {
1012 /* if negative compare opposite direction */
1013 return mp_cmp_mag(b, a);
1014 } else {
1015 return mp_cmp_mag(a, b);
1019 /* compare a digit */
1020 int mp_cmp_d(const mp_int * a, mp_digit b)
1022 /* compare based on sign */
1023 if (a->sign == MP_NEG) {
1024 return MP_LT;
1027 /* compare based on magnitude */
1028 if (a->used > 1) {
1029 return MP_GT;
1032 /* compare the only digit of a to b */
1033 if (a->dp[0] > b) {
1034 return MP_GT;
1035 } else if (a->dp[0] < b) {
1036 return MP_LT;
1037 } else {
1038 return MP_EQ;
1042 /* compare maginitude of two ints (unsigned) */
1043 int mp_cmp_mag (const mp_int * a, const mp_int * b)
1045 int n;
1046 mp_digit *tmpa, *tmpb;
1048 /* compare based on # of non-zero digits */
1049 if (a->used > b->used) {
1050 return MP_GT;
1053 if (a->used < b->used) {
1054 return MP_LT;
1057 /* alias for a */
1058 tmpa = a->dp + (a->used - 1);
1060 /* alias for b */
1061 tmpb = b->dp + (a->used - 1);
1063 /* compare based on digits */
1064 for (n = 0; n < a->used; ++n, --tmpa, --tmpb) {
1065 if (*tmpa > *tmpb) {
1066 return MP_GT;
1069 if (*tmpa < *tmpb) {
1070 return MP_LT;
1073 return MP_EQ;
1076 static const int lnz[16] = {
1077 4, 0, 1, 0, 2, 0, 1, 0, 3, 0, 1, 0, 2, 0, 1, 0
1080 /* Counts the number of lsbs which are zero before the first zero bit */
1081 int mp_cnt_lsb(const mp_int *a)
1083 int x;
1084 mp_digit q, qq;
1086 /* easy out */
1087 if (mp_iszero(a) == 1) {
1088 return 0;
1091 /* scan lower digits until non-zero */
1092 for (x = 0; x < a->used && a->dp[x] == 0; x++);
1093 q = a->dp[x];
1094 x *= DIGIT_BIT;
1096 /* now scan this digit until a 1 is found */
1097 if ((q & 1) == 0) {
1098 do {
1099 qq = q & 15;
1100 x += lnz[qq];
1101 q >>= 4;
1102 } while (qq == 0);
1104 return x;
1107 /* copy, b = a */
1109 mp_copy (const mp_int * a, mp_int * b)
1111 int res, n;
1113 /* if dst == src do nothing */
1114 if (a == b) {
1115 return MP_OKAY;
1118 /* grow dest */
1119 if (b->alloc < a->used) {
1120 if ((res = mp_grow (b, a->used)) != MP_OKAY) {
1121 return res;
1125 /* zero b and copy the parameters over */
1127 register mp_digit *tmpa, *tmpb;
1129 /* pointer aliases */
1131 /* source */
1132 tmpa = a->dp;
1134 /* destination */
1135 tmpb = b->dp;
1137 /* copy all the digits */
1138 for (n = 0; n < a->used; n++) {
1139 *tmpb++ = *tmpa++;
1142 /* clear high digits */
1143 for (; n < b->used; n++) {
1144 *tmpb++ = 0;
1148 /* copy used count and sign */
1149 b->used = a->used;
1150 b->sign = a->sign;
1151 return MP_OKAY;
1154 /* returns the number of bits in an int */
1156 mp_count_bits (const mp_int * a)
1158 int r;
1159 mp_digit q;
1161 /* shortcut */
1162 if (a->used == 0) {
1163 return 0;
1166 /* get number of digits and add that */
1167 r = (a->used - 1) * DIGIT_BIT;
1169 /* take the last digit and count the bits in it */
1170 q = a->dp[a->used - 1];
1171 while (q > 0) {
1172 ++r;
1173 q >>= ((mp_digit) 1);
1175 return r;
1178 /* calc a value mod 2**b */
1179 static int
1180 mp_mod_2d (const mp_int * a, int b, mp_int * c)
1182 int x, res;
1184 /* if b is <= 0 then zero the int */
1185 if (b <= 0) {
1186 mp_zero (c);
1187 return MP_OKAY;
1190 /* if the modulus is larger than the value than return */
1191 if (b > a->used * DIGIT_BIT) {
1192 res = mp_copy (a, c);
1193 return res;
1196 /* copy */
1197 if ((res = mp_copy (a, c)) != MP_OKAY) {
1198 return res;
1201 /* zero digits above the last digit of the modulus */
1202 for (x = (b / DIGIT_BIT) + ((b % DIGIT_BIT) == 0 ? 0 : 1); x < c->used; x++) {
1203 c->dp[x] = 0;
1205 /* clear the digit that is not completely outside/inside the modulus */
1206 c->dp[b / DIGIT_BIT] &= (1 << ((mp_digit)b % DIGIT_BIT)) - 1;
1207 mp_clamp (c);
1208 return MP_OKAY;
1211 /* shift right a certain amount of digits */
1212 static void mp_rshd (mp_int * a, int b)
1214 int x;
1216 /* if b <= 0 then ignore it */
1217 if (b <= 0) {
1218 return;
1221 /* if b > used then simply zero it and return */
1222 if (a->used <= b) {
1223 mp_zero (a);
1224 return;
1228 register mp_digit *bottom, *top;
1230 /* shift the digits down */
1232 /* bottom */
1233 bottom = a->dp;
1235 /* top [offset into digits] */
1236 top = a->dp + b;
1238 /* this is implemented as a sliding window where
1239 * the window is b-digits long and digits from
1240 * the top of the window are copied to the bottom
1242 * e.g.
1244 b-2 | b-1 | b0 | b1 | b2 | ... | bb | ---->
1245 /\ | ---->
1246 \-------------------/ ---->
1248 for (x = 0; x < (a->used - b); x++) {
1249 *bottom++ = *top++;
1252 /* zero the top digits */
1253 for (; x < a->used; x++) {
1254 *bottom++ = 0;
1258 /* remove excess digits */
1259 a->used -= b;
1262 /* shift right by a certain bit count (store quotient in c, optional remainder in d) */
1263 static int mp_div_2d (const mp_int * a, int b, mp_int * c, mp_int * d)
1265 mp_digit D, r, rr;
1266 int x, res;
1267 mp_int t;
1270 /* if the shift count is <= 0 then we do no work */
1271 if (b <= 0) {
1272 res = mp_copy (a, c);
1273 if (d != NULL) {
1274 mp_zero (d);
1276 return res;
1279 if ((res = mp_init (&t)) != MP_OKAY) {
1280 return res;
1283 /* get the remainder */
1284 if (d != NULL) {
1285 if ((res = mp_mod_2d (a, b, &t)) != MP_OKAY) {
1286 mp_clear (&t);
1287 return res;
1291 /* copy */
1292 if ((res = mp_copy (a, c)) != MP_OKAY) {
1293 mp_clear (&t);
1294 return res;
1297 /* shift by as many digits in the bit count */
1298 if (b >= DIGIT_BIT) {
1299 mp_rshd (c, b / DIGIT_BIT);
1302 /* shift any bit count < DIGIT_BIT */
1303 D = (mp_digit) (b % DIGIT_BIT);
1304 if (D != 0) {
1305 register mp_digit *tmpc, mask, shift;
1307 /* mask */
1308 mask = (((mp_digit)1) << D) - 1;
1310 /* shift for lsb */
1311 shift = DIGIT_BIT - D;
1313 /* alias */
1314 tmpc = c->dp + (c->used - 1);
1316 /* carry */
1317 r = 0;
1318 for (x = c->used - 1; x >= 0; x--) {
1319 /* get the lower bits of this word in a temp */
1320 rr = *tmpc & mask;
1322 /* shift the current word and mix in the carry bits from the previous word */
1323 *tmpc = (*tmpc >> D) | (r << shift);
1324 --tmpc;
1326 /* set the carry to the carry bits of the current word found above */
1327 r = rr;
1330 mp_clamp (c);
1331 if (d != NULL) {
1332 mp_exch (&t, d);
1334 mp_clear (&t);
1335 return MP_OKAY;
1338 /* shift left a certain amount of digits */
1339 static int mp_lshd (mp_int * a, int b)
1341 int x, res;
1343 /* if its less than zero return */
1344 if (b <= 0) {
1345 return MP_OKAY;
1348 /* grow to fit the new digits */
1349 if (a->alloc < a->used + b) {
1350 if ((res = mp_grow (a, a->used + b)) != MP_OKAY) {
1351 return res;
1356 register mp_digit *top, *bottom;
1358 /* increment the used by the shift amount then copy upwards */
1359 a->used += b;
1361 /* top */
1362 top = a->dp + a->used - 1;
1364 /* base */
1365 bottom = a->dp + a->used - 1 - b;
1367 /* much like mp_rshd this is implemented using a sliding window
1368 * except the window goes the otherway around. Copying from
1369 * the bottom to the top. see bn_mp_rshd.c for more info.
1371 for (x = a->used - 1; x >= b; x--) {
1372 *top-- = *bottom--;
1375 /* zero the lower digits */
1376 top = a->dp;
1377 for (x = 0; x < b; x++) {
1378 *top++ = 0;
1381 return MP_OKAY;
1384 /* shift left by a certain bit count */
1385 static int mp_mul_2d (const mp_int * a, int b, mp_int * c)
1387 mp_digit d;
1388 int res;
1390 /* copy */
1391 if (a != c) {
1392 if ((res = mp_copy (a, c)) != MP_OKAY) {
1393 return res;
1397 if (c->alloc < c->used + b/DIGIT_BIT + 1) {
1398 if ((res = mp_grow (c, c->used + b / DIGIT_BIT + 1)) != MP_OKAY) {
1399 return res;
1403 /* shift by as many digits in the bit count */
1404 if (b >= DIGIT_BIT) {
1405 if ((res = mp_lshd (c, b / DIGIT_BIT)) != MP_OKAY) {
1406 return res;
1410 /* shift any bit count < DIGIT_BIT */
1411 d = (mp_digit) (b % DIGIT_BIT);
1412 if (d != 0) {
1413 register mp_digit *tmpc, shift, mask, r, rr;
1414 register int x;
1416 /* bitmask for carries */
1417 mask = (((mp_digit)1) << d) - 1;
1419 /* shift for msbs */
1420 shift = DIGIT_BIT - d;
1422 /* alias */
1423 tmpc = c->dp;
1425 /* carry */
1426 r = 0;
1427 for (x = 0; x < c->used; x++) {
1428 /* get the higher bits of the current word */
1429 rr = (*tmpc >> shift) & mask;
1431 /* shift the current word and OR in the carry */
1432 *tmpc = ((*tmpc << d) | r) & MP_MASK;
1433 ++tmpc;
1435 /* set the carry to the carry bits of the current word */
1436 r = rr;
1439 /* set final carry */
1440 if (r != 0) {
1441 c->dp[(c->used)++] = r;
1444 mp_clamp (c);
1445 return MP_OKAY;
1448 /* multiply by a digit */
1449 static int
1450 mp_mul_d (const mp_int * a, mp_digit b, mp_int * c)
1452 mp_digit u, *tmpa, *tmpc;
1453 mp_word r;
1454 int ix, res, olduse;
1456 /* make sure c is big enough to hold a*b */
1457 if (c->alloc < a->used + 1) {
1458 if ((res = mp_grow (c, a->used + 1)) != MP_OKAY) {
1459 return res;
1463 /* get the original destinations used count */
1464 olduse = c->used;
1466 /* set the sign */
1467 c->sign = a->sign;
1469 /* alias for a->dp [source] */
1470 tmpa = a->dp;
1472 /* alias for c->dp [dest] */
1473 tmpc = c->dp;
1475 /* zero carry */
1476 u = 0;
1478 /* compute columns */
1479 for (ix = 0; ix < a->used; ix++) {
1480 /* compute product and carry sum for this term */
1481 r = ((mp_word) u) + ((mp_word)*tmpa++) * ((mp_word)b);
1483 /* mask off higher bits to get a single digit */
1484 *tmpc++ = (mp_digit) (r & ((mp_word) MP_MASK));
1486 /* send carry into next iteration */
1487 u = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
1490 /* store final carry [if any] */
1491 *tmpc++ = u;
1493 /* now zero digits above the top */
1494 while (ix++ < olduse) {
1495 *tmpc++ = 0;
1498 /* set used count */
1499 c->used = a->used + 1;
1500 mp_clamp(c);
1502 return MP_OKAY;
1505 /* integer signed division.
1506 * c*b + d == a [e.g. a/b, c=quotient, d=remainder]
1507 * HAC pp.598 Algorithm 14.20
1509 * Note that the description in HAC is horribly
1510 * incomplete. For example, it doesn't consider
1511 * the case where digits are removed from 'x' in
1512 * the inner loop. It also doesn't consider the
1513 * case that y has fewer than three digits, etc..
1515 * The overall algorithm is as described as
1516 * 14.20 from HAC but fixed to treat these cases.
1518 static int mp_div (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
1520 mp_int q, x, y, t1, t2;
1521 int res, n, t, i, norm, neg;
1523 /* is divisor zero ? */
1524 if (mp_iszero (b) == 1) {
1525 return MP_VAL;
1528 /* if a < b then q=0, r = a */
1529 if (mp_cmp_mag (a, b) == MP_LT) {
1530 if (d != NULL) {
1531 res = mp_copy (a, d);
1532 } else {
1533 res = MP_OKAY;
1535 if (c != NULL) {
1536 mp_zero (c);
1538 return res;
1541 if ((res = mp_init_size (&q, a->used + 2)) != MP_OKAY) {
1542 return res;
1544 q.used = a->used + 2;
1546 if ((res = mp_init (&t1)) != MP_OKAY) {
1547 goto __Q;
1550 if ((res = mp_init (&t2)) != MP_OKAY) {
1551 goto __T1;
1554 if ((res = mp_init_copy (&x, a)) != MP_OKAY) {
1555 goto __T2;
1558 if ((res = mp_init_copy (&y, b)) != MP_OKAY) {
1559 goto __X;
1562 /* fix the sign */
1563 neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
1564 x.sign = y.sign = MP_ZPOS;
1566 /* normalize both x and y, ensure that y >= b/2, [b == 2**DIGIT_BIT] */
1567 norm = mp_count_bits(&y) % DIGIT_BIT;
1568 if (norm < DIGIT_BIT-1) {
1569 norm = (DIGIT_BIT-1) - norm;
1570 if ((res = mp_mul_2d (&x, norm, &x)) != MP_OKAY) {
1571 goto __Y;
1573 if ((res = mp_mul_2d (&y, norm, &y)) != MP_OKAY) {
1574 goto __Y;
1576 } else {
1577 norm = 0;
1580 /* note hac does 0 based, so if used==5 then its 0,1,2,3,4, e.g. use 4 */
1581 n = x.used - 1;
1582 t = y.used - 1;
1584 /* while (x >= y*b**n-t) do { q[n-t] += 1; x -= y*b**{n-t} } */
1585 if ((res = mp_lshd (&y, n - t)) != MP_OKAY) { /* y = y*b**{n-t} */
1586 goto __Y;
1589 while (mp_cmp (&x, &y) != MP_LT) {
1590 ++(q.dp[n - t]);
1591 if ((res = mp_sub (&x, &y, &x)) != MP_OKAY) {
1592 goto __Y;
1596 /* reset y by shifting it back down */
1597 mp_rshd (&y, n - t);
1599 /* step 3. for i from n down to (t + 1) */
1600 for (i = n; i >= (t + 1); i--) {
1601 if (i > x.used) {
1602 continue;
1605 /* step 3.1 if xi == yt then set q{i-t-1} to b-1,
1606 * otherwise set q{i-t-1} to (xi*b + x{i-1})/yt */
1607 if (x.dp[i] == y.dp[t]) {
1608 q.dp[i - t - 1] = ((((mp_digit)1) << DIGIT_BIT) - 1);
1609 } else {
1610 mp_word tmp;
1611 tmp = ((mp_word) x.dp[i]) << ((mp_word) DIGIT_BIT);
1612 tmp |= ((mp_word) x.dp[i - 1]);
1613 tmp /= ((mp_word) y.dp[t]);
1614 if (tmp > (mp_word) MP_MASK)
1615 tmp = MP_MASK;
1616 q.dp[i - t - 1] = (mp_digit) (tmp & (mp_word) (MP_MASK));
1619 /* while (q{i-t-1} * (yt * b + y{t-1})) >
1620 xi * b**2 + xi-1 * b + xi-2
1622 do q{i-t-1} -= 1;
1624 q.dp[i - t - 1] = (q.dp[i - t - 1] + 1) & MP_MASK;
1625 do {
1626 q.dp[i - t - 1] = (q.dp[i - t - 1] - 1) & MP_MASK;
1628 /* find left hand */
1629 mp_zero (&t1);
1630 t1.dp[0] = (t - 1 < 0) ? 0 : y.dp[t - 1];
1631 t1.dp[1] = y.dp[t];
1632 t1.used = 2;
1633 if ((res = mp_mul_d (&t1, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1634 goto __Y;
1637 /* find right hand */
1638 t2.dp[0] = (i - 2 < 0) ? 0 : x.dp[i - 2];
1639 t2.dp[1] = (i - 1 < 0) ? 0 : x.dp[i - 1];
1640 t2.dp[2] = x.dp[i];
1641 t2.used = 3;
1642 } while (mp_cmp_mag(&t1, &t2) == MP_GT);
1644 /* step 3.3 x = x - q{i-t-1} * y * b**{i-t-1} */
1645 if ((res = mp_mul_d (&y, q.dp[i - t - 1], &t1)) != MP_OKAY) {
1646 goto __Y;
1649 if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1650 goto __Y;
1653 if ((res = mp_sub (&x, &t1, &x)) != MP_OKAY) {
1654 goto __Y;
1657 /* if x < 0 then { x = x + y*b**{i-t-1}; q{i-t-1} -= 1; } */
1658 if (x.sign == MP_NEG) {
1659 if ((res = mp_copy (&y, &t1)) != MP_OKAY) {
1660 goto __Y;
1662 if ((res = mp_lshd (&t1, i - t - 1)) != MP_OKAY) {
1663 goto __Y;
1665 if ((res = mp_add (&x, &t1, &x)) != MP_OKAY) {
1666 goto __Y;
1669 q.dp[i - t - 1] = (q.dp[i - t - 1] - 1UL) & MP_MASK;
1673 /* now q is the quotient and x is the remainder
1674 * [which we have to normalize]
1677 /* get sign before writing to c */
1678 x.sign = x.used == 0 ? MP_ZPOS : a->sign;
1680 if (c != NULL) {
1681 mp_clamp (&q);
1682 mp_exch (&q, c);
1683 c->sign = neg;
1686 if (d != NULL) {
1687 mp_div_2d (&x, norm, &x, NULL);
1688 mp_exch (&x, d);
1691 res = MP_OKAY;
1693 __Y:mp_clear (&y);
1694 __X:mp_clear (&x);
1695 __T2:mp_clear (&t2);
1696 __T1:mp_clear (&t1);
1697 __Q:mp_clear (&q);
1698 return res;
1701 static int s_is_power_of_two(mp_digit b, int *p)
1703 int x;
1705 for (x = 1; x < DIGIT_BIT; x++) {
1706 if (b == (((mp_digit)1)<<x)) {
1707 *p = x;
1708 return 1;
1711 return 0;
1714 /* single digit division (based on routine from MPI) */
1715 static int mp_div_d (const mp_int * a, mp_digit b, mp_int * c, mp_digit * d)
1717 mp_int q;
1718 mp_word w;
1719 mp_digit t;
1720 int res, ix;
1722 /* cannot divide by zero */
1723 if (b == 0) {
1724 return MP_VAL;
1727 /* quick outs */
1728 if (b == 1 || mp_iszero(a) == 1) {
1729 if (d != NULL) {
1730 *d = 0;
1732 if (c != NULL) {
1733 return mp_copy(a, c);
1735 return MP_OKAY;
1738 /* power of two ? */
1739 if (s_is_power_of_two(b, &ix) == 1) {
1740 if (d != NULL) {
1741 *d = a->dp[0] & ((((mp_digit)1)<<ix) - 1);
1743 if (c != NULL) {
1744 return mp_div_2d(a, ix, c, NULL);
1746 return MP_OKAY;
1749 /* no easy answer [c'est la vie]. Just division */
1750 if ((res = mp_init_size(&q, a->used)) != MP_OKAY) {
1751 return res;
1754 q.used = a->used;
1755 q.sign = a->sign;
1756 w = 0;
1757 for (ix = a->used - 1; ix >= 0; ix--) {
1758 w = (w << ((mp_word)DIGIT_BIT)) | ((mp_word)a->dp[ix]);
1760 if (w >= b) {
1761 t = (mp_digit)(w / b);
1762 w -= ((mp_word)t) * ((mp_word)b);
1763 } else {
1764 t = 0;
1766 q.dp[ix] = t;
1769 if (d != NULL) {
1770 *d = (mp_digit)w;
1773 if (c != NULL) {
1774 mp_clamp(&q);
1775 mp_exch(&q, c);
1777 mp_clear(&q);
1779 return res;
1782 /* reduce "x" in place modulo "n" using the Diminished Radix algorithm.
1784 * Based on algorithm from the paper
1786 * "Generating Efficient Primes for Discrete Log Cryptosystems"
1787 * Chae Hoon Lim, Pil Loong Lee,
1788 * POSTECH Information Research Laboratories
1790 * The modulus must be of a special format [see manual]
1792 * Has been modified to use algorithm 7.10 from the LTM book instead
1794 * Input x must be in the range 0 <= x <= (n-1)**2
1796 static int
1797 mp_dr_reduce (mp_int * x, const mp_int * n, mp_digit k)
1799 int err, i, m;
1800 mp_word r;
1801 mp_digit mu, *tmpx1, *tmpx2;
1803 /* m = digits in modulus */
1804 m = n->used;
1806 /* ensure that "x" has at least 2m digits */
1807 if (x->alloc < m + m) {
1808 if ((err = mp_grow (x, m + m)) != MP_OKAY) {
1809 return err;
1813 /* top of loop, this is where the code resumes if
1814 * another reduction pass is required.
1816 top:
1817 /* aliases for digits */
1818 /* alias for lower half of x */
1819 tmpx1 = x->dp;
1821 /* alias for upper half of x, or x/B**m */
1822 tmpx2 = x->dp + m;
1824 /* set carry to zero */
1825 mu = 0;
1827 /* compute (x mod B**m) + k * [x/B**m] inline and inplace */
1828 for (i = 0; i < m; i++) {
1829 r = ((mp_word)*tmpx2++) * ((mp_word)k) + *tmpx1 + mu;
1830 *tmpx1++ = (mp_digit)(r & MP_MASK);
1831 mu = (mp_digit)(r >> ((mp_word)DIGIT_BIT));
1834 /* set final carry */
1835 *tmpx1++ = mu;
1837 /* zero words above m */
1838 for (i = m + 1; i < x->used; i++) {
1839 *tmpx1++ = 0;
1842 /* clamp, sub and return */
1843 mp_clamp (x);
1845 /* if x >= n then subtract and reduce again
1846 * Each successive "recursion" makes the input smaller and smaller.
1848 if (mp_cmp_mag (x, n) != MP_LT) {
1849 s_mp_sub(x, n, x);
1850 goto top;
1852 return MP_OKAY;
1855 /* sets the value of "d" required for mp_dr_reduce */
1856 static void mp_dr_setup(const mp_int *a, mp_digit *d)
1858 /* the casts are required if DIGIT_BIT is one less than
1859 * the number of bits in a mp_digit [e.g. DIGIT_BIT==31]
1861 *d = (mp_digit)((((mp_word)1) << ((mp_word)DIGIT_BIT)) -
1862 ((mp_word)a->dp[0]));
1865 /* this is a shell function that calls either the normal or Montgomery
1866 * exptmod functions. Originally the call to the montgomery code was
1867 * embedded in the normal function but that wasted a lot of stack space
1868 * for nothing (since 99% of the time the Montgomery code would be called)
1870 int mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
1872 int dr;
1874 /* modulus P must be positive */
1875 if (P->sign == MP_NEG) {
1876 return MP_VAL;
1879 /* if exponent X is negative we have to recurse */
1880 if (X->sign == MP_NEG) {
1881 mp_int tmpG, tmpX;
1882 int err;
1884 /* first compute 1/G mod P */
1885 if ((err = mp_init(&tmpG)) != MP_OKAY) {
1886 return err;
1888 if ((err = mp_invmod(G, P, &tmpG)) != MP_OKAY) {
1889 mp_clear(&tmpG);
1890 return err;
1893 /* now get |X| */
1894 if ((err = mp_init(&tmpX)) != MP_OKAY) {
1895 mp_clear(&tmpG);
1896 return err;
1898 if ((err = mp_abs(X, &tmpX)) != MP_OKAY) {
1899 mp_clear_multi(&tmpG, &tmpX, NULL);
1900 return err;
1903 /* and now compute (1/G)**|X| instead of G**X [X < 0] */
1904 err = mp_exptmod(&tmpG, &tmpX, P, Y);
1905 mp_clear_multi(&tmpG, &tmpX, NULL);
1906 return err;
1909 dr = 0;
1911 /* if the modulus is odd or dr != 0 use the fast method */
1912 if (mp_isodd (P) == 1 || dr != 0) {
1913 return mp_exptmod_fast (G, X, P, Y, dr);
1914 } else {
1915 /* otherwise use the generic Barrett reduction technique */
1916 return s_mp_exptmod (G, X, P, Y);
1920 /* computes Y == G**X mod P, HAC pp.616, Algorithm 14.85
1922 * Uses a left-to-right k-ary sliding window to compute the modular exponentiation.
1923 * The value of k changes based on the size of the exponent.
1925 * Uses Montgomery or Diminished Radix reduction [whichever appropriate]
1929 mp_exptmod_fast (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y, int redmode)
1931 mp_int M[256], res;
1932 mp_digit buf, mp;
1933 int err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
1935 /* use a pointer to the reduction algorithm. This allows us to use
1936 * one of many reduction algorithms without modding the guts of
1937 * the code with if statements everywhere.
1939 int (*redux)(mp_int*,const mp_int*,mp_digit);
1941 /* find window size */
1942 x = mp_count_bits (X);
1943 if (x <= 7) {
1944 winsize = 2;
1945 } else if (x <= 36) {
1946 winsize = 3;
1947 } else if (x <= 140) {
1948 winsize = 4;
1949 } else if (x <= 450) {
1950 winsize = 5;
1951 } else if (x <= 1303) {
1952 winsize = 6;
1953 } else if (x <= 3529) {
1954 winsize = 7;
1955 } else {
1956 winsize = 8;
1959 /* init M array */
1960 /* init first cell */
1961 if ((err = mp_init(&M[1])) != MP_OKAY) {
1962 return err;
1965 /* now init the second half of the array */
1966 for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
1967 if ((err = mp_init(&M[x])) != MP_OKAY) {
1968 for (y = 1<<(winsize-1); y < x; y++) {
1969 mp_clear (&M[y]);
1971 mp_clear(&M[1]);
1972 return err;
1976 /* determine and setup reduction code */
1977 if (redmode == 0) {
1978 /* now setup montgomery */
1979 if ((err = mp_montgomery_setup (P, &mp)) != MP_OKAY) {
1980 goto __M;
1983 /* automatically pick the comba one if available (saves quite a few calls/ifs) */
1984 if (((P->used * 2 + 1) < MP_WARRAY) &&
1985 P->used < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
1986 redux = fast_mp_montgomery_reduce;
1987 } else {
1988 /* use slower baseline Montgomery method */
1989 redux = mp_montgomery_reduce;
1991 } else if (redmode == 1) {
1992 /* setup DR reduction for moduli of the form B**k - b */
1993 mp_dr_setup(P, &mp);
1994 redux = mp_dr_reduce;
1995 } else {
1996 /* setup DR reduction for moduli of the form 2**k - b */
1997 if ((err = mp_reduce_2k_setup(P, &mp)) != MP_OKAY) {
1998 goto __M;
2000 redux = mp_reduce_2k;
2003 /* setup result */
2004 if ((err = mp_init (&res)) != MP_OKAY) {
2005 goto __M;
2008 /* create M table
2012 * The first half of the table is not computed though accept for M[0] and M[1]
2015 if (redmode == 0) {
2016 /* now we need R mod m */
2017 if ((err = mp_montgomery_calc_normalization (&res, P)) != MP_OKAY) {
2018 goto __RES;
2021 /* now set M[1] to G * R mod m */
2022 if ((err = mp_mulmod (G, &res, P, &M[1])) != MP_OKAY) {
2023 goto __RES;
2025 } else {
2026 mp_set(&res, 1);
2027 if ((err = mp_mod(G, P, &M[1])) != MP_OKAY) {
2028 goto __RES;
2032 /* compute the value at M[1<<(winsize-1)] by squaring M[1] (winsize-1) times */
2033 if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
2034 goto __RES;
2037 for (x = 0; x < (winsize - 1); x++) {
2038 if ((err = mp_sqr (&M[1 << (winsize - 1)], &M[1 << (winsize - 1)])) != MP_OKAY) {
2039 goto __RES;
2041 if ((err = redux (&M[1 << (winsize - 1)], P, mp)) != MP_OKAY) {
2042 goto __RES;
2046 /* create upper table */
2047 for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
2048 if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
2049 goto __RES;
2051 if ((err = redux (&M[x], P, mp)) != MP_OKAY) {
2052 goto __RES;
2056 /* set initial mode and bit cnt */
2057 mode = 0;
2058 bitcnt = 1;
2059 buf = 0;
2060 digidx = X->used - 1;
2061 bitcpy = 0;
2062 bitbuf = 0;
2064 for (;;) {
2065 /* grab next digit as required */
2066 if (--bitcnt == 0) {
2067 /* if digidx == -1 we are out of digits so break */
2068 if (digidx == -1) {
2069 break;
2071 /* read next digit and reset bitcnt */
2072 buf = X->dp[digidx--];
2073 bitcnt = DIGIT_BIT;
2076 /* grab the next msb from the exponent */
2077 y = (buf >> (DIGIT_BIT - 1)) & 1;
2078 buf <<= (mp_digit)1;
2080 /* if the bit is zero and mode == 0 then we ignore it
2081 * These represent the leading zero bits before the first 1 bit
2082 * in the exponent. Technically this opt is not required but it
2083 * does lower the # of trivial squaring/reductions used
2085 if (mode == 0 && y == 0) {
2086 continue;
2089 /* if the bit is zero and mode == 1 then we square */
2090 if (mode == 1 && y == 0) {
2091 if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2092 goto __RES;
2094 if ((err = redux (&res, P, mp)) != MP_OKAY) {
2095 goto __RES;
2097 continue;
2100 /* else we add it to the window */
2101 bitbuf |= (y << (winsize - ++bitcpy));
2102 mode = 2;
2104 if (bitcpy == winsize) {
2105 /* ok window is filled so square as required and multiply */
2106 /* square first */
2107 for (x = 0; x < winsize; x++) {
2108 if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2109 goto __RES;
2111 if ((err = redux (&res, P, mp)) != MP_OKAY) {
2112 goto __RES;
2116 /* then multiply */
2117 if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
2118 goto __RES;
2120 if ((err = redux (&res, P, mp)) != MP_OKAY) {
2121 goto __RES;
2124 /* empty window and reset */
2125 bitcpy = 0;
2126 bitbuf = 0;
2127 mode = 1;
2131 /* if bits remain then square/multiply */
2132 if (mode == 2 && bitcpy > 0) {
2133 /* square then multiply if the bit is set */
2134 for (x = 0; x < bitcpy; x++) {
2135 if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
2136 goto __RES;
2138 if ((err = redux (&res, P, mp)) != MP_OKAY) {
2139 goto __RES;
2142 /* get next bit of the window */
2143 bitbuf <<= 1;
2144 if ((bitbuf & (1 << winsize)) != 0) {
2145 /* then multiply */
2146 if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
2147 goto __RES;
2149 if ((err = redux (&res, P, mp)) != MP_OKAY) {
2150 goto __RES;
2156 if (redmode == 0) {
2157 /* fixup result if Montgomery reduction is used
2158 * recall that any value in a Montgomery system is
2159 * actually multiplied by R mod n. So we have
2160 * to reduce one more time to cancel out the factor
2161 * of R.
2163 if ((err = redux(&res, P, mp)) != MP_OKAY) {
2164 goto __RES;
2168 /* swap res with Y */
2169 mp_exch (&res, Y);
2170 err = MP_OKAY;
2171 __RES:mp_clear (&res);
2172 __M:
2173 mp_clear(&M[1]);
2174 for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
2175 mp_clear (&M[x]);
2177 return err;
2180 /* Greatest Common Divisor using the binary method */
2181 int mp_gcd (const mp_int * a, const mp_int * b, mp_int * c)
2183 mp_int u, v;
2184 int k, u_lsb, v_lsb, res;
2186 /* either zero than gcd is the largest */
2187 if (mp_iszero (a) == 1 && mp_iszero (b) == 0) {
2188 return mp_abs (b, c);
2190 if (mp_iszero (a) == 0 && mp_iszero (b) == 1) {
2191 return mp_abs (a, c);
2194 /* optimized. At this point if a == 0 then
2195 * b must equal zero too
2197 if (mp_iszero (a) == 1) {
2198 mp_zero(c);
2199 return MP_OKAY;
2202 /* get copies of a and b we can modify */
2203 if ((res = mp_init_copy (&u, a)) != MP_OKAY) {
2204 return res;
2207 if ((res = mp_init_copy (&v, b)) != MP_OKAY) {
2208 goto __U;
2211 /* must be positive for the remainder of the algorithm */
2212 u.sign = v.sign = MP_ZPOS;
2214 /* B1. Find the common power of two for u and v */
2215 u_lsb = mp_cnt_lsb(&u);
2216 v_lsb = mp_cnt_lsb(&v);
2217 k = MIN(u_lsb, v_lsb);
2219 if (k > 0) {
2220 /* divide the power of two out */
2221 if ((res = mp_div_2d(&u, k, &u, NULL)) != MP_OKAY) {
2222 goto __V;
2225 if ((res = mp_div_2d(&v, k, &v, NULL)) != MP_OKAY) {
2226 goto __V;
2230 /* divide any remaining factors of two out */
2231 if (u_lsb != k) {
2232 if ((res = mp_div_2d(&u, u_lsb - k, &u, NULL)) != MP_OKAY) {
2233 goto __V;
2237 if (v_lsb != k) {
2238 if ((res = mp_div_2d(&v, v_lsb - k, &v, NULL)) != MP_OKAY) {
2239 goto __V;
2243 while (mp_iszero(&v) == 0) {
2244 /* make sure v is the largest */
2245 if (mp_cmp_mag(&u, &v) == MP_GT) {
2246 /* swap u and v to make sure v is >= u */
2247 mp_exch(&u, &v);
2250 /* subtract smallest from largest */
2251 if ((res = s_mp_sub(&v, &u, &v)) != MP_OKAY) {
2252 goto __V;
2255 /* Divide out all factors of two */
2256 if ((res = mp_div_2d(&v, mp_cnt_lsb(&v), &v, NULL)) != MP_OKAY) {
2257 goto __V;
2261 /* multiply by 2**k which we divided out at the beginning */
2262 if ((res = mp_mul_2d (&u, k, c)) != MP_OKAY) {
2263 goto __V;
2265 c->sign = MP_ZPOS;
2266 res = MP_OKAY;
2267 __V:mp_clear (&u);
2268 __U:mp_clear (&v);
2269 return res;
2272 /* get the lower 32-bits of an mp_int */
2273 unsigned long mp_get_int(const mp_int * a)
2275 int i;
2276 unsigned long res;
2278 if (a->used == 0) {
2279 return 0;
2282 /* get number of digits of the lsb we have to read */
2283 i = MIN(a->used,(int)((sizeof(unsigned long)*CHAR_BIT+DIGIT_BIT-1)/DIGIT_BIT))-1;
2285 /* get most significant digit of result */
2286 res = DIGIT(a,i);
2288 while (--i >= 0) {
2289 res = (res << DIGIT_BIT) | DIGIT(a,i);
2292 /* force result to 32-bits always so it is consistent on non 32-bit platforms */
2293 return res & 0xFFFFFFFFUL;
2296 /* creates "a" then copies b into it */
2297 int mp_init_copy (mp_int * a, const mp_int * b)
2299 int res;
2301 if ((res = mp_init (a)) != MP_OKAY) {
2302 return res;
2304 return mp_copy (b, a);
2307 int mp_init_multi(mp_int *mp, ...)
2309 mp_err res = MP_OKAY; /* Assume ok until proven otherwise */
2310 int n = 0; /* Number of ok inits */
2311 mp_int* cur_arg = mp;
2312 va_list args;
2314 va_start(args, mp); /* init args to next argument from caller */
2315 while (cur_arg != NULL) {
2316 if (mp_init(cur_arg) != MP_OKAY) {
2317 /* Oops - error! Back-track and mp_clear what we already
2318 succeeded in init-ing, then return error.
2320 va_list clean_args;
2322 /* end the current list */
2323 va_end(args);
2325 /* now start cleaning up */
2326 cur_arg = mp;
2327 va_start(clean_args, mp);
2328 while (n--) {
2329 mp_clear(cur_arg);
2330 cur_arg = va_arg(clean_args, mp_int*);
2332 va_end(clean_args);
2333 res = MP_MEM;
2334 break;
2336 n++;
2337 cur_arg = va_arg(args, mp_int*);
2339 va_end(args);
2340 return res; /* Assumed ok, if error flagged above. */
2343 /* hac 14.61, pp608 */
2344 int mp_invmod (const mp_int * a, mp_int * b, mp_int * c)
2346 /* b cannot be negative */
2347 if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2348 return MP_VAL;
2351 /* if the modulus is odd we can use a faster routine instead */
2352 if (mp_isodd (b) == 1) {
2353 return fast_mp_invmod (a, b, c);
2356 return mp_invmod_slow(a, b, c);
2359 /* hac 14.61, pp608 */
2360 int mp_invmod_slow (const mp_int * a, mp_int * b, mp_int * c)
2362 mp_int x, y, u, v, A, B, C, D;
2363 int res;
2365 /* b cannot be negative */
2366 if (b->sign == MP_NEG || mp_iszero(b) == 1) {
2367 return MP_VAL;
2370 /* init temps */
2371 if ((res = mp_init_multi(&x, &y, &u, &v,
2372 &A, &B, &C, &D, NULL)) != MP_OKAY) {
2373 return res;
2376 /* x = a, y = b */
2377 if ((res = mp_copy (a, &x)) != MP_OKAY) {
2378 goto __ERR;
2380 if ((res = mp_copy (b, &y)) != MP_OKAY) {
2381 goto __ERR;
2384 /* 2. [modified] if x,y are both even then return an error! */
2385 if (mp_iseven (&x) == 1 && mp_iseven (&y) == 1) {
2386 res = MP_VAL;
2387 goto __ERR;
2390 /* 3. u=x, v=y, A=1, B=0, C=0,D=1 */
2391 if ((res = mp_copy (&x, &u)) != MP_OKAY) {
2392 goto __ERR;
2394 if ((res = mp_copy (&y, &v)) != MP_OKAY) {
2395 goto __ERR;
2397 mp_set (&A, 1);
2398 mp_set (&D, 1);
2400 top:
2401 /* 4. while u is even do */
2402 while (mp_iseven (&u) == 1) {
2403 /* 4.1 u = u/2 */
2404 if ((res = mp_div_2 (&u, &u)) != MP_OKAY) {
2405 goto __ERR;
2407 /* 4.2 if A or B is odd then */
2408 if (mp_isodd (&A) == 1 || mp_isodd (&B) == 1) {
2409 /* A = (A+y)/2, B = (B-x)/2 */
2410 if ((res = mp_add (&A, &y, &A)) != MP_OKAY) {
2411 goto __ERR;
2413 if ((res = mp_sub (&B, &x, &B)) != MP_OKAY) {
2414 goto __ERR;
2417 /* A = A/2, B = B/2 */
2418 if ((res = mp_div_2 (&A, &A)) != MP_OKAY) {
2419 goto __ERR;
2421 if ((res = mp_div_2 (&B, &B)) != MP_OKAY) {
2422 goto __ERR;
2426 /* 5. while v is even do */
2427 while (mp_iseven (&v) == 1) {
2428 /* 5.1 v = v/2 */
2429 if ((res = mp_div_2 (&v, &v)) != MP_OKAY) {
2430 goto __ERR;
2432 /* 5.2 if C or D is odd then */
2433 if (mp_isodd (&C) == 1 || mp_isodd (&D) == 1) {
2434 /* C = (C+y)/2, D = (D-x)/2 */
2435 if ((res = mp_add (&C, &y, &C)) != MP_OKAY) {
2436 goto __ERR;
2438 if ((res = mp_sub (&D, &x, &D)) != MP_OKAY) {
2439 goto __ERR;
2442 /* C = C/2, D = D/2 */
2443 if ((res = mp_div_2 (&C, &C)) != MP_OKAY) {
2444 goto __ERR;
2446 if ((res = mp_div_2 (&D, &D)) != MP_OKAY) {
2447 goto __ERR;
2451 /* 6. if u >= v then */
2452 if (mp_cmp (&u, &v) != MP_LT) {
2453 /* u = u - v, A = A - C, B = B - D */
2454 if ((res = mp_sub (&u, &v, &u)) != MP_OKAY) {
2455 goto __ERR;
2458 if ((res = mp_sub (&A, &C, &A)) != MP_OKAY) {
2459 goto __ERR;
2462 if ((res = mp_sub (&B, &D, &B)) != MP_OKAY) {
2463 goto __ERR;
2465 } else {
2466 /* v - v - u, C = C - A, D = D - B */
2467 if ((res = mp_sub (&v, &u, &v)) != MP_OKAY) {
2468 goto __ERR;
2471 if ((res = mp_sub (&C, &A, &C)) != MP_OKAY) {
2472 goto __ERR;
2475 if ((res = mp_sub (&D, &B, &D)) != MP_OKAY) {
2476 goto __ERR;
2480 /* if not zero goto step 4 */
2481 if (mp_iszero (&u) == 0)
2482 goto top;
2484 /* now a = C, b = D, gcd == g*v */
2486 /* if v != 1 then there is no inverse */
2487 if (mp_cmp_d (&v, 1) != MP_EQ) {
2488 res = MP_VAL;
2489 goto __ERR;
2492 /* if its too low */
2493 while (mp_cmp_d(&C, 0) == MP_LT) {
2494 if ((res = mp_add(&C, b, &C)) != MP_OKAY) {
2495 goto __ERR;
2499 /* too big */
2500 while (mp_cmp_mag(&C, b) != MP_LT) {
2501 if ((res = mp_sub(&C, b, &C)) != MP_OKAY) {
2502 goto __ERR;
2506 /* C is now the inverse */
2507 mp_exch (&C, c);
2508 res = MP_OKAY;
2509 __ERR:mp_clear_multi (&x, &y, &u, &v, &A, &B, &C, &D, NULL);
2510 return res;
2513 /* c = |a| * |b| using Karatsuba Multiplication using
2514 * three half size multiplications
2516 * Let B represent the radix [e.g. 2**DIGIT_BIT] and
2517 * let n represent half of the number of digits in
2518 * the min(a,b)
2520 * a = a1 * B**n + a0
2521 * b = b1 * B**n + b0
2523 * Then, a * b =>
2524 a1b1 * B**2n + ((a1 - a0)(b1 - b0) + a0b0 + a1b1) * B + a0b0
2526 * Note that a1b1 and a0b0 are used twice and only need to be
2527 * computed once. So in total three half size (half # of
2528 * digit) multiplications are performed, a0b0, a1b1 and
2529 * (a1-b1)(a0-b0)
2531 * Note that a multiplication of half the digits requires
2532 * 1/4th the number of single precision multiplications so in
2533 * total after one call 25% of the single precision multiplications
2534 * are saved. Note also that the call to mp_mul can end up back
2535 * in this function if the a0, a1, b0, or b1 are above the threshold.
2536 * This is known as divide-and-conquer and leads to the famous
2537 * O(N**lg(3)) or O(N**1.584) work which is asymptotically lower than
2538 * the standard O(N**2) that the baseline/comba methods use.
2539 * Generally though the overhead of this method doesn't pay off
2540 * until a certain size (N ~ 80) is reached.
2542 int mp_karatsuba_mul (const mp_int * a, const mp_int * b, mp_int * c)
2544 mp_int x0, x1, y0, y1, t1, x0y0, x1y1;
2545 int B, err;
2547 /* default the return code to an error */
2548 err = MP_MEM;
2550 /* min # of digits */
2551 B = MIN (a->used, b->used);
2553 /* now divide in two */
2554 B = B >> 1;
2556 /* init copy all the temps */
2557 if (mp_init_size (&x0, B) != MP_OKAY)
2558 goto ERR;
2559 if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2560 goto X0;
2561 if (mp_init_size (&y0, B) != MP_OKAY)
2562 goto X1;
2563 if (mp_init_size (&y1, b->used - B) != MP_OKAY)
2564 goto Y0;
2566 /* init temps */
2567 if (mp_init_size (&t1, B * 2) != MP_OKAY)
2568 goto Y1;
2569 if (mp_init_size (&x0y0, B * 2) != MP_OKAY)
2570 goto T1;
2571 if (mp_init_size (&x1y1, B * 2) != MP_OKAY)
2572 goto X0Y0;
2574 /* now shift the digits */
2575 x0.used = y0.used = B;
2576 x1.used = a->used - B;
2577 y1.used = b->used - B;
2580 register int x;
2581 register mp_digit *tmpa, *tmpb, *tmpx, *tmpy;
2583 /* we copy the digits directly instead of using higher level functions
2584 * since we also need to shift the digits
2586 tmpa = a->dp;
2587 tmpb = b->dp;
2589 tmpx = x0.dp;
2590 tmpy = y0.dp;
2591 for (x = 0; x < B; x++) {
2592 *tmpx++ = *tmpa++;
2593 *tmpy++ = *tmpb++;
2596 tmpx = x1.dp;
2597 for (x = B; x < a->used; x++) {
2598 *tmpx++ = *tmpa++;
2601 tmpy = y1.dp;
2602 for (x = B; x < b->used; x++) {
2603 *tmpy++ = *tmpb++;
2607 /* only need to clamp the lower words since by definition the
2608 * upper words x1/y1 must have a known number of digits
2610 mp_clamp (&x0);
2611 mp_clamp (&y0);
2613 /* now calc the products x0y0 and x1y1 */
2614 /* after this x0 is no longer required, free temp [x0==t2]! */
2615 if (mp_mul (&x0, &y0, &x0y0) != MP_OKAY)
2616 goto X1Y1; /* x0y0 = x0*y0 */
2617 if (mp_mul (&x1, &y1, &x1y1) != MP_OKAY)
2618 goto X1Y1; /* x1y1 = x1*y1 */
2620 /* now calc x1-x0 and y1-y0 */
2621 if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2622 goto X1Y1; /* t1 = x1 - x0 */
2623 if (mp_sub (&y1, &y0, &x0) != MP_OKAY)
2624 goto X1Y1; /* t2 = y1 - y0 */
2625 if (mp_mul (&t1, &x0, &t1) != MP_OKAY)
2626 goto X1Y1; /* t1 = (x1 - x0) * (y1 - y0) */
2628 /* add x0y0 */
2629 if (mp_add (&x0y0, &x1y1, &x0) != MP_OKAY)
2630 goto X1Y1; /* t2 = x0y0 + x1y1 */
2631 if (mp_sub (&x0, &t1, &t1) != MP_OKAY)
2632 goto X1Y1; /* t1 = x0y0 + x1y1 - (x1-x0)*(y1-y0) */
2634 /* shift by B */
2635 if (mp_lshd (&t1, B) != MP_OKAY)
2636 goto X1Y1; /* t1 = (x0y0 + x1y1 - (x1-x0)*(y1-y0))<<B */
2637 if (mp_lshd (&x1y1, B * 2) != MP_OKAY)
2638 goto X1Y1; /* x1y1 = x1y1 << 2*B */
2640 if (mp_add (&x0y0, &t1, &t1) != MP_OKAY)
2641 goto X1Y1; /* t1 = x0y0 + t1 */
2642 if (mp_add (&t1, &x1y1, c) != MP_OKAY)
2643 goto X1Y1; /* t1 = x0y0 + t1 + x1y1 */
2645 /* Algorithm succeeded set the return code to MP_OKAY */
2646 err = MP_OKAY;
2648 X1Y1:mp_clear (&x1y1);
2649 X0Y0:mp_clear (&x0y0);
2650 T1:mp_clear (&t1);
2651 Y1:mp_clear (&y1);
2652 Y0:mp_clear (&y0);
2653 X1:mp_clear (&x1);
2654 X0:mp_clear (&x0);
2655 ERR:
2656 return err;
2659 /* Karatsuba squaring, computes b = a*a using three
2660 * half size squarings
2662 * See comments of karatsuba_mul for details. It
2663 * is essentially the same algorithm but merely
2664 * tuned to perform recursive squarings.
2666 int mp_karatsuba_sqr (const mp_int * a, mp_int * b)
2668 mp_int x0, x1, t1, t2, x0x0, x1x1;
2669 int B, err;
2671 err = MP_MEM;
2673 /* min # of digits */
2674 B = a->used;
2676 /* now divide in two */
2677 B = B >> 1;
2679 /* init copy all the temps */
2680 if (mp_init_size (&x0, B) != MP_OKAY)
2681 goto ERR;
2682 if (mp_init_size (&x1, a->used - B) != MP_OKAY)
2683 goto X0;
2685 /* init temps */
2686 if (mp_init_size (&t1, a->used * 2) != MP_OKAY)
2687 goto X1;
2688 if (mp_init_size (&t2, a->used * 2) != MP_OKAY)
2689 goto T1;
2690 if (mp_init_size (&x0x0, B * 2) != MP_OKAY)
2691 goto T2;
2692 if (mp_init_size (&x1x1, (a->used - B) * 2) != MP_OKAY)
2693 goto X0X0;
2696 register int x;
2697 register mp_digit *dst, *src;
2699 src = a->dp;
2701 /* now shift the digits */
2702 dst = x0.dp;
2703 for (x = 0; x < B; x++) {
2704 *dst++ = *src++;
2707 dst = x1.dp;
2708 for (x = B; x < a->used; x++) {
2709 *dst++ = *src++;
2713 x0.used = B;
2714 x1.used = a->used - B;
2716 mp_clamp (&x0);
2718 /* now calc the products x0*x0 and x1*x1 */
2719 if (mp_sqr (&x0, &x0x0) != MP_OKAY)
2720 goto X1X1; /* x0x0 = x0*x0 */
2721 if (mp_sqr (&x1, &x1x1) != MP_OKAY)
2722 goto X1X1; /* x1x1 = x1*x1 */
2724 /* now calc (x1-x0)**2 */
2725 if (mp_sub (&x1, &x0, &t1) != MP_OKAY)
2726 goto X1X1; /* t1 = x1 - x0 */
2727 if (mp_sqr (&t1, &t1) != MP_OKAY)
2728 goto X1X1; /* t1 = (x1 - x0) * (x1 - x0) */
2730 /* add x0y0 */
2731 if (s_mp_add (&x0x0, &x1x1, &t2) != MP_OKAY)
2732 goto X1X1; /* t2 = x0x0 + x1x1 */
2733 if (mp_sub (&t2, &t1, &t1) != MP_OKAY)
2734 goto X1X1; /* t1 = x0x0 + x1x1 - (x1-x0)*(x1-x0) */
2736 /* shift by B */
2737 if (mp_lshd (&t1, B) != MP_OKAY)
2738 goto X1X1; /* t1 = (x0x0 + x1x1 - (x1-x0)*(x1-x0))<<B */
2739 if (mp_lshd (&x1x1, B * 2) != MP_OKAY)
2740 goto X1X1; /* x1x1 = x1x1 << 2*B */
2742 if (mp_add (&x0x0, &t1, &t1) != MP_OKAY)
2743 goto X1X1; /* t1 = x0x0 + t1 */
2744 if (mp_add (&t1, &x1x1, b) != MP_OKAY)
2745 goto X1X1; /* t1 = x0x0 + t1 + x1x1 */
2747 err = MP_OKAY;
2749 X1X1:mp_clear (&x1x1);
2750 X0X0:mp_clear (&x0x0);
2751 T2:mp_clear (&t2);
2752 T1:mp_clear (&t1);
2753 X1:mp_clear (&x1);
2754 X0:mp_clear (&x0);
2755 ERR:
2756 return err;
2759 /* computes least common multiple as |a*b|/(a, b) */
2760 int mp_lcm (const mp_int * a, const mp_int * b, mp_int * c)
2762 int res;
2763 mp_int t1, t2;
2766 if ((res = mp_init_multi (&t1, &t2, NULL)) != MP_OKAY) {
2767 return res;
2770 /* t1 = get the GCD of the two inputs */
2771 if ((res = mp_gcd (a, b, &t1)) != MP_OKAY) {
2772 goto __T;
2775 /* divide the smallest by the GCD */
2776 if (mp_cmp_mag(a, b) == MP_LT) {
2777 /* store quotient in t2 such that t2 * b is the LCM */
2778 if ((res = mp_div(a, &t1, &t2, NULL)) != MP_OKAY) {
2779 goto __T;
2781 res = mp_mul(b, &t2, c);
2782 } else {
2783 /* store quotient in t2 such that t2 * a is the LCM */
2784 if ((res = mp_div(b, &t1, &t2, NULL)) != MP_OKAY) {
2785 goto __T;
2787 res = mp_mul(a, &t2, c);
2790 /* fix the sign to positive */
2791 c->sign = MP_ZPOS;
2793 __T:
2794 mp_clear_multi (&t1, &t2, NULL);
2795 return res;
2798 /* c = a mod b, 0 <= c < b */
2800 mp_mod (const mp_int * a, mp_int * b, mp_int * c)
2802 mp_int t;
2803 int res;
2805 if ((res = mp_init (&t)) != MP_OKAY) {
2806 return res;
2809 if ((res = mp_div (a, b, NULL, &t)) != MP_OKAY) {
2810 mp_clear (&t);
2811 return res;
2814 if (t.sign != b->sign) {
2815 res = mp_add (b, &t, c);
2816 } else {
2817 res = MP_OKAY;
2818 mp_exch (&t, c);
2821 mp_clear (&t);
2822 return res;
2825 static int
2826 mp_mod_d (const mp_int * a, mp_digit b, mp_digit * c)
2828 return mp_div_d(a, b, NULL, c);
2831 /* b = a*2 */
2832 static int mp_mul_2(const mp_int * a, mp_int * b)
2834 int x, res, oldused;
2836 /* grow to accommodate result */
2837 if (b->alloc < a->used + 1) {
2838 if ((res = mp_grow (b, a->used + 1)) != MP_OKAY) {
2839 return res;
2843 oldused = b->used;
2844 b->used = a->used;
2847 register mp_digit r, rr, *tmpa, *tmpb;
2849 /* alias for source */
2850 tmpa = a->dp;
2852 /* alias for dest */
2853 tmpb = b->dp;
2855 /* carry */
2856 r = 0;
2857 for (x = 0; x < a->used; x++) {
2859 /* get what will be the *next* carry bit from the
2860 * MSB of the current digit
2862 rr = *tmpa >> ((mp_digit)(DIGIT_BIT - 1));
2864 /* now shift up this digit, add in the carry [from the previous] */
2865 *tmpb++ = ((*tmpa++ << ((mp_digit)1)) | r) & MP_MASK;
2867 /* copy the carry that would be from the source
2868 * digit into the next iteration
2870 r = rr;
2873 /* new leading digit? */
2874 if (r != 0) {
2875 /* add a MSB which is always 1 at this point */
2876 *tmpb = 1;
2877 ++(b->used);
2880 /* now zero any excess digits on the destination
2881 * that we didn't write to
2883 tmpb = b->dp + b->used;
2884 for (x = b->used; x < oldused; x++) {
2885 *tmpb++ = 0;
2888 b->sign = a->sign;
2889 return MP_OKAY;
2893 * shifts with subtractions when the result is greater than b.
2895 * The method is slightly modified to shift B unconditionally up to just under
2896 * the leading bit of b. This saves a lot of multiple precision shifting.
2898 int mp_montgomery_calc_normalization (mp_int * a, const mp_int * b)
2900 int x, bits, res;
2902 /* how many bits of last digit does b use */
2903 bits = mp_count_bits (b) % DIGIT_BIT;
2906 if (b->used > 1) {
2907 if ((res = mp_2expt (a, (b->used - 1) * DIGIT_BIT + bits - 1)) != MP_OKAY) {
2908 return res;
2910 } else {
2911 mp_set(a, 1);
2912 bits = 1;
2916 /* now compute C = A * B mod b */
2917 for (x = bits - 1; x < DIGIT_BIT; x++) {
2918 if ((res = mp_mul_2 (a, a)) != MP_OKAY) {
2919 return res;
2921 if (mp_cmp_mag (a, b) != MP_LT) {
2922 if ((res = s_mp_sub (a, b, a)) != MP_OKAY) {
2923 return res;
2928 return MP_OKAY;
2931 /* computes xR**-1 == x (mod N) via Montgomery Reduction */
2933 mp_montgomery_reduce (mp_int * x, const mp_int * n, mp_digit rho)
2935 int ix, res, digs;
2936 mp_digit mu;
2938 /* can the fast reduction [comba] method be used?
2940 * Note that unlike in mul you're safely allowed *less*
2941 * than the available columns [255 per default] since carries
2942 * are fixed up in the inner loop.
2944 digs = n->used * 2 + 1;
2945 if ((digs < MP_WARRAY) &&
2946 n->used <
2947 (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
2948 return fast_mp_montgomery_reduce (x, n, rho);
2951 /* grow the input as required */
2952 if (x->alloc < digs) {
2953 if ((res = mp_grow (x, digs)) != MP_OKAY) {
2954 return res;
2957 x->used = digs;
2959 for (ix = 0; ix < n->used; ix++) {
2960 /* mu = ai * rho mod b
2962 * The value of rho must be precalculated via
2963 * montgomery_setup() such that
2964 * it equals -1/n0 mod b this allows the
2965 * following inner loop to reduce the
2966 * input one digit at a time
2968 mu = (mp_digit) (((mp_word)x->dp[ix]) * ((mp_word)rho) & MP_MASK);
2970 /* a = a + mu * m * b**i */
2972 register int iy;
2973 register mp_digit *tmpn, *tmpx, u;
2974 register mp_word r;
2976 /* alias for digits of the modulus */
2977 tmpn = n->dp;
2979 /* alias for the digits of x [the input] */
2980 tmpx = x->dp + ix;
2982 /* set the carry to zero */
2983 u = 0;
2985 /* Multiply and add in place */
2986 for (iy = 0; iy < n->used; iy++) {
2987 /* compute product and sum */
2988 r = ((mp_word)mu) * ((mp_word)*tmpn++) +
2989 ((mp_word) u) + ((mp_word) * tmpx);
2991 /* get carry */
2992 u = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
2994 /* fix digit */
2995 *tmpx++ = (mp_digit)(r & ((mp_word) MP_MASK));
2997 /* At this point the ix'th digit of x should be zero */
3000 /* propagate carries upwards as required*/
3001 while (u) {
3002 *tmpx += u;
3003 u = *tmpx >> DIGIT_BIT;
3004 *tmpx++ &= MP_MASK;
3009 /* at this point the n.used'th least
3010 * significant digits of x are all zero
3011 * which means we can shift x to the
3012 * right by n.used digits and the
3013 * residue is unchanged.
3016 /* x = x/b**n.used */
3017 mp_clamp(x);
3018 mp_rshd (x, n->used);
3020 /* if x >= n then x = x - n */
3021 if (mp_cmp_mag (x, n) != MP_LT) {
3022 return s_mp_sub (x, n, x);
3025 return MP_OKAY;
3028 /* setups the montgomery reduction stuff */
3030 mp_montgomery_setup (const mp_int * n, mp_digit * rho)
3032 mp_digit x, b;
3034 /* fast inversion mod 2**k
3036 * Based on the fact that
3038 * XA = 1 (mod 2**n) => (X(2-XA)) A = 1 (mod 2**2n)
3039 * => 2*X*A - X*X*A*A = 1
3040 * => 2*(1) - (1) = 1
3042 b = n->dp[0];
3044 if ((b & 1) == 0) {
3045 return MP_VAL;
3048 x = (((b + 2) & 4) << 1) + b; /* here x*a==1 mod 2**4 */
3049 x *= 2 - b * x; /* here x*a==1 mod 2**8 */
3050 x *= 2 - b * x; /* here x*a==1 mod 2**16 */
3051 x *= 2 - b * x; /* here x*a==1 mod 2**32 */
3053 /* rho = -1/m mod b */
3054 *rho = (((mp_word)1 << ((mp_word) DIGIT_BIT)) - x) & MP_MASK;
3056 return MP_OKAY;
3059 /* high level multiplication (handles sign) */
3060 int mp_mul (const mp_int * a, const mp_int * b, mp_int * c)
3062 int res, neg;
3063 neg = (a->sign == b->sign) ? MP_ZPOS : MP_NEG;
3065 /* use Karatsuba? */
3066 if (MIN (a->used, b->used) >= KARATSUBA_MUL_CUTOFF) {
3067 res = mp_karatsuba_mul (a, b, c);
3068 } else
3070 /* can we use the fast multiplier?
3072 * The fast multiplier can be used if the output will
3073 * have less than MP_WARRAY digits and the number of
3074 * digits won't affect carry propagation
3076 int digs = a->used + b->used + 1;
3078 if ((digs < MP_WARRAY) &&
3079 MIN(a->used, b->used) <=
3080 (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
3081 res = fast_s_mp_mul_digs (a, b, c, digs);
3082 } else
3083 res = s_mp_mul (a, b, c); /* uses s_mp_mul_digs */
3085 c->sign = (c->used > 0) ? neg : MP_ZPOS;
3086 return res;
3089 /* d = a * b (mod c) */
3091 mp_mulmod (const mp_int * a, const mp_int * b, mp_int * c, mp_int * d)
3093 int res;
3094 mp_int t;
3096 if ((res = mp_init (&t)) != MP_OKAY) {
3097 return res;
3100 if ((res = mp_mul (a, b, &t)) != MP_OKAY) {
3101 mp_clear (&t);
3102 return res;
3104 res = mp_mod (&t, c, d);
3105 mp_clear (&t);
3106 return res;
3109 /* table of first PRIME_SIZE primes */
3110 static const mp_digit __prime_tab[] = {
3111 0x0002, 0x0003, 0x0005, 0x0007, 0x000B, 0x000D, 0x0011, 0x0013,
3112 0x0017, 0x001D, 0x001F, 0x0025, 0x0029, 0x002B, 0x002F, 0x0035,
3113 0x003B, 0x003D, 0x0043, 0x0047, 0x0049, 0x004F, 0x0053, 0x0059,
3114 0x0061, 0x0065, 0x0067, 0x006B, 0x006D, 0x0071, 0x007F, 0x0083,
3115 0x0089, 0x008B, 0x0095, 0x0097, 0x009D, 0x00A3, 0x00A7, 0x00AD,
3116 0x00B3, 0x00B5, 0x00BF, 0x00C1, 0x00C5, 0x00C7, 0x00D3, 0x00DF,
3117 0x00E3, 0x00E5, 0x00E9, 0x00EF, 0x00F1, 0x00FB, 0x0101, 0x0107,
3118 0x010D, 0x010F, 0x0115, 0x0119, 0x011B, 0x0125, 0x0133, 0x0137,
3120 0x0139, 0x013D, 0x014B, 0x0151, 0x015B, 0x015D, 0x0161, 0x0167,
3121 0x016F, 0x0175, 0x017B, 0x017F, 0x0185, 0x018D, 0x0191, 0x0199,
3122 0x01A3, 0x01A5, 0x01AF, 0x01B1, 0x01B7, 0x01BB, 0x01C1, 0x01C9,
3123 0x01CD, 0x01CF, 0x01D3, 0x01DF, 0x01E7, 0x01EB, 0x01F3, 0x01F7,
3124 0x01FD, 0x0209, 0x020B, 0x021D, 0x0223, 0x022D, 0x0233, 0x0239,
3125 0x023B, 0x0241, 0x024B, 0x0251, 0x0257, 0x0259, 0x025F, 0x0265,
3126 0x0269, 0x026B, 0x0277, 0x0281, 0x0283, 0x0287, 0x028D, 0x0293,
3127 0x0295, 0x02A1, 0x02A5, 0x02AB, 0x02B3, 0x02BD, 0x02C5, 0x02CF,
3129 0x02D7, 0x02DD, 0x02E3, 0x02E7, 0x02EF, 0x02F5, 0x02F9, 0x0301,
3130 0x0305, 0x0313, 0x031D, 0x0329, 0x032B, 0x0335, 0x0337, 0x033B,
3131 0x033D, 0x0347, 0x0355, 0x0359, 0x035B, 0x035F, 0x036D, 0x0371,
3132 0x0373, 0x0377, 0x038B, 0x038F, 0x0397, 0x03A1, 0x03A9, 0x03AD,
3133 0x03B3, 0x03B9, 0x03C7, 0x03CB, 0x03D1, 0x03D7, 0x03DF, 0x03E5,
3134 0x03F1, 0x03F5, 0x03FB, 0x03FD, 0x0407, 0x0409, 0x040F, 0x0419,
3135 0x041B, 0x0425, 0x0427, 0x042D, 0x043F, 0x0443, 0x0445, 0x0449,
3136 0x044F, 0x0455, 0x045D, 0x0463, 0x0469, 0x047F, 0x0481, 0x048B,
3138 0x0493, 0x049D, 0x04A3, 0x04A9, 0x04B1, 0x04BD, 0x04C1, 0x04C7,
3139 0x04CD, 0x04CF, 0x04D5, 0x04E1, 0x04EB, 0x04FD, 0x04FF, 0x0503,
3140 0x0509, 0x050B, 0x0511, 0x0515, 0x0517, 0x051B, 0x0527, 0x0529,
3141 0x052F, 0x0551, 0x0557, 0x055D, 0x0565, 0x0577, 0x0581, 0x058F,
3142 0x0593, 0x0595, 0x0599, 0x059F, 0x05A7, 0x05AB, 0x05AD, 0x05B3,
3143 0x05BF, 0x05C9, 0x05CB, 0x05CF, 0x05D1, 0x05D5, 0x05DB, 0x05E7,
3144 0x05F3, 0x05FB, 0x0607, 0x060D, 0x0611, 0x0617, 0x061F, 0x0623,
3145 0x062B, 0x062F, 0x063D, 0x0641, 0x0647, 0x0649, 0x064D, 0x0653
3148 /* determines if an integers is divisible by one
3149 * of the first PRIME_SIZE primes or not
3151 * sets result to 0 if not, 1 if yes
3153 static int mp_prime_is_divisible (const mp_int * a, int *result)
3155 int err, ix;
3156 mp_digit res;
3158 /* default to not */
3159 *result = MP_NO;
3161 for (ix = 0; ix < PRIME_SIZE; ix++) {
3162 /* what is a mod __prime_tab[ix] */
3163 if ((err = mp_mod_d (a, __prime_tab[ix], &res)) != MP_OKAY) {
3164 return err;
3167 /* is the residue zero? */
3168 if (res == 0) {
3169 *result = MP_YES;
3170 return MP_OKAY;
3174 return MP_OKAY;
3177 /* Miller-Rabin test of "a" to the base of "b" as described in
3178 * HAC pp. 139 Algorithm 4.24
3180 * Sets result to 0 if definitely composite or 1 if probably prime.
3181 * Randomly the chance of error is no more than 1/4 and often
3182 * very much lower.
3184 static int mp_prime_miller_rabin (mp_int * a, const mp_int * b, int *result)
3186 mp_int n1, y, r;
3187 int s, j, err;
3189 /* default */
3190 *result = MP_NO;
3192 /* ensure b > 1 */
3193 if (mp_cmp_d(b, 1) != MP_GT) {
3194 return MP_VAL;
3197 /* get n1 = a - 1 */
3198 if ((err = mp_init_copy (&n1, a)) != MP_OKAY) {
3199 return err;
3201 if ((err = mp_sub_d (&n1, 1, &n1)) != MP_OKAY) {
3202 goto __N1;
3205 /* set 2**s * r = n1 */
3206 if ((err = mp_init_copy (&r, &n1)) != MP_OKAY) {
3207 goto __N1;
3210 /* count the number of least significant bits
3211 * which are zero
3213 s = mp_cnt_lsb(&r);
3215 /* now divide n - 1 by 2**s */
3216 if ((err = mp_div_2d (&r, s, &r, NULL)) != MP_OKAY) {
3217 goto __R;
3220 /* compute y = b**r mod a */
3221 if ((err = mp_init (&y)) != MP_OKAY) {
3222 goto __R;
3224 if ((err = mp_exptmod (b, &r, a, &y)) != MP_OKAY) {
3225 goto __Y;
3228 /* if y != 1 and y != n1 do */
3229 if (mp_cmp_d (&y, 1) != MP_EQ && mp_cmp (&y, &n1) != MP_EQ) {
3230 j = 1;
3231 /* while j <= s-1 and y != n1 */
3232 while ((j <= (s - 1)) && mp_cmp (&y, &n1) != MP_EQ) {
3233 if ((err = mp_sqrmod (&y, a, &y)) != MP_OKAY) {
3234 goto __Y;
3237 /* if y == 1 then composite */
3238 if (mp_cmp_d (&y, 1) == MP_EQ) {
3239 goto __Y;
3242 ++j;
3245 /* if y != n1 then composite */
3246 if (mp_cmp (&y, &n1) != MP_EQ) {
3247 goto __Y;
3251 /* probably prime now */
3252 *result = MP_YES;
3253 __Y:mp_clear (&y);
3254 __R:mp_clear (&r);
3255 __N1:mp_clear (&n1);
3256 return err;
3259 /* performs a variable number of rounds of Miller-Rabin
3261 * Probability of error after t rounds is no more than
3264 * Sets result to 1 if probably prime, 0 otherwise
3266 static int mp_prime_is_prime (mp_int * a, int t, int *result)
3268 mp_int b;
3269 int ix, err, res;
3271 /* default to no */
3272 *result = MP_NO;
3274 /* valid value of t? */
3275 if (t <= 0 || t > PRIME_SIZE) {
3276 return MP_VAL;
3279 /* is the input equal to one of the primes in the table? */
3280 for (ix = 0; ix < PRIME_SIZE; ix++) {
3281 if (mp_cmp_d(a, __prime_tab[ix]) == MP_EQ) {
3282 *result = 1;
3283 return MP_OKAY;
3287 /* first perform trial division */
3288 if ((err = mp_prime_is_divisible (a, &res)) != MP_OKAY) {
3289 return err;
3292 /* return if it was trivially divisible */
3293 if (res == MP_YES) {
3294 return MP_OKAY;
3297 /* now perform the miller-rabin rounds */
3298 if ((err = mp_init (&b)) != MP_OKAY) {
3299 return err;
3302 for (ix = 0; ix < t; ix++) {
3303 /* set the prime */
3304 mp_set (&b, __prime_tab[ix]);
3306 if ((err = mp_prime_miller_rabin (a, &b, &res)) != MP_OKAY) {
3307 goto __B;
3310 if (res == MP_NO) {
3311 goto __B;
3315 /* passed the test */
3316 *result = MP_YES;
3317 __B:mp_clear (&b);
3318 return err;
3321 static const struct {
3322 int k, t;
3323 } sizes[] = {
3324 { 128, 28 },
3325 { 256, 16 },
3326 { 384, 10 },
3327 { 512, 7 },
3328 { 640, 6 },
3329 { 768, 5 },
3330 { 896, 4 },
3331 { 1024, 4 }
3334 /* returns # of RM trials required for a given bit size */
3335 int mp_prime_rabin_miller_trials(int size)
3337 int x;
3339 for (x = 0; x < (int)(sizeof(sizes)/(sizeof(sizes[0]))); x++) {
3340 if (sizes[x].k == size) {
3341 return sizes[x].t;
3342 } else if (sizes[x].k > size) {
3343 return (x == 0) ? sizes[0].t : sizes[x - 1].t;
3346 return sizes[x-1].t + 1;
3349 /* makes a truly random prime of a given size (bits),
3351 * Flags are as follows:
3353 * LTM_PRIME_BBS - make prime congruent to 3 mod 4
3354 * LTM_PRIME_SAFE - make sure (p-1)/2 is prime as well (implies LTM_PRIME_BBS)
3355 * LTM_PRIME_2MSB_OFF - make the 2nd highest bit zero
3356 * LTM_PRIME_2MSB_ON - make the 2nd highest bit one
3358 * You have to supply a callback which fills in a buffer with random bytes. "dat" is a parameter you can
3359 * have passed to the callback (e.g. a state or something). This function doesn't use "dat" itself
3360 * so it can be NULL
3364 /* This is possibly the mother of all prime generation functions, muahahahahaha! */
3365 int mp_prime_random_ex(mp_int *a, int t, int size, int flags, ltm_prime_callback cb, void *dat)
3367 unsigned char *tmp, maskAND, maskOR_msb, maskOR_lsb;
3368 int res, err, bsize, maskOR_msb_offset;
3370 /* sanity check the input */
3371 if (size <= 1 || t <= 0) {
3372 return MP_VAL;
3375 /* LTM_PRIME_SAFE implies LTM_PRIME_BBS */
3376 if (flags & LTM_PRIME_SAFE) {
3377 flags |= LTM_PRIME_BBS;
3380 /* calc the byte size */
3381 bsize = (size>>3)+((size&7)?1:0);
3383 /* we need a buffer of bsize bytes */
3384 tmp = HeapAlloc(GetProcessHeap(), 0, bsize);
3385 if (tmp == NULL) {
3386 return MP_MEM;
3389 /* calc the maskAND value for the MSbyte*/
3390 maskAND = ((size&7) == 0) ? 0xFF : (0xFF >> (8 - (size & 7)));
3392 /* calc the maskOR_msb */
3393 maskOR_msb = 0;
3394 maskOR_msb_offset = ((size & 7) == 1) ? 1 : 0;
3395 if (flags & LTM_PRIME_2MSB_ON) {
3396 maskOR_msb |= 1 << ((size - 2) & 7);
3397 } else if (flags & LTM_PRIME_2MSB_OFF) {
3398 maskAND &= ~(1 << ((size - 2) & 7));
3401 /* get the maskOR_lsb */
3402 maskOR_lsb = 0;
3403 if (flags & LTM_PRIME_BBS) {
3404 maskOR_lsb |= 3;
3407 do {
3408 /* read the bytes */
3409 if (cb(tmp, bsize, dat) != bsize) {
3410 err = MP_VAL;
3411 goto error;
3414 /* work over the MSbyte */
3415 tmp[0] &= maskAND;
3416 tmp[0] |= 1 << ((size - 1) & 7);
3418 /* mix in the maskORs */
3419 tmp[maskOR_msb_offset] |= maskOR_msb;
3420 tmp[bsize-1] |= maskOR_lsb;
3422 /* read it in */
3423 if ((err = mp_read_unsigned_bin(a, tmp, bsize)) != MP_OKAY) { goto error; }
3425 /* is it prime? */
3426 if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY) { goto error; }
3427 if (res == MP_NO) {
3428 continue;
3431 if (flags & LTM_PRIME_SAFE) {
3432 /* see if (a-1)/2 is prime */
3433 if ((err = mp_sub_d(a, 1, a)) != MP_OKAY) { goto error; }
3434 if ((err = mp_div_2(a, a)) != MP_OKAY) { goto error; }
3436 /* is it prime? */
3437 if ((err = mp_prime_is_prime(a, t, &res)) != MP_OKAY) { goto error; }
3439 } while (res == MP_NO);
3441 if (flags & LTM_PRIME_SAFE) {
3442 /* restore a to the original value */
3443 if ((err = mp_mul_2(a, a)) != MP_OKAY) { goto error; }
3444 if ((err = mp_add_d(a, 1, a)) != MP_OKAY) { goto error; }
3447 err = MP_OKAY;
3448 error:
3449 HeapFree(GetProcessHeap(), 0, tmp);
3450 return err;
3453 /* reads an unsigned char array, assumes the msb is stored first [big endian] */
3455 mp_read_unsigned_bin (mp_int * a, const unsigned char *b, int c)
3457 int res;
3459 /* make sure there are at least two digits */
3460 if (a->alloc < 2) {
3461 if ((res = mp_grow(a, 2)) != MP_OKAY) {
3462 return res;
3466 /* zero the int */
3467 mp_zero (a);
3469 /* read the bytes in */
3470 while (c-- > 0) {
3471 if ((res = mp_mul_2d (a, 8, a)) != MP_OKAY) {
3472 return res;
3475 a->dp[0] |= *b++;
3476 a->used += 1;
3478 mp_clamp (a);
3479 return MP_OKAY;
3482 /* reduces x mod m, assumes 0 < x < m**2, mu is
3483 * precomputed via mp_reduce_setup.
3484 * From HAC pp.604 Algorithm 14.42
3487 mp_reduce (mp_int * x, const mp_int * m, const mp_int * mu)
3489 mp_int q;
3490 int res, um = m->used;
3492 /* q = x */
3493 if ((res = mp_init_copy (&q, x)) != MP_OKAY) {
3494 return res;
3497 /* q1 = x / b**(k-1) */
3498 mp_rshd (&q, um - 1);
3500 /* according to HAC this optimization is ok */
3501 if (((unsigned long) um) > (((mp_digit)1) << (DIGIT_BIT - 1))) {
3502 if ((res = mp_mul (&q, mu, &q)) != MP_OKAY) {
3503 goto CLEANUP;
3505 } else {
3506 if ((res = s_mp_mul_high_digs (&q, mu, &q, um - 1)) != MP_OKAY) {
3507 goto CLEANUP;
3511 /* q3 = q2 / b**(k+1) */
3512 mp_rshd (&q, um + 1);
3514 /* x = x mod b**(k+1), quick (no division) */
3515 if ((res = mp_mod_2d (x, DIGIT_BIT * (um + 1), x)) != MP_OKAY) {
3516 goto CLEANUP;
3519 /* q = q * m mod b**(k+1), quick (no division) */
3520 if ((res = s_mp_mul_digs (&q, m, &q, um + 1)) != MP_OKAY) {
3521 goto CLEANUP;
3524 /* x = x - q */
3525 if ((res = mp_sub (x, &q, x)) != MP_OKAY) {
3526 goto CLEANUP;
3529 /* If x < 0, add b**(k+1) to it */
3530 if (mp_cmp_d (x, 0) == MP_LT) {
3531 mp_set (&q, 1);
3532 if ((res = mp_lshd (&q, um + 1)) != MP_OKAY)
3533 goto CLEANUP;
3534 if ((res = mp_add (x, &q, x)) != MP_OKAY)
3535 goto CLEANUP;
3538 /* Back off if it's too big */
3539 while (mp_cmp (x, m) != MP_LT) {
3540 if ((res = s_mp_sub (x, m, x)) != MP_OKAY) {
3541 goto CLEANUP;
3545 CLEANUP:
3546 mp_clear (&q);
3548 return res;
3551 /* reduces a modulo n where n is of the form 2**p - d */
3553 mp_reduce_2k(mp_int *a, const mp_int *n, mp_digit d)
3555 mp_int q;
3556 int p, res;
3558 if ((res = mp_init(&q)) != MP_OKAY) {
3559 return res;
3562 p = mp_count_bits(n);
3563 top:
3564 /* q = a/2**p, a = a mod 2**p */
3565 if ((res = mp_div_2d(a, p, &q, a)) != MP_OKAY) {
3566 goto ERR;
3569 if (d != 1) {
3570 /* q = q * d */
3571 if ((res = mp_mul_d(&q, d, &q)) != MP_OKAY) {
3572 goto ERR;
3576 /* a = a + q */
3577 if ((res = s_mp_add(a, &q, a)) != MP_OKAY) {
3578 goto ERR;
3581 if (mp_cmp_mag(a, n) != MP_LT) {
3582 s_mp_sub(a, n, a);
3583 goto top;
3586 ERR:
3587 mp_clear(&q);
3588 return res;
3591 /* determines the setup value */
3592 int
3593 mp_reduce_2k_setup(const mp_int *a, mp_digit *d)
3595 int res, p;
3596 mp_int tmp;
3598 if ((res = mp_init(&tmp)) != MP_OKAY) {
3599 return res;
3602 p = mp_count_bits(a);
3603 if ((res = mp_2expt(&tmp, p)) != MP_OKAY) {
3604 mp_clear(&tmp);
3605 return res;
3608 if ((res = s_mp_sub(&tmp, a, &tmp)) != MP_OKAY) {
3609 mp_clear(&tmp);
3610 return res;
3613 *d = tmp.dp[0];
3614 mp_clear(&tmp);
3615 return MP_OKAY;
3618 /* pre-calculate the value required for Barrett reduction
3619 * For a given modulus "b" it calulates the value required in "a"
3621 int mp_reduce_setup (mp_int * a, const mp_int * b)
3623 int res;
3625 if ((res = mp_2expt (a, b->used * 2 * DIGIT_BIT)) != MP_OKAY) {
3626 return res;
3628 return mp_div (a, b, a, NULL);
3631 /* set to a digit */
3632 void mp_set (mp_int * a, mp_digit b)
3634 mp_zero (a);
3635 a->dp[0] = b & MP_MASK;
3636 a->used = (a->dp[0] != 0) ? 1 : 0;
3639 /* set a 32-bit const */
3640 int mp_set_int (mp_int * a, unsigned long b)
3642 int x, res;
3644 mp_zero (a);
3646 /* set four bits at a time */
3647 for (x = 0; x < 8; x++) {
3648 /* shift the number up four bits */
3649 if ((res = mp_mul_2d (a, 4, a)) != MP_OKAY) {
3650 return res;
3653 /* OR in the top four bits of the source */
3654 a->dp[0] |= (b >> 28) & 15;
3656 /* shift the source up to the next four bits */
3657 b <<= 4;
3659 /* ensure that digits are not clamped off */
3660 a->used += 1;
3662 mp_clamp (a);
3663 return MP_OKAY;
3666 /* shrink a bignum */
3667 int mp_shrink (mp_int * a)
3669 mp_digit *tmp;
3670 if (a->alloc != a->used && a->used > 0) {
3671 if ((tmp = HeapReAlloc(GetProcessHeap(), 0, a->dp, sizeof (mp_digit) * a->used)) == NULL) {
3672 return MP_MEM;
3674 a->dp = tmp;
3675 a->alloc = a->used;
3677 return MP_OKAY;
3680 /* get the size for an signed equivalent */
3681 int mp_signed_bin_size (const mp_int * a)
3683 return 1 + mp_unsigned_bin_size (a);
3686 /* computes b = a*a */
3688 mp_sqr (const mp_int * a, mp_int * b)
3690 int res;
3692 if (a->used >= KARATSUBA_SQR_CUTOFF) {
3693 res = mp_karatsuba_sqr (a, b);
3694 } else
3696 /* can we use the fast comba multiplier? */
3697 if ((a->used * 2 + 1) < MP_WARRAY &&
3698 a->used <
3699 (1 << (sizeof(mp_word) * CHAR_BIT - 2*DIGIT_BIT - 1))) {
3700 res = fast_s_mp_sqr (a, b);
3701 } else
3702 res = s_mp_sqr (a, b);
3704 b->sign = MP_ZPOS;
3705 return res;
3708 /* c = a * a (mod b) */
3710 mp_sqrmod (const mp_int * a, mp_int * b, mp_int * c)
3712 int res;
3713 mp_int t;
3715 if ((res = mp_init (&t)) != MP_OKAY) {
3716 return res;
3719 if ((res = mp_sqr (a, &t)) != MP_OKAY) {
3720 mp_clear (&t);
3721 return res;
3723 res = mp_mod (&t, b, c);
3724 mp_clear (&t);
3725 return res;
3728 /* high level subtraction (handles signs) */
3730 mp_sub (mp_int * a, mp_int * b, mp_int * c)
3732 int sa, sb, res;
3734 sa = a->sign;
3735 sb = b->sign;
3737 if (sa != sb) {
3738 /* subtract a negative from a positive, OR */
3739 /* subtract a positive from a negative. */
3740 /* In either case, ADD their magnitudes, */
3741 /* and use the sign of the first number. */
3742 c->sign = sa;
3743 res = s_mp_add (a, b, c);
3744 } else {
3745 /* subtract a positive from a positive, OR */
3746 /* subtract a negative from a negative. */
3747 /* First, take the difference between their */
3748 /* magnitudes, then... */
3749 if (mp_cmp_mag (a, b) != MP_LT) {
3750 /* Copy the sign from the first */
3751 c->sign = sa;
3752 /* The first has a larger or equal magnitude */
3753 res = s_mp_sub (a, b, c);
3754 } else {
3755 /* The result has the *opposite* sign from */
3756 /* the first number. */
3757 c->sign = (sa == MP_ZPOS) ? MP_NEG : MP_ZPOS;
3758 /* The second has a larger magnitude */
3759 res = s_mp_sub (b, a, c);
3762 return res;
3765 /* single digit subtraction */
3767 mp_sub_d (mp_int * a, mp_digit b, mp_int * c)
3769 mp_digit *tmpa, *tmpc, mu;
3770 int res, ix, oldused;
3772 /* grow c as required */
3773 if (c->alloc < a->used + 1) {
3774 if ((res = mp_grow(c, a->used + 1)) != MP_OKAY) {
3775 return res;
3779 /* if a is negative just do an unsigned
3780 * addition [with fudged signs]
3782 if (a->sign == MP_NEG) {
3783 a->sign = MP_ZPOS;
3784 res = mp_add_d(a, b, c);
3785 a->sign = c->sign = MP_NEG;
3786 return res;
3789 /* setup regs */
3790 oldused = c->used;
3791 tmpa = a->dp;
3792 tmpc = c->dp;
3794 /* if a <= b simply fix the single digit */
3795 if ((a->used == 1 && a->dp[0] <= b) || a->used == 0) {
3796 if (a->used == 1) {
3797 *tmpc++ = b - *tmpa;
3798 } else {
3799 *tmpc++ = b;
3801 ix = 1;
3803 /* negative/1digit */
3804 c->sign = MP_NEG;
3805 c->used = 1;
3806 } else {
3807 /* positive/size */
3808 c->sign = MP_ZPOS;
3809 c->used = a->used;
3811 /* subtract first digit */
3812 *tmpc = *tmpa++ - b;
3813 mu = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3814 *tmpc++ &= MP_MASK;
3816 /* handle rest of the digits */
3817 for (ix = 1; ix < a->used; ix++) {
3818 *tmpc = *tmpa++ - mu;
3819 mu = *tmpc >> (sizeof(mp_digit) * CHAR_BIT - 1);
3820 *tmpc++ &= MP_MASK;
3824 /* zero excess digits */
3825 while (ix++ < oldused) {
3826 *tmpc++ = 0;
3828 mp_clamp(c);
3829 return MP_OKAY;
3832 /* store in unsigned [big endian] format */
3834 mp_to_unsigned_bin (const mp_int * a, unsigned char *b)
3836 int x, res;
3837 mp_int t;
3839 if ((res = mp_init_copy (&t, a)) != MP_OKAY) {
3840 return res;
3843 x = 0;
3844 while (mp_iszero (&t) == 0) {
3845 b[x++] = (unsigned char) (t.dp[0] & 255);
3846 if ((res = mp_div_2d (&t, 8, &t, NULL)) != MP_OKAY) {
3847 mp_clear (&t);
3848 return res;
3851 bn_reverse (b, x);
3852 mp_clear (&t);
3853 return MP_OKAY;
3856 /* get the size for an unsigned equivalent */
3858 mp_unsigned_bin_size (const mp_int * a)
3860 int size = mp_count_bits (a);
3861 return (size / 8 + ((size & 7) != 0 ? 1 : 0));
3864 /* reverse an array, used for radix code */
3865 static void
3866 bn_reverse (unsigned char *s, int len)
3868 int ix, iy;
3869 unsigned char t;
3871 ix = 0;
3872 iy = len - 1;
3873 while (ix < iy) {
3874 t = s[ix];
3875 s[ix] = s[iy];
3876 s[iy] = t;
3877 ++ix;
3878 --iy;
3882 /* low level addition, based on HAC pp.594, Algorithm 14.7 */
3883 static int
3884 s_mp_add (mp_int * a, mp_int * b, mp_int * c)
3886 mp_int *x;
3887 int olduse, res, min, max;
3889 /* find sizes, we let |a| <= |b| which means we have to sort
3890 * them. "x" will point to the input with the most digits
3892 if (a->used > b->used) {
3893 min = b->used;
3894 max = a->used;
3895 x = a;
3896 } else {
3897 min = a->used;
3898 max = b->used;
3899 x = b;
3902 /* init result */
3903 if (c->alloc < max + 1) {
3904 if ((res = mp_grow (c, max + 1)) != MP_OKAY) {
3905 return res;
3909 /* get old used digit count and set new one */
3910 olduse = c->used;
3911 c->used = max + 1;
3914 register mp_digit u, *tmpa, *tmpb, *tmpc;
3915 register int i;
3917 /* alias for digit pointers */
3919 /* first input */
3920 tmpa = a->dp;
3922 /* second input */
3923 tmpb = b->dp;
3925 /* destination */
3926 tmpc = c->dp;
3928 /* zero the carry */
3929 u = 0;
3930 for (i = 0; i < min; i++) {
3931 /* Compute the sum at one digit, T[i] = A[i] + B[i] + U */
3932 *tmpc = *tmpa++ + *tmpb++ + u;
3934 /* U = carry bit of T[i] */
3935 u = *tmpc >> ((mp_digit)DIGIT_BIT);
3937 /* take away carry bit from T[i] */
3938 *tmpc++ &= MP_MASK;
3941 /* now copy higher words if any, that is in A+B
3942 * if A or B has more digits add those in
3944 if (min != max) {
3945 for (; i < max; i++) {
3946 /* T[i] = X[i] + U */
3947 *tmpc = x->dp[i] + u;
3949 /* U = carry bit of T[i] */
3950 u = *tmpc >> ((mp_digit)DIGIT_BIT);
3952 /* take away carry bit from T[i] */
3953 *tmpc++ &= MP_MASK;
3957 /* add carry */
3958 *tmpc++ = u;
3960 /* clear digits above oldused */
3961 for (i = c->used; i < olduse; i++) {
3962 *tmpc++ = 0;
3966 mp_clamp (c);
3967 return MP_OKAY;
3970 static int s_mp_exptmod (const mp_int * G, const mp_int * X, mp_int * P, mp_int * Y)
3972 mp_int M[256], res, mu;
3973 mp_digit buf;
3974 int err, bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
3976 /* find window size */
3977 x = mp_count_bits (X);
3978 if (x <= 7) {
3979 winsize = 2;
3980 } else if (x <= 36) {
3981 winsize = 3;
3982 } else if (x <= 140) {
3983 winsize = 4;
3984 } else if (x <= 450) {
3985 winsize = 5;
3986 } else if (x <= 1303) {
3987 winsize = 6;
3988 } else if (x <= 3529) {
3989 winsize = 7;
3990 } else {
3991 winsize = 8;
3994 /* init M array */
3995 /* init first cell */
3996 if ((err = mp_init(&M[1])) != MP_OKAY) {
3997 return err;
4000 /* now init the second half of the array */
4001 for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4002 if ((err = mp_init(&M[x])) != MP_OKAY) {
4003 for (y = 1<<(winsize-1); y < x; y++) {
4004 mp_clear (&M[y]);
4006 mp_clear(&M[1]);
4007 return err;
4011 /* create mu, used for Barrett reduction */
4012 if ((err = mp_init (&mu)) != MP_OKAY) {
4013 goto __M;
4015 if ((err = mp_reduce_setup (&mu, P)) != MP_OKAY) {
4016 goto __MU;
4019 /* create M table
4021 * The M table contains powers of the base,
4022 * e.g. M[x] = G**x mod P
4024 * The first half of the table is not
4025 * computed though accept for M[0] and M[1]
4027 if ((err = mp_mod (G, P, &M[1])) != MP_OKAY) {
4028 goto __MU;
4031 /* compute the value at M[1<<(winsize-1)] by squaring
4032 * M[1] (winsize-1) times
4034 if ((err = mp_copy (&M[1], &M[1 << (winsize - 1)])) != MP_OKAY) {
4035 goto __MU;
4038 for (x = 0; x < (winsize - 1); x++) {
4039 if ((err = mp_sqr (&M[1 << (winsize - 1)],
4040 &M[1 << (winsize - 1)])) != MP_OKAY) {
4041 goto __MU;
4043 if ((err = mp_reduce (&M[1 << (winsize - 1)], P, &mu)) != MP_OKAY) {
4044 goto __MU;
4048 /* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
4049 * for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
4051 for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
4052 if ((err = mp_mul (&M[x - 1], &M[1], &M[x])) != MP_OKAY) {
4053 goto __MU;
4055 if ((err = mp_reduce (&M[x], P, &mu)) != MP_OKAY) {
4056 goto __MU;
4060 /* setup result */
4061 if ((err = mp_init (&res)) != MP_OKAY) {
4062 goto __MU;
4064 mp_set (&res, 1);
4066 /* set initial mode and bit cnt */
4067 mode = 0;
4068 bitcnt = 1;
4069 buf = 0;
4070 digidx = X->used - 1;
4071 bitcpy = 0;
4072 bitbuf = 0;
4074 for (;;) {
4075 /* grab next digit as required */
4076 if (--bitcnt == 0) {
4077 /* if digidx == -1 we are out of digits */
4078 if (digidx == -1) {
4079 break;
4081 /* read next digit and reset the bitcnt */
4082 buf = X->dp[digidx--];
4083 bitcnt = DIGIT_BIT;
4086 /* grab the next msb from the exponent */
4087 y = (buf >> (mp_digit)(DIGIT_BIT - 1)) & 1;
4088 buf <<= (mp_digit)1;
4090 /* if the bit is zero and mode == 0 then we ignore it
4091 * These represent the leading zero bits before the first 1 bit
4092 * in the exponent. Technically this opt is not required but it
4093 * does lower the # of trivial squaring/reductions used
4095 if (mode == 0 && y == 0) {
4096 continue;
4099 /* if the bit is zero and mode == 1 then we square */
4100 if (mode == 1 && y == 0) {
4101 if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4102 goto __RES;
4104 if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4105 goto __RES;
4107 continue;
4110 /* else we add it to the window */
4111 bitbuf |= (y << (winsize - ++bitcpy));
4112 mode = 2;
4114 if (bitcpy == winsize) {
4115 /* ok window is filled so square as required and multiply */
4116 /* square first */
4117 for (x = 0; x < winsize; x++) {
4118 if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4119 goto __RES;
4121 if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4122 goto __RES;
4126 /* then multiply */
4127 if ((err = mp_mul (&res, &M[bitbuf], &res)) != MP_OKAY) {
4128 goto __RES;
4130 if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4131 goto __RES;
4134 /* empty window and reset */
4135 bitcpy = 0;
4136 bitbuf = 0;
4137 mode = 1;
4141 /* if bits remain then square/multiply */
4142 if (mode == 2 && bitcpy > 0) {
4143 /* square then multiply if the bit is set */
4144 for (x = 0; x < bitcpy; x++) {
4145 if ((err = mp_sqr (&res, &res)) != MP_OKAY) {
4146 goto __RES;
4148 if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4149 goto __RES;
4152 bitbuf <<= 1;
4153 if ((bitbuf & (1 << winsize)) != 0) {
4154 /* then multiply */
4155 if ((err = mp_mul (&res, &M[1], &res)) != MP_OKAY) {
4156 goto __RES;
4158 if ((err = mp_reduce (&res, P, &mu)) != MP_OKAY) {
4159 goto __RES;
4165 mp_exch (&res, Y);
4166 err = MP_OKAY;
4167 __RES:mp_clear (&res);
4168 __MU:mp_clear (&mu);
4169 __M:
4170 mp_clear(&M[1]);
4171 for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
4172 mp_clear (&M[x]);
4174 return err;
4177 /* multiplies |a| * |b| and only computes up to digs digits of result
4178 * HAC pp. 595, Algorithm 14.12 Modified so you can control how
4179 * many digits of output are created.
4181 static int
4182 s_mp_mul_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4184 mp_int t;
4185 int res, pa, pb, ix, iy;
4186 mp_digit u;
4187 mp_word r;
4188 mp_digit tmpx, *tmpt, *tmpy;
4190 /* can we use the fast multiplier? */
4191 if (((digs) < MP_WARRAY) &&
4192 MIN (a->used, b->used) <
4193 (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4194 return fast_s_mp_mul_digs (a, b, c, digs);
4197 if ((res = mp_init_size (&t, digs)) != MP_OKAY) {
4198 return res;
4200 t.used = digs;
4202 /* compute the digits of the product directly */
4203 pa = a->used;
4204 for (ix = 0; ix < pa; ix++) {
4205 /* set the carry to zero */
4206 u = 0;
4208 /* limit ourselves to making digs digits of output */
4209 pb = MIN (b->used, digs - ix);
4211 /* setup some aliases */
4212 /* copy of the digit from a used within the nested loop */
4213 tmpx = a->dp[ix];
4215 /* an alias for the destination shifted ix places */
4216 tmpt = t.dp + ix;
4218 /* an alias for the digits of b */
4219 tmpy = b->dp;
4221 /* compute the columns of the output and propagate the carry */
4222 for (iy = 0; iy < pb; iy++) {
4223 /* compute the column as a mp_word */
4224 r = ((mp_word)*tmpt) +
4225 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4226 ((mp_word) u);
4228 /* the new column is the lower part of the result */
4229 *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4231 /* get the carry word from the result */
4232 u = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4234 /* set carry if it is placed below digs */
4235 if (ix + iy < digs) {
4236 *tmpt = u;
4240 mp_clamp (&t);
4241 mp_exch (&t, c);
4243 mp_clear (&t);
4244 return MP_OKAY;
4247 /* multiplies |a| * |b| and does not compute the lower digs digits
4248 * [meant to get the higher part of the product]
4250 static int
4251 s_mp_mul_high_digs (const mp_int * a, const mp_int * b, mp_int * c, int digs)
4253 mp_int t;
4254 int res, pa, pb, ix, iy;
4255 mp_digit u;
4256 mp_word r;
4257 mp_digit tmpx, *tmpt, *tmpy;
4259 /* can we use the fast multiplier? */
4260 if (((a->used + b->used + 1) < MP_WARRAY)
4261 && MIN (a->used, b->used) < (1 << ((CHAR_BIT * sizeof (mp_word)) - (2 * DIGIT_BIT)))) {
4262 return fast_s_mp_mul_high_digs (a, b, c, digs);
4265 if ((res = mp_init_size (&t, a->used + b->used + 1)) != MP_OKAY) {
4266 return res;
4268 t.used = a->used + b->used + 1;
4270 pa = a->used;
4271 pb = b->used;
4272 for (ix = 0; ix < pa; ix++) {
4273 /* clear the carry */
4274 u = 0;
4276 /* left hand side of A[ix] * B[iy] */
4277 tmpx = a->dp[ix];
4279 /* alias to the address of where the digits will be stored */
4280 tmpt = &(t.dp[digs]);
4282 /* alias for where to read the right hand side from */
4283 tmpy = b->dp + (digs - ix);
4285 for (iy = digs - ix; iy < pb; iy++) {
4286 /* calculate the double precision result */
4287 r = ((mp_word)*tmpt) +
4288 ((mp_word)tmpx) * ((mp_word)*tmpy++) +
4289 ((mp_word) u);
4291 /* get the lower part */
4292 *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4294 /* carry the carry */
4295 u = (mp_digit) (r >> ((mp_word) DIGIT_BIT));
4297 *tmpt = u;
4299 mp_clamp (&t);
4300 mp_exch (&t, c);
4301 mp_clear (&t);
4302 return MP_OKAY;
4305 /* low level squaring, b = a*a, HAC pp.596-597, Algorithm 14.16 */
4306 static int
4307 s_mp_sqr (const mp_int * a, mp_int * b)
4309 mp_int t;
4310 int res, ix, iy, pa;
4311 mp_word r;
4312 mp_digit u, tmpx, *tmpt;
4314 pa = a->used;
4315 if ((res = mp_init_size (&t, 2*pa + 1)) != MP_OKAY) {
4316 return res;
4319 /* default used is maximum possible size */
4320 t.used = 2*pa + 1;
4322 for (ix = 0; ix < pa; ix++) {
4323 /* first calculate the digit at 2*ix */
4324 /* calculate double precision result */
4325 r = ((mp_word) t.dp[2*ix]) +
4326 ((mp_word)a->dp[ix])*((mp_word)a->dp[ix]);
4328 /* store lower part in result */
4329 t.dp[ix+ix] = (mp_digit) (r & ((mp_word) MP_MASK));
4331 /* get the carry */
4332 u = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4334 /* left hand side of A[ix] * A[iy] */
4335 tmpx = a->dp[ix];
4337 /* alias for where to store the results */
4338 tmpt = t.dp + (2*ix + 1);
4340 for (iy = ix + 1; iy < pa; iy++) {
4341 /* first calculate the product */
4342 r = ((mp_word)tmpx) * ((mp_word)a->dp[iy]);
4344 /* now calculate the double precision result, note we use
4345 * addition instead of *2 since it's easier to optimize
4347 r = ((mp_word) *tmpt) + r + r + ((mp_word) u);
4349 /* store lower part */
4350 *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4352 /* get carry */
4353 u = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4355 /* propagate upwards */
4356 while (u != 0) {
4357 r = ((mp_word) *tmpt) + ((mp_word) u);
4358 *tmpt++ = (mp_digit) (r & ((mp_word) MP_MASK));
4359 u = (mp_digit)(r >> ((mp_word) DIGIT_BIT));
4363 mp_clamp (&t);
4364 mp_exch (&t, b);
4365 mp_clear (&t);
4366 return MP_OKAY;
4369 /* low level subtraction (assumes |a| > |b|), HAC pp.595 Algorithm 14.9 */
4371 s_mp_sub (const mp_int * a, const mp_int * b, mp_int * c)
4373 int olduse, res, min, max;
4375 /* find sizes */
4376 min = b->used;
4377 max = a->used;
4379 /* init result */
4380 if (c->alloc < max) {
4381 if ((res = mp_grow (c, max)) != MP_OKAY) {
4382 return res;
4385 olduse = c->used;
4386 c->used = max;
4389 register mp_digit u, *tmpa, *tmpb, *tmpc;
4390 register int i;
4392 /* alias for digit pointers */
4393 tmpa = a->dp;
4394 tmpb = b->dp;
4395 tmpc = c->dp;
4397 /* set carry to zero */
4398 u = 0;
4399 for (i = 0; i < min; i++) {
4400 /* T[i] = A[i] - B[i] - U */
4401 *tmpc = *tmpa++ - *tmpb++ - u;
4403 /* U = carry bit of T[i]
4404 * Note this saves performing an AND operation since
4405 * if a carry does occur it will propagate all the way to the
4406 * MSB. As a result a single shift is enough to get the carry
4408 u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4410 /* Clear carry from T[i] */
4411 *tmpc++ &= MP_MASK;
4414 /* now copy higher words if any, e.g. if A has more digits than B */
4415 for (; i < max; i++) {
4416 /* T[i] = A[i] - U */
4417 *tmpc = *tmpa++ - u;
4419 /* U = carry bit of T[i] */
4420 u = *tmpc >> ((mp_digit)(CHAR_BIT * sizeof (mp_digit) - 1));
4422 /* Clear carry from T[i] */
4423 *tmpc++ &= MP_MASK;
4426 /* clear digits above used (since we may not have grown result above) */
4427 for (i = c->used; i < olduse; i++) {
4428 *tmpc++ = 0;
4432 mp_clamp (c);
4433 return MP_OKAY;