another little fix to stop initialization
[dmvccm.git] / src / before_betweens_harmonic.py
blobc0c862a797d5903b8a9ff6df1b3ef80b3e600430
1 # before_betweens_harmonic.py, initialization for before_betweens_dmv.py
3 # todo: remove old initialization and make initialization2 use
4 # DMV_Grammar instead of DMV_Grammar2
6 from before_betweens_dmv import * # better way to do this?
8 # todo: tweak this
9 HARMONIC_C = 0.0
10 FNONSTOP_MIN = 25
11 FSTOP_MIN = 5
13 ##############################
14 # Initialization #
15 ##############################
16 def taglist(corpus):
17 '''sents is of this form:
18 [['tag', ...], ['tag2', ...], ...]
20 Return a list of the tags. (Has to be ordered for enumerating to be
21 consistent.)
23 Fortunately only has to run once.
24 '''
25 tagset = set()
26 for sent in corpus:
27 for tag in sent:
28 tagset.add(tag)
29 if 'ROOT' in tagset:
30 raise ValueError("it seems we must have a new ROOT symbol")
31 return list(tagset)
37 def init_zeros(tags):
38 '''Return a frequency dictionary with DMV-relevant keys set to 0 or
39 {}.
41 Todo: tweak (especially for f_STOP).'''
42 f = {}
43 for tag in tags:
44 f['ROOT', tag] = 0
45 f['sum', 'ROOT'] = 0
46 for dir in [LEFT, RIGHT]:
47 for adj in [ADJ, NON]:
48 f[tag, 'STOP', dir, adj] = FSTOP_MIN
49 f[tag, '-STOP', dir, adj] = FNONSTOP_MIN
50 f[tag, RIGHT] = {}
51 f[tag, LEFT] = {}
52 f[tag, 'sum', RIGHT] = 0.0
53 f[tag, 'sum', LEFT] = 0.0
54 return f
56 def init_freq(corpus, tags):
57 '''Returns f, a dictionary with these types of keys:
58 - ('ROOT', tag) is basically just the frequency of tag
59 - (tag, 'STOP', 'LN') is for P_STOP(STOP|tag, left, non_adj);
60 etc. for 'RN', 'LA', 'LN', '-STOP'.
61 - (tag, LEFT) is a dictionary of arg:f, where head could take arg
62 to direction LEFT (etc. for RIGHT) and f is "harmonically" divided
63 by distance, used for finding P_CHOOSE
65 Does this stuff:
66 1. counts word frequencies for f_ROOT
67 2. adds to certain f_STOP counters if a word is found first,
68 last, first or second, or last or second to last in the sentence
69 (Left Adjacent, Left Non-Adjacent, etc)
70 3. adds to f_CHOOSE(arg|head) a "harmonic" number (divided by
71 distance between arg and head)
72 '''
73 f = init_zeros(tags)
75 for sent in corpus: # sent is ['VBD', 'NN', ...]
76 n = len(sent)
77 # NOTE: head in DMV_Rule is a number, while this is the string
78 for loc_h, head in enumerate(sent):
79 # todo grok: how is this different from just using straight head
80 # frequency counts, for the ROOT probabilities?
81 f['ROOT', head] += 1
82 f['sum', 'ROOT'] += 1
84 # True = 1, False = 0. todo: make prettier
85 f[head, 'STOP', LEFT,NON] += (loc_h == 1) # second word
86 f[head, '-STOP', LEFT,NON] += (not loc_h == 1) # not second
87 f[head, 'STOP', LEFT,ADJ] += (loc_h == 0) # first word
88 f[head, '-STOP', LEFT,ADJ] += (not loc_h == 0) # not first
89 f[head, 'STOP', RIGHT,NON] += (loc_h == n - 2) # second-to-last
90 f[head, '-STOP', RIGHT,NON] += (not loc_h == n - 2) # not second-to-last
91 f[head, 'STOP', RIGHT,ADJ] += (loc_h == n - 1) # last word
92 f[head, '-STOP', RIGHT,ADJ] += (not loc_h == n - 1) # not last
94 # this is where we make the "harmonic" distribution. quite.
95 for loc_a, arg in enumerate(sent):
96 if loc_h != loc_a:
97 harmony = 1.0/abs(loc_h - loc_a) + HARMONIC_C
98 if loc_h > loc_a:
99 dir = LEFT
100 else:
101 dir = RIGHT
102 if arg not in f[head, dir]:
103 f[head, dir][arg] = 0.0
104 f[head, dir][arg] += harmony
105 f[head, 'sum', dir] += harmony
106 # todo, optimization: possible to do both directions
107 # at once here, and later on rule out the ones we've
108 # done? does it actually speed things up?
110 return f
112 def init_normalize(f, tags, numtag, tagnum):
113 '''Use frequencies (and sums) in f to return create p_STOP and
114 p_ATTACH; at the same time adding the context-free rules to the
115 grammar using these probabilities.
117 Return a usable grammar.'''
118 p_rules = []
119 p_STOP, p_ROOT, p_ATTACH, p_terminals = {},{},{},{}
120 for n_h, head in numtag.iteritems():
121 p_ROOT[n_h] = float(f['ROOT', head]) / f['sum', 'ROOT']
122 p_rules.append( DMV_Rule(ROOT, STOP, (SEAL,n_h),
123 p_ROOT[n_h],
124 p_ROOT[n_h]))
126 # p_STOP = STOP / (STOP + NOT_STOP)
127 for dir in [LEFT,RIGHT]:
128 for adj in [NON,ADJ]:
129 p_STOP[n_h, dir, adj] = \
130 float(f[head, 'STOP', dir, adj]) / \
131 (f[head, 'STOP', dir, adj] + f[head, '-STOP', dir, adj])
132 # make rule using the previously found probN and probA:
133 p_rules.append( DMV_Rule((RGOL, n_h), (GOR, n_h), STOP,
134 p_STOP[n_h, RIGHT,NON],
135 p_STOP[n_h, RIGHT,ADJ]) )
136 p_rules.append( DMV_Rule((SEAL, n_h), STOP, (RGOL, n_h),
137 p_STOP[n_h, LEFT,NON],
138 p_STOP[n_h, LEFT,ADJ]) )
140 p_terminals[(GOR, n_h), head] = 1.0
141 # inner() shouldn't have to deal with those long non-branching
142 # stops. But actually, since these are added rules they just
143 # make things take more time: 2.77s with, 1.87s without
144 # p_terminals[(RGOL, n_h), head] = p_STOP[n_h, 'RA']
145 # p_terminals[(SEAL, n_h), head] = p_STOP[n_h, 'RA'] * p_STOP[n_h, 'LA']
147 for dir in [LEFT, RIGHT]:
148 for arg, val in f[head, dir].iteritems():
149 p_ATTACH[tagnum[arg], n_h, dir] = float(val) / f[head,'sum',dir]
151 # after the head tag-loop, add every head-argument rule:
152 for (n_a, n_h, dir),p_A in p_ATTACH.iteritems():
153 if dir == LEFT: # arg is to the left of head
154 p_rules.append( DMV_Rule((RGOL,n_h), (SEAL,n_a), (RGOL,n_h),
155 p_A*(1-p_STOP[n_h, dir, NON]),
156 p_A*(1-p_STOP[n_h, dir, ADJ])) )
157 if dir == RIGHT:
158 p_rules.append( DMV_Rule((GOR,n_h), (GOR,n_h), (SEAL,n_a),
159 p_A*(1-p_STOP[n_h, dir, NON]),
160 p_A*(1-p_STOP[n_h, dir, ADJ])) )
162 return DMV_Grammar(numtag, tagnum, p_rules, p_terminals, p_STOP, p_ATTACH, p_ROOT)
164 def initialize(corpus):
165 '''Return an initialized DMV_Grammar
166 corpus is a list of lists of tags.'''
167 tags = taglist(corpus)
168 numtag, tagnum = {}, {}
169 for num, tag in enumerate(tags):
170 tagnum[tag] = num
171 numtag[num] = tag
172 # f: frequency counts used in initialization, mostly distances
173 f = init_freq(corpus, tags)
174 g = init_normalize(f, tags, numtag, tagnum)
175 return g
178 if __name__ == "__main__":
179 # todo: grok why there's so little difference in probN and probA values
181 print "--------initialization testing------------"
182 print initialize([['foo', 'two','foo','foo'],
183 ['zero', 'one','two','three']])
185 for (n,s) in [(95,5),(5,5)]:
186 FNONSTOP_MIN = n
187 FSTOP_MIN = s
189 testcorpus = [s.split() for s in ['det nn vbd c nn vbd nn','det nn vbd c nn vbd pp nn',
190 'det nn vbd nn','det nn vbd c nn vbd pp nn',
191 'det nn vbd nn','det nn vbd c nn vbd pp nn',
192 'det nn vbd nn','det nn vbd c nn vbd pp nn',
193 'det nn vbd nn','det nn vbd c nn vbd pp nn',
194 'det nn vbd pp nn','det nn vbd det nn', ]]
195 g = initialize(testcorpus)
197 stopn, nstopn,nstopa, stopa, rewriten, rewritea = 0.0, 0.0, 0.0, 0.0,0.0,0.0
198 for r in g.all_rules():
199 if r.L() == STOP or r.R() == STOP:
200 stopn += r.probN
201 nstopa += 1-r.probA
202 nstopn += 1-r.probN
203 stopa += r.probA
204 else:
205 rewriten += r.probN
206 rewritea += r.probA
207 print "sn:%.2f (nsn:%.2f) sa:%.2f (nsa:%.2f) rn:%.2f ra:%.2f" % (stopn, nstopn, stopa,nstopa, rewriten, rewritea)
212 def tagset_brown():
213 "472 tags, takes a while to extract with tagset(), hardcoded here."
214 return set(['BEDZ-NC', 'NP$', 'AT-TL', 'CS', 'NP+HVZ', 'IN-TL-HL', 'NR-HL', 'CC-TL-HL', 'NNS$-HL', 'JJS-HL', 'JJ-HL', 'WRB-TL', 'JJT-TL', 'WRB', 'DOD*', 'BER*-NC', ')-HL', 'NPS$-HL', 'RB-HL', 'FW-PPSS', 'NP+HVZ-NC', 'NNS$', '--', 'CC-TL', 'FW-NN-TL', 'NP-TL-HL', 'PPSS+MD', 'NPS', 'RBR+CS', 'DTI', 'NPS-TL', 'BEM', 'FW-AT+NP-TL', 'EX+BEZ', 'BEG', 'BED', 'BEZ', 'DTX', 'DOD*-TL', 'FW-VB-NC', 'DTS', 'DTS+BEZ', 'QL-HL', 'NP$-TL', 'WRB+DOD*', 'JJR+CS', 'NN+MD', 'NN-TL-HL', 'HVD-HL', 'NP+BEZ-NC', 'VBN+TO', '*-TL', 'WDT-HL', 'MD', 'NN-HL', 'FW-BE', 'DT$', 'PN-TL', 'DT-HL', 'FW-NR-TL', 'VBG', 'VBD', 'VBN', 'DOD', 'FW-VBG-TL', 'DOZ', 'ABN-TL', 'VB+JJ-NC', 'VBZ', 'RB+CS', 'FW-PN', 'CS-NC', 'VBG-NC', 'BER-HL', 'MD*', '``', 'WPS-TL', 'OD-TL', 'PPSS-HL', 'PPS+MD', 'DO*', 'DO-HL', 'HVG-HL', 'WRB-HL', 'JJT', 'JJS', 'JJR', 'HV+TO', 'WQL', 'DOD-NC', 'CC-HL', 'FW-PPSS+HV', 'FW-NP-TL', 'MD+TO', 'VB+IN', 'JJT-NC', 'WDT+BEZ-TL', '---HL', 'PN$', 'VB+PPO', 'BE-TL', 'VBG-TL', 'NP$-HL', 'VBZ-TL', 'UH', 'FW-WPO', 'AP+AP-NC', 'FW-IN', 'NRS-TL', 'ABL', 'ABN', 'TO-TL', 'ABX', '*-HL', 'FW-WPS', 'VB-NC', 'HVD*', 'PPS+HVD', 'FW-IN+AT', 'FW-NP', 'QLP', 'FW-NR', 'FW-NN', 'PPS+HVZ', 'NNS-NC', 'DT+BEZ-NC', 'PPO', 'PPO-NC', 'EX-HL', 'AP$', 'OD-NC', 'RP', 'WPS+BEZ', 'NN+BEZ', '.-TL', ',', 'FW-DT+BEZ', 'RB', 'FW-PP$-NC', 'RN', 'JJ$-TL', 'MD-NC', 'VBD-NC', 'PPSS+BER-N', 'RB+BEZ-NC', 'WPS-HL', 'VBN-NC', 'BEZ-HL', 'PPL-NC', 'BER-TL', 'PP$$', 'NNS+MD', 'PPS-NC', 'FW-UH-NC', 'PPS+BEZ-NC', 'PPSS+BER-TL', 'NR-NC', 'FW-JJ', 'PPS+BEZ-HL', 'NPS$', 'RB-TL', 'VB-TL', 'BEM*', 'MD*-HL', 'FW-CC', 'NP+MD', 'EX+HVZ', 'FW-CD', 'EX+HVD', 'IN-HL', 'FW-CS', 'JJR-HL', 'FW-IN+NP-TL', 'JJ-TL-HL', 'FW-UH', 'EX', 'FW-NNS-NC', 'FW-JJ-NC', 'VBZ-HL', 'VB+RP', 'BEZ-NC', 'PPSS+HV-TL', 'HV*', 'IN', 'PP$-NC', 'NP-NC', 'BEN', 'PP$-TL', 'FW-*-TL', 'FW-OD-TL', 'WPS', 'WPO', 'MD+PPSS', 'WDT+BER', 'WDT+BEZ', 'CD-HL', 'WDT+BEZ-NC', 'WP$', 'DO+PPSS', 'HV-HL', 'DT-NC', 'PN-NC', 'FW-VBZ', 'HVD', 'HVG', 'NN+BEZ-TL', 'HVZ', 'FW-VBD', 'FW-VBG', 'NNS$-TL', 'JJ-TL', 'FW-VBN', 'MD-TL', 'WDT+DOD', 'HV-TL', 'NN-TL', 'PPSS', 'NR$', 'BER', 'FW-VB', 'DT', 'PN+BEZ', 'VBG-HL', 'FW-PPL+VBZ', 'FW-NPS-TL', 'RB$', 'FW-IN+NN', 'FW-CC-TL', 'RBT', 'RBR', 'PPS-TL', 'PPSS+HV', 'JJS-TL', 'NPS-HL', 'WPS+BEZ-TL', 'NNS-TL-HL', 'VBN-TL-NC', 'QL-TL', 'NN+NN-NC', 'JJR-TL', 'NN$-TL', 'FW-QL', 'IN-TL', 'BED-NC', 'NRS', '.-HL', 'QL', 'PP$-HL', 'WRB+BER', 'JJ', 'WRB+BEZ', 'NNS$-TL-HL', 'PPSS+BEZ', '(', 'PPSS+BER', 'DT+MD', 'DOZ-TL', 'PPSS+BEM', 'FW-PP$', 'RB+BEZ-HL', 'FW-RB+CC', 'FW-PPS', 'VBG+TO', 'DO*-HL', 'NR+MD', 'PPLS', 'IN+IN', 'BEZ*', 'FW-PPL', 'FW-PPO', 'NNS-HL', 'NIL', 'HVN', 'PPSS+BER-NC', 'AP-TL', 'FW-DT', '(-HL', 'DTI-TL', 'JJ+JJ-NC', 'FW-RB', 'FW-VBD-TL', 'BER-NC', 'NNS$-NC', 'JJ-NC', 'NPS$-TL', 'VB+VB-NC', 'PN', 'VB+TO', 'AT-TL-HL', 'BEM-NC', 'PPL-TL', 'ABN-HL', 'RB-NC', 'DO-NC', 'BE-HL', 'WRB+IN', 'FW-UH-TL', 'PPO-HL', 'FW-CD-TL', 'TO-HL', 'PPS+BEZ', 'CD$', 'DO', 'EX+MD', 'HVZ-TL', 'TO-NC', 'IN-NC', '.', 'WRB+DO', 'CD-NC', 'FW-PPO+IN', 'FW-NN$-TL', 'WDT+BEZ-HL', 'RP-HL', 'CC', 'NN+HVZ-TL', 'FW-NNS-TL', 'DT+BEZ', 'WPS+HVZ', 'BEDZ*', 'NP-TL', ':-TL', 'NN-NC', 'WPO-TL', 'QL-NC', 'FW-AT+NN-TL', 'WDT+HVZ', '.-NC', 'FW-DTS', 'NP-HL', ':-HL', 'RBR-NC', 'OD-HL', 'BEDZ-HL', 'VBD-TL', 'NPS-NC', ')', 'TO+VB', 'FW-IN+NN-TL', 'PPL', 'PPS', 'PPSS+VB', 'DT-TL', 'RP-NC', 'VB', 'FW-VB-TL', 'PP$', 'VBD-HL', 'DTI-HL', 'NN-TL-NC', 'PPL-HL', 'DOZ*', 'NR-TL', 'WRB+MD', 'PN+HVZ', 'FW-IN-TL', 'PN+HVD', 'BEN-TL', 'BE', 'WDT', 'WPS+HVD', 'DO-TL', 'FW-NN-NC', 'WRB+BEZ-TL', 'UH-TL', 'JJR-NC', 'NNS', 'PPSS-NC', 'WPS+BEZ-NC', ',-TL', 'NN$', 'VBN-TL-HL', 'WDT-NC', 'OD', 'FW-OD-NC', 'DOZ*-TL', 'PPSS+HVD', 'CS-TL', 'WRB+DOZ', 'CC-NC', 'HV', 'NN$-HL', 'FW-WDT', 'WRB+DOD', 'NN+HVZ', 'AT-NC', 'NNS-TL', 'FW-BEZ', 'CS-HL', 'WPO-NC', 'FW-BER', 'NNS-TL-NC', 'BEZ-TL', 'FW-IN+AT-T', 'ABN-NC', 'NR-TL-HL', 'BEDZ', 'NP+BEZ', 'FW-AT-TL', 'BER*', 'WPS+MD', 'MD-HL', 'BED*', 'HV-NC', 'WPS-NC', 'VBN-HL', 'FW-TO+VB', 'PPSS+MD-NC', 'HVZ*', 'PPS-HL', 'WRB-NC', 'VBN-TL', 'CD-TL-HL', ',-NC', 'RP-TL', 'AP-HL', 'FW-HV', 'WQL-TL', 'FW-AT', 'NN', 'NR$-TL', 'VBZ-NC', '*', 'PPSS-TL', 'JJT-HL', 'FW-NNS', 'NP', 'UH-HL', 'NR', ':', 'FW-NN$', 'RP+IN', ',-HL', 'JJ-TL-NC', 'AP-NC', '*-NC', 'VB-HL', 'HVZ-NC', 'DTS-HL', 'FW-JJT', 'FW-JJR', 'FW-JJ-TL', 'FW-*', 'RB+BEZ', "''", 'VB+AT', 'PN-HL', 'PPO-TL', 'CD-TL', 'UH-NC', 'FW-NN-TL-NC', 'EX-NC', 'PPSS+BEZ*', 'TO', 'WDT+DO+PPS', 'IN+PPO', 'AP', 'AT', 'DOZ-HL', 'FW-RB-TL', 'CD', 'NN+IN', 'FW-AT-HL', 'PN+MD', "'", 'FW-PP$-TL', 'FW-NPS', 'WDT+BER+PP', 'NN+HVD-TL', 'MD+HV', 'AT-HL', 'FW-IN+AT-TL'])