another little fix to stop initialization
[dmvccm.git] / src / loc_h_dmv.py.before_betweens
blob3074268c70a45c55cfae322d22ff03265fef6c11
1 # loc_h_dmv.py
2
3 # dmv reestimation and inside-outside probabilities using loc_h
5 #import numpy # numpy provides Fast Arrays, for future optimization
6 import io
7 from common_dmv import *
9 if __name__ == "__main__":
10     print "loc_h_dmv module tests:"
12 class DMV_Grammar(io.Grammar):
13     '''The DMV-PCFG.
15     Public members:
16     p_STOP, p_ROOT, p_CHOOSE, p_terminals
17     These are changed in the Maximation step, then used to set the
18     new probabilities of each DMV_Rule.
20     Todo: make p_terminals private? (But it has to be changable in
21     maximation step due to the short-cutting rules... could of course
22     make a DMV_Grammar function to update the short-cut rules...)
24     __p_rules is private, but we can still say stuff like:
25     for r in g.all_rules():
26         r.probN = newProbN
27     
28     What other representations do we need? (P_STOP formula uses
29     deps_D(h,l/r) at least)'''
30     def __str__(self):
31         str = ""
32         for r in self.all_rules():
33              str += "%s\n" % r.__str__(self.numtag)
34         return str
36     def h_rules(self, h):
37         return [r for r in self.all_rules() if r.head() == h]
38     
39     def mothersL(self, Node, sent_nums, loc_N):
40         # todo: speed-test with and without sent_nums/loc_N cut-off
41         return [r for r in self.all_rules() if r.L() == Node
42                 and (head(r.R()) in sent_nums[loc_N+1:] or r.R() == STOP)]
43     
44     def mothersR(self, Node, sent_nums, loc_N):
45         return [r for r in self.all_rules() if r.R() == Node
46                 and (head(r.L()) in sent_nums[:loc_N] or r.L() == STOP)]
48     def rules(self, LHS):
49         return [r for r in self.all_rules() if r.LHS() == LHS]
50     
51     def sent_rules(self, LHS, sent_nums):
52         '''Used in dmv.inner. Todo: this takes a _lot_ of time, it
53         seems. Could use some more space and cache some of this
54         somehow perhaps?'''
55         # We don't want to rule out STOPs!
56         nums = sent_nums + [ head(STOP) ]
57         return [r for r in self.all_rules() if r.LHS() == LHS
58                 and head(r.L()) in nums and head(r.R()) in nums]
59     
60     def deps_L(self, head): # todo: do I use this at all?
61         # todo test, probably this list comprehension doesn't work 
62         return [a for r in self.all_rules() if r.head() == head and a == r.L()]
63     
64     def deps_R(self, head):
65         # todo test, probably this list comprehension doesn't work 
66         return [a for r in self.all_rules() if r.head() == head and a == r.R()]
67     
68     def __init__(self, p_rules, p_terminals, p_STOP, p_CHOOSE, p_ROOT, numtag, tagnum):
69         io.Grammar.__init__(self, p_rules, p_terminals, numtag, tagnum)
70         self.p_STOP = p_STOP
71         self.p_CHOOSE = p_CHOOSE
72         self.p_ROOT = p_ROOT
73         self.head_nums = [k for k in numtag.iterkeys()]
74         
76 class DMV_Rule(io.CNF_Rule):
77     '''A single CNF rule in the PCFG, of the form 
78     LHS -> L R
79     where LHS, L and R are 'nodes', eg. of the form (seals, head).
80     
81     Public members:
82     probN, probA
83     
84     Private members:
85     __L, __R, __LHS
86     
87     Different rule-types have different probabilities associated with
88     them:
90     _h_ -> STOP  h_     P( STOP|h,L,    adj)
91     _h_ -> STOP  h_     P( STOP|h,L,non_adj)
92      h_ ->  h  STOP     P( STOP|h,R,    adj)
93      h_ ->  h  STOP     P( STOP|h,R,non_adj)
94      h_ -> _a_   h_     P(-STOP|h,L,    adj) * P(a|h,L)
95      h_ -> _a_   h_     P(-STOP|h,L,non_adj) * P(a|h,L)
96      h  ->  h   _a_     P(-STOP|h,R,    adj) * P(a|h,R)
97      h  ->  h   _a_     P(-STOP|h,R,non_adj) * P(a|h,R) 
98     '''
99     def p(self, adj, *arg):
100         if adj:
101             return self.probA
102         else:
103             return self.probN
105     def adj(middle, loc_h):
106         "middle is eg. k when rewriting for i<k<j (inside probabilities)."
107         return middle == loc_h[0] or middle == loc_h[1]
109     def p_STOP(self, s, t, loc_h):
110         '''Returns the correct probability, adjacent if we're rewriting from
111         the (either left or right) end of the fragment.
112         '''
113         if self.L() == STOP:
114             return self.p(s == loc_h)
115         elif self.R() == STOP:
116             if not loc_h == s:
117                 if 'TODO' in io.DEBUG:
118                     print "(%s given loc_h:%d but s:%d. Todo: optimize away!)" % (self, loc_h, s) 
119                 return 0.0
120             else:
121                 return self.p(t == loc_h)
122             
123     def p_ATTACH(self, r, loc_h, s=None):
124         '''Returns the correct probability, adjacent if we haven't attached
125         anything before.
126         (This is actually p_choose*(1-p_stop).)'''
127         if self.LHS() == self.L():
128             if s and not loc_h == s:
129                 if 'TODO' in io.DEBUG:
130                     print "(%s given loc_h (loc_L):%d but s:%d. Todo: optimize away!)" % (self, loc_h, s) 
131                 return 0.0
132             else:
133                 return self.p(r == loc_h)
134         elif self.LHS() == self.R():
135             return self.p(r+1 == loc_h)
136         
137     def seals(self):
138         return seals(self.LHS())
139     
140     def head(self):
141         return head(self.LHS())
142     
143     def __init__(self, LHS, L, R, probN, probA):
144         for b_h in [LHS, L, R]:
145             if seals(b_h) not in SEALS:
146                 raise ValueError("seals must be in %s; was given: %s"
147                                  % (SEALS, seals(b_h)))
148         io.CNF_Rule.__init__(self, LHS, L, R, probN)
149         self.probA = probA # adjacent
150         self.probN = probN # non_adj
151         
152     @classmethod # so we can call DMV_Rule.bar_str(b_h) 
153     def bar_str(cls, b_h, tag=lambda x:x):
154         if(b_h == ROOT):
155             return 'ROOT'
156         elif(b_h == STOP):
157             return 'STOP'
158         elif(seals(b_h) == RGOL):
159             return " %s_ " % tag(head(b_h))
160         elif(seals(b_h) == SEAL):
161             return "_%s_ " % tag(head(b_h))
162         else:
163             return " %s  " % tag(head(b_h))
165     
166     def __str__(self, tag=lambda x:x):
167         return "%s-->%s %s\t[N %.2f] [A %.2f]" % (self.bar_str(self.LHS(), tag),
168                                                   self.bar_str(self.L(), tag),
169                                                   self.bar_str(self.R(), tag),
170                                                   self.probN,
171                                                   self.probA)
172     
174     
179 ###################################
180 # dmv-specific version of inner() #
181 ###################################
182 def locs(h, sent, s=0, t=None, remove=None):
183     '''Return the locations of h in sent, or some fragment of sent (in the
184     latter case we make sure to offset the locations correctly so that
185     for any x in the returned list, sent[x]==h).
187     t is inclusive, to match the way indices work with inner()
188     (although python list-splicing has "exclusive" end indices)'''
189     if t == None:
190         t = len(sent)-1
191     return [i+s for i,w in enumerate(sent[s:t+1])
192             if w == h and not (i+s) == remove]
195 def inner(s, t, LHS, loc_h, g, sent, ichart={}):
196     ''' A rewrite of io.inner(), to take adjacency into accord.
198     The ichart is now of this form:
199     ichart[s,t,LHS, loc_h]
200     
201     loc_h gives adjacency (along with r and location of other child
202     for attachment rules), and is needed in P_STOP reestimation.
203     
204     Todo: if possible, refactor (move dmv-specific stuff back into
205     dmv, so this is "general" enough to be in io.py)
206     '''
207     
208     def O(s):
209         return sent[s]
210     
211     sent_nums = g.sent_nums(sent)
212     
213     def e(s,t,LHS, loc_h, n_t):
214         def tab():
215             "Tabs for debug output"
216             return "\t"*n_t
217         
218         if (s, t, LHS, loc_h) in ichart:
219             if 'INNER' in io.DEBUG:
220                 print "%s*= %.4f in ichart: s:%d t:%d LHS:%s loc:%d" % (tab(),ichart[s, t, LHS, loc_h], s, t,
221                                                                        DMV_Rule.bar_str(LHS), loc_h)
222             return ichart[s, t, LHS, loc_h]
223         else:
224             if s == t and seals(LHS) == GOR:
225                 if not loc_h == s:
226                     if 'INNER' in io.DEBUG:
227                         print "%s*= 0.0 (wrong loc_h)" % tab()
228                     return 0.0
229                 elif (LHS, O(s)) in g.p_terminals:
230                     prob = g.p_terminals[LHS, O(s)] # "b[LHS, O(s)]" in Lari&Young
231                 else:
232                     # todo: assuming this is how to deal w/lacking
233                     # rules, since we add prob.s, and 0 is identity
234                     prob = 0.0 
235                     if 'INNER' in io.DEBUG:
236                         print "%sLACKING TERMINAL:" % tab()
237                 # todo: add to ichart perhaps? Although, it _is_ simple lookup..
238                 if 'INNER' in io.DEBUG:
239                     print "%s*= %.4f (terminal: %s -> %s_%d)" % (tab(),prob, DMV_Rule.bar_str(LHS), O(s), loc_h) 
240                 return prob
241             else:
242                 p = 0.0 # "sum over j,k in a[LHS,j,k]"
243                 for rule in g.sent_rules(LHS, sent_nums): 
244                     if 'INNER' in io.DEBUG:
245                         print "%ssumming rule %s s:%d t:%d loc:%d" % (tab(),rule,s,t,loc_h) 
246                     L = rule.L()
247                     R = rule.R()
248                     if loc_h == t and LHS == L:
249                         continue # todo: speed-test
250                     if loc_h == s and LHS == R:
251                         continue 
252                     # if it's a STOP rule, rewrite for the same xrange:
253                     if (L == STOP) or (R == STOP):
254                         if L == STOP:
255                             pLR = e(s, t, R, loc_h, n_t+1)
256                         elif R == STOP:
257                             pLR = e(s, t, L, loc_h, n_t+1)
258                         p += rule.p_STOP(s, t, loc_h) * pLR
259                         if 'INNER' in io.DEBUG:
260                             print "%sp= %.4f (STOP)" % (tab(), p) 
261                             
262                     elif t > s: # not a STOP, attachment rewrite:
263                         rp_ATTACH = rule.p_ATTACH # todo: profile/speedtest
264                         for r in xrange(s, t):
265                             p_h = rp_ATTACH(r, loc_h, s=s)
266                             if LHS == L: 
267                                 locs_L = [loc_h]
268                                 locs_R = locs(head(R), sent_nums, r+1, t, loc_h)
269                             elif LHS == R: 
270                                 locs_L = locs(head(L), sent_nums,  s,  r, loc_h)
271                                 locs_R = [loc_h]
272                             for loc_L in locs_L:
273                                 pL = e(s, r, L, loc_L, n_t+1)
274                                 if pL > 0.0: 
275                                     for loc_R in locs_R:
276                                         pR = e(r+1, t, R, loc_R, n_t+1)
277                                         p += pL * p_h * pR
278                             if 'INNER' in io.DEBUG:
279                                 print "%sp= %.4f (ATTACH)" % (tab(), p) 
280                 ichart[s, t, LHS, loc_h] = p
281                 return p
282     # end of e-function
283             
284     inner_prob = e(s,t,LHS,loc_h, 0)
285     if 'INNER' in io.DEBUG:
286         print debug_ichart(g,sent,ichart)
287     return inner_prob
288 # end of dmv.inner(s, t, LHS, loc_h, g, sent, ichart={})
291 def debug_ichart(g,sent,ichart):
292     str = "---ICHART:---\n"
293     for (s,t,LHS,loc_h),v in ichart.iteritems():
294         if type(v) == dict: # skip 'tree'
295             continue
296         str += "%s -> %s_%d ... %s_%d (loc_h:%s):\t%.4f\n" % (DMV_Rule.bar_str(LHS,g.numtag),
297                                                               sent[s], s, sent[s], t, loc_h, v)
298     str += "---ICHART:end---\n"
299     return str
302 def inner_sent(g, sent, ichart={}):
303     return sum([inner(0, len(sent)-1, ROOT, loc_h, g, sent, ichart)
304                 for loc_h in xrange(len(sent))])
307 ###################################
308 # dmv-specific version of outer() #
309 ###################################
310 def outer(s,t,Node,loc_N, g, sent, ichart={}, ochart={}):
311     ''' http://www.student.uib.no/~kun041/dmvccm/DMVCCM.html#outer
312     '''
313     def e(s,t,LHS,loc_h):
314         # or we could just look it up in ichart, assuming ichart to be done
315         return inner(s, t, LHS, loc_h, g, sent, ichart)
316     
317     T = len(sent)-1
318     sent_nums = g.sent_nums(sent)
319     
320     def f(s,t,Node,loc_N):
321         if (s,t,Node,loc_N) in ochart:
322             return ochart[(s, t, Node,loc_N)]
323         if Node == ROOT:
324             if s == 0 and t == T:
325                 return 1.0
326             else: # ROOT may only be used on full sentence
327                 return 0.0 # but we may have non-ROOTs over full sentence too
328         p = 0.0
329         
330         for mom in g.mothersL(Node, sent_nums, loc_N): # mom.L() == Node
331             R = mom.R()
332             mLHS = mom.LHS()
333             if R == STOP:
334                 p += f(s,t,mLHS,loc_N) * mom.p_STOP(s,t,loc_N) # == loc_m
335             else:
336                 if seals(mLHS) == RGOL: # left attachment, head(mLHS) == head(R)
337                     for r in xrange(t+1,T+1): # t+1 to lasT 
338                         for loc_m in locs(head(mLHS),sent_nums,t+1,r):
339                             p_m = mom.p(t+1 == loc_m)
340                             p += f(s,r,mLHS,loc_m) * p_m * e(t+1,r,R,loc_m)
341                 elif seals(mLHS) == GOR: # right attachment, head(mLHS) == head(Node)
342                     loc_m = loc_N
343                     p_m = mom.p( t  == loc_m)
344                     for r in xrange(t+1,T+1): # t+1 to lasT 
345                         for loc_R in locs(head(R),sent_nums,t+1,r):
346                             p += f(s,r,mLHS,loc_m) * p_m * e(t+1,r,R,loc_R)
347         
348         for mom in g.mothersR(Node, sent_nums, loc_N): # mom.R() == Node
349             L = mom.L()
350             mLHS = mom.LHS()
351             if L == STOP:
352                 p += f(s,t,mLHS,loc_N) * mom.p_STOP(s,t,loc_N) # == loc_m
353             else:
354                 if seals(mLHS) == RGOL: # left attachment, head(mLHS) == head(Node)
355                     loc_m = loc_N
356                     p_m = mom.p( s  == loc_m)
357                     for r in xrange(0,s): # first to s-1 
358                         for loc_L in locs(head(L),sent_nums,r,s-1):
359                             p += e(r,s-1,L, loc_L) * p_m * f(r,t,mLHS,loc_m)
360                 elif seals(mLHS) == GOR: # right attachment, head(mLHS) == head(L)
361                     for r in xrange(0,s): # first to s-1
362                         for loc_m in locs(head(mLHS),sent_nums,r,s-1): 
363                             p_m = mom.p(s-1 == loc_m)
364                             p += e(r,s-1,L, loc_m) * p_m * f(r,t,mLHS,loc_m)
365         ochart[s,t,Node,loc_N] = p
366         return p
368     
369     return f(s,t,Node,loc_N)
370 # end outer(s,t,Node,loc_N, g,sent, ichart,ochart)
374 ##############################
375 #      reestimation, todo:   #
376 ##############################
377 ## using local version instead
378 # def c(s,t,LHS,loc_h,g,sent,ichart={},ochart={}):
379 #     # assuming P_sent = P(D(ROOT)) = inner(sent). todo: check K&M about this
380 #     p_sent = inner_sent(g, sent, ichart)
381 #     p_in = inner(s,t,LHS,loc_h,g,sent,ichart) 
382 #     p_out = outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
383 #     if p_sent > 0.0:
384 #         return p_in * p_out / p_sent
385 #     else:
386 #         return p_sent
388 def reest_zeros(h_nums):
389     # todo: p_ROOT? ... p_terminals?
390     f = {}
391     for h in h_nums:
392         for stop in ['LNSTOP','LASTOP','RNSTOP','RASTOP']:
393             for nd in ['num','den']:
394                 f[stop,nd,h] = 0.0
395         for choice in ['RCHOOSE', 'LCHOOSE']:
396             f[choice,'den',h] = 0.0
397     return f
399 def reest_freq(g, corpus):
400     ''' P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
401     f = reest_zeros(g.head_nums)
402     ichart = {}
403     ochart = {}
404     
405     p_sent = None # 50 % speed increase on storing this locally
406     def c_g(s,t,LHS,loc_h,sent): # altogether 2x faster than the global c()
407         if (s,t,LHS,loc_h) in ichart:
408             p_in = ichart[s,t,LHS,loc_h]
409         else:
410             p_in = inner(s,t,LHS,loc_h,g,sent,ichart) 
411         if (s,t,LHS,loc_h) in ochart:
412             p_out = ochart[s,t,LHS,loc_h]
413         else:
414             p_out = outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
416         if p_sent > 0.0:
417             return p_in * p_out / p_sent
418         else:
419             return p_sent
421     def w_g(s,t,a,loc_a,LHS,loc_h,sent):
422         "Todo: should sum through all r in between s and t in sent(_nums)"
423         h = head(LHS)
424         b_h = seals(LHS)
425         if b_h == GOR:
426             return e_L * e_R * f_g(s,t,(GOR, h), loc_h, sent) * p_g(r,(GOR, h), (GOR, h), (SEAL, a), loc_h, sent_nums)
427         if b_h == RGOL:
428             return e_L * e_R * f_g(s,t,(RGOL, h), loc_h, sent) * p_g(r,(RGOL, h),(SEAL, a),(RGOL, h),loc_h,sent_nums)
430     def f_g(s,t,LHS,loc_h,sent): # todo: test with choose rules
431         if (s,t,LHS,loc_h) in ochart:
432             return ochart[s,t,LHS,loc_h]
433         else:
434             return outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
436     def e_g(s,t,LHS,loc_h,sent): # todo: test with choose rules
437         if (s,t,LHS,loc_h) in ichart:
438             return ichart[s,t,LHS,loc_h]
439         else:
440             return inner(s,t,LHS,loc_h,g,sent,ichart) 
441         
442     def p_g(r,LHS,L,R,loc_h,sent):
443         rules = [rule for rule in g.sent_rules(LHS, sent)
444                  if rule.L() == L and rule.R() == R]
445         rule = rules[0]
446         if len(rules) > 1:
447             raise Exception("Several rules matching a[i,j,k]")
448         return rule.p_ATTACH(r,loc_h)
450     for sent in corpus:
451         if 'reest' in io.DEBUG:
452             print sent
453         ichart = {}
454         ochart = {}
455         p_sent = inner_sent(g, sent, ichart)
457         sent_nums = g.sent_nums(sent)
458         # todo: use sum([ichart[s, t...] etc? but can we then
459         # keep den and num separate within _one_ sum()-call?
460         for loc_h,h in enumerate(sent_nums):
461             for t in xrange(loc_h, len(sent)):
462                 for s in xrange(loc_h): # s<loc(h), xrange gives strictly less
463                     # left non-adjacent stop:
464                     f['LNSTOP','num',h] += c_g(s, t, (SEAL, h), loc_h,sent)
465                     f['LNSTOP','den',h] += c_g(s, t, (RGOL,h), loc_h,sent)
466                 # left adjacent stop:
467                 f['LASTOP','num',h] += c_g(loc_h, t, (SEAL, h), loc_h,sent)
468                 f['LASTOP','den',h] += c_g(loc_h, t, (RGOL,h), loc_h,sent)
469             for t in xrange(loc_h+1, len(sent)):
470                 # right non-adjacent stop:
471                 f['RNSTOP','num',h] += c_g(loc_h, t, (RGOL,h), loc_h,sent)
472                 f['RNSTOP','den',h] += c_g(loc_h, t, (GOR, h), loc_h,sent)
473             # right adjacent stop:
474             f['RASTOP','num',h] += c_g(loc_h, loc_h, (RGOL,h), loc_h,sent)
475             f['RASTOP','den',h] += c_g(loc_h, loc_h, (GOR, h), loc_h,sent)
477             # right attachment:  TODO: try with p*e*e*f instead of c, for numerator
478             if 'reest_attach' in io.DEBUG:
479                 print "Rattach %s: for t in %s"%(g.numtag(h),sent[loc_h+1:len(sent)])
480             for t in xrange(loc_h+1, len(sent)): 
481                 cM = c_g(loc_h,t,(GOR, h), loc_h, sent) # v_q in L&Y 
482                 f['RCHOOSE','den',h] += cM
483                 if 'reest_attach' in io.DEBUG:
484                     print "\tc_g( %d , %d, %s, %s, sent)=%.4f"%(loc_h,t,g.numtag(h),loc_h,cM)
485                 args = {} # for summing w_q's in L&Y, without 1/P_q
486                 for r in xrange(loc_h+1, t+1): # loc_h < r <= t 
487                     e_L = e_g(loc_h, r-1, (GOR, h), loc_h, sent)
488                     if 'reest_attach' in io.DEBUG:
489                         print "\t\te_g( %d , %d, %s, %d, sent)=%.4f"%(loc_h,r-1,g.numtag(h),loc_h,e_L)
490                     for i,a in enumerate(sent_nums[r:t+1]):
491                         loc_a = i+r
492                         e_R = e_g(r, t, (SEAL, a), loc_a, sent)
493                         if a not in args:
494                             args[a] = 0.0
495                         args[a] += e_L * e_R * f_g(loc_h,t,(GOR, h), loc_h, sent) * p_g(r,(GOR, h), (GOR, h), (SEAL, a), loc_h, sent_nums)
496                     for a,sum_a in args.iteritems():
497                         f['RCHOOSE','num',h,a] = sum_a / p_sent
498                         
500             # left attachment:
501             if 'reest_attach' in io.DEBUG:
502                 print "Lattach %s: for s in %s"%(g.numtag(h),sent[0:loc_h])
503             for s in xrange(0, loc_h):
504                 if 'reest_attach' in io.DEBUG:
505                     print "\tfor t in %s"%sent[loc_h:len(sent)]
506                 for t in xrange(loc_h, len(sent)):
507                     c_M = c_g(s,t,(RGOL, h), loc_h, sent) # v_q in L&Y 
508                     f['LCHOOSE','den',h] += c_M
509                     if 'reest_attach' in io.DEBUG:
510                         print "\t\tc_g( %d , %d, %s_, %s, sent)=%.4f"%(s,t,g.numtag(h),loc_h,c_M)
511                     if 'reest_attach' in io.DEBUG:
512                         print "\t\tfor r in %s"%(sent[s:loc_h])
513                     args = {} # for summing w_q's in L&Y, without 1/P_q
514                     for r in xrange(s, loc_h): # s <= r < loc_h <= t
515                         e_R = e_g(r+1, t, (RGOL, h), loc_h, sent)
516                         if 'reest_attach' in io.DEBUG:
517                             print "\t\te_g( %d , %d, %s_, %d, sent)=%.4f"%(r+1,t,g.numtag(h),loc_h,e_R)
518                         for i,a in enumerate(sent_nums[s:r+1]):
519                             loc_a = i+s
520                             e_L = e_g( s , r, (SEAL, a), loc_a, sent)
521                             if a not in args:
522                                 args[a] = 0.0
523                             args[a] += e_L * e_R * f_g(s,t,(RGOL, h), loc_h, sent) * p_g(r,(RGOL, h),(SEAL, a),(RGOL, h),loc_h,sent_nums)
524                     for a,sum_a in args.iteritems():
525                         f['LCHOOSE', 'num',h,a] = sum_a / p_sent 
526     return f
528 def reestimate(g, corpus):
529     ""
530     f = reest_freq(g, corpus)
531     # we want to go through only non-ROOT left-STOPs.. 
532     for r in g.all_rules():
533         reest_rule(r,f, g)
534     return f
537 def reest_rule(r,f, g): # g just for numtag / debug output, remove eventually?
538     "remove 0-prob rules? todo"
539     h = r.head()
540     if r.LHS() == ROOT:
541         return None # not sure what todo yet here
542     if r.L() == STOP or head(r.R()) == h:
543         dir = 'L'
544     elif r.R() == STOP or head(r.L()) == h:
545         dir = 'R'
546     else:
547         raise Exception("Odd rule in reestimation.")
549     p_stopN = f[dir+'NSTOP','den',h]
550     if p_stopN > 0.0:
551         p_stopN = f[dir+'NSTOP','num',h] / p_stopN
553     p_stopA = f[dir+'ASTOP','den',h]
554     if p_stopA > 0.0:
555         p_stopA = f[dir+'ASTOP','num',h] / p_stopA
557     if r.L() == STOP or r.R() == STOP: # stop rules
558         if 'reest' in io.DEBUG:
559             print "p(STOP|%d=%s,%s,N): %.4f (was: %.4f)"%(h,g.numtag(h),dir, p_stopN, r.probN) 
560             print "p(STOP|%d=%s,%s,A): %.4f (was: %.4f)"%(h,g.numtag(h),dir, p_stopA, r.probA) 
561         r.probN = p_stopN
562         r.probA = p_stopA
564     else: # attachment rules
565         pchoose = f[dir+'CHOOSE','den',h]
566         if pchoose > 0.0:
567             if head(r.R()) == h: # left attachment
568                 a = head(r.L())
569             elif head(r.L()) == h: # right attachment
570                 a = head(r.R())
571             pchoose = f[dir+'CHOOSE','num',h,a] / pchoose 
572             r.probN = (1-p_stopN) * pchoose
573             r.probA = (1-p_stopA) * pchoose
574             if 'reest' in io.DEBUG:
575                 print "p(%d=%s|%d=%s,%s): %.4f,\tprobN: %.4f, probA: %.4f"%(a,g.numtag(a),h,g.numtag(h),dir, pchoose,r.probN,r.probA) 
583 ##############################
584 #     testing functions:     #
585 ##############################
587 testcorpus = [s.split() for s in ['det nn vbd c vbd','vbd nn c vbd',
588                                   'det nn vbd',      'det nn vbd c pp', 
589                                   'det nn vbd',      'det vbd vbd c pp', 
590                                   'det nn vbd',      'det nn vbd c vbd', 
591                                   'det nn vbd',      'det nn vbd c vbd', 
592                                   'det nn vbd',      'det nn vbd c vbd', 
593                                   'det nn vbd',      'det nn vbd c pp', 
594                                   'det nn vbd pp',   'det nn vbd', ]]
596 def testgrammar():
597     import loc_h_harmonic
598     reload(loc_h_harmonic)
599     return loc_h_harmonic.initialize(testcorpus)
601 def testreestimation():
602     io.DEBUG.add('reest')
603     g = testgrammar()
604     f = reestimate(g, testcorpus)
605     f_stops = {('LNSTOP', 'den', 3): 12.212773236178391, ('RASTOP', 'den', 2): 4.0, ('RNSTOP', 'num', 4): 2.5553487221351365, ('LNSTOP', 'den', 2): 1.274904052793207, ('LASTOP', 'num', 1): 14.999999999999995, ('RASTOP', 'den', 3): 15.0, ('LASTOP', 'num', 4): 16.65701084787457, ('LASTOP', 'num', 0): 4.1600647714443468, ('LNSTOP', 'den', 4): 6.0170669155897105, ('LASTOP', 'num', 3): 2.7872267638216113, ('LASTOP', 'num', 2): 2.9723139990470515, ('LASTOP', 'den', 2): 4.0, ('RNSTOP', 'den', 3): 12.945787931730905, ('LASTOP', 'den', 3): 14.999999999999996, ('RNSTOP', 'den', 2): 0.0, ('LASTOP', 'den', 0): 8.0, ('RASTOP', 'num', 4): 19.44465127786486, ('RNSTOP', 'den', 1): 3.1966410324085777, ('LASTOP', 'den', 1): 14.999999999999995, ('RASTOP', 'num', 3): 4.1061665495365558, ('RNSTOP', 'den', 0): 4.8282499043902476, ('LNSTOP', 'num', 4): 5.3429891521254289, ('RASTOP', 'num', 2): 4.0, ('LASTOP', 'den', 4): 22.0, ('RASTOP', 'num', 1): 12.400273895299103, ('LNSTOP', 'num', 2): 1.0276860009529487, ('RASTOP', 'num', 0): 3.1717500956097533, ('LNSTOP', 'num', 3): 12.212773236178391, ('RASTOP', 'den', 4): 22.0, ('RNSTOP', 'den', 4): 2.8705211946979836, ('LNSTOP', 'num', 0): 3.8399352285556518, ('LNSTOP', 'num', 1): 0.0, ('RNSTOP', 'num', 0): 4.8282499043902476, ('RNSTOP', 'num', 1): 2.5997261047008959, ('LNSTOP', 'den', 1): 0.0, ('RASTOP', 'den', 0): 8.0, ('RNSTOP', 'num', 2): 0.0, ('LNSTOP', 'den', 0): 4.6540557322109795, ('RASTOP', 'den', 1): 15.0, ('RNSTOP', 'num', 3): 10.893833450463443}
606     for k,v in f_stops.iteritems():
607         if not k in f:
608             pass
609 #             print '''Regression!(?) Something changed in the P_STOP reestimation,
610 # expected f[%s]=%.4f, but %s not in f'''%(k,v,k)
611         elif not f[k] == v:
612             pass
613 #             print '''Regression!(?) Something changed in the P_STOP reestimation,
614 # expected f[%s]=%.4f, got f[%s]=.%4f.'''%(k,v,k,f[k])
617 def testgrammar_a():                            # Non, Adj
618     _h_ = DMV_Rule((SEAL,0), STOP,    ( RGOL,0), 1.0, 1.0) # LSTOP
619     h_S = DMV_Rule(( RGOL,0),(GOR,0),  STOP,    0.4, 0.3) # RSTOP
620     h_A = DMV_Rule(( RGOL,0),(SEAL,0),( RGOL,0),0.2, 0.1) # Lattach
621     h_Aa= DMV_Rule(( RGOL,0),(SEAL,1),( RGOL,0),0.4, 0.6) # Lattach to a
622     h   = DMV_Rule((GOR,0),(GOR,0),(SEAL,0),    1.0, 1.0) # Rattach
623     ha  = DMV_Rule((GOR,0),(GOR,0),(SEAL,1),    1.0, 1.0) # Rattach to a
624     rh  = DMV_Rule(   ROOT,   STOP,    (SEAL,0),  0.9, 0.9) # ROOT
626     _a_ = DMV_Rule((SEAL,1), STOP,    ( RGOL,1), 1.0, 1.0) # LSTOP
627     a_S = DMV_Rule(( RGOL,1),(GOR,1),  STOP,    0.4, 0.3) # RSTOP
628     a_A = DMV_Rule(( RGOL,1),(SEAL,1),( RGOL,1),0.4, 0.6) # Lattach
629     a_Ah= DMV_Rule(( RGOL,1),(SEAL,0),( RGOL,1),0.2, 0.1) # Lattach to h
630     a   = DMV_Rule((GOR,1),(GOR,1),(SEAL,1),    1.0, 1.0) # Rattach
631     ah  = DMV_Rule((GOR,1),(GOR,1),(SEAL,0),    1.0, 1.0) # Rattach to h
632     ra  = DMV_Rule(   ROOT,   STOP,    (SEAL,1),  0.1, 0.1) # ROOT
634     b2  = {}
635     b2[(GOR, 0), 'h'] = 1.0
636     b2[(GOR, 1), 'a'] = 1.0
637     
638     return DMV_Grammar([ h_Aa, ha, a_Ah, ah, ra, _a_, a_S, a_A, a, rh, _h_, h_S, h_A, h ],b2,0,0,0, {0:'h',1:'a'}, {'h':0,'a':1})
639 def oa(s,t,LHS,loc_h):
640     return outer(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
641 def ia(s,t,LHS,loc_h):
642     return inner(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
643 def ca(s,t,LHS,loc_h):
644     return c(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
646 def testgrammar_h():                            # Non, Adj
647     _h_ = DMV_Rule((SEAL,0), STOP,    ( RGOL,0), 1.0, 1.0) # LSTOP
648     h_S = DMV_Rule(( RGOL,0),(GOR,0),  STOP,    0.4, 0.3) # RSTOP
649     h_A = DMV_Rule(( RGOL,0),(SEAL,0),( RGOL,0), 0.6, 0.7) # Lattach
650     h   = DMV_Rule((GOR,0),(GOR,0),(SEAL,0), 1.0, 1.0) # Rattach
651     rh  = DMV_Rule(   ROOT,   STOP,    (SEAL,0), 1.0, 1.0) # ROOT
652     b2  = {}
653     b2[(GOR, 0), 'h'] = 1.0
654     
655     return DMV_Grammar([ rh, _h_, h_S, h_A, h ],b2,0,0,0, {0:'h'}, {'h':0})
656     
658 def testreestimation_h():
659     io.DEBUG.add('reest')
660     g = testgrammar_h()
661     reestimate(g,['h h h'.split()])
664 def regression_tests():
665     def test(wanted, got):
666         if not wanted == got:
667             print "Regression! Should be %s: %s" % (wanted, got)
668             
669     g_dup = testgrammar_h()
670         
671     test("0.120",
672          "%.3f" % inner(0, 1, (SEAL,0), 0, g_dup, 'h h'.split(), {}))
673     
674     test("0.063",
675          "%.3f" % inner(0, 1, (SEAL,0), 1, g_dup, 'h h'.split(), {}))
676         
677     test("0.0498",
678          "%.4f" % inner(0, 2, (SEAL,0), 2, g_dup, 'h h h'.split(), {}))
679     
680     test("0.58" ,
681          "%.2f" % outer(1,2,(1,0),2,testgrammar_h(),'h h h'.split(),{},{}))
683     test("0.1089" ,
684          "%.4f" % outer(0,0,(0,0),0,testgrammar_a(),'h a'.split(),{},{}))
686     
687 if __name__ == "__main__":
688     io.DEBUG.clear()
690 #     import profile
691 #     profile.run('testreestimation()')
693 #    io.DEBUG.add('reest_attach')
694     import timeit
695     print timeit.Timer("loc_h_dmv.testreestimation()",'''import loc_h_dmv
696 reload(loc_h_dmv)''').timeit(1)
697     print "TODO: P_CHOOSE needs to be divided by sum_x(a[x|h])"
699 if __name__ == "__main__":
700     regression_tests()