add_root fn written
[dmvccm.git] / src / cnf_dmv.py
blob952e55ad8ca28a654495252f20f249bd97c12b24
1 # cnf_dmv.py
3 #import numpy # numpy provides Fast Arrays, for future optimization
4 from common_dmv import *
5 import io
7 if __name__ == "__main__":
8 print "cnf_dmv module tests:"
10 def att(node):
11 "Only used in cnf_, right"
12 return node[2]
14 class DMV_Grammar(io.Grammar):
15 '''The DMV-PCFG.
17 Public members:
18 p_STOP, p_ROOT, p_CHOOSE, p_terminals
19 These are changed in the Maximation step, then used to set the
20 new probabilities of each DMV_Rule.
22 Todo: make p_terminals private? (But it has to be changable in
23 maximation step due to the short-cutting rules... could of course
24 make a DMV_Grammar function to update the short-cut rules...)
26 __p_rules is private, but we can still say stuff like:
27 for r in g.all_rules():
28 r.probN = newProbN
30 What other representations do we need? (P_STOP formula uses
31 deps_D(h,l/r) at least)'''
32 def __str__(self):
33 str = ""
34 for r in self.all_rules():
35 str += "%s\n" % r.__str__(self.numtag)
36 return str
38 def h_rules(self, h):
39 return [r for r in self.all_rules() if r.head() == h]
41 def mothersL(self, Node, sent_nums, loc_N):
42 # todo: speed-test with and without sent_nums/loc_N cut-off
43 return [r for r in self.all_rules() if r.L() == Node
44 and (head(r.R()) in sent_nums[loc_N+1:] or r.R() == STOP)]
46 def mothersR(self, Node, sent_nums, loc_N):
47 return [r for r in self.all_rules() if r.R() == Node
48 and (head(r.L()) in sent_nums[:loc_N] or r.L() == STOP)]
50 def rules(self, LHS):
51 return [r for r in self.all_rules() if r.LHS() == LHS]
53 def sent_rules(self, LHS, sent_nums):
54 '''Used in dmv.inner. Todo: this takes a _lot_ of time, it
55 seems. Could use some more space and cache some of this
56 somehow perhaps?'''
57 # We don't want to rule out STOPs!
58 nums = sent_nums + [ head(STOP) ]
59 return [r for r in self.all_rules() if r.LHS() == LHS
60 and head(r.L()) in nums and head(r.R()) in nums]
62 def deps_L(self, head): # todo: do I use this at all?
63 # todo test, probably this list comprehension doesn't work
64 return [a for r in self.all_rules() if r.head() == head and a == r.L()]
66 def deps_R(self, head):
67 # todo test, probably this list comprehension doesn't work
68 return [a for r in self.all_rules() if r.head() == head and a == r.R()]
70 def __init__(self, p_rules, p_terminals, p_STOP, p_CHOOSE, p_ROOT, numtag, tagnum):
71 io.Grammar.__init__(self, p_rules, p_terminals, numtag, tagnum)
72 self.p_STOP = p_STOP
73 self.p_CHOOSE = p_CHOOSE
74 self.p_ROOT = p_ROOT
75 self.head_nums = [k for k in numtag.iterkeys()]
78 class DMV_Rule(io.CNF_Rule):
79 '''A single CNF rule in the PCFG, of the form
80 LHS -> L R
81 where LHS, L and R are 'nodes', eg. of the form (seals, head).
83 Public members:
84 probN, probA
86 Private members:
87 __L, __R, __LHS
89 Different rule-types have different probabilities associated with
90 them:
92 _h_ -> STOP h_ P( STOP|h,L, adj)
93 _h_ -> STOP h_ P( STOP|h,L,non_adj)
94 h_ -> h STOP P( STOP|h,R, adj)
95 h_ -> h STOP P( STOP|h,R,non_adj)
96 h_ -> _a_ h_ P(-STOP|h,L, adj) * P(a|h,L)
97 h_ -> _a_ h_ P(-STOP|h,L,non_adj) * P(a|h,L)
98 h -> h _a_ P(-STOP|h,R, adj) * P(a|h,R)
99 h -> h _a_ P(-STOP|h,R,non_adj) * P(a|h,R)
101 def p(self, adj, *arg):
102 if adj:
103 return self.probA
104 else:
105 return self.probN
107 def p_STOP(self, s, t, loc_h):
108 '''Returns the correct probability, adjacent if we're rewriting from
109 the (either left or right) end of the fragment. '''
110 if self.L() == STOP:
111 return self.p(s == loc_h)
112 elif self.R() == STOP:
113 if not loc_h == s:
114 if 'TODO' in DEBUG:
115 print "(%s given loc_h:%d but s:%d. Todo: optimize away!)" % (self, loc_h, s)
116 return 0.0
117 else:
118 return self.p(t == loc_h)
120 def p_ATTACH(self, r, loc_h, s=None):
121 '''Returns the correct probability, adjacent if we haven't attached
122 anything before.'''
123 if self.LHS() == self.L():
124 if not loc_h == s:
125 if 'TODO' in DEBUG:
126 print "(%s given loc_h (loc_L):%d but s:%d. Todo: optimize away!)" % (self, loc_h, s)
127 return 0.0
128 else:
129 return self.p(r == loc_h)
130 elif self.LHS() == self.R():
131 return self.p(r+1 == loc_h)
133 def seals(self):
134 return seals(self.LHS())
136 def head(self):
137 return head(self.LHS())
139 def __init__(self, LHS, L, R, probN, probA):
140 for b_h in [LHS, L, R]:
141 if seals(b_h) not in SEALS:
142 raise ValueError("seals must be in %s; was given: %s"
143 % (SEALS, seals(b_h)))
144 io.CNF_Rule.__init__(self, LHS, L, R, probN)
145 self.probA = probA # adjacent
146 self.probN = probN # non_adj
148 @classmethod # so we can call DMV_Rule.bar_str(b_h)
149 def bar_str(cls, b_h, tag=lambda x:x):
150 if(b_h == ROOT):
151 return 'ROOT'
152 elif(b_h == STOP):
153 return 'STOP'
154 elif(seals(b_h) == RGO_L):
155 return " %s_ " % tag(head(b_h))
156 elif(seals(b_h) == SEAL):
157 return "_%s_ " % tag(head(b_h))
158 else:
159 return " %s " % tag(head(b_h))
162 def __str__(self, tag=lambda x:x):
163 return "%s-->%s %s\t[N %.2f] [A %.2f]" % (self.bar_str(self.LHS(), tag),
164 self.bar_str(self.L(), tag),
165 self.bar_str(self.R(), tag),
166 self.probN,
167 self.probA)
175 ###################################
176 # dmv-specific version of inner() #
177 ###################################
178 def locs(h, sent, s=0, t=None, remove=None):
179 '''Return the locations of h in sent, or some fragment of sent (in the
180 latter case we make sure to offset the locations correctly so that
181 for any x in the returned list, sent[x]==h).
183 t is inclusive, to match the way indices work with inner()
184 (although python list-splicing has "exclusive" end indices)'''
185 if t == None:
186 t = len(sent)-1
187 return [i+s for i,w in enumerate(sent[s:t+1])
188 if w == h and not (i+s) == remove]
191 def inner(s, t, LHS, loc_h, g, sent, ichart={}):
192 ''' A rewrite of io.inner(), to take adjacency into accord.
194 The ichart is now of this form:
195 ichart[s,t,LHS, loc_h]
197 loc_h gives adjacency (along with r and location of other child
198 for attachment rules), and is needed in P_STOP reestimation.
200 Todo: if possible, refactor (move dmv-specific stuff back into
201 dmv, so this is "general" enough to be in io.py)
204 def O(s):
205 return sent[s]
207 sent_nums = g.sent_nums(sent)
208 tree = {}
210 def e(s,t,LHS, loc_h, n_t):
211 def tab():
212 "Tabs for debug output"
213 return "\t"*n_t
215 if (s, t, LHS, loc_h) in ichart:
216 if 'INNER' in DEBUG:
217 print "%s*= %.4f in ichart: s:%d t:%d LHS:%s loc:%d" % (tab(),ichart[s, t, LHS, loc_h], s, t,
218 DMV_Rule.bar_str(LHS), loc_h)
219 return ichart[s, t, LHS, loc_h]
220 else:
221 if s == t and seals(LHS) == GO_R:
222 if not loc_h == s:
223 if 'INNER' in DEBUG:
224 print "%s*= 0.0 (wrong loc_h)" % tab()
225 return 0.0
226 elif (LHS, O(s)) in g.p_terminals:
227 prob = g.p_terminals[LHS, O(s)] # "b[LHS, O(s)]" in Lari&Young
228 else:
229 # todo: assuming this is how to deal w/lacking
230 # rules, since we add prob.s, and 0 is identity
231 prob = 0.0
232 if 'INNER' in DEBUG:
233 print "%sLACKING TERMINAL:" % tab()
234 # todo: add to ichart perhaps? Although, it _is_ simple lookup..
235 if 'INNER' in DEBUG:
236 print "%s*= %.4f (terminal: %s -> %s_%d)" % (tab(),prob, DMV_Rule.bar_str(LHS), O(s), loc_h)
237 return prob
238 else:
239 p = 0.0 # "sum over j,k in a[LHS,j,k]"
240 for rule in g.sent_rules(LHS, sent_nums):
241 if 'INNER' in DEBUG:
242 print "%ssumming rule %s s:%d t:%d loc:%d" % (tab(),rule,s,t,loc_h)
243 L = rule.L()
244 R = rule.R()
245 if (s,t,LHS,loc_h) not in tree:
246 tree[s,t,LHS,loc_h] = set()
247 if loc_h == t and rule.LHS() == L:
248 continue # todo: speed-test
249 if loc_h == s and rule.LHS() == R:
250 continue
251 # if it's a STOP rule, rewrite for the same xrange:
252 if (L == STOP) or (R == STOP):
253 if L == STOP:
254 pLR = e(s, t, R, loc_h, n_t+1)
255 if pLR > 0.0:
256 tree[s,t,LHS,loc_h].add((s,t,R,loc_h))
257 elif R == STOP:
258 pLR = e(s, t, L, loc_h, n_t+1)
259 if pLR > 0.0:
260 tree[s,t,LHS,loc_h].add((s,t,L,loc_h))
261 p += rule.p_STOP(s, t, loc_h) * pLR
262 if 'INNER' in DEBUG:
263 print "%sp= %.4f (STOP)" % (tab(), p)
265 elif t > s: # not a STOP, attachment rewrite:
266 rp_ATTACH = rule.p_ATTACH # todo: profile/speedtest
267 for r in xrange(s, t):
268 p_h = rp_ATTACH(r, loc_h, s=s)
269 if rule.LHS() == L:
270 locs_L = [loc_h]
271 locs_R = locs(head(R), sent_nums, r+1, t, loc_h)
272 elif rule.LHS() == R:
273 locs_L = locs(head(L), sent_nums, s, r, loc_h)
274 locs_R = [loc_h]
275 for loc_L in locs_L:
276 pL = e(s, r, L, loc_L, n_t+1)
277 if pL > 0.0:
278 for loc_R in locs_R:
279 pR = e(r+1, t, R, loc_R, n_t+1)
280 if pR > 0.0: # and pL > 0.0
281 tree[s,t,LHS,loc_h].add(( s ,r,L,loc_L))
282 tree[s,t,LHS,loc_h].add((r+1,t,R,loc_R))
283 p += pL * p_h * pR
284 if 'INNER' in DEBUG:
285 print "%sp= %.4f (ATTACH)" % (tab(), p)
286 ichart[s, t, LHS, loc_h] = p
287 return p
288 # end of e-function
290 inner_prob = e(s,t,LHS,loc_h, 0)
291 ichart['tree'] = {}
292 if 'INNER' in DEBUG:
293 print debug_ichart(g,sent,ichart)
294 return inner_prob
295 # end of dmv.inner(s, t, LHS, loc_h, g, sent, ichart={})
298 def debug_ichart(g,sent,ichart):
299 str = "---ICHART:---\n"
300 for (s,t,LHS,loc_h),v in ichart.iteritems():
301 if type(v) == dict: # skip 'tree'
302 continue
303 str += "%s -> %s_%d ... %s_%d (loc_h:%s):\t%.4f\n" % (DMV_Rule.bar_str(LHS,g.numtag),
304 sent[s], s, sent[s], t, loc_h, v)
305 str += "---ICHART:end---\n"
306 return str
309 def inner_sent(g, sent, ichart={}):
310 return sum([inner(0, len(sent)-1, ROOT, loc_h, g, sent, ichart)
311 for loc_h in xrange(len(sent))])
314 def c(s,t,LHS,loc_h,g,sent,ichart={},ochart={}):
315 # assuming P_sent = P(D(ROOT)) = inner(sent). todo: check K&M about this
316 p_sent = inner_sent(g, sent, ichart)
317 p_in = inner(s,t,LHS,loc_h,g,sent,ichart)
318 p_out = outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
319 if p_sent > 0.0:
320 return p_in * p_out / p_sent
321 else:
322 return p_sent
324 ###################################
325 # dmv-specific version of outer() #
326 ###################################
327 def outer(s,t,Node,loc_N, g, sent, ichart={}, ochart={}):
328 ''' http://www.student.uib.no/~kun041/dmvccm/DMVCCM.html#outer
330 def e(s,t,LHS,loc_h):
331 # or we could just look it up in ichart, assuming ichart to be done
332 return inner(s, t, LHS, loc_h, g, sent, ichart)
334 T = len(sent)-1
335 sent_nums = g.sent_nums(sent)
337 def f(s,t,Node,loc_N):
338 if (s,t,Node,loc_N) in ochart:
339 return ochart[(s, t, Node,loc_N)]
340 if Node == ROOT:
341 if s == 0 and t == T:
342 return 1.0
343 else: # ROOT may only be used on full sentence
344 return 0.0 # but we may have non-ROOTs over full sentence too
345 p = 0.0
347 for mom in g.mothersL(Node, sent_nums, loc_N): # mom.L() == Node
348 R = mom.R()
349 mLHS = mom.LHS()
350 if R == STOP:
351 p += f(s,t,mLHS,loc_N) * mom.p_STOP(s,t,loc_N) # == loc_m
352 else:
353 if seals(mLHS) == RGO_L: # left attachment, head(mLHS) == head(R)
354 for r in xrange(t+1,T+1): # t+1 to lasT
355 for loc_m in locs(head(mLHS),sent_nums,t+1,r):
356 p_m = mom.p(t+1 == loc_m)
357 p += f(s,r,mLHS,loc_m) * p_m * e(t+1,r,R,loc_m)
358 elif seals(mLHS) == GO_R: # right attachment, head(mLHS) == head(Node)
359 loc_m = loc_N
360 p_m = mom.p( t == loc_m)
361 for r in xrange(t+1,T+1): # t+1 to lasT
362 for loc_R in locs(head(R),sent_nums,t+1,r):
363 p += f(s,r,mLHS,loc_m) * p_m * e(t+1,r,R,loc_R)
365 for mom in g.mothersR(Node, sent_nums, loc_N): # mom.R() == Node
366 L = mom.L()
367 mLHS = mom.LHS()
368 if L == STOP:
369 p += f(s,t,mLHS,loc_N) * mom.p_STOP(s,t,loc_N) # == loc_m
370 else:
371 if seals(mLHS) == RGO_L: # left attachment, head(mLHS) == head(Node)
372 loc_m = loc_N
373 p_m = mom.p( s == loc_m)
374 for r in xrange(0,s): # first to s-1
375 for loc_L in locs(head(L),sent_nums,r,s-1):
376 p += e(r,s-1,L, loc_L) * p_m * f(r,t,mLHS,loc_m)
377 elif seals(mLHS) == GO_R: # right attachment, head(mLHS) == head(L)
378 for r in xrange(0,s): # first to s-1
379 for loc_m in locs(head(mLHS),sent_nums,r,s-1):
380 p_m = mom.p(s-1 == loc_m)
381 p += e(r,s-1,L, loc_m) * p_m * f(r,t,mLHS,loc_m)
382 ochart[s,t,Node,loc_N] = p
383 return p
386 return f(s,t,Node,loc_N)
387 # end outer(s,t,Node,loc_N, g,sent, ichart,ochart)
391 ##############################
392 # reestimation, todo: #
393 ##############################
394 def reest_zeros(h_nums):
395 # todo: p_ROOT? ... p_terminals?
396 f = {}
397 for h in h_nums:
398 for stop in ['LNSTOP','LASTOP','RNSTOP','RASTOP']:
399 for nd in ['num','den']:
400 f[stop,nd,h] = 0.0
401 for choice in ['RCHOOSE', 'LCHOOSE']:
402 f[choice,'den',h] = 0.0
403 return f
405 def reest_freq(g, corpus):
406 ''' P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
407 f = reest_zeros(g.head_nums)
408 ichart = {}
409 ochart = {}
411 p_sent = None # 50 % speed increase on storing this locally
412 def c_g(s,t,LHS,loc_h,sent): # altogether 2x faster than the global c()
413 if (s,t,LHS,loc_h) in ichart:
414 p_in = ichart[s,t,LHS,loc_h]
415 else:
416 p_in = inner(s,t,LHS,loc_h,g,sent,ichart)
417 if (s,t,LHS,loc_h) in ochart:
418 p_out = ochart[s,t,LHS,loc_h]
419 else:
420 p_out = outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
422 if p_sent > 0.0:
423 return p_in * p_out / p_sent
424 else:
425 return p_sent
427 def f_g(s,t,LHS,loc_h,sent): # todo: test with choose rules
428 if (s,t,LHS,loc_h) in ochart:
429 return ochart[s,t,LHS,loc_h]
430 else:
431 return outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
433 def e_g(s,t,LHS,loc_h,sent): # todo: test with choose rules
434 if (s,t,LHS,loc_h) in ichart:
435 return ichart[s,t,LHS,loc_h]
436 else:
437 return inner(s,t,LHS,loc_h,g,sent,ichart)
439 def p_g(r,LHS,L,R,loc_h,sent):
440 rules = [rule for rule in g.sent_rules(LHS, sent)
441 if rule.L() == L and rule.R() == R]
442 rule = rules[0]
443 if len(rules) > 1:
444 raise Exception("Several rules matching a[i,j,k]")
445 return rule.p_ATTACH(r,loc_h)
447 for sent in corpus:
448 if 'reest' in DEBUG:
449 print sent
450 ichart = {}
451 ochart = {}
452 p_sent = inner_sent(g, sent, ichart)
454 sent_nums = g.sent_nums(sent)
455 # todo: use sum([ichart[s, t...] etc? but can we then
456 # keep den and num separate within _one_ sum()-call?
457 for loc_h,h in enumerate(sent_nums):
458 for t in xrange(loc_h, len(sent)):
459 for s in xrange(loc_h): # s<loc(h), xrange gives strictly less
460 # left non-adjacent stop:
461 f['LNSTOP','num',h] += c_g(s, t, (SEAL, h), loc_h,sent)
462 f['LNSTOP','den',h] += c_g(s, t, (RGO_L,h), loc_h,sent)
463 # left adjacent stop:
464 f['LASTOP','num',h] += c_g(loc_h, t, (SEAL, h), loc_h,sent)
465 f['LASTOP','den',h] += c_g(loc_h, t, (RGO_L,h), loc_h,sent)
466 for t in xrange(loc_h+1, len(sent)):
467 # right non-adjacent stop:
468 f['RNSTOP','num',h] += c_g(loc_h, t, (RGO_L,h), loc_h,sent)
469 f['RNSTOP','den',h] += c_g(loc_h, t, (GO_R, h), loc_h,sent)
470 # right adjacent stop:
471 f['RASTOP','num',h] += c_g(loc_h, loc_h, (RGO_L,h), loc_h,sent)
472 f['RASTOP','den',h] += c_g(loc_h, loc_h, (GO_R, h), loc_h,sent)
474 # right attachment: TODO: try with p*e*e*f instead of c, for numerator
475 if 'reest_attach' in DEBUG:
476 print "Rattach %s: for t in %s"%(g.numtag(h),sent[loc_h+1:len(sent)])
477 for t in xrange(loc_h+1, len(sent)):
478 cM = c_g(loc_h,t,(GO_R, h), loc_h, sent)
479 f['RCHOOSE','den',h] += cM
480 if 'reest_attach' in DEBUG:
481 print "\tc_g( %d , %d, %s, %s, sent)=%.4f"%(loc_h,t,g.numtag(h),loc_h,cM)
482 for r in xrange(loc_h+1, t+1): # loc_h < r <= t
483 c_L = c_g(loc_h, r-1, (GO_R, h), loc_h, sent)
484 if 'reest_attach' in DEBUG:
485 print "\t\tc_g( %d , %d, %s, %d, sent)=%.4f"%(loc_h,r-1,g.numtag(h),loc_h,c_L)
486 for i,a in enumerate(sent_nums[r:t+1]):
487 loc_a = i+r
488 c_R = c_g(r, t, (SEAL, a), loc_a, sent)
489 if ('RCHOOSE','num',h,a) not in f:
490 f['RCHOOSE','num',h,a] = 0.0
491 f['RCHOOSE','num',h,a] += c_R / c_L
492 if 'reest_attach' in DEBUG:
493 print "\t\t\tc_g( %d , %d, _%s_, %d, sent)=%.4f, /c_L = %.4f"%(r,t,g.numtag(a),loc_a,c_R,c_R/c_L)
495 # left attachment:
496 if 'reest_attach' in DEBUG:
497 print "Lattach %s: for s in %s"%(g.numtag(h),sent[0:loc_h])
498 for s in xrange(0, loc_h):
499 if 'reest_attach' in DEBUG:
500 print "\tfor t in %s"%sent[loc_h:len(sent)]
501 for t in xrange(loc_h, len(sent)):
502 c_M = c_g(s,t,(RGO_L, h), loc_h, sent) # v_q in L&Y
503 f['LCHOOSE','den',h] += c_M
504 if 'reest_attach' in DEBUG:
505 print "\t\tc_g( %d , %d, %s_, %s, sent)=%.4f"%(s,t,g.numtag(h),loc_h,c_M)
506 if 'reest_attach' in DEBUG:
507 print "\t\tfor r in %s"%(sent[s:loc_h])
508 args = {} # for summing w_q's in L&Y, without 1/P_q
509 for r in xrange(s, loc_h): # s <= r < loc_h <= t
510 e_R = e_g(r+1, t, (RGO_L, h), loc_h, sent)
511 if 'reest_attach' in DEBUG:
512 print "\t\tc_g( %d , %d, %s_, %d, sent)=%.4f"%(r+1,t,g.numtag(h),loc_h,e_R)
513 for i,a in enumerate(sent_nums[s:r+1]):
514 loc_a = i+s
515 e_L = e_g( s , r, (SEAL, a), loc_a, sent)
516 if a not in args:
517 args[a] = 0.0
518 args[a] += e_L * e_R * f_g(s,t,(RGO_L, h), loc_h, sent) * p_g(r,(RGO_L, h),(SEAL, a),(RGO_L, h),loc_h,sent_nums)
519 for a,sum_a in args.iteritems():
520 f['LCHOOSE', 'num',h,a] = sum_a / p_sent
521 return f
523 def reestimate(g, corpus):
525 f = reest_freq(g, corpus)
526 # we want to go through only non-ROOT left-STOPs..
527 for r in g.all_rules():
528 reest_rule(r,f, g)
529 return f
532 def reest_rule(r,f, g): # g just for numtag / debug output, remove eventually?
533 "remove 0-prob rules? todo"
534 h = r.head()
535 if r.LHS() == ROOT:
536 return None # not sure what todo yet here
537 if r.L() == STOP or head(r.R()) == h:
538 dir = 'L'
539 elif r.R() == STOP or head(r.L()) == h:
540 dir = 'R'
541 else:
542 raise Exception("Odd rule in reestimation.")
544 p_stopN = f[dir+'NSTOP','den',h]
545 if p_stopN > 0.0:
546 p_stopN = f[dir+'NSTOP','num',h] / p_stopN
548 p_stopA = f[dir+'ASTOP','den',h]
549 if p_stopA > 0.0:
550 p_stopA = f[dir+'ASTOP','num',h] / p_stopA
552 if r.L() == STOP or r.R() == STOP: # stop rules
553 if 'reest' in DEBUG:
554 print "p(STOP|%d=%s,%s,N): %.4f (was: %.4f)"%(h,g.numtag(h),dir, p_stopN, r.probN)
555 print "p(STOP|%d=%s,%s,A): %.4f (was: %.4f)"%(h,g.numtag(h),dir, p_stopA, r.probA)
556 r.probN = p_stopN
557 r.probA = p_stopA
559 else: # attachment rules
560 pchoose = f[dir+'CHOOSE','den',h]
561 if pchoose > 0.0:
562 if head(r.R()) == h: # left attachment
563 a = head(r.L())
564 elif head(r.L()) == h: # right attachment
565 a = head(r.R())
566 pchoose = f[dir+'CHOOSE','num',h,a] / pchoose
567 r.probN = (1-p_stopN) * pchoose
568 r.probA = (1-p_stopA) * pchoose
569 if 'reest' in DEBUG:
570 print "p(%d=%s|%d=%s,%s): %.4f,\tprobN: %.4f, probA: %.4f"%(a,g.numtag(a),h,g.numtag(h),dir, pchoose,r.probN,r.probA)
578 ##############################
579 # testing functions: #
580 ##############################
582 testcorpus = [s.split() for s in ['det nn vbd c vbd','vbd nn c vbd',
583 'det nn vbd', 'det nn vbd c pp',
584 'det nn vbd', 'det vbd vbd c pp',
585 'det nn vbd', 'det nn vbd c vbd',
586 'det nn vbd', 'det nn vbd c vbd',
587 'det nn vbd', 'det nn vbd c vbd',
588 'det nn vbd', 'det nn vbd c pp',
589 'det nn vbd pp', 'det nn vbd', ]]
591 def testgrammar():
592 import harmonic
593 reload(harmonic)
594 return harmonic.initialize(testcorpus)
596 def testreestimation():
597 DEBUG.add('reest')
598 g = testgrammar()
599 f = reestimate(g, testcorpus)
600 f_stops = {('LNSTOP', 'den', 3): 12.212773236178391, ('RASTOP', 'den', 2): 4.0, ('RNSTOP', 'num', 4): 2.5553487221351365, ('LNSTOP', 'den', 2): 1.274904052793207, ('LASTOP', 'num', 1): 14.999999999999995, ('RASTOP', 'den', 3): 15.0, ('LASTOP', 'num', 4): 16.65701084787457, ('LASTOP', 'num', 0): 4.1600647714443468, ('LNSTOP', 'den', 4): 6.0170669155897105, ('LASTOP', 'num', 3): 2.7872267638216113, ('LASTOP', 'num', 2): 2.9723139990470515, ('LASTOP', 'den', 2): 4.0, ('RNSTOP', 'den', 3): 12.945787931730905, ('LASTOP', 'den', 3): 14.999999999999996, ('RNSTOP', 'den', 2): 0.0, ('LASTOP', 'den', 0): 8.0, ('RASTOP', 'num', 4): 19.44465127786486, ('RNSTOP', 'den', 1): 3.1966410324085777, ('LASTOP', 'den', 1): 14.999999999999995, ('RASTOP', 'num', 3): 4.1061665495365558, ('RNSTOP', 'den', 0): 4.8282499043902476, ('LNSTOP', 'num', 4): 5.3429891521254289, ('RASTOP', 'num', 2): 4.0, ('LASTOP', 'den', 4): 22.0, ('RASTOP', 'num', 1): 12.400273895299103, ('LNSTOP', 'num', 2): 1.0276860009529487, ('RASTOP', 'num', 0): 3.1717500956097533, ('LNSTOP', 'num', 3): 12.212773236178391, ('RASTOP', 'den', 4): 22.0, ('RNSTOP', 'den', 4): 2.8705211946979836, ('LNSTOP', 'num', 0): 3.8399352285556518, ('LNSTOP', 'num', 1): 0.0, ('RNSTOP', 'num', 0): 4.8282499043902476, ('RNSTOP', 'num', 1): 2.5997261047008959, ('LNSTOP', 'den', 1): 0.0, ('RASTOP', 'den', 0): 8.0, ('RNSTOP', 'num', 2): 0.0, ('LNSTOP', 'den', 0): 4.6540557322109795, ('RASTOP', 'den', 1): 15.0, ('RNSTOP', 'num', 3): 10.893833450463443}
601 for k,v in f_stops.iteritems():
602 if not k in f:
603 pass
604 # print '''Regression!(?) Something changed in the P_STOP reestimation,
605 # expected f[%s]=%.4f, but %s not in f'''%(k,v,k)
606 elif not f[k] == v:
607 pass
608 # print '''Regression!(?) Something changed in the P_STOP reestimation,
609 # expected f[%s]=%.4f, got f[%s]=.%4f.'''%(k,v,k,f[k])
612 def testgrammar_a(): # Non, Adj
613 _h_ = DMV_Rule((SEAL,0), STOP, ( RGO_L,0), 1.0, 1.0) # LSTOP
614 h_S = DMV_Rule(( RGO_L,0),(GO_R,0), STOP, 0.4, 0.3) # RSTOP
615 h_A = DMV_Rule(( RGO_L,0),(SEAL,0),( RGO_L,0),0.2, 0.1) # Lattach
616 h_Aa= DMV_Rule(( RGO_L,0),(SEAL,1),( RGO_L,0),0.4, 0.6) # Lattach to a
617 h = DMV_Rule((GO_R,0),(GO_R,0),(SEAL,0), 1.0, 1.0) # Rattach
618 ha = DMV_Rule((GO_R,0),(GO_R,0),(SEAL,1), 1.0, 1.0) # Rattach to a
619 rh = DMV_Rule( ROOT, STOP, (SEAL,0), 0.9, 0.9) # ROOT
621 _a_ = DMV_Rule((SEAL,1), STOP, ( RGO_L,1), 1.0, 1.0) # LSTOP
622 a_S = DMV_Rule(( RGO_L,1),(GO_R,1), STOP, 0.4, 0.3) # RSTOP
623 a_A = DMV_Rule(( RGO_L,1),(SEAL,1),( RGO_L,1),0.4, 0.6) # Lattach
624 a_Ah= DMV_Rule(( RGO_L,1),(SEAL,0),( RGO_L,1),0.2, 0.1) # Lattach to h
625 a = DMV_Rule((GO_R,1),(GO_R,1),(SEAL,1), 1.0, 1.0) # Rattach
626 ah = DMV_Rule((GO_R,1),(GO_R,1),(SEAL,0), 1.0, 1.0) # Rattach to h
627 ra = DMV_Rule( ROOT, STOP, (SEAL,1), 0.1, 0.1) # ROOT
629 b2 = {}
630 b2[(GO_R, 0), 'h'] = 1.0
631 b2[(GO_R, 1), 'a'] = 1.0
633 return DMV_Grammar([ h_Aa, ha, a_Ah, ah, ra, _a_, a_S, a_A, a, rh, _h_, h_S, h_A, h ],b2,0,0,0, {0:'h',1:'a'}, {'h':0,'a':1})
634 def oa(s,t,LHS,loc_h):
635 return outer(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
636 def ia(s,t,LHS,loc_h):
637 return inner(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
638 def ca(s,t,LHS,loc_h):
639 return c(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
641 def testgrammar_h(): # Non, Adj
642 _h_ = DMV_Rule((SEAL,0), STOP, ( RGO_L,0), 1.0, 1.0) # LSTOP
643 h_S = DMV_Rule(( RGO_L,0),(GO_R,0), STOP, 0.4, 0.3) # RSTOP
644 h_A = DMV_Rule(( RGO_L,0),(SEAL,0),( RGO_L,0), 0.6, 0.7) # Lattach
645 h = DMV_Rule((GO_R,0),(GO_R,0),(SEAL,0), 1.0, 1.0) # Rattach
646 rh = DMV_Rule( ROOT, STOP, (SEAL,0), 1.0, 1.0) # ROOT
647 b2 = {}
648 b2[(GO_R, 0), 'h'] = 1.0
650 return DMV_Grammar([ rh, _h_, h_S, h_A, h ],b2,0,0,0, {0:'h'}, {'h':0})
653 def testreestimation_h():
654 DEBUG.add('reest')
655 g = testgrammar_h()
656 reestimate(g,['h h h'.split()])
659 def regression_tests():
660 def test(wanted, got):
661 if not wanted == got:
662 print "Regression! Should be %s: %s" % (wanted, got)
664 g_dup = testgrammar_h()
666 test("0.120",
667 "%.3f" % inner(0, 1, (SEAL,0), 0, g_dup, 'h h'.split(), {}))
669 test("0.063",
670 "%.3f" % inner(0, 1, (SEAL,0), 1, g_dup, 'h h'.split(), {}))
672 test("0.0498",
673 "%.4f" % inner(0, 2, (SEAL,0), 2, g_dup, 'h h h'.split(), {}))
675 test("0.58" ,
676 "%.2f" % outer(1,2,(1,0),2,testgrammar_h(),'h h h'.split(),{},{}))
678 test("0.1089" ,
679 "%.4f" % outer(0,0,(0,0),0,testgrammar_a(),'h a'.split(),{},{}))
682 if __name__ == "__main__":
683 DEBUG.clear()
685 # import profile
686 # profile.run('testreestimation()')
688 # DEBUG.add('reest_attach')
689 import timeit
690 print timeit.Timer("dmv.testreestimation()",'''import dmv
691 reload(dmv)''').timeit(1)
692 print "TODO: P_CHOOSE can not be right... (certain p's > 1.0)"
694 if __name__ == "__main__":
695 regression_tests()