Fix range comparison.
[smatch.git] / smatch_extra.c
blob0e85bdc4a08f97f2e0bea73aa2657cefa5e515ff
1 /*
2 * sparse/smatch_extra.c
4 * Copyright (C) 2008 Dan Carpenter.
6 * Licensed under the Open Software License version 1.1
8 */
10 #include <stdlib.h>
11 #define __USE_ISOC99
12 #include <limits.h>
13 #include "parse.h"
14 #include "smatch.h"
15 #include "smatch_slist.h"
16 #include "smatch_extra.h"
18 static int my_id;
20 struct data_range whole_range = {
21 .min = LLONG_MIN,
22 .max = LLONG_MAX,
25 static struct smatch_state *alloc_extra_state_no_name(int val)
27 struct smatch_state *state;
29 state = __alloc_smatch_state(0);
30 if (val == UNDEFINED)
31 state->data = (void *)alloc_dinfo_range(whole_range.min, whole_range.max);
32 else
33 state->data = (void *)alloc_dinfo_range(val, val);
34 return state;
37 /* We do this because ->value_ranges is a list */
38 struct smatch_state *extra_undefined()
40 static struct data_info *dinfo;
41 static struct smatch_state *ret;
43 dinfo = alloc_dinfo_range(whole_range.min, whole_range.max);
44 ret = __alloc_smatch_state(0);
45 ret->name = "unknown";
46 ret->data = dinfo;
47 return ret;
50 struct smatch_state *alloc_extra_state(int val)
52 struct smatch_state *state;
54 if (val == UNDEFINED)
55 return extra_undefined();
56 state = alloc_extra_state_no_name(val);
57 state->name = show_ranges(((struct data_info *)state->data)->value_ranges);
58 return state;
61 struct smatch_state *filter_ranges(struct smatch_state *orig,
62 long long filter_min, long long filter_max)
64 struct smatch_state *ret;
65 struct data_info *orig_info;
66 struct data_info *ret_info;
68 if (!orig)
69 orig = extra_undefined();
70 orig_info = (struct data_info *)orig->data;
71 ret = alloc_extra_state_no_name(UNDEFINED);
72 ret_info = (struct data_info *)ret->data;
73 ret_info->value_ranges = remove_range(orig_info->value_ranges, filter_min, filter_max);
74 ret->name = show_ranges(ret_info->value_ranges);
75 return ret;
78 struct smatch_state *add_filter(struct smatch_state *orig, long long num)
80 return filter_ranges(orig, num, num);
83 static struct smatch_state *merge_func(const char *name, struct symbol *sym,
84 struct smatch_state *s1,
85 struct smatch_state *s2)
87 struct data_info *info1 = (struct data_info *)s1->data;
88 struct data_info *info2 = (struct data_info *)s2->data;
89 struct data_info *ret_info;
90 struct smatch_state *tmp;
92 tmp = alloc_extra_state_no_name(UNDEFINED);
93 tmp->name = "extra_merged";
94 ret_info = (struct data_info *)tmp->data;
95 ret_info->merged = 1;
96 ret_info->value_ranges = range_list_union(info1->value_ranges, info2->value_ranges);
97 return tmp;
100 struct sm_state *__extra_merge(struct sm_state *one, struct state_list *slist1,
101 struct sm_state *two, struct state_list *slist2)
103 struct data_info *info1;
104 struct data_info *info2;
106 if (!one->state->data || !two->state->data) {
107 smatch_msg("internal error in smatch extra '%s = %s or %s'",
108 one->name, show_state(one->state),
109 show_state(two->state));
110 return alloc_state(one->name, one->owner, one->sym,
111 extra_undefined());
114 info1 = (struct data_info *)one->state->data;
115 info2 = (struct data_info *)two->state->data;
117 if (!info1->merged)
118 free_stack(&one->my_pools);
119 if (!info2->merged)
120 free_stack(&two->my_pools);
122 if (one == two && !one->my_pools) {
123 add_pool(&one->my_pools, slist1);
124 add_pool(&one->my_pools, slist2);
125 } else {
126 if (!one->my_pools)
127 add_pool(&one->my_pools, slist1);
128 if (!two->my_pools)
129 add_pool(&two->my_pools, slist2);
132 add_pool(&one->all_pools, slist1);
133 add_pool(&two->all_pools, slist2);
134 return merge_sm_states(one, two);
137 struct sm_state *__extra_and_merge(struct sm_state *sm,
138 struct state_list_stack *stack)
140 struct state_list *slist;
141 struct sm_state *ret = NULL;
142 struct sm_state *tmp;
143 int i = 0;
145 FOR_EACH_PTR(stack, slist) {
146 if (!i++) {
147 ret = get_sm_state_slist(slist, sm->name, sm->owner,
148 sm->sym);
149 } else {
150 tmp = get_sm_state_slist(slist, sm->name, sm->owner,
151 sm->sym);
152 ret = merge_sm_states(ret, tmp);
154 } END_FOR_EACH_PTR(slist);
155 if (!ret) {
156 smatch_msg("Internal error in __extra_and_merge");
157 return NULL;
159 ret->my_pools = stack;
160 ret->all_pools = clone_stack(stack);
161 return ret;
164 static struct smatch_state *unmatched_state(struct sm_state *sm)
166 return extra_undefined();
169 static void match_function_call(struct expression *expr)
171 struct expression *tmp;
172 struct symbol *sym;
173 char *name;
174 int i = 0;
176 FOR_EACH_PTR(expr->args, tmp) {
177 if (tmp->op == '&') {
178 name = get_variable_from_expr(tmp->unop, &sym);
179 if (name) {
180 set_state(name, my_id, sym, extra_undefined());
182 free_string(name);
184 i++;
185 } END_FOR_EACH_PTR(tmp);
188 static void match_assign(struct expression *expr)
190 struct expression *left;
191 struct symbol *sym;
192 char *name;
194 left = strip_expr(expr->left);
195 name = get_variable_from_expr(left, &sym);
196 if (!name)
197 return;
198 set_state(name, my_id, sym, alloc_extra_state(get_value(expr->right)));
199 free_string(name);
202 static void undef_expr(struct expression *expr)
204 struct symbol *sym;
205 char *name;
207 name = get_variable_from_expr(expr->unop, &sym);
208 if (!name)
209 return;
210 if (!get_state(name, my_id, sym)) {
211 free_string(name);
212 return;
214 set_state(name, my_id, sym, extra_undefined());
215 free_string(name);
218 static void match_declarations(struct symbol *sym)
220 const char *name;
222 if (sym->ident) {
223 name = sym->ident->name;
224 if (sym->initializer) {
225 set_state(name, my_id, sym, alloc_extra_state(get_value(sym->initializer)));
226 } else {
227 set_state(name, my_id, sym, extra_undefined());
232 static void match_function_def(struct symbol *sym)
234 struct symbol *arg;
236 FOR_EACH_PTR(sym->ctype.base_type->arguments, arg) {
237 if (!arg->ident) {
238 continue;
240 set_state(arg->ident->name, my_id, arg, extra_undefined());
241 } END_FOR_EACH_PTR(arg);
244 static void match_unop(struct expression *expr)
246 struct symbol *sym;
247 char *name;
248 const char *tmp;
251 name = get_variable_from_expr(expr->unop, &sym);
252 if (!name)
253 return;
255 tmp = show_special(expr->op);
256 if ((!strcmp(tmp, "--")) || (!strcmp(tmp, "++")))
257 set_state(name, my_id, sym, extra_undefined());
258 free_string(name);
261 int get_implied_value(struct expression *expr)
263 struct smatch_state *state;
264 int val;
265 struct symbol *sym;
266 char *name;
268 val = get_value(expr);
269 if (val != UNDEFINED)
270 return val;
272 name = get_variable_from_expr(expr, &sym);
273 if (!name)
274 return UNDEFINED;
275 state = get_state(name, my_id, sym);
276 free_string(name);
277 if (!state || !state->data)
278 return UNDEFINED;
279 return get_single_value_from_range((struct data_info *)state->data);
282 int true_comparison(int left, int comparison, int right)
284 switch(comparison){
285 case '<':
286 case SPECIAL_UNSIGNED_LT:
287 if (left < right)
288 return 1;
289 return 0;
290 case SPECIAL_UNSIGNED_LTE:
291 case SPECIAL_LTE:
292 if (left < right)
293 return 1;
294 case SPECIAL_EQUAL:
295 if (left == right)
296 return 1;
297 return 0;
298 case SPECIAL_UNSIGNED_GTE:
299 case SPECIAL_GTE:
300 if (left == right)
301 return 1;
302 case '>':
303 case SPECIAL_UNSIGNED_GT:
304 if (left > right)
305 return 1;
306 return 0;
307 case SPECIAL_NOTEQUAL:
308 if (left != right)
309 return 1;
310 return 0;
311 default:
312 smatch_msg("unhandled comparison %d\n", comparison);
313 return UNDEFINED;
315 return 0;
318 int true_comparison_range(struct data_range *left, int comparison, struct data_range *right)
320 switch(comparison){
321 case '<':
322 case SPECIAL_UNSIGNED_LT:
323 if (left->min < right->max)
324 return 1;
325 return 0;
326 case SPECIAL_UNSIGNED_LTE:
327 case SPECIAL_LTE:
328 if (left->min <= right->max)
329 return 1;
330 return 0;
331 case SPECIAL_EQUAL:
332 if (left->max < right->min)
333 return 0;
334 if (left->min > right->max)
335 return 0;
336 return 1;
337 case SPECIAL_UNSIGNED_GTE:
338 case SPECIAL_GTE:
339 if (left->max >= right->min)
340 return 1;
341 return 0;
342 case '>':
343 case SPECIAL_UNSIGNED_GT:
344 if (left->max > right->min)
345 return 1;
346 return 0;
347 case SPECIAL_NOTEQUAL:
348 if (left->min != left->max)
349 return 1;
350 if (right->min != right->max)
351 return 1;
352 if (left->min != right->min)
353 return 1;
354 return 0;
355 default:
356 smatch_msg("unhandled comparison %d\n", comparison);
357 return UNDEFINED;
359 return 0;
362 int false_comparison_range(struct data_range *left, int comparison, struct data_range *right)
364 switch(comparison){
365 case '<':
366 case SPECIAL_UNSIGNED_LT:
367 if (left->max >= right->min)
368 return 1;
369 return 0;
370 case SPECIAL_UNSIGNED_LTE:
371 case SPECIAL_LTE:
372 if (left->max > right->min)
373 return 1;
374 return 0;
375 case SPECIAL_EQUAL:
376 if (left->min != left->max)
377 return 1;
378 if (right->min != right->max)
379 return 1;
380 if (left->min != right->min)
381 return 1;
382 return 0;
383 case SPECIAL_UNSIGNED_GTE:
384 case SPECIAL_GTE:
385 if (left->min < right->max)
386 return 1;
387 return 0;
388 case '>':
389 case SPECIAL_UNSIGNED_GT:
390 if (left->min >= right->max)
391 return 1;
392 return 0;
393 case SPECIAL_NOTEQUAL:
394 if (left->max < right->min)
395 return 0;
396 if (left->min > right->max)
397 return 0;
398 return 1;
399 default:
400 smatch_msg("unhandled comparison %d\n", comparison);
401 return UNDEFINED;
403 return 0;
406 static int do_comparison(struct expression *expr)
408 int left, right, ret;
410 if ((left = get_implied_value(expr->left)) == UNDEFINED)
411 return UNDEFINED;
413 if ((right = get_implied_value(expr->right)) == UNDEFINED)
414 return UNDEFINED;
416 ret = true_comparison(left, expr->op, right);
417 if (ret == 1) {
418 SM_DEBUG("%d known condition: %d %s %d => true\n",
419 get_lineno(), left, show_special(expr->op), right);
420 } else if (ret == 0) {
421 SM_DEBUG("%d known condition: %d %s %d => false\n",
422 get_lineno(), left, show_special(expr->op), right);
424 return ret;
427 int last_stmt_val(struct statement *stmt)
429 struct expression *expr;
431 stmt = last_ptr_list((struct ptr_list *)stmt->stmts);
432 if (stmt->type != STMT_EXPRESSION)
433 return UNDEFINED;
434 expr = stmt->expression;
435 return get_value(expr);
438 static void match_comparison(struct expression *expr)
440 long long value;
441 char *name = NULL;
442 struct symbol *sym;
443 struct smatch_state *eq_state;
444 struct smatch_state *neq_state;
446 if (expr->op != SPECIAL_EQUAL && expr->op != SPECIAL_NOTEQUAL)
447 return;
448 value = get_value(expr->left);
449 if (value != UNDEFINED) {
450 name = get_variable_from_expr(expr->right, &sym);
451 } else {
452 value = get_value(expr->right);
453 name = get_variable_from_expr(expr->left, &sym);
455 if (value == UNDEFINED || !name || !sym)
456 goto free;
457 eq_state = alloc_extra_state(value);
458 neq_state = add_filter(extra_undefined(), value);
459 if (expr->op == SPECIAL_EQUAL)
460 set_true_false_states(name, my_id, sym, eq_state, neq_state);
461 else
462 set_true_false_states(name, my_id, sym, neq_state, eq_state);
463 free:
464 free_string(name);
467 /* this is actually hooked from smatch_implied.c... it's hacky, yes */
468 void __extra_match_condition(struct expression *expr)
470 struct symbol *sym;
471 char *name;
472 struct smatch_state *pre_state;
473 struct smatch_state *true_state;
474 struct smatch_state *false_state;
476 expr = strip_expr(expr);
477 switch(expr->type) {
478 case EXPR_PREOP:
479 case EXPR_SYMBOL:
480 case EXPR_DEREF:
481 name = get_variable_from_expr(expr, &sym);
482 if (!name)
483 return;
484 pre_state = get_state(name, my_id, sym);
485 true_state = add_filter(pre_state, 0);
486 false_state = alloc_extra_state(0);
487 set_true_false_states(name, my_id, sym, true_state, false_state);
488 free_string(name);
489 return;
490 case EXPR_COMPARE:
491 match_comparison(expr);
492 return;
496 static int variable_non_zero(struct expression *expr)
498 char *name;
499 struct symbol *sym;
500 struct smatch_state *state;
501 int ret = UNDEFINED;
503 name = get_variable_from_expr(expr, &sym);
504 if (!name || !sym)
505 goto exit;
506 state = get_state(name, my_id, sym);
507 if (!state || !state->data)
508 goto exit;
509 ret = true_comparison(get_single_value_from_range((struct data_info *)state->data),
510 SPECIAL_NOTEQUAL, 0);
511 exit:
512 free_string(name);
513 return ret;
516 int known_condition_true(struct expression *expr)
518 int tmp;
520 if (!expr)
521 return 0;
523 tmp = get_value(expr);
524 if (tmp && tmp != UNDEFINED)
525 return 1;
527 expr = strip_expr(expr);
528 switch(expr->type) {
529 case EXPR_PREOP:
530 if (expr->op == '!') {
531 if (known_condition_false(expr->unop))
532 return 1;
533 break;
535 break;
536 default:
537 break;
539 return 0;
542 int known_condition_false(struct expression *expr)
544 if (!expr)
545 return 0;
547 if (is_zero(expr))
548 return 1;
550 switch(expr->type) {
551 case EXPR_PREOP:
552 if (expr->op == '!') {
553 if (known_condition_true(expr->unop))
554 return 1;
555 break;
557 break;
558 default:
559 break;
561 return 0;
564 int implied_condition_true(struct expression *expr)
566 struct statement *stmt;
567 int tmp;
569 if (!expr)
570 return 0;
572 tmp = get_value(expr);
573 if (tmp && tmp != UNDEFINED)
574 return 1;
576 expr = strip_expr(expr);
577 switch(expr->type) {
578 case EXPR_COMPARE:
579 if (do_comparison(expr) == 1)
580 return 1;
581 break;
582 case EXPR_PREOP:
583 if (expr->op == '!') {
584 if (implied_condition_false(expr->unop))
585 return 1;
586 break;
588 stmt = get_block_thing(expr);
589 if (stmt && (last_stmt_val(stmt) == 1))
590 return 1;
591 break;
592 default:
593 if (variable_non_zero(expr) == 1)
594 return 1;
595 break;
597 return 0;
600 int implied_condition_false(struct expression *expr)
602 struct statement *stmt;
603 struct expression *tmp;
605 if (!expr)
606 return 0;
608 if (is_zero(expr))
609 return 1;
611 switch(expr->type) {
612 case EXPR_COMPARE:
613 if (do_comparison(expr) == 0)
614 return 1;
615 case EXPR_PREOP:
616 if (expr->op == '!') {
617 if (implied_condition_true(expr->unop))
618 return 1;
619 break;
621 stmt = get_block_thing(expr);
622 if (stmt && (last_stmt_val(stmt) == 0))
623 return 1;
624 tmp = strip_expr(expr);
625 if (tmp != expr)
626 return implied_condition_false(tmp);
627 break;
628 default:
629 if (variable_non_zero(expr) == 0)
630 return 1;
631 break;
633 return 0;
636 void register_smatch_extra(int id)
638 my_id = id;
639 add_merge_hook(my_id, &merge_func);
640 add_unmatched_state_hook(my_id, &unmatched_state);
641 add_hook(&undef_expr, OP_HOOK);
642 add_hook(&match_function_def, FUNC_DEF_HOOK);
643 add_hook(&match_function_call, FUNCTION_CALL_HOOK);
644 add_hook(&match_assign, ASSIGNMENT_HOOK);
645 add_hook(&match_declarations, DECLARATION_HOOK);
646 add_hook(&match_unop, OP_HOOK);
647 add_hook(&free_data_info_allocs, END_FUNC_HOOK);
649 #ifdef KERNEL
650 /* I don't know how to test for the ATTRIB_NORET attribute. :( */
651 add_function_hook("panic", &__match_nullify_path_hook, NULL);
652 add_function_hook("do_exit", &__match_nullify_path_hook, NULL);
653 add_function_hook("complete_and_exit", &__match_nullify_path_hook, NULL);
654 add_function_hook("__module_put_and_exit", &__match_nullify_path_hook, NULL);
655 add_function_hook("do_group_exit", &__match_nullify_path_hook, NULL);
656 #endif