function_hooks: function comparisons can imply a parameter value
[smatch.git] / smatch_parse_call_math.c
bloba66e597d3bb8c18764a9f794258c0bae7d0dda7b
1 /*
2 * Copyright (C) 2012 Oracle.
4 * This program is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU General Public License
6 * as published by the Free Software Foundation; either version 2
7 * of the License, or (at your option) any later version.
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
14 * You should have received a copy of the GNU General Public License
15 * along with this program; if not, see http://www.gnu.org/copyleft/gpl.txt
18 #include "smatch.h"
19 #include "smatch_slist.h"
20 #include "smatch_extra.h"
22 static int my_id;
24 struct {
25 const char *func;
26 int param;
27 } alloc_functions[] = {
28 {"kmalloc", 0},
29 {"__kmalloc", 0},
30 {"vmalloc", 0},
31 {"__vmalloc", 0},
32 {"__vmalloc_node", 0},
35 DECLARE_PTR_LIST(sval_list, sval_t);
37 static struct sval_list *num_list;
38 static struct string_list *op_list;
40 static void push_val(sval_t sval)
42 sval_t *p;
44 p = malloc(sizeof(*p));
45 *p = sval;
46 add_ptr_list(&num_list, p);
49 static sval_t pop_val(void)
51 sval_t *p;
52 sval_t ret;
54 if (!num_list) {
55 sm_msg("internal bug: %s popping empty list", __func__);
56 ret.type = &llong_ctype;
57 ret.value = 0;
58 return ret;
60 p = last_ptr_list((struct ptr_list *)num_list);
61 delete_ptr_list_last((struct ptr_list **)&num_list);
62 ret = *p;
63 free(p);
65 return ret;
68 static void push_op(char c)
70 char *p;
72 p = malloc(1);
73 p[0] = c;
74 add_ptr_list(&op_list, p);
77 static char pop_op(void)
79 char *p;
80 char c;
82 if (!op_list) {
83 sm_msg("internal smatch error %s", __func__);
84 return '\0';
87 p = last_ptr_list((struct ptr_list *)op_list);
89 delete_ptr_list_last((struct ptr_list **)&op_list);
90 c = p[0];
91 free(p);
93 return c;
96 static int op_precedence(char c)
98 switch (c) {
99 case '+':
100 case '-':
101 return 1;
102 case '*':
103 case '/':
104 return 2;
105 default:
106 return 0;
110 static int top_op_precedence(void)
112 char *p;
114 if (!op_list)
115 return 0;
117 p = last_ptr_list((struct ptr_list *)op_list);
118 return op_precedence(p[0]);
121 static void pop_until(char c)
123 char op;
124 sval_t left, right;
125 sval_t res;
127 while (top_op_precedence() && op_precedence(c) <= top_op_precedence()) {
128 op = pop_op();
129 right = pop_val();
130 left = pop_val();
131 res = sval_binop(left, op, right);
132 push_val(res);
136 static void discard_stacks(void)
138 while (op_list)
139 pop_op();
140 while (num_list)
141 pop_val();
144 static int get_implied_param(struct expression *call, int param, sval_t *sval)
146 struct expression *arg;
148 arg = get_argument_from_call_expr(call->args, param);
149 return get_implied_value(arg, sval);
152 static int read_number(struct expression *call, char *p, char **end, sval_t *sval)
154 long param;
156 while (*p == ' ')
157 p++;
159 if (*p == '$') {
160 p++;
161 param = strtol(p, &p, 10);
162 if (!get_implied_param(call, param, sval))
163 return 0;
164 *end = p;
165 } else {
166 sval->type = &llong_ctype;
167 sval->value = strtoll(p, end, 10);
168 if (*end == p)
169 return 0;
171 return 1;
174 static char *read_op(char *p)
176 while (*p == ' ')
177 p++;
179 switch (*p) {
180 case '+':
181 case '-':
182 case '*':
183 case '/':
184 return p;
185 default:
186 return NULL;
190 int parse_call_math(struct expression *call, char *math, sval_t *sval)
192 sval_t tmp;
193 char *c;
195 /* try to implement shunting yard algorithm. */
197 c = (char *)math;
198 while (1) {
199 if (option_debug)
200 sm_msg("parsing %s", c);
202 /* read a number and push it onto the number stack */
203 if (!read_number(call, c, &c, &tmp))
204 goto fail;
205 push_val(tmp);
207 if (option_debug)
208 sm_msg("val = %s remaining = %s", sval_to_str(tmp), c);
210 if (!*c)
211 break;
212 if (*c == ']' && *(c + 1) == '\0')
213 break;
215 c = read_op(c);
216 if (!c)
217 goto fail;
219 if (option_debug)
220 sm_msg("op = %c remaining = %s", *c, c);
222 pop_until(*c);
223 push_op(*c);
224 c++;
227 pop_until(0);
228 *sval = pop_val();
229 return 1;
230 fail:
231 discard_stacks();
232 return 0;
235 int parse_call_math_rl(struct expression *call, char *math, struct range_list **rl)
237 struct expression *arg;
238 sval_t sval;
239 char *c = math;
240 int param;
242 if (parse_call_math(call, math, &sval)) {
243 *rl = alloc_rl(sval, sval);
244 return 1;
247 if (*c != '$')
248 return 0;
249 c++;
250 param = strtoll(c, &c, 10);
251 if (*c != ']')
252 return 0;
253 arg = get_argument_from_call_expr(call->args, param);
254 return get_implied_rl(arg, rl);
257 static struct smatch_state *alloc_state_sname(char *sname)
259 struct smatch_state *state;
261 state = __alloc_smatch_state(0);
262 state->name = sname;
263 state->data = INT_PTR(1);
264 return state;
267 static int get_arg_number(struct expression *expr)
269 struct symbol *sym;
270 struct symbol *arg;
271 int i;
273 expr = strip_expr(expr);
274 if (expr->type != EXPR_SYMBOL)
275 return -1;
276 sym = expr->symbol;
278 i = 0;
279 FOR_EACH_PTR(cur_func_sym->ctype.base_type->arguments, arg) {
280 if (arg == sym)
281 return i;
282 i++;
283 } END_FOR_EACH_PTR(arg);
285 return -1;
288 static int format_expr_helper(char *buf, int remaining, struct expression *expr)
290 int arg;
291 sval_t sval;
292 int ret;
293 char *cur;
295 if (!expr)
296 return 0;
298 cur = buf;
300 if (expr->type == EXPR_BINOP) {
301 ret = format_expr_helper(cur, remaining, expr->left);
302 if (ret == 0)
303 return 0;
304 remaining -= ret;
305 if (remaining <= 0)
306 return 0;
307 cur += ret;
309 ret = snprintf(cur, remaining, " %s ", show_special(expr->op));
310 remaining -= ret;
311 if (remaining <= 0)
312 return 0;
313 cur += ret;
315 ret = format_expr_helper(cur, remaining, expr->right);
316 if (ret == 0)
317 return 0;
318 remaining -= ret;
319 if (remaining <= 0)
320 return 0;
321 cur += ret;
322 return cur - buf;
325 arg = get_arg_number(expr);
326 if (arg >= 0) {
327 ret = snprintf(cur, remaining, "$%d", arg);
328 remaining -= ret;
329 if (remaining <= 0)
330 return 0;
331 return ret;
334 if (get_implied_value(expr, &sval)) {
335 ret = snprintf(cur, remaining, "%s", sval_to_str(sval));
336 remaining -= ret;
337 if (remaining <= 0)
338 return 0;
339 return ret;
342 return 0;
345 static char *format_expr(struct expression *expr)
347 char buf[256];
348 int ret;
350 ret = format_expr_helper(buf, sizeof(buf), expr);
351 if (ret == 0)
352 return NULL;
354 return alloc_sname(buf);
357 char *get_value_in_terms_of_parameter_math(struct expression *expr)
359 struct expression *tmp;
360 char buf[256];
361 int ret;
363 tmp = get_assigned_expr(expr);
364 if (tmp)
365 expr = tmp;
367 ret = format_expr_helper(buf, sizeof(buf), expr);
368 if (ret == 0)
369 return NULL;
371 return alloc_sname(buf);
374 char *get_value_in_terms_of_parameter_math_var_sym(const char *name, struct symbol *sym)
376 struct expression *expr;
377 char buf[256];
378 int ret;
380 expr = get_assigned_expr_name_sym(name, sym);
381 if (!expr)
382 return NULL;
384 ret = format_expr_helper(buf, sizeof(buf), expr);
385 if (ret == 0)
386 return NULL;
388 return alloc_sname(buf);
392 static void match_alloc(const char *fn, struct expression *expr, void *_size_arg)
394 int size_arg = PTR_INT(_size_arg);
395 struct expression *right;
396 struct expression *size_expr;
397 char *sname;
399 right = strip_expr(expr->right);
400 size_expr = get_argument_from_call_expr(right->args, size_arg);
402 sname = format_expr(size_expr);
403 if (!sname)
404 return;
405 set_state_expr(my_id, expr->left, alloc_state_sname(sname));
408 static char *swap_format(struct expression *call, char *format)
410 char buf[256];
411 sval_t sval;
412 long param;
413 struct expression *arg;
414 char *p;
415 char *out;
416 int ret;
418 if (format[0] == '$' && format[2] == '\0') {
419 param = strtol(format + 1, NULL, 10);
420 arg = get_argument_from_call_expr(call->args, param);
421 if (!arg)
422 return NULL;
423 return format_expr(arg);
426 buf[0] = '\0';
427 p = format;
428 out = buf;
429 while (*p) {
430 if (*p == '<') {
431 p++;
432 param = strtol(p, &p, 10);
433 if (*p != '>')
434 return NULL;
435 p++;
436 arg = get_argument_from_call_expr(call->args, param);
437 if (!arg)
438 return NULL;
439 param = get_arg_number(arg);
440 if (param >= 0) {
441 ret = snprintf(out, buf + sizeof(buf) - out, "$%ld", param);
442 out += ret;
443 if (out >= buf + sizeof(buf))
444 return NULL;
445 } else if (get_implied_value(arg, &sval)) {
446 ret = snprintf(out, buf + sizeof(buf) - out, "%s", sval_to_str(sval));
447 out += ret;
448 if (out >= buf + sizeof(buf))
449 return NULL;
450 } else {
451 return NULL;
454 *out = *p;
455 p++;
456 out++;
458 if (buf[0] == '\0')
459 return NULL;
460 return alloc_sname(buf);
463 static char *buf_size_recipe;
464 static int db_buf_size_callback(void *unused, int argc, char **argv, char **azColName)
466 if (argc != 1)
467 return 0;
469 if (!buf_size_recipe)
470 buf_size_recipe = alloc_sname(argv[0]);
471 else if (strcmp(buf_size_recipe, argv[0]) != 0)
472 buf_size_recipe = alloc_sname("invalid");
473 return 0;
476 static char *get_allocation_recipe_from_call(struct expression *expr)
478 struct symbol *sym;
479 static char sql_filter[1024];
480 int i;
482 if (is_fake_call(expr))
483 return NULL;
484 expr = strip_expr(expr);
485 if (expr->fn->type != EXPR_SYMBOL)
486 return NULL;
487 sym = expr->fn->symbol;
488 if (!sym)
489 return NULL;
491 for (i = 0; i < ARRAY_SIZE(alloc_functions); i++) {
492 if (strcmp(sym->ident->name, alloc_functions[i].func) == 0) {
493 char buf[32];
495 snprintf(buf, sizeof(buf), "$%d", alloc_functions[i].param);
496 buf_size_recipe = alloc_sname(buf);
497 return swap_format(expr, buf_size_recipe);
501 if (sym->ctype.modifiers & MOD_STATIC) {
502 snprintf(sql_filter, 1024, "file = '%s' and function = '%s';",
503 get_filename(), sym->ident->name);
504 } else {
505 snprintf(sql_filter, 1024, "function = '%s' and static = 0;",
506 sym->ident->name);
509 buf_size_recipe = NULL;
510 run_sql(db_buf_size_callback, NULL,
511 "select value from return_states where type=%d and %s",
512 BUF_SIZE, sql_filter);
513 if (!buf_size_recipe || strcmp(buf_size_recipe, "invalid") == 0)
514 return NULL;
515 return swap_format(expr, buf_size_recipe);
518 static void match_call_assignment(struct expression *expr)
520 char *sname;
522 sname = get_allocation_recipe_from_call(expr->right);
523 if (!sname)
524 return;
525 set_state_expr(my_id, expr->left, alloc_state_sname(sname));
528 static void match_returns_call(int return_id, char *return_ranges, struct expression *call)
530 char *sname;
532 sname = get_allocation_recipe_from_call(call);
533 if (option_debug)
534 sm_msg("sname = %s", sname);
535 if (!sname)
536 return;
538 sql_insert_return_states(return_id, return_ranges, BUF_SIZE, -1, "",
539 sname);
542 static void print_returned_allocations(int return_id, char *return_ranges, struct expression *expr)
544 struct smatch_state *state;
545 struct symbol *sym;
546 char *name;
548 expr = strip_expr(expr);
549 if (!expr)
550 return;
552 if (expr->type == EXPR_CALL) {
553 match_returns_call(return_id, return_ranges, expr);
554 return;
557 name = expr_to_var_sym(expr, &sym);
558 if (!name || !sym)
559 goto free;
561 state = get_state(my_id, name, sym);
562 if (!state || !state->data)
563 goto free;
565 sql_insert_return_states(return_id, return_ranges, BUF_SIZE, -1, "",
566 state->name);
567 free:
568 free_string(name);
571 void register_parse_call_math(int id)
573 int i;
575 my_id = id;
577 for (i = 0; i < ARRAY_SIZE(alloc_functions); i++)
578 add_function_assign_hook(alloc_functions[i].func, &match_alloc,
579 INT_PTR(alloc_functions[i].param));
580 add_hook(&match_call_assignment, CALL_ASSIGNMENT_HOOK);
581 add_split_return_callback(print_returned_allocations);