2 * Multiprecision integer arithmetic, implementing mpint.h.
16 #define SIZE_T_BITS (CHAR_BIT * sizeof(size_t))
19 * Inline helpers to take min and max of size_t values, used
20 * throughout this code.
22 static inline size_t size_t_min(size_t a
, size_t b
)
26 static inline size_t size_t_max(size_t a
, size_t b
)
32 * Helper to fetch a word of data from x with array overflow checking.
33 * If x is too short to have that word, 0 is returned.
35 static inline BignumInt
mp_word(mp_int
*x
, size_t i
)
37 return i
< x
->nw
? x
->w
[i
] : 0;
41 * Shift an ordinary C integer by BIGNUM_INT_BITS, in a way that
42 * avoids writing a shift operator whose RHS is greater or equal to
43 * the size of the type, because that's undefined behaviour in C.
45 * In fact we must avoid even writing it in a definitely-untaken
46 * branch of an if, because compilers will sometimes warn about
47 * that. So you can't just write 'shift too big ? 0 : n >> shift',
48 * because even if 'shift too big' is a constant-expression
49 * evaluating to false, you can still get complaints about the
50 * else clause of the ?:.
52 * So we have to re-check _inside_ that clause, so that the shift
53 * count is reset to something nonsensical but safe in the case
54 * where the clause wasn't going to be taken anyway.
56 static uintmax_t shift_right_by_one_word(uintmax_t n
)
58 bool shift_too_big
= BIGNUM_INT_BYTES
>= sizeof(n
);
59 return shift_too_big
? 0 :
60 n
>> (shift_too_big
? 0 : BIGNUM_INT_BITS
);
62 static uintmax_t shift_left_by_one_word(uintmax_t n
)
64 bool shift_too_big
= BIGNUM_INT_BYTES
>= sizeof(n
);
65 return shift_too_big
? 0 :
66 n
<< (shift_too_big
? 0 : BIGNUM_INT_BITS
);
69 mp_int
*mp_make_sized(size_t nw
)
71 mp_int
*x
= snew_plus(mp_int
, nw
* sizeof(BignumInt
));
72 assert(nw
); /* we outlaw the zero-word mp_int */
74 x
->w
= snew_plus_get_aux(x
);
79 mp_int
*mp_new(size_t maxbits
)
81 size_t words
= (maxbits
+ BIGNUM_INT_BITS
- 1) / BIGNUM_INT_BITS
;
82 return mp_make_sized(words
);
85 mp_int
*mp_resize(mp_int
*mp
, size_t newmaxbits
)
87 mp_int
*copy
= mp_new(newmaxbits
);
88 mp_copy_into(copy
, mp
);
93 mp_int
*mp_from_integer(uintmax_t n
)
95 mp_int
*x
= mp_make_sized(
96 (sizeof(n
) + BIGNUM_INT_BYTES
- 1) / BIGNUM_INT_BYTES
);
97 for (size_t i
= 0; i
< x
->nw
; i
++)
98 x
->w
[i
] = n
>> (i
* BIGNUM_INT_BITS
);
102 size_t mp_max_bytes(mp_int
*x
)
104 return x
->nw
* BIGNUM_INT_BYTES
;
107 size_t mp_max_bits(mp_int
*x
)
109 return x
->nw
* BIGNUM_INT_BITS
;
112 void mp_free(mp_int
*x
)
115 smemclr(x
, sizeof(*x
));
119 void mp_dump(FILE *fp
, const char *prefix
, mp_int
*x
, const char *suffix
)
121 fprintf(fp
, "%s0x", prefix
);
122 for (size_t i
= mp_max_bytes(x
); i
-- > 0 ;)
123 fprintf(fp
, "%02X", mp_get_byte(x
, i
));
127 void mp_copy_into(mp_int
*dest
, mp_int
*src
)
129 size_t copy_nw
= size_t_min(dest
->nw
, src
->nw
);
130 memmove(dest
->w
, src
->w
, copy_nw
* sizeof(BignumInt
));
131 smemclr(dest
->w
+ copy_nw
, (dest
->nw
- copy_nw
) * sizeof(BignumInt
));
134 void mp_copy_integer_into(mp_int
*r
, uintmax_t n
)
136 for (size_t i
= 0; i
< r
->nw
; i
++) {
138 n
= shift_right_by_one_word(n
);
143 * Conditional selection is done by negating 'which', to give a mask
144 * word which is all 1s if which==1 and all 0s if which==0. Then you
145 * can select between two inputs a,b without data-dependent control
146 * flow by XORing them to get their difference; ANDing with the mask
147 * word to replace that difference with 0 if which==0; and XORing that
148 * into a, which will either turn it into b or leave it alone.
150 * This trick will be used throughout this code and taken as read the
151 * rest of the time (or else I'd be here all week typing comments),
152 * but I felt I ought to explain it in words _once_.
154 void mp_select_into(mp_int
*dest
, mp_int
*src0
, mp_int
*src1
,
157 BignumInt mask
= -(BignumInt
)(1 & which
);
158 for (size_t i
= 0; i
< dest
->nw
; i
++) {
159 BignumInt srcword0
= mp_word(src0
, i
), srcword1
= mp_word(src1
, i
);
160 dest
->w
[i
] = srcword0
^ ((srcword1
^ srcword0
) & mask
);
164 void mp_cond_swap(mp_int
*x0
, mp_int
*x1
, unsigned swap
)
166 assert(x0
->nw
== x1
->nw
);
167 volatile BignumInt mask
= -(BignumInt
)(1 & swap
);
168 for (size_t i
= 0; i
< x0
->nw
; i
++) {
169 BignumInt diff
= (x0
->w
[i
] ^ x1
->w
[i
]) & mask
;
175 void mp_clear(mp_int
*x
)
177 smemclr(x
->w
, x
->nw
* sizeof(BignumInt
));
180 void mp_cond_clear(mp_int
*x
, unsigned clear
)
182 BignumInt mask
= ~-(BignumInt
)(1 & clear
);
183 for (size_t i
= 0; i
< x
->nw
; i
++)
188 * Common code between mp_from_bytes_{le,be} which reads bytes in an
189 * arbitrary arithmetic progression.
191 static mp_int
*mp_from_bytes_int(ptrlen bytes
, size_t m
, size_t c
)
193 size_t nw
= (bytes
.len
+ BIGNUM_INT_BYTES
- 1) / BIGNUM_INT_BYTES
;
194 nw
= size_t_max(nw
, 1);
195 mp_int
*n
= mp_make_sized(nw
);
196 for (size_t i
= 0; i
< bytes
.len
; i
++)
197 n
->w
[i
/ BIGNUM_INT_BYTES
] |=
198 (BignumInt
)(((const unsigned char *)bytes
.ptr
)[m
*i
+c
]) <<
199 (8 * (i
% BIGNUM_INT_BYTES
));
203 mp_int
*mp_from_bytes_le(ptrlen bytes
)
205 return mp_from_bytes_int(bytes
, 1, 0);
208 mp_int
*mp_from_bytes_be(ptrlen bytes
)
210 return mp_from_bytes_int(bytes
, -1, bytes
.len
- 1);
213 static mp_int
*mp_from_words(size_t nw
, const BignumInt
*w
)
215 mp_int
*x
= mp_make_sized(nw
);
216 memcpy(x
->w
, w
, x
->nw
* sizeof(BignumInt
));
221 * Decimal-to-binary conversion: just go through the input string
222 * adding on the decimal value of each digit, and then multiplying the
223 * number so far by 10.
225 mp_int
*mp_from_decimal_pl(ptrlen decimal
)
227 /* 196/59 is an upper bound (and also a continued-fraction
228 * convergent) for log2(10), so this conservatively estimates the
229 * number of bits that will be needed to store any number that can
230 * be written in this many decimal digits. */
231 assert(decimal
.len
< (~(size_t)0) / 196);
232 size_t bits
= 196 * decimal
.len
/ 59;
234 /* Now round that up to words. */
235 size_t words
= bits
/ BIGNUM_INT_BITS
+ 1;
237 mp_int
*x
= mp_make_sized(words
);
238 for (size_t i
= 0; i
< decimal
.len
; i
++) {
239 mp_add_integer_into(x
, x
, ((const char *)decimal
.ptr
)[i
] - '0');
241 if (i
+1 == decimal
.len
)
244 mp_mul_integer_into(x
, x
, 10);
249 mp_int
*mp_from_decimal(const char *decimal
)
251 return mp_from_decimal_pl(ptrlen_from_asciz(decimal
));
255 * Hex-to-binary conversion: _algorithmically_ simpler than decimal
256 * (none of those multiplications by 10), but there's some fiddly
257 * bit-twiddling needed to process each hex digit without diverging
258 * control flow depending on whether it's a letter or a number.
260 mp_int
*mp_from_hex_pl(ptrlen hex
)
262 assert(hex
.len
<= (~(size_t)0) / 4);
263 size_t bits
= hex
.len
* 4;
264 size_t words
= (bits
+ BIGNUM_INT_BITS
- 1) / BIGNUM_INT_BITS
;
265 words
= size_t_max(words
, 1);
266 mp_int
*x
= mp_make_sized(words
);
267 for (size_t nibble
= 0; nibble
< hex
.len
; nibble
++) {
268 BignumInt digit
= ((const char *)hex
.ptr
)[hex
.len
-1 - nibble
];
270 BignumInt lmask
= ~-((BignumInt
)((digit
-'a')|('f'-digit
))
271 >> (BIGNUM_INT_BITS
-1));
272 BignumInt umask
= ~-((BignumInt
)((digit
-'A')|('F'-digit
))
273 >> (BIGNUM_INT_BITS
-1));
275 BignumInt digitval
= digit
- '0';
276 digitval
^= (digitval
^ (digit
- 'a' + 10)) & lmask
;
277 digitval
^= (digitval
^ (digit
- 'A' + 10)) & umask
;
278 digitval
&= 0xF; /* at least be _slightly_ nice about weird input */
280 size_t word_idx
= nibble
/ (BIGNUM_INT_BYTES
*2);
281 size_t nibble_within_word
= nibble
% (BIGNUM_INT_BYTES
*2);
282 x
->w
[word_idx
] |= digitval
<< (nibble_within_word
* 4);
287 mp_int
*mp_from_hex(const char *hex
)
289 return mp_from_hex_pl(ptrlen_from_asciz(hex
));
292 mp_int
*mp_copy(mp_int
*x
)
294 return mp_from_words(x
->nw
, x
->w
);
297 uint8_t mp_get_byte(mp_int
*x
, size_t byte
)
299 return 0xFF & (mp_word(x
, byte
/ BIGNUM_INT_BYTES
) >>
300 (8 * (byte
% BIGNUM_INT_BYTES
)));
303 unsigned mp_get_bit(mp_int
*x
, size_t bit
)
305 return 1 & (mp_word(x
, bit
/ BIGNUM_INT_BITS
) >>
306 (bit
% BIGNUM_INT_BITS
));
309 uintmax_t mp_get_integer(mp_int
*x
)
312 for (size_t i
= x
->nw
; i
-- > 0 ;)
313 toret
= shift_left_by_one_word(toret
) | x
->w
[i
];
317 void mp_set_bit(mp_int
*x
, size_t bit
, unsigned val
)
319 size_t word
= bit
/ BIGNUM_INT_BITS
;
320 assert(word
< x
->nw
);
322 unsigned shift
= (bit
% BIGNUM_INT_BITS
);
324 x
->w
[word
] &= ~((BignumInt
)1 << shift
);
325 x
->w
[word
] |= (BignumInt
)(val
& 1) << shift
;
329 * Helper function used here and there to normalise any nonzero input
332 static inline unsigned normalise_to_1(BignumInt n
)
334 n
= (n
>> 1) | (n
& 1); /* ensure top bit is clear */
335 n
= (BignumInt
)(-n
) >> (BIGNUM_INT_BITS
- 1); /* normalise to 0 or 1 */
338 static inline unsigned normalise_to_1_u64(uint64_t n
)
340 n
= (n
>> 1) | (n
& 1); /* ensure top bit is clear */
341 n
= (-n
) >> 63; /* normalise to 0 or 1 */
346 * Find the highest nonzero word in a number. Returns the index of the
347 * word in x->w, and also a pair of output uint64_t in which that word
348 * appears in the high one shifted left by 'shift_wanted' bits, the
349 * words immediately below it occupy the space to the right, and the
350 * words below _that_ fill up the low one.
352 * If there is no nonzero word at all, the passed-by-reference output
353 * variables retain their original values.
355 static inline void mp_find_highest_nonzero_word_pair(
356 mp_int
*x
, size_t shift_wanted
, size_t *index
,
357 uint64_t *hi
, uint64_t *lo
)
359 uint64_t curr_hi
= 0, curr_lo
= 0;
361 for (size_t curr_index
= 0; curr_index
< x
->nw
; curr_index
++) {
362 BignumInt curr_word
= x
->w
[curr_index
];
363 unsigned indicator
= normalise_to_1(curr_word
);
365 curr_lo
= (BIGNUM_INT_BITS
< 64 ? (curr_lo
>> BIGNUM_INT_BITS
) : 0) |
366 (curr_hi
<< (64 - BIGNUM_INT_BITS
));
367 curr_hi
= (BIGNUM_INT_BITS
< 64 ? (curr_hi
>> BIGNUM_INT_BITS
) : 0) |
368 ((uint64_t)curr_word
<< shift_wanted
);
370 if (hi
) *hi
^= (curr_hi
^ *hi
) & -(uint64_t)indicator
;
371 if (lo
) *lo
^= (curr_lo
^ *lo
) & -(uint64_t)indicator
;
372 if (index
) *index
^= (curr_index
^ *index
) & -(size_t) indicator
;
376 size_t mp_get_nbits(mp_int
*x
)
378 /* Sentinel values in case there are no bits set at all: we
379 * imagine that there's a word at position -1 (i.e. the topmost
380 * fraction word) which is all 1s, because that way, we handle a
381 * zero input by considering its highest set bit to be the top one
382 * of that word, i.e. just below the units digit, i.e. at bit
383 * index -1, i.e. so we'll return 0 on output. */
384 size_t hiword_index
= -(size_t)1;
385 uint64_t hiword64
= ~(BignumInt
)0;
388 * Find the highest nonzero word and its index.
390 mp_find_highest_nonzero_word_pair(x
, 0, &hiword_index
, &hiword64
, NULL
);
391 BignumInt hiword
= hiword64
; /* in case BignumInt is a narrower type */
394 * Find the index of the highest set bit within hiword.
396 BignumInt hibit_index
= 0;
397 for (size_t i
= (1 << (BIGNUM_INT_BITS_BITS
-1)); i
!= 0; i
>>= 1) {
398 BignumInt shifted_word
= hiword
>> i
;
399 BignumInt indicator
=
400 (BignumInt
)(-shifted_word
) >> (BIGNUM_INT_BITS
-1);
401 hiword
^= (shifted_word
^ hiword
) & -indicator
;
402 hibit_index
+= i
& -(size_t)indicator
;
406 * Put together the result.
408 return (hiword_index
<< BIGNUM_INT_BITS_BITS
) + hibit_index
+ 1;
412 * Shared code between the hex and decimal output functions to get rid
413 * of leading zeroes on the output string. The idea is that we wrote
414 * out a fixed number of digits and a trailing \0 byte into 'buf', and
415 * now we want to shift it all left so that the first nonzero digit
416 * moves to buf[0] (or, if there are no nonzero digits at all, we move
417 * up by 'maxtrim', so that we return 0 as "0" instead of "").
419 static void trim_leading_zeroes(char *buf
, size_t bufsize
, size_t maxtrim
)
421 size_t trim
= maxtrim
;
424 * Look for the first character not equal to '0', to find the
428 for (size_t pos
= trim
; pos
-- > 0 ;) {
429 uint8_t diff
= buf
[pos
] ^ '0';
430 size_t mask
= -((((size_t)diff
) - 1) >> (SIZE_T_BITS
- 1));
431 trim
^= (trim
^ pos
) & ~mask
;
436 * Now do the shift, in log n passes each of which does a
437 * conditional shift by 2^i bytes if bit i is set in the shift
440 uint8_t *ubuf
= (uint8_t *)buf
;
441 for (size_t logd
= 0; bufsize
>> logd
; logd
++) {
442 uint8_t mask
= -(uint8_t)((trim
>> logd
) & 1);
443 size_t d
= (size_t)1 << logd
;
444 for (size_t i
= 0; i
+d
< bufsize
; i
++) {
445 uint8_t diff
= mask
& (ubuf
[i
] ^ ubuf
[i
+d
]);
453 * Binary to decimal conversion. Our strategy here is to extract each
454 * decimal digit by finding the input number's residue mod 10, then
455 * subtract that off to give an exact multiple of 10, which then means
456 * you can safely divide by 10 by means of shifting right one bit and
457 * then multiplying by the inverse of 5 mod 2^n.
459 char *mp_get_decimal(mp_int
*x_orig
)
461 mp_int
*x
= mp_copy(x_orig
), *y
= mp_make_sized(x
->nw
);
464 * The inverse of 5 mod 2^lots is 0xccccccccccccccccccccd, for an
465 * appropriate number of 'c's. Manually construct an integer the
468 mp_int
*inv5
= mp_make_sized(x
->nw
);
469 assert(BIGNUM_INT_BITS
% 8 == 0);
470 for (size_t i
= 0; i
< inv5
->nw
; i
++)
471 inv5
->w
[i
] = BIGNUM_INT_MASK
/ 5 * 4;
475 * 146/485 is an upper bound (and also a continued-fraction
476 * convergent) of log10(2), so this is a conservative estimate of
477 * the number of decimal digits needed to store a value that fits
478 * in this many binary bits.
480 assert(x
->nw
< (~(size_t)1) / (146 * BIGNUM_INT_BITS
));
481 size_t bufsize
= size_t_max(x
->nw
* (146 * BIGNUM_INT_BITS
) / 485, 1) + 2;
482 char *outbuf
= snewn(bufsize
, char);
483 outbuf
[bufsize
- 1] = '\0';
486 * Loop over the number generating digits from the least
487 * significant upwards, so that we write to outbuf in reverse
490 for (size_t pos
= bufsize
- 1; pos
-- > 0 ;) {
492 * Find the current residue mod 10. We do this by first
493 * summing the bytes of the number, with all but the lowest
494 * one multiplied by 6 (because 256^i == 6 mod 10 for all
495 * i>0). That gives us a single word congruent mod 10 to the
496 * input number, and then we reduce it further by manual
497 * multiplication and shifting, just in case the compiler
498 * target implements the C division operator in a way that has
499 * input-dependent timing.
501 uint32_t low_digit
= 0, maxval
= 0, mult
= 1;
502 for (size_t i
= 0; i
< x
->nw
; i
++) {
503 for (unsigned j
= 0; j
< BIGNUM_INT_BYTES
; j
++) {
504 low_digit
+= mult
* (0xFF & (x
->w
[i
] >> (8*j
)));
505 maxval
+= mult
* 0xFF;
509 * For _really_ big numbers, prevent overflow of t by
510 * periodically folding the top half of the accumulator
511 * into the bottom half, using the same rule 'multiply by
512 * 6 when shifting down by one or more whole bytes'.
514 if (maxval
> UINT32_MAX
- (6 * 0xFF * BIGNUM_INT_BYTES
)) {
515 low_digit
= (low_digit
& 0xFFFF) + 6 * (low_digit
>> 16);
516 maxval
= (maxval
& 0xFFFF) + 6 * (maxval
>> 16);
521 * Final reduction of low_digit. We multiply by 2^32 / 10
522 * (that's the constant 0x19999999) to get a 64-bit value
523 * whose top 32 bits are the approximate quotient
524 * low_digit/10; then we subtract off 10 times that; and
525 * finally we do one last trial subtraction of 10 by adding 6
526 * (which sets bit 4 if the number was just over 10) and then
529 low_digit
-= 10 * ((0x19999999ULL
* low_digit
) >> 32);
530 low_digit
-= 10 * ((low_digit
+ 6) >> 4);
532 assert(low_digit
< 10); /* make sure we did reduce fully */
533 outbuf
[pos
] = '0' + low_digit
;
536 * Now subtract off that digit, divide by 2 (using a right
537 * shift) and by 5 (using the modular inverse), to get the
538 * next output digit into the units position.
540 mp_sub_integer_into(x
, x
, low_digit
);
541 mp_rshift_fixed_into(y
, x
, 1);
542 mp_mul_into(x
, y
, inv5
);
549 trim_leading_zeroes(outbuf
, bufsize
, bufsize
- 2);
554 * Binary to hex conversion. Reasonably simple (only a spot of bit
555 * twiddling to choose whether to output a digit or a letter for each
558 static char *mp_get_hex_internal(mp_int
*x
, uint8_t letter_offset
)
560 size_t nibbles
= x
->nw
* BIGNUM_INT_BYTES
* 2;
561 size_t bufsize
= nibbles
+ 1;
562 char *outbuf
= snewn(bufsize
, char);
563 outbuf
[nibbles
] = '\0';
565 for (size_t nibble
= 0; nibble
< nibbles
; nibble
++) {
566 size_t word_idx
= nibble
/ (BIGNUM_INT_BYTES
*2);
567 size_t nibble_within_word
= nibble
% (BIGNUM_INT_BYTES
*2);
568 uint8_t digitval
= 0xF & (x
->w
[word_idx
] >> (nibble_within_word
* 4));
570 uint8_t mask
= -((digitval
+ 6) >> 4);
571 char digit
= digitval
+ '0' + (letter_offset
& mask
);
572 outbuf
[nibbles
-1 - nibble
] = digit
;
575 trim_leading_zeroes(outbuf
, bufsize
, nibbles
- 1);
579 char *mp_get_hex(mp_int
*x
)
581 return mp_get_hex_internal(x
, 'a' - ('0'+10));
584 char *mp_get_hex_uppercase(mp_int
*x
)
586 return mp_get_hex_internal(x
, 'A' - ('0'+10));
590 * Routines for reading and writing the SSH-1 and SSH-2 wire formats
591 * for multiprecision integers, declared in marshal.h.
593 * These can't avoid having control flow dependent on the true bit
594 * size of the number, because the wire format requires the number of
595 * output bytes to depend on that.
597 void BinarySink_put_mp_ssh1(BinarySink
*bs
, mp_int
*x
)
599 size_t bits
= mp_get_nbits(x
);
600 size_t bytes
= (bits
+ 7) / 8;
602 assert(bits
< 0x10000);
603 put_uint16(bs
, bits
);
604 for (size_t i
= bytes
; i
-- > 0 ;)
605 put_byte(bs
, mp_get_byte(x
, i
));
608 void BinarySink_put_mp_ssh2(BinarySink
*bs
, mp_int
*x
)
610 size_t bytes
= (mp_get_nbits(x
) + 8) / 8;
612 put_uint32(bs
, bytes
);
613 for (size_t i
= bytes
; i
-- > 0 ;)
614 put_byte(bs
, mp_get_byte(x
, i
));
617 mp_int
*BinarySource_get_mp_ssh1(BinarySource
*src
)
619 unsigned bitc
= get_uint16(src
);
620 ptrlen bytes
= get_data(src
, (bitc
+ 7) / 8);
622 return mp_from_integer(0);
624 mp_int
*toret
= mp_from_bytes_be(bytes
);
625 /* SSH-1.5 spec says that it's OK for the prefix uint16 to be
626 * _greater_ than the actual number of bits */
627 if (mp_get_nbits(toret
) > bitc
) {
628 src
->err
= BSE_INVALID
;
630 toret
= mp_from_integer(0);
636 mp_int
*BinarySource_get_mp_ssh2(BinarySource
*src
)
638 ptrlen bytes
= get_string(src
);
640 return mp_from_integer(0);
642 const unsigned char *p
= bytes
.ptr
;
643 if ((bytes
.len
> 0 &&
645 (p
[0] == 0 && (bytes
.len
<= 1 || !(p
[1] & 0x80)))))) {
646 src
->err
= BSE_INVALID
;
647 return mp_from_integer(0);
649 return mp_from_bytes_be(bytes
);
654 * Make an mp_int structure whose words array aliases a subinterval of
655 * some other mp_int. This makes it easy to read or write just the low
656 * or high words of a number, e.g. to add a number starting from a
657 * high bit position, or to reduce mod 2^{n*BIGNUM_INT_BITS}.
659 * The convention throughout this code is that when we store an mp_int
660 * directly by value, we always expect it to be an alias of some kind,
661 * so its words array won't ever need freeing. Whereas an 'mp_int *'
662 * has an owner, who knows whether it needs freeing or whether it was
663 * created by address-taking an alias.
665 static mp_int
mp_make_alias(mp_int
*in
, size_t offset
, size_t len
)
668 * Bounds-check the offset and length so that we always return
669 * something valid, even if it's not necessarily the length the
674 if (len
> in
->nw
- offset
)
675 len
= in
->nw
- offset
;
679 toret
.w
= in
->w
+ offset
;
684 * A special case of mp_make_alias: in some cases we preallocate a
685 * large mp_int to use as scratch space (to avoid pointless
686 * malloc/free churn in recursive or iterative work).
688 * mp_alloc_from_scratch creates an alias of size 'len' to part of
689 * 'pool', and adjusts 'pool' itself so that further allocations won't
690 * overwrite that space.
692 * There's no free function to go with this. Typically you just copy
693 * the pool mp_int by value, allocate from the copy, and when you're
694 * done with those allocations, throw the copy away and go back to the
695 * original value of pool. (A mark/release system.)
697 static mp_int
mp_alloc_from_scratch(mp_int
*pool
, size_t len
)
699 assert(len
<= pool
->nw
);
700 mp_int toret
= mp_make_alias(pool
, 0, len
);
701 *pool
= mp_make_alias(pool
, len
, pool
->nw
);
706 * Internal component common to lots of assorted add/subtract code.
707 * Reads words from a,b; writes into w_out (which might be NULL if the
708 * output isn't even needed). Takes an input carry flag in 'carry',
709 * and returns the output carry. Each word read from b is ANDed with
710 * b_and and then XORed with b_xor.
712 * So you can implement addition by setting b_and to all 1s and b_xor
713 * to 0; you can subtract by making b_xor all 1s too (effectively
714 * bit-flipping b) and also passing 1 as the input carry (to turn
715 * one's complement into two's complement). And you can do conditional
716 * add/subtract by choosing b_and to be all 1s or all 0s based on a
717 * condition, because the value of b will be totally ignored if b_and
720 static BignumCarry
mp_add_masked_into(
721 BignumInt
*w_out
, size_t rw
, mp_int
*a
, mp_int
*b
,
722 BignumInt b_and
, BignumInt b_xor
, BignumCarry carry
)
724 for (size_t i
= 0; i
< rw
; i
++) {
725 BignumInt aword
= mp_word(a
, i
), bword
= mp_word(b
, i
), out
;
726 bword
= (bword
& b_and
) ^ b_xor
;
727 BignumADC(out
, carry
, aword
, bword
, carry
);
735 * Like the public mp_add_into except that it returns the output carry.
737 static inline BignumCarry
mp_add_into_internal(mp_int
*r
, mp_int
*a
, mp_int
*b
)
739 return mp_add_masked_into(r
->w
, r
->nw
, a
, b
, ~(BignumInt
)0, 0, 0);
742 void mp_add_into(mp_int
*r
, mp_int
*a
, mp_int
*b
)
744 mp_add_into_internal(r
, a
, b
);
747 void mp_sub_into(mp_int
*r
, mp_int
*a
, mp_int
*b
)
749 mp_add_masked_into(r
->w
, r
->nw
, a
, b
, ~(BignumInt
)0, ~(BignumInt
)0, 1);
752 void mp_and_into(mp_int
*r
, mp_int
*a
, mp_int
*b
)
754 for (size_t i
= 0; i
< r
->nw
; i
++) {
755 BignumInt aword
= mp_word(a
, i
), bword
= mp_word(b
, i
);
756 r
->w
[i
] = aword
& bword
;
760 void mp_or_into(mp_int
*r
, mp_int
*a
, mp_int
*b
)
762 for (size_t i
= 0; i
< r
->nw
; i
++) {
763 BignumInt aword
= mp_word(a
, i
), bword
= mp_word(b
, i
);
764 r
->w
[i
] = aword
| bword
;
768 void mp_xor_into(mp_int
*r
, mp_int
*a
, mp_int
*b
)
770 for (size_t i
= 0; i
< r
->nw
; i
++) {
771 BignumInt aword
= mp_word(a
, i
), bword
= mp_word(b
, i
);
772 r
->w
[i
] = aword
^ bword
;
776 void mp_bic_into(mp_int
*r
, mp_int
*a
, mp_int
*b
)
778 for (size_t i
= 0; i
< r
->nw
; i
++) {
779 BignumInt aword
= mp_word(a
, i
), bword
= mp_word(b
, i
);
780 r
->w
[i
] = aword
& ~bword
;
784 static void mp_cond_negate(mp_int
*r
, mp_int
*x
, unsigned yes
)
786 BignumCarry carry
= yes
;
787 BignumInt flip
= -(BignumInt
)yes
;
788 for (size_t i
= 0; i
< r
->nw
; i
++) {
789 BignumInt xword
= mp_word(x
, i
);
791 BignumADC(r
->w
[i
], carry
, 0, xword
, carry
);
796 * Similar to mp_add_masked_into, but takes a C integer instead of an
797 * mp_int as the masked operand.
799 static BignumCarry
mp_add_masked_integer_into(
800 BignumInt
*w_out
, size_t rw
, mp_int
*a
, uintmax_t b
,
801 BignumInt b_and
, BignumInt b_xor
, BignumCarry carry
)
803 for (size_t i
= 0; i
< rw
; i
++) {
804 BignumInt aword
= mp_word(a
, i
);
806 b
= shift_right_by_one_word(b
);
808 bword
= (bword
^ b_xor
) & b_and
;
809 BignumADC(out
, carry
, aword
, bword
, carry
);
816 void mp_add_integer_into(mp_int
*r
, mp_int
*a
, uintmax_t n
)
818 mp_add_masked_integer_into(r
->w
, r
->nw
, a
, n
, ~(BignumInt
)0, 0, 0);
821 void mp_sub_integer_into(mp_int
*r
, mp_int
*a
, uintmax_t n
)
823 mp_add_masked_integer_into(r
->w
, r
->nw
, a
, n
,
824 ~(BignumInt
)0, ~(BignumInt
)0, 1);
828 * Sets r to a + n << (word_index * BIGNUM_INT_BITS), treating
829 * word_index as secret data.
831 static void mp_add_integer_into_shifted_by_words(
832 mp_int
*r
, mp_int
*a
, uintmax_t n
, size_t word_index
)
834 unsigned indicator
= 0;
835 BignumCarry carry
= 0;
837 for (size_t i
= 0; i
< r
->nw
; i
++) {
838 /* indicator becomes 1 when we reach the index that the least
839 * significant bits of n want to be placed at, and it stays 1
841 indicator
|= 1 ^ normalise_to_1(i
^ word_index
);
843 /* If indicator is 1, we add the low bits of n into r, and
844 * shift n down. If it's 0, we add zero bits into r, and
846 BignumInt bword
= n
& -(BignumInt
)indicator
;
847 uintmax_t new_n
= shift_right_by_one_word(n
);
848 n
^= (n
^ new_n
) & -(uintmax_t)indicator
;
850 BignumInt aword
= mp_word(a
, i
);
852 BignumADC(out
, carry
, aword
, bword
, carry
);
857 void mp_mul_integer_into(mp_int
*r
, mp_int
*a
, uint16_t n
)
859 BignumInt carry
= 0, mult
= n
;
860 for (size_t i
= 0; i
< r
->nw
; i
++) {
861 BignumInt aword
= mp_word(a
, i
);
862 BignumMULADD(carry
, r
->w
[i
], aword
, mult
, carry
);
867 void mp_cond_add_into(mp_int
*r
, mp_int
*a
, mp_int
*b
, unsigned yes
)
869 BignumInt mask
= -(BignumInt
)(yes
& 1);
870 mp_add_masked_into(r
->w
, r
->nw
, a
, b
, mask
, 0, 0);
873 void mp_cond_sub_into(mp_int
*r
, mp_int
*a
, mp_int
*b
, unsigned yes
)
875 BignumInt mask
= -(BignumInt
)(yes
& 1);
876 mp_add_masked_into(r
->w
, r
->nw
, a
, b
, mask
, mask
, 1 & mask
);
880 * Ordered comparison between unsigned numbers is done by subtracting
881 * one from the other and looking at the output carry.
883 unsigned mp_cmp_hs(mp_int
*a
, mp_int
*b
)
885 size_t rw
= size_t_max(a
->nw
, b
->nw
);
886 return mp_add_masked_into(NULL
, rw
, a
, b
, ~(BignumInt
)0, ~(BignumInt
)0, 1);
889 unsigned mp_hs_integer(mp_int
*x
, uintmax_t n
)
892 size_t nwords
= sizeof(n
)/BIGNUM_INT_BYTES
;
893 for (size_t i
= 0, e
= size_t_max(x
->nw
, nwords
); i
< e
; i
++) {
895 n
= shift_right_by_one_word(n
);
897 BignumADC(dummy_out
, carry
, mp_word(x
, i
), ~nword
, carry
);
904 * Equality comparison is done by bitwise XOR of the input numbers,
905 * ORing together all the output words, and normalising the result
906 * using our careful normalise_to_1 helper function.
908 unsigned mp_cmp_eq(mp_int
*a
, mp_int
*b
)
911 for (size_t i
= 0, limit
= size_t_max(a
->nw
, b
->nw
); i
< limit
; i
++)
912 diff
|= mp_word(a
, i
) ^ mp_word(b
, i
);
913 return 1 ^ normalise_to_1(diff
); /* return 1 if diff _is_ zero */
916 unsigned mp_eq_integer(mp_int
*x
, uintmax_t n
)
919 size_t nwords
= sizeof(n
)/BIGNUM_INT_BYTES
;
920 for (size_t i
= 0, e
= size_t_max(x
->nw
, nwords
); i
< e
; i
++) {
922 n
= shift_right_by_one_word(n
);
923 diff
|= mp_word(x
, i
) ^ nword
;
925 return 1 ^ normalise_to_1(diff
); /* return 1 if diff _is_ zero */
928 static void mp_neg_into(mp_int
*r
, mp_int
*a
)
932 mp_sub_into(r
, &zero
, a
);
935 mp_int
*mp_add(mp_int
*x
, mp_int
*y
)
937 mp_int
*r
= mp_make_sized(size_t_max(x
->nw
, y
->nw
) + 1);
938 mp_add_into(r
, x
, y
);
942 mp_int
*mp_sub(mp_int
*x
, mp_int
*y
)
944 mp_int
*r
= mp_make_sized(size_t_max(x
->nw
, y
->nw
));
945 mp_sub_into(r
, x
, y
);
950 * Internal routine: multiply and accumulate in the trivial O(N^2)
951 * way. Sets r <- r + a*b.
953 static void mp_mul_add_simple(mp_int
*r
, mp_int
*a
, mp_int
*b
)
955 BignumInt
*aend
= a
->w
+ a
->nw
, *bend
= b
->w
+ b
->nw
, *rend
= r
->w
+ r
->nw
;
957 for (BignumInt
*ap
= a
->w
, *rp
= r
->w
;
958 ap
< aend
&& rp
< rend
; ap
++, rp
++) {
960 BignumInt adata
= *ap
, carry
= 0, *rq
= rp
;
962 for (BignumInt
*bp
= b
->w
; bp
< bend
&& rq
< rend
; bp
++, rq
++) {
963 BignumInt bdata
= bp
< bend
? *bp
: 0;
964 BignumMULADD2(carry
, *rq
, adata
, bdata
, *rq
, carry
);
967 for (; rq
< rend
; rq
++)
968 BignumADC(*rq
, carry
, carry
, *rq
, 0);
972 #ifndef KARATSUBA_THRESHOLD /* allow redefinition via -D for testing */
973 #define KARATSUBA_THRESHOLD 24
976 static inline size_t mp_mul_scratchspace_unary(size_t n
)
979 * Simplistic and overcautious bound on the amount of scratch
980 * space that the recursive multiply function will need.
982 * The rationale is: on the main Karatsuba branch of
983 * mp_mul_internal, which is the most space-intensive one, we
984 * allocate space for (a0+a1) and (b0+b1) (each just over half the
985 * input length n) and their product (the sum of those sizes, i.e.
986 * just over n itself). Then in order to actually compute the
987 * product, we do a recursive multiplication of size just over n.
989 * If all those 'just over' weren't there, and everything was
990 * _exactly_ half the length, you'd get the amount of space for a
991 * size-n multiply defined by the recurrence M(n) = 2n + M(n/2),
992 * which is satisfied by M(n) = 4n. But instead it's (2n plus a
993 * word or two) and M(n/2 plus a word or two). On the assumption
994 * that there's still some constant k such that M(n) <= kn, this
995 * gives us kn = 2n + w + k(n/2 + w), where w is a small constant
996 * (one or two words). That simplifies to kn/2 = 2n + (k+1)w, and
997 * since we don't even _start_ needing scratch space until n is at
998 * least 50, we can bound 2n + (k+1)w above by 3n, giving k=6.
1000 * So I claim that 6n words of scratch space will suffice, and I
1001 * check that by assertion at every stage of the recursion.
1006 static size_t mp_mul_scratchspace(size_t rw
, size_t aw
, size_t bw
)
1008 size_t inlen
= size_t_min(rw
, size_t_max(aw
, bw
));
1009 return mp_mul_scratchspace_unary(inlen
);
1012 static void mp_mul_internal(mp_int
*r
, mp_int
*a
, mp_int
*b
, mp_int scratch
)
1014 size_t inlen
= size_t_min(r
->nw
, size_t_max(a
->nw
, b
->nw
));
1015 assert(scratch
.nw
>= mp_mul_scratchspace_unary(inlen
));
1019 if (inlen
< KARATSUBA_THRESHOLD
|| a
->nw
== 0 || b
->nw
== 0) {
1021 * The input numbers are too small to bother optimising. Go
1022 * straight to the simple primitive approach.
1024 mp_mul_add_simple(r
, a
, b
);
1029 * Karatsuba divide-and-conquer algorithm. We cut each input in
1030 * half, so that it's expressed as two big 'digits' in a giant
1036 * Then the product is of course
1038 * ab = a_1 b_1 D^2 + (a_1 b_0 + a_0 b_1) D + a_0 b_0
1040 * and we compute the three coefficients by recursively calling
1041 * ourself to do half-length multiplications.
1043 * The clever bit that makes this worth doing is that we only need
1044 * _one_ half-length multiplication for the central coefficient
1045 * rather than the two that it obviouly looks like, because we can
1046 * use a single multiplication to compute
1048 * (a_1 + a_0) (b_1 + b_0) = a_1 b_1 + a_1 b_0 + a_0 b_1 + a_0 b_0
1050 * and then we subtract the other two coefficients (a_1 b_1 and
1051 * a_0 b_0) which we were computing anyway.
1053 * Hence we get to multiply two numbers of length N in about three
1054 * times as much work as it takes to multiply numbers of length
1055 * N/2, which is obviously better than the four times as much work
1056 * it would take if we just did a long conventional multiply.
1059 /* Break up the input as botlen + toplen, with botlen >= toplen.
1060 * The 'base' D is equal to 2^{botlen * BIGNUM_INT_BITS}. */
1061 size_t toplen
= inlen
/ 2;
1062 size_t botlen
= inlen
- toplen
;
1064 /* Alias bignums that address the two halves of a,b, and useful
1066 mp_int a0
= mp_make_alias(a
, 0, botlen
);
1067 mp_int b0
= mp_make_alias(b
, 0, botlen
);
1068 mp_int a1
= mp_make_alias(a
, botlen
, toplen
);
1069 mp_int b1
= mp_make_alias(b
, botlen
, toplen
);
1070 mp_int r0
= mp_make_alias(r
, 0, botlen
*2);
1071 mp_int r1
= mp_make_alias(r
, botlen
, r
->nw
);
1072 mp_int r2
= mp_make_alias(r
, botlen
*2, r
->nw
);
1074 /* Recurse to compute a0*b0 and a1*b1, in their correct positions
1075 * in the output bignum. They can't overlap. */
1076 mp_mul_internal(&r0
, &a0
, &b0
, scratch
);
1077 mp_mul_internal(&r2
, &a1
, &b1
, scratch
);
1079 if (r
->nw
< inlen
*2) {
1081 * The output buffer isn't large enough to require the whole
1082 * product, so some of a1*b1 won't have been stored. In that
1083 * case we won't try to do the full Karatsuba optimisation;
1084 * we'll just recurse again to compute a0*b1 and a1*b0 - or at
1085 * least as much of them as the output buffer size requires -
1086 * and add each one in.
1088 mp_int s
= mp_alloc_from_scratch(
1089 &scratch
, size_t_min(botlen
+toplen
, r1
.nw
));
1091 mp_mul_internal(&s
, &a0
, &b1
, scratch
);
1092 mp_add_into(&r1
, &r1
, &s
);
1093 mp_mul_internal(&s
, &a1
, &b0
, scratch
);
1094 mp_add_into(&r1
, &r1
, &s
);
1098 /* a0+a1 and b0+b1 */
1099 mp_int asum
= mp_alloc_from_scratch(&scratch
, botlen
+1);
1100 mp_int bsum
= mp_alloc_from_scratch(&scratch
, botlen
+1);
1101 mp_add_into(&asum
, &a0
, &a1
);
1102 mp_add_into(&bsum
, &b0
, &b1
);
1105 mp_int product
= mp_alloc_from_scratch(&scratch
, botlen
*2+1);
1106 mp_mul_internal(&product
, &asum
, &bsum
, scratch
);
1108 /* Subtract off the outer terms we already have */
1109 mp_sub_into(&product
, &product
, &r0
);
1110 mp_sub_into(&product
, &product
, &r2
);
1112 /* And add it in with the right offset. */
1113 mp_add_into(&r1
, &r1
, &product
);
1116 void mp_mul_into(mp_int
*r
, mp_int
*a
, mp_int
*b
)
1118 mp_int
*scratch
= mp_make_sized(mp_mul_scratchspace(r
->nw
, a
->nw
, b
->nw
));
1119 mp_mul_internal(r
, a
, b
, *scratch
);
1123 mp_int
*mp_mul(mp_int
*x
, mp_int
*y
)
1125 mp_int
*r
= mp_make_sized(x
->nw
+ y
->nw
);
1126 mp_mul_into(r
, x
, y
);
1130 void mp_lshift_fixed_into(mp_int
*r
, mp_int
*a
, size_t bits
)
1132 size_t words
= bits
/ BIGNUM_INT_BITS
;
1133 size_t bitoff
= bits
% BIGNUM_INT_BITS
;
1135 for (size_t i
= r
->nw
; i
-- > 0 ;) {
1139 r
->w
[i
] = mp_word(a
, i
- words
);
1143 r
->w
[i
] |= mp_word(a
, i
- words
- 1) >>
1144 (BIGNUM_INT_BITS
- bitoff
);
1150 void mp_rshift_fixed_into(mp_int
*r
, mp_int
*a
, size_t bits
)
1152 size_t words
= bits
/ BIGNUM_INT_BITS
;
1153 size_t bitoff
= bits
% BIGNUM_INT_BITS
;
1155 for (size_t i
= 0; i
< r
->nw
; i
++) {
1156 r
->w
[i
] = mp_word(a
, i
+ words
);
1159 r
->w
[i
] |= mp_word(a
, i
+ words
+ 1) << (BIGNUM_INT_BITS
- bitoff
);
1164 mp_int
*mp_lshift_fixed(mp_int
*x
, size_t bits
)
1166 size_t words
= (bits
+ BIGNUM_INT_BITS
- 1) / BIGNUM_INT_BITS
;
1167 mp_int
*r
= mp_make_sized(x
->nw
+ words
);
1168 mp_lshift_fixed_into(r
, x
, bits
);
1172 mp_int
*mp_rshift_fixed(mp_int
*x
, size_t bits
)
1174 size_t words
= bits
/ BIGNUM_INT_BITS
;
1175 size_t nw
= x
->nw
- size_t_min(x
->nw
, words
);
1176 mp_int
*r
= mp_make_sized(size_t_max(nw
, 1));
1177 mp_rshift_fixed_into(r
, x
, bits
);
1182 * Safe right shift is done using the same technique as
1183 * trim_leading_zeroes above: you make an n-word left shift by
1184 * composing an appropriate subset of power-of-2-sized shifts, so it
1185 * takes log_2(n) loop iterations each of which does a different shift
1186 * by a power of 2 words, using the usual bit twiddling to make the
1187 * whole shift conditional on the appropriate bit of n.
1189 static void mp_rshift_safe_in_place(mp_int
*r
, size_t bits
)
1191 size_t wordshift
= bits
/ BIGNUM_INT_BITS
;
1192 size_t bitshift
= bits
% BIGNUM_INT_BITS
;
1194 unsigned clear
= (r
->nw
- wordshift
) >> (CHAR_BIT
* sizeof(size_t) - 1);
1195 mp_cond_clear(r
, clear
);
1197 for (unsigned bit
= 0; r
->nw
>> bit
; bit
++) {
1198 size_t word_offset
= (size_t)1 << bit
;
1199 BignumInt mask
= -(BignumInt
)((wordshift
>> bit
) & 1);
1200 for (size_t i
= 0; i
< r
->nw
; i
++) {
1201 BignumInt w
= mp_word(r
, i
+ word_offset
);
1202 r
->w
[i
] ^= (r
->w
[i
] ^ w
) & mask
;
1207 * That's done the shifting by words; now we do the shifting by
1210 for (unsigned bit
= 0; bit
< BIGNUM_INT_BITS_BITS
; bit
++) {
1211 unsigned shift
= 1 << bit
, upshift
= BIGNUM_INT_BITS
- shift
;
1212 BignumInt mask
= -(BignumInt
)((bitshift
>> bit
) & 1);
1213 for (size_t i
= 0; i
< r
->nw
; i
++) {
1214 BignumInt w
= ((r
->w
[i
] >> shift
) | (mp_word(r
, i
+1) << upshift
));
1215 r
->w
[i
] ^= (r
->w
[i
] ^ w
) & mask
;
1220 mp_int
*mp_rshift_safe(mp_int
*x
, size_t bits
)
1222 mp_int
*r
= mp_copy(x
);
1223 mp_rshift_safe_in_place(r
, bits
);
1227 void mp_rshift_safe_into(mp_int
*r
, mp_int
*x
, size_t bits
)
1230 mp_rshift_safe_in_place(r
, bits
);
1233 static void mp_lshift_safe_in_place(mp_int
*r
, size_t bits
)
1235 size_t wordshift
= bits
/ BIGNUM_INT_BITS
;
1236 size_t bitshift
= bits
% BIGNUM_INT_BITS
;
1239 * Same strategy as mp_rshift_safe_in_place, but of course the
1243 unsigned clear
= (r
->nw
- wordshift
) >> (CHAR_BIT
* sizeof(size_t) - 1);
1244 mp_cond_clear(r
, clear
);
1246 for (unsigned bit
= 0; r
->nw
>> bit
; bit
++) {
1247 size_t word_offset
= (size_t)1 << bit
;
1248 BignumInt mask
= -(BignumInt
)((wordshift
>> bit
) & 1);
1249 for (size_t i
= r
->nw
; i
-- > 0 ;) {
1250 BignumInt w
= mp_word(r
, i
- word_offset
);
1251 r
->w
[i
] ^= (r
->w
[i
] ^ w
) & mask
;
1255 size_t downshift
= BIGNUM_INT_BITS
- bitshift
;
1256 size_t no_shift
= (downshift
>> BIGNUM_INT_BITS_BITS
);
1257 downshift
&= ~-(size_t)no_shift
;
1258 BignumInt downshifted_mask
= ~-(BignumInt
)no_shift
;
1260 for (size_t i
= r
->nw
; i
-- > 0 ;) {
1261 r
->w
[i
] = (r
->w
[i
] << bitshift
) |
1262 ((mp_word(r
, i
-1) >> downshift
) & downshifted_mask
);
1266 void mp_lshift_safe_into(mp_int
*r
, mp_int
*x
, size_t bits
)
1269 mp_lshift_safe_in_place(r
, bits
);
1272 void mp_reduce_mod_2to(mp_int
*x
, size_t p
)
1274 size_t word
= p
/ BIGNUM_INT_BITS
;
1275 size_t mask
= ((size_t)1 << (p
% BIGNUM_INT_BITS
)) - 1;
1276 for (; word
< x
->nw
; word
++) {
1283 * Inverse mod 2^n is computed by an iterative technique which doubles
1284 * the number of bits at each step.
1286 mp_int
*mp_invert_mod_2to(mp_int
*x
, size_t p
)
1288 /* Input checks: x must be coprime to the modulus, i.e. odd, and p
1291 assert(x
->w
[0] & 1);
1294 size_t rw
= (p
+ BIGNUM_INT_BITS
- 1) / BIGNUM_INT_BITS
;
1295 rw
= size_t_max(rw
, 1);
1296 mp_int
*r
= mp_make_sized(rw
);
1298 size_t mul_scratchsize
= mp_mul_scratchspace(2*rw
, rw
, rw
);
1299 mp_int
*scratch_orig
= mp_make_sized(6 * rw
+ mul_scratchsize
);
1300 mp_int scratch_per_iter
= *scratch_orig
;
1301 mp_int mul_scratch
= mp_alloc_from_scratch(
1302 &scratch_per_iter
, mul_scratchsize
);
1306 for (size_t b
= 1; b
< p
; b
<<= 1) {
1308 * In each step of this iteration, we have the inverse of x
1309 * mod 2^b, and we want the inverse of x mod 2^{2b}.
1311 * Write B = 2^b for convenience, so we want x^{-1} mod B^2.
1312 * Let x = x_0 + B x_1 + k B^2, with 0 <= x_0,x_1 < B.
1314 * We want to find r_0 and r_1 such that
1315 * (r_1 B + r_0) (x_1 B + x_0) == 1 (mod B^2)
1317 * To begin with, we know r_0 must be the inverse mod B of
1318 * x_0, i.e. of x, i.e. it is the inverse we computed in the
1319 * previous iteration. So now all we need is r_1.
1321 * Multiplying out, neglecting multiples of B^2, and writing
1322 * x_0 r_0 = K B + 1, we have
1324 * r_1 x_0 B + r_0 x_1 B + K B == 0 (mod B^2)
1325 * => r_1 x_0 B == - r_0 x_1 B - K B (mod B^2)
1326 * => r_1 x_0 == - r_0 x_1 - K (mod B)
1327 * => r_1 == r_0 (- r_0 x_1 - K) (mod B)
1329 * (the last step because we multiply through by the inverse
1330 * of x_0, which we already know is r_0).
1333 mp_int scratch_this_iter
= scratch_per_iter
;
1334 size_t Bw
= (b
+ BIGNUM_INT_BITS
- 1) / BIGNUM_INT_BITS
;
1335 size_t B2w
= (2*b
+ BIGNUM_INT_BITS
- 1) / BIGNUM_INT_BITS
;
1337 /* Start by finding K: multiply x_0 by r_0, and shift down. */
1338 mp_int x0
= mp_alloc_from_scratch(&scratch_this_iter
, Bw
);
1339 mp_copy_into(&x0
, x
);
1340 mp_reduce_mod_2to(&x0
, b
);
1341 mp_int r0
= mp_make_alias(r
, 0, Bw
);
1342 mp_int Kshift
= mp_alloc_from_scratch(&scratch_this_iter
, B2w
);
1343 mp_mul_internal(&Kshift
, &x0
, &r0
, mul_scratch
);
1344 mp_int K
= mp_alloc_from_scratch(&scratch_this_iter
, Bw
);
1345 mp_rshift_fixed_into(&K
, &Kshift
, b
);
1347 /* Now compute the product r_0 x_1, reusing the space of Kshift. */
1348 mp_int x1
= mp_alloc_from_scratch(&scratch_this_iter
, Bw
);
1349 mp_rshift_fixed_into(&x1
, x
, b
);
1350 mp_reduce_mod_2to(&x1
, b
);
1351 mp_int r0x1
= mp_make_alias(&Kshift
, 0, Bw
);
1352 mp_mul_internal(&r0x1
, &r0
, &x1
, mul_scratch
);
1354 /* Add K to that. */
1355 mp_add_into(&r0x1
, &r0x1
, &K
);
1358 mp_neg_into(&r0x1
, &r0x1
);
1360 /* Multiply by r_0. */
1361 mp_int r1
= mp_alloc_from_scratch(&scratch_this_iter
, Bw
);
1362 mp_mul_internal(&r1
, &r0
, &r0x1
, mul_scratch
);
1363 mp_reduce_mod_2to(&r1
, b
);
1365 /* That's our r_1, so add it on to r_0 to get the full inverse
1366 * output from this iteration. */
1367 mp_lshift_fixed_into(&K
, &r1
, (b
% BIGNUM_INT_BITS
));
1368 size_t Bpos
= b
/ BIGNUM_INT_BITS
;
1369 mp_int r1_position
= mp_make_alias(r
, Bpos
, B2w
-Bpos
);
1370 mp_add_into(&r1_position
, &r1_position
, &K
);
1373 /* Finally, reduce mod the precise desired number of bits. */
1374 mp_reduce_mod_2to(r
, p
);
1376 mp_free(scratch_orig
);
1380 static size_t monty_scratch_size(MontyContext
*mc
)
1382 return 3*mc
->rw
+ mc
->pw
+ mp_mul_scratchspace(mc
->pw
, mc
->rw
, mc
->rw
);
1385 MontyContext
*monty_new(mp_int
*modulus
)
1387 MontyContext
*mc
= snew(MontyContext
);
1389 mc
->rw
= modulus
->nw
;
1390 mc
->rbits
= mc
->rw
* BIGNUM_INT_BITS
;
1391 mc
->pw
= mc
->rw
* 2 + 1;
1393 mc
->m
= mp_copy(modulus
);
1395 mc
->minus_minv_mod_r
= mp_invert_mod_2to(mc
->m
, mc
->rbits
);
1396 mp_neg_into(mc
->minus_minv_mod_r
, mc
->minus_minv_mod_r
);
1398 mp_int
*r
= mp_make_sized(mc
->rw
+ 1);
1400 mc
->powers_of_r_mod_m
[0] = mp_mod(r
, mc
->m
);
1403 for (size_t j
= 1; j
< lenof(mc
->powers_of_r_mod_m
); j
++)
1404 mc
->powers_of_r_mod_m
[j
] = mp_modmul(
1405 mc
->powers_of_r_mod_m
[0], mc
->powers_of_r_mod_m
[j
-1], mc
->m
);
1407 mc
->scratch
= mp_make_sized(monty_scratch_size(mc
));
1412 void monty_free(MontyContext
*mc
)
1415 for (size_t j
= 0; j
< 3; j
++)
1416 mp_free(mc
->powers_of_r_mod_m
[j
]);
1417 mp_free(mc
->minus_minv_mod_r
);
1418 mp_free(mc
->scratch
);
1419 smemclr(mc
, sizeof(*mc
));
1424 * The main Montgomery reduction step.
1426 static mp_int
monty_reduce_internal(MontyContext
*mc
, mp_int
*x
, mp_int scratch
)
1429 * The trick with Montgomery reduction is that on the one hand we
1430 * want to reduce the size of the input by a factor of about r,
1431 * and on the other hand, the two numbers we just multiplied were
1432 * both stored with an extra factor of r multiplied in. So we
1433 * computed ar*br = ab r^2, but we want to return abr, so we need
1434 * to divide by r - and if we can do that by _actually dividing_
1435 * by r then this also reduces the size of the number.
1437 * But we can only do that if the number we're dividing by r is a
1438 * multiple of r. So first we must add an adjustment to it which
1439 * clears its bottom 'rbits' bits. That adjustment must be a
1440 * multiple of m in order to leave the residue mod n unchanged, so
1441 * the question is, what multiple of m can we add to x to make it
1442 * congruent to 0 mod r? And the answer is, x * (-m)^{-1} mod r.
1446 mp_int x_lo
= mp_make_alias(x
, 0, mc
->rbits
);
1448 /* x * (-m)^{-1}, i.e. the number we want to multiply by m */
1449 mp_int k
= mp_alloc_from_scratch(&scratch
, mc
->rw
);
1450 mp_mul_internal(&k
, &x_lo
, mc
->minus_minv_mod_r
, scratch
);
1452 /* m times that, i.e. the number we want to add to x */
1453 mp_int mk
= mp_alloc_from_scratch(&scratch
, mc
->pw
);
1454 mp_mul_internal(&mk
, mc
->m
, &k
, scratch
);
1457 mp_add_into(&mk
, x
, &mk
);
1459 /* Reduce mod r, by simply making an alias to the upper words of x */
1460 mp_int toret
= mp_make_alias(&mk
, mc
->rw
, mk
.nw
- mc
->rw
);
1463 * We'll generally be doing this after a multiplication of two
1464 * fully reduced values. So our input could be anything up to m^2,
1465 * and then we added up to rm to it. Hence, the maximum value is
1466 * rm+m^2, and after dividing by r, that becomes r + m(m/r) < 2r.
1467 * So a single trial-subtraction will finish reducing to the
1470 mp_cond_sub_into(&toret
, &toret
, mc
->m
, mp_cmp_hs(&toret
, mc
->m
));
1474 void monty_mul_into(MontyContext
*mc
, mp_int
*r
, mp_int
*x
, mp_int
*y
)
1476 assert(x
->nw
<= mc
->rw
);
1477 assert(y
->nw
<= mc
->rw
);
1479 mp_int scratch
= *mc
->scratch
;
1480 mp_int tmp
= mp_alloc_from_scratch(&scratch
, 2*mc
->rw
);
1481 mp_mul_into(&tmp
, x
, y
);
1482 mp_int reduced
= monty_reduce_internal(mc
, &tmp
, scratch
);
1483 mp_copy_into(r
, &reduced
);
1484 mp_clear(mc
->scratch
);
1487 mp_int
*monty_mul(MontyContext
*mc
, mp_int
*x
, mp_int
*y
)
1489 mp_int
*toret
= mp_make_sized(mc
->rw
);
1490 monty_mul_into(mc
, toret
, x
, y
);
1494 mp_int
*monty_modulus(MontyContext
*mc
)
1499 mp_int
*monty_identity(MontyContext
*mc
)
1501 return mc
->powers_of_r_mod_m
[0];
1504 mp_int
*monty_invert(MontyContext
*mc
, mp_int
*x
)
1506 /* Given xr, we want to return x^{-1}r = (xr)^{-1} r^2 =
1507 * monty_reduce((xr)^{-1} r^3) */
1508 mp_int
*tmp
= mp_invert(x
, mc
->m
);
1509 mp_int
*toret
= monty_mul(mc
, tmp
, mc
->powers_of_r_mod_m
[2]);
1515 * Importing a number into Montgomery representation involves
1516 * multiplying it by r and reducing mod m. We use the general-purpose
1517 * mp_modmul for this, in case the input number is out of range.
1519 mp_int
*monty_import(MontyContext
*mc
, mp_int
*x
)
1521 return mp_modmul(x
, mc
->powers_of_r_mod_m
[0], mc
->m
);
1524 void monty_import_into(MontyContext
*mc
, mp_int
*r
, mp_int
*x
)
1526 mp_int
*imported
= monty_import(mc
, x
);
1527 mp_copy_into(r
, imported
);
1532 * Exporting a number means multiplying it by r^{-1}, which is exactly
1533 * what monty_reduce does anyway, so we just do that.
1535 void monty_export_into(MontyContext
*mc
, mp_int
*r
, mp_int
*x
)
1537 assert(x
->nw
<= 2*mc
->rw
);
1538 mp_int reduced
= monty_reduce_internal(mc
, x
, *mc
->scratch
);
1539 mp_copy_into(r
, &reduced
);
1540 mp_clear(mc
->scratch
);
1543 mp_int
*monty_export(MontyContext
*mc
, mp_int
*x
)
1545 mp_int
*toret
= mp_make_sized(mc
->rw
);
1546 monty_export_into(mc
, toret
, x
);
1550 #define MODPOW_LOG2_WINDOW_SIZE 5
1551 #define MODPOW_WINDOW_SIZE (1 << MODPOW_LOG2_WINDOW_SIZE)
1552 mp_int
*monty_pow(MontyContext
*mc
, mp_int
*base
, mp_int
*exponent
)
1555 * Modular exponentiation is done from the top down, using a
1556 * fixed-window technique.
1558 * We have a table storing every power of the base from base^0 up
1559 * to base^{w-1}, where w is a small power of 2, say 2^k. (k is
1560 * defined above as MODPOW_LOG2_WINDOW_SIZE, and w = 2^k is
1561 * defined as MODPOW_WINDOW_SIZE.)
1563 * We break the exponent up into k-bit chunks, from the bottom up,
1566 * exponent = c_0 + 2^k c_1 + 2^{2k} c_2 + ... + 2^{nk} c_n
1568 * and we compute base^exponent by computing in turn
1571 * base^{2^k c_n + c_{n-1}}
1572 * base^{2^{2k} c_n + 2^k c_{n-1} + c_{n-2}}
1575 * where each line is obtained by raising the previous line to the
1576 * power 2^k (i.e. squaring it k times) and then multiplying in
1577 * a value base^{c_i}, which we can look up in our table.
1579 * Side-channel considerations: the exponent is secret, so
1580 * actually doing a single table lookup by using a chunk of
1581 * exponent bits as an array index would be an obvious leak of
1582 * secret information into the cache. So instead, in each
1583 * iteration, we read _all_ the table entries, and do a sequence
1584 * of mp_select operations to leave just the one we wanted in the
1585 * variable that will go into the multiplication. In other
1586 * contexts (like software AES) that technique is so prohibitively
1587 * slow that it makes you choose a strategy that doesn't use table
1588 * lookups at all (we do bitslicing in preference); but here, this
1589 * iteration through 2^k table elements is replacing k-1 bignum
1590 * _multiplications_ that you'd have to use instead if you did
1591 * simple square-and-multiply, and that makes it still a win.
1594 /* Table that holds base^0, ..., base^{w-1} */
1595 mp_int
*table
[MODPOW_WINDOW_SIZE
];
1596 table
[0] = mp_copy(monty_identity(mc
));
1597 for (size_t i
= 1; i
< MODPOW_WINDOW_SIZE
; i
++)
1598 table
[i
] = monty_mul(mc
, table
[i
-1], base
);
1600 /* out accumulates the output value */
1601 mp_int
*out
= mp_make_sized(mc
->rw
);
1602 mp_copy_into(out
, monty_identity(mc
));
1604 /* table_entry will hold each value we get out of the table */
1605 mp_int
*table_entry
= mp_make_sized(mc
->rw
);
1607 /* Bit index of the chunk of bits we're working on. Start with the
1608 * highest multiple of k strictly less than the size of our
1609 * bignum, i.e. the highest-index chunk of bits that might
1610 * conceivably contain any nonzero bit. */
1611 size_t i
= (exponent
->nw
* BIGNUM_INT_BITS
) - 1;
1612 i
-= i
% MODPOW_LOG2_WINDOW_SIZE
;
1614 bool first_iteration
= true;
1617 /* Construct the table index */
1618 unsigned table_index
= 0;
1619 for (size_t j
= 0; j
< MODPOW_LOG2_WINDOW_SIZE
; j
++)
1620 table_index
|= mp_get_bit(exponent
, i
+j
) << j
;
1622 /* Iterate through the table to do a side-channel-safe lookup,
1623 * ending up with table_entry = table[table_index] */
1624 mp_copy_into(table_entry
, table
[0]);
1625 for (size_t j
= 1; j
< MODPOW_WINDOW_SIZE
; j
++) {
1626 unsigned not_this_one
=
1627 ((table_index
^ j
) + MODPOW_WINDOW_SIZE
- 1)
1628 >> MODPOW_LOG2_WINDOW_SIZE
;
1629 mp_select_into(table_entry
, table
[j
], table_entry
, not_this_one
);
1632 if (!first_iteration
) {
1633 /* Multiply into the output */
1634 monty_mul_into(mc
, out
, out
, table_entry
);
1636 /* On the first iteration, we can save one multiplication
1637 * by just copying */
1638 mp_copy_into(out
, table_entry
);
1639 first_iteration
= false;
1642 /* If that was the bottommost chunk of bits, we're done */
1646 /* Otherwise, square k times and go round again. */
1647 for (size_t j
= 0; j
< MODPOW_LOG2_WINDOW_SIZE
; j
++)
1648 monty_mul_into(mc
, out
, out
, out
);
1650 i
-= MODPOW_LOG2_WINDOW_SIZE
;
1653 for (size_t i
= 0; i
< MODPOW_WINDOW_SIZE
; i
++)
1655 mp_free(table_entry
);
1656 mp_clear(mc
->scratch
);
1660 mp_int
*mp_modpow(mp_int
*base
, mp_int
*exponent
, mp_int
*modulus
)
1662 assert(modulus
->nw
> 0);
1663 assert(modulus
->w
[0] & 1);
1665 MontyContext
*mc
= monty_new(modulus
);
1666 mp_int
*m_base
= monty_import(mc
, base
);
1667 mp_int
*m_out
= monty_pow(mc
, m_base
, exponent
);
1668 mp_int
*out
= monty_export(mc
, m_out
);
1676 * Given two input integers a,b which are not both even, computes d =
1677 * gcd(a,b) and also two integers A,B such that A*a - B*b = d. A,B
1678 * will be the minimal non-negative pair satisfying that criterion,
1679 * which is equivalent to saying that 0 <= A < b/d and 0 <= B < a/d.
1681 * This algorithm is an adapted form of Stein's algorithm, which
1682 * computes gcd(a,b) using only addition and bit shifts (i.e. without
1683 * needing general division), using the following rules:
1685 * - if both of a,b are even, divide off a common factor of 2
1686 * - if one of a,b (WLOG a) is even, then gcd(a,b) = gcd(a/2,b), so
1687 * just divide a by 2
1688 * - if both of a,b are odd, then WLOG a>b, and gcd(a,b) =
1691 * Sometimes this function is used for modular inversion, in which
1692 * case we already know we expect the two inputs to be coprime, so to
1693 * save time the 'both even' initial case is assumed not to arise (or
1694 * to have been handled already by the caller). So this function just
1695 * performs a sequence of reductions in the following form:
1697 * - if a,b are both odd, sort them so that a > b, and replace a with
1698 * b-a; otherwise sort them so that a is the even one
1699 * - either way, now a is even and b is odd, so divide a by 2.
1701 * The big change to Stein's algorithm is that we need the Bezout
1702 * coefficients as output, not just the gcd. So we need to know how to
1703 * generate those in each case, based on the coefficients from the
1704 * reduced pair of numbers:
1706 * - If a is even, and u,v are such that u*(a/2) + v*b = d:
1707 * + if u is also even, then this is just (u/2)*a + v*b = d
1708 * + otherwise, (u+b)*(a/2) + (v-a/2)*b is also equal to d, and
1709 * since u and b are both odd, (u+b)/2 is an integer, so we have
1710 * ((u+b)/2)*a + (v-a/2)*b = d.
1712 * - If a,b are both odd, and u,v are such that u*b + v*(a-b) = d,
1713 * then v*a + (u-v)*b = d.
1715 * In the case where we passed from (a,b) to (b,(a-b)/2), we regard it
1716 * as having first subtracted b from a and then halved a, so both of
1717 * these transformations must be done in sequence.
1719 * The code below transforms this from a recursive to an iterative
1720 * algorithm. We first reduce a,b to 0,1, recording at each stage
1721 * whether we did the initial subtraction, and whether we had to swap
1722 * the two values; then we iterate backwards over that record of what
1723 * we did, applying the above rules for building up the Bezout
1724 * coefficients as we go. Of course, all the case analysis is done by
1725 * the usual bit-twiddling conditionalisation to avoid data-dependent
1728 * Also, since these mp_ints are generally treated as unsigned, we
1729 * store the coefficients by absolute value, with the semantics that
1730 * they always have opposite sign, and in the unwinding loop we keep a
1731 * bit indicating whether Aa-Bb is currently expected to be +d or -d,
1732 * so that we can do one final conditional adjustment if it's -d.
1734 * Once the reduction rules have managed to reduce the input numbers
1735 * to (0,d), then they are stable (the next reduction will always
1736 * divide the even one by 2, which maps 0 to 0). So it doesn't matter
1737 * if we do more steps of the algorithm than necessary; hence, for
1738 * constant time, we just need to find the maximum number we could
1739 * _possibly_ require, and do that many.
1741 * If a,b < 2^n, at most 2n iterations are required. Proof: consider
1742 * the quantity Q = log_2(a) + log_2(b). Every step halves one of the
1743 * numbers (and may also reduce one of them further by doing a
1744 * subtraction beforehand, but in the worst case, not by much or not
1745 * at all). So Q reduces by at least 1 per iteration, and it starts
1746 * off with a value at most 2n.
1748 * The worst case inputs (I think) are where x=2^{n-1} and y=2^n-1
1749 * (i.e. x is a power of 2 and y is all 1s). In that situation, the
1750 * first n-1 steps repeatedly halve x until it's 1, and then there are
1751 * n further steps each of which subtracts 1 from y and halves it.
1753 static void mp_bezout_into(mp_int
*a_coeff_out
, mp_int
*b_coeff_out
,
1754 mp_int
*gcd_out
, mp_int
*a_in
, mp_int
*b_in
)
1756 size_t nw
= size_t_max(1, size_t_max(a_in
->nw
, b_in
->nw
));
1758 /* Make mutable copies of the input numbers */
1759 mp_int
*a
= mp_make_sized(nw
), *b
= mp_make_sized(nw
);
1760 mp_copy_into(a
, a_in
);
1761 mp_copy_into(b
, b_in
);
1763 /* Space to build up the output coefficients, with an extra word
1764 * so that intermediate values can overflow off the top and still
1765 * right-shift back down to the correct value */
1766 mp_int
*ac
= mp_make_sized(nw
+ 1), *bc
= mp_make_sized(nw
+ 1);
1768 /* And a general-purpose temp register */
1769 mp_int
*tmp
= mp_make_sized(nw
);
1771 /* Space to record the sequence of reduction steps to unwind. We
1772 * make it a BignumInt for no particular reason except that (a)
1773 * mp_make_sized conveniently zeroes the allocation and mp_free
1774 * wipes it, and (b) this way I can use mp_dump() if I have to
1775 * debug this code. */
1776 size_t steps
= 2 * nw
* BIGNUM_INT_BITS
;
1777 mp_int
*record
= mp_make_sized(
1778 (steps
*2 + BIGNUM_INT_BITS
- 1) / BIGNUM_INT_BITS
);
1780 for (size_t step
= 0; step
< steps
; step
++) {
1782 * If a and b are both odd, we want to sort them so that a is
1783 * larger. But if one is even, we want to sort them so that a
1786 unsigned swap_if_both_odd
= mp_cmp_hs(b
, a
);
1787 unsigned swap_if_one_even
= a
->w
[0] & 1;
1788 unsigned both_odd
= a
->w
[0] & b
->w
[0] & 1;
1789 unsigned swap
= swap_if_one_even
^ (
1790 (swap_if_both_odd
^ swap_if_one_even
) & both_odd
);
1792 mp_cond_swap(a
, b
, swap
);
1795 * If a,b are both odd, then a is the larger number, so
1796 * subtract the smaller one from it.
1798 mp_cond_sub_into(a
, a
, b
, both_odd
);
1801 * Now a is even, so divide it by two.
1803 mp_rshift_fixed_into(a
, a
, 1);
1806 * Record the two 1-bit values both_odd and swap.
1808 mp_set_bit(record
, step
*2, both_odd
);
1809 mp_set_bit(record
, step
*2+1, swap
);
1813 * Now we expect to have reduced the two numbers to 0 and d,
1814 * although we don't know which way round. (But we avoid checking
1815 * this by assertion; sometimes we'll need to do this computation
1816 * without giving away that we already know the inputs were bogus.
1817 * So we'd prefer to just press on and return nonsense.)
1822 * At this point we can return the actual gcd. Since one of
1823 * a,b is it and the other is zero, the easiest way to get it
1824 * is to add them together.
1826 mp_add_into(gcd_out
, a
, b
);
1830 * If the caller _only_ wanted the gcd, and neither Bezout
1831 * coefficient is even required, we can skip the entire unwind
1834 if (a_coeff_out
|| b_coeff_out
) {
1837 * The Bezout coefficients of a,b at this point are simply 0
1838 * for whichever of a,b is zero, and 1 for whichever is
1839 * nonzero. The nonzero number equals gcd(a,b), which by
1840 * assumption is odd, so we can do this by just taking the low
1843 ac
->w
[0] = mp_get_bit(a
, 0);
1844 bc
->w
[0] = mp_get_bit(b
, 0);
1847 * Overwrite a,b themselves with those same numbers. This has
1848 * the effect of dividing both of them by d, which will
1849 * arrange that during the unwind stage we generate the
1850 * minimal coefficients instead of a larger pair.
1852 mp_copy_into(a
, ac
);
1853 mp_copy_into(b
, bc
);
1856 * We'll maintain the invariant as we unwind that ac * a - bc
1857 * * b is either +d or -d (or rather, +1/-1 after scaling by
1858 * d), and we'll remember which. (We _could_ keep it at +d the
1859 * whole time, but it would cost more work every time round
1860 * the loop, so it's cheaper to fix that up once at the end.)
1862 * Initially, the result is +d if a was the nonzero value after
1863 * reduction, and -d if b was.
1865 unsigned minus_d
= b
->w
[0];
1867 for (size_t step
= steps
; step
-- > 0 ;) {
1869 * Recover the data from the step we're unwinding.
1871 unsigned both_odd
= mp_get_bit(record
, step
*2);
1872 unsigned swap
= mp_get_bit(record
, step
*2+1);
1875 * Unwind the division: if our coefficient of a is odd, we
1876 * adjust the coefficients by +b and +a respectively.
1878 unsigned adjust
= ac
->w
[0] & 1;
1879 mp_cond_add_into(ac
, ac
, b
, adjust
);
1880 mp_cond_add_into(bc
, bc
, a
, adjust
);
1883 * Now ac is definitely even, so we divide it by two.
1885 mp_rshift_fixed_into(ac
, ac
, 1);
1888 * Now unwind the subtraction, if there was one, by adding
1891 mp_cond_add_into(bc
, bc
, ac
, both_odd
);
1894 * Undo the transformation of the input numbers, by
1895 * multiplying a by 2 and then adding b to a (the latter
1896 * only if both_odd).
1898 mp_lshift_fixed_into(a
, a
, 1);
1899 mp_cond_add_into(a
, a
, b
, both_odd
);
1902 * Finally, undo the swap. If we do swap, this also
1903 * reverses the sign of the current result ac*a+bc*b.
1905 mp_cond_swap(a
, b
, swap
);
1906 mp_cond_swap(ac
, bc
, swap
);
1911 * Now we expect to have recovered the input a,b (or rather,
1912 * the versions of them divided by d). But we might find that
1913 * our current result is -d instead of +d, that is, we have
1914 * A',B' such that A'a - B'b = -d.
1916 * In that situation, we set A = b-A' and B = a-B', giving us
1917 * Aa-Bb = ab - A'a - ab + B'b = +1.
1919 mp_sub_into(tmp
, b
, ac
);
1920 mp_select_into(ac
, ac
, tmp
, minus_d
);
1921 mp_sub_into(tmp
, a
, bc
);
1922 mp_select_into(bc
, bc
, tmp
, minus_d
);
1925 * Now we really are done. Return the outputs.
1928 mp_copy_into(a_coeff_out
, ac
);
1930 mp_copy_into(b_coeff_out
, bc
);
1942 mp_int
*mp_invert(mp_int
*x
, mp_int
*m
)
1944 mp_int
*result
= mp_make_sized(m
->nw
);
1945 mp_bezout_into(result
, NULL
, NULL
, x
, m
);
1949 void mp_gcd_into(mp_int
*a
, mp_int
*b
, mp_int
*gcd
, mp_int
*A
, mp_int
*B
)
1952 * Identify shared factors of 2. To do this we OR the two numbers
1953 * to get something whose lowest set bit is in the right place,
1954 * remove all higher bits by ANDing it with its own negation, and
1955 * use mp_get_nbits to find the location of the single remaining
1958 mp_int
*tmp
= mp_make_sized(size_t_max(a
->nw
, b
->nw
));
1959 for (size_t i
= 0; i
< tmp
->nw
; i
++)
1960 tmp
->w
[i
] = mp_word(a
, i
) | mp_word(b
, i
);
1961 BignumCarry carry
= 1;
1962 for (size_t i
= 0; i
< tmp
->nw
; i
++) {
1964 BignumADC(negw
, carry
, 0, ~tmp
->w
[i
], carry
);
1967 size_t shift
= mp_get_nbits(tmp
) - 1;
1971 * Make copies of a,b with those shared factors of 2 divided off,
1972 * so that at least one is odd (which is the precondition for
1973 * mp_bezout_into). Compute the gcd of those.
1975 mp_int
*as
= mp_rshift_safe(a
, shift
);
1976 mp_int
*bs
= mp_rshift_safe(b
, shift
);
1977 mp_bezout_into(A
, B
, gcd
, as
, bs
);
1982 * And finally shift the gcd back up (unless the caller didn't
1983 * even ask for it), to put the shared factors of 2 back in.
1986 mp_lshift_safe_in_place(gcd
, shift
);
1989 mp_int
*mp_gcd(mp_int
*a
, mp_int
*b
)
1991 mp_int
*gcd
= mp_make_sized(size_t_min(a
->nw
, b
->nw
));
1992 mp_gcd_into(a
, b
, gcd
, NULL
, NULL
);
1996 unsigned mp_coprime(mp_int
*a
, mp_int
*b
)
1998 mp_int
*gcd
= mp_gcd(a
, b
);
1999 unsigned toret
= mp_eq_integer(gcd
, 1);
2004 static uint32_t recip_approx_32(uint32_t x
)
2007 * Given an input x in [2^31,2^32), i.e. a uint32_t with its high
2008 * bit set, this function returns an approximation to 2^63/x,
2009 * computed using only multiplications and bit shifts just in case
2010 * the C divide operator has non-constant time (either because the
2011 * underlying machine instruction does, or because the operator
2012 * expands to a library function on a CPU without hardware
2015 * The coefficients are derived from those of the degree-9
2016 * polynomial which is the minimax-optimal approximation to that
2017 * function on the given interval (generated using the Remez
2018 * algorithm), converted into integer arithmetic with shifts used
2019 * to maximise the number of significant bits at every state. (A
2020 * sort of 'static floating point' - the exponent is statically
2021 * known at every point in the code, so it never needs to be
2022 * stored at run time or to influence runtime decisions.)
2024 * Exhaustive iteration over the whole input space shows the
2025 * largest possible error to be 1686.54. (The input value
2026 * attaining that bound is 4226800006 == 0xfbefd986, whose true
2027 * reciprocal is 2182116973.540... == 0x8210766d.8a6..., whereas
2028 * this function returns 2182115287 == 0x82106fd7.)
2030 uint64_t r
= 0x92db03d6ULL
;
2031 r
= 0xf63e71eaULL
- ((r
*x
) >> 34);
2032 r
= 0xb63721e8ULL
- ((r
*x
) >> 34);
2033 r
= 0x9c2da00eULL
- ((r
*x
) >> 33);
2034 r
= 0xaada0bb8ULL
- ((r
*x
) >> 32);
2035 r
= 0xf75cd403ULL
- ((r
*x
) >> 31);
2036 r
= 0xecf97a41ULL
- ((r
*x
) >> 31);
2037 r
= 0x90d876cdULL
- ((r
*x
) >> 31);
2038 r
= 0x6682799a0ULL
- ((r
*x
) >> 26);
2042 void mp_divmod_into(mp_int
*n
, mp_int
*d
, mp_int
*q_out
, mp_int
*r_out
)
2044 assert(!mp_eq_integer(d
, 0));
2047 * We do division by using Newton-Raphson iteration to converge to
2048 * the reciprocal of d (or rather, R/d for R a sufficiently large
2049 * power of 2); then we multiply that reciprocal by n; and we
2050 * finish up with conditional subtraction.
2052 * But we have to do it in a fixed number of N-R iterations, so we
2053 * need some error analysis to know how many we might need.
2055 * The iteration is derived by defining f(r) = d - R/r.
2056 * Differentiating gives f'(r) = R/r^2, and the Newton-Raphson
2057 * formula applied to those functions gives
2059 * r_{i+1} = r_i - f(r_i) / f'(r_i)
2060 * = r_i - (d - R/r_i) r_i^2 / R
2061 * = r_i (2 R - d r_i) / R
2063 * Now let e_i be the error in a given iteration, in the sense
2067 * i.e. e_i/R = (r_i - r_true) / r_true
2069 * so e_i is the _relative_ error in r_i.
2071 * We must also introduce a rounding-error term, because the
2072 * division by R always gives an integer. This might make the
2073 * output off by up to 1 (in the negative direction, because
2074 * right-shifting gives floor of the true quotient). So when we
2075 * divide by R, we must imagine adding some f in [0,1). Then we
2078 * d r_{i+1} = d r_i (2 R - d r_i) / R - d f
2079 * = (R + e_i) (R - e_i) / R - d f
2080 * = (R^2 - e_i^2) / R - d f
2081 * = R - (e_i^2 / R + d f)
2082 * => e_{i+1} = - (e_i^2 / R + d f)
2084 * The sum of two positive quantities is bounded above by twice
2085 * their max, and max |f| = 1, so we can bound this as follows:
2087 * |e_{i+1}| <= 2 max (e_i^2/R, d)
2088 * |e_{i+1}/R| <= 2 max ((e_i/R)^2, d/R)
2089 * log2 |R/e_{i+1}| <= min (2 log2 |R/e_i|, log2 |R/d|) - 1
2091 * which tells us that the number of 'good' bits - i.e.
2092 * log2(R/e_i) - very nearly doubles at every iteration (apart
2093 * from that subtraction of 1), until it gets to the same size as
2094 * log2(R/d). In other words, the size of R in bits has to be the
2095 * size of denominator we're putting in, _plus_ the amount of
2096 * precision we want to get back out.
2098 * So when we multiply n (the input numerator) by our final
2099 * reciprocal approximation r, but actually r differs from R/d by
2100 * up to 2, then it follows that
2102 * n/d - nr/R = n/d - [ n (R/d + e) ] / R
2103 * = n/d - [ (n/d) R + n e ] / R
2105 * => 0 <= n/d - nr/R < 2n/R
2107 * so our computed quotient can differ from the true n/d by up to
2108 * 2n/R. Hence, as long as we also choose R large enough that 2n/R
2109 * is bounded above by a constant, we can guarantee a bounded
2110 * number of final conditional-subtraction steps.
2114 * Get at least 32 of the most significant bits of the input
2117 size_t hiword_index
= 0;
2118 uint64_t hibits
= 0, lobits
= 0;
2119 mp_find_highest_nonzero_word_pair(d
, 64 - BIGNUM_INT_BITS
,
2120 &hiword_index
, &hibits
, &lobits
);
2123 * Make a shifted combination of those two words which puts the
2124 * topmost bit of the number at bit 63.
2126 size_t shift_up
= 0;
2127 for (size_t i
= BIGNUM_INT_BITS_BITS
; i
-- > 0;) {
2128 size_t sl
= (size_t)1 << i
; /* left shift count */
2129 size_t sr
= 64 - sl
; /* complementary right-shift count */
2131 /* Should we shift up? */
2132 unsigned indicator
= 1 ^ normalise_to_1_u64(hibits
>> sr
);
2134 /* If we do, what will we get? */
2135 uint64_t new_hibits
= (hibits
<< sl
) | (lobits
>> sr
);
2136 uint64_t new_lobits
= lobits
<< sl
;
2137 size_t new_shift_up
= shift_up
+ sl
;
2139 /* Conditionally swap those values in. */
2140 hibits
^= (hibits
^ new_hibits
) & -(uint64_t)indicator
;
2141 lobits
^= (lobits
^ new_lobits
) & -(uint64_t)indicator
;
2142 shift_up
^= (shift_up
^ new_shift_up
) & -(size_t) indicator
;
2146 * So now we know the most significant 32 bits of d are at the top
2147 * of hibits. Approximate the reciprocal of those bits.
2149 lobits
= (uint64_t)recip_approx_32(hibits
>> 32) << 32;
2153 * And shift that up by as many bits as the input was shifted up
2154 * just now, so that the product of this approximation and the
2155 * actual input will be close to a fixed power of two regardless
2156 * of where the MSB was.
2158 * I do this in another log n individual passes, partly in case
2159 * the CPU's register-controlled shift operation isn't
2160 * time-constant, and also in case the compiler code-generates
2161 * uint64_t shifts out of a variable number of smaller-word shift
2162 * instructions, e.g. by splitting up into cases.
2164 for (size_t i
= BIGNUM_INT_BITS_BITS
; i
-- > 0;) {
2165 size_t sl
= (size_t)1 << i
; /* left shift count */
2166 size_t sr
= 64 - sl
; /* complementary right-shift count */
2168 /* Should we shift up? */
2169 unsigned indicator
= 1 & (shift_up
>> i
);
2171 /* If we do, what will we get? */
2172 uint64_t new_hibits
= (hibits
<< sl
) | (lobits
>> sr
);
2173 uint64_t new_lobits
= lobits
<< sl
;
2175 /* Conditionally swap those values in. */
2176 hibits
^= (hibits
^ new_hibits
) & -(uint64_t)indicator
;
2177 lobits
^= (lobits
^ new_lobits
) & -(uint64_t)indicator
;
2181 * The product of the 128-bit value now in hibits:lobits with the
2182 * 128-bit value we originally retrieved in the same variables
2183 * will be in the vicinity of 2^191. So we'll take log2(R) to be
2184 * 191, plus a multiple of BIGNUM_INT_BITS large enough to allow R
2185 * to hold the combined sizes of n and d.
2189 size_t max_log2_n
= (n
->nw
+ d
->nw
) * BIGNUM_INT_BITS
;
2190 log2_R
= max_log2_n
+ 3;
2191 log2_R
-= size_t_min(191, log2_R
);
2192 log2_R
= (log2_R
+ BIGNUM_INT_BITS
- 1) & ~(BIGNUM_INT_BITS
- 1);
2196 /* Number of words in a bignum capable of holding numbers the size
2198 size_t rw
= ((log2_R
+2) + BIGNUM_INT_BITS
- 1) / BIGNUM_INT_BITS
;
2201 * Now construct our full-sized starting reciprocal approximation.
2203 mp_int
*r_approx
= mp_make_sized(rw
);
2204 size_t output_bit_index
;
2206 /* Where in the input number did the input 128-bit value come from? */
2207 size_t input_bit_index
=
2208 (hiword_index
* BIGNUM_INT_BITS
) - (128 - BIGNUM_INT_BITS
);
2210 /* So how far do we need to shift our 64-bit output, if the
2211 * product of those two fixed-size values is 2^191 and we want
2212 * to make it 2^log2_R instead? */
2213 output_bit_index
= log2_R
- 191 - input_bit_index
;
2215 /* If we've done all that right, it should be a whole number
2217 assert(output_bit_index
% BIGNUM_INT_BITS
== 0);
2218 size_t output_word_index
= output_bit_index
/ BIGNUM_INT_BITS
;
2220 mp_add_integer_into_shifted_by_words(
2221 r_approx
, r_approx
, lobits
, output_word_index
);
2222 mp_add_integer_into_shifted_by_words(
2223 r_approx
, r_approx
, hibits
,
2224 output_word_index
+ 64 / BIGNUM_INT_BITS
);
2228 * Make the constant 2*R, which we'll need in the iteration.
2230 mp_int
*two_R
= mp_make_sized(rw
);
2231 BignumInt top_word
= (BignumInt
)1 << ((log2_R
+1) % BIGNUM_INT_BITS
);
2232 mp_add_integer_into_shifted_by_words(
2233 two_R
, two_R
, top_word
, (log2_R
+1) / BIGNUM_INT_BITS
);
2238 mp_int
*dr
= mp_make_sized(rw
+ d
->nw
);
2239 mp_int
*diff
= mp_make_sized(size_t_max(rw
, dr
->nw
));
2240 mp_int
*product
= mp_make_sized(rw
+ diff
->nw
);
2241 size_t scratchsize
= size_t_max(
2242 mp_mul_scratchspace(dr
->nw
, r_approx
->nw
, d
->nw
),
2243 mp_mul_scratchspace(product
->nw
, r_approx
->nw
, diff
->nw
));
2244 mp_int
*scratch
= mp_make_sized(scratchsize
);
2245 mp_int product_shifted
= mp_make_alias(
2246 product
, log2_R
/ BIGNUM_INT_BITS
, product
->nw
);
2249 * Initial error estimate: the 32-bit output of recip_approx_32
2250 * differs by less than 2048 (== 2^11) from the true top 32 bits
2251 * of the reciprocal, so the relative error is at most 2^11
2252 * divided by the 32-bit reciprocal, which at worst is 2^11/2^31 =
2253 * 2^-20. So even in the worst case, we have 20 good bits of
2254 * reciprocal to start with.
2256 size_t good_bits
= 31 - 11;
2257 size_t good_bits_needed
= BIGNUM_INT_BITS
* n
->nw
+ 4; /* add a few */
2260 * Now do Newton-Raphson iterations until we have reason to think
2261 * they're not converging any more.
2263 while (good_bits
< good_bits_needed
) {
2265 * Compute the next iterate.
2267 mp_mul_internal(dr
, r_approx
, d
, *scratch
);
2268 mp_sub_into(diff
, two_R
, dr
);
2269 mp_mul_internal(product
, r_approx
, diff
, *scratch
);
2270 mp_rshift_fixed_into(r_approx
, &product_shifted
,
2271 log2_R
% BIGNUM_INT_BITS
);
2274 * Adjust the error estimate.
2276 good_bits
= good_bits
* 2 - 1;
2285 * Now we've got our reciprocal, we can compute the quotient, by
2286 * multiplying in n and then shifting down by log2_R bits.
2288 mp_int
*quotient_full
= mp_mul(r_approx
, n
);
2289 mp_int quotient_alias
= mp_make_alias(
2290 quotient_full
, log2_R
/ BIGNUM_INT_BITS
, quotient_full
->nw
);
2291 mp_int
*quotient
= mp_make_sized(n
->nw
);
2292 mp_rshift_fixed_into(quotient
, "ient_alias
, log2_R
% BIGNUM_INT_BITS
);
2295 * Next, compute the remainder.
2297 mp_int
*remainder
= mp_make_sized(d
->nw
);
2298 mp_mul_into(remainder
, quotient
, d
);
2299 mp_sub_into(remainder
, n
, remainder
);
2302 * Finally, two conditional subtractions to fix up any remaining
2303 * rounding error. (I _think_ one should be enough, but this
2304 * routine isn't time-critical enough to take chances.)
2306 unsigned q_correction
= 0;
2307 for (unsigned iter
= 0; iter
< 2; iter
++) {
2308 unsigned need_correction
= mp_cmp_hs(remainder
, d
);
2309 mp_cond_sub_into(remainder
, remainder
, d
, need_correction
);
2310 q_correction
+= need_correction
;
2312 mp_add_integer_into(quotient
, quotient
, q_correction
);
2315 * Now we should have a perfect answer, i.e. 0 <= r < d.
2317 assert(!mp_cmp_hs(remainder
, d
));
2320 mp_copy_into(q_out
, quotient
);
2322 mp_copy_into(r_out
, remainder
);
2326 mp_free(quotient_full
);
2331 mp_int
*mp_div(mp_int
*n
, mp_int
*d
)
2333 mp_int
*q
= mp_make_sized(n
->nw
);
2334 mp_divmod_into(n
, d
, q
, NULL
);
2338 mp_int
*mp_mod(mp_int
*n
, mp_int
*d
)
2340 mp_int
*r
= mp_make_sized(d
->nw
);
2341 mp_divmod_into(n
, d
, NULL
, r
);
2345 uint32_t mp_mod_known_integer(mp_int
*x
, uint32_t m
)
2347 uint64_t reciprocal
= ((uint64_t)1 << 48) / m
;
2348 uint64_t accumulator
= 0;
2349 for (size_t i
= mp_max_bytes(x
); i
-- > 0 ;) {
2350 accumulator
= 0x100 * accumulator
+ mp_get_byte(x
, i
);
2352 * Let A be the value in 'accumulator' at this point, and let
2353 * R be the value it will have after we subtract quot*m below.
2355 * Lemma 1: if A < 2^48, then R < 2m.
2359 * By construction, we have 2^48/m - 1 < reciprocal <= 2^48/m.
2360 * Multiplying that by the accumulator gives
2362 * A/m * 2^48 - A < unshifted_quot <= A/m * 2^48
2363 * i.e. 0 <= (A/m * 2^48) - unshifted_quot < A
2364 * i.e. 0 <= A/m - unshifted_quot/2^48 < A/2^48
2366 * So when we shift this quotient right by 48 bits, i.e. take
2367 * the floor of (unshifted_quot/2^48), the value we take the
2368 * floor of is at most A/2^48 less than the true rational
2369 * value A/m that we _wanted_ to take the floor of.
2371 * Provided A < 2^48, this is less than 1. So the quotient
2372 * 'quot' that we've just produced is either the true quotient
2373 * floor(A/m), or one less than it. Hence, the output value R
2374 * is less than 2m. []
2376 * Lemma 2: if A < 2^16 m, then the multiplication of
2377 * accumulator*reciprocal does not overflow.
2379 * Proof: as above, we have reciprocal <= 2^48/m. Multiplying
2380 * by A gives unshifted_quot <= 2^48 * A / m < 2^48 * 2^16 =
2383 uint64_t unshifted_quot
= accumulator
* reciprocal
;
2384 uint64_t quot
= unshifted_quot
>> 48;
2385 accumulator
-= quot
* m
;
2389 * Theorem 1: accumulator < 2m at the end of every iteration of
2392 * Proof: induction on the above loop.
2394 * Base case: at the start of the first loop iteration, the
2395 * accumulator is 0, which is certainly < 2m.
2397 * Inductive step: in each loop iteration, we take a value at most
2398 * 2m-1, multiply it by 2^8, and add another byte less than 2^8 to
2399 * generate the input value A to the reduction process above. So
2400 * we have A < 2m * 2^8 - 1. We know m < 2^32 (because it was
2401 * passed in as a uint32_t), so A < 2^41, which is enough to allow
2402 * us to apply Lemma 1, showing that the value of 'accumulator' at
2403 * the end of the loop is still < 2m. []
2405 * Corollary: we need at most one final subtraction of m to
2406 * produce the canonical residue of x mod m, i.e. in the range
2409 * Theorem 2: no multiplication in the inner loop overflows.
2411 * Proof: in Theorem 1 we established A < 2m * 2^8 - 1 in every
2412 * iteration. That is less than m * 2^16, so Lemma 2 applies.
2414 * The other multiplication, of quot * m, cannot overflow because
2415 * quot is at most A/m, so quot*m <= A < 2^64. []
2418 uint32_t result
= accumulator
;
2419 uint32_t reduced
= result
- m
;
2420 uint32_t select
= -(reduced
>> 31);
2421 result
= reduced
^ ((result
^ reduced
) & select
);
2426 mp_int
*mp_nthroot(mp_int
*y
, unsigned n
, mp_int
*remainder_out
)
2429 * Allocate scratch space.
2431 mp_int
**alloc
, **powers
, **newpowers
, *scratch
;
2432 size_t nalloc
= 2*(n
+1)+1;
2433 alloc
= snewn(nalloc
, mp_int
*);
2434 for (size_t i
= 0; i
< nalloc
; i
++)
2435 alloc
[i
] = mp_make_sized(y
->nw
+ 1);
2437 newpowers
= alloc
+ (n
+1);
2438 scratch
= alloc
[2*n
+2];
2441 * We're computing the rounded-down nth root of y, i.e. the
2442 * maximal x such that x^n <= y. We try to add 2^i to it for each
2443 * possible value of i, starting from the largest one that might
2444 * fit (i.e. such that 2^{n*i} fits in the size of y) downwards to
2447 * We track all the smaller powers of x in the array 'powers'. In
2448 * each iteration, if we update x, we update all of those values
2451 mp_copy_integer_into(powers
[0], 1);
2452 for (size_t s
= mp_max_bits(y
) / n
+ 1; s
-- > 0 ;) {
2454 * Let b = 2^s. We need to compute the powers (x+b)^i for each
2455 * i, starting from our recorded values of x^i.
2457 for (size_t i
= 0; i
< n
+1; i
++) {
2460 * + (i choose 1) x^{i-1} b
2461 * + (i choose 2) x^{i-2} b^2
2465 uint16_t binom
= 1; /* coefficient of b^i */
2466 mp_copy_into(newpowers
[i
], powers
[i
]);
2467 for (size_t j
= 0; j
< i
; j
++) {
2468 /* newpowers[i] += binom * powers[j] * 2^{(i-j)*s} */
2469 mp_mul_integer_into(scratch
, powers
[j
], binom
);
2470 mp_lshift_fixed_into(scratch
, scratch
, (i
-j
) * s
);
2471 mp_add_into(newpowers
[i
], newpowers
[i
], scratch
);
2473 uint32_t binom_mul
= binom
;
2476 assert(binom_mul
< 0x10000);
2482 * Now, is the new value of x^n still <= y? If so, update.
2484 unsigned newbit
= mp_cmp_hs(y
, newpowers
[n
]);
2485 for (size_t i
= 0; i
< n
+1; i
++)
2486 mp_select_into(powers
[i
], powers
[i
], newpowers
[i
], newbit
);
2490 mp_sub_into(remainder_out
, y
, powers
[n
]);
2492 mp_int
*root
= mp_new(mp_max_bits(y
) / n
);
2493 mp_copy_into(root
, powers
[1]);
2495 for (size_t i
= 0; i
< nalloc
; i
++)
2502 mp_int
*mp_modmul(mp_int
*x
, mp_int
*y
, mp_int
*modulus
)
2504 mp_int
*product
= mp_mul(x
, y
);
2505 mp_int
*reduced
= mp_mod(product
, modulus
);
2510 mp_int
*mp_modadd(mp_int
*x
, mp_int
*y
, mp_int
*modulus
)
2512 mp_int
*sum
= mp_add(x
, y
);
2513 mp_int
*reduced
= mp_mod(sum
, modulus
);
2518 mp_int
*mp_modsub(mp_int
*x
, mp_int
*y
, mp_int
*modulus
)
2520 mp_int
*diff
= mp_make_sized(size_t_max(x
->nw
, y
->nw
));
2521 mp_sub_into(diff
, x
, y
);
2522 unsigned negate
= mp_cmp_hs(y
, x
);
2523 mp_cond_negate(diff
, diff
, negate
);
2524 mp_int
*residue
= mp_mod(diff
, modulus
);
2525 mp_cond_negate(residue
, residue
, negate
);
2526 /* If we've just negated the residue, then it will be < 0 and need
2527 * the modulus adding to it to make it positive - *except* if the
2528 * residue was zero when we negated it. */
2529 unsigned make_positive
= negate
& ~mp_eq_integer(residue
, 0);
2530 mp_cond_add_into(residue
, residue
, modulus
, make_positive
);
2535 static mp_int
*mp_modadd_in_range(mp_int
*x
, mp_int
*y
, mp_int
*modulus
)
2537 mp_int
*sum
= mp_make_sized(modulus
->nw
);
2538 unsigned carry
= mp_add_into_internal(sum
, x
, y
);
2539 mp_cond_sub_into(sum
, sum
, modulus
, carry
| mp_cmp_hs(sum
, modulus
));
2543 static mp_int
*mp_modsub_in_range(mp_int
*x
, mp_int
*y
, mp_int
*modulus
)
2545 mp_int
*diff
= mp_make_sized(modulus
->nw
);
2546 mp_sub_into(diff
, x
, y
);
2547 mp_cond_add_into(diff
, diff
, modulus
, 1 ^ mp_cmp_hs(x
, y
));
2551 mp_int
*monty_add(MontyContext
*mc
, mp_int
*x
, mp_int
*y
)
2553 return mp_modadd_in_range(x
, y
, mc
->m
);
2556 mp_int
*monty_sub(MontyContext
*mc
, mp_int
*x
, mp_int
*y
)
2558 return mp_modsub_in_range(x
, y
, mc
->m
);
2561 void mp_min_into(mp_int
*r
, mp_int
*x
, mp_int
*y
)
2563 mp_select_into(r
, x
, y
, mp_cmp_hs(x
, y
));
2566 void mp_max_into(mp_int
*r
, mp_int
*x
, mp_int
*y
)
2568 mp_select_into(r
, y
, x
, mp_cmp_hs(x
, y
));
2571 mp_int
*mp_min(mp_int
*x
, mp_int
*y
)
2573 mp_int
*r
= mp_make_sized(size_t_min(x
->nw
, y
->nw
));
2574 mp_min_into(r
, x
, y
);
2578 mp_int
*mp_max(mp_int
*x
, mp_int
*y
)
2580 mp_int
*r
= mp_make_sized(size_t_max(x
->nw
, y
->nw
));
2581 mp_max_into(r
, x
, y
);
2585 mp_int
*mp_power_2(size_t power
)
2587 mp_int
*x
= mp_new(power
+ 1);
2588 mp_set_bit(x
, power
, 1);
2592 struct ModsqrtContext
{
2593 mp_int
*p
; /* the prime */
2594 MontyContext
*mc
; /* for doing arithmetic mod p */
2596 /* Decompose p-1 as 2^e k, for positive integer e and odd k */
2599 mp_int
*km1o2
; /* (k-1)/2 */
2601 /* The user-provided value z which is not a quadratic residue mod
2602 * p, and its kth power. Both in Montgomery form. */
2606 ModsqrtContext
*modsqrt_new(mp_int
*p
, mp_int
*any_nonsquare_mod_p
)
2608 ModsqrtContext
*sc
= snew(ModsqrtContext
);
2609 memset(sc
, 0, sizeof(ModsqrtContext
));
2612 sc
->mc
= monty_new(sc
->p
);
2613 sc
->z
= monty_import(sc
->mc
, any_nonsquare_mod_p
);
2615 /* Find the lowest set bit in p-1. Since this routine expects p to
2616 * be non-secret (typically a well-known standard elliptic curve
2617 * parameter), for once we don't need clever bit tricks. */
2618 for (sc
->e
= 1; sc
->e
< BIGNUM_INT_BITS
* p
->nw
; sc
->e
++)
2619 if (mp_get_bit(p
, sc
->e
))
2622 sc
->k
= mp_rshift_fixed(p
, sc
->e
);
2623 sc
->km1o2
= mp_rshift_fixed(sc
->k
, 1);
2625 /* Leave zk to be filled in lazily, since it's more expensive to
2626 * compute. If this context turns out never to be needed, we can
2627 * save the bulk of the setup time this way. */
2632 static void modsqrt_lazy_setup(ModsqrtContext
*sc
)
2635 sc
->zk
= monty_pow(sc
->mc
, sc
->z
, sc
->k
);
2638 void modsqrt_free(ModsqrtContext
*sc
)
2652 mp_int
*mp_modsqrt(ModsqrtContext
*sc
, mp_int
*x
, unsigned *success
)
2654 mp_int
*mx
= monty_import(sc
->mc
, x
);
2655 mp_int
*mroot
= monty_modsqrt(sc
, mx
, success
);
2657 mp_int
*root
= monty_export(sc
->mc
, mroot
);
2663 * Modular square root, using an algorithm more or less similar to
2664 * Tonelli-Shanks but adapted for constant time.
2666 * The basic idea is to write p-1 = k 2^e, where k is odd and e > 0.
2667 * Then the multiplicative group mod p (call it G) has a sequence of
2668 * e+1 nested subgroups G = G_0 > G_1 > G_2 > ... > G_e, where each
2669 * G_i is exactly half the size of G_{i-1} and consists of all the
2670 * squares of elements in G_{i-1}. So the innermost group G_e has
2671 * order k, which is odd, and hence within that group you can take a
2672 * square root by raising to the power (k+1)/2.
2674 * Our strategy is to iterate over these groups one by one and make
2675 * sure the number x we're trying to take the square root of is inside
2676 * each one, by adjusting it if it isn't.
2678 * Suppose g is a primitive root of p, i.e. a generator of G_0. (We
2679 * don't actually need to know what g _is_; we just imagine it for the
2680 * sake of understanding.) Then G_i consists of precisely the (2^i)th
2681 * powers of g, and hence, you can tell if a number is in G_i if
2682 * raising it to the power k 2^{e-i} gives 1. So the conceptual
2683 * algorithm goes: for each i, test whether x is in G_i by that
2684 * method. If it isn't, then the previous iteration ensured it's in
2685 * G_{i-1}, so it will be an odd power of g^{2^{i-1}}, and hence
2686 * multiplying by any other odd power of g^{2^{i-1}} will give x' in
2687 * G_i. And we have one of those, because our non-square z is an odd
2688 * power of g, so z^{2^{i-1}} is an odd power of g^{2^{i-1}}.
2690 * (There's a special case in the very first iteration, where we don't
2691 * have a G_{i-1}. If it turns out that x is not even in G_1, that
2692 * means it's not a square, so we set *success to 0. We still run the
2693 * rest of the algorithm anyway, for the sake of constant time, but we
2694 * don't give a hoot what it returns.)
2696 * When we get to the end and have x in G_e, then we can take its
2697 * square root by raising to (k+1)/2. But of course that's not the
2698 * square root of the original input - it's only the square root of
2699 * the adjusted version we produced during the algorithm. To get the
2700 * true output answer we also have to multiply by a power of z,
2701 * namely, z to the power of _half_ whatever we've been multiplying in
2702 * as we go along. (The power of z we multiplied in must have been
2703 * even, because the case in which we would have multiplied in an odd
2704 * power of z is the i=0 case, in which we instead set the failure
2707 * The code below is an optimised version of that basic idea, in which
2708 * we _start_ by computing x^k so as to be able to test membership in
2709 * G_i by only a few squarings rather than a full from-scratch modpow
2710 * every time; we also start by computing our candidate output value
2711 * x^{(k+1)/2}. So when the above description says 'adjust x by z^i'
2712 * for some i, we have to adjust our running values of x^k and
2713 * x^{(k+1)/2} by z^{ik} and z^{ik/2} respectively (the latter is safe
2714 * because, as above, i is always even). And it turns out that we
2715 * don't actually have to store the adjusted version of x itself at
2716 * all - we _only_ keep those two powers of it.
2718 mp_int
*monty_modsqrt(ModsqrtContext
*sc
, mp_int
*x
, unsigned *success
)
2720 modsqrt_lazy_setup(sc
);
2722 mp_int
*scratch_to_free
= mp_make_sized(3 * sc
->mc
->rw
);
2723 mp_int scratch
= *scratch_to_free
;
2726 * Compute toret = x^{(k+1)/2}, our starting point for the output
2727 * square root, and also xk = x^k which we'll use as we go along
2728 * for knowing when to apply correction factors. We do this by
2729 * first computing x^{(k-1)/2}, then multiplying it by x, then
2730 * multiplying the two together.
2732 mp_int
*toret
= monty_pow(sc
->mc
, x
, sc
->km1o2
);
2733 mp_int xk
= mp_alloc_from_scratch(&scratch
, sc
->mc
->rw
);
2734 mp_copy_into(&xk
, toret
);
2735 monty_mul_into(sc
->mc
, toret
, toret
, x
);
2736 monty_mul_into(sc
->mc
, &xk
, toret
, &xk
);
2738 mp_int tmp
= mp_alloc_from_scratch(&scratch
, sc
->mc
->rw
);
2740 mp_int power_of_zk
= mp_alloc_from_scratch(&scratch
, sc
->mc
->rw
);
2741 mp_copy_into(&power_of_zk
, sc
->zk
);
2743 for (size_t i
= 0; i
< sc
->e
; i
++) {
2744 mp_copy_into(&tmp
, &xk
);
2745 for (size_t j
= i
+1; j
< sc
->e
; j
++)
2746 monty_mul_into(sc
->mc
, &tmp
, &tmp
, &tmp
);
2747 unsigned eq1
= mp_cmp_eq(&tmp
, monty_identity(sc
->mc
));
2750 /* One special case: if x=0, then no power of x will ever
2751 * equal 1, but we should still report success on the
2752 * grounds that 0 does have a square root mod p. */
2753 *success
= eq1
| mp_eq_integer(x
, 0);
2755 monty_mul_into(sc
->mc
, &tmp
, toret
, &power_of_zk
);
2756 mp_select_into(toret
, &tmp
, toret
, eq1
);
2758 monty_mul_into(sc
->mc
, &power_of_zk
,
2759 &power_of_zk
, &power_of_zk
);
2761 monty_mul_into(sc
->mc
, &tmp
, &xk
, &power_of_zk
);
2762 mp_select_into(&xk
, &tmp
, &xk
, eq1
);
2766 mp_free(scratch_to_free
);
2771 mp_int
*mp_random_bits_fn(size_t bits
, random_read_fn_t random_read
)
2773 size_t bytes
= (bits
+ 7) / 8;
2774 uint8_t *randbuf
= snewn(bytes
, uint8_t);
2775 random_read(randbuf
, bytes
);
2777 randbuf
[0] &= (2 << ((bits
-1) & 7)) - 1;
2778 mp_int
*toret
= mp_from_bytes_be(make_ptrlen(randbuf
, bytes
));
2779 smemclr(randbuf
, bytes
);
2784 mp_int
*mp_random_upto_fn(mp_int
*limit
, random_read_fn_t rf
)
2787 * It would be nice to generate our random numbers in such a way
2788 * as to make every possible outcome literally equiprobable. But
2789 * we can't do that in constant time, so we have to go for a very
2790 * close approximation instead. I'm going to take the view that a
2791 * factor of (1+2^-128) between the probabilities of two outcomes
2792 * is acceptable on the grounds that you'd have to examine so many
2793 * outputs to even detect it.
2795 mp_int
*unreduced
= mp_random_bits_fn(mp_max_bits(limit
) + 128, rf
);
2796 mp_int
*reduced
= mp_mod(unreduced
, limit
);
2801 mp_int
*mp_random_in_range_fn(mp_int
*lo
, mp_int
*hi
, random_read_fn_t rf
)
2803 mp_int
*n_outcomes
= mp_sub(hi
, lo
);
2804 mp_int
*addend
= mp_random_upto_fn(n_outcomes
, rf
);
2805 mp_int
*result
= mp_make_sized(hi
->nw
);
2806 mp_add_into(result
, addend
, lo
);
2808 mp_free(n_outcomes
);