refactorings, tests added
[sympyx.git] / sympy.py
blob59f2b4c753b93a9f92c9c8654181a73b38375f0c
1 BASIC = 0
2 SYMBOL = 1
3 ADD = 2
4 MUL = 3
5 POW = 4
6 INTEGER = 5
8 def hash_seq(args):
9 # make this more robust:
10 m = 2
11 for x in args:
12 m = hash(m + 1001 ^ hash(x))
13 return m
15 class Basic(object):
17 def __new__(cls, type, args):
18 obj = object.__new__(cls)
19 obj.type = type
20 obj._args = args
21 return obj
23 def __repr__(self):
24 return str(self)
26 def __hash__(self):
27 return hash_seq(self.args)
29 @property
30 def args(self):
31 return self._args
33 def as_coeff_rest(self):
34 return (Integer(1), self)
36 def as_base_exp(self):
37 return (self, Integer(1))
39 def __add__(x, y):
40 return Add((x, y))
42 def __sub__(x, y):
43 return Add((x, -y))
45 def __mul__(x, y):
46 return Mul((x, y))
48 def __pow__(x, y):
49 return Pow((x, y))
51 def __neg__(x):
52 return Mul((Integer(-1), x))
55 class Integer(Basic):
57 def __new__(cls, i):
58 obj = Basic.__new__(cls, INTEGER, [])
59 obj.i = i
60 return obj
62 def __str__(self):
63 return str(self.i)
65 def __add__(self, o):
66 if o.type == INTEGER:
67 return Integer(self.i+o.i)
68 return NotImplemented
70 def __mul__(self, o):
71 if o.type == INTEGER:
72 return Integer(self.i*o.i)
73 return NotImplemented
76 class Symbol(Basic):
78 def __new__(cls, name):
79 obj = Basic.__new__(cls, SYMBOL, [])
80 obj.name = name
81 return obj
83 def __hash__(self):
84 return hash(self.name)
86 def __str__(self):
87 return self.name
90 class Add(Basic):
92 def __new__(cls, args, canonicalize=True):
93 if canonicalize == False:
94 obj = Basic.__new__(cls, ADD, args)
95 return obj
96 args = [sympify(x) for x in args]
97 return Add.canonicalize(args)
99 @classmethod
100 def canonicalize(cls, args):
101 d = {}
102 for a in args:
103 if a.type == ADD:
104 for b in a.args:
105 coeff, key = b.as_coeff_rest()
106 if key in d:
107 d[key] += coeff
108 else:
109 d[key] = coeff
110 else:
111 coeff, key = a.as_coeff_rest()
112 if key in d:
113 d[key] += coeff
114 else:
115 d[key] = coeff
116 args = []
117 for a, b in d.iteritems():
118 args.append(Mul((a, b)))
120 return Add(args, False)
122 def __str__(self):
123 s = str(self.args[0])
124 if self.args[0].type == ADD:
125 s = "(%s)" % str(s)
126 for x in self.args[1:]:
127 s = "%s + %s" % (s, str(x))
128 if x.type == ADD:
129 s = "(%s)" % s
130 return s
132 class Mul(Basic):
134 def __new__(cls, args, canonicalize=True):
135 if canonicalize == False:
136 obj = Basic.__new__(cls, MUL, args)
137 return obj
138 args = [sympify(x) for x in args]
139 return Mul.canonicalize(args)
141 @classmethod
142 def canonicalize(cls, args):
143 d = {}
144 num = Integer(1)
145 for a in args:
146 if a.type == INTEGER:
147 num *= a
148 elif a.type == MUL:
149 for b in a.args:
150 coeff, key = b.as_base_exp()
151 if key in d:
152 d[key] += coeff
153 else:
154 d[key] = coeff
155 else:
156 coeff, key = a.as_base_exp()
157 if key in d:
158 d[key] += coeff
159 else:
160 d[key] = coeff
161 if num.i == 0:
162 return num
163 args = []
164 for a, b in d.iteritems():
165 args.append(Pow((b, a)))
166 if num.i != 1:
167 args.insert(0, num)
168 if len(args) == 1:
169 return args[0]
170 else:
171 return Mul(args, False)
173 def __hash__(self):
174 a = self.args[:]
175 a.sort(key=hash)
176 return hash_seq(a)
178 def __eq__(self, o):
179 if o.type == MUL:
180 a = self.args[:]
181 a.sort(key=hash)
182 b = o.args[:]
183 b.sort(key=hash)
184 return a == b
185 else:
186 return False
189 def as_coeff_rest(self):
190 if self.args[0].type == INTEGER:
191 return (self.args[0], Mul(self.args[1:]))
192 return (Integer(1), self)
194 def __str__(self):
195 s = str(self.args[0])
196 if self.args[0].type == MUL:
197 s = "(%s)" % str(s)
198 for x in self.args[1:]:
199 s = "%s*%s" % (s, str(x))
200 if x.type == MUL:
201 s = "(%s)" % s
202 return s
204 class Pow(Basic):
206 def __new__(cls, args, canonicalize=True):
207 if canonicalize == False:
208 obj = Basic.__new__(cls, MUL, args)
209 return obj
210 args = [sympify(x) for x in args]
211 return Pow.canonicalize(args)
213 @classmethod
214 def canonicalize(cls, args):
215 base, exp = args
216 if exp.type == INTEGER:
217 if exp.i == 0:
218 return Integer(1)
219 if exp.i == 1:
220 return base
221 return Pow(args, False)
223 def __str__(self):
224 s = str(self.args[0])
225 if self.args[0].type == ADD:
226 s = "(%s)" % s
227 if self.args[1].type == ADD:
228 s = "%s^(%s)" % (s, str(self.args[1]))
229 else:
230 s = "%s^%s" % (s, str(self.args[1]))
232 return s
234 def sympify(x):
235 if isinstance(x, int):
236 return Integer(x)
237 return x