2to3 (compiles, not tested)
[tag_parser.git] / src / mjacob / nltk / grammar / TreeAdjoiningGrammar.py
blob1975de149e19894d09fd94b28561a6640b4c4383
1 # This Python file uses the following encoding: utf-8
2 '''
3 Created on Apr 21, 2011
5 This is roughly based on the interface for nltk's ContextFreeGrammar
7 @author: mjacob
8 '''
9 from nltk.grammar import Nonterminal
10 from nltk.tree import Tree
11 import yaml
12 from collections import deque
13 from mjacob.annotations.memoized import Memoize
14 from mjacob.nltk.grammar.TagProduction import TagProduction
15 from mjacob.nltk.grammar.TagNonterminal import TagNonterminal
16 from functools import reduce
18 START = 'start'
19 PRODUCTIONS = 'productions'
21 class InvalidGrammarException(Exception): pass
23 class TreeAdjoiningGrammar(object):
24 """class that represents a TAG grammar
26 A Tree Adjoining Grammar (TAG) is a tuple G = ⟨N, T, I, A, S, f,,OA,,, f,,SA,,⟩ where
27 * N,T are disjoint alphabets of non-terminal and terminal symbols
28 * S ∈ N is a specific start symbol
29 * I is a finite set of initial trees, and A a finite set of auxiliary trees with node labels from N and T ∪ {ε}
30 * f,,OA,, and f,,SA,, are functions that represent adjunction constraints:
31 . f,,OA,, : {v|v vertex in some γ ∈ I ⋃ A} → {0,1} s.t. f,,OA,,(v) = 0 for every v with out-degree 0.
32 . f,,SA,, : {v|v vertex in some γ ∈ I ⋃ A} → P(A) s.t. f,,SA,,(v) = ∅ for every v with out-degree 0.
33 . Every _tree in I ∪ A is called an elementary _tree.
34 . f,,OA,, specifies whether adjunction is obligatory at a node (1/0)
35 . f,,SA,, gives the set of auxiliary trees that can be adjoined to a node.
36 . only internal nodes can allow for adjunction.
37 . this specifically rules out leaves as adjunction sites
39 """
41 # some constants
43 @classmethod
44 def _convert_cfg(cls, cfg):
45 """helper method for converting a CFG to a TAG"""
46 productions = []
47 for rule in cfg.productions():
48 production = TagProduction(rule)
49 productions.append(production)
50 return {START: cfg.start(),
51 PRODUCTIONS: frozenset(productions)}
53 def __init__(self,
54 filename=None,
55 grammar=None,
56 cfg=None,
57 productions=None,
58 start=None,
59 loader=yaml.load):
60 """create a new TAG from some specification, either a YAML string or file, or a CFG.
61 The CFG is assumed to be an nltk CFG grammar, nltk.grammar.ContextFreeGrammar
62 usage:
64 TreeAdjoiningGrammar(filename=filename, loader=LOADER) # filename must be TAG grammar
65 # loader is an optional pickle loader (defaults to yaml)
67 TreeAdjoiningGrammar(grammar=grammar, loader=LOADER) # grammar is string containing a TAG grammar
68 # loader is an optional pickle loader (defaults to yaml)
70 TreeAdjoiningGrammar(cfg=cfg) # CFG is a NLTK ContextFreeGrammar
72 TreeAdjoiningGrammar(productions=productions, start=start) # productions is a collection of TagProduction
73 # start is a Nonterminal
75 """
77 if len([x for x in (grammar, cfg, filename, productions) if x is not None]) != 1:
78 raise ValueError("exactly 1 of grammar, cfg or filename must be specified")
80 if productions is None:
81 if filename or grammar:
82 if filename:
83 with open(filename) as fh:
84 gram = loader(fh)
85 elif grammar:
86 gram = loader(grammar)
87 self._productions = frozenset(TagProduction(tree) for tree in gram[PRODUCTIONS])
88 self._start = Nonterminal(gram[START]) # S
89 else: # cfg
90 gram = self._convert_cfg(cfg)
91 self._productions = gram[PRODUCTIONS]
92 self._start = gram[START] # S
94 else:
95 self._productions = frozenset(productions)
96 self._start = start
98 self.__lexical_productions_by_word = self._find_lexical_productions_by_word()
99 self.__production_dependencies = self._find_production_dependencies()
100 self.__epsilonic = self._find_epsilonic()
102 self._precompute()
103 self._validate()
105 def _precompute(self):
106 self.__filtered_productions = {}
107 self._terminals = frozenset(self._find_terminals(self._productions))
108 self._nonterminals = frozenset(self._find_nonterminals(self._productions)) # N
110 def _find_epsilonic(self):
111 epsilonic = []
112 for production in self._productions:
113 if production.is_initial() and production.is_nonlexical() and production.epsilon_treepositions():
114 epsilonic.append(production)
115 return frozenset(epsilonic)
117 def save(self, filename=None, dumper=yaml.dump):
118 """turn the current TAG into YAML, or optionally write it to a file.
120 useful for seeing what your TAG-ified CFGs look like."""
121 yam = dumper({
122 START: self._start,
123 PRODUCTIONS: self._productions
126 if filename:
127 with open(filename, 'w') as fh:
128 fh.write(yam)
130 else:
131 return yam
133 def _find_terminals(self, productions):
134 for production in productions:
135 for leaf_position in production.leaf_treepositions():
136 yield production[leaf_position]
138 def _find_nonterminals(self, productions):
139 for production in productions:
140 for pos in production.treepositions():
141 node = production.get_node(pos)
142 if isinstance(node, TagNonterminal):
143 yield Nonterminal(node.symbol())
145 def _find_production_dependencies(self):
146 """for a given rule α, find all rules γ to which it may adjoin to or sub to."""
147 deps_by_prod = {}
149 for alpha in self._productions:
150 label = alpha.root().symbol()
151 deps = set()
153 if alpha.is_auxiliary():
154 for gamma in self._productions:
155 if gamma.adjunction_treepositions()[label]:
156 deps.add(gamma)
158 else:
159 for gamma in self._productions:
160 if gamma.substitution_treepositions()[label]:
161 deps.add(gamma)
163 deps_by_prod[alpha] = frozenset(deps)
165 return deps_by_prod
167 def _find_lexical_productions_by_word(self):
168 lexical_productions_by_word = {}
169 for production in self._productions:
170 for leaf in production.terminals():
171 if leaf in lexical_productions_by_word:
172 lexical_productions_by_word[leaf].add(production)
173 else:
174 lexical_productions_by_word[leaf] = set((production,))
175 return lexical_productions_by_word
177 def check_coverage(self, tokens):
178 missing = [token for token in tokens if token not in self._terminals]
179 if missing:
180 raise ValueError("Grammar does not cover some of the "
181 "input words: %r." % missing)
183 def _validate(self):
184 """this method raises an exception if there is something grossly wrong with the specified grammar"""
185 if len(self._terminals) == 0:
186 raise InvalidGrammarException("there must be at least one terminal in the grammar")
187 if len(self._nonterminals) == 0:
188 raise InvalidGrammarException("there must be at least one non-terminal symbol in the grammar")
189 if not self._start:
190 raise InvalidGrammarException("there must be a start symbol in the grammar")
192 initial_trees = False
193 for production in self.productions(is_initial=True):
194 if Nonterminal(production.root().symbol()) == self._start:
195 initial_trees = True
197 for pos, subtree in production.pos_subtrees():
198 node = production.get_node(pos)
200 if isinstance(subtree, Tree):
201 if not isinstance(node, TagNonterminal):
202 raise InvalidGrammarException('unexpected node found in non-terminal position: %s' % (subtree.node))
204 if isinstance(node, TagNonterminal):
205 if Nonterminal(node.symbol()) not in self._nonterminals:
206 raise InvalidGrammarException("unknown non-terminal '%s' found in _tree %s (should be in %s)" % (node.symbol(), production, self._nonterminals))
207 if node.is_foot():
208 raise InvalidGrammarException("foot nodes are not allowed in initial trees %s" % (production))
210 else:
211 if node not in self._terminals:
212 raise InvalidGrammarException("unknown terminal '%s' found in _tree %s" % (node, production))
214 for production in self.productions(is_auxiliary=True):
215 foot_node_count = 0
217 for pos, subtree in production.pos_subtrees():
218 node = production.get_node(pos)
220 if isinstance(subtree, Tree):
221 if not isinstance(node, TagNonterminal):
222 raise InvalidGrammarException('unexpected node found in non-terminal position: %s' % (subtree.node))
224 if isinstance(node, TagNonterminal):
225 if Nonterminal(node.symbol()) not in self._nonterminals:
226 raise InvalidGrammarException("unknown non-terminal '%s' found in _tree %s" % (node.symbol(), production))
228 if node.is_foot():
229 if not isinstance(subtree, TagNonterminal):
230 raise InvalidGrammarException("foot node must be a leaf node!")
232 if node.symbol() != production.root().symbol():
233 raise InvalidGrammarException("foot node label '%s' must be same as root node label %s" % (node.symbol(), production))
234 foot_node_count += 1
235 else:
236 if node not in self._terminals:
237 raise InvalidGrammarException("unknown terminal '%s' found in _tree %s" % (node, production))
239 if foot_node_count == 0:
240 raise InvalidGrammarException("no foot node found in auxiliary _tree %s" % (production))
242 if foot_node_count > 1:
243 raise InvalidGrammarException("only one foot node is allowed in an auxiliary _tree %s" % (production))
245 if initial_trees == 0:
246 raise InvalidGrammarException("there are no starting trees in the given grammar")
248 def _filter_production(self, production, **filters):
249 for filter, value in list(filters.items()):
250 if getattr(production, filter)() != value:
251 return False
252 return True
254 def _filter_productions(self, **filters):
255 keys = tuple(sorted(filters))
256 values = tuple(filters[key] for key in keys)
257 if keys in self.__filtered_productions and values in self.__filtered_productions[keys]:
258 return self.__filtered_productions[keys][values]
260 filtered = tuple(production
261 for production in self._productions
262 if self._filter_production(production, **filters))
264 if not keys in self.__filtered_productions:
265 self.__filtered_productions[keys] = {}
266 self.__filtered_productions[keys][values] = filtered
267 return filtered
269 def terminals(self):
270 return self._terminals
272 def nonterminals(self):
273 return self._nonterminals
275 def productions(self, *roots, **filters):
276 """returns an iterator of productions, given the specified root(s) and filter(s)"""
278 if filters:
279 to_iter = self._filter_productions(**filters)
280 else:
281 to_iter = self._productions
283 if roots:
284 def symbol(s):
285 if isinstance(s, Nonterminal):
286 return s.symbol()
287 else:
288 return s
290 roots = frozenset(symbol(r) for r in roots)
292 return (tree for tree in to_iter if tree.root().symbol() in roots)
294 else:
295 return (tree for tree in to_iter)
297 def start(self):
299 @return: The start symbol of the grammar
300 @rtype: L{Nonterminal}
302 return self._start
304 @Memoize
305 def is_lexical(self):
307 True if all productions are lexicalised.
309 return reduce(bool.__and__, (production.is_lexical() for production in self._productions))
311 def is_nonlexical(self):
313 True if all lexical rules are "preterminals", that is,
314 unary rules which can be separated in a preprocessing step.
316 This means that all productions are of the forms
317 A -> B1 ... Bn (n>=0), or A -> "s".
319 Note: is_lexical() and is_nonlexical() are not opposites.
320 There are grammars which are neither, and grammars which are both.
322 return reduce(bool.__and__, (production.is_nonlexical() for production in self._productions))
324 def is_chomsky_normal_form(self):
326 A grammar is of Chomsky normal form if all productions
327 are of the forms A -> B C, or A -> "s".
329 return reduce(bool.__and__, (production.is_chomsky_normal_form() for production in self._productions))
331 def is_binarised(self):
333 True if all productions are at most binary.
334 Note that there can still be empty and unary productions.
336 return reduce(bool.__and__, (production.is_binarised() for production in self._productions))
338 def is_flexible_chomsky_normal_form(self):
340 True if all productions are of the forms
341 A -> B C, A -> B, or A -> "s".
343 return reduce(bool.__and__, (production.is_flexible_chomsky_normal_form() for production in self._productions))
345 def is_nonempty(self):
347 True if there are no empty productions.
349 return reduce(bool.__and__, (production.is_nonempty() for production in self._productions))
351 def __repr__(self):
352 return '<TreeAdjoiningGrammar with %d productions>' % len(self._productions)
354 def __str__(self):
355 return 'TreeAdjoiningGrammar with %d productions' % len(self._productions)
357 def _is_subsequence(self, production, tokens):
358 leaves = production.terminals()
359 i = 0
360 j = 0
361 while i < len(leaves) and j < len(tokens):
362 if leaves[i] == tokens[j]:
363 i += 1
364 j += 1
365 else:
366 j += 1
367 return i == len(leaves)
369 def reduce(self, tokens):
370 productions = set()
372 for production in self.__epsilonic:
373 productions.add(production)
375 missing = []
376 for token in tokens:
377 if not token in self.__lexical_productions_by_word:
378 missing.append(token)
379 else:
380 for production in self.__lexical_productions_by_word[token]:
381 if production in productions:
382 continue;
384 if self._is_subsequence(production, tokens):
385 productions.add(production)
387 if missing:
388 raise ValueError("Grammar does not cover some of the "
389 "input words: %r." % missing)
391 fringe = deque(productions)
392 while fringe:
393 alpha = fringe.popleft()
394 for gamma in self.__production_dependencies[alpha]:
395 if gamma in productions:
396 continue
397 else:
398 productions.add(gamma)
399 fringe.append(gamma)
401 initial_rule_found = False
402 for production in productions:
403 if production.is_initial() and production.root().symbol() == self._start.symbol():
404 initial_rule_found = True
405 break
406 if not initial_rule_found:
407 productions = ()
409 #print "%s productions reduced to %s" % (len(self._productions), len(productions))
411 return SubTagGrammar(self._start, productions)
413 class SubTagGrammar(TreeAdjoiningGrammar):
414 def __init__(self, start, productions):
415 self._start = start
416 self._productions = productions
418 self._precompute()