2 # -*- coding: ISO-8859-1 -*-
5 # Copyright (C) 2004 André Wobst <wobsta@users.sourceforge.net>
7 # This file is part of PyX (http://pyx.sourceforge.net/).
9 # PyX is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 2 of the License, or
12 # (at your option) any later version.
14 # PyX is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 # GNU General Public License for more details.
19 # You should have received a copy of the GNU General Public License
20 # along with PyX; if not, write to the Free Software
21 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
24 import Numeric
, LinearAlgebra
27 valuetypes
= (types
.IntType
, types
.LongType
, types
.FloatType
)
31 # this class represents a scalar variable
33 def __init__(self
, varname
="(no variable name provided)"):
34 self
.id = id(self
) # compare the id to check for the same variable
35 # (the __eq__ method is used to define "equalities")
36 self
.varname
= varname
40 return term([1], [self
], 0)
42 def __add__(self
, other
):
43 return term([1], [self
], 0) + other
47 def __sub__(self
, other
):
48 return term([1], [self
], 0) - other
50 def __rsub__(self
, other
):
51 return term([-1], [self
], 0) + other
54 return term([-1], [self
], 0)
56 def __mul__(self
, other
):
57 return term([other
], [self
], 0)
61 def __div__(self
, other
):
62 return term([1/other
], [self
], 0)
64 def __eq__(self
, other
):
65 return term([1], [self
], 0) == other
68 return self
.value
is not None
72 raise RuntimeError("variable already defined")
77 raise RuntimeError("variable not yet defined")
82 return str(self
.value
)
91 # this class represents the linear term:
92 # sum([p*v.value for p, v in zip(self.prefactors, self.vars]) + self.const
94 def __init__(self
, prefactors
, vars, const
):
95 assert len(prefactors
) == len(vars)
96 self
.id = id(self
) # compare the id to check for the same term
97 # (the __eq__ method is used to define "equalities")
98 self
.prefactors
= prefactors
105 def __add__(self
, other
):
109 other
= term([], [], other
)
111 prefactors
= self
.prefactors
[:]
112 vids
= [v
.id for v
in vars]
113 for p
, v
in zip(other
.prefactors
, other
.vars):
115 prefactors
[vids
.index(v
.id)] += p
119 return term(prefactors
, vars, self
.const
+ other
.const
)
123 def __sub__(self
, other
):
124 return self
+ (-other
)
127 return term([-p
for p
in self
.prefactors
], self
.vars, -self
.const
)
129 def __rsub__(self
, other
):
132 def __mul__(self
, other
):
133 return term([p
*other
for p
in self
.prefactors
], self
.vars, self
.const
*other
)
137 def __div__(self
, other
):
138 return term([p
/other
for p
in self
.prefactors
], self
.vars, self
.const
/other
)
140 def __eq__(self
, other
):
141 solver
.add(self
-other
)
144 return "+".join(["%s*%s" % pv
for pv
in zip(self
.prefactors
, self
.vars)]) + "+" + str(self
.const
)
148 # linear equation solver
151 self
.eqs
= [] # equations still to be taken into account
153 def add(self
, equation
):
154 # the equation is just a term which should be zero
155 self
.eqs
.append(equation
)
157 # try to solve some combinations of equations
159 for eqs
in self
.combine(self
.eqs
):
161 break # restart for loop
163 break # quit while loop
165 def combine(self
, eqs
):
166 # create combinations of equations
170 for x
in self
.combine(eqs
[1:]):
174 def solve(self
, eqs
):
175 # try to solve a set of equations
180 vids
.extend([v
.id for v
in eq
.vars if v
.id not in vids
and not v
.is_set()])
182 a
= Numeric
.zeros((l
, l
))
183 b
= Numeric
.zeros((l
, ))
185 for i
, vid
in enumerate(vids
):
188 for i
, eq
in enumerate(eqs
):
189 for p
, v
in zip(eq
.prefactors
, eq
.vars):
193 a
[i
, index
[v
.id]] += p
194 vars[index
[v
.id]] = v
196 for i
, value
in enumerate(LinearAlgebra
.solve_linear_equations(a
, b
)):
197 vars[i
].value
= value
199 i
, = [i
for i
, selfeq
in enumerate(self
.eqs
) if selfeq
.id == eq
.id]
209 if __name__
== "__main__":