fixed constant folding so it works correctly both on lhs and rhs
[bosc.git] / src / compiler.cpp
blob1be6bbf07fbb74aa53057ba6a1ef9991bc7c19c1
1 #include <string>
2 #include <stack>
3 #include <cassert>
4 #include <stdexcept>
6 #include "ast.h"
7 #include "compiler.h"
9 void simplify_ast(ASTNode *root)
11 NodeStore v = root->getChildren();
12 for (NodeStore::iterator it = v.begin(); it != v.end(); ++it) {
13 ASTNode *node = *it;
14 if (!node) continue;
15 if (node->getType() == std::string("Semicolon")) {
16 Semicolon *semi = dynamic_cast<Semicolon*>(node);
17 assert(semi);
18 semi->parent = NULL;
19 Block *block = new Block();
20 block->parent = root;
21 block->loc = semi->loc;
22 semi->traverseTree(block->list, semi);
23 for (NodeStore::iterator it2 = block->list.begin(); it2 != block->list.end(); ++it2)
24 (*it2)->parent = block;
25 *it = block;
26 delete semi;
27 } else if (node->getType() == std::string("Comma")) {
28 Comma* comma = dynamic_cast<Comma*>(node);
29 assert(comma);
30 comma->parent = NULL;
31 List *block = new List();
32 block->parent = root;
33 block->loc = comma->loc;
34 comma->traverseTree(block->list, comma);
35 for (NodeStore::iterator it2 = block->list.begin(); it2 != block->list.end(); ++it2)
36 (*it2)->parent = block;
37 *it = block;
38 delete comma;
40 simplify_ast(*it);
42 root->setChildren(v);
46 bool tree_check_parents(ASTNode* root)
48 NodeStore v = root->getChildren();
49 for (NodeStore::iterator it = v.begin(); it != v.end(); ++it) {
50 if (!(*it))
51 continue;
52 if ((*it)->parent != root) {
53 LOG("error:%s: failed check: %s->%s\n", root->locStr().c_str(),
54 root->getType(), (*it)->getType());
55 return false;
57 if (!tree_check_parents(*it))
58 return false;
60 return true;
63 static inline ASTNode* find_parent(ASTNode* p, std::string type)
65 while (p && p->getType() != type)
66 p = p->parent;
67 return p;
70 static inline NodeStore find_all_children(ASTNode* p, std::string type)
72 NodeStore r, v = p->getChildren();
73 for (NodeStore::iterator it = v.begin(); it != v.end(); ++it) {
74 if (!*it)
75 continue;
76 ASTNode *node = *it;
77 if (node->getType() == type)
78 r.push_back(node);
79 NodeStore tmp = find_all_children(node, type);
80 r.insert(r.end(), tmp.begin(), tmp.end());
82 return r;
85 static ASTNode* find_symbol(parse_info &pi, std::string name)
88 SymbolTable::iterator it = pi.staticvars.find(name);
89 if (it != pi.staticvars.end())
90 return it->second.first;
92 it = pi.pieces.find(name);
93 if (it != pi.pieces.end())
94 return it->second.first;
96 if (pi.locals) {
97 it = pi.locals->find(name);
98 if (it != pi.locals->end())
99 return it->second.first;
104 FunctionTable::iterator it = pi.functions.find(name);
105 if (it != pi.functions.end())
106 return it->second.ast;
108 return 0;
111 static bool validate_loop_exit(ASTNode *leaf)
113 ASTNode *p = find_parent(leaf, "Loop");
114 if (p)
115 LOG("stopped at %s\n", p->getType());
116 else
117 LOGERROR(leaf, "%s not in loop\n", leaf->getType());
118 return p != NULL;
121 static bool validate_declarations(ASTNode *node)
123 Declaration *n = dynamic_cast<Declaration*>(node);
124 assert(n);
125 Typename *t = dynamic_cast<Typename*>(n->type);
126 assert(t);
127 if (t->name == "var") {
128 if (!find_parent(node, "Function")) {
129 LOGERROR(t, "local var declaration when not in function\n");
130 return false;
132 } else if (t->name == "static-var" || t->name == "piece") {
133 if (find_parent(node, "Function")) {
134 LOGERROR(t, "global declaration when in function\n");
135 return false;
137 } else {
138 LOGERROR(t, "unknown type: %s\n", t->name.c_str());
139 return false;
141 return true;
145 bool validate_jump(ASTNode *node)
147 int labelnum = 0;
148 Jump *j = dynamic_cast<Jump*>(node);
149 assert(j);
150 Function *fun = dynamic_cast<Function*>(find_parent(node, "Function"));
151 assert(fun);
152 NodeStore labels = find_all_children(fun, "Label");
153 Label *prev = NULL;
154 for (NodeStore::iterator it = labels.begin(); it != labels.end(); ++it) {
155 Label *l = dynamic_cast<Label*>(*it);
156 assert(l);
157 if (l->name == j->dest) {
158 ++labelnum;
159 if (labelnum >= 2) {
160 LOGERROR(j, "%s - duplicate destination: \"%s\"\n", j->getType(), j->dest.c_str());
163 prev = l;
165 if (labelnum == 0)
166 LOGERROR(j, "label \"%s\" not found in function %s\n", j->dest.c_str(), fun->getFunctionName().c_str());
167 return labelnum == 1;
170 static bool insert_into_symtable(parse_info& pi, SymbolTable *tbl, Ident *id)
172 assert(tbl);
173 assert(id);
174 ASTNode* tmp = find_symbol(pi, id->name);
175 if (tmp) {
176 LOGERROR(id, "duplicate declaration of \"%s\"\n",
177 id->name.c_str());
178 LOGLOC(tmp, "previously defined here\n");
179 return false;
180 } else {
181 tbl->insert(SymbolTable::value_type(id->name, Symbol(id, tbl->size())));
182 return true;
187 static int insert_many_into_symtable(parse_info& pi, SymbolTable& sym, ASTNode *paramNode)
189 if (!paramNode)
190 return 0;
192 Ident *id = dynamic_cast<Ident*>(paramNode);
193 if (id) {
194 return !insert_into_symtable(pi, &sym, id);
197 Assign *as = dynamic_cast<Assign*>(paramNode);
198 if (as) {
199 id = dynamic_cast<Ident*>(as->left);
200 assert(id);
201 return !insert_into_symtable(pi, &sym, id);
204 int errors = 0;
205 List *li = dynamic_cast<List*>(paramNode);
206 if (li) {
207 for (NodeStore::iterator it = li->list.begin(); it != li->list.end(); ++it) {
208 as = dynamic_cast<Assign*>(*it);
209 if (as)
210 id = dynamic_cast<Ident*>(as->left);
211 else
212 id = dynamic_cast<Ident*>(*it);
213 assert(id);
214 errors += !insert_into_symtable(pi, &sym, id);
216 return errors;
219 EPRINTF("bad arguments: type %s\n", paramNode->getType());
220 assert(0 && "bad arguments");
221 return 666;
224 static bool add_declarations_to_symtable(parse_info& pi, ASTNode* node)
226 Declaration *decl = dynamic_cast<Declaration*>(node);
227 assert(decl);
228 Typename *type = dynamic_cast<Typename*>(decl->type);
229 assert(type);
230 SymbolTable* tbl;
231 if (type->name == "piece")
232 tbl = &pi.pieces;
233 else if (type->name == "static-var")
234 tbl = &pi.staticvars;
235 else if (type->name == "var")
236 return true;
237 else {
238 LOGERROR(type, "unknown type %s\n", type->name.c_str());
239 return false;
242 return !insert_many_into_symtable(pi, *tbl, decl->list);
245 static bool add_label(parse_info &pi, ASTNode* node)
247 Label* label = dynamic_cast<Label*>(node);
248 assert(label);
249 LabelTable::iterator it = pi.labels.find(label->name);
250 if (it == pi.labels.end()) {
251 pi.labels.insert(LabelTable::value_type(label->name, label));
252 return true;
254 else {
255 assert(it != pi.labels.end());
256 assert(it->second);
257 LOGERROR(label, "duplicate label \"%s\"\n", label->name.c_str());
258 LOGLOC(it->second, "first declared here\n");
259 return false;
265 int function_info::readParams(parse_info& pi, ASTNode *paramNode)
267 int errors = insert_many_into_symtable(pi, locals, paramNode);
268 params.resize(locals.size());
269 for (SymbolTable::iterator it = locals.begin(); it != locals.end(); ++it)
270 params[it->second.second] = it->first;
271 return errors;
274 static double calculate_binop_val(BinaryOperator* binop)
276 Number* l = dynamic_cast<Number*>(binop->left);
277 Number* r = dynamic_cast<Number*>(binop->right);
278 assert(l && r);
279 std::string name = binop->getType();
280 if (name == "Add")
281 return l->val + r->val;
282 else if (name == "Subtract")
283 return l->val - r->val;
284 else if (name == "Multiply")
285 return l->val * r->val;
286 else if (name == "Divide")
287 return l->val / r->val;
288 else if (name == "ShiftLeft")
289 return (int)(l->val) << (int)(r->val);
290 else if (name == "ShiftRight")
291 return (int)(l->val) >> (int)(r->val);
292 else if (name == "BitwiseAnd")
293 return (int)(l->val) & (int)(r->val);
294 else if (name == "BitwiseOr")
295 return (int)(l->val) | (int)(r->val);
296 else if (name == "BitwiseXor")
297 return (int)(l->val) ^ (int)(r->val);
298 else if (name == "LogicalAnd")
299 return (bool)(l->val) && (bool)(r->val);
300 else if (name == "LogicalOr")
301 return (bool)(l->val) && (bool)(r->val);
302 else if (name == "LessThan")
303 return l->val < r->val;
304 else if (name == "LessEqual")
305 return l->val <= r->val;
306 else if (name == "GreaterEqual")
307 return l->val >= r->val;
308 else if (name == "GreaterThan")
309 return l->val >= r->val;
310 else if (name == "Equal")
311 return l->val == r->val;
312 else if (name == "NotEqual")
313 return l->val != r->val;
314 else {
315 LOGERROR(binop, "unknown binary operator in constant folding: %s\n", binop->getType());
316 throw std::runtime_error("unknown binary operator in constant folding");
320 static double calculate_unaop_val(UnaryOperator* unaop)
322 Number* n = dynamic_cast<Number*>(unaop->operand);
323 assert(n);
324 std::string name = unaop->getType();
325 if (name == "Increment")
326 return n->val+1;
327 else if (name == "Decrement")
328 return n->val-1;
329 else if (name == "Negation")
330 return -n->val;
331 else if (name == "LogicalNot")
332 return !(bool)n->val;
333 else if (name == "BitwiseNot")
334 return ~(int)n->val;
335 else {
336 LOGERROR(unaop, "unknown unary operator in constant folding: %s\n", unaop->getType());
337 throw std::runtime_error("unknown unary operator in constant folding");
341 static ASTNode* constant_expression_fold(ASTNode* node)
343 assert(node);
344 BinaryOperator *binop = dynamic_cast<BinaryOperator*>(node);
345 UnaryOperator *unaop = dynamic_cast<UnaryOperator*>(node);
346 if (!binop && !unaop)
347 return node;
349 if (binop) {
350 assert(binop->left);
351 assert(binop->right);
352 if (binop->left->getType() == std::string("Number")) {
353 if (binop->right->getType() == std::string("Number")) {
354 return new Number(calculate_binop_val(binop));
355 } else {
356 ASTNode *newnode = constant_expression_fold(binop->right);
357 newnode->parent = node;
358 binop->right = newnode;
359 if (newnode->getType() == std::string("Number"))
360 return new Number(calculate_binop_val(binop));
361 else {
362 return binop;
366 else if (binop->right->getType() == std::string("Number")) {
367 // optimize first
368 binop->left = constant_expression_fold(binop->left);
369 binop->left->parent = binop;
370 // left is not a number
371 BinaryOperator *l = dynamic_cast<BinaryOperator*>(binop->left);
372 // FIXME this sucks
373 if (l && l->right->getType() == std::string("Number")
374 && (l->oper == binop->oper
375 || (l->oper == "-" && binop->oper == "+")
376 || (l->oper == "+" && binop->oper == "-")
377 || (l->oper == "*" && binop->oper == "/")
378 || (l->oper == "/" && binop->oper == "*"))
379 && (l->getType() == std::string("Add")
380 || l->getType() == std::string("Multiply")
381 || l->getType() == std::string("Subtract")
382 || l->getType() == std::string("Divide")
383 || l->getType() == std::string("BitwiseOr")
384 || l->getType() == std::string("BitwiseAnd")
385 || l->getType() == std::string("LogicalAnd")
386 || l->getType() == std::string("LogicalOr"))) {
387 // cut out a node
388 l->left->parent = binop;
389 binop->left = l->left;
390 // prepare a node for computation
391 l->left = binop->right;
392 if (l->getType() == std::string("Subtract")
393 || l->getType() == std::string("Divide")) {
394 // case 1. invert left child and compute
395 BinaryOperator* l2 = l->invert();
396 binop->right = new Number(calculate_binop_val(l2));
397 delete l2;
398 } else if (binop->getType() == std::string("Subtract")) {
399 // case 2a. invert self and compute
400 Number* n = dynamic_cast<Number*>(binop->right);
401 assert(n);
402 n->val = -n->val;
403 binop->right = new Number(calculate_binop_val(l));
404 } else if (binop->getType() == std::string("Divide")) {
405 // case 2b. invert self and compute
406 Number* n = dynamic_cast<Number*>(binop->right);
407 assert(n);
408 n->val = 1./n->val;
409 binop->right = new Number(calculate_binop_val(l));
411 else {
412 // case 3. no need to invert anything, same operators
413 binop->right = new Number(calculate_binop_val(l));
415 delete l->right;
416 delete l->left;
417 delete l;
418 binop->right->parent = binop;
419 // scary, but works -- tree is shortening with every call
420 return constant_expression_fold(binop);
422 else {
423 // nothing scary to do
424 if (binop->left->getType() == std::string("Number"))
425 return new Number(calculate_binop_val(binop));
426 else {
427 return binop;
430 } else {
431 binop->left = constant_expression_fold(binop->left);
432 binop->right = constant_expression_fold(binop->right);
433 binop->left->parent = binop;
434 binop->right->parent = binop;
435 return binop;
438 else if (unaop) {
439 assert(unaop->operand);
440 if (unaop->operand->getType() == std::string("Number")) {
441 return new Number(calculate_unaop_val(unaop));
442 } else {
443 ASTNode* newnode = constant_expression_fold(unaop->operand);
444 newnode->parent = node;
445 unaop->operand = newnode;
446 if (newnode->getType() == std::string("Number"))
447 return new Number(calculate_unaop_val(unaop));
448 else {
449 return unaop;
453 else {
454 unaop->operand = constant_expression_fold(unaop->operand);
455 unaop->operand->parent = unaop;
456 return unaop;
460 void constant_expression_optimization(ASTNode* root)
462 NodeStore v = root->getChildren();
464 for (NodeStore::iterator it = v.begin(); it != v.end(); ++it)
465 if (*it) {
466 BinaryOperator* binop = dynamic_cast<BinaryOperator*>(*it);
467 if (binop) {
468 *it = constant_expression_fold(binop);
469 (*it)->parent = root;
470 } else {
471 UnaryOperator* unaop = dynamic_cast<UnaryOperator*>(*it);
472 if (unaop) {
473 *it = constant_expression_fold(unaop);
474 (*it)->parent = root;
477 constant_expression_optimization(*it);
479 root->setChildren(v);
483 static int compile_function_step(parse_info& pi, function_info& fi, ASTNode *node)
485 int errors = 0;
486 if (node->getType() == std::string("Block")) {
487 Block* b = dynamic_cast<Block*>(node);
488 for(NodeStore::iterator it = b->list.begin(); it != b->list.end(); ++it)
489 if (*it)
490 errors += compile_function_step(pi, fi, *it);
491 } else if (node->getType() == std::string("List")) {
492 List* l = dynamic_cast<List*>(node);
493 for(NodeStore::iterator it = l->list.begin(); it != l->list.end(); ++it)
494 if (*it)
495 errors += compile_function_step(pi, fi, *it);
496 } else if (node->getType() == std::string("Declaration")) {
497 Declaration *decl = dynamic_cast<Declaration*>(node);
498 errors += insert_many_into_symtable(pi, fi.locals, decl->list);
499 } else if (node->getType() == std::string("Ident")) {
500 Ident* id = dynamic_cast<Ident*>(node);
501 if (!find_symbol(pi, id->name)) {
502 LOGERROR(id, "unknown identifier \"%s\"\n", id->name.c_str());
503 return 1;
505 } else {
506 NodeStore v = node->getChildren();
507 for(NodeStore::iterator it = v.begin(); it != v.end(); ++it)
508 if (*it)
509 errors += compile_function_step(pi, fi, *it);
511 return errors;
514 static int compile_function(parse_info& pi, ASTNode* node)
516 int errors = 0;
517 Function* f = dynamic_cast<Function*>(node);
518 assert(f);
519 FunctionProto *proto = dynamic_cast<FunctionProto*>(f->proto);
520 assert(proto);
522 function_info fi;
523 fi.ast = f;
524 fi.index = pi.functions.size();
525 fi.name = proto->name;
526 pi.locals = &fi.locals;
527 errors += fi.readParams(pi, proto->params);
528 ASTNode *tmp = find_symbol(pi, fi.name);
529 if (tmp) {
530 LOGERROR(f, "duplicate declaration of \"%s\"\n", fi.name.c_str());
531 LOGLOC(tmp, "previously defined here\n");
532 return errors+1;
534 pi.functions.insert(FunctionTable::value_type(fi.name, fi));
536 errors += compile_function_step(pi, fi, f->instr);
538 pi.locals = 0;
539 LOG("%s: read %u params\n", fi.name.c_str(), fi.params.size());
540 return errors;
543 int validate_tree(parse_info& pi, ASTNode *root)
545 NodeStore v = root->getChildren();
546 for (NodeStore::iterator it = v.begin(); it != v.end(); ++it) {
547 if (!(*it))
548 continue;
549 ASTNode *node = *it;
550 std::string type = node->getType();
551 if (type == "Break" || type == "Continue")
552 pi.errors += !validate_loop_exit(node);
553 else if (type == "Declaration") {
554 bool validDecl = validate_declarations(node);
555 if (validDecl)
556 pi.errors += !add_declarations_to_symtable(pi, node);
557 else
558 ++pi.errors;
559 } else if (type == "Jump") {
560 pi.errors += !validate_jump(node);
561 } else if (type == "Label") {
562 pi.errors += !add_label(pi, node);
564 validate_tree(pi, node);
565 if (type == "Function") {
566 pi.errors += compile_function(pi, node);
569 return pi.errors;