simplified DerefenceByIndex with typechecker's pointer arithmetics
[ozulis.git] / src / ast / type-checker-visitor.cc
blobf25423e81e1cdb449fe4987677d6f92ab2472c36
1 #include <boost/foreach.hpp>
3 #include <core/assert.hh>
4 #include <ast/node-factory.hh>
5 #include <ast/cast-tables.hh>
6 #include <ast/scope.hh>
7 #include <ast/type-checker-visitor.hh>
9 namespace ast
11 TypeCheckerVisitor::TypeCheckerVisitor()
12 : BrowseVisitor(),
13 scope_()
17 TypeCheckerVisitor::~TypeCheckerVisitor()
21 template <typename T>
22 void
23 TypeCheckerVisitor::check(T *& node)
25 replacement_ = 0;
26 node->accept(*this);
27 if (replacement_)
29 node = reinterpret_cast<T *> (replacement_);
30 replacement_ = 0;
34 void
35 TypeCheckerVisitor::visit(File & node)
37 scope_ = node.scope;
38 super_t::visit(node);
41 void
42 TypeCheckerVisitor::visit(Function & node)
44 /// @todo check that the last statement is a branch or a return
45 currentFunction_ = &node;
46 super_t::visit(node);
49 void
50 TypeCheckerVisitor::visit(Block & node)
52 scope_ = node.scope;
53 // We must respect this order
54 BOOST_FOREACH (VarDecl *& varDecl, (*node.varDecls))
55 check(varDecl);
56 BOOST_FOREACH (Node *& statement, (*node.statements))
57 check(statement);
60 void
61 TypeCheckerVisitor::visit(Return & node)
63 check(node.exp);
64 if (!isSameType(node.exp->type, currentFunction_->returnType))
66 CastExp * castExp = new CastExp;
67 castExp->exp = node.exp;
68 castExp->type = currentFunction_->returnType;
69 node.exp = castExp;
73 void
74 TypeCheckerVisitor::visit(AssignExp & node)
76 check(node.dest);
77 check(node.value);
78 CastExp * castExp = new CastExp();
79 castExp->type = node.dest->type;
80 castExp->exp = node.value;
81 assert(node.dest->type);
82 node.value = castExp;
83 node.type = node.dest->type;
86 void
87 TypeCheckerVisitor::visit(IdExp & node)
89 if (node.symbol && node.symbol->address &&
90 node.symbol->address->nodeType == RegisterAddress::nodeTypeId())
91 return;
93 assert(scope_);
94 Symbol * s = scope_->findSymbol(node.symbol->name);
95 assert(s);
96 assert(s->type);
97 node.type = s->type;
98 node.symbol = s;
101 void
102 TypeCheckerVisitor::visit(DereferenceExp & node)
104 check(node.exp);
105 Type * unrefType = unreferencedType(node.exp->type);
106 assert_msg(unrefType->nodeType == PointerType::nodeTypeId(),
107 "You can't dereference a non pointer type");
108 PointerType * type = reinterpret_cast<PointerType *> (unrefType);
109 node.type = type->type;
112 void
113 TypeCheckerVisitor::visit(DereferenceByIndexExp & node)
115 check(node.exp);
116 check(node.index);
117 Type * unrefType = unreferencedType(node.exp->type);
118 assert_msg(unrefType->nodeType == PointerType::nodeTypeId(),
119 "You can't dereference a non pointer type");
120 PointerType * type = reinterpret_cast<PointerType *> (unrefType);
121 node.type = type->type;
123 // @todo select the itype depending on the platforme pointer size
124 IntegerType * itype = new IntegerType;
125 itype->isSigned = true;
126 itype->size = 32;
127 node.index = castToType(itype, node.index);
130 void
131 TypeCheckerVisitor::visit(Symbol & node)
133 assert(scope_);
134 const Symbol * s = scope_->findSymbol(node.name);
135 assert(s);
136 assert(s->type);
137 node.type = s->type;
140 void
141 TypeCheckerVisitor::visit(CastExp & node)
143 assert(node.exp);
144 check(node.exp);
147 void
148 TypeCheckerVisitor::visit(CallExp & node)
150 super_t::visit(node);
151 /// @todo generate the function's signature depending on parameters type
152 const Symbol * s = scope_->findSymbol(node.id);
153 assert_msg(s, "function not found in symbol table");
154 assert(s->type);
156 assert(s->type->nodeType == FunctionType::nodeTypeId());
157 node.ftype = reinterpret_cast<FunctionType *> (s->type);
158 node.type = node.ftype->returnType;
160 assert(node.ftype->argsType->size() >= node.args->size());
161 for (unsigned i = 0; i < node.args->size(); i++)
162 (*node.args)[i] = castToType((*node.ftype->argsType)[i], (*node.args)[i]);
165 void
166 TypeCheckerVisitor::homogenizeTypes(BinaryExp & node)
168 CastExp * castExp = castToBestType(node.left->type,
169 node.right->type);
170 if (!castExp)
172 node.type = node.left->type;
173 return;
176 if (castExp->type == node.left->type)
178 castExp->exp = node.right;
179 node.right = castExp;
181 else
183 castExp->exp = node.left;
184 node.left = castExp;
186 node.type = castExp->type;
190 * @internal
191 * - compute the new address:
192 * - addrui = ptrtoui addr
193 * - offset = index * sizeof(*addr)
194 * - newaddrui = ptrtoui + offset
195 * - newaddr = cast(*, newaddrui)
197 void
198 TypeCheckerVisitor::pointerArith(Exp * pointer, Exp * offset)
200 assert(unreferencedType(offset->type)->nodeType != PointerType::nodeTypeId());
201 PointerType * pointerType =
202 reinterpret_cast<PointerType *> (unreferencedType(pointer->type));
204 /* cast node.exp to uint */
205 CastExp * cast1 = NodeFactory::createCastPtrToUInt(pointer);
207 /* compute the offset */
208 MulExp * mul = new MulExp;
209 mul->left = offset;
210 mul->right = NodeFactory::createSizeofValue(pointerType->type);
212 /* add the offset to the uint addr */
213 AddExp * add = new AddExp;
214 add->left = cast1;
215 add->right = mul;
217 /* cast uint addr to ptr */
218 CastExp * cast2 = NodeFactory::createCastUIntToPtr(add, unreferencedType(pointerType));
219 check(cast2);
221 replacement_ = cast2;
224 #define VISIT_ADD_EXP(Type) \
225 void \
226 TypeCheckerVisitor::visit(Type & node) \
228 check(node.left); \
229 check(node.right); \
230 if (unreferencedType(node.left->type)->nodeType == \
231 PointerType::nodeTypeId()) \
232 pointerArith(node.left, node.right); \
233 else if (unreferencedType(node.right->type)->nodeType == \
234 PointerType::nodeTypeId()) \
235 pointerArith(node.right, node.left); \
236 else \
237 homogenizeTypes(node); \
240 #define VISIT_BINARY_ARITH(Type) \
241 void \
242 TypeCheckerVisitor::visit(Type & node) \
244 check(node.left); \
245 check(node.right); \
246 homogenizeTypes(node); \
249 VISIT_ADD_EXP(AddExp)
250 VISIT_ADD_EXP(SubExp)
251 VISIT_BINARY_ARITH(MulExp)
252 VISIT_BINARY_ARITH(DivExp)
253 VISIT_BINARY_ARITH(ModExp)
255 #define VISIT_BITWISE_EXP(Type) \
256 void \
257 TypeCheckerVisitor::visit(Type & node) \
259 check(node.left); \
260 check(node.right); \
262 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.left->type) && \
263 !AST_IS_FLOATING_POINT_TYPE(node.right->type), \
264 "throw and error here, can't do bitwise operation on float."); \
265 homogenizeTypes(node); \
268 VISIT_BITWISE_EXP(AndExp)
269 VISIT_BITWISE_EXP(OrExp)
270 VISIT_BITWISE_EXP(XorExp)
272 VISIT_BITWISE_EXP(ShlExp)
273 VISIT_BITWISE_EXP(AShrExp)
274 VISIT_BITWISE_EXP(LShrExp)
276 #define VISIT_BINARY_CMP_EXP(Type) \
277 void \
278 TypeCheckerVisitor::visit(Type & node) \
280 check(node.left); \
281 check(node.right); \
282 /** \
283 * @todo check if type match \
284 * @todo look for an overloaded operator '+' \
285 * @todo the cast can be applied on both nodes \
286 */ \
287 homogenizeTypes(node); \
288 node.type = NodeFactory::createBoolType(); \
291 VISIT_BINARY_CMP_EXP(EqExp)
292 VISIT_BINARY_CMP_EXP(NeqExp)
293 VISIT_BINARY_CMP_EXP(LtExp)
294 VISIT_BINARY_CMP_EXP(LtEqExp)
295 VISIT_BINARY_CMP_EXP(GtExp)
296 VISIT_BINARY_CMP_EXP(GtEqExp)
298 #define VISIT_BINARY_BOOL_EXP(Type) \
299 void \
300 TypeCheckerVisitor::visit(Type & node) \
302 node.type = NodeFactory::createBoolType(); \
304 check(node.left); \
305 check(node.right); \
307 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.left->type) && \
308 !AST_IS_FLOATING_POINT_TYPE(node.right->type), \
309 "throw and error here, can't do bitwise operation on float."); \
311 node.left = castToType(node.type, node.left); \
312 node.right = castToType(node.type, node.right); \
315 VISIT_BINARY_BOOL_EXP(OrOrExp)
316 VISIT_BINARY_BOOL_EXP(AndAndExp)
318 void
319 TypeCheckerVisitor::visit(NegExp & node)
321 check(node.exp);
322 node.type = node.exp->type;
325 void
326 TypeCheckerVisitor::visit(NotExp & node)
328 check(node.exp);
329 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.exp->type),
330 "throw and error here, can't do bitwise operation on float.");
331 node.type = node.exp->type;
334 void
335 TypeCheckerVisitor::visit(BangExp & node)
337 check(node.exp);
338 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.exp->type),
339 "throw and error here, can't do bitwise operation on float.");
340 node.type = NodeFactory::createBoolType();
341 node.exp = castToType(node.type, node.exp);
344 void
345 TypeCheckerVisitor::visit(ConditionalBranch & node)
347 check(node.cond);
348 node.cond = castToType(NodeFactory::createBoolType(), node.cond);