added pi digits example script
[sympy.git] / sympy / modules / numerics / float_.py
blobffe7837ce7cae6d8d812cd9ce004455b9a48525a
1 """
2 This module implements a class Float for arbitrary-precision binary
3 floating-point arithmetic. It is typically 10-100 times faster
4 than Python's standard Decimals. For details on usage, refer to the
5 docstrings in the Float class.
6 """
8 import math
9 _clog = math.log
10 _csqrt = math.sqrt
12 from sympy import Rational
13 from utils_ import bitcount, trailing_zeros
15 #----------------------------------------------------------------------
16 # Rounding modes
19 class _RoundingMode(int):
20 def __new__(cls, level, name):
21 a = int.__new__(cls, level)
22 a.name = name
23 return a
24 def __repr__(self):
25 return self.name
27 ROUND_DOWN = _RoundingMode(1, 'ROUND_DOWN')
28 ROUND_UP = _RoundingMode(2, 'ROUND_UP')
29 ROUND_FLOOR = _RoundingMode(3, 'ROUND_FLOOR')
30 ROUND_CEILING = _RoundingMode(4, 'ROUND_CEILING')
31 ROUND_HALF_UP = _RoundingMode(5, 'ROUND_HALF_UP')
32 ROUND_HALF_DOWN = _RoundingMode(6, 'ROUND_HALF_DOWN')
33 ROUND_HALF_EVEN = _RoundingMode(7, 'ROUND_HALF_EVEN')
36 #----------------------------------------------------------------------
37 # Helper functions for bit manipulation
40 def rshift(x, n, mode):
41 """
42 Shift x n bits to the right (i.e., calculate x/(2**n)), and
43 round to the nearest integer in accordance with the specified
44 rounding mode. The exponent n may be negative, in which case x is
45 shifted to the left (and no rounding is necessary).
46 """
47 if n == 0 or x == 0: return x
48 if n < 0: return x << -n
50 # Bit-fiddling is relatively expensive in Python. To get away easily, we
51 # can exploit the fact that Python rounds positive integers toward
52 # zero and negative integers away from zero when dividing/shifting
54 # These cases can be handled by simple shifting
55 if mode < ROUND_HALF_UP:
56 if mode == ROUND_DOWN:
57 if x > 0: return x >> n
58 else: return -((-x) >> n)
59 if mode == ROUND_UP:
60 if x > 0: return -((-x) >> n)
61 else: return x >> n
62 if mode == ROUND_FLOOR:
63 return x >> n
64 if mode == ROUND_CEILING:
65 return -((-x) >> n)
67 # Here we need to inspect the bits around the cutoff point
68 if x > 0: t = x >> (n-1)
69 else: t = (-x) >> (n-1)
70 if t & 1:
71 if mode == ROUND_HALF_UP or \
72 (mode == ROUND_HALF_DOWN and x & ((1<<(n-1))-1)) or \
73 (mode == ROUND_HALF_EVEN and (t&2 or x & ((1<<(n-1))-1))):
74 if x > 0: return (t>>1)+1
75 else: return -((t>>1)+1)
76 if x > 0: return t>>1
77 else: return -(t>>1)
79 def normalize(man, exp, prec, mode):
80 """
81 Normalize the binary floating-point number represented by
82 man * 2**exp to the specified precision level, rounding
83 according to the specified rounding mode if necessary.
84 Return a tuple containing the new (man, exp).
85 """
86 if man == 0:
87 return man, 0
88 bc = bitcount(man)
89 if bc > prec:
90 man = rshift(man, bc-prec, mode)
91 exp += (bc - prec)
92 # It is not necessary to strip trailing zeros, but this
93 # standardization permits faster equality testing of numbers
94 # with the same exponent
95 tr = trailing_zeros(man)
96 if tr:
97 man >>= tr
98 exp += tr
99 if man == 0:
100 exp = 0
101 return man, exp
103 #----------------------------------------------------------------------
104 # Other helper functions
107 def binary_to_decimal(man, exp, n):
108 """Represent as a decimal string with at most n digits"""
109 import decimal
110 prec_ = decimal.getcontext().prec
111 decimal.getcontext().prec = n
112 if exp >= 0: d = decimal.Decimal(man) * (1<<exp)
113 else: d = decimal.Decimal(man) / (1<<-exp)
114 a = str(d)
115 decimal.getcontext().prec = prec_
116 return a
118 _convratio = _clog(10,2) # 3.3219...
121 #---------------------------------------------------------------------------#
122 # Float class #
123 #---------------------------------------------------------------------------#
125 class Float(object):
127 A Float is a rational number of the form
129 man * 2**exp
131 ("man" and "exp" are short for "mantissa" and "exponent"). Both man
132 and exp are integers, possibly negative, and may be arbitrarily large.
133 Essentially, a larger mantissa corresponds to a higher precision
134 and a larger exponent corresponds to larger magnitude.
136 When performing an arithmetic operation on two Floats, or creating a
137 new Float from an existing numerical value, the result gets rounded
138 to a fixed precision level, just like with ordinary Python floats.
139 Unlike regular Python floats, however, the precision level can be
140 set arbitrarily high. You can also change the rounding mode (all
141 modes supported by Decimal are also supported by Float).
143 The precision level and rounding mode are stored as properties of
144 the Float class. (An individual Float instances does not have any
145 precision or rounding data associated with it.) The precision level
146 and rounding mode make up the current working context. You can
147 change the working context through static methods of the Float
148 class:
150 Float.setprec(n) -- set precision to n bits
151 Float.extraprec(n) -- increase precision by n bits
152 Float.setdps(n) -- set precision equivalent to n decimals
153 Float.setmode(mode) -- set rounding mode
155 Corresponding methods are available for inspection:
157 Float.getprec()
158 Float.getdps()
159 Float.getmode()
161 There are also two methods Float.store() and Float.revert(). If
162 you call Float.store() before changing precision or mode, the
163 old context can be restored with Float.revert(). (If Float.revert()
164 is called one time too much, the default settings are restored.)
165 You can nest multiple uses of store() and revert().
167 (In the future, it will also be possible to use the 'with'
168 statement to change contexts.)
170 Note that precision is measured in bits. Since the ratio between
171 binary and decimal precision levels is irrational, setprec and
172 setdps work slightly differently. When you set the precision with
173 setdps, the bit precision is set slightly higher than the exact
174 corresponding precision to account for the fact that decimal
175 numbers cannot generally be represented exactly in binary (the
176 classical example is 0.1). The exact number given to setdps
177 is however used by __str__ to determine number of digits to
178 display. Likewise, when you set a bit precision, the decimal
179 printing precision used for __str__ is set slightly lower.
181 The following rounding modes are available:
183 ROUND_DOWN -- toward zero
184 ROUND_UP -- away from zero
185 ROUND_FLOOR -- towards -oo
186 ROUND_CEILING -- towards +oo
187 ROUND_HALF_UP -- to nearest; 0.5 to 1
188 ROUND_HALF_DOWN -- to nearest; 0.5 to 0
189 ROUND_HALF_EVEN -- to nearest; 0.5 to 0 and 1.5 to 2
191 The rounding modes are available both as global constants defined
192 in this module and as properties of the Float class, e.g.
193 Float.ROUND_CEILING.
195 The default precision level is 53 bits and the default rounding
196 mode is ROUND_HALF_EVEN. In this mode, Floats should round exactly
197 like regular Python floats (in the absence of bugs!).
200 #------------------------------------------------------------------
201 # Static methods for context management
204 # Also make these constants available from the class
205 ROUND_DOWN = ROUND_DOWN
206 ROUND_UP = ROUND_UP
207 ROUND_FLOOR = ROUND_FLOOR
208 ROUND_CEILING = ROUND_CEILING
209 ROUND_HALF_UP = ROUND_HALF_UP
210 ROUND_HALF_DOWN = ROUND_HALF_DOWN
211 ROUND_HALF_EVEN = ROUND_HALF_EVEN
213 _prec = 53
214 _dps = 15
215 _mode = ROUND_HALF_EVEN
216 _stack = []
218 @staticmethod
219 def store():
220 """Store the current precision/rounding context. It can
221 be restored by calling Float.revert()"""
222 Float._stack.append((Float._prec, Float._dps, Float._mode))
224 @staticmethod
225 def revert():
226 """Revert to last precision/rounding context stored with
227 Float.store()"""
228 if Float._stack:
229 Float._prec, Float._dps, Float._mode = Float._stack.pop()
230 else:
231 Float._prec, Float._dps, Float._mode = 53, 15, ROUND_HALF_EVEN
233 @staticmethod
234 def setprec(n):
235 """Set precision to n bits"""
236 n = int(n)
237 Float._prec = n
238 Float._dps = int(round(n/_convratio)-1)
240 @staticmethod
241 def setdps(n):
242 """Set the precision high enough to allow representing numbers
243 with at least n decimal places without loss."""
244 n = int(n)
245 Float._prec = int(round((n+1)*_convratio))
246 Float._dps = n
248 @staticmethod
249 def extraprec(n):
250 Float.setprec(Float._prec + n)
252 @staticmethod
253 def setmode(mode):
254 assert isinstance(mode, _RoundingMode)
255 Float._mode = mode
257 @staticmethod
258 def getprec(): return Float._prec
260 @staticmethod
261 def getdps(): return Float._dps
263 @staticmethod
264 def getmode(): return Float._mode
267 #------------------------------------------------------------------
268 # Core object functionality
271 __slots__ = ["man", "exp"]
273 def __init__(s, x=0, prec=None, mode=None):
275 Float(x) creates a new Float instance with value x. The usual
276 types are supported for x:
278 >>> Float(3)
279 Float('3')
280 >>> Float(3.5)
281 Float('3.5')
282 >>> Float('3.5')
283 Float('3.5')
284 >>> Float(Rational(7,2))
285 Float('3.5')
287 You can also create a Float from a tuple specifying its
288 mantissa and exponent:
290 >>> Float((5, -3))
291 Float('0.625')
293 Use the prec and mode arguments to specify a custom precision
294 level (in bits) and rounding mode. If these arguments are
295 omitted, the current working precision is used instead.
297 >>> Float('0.500001', prec=3, mode=ROUND_DOWN)
298 Float('0.5')
299 >>> Float('0.500001', prec=3, mode=ROUND_UP)
300 Float('0.625')
303 prec = prec or s._prec
304 mode = mode or s._mode
305 if isinstance(x, tuple):
306 s.man, s.exp = normalize(x[0], x[1], prec, mode)
307 elif isinstance(x, Float):
308 s.man, s.exp = normalize(x.man, x.exp, prec, mode)
309 elif isinstance(x, (int, long)):
310 s.man, s.exp = normalize(x, 0, prec, mode)
311 elif isinstance(x, float):
312 m, e = math.frexp(x)
313 s.man, s.exp = normalize(int(m*2**53), e-53, prec, mode)
314 elif isinstance(x, (str, Rational)):
315 if isinstance(x, str):
316 x = Rational(x)
317 n = prec + bitcount(x.q) + 2
318 s.man, s.exp = normalize((x.p<<n)//x.q, -n, prec, mode)
319 else:
320 raise TypeError
322 def __pos__(s):
323 """s.__pos__() <==> +s
325 Normalize s to the current working precision, rounding according
326 to the current rounding mode."""
327 return Float(s)
329 def __repr__(s):
330 """Represent s as a decimal string, with sufficiently many
331 digits included to ensure that Float(repr(s)) == s at the
332 current working precision."""
333 st = "Float('%s')"
334 return st % binary_to_decimal(s.man, s.exp, Float._dps + 2)
336 def __str__(s):
337 """Print slightly more prettily than __repr__"""
338 return binary_to_decimal(s.man, s.exp, Float._dps)
340 def __float__(s):
341 """Convert s to a Python float. OverflowError will be raised
342 if the magnitude of s is too large."""
343 try:
344 return math.ldexp(s.man, s.exp)
345 # Handle case when mantissa has too many bits (will still
346 # overflow if exp is large)
347 except OverflowError:
348 n = bitcount(s.man) - 64
349 m = s.man >> n
350 return math.ldexp(m, s.exp + n)
352 #------------------------------------------------------------------
353 # Comparison
356 def __cmp__(s, t):
357 """__cmp__(s, t) <==> cmp(s, t)
359 Returns -1 if s < t, 0 if s == t, and 1 if s > t"""
360 if not isinstance(t, Float):
361 t = Float(t)
362 sm, se, tm, te = s.man, s.exp, t.man, t.exp
363 if tm == 0: return cmp(sm, 0)
364 if sm == 0: return cmp(0, tm)
365 if sm > 0 and tm < 0: return 1
366 if sm < 0 and tm > 0: return -1
367 if se == te: return cmp(sm, tm)
368 a = bitcount(sm) + se
369 b = bitcount(tm) + te
370 if sm > 0:
371 if a < b: return -1
372 if a > b: return 1
373 else:
374 if a < b: return 1
375 if a < b: return -1
376 return cmp((s-t).man, 0)
378 def ae(s, t, rel_eps=None, abs_eps=None):
380 "ae" is short for "almost equal"
382 Determine whether the difference between s and t is smaller
383 than a given epsilon.
385 Both a maximum relative difference and a maximum difference
386 (or 'epsilons') may be specified. The absolute difference is
387 defined as |s-t| and the relative difference is defined
388 as |s-t|/max(|s|, |t|).
390 If only one epsilon is given, both are set to the same value.
391 If none is given, both epsilons are set to 2**(-prec+m) where
392 prec is the current working precision and m is a small integer.
394 if not isinstance(t, Float):
395 t = Float(t)
396 if abs_eps is None and rel_eps is None:
397 rel_eps = Float((1, -s._prec+4))
398 if abs_eps is None:
399 abs_eps = rel_eps
400 elif rel_eps is None:
401 rel_eps = abs_eps
402 diff = abs(s-t)
403 if diff <= abs_eps:
404 return True
405 abss = abs(s)
406 abst = abs(t)
407 if abss < abst:
408 err = diff/t
409 else:
410 err = diff/s
411 return err <= rel_eps
413 def almost_zero(s, prec):
414 """Quick check if |s| < 2**-prec"""
415 return bitcount(s.man) + s.exp < prec
417 def __nonzero__(s):
418 return bool(s.man)
420 #------------------------------------------------------------------
421 # Arithmetic
424 def __abs__(s):
425 if s.man < 0:
426 return -s
427 return s
429 def __add__(s, t):
430 if isinstance(t, Float):
431 if t.exp > s.exp:
432 s, t = t, s
433 return Float((t.man+(s.man<<(s.exp-t.exp)), t.exp))
434 if isinstance(t, (int, long)):
435 # XXX: cancellation is possible here
436 return s + Float(t)
438 __radd__ = __add__
440 def __neg__(s):
441 return Float((-s.man, s.exp))
443 def __sub__(s, t):
444 return s + (-t)
446 def __rsub__(s, t):
447 return (-s) + t
449 def __mul__(s, t):
450 if isinstance(t, Float):
451 return Float((s.man*t.man, s.exp+t.exp))
452 if isinstance(t, (int, long)):
453 return
455 __rmul__ = __mul__
457 def __div__(s, t):
458 if t == 0:
459 raise ZeroDivisionError
460 if isinstance(t, Float):
461 extra = s._prec - bitcount(s.man) + bitcount(t.man) + 4
462 return Float(((s.man<<extra)//t.man, s.exp-t.exp-extra))
463 if isinstance(t, (int, long)):
464 extra = s._prec - bitcount(s.man) + bitcount(t) + 4
465 return Float(((s.man<<extra)//t, s.exp-extra))
467 def __pow__(s, n):
468 """Calculate (man*2**exp)**n, currently for integral n only."""
469 if isinstance(n, (int, long)):
470 if n == 0: return Float((1, 0))
471 if n == 1: return +s
472 if n == 2: return s * s
473 if n == -1: return 1 / s
474 if n < 0:
475 Float._prec += 2
476 r = 1 / (s ** -n)
477 Float._prec -= 2
478 return +r
479 else:
480 prec2 = Float._prec + int(4*_clog(n, 2) + 4)
481 man, exp = normalize(s.man, s.exp, prec2, ROUND_FLOOR)
482 pm, pe = 1, 0
483 while n:
484 if n & 1:
485 pm, pe = normalize(pm*man, pe+exp, prec2, ROUND_FLOOR)
486 n -= 1
487 man, exp = normalize(man*man, exp+exp, prec2, ROUND_FLOOR)
488 n = n // 2
489 return Float((pm, pe))
490 if n == 0.5:
491 return s.sqrt()
492 raise ValueError