first attempt at fixing the w-bug, todo: test
[dmvccm.git] / src / loc_h_dmv.py
blob52620fdd887b1c19e48e03e9ff775b727fad519d
1 # loc_h_dmv.py
2 #
3 # dmv reestimation and inside-outside probabilities using loc_h, and
4 # no CNF-style rules
6 #import numpy # numpy provides Fast Arrays, for future optimization
7 import io
8 from common_dmv import *
10 ### todo: debug with @accepts once in a while, but it's SLOW
11 # from typecheck import accepts, Any
13 if __name__ == "__main__":
14 print "loc_h_dmv module tests:"
16 def adj(middle, loc_h):
17 "middle is eg. k when rewriting for i<k<j (inside probabilities)."
18 return middle == loc_h or middle == loc_h+1 # ADJ == True
20 def make_GO_AT(p_STOP,p_ATTACH):
21 p_GO_AT = {}
22 for (a,h,dir), p_ah in p_ATTACH.iteritems():
23 p_GO_AT[a,h,dir, NON] = p_ah * (1-p_STOP[h, dir, NON])
24 p_GO_AT[a,h,dir, ADJ] = p_ah * (1-p_STOP[h, dir, ADJ])
25 return p_GO_AT
27 class DMV_Grammar(io.Grammar):
28 def __str__(self):
29 def t(n):
30 return "%d=%s" % (n, self.numtag(n))
31 def p(dict,key):
32 if key in dict:
33 return dict[key]
34 else:
35 return 0.0
36 def p_a(a,h):
37 p_L = p(self.p_ATTACH,(a,h,LEFT))
38 p_R = p(self.p_ATTACH,(a,h,RIGHT))
39 if p_L == 0.0 and p_R == 0.0:
40 return ''
41 else:
42 if p_L > 0.0:
43 str = "p_ATTACH[ %s|%s,L] = %.4f" % (t(a), t(h), p_L)
44 else:
45 str = ''
46 if p_R > 0.0:
47 str = str.ljust(40)
48 str += "p_ATTACH[ %s|%s,R] = %.4f" % (t(a), t(h), p_R)
49 return str+'\n'
51 root, stop, att, ord = "","","",""
52 for h in self.headnums():
53 root += "p_ROOT[%s] = %.4f\n" % (t(h), p(self.p_ROOT, (h)))
54 stop += "p_STOP[stop|%s,L,adj] = %.4f\t" % (t(h), p(self.p_STOP, (h,LEFT,ADJ)))
55 stop += "p_STOP[stop|%s,R,adj] = %.4f\n" % (t(h), p(self.p_STOP, (h,RIGHT,ADJ)))
56 stop += "p_STOP[stop|%s,L,non] = %.4f\t" % (t(h), p(self.p_STOP, (h,LEFT,NON)))
57 stop += "p_STOP[stop|%s,R,non] = %.4f\n" % (t(h), p(self.p_STOP, (h,RIGHT,NON)))
58 att += ''.join([p_a(a,h) for a in self.headnums()])
59 ord += "p_ORDER[ left-first|%s ] = %.4f\t" % (t(h), p(self.p_ORDER, (GOL,h)))
60 ord += "p_ORDER[right-first|%s ] = %.4f\n" % (t(h), p(self.p_ORDER, (GOR,h)))
61 return root + stop + att + ord
63 def __init__(self, numtag, tagnum, p_ROOT, p_STOP, p_ATTACH, p_ORDER):
64 io.Grammar.__init__(self, numtag, tagnum)
65 self.p_ROOT = p_ROOT # p_ROOT[w] = p
66 self.p_ORDER = p_ORDER # p_ORDER[seals, w] = p
67 self.p_STOP = p_STOP # p_STOP[w, LEFT, NON] = p (etc. for LA,RN,RA)
68 self.p_ATTACH = p_ATTACH # p_ATTACH[a, h, LEFT] = p (etc. for R)
69 # p_GO_AT[a, h, LEFT, NON] = p (etc. for LA,RN,RA)
70 self.p_GO_AT = make_GO_AT(self.p_STOP, self.p_ATTACH)
72 def p_GO_AT_or0(self, a, h, dir, adj):
73 try:
74 return self.p_GO_AT[a, h, dir, adj]
75 except:
76 return 0.0
79 def locs(sent_nums, start, stop):
80 '''Return the between-word locations of all words in some fragment of
81 sent. We make sure to offset the locations correctly so that for
82 any w in the returned list, sent[w]==loc_w.
84 start is inclusive, stop is exclusive, as in klein-thesis and
85 Python's list-slicing.'''
86 for i0,w in enumerate(sent_nums[start:stop]):
87 loc_w = i0+start
88 yield (loc_w, w)
90 ###################################################
91 # P_INSIDE (dmv-specific) #
92 ###################################################
94 #@accepts(int, int, (int, int), int, Any(), [str], {tuple:float}, IsOneOf(None,{}))
95 def inner(i, j, node, loc_h, g, sent, ichart={}, mpptree=None):
96 ''' The ichart is of this form:
97 ichart[i,j,LHS, loc_h]
98 where i and j are between-word positions.
100 loc_h gives adjacency (along with k for attachment rules), and is
101 needed in P_STOP reestimation.
103 sent_nums = g.sent_nums(sent)
105 def terminal(i,j,node, loc_h, tabs):
106 if not i <= loc_h < j:
107 if 'INNER' in DEBUG:
108 print "%s*= 0.0 (wrong loc_h)" % tabs
109 return 0.0
110 elif POS(node) == sent_nums[i] and node in g.p_ORDER:
111 # todo: add to ichart perhaps? Although, it _is_ simple lookup..
112 prob = g.p_ORDER[node]
113 else:
114 if 'INNER' in DEBUG:
115 print "%sLACKING TERMINAL:" % tabs
116 prob = 0.0
117 if 'INNER' in DEBUG:
118 print "%s*= %.4f (terminal: %s -> %s_%d)" % (tabs,prob, node_str(node), sent[i], loc_h)
119 return prob
121 def e(i,j, (s_h,h), loc_h, n_t):
122 def to_mpp(p, L, R):
123 if mpptree:
124 key = (i,j, (s_h,h), loc_h)
125 if key not in mpptree:
126 mpptree[key] = (p, L, R)
127 elif mpptree[key][0] < p:
128 mpptree[key] = (p, L, R)
130 def tab():
131 "Tabs for debug output"
132 return "\t"*n_t
134 if (i, j, (s_h,h), loc_h) in ichart:
135 if 'INNER' in DEBUG:
136 print "%s*= %.4f in ichart: i:%d j:%d node:%s loc:%s" % (tab(),ichart[i, j, (s_h,h), loc_h], i, j,
137 node_str((s_h,h)), loc_h)
138 return ichart[i, j, (s_h,h), loc_h]
139 else:
140 # Either terminal rewrites, using p_ORDER:
141 if i+1 == j and (s_h == GOR or s_h == GOL):
142 return terminal(i, j, (s_h,h), loc_h, tab())
143 else: # Or not at terminal level yet:
144 if 'INNER' in DEBUG:
145 print "%s%s (%.1f) from %d to %d" % (tab(),node_str((s_h,h)),loc_h,i,j)
146 if s_h == SEAL:
147 p_RGOL = g.p_STOP[h, LEFT, adj(i,loc_h)] * e(i,j,(RGOL,h),loc_h,n_t+1)
148 p_LGOR = g.p_STOP[h, RIGHT, adj(j,loc_h)] * e(i,j,(LGOR,h),loc_h,n_t+1)
149 p = p_RGOL + p_LGOR
150 to_mpp(p_RGOL, STOPKEY, (i,j, (RGOL,h),loc_h))
151 to_mpp(p_LGOR, (i,j, (RGOL,h),loc_h), STOPKEY )
152 if 'INNER' in DEBUG:
153 print "%sp= %.4f (STOP)" % (tab(), p)
154 elif s_h == RGOL or s_h == GOL:
155 p = 0.0
156 if s_h == RGOL:
157 p = g.p_STOP[h, RIGHT, adj(j,loc_h)] * e(i,j, (GOR,h),loc_h,n_t+1)
158 to_mpp(p, (i,j, (GOR,h),loc_h), STOPKEY)
159 for k in xgo_left(i, loc_h): # i < k <= loc_l(h)
160 p_R = e(k, j, ( s_h,h), loc_h, n_t+1)
161 if p_R > 0.0:
162 for loc_a,a in locs(sent_nums, i, k):
163 p_ah = g.p_GO_AT_or0(a, h, LEFT, adj(k,loc_h))
164 if p_ah > 0.0:
165 p_L = e(i, k, (SEAL,a), loc_a, n_t+1)
166 p_add = p_L * p_ah * p_R
167 p += p_add
168 to_mpp(p_add,
169 (i, k, (SEAL,a), loc_a),
170 (k, j, ( s_h,h), loc_h))
171 if 'INNER' in DEBUG:
172 print "%sp= %.4f (ATTACH)" % (tab(), p)
173 elif s_h == GOR or s_h == LGOR:
174 p = 0.0
175 if s_h == LGOR:
176 p = g.p_STOP[h, LEFT, adj(i,loc_h)] * e(i,j, (GOL,h),loc_h,n_t+1)
177 to_mpp(p, (i,j, (GOL,h),loc_h), STOPKEY)
178 for k in xgo_right(loc_h, j): # loc_l(h) < k < j
179 p_L = e(i, k, ( s_h,h), loc_h, n_t+1)
180 if p_L > 0.0:
181 for loc_a,a in locs(sent_nums,k,j):
182 p_ah = g.p_GO_AT_or0(a, h, RIGHT, adj(k,loc_h))
183 p_R = e(k, j, (SEAL,a), loc_a, n_t+1)
184 p_add = p_L * p_ah * p_R
185 p += p_add
186 to_mpp(p_add,
187 (i, k, ( s_h,h), loc_h),
188 (k, j, (SEAL,a), loc_a))
190 if 'INNER' in DEBUG:
191 print "%sp= %.4f (ATTACH)" % (tab(), p)
192 # elif s_h == GOL: # todo
194 ichart[i, j, (s_h,h), loc_h] = p
195 return p
196 # end of e-function
198 inner_prob = e(i,j,node,loc_h, 0)
199 if 'INNER' in DEBUG:
200 print debug_ichart(g,sent,ichart)
201 return inner_prob
202 # end of dmv.inner(i, j, node, loc_h, g, sent, ichart={})
205 def debug_ichart(g,sent,ichart):
206 str = "---ICHART:---\n"
207 for (s,t,LHS,loc_h),v in ichart.iteritems():
208 str += "%s -> %s_%d ... %s_%d (loc_h:%s):\t%.4f\n" % (node_str(LHS,g.numtag),
209 sent[s], s, sent[s], t, loc_h, v)
210 str += "---ICHART:end---\n"
211 return str
214 def inner_sent(g, sent, ichart={}):
215 return sum([g.p_ROOT[w] * inner(0, len(sent), (SEAL,w), loc_w, g, sent, ichart)
216 for loc_w,w in locs(g.sent_nums(sent),0,len(sent))])
222 ###################################################
223 # P_OUTSIDE (dmv-specific) #
224 ###################################################
226 #@accepts(int, int, (int, int), int, Any(), [str], {tuple:float}, {tuple:float})
227 def outer(i,j,w_node,loc_w, g, sent, ichart={}, ochart={}):
228 ''' http://www.student.uib.no/~kun041/dmvccm/DMVCCM.html#outer
230 w_node is a pair (seals,POS); the w in klein-thesis is made up of
231 POS(w) and loc_w
233 sent_nums = g.sent_nums(sent)
234 if POS(w_node) not in sent_nums[i:j]:
235 # sanity check, w must be able to dominate sent[i:j]
236 return 0.0
238 # local functions:
239 def e(i,j,LHS,loc_h): # P_{INSIDE}
240 try:
241 return ichart[i,j,LHS,loc_h]
242 except:
243 return inner(i,j,LHS,loc_h,g,sent,ichart)
245 def f(i,j,w_node,loc_w):
246 if not (i <= loc_w < j):
247 return 0.0
248 if (i,j,w_node,loc_w) in ochart:
249 return ochart[i,j, w_node,loc_w]
250 if w_node == ROOT:
251 if i == 0 and j == len(sent):
252 return 1.0
253 else: # ROOT may only be used on full sentence
254 return 0.0
255 # but we may have non-ROOTs (stops) over full sentence too:
256 w = POS(w_node)
257 s_w = seals(w_node)
259 # todo: try either if p_M > 0.0: or sum(), and speed-test them
261 if s_w == SEAL: # w == a
262 # todo: do the i<sent<j check here to save on calls?
263 p = g.p_ROOT[w] * f(i,j,ROOT,loc_w)
264 # left attach
265 for k in xgt(j, sent): # j<k<len(sent)+1
266 for loc_h,h in locs(sent_nums,j,k):
267 p_wh = g.p_GO_AT_or0(w, h, LEFT, adj(j, loc_h))
268 for s_h in [RGOL, GOL]:
269 p += f(i,k,(s_h,h),loc_h) * p_wh * e(j,k,(s_h,h),loc_h)
270 # right attach
271 for k in xlt(i): # k<i
272 for loc_h,h in locs(sent_nums,k,i):
273 p_wh = g.p_GO_AT_or0(w, h, RIGHT, adj(i, loc_h))
274 for s_h in [LGOR, GOR]:
275 p += e(k,i,(s_h,h), loc_h) * p_wh * f(k,j,(s_h,h), loc_h)
277 elif s_w == RGOL or s_w == GOL: # w == h, left stop + left attach
278 if s_w == RGOL:
279 s_h = SEAL
280 else: # s_w == GOL
281 s_h = LGOR
282 p = g.p_STOP[w, LEFT, adj(i,loc_w)] * f(i,j,( s_h,w),loc_w)
283 for k in xlt(i): # k<i
284 for loc_a,a in locs(sent_nums,k,i):
285 p_aw = g.p_GO_AT_or0(a, w, LEFT, adj(i, loc_w))
286 p += e(k,i, (SEAL,a),loc_a) * p_aw * f(k,j,w_node,loc_w)
288 elif s_w == GOR or s_w == LGOR: # w == h, right stop + right attach
289 if s_w == GOR:
290 s_h = RGOL
291 else: # s_w == LGOR
292 s_h = SEAL
293 p = g.p_STOP[w, RIGHT, adj(j,loc_w)] * f(i,j,( s_h,w),loc_w)
294 for k in xgt(j, sent): # j<k<len(sent)+1
295 for loc_a,a in locs(sent_nums,j,k):
296 p_ah = g.p_GO_AT_or0(a, w, RIGHT, adj(j, loc_w))
297 p += f(i,k,w_node,loc_w) * p_ah * e(j,k,(SEAL,a),loc_a)
299 ochart[i,j,w_node,loc_w] = p
300 return p
301 # end outer.f()
303 return f(i,j,w_node,loc_w)
304 # end outer(i,j,w_node,loc_w, g,sent, ichart,ochart)
309 ###################################################
310 # Reestimation: #
311 ###################################################
313 # todo: it seems we have to rewrite attachment reestimation so that we
314 # have 'a´ as the outer loop, then sentences... but this means running
315 # through sentences several times, and that would require storing
316 # inner probabilites...agh!
318 def reest_zeros(h_nums):
319 '''A dict to hold numerators and denominators for our 6+ reestimation
320 formulas. '''
321 # todo: p_ORDER?
322 fr = { ('ROOT','den'):0.0 } # holds sum over p_sent
323 for h in h_nums:
324 fr['ROOT','num',h] = 0.0
325 for s_h in [GOR,GOL,RGOL,LGOR]:
326 x = (s_h,h)
327 fr['hat_a','den',x] = 0.0 # = c()
328 # not all arguments are attached to, so we just initialize
329 # fr['hat_a','num',a,(s_h,h)] as they show up, in reest_freq
330 for adj in [NON, ADJ]:
331 for nd in ['num','den']:
332 fr['STOP',nd,x,adj] = 0.0
333 return fr
335 def reest_freq(g, corpus):
336 fr = reest_zeros(g.headnums())
337 ichart = {}
338 ochart = {}
339 p_sent = None # 50 % speed increase on storing this locally
341 # local functions altogether 2x faster than global
342 def c(i,j,LHS,loc_h,sent):
343 if not p_sent > 0.0:
344 return p_sent
346 p_in = e(i,j, LHS,loc_h,sent)
347 if not p_in > 0.0:
348 return p_in
350 p_out = f(i,j, LHS,loc_h,sent)
351 return p_in * p_out / p_sent
352 # end reest_freq.c()
354 def f(i,j,LHS,loc_h,sent): # P_{OUTSIDE}
355 try:
356 return ochart[i,j,LHS,loc_h]
357 except:
358 return outer(i,j,LHS,loc_h,g,sent,ichart,ochart)
359 # end reest_freq.f()
361 def e(i,j,LHS,loc_h,sent): # P_{INSIDE}
362 try:
363 return ichart[i,j,LHS,loc_h]
364 except:
365 return inner(i,j,LHS,loc_h,g,sent,ichart)
366 # end reest_freq.e()
368 def w_left(i,j, x,loc_h,sent,sent_nums):
369 if not p_sent > 0.0: return
371 h = POS(x)
372 a_k = {}
373 for k in xtween(i, j):
374 p_out = f(i,j, x,loc_h, sent)
375 if not p_out > 0.0:
376 continue
377 p_R = e(k,j, x,loc_h, sent)
378 if not p_R > 0.0:
379 continue
381 for loc_a,a in locs(sent_nums, i,k): # i<=loc_l(a)<k
382 p_rule = g.p_GO_AT_or0(a, h, LEFT, adj(k, loc_h))
383 p_L = e(i,k, (SEAL,a), loc_a, sent)
384 p = p_L * p_out * p_R * p_rule
385 try: a_k[a] += p
386 except: a_k[a] = p
388 for a,p in a_k.iteritems():
389 try: fr['hat_a','num',a,x] += p / p_sent
390 except: fr['hat_a','num',a,x] = p / p_sent
391 # end reest_freq.w_left()
393 def w_right(i,j, x,loc_h,sent,sent_nums):
394 if not p_sent > 0.0: return
396 h = POS(x)
397 a_k = {}
398 for k in xtween(i, j):
399 p_out = f(i,j, x,loc_h, sent)
400 if not p_out > 0.0:
401 continue
402 p_L = e(i,k, x,loc_h, sent)
403 if not p_L > 0.0:
404 continue
406 for loc_a,a in locs(sent_nums, k,j): # k<=loc_l(a)<j
407 p_rule = g.p_GO_AT_or0(a, h, RIGHT, adj(k, loc_h))
408 p_R = e(k,j, (SEAL,a),loc_a, sent)
409 p = p_L * p_out * p_R * p_rule
410 try: a_k[a] += p
411 except: a_k[a] = p
413 for a,p in a_k.iteritems():
414 try: fr['hat_a','num',a,x] += p / p_sent
415 except: fr['hat_a','num',a,x] = p / p_sent
416 # end reest_freq.w_right()
418 # in reest_freq:
419 for sent in corpus:
420 if 'REEST' in DEBUG:
421 print sent
422 ichart = {}
423 ochart = {}
424 p_sent = inner_sent(g, sent, ichart)
425 fr['ROOT','den'] += p_sent
427 sent_nums = g.sent_nums(sent)
429 for loc_h,h in locs(sent_nums,0,len(sent)+1): # locs-stop is exclusive, thus +1
430 # root:
431 fr['ROOT','num',h] += g.p_ROOT[h] * e(0,len(sent), (SEAL,h),loc_h, sent)
433 loc_l_h = loc_h
434 loc_r_h = loc_l_h+1
436 # left non-adjacent stop:
437 for i in xlt(loc_l_h):
438 fr['STOP','num',(GOL,h),NON] += c(loc_l_h, j, (LGOR, h),loc_h, sent)
439 fr['STOP','den',(GOL,h),NON] += c(loc_l_h, j, (GOL, h),loc_h, sent)
440 for j in xgteq(loc_r_h, sent):
441 fr['STOP','num',(RGOL,h),NON] += c(i, j, (SEAL, h),loc_h, sent)
442 fr['STOP','den',(RGOL,h),NON] += c(i, j, (RGOL, h),loc_h, sent)
443 # left adjacent stop, i = loc_l_h
444 fr['STOP','num',(GOL,h),ADJ] += c(loc_l_h, loc_r_h, (LGOR, h),loc_h, sent)
445 fr['STOP','den',(GOL,h),ADJ] += c(loc_l_h, loc_r_h, (GOL, h),loc_h, sent)
446 for j in xgteq(loc_r_h, sent):
447 fr['STOP','num',(RGOL,h),ADJ] += c(loc_l_h, j, (SEAL, h),loc_h, sent)
448 fr['STOP','den',(RGOL,h),ADJ] += c(loc_l_h, j, (RGOL, h),loc_h, sent)
449 # right non-adjacent stop:
450 for j in xgt(loc_r_h, sent):
451 fr['STOP','num',(GOR,h),NON] += c(loc_l_h, j, (RGOL, h),loc_h, sent)
452 fr['STOP','den',(GOR,h),NON] += c(loc_l_h, j, (GOR, h),loc_h, sent)
453 for i in xlteq(loc_l_h):
454 fr['STOP','num',(LGOR,h),NON] += c(loc_l_h, j, (SEAL, h),loc_h, sent)
455 fr['STOP','den',(LGOR,h),NON] += c(loc_l_h, j, (LGOR, h),loc_h, sent)
456 # right adjacent stop, j = loc_r_h
457 fr['STOP','num',(GOR,h),ADJ] += c(loc_l_h, loc_r_h, (RGOL, h),loc_h, sent)
458 fr['STOP','den',(GOR,h),ADJ] += c(loc_l_h, loc_r_h, (GOR, h),loc_h, sent)
459 for i in xlteq(loc_l_h):
460 fr['STOP','num',(LGOR,h),ADJ] += c(loc_l_h, j, (SEAL, h),loc_h, sent)
461 fr['STOP','den',(LGOR,h),ADJ] += c(loc_l_h, j, (LGOR, h),loc_h, sent)
463 # left attachment:
464 if 'REEST_ATTACH' in DEBUG:
465 print "Lattach %s: for i < %s"%(g.numtag(h),sent[0:loc_h+1])
466 for s_h in [RGOL, GOL]:
467 x = (s_h, h)
468 for i in xlt(loc_l_h): # i < loc_l(h)
469 if 'REEST_ATTACH' in DEBUG:
470 print "\tfor j >= %s"%sent[loc_h:len(sent)]
471 for j in xgteq(loc_r_h, sent): # j >= loc_r(h)
472 fr['hat_a','den',x] += c(i,j, x,loc_h, sent) # v_q in L&Y
473 if 'REEST_ATTACH' in DEBUG:
474 print "\t\tc( %d , %d, %s, %s, sent)=%.4f"%(i,j,node_str(x),loc_h,fr['hat_a','den',x])
475 w_left(i, j, x,loc_h, sent,sent_nums) # compute w for all a in sent
477 # right attachment:
478 if 'REEST_ATTACH' in DEBUG:
479 print "Rattach %s: for i <= %s"%(g.numtag(h),sent[0:loc_h+1])
480 for s_h in [GOR, LGOR]:
481 x = (s_h, h)
482 for i in xlteq(loc_l_h): # i <= loc_l(h)
483 if 'REEST_ATTACH' in DEBUG:
484 print "\tfor j > %s"%sent[loc_h:len(sent)]
485 for j in xgt(loc_r_h, sent): # j > loc_r(h)
486 fr['hat_a','den',x] += c(i,j, x,loc_h, sent) # v_q in L&Y
487 if 'REEST_ATTACH' in DEBUG:
488 print "\t\tc( %d , %d, %s, %s, sent)=%.4f"%(loc_h,j,node_str(x),loc_h,fr['hat_a','den',x])
489 w_right(loc_l_h,j, x,loc_h, sent,sent_nums) # compute w for all a in sent
491 # end for loc_h,h
492 # end for sent
494 return fr
496 def reestimate(g, corpus):
497 fr = reest_freq(g, corpus)
498 p_ROOT, p_STOP, p_ATTACH = {},{},{}
500 for h in g.headnums():
501 reest_head(h, fr, g, p_ROOT, p_STOP, p_ATTACH)
503 g.p_STOP = p_STOP
504 g.p_ATTACH = p_ATTACH
505 g.p_GO_AT = make_GO_AT(p_STOP,p_ATTACH)
506 g.p_ROOT = p_ROOT
507 return fr
510 def reest_head(h, fr, g, p_ROOT, p_STOP, p_ATTACH):
511 "Given a single head, update g with the reestimated probability."
512 # remove 0-prob stuff? todo
513 try:
514 p_ROOT[h] = fr['ROOT','num',h] / fr['ROOT','den']
515 except:
516 p_ROOT[h] = fr['ROOT','den']
518 for dir in [LEFT,RIGHT]:
519 for adj in [ADJ, NON]: # p_STOP
520 p_STOP[h, dir, adj] = 0.0
521 for s_h in dirseal(dir):
522 x = (s_h,h)
523 p = fr['STOP','den', x, adj]
524 if p > 0.0:
525 p = fr['STOP', 'num', x, adj] / p
526 p_STOP[h, dir, adj] += p
528 for s_h in dirseal(dir): # make hat_a for p_ATTACH
529 x = (s_h,h)
530 hat_a = {}
532 p_c = fr['hat_a','den',x]
533 for a in g.headnums():
534 try:
535 hat_a[a,x] = fr['hat_a','num',a,x] / p_c
536 except:
537 pass
539 sum_hat_a = sum([hat_a[w,x] for w in g.headnums()
540 if (w,x) in hat_a])
542 for a in g.headnums():
543 if (a,h,dir) not in p_ATTACH:
544 p_ATTACH[a,h,dir] = 0.0
545 try: # (a,x) might not be in hat_a
546 p_ATTACH[a,h,dir] += hat_a[a,x] / sum_hat_a
547 except:
548 pass
557 ###################################################
558 # Most Probable Parse: #
559 ###################################################
561 STOPKEY = (-1,-1,STOP,-1)
562 ROOTKEY = (-1,-1,ROOT,-1)
564 def make_mpptree(g, sent):
565 '''Tell inner() to make an mpptree, connect ROOT to this. (Logically,
566 this should be part of inner_sent though...)'''
567 ichart = {}
568 mpptree = { ROOTKEY:(0.0, ROOTKEY, None) }
569 for loc_w,w in locs(g.sent_nums(sent),0,len(sent)):
570 p = g.p_ROOT[w] * inner(0, len(sent), (SEAL,w), loc_w, g, sent, ichart, mpptree)
571 L = ROOTKEY
572 R = (0,len(sent), (SEAL,w), loc_w)
573 if mpptree[ROOTKEY][0] < p:
574 mpptree[ROOTKEY] = (p, L, R)
575 return mpptree
577 def parse_mpptree(mpptree, sent):
578 '''mpptree is a dict of the form {k:(p,L,R),...}; where k, L and R
579 are `keys' of the form (i,j,node,loc).
581 returns an mpp of the form [((head, loc_h),(arg, loc_a)), ...],
582 where head and arg are tags.'''
583 # local functions for clear access to mpptree:
584 def k_node(key):
585 return key[2]
586 def k_POS(key):
587 return POS(k_node(key))
588 def k_seals(key):
589 return seals(k_node(key))
590 def k_locnode(key):
591 return (k_node(key),key[3])
592 def k_locPOS(key):
593 return (k_POS(key),key[3])
594 def k_terminal(key):
595 s_k = k_seals(key) # i+1 == j
596 return key[0] + 1 == key[1] and (s_k == GOR or s_k == GOL)
597 def t_L(tree_entry):
598 return tree_entry[1]
599 def t_R(tree_entry):
600 return tree_entry[2]
602 # arbitrarily, "ROOT attaches to right". We add it here to
603 # avoid further complications:
604 firstkey = t_R(mpptree[ROOTKEY])
605 deps = set([ (k_locPOS(ROOTKEY), k_locPOS(firstkey), RIGHT) ])
606 q = [firstkey]
608 while len(q) > 0:
609 k = q.pop()
610 if k_terminal(k):
611 continue
612 else:
613 L = t_L( mpptree[k] )
614 R = t_R( mpptree[k] )
615 if k_locnode( k ) == k_locnode( L ): # Rattach
616 deps.add((k_locPOS( k ), k_locPOS( R ), LEFT))
617 q.extend( [L, R] )
618 elif k_locnode( k ) == k_locnode( R ): # Lattach
619 deps.add((k_locPOS( k ), k_locPOS( L ), RIGHT))
620 q.extend( [L, R] )
621 elif R == STOPKEY:
622 q.append( L )
623 elif L == STOPKEY:
624 q.append( R )
625 return deps
627 def mpp(g, sent):
628 tagf = g.numtag # localized function, todo: speed-test
629 mpptree = make_mpptree(g, sent)
630 return set([((tagf(h), loc_h), (tagf(a), loc_a))
631 for (h, loc_h),(a,loc_a),dir in parse_mpptree(mpptree,sent)])
634 ########################################################################
635 # testing functions: #
636 ########################################################################
638 testcorpus = [s.split() for s in ['det nn vbd c vbd','vbd nn c vbd',
639 'det nn vbd', 'det nn vbd c pp',
640 'det nn vbd', 'det vbd vbd c pp',
641 'det nn vbd', 'det nn vbd c vbd',
642 'det nn vbd', 'det nn vbd c vbd',
643 'det nn vbd', 'det nn vbd c vbd',
644 'det nn vbd', 'det nn vbd c pp',
645 'det nn vbd pp', 'det nn vbd', ]]
647 def testgrammar():
648 import loc_h_harmonic
649 reload(loc_h_harmonic)
651 # make sure these are the way they were when setting up the tests:
652 loc_h_harmonic.HARMONIC_C = 0.0
653 loc_h_harmonic.FNONSTOP_MIN = 25
654 loc_h_harmonic.FSTOP_MIN = 5
655 loc_h_harmonic.RIGHT_FIRST = 1.0
657 return loc_h_harmonic.initialize(testcorpus)
659 def ig(s,t,LHS,loc_h):
660 return inner(s,t,LHS,loc_h,testgrammar(),'det nn vbd'.split(),{})
662 def testreestimation():
663 g = testgrammar()
664 print g
665 # DEBUG.add('REEST_ATTACH')
666 f = reestimate(g, testcorpus)
667 print g
668 testreestimation_regression(f)
669 return f
671 def testreestimation_regression(fr):
672 f_stops = {('STOP', 'den', (RGOL,3),NON): 12.212773236178391, ('STOP', 'den', (GOR,2),ADJ): 4.0, ('STOP', 'num', (GOR,4),NON): 2.5553487221351365, ('STOP', 'den', (RGOL,2),NON): 1.274904052793207, ('STOP', 'num', (RGOL,1),ADJ): 14.999999999999995, ('STOP', 'den', (GOR,3),ADJ): 15.0, ('STOP', 'num', (RGOL,4),ADJ): 16.65701084787457, ('STOP', 'num', (RGOL,0),ADJ): 4.1600647714443468, ('STOP', 'den', (RGOL,4),NON): 6.0170669155897105, ('STOP', 'num', (RGOL,3),ADJ): 2.7872267638216113, ('STOP', 'num', (RGOL,2),ADJ): 2.9723139990470515, ('STOP', 'den', (RGOL,2),ADJ): 4.0, ('STOP', 'den', (GOR,3),NON): 12.945787931730905, ('STOP', 'den', (RGOL,3),ADJ): 14.999999999999996, ('STOP', 'den', (GOR,2),NON): 0.0, ('STOP', 'den', (RGOL,0),ADJ): 8.0, ('STOP', 'num', (GOR,4),ADJ): 19.44465127786486, ('STOP', 'den', (GOR,1),NON): 3.1966410324085777, ('STOP', 'den', (RGOL,1),ADJ): 14.999999999999995, ('STOP', 'num', (GOR,3),ADJ): 4.1061665495365558, ('STOP', 'den', (GOR,0),NON): 4.8282499043902476, ('STOP', 'num', (RGOL,4),NON): 5.3429891521254289, ('STOP', 'num', (GOR,2),ADJ): 4.0, ('STOP', 'den', (RGOL,4),ADJ): 22.0, ('STOP', 'num', (GOR,1),ADJ): 12.400273895299103, ('STOP', 'num', (RGOL,2),NON): 1.0276860009529487, ('STOP', 'num', (GOR,0),ADJ): 3.1717500956097533, ('STOP', 'num', (RGOL,3),NON): 12.212773236178391, ('STOP', 'den', (GOR,4),ADJ): 22.0, ('STOP', 'den', (GOR,4),NON): 2.8705211946979836, ('STOP', 'num', (RGOL,0),NON): 3.8399352285556518, ('STOP', 'num', (RGOL,1),NON): 0.0, ('STOP', 'num', (GOR,0),NON): 4.8282499043902476, ('STOP', 'num', (GOR,1),NON): 2.5997261047008959, ('STOP', 'den', (RGOL,1),NON): 0.0, ('STOP', 'den', (GOR,0),ADJ): 8.0, ('STOP', 'num', (GOR,2),NON): 0.0, ('STOP', 'den', (RGOL,0),NON): 4.6540557322109795, ('STOP', 'den', (GOR,1),ADJ): 15.0, ('STOP', 'num', (GOR,3),NON): 10.893833450463443}
673 for k,v in f_stops.iteritems():
674 if not k in fr:
675 print '''Regression in P_STOP reestimation, should be fr[%s]=%.4f,
676 but %s not in fr'''%(k,v,k)
677 elif not "%.10f"%fr[k] == "%.10f"%v:
678 print '''Regression in P_STOP reestimation, should be fr[%s]=%.4f,
679 got fr[%s]=%.4f.'''%(k,v,k,fr[k])
681 def testmpp_regression(mpptree,k_n):
682 mpp = {ROOTKEY: (2.877072116829971e-05, STOPKEY, (0, 3, (2, 3), 1)),
683 (0, 1, (1, 1), 0): (0.1111111111111111, (0, 1, (0, 1), 0), STOPKEY),
684 (0, 1, (2, 1), 0): (0.049382716049382713, STOPKEY, (0, 1, (1, 1), 0)),
685 (0, 3, (1, 3), 1): (0.00027619892321567721,
686 (0, 1, (2, 1), 0),
687 (1, 3, (1, 3), 1)),
688 (0, 3, (2, 3), 1): (0.00012275507698474543, STOPKEY, (0, 3, (1, 3), 1)),
689 (1, 3, (0, 3), 1): (0.025280986819448362,
690 (1, 2, (0, 3), 1),
691 (2, 3, (2, 4), 2)),
692 (1, 3, (1, 3), 1): (0.0067415964851862296, (1, 3, (0, 3), 1), STOPKEY),
693 (2, 3, (1, 4), 2): (0.32692307692307693, (2, 3, (0, 4), 2), STOPKEY),
694 (2, 3, (2, 4), 2): (0.037721893491124266, STOPKEY, (2, 3, (1, 4), 2))}
695 for k,(v,L,R) in mpp.iteritems():
696 k2 = k[0:k_n] # 3 if the new does not check loc_h
697 if type(k)==str:
698 k2 = k
699 if k2 not in mpptree:
700 print "mpp regression, %s missing"%(k2,)
701 else:
702 vnew = mpptree[k2][0]
703 if not "%.10f"%vnew == "%.10f"%v:
704 print "mpp regression, wanted %s=%.5f, got %.5f"%(k2,v,vnew)
707 def testgrammar_a():
708 h, a = 0, 1
709 p_ROOT, p_STOP, p_ATTACH, p_ORDER = {},{},{},{}
710 p_ROOT[h] = 0.9
711 p_ROOT[a] = 0.1
712 p_STOP[h,LEFT,NON] = 1.0
713 p_STOP[h,LEFT,ADJ] = 1.0
714 p_STOP[h,RIGHT,NON] = 0.4 # RSTOP
715 p_STOP[h,RIGHT,ADJ] = 0.3 # RSTOP
716 p_STOP[a,LEFT,NON] = 1.0
717 p_STOP[a,LEFT,ADJ] = 1.0
718 p_STOP[a,RIGHT,NON] = 0.4 # RSTOP
719 p_STOP[a,RIGHT,ADJ] = 0.3 # RSTOP
720 p_ATTACH[a,h,LEFT] = 1.0 # not used
721 p_ATTACH[a,h,RIGHT] = 1.0 # not used
722 p_ATTACH[h,a,LEFT] = 1.0 # not used
723 p_ATTACH[h,a,RIGHT] = 1.0 # not used
724 p_ATTACH[h,h,LEFT] = 1.0 # not used
725 p_ATTACH[h,h,RIGHT] = 1.0 # not used
726 p_ORDER[(GOR, h)] = 1.0
727 p_ORDER[(GOL, h)] = 0.0
728 p_ORDER[(GOR, a)] = 1.0
729 p_ORDER[(GOL, a)] = 0.0
730 g = DMV_Grammar({h:'h',a:'a'}, {'h':h,'a':a}, p_ROOT, p_STOP, p_ATTACH, p_ORDER)
731 # these probabilities are impossible so add them manually:
732 g.p_GO_AT[a,a,LEFT,NON] = 0.4 # Lattach
733 g.p_GO_AT[a,a,LEFT,ADJ] = 0.6 # Lattach
734 g.p_GO_AT[h,a,LEFT,NON] = 0.2 # Lattach to h
735 g.p_GO_AT[h,a,LEFT,ADJ] = 0.1 # Lattach to h
736 g.p_GO_AT[a,a,RIGHT,NON] = 1.0 # Rattach
737 g.p_GO_AT[a,a,RIGHT,ADJ] = 1.0 # Rattach
738 g.p_GO_AT[h,a,RIGHT,NON] = 1.0 # Rattach to h
739 g.p_GO_AT[h,a,RIGHT,ADJ] = 1.0 # Rattach to h
740 g.p_GO_AT[h,h,LEFT,NON] = 0.2 # Lattach
741 g.p_GO_AT[h,h,LEFT,ADJ] = 0.1 # Lattach
742 g.p_GO_AT[a,h,LEFT,NON] = 0.4 # Lattach to a
743 g.p_GO_AT[a,h,LEFT,ADJ] = 0.6 # Lattach to a
744 g.p_GO_AT[h,h,RIGHT,NON] = 1.0 # Rattach
745 g.p_GO_AT[h,h,RIGHT,ADJ] = 1.0 # Rattach
746 g.p_GO_AT[a,h,RIGHT,NON] = 1.0 # Rattach to a
747 g.p_GO_AT[a,h,RIGHT,ADJ] = 1.0 # Rattach to a
748 return g
751 def testgrammar_h():
752 h = 0
753 p_ROOT, p_STOP, p_ATTACH, p_ORDER = {},{},{},{}
754 p_ROOT[h] = 1.0
755 p_STOP[h,LEFT,NON] = 1.0
756 p_STOP[h,LEFT,ADJ] = 1.0
757 p_STOP[h,RIGHT,NON] = 0.4
758 p_STOP[h,RIGHT,ADJ] = 0.3
759 p_ATTACH[h,h,LEFT] = 1.0 # not used
760 p_ATTACH[h,h,RIGHT] = 1.0 # not used
761 p_ORDER[(GOR, h)] = 1.0
762 p_ORDER[(GOL, h)] = 0.0
763 g = DMV_Grammar({h:'h'}, {'h':h}, p_ROOT, p_STOP, p_ATTACH, p_ORDER)
764 g.p_GO_AT[h,h,LEFT,NON] = 0.6 # these probabilities are impossible
765 g.p_GO_AT[h,h,LEFT,ADJ] = 0.7 # so add them manually...
766 g.p_GO_AT[h,h,RIGHT,NON] = 1.0
767 g.p_GO_AT[h,h,RIGHT,ADJ] = 1.0
768 return g
772 def testreestimation_h():
773 DEBUG.add('REEST')
774 g = testgrammar_h()
775 reestimate(g,['h h h'.split()])
778 def test(wanted, got):
779 if not wanted == got:
780 raise Warning, "Regression! Should be %s: %s" % (wanted, got)
782 def regression_tests():
783 testmpp_regression(make_mpptree(testgrammar(), testcorpus[2]),4)
784 h = 0
786 test("0.120",
787 "%.3f" % inner(0, 2, (SEAL,h), 0, testgrammar_h(), 'h h'.split(),{}))
788 test("0.063",
789 "%.3f" % inner(0, 2, (SEAL,h), 1, testgrammar_h(), 'h h'.split(),{}))
790 test("0.1842",
791 "%.4f" % inner_sent(testgrammar_h(), 'h h h'.split()))
793 test("0.1092",
794 "%.4f" % inner(0, 3, (SEAL,0), 0, testgrammar_h(), 'h h h'.split(),{}))
795 test("0.0252",
796 "%.4f" % inner(0, 3, (SEAL,0), 1, testgrammar_h(), 'h h h'.split(),{}))
797 test("0.0498",
798 "%.4f" % inner(0, 3, (SEAL,h), 2, testgrammar_h(), 'h h h'.split(),{}))
800 test("0.58" ,
801 "%.2f" % outer(1, 3, (RGOL,h), 2, testgrammar_h(),'h h h'.split(),{},{}))
802 test("0.61" , # ftw? can't be right... there's an 0.4 shared between these two...
803 "%.2f" % outer(1, 3, (RGOL,h), 1, testgrammar_h(),'h h h'.split(),{},{}))
805 test("0.00" ,
806 "%.2f" % outer(1, 3, (RGOL,h), 0, testgrammar_h(),'h h h'.split(),{},{}))
807 test("0.00" ,
808 "%.2f" % outer(1, 3, (RGOL,h), 3, testgrammar_h(),'h h h'.split(),{},{}))
810 test("0.1089" ,
811 "%.4f" % outer(0, 1, (GOR,h), 0,testgrammar_a(),'h a'.split(),{},{}))
812 test("0.3600" ,
813 "%.4f" % outer(0, 2, (GOR,h), 0,testgrammar_a(),'h a'.split(),{},{}))
814 test("0.0000" ,
815 "%.4f" % outer(0, 3, (GOR,h), 0,testgrammar_a(),'h a'.split(),{},{}))
817 # todo: add more of these tests...
819 if __name__ == "__main__":
820 DEBUG.clear()
822 # import profile
823 # profile.run('testreestimation()')
825 # import timeit
826 # print timeit.Timer("loc_h_dmv.testreestimation()",'''import loc_h_dmv
827 # reload(loc_h_dmv)''').timeit(1)
829 regression_tests()
831 # print "mpp-test:"
832 # import pprint
833 # for s in testcorpus:
834 # print "sent:%s\nparse:set(\n%s)"%(s,pprint.pformat(list(mpp(testgrammar(), s)),
835 # width=40))
838 # import pprint
839 # pprint.pprint( testreestimation())
843 def testIO():
844 g = testgrammar()
845 inners = [(sent, inner_sent(g, sent, {})) for sent in testcorpus]
846 return inners