Readability fix in base logic
[Melange.git] / app / soc / logic / models / base.py
blob59f9363675cc06069940a1a5c5d2c229734504fe
1 #!/usr/bin/python2.5
3 # Copyright 2008 the Melange authors.
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 """Helpers functions for updating different kinds of models in datastore.
18 """
20 __authors__ = [
21 '"Todd Larsen" <tlarsen@google.com>',
22 '"Sverre Rabbelier" <sverre@rabbelier.nl>',
23 '"Lennard de Rijk" <ljvderijk@gmail.com>',
24 '"Pawel Solyga" <pawel.solyga@gmail.com>',
28 from google.appengine.ext import db
30 from django.utils.translation import ugettext
32 from soc.cache import sidebar
33 from soc.logic import dicts
34 from soc.views import out_of_band
37 class Error(Exception):
38 """Base class for all exceptions raised by this module.
39 """
41 pass
44 class InvalidArgumentError(Error):
45 """Raised when an invalid argument is passed to a method.
47 For example, if an argument is None, but must always be non-False.
48 """
50 pass
53 class NoEntityError(InvalidArgumentError):
54 """Raised when no entity is passed to a method that requires one.
55 """
57 pass
60 class Logic(object):
61 """Base logic for entity classes.
63 The BaseLogic class functions specific to Entity classes by relying
64 on arguments passed to __init__.
65 """
67 def __init__(self, model, base_model=None, scope_logic=None,
68 name=None, skip_properties=None):
69 """Defines the name, key_name and model for this entity.
70 """
72 self._model = model
73 self._base_model = base_model
74 self._scope_logic = scope_logic
76 if name:
77 self._name = name
78 else:
79 self._name = self._model.__name__
81 if skip_properties:
82 self._skip_properties = skip_properties
83 else:
84 self._skip_properties = []
86 def getModel(self):
87 """Returns the model this logic class uses.
88 """
90 return self._model
92 def getScopeLogic(self):
93 """Returns the logic of the enclosing scope.
94 """
96 return self._scope_logic
98 def getScopeDepth(self):
99 """Returns the scope depth for this entity.
101 Returns None if any of the parent scopes return None.
104 if not self._scope_logic:
105 return 0
107 depth = self._scope_logic.logic.getScopeDepth()
108 return None if (depth is None) else (depth + 1)
110 def getKeyNameFromFields(self, fields):
111 """Returns the KeyName constructed from fields dict for this type of entity.
113 The KeyName is in the following format:
114 <key_value1>/<key_value2>/.../<key_valueN>
117 if not fields:
118 raise InvalidArgumentError
120 key_field_names = self.getKeyFieldNames()
122 # check if all key_field_names for this entity are present in fields
123 if not all(field in fields.keys() for field in key_field_names):
124 raise InvalidArgumentError("Not all the required key fields are present")
126 if not all(fields.get(field) for field in key_field_names):
127 raise InvalidArgumentError("Not all KeyValues are non-false")
129 # construct the KeyValues in the order given by getKeyFieldNames()
130 keyvalues = []
131 for key_field_name in key_field_names:
132 keyvalues.append(fields[key_field_name])
134 # construct the KeyName in the appropriate format
135 return '/'.join(keyvalues)
137 def getFullModelClassName(self):
138 """Returns fully-qualified model module.class name string.
141 return '%s.%s' % (self._model.__module__, self._model.__name__)
143 def getKeyValuesFromEntity(self, entity):
144 """Extracts the key values from entity and returns them.
146 The default implementation uses the scope and link_id as key values.
148 Args:
149 entity: the entity from which to extract the key values
152 if not entity:
153 raise NoEntityError
155 return [entity.scope_path, entity.link_id]
157 def getKeyValuesFromFields(self, fields):
158 """Extracts the key values from a dict and returns them.
160 The default implementation uses the scope and link_id as key values.
162 Args:
163 fields: the dict from which to extract the key values
166 if ('scope_path' not in fields) or ('link_id' not in fields):
167 raise InvalidArgumentError
169 return [fields['scope_path'], fields['link_id']]
171 def getKeyFieldNames(self):
172 """Returns an array with the names of the Key Fields.
174 The default implementation uses the scope and link_id as key values.
177 return ['scope_path', 'link_id']
179 def getKeyFieldsFromFields(self, dictionary):
180 """Does any required massaging and filtering of dictionary.
182 The resulting dictionary contains just the key names, and has any
183 required translations/modifications performed.
185 Args:
186 dictionary: The arguments to massage
189 if not dictionary:
190 raise InvalidArgumentError
192 keys = self.getKeyFieldNames()
193 values = self.getKeyValuesFromFields(dictionary)
194 key_fields = dicts.zip(keys, values)
196 return key_fields
198 def getFromKeyName(self, key_name):
199 """"Returns entity for key_name or None if not found.
201 Args:
202 key_name: key name of entity
205 if not key_name:
206 raise InvalidArgumentError
208 return self._model.get_by_key_name(key_name)
210 def getFromKeyNameOr404(self, key_name):
211 """Like getFromKeyName but expects to find an entity.
213 Raises:
214 out_of_band.Error if no entity is found
217 entity = self.getFromKeyName(key_name)
219 if entity:
220 return entity
222 msg = ugettext('There is no "%(name)s" named %(key_name)s.') % {
223 'name': self._name, 'key_name': key_name}
225 raise out_of_band.Error(msg, status=404)
227 def getFromKeyFields(self, fields):
228 """Returns the entity for the specified key names, or None if not found.
230 Args:
231 fields: a dict containing the fields of the entity that
232 uniquely identifies it
235 if not fields:
236 raise InvalidArgumentError
238 key_fields = self.getKeyFieldsFromFields(fields)
240 if all(key_fields.values()):
241 key_name = self.getKeyNameFromFields(key_fields)
242 entity = self.getFromKeyName(key_name)
243 else:
244 entity = None
246 return entity
248 def getFromKeyFieldsOr404(self, fields):
249 """Like getFromKeyFields but expects to find an entity.
251 Raises:
252 out_of_band.Error if no entity is found
255 entity = self.getFromKeyFields(fields)
257 if entity:
258 return entity
260 key_fields = self.getKeyFieldsFromFields(fields)
261 format_text = ugettext('"%(key)s" is "%(value)s"')
263 msg_pairs = [format_text % {'key': key, 'value': value}
264 for key, value in key_fields.iteritems()]
266 joined_pairs = ' and '.join(msg_pairs)
268 msg = ugettext(
269 'There is no "%(name)s" where %(pairs)s.') % {
270 'name': self._name, 'pairs': joined_pairs}
272 raise out_of_band.Error(msg, status=404)
274 def getForFields(self, filter=None, unique=False,
275 limit=1000, offset=0, order=None):
276 """Returns all entities that have the specified properties.
278 Args:
279 filter: a dict for the properties that the entities should have
280 unique: if set, only the first item from the resultset will be returned
281 limit: the amount of entities to fetch at most
282 offset: the position to start at
283 order: a list with the sort order
286 if unique:
287 limit = 1
289 query = self.getQueryForFields(filter=filter, order=order)
291 result = query.fetch(limit, offset)
293 if unique:
294 return result[0] if result else None
296 return result
298 def getQueryForFields(self, filter=None, order=None):
299 """Returns a query with the specified properties.
301 Args:
302 filter: a dict for the properties that the entities should have
303 order: a list with the sort order
305 Returns:
306 - Query object instantiated with the given properties
309 if not filter:
310 filter = {}
312 if not order:
313 order = []
315 orderset = set([i.strip('-') for i in order])
316 if len(orderset) != len(order):
317 raise InvalidArgumentError
319 query = db.Query(self._model)
321 for key, value in filter.iteritems():
322 if isinstance(value, list) and len(value) == 1:
323 value = value[0]
324 if isinstance(value, list):
325 op = '%s IN' % key
326 query.filter(op, value)
327 else:
328 query.filter(key, value)
330 for key in order:
331 query.order(key)
333 return query
335 def updateEntityProperties(self, entity, entity_properties, silent=False):
336 """Update existing entity using supplied properties.
338 Args:
339 entity: a model entity
340 entity_properties: keyword arguments that correspond to entity
341 properties and their values
342 silent: iff True does not call _onUpdate method
344 Returns:
345 The original entity with any supplied properties changed.
348 if not entity:
349 raise NoEntityError
351 if not entity_properties:
352 raise InvalidArgumentError
354 properties = self._model.properties()
356 for name, prop in properties.iteritems():
357 # if the property is not updateable or is not updated, skip it
358 if name in self._skip_properties or (name not in entity_properties):
359 continue
361 if self._updateField(entity, entity_properties, name):
362 value = entity_properties[name]
363 prop.__set__(entity, value)
365 entity.put()
367 # call the _onUpdate method
368 if not silent:
369 self._onUpdate(entity)
371 return entity
373 def updateOrCreateFromKeyName(self, properties, key_name):
374 """Update existing entity, or create new one with supplied properties.
376 Args:
377 properties: dict with entity properties and their values
378 key_name: the key_name of the entity that uniquely identifies it
380 Returns:
381 the entity corresponding to the key_name, with any supplied
382 properties changed, or a new entity now associated with the
383 supplied key_name and properties.
386 entity = self.getFromKeyName(key_name)
388 create_entity = not entity
390 if create_entity:
391 for property_name in properties:
392 self._createField(properties, property_name)
394 # entity did not exist, so create one in a transaction
395 entity = self._model.get_or_insert(key_name, **properties)
396 else:
397 # If someone else already created the entity (due to a race), we
398 # should not update the propties (as they 'won' the race).
399 entity = self.updateEntityProperties(entity, properties, silent=True)
401 if create_entity:
402 # a new entity has been created call _onCreate
403 self._onCreate(entity)
404 else:
405 # the entity has been updated call _onUpdate
406 self._onUpdate(entity)
408 return entity
410 def isDeletable(self, entity):
411 """Returns whether the specified entity can be deleted.
413 Args:
414 entity: an existing entity in datastore
417 return True
419 def delete(self, entity):
420 """Delete existing entity from datastore.
422 Args:
423 entity: an existing entity in datastore
426 entity.delete()
427 # entity has been deleted call _onDelete
428 self._onDelete(entity)
430 def getAll(self, query):
431 """Retrieves all entities for the specified query.
434 chunk = 999
435 offset = 0
436 result = []
437 more = True
439 while(more):
440 data = query.fetch(chunk+1, offset)
442 more = len(data) > chunk
444 if more:
445 del data[chunk]
447 result.extend(data)
448 offset = offset + chunk
450 return result
452 def _createField(self, entity_properties, name):
453 """Hook called when a field is created.
455 To be exact, this method is called for each field (that has a value
456 specified) on an entity that is being created.
458 Base classes should override if any special actions need to be
459 taken when a field is created.
461 Args:
462 entity_properties: keyword arguments that correspond to entity
463 properties and their values
464 name: the name of the field to be created
467 if not entity_properties or (name not in entity_properties):
468 raise InvalidArgumentError
470 def _updateField(self, entity, entity_properties, name):
471 """Hook called when a field is updated.
473 Base classes should override if any special actions need to be
474 taken when a field is updated. The field is not updated if the
475 method does not return a True value.
477 Args:
478 entity: the unaltered entity
479 entity_properties: keyword arguments that correspond to entity
480 properties and their values
481 name: the name of the field to be changed
484 if not entity:
485 raise NoEntityError
487 if not entity_properties or (name not in entity_properties):
488 raise InvalidArgumentError
490 return True
492 def _onCreate(self, entity):
493 """Called when an entity has been created.
495 Classes that override this can use it to do any post-creation operations.
498 if not entity:
499 raise NoEntityError
501 sidebar.flush()
503 def _onUpdate(self, entity):
504 """Called when an entity has been updated.
506 Classes that override this can use it to do any post-update operations.
509 if not entity:
510 raise NoEntityError
512 def _onDelete(self, entity):
513 """Called when an entity has been deleted.
515 Classes that override this can use it to do any post-deletion operations.
518 if not entity:
519 raise NoEntityError