draft implementation of arbitrary PDF
[sympy.git] / sympy / solvers / solvers.py
blob6157b2a2f32e7aa3422740de0a7b00994cc5c6ad
2 """ This module contain solvers for all kinds of equations:
4 - algebraic, use solve()
6 - recurrence, use rsolve()
8 - differential, use dsolve()
10 -transcendental, use tsolve()
12 -nonlinear (numerically), use msolve() (you will need a good starting point)
14 """
16 from sympy.core.sympify import sympify
17 from sympy.core.basic import Basic, S, C
18 from sympy.core.symbol import Symbol, Wild
19 from sympy.core.relational import Equality
20 from sympy.core.function import Derivative, diff
22 from sympy.functions import sqrt, log, exp, LambertW
23 from sympy.simplify import simplify, collect
24 from sympy.matrices import Matrix, zeros
25 from sympy.polys import roots
27 from sympy.utilities import any, all
28 from sympy.utilities.lambdify import lambdify
29 from sympy.solvers.numeric import newton
31 from sympy.solvers.polysys import solve_poly_system
33 def solve(f, *symbols, **flags):
34 """Solves equations and systems of equations.
36 Currently supported are univariate polynomial and transcendental
37 equations and systems of linear and polynomial equations. Input
38 is formed as a single expression or an equation, or an iterable
39 container in case of an equation system. The type of output may
40 vary and depends heavily on the input. For more details refer to
41 more problem specific functions.
43 By default all solutions are simplified to make the output more
44 readable. If this is not the expected behavior, eg. because of
45 speed issues, set simplified=False in function arguments.
47 To solve equations and systems of equations of other kind, eg.
48 recurrence relations of differential equations use rsolve() or
49 dsolve() functions respectively.
51 >>> from sympy import *
52 >>> x,y = symbols('xy')
54 Solve a polynomial equation:
56 >>> solve(x**4-1, x)
57 [1, -1, -I, I]
59 Solve a linear system:
61 >>> solve((x+5*y-2, -3*x+6*y-15), x, y)
62 {x: -3, y: 1}
64 """
65 if not symbols:
66 raise ValueError('no symbols were given')
68 if len(symbols) == 1:
69 if isinstance(symbols[0], (list, tuple, set)):
70 symbols = symbols[0]
72 symbols = map(sympify, symbols)
74 if any(not s.is_Symbol for s in symbols):
75 raise TypeError('not a Symbol')
77 if not isinstance(f, (tuple, list, set)):
78 f = sympify(f)
80 if isinstance(f, Equality):
81 f = f.lhs - f.rhs
83 if len(symbols) == 1:
84 poly = f.as_poly(*symbols)
86 if poly is not None:
87 result = roots(poly, cubics=True, quartics=True).keys()
88 else:
89 result = [tsolve(f, *symbols)]
90 else:
91 raise NotImplementedError('multivariate equation')
93 if flags.get('simplified', True):
94 return map(simplify, result)
95 else:
96 return result
97 else:
98 if not f:
99 return {}
100 else:
101 polys = []
103 for g in f:
104 g = sympify(g)
106 if isinstance(g, Equality):
107 g = g.lhs - g.rhs
109 poly = g.as_poly(*symbols)
111 if poly is not None:
112 polys.append(poly)
113 else:
114 raise NotImplementedError
116 if all(p.is_linear for p in polys):
117 n, m = len(f), len(symbols)
118 matrix = zeros((n, m + 1))
120 for i, poly in enumerate(polys):
121 for coeff, monom in poly.iter_terms():
122 try:
123 j = list(monom).index(1)
124 matrix[i, j] = coeff
125 except ValueError:
126 matrix[i, m] = -coeff
128 return solve_linear_system(matrix, *symbols, **flags)
129 else:
130 return solve_poly_system(polys)
132 def solve_linear_system(system, *symbols, **flags):
133 """Solve system of N linear equations with M variables, which means
134 both Cramer and over defined systems are supported. The possible
135 number of solutions is zero, one or infinite. Respectively this
136 procedure will return None or dictionary with solutions. In the
137 case of over definend system all arbitrary parameters are skiped.
138 This may cause situation in with empty dictionary is returned.
139 In this case it means all symbols can be assigne arbitray values.
141 Input to this functions is a Nx(M+1) matrix, which means it has
142 to be in augmented form. If you are unhappy with such setting
143 use 'solve' method instead, where you can input equations
144 explicitely. And don't worry aboute the matrix, this function
145 is persistent and will make a local copy of it.
147 The algorithm used here is fraction free Gaussian elimination,
148 which results, after elimination, in upper-triangular matrix.
149 Then solutions are found using back-substitution. This approach
150 is more efficient and compact than the Gauss-Jordan method.
152 >>> from sympy import *
153 >>> x, y = symbols('xy')
155 Solve the following system:
157 x + 4 y == 2
158 -2 x + y == 14
160 >>> system = Matrix(( (1, 4, 2), (-2, 1, 14)))
161 >>> solve_linear_system(system, x, y)
162 {x: -6, y: 2}
165 matrix = system[:,:]
166 syms = list(symbols)
168 i, m = 0, matrix.cols-1 # don't count augmentation
170 while i < matrix.lines:
171 if i == m:
172 # an overdetermined system
173 if any(matrix[i:,m]):
174 return None # no solutions
175 else:
176 # remove trailing rows
177 matrix = matrix[:i,:]
178 break
180 if not matrix[i, i]:
181 # there is no pivot in current column
182 # so try to find one in other colums
183 for k in xrange(i+1, m):
184 if matrix[i, k]:
185 break
186 else:
187 if matrix[i, m]:
188 return None # no solutions
189 else:
190 # zero row or was a linear combination of
191 # other rows so now we can safely skip it
192 matrix.row_del(i)
193 continue
195 # we want to change the order of colums so
196 # the order of variables must also change
197 syms[i], syms[k] = syms[k], syms[i]
198 matrix.col_swap(i, k)
200 pivot_inv = S.One / matrix [i, i]
202 # divide all elements in the current row by the pivot
203 matrix.row(i, lambda x, _: x * pivot_inv)
205 for k in xrange(i+1, matrix.lines):
206 if matrix[k, i]:
207 coeff = matrix[k, i]
209 # subtract from the current row the row containing
210 # pivot and multiplied by extracted coefficient
211 matrix.row(k, lambda x, j: simplify(x - matrix[i, j]*coeff))
213 i += 1
215 # if there weren't any problmes, augmented matrix is now
216 # in row-echelon form so we can check how many solutions
217 # there are and extract them using back substitution
219 simplified = flags.get('simplified', True)
221 if len(syms) == matrix.lines:
222 # this system is Cramer equivalent so there is
223 # exactly one solution to this system of equations
224 k, solutions = i-1, {}
226 while k >= 0:
227 content = matrix[k, m]
229 # run back-substitution for variables
230 for j in xrange(k+1, m):
231 content -= matrix[k, j]*solutions[syms[j]]
233 if simplified:
234 solutions[syms[k]] = simplify(content)
235 else:
236 solutions[syms[k]] = content
238 k -= 1
240 return solutions
241 elif len(syms) > matrix.lines:
242 # this system will have infinite number of solutions
243 # dependent on exactly len(syms) - i parameters
244 k, solutions = i-1, {}
246 while k >= 0:
247 content = matrix[k, m]
249 # run back-substitution for variables
250 for j in xrange(k+1, i):
251 content -= matrix[k, j]*solutions[syms[j]]
253 # run back-substitution for parameters
254 for j in xrange(i, m):
255 content -= matrix[k, j]*syms[j]
257 if simplified:
258 solutions[syms[k]] = simplify(content)
259 else:
260 solutions[syms[k]] = content
262 k -= 1
264 return solutions
265 else:
266 return None # no solutions
268 def solve_undetermined_coeffs(equ, coeffs, sym, **flags):
269 """Solve equation of a type p(x; a_1, ..., a_k) == q(x) where both
270 p, q are univariate polynomials and f depends on k parameters.
271 The result of this functions is a dictionary with symbolic
272 values of those parameters with respect to coefficiens in q.
274 This functions accepts both Equations class instances and ordinary
275 SymPy expressions. Specification of parameters and variable is
276 obligatory for efficiency and simplicity reason.
278 >>> from sympy import *
279 >>> a, b, c, x = symbols('a', 'b', 'c', 'x')
281 >>> solve_undetermined_coeffs(Eq(2*a*x + a+b, x), [a, b], x)
282 {a: 1/2, b: -1/2}
284 >>> solve_undetermined_coeffs(Eq(a*c*x + a+b, x), [a, b], x)
285 {a: 1/c, b: -1/c}
288 if isinstance(equ, Equality):
289 # got equation, so move all the
290 # terms to the left hand side
291 equ = equ.lhs - equ.rhs
293 system = collect(equ.expand(), sym, evaluate=False).values()
295 if not any([ equ.has(sym) for equ in system ]):
296 # consecutive powers in the input expressions have
297 # been successfully collected, so solve remaining
298 # system using Gaussian ellimination algorithm
299 return solve(system, *coeffs, **flags)
300 else:
301 return None # no solutions
303 def solve_linear_system_LU(matrix, syms):
304 """ LU function works for invertible only """
305 assert matrix.lines == matrix.cols-1
306 A = matrix[:matrix.lines,:matrix.lines]
307 b = matrix[:,matrix.cols-1:]
308 soln = A.LUsolve(b)
309 solutions = {}
310 for i in range(soln.lines):
311 solutions[syms[i]] = soln[i,0]
312 return solutions
314 def dsolve(eq, funcs):
316 Solves any (supported) kind of differential equation.
318 Usage
319 =====
320 dsolve(f, y(x)) -> Solve a differential equation f for the function y
323 Details
324 =======
325 @param f: ordinary differential equation (either just the left hand
326 side, or the Equality class)
328 @param y: indeterminate function of one variable
330 - you can declare the derivative of an unknown function this way:
331 >>> from sympy import *
332 >>> x = Symbol('x') # x is the independent variable
334 >>> f = Function("f")(x) # f is a function of x
335 >>> f_ = Derivative(f, x) # f_ will be the derivative of f with respect to x
337 - This function just parses the equation "eq" and determines the type of
338 differential equation by its order, then it determines all the coefficients and then
339 calls the particular solver, which just accepts the coefficients.
340 - "eq" can be either an Equality, or just the left hand side (in which
341 case the right hand side is assumed to be 0)
342 - see test_ode.py for many tests, that serve also as a set of examples
343 how to use dsolve
345 Examples
346 ========
347 >>> from sympy import *
348 >>> x = Symbol('x')
350 >>> f = Function('f')
351 >>> dsolve(Derivative(f(x),x,x)+9*f(x), f(x))
352 C1*sin(3*x) + C2*cos(3*x)
353 >>> dsolve(Eq(Derivative(f(x),x,x)+9*f(x)+1, 1), f(x))
354 C1*sin(3*x) + C2*cos(3*x)
358 if isinstance(eq, Equality):
359 if eq.rhs != 0:
360 return dsolve(eq.lhs-eq.rhs, funcs)
361 eq = eq.lhs
363 #currently only solve for one function
364 if isinstance(funcs, Basic) or len(funcs) == 1:
365 if isinstance(funcs, (list, tuple)): # normalize args
366 f = funcs[0]
367 else:
368 f = funcs
370 x = f.args[0]
371 f = f.func
373 #We first get the order of the equation, so that we can choose the
374 #corresponding methods. Currently, only first and second
375 #order odes can be handled.
376 order = deriv_degree(eq, f(x))
378 if order > 2 :
379 raise NotImplementedError("dsolve: Cannot solve " + str(eq))
380 elif order == 2:
381 return solve_ODE_second_order(eq, f(x))
382 elif order == 1:
383 return solve_ODE_first_order(eq, f(x))
384 else:
385 raise NotImplementedError("Not a differential equation!")
387 def deriv_degree(expr, func):
388 """ get the order of a given ode, the function is implemented
389 recursively """
390 a = Wild('a', exclude=[func])
392 order = 0
393 if isinstance(expr, Derivative):
394 order = len(expr.symbols)
395 else:
396 for arg in expr.args:
397 if isinstance(arg, Derivative):
398 order = max(order, len(arg.symbols))
399 elif expr.match(a):
400 order = 0
401 else :
402 for arg1 in arg.args:
403 order = max(order, deriv_degree(arg1, func))
405 return order
407 def solve_ODE_first_order(eq, f):
409 solves many kinds of first order odes, different methods are used
410 depending on the form of the given equation. Now the linear
411 case is implemented.
413 from sympy.integrals.integrals import integrate
414 x = f.args[0]
415 f = f.func
417 #linear case: a(x)*f'(x)+b(x)*f(x)+c(x) = 0
418 a = Wild('a', exclude=[f(x)])
419 b = Wild('b', exclude=[f(x)])
420 c = Wild('c', exclude=[f(x)])
422 r = eq.match(a*diff(f(x),x) + b*f(x) + c)
423 if r:
424 t = C.exp(integrate(r[b]/r[a], x))
425 tt = integrate(t*(-r[c]/r[a]), x)
426 return (tt + Symbol("C1"))/t
428 #other cases of first order odes will be implemented here
430 raise NotImplementedError("dsolve: Cannot solve " + str(eq))
432 def solve_ODE_second_order(eq, f):
434 solves many kinds of second order odes, different methods are used
435 depending on the form of the given equation. Now the constanst
436 coefficients case and a special case are implemented.
438 x = f.args[0]
439 f = f.func
441 #constant coefficients case: af''(x)+bf'(x)+cf(x)=0
442 a = Wild('a', exclude=[x])
443 b = Wild('b', exclude=[x])
444 c = Wild('c', exclude=[x])
446 r = eq.match(a*f(x).diff(x,x) + c*f(x))
447 if r:
448 return Symbol("C1")*C.sin(sqrt(r[c]/r[a])*x)+Symbol("C2")*C.cos(sqrt(r[c]/r[a])*x)
450 r = eq.match(a*f(x).diff(x,x) + b*diff(f(x),x) + c*f(x))
451 if r:
452 r1 = solve(r[a]*x**2 + r[b]*x + r[c], x)
453 if r1[0].is_real:
454 if len(r1) == 1:
455 return (Symbol("C1") + Symbol("C2")*x)*exp(r1[0]*x)
456 else:
457 return Symbol("C1")*exp(r1[0]*x) + Symbol("C2")*exp(r1[1]*x)
458 else:
459 r2 = abs((r1[0] - r1[1])/(2*S.ImaginaryUnit))
460 return (Symbol("C2")*C.cos(r2*x) + Symbol("C1")*C.sin(r2*x))*exp((r1[0] + r1[1])*x/2)
462 #other cases of the second order odes will be implemented here
464 #special equations, that we know how to solve
465 t = x*C.exp(f(x))
466 tt = a*t.diff(x, x)/t
467 r = eq.match(tt.expand())
468 if r:
469 return -solve_ODE_1(f(x), x)
471 t = x*C.exp(-f(x))
472 tt = a*t.diff(x, x)/t
473 r = eq.match(tt.expand())
474 if r:
475 #check, that we've rewritten the equation correctly:
476 #assert ( r[a]*t.diff(x,2)/t ) == eq.subs(f, t)
477 return solve_ODE_1(f(x), x)
479 neq = eq*C.exp(f(x))/C.exp(-f(x))
480 r = neq.match(tt.expand())
481 if r:
482 #check, that we've rewritten the equation correctly:
483 #assert ( t.diff(x,2)*r[a]/t ).expand() == eq
484 return solve_ODE_1(f(x), x)
486 raise NotImplementedError("cannot solve this")
488 def solve_ODE_1(f, x):
489 """ (x*exp(-f(x)))'' = 0 """
490 C1 = Symbol("C1")
491 C2 = Symbol("C2")
492 return -C.log(C1+C2/x)
494 x = Symbol('x', dummy=True)
495 a,b,c,d,e,f,g,h = [Wild(t, exclude=[x]) for t in 'abcdefgh']
496 patterns = None
498 def _generate_patterns():
499 """Generates patterns for transcendental equations.
501 This is lazily calculated (called) in the tsolve() function and stored in
502 the patterns global variable.
505 tmp1 = f ** (h-(c*g/b))
506 tmp2 = (-e*tmp1/a)**(1/d)
507 global patterns
508 patterns = [
509 (a*(b*x+c)**d + e , ((-(e/a))**(1/d)-c)/b),
510 ( b+c*exp(d*x+e) , (log(-b/c)-e)/d),
511 (a*x+b+c*exp(d*x+e) , -b/a-LambertW(c*d*exp(e-b*d/a)/a)/d),
512 ( b+c*f**(d*x+e) , (log(-b/c)-e*log(f))/d/log(f)),
513 (a*x+b+c*f**(d*x+e) , -b/a-LambertW(c*d*f**(e-b*d/a)*log(f)/a)/d/log(f)),
514 ( b+c*log(d*x+e) , (exp(-b/c)-e)/d),
515 (a*x+b+c*log(d*x+e) , -e/d+c/a*LambertW(a/c/d*exp(-b/c+a*e/c/d))),
516 (a*(b*x+c)**d + e*f**(g*x+h) , -c/b-d*LambertW(-tmp2*g*log(f)/b/d)/g/log(f))
519 def tsolve(eq, sym):
521 Solves a transcendental equation with respect to the given
522 symbol. Various equations containing mixed linear terms, powers,
523 and logarithms, can be solved.
525 Only a single solution is returned. This solution is generally
526 not unique. In some cases, a complex solution may be returned
527 even though a real solution exists.
529 >>> from sympy import *
530 >>> x = Symbol('x')
532 >>> tsolve(3**(2*x+5)-4, x)
533 (-5*log(3) + log(4))/(2*log(3))
535 >>> tsolve(log(x) + 2*x, x)
536 1/2*LambertW(2)
539 if patterns is None:
540 _generate_patterns()
541 eq = sympify(eq)
542 if isinstance(eq, Equality):
543 eq = eq.lhs - eq.rhs
544 sym = sympify(sym)
545 eq2 = eq.subs(sym, x)
546 # First see if the equation has a linear factor
547 # In that case, the other factor can contain x in any way (as long as it
548 # is finite), and we have a direct solution
549 r = Wild('r')
550 m = eq2.match((a*x+b)*r)
551 if m and m[a]:
552 return (-b/a).subs(m).subs(x, sym)
553 for p, sol in patterns:
554 m = eq2.match(p)
555 if m:
556 return sol.subs(m).subs(x, sym)
558 # let's also try to inverse the equation
559 lhs = eq
560 rhs = S.Zero
562 while True:
563 indep, dep = lhs.as_independent(sym)
565 # dep + indep == rhs
566 if lhs.is_Add:
567 # this indicates we have done it all
568 if indep is S.Zero:
569 break
571 lhs = dep
572 rhs-= indep
574 # dep * indep == rhs
575 else:
576 # this indicates we have done it all
577 if indep is S.One:
578 break
580 lhs = dep
581 rhs/= indep
583 # -1
584 # f(x) = g -> x = f (g)
585 if lhs.is_Function and lhs.nargs==1 and hasattr(lhs, 'inverse'):
586 rhs = lhs.inverse() (rhs)
587 lhs = lhs.args[0]
589 sol = solve(lhs-rhs, sym)
590 return sol[0]
592 elif lhs.is_Add:
593 # just a simple case - we do variable substitution for first function,
594 # and if it removes all functions - let's call solve.
595 # x -x -1
596 # UC: e + e = y -> t + t = y
597 t = Symbol('t', dummy=True)
598 terms = lhs.args
600 # find first term which is Function
601 for f1 in lhs.args:
602 if f1.is_Function:
603 break
604 else:
605 assert False, 'tsolve: at least one Function expected at this point'
607 # perform the substitution
608 lhs_ = lhs.subs(f1, t)
610 # if no Functions left, we can proceed with usual solve
611 if not (lhs_.is_Function or
612 any(term.is_Function for term in lhs_.args)):
614 # FIXME at present solve cannot solve x + 1/x = y
615 # FIXME so we do this:
616 numer, denom = lhs_.as_numer_denom()
617 sol = solve(numer-rhs*denom, t)
618 #sol = solve(lhs_-rhs, t)
619 sol = sol[0]
621 sol = tsolve(sol-f1, sym)
622 return sol
627 raise ValueError("unable to solve the equation")
630 def msolve(args, f, x0, tol=None, maxsteps=None, verbose=False, norm=None,
631 modules=['mpmath', 'sympy']):
633 Solves a nonlinear equation system numerically.
635 f is a vector function of symbolic expressions representing the system.
636 args are the variables.
637 x0 is a starting vector close to a solution.
639 Be careful with x0, not using floats might give unexpected results.
641 Use modules to specify which modules should be used to evaluate the
642 function and the Jacobian matrix. Make sure to use a module that supports
643 matrices. For more information on the syntax, please see the docstring
644 of lambdify.
646 Currently only fully determined systems are supported.
648 >>> from sympy import Symbol, Matrix
649 >>> x1 = Symbol('x1')
650 >>> x2 = Symbol('x2')
651 >>> f1 = 3 * x1**2 - 2 * x2**2 - 1
652 >>> f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8
653 >>> msolve((x1, x2), (f1, f2), (-1., 1.))
654 [-1.19287309935246]
655 [ 1.27844411169911]
657 if isinstance(f, (list, tuple)):
658 f = Matrix(f).T
659 if len(args) != f.cols:
660 raise NotImplementedError('need exactly as many variables as equations')
661 if verbose:
662 print 'f(x):'
663 print f
664 # derive Jacobian
665 J = f.jacobian(args)
666 if verbose:
667 print 'J(x):'
668 print J
669 # create functions
670 f = lambdify(args, f.T, modules)
671 J = lambdify(args, J, modules)
672 # solve system using Newton's method
673 kwargs = {}
674 if tol:
675 kwargs['tol'] = tol
676 if maxsteps:
677 kwargs['maxsteps'] = maxsteps
678 kwargs['verbose'] = verbose
679 if norm:
680 kwargs['norm'] = norm
681 x = newton(f, x0, J, **kwargs)
682 return x