3 # - prettier printout for DMV_Rule
4 # - DMV_Rule changed a bit. head, L and R are now all pairs of the
6 # - Started on P_STOP, a bit less pseudo now..
9 # - started on initialization. So far, I have frequencies for
10 # everything, very harmonic. Still need to make these into 1-summing
14 # - more work on initialization (init_freq and init_normalize),
15 # getting closer to probabilities now.
18 # - init_normalize is done, it creates p_STOP, p_ROOT and p_CHOOSE,
19 # and also adds the relevant probabilities to p_rules in a grammar.
20 # Still, each individual rule has to store both adjacent and non_adj
21 # probabilities, and inner() should be able to send some parameter
22 # which lets the rule choose... hopefully... Is this possible to do
23 # top-down even? when the sentence could be all the same words?
24 # todo: extensive testing of identical words in sentences!
25 # - frequencies (only used in initialization) are stored as strings,
26 # but in the rules and p_STOP etc, there are only numbers.
29 # - copied inner() into this file, to make the very dmv-specific
30 # adjacency stuff work (have to factor that out later on, when it
34 # - finished typing in dmv.inner(), still have to test and debug
35 # it. The chart is now four times as big since for any rule we may
36 # have attachments to either the left or the right below, which
37 # upper rules depend on, for selecting probN or probA
40 # - fixed a number of little bugs in initialization, where certain
41 # rules were simply not created, or created "backwards"
42 # - dmv.inner() should Work now...
45 # - moved initialization to harmonic.py
48 # - prune() finished, seems to be working.
49 # - started on implementing the other reestimation formulas, in
53 # import numpy # numpy provides Fast Arrays, for future optimization
58 # non-tweakable/constant "lookup" globals
66 if __name__
== "__main__":
67 print "DMV module tests:"
71 '''Useless function, but just here as documentation. Nodes make up
72 LHS, R and L in each DMV_Rule'''
82 class DMV_Grammar(io
.Grammar
):
86 p_STOP, p_ROOT, p_CHOOSE, p_terminals
87 These are changed in the Maximation step, then used to set the
88 new probabilities of each DMV_Rule.
90 Todo: make p_terminals private? (But it has to be changable in
91 maximation step due to the short-cutting rules... could of course
92 make a DMV_Grammar function to update the short-cut rules...)
94 __p_rules is private, but we can still say stuff like:
95 for r in g.all_rules():
98 What other representations do we need? (P_STOP formula uses
99 deps_D(h,l/r) at least)'''
102 for r
in self
.all_rules():
103 str += "%s\n" % r
.__str
__(self
.numtag
)
106 def h_rules(self
, h
):
107 return [r
for r
in self
.all_rules() if r
.head() == h
]
109 def rules(self
, LHS
):
110 return [r
for r
in self
.all_rules() if r
.LHS() == LHS
]
112 def sent_rules(self
, LHS
, sent_nums
):
114 # We don't want to rule out STOPs!
115 nums
= sent_nums
+ [ head(STOP
) ]
116 return [r
for r
in self
.all_rules() if r
.LHS() == LHS
117 and head(r
.L()) in nums
and head(r
.R()) in nums
]
119 def deps_L(self
, head
): # todo: do I use this at all?
120 # todo test, probably this list comprehension doesn't work
121 return [a
for r
in self
.all_rules() if r
.head() == head
and a
== r
.L()]
123 def deps_R(self
, head
):
124 # todo test, probably this list comprehension doesn't work
125 return [a
for r
in self
.all_rules() if r
.head() == head
and a
== r
.R()]
127 def __init__(self
, p_rules
, p_terminals
, p_STOP
, p_CHOOSE
, p_ROOT
, numtag
, tagnum
):
128 io
.Grammar
.__init
__(self
, p_rules
, p_terminals
, numtag
, tagnum
)
130 self
.p_CHOOSE
= p_CHOOSE
132 self
.head_nums
= [k
for k
,v
in numtag
.iteritems()]
135 class DMV_Rule(io
.CNF_Rule
):
136 '''A single CNF rule in the PCFG, of the form
138 where LHS, L and R are 'nodes', eg. of the form (bars, head).
146 Different rule-types have different probabilities associated with
149 _h_ -> STOP h_ P( STOP|h,L, adj)
150 _h_ -> STOP h_ P( STOP|h,L,non_adj)
151 h_ -> h STOP P( STOP|h,R, adj)
152 h_ -> h STOP P( STOP|h,R,non_adj)
153 h_ -> _a_ h_ P(-STOP|h,L, adj) * P(a|h,L)
154 h_ -> _a_ h_ P(-STOP|h,L,non_adj) * P(a|h,L)
155 h -> h _a_ P(-STOP|h,R, adj) * P(a|h,R)
156 h -> h _a_ P(-STOP|h,R,non_adj) * P(a|h,R)
158 def p(self
, adj
, *arg
):
164 def p_STOP(self
, s
, t
, loc_h
):
165 '''Returns the correct probability, adjacent if we're rewriting from
166 the (either left or right) end of the fragment. '''
168 return self
.p(s
== loc_h
)
169 elif self
.R() == STOP
:
171 io
.debug( "(%s given loc_h:%d but s:%d. Todo: optimize away!)"
175 return self
.p(t
== loc_h
)
177 def p_ATTACH(self
, r
, loc_h
, s
=None):
178 '''Returns the correct probability, adjacent if we haven't attached
180 if self
.LHS() == self
.L():
182 io
.debug( "(%s given loc_h (loc_L):%d but s:%d. Todo: optimize away!)"
186 return self
.p(r
== loc_h
)
187 elif self
.LHS() == self
.R():
188 return self
.p(r
+1 == loc_h
)
191 return bars(self
.LHS())
194 return head(self
.LHS())
196 def __init__(self
, LHS
, L
, R
, probN
, probA
):
197 for b_h
in [LHS
, L
, R
]:
198 if bars(b_h
) not in BARS
:
199 raise ValueError("bars must be in %s; was given: %s"
201 io
.CNF_Rule
.__init
__(self
, LHS
, L
, R
, probN
)
202 self
.probA
= probA
# adjacent
203 self
.probN
= probN
# non_adj
205 @classmethod # so we can call DMV_Rule.bar_str(b_h)
206 def bar_str(cls
, b_h
, tag
=lambda x
:x
):
211 elif(bars(b_h
) == RBAR
):
212 return " %s_ " % tag(head(b_h
))
213 elif(bars(b_h
) == LRBAR
):
214 return "_%s_ " % tag(head(b_h
))
216 return " %s " % tag(head(b_h
))
219 def __str__(self
, tag
=lambda x
:x
):
220 return "%s-->%s %s\t[N %.2f] [A %.2f]" % (self
.bar_str(self
.LHS(), tag
),
221 self
.bar_str(self
.L(), tag
),
222 self
.bar_str(self
.R(), tag
),
232 ###################################
233 # dmv-specific version of inner() #
234 ###################################
235 def locs(h
, sent
, s
=0, t
=None, remove
=None):
236 '''Return the locations of h in sent, or some fragment of sent (in the
237 latter case we make sure to offset the locations correctly so that
238 for any x in the returned list, sent[x]==h).
240 t is inclusive, to match the way indices work with inner()
241 (although python list-splicing has "exclusive" end indices)'''
244 return [i
+s
for i
,w
in enumerate(sent
[s
:t
+1])
245 if w
== h
and not (i
+s
) == remove
]
248 def inner(s
, t
, LHS
, loc_h
, g
, sent
, chart
):
249 ''' A rewrite of io.inner(), to take adjacency into accord.
251 The chart is now of this form:
252 chart[(s,t,LHS, loc_h)]
254 loc_h gives adjacency (along with r and location of other child
255 for attachment rules), and is needed in P_STOP reestimation.
257 Todo: if possible, refactor (move dmv-specific stuff back into
258 dmv, so this is "general" enough to be in io.py)
264 sent_nums
= [g
.tagnum(tag
) for tag
in sent
]
266 def e(s
,t
,LHS
, loc_h
, n_t
):
268 "Tabs for debug output"
271 if (s
, t
, LHS
, loc_h
) in chart
:
272 io
.debug("%s*= %.4f in chart: s:%d t:%d LHS:%s loc:%d"
273 %(tab(),chart
[(s
, t
, LHS
, loc_h
)], s
, t
,
274 DMV_Rule
.bar_str(LHS
), loc_h
))
275 return chart
[(s
, t
, LHS
, loc_h
)]
279 # terminals are always F,F for attachment
280 io
.debug("%s*= 0.0 (wrong loc_h)" % tab())
282 elif (LHS
, O(s
)) in g
.p_terminals
:
283 prob
= g
.p_terminals
[LHS
, O(s
)] # "b[LHS, O(s)]" in Lari&Young
285 # todo: assuming this is how to deal w/lacking
286 # rules, since we add prob.s, and 0 is identity
288 io
.debug( "%sLACKING TERMINAL:" % tab())
289 # todo: add to chart perhaps? Although, it _is_ simple lookup..
290 io
.debug( "%s*= %.4f (terminal: %s -> %s_%d)"
291 % (tab(),prob
, DMV_Rule
.bar_str(LHS
), O(s
), loc_h
) )
294 p
= 0.0 # "sum over j,k in a[LHS,j,k]"
295 for rule
in g
.sent_rules(LHS
, sent_nums
):
296 io
.debug( "%ssumming rule %s s:%d t:%d loc:%d" % (tab(),rule
,s
,t
,loc_h
) )
299 # if it's a STOP rule, rewrite for the same range:
300 if (L
== STOP
) or (R
== STOP
):
302 pLR
= e(s
, t
, R
, loc_h
, n_t
+1)
304 pLR
= e(s
, t
, L
, loc_h
, n_t
+1)
305 p
+= rule
.p_STOP(s
, t
, loc_h
) * pLR
306 io
.debug( "%sp= %.4f (STOP)" % (tab(), p
) )
308 else: # not a STOP, an attachment rewrite:
309 for r
in range(s
, t
):
310 # if loc_h == t, no need to try right-attachments,
311 # if loc_h == s, no need to try left-attachments... todo
312 p_h
= rule
.p_ATTACH(r
, loc_h
, s
=s
)
315 locs_R
= locs(head(R
), sent_nums
, r
+1, t
, loc_h
)
316 elif rule
.LHS() == R
:
317 locs_L
= locs(head(L
), sent_nums
, s
, r
, loc_h
)
319 # see http://tinyurl.com/4ffhhw
320 p
+= sum([e(s
, r
, L
, loc_L
, n_t
+1) *
322 e(r
+1, t
, R
, loc_R
, n_t
+1)
324 for loc_R
in locs_R
])
325 io
.debug( "%sp= %.4f (ATTACH)" % (tab(), p
) )
326 chart
[(s
, t
, LHS
, loc_h
)] = p
330 inner_prob
= e(s
,t
,LHS
,loc_h
, 0)
332 print debug_chart(g
,sent
,chart
)
334 # end of dmv.inner(s, t, LHS, loc_h, g, sent, chart)
337 def debug_chart(g
,sent
,chart
):
338 str = "---CHART:---\n"
339 for (s
,t
,LHS
,loc_h
),v
in chart
.iteritems():
340 str += "%s -> %s_%d ... %s_%d (loc_h:%s):\t%.4f\n" % (DMV_Rule
.bar_str(LHS
,g
.numtag
),
341 sent
[s
], s
, sent
[s
], t
, loc_h
, v
)
342 str += "---CHART:end---\n"
346 def inner_sent(loc_h
, g
, sent
, chart
):
347 return inner(0, len(sent
)-1, ROOT
, loc_h
, g
, sent
, chart
)
349 def prune(s
,t
,LHS
,loc_h
, g
, sent_nums
, chart
):
350 '''Removes unused subtrees with positive probability from the
353 Unused := all parent subtrees have (eventually, looking upwards)
355 def prune_helper(keep
,s
,t
,LHS
,loc_h
):
356 keep
= keep
and chart
[(s
,t
,LHS
,loc_h
)] > 0.0
357 for rule
in g
.sent_rules(LHS
, sent_nums
):
361 if (s
,t
,L
,loc_h
) in chart
:
362 prune_helper(keep
, s
,t
, L
,loc_h
)
364 if (s
,t
,R
,loc_h
) in chart
:
365 prune_helper(keep
, s
,t
, R
,loc_h
)
368 for loc_L
in locs(head(L
), sent_nums
, s
, r
):
369 if (s
,r
,rule
.L(),loc_L
) in chart
:
370 prune_helper(keep
, s
,r
,rule
.L(),loc_L
)
371 for loc_R
in locs(head(R
), sent_nums
, r
+1, t
):
372 if (r
+1,t
,rule
.R(),loc_R
) in chart
:
373 prune_helper(keep
,r
+1,t
,rule
.R(),loc_R
)
375 if not (s
,t
,LHS
,loc_h
) in keepchart
:
376 keepchart
[(s
,t
,LHS
,loc_h
)] = keep
377 else: # eg. if previously some parent rule had 0.0, but then a
378 # later rule said "No, I've got a use for this subtree"
379 keepchart
[(s
,t
,LHS
,loc_h
)] += keep
383 prune_helper(True,s
,t
,LHS
,loc_h
)
384 for (s
,t
,LHS
,loc_h
),v
in keepchart
.iteritems():
386 chart
.pop((s
,t
,LHS
,loc_h
))
387 # end prune(s,t,LHS,loc_h, g, sent_nums, chart)
389 def prune_sent(loc_h
, g
, sent_nums
, chart
):
390 return prune(0, len(sent_nums
)-1, ROOT
, loc_h
, g
, sent_nums
, chart
)
395 ##############################
396 # reestimation, todo: #
397 ##############################
398 def reestimate_zeros(h_nums
):
399 # todo: p_ROOT, p_CHOOSE, p_terminals
402 f
[('STOP','num',h
)] = 0.0
403 f
[('STOP','den',h
)] = 0.0
406 def reestimate(g
, corpus
):
408 P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
409 f
= reestimate_zeros(g
.head_nums
)
412 sent_nums
= [g
.tagnum(w
) for i
,w
in enumerate(sent
)]
413 for loc_h
,h
in zip(range(len(sent
)), sent_nums
):
414 inner_sent(loc_h
, g
, sent
, chart
)
415 io
.debug( debug_chart (g
,sent
,chart
) ,'reest_chart')
416 prune_sent(loc_h
, g
, sent_nums
, chart
)
417 io
.debug( debug_chart( g
,sent
,chart
) ,'reest_chart')
418 for s
in range(loc_h
): # s<loc(h), range gives strictly less
419 for t
in range(loc_h
, len(sent
)):
420 io
.debug( "s:%s t:%s loc:%d"%(s
,t
,loc_h
) , 'reest')
421 if (s
, t
, (LRBAR
,h
), loc_h
) in chart
:
422 f
[('STOP','num',h
)] += chart
[(s
, t
, (LRBAR
,h
), loc_h
)]
423 io
.debug( "num+=%s"%chart
[(s
, t
, (LRBAR
,h
), loc_h
)] , 'reest')
424 if (s
, t
, (RBAR
,h
), loc_h
) in chart
:
425 f
[('STOP','den',h
)] += chart
[(s
, t
, (RBAR
,h
), loc_h
)]
426 io
.debug( "den+=%s"%chart
[(s
, t
, (RBAR
,h
), loc_h
)] , 'reest')
427 io
.debug( "num:%s den:%s"%(f
[('STOP','num',h
)],f
[('STOP','den',h
)]), 'reest')
429 # adjacent left attachment
432 # todo: use sum([chart[(s, t...)] etc? but can we then
433 # keep den and num separate within _one_ sum()-call? use map?
435 for r
in g
.all_rules():
437 if f
[('STOP','den',h
)] > 0.0:
438 r
.probN
= f
[('STOP','num',h
)] / f
[('STOP','den',h
)]
439 io
.debug( "p(STOP|%s,L,N): %s / %s = %s"%(h
,f
[('STOP','num',h
)],f
[('STOP','den',h
)],f
[('STOP','num',h
)]/f
[('STOP','den',h
)]) , 'reest')
441 io
.debug( "p(STOP|%s,L,N): %s / 0 = %s"%(h
,f
[('STOP','num',h
)]), 'reest')
451 ##############################
452 # testing functions: #
453 ##############################
455 def testreestimation():
456 corpus
= [s
.split() for s
in ['det vbd nn c vbd','det nn vbd c nn vbd pp',
457 'det vbd nn', 'det vbd nn c vbd pp',
458 'det vbd nn', 'det vbd c nn vbd pp',
459 'det vbd nn', 'det nn vbd nn c vbd pp',
460 'det vbd nn', 'det nn vbd c det vbd pp',
461 'det vbd nn', 'det nn vbd c vbd det det det pp',
462 'det nn vbd', 'det nn vbd c vbd pp',
463 'det nn vbd', 'det nn vbd c vbd det pp',
464 'det nn vbd', 'det nn vbd c vbd pp',
465 'det nn vbd pp', 'det nn vbd det', ]]
466 g
= harmonic
.initialize(corpus
)
467 reestimate(g
, corpus
)
469 def testgrammar_h(): # Non, Adj
470 _h_
= DMV_Rule((LRBAR
,0), STOP
, ( RBAR
,0), 1.0, 1.0) # LSTOP
471 h_S
= DMV_Rule(( RBAR
,0),(NOBAR
,0), STOP
, 0.4, 0.3) # RSTOP
472 h_A
= DMV_Rule(( RBAR
,0),(LRBAR
,0),( RBAR
,0), 0.6, 0.7) # Lattach
473 h
= DMV_Rule((NOBAR
,0),(NOBAR
,0),(LRBAR
,0), 1.0, 1.0) # Rattach
474 rh
= DMV_Rule( ROOT
, STOP
, (LRBAR
,0), 1.0, 1.0) # ROOT
476 b2
[(NOBAR
, 0), 'h'] = 1.0
477 b2
[(RBAR
, 0), 'h'] = h_S
.probA
478 b2
[(LRBAR
, 0), 'h'] = h_S
.probA
* _h_
.probA
480 return DMV_Grammar([ rh
, _h_
, h_S
, h_A
, h
],b2
,0,0,0, {0:'h'}, {'h':0})
485 sent
= 'h h h'.split()
488 inner(0,2,ROOT
,2, g
,sent
,chart
)
489 print debug_chart(g
,sent
,chart
)
490 prune( 0,2,ROOT
,2, g
,[0,0,0],chart
)
491 print debug_chart(g
,sent
,chart
)
493 def testreestimation_h():
496 reestimate(g
,['h h h'.split()])
498 if __name__
== "__main__":
501 timeit
.Timer("dmv.testreestimation_h()",'''import dmv
502 reload(dmv)''').timeit(1)
504 g_dup
= testgrammar_h()
507 test0
= inner(0, 1, (LRBAR
,0), 0, g_dup
, 'h h'.split(), {})
508 if not "0.120"=="%.3f" % test0
:
509 print "Should be 0.120: %.3f" % test0
511 test1
= inner(0, 1, (LRBAR
,0), 1, g_dup
, 'h h'.split(), {})
512 if not "0.063"=="%.3f" % test1
:
513 print "Should be 0.063: %.3f" % test1
515 test3
= inner(0, 2, (LRBAR
,0), 2, g_dup
, 'h h h'.split(), {})
516 if not "0.0498"=="%.4f" % test3
:
517 print "Should be 0.0498: %.4f" % test3