Update to remove faulty gold parses. Seems Done now :)
[dmvccm.git] / src / wsjdep.py
blob1f66590505b00553fcdbb3397d153d8d2d857461
1 from nltk.corpus.reader.util import *
2 from nltk.corpus.reader.bracket_parse import BracketParseCorpusReader
3 from common_dmv import MPPROOT
5 class WSJDepCorpusReader(BracketParseCorpusReader):
6 ''' Reader for the dependency parsed WSJ10. Will not include one-word
7 sentences, since these are not parsed (and thus not
8 POS-tagged!). All implemented foo_sents() functions should now be
9 of length 6268 since there are 38 one-word sentences. '''
10 def __init__(self, root):
11 BracketParseCorpusReader.__init__(self,
12 "../corpus/wsjdep", # path to files
13 ['wsj.combined.10.dep']) # file-list or regexp
15 def _read_block(self, stream):
16 return read_regexp_block(stream,
17 start_re=r'<sentence id=".+">')
19 def _normalize(self, t):
20 # convert XML to sexpr notation, more or less
21 t = re.sub(r'<sentence id="10.+">', r"[ ", t)
23 t = re.sub(r"\s+<text>\s+(.*)\s+</text>", r"<text>\1</text>", t)
24 t = re.sub(r"<text>((.|\n)*)</text>", r"(\1)\\n", t)
26 t = re.sub(r'<rel label=".*?">', r'', t)
27 t = re.sub(r'\s+<head id="(\d+?)" pos="(.+?)">(.+)</head>', r'((\2 \1 \3)', t)
28 t = re.sub(r'\s+<dep id="(\d+?)" pos="(.+?)">(.+)</dep>', r'(\2 \1 \3)', t)
29 t = re.sub(r'\s+</rel>', r')\\n', t)
31 t = re.sub(r"\s*</sentence>", r"]", t)
33 # \\n means "add an \n later", since we keep removing them
34 t = re.sub(r"\\n", r"\n", t)
35 return t
37 def _parse(self, t):
38 return dep_parse(self._normalize(t))
40 def _tag(self, t):
41 tagonly = self._tagonly(t)
42 tagged_sent = zip(self._word(t), tagonly)
43 return tagged_sent
45 def _word(self, t):
46 PARENS = re.compile(r'\(.+\)')
47 sentence = PARENS.findall(self._normalize(t))[0]
48 WORD = re.compile(r'([^\s()]+)')
49 words = WORD.findall(sentence)
50 if len(words) < 2:
51 return [] # skip one-word sentences!
52 else:
53 return words
55 def _get_tagonly_sent(self, parse):
56 "Convert dependency parse into a sorted taglist"
57 if not parse:
58 return None
60 tagset = set([])
61 for head, dep in parse:
62 tagset.add(head)
63 tagset.add(dep)
64 taglist = list(tagset)
65 taglist.sort(lambda x,y: x[1]-y[1])
67 return [tag for tag,loc in taglist]
69 # tags_and_parse
70 def _tags_and_parse(self, t):
71 parse = dep_parse(self._normalize(t))
72 return (self._get_tagonly_sent(parse), parse)
74 def _read_tags_and_parse_sent_block(self, stream):
75 tags_and_parse_sents = [self._tags_and_parse(t) for t in self._read_block(stream)]
76 return [(tag,parse) for (tag,parse) in tags_and_parse_sents if tag and parse]
78 def tagged_and_parsed_sents(self, files=None):
79 return concat([StreamBackedCorpusView(filename,
80 self._read_tags_and_parse_sent_block)
81 for filename in self.abspaths(files)])
83 # tagonly:
84 def _tagonly(self, t):
85 parse = dep_parse(self._normalize(t))
86 return self._get_tagonly_sent(parse)
88 def _read_tagonly_sent_block(self, stream):
89 tagonly_sents = [self._tagonly(t) for t in self._read_block(stream)]
90 return [tagonly_sent for tagonly_sent in tagonly_sents if tagonly_sent]
92 def tagonly_sents(self, files=None):
93 return concat([StreamBackedCorpusView(filename,
94 self._read_tagonly_sent_block)
95 for filename in self.abspaths(files)])
99 def dep_parse(s):
100 "todo: add ROOT, which is implicitly the only non-dependent tagloc"
101 def read_tagloc(pos):
102 match = WORD.match(s, pos+2)
103 tag = match.group(1)
104 pos = match.end()
106 match = WORD.match(s, pos)
107 loc = int(match.group(1))
108 pos = match.end()
110 match = WORD.match(s, pos) # skip the actual word
111 pos = match.end()
113 return (pos,tag,loc)
115 SPACE = re.compile(r'\s*')
116 WORD = re.compile(r'\s*([^\s\(\)]*)\s*')
117 RELSTART = re.compile(r'\(\(')
119 # Skip any initial whitespace and actual sentence
120 match = RELSTART.search(s, 0)
121 if match:
122 pos = match.start()
123 else:
124 # eg. one word sentence, no dependency relation
125 return None
127 parse = set([])
128 head, loc_h = None, None
129 while pos < len(s):
130 # Beginning of a sentence
131 if s[pos] == '[':
132 pos = SPACE.match(s, pos+1).end()
133 # End of a sentence
134 if s[pos] == ']':
135 pos = SPACE.match(s, pos+1).end()
136 if pos != len(s): raise ValueError, "Trailing garbage following sentence"
137 return parse
138 # Beginning of a relation, head:
139 elif s[pos:pos+2] == '((':
140 pos, head, loc_h = read_tagloc(pos)
141 # Dependent(s):
142 elif s[pos:pos+2] == ')(':
143 pos, arg, loc_a = read_tagloc(pos)
144 # Each head-arg relation gets its own pair in parse,
145 # although in xml we may have
146 # <rel><head/><dep/><dep/><dep/></rel>
147 parse.add( ((head,loc_h),(arg,loc_a)) )
148 elif s[pos:pos+2] == '))':
149 pos = SPACE.match(s, pos+2).end()
150 else:
151 print "s: %s\ns[%d]=%s"%(s,pos,s[pos])
152 raise ValueError, 'unexpected token'
154 print "s: %s\ns[%d]=%s"%(s,pos,s[pos])
155 raise ValueError, 'mismatched parens (or something)'
158 def add_root(parse):
159 "Return parse with ROOT added."
160 rooted = None
161 for (head,loc_h) in set([h for h,a in parse]):
162 if (head,loc_h) not in set([a for h,a in parse]):
163 if rooted:
164 raise RuntimeError, "Several possible roots in parse: %s"%(list(parse),)
165 else:
166 rooted = (head,loc_h)
168 if not rooted:
169 raise RuntimeError, "No root in parse!"
171 rootpair = (MPPROOT, rooted)
172 return parse.union(set([ rootpair ]))
174 if __name__ == "__main__":
175 print "WSJDepCorpusReader tests:"
176 reader = WSJDepCorpusReader(None)
178 print "Sentences:"
179 print reader.sents()
181 print "Tagged sentences:"
182 print reader.tagged_sents()
184 parsedsents = reader.parsed_sents()
185 # print "Number of sentences: %d"%len(parsedsents) # takes a while
186 import pprint
187 print "First parsed sentence:"
188 pprint.pprint(parsedsents[0])
190 tags_and_parses = reader.tagged_and_parsed_sents()
191 print "121st tagged and then parsed sentence:"
192 pprint.pprint(tags_and_parses[121])