remove backups, chmod -x
[dmvccm.git] / src / pcnf_dmv.py
blob5bdbff7488395b4c36b08424983fd23239f0f145
1 # pcnf_dmv.py
3 #import numpy # numpy provides Fast Arrays, for future optimization
4 import io
6 from common_dmv import *
7 SEALS = [GOR, RGOL, SEAL, NGOR, NRGOL] # overwriting here
10 if __name__ == "__main__":
11 print "pcnf_dmv module tests:"
13 def make_GO_AT(p_STOP,p_ATTACH):
14 p_GO_AT = {}
15 for (a,h,dir), p_ah in p_ATTACH.iteritems():
16 p_GO_AT[a,h,dir, NON] = p_ah * (1-p_STOP[h, dir, NON])
17 p_GO_AT[a,h,dir, ADJ] = p_ah * (1-p_STOP[h, dir, ADJ])
18 return p_GO_AT
20 class PCNF_DMV_Grammar(io.Grammar):
21 '''The DMV-PCFG.
23 Public members:
24 p_STOP, p_ROOT, p_ATTACH, p_terminals
25 These are changed in the Maximation step, then used to set the
26 new probabilities of each PCNF_DMV_Rule.
28 __p_rules is private, but we can still say stuff like:
29 for r in g.all_rules():
30 r.prob = (1-p_STOP[...]) * p_ATTACH[...]
31 '''
32 def __str__(self):
33 str = ""
34 for r in self.all_rules():
35 str += "%s\n" % r.__str__(self.numtag)
36 return str
38 def LHSs(self):
39 return [ROOT] + [(s_h,h)
40 for h in self.headnums()
41 for s_h in SEALS]
43 def sent_rules(self, sent_nums):
44 sent_nums_stop = sent_nums + [POS(STOP)]
45 return [ r for LHS in self.LHSs()
46 for r in self.arg_rules(LHS, sent_nums)
47 if POS(r.L()) in sent_nums_stop
48 and POS(r.R()) in sent_nums_stop ]
50 # used in outer:
51 def mothersR(self, w_node, argnums):
52 '''For all LHS and x, return all rules of the form 'LHS->x w_node'.'''
53 if w_node not in self.__mothersR:
54 self.__mothersR[w_node] = [r for LHS in self.LHSs()
55 for r in self.rules(LHS)
56 if r.R() == w_node]
57 argnums.append(POS(STOP))
58 return [r for r in self.__mothersR[w_node]
59 if POS(r.L()) in argnums]
61 def mothersL(self, w_node, argnums):
62 '''For all LHS and x, return all rules of the form 'LHS->w_node x'.'''
63 if w_node not in self.__mothersL:
64 self.__mothersL[w_node] = [r for LHS in self.LHSs()
65 for r in self.rules(LHS)
66 if r.L() == w_node]
67 argnums.append(POS(STOP))
68 return [r for r in self.__mothersL[w_node]
69 if POS(r.R()) in argnums]
72 # used in inner:
73 def arg_rules(self, LHS, argnums):
74 return [r for r in self.rules(LHS)
75 if (POS(r.R()) in argnums
76 or POS(r.L()) in argnums)]
79 def make_all_rules(self):
80 self.new_rules([r for LHS in self.LHSs()
81 for r in self._make_rules(LHS, self.headnums())])
83 def _make_rules(self, LHS, argnums):
84 '''This is where the PCNF grammar is defined. Also, s_dir_typ shows how
85 useful it'd be to split up the seals into direction and
86 type... todo?'''
87 h = POS(LHS)
88 if LHS == ROOT:
89 return [PCNF_DMV_Rule(LEFT, LHS, (SEAL,h), STOP, self.p_ROOT[h])
90 for h in set(argnums)]
91 s_h = seals(LHS)
92 if s_h == GOR:
93 return [] # only terminals from here on
94 s_dir_type = { # seal of LHS
95 RGOL: (RIGHT, 'STOP'), NGOR: (RIGHT, 'ATTACH'),
96 SEAL: (LEFT, 'STOP'), NRGOL: (LEFT, 'ATTACH') }
97 dir_s_adj = { # seal of h_daughter
98 RIGHT: [(GOR, True),(NGOR, False)] ,
99 LEFT: [(RGOL,True),(NRGOL,False)] }
100 dir,type = s_dir_type[s_h]
101 rule = {
102 'ATTACH': [PCNF_DMV_Rule(dir, LHS, (s, h), (SEAL,a), self.p_GO_AT[a,h,dir,adj])
103 for a in set(argnums) if (a,h,dir) in self.p_ATTACH
104 for s, adj in dir_s_adj[dir]] ,
105 'STOP': [PCNF_DMV_Rule(dir, LHS, (s, h), STOP, self.p_STOP[h,dir,adj])
106 for s, adj in dir_s_adj[dir]] }
107 return rule[type]
110 def __init__(self, numtag, tagnum, p_ROOT, p_STOP, p_ATTACH, p_terminals):
111 io.Grammar.__init__(self, numtag, tagnum, [], p_terminals)
112 self.p_STOP = p_STOP
113 self.p_ATTACH = p_ATTACH
114 self.p_ROOT = p_ROOT
115 self.p_GO_AT = make_GO_AT(self.p_STOP, self.p_ATTACH)
116 self.make_all_rules()
117 self.__mothersL = {}
118 self.__mothersR = {}
121 class PCNF_DMV_Rule(io.PCNF_Rule):
122 '''A single PCNF rule in the PCFG, of the form
123 LHS -> L R
124 where LHS, L and R are 'nodes', eg. of the form (seals, head).
126 Public members:
127 prob
129 Private members:
130 __L, __R, __LHS
132 Different rule-types have different probabilities associated with
133 them, see formulas.pdf
135 def seals(self):
136 return seals(self.LHS())
138 def POS(self):
139 return POS(self.LHS())
141 def __init__(self, dir, LHS, h_daughter, a_daughter, prob):
142 self.__dir = dir
143 if dir == LEFT:
144 L, R = a_daughter, h_daughter
145 elif dir == RIGHT:
146 L, R = h_daughter, a_daughter
147 else:
148 raise ValueError, "dir must be LEFT or RIGHT, given: %s"%dir
149 for b_h in [LHS, L, R]:
150 if seals(b_h) not in SEALS:
151 raise ValueError("seals must be in %s; was given: %s"
152 % (SEALS, seals(b_h)))
153 io.PCNF_Rule.__init__(self, LHS, L, R, prob)
155 def adj(self):
156 "'undefined' for ROOT"
157 if self.__dir == LEFT:
158 return seals(self.R()) == RGOL
159 else: # RIGHT
160 return seals(self.L()) == GOR
162 def __str__(self, tag=lambda x:x):
163 if self.adj(): adj_str = "adj"
164 else: adj_str = "non_adj"
165 if self.LHS() == ROOT: adj_str = ""
166 return "%s --> %s %s\t[%.2f] %s" % (node_str(self.LHS(), tag),
167 node_str(self.L(), tag),
168 node_str(self.R(), tag),
169 self.prob,
170 adj_str)
178 ###################################
179 # dmv-specific version of inner() #
180 ###################################
181 def inner(i, j, LHS, g, sent, ichart={}):
182 ''' A PCNF rewrite of io.inner(), to take STOP rules into accord. '''
183 def O(i,j):
184 return sent[i]
186 sent_nums = g.sent_nums(sent)
188 def e(i,j,LHS, n_t):
189 def tab():
190 "Tabs for debug output"
191 return "\t"*n_t
192 if (i, j, LHS) in ichart:
193 if 'INNER' in DEBUG:
194 print "%s*= %.4f in ichart: i:%d j:%d LHS:%s" % (tab(), ichart[i, j, LHS], i, j, node_str(LHS))
195 return ichart[i, j, LHS]
196 else:
197 # if seals(LHS) == RGOL then we have to STOP first
198 if i == j-1 and seals(LHS) == GOR:
199 if (LHS, O(i,j)) in g.p_terminals:
200 prob = g.p_terminals[LHS, O(i,j)] # "b[LHS, O(s)]" in Lari&Young
201 else:
202 prob = 0.0
203 if 'INNER' in DEBUG:
204 print "%sLACKING TERMINAL:" % tab()
205 if 'INNER' in DEBUG:
206 print "%s*= %.4f (terminal: %s -> %s)" % (tab(),prob, node_str(LHS), O(i,j))
207 return prob
208 else:
209 p = 0.0 # "sum over j,k in a[LHS,j,k]"
210 for rule in g.arg_rules(LHS, sent_nums):
211 if 'INNER' in DEBUG:
212 print "%ssumming rule %s i:%d j:%d" % (tab(),rule,i,j)
213 L = rule.L()
214 R = rule.R()
215 # if it's a STOP rule, rewrite for the same xrange:
216 if (L == STOP) or (R == STOP):
217 if L == STOP:
218 pLR = e(i, j, R, n_t+1)
219 elif R == STOP:
220 pLR = e(i, j, L, n_t+1)
221 p += rule.p() * pLR
222 if 'INNER' in DEBUG:
223 print "%sp= %.4f (STOP)" % (tab(), p)
225 elif j > i+1 and seals(LHS) != GOR:
226 # not a STOP, attachment rewrite:
227 for k in xtween(i, j): # i<k<j
228 p_L = e(i, k, L, n_t+1)
229 p_R = e(k, j, R, n_t+1)
230 p += rule.p() * p_L * p_R
231 if 'INNER' in DEBUG:
232 print "%sp= %.4f (ATTACH, p_L:%.4f, p_R:%.4f, rule:%.4f)" % (tab(), p,p_L,p_R,rule.p())
233 ichart[i, j, LHS] = p
234 return p
235 # end of e-function
237 inner_prob = e(i,j,LHS, 0)
238 if 'INNER' in DEBUG:
239 print debug_ichart(g,sent,ichart)
240 return inner_prob
241 # end of pcnf_dmv.inner(i, j, LHS, g, sent, ichart={})
244 def debug_ichart(g,sent,ichart):
245 str = "---ICHART:---\n"
246 for (i,j,LHS),v in ichart.iteritems():
247 if type(v) == dict: # skip 'tree'
248 continue
249 str += "%s -> %s ... %s: \t%.4f\n" % (node_str(LHS,g.numtag),
250 sent[i], sent[j-1], v)
251 str += "---ICHART:end---\n"
252 return str
255 def inner_sent(g, sent, ichart={}):
256 return sum([inner(0, len(sent), ROOT, g, sent, ichart)])
259 #######################################
260 # pcnf_dmv-specific version of outer() #
261 #######################################
262 def outer(i,j,w_node, g, sent, ichart={}, ochart={}):
263 def e(i,j,LHS):
264 # or we could just look it up in ichart, assuming ichart to be done
265 return inner(i, j, LHS, g, sent, ichart)
267 sent_nums = g.sent_nums(sent)
268 if POS(w_node) not in sent_nums[i:j]:
269 # sanity check, w must be able to dominate sent[i:j]
270 return 0.0
272 def f(i,j,w_node):
273 if (i,j,w_node) in ochart:
274 return ochart[(i, j, w_node)]
275 if w_node == ROOT:
276 if i == 0 and j == len(sent):
277 return 1.0
278 else: # ROOT may only be used on full sentence
279 return 0.0 # but we may have non-ROOTs over full sentence too
281 p = 0.0
283 for rule in g.mothersL(w_node, sent_nums): # rule.L() == w_node
284 if 'OUTER' in DEBUG: print "w_node:%s (L) ; %s"%(node_str(w_node),rule)
285 if rule.R() == STOP:
286 p0 = f(i,j,rule.LHS()) * rule.p()
287 if 'OUTER' in DEBUG: print p0
288 p += p0
289 else:
290 for k in xgt(j,sent): # i<j<k
291 p0 = f(i,k, rule.LHS() ) * rule.p() * e(j,k, rule.R() )
292 if 'OUTER' in DEBUG: print p0
293 p += p0
295 for rule in g.mothersR(w_node, sent_nums): # rule.R() == w_node
296 if 'OUTER' in DEBUG: print "w_node:%s (R) ; %s"%(node_str(w_node),rule)
297 if rule.L() == STOP:
298 p0 = f(i,j,rule.LHS()) * rule.p()
299 if 'OUTER' in DEBUG: print p0
300 p += p0
301 else:
302 for k in xlt(i): # k<i<j
303 p0 = e(k,i, rule.L() ) * rule.p() * f(k,j, rule.LHS() )
304 if 'OUTER' in DEBUG: print p0
305 p += p0
307 ochart[i,j,w_node] = p
308 return p
311 return f(i,j,w_node)
312 # end outer(i,j,w_node, g,sent, ichart,ochart)
316 ##############################
317 # reestimation, todo: #
318 ##############################
319 def reest_zeros(rules):
320 f = { ('den',ROOT) : 0.0 }
321 for r in rules:
322 for nd in ['num','den']:
323 f[nd, r.LHS(), r.L(), r.R()] = 0.0
324 return f
326 def reest_freq(g, corpus):
327 ''' P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
328 f = reest_zeros(g.all_rules())
329 ichart = {}
330 ochart = {}
332 p_sent = None # 50 % speed increase on storing this locally
334 def c_g(i,j,LHS,sent):
335 if p_sent == 0.0:
336 return 0.0
337 return e_g(i,j,LHS,sent) * f_g(i,j,LHS,sent) / p_sent
339 def w1_g(i,j,rule,sent): # unary (stop) rules, LHS -> child_node
340 if rule.L() == STOP: child = rule.R()
341 elif rule.R() == STOP: child = rule.L()
342 else: raise ValueError, "expected a stop rule: %s"%(rule,)
344 if p_sent == 0.0: return 0.0
346 p_out = f_g(i,j,rule.LHS(),sent)
347 if p_out == 0.0: return 0.0
349 return rule.p() * e_g(i,j,child,sent) * p_out / p_sent
351 def w_g(i,j,rule,sent):
352 if p_sent == 0.0 or i+1 == j: return 0.0
354 p_out = f_g(i,j,rule.LHS(),sent)
355 if p_out == 0.0: return 0.0
357 p = 0.0
358 for k in xtween(i,j):
359 p += rule.p() * e_g(i,k,rule.L(),sent) * e_g(k,j,rule.R(),sent) * p_out
360 return p / p_sent
362 def f_g(i,j,LHS,sent):
363 if (i,j,LHS) in ochart:
364 # print ".",
365 return ochart[i,j,LHS]
366 else:
367 return outer(i,j,LHS,g,sent,ichart,ochart)
369 def e_g(i,j,LHS,sent):
370 if (i,j,LHS) in ichart:
371 # print ".",
372 return ichart[i,j,LHS]
373 else:
374 return inner(i,j,LHS,g,sent,ichart)
376 for s_num,sent in enumerate(corpus):
377 if s_num%5==0: print "s.num %d"%s_num,
378 if 'REEST' in DEBUG: print sent
379 ichart = {}
380 ochart = {}
381 # since we keep re-using p_sent, it seems better to have
382 # sentences as the outer loop; o/w we'd have to keep every chart
383 p_sent = inner_sent(g, sent, ichart)
385 sent_nums = g.sent_nums(sent)
386 sent_rules = g.sent_rules(sent_nums)
387 for r in sent_rules:
388 LHS, L, R = r.LHS(), r.L(), r.R()
389 if 'REEST' in DEBUG: print r
390 if LHS == ROOT:
391 f['num',LHS,L,R] += r.p() * e_g(0, len(sent), R, sent)
392 f['den',ROOT] += p_sent
393 continue # !!! o/w we add wrong values to it below
394 if L == STOP or R == STOP:
395 w = w1_g
396 else:
397 w = w_g
398 for i in xlt(len(sent)):
399 for j in xgt(i, sent):
400 f['num',LHS,L,R] += w(i,j, r, sent)
401 f['den',LHS,L,R] += c_g(i,j, LHS, sent) # v_q
402 print ""
403 return f
405 def reestimate(g, corpus):
406 f = reest_freq(g, corpus)
407 print "applying f to rules"
408 for r in g.all_rules():
409 if r.LHS() == ROOT:
410 r.prob = f['den',ROOT]
411 else:
412 r.prob = f['den',r.LHS(),r.L(),r.R()]
413 if r.prob > 0.0:
414 r.prob = f['num',r.LHS(),r.L(),r.R()] / r.prob
415 return g
418 ##############################
419 # Testing functions: #
420 ##############################
421 def testgrammar():
422 # make sure we use the same data:
423 from loc_h_dmv import testcorpus
425 import pcnf_harmonic
426 reload(pcnf_harmonic)
427 return pcnf_harmonic.initialize(testcorpus)
429 def testreestimation():
430 from loc_h_dmv import testcorpus
431 g = testgrammar()
432 g = reestimate(g, testcorpus[0:4])
433 return g
435 def testgrammar_a(): # Non, Adj
436 _h_ = PCNF_DMV_Rule((SEAL,0), STOP, ( RGOL,0), 1.0, 1.0) # LSTOP
437 h_S = PCNF_DMV_Rule(( RGOL,0),(GOR,0), STOP, 0.4, 0.3) # RSTOP
438 h_A = PCNF_DMV_Rule(( RGOL,0),(SEAL,0),( RGOL,0),0.2, 0.1) # Lattach
439 h_Aa= PCNF_DMV_Rule(( RGOL,0),(SEAL,1),( RGOL,0),0.4, 0.6) # Lattach to a
440 h = PCNF_DMV_Rule((GOR,0),(GOR,0),(SEAL,0), 1.0, 1.0) # Rattach
441 ha = PCNF_DMV_Rule((GOR,0),(GOR,0),(SEAL,1), 1.0, 1.0) # Rattach to a
442 rh = PCNF_DMV_Rule( ROOT, STOP, (SEAL,0), 0.9, 0.9) # ROOT
444 _a_ = PCNF_DMV_Rule((SEAL,1), STOP, ( RGOL,1), 1.0, 1.0) # LSTOP
445 a_S = PCNF_DMV_Rule(( RGOL,1),(GOR,1), STOP, 0.4, 0.3) # RSTOP
446 a_A = PCNF_DMV_Rule(( RGOL,1),(SEAL,1),( RGOL,1),0.4, 0.6) # Lattach
447 a_Ah= PCNF_DMV_Rule(( RGOL,1),(SEAL,0),( RGOL,1),0.2, 0.1) # Lattach to h
448 a = PCNF_DMV_Rule((GOR,1),(GOR,1),(SEAL,1), 1.0, 1.0) # Rattach
449 ah = PCNF_DMV_Rule((GOR,1),(GOR,1),(SEAL,0), 1.0, 1.0) # Rattach to h
450 ra = PCNF_DMV_Rule( ROOT, STOP, (SEAL,1), 0.1, 0.1) # ROOT
452 p_rules = [ h_Aa, ha, a_Ah, ah, ra, _a_, a_S, a_A, a, rh, _h_, h_S, h_A, h ]
455 b = {}
456 b[(GOR, 0), 'h'] = 1.0
457 b[(GOR, 1), 'a'] = 1.0
459 return PCNF_DMV_Grammar({0:'h',1:'a'}, {'h':0,'a':1},
460 None,None,None,b)
462 def testgrammar_h():
463 h = 0
464 p_ROOT, p_STOP, p_ATTACH, p_ORDER = {},{},{},{}
465 p_ROOT[h] = 1.0
466 p_STOP[h,LEFT,NON] = 1.0
467 p_STOP[h,LEFT,ADJ] = 1.0
468 p_STOP[h,RIGHT,NON] = 0.4
469 p_STOP[h,RIGHT,ADJ] = 0.3
470 p_ATTACH[h,h,LEFT] = 1.0 # not used
471 p_ATTACH[h,h,RIGHT] = 1.0 # not used
472 p_terminals = {}
473 p_terminals[(GOR, 0), 'h'] = 1.0
475 g = PCNF_DMV_Grammar({h:'h'}, {'h':h}, p_ROOT, p_STOP, p_ATTACH, p_terminals)
477 g.p_GO_AT[h,h,LEFT,NON] = 0.6 # these probabilities are impossible
478 g.p_GO_AT[h,h,LEFT,ADJ] = 0.7 # so add them manually...
479 g.p_GO_AT[h,h,RIGHT,NON] = 1.0
480 g.p_GO_AT[h,h,RIGHT,ADJ] = 1.0
481 g.make_all_rules()
482 return g
485 def testreestimation_h():
486 DEBUG.add('REEST')
487 g = testgrammar_h()
488 return reestimate(g,['h h h'.split()])
490 def regression_tests():
491 test("0.1830", # = .120 + .063, since we have no loc_h
492 "%.4f" % inner(0, 2, (SEAL,0), testgrammar_h(), 'h h'.split(), {}))
494 test("0.1842", # = .0498 + .1092 +.0252
495 "%.4f" % inner(0, 3, (SEAL,0), testgrammar_h(), 'h h h'.split(), {}))
496 test("0.1842",
497 "%.4f" % inner_sent(testgrammar_h(), 'h h h'.split()))
499 test("0.61" ,
500 "%.2f" % outer(1, 3, ( RGOL,0), testgrammar_h(),'h h h'.split(),{},{}))
501 test("0.58" ,
502 "%.2f" % outer(1, 3, (NRGOL,0), testgrammar_h(),'h h h'.split(),{},{}))
505 if __name__ == "__main__":
506 DEBUG.clear()
508 # import profile
509 # profile.run('testreestimation()')
511 # DEBUG.add('reest_attach')
512 # import timeit
513 # print timeit.Timer("pcnf_dmv.testreestimation_h()",'''import pcnf_dmv
514 # reload(pcnf_dmv)''').timeit(1)
516 if __name__ == "__main__":
517 regression_tests()
518 # g = testgrammar()
519 # print g