Somehow must get counts only from _completed_ trees.. how?
[dmvccm.git] / src / io.py
blob3ecdac38ab131896de4ecf9ffaf1612ddad387d1
1 ##################################################################
2 # Changes: #
3 ##################################################################
4 # 2008-05-24, KBU:
5 # - more elaborate pseudo-code in the io()-function
7 # 2008-05-25, KBU:
8 # - CNF_Rule has a function LHS() which returns head, in DMV_Rule this
9 # returns the pair of (bars, head).
11 # 2008-05-27
12 # - CNF_Rule has __eq__ and __ne__ defined, so that we can use == and
13 # != on two such rules
16 # import numpy # numpy provides Fast Arrays, for future optimization
17 # import pprint # for pretty-printing
19 DEBUG = []
21 # some of dmv-module bleeding in here... todo: prettier (in inner())
22 NOBAR = 0
23 STOP = (NOBAR, -2)
25 def debug(string, level=1):
26 '''Easily turn on/off inline debug printouts with this global. There's
27 a lot of cluttering debug statements here, todo: clean up'''
28 if level in DEBUG:
29 print string
32 class Grammar():
33 '''The PCFG used in the I/O-algorithm.
35 Public members:
36 p_terminals
38 Todo: as of now, this allows duplicate rules... should we check
39 for this? (eg. g = Grammar([x,x],[]) where x.prob == 1 may give
40 inner probabilities of 2.)'''
41 def all_rules(self):
42 return self.__p_rules
44 def rules(self, LHS):
45 return [rule for rule in self.all_rules() if rule.LHS() == LHS]
47 def numtag(self, num):
48 return self.__numtag[num]
50 def tagnum(self, tag):
51 return self.__tagnum[tag]
53 def __init__(self, p_rules, p_terminals, numtag, tagnum):
54 '''rules and p_terminals should be arrays, where p_terminals are of
55 the form [preterminal, terminal], and rules are CNF_Rule's.'''
56 self.__p_rules = p_rules # todo: could check for summing to 1 (+/- epsilon)
57 self.__numtag = numtag
58 self.__tagnum = tagnum
59 self.p_terminals = p_terminals
64 class CNF_Rule():
65 '''A single CNF rule in the PCFG, of the form
66 LHS -> L R
67 where these are just integers
68 (where do we save the connection between number and symbol?
69 symbols being 'vbd' etc.)'''
70 def __eq__(self, other):
71 return self.LHS() == other.LHS() and self.R() == other.R() and self.L() == other.L()
72 def __ne__(self, other):
73 return self.LHS() != other.LHS() or self.R() != other.R() or self.L() != other.L()
74 def __str__(self):
75 return "%s -> %s %s [%.2f]" % (self.LHS(), self.L(), self.R(), self.prob)
76 def __init__(self, LHS, L, R, prob):
77 self.__LHS = LHS
78 self.__R = R
79 self.__L = L
80 self.prob = prob
81 def p(self, *arg):
82 "Return a probability, doesn't care about attachment..."
83 return self.prob
84 def LHS(self):
85 return self.__LHS
86 def L(self):
87 return self.__L
88 def R(self):
89 return self.__R
91 def inner(s, t, LHS, g, sent, chart):
92 ''' Give the inner probability of having the node LHS cover whatever's
93 between s and t in sentence sent, using grammar g.
95 Returns a pair of the inner probability and the chart
97 For DMV, LHS is a pair (bar, h), but this function ought to be
98 agnostic about that.
100 e() is an internal function, so the variable chart (a dictionary)
101 is available to all calls of e().
103 Since terminal probabilities are just simple lookups, they are not
104 put in the chart (although we could put them in there later to
105 optimize)
108 def O(s):
109 return sent[s]
111 def e(s,t,LHS):
112 '''Chart has lists of probability and whether or not we've attached
113 yet to L and R, each entry is a list [p, Rattach, Lattach], where if
114 Rattach==True then the rule has a right-attachment or there is one
115 lower in the tree (meaning we're no longer adjacent).'''
116 if (s, t, LHS) in chart:
117 return chart[(s, t, LHS)]
118 else:
119 debug( "trying from %d to %d with %s" % (s,t,LHS) )
120 if s == t:
121 if (LHS, O(s)) in g.p_terminals:
122 prob = g.p_terminals[LHS, O(s)] # b[LHS, O(s)]
123 else:
124 prob = 0.0 # todo: is this the right way to deal with lacking rules?
125 print "\t LACKING TERMINAL:"
126 debug( "\t terminal: %s -> %s : %.1f" % (LHS, O(s), prob) )
127 # terminals have no attachment
128 return prob
129 else:
130 if (s,t,LHS) not in chart:
131 # by default, not attachment yet
132 chart[(s,t,LHS)] = 0.0 #, False, False]
133 for rule in g.rules(LHS): # summing over j,k in a[LHS,j,k]
134 debug( "\tsumming rule %s" % rule )
135 L = rule.L()
136 R = rule.R()
137 for r in range(s, t): # summing over r = s to r = t-1
138 p_L = e(s, r, L)
139 p_R = e(r + 1, t, R)
140 p = rule.p("todo","todo")
141 chart[(s, t, LHS)] += p * p_L * p_R
142 debug( "\tchart[(%d,%d,%s)] = %.2f" % (s,t,LHS, chart[(s,t,LHS)]) )
143 return chart[(s, t, LHS)]
144 # end of e-function
146 inner_prob = e(s,t,LHS)
147 if 1 in DEBUG:
148 print "---CHART:---"
149 for k,v in chart.iteritems():
150 print "\t%s -> %s_%d ... %s_%d : %.1f" % (k[2], O(k[0]), k[0], O(k[1]), k[1], v)
151 print "---CHART:end---"
152 return [inner_prob, chart] # inner_prob == chart[(s,t,LHS)]
161 if __name__ == "__main__":
162 print "IO-module tests:"
163 b = {}
164 s = CNF_Rule(0,1,2, 1.0) # s->np vp
165 np = CNF_Rule(1,3,4, 0.3) # np->n p
166 b[1, 'n'] = 0.7 # np->'n'
167 b[3, 'n'] = 1.0 # n->'n'
168 b[4, 'p'] = 1.0 # p->'p'
169 vp = CNF_Rule(2,5,1, 0.1) # vp->v np (two parses use this rule)
170 vp2 = CNF_Rule(2,2,4, 0.9) # vp->vp p
171 b[5, 'v'] = 1.0 # v->'v'
173 g = Grammar([s,np,vp,vp2], b, {0:'s',1:'np',2:'vp',3:'n',4:'p',5:'v'},
174 {'s':0,'np':1,'vp':2,'n':3,'p':4,'v':5})
176 print "The rules:"
177 for i in range(0,5):
178 for r in g.rules(i):
179 print r
180 print ""
182 test1 = inner(0,0, 1, g, ['n'], {})
183 if test1[0] != 0.7:
184 print "should be 0.70 : %.2f" % test1[0]
185 print ""
187 DEBUG = [1]
188 test2 = inner(0,2, 2, g, ['v','n','p'], test1[1])
189 print "should be 0.?? (.09??) : %.2f" % test2[0]
190 print "------ trying the same again:----------"
191 test2 = inner(0,2, 2, g, ['v','n','p'], test2[1])
192 print "should be 0.?? (.09??) : %.2f" % test2[0]
195 ##################################################################
196 # just junk from here on down: #
197 ##################################################################
199 # def io(corpus):
200 # "(pseudo-code / wishful thinking) "
201 # g = initialize(corpus) # or corpus.tagset ?
203 # P = {('v','n','p'):0.09}
204 # # P is used in v_q, w_q (expectation), so each sentence in the
205 # # corpus needs some initial P.
207 # # --- Maximation: ---
209 # # actually, this step (from Lari & Young) probably never happens
210 # # with DMV, since instead of the a[i,j,k] and b[i,m] vectors, we
211 # # have P_STOP and P_CHOOSE... or, in a sense it happens only we
212 # # calculate P_STOP and P_CHOOSE..somehow.
213 # for rule in g.p_rules:
214 # rule.num = 0
215 # rule.den = 0
216 # for pre_term in range(len(g.p_terminals)):
217 # ptnum[pre_term] = 0
218 # ptden[pre_term] = 0
220 # # we could also flip this to make rules the outer loop, then we
221 # # wouldn't have to initialize den/num in loops of their own
222 # for sent in corpus:
223 # for rule in g.p_rules # Equation 20
224 # for s in range(len(sent)):
225 # for t in range(s, len(sent)):
226 # rule.num += w(s,t, rule.LHS(),rule.L,rule.R, g, sent, P[sent])
227 # rule.den += v(s,t, rule.LHS(), g, sent, P[sent])
228 # # todo: do we need a "new-prob" vs "old-prob" distinction here?
229 # probably, since we use inner/outer which checks rule.prob()
230 # # todo: also, this wouldn't work, since for each sentence, we'd
231 # # discard the old probability; should rules be the outer
232 # # loop then?
233 # rule.prob = rule.num / rule.den
234 # for pre_term in range(len(g.p_terminals)): # Equation 21
235 # num = 0
236 # den = 0
237 # for s in range(len(sent)):
238 # for t in range(s, len(sent)):
239 # num += v(t,t,pre_term, g, sent, P[sent])
240 # den += v(s,t,pre_term, g, sent, P[sent])
242 # for rule in g.rules:
243 # rule.prob = rule.num / rule.den
244 # for pre_term in range(len(g.p_terminals)):
245 # g.p_terminals[pre_term] = ptnum[pre_term] / ptden[pre_term]
248 # # --- Expectation: ---
249 # for sent in corpus: # Equation 11
250 # inside = inner(0, len(sent), ROOT, g, sent)
251 # P[sent] = inside[0]
253 # # todo: set inner.chart to {} again, how?
255 # # todo: need a old-P new-P distinction to check if we're below
256 # # threshold difference
257 # return "todo"
259 # def w(s,t, LHS,L,R, g, sent, P_sent):
260 # w = 0
261 # rule = g.rule(LHS, L, R)
262 # for r in range(s, t):
263 # w += rule.prob() * inner(s,r, L, g, sent) * inner(r+1, t, R, g, sent) * outer(s,t,LHS,g,sent)
264 # return w / P_sent
266 # def v(s,t, LHS, g, sent, P_sent):
267 # return ( inner(s,t, LHS, g, sent) * outer(s,t, LHS, g, sent) ) / P_sent