Show how to use functions.
[sympyx.git] / sympy_py.py
blobaf5606c27d6bc5eed509166b15ffd4ad33ad385f
1 BASIC = 0
2 SYMBOL = 1
3 ADD = 2
4 MUL = 3
5 POW = 4
6 INTEGER = 5
8 def hash_seq(args):
9 """
10 Hash of a sequence, that *depends* on the order of elements.
11 """
12 # make this more robust:
13 m = 2
14 for x in args:
15 m = hash(m + 1001 ^ hash(x))
16 return m
18 class Basic(object):
20 def __new__(cls, type, args):
21 obj = object.__new__(cls)
22 obj.type = type
23 obj._args = tuple(args)
24 obj.mhash = None
25 return obj
27 def __repr__(self):
28 return str(self)
30 def __hash__(self):
31 if self.mhash is None:
32 h = hash_seq(self.args)
33 self.mhash = h
34 return h
35 else:
36 return self.mhash
38 @property
39 def args(self):
40 return self._args
42 def as_coeff_rest(self):
43 return (Integer(1), self)
45 def as_base_exp(self):
46 return (self, Integer(1))
48 def expand(self):
49 return self
51 def __add__(x, y):
52 return Add((x, y))
54 def __radd__(x, y):
55 return x.__add__(y)
57 def __sub__(x, y):
58 return Add((x, -y))
60 def __rsub__(x, y):
61 return Add((y, -x))
63 def __mul__(x, y):
64 return Mul((x, y))
66 def __rmul__(x, y):
67 return Mul((y, x))
69 def __div__(x, y):
70 return Mul((x, Pow((y, Integer(-1)))))
72 def __rdiv__(x, y):
73 return Mul((y, Pow((x, Integer(-1)))))
75 def __pow__(x, y):
76 return Pow((x, y))
78 def __rpow__(x, y):
79 return Pow((y, x))
81 def __neg__(x):
82 return Mul((Integer(-1), x))
84 def __pos__(x):
85 return x
87 def __ne__(self, x):
88 return not self.__eq__(x)
90 def __eq__(self, o):
91 o = sympify(o)
92 if o.type == self.type:
93 return self.args == o.args
94 else:
95 return False
98 class Integer(Basic):
100 def __new__(cls, i):
101 obj = Basic.__new__(cls, INTEGER, [])
102 obj.i = i
103 return obj
105 def __hash__(self):
106 if self.mhash is None:
107 h = hash(self.i)
108 self.mhash = h
109 return h
110 else:
111 return self.mhash
113 def __eq__(self, o):
114 o = sympify(o)
115 if o.type == INTEGER:
116 return self.i == o.i
117 else:
118 return False
120 def __str__(self):
121 return str(self.i)
123 def __add__(self, o):
124 o = sympify(o)
125 if o.type == INTEGER:
126 return Integer(self.i+o.i)
127 return Basic.__add__(self, o)
129 def __mul__(self, o):
130 o = sympify(o)
131 if o.type == INTEGER:
132 return Integer(self.i*o.i)
133 return Basic.__mul__(self, o)
136 class Symbol(Basic):
138 def __new__(cls, name):
139 obj = Basic.__new__(cls, SYMBOL, [])
140 obj.name = name
141 return obj
143 def __hash__(self):
144 if self.mhash is None:
145 h = hash(self.name)
146 self.mhash = h
147 return h
148 else:
149 return self.mhash
151 def __eq__(self, o):
152 o = sympify(o)
153 if o.type == SYMBOL:
154 return self.name == o.name
155 return False
157 def __str__(self):
158 return self.name
161 class Add(Basic):
163 def __new__(cls, args, canonicalize=True):
164 if canonicalize == False:
165 obj = Basic.__new__(cls, ADD, args)
166 obj._args_set = None
167 return obj
168 args = [sympify(x) for x in args]
169 return Add.canonicalize(args)
171 @classmethod
172 def canonicalize(cls, args):
173 use_glib = 0
174 if use_glib:
175 from csympy import HashTable
176 d = HashTable()
177 else:
178 d = {}
179 num = Integer(0)
180 for a in args:
181 if a.type == INTEGER:
182 num += a
183 elif a.type == ADD:
184 for b in a.args:
185 if b.type == INTEGER:
186 num += b
187 else:
188 coeff, key = b.as_coeff_rest()
189 if key in d:
190 d[key] += coeff
191 else:
192 d[key] = coeff
193 else:
194 coeff, key = a.as_coeff_rest()
195 if key in d:
196 d[key] += coeff
197 else:
198 d[key] = coeff
199 if len(d)==0:
200 return num
201 args = []
202 for a, b in d.iteritems():
203 args.append(Mul((a, b)))
204 if num.i != 0:
205 args.insert(0, num)
206 if len(args) == 1:
207 return args[0]
208 else:
209 return Add(args, False)
211 def freeze_args(self):
212 #print "add is freezing"
213 if self._args_set is None:
214 self._args_set = frozenset(self.args)
215 #print "done"
217 def __eq__(self, o):
218 o = sympify(o)
219 if o.type == ADD:
220 self.freeze_args()
221 o.freeze_args()
222 return self._args_set == o._args_set
223 else:
224 return False
226 def __str__(self):
227 s = str(self.args[0])
228 if self.args[0].type == ADD:
229 s = "(%s)" % str(s)
230 for x in self.args[1:]:
231 s = "%s + %s" % (s, str(x))
232 if x.type == ADD:
233 s = "(%s)" % s
234 return s
236 def __hash__(self):
237 if self.mhash is None:
238 # XXX: it is surprising, but this is *not* faster:
239 #self.freeze_args()
240 #h = hash(self._args_set)
242 # this is faster:
243 a = list(self.args[:])
244 a.sort(key=hash)
245 h = hash_seq(a)
246 self.mhash = h
247 return h
248 else:
249 return self.mhash
251 def expand(self):
252 r = []
253 for term in self.args:
254 r.append( term.expand() )
255 return Add(r)
257 class Mul(Basic):
259 def __new__(cls, args, canonicalize=True):
260 if canonicalize == False:
261 obj = Basic.__new__(cls, MUL, args)
262 obj._args_set = None
263 return obj
264 args = [sympify(x) for x in args]
265 return Mul.canonicalize(args)
267 @classmethod
268 def canonicalize(cls, args):
269 use_glib = 0
270 if use_glib:
271 from csympy import HashTable
272 d = HashTable()
273 else:
274 d = {}
275 num = Integer(1)
276 for a in args:
277 if a.type == INTEGER:
278 num *= a
279 elif a.type == MUL:
280 for b in a.args:
281 if b.type == INTEGER:
282 num *= b
283 else:
284 key, coeff = b.as_base_exp()
285 if key in d:
286 d[key] += coeff
287 else:
288 d[key] = coeff
289 else:
290 key, coeff = a.as_base_exp()
291 if key in d:
292 d[key] += coeff
293 else:
294 d[key] = coeff
295 if num.i == 0 or len(d)==0:
296 return num
297 args = []
298 for a, b in d.iteritems():
299 args.append(Pow((a, b)))
300 if num.i != 1:
301 args.insert(0, num)
302 if len(args) == 1:
303 return args[0]
304 else:
305 return Mul(args, False)
307 def __hash__(self):
308 if self.mhash is None:
309 # in contrast to Add, here it is faster:
310 self.freeze_args()
311 h = hash(self._args_set)
312 # this is slower:
313 #a = list(self.args[:])
314 #a.sort(key=hash)
315 #h = hash_seq(a)
316 self.mhash = h
317 return h
318 else:
319 return self.mhash
321 def freeze_args(self):
322 #print "mul is freezing"
323 if self._args_set is None:
324 self._args_set = frozenset(self.args)
325 #print "done"
327 def __eq__(self, o):
328 o = sympify(o)
329 if o.type == MUL:
330 self.freeze_args()
331 o.freeze_args()
332 return self._args_set == o._args_set
333 else:
334 return False
337 def as_coeff_rest(self):
338 if self.args[0].type == INTEGER:
339 return self.as_two_terms()
340 return (Integer(1), self)
342 def as_two_terms(self):
343 args = self.args
344 a0 = args[0]
346 if len(args) == 2:
347 return a0, args[1]
348 else:
349 return (a0, Mul(args[1:], False))
352 def __str__(self):
353 s = str(self.args[0])
354 if self.args[0].type in [ADD, MUL]:
355 s = "(%s)" % str(s)
356 for x in self.args[1:]:
357 if x.type in [ADD, MUL]:
358 s = "%s * (%s)" % (s, str(x))
359 else:
360 s = "%s*%s" % (s, str(x))
361 return s
363 @classmethod
364 def expand_two(self, a, b):
366 Both a and b are assumed to be expanded.
368 if a.type == ADD and b.type == ADD:
369 terms = []
370 for x in a.args:
371 for y in b.args:
372 terms.append(x*y)
373 return Add(terms)
374 if a.type == ADD:
375 terms = []
376 for x in a.args:
377 terms.append(x*b)
378 return Add(terms)
379 if b.type == ADD:
380 terms = []
381 for y in b.args:
382 terms.append(a*y)
383 return Add(terms)
384 return a*b
386 def expand(self):
387 a, b = self.as_two_terms()
388 r = Mul.expand_two(a, b)
389 if r == self:
390 a = a.expand()
391 b = b.expand()
392 return Mul.expand_two(a, b)
393 else:
394 return r.expand()
396 class Pow(Basic):
398 def __new__(cls, args, canonicalize=True):
399 if canonicalize == False:
400 obj = Basic.__new__(cls, POW, args)
401 return obj
402 args = [sympify(x) for x in args]
403 return Pow.canonicalize(args)
405 @classmethod
406 def canonicalize(cls, args):
407 base, exp = args
408 if base.type == INTEGER:
409 if base.i == 0:
410 return Integer(0)
411 if base.i == 1:
412 return Integer(1)
413 if exp.type == INTEGER:
414 if exp.i == 0:
415 return Integer(1)
416 if exp.i == 1:
417 return base
418 if base.type == POW:
419 return Pow((base.args[0], base.args[1]*exp))
420 return Pow(args, False)
422 def __str__(self):
423 s = str(self.args[0])
424 if self.args[0].type == ADD:
425 s = "(%s)" % s
426 if self.args[1].type == ADD:
427 s = "%s^(%s)" % (s, str(self.args[1]))
428 else:
429 s = "%s^%s" % (s, str(self.args[1]))
430 return s
432 def as_base_exp(self):
433 return self.args
435 def expand(self):
436 base, exp = self.args
437 if base.type == ADD and exp.type == INTEGER:
438 n = exp.i
439 m = len(base.args)
440 #print "multi"
441 d = multinomial_coefficients(m, n)
442 #print "assembly"
443 r = []
444 for powers, coeff in d.iteritems():
445 if coeff == 1:
446 t = []
447 else:
448 t = [Integer(coeff)]
449 for x, p in zip(base.args, powers):
450 if p != 0:
451 t.append(Pow((x, p)))
452 assert len(t) != 0
453 if len(t) == 1:
454 t = t[0]
455 else:
456 t = Mul(t, False)
457 r.append(t)
458 r = Add(r, False)
459 #print "done"
460 return r
461 return self
463 def sympify(x):
464 if isinstance(x, int):
465 return Integer(x)
466 return x
468 def var(s):
470 Create a symbolic variable with the name *s*.
472 INPUT:
473 s -- a string, either a single variable name, or
474 a space separated list of variable names, or
475 a list of variable names.
477 NOTE: The new variable is both returned and automatically injected into
478 the parent's *global* namespace. It's recommended not to use "var" in
479 library code, it is better to use symbols() instead.
481 EXAMPLES:
482 We define some symbolic variables:
483 >>> var('m')
485 >>> var('n xx yy zz')
486 (n, xx, yy, zz)
487 >>> n
491 import re
492 import inspect
493 frame = inspect.currentframe().f_back
495 try:
496 if not isinstance(s, list):
497 s = re.split('\s|,', s)
499 res = []
501 for t in s:
502 # skip empty strings
503 if not t:
504 continue
505 sym = Symbol(t)
506 frame.f_globals[t] = sym
507 res.append(sym)
509 res = tuple(res)
510 if len(res) == 0: # var('')
511 res = None
512 elif len(res) == 1: # var('x')
513 res = res[0]
514 # otherwise var('a b ...')
515 return res
517 finally:
518 # we should explicitly break cyclic dependencies as stated in inspect
519 # doc
520 del frame
522 def binomial_coefficients(n):
523 """Return a dictionary containing pairs {(k1,k2) : C_kn} where
524 C_kn are binomial coefficients and n=k1+k2."""
525 d = {(0, n):1, (n, 0):1}
526 a = 1
527 for k in xrange(1, n//2+1):
528 a = (a * (n-k+1))//k
529 d[k, n-k] = d[n-k, k] = a
530 return d
532 def binomial_coefficients_list(n):
533 """ Return a list of binomial coefficients as rows of the Pascal's
534 triangle.
536 d = [1] * (n+1)
537 a = 1
538 for k in xrange(1, n//2+1):
539 a = (a * (n-k+1))//k
540 d[k] = d[n-k] = a
541 return d
543 def multinomial_coefficients(m, n, _tuple=tuple, _zip=zip):
544 """Return a dictionary containing pairs ``{(k1,k2,..,km) : C_kn}``
545 where ``C_kn`` are multinomial coefficients such that
546 ``n=k1+k2+..+km``.
548 For example:
550 >>> print multinomial_coefficients(2,5)
551 {(3, 2): 10, (1, 4): 5, (2, 3): 10, (5, 0): 1, (0, 5): 1, (4, 1): 5}
553 The algorithm is based on the following result:
555 Consider a polynomial and it's ``m``-th exponent::
557 P(x) = sum_{i=0}^m p_i x^k
558 P(x)^n = sum_{k=0}^{m n} a(n,k) x^k
560 The coefficients ``a(n,k)`` can be computed using the
561 J.C.P. Miller Pure Recurrence [see D.E.Knuth, Seminumerical
562 Algorithms, The art of Computer Programming v.2, Addison
563 Wesley, Reading, 1981;]::
565 a(n,k) = 1/(k p_0) sum_{i=1}^m p_i ((n+1)i-k) a(n,k-i),
567 where ``a(n,0) = p_0^n``.
570 if m==2:
571 return binomial_coefficients(n)
572 symbols = [(0,)*i + (1,) + (0,)*(m-i-1) for i in range(m)]
573 s0 = symbols[0]
574 p0 = [_tuple(aa-bb for aa,bb in _zip(s,s0)) for s in symbols]
575 r = {_tuple(aa*n for aa in s0):1}
576 r_get = r.get
577 r_update = r.update
578 l = [0] * (n*(m-1)+1)
579 l[0] = r.items()
580 for k in xrange(1, n*(m-1)+1):
581 d = {}
582 d_get = d.get
583 for i in xrange(1, min(m,k+1)):
584 nn = (n+1)*i-k
585 if not nn:
586 continue
587 t = p0[i]
588 for t2, c2 in l[k-i]:
589 tt = _tuple([aa+bb for aa,bb in _zip(t2,t)])
590 cc = nn * c2
591 b = d_get(tt)
592 if b is None:
593 d[tt] = cc
594 else:
595 cc = b + cc
596 if cc:
597 d[tt] = cc
598 else:
599 del d[tt]
600 r1 = [(t, c//k) for (t, c) in d.iteritems()]
601 l[k] = r1
602 r_update(r1)
603 return r