little refactoring
[sympyx.git] / sympy.py
blobe6f048af45f1a0239ffff2cf88a3b296172d5fd1
1 BASIC = 0
2 SYMBOL = 1
3 ADD = 2
4 MUL = 3
5 POW = 4
6 INTEGER = 5
8 def hash_seq(args):
9 """
10 Hash of a sequence, that *depends* on the order of elements.
11 """
12 # make this more robust:
13 m = 2
14 for x in args:
15 m = hash(m + 1001 ^ hash(x))
16 return m
18 def compare_lists(a, b):
19 """
20 Compare two sequences.
22 Sequences are equal even with a *different* order of elements.
23 """
25 return set(a) == set(b)
27 class Basic(object):
29 def __new__(cls, type, args):
30 obj = object.__new__(cls)
31 obj.type = type
32 obj._args = tuple(args)
33 return obj
35 def __repr__(self):
36 return str(self)
38 def __hash__(self):
39 return hash_seq(self.args)
41 @property
42 def args(self):
43 return self._args
45 def as_coeff_rest(self):
46 return (Integer(1), self)
48 def as_base_exp(self):
49 return (self, Integer(1))
51 def expand(self):
52 return self
54 def __add__(x, y):
55 return Add((x, y))
57 def __radd__(x, y):
58 return x.__add__(y)
60 def __sub__(x, y):
61 return Add((x, -y))
63 def __rsub__(x, y):
64 return Add((y, -x))
66 def __mul__(x, y):
67 return Mul((x, y))
69 def __rmul__(x, y):
70 return Mul((y, x))
72 def __div__(x, y):
73 return Mul((x, Pow((y, Integer(-1)))))
75 def __rdiv__(x, y):
76 return Mul((y, Pow((x, Integer(-1)))))
78 def __pow__(x, y):
79 return Pow((x, y))
81 def __rpow__(x, y):
82 return Pow((y, x))
84 def __neg__(x):
85 return Mul((Integer(-1), x))
87 def __pos__(x):
88 return x
90 def __ne__(self, x):
91 return not self.__eq__(x)
93 def __eq__(self, o):
94 o = sympify(o)
95 if o.type == self.type:
96 return self.args == o.args
97 else:
98 return False
101 class Integer(Basic):
103 def __new__(cls, i):
104 obj = Basic.__new__(cls, INTEGER, [])
105 obj.i = i
106 return obj
108 def __hash__(self):
109 return hash(self.i)
111 def __eq__(self, o):
112 o = sympify(o)
113 if o.type == INTEGER:
114 return self.i == o.i
115 else:
116 return False
118 def __str__(self):
119 return str(self.i)
121 def __add__(self, o):
122 o = sympify(o)
123 if o.type == INTEGER:
124 return Integer(self.i+o.i)
125 return Basic.__add__(self, o)
127 def __mul__(self, o):
128 o = sympify(o)
129 if o.type == INTEGER:
130 return Integer(self.i*o.i)
131 return Basic.__mul__(self, o)
134 class Symbol(Basic):
136 def __new__(cls, name):
137 obj = Basic.__new__(cls, SYMBOL, [])
138 obj.name = name
139 return obj
141 def __hash__(self):
142 return hash(self.name)
144 def __eq__(self, o):
145 o = sympify(o)
146 if o.type == SYMBOL:
147 return self.name == o.name
148 return False
150 def __str__(self):
151 return self.name
154 class Add(Basic):
156 def __new__(cls, args, canonicalize=True):
157 if canonicalize == False:
158 obj = Basic.__new__(cls, ADD, args)
159 return obj
160 args = [sympify(x) for x in args]
161 return Add.canonicalize(args)
163 @classmethod
164 def canonicalize(cls, args):
165 d = {}
166 num = Integer(0)
167 for a in args:
168 if a.type == INTEGER:
169 num += a
170 elif a.type == ADD:
171 for b in a.args:
172 if b.type == INTEGER:
173 num += b
174 else:
175 coeff, key = b.as_coeff_rest()
176 if key in d:
177 d[key] += coeff
178 else:
179 d[key] = coeff
180 else:
181 coeff, key = a.as_coeff_rest()
182 if key in d:
183 d[key] += coeff
184 else:
185 d[key] = coeff
186 if len(d)==0:
187 return num
188 args = []
189 #print d
190 for a, b in d.iteritems():
191 args.append(Mul((a, b)))
192 if num.i != 0:
193 args.insert(0, num)
194 if len(args) == 1:
195 return args[0]
196 else:
197 return Add(args, False)
199 def __eq__(self, o):
200 o = sympify(o)
201 if o.type == ADD:
202 return compare_lists(self.args, o.args)
203 else:
204 return False
206 def __str__(self):
207 s = str(self.args[0])
208 if self.args[0].type == ADD:
209 s = "(%s)" % str(s)
210 for x in self.args[1:]:
211 s = "%s + %s" % (s, str(x))
212 if x.type == ADD:
213 s = "(%s)" % s
214 return s
216 def __hash__(self):
217 a = list(self.args[:])
218 a.sort(key=hash)
219 return hash_seq(a)
220 return hash(frozenset(self.args))
222 def expand(self):
223 r = Integer(0)
224 for term in self.args:
225 r += term.expand()
226 return r
228 class Mul(Basic):
230 def __new__(cls, args, canonicalize=True):
231 if canonicalize == False:
232 obj = Basic.__new__(cls, MUL, args)
233 return obj
234 args = [sympify(x) for x in args]
235 return Mul.canonicalize(args)
237 @classmethod
238 def canonicalize(cls, args):
239 use_glib = 0
240 if use_glib:
241 from csympy import HashTable
242 d = HashTable()
243 else:
244 d = {}
245 num = Integer(1)
246 for a in args:
247 if a.type == INTEGER:
248 num *= a
249 elif a.type == MUL:
250 for b in a.args:
251 if b.type == INTEGER:
252 num *= b
253 else:
254 key, coeff = b.as_base_exp()
255 if key in d:
256 d[key] += coeff
257 else:
258 d[key] = coeff
259 else:
260 key, coeff = a.as_base_exp()
261 if key in d:
262 d[key] += coeff
263 else:
264 d[key] = coeff
265 if num.i == 0 or len(d)==0:
266 return num
267 args = []
268 for a, b in d.iteritems():
269 args.append(Pow((a, b)))
270 if num.i != 1:
271 args.insert(0, num)
272 if len(args) == 1:
273 return args[0]
274 else:
275 return Mul(args, False)
277 def __hash__(self):
278 a = list(self.args[:])
279 a.sort(key=hash)
280 return hash_seq(a)
281 return hash(frozenset(self.args))
283 def __eq__(self, o):
284 o = sympify(o)
285 if o.type == MUL:
286 return compare_lists(self.args, o.args)
287 else:
288 return False
291 def as_coeff_rest(self):
292 if self.args[0].type == INTEGER:
293 return self.as_two_terms()
294 return (Integer(1), self)
296 def as_two_terms(self):
297 return (self.args[0], Mul(self.args[1:]))
300 def __str__(self):
301 s = str(self.args[0])
302 if self.args[0].type in [ADD, MUL]:
303 s = "(%s)" % str(s)
304 for x in self.args[1:]:
305 if x.type in [ADD, MUL]:
306 s = "%s * (%s)" % (s, str(x))
307 else:
308 s = "%s*%s" % (s, str(x))
309 return s
311 @classmethod
312 def expand_two(self, a, b):
314 Both a and b are assumed to be expanded.
316 if a.type == ADD and b.type == ADD:
317 r = Integer(0)
318 for x in a.args:
319 for y in b.args:
320 r += x*y
321 return r
322 if a.type == ADD:
323 r = Integer(0)
324 for x in a.args:
325 r += x*b
326 return r
327 if b.type == ADD:
328 r = Integer(0)
329 for y in b.args:
330 r += a*y
331 return r
332 return a*b
334 def expand(self):
335 a, b = self.as_two_terms()
336 a = a.expand()
337 b = b.expand()
338 return Mul.expand_two(a, b)
340 class Pow(Basic):
342 def __new__(cls, args, canonicalize=True):
343 if canonicalize == False:
344 obj = Basic.__new__(cls, POW, args)
345 return obj
346 args = [sympify(x) for x in args]
347 return Pow.canonicalize(args)
349 @classmethod
350 def canonicalize(cls, args):
351 base, exp = args
352 if base.type == INTEGER:
353 if base.i == 0:
354 return Integer(0)
355 if base.i == 1:
356 return Integer(1)
357 if exp.type == INTEGER:
358 if exp.i == 0:
359 return Integer(1)
360 if exp.i == 1:
361 return base
362 if base.type == POW:
363 return Pow((base.args[0], base.args[1]*exp))
364 return Pow(args, False)
366 def __str__(self):
367 s = str(self.args[0])
368 if self.args[0].type == ADD:
369 s = "(%s)" % s
370 if self.args[1].type == ADD:
371 s = "%s^(%s)" % (s, str(self.args[1]))
372 else:
373 s = "%s^%s" % (s, str(self.args[1]))
374 return s
376 def as_base_exp(self):
377 return self.args
379 def expand(self):
380 base, exp = self.args
381 if base.type == ADD and exp.type == INTEGER:
382 n = exp.i
383 m = len(base.args)
384 d = multinomial_coefficients(m, n)
385 r = Integer(0)
386 for powers, coeff in d.iteritems():
387 t = Integer(coeff)
388 for x, p in zip(base.args, powers):
389 t *= x**p
390 r += t
391 return r
392 return self
394 def sympify(x):
395 if isinstance(x, int):
396 return Integer(x)
397 return x
399 def var(s):
401 Create a symbolic variable with the name *s*.
403 INPUT:
404 s -- a string, either a single variable name, or
405 a space separated list of variable names, or
406 a list of variable names.
408 NOTE: The new variable is both returned and automatically injected into
409 the parent's *global* namespace. It's recommended not to use "var" in
410 library code, it is better to use symbols() instead.
412 EXAMPLES:
413 We define some symbolic variables:
414 >>> var('m')
416 >>> var('n xx yy zz')
417 (n, xx, yy, zz)
418 >>> n
422 import re
423 import inspect
424 frame = inspect.currentframe().f_back
426 try:
427 if not isinstance(s, list):
428 s = re.split('\s|,', s)
430 res = []
432 for t in s:
433 # skip empty strings
434 if not t:
435 continue
436 sym = Symbol(t)
437 frame.f_globals[t] = sym
438 res.append(sym)
440 res = tuple(res)
441 if len(res) == 0: # var('')
442 res = None
443 elif len(res) == 1: # var('x')
444 res = res[0]
445 # otherwise var('a b ...')
446 return res
448 finally:
449 # we should explicitly break cyclic dependencies as stated in inspect
450 # doc
451 del frame
453 def binomial_coefficients(n):
454 """Return a dictionary containing pairs {(k1,k2) : C_kn} where
455 C_kn are binomial coefficients and n=k1+k2."""
456 d = {(0, n):1, (n, 0):1}
457 a = 1
458 for k in xrange(1, n//2+1):
459 a = (a * (n-k+1))//k
460 d[k, n-k] = d[n-k, k] = a
461 return d
463 def binomial_coefficients_list(n):
464 """ Return a list of binomial coefficients as rows of the Pascal's
465 triangle.
467 d = [1] * (n+1)
468 a = 1
469 for k in xrange(1, n//2+1):
470 a = (a * (n-k+1))//k
471 d[k] = d[n-k] = a
472 return d
474 def multinomial_coefficients(m, n, _tuple=tuple, _zip=zip):
475 """Return a dictionary containing pairs ``{(k1,k2,..,km) : C_kn}``
476 where ``C_kn`` are multinomial coefficients such that
477 ``n=k1+k2+..+km``.
479 For example:
481 >>> print multinomial_coefficients(2,5)
482 {(3, 2): 10, (1, 4): 5, (2, 3): 10, (5, 0): 1, (0, 5): 1, (4, 1): 5}
484 The algorithm is based on the following result:
486 Consider a polynomial and it's ``m``-th exponent::
488 P(x) = sum_{i=0}^m p_i x^k
489 P(x)^n = sum_{k=0}^{m n} a(n,k) x^k
491 The coefficients ``a(n,k)`` can be computed using the
492 J.C.P. Miller Pure Recurrence [see D.E.Knuth, Seminumerical
493 Algorithms, The art of Computer Programming v.2, Addison
494 Wesley, Reading, 1981;]::
496 a(n,k) = 1/(k p_0) sum_{i=1}^m p_i ((n+1)i-k) a(n,k-i),
498 where ``a(n,0) = p_0^n``.
501 if m==2:
502 return binomial_coefficients(n)
503 symbols = [(0,)*i + (1,) + (0,)*(m-i-1) for i in range(m)]
504 s0 = symbols[0]
505 p0 = [_tuple(aa-bb for aa,bb in _zip(s,s0)) for s in symbols]
506 r = {_tuple(aa*n for aa in s0):1}
507 r_get = r.get
508 r_update = r.update
509 l = [0] * (n*(m-1)+1)
510 l[0] = r.items()
511 for k in xrange(1, n*(m-1)+1):
512 d = {}
513 d_get = d.get
514 for i in xrange(1, min(m,k+1)):
515 nn = (n+1)*i-k
516 if not nn:
517 continue
518 t = p0[i]
519 for t2, c2 in l[k-i]:
520 tt = _tuple([aa+bb for aa,bb in _zip(t2,t)])
521 cc = nn * c2
522 b = d_get(tt)
523 if b is None:
524 d[tt] = cc
525 else:
526 cc = b + cc
527 if cc:
528 d[tt] = cc
529 else:
530 del d[tt]
531 r1 = [(t, c//k) for (t, c) in d.iteritems()]
532 l[k] = r1
533 r_update(r1)
534 return r