App Engine Python SDK version 1.9.9
[gae.git] / python / google / appengine / api / search / stub / document_matcher.py
blobfae19d64f84c60aa01349b45603187c3000466e6
1 #!/usr/bin/env python
3 # Copyright 2007 Google Inc.
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 """Document matcher for Search API stub.
19 DocumentMatcher provides an approximation of the Search API's query matching.
20 """
23 from google.appengine.datastore import document_pb
25 from google.appengine._internal.antlr3 import tree
26 from google.appengine.api.search import geo_util
27 from google.appengine.api.search import query_parser
28 from google.appengine.api.search import QueryParser
29 from google.appengine.api.search import search_util
30 from google.appengine.api.search.stub import simple_tokenizer
31 from google.appengine.api.search.stub import tokens
34 MSEC_PER_DAY = 86400000
36 class ExpressionTreeException(Exception):
37 """An error occurred while analyzing/translating the expression parse tree."""
39 def __init__(self, msg):
40 Exception.__init__(self, msg)
43 class DistanceMatcher(object):
44 """A class to match on geo distance."""
45 def __init__(self, geopoint, distance):
46 self._geopoint = geopoint
47 self._distance = distance
49 def _CheckOp(self, op):
50 if op == QueryParser.EQ or op == QueryParser.HAS:
51 raise ExpressionTreeException('Equality comparison not available for Geo type')
52 if op == QueryParser.NE:
53 raise ExpressionTreeException('!= comparison operator is not available')
54 if op not in (QueryParser.GT, QueryParser.GE, QueryParser.LESSTHAN, QueryParser.LE):
55 raise search_util.UnsupportedOnDevError(
56 'Operator %s not supported for distance matches on development server.'
57 % str(op))
59 def _IsDistanceMatch(self, geopoint, op):
60 distance = geopoint - self._geopoint
61 if op == QueryParser.GT or op == QueryParser.GE:
62 return distance >= self._distance
63 if op == QueryParser.LESSTHAN or op == QueryParser.LE:
64 return distance <= self._distance
65 else:
66 raise AssertionError, 'unexpected op %s' % str(op)
68 def IsMatch(self, field_values, op):
69 self._CheckOp(op)
73 for field_value in field_values:
74 geo_pb = field_value.geo()
75 geopoint = geo_util.LatLng(geo_pb.lat(), geo_pb.lng())
76 if self._IsDistanceMatch(geopoint, op):
77 return True
80 return False
83 class DocumentMatcher(object):
84 """A class to match documents with a query."""
86 def __init__(self, query, inverted_index):
87 self._query = query
88 self._inverted_index = inverted_index
89 self._parser = simple_tokenizer.SimpleTokenizer()
91 def _PostingsForToken(self, token):
92 """Returns the postings for the token."""
93 return self._inverted_index.GetPostingsForToken(token)
95 def _PostingsForFieldToken(self, field, value):
96 """Returns postings for the value occurring in the given field."""
97 value = simple_tokenizer.NormalizeString(value)
98 return self._PostingsForToken(
99 tokens.Token(chars=value, field_name=field))
101 def _MatchRawPhraseWithRawAtom(self, field_text, phrase_text):
102 tokenized_phrase = self._parser.TokenizeText(
103 phrase_text, input_field_type=document_pb.FieldValue.ATOM)
104 tokenized_field_text = self._parser.TokenizeText(
105 field_text, input_field_type=document_pb.FieldValue.ATOM)
106 return tokenized_phrase == tokenized_field_text
108 def _MatchPhrase(self, field, match, document):
109 """Match a textual field with a phrase query node."""
110 field_text = field.value().string_value()
111 phrase_text = query_parser.GetPhraseQueryNodeText(match)
114 if field.value().type() == document_pb.FieldValue.ATOM:
115 return self._MatchRawPhraseWithRawAtom(field_text, phrase_text)
118 if not phrase_text:
119 return False
121 phrase = self._parser.TokenizeText(phrase_text)
122 field_text = self._parser.TokenizeText(field_text)
123 if not phrase:
124 return True
125 posting = None
126 for post in self._PostingsForFieldToken(field.name(), phrase[0].chars):
127 if post.doc_id == document.id():
128 posting = post
129 break
130 if not posting:
131 return False
133 def ExtractWords(token_list):
134 return (token.chars for token in token_list)
136 for position in posting.positions:
141 match_words = zip(ExtractWords(field_text[position:]),
142 ExtractWords(phrase))
143 if len(match_words) != len(phrase):
144 continue
147 match = True
148 for doc_word, match_word in match_words:
149 if doc_word != match_word:
150 match = False
152 if match:
153 return True
154 return False
156 def _MatchTextField(self, field, match, document):
157 """Check if a textual field matches a query tree node."""
159 if match.getType() == QueryParser.FUZZY:
160 return self._MatchTextField(field, match.getChild(0), document)
162 if match.getType() == QueryParser.VALUE:
163 if query_parser.IsPhrase(match):
164 return self._MatchPhrase(field, match, document)
167 if field.value().type() == document_pb.FieldValue.ATOM:
168 return (field.value().string_value() ==
169 query_parser.GetQueryNodeText(match))
171 query_tokens = self._parser.TokenizeText(
172 query_parser.GetQueryNodeText(match))
175 if not query_tokens:
176 return True
181 if len(query_tokens) > 1:
182 def QueryNode(token):
183 return query_parser.CreateQueryNode(token.chars, QueryParser.TEXT)
184 return all(self._MatchTextField(field, QueryNode(token), document)
185 for token in query_tokens)
187 token_text = query_tokens[0].chars
188 matching_docids = [
189 post.doc_id for post in self._PostingsForFieldToken(
190 field.name(), token_text)]
191 return document.id() in matching_docids
193 def ExtractGlobalEq(node):
194 op = node.getType()
195 if ((op == QueryParser.EQ or op == QueryParser.HAS) and
196 len(node.children) >= 2):
197 if node.children[0].getType() == QueryParser.GLOBAL:
198 return node.children[1]
199 return node
201 if match.getType() == QueryParser.CONJUNCTION:
202 return all(self._MatchTextField(field, ExtractGlobalEq(child), document)
203 for child in match.children)
205 if match.getType() == QueryParser.DISJUNCTION:
206 return any(self._MatchTextField(field, ExtractGlobalEq(child), document)
207 for child in match.children)
209 if match.getType() == QueryParser.NEGATION:
210 raise ExpressionTreeException('Unable to compare \"' + field.name() +
211 '\" with negation')
214 return False
216 def _MatchDateField(self, field, match, operator, document):
217 """Check if a date field matches a query tree node."""
220 return self._MatchComparableField(
221 field, match, _DateStrToDays, operator, document)
225 def _MatchNumericField(self, field, match, operator, document):
226 """Check if a numeric field matches a query tree node."""
227 return self._MatchComparableField(field, match, float, operator, document)
229 def _MatchGeoField(self, field, matcher, operator, document):
230 """Check if a geo field matches a query tree node."""
232 if not isinstance(matcher, DistanceMatcher):
233 return False
235 if isinstance(field, tree.CommonTree):
236 field = query_parser.GetQueryNodeText(field)
237 values = [ field.value() for field in
238 search_util.GetAllFieldInDocument(document, field) if
239 field.value().type() == document_pb.FieldValue.GEO ]
240 return matcher.IsMatch(values, operator)
243 def _MatchComparableField(
244 self, field, match, cast_to_type, op, document):
245 """A generic method to test matching for comparable types.
247 Comparable types are defined to be anything that supports <, >, <=, >=, ==.
248 For our purposes, this is numbers and dates.
250 Args:
251 field: The document_pb.Field to test
252 match: The query node to match against
253 cast_to_type: The type to cast the node string values to
254 op: The query node type representing the type of comparison to perform
255 document: The document that the field is in
257 Returns:
258 True iff the field matches the query.
260 Raises:
261 UnsupportedOnDevError: Raised when an unsupported operator is used, or
262 when the query node is of the wrong type.
263 ExpressionTreeException: Raised when a != inequality operator is used.
266 field_val = cast_to_type(field.value().string_value())
268 if match.getType() == QueryParser.VALUE:
269 try:
270 match_val = cast_to_type(query_parser.GetPhraseQueryNodeText(match))
271 except ValueError:
272 return False
273 else:
274 return False
276 if op == QueryParser.EQ or op == QueryParser.HAS:
277 return field_val == match_val
278 if op == QueryParser.NE:
279 raise ExpressionTreeException('!= comparison operator is not available')
280 if op == QueryParser.GT:
281 return field_val > match_val
282 if op == QueryParser.GE:
283 return field_val >= match_val
284 if op == QueryParser.LESSTHAN:
285 return field_val < match_val
286 if op == QueryParser.LE:
287 return field_val <= match_val
288 raise search_util.UnsupportedOnDevError(
289 'Operator %s not supported for numerical fields on development server.'
290 % match.getText())
292 def _MatchAnyField(self, field, match, operator, document):
293 """Check if a field matches a query tree.
295 Args:
296 field: the name of the field, or a query node containing the field.
297 match: A query node to match the field with.
298 operator: The query node type corresponding to the type of match to
299 perform (eg QueryParser.EQ, QueryParser.GT, etc).
300 document: The document to match.
303 if isinstance(field, tree.CommonTree):
304 field = query_parser.GetQueryNodeText(field)
305 fields = search_util.GetAllFieldInDocument(document, field)
306 return any(self._MatchField(f, match, operator, document) for f in fields)
308 def _MatchField(self, field, match, operator, document):
309 """Check if a field matches a query tree.
311 Args:
312 field: a document_pb.Field instance to match.
313 match: A query node to match the field with.
314 operator: The a query node type corresponding to the type of match to
315 perform (eg QueryParser.EQ, QueryParser.GT, etc).
316 document: The document to match.
319 if field.value().type() in search_util.TEXT_DOCUMENT_FIELD_TYPES:
320 if operator != QueryParser.EQ and operator != QueryParser.HAS:
321 return False
322 return self._MatchTextField(field, match, document)
324 if field.value().type() in search_util.NUMBER_DOCUMENT_FIELD_TYPES:
325 return self._MatchNumericField(field, match, operator, document)
327 if field.value().type() == document_pb.FieldValue.DATE:
328 return self._MatchDateField(field, match, operator, document)
334 if field.value().type() == document_pb.FieldValue.GEO:
335 return False
337 type_name = document_pb.FieldValue.ContentType_Name(
338 field.value().type()).lower()
339 raise search_util.UnsupportedOnDevError(
340 'Matching fields of type %s is unsupported on dev server (searched for '
341 'field %s)' % (type_name, field.name()))
343 def _MatchGlobal(self, match, document):
344 for field in document.field_list():
345 try:
346 if self._MatchAnyField(field.name(), match, QueryParser.EQ, document):
347 return True
348 except search_util.UnsupportedOnDevError:
352 pass
353 return False
355 def _ResolveDistanceArg(self, node):
356 if node.getType() == QueryParser.VALUE:
357 return query_parser.GetQueryNodeText(node)
358 if node.getType() == QueryParser.FUNCTION:
359 name, args = node.children
360 if name.getText() == 'geopoint':
361 lat, lng = (float(query_parser.GetQueryNodeText(v)) for v in args.children)
362 return geo_util.LatLng(lat, lng)
363 return None
365 def _MatchFunction(self, node, match, operator, document):
366 name, args = node.children
367 if name.getText() == 'distance':
368 x, y = args.children
369 x, y = self._ResolveDistanceArg(x), self._ResolveDistanceArg(y)
370 if isinstance(x, geo_util.LatLng) and isinstance(y, basestring):
371 x, y = y, x
372 if isinstance(x, basestring) and isinstance(y, geo_util.LatLng):
373 distance = float(query_parser.GetQueryNodeText(match))
374 matcher = DistanceMatcher(y, distance)
375 return self._MatchGeoField(x, matcher, operator, document)
376 return False
378 def _IsHasGlobalValue(self, node):
379 if node.getType() == QueryParser.HAS and len(node.children) == 2:
380 if (node.children[0].getType() == QueryParser.GLOBAL and
381 node.children[1].getType() == QueryParser.VALUE):
382 return True
383 return False
385 def _MatchGlobalPhrase(self, node, document):
386 """Check if a document matches a parsed global phrase."""
387 if not all(self._IsHasGlobalValue(child) for child in node.children):
388 return False
390 value_nodes = (child.children[1] for child in node.children)
391 phrase_text = ' '.join(
392 (query_parser.GetQueryNodeText(node) for node in value_nodes))
393 for field in document.field_list():
394 if self._MatchRawPhraseWithRawAtom(field.value().string_value(),
395 phrase_text):
396 return True
397 return False
399 def _CheckMatch(self, node, document):
400 """Check if a document matches a query tree."""
402 if node.getType() == QueryParser.SEQUENCE:
403 result = all(self._CheckMatch(child, document) for child in node.children)
404 return result or self._MatchGlobalPhrase(node, document)
406 if node.getType() == QueryParser.CONJUNCTION:
407 return all(self._CheckMatch(child, document) for child in node.children)
409 if node.getType() == QueryParser.DISJUNCTION:
410 return any(self._CheckMatch(child, document) for child in node.children)
412 if node.getType() == QueryParser.NEGATION:
413 return not self._CheckMatch(node.children[0], document)
415 if node.getType() in query_parser.COMPARISON_TYPES:
416 lhs, match = node.children
417 if lhs.getType() == QueryParser.GLOBAL:
418 return self._MatchGlobal(match, document)
419 elif lhs.getType() == QueryParser.FUNCTION:
420 return self._MatchFunction(lhs, match, node.getType(), document)
421 return self._MatchAnyField(lhs, match, node.getType(), document)
423 return False
425 def Matches(self, document):
426 return self._CheckMatch(self._query, document)
428 def FilterDocuments(self, documents):
429 return (doc for doc in documents if self.Matches(doc))
432 def _DateStrToDays(date_str):
434 date = search_util.DeserializeDate(date_str)
435 return search_util.EpochTime(date) / MSEC_PER_DAY