Add: optimize collect of poly terms (2*x*y + 3*x*z) + general speedup
[sympy.git] / sympy / core / mul.py
blob6c17614ddb78b4c8d5efa9ea429de31f14be8a93
2 from basic import Basic, S, C, sympify
3 from operations import AssocOp
4 from cache import cacheit
6 from logic import fuzzy_not
8 from symbol import Symbol, Wild
9 # from function import FunctionClass, WildFunction /cyclic/
10 # from numbers import Number, Integer, Real /cyclic/
11 # from add import Add /cyclic/
12 # from power import Pow /cyclic/
14 import sympy.mpmath as mpmath
16 class Mul(AssocOp):
18 __slots__ = []
20 is_Mul = True
22 @classmethod
23 def flatten(cls, seq):
24 # apply associativity, separate commutative part of seq
25 c_part = []
26 nc_part = []
28 c_seq = []
29 nc_seq = seq
31 coeff = S.One # standalone term
32 # e.g. 3 * ...
34 c_powers = {} # base -> exp z
35 # e.g. (x+y) -> z for ... * (x+y) * ...
37 exp_dict = {} # num-base -> exp y
38 # e.g. 3 -> y for ... * 3 * ...
40 inv_exp_dict = {} # exp -> Mul(num-bases) x x
41 # e.g. x -> 6 for ... * 2 * 3 * ...
43 order_symbols = None
46 while c_seq or nc_seq:
48 # COMMUTATIVE
49 if c_seq:
50 # first process commutative objects
51 o = c_seq.pop(0)
53 # O(x)
54 if o.is_Order:
55 o, order_symbols = o.as_expr_symbols(order_symbols)
57 # Mul([...])
58 if o.is_Mul:
59 # associativity
60 c_seq = list(o.args[:]) + c_seq
61 continue
63 # 3
64 if o.is_Number:
65 coeff *= o
66 continue
68 # y
69 # x
70 if o.is_Pow:
71 base, exponent = o.as_base_exp()
73 # y
74 # 3
75 if base.is_Number:
77 # let's collect factors with numeric base
78 if base in exp_dict:
79 exp_dict[base] += exponent
80 else:
81 exp_dict[base] = exponent
82 continue
84 # exp(x)
85 if o.func is C.exp:
86 # exp(x) / exp(y) -> exp(x-y)
87 b = S.Exp1
88 e = o.args[0]
90 # everything else
91 else:
92 b, e = o.as_base_exp()
94 # now we have
95 # o = b**e
97 # n n n
98 # (-3 + y) -> (-1) * (3 - y)
99 if b.is_Add and e.is_Number:
100 #found factor (x+y)**number; split off initial coefficient
101 c, t = b.as_coeff_terms()
102 #last time I checked, Add.as_coeff_terms returns One or NegativeOne
103 #but this might change
104 if c.is_negative and not e.is_integer:
105 # extracting root from negative number: ignore sign
106 if c is not S.NegativeOne:
107 # make c positive (probably never occurs)
108 coeff *= (-c) ** e
109 assert len(t)==1,`t`
110 b = -t[0]
111 #else: ignoring sign from NegativeOne: nothing to do!
112 elif c is not S.One:
113 coeff *= c ** e
114 assert len(t)==1,`t`
115 b = t[0]
116 #else: c is One, so pass
118 # let's collect factors with the same base, so e.g.
119 # y z y+z
120 # x * x -> x
121 if b in c_powers:
122 c_powers[b] += e
123 else:
124 c_powers[b] = e
127 # NON-COMMUTATIVE
128 else:
129 o = nc_seq.pop(0)
130 if isinstance(o, WildFunction):
131 pass
132 elif o.is_Order:
133 o, order_symbols = o.as_expr_symbols(order_symbols)
135 # -> commutative
136 if o.is_commutative:
137 # separate commutative symbols
138 c_seq.append(o)
139 continue
141 # Mul([...])
142 if o.__class__ is cls:
143 # associativity
144 nc_seq = list(o.args) + nc_seq
145 continue
146 if not nc_part:
147 nc_part.append(o)
148 continue
150 # b c b+c
151 # try to combine last terms: a * a -> a
152 o1 = nc_part.pop()
153 b1,e1 = o1.as_base_exp()
154 b2,e2 = o.as_base_exp()
155 if b1==b2:
156 nc_seq.insert(0, b1 ** (e1 + e2))
157 else:
158 nc_part.append(o1)
159 nc_part.append(o)
162 # ................................
163 # now we have:
164 # - coeff:
165 # - c_powers: b -> e
166 # - exp_dict: 3 -> e
168 # XXX
169 for b, e in c_powers.items():
170 if e is S.Zero:
171 continue
173 if e is S.One:
174 if b.is_Number:
175 coeff *= b
176 else:
177 c_part.append(b)
178 elif e.is_Integer and b.is_Number:
179 coeff *= b ** e
180 else:
181 c_part.append(Pow(b, e))
183 # x x x
184 # 2 * 3 -> 6
185 for b,e in exp_dict.items():
186 if e in inv_exp_dict:
187 inv_exp_dict[e] *= b
188 else:
189 inv_exp_dict[e] = b
191 for e,b in inv_exp_dict.items():
192 if e is S.Zero:
193 continue
195 if e is S.One:
196 if b.is_Number:
197 coeff *= b
198 else:
199 c_part.append(b)
200 elif e.is_Integer and b.is_Number:
201 coeff *= b ** e
202 else:
203 obj = b**e
204 if obj.is_Number:
205 coeff *= obj
206 else:
207 c_part.append(obj)
210 # deal with
211 # (oo|nan|zero) * ...
212 if (coeff is S.Infinity) or (coeff is S.NegativeInfinity):
213 new_c_part = []
214 for t in c_part:
215 if t.is_positive:
216 continue
217 if t.is_negative:
218 coeff = -coeff
219 continue
220 new_c_part.append(t)
221 c_part = new_c_part
222 new_nc_part = []
223 for t in nc_part:
224 if t.is_positive:
225 continue
226 if t.is_negative:
227 coeff = -coeff
228 continue
229 new_nc_part.append(t)
230 nc_part = new_nc_part
231 c_part.insert(0, coeff)
232 elif (coeff is S.Zero) or (coeff is S.NaN):
233 c_part, nc_part = [coeff], []
234 elif coeff.is_Real:
235 if coeff == Real(0):
236 c_part, nc_part = [coeff], []
237 elif coeff != Real(1):
238 c_part.insert(0, coeff)
239 elif coeff is not S.One:
240 c_part.insert(0, coeff)
242 # order commutative part canonically
243 c_part.sort(Basic.compare)
245 # we are done
246 if len(c_part)==2 and c_part[0].is_Number and c_part[1].is_Add:
247 # 2*(1+a) -> 2 + 2 * a
248 coeff = c_part[0]
249 c_part = [Add(*[coeff*f for f in c_part[1].args])]
251 return c_part, nc_part, order_symbols
254 def _eval_power(b, e):
255 if e.is_Number:
256 if b.is_commutative:
257 if e.is_Integer:
258 # (a*b)**2 -> a**2 * b**2
259 return Mul(*[s**e for s in b.args])
261 if e.is_rational:
262 coeff, rest = b.as_coeff_terms()
263 if coeff == -1:
264 return None
265 elif coeff < 0:
266 return (-coeff)**e * Mul(*((S.NegativeOne,) +rest))**e
267 else:
268 return coeff**e * Mul(*[s**e for s in rest])
271 coeff, rest = b.as_coeff_terms()
272 if coeff is not S.One:
273 # (2*a)**3 -> 2**3 * a**3
274 return coeff**e * Mul(*[s**e for s in rest])
275 elif e.is_Integer:
276 coeff, rest = b.as_coeff_terms()
277 l = [s**e for s in rest]
278 if e.is_negative:
279 l.reverse()
280 return coeff**e * Mul(*l)
282 c,t = b.as_coeff_terms()
283 if e.is_even and c.is_Number and c < 0:
284 return (-c * Mul(*t)) ** e
286 #if e.atoms(Wild):
287 # return Mul(*[t**e for t in b])
289 def _eval_evalf(self, prec):
290 return AssocOp._eval_evalf(self, prec).expand()
292 @cacheit
293 def as_two_terms(self):
294 args = self.args
296 if len(args) == 1:
297 return S.One, self
298 elif len(args) == 2:
299 return args
301 else:
302 return args[0], self._new_rawargs(*args[1:])
304 @cacheit
305 def as_coeff_terms(self, x=None):
306 if x is not None:
307 l1 = []
308 l2 = []
309 for f in self.args:
310 if f.has(x):
311 l2.append(f)
312 else:
313 l1.append(f)
314 return Mul(*l1), tuple(l2)
315 coeff = self.args[0]
316 if coeff.is_Number:
317 return coeff, self.args[1:]
318 return S.One, self.args
320 @staticmethod
321 def _expandsums(sums):
322 L = len(sums)
323 if len(sums) == 1:
324 return sums[0]
325 terms = []
326 left = Mul._expandsums(sums[:L//2])
327 right = Mul._expandsums(sums[L//2:])
328 if isinstance(right, Basic):
329 right = right.args
330 if isinstance(left, Basic):
331 left = left.args
333 if len(left) == 1 and len(right) == 1:
334 # no expansion needed, bail out now to avoid infinite recursion
335 return [Mul(left[0], right[0])]
337 terms = []
338 for a in left:
339 for b in right:
340 terms.append(Mul(a,b).expand())
341 added = Add(*terms)
342 if added.is_Add:
343 terms = list(added.args)
344 else:
345 terms = [added]
346 return terms
348 def _eval_expand_basic(self):
349 plain, sums, rewrite = [], [], False
351 for factor in self.args:
352 terms = factor._eval_expand_basic()
354 if terms is not None:
355 factor = terms
357 if factor.is_Add:
358 sums.append(factor)
359 rewrite = True
360 else:
361 if factor.is_commutative:
362 plain.append(factor)
363 else:
364 sums.append([factor])
366 if terms is not None:
367 rewrite = True
369 if not rewrite:
370 return None
371 else:
372 if sums:
373 terms = Mul._expandsums(sums)
375 if isinstance(terms, Basic):
376 terms = terms.args
378 plain = Mul(*plain)
380 return Add(*(Mul(plain, term) for term in terms), **self.assumptions0)
381 else:
382 return Mul(*plain, **self.assumptions0)
384 def _eval_derivative(self, s):
385 terms = list(self.args)
386 factors = []
387 for i in xrange(len(terms)):
388 t = terms[i].diff(s)
389 if t is S.Zero:
390 continue
391 factors.append(Mul(*(terms[:i]+[t]+terms[i+1:])))
392 return Add(*factors)
394 def _matches_simple(pattern, expr, repl_dict):
395 # handle (w*3).matches('x*5') -> {w: x*5/3}
396 coeff, terms = pattern.as_coeff_terms()
397 if len(terms)==1:
398 return terms[0].matches(expr / coeff, repl_dict)
399 return
401 def matches(pattern, expr, repl_dict={}, evaluate=False):
402 expr = sympify(expr)
403 if pattern.is_commutative and expr.is_commutative:
404 return AssocOp._matches_commutative(pattern, expr, repl_dict, evaluate)
405 # todo for commutative parts, until then use the default matches method for non-commutative products
406 return Basic.matches(pattern, expr, repl_dict, evaluate)
408 @staticmethod
409 def _combine_inverse(lhs, rhs):
410 if lhs == rhs:
411 return S.One
412 return lhs / rhs
414 def as_powers_dict(self):
415 return dict([ term.as_base_exp() for term in self ])
417 def as_numer_denom(self):
418 numers, denoms = [],[]
419 for t in self.args:
420 n,d = t.as_numer_denom()
421 numers.append(n)
422 denoms.append(d)
423 return Mul(*numers), Mul(*denoms)
425 @cacheit
426 def count_ops(self, symbolic=True):
427 if symbolic:
428 return Add(*[t.count_ops(symbolic) for t in self[:]]) + Symbol('MUL') * (len(self[:])-1)
429 return Add(*[t.count_ops(symbolic) for t in self.args[:]]) + (len(self.args)-1)
431 def _eval_is_polynomial(self, syms):
432 for term in self.args:
433 if not term._eval_is_polynomial(syms):
434 return False
435 return True
437 _eval_is_bounded = lambda self: self._eval_template_is_attr('is_bounded')
438 _eval_is_commutative = lambda self: self._eval_template_is_attr('is_commutative')
439 _eval_is_integer = lambda self: self._eval_template_is_attr('is_integer')
440 _eval_is_comparable = lambda self: self._eval_template_is_attr('is_comparable')
443 # I*I -> R, I*I*I -> -I
445 def _eval_is_real(self):
446 im_count = 0
447 re_not = False
449 for t in self.args:
450 if t.is_imaginary:
451 im_count += 1
452 continue
454 t_real = t.is_real
455 if t_real:
456 continue
458 elif fuzzy_not(t_real):
459 re_not = True
461 else:
462 return None
464 if re_not:
465 return False
467 return (im_count % 2 == 0)
470 def _eval_is_imaginary(self):
471 im_count = 0
473 for t in self.args:
474 if t.is_imaginary:
475 im_count += 1
477 elif t.is_real:
478 continue
480 # real=F|U
481 else:
482 return None
484 return (im_count % 2 == 1)
488 def _eval_is_irrational(self):
489 for t in self.args:
490 a = t.is_irrational
491 if a: return True
492 if a is None: return
493 return False
495 def _eval_is_positive(self):
496 terms = [t for t in self.args if not t.is_positive]
497 if not terms:
498 return True
499 c = terms[0]
500 if len(terms)==1:
501 if c.is_nonpositive:
502 return False
503 return
504 r = Mul(*terms[1:])
505 if c.is_negative and r.is_negative:
506 return True
507 if r.is_negative and c.is_negative:
508 return True
509 # check for nonpositivity, <=0
510 if c.is_negative and r.is_nonnegative:
511 return False
512 if r.is_negative and c.is_nonnegative:
513 return False
514 if c.is_nonnegative and r.is_nonpositive:
515 return False
516 if r.is_nonnegative and c.is_nonpositive:
517 return False
520 def _eval_is_negative(self):
521 terms = [t for t in self.args if not t.is_positive]
522 if not terms:
523 # all terms are either positive -- 2*Symbol('n', positive=T)
524 # or unknown -- 2*Symbol('x')
525 if self.is_positive:
526 return False
527 else:
528 return None
529 c = terms[0]
530 if len(terms)==1:
531 return c.is_negative
532 r = Mul(*terms[1:])
533 # check for nonnegativity, >=0
534 if c.is_negative and r.is_nonpositive:
535 return False
536 if r.is_negative and c.is_nonpositive:
537 return False
538 if c.is_nonpositive and r.is_nonpositive:
539 return False
540 if c.is_nonnegative and r.is_nonnegative:
541 return False
543 def _eval_is_odd(self):
544 is_integer = self.is_integer
546 if is_integer:
547 r = True
548 for t in self.args:
549 if t.is_even:
550 return False
551 if t.is_odd is None:
552 r = None
553 return r
555 # !integer -> !odd
556 elif is_integer == False:
557 return False
560 def _eval_is_even(self):
561 is_integer = self.is_integer
563 if is_integer:
564 return fuzzy_not(self._eval_is_odd())
566 elif is_integer == False:
567 return False
569 def _eval_subs(self, old, new):
570 if self==old:
571 return new
572 if isinstance(old, FunctionClass):
573 return self.__class__(*[s.subs(old, new) for s in self.args ])
574 coeff1,terms1 = self.as_coeff_terms()
575 coeff2,terms2 = old.as_coeff_terms()
576 if terms1==terms2: # (2*a).subs(3*a,y) -> 2/3*y
577 return new * coeff1/coeff2
578 l1,l2 = len(terms1),len(terms2)
579 if l2 == 0:
580 # if old is just a number, go through the self.args one by one
581 return Mul(*[x.subs(old, new) for x in self.args])
582 elif l2<l1:
583 # old is some something more complex, like:
584 # (a*b*c*d).subs(b*c,x) -> a*x*d
585 # then we need to search where in self.args the "old" is, and then
586 # correctly substitute both terms and coefficients.
587 for i in xrange(l1-l2+1):
588 if terms2==terms1[i:i+l2]:
589 m1 = Mul(*terms1[:i]).subs(old,new)
590 m2 = Mul(*terms1[i+l2:]).subs(old,new)
591 return Mul(*([coeff1/coeff2,m1,new,m2]))
592 return self.__class__(*[s.subs(old, new) for s in self.args])
594 def _eval_oseries(self, order):
595 x = order.symbols[0]
596 l = []
597 r = []
598 lt = []
599 #separate terms containing "x" (r) and the rest (l)
600 for t in self.args:
601 if not t.has(x):
602 l.append(t)
603 continue
604 r.append(t)
605 #if r is empty or just one term, it's easy:
606 if not r:
607 if order.contains(1,x): return S.Zero
608 return Mul(*l)
609 if len(r)==1:
610 return Mul(*(l + [r[0].oseries(order)]))
611 #otherwise, we need to calculate how many orders we need to calculate
612 #in each term. Currently this is done using as_leading_term, but this
613 #is fragile and slow, because this involves limits. Let's find some
614 #more clever approach.
615 lt = [t.as_leading_term(x) for t in r]
616 for i in xrange(len(r)):
617 m = Mul(*(lt[:i]+lt[i+1:]))
618 #calculate how many orders we want
619 o = order/m
620 #expand each term and multiply things together
621 l.append(r[i].oseries(o))
622 #shouldn't we rather expand everything? This seems to me to leave
623 #things as (x+x**2+...)*(x-x**2+...) etc.:
624 return Mul(*l)
626 def nseries(self, x, x0, n):
627 terms = [t.nseries(x, x0, n) for t in self.args]
628 return Mul(*terms).expand()
631 def _eval_as_leading_term(self, x):
632 return Mul(*[t.as_leading_term(x) for t in self.args])
634 def _eval_conjugate(self):
635 return Mul(*[t.conjugate() for t in self.args])
637 def _sage_(self):
638 s = 1
639 for x in self.args:
640 s *= x._sage_()
641 return s
644 # /cyclic/
645 import basic as _
646 _.Mul = Mul
647 del _
649 import add as _
650 _.Mul = Mul
651 del _
653 import power as _
654 _.Mul = Mul
655 del _
657 import numbers as _
658 _.Mul = Mul
659 del _
661 import operations as _
662 _.Mul = Mul
663 del _