graph.data and graph.style reorganization in progress: introduced styledata
[PyX.git] / pyx / graph / style.py
blobdc7b525761d1c6fa674176cdc1463653ac1d2982
1 #!/usr/bin/env python
2 # -*- coding: ISO-8859-1 -*-
5 # Copyright (C) 2002-2004 Jörg Lehmann <joergl@users.sourceforge.net>
6 # Copyright (C) 2003-2004 Michael Schindler <m-schindler@users.sourceforge.net>
7 # Copyright (C) 2002-2004 André Wobst <wobsta@users.sourceforge.net>
9 # This file is part of PyX (http://pyx.sourceforge.net/).
11 # PyX is free software; you can redistribute it and/or modify
12 # it under the terms of the GNU General Public License as published by
13 # the Free Software Foundation; either version 2 of the License, or
14 # (at your option) any later version.
16 # PyX is distributed in the hope that it will be useful,
17 # but WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 # GNU General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with PyX; if not, write to the Free Software
23 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
26 import re, math
27 from pyx import attr, deco, style, color, unit, canvas, path
28 from pyx import text as textmodule
31 class _style:
33 def setdatapattern(self, graph, columns, pattern):
34 for datakey in columns.keys():
35 match = pattern.match(datakey)
36 if match:
37 # XXX match.groups()[0] must contain the full axisname
38 axisname = match.groups()[0]
39 index = columns[datakey]
40 del columns[datakey]
41 return graph.axes[axisname], index
43 def key_pt(self, c, x_pt, y_pt, width_pt, height_pt, styledata):
44 raise RuntimeError("style doesn't provide a key")
47 class symbolline(_style):
49 def cross(self, x_pt, y_pt, size_pt):
50 return (path.moveto_pt(x_pt-0.5*size_pt, y_pt-0.5*size_pt),
51 path.lineto_pt(x_pt+0.5*size_pt, y_pt+0.5*size_pt),
52 path.moveto_pt(x_pt-0.5*size_pt, y_pt+0.5*size_pt),
53 path.lineto_pt(x_pt+0.5*size_pt, y_pt-0.5*size_pt))
55 def plus(self, x_pt, y_pt, size_pt):
56 return (path.moveto_pt(x_pt-0.707106781*size_pt, y_pt),
57 path.lineto_pt(x_pt+0.707106781*size_pt, y_pt),
58 path.moveto_pt(x_pt, y_pt-0.707106781*size_pt),
59 path.lineto_pt(x_pt, y_pt+0.707106781*size_pt))
61 def square(self, x_pt, y_pt, size_pt):
62 return (path.moveto_pt(x_pt-0.5*size_pt, y_pt-0.5*size_pt),
63 path.lineto_pt(x_pt+0.5*size_pt, y_pt-0.5*size_pt),
64 path.lineto_pt(x_pt+0.5*size_pt, y_pt+0.5*size_pt),
65 path.lineto_pt(x_pt-0.5*size_pt, y_pt+0.5*size_pt),
66 path.closepath())
68 def triangle(self, x_pt, y_pt, size_pt):
69 return (path.moveto_pt(x_pt-0.759835685*size_pt, y_pt-0.438691337*size_pt),
70 path.lineto_pt(x_pt+0.759835685*size_pt, y_pt-0.438691337*size_pt),
71 path.lineto_pt(x_pt, y_pt+0.877382675*size_pt),
72 path.closepath())
74 def circle(self, x_pt, y_pt, size_pt):
75 return (path.arc_pt(x_pt, y_pt, 0.564189583*size_pt, 0, 360),
76 path.closepath())
78 def diamond(self, x_pt, y_pt, size_pt):
79 return (path.moveto_pt(x_pt-0.537284965*size_pt, y_pt),
80 path.lineto_pt(x_pt, y_pt-0.930604859*size_pt),
81 path.lineto_pt(x_pt+0.537284965*size_pt, y_pt),
82 path.lineto_pt(x_pt, y_pt+0.930604859*size_pt),
83 path.closepath())
85 changecross = attr.changelist([cross, plus, square, triangle, circle, diamond])
86 changeplus = attr.changelist([plus, square, triangle, circle, diamond, cross])
87 changesquare = attr.changelist([square, triangle, circle, diamond, cross, plus])
88 changetriangle = attr.changelist([triangle, circle, diamond, cross, plus, square])
89 changecircle = attr.changelist([circle, diamond, cross, plus, square, triangle])
90 changediamond = attr.changelist([diamond, cross, plus, square, triangle, circle])
91 changesquaretwice = attr.changelist([square, square, triangle, triangle, circle, circle, diamond, diamond])
92 changetriangletwice = attr.changelist([triangle, triangle, circle, circle, diamond, diamond, square, square])
93 changecircletwice = attr.changelist([circle, circle, diamond, diamond, square, square, triangle, triangle])
94 changediamondtwice = attr.changelist([diamond, diamond, square, square, triangle, triangle, circle, circle])
96 changestrokedfilled = attr.changelist([deco.stroked, deco.filled])
97 changefilledstroked = attr.changelist([deco.filled, deco.stroked])
99 changelinestyle = attr.changelist([style.linestyle.solid,
100 style.linestyle.dashed,
101 style.linestyle.dotted,
102 style.linestyle.dashdotted])
104 defaultsymbolattrs = [deco.stroked]
105 defaulterrorbarattrs = []
106 defaultlineattrs = [changelinestyle]
108 def __init__(self, symbol=changecross,
109 size="0.2 cm",
110 errorscale=0.5,
111 symbolattrs=[],
112 errorbarattrs=[],
113 lineattrs=[],
114 epsilon=1e-10):
115 self.size_str = size
116 self.symbol = symbol
117 self.errorscale = errorscale
118 self.symbolattrs = symbolattrs
119 self.errorbarattrs = errorbarattrs
120 self.lineattrs = lineattrs
121 self.epsilon = epsilon
123 def setdata(self, graph, columns, styledata):
125 - the instance should be considered read-only
126 (it might be shared between several data)
127 - styledata is the place where to store information
128 - returns the dictionary of columns not used by the style"""
130 # analyse column information
131 styledata.index = {} # a nested index dictionary containing
132 # column numbers, e.g. styledata.index["x"]["x"],
133 # styledata.index["y"]["dmin"] etc.; the first key is a axis
134 # name (without the axis number), the second is one of
135 # the datanames ["x", "min", "max", "d", "dmin", "dmax"]
136 styledata.axes = {} # mapping from axis name (without axis number) to the axis
138 columns = columns.copy()
139 for axisname in graph.axisnames:
140 for dataname, pattern in [("x", re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % axisname)),
141 ("min", re.compile(r"(%s([2-9]|[1-9][0-9]+)?)min$" % axisname)),
142 ("max", re.compile(r"(%s([2-9]|[1-9][0-9]+)?)max$" % axisname)),
143 ("d", re.compile(r"d(%s([2-9]|[1-9][0-9]+)?)$" % axisname)),
144 ("dmin", re.compile(r"d(%s([2-9]|[1-9][0-9]+)?)min$" % axisname)),
145 ("dmax", re.compile(r"d(%s([2-9]|[1-9][0-9]+)?)max$" % axisname))]:
146 matchresult = self.setdatapattern(graph, columns, pattern)
147 if matchresult is not None:
148 axis, index = matchresult
149 if styledata.axes.has_key(axisname):
150 if styledata.axes[axisname] != axis:
151 raise ValueError("axis mismatch for axis name '%s'" % axisname)
152 styledata.index[axisname][dataname] = index
153 else:
154 styledata.index[axisname] = {dataname: index}
155 styledata.axes[axisname] = axis
156 if not styledata.axes.has_key(axisname):
157 raise ValueError("missing columns for axis name '%s'" % axisname)
158 if ((styledata.index[axisname].has_key("min") and styledata.index[axisname].has_key("d")) or
159 (styledata.index[axisname].has_key("min") and styledata.index[axisname].has_key("dmin")) or
160 (styledata.index[axisname].has_key("d") and styledata.index[axisname].has_key("dmin")) or
161 (styledata.index[axisname].has_key("max") and styledata.index[axisname].has_key("d")) or
162 (styledata.index[axisname].has_key("max") and styledata.index[axisname].has_key("dmax")) or
163 (styledata.index[axisname].has_key("d") and styledata.index[axisname].has_key("dmax"))):
164 raise ValueError("multiple errorbar definition for axis name '%s'" % axisname)
165 if (not styledata.index[axisname].has_key("x") and
166 (styledata.index[axisname].has_key("d") or
167 styledata.index[axisname].has_key("dmin") or
168 styledata.index[axisname].has_key("dmax"))):
169 raise ValueError("errorbar definition start value missing for axis name '%s'" % axisname)
170 return columns
172 def selectstyle(self, selectindex, selecttotal, styledata):
173 styledata.symbol = attr.selectattr(self.symbol, selectindex, selecttotal)
174 styledata.size_pt = unit.topt(unit.length(attr.selectattr(self.size_str, selectindex, selecttotal), default_type="v"))
175 styledata.errorsize_pt = self.errorscale * styledata.size_pt
176 if self.symbolattrs is not None:
177 styledata.symbolattrs = attr.selectattrs(self.defaultsymbolattrs + self.symbolattrs, selectindex, selecttotal)
178 else:
179 styledata.symbolattrs = None
180 if self.errorbarattrs is not None:
181 styledata.errorbarattrs = attr.selectattrs(self.defaulterrorbarattrs + self.errorbarattrs, selectindex, selecttotal)
182 else:
183 styledata.errorbarattrs = None
184 if self.lineattrs is not None:
185 styledata.lineattrs = attr.selectattrs(self.defaultlineattrs + self.lineattrs, selectindex, selecttotal)
186 else:
187 styledata.lineattrs = None
189 def adjustaxes(self, columns, styledata):
190 # reverse lookup for axisnames
191 # TODO: the reverse lookup is ugly
192 axisnames = []
193 for column in columns:
194 for axisname in styledata.index.keys():
195 for thiscolumn in styledata.index[axisname].values():
196 if thiscolumn == column and axisname not in axisnames:
197 axisnames.append(axisname)
198 # TODO: perform check to verify that all columns for a given axisname are available at the same time
199 for axisname in axisnames:
200 if styledata.index[axisname].has_key("x"):
201 styledata.axes[axisname].adjustrange(styledata.points, styledata.index[axisname]["x"])
202 if styledata.index[axisname].has_key("min"):
203 styledata.axes[axisname].adjustrange(styledata.points, styledata.index[axisname]["min"])
204 if styledata.index[axisname].has_key("max"):
205 styledata.axes[axisname].adjustrange(styledata.points, styledata.index[axisname]["max"])
206 if styledata.index[axisname].has_key("d"):
207 styledata.axes[axisname].adjustrange(styledata.points, styledata.index[axisname]["x"], deltaindex=styledata.index[axisname]["d"])
208 if styledata.index[axisname].has_key("dmin"):
209 styledata.axes[axisname].adjustrange(styledata.points, styledata.index[axisname]["x"], deltaminindex=styledata.index[axisname]["dmin"])
210 if styledata.index[axisname].has_key("dmax"):
211 styledata.axes[axisname].adjustrange(styledata.points, styledata.index[axisname]["x"], deltamaxindex=styledata.index[axisname]["dmax"])
213 def drawsymbol_pt(self, c, x_pt, y_pt, styledata, point=None):
214 if styledata.symbolattrs is not None:
215 c.draw(path.path(*styledata.symbol(self, x_pt, y_pt, styledata.size_pt)), styledata.symbolattrs)
217 def vpos(self, styledata):
218 # calculate vpos
219 styledata.vpos = [] # list containing the graph coordinates of the point
220 styledata.validvpos = 1 # valid position (but might be outside of the graph)
221 styledata.drawsymbol = 1 # valid position inside the graph
222 for axisname in styledata.graph.axisnames:
223 try:
224 v = styledata.axes[axisname].convert(styledata.point[styledata.index[axisname]["x"]])
225 except (ArithmeticError, KeyError, ValueError, TypeError):
226 styledata.validvpos = 0
227 styledata.drawsymbol = 0
228 styledata.vpos.append(None)
229 else:
230 if v < - self.epsilon or v > 1 + self.epsilon:
231 styledata.drawsymbol = 0
232 styledata.vpos.append(v)
234 def appendlinebasepoints(self, styledata):
235 # append linebasepoints
236 if styledata.validvpos:
237 if len(styledata.linebasepoints):
238 # the last point was inside the graph
239 if styledata.drawsymbol:
240 styledata.linebasepoints.append((styledata.xpos, styledata.ypos))
241 else:
242 # cut end
243 cut = 1
244 for vstart, vend in zip(styledata.lastvpos, styledata.vpos):
245 newcut = None
246 if vend > 1:
247 # 1 = vstart + (vend - vstart) * cut
248 try:
249 newcut = (1 - vstart)/(vend - vstart)
250 except ArithmeticError:
251 break
252 if vend < 0:
253 # 0 = vstart + (vend - vstart) * cut
254 try:
255 newcut = - vstart/(vend - vstart)
256 except ArithmeticError:
257 break
258 if newcut is not None and newcut < cut:
259 cut = newcut
260 else:
261 cutvpos = []
262 for vstart, vend in zip(styledata.lastvpos, styledata.vpos):
263 cutvpos.append(vstart + (vend - vstart) * cut)
264 styledata.linebasepoints.append(styledata.graph.vpos_pt(*cutvpos))
265 styledata.validvpos = 0 # clear linebasepoints below
266 else:
267 # the last point was outside the graph
268 if styledata.lastvpos is not None:
269 if styledata.drawsymbol:
270 # cut beginning
271 cut = 0
272 for vstart, vend in zip(styledata.lastvpos, styledata.vpos):
273 newcut = None
274 if vstart > 1:
275 # 1 = vstart + (vend - vstart) * cut
276 try:
277 newcut = (1 - vstart)/(vend - vstart)
278 except ArithmeticError:
279 break
280 if vstart < 0:
281 # 0 = vstart + (vend - vstart) * cut
282 try:
283 newcut = - vstart/(vend - vstart)
284 except ArithmeticError:
285 break
286 if newcut is not None and newcut > cut:
287 cut = newcut
288 else:
289 cutvpos = []
290 for vstart, vend in zip(styledata.lastvpos, styledata.vpos):
291 cutvpos.append(vstart + (vend - vstart) * cut)
292 styledata.linebasepoints.append(styledata.graph.vpos_pt(*cutvpos))
293 styledata.linebasepoints.append(styledata.graph.vpos_pt(*styledata.vpos))
294 else:
295 # sometimes cut beginning and end
296 cutfrom = 0
297 cutto = 1
298 for vstart, vend in zip(styledata.lastvpos, styledata.vpos):
299 newcutfrom = None
300 if vstart > 1:
301 if vend > 1:
302 break
303 # 1 = vstart + (vend - vstart) * cutfrom
304 try:
305 newcutfrom = (1 - vstart)/(vend - vstart)
306 except ArithmeticError:
307 break
308 if vstart < 0:
309 if vend < 0:
310 break
311 # 0 = vstart + (vend - vstart) * cutfrom
312 try:
313 newcutfrom = - vstart/(vend - vstart)
314 except ArithmeticError:
315 break
316 if newcutfrom is not None and newcutfrom > cutfrom:
317 cutfrom = newcutfrom
318 newcutto = None
319 if vend > 1:
320 # 1 = vstart + (vend - vstart) * cutto
321 try:
322 newcutto = (1 - vstart)/(vend - vstart)
323 except ArithmeticError:
324 break
325 if vend < 0:
326 # 0 = vstart + (vend - vstart) * cutto
327 try:
328 newcutto = - vstart/(vend - vstart)
329 except ArithmeticError:
330 break
331 if newcutto is not None and newcutto < cutto:
332 cutto = newcutto
333 else:
334 if cutfrom < cutto:
335 cutfromvpos = []
336 cuttovpos = []
337 for vstart, vend in zip(styledata.lastvpos, styledata.vpos):
338 cutfromvpos.append(vstart + (vend - vstart) * cutfrom)
339 cuttovpos.append(vstart + (vend - vstart) * cutto)
340 styledata.linebasepoints.append(styledata.graph.vpos_pt(*cutfromvpos))
341 styledata.linebasepoints.append(styledata.graph.vpos_pt(*cuttovpos))
342 styledata.validvpos = 0 # clear linebasepoints below
343 styledata.lastvpos = styledata.vpos
344 else:
345 styledata.lastvpos = None
347 def addpointstopath(self, styledata):
348 # add baselinepoints to styledata.path
349 if len(styledata.linebasepoints) > 1:
350 styledata.path.append(path.moveto_pt(*styledata.linebasepoints[0]))
351 if len(styledata.linebasepoints) > 2:
352 styledata.path.append(path.multilineto_pt(styledata.linebasepoints[1:]))
353 else:
354 styledata.path.append(path.lineto_pt(*styledata.linebasepoints[1]))
355 styledata.linebasepoints = []
357 def doerrorbars(self, styledata):
358 # errorbar loop over the different direction having errorbars
359 for erroraxisname, erroraxisindex in styledata.errorlist:
361 # check for validity of other point components
362 i = 0
363 for v in styledata.vpos:
364 if v is None and i != erroraxisindex:
365 break
366 i += 1
367 else:
368 # calculate min and max
369 errorindex = styledata.index[erroraxisname]
370 try:
371 min = styledata.point[errorindex["x"]] - styledata.point[errorindex["d"]]
372 except:
373 try:
374 min = styledata.point[errorindex["x"]] - styledata.point[errorindex["dmin"]]
375 except:
376 try:
377 min = styledata.point[errorindex["min"]]
378 except:
379 min = None
380 try:
381 max = styledata.point[errorindex["x"]] + styledata.point[errorindex["d"]]
382 except:
383 try:
384 max = styledata.point[errorindex["x"]] + styledata.point[errorindex["dmax"]]
385 except:
386 try:
387 max = styledata.point[errorindex["max"]]
388 except:
389 max = None
391 # calculate vmin and vmax
392 try:
393 vmin = styledata.axes[erroraxisname].convert(min)
394 except:
395 vmin = None
396 try:
397 vmax = styledata.axes[erroraxisname].convert(max)
398 except:
399 vmax = None
401 # create vminpos and vmaxpos
402 vcaps = []
403 if vmin is not None:
404 vminpos = styledata.vpos[:]
405 if vmin > - self.epsilon and vmin < 1 + self.epsilon:
406 vminpos[erroraxisindex] = vmin
407 vcaps.append(vminpos)
408 else:
409 vminpos[erroraxisindex] = 0
410 elif styledata.vpos[erroraxisindex] is not None:
411 vminpos = styledata.vpos
412 else:
413 break
414 if vmax is not None:
415 vmaxpos = styledata.vpos[:]
416 if vmax > - self.epsilon and vmax < 1 + self.epsilon:
417 vmaxpos[erroraxisindex] = vmax
418 vcaps.append(vmaxpos)
419 else:
420 vmaxpos[erroraxisindex] = 1
421 elif styledata.vpos[erroraxisindex] is not None:
422 vmaxpos = styledata.vpos
423 else:
424 break
426 # create path for errorbars
427 errorpath = path.path()
428 errorpath += styledata.graph.vgeodesic(*(vminpos + vmaxpos))
429 for vcap in vcaps:
430 for axisname in styledata.graph.axisnames:
431 if axisname != erroraxisname:
432 errorpath += styledata.graph.vcap_pt(axisname, styledata.errorsize_pt, *vcap)
434 # stroke errorpath
435 if len(errorpath.path):
436 styledata.errorbarcanvas.stroke(errorpath, styledata.errorbarattrs)
438 def drawpoints(self, graph, styledata):
439 if styledata.lineattrs is not None:
440 # TODO: bbox shortcut
441 linecanvas = graph.insert(canvas.canvas())
442 if styledata.errorbarattrs is not None:
443 # TODO: bbox shortcut
444 styledata.errorbarcanvas = graph.insert(canvas.canvas())
445 styledata.path = path.path()
446 styledata.linebasepoints = []
447 styledata.lastvpos = None
448 styledata.errorlist = []
449 styledata.graph = graph
450 if styledata.errorbarattrs is not None:
451 axisindex = 0
452 for axisname in graph.axisnames:
453 if styledata.index[axisname].keys() != ["x"]:
454 styledata.errorlist.append((axisname, axisindex))
455 axisindex += 1
457 for point in styledata.points:
458 styledata.point = point
459 self.vpos(styledata)
460 if styledata.drawsymbol:
461 styledata.xpos, styledata.ypos = graph.vpos_pt(*styledata.vpos)
462 self.drawsymbol_pt(graph, styledata.xpos, styledata.ypos, styledata, point=styledata.point)
463 self.appendlinebasepoints(styledata)
464 if not styledata.validvpos:
465 self.addpointstopath(styledata)
466 self.doerrorbars(styledata)
467 self.addpointstopath(styledata)
469 # stroke styledata.path
470 if styledata.lineattrs is not None:
471 linecanvas.stroke(styledata.path, styledata.lineattrs)
473 def key_pt(self, c, x_pt, y_pt, width_pt, height_pt, styledata):
474 self.drawsymbol_pt(c, x_pt+0.5*width_pt, y_pt+0.5*height_pt, styledata)
475 if styledata.lineattrs is not None:
476 c.stroke(path.line_pt(x_pt, y_pt+0.5*height_pt, x_pt+width_pt, y_pt+0.5*height_pt), styledata.lineattrs)
479 class symbol(symbolline):
481 def __init__(self, *args, **kwargs):
482 symbolline.__init__(self, lineattrs=None, *args, **kwargs)
485 class line(symbolline):
487 def __init__(self, lineattrs=[]):
488 symbolline.__init__(self, symbolattrs=None, errorbarattrs=None, lineattrs=lineattrs)
491 class text(symbol):
493 defaulttextattrs = [textmodule.halign.center, textmodule.vshift.mathaxis]
495 def __init__(self, textdx="0", textdy="0.3 cm", textattrs=[], **kwargs):
496 self.textdx_str = textdx
497 self.textdy_str = textdy
498 self.textattrs = textattrs
499 symbol.__init__(self, **kwargs)
501 def setdata(self, graph, columns, data):
502 columns = columns.copy()
503 data.textindex = columns["text"]
504 del columns["text"]
505 return symbol.setdata(self, graph, columns, data)
507 def selectstyle(self, selectindex, selecttotal, data):
508 if self.textattrs is not None:
509 data.textattrs = attr.selectattrs(self.defaulttextattrs + self.textattrs, selectindex, selecttotal)
510 else:
511 data.textattrs = None
512 symbol.selectstyle(self, selectindex, selecttotal, data)
514 def drawsymbol_pt(self, c, x, y, data, point=None):
515 symbol.drawsymbol_pt(self, c, x, y, data, point)
516 if None not in (x, y, point[data.textindex]) and data.textattrs is not None:
517 c.text_pt(x + data.textdx_pt, y + data.textdy_pt, str(point[data.textindex]), data.textattrs)
519 def drawpoints(self, graph, data):
520 data.textdx = unit.length(self.textdx_str, default_type="v")
521 data.textdy = unit.length(self.textdy_str, default_type="v")
522 data.textdx_pt = unit.topt(data.textdx)
523 data.textdy_pt = unit.topt(data.textdy)
524 symbol.drawpoints(self, graph, data)
527 class arrow(_style):
529 defaultlineattrs = []
530 defaultarrowattrs = []
532 def __init__(self, linelength="0.25 cm", arrowsize="0.15 cm", lineattrs=[], arrowattrs=[], epsilon=1e-10):
533 self.linelength_str = linelength
534 self.arrowsize_str = arrowsize
535 self.lineattrs = lineattrs
536 self.arrowattrs = arrowattrs
537 self.epsilon = epsilon
539 def setdata(self, graph, columns, data):
540 if len(graph.axisnames) != 2:
541 raise TypeError("arrow style restricted on two-dimensional graphs")
542 columns = columns.copy()
543 data.xaxis, data.xindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % graph.axisnames[0]))
544 data.yaxis, data.yindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % graph.axisnames[1]))
545 data.sizeindex = columns["size"]
546 del columns["size"]
547 data.angleindex = columns["angle"]
548 del columns["angle"]
549 return columns
551 def adjustaxes(self, columns, data):
552 if data.xindex in columns:
553 data.xaxis.adjustrange(data.points, data.xindex)
554 if data.yindex in columns:
555 data.yaxis.adjustrange(data.points, data.yindex)
557 def selectstyle(self, selectindex, selecttotal, data):
558 if self.lineattrs is not None:
559 data.lineattrs = attr.selectattrs(self.defaultlineattrs + self.lineattrs, selectindex, selecttotal)
560 else:
561 data.lineattrs = None
562 if self.arrowattrs is not None:
563 data.arrowattrs = attr.selectattrs(self.defaultarrowattrs + self.arrowattrs, selectindex, selecttotal)
564 else:
565 data.arrowattrs = None
567 def drawpoints(self, graph, data):
568 if data.lineattrs is not None and data.arrowattrs is not None:
569 arrowsize = unit.length(self.arrowsize_str, default_type="v")
570 linelength = unit.length(self.linelength_str, default_type="v")
571 arrowsize_pt = unit.topt(arrowsize)
572 linelength_pt = unit.topt(linelength)
573 for point in data.points:
574 xpos, ypos = graph.pos_pt(point[data.xindex], point[data.yindex], xaxis=data.xaxis, yaxis=data.yaxis)
575 if point[data.sizeindex] > self.epsilon:
576 dx = math.cos(point[data.angleindex]*math.pi/180)
577 dy = math.sin(point[data.angleindex]*math.pi/180)
578 x1 = xpos-0.5*dx*linelength_pt*point[data.sizeindex]
579 y1 = ypos-0.5*dy*linelength_pt*point[data.sizeindex]
580 x2 = xpos+0.5*dx*linelength_pt*point[data.sizeindex]
581 y2 = ypos+0.5*dy*linelength_pt*point[data.sizeindex]
582 graph.stroke(path.line_pt(x1, y1, x2, y2), data.lineattrs +
583 [deco.earrow(data.arrowattrs, size=arrowsize*point[data.sizeindex])])
586 class rect(_style):
588 def __init__(self, palette=color.palette.Gray):
589 self.palette = palette
591 def setdata(self, graph, columns, data):
592 if len(graph.axisnames) != 2:
593 raise TypeError("arrow style restricted on two-dimensional graphs")
594 columns = columns.copy()
595 data.xaxis, data.xminindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)min$" % graph.axisnames[0]))
596 data.yaxis, data.yminindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)min$" % graph.axisnames[1]))
597 xaxis, data.xmaxindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)max$" % graph.axisnames[0]))
598 yaxis, data.ymaxindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)max$" % graph.axisnames[1]))
599 if xaxis != data.xaxis or yaxis != data.yaxis:
600 raise ValueError("min/max values should use the same axes")
601 data.colorindex = columns["color"]
602 del columns["color"]
603 return columns
605 def selectstyle(self, selectindex, selecttotal, data):
606 pass
608 def adjustaxes(self, columns, data):
609 if data.xminindex in columns:
610 data.xaxis.adjustrange(data.points, data.xminindex)
611 if data.xmaxindex in columns:
612 data.xaxis.adjustrange(data.points, data.xmaxindex)
613 if data.yminindex in columns:
614 data.yaxis.adjustrange(data.points, data.yminindex)
615 if data.ymaxindex in columns:
616 data.yaxis.adjustrange(data.points, data.ymaxindex)
618 def drawpoints(self, graph, data):
619 # TODO: bbox shortcut
620 c = graph.insert(canvas.canvas())
621 lastcolorvalue = None
622 for point in data.points:
623 try:
624 xvmin = data.xaxis.convert(point[data.xminindex])
625 xvmax = data.xaxis.convert(point[data.xmaxindex])
626 yvmin = data.yaxis.convert(point[data.yminindex])
627 yvmax = data.yaxis.convert(point[data.ymaxindex])
628 colorvalue = point[data.colorindex]
629 if colorvalue != lastcolorvalue:
630 color = self.palette.getcolor(point[data.colorindex])
631 except:
632 continue
633 if ((xvmin < 0 and xvmax < 0) or (xvmin > 1 and xvmax > 1) or
634 (yvmin < 0 and yvmax < 0) or (yvmin > 1 and yvmax > 1)):
635 continue
636 if xvmin < 0:
637 xvmin = 0
638 elif xvmin > 1:
639 xvmin = 1
640 if xvmax < 0:
641 xvmax = 0
642 elif xvmax > 1:
643 xvmax = 1
644 if yvmin < 0:
645 yvmin = 0
646 elif yvmin > 1:
647 yvmin = 1
648 if yvmax < 0:
649 yvmax = 0
650 elif yvmax > 1:
651 yvmax = 1
652 p = graph.vgeodesic(xvmin, yvmin, xvmax, yvmin)
653 p.append(graph.vgeodesic_el(xvmax, yvmin, xvmax, yvmax))
654 p.append(graph.vgeodesic_el(xvmax, yvmax, xvmin, yvmax))
655 p.append(graph.vgeodesic_el(xvmin, yvmax, xvmin, yvmin))
656 p.append(path.closepath())
657 if colorvalue != lastcolorvalue:
658 c.set([color])
659 c.fill(p)
661 class bar(_style):
663 defaultfrompathattrs = []
664 defaultbarattrs = [color.palette.Rainbow, deco.stroked([color.gray.black])]
666 def __init__(self, fromvalue=None, frompathattrs=[], barattrs=[], subnames=None, epsilon=1e-10):
667 self.fromvalue = fromvalue
668 self.frompathattrs = frompathattrs
669 self.barattrs = barattrs
670 self.subnames = subnames
671 self.epsilon = epsilon
673 def setdata(self, graph, columns, data):
674 # TODO: remove limitation to 2d graphs
675 if len(graph.axisnames) != 2:
676 raise TypeError("arrow style currently restricted on two-dimensional graphs")
677 columns = columns.copy()
678 xvalue = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % graph.axisnames[0]))
679 yvalue = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % graph.axisnames[1]))
680 if (xvalue is None and yvalue is None) or (xvalue is not None and yvalue is not None):
681 raise TypeError("must specify exactly one value axis")
682 if xvalue is not None:
683 data.valuepos = 0
684 data.nameaxis, data.nameindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)name$" % graph.axisnames[1]))
685 data.valueaxis = xvalue[0]
686 data.valueindices = [xvalue[1]]
687 else:
688 data.valuepos = 1
689 data.nameaxis, data.nameindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)name$" % graph.axisnames[0]))
690 data.valueaxis = yvalue[0]
691 data.valueindices = [yvalue[1]]
692 i = 1
693 while 1:
694 try:
695 valueaxis, valueindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)stack%i$" % (graph.axisnames[data.valuepos], i)))
696 except:
697 break
698 if data.valueaxis != valueaxis:
699 raise ValueError("different value axes for stacked bars")
700 data.valueindices.append(valueindex)
701 i += 1
702 return columns
704 def selectstyle(self, selectindex, selecttotal, data):
705 if selectindex:
706 data.frompathattrs = None
707 else:
708 data.frompathattrs = self.defaultfrompathattrs + self.frompathattrs
709 if selecttotal > 1:
710 if self.barattrs is not None:
711 data.barattrs = attr.selectattrs(self.defaultbarattrs + self.barattrs, selectindex, selecttotal)
712 else:
713 data.barattrs = None
714 else:
715 data.barattrs = self.defaultbarattrs + self.barattrs
716 data.selectindex = selectindex
717 data.selecttotal = selecttotal
718 if data.selecttotal != 1 and self.subnames is not None:
719 raise ValueError("subnames not allowed when iterating over bars")
721 def adjustaxes(self, columns, data):
722 if data.nameindex in columns:
723 if data.selecttotal == 1:
724 data.nameaxis.adjustrange(data.points, data.nameindex, subnames=self.subnames)
725 else:
726 for i in range(data.selecttotal):
727 data.nameaxis.adjustrange(data.points, data.nameindex, subnames=[i])
728 for valueindex in data.valueindices:
729 if valueindex in columns:
730 data.valueaxis.adjustrange(data.points, valueindex)
732 def drawpoints(self, graph, data):
733 if self.fromvalue is not None:
734 vfromvalue = data.valueaxis.convert(self.fromvalue)
735 if vfromvalue < -self.epsilon:
736 vfromvalue = 0
737 if vfromvalue > 1 + self.epsilon:
738 vfromvalue = 1
739 if data.frompathattrs is not None and vfromvalue > self.epsilon and vfromvalue < 1 - self.epsilon:
740 if data.valuepos:
741 p = graph.vgeodesic(0, vfromvalue, 1, vfromvalue)
742 else:
743 p = graph.vgeodesic(vfromvalue, 0, vfromvalue, 1)
744 graph.stroke(p, data.frompathattrs)
745 else:
746 vfromvalue = 0
747 l = len(data.valueindices)
748 if l > 1:
749 barattrslist = []
750 for i in range(l):
751 barattrslist.append(attr.selectattrs(data.barattrs, i, l))
752 else:
753 barattrslist = [data.barattrs]
754 for point in data.points:
755 vvaluemax = vfromvalue
756 for valueindex, barattrs in zip(data.valueindices, barattrslist):
757 vvaluemin = vvaluemax
758 try:
759 vvaluemax = data.valueaxis.convert(point[valueindex])
760 except:
761 continue
763 if data.selecttotal == 1:
764 try:
765 vnamemin = data.nameaxis.convert((point[data.nameindex], 0))
766 except:
767 continue
768 try:
769 vnamemax = data.nameaxis.convert((point[data.nameindex], 1))
770 except:
771 continue
772 else:
773 try:
774 vnamemin = data.nameaxis.convert((point[data.nameindex], data.selectindex, 0))
775 except:
776 continue
777 try:
778 vnamemax = data.nameaxis.convert((point[data.nameindex], data.selectindex, 1))
779 except:
780 continue
782 if data.valuepos:
783 p = graph.vgeodesic(vnamemin, vvaluemin, vnamemin, vvaluemax)
784 p.append(graph.vgeodesic_el(vnamemin, vvaluemax, vnamemax, vvaluemax))
785 p.append(graph.vgeodesic_el(vnamemax, vvaluemax, vnamemax, vvaluemin))
786 p.append(graph.vgeodesic_el(vnamemax, vvaluemin, vnamemin, vvaluemin))
787 p.append(path.closepath())
788 else:
789 p = graph.vgeodesic(vvaluemin, vnamemin, vvaluemin, vnamemax)
790 p.append(graph.vgeodesic_el(vvaluemin, vnamemax, vvaluemax, vnamemax))
791 p.append(graph.vgeodesic_el(vvaluemax, vnamemax, vvaluemax, vnamemin))
792 p.append(graph.vgeodesic_el(vvaluemax, vnamemin, vvaluemin, vnamemin))
793 p.append(path.closepath())
794 if barattrs is not None:
795 graph.fill(p, barattrs)
797 def key_pt(self, c, x_pt, y_pt, width_pt, height_pt, data):
798 l = len(data.valueindices)
799 if l > 1:
800 for i in range(l):
801 c.fill(path.rect_pt(x_pt+i*width_pt/l, y_pt, width_pt/l, height_pt), attr.selectattrs(data.barattrs, i, l))
802 else:
803 c.fill(path.rect_pt(x_pt, y_pt, width_pt, height_pt), data.barattrs)