e62af423c0fd2dacfd40c8cf064821fa75cac87d
[smatch.git] / smatch_comparison.c
blobe62af423c0fd2dacfd40c8cf064821fa75cac87d
1 /*
2 * smatch/smatch_comparison.c
4 * Copyright (C) 2012 Oracle.
6 * Licensed under the Open Software License version 1.1
8 */
11 * The point here is to store the relationships between two variables.
12 * Ie: y > x.
13 * To do that we create a state with the two variables in alphabetical order:
14 * ->name = "x vs y" and the state would be "<". On the false path the state
15 * would be ">=".
17 * Part of the trick of it is that if x or y is modified then we need to reset
18 * the state. We need to keep a list of all the states which depend on x and
19 * all the states which depend on y. The link_id code handles this.
21 * Future work: If we know that x is greater than y and y is greater than z
22 * then we know that x is greater than z.
25 #include "smatch.h"
26 #include "smatch_extra.h"
27 #include "smatch_slist.h"
29 static int compare_id;
30 static int link_id;
32 struct compare_data {
33 const char *var1;
34 struct symbol *sym1;
35 int comparison;
36 const char *var2;
37 struct symbol *sym2;
39 ALLOCATOR(compare_data, "compare data");
41 int var_sym_eq(const char *a, struct symbol *a_sym, const char *b, struct symbol *b_sym)
43 if (a_sym != b_sym)
44 return 0;
45 if (strcmp(a, b) == 0)
46 return 1;
47 return 0;
50 static struct smatch_state *alloc_compare_state(
51 const char *var1, struct symbol *sym1,
52 int comparison,
53 const char *var2, struct symbol *sym2)
55 struct smatch_state *state;
56 struct compare_data *data;
58 state = __alloc_smatch_state(0);
59 state->name = alloc_sname(show_special(comparison));
60 data = __alloc_compare_data(0);
61 data->var1 = alloc_sname(var1);
62 data->sym1 = sym1;
63 data->comparison = comparison;
64 data->var2 = alloc_sname(var2);
65 data->sym2 = sym2;
66 state->data = data;
67 return state;
70 static int state_to_comparison(struct smatch_state *state)
72 if (!state || !state->data)
73 return 0;
74 return ((struct compare_data *)state->data)->comparison;
78 * flip_op() reverses the op left and right. So "x >= y" becomes "y <= x".
80 static int flip_op(int op)
82 switch (op) {
83 case 0:
84 return 0;
85 case '<':
86 return '>';
87 case SPECIAL_UNSIGNED_LT:
88 return SPECIAL_UNSIGNED_GT;
89 case SPECIAL_LTE:
90 return SPECIAL_GTE;
91 case SPECIAL_UNSIGNED_LTE:
92 return SPECIAL_UNSIGNED_GTE;
93 case SPECIAL_EQUAL:
94 return SPECIAL_EQUAL;
95 case SPECIAL_NOTEQUAL:
96 return SPECIAL_NOTEQUAL;
97 case SPECIAL_GTE:
98 return SPECIAL_LTE;
99 case SPECIAL_UNSIGNED_GTE:
100 return SPECIAL_UNSIGNED_LTE;
101 case '>':
102 return '<';
103 case SPECIAL_UNSIGNED_GT:
104 return SPECIAL_UNSIGNED_LT;
105 default:
106 sm_msg("internal smatch bug. unhandled comparison %d", op);
107 return op;
111 static int falsify_op(int op)
113 switch (op) {
114 case 0:
115 return 0;
116 case '<':
117 return SPECIAL_GTE;
118 case SPECIAL_UNSIGNED_LT:
119 return SPECIAL_UNSIGNED_GTE;
120 case SPECIAL_LTE:
121 return '>';
122 case SPECIAL_UNSIGNED_LTE:
123 return SPECIAL_UNSIGNED_GT;
124 case SPECIAL_EQUAL:
125 return SPECIAL_NOTEQUAL;
126 case SPECIAL_NOTEQUAL:
127 return SPECIAL_EQUAL;
128 case SPECIAL_GTE:
129 return '<';
130 case SPECIAL_UNSIGNED_GTE:
131 return SPECIAL_UNSIGNED_LT;
132 case '>':
133 return SPECIAL_LTE;
134 case SPECIAL_UNSIGNED_GT:
135 return SPECIAL_UNSIGNED_LTE;
136 default:
137 sm_msg("internal smatch bug. unhandled comparison %d", op);
138 return op;
142 static int rl_comparison(struct range_list *left_rl, struct range_list *right_rl)
144 sval_t left_min, left_max, right_min, right_max;
146 if (!left_rl || !right_rl)
147 return 0;
149 left_min = rl_min(left_rl);
150 left_max = rl_max(left_rl);
151 right_min = rl_min(right_rl);
152 right_max = rl_max(right_rl);
154 if (left_min.value == left_max.value &&
155 right_min.value == right_max.value &&
156 left_min.value == right_min.value)
157 return SPECIAL_EQUAL;
159 if (sval_cmp(left_max, right_min) < 0)
160 return '<';
161 if (sval_cmp(left_max, right_min) == 0)
162 return SPECIAL_LTE;
163 if (sval_cmp(left_min, right_max) > 0)
164 return '>';
165 if (sval_cmp(left_min, right_max) == 0)
166 return SPECIAL_GTE;
168 return 0;
171 static struct range_list *get_orig_rl(struct symbol *sym)
173 struct smatch_state *state;
175 if (!sym || !sym->ident)
176 return NULL;
177 state = get_orig_estate(sym->ident->name, sym);
178 return estate_rl(state);
181 static struct smatch_state *unmatched_comparison(struct sm_state *sm)
183 struct compare_data *data = sm->state->data;
184 struct range_list *left_rl, *right_rl;
185 int op;
187 if (!data)
188 return &undefined;
190 if (strstr(data->var1, " orig"))
191 left_rl = get_orig_rl(data->sym1);
192 else if (!get_implied_rl_var_sym(data->var1, data->sym1, &left_rl))
193 return &undefined;
194 if (strstr(data->var2, " orig"))
195 right_rl = get_orig_rl(data->sym2);
196 else if (!get_implied_rl_var_sym(data->var2, data->sym2, &right_rl))
197 return &undefined;
199 op = rl_comparison(left_rl, right_rl);
200 if (op)
201 return alloc_compare_state(data->var1, data->sym1, op, data->var2, data->sym2);
203 return &undefined;
206 /* remove_unsigned_from_comparison() is obviously a hack. */
207 static int remove_unsigned_from_comparison(int op)
209 switch (op) {
210 case SPECIAL_UNSIGNED_LT:
211 return '<';
212 case SPECIAL_UNSIGNED_LTE:
213 return SPECIAL_LTE;
214 case SPECIAL_UNSIGNED_GTE:
215 return SPECIAL_GTE;
216 case SPECIAL_UNSIGNED_GT:
217 return '>';
218 default:
219 return op;
223 static int merge_comparisons(int one, int two)
225 int LT, EQ, GT;
227 one = remove_unsigned_from_comparison(one);
228 two = remove_unsigned_from_comparison(two);
230 LT = EQ = GT = 0;
232 switch (one) {
233 case '<':
234 LT = 1;
235 break;
236 case SPECIAL_LTE:
237 LT = 1;
238 EQ = 1;
239 break;
240 case SPECIAL_EQUAL:
241 EQ = 1;
242 break;
243 case SPECIAL_GTE:
244 GT = 1;
245 EQ = 1;
246 break;
247 case '>':
248 GT = 1;
251 switch (two) {
252 case '<':
253 LT = 1;
254 break;
255 case SPECIAL_LTE:
256 LT = 1;
257 EQ = 1;
258 break;
259 case SPECIAL_EQUAL:
260 EQ = 1;
261 break;
262 case SPECIAL_GTE:
263 GT = 1;
264 EQ = 1;
265 break;
266 case '>':
267 GT = 1;
270 if (LT && EQ && GT)
271 return 0;
272 if (LT && EQ)
273 return SPECIAL_LTE;
274 if (LT && GT)
275 return SPECIAL_NOTEQUAL;
276 if (LT)
277 return '<';
278 if (EQ && GT)
279 return SPECIAL_GTE;
280 if (GT)
281 return '>';
282 return 0;
285 static struct smatch_state *merge_compare_states(struct smatch_state *s1, struct smatch_state *s2)
287 struct compare_data *data = s1->data;
288 int op;
290 op = merge_comparisons(state_to_comparison(s1), state_to_comparison(s2));
291 if (op)
292 return alloc_compare_state(data->var1, data->sym1, op, data->var2, data->sym2);
293 return &undefined;
296 struct smatch_state *alloc_link_state(struct string_list *links)
298 struct smatch_state *state;
299 static char buf[256];
300 char *tmp;
301 int i;
303 state = __alloc_smatch_state(0);
305 i = 0;
306 FOR_EACH_PTR(links, tmp) {
307 if (!i++)
308 snprintf(buf, sizeof(buf), "%s", tmp);
309 else
310 snprintf(buf, sizeof(buf), "%s, %s", buf, tmp);
311 } END_FOR_EACH_PTR(tmp);
313 state->name = alloc_sname(buf);
314 state->data = links;
315 return state;
318 static void save_start_states(struct statement *stmt)
320 struct symbol *param;
321 char orig[64];
322 char state_name[128];
323 struct smatch_state *state;
324 struct string_list *links;
325 char *link;
327 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, param) {
328 if (!param->ident)
329 continue;
330 snprintf(orig, sizeof(orig), "%s orig", param->ident->name);
331 snprintf(state_name, sizeof(state_name), "%s vs %s", param->ident->name, orig);
332 state = alloc_compare_state(param->ident->name, param, SPECIAL_EQUAL, alloc_sname(orig), param);
333 set_state(compare_id, state_name, NULL, state);
335 link = alloc_sname(state_name);
336 links = NULL;
337 insert_string(&links, link);
338 state = alloc_link_state(links);
339 set_state(link_id, param->ident->name, param, state);
340 } END_FOR_EACH_PTR(param);
343 static struct smatch_state *merge_links(struct smatch_state *s1, struct smatch_state *s2)
345 struct smatch_state *ret;
346 struct string_list *links;
348 links = combine_string_lists(s1->data, s2->data);
349 ret = alloc_link_state(links);
350 return ret;
353 static void save_link_var_sym(const char *var, struct symbol *sym, char *link)
355 struct smatch_state *old_state, *new_state;
356 struct string_list *links;
357 char *new;
359 old_state = get_state(link_id, var, sym);
360 if (old_state)
361 links = clone_str_list(old_state->data);
362 else
363 links = NULL;
365 new = alloc_sname(link);
366 insert_string(&links, new);
368 new_state = alloc_link_state(links);
369 set_state(link_id, var, sym, new_state);
372 static void save_link(struct expression *expr, char *link)
374 char *var;
375 struct symbol *sym;
377 var = expr_to_var_sym(expr, &sym);
378 if (!var || !sym)
379 goto done;
381 save_link_var_sym(var, sym, link);
382 done:
383 free_string(var);
386 static void match_inc(struct sm_state *sm)
388 struct string_list *links;
389 struct smatch_state *state;
390 char *tmp;
392 links = sm->state->data;
394 FOR_EACH_PTR(links, tmp) {
395 state = get_state(compare_id, tmp, NULL);
397 switch (state_to_comparison(state)) {
398 case SPECIAL_EQUAL:
399 case SPECIAL_GTE:
400 case SPECIAL_UNSIGNED_GTE:
401 case '>':
402 case SPECIAL_UNSIGNED_GT: {
403 struct compare_data *data = state->data;
404 struct smatch_state *new;
406 new = alloc_compare_state(data->var1, data->sym1, '>', data->var2, data->sym2);
407 set_state(compare_id, tmp, NULL, new);
408 break;
410 default:
411 set_state(compare_id, tmp, NULL, &undefined);
413 } END_FOR_EACH_PTR(tmp);
416 static void match_dec(struct sm_state *sm)
418 struct string_list *links;
419 struct smatch_state *state;
420 char *tmp;
422 links = sm->state->data;
424 FOR_EACH_PTR(links, tmp) {
425 state = get_state(compare_id, tmp, NULL);
427 switch (state_to_comparison(state)) {
428 case SPECIAL_EQUAL:
429 case SPECIAL_LTE:
430 case SPECIAL_UNSIGNED_LTE:
431 case '<':
432 case SPECIAL_UNSIGNED_LT: {
433 struct compare_data *data = state->data;
434 struct smatch_state *new;
436 new = alloc_compare_state(data->var1, data->sym1, '<', data->var2, data->sym2);
437 set_state(compare_id, tmp, NULL, new);
438 break;
440 default:
441 set_state(compare_id, tmp, NULL, &undefined);
443 } END_FOR_EACH_PTR(tmp);
446 static int match_inc_dec(struct sm_state *sm, struct expression *mod_expr)
448 if (!mod_expr)
449 return 0;
450 if (mod_expr->type != EXPR_PREOP && mod_expr->type != EXPR_POSTOP)
451 return 0;
453 if (mod_expr->op == SPECIAL_INCREMENT) {
454 match_inc(sm);
455 return 1;
457 if (mod_expr->op == SPECIAL_DECREMENT) {
458 match_dec(sm);
459 return 1;
461 return 0;
464 static void match_modify(struct sm_state *sm, struct expression *mod_expr)
466 struct string_list *links;
467 char *tmp;
469 if (match_inc_dec(sm, mod_expr))
470 return;
472 links = sm->state->data;
474 FOR_EACH_PTR(links, tmp) {
475 set_state(compare_id, tmp, NULL, &undefined);
476 } END_FOR_EACH_PTR(tmp);
477 set_state(link_id, sm->name, sm->sym, &undefined);
480 static void match_compare(struct expression *expr)
482 char *left = NULL;
483 char *right = NULL;
484 struct symbol *left_sym, *right_sym;
485 int op, false_op;
486 struct smatch_state *true_state, *false_state;
487 char state_name[256];
489 if (expr->type != EXPR_COMPARE)
490 return;
491 left = expr_to_var_sym(expr->left, &left_sym);
492 if (!left || !left_sym)
493 goto free;
494 right = expr_to_var_sym(expr->right, &right_sym);
495 if (!right || !right_sym)
496 goto free;
498 if (strcmp(left, right) > 0) {
499 struct symbol *tmp_sym = left_sym;
500 char *tmp_name = left;
502 left = right;
503 left_sym = right_sym;
504 right = tmp_name;
505 right_sym = tmp_sym;
506 op = flip_op(expr->op);
507 } else {
508 op = expr->op;
510 false_op = falsify_op(op);
511 snprintf(state_name, sizeof(state_name), "%s vs %s", left, right);
512 true_state = alloc_compare_state(left, left_sym, op, right, right_sym);
513 false_state = alloc_compare_state(left, left_sym, false_op, right, right_sym);
515 set_true_false_states(compare_id, state_name, NULL, true_state, false_state);
516 save_link(expr->left, state_name);
517 save_link(expr->right, state_name);
518 free:
519 free_string(left);
520 free_string(right);
523 static void add_comparison_var_sym(const char *left_name, struct symbol *left_sym, int comparison, const char *right_name, struct symbol *right_sym)
525 struct smatch_state *state;
526 char state_name[256];
528 if (strcmp(left_name, right_name) > 0) {
529 struct symbol *tmp_sym = left_sym;
530 const char *tmp_name = left_name;
532 left_name = right_name;
533 left_sym = right_sym;
534 right_name = tmp_name;
535 right_sym = tmp_sym;
536 comparison = flip_op(comparison);
538 snprintf(state_name, sizeof(state_name), "%s vs %s", left_name, right_name);
539 state = alloc_compare_state(left_name, left_sym, comparison, right_name, right_sym);
541 set_state(compare_id, state_name, NULL, state);
542 save_link_var_sym(left_name, left_sym, state_name);
543 save_link_var_sym(right_name, right_sym, state_name);
546 static void add_comparison(struct expression *left, int comparison, struct expression *right)
548 char *left_name = NULL;
549 char *right_name = NULL;
550 struct symbol *left_sym, *right_sym;
552 left_name = expr_to_var_sym(left, &left_sym);
553 if (!left_name || !left_sym)
554 goto free;
555 right_name = expr_to_var_sym(right, &right_sym);
556 if (!right_name || !right_sym)
557 goto free;
559 add_comparison_var_sym(left_name, left_sym, comparison, right_name, right_sym);
561 free:
562 free_string(left_name);
563 free_string(right_name);
566 static void match_assign_add(struct expression *expr)
568 struct expression *right;
569 struct expression *r_left, *r_right;
570 sval_t left_tmp, right_tmp;
572 right = strip_expr(expr->right);
573 r_left = strip_expr(right->left);
574 r_right = strip_expr(right->right);
576 get_absolute_min(r_left, &left_tmp);
577 get_absolute_min(r_right, &right_tmp);
579 if (left_tmp.value > 0)
580 add_comparison(expr->left, '>', r_right);
581 else if (left_tmp.value == 0)
582 add_comparison(expr->left, SPECIAL_GTE, r_right);
584 if (right_tmp.value > 0)
585 add_comparison(expr->left, '>', r_left);
586 else if (right_tmp.value == 0)
587 add_comparison(expr->left, SPECIAL_GTE, r_left);
590 static void match_assign_sub(struct expression *expr)
592 struct expression *right;
593 struct expression *r_left, *r_right;
594 int comparison;
595 sval_t min;
597 right = strip_expr(expr->right);
598 r_left = strip_expr(right->left);
599 r_right = strip_expr(right->right);
601 if (get_absolute_min(r_right, &min) && sval_is_negative(min))
602 return;
604 comparison = get_comparison(r_left, r_right);
606 switch (comparison) {
607 case '>':
608 case SPECIAL_GTE:
609 if (implied_not_equal(r_right, 0))
610 add_comparison(expr->left, '>', r_left);
611 else
612 add_comparison(expr->left, SPECIAL_GTE, r_left);
613 return;
617 static void match_assign_divide(struct expression *expr)
619 struct expression *right;
620 struct expression *r_left, *r_right;
621 sval_t min;
623 right = strip_expr(expr->right);
624 r_left = strip_expr(right->left);
625 r_right = strip_expr(right->right);
626 if (!get_implied_min(r_right, &min) || min.value <= 1)
627 return;
629 add_comparison(expr->left, '<', r_left);
632 static void match_binop_assign(struct expression *expr)
634 struct expression *right;
636 right = strip_expr(expr->right);
637 if (right->op == '+')
638 match_assign_add(expr);
639 if (right->op == '-')
640 match_assign_sub(expr);
641 if (right->op == '/')
642 match_assign_divide(expr);
645 static void copy_comparisons(struct expression *left, struct expression *right)
647 struct string_list *links;
648 struct smatch_state *state;
649 struct compare_data *data;
650 struct symbol *left_sym, *right_sym;
651 char *left_var = NULL;
652 char *right_var = NULL;
653 const char *var;
654 struct symbol *sym;
655 int comparison;
656 char *tmp;
658 left_var = expr_to_var_sym(left, &left_sym);
659 if (!left_var || !left_sym)
660 goto done;
661 right_var = expr_to_var_sym(right, &right_sym);
662 if (!right_var || !right_sym)
663 goto done;
665 state = get_state_expr(link_id, right);
666 if (!state)
667 return;
668 links = state->data;
670 FOR_EACH_PTR(links, tmp) {
671 state = get_state(compare_id, tmp, NULL);
672 if (!state->data)
673 continue;
674 data = state->data;
675 comparison = data->comparison;
676 var = data->var2;
677 sym = data->sym2;
678 if (var_sym_eq(var, sym, right_var, right_sym)) {
679 var = data->var1;
680 sym = data->sym1;
681 comparison = flip_op(comparison);
683 add_comparison_var_sym(left_var, left_sym, comparison, var, sym);
684 } END_FOR_EACH_PTR(tmp);
686 done:
687 free_string(right_var);
690 static void match_normal_assign(struct expression *expr)
692 copy_comparisons(expr->left, expr->right);
693 add_comparison(expr->left, SPECIAL_EQUAL, expr->right);
696 static void match_assign(struct expression *expr)
698 struct expression *right;
700 right = strip_expr(expr->right);
701 if (right->type == EXPR_BINOP)
702 match_binop_assign(expr);
703 else
704 match_normal_assign(expr);
707 static int get_comparison_strings(char *one, char *two)
709 char buf[256];
710 struct smatch_state *state;
711 int invert = 0;
712 int ret = 0;
714 if (strcmp(one, two) > 0) {
715 char *tmp = one;
717 one = two;
718 two = tmp;
719 invert = 1;
722 snprintf(buf, sizeof(buf), "%s vs %s", one, two);
723 state = get_state(compare_id, buf, NULL);
724 if (state)
725 ret = state_to_comparison(state);
727 if (invert)
728 ret = flip_op(ret);
730 return ret;
733 int get_comparison(struct expression *a, struct expression *b)
735 char *one = NULL;
736 char *two = NULL;
737 int ret = 0;
739 one = expr_to_var(a);
740 if (!one)
741 goto free;
742 two = expr_to_var(b);
743 if (!two)
744 goto free;
746 ret = get_comparison_strings(one, two);
747 free:
748 free_string(one);
749 free_string(two);
750 return ret;
753 void __add_comparison_info(struct expression *expr, struct expression *call, const char *range)
755 struct expression *arg;
756 int comparison;
757 const char *c = range;
759 if (!str_to_comparison_arg(c, call, &comparison, &arg))
760 return;
761 add_comparison(expr, comparison, arg);
764 static char *range_comparison_to_param_helper(struct expression *expr, char starts_with)
766 struct symbol *param;
767 char *var = NULL;
768 char buf[256];
769 char *ret_str = NULL;
770 int compare;
771 int i;
773 var = expr_to_var(expr);
774 if (!var)
775 goto free;
777 i = -1;
778 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, param) {
779 i++;
780 if (!param->ident)
781 continue;
782 snprintf(buf, sizeof(buf), "%s orig", param->ident->name);
783 compare = get_comparison_strings(var, buf);
784 if (!compare)
785 continue;
786 if (show_special(compare)[0] != starts_with)
787 continue;
788 snprintf(buf, sizeof(buf), "[%sp%d]", show_special(compare), i);
789 ret_str = alloc_sname(buf);
790 break;
791 } END_FOR_EACH_PTR(param);
793 free:
794 free_string(var);
795 return ret_str;
798 char *expr_equal_to_param(struct expression *expr)
800 return range_comparison_to_param_helper(expr, '=');
803 char *expr_lte_to_param(struct expression *expr)
805 return range_comparison_to_param_helper(expr, '<');
808 static void free_data(struct symbol *sym)
810 if (__inline_fn)
811 return;
812 clear_compare_data_alloc();
815 void register_comparison(int id)
817 compare_id = id;
818 add_hook(&match_compare, CONDITION_HOOK);
819 add_hook(&match_assign, ASSIGNMENT_HOOK);
820 add_hook(&save_start_states, AFTER_DEF_HOOK);
821 add_unmatched_state_hook(compare_id, unmatched_comparison);
822 add_merge_hook(compare_id, &merge_compare_states);
823 add_hook(&free_data, AFTER_FUNC_HOOK);
826 void register_comparison_links(int id)
828 link_id = id;
829 add_merge_hook(link_id, &merge_links);
830 add_modification_hook(link_id, &match_modify);