views: add a BranchCompareView class
[git-cola.git] / cola / model.py
blob87af67fcd761e1151d021c9edcba6df150ab1844
1 #!/usr/bin/env python
2 # Copyright (c) 2008 David Aguilar
3 import os
4 import imp
5 from cStringIO import StringIO
6 from types import DictType
7 from types import ListType
8 from types import TupleType
9 from types import StringTypes
10 from types import BooleanType
11 from types import IntType
12 from types import LongType
13 from types import FloatType
14 from types import ComplexType
15 from types import InstanceType
17 class Observable(object):
18 """Handles subject/observer notifications."""
19 def __init__(self,*args,**kwargs):
20 self.__observers = []
21 self.__notify = True
22 def get_notify(self):
23 return self.__notify
24 def set_notify(self, notify=True):
25 self.__notify = notify
26 def add_observer(self, observer):
27 if observer not in self.__observers:
28 self.__observers.append(observer)
29 def remove_observer(self, observer):
30 if observer in self.__observers:
31 self.__observers.remove(observer)
32 def notify_observers(self, *param):
33 if not self.__notify: return
34 for observer in self.__observers:
35 observer.notify(*param)
37 class Model(Observable):
38 """Creates a generic model object with params specified
39 as a name:value dictionary.
41 get_name() and set_name(value) are created automatically
42 for any of the parameters specified in the **kwargs"""
44 def __init__(self, *args, **kwargs):
45 Observable.__init__(self)
46 self.__params = []
47 self.from_dict(kwargs)
48 self.init()
50 def init(self):
51 """init() is called by the built-in constructor.
52 Subclasses should implement this if necessary."""
53 pass
55 def create(self,**kwargs):
56 return self.from_dict(kwargs)
58 def get_param_names(self):
59 return tuple(self.__params)
61 def notify_all(self):
62 self.notify_observers(*self.get_param_names())
64 def clone(self, *args, **kwargs):
65 return self.__class__(*args, **kwargs).from_dict(self.to_dict())
67 def has_param(self,param):
68 return param in self.__params
70 def get_param(self,param):
71 return getattr(self, param)
73 def __getattr__(self, param):
74 """Provides automatic get/set/add/append methods."""
76 # Base case: we actually have this param
77 if param in self.__dict__:
78 return getattr(self, param)
80 # Check for the translated variant of the param
81 realparam = self.__translate(param, sep='')
82 if realparam in self.__dict__:
83 return getattr(self, realparam)
85 if realparam.startswith("get"):
86 param = self.__translate(param, "get")
87 return lambda: getattr(self, param)
89 elif realparam.startswith("set"):
90 param = self.__translate(param, "set")
91 return lambda v: self.set_param(param, v,
92 check_params=True)
94 elif (realparam.startswith("add") or realparam.startswith("append")):
95 if realparam.startswith("add"):
96 param = self.__translate(realparam, "add")
97 else:
98 param = self.__translate(realparam, "append")
100 def array_append(*values):
101 array = getattr(self, param)
102 if array is None:
103 classnm = self.__class__.__name__
104 errmsg = ("%s object has no array named '%s'"
105 %( classnm, param ))
106 raise AttributeError(errmsg)
107 else:
108 array.extend(values)
109 # Cache the function definition
110 setattr(self, realparam, array_append)
111 return array_append
113 errmsg = ("%s object has no parameter '%s'"
114 % (self.__class__.__name__, param))
116 raise AttributeError(errmsg)
118 def set_param(self, param, value, notify=True, check_params=False):
119 """Set param with optional notification and validity checks."""
121 param = param.lower()
122 if check_params and param not in self.__params:
123 raise AttributeError("Parameter '%s' not available for %s"
124 % (param, self.__class__.__name__))
125 elif param not in self.__params:
126 self.__params.append(param)
128 setattr(self, param, value)
129 if notify:
130 self.notify_observers(param)
132 def copy_params(self, model, params=None):
133 if params is None:
134 params = self.get_param_names()
135 for param in params:
136 self.set_param(param, model.get_param(param))
138 def __translate(self, param, prefix='', sep='_'):
139 """Translates an param name from the external name
140 used in methods to those used internally. The default
141 settings strip off '_' so that both get_foo() and getFoo()
142 are valid incantations."""
143 return param[len(prefix):].lstrip(sep).lower()
145 def save(self, filename):
146 if not has_json():
147 return
148 import simplejson
149 file = open(filename, 'w')
150 simplejson.dump(self.to_dict(), file, indent=4)
151 file.close()
153 def load(self, filename):
154 if not has_json():
155 return
156 import simplejson
157 file = open(filename, 'r')
158 ddict = simplejson.load(file)
159 file.close()
160 if "__class__" in ddict:
161 # load params in-place.
162 del ddict["__class__"]
163 return self.from_dict(ddict)
165 @staticmethod
166 def instance(filename):
167 if not has_json():
168 return
169 import simplejson
170 file = open(filename, 'r')
171 ddict = simplejson.load(file)
172 file.close()
173 if "__class__" in ddict:
174 cls = Model.str_to_class(ddict["__class__"])
175 del ddict["__class__"]
176 return cls().from_dict(ddict)
177 else:
178 return Model().from_dict(ddict)
180 def from_dict(self, source_dict):
181 """Import a complex model from a dictionary.
182 We store class information in the __class__ variable.
183 If it looks like a duck, it's a duck."""
185 if "__class__" in source_dict:
186 clsstr = source_dict["__class__"]
187 del source_dict["__class__"]
188 cls = Model.str_to_class(clsstr)
189 return cls().from_dict(source_dict)
191 # Not initiating a clone: load parameters in-place
192 for param, val in source_dict.iteritems():
193 self.set_param(param,
194 self.__obj_from_value(val),
195 notify=False)
196 self.__params.sort()
197 return self
199 def __obj_from_value(self, val):
200 # Atoms
201 if is_atom(val):
202 return val
204 # Possibly nested lists
205 elif is_list(val):
206 return [ self.__obj_from_value(v) for v in val ]
208 elif is_dict(val):
209 # A param that maps to a Model-object
210 if "__class__" in val:
211 clsstr = val["__class__"]
212 cls = Model.str_to_class(clsstr)
213 del val["__class__"]
214 return cls().from_dict(val)
215 newdict = {}
216 for k, v in val.iteritems():
217 newdict[k] = self.__obj_from_value(v)
218 return newdict
220 # All others
221 return val
223 def to_dict(self):
225 Exports a model to a dictionary.
226 This simplifies serialization.
228 params = {"__class__": Model.class_to_str(self)}
229 for param in self.__params:
230 params[param] = self.__obj_to_value(getattr(self, param))
231 return params
233 def __obj_to_value(self, item):
234 if is_atom(item):
235 return item
237 elif is_list(item):
238 newlist = [ self.__obj_to_value(i) for i in item ]
239 return newlist
241 elif is_dict(item):
242 newdict = {}
243 for k,v in item.iteritems():
244 newdict[k] = self.__obj_to_value(v)
245 return newdict
247 elif is_instance(item):
248 return item.to_dict()
250 else:
251 raise NotImplementedError("Unknown type:" + str(type(item)))
253 __INDENT__ = 0
254 __PREINDENT__ = True
255 __STRSTACK__ = []
257 @staticmethod
258 def INDENT(i=0):
259 Model.__INDENT__ += i
260 return '\t' * Model.__INDENT__
262 def __str__(self):
263 """A convenient, recursively-defined stringification method."""
265 # This avoid infinite recursion on cyclical structures
266 if self in Model.__STRSTACK__:
267 return 'REFERENCE' # TODO: implement references?
268 else:
269 Model.__STRSTACK__.append(self)
271 io = StringIO()
273 if Model.__PREINDENT__:
274 io.write(Model.INDENT())
276 io.write(self.__class__.__name__ + '(')
278 Model.INDENT(1)
280 for param in self.__params:
281 if param.startswith('_'):
282 continue
283 io.write('\n')
285 inner = Model.INDENT() + param + " = "
286 value = getattr(self, param)
288 if type(value) == ListType:
289 indent = Model.INDENT(1)
290 io.write(inner + "[\n")
291 for val in value:
292 if is_model(val):
293 io.write(str(val)+'\n')
294 else:
295 io.write(indent)
296 io.write(str(val))
297 io.write(",\n")
299 io.write(Model.INDENT(-1))
300 io.write("],")
301 else:
302 Model.__PREINDENT__ = False
303 io.write(inner)
304 io.write(str(value))
305 io.write(',')
306 Model.__PREINDENT__ = True
308 io.write('\n' + Model.INDENT(-1) + ')')
309 value = io.getvalue()
310 io.close()
312 Model.__STRSTACK__.remove(self)
313 return value
315 @staticmethod
316 def str_to_class(clstr):
317 items = clstr.split('.')
318 modules = items[:-1]
319 classname = items[-1]
320 path = None
321 module = None
322 for mod in modules:
323 search = imp.find_module(mod, path)
324 try:
325 module = imp.load_module(mod, *search)
326 if hasattr(module, "__path__"):
327 path = module.__path__
328 finally:
329 if search and search[0]:
330 search[0].close()
331 if module:
332 return getattr(module, classname)
333 else:
334 raise Exception("No class found for: %s" % clstr)
336 @staticmethod
337 def class_to_str(instance):
338 modname = instance.__module__
339 classname = instance.__class__.__name__
340 return "%s.%s" % (modname, classname)
343 #############################################################################
344 def has_json():
345 try:
346 import simplejson
347 return True
348 except ImportError:
349 print "Unable to import simplejson." % action
350 print "You do not have simplejson installed."
351 print "try: sudo apt-get install simplejson"
352 return False
354 #############################################################################
355 def is_model(item):
356 return issubclass(item.__class__, Model)
357 def is_dict(item):
358 return type(item) is DictType
359 def is_list(item):
360 return type(item) is ListType or type(item) is TupleType
361 def is_atom(item):
362 return(type(item) in StringTypes
363 or type(item) is BooleanType
364 or type(item) is IntType
365 or type(item) is LongType
366 or type(item) is FloatType
367 or type(item) is ComplexType)
368 def is_instance(item):
369 return(is_model(item) or type(item) is InstanceType)