Add whitespace.
[sympyx.git] / sympy_py.py
blob6e228b09938c6955879c25ef108a22817a7d6686
1 BASIC = 0
2 SYMBOL = 1
3 ADD = 2
4 MUL = 3
5 POW = 4
6 INTEGER = 5
8 def hash_seq(args):
9 """
10 Hash of a sequence, that *depends* on the order of elements.
11 """
12 # make this more robust:
13 m = 2
14 for x in args:
15 m = hash(m + 1001 ^ hash(x))
16 return m
18 class Basic(object):
20 def __new__(cls, type, args):
21 obj = object.__new__(cls)
22 obj.type = type
23 obj._args = tuple(args)
24 obj.mhash = None
25 return obj
27 def __repr__(self):
28 return str(self)
30 def __hash__(self):
31 if self.mhash is None:
32 h = hash_seq(self.args)
33 self.mhash = h
34 return h
35 else:
36 return self.mhash
38 @property
39 def args(self):
40 return self._args
42 def as_coeff_rest(self):
43 return (Integer(1), self)
45 def as_base_exp(self):
46 return (self, Integer(1))
48 def expand(self):
49 return self
51 def __add__(x, y):
52 return Add((x, y))
54 def __radd__(x, y):
55 return x.__add__(y)
57 def __sub__(x, y):
58 return Add((x, -y))
60 def __rsub__(x, y):
61 return Add((y, -x))
63 def __mul__(x, y):
64 return Mul((x, y))
66 def __rmul__(x, y):
67 return Mul((y, x))
69 def __div__(x, y):
70 return Mul((x, Pow((y, Integer(-1)))))
72 def __rdiv__(x, y):
73 return Mul((y, Pow((x, Integer(-1)))))
75 def __pow__(x, y):
76 return Pow((x, y))
78 def __rpow__(x, y):
79 return Pow((y, x))
81 def __neg__(x):
82 return Mul((Integer(-1), x))
84 def __pos__(x):
85 return x
87 def __ne__(self, x):
88 return not self.__eq__(x)
90 def __eq__(self, o):
91 o = sympify(o)
92 if o.type == self.type:
93 return self.args == o.args
94 else:
95 return False
98 class Integer(Basic):
100 def __new__(cls, i):
101 obj = Basic.__new__(cls, INTEGER, [])
102 obj.i = i
103 return obj
105 def __hash__(self):
106 if self.mhash is None:
107 h = hash(self.i)
108 self.mhash = h
109 return h
110 else:
111 return self.mhash
113 def __eq__(self, o):
114 o = sympify(o)
115 if o.type == INTEGER:
116 return self.i == o.i
117 else:
118 return False
120 def __str__(self):
121 return str(self.i)
123 def __add__(self, o):
124 o = sympify(o)
125 if o.type == INTEGER:
126 return Integer(self.i+o.i)
127 return Basic.__add__(self, o)
129 def __mul__(self, o):
130 o = sympify(o)
131 if o.type == INTEGER:
132 return Integer(self.i*o.i)
133 return Basic.__mul__(self, o)
136 class Symbol(Basic):
138 def __new__(cls, name):
139 obj = Basic.__new__(cls, SYMBOL, [])
140 obj.name = name
141 return obj
143 def __hash__(self):
144 if self.mhash is None:
145 h = hash(self.name)
146 self.mhash = h
147 return h
148 else:
149 return self.mhash
151 def __eq__(self, o):
152 o = sympify(o)
153 if o.type == SYMBOL:
154 return self.name == o.name
155 return False
157 def __str__(self):
158 return self.name
161 class Add(Basic):
163 def __new__(cls, args, canonicalize=True):
164 if canonicalize == False:
165 obj = Basic.__new__(cls, ADD, args)
166 obj._args_set = None
167 return obj
168 args = [sympify(x) for x in args]
169 return Add.canonicalize(args)
171 @classmethod
172 def canonicalize(cls, args):
173 use_glib = 0
174 if use_glib:
175 from csympy import HashTable
176 d = HashTable()
177 else:
178 d = {}
179 num = Integer(0)
180 for a in args:
181 if a.type == INTEGER:
182 num += a
183 elif a.type == ADD:
184 for b in a.args:
185 if b.type == INTEGER:
186 num += b
187 else:
188 coeff, key = b.as_coeff_rest()
189 if key in d:
190 d[key] += coeff
191 else:
192 d[key] = coeff
193 else:
194 coeff, key = a.as_coeff_rest()
195 if key in d:
196 d[key] += coeff
197 else:
198 d[key] = coeff
199 if len(d)==0:
200 return num
201 args = []
202 for a, b in d.iteritems():
203 args.append(Mul((a, b)))
204 if num.i != 0:
205 args.insert(0, num)
206 if len(args) == 1:
207 return args[0]
208 else:
209 return Add(args, False)
211 def freeze_args(self):
212 #print "add is freezing"
213 if self._args_set is None:
214 self._args_set = frozenset(self.args)
215 #print "done"
217 def __eq__(self, o):
218 o = sympify(o)
219 if o.type == ADD:
220 self.freeze_args()
221 o.freeze_args()
222 return self._args_set == o._args_set
223 else:
224 return False
226 def __str__(self):
227 s = str(self.args[0])
228 if self.args[0].type == ADD:
229 s = "(%s)" % str(s)
230 for x in self.args[1:]:
231 s = "%s + %s" % (s, str(x))
232 if x.type == ADD:
233 s = "(%s)" % s
234 return s
236 def __hash__(self):
237 if self.mhash is None:
238 # XXX: it is surprising, but this is *not* faster:
239 #self.freeze_args()
240 #h = hash(self._args_set)
242 # this is faster:
243 a = list(self.args[:])
244 a.sort(key=hash)
245 h = hash_seq(a)
246 self.mhash = h
247 return h
248 else:
249 return self.mhash
251 def expand(self):
252 r = Integer(0)
253 for term in self.args:
254 r += term.expand()
255 return r
257 class Mul(Basic):
259 def __new__(cls, args, canonicalize=True):
260 if canonicalize == False:
261 obj = Basic.__new__(cls, MUL, args)
262 obj._args_set = None
263 return obj
264 args = [sympify(x) for x in args]
265 return Mul.canonicalize(args)
267 @classmethod
268 def canonicalize(cls, args):
269 use_glib = 0
270 if use_glib:
271 from csympy import HashTable
272 d = HashTable()
273 else:
274 d = {}
275 num = Integer(1)
276 for a in args:
277 if a.type == INTEGER:
278 num *= a
279 elif a.type == MUL:
280 for b in a.args:
281 if b.type == INTEGER:
282 num *= b
283 else:
284 key, coeff = b.as_base_exp()
285 if key in d:
286 d[key] += coeff
287 else:
288 d[key] = coeff
289 else:
290 key, coeff = a.as_base_exp()
291 if key in d:
292 d[key] += coeff
293 else:
294 d[key] = coeff
295 if num.i == 0 or len(d)==0:
296 return num
297 args = []
298 for a, b in d.iteritems():
299 args.append(Pow((a, b)))
300 if num.i != 1:
301 args.insert(0, num)
302 if len(args) == 1:
303 return args[0]
304 else:
305 return Mul(args, False)
307 def __hash__(self):
308 if self.mhash is None:
309 # in contrast to Add, here it is faster:
310 self.freeze_args()
311 h = hash(self._args_set)
312 # this is slower:
313 #a = list(self.args[:])
314 #a.sort(key=hash)
315 #h = hash_seq(a)
316 self.mhash = h
317 return h
318 else:
319 return self.mhash
321 def freeze_args(self):
322 #print "mul is freezing"
323 if self._args_set is None:
324 self._args_set = frozenset(self.args)
325 #print "done"
327 def __eq__(self, o):
328 o = sympify(o)
329 if o.type == MUL:
330 self.freeze_args()
331 o.freeze_args()
332 return self._args_set == o._args_set
333 else:
334 return False
337 def as_coeff_rest(self):
338 if self.args[0].type == INTEGER:
339 return self.as_two_terms()
340 return (Integer(1), self)
342 def as_two_terms(self):
343 return (self.args[0], Mul(self.args[1:]))
346 def __str__(self):
347 s = str(self.args[0])
348 if self.args[0].type in [ADD, MUL]:
349 s = "(%s)" % str(s)
350 for x in self.args[1:]:
351 if x.type in [ADD, MUL]:
352 s = "%s * (%s)" % (s, str(x))
353 else:
354 s = "%s*%s" % (s, str(x))
355 return s
357 @classmethod
358 def expand_two(self, a, b):
360 Both a and b are assumed to be expanded.
362 if a.type == ADD and b.type == ADD:
363 r = Integer(0)
364 for x in a.args:
365 for y in b.args:
366 r += x*y
367 return r
368 if a.type == ADD:
369 r = Integer(0)
370 for x in a.args:
371 r += x*b
372 return r
373 if b.type == ADD:
374 r = Integer(0)
375 for y in b.args:
376 r += a*y
377 return r
378 return a*b
380 def expand(self):
381 a, b = self.as_two_terms()
382 r = Mul.expand_two(a, b)
383 if r == self:
384 a = a.expand()
385 b = b.expand()
386 return Mul.expand_two(a, b)
387 else:
388 return r.expand()
390 class Pow(Basic):
392 def __new__(cls, args, canonicalize=True):
393 if canonicalize == False:
394 obj = Basic.__new__(cls, POW, args)
395 return obj
396 args = [sympify(x) for x in args]
397 return Pow.canonicalize(args)
399 @classmethod
400 def canonicalize(cls, args):
401 base, exp = args
402 if base.type == INTEGER:
403 if base.i == 0:
404 return Integer(0)
405 if base.i == 1:
406 return Integer(1)
407 if exp.type == INTEGER:
408 if exp.i == 0:
409 return Integer(1)
410 if exp.i == 1:
411 return base
412 if base.type == POW:
413 return Pow((base.args[0], base.args[1]*exp))
414 return Pow(args, False)
416 def __str__(self):
417 s = str(self.args[0])
418 if self.args[0].type == ADD:
419 s = "(%s)" % s
420 if self.args[1].type == ADD:
421 s = "%s^(%s)" % (s, str(self.args[1]))
422 else:
423 s = "%s^%s" % (s, str(self.args[1]))
424 return s
426 def as_base_exp(self):
427 return self.args
429 def expand(self):
430 base, exp = self.args
431 if base.type == ADD and exp.type == INTEGER:
432 n = exp.i
433 m = len(base.args)
434 #print "multi"
435 d = multinomial_coefficients(m, n)
436 #print "assembly"
437 r = []
438 for powers, coeff in d.iteritems():
439 if coeff == 1:
440 t = []
441 else:
442 t = [Integer(coeff)]
443 for x, p in zip(base.args, powers):
444 if p != 0:
445 t.append(Pow((x, p)))
446 assert len(t) != 0
447 if len(t) == 1:
448 t = t[0]
449 else:
450 t = Mul(t, False)
451 r.append(t)
452 r = Add(r, False)
453 #print "done"
454 return r
455 return self
457 def sympify(x):
458 if isinstance(x, int):
459 return Integer(x)
460 return x
462 def var(s):
464 Create a symbolic variable with the name *s*.
466 INPUT:
467 s -- a string, either a single variable name, or
468 a space separated list of variable names, or
469 a list of variable names.
471 NOTE: The new variable is both returned and automatically injected into
472 the parent's *global* namespace. It's recommended not to use "var" in
473 library code, it is better to use symbols() instead.
475 EXAMPLES:
476 We define some symbolic variables:
477 >>> var('m')
479 >>> var('n xx yy zz')
480 (n, xx, yy, zz)
481 >>> n
485 import re
486 import inspect
487 frame = inspect.currentframe().f_back
489 try:
490 if not isinstance(s, list):
491 s = re.split('\s|,', s)
493 res = []
495 for t in s:
496 # skip empty strings
497 if not t:
498 continue
499 sym = Symbol(t)
500 frame.f_globals[t] = sym
501 res.append(sym)
503 res = tuple(res)
504 if len(res) == 0: # var('')
505 res = None
506 elif len(res) == 1: # var('x')
507 res = res[0]
508 # otherwise var('a b ...')
509 return res
511 finally:
512 # we should explicitly break cyclic dependencies as stated in inspect
513 # doc
514 del frame
516 def binomial_coefficients(n):
517 """Return a dictionary containing pairs {(k1,k2) : C_kn} where
518 C_kn are binomial coefficients and n=k1+k2."""
519 d = {(0, n):1, (n, 0):1}
520 a = 1
521 for k in xrange(1, n//2+1):
522 a = (a * (n-k+1))//k
523 d[k, n-k] = d[n-k, k] = a
524 return d
526 def binomial_coefficients_list(n):
527 """ Return a list of binomial coefficients as rows of the Pascal's
528 triangle.
530 d = [1] * (n+1)
531 a = 1
532 for k in xrange(1, n//2+1):
533 a = (a * (n-k+1))//k
534 d[k] = d[n-k] = a
535 return d
537 def multinomial_coefficients(m, n, _tuple=tuple, _zip=zip):
538 """Return a dictionary containing pairs ``{(k1,k2,..,km) : C_kn}``
539 where ``C_kn`` are multinomial coefficients such that
540 ``n=k1+k2+..+km``.
542 For example:
544 >>> print multinomial_coefficients(2,5)
545 {(3, 2): 10, (1, 4): 5, (2, 3): 10, (5, 0): 1, (0, 5): 1, (4, 1): 5}
547 The algorithm is based on the following result:
549 Consider a polynomial and it's ``m``-th exponent::
551 P(x) = sum_{i=0}^m p_i x^k
552 P(x)^n = sum_{k=0}^{m n} a(n,k) x^k
554 The coefficients ``a(n,k)`` can be computed using the
555 J.C.P. Miller Pure Recurrence [see D.E.Knuth, Seminumerical
556 Algorithms, The art of Computer Programming v.2, Addison
557 Wesley, Reading, 1981;]::
559 a(n,k) = 1/(k p_0) sum_{i=1}^m p_i ((n+1)i-k) a(n,k-i),
561 where ``a(n,0) = p_0^n``.
564 if m==2:
565 return binomial_coefficients(n)
566 symbols = [(0,)*i + (1,) + (0,)*(m-i-1) for i in range(m)]
567 s0 = symbols[0]
568 p0 = [_tuple(aa-bb for aa,bb in _zip(s,s0)) for s in symbols]
569 r = {_tuple(aa*n for aa in s0):1}
570 r_get = r.get
571 r_update = r.update
572 l = [0] * (n*(m-1)+1)
573 l[0] = r.items()
574 for k in xrange(1, n*(m-1)+1):
575 d = {}
576 d_get = d.get
577 for i in xrange(1, min(m,k+1)):
578 nn = (n+1)*i-k
579 if not nn:
580 continue
581 t = p0[i]
582 for t2, c2 in l[k-i]:
583 tt = _tuple([aa+bb for aa,bb in _zip(t2,t)])
584 cc = nn * c2
585 b = d_get(tt)
586 if b is None:
587 d[tt] = cc
588 else:
589 cc = b + cc
590 if cc:
591 d[tt] = cc
592 else:
593 del d[tt]
594 r1 = [(t, c//k) for (t, c) in d.iteritems()]
595 l[k] = r1
596 r_update(r1)
597 return r