comparison: if "a < b" and "b < c" then "a < c"
authorDan Carpenter <dan.carpenter@oracle.com>
Thu, 27 Jun 2013 22:18:56 +0000 (28 01:18 +0300)
committerDan Carpenter <dan.carpenter@oracle.com>
Thu, 27 Jun 2013 22:18:56 +0000 (28 01:18 +0300)
Comparisons can be inherited.  It's probably more common that we say if
"a == b" and "a < c" that means "b < c".

Signed-off-by: Dan Carpenter <dan.carpenter@oracle.com>
smatch_comparison.c
validation/sm_compare10.c [new file with mode: 0644]

index e62af42..f8f5ea4 100644 (file)
@@ -220,6 +220,10 @@ static int remove_unsigned_from_comparison(int op)
        }
 }
 
+/*
+ * This is for when you merge states "a < b" and "a == b", the result is that
+ * we can say for sure, "a <= b" after the merge.
+ */
 static int merge_comparisons(int one, int two)
 {
        int LT, EQ, GT;
@@ -282,6 +286,74 @@ static int merge_comparisons(int one, int two)
        return 0;
 }
 
+/*
+ * This is for if you have "a < b" and "b <= c".  Or in other words,
+ * "a < b <= c".  You would call this like get_combined_comparison('<', '<=').
+ * The return comparison would be '<'.
+ *
+ * This function is different from merge_comparisons(), for example:
+ * merge_comparison('<', '==') returns '<='
+ * get_combined_comparison('<', '==') returns '<'
+ */
+static int combine_comparisons(int left_compare, int right_compare)
+{
+       int LT, EQ, GT;
+
+       left_compare = remove_unsigned_from_comparison(left_compare);
+       right_compare = remove_unsigned_from_comparison(right_compare);
+
+       LT = EQ = GT = 0;
+
+       switch (left_compare) {
+       case '<':
+               LT++;
+               break;
+       case SPECIAL_LTE:
+               LT++;
+               EQ++;
+               break;
+       case SPECIAL_EQUAL:
+               return right_compare;
+       case SPECIAL_GTE:
+               GT++;
+               EQ++;
+               break;
+       case '>':
+               GT++;
+       }
+
+       switch (right_compare) {
+       case '<':
+               LT++;
+               break;
+       case SPECIAL_LTE:
+               LT++;
+               EQ++;
+               break;
+       case SPECIAL_EQUAL:
+               return left_compare;
+       case SPECIAL_GTE:
+               GT++;
+               EQ++;
+               break;
+       case '>':
+               GT++;
+       }
+
+       if (LT == 2) {
+               if (EQ == 2)
+                       return SPECIAL_LTE;
+               return '<';
+       }
+
+       if (GT == 2) {
+               if (EQ == 2)
+                       return SPECIAL_GTE;
+               return '>';
+       }
+       return 0;
+}
+
 static struct smatch_state *merge_compare_states(struct smatch_state *s1, struct smatch_state *s2)
 {
        struct compare_data *data = s1->data;
@@ -477,6 +549,82 @@ static void match_modify(struct sm_state *sm, struct expression *mod_expr)
        set_state(link_id, sm->name, sm->sym, &undefined);
 }
 
+static void update_tf_links(const char *left_var, struct symbol *left_sym,
+                           int left_comparison,
+                           const char *mid_var, struct symbol *mid_sym,
+                           struct string_list *links)
+{
+       struct smatch_state *state;
+       struct smatch_state *true_state, *false_state;
+       struct compare_data *data;
+       const char *right_var;
+       struct symbol *right_sym;
+       int right_comparison;
+       int true_comparison;
+       int false_comparison;
+       char *tmp;
+       char state_name[256];
+
+       FOR_EACH_PTR(links, tmp) {
+               state = get_state(compare_id, tmp, NULL);
+               if (!state || !state->data)
+                       continue;
+               data = state->data;
+               right_comparison = data->comparison;
+               right_var = data->var2;
+               right_sym = data->sym2;
+               if (var_sym_eq(mid_var, mid_sym, right_var, right_sym)) {
+                       right_var = data->var1;
+                       right_sym = data->sym1;
+                       right_comparison = flip_op(right_comparison);
+               }
+               true_comparison = combine_comparisons(left_comparison, right_comparison);
+               false_comparison = combine_comparisons(falsify_op(left_comparison), right_comparison);
+
+               if (strcmp(left_var, right_var) > 0) {
+                       struct symbol *tmp_sym = left_sym;
+                       const char *tmp_var = left_var;
+
+                       left_var = right_var;
+                       left_sym = right_sym;
+                       right_var = tmp_var;
+                       right_sym = tmp_sym;
+                       true_comparison = flip_op(true_comparison);
+                       false_comparison = flip_op(false_comparison);
+               }
+
+               if (!true_comparison && !false_comparison)
+                       continue;
+
+               if (true_comparison)
+                       true_state = alloc_compare_state(left_var, left_sym, true_comparison, right_var, right_sym);
+               else
+                       true_state = NULL;
+               if (false_comparison)
+                       false_state = alloc_compare_state(left_var, left_sym, false_comparison, right_var, right_sym);
+               else
+                       false_state = NULL;
+
+               snprintf(state_name, sizeof(state_name), "%s vs %s", left_var, right_var);
+               set_true_false_states(compare_id, state_name, NULL, true_state, false_state);
+               save_link_var_sym(left_var, left_sym, state_name);
+               save_link_var_sym(right_var, right_sym, state_name);
+       } END_FOR_EACH_PTR(tmp);
+}
+
+static void update_tf_data(struct compare_data *tdata)
+{
+       struct smatch_state *state;
+
+       state = get_state(link_id, tdata->var2, tdata->sym2);
+       if (state)
+               update_tf_links(tdata->var1, tdata->sym1, tdata->comparison, tdata->var2, tdata->sym2, state->data);
+
+       state = get_state(link_id, tdata->var1, tdata->sym1);
+       if (state)
+               update_tf_links(tdata->var2, tdata->sym2, flip_op(tdata->comparison), tdata->var1, tdata->sym1, state->data);
+}
+
 static void match_compare(struct expression *expr)
 {
        char *left = NULL;
@@ -512,6 +660,8 @@ static void match_compare(struct expression *expr)
        true_state = alloc_compare_state(left, left_sym, op, right, right_sym);
        false_state = alloc_compare_state(left, left_sym, false_op, right, right_sym);
 
+       update_tf_data(true_state->data);
+
        set_true_false_states(compare_id, state_name, NULL, true_state, false_state);
        save_link(expr->left, state_name);
        save_link(expr->right, state_name);
diff --git a/validation/sm_compare10.c b/validation/sm_compare10.c
new file mode 100644 (file)
index 0000000..410e16c
--- /dev/null
@@ -0,0 +1,20 @@
+#include "check_debug.h"
+
+int a, b, c;
+static int options_write(void)
+{
+       if (c <= b)
+               return;
+       if (a >= b)
+               return;
+       __smatch_compare(a, c);
+}
+
+/*
+ * check-name: smatch compare #10
+ * check-command: smatch -I.. sm_compare10.c
+ *
+ * check-output-start
+sm_compare10.c:10 options_write() a < c
+ * check-output-end
+ */