initial
[prop.git] / prop-src / rwgen3.pcc
blob7ac11bafea3f016b797d9b73b8c357f533abd5c8
1 ///////////////////////////////////////////////////////////////////////////////
2 //
3 //  This file implements the dynamic tree parser algorithm, which is
4 //  used to parse a tree grammar with associated reduction cost functions.
5 //
6 ///////////////////////////////////////////////////////////////////////////////
7 #include <iostream.h>
8 #include <AD/contain/bitset.h>
9 #include "funmap.ph"
10 #include "ir.ph"
11 #include "ast.ph"
12 #include "matchcom.ph"
13 #include "type.h"
14 #include "hashtab.h"
15 #include "rwgen.h"
16 #include "list.h"
17 #include "options.h"
19 extern Id redex_name(Ty);
21 ///////////////////////////////////////////////////////////////////////////////
23 //  Top level method to generate a dynamic tree parser.
24 //  We use a simple dynamic programming algorithm.
26 ///////////////////////////////////////////////////////////////////////////////
27 void RewritingCompiler::gen_dynamic_rewriter (FunctorMap& F)
28 {  
29    generate_state_record(F);        // generate the state record definition
30    generate_accept_rules_tables(F); // generate the accept rule tables
31    generate_closures(F);            // generate the closure routines
32    generate_dynamic_labelers(F);    // generate the labeler functions
33    generate_reducers(F);            // generate the reducer functions
34    // Generate report
35    if (options.generate_report) F.print_report(open_logfile());
38 ///////////////////////////////////////////////////////////////////////////////
40 //  Method to generate the state record.
42 ///////////////////////////////////////////////////////////////////////////////
43 void RewritingCompiler::generate_state_record (FunctorMap& F)
44 {  pr("\n"
45       "%^%/"
46       "%^// State record for rewrite class %s"
47       "%^%/"
48       "%^struct %s_StateRec {\n"
49       "%^   TreeTables::Cost cost[%i]; // cost for each non-terminal"
50       "%^   struct { // accept rule number",
51       F.class_name, F.class_name, F.nonterm_map.size()+1
52      );
54    foreach_entry (e, F.nonterm_rules_bits)
55    {  Id  lhs  = Id(e->k);
56       int bits = int(e->v);
57       pr("%^      unsigned int _%S : %i;", lhs, bits);
58    }
60    pr("%^   } rule;"
61       "%^};\n\n");
64 ///////////////////////////////////////////////////////////////////////////////
66 //  Method to generate the accept rule tables.
68 ///////////////////////////////////////////////////////////////////////////////
69 void RewritingCompiler::generate_accept_rules_tables (FunctorMap& F)
70 {  pr("%^%/"
71       "%^// Accept rules tables for rewrite class %s"
72       "%^%/",
73       F.class_name
74      );
76    foreach_entry (e, F.nonterm_rules)
77    {  Id         lhs   = Id(e->k);
78       MatchRules rules = MatchRules(e->v);
79       int max_rule     = 0;
81       match while (rules)
82       {  #[ one ... rest ]:
83          {  if (max_rule < one->rule_number) max_rule = one->rule_number;
84             rules = rest;
85          }
86       }
88       Id storage_class = max_rule < 128 ? "char" : "short";
90       pr ("%^const %s %s_%S_accept[] = { -1, ", 
91           storage_class, F.class_name, lhs);
93       rules = MatchRules(e->v);
94       match while (rules)
95       {  #[ one ... rest ]:
96          {  pr ("%i%s", one->rule_number, (rest != #[] ? ", " : ""));
97             rules = rest;
98          }
99       }
100       pr (" };\n\n");
101    }
104 ///////////////////////////////////////////////////////////////////////////////
106 //  Method to generate the closure routines for each non-terminal
107 //  which appears the rhs of a chain rule.
109 ///////////////////////////////////////////////////////////////////////////////
110 void RewritingCompiler::generate_closures (FunctorMap& F)
111 {  pr("%^%/"
112       "%^// Closure methods for rewrite class %s"
113       "%^%/",
114       F.class_name
115      );
117    // Generate the headers first
118    {  foreach_entry (e, F.chain_rules)
119       {  Id         rhs   = Id(e->k);
120          MatchRules rules = MatchRules(e->v);
121          Ty         ty    = rules->#1->ty; // type of states.
122          pr ("%^static void %s_%S_closure(%t,int cost);\n",
123              F.class_name, rhs, ty, "redex"
124             );
125       }
126    }
128    pr ("\n");
130    // Then generate the definitions.
131    {  foreach_entry (e, F.chain_rules)
132       {  Id         rhs   = Id(e->k);
133          MatchRules rules = MatchRules(e->v);
134          gen_closure(F,rhs,rules);
135       }
136    }
139 ///////////////////////////////////////////////////////////////////////////////
141 //  Method to generate the closure routine for one non-terminal.
143 ///////////////////////////////////////////////////////////////////////////////
144 void RewritingCompiler::gen_closure (FunctorMap& F, Id rhs, MatchRules rules)
145 {  Ty ty = rules->#1->ty; // type of states.
146    pr ("%^static void %s_%S_closure(%t,int cost__)\n"  
147        "%^{%+"
148        "%^%s_StateRec * _state_rec = (%s_StateRec *)(%s->get_state_rec());",
149         F.class_name, rhs, ty, "redex", F.class_name,
150         F.class_name, redex_name(ty)
151       );
152    int rule_no = 1; 
153    rules = rev(rules);
154    match while (rules)
155    {  #[ MATCHrule(lhs,pat,guard,cost,_) ... rest ]:
156       {  Exp cost_exp;
157          match (cost)
158          {  NOcost:        { cost_exp = LITERALexp(INTlit(0)); }
159          |  INTcost i:     { cost_exp = LITERALexp(INTlit(i)); }
160          |  EXPcost (e,_): 
161             {  // Avoid recomputation of cost
162                Id v = vars.new_label();
163                pr ("%^const int %s = %e;", v, e); 
164                cost_exp = IDexp(v); 
165             }
166          }
167          int nonterm_number = int(F.var_map[lhs]);
168         
169          if (nonterm_number > 0)
170          {  pr ("%^if (cost__ + %e < _state_rec->cost[%i])"
171                 "%^{  _state_rec->cost[%i] = cost__ + %e;"   
172                 "%^   _state_rec->rule._%S = %i;",
173                 cost_exp, nonterm_number, nonterm_number, cost_exp, lhs, rule_no
174                );
176             // Chain rules
177             if (F.chain_rules.contains(lhs))
178             {  pr ("%^   %s_%S_closure(redex,cost__ + %e);",
179                    F.class_name, lhs, cost_exp);
180             }
182             pr ("%^}");
183          }
184          rule_no++;
185          rules = rest;
186       }
187    }
189    pr ("%-%^}\n\n");
192 ///////////////////////////////////////////////////////////////////////////////
194 //  Method to generate the dynamic labelers
196 ///////////////////////////////////////////////////////////////////////////////
197 void RewritingCompiler::generate_dynamic_labelers (FunctorMap& F)
198 {  
199    ////////////////////////////////////////////////////////////////////////////
200    //  Generate a dynamic labeler for each datatype
201    ////////////////////////////////////////////////////////////////////////////
202    foreach_entry (e, F.type_map)
203    {  Ty ty = Ty(F.type_map.key(e));
204       debug_msg("[Rewrite class %s: generating dynamic labeler for datatype %T\n",
205                 F.class_name, ty);
206       gen_dynamic_datatype_labeler(F, ty);
207    }
210 ///////////////////////////////////////////////////////////////////////////////
212 //  Method to generate a labeler routine for one datatype.
214 ///////////////////////////////////////////////////////////////////////////////
215 void RewritingCompiler::gen_dynamic_datatype_labeler(FunctorMap& F, Ty ty)
216 {  
217    ///////////////////////////////////////////////////////////////////////////
218    // Generate the protocol of this labeler routine
219    ///////////////////////////////////////////////////////////////////////////
220    pr ("%^void %s::labeler (%t)"
221        "%^{%+"
222           "%^int cost__;",
223        F.class_name, ty, "redex");
225    ///////////////////////////////////////////////////////////////////////////
226    // Name of the redex inside this routine. 
227    ///////////////////////////////////////////////////////////////////////////
228    Id redex = redex_name(ty);
230    ///////////////////////////////////////////////////////////////////////////
231    //
232    // Allocate and initialize a state record.
233    //
234    ///////////////////////////////////////////////////////////////////////////
235    pr ("%^%s_StateRec * _state_rec = (%s_StateRec *)mem[sizeof(%s_StateRec)];"
236        "%^%s->set_state_rec(_state_rec);"
237        "%^_state_rec->cost[0] = 0;",
238        F.class_name, F.class_name, F.class_name, redex);
239    for (int i = 1; i <= F.nonterm_map.size(); i++)
240    {  pr("%^_state_rec->cost[%i] = ", i);
241    }
242    pr ("%i;\n", TreeTables::infinite_cost);
244    ///////////////////////////////////////////////////////////////////////////
245    // Generate code for bottomup traversal on the datatype
246    ///////////////////////////////////////////////////////////////////////////
247    gen_dynamic_traversals(F, ty);
249    ///////////////////////////////////////////////////////////////////////////
250    // Update the state record.
251    ///////////////////////////////////////////////////////////////////////////
253    ///////////////////////////////////////////////////////////////////////////
254    // End of this routine
255    ///////////////////////////////////////////////////////////////////////////
256    pr ("%^%-%^}\n\n");
259 ///////////////////////////////////////////////////////////////////////////////
261 //  Method to generate code for dynamic traversals of one datatype.
263 ///////////////////////////////////////////////////////////////////////////////
264 void RewritingCompiler::gen_dynamic_traversals(FunctorMap& F, Ty ty)
265 {  if (!F.rule_map.contains(ty))
266    {  bug("%Lgen_dynamic_traversals: %t\n", ty); }
267    MatchRules rules = MatchRules(F.rule_map[ty]);
268    MatchExps  exps  = #[ MATCHexp(IDexp("redex"),0) ];
269    rules = rev(rules);
270    gen_match_stmt(exps, rules, 
271       MATCHnocheck + MATCHnotrace + MATCHall + MATCHwithtreecost);
274 ///////////////////////////////////////////////////////////////////////////////
276 //  Method to annotate the matching tree with tree reduction cost nodes.
277 //  Return the set of rules that matches.  The idea is to hoist
278 //  the cost minimalization rules as near the root as possible.  
280 ///////////////////////////////////////////////////////////////////////////////
281 const BitSet * label_treecost (Match& m, int, MatchRules rules);
282 const BitSet * label_treecost (Match& m, int, MatchRules rules, Match&, Match&);
283 const BitSet * label_treecost (Match& m, int, MatchRules rules, int, Match[], Match&, Bool);
285 const BitSet * label_treecost (Match& m, int N, MatchRules rules)
286 {  match (m)
287    { FAILmatch || SUCCESSmatch _:    { return 0; } 
288    | TREECOSTmatch(m, set, rules):   { return set; }
289    | COSTmatch(n, cost, set, rules): { return set; }
290    | SUCCESSESmatch(_, set, _):      { return set; } 
291    | GUARDmatch(e, a, b):
292      {  return label_treecost(m,N,rules,a,b); }
293    | CONSmatch(pos, _, _, _, n, a, b):
294      {  return label_treecost(m,N,rules,n,a,b,true); }
295    | LITERALmatch(pos, e, ls, n, a, b):
296      {  return label_treecost(m,N,rules,n,a,b,false); }
297    | RANGEmatch(pos, e, lo, hi, a, b):
298      {  return label_treecost(m,N,rules,a,b); }
299    | _:  { bug("label_treecost: %M", m); return 0; }
300    }
303 ///////////////////////////////////////////////////////////////////////////////
305 //  Method to annotate the matching tree with tree reduction cost nodes.
306 //  Return the set of rules that matches.
308 ///////////////////////////////////////////////////////////////////////////////
309 const BitSet * label_treecost (Match& m, int N, MatchRules rules, Match& a, Match& b)
310 {  const BitSet * s1 = label_treecost(a,N,rules);
311    const BitSet * s2 = label_treecost(b,N,rules);
312    if (s1 == 0 || s2 == 0) return 0;
313    BitSet * S = new (mem_pool, N) BitSet;
314    S->Intersect(*s1,*s2);
315    if (S->count() == 0) return 0;
316    m = TREECOSTmatch(m,S,rules);
317    m->label = 0; m->shared = 0;
318    return S;
321 ///////////////////////////////////////////////////////////////////////////////
323 //  Method to annotate the matching tree with tree reduction cost nodes.
324 //  Return the set of rules that matches.
326 ///////////////////////////////////////////////////////////////////////////////
327 const BitSet * label_treecost (Match& m, int N, MatchRules rules, 
328                               int fanout, Match a[], Match& b, Bool ignore)
329 {  const BitSet * Sb = label_treecost(b,N,rules);
330    BitSet * S = new (mem_pool, N) BitSet;
331    Bool empty = ! ignore && Sb == 0;
332    if (! ignore) S->copy(*Sb);
333    else          S->complement();
334    for (int i = 0; i < fanout; i++) 
335    {  const BitSet * Sa = label_treecost(a[i],N,rules);
336       if (Sa) { if (! empty) S->Intersect(*Sa); }
337       else    empty = true;
338    }
339    if (empty || S->count() == 0) return 0;
340    m = TREECOSTmatch(m,S,rules);
341    debug_msg("[NEW TREE]\n");
342    m->label = 0; m->shared = 0;
343    return S;
346 ///////////////////////////////////////////////////////////////////////////////
348 //  Method to prune the matching tree with tree reduction cost nodes.
349 //  We reduce unnecessary cost minimalization nodes.
351 ///////////////////////////////////////////////////////////////////////////////
352 void prune_treecost (Match& m, const BitSet * ignore)
353 {  match (m)
354    { FAILmatch || SUCCESSmatch _: { return; }
355    | COSTmatch(n, cost, set, rules): 
356      { if (ignore) { set->Difference(*ignore); 
357                      if (set->count() == 0) m = FAILmatch;
358                    }
359      }
360    | SUCCESSESmatch(_, set, _):      
361      { if (ignore) { set->Difference(*ignore); 
362                      if (set->count() == 0) m = FAILmatch;
363                    }
364      }
365    | TREECOSTmatch(a, set, rules):   
366      {  BitSet * new_ignore;
367         if (ignore) { new_ignore = new (mem_pool, ignore->size()) BitSet;
368                       new_ignore->Union(*set);
369                     }
370         else new_ignore = set;
371         prune_treecost(a,new_ignore);
372         if (ignore) { set->Difference(*ignore); 
373                       if (set->count() == 0) m = a; 
374                     }
375      }
376    | GUARDmatch(e, a, b):
377      {  prune_treecost(a,ignore); prune_treecost(b,ignore); }
378    | RANGEmatch(pos, e, lo, hi, a, b):
379      {  prune_treecost(a,ignore); prune_treecost(b,ignore); }
380    | CONSmatch(pos, _, _, _, n, a, b):
381      {  for (int i = 0; i < n; i++) prune_treecost(a[i],ignore);
382         prune_treecost(b,ignore); }
383    | LITERALmatch(pos, e, ls, n, a, b):
384      {  for (int i = 0; i < n; i++) prune_treecost(a[i],ignore);
385         prune_treecost(b,ignore); }
386    | _:  { bug("prune_treecost: %M", m); }
387    }
390 ///////////////////////////////////////////////////////////////////////////////
392 //  Method to insert traversal code
394 ///////////////////////////////////////////////////////////////////////////////
395 void add_traversal (Match& m)
396 {  match (m)
397    {  CONSmatch(_, _, ty, alg_ty, n, a, _):
398       {  for (int i = 0; i < n; i++)
399             a[i] = TREELABELmatch(a[i],ty,alg_ty,i);
400       }
401    |  TREECOSTmatch(m,_,_): {  add_traversal(m); }
402    |  _:                    {  bug ("add_traversal: %M", m); }
403    }
406 ///////////////////////////////////////////////////////////////////////////////
408 //  Method to translate the matching tree into a tree with 
409 //  tree reduction cost nodes.
411 ///////////////////////////////////////////////////////////////////////////////
412 Match translate_treecost (Match m, MatchRules rules)
413 {  debug_msg("%Ltranslating rules into treecost\n");
414    label_treecost(m,length(rules),rules);
415    prune_treecost(m,0);
416    add_traversal(m);
417    return m; 
420 ///////////////////////////////////////////////////////////////////////////////
422 //  Return the encoded rule number.
424 ///////////////////////////////////////////////////////////////////////////////
425 static int rule_of(FunctorMap * Fmap, Id lhs, int r)
426 {  int rule_no = 1;
427    MatchRules rules = MatchRules(Fmap->nonterm_rules[lhs]);
428    match while (rules)
429    {  #[ one ... rest ]:
430       {  if (one->rule_number == r) return rule_no;
431          rules = rest;
432          rule_no++;
433       }
434    }
435    bug("rule_of");
436    return 0;
439 ///////////////////////////////////////////////////////////////////////////////
441 //  Method for generating labeling code for pattern parsing.
443 ///////////////////////////////////////////////////////////////////////////////
444 void RewritingCompiler::gen_treelabel_match (Match m, Ty ty, Ty alg_ty, int k)
445 {  // Generate traversal code
446    match (alg_ty) and (deref_all ty)
447    {  DATATYPEty({ terms ... },_), TYCONty(_,tys):
448       {  Cons cons = terms[k];
449          Ty arg_ty = apply_ty(cons->cons_ty,tys);
450          Exp e     = select(IDexp("redex"),cons);
451          if (cons->ty == NOty)
452          {  error("%Ltree parsing mode cannot be used on datatype with unit constructors: %T\n", alg_ty);
453          }
454          match (arg_ty)
455          {  NOty:         // skip
456          |  TUPLEty tys:           
457             {  int i = 1; 
458                for_each(Ty, t, tys)
459                {  if (Fmap->is_rewritable_type(t))
460                      pr("%^labeler(%e);", DOTexp(e,index_of(i))); i++;
461                } 
462             }
463          |  RECORDty (labels,_,ts):
464             {  Ids ids; Tys tys;
465                for (ids = labels, tys = ts; ids && ts; 
466                     ids = ids->#2, tys = tys->#2)
467                {  if (Fmap->is_rewritable_type(tys->#1))
468                      pr("%^labeler(%e);", DOTexp(e,ids->#1));
469                } 
470             }
471          |  ty: {  if (Fmap->is_rewritable_type(ty)) pr("%^labeler(%e);",e); }
472          }
473       }
474    |  _: { bug("RewritingCompiler::gen_treelabel_match"); } 
475    }
477    // Generate labeling code
478    gen(m);
481 ///////////////////////////////////////////////////////////////////////////////
483 //  Method for generating treecost minimalization code from a pattern
484 //  matching tree.
486 ///////////////////////////////////////////////////////////////////////////////
487 void RewritingCompiler::gen_treecost_match (Match m, const BitSet * set, 
488                                             MatchRules rules)
489 {  gen(m);
490    int rule_no = 0;
491    match while (rules)
492    {  #[r as MATCHrule(lhs,pat,guard,cost,_) ... rest]:
493       {  if (set->contains(rule_no))
494          {  if (lhs) 
495             {  pr ("%^// %r\n", r);
496                Exp cost_exp;
497                match (cost)
498                {  NOcost:        { cost_exp = LITERALexp(INTlit(0)); }
499                |  INTcost i:     { cost_exp = LITERALexp(INTlit(i)); }
500                |  EXPcost (e,_): { cost_exp = IDexp(vars.new_label()); 
501                                    pr("%^int %e = %e;",cost_exp,e);
502                                  }
503                }
504                int nonterm_number = int(Fmap->var_map[lhs]);
505                pr ("%^cost__ = %e + %e;"
506                    "%^if (cost__ < _state_rec->cost[%i])"
507                    "%^{   _state_rec->cost[%i] = cost__;"
508                    "%^    _state_rec->rule._%S = %i;",
509                    cost_exp, Fmap->cost_expr(lhs,pat),
510                    nonterm_number, nonterm_number, 
511                    lhs, rule_of(Fmap,lhs,r->rule_number));
513                if (Fmap->chain_rules.contains(lhs))
514                   pr ("%^   %s_%S_closure(redex, cost__);",
515                          Fmap->class_name, lhs); 
516                pr ("%^}"); 
517             }
518          }
519          rules = rest; rule_no++;
520       }
521    }