Add/Mul -- introduce ._new_rawargs()
[sympy.git] / sympy / core / operations.py
blobbfe2be4dfa843e9136ae6af11cbcd7d555d26554
2 from basic import Basic, S, C
3 from sympify import _sympify
4 from cache import cacheit
6 # from add import Add /cyclic/
7 # from mul import Mul /cyclic/
8 # from function import Lambda, WildFunction /cyclic/
9 from symbol import Symbol, Wild
11 class AssocOp(Basic):
12 """ Associative operations, can separate noncommutative and
13 commutative parts.
15 (a op b) op c == a op (b op c) == a op b op c.
17 Base class for Add and Mul.
18 """
20 # for performance reason, we don't let is_commutative go to assumptions,
21 # and keep it right here
22 __slots__ = ['is_commutative']
24 @cacheit
25 def __new__(cls, *args, **assumptions):
26 if assumptions.get('evaluate') is False:
27 return Basic.__new__(cls, *map(_sympify, args), **assumptions)
28 if len(args)==0:
29 return cls.identity()
30 if len(args)==1:
31 return _sympify(args[0])
32 c_part, nc_part, order_symbols = cls.flatten(map(_sympify, args))
33 if len(c_part) + len(nc_part) <= 1:
34 if c_part: obj = c_part[0]
35 elif nc_part: obj = nc_part[0]
36 else: obj = cls.identity()
37 else:
38 obj = Basic.__new__(cls, *(c_part + nc_part), **assumptions)
39 obj.is_commutative = not nc_part
41 if order_symbols is not None:
42 obj = C.Order(obj, *order_symbols)
43 return obj
46 def _new_rawargs(self, *args):
47 """create new instance of own class with args exactly as provided by caller
49 This is handy when we want to optimize things, e.g.
51 >>> from sympy import Mul, symbols
52 >>> x,y = symbols('xy')
53 >>> e = Mul(3,x,y)
54 >>> e.args
55 (3, x, y)
56 >>> Mul(*e.args[1:])
57 x*y
58 >>> e._new_rawargs(*e.args[1:]) # the same as above, but faster
59 x*y
61 """
62 obj = Basic.__new__(type(self), *args) # NB no assumptions for Add/Mul
63 obj.is_commutative = self.is_commutative
65 return obj
67 @classmethod
68 def identity(cls):
69 if cls is Mul: return S.One
70 if cls is Add: return S.Zero
71 if cls is C.Composition:
72 s = Symbol('x',dummy=True)
73 return Lambda(s,s)
74 raise NotImplementedError("identity not defined for class %r" % (cls.__name__))
76 @classmethod
77 def flatten(cls, seq):
78 # apply associativity, no commutativity property is used
79 new_seq = []
80 while seq:
81 o = seq.pop(0)
82 if o.__class__ is cls: # classes must match exactly
83 seq = list(o[:]) + seq
84 continue
85 new_seq.append(o)
86 return [], new_seq, None
88 _eval_subs = Basic._seq_subs
90 def _matches_commutative(pattern, expr, repl_dict={}, evaluate=False):
91 # apply repl_dict to pattern to eliminate fixed wild parts
92 if evaluate:
93 pat = pattern
94 for old,new in repl_dict.items():
95 pat = pat.subs(old, new)
96 if pat != pattern:
97 return pat.matches(expr, repl_dict)
99 # handle simple patterns
100 d = pattern._matches_simple(expr, repl_dict)
101 if d is not None:
102 return d
104 # eliminate exact part from pattern: (2+a+w1+w2).matches(expr) -> (w1+w2).matches(expr-a-2)
105 wild_part = []
106 exact_part = []
107 for p in pattern.args:
108 if p.atoms(Wild, WildFunction):
109 # not all Wild should stay Wilds, for example:
110 # (w2+w3).matches(w1) -> (w1+w3).matches(w1) -> w3.matches(0)
111 if (not p in repl_dict) and (not p in expr):
112 wild_part.append(p)
113 continue
115 exact_part.append(p)
117 if exact_part:
118 newpattern = pattern.__class__(*wild_part)
119 newexpr = pattern.__class__._combine_inverse(expr, pattern.__class__(*exact_part))
120 return newpattern.matches(newexpr, repl_dict)
122 # now to real work ;)
123 if isinstance(expr, pattern.__class__):
124 expr_list = list(expr.args)
125 else:
126 expr_list = [expr]
128 while expr_list:
129 last_op = expr_list.pop()
130 tmp = wild_part[:]
131 while tmp:
132 w = tmp.pop()
133 d1 = w.matches(last_op, repl_dict)
134 if d1 is not None:
135 d2 = pattern.matches(expr, d1, evaluate=True)
136 if d2 is not None:
137 return d2
138 return
140 def _eval_template_is_attr(self, is_attr):
141 # return True if all elements have the property
142 r = True
143 for t in self.args:
144 a = getattr(t, is_attr)
145 if a is None: return
146 if r and not a: r = False
147 return r
149 _eval_evalf = Basic._seq_eval_evalf