todo: fix P_STOP so it actually changes stuff
[dmvccm.git] / src / dmv.py
blobc67e2847c341ba3497b1f890db1194a10cf0a8e0
1 #### changes by KBU:
2 # 2008-05-24:
3 # - prettier printout for DMV_Rule
4 # - DMV_Rule changed a bit. head, L and R are now all pairs of the
5 # form (bars, head).
6 # - Started on P_STOP, a bit less pseudo now..
8 # 2008-05-27:
9 # - started on initialization. So far, I have frequencies for
10 # everything, very harmonic. Still need to make these into 1-summing
11 # probabilities
13 # 2008-05-28:
14 # - more work on initialization (init_freq and init_normalize),
15 # getting closer to probabilities now.
17 # 2008-05-29:
18 # - init_normalize is done, it creates p_STOP, p_ROOT and p_CHOOSE,
19 # and also adds the relevant probabilities to p_rules in a grammar.
20 # Still, each individual rule has to store both adjacent and non_adj
21 # probabilities, and inner() should be able to send some parameter
22 # which lets the rule choose... hopefully... Is this possible to do
23 # top-down even? when the sentence could be all the same words?
24 # todo: extensive testing of identical words in sentences!
25 # - frequencies (only used in initialization) are stored as strings,
26 # but in the rules and p_STOP etc, there are only numbers.
28 # 2008-05-30
29 # - copied inner() into this file, to make the very dmv-specific
30 # adjacency stuff work (have to factor that out later on, when it
31 # works).
33 # 2008-06-01
34 # - finished typing in inner_dmv(), still have to test and debug
35 # it. The chart is now four times as big since for any rule we may
36 # have attachments to either the left or the right below, which
37 # upper rules depend on, for selecting probN or probA
39 # 2008-06-03
40 # - fixed a number of little bugs in initialization, where certain
41 # rules were simply not created, or created "backwards"
42 # - inner_dmv() should Work now...
44 # 2008-06-04
45 # - moved initialization to harmonic.py
48 # import numpy # numpy provides Fast Arrays, for future optimization
49 import pprint
50 import io
51 import harmonic
53 # non-tweakable/constant "lookup" globals
54 BARS = [0,1,2]
55 RBAR = 1
56 LRBAR = 2
57 NOBAR = 0
58 ROOT = (LRBAR, -1)
59 STOP = (NOBAR, -2)
61 if __name__ == "__main__":
62 print "DMV module tests:"
65 def node(bars, head):
66 '''Useless function, but just here as documentation. Nodes make up
67 LHS, R and L in each DMV_Rule'''
68 return (bars, head)
70 def bars(node):
71 return node[0]
73 def head(node):
74 return node[1]
77 class DMV_Grammar(io.Grammar):
78 '''The DMV-PCFG.
80 Public members:
81 p_STOP, p_ROOT, p_CHOOSE, p_terminals
82 These are changed in the Maximation step, then used to set the
83 new probabilities of each DMV_Rule.
85 Todo: make p_terminals private? (But it has to be changable in
86 maximation step due to the short-cutting rules... could of course
87 make a DMV_Grammar function to update the short-cut rules...)
89 __p_rules is private, but we can still say stuff like:
90 for r in g.all_rules():
91 r.probN = newProbN
93 What other representations do we need? (P_STOP formula uses
94 deps_D(h,l/r) at least)'''
95 def __str__(self):
96 str = ""
97 for r in self.all_rules():
98 str += "%s\n" % r.__str__(self.numtag)
99 return str
101 def h_rules(self, h):
102 return [r for r in self.all_rules() if r.head() == h]
104 def rules(self, LHS):
105 return [r for r in self.all_rules() if r.LHS() == LHS]
107 def sent_rules(self, LHS, sent_nums):
108 "Used in inner_dmv."
109 # We don't want to rule out STOPs!
110 sent_nums.append( head(STOP) )
111 return [r for r in self.all_rules() if r.LHS() == LHS
112 and head(r.L()) in sent_nums and head(r.R()) in sent_nums]
114 def heads(self):
115 '''Not sure yet what is needed here, or where this is needed'''
116 return numtag
118 def deps_L(self, head):
119 # todo test, probably this list comprehension doesn't work
120 return [a for r in self.all_rules() if r.head() == head and a == r.L()]
122 def deps_R(self, head):
123 # todo test, probably this list comprehension doesn't work
124 return [a for r in self.all_rules() if r.head() == head and a == r.R()]
126 def __init__(self, p_rules, p_terminals, p_STOP, p_CHOOSE, p_ROOT, numtag, tagnum):
127 io.Grammar.__init__(self, p_rules, p_terminals, numtag, tagnum)
128 self.p_STOP = p_STOP
129 self.p_CHOOSE = p_CHOOSE
130 self.p_ROOT = p_ROOT
133 class DMV_Rule(io.CNF_Rule):
134 '''A single CNF rule in the PCFG, of the form
135 LHS -> L R
136 where LHS, L and R are 'nodes', eg. of the form (bars, head).
138 Public members:
139 probN, probA
141 Private members:
142 __L, __R, __LHS
144 Different rule-types have different probabilities associated with
145 them:
147 _h_ -> STOP h_ P( STOP|h,L, adj)
148 _h_ -> STOP h_ P( STOP|h,L,non_adj)
149 h_ -> h STOP P( STOP|h,R, adj)
150 h_ -> h STOP P( STOP|h,R,non_adj)
151 h_ -> _a_ h_ P(-STOP|h,L, adj) * P(a|h,L)
152 h_ -> _a_ h_ P(-STOP|h,L,non_adj) * P(a|h,L)
153 h -> h _a_ P(-STOP|h,R, adj) * P(a|h,R)
154 h -> h _a_ P(-STOP|h,R,non_adj) * P(a|h,R)
156 def p(self, adj, *arg):
157 if adj:
158 return self.probA
159 else:
160 return self.probN
162 def p_STOP(self, s, t, loc_h):
163 '''Returns the correct probability, adjacent or non-adjacent,
164 depending on whether or not we're at the (either left or
165 right) end of the fragment. '''
166 if self.L() == STOP:
167 return self.p(s == loc_h)
168 elif self.R() == STOP:
169 if not loc_h == s:
170 io.debug( "(%s given loc_h:%d but s:%d. Todo: optimize away!)"
171 % (self, loc_h, s) )
172 return 0.0
173 else:
174 return self.p(t == loc_h)
176 def p_ATTACH(self, r, loc_L, loc_R, s=None):
177 '''Returns the correct probability, adjacent or non-adjacent,
178 depending on whether or not there is a some lower attachment
179 either on the right side of the left child, or the left side
180 of the right child. '''
181 if self.LHS() == self.L() and not loc_L == s:
182 io.debug( "(%s given loc_h (loc_L):%d but s:%d. Todo: optimize away!)"
183 % (self, loc_L, s) )
184 return 0.0
185 else:
186 return self.p(r == loc_L and r+1 == loc_R)
188 def bars(self):
189 return bars(self.LHS())
191 def head(self):
192 return head(self.LHS())
194 def __init__(self, LHS, L, R, probN, probA):
195 for b_h in [LHS, L, R]:
196 if bars(b_h) not in BARS:
197 raise ValueError("bars must be in %s; was given: %s"
198 % (BARS, bars(b_h)))
199 io.CNF_Rule.__init__(self, LHS, L, R, probN)
200 self.probA = probA # adjacent
201 self.probN = probN # non_adj
203 @classmethod # so we can call DMV_Rule.bar_str(b_h)
204 def bar_str(cls, b_h, tag=lambda x:x):
205 if(b_h == ROOT):
206 return 'ROOT'
207 elif(b_h == STOP):
208 return 'STOP'
209 elif(bars(b_h) == RBAR):
210 return " %s_ " % tag(head(b_h))
211 elif(bars(b_h) == LRBAR):
212 return "_%s_ " % tag(head(b_h))
213 else:
214 return " %s " % tag(head(b_h))
217 def __str__(self, tag=lambda x:x):
218 return "%s-->%s %s\t[N %.2f] [A %.2f]" % (self.bar_str(self.LHS(), tag),
219 self.bar_str(self.L(), tag),
220 self.bar_str(self.R(), tag),
221 self.probN,
222 self.probA)
230 ###################################
231 # dmv-specific version of inner() #
232 ###################################
233 def locs(h, sent, s=0, t=None):
234 '''Return the locations of h in sent, or some fragment of sent (in the
235 latter case we make sure to offset the locations correctly so that
236 for any x in the returned list, sent[x]==h).'''
237 if t == None:
238 t = len(sent)
239 return [i+s for i,w in enumerate(sent[s:t]) if w == h]
242 def inner_dmv(s, t, LHS, loc_h, g, sent, chart):
243 ''' A rewrite of inner in io.py, to take adjacency into accord.
245 The chart is now of this form:
246 chart[(s,t,LHS, loc_h)]
248 loc_h gives adjacency (along with r and location of other child
249 for attachment rules), and is needed in P_STOP reestimation.
251 Todo: if possible, refactor (move dmv-specific stuff back into
252 dmv, so this is "general" enough to be in io.py)
255 def O(s):
256 return sent[s]
258 sent_nums = [g.tagnum(tag) for tag in sent]
260 def e(s,t,LHS, loc_h, n_t):
261 def tab():
262 "Tabs for debug output"
263 return "\t"*n_t
265 if (s, t, LHS, loc_h) in chart:
266 io.debug("%s*= %.4f in chart: s:%d t:%d LHS:%s loc:%d"
267 %(tab(),chart[(s, t, LHS, loc_h)], s, t,
268 DMV_Rule.bar_str(LHS), loc_h))
269 return chart[(s, t, LHS, loc_h)]
270 else:
271 if s == t:
272 if not loc_h == s:
273 # terminals are always F,F for attachment
274 io.debug("%s*= 0.0 (wrong loc_h)" % tab())
275 return 0.0
276 elif (LHS, O(s)) in g.p_terminals:
277 prob = g.p_terminals[LHS, O(s)] # "b[LHS, O(s)]" in Lari&Young
278 else:
279 # todo: assuming this is how to deal w/lacking
280 # rules, since we add prob.s, and 0 is identity
281 prob = 0.0
282 io.debug( "%sLACKING TERMINAL:" % tab())
283 # todo: add to chart perhaps? Although, it _is_ simple lookup..
284 io.debug( "%s*= %.4f (terminal: %s -> %s_%d)"
285 % (tab(),prob, DMV_Rule.bar_str(LHS), O(s), loc_h) )
286 return prob
287 else:
288 p = 0.0 # "sum over j,k in a[LHS,j,k]"
289 for rule in g.sent_rules(LHS, sent_nums):
290 io.debug( "%ssumming rule %s s:%d t:%d loc:%d" % (tab(),rule,s,t,loc_h) )
291 L = rule.L()
292 R = rule.R()
293 # if it's a STOP rule, rewrite for the same range:
294 if (L == STOP) or (R == STOP):
295 if L == STOP:
296 pLR = e(s, t, R, loc_h, n_t+1)
297 elif R == STOP:
298 pLR = e(s, t, L, loc_h, n_t+1)
299 p += rule.p_STOP(s, t, loc_h) * pLR
300 io.debug( "%sp= %.4f (STOP)" % (tab(), p) )
302 else: # not a STOP, an attachment rewrite:
303 for r in range(s, t):
304 if rule.LHS() == L:
305 locs_L = [loc_h]
306 locs_R = locs(head(R), sent_nums, r+1, t+1)
307 elif rule.LHS() == R:
308 locs_L = locs(head(L), sent_nums, s, r+1)
309 locs_R = [loc_h]
310 # see http://tinyurl.com/4ffhhw
311 p += sum([e(s, r, L, loc_L, n_t+1) *
312 rule.p_ATTACH(r, loc_L, loc_R, s=s) *
313 e(r+1, t, R, loc_R, n_t+1)
314 for loc_L in locs_L
315 for loc_R in locs_R])
316 io.debug( "%sp= %.4f (ATTACH)" % (tab(), p) )
317 chart[(s, t, LHS, loc_h)] = p
318 return p
319 # end of e-function
321 inner_prob = e(s,t,LHS,loc_h, 0)
322 if io.DEBUG:
323 print "---CHART:---"
324 for (s,t,LHS,loc_h),v in chart.iteritems():
325 print "%s -> %s_%d ... %s_%d (loc_h:%s):\t%.3f" % (DMV_Rule.bar_str(LHS,g.numtag),
326 O(s), s, O(s), t, loc_h, v)
327 print "---CHART:end---"
328 return [inner_prob, chart]
329 # end of inner_dmv(s, t, LHS, loc_h, g, sent, chart)
332 if __name__ == "__main__": # Non, Adj
333 _h_ = DMV_Rule((LRBAR,0), STOP, ( RBAR,0), 1.0, 1.0) # LSTOP
334 h_S = DMV_Rule(( RBAR,0),(NOBAR,0), STOP, 0.4, 0.3) # RSTOP
335 h_A = DMV_Rule(( RBAR,0),(LRBAR,0),( RBAR,0), 0.6, 0.7) # Lattach
336 h = DMV_Rule((NOBAR,0),(NOBAR,0),(LRBAR,0), 1.0, 1.0) # Rattach
337 b2 = {}
338 b2[(NOBAR, 0), 'h'] = 1.0
339 b2[(RBAR, 0), 'h'] = h_S.probA
340 b2[(LRBAR, 0), 'h'] = h_S.probA * _h_.probA
342 g_dup = DMV_Grammar([ _h_, h_S, h_A, h ],b2,0,0,0, {0:'h'}, {'h':0})
344 io.DEBUG = 0
345 test0 = inner_dmv(0, 1, (LRBAR,0), 0, g_dup, 'h h'.split(), {})
346 if not "0.120"=="%.3f" % test0[0]:
347 print "Should be 0.120: %.3f" % test0[0]
349 test1 = inner_dmv(0, 1, (LRBAR,0), 1, g_dup, 'h h'.split(), {})
350 if not "0.063"=="%.3f" % test1[0]:
351 print "Should be 0.063: %.3f" % test1[0]
353 test3 = inner_dmv(0, 2, (LRBAR,0), 2, g_dup, 'h h h'.split(), {})
354 if not "0.0462"=="%.4f" % test3[0]:
355 print "Should be 0.0462: %.4f" % test3[0]
362 ##############################
363 # DMV-probabilities, todo: #
364 ##############################
367 def P_CHOOSE():
368 return "todo"
370 def DMV(sent, g):
371 '''Here it seems like they store rule information on a per-head (per
372 direction) basis, in deps_D(h, dir) which gives us a list. '''
373 def P_h(h):
374 P_h = 1 # ?
375 for dir in ['l', 'r']:
376 for a in deps(h, dir):
377 # D(a)??
378 P_h *= \
379 P_STOP (0, h, dir, adj) * \
380 P_CHOOSE (a, h, dir) * \
381 P_h(D(a)) * \
382 P_STOP (STOP | h, dir, adj)
383 return P_h
384 return P_h(root(sent))
387 def P_STOP(STOP, h, dir, adj, g, corpus):
388 '''corpus is a list of sentences s.
390 This is based on the formula where STOP is True... not sure how we
391 calculate if STOP is False.
393 I thought about instead having this:
395 for rule in g.p_rules:
396 rule.num = 0
397 rule.den = 0
398 for sent in corpus:
399 for rule in g.p_rules:
400 for s:
401 for t:
402 set num and den using inner
403 for rule in g.p_rules
404 rule.prob = rule.num / rule.den
406 ..the way I'm assuming we do it in the commented out io-function in
407 io.py. Having sentences as the outer loop at least we can easily just
408 go through the heads that are actually in the sentence... BUT, this
409 means having to go through p_rules 3 times, not sure what is slower.
411 Also, now inner_dmv makes sure it only goes through heads that are
412 actually in the sentence, so that argument falls.
414 oh, and:
415 P_STOP(-STOP|...) = 1 - P_STOP(STOP|...)
419 P_STOP_num = 0
420 P_STOP_den = 0
421 h_tag = g.numtag(h)
422 for sent in corpus:
423 # have to go through _all_ places where h appears in the
424 # sentence...how? how to make sure it _works_?
425 chart = {} # cuts time from 17s to 7s !
426 if h_tag in sent:
427 locs_h = locs(h_tag, sent)
428 io.debug( "locs_h:%s, sent:%s"%(locs_h,sent) )
429 for loc_h in locs_h:
430 for s in range(loc_h): # s<loc(h), range gives strictly less
431 for t in range(loc_h, len(sent)): # should not be range(s,..), right? todo
432 P_STOP_num += inner_dmv(s, t, (LRBAR,h), loc_h, g, sent, chart)[0]
433 P_STOP_den += inner_dmv(s, t, (RBAR,h), loc_h, g, sent, chart)[0]
435 io.debug( "num/den: %s / %s = %s"%(P_STOP_num, P_STOP_den,P_STOP_num / P_STOP_den))
436 if P_STOP_den > 0.0:
437 return P_STOP_num / P_STOP_den # upside down in article
438 else:
439 return 0.0
443 def testreestimation():
444 testcorpus = [s.split() for s in ['det nn vbd c vbd','det nn vbd c nn vbd pp',
445 'det nn vbd', 'det vbd nn c vbd pp',
446 'det nn vbd', 'det vbd c nn vbd pp',
447 'det nn vbd', 'det nn vbd nn c vbd pp',
448 'det nn vbd', 'det nn vbd c det vbd pp',
449 'det nn vbd', 'det nn vbd c vbd det det det pp',
450 'det nn vbd', 'det nn vbd c vbd pp',
451 'det nn vbd', 'det nn vbd c vbd det pp',
452 'det nn vbd', 'det nn vbd c vbd pp',
453 'det nn vbd pp', 'det nn vbd det', ]]
454 g = harmonic.initialize(testcorpus)
456 h_tag = 'nn'
457 h = g.tagnum(h_tag)
458 print "This will take some time. todo: figure out why it doesn't work"
459 for r in g.h_rules(h):
460 if r.L()==STOP:
461 print r
462 # print "off-set the rule, see what happens:"
463 # r.probN = 0.7
464 # print r
465 for i in range(3):
466 pstophln = P_STOP(True, h, 'L', 'N', g, testcorpus)
467 print "p(STOP|%s,L,N):%s"%(h_tag,pstophln)
469 for r in g.h_rules(h):
470 if r.L()==STOP:
471 print r
472 r.probN = pstophln
473 print r
474 return "todo"
476 def testreestimation_h():
477 _h_ = DMV_Rule((LRBAR,0), STOP, ( RBAR,0), 1.0, 1.0) # LSTOP
478 h_S = DMV_Rule(( RBAR,0),(NOBAR,0), STOP, 0.4, 0.3) # RSTOP
479 h_A = DMV_Rule(( RBAR,0),(LRBAR,0),( RBAR,0), 0.6, 0.7) # Lattach
480 h = DMV_Rule((NOBAR,0),(NOBAR,0),(LRBAR,0), 1.0, 1.0) # Rattach
481 b2 = {}
482 b2[(NOBAR, 0), 'h'] = 1.0
483 b2[(RBAR, 0), 'h'] = h_S.probA
484 b2[(LRBAR, 0), 'h'] = h_S.probA * _h_.probA
486 g_dup = DMV_Grammar([ _h_, h_S, h_A, h ],b2,0,0,0, {0:'h'}, {'h':0})
488 # test3 = inner_dmv(0, 2, (LRBAR,0), 2, g_dup, 'h h h'.split(), {})
489 h_tag = 'h'
490 h = 0
491 print "todo: figure out why it doesn't work"
492 for r in g_dup.h_rules(h):
493 if r.L()==STOP:
494 print r
495 # print "off-set the rule, see what happens:"
496 # r.probN = 0.7
497 # print r
498 for i in range(3):
499 pstophln = P_STOP(True, h, 'L', 'N', g_dup, ['h h h'.split()])
500 print "p(STOP|%s,L,N):%s"%(h_tag,pstophln)
502 for r in g_dup.h_rules(h):
503 if r.L()==STOP:
504 print r
505 r.probN = pstophln
506 print r
507 return "todo"
509 if __name__ == "__main__":
510 io.DEBUG = 0
511 import timeit
512 timeit.Timer("dmv.testreestimation_h()",'''import dmv
513 reload(dmv)''').timeit(1)
514 pass
518 # todo: some more testing on the Brown corpus:
519 # # first five sentences of the Brown corpus:
520 # g_brown = harmonic.initialize([['AT', 'NP-TL', 'NN-TL', 'JJ-TL', 'NN-TL', 'VBD', 'NR', 'AT', 'NN', 'IN', 'NP$', 'JJ', 'NN', 'NN', 'VBD', '``', 'AT', 'NN', "''", 'CS', 'DTI', 'NNS', 'VBD', 'NN', '.'], ['AT', 'NN', 'RBR', 'VBD', 'IN', 'NN', 'NNS', 'CS', 'AT', 'NN-TL', 'JJ-TL', 'NN-TL', ',', 'WDT', 'HVD', 'JJ', 'NN', 'IN', 'AT', 'NN', ',', '``', 'VBZ', 'AT', 'NN', 'CC', 'NNS', 'IN', 'AT', 'NN-TL', 'IN-TL', 'NP-TL', "''", 'IN', 'AT', 'NN', 'IN', 'WDT', 'AT', 'NN', 'BEDZ', 'VBN', '.'], ['AT', 'NP', 'NN', 'NN', 'HVD', 'BEN', 'VBN', 'IN', 'NP-TL', 'JJ-TL', 'NN-TL', 'NN-TL', 'NP', 'NP', 'TO', 'VB', 'NNS', 'IN', 'JJ', '``', 'NNS', "''", 'IN', 'AT', 'JJ', 'NN', 'WDT', 'BEDZ', 'VBN', 'IN', 'NN-TL', 'NP', 'NP', 'NP', '.'], ['``', 'RB', 'AT', 'JJ', 'NN', 'IN', 'JJ', 'NNS', 'BEDZ', 'VBN', "''", ',', 'AT', 'NN', 'VBD', ',', '``', 'IN', 'AT', 'JJ', 'NN', 'IN', 'AT', 'NN', ',', 'AT', 'NN', 'IN', 'NNS', 'CC', 'AT', 'NN', 'IN', 'DT', 'NN', "''", '.'], ['AT', 'NN', 'VBD', 'PPS', 'DOD', 'VB', 'CS', 'AP', 'IN', 'NP$', 'NN', 'CC', 'NN', 'NNS', '``', 'BER', 'JJ', 'CC', 'JJ', 'CC', 'RB', 'JJ', "''", '.'], ['PPS', 'VBD', 'CS', 'NP', 'NNS', 'VB', '``', 'TO', 'HV', 'DTS', 'NNS', 'VBN', 'CC', 'VBN', 'IN', 'AT', 'NN', 'IN', 'VBG', 'CC', 'VBG', 'PPO', "''", '.'], ['AT', 'JJ', 'NN', 'VBD', 'IN', 'AT', 'NN', 'IN', 'AP', 'NNS', ',', 'IN', 'PPO', 'AT', 'NP', 'CC', 'NP-TL', 'NN-TL', 'VBG', 'NNS', 'WDT', 'PPS', 'VBD', '``', 'BER', 'QL', 'VBN', 'CC', 'VB', 'RB', 'VBN', 'NNS', 'WDT', 'VB', 'IN', 'AT', 'JJT', 'NN', 'IN', 'ABX', 'NNS', "''", '.'], ['NN-HL', 'VBN-HL'], ['WRB', ',', 'AT', 'NN', 'VBD', 'PPS', 'VBZ', '``', 'DTS', 'CD', 'NNS', 'MD', 'BE', 'VBN', 'TO', 'VB', 'JJR', 'NN', 'CC', 'VB', 'AT', 'NN', 'IN', 'NN', "''", '.'], ['AT', 'NN-TL', 'VBG-TL', 'NN-TL', ',', 'AT', 'NN', 'VBD', ',', '``', 'BEZ', 'VBG', 'IN', 'VBN', 'JJ', 'NNS', 'CS', 'AT', 'NN', 'IN', 'NN', 'NNS', 'NNS', "''", '.']])
521 # # 36:'AT' in g_brown.numtag, 40:'NP-TL'
523 # io.DEBUG = 0
524 # test_brown = inner_dmv(0,2, (LRBAR,36), g_brown, ['AT', 'NP-TL' ,'NN-TL','JJ-TL'], {})
525 # if io.DEBUG:
526 # for r in g_brown.rules((2,36)) + g_brown.rules((1,36)) + g_brown.rules((0,36)):
527 # L = r.L()
528 # R = r.R()
529 # if head(L) in [36,40,-2] and head(R) in [36,40,-2]:
530 # print r
531 # print "Brown-test gives: %.8f" % test_brown[0]
535 # this will give the tag sequences of all the 6218 Brown corpus
536 # sentences of length < 7:
537 # [[tag for (w, tag) in sent]
538 # for sent in nltk.corpus.brown.tagged_sents() if len(sent) < 7]