[visitors] ported llvm asm generator
[ozulis.git] / src / ozulis / visitors / type-checker.cc
blob149b47452560d7397272207ef142b153db5bb81f
1 #include <boost/foreach.hpp>
3 #include <ozulis/core/assert.hh>
4 #include <ozulis/ast/ast-cast.hh>
5 #include <ozulis/ast/node-factory.hh>
6 #include <ozulis/ast/cast-tables.hh>
7 #include <ozulis/ast/scope.hh>
8 #include <ozulis/visitors/type-checker.hh>
9 #include <ozulis/visitors/browser.hh>
11 namespace ozulis
13 namespace visitors
15 TypeChecker::TypeChecker()
16 : Visitor<TypeChecker>(),
17 currentFunction(0),
18 scope(0),
19 replacement(0)
23 TypeChecker::~TypeChecker()
27 template <typename T>
28 void
29 TypeChecker::check(T *& node)
31 replacement = 0;
32 TypeChecker::visit(*node, *this);
33 if (replacement)
35 node = ast::ast_cast<T *> (replacement);
36 replacement = 0;
40 static void visitFile(ast::Node & node_, TypeChecker & ctx)
42 ast::File & node = reinterpret_cast<ast::File &>(node_);
43 ctx.scope = node.scope;
44 Browser<TypeChecker>::visit(node, ctx);
47 static void visitFunction(ast::Node & node_, TypeChecker & ctx)
49 ast::Function & node = reinterpret_cast<ast::Function &>(node_);
50 /// @todo check that the last statement is a branch or a return
51 ctx.currentFunction = &node;
52 Browser<TypeChecker>::visit(node, ctx);
55 static void visitBlock(ast::Node & node_, TypeChecker & ctx)
57 ast::Block & node = reinterpret_cast<ast::Block &>(node_);
58 ctx.scope = node.scope;
59 // We must respect this order
60 BOOST_FOREACH (ast::VarDecl *& varDecl, (*node.varDecls))
61 ctx.check(varDecl);
62 BOOST_FOREACH (ast::Node *& statement, (*node.statements))
63 ctx.check(statement);
66 static void visitReturn(ast::Node & node_, TypeChecker & ctx)
68 ast::Return & node = reinterpret_cast<ast::Return &>(node_);
69 ctx.check(node.exp);
70 if (!isSameType(node.exp->type, ctx.currentFunction->returnType))
72 ast::CastExp * castExp = new ast::CastExp;
73 castExp->exp = node.exp;
74 castExp->type = ctx.currentFunction->returnType;
75 node.exp = castExp;
79 static void visitAssignExp(ast::Node & node_, TypeChecker & ctx)
81 ast::AssignExp & node = reinterpret_cast<ast::AssignExp &>(node_);
82 ctx.check(node.dest);
83 ctx.check(node.value);
85 ast::CastExp * castExp = new ast::CastExp();
86 castExp->type = unreferencedType(node.dest->type);
87 castExp->exp = node.value;
88 node.value = castExp;
89 node.type = castExp->type;
92 static void visitIdExp(ast::Node & node_, TypeChecker & ctx)
94 ast::IdExp & node = reinterpret_cast<ast::IdExp &>(node_);
95 if (node.symbol && node.symbol->address &&
96 node.symbol->address->nodeType == ast::RegisterAddress::nodeTypeId())
97 return;
99 assert(ctx.scope);
100 ast::Symbol * s = ctx.scope->findSymbol(node.symbol->name);
101 assert(s);
102 assert(s->type);
103 assert(s->address);
104 node.symbol = s;
105 node.type = s->type;
108 static void visitAtExp(ast::Node & node_, TypeChecker & ctx)
110 ast::AtExp & node = reinterpret_cast<ast::AtExp &>(node_);
111 ctx.check(node.exp);
112 ast::PointerType * type = new ast::PointerType;
113 type->type = unreferencedType(node.exp->type);
114 node.type = type;
115 /// @todo check that i can get the address of exp
116 assert(node.exp->type->nodeType == ast::ReferenceType::nodeTypeId() ||
117 node.exp->nodeType == ast::DereferenceExp::nodeTypeId() ||
118 node.exp->nodeType == ast::DereferenceByIndexExp::nodeTypeId());
121 static void visitDereferenceExp(ast::Node & node_, TypeChecker & ctx)
123 ast::DereferenceExp & node = reinterpret_cast<ast::DereferenceExp &>(node_);
124 ctx.check(node.exp);
125 ast::Type * unrefType = unreferencedType(node.exp->type);
126 assert_msg(unrefType->nodeType == ast::PointerType::nodeTypeId(),
127 "Error: you can't dereference a non pointer type");
128 ast::PointerType * type = ast::ast_cast<ast::PointerType *> (unrefType);
129 node.type = type->type;
132 static void visitDereferenceByIndexExp(ast::Node & node_, TypeChecker & ctx)
134 ast::DereferenceByIndexExp & node =
135 reinterpret_cast<ast::DereferenceByIndexExp &>(node_);
136 ctx.check(node.exp);
137 ctx.check(node.index);
138 ast::Type * unrefType = unreferencedType(node.exp->type);
139 assert_msg(unrefType->nodeType == ast::PointerType::nodeTypeId(),
140 "Error: you can't dereference a non pointer type");
141 ast::PointerType * type = ast::ast_cast<ast::PointerType *> (unrefType);
142 node.type = type->type;
144 // @todo select the itype depending on the platforme pointer size
145 ast::IntegerType * itype = new ast::IntegerType;
146 itype->isSigned = true;
147 itype->size = 32;
148 node.index = castToType(itype, node.index);
151 static void visitSymbol(ast::Node & node_, TypeChecker & ctx)
153 ast::Symbol & node = reinterpret_cast<ast::Symbol &>(node_);
154 assert(ctx.scope);
155 const ast::Symbol * s = ctx.scope->findSymbol(node.name);
156 assert(s);
157 assert(s->type);
158 assert(s->address);
159 node.type = s->type;
160 node.address = s->address;
163 static void visitCastExp(ast::Node & node_, TypeChecker & ctx)
165 ast::CastExp & node = reinterpret_cast<ast::CastExp &>(node_);
166 assert(node.exp);
167 ctx.check(node.exp);
170 static void visitCallExp(ast::Node & node_, TypeChecker & ctx)
172 ast::CallExp & node = reinterpret_cast<ast::CallExp &>(node_);
173 Browser<TypeChecker>::visit(node, ctx);
174 /// @todo generate the function's signature depending on parameters type
175 const ast::Symbol * s = ctx.scope->findSymbol(node.id);
176 assert_msg(s, "function not found in symbol table");
177 assert(s->type);
179 node.ftype = ast::ast_cast<ast::FunctionType *> (s->type);
180 node.type = node.ftype->returnType;
182 assert(node.ftype->argsType->size() >= node.args->size());
183 for (unsigned i = 0; i < node.args->size(); i++)
184 (*node.args)[i] = castToType((*node.ftype->argsType)[i], (*node.args)[i]);
187 void
188 TypeChecker::homogenizeTypes(ast::BinaryExp & node)
190 ast::CastExp * castExp = ast::castToBestType(node.left->type,
191 node.right->type);
192 if (!castExp)
194 node.type = node.left->type;
195 return;
198 if (castExp->type == node.left->type)
200 castExp->exp = node.right;
201 node.right = castExp;
203 else
205 castExp->exp = node.left;
206 node.left = castExp;
208 node.type = castExp->type;
212 * @internal
213 * - compute the new address:
214 * - addrui = ptrtoui addr
215 * - offset = index * sizeof(*addr)
216 * - newaddrui = ptrtoui + offset
217 * - newaddr = cast(*, newaddrui)
219 void
220 TypeChecker::pointerArith(ast::Exp * pointer, ast::Exp * offset)
222 ast::PointerType * pointerType =
223 ast::ast_cast<ast::PointerType *> (unreferencedType(pointer->type));
225 /* cast node.exp to uint */
226 ast::CastExp * cast1 = ast::NodeFactory::createCastPtrToUInt(pointer);
228 /* compute the offset */
229 ast::MulExp * mul = new ast::MulExp;
230 mul->left = offset;
231 mul->right = ast::NodeFactory::createSizeofValue(pointerType->type);
233 /* add the offset to the uint addr */
234 ast::AddExp * add = new ast::AddExp;
235 add->left = cast1;
236 add->right = mul;
238 /* cast uint addr to ptr */
239 ast::CastExp * cast2 = ast::NodeFactory::createCastUIntToPtr(add, unreferencedType(pointerType));
240 check(cast2);
242 replacement = cast2;
245 #define VISIT_ADD_EXP(Type) \
246 static void visit##Type(ast::Node & node_, TypeChecker & ctx) \
248 ast::Type & node = reinterpret_cast<ast::Type &> (node_); \
249 ctx.check(node.left); \
250 ctx.check(node.right); \
251 if (unreferencedType(node.left->type)->nodeType == \
252 ast::PointerType::nodeTypeId()) \
253 ctx.pointerArith(node.left, node.right); \
254 else if (unreferencedType(node.right->type)->nodeType == \
255 ast::PointerType::nodeTypeId()) \
256 ctx.pointerArith(node.right, node.left); \
257 else \
258 TypeChecker::homogenizeTypes(node); \
261 #define VISIT_BINARY_ARITH(Type) \
262 static void visit##Type(ast::Node & node_, TypeChecker & ctx) \
264 ast::Type & node = reinterpret_cast<ast::Type &> (node_); \
265 ctx.check(node.left); \
266 ctx.check(node.right); \
267 TypeChecker::homogenizeTypes(node); \
270 VISIT_ADD_EXP(AddExp)
271 VISIT_ADD_EXP(SubExp)
272 VISIT_BINARY_ARITH(MulExp)
273 VISIT_BINARY_ARITH(DivExp)
274 VISIT_BINARY_ARITH(ModExp)
276 #define VISIT_BITWISE_EXP(Type) \
277 static void visit##Type(ast::Node & node_, TypeChecker & ctx) \
279 ast::Type & node = reinterpret_cast<ast::Type &> (node_); \
280 ctx.check(node.left); \
281 ctx.check(node.right); \
283 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.left->type) && \
284 !AST_IS_FLOATING_POINT_TYPE(node.right->type), \
285 "throw and error here, can't do bitwise operation on float."); \
286 TypeChecker::homogenizeTypes(node); \
289 VISIT_BITWISE_EXP(AndExp)
290 VISIT_BITWISE_EXP(OrExp)
291 VISIT_BITWISE_EXP(XorExp)
293 VISIT_BITWISE_EXP(ShlExp)
294 VISIT_BITWISE_EXP(AShrExp)
295 VISIT_BITWISE_EXP(LShrExp)
297 #define VISIT_BINARY_CMP_EXP(Type) \
298 static void visit##Type(ast::Node & node_, TypeChecker & ctx) \
300 ast::Type & node = reinterpret_cast<ast::Type &> (node_); \
301 ctx.check(node.left); \
302 ctx.check(node.right); \
303 /** \
304 * @todo check if type match \
305 * @todo look for an overloaded operator '+' \
306 * @todo the cast can be applied on both nodes \
307 */ \
308 TypeChecker::homogenizeTypes(node); \
309 node.type = ast::NodeFactory::createBoolType(); \
312 VISIT_BINARY_CMP_EXP(EqExp)
313 VISIT_BINARY_CMP_EXP(NeqExp)
314 VISIT_BINARY_CMP_EXP(LtExp)
315 VISIT_BINARY_CMP_EXP(LtEqExp)
316 VISIT_BINARY_CMP_EXP(GtExp)
317 VISIT_BINARY_CMP_EXP(GtEqExp)
319 #define VISIT_BINARY_BOOL_EXP(Type) \
320 static void visit##Type(ast::Node & node_, TypeChecker & ctx) \
322 ast::Type & node = reinterpret_cast<ast::Type &> (node_); \
323 node.type = ast::NodeFactory::createBoolType(); \
325 ctx.check(node.left); \
326 ctx.check(node.right); \
328 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.left->type) && \
329 !AST_IS_FLOATING_POINT_TYPE(node.right->type), \
330 "throw and error here, can't do bitwise operation on float."); \
332 node.left = castToType(node.type, node.left); \
333 node.right = castToType(node.type, node.right); \
336 VISIT_BINARY_BOOL_EXP(OrOrExp)
337 VISIT_BINARY_BOOL_EXP(AndAndExp)
339 static void visitNegExp(ast::Node & node_, TypeChecker & ctx)
341 ast::NegExp & node = reinterpret_cast<ast::NegExp &>(node_);
342 ctx.check(node.exp);
343 node.type = node.exp->type;
346 static void visitNotExp(ast::Node & node_, TypeChecker & ctx)
348 ast::NotExp & node = reinterpret_cast<ast::NotExp &>(node_);
349 ctx.check(node.exp);
350 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.exp->type),
351 "throw and error here, can't do bitwise operation on float.");
352 node.type = node.exp->type;
355 static void visitBangExp(ast::Node & node_, TypeChecker & ctx)
357 ast::BangExp & node = reinterpret_cast<ast::BangExp &>(node_);
358 ctx.check(node.exp);
359 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.exp->type),
360 "throw and error here, can't do bitwise operation on float.");
361 node.type = ast::NodeFactory::createBoolType();
362 node.exp = castToType(node.type, node.exp);
365 static void visitConditionalBranch(ast::Node & node_, TypeChecker & ctx)
367 ast::ConditionalBranch & node = reinterpret_cast<ast::ConditionalBranch &>(node_);
368 ctx.check(node.cond);
369 node.cond = castToType(ast::NodeFactory::createBoolType(), node.cond);
372 void
373 TypeChecker::initBase()
375 #define REGISTER_METHOD(Class) \
376 registerMethod(ast::Class::nodeTypeId(), visit##Class)
378 REGISTER_METHOD(AssignExp);
380 REGISTER_METHOD(AddExp);
381 REGISTER_METHOD(SubExp);
382 REGISTER_METHOD(MulExp);
383 REGISTER_METHOD(DivExp);
384 REGISTER_METHOD(ModExp);
386 REGISTER_METHOD(AndExp);
387 REGISTER_METHOD(OrExp);
388 REGISTER_METHOD(XorExp);
390 REGISTER_METHOD(ShlExp);
391 REGISTER_METHOD(AShrExp);
392 REGISTER_METHOD(LShrExp);
394 REGISTER_METHOD(OrOrExp);
395 REGISTER_METHOD(AndAndExp);
397 REGISTER_METHOD(EqExp);
398 REGISTER_METHOD(NeqExp);
399 REGISTER_METHOD(LtExp);
400 REGISTER_METHOD(LtEqExp);
401 REGISTER_METHOD(GtExp);
402 REGISTER_METHOD(GtEqExp);
404 REGISTER_METHOD(NotExp);
405 REGISTER_METHOD(NegExp);
406 REGISTER_METHOD(BangExp);
408 REGISTER_METHOD(IdExp);
409 REGISTER_METHOD(AtExp);
410 REGISTER_METHOD(DereferenceExp);
411 REGISTER_METHOD(DereferenceByIndexExp);
412 REGISTER_METHOD(CastExp);
413 REGISTER_METHOD(CallExp);
415 REGISTER_METHOD(Symbol);
417 REGISTER_METHOD(File);
418 REGISTER_METHOD(Function);
419 REGISTER_METHOD(Block);
421 REGISTER_METHOD(Return);
423 REGISTER_METHOD(ConditionalBranch);
425 completeWith<Browser<TypeChecker> >();