function_hooks: function comparisons can imply a parameter value
[smatch.git] / smatch_comparison.c
blob722fd31f9acecda3b31e79e48de41df5aafaae08
1 /*
2 * Copyright (C) 2012 Oracle.
4 * This program is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU General Public License
6 * as published by the Free Software Foundation; either version 2
7 * of the License, or (at your option) any later version.
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, see http://www.gnu.org/copyleft/gpl.txt
19 * The point here is to store the relationships between two variables.
20 * Ie: y > x.
21 * To do that we create a state with the two variables in alphabetical order:
22 * ->name = "x vs y" and the state would be "<". On the false path the state
23 * would be ">=".
25 * Part of the trick of it is that if x or y is modified then we need to reset
26 * the state. We need to keep a list of all the states which depend on x and
27 * all the states which depend on y. The link_id code handles this.
31 #include "smatch.h"
32 #include "smatch_extra.h"
33 #include "smatch_slist.h"
35 static int compare_id;
36 static int link_id;
38 struct compare_data {
39 const char *var1;
40 struct var_sym_list *vsl1;
41 int comparison;
42 const char *var2;
43 struct var_sym_list *vsl2;
45 ALLOCATOR(compare_data, "compare data");
47 static struct symbol *vsl_to_sym(struct var_sym_list *vsl)
49 struct var_sym *vs;
51 if (!vsl)
52 return NULL;
53 if (ptr_list_size((struct ptr_list *)vsl) != 1)
54 return NULL;
55 vs = first_ptr_list((struct ptr_list *)vsl);
56 return vs->sym;
59 static struct smatch_state *alloc_compare_state(
60 const char *var1, struct var_sym_list *vsl1,
61 int comparison,
62 const char *var2, struct var_sym_list *vsl2)
64 struct smatch_state *state;
65 struct compare_data *data;
67 state = __alloc_smatch_state(0);
68 state->name = alloc_sname(show_special(comparison));
69 data = __alloc_compare_data(0);
70 data->var1 = alloc_sname(var1);
71 data->vsl1 = clone_var_sym_list(vsl1);
72 data->comparison = comparison;
73 data->var2 = alloc_sname(var2);
74 data->vsl2 = clone_var_sym_list(vsl2);
75 state->data = data;
76 return state;
79 static int state_to_comparison(struct smatch_state *state)
81 if (!state || !state->data)
82 return 0;
83 return ((struct compare_data *)state->data)->comparison;
87 * flip_comparison() reverses the op left and right. So "x >= y" becomes "y <= x".
89 int flip_comparison(int op)
91 switch (op) {
92 case 0:
93 return 0;
94 case '<':
95 return '>';
96 case SPECIAL_UNSIGNED_LT:
97 return SPECIAL_UNSIGNED_GT;
98 case SPECIAL_LTE:
99 return SPECIAL_GTE;
100 case SPECIAL_UNSIGNED_LTE:
101 return SPECIAL_UNSIGNED_GTE;
102 case SPECIAL_EQUAL:
103 return SPECIAL_EQUAL;
104 case SPECIAL_NOTEQUAL:
105 return SPECIAL_NOTEQUAL;
106 case SPECIAL_GTE:
107 return SPECIAL_LTE;
108 case SPECIAL_UNSIGNED_GTE:
109 return SPECIAL_UNSIGNED_LTE;
110 case '>':
111 return '<';
112 case SPECIAL_UNSIGNED_GT:
113 return SPECIAL_UNSIGNED_LT;
114 default:
115 sm_msg("internal smatch bug. unhandled comparison %d", op);
116 return op;
120 int negate_comparison(int op)
122 switch (op) {
123 case 0:
124 return 0;
125 case '<':
126 return SPECIAL_GTE;
127 case SPECIAL_UNSIGNED_LT:
128 return SPECIAL_UNSIGNED_GTE;
129 case SPECIAL_LTE:
130 return '>';
131 case SPECIAL_UNSIGNED_LTE:
132 return SPECIAL_UNSIGNED_GT;
133 case SPECIAL_EQUAL:
134 return SPECIAL_NOTEQUAL;
135 case SPECIAL_NOTEQUAL:
136 return SPECIAL_EQUAL;
137 case SPECIAL_GTE:
138 return '<';
139 case SPECIAL_UNSIGNED_GTE:
140 return SPECIAL_UNSIGNED_LT;
141 case '>':
142 return SPECIAL_LTE;
143 case SPECIAL_UNSIGNED_GT:
144 return SPECIAL_UNSIGNED_LTE;
145 default:
146 sm_msg("internal smatch bug. unhandled comparison %d", op);
147 return op;
151 static int rl_comparison(struct range_list *left_rl, struct range_list *right_rl)
153 sval_t left_min, left_max, right_min, right_max;
155 if (!left_rl || !right_rl)
156 return 0;
158 left_min = rl_min(left_rl);
159 left_max = rl_max(left_rl);
160 right_min = rl_min(right_rl);
161 right_max = rl_max(right_rl);
163 if (left_min.value == left_max.value &&
164 right_min.value == right_max.value &&
165 left_min.value == right_min.value)
166 return SPECIAL_EQUAL;
168 if (sval_cmp(left_max, right_min) < 0)
169 return '<';
170 if (sval_cmp(left_max, right_min) == 0)
171 return SPECIAL_LTE;
172 if (sval_cmp(left_min, right_max) > 0)
173 return '>';
174 if (sval_cmp(left_min, right_max) == 0)
175 return SPECIAL_GTE;
177 return 0;
180 static struct range_list *get_orig_rl(struct var_sym_list *vsl)
182 struct symbol *sym;
183 struct smatch_state *state;
185 if (!vsl)
186 return NULL;
187 sym = vsl_to_sym(vsl);
188 if (!sym || !sym->ident)
189 return NULL;
190 state = get_orig_estate(sym->ident->name, sym);
191 return estate_rl(state);
194 static struct smatch_state *unmatched_comparison(struct sm_state *sm)
196 struct compare_data *data = sm->state->data;
197 struct range_list *left_rl, *right_rl;
198 int op;
200 if (!data)
201 return &undefined;
203 if (strstr(data->var1, " orig"))
204 left_rl = get_orig_rl(data->vsl1);
205 else if (!get_implied_rl_var_sym(data->var1, vsl_to_sym(data->vsl1), &left_rl))
206 return &undefined;
207 if (strstr(data->var2, " orig"))
208 right_rl = get_orig_rl(data->vsl2);
209 else if (!get_implied_rl_var_sym(data->var2, vsl_to_sym(data->vsl2), &right_rl))
210 return &undefined;
213 op = rl_comparison(left_rl, right_rl);
214 if (op)
215 return alloc_compare_state(data->var1, data->vsl1, op, data->var2, data->vsl2);
217 return &undefined;
220 /* remove_unsigned_from_comparison() is obviously a hack. */
221 static int remove_unsigned_from_comparison(int op)
223 switch (op) {
224 case SPECIAL_UNSIGNED_LT:
225 return '<';
226 case SPECIAL_UNSIGNED_LTE:
227 return SPECIAL_LTE;
228 case SPECIAL_UNSIGNED_GTE:
229 return SPECIAL_GTE;
230 case SPECIAL_UNSIGNED_GT:
231 return '>';
232 default:
233 return op;
238 * This is for when you merge states "a < b" and "a == b", the result is that
239 * we can say for sure, "a <= b" after the merge.
241 static int merge_comparisons(int one, int two)
243 int LT, EQ, GT;
245 one = remove_unsigned_from_comparison(one);
246 two = remove_unsigned_from_comparison(two);
248 LT = EQ = GT = 0;
250 switch (one) {
251 case '<':
252 LT = 1;
253 break;
254 case SPECIAL_LTE:
255 LT = 1;
256 EQ = 1;
257 break;
258 case SPECIAL_EQUAL:
259 EQ = 1;
260 break;
261 case SPECIAL_GTE:
262 GT = 1;
263 EQ = 1;
264 break;
265 case '>':
266 GT = 1;
269 switch (two) {
270 case '<':
271 LT = 1;
272 break;
273 case SPECIAL_LTE:
274 LT = 1;
275 EQ = 1;
276 break;
277 case SPECIAL_EQUAL:
278 EQ = 1;
279 break;
280 case SPECIAL_GTE:
281 GT = 1;
282 EQ = 1;
283 break;
284 case '>':
285 GT = 1;
288 if (LT && EQ && GT)
289 return 0;
290 if (LT && EQ)
291 return SPECIAL_LTE;
292 if (LT && GT)
293 return SPECIAL_NOTEQUAL;
294 if (LT)
295 return '<';
296 if (EQ && GT)
297 return SPECIAL_GTE;
298 if (GT)
299 return '>';
300 return 0;
304 * This is for if you have "a < b" and "b <= c". Or in other words,
305 * "a < b <= c". You would call this like get_combined_comparison('<', '<=').
306 * The return comparison would be '<'.
308 * This function is different from merge_comparisons(), for example:
309 * merge_comparison('<', '==') returns '<='
310 * get_combined_comparison('<', '==') returns '<'
312 static int combine_comparisons(int left_compare, int right_compare)
314 int LT, EQ, GT;
316 left_compare = remove_unsigned_from_comparison(left_compare);
317 right_compare = remove_unsigned_from_comparison(right_compare);
319 LT = EQ = GT = 0;
321 switch (left_compare) {
322 case '<':
323 LT++;
324 break;
325 case SPECIAL_LTE:
326 LT++;
327 EQ++;
328 break;
329 case SPECIAL_EQUAL:
330 return right_compare;
331 case SPECIAL_GTE:
332 GT++;
333 EQ++;
334 break;
335 case '>':
336 GT++;
339 switch (right_compare) {
340 case '<':
341 LT++;
342 break;
343 case SPECIAL_LTE:
344 LT++;
345 EQ++;
346 break;
347 case SPECIAL_EQUAL:
348 return left_compare;
349 case SPECIAL_GTE:
350 GT++;
351 EQ++;
352 break;
353 case '>':
354 GT++;
357 if (LT == 2) {
358 if (EQ == 2)
359 return SPECIAL_LTE;
360 return '<';
363 if (GT == 2) {
364 if (EQ == 2)
365 return SPECIAL_GTE;
366 return '>';
368 return 0;
371 static int filter_comparison(int orig, int op)
373 if (orig == op)
374 return orig;
376 switch (orig) {
377 case 0:
378 return op;
379 case '<':
380 switch (op) {
381 case '<':
382 case SPECIAL_LTE:
383 case SPECIAL_NOTEQUAL:
384 return '<';
386 return 0;
387 case SPECIAL_LTE:
388 switch (op) {
389 case '<':
390 case SPECIAL_LTE:
391 case SPECIAL_EQUAL:
392 return op;
393 case SPECIAL_NOTEQUAL:
394 return '<';
396 return 0;
397 case SPECIAL_EQUAL:
398 switch (op) {
399 case SPECIAL_LTE:
400 case SPECIAL_EQUAL:
401 case SPECIAL_GTE:
402 case SPECIAL_UNSIGNED_LTE:
403 case SPECIAL_UNSIGNED_GTE:
404 return SPECIAL_EQUAL;
406 return 0;
407 case SPECIAL_NOTEQUAL:
408 switch (op) {
409 case '<':
410 case SPECIAL_LTE:
411 return '<';
412 case SPECIAL_UNSIGNED_LT:
413 case SPECIAL_UNSIGNED_LTE:
414 return SPECIAL_UNSIGNED_LT;
415 case SPECIAL_NOTEQUAL:
416 return op;
417 case '>':
418 case SPECIAL_GTE:
419 return '>';
420 case SPECIAL_UNSIGNED_GT:
421 case SPECIAL_UNSIGNED_GTE:
422 return SPECIAL_UNSIGNED_GT;
424 return 0;
425 case SPECIAL_GTE:
426 switch (op) {
427 case SPECIAL_LTE:
428 return SPECIAL_EQUAL;
429 case '>':
430 case SPECIAL_GTE:
431 case SPECIAL_EQUAL:
432 return op;
433 case SPECIAL_NOTEQUAL:
434 return '>';
436 return 0;
437 case '>':
438 switch (op) {
439 case '>':
440 case SPECIAL_GTE:
441 case SPECIAL_NOTEQUAL:
442 return '>';
444 return 0;
445 case SPECIAL_UNSIGNED_LT:
446 switch (op) {
447 case SPECIAL_UNSIGNED_LT:
448 case SPECIAL_UNSIGNED_LTE:
449 case SPECIAL_NOTEQUAL:
450 return SPECIAL_UNSIGNED_LT;
452 return 0;
453 case SPECIAL_UNSIGNED_LTE:
454 switch (op) {
455 case SPECIAL_UNSIGNED_LT:
456 case SPECIAL_UNSIGNED_LTE:
457 case SPECIAL_EQUAL:
458 return op;
459 case SPECIAL_NOTEQUAL:
460 return SPECIAL_UNSIGNED_LT;
461 case SPECIAL_UNSIGNED_GTE:
462 return SPECIAL_EQUAL;
464 return 0;
465 case SPECIAL_UNSIGNED_GTE:
466 switch (op) {
467 case SPECIAL_UNSIGNED_LTE:
468 return SPECIAL_EQUAL;
469 case SPECIAL_NOTEQUAL:
470 return SPECIAL_UNSIGNED_GT;
471 case SPECIAL_EQUAL:
472 case SPECIAL_UNSIGNED_GTE:
473 case SPECIAL_UNSIGNED_GT:
474 return op;
476 return 0;
477 case SPECIAL_UNSIGNED_GT:
478 switch (op) {
479 case SPECIAL_UNSIGNED_GT:
480 case SPECIAL_UNSIGNED_GTE:
481 case SPECIAL_NOTEQUAL:
482 return SPECIAL_UNSIGNED_GT;
484 return 0;
486 return 0;
489 static struct smatch_state *merge_compare_states(struct smatch_state *s1, struct smatch_state *s2)
491 struct compare_data *data = s1->data;
492 int op;
494 op = merge_comparisons(state_to_comparison(s1), state_to_comparison(s2));
495 if (op)
496 return alloc_compare_state(data->var1, data->vsl1, op, data->var2, data->vsl2);
497 return &undefined;
500 static struct smatch_state *alloc_link_state(struct string_list *links)
502 struct smatch_state *state;
503 static char buf[256];
504 char *tmp;
505 int i;
507 state = __alloc_smatch_state(0);
509 i = 0;
510 FOR_EACH_PTR(links, tmp) {
511 if (!i++) {
512 snprintf(buf, sizeof(buf), "%s", tmp);
513 } else {
514 append(buf, ", ", sizeof(buf));
515 append(buf, tmp, sizeof(buf));
517 } END_FOR_EACH_PTR(tmp);
519 state->name = alloc_sname(buf);
520 state->data = links;
521 return state;
524 static void save_start_states(struct statement *stmt)
526 struct symbol *param;
527 char orig[64];
528 char state_name[128];
529 struct smatch_state *state;
530 struct string_list *links;
531 char *link;
533 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, param) {
534 struct var_sym_list *vsl1 = NULL;
535 struct var_sym_list *vsl2 = NULL;
537 if (!param->ident)
538 continue;
539 snprintf(orig, sizeof(orig), "%s orig", param->ident->name);
540 snprintf(state_name, sizeof(state_name), "%s vs %s", param->ident->name, orig);
541 add_var_sym(&vsl1, param->ident->name, param);
542 add_var_sym(&vsl2, orig, param);
543 state = alloc_compare_state(param->ident->name, vsl1, SPECIAL_EQUAL, alloc_sname(orig), vsl2);
544 set_state(compare_id, state_name, NULL, state);
546 link = alloc_sname(state_name);
547 links = NULL;
548 insert_string(&links, link);
549 state = alloc_link_state(links);
550 set_state(link_id, param->ident->name, param, state);
551 } END_FOR_EACH_PTR(param);
554 static struct smatch_state *merge_links(struct smatch_state *s1, struct smatch_state *s2)
556 struct smatch_state *ret;
557 struct string_list *links;
559 links = combine_string_lists(s1->data, s2->data);
560 ret = alloc_link_state(links);
561 return ret;
564 static void save_link_var_sym(const char *var, struct symbol *sym, const char *link)
566 struct smatch_state *old_state, *new_state;
567 struct string_list *links;
568 char *new;
570 old_state = get_state(link_id, var, sym);
571 if (old_state)
572 links = clone_str_list(old_state->data);
573 else
574 links = NULL;
576 new = alloc_sname(link);
577 insert_string(&links, new);
579 new_state = alloc_link_state(links);
580 set_state(link_id, var, sym, new_state);
583 static void match_inc(struct sm_state *sm)
585 struct string_list *links;
586 struct smatch_state *state;
587 char *tmp;
589 links = sm->state->data;
591 FOR_EACH_PTR(links, tmp) {
592 state = get_state(compare_id, tmp, NULL);
594 switch (state_to_comparison(state)) {
595 case SPECIAL_EQUAL:
596 case SPECIAL_GTE:
597 case SPECIAL_UNSIGNED_GTE:
598 case '>':
599 case SPECIAL_UNSIGNED_GT: {
600 struct compare_data *data = state->data;
601 struct smatch_state *new;
603 new = alloc_compare_state(data->var1, data->vsl1, '>', data->var2, data->vsl2);
604 set_state(compare_id, tmp, NULL, new);
605 break;
607 default:
608 set_state(compare_id, tmp, NULL, &undefined);
610 } END_FOR_EACH_PTR(tmp);
613 static void match_dec(struct sm_state *sm)
615 struct string_list *links;
616 struct smatch_state *state;
617 char *tmp;
619 links = sm->state->data;
621 FOR_EACH_PTR(links, tmp) {
622 state = get_state(compare_id, tmp, NULL);
624 switch (state_to_comparison(state)) {
625 case SPECIAL_EQUAL:
626 case SPECIAL_LTE:
627 case SPECIAL_UNSIGNED_LTE:
628 case '<':
629 case SPECIAL_UNSIGNED_LT: {
630 struct compare_data *data = state->data;
631 struct smatch_state *new;
633 new = alloc_compare_state(data->var1, data->vsl1, '<', data->var2, data->vsl2);
634 set_state(compare_id, tmp, NULL, new);
635 break;
637 default:
638 set_state(compare_id, tmp, NULL, &undefined);
640 } END_FOR_EACH_PTR(tmp);
643 static int match_inc_dec(struct sm_state *sm, struct expression *mod_expr)
645 if (!mod_expr)
646 return 0;
647 if (mod_expr->type != EXPR_PREOP && mod_expr->type != EXPR_POSTOP)
648 return 0;
650 if (mod_expr->op == SPECIAL_INCREMENT) {
651 match_inc(sm);
652 return 1;
654 if (mod_expr->op == SPECIAL_DECREMENT) {
655 match_dec(sm);
656 return 1;
658 return 0;
661 static void match_modify(struct sm_state *sm, struct expression *mod_expr)
663 struct string_list *links;
664 char *tmp;
666 /* Huh??? This needs a comment! */
667 if (match_inc_dec(sm, mod_expr))
668 return;
670 links = sm->state->data;
672 FOR_EACH_PTR(links, tmp) {
673 set_state(compare_id, tmp, NULL, &undefined);
674 } END_FOR_EACH_PTR(tmp);
675 set_state(link_id, sm->name, sm->sym, &undefined);
678 static char *chunk_to_var_sym(struct expression *expr, struct symbol **sym)
680 char *name, *left_name, *right_name;
681 struct symbol *tmp;
682 char buf[128];
684 expr = strip_expr(expr);
685 if (!expr)
686 return NULL;
687 if (sym)
688 *sym = NULL;
690 name = expr_to_var_sym(expr, &tmp);
691 if (name && tmp) {
692 if (sym)
693 *sym = tmp;
694 return name;
696 if (name)
697 free_string(name);
699 if (expr->type != EXPR_BINOP)
700 return NULL;
701 if (expr->op != '-' && expr->op != '+')
702 return NULL;
704 left_name = expr_to_var(expr->left);
705 if (!left_name)
706 return NULL;
707 right_name = expr_to_var(expr->right);
708 if (!right_name) {
709 free_string(left_name);
710 return NULL;
712 snprintf(buf, sizeof(buf), "%s %s %s", left_name, show_special(expr->op), right_name);
713 free_string(left_name);
714 free_string(right_name);
715 return alloc_string(buf);
718 static char *chunk_to_var(struct expression *expr)
720 return chunk_to_var_sym(expr, NULL);
723 static void save_link(struct expression *expr, char *link)
725 char *var;
726 struct symbol *sym;
728 expr = strip_expr(expr);
729 if (expr->type == EXPR_BINOP) {
730 char *chunk;
732 chunk = chunk_to_var(expr);
733 if (!chunk)
734 return;
736 save_link(expr->left, link);
737 save_link(expr->right, link);
738 save_link_var_sym(chunk, NULL, link);
739 return;
742 var = expr_to_var_sym(expr, &sym);
743 if (!var || !sym)
744 goto done;
746 save_link_var_sym(var, sym, link);
747 done:
748 free_string(var);
752 * The idea here is that we take a comparison "a < b" and then we look at all
753 * the things which "b" is compared against "b < c" and we say that that implies
754 * a relationship "a < c".
756 * The names here about because the comparisons are organized like this
757 * "a < b < c".
760 static void update_tf_links(struct stree *pre_stree,
761 const char *left_var, struct var_sym_list *left_vsl,
762 int left_comparison, int left_false_comparison,
763 const char *mid_var, struct var_sym_list *mid_vsl,
764 struct string_list *links)
766 struct smatch_state *state;
767 struct smatch_state *true_state, *false_state;
768 struct compare_data *data;
769 const char *right_var;
770 struct var_sym_list *right_vsl;
771 int right_comparison;
772 int true_comparison;
773 int false_comparison;
774 char *tmp;
775 char state_name[256];
776 struct var_sym *vs;
778 FOR_EACH_PTR(links, tmp) {
779 state = get_state_stree(pre_stree, compare_id, tmp, NULL);
780 if (!state || !state->data)
781 continue;
782 data = state->data;
783 right_comparison = data->comparison;
784 right_var = data->var2;
785 right_vsl = data->vsl2;
786 if (strcmp(mid_var, right_var) == 0) {
787 right_var = data->var1;
788 right_vsl = data->vsl1;
789 right_comparison = flip_comparison(right_comparison);
791 if (strcmp(left_var, right_var) == 0)
792 continue;
794 true_comparison = combine_comparisons(left_comparison, right_comparison);
795 false_comparison = combine_comparisons(left_false_comparison, right_comparison);
797 if (strcmp(left_var, right_var) > 0) {
798 const char *tmp_var = left_var;
799 struct var_sym_list *tmp_vsl = left_vsl;
801 left_var = right_var;
802 left_vsl = right_vsl;
803 right_var = tmp_var;
804 right_vsl = tmp_vsl;
805 true_comparison = flip_comparison(true_comparison);
806 false_comparison = flip_comparison(false_comparison);
809 if (!true_comparison && !false_comparison)
810 continue;
812 if (true_comparison)
813 true_state = alloc_compare_state(left_var, left_vsl, true_comparison, right_var, right_vsl);
814 else
815 true_state = NULL;
816 if (false_comparison)
817 false_state = alloc_compare_state(left_var, left_vsl, false_comparison, right_var, right_vsl);
818 else
819 false_state = NULL;
821 snprintf(state_name, sizeof(state_name), "%s vs %s", left_var, right_var);
822 set_true_false_states(compare_id, state_name, NULL, true_state, false_state);
823 FOR_EACH_PTR(left_vsl, vs) {
824 save_link_var_sym(vs->var, vs->sym, state_name);
825 } END_FOR_EACH_PTR(vs);
826 FOR_EACH_PTR(right_vsl, vs) {
827 save_link_var_sym(vs->var, vs->sym, state_name);
828 } END_FOR_EACH_PTR(vs);
829 if (!vsl_to_sym(left_vsl))
830 save_link_var_sym(left_var, NULL, state_name);
831 if (!vsl_to_sym(right_vsl))
832 save_link_var_sym(right_var, NULL, state_name);
833 } END_FOR_EACH_PTR(tmp);
836 static void update_tf_data(struct stree *pre_stree,
837 const char *left_name, struct var_sym_list *left_vsl,
838 const char *right_name, struct var_sym_list *right_vsl,
839 int true_comparison, int false_comparison)
841 struct smatch_state *state;
843 state = get_state_stree(pre_stree, link_id, right_name, vsl_to_sym(right_vsl));
844 if (state)
845 update_tf_links(pre_stree, left_name, left_vsl, true_comparison, false_comparison, right_name, right_vsl, state->data);
847 state = get_state_stree(pre_stree, link_id, left_name, vsl_to_sym(left_vsl));
848 if (state)
849 update_tf_links(pre_stree, right_name, right_vsl, flip_comparison(true_comparison), flip_comparison(false_comparison), left_name, left_vsl, state->data);
852 static void match_compare(struct expression *expr)
854 char *left = NULL;
855 char *right = NULL;
856 struct symbol *left_sym, *right_sym;
857 struct var_sym_list *left_vsl, *right_vsl;
858 int op, false_op;
859 int orig_comparison;
860 struct smatch_state *true_state, *false_state;
861 char state_name[256];
862 struct stree *pre_stree;
864 if (expr->type != EXPR_COMPARE)
865 return;
866 left = chunk_to_var_sym(expr->left, &left_sym);
867 if (!left)
868 goto free;
869 left_vsl = expr_to_vsl(expr->left);
870 right = chunk_to_var_sym(expr->right, &right_sym);
871 if (!right)
872 goto free;
873 right_vsl = expr_to_vsl(expr->right);
875 if (strcmp(left, right) > 0) {
876 struct symbol *tmp_sym = left_sym;
877 char *tmp_name = left;
878 struct var_sym_list *tmp_vsl = left_vsl;
880 left = right;
881 left_sym = right_sym;
882 left_vsl = right_vsl;
883 right = tmp_name;
884 right_sym = tmp_sym;
885 right_vsl = tmp_vsl;
886 op = flip_comparison(expr->op);
887 } else {
888 op = expr->op;
890 false_op = negate_comparison(op);
892 orig_comparison = get_comparison_strings(left, right);
893 op = filter_comparison(orig_comparison, op);
894 false_op = filter_comparison(orig_comparison, false_op);
896 snprintf(state_name, sizeof(state_name), "%s vs %s", left, right);
897 true_state = alloc_compare_state(left, left_vsl, op, right, right_vsl);
898 false_state = alloc_compare_state(left, left_vsl, false_op, right, right_vsl);
900 pre_stree = clone_stree(__get_cur_stree());
901 update_tf_data(pre_stree, left, left_vsl, right, right_vsl, op, false_op);
902 free_stree(&pre_stree);
904 set_true_false_states(compare_id, state_name, NULL, true_state, false_state);
905 save_link(expr->left, state_name);
906 save_link(expr->right, state_name);
907 free:
908 free_string(left);
909 free_string(right);
912 static void add_comparison_var_sym(const char *left_name,
913 struct var_sym_list *left_vsl,
914 int comparison,
915 const char *right_name, struct var_sym_list *right_vsl)
917 struct smatch_state *state;
918 struct var_sym *vs;
919 char state_name[256];
921 if (strcmp(left_name, right_name) > 0) {
922 const char *tmp_name = left_name;
923 struct var_sym_list *tmp_vsl = left_vsl;
925 left_name = right_name;
926 left_vsl = right_vsl;
927 right_name = tmp_name;
928 right_vsl = tmp_vsl;
929 comparison = flip_comparison(comparison);
931 snprintf(state_name, sizeof(state_name), "%s vs %s", left_name, right_name);
932 state = alloc_compare_state(left_name, left_vsl, comparison, right_name, right_vsl);
934 set_state(compare_id, state_name, NULL, state);
936 FOR_EACH_PTR(left_vsl, vs) {
937 save_link_var_sym(vs->var, vs->sym, state_name);
938 } END_FOR_EACH_PTR(vs);
939 FOR_EACH_PTR(right_vsl, vs) {
940 save_link_var_sym(vs->var, vs->sym, state_name);
941 } END_FOR_EACH_PTR(vs);
944 static void add_comparison(struct expression *left, int comparison, struct expression *right)
946 char *left_name = NULL;
947 char *right_name = NULL;
948 struct symbol *left_sym, *right_sym;
949 struct var_sym_list *left_vsl, *right_vsl;
950 struct smatch_state *state;
951 char state_name[256];
953 left_name = chunk_to_var_sym(left, &left_sym);
954 if (!left_name)
955 goto free;
956 left_vsl = expr_to_vsl(left);
957 right_name = chunk_to_var_sym(right, &right_sym);
958 if (!right_name)
959 goto free;
960 right_vsl = expr_to_vsl(right);
962 if (strcmp(left_name, right_name) > 0) {
963 struct symbol *tmp_sym = left_sym;
964 char *tmp_name = left_name;
965 struct var_sym_list *tmp_vsl = left_vsl;
967 left_name = right_name;
968 left_sym = right_sym;
969 left_vsl = right_vsl;
970 right_name = tmp_name;
971 right_sym = tmp_sym;
972 right_vsl = tmp_vsl;
973 comparison = flip_comparison(comparison);
975 snprintf(state_name, sizeof(state_name), "%s vs %s", left_name, right_name);
976 state = alloc_compare_state(left_name, left_vsl, comparison, right_name, right_vsl);
978 set_state(compare_id, state_name, NULL, state);
979 save_link(left, state_name);
980 save_link(right, state_name);
982 free:
983 free_string(left_name);
984 free_string(right_name);
987 static void match_assign_add(struct expression *expr)
989 struct expression *right;
990 struct expression *r_left, *r_right;
991 sval_t left_tmp, right_tmp;
993 right = strip_expr(expr->right);
994 r_left = strip_expr(right->left);
995 r_right = strip_expr(right->right);
997 get_absolute_min(r_left, &left_tmp);
998 get_absolute_min(r_right, &right_tmp);
1000 if (left_tmp.value > 0)
1001 add_comparison(expr->left, '>', r_right);
1002 else if (left_tmp.value == 0)
1003 add_comparison(expr->left, SPECIAL_GTE, r_right);
1005 if (right_tmp.value > 0)
1006 add_comparison(expr->left, '>', r_left);
1007 else if (right_tmp.value == 0)
1008 add_comparison(expr->left, SPECIAL_GTE, r_left);
1011 static void match_assign_sub(struct expression *expr)
1013 struct expression *right;
1014 struct expression *r_left, *r_right;
1015 int comparison;
1016 sval_t min;
1018 right = strip_expr(expr->right);
1019 r_left = strip_expr(right->left);
1020 r_right = strip_expr(right->right);
1022 if (get_absolute_min(r_right, &min) && sval_is_negative(min))
1023 return;
1025 comparison = get_comparison(r_left, r_right);
1027 switch (comparison) {
1028 case '>':
1029 case SPECIAL_GTE:
1030 if (implied_not_equal(r_right, 0))
1031 add_comparison(expr->left, '>', r_left);
1032 else
1033 add_comparison(expr->left, SPECIAL_GTE, r_left);
1034 return;
1038 static void match_assign_divide(struct expression *expr)
1040 struct expression *right;
1041 struct expression *r_left, *r_right;
1042 sval_t min;
1044 right = strip_expr(expr->right);
1045 r_left = strip_expr(right->left);
1046 r_right = strip_expr(right->right);
1047 if (!get_implied_min(r_right, &min) || min.value <= 1)
1048 return;
1050 add_comparison(expr->left, '<', r_left);
1053 static void match_binop_assign(struct expression *expr)
1055 struct expression *right;
1057 right = strip_expr(expr->right);
1058 if (right->op == '+')
1059 match_assign_add(expr);
1060 if (right->op == '-')
1061 match_assign_sub(expr);
1062 if (right->op == '/')
1063 match_assign_divide(expr);
1066 static void copy_comparisons(struct expression *left, struct expression *right)
1068 struct string_list *links;
1069 struct smatch_state *state;
1070 struct compare_data *data;
1071 struct symbol *left_sym, *right_sym;
1072 char *left_var = NULL;
1073 char *right_var = NULL;
1074 struct var_sym_list *left_vsl;
1075 const char *var;
1076 struct var_sym_list *vsl;
1077 int comparison;
1078 char *tmp;
1080 left_var = chunk_to_var_sym(left, &left_sym);
1081 if (!left_var)
1082 goto done;
1083 left_vsl = expr_to_vsl(left);
1084 right_var = chunk_to_var_sym(right, &right_sym);
1085 if (!right_var)
1086 goto done;
1088 state = get_state(link_id, right_var, right_sym);
1089 if (!state)
1090 return;
1091 links = state->data;
1093 FOR_EACH_PTR(links, tmp) {
1094 state = get_state(compare_id, tmp, NULL);
1095 if (!state || !state->data)
1096 continue;
1097 data = state->data;
1098 comparison = data->comparison;
1099 var = data->var2;
1100 vsl = data->vsl2;
1101 if (strcmp(var, right_var) == 0) {
1102 var = data->var1;
1103 vsl = data->vsl1;
1104 comparison = flip_comparison(comparison);
1106 add_comparison_var_sym(left_var, left_vsl, comparison, var, vsl);
1107 } END_FOR_EACH_PTR(tmp);
1109 done:
1110 free_string(right_var);
1113 static void match_assign(struct expression *expr)
1115 struct expression *right;
1117 if (expr->op != '=')
1118 return;
1120 if (is_struct(expr->left))
1121 return;
1123 copy_comparisons(expr->left, expr->right);
1124 add_comparison(expr->left, SPECIAL_EQUAL, expr->right);
1126 right = strip_expr(expr->right);
1127 if (right->type == EXPR_BINOP)
1128 match_binop_assign(expr);
1131 int get_comparison_strings(const char *one, const char *two)
1133 char buf[256];
1134 struct smatch_state *state;
1135 int invert = 0;
1136 int ret = 0;
1138 if (strcmp(one, two) == 0)
1139 return SPECIAL_EQUAL;
1141 if (strcmp(one, two) > 0) {
1142 const char *tmp = one;
1144 one = two;
1145 two = tmp;
1146 invert = 1;
1149 snprintf(buf, sizeof(buf), "%s vs %s", one, two);
1150 state = get_state(compare_id, buf, NULL);
1151 if (state)
1152 ret = state_to_comparison(state);
1154 if (invert)
1155 ret = flip_comparison(ret);
1157 return ret;
1160 int get_comparison(struct expression *a, struct expression *b)
1162 char *one = NULL;
1163 char *two = NULL;
1164 int ret = 0;
1166 one = chunk_to_var(a);
1167 if (!one)
1168 goto free;
1169 two = chunk_to_var(b);
1170 if (!two)
1171 goto free;
1173 ret = get_comparison_strings(one, two);
1174 free:
1175 free_string(one);
1176 free_string(two);
1177 return ret;
1180 int possible_comparison(struct expression *a, int comparison, struct expression *b)
1182 char *one = NULL;
1183 char *two = NULL;
1184 int ret = 0;
1185 char buf[256];
1186 struct sm_state *sm;
1187 int saved;
1189 one = chunk_to_var(a);
1190 if (!one)
1191 goto free;
1192 two = chunk_to_var(b);
1193 if (!two)
1194 goto free;
1197 if (strcmp(one, two) == 0 && comparison == SPECIAL_EQUAL) {
1198 ret = 1;
1199 goto free;
1202 if (strcmp(one, two) > 0) {
1203 char *tmp = one;
1205 one = two;
1206 two = tmp;
1207 comparison = flip_comparison(comparison);
1210 snprintf(buf, sizeof(buf), "%s vs %s", one, two);
1211 sm = get_sm_state(compare_id, buf, NULL);
1212 if (!sm)
1213 goto free;
1215 FOR_EACH_PTR(sm->possible, sm) {
1216 if (!sm->state->data)
1217 continue;
1218 saved = ((struct compare_data *)sm->state->data)->comparison;
1219 if (saved == comparison)
1220 ret = 1;
1221 if (comparison == SPECIAL_EQUAL &&
1222 (saved == SPECIAL_LTE ||
1223 saved == SPECIAL_GTE ||
1224 saved == SPECIAL_UNSIGNED_LTE ||
1225 saved == SPECIAL_UNSIGNED_GTE))
1226 ret = 1;
1227 if (ret == 1)
1228 goto free;
1229 } END_FOR_EACH_PTR(sm);
1231 return ret;
1232 free:
1233 free_string(one);
1234 free_string(two);
1235 return ret;
1238 static void update_links_from_call(struct expression *left,
1239 int left_compare,
1240 struct expression *right)
1242 struct string_list *links;
1243 struct smatch_state *state;
1244 struct compare_data *data;
1245 struct symbol *left_sym, *right_sym;
1246 char *left_var = NULL;
1247 char *right_var = NULL;
1248 struct var_sym_list *left_vsl;
1249 const char *var;
1250 struct var_sym_list *vsl;
1251 int comparison;
1252 char *tmp;
1254 left_var = chunk_to_var_sym(left, &left_sym);
1255 if (!left_var)
1256 goto done;
1257 left_vsl = expr_to_vsl(left);
1258 right_var = chunk_to_var_sym(right, &right_sym);
1259 if (!right_var)
1260 goto done;
1262 state = get_state(link_id, right_var, right_sym);
1263 if (!state)
1264 return;
1265 links = state->data;
1267 FOR_EACH_PTR(links, tmp) {
1268 state = get_state(compare_id, tmp, NULL);
1269 if (!state || !state->data)
1270 continue;
1271 data = state->data;
1272 comparison = data->comparison;
1273 var = data->var2;
1274 vsl = data->vsl2;
1275 if (strcmp(var, right_var) == 0) {
1276 var = data->var1;
1277 vsl = data->vsl1;
1278 comparison = flip_comparison(comparison);
1280 comparison = combine_comparisons(left_compare, comparison);
1281 if (!comparison)
1282 continue;
1283 add_comparison_var_sym(left_var, left_vsl, comparison, var, vsl);
1284 } END_FOR_EACH_PTR(tmp);
1286 done:
1287 free_string(right_var);
1290 void __add_comparison_info(struct expression *expr, struct expression *call, const char *range)
1292 struct expression *arg;
1293 int comparison;
1294 const char *c = range;
1296 if (!str_to_comparison_arg(c, call, &comparison, &arg))
1297 return;
1298 update_links_from_call(expr, comparison, arg);
1299 add_comparison(expr, comparison, arg);
1302 static char *range_comparison_to_param_helper(struct expression *expr, char starts_with, int ignore)
1304 struct symbol *param;
1305 char *var = NULL;
1306 char buf[256];
1307 char *ret_str = NULL;
1308 int compare;
1309 int i;
1311 var = chunk_to_var(expr);
1312 if (!var)
1313 goto free;
1315 i = -1;
1316 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, param) {
1317 i++;
1318 if (i == ignore)
1319 continue;
1320 if (!param->ident)
1321 continue;
1322 snprintf(buf, sizeof(buf), "%s orig", param->ident->name);
1323 compare = get_comparison_strings(var, buf);
1324 if (!compare)
1325 continue;
1326 if (show_special(compare)[0] != starts_with)
1327 continue;
1328 snprintf(buf, sizeof(buf), "[%s$%d]", show_special(compare), i);
1329 ret_str = alloc_sname(buf);
1330 break;
1331 } END_FOR_EACH_PTR(param);
1333 free:
1334 free_string(var);
1335 return ret_str;
1338 char *expr_equal_to_param(struct expression *expr, int ignore)
1340 return range_comparison_to_param_helper(expr, '=', ignore);
1343 char *expr_lte_to_param(struct expression *expr, int ignore)
1345 return range_comparison_to_param_helper(expr, '<', ignore);
1348 static void free_data(struct symbol *sym)
1350 if (__inline_fn)
1351 return;
1352 clear_compare_data_alloc();
1355 void register_comparison(int id)
1357 compare_id = id;
1358 add_hook(&match_compare, CONDITION_HOOK);
1359 add_hook(&match_assign, ASSIGNMENT_HOOK);
1360 add_hook(&save_start_states, AFTER_DEF_HOOK);
1361 add_unmatched_state_hook(compare_id, unmatched_comparison);
1362 add_merge_hook(compare_id, &merge_compare_states);
1363 add_hook(&free_data, AFTER_FUNC_HOOK);
1366 void register_comparison_links(int id)
1368 link_id = id;
1369 add_merge_hook(link_id, &merge_links);
1370 add_modification_hook(link_id, &match_modify);