still working on pstop
[dmvccm.git] / src / dmv.py
blob61a090352772a199d5af47909fefbc4962f546f7
1 #### changes by KBU:
2 # 2008-05-24:
3 # - prettier printout for DMV_Rule
4 # - DMV_Rule changed a bit. head, L and R are now all pairs of the
5 # form (bars, head).
6 # - Started on P_STOP, a bit less pseudo now..
8 # 2008-05-27:
9 # - started on initialization. So far, I have frequencies for
10 # everything, very harmonic. Still need to make these into 1-summing
11 # probabilities
13 # 2008-05-28:
14 # - more work on initialization (init_freq and init_normalize),
15 # getting closer to probabilities now.
17 # 2008-05-29:
18 # - init_normalize is done, it creates p_STOP, p_ROOT and p_CHOOSE,
19 # and also adds the relevant probabilities to p_rules in a grammar.
20 # Still, each individual rule has to store both adjacent and non_adj
21 # probabilities, and inner() should be able to send some parameter
22 # which lets the rule choose... hopefully... Is this possible to do
23 # top-down even? when the sentence could be all the same words?
24 # todo: extensive testing of identical words in sentences!
25 # - frequencies (only used in initialization) are stored as strings,
26 # but in the rules and p_STOP etc, there are only numbers.
28 # 2008-05-30
29 # - copied inner() into this file, to make the very dmv-specific
30 # adjacency stuff work (have to factor that out later on, when it
31 # works).
33 # 2008-06-01
34 # - finished typing in inner_dmv(), still have to test and debug
35 # it. The chart is now four times as big since for any rule we may
36 # have attachments to either the left or the right below, which
37 # upper rules depend on, for selecting probN or probA
39 # 2008-06-03
40 # - fixed a number of little bugs in initialization, where certain
41 # rules were simply not created, or created "backwards"
42 # - inner_dmv() should Work now...
44 # 2008-06-04
45 # - moved initialization to harmonic.py
48 # import numpy # numpy provides Fast Arrays, for future optimization
49 import pprint
50 import io
51 import harmonic
53 # non-tweakable/constant "lookup" globals
54 BARS = [0,1,2]
55 RBAR = 1
56 LRBAR = 2
57 NOBAR = 0
58 ROOT = (LRBAR, -1)
59 STOP = (NOBAR, -2)
61 # todo: use these instead for attachment constants. Requires making
62 # the last two indices of chart[] one single pair, boring retyping and
63 # testing, worth it?
64 LRATT = (True, True)
65 LATT = (True, False)
66 RATT = (False, True)
67 NOATT = (False, False)
69 if __name__ == "__main__":
70 print "DMV module tests:"
73 def node(bars, head):
74 '''Useless function, but just here as documentation. Nodes make up
75 LHS, R and L in each DMV_Rule'''
76 return (bars, head)
78 def bars(node):
79 return node[0]
81 def head(node):
82 return node[1]
85 class DMV_Grammar(io.Grammar):
86 '''The DMV-PCFG.
88 Public members:
89 p_STOP, p_ROOT, p_CHOOSE, p_terminals
90 These are changed in the Maximation step, then used to set the
91 new probabilities of each DMV_Rule.
93 Todo: make p_terminals private? (But it has to be changable in
94 maximation step due to the short-cutting rules... could of course
95 make a DMV_Grammar function to update the short-cut rules...)
97 __p_rules is private, but we can still say stuff like:
98 for r in g.all_rules():
99 r.probN = newProbN
101 What other representations do we need? (P_STOP formula uses
102 deps_D(h,l/r) at least)'''
103 def __str__(self):
104 str = ""
105 for r in self.all_rules():
106 str += "%s\n" % r.__str__(self.numtag)
107 return str
109 def rules(self, LHS):
110 "This function is no longer used in DMV, since we have sent_rules."
111 return [r for r in self.all_rules()
112 if r.head() == head(LHS) and r.bars() == bars(LHS)]
114 def sent_rules(self, LHS, sent_nums):
116 # We don't want to rule out STOPs!
117 sent_nums.append( head(STOP) )
118 return [r for r in self.all_rules()
119 if r.head() == head(LHS) and r.bars() == bars(LHS)
120 and head(r.L()) in sent_nums and head(r.R()) in sent_nums]
122 def heads(self):
123 '''Not sure yet what is needed here, or where this is needed'''
124 return numtag
126 def deps_L(self, head):
127 # todo test, probably this list comprehension doesn't work
128 return [a for r in self.all_rules() if r.head() == head and a == r.L()]
130 def deps_R(self, head):
131 # todo test, probably this list comprehension doesn't work
132 return [a for r in self.all_rules() if r.head() == head and a == r.R()]
134 def __init__(self, p_rules, p_terminals, p_STOP, p_CHOOSE, p_ROOT, numtag, tagnum):
135 io.Grammar.__init__(self, p_rules, p_terminals, numtag, tagnum)
136 self.p_STOP = p_STOP
137 self.p_CHOOSE = p_CHOOSE
138 self.p_ROOT = p_ROOT
141 class DMV_Rule(io.CNF_Rule):
142 '''A single CNF rule in the PCFG, of the form
143 LHS -> L R
144 where LHS, L and R are 'nodes', eg. of the form (bars, head).
146 Public members:
147 probN, probA
149 Private members:
150 __L, __R, __LHS
152 Different rule-types have different probabilities associated with
153 them:
155 _h_ -> STOP h_ P( STOP|h,L, adj)
156 _h_ -> STOP h_ P( STOP|h,L,non_adj)
157 h_ -> h STOP P( STOP|h,R, adj)
158 h_ -> h STOP P( STOP|h,R,non_adj)
159 h_ -> _a_ h_ P(-STOP|h,L, adj) * P(a|h,L)
160 h_ -> _a_ h_ P(-STOP|h,L,non_adj) * P(a|h,L)
161 h -> h _a_ P(-STOP|h,R, adj) * P(a|h,R)
162 h -> h _a_ P(-STOP|h,R,non_adj) * P(a|h,R)
164 def p(self, LRattach, RLattach, *arg):
165 '''Returns the correct probability, adjacent or non-adjacent,
166 depending on whether or not there is a some lower attachment
167 either on the right side of the left child, or the left side
168 of the right child. '''
169 if (not LRattach) and (not RLattach):
170 return self.probA
171 else:
172 return self.probN
174 def bars(self):
175 return bars(self.LHS())
177 def head(self):
178 return head(self.LHS())
180 def __init__(self, LHS, L, R, probN, probA):
181 for b_h in [LHS, L, R]:
182 if bars(b_h) not in BARS:
183 raise ValueError("bars must be in %s; was given: %s"
184 % (BARS, bars(b_h)))
185 io.CNF_Rule.__init__(self, LHS, L, R, probN)
186 self.probA = probA # adjacent
187 self.probN = probN # non_adj
189 @classmethod # so we can call DMV_Rule.bar_str(b_h)
190 def bar_str(cls, b_h, tag=lambda x:x):
191 if(b_h == ROOT):
192 return 'ROOT'
193 elif(b_h == STOP):
194 return 'STOP'
195 elif(bars(b_h) == RBAR):
196 return " %s_ " % tag(head(b_h))
197 elif(bars(b_h) == LRBAR):
198 return "_%s_ " % tag(head(b_h))
199 else:
200 return " %s " % tag(head(b_h))
203 def __str__(self, tag=lambda x:x):
204 return "%s-->%s %s\t[N %.2f] [A %.2f]" % (self.bar_str(self.LHS(), tag),
205 self.bar_str(self.L(), tag),
206 self.bar_str(self.R(), tag),
207 self.probN,
208 self.probA)
216 ###################################
217 # dmv-specific version of inner() #
218 ###################################
219 def rewrite_adj(bars, Lattach, Rattach):
220 '''Returns a list of possible adjacencies for the left and right
221 children of an attachment rule. Each possible adjacency is a list
222 of booleans of the form (LL, LR, RL, RR).
224 Todo: make prettier? Although since we call this so many times,
225 having it spelled out here is probably faster'''
226 if bars == NOBAR and not Lattach and Rattach:
227 return ( (Lattach, False, False, False),
228 (Lattach, False, False, True),
229 (Lattach, False, True, False),
230 (Lattach, False, True, True),
231 (Lattach, True, False, False),
232 (Lattach, True, False, True),
233 (Lattach, True, True, False),
234 (Lattach, True, True, True), )
235 elif bars == RBAR and Lattach:
236 # Rattach may be either true or false here!
237 return ( (False, False, False, Rattach),
238 (False, False, True, Rattach),
239 (False, True, False, Rattach),
240 (False, True, True, Rattach),
241 (True, False, False, Rattach),
242 (True, False, True, Rattach),
243 (True, True, False, Rattach),
244 (True, True, True, Rattach) )
245 else:
246 # NOBAR rewrite rules cannot have Lattach below, and must
247 # have/add Rattach. RBAR rewrite rules must add Lattach, but
248 # don't care about Rattach. Returning () should ensure we
249 # don't add any probability to such "false" situations
250 return ()
252 def inner_dmv(s, t, LHS, g, sent, chart):
253 ''' A rewrite of inner in io.py, to take adjacency into accord.
255 The chart is now 4 times bigger, since there are different values
256 for with or without L/R attachments:
257 chart[(s,t,LHS, Lattach, Rattach)]
259 If Rattach==True then the rule has a right-attachment or there is
260 one lower in the tree (meaning we're no longer
261 adjacent). Adjacency depends on whether there is an attachment
262 lower in the tree, cf. DMV_Rule.p(LRattach, RLattach).
264 Todo: if possible, refactor (move dmv-specific stuff back into
265 dmv, so this is "general" enough to be in io.py)
268 def debug_inner_dmv(tabs,s,t,LHS,Lattach,Rattach):
269 if io.DEBUG:
270 attach = {
271 (True, True): "left and right attachments below",
272 (True, False): "left attachment(s) below",
273 (False, True): "right attachment(s) below",
274 (False, False): "no attachments below" }
275 info = (tabs,O(s),s,O(t),t, DMV_Rule.bar_str(LHS), attach[Lattach,Rattach])
276 print "%sTrying from %s_%d to %s_%d with %s, %s:" % info
278 def O(s):
279 return sent[s]
281 sent_nums = [g.tagnum(tag) for tag in sent]
283 def e(s,t,LHS, Lattach, Rattach, n_t):
284 def tab():
285 "Tabs for debug output"
286 return "\t"*n_t
288 if (s, t, LHS, Lattach, Rattach) in chart:
289 return chart[(s, t, LHS, Lattach, Rattach)]
290 else:
291 debug_inner_dmv(tab(),s,t,LHS, Lattach, Rattach)
292 if s == t:
293 if Lattach or Rattach:
294 # terminals are always F,F for attachment
295 io.debug("%s= 0.0 (1 word, no lower attach)" % tab())
296 return 0.0
297 elif (LHS, O(s)) in g.p_terminals:
298 prob = g.p_terminals[LHS, O(s)] # b[LHS, O(s)] in Lari&Young
299 else:
300 # todo: assuming this is how to deal with lacking
301 # rules, since we add prob.s, and 0 is identity
302 prob = 0.0
303 io.debug( "%sLACKING TERMINAL:" % tab())
304 # todo: add to chart perhaps? Although, it _is_ simple lookup..
305 io.debug( "%s= %.1f (terminal: %s -> %s)" % (tab(),prob,
306 DMV_Rule.bar_str(LHS),
307 O(s)) )
308 return prob
309 else:
310 if (s,t,LHS,Lattach, Rattach) not in chart:
311 chart[(s,t,LHS,Lattach,Rattach)] = 0.0
312 for rule in g.sent_rules(LHS, sent_nums): # summing over j,k in a[LHS,j,k]
313 io.debug( "%ssumming rule %s" % (tab(),rule) )
314 L = rule.L()
315 R = rule.R()
316 # if it's a STOP rule, rewrite for the same range:
317 if (L == STOP) or (R == STOP):
318 if L == STOP:
319 p = rule.p(Lattach, False) # todo check
320 pLR = e(s, t, R, Lattach, Rattach, n_t+1)
321 elif R == STOP:
322 p = rule.p(False, Rattach) # todo check
323 pLR = e(s, t, L, Lattach, Rattach, n_t+1)
324 chart[(s, t, LHS, Lattach, Rattach)] += p * pLR
326 # not a STOP, an attachment rewrite:
327 else:
328 for r in range(s, t):
329 if head(L) in sent_nums[s:r+1] and head(R) in sent_nums[r+1:t+1]:
330 # LL etc are boolean attachment values
331 for (LL, LR, RL, RR) in rewrite_adj(rule.bars(), Lattach, Rattach):
332 p = rule.p(LR, RL) # probN or probA
333 pL = e(s, r, L, LL, LR, n_t+1)
334 pR = e(r+1, t, R, RL, RR, n_t+1)
335 chart[(s, t, LHS,Lattach,Rattach)] += p * pL * pR
337 return chart[(s, t, LHS,Lattach,Rattach)]
338 # end of e-function
340 inner_prob = e(s,t,LHS,True,True, 0) + e(s,t,LHS,True,False, 0) + e(s,t,LHS,False,True, 0) + e(s,t,LHS,False,False, 0)
341 if io.DEBUG:
342 print "---CHART:---"
343 for (s,t,LHS,L,R),v in chart.iteritems():
344 print "\t%s -> %s_%d ... %s_%d (L:%s, R:%s):\t%.3f" % (DMV_Rule.bar_str(LHS,g.numtag),
345 O(s), s,
346 O(s), t,
347 L, R, v)
348 print "---CHART:end---"
349 return [inner_prob, chart]
353 if __name__ == "__main__": # Non, Adj
354 _h_ = DMV_Rule((LRBAR,0), STOP, ( RBAR,0), 1.0, 1.0) # LSTOP
355 h_S = DMV_Rule(( RBAR,0),(NOBAR,0), STOP, 0.4, 0.3) # RSTOP
356 h_A = DMV_Rule(( RBAR,0),(LRBAR,0),( RBAR,0), 0.6, 0.7) # Lattach
357 h = DMV_Rule((NOBAR,0),(NOBAR,0),(LRBAR,0), 1.0, 1.0) # Rattach
358 b2 = {}
359 b2[(NOBAR, 0), 'h'] = 1.0
360 b2[(RBAR, 0), 'h'] = h_S.probA
361 b2[(LRBAR, 0), 'h'] = h_S.probA * _h_.probA
363 g_dup = DMV_Grammar([ _h_, h_S, h_A, h ],b2,0,0,0, {0:'h'}, {'h':0})
365 io.DEBUG = 0
366 test1 = inner_dmv(0, 1, (LRBAR,0), g_dup, 'h h'.split(), {})
367 if not "0.183"=="%.3f" % test1[0]:
368 print "Should be 0.183: %.3f" % test1[0]
371 ##############################
372 # DMV-probabilities, todo: #
373 ##############################
376 def P_CHOOSE():
377 return "todo"
379 def DMV(sent, g):
380 '''Here it seems like they store rule information on a per-head (per
381 direction) basis, in deps_D(h, dir) which gives us a list. '''
382 def P_h(h):
383 P_h = 1 # ?
384 for dir in ['l', 'r']:
385 for a in deps(h, dir):
386 # D(a)??
387 P_h *= \
388 P_STOP (0, h, dir, adj) * \
389 P_CHOOSE (a, h, dir) * \
390 P_h(D(a)) * \
391 P_STOP (STOP | h, dir, adj)
392 return P_h
393 return P_h(root(sent))
396 def P_STOP(STOP, h, dir, adj, g, corpus):
397 '''corpus is a list of sentences s.
399 This is based on the formula where STOP is True... not sure how we
400 calculate if STOP is False.
403 I thought about instead having this:
405 for rule in g.p_rules:
406 rule.num = 0
407 rule.den = 0
408 for sent in corpus:
409 for rule in g.p_rules:
410 for s:
411 for t:
412 set num and den using inner
413 for rule in g.p_rules
414 rule.prob = rule.num / rule.den
416 ..the way I'm assuming we do it in the commented out io-function in
417 io.py. Having sentences as the outer loop at least we can easily just
418 go through the heads that are actually in the sentence... BUT, this
419 means having to go through p_rules 3 times, not sure what is slower.
421 Also, now inner_dmv makes sure it only goes through heads that are
422 actually in the sentence, so that argument falls.
424 oh, and:
425 P_STOP(-STOP|...) = 1 - P_STOP(STOP|...)
429 P_STOP_num = 0
430 P_STOP_den = 0
431 h_tag = g.numtag(h)
432 for sent in corpus:
433 # have to go through _all_ places where h appears in the
434 # sentence...how? how to make sure it _works_?
435 if h_tag in sent:
436 locs_h = [i for i,w in enumerate(sent) if w == h_tag]
437 for loc_h in locs_h:
438 print h_tag
439 print loc_h
440 for s in range(loc_h): # i<loc(h), range gives strictly less
441 for t in range(i, len(sent)):
442 P_STOP_num += inner_dmv(s, t, (LRBAR,h), g, sent, {})[0]
443 P_STOP_den += inner_dmv(s, t, (RBAR,h), g, sent, {})[0]
444 print "den/num %.3f / %.3f"%(P_STOP_den, P_STOP_num)
445 if P_STOP_den > 0.0:
446 return P_STOP_num / P_STOP_den # upside down in article
447 else:
448 return 0.0
452 if __name__ == "__main__":
453 pass
454 # inner_dmv(0, 2, ROOT, g_abc, 'det nn vbd'.split(), {})
455 # io.DEBUG = 0
456 # print g_abc
457 # print P_STOP(True, 0, 'L', 'N', g_abc, corpus_abc)
463 # todo: some more testing on the Brown corpus:
464 if __name__ == "__main__":
465 pass
466 # # first five sentences of the Brown corpus:
467 # g_brown = harmonic.initialize([['AT', 'NP-TL', 'NN-TL', 'JJ-TL', 'NN-TL', 'VBD', 'NR', 'AT', 'NN', 'IN', 'NP$', 'JJ', 'NN', 'NN', 'VBD', '``', 'AT', 'NN', "''", 'CS', 'DTI', 'NNS', 'VBD', 'NN', '.'], ['AT', 'NN', 'RBR', 'VBD', 'IN', 'NN', 'NNS', 'CS', 'AT', 'NN-TL', 'JJ-TL', 'NN-TL', ',', 'WDT', 'HVD', 'JJ', 'NN', 'IN', 'AT', 'NN', ',', '``', 'VBZ', 'AT', 'NN', 'CC', 'NNS', 'IN', 'AT', 'NN-TL', 'IN-TL', 'NP-TL', "''", 'IN', 'AT', 'NN', 'IN', 'WDT', 'AT', 'NN', 'BEDZ', 'VBN', '.'], ['AT', 'NP', 'NN', 'NN', 'HVD', 'BEN', 'VBN', 'IN', 'NP-TL', 'JJ-TL', 'NN-TL', 'NN-TL', 'NP', 'NP', 'TO', 'VB', 'NNS', 'IN', 'JJ', '``', 'NNS', "''", 'IN', 'AT', 'JJ', 'NN', 'WDT', 'BEDZ', 'VBN', 'IN', 'NN-TL', 'NP', 'NP', 'NP', '.'], ['``', 'RB', 'AT', 'JJ', 'NN', 'IN', 'JJ', 'NNS', 'BEDZ', 'VBN', "''", ',', 'AT', 'NN', 'VBD', ',', '``', 'IN', 'AT', 'JJ', 'NN', 'IN', 'AT', 'NN', ',', 'AT', 'NN', 'IN', 'NNS', 'CC', 'AT', 'NN', 'IN', 'DT', 'NN', "''", '.'], ['AT', 'NN', 'VBD', 'PPS', 'DOD', 'VB', 'CS', 'AP', 'IN', 'NP$', 'NN', 'CC', 'NN', 'NNS', '``', 'BER', 'JJ', 'CC', 'JJ', 'CC', 'RB', 'JJ', "''", '.'], ['PPS', 'VBD', 'CS', 'NP', 'NNS', 'VB', '``', 'TO', 'HV', 'DTS', 'NNS', 'VBN', 'CC', 'VBN', 'IN', 'AT', 'NN', 'IN', 'VBG', 'CC', 'VBG', 'PPO', "''", '.'], ['AT', 'JJ', 'NN', 'VBD', 'IN', 'AT', 'NN', 'IN', 'AP', 'NNS', ',', 'IN', 'PPO', 'AT', 'NP', 'CC', 'NP-TL', 'NN-TL', 'VBG', 'NNS', 'WDT', 'PPS', 'VBD', '``', 'BER', 'QL', 'VBN', 'CC', 'VB', 'RB', 'VBN', 'NNS', 'WDT', 'VB', 'IN', 'AT', 'JJT', 'NN', 'IN', 'ABX', 'NNS', "''", '.'], ['NN-HL', 'VBN-HL'], ['WRB', ',', 'AT', 'NN', 'VBD', 'PPS', 'VBZ', '``', 'DTS', 'CD', 'NNS', 'MD', 'BE', 'VBN', 'TO', 'VB', 'JJR', 'NN', 'CC', 'VB', 'AT', 'NN', 'IN', 'NN', "''", '.'], ['AT', 'NN-TL', 'VBG-TL', 'NN-TL', ',', 'AT', 'NN', 'VBD', ',', '``', 'BEZ', 'VBG', 'IN', 'VBN', 'JJ', 'NNS', 'CS', 'AT', 'NN', 'IN', 'NN', 'NNS', 'NNS', "''", '.']])
468 # # 36:'AT' in g_brown.numtag, 40:'NP-TL'
470 # io.DEBUG = 0
471 # test_brown = inner_dmv(0,2, (LRBAR,36), g_brown, ['AT', 'NP-TL' ,'NN-TL','JJ-TL'], {})
472 # if io.DEBUG:
473 # for r in g_brown.rules((2,36)) + g_brown.rules((1,36)) + g_brown.rules((0,36)):
474 # L = r.L()
475 # R = r.R()
476 # if head(L) in [36,40,-2] and head(R) in [36,40,-2]:
477 # print r
478 # print "Brown-test gives: %.8f" % test_brown[0]
482 # this will give the tag sequences of all the 6218 Brown corpus
483 # sentences of length < 7:
484 # [[tag for (w, tag) in sent]
485 # for sent in nltk.corpus.brown.tagged_sents() if len(sent) < 7]