6c136e51f0978f2d15ea7e51650914c2d6ef3ba3
[smatch.git] / smatch_comparison.c
blob6c136e51f0978f2d15ea7e51650914c2d6ef3ba3
1 /*
2 * smatch/smatch_comparison.c
4 * Copyright (C) 2012 Oracle.
6 * Licensed under the Open Software License version 1.1
8 */
11 * The point here is to store the relationships between two variables.
12 * Ie: y > x.
13 * To do that we create a state with the two variables in alphabetical order:
14 * ->name = "x vs y" and the state would be "<". On the false path the state
15 * would be ">=".
17 * Part of the trick of it is that if x or y is modified then we need to reset
18 * the state. We need to keep a list of all the states which depend on x and
19 * all the states which depend on y. The link_id code handles this.
21 * Future work: If we know that x is greater than y and y is greater than z
22 * then we know that x is greater than z.
25 #include "smatch.h"
26 #include "smatch_extra.h"
27 #include "smatch_slist.h"
29 static int compare_id;
30 static int link_id;
32 struct compare_data {
33 const char *var1;
34 struct symbol *sym1;
35 const char *var2;
36 struct symbol *sym2;
37 int comparison;
39 ALLOCATOR(compare_data, "compare data");
41 static struct smatch_state *alloc_compare_state(
42 const char *var1, struct symbol *sym1,
43 const char *var2, struct symbol *sym2,
44 int comparison)
46 struct smatch_state *state;
47 struct compare_data *data;
49 state = __alloc_smatch_state(0);
50 state->name = alloc_sname(show_special(comparison));
51 data = __alloc_compare_data(0);
52 data->var1 = alloc_sname(var1);
53 data->sym1 = sym1;
54 data->var2 = alloc_sname(var2);
55 data->sym2 = sym2;
56 data->comparison = comparison;
57 state->data = data;
58 return state;
61 static int state_to_comparison(struct smatch_state *state)
63 if (!state || !state->data)
64 return 0;
65 return ((struct compare_data *)state->data)->comparison;
69 * flip_op() reverses the op left and right. So "x >= y" becomes "y <= x".
71 static int flip_op(int op)
73 switch (op) {
74 case 0:
75 return 0;
76 case '<':
77 return '>';
78 case SPECIAL_UNSIGNED_LT:
79 return SPECIAL_UNSIGNED_GT;
80 case SPECIAL_LTE:
81 return SPECIAL_GTE;
82 case SPECIAL_UNSIGNED_LTE:
83 return SPECIAL_UNSIGNED_GTE;
84 case SPECIAL_EQUAL:
85 return SPECIAL_EQUAL;
86 case SPECIAL_NOTEQUAL:
87 return SPECIAL_NOTEQUAL;
88 case SPECIAL_GTE:
89 return SPECIAL_LTE;
90 case SPECIAL_UNSIGNED_GTE:
91 return SPECIAL_UNSIGNED_LTE;
92 case '>':
93 return '<';
94 case SPECIAL_UNSIGNED_GT:
95 return SPECIAL_UNSIGNED_LT;
96 default:
97 sm_msg("internal smatch bug. unhandled comparison %d", op);
98 return op;
102 static int falsify_op(int op)
104 switch (op) {
105 case 0:
106 return 0;
107 case '<':
108 return SPECIAL_GTE;
109 case SPECIAL_UNSIGNED_LT:
110 return SPECIAL_UNSIGNED_GTE;
111 case SPECIAL_LTE:
112 return '>';
113 case SPECIAL_UNSIGNED_LTE:
114 return SPECIAL_UNSIGNED_GT;
115 case SPECIAL_EQUAL:
116 return SPECIAL_NOTEQUAL;
117 case SPECIAL_NOTEQUAL:
118 return SPECIAL_EQUAL;
119 case SPECIAL_GTE:
120 return '<';
121 case SPECIAL_UNSIGNED_GTE:
122 return SPECIAL_UNSIGNED_LT;
123 case '>':
124 return SPECIAL_LTE;
125 case SPECIAL_UNSIGNED_GT:
126 return SPECIAL_UNSIGNED_LTE;
127 default:
128 sm_msg("internal smatch bug. unhandled comparison %d", op);
129 return op;
133 static int rl_comparison(struct range_list *left_rl, struct range_list *right_rl)
135 sval_t left_min, left_max, right_min, right_max;
137 if (!left_rl || !right_rl)
138 return 0;
140 left_min = rl_min(left_rl);
141 left_max = rl_max(left_rl);
142 right_min = rl_min(right_rl);
143 right_max = rl_max(right_rl);
145 if (left_min.value == left_max.value &&
146 right_min.value == right_max.value &&
147 left_min.value == right_min.value)
148 return SPECIAL_EQUAL;
150 if (sval_cmp(left_max, right_min) < 0)
151 return '<';
152 if (sval_cmp(left_max, right_min) == 0)
153 return SPECIAL_LTE;
154 if (sval_cmp(left_min, right_max) > 0)
155 return '>';
156 if (sval_cmp(left_min, right_max) == 0)
157 return SPECIAL_GTE;
159 return 0;
162 static struct range_list *get_orig_rl(struct symbol *sym)
164 struct smatch_state *state;
166 if (!sym || !sym->ident)
167 return NULL;
168 state = get_orig_estate(sym->ident->name, sym);
169 return estate_rl(state);
172 static struct smatch_state *unmatched_comparison(struct sm_state *sm)
174 struct compare_data *data = sm->state->data;
175 struct range_list *left_rl, *right_rl;
176 int op;
178 if (!data)
179 return &undefined;
181 if (strstr(data->var1, " orig"))
182 left_rl = get_orig_rl(data->sym1);
183 else if (!get_implied_rl_var_sym(data->var1, data->sym1, &left_rl))
184 return &undefined;
185 if (strstr(data->var2, " orig"))
186 right_rl = get_orig_rl(data->sym2);
187 else if (!get_implied_rl_var_sym(data->var2, data->sym2, &right_rl))
188 return &undefined;
190 op = rl_comparison(left_rl, right_rl);
191 if (op)
192 return alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, op);
194 return &undefined;
197 /* remove_unsigned_from_comparison() is obviously a hack. */
198 static int remove_unsigned_from_comparison(int op)
200 switch (op) {
201 case SPECIAL_UNSIGNED_LT:
202 return '<';
203 case SPECIAL_UNSIGNED_LTE:
204 return SPECIAL_LTE;
205 case SPECIAL_UNSIGNED_GTE:
206 return SPECIAL_GTE;
207 case SPECIAL_UNSIGNED_GT:
208 return '>';
209 default:
210 return op;
214 static int merge_comparisons(int one, int two)
216 int LT, EQ, GT;
218 one = remove_unsigned_from_comparison(one);
219 two = remove_unsigned_from_comparison(two);
221 LT = EQ = GT = 0;
223 switch (one) {
224 case '<':
225 LT = 1;
226 break;
227 case SPECIAL_LTE:
228 LT = 1;
229 EQ = 1;
230 break;
231 case SPECIAL_EQUAL:
232 EQ = 1;
233 break;
234 case SPECIAL_GTE:
235 GT = 1;
236 EQ = 1;
237 break;
238 case '>':
239 GT = 1;
242 switch (two) {
243 case '<':
244 LT = 1;
245 break;
246 case SPECIAL_LTE:
247 LT = 1;
248 EQ = 1;
249 break;
250 case SPECIAL_EQUAL:
251 EQ = 1;
252 break;
253 case SPECIAL_GTE:
254 GT = 1;
255 EQ = 1;
256 break;
257 case '>':
258 GT = 1;
261 if (LT && EQ && GT)
262 return 0;
263 if (LT && EQ)
264 return SPECIAL_LTE;
265 if (LT && GT)
266 return SPECIAL_NOTEQUAL;
267 if (LT)
268 return '<';
269 if (EQ && GT)
270 return SPECIAL_GTE;
271 if (GT)
272 return '>';
273 return 0;
276 static struct smatch_state *merge_compare_states(struct smatch_state *s1, struct smatch_state *s2)
278 struct compare_data *data = s1->data;
279 int op;
281 op = merge_comparisons(state_to_comparison(s1), state_to_comparison(s2));
282 if (op)
283 return alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, op);
284 return &undefined;
287 struct smatch_state *alloc_link_state(struct string_list *links)
289 struct smatch_state *state;
290 static char buf[256];
291 char *tmp;
292 int i;
294 state = __alloc_smatch_state(0);
296 i = 0;
297 FOR_EACH_PTR(links, tmp) {
298 if (!i++)
299 snprintf(buf, sizeof(buf), "%s", tmp);
300 else
301 snprintf(buf, sizeof(buf), "%s, %s", buf, tmp);
302 } END_FOR_EACH_PTR(tmp);
304 state->name = alloc_sname(buf);
305 state->data = links;
306 return state;
309 static void save_start_states(struct statement *stmt)
311 struct symbol *param;
312 char orig[64];
313 char state_name[128];
314 struct smatch_state *state;
315 struct string_list *links;
316 char *link;
318 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, param) {
319 if (!param->ident)
320 continue;
321 snprintf(orig, sizeof(orig), "%s orig", param->ident->name);
322 snprintf(state_name, sizeof(state_name), "%s vs %s", param->ident->name, orig);
323 state = alloc_compare_state(param->ident->name, param, alloc_sname(orig), param, SPECIAL_EQUAL);
324 set_state(compare_id, state_name, NULL, state);
326 link = alloc_sname(state_name);
327 links = NULL;
328 insert_string(&links, link);
329 state = alloc_link_state(links);
330 set_state(link_id, param->ident->name, param, state);
331 } END_FOR_EACH_PTR(param);
334 static struct smatch_state *merge_links(struct smatch_state *s1, struct smatch_state *s2)
336 struct smatch_state *ret;
337 struct string_list *links;
339 links = combine_string_lists(s1->data, s2->data);
340 ret = alloc_link_state(links);
341 return ret;
344 static void save_link_var_sym(const char *var, struct symbol *sym, char *link)
346 struct smatch_state *old_state, *new_state;
347 struct string_list *links;
348 char *new;
350 old_state = get_state(link_id, var, sym);
351 if (old_state)
352 links = clone_str_list(old_state->data);
353 else
354 links = NULL;
356 new = alloc_sname(link);
357 insert_string(&links, new);
359 new_state = alloc_link_state(links);
360 set_state(link_id, var, sym, new_state);
363 static void save_link(struct expression *expr, char *link)
365 char *var;
366 struct symbol *sym;
368 var = expr_to_var_sym(expr, &sym);
369 if (!var || !sym)
370 goto done;
372 save_link_var_sym(var, sym, link);
373 done:
374 free_string(var);
377 static void match_inc(struct sm_state *sm)
379 struct string_list *links;
380 struct smatch_state *state;
381 char *tmp;
383 links = sm->state->data;
385 FOR_EACH_PTR(links, tmp) {
386 state = get_state(compare_id, tmp, NULL);
388 switch (state_to_comparison(state)) {
389 case SPECIAL_EQUAL:
390 case SPECIAL_GTE:
391 case SPECIAL_UNSIGNED_GTE:
392 case '>':
393 case SPECIAL_UNSIGNED_GT: {
394 struct compare_data *data = state->data;
395 struct smatch_state *new;
397 new = alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, '>');
398 set_state(compare_id, tmp, NULL, new);
399 break;
401 default:
402 set_state(compare_id, tmp, NULL, &undefined);
404 } END_FOR_EACH_PTR(tmp);
407 static void match_dec(struct sm_state *sm)
409 struct string_list *links;
410 struct smatch_state *state;
411 char *tmp;
413 links = sm->state->data;
415 FOR_EACH_PTR(links, tmp) {
416 state = get_state(compare_id, tmp, NULL);
418 switch (state_to_comparison(state)) {
419 case SPECIAL_EQUAL:
420 case SPECIAL_LTE:
421 case SPECIAL_UNSIGNED_LTE:
422 case '<':
423 case SPECIAL_UNSIGNED_LT: {
424 struct compare_data *data = state->data;
425 struct smatch_state *new;
427 new = alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, '<');
428 set_state(compare_id, tmp, NULL, new);
429 break;
431 default:
432 set_state(compare_id, tmp, NULL, &undefined);
434 } END_FOR_EACH_PTR(tmp);
437 static int match_inc_dec(struct sm_state *sm, struct expression *mod_expr)
439 if (!mod_expr)
440 return 0;
441 if (mod_expr->type != EXPR_PREOP && mod_expr->type != EXPR_POSTOP)
442 return 0;
444 if (mod_expr->op == SPECIAL_INCREMENT) {
445 match_inc(sm);
446 return 1;
448 if (mod_expr->op == SPECIAL_DECREMENT) {
449 match_dec(sm);
450 return 1;
452 return 0;
455 static void match_modify(struct sm_state *sm, struct expression *mod_expr)
457 struct string_list *links;
458 char *tmp;
460 if (match_inc_dec(sm, mod_expr))
461 return;
463 links = sm->state->data;
465 FOR_EACH_PTR(links, tmp) {
466 set_state(compare_id, tmp, NULL, &undefined);
467 } END_FOR_EACH_PTR(tmp);
468 set_state(link_id, sm->name, sm->sym, &undefined);
471 static void match_logic(struct expression *expr)
473 char *left = NULL;
474 char *right = NULL;
475 struct symbol *left_sym, *right_sym;
476 int op, false_op;
477 struct smatch_state *true_state, *false_state;
478 char state_name[256];
480 if (expr->type != EXPR_COMPARE)
481 return;
482 left = expr_to_var_sym(expr->left, &left_sym);
483 if (!left || !left_sym)
484 goto free;
485 right = expr_to_var_sym(expr->right, &right_sym);
486 if (!right || !right_sym)
487 goto free;
489 if (strcmp(left, right) > 0) {
490 struct symbol *tmp_sym = left_sym;
491 char *tmp_name = left;
493 left = right;
494 left_sym = right_sym;
495 right = tmp_name;
496 right_sym = tmp_sym;
497 op = flip_op(expr->op);
498 } else {
499 op = expr->op;
501 false_op = falsify_op(op);
502 snprintf(state_name, sizeof(state_name), "%s vs %s", left, right);
503 true_state = alloc_compare_state(left, left_sym, right, right_sym, op);
504 false_state = alloc_compare_state(left, left_sym, right, right_sym, false_op);
506 set_true_false_states(compare_id, state_name, NULL, true_state, false_state);
507 save_link(expr->left, state_name);
508 save_link(expr->right, state_name);
509 free:
510 free_string(left);
511 free_string(right);
514 static void add_comparison_var_sym(const char *left_name, struct symbol *left_sym, int comparison, const char *right_name, struct symbol *right_sym)
516 struct smatch_state *state;
517 char state_name[256];
519 if (strcmp(left_name, right_name) > 0) {
520 struct symbol *tmp_sym = left_sym;
521 const char *tmp_name = left_name;
523 left_name = right_name;
524 left_sym = right_sym;
525 right_name = tmp_name;
526 right_sym = tmp_sym;
527 comparison = flip_op(comparison);
529 snprintf(state_name, sizeof(state_name), "%s vs %s", left_name, right_name);
530 state = alloc_compare_state(left_name, left_sym, right_name, right_sym, comparison);
532 set_state(compare_id, state_name, NULL, state);
533 save_link_var_sym(left_name, left_sym, state_name);
534 save_link_var_sym(right_name, right_sym, state_name);
537 static void add_comparison(struct expression *left, int comparison, struct expression *right)
539 char *left_name = NULL;
540 char *right_name = NULL;
541 struct symbol *left_sym, *right_sym;
543 left_name = expr_to_var_sym(left, &left_sym);
544 if (!left_name || !left_sym)
545 goto free;
546 right_name = expr_to_var_sym(right, &right_sym);
547 if (!right_name || !right_sym)
548 goto free;
550 add_comparison_var_sym(left_name, left_sym, comparison, right_name, right_sym);
552 free:
553 free_string(left_name);
554 free_string(right_name);
557 static void match_assign_add(struct expression *expr)
559 struct expression *right;
560 struct expression *r_left, *r_right;
561 sval_t left_tmp, right_tmp;
563 right = strip_expr(expr->right);
564 r_left = strip_expr(right->left);
565 r_right = strip_expr(right->right);
567 get_absolute_min(r_left, &left_tmp);
568 get_absolute_min(r_right, &right_tmp);
570 if (left_tmp.value > 0)
571 add_comparison(expr->left, '>', r_right);
572 else if (left_tmp.value == 0)
573 add_comparison(expr->left, SPECIAL_GTE, r_right);
575 if (right_tmp.value > 0)
576 add_comparison(expr->left, '>', r_left);
577 else if (right_tmp.value == 0)
578 add_comparison(expr->left, SPECIAL_GTE, r_left);
581 static void match_assign_sub(struct expression *expr)
583 struct expression *right;
584 struct expression *r_left, *r_right;
585 int comparison;
586 sval_t min;
588 right = strip_expr(expr->right);
589 r_left = strip_expr(right->left);
590 r_right = strip_expr(right->right);
592 if (get_absolute_min(r_right, &min) && sval_is_negative(min))
593 return;
595 comparison = get_comparison(r_left, r_right);
597 switch (comparison) {
598 case '>':
599 case SPECIAL_GTE:
600 if (implied_not_equal(r_right, 0))
601 add_comparison(expr->left, '>', r_left);
602 else
603 add_comparison(expr->left, SPECIAL_GTE, r_left);
604 return;
608 static void match_binop_assign(struct expression *expr)
610 struct expression *right;
612 right = strip_expr(expr->right);
613 if (right->op == '+')
614 match_assign_add(expr);
615 if (right->op == '-')
616 match_assign_sub(expr);
619 static void copy_comparisons(struct expression *left, struct expression *right)
621 struct string_list *links;
622 struct smatch_state *state;
623 struct compare_data *data;
624 struct symbol *left_sym, *right_sym;
625 char *left_var = NULL;
626 char *right_var = NULL;
627 const char *var;
628 struct symbol *sym;
629 int comparison;
630 char *tmp;
632 left_var = expr_to_var_sym(left, &left_sym);
633 if (!left_var || !left_sym)
634 goto done;
635 right_var = expr_to_var_sym(right, &right_sym);
636 if (!right_var || !right_sym)
637 goto done;
639 state = get_state_expr(link_id, right);
640 if (!state)
641 return;
642 links = state->data;
644 FOR_EACH_PTR(links, tmp) {
645 state = get_state(compare_id, tmp, NULL);
646 if (!state->data)
647 continue;
648 data = state->data;
649 comparison = data->comparison;
650 var = data->var2;
651 sym = data->sym2;
652 if (sym == right_sym && strcmp(var, right_var) == 0) {
653 var = data->var1;
654 sym = data->sym1;
655 comparison = flip_op(comparison);
657 add_comparison_var_sym(left_var, left_sym, comparison, var, sym);
658 } END_FOR_EACH_PTR(tmp);
660 done:
661 free_string(right_var);
664 static void match_normal_assign(struct expression *expr)
666 copy_comparisons(expr->left, expr->right);
667 add_comparison(expr->left, SPECIAL_EQUAL, expr->right);
670 static void match_assign(struct expression *expr)
672 struct expression *right;
674 right = strip_expr(expr->right);
675 if (right->type == EXPR_BINOP)
676 match_binop_assign(expr);
677 else
678 match_normal_assign(expr);
681 static int get_comparison_strings(char *one, char *two)
683 char buf[256];
684 struct smatch_state *state;
685 int invert = 0;
686 int ret = 0;
688 if (strcmp(one, two) > 0) {
689 char *tmp = one;
691 one = two;
692 two = tmp;
693 invert = 1;
696 snprintf(buf, sizeof(buf), "%s vs %s", one, two);
697 state = get_state(compare_id, buf, NULL);
698 if (state)
699 ret = state_to_comparison(state);
701 if (invert)
702 ret = flip_op(ret);
704 return ret;
707 int get_comparison(struct expression *a, struct expression *b)
709 char *one = NULL;
710 char *two = NULL;
711 int ret = 0;
713 one = expr_to_var(a);
714 if (!one)
715 goto free;
716 two = expr_to_var(b);
717 if (!two)
718 goto free;
720 ret = get_comparison_strings(one, two);
721 free:
722 free_string(one);
723 free_string(two);
724 return ret;
727 void __add_comparison_info(struct expression *expr, struct expression *call, const char *range)
729 struct expression *arg;
730 int comparison;
731 const char *c = range;
733 if (!str_to_comparison_arg(c, call, &comparison, &arg))
734 return;
735 add_comparison(expr, comparison, arg);
738 static char *range_comparison_to_param_helper(struct expression *expr, char starts_with)
740 struct symbol *param;
741 char *var = NULL;
742 char buf[256];
743 char *ret_str = NULL;
744 int compare;
745 int i;
747 var = expr_to_var(expr);
748 if (!var)
749 goto free;
751 i = -1;
752 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, param) {
753 i++;
754 if (!param->ident)
755 continue;
756 snprintf(buf, sizeof(buf), "%s orig", param->ident->name);
757 compare = get_comparison_strings(var, buf);
758 if (!compare)
759 continue;
760 if (show_special(compare)[0] != starts_with)
761 continue;
762 snprintf(buf, sizeof(buf), "[%sp%d]", show_special(compare), i);
763 ret_str = alloc_sname(buf);
764 break;
765 } END_FOR_EACH_PTR(param);
767 free:
768 free_string(var);
769 return ret_str;
772 char *expr_equal_to_param(struct expression *expr)
774 return range_comparison_to_param_helper(expr, '=');
777 char *expr_lte_to_param(struct expression *expr)
779 return range_comparison_to_param_helper(expr, '<');
782 static void free_data(struct symbol *sym)
784 if (__inline_fn)
785 return;
786 clear_compare_data_alloc();
789 void register_comparison(int id)
791 compare_id = id;
792 add_hook(&match_logic, CONDITION_HOOK);
793 add_hook(&match_assign, ASSIGNMENT_HOOK);
794 add_hook(&save_start_states, AFTER_DEF_HOOK);
795 add_unmatched_state_hook(compare_id, unmatched_comparison);
796 add_merge_hook(compare_id, &merge_compare_states);
797 add_hook(&free_data, AFTER_FUNC_HOOK);
800 void register_comparison_links(int id)
802 link_id = id;
803 add_merge_hook(link_id, &merge_links);
804 add_modification_hook(link_id, &match_modify);