report almost done
[dmvccm.git] / src / loc_h_dmv.py
blob6c97920fa2b39cea4b22e2466b9f3445bf651d1b
1 # loc_h_dmv.py
2 #
3 # dmv reestimation and inside-outside probabilities using loc_h, and
4 # no CNF-style rules
6 # Table of Contents:
7 # 1. Grammar-class and related functions
8 # 2. P_INSIDE / inner() and inner_sent()
9 # 3. P_OUTSIDE / outer()
10 # 4. Reestimation v.1: sentences as outer loop
11 # 5. Reestimation v.2: head-types as outer loop
12 # 6. Most Probable Parse
13 # 7. Testing functions
15 import io
16 from common_dmv import *
18 ### todo: debug with @accepts once in a while, but it's SLOW
19 # from typecheck import accepts, Any
21 if __name__ == "__main__":
22 print "loc_h_dmv module tests:"
24 def adj(middle, loc_h):
25 "middle is eg. k when rewriting for i<k<j (inside probabilities)."
26 return middle == loc_h or middle == loc_h+1 # ADJ == True
28 def make_GO_AT(p_STOP,p_ATTACH):
29 p_GO_AT = {}
30 for (a,h,dir), p_ah in p_ATTACH.iteritems():
31 p_GO_AT[a,h,dir, NON] = p_ah * (1-p_STOP[h, dir, NON])
32 p_GO_AT[a,h,dir, ADJ] = p_ah * (1-p_STOP[h, dir, ADJ])
33 return p_GO_AT
35 class DMV_Grammar(io.Grammar):
36 def __str__(self):
37 LJUST = 47
38 def t(n):
39 return "%d=%s" % (n, self.numtag(n))
40 def p(dict,key):
41 if key in dict:
42 if dict[key] > 1.00000001: # stupid floating point comparisons
43 raise Exception, "probability > 1.0:%s=%s"%(key,dict[key])
44 return dict[key]
45 else: return 0.0
46 def no_zeroL(str,tagstr,prob):
47 if prob > 0.0: return (str%(tagstr,prob)).ljust(LJUST)
48 else: return "".ljust(LJUST)
49 def no_zeroR(str,tagstr,prob):
50 if prob > 0.0: return str%(tagstr,prob)
51 else: return ""
52 def p_a(a,h):
53 p_L = p(self.p_ATTACH,(a,h,LEFT))
54 p_R = p(self.p_ATTACH,(a,h,RIGHT))
55 if p_L == 0.0 and p_R == 0.0:
56 return ''
57 else:
58 if p_L > 0.0:
59 str = "p_ATTACH[%s|%s,L] = %s" % (t(a), t(h), p_L)
60 str = str.ljust(LJUST)
61 else:
62 str = ''
63 if p_R > 0.0:
64 str = str.ljust(LJUST)
65 str += "p_ATTACH[%s|%s,R] = %s" % (t(a), t(h), p_R)
66 return '\n'+str
68 root, stop, att, ord = "","","",""
69 for h in self.headnums():
70 root += no_zeroL("\np_ROOT[%s] = %s", t(h), p(self.p_ROOT, (h)))
71 stop += '\n'
72 stop += no_zeroL("p_STOP[stop|%s,L,adj] = %s", t(h), p(self.p_STOP, (h,LEFT,ADJ)))
73 stop += no_zeroR("p_STOP[stop|%s,R,adj] = %s", t(h), p(self.p_STOP, (h,RIGHT,ADJ)))
74 stop += '\n'
75 stop += no_zeroL("p_STOP[stop|%s,L,non] = %s", t(h), p(self.p_STOP, (h,LEFT,NON)))
76 stop += no_zeroR("p_STOP[stop|%s,R,non] = %s", t(h), p(self.p_STOP, (h,RIGHT,NON)))
77 att += ''.join([p_a(a,h) for a in self.headnums()])
78 ord += '\n'
79 ord += no_zeroL("p_ORDER[ left-first|%s ] = %s", t(h), p(self.p_ORDER, (GOL,h)))
80 ord += no_zeroR("p_ORDER[right-first|%s ] = %s", t(h), p(self.p_ORDER, (GOR,h)))
81 return root + stop + att + ord
83 def __init__(self, numtag, tagnum, p_ROOT, p_STOP, p_ATTACH, p_ORDER):
84 io.Grammar.__init__(self, numtag, tagnum)
85 self.p_ROOT = p_ROOT # p_ROOT[w] = p
86 self.p_ORDER = p_ORDER # p_ORDER[seals, w] = p
87 self.p_STOP = p_STOP # p_STOP[w, LEFT, NON] = p (etc. for LA,RN,RA)
88 self.p_ATTACH = p_ATTACH # p_ATTACH[a, h, LEFT] = p (etc. for R)
89 # p_GO_AT[a, h, LEFT, NON] = p (etc. for LA,RN,RA)
90 self.p_GO_AT = make_GO_AT(self.p_STOP, self.p_ATTACH)
91 # these are used in reestimate2():
92 self.reset_iocharts()
94 def get_iochart(self, sent_nums):
95 ch_key = tuple(sent_nums)
96 try:
97 ichart = self._icharts[ch_key]
98 except KeyError:
99 ichart = {}
100 try:
101 ochart = self._ocharts[ch_key]
102 except KeyError:
103 ochart = {}
104 return (ichart, ochart)
106 def set_iochart(self, sent_nums, ichart, ochart):
107 self._icharts[tuple(sent_nums)] = ichart
108 self._ocharts[tuple(sent_nums)] = ochart
110 def reset_iocharts(self):
111 self._icharts = {}
112 self._ocharts = {}
114 def p_GO_AT_or0(self, a, h, dir, adj):
115 try:
116 return self.p_GO_AT[a, h, dir, adj]
117 except KeyError:
118 return 0.0
121 def locs(sent_nums, start, stop):
122 '''Return the between-word locations of all words in some fragment of
123 sent. We make sure to offset the locations correctly so that for
124 any w in the returned list, sent[w]==loc_w.
126 start is inclusive, stop is exclusive, as in klein-thesis and
127 Python's list-slicing.'''
128 for i0,w in enumerate(sent_nums[start:stop]):
129 loc_w = i0+start
130 yield (loc_w, w)
132 ###################################################
133 # P_INSIDE (dmv-specific) #
134 ###################################################
136 #@accepts(int, int, (int, int), int, Any(), [str], {tuple:float}, IsOneOf(None,{}))
137 def inner(i, j, node, loc_h, g, sent, ichart, mpptree=None):
138 ''' The ichart is of this form:
139 ichart[i,j,LHS, loc_h]
140 where i and j are between-word positions.
142 loc_h gives adjacency (along with k for attachment rules), and is
143 needed in P_STOP reestimation.
145 sent_nums = g.sent_nums(sent)
147 def terminal(i,j,node, loc_h, tabs):
148 if not i <= loc_h < j:
149 if 'INNER' in DEBUG:
150 print "%s*= 0.0 (wrong loc_h)" % tabs
151 return 0.0
152 elif POS(node) == sent_nums[i] and node in g.p_ORDER:
153 # todo: add to ichart perhaps? Although, it _is_ simple lookup..
154 prob = g.p_ORDER[node]
155 else:
156 if 'INNER' in DEBUG:
157 print "%sLACKING TERMINAL:" % tabs
158 prob = 0.0
159 if 'INNER' in DEBUG:
160 print "%s*= %.4f (terminal: %s -> %s_%d)" % (tabs,prob, node_str(node), sent[i], loc_h)
161 return prob
163 def e(i,j, (s_h,h), loc_h, n_t):
164 def to_mpp(p, L, R):
165 if mpptree:
166 key = (i,j, (s_h,h), loc_h)
167 if key not in mpptree:
168 mpptree[key] = (p, L, R)
169 elif mpptree[key][0] < p:
170 mpptree[key] = (p, L, R)
172 def tab():
173 "Tabs for debug output"
174 return "\t"*n_t
176 if (i, j, (s_h,h), loc_h) in ichart:
177 if 'INNER' in DEBUG:
178 print "%s*= %.4f in ichart: i:%d j:%d node:%s loc:%s" % (tab(),ichart[i, j, (s_h,h), loc_h], i, j,
179 node_str((s_h,h)), loc_h)
180 return ichart[i, j, (s_h,h), loc_h]
181 else:
182 # Either terminal rewrites, using p_ORDER:
183 if i+1 == j and (s_h == GOR or s_h == GOL):
184 return terminal(i, j, (s_h,h), loc_h, tab())
185 else: # Or not at terminal level yet:
186 if 'INNER' in DEBUG:
187 print "%s%s (%.1f) from %d to %d" % (tab(),node_str((s_h,h)),loc_h,i,j)
188 if s_h == SEAL:
189 if h == POS(ROOT): # only used in testing, o/w we use inner_sent
190 h = sent_nums[loc_h]
191 if i != 0 or j != len(sent): raise ValueError
192 else: return g.p_ROOT[h] * e(i,j,(SEAL,h),loc_h,n_t+1)
193 p_RGOL = g.p_STOP[h, LEFT, adj(i,loc_h)] * e(i,j,(RGOL,h),loc_h,n_t+1)
194 p_LGOR = g.p_STOP[h, RIGHT, adj(j,loc_h)] * e(i,j,(LGOR,h),loc_h,n_t+1)
195 p = p_RGOL + p_LGOR
196 to_mpp(p_RGOL, STOPKEY, (i,j, (RGOL,h),loc_h))
197 to_mpp(p_LGOR, (i,j, (RGOL,h),loc_h), STOPKEY )
198 if 'INNER' in DEBUG:
199 print "%sp= %.4f (STOP)" % (tab(), p)
200 elif s_h == RGOL or s_h == GOL:
201 p = 0.0
202 if s_h == RGOL:
203 p = g.p_STOP[h, RIGHT, adj(j,loc_h)] * e(i,j, (GOR,h),loc_h,n_t+1)
204 to_mpp(p, (i,j, (GOR,h),loc_h), STOPKEY)
205 for k in xgo_left(i, loc_h): # i < k <= loc_l(h)
206 p_R = e(k, j, ( s_h,h), loc_h, n_t+1)
207 if p_R > 0.0:
208 for loc_a,a in locs(sent_nums, i, k):
209 p_ah = g.p_GO_AT_or0(a, h, LEFT, adj(k,loc_h))
210 if p_ah > 0.0:
211 p_L = e(i, k, (SEAL,a), loc_a, n_t+1)
212 p_add = p_L * p_ah * p_R
213 p += p_add
214 to_mpp(p_add,
215 (i, k, (SEAL,a), loc_a),
216 (k, j, ( s_h,h), loc_h))
217 if 'INNER' in DEBUG:
218 print "%sp= %.4f (ATTACH)" % (tab(), p)
219 elif s_h == GOR or s_h == LGOR:
220 p = 0.0
221 if s_h == LGOR:
222 p = g.p_STOP[h, LEFT, adj(i,loc_h)] * e(i,j, (GOL,h),loc_h,n_t+1)
223 to_mpp(p, (i,j, (GOL,h),loc_h), STOPKEY)
224 for k in xgo_right(loc_h, j): # loc_l(h) < k < j
225 p_L = e(i, k, ( s_h,h), loc_h, n_t+1)
226 if p_L > 0.0:
227 for loc_a,a in locs(sent_nums,k,j):
228 p_ah = g.p_GO_AT_or0(a, h, RIGHT, adj(k,loc_h))
229 p_R = e(k, j, (SEAL,a), loc_a, n_t+1)
230 p_add = p_L * p_ah * p_R
231 p += p_add
232 to_mpp(p_add,
233 (i, k, ( s_h,h), loc_h),
234 (k, j, (SEAL,a), loc_a))
236 if 'INNER' in DEBUG:
237 print "%sp= %.4f (ATTACH)" % (tab(), p)
238 # elif s_h == GOL: # todo
240 ichart[i, j, (s_h,h), loc_h] = p
241 return p
242 # end of e-function
244 inner_prob = e(i,j,node,loc_h, 0)
245 if 'INNER' in DEBUG:
246 print debug_ichart(g,sent,ichart)
247 return inner_prob
248 # end of dmv.inner(i, j, node, loc_h, g, sent, ichart,mpptree)
251 def debug_ichart(g,sent,ichart):
252 str = "---ICHART:---\n"
253 for (s,t,LHS,loc_h),v in ichart.iteritems():
254 str += "%s -> %s_%d ... %s_%d (loc_h:%s):\t%s\n" % (node_str(LHS,g.numtag),
255 sent[s], s, sent[s], t, loc_h, v)
256 str += "---ICHART:end---\n"
257 return str
260 def inner_sent(g, sent, ichart):
261 return sum([g.p_ROOT[w] * inner(0, len(sent), (SEAL,w), loc_w, g, sent, ichart)
262 for loc_w,w in locs(g.sent_nums(sent),0,len(sent))])
268 ###################################################
269 # P_OUTSIDE (dmv-specific) #
270 ###################################################
272 #@accepts(int, int, (int, int), int, Any(), [str], {tuple:float}, {tuple:float})
273 def outer(i,j,w_node,loc_w, g, sent, ichart, ochart):
274 ''' http://www.student.uib.no/~kun041/dmvccm/DMVCCM.html#outer
276 w_node is a pair (seals,POS); the w in klein-thesis is made up of
277 POS(w) and loc_w
279 sent_nums = g.sent_nums(sent)
280 if POS(w_node) not in sent_nums[i:j]:
281 # sanity check, w must be able to dominate sent[i:j]
282 return 0.0
284 # local functions:
285 def e(i,j,LHS,loc_h): # P_{INSIDE}
286 try:
287 return ichart[i,j,LHS,loc_h]
288 except KeyError:
289 return inner(i,j,LHS,loc_h,g,sent,ichart)
291 def f(i,j,w_node,loc_w):
292 if not (i <= loc_w < j):
293 return 0.0
294 if (i,j,w_node,loc_w) in ochart:
295 return ochart[i,j, w_node,loc_w]
296 if w_node == ROOT:
297 if i == 0 and j == len(sent):
298 return 1.0
299 else: # ROOT may only be used on full sentence
300 return 0.0
301 # but we may have non-ROOTs (stops) over full sentence too:
302 w = POS(w_node)
303 s_w = seals(w_node)
305 # todo: try either if p_M > 0.0: or sum(), and speed-test them
307 if s_w == SEAL: # w == a
308 # todo: do the i<sent<j check here to save on calls?
309 p = g.p_ROOT[w] * f(i,j,ROOT,loc_w)
310 # left attach
311 for k in xgt(j, sent): # j<k<len(sent)+1
312 for loc_h,h in locs(sent_nums,j,k):
313 p_wh = g.p_GO_AT_or0(w, h, LEFT, adj(j, loc_h))
314 for s_h in [RGOL, GOL]:
315 p += f(i,k,(s_h,h),loc_h) * p_wh * e(j,k,(s_h,h),loc_h)
316 # right attach
317 for k in xlt(i): # k<i
318 for loc_h,h in locs(sent_nums,k,i):
319 p_wh = g.p_GO_AT_or0(w, h, RIGHT, adj(i, loc_h))
320 for s_h in [LGOR, GOR]:
321 p += e(k,i,(s_h,h), loc_h) * p_wh * f(k,j,(s_h,h), loc_h)
323 elif s_w == RGOL or s_w == GOL: # w == h, left stop + left attach
324 if s_w == RGOL:
325 s_h = SEAL
326 else: # s_w == GOL
327 s_h = LGOR
328 p = g.p_STOP[w, LEFT, adj(i,loc_w)] * f(i,j,( s_h,w),loc_w)
329 for k in xlt(i): # k<i
330 for loc_a,a in locs(sent_nums,k,i):
331 p_aw = g.p_GO_AT_or0(a, w, LEFT, adj(i, loc_w))
332 p += e(k,i, (SEAL,a),loc_a) * p_aw * f(k,j,w_node,loc_w)
334 elif s_w == GOR or s_w == LGOR: # w == h, right stop + right attach
335 if s_w == GOR:
336 s_h = RGOL
337 else: # s_w == LGOR
338 s_h = SEAL
339 p = g.p_STOP[w, RIGHT, adj(j,loc_w)] * f(i,j,( s_h,w),loc_w)
340 for k in xgt(j, sent): # j<k<len(sent)+1
341 for loc_a,a in locs(sent_nums,j,k):
342 p_ah = g.p_GO_AT_or0(a, w, RIGHT, adj(j, loc_w))
343 p += f(i,k,w_node,loc_w) * p_ah * e(j,k,(SEAL,a),loc_a)
345 ochart[i,j,w_node,loc_w] = p
346 return p
347 # end outer.f()
349 return f(i,j,w_node,loc_w)
350 # end outer(i,j,w_node,loc_w, g,sent, ichart,ochart)
355 ###################################################
356 # Reestimation v.1: #
357 # Sentences as outer loop #
358 ###################################################
360 def reest_zeros(h_nums):
361 '''A dict to hold numerators and denominators for our 6+ reestimation
362 formulas. '''
363 # todo: p_ORDER?
364 fr = { ('ROOT','den'):0.0 } # holds sum over f_sent!! not p_sent...
365 for h in h_nums:
366 fr['ROOT','num',h] = 0.0
367 for s_h in [GOR,GOL,RGOL,LGOR]:
368 x = (s_h,h)
369 fr['hat_a','den',x] = 0.0 # = c()
370 # not all arguments are attached to, so we just initialize
371 # fr['hat_a','num',a,(s_h,h)] as they show up, in reest_freq
372 for adj in [NON, ADJ]:
373 for nd in ['num','den']:
374 fr['STOP',nd,x,adj] = 0.0
375 return fr
378 def reest_freq(g, corpus):
379 fr = reest_zeros(g.headnums())
380 ichart = {}
381 ochart = {}
382 p_sent = None # 50 % speed increase on storing this locally
384 # local functions altogether 2x faster than global
385 def c(i,j,LHS,loc_h,sent):
386 if not p_sent > 0.0:
387 return p_sent
389 p_in = e(i,j, LHS,loc_h,sent)
390 if not p_in > 0.0:
391 return p_in
393 p_out = f(i,j, LHS,loc_h,sent)
394 return p_in * p_out / p_sent
395 # end reest_freq.c()
397 def f(i,j,LHS,loc_h,sent): # P_{OUTSIDE}
398 try:
399 return ochart[i,j,LHS,loc_h]
400 except KeyError:
401 return outer(i,j,LHS,loc_h,g,sent,ichart,ochart)
402 # end reest_freq.f()
404 def e(i,j,LHS,loc_h,sent): # P_{INSIDE}
405 try:
406 return ichart[i,j,LHS,loc_h]
407 except KeyError:
408 return inner(i,j,LHS,loc_h,g,sent,ichart)
409 # end reest_freq.e()
411 def w_left(i,j, x,loc_h,sent,sent_nums):
412 if not p_sent > 0.0: return
414 h = POS(x)
415 a_k = {}
416 for k in xtween(i, j):
417 p_out = f(i,j, x,loc_h, sent)
418 if not p_out > 0.0:
419 continue
420 p_R = e(k,j, x,loc_h, sent)
421 if not p_R > 0.0:
422 continue
424 for loc_a,a in locs(sent_nums, i,k): # i<=loc_l(a)<k
425 p_rule = g.p_GO_AT_or0(a, h, LEFT, adj(k, loc_h))
426 p_L = e(i,k, (SEAL,a), loc_a, sent)
427 p = p_L * p_out * p_R * p_rule
428 try: a_k[a] += p
429 except KeyError: a_k[a] = p
431 for a,p in a_k.iteritems():
432 try: fr['hat_a','num',a,x] += p / p_sent
433 except KeyError: fr['hat_a','num',a,x] = p / p_sent
434 # end reest_freq.w_left()
436 def w_right(i,j, x,loc_h,sent,sent_nums):
437 if not p_sent > 0.0: return
439 h = POS(x)
440 a_k = {}
441 for k in xtween(i, j):
442 p_out = f(i,j, x,loc_h, sent)
443 if not p_out > 0.0:
444 continue
445 p_L = e(i,k, x,loc_h, sent)
446 if not p_L > 0.0:
447 continue
449 for loc_a,a in locs(sent_nums, k,j): # k<=loc_l(a)<j
450 p_rule = g.p_GO_AT_or0(a, h, RIGHT, adj(k, loc_h))
451 p_R = e(k,j, (SEAL,a),loc_a, sent)
452 p = p_L * p_out * p_R * p_rule
453 try: a_k[a] += p
454 except KeyError: a_k[a] = p
456 for a,p in a_k.iteritems():
457 try: fr['hat_a','num',a,x] += p / p_sent
458 except KeyError: fr['hat_a','num',a,x] = p / p_sent
459 # end reest_freq.w_right()
461 # in reest_freq:
462 for sent in corpus:
463 if 'REEST' in DEBUG:
464 print sent
465 ichart = {}
466 ochart = {}
467 p_sent = inner_sent(g, sent, ichart)
468 fr['ROOT','den'] += 1 # divide by p_sent per h!
470 sent_nums = g.sent_nums(sent)
472 for loc_h,h in locs(sent_nums,0,len(sent)+1): # locs-stop is exclusive, thus +1
473 # root:
474 fr['ROOT','num',h] += g.p_ROOT[h] * e(0,len(sent), (SEAL,h),loc_h, sent) \
475 / p_sent
477 loc_l_h = loc_h
478 loc_r_h = loc_l_h+1
480 # left non-adjacent stop:
481 for i in xlt(loc_l_h):
482 fr['STOP','num',(GOL,h),NON] += c(loc_l_h, j, (LGOR, h),loc_h, sent)
483 fr['STOP','den',(GOL,h),NON] += c(loc_l_h, j, (GOL, h),loc_h, sent)
484 for j in xgteq(loc_r_h, sent):
485 fr['STOP','num',(RGOL,h),NON] += c(i, j, (SEAL, h),loc_h, sent)
486 fr['STOP','den',(RGOL,h),NON] += c(i, j, (RGOL, h),loc_h, sent)
487 # left adjacent stop, i = loc_l_h
488 fr['STOP','num',(GOL,h),ADJ] += c(loc_l_h, loc_r_h, (LGOR, h),loc_h, sent)
489 fr['STOP','den',(GOL,h),ADJ] += c(loc_l_h, loc_r_h, (GOL, h),loc_h, sent)
490 for j in xgteq(loc_r_h, sent):
491 fr['STOP','num',(RGOL,h),ADJ] += c(loc_l_h, j, (SEAL, h),loc_h, sent)
492 fr['STOP','den',(RGOL,h),ADJ] += c(loc_l_h, j, (RGOL, h),loc_h, sent)
493 # right non-adjacent stop:
494 for j in xgt(loc_r_h, sent):
495 fr['STOP','num',(GOR,h),NON] += c(loc_l_h, j, (RGOL, h),loc_h, sent)
496 fr['STOP','den',(GOR,h),NON] += c(loc_l_h, j, (GOR, h),loc_h, sent)
497 for i in xlteq(loc_l_h):
498 fr['STOP','num',(LGOR,h),NON] += c(loc_l_h, j, (SEAL, h),loc_h, sent)
499 fr['STOP','den',(LGOR,h),NON] += c(loc_l_h, j, (LGOR, h),loc_h, sent)
500 # right adjacent stop, j = loc_r_h
501 fr['STOP','num',(GOR,h),ADJ] += c(loc_l_h, loc_r_h, (RGOL, h),loc_h, sent)
502 fr['STOP','den',(GOR,h),ADJ] += c(loc_l_h, loc_r_h, (GOR, h),loc_h, sent)
503 for i in xlteq(loc_l_h):
504 fr['STOP','num',(LGOR,h),ADJ] += c(loc_l_h, j, (SEAL, h),loc_h, sent)
505 fr['STOP','den',(LGOR,h),ADJ] += c(loc_l_h, j, (LGOR, h),loc_h, sent)
507 # left attachment:
508 if 'REEST_ATTACH' in DEBUG:
509 print "Lattach %s: for i < %s"%(g.numtag(h),sent[0:loc_h+1])
510 for s_h in [RGOL, GOL]:
511 x = (s_h, h)
512 for i in xlt(loc_l_h): # i < loc_l(h)
513 if 'REEST_ATTACH' in DEBUG:
514 print "\tfor j >= %s"%sent[loc_h:len(sent)]
515 for j in xgteq(loc_r_h, sent): # j >= loc_r(h)
516 fr['hat_a','den',x] += c(i,j, x,loc_h, sent) # v_q in L&Y
517 if 'REEST_ATTACH' in DEBUG:
518 print "\t\tc( %d , %d, %s, %s, sent)=%.4f"%(i,j,node_str(x),loc_h,fr['hat_a','den',x])
519 w_left(i, j, x,loc_h, sent,sent_nums) # compute w for all a in sent
521 # right attachment:
522 if 'REEST_ATTACH' in DEBUG:
523 print "Rattach %s: for i <= %s"%(g.numtag(h),sent[0:loc_h+1])
524 for s_h in [GOR, LGOR]:
525 x = (s_h, h)
526 for i in xlteq(loc_l_h): # i <= loc_l(h)
527 if 'REEST_ATTACH' in DEBUG:
528 print "\tfor j > %s"%sent[loc_h:len(sent)]
529 for j in xgt(loc_r_h, sent): # j > loc_r(h)
530 fr['hat_a','den',x] += c(i,j, x,loc_h, sent) # v_q in L&Y
531 if 'REEST_ATTACH' in DEBUG:
532 print "\t\tc( %d , %d, %s, %s, sent)=%.4f"%(loc_h,j,node_str(x),loc_h,fr['hat_a','den',x])
533 w_right(i,j, x,loc_h, sent,sent_nums) # compute w for all a in sent
534 # end for loc_h,h
535 # end for sent
537 return fr
539 def reestimate(old_g, corpus):
540 fr = reest_freq(old_g, corpus)
541 p_ROOT, p_STOP, p_ATTACH = {},{},{}
543 for h in old_g.headnums():
544 # reest_head changes p_ROOT, p_STOP, p_ATTACH
545 reest_head(h, fr, old_g, p_ROOT, p_STOP, p_ATTACH)
546 p_ORDER = old_g.p_ORDER
547 numtag, tagnum = old_g.get_nums_tags()
549 new_g = DMV_Grammar(numtag, tagnum, p_ROOT, p_STOP, p_ATTACH, p_ORDER)
550 return new_g
553 def reest_head(h, fr, g, p_ROOT, p_STOP, p_ATTACH):
554 "Given a single head, update g with the reestimated probability."
555 # remove 0-prob stuff? todo
556 try:
557 p_ROOT[h] = fr['ROOT','num',h] / fr['ROOT','den']
558 except KeyError:
559 p_ROOT[h] = 0.0
561 for dir in [LEFT,RIGHT]:
562 for adj in [ADJ, NON]: # p_STOP
563 p_STOP[h, dir, adj] = 0.0
564 for s_h in dirseal(dir):
565 x = (s_h,h)
566 p = fr['STOP','den', x, adj]
567 if p > 0.0:
568 p = fr['STOP', 'num', x, adj] / p
569 p_STOP[h, dir, adj] += p
571 for s_h in dirseal(dir): # make hat_a for p_ATTACH
572 x = (s_h,h)
573 p_c = fr['hat_a','den',x]
575 for a in g.headnums():
576 if (a,h,dir) not in p_ATTACH:
577 p_ATTACH[a,h,dir] = 0.0
578 try: # (a,x) might not be in hat_a
579 p_ATTACH[a,h,dir] += fr['hat_a','num',a,x] / p_c
580 except KeyError: pass
581 except ZeroDivisionError: pass
587 ###################################################
588 # Reestimation v.2: #
589 # Heads as outer loop #
590 ###################################################
592 def locs_h(h, sent_nums):
593 '''Return the between-word locations of all tokens of h in sent.'''
594 return [loc_w for loc_w,w in locs(sent_nums, 0, len(sent_nums))
595 if w == h]
597 def locs_a(a, sent_nums, start, stop):
598 '''Return the between-word locations of all tokens of h in some
599 fragment of sent. We make sure to offset the locations correctly
600 so that for any w in the returned list, sent[w]==loc_w.
602 start is inclusive, stop is exclusive, as in klein-thesis and
603 Python's list-slicing (eg. return left-loc).'''
604 return [loc_w for loc_w,w in locs(sent_nums, start, stop)
605 if w == a]
607 def inner2(i, j, node, loc_h, g, sent):
608 ichart,ochart = g.get_iochart(s_n)
609 try: p = ichart[i,j,x,loc_h]
610 except KeyError: p = inner(i,j,x,loc_h,g,sent,ichart)
611 g.set_iochart(s_n,ichart,ochart)
612 return p
614 def inner_sent2(g, sent):
615 ichart,ochart = g.get_iochart(s_n)
616 p = inner_sent(g,sent,ichart)
617 g.set_iochart(s_n,ichart,ochart)
618 return p
620 def outer2(i, j,w_node,loc_w, g, sent):
621 ichart,ochart = g.get_iochart(s_n)
622 try: p = ochart[i,j,w_node,loc_w]
623 except KeyError: p = inner(i,j,w_node,loc_w,g,sent,ichart,ochart)
624 g.set_iochart(s_n,ichart,ochart)
625 return p
627 def reestimate2(old_g, corpus):
628 p_ROOT, p_STOP, p_ATTACH = {},{},{}
630 for h in old_g.headnums():
631 # reest_head changes p_ROOT, p_STOP, p_ATTACH
632 reest_head2(h, old_g, corpus, p_ROOT, p_STOP, p_ATTACH)
633 p_ORDER = old_g.p_ORDER
634 numtag, tagnum = old_g.get_nums_tags()
636 new_g = DMV_Grammar(numtag, tagnum, p_ROOT, p_STOP, p_ATTACH, p_ORDER)
637 return new_g
639 def hat_d2(xbar, x, xi, xj, g, corpus): # stop helper
640 def c(x,loc_x,i,j): return c2(x,loc_x,i,j,g,s_n,sent)
642 h = POS(x)
643 if h != POS(xbar): raise ValueError
645 num, den = 0.0, 0.0
646 for s_n,sent in [(g.sent_nums(sent),sent) for sent in corpus]:
647 for loc_h in locs_h(h,s_n):
648 loc_l_h, loc_r_h = loc_h, loc_h + 1
649 for i in xi(loc_l_h):
650 for j in xj(loc_r_h, s_n):
651 # print "s:%s %d,%d"%(sent,i,j)
652 num += c(xbar,loc_h,i,j)
653 den += c(x,loc_h,i,j)
654 if den == 0.0:
655 return den
656 return num/den # eg. SEAL/RGOL, xbar/x
659 def c2(x,loc_h,i,j,g,s_n,sent):
660 ichart,ochart = g.get_iochart(s_n)
662 def f(i,j,x,loc_h): # P_{OUTSIDE}
663 try: return ochart[i,j,x,loc_h]
664 except KeyError: return outer(i,j,x,loc_h,g,sent,ichart,ochart)
665 def e(i,j,x,loc_h): # P_{INSIDE}
666 try: return ichart[i,j,x,loc_h]
667 except KeyError: return inner(i,j,x,loc_h,g,sent,ichart)
669 p_sent = inner_sent(g, sent, ichart)
670 if not p_sent > 0.0:
671 return p_sent
673 p_in = e(i,j, x,loc_h)
674 if not p_in > 0.0:
675 return p_in
677 p_out = f(i,j, x,loc_h)
679 g.set_iochart(s_n,ichart,ochart)
680 return p_in * p_out / p_sent
682 def w2(a, x,loc_h, dir, i, j, g, s_n,sent):
683 ichart,ochart = g.get_iochart(s_n)
685 def f(i,j,x,loc_h): # P_{OUTSIDE}
686 try: return ochart[i,j,x,loc_h]
687 except KeyError: return outer(i,j,x,loc_h,g,sent,ichart,ochart)
688 def e(i,j,x,loc_h): # P_{INSIDE}
689 try: return ichart[i,j,x,loc_h]
690 except KeyError: return inner(i,j,x,loc_h,g,sent,ichart)
692 h = POS(x)
693 p_sent = inner_sent(g, sent, ichart)
695 if dir == LEFT:
696 L, R = (SEAL,a),x
697 else:
698 L, R = x,(SEAL,a)
699 w_sum = 0.0
701 for k in xtween(i,j):
702 if dir == LEFT:
703 start, stop = i, k
704 else:
705 start, stop = k, j
706 for loc_a in locs_a(a, s_n, start, stop):
707 if dir == LEFT:
708 loc_L, loc_R = loc_a, loc_h
709 else:
710 loc_L, loc_R = loc_h, loc_a
711 p = g.p_GO_AT_or0(a,h,dir,adj(k,loc_h))
712 in_L = e(i,k,L,loc_L)
713 in_R = e(k,j,R,loc_R)
714 out = f(i,j,x,loc_h)
715 w_sum += p * in_L * in_R * out
717 g.set_iochart(s_n,ichart,ochart)
718 return w_sum/p_sent
720 def hat_a2(a, x, dir, g, corpus): # attachment helper
721 def w(a,x,loc_x,dir,i,j): return w2(a,x,loc_x,dir,i,j,g,s_n,sent)
722 def c(x,loc_x,i,j): return c2(x,loc_x,i,j,g,s_n,sent)
724 h = POS(x)
725 if dir == LEFT:
726 xi, xj = xlt, xgteq
727 else:
728 xi, xj = xlteq, xgt
729 den, num = 0.0, 0.0
731 for s_n,sent in [(g.sent_nums(sent),sent) for sent in corpus]:
732 for loc_h in locs_h(h,s_n):
733 loc_l_h, loc_r_h = loc_h, loc_h + 1
734 for i in xi(loc_l_h):
735 for j in xj(loc_r_h,sent):
736 num += w(a, x,loc_h, dir, i,j)
737 den += c(x,loc_h, i,j)
738 if den == 0.0:
739 return den
740 return num/den
742 def reest_root2(h,g,corpus):
743 sum = 0.0
744 corpus_size = 0.0
745 for s_n,sent in [(g.sent_nums(sent),sent) for sent in corpus]:
746 num, den = 0.0, 0.0
747 corpus_size += 1.0
748 ichart, ochart = g.get_iochart(s_n)
749 den += inner_sent(g, sent, ichart)
750 for loc_h in locs_h(h,s_n):
751 num += \
752 g.p_ROOT[h] * \
753 inner(0, len(s_n), (SEAL,h), loc_h, g, sent, ichart)
754 g.set_iochart(s_n, ichart, ochart)
755 sum += num / den
756 return sum / corpus_size
758 def reest_head2(h, g, corpus, p_ROOT, p_STOP, p_ATTACH):
759 print "h: %d=%s ..."%(h,g.numtag(h)),
760 def hat_d(xbar,x,xi,xj): return hat_d2(xbar,x,xi,xj, g, corpus)
761 def hat_a(a, x, dir ): return hat_a2(a, x, dir, g, corpus)
763 p_STOP[h, LEFT,NON] = \
764 hat_d((SEAL,h),(RGOL,h),xlt, xgteq) + \
765 hat_d((LGOR,h),( GOL,h),xlt, xeq)
766 p_STOP[h, LEFT,ADJ] = \
767 hat_d((SEAL,h),(RGOL,h),xeq, xgteq) + \
768 hat_d((LGOR,h),( GOL,h),xeq, xeq)
769 p_STOP[h,RIGHT,NON] = \
770 hat_d((RGOL,h),( GOR,h),xeq, xgt) + \
771 hat_d((SEAL,h),(LGOR,h),xlteq,xgt)
772 p_STOP[h,RIGHT,ADJ] = \
773 hat_d((RGOL,h),( GOR,h),xeq, xeq) + \
774 hat_d((SEAL,h),(LGOR,h),xlteq,xeq)
775 print "stops done...",
777 p_ROOT[h] = reest_root2(h,g,corpus)
778 print "root done...",
780 for a in g.headnums():
781 p_ATTACH[a,h,LEFT] = \
782 hat_a(a, (GOL,h),LEFT) + \
783 hat_a(a,(RGOL,h),LEFT)
784 p_ATTACH[a,h,RIGHT] = \
785 hat_a(a, (GOR,h),RIGHT) + \
786 hat_a(a,(LGOR,h),RIGHT)
788 print "attachment done"
792 ###################################################
793 # Most Probable Parse: #
794 ###################################################
796 STOPKEY = (-1,-1,STOP,-1)
797 ROOTKEY = (-1,-1,ROOT,-1)
799 def make_mpptree(g, sent):
800 '''Tell inner() to make an mpptree, connect ROOT to this. (Logically,
801 this should be part of inner_sent though...)'''
802 ichart = {}
803 mpptree = { ROOTKEY:(0.0, ROOTKEY, None) }
804 for loc_w,w in locs(g.sent_nums(sent),0,len(sent)):
805 p = g.p_ROOT[w] * inner(0, len(sent), (SEAL,w), loc_w, g, sent, ichart, mpptree)
806 L = ROOTKEY
807 R = (0,len(sent), (SEAL,w), loc_w)
808 if mpptree[ROOTKEY][0] < p:
809 mpptree[ROOTKEY] = (p, L, R)
810 return mpptree
812 def parse_mpptree(mpptree, sent):
813 '''mpptree is a dict of the form {k:(p,L,R),...}; where k, L and R
814 are `keys' of the form (i,j,node,loc).
816 returns an mpp of the form [((head, loc_h),(arg, loc_a)), ...],
817 where head and arg are tags.'''
818 # local functions for clear access to mpptree:
819 def k_node(key):
820 return key[2]
821 def k_POS(key):
822 return POS(k_node(key))
823 def k_seals(key):
824 return seals(k_node(key))
825 def k_locnode(key):
826 return (k_node(key),key[3])
827 def k_locPOS(key):
828 return (k_POS(key),key[3])
829 def k_terminal(key):
830 s_k = k_seals(key) # i+1 == j
831 return key[0] + 1 == key[1] and (s_k == GOR or s_k == GOL)
832 def t_L(tree_entry):
833 return tree_entry[1]
834 def t_R(tree_entry):
835 return tree_entry[2]
837 # arbitrarily, "ROOT attaches to right". We add it here to
838 # avoid further complications:
839 firstkey = t_R(mpptree[ROOTKEY])
840 deps = set([ (k_locPOS(ROOTKEY), k_locPOS(firstkey), RIGHT) ])
841 q = [firstkey]
843 while len(q) > 0:
844 k = q.pop()
845 if k_terminal(k):
846 continue
847 else:
848 L = t_L( mpptree[k] )
849 R = t_R( mpptree[k] )
850 if k_locnode( k ) == k_locnode( L ): # Rattach
851 deps.add((k_locPOS( k ), k_locPOS( R ), LEFT))
852 q.extend( [L, R] )
853 elif k_locnode( k ) == k_locnode( R ): # Lattach
854 deps.add((k_locPOS( k ), k_locPOS( L ), RIGHT))
855 q.extend( [L, R] )
856 elif R == STOPKEY:
857 q.append( L )
858 elif L == STOPKEY:
859 q.append( R )
860 return deps
862 def mpp(g, sent):
863 tagf = g.numtag # localized function, todo: speed-test
864 mpptree = make_mpptree(g, sent)
865 return set([((tagf(h), loc_h), (tagf(a), loc_a))
866 for (h, loc_h),(a,loc_a),dir in parse_mpptree(mpptree,sent)])
869 ########################################################################
870 # testing functions: #
871 ########################################################################
873 testcorpus = [s.split() for s in ['det nn vbd c vbd','vbd nn c vbd',
874 'det nn vbd', 'det nn vbd c pp',
875 'det nn vbd', 'det vbd vbd c pp',
876 'det nn vbd', 'det nn vbd c vbd',
877 'det nn vbd', 'det nn vbd c vbd',
878 'det nn vbd', 'det nn vbd c vbd',
879 'det nn vbd', 'det nn vbd c pp',
880 'det nn vbd pp', 'det nn vbd', ]]
882 def testgrammar():
883 import loc_h_harmonic
884 reload(loc_h_harmonic)
886 # make sure these are the way they were when setting up the tests:
887 loc_h_harmonic.HARMONIC_C = 0.0
888 loc_h_harmonic.FNONSTOP_MIN = 25
889 loc_h_harmonic.FSTOP_MIN = 5
890 loc_h_harmonic.RIGHT_FIRST = 1.0
891 loc_h_harmonic.OLD_STOP_CALC = True
893 return loc_h_harmonic.initialize(testcorpus)
895 def testreestimation2():
896 g2 = testgrammar()
897 reestimate2(g2, testcorpus)
898 return g2
900 def testreestimation():
901 g = testgrammar()
902 g = reestimate(g, testcorpus)
903 return g
906 def testmpp_regression(mpptree,k_n):
907 mpp = {ROOTKEY: (2.877072116829971e-05, STOPKEY, (0, 3, (2, 3), 1)),
908 (0, 1, (1, 1), 0): (0.1111111111111111, (0, 1, (0, 1), 0), STOPKEY),
909 (0, 1, (2, 1), 0): (0.049382716049382713, STOPKEY, (0, 1, (1, 1), 0)),
910 (0, 3, (1, 3), 1): (0.00027619892321567721,
911 (0, 1, (2, 1), 0),
912 (1, 3, (1, 3), 1)),
913 (0, 3, (2, 3), 1): (0.00012275507698474543, STOPKEY, (0, 3, (1, 3), 1)),
914 (1, 3, (0, 3), 1): (0.025280986819448362,
915 (1, 2, (0, 3), 1),
916 (2, 3, (2, 4), 2)),
917 (1, 3, (1, 3), 1): (0.0067415964851862296, (1, 3, (0, 3), 1), STOPKEY),
918 (2, 3, (1, 4), 2): (0.32692307692307693, (2, 3, (0, 4), 2), STOPKEY),
919 (2, 3, (2, 4), 2): (0.037721893491124266, STOPKEY, (2, 3, (1, 4), 2))}
920 for k,(v,L,R) in mpp.iteritems():
921 k2 = k[0:k_n] # 3 if the new does not check loc_h
922 if type(k)==str:
923 k2 = k
924 if k2 not in mpptree:
925 print "mpp regression, %s missing"%(k2,)
926 else:
927 vnew = mpptree[k2][0]
928 if not "%.10f"%vnew == "%.10f"%v:
929 print "mpp regression, wanted %s=%.5f, got %.5f"%(k2,v,vnew)
932 def testgrammar_a():
933 h, a = 0, 1
934 p_ROOT, p_STOP, p_ATTACH, p_ORDER = {},{},{},{}
935 p_ROOT[h] = 0.9
936 p_ROOT[a] = 0.1
937 p_STOP[h,LEFT,NON] = 1.0
938 p_STOP[h,LEFT,ADJ] = 1.0
939 p_STOP[h,RIGHT,NON] = 0.4 # RSTOP
940 p_STOP[h,RIGHT,ADJ] = 0.3 # RSTOP
941 p_STOP[a,LEFT,NON] = 1.0
942 p_STOP[a,LEFT,ADJ] = 1.0
943 p_STOP[a,RIGHT,NON] = 0.4 # RSTOP
944 p_STOP[a,RIGHT,ADJ] = 0.3 # RSTOP
945 p_ATTACH[a,h,LEFT] = 1.0 # not used
946 p_ATTACH[a,h,RIGHT] = 1.0 # not used
947 p_ATTACH[h,a,LEFT] = 1.0 # not used
948 p_ATTACH[h,a,RIGHT] = 1.0 # not used
949 p_ATTACH[h,h,LEFT] = 1.0 # not used
950 p_ATTACH[h,h,RIGHT] = 1.0 # not used
951 p_ORDER[(GOR, h)] = 1.0
952 p_ORDER[(GOL, h)] = 0.0
953 p_ORDER[(GOR, a)] = 1.0
954 p_ORDER[(GOL, a)] = 0.0
955 g = DMV_Grammar({h:'h',a:'a'}, {'h':h,'a':a}, p_ROOT, p_STOP, p_ATTACH, p_ORDER)
956 # these probabilities are impossible so add them manually:
957 g.p_GO_AT[a,a,LEFT,NON] = 0.4 # Lattach
958 g.p_GO_AT[a,a,LEFT,ADJ] = 0.6 # Lattach
959 g.p_GO_AT[h,a,LEFT,NON] = 0.2 # Lattach to h
960 g.p_GO_AT[h,a,LEFT,ADJ] = 0.1 # Lattach to h
961 g.p_GO_AT[a,a,RIGHT,NON] = 1.0 # Rattach
962 g.p_GO_AT[a,a,RIGHT,ADJ] = 1.0 # Rattach
963 g.p_GO_AT[h,a,RIGHT,NON] = 1.0 # Rattach to h
964 g.p_GO_AT[h,a,RIGHT,ADJ] = 1.0 # Rattach to h
965 g.p_GO_AT[h,h,LEFT,NON] = 0.2 # Lattach
966 g.p_GO_AT[h,h,LEFT,ADJ] = 0.1 # Lattach
967 g.p_GO_AT[a,h,LEFT,NON] = 0.4 # Lattach to a
968 g.p_GO_AT[a,h,LEFT,ADJ] = 0.6 # Lattach to a
969 g.p_GO_AT[h,h,RIGHT,NON] = 1.0 # Rattach
970 g.p_GO_AT[h,h,RIGHT,ADJ] = 1.0 # Rattach
971 g.p_GO_AT[a,h,RIGHT,NON] = 1.0 # Rattach to a
972 g.p_GO_AT[a,h,RIGHT,ADJ] = 1.0 # Rattach to a
973 return g
976 def testgrammar_h():
977 h = 0
978 p_ROOT, p_STOP, p_ATTACH, p_ORDER = {},{},{},{}
979 p_ROOT[h] = 1.0
980 p_STOP[h,LEFT,NON] = 1.0
981 p_STOP[h,LEFT,ADJ] = 1.0
982 p_STOP[h,RIGHT,NON] = 0.4
983 p_STOP[h,RIGHT,ADJ] = 0.3
984 p_ATTACH[h,h,LEFT] = 1.0 # not used
985 p_ATTACH[h,h,RIGHT] = 1.0 # not used
986 p_ORDER[(GOR, h)] = 1.0
987 p_ORDER[(GOL, h)] = 0.0
988 g = DMV_Grammar({h:'h'}, {'h':h}, p_ROOT, p_STOP, p_ATTACH, p_ORDER)
989 g.p_GO_AT[h,h,LEFT,NON] = 0.6 # these probabilities are impossible
990 g.p_GO_AT[h,h,LEFT,ADJ] = 0.7 # so add them manually...
991 g.p_GO_AT[h,h,RIGHT,NON] = 1.0
992 g.p_GO_AT[h,h,RIGHT,ADJ] = 1.0
993 return g
997 def testreestimation_h():
998 DEBUG.add('REEST')
999 g = testgrammar_h()
1000 reestimate(g,['h h h'.split()])
1003 def test(wanted, got):
1004 if not wanted == got:
1005 raise Warning, "Regression! Should be %s: %s" % (wanted, got)
1007 def regression_tests():
1008 testmpp_regression(make_mpptree(testgrammar(), testcorpus[2]),4)
1009 h = 0
1011 test("0.120",
1012 "%.3f" % inner(0, 2, (SEAL,h), 0, testgrammar_h(), 'h h'.split(),{}))
1013 test("0.063",
1014 "%.3f" % inner(0, 2, (SEAL,h), 1, testgrammar_h(), 'h h'.split(),{}))
1015 test("0.1842",
1016 "%.4f" % inner_sent(testgrammar_h(), 'h h h'.split(),{}))
1018 test("0.1092",
1019 "%.4f" % inner(0, 3, (SEAL,0), 0, testgrammar_h(), 'h h h'.split(),{}))
1020 test("0.0252",
1021 "%.4f" % inner(0, 3, (SEAL,0), 1, testgrammar_h(), 'h h h'.split(),{}))
1022 test("0.0498",
1023 "%.4f" % inner(0, 3, (SEAL,h), 2, testgrammar_h(), 'h h h'.split(),{}))
1025 test("0.58" ,
1026 "%.2f" % outer(1, 3, (RGOL,h), 2, testgrammar_h(),'h h h'.split(),{},{}))
1027 test("0.61" , # ftw? can't be right... there's an 0.4 shared between these two...
1028 "%.2f" % outer(1, 3, (RGOL,h), 1, testgrammar_h(),'h h h'.split(),{},{}))
1030 test("0.00" ,
1031 "%.2f" % outer(1, 3, (RGOL,h), 0, testgrammar_h(),'h h h'.split(),{},{}))
1032 test("0.00" ,
1033 "%.2f" % outer(1, 3, (RGOL,h), 3, testgrammar_h(),'h h h'.split(),{},{}))
1035 test("0.1089" ,
1036 "%.4f" % outer(0, 1, (GOR,h), 0,testgrammar_a(),'h a'.split(),{},{}))
1037 test("0.3600" ,
1038 "%.4f" % outer(0, 2, (GOR,h), 0,testgrammar_a(),'h a'.split(),{},{}))
1039 test("0.0000" ,
1040 "%.4f" % outer(0, 3, (GOR,h), 0,testgrammar_a(),'h a'.split(),{},{}))
1042 # todo: add more of these tests...
1046 def compare_grammars(g1,g2):
1047 result = ""
1048 for d1,d2 in [(g1.p_ATTACH,g2.p_ATTACH),(g1.p_STOP,g2.p_STOP),
1049 (g1.p_ORDER, g2.p_ORDER), (g1.p_ROOT,g2.p_ROOT) ]:
1050 for k,v in d1.iteritems():
1051 if k not in d2:
1052 result += "\nreestimate1[%s]=%s missing from reestimate2"%(k,v)
1053 elif "%s"%d2[k] != "%s"%v:
1054 result += "\nreestimate1[%s]=%s while \nreestimate2[%s]=%s."%(k,v,k,d2[k])
1055 for k,v in d2.iteritems():
1056 if k not in d1:
1057 result += "\nreestimate2[%s]=%s missing from reestimate1"%(k,v)
1058 return result
1061 def testNVNgrammar():
1062 import loc_h_harmonic
1064 # make sure these are the way they were when setting up the tests:
1065 loc_h_harmonic.HARMONIC_C = 0.0
1066 loc_h_harmonic.FNONSTOP_MIN = 25
1067 loc_h_harmonic.FSTOP_MIN = 5
1068 loc_h_harmonic.RIGHT_FIRST = 1.0
1069 loc_h_harmonic.OLD_STOP_CALC = True
1071 g = loc_h_harmonic.initialize(['n v n'.split()])
1072 return g # todo
1074 def testIO():
1075 g = testgrammar()
1076 inners = [(sent, inner_sent(g, sent, {})) for sent in testcorpus]
1077 return inners
1079 if __name__ == "__main__":
1080 DEBUG.clear()
1081 regression_tests()
1083 # import profile
1084 # profile.run('testreestimation()')
1086 # import timeit
1087 # print timeit.Timer("loc_h_dmv.testreestimation()",'''import loc_h_dmv
1088 # reload(loc_h_dmv)''').timeit(1)
1091 # print "mpp-test:"
1092 # import pprint
1093 # for s in testcorpus:
1094 # print "sent:%s\nparse:set(\n%s)"%(s,pprint.pformat(list(mpp(testgrammar(), s)),
1095 # width=40))
1097 # g1 = testreestimation()
1098 # g2 = testreestimation2()
1099 # print compare_grammars(g1,g2)
1107 if False:
1108 g = testNVNgrammar()
1109 q_sent = inner_sent(g,'n v n'.split(),{})
1110 q_tree = {}
1111 q_tree[1] = 2.7213e-06 # n_0 -> v, n_0 -> n_2
1112 q_tree[2] = 9.738e-06 # n -> v -> n
1113 q_tree[3] = 2.268e-06 # n_0 -> n_2 -> v
1114 q_tree[4] = 2.7213e-06 # same as 1-3
1115 q_tree[5] = 9.738e-06
1116 q_tree[6] = 2.268e-06
1117 q_tree[7] = 1.086e-05 # n <- v -> n (e-05!!!)
1118 f_T_q = {}
1119 for i,q_t in q_tree.iteritems():
1120 f_T_q[i] = q_t / q_sent
1121 import pprint
1122 pprint.pprint(q_tree)
1123 pprint.pprint(f_T_q)
1124 print sum([f for f in f_T_q.values()])
1126 def treediv(num,den):
1127 return \
1128 sum([f_T_q[i] for i in num ]) / \
1129 sum([f_T_q[i] for i in den ])
1130 g2 = {}
1131 # g2['root --> _n_'] = treediv( (1,2,3,4,5,6), (1,2,3,4,5,6,7) )
1132 # g2['root --> _v_'] = treediv( (7,), (1,2,3,4,5,6,7) )
1133 # g2['_n_ --> STOP n><'] = treediv( (1,2,3,4,5,6,7,1,2,3,4,5,6,7),
1134 # (1,2,3,4,5,6,7,1,2,3,4,5,6,7))
1136 # g2['_n_ --> STOP n>< NON'] = treediv( (3,4,5,6),
1137 # (3,4,5,6,4) )
1139 # g2['_v_ --> STOP v><'] = treediv( (1,2,3,4,5,6,7),
1140 # (1,2,3,4,5,6,7) )
1141 # nlrtrees = (1,2,3,4,5,6,7,1,2,3,4,5,6,7,
1142 # 3,4,4,5,6)
1143 # g2['n>< --> _n_ n><'] = treediv( ( 4, 6), nlrtrees )
1144 # g2['n>< --> _v_ n><'] = treediv( (3,4,5), nlrtrees )
1145 # g2['n>< --> n> STOP'] = treediv( (1,2,3,4,5,6,7,1,2,3,4,5,6,7),
1146 # nlrtrees )
1148 # g2['n>< --> n> STOP ADJ'] = treediv( ( 4,5, 7,1,2,3,4,5,6,7),
1149 # nlrtrees )
1150 # g2['n>< --> n> STOP NON'] = treediv( (1,2,3, 6),
1151 # nlrtrees )
1153 # vlrtrees = (1,2,3,4,5,6,7,
1154 # 7,5)
1155 # g2['v>< --> _n_ v><'] = treediv( (5,7), vlrtrees )
1156 # g2['v>< --> v> STOP'] = treediv( (1,2,3,4,5,6,7), vlrtrees )
1157 # nrtrees = (1,2,3,4,5,6,7,1,2,3,4,5,6,7,
1158 # 1,1,2,3,6)
1159 # g2['n> --> n> _n_'] = treediv( (1,3), nrtrees )
1160 # g2['n> --> n> _v_'] = treediv( (1,2,6), nrtrees )
1162 # g2['n> --> n> _n_ NON'] = treediv( (1,), nrtrees )
1163 # g2['n> --> n> _n_ ADJ'] = treediv( ( 3,), nrtrees )
1164 # g2['n> --> n> _v_ ADJ'] = treediv( ( 1,2, 6), nrtrees )
1166 # vrtrees = (1,2,3,4,5,6,7,
1167 # 7,2)
1168 # g2['v> --> v> _n_'] = treediv( (2,7), vrtrees )
1170 # g2[' v|n,R '] = treediv( (1, 2, 6),
1171 # (1,1,2,3,6) )
1172 # g2[' n|n,R '] = treediv( (1, 3),
1173 # (1,1,2,3,6) )
1175 g2[' stop|n,R,non '] = treediv( ( 1,2,3,6),
1176 (1,1,2,3,6) )
1177 g2[' v|n,left '] = treediv( ( 3,4,5),
1178 (6,4,3,4,5) )
1179 g2[' n|n,left '] = treediv( (6,4),
1180 (6,4,3,4,5) )
1182 pprint.pprint(g2)
1183 g3 = reestimate2(g, ['n v n'.split()])
1184 print g3
1185 g4 = reestimate2(g, ['n v n'.split()])
1186 print g4