close to finishing pCHOOSE
[dmvccm.git] / src / dmv.py
blobfec7c50054b8f87bd4542451be3af7f805460f20
1 #### changes by KBU:
2 # 2008-06-12
3 # - outer() seems to be working, wrote c(s,t,LHS,loc_h,...) too now.
4 #
5 # 2008-06-11
6 # - moved prune() to junk.py, now using outer() instead. outer() is
7 # written, but needs testing.
9 # 2008-06-09
10 # - prune() finished, seems to be working.
11 # - started on implementing the other reestimation formulas, in
12 # reestimate()
14 # 2008-06-04
15 # - moved initialization to harmonic.py
17 # 2008-06-03
18 # - fixed a number of little bugs in initialization, where certain
19 # rules were simply not created, or created "backwards"
20 # - dmv.inner() should Work now...
22 # 2008-06-01
23 # - finished typing in dmv.inner(), still have to test and debug
24 # it. The ichart is now four times as big since for any rule we may
25 # have attachments to either the left or the right below, which
26 # upper rules depend on, for selecting probN or probA
28 # 2008-05-30
29 # - copied inner() into this file, to make the very dmv-specific
30 # adjacency stuff work (have to factor that out later on, when it
31 # works).
33 # 2008-05-29
34 # - init_normalize is done, it creates p_STOP, p_ROOT and p_CHOOSE,
35 # and also adds the relevant probabilities to p_rules in a grammar.
36 # Still, each individual rule has to store both adjacent and non_adj
37 # probabilities, and inner() should be able to send some parameter
38 # which lets the rule choose... hopefully... Is this possible to do
39 # top-down even? when the sentence could be all the same words?
40 # todo: extensive testing of identical words in sentences!
41 # - frequencies (only used in initialization) are stored as strings,
42 # but in the rules and p_STOP etc, there are only numbers.
44 # 2008-05-28
45 # - more work on initialization (init_freq and init_normalize),
46 # getting closer to probabilities now.
48 # 2008-05-27
49 # - started on initialization. So far, I have frequencies for
50 # everything, very harmonic. Still need to make these into 1-summing
51 # probabilities
53 # 2008-05-24
54 # - prettier printout for DMV_Rule
55 # - DMV_Rule changed a bit. head, L and R are now all pairs of the
56 # form (seals, head).
57 # - Started on P_STOP, a bit less pseudo now..
61 #import numpy # numpy provides Fast Arrays, for future optimization
62 import io
64 # non-tweakable/constant "lookup" globals
65 GO_R = 0 # was: NOBAR
66 RGO_L = 1 # was: RBAR
67 SEAL = 2 # was: LRBAR
69 # probably need these for combined model, see thesis-appendix:
70 GO_L = 3
71 LGO_R = 4
72 SEALS = [GO_R, RGO_L, SEAL, GO_L, LGO_R]
74 ROOT = (SEAL, -1)
75 STOP = (GO_R, -2)
77 if __name__ == "__main__":
78 print "DMV module tests:"
81 def node(seals, head):
82 '''Useless function, but just here as documentation. Nodes make up
83 LHS, R and L in each DMV_Rule'''
84 return (seals, head)
86 def seals(node):
87 return node[0]
89 def head(node):
90 return node[1]
93 class DMV_Grammar(io.Grammar):
94 '''The DMV-PCFG.
96 Public members:
97 p_STOP, p_ROOT, p_CHOOSE, p_terminals
98 These are changed in the Maximation step, then used to set the
99 new probabilities of each DMV_Rule.
101 Todo: make p_terminals private? (But it has to be changable in
102 maximation step due to the short-cutting rules... could of course
103 make a DMV_Grammar function to update the short-cut rules...)
105 __p_rules is private, but we can still say stuff like:
106 for r in g.all_rules():
107 r.probN = newProbN
109 What other representations do we need? (P_STOP formula uses
110 deps_D(h,l/r) at least)'''
111 def __str__(self):
112 str = ""
113 for r in self.all_rules():
114 str += "%s\n" % r.__str__(self.numtag)
115 return str
117 def h_rules(self, h):
118 return [r for r in self.all_rules() if r.head() == h]
120 def mothersL(self, Node, sent_nums, loc_N):
121 # todo: speed-test with and without sent_nums/loc_N cut-off
122 return [r for r in self.all_rules() if r.L() == Node
123 and (head(r.R()) in sent_nums[loc_N+1:] or r.R() == STOP)]
125 def mothersR(self, Node, sent_nums, loc_N):
126 return [r for r in self.all_rules() if r.R() == Node
127 and (head(r.L()) in sent_nums[:loc_N] or r.L() == STOP)]
129 def rules(self, LHS):
130 return [r for r in self.all_rules() if r.LHS() == LHS]
132 def sent_rules(self, LHS, sent_nums):
133 '''Used in dmv.inner. Todo: this takes a _lot_ of time, it
134 seems. Could use some more space and cache some of this
135 somehow perhaps?'''
136 # We don't want to rule out STOPs!
137 nums = sent_nums + [ head(STOP) ]
138 return [r for r in self.all_rules() if r.LHS() == LHS
139 and head(r.L()) in nums and head(r.R()) in nums]
141 def deps_L(self, head): # todo: do I use this at all?
142 # todo test, probably this list comprehension doesn't work
143 return [a for r in self.all_rules() if r.head() == head and a == r.L()]
145 def deps_R(self, head):
146 # todo test, probably this list comprehension doesn't work
147 return [a for r in self.all_rules() if r.head() == head and a == r.R()]
149 def __init__(self, p_rules, p_terminals, p_STOP, p_CHOOSE, p_ROOT, numtag, tagnum):
150 io.Grammar.__init__(self, p_rules, p_terminals, numtag, tagnum)
151 self.p_STOP = p_STOP
152 self.p_CHOOSE = p_CHOOSE
153 self.p_ROOT = p_ROOT
154 self.head_nums = [k for k in numtag.iterkeys()]
157 class DMV_Rule(io.CNF_Rule):
158 '''A single CNF rule in the PCFG, of the form
159 LHS -> L R
160 where LHS, L and R are 'nodes', eg. of the form (seals, head).
162 Public members:
163 probN, probA
165 Private members:
166 __L, __R, __LHS
168 Different rule-types have different probabilities associated with
169 them:
171 _h_ -> STOP h_ P( STOP|h,L, adj)
172 _h_ -> STOP h_ P( STOP|h,L,non_adj)
173 h_ -> h STOP P( STOP|h,R, adj)
174 h_ -> h STOP P( STOP|h,R,non_adj)
175 h_ -> _a_ h_ P(-STOP|h,L, adj) * P(a|h,L)
176 h_ -> _a_ h_ P(-STOP|h,L,non_adj) * P(a|h,L)
177 h -> h _a_ P(-STOP|h,R, adj) * P(a|h,R)
178 h -> h _a_ P(-STOP|h,R,non_adj) * P(a|h,R)
180 def p(self, adj, *arg):
181 if adj:
182 return self.probA
183 else:
184 return self.probN
186 def p_STOP(self, s, t, loc_h):
187 '''Returns the correct probability, adjacent if we're rewriting from
188 the (either left or right) end of the fragment. '''
189 if self.L() == STOP:
190 return self.p(s == loc_h)
191 elif self.R() == STOP:
192 if not loc_h == s:
193 if 'TODO' in io.DEBUG:
194 print "(%s given loc_h:%d but s:%d. Todo: optimize away!)" % (self, loc_h, s)
195 return 0.0
196 else:
197 return self.p(t == loc_h)
199 def p_ATTACH(self, r, loc_h, s=None):
200 '''Returns the correct probability, adjacent if we haven't attached
201 anything before.'''
202 if self.LHS() == self.L():
203 if not loc_h == s:
204 if 'TODO' in io.DEBUG:
205 print "(%s given loc_h (loc_L):%d but s:%d. Todo: optimize away!)" % (self, loc_h, s)
206 return 0.0
207 else:
208 return self.p(r == loc_h)
209 elif self.LHS() == self.R():
210 return self.p(r+1 == loc_h)
212 def seals(self):
213 return seals(self.LHS())
215 def head(self):
216 return head(self.LHS())
218 def __init__(self, LHS, L, R, probN, probA):
219 for b_h in [LHS, L, R]:
220 if seals(b_h) not in SEALS:
221 raise ValueError("seals must be in %s; was given: %s"
222 % (SEALS, seals(b_h)))
223 io.CNF_Rule.__init__(self, LHS, L, R, probN)
224 self.probA = probA # adjacent
225 self.probN = probN # non_adj
227 @classmethod # so we can call DMV_Rule.bar_str(b_h)
228 def bar_str(cls, b_h, tag=lambda x:x):
229 if(b_h == ROOT):
230 return 'ROOT'
231 elif(b_h == STOP):
232 return 'STOP'
233 elif(seals(b_h) == RGO_L):
234 return " %s_ " % tag(head(b_h))
235 elif(seals(b_h) == SEAL):
236 return "_%s_ " % tag(head(b_h))
237 else:
238 return " %s " % tag(head(b_h))
241 def __str__(self, tag=lambda x:x):
242 return "%s-->%s %s\t[N %.2f] [A %.2f]" % (self.bar_str(self.LHS(), tag),
243 self.bar_str(self.L(), tag),
244 self.bar_str(self.R(), tag),
245 self.probN,
246 self.probA)
254 ###################################
255 # dmv-specific version of inner() #
256 ###################################
257 def locs(h, sent, s=0, t=None, remove=None):
258 '''Return the locations of h in sent, or some fragment of sent (in the
259 latter case we make sure to offset the locations correctly so that
260 for any x in the returned list, sent[x]==h).
262 t is inclusive, to match the way indices work with inner()
263 (although python list-splicing has "exclusive" end indices)'''
264 if t == None:
265 t = len(sent)-1
266 return [i+s for i,w in enumerate(sent[s:t+1])
267 if w == h and not (i+s) == remove]
270 def inner(s, t, LHS, loc_h, g, sent, ichart={}):
271 ''' A rewrite of io.inner(), to take adjacency into accord.
273 The ichart is now of this form:
274 ichart[s,t,LHS, loc_h]
276 loc_h gives adjacency (along with r and location of other child
277 for attachment rules), and is needed in P_STOP reestimation.
279 Todo: if possible, refactor (move dmv-specific stuff back into
280 dmv, so this is "general" enough to be in io.py)
283 def O(s):
284 return sent[s]
286 sent_nums = g.sent_nums(sent)
287 tree = {}
289 def e(s,t,LHS, loc_h, n_t):
290 def tab():
291 "Tabs for debug output"
292 return "\t"*n_t
294 if (s, t, LHS, loc_h) in ichart:
295 if 'INNER' in io.DEBUG:
296 print "%s*= %.4f in ichart: s:%d t:%d LHS:%s loc:%d" % (tab(),ichart[s, t, LHS, loc_h], s, t,
297 DMV_Rule.bar_str(LHS), loc_h)
298 return ichart[s, t, LHS, loc_h]
299 else:
300 if s == t and seals(LHS) == GO_R:
301 if not loc_h == s:
302 if 'INNER' in io.DEBUG:
303 print "%s*= 0.0 (wrong loc_h)" % tab()
304 return 0.0
305 elif (LHS, O(s)) in g.p_terminals:
306 prob = g.p_terminals[LHS, O(s)] # "b[LHS, O(s)]" in Lari&Young
307 else:
308 # todo: assuming this is how to deal w/lacking
309 # rules, since we add prob.s, and 0 is identity
310 prob = 0.0
311 if 'INNER' in io.DEBUG:
312 print "%sLACKING TERMINAL:" % tab()
313 # todo: add to ichart perhaps? Although, it _is_ simple lookup..
314 if 'INNER' in io.DEBUG:
315 print "%s*= %.4f (terminal: %s -> %s_%d)" % (tab(),prob, DMV_Rule.bar_str(LHS), O(s), loc_h)
316 return prob
317 else:
318 p = 0.0 # "sum over j,k in a[LHS,j,k]"
319 for rule in g.sent_rules(LHS, sent_nums):
320 if 'INNER' in io.DEBUG:
321 print "%ssumming rule %s s:%d t:%d loc:%d" % (tab(),rule,s,t,loc_h)
322 L = rule.L()
323 R = rule.R()
324 if (s,t,LHS,loc_h) not in tree:
325 tree[s,t,LHS,loc_h] = set()
326 if loc_h == t and rule.LHS() == L:
327 continue # todo: speed-test
328 if loc_h == s and rule.LHS() == R:
329 continue
330 # if it's a STOP rule, rewrite for the same xrange:
331 if (L == STOP) or (R == STOP):
332 if L == STOP:
333 pLR = e(s, t, R, loc_h, n_t+1)
334 if pLR > 0.0:
335 tree[s,t,LHS,loc_h].add((s,t,R,loc_h))
336 elif R == STOP:
337 pLR = e(s, t, L, loc_h, n_t+1)
338 if pLR > 0.0:
339 tree[s,t,LHS,loc_h].add((s,t,L,loc_h))
340 p += rule.p_STOP(s, t, loc_h) * pLR
341 if 'INNER' in io.DEBUG:
342 print "%sp= %.4f (STOP)" % (tab(), p)
344 elif t > s: # not a STOP, attachment rewrite:
345 rp_ATTACH = rule.p_ATTACH # todo: profile/speedtest
346 for r in xrange(s, t):
347 p_h = rp_ATTACH(r, loc_h, s=s)
348 if rule.LHS() == L:
349 locs_L = [loc_h]
350 locs_R = locs(head(R), sent_nums, r+1, t, loc_h)
351 elif rule.LHS() == R:
352 locs_L = locs(head(L), sent_nums, s, r, loc_h)
353 locs_R = [loc_h]
354 for loc_L in locs_L:
355 pL = e(s, r, L, loc_L, n_t+1)
356 if pL > 0.0:
357 for loc_R in locs_R:
358 pR = e(r+1, t, R, loc_R, n_t+1)
359 if pR > 0.0: # and pL > 0.0
360 tree[s,t,LHS,loc_h].add(( s ,r,L,loc_L))
361 tree[s,t,LHS,loc_h].add((r+1,t,R,loc_R))
362 p += pL * p_h * pR
363 if 'INNER' in io.DEBUG:
364 print "%sp= %.4f (ATTACH)" % (tab(), p)
365 ichart[s, t, LHS, loc_h] = p
366 return p
367 # end of e-function
369 inner_prob = e(s,t,LHS,loc_h, 0)
370 ichart['tree'] = {}
371 if 'INNER' in io.DEBUG:
372 print debug_ichart(g,sent,ichart)
373 return inner_prob
374 # end of dmv.inner(s, t, LHS, loc_h, g, sent, ichart={})
377 def debug_ichart(g,sent,ichart):
378 str = "---ICHART:---\n"
379 for (s,t,LHS,loc_h),v in ichart.iteritems():
380 if type(v) == dict: # skip 'tree'
381 continue
382 str += "%s -> %s_%d ... %s_%d (loc_h:%s):\t%.4f\n" % (DMV_Rule.bar_str(LHS,g.numtag),
383 sent[s], s, sent[s], t, loc_h, v)
384 str += "---ICHART:end---\n"
385 return str
388 def inner_sent(g, sent, ichart={}):
389 return sum([inner(0, len(sent)-1, ROOT, loc_h, g, sent, ichart)
390 for loc_h in xrange(len(sent))])
393 def c(s,t,LHS,loc_h,g,sent,ichart={},ochart={}):
394 # assuming P_sent = P(D(ROOT)) = inner(sent). todo: check K&M about this
395 p_sent = inner_sent(g, sent, ichart)
396 p_in = inner(s,t,LHS,loc_h,g,sent,ichart)
397 p_out = outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
398 if p_sent > 0.0:
399 return p_in * p_out / p_sent
400 else:
401 return p_sent
403 ###################################
404 # dmv-specific version of outer() #
405 ###################################
406 def outer(s,t,Node,loc_N, g, sent, ichart={}, ochart={}):
407 ''' http://www.student.uib.no/~kun041/dmvccm/DMVCCM.html#outer
409 def e(s,t,LHS,loc_h):
410 # or we could just look it up in ichart, assuming ichart to be done
411 return inner(s, t, LHS, loc_h, g, sent, ichart)
413 T = len(sent)-1
414 sent_nums = g.sent_nums(sent)
416 def f(s,t,Node,loc_N):
417 if (s,t,Node,loc_N) in ochart:
418 return ochart[(s, t, Node,loc_N)]
419 if Node == ROOT:
420 if s == 0 and t == T:
421 return 1.0
422 else: # ROOT may only be used on full sentence
423 return 0.0 # but we may have non-ROOTs over full sentence too
424 p = 0.0
426 for mom in g.mothersL(Node, sent_nums, loc_N): # mom.L() == Node
427 R = mom.R()
428 mLHS = mom.LHS()
429 if R == STOP:
430 p += f(s,t,mLHS,loc_N) * mom.p_STOP(s,t,loc_N) # == loc_m
431 else:
432 if seals(mLHS) == RGO_L: # left attachment, head(mLHS) == head(R)
433 for r in xrange(t+1,T+1): # t+1 to lasT
434 for loc_m in locs(head(mLHS),sent_nums,t+1,r):
435 p_m = mom.p(t+1 == loc_m)
436 p += f(s,r,mLHS,loc_m) * p_m * e(t+1,r,R,loc_m)
437 elif seals(mLHS) == GO_R: # right attachment, head(mLHS) == head(Node)
438 loc_m = loc_N
439 p_m = mom.p( t == loc_m)
440 for r in xrange(t+1,T+1): # t+1 to lasT
441 for loc_R in locs(head(R),sent_nums,t+1,r):
442 p += f(s,r,mLHS,loc_m) * p_m * e(t+1,r,R,loc_R)
444 for mom in g.mothersR(Node, sent_nums, loc_N): # mom.R() == Node
445 L = mom.L()
446 mLHS = mom.LHS()
447 if L == STOP:
448 p += f(s,t,mLHS,loc_N) * mom.p_STOP(s,t,loc_N) # == loc_m
449 else:
450 if seals(mLHS) == RGO_L: # left attachment, head(mLHS) == head(Node)
451 loc_m = loc_N
452 p_m = mom.p( s == loc_m)
453 for r in xrange(0,s): # first to s-1
454 for loc_L in locs(head(L),sent_nums,r,s-1):
455 p += e(r,s-1,L, loc_L) * p_m * f(r,t,mLHS,loc_m)
456 elif seals(mLHS) == GO_R: # right attachment, head(mLHS) == head(L)
457 for r in xrange(0,s): # first to s-1
458 for loc_m in locs(head(mLHS),sent_nums,r,s-1):
459 p_m = mom.p(s-1 == loc_m)
460 p += e(r,s-1,L, loc_m) * p_m * f(r,t,mLHS,loc_m)
461 ochart[s,t,Node,loc_N] = p
462 return p
465 return f(s,t,Node,loc_N)
466 # end outer(s,t,Node,loc_N, g,sent, ichart,ochart)
470 ##############################
471 # reestimation, todo: #
472 ##############################
473 def reest_zeros(h_nums):
474 # todo: p_ROOT? ... p_terminals?
475 f = {}
476 for h in h_nums:
477 for stop in ['LNSTOP','LASTOP','RNSTOP','RASTOP']:
478 for nd in ['num','den']:
479 f[stop,nd,h] = 0.0
480 for choice in ['RCHOOSE', 'LCHOOSE']:
481 f[choice,'den',h] = 0.0
482 return f
484 def reest_freq(g, corpus):
485 ''' P_STOP(-STOP|...) = 1 - P_STOP(STOP|...) '''
486 f = reest_zeros(g.head_nums)
487 ichart = {}
488 ochart = {}
490 p_sent = None # 50 % speed increase on storing this locally
491 def c_g(s,t,LHS,loc_h,sent): # altogether 2x faster than the global c()
492 if (s,t,LHS,loc_h) in ichart:
493 p_in = ichart[s,t,LHS,loc_h]
494 else:
495 p_in = inner(s,t,LHS,loc_h,g,sent,ichart)
496 if (s,t,LHS,loc_h) in ochart:
497 p_out = ochart[s,t,LHS,loc_h]
498 else:
499 p_out = outer(s,t,LHS,loc_h,g,sent,ichart,ochart)
501 if p_sent > 0.0:
502 return p_in * p_out / p_sent
503 else:
504 return p_sent
506 for sent in corpus:
507 if 'reest' in io.DEBUG:
508 print sent
509 ichart = {}
510 ochart = {}
511 p_sent = inner_sent(g, sent, ichart)
513 sent_nums = g.sent_nums(sent)
514 # todo: use sum([ichart[s, t...] etc? but can we then
515 # keep den and num separate within _one_ sum()-call?
516 for loc_h,h in enumerate(sent_nums):
517 for t in xrange(loc_h, len(sent)):
518 for s in xrange(loc_h): # s<loc(h), xrange gives strictly less
519 # left non-adjacent stop:
520 f['LNSTOP','num',h] += c_g(s, t, (SEAL, h), loc_h,sent)
521 f['LNSTOP','den',h] += c_g(s, t, (RGO_L,h), loc_h,sent)
522 # left adjacent stop:
523 f['LASTOP','num',h] += c_g(loc_h, t, (SEAL, h), loc_h,sent)
524 f['LASTOP','den',h] += c_g(loc_h, t, (RGO_L,h), loc_h,sent)
525 for t in xrange(loc_h+1, len(sent)):
526 # right non-adjacent stop:
527 f['RNSTOP','num',h] += c_g(loc_h, t, (RGO_L,h), loc_h,sent)
528 f['RNSTOP','den',h] += c_g(loc_h, t, (GO_R, h), loc_h,sent)
529 # right adjacent stop:
530 f['RASTOP','num',h] += c_g(loc_h, loc_h, (RGO_L,h), loc_h,sent)
531 f['RASTOP','den',h] += c_g(loc_h, loc_h, (GO_R, h), loc_h,sent)
533 print sent
535 # right attachment:
536 for t in xrange(loc_h+1, len(sent)):
537 f['RCHOOSE','den',h] += c_g(loc_h,t,(GO_R, h), loc_h, sent)
538 for loc_a,a in enumerate(sent_nums):
539 if loc_a != loc_h:
540 if ('RCHOOSE','num',h,a) not in f:
541 f['RCHOOSE','num',h,a] = 0.0
542 for r in xrange(loc_h+1, t+1): # loc_h < r <= t
543 cL = c_g(loc_h, r-1, (GO_R, h), loc_h, sent)
544 cR = c_g(r, t, (SEAL, a), loc_a, sent)
545 f['RCHOOSE','num',h,a] += cL * cR
547 # left attachment:
548 print "Lattach %s: for s in %s"%(g.numtag(h),sent[0:loc_h])
549 for s in xrange(0, loc_h):
550 print "\tfor t in %s"%sent[loc_h:len(sent)]
551 for t in xrange(loc_h, len(sent)):
552 print "\t\tc_g( %d , %d, _%s, loc_a, sent)"%(s,t,g.numtag(h))
553 f['LCHOOSE','den',h] += c_g(s,t,(RGO_L, h), loc_h, sent)
554 print "\t\tfor r in %s"%(sent[s:loc_h])
555 for r in xrange(s, loc_h): # s <= r < loc_h <= t
556 for i,a in enumerate(sent_nums[s:loc_h]):
557 loc_a = i+s
558 print "\t\t\tc_g( %d , %d, _%s_, loc_a, sent)"%(s,r,g.numtag(a))
559 cL = c_g( s , r, (SEAL, a), loc_a, sent)
560 print "\t\t\tc_g( %d , %d, %s_, loc_h, sent)"%(r+1,t,g.numtag(h))
561 cR = c_g(r+1, t, (RGO_L, h), loc_h, sent)
562 if ('LCHOOSE','num',h,a) not in f:
563 f['LCHOOSE','num',h,a] = 0.0
564 f['LCHOOSE','num',h,a] += cL * cR
565 return f
567 def reestimate(g, corpus):
569 f = reest_freq(g, corpus)
570 # we want to go through only non-ROOT left-STOPs..
571 for r in g.all_rules():
572 reest_rule(r,f, g)
573 return f
575 def reest_rule(r,f, g): # g just for numtag / debug output, remove eventually?
576 "remove 0-prob rules? todo"
577 h = r.head()
578 if r.LHS() == ROOT:
579 return None # not sure what todo yet here
580 if r.L() == STOP or head(r.R()) == h:
581 dir = 'L'
582 elif r.R() == STOP or head(r.L()) == h:
583 dir = 'R'
584 #else error?
586 p_stopN = f[dir+'NSTOP','den',h]
587 if p_stopN > 0.0:
588 p_stopN = f[dir+'NSTOP','num',h] / p_stopN
589 p_stopA = f[dir+'ASTOP','den',h]
590 if p_stopA > 0.0:
591 p_stopA = f[dir+'ASTOP','num',h] / p_stopA
593 if r.L() == STOP or r.R() == STOP:
594 if 'reest' in io.DEBUG:
595 print "p(STOP|%d=%s,%s,N): %.4f (was: %.4f)"%(h,g.numtag(h),dir, p_stopN, r.probN)
596 print "p(STOP|%d=%s,%s,A): %.4f (was: %.4f)"%(h,g.numtag(h),dir, p_stopA, r.probA)
597 r.probN = p_stopN
598 r.probA = p_stopA
599 else: # attachments
600 pchoose = f[dir+'CHOOSE','den',h]
601 if pchoose > 0.0:
602 if head(r.R()) == h: # left attachment
603 a = head(r.L())
604 elif head(r.L()) == h: # right attachment
605 a = head(r.R())
606 pchoose = f['LCHOOSE','num',h, a] / pchoose
607 if 'reest' in io.DEBUG:
608 print "p(%d=%s|%d=%s,%s): %.4f"%(a,g.numtag(a),h,g.numtag(h),dir, pchoose)
609 r.probA = (1-p_stopA) * pchoose
610 r.probN = (1-p_stopN) * pchoose
618 ##############################
619 # testing functions: #
620 ##############################
622 testcorpus = [s.split() for s in ['det nn vbd c vbd','vbd nn c vbd',
623 'det nn vbd', 'det nn vbd c pp',
624 'det nn vbd', 'det vbd vbd c pp',
625 'det nn vbd', 'det nn vbd c vbd',
626 'det nn vbd', 'det nn vbd c vbd',
627 'det nn vbd', 'det nn vbd c vbd',
628 'det nn vbd', 'det nn vbd c pp',
629 'det nn vbd pp', 'det nn vbd', ]]
631 def testgrammar():
632 import harmonic
633 reload(harmonic)
634 return harmonic.initialize(testcorpus)
636 def testreestimation():
637 io.DEBUG.add('reest')
638 g = testgrammar()
639 if not reestimate(g, testcorpus) == {('LNSTOP', 'den', 3): 12.212773236178391, ('RCHOOSE', 'num', 4, 4): 21.165993113329321, ('RASTOP', 'den', 2): 4.0, ('RNSTOP', 'num', 4): 2.5553487221351365, ('LNSTOP', 'den', 2): 1.274904052793207, ('LCHOOSE', 'num', 0, 4): 61.684182043830802, ('LASTOP', 'num', 1): 14.999999999999995, ('RASTOP', 'den', 3): 15.0, ('RCHOOSE', 'num', 4, 2): 14.336991917519207, ('LCHOOSE', 'num', 0, 1): 37.021815131880523, ('LASTOP', 'num', 4): 16.65701084787457, ('RCHOOSE', 'num', 0, 1): 7.0, ('LASTOP', 'num', 0): 4.1600647714443468, ('LCHOOSE', 'num', 2, 4): 24.584578647023847, ('LNSTOP', 'den', 4): 6.0170669155897105, ('RCHOOSE', 'num', 3, 2): 12.459705909121311, ('LASTOP', 'num', 3): 2.7872267638216113, ('RCHOOSE', 'num', 3, 1): 33.272425367699817, ('RCHOOSE', 'num', 3, 0): 31.14144593225323, ('LASTOP', 'num', 2): 2.9723139990470515, ('LCHOOSE', 'den', 4): 6.0170669155897105, ('LCHOOSE', 'num', 4, 2): 9.1747523009571736, ('LCHOOSE', 'num', 2, 1): 19.69614336189014, ('LASTOP', 'den', 2): 4.0, ('RCHOOSE', 'den', 4): 2.8705211946979836, ('RNSTOP', 'den', 3): 12.945787931730905, ('RCHOOSE', 'num', 4, 1): 18.8586091476162, ('LCHOOSE', 'num', 4, 4): 38.665348775409555, ('LASTOP', 'den', 3): 14.999999999999996, ('RCHOOSE', 'den', 3): 12.945787931730905, ('LCHOOSE', 'num', 0, 2): 11.067546558158417, ('RNSTOP', 'den', 2): 0.0, ('LCHOOSE', 'num', 3, 0): 7.0000000000000009, ('LASTOP', 'den', 0): 8.0, ('RCHOOSE', 'den', 2): 0.0, ('LCHOOSE', 'den', 1): 0.0, ('RASTOP', 'num', 4): 19.44465127786486, ('RNSTOP', 'den', 1): 3.1966410324085777, ('LASTOP', 'den', 1): 14.999999999999995, ('RCHOOSE', 'den', 1): 3.1966410324085777, ('LCHOOSE', 'den', 0): 4.6540557322109795, ('RASTOP', 'num', 3): 4.1061665495365558, ('RNSTOP', 'den', 0): 4.8282499043902476, ('LNSTOP', 'num', 4): 5.3429891521254289, ('RCHOOSE', 'num', 4, 3): 17.107681762009616, ('RCHOOSE', 'den', 0): 4.8282499043902476, ('LCHOOSE', 'den', 3): 12.212773236178391, ('RASTOP', 'num', 2): 4.0, ('LCHOOSE', 'num', 2, 3): 13.357495028303836, ('RCHOOSE', 'num', 0, 3): 7.0, ('LASTOP', 'den', 4): 22.0, ('LCHOOSE', 'den', 2): 1.274904052793207, ('RASTOP', 'num', 1): 12.400273895299103, ('LNSTOP', 'num', 2): 1.0276860009529487, ('RCHOOSE', 'num', 1, 0): 35.275732604734458, ('LCHOOSE', 'num', 4, 0): 38.815876450753834, ('RCHOOSE', 'num', 0, 2): 5.4153585824117219, ('RASTOP', 'num', 0): 3.1717500956097533, ('RCHOOSE', 'num', 1, 3): 44.726704286252108, ('LNSTOP', 'num', 3): 12.212773236178391, ('LCHOOSE', 'num', 4, 1): 79.825243724755524, ('RASTOP', 'den', 4): 22.0, ('RNSTOP', 'den', 4): 2.8705211946979836, ('LNSTOP', 'num', 0): 3.8399352285556518, ('RCHOOSE', 'num', 4, 0): 25.799832884567792, ('LCHOOSE', 'num', 0, 3): 29.062168307358732, ('LNSTOP', 'num', 1): 0.0, ('RCHOOSE', 'num', 0, 4): 17.272366992134707, ('LCHOOSE', 'num', 3, 4): 22.334127121479728, ('RCHOOSE', 'num', 3, 4): 68.67411582483453, ('RNSTOP', 'num', 0): 4.8282499043902476, ('LCHOOSE', 'num', 3, 2): 3.0, ('LCHOOSE', 'num', 2, 0): 13.84015209565697, ('RCHOOSE', 'num', 1, 2): 19.929658772364007, ('LCHOOSE', 'num', 3, 1): 48.942276305039663, ('RNSTOP', 'num', 1): 2.5997261047008959, ('LNSTOP', 'den', 1): 0.0, ('RCHOOSE', 'num', 1, 4): 84.97599440279366, ('RASTOP', 'den', 0): 8.0, ('RNSTOP', 'num', 2): 0.0, ('LNSTOP', 'den', 0): 4.6540557322109795, ('LCHOOSE', 'num', 4, 3): 58.697169235084786, ('RASTOP', 'den', 1): 15.0, ('RNSTOP', 'num', 3): 10.893833450463443}:
640 print "sthg different in reestimation"
644 def testgrammar_a(): # Non, Adj
645 _h_ = DMV_Rule((SEAL,0), STOP, ( RGO_L,0), 1.0, 1.0) # LSTOP
646 h_S = DMV_Rule(( RGO_L,0),(GO_R,0), STOP, 0.4, 0.3) # RSTOP
647 h_A = DMV_Rule(( RGO_L,0),(SEAL,0),( RGO_L,0),0.2, 0.1) # Lattach
648 h_Aa= DMV_Rule(( RGO_L,0),(SEAL,1),( RGO_L,0),0.4, 0.6) # Lattach to a
649 h = DMV_Rule((GO_R,0),(GO_R,0),(SEAL,0), 1.0, 1.0) # Rattach
650 ha = DMV_Rule((GO_R,0),(GO_R,0),(SEAL,1), 1.0, 1.0) # Rattach to a
651 rh = DMV_Rule( ROOT, STOP, (SEAL,0), 0.9, 0.9) # ROOT
653 _a_ = DMV_Rule((SEAL,1), STOP, ( RGO_L,1), 1.0, 1.0) # LSTOP
654 a_S = DMV_Rule(( RGO_L,1),(GO_R,1), STOP, 0.4, 0.3) # RSTOP
655 a_A = DMV_Rule(( RGO_L,1),(SEAL,1),( RGO_L,1),0.4, 0.6) # Lattach
656 a_Ah= DMV_Rule(( RGO_L,1),(SEAL,0),( RGO_L,1),0.2, 0.1) # Lattach to h
657 a = DMV_Rule((GO_R,1),(GO_R,1),(SEAL,1), 1.0, 1.0) # Rattach
658 ah = DMV_Rule((GO_R,1),(GO_R,1),(SEAL,0), 1.0, 1.0) # Rattach to h
659 ra = DMV_Rule( ROOT, STOP, (SEAL,1), 0.1, 0.1) # ROOT
661 b2 = {}
662 b2[(GO_R, 0), 'h'] = 1.0
663 b2[(GO_R, 1), 'a'] = 1.0
665 return DMV_Grammar([ h_Aa, ha, a_Ah, ah, ra, _a_, a_S, a_A, a, rh, _h_, h_S, h_A, h ],b2,0,0,0, {0:'h',1:'a'}, {'h':0,'a':1})
666 def oa(s,t,LHS,loc_h):
667 return outer(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
668 def ia(s,t,LHS,loc_h):
669 return inner(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
670 def ca(s,t,LHS,loc_h):
671 return c(s,t,LHS,loc_h,testgrammar_a(),'h a'.split())
673 def testgrammar_h(): # Non, Adj
674 _h_ = DMV_Rule((SEAL,0), STOP, ( RGO_L,0), 1.0, 1.0) # LSTOP
675 h_S = DMV_Rule(( RGO_L,0),(GO_R,0), STOP, 0.4, 0.3) # RSTOP
676 h_A = DMV_Rule(( RGO_L,0),(SEAL,0),( RGO_L,0), 0.6, 0.7) # Lattach
677 h = DMV_Rule((GO_R,0),(GO_R,0),(SEAL,0), 1.0, 1.0) # Rattach
678 rh = DMV_Rule( ROOT, STOP, (SEAL,0), 1.0, 1.0) # ROOT
679 b2 = {}
680 b2[(GO_R, 0), 'h'] = 1.0
682 return DMV_Grammar([ rh, _h_, h_S, h_A, h ],b2,0,0,0, {0:'h'}, {'h':0})
685 def testreestimation_h():
686 io.DEBUG.add('reest')
687 g = testgrammar_h()
688 reestimate(g,['h h h'.split()])
690 def regression_tests():
691 def test(wanted, got):
692 if not wanted == got:
693 print "Regression! Should be %s: %s" % (wanted, got)
695 g_dup = testgrammar_h()
697 test("0.120",
698 "%.3f" % inner(0, 1, (SEAL,0), 0, g_dup, 'h h'.split(), {}))
700 test("0.063",
701 "%.3f" % inner(0, 1, (SEAL,0), 1, g_dup, 'h h'.split(), {}))
703 test("0.0498",
704 "%.4f" % inner(0, 2, (SEAL,0), 2, g_dup, 'h h h'.split(), {}))
706 test("0.58" ,
707 "%.2f" % outer(1,2,(1,0),2,testgrammar_h(),'h h h'.split(),{},{}))
709 test("0.1089" ,
710 "%.4f" % outer(0,0,(0,0),0,testgrammar_a(),'h a'.split(),{},{}))
713 if __name__ == "__main__":
714 io.DEBUG.clear()
716 # import profile
717 # profile.run('testreestimation()')
719 import timeit
720 print timeit.Timer("dmv.testreestimation()",'''import dmv
721 reload(dmv)''').timeit(1)
722 regression_tests()