Use sympify instead of Basic.sympify (#667)
[sympy.git] / sympy / plotting / plot_mode.py
blobf1c76466fe1579b881c860da77b513d5f2b815c8
1 from sympy import Basic, Symbol, symbols, sympify
2 from plot_interval import PlotInterval
3 from plot_object import PlotObject
4 from color_scheme import ColorScheme
5 from util import parse_option_string
7 class PlotMode(PlotObject):
8 """
9 Grandparent class for plotting
10 modes. Serves as interface for
11 registration, lookup, and init
12 of modes.
14 To create a new plot mode,
15 inherit from PlotModeBase
16 or one of its children, such
17 as PlotSurface or PlotCurve.
18 """
20 ## Class-level attributes
21 ## used to register and lookup
22 ## plot modes. See PlotModeBase
23 ## for descriptions and usage.
25 i_vars, d_vars = '', ''
26 intervals = []
27 aliases = []
28 is_default = False
30 ## Draw is the only method here which
31 ## is meant to be overridden in child
32 ## classes, and PlotModeBase provides
33 ## a base implementation.
34 def draw(self): raise NotImplementedError()
36 ## Everything else in this file has to
37 ## do with registration and retrieval
38 ## of plot modes. This is where I've
39 ## hidden much of the ugliness of automatic
40 ## plot mode divination...
42 ## Plot mode registry data structures
43 _mode_alias_list = []
44 _mode_map = {
45 1: {1: {}, 2: {}},
46 2: {1: {}, 2: {}},
47 3: {1: {}, 2: {}},
48 } # [d][i][alias_str]: class
49 _mode_default_map = {
50 1: {},
51 2: {},
52 3: {},
53 } # [d][i]: class
54 _i_var_max, _d_var_max = 2, 3
56 def __new__(cls, *args, **kwargs):
57 """
58 This is the function which interprets
59 arguments given to Plot.__init__ and
60 Plot.__setattr__. Returns an initialized
61 instance of the appropriate child class.
62 """
63 nargs, nkwargs = PlotMode._extract_options(args, kwargs)
64 mode_arg = nkwargs.get('mode', '')
66 # Interpret the arguments
67 d_vars, intervals = PlotMode._interpret_args(nargs)
68 i_vars = PlotMode._find_i_vars(d_vars, intervals)
69 i, d = max([len(i_vars), len(intervals)]), len(d_vars)
71 # Find the appropriate mode
72 subcls = PlotMode._get_mode(mode_arg, i, d)
74 # Create the object
75 o = object.__new__(subcls)
77 # Do some setup for the mode instance
78 o.d_vars = d_vars
79 o._fill_i_vars(i_vars)
80 o._fill_intervals(intervals)
81 o.options = nkwargs
83 return o
85 @staticmethod
86 def _get_mode(mode_arg, i_var_count, d_var_count):
87 """
88 Tries to return an appropriate mode class.
89 Intended to be called only by __new__.
91 mode_arg
92 Can be a string or a class. If it is a
93 PlotMode subclass, it is simply returned.
94 If it is a string, it can an alias for
95 a mode or an empty string. In the latter
96 case, we try to find a default mode for
97 the i_var_count and d_var_count.
99 i_var_count
100 The number of independent variables
101 needed to evaluate the d_vars.
103 d_var_count
104 The number of dependent variables;
105 usually the number of functions to
106 be evaluated in plotting.
108 For example, a cartesian function y = f(x) has
109 one i_var (x) and one d_var (y). A parametric
110 form x,y,z = f(u,v), f(u,v), f(u,v) has two
111 two i_vars (u,v) and three d_vars (x,y,z).
113 # if the mode_arg is simply a PlotMode class,
114 # check that the mode supports the numbers
115 # of independent and dependent vars, then
116 # return it
117 try:
118 m = None
119 if issubclass(mode_arg, PlotMode):
120 m = mode_arg
121 except: pass
122 if m:
123 if not m._was_initialized:
124 raise ValueError(("To use unregistered plot mode %s "
125 "you must first call %s._init_mode().")
126 % (m.__name__, m.__name__))
127 if d_var_count != m.d_var_count:
128 raise ValueError(("%s can only plot functions "
129 "with %i dependent variables.")
130 % (m.__name__,
131 m.d_var_count))
132 if i_var_count > m.i_var_count:
133 raise ValueError(("%s cannot plot functions "
134 "with more than %i independent "
135 "variables.")
136 % (m.__name__,
137 m.i_var_count))
138 return m
139 # If it is a string, there are two possibilities.
140 if isinstance(mode_arg, str):
141 i, d = i_var_count, d_var_count
142 if i > PlotMode._i_var_max:
143 raise ValueError(var_count_error(True, True))
144 if d > PlotMode._d_var_max:
145 raise ValueError(var_count_error(False, True))
146 # If the string is '', try to find a suitable
147 # default mode
148 if not mode_arg: return PlotMode._get_default_mode(i, d)
149 # Otherwise, interpret the string as a mode
150 # alias (e.g. 'cartesian', 'parametric', etc)
151 else: return PlotMode._get_aliased_mode(mode_arg, i, d)
152 else:
153 raise ValueError("PlotMode argument must be "
154 "a class or a string")
156 @staticmethod
157 def _get_default_mode(i, d, i_vars=-1):
158 if i_vars == -1:
159 i_vars = i
160 try:
161 return PlotMode._mode_default_map[d][i]
162 except:
163 # Keep looking for modes in higher i var counts
164 # which support the given d var count until we
165 # reach the max i_var count.
166 if i < PlotMode._i_var_max:
167 return PlotMode._get_default_mode(i+1, d, i_vars)
168 else:
169 raise ValueError(("Couldn't find a default mode "
170 "for %i independent and %i "
171 "dependent variables.") % (i_vars, d))
173 @staticmethod
174 def _get_aliased_mode(alias, i, d, i_vars=-1):
175 if i_vars == -1:
176 i_vars = i
177 if alias not in PlotMode._mode_alias_list:
178 raise ValueError(("Couldn't find a mode called"
179 " %s. Known modes: %s.")
180 % (alias, ", ".join(PlotMode._mode_alias_list)))
181 try:
182 return PlotMode._mode_map[d][i][alias]
183 except:
184 # Keep looking for modes in higher i var counts
185 # which support the given d var count and alias
186 # until we reach the max i_var count.
187 if i < PlotMode._i_var_max:
188 return PlotMode._get_aliased_mode(alias, i+1, d, i_vars)
189 else:
190 raise ValueError(("Couldn't find a %s mode "
191 "for %i independent and %i "
192 "dependent variables.")
193 % (alias, i_vars, d))
195 @classmethod
196 def _register(cls):
198 Called once for each user-usable plot mode.
199 For Cartesian2D, it is invoked after the
200 class definition: Cartesian2D._register()
202 name = cls.__name__
203 #try:
204 cls._init_mode()
206 #except Exception, e:
207 # raise Exception( ("Failed to initialize "
208 # "plot mode %s. Reason: %s")
209 # % (name, (str(e))) )
211 try:
212 i, d = cls.i_var_count, cls.d_var_count
213 # Add the mode to _mode_map under all
214 # given aliases
215 for a in cls.aliases:
216 if a not in PlotMode._mode_alias_list:
217 # Also track valid aliases, so
218 # we can quickly know when given
219 # an invalid one in _get_mode.
220 PlotMode._mode_alias_list.append(a)
221 PlotMode._mode_map[d][i][a] = cls
222 if cls.is_default:
223 # If this mode was marked as the
224 # default for this d,i combination,
225 # also set that.
226 PlotMode._mode_default_map[d][i] = cls
228 except Exception, e:
229 raise Exception( ("Failed to register "
230 "plot mode %s. Reason: %s")
231 % (name, (str(e))) )
233 @classmethod
234 def _init_mode(cls):
236 Initializes the plot mode based on
237 the 'mode-specific parameters' above.
238 Only intended to be called by
239 PlotMode._register(). To use a mode without
240 registering it, you can directly call
241 ModeSubclass._init_mode().
243 def symbols_list(symbol_str):
245 symbols() doesn't behave exactly
246 like I need. I need a list even
247 when len(str) == 1 or 0.
249 if len(symbol_str) == 0:
250 return []
251 if len(symbol_str) == 1:
252 return [Symbol(symbol_str)]
253 return symbols(symbol_str)
255 # Convert the vars strs into
256 # lists of symbols.
257 cls.i_vars = symbols_list(cls.i_vars)
258 cls.d_vars = symbols_list(cls.d_vars)
260 # Var count is used often, calculate
261 # it once here
262 cls.i_var_count = len(cls.i_vars)
263 cls.d_var_count = len(cls.d_vars)
265 if cls.i_var_count > PlotMode._i_var_max:
266 raise ValueError(var_count_error(True, False))
267 if cls.d_var_count > PlotMode._d_var_max:
268 raise ValueError(var_count_error(False, False))
270 # Try to use first alias as primary_alias
271 if len(cls.aliases) > 0:
272 cls.primary_alias = cls.aliases[0]
273 else:
274 cls.primary_alias = cls.__name__
276 di = cls.intervals
277 if len(di) != cls.i_var_count:
278 raise ValueError("Plot mode must provide a "
279 "default interval for each i_var.")
280 for i in range(cls.i_var_count):
281 assert len(di[i]) == 3 # default intervals
282 # must be given
283 # [min,max,steps]
285 # (no var, but they
286 # must be in the same
287 # order as i_vars)
289 # Initialize an incomplete interval,
290 # to later be filled with a var when
291 # the mode is instantiated.
292 di[i] = PlotInterval(None, *di[i])
294 # To prevent people from using modes
295 # without these required fields set up.
296 cls._was_initialized = True
298 _was_initialized = False
300 ## Initializer Helper Methods
302 @staticmethod
303 def _find_i_vars(functions, intervals):
304 i_vars = []
306 # First, collect i_vars in the
307 # order they are given in any
308 # intervals.
309 for i in intervals:
310 if i.v is None:
311 continue
312 elif i.v in i_vars:
313 raise ValueError(("Multiple intervals given "
314 "for %s.") % (str(i.v)))
315 i_vars.append(i.v)
317 # Then, find any remaining
318 # i_vars in given functions
319 # (aka d_vars)
320 for f in functions:
321 for a in f.atoms(type=Symbol):
322 if a not in i_vars:
323 i_vars.append(a)
325 return i_vars
327 def _fill_i_vars(self, i_vars):
328 # copy default i_vars
329 self.i_vars = [Symbol(str(i)) for i in self.i_vars]
330 # replace with given i_vars
331 for i in range(len(i_vars)):
332 self.i_vars[i] = i_vars[i]
334 def _fill_intervals(self, intervals):
335 # copy default intervals
336 self.intervals = [PlotInterval(i) for i in self.intervals]
337 # track i_vars used so far
338 v_used = []
339 # fill copy of default
340 # intervals with given info
341 for i in range(len(intervals)):
342 self.intervals[i].fill_from(intervals[i])
343 if self.intervals[i].v is not None:
344 v_used.append(self.intervals[i].v)
345 # Find any orphan intervals and
346 # assign them i_vars
347 for i in range(len(self.intervals)):
348 if self.intervals[i].v is None:
349 u = [v for v in self.i_vars if v not in v_used]
350 assert len(u) != 0
351 self.intervals[i].v = u[0]
352 v_used.append(u[0])
354 @staticmethod
355 def _interpret_args(args):
356 interval_wrong_order = "PlotInterval %s was given before any function(s)."
357 interpret_error = "Could not interpret %s as a function or interval."
359 functions, intervals = [], []
360 for a in args:
361 i = PlotInterval.try_parse(a)
362 if i is not None:
363 if len(functions) == 0:
364 raise ValueError(interval_wrong_order % (str(i)))
365 else:
366 intervals.append(i)
367 else:
368 if isinstance(a, (str, list, tuple)):
369 raise ValueError(interpret_error % (str(a)))
370 try:
371 f = sympify(a)
372 functions.append(f)
373 except:
374 raise ValueError(interpret_error % str(a))
376 return functions, intervals
378 @staticmethod
379 def _extract_options(args, kwargs):
380 nkwargs, nargs = {}, []
381 for a in args:
382 if isinstance(a, str):
383 nkwargs = dict(nkwargs, **parse_option_string(a))
384 else:
385 nargs.append(a)
386 nkwargs = dict(nkwargs, **kwargs)
387 return nargs, nkwargs
389 def var_count_error(is_independent, is_plotting):
391 Used to format an error message which differs
392 slightly in 4 places.
394 if is_plotting: v = "Plotting"
395 else: v = "Registering plot modes"
396 if is_independent: n, s = PlotMode._i_var_max, "independent"
397 else: n, s = PlotMode._d_var_max, "dependent"
398 return ("%s with more than %i %s variables "
399 "is not supported.") % (v, n, s)