comparison: create expr_equal/lte_to_param() functions
[smatch.git] / smatch_comparison.c
blob8535f91e5a7932b86bcb06b1d772f427ac4367cb
1 /*
2 * smatch/smatch_comparison.c
4 * Copyright (C) 2012 Oracle.
6 * Licensed under the Open Software License version 1.1
8 */
11 * The point here is to store the relationships between two variables.
12 * Ie: y > x.
13 * To do that we create a state with the two variables in alphabetical order:
14 * ->name = "x vs y" and the state would be "<". On the false path the state
15 * would be ">=".
17 * Part of the trick of it is that if x or y is modified then we need to reset
18 * the state. We need to keep a list of all the states which depend on x and
19 * all the states which depend on y. The link_id code handles this.
21 * Future work: If we know that x is greater than y and y is greater than z
22 * then we know that x is greater than z.
25 #include "smatch.h"
26 #include "smatch_extra.h"
27 #include "smatch_slist.h"
29 static int compare_id;
30 static int link_id;
32 struct compare_data {
33 const char *var1;
34 struct symbol *sym1;
35 const char *var2;
36 struct symbol *sym2;
37 int comparison;
39 ALLOCATOR(compare_data, "compare data");
41 static struct smatch_state *alloc_compare_state(
42 const char *var1, struct symbol *sym1,
43 const char *var2, struct symbol *sym2,
44 int comparison)
46 struct smatch_state *state;
47 struct compare_data *data;
49 state = __alloc_smatch_state(0);
50 state->name = alloc_sname(show_special(comparison));
51 data = __alloc_compare_data(0);
52 data->var1 = alloc_sname(var1);
53 data->sym1 = sym1;
54 data->var2 = alloc_sname(var2);
55 data->sym2 = sym2;
56 data->comparison = comparison;
57 state->data = data;
58 return state;
61 static int state_to_comparison(struct smatch_state *state)
63 if (!state || !state->data)
64 return 0;
65 return ((struct compare_data *)state->data)->comparison;
69 * flip_op() reverses the op left and right. So "x >= y" becomes "y <= x".
71 static int flip_op(int op)
73 switch (op) {
74 case 0:
75 return 0;
76 case '<':
77 return '>';
78 case SPECIAL_UNSIGNED_LT:
79 return SPECIAL_UNSIGNED_GT;
80 case SPECIAL_LTE:
81 return SPECIAL_GTE;
82 case SPECIAL_UNSIGNED_LTE:
83 return SPECIAL_UNSIGNED_GTE;
84 case SPECIAL_EQUAL:
85 return SPECIAL_EQUAL;
86 case SPECIAL_NOTEQUAL:
87 return SPECIAL_NOTEQUAL;
88 case SPECIAL_GTE:
89 return SPECIAL_LTE;
90 case SPECIAL_UNSIGNED_GTE:
91 return SPECIAL_UNSIGNED_LTE;
92 case '>':
93 return '<';
94 case SPECIAL_UNSIGNED_GT:
95 return SPECIAL_UNSIGNED_LT;
96 default:
97 sm_msg("internal smatch bug. unhandled comparison %d", op);
98 return op;
102 static int falsify_op(int op)
104 switch (op) {
105 case 0:
106 return 0;
107 case '<':
108 return SPECIAL_GTE;
109 case SPECIAL_UNSIGNED_LT:
110 return SPECIAL_UNSIGNED_GTE;
111 case SPECIAL_LTE:
112 return '>';
113 case SPECIAL_UNSIGNED_LTE:
114 return SPECIAL_UNSIGNED_GT;
115 case SPECIAL_EQUAL:
116 return SPECIAL_NOTEQUAL;
117 case SPECIAL_NOTEQUAL:
118 return SPECIAL_EQUAL;
119 case SPECIAL_GTE:
120 return '<';
121 case SPECIAL_UNSIGNED_GTE:
122 return SPECIAL_UNSIGNED_LT;
123 case '>':
124 return SPECIAL_LTE;
125 case SPECIAL_UNSIGNED_GT:
126 return SPECIAL_UNSIGNED_LTE;
127 default:
128 sm_msg("internal smatch bug. unhandled comparison %d", op);
129 return op;
133 static int rl_comparison(struct range_list *left_rl, struct range_list *right_rl)
135 sval_t left_min, left_max, right_min, right_max;
137 if (!left_rl || !right_rl)
138 return 0;
140 left_min = rl_min(left_rl);
141 left_max = rl_max(left_rl);
142 right_min = rl_min(right_rl);
143 right_max = rl_max(right_rl);
145 if (left_min.value == left_max.value &&
146 right_min.value == right_max.value &&
147 left_min.value == right_min.value)
148 return SPECIAL_EQUAL;
150 if (sval_cmp(left_max, right_min) < 0)
151 return '<';
152 if (sval_cmp(left_max, right_min) == 0)
153 return SPECIAL_LTE;
154 if (sval_cmp(left_min, right_max) > 0)
155 return '>';
156 if (sval_cmp(left_min, right_max) == 0)
157 return SPECIAL_GTE;
159 return 0;
162 static struct smatch_state *unmatched_comparison(struct sm_state *sm)
164 struct compare_data *data = sm->state->data;
165 struct range_list *left_rl, *right_rl;
166 int op;
168 if (!data)
169 return &undefined;
171 if (!get_implied_rl_var_sym(data->var1, data->sym1, &left_rl))
172 return &undefined;
173 if (!get_implied_rl_var_sym(data->var2, data->sym2, &right_rl))
174 return &undefined;
176 op = rl_comparison(left_rl, right_rl);
177 if (op)
178 return alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, op);
180 return &undefined;
183 /* remove_unsigned_from_comparison() is obviously a hack. */
184 static int remove_unsigned_from_comparison(int op)
186 switch (op) {
187 case SPECIAL_UNSIGNED_LT:
188 return '<';
189 case SPECIAL_UNSIGNED_LTE:
190 return SPECIAL_LTE;
191 case SPECIAL_UNSIGNED_GTE:
192 return SPECIAL_GTE;
193 case SPECIAL_UNSIGNED_GT:
194 return '>';
195 default:
196 return op;
200 static int merge_comparisons(int one, int two)
202 int LT, EQ, GT;
204 one = remove_unsigned_from_comparison(one);
205 two = remove_unsigned_from_comparison(two);
207 LT = EQ = GT = 0;
209 switch (one) {
210 case '<':
211 LT = 1;
212 break;
213 case SPECIAL_LTE:
214 LT = 1;
215 EQ = 1;
216 break;
217 case SPECIAL_EQUAL:
218 EQ = 1;
219 break;
220 case SPECIAL_GTE:
221 GT = 1;
222 EQ = 1;
223 break;
224 case '>':
225 GT = 1;
228 switch (two) {
229 case '<':
230 LT = 1;
231 break;
232 case SPECIAL_LTE:
233 LT = 1;
234 EQ = 1;
235 break;
236 case SPECIAL_EQUAL:
237 EQ = 1;
238 break;
239 case SPECIAL_GTE:
240 GT = 1;
241 EQ = 1;
242 break;
243 case '>':
244 GT = 1;
247 if (LT && EQ && GT)
248 return 0;
249 if (LT && EQ)
250 return SPECIAL_LTE;
251 if (LT && GT)
252 return SPECIAL_NOTEQUAL;
253 if (LT)
254 return '<';
255 if (EQ && GT)
256 return SPECIAL_GTE;
257 if (GT)
258 return '>';
259 return 0;
262 static struct smatch_state *merge_compare_states(struct smatch_state *s1, struct smatch_state *s2)
264 struct compare_data *data = s1->data;
265 int op;
267 op = merge_comparisons(state_to_comparison(s1), state_to_comparison(s2));
268 if (op)
269 return alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, op);
270 return &undefined;
273 struct smatch_state *alloc_link_state(struct string_list *links)
275 struct smatch_state *state;
276 static char buf[256];
277 char *tmp;
278 int i;
280 state = __alloc_smatch_state(0);
282 i = 0;
283 FOR_EACH_PTR(links, tmp) {
284 if (!i++)
285 snprintf(buf, sizeof(buf), "%s", tmp);
286 else
287 snprintf(buf, sizeof(buf), "%s, %s", buf, tmp);
288 } END_FOR_EACH_PTR(tmp);
290 state->name = alloc_sname(buf);
291 state->data = links;
292 return state;
295 static void save_start_states(struct statement *stmt)
297 struct symbol *param;
298 char orig[64];
299 char state_name[128];
300 struct smatch_state *state;
301 struct string_list *links;
302 char *link;
304 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, param) {
305 if (!param->ident)
306 continue;
307 snprintf(orig, sizeof(orig), "%s orig", param->ident->name);
308 snprintf(state_name, sizeof(state_name), "%s vs %s", param->ident->name, orig);
309 state = alloc_compare_state(param->ident->name, param, alloc_sname(orig), NULL, SPECIAL_EQUAL);
310 set_state(compare_id, state_name, NULL, state);
312 link = alloc_sname(state_name);
313 links = NULL;
314 insert_string(&links, link);
315 state = alloc_link_state(links);
316 set_state(link_id, param->ident->name, param, state);
317 } END_FOR_EACH_PTR(param);
320 static struct smatch_state *merge_links(struct smatch_state *s1, struct smatch_state *s2)
322 struct smatch_state *ret;
323 struct string_list *links;
325 links = combine_string_lists(s1->data, s2->data);
326 ret = alloc_link_state(links);
327 return ret;
330 static void save_link(struct expression *expr, char *link)
332 struct smatch_state *old_state, *new_state;
333 struct string_list *links;
334 char *new;
336 old_state = get_state_expr(link_id, expr);
337 if (old_state)
338 links = clone_str_list(old_state->data);
339 else
340 links = NULL;
342 new = alloc_sname(link);
343 insert_string(&links, new);
345 new_state = alloc_link_state(links);
346 set_state_expr(link_id, expr, new_state);
349 static void match_inc(struct sm_state *sm)
351 struct string_list *links;
352 struct smatch_state *state;
353 char *tmp;
355 links = sm->state->data;
357 FOR_EACH_PTR(links, tmp) {
358 state = get_state(compare_id, tmp, NULL);
360 switch (state_to_comparison(state)) {
361 case SPECIAL_EQUAL:
362 case SPECIAL_GTE:
363 case SPECIAL_UNSIGNED_GTE:
364 case '>':
365 case SPECIAL_UNSIGNED_GT: {
366 struct compare_data *data = state->data;
367 struct smatch_state *new;
369 new = alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, '>');
370 set_state(compare_id, tmp, NULL, new);
371 break;
373 default:
374 set_state(compare_id, tmp, NULL, &undefined);
376 } END_FOR_EACH_PTR(tmp);
379 static void match_dec(struct sm_state *sm)
381 struct string_list *links;
382 struct smatch_state *state;
383 char *tmp;
385 links = sm->state->data;
387 FOR_EACH_PTR(links, tmp) {
388 state = get_state(compare_id, tmp, NULL);
390 switch (state_to_comparison(state)) {
391 case SPECIAL_EQUAL:
392 case SPECIAL_LTE:
393 case SPECIAL_UNSIGNED_LTE:
394 case '<':
395 case SPECIAL_UNSIGNED_LT: {
396 struct compare_data *data = state->data;
397 struct smatch_state *new;
399 new = alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, '<');
400 set_state(compare_id, tmp, NULL, new);
401 break;
403 default:
404 set_state(compare_id, tmp, NULL, &undefined);
406 } END_FOR_EACH_PTR(tmp);
409 static int match_inc_dec(struct sm_state *sm, struct expression *mod_expr)
411 if (!mod_expr)
412 return 0;
413 if (mod_expr->type != EXPR_PREOP && mod_expr->type != EXPR_POSTOP)
414 return 0;
416 if (mod_expr->op == SPECIAL_INCREMENT) {
417 match_inc(sm);
418 return 1;
420 if (mod_expr->op == SPECIAL_DECREMENT) {
421 match_dec(sm);
422 return 1;
424 return 0;
427 static void match_modify(struct sm_state *sm, struct expression *mod_expr)
429 struct string_list *links;
430 char *tmp;
432 if (match_inc_dec(sm, mod_expr))
433 return;
435 links = sm->state->data;
437 FOR_EACH_PTR(links, tmp) {
438 set_state(compare_id, tmp, NULL, &undefined);
439 } END_FOR_EACH_PTR(tmp);
440 set_state(link_id, sm->name, sm->sym, &undefined);
443 static void match_logic(struct expression *expr)
445 char *left = NULL;
446 char *right = NULL;
447 struct symbol *left_sym, *right_sym;
448 int op, false_op;
449 struct smatch_state *true_state, *false_state;
450 char state_name[256];
452 if (expr->type != EXPR_COMPARE)
453 return;
454 left = expr_to_var_sym(expr->left, &left_sym);
455 if (!left || !left_sym)
456 goto free;
457 right = expr_to_var_sym(expr->right, &right_sym);
458 if (!right || !right_sym)
459 goto free;
461 if (strcmp(left, right) > 0) {
462 struct symbol *tmp_sym = left_sym;
463 char *tmp_name = left;
465 left = right;
466 left_sym = right_sym;
467 right = tmp_name;
468 right_sym = tmp_sym;
469 op = flip_op(expr->op);
470 } else {
471 op = expr->op;
473 false_op = falsify_op(op);
474 snprintf(state_name, sizeof(state_name), "%s vs %s", left, right);
475 true_state = alloc_compare_state(left, left_sym, right, right_sym, op);
476 false_state = alloc_compare_state(left, left_sym, right, right_sym, false_op);
478 set_true_false_states(compare_id, state_name, NULL, true_state, false_state);
479 save_link(expr->left, state_name);
480 save_link(expr->right, state_name);
481 free:
482 free_string(left);
483 free_string(right);
486 static void add_comparison(struct expression *left, int comparison, struct expression *right)
488 char *left_name = NULL;
489 char *right_name = NULL;
490 struct symbol *left_sym, *right_sym;
491 struct smatch_state *state;
492 char state_name[256];
494 left_name = expr_to_var_sym(left, &left_sym);
495 if (!left_name || !left_sym)
496 goto free;
497 right_name = expr_to_var_sym(right, &right_sym);
498 if (!right_name || !right_sym)
499 goto free;
501 if (strcmp(left_name, right_name) > 0) {
502 struct symbol *tmp_sym = left_sym;
503 char *tmp_name = left_name;
505 left_name = right_name;
506 left_sym = right_sym;
507 right_name = tmp_name;
508 right_sym = tmp_sym;
509 comparison = flip_op(comparison);
511 snprintf(state_name, sizeof(state_name), "%s vs %s", left_name, right_name);
512 state = alloc_compare_state(left_name, left_sym, right_name, right_sym, comparison);
514 set_state(compare_id, state_name, NULL, state);
515 save_link(left, state_name);
516 save_link(right, state_name);
517 free:
518 free_string(left_name);
519 free_string(right_name);
523 static void match_assign_add(struct expression *expr)
525 struct expression *right;
526 struct expression *r_left, *r_right;
527 sval_t left_tmp, right_tmp;
529 right = strip_expr(expr->right);
530 r_left = strip_expr(right->left);
531 r_right = strip_expr(right->right);
533 if (!is_capped(expr->left)) {
534 get_absolute_max(r_left, &left_tmp);
535 get_absolute_max(r_right, &right_tmp);
536 if (sval_binop_overflows(left_tmp, '+', right_tmp))
537 return;
540 get_absolute_min(r_left, &left_tmp);
541 if (sval_is_negative(left_tmp))
542 return;
543 get_absolute_min(r_right, &right_tmp);
544 if (sval_is_negative(right_tmp))
545 return;
547 if (left_tmp.value == 0)
548 add_comparison(expr->left, SPECIAL_GTE, r_right);
549 else
550 add_comparison(expr->left, '>', r_right);
552 if (right_tmp.value == 0)
553 add_comparison(expr->left, SPECIAL_GTE, r_left);
554 else
555 add_comparison(expr->left, '>', r_left);
558 static void match_assign_sub(struct expression *expr)
560 struct expression *right;
561 struct expression *r_left, *r_right;
562 int comparison;
563 sval_t min;
565 right = strip_expr(expr->right);
566 r_left = strip_expr(right->left);
567 r_right = strip_expr(right->right);
569 if (get_absolute_min(r_right, &min) && sval_is_negative(min))
570 return;
572 comparison = get_comparison(r_left, r_right);
574 switch (comparison) {
575 case '>':
576 case SPECIAL_GTE:
577 if (implied_not_equal(r_right, 0))
578 add_comparison(expr->left, '>', r_left);
579 else
580 add_comparison(expr->left, SPECIAL_GTE, r_left);
581 return;
585 static void match_binop_assign(struct expression *expr)
587 struct expression *right;
589 right = strip_expr(expr->right);
590 if (right->op == '+')
591 match_assign_add(expr);
592 if (right->op == '-')
593 match_assign_sub(expr);
596 static void match_normal_assign(struct expression *expr)
598 add_comparison(expr->left, SPECIAL_EQUAL, expr->right);
601 static void match_assign(struct expression *expr)
603 struct expression *right;
605 right = strip_expr(expr->right);
606 if (right->type == EXPR_BINOP)
607 match_binop_assign(expr);
608 else
609 match_normal_assign(expr);
612 static int get_comparison_strings(char *one, char *two)
614 char buf[256];
615 struct smatch_state *state;
616 int invert = 0;
617 int ret = 0;
619 if (strcmp(one, two) > 0) {
620 char *tmp = one;
622 one = two;
623 two = tmp;
624 invert = 1;
627 snprintf(buf, sizeof(buf), "%s vs %s", one, two);
628 state = get_state(compare_id, buf, NULL);
629 if (state)
630 ret = state_to_comparison(state);
632 if (invert)
633 ret = flip_op(ret);
635 return ret;
638 int get_comparison(struct expression *a, struct expression *b)
640 char *one = NULL;
641 char *two = NULL;
642 int ret = 0;
644 one = expr_to_var(a);
645 if (!one)
646 goto free;
647 two = expr_to_var(b);
648 if (!two)
649 goto free;
651 ret = get_comparison_strings(one, two);
652 free:
653 free_string(one);
654 free_string(two);
655 return ret;
658 void __add_comparison_info(struct expression *expr, struct expression *call, const char *range)
660 struct expression *arg;
661 int comparison;
662 const char *c = range;
664 if (!str_to_comparison_arg(c, call, &comparison, &arg, NULL))
665 return;
666 add_comparison(expr, SPECIAL_LTE, arg);
669 static char *range_comparison_to_param_helper(struct expression *expr, char starts_with)
671 struct symbol *param;
672 char *var = NULL;
673 char buf[256];
674 char *ret_str = NULL;
675 int compare;
676 int i;
678 var = expr_to_var(expr);
679 if (!var)
680 goto free;
682 i = -1;
683 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, param) {
684 i++;
685 if (!param->ident)
686 continue;
687 snprintf(buf, sizeof(buf), "%s orig", param->ident->name);
688 compare = get_comparison_strings(var, buf);
689 if (!compare)
690 continue;
691 if (show_special(compare)[0] != starts_with)
692 continue;
693 snprintf(buf, sizeof(buf), "%sp%d", show_special(compare), i);
694 ret_str = alloc_sname(buf);
695 break;
696 } END_FOR_EACH_PTR(param);
698 free:
699 free_string(var);
700 return ret_str;
703 char *expr_equal_to_param(struct expression *expr)
705 return range_comparison_to_param_helper(expr, '=');
708 char *expr_lte_to_param(struct expression *expr)
710 return range_comparison_to_param_helper(expr, '<');
713 char *range_comparison_to_param(struct expression *expr)
715 char *comparison_str;
716 char buf[256];
717 sval_t min;
719 comparison_str = expr_equal_to_param(expr);
720 if (comparison_str) {
721 snprintf(buf, sizeof(buf), "[%s]", comparison_str);
722 return alloc_sname(buf);
725 comparison_str = expr_lte_to_param(expr);
726 if (!comparison_str)
727 return NULL;
728 get_absolute_min(expr, &min);
729 snprintf(buf, sizeof(buf), "%s-[%s]", sval_to_str(min), comparison_str);
730 return alloc_sname(buf);
733 static void free_data(struct symbol *sym)
735 if (__inline_fn)
736 return;
737 clear_compare_data_alloc();
740 void register_comparison(int id)
742 compare_id = id;
743 add_hook(&match_logic, CONDITION_HOOK);
744 add_hook(&match_assign, ASSIGNMENT_HOOK);
745 add_hook(&save_start_states, AFTER_DEF_HOOK);
746 add_unmatched_state_hook(compare_id, unmatched_comparison);
747 add_merge_hook(compare_id, &merge_compare_states);
748 add_hook(&free_data, AFTER_FUNC_HOOK);
751 void register_comparison_links(int id)
753 link_id = id;
754 add_merge_hook(link_id, &merge_links);
755 add_modification_hook(link_id, &match_modify);