math: use comparisons to handle subtraction better
[smatch.git] / inline.c
bloba3002c6bda5b41ebe069f9be31511ed195062833
1 /*
2 * Sparse - a semantic source parser.
4 * Copyright (C) 2003 Transmeta Corp.
5 * 2003-2004 Linus Torvalds
7 * Permission is hereby granted, free of charge, to any person obtaining a copy
8 * of this software and associated documentation files (the "Software"), to deal
9 * in the Software without restriction, including without limitation the rights
10 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 * copies of the Software, and to permit persons to whom the Software is
12 * furnished to do so, subject to the following conditions:
14 * The above copyright notice and this permission notice shall be included in
15 * all copies or substantial portions of the Software.
17 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23 * THE SOFTWARE.
26 #include <stdlib.h>
27 #include <stdio.h>
29 #include "lib.h"
30 #include "allocate.h"
31 #include "token.h"
32 #include "parse.h"
33 #include "symbol.h"
34 #include "expression.h"
36 static struct expression * dup_expression(struct expression *expr)
38 struct expression *dup = alloc_expression(expr->pos, expr->type);
39 *dup = *expr;
40 return dup;
43 static struct statement * dup_statement(struct statement *stmt)
45 struct statement *dup = alloc_statement(stmt->pos, stmt->type);
46 *dup = *stmt;
47 return dup;
50 static struct symbol *copy_symbol(struct position pos, struct symbol *sym)
52 if (!sym)
53 return sym;
54 if (sym->ctype.modifiers & (MOD_STATIC | MOD_EXTERN | MOD_TOPLEVEL | MOD_INLINE))
55 return sym;
56 if (!sym->replace) {
57 warning(pos, "unreplaced symbol '%s'", show_ident(sym->ident));
58 return sym;
60 return sym->replace;
63 static struct symbol_list *copy_symbol_list(struct symbol_list *src)
65 struct symbol_list *dst = NULL;
66 struct symbol *sym;
68 FOR_EACH_PTR(src, sym) {
69 struct symbol *newsym = copy_symbol(sym->pos, sym);
70 add_symbol(&dst, newsym);
71 } END_FOR_EACH_PTR(sym);
72 return dst;
75 static struct expression * copy_expression(struct expression *expr)
77 if (!expr)
78 return NULL;
80 switch (expr->type) {
82 * EXPR_SYMBOL is the interesting case, we may need to replace the
83 * symbol to the new copy.
85 case EXPR_SYMBOL: {
86 struct symbol *sym = copy_symbol(expr->pos, expr->symbol);
87 if (sym == expr->symbol)
88 break;
89 expr = dup_expression(expr);
90 expr->symbol = sym;
91 break;
94 /* Atomics, never change, just return the expression directly */
95 case EXPR_VALUE:
96 case EXPR_STRING:
97 case EXPR_FVALUE:
98 case EXPR_TYPE:
99 break;
101 /* Unops: check if the subexpression is unique */
102 case EXPR_PREOP:
103 case EXPR_POSTOP: {
104 struct expression *unop = copy_expression(expr->unop);
105 if (expr->unop == unop)
106 break;
107 expr = dup_expression(expr);
108 expr->unop = unop;
109 break;
112 case EXPR_SLICE: {
113 struct expression *base = copy_expression(expr->base);
114 expr = dup_expression(expr);
115 expr->base = base;
116 break;
119 /* Binops: copy left/right expressions */
120 case EXPR_BINOP:
121 case EXPR_COMMA:
122 case EXPR_COMPARE:
123 case EXPR_LOGICAL: {
124 struct expression *left = copy_expression(expr->left);
125 struct expression *right = copy_expression(expr->right);
126 if (left == expr->left && right == expr->right)
127 break;
128 expr = dup_expression(expr);
129 expr->left = left;
130 expr->right = right;
131 break;
134 case EXPR_ASSIGNMENT: {
135 struct expression *left = copy_expression(expr->left);
136 struct expression *right = copy_expression(expr->right);
137 if (expr->op == '=' && left == expr->left && right == expr->right)
138 break;
139 expr = dup_expression(expr);
140 expr->left = left;
141 expr->right = right;
142 break;
145 /* Dereference */
146 case EXPR_DEREF: {
147 struct expression *deref = copy_expression(expr->deref);
148 expr = dup_expression(expr);
149 expr->deref = deref;
150 break;
153 /* Cast/sizeof/__alignof__ */
154 case EXPR_CAST:
155 if (expr->cast_expression->type == EXPR_INITIALIZER) {
156 struct expression *cast = expr->cast_expression;
157 struct symbol *sym = expr->cast_type;
158 expr = dup_expression(expr);
159 expr->cast_expression = copy_expression(cast);
160 expr->cast_type = alloc_symbol(sym->pos, sym->type);
161 *expr->cast_type = *sym;
162 break;
164 case EXPR_FORCE_CAST:
165 case EXPR_IMPLIED_CAST:
166 case EXPR_SIZEOF:
167 case EXPR_PTRSIZEOF:
168 case EXPR_ALIGNOF: {
169 struct expression *cast = copy_expression(expr->cast_expression);
170 if (cast == expr->cast_expression)
171 break;
172 expr = dup_expression(expr);
173 expr->cast_expression = cast;
174 break;
177 /* Conditional expression */
178 case EXPR_SELECT:
179 case EXPR_CONDITIONAL: {
180 struct expression *cond = copy_expression(expr->conditional);
181 struct expression *true = copy_expression(expr->cond_true);
182 struct expression *false = copy_expression(expr->cond_false);
183 if (cond == expr->conditional && true == expr->cond_true && false == expr->cond_false)
184 break;
185 expr = dup_expression(expr);
186 expr->conditional = cond;
187 expr->cond_true = true;
188 expr->cond_false = false;
189 break;
192 /* Statement expression */
193 case EXPR_STATEMENT: {
194 struct statement *stmt = alloc_statement(expr->pos, STMT_COMPOUND);
195 copy_statement(expr->statement, stmt);
196 expr = dup_expression(expr);
197 expr->statement = stmt;
198 break;
201 /* Call expression */
202 case EXPR_CALL: {
203 struct expression *fn = copy_expression(expr->fn);
204 struct expression_list *list = expr->args;
205 struct expression *arg;
207 expr = dup_expression(expr);
208 expr->fn = fn;
209 expr->args = NULL;
210 FOR_EACH_PTR(list, arg) {
211 add_expression(&expr->args, copy_expression(arg));
212 } END_FOR_EACH_PTR(arg);
213 break;
216 /* Initializer list statement */
217 case EXPR_INITIALIZER: {
218 struct expression_list *list = expr->expr_list;
219 struct expression *entry;
220 expr = dup_expression(expr);
221 expr->expr_list = NULL;
222 FOR_EACH_PTR(list, entry) {
223 add_expression(&expr->expr_list, copy_expression(entry));
224 } END_FOR_EACH_PTR(entry);
225 break;
228 /* Label in inline function - hmm. */
229 case EXPR_LABEL: {
230 struct symbol *label_symbol = copy_symbol(expr->pos, expr->label_symbol);
231 expr = dup_expression(expr);
232 expr->label_symbol = label_symbol;
233 break;
236 case EXPR_INDEX: {
237 struct expression *sub_expr = copy_expression(expr->idx_expression);
238 expr = dup_expression(expr);
239 expr->idx_expression = sub_expr;
240 break;
243 case EXPR_IDENTIFIER: {
244 struct expression *sub_expr = copy_expression(expr->ident_expression);
245 expr = dup_expression(expr);
246 expr->ident_expression = sub_expr;
247 break;
250 /* Position in initializer.. */
251 case EXPR_POS: {
252 struct expression *val = copy_expression(expr->init_expr);
253 expr = dup_expression(expr);
254 expr->init_expr = val;
255 break;
257 case EXPR_OFFSETOF: {
258 struct expression *val = copy_expression(expr->down);
259 if (expr->op == '.') {
260 if (expr->down != val) {
261 expr = dup_expression(expr);
262 expr->down = val;
264 } else {
265 struct expression *idx = copy_expression(expr->index);
266 if (expr->down != val || expr->index != idx) {
267 expr = dup_expression(expr);
268 expr->down = val;
269 expr->index = idx;
272 break;
274 default:
275 warning(expr->pos, "trying to copy expression type %d", expr->type);
277 return expr;
280 static struct expression_list *copy_asm_constraints(struct expression_list *in)
282 struct expression_list *out = NULL;
283 struct expression *expr;
284 int state = 0;
286 FOR_EACH_PTR(in, expr) {
287 switch (state) {
288 case 0: /* identifier */
289 case 1: /* constraint */
290 state++;
291 add_expression(&out, expr);
292 continue;
293 case 2: /* expression */
294 state = 0;
295 add_expression(&out, copy_expression(expr));
296 continue;
298 } END_FOR_EACH_PTR(expr);
299 return out;
302 static void set_replace(struct symbol *old, struct symbol *new)
304 new->replace = old;
305 old->replace = new;
308 static void unset_replace(struct symbol *sym)
310 struct symbol *r = sym->replace;
311 if (!r) {
312 warning(sym->pos, "symbol '%s' not replaced?", show_ident(sym->ident));
313 return;
315 r->replace = NULL;
316 sym->replace = NULL;
319 static void unset_replace_list(struct symbol_list *list)
321 struct symbol *sym;
322 FOR_EACH_PTR(list, sym) {
323 unset_replace(sym);
324 } END_FOR_EACH_PTR(sym);
327 static struct statement *copy_one_statement(struct statement *stmt)
329 if (!stmt)
330 return NULL;
331 switch(stmt->type) {
332 case STMT_NONE:
333 break;
334 case STMT_DECLARATION: {
335 struct symbol *sym;
336 struct statement *newstmt = dup_statement(stmt);
337 newstmt->declaration = NULL;
338 FOR_EACH_PTR(stmt->declaration, sym) {
339 struct symbol *newsym = copy_symbol(stmt->pos, sym);
340 if (newsym != sym)
341 newsym->initializer = copy_expression(sym->initializer);
342 add_symbol(&newstmt->declaration, newsym);
343 } END_FOR_EACH_PTR(sym);
344 stmt = newstmt;
345 break;
347 case STMT_CONTEXT:
348 case STMT_EXPRESSION: {
349 struct expression *expr = copy_expression(stmt->expression);
350 if (expr == stmt->expression)
351 break;
352 stmt = dup_statement(stmt);
353 stmt->expression = expr;
354 break;
356 case STMT_RANGE: {
357 struct expression *expr = copy_expression(stmt->range_expression);
358 if (expr == stmt->expression)
359 break;
360 stmt = dup_statement(stmt);
361 stmt->range_expression = expr;
362 break;
364 case STMT_COMPOUND: {
365 struct statement *new = alloc_statement(stmt->pos, STMT_COMPOUND);
366 copy_statement(stmt, new);
367 stmt = new;
368 break;
370 case STMT_IF: {
371 struct expression *cond = stmt->if_conditional;
372 struct statement *true = stmt->if_true;
373 struct statement *false = stmt->if_false;
375 cond = copy_expression(cond);
376 true = copy_one_statement(true);
377 false = copy_one_statement(false);
378 if (stmt->if_conditional == cond &&
379 stmt->if_true == true &&
380 stmt->if_false == false)
381 break;
382 stmt = dup_statement(stmt);
383 stmt->if_conditional = cond;
384 stmt->if_true = true;
385 stmt->if_false = false;
386 break;
388 case STMT_RETURN: {
389 struct expression *retval = copy_expression(stmt->ret_value);
390 struct symbol *sym = copy_symbol(stmt->pos, stmt->ret_target);
392 stmt = dup_statement(stmt);
393 stmt->ret_value = retval;
394 stmt->ret_target = sym;
395 break;
397 case STMT_CASE: {
398 stmt = dup_statement(stmt);
399 stmt->case_label = copy_symbol(stmt->pos, stmt->case_label);
400 stmt->case_label->stmt = stmt;
401 stmt->case_expression = copy_expression(stmt->case_expression);
402 stmt->case_to = copy_expression(stmt->case_to);
403 stmt->case_statement = copy_one_statement(stmt->case_statement);
404 break;
406 case STMT_SWITCH: {
407 struct symbol *switch_break = copy_symbol(stmt->pos, stmt->switch_break);
408 struct symbol *switch_case = copy_symbol(stmt->pos, stmt->switch_case);
409 struct expression *expr = copy_expression(stmt->switch_expression);
410 struct statement *switch_stmt = copy_one_statement(stmt->switch_statement);
412 stmt = dup_statement(stmt);
413 switch_case->symbol_list = copy_symbol_list(switch_case->symbol_list);
414 stmt->switch_break = switch_break;
415 stmt->switch_case = switch_case;
416 stmt->switch_expression = expr;
417 stmt->switch_statement = switch_stmt;
418 break;
420 case STMT_ITERATOR: {
421 stmt = dup_statement(stmt);
422 stmt->iterator_break = copy_symbol(stmt->pos, stmt->iterator_break);
423 stmt->iterator_continue = copy_symbol(stmt->pos, stmt->iterator_continue);
424 stmt->iterator_syms = copy_symbol_list(stmt->iterator_syms);
426 stmt->iterator_pre_statement = copy_one_statement(stmt->iterator_pre_statement);
427 stmt->iterator_pre_condition = copy_expression(stmt->iterator_pre_condition);
429 stmt->iterator_statement = copy_one_statement(stmt->iterator_statement);
431 stmt->iterator_post_statement = copy_one_statement(stmt->iterator_post_statement);
432 stmt->iterator_post_condition = copy_expression(stmt->iterator_post_condition);
433 break;
435 case STMT_LABEL: {
436 stmt = dup_statement(stmt);
437 stmt->label_identifier = copy_symbol(stmt->pos, stmt->label_identifier);
438 stmt->label_statement = copy_one_statement(stmt->label_statement);
439 break;
441 case STMT_GOTO: {
442 stmt = dup_statement(stmt);
443 stmt->goto_label = copy_symbol(stmt->pos, stmt->goto_label);
444 stmt->goto_expression = copy_expression(stmt->goto_expression);
445 stmt->target_list = copy_symbol_list(stmt->target_list);
446 break;
448 case STMT_ASM: {
449 stmt = dup_statement(stmt);
450 stmt->asm_inputs = copy_asm_constraints(stmt->asm_inputs);
451 stmt->asm_outputs = copy_asm_constraints(stmt->asm_outputs);
452 /* no need to dup "clobbers", since they are all constant strings */
453 break;
455 default:
456 warning(stmt->pos, "trying to copy statement type %d", stmt->type);
457 break;
459 return stmt;
463 * Copy a statement tree from 'src' to 'dst', where both
464 * source and destination are of type STMT_COMPOUND.
466 * We do this for the tree-level inliner.
468 * This doesn't do the symbol replacement right: it's not
469 * re-entrant.
471 void copy_statement(struct statement *src, struct statement *dst)
473 struct statement *stmt;
475 FOR_EACH_PTR(src->stmts, stmt) {
476 add_statement(&dst->stmts, copy_one_statement(stmt));
477 } END_FOR_EACH_PTR(stmt);
478 dst->args = copy_one_statement(src->args);
479 dst->ret = copy_symbol(src->pos, src->ret);
480 dst->inline_fn = src->inline_fn;
483 static struct symbol *create_copy_symbol(struct symbol *orig)
485 struct symbol *sym = orig;
486 if (orig) {
487 sym = alloc_symbol(orig->pos, orig->type);
488 *sym = *orig;
489 sym->bb_target = NULL;
490 sym->pseudo = NULL;
491 set_replace(orig, sym);
492 orig = sym;
494 return orig;
497 static struct symbol_list *create_symbol_list(struct symbol_list *src)
499 struct symbol_list *dst = NULL;
500 struct symbol *sym;
502 FOR_EACH_PTR(src, sym) {
503 struct symbol *newsym = create_copy_symbol(sym);
504 add_symbol(&dst, newsym);
505 } END_FOR_EACH_PTR(sym);
506 return dst;
509 int inline_function(struct expression *expr, struct symbol *sym)
511 struct symbol_list * fn_symbol_list;
512 struct symbol *fn = sym->ctype.base_type;
513 struct expression_list *arg_list = expr->args;
514 struct statement *stmt = alloc_statement(expr->pos, STMT_COMPOUND);
515 struct symbol_list *name_list, *arg_decl;
516 struct symbol *name;
517 struct expression *arg;
519 if (!fn->inline_stmt) {
520 sparse_error(fn->pos, "marked inline, but without a definition");
521 return 0;
523 if (fn->expanding)
524 return 0;
526 fn->expanding = 1;
528 name_list = fn->arguments;
530 expr->type = EXPR_STATEMENT;
531 expr->statement = stmt;
532 expr->ctype = fn->ctype.base_type;
534 fn_symbol_list = create_symbol_list(sym->inline_symbol_list);
536 arg_decl = NULL;
537 PREPARE_PTR_LIST(name_list, name);
538 FOR_EACH_PTR(arg_list, arg) {
539 struct symbol *a = alloc_symbol(arg->pos, SYM_NODE);
541 a->ctype.base_type = arg->ctype;
542 if (name) {
543 *a = *name;
544 set_replace(name, a);
545 add_symbol(&fn_symbol_list, a);
547 a->initializer = arg;
548 add_symbol(&arg_decl, a);
550 NEXT_PTR_LIST(name);
551 } END_FOR_EACH_PTR(arg);
552 FINISH_PTR_LIST(name);
554 copy_statement(fn->inline_stmt, stmt);
556 if (arg_decl) {
557 struct statement *decl = alloc_statement(expr->pos, STMT_DECLARATION);
558 decl->declaration = arg_decl;
559 stmt->args = decl;
561 stmt->inline_fn = sym;
563 unset_replace_list(fn_symbol_list);
565 evaluate_statement(stmt);
567 fn->expanding = 0;
568 return 1;
571 void uninline(struct symbol *sym)
573 struct symbol *fn = sym->ctype.base_type;
574 struct symbol_list *arg_list = fn->arguments;
575 struct symbol *p;
577 sym->symbol_list = create_symbol_list(sym->inline_symbol_list);
578 FOR_EACH_PTR(arg_list, p) {
579 p->replace = p;
580 } END_FOR_EACH_PTR(p);
581 fn->stmt = alloc_statement(fn->pos, STMT_COMPOUND);
582 copy_statement(fn->inline_stmt, fn->stmt);
583 unset_replace_list(sym->symbol_list);
584 unset_replace_list(arg_list);