3 # Copyright 2008 the Melange authors.
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 """Helpers functions for updating different kinds of models in datastore.
21 '"Todd Larsen" <tlarsen@google.com>',
22 '"Sverre Rabbelier" <sverre@rabbelier.nl>',
23 '"Lennard de Rijk" <ljvderijk@gmail.com>',
24 '"Pawel Solyga" <pawel.solyga@gmail.com>',
28 from google
.appengine
.ext
import db
30 from django
.utils
.translation
import ugettext
32 from soc
.cache
import sidebar
33 from soc
.logic
import dicts
34 from soc
.views
import out_of_band
37 class Error(Exception):
38 """Base class for all exceptions raised by this module.
44 class InvalidArgumentError(Error
):
45 """Raised when an invalid argument is passed to a method.
47 For example, if an argument is None, but must always be non-False.
53 class NoEntityError(InvalidArgumentError
):
54 """Raised when no entity is passed to a method that requires one.
61 """Base logic for entity classes.
63 The BaseLogic class functions specific to Entity classes by relying
64 on arguments passed to __init__.
67 def __init__(self
, model
, base_model
=None, scope_logic
=None,
68 name
=None, skip_properties
=None):
69 """Defines the name, key_name and model for this entity.
73 self
._base
_model
= base_model
74 self
._scope
_logic
= scope_logic
79 self
._name
= self
._model
.__name
__
82 self
._skip
_properties
= skip_properties
84 self
._skip
_properties
= []
87 """Returns the model this logic class uses.
92 def getScopeLogic(self
):
93 """Returns the logic of the enclosing scope.
96 return self
._scope
_logic
98 def getScopeDepth(self
):
99 """Returns the scope depth for this entity.
101 Returns None if any of the parent scopes return None.
104 if not self
._scope
_logic
:
107 depth
= self
._scope
_logic
.logic
.getScopeDepth()
108 return None if (depth
is None) else (depth
+ 1)
110 def getKeyNameFromFields(self
, fields
):
111 """Returns the KeyName constructed from fields dict for this type of entity.
113 The KeyName is in the following format:
114 <key_value1>/<key_value2>/.../<key_valueN>
118 raise InvalidArgumentError
120 key_field_names
= self
.getKeyFieldNames()
122 # check if all key_field_names for this entity are present in fields
123 if not all(field
in fields
.keys() for field
in key_field_names
):
124 raise InvalidArgumentError("Not all the required key fields are present")
126 if not all(fields
.get(field
) for field
in key_field_names
):
127 raise InvalidArgumentError("Not all KeyValues are non-false")
129 # construct the KeyValues in the order given by getKeyFieldNames()
131 for key_field_name
in key_field_names
:
132 keyvalues
.append(fields
[key_field_name
])
134 # construct the KeyName in the appropriate format
135 return '/'.join(keyvalues
)
137 def getFullModelClassName(self
):
138 """Returns fully-qualified model module.class name string.
141 return '%s.%s' % (self
._model
.__module
__, self
._model
.__name
__)
143 def getKeyValuesFromEntity(self
, entity
):
144 """Extracts the key values from entity and returns them.
146 The default implementation uses the scope and link_id as key values.
149 entity: the entity from which to extract the key values
155 return [entity
.scope_path
, entity
.link_id
]
157 def getKeyValuesFromFields(self
, fields
):
158 """Extracts the key values from a dict and returns them.
160 The default implementation uses the scope and link_id as key values.
163 fields: the dict from which to extract the key values
166 if ('scope_path' not in fields
) or ('link_id' not in fields
):
167 raise InvalidArgumentError
169 return [fields
['scope_path'], fields
['link_id']]
171 def getKeyFieldNames(self
):
172 """Returns an array with the names of the Key Fields.
174 The default implementation uses the scope and link_id as key values.
177 return ['scope_path', 'link_id']
179 def getKeyFieldsFromFields(self
, dictionary
):
180 """Does any required massaging and filtering of dictionary.
182 The resulting dictionary contains just the key names, and has any
183 required translations/modifications performed.
186 dictionary: The arguments to massage
190 raise InvalidArgumentError
192 keys
= self
.getKeyFieldNames()
193 values
= self
.getKeyValuesFromFields(dictionary
)
194 key_fields
= dicts
.zip(keys
, values
)
198 def getFromKeyName(self
, key_name
):
199 """"Returns entity for key_name or None if not found.
202 key_name: key name of entity
206 raise InvalidArgumentError
208 return self
._model
.get_by_key_name(key_name
)
210 def getFromKeyNameOr404(self
, key_name
):
211 """Like getFromKeyName but expects to find an entity.
214 out_of_band.Error if no entity is found
217 entity
= self
.getFromKeyName(key_name
)
222 msg
= ugettext('There is no "%(name)s" named %(key_name)s.') % {
223 'name': self
._name
, 'key_name': key_name
}
225 raise out_of_band
.Error(msg
, status
=404)
227 def getFromKeyFields(self
, fields
):
228 """Returns the entity for the specified key names, or None if not found.
231 fields: a dict containing the fields of the entity that
232 uniquely identifies it
236 raise InvalidArgumentError
238 key_fields
= self
.getKeyFieldsFromFields(fields
)
240 if all(key_fields
.values()):
241 key_name
= self
.getKeyNameFromFields(key_fields
)
242 entity
= self
.getFromKeyName(key_name
)
248 def getFromKeyFieldsOr404(self
, fields
):
249 """Like getFromKeyFields but expects to find an entity.
252 out_of_band.Error if no entity is found
255 entity
= self
.getFromKeyFields(fields
)
260 key_fields
= self
.getKeyFieldsFromFields(fields
)
261 format_text
= ugettext('"%(key)s" is "%(value)s"')
263 msg_pairs
= [format_text
% {'key': key
, 'value': value
}
264 for key
, value
in key_fields
.iteritems()]
266 joined_pairs
= ' and '.join(msg_pairs
)
269 'There is no "%(name)s" where %(pairs)s.') % {
270 'name': self
._name
, 'pairs': joined_pairs
}
272 raise out_of_band
.Error(msg
, status
=404)
274 def getForFields(self
, filter=None, unique
=False,
275 limit
=1000, offset
=0, order
=None):
276 """Returns all entities that have the specified properties.
279 filter: a dict for the properties that the entities should have
280 unique: if set, only the first item from the resultset will be returned
281 limit: the amount of entities to fetch at most
282 offset: the position to start at
283 order: a list with the sort order
289 query
= self
.getQueryForFields(filter=filter, order
=order
)
291 result
= query
.fetch(limit
, offset
)
294 return result
[0] if result
else None
298 def getQueryForFields(self
, filter=None, order
=None):
299 """Returns a query with the specified properties.
302 filter: a dict for the properties that the entities should have
303 order: a list with the sort order
306 - Query object instantiated with the given properties
315 orderset
= set([i
.strip('-') for i
in order
])
316 if len(orderset
) != len(order
):
317 raise InvalidArgumentError
319 query
= db
.Query(self
._model
)
321 for key
, value
in filter.iteritems():
322 if isinstance(value
, list) and len(value
) == 1:
324 if isinstance(value
, list):
326 query
.filter(op
, value
)
328 query
.filter(key
, value
)
335 def updateEntityProperties(self
, entity
, entity_properties
, silent
=False):
336 """Update existing entity using supplied properties.
339 entity: a model entity
340 entity_properties: keyword arguments that correspond to entity
341 properties and their values
342 silent: iff True does not call _onUpdate method
345 The original entity with any supplied properties changed.
351 if not entity_properties
:
352 raise InvalidArgumentError
354 properties
= self
._model
.properties()
356 for name
, prop
in properties
.iteritems():
357 # if the property is not updateable or is not updated, skip it
358 if name
in self
._skip
_properties
or (name
not in entity_properties
):
361 if self
._updateField
(entity
, entity_properties
, name
):
362 value
= entity_properties
[name
]
363 prop
.__set
__(entity
, value
)
367 # call the _onUpdate method
369 self
._onUpdate
(entity
)
373 def updateOrCreateFromKeyName(self
, properties
, key_name
):
374 """Update existing entity, or create new one with supplied properties.
377 properties: dict with entity properties and their values
378 key_name: the key_name of the entity that uniquely identifies it
381 the entity corresponding to the key_name, with any supplied
382 properties changed, or a new entity now associated with the
383 supplied key_name and properties.
386 entity
= self
.getFromKeyName(key_name
)
388 create_entity
= not entity
391 for property_name
in properties
:
392 self
._createField
(properties
, property_name
)
394 # entity did not exist, so create one in a transaction
395 entity
= self
._model
.get_or_insert(key_name
, **properties
)
397 # If someone else already created the entity (due to a race), we
398 # should not update the propties (as they 'won' the race).
399 entity
= self
.updateEntityProperties(entity
, properties
, silent
=True)
402 # a new entity has been created call _onCreate
403 self
._onCreate
(entity
)
405 # the entity has been updated call _onUpdate
406 self
._onUpdate
(entity
)
410 def isDeletable(self
, entity
):
411 """Returns whether the specified entity can be deleted.
414 entity: an existing entity in datastore
419 def delete(self
, entity
):
420 """Delete existing entity from datastore.
423 entity: an existing entity in datastore
427 # entity has been deleted call _onDelete
428 self
._onDelete
(entity
)
430 def getAll(self
, query
):
431 """Retrieves all entities for the specified query.
440 data
= query
.fetch(chunk
+1, offset
)
442 more
= len(data
) > chunk
448 offset
= offset
+ chunk
452 def _createField(self
, entity_properties
, name
):
453 """Hook called when a field is created.
455 To be exact, this method is called for each field (that has a value
456 specified) on an entity that is being created.
458 Base classes should override if any special actions need to be
459 taken when a field is created.
462 entity_properties: keyword arguments that correspond to entity
463 properties and their values
464 name: the name of the field to be created
467 if not entity_properties
or (name
not in entity_properties
):
468 raise InvalidArgumentError
470 def _updateField(self
, entity
, entity_properties
, name
):
471 """Hook called when a field is updated.
473 Base classes should override if any special actions need to be
474 taken when a field is updated. The field is not updated if the
475 method does not return a True value.
478 entity: the unaltered entity
479 entity_properties: keyword arguments that correspond to entity
480 properties and their values
481 name: the name of the field to be changed
487 if not entity_properties
or (name
not in entity_properties
):
488 raise InvalidArgumentError
492 def _onCreate(self
, entity
):
493 """Called when an entity has been created.
495 Classes that override this can use it to do any post-creation operations.
503 def _onUpdate(self
, entity
):
504 """Called when an entity has been updated.
506 Classes that override this can use it to do any post-update operations.
512 def _onDelete(self
, entity
):
513 """Called when an entity has been deleted.
515 Classes that override this can use it to do any post-deletion operations.