My first backup
[dmvccm.git] / src / dmv.py~20080528~
blobe995222bc1ed9fab24aa59710ad6d2d26a8fc104
1 #### changes by KBU:
2 # 2008-05-24:
3 # - prettier printout for DMV_Rule
4 # - DMV_Rule changed a bit. head, L and R are now all pairs of the
5 #   form (bars, head).
6 # - Started on P_STOP, a bit less pseudo now..
8 # 2008-05-27:
9 # - started on initialization. So far, I have frequencies for
10 #   everything, very harmonic. Still need to make these into 1-summing
11 #   probabilities
13 # import numpy # numpy provides Fast Arrays, for future optimization
15 import io
16 BARS = [0,1,2]
17 RBAR = 1
18 LRBAR = 2
19 NOBAR = 0
20 # is this the best way to do ROOT and STOP?
21 ROOT = (LRBAR, -1) 
22 STOP = -2 
24 # the following will only print when in this module:
25 if __name__ == "__main__":
26     print "DMV-module tests:"
27     a = io.CNF_Rule(3,1,3,0.1)
28     if(a.head != 3):
29         print "import io not working"
31 class DMV_Grammar(io.Grammar):
32     '''The DMV-PCFG.
34     We need to be able to access rules per mother node, sum every H, every
35     H_, ..., every H', every H'_, etc. for the IO-algorithm.
37     What other representations do we need? (P_STOP formula uses
38     deps_D(h,l/r) at least)'''
39     def rules(self, b_h):
40         bars = b_h[0]
41         head = b_h[1]
42         return [r for r in self.p_rules if r.head == head and r.bars == bars]
43     
44     def heads(self):
45         return "some structure full of rule heads.."
47     def deps(self, h, dir):
48         return "all dir-dependents of rules with head h"
50     def __init__(self, p_rules, p_terminals):
51         io.Grammar.__init__(self, p_rules, p_terminals)
52         
54 class DMV_Rule(io.CNF_Rule):
55     '''A single CNF rule in the PCFG, of the form 
56     LHS -> L R
57     where LHS = (bars, head)
58     
59     todo: possibly just store b_h instead of bars and head? (then b_h
60     = LHS, while we need new accessor functions for bars and head)
61     
62     Different rule-types have different probabilities associated with
63     them:
65     _h_ -> STOP  h_     P( STOP|h,L,    adj)
66     _h_ -> STOP  h_     P( STOP|h,L,non_adj)
67      h_ ->  h  STOP     P( STOP|h,R,    adj)
68      h_ ->  h  STOP     P( STOP|h,R,non_adj)
69      h_ -> _a_   h_     P(-STOP|h,L,    adj) * P(a|h,L)
70      h_ -> _a_   h_     P(-STOP|h,L,non_adj) * P(a|h,L)
71      h  ->  h   _a_     P(-STOP|h,R,    adj) * P(a|h,R)
72      h  ->  h   _a_     P(-STOP|h,R,non_adj) * P(a|h,R)
74     Todo, togrok: How do we know whether we use adj or non in inner()?
75     
76     Todo, togrok: How does "STOP" work in the inner-function?
77     '''
78     def __str__(self):
79         def bar_str(b_h):
80             str = "%d" % b_h[1]
81             if(b_h[0] == RBAR):
82                 str = "_%d" % b_h[1]
83             if(b_h[0] == LRBAR):
84                 str = "_%d_" % b_h[1]
85             return str
86         return "%s -> %s %s [%.2f]" % (bar_str((self.bars,self.head)),
87                                        bar_str(self.L),
88                                        bar_str(self.R),
89                                        self.prob)
91     def LHS(self):
92         return (self.bars, self.head)
94     def __init__(self, b_h, b_L, b_R, prob):
95         io.CNF_Rule.__init__(self, b_h[1], b_L, b_R, prob)
96         if b_h[0] in BARS:
97             self.bars = b_h[0]
98         else: # hmm, should perhaps check b_L and b_R too? todo
99             raise ValueError("bars must be in %s; was given: %s" % (BARS, b_h[0]))
102 # the following will only print when in this module:
103 if __name__ == "__main__":
104     # these are not Real rules, just testing the classes. todo: make
105     # a rule-set to test inner() on.
106     b = {}
107     s   = DMV_Rule((LRBAR,0), (NOBAR,1),(NOBAR,2), 1.0) # s->np vp
108     np  = DMV_Rule((NOBAR,1), (NOBAR,3),(NOBAR,4), 0.3) # np->n p
109     b[(NOBAR,1), 'n'] = 0.7 # np->'n'
110     b[(NOBAR,3), 'n'] = 1.0 # n->'n'
111     b[(NOBAR,4), 'p'] = 1.0 # p->'p'
112     vp  = DMV_Rule((NOBAR,2), (NOBAR,5),(NOBAR,1), 0.1) # vp->v np (two parses use this rule)
113     vp2 = DMV_Rule((NOBAR,2), (NOBAR,2),(NOBAR,4), 0.9) # vp->vp p
114     b[(NOBAR,5), 'v'] = 1.0 # v->'v'
115     
116     g = DMV_Grammar([s,np,vp,vp2], b)
117     
118     io.DEBUG = 0
119     test1 = io.inner(0,0, (NOBAR,1), g, ['n'], {})
120     if test1[0] != 0.7:
121         print "should be 0.70 : %.2f" % test1[0]
122         print ""
123     
124     test2 = io.inner(0,2, (NOBAR,2), g, ['v','n','p'], {})
125     if "%.2f" % test2[0] != "0.09": # 0.092999 etc, don't care about that
126         print "should be 0.09 if the io.py-test is right : %.2f" % test2[0]
127     # the following should manage to look stuff up in the chart:
128     test2 = io.inner(0,2, (NOBAR,2), g, ['v','n','p'], test2[1])
129     if "%.2f" % test2[0] != "0.09":
130         print "should be 0.09 if the io.py-test is right : %.2f" % test2[0]
132     
135 def all_inner_dmv(sent, g):
136     for h in g.heads():
137         for bars in BARS:
138             for s in range(s, len(sent)): # summing over r = s to r = t-1
139                 for t in range(s+1, len(sent)):
140                     io.inner(s, t, (bars, h), g, sent)
141     
144 # DMV-probabilities, todo:
146 def P_STOP(STOP, h, dir, adj, corpus):
147     '''corpus is a list of sentences s. 
149 This is based on the formula where STOP is True... not sure how we
150 calculate if STOP is False.
153 I thought about instead having this:
155 for rule in g.p_rules:
156     rule.num = 0
157     rule.den = 0
158 for sent in corpus:
159     for rule in g.p_rules:
160        for s:
161            for t:
162                set num and den using inner
163 for rule in g.p_rules
164     rule.prob = rule.num / rule.den
166 ..the way I'm assuming we do it in the commented out io-function in
167 io.py. Having sentences as the outer loop at least we can easily just
168 go through the heads that are actually in the sentence... BUT, this
169 means having to go through rules 3 times, not sure what is slower.
171 oh, and:
172 P_STOP(-STOP|...) = 1 - P_STOP(STOP|...)
176     P_STOP_num = 0
177     P_STOP_den = 0
178     for sent in corpus:
179         # here we should somehow make each word in the sentence
180         # unique, decorate them with subscripts or something. We have
181         # to run through the sentence as many times as h appears
182         # there. This also means changing inner(), I suspect.  Have to
183         # make sure we separate reading of inner_prob from changing of
184         # inner_prob...
185         for s in range(loc(h)): # i<loc(h), where h is in the sentence. 
186             for t in range(i, len(sent)):
187                 P_STOP_num += inner(s, t, h-r, g, sent, chart)
188                 P_STOP_den += inner(s, t, l-h-r, g, sent, chart) 
189     return P_STOP_num / P_STOP_den # possibly other way round? todo
192 def P_CHOOSE():
193     return "todo"
195 def DMV(sent, g):
196     '''Here it seems like they store rule information on a per-head (per
197     direction) basis, in deps_D(h, dir) which gives us a list. '''
198     def P_h(h):
199         P_h = 1 # ?
200         for dir in ['l', 'r']:
201             for a in deps(h, dir):
202                 # D(a)??
203                 P_h *= \
204                     P_STOP (0, h, dir, adj) * \
205                     P_CHOOSE (a, h, dir) * \
206                     P_h(D(a)) * \
207                 P_STOP (STOP | h, dir, adj)
208         return P_h
209     return P_h(root(sent))
213 # Initialization, todo
214 def tagset(corpus):
215     '''sents is of this form:
216 [['tag', ...], ['tag2', ...], ...]
218 Return a set of the tags. 
219 Fortunately only has to run once.
221     tagset = set()
222     for sent in corpus:
223         for tag in sent:
224             tagset.add(tag)
225     if 'ROOT' in tagset:
226         raise ValueError("it seems we must have a new ROOT symbol")
227     return tagset
229 def tagset_brown():
230     "472 tags, takes a while to extract with tagset(), hardcoded here."
231     return set(['BEDZ-NC', 'NP$', 'AT-TL', 'CS', 'NP+HVZ', 'IN-TL-HL', 'NR-HL', 'CC-TL-HL', 'NNS$-HL', 'JJS-HL', 'JJ-HL', 'WRB-TL', 'JJT-TL', 'WRB', 'DOD*', 'BER*-NC', ')-HL', 'NPS$-HL', 'RB-HL', 'FW-PPSS', 'NP+HVZ-NC', 'NNS$', '--', 'CC-TL', 'FW-NN-TL', 'NP-TL-HL', 'PPSS+MD', 'NPS', 'RBR+CS', 'DTI', 'NPS-TL', 'BEM', 'FW-AT+NP-TL', 'EX+BEZ', 'BEG', 'BED', 'BEZ', 'DTX', 'DOD*-TL', 'FW-VB-NC', 'DTS', 'DTS+BEZ', 'QL-HL', 'NP$-TL', 'WRB+DOD*', 'JJR+CS', 'NN+MD', 'NN-TL-HL', 'HVD-HL', 'NP+BEZ-NC', 'VBN+TO', '*-TL', 'WDT-HL', 'MD', 'NN-HL', 'FW-BE', 'DT$', 'PN-TL', 'DT-HL', 'FW-NR-TL', 'VBG', 'VBD', 'VBN', 'DOD', 'FW-VBG-TL', 'DOZ', 'ABN-TL', 'VB+JJ-NC', 'VBZ', 'RB+CS', 'FW-PN', 'CS-NC', 'VBG-NC', 'BER-HL', 'MD*', '``', 'WPS-TL', 'OD-TL', 'PPSS-HL', 'PPS+MD', 'DO*', 'DO-HL', 'HVG-HL', 'WRB-HL', 'JJT', 'JJS', 'JJR', 'HV+TO', 'WQL', 'DOD-NC', 'CC-HL', 'FW-PPSS+HV', 'FW-NP-TL', 'MD+TO', 'VB+IN', 'JJT-NC', 'WDT+BEZ-TL', '---HL', 'PN$', 'VB+PPO', 'BE-TL', 'VBG-TL', 'NP$-HL', 'VBZ-TL', 'UH', 'FW-WPO', 'AP+AP-NC', 'FW-IN', 'NRS-TL', 'ABL', 'ABN', 'TO-TL', 'ABX', '*-HL', 'FW-WPS', 'VB-NC', 'HVD*', 'PPS+HVD', 'FW-IN+AT', 'FW-NP', 'QLP', 'FW-NR', 'FW-NN', 'PPS+HVZ', 'NNS-NC', 'DT+BEZ-NC', 'PPO', 'PPO-NC', 'EX-HL', 'AP$', 'OD-NC', 'RP', 'WPS+BEZ', 'NN+BEZ', '.-TL', ',', 'FW-DT+BEZ', 'RB', 'FW-PP$-NC', 'RN', 'JJ$-TL', 'MD-NC', 'VBD-NC', 'PPSS+BER-N', 'RB+BEZ-NC', 'WPS-HL', 'VBN-NC', 'BEZ-HL', 'PPL-NC', 'BER-TL', 'PP$$', 'NNS+MD', 'PPS-NC', 'FW-UH-NC', 'PPS+BEZ-NC', 'PPSS+BER-TL', 'NR-NC', 'FW-JJ', 'PPS+BEZ-HL', 'NPS$', 'RB-TL', 'VB-TL', 'BEM*', 'MD*-HL', 'FW-CC', 'NP+MD', 'EX+HVZ', 'FW-CD', 'EX+HVD', 'IN-HL', 'FW-CS', 'JJR-HL', 'FW-IN+NP-TL', 'JJ-TL-HL', 'FW-UH', 'EX', 'FW-NNS-NC', 'FW-JJ-NC', 'VBZ-HL', 'VB+RP', 'BEZ-NC', 'PPSS+HV-TL', 'HV*', 'IN', 'PP$-NC', 'NP-NC', 'BEN', 'PP$-TL', 'FW-*-TL', 'FW-OD-TL', 'WPS', 'WPO', 'MD+PPSS', 'WDT+BER', 'WDT+BEZ', 'CD-HL', 'WDT+BEZ-NC', 'WP$', 'DO+PPSS', 'HV-HL', 'DT-NC', 'PN-NC', 'FW-VBZ', 'HVD', 'HVG', 'NN+BEZ-TL', 'HVZ', 'FW-VBD', 'FW-VBG', 'NNS$-TL', 'JJ-TL', 'FW-VBN', 'MD-TL', 'WDT+DOD', 'HV-TL', 'NN-TL', 'PPSS', 'NR$', 'BER', 'FW-VB', 'DT', 'PN+BEZ', 'VBG-HL', 'FW-PPL+VBZ', 'FW-NPS-TL', 'RB$', 'FW-IN+NN', 'FW-CC-TL', 'RBT', 'RBR', 'PPS-TL', 'PPSS+HV', 'JJS-TL', 'NPS-HL', 'WPS+BEZ-TL', 'NNS-TL-HL', 'VBN-TL-NC', 'QL-TL', 'NN+NN-NC', 'JJR-TL', 'NN$-TL', 'FW-QL', 'IN-TL', 'BED-NC', 'NRS', '.-HL', 'QL', 'PP$-HL', 'WRB+BER', 'JJ', 'WRB+BEZ', 'NNS$-TL-HL', 'PPSS+BEZ', '(', 'PPSS+BER', 'DT+MD', 'DOZ-TL', 'PPSS+BEM', 'FW-PP$', 'RB+BEZ-HL', 'FW-RB+CC', 'FW-PPS', 'VBG+TO', 'DO*-HL', 'NR+MD', 'PPLS', 'IN+IN', 'BEZ*', 'FW-PPL', 'FW-PPO', 'NNS-HL', 'NIL', 'HVN', 'PPSS+BER-NC', 'AP-TL', 'FW-DT', '(-HL', 'DTI-TL', 'JJ+JJ-NC', 'FW-RB', 'FW-VBD-TL', 'BER-NC', 'NNS$-NC', 'JJ-NC', 'NPS$-TL', 'VB+VB-NC', 'PN', 'VB+TO', 'AT-TL-HL', 'BEM-NC', 'PPL-TL', 'ABN-HL', 'RB-NC', 'DO-NC', 'BE-HL', 'WRB+IN', 'FW-UH-TL', 'PPO-HL', 'FW-CD-TL', 'TO-HL', 'PPS+BEZ', 'CD$', 'DO', 'EX+MD', 'HVZ-TL', 'TO-NC', 'IN-NC', '.', 'WRB+DO', 'CD-NC', 'FW-PPO+IN', 'FW-NN$-TL', 'WDT+BEZ-HL', 'RP-HL', 'CC', 'NN+HVZ-TL', 'FW-NNS-TL', 'DT+BEZ', 'WPS+HVZ', 'BEDZ*', 'NP-TL', ':-TL', 'NN-NC', 'WPO-TL', 'QL-NC', 'FW-AT+NN-TL', 'WDT+HVZ', '.-NC', 'FW-DTS', 'NP-HL', ':-HL', 'RBR-NC', 'OD-HL', 'BEDZ-HL', 'VBD-TL', 'NPS-NC', ')', 'TO+VB', 'FW-IN+NN-TL', 'PPL', 'PPS', 'PPSS+VB', 'DT-TL', 'RP-NC', 'VB', 'FW-VB-TL', 'PP$', 'VBD-HL', 'DTI-HL', 'NN-TL-NC', 'PPL-HL', 'DOZ*', 'NR-TL', 'WRB+MD', 'PN+HVZ', 'FW-IN-TL', 'PN+HVD', 'BEN-TL', 'BE', 'WDT', 'WPS+HVD', 'DO-TL', 'FW-NN-NC', 'WRB+BEZ-TL', 'UH-TL', 'JJR-NC', 'NNS', 'PPSS-NC', 'WPS+BEZ-NC', ',-TL', 'NN$', 'VBN-TL-HL', 'WDT-NC', 'OD', 'FW-OD-NC', 'DOZ*-TL', 'PPSS+HVD', 'CS-TL', 'WRB+DOZ', 'CC-NC', 'HV', 'NN$-HL', 'FW-WDT', 'WRB+DOD', 'NN+HVZ', 'AT-NC', 'NNS-TL', 'FW-BEZ', 'CS-HL', 'WPO-NC', 'FW-BER', 'NNS-TL-NC', 'BEZ-TL', 'FW-IN+AT-T', 'ABN-NC', 'NR-TL-HL', 'BEDZ', 'NP+BEZ', 'FW-AT-TL', 'BER*', 'WPS+MD', 'MD-HL', 'BED*', 'HV-NC', 'WPS-NC', 'VBN-HL', 'FW-TO+VB', 'PPSS+MD-NC', 'HVZ*', 'PPS-HL', 'WRB-NC', 'VBN-TL', 'CD-TL-HL', ',-NC', 'RP-TL', 'AP-HL', 'FW-HV', 'WQL-TL', 'FW-AT', 'NN', 'NR$-TL', 'VBZ-NC', '*', 'PPSS-TL', 'JJT-HL', 'FW-NNS', 'NP', 'UH-HL', 'NR', ':', 'FW-NN$', 'RP+IN', ',-HL', 'JJ-TL-NC', 'AP-NC', '*-NC', 'VB-HL', 'HVZ-NC', 'DTS-HL', 'FW-JJT', 'FW-JJR', 'FW-JJ-TL', 'FW-*', 'RB+BEZ', "''", 'VB+AT', 'PN-HL', 'PPO-TL', 'CD-TL', 'UH-NC', 'FW-NN-TL-NC', 'EX-NC', 'PPSS+BEZ*', 'TO', 'WDT+DO+PPS', 'IN+PPO', 'AP', 'AT', 'DOZ-HL', 'FW-RB-TL', 'CD', 'NN+IN', 'FW-AT-HL', 'PN+MD', "'", 'FW-PP$-TL', 'FW-NPS', 'WDT+BER+PP', 'NN+HVD-TL', 'MD+HV', 'AT-HL', 'FW-IN+AT-TL'])
235 def init_zeros(tags):
236     "Return a frequency dictionary with DMV-relevant keys set to 0 / {}."
237     f = {} 
238     for tag in tags:
239         f['ROOT', tag] = 0
240         f[tag, 'LN'] = 0
241         f[tag, 'LA'] = 0
242         f[tag, 'RN'] = 0
243         f[tag, 'RA'] = 0
244         f[tag, 'R'] = {}
245         f[tag, 'L'] = {}
246     return f
248 def init_freq(corpus):
249     '''Returns f, a dictionary with these types of keys:
250     - ('ROOT', tag) is basically just the frequency of tag
251     - (tag, 'LN') is for P_STOP(STOP|tag, left, non_adj); etc. for
252       'RN', 'LA', 'LN'.
253     - (tag, 'L') is a dictionary of arg:f, where head could take arg
254       to direction 'L' (etc. for 'R') and f is "harmonically" divided
255       by distance, used for finding P_CHOOSE
256     '''
257     f = init_zeros(tagset(corpus))
258     
259     for sent in corpus: # sent is ['VBD', 'NN', ...]
261         # NOTE: head in DMV_Rule is a number, while this is the string
262         for i_h, head in enumerate(sent): 
263             # todo grok: how is this different from just using straight head
264             # frequency counts, for the ROOT probabilities?
265             f['ROOT', head] += 1
267             if i_h <= 1: # first two words
268                 f[head, 'LN'] += 1
269                 if i_h == 0: # very first word
270                     f[head, 'LA'] += 1
272             if i_h >= len(sent) - 2: # last two words
273                 f[head, 'RN'] += 1
274                 if i_h == len(sent) - 1: # very last word
275                     f[head, 'RA'] += 1
277             for i_a, arg in enumerate(sent):
278                 # todo, optimization: possible to do both directions
279                 # at once here, and later on rule out the ones we've
280                 # done? does it actually speed things up?
281                 if arg != head:
282                     C = 0.0 # todo: tweak
283                     if i_h > i_a: 
284                         if arg not in f[head, 'L']:
285                             f[head, 'L'][arg] = 0.0
286                         f[head, 'L'][arg] += 1.0/(i_h - i_a) + C
287                     if i_h < i_a: 
288                         if arg not in f[head, 'R']:
289                             f[head, 'R'][arg] = 0.0
290                         f[head, 'R'][arg] += 1.0/(i_a - i_h) + C
292     return f # end init_sent
294 def initialize(corpus):
295     '''Return an initialized DMV_Grammar (todo)
296     corpus is a list of lists of tags.'''
297     p_terminals = {}
298     for tag in tagset(corpus):
299         p_terminals[(NOBAR, tag), tag] = 1 # I guess....
300     
301     # f: frequency counts used in initialization, mostly distances
302     f = init_freq(corpus)
303         
304     #p_rules = make_stop_root(tags)
305     # lots todo; make probabilities... not sure whether make_stop_root
306     # is useful
307     return f
310 if __name__ == "__main__":
311     print "--------------------"
312     import pprint
313     pprint.pprint(initialize([['zero', 'one','two','three']]))
314     print '''    - ('ROOT', tag) is basically just the frequency of tag
315     - (tag, 'LN') is for P_STOP(STOP|tag, left, non_adj); etc. for
316       'RN', 'LA', 'LN'.
317     - (tag, 'L') is a dictionary of arg:f, where head could take arg
318       to direction 'L' (etc. for 'R') and f is "harmonically" divided
319       by distance, used for finding P_CHOOSE'''
324 ####################
325 #    junk:         #
326 ####################
328 def make_stop_root(tagset):
329     "not sure if this has anything to it."
330     # p_rules could be a set, if we have DMV_Rule equality function
331     # that lets two rules be equal if LHS, L and R are equal, although
332     # prob may be different; this would ensure no duplicate rules at
333     # the very least.
334     p_rules = []
335     for num, tag in enumerate(tagset):
336         p_rules.append( DMV_Rule((LRBAR, num), STOP, (RBAR, num), 0.0) )
337         p_rules.append( DMV_Rule((RBAR, num), (NOBAR, num), STOP, 0.0) )
338         p_rules.append( DMV_Rule(ROOT, (LRBAR,num), STOP, 0.0) )
339     return p_rules