outer seems to be working now
[dmvccm.git] / src / dmv.py
blob321286d6d9265653b8f2efb27c966bb5f0b245fd
1 #### changes by KBU:
2 # 2008-06-11
3 # - moved prune() to junk.py, now using outer() instead. outer() is
4 # written, but needs testing.
6 # 2008-06-09
7 # - prune() finished, seems to be working.
8 # - started on implementing the other reestimation formulas, in
9 # reestimate()
11 # 2008-06-04
12 # - moved initialization to harmonic.py
14 # 2008-06-03
15 # - fixed a number of little bugs in initialization, where certain
16 # rules were simply not created, or created "backwards"
17 # - dmv.inner() should Work now...
19 # 2008-06-01
20 # - finished typing in dmv.inner(), still have to test and debug
21 # it. The ichart is now four times as big since for any rule we may
22 # have attachments to either the left or the right below, which
23 # upper rules depend on, for selecting probN or probA
25 # 2008-05-30
26 # - copied inner() into this file, to make the very dmv-specific
27 # adjacency stuff work (have to factor that out later on, when it
28 # works).
30 # 2008-05-29
31 # - init_normalize is done, it creates p_STOP, p_ROOT and p_CHOOSE,
32 # and also adds the relevant probabilities to p_rules in a grammar.
33 # Still, each individual rule has to store both adjacent and non_adj
34 # probabilities, and inner() should be able to send some parameter
35 # which lets the rule choose... hopefully... Is this possible to do
36 # top-down even? when the sentence could be all the same words?
37 # todo: extensive testing of identical words in sentences!
38 # - frequencies (only used in initialization) are stored as strings,
39 # but in the rules and p_STOP etc, there are only numbers.
41 # 2008-05-28
42 # - more work on initialization (init_freq and init_normalize),
43 # getting closer to probabilities now.
45 # 2008-05-27
46 # - started on initialization. So far, I have frequencies for
47 # everything, very harmonic. Still need to make these into 1-summing
48 # probabilities
50 # 2008-05-24
51 # - prettier printout for DMV_Rule
52 # - DMV_Rule changed a bit. head, L and R are now all pairs of the
53 # form (bars, head).
54 # - Started on P_STOP, a bit less pseudo now..
58 #import numpy # numpy provides Fast Arrays, for future optimization
59 import io
61 # non-tweakable/constant "lookup" globals
62 BARS = [0,1,2]
63 RBAR = 1
64 LRBAR = 2
65 NOBAR = 0
66 ROOT = (LRBAR, -1)
67 STOP = (NOBAR, -2)
69 if __name__ == "__main__":
70 print "DMV module tests:"
73 def node(bars, head):
74 '''Useless function, but just here as documentation. Nodes make up
75 LHS, R and L in each DMV_Rule'''
76 return (bars, head)
78 def bars(node):
79 return node[0]
81 def head(node):
82 return node[1]
85 class DMV_Grammar(io.Grammar):
86 '''The DMV-PCFG.
88 Public members:
89 p_STOP, p_ROOT, p_CHOOSE, p_terminals
90 These are changed in the Maximation step, then used to set the
91 new probabilities of each DMV_Rule.
93 Todo: make p_terminals private? (But it has to be changable in
94 maximation step due to the short-cutting rules... could of course
95 make a DMV_Grammar function to update the short-cut rules...)
97 __p_rules is private, but we can still say stuff like:
98 for r in g.all_rules():
99 r.probN = newProbN
101 What other representations do we need? (P_STOP formula uses
102 deps_D(h,l/r) at least)'''
103 def __str__(self):
104 str = ""
105 for r in self.all_rules():
106 str += "%s\n" % r.__str__(self.numtag)
107 return str
109 def h_rules(self, h):
110 return [r for r in self.all_rules() if r.head() == h]
112 def mothersL(self, Node, sent_nums):
113 return [r for r in self.all_rules() if r.L() == Node]
115 def mothersR(self, Node, sent_nums):
116 return [r for r in self.all_rules() if r.R() == Node]
118 def rules(self, LHS):
119 return [r for r in self.all_rules() if r.LHS() == LHS]
121 def sent_rules(self, LHS, sent_nums):
122 '''Used in dmv.inner. Todo: this takes a _lot_ of time, it
123 seems. Could use some more space and cache some of this
124 somehow perhaps?'''
125 # We don't want to rule out STOPs!
126 nums = sent_nums + [ head(STOP) ]
127 return [r for r in self.all_rules() if r.LHS() == LHS
128 and head(r.L()) in nums and head(r.R()) in nums]
130 def deps_L(self, head): # todo: do I use this at all?
131 # todo test, probably this list comprehension doesn't work
132 return [a for r in self.all_rules() if r.head() == head and a == r.L()]
134 def deps_R(self, head):
135 # todo test, probably this list comprehension doesn't work
136 return [a for r in self.all_rules() if r.head() == head and a == r.R()]
138 def __init__(self, p_rules, p_terminals, p_STOP, p_CHOOSE, p_ROOT, numtag, tagnum):
139 io.Grammar.__init__(self, p_rules, p_terminals, numtag, tagnum)
140 self.p_STOP = p_STOP
141 self.p_CHOOSE = p_CHOOSE
142 self.p_ROOT = p_ROOT
143 self.head_nums = [k for k in numtag.iterkeys()]
146 class DMV_Rule(io.CNF_Rule):
147 '''A single CNF rule in the PCFG, of the form
148 LHS -> L R
149 where LHS, L and R are 'nodes', eg. of the form (bars, head).
151 Public members:
152 probN, probA
154 Private members:
155 __L, __R, __LHS
157 Different rule-types have different probabilities associated with
158 them:
160 _h_ -> STOP h_ P( STOP|h,L, adj)
161 _h_ -> STOP h_ P( STOP|h,L,non_adj)
162 h_ -> h STOP P( STOP|h,R, adj)
163 h_ -> h STOP P( STOP|h,R,non_adj)
164 h_ -> _a_ h_ P(-STOP|h,L, adj) * P(a|h,L)
165 h_ -> _a_ h_ P(-STOP|h,L,non_adj) * P(a|h,L)
166 h -> h _a_ P(-STOP|h,R, adj) * P(a|h,R)
167 h -> h _a_ P(-STOP|h,R,non_adj) * P(a|h,R)
169 def p(self, adj, *arg):
170 if adj:
171 return self.probA
172 else:
173 return self.probN
175 def p_STOP(self, s, t, loc_h):
176 '''Returns the correct probability, adjacent if we're rewriting from
177 the (either left or right) end of the fragment. '''
178 if self.L() == STOP:
179 return self.p(s == loc_h)
180 elif self.R() == STOP:
181 if not loc_h == s:
182 if 'TODO' in io.DEBUG:
183 print "(%s given loc_h:%d but s:%d. Todo: optimize away!)" % (self, loc_h, s)
184 return 0.0
185 else:
186 return self.p(t == loc_h)
188 def p_ATTACH(self, r, loc_h, s=None):
189 '''Returns the correct probability, adjacent if we haven't attached
190 anything before.'''
191 if self.LHS() == self.L():
192 if not loc_h == s:
193 if 'TODO' in io.DEBUG:
194 print "(%s given loc_h (loc_L):%d but s:%d. Todo: optimize away!)" % (self, loc_h, s)
195 return 0.0
196 else:
197 return self.p(r == loc_h)
198 elif self.LHS() == self.R():
199 return self.p(r+1 == loc_h)
201 def bars(self):
202 return bars(self.LHS())
204 def head(self):
205 return head(self.LHS())
207 def __init__(self, LHS, L, R, probN, probA):
208 for b_h in [LHS, L, R]:
209 if bars(b_h) not in BARS:
210 raise ValueError("bars must be in %s; was given: %s"
211 % (BARS, bars(b_h)))
212 io.CNF_Rule.__init__(self, LHS, L, R, probN)
213 self.probA = probA # adjacent
214 self.probN = probN # non_adj
216 @classmethod # so we can call DMV_Rule.bar_str(b_h)
217 def bar_str(cls, b_h, tag=lambda x:x):
218 if(b_h == ROOT):
219 return 'ROOT'
220 elif(b_h == STOP):
221 return 'STOP'
222 elif(bars(b_h) == RBAR):
223 return " %s_ " % tag(head(b_h))
224 elif(bars(b_h) == LRBAR):
225 return "_%s_ " % tag(head(b_h))
226 else:
227 return " %s " % tag(head(b_h))
230 def __str__(self, tag=lambda x:x):
231 return "%s-->%s %s\t[N %.2f] [A %.2f]" % (self.bar_str(self.LHS(), tag),
232 self.bar_str(self.L(), tag),
233 self.bar_str(self.R(), tag),
234 self.probN,
235 self.probA)
243 ###################################
244 # dmv-specific version of inner() #
245 ###################################
246 def locs(h, sent, s=0, t=None, remove=None):
247 '''Return the locations of h in sent, or some fragment of sent (in the
248 latter case we make sure to offset the locations correctly so that
249 for any x in the returned list, sent[x]==h).
251 t is inclusive, to match the way indices work with inner()
252 (although python list-splicing has "exclusive" end indices)'''
253 if t == None:
254 t = len(sent)-1
255 return [i+s for i,w in enumerate(sent[s:t+1])
256 if w == h and not (i+s) == remove]
259 def inner(s, t, LHS, loc_h, g, sent, ichart):
260 ''' A rewrite of io.inner(), to take adjacency into accord.
262 The ichart is now of this form:
263 ichart[s,t,LHS, loc_h]
265 loc_h gives adjacency (along with r and location of other child
266 for attachment rules), and is needed in P_STOP reestimation.
268 Todo: if possible, refactor (move dmv-specific stuff back into
269 dmv, so this is "general" enough to be in io.py)
272 def O(s):
273 return sent[s]
275 sent_nums = [g.tagnum(tag) for tag in sent]
276 tree = {}
278 def e(s,t,LHS, loc_h, n_t):
279 def tab():
280 "Tabs for debug output"
281 return "\t"*n_t
283 if (s, t, LHS, loc_h) in ichart:
284 if 'INNER' in io.DEBUG:
285 print "%s*= %.4f in ichart: s:%d t:%d LHS:%s loc:%d" % (tab(),ichart[s, t, LHS, loc_h], s, t,
286 DMV_Rule.bar_str(LHS), loc_h)
287 return ichart[s, t, LHS, loc_h]
288 else:
289 if s == t:
290 if not loc_h == s:
291 if 'INNER' in io.DEBUG:
292 print "%s*= 0.0 (wrong loc_h)" % tab()
293 return 0.0
294 elif (LHS, O(s)) in g.p_terminals:
295 prob = g.p_terminals[LHS, O(s)] # "b[LHS, O(s)]" in Lari&Young
296 else:
297 # todo: assuming this is how to deal w/lacking
298 # rules, since we add prob.s, and 0 is identity
299 prob = 0.0
300 if 'INNER' in io.DEBUG:
301 print "%sLACKING TERMINAL:" % tab()
302 # todo: add to ichart perhaps? Although, it _is_ simple lookup..
303 if 'INNER' in io.DEBUG:
304 print "%s*= %.4f (terminal: %s -> %s_%d)" % (tab(),prob, DMV_Rule.bar_str(LHS), O(s), loc_h)
305 return prob
306 else:
307 p = 0.0 # "sum over j,k in a[LHS,j,k]"
308 for rule in g.sent_rules(LHS, sent_nums):
309 if 'INNER' in io.DEBUG:
310 print "%ssumming rule %s s:%d t:%d loc:%d" % (tab(),rule,s,t,loc_h)
311 L = rule.L()
312 R = rule.R()
313 if (s,t,LHS,loc_h) not in tree:
314 tree[s,t,LHS,loc_h] = set()
315 if loc_h == t and rule.LHS() == L:
316 break # 25% faster with these cut-offs
317 if loc_h == s and rule.LHS() == R:
318 break
319 # if it's a STOP rule, rewrite for the same range:
320 if (L == STOP) or (R == STOP):
321 if L == STOP:
322 pLR = e(s, t, R, loc_h, n_t+1)
323 if pLR > 0.0:
324 tree[s,t,LHS,loc_h].add((s,t,R,loc_h))
325 elif R == STOP:
326 pLR = e(s, t, L, loc_h, n_t+1)
327 if pLR > 0.0:
328 tree[s,t,LHS,loc_h].add((s,t,L,loc_h))
329 p += rule.p_STOP(s, t, loc_h) * pLR
330 if 'INNER' in io.DEBUG:
331 print "%sp= %.4f (STOP)" % (tab(), p)
333 else: # not a STOP, an attachment rewrite:
334 rp_ATTACH = rule.p_ATTACH # todo: profile/speedtest
335 for r in xrange(s, t):
336 p_h = rp_ATTACH(r, loc_h, s=s)
337 if rule.LHS() == L:
338 locs_L = [loc_h]
339 locs_R = locs(head(R), sent_nums, r+1, t, loc_h)
340 elif rule.LHS() == R:
341 locs_L = locs(head(L), sent_nums, s, r, loc_h)
342 locs_R = [loc_h]
343 for loc_L in locs_L:
344 pL = e(s, r, L, loc_L, n_t+1)
345 if pL > 0.0:
346 for loc_R in locs_R:
347 pR = e(r+1, t, R, loc_R, n_t+1)
348 if pR > 0.0: # and pL > 0.0
349 tree[s,t,LHS,loc_h].add(( s ,r,L,loc_L))
350 tree[s,t,LHS,loc_h].add((r+1,t,R,loc_R))
351 p += pL * p_h * pR
352 if 'INNER' in io.DEBUG:
353 print "%sp= %.4f (ATTACH)" % (tab(), p)
354 ichart[s, t, LHS, loc_h] = p
355 return p
356 # end of e-function
358 inner_prob = e(s,t,LHS,loc_h, 0)
359 ichart['tree'] = tree
360 if 'INNER' in io.DEBUG:
361 print debug_ichart(g,sent,ichart)
362 return inner_prob
363 # end of dmv.inner(s, t, LHS, loc_h, g, sent, ichart)
366 def debug_ichart(g,sent,ichart):
367 str = "---ICHART:---\n"
368 for (s,t,LHS,loc_h),v in ichart.iteritems():
369 str += "%s -> %s_%d ... %s_%d (loc_h:%s):\t%.4f\n" % (DMV_Rule.bar_str(LHS,g.numtag),
370 sent[s], s, sent[s], t, loc_h, v)
371 str += "---ICHART:end---\n"
372 return str
375 def inner_sent(loc_h, g, sent, ichart):
376 return inner(0, len(sent)-1, ROOT, loc_h, g, sent, ichart)
383 def c(s,t,LHS,loc_h):
384 return inner() * outer() # divided by sentence probability? todo
386 ###################################
387 # dmv-specific version of outer() #
388 ###################################
389 def outer(s,t,Node,loc_N, g, sent, ochart, ichart):
390 ''' http://www.student.uib.no/~kun041/dmvccm/DMVCCM.html#outer
392 def e(s,t,LHS,loc_h):
393 # or we could just look it up in ichart, assuming ichart to be done
394 return inner(s, t, LHS, loc_h, g, sent, ichart)
396 T = len(sent)-1
397 sent_nums = [i for i,w in enumerate(sent)]
399 def f(s,t,Node,loc_N):
400 if (s,t,Node) in ochart:
401 return ochart[(s, t, Node,loc_N)]
402 if Node == ROOT:
403 if s == 0 and t == T:
404 return 1.0
405 else: # ROOT may only be used on full sentence
406 return 0.0 # but we may have non-ROOTs over full sentence too
407 p = 0.0
409 for mom in g.mothersL(Node, sent_nums): # mom.L() == Node
410 R = mom.R()
411 mLHS = mom.LHS()
412 if R == STOP:
413 p += f(s,t,mLHS,loc_N) * mom.p_STOP(s,t,loc_N) # == loc_m
414 else:
415 if bars(mLHS) == RBAR: # left attachment, head(mLHS) == head(L)
416 for r in xrange(t+1,T+1): # t+1 to lasT
417 for loc_m in locs(head(mLHS),sent_nums,t+1,r):
418 p_m = mom.p(t+1 == loc_m)
419 p += f(s,r,mLHS,loc_m) * p_m * e(t+1,r,R,loc_m)
420 else: # right attachment, head(mLHS) == head(Node)
421 loc_m = loc_N
422 p_m = mom.p( t == loc_m)
423 for r in xrange(t+1,T+1): # t+1 to lasT
424 for loc_R in locs(head(mLHS),sent_nums,t+1,r):
425 p += f(s,r,mLHS,loc_m) * p_m * e(t+1,r,R,loc_R)
427 for mom in g.mothersR(Node, sent_nums):
428 L = mom.L()
429 mLHS = mom.LHS()
430 if L == STOP:
431 p += f(s,t,mLHS,loc_N) * mom.p_STOP(s,t,loc_N) # == loc_m
432 else:
433 if bars(mLHS) == RBAR: # left attachment, head(mLHS) == head(Node)
434 loc_m = loc_N
435 p_m = mom.p( s == loc_m)
436 for r in xrange(0,s): # first to s-1
437 for loc_L in locs(head(L),sent_nums,r,s-1):
438 p += e(r,s-1,L, loc_L) * p_m * f(r,t,mLHS,loc_m)
439 else: # right attachment, head(mLHS) == head(R)
440 for r in xrange(0,s): # first to s-1
441 for loc_m in locs(head(mLHS),sent_nums,r,s-1):
442 p_m = mom.p(s-1 == loc_m)
443 p += e(r,s-1,L, loc_m) * p_m * f(r,t,mLHS,loc_m)
444 ochart[s,t,Node,loc_N] = p
445 return p
448 return f(s,t,Node,loc_N)
449 # end outer(s,t,Node,loc_N, g,sent, ochart,ichart)
453 ##############################
454 # reestimation, todo: #
455 ##############################
456 def reestimate_zeros(h_nums):
457 # todo: p_ROOT, p_CHOOSE, p_terminals
458 f = {}
459 for h in h_nums:
460 f[('LNSTOP','num',h)] = 0.0
461 f[('LNSTOP','den',h)] = 0.0
462 return f
464 def reestimate(g, corpus):
465 '''current todo.
466 P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
467 f = reestimate_zeros(g.head_nums)
468 for sent in corpus:
469 sent_nums = [g.tagnum(w) for i,w in enumerate(sent)]
470 ichart = {}
471 for loc_h,h in zip(xrange(len(sent)), sent_nums):
472 inner_sent(loc_h, g, sent, ichart)
473 for loc_h,h in zip(xrange(len(sent)), sent_nums):
474 for s in xrange(loc_h): # s<loc(h), range gives strictly less
475 for t in xrange(loc_h, len(sent)):
476 if 'reest' in io.DEBUG:
477 print "s:%s t:%s loc:%d"%(s,t,loc_h)
478 if (s, t, (LRBAR,h), loc_h) in ichart:
479 f[('LNSTOP','num',h)] += ichart[s, t, (LRBAR,h), loc_h]
480 if 'reest' in io.DEBUG:
481 print "num+=%s"%ichart[s, t, (LRBAR,h), loc_h]
482 if (s, t, (RBAR,h), loc_h) in ichart:
483 f[('LNSTOP','den',h)] += ichart[s, t, (RBAR,h), loc_h]
484 if 'reest' in io.DEBUG:
485 print "den+=%s"%ichart[s, t, (RBAR,h), loc_h]
486 if 'reest' in io.DEBUG:
487 print "num:%s den:%s"%(f[('LNSTOP','num',h)],f[('LNSTOP','den',h)])
488 for s in [loc_h]:
489 # adjacent left attachment
490 pass
492 # todo: use sum([ichart[s, t...] etc? but can we then
493 # keep den and num separate within _one_ sum()-call? use map?
495 for r in g.all_rules():
496 if r.L() == STOP:
497 h = r.head()
498 if f[('LNSTOP','den',h)] > 0.0:
499 r.probN = f[('LNSTOP','num',h)] / f[('LNSTOP','den',h)]
500 if 'reest' in io.DEBUG:
501 print "p(STOP|%d=%s,L,N): %s / %s = %s"%(h,g.numtag(h),f[('LNSTOP','num',h)],f[('LNSTOP','den',h)],f[('LNSTOP','num',h)]/f[('LNSTOP','den',h)])
502 else:
503 if 'reest' in io.DEBUG:
504 print "p(STOP|%d=%s,L,N): %s / 0"%(h,g.numtag(h),f[('LNSTOP','num',h)])
514 ##############################
515 # testing functions: #
516 ##############################
518 def testreestimation():
519 corpus = [s.split() for s in ['det vbd nn c vbd','det nn vbd c nn vbd pp',
520 'det vbd nn', 'det vbd nn c vbd pp',
521 'det vbd nn', 'det vbd c nn vbd pp',
522 'det vbd nn', 'det nn vbd nn c vbd pp',
523 'det vbd nn', 'det nn vbd c det vbd pp',
524 'det vbd nn', 'det nn vbd c vbd det det det pp',
525 'det nn vbd', 'det nn vbd c vbd pp',
526 'det nn vbd', 'det nn vbd c vbd det pp',
527 'det nn vbd', 'det nn vbd c vbd pp',
528 'det nn vbd pp', 'det nn vbd det', ]]
529 import harmonic
530 g = harmonic.initialize(corpus)
531 corpus = [s.split() for s in ['det nn vbd det nn' ]]
532 reestimate(g, corpus)
536 def testgrammar_h(): # Non, Adj
537 _h_ = DMV_Rule((LRBAR,0), STOP, ( RBAR,0), 1.0, 1.0) # LSTOP
538 h_S = DMV_Rule(( RBAR,0),(NOBAR,0), STOP, 0.4, 0.3) # RSTOP
539 h_A = DMV_Rule(( RBAR,0),(LRBAR,0),( RBAR,0), 0.6, 0.7) # Lattach
540 h = DMV_Rule((NOBAR,0),(NOBAR,0),(LRBAR,0), 1.0, 1.0) # Rattach
541 rh = DMV_Rule( ROOT, STOP, (LRBAR,0), 1.0, 1.0) # ROOT
542 b2 = {}
543 b2[(NOBAR, 0), 'h'] = 1.0
544 b2[(RBAR, 0), 'h'] = h_S.probA
545 b2[(LRBAR, 0), 'h'] = h_S.probA * _h_.probA
547 return DMV_Grammar([ rh, _h_, h_S, h_A, h ],b2,0,0,0, {0:'h'}, {'h':0})
550 def testreestimation_h():
551 io.DEBUG=['reest']
552 g = testgrammar_h()
553 reestimate(g,['h h h'.split()])
555 def regression_tests():
556 g_dup = testgrammar_h()
558 test0 = inner(0, 1, (LRBAR,0), 0, g_dup, 'h h'.split(), {})
559 if not "0.120"=="%.3f" % test0:
560 print "Should be 0.120: %.3f" % test0
562 test1 = inner(0, 1, (LRBAR,0), 1, g_dup, 'h h'.split(), {})
563 if not "0.063"=="%.3f" % test1:
564 print "Should be 0.063: %.3f" % test1
566 test3 = inner(0, 2, (LRBAR,0), 2, g_dup, 'h h h'.split(), {})
567 if not "0.0498"=="%.4f" % test3:
568 print "Should be 0.0498: %.4f" % test3
570 test4 = outer(1,2,(1,0),2,testgrammar_h(),'h h h'.split(),{},{})
571 if not "0.58" == "%.2f" % test4:
572 print "Should be 0.58: %.2f" % test4
574 if __name__ == "__main__":
575 io.DEBUG = ['reest_ichart']
576 import timeit
577 # import profile
578 # profile.run('testreestimation()')
579 # print timeit.Timer("dmv.testreestimation()",'''import dmv
580 # reload(dmv)''').timeit(1)
583 regression_tests()