written reest for cnf, todo: test it
[dmvccm.git] / src / cnf_dmv.py
blobc6dc4aecabe8f3f10c02eb606c4b317a9a106f65
1 # cnf_dmv.py
3 #import numpy # numpy provides Fast Arrays, for future optimization
4 import io
6 from common_dmv import *
7 SEALS = [GOR, RGOL, SEAL, NGOR, NRGOL] # overwriting here
10 if __name__ == "__main__":
11 print "cnf_dmv module tests:"
13 def make_GO_AT(p_STOP,p_ATTACH):
14 p_GO_AT = {}
15 for (a,h,dir), p_ah in p_ATTACH.iteritems():
16 p_GO_AT[a,h,dir, NON] = p_ah * (1-p_STOP[h, dir, NON])
17 p_GO_AT[a,h,dir, ADJ] = p_ah * (1-p_STOP[h, dir, ADJ])
18 return p_GO_AT
20 class CNF_DMV_Grammar(io.Grammar):
21 '''The DMV-PCFG.
23 Public members:
24 p_STOP, p_ROOT, p_ATTACH, p_terminals
25 These are changed in the Maximation step, then used to set the
26 new probabilities of each CNF_DMV_Rule.
28 __p_rules is private, but we can still say stuff like:
29 for r in g.all_rules():
30 r.prob = (1-p_STOP[...]) * p_ATTACH[...]
31 '''
32 def __str__(self):
33 str = ""
34 for r in self.all_rules():
35 str += "%s\n" % r.__str__(self.numtag)
36 return str
38 def LHSs(self):
39 return [ROOT] + [(s_h,h)
40 for h in self.headnums()
41 for s_h in SEALS]
43 def sent_rules(self, sent_nums):
44 sent_nums_stop = sent_nums + [POS(STOP)]
45 return [ r for LHS in self.LHSs()
46 for r in self.arg_rules(LHS, sent_nums)
47 if POS(r.L()) in sent_nums_stop
48 and POS(r.R()) in sent_nums_stop ]
50 # used in outer:
51 def mothersR(self, w_node, argnums):
52 '''For all LHS and x, return all rules of the form 'LHS->x w_node'.'''
53 return [r for LHS in self.LHSs()
54 for r in self.arg_rules(LHS, argnums)
55 if r.R() == w_node]
57 def mothersL(self, w_node, argnums):
58 '''For all LHS and x, return all rules of the form 'LHS->w_node x'.'''
59 return [r for LHS in self.LHSs()
60 for r in self.arg_rules(LHS, argnums)
61 if r.L() == w_node]
63 # used in inner:
64 def arg_rules(self, LHS, argnums):
65 return [r for r in self.rules(LHS)
66 if (POS(r.R()) in argnums
67 or POS(r.L()) in argnums)]
70 def make_all_rules(self):
71 self.new_rules([r for LHS in self.LHSs()
72 for r in self._make_rules(LHS, self.headnums())])
74 def _make_rules(self, LHS, argnums):
75 '''This is where the CNF grammar is defined. Also, s_dir_typ shows how
76 useful it'd be to split up the seals into direction and
77 type... todo?'''
78 h = POS(LHS)
79 if LHS == ROOT:
80 return [CNF_DMV_Rule(LEFT, LHS, (SEAL,h), STOP, self.p_ROOT[h])
81 for h in set(argnums)]
82 s_h = seals(LHS)
83 if s_h == GOR:
84 return [] # only terminals from here on
85 s_dir_type = { # seal of LHS
86 RGOL: (RIGHT, 'STOP'), NGOR: (RIGHT, 'ATTACH'),
87 SEAL: (LEFT, 'STOP'), NRGOL: (LEFT, 'ATTACH') }
88 dir_s_adj = { # seal of h_daughter
89 RIGHT: [(GOR, True),(NGOR, False)] ,
90 LEFT: [(RGOL,True),(NRGOL,False)] }
91 dir,type = s_dir_type[s_h]
92 rule = {
93 'ATTACH': [CNF_DMV_Rule(dir, LHS, (s, h), (SEAL,a), self.p_GO_AT[a,h,dir,adj])
94 for a in set(argnums) if (a,h,dir) in self.p_ATTACH
95 for s, adj in dir_s_adj[dir]] ,
96 'STOP': [CNF_DMV_Rule(dir, LHS, (s, h), STOP, self.p_STOP[h,dir,adj])
97 for s, adj in dir_s_adj[dir]] }
98 return rule[type]
101 def __init__(self, numtag, tagnum, p_ROOT, p_STOP, p_ATTACH, p_terminals):
102 io.Grammar.__init__(self, numtag, tagnum, [], p_terminals)
103 self.p_STOP = p_STOP
104 self.p_ATTACH = p_ATTACH
105 self.p_ROOT = p_ROOT
106 self.p_GO_AT = make_GO_AT(self.p_STOP, self.p_ATTACH)
107 self.make_all_rules()
110 class CNF_DMV_Rule(io.CNF_Rule):
111 '''A single CNF rule in the PCFG, of the form
112 LHS -> L R
113 where LHS, L and R are 'nodes', eg. of the form (seals, head).
115 Public members:
116 prob
118 Private members:
119 __L, __R, __LHS
121 Different rule-types have different probabilities associated with
122 them, see formulas.pdf
124 def seals(self):
125 return seals(self.LHS())
127 def POS(self):
128 return POS(self.LHS())
130 def __init__(self, dir, LHS, h_daughter, a_daughter, prob):
131 self.__dir = dir
132 if dir == LEFT:
133 L, R = a_daughter, h_daughter
134 elif dir == RIGHT:
135 L, R = h_daughter, a_daughter
136 else:
137 raise ValueError, "dir must be LEFT or RIGHT, given: %s"%dir
138 for b_h in [LHS, L, R]:
139 if seals(b_h) not in SEALS:
140 raise ValueError("seals must be in %s; was given: %s"
141 % (SEALS, seals(b_h)))
142 io.CNF_Rule.__init__(self, LHS, L, R, prob)
144 def adj(self):
145 "'undefined' for ROOT"
146 if self.__dir == LEFT:
147 return seals(self.R()) == RGOL
148 else: # RIGHT
149 return seals(self.L()) == GOR
151 def __str__(self, tag=lambda x:x):
152 if self.adj(): adj_str = "adj"
153 else: adj_str = "non_adj"
154 if self.LHS() == ROOT: adj_str = ""
155 return "%s --> %s %s\t[%.2f] %s" % (node_str(self.LHS(), tag),
156 node_str(self.L(), tag),
157 node_str(self.R(), tag),
158 self.prob,
159 adj_str)
167 ###################################
168 # dmv-specific version of inner() #
169 ###################################
170 def inner(i, j, LHS, g, sent, ichart={}):
171 ''' A CNF rewrite of io.inner(), to take STOP rules into accord. '''
172 def O(i,j):
173 return sent[i]
175 sent_nums = g.sent_nums(sent)
177 def e(i,j,LHS, n_t):
178 def tab():
179 "Tabs for debug output"
180 return "\t"*n_t
181 if (i, j, LHS) in ichart:
182 if 'INNER' in DEBUG:
183 print "%s*= %.4f in ichart: i:%d j:%d LHS:%s" % (tab(), ichart[i, j, LHS], i, j, node_str(LHS))
184 return ichart[i, j, LHS]
185 else:
186 # if seals(LHS) == RGOL then we have to STOP first
187 if i == j-1 and seals(LHS) == GOR:
188 if (LHS, O(i,j)) in g.p_terminals:
189 prob = g.p_terminals[LHS, O(i,j)] # "b[LHS, O(s)]" in Lari&Young
190 else:
191 prob = 0.0
192 if 'INNER' in DEBUG:
193 print "%sLACKING TERMINAL:" % tab()
194 if 'INNER' in DEBUG:
195 print "%s*= %.4f (terminal: %s -> %s)" % (tab(),prob, node_str(LHS), O(i,j))
196 return prob
197 else:
198 p = 0.0 # "sum over j,k in a[LHS,j,k]"
199 for rule in g.arg_rules(LHS, sent_nums):
200 if 'INNER' in DEBUG:
201 print "%ssumming rule %s i:%d j:%d" % (tab(),rule,i,j)
202 L = rule.L()
203 R = rule.R()
204 # if it's a STOP rule, rewrite for the same xrange:
205 if (L == STOP) or (R == STOP):
206 if L == STOP:
207 pLR = e(i, j, R, n_t+1)
208 elif R == STOP:
209 pLR = e(i, j, L, n_t+1)
210 p += rule.p() * pLR
211 if 'INNER' in DEBUG:
212 print "%sp= %.4f (STOP)" % (tab(), p)
214 elif j > i+1 and seals(LHS) != GOR:
215 # not a STOP, attachment rewrite:
216 for k in xtween(i, j): # i<k<j
217 p_L = e(i, k, L, n_t+1)
218 p_R = e(k, j, R, n_t+1)
219 p += rule.p() * p_L * p_R
220 if 'INNER' in DEBUG:
221 print "%sp= %.4f (ATTACH, p_L:%.4f, p_R:%.4f, rule:%.4f)" % (tab(), p,p_L,p_R,rule.p())
222 ichart[i, j, LHS] = p
223 return p
224 # end of e-function
226 inner_prob = e(i,j,LHS, 0)
227 if 'INNER' in DEBUG:
228 print debug_ichart(g,sent,ichart)
229 return inner_prob
230 # end of cnf_dmv.inner(i, j, LHS, g, sent, ichart={})
233 def debug_ichart(g,sent,ichart):
234 str = "---ICHART:---\n"
235 for (i,j,LHS),v in ichart.iteritems():
236 if type(v) == dict: # skip 'tree'
237 continue
238 str += "%s -> %s ... %s: \t%.4f\n" % (node_str(LHS,g.numtag),
239 sent[i], sent[j-1], v)
240 str += "---ICHART:end---\n"
241 return str
244 def inner_sent(g, sent, ichart={}):
245 return sum([inner(0, len(sent), ROOT, g, sent, ichart)])
248 #######################################
249 # cnf_dmv-specific version of outer() #
250 #######################################
251 def outer(i,j,w_node, g, sent, ichart={}, ochart={}):
252 def e(i,j,LHS):
253 # or we could just look it up in ichart, assuming ichart to be done
254 return inner(i, j, LHS, g, sent, ichart)
256 sent_nums = g.sent_nums(sent)
257 if POS(w_node) not in sent_nums[i:j]:
258 # sanity check, w must be able to dominate sent[i:j]
259 return 0.0
261 def f(i,j,w_node):
262 if (i,j,w_node) in ochart:
263 return ochart[(i, j, w_node)]
264 if w_node == ROOT:
265 if i == 0 and j == len(sent):
266 return 1.0
267 else: # ROOT may only be used on full sentence
268 return 0.0 # but we may have non-ROOTs over full sentence too
270 p = 0.0
272 for rule in g.mothersL(w_node, sent_nums): # rule.L() == w_node
273 if 'OUTER' in DEBUG: print "w_node:%s (L) ; %s"%(node_str(w_node),rule)
274 if rule.R() == STOP:
275 p0 = f(i,j,rule.LHS()) * rule.p()
276 if 'OUTER' in DEBUG: print p0
277 p += p0
278 else:
279 for k in xgt(j,sent): # i<j<k
280 p0 = f(i,k, rule.LHS() ) * rule.p() * e(j,k, rule.R() )
281 if 'OUTER' in DEBUG: print p0
282 p += p0
284 for rule in g.mothersR(w_node, sent_nums): # rule.R() == w_node
285 if 'OUTER' in DEBUG: print "w_node:%s (R) ; %s"%(node_str(w_node),rule)
286 if rule.L() == STOP:
287 p0 = f(i,j,rule.LHS()) * rule.p()
288 if 'OUTER' in DEBUG: print p0
289 p += p0
290 else:
291 for k in xlt(i): # k<i<j
292 p0 = e(k,i, rule.L() ) * rule.p() * f(k,j, rule.LHS() )
293 if 'OUTER' in DEBUG: print p0
294 p += p0
296 ochart[i,j,w_node] = p
297 return p
300 return f(i,j,w_node)
301 # end outer(i,j,w_node, g,sent, ichart,ochart)
305 ##############################
306 # reestimation, todo: #
307 ##############################
308 def reest_zeros(rules):
309 f = {}
310 for r in rules:
311 for nd in ['num','den']:
312 f[nd, r.LHS(), r.L(), r.R()] = 0.0
313 return f
315 def reest_freq(g, corpus):
316 ''' P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
317 f = reest_zeros(g.all_rules())
318 ichart = {}
319 ochart = {}
321 p_sent = None # 50 % speed increase on storing this locally
323 def c_g(i,j,LHS,sent):
324 if p_sent == 0.0:
325 return 0.0
326 return e_g(i,j,LHS,sent) * f_g(i,j,LHS,sent) / p_sent
328 def w1_g(i,j,rule,sent): # unary (stop) rules, LHS -> child_node
329 if rule.L() == STOP: child = rule.R()
330 elif rule.R() == STOP: child = rule.L()
331 else: raise ValueError, "expected a stop rule: %s"%(rule,)
333 if p_sent == 0.0: return 0.0
335 p_out = f_g(i,j,rule.LHS(),sent)
336 if p_out == 0.0: return 0.0
338 return rule.p() * e_g(i,j,child,sent) * p_out / p_sent
340 def w_g(i,j,rule,sent):
341 if p_sent == 0.0 or i+1 == j: return 0.0
343 p_out = f_g(i,j,rule.LHS(),sent)
344 if p_out == 0.0: return 0.0
346 p = 0.0
347 for k in xtween(i,j):
348 p += rule.p() * e_g(i,k,rule.L(),sent) * e_g(k,j,rule.R(),sent) * p_out
349 return p / p_sent
351 def f_g(i,j,LHS,sent):
352 if (i,j,LHS) in ochart: return ochart[i,j,LHS]
353 else: return outer(i,j,LHS,g,sent,ichart,ochart)
355 def e_g(i,j,LHS,sent):
356 if (i,j,LHS) in ichart: return ichart[i,j,LHS]
357 else: return inner(i,j,LHS,g,sent,ichart)
359 for sent in corpus:
360 if 'REEST' in DEBUG: print sent
361 ichart = {}
362 ochart = {}
363 # since we keep re-using p_sent, it seems better to have
364 # sentences as the outer loop; o/w we'd have to keep every chart
365 p_sent = inner_sent(g, sent, ichart)
367 sent_nums = g.sent_nums(sent)
368 sent_rules = g.sent_rules(sent_nums)
369 for r in sent_rules:
370 LHS, L, R = r.LHS(), r.L(), r.R()
371 if 'REEST' in DEBUG: print r
372 if LHS == ROOT:
373 f['num',LHS,L,R] += r.p() * e_g(0, len(sent), R, sent)
374 f['den',LHS,L,R] += p_sent
375 continue # !!! o/w we add wrong values to it below
376 if L == STOP or R == STOP:
377 w = w1_g
378 else:
379 w = w_g
380 for i in xlt(len(sent)):
381 for j in xgt(i, sent):
382 f['num',LHS,L,R] += w(i,j, r, sent)
383 f['den',LHS,L,R] += c_g(i,j, LHS, sent)
384 return f
386 def reestimate(g, corpus):
387 f = reest_freq(g, corpus)
388 for r in g.all_rules():
389 r.prob = f['den',r.LHS(),r.L(),r.R()]
390 if r.prob > 0.0:
391 r.prob = f['num',r.LHS(),r.L(),r.R()] / r.prob
392 return f
395 ##############################
396 # testing functions: #
397 ##############################
398 def testgrammar():
399 # make sure we use the same data:
400 from loc_h_dmv import testcorpus
402 import cnf_harmonic
403 reload(cnf_harmonic)
404 return cnf_harmonic.initialize(testcorpus)
406 def testreestimation():
407 from loc_h_dmv import testcorpus
408 g = testgrammar()
409 f = reestimate(g, testcorpus)
410 return (f,g)
412 def testgrammar_a(): # Non, Adj
413 _h_ = CNF_DMV_Rule((SEAL,0), STOP, ( RGOL,0), 1.0, 1.0) # LSTOP
414 h_S = CNF_DMV_Rule(( RGOL,0),(GOR,0), STOP, 0.4, 0.3) # RSTOP
415 h_A = CNF_DMV_Rule(( RGOL,0),(SEAL,0),( RGOL,0),0.2, 0.1) # Lattach
416 h_Aa= CNF_DMV_Rule(( RGOL,0),(SEAL,1),( RGOL,0),0.4, 0.6) # Lattach to a
417 h = CNF_DMV_Rule((GOR,0),(GOR,0),(SEAL,0), 1.0, 1.0) # Rattach
418 ha = CNF_DMV_Rule((GOR,0),(GOR,0),(SEAL,1), 1.0, 1.0) # Rattach to a
419 rh = CNF_DMV_Rule( ROOT, STOP, (SEAL,0), 0.9, 0.9) # ROOT
421 _a_ = CNF_DMV_Rule((SEAL,1), STOP, ( RGOL,1), 1.0, 1.0) # LSTOP
422 a_S = CNF_DMV_Rule(( RGOL,1),(GOR,1), STOP, 0.4, 0.3) # RSTOP
423 a_A = CNF_DMV_Rule(( RGOL,1),(SEAL,1),( RGOL,1),0.4, 0.6) # Lattach
424 a_Ah= CNF_DMV_Rule(( RGOL,1),(SEAL,0),( RGOL,1),0.2, 0.1) # Lattach to h
425 a = CNF_DMV_Rule((GOR,1),(GOR,1),(SEAL,1), 1.0, 1.0) # Rattach
426 ah = CNF_DMV_Rule((GOR,1),(GOR,1),(SEAL,0), 1.0, 1.0) # Rattach to h
427 ra = CNF_DMV_Rule( ROOT, STOP, (SEAL,1), 0.1, 0.1) # ROOT
429 p_rules = [ h_Aa, ha, a_Ah, ah, ra, _a_, a_S, a_A, a, rh, _h_, h_S, h_A, h ]
432 b = {}
433 b[(GOR, 0), 'h'] = 1.0
434 b[(GOR, 1), 'a'] = 1.0
436 return CNF_DMV_Grammar({0:'h',1:'a'}, {'h':0,'a':1},
437 None,None,None,b)
439 def testgrammar_h():
440 h = 0
441 p_ROOT, p_STOP, p_ATTACH, p_ORDER = {},{},{},{}
442 p_ROOT[h] = 1.0
443 p_STOP[h,LEFT,NON] = 1.0
444 p_STOP[h,LEFT,ADJ] = 1.0
445 p_STOP[h,RIGHT,NON] = 0.4
446 p_STOP[h,RIGHT,ADJ] = 0.3
447 p_ATTACH[h,h,LEFT] = 1.0 # not used
448 p_ATTACH[h,h,RIGHT] = 1.0 # not used
449 p_terminals = {}
450 p_terminals[(GOR, 0), 'h'] = 1.0
452 g = CNF_DMV_Grammar({h:'h'}, {'h':h}, p_ROOT, p_STOP, p_ATTACH, p_terminals)
454 g.p_GO_AT[h,h,LEFT,NON] = 0.6 # these probabilities are impossible
455 g.p_GO_AT[h,h,LEFT,ADJ] = 0.7 # so add them manually...
456 g.p_GO_AT[h,h,RIGHT,NON] = 1.0
457 g.p_GO_AT[h,h,RIGHT,ADJ] = 1.0
458 g.make_all_rules()
459 return g
462 def testreestimation_h():
463 DEBUG.add('REEST')
464 g = testgrammar_h()
465 return reestimate(g,['h h h'.split()])
467 def regression_tests():
468 test("0.1830", # = .120 + .063, since we have no loc_h
469 "%.4f" % inner(0, 2, (SEAL,0), testgrammar_h(), 'h h'.split(), {}))
471 test("0.1842", # = .0498 + .1092 +.0252
472 "%.4f" % inner(0, 3, (SEAL,0), testgrammar_h(), 'h h h'.split(), {}))
473 test("0.1842",
474 "%.4f" % inner_sent(testgrammar_h(), 'h h h'.split()))
476 test("0.61" ,
477 "%.2f" % outer(1, 3, ( RGOL,0), testgrammar_h(),'h h h'.split(),{},{}))
478 test("0.58" ,
479 "%.2f" % outer(1, 3, (NRGOL,0), testgrammar_h(),'h h h'.split(),{},{}))
482 if __name__ == "__main__":
483 DEBUG.clear()
485 # import profile
486 # profile.run('testreestimation()')
488 # DEBUG.add('reest_attach')
489 # import timeit
490 # print timeit.Timer("cnf_dmv.testreestimation_h()",'''import cnf_dmv
491 # reload(cnf_dmv)''').timeit(1)
493 if __name__ == "__main__":
494 regression_tests()
495 # g = testgrammar()
496 # print g