allow for non-keyword arguments in symbol
[PyX.git] / pyx / graph / style.py
blob763b0f9511c2d7dddab5485838f29331ad323570
1 #!/usr/bin/env python
2 # -*- coding: ISO-8859-1 -*-
5 # Copyright (C) 2002-2004 Jörg Lehmann <joergl@users.sourceforge.net>
6 # Copyright (C) 2003-2004 Michael Schindler <m-schindler@users.sourceforge.net>
7 # Copyright (C) 2002-2004 André Wobst <wobsta@users.sourceforge.net>
9 # This file is part of PyX (http://pyx.sourceforge.net/).
11 # PyX is free software; you can redistribute it and/or modify
12 # it under the terms of the GNU General Public License as published by
13 # the Free Software Foundation; either version 2 of the License, or
14 # (at your option) any later version.
16 # PyX is distributed in the hope that it will be useful,
17 # but WITHOUT ANY WARRANTY; without even the implied warranty of
18 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 # GNU General Public License for more details.
21 # You should have received a copy of the GNU General Public License
22 # along with PyX; if not, write to the Free Software
23 # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
26 import re, math
27 from pyx import attr, deco, style, color, unit, canvas, path
28 from pyx import text as textmodule
31 class _style:
33 def setdatapattern(self, graph, columns, pattern):
34 for datakey in columns.keys():
35 match = pattern.match(datakey)
36 if match:
37 # XXX match.groups()[0] must contain the full axisname
38 axisname = match.groups()[0]
39 index = columns[datakey]
40 del columns[datakey]
41 return graph.axes[axisname], index
43 def key_pt(self, c, x_pt, y_pt, width_pt, height_pt, data):
44 raise RuntimeError("style doesn't provide a key")
47 class symbolline(_style):
49 def cross(self, x_pt, y_pt, size_pt):
50 return (path.moveto_pt(x_pt-0.5*size_pt, y_pt-0.5*size_pt),
51 path.lineto_pt(x_pt+0.5*size_pt, y_pt+0.5*size_pt),
52 path.moveto_pt(x_pt-0.5*size_pt, y_pt+0.5*size_pt),
53 path.lineto_pt(x_pt+0.5*size_pt, y_pt-0.5*size_pt))
55 def plus(self, x_pt, y_pt, size_pt):
56 return (path.moveto_pt(x_pt-0.707106781*size_pt, y_pt),
57 path.lineto_pt(x_pt+0.707106781*size_pt, y_pt),
58 path.moveto_pt(x_pt, y_pt-0.707106781*size_pt),
59 path.lineto_pt(x_pt, y_pt+0.707106781*size_pt))
61 def square(self, x_pt, y_pt, size_pt):
62 return (path.moveto_pt(x_pt-0.5*size_pt, y_pt-0.5*size_pt),
63 path.lineto_pt(x_pt+0.5*size_pt, y_pt-0.5*size_pt),
64 path.lineto_pt(x_pt+0.5*size_pt, y_pt+0.5*size_pt),
65 path.lineto_pt(x_pt-0.5*size_pt, y_pt+0.5*size_pt),
66 path.closepath())
68 def triangle(self, x_pt, y_pt, size_pt):
69 return (path.moveto_pt(x_pt-0.759835685*size_pt, y_pt-0.438691337*size_pt),
70 path.lineto_pt(x_pt+0.759835685*size_pt, y_pt-0.438691337*size_pt),
71 path.lineto_pt(x_pt, y_pt+0.877382675*size_pt),
72 path.closepath())
74 def circle(self, x_pt, y_pt, size_pt):
75 return (path.arc_pt(x_pt, y_pt, 0.564189583*size_pt, 0, 360),
76 path.closepath())
78 def diamond(self, x_pt, y_pt, size_pt):
79 return (path.moveto_pt(x_pt-0.537284965*size_pt, y_pt),
80 path.lineto_pt(x_pt, y_pt-0.930604859*size_pt),
81 path.lineto_pt(x_pt+0.537284965*size_pt, y_pt),
82 path.lineto_pt(x_pt, y_pt+0.930604859*size_pt),
83 path.closepath())
85 changecross = attr.changelist([cross, plus, square, triangle, circle, diamond])
86 changeplus = attr.changelist([plus, square, triangle, circle, diamond, cross])
87 changesquare = attr.changelist([square, triangle, circle, diamond, cross, plus])
88 changetriangle = attr.changelist([triangle, circle, diamond, cross, plus, square])
89 changecircle = attr.changelist([circle, diamond, cross, plus, square, triangle])
90 changediamond = attr.changelist([diamond, cross, plus, square, triangle, circle])
91 changesquaretwice = attr.changelist([square, square, triangle, triangle, circle, circle, diamond, diamond])
92 changetriangletwice = attr.changelist([triangle, triangle, circle, circle, diamond, diamond, square, square])
93 changecircletwice = attr.changelist([circle, circle, diamond, diamond, square, square, triangle, triangle])
94 changediamondtwice = attr.changelist([diamond, diamond, square, square, triangle, triangle, circle, circle])
96 changestrokedfilled = attr.changelist([deco.stroked, deco.filled])
97 changefilledstroked = attr.changelist([deco.filled, deco.stroked])
99 changelinestyle = attr.changelist([style.linestyle.solid,
100 style.linestyle.dashed,
101 style.linestyle.dotted,
102 style.linestyle.dashdotted])
104 defaultsymbolattrs = [deco.stroked]
105 defaulterrorbarattrs = []
106 defaultlineattrs = [changelinestyle]
108 def __init__(self, symbol=changecross,
109 size="0.2 cm",
110 errorscale=0.5,
111 symbolattrs=[],
112 errorbarattrs=[],
113 lineattrs=[],
114 epsilon=1e-10):
115 self.size_str = size
116 self.symbol = symbol
117 self.errorscale = errorscale
118 self.symbolattrs = symbolattrs
119 self.errorbarattrs = errorbarattrs
120 self.lineattrs = lineattrs
121 self.epsilon = epsilon
123 def setdata(self, graph, columns, data):
125 - the instance should be considered read-only
126 (it might be shared between several data)
127 - data is the place where to store information
128 - returns the dictionary of columns not used by the style"""
130 # analyse column information
131 data.index = {} # a nested index dictionary containing
132 # column numbers, e.g. data.index["x"]["x"],
133 # data.index["y"]["dmin"] etc.; the first key is a axis
134 # name (without the axis number), the second is one of
135 # the datanames ["x", "min", "max", "d", "dmin", "dmax"]
136 data.axes = {} # mapping from axis name (without axis number) to the axis
138 columns = columns.copy()
139 for axisname in graph.axisnames:
140 for dataname, pattern in [("x", re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % axisname)),
141 ("min", re.compile(r"(%s([2-9]|[1-9][0-9]+)?)min$" % axisname)),
142 ("max", re.compile(r"(%s([2-9]|[1-9][0-9]+)?)max$" % axisname)),
143 ("d", re.compile(r"d(%s([2-9]|[1-9][0-9]+)?)$" % axisname)),
144 ("dmin", re.compile(r"d(%s([2-9]|[1-9][0-9]+)?)min$" % axisname)),
145 ("dmax", re.compile(r"d(%s([2-9]|[1-9][0-9]+)?)max$" % axisname))]:
146 matchresult = self.setdatapattern(graph, columns, pattern)
147 if matchresult is not None:
148 axis, index = matchresult
149 if data.axes.has_key(axisname):
150 if data.axes[axisname] != axis:
151 raise ValueError("axis mismatch for axis name '%s'" % axisname)
152 data.index[axisname][dataname] = index
153 else:
154 data.index[axisname] = {dataname: index}
155 data.axes[axisname] = axis
156 if not data.axes.has_key(axisname):
157 raise ValueError("missing columns for axis name '%s'" % axisname)
158 if ((data.index[axisname].has_key("min") and data.index[axisname].has_key("d")) or
159 (data.index[axisname].has_key("min") and data.index[axisname].has_key("dmin")) or
160 (data.index[axisname].has_key("d") and data.index[axisname].has_key("dmin")) or
161 (data.index[axisname].has_key("max") and data.index[axisname].has_key("d")) or
162 (data.index[axisname].has_key("max") and data.index[axisname].has_key("dmax")) or
163 (data.index[axisname].has_key("d") and data.index[axisname].has_key("dmax"))):
164 raise ValueError("multiple errorbar definition for axis name '%s'" % axisname)
165 if (not data.index[axisname].has_key("x") and
166 (data.index[axisname].has_key("d") or
167 data.index[axisname].has_key("dmin") or
168 data.index[axisname].has_key("dmax"))):
169 raise ValueError("errorbar definition start value missing for axis name '%s'" % axisname)
170 return columns
172 def selectstyle(self, selectindex, selecttotal, data):
173 data.symbol = attr.selectattr(self.symbol, selectindex, selecttotal)
174 data.size_pt = unit.topt(unit.length(attr.selectattr(self.size_str, selectindex, selecttotal), default_type="v"))
175 data.errorsize_pt = self.errorscale * data.size_pt
176 if self.symbolattrs is not None:
177 data.symbolattrs = attr.selectattrs(self.defaultsymbolattrs + self.symbolattrs, selectindex, selecttotal)
178 else:
179 data.symbolattrs = None
180 if self.errorbarattrs is not None:
181 data.errorbarattrs = attr.selectattrs(self.defaulterrorbarattrs + self.errorbarattrs, selectindex, selecttotal)
182 else:
183 data.errorbarattrs = None
184 if self.lineattrs is not None:
185 data.lineattrs = attr.selectattrs(self.defaultlineattrs + self.lineattrs, selectindex, selecttotal)
186 else:
187 data.lineattrs = None
189 def adjustaxes(self, columns, data):
190 # reverse lookup for axisnames
191 # TODO: the reverse lookup is ugly
192 axisnames = []
193 for column in columns:
194 for axisname in data.index.keys():
195 for thiscolumn in data.index[axisname].values():
196 if thiscolumn == column and axisname not in axisnames:
197 axisnames.append(axisname)
198 # TODO: perform check to verify that all columns for a given axisname are available at the same time
199 for axisname in axisnames:
200 if data.index[axisname].has_key("x"):
201 data.axes[axisname].adjustrange(data.points, data.index[axisname]["x"])
202 if data.index[axisname].has_key("min"):
203 data.axes[axisname].adjustrange(data.points, data.index[axisname]["min"])
204 if data.index[axisname].has_key("max"):
205 data.axes[axisname].adjustrange(data.points, data.index[axisname]["max"])
206 if data.index[axisname].has_key("d"):
207 data.axes[axisname].adjustrange(data.points, data.index[axisname]["x"], deltaindex=data.index[axisname]["d"])
208 if data.index[axisname].has_key("dmin"):
209 data.axes[axisname].adjustrange(data.points, data.index[axisname]["x"], deltaminindex=data.index[axisname]["dmin"])
210 if data.index[axisname].has_key("dmax"):
211 data.axes[axisname].adjustrange(data.points, data.index[axisname]["x"], deltamaxindex=data.index[axisname]["dmax"])
213 def drawsymbol_pt(self, c, x_pt, y_pt, data, point=None):
214 if data.symbolattrs is not None:
215 c.draw(path.path(*data.symbol(self, x_pt, y_pt, data.size_pt)), data.symbolattrs)
217 def drawpoints(self, graph, data):
218 if data.lineattrs is not None:
219 # TODO: bbox shortcut
220 linecanvas = graph.insert(canvas.canvas())
221 if data.errorbarattrs is not None:
222 # TODO: bbox shortcut
223 errorbarcanvas = graph.insert(canvas.canvas())
224 data.path = path.path()
225 linebasepoints = []
226 lastvpos = None
227 errorlist = []
228 if data.errorbarattrs is not None:
229 axisindex = 0
230 for axisname in graph.axisnames:
231 if data.index[axisname].keys() != ["x"]:
232 errorlist.append((axisname, axisindex))
233 axisindex += 1
235 for point in data.points:
236 # calculate vpos
237 vpos = [] # list containing the graph coordinates of the point
238 validvpos = 1 # valid position (but might be outside of the graph)
239 drawsymbol = 1 # valid position inside the graph
240 for axisname in graph.axisnames:
241 try:
242 v = data.axes[axisname].convert(point[data.index[axisname]["x"]])
243 except:
244 validvpos = 0
245 drawsymbol = 0
246 vpos.append(None)
247 else:
248 if v < - self.epsilon or v > 1 + self.epsilon:
249 drawsymbol = 0
250 vpos.append(v)
252 # draw symbol
253 if drawsymbol:
254 xpos, ypos = graph.vpos_pt(*vpos)
255 self.drawsymbol_pt(graph, xpos, ypos, data, point=point)
257 # append linebasepoints
258 if validvpos:
259 if len(linebasepoints):
260 # the last point was inside the graph
261 if drawsymbol:
262 linebasepoints.append((xpos, ypos))
263 else:
264 # cut end
265 cut = 1
266 for vstart, vend in zip(lastvpos, vpos):
267 newcut = None
268 if vend > 1:
269 # 1 = vstart + (vend - vstart) * cut
270 newcut = (1 - vstart)/(vend - vstart)
271 if vend < 0:
272 # 0 = vstart + (vend - vstart) * cut
273 newcut = - vstart/(vend - vstart)
274 if newcut is not None and newcut < cut:
275 cut = newcut
276 cutvpos = []
277 for vstart, vend in zip(lastvpos, vpos):
278 cutvpos.append(vstart + (vend - vstart) * cut)
279 linebasepoints.append(graph.vpos_pt(*cutvpos))
280 validvpos = 0 # clear linebasepoints below
281 else:
282 # the last point was outside the graph
283 if lastvpos is not None:
284 if drawsymbol:
285 # cut beginning
286 cut = 0
287 for vstart, vend in zip(lastvpos, vpos):
288 newcut = None
289 if vstart > 1:
290 # 1 = vstart + (vend - vstart) * cut
291 newcut = (1 - vstart)/(vend - vstart)
292 if vstart < 0:
293 # 0 = vstart + (vend - vstart) * cut
294 newcut = - vstart/(vend - vstart)
295 if newcut is not None and newcut > cut:
296 cut = newcut
297 cutvpos = []
298 for vstart, vend in zip(lastvpos, vpos):
299 cutvpos.append(vstart + (vend - vstart) * cut)
300 linebasepoints.append(graph.vpos_pt(*cutvpos))
301 linebasepoints.append(graph.vpos_pt(*vpos))
302 else:
303 # sometimes cut beginning and end
304 cutfrom = 0
305 cutto = 1
306 for vstart, vend in zip(lastvpos, vpos):
307 newcutfrom = None
308 if vstart > 1:
309 # 1 = vstart + (vend - vstart) * cutfrom
310 newcutfrom = (1 - vstart)/(vend - vstart)
311 if vstart < 0:
312 # 0 = vstart + (vend - vstart) * cutfrom
313 newcutfrom = - vstart/(vend - vstart)
314 if newcutfrom is not None and newcutfrom > cutfrom:
315 cutfrom = newcutfrom
316 newcutto = None
317 if vend > 1:
318 # 1 = vstart + (vend - vstart) * cutto
319 newcutto = (1 - vstart)/(vend - vstart)
320 if vend < 0:
321 # 0 = vstart + (vend - vstart) * cutto
322 newcutto = - vstart/(vend - vstart)
323 if newcutto is not None and newcutto < cutto:
324 cutto = newcutto
325 if cutfrom < cutto:
326 cutfromvpos = []
327 cuttovpos = []
328 for vstart, vend in zip(lastvpos, vpos):
329 cutfromvpos.append(vstart + (vend - vstart) * cutfrom)
330 cuttovpos.append(vstart + (vend - vstart) * cutto)
331 linebasepoints.append(graph.vpos_pt(*cutfromvpos))
332 linebasepoints.append(graph.vpos_pt(*cuttovpos))
333 validvpos = 0 # clear linebasepoints below
334 lastvpos = vpos
335 else:
336 lastvpos = None
338 if not validvpos:
339 # add baselinepoints to data.path
340 if len(linebasepoints) > 1:
341 data.path.append(path.moveto_pt(*linebasepoints[0]))
342 if len(linebasepoints) > 2:
343 data.path.append(path.multilineto_pt(linebasepoints[1:]))
344 else:
345 data.path.append(path.lineto_pt(*linebasepoints[1]))
346 linebasepoints = []
348 # errorbar loop over the different direction having errorbars
349 for erroraxisname, erroraxisindex in errorlist:
351 # check for validity of other point components
352 i = 0
353 for v in vpos:
354 if v is None and i != erroraxisindex:
355 break
356 i += 1
357 else:
358 # calculate min and max
359 errorindex = data.index[erroraxisname]
360 try:
361 min = point[errorindex["x"]] - point[errorindex["d"]]
362 except:
363 try:
364 min = point[errorindex["x"]] - point[errorindex["dmin"]]
365 except:
366 try:
367 min = point[errorindex["min"]]
368 except:
369 min = None
370 try:
371 max = point[errorindex["x"]] + point[errorindex["d"]]
372 except:
373 try:
374 max = point[errorindex["x"]] + point[errorindex["dmax"]]
375 except:
376 try:
377 max = point[errorindex["max"]]
378 except:
379 max = None
381 # calculate vmin and vmax
382 try:
383 vmin = data.axes[erroraxisname].convert(min)
384 except:
385 vmin = None
386 try:
387 vmax = data.axes[erroraxisname].convert(max)
388 except:
389 vmax = None
391 # create vminpos and vmaxpos
392 vcaps = []
393 if vmin is not None:
394 vminpos = vpos[:]
395 if vmin > - self.epsilon and vmin < 1 + self.epsilon:
396 vminpos[erroraxisindex] = vmin
397 vcaps.append(vminpos)
398 else:
399 vminpos[erroraxisindex] = 0
400 elif vpos[erroraxisindex] is not None:
401 vminpos = vpos
402 else:
403 break
404 if vmax is not None:
405 vmaxpos = vpos[:]
406 if vmax > - self.epsilon and vmax < 1 + self.epsilon:
407 vmaxpos[erroraxisindex] = vmax
408 vcaps.append(vmaxpos)
409 else:
410 vmaxpos[erroraxisindex] = 1
411 elif vpos[erroraxisindex] is not None:
412 vmaxpos = vpos
413 else:
414 break
416 # create path for errorbars
417 errorpath = path.path()
418 errorpath += graph.vgeodesic(*(vminpos + vmaxpos))
419 for vcap in vcaps:
420 for axisname in graph.axisnames:
421 if axisname != erroraxisname:
422 errorpath += graph.vcap_pt(axisname, data.errorsize_pt, *vcap)
424 # stroke errorpath
425 if len(errorpath.path):
426 errorbarcanvas.stroke(errorpath, data.errorbarattrs)
428 # add baselinepoints to data.path
429 if len(linebasepoints) > 1:
430 data.path.append(path.moveto_pt(*linebasepoints[0]))
431 if len(linebasepoints) > 2:
432 data.path.append(path.multilineto_pt(linebasepoints[1:]))
433 else:
434 data.path.append(path.lineto_pt(*linebasepoints[1]))
436 # stroke data.path
437 if data.lineattrs is not None:
438 linecanvas.stroke(data.path, data.lineattrs)
440 def key_pt(self, c, x_pt, y_pt, width_pt, height_pt, data):
441 self.drawsymbol_pt(c, x_pt+0.5*width_pt, y_pt+0.5*height_pt, data)
442 if data.lineattrs is not None:
443 c.stroke(path.line_pt(x_pt, y_pt+0.5*height_pt, x_pt+width_pt, y_pt+0.5*height_pt), data.lineattrs)
446 class symbol(symbolline):
448 def __init__(self, *args, **kwargs):
449 symbolline.__init__(self, lineattrs=None, *args, **kwargs)
452 class line(symbolline):
454 def __init__(self, lineattrs=[]):
455 symbolline.__init__(self, symbolattrs=None, errorbarattrs=None, lineattrs=lineattrs)
458 class text(symbol):
460 defaulttextattrs = [textmodule.halign.center, textmodule.vshift.mathaxis]
462 def __init__(self, textdx="0", textdy="0.3 cm", textattrs=[], **kwargs):
463 self.textdx_str = textdx
464 self.textdy_str = textdy
465 self.textattrs = textattrs
466 symbol.__init__(self, **kwargs)
468 def setdata(self, graph, columns, data):
469 columns = columns.copy()
470 data.textindex = columns["text"]
471 del columns["text"]
472 return symbol.setdata(self, graph, columns, data)
474 def selectstyle(self, selectindex, selecttotal, data):
475 if self.textattrs is not None:
476 data.textattrs = attr.selectattrs(self.defaulttextattrs + self.textattrs, selectindex, selecttotal)
477 else:
478 data.textattrs = None
479 symbol.selectstyle(self, selectindex, selecttotal, data)
481 def drawsymbol_pt(self, c, x, y, data, point=None):
482 symbol.drawsymbol_pt(self, c, x, y, data, point)
483 if None not in (x, y, point[data.textindex]) and data.textattrs is not None:
484 c.text_pt(x + data.textdx_pt, y + data.textdy_pt, str(point[data.textindex]), data.textattrs)
486 def drawpoints(self, graph, data):
487 data.textdx = unit.length(self.textdx_str, default_type="v")
488 data.textdy = unit.length(self.textdy_str, default_type="v")
489 data.textdx_pt = unit.topt(data.textdx)
490 data.textdy_pt = unit.topt(data.textdy)
491 symbol.drawpoints(self, graph, data)
494 class arrow(_style):
496 defaultlineattrs = []
497 defaultarrowattrs = []
499 def __init__(self, linelength="0.25 cm", arrowsize="0.15 cm", lineattrs=[], arrowattrs=[], epsilon=1e-10):
500 self.linelength_str = linelength
501 self.arrowsize_str = arrowsize
502 self.lineattrs = lineattrs
503 self.arrowattrs = arrowattrs
504 self.epsilon = epsilon
506 def setdata(self, graph, columns, data):
507 if len(graph.axisnames) != 2:
508 raise TypeError("arrow style restricted on two-dimensional graphs")
509 columns = columns.copy()
510 data.xaxis, data.xindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % graph.axisnames[0]))
511 data.yaxis, data.yindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % graph.axisnames[1]))
512 data.sizeindex = columns["size"]
513 del columns["size"]
514 data.angleindex = columns["angle"]
515 del columns["angle"]
516 return columns
518 def adjustaxes(self, columns, data):
519 if data.xindex in columns:
520 data.xaxis.adjustrange(data.points, data.xindex)
521 if data.yindex in columns:
522 data.yaxis.adjustrange(data.points, data.yindex)
524 def selectstyle(self, selectindex, selecttotal, data):
525 if self.lineattrs is not None:
526 data.lineattrs = attr.selectattrs(self.defaultlineattrs + self.lineattrs, selectindex, selecttotal)
527 else:
528 data.lineattrs = None
529 if self.arrowattrs is not None:
530 data.arrowattrs = attr.selectattrs(self.defaultarrowattrs + self.arrowattrs, selectindex, selecttotal)
531 else:
532 data.arrowattrs = None
534 def drawpoints(self, graph, data):
535 if data.lineattrs is not None and data.arrowattrs is not None:
536 arrowsize = unit.length(self.arrowsize_str, default_type="v")
537 linelength = unit.length(self.linelength_str, default_type="v")
538 arrowsize_pt = unit.topt(arrowsize)
539 linelength_pt = unit.topt(linelength)
540 for point in data.points:
541 xpos, ypos = graph.pos_pt(point[data.xindex], point[data.yindex], xaxis=data.xaxis, yaxis=data.yaxis)
542 if point[data.sizeindex] > self.epsilon:
543 dx = math.cos(point[data.angleindex]*math.pi/180)
544 dy = math.sin(point[data.angleindex]*math.pi/180)
545 x1 = xpos-0.5*dx*linelength_pt*point[data.sizeindex]
546 y1 = ypos-0.5*dy*linelength_pt*point[data.sizeindex]
547 x2 = xpos+0.5*dx*linelength_pt*point[data.sizeindex]
548 y2 = ypos+0.5*dy*linelength_pt*point[data.sizeindex]
549 graph.stroke(path.line_pt(x1, y1, x2, y2), data.lineattrs +
550 [deco.earrow(data.arrowattrs, size=arrowsize*point[data.sizeindex])])
553 class rect(_style):
555 def __init__(self, palette=color.palette.Gray):
556 self.palette = palette
558 def setdata(self, graph, columns, data):
559 if len(graph.axisnames) != 2:
560 raise TypeError("arrow style restricted on two-dimensional graphs")
561 columns = columns.copy()
562 data.xaxis, data.xminindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)min$" % graph.axisnames[0]))
563 data.yaxis, data.yminindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)min$" % graph.axisnames[1]))
564 xaxis, data.xmaxindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)max$" % graph.axisnames[0]))
565 yaxis, data.ymaxindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)max$" % graph.axisnames[1]))
566 if xaxis != data.xaxis or yaxis != data.yaxis:
567 raise ValueError("min/max values should use the same axes")
568 data.colorindex = columns["color"]
569 del columns["color"]
570 return columns
572 def selectstyle(self, selectindex, selecttotal, data):
573 pass
575 def adjustaxes(self, columns, data):
576 if data.xminindex in columns:
577 data.xaxis.adjustrange(data.points, data.xminindex)
578 if data.xmaxindex in columns:
579 data.xaxis.adjustrange(data.points, data.xmaxindex)
580 if data.yminindex in columns:
581 data.yaxis.adjustrange(data.points, data.yminindex)
582 if data.ymaxindex in columns:
583 data.yaxis.adjustrange(data.points, data.ymaxindex)
585 def drawpoints(self, graph, data):
586 # TODO: bbox shortcut
587 c = graph.insert(canvas.canvas())
588 lastcolorvalue = None
589 for point in data.points:
590 try:
591 xvmin = data.xaxis.convert(point[data.xminindex])
592 xvmax = data.xaxis.convert(point[data.xmaxindex])
593 yvmin = data.yaxis.convert(point[data.yminindex])
594 yvmax = data.yaxis.convert(point[data.ymaxindex])
595 colorvalue = point[data.colorindex]
596 if colorvalue != lastcolorvalue:
597 color = self.palette.getcolor(point[data.colorindex])
598 except:
599 continue
600 if ((xvmin < 0 and xvmax < 0) or (xvmin > 1 and xvmax > 1) or
601 (yvmin < 0 and yvmax < 0) or (yvmin > 1 and yvmax > 1)):
602 continue
603 if xvmin < 0:
604 xvmin = 0
605 elif xvmin > 1:
606 xvmin = 1
607 if xvmax < 0:
608 xvmax = 0
609 elif xvmax > 1:
610 xvmax = 1
611 if yvmin < 0:
612 yvmin = 0
613 elif yvmin > 1:
614 yvmin = 1
615 if yvmax < 0:
616 yvmax = 0
617 elif yvmax > 1:
618 yvmax = 1
619 p = graph.vgeodesic(xvmin, yvmin, xvmax, yvmin)
620 p.append(graph.vgeodesic_el(xvmax, yvmin, xvmax, yvmax))
621 p.append(graph.vgeodesic_el(xvmax, yvmax, xvmin, yvmax))
622 p.append(graph.vgeodesic_el(xvmin, yvmax, xvmin, yvmin))
623 p.append(path.closepath())
624 if colorvalue != lastcolorvalue:
625 c.set([color])
626 c.fill(p)
628 class bar(_style):
630 defaultfrompathattrs = []
631 defaultbarattrs = [color.palette.Rainbow, deco.stroked([color.gray.black])]
633 def __init__(self, fromvalue=None, frompathattrs=[], barattrs=[], subnames=None, epsilon=1e-10):
634 self.fromvalue = fromvalue
635 self.frompathattrs = frompathattrs
636 self.barattrs = barattrs
637 self.subnames = subnames
638 self.epsilon = epsilon
640 def setdata(self, graph, columns, data):
641 # TODO: remove limitation to 2d graphs
642 if len(graph.axisnames) != 2:
643 raise TypeError("arrow style currently restricted on two-dimensional graphs")
644 columns = columns.copy()
645 xvalue = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % graph.axisnames[0]))
646 yvalue = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)$" % graph.axisnames[1]))
647 if (xvalue is None and yvalue is None) or (xvalue is not None and yvalue is not None):
648 raise TypeError("must specify exactly one value axis")
649 if xvalue is not None:
650 data.valuepos = 0
651 data.nameaxis, data.nameindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)name$" % graph.axisnames[1]))
652 data.valueaxis = xvalue[0]
653 data.valueindices = [xvalue[1]]
654 else:
655 data.valuepos = 1
656 data.nameaxis, data.nameindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)name$" % graph.axisnames[0]))
657 data.valueaxis = yvalue[0]
658 data.valueindices = [yvalue[1]]
659 i = 1
660 while 1:
661 try:
662 valueaxis, valueindex = _style.setdatapattern(self, graph, columns, re.compile(r"(%s([2-9]|[1-9][0-9]+)?)stack%i$" % (graph.axisnames[data.valuepos], i)))
663 except:
664 break
665 if data.valueaxis != valueaxis:
666 raise ValueError("different value axes for stacked bars")
667 data.valueindices.append(valueindex)
668 i += 1
669 return columns
671 def selectstyle(self, selectindex, selecttotal, data):
672 if selectindex:
673 data.frompathattrs = None
674 else:
675 data.frompathattrs = self.defaultfrompathattrs + self.frompathattrs
676 if selecttotal > 1:
677 if self.barattrs is not None:
678 data.barattrs = attr.selectattrs(self.defaultbarattrs + self.barattrs, selectindex, selecttotal)
679 else:
680 data.barattrs = None
681 else:
682 data.barattrs = self.defaultbarattrs + self.barattrs
683 data.selectindex = selectindex
684 data.selecttotal = selecttotal
685 if data.selecttotal != 1 and self.subnames is not None:
686 raise ValueError("subnames not allowed when iterating over bars")
688 def adjustaxes(self, columns, data):
689 if data.nameindex in columns:
690 if data.selecttotal == 1:
691 data.nameaxis.adjustrange(data.points, data.nameindex, subnames=self.subnames)
692 else:
693 for i in range(data.selecttotal):
694 data.nameaxis.adjustrange(data.points, data.nameindex, subnames=[i])
695 for valueindex in data.valueindices:
696 if valueindex in columns:
697 data.valueaxis.adjustrange(data.points, valueindex)
699 def drawpoints(self, graph, data):
700 if self.fromvalue is not None:
701 vfromvalue = data.valueaxis.convert(self.fromvalue)
702 if vfromvalue < -self.epsilon:
703 vfromvalue = 0
704 if vfromvalue > 1 + self.epsilon:
705 vfromvalue = 1
706 if data.frompathattrs is not None and vfromvalue > self.epsilon and vfromvalue < 1 - self.epsilon:
707 if data.valuepos:
708 p = graph.vgeodesic(0, vfromvalue, 1, vfromvalue)
709 else:
710 p = graph.vgeodesic(vfromvalue, 0, vfromvalue, 1)
711 graph.stroke(p, data.frompathattrs)
712 else:
713 vfromvalue = 0
714 l = len(data.valueindices)
715 if l > 1:
716 barattrslist = []
717 for i in range(l):
718 barattrslist.append(attr.selectattrs(data.barattrs, i, l))
719 else:
720 barattrslist = [data.barattrs]
721 for point in data.points:
722 vvaluemax = vfromvalue
723 for valueindex, barattrs in zip(data.valueindices, barattrslist):
724 vvaluemin = vvaluemax
725 try:
726 vvaluemax = data.valueaxis.convert(point[valueindex])
727 except:
728 continue
730 if data.selecttotal == 1:
731 try:
732 vnamemin = data.nameaxis.convert((point[data.nameindex], 0))
733 except:
734 continue
735 try:
736 vnamemax = data.nameaxis.convert((point[data.nameindex], 1))
737 except:
738 continue
739 else:
740 try:
741 vnamemin = data.nameaxis.convert((point[data.nameindex], data.selectindex, 0))
742 except:
743 continue
744 try:
745 vnamemax = data.nameaxis.convert((point[data.nameindex], data.selectindex, 1))
746 except:
747 continue
749 if data.valuepos:
750 p = graph.vgeodesic(vnamemin, vvaluemin, vnamemin, vvaluemax)
751 p.append(graph.vgeodesic_el(vnamemin, vvaluemax, vnamemax, vvaluemax))
752 p.append(graph.vgeodesic_el(vnamemax, vvaluemax, vnamemax, vvaluemin))
753 p.append(graph.vgeodesic_el(vnamemax, vvaluemin, vnamemin, vvaluemin))
754 p.append(path.closepath())
755 else:
756 p = graph.vgeodesic(vvaluemin, vnamemin, vvaluemin, vnamemax)
757 p.append(graph.vgeodesic_el(vvaluemin, vnamemax, vvaluemax, vnamemax))
758 p.append(graph.vgeodesic_el(vvaluemax, vnamemax, vvaluemax, vnamemin))
759 p.append(graph.vgeodesic_el(vvaluemax, vnamemin, vvaluemin, vnamemin))
760 p.append(path.closepath())
761 if barattrs is not None:
762 graph.fill(p, barattrs)
764 def key_pt(self, c, x_pt, y_pt, width_pt, height_pt, data):
765 l = len(data.valueindices)
766 if l > 1:
767 for i in range(l):
768 c.fill(path.rect_pt(x_pt+i*width_pt/l, y_pt, width_pt/l, height_pt), attr.selectattrs(data.barattrs, i, l))
769 else:
770 c.fill(path.rect_pt(x_pt, y_pt, width_pt, height_pt), data.barattrs)