found a bug in w-formula, 1/P_s should be outside sums, starting to think we should...
[dmvccm.git] / src / cnf_dmv.py
blobd596deca1e1f1e3f672e207e8c45f3179b62875a
1 # cnf_dmv.py
3 #import numpy # numpy provides Fast Arrays, for future optimization
4 import io
6 from common_dmv import *
7 SEALS = [GOR, RGOL, SEAL, NGOR, NRGOL] # overwriting here
10 if __name__ == "__main__":
11 print "cnf_dmv module tests:"
13 def make_GO_AT(p_STOP,p_ATTACH):
14 p_GO_AT = {}
15 for (a,h,dir), p_ah in p_ATTACH.iteritems():
16 p_GO_AT[a,h,dir, NON] = p_ah * (1-p_STOP[h, dir, NON])
17 p_GO_AT[a,h,dir, ADJ] = p_ah * (1-p_STOP[h, dir, ADJ])
18 return p_GO_AT
20 class CNF_DMV_Grammar(io.Grammar):
21 '''The DMV-PCFG.
23 Public members:
24 p_STOP, p_ROOT, p_ATTACH, p_terminals
25 These are changed in the Maximation step, then used to set the
26 new probabilities of each CNF_DMV_Rule.
28 __p_rules is private, but we can still say stuff like:
29 for r in g.all_rules():
30 r.prob = (1-p_STOP[...]) * p_ATTACH[...]
31 '''
32 def __str__(self):
33 str = ""
34 for r in self.all_rules():
35 str += "%s\n" % r.__str__(self.numtag)
36 return str
38 def LHSs(self):
39 return [ROOT] + [(s_h,h)
40 for h in self.headnums()
41 for s_h in SEALS]
43 def sent_rules(self, sent_nums):
44 sent_nums_stop = sent_nums + [POS(STOP)]
45 return [ r for LHS in self.LHSs()
46 for r in self.arg_rules(LHS, sent_nums)
47 if POS(r.L()) in sent_nums_stop
48 and POS(r.R()) in sent_nums_stop ]
50 # used in outer:
51 def mothersR(self, w_node, argnums):
52 '''For all LHS and x, return all rules of the form 'LHS->x w_node'.'''
53 if w_node not in self.__mothersR:
54 self.__mothersR[w_node] = [r for LHS in self.LHSs()
55 for r in self.rules(LHS)
56 if r.R() == w_node]
57 return [r for r in self.__mothersR[w_node]
58 if POS(r.L()) in argnums]
60 def mothersL(self, w_node, argnums):
61 '''For all LHS and x, return all rules of the form 'LHS->w_node x'.'''
62 if w_node not in self.__mothersL:
63 self.__mothersL[w_node] = [r for LHS in self.LHSs()
64 for r in self.rules(LHS)
65 if r.L() == w_node]
66 return [r for r in self.__mothersL[w_node]
67 if POS(r.R()) in argnums]
70 # used in inner:
71 def arg_rules(self, LHS, argnums):
72 return [r for r in self.rules(LHS)
73 if (POS(r.R()) in argnums
74 or POS(r.L()) in argnums)]
77 def make_all_rules(self):
78 self.new_rules([r for LHS in self.LHSs()
79 for r in self._make_rules(LHS, self.headnums())])
81 def _make_rules(self, LHS, argnums):
82 '''This is where the CNF grammar is defined. Also, s_dir_typ shows how
83 useful it'd be to split up the seals into direction and
84 type... todo?'''
85 h = POS(LHS)
86 if LHS == ROOT:
87 return [CNF_DMV_Rule(LEFT, LHS, (SEAL,h), STOP, self.p_ROOT[h])
88 for h in set(argnums)]
89 s_h = seals(LHS)
90 if s_h == GOR:
91 return [] # only terminals from here on
92 s_dir_type = { # seal of LHS
93 RGOL: (RIGHT, 'STOP'), NGOR: (RIGHT, 'ATTACH'),
94 SEAL: (LEFT, 'STOP'), NRGOL: (LEFT, 'ATTACH') }
95 dir_s_adj = { # seal of h_daughter
96 RIGHT: [(GOR, True),(NGOR, False)] ,
97 LEFT: [(RGOL,True),(NRGOL,False)] }
98 dir,type = s_dir_type[s_h]
99 rule = {
100 'ATTACH': [CNF_DMV_Rule(dir, LHS, (s, h), (SEAL,a), self.p_GO_AT[a,h,dir,adj])
101 for a in set(argnums) if (a,h,dir) in self.p_ATTACH
102 for s, adj in dir_s_adj[dir]] ,
103 'STOP': [CNF_DMV_Rule(dir, LHS, (s, h), STOP, self.p_STOP[h,dir,adj])
104 for s, adj in dir_s_adj[dir]] }
105 return rule[type]
108 def __init__(self, numtag, tagnum, p_ROOT, p_STOP, p_ATTACH, p_terminals):
109 io.Grammar.__init__(self, numtag, tagnum, [], p_terminals)
110 self.p_STOP = p_STOP
111 self.p_ATTACH = p_ATTACH
112 self.p_ROOT = p_ROOT
113 self.p_GO_AT = make_GO_AT(self.p_STOP, self.p_ATTACH)
114 self.make_all_rules()
115 self.__mothersL = {}
116 self.__mothersR = {}
119 class CNF_DMV_Rule(io.CNF_Rule):
120 '''A single CNF rule in the PCFG, of the form
121 LHS -> L R
122 where LHS, L and R are 'nodes', eg. of the form (seals, head).
124 Public members:
125 prob
127 Private members:
128 __L, __R, __LHS
130 Different rule-types have different probabilities associated with
131 them, see formulas.pdf
133 def seals(self):
134 return seals(self.LHS())
136 def POS(self):
137 return POS(self.LHS())
139 def __init__(self, dir, LHS, h_daughter, a_daughter, prob):
140 self.__dir = dir
141 if dir == LEFT:
142 L, R = a_daughter, h_daughter
143 elif dir == RIGHT:
144 L, R = h_daughter, a_daughter
145 else:
146 raise ValueError, "dir must be LEFT or RIGHT, given: %s"%dir
147 for b_h in [LHS, L, R]:
148 if seals(b_h) not in SEALS:
149 raise ValueError("seals must be in %s; was given: %s"
150 % (SEALS, seals(b_h)))
151 io.CNF_Rule.__init__(self, LHS, L, R, prob)
153 def adj(self):
154 "'undefined' for ROOT"
155 if self.__dir == LEFT:
156 return seals(self.R()) == RGOL
157 else: # RIGHT
158 return seals(self.L()) == GOR
160 def __str__(self, tag=lambda x:x):
161 if self.adj(): adj_str = "adj"
162 else: adj_str = "non_adj"
163 if self.LHS() == ROOT: adj_str = ""
164 return "%s --> %s %s\t[%.2f] %s" % (node_str(self.LHS(), tag),
165 node_str(self.L(), tag),
166 node_str(self.R(), tag),
167 self.prob,
168 adj_str)
176 ###################################
177 # dmv-specific version of inner() #
178 ###################################
179 def inner(i, j, LHS, g, sent, ichart={}):
180 ''' A CNF rewrite of io.inner(), to take STOP rules into accord. '''
181 def O(i,j):
182 return sent[i]
184 sent_nums = g.sent_nums(sent)
186 def e(i,j,LHS, n_t):
187 def tab():
188 "Tabs for debug output"
189 return "\t"*n_t
190 if (i, j, LHS) in ichart:
191 if 'INNER' in DEBUG:
192 print "%s*= %.4f in ichart: i:%d j:%d LHS:%s" % (tab(), ichart[i, j, LHS], i, j, node_str(LHS))
193 return ichart[i, j, LHS]
194 else:
195 # if seals(LHS) == RGOL then we have to STOP first
196 if i == j-1 and seals(LHS) == GOR:
197 if (LHS, O(i,j)) in g.p_terminals:
198 prob = g.p_terminals[LHS, O(i,j)] # "b[LHS, O(s)]" in Lari&Young
199 else:
200 prob = 0.0
201 if 'INNER' in DEBUG:
202 print "%sLACKING TERMINAL:" % tab()
203 if 'INNER' in DEBUG:
204 print "%s*= %.4f (terminal: %s -> %s)" % (tab(),prob, node_str(LHS), O(i,j))
205 return prob
206 else:
207 p = 0.0 # "sum over j,k in a[LHS,j,k]"
208 for rule in g.arg_rules(LHS, sent_nums):
209 if 'INNER' in DEBUG:
210 print "%ssumming rule %s i:%d j:%d" % (tab(),rule,i,j)
211 L = rule.L()
212 R = rule.R()
213 # if it's a STOP rule, rewrite for the same xrange:
214 if (L == STOP) or (R == STOP):
215 if L == STOP:
216 pLR = e(i, j, R, n_t+1)
217 elif R == STOP:
218 pLR = e(i, j, L, n_t+1)
219 p += rule.p() * pLR
220 if 'INNER' in DEBUG:
221 print "%sp= %.4f (STOP)" % (tab(), p)
223 elif j > i+1 and seals(LHS) != GOR:
224 # not a STOP, attachment rewrite:
225 for k in xtween(i, j): # i<k<j
226 p_L = e(i, k, L, n_t+1)
227 p_R = e(k, j, R, n_t+1)
228 p += rule.p() * p_L * p_R
229 if 'INNER' in DEBUG:
230 print "%sp= %.4f (ATTACH, p_L:%.4f, p_R:%.4f, rule:%.4f)" % (tab(), p,p_L,p_R,rule.p())
231 ichart[i, j, LHS] = p
232 return p
233 # end of e-function
235 inner_prob = e(i,j,LHS, 0)
236 if 'INNER' in DEBUG:
237 print debug_ichart(g,sent,ichart)
238 return inner_prob
239 # end of cnf_dmv.inner(i, j, LHS, g, sent, ichart={})
242 def debug_ichart(g,sent,ichart):
243 str = "---ICHART:---\n"
244 for (i,j,LHS),v in ichart.iteritems():
245 if type(v) == dict: # skip 'tree'
246 continue
247 str += "%s -> %s ... %s: \t%.4f\n" % (node_str(LHS,g.numtag),
248 sent[i], sent[j-1], v)
249 str += "---ICHART:end---\n"
250 return str
253 def inner_sent(g, sent, ichart={}):
254 return sum([inner(0, len(sent), ROOT, g, sent, ichart)])
257 #######################################
258 # cnf_dmv-specific version of outer() #
259 #######################################
260 def outer(i,j,w_node, g, sent, ichart={}, ochart={}):
261 def e(i,j,LHS):
262 # or we could just look it up in ichart, assuming ichart to be done
263 return inner(i, j, LHS, g, sent, ichart)
265 sent_nums = g.sent_nums(sent)
266 if POS(w_node) not in sent_nums[i:j]:
267 # sanity check, w must be able to dominate sent[i:j]
268 return 0.0
270 def f(i,j,w_node):
271 if (i,j,w_node) in ochart:
272 return ochart[(i, j, w_node)]
273 if w_node == ROOT:
274 if i == 0 and j == len(sent):
275 return 1.0
276 else: # ROOT may only be used on full sentence
277 return 0.0 # but we may have non-ROOTs over full sentence too
279 p = 0.0
281 for rule in g.mothersL(w_node, sent_nums): # rule.L() == w_node
282 if 'OUTER' in DEBUG: print "w_node:%s (L) ; %s"%(node_str(w_node),rule)
283 if rule.R() == STOP:
284 p0 = f(i,j,rule.LHS()) * rule.p()
285 if 'OUTER' in DEBUG: print p0
286 p += p0
287 else:
288 for k in xgt(j,sent): # i<j<k
289 p0 = f(i,k, rule.LHS() ) * rule.p() * e(j,k, rule.R() )
290 if 'OUTER' in DEBUG: print p0
291 p += p0
293 for rule in g.mothersR(w_node, sent_nums): # rule.R() == w_node
294 if 'OUTER' in DEBUG: print "w_node:%s (R) ; %s"%(node_str(w_node),rule)
295 if rule.L() == STOP:
296 p0 = f(i,j,rule.LHS()) * rule.p()
297 if 'OUTER' in DEBUG: print p0
298 p += p0
299 else:
300 for k in xlt(i): # k<i<j
301 p0 = e(k,i, rule.L() ) * rule.p() * f(k,j, rule.LHS() )
302 if 'OUTER' in DEBUG: print p0
303 p += p0
305 ochart[i,j,w_node] = p
306 return p
309 return f(i,j,w_node)
310 # end outer(i,j,w_node, g,sent, ichart,ochart)
314 ##############################
315 # reestimation, todo: #
316 ##############################
317 def reest_zeros(rules):
318 f = {}
319 for r in rules:
320 for nd in ['num','den']:
321 f[nd, r.LHS(), r.L(), r.R()] = 0.0
322 return f
324 def reest_freq(g, corpus):
325 ''' P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
326 f = reest_zeros(g.all_rules())
327 ichart = {}
328 ochart = {}
330 p_sent = None # 50 % speed increase on storing this locally
332 def c_g(i,j,LHS,sent):
333 if p_sent == 0.0:
334 return 0.0
335 return e_g(i,j,LHS,sent) * f_g(i,j,LHS,sent) / p_sent
337 def w1_g(i,j,rule,sent): # unary (stop) rules, LHS -> child_node
338 if rule.L() == STOP: child = rule.R()
339 elif rule.R() == STOP: child = rule.L()
340 else: raise ValueError, "expected a stop rule: %s"%(rule,)
342 if p_sent == 0.0: return 0.0
344 p_out = f_g(i,j,rule.LHS(),sent)
345 if p_out == 0.0: return 0.0
347 return rule.p() * e_g(i,j,child,sent) * p_out / p_sent
349 def w_g(i,j,rule,sent):
350 if p_sent == 0.0 or i+1 == j: return 0.0
352 p_out = f_g(i,j,rule.LHS(),sent)
353 if p_out == 0.0: return 0.0
355 p = 0.0
356 for k in xtween(i,j):
357 p += rule.p() * e_g(i,k,rule.L(),sent) * e_g(k,j,rule.R(),sent) * p_out
358 return p / p_sent
360 def f_g(i,j,LHS,sent):
361 if (i,j,LHS) in ochart:
362 # print ".",
363 return ochart[i,j,LHS]
364 else:
365 return outer(i,j,LHS,g,sent,ichart,ochart)
367 def e_g(i,j,LHS,sent):
368 if (i,j,LHS) in ichart:
369 # print ".",
370 return ichart[i,j,LHS]
371 else:
372 return inner(i,j,LHS,g,sent,ichart)
374 for sn,sent in enumerate(corpus):
375 if sn%1==0: print "sentence number %d"%sn
376 if 'REEST' in DEBUG: print sent
377 ichart = {}
378 ochart = {}
379 # since we keep re-using p_sent, it seems better to have
380 # sentences as the outer loop; o/w we'd have to keep every chart
381 p_sent = inner_sent(g, sent, ichart)
383 sent_nums = g.sent_nums(sent)
384 sent_rules = g.sent_rules(sent_nums)
385 for r in sent_rules:
386 print r
387 LHS, L, R = r.LHS(), r.L(), r.R()
388 if 'REEST' in DEBUG: print r
389 if LHS == ROOT:
390 f['num',LHS,L,R] += r.p() * e_g(0, len(sent), R, sent)
391 f['den',LHS,L,R] += p_sent
392 continue # !!! o/w we add wrong values to it below
393 if L == STOP or R == STOP:
394 w = w1_g
395 else:
396 w = w_g
397 for i in xlt(len(sent)):
398 for j in xgt(i, sent):
399 f['num',LHS,L,R] += w(i,j, r, sent)
400 f['den',LHS,L,R] += c_g(i,j, LHS, sent)
401 return f
403 def reestimate(g, corpus):
404 f = reest_freq(g, corpus)
405 print "applying f to rules"
406 for r in g.all_rules():
407 r.prob = f['den',r.LHS(),r.L(),r.R()]
408 if r.prob > 0.0:
409 r.prob = f['num',r.LHS(),r.L(),r.R()] / r.prob
410 return f
413 ##############################
414 # testing functions: #
415 ##############################
416 def testgrammar():
417 # make sure we use the same data:
418 from loc_h_dmv import testcorpus
420 import cnf_harmonic
421 reload(cnf_harmonic)
422 return cnf_harmonic.initialize(testcorpus)
424 def testreestimation():
425 from loc_h_dmv import testcorpus
426 g = testgrammar()
427 f = reestimate(g, testcorpus[0:4])
428 return (f,g)
430 def testgrammar_a(): # Non, Adj
431 _h_ = CNF_DMV_Rule((SEAL,0), STOP, ( RGOL,0), 1.0, 1.0) # LSTOP
432 h_S = CNF_DMV_Rule(( RGOL,0),(GOR,0), STOP, 0.4, 0.3) # RSTOP
433 h_A = CNF_DMV_Rule(( RGOL,0),(SEAL,0),( RGOL,0),0.2, 0.1) # Lattach
434 h_Aa= CNF_DMV_Rule(( RGOL,0),(SEAL,1),( RGOL,0),0.4, 0.6) # Lattach to a
435 h = CNF_DMV_Rule((GOR,0),(GOR,0),(SEAL,0), 1.0, 1.0) # Rattach
436 ha = CNF_DMV_Rule((GOR,0),(GOR,0),(SEAL,1), 1.0, 1.0) # Rattach to a
437 rh = CNF_DMV_Rule( ROOT, STOP, (SEAL,0), 0.9, 0.9) # ROOT
439 _a_ = CNF_DMV_Rule((SEAL,1), STOP, ( RGOL,1), 1.0, 1.0) # LSTOP
440 a_S = CNF_DMV_Rule(( RGOL,1),(GOR,1), STOP, 0.4, 0.3) # RSTOP
441 a_A = CNF_DMV_Rule(( RGOL,1),(SEAL,1),( RGOL,1),0.4, 0.6) # Lattach
442 a_Ah= CNF_DMV_Rule(( RGOL,1),(SEAL,0),( RGOL,1),0.2, 0.1) # Lattach to h
443 a = CNF_DMV_Rule((GOR,1),(GOR,1),(SEAL,1), 1.0, 1.0) # Rattach
444 ah = CNF_DMV_Rule((GOR,1),(GOR,1),(SEAL,0), 1.0, 1.0) # Rattach to h
445 ra = CNF_DMV_Rule( ROOT, STOP, (SEAL,1), 0.1, 0.1) # ROOT
447 p_rules = [ h_Aa, ha, a_Ah, ah, ra, _a_, a_S, a_A, a, rh, _h_, h_S, h_A, h ]
450 b = {}
451 b[(GOR, 0), 'h'] = 1.0
452 b[(GOR, 1), 'a'] = 1.0
454 return CNF_DMV_Grammar({0:'h',1:'a'}, {'h':0,'a':1},
455 None,None,None,b)
457 def testgrammar_h():
458 h = 0
459 p_ROOT, p_STOP, p_ATTACH, p_ORDER = {},{},{},{}
460 p_ROOT[h] = 1.0
461 p_STOP[h,LEFT,NON] = 1.0
462 p_STOP[h,LEFT,ADJ] = 1.0
463 p_STOP[h,RIGHT,NON] = 0.4
464 p_STOP[h,RIGHT,ADJ] = 0.3
465 p_ATTACH[h,h,LEFT] = 1.0 # not used
466 p_ATTACH[h,h,RIGHT] = 1.0 # not used
467 p_terminals = {}
468 p_terminals[(GOR, 0), 'h'] = 1.0
470 g = CNF_DMV_Grammar({h:'h'}, {'h':h}, p_ROOT, p_STOP, p_ATTACH, p_terminals)
472 g.p_GO_AT[h,h,LEFT,NON] = 0.6 # these probabilities are impossible
473 g.p_GO_AT[h,h,LEFT,ADJ] = 0.7 # so add them manually...
474 g.p_GO_AT[h,h,RIGHT,NON] = 1.0
475 g.p_GO_AT[h,h,RIGHT,ADJ] = 1.0
476 g.make_all_rules()
477 return g
480 def testreestimation_h():
481 DEBUG.add('REEST')
482 g = testgrammar_h()
483 return reestimate(g,['h h h'.split()])
485 def regression_tests():
486 test("0.1830", # = .120 + .063, since we have no loc_h
487 "%.4f" % inner(0, 2, (SEAL,0), testgrammar_h(), 'h h'.split(), {}))
489 test("0.1842", # = .0498 + .1092 +.0252
490 "%.4f" % inner(0, 3, (SEAL,0), testgrammar_h(), 'h h h'.split(), {}))
491 test("0.1842",
492 "%.4f" % inner_sent(testgrammar_h(), 'h h h'.split()))
494 test("0.61" ,
495 "%.2f" % outer(1, 3, ( RGOL,0), testgrammar_h(),'h h h'.split(),{},{}))
496 test("0.58" ,
497 "%.2f" % outer(1, 3, (NRGOL,0), testgrammar_h(),'h h h'.split(),{},{}))
500 if __name__ == "__main__":
501 DEBUG.clear()
503 # import profile
504 # profile.run('testreestimation()')
506 # DEBUG.add('reest_attach')
507 # import timeit
508 # print timeit.Timer("cnf_dmv.testreestimation_h()",'''import cnf_dmv
509 # reload(cnf_dmv)''').timeit(1)
511 if __name__ == "__main__":
512 regression_tests()
513 # g = testgrammar()
514 # print g
515 print "TODO!!!! fix outer (also, make mothersL and R faster somehow)"