7eec28f1a0360f2d99bf825a39a4acffe428dbdc
[dmvccm.git] / src / io.py
blob7eec28f1a0360f2d99bf825a39a4acffe428dbdc
1 # io.py, with sentence positions between word locations (i<k<j)
3 DEBUG = set(['TODO'])
5 # some of dmv-module bleeding in here... todo: prettier (in inner())
6 NOBAR = 0
7 STOP = (NOBAR, -2)
8 ROOTNUM = -1
10 def debug(string, level='TODO'):
11 '''Easily turn on/off inline debug printouts with this global. There's
12 a lot of cluttering debug statements here, todo: clean up'''
13 if level in DEBUG:
14 print string
17 class Grammar():
18 '''The PCFG used in the I/O-algorithm.
20 Public members:
21 p_terminals
23 Todo: as of now, this allows duplicate rules... should we check
24 for this? (eg. g = Grammar([x,x],[]) where x.prob == 1 may give
25 inner probabilities of 2.)'''
26 def all_rules(self):
27 return self.__p_rules
29 def new_rules(self, p_rules):
30 self.__p_rules = p_rules
32 def rules(self, LHS):
33 return [rule for rule in self.all_rules() if rule.LHS() == LHS]
35 def headnums(self):
36 return self.__head_nums
38 def sent_nums(self, sent):
39 return [self.tagnum(tag) for tag in sent]
41 def numtag(self, num):
42 if num == ROOTNUM: # don't want these added to headnums (which we iter through)
43 return 'ROOT'
44 else:
45 return self.__numtag[num]
47 def tagnum(self, tag):
48 if tag == 'ROOT':
49 return ROOTNUM
50 else:
51 return self.__tagnum[tag]
53 def __init__(self, numtag, tagnum, p_rules=[], p_terminals=[]):
54 '''rules and p_terminals should be arrays, where p_terminals are of
55 the form [preterminal, terminal], and rules are CNF_Rule's.'''
56 self.__numtag = numtag
57 self.__tagnum = tagnum
58 self.__head_nums = [k for k in numtag.iterkeys()]
59 self.__p_rules = p_rules # todo: could check for summing to 1 (+/- epsilon)
60 self.p_terminals = p_terminals
65 class CNF_Rule():
66 '''A single CNF rule in the PCFG, of the form
67 LHS -> L R
68 where these are just integers
69 (where do we save the connection between number and symbol?
70 symbols being 'vbd' etc.)'''
71 def __eq__(self, other):
72 return self.LHS() == other.LHS() and self.R() == other.R() and self.L() == other.L()
73 def __ne__(self, other):
74 return self.LHS() != other.LHS() or self.R() != other.R() or self.L() != other.L()
75 def __str__(self):
76 return "%s -> %s %s [%.2f]" % (self.LHS(), self.L(), self.R(), self.prob)
77 def __init__(self, LHS, L, R, prob):
78 self.__LHS = LHS
79 self.__R = R
80 self.__L = L
81 self.prob = prob
82 def p(self, *arg):
83 "Return a probability, doesn't care about attachment..."
84 return self.prob
85 def LHS(self):
86 return self.__LHS
87 def L(self):
88 return self.__L
89 def R(self):
90 return self.__R
92 def inner(i, j, LHS, g, sent, chart):
93 ''' Give the inner probability of having the node LHS cover whatever's
94 between s and t in sentence sent, using grammar g.
96 Returns a pair of the inner probability and the chart
98 For DMV, LHS is a pair (bar, h), but this function ought to be
99 agnostic about that.
101 e() is an internal function, so the variable chart (a dictionary)
102 is available to all calls of e().
104 Since terminal probabilities are just simple lookups, they are not
105 put in the chart (although we could put them in there later to
106 optimize)
109 def O(i,j):
110 return sent[i]
112 def e(i,j,LHS):
113 '''Chart has lists of probability and whether or not we've attached
114 yet to L and R, each entry is a list [p, Rattach, Lattach], where if
115 Rattach==True then the rule has a right-attachment or there is one
116 lower in the tree (meaning we're no longer adjacent).'''
117 if (i, j, LHS) in chart:
118 return chart[i, j, LHS]
119 else:
120 debug( "trying from %d to %d with %s" % (i,j,LHS) , "IO")
121 if i+1 == j:
122 if (LHS, O(i,j)) in g.p_terminals:
123 prob = g.p_terminals[LHS, O(i,j)] # b[LHS, O(s)] in L&Y
124 else:
125 prob = 0.0
126 print "\t LACKING TERMINAL:%s -> %s : %.1f" % (LHS, O(i,j), prob)
127 debug( "\t terminal: %s -> %s : %.1f" % (LHS, O(i,j), prob) ,"IO")
128 # terminals have no attachment
129 return prob
130 else:
131 if (i,j,LHS) not in chart:
132 # by default, not attachment yet
133 chart[i,j,LHS] = 0.0
134 for rule in g.rules(LHS): # summing over rules headed by LHS, "a[i,j,k]"
135 debug( "\tsumming rule %s" % rule , "IO")
136 L = rule.L()
137 R = rule.R()
138 for k in range(i+1, j): # i<k<j
139 p_L = e(i, k, L)
140 p_R = e(k, j, R)
141 chart[i, j, LHS] += rule.p() * p_L * p_R
142 debug( "\tchart[%d,%d,%s] = %.2f" % (i,j,LHS, chart[i,j,LHS]) ,"IO")
143 return chart[i, j, LHS]
144 # end of e-function
146 inner_prob = e(i,j,LHS)
147 if 'IO' in DEBUG:
148 print "---CHART:---"
149 for k,v in chart.iteritems():
150 print "\t%s -> %s_%d ... %s_%d : %.1f" % (k[2], O(k[0]), k[0], O(k[1]), k[1], v)
151 print "---CHART:end---"
152 return [inner_prob, chart]
161 if __name__ == "__main__":
162 print "IO-module tests:"
163 b = {}
164 s = CNF_Rule(0,1,2, 1.0) # s->np vp
165 np = CNF_Rule(1,3,4, 0.3) # np->n p
166 b[1, 'n'] = 0.7 # np->'n'
167 b[3, 'n'] = 1.0 # n->'n'
168 b[4, 'p'] = 1.0 # p->'p'
169 vp = CNF_Rule(2,5,1, 0.1) # vp->v np (two parses use this rule)
170 vp2 = CNF_Rule(2,2,4, 0.9) # vp->vp p
171 b[5, 'v'] = 1.0 # v->'v'
173 g = Grammar({0:'s',1:'np',2:'vp',3:'n',4:'p',5:'v'},
174 {'s':0,'np':1,'vp':2,'n':3,'p':4,'v':5},
175 [s,np,vp,vp2], b)
177 # print "The rules:"
178 # for i in range(0,5):
179 # for r in g.rules(i):
180 # print r
181 # print ""
183 test1 = inner(0,1, 1, g, ['n'], {})
184 if test1[0] != 0.7:
185 print "should be 0.70 : %.3f" % test1[0]
186 print ""
188 test2 = inner(0,3, 2, g, ['v','n','p'], test1[1])
189 if test2[0] != 0.0930:
190 print "should be 0.0930 : %.4f" % test2[0]
191 test2 = inner(0,3, 2, g, ['v','n','p'], test2[1])
192 if test2[0] != 0.0930:
193 print "should be 0.0930 : %.4f" % test2[0]