extra: modify match_comparison() so it can deal with range comparisons
[smatch.git] / smatch_type.c
blob90e2885aa2a3313ca011bc6eb1c43dd4a73566cd
1 /*
2 * sparse/smatch_types.c
4 * Copyright (C) 2009 Dan Carpenter.
6 * Licensed under the Open Software License version 1.1
8 */
11 * The idea here is that you have an expression and you
12 * want to know what the type is for that.
15 #include "smatch.h"
17 struct symbol *get_real_base_type(struct symbol *sym)
19 struct symbol *ret;
21 ret = get_base_type(sym);
22 if (ret && ret->type == SYM_RESTRICT)
23 return get_real_base_type(ret);
24 return ret;
27 static struct symbol *get_type_symbol(struct expression *expr)
29 if (!expr || expr->type != EXPR_SYMBOL || !expr->symbol)
30 return NULL;
32 return get_real_base_type(expr->symbol);
35 static struct symbol *get_symbol_from_deref(struct expression *expr)
37 struct ident *member;
38 struct symbol *struct_sym;
39 struct symbol *tmp;
41 if (!expr || expr->type != EXPR_DEREF)
42 return NULL;
44 member = expr->member;
45 struct_sym = get_type(expr->deref);
46 if (!struct_sym) {
47 // sm_msg("could not find struct type");
48 return NULL;
50 if (struct_sym->type == SYM_PTR)
51 struct_sym = get_real_base_type(struct_sym);
52 FOR_EACH_PTR(struct_sym->symbol_list, tmp) {
53 if (tmp->ident == member)
54 return get_real_base_type(tmp);
55 } END_FOR_EACH_PTR(tmp);
56 return NULL;
59 static struct symbol *get_return_type(struct expression *expr)
61 struct symbol *tmp;
63 tmp = get_type(expr->fn);
64 if (!tmp)
65 return NULL;
66 return get_real_base_type(tmp);
69 static struct symbol *get_pointer_type(struct expression *expr)
71 struct symbol *sym;
73 sym = get_type(expr);
74 if (!sym || (sym->type != SYM_PTR && sym->type != SYM_ARRAY))
75 return NULL;
76 return get_real_base_type(sym);
79 static struct symbol *fake_pointer_sym(struct expression *expr)
81 struct symbol *sym;
82 struct symbol *base;
84 sym = alloc_symbol(expr->pos, SYM_PTR);
85 expr = expr->unop;
86 base = get_type(expr);
87 if (!base)
88 return NULL;
89 sym->ctype.base_type = base;
90 return sym;
93 struct symbol *get_type(struct expression *expr)
95 struct symbol *tmp;
97 if (!expr)
98 return NULL;
99 expr = strip_parens(expr);
101 switch (expr->type) {
102 case EXPR_SYMBOL:
103 return get_type_symbol(expr);
104 case EXPR_DEREF:
105 return get_symbol_from_deref(expr);
106 case EXPR_PREOP:
107 if (expr->op == '&')
108 return fake_pointer_sym(expr);
109 if (expr->op == '*')
110 return get_pointer_type(expr->unop);
111 return get_type(expr->unop);
112 case EXPR_CAST:
113 case EXPR_FORCE_CAST:
114 case EXPR_IMPLIED_CAST:
115 return get_real_base_type(expr->cast_type);
116 case EXPR_BINOP:
117 if (expr->op != '+')
118 return NULL;
119 tmp = get_type(expr->left);
120 return tmp;
121 case EXPR_CALL:
122 return get_return_type(expr);
123 default:
124 return expr->ctype;
125 // sm_msg("unhandled type %d", expr->type);
129 return NULL;
132 int type_unsigned(struct symbol *base_type)
134 if (!base_type)
135 return 0;
136 if (base_type->ctype.modifiers & MOD_UNSIGNED)
137 return 1;
138 return 0;
141 int expr_unsigned(struct expression *expr)
143 struct symbol *sym;
145 sym = get_type(expr);
146 if (!sym)
147 return 0;
148 if (type_unsigned(sym))
149 return 1;
150 return 0;
153 int returns_unsigned(struct symbol *sym)
155 if (!sym)
156 return 0;
157 sym = get_base_type(sym);
158 if (!sym || sym->type != SYM_FN)
159 return 0;
160 sym = get_base_type(sym);
161 return type_unsigned(sym);
164 int returns_pointer(struct symbol *sym)
166 if (!sym)
167 return 0;
168 sym = get_base_type(sym);
169 if (!sym || sym->type != SYM_FN)
170 return 0;
171 sym = get_base_type(sym);
172 if (sym->type == SYM_PTR)
173 return 1;
174 return 0;
177 long long type_max(struct symbol *base_type)
179 long long ret = whole_range.max;
180 int bits;
182 if (!base_type || !base_type->bit_size)
183 return ret;
184 bits = base_type->bit_size;
185 if (bits == 64)
186 return ret;
187 if (!type_unsigned(base_type))
188 bits--;
189 ret >>= (63 - bits);
190 return ret;
193 long long type_min(struct symbol *base_type)
195 long long ret = whole_range.min;
196 int bits;
198 if (!base_type || !base_type->bit_size)
199 return ret;
200 if (type_unsigned(base_type))
201 return 0;
202 ret = whole_range.max;
203 bits = base_type->bit_size - 1;
204 ret >>= (63 - bits);
205 return -(ret + 1);