loc_h_dmv reestimation conforms to formulas.pdf, but cnf_ version still differs somehow
[dmvccm.git] / src / io.py
blob0853544ce00955f0b1b372278760dff25bdae2e7
1 # io.py, with sentence positions between word locations (i<k<j)
3 DEBUG = set(['TODO'])
5 # some of dmv-module bleeding in here... todo: prettier (in inner())
6 NOBAR = 0
7 STOP = (NOBAR, -2)
8 ROOTNUM = -1
10 def debug(string, level='TODO'):
11 '''Easily turn on/off inline debug printouts with this global. There's
12 a lot of cluttering debug statements here, todo: clean up'''
13 if level in DEBUG:
14 print string
17 class Grammar():
18 '''The PCFG used in the I/O-algorithm.
20 Public members:
21 p_terminals
23 Todo: as of now, this allows duplicate rules... should we check
24 for this? (eg. g = Grammar([x,x],[]) where x.prob == 1 may give
25 inner probabilities of 2.)'''
26 def all_rules(self):
27 return self.__p_rules
29 def new_rules(self, p_rules):
30 self.__p_rules = p_rules
32 def rules(self, LHS):
33 return [rule for rule in self.all_rules() if rule.LHS() == LHS]
35 def headnums(self):
36 return self.__head_nums
38 def sent_nums(self, sent):
39 return [self.tagnum(tag) for tag in sent]
41 def numtag(self, num):
42 if num == ROOTNUM: # don't want these added to headnums (which we iter through)
43 return 'ROOT'
44 else:
45 return self.__numtag[num]
47 def tagnum(self, tag):
48 if tag == 'ROOT':
49 return ROOTNUM
50 else:
51 return self.__tagnum[tag]
53 def get_nums_tags(self):
54 return (self.__numtag,self.__tagnum)
56 def __init__(self, numtag, tagnum, p_rules=[], p_terminals=[]):
57 '''rules and p_terminals should be arrays, where p_terminals are of
58 the form [preterminal, terminal], and rules are CNF_Rule's.'''
59 self.__numtag = numtag
60 self.__tagnum = tagnum
61 self.__head_nums = [k for k in numtag.iterkeys()]
62 self.__p_rules = p_rules # todo: could check for summing to 1 (+/- epsilon)
63 self.p_terminals = p_terminals
68 class CNF_Rule():
69 '''A single CNF rule in the PCFG, of the form
70 LHS -> L R
71 where these are just integers
72 (where do we save the connection between number and symbol?
73 symbols being 'vbd' etc.)'''
74 def __eq__(self, other):
75 return self.LHS() == other.LHS() and self.R() == other.R() and self.L() == other.L()
76 def __ne__(self, other):
77 return self.LHS() != other.LHS() or self.R() != other.R() or self.L() != other.L()
78 def __str__(self):
79 return "%s -> %s %s [%.2f]" % (self.LHS(), self.L(), self.R(), self.prob)
80 def __init__(self, LHS, L, R, prob):
81 self.__LHS = LHS
82 self.__R = R
83 self.__L = L
84 self.prob = prob
85 def p(self, *arg):
86 "Return a probability, doesn't care about attachment..."
87 return self.prob
88 def LHS(self):
89 return self.__LHS
90 def L(self):
91 return self.__L
92 def R(self):
93 return self.__R
95 def inner(i, j, LHS, g, sent, chart):
96 ''' Give the inner probability of having the node LHS cover whatever's
97 between s and t in sentence sent, using grammar g.
99 Returns a pair of the inner probability and the chart
101 For DMV, LHS is a pair (bar, h), but this function ought to be
102 agnostic about that.
104 e() is an internal function, so the variable chart (a dictionary)
105 is available to all calls of e().
107 Since terminal probabilities are just simple lookups, they are not
108 put in the chart (although we could put them in there later to
109 optimize)
112 def O(i,j):
113 return sent[i]
115 def e(i,j,LHS):
116 '''Chart has lists of probability and whether or not we've attached
117 yet to L and R, each entry is a list [p, Rattach, Lattach], where if
118 Rattach==True then the rule has a right-attachment or there is one
119 lower in the tree (meaning we're no longer adjacent).'''
120 if (i, j, LHS) in chart:
121 return chart[i, j, LHS]
122 else:
123 debug( "trying from %d to %d with %s" % (i,j,LHS) , "IO")
124 if i+1 == j:
125 if (LHS, O(i,j)) in g.p_terminals:
126 prob = g.p_terminals[LHS, O(i,j)] # b[LHS, O(s)] in L&Y
127 else:
128 prob = 0.0
129 print "\t LACKING TERMINAL:%s -> %s : %.1f" % (LHS, O(i,j), prob)
130 debug( "\t terminal: %s -> %s : %.1f" % (LHS, O(i,j), prob) ,"IO")
131 # terminals have no attachment
132 return prob
133 else:
134 if (i,j,LHS) not in chart:
135 # by default, not attachment yet
136 chart[i,j,LHS] = 0.0
137 for rule in g.rules(LHS): # summing over rules headed by LHS, "a[i,j,k]"
138 debug( "\tsumming rule %s" % rule , "IO")
139 L = rule.L()
140 R = rule.R()
141 for k in range(i+1, j): # i<k<j
142 p_L = e(i, k, L)
143 p_R = e(k, j, R)
144 chart[i, j, LHS] += rule.p() * p_L * p_R
145 debug( "\tchart[%d,%d,%s] = %.2f" % (i,j,LHS, chart[i,j,LHS]) ,"IO")
146 return chart[i, j, LHS]
147 # end of e-function
149 inner_prob = e(i,j,LHS)
150 if 'IO' in DEBUG:
151 print "---CHART:---"
152 for k,v in chart.iteritems():
153 print "\t%s -> %s_%d ... %s_%d : %.1f" % (k[2], O(k[0]), k[0], O(k[1]), k[1], v)
154 print "---CHART:end---"
155 return [inner_prob, chart]
164 if __name__ == "__main__":
165 print "IO-module tests:"
166 b = {}
167 s = CNF_Rule(0,1,2, 1.0) # s->np vp
168 np = CNF_Rule(1,3,4, 0.3) # np->n p
169 b[1, 'n'] = 0.7 # np->'n'
170 b[3, 'n'] = 1.0 # n->'n'
171 b[4, 'p'] = 1.0 # p->'p'
172 vp = CNF_Rule(2,5,1, 0.1) # vp->v np (two parses use this rule)
173 vp2 = CNF_Rule(2,2,4, 0.9) # vp->vp p
174 b[5, 'v'] = 1.0 # v->'v'
176 g = Grammar({0:'s',1:'np',2:'vp',3:'n',4:'p',5:'v'},
177 {'s':0,'np':1,'vp':2,'n':3,'p':4,'v':5},
178 [s,np,vp,vp2], b)
180 # print "The rules:"
181 # for i in range(0,5):
182 # for r in g.rules(i):
183 # print r
184 # print ""
186 test1 = inner(0,1, 1, g, ['n'], {})
187 if test1[0] != 0.7:
188 print "should be 0.70 : %.3f" % test1[0]
189 print ""
191 test2 = inner(0,3, 2, g, ['v','n','p'], test1[1])
192 if test2[0] != 0.0930:
193 print "should be 0.0930 : %.4f" % test2[0]
194 test2 = inner(0,3, 2, g, ['v','n','p'], test2[1])
195 if test2[0] != 0.0930:
196 print "should be 0.0930 : %.4f" % test2[0]