d596deca1e1f1e3f672e207e8c45f3179b62875a
3 #import numpy # numpy provides Fast Arrays, for future optimization
6 from common_dmv
import *
7 SEALS
= [GOR
, RGOL
, SEAL
, NGOR
, NRGOL
] # overwriting here
10 if __name__
== "__main__":
11 print "cnf_dmv module tests:"
13 def make_GO_AT(p_STOP
,p_ATTACH
):
15 for (a
,h
,dir), p_ah
in p_ATTACH
.iteritems():
16 p_GO_AT
[a
,h
,dir, NON
] = p_ah
* (1-p_STOP
[h
, dir, NON
])
17 p_GO_AT
[a
,h
,dir, ADJ
] = p_ah
* (1-p_STOP
[h
, dir, ADJ
])
20 class CNF_DMV_Grammar(io
.Grammar
):
24 p_STOP, p_ROOT, p_ATTACH, p_terminals
25 These are changed in the Maximation step, then used to set the
26 new probabilities of each CNF_DMV_Rule.
28 __p_rules is private, but we can still say stuff like:
29 for r in g.all_rules():
30 r.prob = (1-p_STOP[...]) * p_ATTACH[...]
34 for r
in self
.all_rules():
35 str += "%s\n" % r
.__str
__(self
.numtag
)
39 return [ROOT
] + [(s_h
,h
)
40 for h
in self
.headnums()
43 def sent_rules(self
, sent_nums
):
44 sent_nums_stop
= sent_nums
+ [POS(STOP
)]
45 return [ r
for LHS
in self
.LHSs()
46 for r
in self
.arg_rules(LHS
, sent_nums
)
47 if POS(r
.L()) in sent_nums_stop
48 and POS(r
.R()) in sent_nums_stop
]
51 def mothersR(self
, w_node
, argnums
):
52 '''For all LHS and x, return all rules of the form 'LHS->x w_node'.'''
53 if w_node
not in self
.__mothersR
:
54 self
.__mothersR
[w_node
] = [r
for LHS
in self
.LHSs()
55 for r
in self
.rules(LHS
)
57 return [r
for r
in self
.__mothersR
[w_node
]
58 if POS(r
.L()) in argnums
]
60 def mothersL(self
, w_node
, argnums
):
61 '''For all LHS and x, return all rules of the form 'LHS->w_node x'.'''
62 if w_node
not in self
.__mothersL
:
63 self
.__mothersL
[w_node
] = [r
for LHS
in self
.LHSs()
64 for r
in self
.rules(LHS
)
66 return [r
for r
in self
.__mothersL
[w_node
]
67 if POS(r
.R()) in argnums
]
71 def arg_rules(self
, LHS
, argnums
):
72 return [r
for r
in self
.rules(LHS
)
73 if (POS(r
.R()) in argnums
74 or POS(r
.L()) in argnums
)]
77 def make_all_rules(self
):
78 self
.new_rules([r
for LHS
in self
.LHSs()
79 for r
in self
._make
_rules
(LHS
, self
.headnums())])
81 def _make_rules(self
, LHS
, argnums
):
82 '''This is where the CNF grammar is defined. Also, s_dir_typ shows how
83 useful it'd be to split up the seals into direction and
87 return [CNF_DMV_Rule(LEFT
, LHS
, (SEAL
,h
), STOP
, self
.p_ROOT
[h
])
88 for h
in set(argnums
)]
91 return [] # only terminals from here on
92 s_dir_type
= { # seal of LHS
93 RGOL
: (RIGHT
, 'STOP'), NGOR
: (RIGHT
, 'ATTACH'),
94 SEAL
: (LEFT
, 'STOP'), NRGOL
: (LEFT
, 'ATTACH') }
95 dir_s_adj
= { # seal of h_daughter
96 RIGHT
: [(GOR
, True),(NGOR
, False)] ,
97 LEFT
: [(RGOL
,True),(NRGOL
,False)] }
98 dir,type = s_dir_type
[s_h
]
100 'ATTACH': [CNF_DMV_Rule(dir, LHS
, (s
, h
), (SEAL
,a
), self
.p_GO_AT
[a
,h
,dir,adj
])
101 for a
in set(argnums
) if (a
,h
,dir) in self
.p_ATTACH
102 for s
, adj
in dir_s_adj
[dir]] ,
103 'STOP': [CNF_DMV_Rule(dir, LHS
, (s
, h
), STOP
, self
.p_STOP
[h
,dir,adj
])
104 for s
, adj
in dir_s_adj
[dir]] }
108 def __init__(self
, numtag
, tagnum
, p_ROOT
, p_STOP
, p_ATTACH
, p_terminals
):
109 io
.Grammar
.__init
__(self
, numtag
, tagnum
, [], p_terminals
)
111 self
.p_ATTACH
= p_ATTACH
113 self
.p_GO_AT
= make_GO_AT(self
.p_STOP
, self
.p_ATTACH
)
114 self
.make_all_rules()
119 class CNF_DMV_Rule(io
.CNF_Rule
):
120 '''A single CNF rule in the PCFG, of the form
122 where LHS, L and R are 'nodes', eg. of the form (seals, head).
130 Different rule-types have different probabilities associated with
131 them, see formulas.pdf
134 return seals(self
.LHS())
137 return POS(self
.LHS())
139 def __init__(self
, dir, LHS
, h_daughter
, a_daughter
, prob
):
142 L
, R
= a_daughter
, h_daughter
144 L
, R
= h_daughter
, a_daughter
146 raise ValueError, "dir must be LEFT or RIGHT, given: %s"%dir
147 for b_h
in [LHS
, L
, R
]:
148 if seals(b_h
) not in SEALS
:
149 raise ValueError("seals must be in %s; was given: %s"
150 % (SEALS
, seals(b_h
)))
151 io
.CNF_Rule
.__init
__(self
, LHS
, L
, R
, prob
)
154 "'undefined' for ROOT"
155 if self
.__dir
== LEFT
:
156 return seals(self
.R()) == RGOL
158 return seals(self
.L()) == GOR
160 def __str__(self
, tag
=lambda x
:x
):
161 if self
.adj(): adj_str
= "adj"
162 else: adj_str
= "non_adj"
163 if self
.LHS() == ROOT
: adj_str
= ""
164 return "%s --> %s %s\t[%.2f] %s" % (node_str(self
.LHS(), tag
),
165 node_str(self
.L(), tag
),
166 node_str(self
.R(), tag
),
176 ###################################
177 # dmv-specific version of inner() #
178 ###################################
179 def inner(i
, j
, LHS
, g
, sent
, ichart
={}):
180 ''' A CNF rewrite of io.inner(), to take STOP rules into accord. '''
184 sent_nums
= g
.sent_nums(sent
)
188 "Tabs for debug output"
190 if (i
, j
, LHS
) in ichart
:
192 print "%s*= %.4f in ichart: i:%d j:%d LHS:%s" % (tab(), ichart
[i
, j
, LHS
], i
, j
, node_str(LHS
))
193 return ichart
[i
, j
, LHS
]
195 # if seals(LHS) == RGOL then we have to STOP first
196 if i
== j
-1 and seals(LHS
) == GOR
:
197 if (LHS
, O(i
,j
)) in g
.p_terminals
:
198 prob
= g
.p_terminals
[LHS
, O(i
,j
)] # "b[LHS, O(s)]" in Lari&Young
202 print "%sLACKING TERMINAL:" % tab()
204 print "%s*= %.4f (terminal: %s -> %s)" % (tab(),prob
, node_str(LHS
), O(i
,j
))
207 p
= 0.0 # "sum over j,k in a[LHS,j,k]"
208 for rule
in g
.arg_rules(LHS
, sent_nums
):
210 print "%ssumming rule %s i:%d j:%d" % (tab(),rule
,i
,j
)
213 # if it's a STOP rule, rewrite for the same xrange:
214 if (L
== STOP
) or (R
== STOP
):
216 pLR
= e(i
, j
, R
, n_t
+1)
218 pLR
= e(i
, j
, L
, n_t
+1)
221 print "%sp= %.4f (STOP)" % (tab(), p
)
223 elif j
> i
+1 and seals(LHS
) != GOR
:
224 # not a STOP, attachment rewrite:
225 for k
in xtween(i
, j
): # i<k<j
226 p_L
= e(i
, k
, L
, n_t
+1)
227 p_R
= e(k
, j
, R
, n_t
+1)
228 p
+= rule
.p() * p_L
* p_R
230 print "%sp= %.4f (ATTACH, p_L:%.4f, p_R:%.4f, rule:%.4f)" % (tab(), p
,p_L
,p_R
,rule
.p())
231 ichart
[i
, j
, LHS
] = p
235 inner_prob
= e(i
,j
,LHS
, 0)
237 print debug_ichart(g
,sent
,ichart
)
239 # end of cnf_dmv.inner(i, j, LHS, g, sent, ichart={})
242 def debug_ichart(g
,sent
,ichart
):
243 str = "---ICHART:---\n"
244 for (i
,j
,LHS
),v
in ichart
.iteritems():
245 if type(v
) == dict: # skip 'tree'
247 str += "%s -> %s ... %s: \t%.4f\n" % (node_str(LHS
,g
.numtag
),
248 sent
[i
], sent
[j
-1], v
)
249 str += "---ICHART:end---\n"
253 def inner_sent(g
, sent
, ichart
={}):
254 return sum([inner(0, len(sent
), ROOT
, g
, sent
, ichart
)])
257 #######################################
258 # cnf_dmv-specific version of outer() #
259 #######################################
260 def outer(i
,j
,w_node
, g
, sent
, ichart
={}, ochart
={}):
262 # or we could just look it up in ichart, assuming ichart to be done
263 return inner(i
, j
, LHS
, g
, sent
, ichart
)
265 sent_nums
= g
.sent_nums(sent
)
266 if POS(w_node
) not in sent_nums
[i
:j
]:
267 # sanity check, w must be able to dominate sent[i:j]
271 if (i
,j
,w_node
) in ochart
:
272 return ochart
[(i
, j
, w_node
)]
274 if i
== 0 and j
== len(sent
):
276 else: # ROOT may only be used on full sentence
277 return 0.0 # but we may have non-ROOTs over full sentence too
281 for rule
in g
.mothersL(w_node
, sent_nums
): # rule.L() == w_node
282 if 'OUTER' in DEBUG
: print "w_node:%s (L) ; %s"%(node_str(w_node
),rule
)
284 p0
= f(i
,j
,rule
.LHS()) * rule
.p()
285 if 'OUTER' in DEBUG
: print p0
288 for k
in xgt(j
,sent
): # i<j<k
289 p0
= f(i
,k
, rule
.LHS() ) * rule
.p() * e(j
,k
, rule
.R() )
290 if 'OUTER' in DEBUG
: print p0
293 for rule
in g
.mothersR(w_node
, sent_nums
): # rule.R() == w_node
294 if 'OUTER' in DEBUG
: print "w_node:%s (R) ; %s"%(node_str(w_node
),rule
)
296 p0
= f(i
,j
,rule
.LHS()) * rule
.p()
297 if 'OUTER' in DEBUG
: print p0
300 for k
in xlt(i
): # k<i<j
301 p0
= e(k
,i
, rule
.L() ) * rule
.p() * f(k
,j
, rule
.LHS() )
302 if 'OUTER' in DEBUG
: print p0
305 ochart
[i
,j
,w_node
] = p
310 # end outer(i,j,w_node, g,sent, ichart,ochart)
314 ##############################
315 # reestimation, todo: #
316 ##############################
317 def reest_zeros(rules
):
320 for nd
in ['num','den']:
321 f
[nd
, r
.LHS(), r
.L(), r
.R()] = 0.0
324 def reest_freq(g
, corpus
):
325 ''' P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
326 f
= reest_zeros(g
.all_rules())
330 p_sent
= None # 50 % speed increase on storing this locally
332 def c_g(i
,j
,LHS
,sent
):
335 return e_g(i
,j
,LHS
,sent
) * f_g(i
,j
,LHS
,sent
) / p_sent
337 def w1_g(i
,j
,rule
,sent
): # unary (stop) rules, LHS -> child_node
338 if rule
.L() == STOP
: child
= rule
.R()
339 elif rule
.R() == STOP
: child
= rule
.L()
340 else: raise ValueError, "expected a stop rule: %s"%(rule
,)
342 if p_sent
== 0.0: return 0.0
344 p_out
= f_g(i
,j
,rule
.LHS(),sent
)
345 if p_out
== 0.0: return 0.0
347 return rule
.p() * e_g(i
,j
,child
,sent
) * p_out
/ p_sent
349 def w_g(i
,j
,rule
,sent
):
350 if p_sent
== 0.0 or i
+1 == j
: return 0.0
352 p_out
= f_g(i
,j
,rule
.LHS(),sent
)
353 if p_out
== 0.0: return 0.0
356 for k
in xtween(i
,j
):
357 p
+= rule
.p() * e_g(i
,k
,rule
.L(),sent
) * e_g(k
,j
,rule
.R(),sent
) * p_out
360 def f_g(i
,j
,LHS
,sent
):
361 if (i
,j
,LHS
) in ochart
:
363 return ochart
[i
,j
,LHS
]
365 return outer(i
,j
,LHS
,g
,sent
,ichart
,ochart
)
367 def e_g(i
,j
,LHS
,sent
):
368 if (i
,j
,LHS
) in ichart
:
370 return ichart
[i
,j
,LHS
]
372 return inner(i
,j
,LHS
,g
,sent
,ichart
)
374 for sn
,sent
in enumerate(corpus
):
375 if sn
%1==0: print "sentence number %d"%sn
376 if 'REEST' in DEBUG
: print sent
379 # since we keep re-using p_sent, it seems better to have
380 # sentences as the outer loop; o/w we'd have to keep every chart
381 p_sent
= inner_sent(g
, sent
, ichart
)
383 sent_nums
= g
.sent_nums(sent
)
384 sent_rules
= g
.sent_rules(sent_nums
)
387 LHS
, L
, R
= r
.LHS(), r
.L(), r
.R()
388 if 'REEST' in DEBUG
: print r
390 f
['num',LHS
,L
,R
] += r
.p() * e_g(0, len(sent
), R
, sent
)
391 f
['den',LHS
,L
,R
] += p_sent
392 continue # !!! o/w we add wrong values to it below
393 if L
== STOP
or R
== STOP
:
397 for i
in xlt(len(sent
)):
398 for j
in xgt(i
, sent
):
399 f
['num',LHS
,L
,R
] += w(i
,j
, r
, sent
)
400 f
['den',LHS
,L
,R
] += c_g(i
,j
, LHS
, sent
)
403 def reestimate(g
, corpus
):
404 f
= reest_freq(g
, corpus
)
405 print "applying f to rules"
406 for r
in g
.all_rules():
407 r
.prob
= f
['den',r
.LHS(),r
.L(),r
.R()]
409 r
.prob
= f
['num',r
.LHS(),r
.L(),r
.R()] / r
.prob
413 ##############################
414 # testing functions: #
415 ##############################
417 # make sure we use the same data:
418 from loc_h_dmv
import testcorpus
422 return cnf_harmonic
.initialize(testcorpus
)
424 def testreestimation():
425 from loc_h_dmv
import testcorpus
427 f
= reestimate(g
, testcorpus
[0:4])
430 def testgrammar_a(): # Non, Adj
431 _h_
= CNF_DMV_Rule((SEAL
,0), STOP
, ( RGOL
,0), 1.0, 1.0) # LSTOP
432 h_S
= CNF_DMV_Rule(( RGOL
,0),(GOR
,0), STOP
, 0.4, 0.3) # RSTOP
433 h_A
= CNF_DMV_Rule(( RGOL
,0),(SEAL
,0),( RGOL
,0),0.2, 0.1) # Lattach
434 h_Aa
= CNF_DMV_Rule(( RGOL
,0),(SEAL
,1),( RGOL
,0),0.4, 0.6) # Lattach to a
435 h
= CNF_DMV_Rule((GOR
,0),(GOR
,0),(SEAL
,0), 1.0, 1.0) # Rattach
436 ha
= CNF_DMV_Rule((GOR
,0),(GOR
,0),(SEAL
,1), 1.0, 1.0) # Rattach to a
437 rh
= CNF_DMV_Rule( ROOT
, STOP
, (SEAL
,0), 0.9, 0.9) # ROOT
439 _a_
= CNF_DMV_Rule((SEAL
,1), STOP
, ( RGOL
,1), 1.0, 1.0) # LSTOP
440 a_S
= CNF_DMV_Rule(( RGOL
,1),(GOR
,1), STOP
, 0.4, 0.3) # RSTOP
441 a_A
= CNF_DMV_Rule(( RGOL
,1),(SEAL
,1),( RGOL
,1),0.4, 0.6) # Lattach
442 a_Ah
= CNF_DMV_Rule(( RGOL
,1),(SEAL
,0),( RGOL
,1),0.2, 0.1) # Lattach to h
443 a
= CNF_DMV_Rule((GOR
,1),(GOR
,1),(SEAL
,1), 1.0, 1.0) # Rattach
444 ah
= CNF_DMV_Rule((GOR
,1),(GOR
,1),(SEAL
,0), 1.0, 1.0) # Rattach to h
445 ra
= CNF_DMV_Rule( ROOT
, STOP
, (SEAL
,1), 0.1, 0.1) # ROOT
447 p_rules
= [ h_Aa
, ha
, a_Ah
, ah
, ra
, _a_
, a_S
, a_A
, a
, rh
, _h_
, h_S
, h_A
, h
]
451 b
[(GOR
, 0), 'h'] = 1.0
452 b
[(GOR
, 1), 'a'] = 1.0
454 return CNF_DMV_Grammar({0:'h',1:'a'}, {'h':0,'a':1},
459 p_ROOT
, p_STOP
, p_ATTACH
, p_ORDER
= {},{},{},{}
461 p_STOP
[h
,LEFT
,NON
] = 1.0
462 p_STOP
[h
,LEFT
,ADJ
] = 1.0
463 p_STOP
[h
,RIGHT
,NON
] = 0.4
464 p_STOP
[h
,RIGHT
,ADJ
] = 0.3
465 p_ATTACH
[h
,h
,LEFT
] = 1.0 # not used
466 p_ATTACH
[h
,h
,RIGHT
] = 1.0 # not used
468 p_terminals
[(GOR
, 0), 'h'] = 1.0
470 g
= CNF_DMV_Grammar({h
:'h'}, {'h':h
}, p_ROOT
, p_STOP
, p_ATTACH
, p_terminals
)
472 g
.p_GO_AT
[h
,h
,LEFT
,NON
] = 0.6 # these probabilities are impossible
473 g
.p_GO_AT
[h
,h
,LEFT
,ADJ
] = 0.7 # so add them manually...
474 g
.p_GO_AT
[h
,h
,RIGHT
,NON
] = 1.0
475 g
.p_GO_AT
[h
,h
,RIGHT
,ADJ
] = 1.0
480 def testreestimation_h():
483 return reestimate(g
,['h h h'.split()])
485 def regression_tests():
486 test("0.1830", # = .120 + .063, since we have no loc_h
487 "%.4f" % inner(0, 2, (SEAL
,0), testgrammar_h(), 'h h'.split(), {}))
489 test("0.1842", # = .0498 + .1092 +.0252
490 "%.4f" % inner(0, 3, (SEAL
,0), testgrammar_h(), 'h h h'.split(), {}))
492 "%.4f" % inner_sent(testgrammar_h(), 'h h h'.split()))
495 "%.2f" % outer(1, 3, ( RGOL
,0), testgrammar_h(),'h h h'.split(),{},{}))
497 "%.2f" % outer(1, 3, (NRGOL
,0), testgrammar_h(),'h h h'.split(),{},{}))
500 if __name__
== "__main__":
504 # profile.run('testreestimation()')
506 # DEBUG.add('reest_attach')
508 # print timeit.Timer("cnf_dmv.testreestimation_h()",'''import cnf_dmv
509 # reload(cnf_dmv)''').timeit(1)
511 if __name__
== "__main__":
515 print "TODO!!!! fix outer (also, make mothersL and R faster somehow)"