update to reflect API changes
[gae-samples.git] / search / product_search_python / docs.py
blob7e41a21e4c94fbbea7bd141b5ede7441c76c54a5
1 #!/usr/bin/env python
3 # Copyright 2012 Google Inc.
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 """ Contains 'helper' classes for managing search.Documents.
18 BaseDocumentManager provides some common utilities, and the Product subclass
19 adds some Product-document-specific helper methods.
20 """
22 import collections
23 import copy
24 import datetime
25 import logging
26 import re
27 import string
28 import urllib
30 import categories
31 import config
32 import errors
33 import models
35 from google.appengine.api import search
36 from google.appengine.ext import ndb
39 class BaseDocumentManager(object):
40 """Abstract class. Provides helper methods to manage search.Documents."""
42 _INDEX_NAME = None
43 _VISIBLE_PRINTABLE_ASCII = frozenset(
44 set(string.printable) - set(string.whitespace))
46 def __init__(self, doc):
47 """Builds a dict of the fields mapped against the field names, for
48 efficient access.
49 """
50 self.doc = doc
51 fields = doc.fields
53 def getFieldVal(self, fname):
54 """Get the value of the document field with the given name. If there is
55 more than one such field, the method returns None."""
56 try:
57 return self.doc.field(fname).value
58 except ValueError:
59 return None
61 def setFirstField(self, new_field):
62 """Set the value of the (first) document field with the given name."""
63 for i, field in enumerate(self.doc.fields):
64 if field.name == new_field.name:
65 self.doc.fields[i] = new_field
66 return True
67 return False
69 @classmethod
70 def isValidDocId(cls, doc_id):
71 """Checks if the given id is a visible printable ASCII string not starting
72 with '!'. Whitespace characters are excluded.
73 """
74 for char in doc_id:
75 if char not in cls._VISIBLE_PRINTABLE_ASCII:
76 return False
77 return not doc_id.startswith('!')
79 @classmethod
80 def getIndex(cls):
81 return search.Index(name=cls._INDEX_NAME)
83 @classmethod
84 def deleteAllInIndex(cls):
85 """Delete all the docs in the given index."""
86 docindex = cls.getIndex()
88 try:
89 while True:
90 # until no more documents, get a list of documents,
91 # constraining the returned objects to contain only the doc ids,
92 # extract the doc ids, and delete the docs.
93 document_ids = [document.doc_id
94 for document in docindex.get_range(ids_only=True)]
95 if not document_ids:
96 break
97 docindex.delete(document_ids)
98 except search.Error:
99 logging.exception("Error removing documents:")
101 @classmethod
102 def getDoc(cls, doc_id):
103 """Return the document with the given doc id. One way to do this is via
104 the get_range method, as shown here. If the doc id is not in the
105 index, the first doc in the index will be returned instead, so we need
106 to check for that case."""
107 if not doc_id:
108 return None
109 try:
110 index = cls.getIndex()
111 response = index.get_range(
112 start_id=doc_id, limit=1, include_start_object=True)
113 if response.results and response.results[0].doc_id == doc_id:
114 return response.results[0]
115 return None
116 except search.InvalidRequest: # catches ill-formed doc ids
117 return None
119 @classmethod
120 def removeDocById(cls, doc_id):
121 """Remove the doc with the given doc id."""
122 try:
123 cls.getIndex().delete(doc_id)
124 except search.Error:
125 logging.exception("Error removing doc id %s.", doc_id)
127 @classmethod
128 def add(cls, documents):
129 """wrapper for search index add method; specifies the index name."""
130 try:
131 return cls.getIndex().put(documents)
132 except search.Error:
133 logging.exception("Error adding documents.")
136 class Store(BaseDocumentManager):
138 _INDEX_NAME = config.STORE_INDEX_NAME
139 STORE_NAME = 'store_name'
140 STORE_LOCATION = 'store_location'
141 STORE_ADDRESS = 'store_address'
144 class Product(BaseDocumentManager):
145 """Provides helper methods to manage Product documents. All Product documents
146 built using these methods will include a core set of fields (see the
147 _buildCoreProductFields method). We use the given product id (the Product
148 entity key) as the doc_id. This is not required for the entity/document
149 design-- each explicitly point to each other, allowing their ids to be
150 decoupled-- but using the product id as the doc id allows a document to be
151 reindexed given its product info, without having to fetch the
152 existing document."""
154 _INDEX_NAME = config.PRODUCT_INDEX_NAME
156 # 'core' product document field names
157 PID = 'pid'
158 DESCRIPTION = 'description'
159 CATEGORY = 'category'
160 PRODUCT_NAME = 'name'
161 PRICE = 'price'
162 AVG_RATING = 'ar' #average rating
163 UPDATED = 'modified'
165 _SORT_OPTIONS = [
166 [AVG_RATING, 'average rating', search.SortExpression(
167 expression=AVG_RATING,
168 direction=search.SortExpression.DESCENDING, default_value=0)],
169 [PRICE, 'price', search.SortExpression(
170 # other examples:
171 # expression='max(price, 14.99)'
172 # If you access _score in your sort expressions,
173 # your SortOptions should include a scorer.
174 # e.g. search.SortOptions(match_scorer=search.MatchScorer(),...)
175 # Then, you can access the score to build expressions like:
176 # expression='price * _score'
177 expression=PRICE,
178 direction=search.SortExpression.ASCENDING, default_value=9999)],
179 [UPDATED, 'modified', search.SortExpression(
180 expression=UPDATED,
181 direction=search.SortExpression.DESCENDING, default_value=1)],
182 [CATEGORY, 'category', search.SortExpression(
183 expression=CATEGORY,
184 direction=search.SortExpression.ASCENDING, default_value='')],
185 [PRODUCT_NAME, 'product name', search.SortExpression(
186 expression=PRODUCT_NAME,
187 direction=search.SortExpression.ASCENDING, default_value='zzz')]
190 _SORT_MENU = None
191 _SORT_DICT = None
194 @classmethod
195 def deleteAllInProductIndex(cls):
196 cls.deleteAllInIndex()
198 @classmethod
199 def getSortMenu(cls):
200 if not cls._SORT_MENU:
201 cls._buildSortMenu()
202 return cls._SORT_MENU
204 @classmethod
205 def getSortDict(cls):
206 if not cls._SORT_DICT:
207 cls._buildSortDict()
208 return cls._SORT_DICT
210 @classmethod
211 def _buildSortMenu(cls):
212 """Build the default set of sort options used for Product search.
213 Of these options, all but 'relevance' reference core fields that
214 all Products will have."""
215 res = [(elt[0], elt[1]) for elt in cls._SORT_OPTIONS]
216 cls._SORT_MENU = [('relevance', 'relevance')] + res
218 @classmethod
219 def _buildSortDict(cls):
220 """Build a dict that maps sort option keywords to their corresponding
221 SortExpressions."""
222 cls._SORT_DICT = {}
223 for elt in cls._SORT_OPTIONS:
224 cls._SORT_DICT[elt[0]] = elt[2]
226 @classmethod
227 def getDocFromPid(cls, pid):
228 """Given a pid, get its doc. We're using the pid as the doc id, so we can
229 do this via a direct fetch."""
230 return cls.getDoc(pid)
232 @classmethod
233 def removeProductDocByPid(cls, pid):
234 """Given a doc's pid, remove the doc matching it from the product
235 index."""
236 cls.removeDocById(pid)
238 @classmethod
239 def updateRatingInDoc(cls, doc_id, avg_rating):
240 # get the associated doc from the doc id in the product entity
241 doc = cls.getDoc(doc_id)
242 if doc:
243 pdoc = cls(doc)
244 pdoc.setAvgRating(avg_rating)
245 # The use of the same id will cause the existing doc to be reindexed.
246 return doc
247 else:
248 raise errors.OperationFailedError(
249 'Could not retrieve doc associated with id %s' % (doc_id,))
251 @classmethod
252 def updateRatingsInfo(cls, doc_id, avg_rating):
253 """Given a models.Product entity, update and reindex the associated
254 document with the product entity's current average rating. """
256 ndoc = cls.updateRatingInDoc(doc_id, avg_rating)
257 # reindex the returned updated doc
258 return cls.add(ndoc)
260 # 'accessor' convenience methods
262 def getPID(self):
263 """Get the value of the 'pid' field of a Product doc."""
264 return self.getFieldVal(self.PID)
266 def getName(self):
267 """Get the value of the 'name' field of a Product doc."""
268 return self.getFieldVal(self.PRODUCT_NAME)
270 def getDescription(self):
271 """Get the value of the 'description' field of a Product doc."""
272 return self.getFieldVal(self.DESCRIPTION)
274 def getCategory(self):
275 """Get the value of the 'cat' field of a Product doc."""
276 return self.getFieldVal(self.CATEGORY)
278 def setCategory(self, cat):
279 """Set the value of the 'cat' (category) field of a Product doc."""
280 return self.setFirstField(search.NumberField(name=self.CATEGORY, value=cat))
282 def getAvgRating(self):
283 """Get the value of the 'ar' (average rating) field of a Product doc."""
284 return self.getFieldVal(self.AVG_RATING)
286 def setAvgRating(self, ar):
287 """Set the value of the 'ar' field of a Product doc."""
288 return self.setFirstField(search.NumberField(name=self.AVG_RATING, value=ar))
290 def getPrice(self):
291 """Get the value of the 'price' field of a Product doc."""
292 return self.getFieldVal(self.PRICE)
294 @classmethod
295 def generateRatingsBuckets(cls, query_string):
296 """Builds a dict of ratings 'buckets' and their counts, based on the
297 value of the 'avg_rating" field for the documents retrieved by the given
298 query. See the 'generateRatingsLinks' method. This information will
299 be used to generate sidebar links that allow the user to drill down in query
300 results based on rating.
302 For demonstration purposes only; this will be expensive for large data
303 sets.
306 # do the query on the *full* search results
307 # to generate the facet information, imitating what may in future be
308 # provided by the FTS API.
309 try:
310 sq = search.Query(
311 query_string=query_string.strip())
312 search_results = cls.getIndex().search(sq)
313 except search.Error:
314 logging.exception('An error occurred on search.')
315 return None
317 ratings_buckets = collections.defaultdict(int)
318 # populate the buckets
319 for res in search_results:
320 ratings_buckets[int((cls(res)).getAvgRating() or 0)] += 1
321 return ratings_buckets
323 @classmethod
324 def generateRatingsLinks(cls, query, phash):
325 """Given a dict of ratings 'buckets' and their counts,
326 builds a list of html snippets, to be displayed in the sidebar when
327 showing results of a query. Each is a link that runs the query, additionally
328 filtered by the indicated ratings interval."""
330 ratings_buckets = cls.generateRatingsBuckets(query)
331 if not ratings_buckets:
332 return None
333 rlist = []
334 for k in range(config.RATING_MIN, config.RATING_MAX+1):
335 try:
336 v = ratings_buckets[k]
337 except KeyError:
338 return
339 # build html
340 if k < 5:
341 htext = '%s-%s (%s)' % (k, k+1, v)
342 else:
343 htext = '%s (%s)' % (k, v)
344 phash['rating'] = k
345 hlink = '/psearch?' + urllib.urlencode(phash)
346 rlist.append((hlink, htext))
347 return rlist
349 @classmethod
350 def _buildCoreProductFields(
351 cls, pid, name, description, category, category_name, price):
352 """Construct a 'core' document field list for the fields common to all
353 Products. The various categories (as defined in the file 'categories.py'),
354 may add additional specialized fields; these will be appended to this
355 core list. (see _buildProductFields)."""
356 fields = [search.TextField(name=cls.PID, value=pid),
357 # The 'updated' field is always set to the current date.
358 search.DateField(name=cls.UPDATED,
359 value=datetime.datetime.now().date()),
360 search.TextField(name=cls.PRODUCT_NAME, value=name),
361 # strip the markup from the description value, which can
362 # potentially come from user input. We do this so that
363 # we don't need to sanitize the description in the
364 # templates, showing off the Search API's ability to mark up query
365 # terms in generated snippets. This is done only for
366 # demonstration purposes; in an actual app,
367 # it would be preferrable to use a library like Beautiful Soup
368 # instead.
369 # We'll let the templating library escape all other rendered
370 # values for us, so this is the only field we do this for.
371 search.TextField(
372 name=cls.DESCRIPTION,
373 value=re.sub(r'<[^>]*?>', '', description)),
374 search.AtomField(name=cls.CATEGORY, value=category),
375 search.NumberField(name=cls.AVG_RATING, value=0.0),
376 search.NumberField(name=cls.PRICE, value=price)
378 return fields
380 @classmethod
381 def _buildProductFields(cls, pid=None, category=None, name=None,
382 description=None, category_name=None, price=None, **params):
383 """Build all the additional non-core fields for a document of the given
384 product type (category), using the given params dict, and the
385 already-constructed list of 'core' fields. All such additional
386 category-specific fields are treated as required.
389 fields = cls._buildCoreProductFields(
390 pid, name, description, category, category_name, price)
391 # get the specification of additional (non-'core') fields for this category
392 pdict = categories.product_dict.get(category_name)
393 if pdict:
394 # for all fields
395 for k, field_type in pdict.iteritems():
396 # see if there is a value in the given params for that field.
397 # if there is, get the field type, create the field, and append to the
398 # document field list.
399 if k in params:
400 v = params[k]
401 if field_type == search.NumberField:
402 try:
403 val = float(v)
404 fields.append(search.NumberField(name=k, value=val))
405 except ValueError:
406 error_message = ('bad value %s for field %s of type %s' %
407 (k, v, field_type))
408 logging.error(error_message)
409 raise errors.OperationFailedError(error_message)
410 elif field_type == search.TextField:
411 fields.append(search.TextField(name=k, value=str(v)))
412 else:
413 # you may want to add handling of other field types for generality.
414 # Not needed for our current sample data.
415 logging.warn('not processed: %s, %s, of type %s', k, v, field_type)
416 else:
417 error_message = ('value not given for field "%s" of field type "%s"'
418 % (k, field_type))
419 logging.warn(error_message)
420 raise errors.OperationFailedError(error_message)
421 else:
422 # else, did not have an entry in the params dict for the given field.
423 logging.warn(
424 'product field information not found for category name %s',
425 params['category_name'])
426 return fields
428 @classmethod
429 def _createDocument(
430 cls, pid=None, category=None, name=None, description=None,
431 category_name=None, price=None, **params):
432 """Create a Document object from given params."""
433 # check for the fields that are always required.
434 if pid and category and name:
435 # First, check that the given pid has only visible ascii characters,
436 # and does not contain whitespace. The pid will be used as the doc_id,
437 # which has these requirements.
438 if not cls.isValidDocId(pid):
439 raise errors.OperationFailedError("Illegal pid %s" % pid)
440 # construct the document fields from the params
441 resfields = cls._buildProductFields(
442 pid=pid, category=category, name=name,
443 description=description,
444 category_name=category_name, price=price, **params)
445 # build and index the document. Use the pid (product id) as the doc id.
446 # (If we did not do this, and left the doc_id unspecified, an id would be
447 # auto-generated.)
448 d = search.Document(doc_id=pid, fields=resfields)
449 return d
450 else:
451 raise errors.OperationFailedError('Missing parameter.')
453 @classmethod
454 def _normalizeParams(cls, params):
455 """Normalize the submitted params for building a product."""
457 params = copy.deepcopy(params)
458 try:
459 params['pid'] = params['pid'].strip()
460 params['name'] = params['name'].strip()
461 params['category_name'] = params['category']
462 params['category'] = params['category']
463 try:
464 params['price'] = float(params['price'])
465 except ValueError:
466 error_message = 'bad price value: %s' % params['price']
467 logging.error(error_message)
468 raise errors.OperationFailedError(error_message)
469 return params
470 except KeyError as e1:
471 logging.exception("key error")
472 raise errors.OperationFailedError(e1)
473 except errors.Error as e2:
474 logging.debug(
475 'Problem with params: %s: %s' % (params, e2.error_message))
476 raise errors.OperationFailedError(e2.error_message)
478 @classmethod
479 def buildProductBatch(cls, rows):
480 """Build product documents and their related datastore entities, in batch,
481 given a list of params dicts. Should be used for new products, as does not
482 handle updates of existing product entities. This method does not require
483 that the doc ids be tied to the product ids, and obtains the doc ids from
484 the results of the document add."""
486 docs = []
487 dbps = []
488 for row in rows:
489 try:
490 params = cls._normalizeParams(row)
491 doc = cls._createDocument(**params)
492 docs.append(doc)
493 # create product entity, sans doc_id
494 dbp = models.Product(
495 id=params['pid'], price=params['price'],
496 category=params['category'])
497 dbps.append(dbp)
498 except errors.OperationFailedError:
499 logging.error('error creating document from data: %s', row)
500 try:
501 add_results = cls.add(docs)
502 except search.Error:
503 logging.exception('Add failed')
504 return
505 if len(add_results) != len(dbps):
506 # this case should not be reached; if there was an issue,
507 # search.Error should have been thrown, above.
508 raise errors.OperationFailedError(
509 'Error: wrong number of results returned from indexing operation')
510 # now set the entities with the doc ids, the list of which are returned in
511 # the same order as the list of docs given to the indexers
512 for i, dbp in enumerate(dbps):
513 dbp.doc_id = add_results[i].id
514 # persist the entities
515 ndb.put_multi(dbps)
517 @classmethod
518 def buildProduct(cls, params):
519 """Create/update a product document and its related datastore entity. The
520 product id and the field values are taken from the params dict.
522 params = cls._normalizeParams(params)
523 # check to see if doc already exists. We do this because we need to retain
524 # some information from the existing doc. We could skip the fetch if this
525 # were not the case.
526 curr_doc = cls.getDocFromPid(params['pid'])
527 d = cls._createDocument(**params)
528 if curr_doc: # retain ratings info from existing doc
529 avg_rating = cls(curr_doc).getAvgRating()
530 cls(d).setAvgRating(avg_rating)
532 # This will reindex if a doc with that doc id already exists
533 doc_ids = cls.add(d)
534 try:
535 doc_id = doc_ids[0].id
536 except IndexError:
537 doc_id = None
538 raise errors.OperationFailedError('could not index document')
539 logging.debug('got new doc id %s for product: %s', doc_id, params['pid'])
541 # now update the entity
542 def _tx():
543 # Check whether the product entity exists. If so, we want to update
544 # from the params, but preserve its ratings-related info.
545 prod = models.Product.get_by_id(params['pid'])
546 if prod: #update
547 prod.update_core(params, doc_id)
548 else: # create new entity
549 prod = models.Product.create(params, doc_id)
550 prod.put()
551 return prod
552 prod = ndb.transaction(_tx)
553 logging.debug('prod: %s', prod)
554 return prod