"Fix" docstring where output depends on internal ordering.
[sympy.git] / sympy / solvers / solvers.py
blob938a0830086796381f4e8d93da6385882715f6e5
2 """ This module contain solvers for all kinds of equations:
4 - algebraic, use solve()
6 - recurrence, use rsolve()
8 - differential, use dsolve()
10 -transcendental, use tsolve()
12 -nonlinear (numerically), use msolve() (you will need a good starting point)
14 """
16 from sympy.core.sympify import sympify
17 from sympy.core.basic import Basic, S, C
18 from sympy.core.symbol import Symbol, Wild
19 from sympy.core.relational import Equality
20 from sympy.core.function import Derivative, diff
22 from sympy.functions import sqrt, log, exp, LambertW
23 from sympy.simplify import simplify, collect
24 from sympy.matrices import Matrix, zeronm
25 from sympy.polys import roots
27 from sympy.utilities import any, all
28 from sympy.utilities.lambdify import lambdify
29 from sympy.solvers.numeric import newton
31 from sympy.solvers.polysys import solve_poly_system
33 def solve(f, *symbols, **flags):
34 """Solves equations and systems of equations.
36 Currently supported are univariate polynomial and transcendental
37 equations and systems of linear and polynomial equations. Input
38 is formed as a single expression or an equation, or an iterable
39 container in case of an equation system. The type of output may
40 vary and depends heavily on the input. For more details refer to
41 more problem specific functions.
43 By default all solutions are simplified to make the output more
44 readable. If this is not the expected behavior, eg. because of
45 speed issues, set simplified=False in function arguments.
47 To solve equations and systems of equations of other kind, eg.
48 recurrence relations of differential equations use rsolve() or
49 dsolve() functions respectively.
51 >>> from sympy import *
52 >>> x,y = symbols('xy')
54 Solve a polynomial equation:
56 >>> solve(x**4-1, x)
57 [1, -1, -I, I]
59 Solve a linear system:
61 >>> solve((x+5*y-2, -3*x+6*y-15), x, y)
62 {y: 1, x: -3}
64 """
65 if not symbols:
66 raise ValueError('no symbols were given')
68 if len(symbols) == 1:
69 if isinstance(symbols[0], (list, tuple, set)):
70 symbols = symbols[0]
72 symbols = map(sympify, symbols)
74 if any(not s.is_Symbol for s in symbols):
75 raise TypeError('not a Symbol')
77 if not isinstance(f, (tuple, list, set)):
78 f = sympify(f)
80 if isinstance(f, Equality):
81 f = f.lhs - f.rhs
83 if len(symbols) == 1:
84 poly = f.as_poly(*symbols)
86 if poly is not None:
87 result = roots(poly, cubics=True, quartics=True).keys()
88 else:
89 result = [tsolve(f, *symbols)]
90 else:
91 raise NotImplementedError('multivariate equation')
93 if flags.get('simplified', True):
94 return map(simplify, result)
95 else:
96 return result
97 else:
98 if not f:
99 return {}
100 else:
101 polys = []
103 for g in f:
104 g = sympify(g)
106 if isinstance(g, Equality):
107 g = g.lhs - g.rhs
109 poly = g.as_poly(*symbols)
111 if poly is not None:
112 polys.append(poly)
113 else:
114 raise NotImplementedError
116 if all(p.is_linear for p in polys):
117 n, m = len(f), len(symbols)
118 matrix = zeronm(n, m + 1)
120 for i, poly in enumerate(polys):
121 for coeff, monom in poly.iter_terms():
122 try:
123 j = list(monom).index(1)
124 matrix[i, j] = coeff
125 except ValueError:
126 matrix[i, m] = -coeff
128 return solve_linear_system(matrix, *symbols, **flags)
129 else:
130 return solve_poly_system(polys)
132 def solve_linear_system(system, *symbols, **flags):
133 """Solve system of N linear equations with M variables, which means
134 both Cramer and over defined systems are supported. The possible
135 number of solutions is zero, one or infinite. Respectively this
136 procedure will return None or dictionary with solutions. In the
137 case of over definend system all arbitrary parameters are skiped.
138 This may cause situation in with empty dictionary is returned.
139 In this case it means all symbols can be assigne arbitray values.
141 Input to this functions is a Nx(M+1) matrix, which means it has
142 to be in augmented form. If you are unhappy with such setting
143 use 'solve' method instead, where you can input equations
144 explicitely. And don't worry aboute the matrix, this function
145 is persistent and will make a local copy of it.
147 The algorithm used here is fraction free Gaussian elimination,
148 which results, after elimination, in upper-triangular matrix.
149 Then solutions are found using back-substitution. This approach
150 is more efficient and compact than the Gauss-Jordan method.
152 >>> from sympy import *
153 >>> x, y = symbols('xy')
155 Solve the following system:
157 x + 4 y == 2
158 -2 x + y == 14
160 >>> system = Matrix(( (1, 4, 2), (-2, 1, 14)))
161 >>> solve_linear_system(system, x, y)
162 {y: 2, x: -6}
165 matrix = system[:,:]
166 syms = list(symbols)
168 i, m = 0, matrix.cols-1 # don't count augmentation
170 while i < matrix.lines:
171 if matrix [i, i] == 0:
172 # there is no pivot in current column
173 # so try to find one in other colums
174 for k in range(i+1, m):
175 if matrix[i, k] != 0:
176 break
177 else:
178 if matrix[i, m] != 0:
179 return None # no solutions
180 else:
181 # zero row or was a linear combination of
182 # other rows so now we can safely skip it
183 matrix.row_del(i)
184 continue
186 # we want to change the order of colums so
187 # the order of variables must also change
188 syms[i], syms[k] = syms[k], syms[i]
189 matrix.col_swap(i, k)
191 pivot = matrix [i, i]
193 # divide all elements in the current row by the pivot
194 matrix.row(i, lambda x, _: x / pivot)
196 for k in range(i+1, matrix.lines):
197 if matrix[k, i] != 0:
198 coeff = matrix[k, i]
200 # subtract from the current row the row containing
201 # pivot and multiplied by extracted coefficient
202 matrix.row(k, lambda x, j: x - matrix[i, j]*coeff)
204 i += 1
206 # if there weren't any problmes, augmented matrix is now
207 # in row-echelon form so we can check how many solutions
208 # there are and extract them using back substitution
210 simplified = flags.get('simplified', True)
212 if len(syms) == matrix.lines:
213 # this system is Cramer equivalent so there is
214 # exactly one solution to this system of equations
215 k, solutions = i-1, {}
217 while k >= 0:
218 content = matrix[k, m]
220 # run back-substitution for variables
221 for j in range(k+1, m):
222 content -= matrix[k, j]*solutions[syms[j]]
224 if simplified:
225 solutions[syms[k]] = simplify(content)
226 else:
227 solutions[syms[k]] = content
229 k -= 1
231 return solutions
232 elif len(syms) > matrix.lines:
233 # this system will have infinite number of solutions
234 # dependent on exactly len(syms) - i parameters
235 k, solutions = i-1, {}
237 while k >= 0:
238 content = matrix[k, m]
240 # run back-substitution for variables
241 for j in range(k+1, i):
242 content -= matrix[k, j]*solutions[syms[j]]
244 # run back-substitution for parameters
245 for j in range(i, m):
246 content -= matrix[k, j]*syms[j]
248 if simplified:
249 solutions[syms[k]] = simplify(content)
250 else:
251 solutions[syms[k]] = content
253 k -= 1
255 return solutions
256 else:
257 return None # no solutions
259 def solve_undetermined_coeffs(equ, coeffs, sym, **flags):
260 """Solve equation of a type p(x; a_1, ..., a_k) == q(x) where both
261 p, q are univariate polynomials and f depends on k parameters.
262 The result of this functions is a dictionary with symbolic
263 values of those parameters with respect to coefficiens in q.
265 This functions accepts both Equations class instances and ordinary
266 SymPy expressions. Specification of parameters and variable is
267 obligatory for efficiency and simplicity reason.
269 >>> from sympy import *
270 >>> a, b, c, x = symbols('a', 'b', 'c', 'x')
272 >>> solve_undetermined_coeffs(Eq(2*a*x + a+b, x), [a, b], x)
273 {a: 1/2, b: -1/2}
275 >>> solve_undetermined_coeffs(Eq(a*c*x + a+b, x), [a, b], x)
276 {a: 1/c, b: -1/c}
279 if isinstance(equ, Equality):
280 # got equation, so move all the
281 # terms to the left hand side
282 equ = equ.lhs - equ.rhs
284 system = collect(equ.expand(), sym, evaluate=False).values()
286 if not any([ equ.has(sym) for equ in system ]):
287 # consecutive powers in the input expressions have
288 # been successfully collected, so solve remaining
289 # system using Gaussian ellimination algorithm
290 return solve(system, *coeffs, **flags)
291 else:
292 return None # no solutions
294 def solve_linear_system_LU(matrix, syms):
295 """ LU function works for invertible only """
296 assert matrix.lines == matrix.cols-1
297 A = matrix[:matrix.lines,:matrix.lines]
298 b = matrix[:,matrix.cols-1:]
299 soln = A.LUsolve(b)
300 solutions = {}
301 for i in range(soln.lines):
302 solutions[syms[i]] = soln[i,0]
303 return solutions
305 def dsolve(eq, funcs):
307 Solves any (supported) kind of differential equation.
309 Usage
310 =====
311 dsolve(f, y(x)) -> Solve a differential equation f for the function y
314 Details
315 =======
316 @param f: ordinary differential equation (either just the left hand
317 side, or the Equality class)
319 @param y: indeterminate function of one variable
321 - you can declare the derivative of an unknown function this way:
322 >>> from sympy import *
323 >>> x = Symbol('x') # x is the independent variable
325 >>> f = Function("f")(x) # f is a function of x
326 >>> f_ = Derivative(f, x) # f_ will be the derivative of f with respect to x
328 - This function just parses the equation "eq" and determines the type of
329 differential equation by its order, then it determines all the coefficients and then
330 calls the particular solver, which just accepts the coefficients.
331 - "eq" can be either an Equality, or just the left hand side (in which
332 case the right hand side is assumed to be 0)
333 - see test_ode.py for many tests, that serve also as a set of examples
334 how to use dsolve
336 Examples
337 ========
338 >>> from sympy import *
339 >>> x = Symbol('x')
341 >>> f = Function('f')
342 >>> dsolve(Derivative(f(x),x,x)+9*f(x), f(x))
343 C1*sin(3*x) + C2*cos(3*x)
344 >>> dsolve(Eq(Derivative(f(x),x,x)+9*f(x)+1, 1), f(x))
345 C1*sin(3*x) + C2*cos(3*x)
349 if isinstance(eq, Equality):
350 if eq.rhs != 0:
351 return dsolve(eq.lhs-eq.rhs, funcs)
352 eq = eq.lhs
354 #currently only solve for one function
355 if isinstance(funcs, Basic) or len(funcs) == 1:
356 if isinstance(funcs, (list, tuple)): # normalize args
357 f = funcs[0]
358 else:
359 f = funcs
361 x = f.args[0]
362 f = f.func
364 #We first get the order of the equation, so that we can choose the
365 #corresponding methods. Currently, only first and second
366 #order odes can be handled.
367 order = deriv_degree(eq, f(x))
369 if order > 2 :
370 raise NotImplementedError("dsolve: Cannot solve " + str(eq))
371 elif order == 2:
372 return solve_ODE_second_order(eq, f(x))
373 elif order == 1:
374 return solve_ODE_first_order(eq, f(x))
375 else:
376 raise NotImplementedError("Not a differential equation!")
378 def deriv_degree(expr, func):
379 """ get the order of a given ode, the function is implemented
380 recursively """
381 a = Wild('a', exclude=[func])
383 order = 0
384 if isinstance(expr, Derivative):
385 order = len(expr.symbols)
386 else:
387 for arg in expr.args:
388 if isinstance(arg, Derivative):
389 order = max(order, len(arg.symbols))
390 elif expr.match(a):
391 order = 0
392 else :
393 for arg1 in arg.args:
394 order = max(order, deriv_degree(arg1, func))
396 return order
398 def solve_ODE_first_order(eq, f):
400 solves many kinds of first order odes, different methods are used
401 depending on the form of the given equation. Now the linear
402 case is implemented.
404 from sympy.integrals.integrals import integrate
405 x = f.args[0]
406 f = f.func
408 #linear case: a(x)*f'(x)+b(x)*f(x)+c(x) = 0
409 a = Wild('a', exclude=[f(x)])
410 b = Wild('b', exclude=[f(x)])
411 c = Wild('c', exclude=[f(x)])
413 r = eq.match(a*diff(f(x),x) + b*f(x) + c)
414 if r:
415 t = C.exp(integrate(r[b]/r[a], x))
416 tt = integrate(t*(-r[c]/r[a]), x)
417 return (tt + Symbol("C1"))/t
419 #other cases of first order odes will be implemented here
421 raise NotImplementedError("dsolve: Cannot solve " + str(eq))
423 def solve_ODE_second_order(eq, f):
425 solves many kinds of second order odes, different methods are used
426 depending on the form of the given equation. Now the constanst
427 coefficients case and a special case are implemented.
429 x = f.args[0]
430 f = f.func
432 #constant coefficients case: af''(x)+bf'(x)+cf(x)=0
433 a = Wild('a', exclude=[x])
434 b = Wild('b', exclude=[x])
435 c = Wild('c', exclude=[x])
437 r = eq.match(a*f(x).diff(x,x) + c*f(x))
438 if r:
439 return Symbol("C1")*C.sin(sqrt(r[c]/r[a])*x)+Symbol("C2")*C.cos(sqrt(r[c]/r[a])*x)
441 r = eq.match(a*f(x).diff(x,x) + b*diff(f(x),x) + c*f(x))
442 if r:
443 r1 = solve(r[a]*x**2 + r[b]*x + r[c], x)
444 if r1[0].is_real:
445 if len(r1) == 1:
446 return (Symbol("C1") + Symbol("C2")*x)*exp(r1[0]*x)
447 else:
448 return Symbol("C1")*exp(r1[0]*x) + Symbol("C2")*exp(r1[1]*x)
449 else:
450 r2 = abs((r1[0] - r1[1])/(2*S.ImaginaryUnit))
451 return (Symbol("C2")*C.cos(r2*x) + Symbol("C1")*C.sin(r2*x))*exp((r1[0] + r1[1])*x/2)
453 #other cases of the second order odes will be implemented here
455 #special equations, that we know how to solve
456 t = x*C.exp(f(x))
457 tt = a*t.diff(x, x)/t
458 r = eq.match(tt.expand())
459 if r:
460 return -solve_ODE_1(f(x), x)
462 t = x*C.exp(-f(x))
463 tt = a*t.diff(x, x)/t
464 r = eq.match(tt.expand())
465 if r:
466 #check, that we've rewritten the equation correctly:
467 #assert ( r[a]*t.diff(x,2)/t ) == eq.subs(f, t)
468 return solve_ODE_1(f(x), x)
470 neq = eq*C.exp(f(x))/C.exp(-f(x))
471 r = neq.match(tt.expand())
472 if r:
473 #check, that we've rewritten the equation correctly:
474 #assert ( t.diff(x,2)*r[a]/t ).expand() == eq
475 return solve_ODE_1(f(x), x)
477 raise NotImplementedError("cannot solve this")
479 def solve_ODE_1(f, x):
480 """ (x*exp(-f(x)))'' = 0 """
481 C1 = Symbol("C1")
482 C2 = Symbol("C2")
483 return -C.log(C1+C2/x)
485 x = Symbol('x', dummy=True)
486 a,b,c,d,e,f,g,h = [Wild(t, exclude=[x]) for t in 'abcdefgh']
487 patterns = None
489 def _generate_patterns():
490 """Generates patterns for transcendental equations.
492 This is lazily calculated (called) in the tsolve() function and stored in
493 the patterns global variable.
496 tmp1 = f ** (h-(c*g/b))
497 tmp2 = (-e*tmp1/a)**(1/d)
498 global patterns
499 patterns = [
500 (a*(b*x+c)**d + e , ((-(e/a))**(1/d)-c)/b),
501 ( b+c*exp(d*x+e) , (log(-b/c)-e)/d),
502 (a*x+b+c*exp(d*x+e) , -b/a-LambertW(c*d*exp(e-b*d/a)/a)/d),
503 ( b+c*f**(d*x+e) , (log(-b/c)-e*log(f))/d/log(f)),
504 (a*x+b+c*f**(d*x+e) , -b/a-LambertW(c*d*f**(e-b*d/a)*log(f)/a)/d/log(f)),
505 ( b+c*log(d*x+e) , (exp(-b/c)-e)/d),
506 (a*x+b+c*log(d*x+e) , -e/d+c/a*LambertW(a/c/d*exp(-b/c+a*e/c/d))),
507 (a*(b*x+c)**d + e*f**(g*x+h) , -c/b-d*LambertW(-tmp2*g*log(f)/b/d)/g/log(f))
510 def tsolve(eq, sym):
512 Solves a transcendental equation with respect to the given
513 symbol. Various equations containing mixed linear terms, powers,
514 and logarithms, can be solved.
516 Only a single solution is returned. This solution is generally
517 not unique. In some cases, a complex solution may be returned
518 even though a real solution exists.
520 >>> from sympy import *
521 >>> x = Symbol('x')
523 >>> tsolve(3**(2*x+5)-4, x)
524 (-5*log(3) + log(4))/(2*log(3))
526 >>> tsolve(log(x) + 2*x, x)
527 1/2*LambertW(2)
530 if patterns is None:
531 _generate_patterns()
532 eq = sympify(eq)
533 if isinstance(eq, Equality):
534 eq = eq.lhs - eq.rhs
535 sym = sympify(sym)
536 eq2 = eq.subs(sym, x)
537 # First see if the equation has a linear factor
538 # In that case, the other factor can contain x in any way (as long as it
539 # is finite), and we have a direct solution
540 r = Wild('r')
541 m = eq2.match((a*x+b)*r)
542 if m and m[a]:
543 return (-b/a).subs(m).subs(x, sym)
544 for p, sol in patterns:
545 m = eq2.match(p)
546 if m:
547 return sol.subs(m).subs(x, sym)
548 raise ValueError("unable to solve the equation")
551 def msolve(args, f, x0, tol=None, maxsteps=None, verbose=False, norm=None,
552 modules=['mpmath', 'sympy']):
554 Solves a nonlinear equation system numerically.
556 f is a vector function of symbolic expressions representing the system.
557 args are the variables.
558 x0 is a starting vector close to a solution.
560 Be careful with x0, not using floats might give unexpected results.
562 Use modules to specify which modules should be used to evaluate the
563 function and the Jacobian matrix. Make sure to use a module that supports
564 matrices. For more information on the syntax, please see the docstring
565 of lambdify.
567 Currently only fully determined systems are supported.
569 >>> from sympy import Symbol, Matrix
570 >>> x1 = Symbol('x1')
571 >>> x2 = Symbol('x2')
572 >>> f1 = 3 * x1**2 - 2 * x2**2 - 1
573 >>> f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8
574 >>> msolve((x1, x2), (f1, f2), (-1., 1.))
575 [-1.19287309935246]
576 [ 1.27844411169911]
578 if isinstance(f, (list, tuple)):
579 f = Matrix(f).T
580 if len(args) != f.cols:
581 raise NotImplementedError('need exactly as many variables as equations')
582 if verbose:
583 print 'f(x):'
584 print f
585 # derive Jacobian
586 J = f.jacobian(args)
587 if verbose:
588 print 'J(x):'
589 print J
590 # create functions
591 f = lambdify(args, f.T, modules)
592 J = lambdify(args, J, modules)
593 # solve system using Newton's method
594 kwargs = {}
595 if tol:
596 kwargs['tol'] = tol
597 if maxsteps:
598 kwargs['maxsteps'] = maxsteps
599 kwargs['verbose'] = verbose
600 if norm:
601 kwargs['norm'] = norm
602 x = newton(f, x0, J, **kwargs)
603 return x