3 # - moved prune() to junk.py, now using outer() instead. outer() is
4 # written, but needs testing.
7 # - prune() finished, seems to be working.
8 # - started on implementing the other reestimation formulas, in
12 # - moved initialization to harmonic.py
15 # - fixed a number of little bugs in initialization, where certain
16 # rules were simply not created, or created "backwards"
17 # - dmv.inner() should Work now...
20 # - finished typing in dmv.inner(), still have to test and debug
21 # it. The ichart is now four times as big since for any rule we may
22 # have attachments to either the left or the right below, which
23 # upper rules depend on, for selecting probN or probA
26 # - copied inner() into this file, to make the very dmv-specific
27 # adjacency stuff work (have to factor that out later on, when it
31 # - init_normalize is done, it creates p_STOP, p_ROOT and p_CHOOSE,
32 # and also adds the relevant probabilities to p_rules in a grammar.
33 # Still, each individual rule has to store both adjacent and non_adj
34 # probabilities, and inner() should be able to send some parameter
35 # which lets the rule choose... hopefully... Is this possible to do
36 # top-down even? when the sentence could be all the same words?
37 # todo: extensive testing of identical words in sentences!
38 # - frequencies (only used in initialization) are stored as strings,
39 # but in the rules and p_STOP etc, there are only numbers.
42 # - more work on initialization (init_freq and init_normalize),
43 # getting closer to probabilities now.
46 # - started on initialization. So far, I have frequencies for
47 # everything, very harmonic. Still need to make these into 1-summing
51 # - prettier printout for DMV_Rule
52 # - DMV_Rule changed a bit. head, L and R are now all pairs of the
54 # - Started on P_STOP, a bit less pseudo now..
58 #import numpy # numpy provides Fast Arrays, for future optimization
61 # non-tweakable/constant "lookup" globals
69 if __name__
== "__main__":
70 print "DMV module tests:"
74 '''Useless function, but just here as documentation. Nodes make up
75 LHS, R and L in each DMV_Rule'''
85 class DMV_Grammar(io
.Grammar
):
89 p_STOP, p_ROOT, p_CHOOSE, p_terminals
90 These are changed in the Maximation step, then used to set the
91 new probabilities of each DMV_Rule.
93 Todo: make p_terminals private? (But it has to be changable in
94 maximation step due to the short-cutting rules... could of course
95 make a DMV_Grammar function to update the short-cut rules...)
97 __p_rules is private, but we can still say stuff like:
98 for r in g.all_rules():
101 What other representations do we need? (P_STOP formula uses
102 deps_D(h,l/r) at least)'''
105 for r
in self
.all_rules():
106 str += "%s\n" % r
.__str
__(self
.numtag
)
109 def h_rules(self
, h
):
110 return [r
for r
in self
.all_rules() if r
.head() == h
]
112 def mothersL(self
, Node
, sent_nums
):
113 return [r
for r
in self
.all_rules() if r
.L() == Node
]
115 def mothersR(self
, Node
, sent_nums
):
116 return [r
for r
in self
.all_rules() if r
.R() == Node
]
118 def rules(self
, LHS
):
119 return [r
for r
in self
.all_rules() if r
.LHS() == LHS
]
121 def sent_rules(self
, LHS
, sent_nums
):
122 '''Used in dmv.inner. Todo: this takes a _lot_ of time, it
123 seems. Could use some more space and cache some of this
125 # We don't want to rule out STOPs!
126 nums
= sent_nums
+ [ head(STOP
) ]
127 return [r
for r
in self
.all_rules() if r
.LHS() == LHS
128 and head(r
.L()) in nums
and head(r
.R()) in nums
]
130 def deps_L(self
, head
): # todo: do I use this at all?
131 # todo test, probably this list comprehension doesn't work
132 return [a
for r
in self
.all_rules() if r
.head() == head
and a
== r
.L()]
134 def deps_R(self
, head
):
135 # todo test, probably this list comprehension doesn't work
136 return [a
for r
in self
.all_rules() if r
.head() == head
and a
== r
.R()]
138 def __init__(self
, p_rules
, p_terminals
, p_STOP
, p_CHOOSE
, p_ROOT
, numtag
, tagnum
):
139 io
.Grammar
.__init
__(self
, p_rules
, p_terminals
, numtag
, tagnum
)
141 self
.p_CHOOSE
= p_CHOOSE
143 self
.head_nums
= [k
for k
in numtag
.iterkeys()]
146 class DMV_Rule(io
.CNF_Rule
):
147 '''A single CNF rule in the PCFG, of the form
149 where LHS, L and R are 'nodes', eg. of the form (bars, head).
157 Different rule-types have different probabilities associated with
160 _h_ -> STOP h_ P( STOP|h,L, adj)
161 _h_ -> STOP h_ P( STOP|h,L,non_adj)
162 h_ -> h STOP P( STOP|h,R, adj)
163 h_ -> h STOP P( STOP|h,R,non_adj)
164 h_ -> _a_ h_ P(-STOP|h,L, adj) * P(a|h,L)
165 h_ -> _a_ h_ P(-STOP|h,L,non_adj) * P(a|h,L)
166 h -> h _a_ P(-STOP|h,R, adj) * P(a|h,R)
167 h -> h _a_ P(-STOP|h,R,non_adj) * P(a|h,R)
169 def p(self
, adj
, *arg
):
175 def p_STOP(self
, s
, t
, loc_h
):
176 '''Returns the correct probability, adjacent if we're rewriting from
177 the (either left or right) end of the fragment. '''
179 return self
.p(s
== loc_h
)
180 elif self
.R() == STOP
:
182 if 'TODO' in io
.DEBUG
:
183 print "(%s given loc_h:%d but s:%d. Todo: optimize away!)" % (self
, loc_h
, s
)
186 return self
.p(t
== loc_h
)
188 def p_ATTACH(self
, r
, loc_h
, s
=None):
189 '''Returns the correct probability, adjacent if we haven't attached
191 if self
.LHS() == self
.L():
193 if 'TODO' in io
.DEBUG
:
194 print "(%s given loc_h (loc_L):%d but s:%d. Todo: optimize away!)" % (self
, loc_h
, s
)
197 return self
.p(r
== loc_h
)
198 elif self
.LHS() == self
.R():
199 return self
.p(r
+1 == loc_h
)
202 return bars(self
.LHS())
205 return head(self
.LHS())
207 def __init__(self
, LHS
, L
, R
, probN
, probA
):
208 for b_h
in [LHS
, L
, R
]:
209 if bars(b_h
) not in BARS
:
210 raise ValueError("bars must be in %s; was given: %s"
212 io
.CNF_Rule
.__init
__(self
, LHS
, L
, R
, probN
)
213 self
.probA
= probA
# adjacent
214 self
.probN
= probN
# non_adj
216 @classmethod # so we can call DMV_Rule.bar_str(b_h)
217 def bar_str(cls
, b_h
, tag
=lambda x
:x
):
222 elif(bars(b_h
) == RBAR
):
223 return " %s_ " % tag(head(b_h
))
224 elif(bars(b_h
) == LRBAR
):
225 return "_%s_ " % tag(head(b_h
))
227 return " %s " % tag(head(b_h
))
230 def __str__(self
, tag
=lambda x
:x
):
231 return "%s-->%s %s\t[N %.2f] [A %.2f]" % (self
.bar_str(self
.LHS(), tag
),
232 self
.bar_str(self
.L(), tag
),
233 self
.bar_str(self
.R(), tag
),
243 ###################################
244 # dmv-specific version of inner() #
245 ###################################
246 def locs(h
, sent
, s
=0, t
=None, remove
=None):
247 '''Return the locations of h in sent, or some fragment of sent (in the
248 latter case we make sure to offset the locations correctly so that
249 for any x in the returned list, sent[x]==h).
251 t is inclusive, to match the way indices work with inner()
252 (although python list-splicing has "exclusive" end indices)'''
255 return [i
+s
for i
,w
in enumerate(sent
[s
:t
+1])
256 if w
== h
and not (i
+s
) == remove
]
259 def inner(s
, t
, LHS
, loc_h
, g
, sent
, ichart
):
260 ''' A rewrite of io.inner(), to take adjacency into accord.
262 The ichart is now of this form:
263 ichart[s,t,LHS, loc_h]
265 loc_h gives adjacency (along with r and location of other child
266 for attachment rules), and is needed in P_STOP reestimation.
268 Todo: if possible, refactor (move dmv-specific stuff back into
269 dmv, so this is "general" enough to be in io.py)
275 sent_nums
= [g
.tagnum(tag
) for tag
in sent
]
278 def e(s
,t
,LHS
, loc_h
, n_t
):
280 "Tabs for debug output"
283 if (s
, t
, LHS
, loc_h
) in ichart
:
284 if 'INNER' in io
.DEBUG
:
285 print "%s*= %.4f in ichart: s:%d t:%d LHS:%s loc:%d" % (tab(),ichart
[s
, t
, LHS
, loc_h
], s
, t
,
286 DMV_Rule
.bar_str(LHS
), loc_h
)
287 return ichart
[s
, t
, LHS
, loc_h
]
291 if 'INNER' in io
.DEBUG
:
292 print "%s*= 0.0 (wrong loc_h)" % tab()
294 elif (LHS
, O(s
)) in g
.p_terminals
:
295 prob
= g
.p_terminals
[LHS
, O(s
)] # "b[LHS, O(s)]" in Lari&Young
297 # todo: assuming this is how to deal w/lacking
298 # rules, since we add prob.s, and 0 is identity
300 if 'INNER' in io
.DEBUG
:
301 print "%sLACKING TERMINAL:" % tab()
302 # todo: add to ichart perhaps? Although, it _is_ simple lookup..
303 if 'INNER' in io
.DEBUG
:
304 print "%s*= %.4f (terminal: %s -> %s_%d)" % (tab(),prob
, DMV_Rule
.bar_str(LHS
), O(s
), loc_h
)
307 p
= 0.0 # "sum over j,k in a[LHS,j,k]"
308 for rule
in g
.sent_rules(LHS
, sent_nums
):
309 if 'INNER' in io
.DEBUG
:
310 print "%ssumming rule %s s:%d t:%d loc:%d" % (tab(),rule
,s
,t
,loc_h
)
313 if (s
,t
,LHS
,loc_h
) not in tree
:
314 tree
[s
,t
,LHS
,loc_h
] = set()
315 if loc_h
== t
and rule
.LHS() == L
:
316 break # 25% faster with these cut-offs
317 if loc_h
== s
and rule
.LHS() == R
:
319 # if it's a STOP rule, rewrite for the same range:
320 if (L
== STOP
) or (R
== STOP
):
322 pLR
= e(s
, t
, R
, loc_h
, n_t
+1)
324 tree
[s
,t
,LHS
,loc_h
].add((s
,t
,R
,loc_h
))
326 pLR
= e(s
, t
, L
, loc_h
, n_t
+1)
328 tree
[s
,t
,LHS
,loc_h
].add((s
,t
,L
,loc_h
))
329 p
+= rule
.p_STOP(s
, t
, loc_h
) * pLR
330 if 'INNER' in io
.DEBUG
:
331 print "%sp= %.4f (STOP)" % (tab(), p
)
333 else: # not a STOP, an attachment rewrite:
334 rp_ATTACH
= rule
.p_ATTACH
# todo: profile/speedtest
335 for r
in xrange(s
, t
):
336 p_h
= rp_ATTACH(r
, loc_h
, s
=s
)
339 locs_R
= locs(head(R
), sent_nums
, r
+1, t
, loc_h
)
340 elif rule
.LHS() == R
:
341 locs_L
= locs(head(L
), sent_nums
, s
, r
, loc_h
)
344 pL
= e(s
, r
, L
, loc_L
, n_t
+1)
347 pR
= e(r
+1, t
, R
, loc_R
, n_t
+1)
348 if pR
> 0.0: # and pL > 0.0
349 tree
[s
,t
,LHS
,loc_h
].add(( s
,r
,L
,loc_L
))
350 tree
[s
,t
,LHS
,loc_h
].add((r
+1,t
,R
,loc_R
))
352 if 'INNER' in io
.DEBUG
:
353 print "%sp= %.4f (ATTACH)" % (tab(), p
)
354 ichart
[s
, t
, LHS
, loc_h
] = p
358 inner_prob
= e(s
,t
,LHS
,loc_h
, 0)
359 ichart
['tree'] = tree
360 if 'INNER' in io
.DEBUG
:
361 print debug_ichart(g
,sent
,ichart
)
363 # end of dmv.inner(s, t, LHS, loc_h, g, sent, ichart)
366 def debug_ichart(g
,sent
,ichart
):
367 str = "---ICHART:---\n"
368 for (s
,t
,LHS
,loc_h
),v
in ichart
.iteritems():
369 str += "%s -> %s_%d ... %s_%d (loc_h:%s):\t%.4f\n" % (DMV_Rule
.bar_str(LHS
,g
.numtag
),
370 sent
[s
], s
, sent
[s
], t
, loc_h
, v
)
371 str += "---ICHART:end---\n"
375 def inner_sent(loc_h
, g
, sent
, ichart
):
376 return inner(0, len(sent
)-1, ROOT
, loc_h
, g
, sent
, ichart
)
383 def c(s
,t
,LHS
,loc_h
):
384 return inner() * outer() # divided by sentence probability? todo
386 ###################################
387 # dmv-specific version of outer() #
388 ###################################
389 def outer(s
,t
,Node
,loc_N
, g
, sent
, ochart
, ichart
):
390 ''' http://www.student.uib.no/~kun041/dmvccm/DMVCCM.html#outer
392 def e(s
,t
,LHS
,loc_h
):
393 # or we could just look it up in ichart, assuming ichart to be done
394 return inner(s
, t
, LHS
, loc_h
, g
, sent
, ichart
)
397 sent_nums
= [i
for i
,w
in enumerate(sent
)]
399 def f(s
,t
,Node
,loc_N
):
400 if (s
,t
,Node
) in ochart
:
401 return ochart
[(s
, t
, Node
,loc_N
)]
403 if s
== 0 and t
== T
:
405 else: # ROOT may only be used on full sentence
406 return 0.0 # but we may have non-ROOTs over full sentence too
409 for mom
in g
.mothersL(Node
, sent_nums
): # mom.L() == Node
413 p
+= f(s
,t
,mLHS
,loc_N
) * mom
.p_STOP(s
,t
,loc_N
) # == loc_m
415 if bars(mLHS
) == RBAR
: # left attachment, head(mLHS) == head(L)
416 for r
in xrange(t
+1,T
+1): # t+1 to lasT
417 for loc_m
in locs(head(mLHS
),sent_nums
,t
+1,r
):
418 p_m
= mom
.p(t
+1 == loc_m
)
419 p
+= f(s
,r
,mLHS
,loc_m
) * p_m
* e(t
+1,r
,R
,loc_m
)
420 else: # right attachment, head(mLHS) == head(Node)
422 p_m
= mom
.p( t
== loc_m
)
423 for r
in xrange(t
+1,T
+1): # t+1 to lasT
424 for loc_R
in locs(head(mLHS
),sent_nums
,t
+1,r
):
425 p
+= f(s
,r
,mLHS
,loc_m
) * p_m
* e(t
+1,r
,R
,loc_R
)
427 for mom
in g
.mothersR(Node
, sent_nums
):
431 p
+= f(s
,t
,mLHS
,loc_N
) * mom
.p_STOP(s
,t
,loc_N
) # == loc_m
433 if bars(mLHS
) == RBAR
: # left attachment, head(mLHS) == head(Node)
435 p_m
= mom
.p( s
== loc_m
)
436 for r
in xrange(0,s
): # first to s-1
437 for loc_L
in locs(head(L
),sent_nums
,r
,s
-1):
438 p
+= e(r
,s
-1,L
, loc_L
) * p_m
* f(r
,t
,mLHS
,loc_m
)
439 else: # right attachment, head(mLHS) == head(R)
440 for r
in xrange(0,s
): # first to s-1
441 for loc_m
in locs(head(mLHS
),sent_nums
,r
,s
-1):
442 p_m
= mom
.p(s
-1 == loc_m
)
443 p
+= e(r
,s
-1,L
, loc_m
) * p_m
* f(r
,t
,mLHS
,loc_m
)
444 ochart
[s
,t
,Node
,loc_N
] = p
448 return f(s
,t
,Node
,loc_N
)
449 # end outer(s,t,Node,loc_N, g,sent, ochart,ichart)
453 ##############################
454 # reestimation, todo: #
455 ##############################
456 def reestimate_zeros(h_nums
):
457 # todo: p_ROOT, p_CHOOSE, p_terminals
460 f
[('LNSTOP','num',h
)] = 0.0
461 f
[('LNSTOP','den',h
)] = 0.0
464 def reestimate(g
, corpus
):
466 P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
467 f
= reestimate_zeros(g
.head_nums
)
469 sent_nums
= [g
.tagnum(w
) for i
,w
in enumerate(sent
)]
471 for loc_h
,h
in zip(xrange(len(sent
)), sent_nums
):
472 inner_sent(loc_h
, g
, sent
, ichart
)
473 for loc_h
,h
in zip(xrange(len(sent
)), sent_nums
):
474 for s
in xrange(loc_h
): # s<loc(h), range gives strictly less
475 for t
in xrange(loc_h
, len(sent
)):
476 if 'reest' in io
.DEBUG
:
477 print "s:%s t:%s loc:%d"%(s
,t
,loc_h
)
478 if (s
, t
, (LRBAR
,h
), loc_h
) in ichart
:
479 f
[('LNSTOP','num',h
)] += ichart
[s
, t
, (LRBAR
,h
), loc_h
]
480 if 'reest' in io
.DEBUG
:
481 print "num+=%s"%ichart
[s
, t
, (LRBAR
,h
), loc_h
]
482 if (s
, t
, (RBAR
,h
), loc_h
) in ichart
:
483 f
[('LNSTOP','den',h
)] += ichart
[s
, t
, (RBAR
,h
), loc_h
]
484 if 'reest' in io
.DEBUG
:
485 print "den+=%s"%ichart
[s
, t
, (RBAR
,h
), loc_h
]
486 if 'reest' in io
.DEBUG
:
487 print "num:%s den:%s"%(f
[('LNSTOP','num',h
)],f
[('LNSTOP','den',h
)])
489 # adjacent left attachment
492 # todo: use sum([ichart[s, t...] etc? but can we then
493 # keep den and num separate within _one_ sum()-call? use map?
495 for r
in g
.all_rules():
498 if f
[('LNSTOP','den',h
)] > 0.0:
499 r
.probN
= f
[('LNSTOP','num',h
)] / f
[('LNSTOP','den',h
)]
500 if 'reest' in io
.DEBUG
:
501 print "p(STOP|%d=%s,L,N): %s / %s = %s"%(h
,g
.numtag(h
),f
[('LNSTOP','num',h
)],f
[('LNSTOP','den',h
)],f
[('LNSTOP','num',h
)]/f
[('LNSTOP','den',h
)])
503 if 'reest' in io
.DEBUG
:
504 print "p(STOP|%d=%s,L,N): %s / 0"%(h
,g
.numtag(h
),f
[('LNSTOP','num',h
)])
514 ##############################
515 # testing functions: #
516 ##############################
518 def testreestimation():
519 corpus
= [s
.split() for s
in ['det vbd nn c vbd','det nn vbd c nn vbd pp',
520 'det vbd nn', 'det vbd nn c vbd pp',
521 'det vbd nn', 'det vbd c nn vbd pp',
522 'det vbd nn', 'det nn vbd nn c vbd pp',
523 'det vbd nn', 'det nn vbd c det vbd pp',
524 'det vbd nn', 'det nn vbd c vbd det det det pp',
525 'det nn vbd', 'det nn vbd c vbd pp',
526 'det nn vbd', 'det nn vbd c vbd det pp',
527 'det nn vbd', 'det nn vbd c vbd pp',
528 'det nn vbd pp', 'det nn vbd det', ]]
530 g
= harmonic
.initialize(corpus
)
531 corpus
= [s
.split() for s
in ['det nn vbd det nn' ]]
532 reestimate(g
, corpus
)
536 def testgrammar_h(): # Non, Adj
537 _h_
= DMV_Rule((LRBAR
,0), STOP
, ( RBAR
,0), 1.0, 1.0) # LSTOP
538 h_S
= DMV_Rule(( RBAR
,0),(NOBAR
,0), STOP
, 0.4, 0.3) # RSTOP
539 h_A
= DMV_Rule(( RBAR
,0),(LRBAR
,0),( RBAR
,0), 0.6, 0.7) # Lattach
540 h
= DMV_Rule((NOBAR
,0),(NOBAR
,0),(LRBAR
,0), 1.0, 1.0) # Rattach
541 rh
= DMV_Rule( ROOT
, STOP
, (LRBAR
,0), 1.0, 1.0) # ROOT
543 b2
[(NOBAR
, 0), 'h'] = 1.0
544 b2
[(RBAR
, 0), 'h'] = h_S
.probA
545 b2
[(LRBAR
, 0), 'h'] = h_S
.probA
* _h_
.probA
547 return DMV_Grammar([ rh
, _h_
, h_S
, h_A
, h
],b2
,0,0,0, {0:'h'}, {'h':0})
550 def testreestimation_h():
553 reestimate(g
,['h h h'.split()])
555 def regression_tests():
556 g_dup
= testgrammar_h()
558 test0
= inner(0, 1, (LRBAR
,0), 0, g_dup
, 'h h'.split(), {})
559 if not "0.120"=="%.3f" % test0
:
560 print "Should be 0.120: %.3f" % test0
562 test1
= inner(0, 1, (LRBAR
,0), 1, g_dup
, 'h h'.split(), {})
563 if not "0.063"=="%.3f" % test1
:
564 print "Should be 0.063: %.3f" % test1
566 test3
= inner(0, 2, (LRBAR
,0), 2, g_dup
, 'h h h'.split(), {})
567 if not "0.0498"=="%.4f" % test3
:
568 print "Should be 0.0498: %.4f" % test3
570 test4
= outer(1,2,(1,0),2,testgrammar_h(),'h h h'.split(),{},{})
571 if not "0.58" == "%.2f" % test4
:
572 print "Should be 0.58: %.2f" % test4
574 if __name__
== "__main__":
575 io
.DEBUG
= ['reest_ichart']
578 # profile.run('testreestimation()')
579 # print timeit.Timer("dmv.testreestimation()",'''import dmv
580 # reload(dmv)''').timeit(1)