inner and outer done for cnf_dmv.py, seems to work. todo: reestimation for cnf
[dmvccm.git] / src / cnf_dmv.py
blobb0d68aeb2474116db23f1894b1ff9c017124c1a5
1 # cnf_dmv.py
3 #import numpy # numpy provides Fast Arrays, for future optimization
4 import io
6 from common_dmv import *
7 SEALS = [GOR, RGOL, SEAL, NGOR, NRGOL] # overwriting here
10 if __name__ == "__main__":
11 print "cnf_dmv module tests:"
13 def make_GO_AT(p_STOP,p_ATTACH):
14 p_GO_AT = {}
15 for (a,h,dir), p_ah in p_ATTACH.iteritems():
16 p_GO_AT[a,h,dir, NON] = p_ah * (1-p_STOP[h, dir, NON])
17 p_GO_AT[a,h,dir, ADJ] = p_ah * (1-p_STOP[h, dir, ADJ])
18 return p_GO_AT
20 class CNF_DMV_Grammar(io.Grammar):
21 '''The DMV-PCFG.
23 Public members:
24 p_STOP, p_ROOT, p_ATTACH, p_terminals
25 These are changed in the Maximation step, then used to set the
26 new probabilities of each CNF_DMV_Rule.
28 __p_rules is private, but we can still say stuff like:
29 for r in g.all_rules():
30 r.prob = (1-p_STOP[...]) * p_ATTACH[...]
31 '''
32 def __str__(self):
33 str = ""
34 for r in self.all_rules():
35 str += "%s\n" % r.__str__(self.numtag)
36 return str
38 def all_rules(self):
39 LHSs = [ROOT] + [(s_h,h)
40 for h in self.headnums()
41 for s_h in SEALS]
42 return [r for LHS in LHSs
43 for r in self.rules(LHS, self.headnums())]
45 # used in outer:
46 def mothersR(self, w_node, argnums):
47 '''For all LHS and x, return all rules of the form 'LHS->x w_node'.'''
48 LHSs = [ROOT] + [(s_h,h)
49 for h in set(argnums)
50 for s_h in SEALS]
51 return [r for LHS in LHSs
52 for r in self.rules(LHS, argnums)
53 if r.R() == w_node]
55 def mothersL(self, w_node, argnums):
56 '''For all LHS and x, return all rules of the form 'LHS->w_node x'.'''
57 LHSs = [ROOT] + [(s_h,h)
58 for h in set(argnums)
59 for s_h in SEALS]
60 return [r for LHS in LHSs
61 for r in self.rules(LHS, argnums)
62 if r.L() == w_node]
64 # used in inner:
65 def rules(self, LHS, argnums):
66 '''This is where the CNF grammar is defined (on the fly, so it's
67 probably slow, but hey, this is for testing). Also, s_dir_typ
68 shows how useful it'd be to split up the seals into direction
69 and type... todo?'''
70 h = POS(LHS)
71 if LHS == ROOT:
72 return [CNF_DMV_Rule(LEFT, LHS, (SEAL,h), STOP, self.p_ROOT[h])
73 for h in set(argnums)]
74 s_h = seals(LHS)
75 if s_h == GOR:
76 return [] # only terminals from here on
77 s_dir_type = { # seal of LHS
78 RGOL: (RIGHT, 'STOP'), NGOR: (RIGHT, 'ATTACH'),
79 SEAL: (LEFT, 'STOP'), NRGOL: (LEFT, 'ATTACH') }
80 dir_s_adj = { # seal of h_daughter
81 RIGHT: [(GOR, True),(NGOR, False)] ,
82 LEFT: [(RGOL,True),(NRGOL,False)] }
83 dir,type = s_dir_type[s_h]
84 rule = {
85 'ATTACH': [CNF_DMV_Rule(dir, LHS, (s, h), (SEAL,a), self.p_GO_AT[a,h,dir,adj])
86 for a in set(argnums) if (a,h,dir) in self.p_ATTACH
87 for s, adj in dir_s_adj[dir]] ,
88 'STOP': [CNF_DMV_Rule(dir, LHS, (s, h), STOP, self.p_STOP[h,dir,adj])
89 for s, adj in dir_s_adj[dir]] }
90 return rule[type]
92 def __init__(self, numtag, tagnum, p_ROOT, p_STOP, p_ATTACH, p_terminals):
93 io.Grammar.__init__(self, numtag, tagnum, [], p_terminals)
94 self.p_STOP = p_STOP
95 self.p_ATTACH = p_ATTACH
96 self.p_ROOT = p_ROOT
97 self.p_GO_AT = make_GO_AT(self.p_STOP, self.p_ATTACH)
100 class CNF_DMV_Rule(io.CNF_Rule):
101 '''A single CNF rule in the PCFG, of the form
102 LHS -> L R
103 where LHS, L and R are 'nodes', eg. of the form (seals, head).
105 Public members:
106 prob
108 Private members:
109 __L, __R, __LHS
111 Different rule-types have different probabilities associated with
112 them, see formulas.pdf
114 def seals(self):
115 return seals(self.LHS())
117 def POS(self):
118 return POS(self.LHS())
120 def __init__(self, dir, LHS, h_daughter, a_daughter, prob):
121 self.__dir = dir
122 if dir == LEFT:
123 L, R = a_daughter, h_daughter
124 elif dir == RIGHT:
125 L, R = h_daughter, a_daughter
126 else:
127 raise ValueError, "dir must be LEFT or RIGHT, given: %s"%dir
128 for b_h in [LHS, L, R]:
129 if seals(b_h) not in SEALS:
130 raise ValueError("seals must be in %s; was given: %s"
131 % (SEALS, seals(b_h)))
132 io.CNF_Rule.__init__(self, LHS, L, R, prob)
134 def adj(self):
135 "'undefined' for ROOT"
136 if self.__dir == LEFT:
137 return seals(self.R()) == RGOL
138 else: # RIGHT
139 return seals(self.L()) == GOR
141 def __str__(self, tag=lambda x:x):
142 if self.adj(): adj_str = "adj"
143 else: adj_str = "non_adj"
144 if self.LHS() == ROOT: adj_str = ""
145 return "%s --> %s %s\t[%.2f] %s" % (node_str(self.LHS(), tag),
146 node_str(self.L(), tag),
147 node_str(self.R(), tag),
148 self.prob,
149 adj_str)
157 ###################################
158 # dmv-specific version of inner() #
159 ###################################
160 def inner(i, j, LHS, g, sent, ichart={}):
161 ''' A CNF rewrite of io.inner(), to take STOP rules into accord. '''
162 def O(i,j):
163 return sent[i]
165 sent_nums = g.sent_nums(sent)
167 def e(i,j,LHS, n_t):
168 def tab():
169 "Tabs for debug output"
170 return "\t"*n_t
171 if (i, j, LHS) in ichart:
172 if 'INNER' in DEBUG:
173 print "%s*= %.4f in ichart: i:%d j:%d LHS:%s" % (tab(), ichart[i, j, LHS], i, j, node_str(LHS))
174 return ichart[i, j, LHS]
175 else:
176 # if seals(LHS) == RGOL then we have to STOP first
177 if i == j-1 and seals(LHS) == GOR:
178 if (LHS, O(i,j)) in g.p_terminals:
179 prob = g.p_terminals[LHS, O(i,j)] # "b[LHS, O(s)]" in Lari&Young
180 else:
181 prob = 0.0
182 if 'INNER' in DEBUG:
183 print "%sLACKING TERMINAL:" % tab()
184 if 'INNER' in DEBUG:
185 print "%s*= %.4f (terminal: %s -> %s)" % (tab(),prob, node_str(LHS), O(i,j))
186 return prob
187 else:
188 p = 0.0 # "sum over j,k in a[LHS,j,k]"
189 for rule in g.rules(LHS, sent_nums):
190 if 'INNER' in DEBUG:
191 print "%ssumming rule %s i:%d j:%d" % (tab(),rule,i,j)
192 L = rule.L()
193 R = rule.R()
194 # if it's a STOP rule, rewrite for the same xrange:
195 if (L == STOP) or (R == STOP):
196 if L == STOP:
197 pLR = e(i, j, R, n_t+1)
198 elif R == STOP:
199 pLR = e(i, j, L, n_t+1)
200 p += rule.p() * pLR
201 if 'INNER' in DEBUG:
202 print "%sp= %.4f (STOP)" % (tab(), p)
204 elif j > i+1 and seals(LHS) != GOR:
205 # not a STOP, attachment rewrite:
206 for k in xtween(i, j): # i<k<j
207 p_L = e(i, k, L, n_t+1)
208 p_R = e(k, j, R, n_t+1)
209 p += rule.p() * p_L * p_R
210 if 'INNER' in DEBUG:
211 print "%sp= %.4f (ATTACH, p_L:%.4f, p_R:%.4f, rule:%.4f)" % (tab(), p,p_L,p_R,rule.p())
212 ichart[i, j, LHS] = p
213 return p
214 # end of e-function
216 inner_prob = e(i,j,LHS, 0)
217 if 'INNER' in DEBUG:
218 print debug_ichart(g,sent,ichart)
219 return inner_prob
220 # end of cnf_dmv.inner(i, j, LHS, g, sent, ichart={})
223 def debug_ichart(g,sent,ichart):
224 str = "---ICHART:---\n"
225 for (i,j,LHS),v in ichart.iteritems():
226 if type(v) == dict: # skip 'tree'
227 continue
228 str += "%s -> %s ... %s: \t%.4f\n" % (node_str(LHS,g.numtag),
229 sent[i], sent[j-1], v)
230 str += "---ICHART:end---\n"
231 return str
234 def inner_sent(g, sent, ichart={}):
235 return sum([inner(0, len(sent), ROOT, g, sent, ichart)])
238 def c(i,j,LHS,g,sent,ichart={},ochart={}):
239 p_sent = inner_sent(g, sent, ichart)
240 p_in = inner(i,j,LHS,g,sent,ichart)
241 p_out = outer(i,j,LHS,g,sent,ichart,ochart)
242 if p_sent > 0.0:
243 return p_in * p_out / p_sent
244 else:
245 return p_sent
247 #######################################
248 # cnf_dmv-specific version of outer() # todo below
249 #######################################
250 def outer(i,j,w_node, g, sent, ichart={}, ochart={}):
251 def e(i,j,LHS):
252 # or we could just look it up in ichart, assuming ichart to be done
253 return inner(i, j, LHS, g, sent, ichart)
255 sent_nums = g.sent_nums(sent)
257 def f(i,j,w_node):
258 if (i,j,w_node) in ochart:
259 return ochart[(i, j, w_node)]
260 if w_node == ROOT:
261 if i == 0 and j == len(sent):
262 return 1.0
263 else: # ROOT may only be used on full sentence
264 return 0.0 # but we may have non-ROOTs over full sentence too
266 p = 0.0
268 for rule in g.mothersL(w_node, sent_nums): # rule.L() == w_node
269 if 'OUTER' in DEBUG: print "w_node:%s (L) ; %s"%(node_str(w_node),rule)
270 if rule.R() == STOP:
271 p0 = f(i,j,rule.LHS()) * rule.p()
272 if 'OUTER' in DEBUG: print p0
273 p += p0
274 else:
275 for k in xgt(j,sent):
276 p0 = f(i,k,LHS) * rule.p() * e(t+1,k,R)
277 if 'OUTER' in DEBUG: print p0
278 p += p0
280 for rule in g.mothersR(w_node, sent_nums): # rule.R() == w_node
281 if 'OUTER' in DEBUG: print "w_node:%s (R) ; %s"%(node_str(w_node),rule)
282 if rule.L() == STOP:
283 p0 = f(i,j,rule.LHS()) * rule.p()
284 if 'OUTER' in DEBUG: print p0
285 p += p0
286 else:
287 for k in xlt(i):
288 p0 = e(k,i,rule.L()) * rule.p() * f(k,j,rule.LHS())
289 if 'OUTER' in DEBUG: print p0
290 p += p0
292 ochart[i,j,w_node] = p
293 return p
296 return f(i,j,w_node)
297 # end outer(i,j,w_node, g,sent, ichart,ochart)
301 ##############################
302 # reestimation, todo: #
303 ##############################
304 def reest_zeros(h_nums):
305 # todo: p_ROOT? ... p_terminals?
306 f = {}
307 for h in h_nums:
308 for stop in ['LNSTOP','LASTOP','RNSTOP','RASTOP']:
309 for nd in ['num','den']:
310 f[stop,nd,h] = 0.0
311 for choice in ['RCHOOSE', 'LCHOOSE']:
312 f[choice,'den',h] = 0.0
313 return f
315 def reest_freq(g, corpus):
316 ''' P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
317 f = reest_zeros(g.headnums())
318 ichart = {}
319 ochart = {}
321 p_sent = None # 50 % speed increase on storing this locally
322 def c_g(s,t,LHS,loc_h,sent): # altogether 2x faster than the global c()
323 if (s,t,LHS,loc_h) in ichart:
324 p_in = ichart[s,t,LHS,loc_h]
325 else:
326 p_in = inner(s,t,LHS,g,sent,ichart)
327 if (s,t,LHS,loc_h) in ochart:
328 p_out = ochart[s,t,LHS,loc_h]
329 else:
330 p_out = outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
332 if p_sent > 0.0:
333 return p_in * p_out / p_sent
334 else:
335 return p_sent
337 def f_g(s,t,LHS,loc_h,sent): # todo: test with choose rules
338 if (s,t,LHS,loc_h) in ochart:
339 return ochart[s,t,LHS,loc_h]
340 else:
341 return outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
343 def e_g(s,t,LHS,loc_h,sent): # todo: test with choose rules
344 if (s,t,LHS,loc_h) in ichart:
345 return ichart[s,t,LHS,loc_h]
346 else:
347 return inner(s,t,LHS,loc_h,g,sent,ichart)
349 def p_g(r,LHS,L,R,loc_h,sent):
350 rules = [rule for rule in g.rules(LHS, sent)
351 if rule.L() == L and rule.R() == R]
352 rule = rules[0]
353 if len(rules) > 1:
354 raise Exception("Several rules matching a[i,j,k]")
355 return rule.p_ATTACH(r,loc_h)
357 for sent in corpus:
358 if 'reest' in DEBUG:
359 print sent
360 ichart = {}
361 ochart = {}
362 p_sent = inner_sent(g, sent, ichart)
364 sent_nums = g.sent_nums(sent)
365 # todo: use sum([ichart[s, t...] etc? but can we then
366 # keep den and num separate within _one_ sum()-call?
367 for loc_h,h in enumerate(sent_nums):
368 for t in xrange(loc_h, len(sent)):
369 for s in xrange(loc_h): # s<loc(h), xrange gives strictly less
370 # left non-adjacent stop:
371 f['LNSTOP','num',h] += c_g(s, t, (SEAL, h), loc_h,sent)
372 f['LNSTOP','den',h] += c_g(s, t, (RGOL,h), loc_h,sent)
373 # left adjacent stop:
374 f['LASTOP','num',h] += c_g(loc_h, t, (SEAL, h), loc_h,sent)
375 f['LASTOP','den',h] += c_g(loc_h, t, (RGOL,h), loc_h,sent)
376 for t in xrange(loc_h+1, len(sent)):
377 # right non-adjacent stop:
378 f['RNSTOP','num',h] += c_g(loc_h, t, (RGOL,h), loc_h,sent)
379 f['RNSTOP','den',h] += c_g(loc_h, t, (GOR, h), loc_h,sent)
380 # right adjacent stop:
381 f['RASTOP','num',h] += c_g(loc_h, loc_h, (RGOL,h), loc_h,sent)
382 f['RASTOP','den',h] += c_g(loc_h, loc_h, (GOR, h), loc_h,sent)
384 # right attachment: TODO: try with p*e*e*f instead of c, for numerator
385 if 'reest_attach' in DEBUG:
386 print "Rattach %s: for t in %s"%(g.numtag(h),sent[loc_h+1:len(sent)])
387 for t in xrange(loc_h+1, len(sent)):
388 cM = c_g(loc_h,t,(GOR, h), loc_h, sent)
389 f['RCHOOSE','den',h] += cM
390 if 'reest_attach' in DEBUG:
391 print "\tc_g( %d , %d, %s, %s, sent)=%.4f"%(loc_h,t,g.numtag(h),loc_h,cM)
392 for r in xrange(loc_h+1, t+1): # loc_h < r <= t
393 c_L = c_g(loc_h, r-1, (GOR, h), loc_h, sent)
394 if 'reest_attach' in DEBUG:
395 print "\t\tc_g( %d , %d, %s, %d, sent)=%.4f"%(loc_h,r-1,g.numtag(h),loc_h,c_L)
396 for i,a in enumerate(sent_nums[r:t+1]):
397 loc_a = i+r
398 c_R = c_g(r, t, (SEAL, a), loc_a, sent)
399 if ('RCHOOSE','num',h,a) not in f:
400 f['RCHOOSE','num',h,a] = 0.0
401 f['RCHOOSE','num',h,a] += c_R / c_L
402 if 'reest_attach' in DEBUG:
403 print "\t\t\tc_g( %d , %d, _%s_, %d, sent)=%.4f, /c_L = %.4f"%(r,t,g.numtag(a),loc_a,c_R,c_R/c_L)
405 # left attachment:
406 if 'reest_attach' in DEBUG:
407 print "Lattach %s: for s in %s"%(g.numtag(h),sent[0:loc_h])
408 for s in xrange(0, loc_h):
409 if 'reest_attach' in DEBUG:
410 print "\tfor t in %s"%sent[loc_h:len(sent)]
411 for t in xrange(loc_h, len(sent)):
412 c_M = c_g(s,t,(RGOL, h), loc_h, sent) # v_q in L&Y
413 f['LCHOOSE','den',h] += c_M
414 if 'reest_attach' in DEBUG:
415 print "\t\tc_g( %d , %d, %s_, %s, sent)=%.4f"%(s,t,g.numtag(h),loc_h,c_M)
416 if 'reest_attach' in DEBUG:
417 print "\t\tfor r in %s"%(sent[s:loc_h])
418 args = {} # for summing w_q's in L&Y, without 1/P_q
419 for r in xrange(s, loc_h): # s <= r < loc_h <= t
420 e_R = e_g(r+1, t, (RGOL, h), loc_h, sent)
421 if 'reest_attach' in DEBUG:
422 print "\t\tc_g( %d , %d, %s_, %d, sent)=%.4f"%(r+1,t,g.numtag(h),loc_h,e_R)
423 for i,a in enumerate(sent_nums[s:r+1]):
424 loc_a = i+s
425 e_L = e_g( s , r, (SEAL, a), loc_a, sent)
426 if a not in args:
427 args[a] = 0.0
428 args[a] += e_L * e_R * f_g(s,t,(RGOL, h), loc_h, sent) * p_g(r,(RGOL, h),(SEAL, a),(RGOL, h),loc_h,sent_nums)
429 for a,sum_a in args.iteritems():
430 f['LCHOOSE', 'num',h,a] = sum_a / p_sent
431 return f
433 def reestimate(g, corpus):
435 f = reest_freq(g, corpus)
436 # we want to go through only non-ROOT left-STOPs..
437 for r in g.all_rules():
438 reest_rule(r,f, g)
439 return f
442 def reest_rule(r,f, g): # g just for numtag / debug output, remove eventually?
443 "remove 0-prob rules? todo"
444 h = r.POS()
445 if r.LHS() == ROOT:
446 return None # not sure what todo yet here
447 if r.L() == STOP or POS(r.R()) == h:
448 dir = 'L'
449 elif r.R() == STOP or POS(r.L()) == h:
450 dir = 'R'
451 else:
452 raise Exception("Odd rule in reestimation.")
454 p_stopN = f[dir+'NSTOP','den',h]
455 if p_stopN > 0.0:
456 p_stopN = f[dir+'NSTOP','num',h] / p_stopN
458 p_stopA = f[dir+'ASTOP','den',h]
459 if p_stopA > 0.0:
460 p_stopA = f[dir+'ASTOP','num',h] / p_stopA
462 if r.L() == STOP or r.R() == STOP: # stop rules
463 if 'reest' in DEBUG:
464 print "p(STOP|%d=%s,%s,N): %.4f (was: %.4f)"%(h,g.numtag(h),dir, p_stopN, r.probN)
465 print "p(STOP|%d=%s,%s,A): %.4f (was: %.4f)"%(h,g.numtag(h),dir, p_stopA, r.probA)
466 r.probN = p_stopN
467 r.probA = p_stopA
469 else: # attachment rules
470 pchoose = f[dir+'CHOOSE','den',h]
471 if pchoose > 0.0:
472 if POS(r.R()) == h: # left attachment
473 a = POS(r.L())
474 elif POS(r.L()) == h: # right attachment
475 a = POS(r.R())
476 pchoose = f[dir+'CHOOSE','num',h,a] / pchoose
477 r.probN = (1-p_stopN) * pchoose
478 r.probA = (1-p_stopA) * pchoose
479 if 'reest' in DEBUG:
480 print "p(%d=%s|%d=%s,%s): %.4f,\tprobN: %.4f, probA: %.4f"%(a,g.numtag(a),h,g.numtag(h),dir, pchoose,r.probN,r.probA)
488 ##############################
489 # testing functions: #
490 ##############################
491 def testgrammar():
492 # make sure we use the same data:
493 from loc_h_dmv import testcorpus
495 import cnf_harmonic
496 reload(cnf_harmonic)
497 return cnf_harmonic.initialize(testcorpus)
499 def testreestimation():
500 g = testgrammar()
501 f = reestimate(g, testcorpus)
503 def testgrammar_a(): # Non, Adj
504 _h_ = CNF_DMV_Rule((SEAL,0), STOP, ( RGOL,0), 1.0, 1.0) # LSTOP
505 h_S = CNF_DMV_Rule(( RGOL,0),(GOR,0), STOP, 0.4, 0.3) # RSTOP
506 h_A = CNF_DMV_Rule(( RGOL,0),(SEAL,0),( RGOL,0),0.2, 0.1) # Lattach
507 h_Aa= CNF_DMV_Rule(( RGOL,0),(SEAL,1),( RGOL,0),0.4, 0.6) # Lattach to a
508 h = CNF_DMV_Rule((GOR,0),(GOR,0),(SEAL,0), 1.0, 1.0) # Rattach
509 ha = CNF_DMV_Rule((GOR,0),(GOR,0),(SEAL,1), 1.0, 1.0) # Rattach to a
510 rh = CNF_DMV_Rule( ROOT, STOP, (SEAL,0), 0.9, 0.9) # ROOT
512 _a_ = CNF_DMV_Rule((SEAL,1), STOP, ( RGOL,1), 1.0, 1.0) # LSTOP
513 a_S = CNF_DMV_Rule(( RGOL,1),(GOR,1), STOP, 0.4, 0.3) # RSTOP
514 a_A = CNF_DMV_Rule(( RGOL,1),(SEAL,1),( RGOL,1),0.4, 0.6) # Lattach
515 a_Ah= CNF_DMV_Rule(( RGOL,1),(SEAL,0),( RGOL,1),0.2, 0.1) # Lattach to h
516 a = CNF_DMV_Rule((GOR,1),(GOR,1),(SEAL,1), 1.0, 1.0) # Rattach
517 ah = CNF_DMV_Rule((GOR,1),(GOR,1),(SEAL,0), 1.0, 1.0) # Rattach to h
518 ra = CNF_DMV_Rule( ROOT, STOP, (SEAL,1), 0.1, 0.1) # ROOT
520 p_rules = [ h_Aa, ha, a_Ah, ah, ra, _a_, a_S, a_A, a, rh, _h_, h_S, h_A, h ]
523 b = {}
524 b[(GOR, 0), 'h'] = 1.0
525 b[(GOR, 1), 'a'] = 1.0
527 return CNF_DMV_Grammar({0:'h',1:'a'}, {'h':0,'a':1},
528 None,None,None,b)
530 def testgrammar_h():
531 h = 0
532 p_ROOT, p_STOP, p_ATTACH, p_ORDER = {},{},{},{}
533 p_ROOT[h] = 1.0
534 p_STOP[h,LEFT,NON] = 1.0
535 p_STOP[h,LEFT,ADJ] = 1.0
536 p_STOP[h,RIGHT,NON] = 0.4
537 p_STOP[h,RIGHT,ADJ] = 0.3
538 p_ATTACH[h,h,LEFT] = 1.0 # not used
539 p_ATTACH[h,h,RIGHT] = 1.0 # not used
540 p_terminals = {}
541 p_terminals[(GOR, 0), 'h'] = 1.0
543 g = CNF_DMV_Grammar({h:'h'}, {'h':h}, p_ROOT, p_STOP, p_ATTACH, p_terminals)
545 g.p_GO_AT[h,h,LEFT,NON] = 0.6 # these probabilities are impossible
546 g.p_GO_AT[h,h,LEFT,ADJ] = 0.7 # so add them manually...
547 g.p_GO_AT[h,h,RIGHT,NON] = 1.0
548 g.p_GO_AT[h,h,RIGHT,ADJ] = 1.0
549 return g
552 def testreestimation_h():
553 DEBUG.add('reest')
554 g = testgrammar_h()
555 reestimate(g,['h h h'.split()])
557 def test(wanted, got):
558 if not wanted == got:
559 raise Warning, "Regression! Should be %s: %s" % (wanted, got)
561 def regression_tests():
562 test("0.1830", # = .120 + .063, since we have no loc_h
563 "%.4f" % inner(0, 2, (SEAL,0), testgrammar_h(), 'h h'.split(), {}))
565 test("0.1842", # = .0498 + .1092 +.0252
566 "%.4f" % inner(0, 3, (SEAL,0), testgrammar_h(), 'h h h'.split(), {}))
568 test("0.61" ,
569 "%.2f" % outer(1, 3, ( RGOL,0), testgrammar_h(),'h h h'.split(),{},{}))
570 test("0.58" ,
571 "%.2f" % outer(1, 3, (NRGOL,0), testgrammar_h(),'h h h'.split(),{},{}))
574 if __name__ == "__main__":
575 DEBUG.clear()
577 # import profile
578 # profile.run('testreestimation()')
580 # DEBUG.add('reest_attach')
581 # import timeit
582 # print timeit.Timer("cnf_dmv.testreestimation_h()",'''import cnf_dmv
583 # reload(cnf_dmv)''').timeit(1)
585 if __name__ == "__main__":
586 regression_tests()
587 # g = testgrammar()
588 # print g