comparison: improve "foo = min(...);" assignment handling
[smatch.git] / smatch_type.c
blobaf59f3900097052378775336c7ed3190e9189720
1 /*
2 * Copyright (C) 2009 Dan Carpenter.
4 * This program is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU General Public License
6 * as published by the Free Software Foundation; either version 2
7 * of the License, or (at your option) any later version.
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, see http://www.gnu.org/copyleft/gpl.txt
19 * The idea here is that you have an expression and you
20 * want to know what the type is for that.
23 #include "smatch.h"
24 #include "smatch_slist.h"
26 struct symbol *get_real_base_type(struct symbol *sym)
28 struct symbol *ret;
30 if (!sym)
31 return NULL;
32 ret = get_base_type(sym);
33 if (!ret)
34 return NULL;
35 if (ret->type == SYM_RESTRICT || ret->type == SYM_NODE)
36 return get_real_base_type(ret);
37 return ret;
40 int type_bytes(struct symbol *type)
42 int bits;
44 if (type && type->type == SYM_ARRAY)
45 return array_bytes(type);
47 bits = type_bits(type);
48 if (bits < 0)
49 return 0;
50 return bits_to_bytes(bits);
53 int array_bytes(struct symbol *type)
55 if (!type || type->type != SYM_ARRAY)
56 return 0;
57 return bits_to_bytes(type->bit_size);
60 static struct symbol *get_binop_type(struct expression *expr)
62 struct symbol *left, *right;
64 left = get_type(expr->left);
65 if (!left)
66 return NULL;
68 if (expr->op == SPECIAL_LEFTSHIFT ||
69 expr->op == SPECIAL_RIGHTSHIFT) {
70 if (type_positive_bits(left) < 31)
71 return &int_ctype;
72 return left;
74 right = get_type(expr->right);
75 if (!right)
76 return NULL;
78 if (left->type == SYM_PTR || left->type == SYM_ARRAY)
79 return left;
80 if (right->type == SYM_PTR || right->type == SYM_ARRAY)
81 return right;
83 if (type_positive_bits(left) < 31 && type_positive_bits(right) < 31)
84 return &int_ctype;
86 if (type_positive_bits(left) > type_positive_bits(right))
87 return left;
88 return right;
91 static struct symbol *get_type_symbol(struct expression *expr)
93 if (!expr || expr->type != EXPR_SYMBOL || !expr->symbol)
94 return NULL;
96 return get_real_base_type(expr->symbol);
99 static struct symbol *get_member_symbol(struct symbol_list *symbol_list, struct ident *member)
101 struct symbol *tmp, *sub;
103 FOR_EACH_PTR(symbol_list, tmp) {
104 if (!tmp->ident) {
105 sub = get_real_base_type(tmp);
106 sub = get_member_symbol(sub->symbol_list, member);
107 if (sub)
108 return sub;
109 continue;
111 if (tmp->ident == member)
112 return tmp;
113 } END_FOR_EACH_PTR(tmp);
115 return NULL;
118 static struct symbol *get_symbol_from_deref(struct expression *expr)
120 struct ident *member;
121 struct symbol *sym;
123 if (!expr || expr->type != EXPR_DEREF)
124 return NULL;
126 member = expr->member;
127 sym = get_type(expr->deref);
128 if (!sym) {
129 // sm_msg("could not find struct type");
130 return NULL;
132 if (sym->type == SYM_PTR)
133 sym = get_real_base_type(sym);
134 sym = get_member_symbol(sym->symbol_list, member);
135 if (!sym)
136 return NULL;
137 return get_real_base_type(sym);
140 static struct symbol *get_return_type(struct expression *expr)
142 struct symbol *tmp;
144 tmp = get_type(expr->fn);
145 if (!tmp)
146 return NULL;
147 /* this is to handle __builtin_constant_p() */
148 if (tmp->type != SYM_FN)
149 tmp = get_base_type(tmp);
150 return get_real_base_type(tmp);
153 static struct symbol *get_expr_stmt_type(struct statement *stmt)
155 if (stmt->type != STMT_COMPOUND)
156 return NULL;
157 stmt = last_ptr_list((struct ptr_list *)stmt->stmts);
158 if (stmt->type == STMT_LABEL)
159 stmt = stmt->label_statement;
160 if (stmt->type != STMT_EXPRESSION)
161 return NULL;
162 return get_type(stmt->expression);
165 static struct symbol *get_select_type(struct expression *expr)
167 struct symbol *one, *two;
169 one = get_type(expr->cond_true);
170 two = get_type(expr->cond_false);
171 if (!one || !two)
172 return NULL;
174 * This is a hack. If the types are not equiv then we
175 * really don't know the type. But I think guessing is
176 * probably Ok here.
178 if (type_positive_bits(one) > type_positive_bits(two))
179 return one;
180 return two;
183 struct symbol *get_pointer_type(struct expression *expr)
185 struct symbol *sym;
187 sym = get_type(expr);
188 if (!sym)
189 return NULL;
190 if (sym->type == SYM_NODE) {
191 sym = get_real_base_type(sym);
192 if (!sym)
193 return NULL;
195 if (sym->type != SYM_PTR && sym->type != SYM_ARRAY)
196 return NULL;
197 return get_real_base_type(sym);
200 static struct symbol *fake_pointer_sym(struct expression *expr)
202 struct symbol *sym;
203 struct symbol *base;
205 sym = alloc_symbol(expr->pos, SYM_PTR);
206 expr = expr->unop;
207 base = get_type(expr);
208 if (!base)
209 return NULL;
210 sym->ctype.base_type = base;
211 return sym;
214 static struct symbol *get_type_helper(struct expression *expr)
216 struct symbol *ret;
218 expr = strip_parens(expr);
219 if (!expr)
220 return NULL;
222 if (expr->ctype)
223 return expr->ctype;
225 switch (expr->type) {
226 case EXPR_STRING:
227 ret = &string_ctype;
228 break;
229 case EXPR_SYMBOL:
230 ret = get_type_symbol(expr);
231 break;
232 case EXPR_DEREF:
233 ret = get_symbol_from_deref(expr);
234 break;
235 case EXPR_PREOP:
236 case EXPR_POSTOP:
237 if (expr->op == '&')
238 ret = fake_pointer_sym(expr);
239 else if (expr->op == '*')
240 ret = get_pointer_type(expr->unop);
241 else
242 ret = get_type(expr->unop);
243 break;
244 case EXPR_ASSIGNMENT:
245 ret = get_type(expr->left);
246 break;
247 case EXPR_CAST:
248 case EXPR_FORCE_CAST:
249 case EXPR_IMPLIED_CAST:
250 ret = get_real_base_type(expr->cast_type);
251 break;
252 case EXPR_COMPARE:
253 case EXPR_BINOP:
254 ret = get_binop_type(expr);
255 break;
256 case EXPR_CALL:
257 ret = get_return_type(expr);
258 break;
259 case EXPR_STATEMENT:
260 ret = get_expr_stmt_type(expr->statement);
261 break;
262 case EXPR_CONDITIONAL:
263 case EXPR_SELECT:
264 ret = get_select_type(expr);
265 break;
266 case EXPR_SIZEOF:
267 ret = &ulong_ctype;
268 break;
269 case EXPR_LOGICAL:
270 ret = &int_ctype;
271 break;
272 default:
273 return NULL;
276 if (ret && ret->type == SYM_TYPEOF)
277 ret = get_type(ret->initializer);
279 expr->ctype = ret;
280 return ret;
283 static struct symbol *get_final_type_helper(struct expression *expr)
286 * I'm not totally positive I understand types...
288 * So, when you're doing pointer math, and you do a subtraction, then
289 * the sval_binop() and whatever need to know the type of the pointer
290 * so they can figure out the alignment. But the result is going to be
291 * and ssize_t. So get_operation_type() gives you the pointer type
292 * and get_type() gives you ssize_t.
294 * Most of the time the operation type and the final type are the same
295 * but this just handles the few places where they are different.
299 return NULL;
301 if (!expr)
302 return NULL;
304 switch (expr->type) {
305 case EXPR_COMPARE:
306 return &int_ctype;
307 case EXPR_BINOP: {
308 struct symbol *left, *right;
310 if (expr->op != '-')
311 return NULL;
313 left = get_type(expr->left);
314 right = get_type(expr->right);
315 if (type_is_ptr(left) && type_is_ptr(right))
316 return ssize_t_ctype;
320 return NULL;
323 struct symbol *get_type(struct expression *expr)
325 return get_type_helper(expr);
328 struct symbol *get_final_type(struct expression *expr)
330 struct symbol *ret;
332 ret = get_final_type_helper(expr);
333 if (ret)
334 return ret;
335 return get_type_helper(expr);
338 struct symbol *get_promoted_type(struct symbol *left, struct symbol *right)
340 struct symbol *ret = &int_ctype;
342 if (type_positive_bits(left) > type_positive_bits(ret))
343 ret = left;
344 if (type_positive_bits(right) > type_positive_bits(ret))
345 ret = right;
347 if (type_is_ptr(left))
348 ret = left;
349 if (type_is_ptr(right))
350 ret = right;
352 return ret;
355 int type_signed(struct symbol *base_type)
357 if (!base_type)
358 return 0;
359 if (base_type->ctype.modifiers & MOD_SIGNED)
360 return 1;
361 return 0;
364 int expr_unsigned(struct expression *expr)
366 struct symbol *sym;
368 sym = get_type(expr);
369 if (!sym)
370 return 0;
371 if (type_unsigned(sym))
372 return 1;
373 return 0;
376 int expr_signed(struct expression *expr)
378 struct symbol *sym;
380 sym = get_type(expr);
381 if (!sym)
382 return 0;
383 if (type_signed(sym))
384 return 1;
385 return 0;
388 int returns_unsigned(struct symbol *sym)
390 if (!sym)
391 return 0;
392 sym = get_base_type(sym);
393 if (!sym || sym->type != SYM_FN)
394 return 0;
395 sym = get_base_type(sym);
396 return type_unsigned(sym);
399 int is_pointer(struct expression *expr)
401 struct symbol *sym;
403 sym = get_type(expr);
404 if (!sym)
405 return 0;
406 if (sym == &string_ctype)
407 return 0;
408 if (sym->type == SYM_PTR)
409 return 1;
410 return 0;
413 int returns_pointer(struct symbol *sym)
415 if (!sym)
416 return 0;
417 sym = get_base_type(sym);
418 if (!sym || sym->type != SYM_FN)
419 return 0;
420 sym = get_base_type(sym);
421 if (sym->type == SYM_PTR)
422 return 1;
423 return 0;
426 sval_t sval_type_max(struct symbol *base_type)
428 sval_t ret;
430 if (!base_type || !type_bits(base_type))
431 base_type = &llong_ctype;
432 ret.type = base_type;
434 ret.value = (~0ULL) >> (64 - type_positive_bits(base_type));
435 return ret;
438 sval_t sval_type_min(struct symbol *base_type)
440 sval_t ret;
442 if (!base_type || !type_bits(base_type))
443 base_type = &llong_ctype;
444 ret.type = base_type;
446 if (type_unsigned(base_type)) {
447 ret.value = 0;
448 return ret;
451 ret.value = (~0ULL) << type_positive_bits(base_type);
453 return ret;
456 int nr_bits(struct expression *expr)
458 struct symbol *type;
460 type = get_type(expr);
461 if (!type)
462 return 0;
463 return type_bits(type);
466 int is_void_pointer(struct expression *expr)
468 struct symbol *type;
470 type = get_type(expr);
471 if (!type || type->type != SYM_PTR)
472 return 0;
473 type = get_real_base_type(type);
474 if (type == &void_ctype)
475 return 1;
476 return 0;
479 int is_char_pointer(struct expression *expr)
481 struct symbol *type;
483 type = get_type(expr);
484 if (!type || type->type != SYM_PTR)
485 return 0;
486 type = get_real_base_type(type);
487 if (type == &char_ctype)
488 return 1;
489 return 0;
492 int is_string(struct expression *expr)
494 expr = strip_expr(expr);
495 if (!expr || expr->type != EXPR_STRING)
496 return 0;
497 if (expr->string)
498 return 1;
499 return 0;
502 int is_static(struct expression *expr)
504 char *name;
505 struct symbol *sym;
506 int ret = 0;
508 name = expr_to_str_sym(expr, &sym);
509 if (!name || !sym)
510 goto free;
512 if (sym->ctype.modifiers & MOD_STATIC)
513 ret = 1;
514 free:
515 free_string(name);
516 return ret;
519 int is_local_variable(struct expression *expr)
521 struct symbol *sym;
522 char *name;
524 name = expr_to_var_sym(expr, &sym);
525 free_string(name);
526 if (!sym || !sym->scope || !sym->scope->token || !cur_func_sym)
527 return 0;
528 if (cmp_pos(sym->scope->token->pos, cur_func_sym->pos) < 0)
529 return 0;
530 if (is_static(expr))
531 return 0;
532 return 1;
535 int types_equiv(struct symbol *one, struct symbol *two)
537 if (!one && !two)
538 return 1;
539 if (!one || !two)
540 return 0;
541 if (one->type != two->type)
542 return 0;
543 if (one->type == SYM_PTR)
544 return types_equiv(get_real_base_type(one), get_real_base_type(two));
545 if (type_positive_bits(one) != type_positive_bits(two))
546 return 0;
547 return 1;
550 int fn_static(void)
552 return !!(cur_func_sym->ctype.modifiers & MOD_STATIC);
555 const char *global_static(void)
557 if (cur_func_sym->ctype.modifiers & MOD_STATIC)
558 return "static";
559 else
560 return "global";
563 struct symbol *cur_func_return_type(void)
565 struct symbol *sym;
567 sym = get_real_base_type(cur_func_sym);
568 if (!sym || sym->type != SYM_FN)
569 return NULL;
570 sym = get_real_base_type(sym);
571 return sym;
574 struct symbol *get_arg_type(struct expression *fn, int arg)
576 struct symbol *fn_type;
577 struct symbol *tmp;
578 struct symbol *arg_type;
579 int i;
581 fn_type = get_type(fn);
582 if (!fn_type)
583 return NULL;
584 if (fn_type->type == SYM_PTR)
585 fn_type = get_real_base_type(fn_type);
586 if (fn_type->type != SYM_FN)
587 return NULL;
589 i = 0;
590 FOR_EACH_PTR(fn_type->arguments, tmp) {
591 arg_type = get_real_base_type(tmp);
592 if (i == arg) {
593 return arg_type;
595 i++;
596 } END_FOR_EACH_PTR(tmp);
598 return NULL;
601 static struct symbol *get_member_from_string(struct symbol_list *symbol_list, const char *name)
603 struct symbol *tmp, *sub;
604 int chunk_len;
606 if (strncmp(name, ".", 1) == 0)
607 name += 1;
608 if (strncmp(name, "->", 2) == 0)
609 name += 2;
611 FOR_EACH_PTR(symbol_list, tmp) {
612 if (!tmp->ident) {
613 sub = get_real_base_type(tmp);
614 sub = get_member_from_string(sub->symbol_list, name);
615 if (sub)
616 return sub;
617 continue;
620 if (strcmp(tmp->ident->name, name) == 0)
621 return tmp;
623 chunk_len = strlen(tmp->ident->name);
624 if (strncmp(tmp->ident->name, name, chunk_len) == 0 &&
625 (name[chunk_len] == '.' || name[chunk_len] == '-')) {
626 sub = get_real_base_type(tmp);
627 return get_member_from_string(sub->symbol_list, name + chunk_len);
630 } END_FOR_EACH_PTR(tmp);
632 return NULL;
635 struct symbol *get_member_type_from_key(struct expression *expr, const char *key)
637 struct symbol *sym;
639 if (strcmp(key, "$") == 0)
640 return get_type(expr);
642 if (strcmp(key, "*$") == 0) {
643 sym = get_type(expr);
644 if (!sym || sym->type != SYM_PTR)
645 return NULL;
646 return get_real_base_type(sym);
649 sym = get_type(expr);
650 if (!sym)
651 return NULL;
652 if (sym->type == SYM_PTR)
653 sym = get_real_base_type(sym);
655 key = key + 1;
656 sym = get_member_from_string(sym->symbol_list, key);
657 if (!sym)
658 return NULL;
659 return get_real_base_type(sym);
662 int is_struct(struct expression *expr)
664 struct symbol *type;
666 type = get_type(expr);
667 if (type && type->type == SYM_STRUCT)
668 return 1;
669 return 0;
672 static struct {
673 struct symbol *sym;
674 const char *name;
675 } base_types[] = {
676 {&bool_ctype, "bool"},
677 {&void_ctype, "void"},
678 {&type_ctype, "type"},
679 {&char_ctype, "char"},
680 {&schar_ctype, "schar"},
681 {&uchar_ctype, "uchar"},
682 {&short_ctype, "short"},
683 {&sshort_ctype, "sshort"},
684 {&ushort_ctype, "ushort"},
685 {&int_ctype, "int"},
686 {&sint_ctype, "sint"},
687 {&uint_ctype, "uint"},
688 {&long_ctype, "long"},
689 {&slong_ctype, "slong"},
690 {&ulong_ctype, "ulong"},
691 {&llong_ctype, "llong"},
692 {&sllong_ctype, "sllong"},
693 {&ullong_ctype, "ullong"},
694 {&lllong_ctype, "lllong"},
695 {&slllong_ctype, "slllong"},
696 {&ulllong_ctype, "ulllong"},
697 {&float_ctype, "float"},
698 {&double_ctype, "double"},
699 {&ldouble_ctype, "ldouble"},
700 {&string_ctype, "string"},
701 {&ptr_ctype, "ptr"},
702 {&lazy_ptr_ctype, "lazy_ptr"},
703 {&incomplete_ctype, "incomplete"},
704 {&label_ctype, "label"},
705 {&bad_ctype, "bad"},
706 {&null_ctype, "null"},
709 static const char *base_type_str(struct symbol *sym)
711 int i;
713 for (i = 0; i < ARRAY_SIZE(base_types); i++) {
714 if (sym == base_types[i].sym)
715 return base_types[i].name;
717 return "<unknown>";
720 static int type_str_helper(char *buf, int size, struct symbol *type)
722 int n;
724 if (!type)
725 return snprintf(buf, size, "<unknown>");
727 if (type->type == SYM_BASETYPE) {
728 return snprintf(buf, size, base_type_str(type));
729 } else if (type->type == SYM_PTR) {
730 type = get_real_base_type(type);
731 n = type_str_helper(buf, size, type);
732 if (n > size)
733 return n;
734 return n + snprintf(buf + n, size - n, "*");
735 } else if (type->type == SYM_ARRAY) {
736 type = get_real_base_type(type);
737 n = type_str_helper(buf, size, type);
738 if (n > size)
739 return n;
740 return n + snprintf(buf + n, size - n, "[]");
741 } else if (type->type == SYM_STRUCT) {
742 return snprintf(buf, size, "struct %s", type->ident ? type->ident->name : "");
743 } else if (type->type == SYM_UNION) {
744 if (type->ident)
745 return snprintf(buf, size, "union %s", type->ident->name);
746 else
747 return snprintf(buf, size, "anonymous union");
748 } else if (type->type == SYM_FN) {
749 struct symbol *arg, *return_type, *arg_type;
750 int i;
752 return_type = get_real_base_type(type);
753 n = type_str_helper(buf, size, return_type);
754 if (n > size)
755 return n;
756 n += snprintf(buf + n, size - n, "(*)(");
757 if (n > size)
758 return n;
760 i = 0;
761 FOR_EACH_PTR(type->arguments, arg) {
762 if (i++)
763 n += snprintf(buf + n, size - n, ", ");
764 if (n > size)
765 return n;
766 arg_type = get_real_base_type(arg);
767 n += type_str_helper(buf + n, size - n, arg_type);
768 if (n > size)
769 return n;
770 } END_FOR_EACH_PTR(arg);
772 return n + snprintf(buf + n, size - n, ")");
773 } else if (type->type == SYM_NODE) {
774 n = snprintf(buf, size, "node {");
775 if (n > size)
776 return n;
777 type = get_real_base_type(type);
778 n += type_str_helper(buf + n, size - n, type);
779 if (n > size)
780 return n;
781 return n + snprintf(buf + n, size - n, "}");
782 } else {
783 return snprintf(buf, size, "<type %d>", type->type);
787 char *type_to_str(struct symbol *type)
789 static char buf[256];
791 buf[0] = '\0';
792 type_str_helper(buf, sizeof(buf), type);
793 return buf;