use pointers like this: a[4] works :-)
[ozulis.git] / src / ast / type-checker-visitor.cc
blobb174e542c16496eba8efabc83d25be5ce6946106
1 #include <boost/foreach.hpp>
3 #include <core/assert.hh>
4 #include <ast/node-factory.hh>
5 #include <ast/cast-tables.hh>
6 #include <ast/scope.hh>
7 #include <ast/type-checker-visitor.hh>
9 namespace ast
11 TypeCheckerVisitor::TypeCheckerVisitor()
12 : BrowseVisitor(),
13 scope_()
17 TypeCheckerVisitor::~TypeCheckerVisitor()
21 void
22 TypeCheckerVisitor::visit(File & node)
24 scope_ = node.scope;
25 super_t::visit(node);
28 void
29 TypeCheckerVisitor::visit(Function & node)
31 /// @todo check that the last statement is a branch or a return
32 currentFunction_ = &node;
33 super_t::visit(node);
36 void
37 TypeCheckerVisitor::visit(Block & node)
39 scope_ = node.scope;
40 // We must respect this order
41 BOOST_FOREACH (Node * varDecl, (*node.varDecls))
42 varDecl->accept(*this);
43 BOOST_FOREACH (Node * statement, (*node.statements))
44 statement->accept(*this);
47 void
48 TypeCheckerVisitor::visit(Return & node)
50 node.exp->accept(*this);
51 if (!isSameType(node.exp->type, currentFunction_->returnType))
53 CastExp * castExp = new CastExp;
54 castExp->exp = node.exp;
55 castExp->type = currentFunction_->returnType;
56 node.exp = castExp;
60 void
61 TypeCheckerVisitor::visit(AssignExp & node)
63 node.dest->accept(*this);
64 node.value->accept(*this);
65 CastExp * castExp = new CastExp();
66 castExp->type = node.dest->type;
67 castExp->exp = node.value;
68 assert(node.dest->type);
69 node.value = castExp;
70 node.type = node.dest->type;
73 void
74 TypeCheckerVisitor::visit(IdExp & node)
76 if (node.symbol && node.symbol->address &&
77 node.symbol->address->nodeType == RegisterAddress::nodeTypeId())
78 return;
80 assert(scope_);
81 Symbol * s = scope_->findSymbol(node.symbol->name);
82 assert(s);
83 assert(s->type);
84 node.type = s->type;
85 node.symbol = s;
88 void
89 TypeCheckerVisitor::visit(DereferenceExp & node)
91 node.exp->accept(*this);
92 Type * unrefType = unreferencedType(node.exp->type);
93 assert_msg(unrefType->nodeType == PointerType::nodeTypeId(),
94 "You can't dereference a non pointer type");
95 PointerType * type = reinterpret_cast<PointerType *> (unrefType);
96 node.type = type->type;
99 void
100 TypeCheckerVisitor::visit(DereferenceByIndexExp & node)
102 node.exp->accept(*this);
103 node.index->accept(*this);
104 Type * unrefType = unreferencedType(node.exp->type);
105 assert_msg(unrefType->nodeType == PointerType::nodeTypeId(),
106 "You can't dereference a non pointer type");
107 PointerType * type = reinterpret_cast<PointerType *> (unrefType);
108 node.type = type->type;
110 // @todo select the itype depending on the platforme pointer size
111 IntegerType * itype = new IntegerType;
112 itype->isSigned = true;
113 itype->size = 32;
114 node.index = castToType(itype, node.index);
117 void
118 TypeCheckerVisitor::visit(Symbol & node)
120 assert(scope_);
121 const Symbol * s = scope_->findSymbol(node.name);
122 assert(s);
123 assert(s->type);
124 node.type = s->type;
127 void
128 TypeCheckerVisitor::visit(CastExp & node)
130 assert(node.exp);
131 node.exp->accept(*this);
134 void
135 TypeCheckerVisitor::visit(CallExp & node)
137 super_t::visit(node);
138 /// @todo generate the function's signature depending on parameters type
139 const Symbol * s = scope_->findSymbol(node.id);
140 assert_msg(s, "function not found in symbol table");
141 assert(s->type);
143 assert(s->type->nodeType == FunctionType::nodeTypeId());
144 node.ftype = reinterpret_cast<FunctionType *> (s->type);
145 node.type = node.ftype->returnType;
147 assert(node.ftype->argsType->size() >= node.args->size());
148 for (unsigned i = 0; i < node.args->size(); i++)
149 (*node.args)[i] = castToType((*node.ftype->argsType)[i], (*node.args)[i]);
152 void
153 TypeCheckerVisitor::homogenizeTypes(BinaryExp & node)
155 CastExp * castExp = castToBestType(node.left->type,
156 node.right->type);
157 if (!castExp)
159 node.type = node.left->type;
160 return;
163 if (castExp->type == node.left->type)
165 castExp->exp = node.right;
166 node.right = castExp;
168 else
170 castExp->exp = node.left;
171 node.left = castExp;
173 node.type = castExp->type;
176 #define VISIT_BINARY_ARITH(Type) \
177 void \
178 TypeCheckerVisitor::visit(Type & node) \
180 node.left->accept(*this); \
181 node.right->accept(*this); \
182 /** \
183 * @todo check if type match \
184 * @todo look for an overloaded operator '+' \
185 * @todo the cast can be applied on both nodes \
186 */ \
187 homogenizeTypes(node); \
190 VISIT_BINARY_ARITH(AddExp)
191 VISIT_BINARY_ARITH(SubExp)
192 VISIT_BINARY_ARITH(MulExp)
193 VISIT_BINARY_ARITH(DivExp)
194 VISIT_BINARY_ARITH(ModExp)
196 #define VISIT_BITWISE_EXP(Type) \
197 void \
198 TypeCheckerVisitor::visit(Type & node) \
200 node.left->accept(*this); \
201 node.right->accept(*this); \
203 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.left->type) && \
204 !AST_IS_FLOATING_POINT_TYPE(node.right->type), \
205 "throw and error here, can't do bitwise operation on float."); \
206 homogenizeTypes(node); \
209 VISIT_BITWISE_EXP(AndExp)
210 VISIT_BITWISE_EXP(OrExp)
211 VISIT_BITWISE_EXP(XorExp)
213 VISIT_BITWISE_EXP(ShlExp)
214 VISIT_BITWISE_EXP(AShrExp)
215 VISIT_BITWISE_EXP(LShrExp)
217 #define VISIT_BINARY_CMP_EXP(Type) \
218 void \
219 TypeCheckerVisitor::visit(Type & node) \
221 node.left->accept(*this); \
222 node.right->accept(*this); \
223 /** \
224 * @todo check if type match \
225 * @todo look for an overloaded operator '+' \
226 * @todo the cast can be applied on both nodes \
227 */ \
228 homogenizeTypes(node); \
229 node.type = NodeFactory::createBoolType(); \
232 VISIT_BINARY_CMP_EXP(EqExp)
233 VISIT_BINARY_CMP_EXP(NeqExp)
234 VISIT_BINARY_CMP_EXP(LtExp)
235 VISIT_BINARY_CMP_EXP(LtEqExp)
236 VISIT_BINARY_CMP_EXP(GtExp)
237 VISIT_BINARY_CMP_EXP(GtEqExp)
239 #define VISIT_BINARY_BOOL_EXP(Type) \
240 void \
241 TypeCheckerVisitor::visit(Type & node) \
243 node.type = NodeFactory::createBoolType(); \
245 node.left->accept(*this); \
246 node.right->accept(*this); \
248 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.left->type) && \
249 !AST_IS_FLOATING_POINT_TYPE(node.right->type), \
250 "throw and error here, can't do bitwise operation on float."); \
252 node.left = castToType(node.type, node.left); \
253 node.right = castToType(node.type, node.right); \
256 VISIT_BINARY_BOOL_EXP(OrOrExp)
257 VISIT_BINARY_BOOL_EXP(AndAndExp)
259 void
260 TypeCheckerVisitor::visit(NegExp & node)
262 node.exp->accept(*this);
263 node.type = node.exp->type;
266 void
267 TypeCheckerVisitor::visit(NotExp & node)
269 node.exp->accept(*this);
270 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.exp->type),
271 "throw and error here, can't do bitwise operation on float.");
272 node.type = node.exp->type;
275 void
276 TypeCheckerVisitor::visit(BangExp & node)
278 node.exp->accept(*this);
279 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.exp->type),
280 "throw and error here, can't do bitwise operation on float.");
281 node.type = NodeFactory::createBoolType();
282 node.exp = castToType(node.type, node.exp);
285 void
286 TypeCheckerVisitor::visit(ConditionalBranch & node)
288 node.cond->accept(*this);
289 node.cond = castToType(NodeFactory::createBoolType(), node.cond);