Mul.canonize ported from pure python.
[sympyx.git] / sympy_pyx.pyx
blob0c038c583daaf6ae46ed8ff9037ebaa7e70f086d
1 # (1) Cython does not support __new__
3 # (2) what to do if we want
5 # cdef class Base:
6 # cdef virt_func(Base a, Base b):
7 # # here we ensure that a & b are of the same type
8 # ...
10 # cdef class Child(Base):
11 # cdef virt_func(Child a, Child b):
12 # ...
14 # ?
16 # currently we have to do:
18 # cdef class Child:
19 # cdef cirt_func(Child a, _Basic _b):
20 # cdef Child b = <Child>_b
23 # (3) @staticmethod for cdef methods?
25 # (4) nested cdef like in here:
27 # if ...:
28 # cdef _Basic a = ...
31 DEF BASIC = 0
32 DEF SYMBOL = 1
33 DEF ADD = 2
34 DEF MUL = 3
35 DEF POW = 4
36 DEF INTEGER = 5
38 cdef int hash_seq(args):
39 """
40 Hash of a sequence, that *depends* on the order of elements.
41 """
42 # make this more robust:
43 cdef int m = 2
44 for x in args:
45 m = hash(m + 1001 ^ hash(x))
46 return m
51 cdef class _Basic:
52 cdef int _type
53 cdef int hash
54 cdef tuple _args # XXX tuple -> list?
56 def __cinit__(self):
57 self.hash = -1
59 def __repr__(self):
60 return str(self)
62 def __hash__(self):
63 if self.hash == -1:
64 self.hash = self._hash()
66 return self.hash
68 cdef int _hash(self):
69 return hash_seq(self._args)
71 property args:
73 def __get__(self):
74 return self._args
76 # for Basic.__new__
77 def _set_rawargs(self, args):
78 self._args = args
81 property type:
83 def __get__(self):
84 return self._type
86 # XXX struct2
87 cpdef as_coeff_rest(self):
88 return (Integer(1), self)
90 cpdef as_base_exp(self):
91 return (self, Integer(1))
93 cpdef _Basic expand(self):
94 return self
96 # NOTE: there is no __rxxx__ methods in Cython/Pyrex
98 def __add__(x, y):
99 return Add((x, y))
101 def __sub__(x, y):
102 return Add((x, -y))
104 def __mul__(x, y):
105 return Mul((x, y))
107 def __div__(x, y):
108 return Mul((x, Pow((y, Integer(-1)))))
110 # FIXME we should get rid of z?
111 def __pow__(x, y, z):
112 return Pow((x, y))
114 def __neg__(x):
115 return Mul((Integer(-1), x))
117 def __pos__(x):
118 return x
120 # in subclasses, you can be sure that _equal(a, b) is called with exactly
121 # the same type, e.g.
123 # when _Add._equal(a, b) is called a and b are of ._type=ADD for sure
124 cdef int _equal(_Basic self, _Basic o):
125 # by default we compare ._args
126 return self._args == o._args
128 cdef bint equal(_Basic self, _Basic o):
129 if self._type != o._type:
130 return 0
132 # now we know self and o are of the same type, lets dispatch to their
133 # type's ._equal
134 return self._equal(o)
138 def __richcmp__(_Basic x, y, int op):
139 #print '__richcmp__ %s %s %i' % (x,y,op)
140 y = sympify(y)
142 # eq
143 if op==2:
144 return x.equal(y)
146 # ne
147 elif op==3:
148 return not x.equal(y)
151 else:
152 return False
156 cpdef Integer(i):
157 return _Integer(i)
160 cdef class _Integer(_Basic):
161 cdef object i # XXX object -> pyint?
163 def __cinit__(self, i):
164 self._type = INTEGER
165 self.i = i
167 cdef int _hash(self):
168 return hash(self.i)
170 cdef int _equal(_Integer self, _Basic o):
171 cdef _Integer other = <_Integer>o
172 return self.i == other.i
175 def __str__(_Integer self):
176 return str(self.i)
178 def __repr__(_Integer self):
179 return 'Integer(%i)' % self.i
181 # there is no __radd__ in pyrex
182 def __add__(_a, _b):
183 cdef _Basic a = sympify(_a)
184 cdef _Basic b = sympify(_b)
185 if a._type == INTEGER and b._type == INTEGER:
186 return Integer( (<_Integer>a).i + (<_Integer>b).i )
188 return _Basic.__add__(a, b)
190 # there is no __rmul__ in pyrex
191 def __mul__(_a, _b):
192 cdef _Basic a = sympify(_a)
193 cdef _Basic b = sympify(_b)
194 if a._type == INTEGER and b._type == INTEGER:
195 return Integer( (<_Integer>a).i * (<_Integer>b).i )
196 return _Basic.__mul__(a, b)
200 # Symbol.__new__
201 cpdef _Basic Symbol(name):
202 obj = _Symbol(name)
203 return obj
206 cdef class _Symbol(_Basic):
207 cdef object name # XXX object -> str
209 def __cinit__(self, name):
210 self._type = SYMBOL
211 self.name = name
213 cdef int _hash(self):
214 return hash(self.name)
216 cdef int _equal(_Symbol self, _Basic o):
217 cdef _Symbol other = <_Symbol>o
218 #print 'Symbol._equal %s %s' % (self.name, other.name)
219 return self.name == other.name
221 def __str__(_Symbol self):
222 return self.name
224 def __repr__(_Symbol self):
225 return 'Symbol(%s)' % self.name
229 # Add.__new__
230 cpdef _Basic Add(args):
231 args = [sympify(x) for x in args]
232 return _Add_canonicalize(args)
235 # @staticmethod
236 cdef _Basic _Add_canonicalize(args):
237 # use_glib = 0
238 # if use_glib:
239 # from csympy import HashTable
240 # d = HashTable()
241 # else:
242 # d = {}
244 cdef dict d = {}
246 cdef _Basic a
247 cdef _Basic b
248 cdef _Integer num = Integer(0)
250 cdef _Basic coeff
251 cdef _Basic key
253 for a in args:
254 if a._type == INTEGER:
255 num += a
256 elif a._type == ADD:
257 for b in a._args:
258 if b._type == INTEGER:
259 num += b
260 else:
261 coeff, key = b.as_coeff_rest()
262 if key in d:
263 d[key] += coeff
264 else:
265 d[key] = coeff
266 else:
267 coeff, key = a.as_coeff_rest()
268 if key in d:
269 d[key] += coeff
270 else:
271 d[key] = coeff
272 if len(d)==0:
273 return num
274 args = []
275 for a, b in d.iteritems():
276 args.append(Mul((a, b)))
277 if num.i != 0:
278 args.insert(0, num)
279 if len(args) == 1:
280 return args[0]
281 else:
282 return _Add(args)
285 cdef class _Add(_Basic):
286 cdef object _args_set # XXX object -> frozenset
288 def __cinit__(_Add self, args):
289 self._type = ADD
290 self._args = tuple(args)
294 def freeze_args(self):
295 #print "add is freezing"
296 if self._args_set is None:
297 self._args_set = frozenset(self._args)
298 #print "done"
300 cdef int _equal(_Add self, _Basic o):
301 cdef _Add other = <_Add>o
302 self .freeze_args()
303 other.freeze_args()
305 return self._args_set == other._args_set
308 def __str__(_Basic self):
309 cdef _Basic a = self._args[0]
310 s = str(a)
311 if a._type == ADD:
312 s = "(%s)" % str(s)
313 for a in self._args[1:]:
314 s = "%s + %s" % (s, str(a))
315 if a._type == ADD:
316 s = "(%s)" % s
317 return s
319 cdef int _hash(self):
320 # XXX: it is surprising, but this is *not* faster:
321 #self.freeze_args()
322 #h = hash(self._args_set)
324 # this is faster:
325 a = list(self._args)
326 a.sort(key=hash)
327 return hash_seq(a)
329 cpdef _Basic expand(self):
330 r = []
331 for term in self._args:
332 r.append( term.expand() )
334 return Add(r)
336 cpdef _Basic Mul(args):
337 args = [sympify(x) for x in args]
338 return _Mul_canonicalize(args)
341 cdef _Basic _Mul_canonicalize(args):
342 # use_glib = 0
343 # if use_glib:
344 # from csympy import HashTable
345 # d = HashTable()
346 # else:
347 # d = {}
349 cdef dict d = {}
351 cdef _Basic a
352 cdef _Basic b
354 if len(args) == 2 and args[0].type == MUL and args[1].type == INTEGER:
355 a, b = args
356 if (<_Integer>b).i == 1:
357 return a
358 if (<_Integer>b).i == 0:
359 return b
360 if a.args[0].type == INTEGER:
361 if (<_Integer>(a.args[0])).i == 1:
362 args = (b,) + a.args[1:]
363 else:
364 args = (b*a.args[0],) + a.args[1:]
365 else:
366 args = (b,)+a.args
367 return _Mul(args)
370 cdef _Integer num = Integer(1)
372 for a in args:
373 if a._type == INTEGER:
374 num *= a
375 elif a._type == MUL:
376 for b in a._args:
377 if b._type == INTEGER:
378 num *= b
379 else:
380 key, coeff = b.as_base_exp()
381 if key in d:
382 d[key] += coeff
383 else:
384 d[key] = coeff
385 else:
386 key, coeff = a.as_base_exp()
387 if key in d:
388 d[key] += coeff
389 else:
390 d[key] = coeff
391 if num.i == 0 or len(d)==0:
392 return num
393 args = []
394 for a, b in d.iteritems():
395 args.append(Pow((a, b)))
396 if num.i != 1:
397 args.insert(0, num)
398 if len(args) == 1:
399 return args[0]
400 else:
401 return _Mul(args)
405 # @staticmethod
406 cdef _Basic _Mul_expand_two(_Basic a, _Basic b):
408 Both a and b are assumed to be expanded.
410 cdef _Basic r
411 cdef _Basic x
412 cdef _Basic y
414 if a._type == ADD and b._type == ADD:
415 terms = []
416 for x in a._args:
417 for y in b._args:
418 terms.append(x*y)
419 return Add(terms)
420 if a._type == ADD:
421 terms = []
422 for x in a._args:
423 terms.append(x*b)
424 return Add(terms)
425 if b._type == ADD:
426 terms = []
427 for y in b._args:
428 terms.append(a*y)
429 return Add(terms)
430 return a*b
432 cdef class _Mul(_Basic):
433 cdef object _args_set # XXX object -> frozenset
435 def __cinit__(self, args):
436 self._type = MUL
437 self._args= tuple(args)
438 self._args_set = None
441 cdef int _hash(self):
442 # in contrast to Add, here it is faster:
443 self.freeze_args()
444 return hash(self._args_set)
445 # this is slower:
446 #a = list(self._args[:])
447 #a.sort(key=hash)
448 #h = hash_seq(a)
449 #return h
451 def freeze_args(self):
452 #print "mul is freezing"
453 if self._args_set is None:
454 self._args_set = frozenset(self._args)
455 #print "done"
458 cdef int _equal(_Mul self, _Basic o):
459 cdef _Mul other = <_Mul>o
460 self .freeze_args()
461 other.freeze_args()
462 return self._args_set == other._args_set
465 cpdef as_coeff_rest(self):
466 cdef _Basic a = self._args[0]
468 if a._type == INTEGER:
469 return self.as_two_terms()
470 return (Integer(1), self)
472 # XXX struct2
473 cpdef as_two_terms(_Mul self):
474 cdef _Basic a0 = self._args[0]
475 if len(self._args) == 2:
476 return a0, self._args[1]
478 else:
479 # XXX _Mul is ok here (like ._new_rawargs)
480 return (self._args[0], _Mul(self._args[1:]))
484 def __str__(self):
485 cdef _Basic a = self._args[0]
486 s = str(a)
487 if a._type in [ADD, MUL]:
488 s = "(%s)" % str(s)
489 for a in self._args[1:]:
490 if a._type in [ADD, MUL]:
491 s = "%s * (%s)" % (s, str(a))
492 else:
493 s = "%s*%s" % (s, str(a))
494 return s
497 cpdef _Basic expand(self):
498 cdef _Basic a
499 cdef _Basic b
500 cdef _Basic r
501 a, b = self.as_two_terms()
502 r = _Mul_expand_two(a, b)
503 if r == self:
504 a = a.expand()
505 b = b.expand()
506 return _Mul_expand_two(a, b)
507 else:
508 return r.expand()
510 # Pow.__new__
511 cpdef _Basic Pow(args):
512 args = [sympify(x) for x in args]
513 return _Pow_canonicalize(args)
515 # @staticmethod
516 cdef _Basic _Pow_canonicalize(args):
517 cdef _Basic base
518 cdef _Basic exp
519 base, exp = args
521 cdef _Integer b = <_Integer>base
522 cdef _Integer e = <_Integer>exp
524 if base._type == INTEGER:
525 if b.i == 0:
526 return b # Integer(0)
527 if b.i == 1:
528 return b # Integer(1)
529 if exp._type == INTEGER:
530 if e.i == 0:
531 return Integer(1)
532 if e.i == 1:
533 return base
534 if base._type == POW:
535 return Pow((base._args[0], base._args[1]*exp))
536 return _Pow(args)
540 cdef class _Pow(_Basic):
542 def __cinit__(self, args):
543 self._type = POW
544 self._args= tuple(args)
546 def __str__(_Pow self):
547 cdef _Basic b = self._args[0]
548 cdef _Basic e = self._args[1]
549 s = str(b)
550 if b._type == ADD:
551 s = "(%s)" % s
553 if e._type == ADD:
554 s = "%s^(%s)" % (s, str(e))
555 else:
556 s = "%s^%s" % (s, str(e))
557 return s
559 # XXX struct 2
560 cpdef as_base_exp(_Pow self):
561 return self._args
563 cpdef _Basic expand(_Pow self):
564 cdef _Basic _base = self._args[0]
565 cdef _Basic _exp = self._args[1]
567 # XXX please careful here - use it only after appropriate check
568 cdef _Add base = <_Add>_base
569 cdef _Integer exp = <_Integer>_exp
571 cdef int p
573 cdef list r
574 cdef list t
575 cdef _Basic term
576 cdef _Basic ret
577 cdef _Basic tt
579 if _base._type == ADD and _exp._type == INTEGER:
580 n = exp.i
581 m = len(base._args)
582 #print "multi"
583 d = multinomial_coefficients(m, n)
584 #print "assembly"
585 r = []
586 for powers, coeff in d.iteritems():
587 if coeff == 1:
588 t = []
589 else:
590 t = [Integer(coeff)]
591 for x, p in zip(base._args, powers):
592 if p != 0:
593 if p == 1:
594 tt = x
595 else:
596 tt = _Pow_canonicalize((x, Integer(p)))
597 t.append(tt)
598 #t.append(_Pow((x, Integer(p)))) # XXX _Pow -> Pow
599 assert len(t) != 0
600 if len(t) == 1:
601 term = t[0]
602 else:
603 term = _Mul(t)
604 r.append(term)
606 ret = _Add(r)
607 #print "done"
608 return ret
610 return self
612 cpdef _Basic sympify(x):
613 if isinstance(x, int):
614 return Integer(x)
615 return x
617 def binomial_coefficients(n):
618 """Return a dictionary containing pairs {(k1,k2) : C_kn} where
619 C_kn are binomial coefficients and n=k1+k2."""
620 d = {(0, n):1, (n, 0):1}
621 a = 1
622 for k in xrange(1, n//2+1):
623 a = (a * (n-k+1))//k
624 d[k, n-k] = d[n-k, k] = a
625 return d
627 def binomial_coefficients_list(n):
628 """ Return a list of binomial coefficients as rows of the Pascal's
629 triangle.
631 d = [1] * (n+1)
632 a = 1
633 for k in xrange(1, n//2+1):
634 a = (a * (n-k+1))//k
635 d[k] = d[n-k] = a
636 return d
638 def multinomial_coefficients(m, n, _tuple=tuple, _zip=zip):
639 """Return a dictionary containing pairs ``{(k1,k2,..,km) : C_kn}``
640 where ``C_kn`` are multinomial coefficients such that
641 ``n=k1+k2+..+km``.
643 For example:
645 >>> print multinomial_coefficients(2,5)
646 {(3, 2): 10, (1, 4): 5, (2, 3): 10, (5, 0): 1, (0, 5): 1, (4, 1): 5}
648 The algorithm is based on the following result:
650 Consider a polynomial and it's ``m``-th exponent::
652 P(x) = sum_{i=0}^m p_i x^k
653 P(x)^n = sum_{k=0}^{m n} a(n,k) x^k
655 The coefficients ``a(n,k)`` can be computed using the
656 J.C.P. Miller Pure Recurrence [see D.E.Knuth, Seminumerical
657 Algorithms, The art of Computer Programming v.2, Addison
658 Wesley, Reading, 1981;]::
660 a(n,k) = 1/(k p_0) sum_{i=1}^m p_i ((n+1)i-k) a(n,k-i),
662 where ``a(n,0) = p_0^n``.
665 if m==2:
666 return binomial_coefficients(n)
667 symbols = [(0,)*i + (1,) + (0,)*(m-i-1) for i in range(m)]
668 s0 = symbols[0]
669 p0 = [_tuple([aa-bb for aa,bb in _zip(s,s0)]) for s in symbols]
670 r = {_tuple([aa*n for aa in s0]):1}
671 r_get = r.get
672 r_update = r.update
673 l = [0] * (n*(m-1)+1)
674 l[0] = r.items()
675 for k in xrange(1, n*(m-1)+1):
676 d = {}
677 d_get = d.get
678 for i in xrange(1, min(m,k+1)):
679 nn = (n+1)*i-k
680 if not nn:
681 continue
682 t = p0[i]
683 for t2, c2 in l[k-i]:
684 tt = _tuple([aa+bb for aa,bb in _zip(t2,t)])
685 cc = nn * c2
686 b = d_get(tt)
687 if b is None:
688 d[tt] = cc
689 else:
690 cc = b + cc
691 if cc:
692 d[tt] = cc
693 else:
694 del d[tt]
695 r1 = [(t, c//k) for (t, c) in d.iteritems()]
696 l[k] = r1
697 r_update(r1)
698 return r