added SQLTable pickle test
[pygr.git] / pygr / schema.py
blob547d2506cf24b4e4818230bde6c75bdb8e510a7f
2 import types
4 # STORES DICTIONARY OF ATTRIBUTE-BOUND GRAPHS
5 # AND LIST OF UNBOUND GRAPHS
6 class SchemaDict(dict):
7 """ Container for schema rules bound to a class or object. Rules are stored in
8 two indexes for fast access, indexed by graph, and indexed by attrname.
9 Use += and -= to add or remove rules.
10 """
11 def __init__(self,newlist=(),baselist=()):
12 "Initialize schema list from list of base classes and newlist of rules"
13 self.attrs={}
14 dict.__init__(self)
15 # COMBINE SCHEMAS FROM PARENTS WITH NEW SCHEMA LIST
16 for b in baselist:
17 if hasattr(b,'__class_schema__'):
18 self.update(b.__class_schema__)
19 self.attrs.update(b.__class_schema__.attrs)
20 for i in newlist: # newlist OVERRIDES OLD DEFS FROM baselist
21 self += i
23 def __iadd__(self,i):
24 "Add a schema rule to this SchemaDict"
25 g=i[0]
26 if len(i)>=2:
27 if isinstance(i[1],types.StringType):
28 if i[1] in self.attrs: # REMOVE OLD ENTRY
29 self -= self.attrs[i[1]]
30 self.attrs[i[1]]=i # SAVE IN INDEX ACCORDING TO ATTR NAME
31 else:
32 raise TypeError('Attribute name must be a string')
33 if g not in self:
34 self[g]=[]
35 self[g].append(i) # SAVE IN GRAPH INDEX
36 return self # REQUIRED FROM iadd()!!
38 def __isub__(self,i):
39 "Remove a schema rule from this SchemaDict"
40 g=i[0]
41 if g not in self:
42 raise KeyError('graph not found in SchemaDict!')
43 self[g].remove(i) # REMOVE OLD ENTRY
44 if len(self[g])==0: # REMOVE EMPTY LIST
45 del self[g]
46 if len(i)>=2:
47 if isinstance(i[1],types.StringType):
48 if i[1] not in self.attrs:
49 raise KeyError('attribute not found in SchemaDict!')
50 del self.attrs[i[1]] # REMOVE OLD ENTRY
51 else:
52 raise TypeError('Attribute name must be a string')
53 return self # REQUIRED FROM iadd()!!
55 def initInstance(self,obj):
56 "Add obj as new node to all graphs referenced by this SchemaDict."
57 for g,l in self.items(): # GET ALL OUR RULES
58 for s in l:
59 if obj not in g:
60 g.__iadd__(obj,(s,)) # ADD OBJECT TO GRAPH USING RULE s
62 def getschema(self,attr=None,graph=None):
63 "Return list of schema rules that match attr / graph arguments."
64 if attr:
65 if attr in self.attrs:
66 return [self.attrs[attr]]
67 elif graph:
68 if graph in self:
69 return self[graph]
70 else:
71 raise ValueError('You must specify an attribute or graph.')
72 return [] # DIDN'T FIND ANYTHING
77 class SchemaList(list):
78 "Temporary container for returned schema list, with attached methods"
79 def __init__(self,obj):
80 self.obj=obj # OBJECT THAT WE'RE DESCRIBING SCHEMA OF
81 list.__init__(self) # CALL SUPERCLASS CONSTRUCTOR
83 def __iadd__(self,rule):
84 "Add a new schema rule to object described by this SchemaList"
85 if not hasattr(self.obj,'__schema__'):
86 self.obj.__schema__=SchemaDict()
87 self.obj.__schema__ += rule
88 return self # REQUIRED FROM iadd()!!
90 # PROBABLY NEED AN __isub__() TOO??
94 ######################
95 # getschema, getnodes, getedges
96 # these functions are analogous to getattr, except they get graph information
99 def getschema(o,attr=None,graph=None):
100 "Get list of schema rules for object o that match attr / graph arguments."
101 found=SchemaList(o)
102 attrs={}
103 if hasattr(o,'__schema__'):
104 for s in o.__schema__.getschema(attr,graph):
105 found.append(s)
106 if isinstance(s[1],types.StringType):
107 attrs[s[1]]=None
108 if attr and len(found)>0: # DON'T PROCEED
109 return found
110 if hasattr(o,'__class_schema__'):
111 for s in o.__class_schema__.getschema(attr,graph):
112 if not isinstance(s[1],types.StringType) or s[1] not in attrs:
113 found.append(s) # DON'T OVERWRITE OBJECT __schema__ BINDINGS
114 return found
119 def setschema(o,attr,graph):
120 "Bind object to graph, and if attr not None, also bind graph to this attribute."
121 if not hasattr(o,'__schema__'):
122 o.__schema__=SchemaDict()
123 o.__schema__ += (graph,attr)
127 def getnodes(o,attr=None,graph=None):
128 """Get destination nodes from graph bindings of o, limited to the
129 specific attribute or graph if specified"""
130 if attr:
131 if hasattr(o,'__schema__') and attr in o.__schema__:
132 return getattr(o,o.__schema__[attr][2]) # RETURN THE PRODUCT
134 if hasattr(o,'__class_schema__') and attr in o.__class_schema__:
135 return getattr(o,o.__class_schema__[attr][2]) # RETURN THE PRODUCT
136 raise AttributeError('No attribute named %s in object %s' % (attr,o))
137 elif graph: # JUST LOOK UP THE GRAPH TRIVIALLY
138 return graph[o]
139 else: # SHOULD WE GET ALL NODES FROM ALL SCHEMA ENTRIES? HOW??
140 raise ValueError('You must pass an attribute name or graph')
144 def getedges(o,attr=None,graph=None):
145 """Get edges from graph bindings of o, limited to the specific attribute
146 or graph if specified"""
147 g=getnodes(o,attr,graph) # CAN JUST REUSE THE LOGIC OF getnodes
148 if g and hasattr(g,'edges'):
149 return g.edges()
150 else:
151 return None