refactoring
[sympyx.git] / sympy.py
blob91aa8b1c1ae420d63f2400b2e188aee7860c5e8b
1 BASIC = 0
2 SYMBOL = 1
3 ADD = 2
4 MUL = 3
5 POW = 4
6 INTEGER = 5
8 def hash_seq(args):
9 """
10 Hash of a sequence, that *depends* on the order of elements.
11 """
12 # make this more robust:
13 m = 2
14 for x in args:
15 m = hash(m + 1001 ^ hash(x))
16 return m
18 def compare_lists(a, b):
19 """
20 Compare two sequences.
22 Sequences are equal even with a *different* order of elements.
23 """
25 return set(a) == set(b)
27 class Basic(object):
29 def __new__(cls, type, args):
30 obj = object.__new__(cls)
31 obj.type = type
32 obj._args = tuple(args)
33 return obj
35 def __repr__(self):
36 return str(self)
38 def __hash__(self):
39 return hash_seq(self.args)
41 @property
42 def args(self):
43 return self._args
45 def as_coeff_rest(self):
46 return (Integer(1), self)
48 def as_base_exp(self):
49 return (self, Integer(1))
51 def expand(self):
52 return self
54 def __add__(x, y):
55 return Add((x, y))
57 def __radd__(x, y):
58 return x.__add__(y)
60 def __sub__(x, y):
61 return Add((x, -y))
63 def __rsub__(x, y):
64 return Add((y, -x))
66 def __mul__(x, y):
67 return Mul((x, y))
69 def __rmul__(x, y):
70 return Mul((y, x))
72 def __div__(x, y):
73 return Mul((x, Pow((y, Integer(-1)))))
75 def __rdiv__(x, y):
76 return Mul((y, Pow((x, Integer(-1)))))
78 def __pow__(x, y):
79 return Pow((x, y))
81 def __rpow__(x, y):
82 return Pow((y, x))
84 def __neg__(x):
85 return Mul((Integer(-1), x))
87 def __pos__(x):
88 return x
90 def __ne__(self, x):
91 return not self.__eq__(x)
93 def __eq__(self, o):
94 o = sympify(o)
95 if o.type == self.type:
96 return self.args == o.args
97 else:
98 return False
101 class Integer(Basic):
103 def __new__(cls, i):
104 obj = Basic.__new__(cls, INTEGER, [])
105 obj.i = i
106 return obj
108 def __hash__(self):
109 return hash(self.i)
111 def __eq__(self, o):
112 o = sympify(o)
113 if o.type == INTEGER:
114 return self.i == o.i
115 else:
116 return False
118 def __str__(self):
119 return str(self.i)
121 def __add__(self, o):
122 o = sympify(o)
123 if o.type == INTEGER:
124 return Integer(self.i+o.i)
125 return Basic.__add__(self, o)
127 def __mul__(self, o):
128 o = sympify(o)
129 if o.type == INTEGER:
130 return Integer(self.i*o.i)
131 return Basic.__mul__(self, o)
134 class Symbol(Basic):
136 def __new__(cls, name):
137 obj = Basic.__new__(cls, SYMBOL, [])
138 obj.name = name
139 return obj
141 def __hash__(self):
142 return hash(self.name)
144 def __eq__(self, o):
145 o = sympify(o)
146 if o.type == SYMBOL:
147 return self.name == o.name
148 return False
150 def __str__(self):
151 return self.name
154 class Add(Basic):
156 def __new__(cls, args, canonicalize=True):
157 if canonicalize == False:
158 obj = Basic.__new__(cls, ADD, args)
159 return obj
160 args = [sympify(x) for x in args]
161 return Add.canonicalize(args)
163 @classmethod
164 def canonicalize(cls, args):
165 use_glib = 0
166 if use_glib:
167 from csympy import HashTable
168 d = HashTable()
169 else:
170 d = {}
171 num = Integer(0)
172 for a in args:
173 if a.type == INTEGER:
174 num += a
175 elif a.type == ADD:
176 for b in a.args:
177 if b.type == INTEGER:
178 num += b
179 else:
180 coeff, key = b.as_coeff_rest()
181 if key in d:
182 d[key] += coeff
183 else:
184 d[key] = coeff
185 else:
186 coeff, key = a.as_coeff_rest()
187 if key in d:
188 d[key] += coeff
189 else:
190 d[key] = coeff
191 if len(d)==0:
192 return num
193 args = []
194 for a, b in d.iteritems():
195 args.append(Mul((a, b)))
196 if num.i != 0:
197 args.insert(0, num)
198 if len(args) == 1:
199 return args[0]
200 else:
201 return Add(args, False)
203 def __eq__(self, o):
204 o = sympify(o)
205 if o.type == ADD:
206 return compare_lists(self.args, o.args)
207 else:
208 return False
210 def __str__(self):
211 s = str(self.args[0])
212 if self.args[0].type == ADD:
213 s = "(%s)" % str(s)
214 for x in self.args[1:]:
215 s = "%s + %s" % (s, str(x))
216 if x.type == ADD:
217 s = "(%s)" % s
218 return s
220 def __hash__(self):
221 a = list(self.args[:])
222 a.sort(key=hash)
223 return hash_seq(a)
225 def expand(self):
226 r = Integer(0)
227 for term in self.args:
228 r += term.expand()
229 return r
231 class Mul(Basic):
233 def __new__(cls, args, canonicalize=True):
234 if canonicalize == False:
235 obj = Basic.__new__(cls, MUL, args)
236 return obj
237 args = [sympify(x) for x in args]
238 return Mul.canonicalize(args)
240 @classmethod
241 def canonicalize(cls, args):
242 use_glib = 0
243 if use_glib:
244 from csympy import HashTable
245 d = HashTable()
246 else:
247 d = {}
248 num = Integer(1)
249 for a in args:
250 if a.type == INTEGER:
251 num *= a
252 elif a.type == MUL:
253 for b in a.args:
254 if b.type == INTEGER:
255 num *= b
256 else:
257 key, coeff = b.as_base_exp()
258 if key in d:
259 d[key] += coeff
260 else:
261 d[key] = coeff
262 else:
263 key, coeff = a.as_base_exp()
264 if key in d:
265 d[key] += coeff
266 else:
267 d[key] = coeff
268 if num.i == 0 or len(d)==0:
269 return num
270 args = []
271 for a, b in d.iteritems():
272 args.append(Pow((a, b)))
273 if num.i != 1:
274 args.insert(0, num)
275 if len(args) == 1:
276 return args[0]
277 else:
278 return Mul(args, False)
280 def __hash__(self):
281 a = list(self.args[:])
282 a.sort(key=hash)
283 return hash_seq(a)
285 def __eq__(self, o):
286 o = sympify(o)
287 if o.type == MUL:
288 return compare_lists(self.args, o.args)
289 else:
290 return False
293 def as_coeff_rest(self):
294 if self.args[0].type == INTEGER:
295 return self.as_two_terms()
296 return (Integer(1), self)
298 def as_two_terms(self):
299 return (self.args[0], Mul(self.args[1:]))
302 def __str__(self):
303 s = str(self.args[0])
304 if self.args[0].type in [ADD, MUL]:
305 s = "(%s)" % str(s)
306 for x in self.args[1:]:
307 if x.type in [ADD, MUL]:
308 s = "%s * (%s)" % (s, str(x))
309 else:
310 s = "%s*%s" % (s, str(x))
311 return s
313 @classmethod
314 def expand_two(self, a, b):
316 Both a and b are assumed to be expanded.
318 if a.type == ADD and b.type == ADD:
319 r = Integer(0)
320 for x in a.args:
321 for y in b.args:
322 r += x*y
323 return r
324 if a.type == ADD:
325 r = Integer(0)
326 for x in a.args:
327 r += x*b
328 return r
329 if b.type == ADD:
330 r = Integer(0)
331 for y in b.args:
332 r += a*y
333 return r
334 return a*b
336 def expand(self):
337 a, b = self.as_two_terms()
338 r = Mul.expand_two(a, b)
339 if r == self:
340 a = a.expand()
341 b = b.expand()
342 return Mul.expand_two(a, b)
343 else:
344 return r.expand()
346 class Pow(Basic):
348 def __new__(cls, args, canonicalize=True):
349 if canonicalize == False:
350 obj = Basic.__new__(cls, POW, args)
351 return obj
352 args = [sympify(x) for x in args]
353 return Pow.canonicalize(args)
355 @classmethod
356 def canonicalize(cls, args):
357 base, exp = args
358 if base.type == INTEGER:
359 if base.i == 0:
360 return Integer(0)
361 if base.i == 1:
362 return Integer(1)
363 if exp.type == INTEGER:
364 if exp.i == 0:
365 return Integer(1)
366 if exp.i == 1:
367 return base
368 if base.type == POW:
369 return Pow((base.args[0], base.args[1]*exp))
370 return Pow(args, False)
372 def __str__(self):
373 s = str(self.args[0])
374 if self.args[0].type == ADD:
375 s = "(%s)" % s
376 if self.args[1].type == ADD:
377 s = "%s^(%s)" % (s, str(self.args[1]))
378 else:
379 s = "%s^%s" % (s, str(self.args[1]))
380 return s
382 def as_base_exp(self):
383 return self.args
385 def expand(self):
386 base, exp = self.args
387 if base.type == ADD and exp.type == INTEGER:
388 n = exp.i
389 m = len(base.args)
390 print "multi:", m, n
391 d = multinomial_coefficients(m, n)
392 r = Integer(0)
393 print "assembly"
394 for powers, coeff in d.iteritems():
395 t = [Integer(coeff)]
396 for x, p in zip(base.args, powers):
397 t.append(Pow((x, p)))
398 t = Mul(t)
399 if r.type == ADD:
400 add_args = list(r.args) + [t]
401 r = Add(add_args, True)
402 else:
403 r = r + t
404 print "done"
405 return r
406 return self
408 def sympify(x):
409 if isinstance(x, int):
410 return Integer(x)
411 return x
413 def var(s):
415 Create a symbolic variable with the name *s*.
417 INPUT:
418 s -- a string, either a single variable name, or
419 a space separated list of variable names, or
420 a list of variable names.
422 NOTE: The new variable is both returned and automatically injected into
423 the parent's *global* namespace. It's recommended not to use "var" in
424 library code, it is better to use symbols() instead.
426 EXAMPLES:
427 We define some symbolic variables:
428 >>> var('m')
430 >>> var('n xx yy zz')
431 (n, xx, yy, zz)
432 >>> n
436 import re
437 import inspect
438 frame = inspect.currentframe().f_back
440 try:
441 if not isinstance(s, list):
442 s = re.split('\s|,', s)
444 res = []
446 for t in s:
447 # skip empty strings
448 if not t:
449 continue
450 sym = Symbol(t)
451 frame.f_globals[t] = sym
452 res.append(sym)
454 res = tuple(res)
455 if len(res) == 0: # var('')
456 res = None
457 elif len(res) == 1: # var('x')
458 res = res[0]
459 # otherwise var('a b ...')
460 return res
462 finally:
463 # we should explicitly break cyclic dependencies as stated in inspect
464 # doc
465 del frame
467 def binomial_coefficients(n):
468 """Return a dictionary containing pairs {(k1,k2) : C_kn} where
469 C_kn are binomial coefficients and n=k1+k2."""
470 d = {(0, n):1, (n, 0):1}
471 a = 1
472 for k in xrange(1, n//2+1):
473 a = (a * (n-k+1))//k
474 d[k, n-k] = d[n-k, k] = a
475 return d
477 def binomial_coefficients_list(n):
478 """ Return a list of binomial coefficients as rows of the Pascal's
479 triangle.
481 d = [1] * (n+1)
482 a = 1
483 for k in xrange(1, n//2+1):
484 a = (a * (n-k+1))//k
485 d[k] = d[n-k] = a
486 return d
488 def multinomial_coefficients(m, n, _tuple=tuple, _zip=zip):
489 """Return a dictionary containing pairs ``{(k1,k2,..,km) : C_kn}``
490 where ``C_kn`` are multinomial coefficients such that
491 ``n=k1+k2+..+km``.
493 For example:
495 >>> print multinomial_coefficients(2,5)
496 {(3, 2): 10, (1, 4): 5, (2, 3): 10, (5, 0): 1, (0, 5): 1, (4, 1): 5}
498 The algorithm is based on the following result:
500 Consider a polynomial and it's ``m``-th exponent::
502 P(x) = sum_{i=0}^m p_i x^k
503 P(x)^n = sum_{k=0}^{m n} a(n,k) x^k
505 The coefficients ``a(n,k)`` can be computed using the
506 J.C.P. Miller Pure Recurrence [see D.E.Knuth, Seminumerical
507 Algorithms, The art of Computer Programming v.2, Addison
508 Wesley, Reading, 1981;]::
510 a(n,k) = 1/(k p_0) sum_{i=1}^m p_i ((n+1)i-k) a(n,k-i),
512 where ``a(n,0) = p_0^n``.
515 if m==2:
516 return binomial_coefficients(n)
517 symbols = [(0,)*i + (1,) + (0,)*(m-i-1) for i in range(m)]
518 s0 = symbols[0]
519 p0 = [_tuple(aa-bb for aa,bb in _zip(s,s0)) for s in symbols]
520 r = {_tuple(aa*n for aa in s0):1}
521 r_get = r.get
522 r_update = r.update
523 l = [0] * (n*(m-1)+1)
524 l[0] = r.items()
525 for k in xrange(1, n*(m-1)+1):
526 d = {}
527 d_get = d.get
528 for i in xrange(1, min(m,k+1)):
529 nn = (n+1)*i-k
530 if not nn:
531 continue
532 t = p0[i]
533 for t2, c2 in l[k-i]:
534 tt = _tuple([aa+bb for aa,bb in _zip(t2,t)])
535 cc = nn * c2
536 b = d_get(tt)
537 if b is None:
538 d[tt] = cc
539 else:
540 cc = b + cc
541 if cc:
542 d[tt] = cc
543 else:
544 del d[tt]
545 r1 = [(t, c//k) for (t, c) in d.iteritems()]
546 l[k] = r1
547 r_update(r1)
548 return r