Added a warning when constructing a Matrix without bracket + test modified
[sympy.git] / sympy / plotting / plot_mode.py
bloba87a29d7c988f2c7b1f2461d53f8f74be83b96c6
1 from sympy import Basic, Symbol, symbols, sympify
2 from plot_interval import PlotInterval
3 from plot_object import PlotObject
4 from color_scheme import ColorScheme
5 from util import parse_option_string
6 from sympy.geometry.entity import GeometryEntity
8 class PlotMode(PlotObject):
9 """
10 Grandparent class for plotting
11 modes. Serves as interface for
12 registration, lookup, and init
13 of modes.
15 To create a new plot mode,
16 inherit from PlotModeBase
17 or one of its children, such
18 as PlotSurface or PlotCurve.
19 """
21 ## Class-level attributes
22 ## used to register and lookup
23 ## plot modes. See PlotModeBase
24 ## for descriptions and usage.
26 i_vars, d_vars = '', ''
27 intervals = []
28 aliases = []
29 is_default = False
31 ## Draw is the only method here which
32 ## is meant to be overridden in child
33 ## classes, and PlotModeBase provides
34 ## a base implementation.
35 def draw(self): raise NotImplementedError()
37 ## Everything else in this file has to
38 ## do with registration and retrieval
39 ## of plot modes. This is where I've
40 ## hidden much of the ugliness of automatic
41 ## plot mode divination...
43 ## Plot mode registry data structures
44 _mode_alias_list = []
45 _mode_map = {
46 1: {1: {}, 2: {}},
47 2: {1: {}, 2: {}},
48 3: {1: {}, 2: {}},
49 } # [d][i][alias_str]: class
50 _mode_default_map = {
51 1: {},
52 2: {},
53 3: {},
54 } # [d][i]: class
55 _i_var_max, _d_var_max = 2, 3
57 def __new__(cls, *args, **kwargs):
58 """
59 This is the function which interprets
60 arguments given to Plot.__init__ and
61 Plot.__setattr__. Returns an initialized
62 instance of the appropriate child class.
63 """
65 nargs, nkwargs = PlotMode._extract_options(args, kwargs)
66 mode_arg = nkwargs.get('mode', '')
68 # Interpret the arguments
69 d_vars, intervals = PlotMode._interpret_args(nargs)
70 i_vars = PlotMode._find_i_vars(d_vars, intervals)
71 i, d = max([len(i_vars), len(intervals)]), len(d_vars)
73 # Find the appropriate mode
74 subcls = PlotMode._get_mode(mode_arg, i, d)
76 # Create the object
77 o = object.__new__(subcls)
79 # Do some setup for the mode instance
80 o.d_vars = d_vars
81 o._fill_i_vars(i_vars)
82 o._fill_intervals(intervals)
83 o.options = nkwargs
85 return o
87 @staticmethod
88 def _get_mode(mode_arg, i_var_count, d_var_count):
89 """
90 Tries to return an appropriate mode class.
91 Intended to be called only by __new__.
93 mode_arg
94 Can be a string or a class. If it is a
95 PlotMode subclass, it is simply returned.
96 If it is a string, it can an alias for
97 a mode or an empty string. In the latter
98 case, we try to find a default mode for
99 the i_var_count and d_var_count.
101 i_var_count
102 The number of independent variables
103 needed to evaluate the d_vars.
105 d_var_count
106 The number of dependent variables;
107 usually the number of functions to
108 be evaluated in plotting.
110 For example, a cartesian function y = f(x) has
111 one i_var (x) and one d_var (y). A parametric
112 form x,y,z = f(u,v), f(u,v), f(u,v) has two
113 two i_vars (u,v) and three d_vars (x,y,z).
115 # if the mode_arg is simply a PlotMode class,
116 # check that the mode supports the numbers
117 # of independent and dependent vars, then
118 # return it
119 try:
120 m = None
121 if issubclass(mode_arg, PlotMode):
122 m = mode_arg
123 except: pass
124 if m:
125 if not m._was_initialized:
126 raise ValueError(("To use unregistered plot mode %s "
127 "you must first call %s._init_mode().")
128 % (m.__name__, m.__name__))
129 if d_var_count != m.d_var_count:
130 raise ValueError(("%s can only plot functions "
131 "with %i dependent variables.")
132 % (m.__name__,
133 m.d_var_count))
134 if i_var_count > m.i_var_count:
135 raise ValueError(("%s cannot plot functions "
136 "with more than %i independent "
137 "variables.")
138 % (m.__name__,
139 m.i_var_count))
140 return m
141 # If it is a string, there are two possibilities.
142 if isinstance(mode_arg, str):
143 i, d = i_var_count, d_var_count
144 if i > PlotMode._i_var_max:
145 raise ValueError(var_count_error(True, True))
146 if d > PlotMode._d_var_max:
147 raise ValueError(var_count_error(False, True))
148 # If the string is '', try to find a suitable
149 # default mode
150 if not mode_arg: return PlotMode._get_default_mode(i, d)
151 # Otherwise, interpret the string as a mode
152 # alias (e.g. 'cartesian', 'parametric', etc)
153 else: return PlotMode._get_aliased_mode(mode_arg, i, d)
154 else:
155 raise ValueError("PlotMode argument must be "
156 "a class or a string")
158 @staticmethod
159 def _get_default_mode(i, d, i_vars=-1):
160 if i_vars == -1:
161 i_vars = i
162 try:
163 return PlotMode._mode_default_map[d][i]
164 except:
165 # Keep looking for modes in higher i var counts
166 # which support the given d var count until we
167 # reach the max i_var count.
168 if i < PlotMode._i_var_max:
169 return PlotMode._get_default_mode(i+1, d, i_vars)
170 else:
171 raise ValueError(("Couldn't find a default mode "
172 "for %i independent and %i "
173 "dependent variables.") % (i_vars, d))
175 @staticmethod
176 def _get_aliased_mode(alias, i, d, i_vars=-1):
177 if i_vars == -1:
178 i_vars = i
179 if alias not in PlotMode._mode_alias_list:
180 raise ValueError(("Couldn't find a mode called"
181 " %s. Known modes: %s.")
182 % (alias, ", ".join(PlotMode._mode_alias_list)))
183 try:
184 return PlotMode._mode_map[d][i][alias]
185 except:
186 # Keep looking for modes in higher i var counts
187 # which support the given d var count and alias
188 # until we reach the max i_var count.
189 if i < PlotMode._i_var_max:
190 return PlotMode._get_aliased_mode(alias, i+1, d, i_vars)
191 else:
192 raise ValueError(("Couldn't find a %s mode "
193 "for %i independent and %i "
194 "dependent variables.")
195 % (alias, i_vars, d))
197 @classmethod
198 def _register(cls):
200 Called once for each user-usable plot mode.
201 For Cartesian2D, it is invoked after the
202 class definition: Cartesian2D._register()
204 name = cls.__name__
205 #try:
206 cls._init_mode()
208 #except Exception, e:
209 # raise Exception( ("Failed to initialize "
210 # "plot mode %s. Reason: %s")
211 # % (name, (str(e))) )
213 try:
214 i, d = cls.i_var_count, cls.d_var_count
215 # Add the mode to _mode_map under all
216 # given aliases
217 for a in cls.aliases:
218 if a not in PlotMode._mode_alias_list:
219 # Also track valid aliases, so
220 # we can quickly know when given
221 # an invalid one in _get_mode.
222 PlotMode._mode_alias_list.append(a)
223 PlotMode._mode_map[d][i][a] = cls
224 if cls.is_default:
225 # If this mode was marked as the
226 # default for this d,i combination,
227 # also set that.
228 PlotMode._mode_default_map[d][i] = cls
230 except Exception, e:
231 raise Exception( ("Failed to register "
232 "plot mode %s. Reason: %s")
233 % (name, (str(e))) )
235 @classmethod
236 def _init_mode(cls):
238 Initializes the plot mode based on
239 the 'mode-specific parameters' above.
240 Only intended to be called by
241 PlotMode._register(). To use a mode without
242 registering it, you can directly call
243 ModeSubclass._init_mode().
245 def symbols_list(symbol_str):
247 symbols() doesn't behave exactly
248 like I need. I need a list even
249 when len(str) == 1 or 0.
251 if len(symbol_str) == 0:
252 return []
253 if len(symbol_str) == 1:
254 return [Symbol(symbol_str)]
255 return symbols(symbol_str)
257 # Convert the vars strs into
258 # lists of symbols.
259 cls.i_vars = symbols_list(cls.i_vars)
260 cls.d_vars = symbols_list(cls.d_vars)
262 # Var count is used often, calculate
263 # it once here
264 cls.i_var_count = len(cls.i_vars)
265 cls.d_var_count = len(cls.d_vars)
267 if cls.i_var_count > PlotMode._i_var_max:
268 raise ValueError(var_count_error(True, False))
269 if cls.d_var_count > PlotMode._d_var_max:
270 raise ValueError(var_count_error(False, False))
272 # Try to use first alias as primary_alias
273 if len(cls.aliases) > 0:
274 cls.primary_alias = cls.aliases[0]
275 else:
276 cls.primary_alias = cls.__name__
278 di = cls.intervals
279 if len(di) != cls.i_var_count:
280 raise ValueError("Plot mode must provide a "
281 "default interval for each i_var.")
282 for i in range(cls.i_var_count):
283 assert len(di[i]) == 3 # default intervals
284 # must be given
285 # [min,max,steps]
287 # (no var, but they
288 # must be in the same
289 # order as i_vars)
291 # Initialize an incomplete interval,
292 # to later be filled with a var when
293 # the mode is instantiated.
294 di[i] = PlotInterval(None, *di[i])
296 # To prevent people from using modes
297 # without these required fields set up.
298 cls._was_initialized = True
300 _was_initialized = False
302 ## Initializer Helper Methods
304 @staticmethod
305 def _find_i_vars(functions, intervals):
306 i_vars = []
308 # First, collect i_vars in the
309 # order they are given in any
310 # intervals.
311 for i in intervals:
312 if i.v is None:
313 continue
314 elif i.v in i_vars:
315 raise ValueError(("Multiple intervals given "
316 "for %s.") % (str(i.v)))
317 i_vars.append(i.v)
319 # Then, find any remaining
320 # i_vars in given functions
321 # (aka d_vars)
322 for f in functions:
323 for a in f.atoms(Symbol):
324 if a not in i_vars:
325 i_vars.append(a)
327 return i_vars
329 def _fill_i_vars(self, i_vars):
330 # copy default i_vars
331 self.i_vars = [Symbol(str(i)) for i in self.i_vars]
332 # replace with given i_vars
333 for i in range(len(i_vars)):
334 self.i_vars[i] = i_vars[i]
336 def _fill_intervals(self, intervals):
337 # copy default intervals
338 self.intervals = [PlotInterval(i) for i in self.intervals]
339 # track i_vars used so far
340 v_used = []
341 # fill copy of default
342 # intervals with given info
343 for i in range(len(intervals)):
344 self.intervals[i].fill_from(intervals[i])
345 if self.intervals[i].v is not None:
346 v_used.append(self.intervals[i].v)
347 # Find any orphan intervals and
348 # assign them i_vars
349 for i in range(len(self.intervals)):
350 if self.intervals[i].v is None:
351 u = [v for v in self.i_vars if v not in v_used]
352 assert len(u) != 0
353 self.intervals[i].v = u[0]
354 v_used.append(u[0])
356 @staticmethod
357 def _interpret_args(args):
358 interval_wrong_order = "PlotInterval %s was given before any function(s)."
359 interpret_error = "Could not interpret %s as a function or interval."
361 functions, intervals = [], []
362 if isinstance(args[0], GeometryEntity):
363 for coords in list(args[0].arbitrary_point()):
364 functions.append(coords)
365 intervals.append(PlotInterval.try_parse(args[0].plot_interval()))
366 else:
367 for a in args:
368 i = PlotInterval.try_parse(a)
369 if i is not None:
370 if len(functions) == 0:
371 raise ValueError(interval_wrong_order % (str(i)))
372 else:
373 intervals.append(i)
374 else:
375 if isinstance(a, (str, list, tuple)):
376 raise ValueError(interpret_error % (str(a)))
377 try:
378 f = sympify(a)
379 functions.append(f)
380 except:
381 raise ValueError(interpret_error % str(a))
383 return functions, intervals
385 @staticmethod
386 def _extract_options(args, kwargs):
387 nkwargs, nargs = {}, []
388 for a in args:
389 if isinstance(a, str):
390 nkwargs = dict(nkwargs, **parse_option_string(a))
391 else:
392 nargs.append(a)
393 nkwargs = dict(nkwargs, **kwargs)
394 return nargs, nkwargs
396 def var_count_error(is_independent, is_plotting):
398 Used to format an error message which differs
399 slightly in 4 places.
401 if is_plotting: v = "Plotting"
402 else: v = "Registering plot modes"
403 if is_independent: n, s = PlotMode._i_var_max, "independent"
404 else: n, s = PlotMode._d_var_max, "dependent"
405 return ("%s with more than %i %s variables "
406 "is not supported.") % (v, n, s)