comparison: copy all the comparisons when you assign a variable to another
[smatch.git] / smatch_comparison.c
blob5c72acc1e1f9df03f7c3e403aeaf788717f4f51a
1 /*
2 * smatch/smatch_comparison.c
4 * Copyright (C) 2012 Oracle.
6 * Licensed under the Open Software License version 1.1
8 */
11 * The point here is to store the relationships between two variables.
12 * Ie: y > x.
13 * To do that we create a state with the two variables in alphabetical order:
14 * ->name = "x vs y" and the state would be "<". On the false path the state
15 * would be ">=".
17 * Part of the trick of it is that if x or y is modified then we need to reset
18 * the state. We need to keep a list of all the states which depend on x and
19 * all the states which depend on y. The link_id code handles this.
21 * Future work: If we know that x is greater than y and y is greater than z
22 * then we know that x is greater than z.
25 #include "smatch.h"
26 #include "smatch_extra.h"
27 #include "smatch_slist.h"
29 static int compare_id;
30 static int link_id;
32 struct compare_data {
33 const char *var1;
34 struct symbol *sym1;
35 const char *var2;
36 struct symbol *sym2;
37 int comparison;
39 ALLOCATOR(compare_data, "compare data");
41 static struct smatch_state *alloc_compare_state(
42 const char *var1, struct symbol *sym1,
43 const char *var2, struct symbol *sym2,
44 int comparison)
46 struct smatch_state *state;
47 struct compare_data *data;
49 state = __alloc_smatch_state(0);
50 state->name = alloc_sname(show_special(comparison));
51 data = __alloc_compare_data(0);
52 data->var1 = alloc_sname(var1);
53 data->sym1 = sym1;
54 data->var2 = alloc_sname(var2);
55 data->sym2 = sym2;
56 data->comparison = comparison;
57 state->data = data;
58 return state;
61 static int state_to_comparison(struct smatch_state *state)
63 if (!state || !state->data)
64 return 0;
65 return ((struct compare_data *)state->data)->comparison;
69 * flip_op() reverses the op left and right. So "x >= y" becomes "y <= x".
71 static int flip_op(int op)
73 switch (op) {
74 case 0:
75 return 0;
76 case '<':
77 return '>';
78 case SPECIAL_UNSIGNED_LT:
79 return SPECIAL_UNSIGNED_GT;
80 case SPECIAL_LTE:
81 return SPECIAL_GTE;
82 case SPECIAL_UNSIGNED_LTE:
83 return SPECIAL_UNSIGNED_GTE;
84 case SPECIAL_EQUAL:
85 return SPECIAL_EQUAL;
86 case SPECIAL_NOTEQUAL:
87 return SPECIAL_NOTEQUAL;
88 case SPECIAL_GTE:
89 return SPECIAL_LTE;
90 case SPECIAL_UNSIGNED_GTE:
91 return SPECIAL_UNSIGNED_LTE;
92 case '>':
93 return '<';
94 case SPECIAL_UNSIGNED_GT:
95 return SPECIAL_UNSIGNED_LT;
96 default:
97 sm_msg("internal smatch bug. unhandled comparison %d", op);
98 return op;
102 static int falsify_op(int op)
104 switch (op) {
105 case 0:
106 return 0;
107 case '<':
108 return SPECIAL_GTE;
109 case SPECIAL_UNSIGNED_LT:
110 return SPECIAL_UNSIGNED_GTE;
111 case SPECIAL_LTE:
112 return '>';
113 case SPECIAL_UNSIGNED_LTE:
114 return SPECIAL_UNSIGNED_GT;
115 case SPECIAL_EQUAL:
116 return SPECIAL_NOTEQUAL;
117 case SPECIAL_NOTEQUAL:
118 return SPECIAL_EQUAL;
119 case SPECIAL_GTE:
120 return '<';
121 case SPECIAL_UNSIGNED_GTE:
122 return SPECIAL_UNSIGNED_LT;
123 case '>':
124 return SPECIAL_LTE;
125 case SPECIAL_UNSIGNED_GT:
126 return SPECIAL_UNSIGNED_LTE;
127 default:
128 sm_msg("internal smatch bug. unhandled comparison %d", op);
129 return op;
133 static int rl_comparison(struct range_list *left_rl, struct range_list *right_rl)
135 sval_t left_min, left_max, right_min, right_max;
137 if (!left_rl || !right_rl)
138 return 0;
140 left_min = rl_min(left_rl);
141 left_max = rl_max(left_rl);
142 right_min = rl_min(right_rl);
143 right_max = rl_max(right_rl);
145 if (left_min.value == left_max.value &&
146 right_min.value == right_max.value &&
147 left_min.value == right_min.value)
148 return SPECIAL_EQUAL;
150 if (sval_cmp(left_max, right_min) < 0)
151 return '<';
152 if (sval_cmp(left_max, right_min) == 0)
153 return SPECIAL_LTE;
154 if (sval_cmp(left_min, right_max) > 0)
155 return '>';
156 if (sval_cmp(left_min, right_max) == 0)
157 return SPECIAL_GTE;
159 return 0;
162 static struct smatch_state *unmatched_comparison(struct sm_state *sm)
164 struct compare_data *data = sm->state->data;
165 struct range_list *left_rl, *right_rl;
166 int op;
168 if (!data)
169 return &undefined;
171 if (!get_implied_rl_var_sym(data->var1, data->sym1, &left_rl))
172 return &undefined;
173 if (!get_implied_rl_var_sym(data->var2, data->sym2, &right_rl))
174 return &undefined;
176 op = rl_comparison(left_rl, right_rl);
177 if (op)
178 return alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, op);
180 return &undefined;
183 /* remove_unsigned_from_comparison() is obviously a hack. */
184 static int remove_unsigned_from_comparison(int op)
186 switch (op) {
187 case SPECIAL_UNSIGNED_LT:
188 return '<';
189 case SPECIAL_UNSIGNED_LTE:
190 return SPECIAL_LTE;
191 case SPECIAL_UNSIGNED_GTE:
192 return SPECIAL_GTE;
193 case SPECIAL_UNSIGNED_GT:
194 return '>';
195 default:
196 return op;
200 static int merge_comparisons(int one, int two)
202 int LT, EQ, GT;
204 one = remove_unsigned_from_comparison(one);
205 two = remove_unsigned_from_comparison(two);
207 LT = EQ = GT = 0;
209 switch (one) {
210 case '<':
211 LT = 1;
212 break;
213 case SPECIAL_LTE:
214 LT = 1;
215 EQ = 1;
216 break;
217 case SPECIAL_EQUAL:
218 EQ = 1;
219 break;
220 case SPECIAL_GTE:
221 GT = 1;
222 EQ = 1;
223 break;
224 case '>':
225 GT = 1;
228 switch (two) {
229 case '<':
230 LT = 1;
231 break;
232 case SPECIAL_LTE:
233 LT = 1;
234 EQ = 1;
235 break;
236 case SPECIAL_EQUAL:
237 EQ = 1;
238 break;
239 case SPECIAL_GTE:
240 GT = 1;
241 EQ = 1;
242 break;
243 case '>':
244 GT = 1;
247 if (LT && EQ && GT)
248 return 0;
249 if (LT && EQ)
250 return SPECIAL_LTE;
251 if (LT && GT)
252 return SPECIAL_NOTEQUAL;
253 if (LT)
254 return '<';
255 if (EQ && GT)
256 return SPECIAL_GTE;
257 if (GT)
258 return '>';
259 return 0;
262 static struct smatch_state *merge_compare_states(struct smatch_state *s1, struct smatch_state *s2)
264 struct compare_data *data = s1->data;
265 int op;
267 op = merge_comparisons(state_to_comparison(s1), state_to_comparison(s2));
268 if (op)
269 return alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, op);
270 return &undefined;
273 struct smatch_state *alloc_link_state(struct string_list *links)
275 struct smatch_state *state;
276 static char buf[256];
277 char *tmp;
278 int i;
280 state = __alloc_smatch_state(0);
282 i = 0;
283 FOR_EACH_PTR(links, tmp) {
284 if (!i++)
285 snprintf(buf, sizeof(buf), "%s", tmp);
286 else
287 snprintf(buf, sizeof(buf), "%s, %s", buf, tmp);
288 } END_FOR_EACH_PTR(tmp);
290 state->name = alloc_sname(buf);
291 state->data = links;
292 return state;
295 static void save_start_states(struct statement *stmt)
297 struct symbol *param;
298 char orig[64];
299 char state_name[128];
300 struct smatch_state *state;
301 struct string_list *links;
302 char *link;
304 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, param) {
305 if (!param->ident)
306 continue;
307 snprintf(orig, sizeof(orig), "%s orig", param->ident->name);
308 snprintf(state_name, sizeof(state_name), "%s vs %s", param->ident->name, orig);
309 state = alloc_compare_state(param->ident->name, param, alloc_sname(orig), NULL, SPECIAL_EQUAL);
310 set_state(compare_id, state_name, NULL, state);
312 link = alloc_sname(state_name);
313 links = NULL;
314 insert_string(&links, link);
315 state = alloc_link_state(links);
316 set_state(link_id, param->ident->name, param, state);
317 } END_FOR_EACH_PTR(param);
320 static struct smatch_state *merge_links(struct smatch_state *s1, struct smatch_state *s2)
322 struct smatch_state *ret;
323 struct string_list *links;
325 links = combine_string_lists(s1->data, s2->data);
326 ret = alloc_link_state(links);
327 return ret;
330 static void save_link_var_sym(const char *var, struct symbol *sym, char *link)
332 struct smatch_state *old_state, *new_state;
333 struct string_list *links;
334 char *new;
336 old_state = get_state(link_id, var, sym);
337 if (old_state)
338 links = clone_str_list(old_state->data);
339 else
340 links = NULL;
342 new = alloc_sname(link);
343 insert_string(&links, new);
345 new_state = alloc_link_state(links);
346 set_state(link_id, var, sym, new_state);
349 static void save_link(struct expression *expr, char *link)
351 char *var;
352 struct symbol *sym;
354 var = expr_to_var_sym(expr, &sym);
355 if (!var || !sym)
356 goto done;
358 save_link_var_sym(var, sym, link);
359 done:
360 free_string(var);
363 static void match_inc(struct sm_state *sm)
365 struct string_list *links;
366 struct smatch_state *state;
367 char *tmp;
369 links = sm->state->data;
371 FOR_EACH_PTR(links, tmp) {
372 state = get_state(compare_id, tmp, NULL);
374 switch (state_to_comparison(state)) {
375 case SPECIAL_EQUAL:
376 case SPECIAL_GTE:
377 case SPECIAL_UNSIGNED_GTE:
378 case '>':
379 case SPECIAL_UNSIGNED_GT: {
380 struct compare_data *data = state->data;
381 struct smatch_state *new;
383 new = alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, '>');
384 set_state(compare_id, tmp, NULL, new);
385 break;
387 default:
388 set_state(compare_id, tmp, NULL, &undefined);
390 } END_FOR_EACH_PTR(tmp);
393 static void match_dec(struct sm_state *sm)
395 struct string_list *links;
396 struct smatch_state *state;
397 char *tmp;
399 links = sm->state->data;
401 FOR_EACH_PTR(links, tmp) {
402 state = get_state(compare_id, tmp, NULL);
404 switch (state_to_comparison(state)) {
405 case SPECIAL_EQUAL:
406 case SPECIAL_LTE:
407 case SPECIAL_UNSIGNED_LTE:
408 case '<':
409 case SPECIAL_UNSIGNED_LT: {
410 struct compare_data *data = state->data;
411 struct smatch_state *new;
413 new = alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, '<');
414 set_state(compare_id, tmp, NULL, new);
415 break;
417 default:
418 set_state(compare_id, tmp, NULL, &undefined);
420 } END_FOR_EACH_PTR(tmp);
423 static int match_inc_dec(struct sm_state *sm, struct expression *mod_expr)
425 if (!mod_expr)
426 return 0;
427 if (mod_expr->type != EXPR_PREOP && mod_expr->type != EXPR_POSTOP)
428 return 0;
430 if (mod_expr->op == SPECIAL_INCREMENT) {
431 match_inc(sm);
432 return 1;
434 if (mod_expr->op == SPECIAL_DECREMENT) {
435 match_dec(sm);
436 return 1;
438 return 0;
441 static void match_modify(struct sm_state *sm, struct expression *mod_expr)
443 struct string_list *links;
444 char *tmp;
446 if (match_inc_dec(sm, mod_expr))
447 return;
449 links = sm->state->data;
451 FOR_EACH_PTR(links, tmp) {
452 set_state(compare_id, tmp, NULL, &undefined);
453 } END_FOR_EACH_PTR(tmp);
454 set_state(link_id, sm->name, sm->sym, &undefined);
457 static void match_logic(struct expression *expr)
459 char *left = NULL;
460 char *right = NULL;
461 struct symbol *left_sym, *right_sym;
462 int op, false_op;
463 struct smatch_state *true_state, *false_state;
464 char state_name[256];
466 if (expr->type != EXPR_COMPARE)
467 return;
468 left = expr_to_var_sym(expr->left, &left_sym);
469 if (!left || !left_sym)
470 goto free;
471 right = expr_to_var_sym(expr->right, &right_sym);
472 if (!right || !right_sym)
473 goto free;
475 if (strcmp(left, right) > 0) {
476 struct symbol *tmp_sym = left_sym;
477 char *tmp_name = left;
479 left = right;
480 left_sym = right_sym;
481 right = tmp_name;
482 right_sym = tmp_sym;
483 op = flip_op(expr->op);
484 } else {
485 op = expr->op;
487 false_op = falsify_op(op);
488 snprintf(state_name, sizeof(state_name), "%s vs %s", left, right);
489 true_state = alloc_compare_state(left, left_sym, right, right_sym, op);
490 false_state = alloc_compare_state(left, left_sym, right, right_sym, false_op);
492 set_true_false_states(compare_id, state_name, NULL, true_state, false_state);
493 save_link(expr->left, state_name);
494 save_link(expr->right, state_name);
495 free:
496 free_string(left);
497 free_string(right);
500 static void add_comparison_var_sym(const char *left_name, struct symbol *left_sym, int comparison, const char *right_name, struct symbol *right_sym)
502 struct smatch_state *state;
503 char state_name[256];
505 if (strcmp(left_name, right_name) > 0) {
506 struct symbol *tmp_sym = left_sym;
507 const char *tmp_name = left_name;
509 left_name = right_name;
510 left_sym = right_sym;
511 right_name = tmp_name;
512 right_sym = tmp_sym;
513 comparison = flip_op(comparison);
515 snprintf(state_name, sizeof(state_name), "%s vs %s", left_name, right_name);
516 state = alloc_compare_state(left_name, left_sym, right_name, right_sym, comparison);
518 set_state(compare_id, state_name, NULL, state);
519 save_link_var_sym(left_name, left_sym, state_name);
520 save_link_var_sym(right_name, right_sym, state_name);
523 static void add_comparison(struct expression *left, int comparison, struct expression *right)
525 char *left_name = NULL;
526 char *right_name = NULL;
527 struct symbol *left_sym, *right_sym;
529 left_name = expr_to_var_sym(left, &left_sym);
530 if (!left_name || !left_sym)
531 goto free;
532 right_name = expr_to_var_sym(right, &right_sym);
533 if (!right_name || !right_sym)
534 goto free;
536 add_comparison_var_sym(left_name, left_sym, comparison, right_name, right_sym);
538 free:
539 free_string(left_name);
540 free_string(right_name);
543 static void match_assign_add(struct expression *expr)
545 struct expression *right;
546 struct expression *r_left, *r_right;
547 sval_t left_tmp, right_tmp;
549 right = strip_expr(expr->right);
550 r_left = strip_expr(right->left);
551 r_right = strip_expr(right->right);
553 if (!is_capped(expr->left)) {
554 get_absolute_max(r_left, &left_tmp);
555 get_absolute_max(r_right, &right_tmp);
556 if (sval_binop_overflows(left_tmp, '+', right_tmp))
557 return;
560 get_absolute_min(r_left, &left_tmp);
561 if (sval_is_negative(left_tmp))
562 return;
563 get_absolute_min(r_right, &right_tmp);
564 if (sval_is_negative(right_tmp))
565 return;
567 if (left_tmp.value == 0)
568 add_comparison(expr->left, SPECIAL_GTE, r_right);
569 else
570 add_comparison(expr->left, '>', r_right);
572 if (right_tmp.value == 0)
573 add_comparison(expr->left, SPECIAL_GTE, r_left);
574 else
575 add_comparison(expr->left, '>', r_left);
578 static void match_assign_sub(struct expression *expr)
580 struct expression *right;
581 struct expression *r_left, *r_right;
582 int comparison;
583 sval_t min;
585 right = strip_expr(expr->right);
586 r_left = strip_expr(right->left);
587 r_right = strip_expr(right->right);
589 if (get_absolute_min(r_right, &min) && sval_is_negative(min))
590 return;
592 comparison = get_comparison(r_left, r_right);
594 switch (comparison) {
595 case '>':
596 case SPECIAL_GTE:
597 if (implied_not_equal(r_right, 0))
598 add_comparison(expr->left, '>', r_left);
599 else
600 add_comparison(expr->left, SPECIAL_GTE, r_left);
601 return;
605 static void match_binop_assign(struct expression *expr)
607 struct expression *right;
609 right = strip_expr(expr->right);
610 if (right->op == '+')
611 match_assign_add(expr);
612 if (right->op == '-')
613 match_assign_sub(expr);
616 static void copy_comparisons(struct expression *left, struct expression *right)
618 struct string_list *links;
619 struct smatch_state *state;
620 struct compare_data *data;
621 struct symbol *left_sym, *right_sym;
622 char *left_var = NULL;
623 char *right_var = NULL;
624 const char *var;
625 struct symbol *sym;
626 int comparison;
627 char *tmp;
629 left_var = expr_to_var_sym(left, &left_sym);
630 if (!left_var || !left_sym)
631 goto done;
632 right_var = expr_to_var_sym(right, &right_sym);
633 if (!right_var || !right_sym)
634 goto done;
636 state = get_state_expr(link_id, right);
637 if (!state)
638 return;
639 links = state->data;
641 FOR_EACH_PTR(links, tmp) {
642 state = get_state(compare_id, tmp, NULL);
643 if (!state->data)
644 continue;
645 data = state->data;
646 comparison = data->comparison;
647 var = data->var1;
648 sym = data->sym1;
649 if (sym == right_sym && strcmp(var, right_var) == 0) {
650 var = data->var2;
651 sym = data->sym2;
652 comparison = flip_op(comparison);
654 add_comparison_var_sym(left_var, left_sym, comparison, var, sym);
655 } END_FOR_EACH_PTR(tmp);
657 done:
658 free_string(right_var);
661 static void match_normal_assign(struct expression *expr)
663 copy_comparisons(expr->left, expr->right);
664 add_comparison(expr->left, SPECIAL_EQUAL, expr->right);
667 static void match_assign(struct expression *expr)
669 struct expression *right;
671 right = strip_expr(expr->right);
672 if (right->type == EXPR_BINOP)
673 match_binop_assign(expr);
674 else
675 match_normal_assign(expr);
678 static int get_comparison_strings(char *one, char *two)
680 char buf[256];
681 struct smatch_state *state;
682 int invert = 0;
683 int ret = 0;
685 if (strcmp(one, two) > 0) {
686 char *tmp = one;
688 one = two;
689 two = tmp;
690 invert = 1;
693 snprintf(buf, sizeof(buf), "%s vs %s", one, two);
694 state = get_state(compare_id, buf, NULL);
695 if (state)
696 ret = state_to_comparison(state);
698 if (invert)
699 ret = flip_op(ret);
701 return ret;
704 int get_comparison(struct expression *a, struct expression *b)
706 char *one = NULL;
707 char *two = NULL;
708 int ret = 0;
710 one = expr_to_var(a);
711 if (!one)
712 goto free;
713 two = expr_to_var(b);
714 if (!two)
715 goto free;
717 ret = get_comparison_strings(one, two);
718 free:
719 free_string(one);
720 free_string(two);
721 return ret;
724 void __add_comparison_info(struct expression *expr, struct expression *call, const char *range)
726 struct expression *arg;
727 int comparison;
728 const char *c = range;
730 if (!str_to_comparison_arg(c, call, &comparison, &arg))
731 return;
732 add_comparison(expr, SPECIAL_LTE, arg);
735 static char *range_comparison_to_param_helper(struct expression *expr, char starts_with)
737 struct symbol *param;
738 char *var = NULL;
739 char buf[256];
740 char *ret_str = NULL;
741 int compare;
742 int i;
744 var = expr_to_var(expr);
745 if (!var)
746 goto free;
748 i = -1;
749 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, param) {
750 i++;
751 if (!param->ident)
752 continue;
753 snprintf(buf, sizeof(buf), "%s orig", param->ident->name);
754 compare = get_comparison_strings(var, buf);
755 if (!compare)
756 continue;
757 if (show_special(compare)[0] != starts_with)
758 continue;
759 snprintf(buf, sizeof(buf), "[%sp%d]", show_special(compare), i);
760 ret_str = alloc_sname(buf);
761 break;
762 } END_FOR_EACH_PTR(param);
764 free:
765 free_string(var);
766 return ret_str;
769 char *expr_equal_to_param(struct expression *expr)
771 return range_comparison_to_param_helper(expr, '=');
774 char *expr_lte_to_param(struct expression *expr)
776 return range_comparison_to_param_helper(expr, '<');
779 static void free_data(struct symbol *sym)
781 if (__inline_fn)
782 return;
783 clear_compare_data_alloc();
786 void register_comparison(int id)
788 compare_id = id;
789 add_hook(&match_logic, CONDITION_HOOK);
790 add_hook(&match_assign, ASSIGNMENT_HOOK);
791 add_hook(&save_start_states, AFTER_DEF_HOOK);
792 add_unmatched_state_hook(compare_id, unmatched_comparison);
793 add_merge_hook(compare_id, &merge_compare_states);
794 add_hook(&free_data, AFTER_FUNC_HOOK);
797 void register_comparison_links(int id)
799 link_id = id;
800 add_merge_hook(link_id, &merge_links);
801 add_modification_hook(link_id, &match_modify);