5e2a8be41f4c4b167669ef16f70b23436d34421f
3 #import numpy # numpy provides Fast Arrays, for future optimization
6 from common_dmv
import *
7 SEALS
= [GOR
, RGOL
, SEAL
, NGOR
, NRGOL
] # overwriting here
10 if __name__
== "__main__":
11 print "cnf_dmv module tests:"
13 def make_GO_AT(p_STOP
,p_ATTACH
):
15 for (a
,h
,dir), p_ah
in p_ATTACH
.iteritems():
16 p_GO_AT
[a
,h
,dir, NON
] = p_ah
* (1-p_STOP
[h
, dir, NON
])
17 p_GO_AT
[a
,h
,dir, ADJ
] = p_ah
* (1-p_STOP
[h
, dir, ADJ
])
20 class CNF_DMV_Grammar(io
.Grammar
):
24 p_STOP, p_ROOT, p_ATTACH, p_terminals
25 These are changed in the Maximation step, then used to set the
26 new probabilities of each CNF_DMV_Rule.
28 __p_rules is private, but we can still say stuff like:
29 for r in g.all_rules():
30 r.prob = (1-p_STOP[...]) * p_ATTACH[...]
34 for r
in self
.all_rules():
35 str += "%s\n" % r
.__str
__(self
.numtag
)
39 return [ROOT
] + [(s_h
,h
)
40 for h
in self
.headnums()
43 def sent_rules(self
, sent_nums
):
44 sent_nums_stop
= sent_nums
+ [POS(STOP
)]
45 return [ r
for LHS
in self
.LHSs()
46 for r
in self
.arg_rules(LHS
, sent_nums
)
47 if POS(r
.L()) in sent_nums_stop
48 and POS(r
.R()) in sent_nums_stop
]
51 def mothersR(self
, w_node
, argnums
):
52 '''For all LHS and x, return all rules of the form 'LHS->x w_node'.'''
53 return [r
for LHS
in self
.LHSs()
54 for r
in self
.arg_rules(LHS
, argnums
)
57 def mothersL(self
, w_node
, argnums
):
58 '''For all LHS and x, return all rules of the form 'LHS->w_node x'.'''
59 return [r
for LHS
in self
.LHSs()
60 for r
in self
.arg_rules(LHS
, argnums
)
64 def arg_rules(self
, LHS
, argnums
):
65 return [r
for r
in self
.rules(LHS
)
66 if (POS(r
.R()) in argnums
67 or POS(r
.L()) in argnums
)]
70 def make_all_rules(self
):
71 self
.new_rules([r
for LHS
in self
.LHSs()
72 for r
in self
._make
_rules
(LHS
, self
.headnums())])
74 def _make_rules(self
, LHS
, argnums
):
75 '''This is where the CNF grammar is defined. Also, s_dir_typ shows how
76 useful it'd be to split up the seals into direction and
80 return [CNF_DMV_Rule(LEFT
, LHS
, (SEAL
,h
), STOP
, self
.p_ROOT
[h
])
81 for h
in set(argnums
)]
84 return [] # only terminals from here on
85 s_dir_type
= { # seal of LHS
86 RGOL
: (RIGHT
, 'STOP'), NGOR
: (RIGHT
, 'ATTACH'),
87 SEAL
: (LEFT
, 'STOP'), NRGOL
: (LEFT
, 'ATTACH') }
88 dir_s_adj
= { # seal of h_daughter
89 RIGHT
: [(GOR
, True),(NGOR
, False)] ,
90 LEFT
: [(RGOL
,True),(NRGOL
,False)] }
91 dir,type = s_dir_type
[s_h
]
93 'ATTACH': [CNF_DMV_Rule(dir, LHS
, (s
, h
), (SEAL
,a
), self
.p_GO_AT
[a
,h
,dir,adj
])
94 for a
in set(argnums
) if (a
,h
,dir) in self
.p_ATTACH
95 for s
, adj
in dir_s_adj
[dir]] ,
96 'STOP': [CNF_DMV_Rule(dir, LHS
, (s
, h
), STOP
, self
.p_STOP
[h
,dir,adj
])
97 for s
, adj
in dir_s_adj
[dir]] }
101 def __init__(self
, numtag
, tagnum
, p_ROOT
, p_STOP
, p_ATTACH
, p_terminals
):
102 io
.Grammar
.__init
__(self
, numtag
, tagnum
, [], p_terminals
)
104 self
.p_ATTACH
= p_ATTACH
106 self
.p_GO_AT
= make_GO_AT(self
.p_STOP
, self
.p_ATTACH
)
107 self
.make_all_rules()
110 class CNF_DMV_Rule(io
.CNF_Rule
):
111 '''A single CNF rule in the PCFG, of the form
113 where LHS, L and R are 'nodes', eg. of the form (seals, head).
121 Different rule-types have different probabilities associated with
122 them, see formulas.pdf
125 return seals(self
.LHS())
128 return POS(self
.LHS())
130 def __init__(self
, dir, LHS
, h_daughter
, a_daughter
, prob
):
133 L
, R
= a_daughter
, h_daughter
135 L
, R
= h_daughter
, a_daughter
137 raise ValueError, "dir must be LEFT or RIGHT, given: %s"%dir
138 for b_h
in [LHS
, L
, R
]:
139 if seals(b_h
) not in SEALS
:
140 raise ValueError("seals must be in %s; was given: %s"
141 % (SEALS
, seals(b_h
)))
142 io
.CNF_Rule
.__init
__(self
, LHS
, L
, R
, prob
)
145 "'undefined' for ROOT"
146 if self
.__dir
== LEFT
:
147 return seals(self
.R()) == RGOL
149 return seals(self
.L()) == GOR
151 def __str__(self
, tag
=lambda x
:x
):
152 if self
.adj(): adj_str
= "adj"
153 else: adj_str
= "non_adj"
154 if self
.LHS() == ROOT
: adj_str
= ""
155 return "%s --> %s %s\t[%.2f] %s" % (node_str(self
.LHS(), tag
),
156 node_str(self
.L(), tag
),
157 node_str(self
.R(), tag
),
167 ###################################
168 # dmv-specific version of inner() #
169 ###################################
170 def inner(i
, j
, LHS
, g
, sent
, ichart
={}):
171 ''' A CNF rewrite of io.inner(), to take STOP rules into accord. '''
175 sent_nums
= g
.sent_nums(sent
)
179 "Tabs for debug output"
181 if (i
, j
, LHS
) in ichart
:
183 print "%s*= %.4f in ichart: i:%d j:%d LHS:%s" % (tab(), ichart
[i
, j
, LHS
], i
, j
, node_str(LHS
))
184 return ichart
[i
, j
, LHS
]
186 # if seals(LHS) == RGOL then we have to STOP first
187 if i
== j
-1 and seals(LHS
) == GOR
:
188 if (LHS
, O(i
,j
)) in g
.p_terminals
:
189 prob
= g
.p_terminals
[LHS
, O(i
,j
)] # "b[LHS, O(s)]" in Lari&Young
193 print "%sLACKING TERMINAL:" % tab()
195 print "%s*= %.4f (terminal: %s -> %s)" % (tab(),prob
, node_str(LHS
), O(i
,j
))
198 p
= 0.0 # "sum over j,k in a[LHS,j,k]"
199 for rule
in g
.arg_rules(LHS
, sent_nums
):
201 print "%ssumming rule %s i:%d j:%d" % (tab(),rule
,i
,j
)
204 # if it's a STOP rule, rewrite for the same xrange:
205 if (L
== STOP
) or (R
== STOP
):
207 pLR
= e(i
, j
, R
, n_t
+1)
209 pLR
= e(i
, j
, L
, n_t
+1)
212 print "%sp= %.4f (STOP)" % (tab(), p
)
214 elif j
> i
+1 and seals(LHS
) != GOR
:
215 # not a STOP, attachment rewrite:
216 for k
in xtween(i
, j
): # i<k<j
217 p_L
= e(i
, k
, L
, n_t
+1)
218 p_R
= e(k
, j
, R
, n_t
+1)
219 p
+= rule
.p() * p_L
* p_R
221 print "%sp= %.4f (ATTACH, p_L:%.4f, p_R:%.4f, rule:%.4f)" % (tab(), p
,p_L
,p_R
,rule
.p())
222 ichart
[i
, j
, LHS
] = p
226 inner_prob
= e(i
,j
,LHS
, 0)
228 print debug_ichart(g
,sent
,ichart
)
230 # end of cnf_dmv.inner(i, j, LHS, g, sent, ichart={})
233 def debug_ichart(g
,sent
,ichart
):
234 str = "---ICHART:---\n"
235 for (i
,j
,LHS
),v
in ichart
.iteritems():
236 if type(v
) == dict: # skip 'tree'
238 str += "%s -> %s ... %s: \t%.4f\n" % (node_str(LHS
,g
.numtag
),
239 sent
[i
], sent
[j
-1], v
)
240 str += "---ICHART:end---\n"
244 def inner_sent(g
, sent
, ichart
={}):
245 return sum([inner(0, len(sent
), ROOT
, g
, sent
, ichart
)])
248 #######################################
249 # cnf_dmv-specific version of outer() #
250 #######################################
251 def outer(i
,j
,w_node
, g
, sent
, ichart
={}, ochart
={}):
253 # or we could just look it up in ichart, assuming ichart to be done
254 return inner(i
, j
, LHS
, g
, sent
, ichart
)
256 sent_nums
= g
.sent_nums(sent
)
257 if POS(w_node
) not in sent_nums
[i
:j
]:
258 # sanity check, w must be able to dominate sent[i:j]
262 if (i
,j
,w_node
) in ochart
:
263 return ochart
[(i
, j
, w_node
)]
265 if i
== 0 and j
== len(sent
):
267 else: # ROOT may only be used on full sentence
268 return 0.0 # but we may have non-ROOTs over full sentence too
272 for rule
in g
.mothersL(w_node
, sent_nums
): # rule.L() == w_node
273 if 'OUTER' in DEBUG
: print "w_node:%s (L) ; %s"%(node_str(w_node
),rule
)
275 p0
= f(i
,j
,rule
.LHS()) * rule
.p()
276 if 'OUTER' in DEBUG
: print p0
279 for k
in xgt(j
,sent
): # i<j<k
280 p0
= f(i
,k
, rule
.LHS() ) * rule
.p() * e(j
,k
, rule
.R() )
281 if 'OUTER' in DEBUG
: print p0
284 for rule
in g
.mothersR(w_node
, sent_nums
): # rule.R() == w_node
285 if 'OUTER' in DEBUG
: print "w_node:%s (R) ; %s"%(node_str(w_node
),rule
)
287 p0
= f(i
,j
,rule
.LHS()) * rule
.p()
288 if 'OUTER' in DEBUG
: print p0
291 for k
in xlt(i
): # k<i<j
292 p0
= e(k
,i
, rule
.L() ) * rule
.p() * f(k
,j
, rule
.LHS() )
293 if 'OUTER' in DEBUG
: print p0
296 ochart
[i
,j
,w_node
] = p
301 # end outer(i,j,w_node, g,sent, ichart,ochart)
305 ##############################
306 # reestimation, todo: #
307 ##############################
308 def reest_zeros(rules
):
311 for nd
in ['num','den']:
312 f
[nd
, r
.LHS(), r
.L(), r
.R()] = 0.0
315 def reest_freq(g
, corpus
):
316 ''' P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
317 f
= reest_zeros(g
.all_rules())
321 p_sent
= None # 50 % speed increase on storing this locally
323 def c_g(i
,j
,LHS
,sent
):
326 return e_g(i
,j
,LHS
,sent
) * f_g(i
,j
,LHS
,sent
) / p_sent
328 def w1_g(i
,j
,rule
,sent
): # unary (stop) rules, LHS -> child_node
329 if rule
.L() == STOP
: child
= rule
.R()
330 elif rule
.R() == STOP
: child
= rule
.L()
331 else: raise ValueError, "expected a stop rule: %s"%(rule
,)
333 if p_sent
== 0.0: return 0.0
335 p_out
= f_g(i
,j
,rule
.LHS(),sent
)
336 if p_out
== 0.0: return 0.0
338 return rule
.p() * e_g(i
,j
,child
,sent
) * p_out
/ p_sent
340 def w_g(i
,j
,rule
,sent
):
341 if p_sent
== 0.0 or i
+1 == j
: return 0.0
343 p_out
= f_g(i
,j
,rule
.LHS(),sent
)
344 if p_out
== 0.0: return 0.0
347 for k
in xtween(i
,j
):
348 p
+= rule
.p() * e_g(i
,k
,rule
.L(),sent
) * e_g(k
,j
,rule
.R(),sent
) * p_out
351 def f_g(i
,j
,LHS
,sent
):
352 if (i
,j
,LHS
) in ochart
:
354 return ochart
[i
,j
,LHS
]
356 return outer(i
,j
,LHS
,g
,sent
,ichart
,ochart
)
358 def e_g(i
,j
,LHS
,sent
):
359 if (i
,j
,LHS
) in ichart
:
361 return ichart
[i
,j
,LHS
]
363 return inner(i
,j
,LHS
,g
,sent
,ichart
)
365 for sn
,sent
in enumerate(corpus
):
366 if sn
%1==0: print "sentence number %d"%sn
367 if 'REEST' in DEBUG
: print sent
370 # since we keep re-using p_sent, it seems better to have
371 # sentences as the outer loop; o/w we'd have to keep every chart
372 p_sent
= inner_sent(g
, sent
, ichart
)
374 sent_nums
= g
.sent_nums(sent
)
375 sent_rules
= g
.sent_rules(sent_nums
)
378 LHS
, L
, R
= r
.LHS(), r
.L(), r
.R()
379 if 'REEST' in DEBUG
: print r
381 f
['num',LHS
,L
,R
] += r
.p() * e_g(0, len(sent
), R
, sent
)
382 f
['den',LHS
,L
,R
] += p_sent
383 continue # !!! o/w we add wrong values to it below
384 if L
== STOP
or R
== STOP
:
388 for i
in xlt(len(sent
)):
389 for j
in xgt(i
, sent
):
390 f
['num',LHS
,L
,R
] += w(i
,j
, r
, sent
)
391 f
['den',LHS
,L
,R
] += c_g(i
,j
, LHS
, sent
)
394 def reestimate(g
, corpus
):
395 f
= reest_freq(g
, corpus
)
396 print "applying f to rules"
397 for r
in g
.all_rules():
398 r
.prob
= f
['den',r
.LHS(),r
.L(),r
.R()]
400 r
.prob
= f
['num',r
.LHS(),r
.L(),r
.R()] / r
.prob
404 ##############################
405 # testing functions: #
406 ##############################
408 # make sure we use the same data:
409 from loc_h_dmv
import testcorpus
413 return cnf_harmonic
.initialize(testcorpus
)
415 def testreestimation():
416 from loc_h_dmv
import testcorpus
418 f
= reestimate(g
, testcorpus
)
421 def testgrammar_a(): # Non, Adj
422 _h_
= CNF_DMV_Rule((SEAL
,0), STOP
, ( RGOL
,0), 1.0, 1.0) # LSTOP
423 h_S
= CNF_DMV_Rule(( RGOL
,0),(GOR
,0), STOP
, 0.4, 0.3) # RSTOP
424 h_A
= CNF_DMV_Rule(( RGOL
,0),(SEAL
,0),( RGOL
,0),0.2, 0.1) # Lattach
425 h_Aa
= CNF_DMV_Rule(( RGOL
,0),(SEAL
,1),( RGOL
,0),0.4, 0.6) # Lattach to a
426 h
= CNF_DMV_Rule((GOR
,0),(GOR
,0),(SEAL
,0), 1.0, 1.0) # Rattach
427 ha
= CNF_DMV_Rule((GOR
,0),(GOR
,0),(SEAL
,1), 1.0, 1.0) # Rattach to a
428 rh
= CNF_DMV_Rule( ROOT
, STOP
, (SEAL
,0), 0.9, 0.9) # ROOT
430 _a_
= CNF_DMV_Rule((SEAL
,1), STOP
, ( RGOL
,1), 1.0, 1.0) # LSTOP
431 a_S
= CNF_DMV_Rule(( RGOL
,1),(GOR
,1), STOP
, 0.4, 0.3) # RSTOP
432 a_A
= CNF_DMV_Rule(( RGOL
,1),(SEAL
,1),( RGOL
,1),0.4, 0.6) # Lattach
433 a_Ah
= CNF_DMV_Rule(( RGOL
,1),(SEAL
,0),( RGOL
,1),0.2, 0.1) # Lattach to h
434 a
= CNF_DMV_Rule((GOR
,1),(GOR
,1),(SEAL
,1), 1.0, 1.0) # Rattach
435 ah
= CNF_DMV_Rule((GOR
,1),(GOR
,1),(SEAL
,0), 1.0, 1.0) # Rattach to h
436 ra
= CNF_DMV_Rule( ROOT
, STOP
, (SEAL
,1), 0.1, 0.1) # ROOT
438 p_rules
= [ h_Aa
, ha
, a_Ah
, ah
, ra
, _a_
, a_S
, a_A
, a
, rh
, _h_
, h_S
, h_A
, h
]
442 b
[(GOR
, 0), 'h'] = 1.0
443 b
[(GOR
, 1), 'a'] = 1.0
445 return CNF_DMV_Grammar({0:'h',1:'a'}, {'h':0,'a':1},
450 p_ROOT
, p_STOP
, p_ATTACH
, p_ORDER
= {},{},{},{}
452 p_STOP
[h
,LEFT
,NON
] = 1.0
453 p_STOP
[h
,LEFT
,ADJ
] = 1.0
454 p_STOP
[h
,RIGHT
,NON
] = 0.4
455 p_STOP
[h
,RIGHT
,ADJ
] = 0.3
456 p_ATTACH
[h
,h
,LEFT
] = 1.0 # not used
457 p_ATTACH
[h
,h
,RIGHT
] = 1.0 # not used
459 p_terminals
[(GOR
, 0), 'h'] = 1.0
461 g
= CNF_DMV_Grammar({h
:'h'}, {'h':h
}, p_ROOT
, p_STOP
, p_ATTACH
, p_terminals
)
463 g
.p_GO_AT
[h
,h
,LEFT
,NON
] = 0.6 # these probabilities are impossible
464 g
.p_GO_AT
[h
,h
,LEFT
,ADJ
] = 0.7 # so add them manually...
465 g
.p_GO_AT
[h
,h
,RIGHT
,NON
] = 1.0
466 g
.p_GO_AT
[h
,h
,RIGHT
,ADJ
] = 1.0
471 def testreestimation_h():
474 return reestimate(g
,['h h h'.split()])
476 def regression_tests():
477 test("0.1830", # = .120 + .063, since we have no loc_h
478 "%.4f" % inner(0, 2, (SEAL
,0), testgrammar_h(), 'h h'.split(), {}))
480 test("0.1842", # = .0498 + .1092 +.0252
481 "%.4f" % inner(0, 3, (SEAL
,0), testgrammar_h(), 'h h h'.split(), {}))
483 "%.4f" % inner_sent(testgrammar_h(), 'h h h'.split()))
486 "%.2f" % outer(1, 3, ( RGOL
,0), testgrammar_h(),'h h h'.split(),{},{}))
488 "%.2f" % outer(1, 3, (NRGOL
,0), testgrammar_h(),'h h h'.split(),{},{}))
491 if __name__
== "__main__":
495 # profile.run('testreestimation()')
497 # DEBUG.add('reest_attach')
499 # print timeit.Timer("cnf_dmv.testreestimation_h()",'''import cnf_dmv
500 # reload(cnf_dmv)''').timeit(1)
502 if __name__
== "__main__":