2to3 (compiles, not tested)
[tag_parser.git] / tests / integration / tag_chart / test_tag_with_cfg.py
bloba2d70ec38e87df9a3a4ce831193fc336cc922b62
1 # This Python file uses the following encoding: utf-8
2 '''
3 Created on May 14, 2011
5 @author: mjacob
6 '''
7 # This Python file uses the following encoding: utf-8
8 # -*- coding: UTF-8 -*-
9 from nltk.parse.earleychart import EarleyChartParser
10 import nltk
11 from nltk.tree import ImmutableTree
12 from mjacob.algorithms.generate_random import generate_random_sentence
13 """ integration tests - does a specific parser puke on any of these TAGs? """
14 import argparse
15 import sys
16 import re
17 import logging
18 logging.basicConfig(level=logging.WARN, stream=sys.stdout)
20 def get_class(module_name, class_name):
21 return __import__(module_name, fromlist=[class_name]).__getattribute__(class_name)
23 class ParseTests(object):
24 """this class will generate random sentences from a using a context free grammar,
25 and then test that the results of my TAG parsers are identical to the results of
26 a standard NLTK parser."""
28 def __init__(self, args):
29 self.args=args
31 def run_tests(self, n=100, grammar_file='grammars/spanish_grammars/spanish2.cfg'):
32 parser_module_name, parser_class_name = re.match('(.*)\.(\w+)', self.args.parser).groups()
33 parser_class = get_class(parser_module_name, parser_class_name)
35 grammar = nltk.data.load(grammar_file)
36 parser = parser_class(grammar)
38 base_parser = EarleyChartParser(grammar)
40 errors = []
41 err = False
42 for i in range(n):
43 print(i, end=' ')
44 sent = generate_random_sentence(grammar)
45 print(len(errors), end=' ')
46 print(sent, end=' ')
47 tokens = sent.split(' ')
48 good_parses = set(base_parser.nbest_parse(tokens, tree_class=ImmutableTree))
49 print(len(good_parses), end=' ')
50 if len(good_parses) > 5:
51 print("skipping!")
52 continue
53 found_parses = set(parser.nbest_parse(tokens))
54 print(len(found_parses))
56 if found_parses != good_parses:
57 if len(found_parses) != len(good_parses):
58 errors.append("different number of parses found for \"%s\" (%s vs %s)" % (sent, len(good_parses), len(found_parses)))
59 else:
60 errors.append("different parses found for \"%s\" (%s)" % (sent, len(good_parses)))
61 if not err and len(good_parses) == 1:
62 print("GOOD:")
63 print("\n".join(repr(x) for x in good_parses))
64 print("FOUND:")
65 print("\n".join(repr(x) for x in found_parses))
67 err = True
69 if errors:
70 print("%s errors (out of %s tests)" % (len(errors), n))
71 print("\n".join(errors))
72 return 1
74 else:
75 print("%s tests passed" % (n))
77 def parse_arguments():
78 parser = argparse.ArgumentParser(description='integration tests for CFG parsers')
79 parser.add_argument('parser', metavar='parser class', type=str,
80 help='a cfg parser class')
81 return parser.parse_args()
83 def main():
84 args = parse_arguments()
85 test_runner = ParseTests(args)
86 exit_code = test_runner.run_tests()
87 exit(exit_code)
89 if __name__ == "__main__":
90 main()