From bd6313bdc48d021c2a8b3dc7d6d3c62fe0cdd28d Mon Sep 17 00:00:00 2001 From: Dan Carpenter Date: Wed, 12 Jun 2013 23:20:40 +0300 Subject: [PATCH] comparison: handle merging comparisons This is a pretty involved patch. Theoretically I could have broken it up into two parts, but when I'm coding my own stuff then I have no standards so it's all going in as one big patch. The code we care about looks like this: if (x < 10) y = x; else y = 10; __smatch_compare(x, y); On the first branch y == x. On the second branch we wouldn't normally store the relationship between x and y so we have an unmatched state. We can determine the relationship though. Then we merge the states together to get the final relationship. The first thing to say is that smatch_compare.c didn't used to store the name and symbols of the variables being compared. So I had to create a dynamically allocated smatch_state so that information could be stored in the ->data. The next part of the change is if there is an unmatched state. We can get the implied range list for both variables and determine the relationship from that. y is 10. x is 10-max. That means y <= x. Then when we merge it it's simple enough y == x and y <= x combine to say that y <= x. Signed-off-by: Dan Carpenter --- smatch_comparison.c | 302 ++++++++++++++++++++++++++++++++++++----------- validation/sm_compare6.c | 23 ++++ 2 files changed, 259 insertions(+), 66 deletions(-) create mode 100644 validation/sm_compare6.c diff --git a/smatch_comparison.c b/smatch_comparison.c index 28447771..9784874d 100644 --- a/smatch_comparison.c +++ b/smatch_comparison.c @@ -29,49 +29,45 @@ static int compare_id; static int link_id; -static struct smatch_state compare_states[] = { - ['<'] = { - .name = "<", - .data = (void *)'<', - }, - [SPECIAL_UNSIGNED_LT] = { - .name = "<", - .data = (void *)SPECIAL_UNSIGNED_LT, - }, - [SPECIAL_LTE] = { - .name = "<=", - .data = (void *)SPECIAL_LTE, - }, - [SPECIAL_UNSIGNED_LTE] = { - .name = "<=", - .data = (void *)SPECIAL_UNSIGNED_LTE, - }, - [SPECIAL_EQUAL] = { - .name = "==", - .data = (void *)SPECIAL_EQUAL, - }, - [SPECIAL_NOTEQUAL] = { - .name = "!=", - .data = (void *)SPECIAL_NOTEQUAL, - }, - [SPECIAL_GTE] = { - .name = ">=", - .data = (void *)SPECIAL_GTE, - }, - [SPECIAL_UNSIGNED_GTE] = { - .name = ">=", - .data = (void *)SPECIAL_UNSIGNED_GTE, - }, - ['>'] = { - .name = ">", - .data = (void *)'>', - }, - [SPECIAL_UNSIGNED_GT] = { - .name = ">", - .data = (void *)SPECIAL_UNSIGNED_GT, - }, +struct compare_data { + const char *var1; + struct symbol *sym1; + const char *var2; + struct symbol *sym2; + int comparison; }; +ALLOCATOR(compare_data, "compare data"); + +static struct smatch_state *alloc_compare_state( + const char *var1, struct symbol *sym1, + const char *var2, struct symbol *sym2, + int comparison) +{ + struct smatch_state *state; + struct compare_data *data; + + state = __alloc_smatch_state(0); + state->name = alloc_sname(show_special(comparison)); + data = __alloc_compare_data(0); + data->var1 = alloc_sname(var1); + data->sym1 = sym1; + data->var2 = alloc_sname(var2); + data->sym2 = sym2; + data->comparison = comparison; + state->data = data; + return state; +} + +static int state_to_comparison(struct smatch_state *state) +{ + if (!state || !state->data) + return 0; + return ((struct compare_data *)state->data)->comparison; +} +/* + * flip_op() reverses the op left and right. So "x >= y" becomes "y <= x". + */ static int flip_op(int op) { switch (op) { @@ -134,6 +130,146 @@ static int falsify_op(int op) } } +static int rl_comparison(struct range_list *left_rl, struct range_list *right_rl) +{ + sval_t left_min, left_max, right_min, right_max; + + if (!left_rl || !right_rl) + return 0; + + left_min = rl_min(left_rl); + left_max = rl_max(left_rl); + right_min = rl_min(right_rl); + right_max = rl_max(right_rl); + + if (left_min.value == left_max.value && + right_min.value == right_max.value && + left_min.value == right_min.value) + return SPECIAL_EQUAL; + + if (sval_cmp(left_max, right_min) < 0) + return '<'; + if (sval_cmp(left_max, right_min) == 0) + return SPECIAL_LTE; + if (sval_cmp(left_min, right_max) > 0) + return '>'; + if (sval_cmp(left_min, right_max) == 0) + return SPECIAL_GTE; + + return 0; +} + +static struct smatch_state *unmatched_comparison(struct sm_state *sm) +{ + struct compare_data *data = sm->state->data; + struct range_list *left_rl, *right_rl; + int op; + + if (!data) + return &undefined; + + if (!get_implied_rl_var_sym(data->var1, data->sym1, &left_rl)) + return &undefined; + if (!get_implied_rl_var_sym(data->var2, data->sym2, &right_rl)) + return &undefined; + + op = rl_comparison(left_rl, right_rl); + if (op) + return alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, op); + + return &undefined; +} + +/* remove_unsigned_from_comparison() is obviously a hack. */ +static int remove_unsigned_from_comparison(int op) +{ + switch (op) { + case SPECIAL_UNSIGNED_LT: + return '<'; + case SPECIAL_UNSIGNED_LTE: + return SPECIAL_LTE; + case SPECIAL_UNSIGNED_GTE: + return SPECIAL_GTE; + case SPECIAL_UNSIGNED_GT: + return '>'; + default: + return op; + } +} + +static int merge_comparisons(int one, int two) +{ + int LT, EQ, GT; + + one = remove_unsigned_from_comparison(one); + two = remove_unsigned_from_comparison(two); + + LT = EQ = GT = 0; + + switch (one) { + case '<': + LT = 1; + break; + case SPECIAL_LTE: + LT = 1; + EQ = 1; + break; + case SPECIAL_EQUAL: + EQ = 1; + break; + case SPECIAL_GTE: + GT = 1; + EQ = 1; + break; + case '>': + GT = 1; + } + + switch (two) { + case '<': + LT = 1; + break; + case SPECIAL_LTE: + LT = 1; + EQ = 1; + break; + case SPECIAL_EQUAL: + EQ = 1; + break; + case SPECIAL_GTE: + GT = 1; + EQ = 1; + break; + case '>': + GT = 1; + } + + if (LT && EQ && GT) + return 0; + if (LT && EQ) + return SPECIAL_LTE; + if (LT && GT) + return SPECIAL_NOTEQUAL; + if (LT) + return '<'; + if (EQ && GT) + return SPECIAL_GTE; + if (GT) + return '>'; + return 0; +} + +static struct smatch_state *merge_compare_states(struct smatch_state *s1, struct smatch_state *s2) +{ + struct compare_data *data = s1->data; + int op; + + op = merge_comparisons(state_to_comparison(s1), state_to_comparison(s2)); + if (op) + return alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, op); + return &undefined; +} + struct smatch_state *alloc_link_state(struct string_list *links) { struct smatch_state *state; @@ -170,7 +306,7 @@ static void save_start_states(struct statement *stmt) continue; snprintf(orig, sizeof(orig), "%s orig", param->ident->name); snprintf(state_name, sizeof(state_name), "%s vs %s", param->ident->name, orig); - state = &compare_states[SPECIAL_EQUAL]; + state = alloc_compare_state(param->ident->name, param, alloc_sname(orig), NULL, SPECIAL_EQUAL); set_state(compare_id, state_name, NULL, state); link = alloc_sname(state_name); @@ -181,7 +317,7 @@ static void save_start_states(struct statement *stmt) } END_FOR_EACH_PTR(param); } -static struct smatch_state *merge_func(struct smatch_state *s1, struct smatch_state *s2) +static struct smatch_state *merge_links(struct smatch_state *s1, struct smatch_state *s2) { struct smatch_state *ret; struct string_list *links; @@ -220,13 +356,21 @@ static void match_inc(struct sm_state *sm) FOR_EACH_PTR(links, tmp) { state = get_state(compare_id, tmp, NULL); - if (state == &compare_states[SPECIAL_EQUAL] || - state == &compare_states[SPECIAL_GTE] || - state == &compare_states[SPECIAL_UNSIGNED_GTE] || - state == &compare_states['>'] || - state == &compare_states[SPECIAL_UNSIGNED_GT]) { - set_state(compare_id, tmp, NULL, &compare_states['>']); - } else { + + switch (state_to_comparison(state)) { + case SPECIAL_EQUAL: + case SPECIAL_GTE: + case SPECIAL_UNSIGNED_GTE: + case '>': + case SPECIAL_UNSIGNED_GT: { + struct compare_data *data = state->data; + struct smatch_state *new; + + new = alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, '>'); + set_state(compare_id, tmp, NULL, new); + break; + } + default: set_state(compare_id, tmp, NULL, &undefined); } } END_FOR_EACH_PTR(tmp); @@ -242,13 +386,21 @@ static void match_dec(struct sm_state *sm) FOR_EACH_PTR(links, tmp) { state = get_state(compare_id, tmp, NULL); - if (state == &compare_states[SPECIAL_EQUAL] || - state == &compare_states[SPECIAL_LTE] || - state == &compare_states[SPECIAL_UNSIGNED_LTE] || - state == &compare_states['<'] || - state == &compare_states[SPECIAL_UNSIGNED_LT]) { - set_state(compare_id, tmp, NULL, &compare_states['<']); - } else { + + switch (state_to_comparison(state)) { + case SPECIAL_EQUAL: + case SPECIAL_LTE: + case SPECIAL_UNSIGNED_LTE: + case '<': + case SPECIAL_UNSIGNED_LT: { + struct compare_data *data = state->data; + struct smatch_state *new; + + new = alloc_compare_state(data->var1, data->sym1, data->var2, data->sym2, '<'); + set_state(compare_id, tmp, NULL, new); + break; + } + default: set_state(compare_id, tmp, NULL, &undefined); } } END_FOR_EACH_PTR(tmp); @@ -307,17 +459,21 @@ static void match_logic(struct expression *expr) goto free; if (strcmp(left, right) > 0) { - char *tmp = left; + struct symbol *tmp_sym = left_sym; + char *tmp_name = left; + left = right; - right = tmp; + left_sym = right_sym; + right = tmp_name; + right_sym = tmp_sym; op = flip_op(expr->op); } else { op = expr->op; } false_op = falsify_op(op); snprintf(state_name, sizeof(state_name), "%s vs %s", left, right); - true_state = &compare_states[op]; - false_state = &compare_states[false_op]; + true_state = alloc_compare_state(left, left_sym, right, right_sym, op); + false_state = alloc_compare_state(left, left_sym, right, right_sym, false_op); set_true_false_states(compare_id, state_name, NULL, true_state, false_state); save_link(expr->left, state_name); @@ -343,13 +499,17 @@ static void add_comparison(struct expression *left, int comparison, struct expre goto free; if (strcmp(left_name, right_name) > 0) { - char *tmp = left_name; + struct symbol *tmp_sym = left_sym; + char *tmp_name = left_name; + left_name = right_name; - right_name = tmp; + left_sym = right_sym; + right_name = tmp_name; + right_sym = tmp_sym; comparison = flip_op(comparison); } snprintf(state_name, sizeof(state_name), "%s vs %s", left_name, right_name); - state = &compare_states[comparison]; + state = alloc_compare_state(left_name, left_sym, right_name, right_sym, comparison); set_state(compare_id, state_name, NULL, state); save_link(left, state_name); @@ -467,7 +627,7 @@ static int get_comparison_strings(char *one, char *two) snprintf(buf, sizeof(buf), "%s vs %s", one, two); state = get_state(compare_id, buf, NULL); if (state) - ret = PTR_INT(state->data); + ret = state_to_comparison(state); if (invert) ret = flip_op(ret); @@ -546,17 +706,27 @@ free: return ret_str; } +static void free_data(struct symbol *sym) +{ + if (__inline_fn) + return; + clear_compare_data_alloc(); +} + void register_comparison(int id) { compare_id = id; add_hook(&match_logic, CONDITION_HOOK); add_hook(&match_assign, ASSIGNMENT_HOOK); add_hook(&save_start_states, AFTER_DEF_HOOK); + add_unmatched_state_hook(compare_id, unmatched_comparison); + add_merge_hook(compare_id, &merge_compare_states); + add_hook(&free_data, AFTER_FUNC_HOOK); } void register_comparison_links(int id) { link_id = id; - add_merge_hook(link_id, &merge_func); + add_merge_hook(link_id, &merge_links); add_modification_hook(link_id, &match_modify); } diff --git a/validation/sm_compare6.c b/validation/sm_compare6.c new file mode 100644 index 00000000..60b8da73 --- /dev/null +++ b/validation/sm_compare6.c @@ -0,0 +1,23 @@ +#include "check_debug.h" + +int returns_less(int x) +{ + int y; + + if (x > 10) + y = 10; + else + y = x; + + __smatch_compare(x, y); + return y; +} + +/* + * check-name: smatch compare #6 + * check-command: smatch -I.. sm_compare6.c + * + * check-output-start +sm_compare6.c:12 returns_less() x >= y + * check-output-end + */ -- 2.11.4.GIT