todo: test PSTOP(h|ln)
[dmvccm.git] / src / dmv.py
blobae39a981021c5392b54680539d66d68e7a9be63a
1 #### changes by KBU:
2 # 2008-05-24
3 # - prettier printout for DMV_Rule
4 # - DMV_Rule changed a bit. head, L and R are now all pairs of the
5 # form (bars, head).
6 # - Started on P_STOP, a bit less pseudo now..
8 # 2008-05-27
9 # - started on initialization. So far, I have frequencies for
10 # everything, very harmonic. Still need to make these into 1-summing
11 # probabilities
13 # 2008-05-28
14 # - more work on initialization (init_freq and init_normalize),
15 # getting closer to probabilities now.
17 # 2008-05-29
18 # - init_normalize is done, it creates p_STOP, p_ROOT and p_CHOOSE,
19 # and also adds the relevant probabilities to p_rules in a grammar.
20 # Still, each individual rule has to store both adjacent and non_adj
21 # probabilities, and inner() should be able to send some parameter
22 # which lets the rule choose... hopefully... Is this possible to do
23 # top-down even? when the sentence could be all the same words?
24 # todo: extensive testing of identical words in sentences!
25 # - frequencies (only used in initialization) are stored as strings,
26 # but in the rules and p_STOP etc, there are only numbers.
28 # 2008-05-30
29 # - copied inner() into this file, to make the very dmv-specific
30 # adjacency stuff work (have to factor that out later on, when it
31 # works).
33 # 2008-06-01
34 # - finished typing in dmv.inner(), still have to test and debug
35 # it. The chart is now four times as big since for any rule we may
36 # have attachments to either the left or the right below, which
37 # upper rules depend on, for selecting probN or probA
39 # 2008-06-03
40 # - fixed a number of little bugs in initialization, where certain
41 # rules were simply not created, or created "backwards"
42 # - dmv.inner() should Work now...
44 # 2008-06-04
45 # - moved initialization to harmonic.py
47 # 2008-06-09
48 # - prune() finished, seems to be working.
49 # - started on implementing the other reestimation formulas, in
50 # reestimate()
53 # import numpy # numpy provides Fast Arrays, for future optimization
54 import pprint
55 import io
56 import harmonic
58 # non-tweakable/constant "lookup" globals
59 BARS = [0,1,2]
60 RBAR = 1
61 LRBAR = 2
62 NOBAR = 0
63 ROOT = (LRBAR, -1)
64 STOP = (NOBAR, -2)
66 if __name__ == "__main__":
67 print "DMV module tests:"
70 def node(bars, head):
71 '''Useless function, but just here as documentation. Nodes make up
72 LHS, R and L in each DMV_Rule'''
73 return (bars, head)
75 def bars(node):
76 return node[0]
78 def head(node):
79 return node[1]
82 class DMV_Grammar(io.Grammar):
83 '''The DMV-PCFG.
85 Public members:
86 p_STOP, p_ROOT, p_CHOOSE, p_terminals
87 These are changed in the Maximation step, then used to set the
88 new probabilities of each DMV_Rule.
90 Todo: make p_terminals private? (But it has to be changable in
91 maximation step due to the short-cutting rules... could of course
92 make a DMV_Grammar function to update the short-cut rules...)
94 __p_rules is private, but we can still say stuff like:
95 for r in g.all_rules():
96 r.probN = newProbN
98 What other representations do we need? (P_STOP formula uses
99 deps_D(h,l/r) at least)'''
100 def __str__(self):
101 str = ""
102 for r in self.all_rules():
103 str += "%s\n" % r.__str__(self.numtag)
104 return str
106 def h_rules(self, h):
107 return [r for r in self.all_rules() if r.head() == h]
109 def rules(self, LHS):
110 return [r for r in self.all_rules() if r.LHS() == LHS]
112 def sent_rules(self, LHS, sent_nums):
113 "Used in dmv.inner"
114 # We don't want to rule out STOPs!
115 nums = sent_nums + [ head(STOP) ]
116 return [r for r in self.all_rules() if r.LHS() == LHS
117 and head(r.L()) in nums and head(r.R()) in nums]
119 def deps_L(self, head): # todo: do I use this at all?
120 # todo test, probably this list comprehension doesn't work
121 return [a for r in self.all_rules() if r.head() == head and a == r.L()]
123 def deps_R(self, head):
124 # todo test, probably this list comprehension doesn't work
125 return [a for r in self.all_rules() if r.head() == head and a == r.R()]
127 def __init__(self, p_rules, p_terminals, p_STOP, p_CHOOSE, p_ROOT, numtag, tagnum):
128 io.Grammar.__init__(self, p_rules, p_terminals, numtag, tagnum)
129 self.p_STOP = p_STOP
130 self.p_CHOOSE = p_CHOOSE
131 self.p_ROOT = p_ROOT
132 self.head_nums = [k for k,v in numtag.iteritems()]
135 class DMV_Rule(io.CNF_Rule):
136 '''A single CNF rule in the PCFG, of the form
137 LHS -> L R
138 where LHS, L and R are 'nodes', eg. of the form (bars, head).
140 Public members:
141 probN, probA
143 Private members:
144 __L, __R, __LHS
146 Different rule-types have different probabilities associated with
147 them:
149 _h_ -> STOP h_ P( STOP|h,L, adj)
150 _h_ -> STOP h_ P( STOP|h,L,non_adj)
151 h_ -> h STOP P( STOP|h,R, adj)
152 h_ -> h STOP P( STOP|h,R,non_adj)
153 h_ -> _a_ h_ P(-STOP|h,L, adj) * P(a|h,L)
154 h_ -> _a_ h_ P(-STOP|h,L,non_adj) * P(a|h,L)
155 h -> h _a_ P(-STOP|h,R, adj) * P(a|h,R)
156 h -> h _a_ P(-STOP|h,R,non_adj) * P(a|h,R)
158 def p(self, adj, *arg):
159 if adj:
160 return self.probA
161 else:
162 return self.probN
164 def p_STOP(self, s, t, loc_h):
165 '''Returns the correct probability, adjacent if we're rewriting from
166 the (either left or right) end of the fragment. '''
167 if self.L() == STOP:
168 return self.p(s == loc_h)
169 elif self.R() == STOP:
170 if not loc_h == s:
171 io.debug( "(%s given loc_h:%d but s:%d. Todo: optimize away!)"
172 % (self, loc_h, s) )
173 return 0.0
174 else:
175 return self.p(t == loc_h)
177 def p_ATTACH(self, r, loc_h, s=None):
178 '''Returns the correct probability, adjacent if we haven't attached
179 anything before.'''
180 if self.LHS() == self.L():
181 if not loc_h == s:
182 io.debug( "(%s given loc_h (loc_L):%d but s:%d. Todo: optimize away!)"
183 % (self, loc_h, s) )
184 return 0.0
185 else:
186 return self.p(r == loc_h)
187 elif self.LHS() == self.R():
188 return self.p(r+1 == loc_h)
190 def bars(self):
191 return bars(self.LHS())
193 def head(self):
194 return head(self.LHS())
196 def __init__(self, LHS, L, R, probN, probA):
197 for b_h in [LHS, L, R]:
198 if bars(b_h) not in BARS:
199 raise ValueError("bars must be in %s; was given: %s"
200 % (BARS, bars(b_h)))
201 io.CNF_Rule.__init__(self, LHS, L, R, probN)
202 self.probA = probA # adjacent
203 self.probN = probN # non_adj
205 @classmethod # so we can call DMV_Rule.bar_str(b_h)
206 def bar_str(cls, b_h, tag=lambda x:x):
207 if(b_h == ROOT):
208 return 'ROOT'
209 elif(b_h == STOP):
210 return 'STOP'
211 elif(bars(b_h) == RBAR):
212 return " %s_ " % tag(head(b_h))
213 elif(bars(b_h) == LRBAR):
214 return "_%s_ " % tag(head(b_h))
215 else:
216 return " %s " % tag(head(b_h))
219 def __str__(self, tag=lambda x:x):
220 return "%s-->%s %s\t[N %.2f] [A %.2f]" % (self.bar_str(self.LHS(), tag),
221 self.bar_str(self.L(), tag),
222 self.bar_str(self.R(), tag),
223 self.probN,
224 self.probA)
232 ###################################
233 # dmv-specific version of inner() #
234 ###################################
235 def locs(h, sent, s=0, t=None, remove=None):
236 '''Return the locations of h in sent, or some fragment of sent (in the
237 latter case we make sure to offset the locations correctly so that
238 for any x in the returned list, sent[x]==h).
240 t is inclusive, to match the way indices work with inner()
241 (although python list-splicing has "exclusive" end indices)'''
242 if t == None:
243 t = len(sent)-1
244 return [i+s for i,w in enumerate(sent[s:t+1])
245 if w == h and not (i+s) == remove]
248 def inner(s, t, LHS, loc_h, g, sent, chart):
249 ''' A rewrite of io.inner(), to take adjacency into accord.
251 The chart is now of this form:
252 chart[(s,t,LHS, loc_h)]
254 loc_h gives adjacency (along with r and location of other child
255 for attachment rules), and is needed in P_STOP reestimation.
257 Todo: if possible, refactor (move dmv-specific stuff back into
258 dmv, so this is "general" enough to be in io.py)
261 def O(s):
262 return sent[s]
264 sent_nums = [g.tagnum(tag) for tag in sent]
266 def e(s,t,LHS, loc_h, n_t):
267 def tab():
268 "Tabs for debug output"
269 return "\t"*n_t
271 if (s, t, LHS, loc_h) in chart:
272 io.debug("%s*= %.4f in chart: s:%d t:%d LHS:%s loc:%d"
273 %(tab(),chart[(s, t, LHS, loc_h)], s, t,
274 DMV_Rule.bar_str(LHS), loc_h))
275 return chart[(s, t, LHS, loc_h)]
276 else:
277 if s == t:
278 if not loc_h == s:
279 # terminals are always F,F for attachment
280 io.debug("%s*= 0.0 (wrong loc_h)" % tab())
281 return 0.0
282 elif (LHS, O(s)) in g.p_terminals:
283 prob = g.p_terminals[LHS, O(s)] # "b[LHS, O(s)]" in Lari&Young
284 else:
285 # todo: assuming this is how to deal w/lacking
286 # rules, since we add prob.s, and 0 is identity
287 prob = 0.0
288 io.debug( "%sLACKING TERMINAL:" % tab())
289 # todo: add to chart perhaps? Although, it _is_ simple lookup..
290 io.debug( "%s*= %.4f (terminal: %s -> %s_%d)"
291 % (tab(),prob, DMV_Rule.bar_str(LHS), O(s), loc_h) )
292 return prob
293 else:
294 p = 0.0 # "sum over j,k in a[LHS,j,k]"
295 for rule in g.sent_rules(LHS, sent_nums):
296 io.debug( "%ssumming rule %s s:%d t:%d loc:%d" % (tab(),rule,s,t,loc_h) )
297 L = rule.L()
298 R = rule.R()
299 # if it's a STOP rule, rewrite for the same range:
300 if (L == STOP) or (R == STOP):
301 if L == STOP:
302 pLR = e(s, t, R, loc_h, n_t+1)
303 elif R == STOP:
304 pLR = e(s, t, L, loc_h, n_t+1)
305 p += rule.p_STOP(s, t, loc_h) * pLR
306 io.debug( "%sp= %.4f (STOP)" % (tab(), p) )
308 else: # not a STOP, an attachment rewrite:
309 for r in range(s, t):
310 # if loc_h == t, no need to try right-attachments,
311 # if loc_h == s, no need to try left-attachments... todo
312 p_h = rule.p_ATTACH(r, loc_h, s=s)
313 if rule.LHS() == L:
314 locs_L = [loc_h]
315 locs_R = locs(head(R), sent_nums, r+1, t, loc_h)
316 elif rule.LHS() == R:
317 locs_L = locs(head(L), sent_nums, s, r, loc_h)
318 locs_R = [loc_h]
319 # see http://tinyurl.com/4ffhhw
320 p += sum([e(s, r, L, loc_L, n_t+1) *
321 p_h *
322 e(r+1, t, R, loc_R, n_t+1)
323 for loc_L in locs_L
324 for loc_R in locs_R])
325 io.debug( "%sp= %.4f (ATTACH)" % (tab(), p) )
326 chart[(s, t, LHS, loc_h)] = p
327 return p
328 # end of e-function
330 inner_prob = e(s,t,LHS,loc_h, 0)
331 if 1 in io.DEBUG:
332 print debug_chart(g,sent,chart)
333 return inner_prob
334 # end of dmv.inner(s, t, LHS, loc_h, g, sent, chart)
337 def debug_chart(g,sent,chart):
338 str = "---CHART:---\n"
339 for (s,t,LHS,loc_h),v in chart.iteritems():
340 str += "%s -> %s_%d ... %s_%d (loc_h:%s):\t%.4f\n" % (DMV_Rule.bar_str(LHS,g.numtag),
341 sent[s], s, sent[s], t, loc_h, v)
342 str += "---CHART:end---\n"
343 return str
346 def inner_sent(loc_h, g, sent, chart):
347 return inner(0, len(sent)-1, ROOT, loc_h, g, sent, chart)
349 def prune(s,t,LHS,loc_h, g, sent_nums, chart):
350 '''Removes unused subtrees with positive probability from the
351 chart.
353 Unused := all parent subtrees have (eventually, looking upwards)
354 probability 0.0'''
355 def prune_helper(keep,s,t,LHS,loc_h):
356 keep = keep and chart[(s,t,LHS,loc_h)] > 0.0
357 for rule in g.sent_rules(LHS, sent_nums):
358 L = rule.L()
359 R = rule.R()
360 if R==STOP:
361 if (s,t,L,loc_h) in chart:
362 prune_helper(keep, s,t, L,loc_h)
363 elif L==STOP:
364 if (s,t,R,loc_h) in chart:
365 prune_helper(keep, s,t, R,loc_h)
366 else:
367 for r in range(s,t):
368 for loc_L in locs(head(L), sent_nums, s, r):
369 if (s,r,rule.L(),loc_L) in chart:
370 prune_helper(keep, s ,r,rule.L(),loc_L)
371 for loc_R in locs(head(R), sent_nums, r+1, t):
372 if (r+1,t,rule.R(),loc_R) in chart:
373 prune_helper(keep,r+1,t,rule.R(),loc_R)
375 if not (s,t,LHS,loc_h) in keepchart:
376 keepchart[(s,t,LHS,loc_h)] = keep
377 else: # eg. if previously some parent rule had 0.0, but then a
378 # later rule said "No, I've got a use for this subtree"
379 keepchart[(s,t,LHS,loc_h)] += keep
380 return None
382 keepchart = {}
383 prune_helper(True,s,t,LHS,loc_h)
384 for (s,t,LHS,loc_h),v in keepchart.iteritems():
385 if not v:
386 chart.pop((s,t,LHS,loc_h))
387 # end prune(s,t,LHS,loc_h, g, sent_nums, chart)
389 def prune_sent(loc_h, g, sent_nums, chart):
390 return prune(0, len(sent_nums)-1, ROOT, loc_h, g, sent_nums, chart)
395 ##############################
396 # reestimation, todo: #
397 ##############################
398 def reestimate_zeros(h_nums):
399 # todo: p_ROOT, p_CHOOSE, p_terminals
400 f = {}
401 for h in h_nums:
402 f[('STOP','num',h)] = 0.0
403 f[('STOP','den',h)] = 0.0
404 return f
406 def reestimate(g, corpus):
407 '''current todo.
408 P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
409 f = reestimate_zeros(g.head_nums)
410 for sent in corpus:
411 chart = {}
412 sent_nums = [g.tagnum(w) for i,w in enumerate(sent)]
413 for loc_h,h in zip(range(len(sent)), sent_nums):
414 inner_sent(loc_h, g, sent, chart)
415 io.debug( debug_chart (g,sent,chart) ,'reest_chart')
416 prune_sent(loc_h, g, sent_nums, chart)
417 io.debug( debug_chart( g,sent,chart) ,'reest_chart')
418 for s in range(loc_h): # s<loc(h), range gives strictly less
419 for t in range(loc_h, len(sent)):
420 io.debug( "s:%s t:%s loc:%d"%(s,t,loc_h) , 'reest')
421 if (s, t, (LRBAR,h), loc_h) in chart:
422 f[('STOP','num',h)] += chart[(s, t, (LRBAR,h), loc_h)]
423 io.debug( "num+=%s"%chart[(s, t, (LRBAR,h), loc_h)] , 'reest')
424 if (s, t, (RBAR,h), loc_h) in chart:
425 f[('STOP','den',h)] += chart[(s, t, (RBAR,h), loc_h)]
426 io.debug( "den+=%s"%chart[(s, t, (RBAR,h), loc_h)] , 'reest')
427 io.debug( "num:%s den:%s"%(f[('STOP','num',h)],f[('STOP','den',h)]), 'reest')
428 for s in [loc_h]:
429 # adjacent left attachment
430 pass
432 # todo: use sum([chart[(s, t...)] etc? but can we then
433 # keep den and num separate within _one_ sum()-call? use map?
435 for r in g.all_rules():
436 if r.L() == STOP:
437 if f[('STOP','den',h)] > 0.0:
438 r.probN = f[('STOP','num',h)] / f[('STOP','den',h)]
439 io.debug( "p(STOP|%s,L,N): %s / %s = %s"%(h,f[('STOP','num',h)],f[('STOP','den',h)],f[('STOP','num',h)]/f[('STOP','den',h)]) , 'reest')
440 else:
441 io.debug( "p(STOP|%s,L,N): %s / 0 = %s"%(h,f[('STOP','num',h)]), 'reest')
451 ##############################
452 # testing functions: #
453 ##############################
455 def testreestimation():
456 corpus = [s.split() for s in ['det vbd nn c vbd','det nn vbd c nn vbd pp',
457 'det vbd nn', 'det vbd nn c vbd pp',
458 'det vbd nn', 'det vbd c nn vbd pp',
459 'det vbd nn', 'det nn vbd nn c vbd pp',
460 'det vbd nn', 'det nn vbd c det vbd pp',
461 'det vbd nn', 'det nn vbd c vbd det det det pp',
462 'det nn vbd', 'det nn vbd c vbd pp',
463 'det nn vbd', 'det nn vbd c vbd det pp',
464 'det nn vbd', 'det nn vbd c vbd pp',
465 'det nn vbd pp', 'det nn vbd det', ]]
466 g = harmonic.initialize(corpus)
467 reestimate(g, corpus)
469 def testgrammar_h(): # Non, Adj
470 _h_ = DMV_Rule((LRBAR,0), STOP, ( RBAR,0), 1.0, 1.0) # LSTOP
471 h_S = DMV_Rule(( RBAR,0),(NOBAR,0), STOP, 0.4, 0.3) # RSTOP
472 h_A = DMV_Rule(( RBAR,0),(LRBAR,0),( RBAR,0), 0.6, 0.7) # Lattach
473 h = DMV_Rule((NOBAR,0),(NOBAR,0),(LRBAR,0), 1.0, 1.0) # Rattach
474 rh = DMV_Rule( ROOT, STOP, (LRBAR,0), 1.0, 1.0) # ROOT
475 b2 = {}
476 b2[(NOBAR, 0), 'h'] = 1.0
477 b2[(RBAR, 0), 'h'] = h_S.probA
478 b2[(LRBAR, 0), 'h'] = h_S.probA * _h_.probA
480 return DMV_Grammar([ rh, _h_, h_S, h_A, h ],b2,0,0,0, {0:'h'}, {'h':0})
483 def testprune():
484 g = testgrammar_h()
485 sent = 'h h h'.split()
486 chart = {}
488 inner(0,2,ROOT,2, g,sent,chart)
489 print debug_chart(g,sent,chart)
490 prune( 0,2,ROOT,2, g,[0,0,0],chart)
491 print debug_chart(g,sent,chart)
493 def testreestimation_h():
494 io.DEBUG=['reest']
495 g = testgrammar_h()
496 reestimate(g,['h h h'.split()])
498 if __name__ == "__main__":
499 io.DEBUG = []
500 import timeit
501 timeit.Timer("dmv.testreestimation_h()",'''import dmv
502 reload(dmv)''').timeit(1)
504 g_dup = testgrammar_h()
506 io.DEBUG = []
507 test0 = inner(0, 1, (LRBAR,0), 0, g_dup, 'h h'.split(), {})
508 if not "0.120"=="%.3f" % test0:
509 print "Should be 0.120: %.3f" % test0
511 test1 = inner(0, 1, (LRBAR,0), 1, g_dup, 'h h'.split(), {})
512 if not "0.063"=="%.3f" % test1:
513 print "Should be 0.063: %.3f" % test1
515 test3 = inner(0, 2, (LRBAR,0), 2, g_dup, 'h h h'.split(), {})
516 if not "0.0498"=="%.4f" % test3:
517 print "Should be 0.0498: %.4f" % test3