ranges: simplify and robustify str_to_rl_helper() a bit
[smatch.git] / smatch_comparison.c
blobc5a67583a86b1fa81bf93377584c290fd041c004
1 /*
2 * Copyright (C) 2012 Oracle.
4 * This program is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU General Public License
6 * as published by the Free Software Foundation; either version 2
7 * of the License, or (at your option) any later version.
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, see http://www.gnu.org/copyleft/gpl.txt
19 * The point here is to store the relationships between two variables.
20 * Ie: y > x.
21 * To do that we create a state with the two variables in alphabetical order:
22 * ->name = "x vs y" and the state would be "<". On the false path the state
23 * would be ">=".
25 * Part of the trick of it is that if x or y is modified then we need to reset
26 * the state. We need to keep a list of all the states which depend on x and
27 * all the states which depend on y. The link_id code handles this.
31 #include "smatch.h"
32 #include "smatch_extra.h"
33 #include "smatch_slist.h"
35 static int compare_id;
36 static int link_id;
38 struct compare_data {
39 const char *var1;
40 struct var_sym_list *vsl1;
41 int comparison;
42 const char *var2;
43 struct var_sym_list *vsl2;
45 ALLOCATOR(compare_data, "compare data");
47 int chunk_vsl_eq(const char *a, struct var_sym_list *a_vsl, const char *b, struct var_sym_list *b_vsl)
49 if (strcmp(a, b) == 0)
50 return 1;
51 return 0;
54 static struct symbol *vsl_to_sym(struct var_sym_list *vsl)
56 struct var_sym *vs;
58 if (!vsl)
59 return NULL;
60 if (ptr_list_size((struct ptr_list *)vsl) != 1)
61 return NULL;
62 vs = first_ptr_list((struct ptr_list *)vsl);
63 return vs->sym;
66 static struct smatch_state *alloc_compare_state(
67 const char *var1, struct var_sym_list *vsl1,
68 int comparison,
69 const char *var2, struct var_sym_list *vsl2)
71 struct smatch_state *state;
72 struct compare_data *data;
74 state = __alloc_smatch_state(0);
75 state->name = alloc_sname(show_special(comparison));
76 data = __alloc_compare_data(0);
77 data->var1 = alloc_sname(var1);
78 data->vsl1 = clone_var_sym_list(vsl1);
79 data->comparison = comparison;
80 data->var2 = alloc_sname(var2);
81 data->vsl2 = clone_var_sym_list(vsl2);
82 state->data = data;
83 return state;
86 static int state_to_comparison(struct smatch_state *state)
88 if (!state || !state->data)
89 return 0;
90 return ((struct compare_data *)state->data)->comparison;
94 * flip_op() reverses the op left and right. So "x >= y" becomes "y <= x".
96 static int flip_op(int op)
98 switch (op) {
99 case 0:
100 return 0;
101 case '<':
102 return '>';
103 case SPECIAL_UNSIGNED_LT:
104 return SPECIAL_UNSIGNED_GT;
105 case SPECIAL_LTE:
106 return SPECIAL_GTE;
107 case SPECIAL_UNSIGNED_LTE:
108 return SPECIAL_UNSIGNED_GTE;
109 case SPECIAL_EQUAL:
110 return SPECIAL_EQUAL;
111 case SPECIAL_NOTEQUAL:
112 return SPECIAL_NOTEQUAL;
113 case SPECIAL_GTE:
114 return SPECIAL_LTE;
115 case SPECIAL_UNSIGNED_GTE:
116 return SPECIAL_UNSIGNED_LTE;
117 case '>':
118 return '<';
119 case SPECIAL_UNSIGNED_GT:
120 return SPECIAL_UNSIGNED_LT;
121 default:
122 sm_msg("internal smatch bug. unhandled comparison %d", op);
123 return op;
127 static int falsify_op(int op)
129 switch (op) {
130 case 0:
131 return 0;
132 case '<':
133 return SPECIAL_GTE;
134 case SPECIAL_UNSIGNED_LT:
135 return SPECIAL_UNSIGNED_GTE;
136 case SPECIAL_LTE:
137 return '>';
138 case SPECIAL_UNSIGNED_LTE:
139 return SPECIAL_UNSIGNED_GT;
140 case SPECIAL_EQUAL:
141 return SPECIAL_NOTEQUAL;
142 case SPECIAL_NOTEQUAL:
143 return SPECIAL_EQUAL;
144 case SPECIAL_GTE:
145 return '<';
146 case SPECIAL_UNSIGNED_GTE:
147 return SPECIAL_UNSIGNED_LT;
148 case '>':
149 return SPECIAL_LTE;
150 case SPECIAL_UNSIGNED_GT:
151 return SPECIAL_UNSIGNED_LTE;
152 default:
153 sm_msg("internal smatch bug. unhandled comparison %d", op);
154 return op;
158 static int rl_comparison(struct range_list *left_rl, struct range_list *right_rl)
160 sval_t left_min, left_max, right_min, right_max;
162 if (!left_rl || !right_rl)
163 return 0;
165 left_min = rl_min(left_rl);
166 left_max = rl_max(left_rl);
167 right_min = rl_min(right_rl);
168 right_max = rl_max(right_rl);
170 if (left_min.value == left_max.value &&
171 right_min.value == right_max.value &&
172 left_min.value == right_min.value)
173 return SPECIAL_EQUAL;
175 if (sval_cmp(left_max, right_min) < 0)
176 return '<';
177 if (sval_cmp(left_max, right_min) == 0)
178 return SPECIAL_LTE;
179 if (sval_cmp(left_min, right_max) > 0)
180 return '>';
181 if (sval_cmp(left_min, right_max) == 0)
182 return SPECIAL_GTE;
184 return 0;
187 static struct range_list *get_orig_rl(struct var_sym_list *vsl)
189 struct symbol *sym;
190 struct smatch_state *state;
192 if (!vsl)
193 return NULL;
194 sym = vsl_to_sym(vsl);
195 if (!sym || !sym->ident)
196 return NULL;
197 state = get_orig_estate(sym->ident->name, sym);
198 return estate_rl(state);
201 static struct smatch_state *unmatched_comparison(struct sm_state *sm)
203 struct compare_data *data = sm->state->data;
204 struct range_list *left_rl, *right_rl;
205 int op;
207 if (!data)
208 return &undefined;
210 if (strstr(data->var1, " orig"))
211 left_rl = get_orig_rl(data->vsl1);
212 else if (!get_implied_rl_var_sym(data->var1, vsl_to_sym(data->vsl1), &left_rl))
213 return &undefined;
214 if (strstr(data->var2, " orig"))
215 right_rl = get_orig_rl(data->vsl2);
216 else if (!get_implied_rl_var_sym(data->var2, vsl_to_sym(data->vsl2), &right_rl))
217 return &undefined;
220 op = rl_comparison(left_rl, right_rl);
221 if (op)
222 return alloc_compare_state(data->var1, data->vsl1, op, data->var2, data->vsl2);
224 return &undefined;
227 /* remove_unsigned_from_comparison() is obviously a hack. */
228 static int remove_unsigned_from_comparison(int op)
230 switch (op) {
231 case SPECIAL_UNSIGNED_LT:
232 return '<';
233 case SPECIAL_UNSIGNED_LTE:
234 return SPECIAL_LTE;
235 case SPECIAL_UNSIGNED_GTE:
236 return SPECIAL_GTE;
237 case SPECIAL_UNSIGNED_GT:
238 return '>';
239 default:
240 return op;
245 * This is for when you merge states "a < b" and "a == b", the result is that
246 * we can say for sure, "a <= b" after the merge.
248 static int merge_comparisons(int one, int two)
250 int LT, EQ, GT;
252 one = remove_unsigned_from_comparison(one);
253 two = remove_unsigned_from_comparison(two);
255 LT = EQ = GT = 0;
257 switch (one) {
258 case '<':
259 LT = 1;
260 break;
261 case SPECIAL_LTE:
262 LT = 1;
263 EQ = 1;
264 break;
265 case SPECIAL_EQUAL:
266 EQ = 1;
267 break;
268 case SPECIAL_GTE:
269 GT = 1;
270 EQ = 1;
271 break;
272 case '>':
273 GT = 1;
276 switch (two) {
277 case '<':
278 LT = 1;
279 break;
280 case SPECIAL_LTE:
281 LT = 1;
282 EQ = 1;
283 break;
284 case SPECIAL_EQUAL:
285 EQ = 1;
286 break;
287 case SPECIAL_GTE:
288 GT = 1;
289 EQ = 1;
290 break;
291 case '>':
292 GT = 1;
295 if (LT && EQ && GT)
296 return 0;
297 if (LT && EQ)
298 return SPECIAL_LTE;
299 if (LT && GT)
300 return SPECIAL_NOTEQUAL;
301 if (LT)
302 return '<';
303 if (EQ && GT)
304 return SPECIAL_GTE;
305 if (GT)
306 return '>';
307 return 0;
311 * This is for if you have "a < b" and "b <= c". Or in other words,
312 * "a < b <= c". You would call this like get_combined_comparison('<', '<=').
313 * The return comparison would be '<'.
315 * This function is different from merge_comparisons(), for example:
316 * merge_comparison('<', '==') returns '<='
317 * get_combined_comparison('<', '==') returns '<'
319 static int combine_comparisons(int left_compare, int right_compare)
321 int LT, EQ, GT;
323 left_compare = remove_unsigned_from_comparison(left_compare);
324 right_compare = remove_unsigned_from_comparison(right_compare);
326 LT = EQ = GT = 0;
328 switch (left_compare) {
329 case '<':
330 LT++;
331 break;
332 case SPECIAL_LTE:
333 LT++;
334 EQ++;
335 break;
336 case SPECIAL_EQUAL:
337 return right_compare;
338 case SPECIAL_GTE:
339 GT++;
340 EQ++;
341 break;
342 case '>':
343 GT++;
346 switch (right_compare) {
347 case '<':
348 LT++;
349 break;
350 case SPECIAL_LTE:
351 LT++;
352 EQ++;
353 break;
354 case SPECIAL_EQUAL:
355 return left_compare;
356 case SPECIAL_GTE:
357 GT++;
358 EQ++;
359 break;
360 case '>':
361 GT++;
364 if (LT == 2) {
365 if (EQ == 2)
366 return SPECIAL_LTE;
367 return '<';
370 if (GT == 2) {
371 if (EQ == 2)
372 return SPECIAL_GTE;
373 return '>';
375 return 0;
378 static int filter_comparison(int orig, int op)
380 switch (orig) {
381 case 0:
382 return op;
383 case '<':
384 switch (op) {
385 case '<':
386 return op;
387 case SPECIAL_NOTEQUAL:
388 return '<';
390 return 0;
391 case SPECIAL_LTE:
392 switch (op) {
393 case '<':
394 case SPECIAL_LTE:
395 case SPECIAL_EQUAL:
396 return op;
397 case SPECIAL_NOTEQUAL:
398 return '<';
400 return 0;
401 case SPECIAL_EQUAL:
402 switch (op) {
403 case SPECIAL_LTE:
404 case SPECIAL_EQUAL:
405 case SPECIAL_GTE:
406 return SPECIAL_EQUAL;
408 return 0;
409 case SPECIAL_NOTEQUAL:
410 switch (op) {
411 case '<':
412 case SPECIAL_LTE:
413 return '<';
414 case SPECIAL_NOTEQUAL:
415 return op;
416 case '>':
417 case SPECIAL_GTE:
418 return '>';
420 return 0;
421 case SPECIAL_GTE:
422 switch (op) {
423 case '>':
424 case SPECIAL_GTE:
425 case SPECIAL_EQUAL:
426 return op;
427 case SPECIAL_NOTEQUAL:
428 return '>';
430 return 0;
431 case '>':
432 switch (op) {
433 case '>':
434 case SPECIAL_GTE:
435 return '>';
436 case SPECIAL_NOTEQUAL:
437 return '>';
439 return 0;
441 sm_msg("Internal: what did I forget? orig = %d op = '%s'", orig, show_special(op));
442 return 0;
445 static struct smatch_state *merge_compare_states(struct smatch_state *s1, struct smatch_state *s2)
447 struct compare_data *data = s1->data;
448 int op;
450 op = merge_comparisons(state_to_comparison(s1), state_to_comparison(s2));
451 if (op)
452 return alloc_compare_state(data->var1, data->vsl1, op, data->var2, data->vsl2);
453 return &undefined;
456 static struct smatch_state *alloc_link_state(struct string_list *links)
458 struct smatch_state *state;
459 static char buf[256];
460 char *tmp;
461 int i;
463 state = __alloc_smatch_state(0);
465 i = 0;
466 FOR_EACH_PTR(links, tmp) {
467 if (!i++) {
468 snprintf(buf, sizeof(buf), "%s", tmp);
469 } else {
470 append(buf, ", ", sizeof(buf));
471 append(buf, tmp, sizeof(buf));
473 } END_FOR_EACH_PTR(tmp);
475 state->name = alloc_sname(buf);
476 state->data = links;
477 return state;
480 static void save_start_states(struct statement *stmt)
482 struct symbol *param;
483 char orig[64];
484 char state_name[128];
485 struct smatch_state *state;
486 struct string_list *links;
487 char *link;
489 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, param) {
490 struct var_sym_list *vsl1 = NULL;
491 struct var_sym_list *vsl2 = NULL;
493 if (!param->ident)
494 continue;
495 snprintf(orig, sizeof(orig), "%s orig", param->ident->name);
496 snprintf(state_name, sizeof(state_name), "%s vs %s", param->ident->name, orig);
497 add_var_sym(&vsl1, param->ident->name, param);
498 add_var_sym(&vsl2, orig, param);
499 state = alloc_compare_state(param->ident->name, vsl1, SPECIAL_EQUAL, alloc_sname(orig), vsl2);
500 set_state(compare_id, state_name, NULL, state);
502 link = alloc_sname(state_name);
503 links = NULL;
504 insert_string(&links, link);
505 state = alloc_link_state(links);
506 set_state(link_id, param->ident->name, param, state);
507 } END_FOR_EACH_PTR(param);
510 static struct smatch_state *merge_links(struct smatch_state *s1, struct smatch_state *s2)
512 struct smatch_state *ret;
513 struct string_list *links;
515 links = combine_string_lists(s1->data, s2->data);
516 ret = alloc_link_state(links);
517 return ret;
520 static void save_link_var_sym(const char *var, struct symbol *sym, const char *link)
522 struct smatch_state *old_state, *new_state;
523 struct string_list *links;
524 char *new;
526 old_state = get_state(link_id, var, sym);
527 if (old_state)
528 links = clone_str_list(old_state->data);
529 else
530 links = NULL;
532 new = alloc_sname(link);
533 insert_string(&links, new);
535 new_state = alloc_link_state(links);
536 set_state(link_id, var, sym, new_state);
539 static void match_inc(struct sm_state *sm)
541 struct string_list *links;
542 struct smatch_state *state;
543 char *tmp;
545 links = sm->state->data;
547 FOR_EACH_PTR(links, tmp) {
548 state = get_state(compare_id, tmp, NULL);
550 switch (state_to_comparison(state)) {
551 case SPECIAL_EQUAL:
552 case SPECIAL_GTE:
553 case SPECIAL_UNSIGNED_GTE:
554 case '>':
555 case SPECIAL_UNSIGNED_GT: {
556 struct compare_data *data = state->data;
557 struct smatch_state *new;
559 new = alloc_compare_state(data->var1, data->vsl1, '>', data->var2, data->vsl2);
560 set_state(compare_id, tmp, NULL, new);
561 break;
563 default:
564 set_state(compare_id, tmp, NULL, &undefined);
566 } END_FOR_EACH_PTR(tmp);
569 static void match_dec(struct sm_state *sm)
571 struct string_list *links;
572 struct smatch_state *state;
573 char *tmp;
575 links = sm->state->data;
577 FOR_EACH_PTR(links, tmp) {
578 state = get_state(compare_id, tmp, NULL);
580 switch (state_to_comparison(state)) {
581 case SPECIAL_EQUAL:
582 case SPECIAL_LTE:
583 case SPECIAL_UNSIGNED_LTE:
584 case '<':
585 case SPECIAL_UNSIGNED_LT: {
586 struct compare_data *data = state->data;
587 struct smatch_state *new;
589 new = alloc_compare_state(data->var1, data->vsl1, '<', data->var2, data->vsl2);
590 set_state(compare_id, tmp, NULL, new);
591 break;
593 default:
594 set_state(compare_id, tmp, NULL, &undefined);
596 } END_FOR_EACH_PTR(tmp);
599 static int match_inc_dec(struct sm_state *sm, struct expression *mod_expr)
601 if (!mod_expr)
602 return 0;
603 if (mod_expr->type != EXPR_PREOP && mod_expr->type != EXPR_POSTOP)
604 return 0;
606 if (mod_expr->op == SPECIAL_INCREMENT) {
607 match_inc(sm);
608 return 1;
610 if (mod_expr->op == SPECIAL_DECREMENT) {
611 match_dec(sm);
612 return 1;
614 return 0;
617 static void match_modify(struct sm_state *sm, struct expression *mod_expr)
619 struct string_list *links;
620 char *tmp;
622 /* Huh??? This needs a comment! */
623 if (match_inc_dec(sm, mod_expr))
624 return;
626 links = sm->state->data;
628 FOR_EACH_PTR(links, tmp) {
629 set_state(compare_id, tmp, NULL, &undefined);
630 } END_FOR_EACH_PTR(tmp);
631 set_state(link_id, sm->name, sm->sym, &undefined);
634 static char *chunk_to_var_sym(struct expression *expr, struct symbol **sym)
636 char *name, *left_name, *right_name;
637 struct symbol *tmp;
638 char buf[128];
640 expr = strip_expr(expr);
641 if (!expr)
642 return NULL;
643 if (sym)
644 *sym = NULL;
646 name = expr_to_var_sym(expr, &tmp);
647 if (name && tmp) {
648 if (sym)
649 *sym = tmp;
650 return name;
652 if (name)
653 free_string(name);
655 if (expr->type != EXPR_BINOP)
656 return NULL;
657 if (expr->op != '-' && expr->op != '+')
658 return NULL;
660 left_name = expr_to_var(expr->left);
661 if (!left_name)
662 return NULL;
663 right_name = expr_to_var(expr->right);
664 if (!right_name) {
665 free_string(left_name);
666 return NULL;
668 snprintf(buf, sizeof(buf), "%s %s %s", left_name, show_special(expr->op), right_name);
669 free_string(left_name);
670 free_string(right_name);
671 return alloc_string(buf);
674 static char *chunk_to_var(struct expression *expr)
676 return chunk_to_var_sym(expr, NULL);
679 static void save_link(struct expression *expr, char *link)
681 char *var;
682 struct symbol *sym;
684 expr = strip_expr(expr);
685 if (expr->type == EXPR_BINOP) {
686 char *chunk;
688 chunk = chunk_to_var(expr);
689 if (!chunk)
690 return;
692 save_link(expr->left, link);
693 save_link(expr->right, link);
694 save_link_var_sym(chunk, NULL, link);
695 return;
698 var = expr_to_var_sym(expr, &sym);
699 if (!var || !sym)
700 goto done;
702 save_link_var_sym(var, sym, link);
703 done:
704 free_string(var);
707 static void update_tf_links(struct stree *pre_stree,
708 const char *left_var, struct var_sym_list *left_vsl,
709 int left_comparison,
710 const char *mid_var, struct var_sym_list *mid_vsl,
711 struct string_list *links)
713 struct smatch_state *state;
714 struct smatch_state *true_state, *false_state;
715 struct compare_data *data;
716 const char *right_var;
717 struct var_sym_list *right_vsl;
718 int right_comparison;
719 int true_comparison;
720 int false_comparison;
721 char *tmp;
722 char state_name[256];
723 struct var_sym *vs;
725 FOR_EACH_PTR(links, tmp) {
726 state = get_state_stree(pre_stree, compare_id, tmp, NULL);
727 if (!state || !state->data)
728 continue;
729 data = state->data;
730 right_comparison = data->comparison;
731 right_var = data->var2;
732 right_vsl = data->vsl2;
733 if (chunk_vsl_eq(mid_var, mid_vsl, right_var, right_vsl)) {
734 right_var = data->var1;
735 right_vsl = data->vsl1;
736 right_comparison = flip_op(right_comparison);
738 true_comparison = combine_comparisons(left_comparison, right_comparison);
739 false_comparison = combine_comparisons(falsify_op(left_comparison), right_comparison);
741 if (strcmp(left_var, right_var) > 0) {
742 const char *tmp_var = left_var;
743 struct var_sym_list *tmp_vsl = left_vsl;
745 left_var = right_var;
746 left_vsl = right_vsl;
747 right_var = tmp_var;
748 right_vsl = tmp_vsl;
749 true_comparison = flip_op(true_comparison);
750 false_comparison = flip_op(false_comparison);
753 if (!true_comparison && !false_comparison)
754 continue;
756 if (true_comparison)
757 true_state = alloc_compare_state(left_var, left_vsl, true_comparison, right_var, right_vsl);
758 else
759 true_state = NULL;
760 if (false_comparison)
761 false_state = alloc_compare_state(left_var, left_vsl, false_comparison, right_var, right_vsl);
762 else
763 false_state = NULL;
765 snprintf(state_name, sizeof(state_name), "%s vs %s", left_var, right_var);
766 set_true_false_states(compare_id, state_name, NULL, true_state, false_state);
767 FOR_EACH_PTR(left_vsl, vs) {
768 save_link_var_sym(vs->var, vs->sym, state_name);
769 } END_FOR_EACH_PTR(vs);
770 FOR_EACH_PTR(right_vsl, vs) {
771 save_link_var_sym(vs->var, vs->sym, state_name);
772 } END_FOR_EACH_PTR(vs);
773 if (!vsl_to_sym(left_vsl))
774 save_link_var_sym(left_var, NULL, state_name);
775 if (!vsl_to_sym(right_vsl))
776 save_link_var_sym(right_var, NULL, state_name);
777 } END_FOR_EACH_PTR(tmp);
780 static void update_tf_data(struct stree *pre_stree,
781 const char *left_name, struct symbol *left_sym,
782 const char *right_name, struct symbol *right_sym,
783 struct compare_data *tdata)
785 struct smatch_state *state;
787 state = get_state_stree(pre_stree, link_id, tdata->var2, vsl_to_sym(tdata->vsl2));
788 if (state)
789 update_tf_links(pre_stree, tdata->var1, tdata->vsl1, tdata->comparison, tdata->var2, tdata->vsl2, state->data);
791 state = get_state_stree(pre_stree, link_id, tdata->var1, vsl_to_sym(tdata->vsl1));
792 if (state)
793 update_tf_links(pre_stree, tdata->var2, tdata->vsl2, flip_op(tdata->comparison), tdata->var1, tdata->vsl1, state->data);
796 static void match_compare(struct expression *expr)
798 char *left = NULL;
799 char *right = NULL;
800 struct symbol *left_sym, *right_sym;
801 struct var_sym_list *left_vsl, *right_vsl;
802 int op, false_op;
803 int orig_comparison;
804 struct smatch_state *true_state, *false_state;
805 char state_name[256];
806 struct stree *pre_stree;
808 if (expr->type != EXPR_COMPARE)
809 return;
810 left = chunk_to_var_sym(expr->left, &left_sym);
811 if (!left)
812 goto free;
813 left_vsl = expr_to_vsl(expr->left);
814 right = chunk_to_var_sym(expr->right, &right_sym);
815 if (!right)
816 goto free;
817 right_vsl = expr_to_vsl(expr->right);
819 if (strcmp(left, right) > 0) {
820 struct symbol *tmp_sym = left_sym;
821 char *tmp_name = left;
822 struct var_sym_list *tmp_vsl = left_vsl;
824 left = right;
825 left_sym = right_sym;
826 left_vsl = right_vsl;
827 right = tmp_name;
828 right_sym = tmp_sym;
829 right_vsl = tmp_vsl;
830 op = flip_op(expr->op);
831 } else {
832 op = expr->op;
834 false_op = falsify_op(op);
836 orig_comparison = get_comparison_strings(left, right);
837 op = filter_comparison(orig_comparison, op);
838 false_op = filter_comparison(orig_comparison, false_op);
840 snprintf(state_name, sizeof(state_name), "%s vs %s", left, right);
841 true_state = alloc_compare_state(left, left_vsl, op, right, right_vsl);
842 false_state = alloc_compare_state(left, left_vsl, false_op, right, right_vsl);
844 pre_stree = clone_stree(__get_cur_stree());
845 update_tf_data(pre_stree, left, left_sym, right, right_sym, true_state->data);
846 free_stree(&pre_stree);
848 set_true_false_states(compare_id, state_name, NULL, true_state, false_state);
849 save_link(expr->left, state_name);
850 save_link(expr->right, state_name);
851 free:
852 free_string(left);
853 free_string(right);
856 static void add_comparison_var_sym(const char *left_name,
857 struct var_sym_list *left_vsl,
858 int comparison,
859 const char *right_name, struct var_sym_list *right_vsl)
861 struct smatch_state *state;
862 struct var_sym *vs;
863 char state_name[256];
865 if (strcmp(left_name, right_name) > 0) {
866 const char *tmp_name = left_name;
867 struct var_sym_list *tmp_vsl = left_vsl;
869 left_name = right_name;
870 left_vsl = right_vsl;
871 right_name = tmp_name;
872 right_vsl = tmp_vsl;
873 comparison = flip_op(comparison);
875 snprintf(state_name, sizeof(state_name), "%s vs %s", left_name, right_name);
876 state = alloc_compare_state(left_name, left_vsl, comparison, right_name, right_vsl);
878 set_state(compare_id, state_name, NULL, state);
880 FOR_EACH_PTR(left_vsl, vs) {
881 save_link_var_sym(vs->var, vs->sym, state_name);
882 } END_FOR_EACH_PTR(vs);
883 FOR_EACH_PTR(right_vsl, vs) {
884 save_link_var_sym(vs->var, vs->sym, state_name);
885 } END_FOR_EACH_PTR(vs);
888 static void add_comparison(struct expression *left, int comparison, struct expression *right)
890 char *left_name = NULL;
891 char *right_name = NULL;
892 struct symbol *left_sym, *right_sym;
893 struct var_sym_list *left_vsl, *right_vsl;
894 struct smatch_state *state;
895 char state_name[256];
897 left_name = chunk_to_var_sym(left, &left_sym);
898 if (!left_name)
899 goto free;
900 left_vsl = expr_to_vsl(left);
901 right_name = chunk_to_var_sym(right, &right_sym);
902 if (!right_name)
903 goto free;
904 right_vsl = expr_to_vsl(right);
906 if (strcmp(left_name, right_name) > 0) {
907 struct symbol *tmp_sym = left_sym;
908 char *tmp_name = left_name;
909 struct var_sym_list *tmp_vsl = left_vsl;
911 left_name = right_name;
912 left_sym = right_sym;
913 left_vsl = right_vsl;
914 right_name = tmp_name;
915 right_sym = tmp_sym;
916 right_vsl = tmp_vsl;
917 comparison = flip_op(comparison);
919 snprintf(state_name, sizeof(state_name), "%s vs %s", left_name, right_name);
920 state = alloc_compare_state(left_name, left_vsl, comparison, right_name, right_vsl);
922 set_state(compare_id, state_name, NULL, state);
923 save_link(left, state_name);
924 save_link(right, state_name);
926 free:
927 free_string(left_name);
928 free_string(right_name);
931 static void match_assign_add(struct expression *expr)
933 struct expression *right;
934 struct expression *r_left, *r_right;
935 sval_t left_tmp, right_tmp;
937 right = strip_expr(expr->right);
938 r_left = strip_expr(right->left);
939 r_right = strip_expr(right->right);
941 get_absolute_min(r_left, &left_tmp);
942 get_absolute_min(r_right, &right_tmp);
944 if (left_tmp.value > 0)
945 add_comparison(expr->left, '>', r_right);
946 else if (left_tmp.value == 0)
947 add_comparison(expr->left, SPECIAL_GTE, r_right);
949 if (right_tmp.value > 0)
950 add_comparison(expr->left, '>', r_left);
951 else if (right_tmp.value == 0)
952 add_comparison(expr->left, SPECIAL_GTE, r_left);
955 static void match_assign_sub(struct expression *expr)
957 struct expression *right;
958 struct expression *r_left, *r_right;
959 int comparison;
960 sval_t min;
962 right = strip_expr(expr->right);
963 r_left = strip_expr(right->left);
964 r_right = strip_expr(right->right);
966 if (get_absolute_min(r_right, &min) && sval_is_negative(min))
967 return;
969 comparison = get_comparison(r_left, r_right);
971 switch (comparison) {
972 case '>':
973 case SPECIAL_GTE:
974 if (implied_not_equal(r_right, 0))
975 add_comparison(expr->left, '>', r_left);
976 else
977 add_comparison(expr->left, SPECIAL_GTE, r_left);
978 return;
982 static void match_assign_divide(struct expression *expr)
984 struct expression *right;
985 struct expression *r_left, *r_right;
986 sval_t min;
988 right = strip_expr(expr->right);
989 r_left = strip_expr(right->left);
990 r_right = strip_expr(right->right);
991 if (!get_implied_min(r_right, &min) || min.value <= 1)
992 return;
994 add_comparison(expr->left, '<', r_left);
997 static void match_binop_assign(struct expression *expr)
999 struct expression *right;
1001 right = strip_expr(expr->right);
1002 if (right->op == '+')
1003 match_assign_add(expr);
1004 if (right->op == '-')
1005 match_assign_sub(expr);
1006 if (right->op == '/')
1007 match_assign_divide(expr);
1010 static void copy_comparisons(struct expression *left, struct expression *right)
1012 struct string_list *links;
1013 struct smatch_state *state;
1014 struct compare_data *data;
1015 struct symbol *left_sym, *right_sym;
1016 char *left_var = NULL;
1017 char *right_var = NULL;
1018 struct var_sym_list *left_vsl;
1019 const char *var;
1020 struct var_sym_list *vsl;
1021 int comparison;
1022 char *tmp;
1024 left_var = chunk_to_var_sym(left, &left_sym);
1025 if (!left_var)
1026 goto done;
1027 left_vsl = expr_to_vsl(left);
1028 right_var = chunk_to_var_sym(right, &right_sym);
1029 if (!right_var)
1030 goto done;
1032 state = get_state(link_id, right_var, right_sym);
1033 if (!state)
1034 return;
1035 links = state->data;
1037 FOR_EACH_PTR(links, tmp) {
1038 state = get_state(compare_id, tmp, NULL);
1039 if (!state || !state->data)
1040 continue;
1041 data = state->data;
1042 comparison = data->comparison;
1043 var = data->var2;
1044 vsl = data->vsl2;
1045 if (chunk_vsl_eq(var, vsl, right_var, NULL)) {
1046 var = data->var1;
1047 vsl = data->vsl1;
1048 comparison = flip_op(comparison);
1050 add_comparison_var_sym(left_var, left_vsl, comparison, var, vsl);
1051 } END_FOR_EACH_PTR(tmp);
1053 done:
1054 free_string(right_var);
1057 static void match_assign(struct expression *expr)
1059 struct expression *right;
1061 if (expr->op != '=')
1062 return;
1064 if (is_struct(expr->left))
1065 return;
1067 copy_comparisons(expr->left, expr->right);
1068 add_comparison(expr->left, SPECIAL_EQUAL, expr->right);
1070 right = strip_expr(expr->right);
1071 if (right->type == EXPR_BINOP)
1072 match_binop_assign(expr);
1075 int get_comparison_strings(const char *one, const char *two)
1077 char buf[256];
1078 struct smatch_state *state;
1079 int invert = 0;
1080 int ret = 0;
1082 if (strcmp(one, two) == 0)
1083 return SPECIAL_EQUAL;
1085 if (strcmp(one, two) > 0) {
1086 const char *tmp = one;
1088 one = two;
1089 two = tmp;
1090 invert = 1;
1093 snprintf(buf, sizeof(buf), "%s vs %s", one, two);
1094 state = get_state(compare_id, buf, NULL);
1095 if (state)
1096 ret = state_to_comparison(state);
1098 if (invert)
1099 ret = flip_op(ret);
1101 return ret;
1104 int get_comparison(struct expression *a, struct expression *b)
1106 char *one = NULL;
1107 char *two = NULL;
1108 int ret = 0;
1110 one = chunk_to_var(a);
1111 if (!one)
1112 goto free;
1113 two = chunk_to_var(b);
1114 if (!two)
1115 goto free;
1117 ret = get_comparison_strings(one, two);
1118 free:
1119 free_string(one);
1120 free_string(two);
1121 return ret;
1124 int possible_comparison(struct expression *a, int comparison, struct expression *b)
1126 char *one = NULL;
1127 char *two = NULL;
1128 int ret = 0;
1129 char buf[256];
1130 struct sm_state *sm;
1131 int saved;
1133 one = chunk_to_var(a);
1134 if (!one)
1135 goto free;
1136 two = chunk_to_var(b);
1137 if (!two)
1138 goto free;
1141 if (strcmp(one, two) == 0 && comparison == SPECIAL_EQUAL) {
1142 ret = 1;
1143 goto free;
1146 if (strcmp(one, two) > 0) {
1147 char *tmp = one;
1149 one = two;
1150 two = tmp;
1151 comparison = flip_op(comparison);
1154 snprintf(buf, sizeof(buf), "%s vs %s", one, two);
1155 sm = get_sm_state(compare_id, buf, NULL);
1156 if (!sm)
1157 goto free;
1159 FOR_EACH_PTR(sm->possible, sm) {
1160 if (!sm->state->data)
1161 continue;
1162 saved = ((struct compare_data *)sm->state->data)->comparison;
1163 if (saved == comparison)
1164 ret = 1;
1165 if (comparison == SPECIAL_EQUAL &&
1166 (saved == SPECIAL_LTE ||
1167 saved == SPECIAL_GTE ||
1168 saved == SPECIAL_UNSIGNED_LTE ||
1169 saved == SPECIAL_UNSIGNED_GTE))
1170 ret = 1;
1171 if (ret == 1)
1172 goto free;
1173 } END_FOR_EACH_PTR(sm);
1175 return ret;
1176 free:
1177 free_string(one);
1178 free_string(two);
1179 return ret;
1182 static void update_links_from_call(struct expression *left,
1183 int left_compare,
1184 struct expression *right)
1186 struct string_list *links;
1187 struct smatch_state *state;
1188 struct compare_data *data;
1189 struct symbol *left_sym, *right_sym;
1190 char *left_var = NULL;
1191 char *right_var = NULL;
1192 struct var_sym_list *left_vsl;
1193 const char *var;
1194 struct var_sym_list *vsl;
1195 int comparison;
1196 char *tmp;
1198 left_var = chunk_to_var_sym(left, &left_sym);
1199 if (!left_var)
1200 goto done;
1201 left_vsl = expr_to_vsl(left);
1202 right_var = chunk_to_var_sym(right, &right_sym);
1203 if (!right_var)
1204 goto done;
1206 state = get_state(link_id, right_var, right_sym);
1207 if (!state)
1208 return;
1209 links = state->data;
1211 FOR_EACH_PTR(links, tmp) {
1212 state = get_state(compare_id, tmp, NULL);
1213 if (!state || !state->data)
1214 continue;
1215 data = state->data;
1216 comparison = data->comparison;
1217 var = data->var2;
1218 vsl = data->vsl2;
1219 if (chunk_vsl_eq(var, vsl, right_var, NULL)) {
1220 var = data->var1;
1221 vsl = data->vsl1;
1222 comparison = flip_op(comparison);
1224 comparison = combine_comparisons(left_compare, comparison);
1225 if (!comparison)
1226 continue;
1227 add_comparison_var_sym(left_var, left_vsl, comparison, var, vsl);
1228 } END_FOR_EACH_PTR(tmp);
1230 done:
1231 free_string(right_var);
1234 void __add_comparison_info(struct expression *expr, struct expression *call, const char *range)
1236 struct expression *arg;
1237 int comparison;
1238 const char *c = range;
1240 if (!str_to_comparison_arg(c, call, &comparison, &arg))
1241 return;
1242 update_links_from_call(expr, comparison, arg);
1243 add_comparison(expr, comparison, arg);
1246 static char *range_comparison_to_param_helper(struct expression *expr, char starts_with, int ignore)
1248 struct symbol *param;
1249 char *var = NULL;
1250 char buf[256];
1251 char *ret_str = NULL;
1252 int compare;
1253 int i;
1255 var = chunk_to_var(expr);
1256 if (!var)
1257 goto free;
1259 i = -1;
1260 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, param) {
1261 i++;
1262 if (i == ignore)
1263 continue;
1264 if (!param->ident)
1265 continue;
1266 snprintf(buf, sizeof(buf), "%s orig", param->ident->name);
1267 compare = get_comparison_strings(var, buf);
1268 if (!compare)
1269 continue;
1270 if (show_special(compare)[0] != starts_with)
1271 continue;
1272 snprintf(buf, sizeof(buf), "[%s$%d]", show_special(compare), i);
1273 ret_str = alloc_sname(buf);
1274 break;
1275 } END_FOR_EACH_PTR(param);
1277 free:
1278 free_string(var);
1279 return ret_str;
1282 char *expr_equal_to_param(struct expression *expr, int ignore)
1284 return range_comparison_to_param_helper(expr, '=', ignore);
1287 char *expr_lte_to_param(struct expression *expr, int ignore)
1289 return range_comparison_to_param_helper(expr, '<', ignore);
1292 static void free_data(struct symbol *sym)
1294 if (__inline_fn)
1295 return;
1296 clear_compare_data_alloc();
1299 void register_comparison(int id)
1301 compare_id = id;
1302 add_hook(&match_compare, CONDITION_HOOK);
1303 add_hook(&match_assign, ASSIGNMENT_HOOK);
1304 add_hook(&save_start_states, AFTER_DEF_HOOK);
1305 add_unmatched_state_hook(compare_id, unmatched_comparison);
1306 add_merge_hook(compare_id, &merge_compare_states);
1307 add_hook(&free_data, AFTER_FUNC_HOOK);
1310 void register_comparison_links(int id)
1312 link_id = id;
1313 add_merge_hook(link_id, &merge_links);
1314 add_modification_hook(link_id, &match_modify);