2 """ This module contain solvers for all kinds of equations:
4 - algebraic, use solve()
6 - recurrence, use rsolve()
8 - differential, use dsolve()
10 -transcendental, use tsolve()
12 -nonlinear (numerically), use msolve() (you will need a good starting point)
16 from sympy
.core
.sympify
import sympify
17 from sympy
.core
.basic
import Basic
, S
, C
18 from sympy
.core
.symbol
import Symbol
, Wild
19 from sympy
.core
.relational
import Equality
20 from sympy
.core
.function
import Derivative
, diff
22 from sympy
.functions
import sqrt
, log
, exp
, LambertW
23 from sympy
.simplify
import simplify
, collect
24 from sympy
.matrices
import Matrix
, zeros
25 from sympy
.polys
import roots
27 from sympy
.utilities
import any
, all
28 from sympy
.utilities
.lambdify
import lambdify
29 from sympy
.solvers
.numeric
import newton
31 from sympy
.solvers
.polysys
import solve_poly_system
33 def solve(f
, *symbols
, **flags
):
34 """Solves equations and systems of equations.
36 Currently supported are univariate polynomial and transcendental
37 equations and systems of linear and polynomial equations. Input
38 is formed as a single expression or an equation, or an iterable
39 container in case of an equation system. The type of output may
40 vary and depends heavily on the input. For more details refer to
41 more problem specific functions.
43 By default all solutions are simplified to make the output more
44 readable. If this is not the expected behavior, eg. because of
45 speed issues, set simplified=False in function arguments.
47 To solve equations and systems of equations of other kind, eg.
48 recurrence relations of differential equations use rsolve() or
49 dsolve() functions respectively.
51 >>> from sympy import *
52 >>> x,y = symbols('xy')
54 Solve a polynomial equation:
59 Solve a linear system:
61 >>> solve((x+5*y-2, -3*x+6*y-15), x, y)
66 raise ValueError('no symbols were given')
69 if isinstance(symbols
[0], (list, tuple, set)):
72 symbols
= map(sympify
, symbols
)
74 if any(not s
.is_Symbol
for s
in symbols
):
75 raise TypeError('not a Symbol')
77 if not isinstance(f
, (tuple, list, set)):
80 if isinstance(f
, Equality
):
84 poly
= f
.as_poly(*symbols
)
87 result
= roots(poly
, cubics
=True, quartics
=True).keys()
89 result
= [tsolve(f
, *symbols
)]
91 raise NotImplementedError('multivariate equation')
93 if flags
.get('simplified', True):
94 return map(simplify
, result
)
106 if isinstance(g
, Equality
):
109 poly
= g
.as_poly(*symbols
)
114 raise NotImplementedError
116 if all(p
.is_linear
for p
in polys
):
117 n
, m
= len(f
), len(symbols
)
118 matrix
= zeros((n
, m
+ 1))
120 for i
, poly
in enumerate(polys
):
121 for coeff
, monom
in poly
.iter_terms():
123 j
= list(monom
).index(1)
126 matrix
[i
, m
] = -coeff
128 return solve_linear_system(matrix
, *symbols
, **flags
)
130 return solve_poly_system(polys
)
132 def solve_linear_system(system
, *symbols
, **flags
):
133 """Solve system of N linear equations with M variables, which means
134 both Cramer and over defined systems are supported. The possible
135 number of solutions is zero, one or infinite. Respectively this
136 procedure will return None or dictionary with solutions. In the
137 case of over definend system all arbitrary parameters are skiped.
138 This may cause situation in with empty dictionary is returned.
139 In this case it means all symbols can be assigne arbitray values.
141 Input to this functions is a Nx(M+1) matrix, which means it has
142 to be in augmented form. If you are unhappy with such setting
143 use 'solve' method instead, where you can input equations
144 explicitely. And don't worry aboute the matrix, this function
145 is persistent and will make a local copy of it.
147 The algorithm used here is fraction free Gaussian elimination,
148 which results, after elimination, in upper-triangular matrix.
149 Then solutions are found using back-substitution. This approach
150 is more efficient and compact than the Gauss-Jordan method.
152 >>> from sympy import *
153 >>> x, y = symbols('xy')
155 Solve the following system:
160 >>> system = Matrix(( (1, 4, 2), (-2, 1, 14)))
161 >>> solve_linear_system(system, x, y)
168 i
, m
= 0, matrix
.cols
-1 # don't count augmentation
170 while i
< matrix
.lines
:
172 # an overdetermined system
173 if any(matrix
[i
:,m
]):
174 return None # no solutions
176 # remove trailing rows
177 matrix
= matrix
[:i
,:]
181 # there is no pivot in current column
182 # so try to find one in other colums
183 for k
in xrange(i
+1, m
):
188 return None # no solutions
190 # zero row or was a linear combination of
191 # other rows so now we can safely skip it
195 # we want to change the order of colums so
196 # the order of variables must also change
197 syms
[i
], syms
[k
] = syms
[k
], syms
[i
]
198 matrix
.col_swap(i
, k
)
200 pivot_inv
= S
.One
/ matrix
[i
, i
]
202 # divide all elements in the current row by the pivot
203 matrix
.row(i
, lambda x
, _
: x
* pivot_inv
)
205 for k
in xrange(i
+1, matrix
.lines
):
209 # subtract from the current row the row containing
210 # pivot and multiplied by extracted coefficient
211 matrix
.row(k
, lambda x
, j
: simplify(x
- matrix
[i
, j
]*coeff
))
215 # if there weren't any problmes, augmented matrix is now
216 # in row-echelon form so we can check how many solutions
217 # there are and extract them using back substitution
219 simplified
= flags
.get('simplified', True)
221 if len(syms
) == matrix
.lines
:
222 # this system is Cramer equivalent so there is
223 # exactly one solution to this system of equations
224 k
, solutions
= i
-1, {}
227 content
= matrix
[k
, m
]
229 # run back-substitution for variables
230 for j
in xrange(k
+1, m
):
231 content
-= matrix
[k
, j
]*solutions
[syms
[j
]]
234 solutions
[syms
[k
]] = simplify(content
)
236 solutions
[syms
[k
]] = content
241 elif len(syms
) > matrix
.lines
:
242 # this system will have infinite number of solutions
243 # dependent on exactly len(syms) - i parameters
244 k
, solutions
= i
-1, {}
247 content
= matrix
[k
, m
]
249 # run back-substitution for variables
250 for j
in xrange(k
+1, i
):
251 content
-= matrix
[k
, j
]*solutions
[syms
[j
]]
253 # run back-substitution for parameters
254 for j
in xrange(i
, m
):
255 content
-= matrix
[k
, j
]*syms
[j
]
258 solutions
[syms
[k
]] = simplify(content
)
260 solutions
[syms
[k
]] = content
266 return None # no solutions
268 def solve_undetermined_coeffs(equ
, coeffs
, sym
, **flags
):
269 """Solve equation of a type p(x; a_1, ..., a_k) == q(x) where both
270 p, q are univariate polynomials and f depends on k parameters.
271 The result of this functions is a dictionary with symbolic
272 values of those parameters with respect to coefficiens in q.
274 This functions accepts both Equations class instances and ordinary
275 SymPy expressions. Specification of parameters and variable is
276 obligatory for efficiency and simplicity reason.
278 >>> from sympy import *
279 >>> a, b, c, x = symbols('a', 'b', 'c', 'x')
281 >>> solve_undetermined_coeffs(Eq(2*a*x + a+b, x), [a, b], x)
284 >>> solve_undetermined_coeffs(Eq(a*c*x + a+b, x), [a, b], x)
288 if isinstance(equ
, Equality
):
289 # got equation, so move all the
290 # terms to the left hand side
291 equ
= equ
.lhs
- equ
.rhs
293 system
= collect(equ
.expand(), sym
, evaluate
=False).values()
295 if not any([ equ
.has(sym
) for equ
in system
]):
296 # consecutive powers in the input expressions have
297 # been successfully collected, so solve remaining
298 # system using Gaussian ellimination algorithm
299 return solve(system
, *coeffs
, **flags
)
301 return None # no solutions
303 def solve_linear_system_LU(matrix
, syms
):
304 """ LU function works for invertible only """
305 assert matrix
.lines
== matrix
.cols
-1
306 A
= matrix
[:matrix
.lines
,:matrix
.lines
]
307 b
= matrix
[:,matrix
.cols
-1:]
310 for i
in range(soln
.lines
):
311 solutions
[syms
[i
]] = soln
[i
,0]
314 def dsolve(eq
, funcs
):
316 Solves any (supported) kind of differential equation.
320 dsolve(f, y(x)) -> Solve a differential equation f for the function y
325 @param f: ordinary differential equation (either just the left hand
326 side, or the Equality class)
328 @param y: indeterminate function of one variable
330 - you can declare the derivative of an unknown function this way:
331 >>> from sympy import *
332 >>> x = Symbol('x') # x is the independent variable
334 >>> f = Function("f")(x) # f is a function of x
335 >>> f_ = Derivative(f, x) # f_ will be the derivative of f with respect to x
337 - This function just parses the equation "eq" and determines the type of
338 differential equation by its order, then it determines all the coefficients and then
339 calls the particular solver, which just accepts the coefficients.
340 - "eq" can be either an Equality, or just the left hand side (in which
341 case the right hand side is assumed to be 0)
342 - see test_ode.py for many tests, that serve also as a set of examples
347 >>> from sympy import *
350 >>> f = Function('f')
351 >>> dsolve(Derivative(f(x),x,x)+9*f(x), f(x))
352 C1*sin(3*x) + C2*cos(3*x)
353 >>> dsolve(Eq(Derivative(f(x),x,x)+9*f(x)+1, 1), f(x))
354 C1*sin(3*x) + C2*cos(3*x)
358 if isinstance(eq
, Equality
):
360 return dsolve(eq
.lhs
-eq
.rhs
, funcs
)
363 #currently only solve for one function
364 if isinstance(funcs
, Basic
) or len(funcs
) == 1:
365 if isinstance(funcs
, (list, tuple)): # normalize args
373 #We first get the order of the equation, so that we can choose the
374 #corresponding methods. Currently, only first and second
375 #order odes can be handled.
376 order
= deriv_degree(eq
, f(x
))
379 raise NotImplementedError("dsolve: Cannot solve " + str(eq
))
381 return solve_ODE_second_order(eq
, f(x
))
383 return solve_ODE_first_order(eq
, f(x
))
385 raise NotImplementedError("Not a differential equation!")
387 def deriv_degree(expr
, func
):
388 """ get the order of a given ode, the function is implemented
390 a
= Wild('a', exclude
=[func
])
393 if isinstance(expr
, Derivative
):
394 order
= len(expr
.symbols
)
396 for arg
in expr
.args
:
397 if isinstance(arg
, Derivative
):
398 order
= max(order
, len(arg
.symbols
))
402 for arg1
in arg
.args
:
403 order
= max(order
, deriv_degree(arg1
, func
))
407 def solve_ODE_first_order(eq
, f
):
409 solves many kinds of first order odes, different methods are used
410 depending on the form of the given equation. Now the linear
413 from sympy
.integrals
.integrals
import integrate
417 #linear case: a(x)*f'(x)+b(x)*f(x)+c(x) = 0
418 a
= Wild('a', exclude
=[f(x
)])
419 b
= Wild('b', exclude
=[f(x
)])
420 c
= Wild('c', exclude
=[f(x
)])
422 r
= eq
.match(a
*diff(f(x
),x
) + b
*f(x
) + c
)
424 t
= C
.exp(integrate(r
[b
]/r
[a
], x
))
425 tt
= integrate(t
*(-r
[c
]/r
[a
]), x
)
426 return (tt
+ Symbol("C1"))/t
428 #other cases of first order odes will be implemented here
430 raise NotImplementedError("solve_ODE_first_order: Cannot solve " + str(eq
))
432 def solve_ODE_second_order(eq
, f
):
434 solves many kinds of second order odes, different methods are used
435 depending on the form of the given equation. Now the constanst
436 coefficients case and a special case are implemented.
441 #constant coefficients case: af''(x)+bf'(x)+cf(x)=0
442 a
= Wild('a', exclude
=[x
])
443 b
= Wild('b', exclude
=[x
])
444 c
= Wild('c', exclude
=[x
])
446 r
= eq
.match(a
*f(x
).diff(x
,x
) + c
*f(x
))
448 return Symbol("C1")*C
.sin(sqrt(r
[c
]/r
[a
])*x
)+Symbol("C2")*C
.cos(sqrt(r
[c
]/r
[a
])*x
)
450 r
= eq
.match(a
*f(x
).diff(x
,x
) + b
*diff(f(x
),x
) + c
*f(x
))
452 r1
= solve(r
[a
]*x
**2 + r
[b
]*x
+ r
[c
], x
)
455 return (Symbol("C1") + Symbol("C2")*x
)*exp(r1
[0]*x
)
457 return Symbol("C1")*exp(r1
[0]*x
) + Symbol("C2")*exp(r1
[1]*x
)
459 r2
= abs((r1
[0] - r1
[1])/(2*S
.ImaginaryUnit
))
460 return (Symbol("C2")*C
.cos(r2
*x
) + Symbol("C1")*C
.sin(r2
*x
))*exp((r1
[0] + r1
[1])*x
/2)
462 #other cases of the second order odes will be implemented here
464 #special equations, that we know how to solve
466 tt
= a
*t
.diff(x
, x
)/t
467 r
= eq
.match(tt
.expand())
469 return -solve_ODE_1(f(x
), x
)
472 tt
= a
*t
.diff(x
, x
)/t
473 r
= eq
.match(tt
.expand())
475 #check, that we've rewritten the equation correctly:
476 #assert ( r[a]*t.diff(x,2)/t ) == eq.subs(f, t)
477 return solve_ODE_1(f(x
), x
)
479 neq
= eq
*C
.exp(f(x
))/C
.exp(-f(x
))
480 r
= neq
.match(tt
.expand())
482 #check, that we've rewritten the equation correctly:
483 #assert ( t.diff(x,2)*r[a]/t ).expand() == eq
484 return solve_ODE_1(f(x
), x
)
486 raise NotImplementedError("solve_ODE_second_order: cannot solve " + str(eq
))
488 def solve_ODE_1(f
, x
):
489 """ (x*exp(-f(x)))'' = 0 """
492 return -C
.log(C1
+C2
/x
)
494 x
= Symbol('x', dummy
=True)
495 a
,b
,c
,d
,e
,f
,g
,h
= [Wild(t
, exclude
=[x
]) for t
in 'abcdefgh']
498 def _generate_patterns():
499 """Generates patterns for transcendental equations.
501 This is lazily calculated (called) in the tsolve() function and stored in
502 the patterns global variable.
505 tmp1
= f
** (h
-(c
*g
/b
))
506 tmp2
= (-e
*tmp1
/a
)**(1/d
)
509 (a
*(b
*x
+c
)**d
+ e
, ((-(e
/a
))**(1/d
)-c
)/b
),
510 ( b
+c
*exp(d
*x
+e
) , (log(-b
/c
)-e
)/d
),
511 (a
*x
+b
+c
*exp(d
*x
+e
) , -b
/a
-LambertW(c
*d
*exp(e
-b
*d
/a
)/a
)/d
),
512 ( b
+c
*f
**(d
*x
+e
) , (log(-b
/c
)-e
*log(f
))/d
/log(f
)),
513 (a
*x
+b
+c
*f
**(d
*x
+e
) , -b
/a
-LambertW(c
*d
*f
**(e
-b
*d
/a
)*log(f
)/a
)/d
/log(f
)),
514 ( b
+c
*log(d
*x
+e
) , (exp(-b
/c
)-e
)/d
),
515 (a
*x
+b
+c
*log(d
*x
+e
) , -e
/d
+c
/a
*LambertW(a
/c
/d
*exp(-b
/c
+a
*e
/c
/d
))),
516 (a
*(b
*x
+c
)**d
+ e
*f
**(g
*x
+h
) , -c
/b
-d
*LambertW(-tmp2
*g
*log(f
)/b
/d
)/g
/log(f
))
521 Solves a transcendental equation with respect to the given
522 symbol. Various equations containing mixed linear terms, powers,
523 and logarithms, can be solved.
525 Only a single solution is returned. This solution is generally
526 not unique. In some cases, a complex solution may be returned
527 even though a real solution exists.
529 >>> from sympy import *
532 >>> tsolve(3**(2*x+5)-4, x)
533 (-5*log(3) + log(4))/(2*log(3))
535 >>> tsolve(log(x) + 2*x, x)
542 if isinstance(eq
, Equality
):
545 eq2
= eq
.subs(sym
, x
)
546 # First see if the equation has a linear factor
547 # In that case, the other factor can contain x in any way (as long as it
548 # is finite), and we have a direct solution
550 m
= eq2
.match((a
*x
+b
)*r
)
552 return (-b
/a
).subs(m
).subs(x
, sym
)
553 for p
, sol
in patterns
:
556 return sol
.subs(m
).subs(x
, sym
)
558 # let's also try to inverse the equation
563 indep
, dep
= lhs
.as_independent(sym
)
567 # this indicates we have done it all
576 # this indicates we have done it all
584 # f(x) = g -> x = f (g)
585 if lhs
.is_Function
and lhs
.nargs
==1 and hasattr(lhs
, 'inverse'):
586 rhs
= lhs
.inverse() (rhs
)
589 sol
= solve(lhs
-rhs
, sym
)
593 # just a simple case - we do variable substitution for first function,
594 # and if it removes all functions - let's call solve.
596 # UC: e + e = y -> t + t = y
597 t
= Symbol('t', dummy
=True)
600 # find first term which is Function
605 assert False, 'tsolve: at least one Function expected at this point'
607 # perform the substitution
608 lhs_
= lhs
.subs(f1
, t
)
610 # if no Functions left, we can proceed with usual solve
611 if not (lhs_
.is_Function
or
612 any(term
.is_Function
for term
in lhs_
.args
)):
614 # FIXME at present solve cannot solve x + 1/x = y
615 # FIXME so we do this:
616 numer
, denom
= lhs_
.as_numer_denom()
617 sol
= solve(numer
-rhs
*denom
, t
)
618 #sol = solve(lhs_-rhs, t)
621 sol
= tsolve(sol
-f1
, sym
)
627 raise ValueError("unable to solve the equation")
630 def msolve(args
, f
, x0
, tol
=None, maxsteps
=None, verbose
=False, norm
=None,
631 modules
=['mpmath', 'sympy']):
633 Solves a nonlinear equation system numerically.
635 f is a vector function of symbolic expressions representing the system.
636 args are the variables.
637 x0 is a starting vector close to a solution.
639 Be careful with x0, not using floats might give unexpected results.
641 Use modules to specify which modules should be used to evaluate the
642 function and the Jacobian matrix. Make sure to use a module that supports
643 matrices. For more information on the syntax, please see the docstring
646 Currently only fully determined systems are supported.
648 >>> from sympy import Symbol, Matrix
649 >>> x1 = Symbol('x1')
650 >>> x2 = Symbol('x2')
651 >>> f1 = 3 * x1**2 - 2 * x2**2 - 1
652 >>> f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8
653 >>> msolve((x1, x2), (f1, f2), (-1., 1.))
657 if isinstance(f
, (list, tuple)):
659 if len(args
) != f
.cols
:
660 raise NotImplementedError('need exactly as many variables as equations')
670 f
= lambdify(args
, f
.T
, modules
)
671 J
= lambdify(args
, J
, modules
)
672 # solve system using Newton's method
677 kwargs
['maxsteps'] = maxsteps
678 kwargs
['verbose'] = verbose
680 kwargs
['norm'] = norm
681 x
= newton(f
, x0
, J
, **kwargs
)