close to finishing pCHOOSE
[dmvccm.git] / src / harmonic.py
blobf1e35544826479d676fe7a5391af016948d18dab
1 # harmonic.py, initialization for dmv
2 #
3 # initialization is a package within dmvccm
5 from dmv import * # better way to do this?
7 # todo: tweak this
8 HARMONIC_C = 0.0
9 FNONSTOP_MIN = 25
10 FSTOP_MIN = 5
12 ##############################
13 # Initialization #
14 ##############################
15 def taglist(corpus):
16 '''sents is of this form:
17 [['tag', ...], ['tag2', ...], ...]
19 Return a list of the tags. (Has to be ordered for enumerating to be
20 consistent.)
22 Fortunately only has to run once.
23 '''
24 tagset = set()
25 for sent in corpus:
26 for tag in sent:
27 tagset.add(tag)
28 if 'ROOT' in tagset:
29 raise ValueError("it seems we must have a new ROOT symbol")
30 return list(tagset)
36 def init_zeros(tags):
37 '''Return a frequency dictionary with DMV-relevant keys set to 0 or
38 {}.
40 Todo: tweak (especially for f_STOP).'''
41 f = {}
42 for tag in tags:
43 f['ROOT', tag] = 0
44 f['sum', 'ROOT'] = 0
45 for dir_adj in ['LN','LA','RN','RA']:
46 f[tag, 'STOP', dir_adj] = FSTOP_MIN
47 f[tag, '-STOP', dir_adj] = FNONSTOP_MIN
48 f[tag, 'R'] = {}
49 f[tag, 'L'] = {}
50 f[tag, 'sum', 'R'] = 0.0
51 f[tag, 'sum', 'L'] = 0.0
52 return f
54 def init_freq(corpus, tags):
55 '''Returns f, a dictionary with these types of keys:
56 - ('ROOT', tag) is basically just the frequency of tag
57 - (tag, 'STOP', 'LN') is for P_STOP(STOP|tag, left, non_adj);
58 etc. for 'RN', 'LA', 'LN', '-STOP'.
59 - (tag, 'L') is a dictionary of arg:f, where head could take arg
60 to direction 'L' (etc. for 'R') and f is "harmonically" divided
61 by distance, used for finding P_CHOOSE
63 Does this stuff:
64 1. counts word frequencies for f_ROOT
65 2. adds to certain f_STOP counters if a word is found first,
66 last, first or second, or last or second to last in the sentence
67 (Left Adjacent, Left Non-Adjacent, etc)
68 3. adds to f_CHOOSE(arg|head) a "harmonic" number (divided by
69 distance between arg and head)
70 '''
71 f = init_zeros(tags)
73 for sent in corpus: # sent is ['VBD', 'NN', ...]
74 n = len(sent)
75 # NOTE: head in DMV_Rule is a number, while this is the string
76 for loc_h, head in enumerate(sent):
77 # todo grok: how is this different from just using straight head
78 # frequency counts, for the ROOT probabilities?
79 f['ROOT', head] += 1
80 f['sum', 'ROOT'] += 1
82 # True = 1, False = 0. todo: make prettier
83 f[head, 'STOP', 'LN'] += (loc_h == 1) # second word
84 f[head, '-STOP', 'LN'] += (not loc_h == 1) # not second
85 f[head, 'STOP', 'LA'] += (loc_h == 0) # first word
86 f[head, '-STOP', 'LA'] += (not loc_h == 0) # not first
87 f[head, 'STOP', 'RN'] += (loc_h == n - 2) # second-to-last
88 f[head, '-STOP', 'RN'] += (not loc_h == n - 2) # not second-to-last
89 f[head, 'STOP', 'RA'] += (loc_h == n - 1) # last word
90 f[head, '-STOP', 'RA'] += (not loc_h == n - 1) # not last
92 # this is where we make the "harmonic" distribution. quite.
93 for loc_a, arg in enumerate(sent):
94 if loc_h != loc_a:
95 harmony = 1.0/abs(loc_h - loc_a) + HARMONIC_C
96 if loc_h > loc_a:
97 dir = 'L'
98 else:
99 dir = 'R'
100 if arg not in f[head, dir]:
101 f[head, dir][arg] = 0.0
102 f[head, dir][arg] += harmony
103 f[head, 'sum', dir] += harmony
104 # todo, optimization: possible to do both directions
105 # at once here, and later on rule out the ones we've
106 # done? does it actually speed things up?
108 return f
110 def init_normalize(f, tags, tagnum, numtag):
111 '''Use frequencies (and sums) in f to return create p_STOP and
112 p_CHOOSE; at the same time adding the context-free rules to the
113 grammar using these probabilities.
115 Return a usable grammar.'''
116 p_rules = []
117 p_STOP, p_ROOT, p_CHOOSE, p_terminals = {},{},{},{}
118 for n_h, head in enumerate(tags):
119 p_ROOT[n_h] = float(f['ROOT', head]) / f['sum', 'ROOT']
120 p_rules.append( DMV_Rule(ROOT, STOP, (SEAL,n_h),
121 p_ROOT[n_h],
122 p_ROOT[n_h]))
124 # p_STOP = STOP / (STOP + NOT_STOP)
125 for dir in ['L','R']:
126 for adj in ['N','A']:
127 p_STOP[n_h, dir+adj] = \
128 float(f[head, 'STOP', dir+adj]) / \
129 (f[head, 'STOP', dir+adj] + f[head, '-STOP', dir+adj])
130 # make rule using the previously found probN and probA:
131 p_rules.append( DMV_Rule((RGO_L, n_h), (GO_R, n_h), STOP,
132 p_STOP[n_h, 'RN'],
133 p_STOP[n_h, 'RA']) )
134 p_rules.append( DMV_Rule((SEAL, n_h), STOP, (RGO_L, n_h),
135 p_STOP[n_h, 'LN'],
136 p_STOP[n_h, 'LA']) )
138 p_terminals[(GO_R, n_h), head] = 1.0
139 # inner() shouldn't have to deal with those long non-branching
140 # stops. But actually, since these are added rules they just
141 # make things take more time: 2.77s with, 1.87s without
142 # p_terminals[(RGO_L, n_h), head] = p_STOP[n_h, 'RA']
143 # p_terminals[(SEAL, n_h), head] = p_STOP[n_h, 'RA'] * p_STOP[n_h, 'LA']
145 for dir in ['L', 'R']:
146 for arg, val in f[head, dir].iteritems():
147 p_CHOOSE[tagnum[arg], n_h, dir] = float(val) / f[head,'sum',dir]
149 # after the head tag-loop, add every head-argument rule:
150 for (n_a, n_h, dir),p_C in p_CHOOSE.iteritems():
151 if dir == 'L': # arg is to the left of head
152 p_rules.append( DMV_Rule((RGO_L,n_h), (SEAL,n_a), (RGO_L,n_h),
153 p_C*(1-p_STOP[n_h, dir+'N']),
154 p_C*(1-p_STOP[n_h, dir+'A'])) )
155 if dir == 'R':
156 p_rules.append( DMV_Rule((GO_R,n_h), (GO_R,n_h), (SEAL,n_a),
157 p_C*(1-p_STOP[n_h, dir+'N']),
158 p_C*(1-p_STOP[n_h, dir+'A'])) )
160 return DMV_Grammar(p_rules, p_terminals, p_STOP, p_CHOOSE, p_ROOT, numtag, tagnum)
163 def initialize(corpus):
164 '''Return an initialized DMV_Grammar
165 corpus is a list of lists of tags.'''
166 tags = taglist(corpus)
167 tagnum, numtag = {}, {}
168 for num, tag in enumerate(tags):
169 tagnum[tag] = num
170 numtag[num] = tag
171 # f: frequency counts used in initialization, mostly distances
172 f = init_freq(corpus, tags)
173 g = init_normalize(f, tags, tagnum, numtag)
174 return g
177 if __name__ == "__main__":
178 # todo: grok why there's so little difference in probN and probA values
180 print "--------initialization testing------------"
181 print initialize([['foo', 'two','foo','foo'],
182 ['zero', 'one','two','three']])
184 for (n,s) in [(95,5),(5,5)]:
185 FNONSTOP_MIN = n
186 FSTOP_MIN = s
188 testcorpus = [s.split() for s in ['det nn vbd c nn vbd nn','det nn vbd c nn vbd pp nn',
189 'det nn vbd nn','det nn vbd c nn vbd pp nn',
190 'det nn vbd nn','det nn vbd c nn vbd pp nn',
191 'det nn vbd nn','det nn vbd c nn vbd pp nn',
192 'det nn vbd nn','det nn vbd c nn vbd pp nn',
193 'det nn vbd pp nn','det nn vbd det nn', ]]
194 g = initialize(testcorpus)
196 stopn, nstopn,nstopa, stopa, rewriten, rewritea = 0.0, 0.0, 0.0, 0.0,0.0,0.0
197 for r in g.all_rules():
198 if r.L() == STOP or r.R() == STOP:
199 stopn += r.probN
200 nstopa += 1-r.probA
201 nstopn += 1-r.probN
202 stopa += r.probA
203 else:
204 rewriten += r.probN
205 rewritea += r.probA
206 print "sn:%.2f (nsn:%.2f) sa:%.2f (nsa:%.2f) rn:%.2f ra:%.2f" % (stopn, nstopn, stopa,nstopa, rewriten, rewritea)
211 def tagset_brown():
212 "472 tags, takes a while to extract with tagset(), hardcoded here."
213 return set(['BEDZ-NC', 'NP$', 'AT-TL', 'CS', 'NP+HVZ', 'IN-TL-HL', 'NR-HL', 'CC-TL-HL', 'NNS$-HL', 'JJS-HL', 'JJ-HL', 'WRB-TL', 'JJT-TL', 'WRB', 'DOD*', 'BER*-NC', ')-HL', 'NPS$-HL', 'RB-HL', 'FW-PPSS', 'NP+HVZ-NC', 'NNS$', '--', 'CC-TL', 'FW-NN-TL', 'NP-TL-HL', 'PPSS+MD', 'NPS', 'RBR+CS', 'DTI', 'NPS-TL', 'BEM', 'FW-AT+NP-TL', 'EX+BEZ', 'BEG', 'BED', 'BEZ', 'DTX', 'DOD*-TL', 'FW-VB-NC', 'DTS', 'DTS+BEZ', 'QL-HL', 'NP$-TL', 'WRB+DOD*', 'JJR+CS', 'NN+MD', 'NN-TL-HL', 'HVD-HL', 'NP+BEZ-NC', 'VBN+TO', '*-TL', 'WDT-HL', 'MD', 'NN-HL', 'FW-BE', 'DT$', 'PN-TL', 'DT-HL', 'FW-NR-TL', 'VBG', 'VBD', 'VBN', 'DOD', 'FW-VBG-TL', 'DOZ', 'ABN-TL', 'VB+JJ-NC', 'VBZ', 'RB+CS', 'FW-PN', 'CS-NC', 'VBG-NC', 'BER-HL', 'MD*', '``', 'WPS-TL', 'OD-TL', 'PPSS-HL', 'PPS+MD', 'DO*', 'DO-HL', 'HVG-HL', 'WRB-HL', 'JJT', 'JJS', 'JJR', 'HV+TO', 'WQL', 'DOD-NC', 'CC-HL', 'FW-PPSS+HV', 'FW-NP-TL', 'MD+TO', 'VB+IN', 'JJT-NC', 'WDT+BEZ-TL', '---HL', 'PN$', 'VB+PPO', 'BE-TL', 'VBG-TL', 'NP$-HL', 'VBZ-TL', 'UH', 'FW-WPO', 'AP+AP-NC', 'FW-IN', 'NRS-TL', 'ABL', 'ABN', 'TO-TL', 'ABX', '*-HL', 'FW-WPS', 'VB-NC', 'HVD*', 'PPS+HVD', 'FW-IN+AT', 'FW-NP', 'QLP', 'FW-NR', 'FW-NN', 'PPS+HVZ', 'NNS-NC', 'DT+BEZ-NC', 'PPO', 'PPO-NC', 'EX-HL', 'AP$', 'OD-NC', 'RP', 'WPS+BEZ', 'NN+BEZ', '.-TL', ',', 'FW-DT+BEZ', 'RB', 'FW-PP$-NC', 'RN', 'JJ$-TL', 'MD-NC', 'VBD-NC', 'PPSS+BER-N', 'RB+BEZ-NC', 'WPS-HL', 'VBN-NC', 'BEZ-HL', 'PPL-NC', 'BER-TL', 'PP$$', 'NNS+MD', 'PPS-NC', 'FW-UH-NC', 'PPS+BEZ-NC', 'PPSS+BER-TL', 'NR-NC', 'FW-JJ', 'PPS+BEZ-HL', 'NPS$', 'RB-TL', 'VB-TL', 'BEM*', 'MD*-HL', 'FW-CC', 'NP+MD', 'EX+HVZ', 'FW-CD', 'EX+HVD', 'IN-HL', 'FW-CS', 'JJR-HL', 'FW-IN+NP-TL', 'JJ-TL-HL', 'FW-UH', 'EX', 'FW-NNS-NC', 'FW-JJ-NC', 'VBZ-HL', 'VB+RP', 'BEZ-NC', 'PPSS+HV-TL', 'HV*', 'IN', 'PP$-NC', 'NP-NC', 'BEN', 'PP$-TL', 'FW-*-TL', 'FW-OD-TL', 'WPS', 'WPO', 'MD+PPSS', 'WDT+BER', 'WDT+BEZ', 'CD-HL', 'WDT+BEZ-NC', 'WP$', 'DO+PPSS', 'HV-HL', 'DT-NC', 'PN-NC', 'FW-VBZ', 'HVD', 'HVG', 'NN+BEZ-TL', 'HVZ', 'FW-VBD', 'FW-VBG', 'NNS$-TL', 'JJ-TL', 'FW-VBN', 'MD-TL', 'WDT+DOD', 'HV-TL', 'NN-TL', 'PPSS', 'NR$', 'BER', 'FW-VB', 'DT', 'PN+BEZ', 'VBG-HL', 'FW-PPL+VBZ', 'FW-NPS-TL', 'RB$', 'FW-IN+NN', 'FW-CC-TL', 'RBT', 'RBR', 'PPS-TL', 'PPSS+HV', 'JJS-TL', 'NPS-HL', 'WPS+BEZ-TL', 'NNS-TL-HL', 'VBN-TL-NC', 'QL-TL', 'NN+NN-NC', 'JJR-TL', 'NN$-TL', 'FW-QL', 'IN-TL', 'BED-NC', 'NRS', '.-HL', 'QL', 'PP$-HL', 'WRB+BER', 'JJ', 'WRB+BEZ', 'NNS$-TL-HL', 'PPSS+BEZ', '(', 'PPSS+BER', 'DT+MD', 'DOZ-TL', 'PPSS+BEM', 'FW-PP$', 'RB+BEZ-HL', 'FW-RB+CC', 'FW-PPS', 'VBG+TO', 'DO*-HL', 'NR+MD', 'PPLS', 'IN+IN', 'BEZ*', 'FW-PPL', 'FW-PPO', 'NNS-HL', 'NIL', 'HVN', 'PPSS+BER-NC', 'AP-TL', 'FW-DT', '(-HL', 'DTI-TL', 'JJ+JJ-NC', 'FW-RB', 'FW-VBD-TL', 'BER-NC', 'NNS$-NC', 'JJ-NC', 'NPS$-TL', 'VB+VB-NC', 'PN', 'VB+TO', 'AT-TL-HL', 'BEM-NC', 'PPL-TL', 'ABN-HL', 'RB-NC', 'DO-NC', 'BE-HL', 'WRB+IN', 'FW-UH-TL', 'PPO-HL', 'FW-CD-TL', 'TO-HL', 'PPS+BEZ', 'CD$', 'DO', 'EX+MD', 'HVZ-TL', 'TO-NC', 'IN-NC', '.', 'WRB+DO', 'CD-NC', 'FW-PPO+IN', 'FW-NN$-TL', 'WDT+BEZ-HL', 'RP-HL', 'CC', 'NN+HVZ-TL', 'FW-NNS-TL', 'DT+BEZ', 'WPS+HVZ', 'BEDZ*', 'NP-TL', ':-TL', 'NN-NC', 'WPO-TL', 'QL-NC', 'FW-AT+NN-TL', 'WDT+HVZ', '.-NC', 'FW-DTS', 'NP-HL', ':-HL', 'RBR-NC', 'OD-HL', 'BEDZ-HL', 'VBD-TL', 'NPS-NC', ')', 'TO+VB', 'FW-IN+NN-TL', 'PPL', 'PPS', 'PPSS+VB', 'DT-TL', 'RP-NC', 'VB', 'FW-VB-TL', 'PP$', 'VBD-HL', 'DTI-HL', 'NN-TL-NC', 'PPL-HL', 'DOZ*', 'NR-TL', 'WRB+MD', 'PN+HVZ', 'FW-IN-TL', 'PN+HVD', 'BEN-TL', 'BE', 'WDT', 'WPS+HVD', 'DO-TL', 'FW-NN-NC', 'WRB+BEZ-TL', 'UH-TL', 'JJR-NC', 'NNS', 'PPSS-NC', 'WPS+BEZ-NC', ',-TL', 'NN$', 'VBN-TL-HL', 'WDT-NC', 'OD', 'FW-OD-NC', 'DOZ*-TL', 'PPSS+HVD', 'CS-TL', 'WRB+DOZ', 'CC-NC', 'HV', 'NN$-HL', 'FW-WDT', 'WRB+DOD', 'NN+HVZ', 'AT-NC', 'NNS-TL', 'FW-BEZ', 'CS-HL', 'WPO-NC', 'FW-BER', 'NNS-TL-NC', 'BEZ-TL', 'FW-IN+AT-T', 'ABN-NC', 'NR-TL-HL', 'BEDZ', 'NP+BEZ', 'FW-AT-TL', 'BER*', 'WPS+MD', 'MD-HL', 'BED*', 'HV-NC', 'WPS-NC', 'VBN-HL', 'FW-TO+VB', 'PPSS+MD-NC', 'HVZ*', 'PPS-HL', 'WRB-NC', 'VBN-TL', 'CD-TL-HL', ',-NC', 'RP-TL', 'AP-HL', 'FW-HV', 'WQL-TL', 'FW-AT', 'NN', 'NR$-TL', 'VBZ-NC', '*', 'PPSS-TL', 'JJT-HL', 'FW-NNS', 'NP', 'UH-HL', 'NR', ':', 'FW-NN$', 'RP+IN', ',-HL', 'JJ-TL-NC', 'AP-NC', '*-NC', 'VB-HL', 'HVZ-NC', 'DTS-HL', 'FW-JJT', 'FW-JJR', 'FW-JJ-TL', 'FW-*', 'RB+BEZ', "''", 'VB+AT', 'PN-HL', 'PPO-TL', 'CD-TL', 'UH-NC', 'FW-NN-TL-NC', 'EX-NC', 'PPSS+BEZ*', 'TO', 'WDT+DO+PPS', 'IN+PPO', 'AP', 'AT', 'DOZ-HL', 'FW-RB-TL', 'CD', 'NN+IN', 'FW-AT-HL', 'PN+MD', "'", 'FW-PP$-TL', 'FW-NPS', 'WDT+BER+PP', 'NN+HVD-TL', 'MD+HV', 'AT-HL', 'FW-IN+AT-TL'])