From 10e0baba956daa8c6577692e6e70447cf1d44f98 Mon Sep 17 00:00:00 2001 From: Sven Verdoolaege Date: Tue, 20 Nov 2012 20:22:35 +0100 Subject: [PATCH] add before_each_for/after_each_for callbacks --- doc/user.pod | 25 ++++++++ include/isl/ast_build.h | 4 +- isl_ast_build.c | 44 ++++++++++++++ isl_ast_build_private.h | 14 +++++ isl_ast_codegen.c | 40 +++++++++++++ isl_test.c | 154 ++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 279 insertions(+), 2 deletions(-) diff --git a/doc/user.pod b/doc/user.pod index 78fa57e0..e8969c68 100644 --- a/doc/user.pod +++ b/doc/user.pod @@ -5975,9 +5975,34 @@ user defined node created using the following function. __isl_take isl_ast_node *node, __isl_keep isl_ast_build *build, void *user), void *user); + __isl_give isl_ast_build * + isl_ast_build_set_before_each_for( + __isl_take isl_ast_build *build, + __isl_give isl_id *(*fn)( + __isl_keep isl_ast_build *build, + void *user), void *user); + __isl_give isl_ast_build * + isl_ast_build_set_after_each_for( + __isl_take isl_ast_build *build, + __isl_give isl_ast_node *(*fn)( + __isl_take isl_ast_node *node, + __isl_keep isl_ast_build *build, + void *user), void *user); The callback set by C will be called for each domain AST node. +The callbacks set by C +and C will be called +for each for AST node. The first will be called in depth-first +pre-order, while the second will be called in depth-first post-order. +Since C is called before the for +node is actually constructed, it is only passed an C. +The returned C will be added as an annotation (using +C) to the constructed for node. +In particular, if the user has also specified an C +callback, then the annotation can be retrieved from the node passed to +that callback using C. +All callbacks should C on failure. The given C can be used to create new C objects using C or C. diff --git a/include/isl/ast_build.h b/include/isl/ast_build.h index daf78dc5..7294ed7c 100644 --- a/include/isl/ast_build.h +++ b/include/isl/ast_build.h @@ -64,8 +64,8 @@ __isl_give isl_ast_build *isl_ast_build_set_at_each_domain( __isl_keep isl_ast_build *build, void *user), void *user); __isl_give isl_ast_build *isl_ast_build_set_before_each_for( __isl_take isl_ast_build *build, - __isl_give isl_ast_node *(*fn)(__isl_take isl_ast_node *node, - __isl_keep isl_ast_build *build, void *user), void *user); + __isl_give isl_id *(*fn)(__isl_keep isl_ast_build *build, + void *user), void *user); __isl_give isl_ast_build *isl_ast_build_set_after_each_for( __isl_take isl_ast_build *build, __isl_give isl_ast_node *(*fn)(__isl_take isl_ast_node *node, diff --git a/isl_ast_build.c b/isl_ast_build.c index 0673726b..26aabab5 100644 --- a/isl_ast_build.c +++ b/isl_ast_build.c @@ -180,6 +180,10 @@ __isl_give isl_ast_build *isl_ast_build_dup(__isl_keep isl_ast_build *build) dup->options = isl_union_map_copy(build->options); dup->at_each_domain = build->at_each_domain; dup->at_each_domain_user = build->at_each_domain_user; + dup->before_each_for = build->before_each_for; + dup->before_each_for_user = build->before_each_for_user; + dup->after_each_for = build->after_each_for; + dup->after_each_for_user = build->after_each_for_user; dup->create_leaf = build->create_leaf; dup->create_leaf_user = build->create_leaf_user; @@ -338,6 +342,42 @@ __isl_give isl_ast_build *isl_ast_build_set_at_each_domain( return build; } +/* Set the "before_each_for" callback of "build" to "fn". + */ +__isl_give isl_ast_build *isl_ast_build_set_before_each_for( + __isl_take isl_ast_build *build, + __isl_give isl_id *(*fn)(__isl_keep isl_ast_build *build, + void *user), void *user) +{ + build = isl_ast_build_cow(build); + + if (!build) + return NULL; + + build->before_each_for = fn; + build->before_each_for_user = user; + + return build; +} + +/* Set the "after_each_for" callback of "build" to "fn". + */ +__isl_give isl_ast_build *isl_ast_build_set_after_each_for( + __isl_take isl_ast_build *build, + __isl_give isl_ast_node *(*fn)(__isl_take isl_ast_node *node, + __isl_keep isl_ast_build *build, void *user), void *user) +{ + build = isl_ast_build_cow(build); + + if (!build) + return NULL; + + build->after_each_for = fn; + build->after_each_for_user = user; + + return build; +} + /* Set the "create_leaf" callback of "build" to "fn". */ __isl_give isl_ast_build *isl_ast_build_set_create_leaf( @@ -374,6 +414,10 @@ __isl_give isl_ast_build *isl_ast_build_clear_local_info( build->at_each_domain = NULL; build->at_each_domain_user = NULL; + build->before_each_for = NULL; + build->before_each_for_user = NULL; + build->after_each_for = NULL; + build->after_each_for_user = NULL; build->create_leaf = NULL; build->create_leaf_user = NULL; diff --git a/isl_ast_build_private.h b/isl_ast_build_private.h index 0ee88c51..5881efff 100644 --- a/isl_ast_build_private.h +++ b/isl_ast_build_private.h @@ -98,6 +98,12 @@ enum isl_ast_build_domain_type { * an element of the domain. Each of these nodes is a user node * with as expression a call expression. * + * The "before_each_for" callback is called on each for node before + * its children have been created. + * + * The "after_each_for" callback is called on each for node after + * its children have been created. + * * "executed" contains the inverse schedule at this point * of the AST generation. * It is currently only used in isl_ast_build_get_schedule, which is @@ -131,6 +137,14 @@ struct isl_ast_build { __isl_keep isl_ast_build *build, void *user); void *at_each_domain_user; + __isl_give isl_id *(*before_each_for)( + __isl_keep isl_ast_build *context, void *user); + void *before_each_for_user; + __isl_give isl_ast_node *(*after_each_for)( + __isl_take isl_ast_node *node, + __isl_keep isl_ast_build *context, void *user); + void *after_each_for_user; + __isl_give isl_ast_node *(*create_leaf)( __isl_take isl_ast_build *build, void *user); void *create_leaf_user; diff --git a/isl_ast_codegen.c b/isl_ast_codegen.c index 89bc512c..e0e217dc 100644 --- a/isl_ast_codegen.c +++ b/isl_ast_codegen.c @@ -269,6 +269,38 @@ error: data.list = NULL; return data.list; } +/* Call the before_each_for callback, if requested by the user. + */ +static __isl_give isl_ast_node *before_each_for(__isl_take isl_ast_node *node, + __isl_keep isl_ast_build *build) +{ + isl_id *id; + + if (!node || !build) + return isl_ast_node_free(node); + if (!build->before_each_for) + return node; + id = build->before_each_for(build, build->before_each_for_user); + node = isl_ast_node_set_annotation(node, id); + return node; +} + +/* Call the after_each_for callback, if requested by the user. + */ +static __isl_give isl_ast_graft *after_each_for(__isl_keep isl_ast_graft *graft, + __isl_keep isl_ast_build *build) +{ + if (!graft || !build) + isl_ast_graft_free(graft); + if (!build->after_each_for) + return graft; + graft->node = build->after_each_for(graft->node, build, + build->after_each_for_user); + if (!graft->node) + return isl_ast_graft_free(graft); + return graft; +} + /* Eliminate the schedule dimension "pos" from "executed" and return * the result. */ @@ -1177,6 +1209,9 @@ static __isl_give isl_ast_node *create_for(__isl_keep isl_ast_build *build, * we performed separation with explicit bounds. * The very first step is then to copy these constraints to "bounds". * + * Since we may be calling before_each_for and after_each_for + * callbacks, we record the current inverse schedule in the build. + * * We consider three builds, * "build" is the one in which the current level is created, * "body_build" is the build in which the next level is created, @@ -1230,6 +1265,7 @@ static __isl_give isl_ast_graft *create_node_scaled( domain = isl_set_detect_equalities(domain); hull = isl_set_unshifted_simple_hull(isl_set_copy(domain)); bounds = isl_basic_set_intersect(bounds, hull); + build = isl_ast_build_set_executed(build, isl_union_map_copy(executed)); depth = isl_ast_build_get_depth(build); sub_build = isl_ast_build_copy(build); @@ -1247,6 +1283,8 @@ static __isl_give isl_ast_graft *create_node_scaled( body_build = isl_ast_build_copy(sub_build); body_build = isl_ast_build_increase_depth(body_build); + if (!eliminated) + node = before_each_for(node, body_build); children = generate_next_level(executed, isl_ast_build_copy(body_build)); @@ -1259,6 +1297,8 @@ static __isl_give isl_ast_graft *create_node_scaled( graft = refine_degenerate(graft, bounds, build, sub_build); else graft = refine_generic(graft, bounds, domain, build); + if (!eliminated) + graft = after_each_for(graft, body_build); isl_ast_build_free(body_build); isl_ast_build_free(sub_build); diff --git a/isl_test.c b/isl_test.c index 09c88f9f..b39c6829 100644 --- a/isl_test.c +++ b/isl_test.c @@ -3438,6 +3438,158 @@ static int test_ast(isl_ctx *ctx) return 0; } +/* Internal data structure for before_for and after_for callbacks. + * + * depth is the current depth + * before is the number of times before_for has been called + * after is the number of times after_for has been called + */ +struct isl_test_codegen_data { + int depth; + int before; + int after; +}; + +/* This function is called before each for loop in the AST generated + * from test_ast_gen1. + * + * Increment the number of calls and the depth. + * Check that the space returned by isl_ast_build_get_schedule_space + * matches the target space of the schedule returned by + * isl_ast_build_get_schedule. + * Return an isl_id that is checked by the corresponding call + * to after_for. + */ +static __isl_give isl_id *before_for(__isl_keep isl_ast_build *build, + void *user) +{ + struct isl_test_codegen_data *data = user; + isl_ctx *ctx; + isl_space *space; + isl_union_map *schedule; + isl_union_set *uset; + isl_set *set; + int empty; + char name[] = "d0"; + + ctx = isl_ast_build_get_ctx(build); + + if (data->before >= 3) + isl_die(ctx, isl_error_unknown, + "unexpected number of for nodes", return NULL); + if (data->depth >= 2) + isl_die(ctx, isl_error_unknown, + "unexpected depth", return NULL); + + snprintf(name, sizeof(name), "d%d", data->depth); + data->before++; + data->depth++; + + schedule = isl_ast_build_get_schedule(build); + uset = isl_union_map_range(schedule); + if (isl_union_set_n_set(uset) != 1) { + isl_union_set_free(uset); + isl_die(ctx, isl_error_unknown, + "expecting single range space", return NULL); + } + + space = isl_ast_build_get_schedule_space(build); + set = isl_union_set_extract_set(uset, space); + isl_union_set_free(uset); + empty = isl_set_is_empty(set); + isl_set_free(set); + + if (empty < 0) + return NULL; + if (empty) + isl_die(ctx, isl_error_unknown, + "spaces don't match", return NULL); + + return isl_id_alloc(ctx, name, NULL); +} + +/* This function is called after each for loop in the AST generated + * from test_ast_gen1. + * + * Increment the number of calls and decrement the depth. + * Check that the annotation attached to the node matches + * the isl_id returned by the corresponding call to before_for. + */ +static __isl_give isl_ast_node *after_for(__isl_take isl_ast_node *node, + __isl_keep isl_ast_build *build, void *user) +{ + struct isl_test_codegen_data *data = user; + isl_id *id; + const char *name; + int valid; + + data->after++; + data->depth--; + + if (data->after > data->before) + isl_die(isl_ast_node_get_ctx(node), isl_error_unknown, + "mismatch in number of for nodes", + return isl_ast_node_free(node)); + + id = isl_ast_node_get_annotation(node); + if (!id) + isl_die(isl_ast_node_get_ctx(node), isl_error_unknown, + "missing annotation", return isl_ast_node_free(node)); + + name = isl_id_get_name(id); + valid = name && atoi(name + 1) == data->depth; + isl_id_free(id); + + if (!valid) + isl_die(isl_ast_node_get_ctx(node), isl_error_unknown, + "wrong annotation", return isl_ast_node_free(node)); + + return node; +} + +/* Check that the before_each_for and after_each_for callbacks + * are called for each for loop in the generated code, + * that they are called in the right order and that the isl_id + * returned from the before_each_for callback is attached to + * the isl_ast_node passed to the corresponding after_each_for call. + */ +static int test_ast_gen1(isl_ctx *ctx) +{ + const char *str; + isl_set *set; + isl_union_map *schedule; + isl_ast_build *build; + isl_ast_node *tree; + struct isl_test_codegen_data data; + + str = "[N] -> { : N >= 10 }"; + set = isl_set_read_from_str(ctx, str); + str = "[N] -> { A[i,j] -> S[8,i,3,j] : 0 <= i,j <= N; " + "B[i,j] -> S[8,j,9,i] : 0 <= i,j <= N }"; + schedule = isl_union_map_read_from_str(ctx, str); + + data.before = 0; + data.after = 0; + data.depth = 0; + build = isl_ast_build_from_context(set); + build = isl_ast_build_set_before_each_for(build, + &before_for, &data); + build = isl_ast_build_set_after_each_for(build, + &after_for, &data); + tree = isl_ast_build_ast_from_schedule(build, schedule); + isl_ast_build_free(build); + if (!tree) + return -1; + + isl_ast_node_free(tree); + + if (data.before != 3 || data.after != 3) + isl_die(ctx, isl_error_unknown, + "unexpected number of for nodes", return -1); + + return 0; +} + /* Check that the AST generator handles domains that are integrally disjoint * but not ratinoally disjoint. */ @@ -3576,6 +3728,8 @@ static int test_ast_gen4(isl_ctx *ctx) static int test_ast_gen(isl_ctx *ctx) { + if (test_ast_gen1(ctx) < 0) + return -1; if (test_ast_gen2(ctx) < 0) return -1; if (test_ast_gen3(ctx) < 0) -- 2.11.4.GIT