conditions: add a NULL check
[smatch.git] / smatch_type.c
blobc34543c00a1091c0ad3a27a87c2e5378dd4ce530
1 /*
2 * Copyright (C) 2009 Dan Carpenter.
4 * This program is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU General Public License
6 * as published by the Free Software Foundation; either version 2
7 * of the License, or (at your option) any later version.
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, see http://www.gnu.org/copyleft/gpl.txt
19 * The idea here is that you have an expression and you
20 * want to know what the type is for that.
23 #include "smatch.h"
24 #include "smatch_slist.h"
26 struct symbol *get_real_base_type(struct symbol *sym)
28 struct symbol *ret;
30 if (!sym)
31 return NULL;
32 ret = get_base_type(sym);
33 if (!ret)
34 return NULL;
35 if (ret->type == SYM_RESTRICT || ret->type == SYM_NODE)
36 return get_real_base_type(ret);
37 return ret;
40 int type_bits(struct symbol *type)
42 if (!type)
43 return 0;
44 if (type->type == SYM_PTR) /* Sparse doesn't set this for &pointers */
45 return bits_in_pointer;
46 if (!type->examined)
47 examine_symbol_type(type);
48 return type->bit_size;
51 int type_bytes(struct symbol *type)
53 int bits = type_bits(type);
55 if (bits < 0)
56 return 0;
57 return bits_to_bytes(bits);
60 int type_positive_bits(struct symbol *type)
62 if (!type)
63 return 0;
64 if (type_unsigned(type))
65 return type_bits(type);
66 return type_bits(type) - 1;
69 static struct symbol *get_binop_type(struct expression *expr)
71 struct symbol *left, *right;
73 left = get_type(expr->left);
74 if (!left)
75 return NULL;
77 if (expr->op == SPECIAL_LEFTSHIFT ||
78 expr->op == SPECIAL_RIGHTSHIFT) {
79 if (type_positive_bits(left) < 31)
80 return &int_ctype;
81 return left;
83 if (left->type == SYM_PTR || left->type == SYM_ARRAY)
84 return left;
86 right = get_type(expr->right);
87 if (!right)
88 return NULL;
90 if (right->type == SYM_PTR || right->type == SYM_ARRAY)
91 return right;
93 if (type_positive_bits(left) < 31 && type_positive_bits(right) < 31)
94 return &int_ctype;
96 if (type_positive_bits(left) > type_positive_bits(right))
97 return left;
98 return right;
101 static struct symbol *get_type_symbol(struct expression *expr)
103 if (!expr || expr->type != EXPR_SYMBOL || !expr->symbol)
104 return NULL;
106 return get_real_base_type(expr->symbol);
109 static struct symbol *get_member_symbol(struct symbol_list *symbol_list, struct ident *member)
111 struct symbol *tmp, *sub;
113 FOR_EACH_PTR(symbol_list, tmp) {
114 if (!tmp->ident) {
115 sub = get_real_base_type(tmp);
116 sub = get_member_symbol(sub->symbol_list, member);
117 if (sub)
118 return sub;
119 continue;
121 if (tmp->ident == member)
122 return tmp;
123 } END_FOR_EACH_PTR(tmp);
125 return NULL;
128 static struct symbol *get_symbol_from_deref(struct expression *expr)
130 struct ident *member;
131 struct symbol *sym;
133 if (!expr || expr->type != EXPR_DEREF)
134 return NULL;
136 member = expr->member;
137 sym = get_type(expr->deref);
138 if (!sym) {
139 // sm_msg("could not find struct type");
140 return NULL;
142 if (sym->type == SYM_PTR)
143 sym = get_real_base_type(sym);
144 sym = get_member_symbol(sym->symbol_list, member);
145 if (!sym)
146 return NULL;
147 return get_real_base_type(sym);
150 static struct symbol *get_return_type(struct expression *expr)
152 struct symbol *tmp;
154 tmp = get_type(expr->fn);
155 if (!tmp)
156 return NULL;
157 /* this is to handle __builtin_constant_p() */
158 if (tmp->type != SYM_FN)
159 tmp = get_base_type(tmp);
160 return get_real_base_type(tmp);
163 static struct symbol *get_expr_stmt_type(struct statement *stmt)
165 if (stmt->type != STMT_COMPOUND)
166 return NULL;
167 stmt = last_ptr_list((struct ptr_list *)stmt->stmts);
168 if (stmt->type == STMT_LABEL)
169 stmt = stmt->label_statement;
170 if (stmt->type != STMT_EXPRESSION)
171 return NULL;
172 return get_type(stmt->expression);
175 static struct symbol *get_select_type(struct expression *expr)
177 struct symbol *one, *two;
179 one = get_type(expr->cond_true);
180 two = get_type(expr->cond_false);
181 if (!one || !two)
182 return NULL;
184 * This is a hack. If the types are not equiv then we
185 * really don't know the type. But I think guessing is
186 * probably Ok here.
188 if (type_positive_bits(one) > type_positive_bits(two))
189 return one;
190 return two;
193 struct symbol *get_pointer_type(struct expression *expr)
195 struct symbol *sym;
197 sym = get_type(expr);
198 if (!sym)
199 return NULL;
200 if (sym->type == SYM_NODE) {
201 sym = get_real_base_type(sym);
202 if (!sym)
203 return NULL;
205 if (sym->type != SYM_PTR && sym->type != SYM_ARRAY)
206 return NULL;
207 return get_real_base_type(sym);
210 static struct symbol *fake_pointer_sym(struct expression *expr)
212 struct symbol *sym;
213 struct symbol *base;
215 sym = alloc_symbol(expr->pos, SYM_PTR);
216 expr = expr->unop;
217 base = get_type(expr);
218 if (!base)
219 return NULL;
220 sym->ctype.base_type = base;
221 return sym;
224 struct symbol *get_type(struct expression *expr)
226 struct symbol *ret;
228 expr = strip_parens(expr);
229 if (!expr)
230 return NULL;
232 if (expr->ctype)
233 return expr->ctype;
235 switch (expr->type) {
236 case EXPR_STRING:
237 ret = &string_ctype;
238 break;
239 case EXPR_SYMBOL:
240 ret = get_type_symbol(expr);
241 break;
242 case EXPR_DEREF:
243 ret = get_symbol_from_deref(expr);
244 break;
245 case EXPR_PREOP:
246 case EXPR_POSTOP:
247 if (expr->op == '&')
248 ret = fake_pointer_sym(expr);
249 else if (expr->op == '*')
250 ret = get_pointer_type(expr->unop);
251 else
252 ret = get_type(expr->unop);
253 break;
254 case EXPR_ASSIGNMENT:
255 ret = get_type(expr->left);
256 break;
257 case EXPR_CAST:
258 case EXPR_FORCE_CAST:
259 case EXPR_IMPLIED_CAST:
260 ret = get_real_base_type(expr->cast_type);
261 break;
262 case EXPR_COMPARE:
263 case EXPR_BINOP:
264 ret = get_binop_type(expr);
265 break;
266 case EXPR_CALL:
267 ret = get_return_type(expr);
268 break;
269 case EXPR_STATEMENT:
270 ret = get_expr_stmt_type(expr->statement);
271 break;
272 case EXPR_CONDITIONAL:
273 case EXPR_SELECT:
274 ret = get_select_type(expr);
275 break;
276 case EXPR_SIZEOF:
277 ret = &ulong_ctype;
278 break;
279 case EXPR_LOGICAL:
280 ret = &int_ctype;
281 break;
282 default:
283 return NULL;
286 if (ret && ret->type == SYM_TYPEOF)
287 ret = get_type(ret->initializer);
289 expr->ctype = ret;
290 return ret;
293 int type_unsigned(struct symbol *base_type)
295 if (!base_type)
296 return 0;
297 if (base_type->ctype.modifiers & MOD_UNSIGNED)
298 return 1;
299 return 0;
302 int type_signed(struct symbol *base_type)
304 if (!base_type)
305 return 0;
306 if (base_type->ctype.modifiers & MOD_SIGNED)
307 return 1;
308 return 0;
311 int expr_unsigned(struct expression *expr)
313 struct symbol *sym;
315 sym = get_type(expr);
316 if (!sym)
317 return 0;
318 if (type_unsigned(sym))
319 return 1;
320 return 0;
323 int expr_signed(struct expression *expr)
325 struct symbol *sym;
327 sym = get_type(expr);
328 if (!sym)
329 return 0;
330 if (type_signed(sym))
331 return 1;
332 return 0;
335 int returns_unsigned(struct symbol *sym)
337 if (!sym)
338 return 0;
339 sym = get_base_type(sym);
340 if (!sym || sym->type != SYM_FN)
341 return 0;
342 sym = get_base_type(sym);
343 return type_unsigned(sym);
346 int is_pointer(struct expression *expr)
348 struct symbol *sym;
350 sym = get_type(expr);
351 if (!sym)
352 return 0;
353 if (sym == &string_ctype)
354 return 0;
355 if (sym->type == SYM_PTR)
356 return 1;
357 return 0;
360 int returns_pointer(struct symbol *sym)
362 if (!sym)
363 return 0;
364 sym = get_base_type(sym);
365 if (!sym || sym->type != SYM_FN)
366 return 0;
367 sym = get_base_type(sym);
368 if (sym->type == SYM_PTR)
369 return 1;
370 return 0;
373 sval_t sval_type_max(struct symbol *base_type)
375 sval_t ret;
377 if (!base_type || !type_bits(base_type))
378 base_type = &llong_ctype;
379 ret.type = base_type;
381 ret.value = (~0ULL) >> (64 - type_positive_bits(base_type));
382 return ret;
385 sval_t sval_type_min(struct symbol *base_type)
387 sval_t ret;
389 if (!base_type || !type_bits(base_type))
390 base_type = &llong_ctype;
391 ret.type = base_type;
393 if (type_unsigned(base_type)) {
394 ret.value = 0;
395 return ret;
398 ret.value = (~0ULL) << type_positive_bits(base_type);
400 return ret;
403 int nr_bits(struct expression *expr)
405 struct symbol *type;
407 type = get_type(expr);
408 if (!type)
409 return 0;
410 return type_bits(type);
413 int is_void_pointer(struct expression *expr)
415 struct symbol *type;
417 type = get_type(expr);
418 if (!type || type->type != SYM_PTR)
419 return 0;
420 type = get_real_base_type(type);
421 if (type == &void_ctype)
422 return 1;
423 return 0;
426 int is_char_pointer(struct expression *expr)
428 struct symbol *type;
430 type = get_type(expr);
431 if (!type || type->type != SYM_PTR)
432 return 0;
433 type = get_real_base_type(type);
434 if (type == &char_ctype)
435 return 1;
436 return 0;
439 int is_string(struct expression *expr)
441 expr = strip_expr(expr);
442 if (!expr || expr->type != EXPR_STRING)
443 return 0;
444 if (expr->string)
445 return 1;
446 return 0;
449 int is_static(struct expression *expr)
451 char *name;
452 struct symbol *sym;
453 int ret = 0;
455 name = expr_to_str_sym(expr, &sym);
456 if (!name || !sym)
457 goto free;
459 if (sym->ctype.modifiers & MOD_STATIC)
460 ret = 1;
461 free:
462 free_string(name);
463 return ret;
466 int is_local_variable(struct expression *expr)
468 struct symbol *sym;
469 char *name;
471 name = expr_to_var_sym(expr, &sym);
472 free_string(name);
473 if (!sym || !sym->scope || !sym->scope->token)
474 return 0;
475 if (cmp_pos(sym->scope->token->pos, cur_func_sym->pos) < 0)
476 return 0;
477 if (is_static(expr))
478 return 0;
479 return 1;
482 int types_equiv(struct symbol *one, struct symbol *two)
484 if (!one && !two)
485 return 1;
486 if (!one || !two)
487 return 0;
488 if (one->type != two->type)
489 return 0;
490 if (one->type == SYM_PTR)
491 return types_equiv(get_real_base_type(one), get_real_base_type(two));
492 if (type_positive_bits(one) != type_positive_bits(two))
493 return 0;
494 return 1;
497 int fn_static(void)
499 return !!(cur_func_sym->ctype.modifiers & MOD_STATIC);
502 const char *global_static(void)
504 if (cur_func_sym->ctype.modifiers & MOD_STATIC)
505 return "static";
506 else
507 return "global";
510 struct symbol *cur_func_return_type(void)
512 struct symbol *sym;
514 sym = get_real_base_type(cur_func_sym);
515 if (!sym || sym->type != SYM_FN)
516 return NULL;
517 sym = get_real_base_type(sym);
518 return sym;
521 struct symbol *get_arg_type(struct expression *fn, int arg)
523 struct symbol *fn_type;
524 struct symbol *tmp;
525 struct symbol *arg_type;
526 int i;
528 fn_type = get_type(fn);
529 if (!fn_type)
530 return NULL;
531 if (fn_type->type == SYM_PTR)
532 fn_type = get_real_base_type(fn_type);
533 if (fn_type->type != SYM_FN)
534 return NULL;
536 i = 0;
537 FOR_EACH_PTR(fn_type->arguments, tmp) {
538 arg_type = get_real_base_type(tmp);
539 if (i == arg) {
540 return arg_type;
542 i++;
543 } END_FOR_EACH_PTR(tmp);
545 return NULL;
548 static struct symbol *get_member_from_string(struct symbol_list *symbol_list, char *name)
550 struct symbol *tmp, *sub;
551 int chunk_len;
553 if (strncmp(name, ".", 1) == 0)
554 name += 1;
555 if (strncmp(name, "->", 2) == 0)
556 name += 2;
558 FOR_EACH_PTR(symbol_list, tmp) {
559 if (!tmp->ident) {
560 sub = get_real_base_type(tmp);
561 sub = get_member_from_string(sub->symbol_list, name);
562 if (sub)
563 return sub;
564 continue;
567 if (strcmp(tmp->ident->name, name) == 0)
568 return tmp;
570 chunk_len = strlen(tmp->ident->name);
571 if (strncmp(tmp->ident->name, name, chunk_len) == 0 &&
572 (name[chunk_len] == '.' || name[chunk_len] == '-')) {
573 sub = get_real_base_type(tmp);
574 return get_member_from_string(sub->symbol_list, name + chunk_len);
577 } END_FOR_EACH_PTR(tmp);
579 return NULL;
582 struct symbol *get_member_type_from_key(struct expression *expr, char *key)
584 struct symbol *sym;
586 if (strcmp(key, "$") == 0)
587 return get_type(expr);
589 if (strcmp(key, "*$") == 0) {
590 sym = get_type(expr);
591 if (!sym || sym->type != SYM_PTR)
592 return NULL;
593 return get_real_base_type(sym);
596 sym = get_type(expr);
597 if (!sym)
598 return NULL;
599 if (sym->type == SYM_PTR)
600 sym = get_real_base_type(sym);
602 key = key + 1;
603 sym = get_member_from_string(sym->symbol_list, key);
604 if (!sym)
605 return NULL;
606 return get_real_base_type(sym);
609 int is_struct(struct expression *expr)
611 struct symbol *type;
613 type = get_type(expr);
614 if (type && type->type == SYM_STRUCT)
615 return 1;
616 return 0;
619 static struct {
620 struct symbol *sym;
621 const char *name;
622 } base_types[] = {
623 {&bool_ctype, "bool"},
624 {&void_ctype, "void"},
625 {&type_ctype, "type"},
626 {&char_ctype, "char"},
627 {&schar_ctype, "schar"},
628 {&uchar_ctype, "uchar"},
629 {&short_ctype, "short"},
630 {&sshort_ctype, "sshort"},
631 {&ushort_ctype, "ushort"},
632 {&int_ctype, "int"},
633 {&sint_ctype, "sint"},
634 {&uint_ctype, "uint"},
635 {&long_ctype, "long"},
636 {&slong_ctype, "slong"},
637 {&ulong_ctype, "ulong"},
638 {&llong_ctype, "llong"},
639 {&sllong_ctype, "sllong"},
640 {&ullong_ctype, "ullong"},
641 {&lllong_ctype, "lllong"},
642 {&slllong_ctype, "slllong"},
643 {&ulllong_ctype, "ulllong"},
644 {&float_ctype, "float"},
645 {&double_ctype, "double"},
646 {&ldouble_ctype, "ldouble"},
647 {&string_ctype, "string"},
648 {&ptr_ctype, "ptr"},
649 {&lazy_ptr_ctype, "lazy_ptr"},
650 {&incomplete_ctype, "incomplete"},
651 {&label_ctype, "label"},
652 {&bad_ctype, "bad"},
653 {&null_ctype, "null"},
656 static const char *base_type_str(struct symbol *sym)
658 int i;
660 for (i = 0; i < ARRAY_SIZE(base_types); i++) {
661 if (sym == base_types[i].sym)
662 return base_types[i].name;
664 return "<unknown>";
667 static int type_str_helper(char *buf, int size, struct symbol *type)
669 int n;
671 if (!type)
672 return snprintf(buf, size, "<unknown>");
674 if (type->type == SYM_BASETYPE) {
675 return snprintf(buf, size, base_type_str(type));
676 } else if (type->type == SYM_PTR) {
677 type = get_real_base_type(type);
678 n = type_str_helper(buf, size, type);
679 if (n > size)
680 return n;
681 return n + snprintf(buf + n, size - n, "*");
682 } else if (type->type == SYM_ARRAY) {
683 type = get_real_base_type(type);
684 n = type_str_helper(buf, size, type);
685 if (n > size)
686 return n;
687 return n + snprintf(buf + n, size - n, "[]");
688 } else if (type->type == SYM_STRUCT) {
689 return snprintf(buf, size, "struct %s", type->ident ? type->ident->name : "");
690 } else if (type->type == SYM_UNION) {
691 if (type->ident)
692 return snprintf(buf, size, "union %s", type->ident->name);
693 else
694 return snprintf(buf, size, "anonymous union");
695 } else if (type->type == SYM_FN) {
696 struct symbol *arg, *return_type, *arg_type;
697 int i;
699 return_type = get_real_base_type(type);
700 n = type_str_helper(buf, size, return_type);
701 if (n > size)
702 return n;
703 n += snprintf(buf + n, size - n, "(*)(");
704 if (n > size)
705 return n;
707 i = 0;
708 FOR_EACH_PTR(type->arguments, arg) {
709 if (i++)
710 n += snprintf(buf + n, size - n, ", ");
711 if (n > size)
712 return n;
713 arg_type = get_real_base_type(arg);
714 n += type_str_helper(buf + n, size - n, arg_type);
715 if (n > size)
716 return n;
717 } END_FOR_EACH_PTR(arg);
719 return n + snprintf(buf + n, size - n, ")");
720 } else if (type->type == SYM_NODE) {
721 n = snprintf(buf, size, "node {");
722 if (n > size)
723 return n;
724 type = get_real_base_type(type);
725 n += type_str_helper(buf + n, size - n, type);
726 if (n > size)
727 return n;
728 return n + snprintf(buf + n, size - n, "}");
729 } else {
730 return snprintf(buf, size, "<type %d>", type->type);
734 char *type_to_str(struct symbol *type)
736 static char buf[256];
738 buf[0] = '\0';
739 type_str_helper(buf, sizeof(buf), type);
740 return buf;