beta-0.89.2
[luatex.git] / source / libs / gmp / gmp-src / mpn / generic / mul_fft.c
blobd3c0b7bf22053e1afcd9a0185cfc155fe9edbd27
1 /* Schoenhage's fast multiplication modulo 2^N+1.
3 Contributed by Paul Zimmermann.
5 THE FUNCTIONS IN THIS FILE ARE INTERNAL WITH MUTABLE INTERFACES. IT IS ONLY
6 SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST
7 GUARANTEED THAT THEY WILL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
9 Copyright 1998-2010, 2012, 2013 Free Software Foundation, Inc.
11 This file is part of the GNU MP Library.
13 The GNU MP Library is free software; you can redistribute it and/or modify
14 it under the terms of either:
16 * the GNU Lesser General Public License as published by the Free
17 Software Foundation; either version 3 of the License, or (at your
18 option) any later version.
22 * the GNU General Public License as published by the Free Software
23 Foundation; either version 2 of the License, or (at your option) any
24 later version.
26 or both in parallel, as here.
28 The GNU MP Library is distributed in the hope that it will be useful, but
29 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
30 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
31 for more details.
33 You should have received copies of the GNU General Public License and the
34 GNU Lesser General Public License along with the GNU MP Library. If not,
35 see https://www.gnu.org/licenses/. */
38 /* References:
40 Schnelle Multiplikation grosser Zahlen, by Arnold Schoenhage and Volker
41 Strassen, Computing 7, p. 281-292, 1971.
43 Asymptotically fast algorithms for the numerical multiplication and division
44 of polynomials with complex coefficients, by Arnold Schoenhage, Computer
45 Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982.
47 Tapes versus Pointers, a study in implementing fast algorithms, by Arnold
48 Schoenhage, Bulletin of the EATCS, 30, p. 23-32, 1986.
50 TODO:
52 Implement some of the tricks published at ISSAC'2007 by Gaudry, Kruppa, and
53 Zimmermann.
55 It might be possible to avoid a small number of MPN_COPYs by using a
56 rotating temporary or two.
58 Cleanup and simplify the code!
61 #ifdef TRACE
62 #undef TRACE
63 #define TRACE(x) x
64 #include <stdio.h>
65 #else
66 #define TRACE(x)
67 #endif
69 #include "gmp.h"
70 #include "gmp-impl.h"
72 #ifdef WANT_ADDSUB
73 #include "generic/add_n_sub_n.c"
74 #define HAVE_NATIVE_mpn_add_n_sub_n 1
75 #endif
77 static mp_limb_t mpn_mul_fft_internal (mp_ptr, mp_size_t, int, mp_ptr *,
78 mp_ptr *, mp_ptr, mp_ptr, mp_size_t,
79 mp_size_t, mp_size_t, int **, mp_ptr, int);
80 static void mpn_mul_fft_decompose (mp_ptr, mp_ptr *, mp_size_t, mp_size_t, mp_srcptr,
81 mp_size_t, mp_size_t, mp_size_t, mp_ptr);
84 /* Find the best k to use for a mod 2^(m*GMP_NUMB_BITS)+1 FFT for m >= n.
85 We have sqr=0 if for a multiply, sqr=1 for a square.
86 There are three generations of this code; we keep the old ones as long as
87 some gmp-mparam.h is not updated. */
90 /*****************************************************************************/
92 #if TUNE_PROGRAM_BUILD || (defined (MUL_FFT_TABLE3) && defined (SQR_FFT_TABLE3))
94 #ifndef FFT_TABLE3_SIZE /* When tuning this is defined in gmp-impl.h */
95 #if defined (MUL_FFT_TABLE3_SIZE) && defined (SQR_FFT_TABLE3_SIZE)
96 #if MUL_FFT_TABLE3_SIZE > SQR_FFT_TABLE3_SIZE
97 #define FFT_TABLE3_SIZE MUL_FFT_TABLE3_SIZE
98 #else
99 #define FFT_TABLE3_SIZE SQR_FFT_TABLE3_SIZE
100 #endif
101 #endif
102 #endif
104 #ifndef FFT_TABLE3_SIZE
105 #define FFT_TABLE3_SIZE 200
106 #endif
108 FFT_TABLE_ATTRS struct fft_table_nk mpn_fft_table3[2][FFT_TABLE3_SIZE] =
110 MUL_FFT_TABLE3,
111 SQR_FFT_TABLE3
115 mpn_fft_best_k (mp_size_t n, int sqr)
117 const struct fft_table_nk *fft_tab, *tab;
118 mp_size_t tab_n, thres;
119 int last_k;
121 fft_tab = mpn_fft_table3[sqr];
122 last_k = fft_tab->k;
123 for (tab = fft_tab + 1; ; tab++)
125 tab_n = tab->n;
126 thres = tab_n << last_k;
127 if (n <= thres)
128 break;
129 last_k = tab->k;
131 return last_k;
134 #define MPN_FFT_BEST_READY 1
135 #endif
137 /*****************************************************************************/
139 #if ! defined (MPN_FFT_BEST_READY)
140 FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] =
142 MUL_FFT_TABLE,
143 SQR_FFT_TABLE
147 mpn_fft_best_k (mp_size_t n, int sqr)
149 int i;
151 for (i = 0; mpn_fft_table[sqr][i] != 0; i++)
152 if (n < mpn_fft_table[sqr][i])
153 return i + FFT_FIRST_K;
155 /* treat 4*last as one further entry */
156 if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1])
157 return i + FFT_FIRST_K;
158 else
159 return i + FFT_FIRST_K + 1;
161 #endif
163 /*****************************************************************************/
166 /* Returns smallest possible number of limbs >= pl for a fft of size 2^k,
167 i.e. smallest multiple of 2^k >= pl.
169 Don't declare static: needed by tuneup.
172 mp_size_t
173 mpn_fft_next_size (mp_size_t pl, int k)
175 pl = 1 + ((pl - 1) >> k); /* ceil (pl/2^k) */
176 return pl << k;
180 /* Initialize l[i][j] with bitrev(j) */
181 static void
182 mpn_fft_initl (int **l, int k)
184 int i, j, K;
185 int *li;
187 l[0][0] = 0;
188 for (i = 1, K = 1; i <= k; i++, K *= 2)
190 li = l[i];
191 for (j = 0; j < K; j++)
193 li[j] = 2 * l[i - 1][j];
194 li[K + j] = 1 + li[j];
200 /* r <- a*2^d mod 2^(n*GMP_NUMB_BITS)+1 with a = {a, n+1}
201 Assumes a is semi-normalized, i.e. a[n] <= 1.
202 r and a must have n+1 limbs, and not overlap.
204 static void
205 mpn_fft_mul_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t d, mp_size_t n)
207 unsigned int sh;
208 mp_size_t m;
209 mp_limb_t cc, rd;
211 sh = d % GMP_NUMB_BITS;
212 m = d / GMP_NUMB_BITS;
214 if (m >= n) /* negate */
216 /* r[0..m-1] <-- lshift(a[n-m]..a[n-1], sh)
217 r[m..n-1] <-- -lshift(a[0]..a[n-m-1], sh) */
219 m -= n;
220 if (sh != 0)
222 /* no out shift below since a[n] <= 1 */
223 mpn_lshift (r, a + n - m, m + 1, sh);
224 rd = r[m];
225 cc = mpn_lshiftc (r + m, a, n - m, sh);
227 else
229 MPN_COPY (r, a + n - m, m);
230 rd = a[n];
231 mpn_com (r + m, a, n - m);
232 cc = 0;
235 /* add cc to r[0], and add rd to r[m] */
237 /* now add 1 in r[m], subtract 1 in r[n], i.e. add 1 in r[0] */
239 r[n] = 0;
240 /* cc < 2^sh <= 2^(GMP_NUMB_BITS-1) thus no overflow here */
241 cc++;
242 mpn_incr_u (r, cc);
244 rd++;
245 /* rd might overflow when sh=GMP_NUMB_BITS-1 */
246 cc = (rd == 0) ? 1 : rd;
247 r = r + m + (rd == 0);
248 mpn_incr_u (r, cc);
250 else
252 /* r[0..m-1] <-- -lshift(a[n-m]..a[n-1], sh)
253 r[m..n-1] <-- lshift(a[0]..a[n-m-1], sh) */
254 if (sh != 0)
256 /* no out bits below since a[n] <= 1 */
257 mpn_lshiftc (r, a + n - m, m + 1, sh);
258 rd = ~r[m];
259 /* {r, m+1} = {a+n-m, m+1} << sh */
260 cc = mpn_lshift (r + m, a, n - m, sh); /* {r+m, n-m} = {a, n-m}<<sh */
262 else
264 /* r[m] is not used below, but we save a test for m=0 */
265 mpn_com (r, a + n - m, m + 1);
266 rd = a[n];
267 MPN_COPY (r + m, a, n - m);
268 cc = 0;
271 /* now complement {r, m}, subtract cc from r[0], subtract rd from r[m] */
273 /* if m=0 we just have r[0]=a[n] << sh */
274 if (m != 0)
276 /* now add 1 in r[0], subtract 1 in r[m] */
277 if (cc-- == 0) /* then add 1 to r[0] */
278 cc = mpn_add_1 (r, r, n, CNST_LIMB(1));
279 cc = mpn_sub_1 (r, r, m, cc) + 1;
280 /* add 1 to cc instead of rd since rd might overflow */
283 /* now subtract cc and rd from r[m..n] */
285 r[n] = -mpn_sub_1 (r + m, r + m, n - m, cc);
286 r[n] -= mpn_sub_1 (r + m, r + m, n - m, rd);
287 if (r[n] & GMP_LIMB_HIGHBIT)
288 r[n] = mpn_add_1 (r, r, n, CNST_LIMB(1));
293 /* r <- a+b mod 2^(n*GMP_NUMB_BITS)+1.
294 Assumes a and b are semi-normalized.
296 static inline void
297 mpn_fft_add_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
299 mp_limb_t c, x;
301 c = a[n] + b[n] + mpn_add_n (r, a, b, n);
302 /* 0 <= c <= 3 */
304 #if 1
305 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The
306 result is slower code, of course. But the following outsmarts GCC. */
307 x = (c - 1) & -(c != 0);
308 r[n] = c - x;
309 MPN_DECR_U (r, n + 1, x);
310 #endif
311 #if 0
312 if (c > 1)
314 r[n] = 1; /* r[n] - c = 1 */
315 MPN_DECR_U (r, n + 1, c - 1);
317 else
319 r[n] = c;
321 #endif
324 /* r <- a-b mod 2^(n*GMP_NUMB_BITS)+1.
325 Assumes a and b are semi-normalized.
327 static inline void
328 mpn_fft_sub_modF (mp_ptr r, mp_srcptr a, mp_srcptr b, mp_size_t n)
330 mp_limb_t c, x;
332 c = a[n] - b[n] - mpn_sub_n (r, a, b, n);
333 /* -2 <= c <= 1 */
335 #if 1
336 /* GCC 4.1 outsmarts most expressions here, and generates a 50% branch. The
337 result is slower code, of course. But the following outsmarts GCC. */
338 x = (-c) & -((c & GMP_LIMB_HIGHBIT) != 0);
339 r[n] = x + c;
340 MPN_INCR_U (r, n + 1, x);
341 #endif
342 #if 0
343 if ((c & GMP_LIMB_HIGHBIT) != 0)
345 r[n] = 0;
346 MPN_INCR_U (r, n + 1, -c);
348 else
350 r[n] = c;
352 #endif
355 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
356 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
357 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */
359 static void
360 mpn_fft_fft (mp_ptr *Ap, mp_size_t K, int **ll,
361 mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp)
363 if (K == 2)
365 mp_limb_t cy;
366 #if HAVE_NATIVE_mpn_add_n_sub_n
367 cy = mpn_add_n_sub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;
368 #else
369 MPN_COPY (tp, Ap[0], n + 1);
370 mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1);
371 cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);
372 #endif
373 if (Ap[0][n] > 1) /* can be 2 or 3 */
374 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
375 if (cy) /* Ap[inc][n] can be -1 or -2 */
376 Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + 1);
378 else
380 mp_size_t j, K2 = K >> 1;
381 int *lk = *ll;
383 mpn_fft_fft (Ap, K2, ll-1, 2 * omega, n, inc * 2, tp);
384 mpn_fft_fft (Ap+inc, K2, ll-1, 2 * omega, n, inc * 2, tp);
385 /* A[2*j*inc] <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc]
386 A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */
387 for (j = 0; j < K2; j++, lk += 2, Ap += 2 * inc)
389 /* Ap[inc] <- Ap[0] + Ap[inc] * 2^(lk[1] * omega)
390 Ap[0] <- Ap[0] + Ap[inc] * 2^(lk[0] * omega) */
391 mpn_fft_mul_2exp_modF (tp, Ap[inc], lk[0] * omega, n);
392 mpn_fft_sub_modF (Ap[inc], Ap[0], tp, n);
393 mpn_fft_add_modF (Ap[0], Ap[0], tp, n);
398 /* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where
399 N=n*GMP_NUMB_BITS, and 2^omega is a primitive root mod 2^N+1
400 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1
401 tp must have space for 2*(n+1) limbs.
405 /* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*GMP_NUMB_BITS)+1,
406 by subtracting that modulus if necessary.
408 If ap[0..n] is exactly 2^(n*GMP_NUMB_BITS) then mpn_sub_1 produces a
409 borrow and the limbs must be zeroed out again. This will occur very
410 infrequently. */
412 static inline void
413 mpn_fft_normalize (mp_ptr ap, mp_size_t n)
415 if (ap[n] != 0)
417 MPN_DECR_U (ap, n + 1, CNST_LIMB(1));
418 if (ap[n] == 0)
420 /* This happens with very low probability; we have yet to trigger it,
421 and thereby make sure this code is correct. */
422 MPN_ZERO (ap, n);
423 ap[n] = 1;
425 else
426 ap[n] = 0;
430 /* a[i] <- a[i]*b[i] mod 2^(n*GMP_NUMB_BITS)+1 for 0 <= i < K */
431 static void
432 mpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, mp_size_t K)
434 int i;
435 int sqr = (ap == bp);
436 TMP_DECL;
438 TMP_MARK;
440 if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
442 mp_size_t K2, nprime2, Nprime2, M2, maxLK, l, Mp2;
443 int k;
444 int **fft_l, *tmp;
445 mp_ptr *Ap, *Bp, A, B, T;
447 k = mpn_fft_best_k (n, sqr);
448 K2 = (mp_size_t) 1 << k;
449 ASSERT_ALWAYS((n & (K2 - 1)) == 0);
450 maxLK = (K2 > GMP_NUMB_BITS) ? K2 : GMP_NUMB_BITS;
451 M2 = n * GMP_NUMB_BITS >> k;
452 l = n >> k;
453 Nprime2 = ((2 * M2 + k + 2 + maxLK) / maxLK) * maxLK;
454 /* Nprime2 = ceil((2*M2+k+3)/maxLK)*maxLK*/
455 nprime2 = Nprime2 / GMP_NUMB_BITS;
457 /* we should ensure that nprime2 is a multiple of the next K */
458 if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
460 mp_size_t K3;
461 for (;;)
463 K3 = (mp_size_t) 1 << mpn_fft_best_k (nprime2, sqr);
464 if ((nprime2 & (K3 - 1)) == 0)
465 break;
466 nprime2 = (nprime2 + K3 - 1) & -K3;
467 Nprime2 = nprime2 * GMP_LIMB_BITS;
468 /* warning: since nprime2 changed, K3 may change too! */
471 ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */
473 Mp2 = Nprime2 >> k;
475 Ap = TMP_BALLOC_MP_PTRS (K2);
476 Bp = TMP_BALLOC_MP_PTRS (K2);
477 A = TMP_BALLOC_LIMBS (2 * (nprime2 + 1) << k);
478 T = TMP_BALLOC_LIMBS (2 * (nprime2 + 1));
479 B = A + ((nprime2 + 1) << k);
480 fft_l = TMP_BALLOC_TYPE (k + 1, int *);
481 tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
482 for (i = 0; i <= k; i++)
484 fft_l[i] = tmp;
485 tmp += (mp_size_t) 1 << i;
488 mpn_fft_initl (fft_l, k);
490 TRACE (printf ("recurse: %ldx%ld limbs -> %ld times %ldx%ld (%1.2f)\n", n,
491 n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2));
492 for (i = 0; i < K; i++, ap++, bp++)
494 mp_limb_t cy;
495 mpn_fft_normalize (*ap, n);
496 if (!sqr)
497 mpn_fft_normalize (*bp, n);
499 mpn_mul_fft_decompose (A, Ap, K2, nprime2, *ap, (l << k) + 1, l, Mp2, T);
500 if (!sqr)
501 mpn_mul_fft_decompose (B, Bp, K2, nprime2, *bp, (l << k) + 1, l, Mp2, T);
503 cy = mpn_mul_fft_internal (*ap, n, k, Ap, Bp, A, B, nprime2,
504 l, Mp2, fft_l, T, sqr);
505 (*ap)[n] = cy;
508 else
510 mp_ptr a, b, tp, tpn;
511 mp_limb_t cc;
512 mp_size_t n2 = 2 * n;
513 tp = TMP_BALLOC_LIMBS (n2);
514 tpn = tp + n;
515 TRACE (printf (" mpn_mul_n %ld of %ld limbs\n", K, n));
516 for (i = 0; i < K; i++)
518 a = *ap++;
519 b = *bp++;
520 if (sqr)
521 mpn_sqr (tp, a, n);
522 else
523 mpn_mul_n (tp, b, a, n);
524 if (a[n] != 0)
525 cc = mpn_add_n (tpn, tpn, b, n);
526 else
527 cc = 0;
528 if (b[n] != 0)
529 cc += mpn_add_n (tpn, tpn, a, n) + a[n];
530 if (cc != 0)
532 /* FIXME: use MPN_INCR_U here, since carry is not expected. */
533 cc = mpn_add_1 (tp, tp, n2, cc);
534 ASSERT (cc == 0);
536 a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1));
539 TMP_FREE;
543 /* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]]
544 output: K*A[0] K*A[K-1] ... K*A[1].
545 Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1.
546 This condition is also fulfilled at exit.
548 static void
549 mpn_fft_fftinv (mp_ptr *Ap, mp_size_t K, mp_size_t omega, mp_size_t n, mp_ptr tp)
551 if (K == 2)
553 mp_limb_t cy;
554 #if HAVE_NATIVE_mpn_add_n_sub_n
555 cy = mpn_add_n_sub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;
556 #else
557 MPN_COPY (tp, Ap[0], n + 1);
558 mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1);
559 cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);
560 #endif
561 if (Ap[0][n] > 1) /* can be 2 or 3 */
562 Ap[0][n] = 1 - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - 1);
563 if (cy) /* Ap[1][n] can be -1 or -2 */
564 Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + 1);
566 else
568 mp_size_t j, K2 = K >> 1;
570 mpn_fft_fftinv (Ap, K2, 2 * omega, n, tp);
571 mpn_fft_fftinv (Ap + K2, K2, 2 * omega, n, tp);
572 /* A[j] <- A[j] + omega^j A[j+K/2]
573 A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */
574 for (j = 0; j < K2; j++, Ap++)
576 /* Ap[K2] <- Ap[0] + Ap[K2] * 2^((j + K2) * omega)
577 Ap[0] <- Ap[0] + Ap[K2] * 2^(j * omega) */
578 mpn_fft_mul_2exp_modF (tp, Ap[K2], j * omega, n);
579 mpn_fft_sub_modF (Ap[K2], Ap[0], tp, n);
580 mpn_fft_add_modF (Ap[0], Ap[0], tp, n);
586 /* R <- A/2^k mod 2^(n*GMP_NUMB_BITS)+1 */
587 static void
588 mpn_fft_div_2exp_modF (mp_ptr r, mp_srcptr a, mp_bitcnt_t k, mp_size_t n)
590 mp_bitcnt_t i;
592 ASSERT (r != a);
593 i = (mp_bitcnt_t) 2 * n * GMP_NUMB_BITS - k;
594 mpn_fft_mul_2exp_modF (r, a, i, n);
595 /* 1/2^k = 2^(2nL-k) mod 2^(n*GMP_NUMB_BITS)+1 */
596 /* normalize so that R < 2^(n*GMP_NUMB_BITS)+1 */
597 mpn_fft_normalize (r, n);
601 /* {rp,n} <- {ap,an} mod 2^(n*GMP_NUMB_BITS)+1, n <= an <= 3*n.
602 Returns carry out, i.e. 1 iff {ap,an} = -1 mod 2^(n*GMP_NUMB_BITS)+1,
603 then {rp,n}=0.
605 static mp_size_t
606 mpn_fft_norm_modF (mp_ptr rp, mp_size_t n, mp_ptr ap, mp_size_t an)
608 mp_size_t l, m, rpn;
609 mp_limb_t cc;
611 ASSERT ((n <= an) && (an <= 3 * n));
612 m = an - 2 * n;
613 if (m > 0)
615 l = n;
616 /* add {ap, m} and {ap+2n, m} in {rp, m} */
617 cc = mpn_add_n (rp, ap, ap + 2 * n, m);
618 /* copy {ap+m, n-m} to {rp+m, n-m} */
619 rpn = mpn_add_1 (rp + m, ap + m, n - m, cc);
621 else
623 l = an - n; /* l <= n */
624 MPN_COPY (rp, ap, n);
625 rpn = 0;
628 /* remains to subtract {ap+n, l} from {rp, n+1} */
629 cc = mpn_sub_n (rp, rp, ap + n, l);
630 rpn -= mpn_sub_1 (rp + l, rp + l, n - l, cc);
631 if (rpn < 0) /* necessarily rpn = -1 */
632 rpn = mpn_add_1 (rp, rp, n, CNST_LIMB(1));
633 return rpn;
636 /* store in A[0..nprime] the first M bits from {n, nl},
637 in A[nprime+1..] the following M bits, ...
638 Assumes M is a multiple of GMP_NUMB_BITS (M = l * GMP_NUMB_BITS).
639 T must have space for at least (nprime + 1) limbs.
640 We must have nl <= 2*K*l.
642 static void
643 mpn_mul_fft_decompose (mp_ptr A, mp_ptr *Ap, mp_size_t K, mp_size_t nprime,
644 mp_srcptr n, mp_size_t nl, mp_size_t l, mp_size_t Mp,
645 mp_ptr T)
647 mp_size_t i, j;
648 mp_ptr tmp;
649 mp_size_t Kl = K * l;
650 TMP_DECL;
651 TMP_MARK;
653 if (nl > Kl) /* normalize {n, nl} mod 2^(Kl*GMP_NUMB_BITS)+1 */
655 mp_size_t dif = nl - Kl;
656 mp_limb_signed_t cy;
658 tmp = TMP_BALLOC_LIMBS(Kl + 1);
660 if (dif > Kl)
662 int subp = 0;
664 cy = mpn_sub_n (tmp, n, n + Kl, Kl);
665 n += 2 * Kl;
666 dif -= Kl;
668 /* now dif > 0 */
669 while (dif > Kl)
671 if (subp)
672 cy += mpn_sub_n (tmp, tmp, n, Kl);
673 else
674 cy -= mpn_add_n (tmp, tmp, n, Kl);
675 subp ^= 1;
676 n += Kl;
677 dif -= Kl;
679 /* now dif <= Kl */
680 if (subp)
681 cy += mpn_sub (tmp, tmp, Kl, n, dif);
682 else
683 cy -= mpn_add (tmp, tmp, Kl, n, dif);
684 if (cy >= 0)
685 cy = mpn_add_1 (tmp, tmp, Kl, cy);
686 else
687 cy = mpn_sub_1 (tmp, tmp, Kl, -cy);
689 else /* dif <= Kl, i.e. nl <= 2 * Kl */
691 cy = mpn_sub (tmp, n, Kl, n + Kl, dif);
692 cy = mpn_add_1 (tmp, tmp, Kl, cy);
694 tmp[Kl] = cy;
695 nl = Kl + 1;
696 n = tmp;
698 for (i = 0; i < K; i++)
700 Ap[i] = A;
701 /* store the next M bits of n into A[0..nprime] */
702 if (nl > 0) /* nl is the number of remaining limbs */
704 j = (l <= nl && i < K - 1) ? l : nl; /* store j next limbs */
705 nl -= j;
706 MPN_COPY (T, n, j);
707 MPN_ZERO (T + j, nprime + 1 - j);
708 n += l;
709 mpn_fft_mul_2exp_modF (A, T, i * Mp, nprime);
711 else
712 MPN_ZERO (A, nprime + 1);
713 A += nprime + 1;
715 ASSERT_ALWAYS (nl == 0);
716 TMP_FREE;
719 /* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*GMP_NUMB_BITS
720 op is pl limbs, its high bit is returned.
721 One must have pl = mpn_fft_next_size (pl, k).
722 T must have space for 2 * (nprime + 1) limbs.
725 static mp_limb_t
726 mpn_mul_fft_internal (mp_ptr op, mp_size_t pl, int k,
727 mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B,
728 mp_size_t nprime, mp_size_t l, mp_size_t Mp,
729 int **fft_l, mp_ptr T, int sqr)
731 mp_size_t K, i, pla, lo, sh, j;
732 mp_ptr p;
733 mp_limb_t cc;
735 K = (mp_size_t) 1 << k;
737 /* direct fft's */
738 mpn_fft_fft (Ap, K, fft_l + k, 2 * Mp, nprime, 1, T);
739 if (!sqr)
740 mpn_fft_fft (Bp, K, fft_l + k, 2 * Mp, nprime, 1, T);
742 /* term to term multiplications */
743 mpn_fft_mul_modF_K (Ap, sqr ? Ap : Bp, nprime, K);
745 /* inverse fft's */
746 mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T);
748 /* division of terms after inverse fft */
749 Bp[0] = T + nprime + 1;
750 mpn_fft_div_2exp_modF (Bp[0], Ap[0], k, nprime);
751 for (i = 1; i < K; i++)
753 Bp[i] = Ap[i - 1];
754 mpn_fft_div_2exp_modF (Bp[i], Ap[i], k + (K - i) * Mp, nprime);
757 /* addition of terms in result p */
758 MPN_ZERO (T, nprime + 1);
759 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
760 p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */
761 MPN_ZERO (p, pla);
762 cc = 0; /* will accumulate the (signed) carry at p[pla] */
763 for (i = K - 1, lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l)
765 mp_ptr n = p + sh;
767 j = (K - i) & (K - 1);
769 if (mpn_add_n (n, n, Bp[j], nprime + 1))
770 cc += mpn_add_1 (n + nprime + 1, n + nprime + 1,
771 pla - sh - nprime - 1, CNST_LIMB(1));
772 T[2 * l] = i + 1; /* T = (i + 1)*2^(2*M) */
773 if (mpn_cmp (Bp[j], T, nprime + 1) > 0)
774 { /* subtract 2^N'+1 */
775 cc -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1));
776 cc -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1));
779 if (cc == -CNST_LIMB(1))
781 if ((cc = mpn_add_1 (p + pla - pl, p + pla - pl, pl, CNST_LIMB(1))))
783 /* p[pla-pl]...p[pla-1] are all zero */
784 mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1));
785 mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1));
788 else if (cc == 1)
790 if (pla >= 2 * pl)
792 while ((cc = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, cc)))
795 else
797 cc = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, cc);
798 ASSERT (cc == 0);
801 else
802 ASSERT (cc == 0);
804 /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ]
805 < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ]
806 < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */
807 return mpn_fft_norm_modF (op, pl, p, pla);
810 /* return the lcm of a and 2^k */
811 static mp_bitcnt_t
812 mpn_mul_fft_lcm (mp_bitcnt_t a, int k)
814 mp_bitcnt_t l = k;
816 while (a % 2 == 0 && k > 0)
818 a >>= 1;
819 k --;
821 return a << l;
825 mp_limb_t
826 mpn_mul_fft (mp_ptr op, mp_size_t pl,
827 mp_srcptr n, mp_size_t nl,
828 mp_srcptr m, mp_size_t ml,
829 int k)
831 int i;
832 mp_size_t K, maxLK;
833 mp_size_t N, Nprime, nprime, M, Mp, l;
834 mp_ptr *Ap, *Bp, A, T, B;
835 int **fft_l, *tmp;
836 int sqr = (n == m && nl == ml);
837 mp_limb_t h;
838 TMP_DECL;
840 TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k));
841 ASSERT_ALWAYS (mpn_fft_next_size (pl, k) == pl);
843 TMP_MARK;
844 N = pl * GMP_NUMB_BITS;
845 fft_l = TMP_BALLOC_TYPE (k + 1, int *);
846 tmp = TMP_BALLOC_TYPE ((size_t) 2 << k, int);
847 for (i = 0; i <= k; i++)
849 fft_l[i] = tmp;
850 tmp += (mp_size_t) 1 << i;
853 mpn_fft_initl (fft_l, k);
854 K = (mp_size_t) 1 << k;
855 M = N >> k; /* N = 2^k M */
856 l = 1 + (M - 1) / GMP_NUMB_BITS;
857 maxLK = mpn_mul_fft_lcm (GMP_NUMB_BITS, k); /* lcm (GMP_NUMB_BITS, 2^k) */
859 Nprime = (1 + (2 * M + k + 2) / maxLK) * maxLK;
860 /* Nprime = ceil((2*M+k+3)/maxLK)*maxLK; */
861 nprime = Nprime / GMP_NUMB_BITS;
862 TRACE (printf ("N=%ld K=%ld, M=%ld, l=%ld, maxLK=%ld, Np=%ld, np=%ld\n",
863 N, K, M, l, maxLK, Nprime, nprime));
864 /* we should ensure that recursively, nprime is a multiple of the next K */
865 if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD))
867 mp_size_t K2;
868 for (;;)
870 K2 = (mp_size_t) 1 << mpn_fft_best_k (nprime, sqr);
871 if ((nprime & (K2 - 1)) == 0)
872 break;
873 nprime = (nprime + K2 - 1) & -K2;
874 Nprime = nprime * GMP_LIMB_BITS;
875 /* warning: since nprime changed, K2 may change too! */
877 TRACE (printf ("new maxLK=%ld, Np=%ld, np=%ld\n", maxLK, Nprime, nprime));
879 ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */
881 T = TMP_BALLOC_LIMBS (2 * (nprime + 1));
882 Mp = Nprime >> k;
884 TRACE (printf ("%ldx%ld limbs -> %ld times %ldx%ld limbs (%1.2f)\n",
885 pl, pl, K, nprime, nprime, 2.0 * (double) N / Nprime / K);
886 printf (" temp space %ld\n", 2 * K * (nprime + 1)));
888 A = TMP_BALLOC_LIMBS (K * (nprime + 1));
889 Ap = TMP_BALLOC_MP_PTRS (K);
890 mpn_mul_fft_decompose (A, Ap, K, nprime, n, nl, l, Mp, T);
891 if (sqr)
893 mp_size_t pla;
894 pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */
895 B = TMP_BALLOC_LIMBS (pla);
896 Bp = TMP_BALLOC_MP_PTRS (K);
898 else
900 B = TMP_BALLOC_LIMBS (K * (nprime + 1));
901 Bp = TMP_BALLOC_MP_PTRS (K);
902 mpn_mul_fft_decompose (B, Bp, K, nprime, m, ml, l, Mp, T);
904 h = mpn_mul_fft_internal (op, pl, k, Ap, Bp, A, B, nprime, l, Mp, fft_l, T, sqr);
906 TMP_FREE;
907 return h;
910 #if WANT_OLD_FFT_FULL
911 /* multiply {n, nl} by {m, ml}, and put the result in {op, nl+ml} */
912 void
913 mpn_mul_fft_full (mp_ptr op,
914 mp_srcptr n, mp_size_t nl,
915 mp_srcptr m, mp_size_t ml)
917 mp_ptr pad_op;
918 mp_size_t pl, pl2, pl3, l;
919 mp_size_t cc, c2, oldcc;
920 int k2, k3;
921 int sqr = (n == m && nl == ml);
923 pl = nl + ml; /* total number of limbs of the result */
925 /* perform a fft mod 2^(2N)+1 and one mod 2^(3N)+1.
926 We must have pl3 = 3/2 * pl2, with pl2 a multiple of 2^k2, and
927 pl3 a multiple of 2^k3. Since k3 >= k2, both are multiples of 2^k2,
928 and pl2 must be an even multiple of 2^k2. Thus (pl2,pl3) =
929 (2*j*2^k2,3*j*2^k2), which works for 3*j <= pl/2^k2 <= 5*j.
930 We need that consecutive intervals overlap, i.e. 5*j >= 3*(j+1),
931 which requires j>=2. Thus this scheme requires pl >= 6 * 2^FFT_FIRST_K. */
933 /* ASSERT_ALWAYS(pl >= 6 * (1 << FFT_FIRST_K)); */
935 pl2 = (2 * pl - 1) / 5; /* ceil (2pl/5) - 1 */
938 pl2++;
939 k2 = mpn_fft_best_k (pl2, sqr); /* best fft size for pl2 limbs */
940 pl2 = mpn_fft_next_size (pl2, k2);
941 pl3 = 3 * pl2 / 2; /* since k>=FFT_FIRST_K=4, pl2 is a multiple of 2^4,
942 thus pl2 / 2 is exact */
943 k3 = mpn_fft_best_k (pl3, sqr);
945 while (mpn_fft_next_size (pl3, k3) != pl3);
947 TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl2=%ld pl3=%ld k=%d\n",
948 nl, ml, pl2, pl3, k2));
950 ASSERT_ALWAYS(pl3 <= pl);
951 cc = mpn_mul_fft (op, pl3, n, nl, m, ml, k3); /* mu */
952 ASSERT(cc == 0);
953 pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl2);
954 cc = mpn_mul_fft (pad_op, pl2, n, nl, m, ml, k2); /* lambda */
955 cc = -cc + mpn_sub_n (pad_op, pad_op, op, pl2); /* lambda - low(mu) */
956 /* 0 <= cc <= 1 */
957 ASSERT(0 <= cc && cc <= 1);
958 l = pl3 - pl2; /* l = pl2 / 2 since pl3 = 3/2 * pl2 */
959 c2 = mpn_add_n (pad_op, pad_op, op + pl2, l);
960 cc = mpn_add_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2) - cc;
961 ASSERT(-1 <= cc && cc <= 1);
962 if (cc < 0)
963 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
964 ASSERT(0 <= cc && cc <= 1);
965 /* now lambda-mu = {pad_op, pl2} - cc mod 2^(pl2*GMP_NUMB_BITS)+1 */
966 oldcc = cc;
967 #if HAVE_NATIVE_mpn_add_n_sub_n
968 c2 = mpn_add_n_sub_n (pad_op + l, pad_op, pad_op, pad_op + l, l);
969 /* c2 & 1 is the borrow, c2 & 2 is the carry */
970 cc += c2 >> 1; /* carry out from high <- low + high */
971 c2 = c2 & 1; /* borrow out from low <- low - high */
972 #else
974 mp_ptr tmp;
975 TMP_DECL;
977 TMP_MARK;
978 tmp = TMP_BALLOC_LIMBS (l);
979 MPN_COPY (tmp, pad_op, l);
980 c2 = mpn_sub_n (pad_op, pad_op, pad_op + l, l);
981 cc += mpn_add_n (pad_op + l, tmp, pad_op + l, l);
982 TMP_FREE;
984 #endif
985 c2 += oldcc;
986 /* first normalize {pad_op, pl2} before dividing by 2: c2 is the borrow
987 at pad_op + l, cc is the carry at pad_op + pl2 */
988 /* 0 <= cc <= 2 */
989 cc -= mpn_sub_1 (pad_op + l, pad_op + l, l, (mp_limb_t) c2);
990 /* -1 <= cc <= 2 */
991 if (cc > 0)
992 cc = -mpn_sub_1 (pad_op, pad_op, pl2, (mp_limb_t) cc);
993 /* now -1 <= cc <= 0 */
994 if (cc < 0)
995 cc = mpn_add_1 (pad_op, pad_op, pl2, (mp_limb_t) -cc);
996 /* now {pad_op, pl2} is normalized, with 0 <= cc <= 1 */
997 if (pad_op[0] & 1) /* if odd, add 2^(pl2*GMP_NUMB_BITS)+1 */
998 cc += 1 + mpn_add_1 (pad_op, pad_op, pl2, CNST_LIMB(1));
999 /* now 0 <= cc <= 2, but cc=2 cannot occur since it would give a carry
1000 out below */
1001 mpn_rshift (pad_op, pad_op, pl2, 1); /* divide by two */
1002 if (cc) /* then cc=1 */
1003 pad_op [pl2 - 1] |= (mp_limb_t) 1 << (GMP_NUMB_BITS - 1);
1004 /* now {pad_op,pl2}-cc = (lambda-mu)/(1-2^(l*GMP_NUMB_BITS))
1005 mod 2^(pl2*GMP_NUMB_BITS) + 1 */
1006 c2 = mpn_add_n (op, op, pad_op, pl2); /* no need to add cc (is 0) */
1007 /* since pl2+pl3 >= pl, necessary the extra limbs (including cc) are zero */
1008 MPN_COPY (op + pl3, pad_op, pl - pl3);
1009 ASSERT_MPN_ZERO_P (pad_op + pl - pl3, pl2 + pl3 - pl);
1010 __GMP_FREE_FUNC_LIMBS (pad_op, pl2);
1011 /* since the final result has at most pl limbs, no carry out below */
1012 mpn_add_1 (op + pl2, op + pl2, pl - pl2, (mp_limb_t) c2);
1014 #endif