- removed isinstance() calls
[PyX/mjg.git] / test / experimental / solve.py
blob0e6efe6c8f3fabd8694be5424d7026a2f601365a
1 #!/usr/bin/env python
2 # -*- coding: ISO-8859-1 -*-
5 # Copyright (C) 2004 André Wobst <wobsta@users.sourceforge.net>
7 # This file is part of PyX (http://pyx.sourceforge.net/).
9 # PyX is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 2 of the License, or
12 # (at your option) any later version.
14 # PyX is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
19 # You should have received a copy of the GNU General Public License
20 # along with PyX; if not, write to the Free Software
21 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
23 import types, sets
24 import Numeric, LinearAlgebra
27 valuetypes = (types.IntType, types.LongType, types.FloatType)
30 class scalar:
31 # this class represents a scalar variable
33 def __init__(self, varname="(no variable name provided)"):
34 self.id = id(self) # compare the id to check for the same variable
35 # (the __eq__ method is used to define "equalities")
36 self.varname = varname
37 self.value = None
39 def term(self):
40 return term([1], [self], 0)
42 def __add__(self, other):
43 return term([1], [self], 0) + other
45 __radd__ = __add__
47 def __sub__(self, other):
48 return term([1], [self], 0) - other
50 def __rsub__(self, other):
51 return term([-1], [self], 0) + other
53 def __neg__(self):
54 return term([-1], [self], 0)
56 def __mul__(self, other):
57 return term([other], [self], 0)
59 __rmul__ = __mul__
61 def __div__(self, other):
62 return term([1/other], [self], 0)
64 def __eq__(self, other):
65 return term([1], [self], 0) == other
67 def is_set(self):
68 return self.value is not None
70 def set(self, value):
71 if self.is_set():
72 raise RuntimeError("variable already defined")
73 self.value = value
75 def get(self):
76 if not self.is_set():
77 raise RuntimeError("variable not yet defined")
78 return self.value
80 def __str__(self):
81 if self.is_set():
82 return str(self.value)
83 else:
84 return self.varname
86 def __float__(self):
87 return self.get()
90 class term:
91 # this class represents the linear term:
92 # sum([p*v.value for p, v in zip(self.prefactors, self.vars]) + self.const
94 def __init__(self, prefactors, vars, const):
95 assert len(prefactors) == len(vars)
96 self.id = id(self) # compare the id to check for the same term
97 # (the __eq__ method is used to define "equalities")
98 self.prefactors = prefactors
99 self.vars = vars
100 self.const = const
102 def term(self):
103 return self
105 def __add__(self, other):
106 try:
107 other = other.term()
108 except:
109 other = term([], [], other)
110 vars = self.vars[:]
111 prefactors = self.prefactors[:]
112 vids = [v.id for v in vars]
113 for p, v in zip(other.prefactors, other.vars):
114 try:
115 prefactors[vids.index(v.id)] += p
116 except ValueError:
117 vars.append(v)
118 prefactors.append(p)
119 return term(prefactors, vars, self.const + other.const)
121 __radd__ = __add__
123 def __sub__(self, other):
124 return self + (-other)
126 def __neg__(self):
127 return term([-p for p in self.prefactors], self.vars, -self.const)
129 def __rsub__(self, other):
130 return -self+other
132 def __mul__(self, other):
133 return term([p*other for p in self.prefactors], self.vars, self.const*other)
135 __rmul__ = __mul__
137 def __div__(self, other):
138 return term([p/other for p in self.prefactors], self.vars, self.const/other)
140 def __eq__(self, other):
141 solver.add(self-other)
143 def __str__(self):
144 return "+".join(["%s*%s" % pv for pv in zip(self.prefactors, self.vars)]) + "+" + str(self.const)
147 class Solver:
148 # linear equation solver
150 def __init__(self):
151 self.eqs = [] # equations still to be taken into account
153 def add(self, equation):
154 # the equation is just a term which should be zero
155 self.eqs.append(equation)
157 # try to solve some combinations of equations
158 while 1:
159 for eqs in self.combine(self.eqs):
160 if self.solve(eqs):
161 break # restart for loop
162 else:
163 break # quit while loop
165 def combine(self, eqs):
166 # create combinations of equations
167 if not len(eqs):
168 yield []
169 else:
170 for x in self.combine(eqs[1:]):
171 yield x
172 yield [eqs[0]] + x
174 def solve(self, eqs):
175 # try to solve a set of equations
176 l = len(eqs)
177 if l:
178 vids = []
179 for eq in eqs:
180 vids.extend([v.id for v in eq.vars if v.id not in vids and not v.is_set()])
181 if len(vids) == l:
182 a = Numeric.zeros((l, l))
183 b = Numeric.zeros((l, ))
184 index = {}
185 for i, vid in enumerate(vids):
186 index[vid] = i
187 vars = {}
188 for i, eq in enumerate(eqs):
189 for p, v in zip(eq.prefactors, eq.vars):
190 if v.is_set():
191 b[i] -= p*v.value
192 else:
193 a[i, index[v.id]] += p
194 vars[index[v.id]] = v
195 b[i] -= eq.const
196 for i, value in enumerate(LinearAlgebra.solve_linear_equations(a, b)):
197 vars[i].value = value
198 for eq in eqs:
199 i, = [i for i, selfeq in enumerate(self.eqs) if selfeq.id == eq.id]
200 del self.eqs[i]
201 return 1
202 else:
203 assert len(vids) > l
204 return 0
206 solver = Solver()
209 if __name__ == "__main__":
211 x = scalar("x")
212 y = scalar("y")
213 z = scalar("z")
215 x + y == 2*x - y + 3
216 x - y == z
217 5 == z
219 print "x=%s" % x
220 print "y=%s" % y
221 print "z=%s" % z