extra: fix a bug in how pointers are set inside called functions
[smatch.git] / smatch_type.c
blobddc65ece5c14392a2ba03c6ef54dbb98f0dac569
1 /*
2 * sparse/smatch_types.c
4 * Copyright (C) 2009 Dan Carpenter.
6 * Licensed under the Open Software License version 1.1
8 */
11 * The idea here is that you have an expression and you
12 * want to know what the type is for that.
15 #include "smatch.h"
17 struct symbol *get_real_base_type(struct symbol *sym)
19 struct symbol *ret;
21 ret = get_base_type(sym);
22 if (ret && ret->type == SYM_RESTRICT)
23 return get_real_base_type(ret);
24 return ret;
27 int type_bits(struct symbol *type)
29 if (!type)
30 return 0;
31 if (type->type == SYM_PTR) /* Sparse doesn't set this for &pointers */
32 return bits_in_pointer;
33 return type->bit_size;
36 int type_positive_bits(struct symbol *type)
38 if (!type)
39 return 0;
40 if (type_unsigned(type))
41 return type_bits(type);
42 return type_bits(type) - 1;
45 static struct symbol *get_binop_type(struct expression *expr)
47 struct symbol *left, *right;
49 left = get_type(expr->left);
50 right = get_type(expr->right);
52 if (!left || !right)
53 return NULL;
55 if (left->type == SYM_PTR || left->type == SYM_ARRAY)
56 return left;
57 if (right->type == SYM_PTR || right->type == SYM_ARRAY)
58 return right;
60 if (expr->op == SPECIAL_LEFTSHIFT ||
61 expr->op == SPECIAL_RIGHTSHIFT) {
62 if (type_positive_bits(left) < 31)
63 return &int_ctype;
64 return left;
67 if (type_positive_bits(left) < 31 && type_positive_bits(right) < 31)
68 return &int_ctype;
70 if (type_positive_bits(left) > type_positive_bits(right))
71 return left;
72 return right;
75 static struct symbol *get_type_symbol(struct expression *expr)
77 if (!expr || expr->type != EXPR_SYMBOL || !expr->symbol)
78 return NULL;
80 return get_real_base_type(expr->symbol);
83 static struct symbol *get_member_symbol(struct symbol_list *symbol_list, struct ident *member)
85 struct symbol *tmp, *sub;
87 FOR_EACH_PTR(symbol_list, tmp) {
88 if (!tmp->ident) {
89 sub = get_real_base_type(tmp);
90 sub = get_member_symbol(sub->symbol_list, member);
91 if (sub)
92 return sub;
93 continue;
95 if (tmp->ident == member)
96 return tmp;
97 } END_FOR_EACH_PTR(tmp);
99 return NULL;
102 static struct symbol *get_symbol_from_deref(struct expression *expr)
104 struct ident *member;
105 struct symbol *sym;
107 if (!expr || expr->type != EXPR_DEREF)
108 return NULL;
110 member = expr->member;
111 sym = get_type(expr->deref);
112 if (!sym) {
113 // sm_msg("could not find struct type");
114 return NULL;
116 if (sym->type == SYM_PTR)
117 sym = get_real_base_type(sym);
118 sym = get_member_symbol(sym->symbol_list, member);
119 if (!sym)
120 return NULL;
121 return get_real_base_type(sym);
124 static struct symbol *get_return_type(struct expression *expr)
126 struct symbol *tmp;
128 tmp = get_type(expr->fn);
129 if (!tmp)
130 return NULL;
131 return get_real_base_type(tmp);
134 static struct symbol *get_expr_stmt_type(struct statement *stmt)
136 if (stmt->type != STMT_COMPOUND)
137 return NULL;
138 stmt = last_ptr_list((struct ptr_list *)stmt->stmts);
139 if (!stmt || stmt->type != STMT_EXPRESSION)
140 return NULL;
141 return get_type(stmt->expression);
144 static struct symbol *get_select_type(struct expression *expr)
146 struct symbol *one, *two;
148 one = get_type(expr->cond_true);
149 two = get_type(expr->cond_false);
150 if (!one || !two)
151 return NULL;
153 * This is a hack. If the types are not equiv then we
154 * really don't know the type. But I think guessing is
155 * probably Ok here.
157 if (type_positive_bits(one) > type_positive_bits(two))
158 return one;
159 return two;
162 struct symbol *get_pointer_type(struct expression *expr)
164 struct symbol *sym;
166 sym = get_type(expr);
167 if (!sym || (sym->type != SYM_PTR && sym->type != SYM_ARRAY))
168 return NULL;
169 return get_real_base_type(sym);
172 static struct symbol *fake_pointer_sym(struct expression *expr)
174 struct symbol *sym;
175 struct symbol *base;
177 sym = alloc_symbol(expr->pos, SYM_PTR);
178 expr = expr->unop;
179 base = get_type(expr);
180 if (!base)
181 return NULL;
182 sym->ctype.base_type = base;
183 return sym;
186 struct symbol *get_type(struct expression *expr)
188 if (!expr)
189 return NULL;
190 expr = strip_parens(expr);
192 switch (expr->type) {
193 case EXPR_SYMBOL:
194 return get_type_symbol(expr);
195 case EXPR_DEREF:
196 return get_symbol_from_deref(expr);
197 case EXPR_PREOP:
198 case EXPR_POSTOP:
199 if (expr->op == '&')
200 return fake_pointer_sym(expr);
201 if (expr->op == '*')
202 return get_pointer_type(expr->unop);
203 return get_type(expr->unop);
204 case EXPR_ASSIGNMENT:
205 return get_type(expr->left);
206 case EXPR_CAST:
207 case EXPR_FORCE_CAST:
208 case EXPR_IMPLIED_CAST:
209 return get_real_base_type(expr->cast_type);
210 case EXPR_COMPARE:
211 case EXPR_BINOP:
212 return get_binop_type(expr);
213 case EXPR_CALL:
214 return get_return_type(expr);
215 case EXPR_STATEMENT:
216 return get_expr_stmt_type(expr->statement);
217 case EXPR_CONDITIONAL:
218 case EXPR_SELECT:
219 return get_select_type(expr);
220 case EXPR_SIZEOF:
221 return &ulong_ctype;
222 case EXPR_LOGICAL:
223 return &int_ctype;
224 default:
225 // sm_msg("unhandled type %d", expr->type);
226 return expr->ctype;
228 return NULL;
231 int type_unsigned(struct symbol *base_type)
233 if (!base_type)
234 return 0;
235 if (base_type->ctype.modifiers & MOD_UNSIGNED)
236 return 1;
237 return 0;
240 int type_signed(struct symbol *base_type)
242 if (!base_type)
243 return 0;
244 if (base_type->ctype.modifiers & MOD_UNSIGNED)
245 return 0;
246 return 1;
249 int expr_unsigned(struct expression *expr)
251 struct symbol *sym;
253 sym = get_type(expr);
254 if (!sym)
255 return 0;
256 if (type_unsigned(sym))
257 return 1;
258 return 0;
261 int returns_unsigned(struct symbol *sym)
263 if (!sym)
264 return 0;
265 sym = get_base_type(sym);
266 if (!sym || sym->type != SYM_FN)
267 return 0;
268 sym = get_base_type(sym);
269 return type_unsigned(sym);
272 int is_pointer(struct expression *expr)
274 struct symbol *sym;
276 sym = get_type(expr);
277 if (!sym)
278 return 0;
279 if (sym->type == SYM_PTR)
280 return 1;
281 return 0;
284 int returns_pointer(struct symbol *sym)
286 if (!sym)
287 return 0;
288 sym = get_base_type(sym);
289 if (!sym || sym->type != SYM_FN)
290 return 0;
291 sym = get_base_type(sym);
292 if (sym->type == SYM_PTR)
293 return 1;
294 return 0;
297 sval_t sval_type_max(struct symbol *base_type)
299 sval_t ret;
301 ret.value = (~0ULL) >> 1;
302 ret.type = base_type;
304 if (!base_type || !base_type->bit_size)
305 return ret;
307 ret.value = (~0ULL) >> (64 - type_positive_bits(base_type));
308 return ret;
311 sval_t sval_type_min(struct symbol *base_type)
313 sval_t ret;
315 if (!base_type || !base_type->bit_size)
316 base_type = &llong_ctype;
317 ret.type = base_type;
319 if (type_unsigned(base_type)) {
320 ret.value = 0;
321 return ret;
324 ret.value = (~0ULL) << type_positive_bits(base_type);
326 return ret;
329 int nr_bits(struct expression *expr)
331 struct symbol *type;
333 type = get_type(expr);
334 if (!type)
335 return 0;
336 return type_bits(type);
339 int is_static(struct expression *expr)
341 char *name;
342 struct symbol *sym;
343 int ret = 0;
345 name = get_variable_from_expr_complex(expr, &sym);
346 if (!name || !sym)
347 goto free;
349 if (sym->ctype.modifiers & MOD_STATIC)
350 ret = 1;
351 free:
352 free_string(name);
353 return ret;
356 int types_equiv(struct symbol *one, struct symbol *two)
358 if (!one && !two)
359 return 1;
360 if (!one || !two)
361 return 0;
362 if (one->type != two->type)
363 return 0;
364 if (one->type == SYM_PTR)
365 return types_equiv(get_real_base_type(one), get_real_base_type(two));
366 if (type_positive_bits(one) != type_positive_bits(two))
367 return 0;
368 return 1;
371 const char *global_static()
373 if (cur_func_sym->ctype.modifiers & MOD_STATIC)
374 return "static";
375 else
376 return "global";
379 struct symbol *cur_func_return_type(void)
381 struct symbol *sym;
383 sym = get_real_base_type(cur_func_sym);
384 if (!sym || sym->type != SYM_FN)
385 return NULL;
386 sym = get_real_base_type(sym);
387 return sym;
390 struct symbol *get_arg_type(struct expression *fn, int arg)
392 struct symbol *fn_type;
393 struct symbol *tmp;
394 struct symbol *arg_type;
395 int i;
397 fn_type = get_type(fn);
398 if (!fn_type)
399 return NULL;
400 if (fn_type->type == SYM_PTR)
401 fn_type = get_real_base_type(fn_type);
402 if (fn_type->type != SYM_FN)
403 return NULL;
405 i = 0;
406 FOR_EACH_PTR(fn_type->arguments, tmp) {
407 arg_type = get_real_base_type(tmp);
408 if (i == arg) {
409 return arg_type;
411 i++;
412 } END_FOR_EACH_PTR(tmp);
414 return NULL;
417 static struct symbol *get_member_from_string(struct symbol_list *symbol_list, char *name)
419 struct symbol *tmp, *sub;
420 int chunk_len;
422 if (strncmp(name, ".", 1) == 0)
423 name += 1;
424 if (strncmp(name, "->", 2) == 0)
425 name += 2;
427 FOR_EACH_PTR(symbol_list, tmp) {
428 if (!tmp->ident) {
429 sub = get_real_base_type(tmp);
430 sub = get_member_from_string(sub->symbol_list, name);
431 if (sub)
432 return sub;
433 continue;
436 if (strcmp(tmp->ident->name, name) == 0)
437 return tmp;
439 chunk_len = strlen(tmp->ident->name);
440 if (strncmp(tmp->ident->name, name, chunk_len) == 0 &&
441 (name[chunk_len] == '.' || name[chunk_len] == '-')) {
442 sub = get_real_base_type(tmp);
443 return get_member_from_string(sub->symbol_list, name + chunk_len);
446 } END_FOR_EACH_PTR(tmp);
448 return NULL;
451 struct symbol *get_member_type_from_key(struct symbol *sym, char *key)
453 if (strcmp(key, "$$") == 0)
454 return get_real_base_type(sym);
456 key = key + 2;
457 sym = get_real_base_type(sym);
458 if (sym->type == SYM_PTR)
459 sym = get_real_base_type(sym);
461 sym = get_member_from_string(sym->symbol_list, key);
462 if (!sym)
463 return NULL;
464 return get_real_base_type(sym);