- allow for intermediate, non-linear terms
[PyX/mjg.git] / test / experimental / solve.py
blobc8397cf9445a07efd7c2d876ff8dceedc06e435d
1 #!/usr/bin/env python
2 # -*- coding: ISO-8859-1 -*-
5 # Copyright (C) 2004 André Wobst <wobsta@users.sourceforge.net>
7 # This file is part of PyX (http://pyx.sourceforge.net/).
9 # PyX is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 2 of the License, or
12 # (at your option) any later version.
14 # PyX is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
19 # You should have received a copy of the GNU General Public License
20 # along with PyX; if not, write to the Free Software
21 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
23 import Numeric, LinearAlgebra
26 class scalar:
27 # represents a scalar variable or constant
29 def __init__(self, value=None, varname="unnamed_scalar"):
30 self.value = None
31 if value is not None:
32 self.set(value)
33 self.varname = varname
35 def addend(self):
36 return addend([self], None)
38 def term(self):
39 return term([self.addend()])
41 def __add__(self, other):
42 return self.term() + other
44 __radd__ = __add__
46 def __sub__(self, other):
47 return self.term() - other
49 def __rsub__(self, other):
50 return -self.term() + other
52 def __neg__(self):
53 return -self.addend()
55 def __mul__(self, other):
56 return self.addend()*other
58 __rmul__ = __mul__
60 def __div__(self, other):
61 return self.addend()/other
63 def is_set(self):
64 return self.value is not None
66 def set(self, value):
67 if self.is_set():
68 raise RuntimeError("variable already defined")
69 try:
70 self.value = float(value)
71 except:
72 raise RuntimeError("float expected")
74 def get(self):
75 if not self.is_set():
76 raise RuntimeError("variable not yet defined")
77 return self.value
79 def __str__(self):
80 if self.is_set():
81 return "%s{=%s}" % (self.varname, self.value)
82 else:
83 return self.varname
85 def __float__(self):
86 return self.get()
89 class vector:
90 # represents a vector, i.e. a list of scalars
92 def __init__(self, dimension_or_values, varname="unnamed_vector"):
93 try:
94 varname + ""
95 except TypeError:
96 raise RuntimeError("a vectors varname should be a string (you probably wanted to write vector([x, y]) instead of vector(x, y))")
97 try:
98 # values
99 self.scalars = [scalar(value=value, varname="%s[%i]" % (varname, i))
100 for i, value in enumerate(dimension_or_values)]
101 except (TypeError, AttributeError):
102 # dimension
103 self.scalars = [scalar(varname="%s[%i]" % (varname, i))
104 for i in range(dimension_or_values)]
105 self.varname = varname
107 def __len__(self):
108 return len(self.scalars)
110 def __getitem__(self, i):
111 return self.scalars[i]
113 def __getattr__(self, attr):
114 if attr == "x":
115 return self[0]
116 if attr == "y":
117 return self[1]
118 if attr == "z":
119 return self[2]
120 else:
121 raise AttributeError(attr)
123 def addend(self):
124 return addend([], self)
126 def term(self):
127 return term([self.addend()])
129 def __add__(self, other):
130 return self.term() + other
132 __radd__ = __add__
134 def __sub__(self, other):
135 return self.term() - other
137 def __rsub__(self, other):
138 return -self.term() + other
140 def __neg__(self):
141 return -self.addend()
143 def __mul__(self, other):
144 return self.addend()*other
146 __rmul__ = __mul__
148 def __div__(self, other):
149 return self.addend()/other
151 def __str__(self):
152 return "%s{=(%s)}" % (self.varname, ", ".join([str(scalar) for scalar in self.scalars]))
155 class addend:
156 # represents an addend of a term, i.e. a list of scalars and
157 # optionally a vector (for a vector term) otherwise the vector
158 # is None
160 def __init__(self, scalars, vector):
161 # self.vector might be None for a scalar addend
162 self.scalars = scalars
163 self.vector = vector
165 def __len__(self):
166 return len(self.vector)
168 def __getitem__(self, i):
169 return addend(self.scalars + [self.vector[i]], None)
171 def addend(self):
172 return self
174 def term(self):
175 return term([self.addend()])
177 def is_linear(self):
178 assert self.vector is None
179 return len([scalar for scalar in self.scalars if not scalar.is_set()]) < 2
181 def prefactor(self):
182 assert self.is_linear()
183 prefactor = 1
184 for scalar_set in [scalar for scalar in self.scalars if scalar.is_set()]:
185 prefactor *= scalar_set.get()
186 return prefactor
188 def variable(self):
189 assert self.is_linear()
190 try:
191 variable, = [scalar for scalar in self.scalars if not scalar.is_set()]
192 except ValueError:
193 return None
194 else:
195 return variable
197 def __add__(self, other):
198 return self.term() + other
200 __radd__ = __add__
202 def __sub__(self, other):
203 return self.term() - other
205 def __rsub__(self, other):
206 return -self.term() + other
208 def __neg__(self):
209 return addend([scalar(-1)] + self.scalars, self.vector)
211 def __mul__(self, other):
212 try:
213 a = other.addend()
214 except (TypeError, AttributeError):
215 try:
216 t = other.term()
217 except (TypeError, AttributeError):
218 return self*scalar(other)
219 else:
220 return term([self*a for a in t.addends])
221 else:
222 if a.vector is not None:
223 if self.vector is not None:
224 if len(self.vector) != len(a.vector):
225 raise RuntimeError("vector length mismatch in scalar product")
226 return term([addend(self.scalars + a.scalars + [x*y], None)
227 for x, y in zip(self.vector, a.vector)])
228 else:
229 return addend(self.scalars + a.scalars, a.vector)
230 else:
231 return addend(self.scalars + a.scalars, self.vector)
233 __rmul__ = __mul__
235 def __div__(self, other):
236 return addend([scalar(1/other)] + self.scalars, self.vector)
238 def __str__(self):
239 scalarstring = " * ".join([str(scalar) for scalar in self.scalars])
240 if self.vector is None:
241 return scalarstring
242 else:
243 if len(scalarstring):
244 scalarstring += " * "
245 return scalarstring + str(self.vector)
248 class term:
249 # represents a term, i.e. a list of addends
251 def __init__(self, addends):
252 assert len(addends)
253 try:
254 self.length = len(addends[0])
255 except (TypeError, AttributeError):
256 for addend in addends[1:]:
257 try:
258 len(addend)
259 except (TypeError, AttributeError):
260 pass
261 else:
262 raise RuntimeError("vector addend in scalar term")
263 self.length = None
264 else:
265 for addend in addends[1:]:
266 try:
267 l = len(addend)
268 except (TypeError, AttributeError):
269 raise RuntimeError("scalar addend in vector term")
270 if l != self.length:
271 raise RuntimeError("vector length mismatch in term constructor")
272 self.addends = addends
274 def __len__(self):
275 if self.length is None:
276 raise AttributeError("scalar term")
277 else:
278 return self.length
280 def __getitem__(self, i):
281 return term([addend[i] for addend in self.addends])
283 def term(self):
284 return self
286 def is_linear(self):
287 is_linear = 1
288 for addend in self.addends:
289 is_linear = is_linear and addend.is_linear()
290 return is_linear
292 def __add__(self, other):
293 try:
294 t = other.term()
295 except:
296 return self + scalar(other)
297 else:
298 return term(self.addends + t.addends)
300 __radd__ = __add__
302 def __neg__(self):
303 return term([-addend for addend in self.addends])
305 def __sub__(self, other):
306 return -other+self
308 def __rsub__(self, other):
309 return -self+other
311 def __mul__(self, other):
312 return term([addend*other for addend in self.addends])
314 __rmul__ = __mul__
316 def __div__(self, other):
317 return term([addend/other for addend in self.addends])
319 def __str__(self):
320 return " + ".join([str(addend) for addend in self.addends])
323 class Solver:
324 # linear equation solver
326 def __init__(self):
327 self.eqs = [] # scalar equations not yet solved (a equation is a term to be zero here)
329 def eq(self, lhs, rhs=None):
330 if rhs is None:
331 eq = lhs
332 else:
333 eq = lhs - rhs
334 eq = eq.term()
335 try:
336 # is it a vector equation?
337 neqs = len(eq)
338 except (TypeError, AttributeError):
339 self.add(eq)
340 else:
341 for i in range(neqs):
342 self.add(eq[i])
344 def add(self, equation):
345 # the equation is just a term which should be zero
346 self.eqs.append(equation)
348 # try to solve some combinations of equations
349 while 1:
350 for eqs in self.combine(self.eqs):
351 if self.solve(eqs):
352 break # restart for loop
353 else:
354 break # quit while loop
356 def combine(self, eqs):
357 # create combinations of linear equations
358 if not len(eqs):
359 yield []
360 else:
361 for x in self.combine(eqs[1:]):
362 yield x
363 if eqs[0].is_linear():
364 for x in self.combine(eqs[1:]):
365 yield [eqs[0]] + x
367 def solve(self, eqs):
368 # try to solve a set of linear equations
369 l = len(eqs)
370 if l:
371 vars = []
372 for eq in eqs:
373 for addend in eq.addends:
374 var = addend.variable()
375 if var is not None and var not in vars:
376 vars.append(var)
377 if len(vars) == l:
378 a = Numeric.zeros((l, l))
379 b = Numeric.zeros((l, ))
380 for i, eq in enumerate(eqs):
381 for addend in eq.addends:
382 var = addend.variable()
383 if var is not None:
384 a[i, vars.index(var)] += addend.prefactor()
385 else:
386 b[i] -= addend.prefactor()
387 for i, value in enumerate(LinearAlgebra.solve_linear_equations(a, b)):
388 vars[i].set(value)
389 for eq in eqs:
390 i, = [i for i, selfeq in enumerate(self.eqs) if selfeq == eq]
391 del self.eqs[i]
392 return 1
393 elif len(vars) < l:
394 raise RuntimeError("equations are overdetermined")
395 return 0
397 solver = Solver()
400 if __name__ == "__main__":
402 x = vector(2, "x")
403 y = vector(2, "y")
404 z = vector(2, "z")
406 solver.eq(4*x + y, 2*x - y + vector([4, 0])) # => x + y = (2, 0)
407 solver.eq(x[0] - y[0], z[1])
408 solver.eq(x[1] - y[1], z[0])
409 solver.eq(vector([5, 0]), z)
411 print x
412 print y
413 print z