error in main.py with cnf reestimation, todo
[dmvccm.git] / src / wsjdep.py
blobeeb6a666aa9923eadb2a2593515f7ae5ae0ee56f
1 from nltk.corpus.reader.util import *
2 from nltk.corpus.reader.bracket_parse import BracketParseCorpusReader
4 class WSJDepCorpusReader(BracketParseCorpusReader):
5 ''' Reader for the dependency parsed WSJ10. Will not include one-word
6 sentences, since these are not parsed (and thus not
7 POS-tagged!). All implemented foo_sents() functions should now be
8 of length 6268 since there are 38 one-word sentences. '''
9 def __init__(self, root):
10 BracketParseCorpusReader.__init__(self,
11 "../corpus/wsjdep", # path to files
12 ['wsj.combined.10.dep']) # file-list or regexp
14 def _read_block(self, stream):
15 return read_regexp_block(stream,
16 start_re=r'<sentence id=".+">')
18 def _normalize(self, t):
19 # convert XML to sexpr notation, more or less
20 t = re.sub(r'<sentence id="10.+">', r"[ ", t)
22 t = re.sub(r"\s+<text>\s+(.*)\s+</text>", r"<text>\1</text>", t)
23 t = re.sub(r"<text>((.|\n)*)</text>", r"(\1)\\n", t)
25 t = re.sub(r'<rel label=".*?">', r'', t)
26 t = re.sub(r'\s+<head id="(\d+?)" pos="(.+?)">(.+)</head>', r'((\2 \1 \3)', t)
27 t = re.sub(r'\s+<dep id="(\d+?)" pos="(.+?)">(.+)</dep>', r'(\2 \1 \3)', t)
28 t = re.sub(r'\s+</rel>', r')\\n', t)
30 t = re.sub(r"\s*</sentence>", r"]", t)
32 # \\n means "add an \n later", since we keep removing them
33 t = re.sub(r"\\n", r"\n", t)
34 return t
36 def _parse(self, t):
37 return dep_parse(self._normalize(t))
39 def _tag(self, t):
40 tagonly = self._tagonly(t)
41 tagged_sent = zip(self._word(t), tagonly)
42 return tagged_sent
44 def _word(self, t):
45 PARENS = re.compile(r'\(.+\)')
46 sentence = PARENS.findall(self._normalize(t))[0]
47 WORD = re.compile(r'([^\s()]+)')
48 words = WORD.findall(sentence)
49 if len(words) < 2:
50 return [] # skip one-word sentences!
51 else:
52 return words
54 def _get_tagonly_sent(self, parse):
55 "Convert dependency parse into a sorted taglist"
56 if not parse:
57 return None
59 tagset = set([])
60 for head, dep in parse:
61 tagset.add(head)
62 tagset.add(dep)
63 taglist = list(tagset)
64 taglist.sort(lambda x,y: x[1]-y[1])
66 return [tag for tag,loc in taglist]
68 # tags_and_parse
69 def _tags_and_parse(self, t):
70 parse = dep_parse(self._normalize(t))
71 return (self._get_tagonly_sent(parse), parse)
73 def _read_tags_and_parse_sent_block(self, stream):
74 tags_and_parse_sents = [self._tags_and_parse(t) for t in self._read_block(stream)]
75 return [(tag,parse) for (tag,parse) in tags_and_parse_sents if tag and parse]
77 def tagged_and_parsed_sents(self, files=None):
78 return concat([StreamBackedCorpusView(filename,
79 self._read_tags_and_parse_sent_block)
80 for filename in self.abspaths(files)])
82 # tagonly:
83 def _tagonly(self, t):
84 parse = dep_parse(self._normalize(t))
85 return self._get_tagonly_sent(parse)
87 def _read_tagonly_sent_block(self, stream):
88 tagonly_sents = [self._tagonly(t) for t in self._read_block(stream)]
89 return [tagonly_sent for tagonly_sent in tagonly_sents if tagonly_sent]
91 def tagonly_sents(self, files=None):
92 return concat([StreamBackedCorpusView(filename,
93 self._read_tagonly_sent_block)
94 for filename in self.abspaths(files)])
98 def dep_parse(s):
99 "todo: add ROOT, which is implicitly the only non-dependent tagloc"
100 def read_tagloc(pos):
101 match = WORD.match(s, pos+2)
102 tag = match.group(1)
103 pos = match.end()
105 match = WORD.match(s, pos)
106 loc = int(match.group(1))
107 pos = match.end()
109 match = WORD.match(s, pos) # skip the actual word
110 pos = match.end()
112 return (pos,tag,loc)
114 SPACE = re.compile(r'\s*')
115 WORD = re.compile(r'\s*([^\s\(\)]*)\s*')
116 RELSTART = re.compile(r'\(\(')
118 # Skip any initial whitespace and actual sentence
119 match = RELSTART.search(s, 0)
120 if match:
121 pos = match.start()
122 else:
123 # eg. one word sentence, no dependency relation
124 return None
126 parse = set([])
127 head, loc_h = None, None
128 while pos < len(s):
129 # Beginning of a sentence
130 if s[pos] == '[':
131 pos = SPACE.match(s, pos+1).end()
132 # End of a sentence
133 if s[pos] == ']':
134 pos = SPACE.match(s, pos+1).end()
135 if pos != len(s): raise ValueError, "Trailing garbage following sentence"
136 return parse
137 # Beginning of a relation, head:
138 elif s[pos:pos+2] == '((':
139 pos, head, loc_h = read_tagloc(pos)
140 # Dependent(s):
141 elif s[pos:pos+2] == ')(':
142 pos, arg, loc_a = read_tagloc(pos)
143 # Each head-arg relation gets its own pair in parse,
144 # although in xml we may have
145 # <rel><head/><dep/><dep/><dep/></rel>
146 parse.add( ((head,loc_h),(arg,loc_a)) )
147 elif s[pos:pos+2] == '))':
148 pos = SPACE.match(s, pos+2).end()
149 else:
150 print "s: %s\ns[%d]=%s"%(s,pos,s[pos])
151 raise ValueError, 'unexpected token'
153 print "s: %s\ns[%d]=%s"%(s,pos,s[pos])
154 raise ValueError, 'mismatched parens (or something)'
157 def add_root(parse):
158 "Return parse with ROOT added."
159 rooted = None
160 for (head,loc_h) in set([h for h,a in parse]):
161 if (head,loc_h) not in set([a for h,a in parse]):
162 if rooted:
163 raise ValueError, "Several possible roots in parse"
164 else:
165 rooted = (head,loc_h)
167 if not rooted:
168 raise ValueError, "No root in parse!"
170 parse.add( (('ROOT',-1), rooted) )
171 return parse
173 if __name__ == "__main__":
174 print "WSJDepCorpusReader tests:"
175 reader = WSJDepCorpusReader(None)
177 print "Sentences:"
178 print reader.sents()
180 print "Tagged sentences:"
181 print reader.tagged_sents()
183 parsedsents = reader.parsed_sents()
184 # print "Number of sentences: %d"%len(parsedsents) # takes a while
185 import pprint
186 print "First parsed sentence:"
187 pprint.pprint(parsedsents[0])
189 tags_and_parses = reader.tagged_and_parsed_sents()
190 print "121st tagged and then parsed sentence:"
191 pprint.pprint(tags_and_parses[121])