1 #include <boost/foreach.hpp>
3 #include <ozulis/core/assert.hh>
4 #include <ozulis/ast/ast-cast.hh>
5 #include <ozulis/ast/node-factory.hh>
6 #include <ozulis/ast/cast-tables.hh>
7 #include <ozulis/ast/scope.hh>
8 #include <ozulis/ast/type-checker-visitor.hh>
14 TypeCheckerVisitor::TypeCheckerVisitor()
20 TypeCheckerVisitor::~TypeCheckerVisitor()
26 TypeCheckerVisitor::check(T
*& node
)
32 node
= ast_cast
<T
*> (replacement_
);
38 TypeCheckerVisitor::visit(File
& node
)
45 TypeCheckerVisitor::visit(Function
& node
)
47 /// @todo check that the last statement is a branch or a return
48 currentFunction_
= &node
;
53 TypeCheckerVisitor::visit(Block
& node
)
56 // We must respect this order
57 BOOST_FOREACH (VarDecl
*& varDecl
, (*node
.varDecls
))
59 BOOST_FOREACH (Node
*& statement
, (*node
.statements
))
64 TypeCheckerVisitor::visit(Return
& node
)
67 if (!isSameType(node
.exp
->type
, currentFunction_
->returnType
))
69 CastExp
* castExp
= new CastExp
;
70 castExp
->exp
= node
.exp
;
71 castExp
->type
= currentFunction_
->returnType
;
77 TypeCheckerVisitor::visit(AssignExp
& node
)
82 CastExp
* castExp
= new CastExp();
83 castExp
->type
= unreferencedType(node
.dest
->type
);
84 castExp
->exp
= node
.value
;
86 node
.type
= castExp
->type
;
90 TypeCheckerVisitor::visit(IdExp
& node
)
92 if (node
.symbol
&& node
.symbol
->address
&&
93 node
.symbol
->address
->nodeType
== RegisterAddress::nodeTypeId())
97 Symbol
* s
= scope_
->findSymbol(node
.symbol
->name
);
106 TypeCheckerVisitor::visit(AtExp
& node
)
109 PointerType
* type
= new PointerType
;
110 type
->type
= unreferencedType(node
.exp
->type
);
112 /// @todo check that i can get the address of exp
113 assert(node
.exp
->type
->nodeType
== ReferenceType::nodeTypeId() ||
114 node
.exp
->nodeType
== DereferenceExp::nodeTypeId() ||
115 node
.exp
->nodeType
== DereferenceByIndexExp::nodeTypeId());
119 TypeCheckerVisitor::visit(DereferenceExp
& node
)
122 Type
* unrefType
= unreferencedType(node
.exp
->type
);
123 assert_msg(unrefType
->nodeType
== PointerType::nodeTypeId(),
124 "Error: you can't dereference a non pointer type");
125 PointerType
* type
= ast_cast
<PointerType
*> (unrefType
);
126 node
.type
= type
->type
;
130 TypeCheckerVisitor::visit(DereferenceByIndexExp
& node
)
134 Type
* unrefType
= unreferencedType(node
.exp
->type
);
135 assert_msg(unrefType
->nodeType
== PointerType::nodeTypeId(),
136 "Error: you can't dereference a non pointer type");
137 PointerType
* type
= ast_cast
<PointerType
*> (unrefType
);
138 node
.type
= type
->type
;
140 // @todo select the itype depending on the platforme pointer size
141 IntegerType
* itype
= new IntegerType
;
142 itype
->isSigned
= true;
144 node
.index
= castToType(itype
, node
.index
);
148 TypeCheckerVisitor::visit(Symbol
& node
)
151 const Symbol
* s
= scope_
->findSymbol(node
.name
);
156 node
.address
= s
->address
;
160 TypeCheckerVisitor::visit(CastExp
& node
)
167 TypeCheckerVisitor::visit(CallExp
& node
)
169 super_t::visit(node
);
170 /// @todo generate the function's signature depending on parameters type
171 const Symbol
* s
= scope_
->findSymbol(node
.id
);
172 assert_msg(s
, "function not found in symbol table");
175 node
.ftype
= ast_cast
<FunctionType
*> (s
->type
);
176 node
.type
= node
.ftype
->returnType
;
178 assert(node
.ftype
->argsType
->size() >= node
.args
->size());
179 for (unsigned i
= 0; i
< node
.args
->size(); i
++)
180 (*node
.args
)[i
] = castToType((*node
.ftype
->argsType
)[i
], (*node
.args
)[i
]);
184 TypeCheckerVisitor::homogenizeTypes(BinaryExp
& node
)
186 CastExp
* castExp
= castToBestType(node
.left
->type
,
190 node
.type
= node
.left
->type
;
194 if (castExp
->type
== node
.left
->type
)
196 castExp
->exp
= node
.right
;
197 node
.right
= castExp
;
201 castExp
->exp
= node
.left
;
204 node
.type
= castExp
->type
;
209 * - compute the new address:
210 * - addrui = ptrtoui addr
211 * - offset = index * sizeof(*addr)
212 * - newaddrui = ptrtoui + offset
213 * - newaddr = cast(*, newaddrui)
216 TypeCheckerVisitor::pointerArith(Exp
* pointer
, Exp
* offset
)
218 PointerType
* pointerType
=
219 ast_cast
<PointerType
*> (unreferencedType(pointer
->type
));
221 /* cast node.exp to uint */
222 CastExp
* cast1
= NodeFactory::createCastPtrToUInt(pointer
);
224 /* compute the offset */
225 MulExp
* mul
= new MulExp
;
227 mul
->right
= NodeFactory::createSizeofValue(pointerType
->type
);
229 /* add the offset to the uint addr */
230 AddExp
* add
= new AddExp
;
234 /* cast uint addr to ptr */
235 CastExp
* cast2
= NodeFactory::createCastUIntToPtr(add
, unreferencedType(pointerType
));
238 replacement_
= cast2
;
241 #define VISIT_ADD_EXP(Type) \
243 TypeCheckerVisitor::visit(Type & node) \
247 if (unreferencedType(node.left->type)->nodeType == \
248 PointerType::nodeTypeId()) \
249 pointerArith(node.left, node.right); \
250 else if (unreferencedType(node.right->type)->nodeType == \
251 PointerType::nodeTypeId()) \
252 pointerArith(node.right, node.left); \
254 homogenizeTypes(node); \
257 #define VISIT_BINARY_ARITH(Type) \
259 TypeCheckerVisitor::visit(Type & node) \
263 homogenizeTypes(node); \
266 VISIT_ADD_EXP(AddExp
)
267 VISIT_ADD_EXP(SubExp
)
268 VISIT_BINARY_ARITH(MulExp
)
269 VISIT_BINARY_ARITH(DivExp
)
270 VISIT_BINARY_ARITH(ModExp
)
272 #define VISIT_BITWISE_EXP(Type) \
274 TypeCheckerVisitor::visit(Type & node) \
279 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.left->type) && \
280 !AST_IS_FLOATING_POINT_TYPE(node.right->type), \
281 "throw and error here, can't do bitwise operation on float."); \
282 homogenizeTypes(node); \
285 VISIT_BITWISE_EXP(AndExp
)
286 VISIT_BITWISE_EXP(OrExp
)
287 VISIT_BITWISE_EXP(XorExp
)
289 VISIT_BITWISE_EXP(ShlExp
)
290 VISIT_BITWISE_EXP(AShrExp
)
291 VISIT_BITWISE_EXP(LShrExp
)
293 #define VISIT_BINARY_CMP_EXP(Type) \
295 TypeCheckerVisitor::visit(Type & node) \
300 * @todo check if type match \
301 * @todo look for an overloaded operator '+' \
302 * @todo the cast can be applied on both nodes \
304 homogenizeTypes(node); \
305 node.type = NodeFactory::createBoolType(); \
308 VISIT_BINARY_CMP_EXP(EqExp
)
309 VISIT_BINARY_CMP_EXP(NeqExp
)
310 VISIT_BINARY_CMP_EXP(LtExp
)
311 VISIT_BINARY_CMP_EXP(LtEqExp
)
312 VISIT_BINARY_CMP_EXP(GtExp
)
313 VISIT_BINARY_CMP_EXP(GtEqExp
)
315 #define VISIT_BINARY_BOOL_EXP(Type) \
317 TypeCheckerVisitor::visit(Type & node) \
319 node.type = NodeFactory::createBoolType(); \
324 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node.left->type) && \
325 !AST_IS_FLOATING_POINT_TYPE(node.right->type), \
326 "throw and error here, can't do bitwise operation on float."); \
328 node.left = castToType(node.type, node.left); \
329 node.right = castToType(node.type, node.right); \
332 VISIT_BINARY_BOOL_EXP(OrOrExp
)
333 VISIT_BINARY_BOOL_EXP(AndAndExp
)
336 TypeCheckerVisitor::visit(NegExp
& node
)
339 node
.type
= node
.exp
->type
;
343 TypeCheckerVisitor::visit(NotExp
& node
)
346 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node
.exp
->type
),
347 "throw and error here, can't do bitwise operation on float.");
348 node
.type
= node
.exp
->type
;
352 TypeCheckerVisitor::visit(BangExp
& node
)
355 assert_msg(!AST_IS_FLOATING_POINT_TYPE(node
.exp
->type
),
356 "throw and error here, can't do bitwise operation on float.");
357 node
.type
= NodeFactory::createBoolType();
358 node
.exp
= castToType(node
.type
, node
.exp
);
362 TypeCheckerVisitor::visit(ConditionalBranch
& node
)
365 node
.cond
= castToType(NodeFactory::createBoolType(), node
.cond
);