9 void simplify_ast(ASTNode
*root
)
11 NodeStore v
= root
->getChildren();
12 for (NodeStore::iterator it
= v
.begin(); it
!= v
.end(); ++it
) {
15 if (node
->getType() == std::string("Semicolon")) {
16 Semicolon
*semi
= dynamic_cast<Semicolon
*>(node
);
19 Block
*block
= new Block();
21 block
->loc
= semi
->loc
;
22 semi
->traverseTree(block
->list
, semi
);
23 for (NodeStore::iterator it2
= block
->list
.begin(); it2
!= block
->list
.end(); ++it2
)
24 (*it2
)->parent
= block
;
27 } else if (node
->getType() == std::string("Comma")) {
28 Comma
* comma
= dynamic_cast<Comma
*>(node
);
31 List
*block
= new List();
33 block
->loc
= comma
->loc
;
34 comma
->traverseTree(block
->list
, comma
);
35 for (NodeStore::iterator it2
= block
->list
.begin(); it2
!= block
->list
.end(); ++it2
)
36 (*it2
)->parent
= block
;
46 bool tree_check_parents(ASTNode
* root
)
48 NodeStore v
= root
->getChildren();
49 for (NodeStore::iterator it
= v
.begin(); it
!= v
.end(); ++it
) {
52 if ((*it
)->parent
!= root
) {
53 LOG("error:%s: failed check: %s->%s\n", root
->locStr().c_str(),
54 root
->getType(), (*it
)->getType());
57 if (!tree_check_parents(*it
))
63 static inline ASTNode
* find_parent(ASTNode
* p
, std::string type
)
65 while (p
&& p
->getType() != type
)
70 static inline NodeStore
find_all_children(ASTNode
* p
, std::string type
)
72 NodeStore r
, v
= p
->getChildren();
73 for (NodeStore::iterator it
= v
.begin(); it
!= v
.end(); ++it
) {
77 if (node
->getType() == type
)
79 NodeStore tmp
= find_all_children(node
, type
);
80 r
.insert(r
.end(), tmp
.begin(), tmp
.end());
85 static ASTNode
* find_symbol(parse_info
&pi
, std::string name
)
88 SymbolTable::iterator it
= pi
.staticvars
.find(name
);
89 if (it
!= pi
.staticvars
.end())
90 return it
->second
.first
;
92 it
= pi
.pieces
.find(name
);
93 if (it
!= pi
.pieces
.end())
94 return it
->second
.first
;
97 it
= pi
.locals
->find(name
);
98 if (it
!= pi
.locals
->end())
99 return it
->second
.first
;
104 FunctionTable::iterator it
= pi
.functions
.find(name
);
105 if (it
!= pi
.functions
.end())
106 return it
->second
.ast
;
111 static bool validate_loop_exit(ASTNode
*leaf
)
113 ASTNode
*p
= find_parent(leaf
, "Loop");
115 LOG("stopped at %s\n", p
->getType());
117 LOGERROR(leaf
, "%s not in loop\n", leaf
->getType());
121 static bool validate_declarations(ASTNode
*node
)
123 Declaration
*n
= dynamic_cast<Declaration
*>(node
);
125 Typename
*t
= dynamic_cast<Typename
*>(n
->type
);
127 if (t
->name
== "var") {
128 if (!find_parent(node
, "Function")) {
129 LOGERROR(t
, "local var declaration when not in function\n");
132 } else if (t
->name
== "static-var" || t
->name
== "piece") {
133 if (find_parent(node
, "Function")) {
134 LOGERROR(t
, "global declaration when in function\n");
138 LOGERROR(t
, "unknown type: %s\n", t
->name
.c_str());
145 bool validate_jump(ASTNode
*node
)
148 Jump
*j
= dynamic_cast<Jump
*>(node
);
150 Function
*fun
= dynamic_cast<Function
*>(find_parent(node
, "Function"));
152 NodeStore labels
= find_all_children(fun
, "Label");
154 for (NodeStore::iterator it
= labels
.begin(); it
!= labels
.end(); ++it
) {
155 Label
*l
= dynamic_cast<Label
*>(*it
);
157 if (l
->name
== j
->dest
) {
160 LOGERROR(j
, "%s - duplicate destination: \"%s\"\n", j
->getType(), j
->dest
.c_str());
166 LOGERROR(j
, "label \"%s\" not found in function %s\n", j
->dest
.c_str(), fun
->getFunctionName().c_str());
167 return labelnum
== 1;
170 static bool insert_into_symtable(parse_info
& pi
, SymbolTable
*tbl
, Ident
*id
)
174 ASTNode
* tmp
= find_symbol(pi
, id
->name
);
176 LOGERROR(id
, "duplicate declaration of \"%s\"\n",
178 LOGLOC(tmp
, "previously defined here\n");
181 tbl
->insert(SymbolTable::value_type(id
->name
, Symbol(id
, tbl
->size())));
187 static int insert_many_into_symtable(parse_info
& pi
, SymbolTable
& sym
, ASTNode
*paramNode
)
192 Ident
*id
= dynamic_cast<Ident
*>(paramNode
);
194 return !insert_into_symtable(pi
, &sym
, id
);
197 Assign
*as
= dynamic_cast<Assign
*>(paramNode
);
199 id
= dynamic_cast<Ident
*>(as
->left
);
201 return !insert_into_symtable(pi
, &sym
, id
);
205 List
*li
= dynamic_cast<List
*>(paramNode
);
207 for (NodeStore::iterator it
= li
->list
.begin(); it
!= li
->list
.end(); ++it
) {
208 as
= dynamic_cast<Assign
*>(*it
);
210 id
= dynamic_cast<Ident
*>(as
->left
);
212 id
= dynamic_cast<Ident
*>(*it
);
214 errors
+= !insert_into_symtable(pi
, &sym
, id
);
219 EPRINTF("bad arguments: type %s\n", paramNode
->getType());
220 assert(0 && "bad arguments");
224 static bool add_declarations_to_symtable(parse_info
& pi
, ASTNode
* node
)
226 Declaration
*decl
= dynamic_cast<Declaration
*>(node
);
228 Typename
*type
= dynamic_cast<Typename
*>(decl
->type
);
231 if (type
->name
== "piece")
233 else if (type
->name
== "static-var")
234 tbl
= &pi
.staticvars
;
235 else if (type
->name
== "var")
238 LOGERROR(type
, "unknown type %s\n", type
->name
.c_str());
242 return !insert_many_into_symtable(pi
, *tbl
, decl
->list
);
245 static bool add_label(parse_info
&pi
, ASTNode
* node
)
247 Label
* label
= dynamic_cast<Label
*>(node
);
249 LabelTable::iterator it
= pi
.labels
.find(label
->name
);
250 if (it
== pi
.labels
.end()) {
251 pi
.labels
.insert(LabelTable::value_type(label
->name
, label
));
255 assert(it
!= pi
.labels
.end());
257 LOGERROR(label
, "duplicate label \"%s\"\n", label
->name
.c_str());
258 LOGLOC(it
->second
, "first declared here\n");
265 int function_info::readParams(parse_info
& pi
, ASTNode
*paramNode
)
267 int errors
= insert_many_into_symtable(pi
, locals
, paramNode
);
268 params
.resize(locals
.size());
269 for (SymbolTable::iterator it
= locals
.begin(); it
!= locals
.end(); ++it
)
270 params
[it
->second
.second
] = it
->first
;
274 static double calculate_binop_val(BinaryOperator
* binop
)
276 Number
* l
= dynamic_cast<Number
*>(binop
->left
);
277 Number
* r
= dynamic_cast<Number
*>(binop
->right
);
279 std::string name
= binop
->getType();
281 return l
->val
+ r
->val
;
282 else if (name
== "Subtract")
283 return l
->val
- r
->val
;
284 else if (name
== "Multiply")
285 return l
->val
* r
->val
;
286 else if (name
== "Divide")
287 return l
->val
/ r
->val
;
288 else if (name
== "ShiftLeft")
289 return (int)(l
->val
) << (int)(r
->val
);
290 else if (name
== "ShiftRight")
291 return (int)(l
->val
) >> (int)(r
->val
);
292 else if (name
== "BitwiseAnd")
293 return (int)(l
->val
) & (int)(r
->val
);
294 else if (name
== "BitwiseOr")
295 return (int)(l
->val
) | (int)(r
->val
);
296 else if (name
== "BitwiseXor")
297 return (int)(l
->val
) ^ (int)(r
->val
);
298 else if (name
== "LogicalAnd")
299 return (bool)(l
->val
) && (bool)(r
->val
);
300 else if (name
== "LogicalOr")
301 return (bool)(l
->val
) && (bool)(r
->val
);
302 else if (name
== "LessThan")
303 return l
->val
< r
->val
;
304 else if (name
== "LessEqual")
305 return l
->val
<= r
->val
;
306 else if (name
== "GreaterEqual")
307 return l
->val
>= r
->val
;
308 else if (name
== "GreaterThan")
309 return l
->val
>= r
->val
;
310 else if (name
== "Equal")
311 return l
->val
== r
->val
;
312 else if (name
== "NotEqual")
313 return l
->val
!= r
->val
;
315 LOGERROR(binop
, "unknown binary operator in constant folding: %s\n", binop
->getType());
316 throw std::runtime_error("unknown binary operator in constant folding");
320 static double calculate_unaop_val(UnaryOperator
* unaop
)
322 Number
* n
= dynamic_cast<Number
*>(unaop
->operand
);
324 std::string name
= unaop
->getType();
325 if (name
== "Increment")
327 else if (name
== "Decrement")
329 else if (name
== "Negation")
331 else if (name
== "LogicalNot")
332 return !(bool)n
->val
;
333 else if (name
== "BitwiseNot")
336 LOGERROR(unaop
, "unknown unary operator in constant folding: %s\n", unaop
->getType());
337 throw std::runtime_error("unknown unary operator in constant folding");
341 static ASTNode
* constant_expression_fold(ASTNode
* node
)
344 BinaryOperator
*binop
= dynamic_cast<BinaryOperator
*>(node
);
345 UnaryOperator
*unaop
= dynamic_cast<UnaryOperator
*>(node
);
346 if (!binop
&& !unaop
)
351 assert(binop
->right
);
352 if (binop
->left
->getType() == std::string("Number")) {
353 if (binop
->right
->getType() == std::string("Number")) {
354 return new Number(calculate_binop_val(binop
));
356 ASTNode
*newnode
= constant_expression_fold(binop
->right
);
357 newnode
->parent
= node
;
358 binop
->right
= newnode
;
359 if (newnode
->getType() == std::string("Number"))
360 return new Number(calculate_binop_val(binop
));
366 else if (binop
->right
->getType() == std::string("Number")) {
368 binop
->left
= constant_expression_fold(binop
->left
);
369 binop
->left
->parent
= binop
;
370 // left is not a number
371 BinaryOperator
*l
= dynamic_cast<BinaryOperator
*>(binop
->left
);
373 if (l
&& l
->right
->getType() == std::string("Number")
374 && (l
->oper
== binop
->oper
375 || (l
->oper
== "-" && binop
->oper
== "+")
376 || (l
->oper
== "+" && binop
->oper
== "-")
377 || (l
->oper
== "*" && binop
->oper
== "/")
378 || (l
->oper
== "/" && binop
->oper
== "*"))
379 && (l
->getType() == std::string("Add")
380 || l
->getType() == std::string("Multiply")
381 || l
->getType() == std::string("Subtract")
382 || l
->getType() == std::string("Divide")
383 || l
->getType() == std::string("BitwiseOr")
384 || l
->getType() == std::string("BitwiseAnd")
385 || l
->getType() == std::string("LogicalAnd")
386 || l
->getType() == std::string("LogicalOr"))) {
388 l
->left
->parent
= binop
;
389 binop
->left
= l
->left
;
390 // prepare a node for computation
391 l
->left
= binop
->right
;
392 if (l
->getType() == std::string("Subtract")
393 || l
->getType() == std::string("Divide")) {
394 // case 1. invert left child and compute
395 BinaryOperator
* l2
= l
->invert();
396 binop
->right
= new Number(calculate_binop_val(l2
));
398 } else if (binop
->getType() == std::string("Subtract")) {
399 // case 2a. invert self and compute
400 Number
* n
= dynamic_cast<Number
*>(binop
->right
);
403 binop
->right
= new Number(calculate_binop_val(l
));
404 } else if (binop
->getType() == std::string("Divide")) {
405 // case 2b. invert self and compute
406 Number
* n
= dynamic_cast<Number
*>(binop
->right
);
409 binop
->right
= new Number(calculate_binop_val(l
));
412 // case 3. no need to invert anything, same operators
413 binop
->right
= new Number(calculate_binop_val(l
));
418 binop
->right
->parent
= binop
;
419 // scary, but works -- tree is shortening with every call
420 return constant_expression_fold(binop
);
423 // nothing scary to do
424 if (binop
->left
->getType() == std::string("Number"))
425 return new Number(calculate_binop_val(binop
));
431 binop
->left
= constant_expression_fold(binop
->left
);
432 binop
->right
= constant_expression_fold(binop
->right
);
433 binop
->left
->parent
= binop
;
434 binop
->right
->parent
= binop
;
439 assert(unaop
->operand
);
440 if (unaop
->operand
->getType() == std::string("Number")) {
441 return new Number(calculate_unaop_val(unaop
));
443 ASTNode
* newnode
= constant_expression_fold(unaop
->operand
);
444 newnode
->parent
= node
;
445 unaop
->operand
= newnode
;
446 if (newnode
->getType() == std::string("Number"))
447 return new Number(calculate_unaop_val(unaop
));
454 unaop
->operand
= constant_expression_fold(unaop
->operand
);
455 unaop
->operand
->parent
= unaop
;
460 void constant_expression_optimization(ASTNode
* root
)
462 NodeStore v
= root
->getChildren();
464 for (NodeStore::iterator it
= v
.begin(); it
!= v
.end(); ++it
)
466 BinaryOperator
* binop
= dynamic_cast<BinaryOperator
*>(*it
);
468 *it
= constant_expression_fold(binop
);
469 (*it
)->parent
= root
;
471 UnaryOperator
* unaop
= dynamic_cast<UnaryOperator
*>(*it
);
473 *it
= constant_expression_fold(unaop
);
474 (*it
)->parent
= root
;
477 constant_expression_optimization(*it
);
479 root
->setChildren(v
);
483 static int compile_function_step(parse_info
& pi
, function_info
& fi
, ASTNode
*node
)
486 if (node
->getType() == std::string("Block")) {
487 Block
* b
= dynamic_cast<Block
*>(node
);
488 for(NodeStore::iterator it
= b
->list
.begin(); it
!= b
->list
.end(); ++it
)
490 errors
+= compile_function_step(pi
, fi
, *it
);
491 } else if (node
->getType() == std::string("List")) {
492 List
* l
= dynamic_cast<List
*>(node
);
493 for(NodeStore::iterator it
= l
->list
.begin(); it
!= l
->list
.end(); ++it
)
495 errors
+= compile_function_step(pi
, fi
, *it
);
496 } else if (node
->getType() == std::string("Declaration")) {
497 Declaration
*decl
= dynamic_cast<Declaration
*>(node
);
498 errors
+= insert_many_into_symtable(pi
, fi
.locals
, decl
->list
);
499 } else if (node
->getType() == std::string("Ident")) {
500 Ident
* id
= dynamic_cast<Ident
*>(node
);
501 if (!find_symbol(pi
, id
->name
)) {
502 LOGERROR(id
, "unknown identifier \"%s\"\n", id
->name
.c_str());
505 } else if (node
->getType() == std::string("CallExpr")) {
506 NodeStore v
= node
->getChildren();
507 NodeStore::iterator it
= v
.begin();
509 if ((*it
)->getType() == std::string("Ident")) {
510 Ident
* id
= dynamic_cast<Ident
*>(*it
);
512 ASTNode
* sym
= find_symbol(pi
, id
->name
);
513 if (!sym
&& pi
.functions
.find(id
->name
) == pi
.functions
.end()) {
514 pi
.undefinedFuncs
.insert(NodeTable::value_type(id
->name
, id
));
517 LOGERROR(id
, "symbol \"%s\" is not a function\n", id
->name
.c_str());
518 LOGLOC(sym
, "was defined here\n");
522 LOGWARNING(*it
, "warning: not calling by name\n");
526 errors
+= compile_function_step(pi
, fi
, *it
);
528 NodeStore v
= node
->getChildren();
529 for(NodeStore::iterator it
= v
.begin(); it
!= v
.end(); ++it
)
531 errors
+= compile_function_step(pi
, fi
, *it
);
536 static int compile_function(parse_info
& pi
, ASTNode
* node
)
539 Function
* f
= dynamic_cast<Function
*>(node
);
541 FunctionProto
*proto
= dynamic_cast<FunctionProto
*>(f
->proto
);
546 fi
.index
= pi
.functions
.size();
547 fi
.name
= proto
->name
;
548 pi
.locals
= &fi
.locals
;
549 errors
+= fi
.readParams(pi
, proto
->params
);
550 ASTNode
*tmp
= find_symbol(pi
, fi
.name
);
552 LOGERROR(f
, "duplicate declaration of \"%s\"\n", fi
.name
.c_str());
553 LOGLOC(tmp
, "previously defined here\n");
556 pi
.functions
.insert(FunctionTable::value_type(fi
.name
, fi
));
557 pi
.undefinedFuncs
.erase(fi
.name
);
559 errors
+= compile_function_step(pi
, fi
, f
->instr
);
562 LOG("%s: read %u params\n", fi
.name
.c_str(), fi
.params
.size());
566 int check_parse_results(parse_info
& pi
)
568 for (NodeTable::iterator it
= pi
.undefinedFuncs
.begin();
569 it
!= pi
.undefinedFuncs
.end(); ++it
) {
571 Ident
* id
= dynamic_cast<Ident
*>(it
->second
);
573 LOGERROR(id
, "function called, but undefined: \"%s\"\n", id
->name
.c_str());
575 LOGERROR(it
->second
, "function called, but undefined\n");
579 return (int)pi
.undefinedFuncs
.size();
582 int validate_tree(parse_info
& pi
, ASTNode
*root
)
584 NodeStore v
= root
->getChildren();
585 for (NodeStore::iterator it
= v
.begin(); it
!= v
.end(); ++it
) {
589 std::string type
= node
->getType();
590 if (type
== "Break" || type
== "Continue")
591 pi
.errors
+= !validate_loop_exit(node
);
592 else if (type
== "Declaration") {
593 bool validDecl
= validate_declarations(node
);
595 pi
.errors
+= !add_declarations_to_symtable(pi
, node
);
598 } else if (type
== "Jump") {
599 pi
.errors
+= !validate_jump(node
);
600 } else if (type
== "Label") {
601 pi
.errors
+= !add_label(pi
, node
);
603 validate_tree(pi
, node
);
604 if (type
== "Function") {
605 pi
.errors
+= compile_function(pi
, node
);