[PATCH] Avoid recursive inline function expansion
[smatch.git] / inline.c
blob1c40dc3e59a40b737d982fbe6c6de654c2732516
1 /*
2 * Sparse - a semantic source parser.
4 * Copyright (C) 2003 Transmeta Corp.
5 * 2003 Linus Torvalds
7 * Licensed under the Open Software License version 1.1
8 */
10 #include <stdlib.h>
11 #include <stdio.h>
13 #include "lib.h"
14 #include "token.h"
15 #include "parse.h"
16 #include "symbol.h"
17 #include "expression.h"
19 static struct expression * dup_expression(struct expression *expr)
21 struct expression *dup = alloc_expression(expr->pos, expr->type);
22 *dup = *expr;
23 return dup;
26 static struct statement * dup_statement(struct statement *stmt)
28 struct statement *dup = alloc_statement(stmt->pos, stmt->type);
29 *dup = *stmt;
30 return dup;
33 static struct symbol *copy_symbol(struct position pos, struct symbol *sym)
35 if (!sym)
36 return sym;
37 if (sym->ctype.modifiers & (MOD_STATIC | MOD_EXTERN | MOD_TOPLEVEL | MOD_INLINE))
38 return sym;
39 if (!sym->replace) {
40 warn(pos, "unreplaced symbol '%s'", show_ident(sym->ident));
41 return sym;
43 return sym->replace;
46 static struct symbol_list *copy_symbol_list(struct symbol_list *src)
48 struct symbol_list *dst = NULL;
49 struct symbol *sym;
51 FOR_EACH_PTR(src, sym) {
52 struct symbol *newsym = copy_symbol(sym->pos, sym);
53 add_symbol(&dst, newsym);
54 } END_FOR_EACH_PTR;
55 return dst;
58 static struct expression * copy_expression(struct expression *expr)
60 if (!expr)
61 return NULL;
63 switch (expr->type) {
65 * EXPR_SYMBOL is the interesting case, we may need to replace the
66 * symbol to the new copy.
68 case EXPR_SYMBOL: {
69 struct symbol *sym = copy_symbol(expr->pos, expr->symbol);
70 if (sym == expr->symbol)
71 break;
72 expr = dup_expression(expr);
73 expr->symbol = sym;
74 break;
77 /* Atomics, never change, just return the expression directly */
78 case EXPR_VALUE:
79 case EXPR_STRING:
80 case EXPR_FVALUE:
81 break;
83 /* Unops: check if the subexpression is unique */
84 case EXPR_PREOP:
85 case EXPR_POSTOP: {
86 struct expression *unop = copy_expression(expr->unop);
87 if (expr->unop == unop)
88 break;
89 expr = dup_expression(expr);
90 expr->unop = unop;
91 break;
94 /* Binops: copy left/right expressions */
95 case EXPR_BINOP:
96 case EXPR_COMMA:
97 case EXPR_COMPARE:
98 case EXPR_LOGICAL:
99 case EXPR_ASSIGNMENT: {
100 struct expression *left = copy_expression(expr->left);
101 struct expression *right = copy_expression(expr->right);
102 if (left == expr->left && right == expr->right)
103 break;
104 expr = dup_expression(expr);
105 expr->left = left;
106 expr->right = right;
107 break;
110 /* Dereference */
111 case EXPR_DEREF: {
112 struct expression *deref = copy_expression(expr->deref);
113 if (deref == expr->deref)
114 break;
115 expr = dup_expression(expr);
116 expr->deref = deref;
117 break;
120 /* Cast/sizeof/__alignof__ */
121 case EXPR_CAST:
122 case EXPR_SIZEOF:
123 case EXPR_ALIGNOF: {
124 struct expression *cast = copy_expression(expr->cast_expression);
125 if (cast == expr->cast_expression)
126 break;
127 expr = dup_expression(expr);
128 expr->cast_expression = cast;
129 break;
132 /* Conditional expression */
133 case EXPR_SELECT:
134 case EXPR_CONDITIONAL: {
135 struct expression *cond = copy_expression(expr->conditional);
136 struct expression *true = copy_expression(expr->cond_true);
137 struct expression *false = copy_expression(expr->cond_false);
138 if (cond == expr->conditional && true == expr->cond_true && false == expr->cond_false)
139 break;
140 expr = dup_expression(expr);
141 expr->conditional = cond;
142 expr->cond_true = true;
143 expr->cond_false = false;
144 break;
147 /* Statement expression */
148 case EXPR_STATEMENT: {
149 struct statement *stmt = alloc_statement(expr->pos, STMT_COMPOUND);
150 copy_statement(expr->statement, stmt);
151 expr = dup_expression(expr);
152 expr->statement = stmt;
153 break;
156 /* Call expression */
157 case EXPR_CALL: {
158 struct expression *fn = copy_expression(expr->fn);
159 struct expression_list *list = expr->args;
160 struct expression *arg;
162 expr = dup_expression(expr);
163 expr->fn = fn;
164 expr->args = NULL;
165 FOR_EACH_PTR(list, arg) {
166 add_expression(&expr->args, copy_expression(arg));
167 } END_FOR_EACH_PTR;
168 break;
171 /* Initializer list statement */
172 case EXPR_INITIALIZER: {
173 struct expression_list *list = expr->expr_list;
174 struct expression *entry;
175 expr = dup_expression(expr);
176 expr->expr_list = NULL;
177 FOR_EACH_PTR(list, entry) {
178 add_expression(&expr->expr_list, copy_expression(entry));
179 } END_FOR_EACH_PTR;
180 break;
183 /* Label in inline function - hmm. */
184 case EXPR_LABEL: {
185 struct symbol *label_symbol = copy_symbol(expr->pos, expr->label_symbol);
186 expr = dup_expression(expr);
187 expr->label_symbol = label_symbol;
188 break;
191 /* Identifier in member dereference is unchanged across a fn copy */
192 /* As is an array index expression */
193 case EXPR_INDEX:
194 case EXPR_IDENTIFIER:
195 break;
197 /* Position in initializer.. */
198 case EXPR_POS: {
199 struct expression *val = copy_expression(expr->init_expr);
200 if (val == expr->init_expr)
201 break;
202 expr = dup_expression(expr);
203 expr->init_expr = val;
204 break;
207 default:
208 warn(expr->pos, "trying to copy expression type %d", expr->type);
210 return expr;
213 void set_replace(struct symbol *old, struct symbol *new)
215 new->replace = old;
216 old->replace = new;
219 void unset_replace(struct symbol *sym)
221 struct symbol *r = sym->replace;
222 if (!r) {
223 warn(sym->pos, "symbol '%s' not replaced?", show_ident(sym->ident));
224 return;
226 r->replace = NULL;
227 sym->replace = NULL;
230 static void unset_replace_list(struct symbol_list *list)
232 struct symbol *sym;
233 FOR_EACH_PTR(list, sym) {
234 unset_replace(sym);
235 } END_FOR_EACH_PTR;
238 static struct statement *copy_one_statement(struct statement *stmt)
240 if (!stmt)
241 return NULL;
242 switch(stmt->type) {
243 case STMT_NONE:
244 break;
245 case STMT_EXPRESSION: {
246 struct expression *expr = copy_expression(stmt->expression);
247 if (expr == stmt->expression)
248 break;
249 stmt = dup_statement(stmt);
250 stmt->expression = expr;
251 break;
253 case STMT_COMPOUND: {
254 struct statement *new = alloc_statement(stmt->pos, STMT_COMPOUND);
255 copy_statement(stmt, new);
256 stmt = new;
257 break;
259 case STMT_IF: {
260 struct expression *cond = stmt->if_conditional;
261 struct statement *true = stmt->if_true;
262 struct statement *false = stmt->if_false;
264 cond = copy_expression(cond);
265 true = copy_one_statement(true);
266 false = copy_one_statement(false);
267 if (stmt->if_conditional == cond &&
268 stmt->if_true == true &&
269 stmt->if_false == false)
270 break;
271 stmt = dup_statement(stmt);
272 stmt->if_conditional = cond;
273 stmt->if_true = true;
274 stmt->if_false = false;
275 break;
277 case STMT_RETURN: {
278 struct expression *retval = copy_expression(stmt->ret_value);
279 struct symbol *sym = copy_symbol(stmt->pos, stmt->ret_target);
281 stmt = dup_statement(stmt);
282 stmt->ret_value = retval;
283 stmt->ret_target = sym;
284 break;
286 case STMT_CASE: {
287 stmt = dup_statement(stmt);
288 stmt->case_label = copy_symbol(stmt->pos, stmt->case_label);
289 stmt->case_expression = copy_expression(stmt->case_expression);
290 stmt->case_to = copy_expression(stmt->case_to);
291 stmt->case_statement = copy_one_statement(stmt->case_statement);
292 break;
294 case STMT_SWITCH: {
295 struct symbol *switch_break = copy_symbol(stmt->pos, stmt->switch_break);
296 struct symbol *switch_case = copy_symbol(stmt->pos, stmt->switch_case);
297 struct expression *expr = copy_expression(stmt->switch_expression);
298 struct statement *switch_stmt = copy_one_statement(stmt->switch_statement);
299 stmt = dup_statement(stmt);
300 stmt->switch_break = switch_break;
301 stmt->switch_case = switch_case;
302 stmt->switch_expression = expr;
303 stmt->switch_statement = switch_stmt;
304 break;
306 case STMT_ITERATOR: {
307 stmt = dup_statement(stmt);
308 stmt->iterator_break = copy_symbol(stmt->pos, stmt->iterator_break);
309 stmt->iterator_continue = copy_symbol(stmt->pos, stmt->iterator_continue);
310 stmt->iterator_syms = copy_symbol_list(stmt->iterator_syms);
312 stmt->iterator_pre_statement = copy_one_statement(stmt->iterator_pre_statement);
313 stmt->iterator_pre_condition = copy_expression(stmt->iterator_pre_condition);
315 stmt->iterator_statement = copy_one_statement(stmt->iterator_statement);
317 stmt->iterator_post_statement = copy_one_statement(stmt->iterator_post_statement);
318 stmt->iterator_post_condition = copy_expression(stmt->iterator_post_condition);
319 break;
321 case STMT_LABEL: {
322 stmt = dup_statement(stmt);
323 stmt->label_identifier = copy_symbol(stmt->pos, stmt->label_identifier);
324 stmt->label_statement = copy_one_statement(stmt->label_statement);
325 break;
327 case STMT_GOTO: {
328 stmt = dup_statement(stmt);
329 stmt->goto_label = copy_symbol(stmt->pos, stmt->goto_label);
330 stmt->goto_expression = copy_expression(stmt->goto_expression);
331 stmt->target_list = copy_symbol_list(stmt->target_list);
332 break;
334 case STMT_ASM: {
335 /* FIXME! */
336 break;
338 default:
339 warn(stmt->pos, "trying to copy statement type %d", stmt->type);
340 break;
342 return stmt;
346 * Copy a stateemnt tree from 'src' to 'dst', where both
347 * source and destination are of type STMT_COMPOUND.
349 * We do this for the tree-level inliner.
351 * This doesn't do the symbol replacement right: it's not
352 * re-entrant.
354 void copy_statement(struct statement *src, struct statement *dst)
356 struct statement *stmt;
357 struct symbol *sym;
359 FOR_EACH_PTR(src->syms, sym) {
360 struct symbol *newsym = copy_symbol(src->pos, sym);
361 newsym->initializer = copy_expression(sym->initializer);
362 add_symbol(&dst->syms, newsym);
363 } END_FOR_EACH_PTR;
365 FOR_EACH_PTR(src->stmts, stmt) {
366 add_statement(&dst->stmts, copy_one_statement(stmt));
367 } END_FOR_EACH_PTR;
369 dst->ret = copy_symbol(src->pos, src->ret);
372 static struct symbol *create_copy_symbol(struct symbol *orig)
374 struct symbol *sym = orig;
375 if (orig) {
376 sym = alloc_symbol(orig->pos, orig->type);
377 *sym = *orig;
378 set_replace(orig, sym);
379 orig = sym;
381 return orig;
384 static struct symbol_list *create_symbol_list(struct symbol_list *src)
386 struct symbol_list *dst = NULL;
387 struct symbol *sym;
389 FOR_EACH_PTR(src, sym) {
390 struct symbol *newsym = create_copy_symbol(sym);
391 add_symbol(&dst, newsym);
392 } END_FOR_EACH_PTR;
393 return dst;
396 int inline_function(struct expression *expr, struct symbol *sym)
398 struct symbol_list * fn_symbol_list;
399 struct symbol *fn = sym->ctype.base_type;
400 struct expression_list *arg_list = expr->args;
401 struct statement *stmt = alloc_statement(expr->pos, STMT_COMPOUND);
402 struct symbol_list *name_list;
403 struct symbol *name;
404 struct expression *arg;
406 if (!fn->stmt) {
407 warn(fn->pos, "marked inline, but without a definition");
408 return 0;
410 if (fn->expanding)
411 return 0;
412 fn->expanding = 1;
414 name_list = fn->arguments;
416 stmt = alloc_statement(expr->pos, STMT_COMPOUND);
418 expr->type = EXPR_STATEMENT;
419 expr->statement = stmt;
420 expr->ctype = fn->ctype.base_type;
422 fn_symbol_list = create_symbol_list(sym->symbol_list);
424 PREPARE_PTR_LIST(name_list, name);
425 FOR_EACH_PTR(arg_list, arg) {
426 struct symbol *a = alloc_symbol(arg->pos, SYM_NODE);
428 a->ctype.base_type = arg->ctype;
429 if (name) {
430 *a = *name;
431 set_replace(name, a);
432 add_symbol(&fn_symbol_list, a);
434 a->initializer = arg;
435 add_symbol(&stmt->syms, a);
437 NEXT_PTR_LIST(name);
438 } END_FOR_EACH_PTR;
439 FINISH_PTR_LIST(name);
441 copy_statement(fn->stmt, stmt);
443 unset_replace_list(fn_symbol_list);
445 evaluate_statement(stmt);
447 fn->expanding = 0;
448 return 1;