Catch exception related to number field handling in older SDK.
[gae-samples.git] / search / product_search_python / docs.py
blobccf939f43ca9594c34350f181cc20110306141dd
1 #!/usr/bin/env python
3 # Copyright 2012 Google Inc.
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 """ Contains 'helper' classes for managing search.Documents.
18 BaseDocumentManager provides some common utilities, and the Product subclass
19 adds some Product-document-specific helper methods.
20 """
22 import collections
23 import copy
24 import logging
25 import re
26 import string
27 import urllib
29 import categories
30 import config
31 import errors
32 import models
34 from google.appengine.api import search
35 from google.appengine.ext import ndb
38 class BaseDocumentManager(object):
39 """Abstract class. Provides helper methods to manage search.Documents."""
41 _INDEX_NAME = None
42 _VISIBLE_PRINTABLE_ASCII = frozenset(
43 set(string.printable) - set(string.whitespace))
45 def __init__(self, doc):
46 """Builds a dict of the fields mapped against the field names, for
47 efficient access.
48 """
49 self.doc = doc
50 fields = doc.fields
51 self.doc_map = {}
52 for field in fields:
53 fieldlist = self.doc_map.get(field.name,[])
54 fieldlist.append(field)
55 self.doc_map[field.name] = fieldlist
57 def getFirstFieldVal(self, fname):
58 """Get the value of the (first) document field with the given name."""
59 try:
60 return self.doc_map.get(fname)[0].value
61 except IndexError:
62 return None
63 except TypeError:
64 return None
66 def setFirstField(self, new_field):
67 """Set the value of the (first) document field with the given name."""
68 for i, field in enumerate(self.doc.fields):
69 if field.name == new_field.name:
70 self.doc.fields[i] = new_field
71 return True
72 return False
74 @classmethod
75 def isValidDocId(cls, doc_id):
76 """Checks if the given id is a visible printable ASCII string not starting
77 with '!'. Whitespace characters are excluded.
78 """
79 for char in doc_id:
80 if char not in cls._VISIBLE_PRINTABLE_ASCII:
81 return False
82 return not doc_id.startswith('!')
84 @classmethod
85 def getIndex(cls):
86 return search.Index(name=cls._INDEX_NAME)
88 @classmethod
89 def deleteAllInIndex(cls):
90 """Delete all the docs in the given index."""
91 docindex = cls.getIndex()
93 try:
94 while True:
95 # until no more documents, get a list of documents,
96 # constraining the returned objects to contain only the doc ids,
97 # extract the doc ids, and delete the docs.
98 document_ids = [document.doc_id
99 for document in docindex.list_documents(ids_only=True)]
100 if not document_ids:
101 break
102 docindex.remove(document_ids)
103 except search.Error:
104 logging.exception("Error removing documents:")
106 @classmethod
107 def getDoc(cls, doc_id):
108 """Return the document with the given doc id. One way to do this is via
109 the list_documents method, as shown here. If the doc id is not in the
110 index, the first doc in the list will be returned instead, so we need
111 to check for that case."""
112 if not doc_id:
113 return None
114 try:
115 index = cls.getIndex()
116 response = index.list_documents(
117 start_doc_id=doc_id, limit=1, include_start_doc=True)
118 if response.results and response.results[0].doc_id == doc_id:
119 return response.results[0]
120 return None
121 except search.InvalidRequest: # catches ill-formed doc ids
122 return None
124 @classmethod
125 def removeDocById(cls, doc_id):
126 """Remove the doc with the given doc id."""
127 try:
128 cls.getIndex().remove(doc_id)
129 except search.Error:
130 logging.exception("Error removing doc id %s.", doc_id)
132 @classmethod
133 def add(cls, documents):
134 """wrapper for search index add method; specifies the index name."""
135 try:
136 return cls.getIndex().add(documents)
137 except search.Error:
138 logging.exception("Error adding documents.")
141 class Product(BaseDocumentManager):
142 """Provides helper methods to manage Product documents. All Product documents
143 built using these methods will include a core set of fields (see the
144 _buildCoreProductFields method). We use the given product id (the Product
145 entity key) as the doc_id. This is not required for the entity/document
146 design-- each explicitly point to each other, allowing their ids to be
147 decoupled-- but using the product id as the doc id allows a document to be
148 reindexed given its product info, without having to fetch the
149 existing document."""
151 _INDEX_NAME = config.PRODUCT_INDEX_NAME
153 # 'core' product document field names
154 PID = 'pid'
155 DESCRIPTION = 'description'
156 CAT = 'cat'
157 CATNAME = 'catname'
158 PNAME = 'name'
159 PRICE = 'price'
160 AR = 'ar' #average rating
162 _SORT_OPTIONS = [
163 [AR, 'average rating', search.SortExpression(
164 expression=AR,
165 direction=search.SortExpression.DESCENDING, default_value=1)],
166 [PRICE, 'price', search.SortExpression(
167 expression=PRICE,
168 direction=search.SortExpression.ASCENDING, default_value=1)],
169 [CATNAME, 'category', search.SortExpression(
170 expression=CATNAME,
171 direction=search.SortExpression.ASCENDING, default_value='')],
172 [PNAME, 'product name', search.SortExpression(
173 expression=PNAME,
174 direction=search.SortExpression.ASCENDING, default_value='')]
177 _SORT_MENU = None
178 _SORT_DICT = None
181 @classmethod
182 def deleteAllInProductIndex(cls):
183 cls.deleteAllInIndex()
185 @classmethod
186 def getSortMenu(cls):
187 if not cls._SORT_MENU:
188 cls._buildSortMenu()
189 return cls._SORT_MENU
191 @classmethod
192 def getSortDict(cls):
193 if not cls._SORT_DICT:
194 cls._buildSortDict()
195 return cls._SORT_DICT
197 @classmethod
198 def _buildSortMenu(cls):
199 """Build the default set of sort options used for Product search.
200 Of these options, all but 'relevance' reference core fields that
201 all Products will have."""
202 res = [(elt[0], elt[1]) for elt in cls._SORT_OPTIONS]
203 cls._SORT_MENU = [('relevance', 'relevance')] + res
205 @classmethod
206 def _buildSortDict(cls):
207 """Build a dict that maps sort option keywords to their corresponding
208 SortExpressions."""
209 cls._SORT_DICT = {}
210 for elt in cls._SORT_OPTIONS:
211 cls._SORT_DICT[elt[0]] = elt[2]
213 @classmethod
214 def getDocFromPid(cls, pid):
215 """Given a pid, get its doc. We're using the pid as the doc id, so we can
216 do this via a direct fetch."""
217 return cls.getDoc(pid)
219 @classmethod
220 def removeProductDocByPid(cls, pid):
221 """Given a doc's pid, remove the doc matching it from the product
222 index."""
223 cls.removeDocById(pid)
225 @classmethod
226 def updateRatingInDoc(cls, doc_id, avg_rating):
227 # get the associated doc from the doc id in the product entity
228 doc = cls.getDoc(doc_id)
229 if doc:
230 pdoc = cls(doc)
231 cat = pdoc.getCategory()
232 # The category cast to int is to avoid a current dev appserver issue when
233 # reindexing. Not an issue for a deployed app.
234 pdoc.setCategory(int(cat))
235 pdoc.setAvgRating(avg_rating)
236 # The use of the same id will cause the existing doc to be reindexed.
237 return doc
238 else:
239 raise errors.OperationFailedError(
240 'Could not retrieve doc associated with id %s' % (doc_id,))
242 @classmethod
243 def updateRatingsInfo(cls, doc_id, avg_rating):
244 """Given a models.Product entity, update and reindex the associated
245 document with the product entity's current average rating. """
247 ndoc = cls.updateRatingInDoc(doc_id, avg_rating)
248 # reindex the returned updated doc
249 return cls.add(ndoc)
251 # 'accessor' methods
253 def getPID(self):
254 """Get the value of the 'pid' field of a Product doc."""
255 return self.getFirstFieldVal(self.PID)
257 def getName(self):
258 """Get the value of the 'name' field of a Product doc."""
259 return self.getFirstFieldVal(self.PNAME)
261 def getDescription(self):
262 """Get the value of the 'description' field of a Product doc."""
263 return self.getFirstFieldVal(self.DESCRIPTION)
265 def getCategory(self):
266 """Get the value of the 'cat' field of a Product doc."""
267 return self.getFirstFieldVal(self.CAT)
269 def getCategoryName(self):
270 """Get the value of the 'catname' field of a Product doc."""
271 return self.getFirstFieldVal(self.CATNAME)
273 def setCategory(self, cat):
274 """Set the value of the 'cat' (category) field of a Product doc."""
275 return self.setFirstField(search.NumberField(name=self.CAT, value=cat))
277 def getAvgRating(self):
278 """Get the value of the 'ar' (average rating) field of a Product doc."""
279 return self.getFirstFieldVal(self.AR)
281 def setAvgRating(self, ar):
282 """Set the value of the 'ar' field of a Product doc."""
283 return self.setFirstField(search.NumberField(name=self.AR, value=ar))
285 def getPrice(self):
286 """Get the value of the 'price' field of a Product doc."""
287 return self.getFirstFieldVal(self.PRICE)
289 @classmethod
290 def generateRatingsBuckets(cls, query_string):
291 """Builds a dict of ratings 'buckets' and their counts, based on the
292 value of the 'avg_rating" field for the documents retrieved by the given
293 query. See the 'generateRatingsLinks' method. This information will
294 be used to generate sidebar links that allow the user to drill down in query
295 results based on rating.
297 For demonstration purposes only; this will be expensive for large data
298 sets.
301 # do the query on the *full* search results
302 # to generate the facet information, imitating what may in future be
303 # provided by the FTS API.
304 try:
305 sq = search.Query(
306 query_string=query_string.strip())
307 search_results = cls.getIndex().search(sq)
308 except search.Error:
309 logging.exception('An error occurred on search.')
310 return None
312 ratings_buckets = collections.defaultdict(int)
313 # populate the buckets
314 for res in search_results:
315 ratings_buckets[int((cls(res)).getAvgRating() or 0)] += 1
316 return ratings_buckets
318 @classmethod
319 def generateRatingsLinks(cls, query, phash):
320 """Given a dict of ratings 'buckets' and their counts,
321 builds a list of html snippets, to be displayed in the sidebar when
322 showing results of a query. Each is a link that runs the query, additionally
323 filtered by the indicated ratings interval."""
325 ratings_buckets = cls.generateRatingsBuckets(query)
326 if not ratings_buckets:
327 return None
328 rlist = []
329 for k in range(config.RATING_MIN, config.RATING_MAX+1):
330 try:
331 v = ratings_buckets[k]
332 except KeyError:
333 return
334 # build html
335 if k < 5:
336 htext = '%s-%s (%s)' % (k, k+1, v)
337 else:
338 htext = '%s (%s)' % (k, v)
339 phash['rating'] = k
340 hlink = '/psearch?' + urllib.urlencode(phash)
341 rlist.append((hlink, htext))
342 return rlist
344 @classmethod
345 def _buildCoreProductFields(
346 cls, pid, name, description, category, category_name, price):
347 """Construct a 'core' document field list for the fields common to all
348 Products. The various categories (as defined in the file 'categories.py'),
349 may add additional specialized fields; these will be appended to this
350 core list. (see _buildProductFields)."""
351 fields = [search.TextField(name=cls.PID, value=pid),
352 search.TextField(name=cls.PNAME, value=name),
353 # strip the markup from the description value, which can
354 # potentially come from user input. We do this so that
355 # we don't need to sanitize the description in the
356 # templates, showing off the Search API's ability to mark up query
357 # terms in generated snippets. This is done only for
358 # demonstration purposes; in an actual app,
359 # it would be preferrable to use a library like Beautiful Soup
360 # instead.
361 # We'll let the templating library escape all other rendered
362 # values for us, so this is the only field we do this for.
363 search.TextField(
364 name=cls.DESCRIPTION,
365 value=re.sub(r'<[^>]*?>', '', description)),
366 search.NumberField(name=cls.CAT, value=category),
367 search.TextField(name=cls.CATNAME, value=category_name),
368 search.NumberField(name=cls.AR, value=0.0),
369 search.NumberField(name=cls.PRICE, value=price)
371 return fields
373 @classmethod
374 def _buildProductFields(cls, pid=None, category=None, name=None,
375 description=None, category_name=None, price=None, **params):
376 """Build all the additional non-core fields for a document of the given
377 product type (category), using the given params dict, and the
378 already-constructed list of 'core' fields. All such additional
379 category-specific fields are treated as required.
382 fields = cls._buildCoreProductFields(
383 pid, name, description, category, category_name, price)
384 # get the specification of additional (non-'core') fields for this category
385 pdict = categories.product_dict.get(category_name)
386 if pdict:
387 # for all fields
388 for k, field_type in pdict.iteritems():
389 # see if there is a value in the given params for that field.
390 # if there is, get the field type, create the field, and append to the
391 # document field list.
392 if k in params:
393 v = params[k]
394 if field_type == search.NumberField:
395 try:
396 val = float(v)
397 fields.append(search.NumberField(name=k, value=val))
398 except ValueError:
399 error_message = ('bad value %s for field %s of type %s' %
400 (k, v, field_type))
401 logging.error(error_message)
402 raise errors.OperationFailedError(error_message)
403 elif field_type == search.TextField:
404 fields.append(search.TextField(name=k, value=str(v)))
405 else:
406 # TODO -- add handling of other field types for generality. Not
407 # needed for our current sample data.
408 logging.warn('not processed: %s, %s, of type %s', k, v, field_type)
409 else:
410 error_message = ('value not given for field "%s" of field type "%s"'
411 % (k, field_type))
412 logging.warn(error_message)
413 raise errors.OperationFailedError(error_message)
414 else:
415 # else, did not have an entry in the params dict for the given field.
416 logging.warn(
417 'product field information not found for category name %s',
418 params['category_name'])
419 return fields
421 @classmethod
422 def _createDocument(
423 cls, pid=None, category=None, name=None, description=None,
424 category_name=None, price=None, **params):
425 """Create a Document object from given params."""
426 # check for the fields that are always required.
427 if pid and category and name:
428 # First, check that the given pid has only visible ascii characters,
429 # and does not contain whitespace. The pid will be used as the doc_id,
430 # which has these requirements.
431 if not cls.isValidDocId(pid):
432 raise errors.OperationFailedError("Illegal pid %s" % pid)
433 # construct the document fields from the params
434 resfields = cls._buildProductFields(
435 pid=pid, category=category, name=name,
436 description=description,
437 category_name=category_name, price=price, **params)
438 # build and index the document. Use the pid (product id) as the doc id.
439 # (If we did not do this, and left the doc_id unspecified, an id would be
440 # auto-generated.)
441 d = search.Document(doc_id=pid, fields=resfields)
442 return d
443 else:
444 raise errors.OperationFailedError('Missing parameter.')
446 @classmethod
447 def _normalizeParams(cls, params):
448 """Normalize the submitted params for building a product."""
450 params = copy.deepcopy(params)
451 chash = models.Category.getCategoryDict()
452 try:
453 params['pid'] = params['pid'].strip()
454 params['name'] = params['name'].strip()
455 params['category_name'] = params['category']
456 params['category'] = int(chash.get(params['category']))
457 try:
458 params['price'] = float(params['price'])
459 except ValueError:
460 error_message = 'bad price value: %s' % params['price']
461 logging.error(error_message)
462 raise errors.OperationFailedError(error_message)
463 return params
464 except KeyError as e1:
465 raise errors.OperationFailedError(e1)
466 except errors.Error as e2:
467 logging.debug(
468 'Problem with params: %s: %s' % (params, e2.error_message))
469 raise errors.OperationFailedError(e2.error_message)
471 @classmethod
472 def buildProductBatch(cls, rows):
473 """Build product documents and their related datastore entities, in batch,
474 given a list of params dicts. Should be used for new products, as does not
475 handle updates of existing product entities."""
477 docs = []
478 dbps = []
479 for row in rows:
480 try:
481 params = cls._normalizeParams(row)
482 doc = cls._createDocument(**params)
483 docs.append(doc)
484 # create product entity, sans doc_id
485 dbp = models.Product(
486 id=params['pid'], price=params['price'],
487 category=params['category'])
488 dbps.append(dbp)
489 except errors.OperationFailedError:
490 logging.error('error creating document from data: %s', row)
491 doc_ids = cls.add(docs)
492 if len(doc_ids) != len(dbps):
493 raise errors.OperationFailedError(
494 'Error: wrong number of doc ids returned from indexing operation')
495 # now set the entities with the doc ids, the list of which are returned in
496 # the same order as the list of docs given to the indexers
497 for i, dbp in enumerate(dbps):
498 dbp.doc_id = doc_ids[i].document_id
499 # persist the entities
500 ndb.put_multi(dbps)
502 @classmethod
503 def buildProduct(cls, params):
504 """Create/update a product document and its related datastore entity. The
505 product id and the field values are taken from the params dict.
507 params = cls._normalizeParams(params)
508 # check to see if doc already exists
509 curr_doc = cls.getDocFromPid(params['pid'])
510 d = cls._createDocument(**params)
511 if curr_doc: #don't overwrite ratings info from existing doc
512 try:
513 avg_rating = cls(curr_doc).getAvgRating()
514 cls(d).setAvgRating(avg_rating)
515 except TypeError:
516 # catch potential issue with 0-valued numeric fields in older SDK
517 logging.exception("catch 0-valued field error:")
519 # This will reindex if a doc with that doc id already exists
520 doc_ids = cls.add(d)
521 try:
522 doc_id = doc_ids[0].document_id
523 except IndexError:
524 doc_id = None
525 raise errors.OperationFailedError('could not index document')
526 logging.debug('got new doc id %s for product: %s', doc_id, params['pid'])
528 # now update the entity
529 def _tx():
530 # Check whether the product entity exists. If so, we want to update
531 # from the params, but preserve its ratings-related info.
532 prod = models.Product.get_by_id(params['pid'])
533 if prod: #update
534 prod.update_core(params, doc_id)
535 else: # create new entity
536 prod = models.Product.create(params, doc_id)
537 prod.put()
538 return prod
539 prod = ndb.transaction(_tx)
540 logging.debug('prod: %s', prod)
541 return prod