2 # -*- coding: ISO-8859-1 -*-
5 # Copyright (C) 2004 André Wobst <wobsta@users.sourceforge.net>
7 # This file is part of PyX (http://pyx.sourceforge.net/).
9 # PyX is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 2 of the License, or
12 # (at your option) any later version.
14 # PyX is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
19 # You should have received a copy of the GNU General Public License
20 # along with PyX; if not, write to the Free Software
21 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
23 import Numeric
, LinearAlgebra
27 # represents a scalar variable or constant
29 def __init__(self
, value
=None, varname
="unnamed_scalar"):
33 self
.varname
= varname
36 return addend([self
], None)
39 return term([self
.addend()])
41 def __add__(self
, other
):
42 return self
.term() + other
46 def __sub__(self
, other
):
47 return self
.term() - other
49 def __rsub__(self
, other
):
50 return -self
.term() + other
55 def __mul__(self
, other
):
56 return self
.addend()*other
60 def __div__(self
, other
):
61 return self
.addend()/other
64 return self
.value
is not None
68 raise RuntimeError("variable already defined")
70 self
.value
= float(value
)
72 raise RuntimeError("float expected")
76 raise RuntimeError("variable not yet defined")
81 return "%s{=%s}" % (self
.varname
, self
.value
)
90 # represents a vector, i.e. a list of scalars
92 def __init__(self
, dimension_or_values
, varname
="unnamed_vector"):
96 raise RuntimeError("a vectors varname should be a string (you probably wanted to write vector([x, y]) instead of vector(x, y))")
99 self
.scalars
= [scalar(value
=value
, varname
="%s[%i]" % (varname
, i
))
100 for i
, value
in enumerate(dimension_or_values
)]
101 except (TypeError, AttributeError):
103 self
.scalars
= [scalar(varname
="%s[%i]" % (varname
, i
))
104 for i
in range(dimension_or_values
)]
105 self
.varname
= varname
108 return len(self
.scalars
)
110 def __getitem__(self
, i
):
111 return self
.scalars
[i
]
113 def __getattr__(self
, attr
):
121 raise AttributeError(attr
)
124 return addend([], self
)
127 return term([self
.addend()])
129 def __add__(self
, other
):
130 return self
.term() + other
134 def __sub__(self
, other
):
135 return self
.term() - other
137 def __rsub__(self
, other
):
138 return -self
.term() + other
141 return -self
.addend()
143 def __mul__(self
, other
):
144 return self
.addend()*other
148 def __div__(self
, other
):
149 return self
.addend()/other
152 return "%s{=(%s)}" % (self
.varname
, ", ".join([str(scalar
) for scalar
in self
.scalars
]))
156 # represents an addend of a term, i.e. a list of scalars and
157 # optionally a vector (for a vector term) otherwise the vector
160 def __init__(self
, scalars
, vector
):
161 # self.vector might be None for a scalar addend
162 self
.scalars
= scalars
166 return len(self
.vector
)
168 def __getitem__(self
, i
):
169 return addend(self
.scalars
+ [self
.vector
[i
]], None)
175 return term([self
.addend()])
178 assert self
.vector
is None
179 return len([scalar
for scalar
in self
.scalars
if not scalar
.is_set()]) < 2
182 assert self
.is_linear()
184 for scalar_set
in [scalar
for scalar
in self
.scalars
if scalar
.is_set()]:
185 prefactor
*= scalar_set
.get()
189 assert self
.is_linear()
191 variable
, = [scalar
for scalar
in self
.scalars
if not scalar
.is_set()]
197 def __add__(self
, other
):
198 return self
.term() + other
202 def __sub__(self
, other
):
203 return self
.term() - other
205 def __rsub__(self
, other
):
206 return -self
.term() + other
209 return addend([scalar(-1)] + self
.scalars
, self
.vector
)
211 def __mul__(self
, other
):
214 except (TypeError, AttributeError):
217 except (TypeError, AttributeError):
218 return self
*scalar(other
)
220 return term([self
*a
for a
in t
.addends
])
222 if a
.vector
is not None:
223 if self
.vector
is not None:
224 if len(self
.vector
) != len(a
.vector
):
225 raise RuntimeError("vector length mismatch in scalar product")
226 return term([addend(self
.scalars
+ a
.scalars
+ [x
*y
], None)
227 for x
, y
in zip(self
.vector
, a
.vector
)])
229 return addend(self
.scalars
+ a
.scalars
, a
.vector
)
231 return addend(self
.scalars
+ a
.scalars
, self
.vector
)
235 def __div__(self
, other
):
236 return addend([scalar(1/other
)] + self
.scalars
, self
.vector
)
239 scalarstring
= " * ".join([str(scalar
) for scalar
in self
.scalars
])
240 if self
.vector
is None:
243 if len(scalarstring
):
244 scalarstring
+= " * "
245 return scalarstring
+ str(self
.vector
)
249 # represents a term, i.e. a list of addends
251 def __init__(self
, addends
):
254 self
.length
= len(addends
[0])
255 except (TypeError, AttributeError):
256 for addend
in addends
[1:]:
259 except (TypeError, AttributeError):
262 raise RuntimeError("vector addend in scalar term")
265 for addend
in addends
[1:]:
268 except (TypeError, AttributeError):
269 raise RuntimeError("scalar addend in vector term")
271 raise RuntimeError("vector length mismatch in term constructor")
272 self
.addends
= addends
275 if self
.length
is None:
276 raise AttributeError("scalar term")
280 def __getitem__(self
, i
):
281 return term([addend
[i
] for addend
in self
.addends
])
288 for addend
in self
.addends
:
289 is_linear
= is_linear
and addend
.is_linear()
292 def __add__(self
, other
):
296 return self
+ scalar(other
)
298 return term(self
.addends
+ t
.addends
)
303 return term([-addend
for addend
in self
.addends
])
305 def __sub__(self
, other
):
308 def __rsub__(self
, other
):
311 def __mul__(self
, other
):
312 return term([addend
*other
for addend
in self
.addends
])
316 def __div__(self
, other
):
317 return term([addend
/other
for addend
in self
.addends
])
320 return " + ".join([str(addend
) for addend
in self
.addends
])
324 # linear equation solver
327 self
.eqs
= [] # scalar equations not yet solved (a equation is a term to be zero here)
329 def eq(self
, lhs
, rhs
=None):
336 # is it a vector equation?
338 except (TypeError, AttributeError):
341 for i
in range(neqs
):
344 def add(self
, equation
):
345 # the equation is just a term which should be zero
346 self
.eqs
.append(equation
)
348 # try to solve some combinations of equations
350 for eqs
in self
.combine(self
.eqs
):
352 break # restart for loop
354 break # quit while loop
356 def combine(self
, eqs
):
357 # create combinations of linear equations
361 for x
in self
.combine(eqs
[1:]):
363 if eqs
[0].is_linear():
364 for x
in self
.combine(eqs
[1:]):
367 def solve(self
, eqs
):
368 # try to solve a set of linear equations
373 for addend
in eq
.addends
:
374 var
= addend
.variable()
375 if var
is not None and var
not in vars:
378 a
= Numeric
.zeros((l
, l
))
379 b
= Numeric
.zeros((l
, ))
380 for i
, eq
in enumerate(eqs
):
381 for addend
in eq
.addends
:
382 var
= addend
.variable()
384 a
[i
, vars.index(var
)] += addend
.prefactor()
386 b
[i
] -= addend
.prefactor()
387 for i
, value
in enumerate(LinearAlgebra
.solve_linear_equations(a
, b
)):
390 i
, = [i
for i
, selfeq
in enumerate(self
.eqs
) if selfeq
== eq
]
394 raise RuntimeError("equations are overdetermined")
400 if __name__
== "__main__":
406 solver
.eq(4*x
+ y
, 2*x
- y
+ vector([4, 0])) # => x + y = (2, 0)
407 solver
.eq(x
[0] - y
[0], z
[1])
408 solver
.eq(x
[1] - y
[1], z
[0])
409 solver
.eq(vector([5, 0]), z
)