couple bugs fixed
[sympyx.git] / sympy.py
blobfee5c5fe0582c55d3ac9396c9a4d2e70ec4bb6db
1 BASIC = 0
2 SYMBOL = 1
3 ADD = 2
4 MUL = 3
5 POW = 4
6 INTEGER = 5
8 def hash_seq(args):
9 # make this more robust:
10 m = 2
11 for x in args:
12 m = hash(m + 1001 ^ hash(x))
13 return m
15 class Basic(object):
17 def __new__(cls, type, args):
18 obj = object.__new__(cls)
19 obj.type = type
20 obj._args = tuple(args)
21 return obj
23 def __repr__(self):
24 return str(self)
26 def __hash__(self):
27 return hash_seq(self.args)
29 @property
30 def args(self):
31 return self._args
33 def as_coeff_rest(self):
34 return (Integer(1), self)
36 def as_base_exp(self):
37 return (self, Integer(1))
39 def __add__(x, y):
40 return Add((x, y))
42 def __radd__(x, y):
43 return x.__add__(y)
45 def __sub__(x, y):
46 return Add((x, -y))
48 def __rsub__(x, y):
49 return Add((y, -x))
51 def __mul__(x, y):
52 return Mul((x, y))
54 def __rmul__(x, y):
55 return Mul((y, x))
57 def __div__(x, y):
58 return Mul((x, Pow((y, Integer(-1)))))
60 def __rdiv__(x, y):
61 return Mul((y, Pow((x, Integer(-1)))))
63 def __pow__(x, y):
64 return Pow((x, y))
66 def __rpow__(x, y):
67 return Pow((y, x))
69 def __neg__(x):
70 return Mul((Integer(-1), x))
72 def __pos__(x):
73 return x
75 def __ne__(self, x):
76 return not self.__eq__(x)
79 class Integer(Basic):
81 def __new__(cls, i):
82 obj = Basic.__new__(cls, INTEGER, [])
83 obj.i = i
84 return obj
86 def __eq__(self, o):
87 o = sympify(o)
88 if o.type == INTEGER:
89 return self.i == o.i
90 else:
91 return False
93 def __str__(self):
94 return str(self.i)
96 def __add__(self, o):
97 o = sympify(o)
98 if o.type == INTEGER:
99 return Integer(self.i+o.i)
100 return Basic.__add__(self, o)
102 def __mul__(self, o):
103 o = sympify(o)
104 if o.type == INTEGER:
105 return Integer(self.i*o.i)
106 return Basic.__mul__(self, o)
109 class Symbol(Basic):
111 def __new__(cls, name):
112 obj = Basic.__new__(cls, SYMBOL, [])
113 obj.name = name
114 return obj
116 def __hash__(self):
117 return hash(self.name)
119 def __eq__(self, o):
120 if o.type == SYMBOL:
121 return self.name == o.name
122 return False
124 def __str__(self):
125 return self.name
128 class Add(Basic):
130 def __new__(cls, args, canonicalize=True):
131 if canonicalize == False:
132 obj = Basic.__new__(cls, ADD, args)
133 return obj
134 args = [sympify(x) for x in args]
135 return Add.canonicalize(args)
137 @classmethod
138 def canonicalize(cls, args):
139 d = {}
140 for a in args:
141 if a.type == ADD:
142 for b in a.args:
143 coeff, key = b.as_coeff_rest()
144 if key in d:
145 d[key] += coeff
146 else:
147 d[key] = coeff
148 else:
149 coeff, key = a.as_coeff_rest()
150 if key in d:
151 d[key] += coeff
152 else:
153 d[key] = coeff
154 print d
155 args = []
156 for a, b in d.iteritems():
157 args.append(Mul((a, b)))
159 return Add(args, False)
161 def __str__(self):
162 s = str(self.args[0])
163 if self.args[0].type == ADD:
164 s = "(%s)" % str(s)
165 for x in self.args[1:]:
166 s = "%s + %s" % (s, str(x))
167 if x.type == ADD:
168 s = "(%s)" % s
169 return s
171 class Mul(Basic):
173 def __new__(cls, args, canonicalize=True):
174 if canonicalize == False:
175 obj = Basic.__new__(cls, MUL, args)
176 return obj
177 args = [sympify(x) for x in args]
178 return Mul.canonicalize(args)
180 @classmethod
181 def canonicalize(cls, args):
182 d = {}
183 num = Integer(1)
184 for a in args:
185 if a.type == INTEGER:
186 num *= a
187 elif a.type == MUL:
188 for b in a.args:
189 coeff, key = b.as_base_exp()
190 if key in d:
191 d[key] += coeff
192 else:
193 d[key] = coeff
194 else:
195 coeff, key = a.as_base_exp()
196 if key in d:
197 d[key] += coeff
198 else:
199 d[key] = coeff
200 if num.i == 0 or len(d)==0:
201 return num
202 args = []
203 for a, b in d.iteritems():
204 args.append(Pow((b, a)))
205 if num.i != 1:
206 args.insert(0, num)
207 if len(args) == 1:
208 return args[0]
209 else:
210 return Mul(args, False)
212 def __hash__(self):
213 a = list(self.args[:])
214 a.sort(key=hash)
215 return hash_seq(a)
217 def __eq__(self, o):
218 o = sympify(o)
219 if o.type == MUL:
220 a = list(self.args[:])
221 a.sort(key=hash)
222 b = list(o.args[:])
223 b.sort(key=hash)
224 return a == b
225 else:
226 return False
229 def as_coeff_rest(self):
230 if self.args[0].type == INTEGER:
231 return (self.args[0], Mul(self.args[1:]))
232 return (Integer(1), self)
234 def __str__(self):
235 s = str(self.args[0])
236 if self.args[0].type == MUL:
237 s = "(%s)" % str(s)
238 for x in self.args[1:]:
239 s = "%s*%s" % (s, str(x))
240 if x.type == MUL:
241 s = "(%s)" % s
242 return s
244 class Pow(Basic):
246 def __new__(cls, args, canonicalize=True):
247 if canonicalize == False:
248 obj = Basic.__new__(cls, POW, args)
249 return obj
250 args = [sympify(x) for x in args]
251 return Pow.canonicalize(args)
253 @classmethod
254 def canonicalize(cls, args):
255 base, exp = args
256 if exp.type == INTEGER:
257 if exp.i == 0:
258 return Integer(1)
259 if exp.i == 1:
260 return base
261 return Pow(args, False)
263 def __str__(self):
264 s = str(self.args[0])
265 if self.args[0].type == ADD:
266 s = "(%s)" % s
267 if self.args[1].type == ADD:
268 s = "%s^(%s)" % (s, str(self.args[1]))
269 else:
270 s = "%s^%s" % (s, str(self.args[1]))
272 return s
274 def sympify(x):
275 if isinstance(x, int):
276 return Integer(x)
277 return x