Cos -> cos transition
[sympy.git] / sympy / solvers / solvers.py
blob241c80b36c311ffa81828b404ba5ffb50040ac18
2 """ This module contain solvers for all kinds of equations:
4 - algebraic, use solve()
6 - recurrence, use rsolve() (not implemented)
8 - differential, use dsolve() (not implemented)
10 """
12 from sympy.core import *
14 from sympy.simplify import simplify, collect
15 from sympy.matrices import Matrix, zeronm
16 from sympy.polynomials import roots, PolynomialException
17 from sympy.utilities import any
19 def solve(eq, syms, simplified=True):
20 """Solves univariate polynomial equations and linear systems with
21 arbitrary symbolic coefficients. This function is just a wrapper
22 which makes analysis of its arguments and executes more specific
23 functions like 'roots' or 'solve_linear_system' etc.
25 On input you have to specify equation or a set of equations
26 (in this case via a list) using '==' pretty syntax or via
27 ordinary expressions, and a list of variables.
29 On output you will get a list of solutions in univariate case
30 or a dictionary with variables as keys and solutions as values
31 in the other case. If there were variables with can be assigned
32 with arbitrary value, then they will be avoided in the output.
34 Optionaly it is possible to have the solutions preprocessed
35 using simplification routines if 'simplified' flag is set.
37 To solve recurrence relations or differential equations use
38 'rsolve' or 'dsolve' functions respectively, which are also
39 wrappers combining set of problem specific methods.
41 >>> from sympy import *
42 >>> x, y, a = symbols('xya')
44 >>> r = solve(x**2 - 3*x + 2, x)
45 >>> r.sort()
46 >>> print r
47 [1, 2]
49 >>> solve(x**2 == a, x)
50 [-a**(1/2), a**(1/2)]
52 >>> solve(x**4 == 1, x)
53 [I, 1, -1, -I]
55 >>> solve([x + 5*y == 2, -3*x + 6*y == 15], [x, y])
56 {y: 1, x: -3}
58 """
59 if isinstance(syms, Basic):
60 syms = [syms]
62 if not isinstance(eq, list):
63 if isinstance(eq, Equality):
64 # got equation, so move all the
65 # terms to the left hand side
66 equ = eq.lhs - eq.rhs
67 else:
68 equ = Basic.sympify(eq)
70 try:
71 # 'roots' method will return all possible complex
72 # solutions, however we have to remove duplicates
73 solutions = list(set(roots(equ, syms[0])))
74 except PolynomialException:
75 raise "Not a polynomial equation. Can't solve it, yet."
77 if simplified == True:
78 return [ simplify(s) for s in solutions ]
79 else:
80 return solutions
81 else:
82 if eq == []:
83 return {}
84 else:
85 # augmented matrix
86 n, m = len(eq), len(syms)
87 matrix = zeronm(n, m+1)
89 index = {}
91 for i in range(0, m):
92 index[syms[i]] = i
94 for i in range(0, n):
95 if isinstance(eq[i], Equality):
96 # got equation, so move all the
97 # terms to the left hand side
98 equ = eq[i].lhs - eq[i].rhs
99 else:
100 equ = Basic.sympify(eq[i])
102 content = collect(equ.expand(), syms, evaluate=False)
104 for var, expr in content.iteritems():
105 if isinstance(var, Symbol) and not expr.has(*syms):
106 matrix[i, index[var]] = expr
107 elif isinstance(var, Basic.One) and not expr.has(*syms):
108 matrix[i, m] = -expr
109 else:
110 raise "Not a linear system. Can't solve it, yet."
111 else:
112 return solve_linear_system(matrix, syms, simplified)
114 def solve_linear_system(system, symbols, simplified=True):
115 """Solve system of N linear equations with M variables, which means
116 both Cramer and over defined systems are supported. The possible
117 number of solutions is zero, one or infinite. Respectively this
118 procedure will return None or dictionary with solutions. In the
119 case of over definend system all arbitrary parameters are skiped.
120 This may cause situation in with empty dictionary is returned.
121 In this case it means all symbols can be assigne arbitray values.
123 Input to this functions is a Nx(M+1) matrix, which means it has
124 to be in augmented form. If you are unhappy with such setting
125 use 'solve' method instead, where you can input equations
126 explicitely. And don't worry aboute the matrix, this function
127 is persistent and will make a local copy of it.
129 The algorithm used here is fraction free Gaussian elimination,
130 which results, after elimination, in upper-triangular matrix.
131 Then solutions are found using back-substitution. This approach
132 is more efficient and compact than the Gauss-Jordan method.
134 >>> from sympy import *
135 >>> x, y = symbols('xy')
137 Solve the following system:
139 x + 4 y == 2
140 -2 x + y == 14
142 >>> system = Matrix(( (1, 4, 2), (-2, 1, 14)))
143 >>> solve_linear_system(system, [x, y])
144 {y: 2, x: -6}
147 matrix = system[:,:]
148 syms = symbols[:]
150 i, m = 0, matrix.cols-1 # don't count augmentation
152 while i < matrix.lines:
153 if matrix [i, i] == 0:
154 # there is no pivot in current column
155 # so try to find one in other colums
156 for k in range(i+1, m):
157 if matrix[i, k] != 0:
158 break
159 else:
160 if matrix[i, m] != 0:
161 return None # no solutions
162 else:
163 # zero row or was a linear combination of
164 # other rows so now we can safely skip it
165 matrix.row_del(i)
166 continue
168 # we want to change the order of colums so
169 # the order of variables must also change
170 syms[i], syms[k] = syms[k], syms[i]
171 matrix.col_swap(i, k)
173 pivot = matrix [i, i]
175 # divide all elements in the current row by the pivot
176 matrix.row(i, lambda x, _: x / pivot)
178 for k in range(i+1, matrix.lines):
179 if matrix[k, i] != 0:
180 coeff = matrix[k, i]
182 # subtract from the current row the row containing
183 # pivot and multiplied by extracted coefficient
184 matrix.row(k, lambda x, j: x - matrix[i, j]*coeff)
186 i += 1
188 # if there weren't any problmes, augmented matrix is now
189 # in row-echelon form so we can check how many solutions
190 # there are and extract them using back substitution
192 if len(syms) == matrix.lines:
193 # this system is Cramer equivalent so there is
194 # exactly one solution to this system of equations
195 k, solutions = i-1, {}
197 while k >= 0:
198 content = matrix[k, m]
200 # run back-substitution for variables
201 for j in range(k+1, m):
202 content -= matrix[k, j]*solutions[syms[j]]
204 if simplified == True:
205 solutions[syms[k]] = simplify(content)
206 else:
207 solutions[syms[k]] = content
209 k -= 1
211 return solutions
212 elif len(syms) > matrix.lines:
213 # this system will have infinite number of solutions
214 # dependent on exactly len(syms) - i parameters
215 k, solutions = i-1, {}
217 while k >= 0:
218 content = matrix[k, m]
220 # run back-substitution for variables
221 for j in range(k+1, i):
222 content -= matrix[k, j]*solutions[syms[j]]
224 # run back-substitution for parameters
225 for j in range(i, m):
226 content -= matrix[k, j]*syms[j]
228 if simplified == True:
229 solutions[syms[k]] = simplify(content)
230 else:
231 solutions[syms[k]] = content
233 k -= 1
235 return solutions
236 else:
237 return None # no solutions
239 def solve_undetermined_coeffs(equ, coeffs, sym, simplified=True):
240 """Solve equation of a type p(x; a_1, ..., a_k) == q(x) where both
241 p, q are univariate polynomials and f depends on k parameters.
242 The result of this functions is a dictionary with symbolic
243 values of those parameters with respect to coefficiens in q.
245 This functions accepts both Equations class instances and ordinary
246 SymPy expressions. Specification of parameters and variable is
247 obligatory for efficiency and simplicity reason.
249 >>> from sympy import *
250 >>> a, b, c, x = symbols('a', 'b', 'c', 'x')
252 >>> solve_undetermined_coeffs(2*a*x + a+b == x, [a, b], x)
253 {a: 1/2, b: -1/2}
255 >>> solve_undetermined_coeffs(a*c*x + a+b == x, [a, b], x)
256 {a: 1/c, b: -1/c}
259 if isinstance(equ, Equality):
260 # got equation, so move all the
261 # terms to the left hand side
262 equ = equ.lhs - equ.rhs
264 system = collect(equ.expand(), sym, evaluate=False).values()
266 if not any([ equ.has(sym) for equ in system ]):
267 # consecutive powers in the input expressions have
268 # been successfully collected, so solve remaining
269 # system using Gaussian ellimination algorithm
270 return solve(system, coeffs, simplified)
271 else:
272 return None # no solutions
274 def solve_linear_system_LU(matrix, syms):
275 """ LU function works for invertible only """
276 assert matrix.lines == matrix.cols-1
277 A = matrix[:matrix.lines,:matrix.lines]
278 b = matrix[:,matrix.cols-1:]
279 soln = A.LUsolve(b)
280 solutions = {}
281 for i in range(soln.lines):
282 solutions[syms[i]] = soln[i,0]
283 return solutions
285 def dsolve(eq, funcs):
287 Solves any (supported) kind of differential equation.
289 Usage
290 =====
291 dsolve(f, y(x)) -> Solve a differential equation f for the function y
294 Details
295 =======
296 @param f: ordinary differential equation
298 @param y: indeterminate function of one variable
300 - you can declare the derivative of an unknown function this way:
301 >>> from sympy import *
302 >>> x = Symbol('x') # x is the independent variable
303 >>> f = Function(x) # f is a function of f
304 >>> f_ = Derivative(f, x) # f_ will be the derivative of f with respect to x
306 - This function just parses the equation "eq" and determines the type of
307 differential equation, then it determines all the coefficients and then
308 calls the particular solver, which just accepts the coefficients.
310 Examples
311 ========
312 >>> from sympy import *
313 >>> x = Symbol('x')
314 >>> f = Function('f')
315 >>> fx = f(x)
316 >>> dsolve(Derivative(Derivative(fx,x),x)+9*fx, fx)
317 C1*sin(3*x) + C2*cos(3*x)
321 #currently only solve for one function
322 if isinstance(funcs, Basic) or len(funcs) == 1:
323 if isinstance(funcs, (list, tuple)): # normalize args
324 f = funcs[0]
325 else:
326 f = funcs
328 x = f[0]
329 f = f.func
331 # This assumes f is an ApplyXXX object
332 a = Wild('a', exclude=[f])
333 b = Wild('b', exclude=[f])
334 c = Wild('c', exclude=[f])
336 r = eq.match(a*Derivative(f(x),x) + b)
337 if r: return solve_ODE_first_order(r[a], r[b], f(x), x)
339 r = eq.match(a*Derivative(f(x),x,x) + b*f(x))
340 if r: return solve_ODE_second_order(r[a], 0, r[b], f(x), x)
342 #special equations, that we know how to solve
343 t = x*S.Exp(-f(x))
344 tt = a*Derivative(t,x,x)/t
345 r = eq.match(tt.expand())
346 if r:
347 #check, that we've rewritten the equation correctly:
348 #assert ( r[a]*t.diff(x,2)/t ) == eq.subs(f, t)
349 return solve_ODE_1(f(x), x)
351 neq = eq*S.Exp(f(x))/S.Exp(-f(x))
352 r = neq.match(tt.expand())
353 if r:
354 #check, that we've rewritten the equation correctly:
355 #assert ( t.diff(x,2)*r[a]/t ).expand() == eq
356 return solve_ODE_1(f(x), x)
357 raise NotImplementedError("dsolve: Cannot solve " + str(eq))
359 def solve_ODE_first_order(a, b, f, x):
360 """ a*f'(x)+b = 0 """
361 from sympy.integrals.integrals import integrate
362 return integrate(-b/a, x) + Symbol("C1")
364 def solve_ODE_second_order(a, b, c, f, x):
365 """ a*f''(x) + b*f'(x) + c = 0 """
366 #a very special case, for b=0 and a,c not depending on x:
367 return Symbol("C1")*S.Sin(S.Sqrt(c/a)*x)+Symbol("C2")*Basic.cos(S.Sqrt(c/a)*x)
369 def solve_ODE_1(f, x):
370 """ (x*exp(-f(x)))'' = 0 """
371 C1 = Symbol("C1")
372 C2 = Symbol("C2")
373 return -S.Log(C1+C2/x)