[refractor] moved classes into ozulis namespace and created a folder plugins
[ozulis.git] / src / ozulis / ast / type-checker-visitor.cc
blob50fc48b03eb10c2b3bbd78c051833ae9c8dcdf1b
1 #include <boost/foreach.hpp>
3 #include <ozulis/core/assert.hh>
4 #include <ozulis/ast/ast-cast.hh>
5 #include <ozulis/ast/node-factory.hh>
6 #include <ozulis/ast/cast-tables.hh>
7 #include <ozulis/ast/scope.hh>
8 #include <ozulis/ast/type-checker-visitor.hh>
10 namespace ozulis
12 namespace ast
14 TypeCheckerVisitor::TypeCheckerVisitor()
15 : BrowseVisitor(),
16 scope_()
20 TypeCheckerVisitor::~TypeCheckerVisitor()
24 template <typename T>
25 void
26 TypeCheckerVisitor::check(T *& node)
28 replacement_ = 0;
29 node->accept(*this);
30 if (replacement_)
32 node = ast_cast<T *> (replacement_);
33 replacement_ = 0;
37 void
38 TypeCheckerVisitor::visit(File & node)
40 scope_ = node.scope;
41 super_t::visit(node);
44 void
45 TypeCheckerVisitor::visit(Function & node)
47 /// @todo check that the last statement is a branch or a return
48 currentFunction_ = &node;
49 super_t::visit(node);
52 void
53 TypeCheckerVisitor::visit(Block & node)
55 scope_ = node.scope;
56 // We must respect this order
57 BOOST_FOREACH (VarDecl *& varDecl, (*node.varDecls))
58 check(varDecl);
59 BOOST_FOREACH (Node *& statement, (*node.statements))
60 check(statement);
63 void
64 TypeCheckerVisitor::visit(Return & node)
66 check(node.exp);
67 if (!isSameType(node.exp->type, currentFunction_->returnType))
69 CastExp * castExp = new CastExp;
70 castExp->exp = node.exp;
71 castExp->type = currentFunction_->returnType;
72 node.exp = castExp;
76 void
77 TypeCheckerVisitor::visit(AssignExp & node)
79 check(node.dest);
80 check(node.value);
82 CastExp * castExp = new CastExp();
83 castExp->type = unreferencedType(node.dest->type);
84 castExp->exp = node.value;
85 node.value = castExp;
86 node.type = castExp->type;
89 void
90 TypeCheckerVisitor::visit(IdExp & node)
92 if (node.symbol && node.symbol->address &&
93 node.symbol->address->nodeType == RegisterAddress::nodeTypeId())
94 return;
96 assert(scope_);
97 Symbol * s = scope_->findSymbol(node.symbol->name);
98 assert(s);
99 assert(s->type);
100 assert(s->address);
101 node.symbol = s;
102 node.type = s->type;
105 void
106 TypeCheckerVisitor::visit(AtExp & node)
108 check(node.exp);
109 PointerType * type = new PointerType;
110 type->type = unreferencedType(node.exp->type);
111 node.type = type;
112 /// @todo check that i can get the address of exp
113 assert(node.exp->type->nodeType == ReferenceType::nodeTypeId() ||
114 node.exp->nodeType == DereferenceExp::nodeTypeId() ||
115 node.exp->nodeType == DereferenceByIndexExp::nodeTypeId());
118 void
119 TypeCheckerVisitor::visit(DereferenceExp & node)
121 check(node.exp);
122 Type * unrefType = unreferencedType(node.exp->type);
123 assert_msg(unrefType->nodeType == PointerType::nodeTypeId(),
124 "Error: you can't dereference a non pointer type");
125 PointerType * type = ast_cast<PointerType *> (unrefType);
126 node.type = type->type;
129 void
130 TypeCheckerVisitor::visit(DereferenceByIndexExp & node)
132 check(node.exp);
133 check(node.index);
134 Type * unrefType = unreferencedType(node.exp->type);
135 assert_msg(unrefType->nodeType == PointerType::nodeTypeId(),
136 "Error: you can't dereference a non pointer type");
137 PointerType * type = ast_cast<PointerType *> (unrefType);
138 node.type = type->type;
140 // @todo select the itype depending on the platforme pointer size
141 IntegerType * itype = new IntegerType;
142 itype->isSigned = true;
143 itype->size = 32;
144 node.index = castToType(itype, node.index);
147 void
148 TypeCheckerVisitor::visit(Symbol & node)
150 assert(scope_);
151 const Symbol * s = scope_->findSymbol(node.name);
152 assert(s);
153 assert(s->type);
154 assert(s->address);
155 node.type = s->type;
156 node.address = s->address;
159 void
160 TypeCheckerVisitor::visit(CastExp & node)
162 assert(node.exp);
163 check(node.exp);
166 void
167 TypeCheckerVisitor::visit(CallExp & node)
169 super_t::visit(node);
170 /// @todo generate the function's signature depending on parameters type
171 const Symbol * s = scope_->findSymbol(node.id);
172 assert_msg(s, "function not found in symbol table");
173 assert(s->type);
175 node.ftype = ast_cast<FunctionType *> (s->type);
176 node.type = node.ftype->returnType;
178 assert(node.ftype->argsType->size() >= node.args->size());
179 for (unsigned i = 0; i < node.args->size(); i++)
180 (*node.args)[i] = castToType((*node.ftype->argsType)[i], (*node.args)[i]);
183 void
184 TypeCheckerVisitor::homogenizeTypes(BinaryExp & node)
186 CastExp * castExp = castToBestType(node.left->type,
187 node.right->type);
188 if (!castExp)
190 node.type = node.left->type;
191 return;
194 if (castExp->type == node.left->type)
196 castExp->exp = node.right;
197 node.right = castExp;
199 else
201 castExp->exp = node.left;
202 node.left = castExp;
204 node.type = castExp->type;
208 * @internal
209 * - compute the new address:
210 * - addrui = ptrtoui addr
211 * - offset = index * sizeof(*addr)
212 * - newaddrui = ptrtoui + offset
213 * - newaddr = cast(*, newaddrui)
215 void
216 TypeCheckerVisitor::pointerArith(Exp * pointer, Exp * offset)
218 PointerType * pointerType =
219 ast_cast<PointerType *> (unreferencedType(pointer->type));
221 /* cast node.exp to uint */
222 CastExp * cast1 = NodeFactory::createCastPtrToUInt(pointer);
224 /* compute the offset */
225 MulExp * mul = new MulExp;
226 mul->left = offset;
227 mul->right = NodeFactory::createSizeofValue(pointerType->type);
229 /* add the offset to the uint addr */
230 AddExp * add = new AddExp;
231 add->left = cast1;
232 add->right = mul;
234 /* cast uint addr to ptr */
235 CastExp * cast2 = NodeFactory::createCastUIntToPtr(add, unreferencedType(pointerType));
236 check(cast2);
238 replacement_ = cast2;
241 #define VISIT_ADD_EXP(Type) \
242 void \
243 TypeCheckerVisitor::visit(Type & node) \
245 check(node.left); \
246 check(node.right); \
247 if (unreferencedType(node.left->type)->nodeType == \
248 PointerType::nodeTypeId()) \
249 pointerArith(node.left, node.right); \
250 else if (unreferencedType(node.right->type)->nodeType == \
251 PointerType::nodeTypeId()) \
252 pointerArith(node.right, node.left); \
253 else \
254 homogenizeTypes(node); \
257 #define VISIT_BINARY_ARITH(Type) \
258 void \
259 TypeCheckerVisitor::visit(Type & node) \
261 check(node.left); \
262 check(node.right); \
263 homogenizeTypes(node); \
266 VISIT_ADD_EXP(AddExp)
267 VISIT_ADD_EXP(SubExp)
268 VISIT_BINARY_ARITH(MulExp)
269 VISIT_BINARY_ARITH(DivExp)
270 VISIT_BINARY_ARITH(ModExp)
272 #define VISIT_BITWISE_EXP(Type) \
273 void \
274 TypeCheckerVisitor::visit(Type & node) \
276 check(node.left); \
277 check(node.right); \
279 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.left->type) && \
280 !AST_IS_FLOATING_POINT_TYPE(node.right->type), \
281 "throw and error here, can't do bitwise operation on float."); \
282 homogenizeTypes(node); \
285 VISIT_BITWISE_EXP(AndExp)
286 VISIT_BITWISE_EXP(OrExp)
287 VISIT_BITWISE_EXP(XorExp)
289 VISIT_BITWISE_EXP(ShlExp)
290 VISIT_BITWISE_EXP(AShrExp)
291 VISIT_BITWISE_EXP(LShrExp)
293 #define VISIT_BINARY_CMP_EXP(Type) \
294 void \
295 TypeCheckerVisitor::visit(Type & node) \
297 check(node.left); \
298 check(node.right); \
299 /** \
300 * @todo check if type match \
301 * @todo look for an overloaded operator '+' \
302 * @todo the cast can be applied on both nodes \
303 */ \
304 homogenizeTypes(node); \
305 node.type = NodeFactory::createBoolType(); \
308 VISIT_BINARY_CMP_EXP(EqExp)
309 VISIT_BINARY_CMP_EXP(NeqExp)
310 VISIT_BINARY_CMP_EXP(LtExp)
311 VISIT_BINARY_CMP_EXP(LtEqExp)
312 VISIT_BINARY_CMP_EXP(GtExp)
313 VISIT_BINARY_CMP_EXP(GtEqExp)
315 #define VISIT_BINARY_BOOL_EXP(Type) \
316 void \
317 TypeCheckerVisitor::visit(Type & node) \
319 node.type = NodeFactory::createBoolType(); \
321 check(node.left); \
322 check(node.right); \
324 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.left->type) && \
325 !AST_IS_FLOATING_POINT_TYPE(node.right->type), \
326 "throw and error here, can't do bitwise operation on float."); \
328 node.left = castToType(node.type, node.left); \
329 node.right = castToType(node.type, node.right); \
332 VISIT_BINARY_BOOL_EXP(OrOrExp)
333 VISIT_BINARY_BOOL_EXP(AndAndExp)
335 void
336 TypeCheckerVisitor::visit(NegExp & node)
338 check(node.exp);
339 node.type = node.exp->type;
342 void
343 TypeCheckerVisitor::visit(NotExp & node)
345 check(node.exp);
346 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.exp->type),
347 "throw and error here, can't do bitwise operation on float.");
348 node.type = node.exp->type;
351 void
352 TypeCheckerVisitor::visit(BangExp & node)
354 check(node.exp);
355 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.exp->type),
356 "throw and error here, can't do bitwise operation on float.");
357 node.type = NodeFactory::createBoolType();
358 node.exp = castToType(node.type, node.exp);
361 void
362 TypeCheckerVisitor::visit(ConditionalBranch & node)
364 check(node.cond);
365 node.cond = castToType(NodeFactory::createBoolType(), node.cond);