another little fix to stop initialization
[dmvccm.git] / src / before_betweens_dmv.py
blob38692aa669322dc6774bb24f1beb3aba421ff783
1 # before_betweens_dmv.py
2 #
3 # dmv reestimation and inside-outside probabilities using loc_h, but
4 # at-word sentence locations
6 #import numpy # numpy provides Fast Arrays, for future optimization
7 import io
8 from common_dmv import *
10 if __name__ == "__main__":
11 print "before_betweens_dmv module tests:"
13 class DMV_Grammar(io.Grammar):
14 '''The DMV-PCFG.
16 Public members:
17 p_STOP, p_ROOT, p_CHOOSE, p_terminals
18 These are changed in the Maximation step, then used to set the
19 new probabilities of each DMV_Rule.
21 Todo: make p_terminals private? (But it has to be changable in
22 maximation step due to the short-cutting rules... could of course
23 make a DMV_Grammar function to update the short-cut rules...)
25 __p_rules is private, but we can still say stuff like:
26 for r in g.all_rules():
27 r.probN = newProbN
29 What other representations do we need? (P_STOP formula uses
30 deps_D(h,l/r) at least)'''
31 def __str__(self):
32 str = ""
33 for r in self.all_rules():
34 str += "%s\n" % r.__str__(self.numtag)
35 return str
37 def h_rules(self, h):
38 return [r for r in self.all_rules() if r.POS() == h]
40 def mothersL(self, Node, sent_nums, loc_N):
41 # todo: speed-test with and without sent_nums/loc_N cut-off
42 return [r for r in self.all_rules() if r.L() == Node
43 and (POS(r.R()) in sent_nums[loc_N+1:] or r.R() == STOP)]
45 def mothersR(self, Node, sent_nums, loc_N):
46 return [r for r in self.all_rules() if r.R() == Node
47 and (POS(r.L()) in sent_nums[:loc_N] or r.L() == STOP)]
49 def rules(self, LHS):
50 return [r for r in self.all_rules() if r.LHS() == LHS]
52 def sent_rules(self, LHS, sent_nums):
53 '''Used in dmv.inner. Todo: this takes a _lot_ of time, it
54 seems. Could use some more space and cache some of this
55 somehow perhaps?'''
56 # We don't want to rule out STOPs!
57 nums = sent_nums + [ POS(STOP) ]
58 return [r for r in self.all_rules() if r.LHS() == LHS
59 and POS(r.L()) in nums and POS(r.R()) in nums]
61 def deps_L(self, head): # todo: do I use this at all?
62 # todo test, probably this list comprehension doesn't work
63 return [a for r in self.all_rules() if r.POS() == head and a == r.L()]
65 def deps_R(self, head):
66 # todo test, probably this list comprehension doesn't work
67 return [a for r in self.all_rules() if r.POS() == head and a == r.R()]
69 def __init__(self, numtag, tagnum, p_rules, p_terminals, p_STOP, p_CHOOSE, p_ROOT):
70 io.Grammar.__init__(self, numtag, tagnum, p_rules, p_terminals)
71 self.p_STOP = p_STOP
72 self.p_CHOOSE = p_CHOOSE
73 self.p_ROOT = p_ROOT
74 self.head_nums = [k for k in numtag.iterkeys()]
77 class DMV_Rule(io.CNF_Rule):
78 '''A single CNF rule in the PCFG, of the form
79 LHS -> L R
80 where LHS, L and R are 'nodes', eg. of the form (seals, head).
82 Public members:
83 probN, probA
85 Private members:
86 __L, __R, __LHS
88 Different rule-types have different probabilities associated with
89 them:
91 _h_ -> STOP h_ P( STOP|h,L, adj)
92 _h_ -> STOP h_ P( STOP|h,L,non_adj)
93 h_ -> h STOP P( STOP|h,R, adj)
94 h_ -> h STOP P( STOP|h,R,non_adj)
95 h_ -> _a_ h_ P(-STOP|h,L, adj) * P(a|h,L)
96 h_ -> _a_ h_ P(-STOP|h,L,non_adj) * P(a|h,L)
97 h -> h _a_ P(-STOP|h,R, adj) * P(a|h,R)
98 h -> h _a_ P(-STOP|h,R,non_adj) * P(a|h,R)
99 '''
100 def p(self, adj, *arg):
101 if adj:
102 return self.probA
103 else:
104 return self.probN
106 def adj(middle, loc_h):
107 "middle is eg. k when rewriting for i<k<j (inside probabilities)."
108 return middle == loc_h[0] or middle == loc_h[1]
110 def p_STOP(self, s, t, loc_h):
111 '''Returns the correct probability, adjacent if we're rewriting from
112 the (either left or right) end of the fragment.
114 if self.L() == STOP:
115 return self.p(s == loc_h)
116 elif self.R() == STOP:
117 if not loc_h == s:
118 if 'TODO' in DEBUG:
119 print "(%s given loc_h:%d but s:%d. Todo: optimize away!)" % (self, loc_h, s)
120 return 0.0
121 else:
122 return self.p(t == loc_h)
124 def p_ATTACH(self, r, loc_h, s=None):
125 '''Returns the correct probability, adjacent if we haven't attached
126 anything before.
127 (This is actually p_choose*(1-p_stop).)'''
128 if self.LHS() == self.L():
129 if s and not loc_h == s:
130 if 'TODO' in DEBUG:
131 print "(%s given loc_h (loc_L):%d but s:%d. Todo: optimize away!)" % (self, loc_h, s)
132 return 0.0
133 else:
134 return self.p(r == loc_h)
135 elif self.LHS() == self.R():
136 return self.p(r+1 == loc_h)
138 def seals(self):
139 return seals(self.LHS())
141 def POS(self):
142 return POS(self.LHS())
144 def __init__(self, LHS, L, R, probN, probA):
145 for b_h in [LHS, L, R]:
146 if seals(b_h) not in SEALS:
147 raise ValueError("seals must be in %s; was given: %s"
148 % (SEALS, seals(b_h)))
149 io.CNF_Rule.__init__(self, LHS, L, R, probN)
150 self.probA = probA # adjacent
151 self.probN = probN # non_adj
153 @classmethod # so we can call DMV_Rule.bar_str(b_h)
154 def bar_str(cls, b_h, tag=lambda x:x):
155 if(b_h == ROOT):
156 return 'ROOT'
157 elif(b_h == STOP):
158 return 'STOP'
159 elif(seals(b_h) == RGOL):
160 return " %s_ " % tag(POS(b_h))
161 elif(seals(b_h) == SEAL):
162 return "_%s_ " % tag(POS(b_h))
163 else:
164 return " %s " % tag(POS(b_h))
167 def __str__(self, tag=lambda x:x):
168 return "%s-->%s %s\t[N %.2f] [A %.2f]" % (self.bar_str(self.LHS(), tag),
169 self.bar_str(self.L(), tag),
170 self.bar_str(self.R(), tag),
171 self.probN,
172 self.probA)
180 ###################################
181 # dmv-specific version of inner() #
182 ###################################
183 def locs(h, sent, s=0, t=None, remove=None):
184 '''Return the locations of h in sent, or some fragment of sent (in the
185 latter case we make sure to offset the locations correctly so that
186 for any x in the returned list, sent[x]==h).
188 t is inclusive, to match the way indices work with inner()
189 (although python list-splicing has "exclusive" end indices)'''
190 if t == None:
191 t = len(sent)-1
192 return [i+s for i,w in enumerate(sent[s:t+1])
193 if w == h and not (i+s) == remove]
196 def inner(s, t, LHS, loc_h, g, sent, ichart={}):
197 ''' A rewrite of io.inner(), to take adjacency into accord.
199 The ichart is now of this form:
200 ichart[s,t,LHS, loc_h]
202 loc_h gives adjacency (along with r and location of other child
203 for attachment rules), and is needed in P_STOP reestimation.
205 Todo: if possible, refactor (move dmv-specific stuff back into
206 dmv, so this is "general" enough to be in io.py)
209 def O(s):
210 return sent[s]
212 sent_nums = g.sent_nums(sent)
214 def e(s,t,LHS, loc_h, n_t):
215 def tab():
216 "Tabs for debug output"
217 return "\t"*n_t
219 if (s, t, LHS, loc_h) in ichart:
220 if 'INNER' in DEBUG:
221 print "%s*= %.4f in ichart: s:%d t:%d LHS:%s loc:%d" % (tab(),ichart[s, t, LHS, loc_h], s, t,
222 DMV_Rule.bar_str(LHS), loc_h)
223 return ichart[s, t, LHS, loc_h]
224 else:
225 if s == t and seals(LHS) == GOR:
226 if not loc_h == s:
227 if 'INNER' in DEBUG:
228 print "%s*= 0.0 (wrong loc_h)" % tab()
229 return 0.0
230 elif (LHS, O(s)) in g.p_terminals:
231 prob = g.p_terminals[LHS, O(s)] # "b[LHS, O(s)]" in Lari&Young
232 else:
233 # todo: assuming this is how to deal w/lacking
234 # rules, since we add prob.s, and 0 is identity
235 prob = 0.0
236 if 'INNER' in DEBUG:
237 print "%sLACKING TERMINAL:" % tab()
238 # todo: add to ichart perhaps? Although, it _is_ simple lookup..
239 if 'INNER' in DEBUG:
240 print "%s*= %.4f (terminal: %s -> %s_%d)" % (tab(),prob, DMV_Rule.bar_str(LHS), O(s), loc_h)
241 return prob
242 else:
243 p = 0.0 # "sum over j,k in a[LHS,j,k]"
244 for rule in g.sent_rules(LHS, sent_nums):
245 if 'INNER' in DEBUG:
246 print "%ssumming rule %s s:%d t:%d loc:%d" % (tab(),rule,s,t,loc_h)
247 L = rule.L()
248 R = rule.R()
249 if loc_h == t and LHS == L:
250 continue # todo: speed-test
251 if loc_h == s and LHS == R:
252 continue
253 # if it's a STOP rule, rewrite for the same xrange:
254 if (L == STOP) or (R == STOP):
255 if L == STOP:
256 pLR = e(s, t, R, loc_h, n_t+1)
257 elif R == STOP:
258 pLR = e(s, t, L, loc_h, n_t+1)
259 p += rule.p_STOP(s, t, loc_h) * pLR
260 if 'INNER' in DEBUG:
261 print "%sp= %.4f (STOP)" % (tab(), p)
263 elif t > s: # not a STOP, attachment rewrite:
264 rp_ATTACH = rule.p_ATTACH # todo: profile/speedtest
265 for r in xrange(s, t):
266 p_h = rp_ATTACH(r, loc_h, s=s)
267 if LHS == L:
268 locs_L = [loc_h]
269 locs_R = locs(POS(R), sent_nums, r+1, t, loc_h)
270 elif LHS == R:
271 locs_L = locs(POS(L), sent_nums, s, r, loc_h)
272 locs_R = [loc_h]
273 for loc_L in locs_L:
274 pL = e(s, r, L, loc_L, n_t+1)
275 if pL > 0.0:
276 for loc_R in locs_R:
277 pR = e(r+1, t, R, loc_R, n_t+1)
278 p += pL * p_h * pR
279 if 'INNER' in DEBUG:
280 print "%sp= %.4f (ATTACH)" % (tab(), p)
281 ichart[s, t, LHS, loc_h] = p
282 return p
283 # end of e-function
285 inner_prob = e(s,t,LHS,loc_h, 0)
286 if 'INNER' in DEBUG:
287 print debug_ichart(g,sent,ichart)
288 return inner_prob
289 # end of dmv.inner(s, t, LHS, loc_h, g, sent, ichart={})
292 def debug_ichart(g,sent,ichart):
293 str = "---ICHART:---\n"
294 for (s,t,LHS,loc_h),v in ichart.iteritems():
295 if type(v) == dict: # skip 'tree'
296 continue
297 str += "%s -> %s_%d ... %s_%d (loc_h:%s):\t%.4f\n" % (DMV_Rule.bar_str(LHS,g.numtag),
298 sent[s], s, sent[s], t, loc_h, v)
299 str += "---ICHART:end---\n"
300 return str
303 def inner_sent(g, sent, ichart={}):
304 return sum([inner(0, len(sent)-1, ROOT, loc_h, g, sent, ichart)
305 for loc_h in xrange(len(sent))])
308 ###################################
309 # dmv-specific version of outer() #
310 ###################################
311 def outer(s,t,Node,loc_N, g, sent, ichart={}, ochart={}):
312 ''' http://www.student.uib.no/~kun041/dmvccm/DMVCCM.html#outer
314 def e(s,t,LHS,loc_h):
315 # or we could just look it up in ichart, assuming ichart to be done
316 return inner(s, t, LHS, loc_h, g, sent, ichart)
318 T = len(sent)-1
319 sent_nums = g.sent_nums(sent)
321 def f(s,t,Node,loc_N):
322 if (s,t,Node,loc_N) in ochart:
323 return ochart[(s, t, Node,loc_N)]
324 if Node == ROOT:
325 if s == 0 and t == T:
326 return 1.0
327 else: # ROOT may only be used on full sentence
328 return 0.0 # but we may have non-ROOTs over full sentence too
329 p = 0.0
331 for mom in g.mothersL(Node, sent_nums, loc_N): # mom.L() == Node
332 R = mom.R()
333 mLHS = mom.LHS()
334 if R == STOP:
335 p += f(s,t,mLHS,loc_N) * mom.p_STOP(s,t,loc_N) # == loc_m
336 else:
337 if seals(mLHS) == RGOL: # left attachment, POS(mLHS) == POS(R)
338 for r in xrange(t+1,T+1): # t+1 to lasT
339 for loc_m in locs(POS(mLHS),sent_nums,t+1,r):
340 p_m = mom.p(t+1 == loc_m)
341 p += f(s,r,mLHS,loc_m) * p_m * e(t+1,r,R,loc_m)
342 elif seals(mLHS) == GOR: # right attachment, POS(mLHS) == POS(Node)
343 loc_m = loc_N
344 p_m = mom.p( t == loc_m)
345 for r in xrange(t+1,T+1): # t+1 to lasT
346 for loc_R in locs(POS(R),sent_nums,t+1,r):
347 p += f(s,r,mLHS,loc_m) * p_m * e(t+1,r,R,loc_R)
349 for mom in g.mothersR(Node, sent_nums, loc_N): # mom.R() == Node
350 L = mom.L()
351 mLHS = mom.LHS()
352 if L == STOP:
353 p += f(s,t,mLHS,loc_N) * mom.p_STOP(s,t,loc_N) # == loc_m
354 else:
355 if seals(mLHS) == RGOL: # left attachment, POS(mLHS) == POS(Node)
356 loc_m = loc_N
357 p_m = mom.p( s == loc_m)
358 for r in xrange(0,s): # first to s-1
359 for loc_L in locs(POS(L),sent_nums,r,s-1):
360 p += e(r,s-1,L, loc_L) * p_m * f(r,t,mLHS,loc_m)
361 elif seals(mLHS) == GOR: # right attachment, POS(mLHS) == POS(L)
362 for r in xrange(0,s): # first to s-1
363 for loc_m in locs(POS(mLHS),sent_nums,r,s-1):
364 p_m = mom.p(s-1 == loc_m)
365 p += e(r,s-1,L, loc_m) * p_m * f(r,t,mLHS,loc_m)
366 ochart[s,t,Node,loc_N] = p
367 return p
370 return f(s,t,Node,loc_N)
371 # end outer(s,t,Node,loc_N, g,sent, ichart,ochart)
375 ##############################
376 # reestimation, todo: #
377 ##############################
378 ## using local version instead
379 # def c(s,t,LHS,loc_h,g,sent,ichart={},ochart={}):
380 # # assuming P_sent = P(D(ROOT)) = inner(sent). todo: check K&M about this
381 # p_sent = inner_sent(g, sent, ichart)
382 # p_in = inner(s,t,LHS,loc_h,g,sent,ichart)
383 # p_out = outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
384 # if p_sent > 0.0:
385 # return p_in * p_out / p_sent
386 # else:
387 # return p_sent
389 def reest_zeros(h_nums):
390 # todo: p_ROOT? ... p_terminals?
391 f = {}
392 for h in h_nums:
393 for stop in ['LNSTOP','LASTOP','RNSTOP','RASTOP']:
394 for nd in ['num','den']:
395 f[stop,nd,h] = 0.0
396 for choice in ['RCHOOSE', 'LCHOOSE']:
397 f[choice,'den',h] = 0.0
398 return f
400 def reest_freq(g, corpus):
401 ''' P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
402 f = reest_zeros(g.head_nums)
403 ichart = {}
404 ochart = {}
406 p_sent = None # 50 % speed increase on storing this locally
407 def c_g(s,t,LHS,loc_h,sent): # altogether 2x faster than the global c()
408 if (s,t,LHS,loc_h) in ichart:
409 p_in = ichart[s,t,LHS,loc_h]
410 else:
411 p_in = inner(s,t,LHS,loc_h,g,sent,ichart)
412 if (s,t,LHS,loc_h) in ochart:
413 p_out = ochart[s,t,LHS,loc_h]
414 else:
415 p_out = outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
417 if p_sent > 0.0:
418 return p_in * p_out / p_sent
419 else:
420 return p_sent
422 def w_g(s,t,a,loc_a,LHS,loc_h,sent):
423 "Todo: should sum through all r in between s and t in sent(_nums)"
424 h = POS(LHS)
425 b_h = seals(LHS)
426 if b_h == GOR:
427 return e_L * e_R * f_g(s,t,(GOR, h), loc_h, sent) * p_g(r,(GOR, h), (GOR, h), (SEAL, a), loc_h, sent_nums)
428 if b_h == RGOL:
429 return e_L * e_R * f_g(s,t,(RGOL, h), loc_h, sent) * p_g(r,(RGOL, h),(SEAL, a),(RGOL, h),loc_h,sent_nums)
431 def f_g(s,t,LHS,loc_h,sent): # todo: test with choose rules
432 if (s,t,LHS,loc_h) in ochart:
433 return ochart[s,t,LHS,loc_h]
434 else:
435 return outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
437 def e_g(s,t,LHS,loc_h,sent): # todo: test with choose rules
438 if (s,t,LHS,loc_h) in ichart:
439 return ichart[s,t,LHS,loc_h]
440 else:
441 return inner(s,t,LHS,loc_h,g,sent,ichart)
443 def p_g(r,LHS,L,R,loc_h,sent):
444 rules = [rule for rule in g.sent_rules(LHS, sent)
445 if rule.L() == L and rule.R() == R]
446 rule = rules[0]
447 if len(rules) > 1:
448 raise Exception("Several rules matching a[i,j,k]")
449 return rule.p_ATTACH(r,loc_h)
451 for sent in corpus:
452 if 'reest' in DEBUG:
453 print sent
454 ichart = {}
455 ochart = {}
456 p_sent = inner_sent(g, sent, ichart)
458 sent_nums = g.sent_nums(sent)
459 # todo: use sum([ichart[s, t...] etc? but can we then
460 # keep den and num separate within _one_ sum()-call?
461 for loc_h,h in enumerate(sent_nums):
462 for t in xrange(loc_h, len(sent)):
463 for s in xrange(loc_h): # s<loc(h), xrange gives strictly less
464 # left non-adjacent stop:
465 f['LNSTOP','num',h] += c_g(s, t, (SEAL, h), loc_h,sent)
466 f['LNSTOP','den',h] += c_g(s, t, (RGOL,h), loc_h,sent)
467 # left adjacent stop:
468 f['LASTOP','num',h] += c_g(loc_h, t, (SEAL, h), loc_h,sent)
469 f['LASTOP','den',h] += c_g(loc_h, t, (RGOL,h), loc_h,sent)
470 for t in xrange(loc_h+1, len(sent)):
471 # right non-adjacent stop:
472 f['RNSTOP','num',h] += c_g(loc_h, t, (RGOL,h), loc_h,sent)
473 f['RNSTOP','den',h] += c_g(loc_h, t, (GOR, h), loc_h,sent)
474 # right adjacent stop:
475 f['RASTOP','num',h] += c_g(loc_h, loc_h, (RGOL,h), loc_h,sent)
476 f['RASTOP','den',h] += c_g(loc_h, loc_h, (GOR, h), loc_h,sent)
478 # right attachment: TODO: try with p*e*e*f instead of c, for numerator
479 if 'reest_attach' in DEBUG:
480 print "Rattach %s: for t in %s"%(g.numtag(h),sent[loc_h+1:len(sent)])
481 for t in xrange(loc_h+1, len(sent)):
482 cM = c_g(loc_h,t,(GOR, h), loc_h, sent) # v_q in L&Y
483 f['RCHOOSE','den',h] += cM
484 if 'reest_attach' in DEBUG:
485 print "\tc_g( %d , %d, %s, %s, sent)=%.4f"%(loc_h,t,g.numtag(h),loc_h,cM)
486 args = {} # for summing w_q's in L&Y, without 1/P_q
487 for r in xrange(loc_h+1, t+1): # loc_h < r <= t
488 e_L = e_g(loc_h, r-1, (GOR, h), loc_h, sent)
489 if 'reest_attach' in DEBUG:
490 print "\t\te_g( %d , %d, %s, %d, sent)=%.4f"%(loc_h,r-1,g.numtag(h),loc_h,e_L)
491 for i,a in enumerate(sent_nums[r:t+1]):
492 loc_a = i+r
493 e_R = e_g(r, t, (SEAL, a), loc_a, sent)
494 if a not in args:
495 args[a] = 0.0
496 args[a] += e_L * e_R * f_g(loc_h,t,(GOR, h), loc_h, sent) * p_g(r,(GOR, h), (GOR, h), (SEAL, a), loc_h, sent_nums)
497 for a,sum_a in args.iteritems():
498 f['RCHOOSE','num',h,a] = sum_a / p_sent
501 # left attachment:
502 if 'reest_attach' in DEBUG:
503 print "Lattach %s: for s in %s"%(g.numtag(h),sent[0:loc_h])
504 for s in xrange(0, loc_h):
505 if 'reest_attach' in DEBUG:
506 print "\tfor t in %s"%sent[loc_h:len(sent)]
507 for t in xrange(loc_h, len(sent)):
508 c_M = c_g(s,t,(RGOL, h), loc_h, sent) # v_q in L&Y
509 f['LCHOOSE','den',h] += c_M
510 if 'reest_attach' in DEBUG:
511 print "\t\tc_g( %d , %d, %s_, %s, sent)=%.4f"%(s,t,g.numtag(h),loc_h,c_M)
512 if 'reest_attach' in DEBUG:
513 print "\t\tfor r in %s"%(sent[s:loc_h])
514 args = {} # for summing w_q's in L&Y, without 1/P_q
515 for r in xrange(s, loc_h): # s <= r < loc_h <= t
516 e_R = e_g(r+1, t, (RGOL, h), loc_h, sent)
517 if 'reest_attach' in DEBUG:
518 print "\t\te_g( %d , %d, %s_, %d, sent)=%.4f"%(r+1,t,g.numtag(h),loc_h,e_R)
519 for i,a in enumerate(sent_nums[s:r+1]):
520 loc_a = i+s
521 e_L = e_g( s , r, (SEAL, a), loc_a, sent)
522 if a not in args:
523 args[a] = 0.0
524 args[a] += e_L * e_R * f_g(s,t,(RGOL, h), loc_h, sent) * p_g(r,(RGOL, h),(SEAL, a),(RGOL, h),loc_h,sent_nums)
525 for a,sum_a in args.iteritems():
526 f['LCHOOSE', 'num',h,a] = sum_a / p_sent
527 return f
529 def reestimate(g, corpus):
531 f = reest_freq(g, corpus)
532 # we want to go through only non-ROOT left-STOPs..
533 for r in g.all_rules():
534 reest_rule(r,f, g)
535 return f
538 def reest_rule(r,f, g): # g just for numtag / debug output, remove eventually?
539 "remove 0-prob rules? todo"
540 h = r.POS()
541 if r.LHS() == ROOT:
542 return None # not sure what todo yet here
543 if r.L() == STOP or POS(r.R()) == h:
544 dir = 'L'
545 elif r.R() == STOP or POS(r.L()) == h:
546 dir = 'R'
547 else:
548 raise Exception("Odd rule in reestimation.")
550 p_stopN = f[dir+'NSTOP','den',h]
551 if p_stopN > 0.0:
552 p_stopN = f[dir+'NSTOP','num',h] / p_stopN
554 p_stopA = f[dir+'ASTOP','den',h]
555 if p_stopA > 0.0:
556 p_stopA = f[dir+'ASTOP','num',h] / p_stopA
558 if r.L() == STOP or r.R() == STOP: # stop rules
559 if 'reest' in DEBUG:
560 print "p(STOP|%d=%s,%s,N): %.4f (was: %.4f)"%(h,g.numtag(h),dir, p_stopN, r.probN)
561 print "p(STOP|%d=%s,%s,A): %.4f (was: %.4f)"%(h,g.numtag(h),dir, p_stopA, r.probA)
562 r.probN = p_stopN
563 r.probA = p_stopA
565 else: # attachment rules
566 pchoose = f[dir+'CHOOSE','den',h]
567 if pchoose > 0.0:
568 if POS(r.R()) == h: # left attachment
569 a = POS(r.L())
570 elif POS(r.L()) == h: # right attachment
571 a = POS(r.R())
572 pchoose = f[dir+'CHOOSE','num',h,a] / pchoose
573 r.probN = (1-p_stopN) * pchoose
574 r.probA = (1-p_stopA) * pchoose
575 if 'reest' in DEBUG:
576 print "p(%d=%s|%d=%s,%s): %.4f,\tprobN: %.4f, probA: %.4f"%(a,g.numtag(a),h,g.numtag(h),dir, pchoose,r.probN,r.probA)
584 ##############################
585 # testing functions: #
586 ##############################
588 testcorpus = [s.split() for s in ['det nn vbd c vbd','vbd nn c vbd',
589 'det nn vbd', 'det nn vbd c pp',
590 'det nn vbd', 'det vbd vbd c pp',
591 'det nn vbd', 'det nn vbd c vbd',
592 'det nn vbd', 'det nn vbd c vbd',
593 'det nn vbd', 'det nn vbd c vbd',
594 'det nn vbd', 'det nn vbd c pp',
595 'det nn vbd pp', 'det nn vbd', ]]
597 def testgrammar():
598 import before_betweens_harmonic
599 reload(before_betweens_harmonic)
600 return before_betweens_harmonic.initialize(testcorpus)
602 def testreestimation():
603 g = testgrammar()
604 f = reestimate(g, testcorpus)
605 f_stops = {('LNSTOP', 'den', 3): 12.212773236178391, ('RASTOP', 'den', 2): 4.0, ('RNSTOP', 'num', 4): 2.5553487221351365, ('LNSTOP', 'den', 2): 1.274904052793207, ('LASTOP', 'num', 1): 14.999999999999995, ('RASTOP', 'den', 3): 15.0, ('LASTOP', 'num', 4): 16.65701084787457, ('LASTOP', 'num', 0): 4.1600647714443468, ('LNSTOP', 'den', 4): 6.0170669155897105, ('LASTOP', 'num', 3): 2.7872267638216113, ('LASTOP', 'num', 2): 2.9723139990470515, ('LASTOP', 'den', 2): 4.0, ('RNSTOP', 'den', 3): 12.945787931730905, ('LASTOP', 'den', 3): 14.999999999999996, ('RNSTOP', 'den', 2): 0.0, ('LASTOP', 'den', 0): 8.0, ('RASTOP', 'num', 4): 19.44465127786486, ('RNSTOP', 'den', 1): 3.1966410324085777, ('LASTOP', 'den', 1): 14.999999999999995, ('RASTOP', 'num', 3): 4.1061665495365558, ('RNSTOP', 'den', 0): 4.8282499043902476, ('LNSTOP', 'num', 4): 5.3429891521254289, ('RASTOP', 'num', 2): 4.0, ('LASTOP', 'den', 4): 22.0, ('RASTOP', 'num', 1): 12.400273895299103, ('LNSTOP', 'num', 2): 1.0276860009529487, ('RASTOP', 'num', 0): 3.1717500956097533, ('LNSTOP', 'num', 3): 12.212773236178391, ('RASTOP', 'den', 4): 22.0, ('RNSTOP', 'den', 4): 2.8705211946979836, ('LNSTOP', 'num', 0): 3.8399352285556518, ('LNSTOP', 'num', 1): 0.0, ('RNSTOP', 'num', 0): 4.8282499043902476, ('RNSTOP', 'num', 1): 2.5997261047008959, ('LNSTOP', 'den', 1): 0.0, ('RASTOP', 'den', 0): 8.0, ('RNSTOP', 'num', 2): 0.0, ('LNSTOP', 'den', 0): 4.6540557322109795, ('RASTOP', 'den', 1): 15.0, ('RNSTOP', 'num', 3): 10.893833450463443}
606 for k,v in f_stops.iteritems():
607 if not k in f:
608 print '''Regression!(?) Something changed in the P_STOP reestimation,
609 expected f[%s]=%.4f, but %s not in f'''%(k,v,k)
610 pass
611 elif not "%.10f"%f[k] == "%.10f"%v:
612 print '''Regression!(?) Something changed in the P_STOP reestimation,
613 expected f[%s]=%.4f, got f[%s]=%.4f.'''%(k,v,k,f[k])
614 pass
617 def testgrammar_a(): # Non, Adj
618 _h_ = DMV_Rule((SEAL,0), STOP, ( RGOL,0), 1.0, 1.0) # LSTOP
619 h_S = DMV_Rule(( RGOL,0),(GOR,0), STOP, 0.4, 0.3) # RSTOP
620 h_A = DMV_Rule(( RGOL,0),(SEAL,0),( RGOL,0),0.2, 0.1) # Lattach
621 h_Aa= DMV_Rule(( RGOL,0),(SEAL,1),( RGOL,0),0.4, 0.6) # Lattach to a
622 h = DMV_Rule((GOR,0),(GOR,0),(SEAL,0), 1.0, 1.0) # Rattach
623 ha = DMV_Rule((GOR,0),(GOR,0),(SEAL,1), 1.0, 1.0) # Rattach to a
624 rh = DMV_Rule( ROOT, STOP, (SEAL,0), 0.9, 0.9) # ROOT
626 _a_ = DMV_Rule((SEAL,1), STOP, ( RGOL,1), 1.0, 1.0) # LSTOP
627 a_S = DMV_Rule(( RGOL,1),(GOR,1), STOP, 0.4, 0.3) # RSTOP
628 a_A = DMV_Rule(( RGOL,1),(SEAL,1),( RGOL,1),0.4, 0.6) # Lattach
629 a_Ah= DMV_Rule(( RGOL,1),(SEAL,0),( RGOL,1),0.2, 0.1) # Lattach to h
630 a = DMV_Rule((GOR,1),(GOR,1),(SEAL,1), 1.0, 1.0) # Rattach
631 ah = DMV_Rule((GOR,1),(GOR,1),(SEAL,0), 1.0, 1.0) # Rattach to h
632 ra = DMV_Rule( ROOT, STOP, (SEAL,1), 0.1, 0.1) # ROOT
634 b2 = {}
635 b2[(GOR, 0), 'h'] = 1.0
636 b2[(GOR, 1), 'a'] = 1.0
638 return DMV_Grammar({0:'h',1:'a'}, {'h':0,'a':1}, [ h_Aa, ha, a_Ah, ah, ra, _a_, a_S, a_A, a, rh, _h_, h_S, h_A, h ],b2,0,0,0)
639 def oa(s,t,LHS,loc_h):
640 return outer(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
641 def ia(s,t,LHS,loc_h):
642 return inner(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
643 def ca(s,t,LHS,loc_h):
644 return c(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
646 def testgrammar_h(): # Non, Adj
647 _h_ = DMV_Rule((SEAL,0), STOP, ( RGOL,0), 1.0, 1.0) # LSTOP
648 h_S = DMV_Rule(( RGOL,0),(GOR,0), STOP, 0.4, 0.3) # RSTOP
649 h_A = DMV_Rule(( RGOL,0),(SEAL,0),( RGOL,0), 0.6, 0.7) # Lattach
650 h = DMV_Rule((GOR,0),(GOR,0),(SEAL,0), 1.0, 1.0) # Rattach
651 rh = DMV_Rule( ROOT, STOP, (SEAL,0), 1.0, 1.0) # ROOT
652 b2 = {}
653 b2[(GOR, 0), 'h'] = 1.0
655 return DMV_Grammar({0:'h'}, {'h':0}, [ rh, _h_, h_S, h_A, h ],b2,0,0,0)
658 def testreestimation_h():
659 g = testgrammar_h()
660 reestimate(g,['h h h'.split()])
663 def regression_tests():
664 def test(wanted, got):
665 if not wanted == got:
666 print "Regression! Should be %s: %s" % (wanted, got)
668 g_dup = testgrammar_h()
670 test("0.120",
671 "%.3f" % inner(0, 1, (SEAL,0), 0, g_dup, 'h h'.split(), {}))
673 test("0.063",
674 "%.3f" % inner(0, 1, (SEAL,0), 1, g_dup, 'h h'.split(), {}))
676 test("0.0498",
677 "%.4f" % inner(0, 2, (SEAL,0), 2, g_dup, 'h h h'.split(), {}))
679 test("0.58" ,
680 "%.2f" % outer(1,2,(1,0),2,testgrammar_h(),'h h h'.split(),{},{}))
682 test("0.1089" ,
683 "%.4f" % outer(0,0,(0,0),0,testgrammar_a(),'h a'.split(),{},{}))
684 test("0.3600" ,
685 "%.4f" % outer(0,1,(0,0),0,testgrammar_a(),'h a'.split(),{},{}))
686 test("0.0000" ,
687 "%.4f" % outer(0,2,(0,0),0,testgrammar_a(),'h a'.split(),{},{}))
690 if __name__ == "__main__":
691 DEBUG.clear()
692 if __name__ == "__main__":
693 regression_tests()
694 testreestimation()
696 def testIO():
697 g = testgrammar()
698 inners = [(sent, inner_sent(g, sent, {})) for sent in testcorpus]
699 return inners