comparison: handle preops like "if (++a == b)"
[smatch.git] / check_buf_comparison.c
blob98cabc0624e357d3f3c2cba6bf53c150a9c06098
1 /*
2 * Copyright (C) 2012 Oracle.
4 * This program is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU General Public License
6 * as published by the Free Software Foundation; either version 2
7 * of the License, or (at your option) any later version.
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, see http://www.gnu.org/copyleft/gpl.txt
19 * The point here is to store that a buffer has x bytes even if we don't know
20 * the value of x.
24 #include "smatch.h"
25 #include "smatch_extra.h"
26 #include "smatch_slist.h"
28 static int size_id;
29 static int link_id;
32 * We need this for code which does:
34 * if (size)
35 * foo = malloc(size);
37 * We want to record that the size of "foo" is "size" even after the merge.
40 static struct smatch_state *unmatched_state(struct sm_state *sm)
42 struct expression *size_expr;
43 sval_t sval;
45 if (!sm->state->data)
46 return &undefined;
47 size_expr = sm->state->data;
48 if (!get_implied_value(size_expr, &sval) || sval.value != 0)
49 return &undefined;
50 return sm->state;
53 static int expr_equiv(struct expression *one, struct expression *two)
55 struct symbol *one_sym, *two_sym;
56 char *one_name = NULL;
57 char *two_name = NULL;
58 int ret = 0;
60 if (!one || !two)
61 return 0;
62 if (one->type != two->type)
63 return 0;
64 one_name = expr_to_str_sym(one, &one_sym);
65 if (!one_name || !one_sym)
66 goto free;
67 two_name = expr_to_str_sym(two, &two_sym);
68 if (!two_name || !two_sym)
69 goto free;
70 if (one_sym != two_sym)
71 goto free;
72 if (strcmp(one_name, two_name) == 0)
73 ret = 1;
74 free:
75 free_string(one_name);
76 free_string(two_name);
77 return ret;
80 static void match_modify(struct sm_state *sm, struct expression *mod_expr)
82 struct expression *expr;
84 expr = sm->state->data;
85 if (!expr)
86 return;
87 set_state_expr(size_id, expr, &undefined);
90 static struct smatch_state *alloc_expr_state(struct expression *expr)
92 struct smatch_state *state;
93 char *name;
95 state = __alloc_smatch_state(0);
96 expr = strip_expr(expr);
97 name = expr_to_str(expr);
98 state->name = alloc_sname(name);
99 free_string(name);
100 state->data = expr;
101 return state;
104 static int bytes_per_element(struct expression *expr)
106 struct symbol *type;
108 type = get_type(expr);
109 if (!type)
110 return 0;
112 if (type->type != SYM_PTR && type->type != SYM_ARRAY)
113 return 0;
115 type = get_base_type(type);
116 return type_bytes(type);
119 static void db_save_type_links(struct expression *array, struct expression *size)
121 const char *array_name;
123 array_name = get_data_info_name(array);
124 if (!array_name)
125 array_name = "";
126 sql_insert_data_info(size, ARRAY_LEN, array_name);
129 static void match_alloc(const char *fn, struct expression *expr, void *_size_arg)
131 int size_arg = PTR_INT(_size_arg);
132 struct expression *pointer, *call, *arg;
133 struct sm_state *tmp;
135 pointer = strip_expr(expr->left);
136 call = strip_expr(expr->right);
137 arg = get_argument_from_call_expr(call->args, size_arg);
138 arg = strip_expr(arg);
140 if (arg->type == EXPR_BINOP && arg->op == '*') {
141 struct expression *left, *right;
142 sval_t sval;
144 left = strip_expr(arg->left);
145 right = strip_expr(arg->right);
147 if (get_implied_value(left, &sval) &&
148 sval.value == bytes_per_element(pointer))
149 arg = right;
151 if (get_implied_value(right, &sval) &&
152 sval.value == bytes_per_element(pointer))
153 arg = left;
156 db_save_type_links(pointer, arg);
157 tmp = set_state_expr(size_id, pointer, alloc_expr_state(arg));
158 if (!tmp)
159 return;
160 set_state_expr(link_id, arg, alloc_expr_state(pointer));
163 static void match_calloc(const char *fn, struct expression *expr, void *unused)
165 struct expression *pointer, *call, *arg;
166 struct sm_state *tmp;
167 sval_t sval;
169 pointer = strip_expr(expr->left);
170 call = strip_expr(expr->right);
171 arg = get_argument_from_call_expr(call->args, 0);
172 if (get_implied_value(arg, &sval) &&
173 sval.value == bytes_per_element(pointer))
174 arg = get_argument_from_call_expr(call->args, 1);
176 db_save_type_links(pointer, arg);
177 tmp = set_state_expr(size_id, pointer, alloc_expr_state(arg));
178 if (!tmp)
179 return;
180 set_state_expr(link_id, arg, alloc_expr_state(pointer));
183 struct expression *get_size_variable(struct expression *buf)
185 struct smatch_state *state;
187 state = get_state_expr(size_id, buf);
188 if (state)
189 return state->data;
190 return NULL;
193 static void array_check(struct expression *expr)
195 struct expression *array;
196 struct expression *size;
197 struct expression *offset;
198 char *array_str, *offset_str;
200 expr = strip_expr(expr);
201 if (!is_array(expr))
202 return;
204 array = get_array_base(expr);
205 size = get_size_variable(array);
206 if (!size)
207 return;
208 offset = get_array_offset(expr);
209 if (!possible_comparison(size, SPECIAL_EQUAL, offset))
210 return;
212 array_str = expr_to_str(array);
213 offset_str = expr_to_str(offset);
214 sm_msg("warn: potentially one past the end of array '%s[%s]'", array_str, offset_str);
215 free_string(array_str);
216 free_string(offset_str);
219 struct db_info {
220 char *name;
221 int ret;
224 static int db_limitter_callback(void *_info, int argc, char **argv, char **azColName)
226 struct db_info *info = _info;
229 * If possible the limitters are tied to the struct they limit. If we
230 * aren't sure which struct they limit then we use them as limitters for
231 * everything.
233 if (!info->name || argv[0][0] == '\0' || strcmp(info->name, argv[0]) == 0)
234 info->ret = 1;
235 return 0;
238 static char *vsl_to_data_info_name(const char *name, struct var_sym_list *vsl)
240 struct var_sym *vs;
241 struct symbol *type;
242 static char buf[80];
243 const char *p;
245 if (ptr_list_size((struct ptr_list *)vsl) != 1)
246 return NULL;
247 vs = first_ptr_list((struct ptr_list *)vsl);
249 type = get_real_base_type(vs->sym);
250 if (!type || type->type != SYM_PTR)
251 goto top_level_name;
252 type = get_real_base_type(type);
253 if (!type || type->type != SYM_STRUCT)
254 goto top_level_name;
255 if (!type->ident)
256 goto top_level_name;
258 p = name;
259 while ((name = strstr(p, "->")))
260 p = name + 2;
262 snprintf(buf, sizeof(buf),"(struct %s)->%s", type->ident->name, p);
263 return alloc_sname(buf);
265 top_level_name:
266 if (!(vs->sym->ctype.modifiers & MOD_TOPLEVEL))
267 return NULL;
268 if (vs->sym->ctype.modifiers & MOD_STATIC)
269 snprintf(buf, sizeof(buf),"static %s", name);
270 else
271 snprintf(buf, sizeof(buf),"global %s", name);
272 return alloc_sname(buf);
275 static int db_var_is_array_limit(struct expression *array, const char *name, struct var_sym_list *vsl)
277 char *size_name;
278 char *array_name = get_data_info_name(array);
279 struct db_info db_info = {.name = array_name,};
281 size_name = vsl_to_data_info_name(name, vsl);
282 if (!size_name)
283 return 0;
285 run_sql(db_limitter_callback, &db_info,
286 "select value from data_info where type = %d and data = '%s';",
287 ARRAY_LEN, size_name);
289 return db_info.ret;
292 static int known_access_ok_comparison(struct expression *expr)
294 struct expression *array;
295 struct expression *size;
296 struct expression *offset;
297 int comparison;
299 array = get_array_base(expr);
300 size = get_size_variable(array);
301 if (!size)
302 return 0;
303 offset = get_array_offset(expr);
304 comparison = get_comparison(size, offset);
305 if (comparison == '>' || comparison == SPECIAL_UNSIGNED_GT)
306 return 1;
308 return 0;
311 static int known_access_ok_numbers(struct expression *expr)
313 struct expression *array;
314 struct expression *offset;
315 sval_t max;
316 int size;
318 array = get_array_base(expr);
319 offset = get_array_offset(expr);
321 size = get_array_size(array);
322 if (size <= 0)
323 return 0;
325 get_absolute_max(offset, &max);
326 if (max.uvalue < size)
327 return 1;
328 return 0;
331 static void array_check_data_info(struct expression *expr)
333 struct expression *array;
334 struct expression *offset;
335 struct state_list *slist;
336 struct sm_state *sm;
337 struct compare_data *comp;
338 char *offset_name;
339 const char *equal_name = NULL;
341 expr = strip_expr(expr);
342 if (!is_array(expr))
343 return;
345 if (known_access_ok_numbers(expr))
346 return;
347 if (known_access_ok_comparison(expr))
348 return;
350 array = get_array_base(expr);
351 offset = get_array_offset(expr);
352 offset_name = expr_to_var(offset);
353 if (!offset_name)
354 return;
355 slist = get_all_possible_equal_comparisons(offset);
356 if (!slist)
357 goto free;
359 FOR_EACH_PTR(slist, sm) {
360 comp = sm->state->data;
361 if (strcmp(comp->var1, offset_name) == 0) {
362 if (db_var_is_array_limit(array, comp->var2, comp->vsl2)) {
363 equal_name = comp->var2;
364 break;
366 } else if (strcmp(comp->var2, offset_name) == 0) {
367 if (db_var_is_array_limit(array, comp->var1, comp->vsl1)) {
368 equal_name = comp->var1;
369 break;
372 } END_FOR_EACH_PTR(sm);
374 if (equal_name) {
375 char *array_name = expr_to_str(array);
377 sm_msg("warn: potential off by one '%s[]' limit '%s'", array_name, equal_name);
378 free_string(array_name);
381 free:
382 free_slist(&slist);
383 free_string(offset_name);
386 static void add_allocation_function(const char *func, void *call_back, int param)
388 add_function_assign_hook(func, call_back, INT_PTR(param));
391 static char *buf_size_param_comparison(struct expression *array, struct expression_list *args)
393 struct expression *arg;
394 struct expression *size;
395 static char buf[32];
396 int i;
398 size = get_size_variable(array);
399 if (!size)
400 return NULL;
402 i = -1;
403 FOR_EACH_PTR(args, arg) {
404 i++;
405 if (arg == array)
406 continue;
407 if (!expr_equiv(arg, size))
408 continue;
409 snprintf(buf, sizeof(buf), "==$%d", i);
410 return buf;
411 } END_FOR_EACH_PTR(arg);
413 return NULL;
416 static void match_call(struct expression *call)
418 struct expression *arg;
419 char *compare;
420 int param;
422 param = -1;
423 FOR_EACH_PTR(call->args, arg) {
424 param++;
425 if (!is_pointer(arg))
426 continue;
427 compare = buf_size_param_comparison(arg, call->args);
428 if (!compare)
429 continue;
430 sql_insert_caller_info(call, ARRAY_LEN, param, "$", compare);
431 } END_FOR_EACH_PTR(arg);
434 static int get_param(int param, char **name, struct symbol **sym)
436 struct symbol *arg;
437 int i;
439 i = 0;
440 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, arg) {
442 * this is a temporary hack to work around a bug (I think in sparse?)
443 * 2.6.37-rc1:fs/reiserfs/journal.o
444 * If there is a function definition without parameter name found
445 * after a function implementation then it causes a crash.
446 * int foo() {}
447 * int bar(char *);
449 if (arg->ident->name < (char *)100)
450 continue;
451 if (i == param) {
452 *name = arg->ident->name;
453 *sym = arg;
454 return TRUE;
456 i++;
457 } END_FOR_EACH_PTR(arg);
459 return FALSE;
462 static void set_param_compare(const char *array_name, struct symbol *array_sym, char *key, char *value)
464 struct expression *array_expr;
465 struct expression *size_expr;
466 struct symbol *size_sym;
467 char *size_name;
468 long param;
469 struct sm_state *tmp;
471 if (strncmp(value, "==$", 3) != 0)
472 return;
473 param = strtol(value + 3, NULL, 10);
474 if (!get_param(param, &size_name, &size_sym))
475 return;
476 array_expr = symbol_expression(array_sym);
477 size_expr = symbol_expression(size_sym);
479 tmp = set_state_expr(size_id, array_expr, alloc_expr_state(size_expr));
480 if (!tmp)
481 return;
482 set_state_expr(link_id, size_expr, alloc_expr_state(array_expr));
487 static void munge_start_states(struct statement *stmt)
489 struct state_list *slist = NULL;
490 struct sm_state *sm;
491 struct sm_state *poss;
493 FOR_EACH_MY_SM(size_id, __get_cur_stree(), sm) {
494 if (sm->state != &merged)
495 continue;
497 * screw it. let's just assume that if one caller passes the
498 * size then they all do.
500 FOR_EACH_PTR(sm->possible, poss) {
501 if (poss->state != &merged &&
502 poss->state != &undefined) {
503 add_ptr_list(&slist, poss);
504 break;
506 } END_FOR_EACH_PTR(poss);
507 } END_FOR_EACH_SM(sm);
509 FOR_EACH_PTR(slist, sm) {
510 set_state(size_id, sm->name, sm->sym, sm->state);
511 } END_FOR_EACH_PTR(sm);
513 free_slist(&slist);
516 void check_buf_comparison(int id)
518 size_id = id;
520 add_unmatched_state_hook(size_id, &unmatched_state);
522 add_allocation_function("malloc", &match_alloc, 0);
523 add_allocation_function("memdup", &match_alloc, 1);
524 add_allocation_function("realloc", &match_alloc, 1);
525 if (option_project == PROJ_KERNEL) {
526 add_allocation_function("kmalloc", &match_alloc, 0);
527 add_allocation_function("kzalloc", &match_alloc, 0);
528 add_allocation_function("vmalloc", &match_alloc, 0);
529 add_allocation_function("__vmalloc", &match_alloc, 0);
530 add_allocation_function("sock_kmalloc", &match_alloc, 1);
531 add_allocation_function("kmemdup", &match_alloc, 1);
532 add_allocation_function("kmemdup_user", &match_alloc, 1);
533 add_allocation_function("dma_alloc_attrs", &match_alloc, 1);
534 add_allocation_function("pci_alloc_consistent", &match_alloc, 1);
535 add_allocation_function("pci_alloc_coherent", &match_alloc, 1);
536 add_allocation_function("devm_kmalloc", &match_alloc, 1);
537 add_allocation_function("devm_kzalloc", &match_alloc, 1);
538 add_allocation_function("kcalloc", &match_calloc, 0);
539 add_allocation_function("krealloc", &match_alloc, 1);
542 add_hook(&array_check, OP_HOOK);
543 add_hook(&array_check_data_info, OP_HOOK);
545 add_hook(&match_call, FUNCTION_CALL_HOOK);
546 select_caller_info_hook(set_param_compare, ARRAY_LEN);
547 add_hook(&munge_start_states, AFTER_DEF_HOOK);
550 void check_buf_comparison_links(int id)
552 link_id = id;
553 add_modification_hook(link_id, &match_modify);