Line endings converted to LF (#612)
[sympy.git] / examples / sample.py
blobc2798ae55d0835d2b74438e80442af51e468cc9d
1 """
2 Utility functions for plotting sympy functions.
4 See examples\mplot2d.py and examples\mplot3d.py for usable 2d and 3d
5 graphing functions using matplotlib.
6 """
8 from numpy import repeat, arange, empty, ndarray, array
9 from sympy import Symbol, Basic, Real, Rational, I, sympify
11 def sample2d(f, x_args):
12 """
13 Samples a 2d function f over specified intervals and returns two
14 arrays (X, Y) suitable for plotting with matlab (matplotlib)
15 syntax. See examples\mplot2d.py.
17 f is a function of one variable, such as x**2.
18 x_args is an interval given in the form (var, min, max, n)
19 """
20 try:
21 f = sympify(f)
22 except:
23 raise ValueError("f could not be interpretted as a SymPy function")
24 try:
25 x, x_min, x_max, x_n = x_args
26 except:
27 raise ValueError("x_args must be a tuple of the form (var, min, max, n)")
29 x_l = float(x_max - x_min)
30 x_d = x_l/float(x_n)
31 X = arange(float(x_min), float(x_max)+x_d, x_d)
33 Y = empty(len(X))
34 for i in range(len(X)):
35 try:
36 Y[i] = float(f.subs(x, X[i]))
37 except:
38 Y[i] = None
39 return X, Y
41 def sample3d(f, x_args, y_args):
42 """
43 Samples a 3d function f over specified intervals and returns three
44 2d arrays (X, Y, Z) suitable for plotting with matlab (matplotlib)
45 syntax. See examples\mplot3d.py.
47 f is a function of two variables, such as x**2 + y**2.
48 x_args and y_args are intervals given in the form (var, min, max, n)
49 """
50 x, x_min, x_max, x_n = None, None, None, None
51 y, y_min, y_max, y_n = None, None, None, None
52 try:
53 f = sympify(f)
54 except:
55 raise ValueError("f could not be interpretted as a SymPy function")
56 try:
57 x, x_min, x_max, x_n = x_args
58 y, y_min, y_max, y_n = y_args
59 except:
60 raise ValueError("x_args and y_args must be tuples of the form (var, min, max, intervals)")
62 x_l = float(x_max - x_min)
63 x_d = x_l/float(x_n)
64 x_a = arange(float(x_min), float(x_max)+x_d, x_d)
66 y_l = float(y_max - y_min)
67 y_d = y_l/float(y_n)
68 y_a = arange(float(y_min), float(y_max)+y_d, y_d)
70 def meshgrid(x, y):
71 """
72 Taken from matplotlib.mlab.meshgrid.
73 """
74 x = array(x)
75 y = array(y)
76 numRows, numCols = len(y), len(x)
77 x.shape = 1, numCols
78 X = repeat(x, numRows, 0)
80 y.shape = numRows, 1
81 Y = repeat(y, numCols, 1)
82 return X, Y
84 X, Y = meshgrid(x_a, y_a)
86 Z = ndarray((len(X), len(X[0])))
87 for j in range(len(X)):
88 for k in range(len(X[0])):
89 try:
90 Z[j][k] = float( f.subs(x, X[j][k]).subs(y, Y[j][k]) )
91 except:
92 Z[j][k] = 0
93 return X, Y, Z
95 def sample(f, *var_args):
96 """
97 Samples a 2d or 3d function over specified intervals and returns
98 a dataset suitable for plotting with matlab (matplotlib) syntax.
99 Wrapper for sample2d and sample3d.
101 f is a function of one or two variables, such as x**2.
102 var_args are intervals for each variable given in the form (var, min, max, n)
104 if len(var_args) == 1:
105 return sample2d(f, var_args[0])
106 elif len(var_args) == 2:
107 return sample3d(f, var_args[0], var_args[1])
108 else:
109 raise ValueError("Only 2d and 3d sampling are supported at this time.")