Initial SymPy benchmark suite
[sympy.git] / sympy / mpmath / specfun.py
blob1d27aba8caf75893b4e4897d1cd343e4d0e02164
1 """
2 Miscellaneous special functions
3 """
5 from lib import *
6 from libmpc import *
7 from mptypes import *
8 from mptypes import constant
10 __docformat__ = 'plaintext'
12 #---------------------------------------------------------------------------#
13 # #
14 # First some mathematical constants #
15 # #
16 #---------------------------------------------------------------------------#
19 # The golden ratio is given by phi = (1 + sqrt(5))/2
21 @constant_memo
22 def phi_fixed(prec):
23 prec += 10
24 sqrt = [sqrt_fixed2, sqrt_fixed][prec < 20000]
25 a = sqrt(MP_FIVE<<prec, prec) + (MP_ONE << prec)
26 return a >> 11
28 # Catalan's constant is computed using Lupas's rapidly convergent series
29 # (listed on http://mathworld.wolfram.com/CatalansConstant.html)
30 # oo
31 # ___ n-1 8n 2 3 2
32 # 1 \ (-1) 2 (40n - 24n + 3) [(2n)!] (n!)
33 # K = --- ) -----------------------------------------
34 # 64 /___ 3 2
35 # n (2n-1) [(4n)!]
36 # n = 1
38 @constant_memo
39 def catalan_fixed(prec):
40 prec = prec + 20
41 a = one = MP_ONE << prec
42 s, t, n = 0, 1, 1
43 while t:
44 a *= 32 * n**3 * (2*n-1)
45 a //= (3-16*n+16*n**2)**2
46 t = a * (-1)**(n-1) * (40*n**2-24*n+3) // (n**3 * (2*n-1))
47 s += t
48 n += 1
49 return s >> (20 + 6)
51 # Euler's constant (gamma) is computed using the Brent-McMillan formula,
52 # gamma ~= A(n)/B(n) - log(n), where
54 # A(n) = sum_{k=0,1,2,...} (n**k / k!)**2 * H(k)
55 # B(n) = sum_{k=0,1,2,...} (n**k / k!)**2
56 # H(k) = 1 + 1/2 + 1/3 + ... + 1/k
58 # The error is bounded by O(exp(-4n)). Choosing n to be a power
59 # of two, 2**p, the logarithm becomes particularly easy to calculate.
61 # Reference:
62 # Xavier Gourdon & Pascal Sebah, The Euler constant: gamma
63 # http://numbers.computation.free.fr/Constants/Gamma/gamma.pdf
65 @constant_memo
66 def euler_fixed(prec):
67 prec += 30
68 # choose p such that exp(-4*(2**p)) < 2**-n
69 p = int(math.log((prec/4) * math.log(2), 2)) + 1
70 n = MP_ONE<<p
71 r = one = MP_ONE<<prec
72 H, A, B, npow, k, d = MP_ZERO, MP_ZERO, MP_ZERO, 1, 1, 1
73 while r:
74 A += (r * H) >> prec
75 B += r
76 r = r * (n*n) // (k*k)
77 H += one // k
78 k += 1
79 S = ((A<<prec) // B) - p*log2_fixed(prec)
80 return S >> 30
82 # Khinchin's constant is relatively difficult to compute. Here
83 # we use the rational zeta series
85 # oo 2*n-1
86 # ___ ___
87 # \ ` zeta(2*n)-1 \ ` (-1)^(k+1)
88 # log(K)*log(2) = ) ------------ ) ----------
89 # /___. n /___. k
90 # n = 1 k = 1
92 # which adds half a digit per term. The essential trick for achieving
93 # reasonable efficiency is to recycle both the values of the zeta
94 # function (essentially Bernoulli numbers) and the partial terms of
95 # the inner sum.
97 # An alternative might be to use K = 2*exp[1/log(2) X] where
99 # / 1 1 [ pi*x*(1-x^2) ]
100 # X = | ------ log [ ------------ ].
101 # / 0 x(1+x) [ sin(pi*x) ]
103 # and integrate numerically. In practice, this seems to be slightly
104 # slower than the zeta series at high precision.
106 @constant_memo
107 def khinchin_fixed(prec):
108 orig = mp.prec
109 try:
110 mp.prec = int(prec + prec**0.5 + 15)
111 s = mpf(0)
112 t = one = mpf(1)
113 B = bernoulli_range()
114 fac = mpf(4)
115 pipow = twopi2 = (2*pi)**2
116 n = 1
117 while 1:
118 zeta2n = (-1)**(n+1) * B.next() * pipow / fac
119 term = ((zeta2n - 1) * t) / n
120 # print n, nstr(term)
121 if term < eps:
122 break
123 s += term
124 t += (one/(2*n+1) - one/(2*n))
125 n += 1
126 fac *= (2*n)*(2*n-1)
127 pipow *= twopi2
128 return to_fixed(exp(s/ln2)._mpf_, prec)
129 finally:
130 mp.prec = orig
132 # Glaisher's constant is defined as A = exp(1/2 - zeta'(-1)).
133 # One way to compute it would be to perform direct numerical
134 # differentiation, but computing arbitrary Riemann zeta function
135 # values at high precision is expensive. We instead use the formula
137 # A = exp((6 (-zeta'(2))/pi^2 + log 2 pi + gamma)/12)
139 # and compute zeta'(2) from the series representation
141 # oo
142 # ___
143 # \ log k
144 # -zeta'(2) = ) -----
145 # /___ 2
147 # k = 2
149 # This series converges exceptionally slowly, but can be accelerated
150 # using Euler-Maclaurin formula. The important insight is that the
151 # E-M integral can be done in closed form and that the high order
152 # are given by
154 # n / \
155 # d | log x | a + b log x
156 # --- | ----- | = -----------
157 # n | 2 | 2 + n
158 # dx \ x / x
160 # where a and b are integers given by a simple recurrence. Note
161 # that just one logarithm is needed. However, lots of integer
162 # logarithms are required for the initial summation.
164 # This algorithm could possibly be turned into a faster algorithm
165 # for general evaluation of zeta(s) or zeta'(s); this should be
166 # looked into.
168 @constant_memo
169 def glaisher_fixed(prec):
170 orig = mp.prec
171 try:
172 dps = mp.dps
173 mp.prec = prec + 30
174 N = int(1.0*dps + 5)
175 logs = log_range()
176 s = mpf(0)
177 # E-M step 1: sum log(k)/k**2 for k = 2..N-1
178 for n in range(2, N):
179 # print n, N
180 logn = logs.next()
181 s += logn / n**2
182 logN = logs.next()
183 # E-M step 2: integral of log(x)/x**2 from N to inf
184 s += (1+logN)/N
185 # E-M step 3: endpoint correction term f(N)/2
186 s += logN/(N**2 * 2)
187 # E-M step 4: the series of derivatives
188 pN, a, b, j, B2k, fac, k = N**3, 1, -2, 3, bernoulli_range(), 2, 1
189 while 1:
190 # D(2*k-1) * B(2*k) / fac(2*k) [D(n) = nth derivative]
191 D = (a+b*logN)/pN
192 B = B2k.next()
193 term = B * D / fac
194 if abs(term) < eps:
195 break
196 # print k, nstr(term)
197 s -= term
198 # Advance derivative twice
199 a, b, pN, j = b-a*j, -j*b, pN*N, j+1
200 a, b, pN, j = b-a*j, -j*b, pN*N, j+1
201 k += 1
202 fac *= (2*k) * (2*k-1)
203 A = exp((6*s/pi**2 + log(2*pi) + euler)/12)
204 return to_fixed(A._mpf_, prec)
205 finally:
206 mp.prec = orig
208 # Apery's constant can be computed using the very rapidly convergent
209 # series
210 # oo
211 # ___ 2 10
212 # \ n 205 n + 250 n + 77 (n!)
213 # zeta(3) = ) (-1) ------------------- ----------
214 # /___ 64 5
215 # n = 0 ((2n+1)!)
217 @constant_memo
218 def apery_fixed(prec):
219 prec += 20
220 d = MP_ONE << prec
221 term = MP_BASE(77) << prec
222 n = 1
223 s = MP_ZERO
224 while term:
225 s += term
226 d *= (n**10)
227 d //= (((2*n+1)**5) * (2*n)**5)
228 term = (-1)**n * (205*(n**2) + 250*n + 77) * d
229 n += 1
230 return s >> (20 + 6)
232 fme = from_man_exp
234 def mpf_phi(p, r): return fme(phi_fixed(p+10), -p-10, p, r)
235 def mpf_khinchin(p, r): return fme(khinchin_fixed(p+10), -p-10, p, r)
236 def mpf_glaisher(p, r): return fme(glaisher_fixed(p+10), -p-10, p, r)
237 def mpf_apery(p, r): return fme(apery_fixed(p+10), -p-10, p, r)
238 def mpf_euler(p, r): return fme(euler_fixed(p+10), -p-10, p, r)
239 def mpf_catalan(p, r): return fme(catalan_fixed(p+10), -p-10, p, r)
241 phi = constant(mpf_phi, "Golden ratio (phi)")
242 catalan = constant(mpf_catalan, "Catalan's constant")
243 euler = constant(mpf_euler, "Euler's constant (gamma)")
244 khinchin = constant(mpf_khinchin, "Khinchin's constant")
245 glaisher = constant(mpf_glaisher, "Glaisher's constant")
246 apery = constant(mpf_apery, "Apery's constant")
249 #----------------------------------------------------------------------
250 # Factorial related functions
253 # For internal use
254 def int_fac(n, memo={0:1, 1:1}):
255 """Return n factorial (for integers n >= 0 only)."""
256 f = memo.get(n)
257 if f:
258 return f
259 k = len(memo)
260 p = memo[k-1]
261 while k <= n:
262 p *= k
263 if k < 1024:
264 memo[k] = p
265 k += 1
266 return p
268 if MODE == "gmpy":
269 int_fac = gmpy.fac
272 We compute the gamma function using Spouge's approximation
274 x! = (x+a)**(x+1/2) * exp(-x-a) * [c_0 + S(x) + eps]
276 where S(x) is the sum of c_k/(x+k) from k = 1 to a-1 and the coefficients
277 are given by
279 c_0 = sqrt(2*pi)
281 (-1)**(k-1)
282 c_k = ----------- (a-k)**(k-1/2) exp(-k+a), k = 1,2,...,a-1
283 (k - 1)!
285 Due to an inequality proved by Spouge, if we choose a = int(1.26*n), the
286 error eps is less than 10**-n for any x in the right complex half-plane
287 (assuming a > 2). In practice, it seems that a can be chosen quite a bit
288 lower still (30-50%); this possibility should be investigated.
290 Reference:
291 John L. Spouge, "Computation of the gamma, digamma, and trigamma
292 functions", SIAM Journal on Numerical Analysis 31 (1994), no. 3, 931-944.
295 spouge_cache = {}
297 def calc_spouge_coefficients(a, prec):
298 wp = prec + int(a*1.4)
299 c = [0] * a
300 # b = exp(a-1)
301 b = fexp(from_int(a-1), wp)
302 # e = exp(1)
303 e = fexp(fone, wp)
304 # sqrt(2*pi)
305 sq2pi = fsqrt(fshift(fpi(wp), 1), wp)
306 c[0] = to_fixed(sq2pi, prec)
307 for k in xrange(1, a):
308 # c[k] = ((-1)**(k-1) * (a-k)**k) * b / sqrt(a-k)
309 term = fmuli(b, ((-1)**(k-1) * (a-k)**k), wp)
310 term = fdiv(term, fsqrt(from_int(a-k), wp), wp)
311 c[k] = to_fixed(term, prec)
312 # b = b / (e * k)
313 b = fdiv(b, fmul(e, from_int(k), wp), wp)
314 return c
316 # Cached lookup of coefficients
317 def get_spouge_coefficients(prec):
319 # This exact precision has been used before
320 if prec in spouge_cache:
321 return spouge_cache[prec]
323 for p in spouge_cache:
324 if 0.8 <= float(p)/prec < 1:
325 return spouge_cache[p]
327 # Here we estimate the value of a based on Spouge's inequality for
328 # the relative error
329 a = max(3, int(0.39*prec)) # ~= 1.26*n
331 coefs = calc_spouge_coefficients(a, prec)
332 spouge_cache[prec] = (prec, a, coefs)
333 return spouge_cache[prec]
335 def spouge_sum_real(x, prec, a, c):
336 x = to_fixed(x, prec)
337 s = c[0]
338 for k in xrange(1, a):
339 s += (c[k] << prec) // (x + (k << prec))
340 return from_man_exp(s, -prec, prec, round_floor)
342 # Unused: for fast computation of gamma(p/q)
343 def spouge_sum_rational(p, q, prec, a, c):
344 s = c[0]
345 for k in xrange(1, a):
346 s += c[k] * q // (p+q*k)
347 return from_man_exp(s, -prec, prec, round_floor)
349 # For a complex number a + b*I, we have
351 # c_k (a+k)*c_k b * c_k
352 # ------------- = --------- - ------- * I
353 # (a + b*I) + k M M
355 # 2 2 2 2 2
356 # where M = (a+k) + b = (a + b ) + (2*a*k + k )
358 def spouge_sum_complex(re, im, prec, a, c):
359 re = to_fixed(re, prec)
360 im = to_fixed(im, prec)
361 sre, sim = c[0], 0
362 mag = ((re**2)>>prec) + ((im**2)>>prec)
363 for k in xrange(1, a):
364 M = mag + re*(2*k) + ((k**2) << prec)
365 sre += (c[k] * (re + (k << prec))) // M
366 sim -= (c[k] * im) // M
367 re = from_man_exp(sre, -prec, prec, round_floor)
368 im = from_man_exp(sim, -prec, prec, round_floor)
369 return re, im
371 def mpf_gamma(x, prec, rounding=round_fast, p1=1):
372 sign, man, exp, bc = x
373 if exp >= 0:
374 if sign or (p1 and not man):
375 raise ValueError("gamma function pole")
376 # A direct factorial is fastest
377 if exp + bc <= 10:
378 return from_int(int_fac((man<<exp)-p1), prec, rounding)
379 wp = prec + 15
380 if p1:
381 x = fsub(x, fone, wp)
382 # x < 0.25
383 if sign or exp+bc < -1:
384 # gamma = pi / (sin(pi*x) * gamma(1-x))
385 wp += 15
386 pi = fpi(wp)
387 pix = fmul(x, pi, wp)
388 t = fsin(pix, wp)
389 g = mpf_gamma(fsub(fone, x, wp), wp)
390 return fdiv(pix, fmul(t, g, wp), prec, rounding)
391 sprec, a, c = get_spouge_coefficients(wp)
392 s = spouge_sum_real(x, sprec, a, c)
393 # gamma = exp(log(x+a)*(x+0.5) - xpa) * s
394 xpa = fadd(x, from_int(a), wp)
395 logxpa = flog(xpa, wp)
396 xph = fadd(x, fhalf, wp)
397 t = fsub(fmul(logxpa, xph, wp), xpa, wp)
398 t = fmul(fexp(t, wp), s, prec, rounding)
399 return t
401 def mpc_gamma(x, prec, rounding=round_fast, p1=1):
402 re, im = x
403 if im == fzero:
404 return mpf_gamma(re, prec, rounding, p1), fzero
405 wp = prec + 25
406 sign, man, exp, bc = re
407 if p1:
408 re = fsub(re, fone, wp)
409 x = re, im
410 if sign or exp+bc < -1:
411 # Reflection formula
412 wp += 15
413 pi = fpi(wp), fzero
414 pix = mpc_mul(x, pi, wp)
415 t = mpc_sin(pix, wp)
416 u = mpc_sub(mpc_one, x, wp)
417 g = mpc_gamma(u, wp)
418 w = mpc_mul(t, g, wp)
419 return mpc_div(pix, w, wp)
420 sprec, a, c = get_spouge_coefficients(wp)
421 s = spouge_sum_complex(re, im, sprec, a, c)
422 # gamma = exp(log(x+a)*(x+0.5) - xpa) * s
423 repa = fadd(re, from_int(a), wp)
424 logxpa = mpc_log((repa, im), wp)
425 reph = fadd(re, fhalf, wp)
426 t = mpc_sub(mpc_mul(logxpa, (reph, im), wp), (repa, im), wp)
427 t = mpc_mul(mpc_exp(t, wp), s, prec, rounding)
428 return t
430 def gamma(x):
431 x = convert_lossless(x)
432 prec = mp.prec
433 if isinstance(x, mpf):
434 return make_mpf(mpf_gamma(x._mpf_, prec, round_nearest, 1))
435 else:
436 return make_mpc(mpc_gamma(x._mpc_, prec, round_nearest, 1))
438 def factorial(x):
439 x = convert_lossless(x)
440 prec = mp.prec
441 if isinstance(x, mpf):
442 return make_mpf(mpf_gamma(x._mpf_, prec, round_nearest, 0))
443 else:
444 return make_mpc(mpc_gamma(x._mpc_, prec, round_nearest, 0))
446 def isnpint(x):
447 if not x:
448 return True
449 if isinstance(x, mpf):
450 sign, man, exp, bc = x._mpf_
451 return sign and exp >= 0
452 if isinstance(x, mpc):
453 return not x.imag and isnpint(x.real)
455 def gammaprod(a, b):
457 Computes the product / quotient of gamma functions
459 G(a_0) G(a_1) ... G(a_p)
460 ------------------------
461 G(b_0) G(b_1) ... G(a_q)
463 with proper cancellation of poles (interpreting the expression as a
464 limit). Returns +inf if the limit diverges.
466 a = [convert_lossless(x) for x in a]
467 b = [convert_lossless(x) for x in b]
468 poles_num = []
469 poles_den = []
470 regular_num = []
471 regular_den = []
472 for x in a: [regular_num, poles_num][isnpint(x)].append(x)
473 for x in b: [regular_den, poles_den][isnpint(x)].append(x)
474 # One more pole in numerator or denominator gives 0 or inf
475 if len(poles_num) < len(poles_den): return mpf(0)
476 if len(poles_num) > len(poles_den): return mpf('+inf')
477 # All poles cancel
478 # lim G(i)/G(j) = (-1)**(i+j) * gamma(1-j) / gamma(1-i)
479 p = mpf(1)
480 orig = mp.prec
481 try:
482 mp.prec = orig + 15
483 while poles_num:
484 i = poles_num.pop()
485 j = poles_den.pop()
486 p *= (-1)**(i+j) * gamma(1-j) / gamma(1-i)
487 for x in regular_num: p *= gamma(x)
488 for x in regular_den: p /= gamma(x)
489 finally:
490 mp.prec = orig
491 return +p
493 def binomial(n, k):
494 """Binomial coefficient, C(n,k) = n!/(k!*(n-k)!)."""
495 return gammaprod([n+1], [k+1, n-k+1])
497 def rf(x, n):
498 """Rising factorial (Pochhammer symbol), x^(n)"""
499 return gammaprod([x+n], [x])
501 def ff(x, n):
502 """Falling factorial, x_(n)"""
503 return gammaprod([x+1], [x-n+1])
507 #---------------------------------------------------------------------------#
509 # Riemann zeta function #
511 #---------------------------------------------------------------------------#
514 We use zeta(s) = eta(s) * (1 - 2**(1-s)) and Borwein's approximation
516 ___ k
517 -1 \ (-1) (d_k - d_n)
518 eta(s) ~= ---- ) ------------------
519 d_n /___ s
520 k = 0 (k + 1)
521 where
523 ___ i
524 \ (n + i - 1)! 4
525 d_k = n ) ---------------.
526 /___ (n - i)! (2i)!
527 i = 0
529 If s = a + b*I, the absolute error for eta(s) is bounded by
531 3 (1 + 2|b|)
532 ------------ * exp(|b| pi/2)
534 (3+sqrt(8))
536 Disregarding the linear term, we have approximately,
538 log(err) ~= log(exp(1.58*|b|)) - log(5.8**n)
539 log(err) ~= 1.58*|b| - log(5.8)*n
540 log(err) ~= 1.58*|b| - 1.76*n
541 log2(err) ~= 2.28*|b| - 2.54*n
543 So for p bits, we should choose n > (p + 2.28*|b|) / 2.54.
545 Reference:
546 Peter Borwein, "An Efficient Algorithm for the Riemann Zeta Function"
547 http://www.cecm.sfu.ca/personal/pborwein/PAPERS/P117.ps
549 http://en.wikipedia.org/wiki/Dirichlet_eta_function
552 d_cache = {}
554 def zeta_coefs(n):
555 if n in d_cache:
556 return d_cache[n]
557 ds = [MP_ZERO] * (n+1)
558 d = MP_ONE
559 s = ds[0] = MP_ONE
560 for i in range(1, n+1):
561 d = d * 4 * (n+i-1) * (n-i+1)
562 d //= ((2*i) * ((2*i)-1))
563 s += d
564 ds[i] = s
565 d_cache[n] = ds
566 return ds
568 # Integer logarithms
569 _log_cache = {}
571 def _logk(k):
572 p = mp.prec
573 if k in _log_cache and _log_cache[k][0] >= p:
574 return +_log_cache[k][1]
575 else:
576 x = log(k)
577 _log_cache[k] = (p, x)
578 return x
580 @extraprec(10, normalize_output=True)
581 def zeta(s):
582 """Returns the Riemann zeta function of s."""
583 s = convert_lossless(s)
584 if s.real < 0:
585 # Reflection formula (XXX: gets bad around the zeros)
586 return 2**s * pi**(s-1) * sin(pi*s/2) * gamma(1-s) * zeta(1-s)
587 else:
588 p = mp.prec
589 n = int((p + 2.28*abs(float(mpc(s).imag)))/2.54) + 3
590 d = zeta_coefs(n)
591 if isinstance(s, mpf) and s == int(s):
592 sint = int(s)
593 t = 0
594 for k in range(n):
595 t += (((-1)**k * (d[k] - d[n])) << p) // (k+1)**sint
596 return (mpf((t, -p)) / -d[n]) / (1 - mpf(2)**(1-sint))
597 else:
598 t = mpf(0)
599 for k in range(n):
600 t += (-1)**k * mpf(d[k]-d[n]) * exp(-_logk(k+1)*s)
601 return (t / -d[n]) / (mpf(1) - exp(log(2)*(1-s)))
604 @extraprec(5, normalize_output=True)
605 def bernoulli(n):
606 """nth Bernoulli number, B_n"""
607 if n == 1:
608 return mpf(-0.5)
609 if n & 1:
610 return mpf(0)
611 m = n // 2
612 return (-1)**(m-1) * 2 * factorial(n) / (2*pi)**n * zeta(n)
614 # For sequential computation of Bernoulli numbers, we use Ramanujan's formula
616 # / n + 3 \
617 # B = (A(n) - S(n)) / | |
618 # n \ n /
620 # where A(n) = (n+3)/3 when n = 0 or 2 (mod 6), A(n) = -(n+3)/6
621 # when n = 4 (mod 6), and
623 # [n/6]
624 # ___
625 # \ / n + 3 \
626 # S(n) = ) | | * B
627 # /___ \ n - 6*k / n-6*k
628 # k = 1
630 def bernoulli_range():
631 """Generates B(2), B(4), B(6), ..."""
632 oprec = mp.prec
633 rounding = mp.rounding[0]
634 prec = oprec + 30
635 computed = {0:fone}
636 m, bin1, bin = 2, MP_ONE, MP_BASE(10)
637 f3 = from_int(3)
638 f6 = from_int(6)
639 while 1:
640 case = m % 6
641 s = fzero
642 if m < 6: a = MP_ZERO
643 else: a = bin1
644 for j in xrange(1, m//6+1):
645 s = fadd(s, fmuli(computed[m-6*j], a, prec), prec)
646 # Inner binomial coefficient
647 j6 = 6*j
648 a *= ((m-5-j6)*(m-4-j6)*(m-3-j6)*(m-2-j6)*(m-1-j6)*(m-j6))
649 a //= ((4+j6)*(5+j6)*(6+j6)*(7+j6)*(8+j6)*(9+j6))
650 if case == 0: b = fdivi(m+3, f3, prec)
651 if case == 2: b = fdivi(m+3, f3, prec)
652 if case == 4: b = fdivi(-m-3, f6, prec)
653 b = fdiv(fsub(b, s, prec), from_int(bin), prec)
654 computed[m] = b
655 yield make_mpf(fpos(b, oprec, rounding))
656 m += 2
657 bin = bin * ((m+2)*(m+3)) // (m*(m-1))
658 if m > 6: bin1 = bin1 * ((2+m)*(3+m)) // ((m-7)*(m-6))
661 #---------------------------------------------------------------------------#
663 # Hypergeometric functions #
665 #---------------------------------------------------------------------------#
667 import operator
670 TODO:
671 * By counting the number of multiplications vs divisions,
672 the bit size of p can be kept around wp instead of growing
673 it to n*wp for some (possibly large) n
675 * Due to roundoff error, the series may fail to converge
676 when x is negative and the convergence is slow.
680 def hypsum(ar, af, ac, br, bf, bc, x):
682 Generic hypergeometric summation. This function computes:
684 1 a_1 a_2 ... 1 (a_1 + 1) (a_2 + 1) ... 2
685 1 + -- ----------- x + -- ----------------------- x + ...
686 1! b_1 b_2 ... 2! (b_1 + 1) (b_2 + 1) ...
688 The a_i and b_i sequences are separated by type:
690 ar - list of a_i rationals [p,q]
691 af - list of a_i mpf value tuples
692 ac - list of a_i mpc value tuples
693 br - list of b_i rationals [p,q]
694 bf - list of b_i mpf value tuples
695 bc - list of b_i mpc value tuples
697 Note: the rational coefficients will be updated in-place and must
698 hence be mutable (lists rather than tuples).
700 x must be an mpf or mpc instance.
703 have_float = af or bf
704 have_complex = ac or bc
706 prec = mp.prec
707 rnd = mp.rounding[0]
708 wp = prec + 25
710 if isinstance(x, mpf):
711 x = to_fixed(x._mpf_, wp)
712 y = MP_ZERO
713 else:
714 have_complex = 1
715 x, y = x._mpc_
716 x = to_fixed(x, wp)
717 y = to_fixed(y, wp)
719 sre = pre = one = MP_ONE << wp
720 sim = pim = MP_ZERO
722 n = 1
724 # Need to shift down by wp once for each fixed-point multiply
725 # At minimum, we multiply by once by x each step
726 shift = 1
728 # Fixed-point real coefficients
729 if have_float:
730 len_af = len(af)
731 len_bf = len(bf)
732 range_af = range(len_af)
733 range_bf = range(len_bf)
734 for i in range_af: af[i] = to_fixed(af[i], wp)
735 for i in range_bf: bf[i] = to_fixed(bf[i], wp)
736 shift += len_af
738 if have_complex:
739 len_ac = len(ac)
740 len_bc = len(bc)
741 range_ac = range(len_ac)
742 range_bc = range(len_bc)
743 for i in range_ac: ac[i] = [to_fixed(ac[i][0], wp), to_fixed(ac[i][1], wp)]
744 for i in range_bc: bc[i] = [to_fixed(bc[i][0], wp), to_fixed(bc[i][1], wp)]
745 shift += len_ac
747 aqs = [a[1] for a in ar]
748 bqs = [b[1] for b in br]
749 aqprod = reduce(operator.mul, aqs, 1)
750 bqprod = reduce(operator.mul, bqs, 1)
752 assert shift >= 0
754 while 1:
755 # Integer and rational part of product
756 mul = bqprod
757 div = n * aqprod
758 for ap, aq in ar: mul *= ap
759 for bp, bq in br: div *= bp
761 if have_complex:
762 # Multiply by rational factors
763 pre *= mul
764 pim *= mul
765 # Multiply by z
766 pre, pim = pre*x - pim*y, pim*x + pre*y
767 # Multiply by real factors
768 for a in af:
769 pre *= a
770 pim *= a
771 # Multiply by complex factors
772 for are, aim in ac:
773 pre, pim = pre*are - pim*aim, pim*are + pre*aim
774 # Divide by rational factors
775 pre //= div
776 pim //= div
777 # Divide by real factors
778 for b in bf:
779 pre = (pre << wp) // b
780 pim = (pim << wp) // b
781 # Divide by complex factors
782 for bre, bim in bc:
783 mag = bre*bre + bim*bim
784 re = pre*bre + pim*bim
785 im = pim*bre - pre*bim
786 pre = (re << wp) // mag
787 pim = (im << wp) // mag
788 elif have_float:
789 # Multiply and divide by real and rational factors, and x
790 for a in af: pre *= a
791 for b in bf:
792 pre = (pre << wp) // b
793 pre = (pre * (mul * x)) // div
795 else:
796 # Multiply and divide by rational factors and x
797 pre = (pre * (mul * x)) // div
799 pre >>= (wp*shift)
800 sre += pre
802 if have_complex:
803 pim >>= (wp*shift)
804 sim += pim
805 if (-100 < pre < 100) and (-100 < pim < 100):
806 break
807 else:
808 if -100 < pre < 100:
809 break
811 # Add 1 to all as and bs
812 n += 1
813 for ap_aq in ar: ap_aq[0] += ap_aq[1]
814 for bp_bq in br: bp_bq[0] += bp_bq[1]
815 if have_float:
816 for i in range_af: af[i] += one
817 for i in range_bf: bf[i] += one
818 if have_complex:
819 for i in range_ac: ac[i][0] += one
820 for i in range_bc: bc[i][0] += one
822 re = from_man_exp(sre, -wp, prec, rnd)
823 if have_complex:
824 return make_mpc((re, from_man_exp(sim, -wp, prec, rnd)))
825 else:
826 return make_mpf(re)
829 #---------------------------------------------------------------------------#
830 # Special-case implementation for rational parameters. These are #
831 # about 2x faster at low precision #
832 #---------------------------------------------------------------------------#
834 def sum_hyp0f1_rat((bp, bq), x):
835 """Sum 0F1 for rational a. x must be mpf or mpc."""
836 prec = mp.prec
837 rnd = mp.rounding[0]
838 wp = prec + 25
839 if isinstance(x, mpf):
840 x = to_fixed(x._mpf_, wp)
841 s = p = MP_ONE << wp
842 n = 1
843 while 1:
844 p = (p * (bq*x) // (n*bp)) >> wp
845 if -100 < p < 100:
846 break
847 s += p; n += 1; bp += bq
848 return make_mpf(from_man_exp(s, -wp, prec, rnd))
849 else:
850 wp = prec + 25
851 zre, zim = x._mpc_
852 zre = to_fixed(zre, wp)
853 zim = to_fixed(zim, wp)
854 sre = pre = MP_ONE << wp
855 sim = pim = MP_ZERO
856 n = 1
857 while 1:
858 r1 = bq
859 r2 = n*bp
860 pre, pim = pre*zre - pim*zim, pim*zre + pre*zim
861 pre = ((pre * r1) // r2) >> wp
862 pim = ((pim * r1) // r2) >> wp
863 if -100 < pre < 100 and -100 < pim < 100:
864 break
865 sre += pre; sim += pim; n += 1; bp += bq
866 re = from_man_exp(sre, -wp, prec, rnd)
867 im = from_man_exp(sim, -wp, prec, rnd)
868 return make_mpc((re, im))
871 def sum_hyp1f1_rat((ap, aq), (bp, bq), x):
872 """Sum 1F1 for rational a, b. x must be mpf or mpc."""
873 prec = mp.prec
874 rnd = mp.rounding[0]
875 wp = prec + 25
876 if isinstance(x, mpf):
877 x = to_fixed(x._mpf_, wp)
878 s = p = MP_ONE << wp
879 n = 1
880 while 1:
881 p = (p * (ap*bq*x) // (n*aq*bp)) >> wp
882 if -100 < p < 100:
883 break
884 s += p; n += 1; ap += aq; bp += bq
885 return make_mpf(from_man_exp(s, -wp, prec, rnd))
886 else:
887 wp = prec + 25
888 zre, zim = x._mpc_
889 zre = to_fixed(zre, wp)
890 zim = to_fixed(zim, wp)
891 sre = pre = MP_ONE << wp
892 sim = pim = MP_ZERO
893 n = 1
894 while 1:
895 r1 = ap*bq
896 r2 = n*aq*bp
897 pre, pim = pre*zre - pim*zim, pim*zre + pre*zim
898 pre = ((pre * r1) // r2) >> wp
899 pim = ((pim * r1) // r2) >> wp
900 if -100 < pre < 100 and -100 < pim < 100:
901 break
902 sre += pre; sim += pim; n += 1; ap += aq; bp += bq
903 re = from_man_exp(sre, -wp, prec, rnd)
904 im = from_man_exp(sim, -wp, prec, rnd)
905 return make_mpc((re, im))
907 def sum_hyp2f1_rat((ap, aq), (bp, bq), (cp, cq), x):
908 """Sum 2F1 for rational a, b, c. x must be mpf or mpc"""
909 prec = mp.prec
910 rnd = mp.rounding[0]
911 wp = prec + 25
912 if isinstance(x, mpf):
913 x = to_fixed(x._mpf_, wp)
914 s = p = MP_ONE << wp
915 n = 1
916 while 1:
917 p = (p * (ap*bp*cq*x) // (n*aq*bq*cp)) >> wp
918 if -100 < p < 100:
919 break
920 s += p; n += 1; ap += aq; bp += bq; cp += cq
921 return make_mpf(from_man_exp(s, -wp, prec, rnd))
922 else:
923 wp = prec + 25
924 zre, zim = x._mpc_
925 zre = to_fixed(zre, wp)
926 zim = to_fixed(zim, wp)
927 sre = pre = MP_ONE << wp
928 sim = pim = MP_ZERO
929 n = 1
930 while 1:
931 r1 = ap*bp*cq
932 r2 = n*aq*bq*cp
933 pre, pim = pre*zre - pim*zim, pim*zre + pre*zim
934 pre = ((pre * r1) // r2) >> wp
935 pim = ((pim * r1) // r2) >> wp
936 if -100 < pre < 100 and -100 < pim < 100:
937 break
938 sre += pre; sim += pim; n += 1; ap += aq; bp += bq; cp += cq
939 re = from_man_exp(sre, -wp, prec, rnd)
940 im = from_man_exp(sim, -wp, prec, rnd)
941 return make_mpc((re, im))
943 def parse_param(x):
944 if isinstance(x, tuple):
945 p, q = x
946 return [[p, q]], [], []
947 if isinstance(x, (int, long)):
948 return [[x, 1]], [], []
949 x = convert_lossless(x)
950 if isinstance(x, mpf):
951 return [], [x._mpf_], []
952 if isinstance(x, mpc):
953 return [], [], [x._mpc_]
955 class _mpq(tuple):
956 @property
957 def _mpf_(self):
958 return (mpf(self[0])/self[1])._mpf_
959 def __add__(self, other):
960 if isinstance(other, _mpq):
961 a, b = self
962 c, d = other
963 return _mpq((a*d+b*c, b*d))
964 return NotImplemented
965 def __sub__(self, other):
966 if isinstance(other, _mpq):
967 a, b = self
968 c, d = other
969 return _mpq((a*d-b*c, b*d))
970 return NotImplemented
972 _1 = _mpq((1,1))
973 _0 = _mpq((0,1))
975 def _as_num(x):
976 if isinstance(x, list):
977 return _mpq(x)
978 return x
980 def eval_hyp2f1(a,b,c,z):
981 ar, af, ac = parse_param(a)
982 br, bf, bc = parse_param(b)
983 cr, cf, cc = parse_param(c)
984 absz = abs(z)
985 if absz == 1:
986 # TODO: determine whether it actually does, and otherwise
987 # return infinity instead
988 print "Warning: 2F1 might not converge for |z| = 1"
989 if absz <= 1:
990 if ar and br and cr:
991 return sum_hyp2f1_rat(ar[0], br[0], cr[0], z)
992 return hypsum(ar+br, af+bf, ac+bc, cr, cf, cc, z)
993 # Use 1/z transformation
994 a = (ar and _as_num(ar[0])) or convert_lossless(a)
995 b = (br and _as_num(br[0])) or convert_lossless(b)
996 c = (cr and _as_num(cr[0])) or convert_lossless(c)
997 orig = mp.prec
998 try:
999 mp.prec = orig + 15
1000 h1 = eval_hyp2f1(a, _1-c+a, _1-b+a, 1/z)
1001 h2 = eval_hyp2f1(b, _1-c+b, _1-a+b, 1/z)
1002 #s1 = G(c)*G(b-a)/G(b)/G(c-a) * (-z)**(-a) * h1
1003 #s2 = G(c)*G(a-b)/G(a)/G(c-b) * (-z)**(-b) * h2
1004 f1 = gammaprod([c,b-a],[b,c-a])
1005 f2 = gammaprod([c,a-b],[a,c-b])
1006 s1 = f1 * (-z)**(_0-a) * h1
1007 s2 = f2 * (-z)**(_0-b) * h2
1008 v = s1 + s2
1009 finally:
1010 mp.prec = orig
1011 return +v
1013 #---------------------------------------------------------------------------#
1014 # And now the user-friendly versions #
1015 #---------------------------------------------------------------------------#
1017 def hyper(as, bs, z):
1019 Hypergeometric function pFq,
1021 [ a_1, a_2, ..., a_p | ]
1022 pFq [ | z ]
1023 [ b_1, b_2, ..., b_q | ]
1025 The parameter lists as and bs may contain real or complex numbers.
1026 Exact rational parameters can be given as tuples (p, q).
1028 p = len(as)
1029 q = len(bs)
1030 z = convert_lossless(z)
1031 degree = p, q
1032 if degree == (0, 1):
1033 br, bf, bc = parse_param(bs[0])
1034 if br:
1035 return sum_hyp0f1_rat(br[0], z)
1036 return hypsum([], [], [], br, bf, bc, z)
1037 if degree == (1, 1):
1038 ar, af, ac = parse_param(as[0])
1039 br, bf, bc = parse_param(bs[0])
1040 if ar and br:
1041 a, b = ar[0], br[0]
1042 return sum_hyp1f1_rat(a, b, z)
1043 return hypsum(ar, af, ac, br, bf, bc, z)
1044 if degree == (2, 1):
1045 return eval_hyp2f1(as[0],as[1],bs[0],z)
1046 ars, afs, acs, brs, bfs, bcs = [], [], [], [], [], []
1047 for a in as:
1048 r, f, c = parse_param(a)
1049 ars += r
1050 afs += f
1051 acs += c
1052 for b in bs:
1053 r, f, c = parse_param(b)
1054 brs += r
1055 bfs += f
1056 bcs += c
1057 return hypsum(ars, afs, acs, brs, bfs, bcs, z)
1059 def hyp0f1(a, z):
1060 """Hypergeometric function 0F1. hyp0f1(a,z) is equivalent
1061 to hyper([], [a], z); see documentation for hyper() for more
1062 information."""
1063 return hyper([], [a], z)
1065 def hyp1f1(a,b,z):
1066 """Hypergeometric function 1F1. hyp1f1(a,b,z) is equivalent
1067 to hyper([a], [b], z); see documentation for hyper() for more
1068 information."""
1069 return hyper([a], [b], z)
1071 def hyp2f1(a,b,c,z):
1072 """Hypergeometric function 2F1. hyp2f1(a,b,c,z) is equivalent
1073 to hyper([a,b], [c], z); see documentation for hyper() for more
1074 information."""
1075 return hyper([a,b], [c], z)
1077 def funcwrapper(f):
1078 def g(z):
1079 orig = mp.prec
1080 rnd = mp.rounding[0]
1081 try:
1082 z = convert_lossless(z)
1083 mp.prec = orig + 10
1084 v = f(z)
1085 finally:
1086 mp.prec = orig
1087 return +v
1088 g.__name__ = f.__name__
1089 g.__doc__ = f.__doc__
1090 return g
1092 @extraprec(20, normalize_output=True)
1093 def lower_gamma(a,z):
1094 """Lower incomplete gamma function gamma(a, z)"""
1095 z = convert_lossless(z)
1096 if not isinstance(a, (int, long)):
1097 a = convert_lossless(a)
1098 # XXX: may need more precision
1099 return hyp1f1(1, 1+a, z) * z**a * exp(-z) / a
1101 @extraprec(20, normalize_output=True)
1102 def upper_gamma(a,z):
1103 """Upper incomplete gamma function Gamma(a, z)"""
1104 return gamma(a) - lower_gamma(a, z)
1106 @funcwrapper
1107 def erf(z):
1108 """Error function, erf(z)"""
1109 return (2/sqrt(pi)*z) * sum_hyp1f1_rat((1,2),(3,2), -z**2)
1111 @funcwrapper
1112 def ellipk(m):
1113 """Complete elliptic integral of the first kind, K(m). Note that
1114 the argument is the parameter m = k^2, not the modulus k."""
1115 if m == 1:
1116 return inf
1117 return pi/2 * sum_hyp2f1_rat((1,2),(1,2),(1,1), m)
1119 @funcwrapper
1120 def ellipe(m):
1121 """Complete elliptic integral of the second kind, E(m). Note that
1122 the argument is the parameter m = k^2, not the modulus k."""
1123 if m == 1:
1124 return m
1125 return pi/2 * sum_hyp2f1_rat((1,2),(-1,2),(1,1), m)
1127 # TODO: for complex a, b handle the branch cut correctly
1128 @extraprec(15, normalize_output=True)
1129 def agm(a, b):
1130 """Arithmetic-geometric mean of a and b."""
1131 a = convert_lossless(a)
1132 b = convert_lossless(b)
1133 if not a or not b:
1134 return a*b
1135 weps = eps * 16
1136 half = mpf(0.5)
1137 while abs(a-b) > weps:
1138 a, b = (a+b)*half, (a*b)**half
1139 return a
1141 def jacobi(n, a, b, x):
1142 """Jacobi polynomial P_n^(a,b)(x)."""
1143 orig = mp.prec
1144 try:
1145 mp.prec = orig + 15
1146 x = convert_lossless(x)
1147 v = binomial(n+a,n) * hyp2f1(-n,1+n+a+b,a+1,(1-x)/2)
1148 finally:
1149 mp.prec = orig
1150 return +v
1152 def legendre(n, x):
1153 """Legendre polynomial P_n(x)."""
1154 orig = mp.prec
1155 try:
1156 mp.prec = orig + 15
1157 x = convert_lossless(x)
1158 if not isinstance(n, (int, long)):
1159 n = convert_lossless(n)
1160 if x == -1:
1161 # TODO: hyp2f1 should handle this
1162 if x == int(x):
1163 return (-1)**(n + (n>=0)) * mpf(-1)
1164 return inf
1165 v = hyp2f1(-n,n+1,1,(1-x)/2)
1166 finally:
1167 mp.prec = orig
1168 return +v
1170 def chebyt(n, x):
1171 """Chebyshev polynomial of the first kind T_n(x)."""
1172 orig = mp.prec
1173 try:
1174 mp.prec = orig + 15
1175 x = convert_lossless(x)
1176 v = hyp2f1(-n,n,0.5,(1-x)/2)
1177 finally:
1178 mp.prec = orig
1179 return +v
1181 def chebyu(n, x):
1182 """Chebyshev polynomial of the second kind U_n(x)."""
1183 orig = mp.prec
1184 try:
1185 mp.prec = orig + 15
1186 x = convert_lossless(x)
1187 v = (n+1) * hyp2f1(-n,n+2,1.5,(1-x)/2)
1188 finally:
1189 mp.prec = orig
1190 return +v
1192 # A Bessel function of the first kind of integer order, J_n(x), is
1193 # given by the power series
1195 # oo
1196 # ___ k 2 k + n
1197 # \ (-1) / x \
1198 # J_n(x) = ) ----------- | - |
1199 # /___ k! (k + n)! \ 2 /
1200 # k = 0
1202 # Simplifying the quotient between two successive terms gives the
1203 # ratio x^2 / (-4*k*(k+n)). Hence, we only need one full-precision
1204 # multiplication and one division by a small integer per term.
1205 # The complex version is very similar, the only difference being
1206 # that the multiplication is actually 4 multiplies.
1208 # In the general case, we have
1209 # J_v(x) = (x/2)**v / v! * 0F1(v+1, (-1/4)*z**2)
1211 # TODO: for extremely large x, we could use an asymptotic
1212 # trigonometric approximation.
1214 # TODO: recompute at higher precision if the fixed-point mantissa
1215 # is very small
1217 def mpf_jn_series(n, x, prec):
1218 negate = n < 0 and n & 1
1219 n = abs(n)
1220 origprec = prec
1221 prec += 20 + bitcount(abs(n))
1222 x = to_fixed(x, prec)
1223 x2 = (x**2) >> prec
1224 if not n:
1225 s = t = MP_ONE << prec
1226 else:
1227 s = t = (x**n // int_fac(n)) >> ((n-1)*prec + n)
1228 k = 1
1229 while t:
1230 t = ((t * x2) // (-4*k*(k+n))) >> prec
1231 s += t
1232 k += 1
1233 if negate:
1234 s = -s
1235 return make_mpf(from_man_exp(s, -prec, origprec, round_nearest))
1237 def mpc_jn_series(n, z, prec):
1238 negate = n < 0 and n & 1
1239 n = abs(n)
1240 origprec = prec
1241 prec += 20 + bitcount(abs(n))
1242 zre, zim = z
1243 zre = to_fixed(zre, prec)
1244 zim = to_fixed(zim, prec)
1245 z2re = (zre**2 - zim**2) >> prec
1246 z2im = (zre*zim) >> (prec-1)
1247 if not n:
1248 sre = tre = MP_ONE << prec
1249 sim = tim = MP_ZERO
1250 else:
1251 re, im = complex_int_pow(zre, zim, n)
1252 sre = tre = (re // int_fac(n)) >> ((n-1)*prec + n)
1253 sim = tim = (im // int_fac(n)) >> ((n-1)*prec + n)
1254 k = 1
1255 while abs(tre) + abs(tim) > 3:
1256 p = -4*k*(k+n)
1257 tre, tim = tre*z2re - tim*z2im, tim*z2re + tre*z2im
1258 tre = (tre // p) >> prec
1259 tim = (tim // p) >> prec
1260 sre += tre
1261 sim += tim
1262 k += 1
1263 if negate:
1264 sre = -sre
1265 sim = -sim
1266 re = from_man_exp(sre, -prec, origprec, round_nearest)
1267 im = from_man_exp(sim, -prec, origprec, round_nearest)
1268 return make_mpc((re, im))
1270 def jv(v, x):
1271 """Bessel function J_v(x)."""
1272 prec = mp.prec
1273 x = convert_lossless(x)
1274 if isinstance(v, int_types):
1275 if isinstance(x, mpf):
1276 return mpf_jn_series(v, x._mpf_, prec)
1277 if isinstance(x, mpc):
1278 return mpc_jn_series(v, (x.real._mpf_, x.imag._mpf_), prec)
1279 hx = x/2
1280 return hx**v * hyp0f1(v+1, -hx**2) / factorial(v)
1282 jn = jv
1284 def j0(x):
1285 """Bessel function J_0(x)."""
1286 return jv(0, x)
1288 def j1(x):
1289 """Bessel function J_1(x)."""
1290 return jv(1, x)
1292 #---------------------------------------------------------------------------#
1294 # Miscellaneous #
1296 #---------------------------------------------------------------------------#
1299 def log_range():
1300 """Generate log(2), log(3), log(4), ..."""
1301 prec = mp.prec + 20
1302 one = 1 << prec
1303 L = log2_fixed(prec)
1304 p = 2
1305 while 1:
1306 yield mpf((L, -prec))
1307 s = 0
1308 u = one
1309 k = 1
1310 a = (2*p+1)**2
1311 while u:
1312 s += u // k
1313 u //= a
1314 k += 2
1315 L += 2*s//(2*p+1)
1316 p += 1
1318 @extraprec(30, normalize_output=True)
1319 def lambertw(z, k=0, approx=None):
1321 lambertw(z,k) gives the kth branch of the Lambert W function W(z),
1322 defined as the kth solution of z = W(z)*exp(W(z)).
1324 lambertw(z) == lambertw(z, k=0) gives the principal branch
1325 value (0th branch solution), which is real for z > -1/e .
1327 The k = -1 branch is real for -1/e < z < 0. All branches except
1328 k = 0 have a logarithmic singularity at 0.
1330 The definition, implementation and choice of branches is based
1331 on Corless et al, "On the Lambert W function", Adv. Comp. Math. 5
1332 (1996) 329-359, available online here:
1333 http://www.apmaths.uwo.ca/~djeffrey/Offprints/W-adv-cm.pdf
1335 TODO: use a series expansion when extremely close to the branch point
1336 at -1/e and make sure that the proper branch is chosen there
1338 z = convert_lossless(z)
1339 if isnan(z):
1340 return z
1341 # We must be extremely careful near the singularities at -1/e and 0
1342 u = exp(-1)
1343 if abs(z) <= u:
1344 if not z:
1345 # w(0,0) = 0; for all other branches we hit the pole
1346 if not k:
1347 return z
1348 return -inf
1349 if not k:
1350 w = z
1351 # For small real z < 0, the -1 branch behaves roughly like log(-z)
1352 elif k == -1 and not z.imag and z.real < 0:
1353 w = log(-z)
1354 # Use a simple asymptotic approximation.
1355 else:
1356 w = log(z)
1357 # The branches are roughly logarithmic. This approximation
1358 # gets better for large |k|; need to check that this always
1359 # works for k ~= -1, 0, 1.
1360 if k: w += k * 2*pi*j
1361 else:
1362 if z == inf: return z
1363 if z == -inf: return nan
1364 # Simple asymptotic approximation as above
1365 w = log(z)
1366 if k: w += k * 2*pi*j
1367 # Use Halley iteration to solve w*exp(w) = z
1368 two = mpf(2)
1369 weps = ldexp(eps, 15)
1370 for i in xrange(100):
1371 ew = exp(w)
1372 wew = w*ew
1373 wewz = wew-z
1374 wn = w - wewz/(wew+ew-(w+two)*wewz/(two*w+two))
1375 if abs(wn-w) < weps*abs(wn):
1376 return wn
1377 else:
1378 w = wn
1379 print "Warning: Lambert W iteration failed to converge:", z
1380 return wn