minor cleanups
[PyX/mjg.git] / pyx / graph / style.py
blob08c82cd190b1aca7169cf98de222e312d9a08546
1 #!/usr/bin/env python
2 # -*- coding: ISO-8859-1 -*-
5 # Copyright (C) 2002-2004 Jörg Lehmann <joergl@users.sourceforge.net>
6 # Copyright (C) 2003-2004 Michael Schindler <m-schindler@users.sourceforge.net>
7 # Copyright (C) 2002-2004 André Wobst <wobsta@users.sourceforge.net>
9 # This file is part of PyX (http://pyx.sourceforge.net/).
11 # PyX is free software; you can redistribute it and/or modify
12 # it under the terms of the GNU General Public License as published by
13 # the Free Software Foundation; either version 2 of the License, or
14 # (at your option) any later version.
16 # PyX is distributed in the hope that it will be useful,
17 # but WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 # GNU General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with PyX; if not, write to the Free Software
23 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
26 import re, math
27 from pyx import attr, deco, style, color, unit, canvas, path
28 from pyx import text as textmodule
31 class _style:
33 def setdatapattern(self, graph, columns, pattern):
34 for datakey in columns.keys():
35 match = pattern.match(datakey)
36 if match:
37 # XXX match.groups()[0] must contain the full axisname
38 axisname = match.groups()[0]
39 index = columns[datakey]
40 del columns[datakey]
41 return graph.axes[axisname], index
43 def key_pt(self, c, x_pt, y_pt, width_pt, height_pt, styledata):
44 raise RuntimeError("style doesn't provide a key")
46 def adjustaxes(self, points, columns, styledata):
47 return
49 def setdata(self, graph, columns, styledata):
50 return columns
52 def selectstyle(self, selectindex, selecttotal, styledata):
53 pass
55 def initdrawpoints(self, graph, styledata):
56 pass
58 def drawpoint(self, graph, styledata):
59 pass
61 def donedrawpoints(self, graph, styledata):
62 pass
65 class pointpos(_style):
67 def __init__(self, epsilon=1e-10):
68 self.epsilon = epsilon
70 def setdata(self, graph, columns, styledata):
71 # analyse column information
72 styledata.pointposaxisindex = []
73 columns = columns.copy()
74 for axisname in graph.axisnames:
75 pattern = re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % axisname)
76 styledata.pointposaxisindex.append(self.setdatapattern(graph, columns, pattern)) # TODO: append "needed=1"
77 return columns
79 def adjustaxes(self, points, columns, styledata):
80 for axis, index in styledata.pointposaxisindex:
81 axis.adjustrange(points, index)
83 def drawpoint(self, graph, styledata):
84 # calculate vpos
85 styledata.vpos = []
86 styledata.vposavailable = 1 # valid position (but might be outside of the graph)
87 styledata.vposvalid = 1 # valid position inside the graph
88 for axis, index in styledata.pointposaxisindex:
89 try:
90 v = axis.convert(styledata.point[index])
91 except (ArithmeticError, KeyError, ValueError, TypeError):
92 styledata.vposavailable = 0
93 styledata.vposvalid = 0
94 styledata.vpos.append(None)
95 else:
96 if v < - self.epsilon or v > 1 + self.epsilon:
97 styledata.vposvalid = 0
98 styledata.vpos.append(v)
101 class rangepos(_style):
103 def setdata(self, graph, columns, styledata):
105 - the instance should be considered read-only
106 (it might be shared between several data)
107 - styledata is the place where to store information
108 - returns the dictionary of columns not used by the style"""
110 # analyse column information
111 styledata.index = {} # a nested index dictionary containing
112 # column numbers, e.g. styledata.index["x"]["x"],
113 # styledata.index["y"]["dmin"] etc.; the first key is a axis
114 # name (without the axis number), the second is one of
115 # the datanames ["x", "min", "max", "d", "dmin", "dmax"]
116 styledata.axes = {} # mapping from axis name (without axis number) to the axis
118 columns = columns.copy()
119 for axisname in graph.axisnames:
120 for dataname, pattern in [("x", re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % axisname)),
121 ("min", re.compile(r"(%s([2-9]|[1-9][0-9]+)?)min$" % axisname)),
122 ("max", re.compile(r"(%s([2-9]|[1-9][0-9]+)?)max$" % axisname)),
123 ("d", re.compile(r"d(%s([2-9]|[1-9][0-9]+)?)$" % axisname)),
124 ("dmin", re.compile(r"d(%s([2-9]|[1-9][0-9]+)?)min$" % axisname)),
125 ("dmax", re.compile(r"d(%s([2-9]|[1-9][0-9]+)?)max$" % axisname))]:
126 matchresult = self.setdatapattern(graph, columns, pattern)
127 if matchresult is not None:
128 axis, index = matchresult
129 if styledata.axes.has_key(axisname):
130 if styledata.axes[axisname] != axis:
131 raise ValueError("axis mismatch for axis name '%s'" % axisname)
132 styledata.index[axisname][dataname] = index
133 else:
134 styledata.index[axisname] = {dataname: index}
135 styledata.axes[axisname] = axis
136 if not styledata.axes.has_key(axisname):
137 raise ValueError("missing columns for axis name '%s'" % axisname)
138 if ((styledata.index[axisname].has_key("min") and styledata.index[axisname].has_key("d")) or
139 (styledata.index[axisname].has_key("min") and styledata.index[axisname].has_key("dmin")) or
140 (styledata.index[axisname].has_key("d") and styledata.index[axisname].has_key("dmin")) or
141 (styledata.index[axisname].has_key("max") and styledata.index[axisname].has_key("d")) or
142 (styledata.index[axisname].has_key("max") and styledata.index[axisname].has_key("dmax")) or
143 (styledata.index[axisname].has_key("d") and styledata.index[axisname].has_key("dmax"))):
144 raise ValueError("multiple errorbar definition for axis name '%s'" % axisname)
145 if (not styledata.index[axisname].has_key("x") and
146 (styledata.index[axisname].has_key("d") or
147 styledata.index[axisname].has_key("dmin") or
148 styledata.index[axisname].has_key("dmax"))):
149 raise ValueError("errorbar definition start value missing for axis name '%s'" % axisname)
150 return columns
152 def adjustaxes(self, points, columns, styledata):
153 # reverse lookup for axisnames
154 # TODO: the reverse lookup is ugly
155 axisnames = []
156 for column in columns:
157 for axisname in styledata.index.keys():
158 for thiscolumn in styledata.index[axisname].values():
159 if thiscolumn == column and axisname not in axisnames:
160 axisnames.append(axisname)
161 # TODO: perform check to verify that all columns for a given axisname are available at the same time
162 for axisname in axisnames:
163 if styledata.index[axisname].has_key("x"):
164 styledata.axes[axisname].adjustrange(points, styledata.index[axisname]["x"])
165 if styledata.index[axisname].has_key("min"):
166 styledata.axes[axisname].adjustrange(points, styledata.index[axisname]["min"])
167 if styledata.index[axisname].has_key("max"):
168 styledata.axes[axisname].adjustrange(points, styledata.index[axisname]["max"])
169 if styledata.index[axisname].has_key("d"):
170 styledata.axes[axisname].adjustrange(points, styledata.index[axisname]["x"], deltaindex=styledata.index[axisname]["d"])
171 if styledata.index[axisname].has_key("dmin"):
172 styledata.axes[axisname].adjustrange(points, styledata.index[axisname]["x"], deltaminindex=styledata.index[axisname]["dmin"])
173 if styledata.index[axisname].has_key("dmax"):
174 styledata.axes[axisname].adjustrange(points, styledata.index[axisname]["x"], deltamaxindex=styledata.index[axisname]["dmax"])
176 def doerrorbars(self, styledata):
177 # errorbar loop over the different direction having errorbars
178 for erroraxisname, erroraxisindex in styledata.errorlist:
180 # check for validity of other point components
181 i = 0
182 for v in styledata.vpos:
183 if v is None and i != erroraxisindex:
184 break
185 i += 1
186 else:
187 # calculate min and max
188 errorindex = styledata.index[erroraxisname]
189 try:
190 min = styledata.point[errorindex["x"]] - styledata.point[errorindex["d"]]
191 except:
192 try:
193 min = styledata.point[errorindex["x"]] - styledata.point[errorindex["dmin"]]
194 except:
195 try:
196 min = styledata.point[errorindex["min"]]
197 except:
198 min = None
199 try:
200 max = styledata.point[errorindex["x"]] + styledata.point[errorindex["d"]]
201 except:
202 try:
203 max = styledata.point[errorindex["x"]] + styledata.point[errorindex["dmax"]]
204 except:
205 try:
206 max = styledata.point[errorindex["max"]]
207 except:
208 max = None
210 # calculate vmin and vmax
211 try:
212 vmin = styledata.axes[erroraxisname].convert(min)
213 except:
214 vmin = None
215 try:
216 vmax = styledata.axes[erroraxisname].convert(max)
217 except:
218 vmax = None
220 # create vminpos and vmaxpos
221 vcaps = []
222 if vmin is not None:
223 vminpos = styledata.vpos[:]
224 if vmin > - self.epsilon and vmin < 1 + self.epsilon:
225 vminpos[erroraxisindex] = vmin
226 vcaps.append(vminpos)
227 else:
228 vminpos[erroraxisindex] = 0
229 elif styledata.vpos[erroraxisindex] is not None:
230 vminpos = styledata.vpos
231 else:
232 break
233 if vmax is not None:
234 vmaxpos = styledata.vpos[:]
235 if vmax > - self.epsilon and vmax < 1 + self.epsilon:
236 vmaxpos[erroraxisindex] = vmax
237 vcaps.append(vmaxpos)
238 else:
239 vmaxpos[erroraxisindex] = 1
240 elif styledata.vpos[erroraxisindex] is not None:
241 vmaxpos = styledata.vpos
242 else:
243 break
246 def _crosssymbol(c, x_pt, y_pt, size_pt, attrs):
247 c.draw(path.path(path.moveto_pt(x_pt-0.5*size_pt, y_pt-0.5*size_pt),
248 path.lineto_pt(x_pt+0.5*size_pt, y_pt+0.5*size_pt),
249 path.moveto_pt(x_pt-0.5*size_pt, y_pt+0.5*size_pt),
250 path.lineto_pt(x_pt+0.5*size_pt, y_pt-0.5*size_pt)), attrs)
252 def _plussymbol(c, x_pt, y_pt, size_pt, attrs):
253 c.draw(path.path(path.moveto_pt(x_pt-0.707106781*size_pt, y_pt),
254 path.lineto_pt(x_pt+0.707106781*size_pt, y_pt),
255 path.moveto_pt(x_pt, y_pt-0.707106781*size_pt),
256 path.lineto_pt(x_pt, y_pt+0.707106781*size_pt)), attrs)
258 def _squaresymbol(c, x_pt, y_pt, size_pt, attrs):
259 c.draw(path.path(path.moveto_pt(x_pt-0.5*size_pt, y_pt-0.5*size_pt),
260 path.lineto_pt(x_pt+0.5*size_pt, y_pt-0.5*size_pt),
261 path.lineto_pt(x_pt+0.5*size_pt, y_pt+0.5*size_pt),
262 path.lineto_pt(x_pt-0.5*size_pt, y_pt+0.5*size_pt),
263 path.closepath()), attrs)
265 def _trianglesymbol(c, x_pt, y_pt, size_pt, attrs):
266 c.draw(path.path(path.moveto_pt(x_pt-0.759835685*size_pt, y_pt-0.438691337*size_pt),
267 path.lineto_pt(x_pt+0.759835685*size_pt, y_pt-0.438691337*size_pt),
268 path.lineto_pt(x_pt, y_pt+0.877382675*size_pt),
269 path.closepath()), attrs)
271 def _circlesymbol(c, x_pt, y_pt, size_pt, attrs):
272 c.draw(path.path(path.arc_pt(x_pt, y_pt, 0.564189583*size_pt, 0, 360),
273 path.closepath()), attrs)
275 def _diamondsymbol(c, x_pt, y_pt, size_pt, attrs):
276 c.draw(path.path(path.moveto_pt(x_pt-0.537284965*size_pt, y_pt),
277 path.lineto_pt(x_pt, y_pt-0.930604859*size_pt),
278 path.lineto_pt(x_pt+0.537284965*size_pt, y_pt),
279 path.lineto_pt(x_pt, y_pt+0.930604859*size_pt),
280 path.closepath()), attrs)
283 class symbol(_style):
285 # insert symbols like staticmethods
286 cross = _crosssymbol
287 plus = _plussymbol
288 square = _squaresymbol
289 triangle = _trianglesymbol
290 circle = _circlesymbol
291 diamond = _diamondsymbol
293 changecross = attr.changelist([cross, plus, square, triangle, circle, diamond])
294 changeplus = attr.changelist([plus, square, triangle, circle, diamond, cross])
295 changesquare = attr.changelist([square, triangle, circle, diamond, cross, plus])
296 changetriangle = attr.changelist([triangle, circle, diamond, cross, plus, square])
297 changecircle = attr.changelist([circle, diamond, cross, plus, square, triangle])
298 changediamond = attr.changelist([diamond, cross, plus, square, triangle, circle])
299 changesquaretwice = attr.changelist([square, square, triangle, triangle, circle, circle, diamond, diamond])
300 changetriangletwice = attr.changelist([triangle, triangle, circle, circle, diamond, diamond, square, square])
301 changecircletwice = attr.changelist([circle, circle, diamond, diamond, square, square, triangle, triangle])
302 changediamondtwice = attr.changelist([diamond, diamond, square, square, triangle, triangle, circle, circle])
304 changestrokedfilled = attr.changelist([deco.stroked, deco.filled])
305 changefilledstroked = attr.changelist([deco.filled, deco.stroked])
307 defaultsymbolattrs = [deco.stroked]
309 def __init__(self, symbol=changecross, size=0.2*unit.v_cm, symbolattrs=[]):
310 self.symbol = symbol
311 self.size = size
312 self.symbolattrs = symbolattrs
314 def selectstyle(self, selectindex, selecttotal, styledata):
315 styledata.symbol = attr.selectattr(self.symbol, selectindex, selecttotal)
316 styledata.size_pt = unit.topt(attr.selectattr(self.size, selectindex, selecttotal))
317 if self.symbolattrs is not None:
318 styledata.symbolattrs = attr.selectattrs(self.defaultsymbolattrs + self.symbolattrs, selectindex, selecttotal)
319 else:
320 styledata.symbolattrs = None
322 def drawpoint(self, graph, styledata):
323 if styledata.vposvalid and styledata.symbolattrs is not None:
324 xpos, ypos = graph.vpos_pt(*styledata.vpos)
325 styledata.symbol(graph, xpos, ypos, styledata.size_pt, styledata.symbolattrs)
327 def key_pt(self, graph, x_pt, y_pt, width_pt, height_pt, styledata):
328 if styledata.symbolattrs is not None:
329 styledata.symbol(graph, x_pt+0.5*width_pt, y_pt+0.5*height_pt, styledata.size_pt, styledata.symbolattrs)
332 class line(_style):
334 changelinestyle = attr.changelist([style.linestyle.solid,
335 style.linestyle.dashed,
336 style.linestyle.dotted,
337 style.linestyle.dashdotted])
339 defaultlineattrs = [changelinestyle]
341 def __init__(self, lineattrs=[]):
342 self.lineattrs = lineattrs
344 def selectstyle(self, selectindex, selecttotal, styledata):
345 styledata.lineattrs = attr.selectattrs(self.defaultlineattrs + self.lineattrs, selectindex, selecttotal)
347 def initdrawpoints(self, graph, styledata):
348 styledata.linecanvas = graph.insert(canvas.canvas())
349 styledata.path = path.path()
350 styledata.linebasepoints = []
351 styledata.lastvpos = None
353 def appendlinebasepoints(self, graph, styledata):
354 # append linebasepoints
355 if styledata.vposavailable:
356 if len(styledata.linebasepoints):
357 # the last point was inside the graph
358 if styledata.vposvalid: # shortcut for the common case
359 styledata.linebasepoints.append(graph.vpos_pt(*styledata.vpos))
360 else:
361 # cut end
362 cut = 1
363 for vstart, vend in zip(styledata.lastvpos, styledata.vpos):
364 newcut = None
365 if vend > 1:
366 # 1 = vstart + (vend - vstart) * cut
367 try:
368 newcut = (1 - vstart)/(vend - vstart)
369 except ArithmeticError:
370 break
371 if vend < 0:
372 # 0 = vstart + (vend - vstart) * cut
373 try:
374 newcut = - vstart/(vend - vstart)
375 except ArithmeticError:
376 break
377 if newcut is not None and newcut < cut:
378 cut = newcut
379 else:
380 cutvpos = []
381 for vstart, vend in zip(styledata.lastvpos, styledata.vpos):
382 cutvpos.append(vstart + (vend - vstart) * cut)
383 styledata.linebasepoints.append(styledata.graph.vpos_pt(*cutvpos))
384 styledata.lastvpos = styledata.vpos
385 return 0
386 else:
387 # the last point was outside the graph
388 if styledata.lastvpos is not None:
389 if styledata.vposvalid:
390 # cut beginning
391 cut = 0
392 for vstart, vend in zip(styledata.lastvpos, styledata.vpos):
393 newcut = None
394 if vstart > 1:
395 # 1 = vstart + (vend - vstart) * cut
396 try:
397 newcut = (1 - vstart)/(vend - vstart)
398 except ArithmeticError:
399 break
400 if vstart < 0:
401 # 0 = vstart + (vend - vstart) * cut
402 try:
403 newcut = - vstart/(vend - vstart)
404 except ArithmeticError:
405 break
406 if newcut is not None and newcut > cut:
407 cut = newcut
408 else:
409 cutvpos = []
410 for vstart, vend in zip(styledata.lastvpos, styledata.vpos):
411 cutvpos.append(vstart + (vend - vstart) * cut)
412 styledata.linebasepoints.append(graph.vpos_pt(*cutvpos))
413 styledata.linebasepoints.append(graph.vpos_pt(*styledata.vpos))
414 else:
415 # sometimes cut beginning and end
416 cutfrom = 0
417 cutto = 1
418 for vstart, vend in zip(styledata.lastvpos, styledata.vpos):
419 newcutfrom = None
420 if vstart > 1:
421 if vend > 1:
422 break
423 # 1 = vstart + (vend - vstart) * cutfrom
424 try:
425 newcutfrom = (1 - vstart)/(vend - vstart)
426 except ArithmeticError:
427 break
428 if vstart < 0:
429 if vend < 0:
430 break
431 # 0 = vstart + (vend - vstart) * cutfrom
432 try:
433 newcutfrom = - vstart/(vend - vstart)
434 except ArithmeticError:
435 break
436 if newcutfrom is not None and newcutfrom > cutfrom:
437 cutfrom = newcutfrom
438 newcutto = None
439 if vend > 1:
440 # 1 = vstart + (vend - vstart) * cutto
441 try:
442 newcutto = (1 - vstart)/(vend - vstart)
443 except ArithmeticError:
444 break
445 if vend < 0:
446 # 0 = vstart + (vend - vstart) * cutto
447 try:
448 newcutto = - vstart/(vend - vstart)
449 except ArithmeticError:
450 break
451 if newcutto is not None and newcutto < cutto:
452 cutto = newcutto
453 else:
454 if cutfrom < cutto:
455 cutfromvpos = []
456 cuttovpos = []
457 for vstart, vend in zip(styledata.lastvpos, styledata.vpos):
458 cutfromvpos.append(vstart + (vend - vstart) * cutfrom)
459 cuttovpos.append(vstart + (vend - vstart) * cutto)
460 styledata.linebasepoints.append(styledata.graph.vpos_pt(*cutfromvpos))
461 styledata.linebasepoints.append(styledata.graph.vpos_pt(*cuttovpos))
462 styledata.lastvpos = styledata.vpos
463 return 0
464 styledata.lastvpos = styledata.vpos
465 else:
466 styledata.lastvpos = None
467 return 0
469 def addpointstopath(self, styledata):
470 # add baselinepoints to styledata.path
471 if len(styledata.linebasepoints) > 1:
472 styledata.path.append(path.moveto_pt(*styledata.linebasepoints[0]))
473 if len(styledata.linebasepoints) > 2:
474 styledata.path.append(path.multilineto_pt(styledata.linebasepoints[1:]))
475 else:
476 styledata.path.append(path.lineto_pt(*styledata.linebasepoints[1]))
477 styledata.linebasepoints = []
479 def drawpoint(self, graph, styledata):
480 if self.appendlinebasepoints(graph, styledata):
481 self.addpointstopath(styledata)
483 def donedrawpoints(self, graph, styledata):
484 self.addpointstopath(styledata)
485 if styledata.lineattrs is not None and len(styledata.path.path):
486 styledata.linecanvas.stroke(styledata.path, styledata.lineattrs)
488 def key_pt(self, c, x_pt, y_pt, width_pt, height_pt, styledata):
489 if styledata.lineattrs is not None:
490 c.stroke(path.line_pt(x_pt, y_pt+0.5*height_pt, x_pt+width_pt, y_pt+0.5*height_pt), styledata.lineattrs)
493 class errorbars(_style):
495 defaulterrorbarattrs = []
497 def __init__(self, size=0.1*unit.v_cm,
498 errorbarattrs=[],
499 epsilon=1e-10):
500 self.size = size
501 self.errorbarattrs = errorbarattrs
502 self.epsilon = epsilon
504 def selectstyle(self, selectindex, selecttotal, styledata):
505 styledata.errorsize_pt = unit.topt(attr.selectattr(self.size, selectindex, selecttotal))
506 styledata.errorbarattrs = attr.selectattrs(self.defaulterrorbarattrs + self.errorbarattrs, selectindex, selecttotal)
508 def initdrawpoints(self, graph, styledata):
509 styledata.errorbarcanvas = graph.insert(canvas.canvas())
510 styledata.errorlist = []
511 if styledata.errorbarattrs is not None:
512 axisindex = 0
513 for axisname in graph.axisnames:
514 if styledata.index[axisname].keys() != ["x"]:
515 styledata.errorlist.append((axisname, axisindex))
516 axisindex += 1
518 def doerrorbars(self, styledata):
519 # errorbar loop over the different direction having errorbars
520 for erroraxisname, erroraxisindex in styledata.errorlist:
521 # create path for errorbars
522 errorpath = path.path()
523 errorpath += styledata.graph.vgeodesic(*(vminpos + vmaxpos))
524 for vcap in vcaps:
525 for axisname in styledata.graph.axisnames:
526 if axisname != erroraxisname:
527 errorpath += styledata.graph.vcap_pt(axisname, styledata.errorsize_pt, *vcap)
529 # stroke errorpath
530 if len(errorpath.path):
531 styledata.errorbarcanvas.stroke(errorpath, styledata.errorbarattrs)
533 def drawpoint(self, graph, styledata):
534 self.doerrorbars(styledata)
537 class text(symbol):
539 defaulttextattrs = [textmodule.halign.center, textmodule.vshift.mathaxis]
541 def __init__(self, textdx=0*unit.v_cm, textdy=0.3*unit.v_cm, textattrs=[], **kwargs):
542 self.textdx = textdx
543 self.textdy = textdy
544 self.textattrs = textattrs
545 symbol.__init__(self, **kwargs)
547 def setdata(self, graph, columns, styledata):
548 columns = columns.copy()
549 styledata.textindex = columns["text"]
550 del columns["text"]
551 return symbol.setdata(self, graph, columns, styledata)
553 def selectstyle(self, selectindex, selecttotal, styledata):
554 if self.textattrs is not None:
555 styledata.textattrs = attr.selectattrs(self.defaulttextattrs + self.textattrs, selectindex, selecttotal)
556 else:
557 styledata.textattrs = None
558 symbol.selectstyle(self, selectindex, selecttotal, styledata)
560 def drawsymbol_pt(self, c, x, y, styledata, point=None):
561 symbol.drawsymbol_pt(self, c, x, y, styledata, point)
562 if None not in (x, y, point[styledata.textindex]) and styledata.textattrs is not None:
563 c.text_pt(x + styledata.textdx_pt, y + styledata.textdy_pt, str(point[styledata.textindex]), styledata.textattrs)
565 def drawpoints(self, points, graph, styledata):
566 styledata.textdx_pt = unit.topt(self.textdx)
567 styledata.textdy_pt = unit.topt(self.textdy)
568 symbol.drawpoints(self, points, graph, styledata)
571 class arrow(_style):
573 defaultlineattrs = []
574 defaultarrowattrs = []
576 def __init__(self, linelength=0.25*unit.v_cm, arrowsize=0.15*unit.v_cm, lineattrs=[], arrowattrs=[], epsilon=1e-10):
577 self.linelength = linelength
578 self.arrowsize = arrowsize
579 self.lineattrs = lineattrs
580 self.arrowattrs = arrowattrs
581 self.epsilon = epsilon
583 def setdata(self, graph, columns, styledata):
584 if len(graph.axisnames) != 2:
585 raise TypeError("arrow style restricted on two-dimensional graphs")
586 columns = columns.copy()
587 styledata.xaxis, styledata.xindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % graph.axisnames[0]))
588 styledata.yaxis, styledata.yindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % graph.axisnames[1]))
589 styledata.sizeindex = columns["size"]
590 del columns["size"]
591 styledata.angleindex = columns["angle"]
592 del columns["angle"]
593 return columns
595 def adjustaxes(self, points, columns, styledata):
596 if styledata.xindex in columns:
597 styledata.xaxis.adjustrange(points, styledata.xindex)
598 if styledata.yindex in columns:
599 styledata.yaxis.adjustrange(points, styledata.yindex)
601 def selectstyle(self, selectindex, selecttotal, styledata):
602 if self.lineattrs is not None:
603 styledata.lineattrs = attr.selectattrs(self.defaultlineattrs + self.lineattrs, selectindex, selecttotal)
604 else:
605 styledata.lineattrs = None
606 if self.arrowattrs is not None:
607 styledata.arrowattrs = attr.selectattrs(self.defaultarrowattrs + self.arrowattrs, selectindex, selecttotal)
608 else:
609 styledata.arrowattrs = None
611 def drawpoints(self, points, graph, styledata):
612 if styledata.lineattrs is not None and styledata.arrowattrs is not None:
613 linelength_pt = unit.topt(self.linelength)
614 for point in points:
615 xpos, ypos = graph.pos_pt(point[styledata.xindex], point[styledata.yindex], xaxis=styledata.xaxis, yaxis=styledata.yaxis)
616 if point[styledata.sizeindex] > self.epsilon:
617 dx = math.cos(point[styledata.angleindex]*math.pi/180)
618 dy = math.sin(point[styledata.angleindex]*math.pi/180)
619 x1 = xpos-0.5*dx*linelength_pt*point[styledata.sizeindex]
620 y1 = ypos-0.5*dy*linelength_pt*point[styledata.sizeindex]
621 x2 = xpos+0.5*dx*linelength_pt*point[styledata.sizeindex]
622 y2 = ypos+0.5*dy*linelength_pt*point[styledata.sizeindex]
623 graph.stroke(path.line_pt(x1, y1, x2, y2), styledata.lineattrs +
624 [deco.earrow(styledata.arrowattrs, size=self.arrowsize*point[styledata.sizeindex])])
627 class rect(_style):
629 def __init__(self, palette=color.palette.Gray):
630 self.palette = palette
632 def setdata(self, graph, columns, styledata):
633 if len(graph.axisnames) != 2:
634 raise TypeError("arrow style restricted on two-dimensional graphs")
635 columns = columns.copy()
636 styledata.xaxis, styledata.xminindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)min$" % graph.axisnames[0]))
637 styledata.yaxis, styledata.yminindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)min$" % graph.axisnames[1]))
638 xaxis, styledata.xmaxindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)max$" % graph.axisnames[0]))
639 yaxis, styledata.ymaxindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)max$" % graph.axisnames[1]))
640 if xaxis != styledata.xaxis or yaxis != styledata.yaxis:
641 raise ValueError("min/max values should use the same axes")
642 styledata.colorindex = columns["color"]
643 del columns["color"]
644 return columns
646 def selectstyle(self, selectindex, selecttotal, styledata):
647 pass
649 def adjustaxes(self, points, columns, styledata):
650 if styledata.xminindex in columns:
651 styledata.xaxis.adjustrange(points, styledata.xminindex)
652 if styledata.xmaxindex in columns:
653 styledata.xaxis.adjustrange(points, styledata.xmaxindex)
654 if styledata.yminindex in columns:
655 styledata.yaxis.adjustrange(points, styledata.yminindex)
656 if styledata.ymaxindex in columns:
657 styledata.yaxis.adjustrange(points, styledata.ymaxindex)
659 def drawpoints(self, points, graph, styledata):
660 # TODO: bbox shortcut
661 c = graph.insert(canvas.canvas())
662 lastcolorvalue = None
663 for point in points:
664 try:
665 xvmin = styledata.xaxis.convert(point[styledata.xminindex])
666 xvmax = styledata.xaxis.convert(point[styledata.xmaxindex])
667 yvmin = styledata.yaxis.convert(point[styledata.yminindex])
668 yvmax = styledata.yaxis.convert(point[styledata.ymaxindex])
669 colorvalue = point[styledata.colorindex]
670 if colorvalue != lastcolorvalue:
671 color = self.palette.getcolor(point[styledata.colorindex])
672 except:
673 continue
674 if ((xvmin < 0 and xvmax < 0) or (xvmin > 1 and xvmax > 1) or
675 (yvmin < 0 and yvmax < 0) or (yvmin > 1 and yvmax > 1)):
676 continue
677 if xvmin < 0:
678 xvmin = 0
679 elif xvmin > 1:
680 xvmin = 1
681 if xvmax < 0:
682 xvmax = 0
683 elif xvmax > 1:
684 xvmax = 1
685 if yvmin < 0:
686 yvmin = 0
687 elif yvmin > 1:
688 yvmin = 1
689 if yvmax < 0:
690 yvmax = 0
691 elif yvmax > 1:
692 yvmax = 1
693 p = graph.vgeodesic(xvmin, yvmin, xvmax, yvmin)
694 p.append(graph.vgeodesic_el(xvmax, yvmin, xvmax, yvmax))
695 p.append(graph.vgeodesic_el(xvmax, yvmax, xvmin, yvmax))
696 p.append(graph.vgeodesic_el(xvmin, yvmax, xvmin, yvmin))
697 p.append(path.closepath())
698 if colorvalue != lastcolorvalue:
699 c.set([color])
700 c.fill(p)
703 class bar(_style):
705 defaultfrompathattrs = []
706 defaultbarattrs = [color.palette.Rainbow, deco.stroked([color.gray.black])]
708 def __init__(self, fromvalue=None, frompathattrs=[], barattrs=[], subnames=None, epsilon=1e-10):
709 self.fromvalue = fromvalue
710 self.frompathattrs = frompathattrs
711 self.barattrs = barattrs
712 self.subnames = subnames
713 self.epsilon = epsilon
715 def setdata(self, graph, columns, styledata):
716 # TODO: remove limitation to 2d graphs
717 if len(graph.axisnames) != 2:
718 raise TypeError("arrow style currently restricted on two-dimensional graphs")
719 columns = columns.copy()
720 xvalue = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % graph.axisnames[0]))
721 yvalue = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % graph.axisnames[1]))
722 if (xvalue is None and yvalue is None) or (xvalue is not None and yvalue is not None):
723 raise TypeError("must specify exactly one value axis")
724 if xvalue is not None:
725 styledata.valuepos = 0
726 styledata.nameaxis, styledata.nameindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)name$" % graph.axisnames[1]))
727 styledata.valueaxis = xvalue[0]
728 styledata.valueindices = [xvalue[1]]
729 else:
730 styledata.valuepos = 1
731 styledata.nameaxis, styledata.nameindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)name$" % graph.axisnames[0]))
732 styledata.valueaxis = yvalue[0]
733 styledata.valueindices = [yvalue[1]]
734 i = 1
735 while 1:
736 try:
737 valueaxis, valueindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)stack%i$" % (graph.axisnames[styledata.valuepos], i)))
738 except:
739 break
740 if styledata.valueaxis != valueaxis:
741 raise ValueError("different value axes for stacked bars")
742 styledata.valueindices.append(valueindex)
743 i += 1
744 return columns
746 def selectstyle(self, selectindex, selecttotal, styledata):
747 if selectindex:
748 styledata.frompathattrs = None
749 else:
750 styledata.frompathattrs = self.defaultfrompathattrs + self.frompathattrs
751 if selecttotal > 1:
752 if self.barattrs is not None:
753 styledata.barattrs = attr.selectattrs(self.defaultbarattrs + self.barattrs, selectindex, selecttotal)
754 else:
755 styledata.barattrs = None
756 else:
757 styledata.barattrs = self.defaultbarattrs + self.barattrs
758 styledata.selectindex = selectindex
759 styledata.selecttotal = selecttotal
760 if styledata.selecttotal != 1 and self.subnames is not None:
761 raise ValueError("subnames not allowed when iterating over bars")
763 def adjustaxes(self, points, columns, styledata):
764 if styledata.nameindex in columns:
765 if styledata.selecttotal == 1:
766 styledata.nameaxis.adjustrange(points, styledata.nameindex, subnames=self.subnames)
767 else:
768 for i in range(styledata.selecttotal):
769 styledata.nameaxis.adjustrange(points, styledata.nameindex, subnames=[i])
770 for valueindex in styledata.valueindices:
771 if valueindex in columns:
772 styledata.valueaxis.adjustrange(points, valueindex)
774 def drawpoints(self, points, graph, styledata):
775 if self.fromvalue is not None:
776 vfromvalue = styledata.valueaxis.convert(self.fromvalue)
777 if vfromvalue < -self.epsilon:
778 vfromvalue = 0
779 if vfromvalue > 1 + self.epsilon:
780 vfromvalue = 1
781 if styledata.frompathattrs is not None and vfromvalue > self.epsilon and vfromvalue < 1 - self.epsilon:
782 if styledata.valuepos:
783 p = graph.vgeodesic(0, vfromvalue, 1, vfromvalue)
784 else:
785 p = graph.vgeodesic(vfromvalue, 0, vfromvalue, 1)
786 graph.stroke(p, styledata.frompathattrs)
787 else:
788 vfromvalue = 0
789 l = len(styledata.valueindices)
790 if l > 1:
791 barattrslist = []
792 for i in range(l):
793 barattrslist.append(attr.selectattrs(styledata.barattrs, i, l))
794 else:
795 barattrslist = [styledata.barattrs]
796 for point in points:
797 vvaluemax = vfromvalue
798 for valueindex, barattrs in zip(styledata.valueindices, barattrslist):
799 vvaluemin = vvaluemax
800 try:
801 vvaluemax = styledata.valueaxis.convert(point[valueindex])
802 except:
803 continue
805 if styledata.selecttotal == 1:
806 try:
807 vnamemin = styledata.nameaxis.convert((point[styledata.nameindex], 0))
808 except:
809 continue
810 try:
811 vnamemax = styledata.nameaxis.convert((point[styledata.nameindex], 1))
812 except:
813 continue
814 else:
815 try:
816 vnamemin = styledata.nameaxis.convert((point[styledata.nameindex], styledata.selectindex, 0))
817 except:
818 continue
819 try:
820 vnamemax = styledata.nameaxis.convert((point[styledata.nameindex], styledata.selectindex, 1))
821 except:
822 continue
824 if styledata.valuepos:
825 p = graph.vgeodesic(vnamemin, vvaluemin, vnamemin, vvaluemax)
826 p.append(graph.vgeodesic_el(vnamemin, vvaluemax, vnamemax, vvaluemax))
827 p.append(graph.vgeodesic_el(vnamemax, vvaluemax, vnamemax, vvaluemin))
828 p.append(graph.vgeodesic_el(vnamemax, vvaluemin, vnamemin, vvaluemin))
829 p.append(path.closepath())
830 else:
831 p = graph.vgeodesic(vvaluemin, vnamemin, vvaluemin, vnamemax)
832 p.append(graph.vgeodesic_el(vvaluemin, vnamemax, vvaluemax, vnamemax))
833 p.append(graph.vgeodesic_el(vvaluemax, vnamemax, vvaluemax, vnamemin))
834 p.append(graph.vgeodesic_el(vvaluemax, vnamemin, vvaluemin, vnamemin))
835 p.append(path.closepath())
836 if barattrs is not None:
837 graph.fill(p, barattrs)
839 def key_pt(self, c, x_pt, y_pt, width_pt, height_pt, styledata):
840 l = len(styledata.valueindices)
841 if l > 1:
842 for i in range(l):
843 c.fill(path.rect_pt(x_pt+i*width_pt/l, y_pt, width_pt/l, height_pt), attr.selectattrs(styledata.barattrs, i, l))
844 else:
845 c.fill(path.rect_pt(x_pt, y_pt, width_pt, height_pt), styledata.barattrs)