report almost done
[dmvccm.git] / src / io.py.~15~
blob2b20ad7d5a1e69547675652523ee7e055031e765
1 ##################################################################
2 #                            Changes:                            #
3 ##################################################################
4 # 2008-05-24, KBU:
5 # - more elaborate pseudo-code in the io()-function
7 # 2008-05-25, KBU:
8 # - CNF_Rule has a function LHS() which returns head, in DMV_Rule this
9 #   returns the pair of (bars, head). 
11 # 2008-05-27
12 # - CNF_Rule has __eq__ and __ne__ defined, so that we can use == and 
13 #   != on two such rules
16 # import numpy # numpy provides Fast Arrays, for future optimization
17 # import pprint # for pretty-printing
19 DEBUG = 0
21 # some of dmv-module bleeding in here... todo: prettier (in inner())
22 NOBAR = 0
23 STOP = (NOBAR, -2) 
25 def debug(string):
26     '''Easily turn on/off inline debug printouts with this global. There's
27 a lot of cluttering debug statements here, todo: clean up'''
28     if DEBUG:
29         print string
32 class Grammar():
33     '''The PCFG used in the I/O-algorithm.
35     Public members:
36     p_terminals
37     
38     Todo: as of now, this allows duplicate rules... should we check
39     for this?  (eg. g = Grammar([x,x],[]) where x.prob == 1 may give
40     inner probabilities of 2.)'''
41     def all_rules(self):
42         return self.__p_rules
43     
44     def rules(self, LHS):
45         return [rule for rule in self.all_rules() if rule.LHS() == LHS]
46     
47     def numtag(self):
48         return __numtag
50     def tagnum(self):
51         return __tagnum
53     def __init__(self, p_rules, p_terminals, numtag, tagnum):
54         '''rules and p_terminals should be arrays, where p_terminals are of
55         the form [preterminal, terminal], and rules are CNF_Rule's.'''        
56         self.__p_rules = p_rules # todo: could check for summing to 1 (+/- epsilon)
57         self.__numtag = numtag
58         self.__tagnum = tagnum
59         self.p_terminals = p_terminals
64 class CNF_Rule():
65     '''A single CNF rule in the PCFG, of the form 
66     LHS -> L R
67     where these are just integers
68     (where do we save the connection between number and symbol?
69     symbols being 'vbd' etc.)'''
70     def __eq__(self, other):
71         return self.LHS() == other.LHS() and self.R() == other.R() and self.L() == other.L()
72     def __ne__(self, other):
73         return self.LHS() != other.LHS() or self.R() != other.R() or self.L() != other.L()
74     def __str__(self):
75         return "%s -> %s %s [%.2f]" % (self.LHS(), self.L(), self.R(), self.prob)
76     def __init__(self, LHS, L, R, prob):
77         self.__LHS = LHS
78         self.__R = R
79         self.__L = L
80         self.prob = prob
81     def p(self, *arg):
82         "Return a probability, doesn't care about attachment..."
83         return self.prob
84     def LHS(self):
85         return self.__LHS
86     def L(self):
87         return self.__L
88     def R(self):
89         return self.__R
90     
91 def inner(s, t, LHS, g, sent, chart):
92     ''' Give the inner probability of having the node LHS cover whatever's
93     between s and t in sentence sent, using grammar g.
95     Returns a pair of the inner probability and the chart
97     For DMV, LHS is a pair (bar, h), but this function ought to be
98     agnostic about that.
100     e() is an internal function, so the variable chart (a dictionary)
101     is available to all calls of e().
103     Since terminal probabilities are just simple lookups, they are not
104     put in the chart (although we could put them in there later to
105     optimize)
106     '''
107     
108     def O(s):
109         return sent[s]
110     
111     def e(s,t,LHS):
112         '''Chart has lists of probability and whether or not we've attached
113 yet to L and R, each entry is a list [p, Rattach, Lattach], where if
114 Rattach==True then the rule has a right-attachment or there is one
115 lower in the tree (meaning we're no longer adjacent).'''
116         if (s, t, LHS) in chart:
117             return chart[(s, t, LHS)]
118         else:
119             debug( "trying from %d to %d with %s" % (s,t,LHS) )
120             if s == t:
121                 if (LHS, O(s)) in g.p_terminals:
122                     prob = g.p_terminals[LHS, O(s)] # b[LHS, O(s)]
123                 else:
124                     prob = 0.0 # todo: is this the right way to deal with lacking rules?
125                     print "\t LACKING TERMINAL:"
126                 debug( "\t terminal: %s -> %s : %.1f" % (LHS, O(s), prob) )
127                 # terminals have no attachment
128                 return prob
129             else:
130                 if (s,t,LHS) not in chart:
131                     # by default, not attachment yet
132                     chart[(s,t,LHS)] = 0.0 #, False, False]
133                 for rule in g.rules(LHS): # summing over j,k in a[LHS,j,k]
134                     debug( "\tsumming rule %s" % rule ) 
135                     L = rule.L()
136                     R = rule.R()
137                     for r in range(s, t): # summing over r = s to r = t-1
138                         p_L = e(s, r, L)
139                         p_R = e(r + 1, t, R)
140                         p = rule.p("todo","todo") 
141                         chart[(s, t, LHS)] += p * p_L * p_R
142                 debug( "\tchart[(%d,%d,%s)] = %.2f" % (s,t,LHS, chart[(s,t,LHS)]) )
143                 return chart[(s, t, LHS)]
144     # end of e-function
145     
146     inner_prob = e(s,t,LHS)
147     if DEBUG:
148         print "---CHART:---"
149         for k,v in chart.iteritems():
150             print "\t%s -> %s_%d ... %s_%d : %.1f" % (k[2], O(k[0]), k[0], O(k[1]), k[1], v)
151         print "---CHART:end---"
152     return [inner_prob, chart] # inner_prob == chart[(s,t,LHS)] 
161 if __name__ == "__main__":
162     print "IO-module tests:"
163     b = {}
164     s   = CNF_Rule(0,1,2, 1.0) # s->np vp
165     np  = CNF_Rule(1,3,4, 0.3) # np->n p
166     b[1, 'n'] = 0.7 # np->'n'
167     b[3, 'n'] = 1.0 # n->'n'
168     b[4, 'p'] = 1.0 # p->'p'
169     vp  = CNF_Rule(2,5,1, 0.1) # vp->v np (two parses use this rule)
170     vp2 = CNF_Rule(2,2,4, 0.9) # vp->vp p
171     b[5, 'v'] = 1.0 # v->'v'
172     
173     g = Grammar([s,np,vp,vp2], b, {0:'s',1:'np',2:'vp',3:'n',4:'p',5:'v'})
174     
175     print "The rules:"
176     for i in range(0,5):
177         for r in g.rules(i):
178             print r
179     print ""
180     
181     test1 = inner(0,0, 1, g, ['n'], {})
182     if test1[0] != 0.7:
183         print "should be 0.70 : %.2f" % test1[0]
184         print ""
185     
186     DEBUG = 1
187     test2 = inner(0,2, 2, g, ['v','n','p'], test1[1])
188     print "should be 0.?? (.09??) : %.2f" % test2[0]
189     print "------ trying the same again:----------"
190     test2 = inner(0,2, 2, g, ['v','n','p'], test2[1])
191     print "should be 0.?? (.09??) : %.2f" % test2[0]
193     
194 ##################################################################
195 #            just junk from here on down:                        #
196 ##################################################################
198 # def io(corpus):
199 #     "(pseudo-code / wishful thinking) "
200 #     g = initialize(corpus) # or corpus.tagset ?
201     
202 #     P = {('v','n','p'):0.09}
203 #     # P is used in v_q, w_q (expectation), so each sentence in the
204 #     # corpus needs some initial P.
205     
206 #     # --- Maximation: ---
207 #     #
208 #     # actually, this step (from Lari & Young) probably never happens
209 #     # with DMV, since instead of the a[i,j,k] and b[i,m] vectors, we
210 #     # have P_STOP and P_CHOOSE... or, in a sense it happens only we
211 #     # calculate P_STOP and P_CHOOSE..somehow.
212 #     for rule in g.p_rules:
213 #         rule.num = 0
214 #         rule.den = 0
215 #     for pre_term in range(len(g.p_terminals)): 
216 #         ptnum[pre_term] = 0
217 #         ptden[pre_term] = 0
219 #     # we could also flip this to make rules the outer loop, then we
220 #     # wouldn't have to initialize den/num in loops of their own
221 #     for sent in corpus:
222 #         for rule in g.p_rules # Equation 20
223 #             for s in range(len(sent)):
224 #                 for t in range(s, len(sent)):
225 #                     rule.num += w(s,t, rule.LHS(),rule.L,rule.R, g, sent, P[sent])
226 #                     rule.den += v(s,t, rule.LHS(), g, sent, P[sent])
227 #             # todo: do we need a "new-prob" vs "old-prob" distinction here?
228 #                     probably, since we use inner/outer which checks rule.prob()
229 #             # todo: also, this wouldn't work, since for each sentence, we'd
230 #             #       discard the old probability; should rules be the outer
231 #             #       loop then?
232 #             rule.prob = rule.num / rule.den
233 #         for pre_term in range(len(g.p_terminals)): # Equation 21
234 #             num = 0
235 #             den = 0
236 #             for s in range(len(sent)):
237 #                 for t in range(s, len(sent)):
238 #                     num += v(t,t,pre_term, g, sent, P[sent])
239 #                     den += v(s,t,pre_term, g, sent, P[sent])
241 #     for rule in g.rules:
242 #         rule.prob = rule.num / rule.den
243 #     for pre_term in range(len(g.p_terminals)): 
244 #         g.p_terminals[pre_term] = ptnum[pre_term] / ptden[pre_term]
247 #     # --- Expectation: ---
248 #     for sent in corpus: # Equation 11
249 #         inside = inner(0, len(sent), ROOT, g, sent)
250 #         P[sent] = inside[0]
251     
252 #     # todo: set inner.chart to {} again, how?
254 #     # todo: need a old-P new-P distinction to check if we're below
255 #     # threshold difference
256 #     return "todo"
258 # def w(s,t, LHS,L,R, g, sent, P_sent):
259 #     w = 0
260 #     rule = g.rule(LHS, L, R)
261 #     for r in range(s, t):
262 #         w += rule.prob() * inner(s,r, L, g, sent) * inner(r+1, t, R, g, sent) * outer(s,t,LHS,g,sent)
263 #     return w / P_sent
264         
265 # def v(s,t, LHS, g, sent, P_sent):
266 #     return ( inner(s,t, LHS, g, sent) * outer(s,t, LHS, g, sent) ) / P_sent
267