35ea49d791248df8857d42f3abdbca62eecce797
3 # dmv reestimation and inside-outside probabilities using loc_h, and
6 #import numpy # numpy provides Fast Arrays, for future optimization
8 from common_dmv
import *
10 ### todo: debug with @accepts once in a while, but it's SLOW
11 # from typecheck import accepts, Any
13 if __name__
== "__main__":
14 print "loc_h_dmv module tests:"
16 def adj(middle
, loc_h
):
17 "middle is eg. k when rewriting for i<k<j (inside probabilities)."
18 return middle
== loc_h
or middle
== loc_h
+1 # ADJ == True
20 def make_GO_AT(p_STOP
,p_ATTACH
):
22 for (a
,h
,dir), p_ah
in p_ATTACH
.iteritems():
23 p_GO_AT
[a
,h
,dir, NON
] = p_ah
* (1-p_STOP
[h
, dir, NON
])
24 p_GO_AT
[a
,h
,dir, ADJ
] = p_ah
* (1-p_STOP
[h
, dir, ADJ
])
27 class DMV_Grammar(io
.Grammar
):
30 return "%d=%s" % (n
, self
.numtag(n
))
37 p_L
= p(self
.p_ATTACH
,(a
,h
,LEFT
))
38 p_R
= p(self
.p_ATTACH
,(a
,h
,RIGHT
))
39 if p_L
== 0.0 and p_R
== 0.0:
43 str = "p_ATTACH[ %s|%s,L] = %.4f" % (t(a
), t(h
), p_L
)
48 str += "p_ATTACH[ %s|%s,R] = %.4f" % (t(a
), t(h
), p_R
)
51 root
, stop
, att
, ord = "","","",""
52 for h
in self
.headnums():
53 root
+= "p_ROOT[%s] = %.4f\n" % (t(h
), p(self
.p_ROOT
, (h
)))
54 stop
+= "p_STOP[stop|%s,L,adj] = %.4f\t" % (t(h
), p(self
.p_STOP
, (h
,LEFT
,ADJ
)))
55 stop
+= "p_STOP[stop|%s,R,adj] = %.4f\n" % (t(h
), p(self
.p_STOP
, (h
,RIGHT
,ADJ
)))
56 stop
+= "p_STOP[stop|%s,L,non] = %.4f\t" % (t(h
), p(self
.p_STOP
, (h
,LEFT
,NON
)))
57 stop
+= "p_STOP[stop|%s,R,non] = %.4f\n" % (t(h
), p(self
.p_STOP
, (h
,RIGHT
,NON
)))
58 att
+= ''.join([p_a(a
,h
) for a
in self
.headnums()])
59 ord += "p_ORDER[ left-first|%s ] = %.4f\t" % (t(h
), p(self
.p_ORDER
, (GOL
,h
)))
60 ord += "p_ORDER[right-first|%s ] = %.4f\n" % (t(h
), p(self
.p_ORDER
, (GOR
,h
)))
61 return root
+ stop
+ att
+ ord
63 def __init__(self
, numtag
, tagnum
, p_ROOT
, p_STOP
, p_ATTACH
, p_ORDER
):
64 io
.Grammar
.__init
__(self
, numtag
, tagnum
)
65 self
.p_ROOT
= p_ROOT
# p_ROOT[w] = p
66 self
.p_ORDER
= p_ORDER
# p_ORDER[seals, w] = p
67 self
.p_STOP
= p_STOP
# p_STOP[w, LEFT, NON] = p (etc. for LA,RN,RA)
68 self
.p_ATTACH
= p_ATTACH
# p_ATTACH[a, h, LEFT] = p (etc. for R)
69 # p_GO_AT[a, h, LEFT, NON] = p (etc. for LA,RN,RA)
70 self
.p_GO_AT
= make_GO_AT(self
.p_STOP
, self
.p_ATTACH
)
72 def p_GO_AT_or0(self
, a
, h
, dir, adj
):
74 return self
.p_GO_AT
[a
, h
, dir, adj
]
79 def locs(sent_nums
, start
, stop
):
80 '''Return the between-word locations of all words in some fragment of
81 sent. We make sure to offset the locations correctly so that for
82 any w in the returned list, sent[w]==loc_w.
84 start is inclusive, stop is exclusive, as in klein-thesis and
85 Python's list-slicing.'''
86 for i0
,w
in enumerate(sent_nums
[start
:stop
]):
90 ###################################################
91 # P_INSIDE (dmv-specific) #
92 ###################################################
94 #@accepts(int, int, (int, int), int, Any(), [str], {tuple:float}, IsOneOf(None,{}))
95 def inner(i
, j
, node
, loc_h
, g
, sent
, ichart
={}, mpptree
=None):
96 ''' The ichart is of this form:
97 ichart[i,j,LHS, loc_h]
98 where i and j are between-word positions.
100 loc_h gives adjacency (along with k for attachment rules), and is
101 needed in P_STOP reestimation.
103 sent_nums
= g
.sent_nums(sent
)
105 def terminal(i
,j
,node
, loc_h
, tabs
):
106 if not i
<= loc_h
< j
:
108 print "%s*= 0.0 (wrong loc_h)" % tabs
110 elif POS(node
) == sent_nums
[i
] and node
in g
.p_ORDER
:
111 # todo: add to ichart perhaps? Although, it _is_ simple lookup..
112 prob
= g
.p_ORDER
[node
]
115 print "%sLACKING TERMINAL:" % tabs
118 print "%s*= %.4f (terminal: %s -> %s_%d)" % (tabs
,prob
, node_str(node
), sent
[i
], loc_h
)
121 def e(i
,j
, (s_h
,h
), loc_h
, n_t
):
124 key
= (i
,j
, (s_h
,h
), loc_h
)
125 if key
not in mpptree
:
126 mpptree
[key
] = (p
, L
, R
)
127 elif mpptree
[key
][0] < p
:
128 mpptree
[key
] = (p
, L
, R
)
131 "Tabs for debug output"
134 if (i
, j
, (s_h
,h
), loc_h
) in ichart
:
136 print "%s*= %.4f in ichart: i:%d j:%d node:%s loc:%s" % (tab(),ichart
[i
, j
, (s_h
,h
), loc_h
], i
, j
,
137 node_str((s_h
,h
)), loc_h
)
138 return ichart
[i
, j
, (s_h
,h
), loc_h
]
140 # Either terminal rewrites, using p_ORDER:
141 if i
+1 == j
and (s_h
== GOR
or s_h
== GOL
):
142 return terminal(i
, j
, (s_h
,h
), loc_h
, tab())
143 else: # Or not at terminal level yet:
145 print "%s%s (%.1f) from %d to %d" % (tab(),node_str((s_h
,h
)),loc_h
,i
,j
)
147 p_RGOL
= g
.p_STOP
[h
, LEFT
, adj(i
,loc_h
)] * e(i
,j
,(RGOL
,h
),loc_h
,n_t
+1)
148 p_LGOR
= g
.p_STOP
[h
, RIGHT
, adj(j
,loc_h
)] * e(i
,j
,(LGOR
,h
),loc_h
,n_t
+1)
150 to_mpp(p_RGOL
, STOPKEY
, (i
,j
, (RGOL
,h
),loc_h
))
151 to_mpp(p_LGOR
, (i
,j
, (RGOL
,h
),loc_h
), STOPKEY
)
153 print "%sp= %.4f (STOP)" % (tab(), p
)
154 elif s_h
== RGOL
or s_h
== GOL
:
157 p
= g
.p_STOP
[h
, RIGHT
, adj(j
,loc_h
)] * e(i
,j
, (GOR
,h
),loc_h
,n_t
+1)
158 to_mpp(p
, (i
,j
, (GOR
,h
),loc_h
), STOPKEY
)
159 for k
in xgo_left(i
, loc_h
): # i < k <= loc_l(h)
160 p_R
= e(k
, j
, ( s_h
,h
), loc_h
, n_t
+1)
162 for loc_a
,a
in locs(sent_nums
, i
, k
):
163 p_ah
= g
.p_GO_AT_or0(a
, h
, LEFT
, adj(k
,loc_h
))
165 p_L
= e(i
, k
, (SEAL
,a
), loc_a
, n_t
+1)
166 p_add
= p_L
* p_ah
* p_R
169 (i
, k
, (SEAL
,a
), loc_a
),
170 (k
, j
, ( s_h
,h
), loc_h
))
172 print "%sp= %.4f (ATTACH)" % (tab(), p
)
173 elif s_h
== GOR
or s_h
== LGOR
:
176 p
= g
.p_STOP
[h
, LEFT
, adj(i
,loc_h
)] * e(i
,j
, (GOL
,h
),loc_h
,n_t
+1)
177 to_mpp(p
, (i
,j
, (GOL
,h
),loc_h
), STOPKEY
)
178 for k
in xgo_right(loc_h
, j
): # loc_l(h) < k < j
179 p_L
= e(i
, k
, ( s_h
,h
), loc_h
, n_t
+1)
181 for loc_a
,a
in locs(sent_nums
,k
,j
):
182 p_ah
= g
.p_GO_AT_or0(a
, h
, RIGHT
, adj(k
,loc_h
))
183 p_R
= e(k
, j
, (SEAL
,a
), loc_a
, n_t
+1)
184 p_add
= p_L
* p_ah
* p_R
187 (i
, k
, ( s_h
,h
), loc_h
),
188 (k
, j
, (SEAL
,a
), loc_a
))
191 print "%sp= %.4f (ATTACH)" % (tab(), p
)
192 # elif s_h == GOL: # todo
194 ichart
[i
, j
, (s_h
,h
), loc_h
] = p
198 inner_prob
= e(i
,j
,node
,loc_h
, 0)
200 print debug_ichart(g
,sent
,ichart
)
202 # end of dmv.inner(i, j, node, loc_h, g, sent, ichart={})
205 def debug_ichart(g
,sent
,ichart
):
206 str = "---ICHART:---\n"
207 for (s
,t
,LHS
,loc_h
),v
in ichart
.iteritems():
208 str += "%s -> %s_%d ... %s_%d (loc_h:%s):\t%.4f\n" % (node_str(LHS
,g
.numtag
),
209 sent
[s
], s
, sent
[s
], t
, loc_h
, v
)
210 str += "---ICHART:end---\n"
214 def inner_sent(g
, sent
, ichart
={}):
215 return sum([g
.p_ROOT
[w
] * inner(0, len(sent
), (SEAL
,w
), loc_w
, g
, sent
, ichart
)
216 for loc_w
,w
in locs(g
.sent_nums(sent
),0,len(sent
))])
222 ###################################################
223 # P_OUTSIDE (dmv-specific) #
224 ###################################################
226 #@accepts(int, int, (int, int), int, Any(), [str], {tuple:float}, {tuple:float})
227 def outer(i
,j
,w_node
,loc_w
, g
, sent
, ichart
={}, ochart
={}):
228 ''' http://www.student.uib.no/~kun041/dmvccm/DMVCCM.html#outer
230 w_node is a pair (seals,POS); the w in klein-thesis is made up of
233 sent_nums
= g
.sent_nums(sent
)
234 if POS(w_node
) not in sent_nums
[i
:j
]:
235 # sanity check, w must be able to dominate sent[i:j]
239 def e(i
,j
,LHS
,loc_h
): # P_{INSIDE}
241 return ichart
[i
,j
,LHS
,loc_h
]
243 return inner(i
,j
,LHS
,loc_h
,g
,sent
,ichart
)
245 def f(i
,j
,w_node
,loc_w
):
246 if not (i
<= loc_w
< j
):
248 if (i
,j
,w_node
,loc_w
) in ochart
:
249 return ochart
[i
,j
, w_node
,loc_w
]
251 if i
== 0 and j
== len(sent
):
253 else: # ROOT may only be used on full sentence
255 # but we may have non-ROOTs (stops) over full sentence too:
259 # todo: try either if p_M > 0.0: or sum(), and speed-test them
261 if s_w
== SEAL
: # w == a
262 # todo: do the i<sent<j check here to save on calls?
263 p
= g
.p_ROOT
[w
] * f(i
,j
,ROOT
,loc_w
)
265 for k
in xgt(j
, sent
): # j<k<len(sent)+1
266 for loc_h
,h
in locs(sent_nums
,j
,k
):
267 p_wh
= g
.p_GO_AT_or0(w
, h
, LEFT
, adj(j
, loc_h
))
268 for s_h
in [RGOL
, GOL
]:
269 p
+= f(i
,k
,(s_h
,h
),loc_h
) * p_wh
* e(j
,k
,(s_h
,h
),loc_h
)
271 for k
in xlt(i
): # k<i
272 for loc_h
,h
in locs(sent_nums
,k
,i
):
273 p_wh
= g
.p_GO_AT_or0(w
, h
, RIGHT
, adj(i
, loc_h
))
274 for s_h
in [LGOR
, GOR
]:
275 p
+= e(k
,i
,(s_h
,h
), loc_h
) * p_wh
* f(k
,j
,(s_h
,h
), loc_h
)
277 elif s_w
== RGOL
or s_w
== GOL
: # w == h, left stop + left attach
282 p
= g
.p_STOP
[w
, LEFT
, adj(i
,loc_w
)] * f(i
,j
,( s_h
,w
),loc_w
)
283 for k
in xlt(i
): # k<i
284 for loc_a
,a
in locs(sent_nums
,k
,i
):
285 p_aw
= g
.p_GO_AT_or0(a
, w
, LEFT
, adj(i
, loc_w
))
286 p
+= e(k
,i
, (SEAL
,a
),loc_a
) * p_aw
* f(k
,j
,w_node
,loc_w
)
288 elif s_w
== GOR
or s_w
== LGOR
: # w == h, right stop + right attach
293 p
= g
.p_STOP
[w
, RIGHT
, adj(j
,loc_w
)] * f(i
,j
,( s_h
,w
),loc_w
)
294 for k
in xgt(j
, sent
): # j<k<len(sent)+1
295 for loc_a
,a
in locs(sent_nums
,j
,k
):
296 p_ah
= g
.p_GO_AT_or0(a
, w
, RIGHT
, adj(j
, loc_w
))
297 p
+= f(i
,k
,w_node
,loc_w
) * p_ah
* e(j
,k
,(SEAL
,a
),loc_a
)
299 ochart
[i
,j
,w_node
,loc_w
] = p
303 return f(i
,j
,w_node
,loc_w
)
304 # end outer(i,j,w_node,loc_w, g,sent, ichart,ochart)
309 ###################################################
311 ###################################################
313 # todo: it seems we have to rewrite attachment reestimation so that we
314 # have 'a´ as the outer loop, then sentences... but this means running
315 # through sentences several times, and that would require storing
316 # inner probabilites...agh!
318 def reest_zeros(h_nums
):
319 '''A dict to hold numerators and denominators for our 6+ reestimation
322 fr
= { ('ROOT','den'):0.0 } # holds sum over p_sent
324 fr
['ROOT','num',h
] = 0.0
325 for s_h
in [GOR
,GOL
,RGOL
,LGOR
]:
327 fr
['hat_a','den',x
] = 0.0 # = c()
328 # not all arguments are attached to, so we just initialize
329 # fr['hat_a','num',a,(s_h,h)] as they show up, in reest_freq
330 for adj
in [NON
, ADJ
]:
331 for nd
in ['num','den']:
332 fr
['STOP',nd
,x
,adj
] = 0.0
335 def reest_freq(g
, corpus
):
336 fr
= reest_zeros(g
.headnums())
339 p_sent
= None # 50 % speed increase on storing this locally
341 # local functions altogether 2x faster than global
342 def c(i
,j
,LHS
,loc_h
,sent
):
346 p_in
= e(i
,j
, LHS
,loc_h
,sent
)
350 p_out
= f(i
,j
, LHS
,loc_h
,sent
)
351 return p_in
* p_out
/ p_sent
354 def f(i
,j
,LHS
,loc_h
,sent
): # P_{OUTSIDE}
356 return ochart
[i
,j
,LHS
,loc_h
]
358 return outer(i
,j
,LHS
,loc_h
,g
,sent
,ichart
,ochart
)
361 def e(i
,j
,LHS
,loc_h
,sent
): # P_{INSIDE}
363 return ichart
[i
,j
,LHS
,loc_h
]
365 return inner(i
,j
,LHS
,loc_h
,g
,sent
,ichart
)
368 def w_left(i
,j
, x
,loc_h
,sent
,sent_nums
):
369 if not p_sent
> 0.0: return
373 for k
in xtween(i
, j
):
374 p_out
= f(i
,j
, x
,loc_h
, sent
)
377 p_R
= e(k
,j
, x
,loc_h
, sent
)
381 for loc_a
,a
in locs(sent_nums
, i
,k
): # i<=loc_l(a)<k
382 p_rule
= g
.p_GO_AT_or0(a
, h
, LEFT
, adj(k
, loc_h
))
383 p_L
= e(i
,k
, (SEAL
,a
), loc_a
, sent
)
384 p
= p_L
* p_out
* p_R
* p_rule
388 for a
,p
in a_k
.iteritems():
389 try: fr
['hat_a','num',a
,x
] += p
/ p_sent
390 except: fr
['hat_a','num',a
,x
] = p
/ p_sent
391 # end reest_freq.w_left()
393 def w_right(i
,j
, x
,loc_h
,sent
,sent_nums
):
394 if not p_sent
> 0.0: return
397 for k
in xtween(i
, j
):
398 p_out
= f(i
,j
, x
,loc_h
, sent
)
401 p_L
= e(i
,k
, x
,loc_h
, sent
)
405 for loc_a
,a
in locs(sent_nums
, k
,j
): # k<=loc_l(a)<j
406 p_rule
= g
.p_GO_AT_or0(a
, h
, RIGHT
, adj(k
, loc_h
))
407 p_R
= e(k
,j
, (SEAL
,a
),loc_a
, sent
)
408 p
= p_L
* p_out
* p_R
* p_rule
410 fr
['hat_a','num',a
,x
] += p
412 fr
['hat_a','num',a
,x
] = p
413 # end reest_freq.w_right()
421 p_sent
= inner_sent(g
, sent
, ichart
)
422 fr
['ROOT','den'] += p_sent
424 sent_nums
= g
.sent_nums(sent
)
426 for loc_h
,h
in locs(sent_nums
,0,len(sent
)+1): # locs-stop is exclusive, thus +1
428 fr
['ROOT','num',h
] += g
.p_ROOT
[h
] * e(0,len(sent
), (SEAL
,h
),loc_h
, sent
)
433 # left non-adjacent stop:
434 for i
in xlt(loc_l_h
):
435 fr
['STOP','num',(GOL
,h
),NON
] += c(loc_l_h
, j
, (LGOR
, h
),loc_h
, sent
)
436 fr
['STOP','den',(GOL
,h
),NON
] += c(loc_l_h
, j
, (GOL
, h
),loc_h
, sent
)
437 for j
in xgteq(loc_r_h
, sent
):
438 fr
['STOP','num',(RGOL
,h
),NON
] += c(i
, j
, (SEAL
, h
),loc_h
, sent
)
439 fr
['STOP','den',(RGOL
,h
),NON
] += c(i
, j
, (RGOL
, h
),loc_h
, sent
)
440 # left adjacent stop, i = loc_l_h
441 fr
['STOP','num',(GOL
,h
),ADJ
] += c(loc_l_h
, loc_r_h
, (LGOR
, h
),loc_h
, sent
)
442 fr
['STOP','den',(GOL
,h
),ADJ
] += c(loc_l_h
, loc_r_h
, (GOL
, h
),loc_h
, sent
)
443 for j
in xgteq(loc_r_h
, sent
):
444 fr
['STOP','num',(RGOL
,h
),ADJ
] += c(loc_l_h
, j
, (SEAL
, h
),loc_h
, sent
)
445 fr
['STOP','den',(RGOL
,h
),ADJ
] += c(loc_l_h
, j
, (RGOL
, h
),loc_h
, sent
)
446 # right non-adjacent stop:
447 for j
in xgt(loc_r_h
, sent
):
448 fr
['STOP','num',(GOR
,h
),NON
] += c(loc_l_h
, j
, (RGOL
, h
),loc_h
, sent
)
449 fr
['STOP','den',(GOR
,h
),NON
] += c(loc_l_h
, j
, (GOR
, h
),loc_h
, sent
)
450 for i
in xlteq(loc_l_h
):
451 fr
['STOP','num',(LGOR
,h
),NON
] += c(loc_l_h
, j
, (SEAL
, h
),loc_h
, sent
)
452 fr
['STOP','den',(LGOR
,h
),NON
] += c(loc_l_h
, j
, (LGOR
, h
),loc_h
, sent
)
453 # right adjacent stop, j = loc_r_h
454 fr
['STOP','num',(GOR
,h
),ADJ
] += c(loc_l_h
, loc_r_h
, (RGOL
, h
),loc_h
, sent
)
455 fr
['STOP','den',(GOR
,h
),ADJ
] += c(loc_l_h
, loc_r_h
, (GOR
, h
),loc_h
, sent
)
456 for i
in xlteq(loc_l_h
):
457 fr
['STOP','num',(LGOR
,h
),ADJ
] += c(loc_l_h
, j
, (SEAL
, h
),loc_h
, sent
)
458 fr
['STOP','den',(LGOR
,h
),ADJ
] += c(loc_l_h
, j
, (LGOR
, h
),loc_h
, sent
)
461 if 'REEST_ATTACH' in DEBUG
:
462 print "Lattach %s: for i < %s"%(g
.numtag(h
),sent
[0:loc_h
+1])
463 for s_h
in [RGOL
, GOL
]:
465 for i
in xlt(loc_l_h
): # i < loc_l(h)
466 if 'REEST_ATTACH' in DEBUG
:
467 print "\tfor j >= %s"%sent
[loc_h
:len(sent
)]
468 for j
in xgteq(loc_r_h
, sent
): # j >= loc_r(h)
469 fr
['hat_a','den',x
] += c(i
,j
, x
,loc_h
, sent
) # v_q in L&Y
470 if 'REEST_ATTACH' in DEBUG
:
471 print "\t\tc( %d , %d, %s, %s, sent)=%.4f"%(i
,j
,node_str(x
),loc_h
,fr
['hat_a','den',x
])
472 w_left(i
, j
, x
,loc_h
, sent
,sent_nums
) # compute w for all a in sent
475 if 'REEST_ATTACH' in DEBUG
:
476 print "Rattach %s: for i <= %s"%(g
.numtag(h
),sent
[0:loc_h
+1])
477 for s_h
in [GOR
, LGOR
]:
479 for i
in xlteq(loc_l_h
): # i <= loc_l(h)
480 if 'REEST_ATTACH' in DEBUG
:
481 print "\tfor j > %s"%sent
[loc_h
:len(sent
)]
482 for j
in xgt(loc_r_h
, sent
): # j > loc_r(h)
483 fr
['hat_a','den',x
] += c(i
,j
, x
,loc_h
, sent
) # v_q in L&Y
484 if 'REEST_ATTACH' in DEBUG
:
485 print "\t\tc( %d , %d, %s, %s, sent)=%.4f"%(loc_h
,j
,node_str(x
),loc_h
,fr
['hat_a','den',x
])
486 w_right(loc_l_h
,j
, x
,loc_h
, sent
,sent_nums
) # compute w for all a in sent
493 def reestimate(g
, corpus
):
494 fr
= reest_freq(g
, corpus
)
495 p_ROOT
, p_STOP
, p_ATTACH
= {},{},{}
497 for h
in g
.headnums():
498 reest_head(h
, fr
, g
, p_ROOT
, p_STOP
, p_ATTACH
)
501 g
.p_ATTACH
= p_ATTACH
502 g
.p_GO_AT
= make_GO_AT(p_STOP
,p_ATTACH
)
507 def reest_head(h
, fr
, g
, p_ROOT
, p_STOP
, p_ATTACH
):
508 "Given a single head, update g with the reestimated probability."
509 # remove 0-prob stuff? todo
511 p_ROOT
[h
] = fr
['ROOT','num',h
] / fr
['ROOT','den']
513 p_ROOT
[h
] = fr
['ROOT','den']
515 for dir in [LEFT
,RIGHT
]:
516 for adj
in [ADJ
, NON
]: # p_STOP
517 p_STOP
[h
, dir, adj
] = 0.0
518 for s_h
in dirseal(dir):
520 p
= fr
['STOP','den', x
, adj
]
522 p
= fr
['STOP', 'num', x
, adj
] / p
523 p_STOP
[h
, dir, adj
] += p
525 for s_h
in dirseal(dir): # make hat_a for p_ATTACH
529 p_c
= fr
['hat_a','den',x
]
530 for a
in g
.headnums():
532 hat_a
[a
,x
] = fr
['hat_a','num',a
,x
] / p_c
536 sum_hat_a
= sum([hat_a
[w
,x
] for w
in g
.headnums()
539 for a
in g
.headnums():
540 if (a
,h
,dir) not in p_ATTACH
:
541 p_ATTACH
[a
,h
,dir] = 0.0
542 try: # (a,x) might not be in hat_a
543 p_ATTACH
[a
,h
,dir] += hat_a
[a
,x
] / sum_hat_a
554 ###################################################
555 # Most Probable Parse: #
556 ###################################################
558 STOPKEY
= (-1,-1,STOP
,-1)
559 ROOTKEY
= (-1,-1,ROOT
,-1)
561 def make_mpptree(g
, sent
):
562 '''Tell inner() to make an mpptree, connect ROOT to this. (Logically,
563 this should be part of inner_sent though...)'''
565 mpptree
= { ROOTKEY
:(0.0, ROOTKEY
, None) }
566 for loc_w
,w
in locs(g
.sent_nums(sent
),0,len(sent
)):
567 p
= g
.p_ROOT
[w
] * inner(0, len(sent
), (SEAL
,w
), loc_w
, g
, sent
, ichart
, mpptree
)
569 R
= (0,len(sent
), (SEAL
,w
), loc_w
)
570 if mpptree
[ROOTKEY
][0] < p
:
571 mpptree
[ROOTKEY
] = (p
, L
, R
)
574 def parse_mpptree(mpptree
, sent
):
575 '''mpptree is a dict of the form {k:(p,L,R),...}; where k, L and R
576 are `keys' of the form (i,j,node,loc).
578 returns an mpp of the form [((head, loc_h),(arg, loc_a)), ...],
579 where head and arg are tags.'''
580 # local functions for clear access to mpptree:
584 return POS(k_node(key
))
586 return seals(k_node(key
))
588 return (k_node(key
),key
[3])
590 return (k_POS(key
),key
[3])
592 s_k
= k_seals(key
) # i+1 == j
593 return key
[0] + 1 == key
[1] and (s_k
== GOR
or s_k
== GOL
)
599 # arbitrarily, "ROOT attaches to right". We add it here to
600 # avoid further complications:
601 firstkey
= t_R(mpptree
[ROOTKEY
])
602 deps
= set([ (k_locPOS(ROOTKEY
), k_locPOS(firstkey
), RIGHT
) ])
610 L
= t_L( mpptree
[k
] )
611 R
= t_R( mpptree
[k
] )
612 if k_locnode( k
) == k_locnode( L
): # Rattach
613 deps
.add((k_locPOS( k
), k_locPOS( R
), LEFT
))
615 elif k_locnode( k
) == k_locnode( R
): # Lattach
616 deps
.add((k_locPOS( k
), k_locPOS( L
), RIGHT
))
625 tagf
= g
.numtag
# localized function, todo: speed-test
626 mpptree
= make_mpptree(g
, sent
)
627 return set([((tagf(h
), loc_h
), (tagf(a
), loc_a
))
628 for (h
, loc_h
),(a
,loc_a
),dir in parse_mpptree(mpptree
,sent
)])
631 ########################################################################
632 # testing functions: #
633 ########################################################################
635 testcorpus
= [s
.split() for s
in ['det nn vbd c vbd','vbd nn c vbd',
636 'det nn vbd', 'det nn vbd c pp',
637 'det nn vbd', 'det vbd vbd c pp',
638 'det nn vbd', 'det nn vbd c vbd',
639 'det nn vbd', 'det nn vbd c vbd',
640 'det nn vbd', 'det nn vbd c vbd',
641 'det nn vbd', 'det nn vbd c pp',
642 'det nn vbd pp', 'det nn vbd', ]]
645 import loc_h_harmonic
646 reload(loc_h_harmonic
)
648 # make sure these are the way they were when setting up the tests:
649 loc_h_harmonic
.HARMONIC_C
= 0.0
650 loc_h_harmonic
.FNONSTOP_MIN
= 25
651 loc_h_harmonic
.FSTOP_MIN
= 5
652 loc_h_harmonic
.RIGHT_FIRST
= 1.0
654 return loc_h_harmonic
.initialize(testcorpus
)
656 def ig(s
,t
,LHS
,loc_h
):
657 return inner(s
,t
,LHS
,loc_h
,testgrammar(),'det nn vbd'.split(),{})
659 def testreestimation():
662 # DEBUG.add('REEST_ATTACH')
663 f
= reestimate(g
, testcorpus
)
665 testreestimation_regression(f
)
668 def testreestimation_regression(fr
):
669 f_stops
= {('STOP', 'den', (RGOL
,3),NON
): 12.212773236178391, ('STOP', 'den', (GOR
,2),ADJ
): 4.0, ('STOP', 'num', (GOR
,4),NON
): 2.5553487221351365, ('STOP', 'den', (RGOL
,2),NON
): 1.274904052793207, ('STOP', 'num', (RGOL
,1),ADJ
): 14.999999999999995, ('STOP', 'den', (GOR
,3),ADJ
): 15.0, ('STOP', 'num', (RGOL
,4),ADJ
): 16.65701084787457, ('STOP', 'num', (RGOL
,0),ADJ
): 4.1600647714443468, ('STOP', 'den', (RGOL
,4),NON
): 6.0170669155897105, ('STOP', 'num', (RGOL
,3),ADJ
): 2.7872267638216113, ('STOP', 'num', (RGOL
,2),ADJ
): 2.9723139990470515, ('STOP', 'den', (RGOL
,2),ADJ
): 4.0, ('STOP', 'den', (GOR
,3),NON
): 12.945787931730905, ('STOP', 'den', (RGOL
,3),ADJ
): 14.999999999999996, ('STOP', 'den', (GOR
,2),NON
): 0.0, ('STOP', 'den', (RGOL
,0),ADJ
): 8.0, ('STOP', 'num', (GOR
,4),ADJ
): 19.44465127786486, ('STOP', 'den', (GOR
,1),NON
): 3.1966410324085777, ('STOP', 'den', (RGOL
,1),ADJ
): 14.999999999999995, ('STOP', 'num', (GOR
,3),ADJ
): 4.1061665495365558, ('STOP', 'den', (GOR
,0),NON
): 4.8282499043902476, ('STOP', 'num', (RGOL
,4),NON
): 5.3429891521254289, ('STOP', 'num', (GOR
,2),ADJ
): 4.0, ('STOP', 'den', (RGOL
,4),ADJ
): 22.0, ('STOP', 'num', (GOR
,1),ADJ
): 12.400273895299103, ('STOP', 'num', (RGOL
,2),NON
): 1.0276860009529487, ('STOP', 'num', (GOR
,0),ADJ
): 3.1717500956097533, ('STOP', 'num', (RGOL
,3),NON
): 12.212773236178391, ('STOP', 'den', (GOR
,4),ADJ
): 22.0, ('STOP', 'den', (GOR
,4),NON
): 2.8705211946979836, ('STOP', 'num', (RGOL
,0),NON
): 3.8399352285556518, ('STOP', 'num', (RGOL
,1),NON
): 0.0, ('STOP', 'num', (GOR
,0),NON
): 4.8282499043902476, ('STOP', 'num', (GOR
,1),NON
): 2.5997261047008959, ('STOP', 'den', (RGOL
,1),NON
): 0.0, ('STOP', 'den', (GOR
,0),ADJ
): 8.0, ('STOP', 'num', (GOR
,2),NON
): 0.0, ('STOP', 'den', (RGOL
,0),NON
): 4.6540557322109795, ('STOP', 'den', (GOR
,1),ADJ
): 15.0, ('STOP', 'num', (GOR
,3),NON
): 10.893833450463443}
670 for k
,v
in f_stops
.iteritems():
672 print '''Regression in P_STOP reestimation, should be fr[%s]=%.4f,
673 but %s not in fr'''%(k
,v
,k
)
674 elif not "%.10f"%fr[k
] == "%.10f"%v
:
675 print '''Regression in P_STOP reestimation, should be fr[%s]=%.4f,
676 got fr[%s]=%.4f.'''%(k
,v
,k
,fr
[k
])
678 def testmpp_regression(mpptree
,k_n
):
679 mpp
= {ROOTKEY
: (2.877072116829971e-05, STOPKEY
, (0, 3, (2, 3), 1)),
680 (0, 1, (1, 1), 0): (0.1111111111111111, (0, 1, (0, 1), 0), STOPKEY
),
681 (0, 1, (2, 1), 0): (0.049382716049382713, STOPKEY
, (0, 1, (1, 1), 0)),
682 (0, 3, (1, 3), 1): (0.00027619892321567721,
685 (0, 3, (2, 3), 1): (0.00012275507698474543, STOPKEY
, (0, 3, (1, 3), 1)),
686 (1, 3, (0, 3), 1): (0.025280986819448362,
689 (1, 3, (1, 3), 1): (0.0067415964851862296, (1, 3, (0, 3), 1), STOPKEY
),
690 (2, 3, (1, 4), 2): (0.32692307692307693, (2, 3, (0, 4), 2), STOPKEY
),
691 (2, 3, (2, 4), 2): (0.037721893491124266, STOPKEY
, (2, 3, (1, 4), 2))}
692 for k
,(v
,L
,R
) in mpp
.iteritems():
693 k2
= k
[0:k_n
] # 3 if the new does not check loc_h
696 if k2
not in mpptree
:
697 print "mpp regression, %s missing"%(k2
,)
699 vnew
= mpptree
[k2
][0]
700 if not "%.10f"%vnew
== "%.10f"%v
:
701 print "mpp regression, wanted %s=%.5f, got %.5f"%(k2
,v
,vnew
)
706 p_ROOT
, p_STOP
, p_ATTACH
, p_ORDER
= {},{},{},{}
709 p_STOP
[h
,LEFT
,NON
] = 1.0
710 p_STOP
[h
,LEFT
,ADJ
] = 1.0
711 p_STOP
[h
,RIGHT
,NON
] = 0.4 # RSTOP
712 p_STOP
[h
,RIGHT
,ADJ
] = 0.3 # RSTOP
713 p_STOP
[a
,LEFT
,NON
] = 1.0
714 p_STOP
[a
,LEFT
,ADJ
] = 1.0
715 p_STOP
[a
,RIGHT
,NON
] = 0.4 # RSTOP
716 p_STOP
[a
,RIGHT
,ADJ
] = 0.3 # RSTOP
717 p_ATTACH
[a
,h
,LEFT
] = 1.0 # not used
718 p_ATTACH
[a
,h
,RIGHT
] = 1.0 # not used
719 p_ATTACH
[h
,a
,LEFT
] = 1.0 # not used
720 p_ATTACH
[h
,a
,RIGHT
] = 1.0 # not used
721 p_ATTACH
[h
,h
,LEFT
] = 1.0 # not used
722 p_ATTACH
[h
,h
,RIGHT
] = 1.0 # not used
723 p_ORDER
[(GOR
, h
)] = 1.0
724 p_ORDER
[(GOL
, h
)] = 0.0
725 p_ORDER
[(GOR
, a
)] = 1.0
726 p_ORDER
[(GOL
, a
)] = 0.0
727 g
= DMV_Grammar({h
:'h',a
:'a'}, {'h':h
,'a':a
}, p_ROOT
, p_STOP
, p_ATTACH
, p_ORDER
)
728 # these probabilities are impossible so add them manually:
729 g
.p_GO_AT
[a
,a
,LEFT
,NON
] = 0.4 # Lattach
730 g
.p_GO_AT
[a
,a
,LEFT
,ADJ
] = 0.6 # Lattach
731 g
.p_GO_AT
[h
,a
,LEFT
,NON
] = 0.2 # Lattach to h
732 g
.p_GO_AT
[h
,a
,LEFT
,ADJ
] = 0.1 # Lattach to h
733 g
.p_GO_AT
[a
,a
,RIGHT
,NON
] = 1.0 # Rattach
734 g
.p_GO_AT
[a
,a
,RIGHT
,ADJ
] = 1.0 # Rattach
735 g
.p_GO_AT
[h
,a
,RIGHT
,NON
] = 1.0 # Rattach to h
736 g
.p_GO_AT
[h
,a
,RIGHT
,ADJ
] = 1.0 # Rattach to h
737 g
.p_GO_AT
[h
,h
,LEFT
,NON
] = 0.2 # Lattach
738 g
.p_GO_AT
[h
,h
,LEFT
,ADJ
] = 0.1 # Lattach
739 g
.p_GO_AT
[a
,h
,LEFT
,NON
] = 0.4 # Lattach to a
740 g
.p_GO_AT
[a
,h
,LEFT
,ADJ
] = 0.6 # Lattach to a
741 g
.p_GO_AT
[h
,h
,RIGHT
,NON
] = 1.0 # Rattach
742 g
.p_GO_AT
[h
,h
,RIGHT
,ADJ
] = 1.0 # Rattach
743 g
.p_GO_AT
[a
,h
,RIGHT
,NON
] = 1.0 # Rattach to a
744 g
.p_GO_AT
[a
,h
,RIGHT
,ADJ
] = 1.0 # Rattach to a
750 p_ROOT
, p_STOP
, p_ATTACH
, p_ORDER
= {},{},{},{}
752 p_STOP
[h
,LEFT
,NON
] = 1.0
753 p_STOP
[h
,LEFT
,ADJ
] = 1.0
754 p_STOP
[h
,RIGHT
,NON
] = 0.4
755 p_STOP
[h
,RIGHT
,ADJ
] = 0.3
756 p_ATTACH
[h
,h
,LEFT
] = 1.0 # not used
757 p_ATTACH
[h
,h
,RIGHT
] = 1.0 # not used
758 p_ORDER
[(GOR
, h
)] = 1.0
759 p_ORDER
[(GOL
, h
)] = 0.0
760 g
= DMV_Grammar({h
:'h'}, {'h':h
}, p_ROOT
, p_STOP
, p_ATTACH
, p_ORDER
)
761 g
.p_GO_AT
[h
,h
,LEFT
,NON
] = 0.6 # these probabilities are impossible
762 g
.p_GO_AT
[h
,h
,LEFT
,ADJ
] = 0.7 # so add them manually...
763 g
.p_GO_AT
[h
,h
,RIGHT
,NON
] = 1.0
764 g
.p_GO_AT
[h
,h
,RIGHT
,ADJ
] = 1.0
769 def testreestimation_h():
772 reestimate(g
,['h h h'.split()])
775 def test(wanted
, got
):
776 if not wanted
== got
:
777 raise Warning, "Regression! Should be %s: %s" % (wanted
, got
)
779 def regression_tests():
780 testmpp_regression(make_mpptree(testgrammar(), testcorpus
[2]),4)
784 "%.3f" % inner(0, 2, (SEAL
,h
), 0, testgrammar_h(), 'h h'.split(),{}))
786 "%.3f" % inner(0, 2, (SEAL
,h
), 1, testgrammar_h(), 'h h'.split(),{}))
788 "%.4f" % inner_sent(testgrammar_h(), 'h h h'.split()))
791 "%.4f" % inner(0, 3, (SEAL
,0), 0, testgrammar_h(), 'h h h'.split(),{}))
793 "%.4f" % inner(0, 3, (SEAL
,0), 1, testgrammar_h(), 'h h h'.split(),{}))
795 "%.4f" % inner(0, 3, (SEAL
,h
), 2, testgrammar_h(), 'h h h'.split(),{}))
798 "%.2f" % outer(1, 3, (RGOL
,h
), 2, testgrammar_h(),'h h h'.split(),{},{}))
799 test("0.61" , # ftw? can't be right... there's an 0.4 shared between these two...
800 "%.2f" % outer(1, 3, (RGOL
,h
), 1, testgrammar_h(),'h h h'.split(),{},{}))
803 "%.2f" % outer(1, 3, (RGOL
,h
), 0, testgrammar_h(),'h h h'.split(),{},{}))
805 "%.2f" % outer(1, 3, (RGOL
,h
), 3, testgrammar_h(),'h h h'.split(),{},{}))
808 "%.4f" % outer(0, 1, (GOR
,h
), 0,testgrammar_a(),'h a'.split(),{},{}))
810 "%.4f" % outer(0, 2, (GOR
,h
), 0,testgrammar_a(),'h a'.split(),{},{}))
812 "%.4f" % outer(0, 3, (GOR
,h
), 0,testgrammar_a(),'h a'.split(),{},{}))
814 # todo: add more of these tests...
816 if __name__
== "__main__":
820 # profile.run('testreestimation()')
823 # print timeit.Timer("loc_h_dmv.testreestimation()",'''import loc_h_dmv
824 # reload(loc_h_dmv)''').timeit(1)
830 # for s in testcorpus:
831 # print "sent:%s\nparse:set(\n%s)"%(s,pprint.pformat(list(mpp(testgrammar(), s)),
836 # pprint.pprint( testreestimation())
842 inners
= [(sent
, inner_sent(g
, sent
, {})) for sent
in testcorpus
]