Add a special case to Mul.canonize that speed things up.
[sympyx.git] / sympy_pyx.pyx
blob428bbb97fd4b1074a6c732d87efca54838b8a7d9
1 # (1) Cython does not support __new__
3 # (2) what to do if we want
5 # cdef class Base:
6 # cdef virt_func(Base a, Base b):
7 # # here we ensure that a & b are of the same type
8 # ...
10 # cdef class Child(Base):
11 # cdef virt_func(Child a, Child b):
12 # ...
14 # ?
16 # currently we have to do:
18 # cdef class Child:
19 # cdef cirt_func(Child a, _Basic _b):
20 # cdef Child b = <Child>_b
23 # (3) @staticmethod for cdef methods?
25 # (4) nested cdef like in here:
27 # if ...:
28 # cdef _Basic a = ...
31 DEF BASIC = 0
32 DEF SYMBOL = 1
33 DEF ADD = 2
34 DEF MUL = 3
35 DEF POW = 4
36 DEF INTEGER = 5
38 cdef int hash_seq(args):
39 """
40 Hash of a sequence, that *depends* on the order of elements.
41 """
42 # make this more robust:
43 cdef int m = 2
44 for x in args:
45 m = hash(m + 1001 ^ hash(x))
46 return m
51 cdef class _Basic:
52 cdef int _type
53 cdef int hash
54 cdef tuple _args # XXX tuple -> list?
56 def __cinit__(self):
57 self.hash = -1
59 def __repr__(self):
60 return str(self)
62 def __hash__(self):
63 if self.hash == -1:
64 self.hash = self._hash()
66 return self.hash
68 cdef int _hash(self):
69 return hash_seq(self._args)
71 property args:
73 def __get__(self):
74 return self._args
76 # for Basic.__new__
77 def _set_rawargs(self, args):
78 self._args = args
81 property type:
83 def __get__(self):
84 return self._type
86 # XXX struct2
87 cpdef as_coeff_rest(self):
88 return (Integer(1), self)
90 cpdef as_base_exp(self):
91 return (self, Integer(1))
93 cpdef _Basic expand(self):
94 return self
96 # NOTE: there is no __rxxx__ methods in Cython/Pyrex
98 def __add__(x, y):
99 return Add((x, y))
101 def __sub__(x, y):
102 return Add((x, -y))
104 def __mul__(x, y):
105 return Mul((x, y))
107 def __div__(x, y):
108 return Mul((x, Pow((y, Integer(-1)))))
110 # FIXME we should get rid of z?
111 def __pow__(x, y, z):
112 return Pow((x, y))
114 def __neg__(x):
115 return Mul((Integer(-1), x))
117 def __pos__(x):
118 return x
120 # in subclasses, you can be sure that _equal(a, b) is called with exactly
121 # the same type, e.g.
123 # when _Add._equal(a, b) is called a and b are of ._type=ADD for sure
124 cdef int _equal(_Basic self, _Basic o):
125 # by default we compare ._args
126 return self._args == o._args
128 cdef bint equal(_Basic self, _Basic o):
129 if self._type != o._type:
130 return 0
132 # now we know self and o are of the same type, lets dispatch to their
133 # type's ._equal
134 return self._equal(o)
138 def __richcmp__(_Basic x, y, int op):
139 #print '__richcmp__ %s %s %i' % (x,y,op)
140 y = sympify(y)
142 # eq
143 if op==2:
144 return x.equal(y)
146 # ne
147 elif op==3:
148 return not x.equal(y)
151 else:
152 return False
156 cpdef Integer(i):
157 return _Integer(i)
160 cdef class _Integer(_Basic):
161 cdef object i # XXX object -> pyint?
163 def __cinit__(self, i):
164 self._type = INTEGER
165 self.i = i
167 cdef int _hash(self):
168 return hash(self.i)
170 cdef int _equal(_Integer self, _Basic o):
171 cdef _Integer other = <_Integer>o
172 return self.i == other.i
175 def __str__(_Integer self):
176 return str(self.i)
178 def __repr__(_Integer self):
179 return 'Integer(%i)' % self.i
181 # there is no __radd__ in pyrex
182 def __add__(_a, _b):
183 cdef _Basic a = sympify(_a)
184 cdef _Basic b = sympify(_b)
185 if a._type == INTEGER and b._type == INTEGER:
186 return Integer( (<_Integer>a).i + (<_Integer>b).i )
188 return _Basic.__add__(a, b)
190 # there is no __rmul__ in pyrex
191 def __mul__(_a, _b):
192 cdef _Basic a = sympify(_a)
193 cdef _Basic b = sympify(_b)
194 if a._type == INTEGER and b._type == INTEGER:
195 return Integer( (<_Integer>a).i * (<_Integer>b).i )
196 return _Basic.__mul__(a, b)
200 # Symbol.__new__
201 cpdef _Basic Symbol(name):
202 obj = _Symbol(name)
203 return obj
206 cdef class _Symbol(_Basic):
207 cdef object name # XXX object -> str
209 def __cinit__(self, name):
210 self._type = SYMBOL
211 self.name = name
213 cdef int _hash(self):
214 return hash(self.name)
216 cdef int _equal(_Symbol self, _Basic o):
217 cdef _Symbol other = <_Symbol>o
218 #print 'Symbol._equal %s %s' % (self.name, other.name)
219 return self.name == other.name
221 def __str__(_Symbol self):
222 return self.name
224 def __repr__(_Symbol self):
225 return 'Symbol(%s)' % self.name
229 # Add.__new__
230 cpdef _Basic Add(args):
231 args = [sympify(x) for x in args]
232 return _Add_canonicalize(args)
235 # @staticmethod
236 cdef _Basic _Add_canonicalize(args):
237 # use_glib = 0
238 # if use_glib:
239 # from csympy import HashTable
240 # d = HashTable()
241 # else:
242 # d = {}
244 cdef dict d = {}
246 cdef _Basic a
247 cdef _Basic b
248 cdef _Integer num = Integer(0)
250 cdef _Basic coeff
251 cdef _Basic key
253 for a in args:
254 if a._type == INTEGER:
255 num += a
256 elif a._type == ADD:
257 for b in a._args:
258 if b._type == INTEGER:
259 num += b
260 else:
261 coeff, key = b.as_coeff_rest()
262 if key in d:
263 d[key] += coeff
264 else:
265 d[key] = coeff
266 else:
267 coeff, key = a.as_coeff_rest()
268 if key in d:
269 d[key] += coeff
270 else:
271 d[key] = coeff
272 if len(d)==0:
273 return num
274 args = []
275 for a, b in d.iteritems():
276 args.append(Mul((a, b)))
277 if num.i != 0:
278 args.insert(0, num)
279 if len(args) == 1:
280 return args[0]
281 else:
282 return _Add(args)
285 cdef class _Add(_Basic):
286 cdef object _args_set # XXX object -> frozenset
288 def __cinit__(_Add self, args):
289 self._type = ADD
290 self._args = tuple(args)
294 def freeze_args(self):
295 #print "add is freezing"
296 if self._args_set is None:
297 self._args_set = frozenset(self._args)
298 #print "done"
300 cdef int _equal(_Add self, _Basic o):
301 cdef _Add other = <_Add>o
302 self .freeze_args()
303 other.freeze_args()
305 return self._args_set == other._args_set
308 def __str__(_Basic self):
309 cdef _Basic a = self._args[0]
310 s = str(a)
311 if a._type == ADD:
312 s = "(%s)" % str(s)
313 for a in self._args[1:]:
314 s = "%s + %s" % (s, str(a))
315 if a._type == ADD:
316 s = "(%s)" % s
317 return s
319 cdef int _hash(self):
320 # XXX: it is surprising, but this is *not* faster:
321 #self.freeze_args()
322 #h = hash(self._args_set)
324 # this is faster:
325 a = list(self._args)
326 a.sort(key=hash)
327 return hash_seq(a)
329 cpdef _Basic expand(self):
330 r = []
331 for term in self._args:
332 r.append( term.expand() )
334 return Add(r)
336 cpdef _Basic Mul(args):
337 args = [sympify(x) for x in args]
338 return _Mul_canonicalize(args)
341 cdef _Basic _Mul_canonicalize(args):
342 # use_glib = 0
343 # if use_glib:
344 # from csympy import HashTable
345 # d = HashTable()
346 # else:
347 # d = {}
349 cdef dict d = {}
351 cdef _Basic a
352 cdef _Basic b
353 cdef _Integer num = Integer(1)
355 for a in args:
356 if a._type == INTEGER:
357 num *= a
358 elif a._type == MUL:
359 for b in a._args:
360 if b._type == INTEGER:
361 num *= b
362 else:
363 key, coeff = b.as_base_exp()
364 if key in d:
365 d[key] += coeff
366 else:
367 d[key] = coeff
368 else:
369 key, coeff = a.as_base_exp()
370 if key in d:
371 d[key] += coeff
372 else:
373 d[key] = coeff
374 if num.i == 0 or len(d)==0:
375 return num
376 args = []
377 for a, b in d.iteritems():
378 args.append(Pow((a, b)))
379 if num.i != 1:
380 args.insert(0, num)
381 if len(args) == 1:
382 return args[0]
383 else:
384 return _Mul(args)
388 # @staticmethod
389 cdef _Basic _Mul_expand_two(_Basic a, _Basic b):
391 Both a and b are assumed to be expanded.
393 cdef _Basic r
394 cdef _Basic x
395 cdef _Basic y
397 if a._type == ADD and b._type == ADD:
398 terms = []
399 for x in a._args:
400 for y in b._args:
401 terms.append(x*y)
402 return Add(terms)
403 if a._type == ADD:
404 terms = []
405 for x in a._args:
406 terms.append(x*b)
407 return Add(terms)
408 if b._type == ADD:
409 terms = []
410 for y in b._args:
411 terms.append(a*y)
412 return Add(terms)
413 return a*b
415 cdef class _Mul(_Basic):
416 cdef object _args_set # XXX object -> frozenset
418 def __cinit__(self, args):
419 self._type = MUL
420 self._args= tuple(args)
421 self._args_set = None
424 cdef int _hash(self):
425 # in contrast to Add, here it is faster:
426 self.freeze_args()
427 return hash(self._args_set)
428 # this is slower:
429 #a = list(self._args[:])
430 #a.sort(key=hash)
431 #h = hash_seq(a)
432 #return h
434 def freeze_args(self):
435 #print "mul is freezing"
436 if self._args_set is None:
437 self._args_set = frozenset(self._args)
438 #print "done"
441 cdef int _equal(_Mul self, _Basic o):
442 cdef _Mul other = <_Mul>o
443 self .freeze_args()
444 other.freeze_args()
445 return self._args_set == other._args_set
448 cpdef as_coeff_rest(self):
449 cdef _Basic a = self._args[0]
451 if a._type == INTEGER:
452 return self.as_two_terms()
453 return (Integer(1), self)
455 # XXX struct2
456 cpdef as_two_terms(_Mul self):
457 cdef _Basic a0 = self._args[0]
458 if len(self._args) == 2:
459 return a0, self._args[1]
461 else:
462 # XXX _Mul is ok here (like ._new_rawargs)
463 return (self._args[0], _Mul(self._args[1:]))
467 def __str__(self):
468 cdef _Basic a = self._args[0]
469 s = str(a)
470 if a._type in [ADD, MUL]:
471 s = "(%s)" % str(s)
472 for a in self._args[1:]:
473 if a._type in [ADD, MUL]:
474 s = "%s * (%s)" % (s, str(a))
475 else:
476 s = "%s*%s" % (s, str(a))
477 return s
480 cpdef _Basic expand(self):
481 cdef _Basic a
482 cdef _Basic b
483 cdef _Basic r
484 a, b = self.as_two_terms()
485 r = _Mul_expand_two(a, b)
486 if r == self:
487 a = a.expand()
488 b = b.expand()
489 return _Mul_expand_two(a, b)
490 else:
491 return r.expand()
493 # Pow.__new__
494 cpdef _Basic Pow(args):
495 args = [sympify(x) for x in args]
496 return _Pow_canonicalize(args)
498 # @staticmethod
499 cdef _Basic _Pow_canonicalize(args):
500 cdef _Basic base
501 cdef _Basic exp
502 base, exp = args
504 cdef _Integer b = <_Integer>base
505 cdef _Integer e = <_Integer>exp
507 if base._type == INTEGER:
508 if b.i == 0:
509 return b # Integer(0)
510 if b.i == 1:
511 return b # Integer(1)
512 if exp._type == INTEGER:
513 if e.i == 0:
514 return Integer(1)
515 if e.i == 1:
516 return base
517 if base._type == POW:
518 return Pow((base._args[0], base._args[1]*exp))
519 return _Pow(args)
523 cdef class _Pow(_Basic):
525 def __cinit__(self, args):
526 self._type = POW
527 self._args= tuple(args)
529 def __str__(_Pow self):
530 cdef _Basic b = self._args[0]
531 cdef _Basic e = self._args[1]
532 s = str(b)
533 if b._type == ADD:
534 s = "(%s)" % s
536 if e._type == ADD:
537 s = "%s^(%s)" % (s, str(e))
538 else:
539 s = "%s^%s" % (s, str(e))
540 return s
542 # XXX struct 2
543 cpdef as_base_exp(_Pow self):
544 return self._args
546 cpdef _Basic expand(_Pow self):
547 cdef _Basic _base = self._args[0]
548 cdef _Basic _exp = self._args[1]
550 # XXX please careful here - use it only after appropriate check
551 cdef _Add base = <_Add>_base
552 cdef _Integer exp = <_Integer>_exp
554 cdef int p
556 cdef list r
557 cdef list t
558 cdef _Basic term
559 cdef _Basic ret
561 if _base._type == ADD and _exp._type == INTEGER:
562 n = exp.i
563 m = len(base._args)
564 #print "multi"
565 d = multinomial_coefficients(m, n)
566 #print "assembly"
567 r = []
568 for powers, coeff in d.iteritems():
569 if coeff == 1:
570 t = []
571 else:
572 t = [Integer(coeff)]
573 for x, p in zip(base._args, powers):
574 if p != 0:
575 t.append(Pow((x, p)))
576 #t.append(_Pow((x, Integer(p)))) # XXX _Pow -> Pow
577 assert len(t) != 0
578 if len(t) == 1:
579 term = t[0]
580 else:
581 term = _Mul(t)
582 r.append(term)
584 ret = _Add(r)
585 #print "done"
586 return ret
588 return self
590 cpdef _Basic sympify(x):
591 if isinstance(x, int):
592 return Integer(x)
593 return x
595 def binomial_coefficients(n):
596 """Return a dictionary containing pairs {(k1,k2) : C_kn} where
597 C_kn are binomial coefficients and n=k1+k2."""
598 d = {(0, n):1, (n, 0):1}
599 a = 1
600 for k in xrange(1, n//2+1):
601 a = (a * (n-k+1))//k
602 d[k, n-k] = d[n-k, k] = a
603 return d
605 def binomial_coefficients_list(n):
606 """ Return a list of binomial coefficients as rows of the Pascal's
607 triangle.
609 d = [1] * (n+1)
610 a = 1
611 for k in xrange(1, n//2+1):
612 a = (a * (n-k+1))//k
613 d[k] = d[n-k] = a
614 return d
616 def multinomial_coefficients(m, n, _tuple=tuple, _zip=zip):
617 """Return a dictionary containing pairs ``{(k1,k2,..,km) : C_kn}``
618 where ``C_kn`` are multinomial coefficients such that
619 ``n=k1+k2+..+km``.
621 For example:
623 >>> print multinomial_coefficients(2,5)
624 {(3, 2): 10, (1, 4): 5, (2, 3): 10, (5, 0): 1, (0, 5): 1, (4, 1): 5}
626 The algorithm is based on the following result:
628 Consider a polynomial and it's ``m``-th exponent::
630 P(x) = sum_{i=0}^m p_i x^k
631 P(x)^n = sum_{k=0}^{m n} a(n,k) x^k
633 The coefficients ``a(n,k)`` can be computed using the
634 J.C.P. Miller Pure Recurrence [see D.E.Knuth, Seminumerical
635 Algorithms, The art of Computer Programming v.2, Addison
636 Wesley, Reading, 1981;]::
638 a(n,k) = 1/(k p_0) sum_{i=1}^m p_i ((n+1)i-k) a(n,k-i),
640 where ``a(n,0) = p_0^n``.
643 if m==2:
644 return binomial_coefficients(n)
645 symbols = [(0,)*i + (1,) + (0,)*(m-i-1) for i in range(m)]
646 s0 = symbols[0]
647 p0 = [_tuple([aa-bb for aa,bb in _zip(s,s0)]) for s in symbols]
648 r = {_tuple([aa*n for aa in s0]):1}
649 r_get = r.get
650 r_update = r.update
651 l = [0] * (n*(m-1)+1)
652 l[0] = r.items()
653 for k in xrange(1, n*(m-1)+1):
654 d = {}
655 d_get = d.get
656 for i in xrange(1, min(m,k+1)):
657 nn = (n+1)*i-k
658 if not nn:
659 continue
660 t = p0[i]
661 for t2, c2 in l[k-i]:
662 tt = _tuple([aa+bb for aa,bb in _zip(t2,t)])
663 cc = nn * c2
664 b = d_get(tt)
665 if b is None:
666 d[tt] = cc
667 else:
668 cc = b + cc
669 if cc:
670 d[tt] = cc
671 else:
672 del d[tt]
673 r1 = [(t, c//k) for (t, c) in d.iteritems()]
674 l[k] = r1
675 r_update(r1)
676 return r