ok
[sympyx.git] / sympy.py
blobd0db8824a14ac43b1dd23cc93fae83e3175a587f
1 BASIC = 0
2 SYMBOL = 1
3 ADD = 2
4 MUL = 3
5 POW = 4
6 INTEGER = 5
8 def hash_seq(args):
9 # make this more robust:
10 m = 2
11 for x in args:
12 m = hash(m + 1001 ^ hash(x))
13 return m
15 class Basic(object):
17 def __new__(cls, type, args):
18 obj = object.__new__(cls)
19 obj.type = type
20 obj._args = tuple(args)
21 return obj
23 def __repr__(self):
24 return str(self)
26 def __hash__(self):
27 return hash_seq(self.args)
29 @property
30 def args(self):
31 return self._args
33 def as_coeff_rest(self):
34 return (Integer(1), self)
36 def as_base_exp(self):
37 return (self, Integer(1))
39 def __add__(x, y):
40 return Add((x, y))
42 def __radd__(x, y):
43 return x.__add__(y)
45 def __sub__(x, y):
46 return Add((x, -y))
48 def __rsub__(x, y):
49 return Add((y, -x))
51 def __mul__(x, y):
52 return Mul((x, y))
54 def __rmul__(x, y):
55 return Mul((y, x))
57 def __div__(x, y):
58 return Mul((x, Pow((y, Integer(-1)))))
60 def __rdiv__(x, y):
61 return Mul((y, Pow((x, Integer(-1)))))
63 def __pow__(x, y):
64 return Pow((x, y))
66 def __rpow__(x, y):
67 return Pow((y, x))
69 def __neg__(x):
70 return Mul((Integer(-1), x))
72 def __pos__(x):
73 return x
75 def __ne__(self, x):
76 return not self.__eq__(x)
78 def __eq__(self, o):
79 o = sympify(o)
80 if o.type == self.type:
81 return self.args == o.args
82 else:
83 return False
86 class Integer(Basic):
88 def __new__(cls, i):
89 obj = Basic.__new__(cls, INTEGER, [])
90 obj.i = i
91 return obj
93 def __hash__(self):
94 return hash(self.i)
96 def __eq__(self, o):
97 o = sympify(o)
98 if o.type == INTEGER:
99 return self.i == o.i
100 else:
101 return False
103 def __str__(self):
104 return str(self.i)
106 def __add__(self, o):
107 o = sympify(o)
108 if o.type == INTEGER:
109 return Integer(self.i+o.i)
110 return Basic.__add__(self, o)
112 def __mul__(self, o):
113 o = sympify(o)
114 if o.type == INTEGER:
115 return Integer(self.i*o.i)
116 return Basic.__mul__(self, o)
119 class Symbol(Basic):
121 def __new__(cls, name):
122 obj = Basic.__new__(cls, SYMBOL, [])
123 obj.name = name
124 return obj
126 def __hash__(self):
127 return hash(self.name)
129 def __eq__(self, o):
130 o = sympify(o)
131 if o.type == SYMBOL:
132 return self.name == o.name
133 return False
135 def __str__(self):
136 return self.name
139 class Add(Basic):
141 def __new__(cls, args, canonicalize=True):
142 if canonicalize == False:
143 obj = Basic.__new__(cls, ADD, args)
144 return obj
145 args = [sympify(x) for x in args]
146 return Add.canonicalize(args)
148 @classmethod
149 def canonicalize(cls, args):
150 d = {}
151 num = Integer(0)
152 for a in args:
153 if a.type == INTEGER:
154 num += a
155 elif a.type == ADD:
156 for b in a.args:
157 coeff, key = b.as_coeff_rest()
158 if key in d:
159 d[key] += coeff
160 else:
161 d[key] = coeff
162 else:
163 coeff, key = a.as_coeff_rest()
164 if key in d:
165 d[key] += coeff
166 else:
167 d[key] = coeff
168 if len(d)==0:
169 return num
170 args = []
171 for a, b in d.iteritems():
172 args.append(Mul((a, b)))
173 if num.i != 0:
174 args.insert(0, num)
175 if len(args) == 1:
176 return args[0]
177 else:
178 return Add(args, False)
180 def __eq__(self, o):
181 o = sympify(o)
182 if o.type == ADD:
183 a = list(self.args[:])
184 a.sort(key=hash)
185 b = list(o.args[:])
186 b.sort(key=hash)
187 return a == b
188 else:
189 return False
191 def __str__(self):
192 s = str(self.args[0])
193 if self.args[0].type == ADD:
194 s = "(%s)" % str(s)
195 for x in self.args[1:]:
196 s = "%s + %s" % (s, str(x))
197 if x.type == ADD:
198 s = "(%s)" % s
199 return s
201 class Mul(Basic):
203 def __new__(cls, args, canonicalize=True):
204 if canonicalize == False:
205 obj = Basic.__new__(cls, MUL, args)
206 return obj
207 args = [sympify(x) for x in args]
208 return Mul.canonicalize(args)
210 @classmethod
211 def canonicalize(cls, args):
212 d = {}
213 num = Integer(1)
214 for a in args:
215 if a.type == INTEGER:
216 num *= a
217 elif a.type == MUL:
218 for b in a.args:
219 key, coeff = b.as_base_exp()
220 if key in d:
221 d[key] += coeff
222 else:
223 d[key] = coeff
224 else:
225 key, coeff = a.as_base_exp()
226 if key in d:
227 d[key] += coeff
228 else:
229 d[key] = coeff
230 if num.i == 0 or len(d)==0:
231 return num
232 args = []
233 for a, b in d.iteritems():
234 args.append(Pow((a, b)))
235 if num.i != 1:
236 args.insert(0, num)
237 if len(args) == 1:
238 return args[0]
239 else:
240 return Mul(args, False)
242 def __hash__(self):
243 a = list(self.args[:])
244 a.sort(key=hash)
245 return hash_seq(a)
247 def __eq__(self, o):
248 o = sympify(o)
249 if o.type == MUL:
250 a = list(self.args[:])
251 a.sort(key=hash)
252 b = list(o.args[:])
253 b.sort(key=hash)
254 return a == b
255 else:
256 return False
259 def as_coeff_rest(self):
260 if self.args[0].type == INTEGER:
261 return (self.args[0], Mul(self.args[1:]))
262 return (Integer(1), self)
264 def __str__(self):
265 s = str(self.args[0])
266 if self.args[0].type == MUL:
267 s = "(%s)" % str(s)
268 for x in self.args[1:]:
269 s = "%s*%s" % (s, str(x))
270 if x.type == MUL:
271 s = "(%s)" % s
272 return s
274 class Pow(Basic):
276 def __new__(cls, args, canonicalize=True):
277 if canonicalize == False:
278 obj = Basic.__new__(cls, POW, args)
279 return obj
280 args = [sympify(x) for x in args]
281 return Pow.canonicalize(args)
283 @classmethod
284 def canonicalize(cls, args):
285 base, exp = args
286 if base.type == INTEGER:
287 if base.i == 0:
288 return Integer(0)
289 if base.i == 1:
290 return Integer(1)
291 if exp.type == INTEGER:
292 if exp.i == 0:
293 return Integer(1)
294 if exp.i == 1:
295 return base
296 if base.type == POW:
297 return Pow((base.args[0], base.args[1]*exp))
298 return Pow(args, False)
300 def __str__(self):
301 s = str(self.args[0])
302 if self.args[0].type == ADD:
303 s = "(%s)" % s
304 if self.args[1].type == ADD:
305 s = "%s^(%s)" % (s, str(self.args[1]))
306 else:
307 s = "%s^%s" % (s, str(self.args[1]))
309 return s
311 def sympify(x):
312 if isinstance(x, int):
313 return Integer(x)
314 return x