[pointers] need to unref types
[ozulis.git] / src / ast / type-checker-visitor.cc
blob580f7409d0cd7ec02644d04e16753c2ed1f023de
1 #include <boost/foreach.hpp>
3 #include <core/assert.hh>
5 #include <ast/cast-tables.hh>
6 #include <ast/scope.hh>
7 #include "type-checker-visitor.hh"
9 namespace ast
11 TypeCheckerVisitor::TypeCheckerVisitor()
12 : BrowseVisitor(),
13 scope_()
17 TypeCheckerVisitor::~TypeCheckerVisitor()
21 void
22 TypeCheckerVisitor::visit(File & node)
24 scope_ = node.scope;
25 super_t::visit(node);
28 void
29 TypeCheckerVisitor::visit(Function & node)
31 /// @todo check that the last statement is a branch or a return
32 currentFunction_ = &node;
33 super_t::visit(node);
36 void
37 TypeCheckerVisitor::visit(Block & node)
39 scope_ = node.scope;
40 // We must respect this order
41 BOOST_FOREACH (Node * varDecl, (*node.varDecls))
42 varDecl->accept(*this);
43 BOOST_FOREACH (Node * statement, (*node.statements))
44 statement->accept(*this);
47 void
48 TypeCheckerVisitor::visit(Return & node)
50 node.exp->accept(*this);
51 if (!isSameType(node.exp->type, currentFunction_->returnType))
53 CastExp * castExp = new CastExp;
54 castExp->exp = node.exp;
55 castExp->type = currentFunction_->returnType;
56 node.exp = castExp;
60 void
61 TypeCheckerVisitor::visit(AssignExp & node)
63 node.dest->accept(*this);
64 node.value->accept(*this);
65 CastExp * castExp = new CastExp();
66 castExp->type = node.dest->type;
67 castExp->exp = node.value;
68 assert(node.dest->type);
69 node.value = castExp;
70 node.type = node.dest->type;
73 void
74 TypeCheckerVisitor::visit(IdExp & node)
76 assert(scope_);
77 Symbol * s = scope_->findSymbol(node.symbol->name);
78 assert(s);
79 assert(s->type);
80 node.type = s->type;
81 node.symbol = s;
84 void
85 TypeCheckerVisitor::visit(DereferenceExp & node)
87 node.exp->accept(*this);
88 Type * unrefType = unreferencedType(node.exp->type);
89 assert_msg(unrefType->nodeType == PointerType::nodeTypeId(),
90 "You can't dereference a non pointer type");
91 PointerType * type = reinterpret_cast<PointerType *> (unrefType);
92 node.type = type->type;
95 void
96 TypeCheckerVisitor::visit(Symbol & node)
98 assert(scope_);
99 const Symbol * s = scope_->findSymbol(node.name);
100 assert(s);
101 assert(s->type);
102 node.type = s->type;
105 void
106 TypeCheckerVisitor::visit(CastExp & node)
108 assert(node.exp);
109 node.exp->accept(*this);
112 void
113 TypeCheckerVisitor::visit(CallExp & node)
115 super_t::visit(node);
116 /// @todo generate the function's signature depending on parameters type
117 const Symbol * s = scope_->findSymbol(node.id);
118 assert_msg(s, "function not found in symbol table");
119 assert(s->type);
121 assert(s->type->nodeType == FunctionType::nodeTypeId());
122 node.ftype = reinterpret_cast<FunctionType *> (s->type);
123 node.type = node.ftype->returnType;
125 assert(node.ftype->argsType->size() >= node.args->size());
126 for (unsigned i = 0; i < node.args->size(); i++)
127 (*node.args)[i] = castToType((*node.ftype->argsType)[i], (*node.args)[i]);
130 #define VISIT(Type) \
131 void \
132 TypeCheckerVisitor::visit(Type & node) \
134 node.left->accept(*this); \
135 node.right->accept(*this); \
136 /** \
137 * @todo check if type match \
138 * @todo look for an overloaded operator '+' \
139 * @todo the cast can be applied on both nodes \
140 */ \
141 CastExp * castExp = castToBestType(node.left->type, \
142 node.right->type); \
144 if (!castExp) \
146 node.type = node.left->type; \
147 return; \
150 if (castExp->type == node.left->type) \
152 castExp->exp = node.right; \
153 node.right = castExp; \
155 else \
157 castExp->exp = node.left; \
158 node.left = castExp; \
160 node.type = castExp->type; \
163 VISIT(AddExp)
164 VISIT(SubExp)
165 VISIT(MulExp)
166 VISIT(DivExp)
167 VISIT(ModExp)
169 #define VISIT_BITWISE_EXP(Type) \
170 void \
171 TypeCheckerVisitor::visit(Type & node) \
173 node.left->accept(*this); \
174 node.right->accept(*this); \
176 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.left->type) && \
177 !AST_IS_FLOATING_POINT_TYPE(node.right->type), \
178 "throw and error here, can't do bitwise operation on float."); \
179 CastExp * castExp = castToBestType(node.left->type, \
180 node.right->type); \
182 if (!castExp) \
184 node.type = node.left->type; \
185 return; \
188 if (castExp->type == node.left->type) \
190 castExp->exp = node.right; \
191 node.right = castExp; \
193 else \
195 castExp->exp = node.left; \
196 node.left = castExp; \
198 node.type = castExp->type; \
201 VISIT_BITWISE_EXP(AndExp)
202 VISIT_BITWISE_EXP(OrExp)
203 VISIT_BITWISE_EXP(XorExp)
205 VISIT_BITWISE_EXP(ShlExp)
206 VISIT_BITWISE_EXP(AShrExp)
207 VISIT_BITWISE_EXP(LShrExp)
209 #define VISIT_BINARY_CMP_EXP(Type) \
210 void \
211 TypeCheckerVisitor::visit(Type & node) \
213 node.left->accept(*this); \
214 node.right->accept(*this); \
215 node.type = new BoolType; \
216 /** \
217 * @todo check if type match \
218 * @todo look for an overloaded operator '+' \
219 * @todo the cast can be applied on both nodes \
220 */ \
221 CastExp * castExp = castToBestType(node.left->type, \
222 node.right->type); \
224 if (!castExp) \
225 return; \
227 if (castExp->type == node.left->type) \
229 castExp->exp = node.right; \
230 node.right = castExp; \
232 else \
234 castExp->exp = node.left; \
235 node.left = castExp; \
239 VISIT_BINARY_CMP_EXP(EqExp)
240 VISIT_BINARY_CMP_EXP(NeqExp)
241 VISIT_BINARY_CMP_EXP(LtExp)
242 VISIT_BINARY_CMP_EXP(LtEqExp)
243 VISIT_BINARY_CMP_EXP(GtExp)
244 VISIT_BINARY_CMP_EXP(GtEqExp)
246 #define VISIT_BINARY_BOOL_EXP(Type) \
247 void \
248 TypeCheckerVisitor::visit(Type & node) \
250 node.type = new BoolType; \
252 node.left->accept(*this); \
253 node.right->accept(*this); \
255 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.left->type) && \
256 !AST_IS_FLOATING_POINT_TYPE(node.right->type), \
257 "throw and error here, can't do bitwise operation on float."); \
259 if (node.left->type->nodeType != BoolType::nodeTypeId()) \
261 CastExp * castExp = new CastExp; \
262 castExp->type = node.type; \
263 castExp->exp = node.left; \
264 node.left = castExp; \
266 if (node.right->type->nodeType != BoolType::nodeTypeId()) \
268 CastExp * castExp = new CastExp; \
269 castExp->type = node.type; \
270 castExp->exp = node.right; \
271 node.right = castExp; \
275 VISIT_BINARY_BOOL_EXP(OrOrExp)
276 VISIT_BINARY_BOOL_EXP(AndAndExp)
278 void
279 TypeCheckerVisitor::visit(NegExp & node)
281 node.exp->accept(*this);
282 node.type = node.exp->type;
285 void
286 TypeCheckerVisitor::visit(NotExp & node)
288 node.exp->accept(*this);
289 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.exp->type),
290 "throw and error here, can't do bitwise operation on float.");
291 node.type = node.exp->type;
294 void
295 TypeCheckerVisitor::visit(BangExp & node)
297 node.exp->accept(*this);
298 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.exp->type),
299 "throw and error here, can't do bitwise operation on float.");
300 node.type = new BoolType;
301 if (node.exp->type->nodeType != BoolType::nodeTypeId())
303 CastExp * castExp = new CastExp;
304 castExp->exp = node.exp;
305 castExp->type = node.type;
306 node.exp = castExp;
310 void
311 TypeCheckerVisitor::visit(ConditionalBranch & node)
313 node.cond->accept(*this);
314 if (node.cond->type->nodeType == BoolType::nodeTypeId())
315 return;
317 CastExp * castExp = new CastExp;
318 castExp->exp = node.cond;
319 castExp->type = new BoolType;
320 node.cond = castExp;