2 """ This module contain solvers for all kinds of equations:
4 - algebraic, use solve()
6 - recurrence, use rsolve()
8 - differential, use dsolve()
10 -transcendental, use tsolve()
12 -nonlinear (numerically), use msolve() (you will need a good starting point)
16 from sympy
.core
.sympify
import sympify
17 from sympy
.core
.basic
import Basic
, S
, C
18 from sympy
.core
.symbol
import Symbol
, Wild
19 from sympy
.core
.relational
import Equality
20 from sympy
.core
.function
import Derivative
, diff
22 from sympy
.functions
import sqrt
, log
, exp
, LambertW
23 from sympy
.simplify
import simplify
, collect
24 from sympy
.matrices
import Matrix
, zeronm
25 from sympy
.polys
import roots
27 from sympy
.utilities
import any
, all
28 from sympy
.utilities
.lambdify
import lambdify
29 from sympy
.solvers
.numeric
import newton
31 from sympy
.solvers
.polysys
import solve_poly_system
33 def solve(f
, *symbols
, **flags
):
34 """Solves equations and systems of equations.
36 Currently supported are univariate polynomial and transcendental
37 equations and systems of linear and polynomial equations. Input
38 is formed as a single expression or an equation, or an iterable
39 container in case of an equation system. The type of output may
40 vary and depends heavily on the input. For more details refer to
41 more problem specific functions.
43 By default all solutions are simplified to make the output more
44 readable. If this is not the expected behavior, eg. because of
45 speed issues, set simplified=False in function arguments.
47 To solve equations and systems of equations of other kind, eg.
48 recurrence relations of differential equations use rsolve() or
49 dsolve() functions respectively.
51 >>> from sympy import *
52 >>> x,y = symbols('xy')
54 Solve a polynomial equation:
59 Solve a linear system:
61 >>> solve((x+5*y-2, -3*x+6*y-15), x, y)
66 raise ValueError('no symbols were given')
69 if isinstance(symbols
[0], (list, tuple, set)):
72 symbols
= map(sympify
, symbols
)
74 if any(not s
.is_Symbol
for s
in symbols
):
75 raise TypeError('not a Symbol')
77 if not isinstance(f
, (tuple, list, set)):
80 if isinstance(f
, Equality
):
84 poly
= f
.as_poly(*symbols
)
87 result
= roots(poly
, cubics
=True, quartics
=True).keys()
89 result
= [tsolve(f
, *symbols
)]
91 raise NotImplementedError('multivariate equation')
93 if flags
.get('simplified', True):
94 return map(simplify
, result
)
106 if isinstance(g
, Equality
):
109 poly
= g
.as_poly(*symbols
)
114 raise NotImplementedError
116 if all(p
.is_linear
for p
in polys
):
117 n
, m
= len(f
), len(symbols
)
118 matrix
= zeronm(n
, m
+ 1)
120 for i
, poly
in enumerate(polys
):
121 for coeff
, monom
in poly
.iter_terms():
123 j
= list(monom
).index(1)
126 matrix
[i
, m
] = -coeff
128 return solve_linear_system(matrix
, *symbols
, **flags
)
130 return solve_poly_system(polys
)
132 def solve_linear_system(system
, *symbols
, **flags
):
133 """Solve system of N linear equations with M variables, which means
134 both Cramer and over defined systems are supported. The possible
135 number of solutions is zero, one or infinite. Respectively this
136 procedure will return None or dictionary with solutions. In the
137 case of over definend system all arbitrary parameters are skiped.
138 This may cause situation in with empty dictionary is returned.
139 In this case it means all symbols can be assigne arbitray values.
141 Input to this functions is a Nx(M+1) matrix, which means it has
142 to be in augmented form. If you are unhappy with such setting
143 use 'solve' method instead, where you can input equations
144 explicitely. And don't worry aboute the matrix, this function
145 is persistent and will make a local copy of it.
147 The algorithm used here is fraction free Gaussian elimination,
148 which results, after elimination, in upper-triangular matrix.
149 Then solutions are found using back-substitution. This approach
150 is more efficient and compact than the Gauss-Jordan method.
152 >>> from sympy import *
153 >>> x, y = symbols('xy')
155 Solve the following system:
160 >>> system = Matrix(( (1, 4, 2), (-2, 1, 14)))
161 >>> solve_linear_system(system, x, y)
168 i
, m
= 0, matrix
.cols
-1 # don't count augmentation
170 while i
< matrix
.lines
:
171 if matrix
[i
, i
] == 0:
172 # there is no pivot in current column
173 # so try to find one in other colums
174 for k
in range(i
+1, m
):
175 if matrix
[i
, k
] != 0:
178 if matrix
[i
, m
] != 0:
179 return None # no solutions
181 # zero row or was a linear combination of
182 # other rows so now we can safely skip it
186 # we want to change the order of colums so
187 # the order of variables must also change
188 syms
[i
], syms
[k
] = syms
[k
], syms
[i
]
189 matrix
.col_swap(i
, k
)
191 pivot
= matrix
[i
, i
]
193 # divide all elements in the current row by the pivot
194 matrix
.row(i
, lambda x
, _
: x
/ pivot
)
196 for k
in range(i
+1, matrix
.lines
):
197 if matrix
[k
, i
] != 0:
200 # subtract from the current row the row containing
201 # pivot and multiplied by extracted coefficient
202 matrix
.row(k
, lambda x
, j
: x
- matrix
[i
, j
]*coeff
)
206 # if there weren't any problmes, augmented matrix is now
207 # in row-echelon form so we can check how many solutions
208 # there are and extract them using back substitution
210 simplified
= flags
.get('simplified', True)
212 if len(syms
) == matrix
.lines
:
213 # this system is Cramer equivalent so there is
214 # exactly one solution to this system of equations
215 k
, solutions
= i
-1, {}
218 content
= matrix
[k
, m
]
220 # run back-substitution for variables
221 for j
in range(k
+1, m
):
222 content
-= matrix
[k
, j
]*solutions
[syms
[j
]]
225 solutions
[syms
[k
]] = simplify(content
)
227 solutions
[syms
[k
]] = content
232 elif len(syms
) > matrix
.lines
:
233 # this system will have infinite number of solutions
234 # dependent on exactly len(syms) - i parameters
235 k
, solutions
= i
-1, {}
238 content
= matrix
[k
, m
]
240 # run back-substitution for variables
241 for j
in range(k
+1, i
):
242 content
-= matrix
[k
, j
]*solutions
[syms
[j
]]
244 # run back-substitution for parameters
245 for j
in range(i
, m
):
246 content
-= matrix
[k
, j
]*syms
[j
]
249 solutions
[syms
[k
]] = simplify(content
)
251 solutions
[syms
[k
]] = content
257 return None # no solutions
259 def solve_undetermined_coeffs(equ
, coeffs
, sym
, **flags
):
260 """Solve equation of a type p(x; a_1, ..., a_k) == q(x) where both
261 p, q are univariate polynomials and f depends on k parameters.
262 The result of this functions is a dictionary with symbolic
263 values of those parameters with respect to coefficiens in q.
265 This functions accepts both Equations class instances and ordinary
266 SymPy expressions. Specification of parameters and variable is
267 obligatory for efficiency and simplicity reason.
269 >>> from sympy import *
270 >>> a, b, c, x = symbols('a', 'b', 'c', 'x')
272 >>> solve_undetermined_coeffs(Eq(2*a*x + a+b, x), [a, b], x)
275 >>> solve_undetermined_coeffs(Eq(a*c*x + a+b, x), [a, b], x)
279 if isinstance(equ
, Equality
):
280 # got equation, so move all the
281 # terms to the left hand side
282 equ
= equ
.lhs
- equ
.rhs
284 system
= collect(equ
.expand(), sym
, evaluate
=False).values()
286 if not any([ equ
.has(sym
) for equ
in system
]):
287 # consecutive powers in the input expressions have
288 # been successfully collected, so solve remaining
289 # system using Gaussian ellimination algorithm
290 return solve(system
, *coeffs
, **flags
)
292 return None # no solutions
294 def solve_linear_system_LU(matrix
, syms
):
295 """ LU function works for invertible only """
296 assert matrix
.lines
== matrix
.cols
-1
297 A
= matrix
[:matrix
.lines
,:matrix
.lines
]
298 b
= matrix
[:,matrix
.cols
-1:]
301 for i
in range(soln
.lines
):
302 solutions
[syms
[i
]] = soln
[i
,0]
305 def dsolve(eq
, funcs
):
307 Solves any (supported) kind of differential equation.
311 dsolve(f, y(x)) -> Solve a differential equation f for the function y
316 @param f: ordinary differential equation (either just the left hand
317 side, or the Equality class)
319 @param y: indeterminate function of one variable
321 - you can declare the derivative of an unknown function this way:
322 >>> from sympy import *
323 >>> x = Symbol('x') # x is the independent variable
325 >>> f = Function("f")(x) # f is a function of x
326 >>> f_ = Derivative(f, x) # f_ will be the derivative of f with respect to x
328 - This function just parses the equation "eq" and determines the type of
329 differential equation by its order, then it determines all the coefficients and then
330 calls the particular solver, which just accepts the coefficients.
331 - "eq" can be either an Equality, or just the left hand side (in which
332 case the right hand side is assumed to be 0)
333 - see test_ode.py for many tests, that serve also as a set of examples
338 >>> from sympy import *
341 >>> f = Function('f')
342 >>> dsolve(Derivative(f(x),x,x)+9*f(x), f(x))
343 C1*sin(3*x) + C2*cos(3*x)
344 >>> dsolve(Eq(Derivative(f(x),x,x)+9*f(x)+1, 1), f(x))
345 C1*sin(3*x) + C2*cos(3*x)
349 if isinstance(eq
, Equality
):
351 return dsolve(eq
.lhs
-eq
.rhs
, funcs
)
354 #currently only solve for one function
355 if isinstance(funcs
, Basic
) or len(funcs
) == 1:
356 if isinstance(funcs
, (list, tuple)): # normalize args
364 #We first get the order of the equation, so that we can choose the
365 #corresponding methods. Currently, only first and second
366 #order odes can be handled.
367 order
= deriv_degree(eq
, f(x
))
370 raise NotImplementedError("dsolve: Cannot solve " + str(eq
))
372 return solve_ODE_second_order(eq
, f(x
))
374 return solve_ODE_first_order(eq
, f(x
))
376 raise NotImplementedError("Not a differential equation!")
378 def deriv_degree(expr
, func
):
379 """ get the order of a given ode, the function is implemented
381 a
= Wild('a', exclude
=[func
])
384 if isinstance(expr
, Derivative
):
385 order
= len(expr
.symbols
)
387 for arg
in expr
.args
:
388 if isinstance(arg
, Derivative
):
389 order
= max(order
, len(arg
.symbols
))
393 for arg1
in arg
.args
:
394 order
= max(order
, deriv_degree(arg1
, func
))
398 def solve_ODE_first_order(eq
, f
):
400 solves many kinds of first order odes, different methods are used
401 depending on the form of the given equation. Now the linear
404 from sympy
.integrals
.integrals
import integrate
408 #linear case: a(x)*f'(x)+b(x)*f(x)+c(x) = 0
409 a
= Wild('a', exclude
=[f(x
)])
410 b
= Wild('b', exclude
=[f(x
)])
411 c
= Wild('c', exclude
=[f(x
)])
413 r
= eq
.match(a
*diff(f(x
),x
) + b
*f(x
) + c
)
415 t
= C
.exp(integrate(r
[b
]/r
[a
], x
))
416 tt
= integrate(t
*(-r
[c
]/r
[a
]), x
)
417 return (tt
+ Symbol("C1"))/t
419 #other cases of first order odes will be implemented here
421 raise NotImplementedError("dsolve: Cannot solve " + str(eq
))
423 def solve_ODE_second_order(eq
, f
):
425 solves many kinds of second order odes, different methods are used
426 depending on the form of the given equation. Now the constanst
427 coefficients case and a special case are implemented.
432 #constant coefficients case: af''(x)+bf'(x)+cf(x)=0
433 a
= Wild('a', exclude
=[x
])
434 b
= Wild('b', exclude
=[x
])
435 c
= Wild('c', exclude
=[x
])
437 r
= eq
.match(a
*f(x
).diff(x
,x
) + c
*f(x
))
439 return Symbol("C1")*C
.sin(sqrt(r
[c
]/r
[a
])*x
)+Symbol("C2")*C
.cos(sqrt(r
[c
]/r
[a
])*x
)
441 r
= eq
.match(a
*f(x
).diff(x
,x
) + b
*diff(f(x
),x
) + c
*f(x
))
443 r1
= solve(r
[a
]*x
**2 + r
[b
]*x
+ r
[c
], x
)
446 return (Symbol("C1") + Symbol("C2")*x
)*exp(r1
[0]*x
)
448 return Symbol("C1")*exp(r1
[0]*x
) + Symbol("C2")*exp(r1
[1]*x
)
450 r2
= abs((r1
[0] - r1
[1])/(2*S
.ImaginaryUnit
))
451 return (Symbol("C2")*C
.cos(r2
*x
) + Symbol("C1")*C
.sin(r2
*x
))*exp((r1
[0] + r1
[1])*x
/2)
453 #other cases of the second order odes will be implemented here
455 #special equations, that we know how to solve
457 tt
= a
*t
.diff(x
, x
)/t
458 r
= eq
.match(tt
.expand())
460 return -solve_ODE_1(f(x
), x
)
463 tt
= a
*t
.diff(x
, x
)/t
464 r
= eq
.match(tt
.expand())
466 #check, that we've rewritten the equation correctly:
467 #assert ( r[a]*t.diff(x,2)/t ) == eq.subs(f, t)
468 return solve_ODE_1(f(x
), x
)
470 neq
= eq
*C
.exp(f(x
))/C
.exp(-f(x
))
471 r
= neq
.match(tt
.expand())
473 #check, that we've rewritten the equation correctly:
474 #assert ( t.diff(x,2)*r[a]/t ).expand() == eq
475 return solve_ODE_1(f(x
), x
)
477 raise NotImplementedError("cannot solve this")
479 def solve_ODE_1(f
, x
):
480 """ (x*exp(-f(x)))'' = 0 """
483 return -C
.log(C1
+C2
/x
)
485 x
= Symbol('x', dummy
=True)
486 a
,b
,c
,d
,e
,f
,g
,h
= [Wild(t
, exclude
=[x
]) for t
in 'abcdefgh']
489 def _generate_patterns():
490 """Generates patterns for transcendental equations.
492 This is lazily calculated (called) in the tsolve() function and stored in
493 the patterns global variable.
496 tmp1
= f
** (h
-(c
*g
/b
))
497 tmp2
= (-e
*tmp1
/a
)**(1/d
)
500 (a
*(b
*x
+c
)**d
+ e
, ((-(e
/a
))**(1/d
)-c
)/b
),
501 ( b
+c
*exp(d
*x
+e
) , (log(-b
/c
)-e
)/d
),
502 (a
*x
+b
+c
*exp(d
*x
+e
) , -b
/a
-LambertW(c
*d
*exp(e
-b
*d
/a
)/a
)/d
),
503 ( b
+c
*f
**(d
*x
+e
) , (log(-b
/c
)-e
*log(f
))/d
/log(f
)),
504 (a
*x
+b
+c
*f
**(d
*x
+e
) , -b
/a
-LambertW(c
*d
*f
**(e
-b
*d
/a
)*log(f
)/a
)/d
/log(f
)),
505 ( b
+c
*log(d
*x
+e
) , (exp(-b
/c
)-e
)/d
),
506 (a
*x
+b
+c
*log(d
*x
+e
) , -e
/d
+c
/a
*LambertW(a
/c
/d
*exp(-b
/c
+a
*e
/c
/d
))),
507 (a
*(b
*x
+c
)**d
+ e
*f
**(g
*x
+h
) , -c
/b
-d
*LambertW(-tmp2
*g
*log(f
)/b
/d
)/g
/log(f
))
512 Solves a transcendental equation with respect to the given
513 symbol. Various equations containing mixed linear terms, powers,
514 and logarithms, can be solved.
516 Only a single solution is returned. This solution is generally
517 not unique. In some cases, a complex solution may be returned
518 even though a real solution exists.
520 >>> from sympy import *
523 >>> tsolve(3**(2*x+5)-4, x)
524 (-5*log(3) + log(4))/(2*log(3))
526 >>> tsolve(log(x) + 2*x, x)
533 if isinstance(eq
, Equality
):
536 eq2
= eq
.subs(sym
, x
)
537 # First see if the equation has a linear factor
538 # In that case, the other factor can contain x in any way (as long as it
539 # is finite), and we have a direct solution
541 m
= eq2
.match((a
*x
+b
)*r
)
543 return (-b
/a
).subs(m
).subs(x
, sym
)
544 for p
, sol
in patterns
:
547 return sol
.subs(m
).subs(x
, sym
)
548 raise ValueError("unable to solve the equation")
551 def msolve(args
, f
, x0
, tol
=None, maxsteps
=None, verbose
=False, norm
=None,
552 modules
=['mpmath', 'sympy']):
554 Solves a nonlinear equation system numerically.
556 f is a vector function of symbolic expressions representing the system.
557 args are the variables.
558 x0 is a starting vector close to a solution.
560 Be careful with x0, not using floats might give unexpected results.
562 Use modules to specify which modules should be used to evaluate the
563 function and the Jacobian matrix. Make sure to use a module that supports
564 matrices. For more information on the syntax, please see the docstring
567 Currently only fully determined systems are supported.
569 >>> from sympy import Symbol, Matrix
570 >>> x1 = Symbol('x1')
571 >>> x2 = Symbol('x2')
572 >>> f1 = 3 * x1**2 - 2 * x2**2 - 1
573 >>> f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8
574 >>> msolve((x1, x2), (f1, f2), (-1., 1.))
578 if isinstance(f
, (list, tuple)):
580 if len(args
) != f
.cols
:
581 raise NotImplementedError('need exactly as many variables as equations')
591 f
= lambdify(args
, f
.T
, modules
)
592 J
= lambdify(args
, J
, modules
)
593 # solve system using Newton's method
598 kwargs
['maxsteps'] = maxsteps
599 kwargs
['verbose'] = verbose
601 kwargs
['norm'] = norm
602 x
= newton(f
, x0
, J
, **kwargs
)