extract generic isl_*_list_foreach_scc from isl_ast_codegen.c
authorSven Verdoolaege <skimo@kotnet.org>
Mon, 1 Apr 2013 09:47:03 +0000 (1 11:47 +0200)
committerSven Verdoolaege <skimo@kotnet.org>
Mon, 22 Apr 2013 07:38:20 +0000 (22 09:38 +0200)
isl_ast_codegen.c contained two implementations of this function
that have now been replaced by two calls to the extracted functions.
This makes isl_ast_codegen.c less dependent on the internals of
isl_basic_set_list.  We should also be able to later reuse the
extracted function in other contexts.  Once all users of isl_tarjan
have been converted to this interface, we probably want to remove
the old interface and perform the calls to callback as soon as
each SCC is found rather than waiting until all of them have been found.

Signed-off-by: Sven Verdoolaege <skimo@kotnet.org>
doc/user.pod
include/isl/list.h
isl_ast_codegen.c
isl_list_templ.c

index c8bd428..49d12f4 100644 (file)
@@ -3133,6 +3133,18 @@ Lists can be inspected using the following functions.
        int isl_set_list_foreach(__isl_keep isl_set_list *list,
                int (*fn)(__isl_take isl_set *el, void *user),
                void *user);
+       int isl_set_list_foreach_scc(__isl_keep isl_set_list *list,
+               int (*follows)(__isl_keep isl_set *a,
+                       __isl_keep isl_set *b, void *user),
+               void *follows_user
+               int (*fn)(__isl_take isl_set *el, void *user),
+               void *fn_user);
+
+The function C<isl_set_list_foreach_scc> calls C<fn> on each of the
+strongly connected components of the graph with as vertices the elements
+of C<list> and a directed edge from vertex C<b> to vertex C<a>
+iff C<follows(a, b)> returns C<1>.  The callbacks C<follows> and C<fn>
+should return C<-1> on error.
 
 Lists can be printed using
 
index d13987a..3e953af 100644 (file)
@@ -53,6 +53,12 @@ __isl_give isl_##EL##_list *isl_##EL##_list_sort(                    \
        int (*cmp)(__isl_keep struct isl_##EL *a,                       \
                __isl_keep struct isl_##EL *b,                          \
                void *user), void *user);                               \
+int isl_##EL##_list_foreach_scc(__isl_keep isl_##EL##_list *list,      \
+       int (*follows)(__isl_keep struct isl_##EL *a,                   \
+                       __isl_keep struct isl_##EL *b, void *user),     \
+       void *follows_user,                                             \
+       int (*fn)(__isl_take isl_##EL##_list *scc, void *user),         \
+       void *fn_user);                                                 \
 __isl_give isl_printer *isl_printer_print_##EL##_list(                 \
        __isl_take isl_printer *p, __isl_keep isl_##EL##_list *list);   \
 void isl_##EL##_list_dump(__isl_keep isl_##EL##_list *list);
index c09b2e4..58ba539 100644 (file)
@@ -1628,30 +1628,24 @@ done:
        return list;
 }
 
-struct isl_domain_follows_at_depth_data {
-       int depth;
-       isl_basic_set **piece;
-};
-
 /* Does any element of i follow or coincide with any element of j
- * at the current depth (data->depth) for equal values of the outer
- * dimensions?
+ * at the current depth for equal values of the outer dimensions?
  */
-static int domain_follows_at_depth(int i, int j, void *user)
+static int domain_follows_at_depth(__isl_keep isl_basic_set *i,
+       __isl_keep isl_basic_set *j, void *user)
 {
-       struct isl_domain_follows_at_depth_data *data = user;
+       int depth = *(int *) user;
        isl_basic_map *test;
        int empty;
        int l;
 
-       test = isl_basic_map_from_domain_and_range(
-                       isl_basic_set_copy(data->piece[i]),
-                       isl_basic_set_copy(data->piece[j]));
-       for (l = 0; l < data->depth; ++l)
+       test = isl_basic_map_from_domain_and_range(isl_basic_set_copy(i),
+                                                   isl_basic_set_copy(j));
+       for (l = 0; l < depth; ++l)
                test = isl_basic_map_equate(test, isl_dim_in, l,
                                                isl_dim_out, l);
-       test = isl_basic_map_order_ge(test, isl_dim_in, data->depth,
-                                       isl_dim_out, data->depth);
+       test = isl_basic_map_order_ge(test, isl_dim_in, depth,
+                                       isl_dim_out, depth);
        empty = isl_basic_map_is_empty(test);
        isl_basic_map_free(test);
 
@@ -1663,14 +1657,26 @@ static __isl_give isl_ast_graft_list *generate_sorted_domains(
        __isl_keep isl_union_map *executed,
        __isl_keep isl_ast_build *build);
 
-/* Generate code for the "n" schedule domains in "domain_list"
- * with positions specified by the entries of the "pos" array
+/* Internal data structure for add_nodes.
+ *
+ * "executed" and "build" are extra arguments to be passed to add_node.
+ * "list" collects the results.
+ */
+struct isl_add_nodes_data {
+       isl_union_map *executed;
+       isl_ast_build *build;
+
+       isl_ast_graft_list *list;
+};
+
+/* Generate code for the schedule domains in "scc"
  * and add the results to "list".
  *
- * The "n" domains form a strongly connected component in the ordering.
- * If n is larger than 1, then this means that we cannot determine a valid
- * ordering for the n domains in the component.  This should be fairly
- * rare because the individual domains have been made disjoint first.
+ * The domains in "scc" form a strongly connected component in the ordering.
+ * If the number of domains in "scc" is larger than 1, then this means
+ * that we cannot determine a valid ordering for the domains in the component.
+ * This should be fairly rare because the individual domains
+ * have been made disjoint first.
  * The problem is that the domains may be integrally disjoint but not
  * rationally disjoint.  For example, we may have domains
  *
@@ -1697,40 +1703,41 @@ static __isl_give isl_ast_graft_list *generate_sorted_domains(
  * convex combination in terms of a and b and in terms of c and d.
  * Taking the same combination of i and j gives a point in the intersection.
  */
-static __isl_give isl_ast_graft_list *add_nodes(
-       __isl_take isl_ast_graft_list *list, int *pos, int n,
-       __isl_keep isl_basic_set_list *domain_list,
-       __isl_keep isl_union_map *executed,
-       __isl_keep isl_ast_build *build)
+static int add_nodes(__isl_take isl_basic_set_list *scc, void *user)
 {
-       int i;
+       struct isl_add_nodes_data *data = user;
+       int i, n;
        isl_basic_set *bset;
        isl_set *set;
 
-       bset = isl_basic_set_list_get_basic_set(domain_list, pos[0]);
-       if (n == 1)
-               return add_node(list, isl_union_map_copy(executed), bset,
-                               isl_ast_build_copy(build));
+       n = isl_basic_set_list_n_basic_set(scc);
+       bset = isl_basic_set_list_get_basic_set(scc, 0);
+       if (n == 1) {
+               isl_basic_set_list_free(scc);
+               data->list = add_node(data->list,
+                               isl_union_map_copy(data->executed), bset,
+                               isl_ast_build_copy(data->build));
+               return data->list ? 0 : -1;
+       }
 
        set = isl_set_from_basic_set(bset);
        for (i = 1; i < n; ++i) {
-               bset = isl_basic_set_list_get_basic_set(domain_list, pos[i]);
+               bset = isl_basic_set_list_get_basic_set(scc, i);
                set = isl_set_union(set, isl_set_from_basic_set(bset));
        }
 
        set = isl_set_make_disjoint(set);
        if (isl_set_n_basic_set(set) == n)
-               isl_die(isl_ast_graft_list_get_ctx(list), isl_error_internal,
-                       "unable to separate loop parts", goto error);
-       domain_list = isl_basic_set_list_from_set(set);
-       list = isl_ast_graft_list_concat(list,
-                   generate_sorted_domains(domain_list, executed, build));
-       isl_basic_set_list_free(domain_list);
+               isl_die(isl_basic_set_list_get_ctx(scc), isl_error_internal,
+                       "unable to separate loop parts",
+                       set = isl_set_free(set));
+       isl_basic_set_list_free(scc);
+       scc = isl_basic_set_list_from_set(set);
+       data->list = isl_ast_graft_list_concat(data->list,
+                   generate_sorted_domains(scc, data->executed, data->build));
+       isl_basic_set_list_free(scc);
 
-       return list;
-error:
-       isl_set_free(set);
-       return isl_ast_graft_list_free(list);
+       return data->list ? 0 : -1;
 }
 
 /* Sort the domains in "domain_list" according to the execution order
@@ -1751,71 +1758,47 @@ static __isl_give isl_ast_graft_list *generate_sorted_domains(
        __isl_keep isl_union_map *executed, __isl_keep isl_ast_build *build)
 {
        isl_ctx *ctx;
-       isl_ast_graft_list *list;
-       struct isl_domain_follows_at_depth_data data;
-       struct isl_tarjan_graph *g;
-       int i, n;
+       struct isl_add_nodes_data data;
+       int depth;
+       int n;
 
        if (!domain_list)
                return NULL;
 
        ctx = isl_basic_set_list_get_ctx(domain_list);
        n = isl_basic_set_list_n_basic_set(domain_list);
-       list = isl_ast_graft_list_alloc(ctx, n);
+       data.list = isl_ast_graft_list_alloc(ctx, n);
        if (n == 0)
-               return list;
+               return data.list;
        if (n == 1)
-               return add_node(list, isl_union_map_copy(executed),
+               return add_node(data.list, isl_union_map_copy(executed),
                        isl_basic_set_list_get_basic_set(domain_list, 0),
                        isl_ast_build_copy(build));
 
-       data.depth = isl_ast_build_get_depth(build);
-       data.piece = domain_list->p;
-       g = isl_tarjan_graph_init(ctx, n, &domain_follows_at_depth, &data);
-       if (!g)
-               goto error;
-
-       i = 0;
-       while (list && n) {
-               int first;
-
-               if (g->order[i] == -1)
-                       isl_die(ctx, isl_error_internal, "cannot happen",
-                               goto error);
-               first = i;
-               while (g->order[i] != -1) {
-                       ++i; --n;
-               }
-               list = add_nodes(list, g->order + first, i - first,
-                                       domain_list, executed, build);
-               ++i;
-       }
-
-       if (0)
-error:         list = isl_ast_graft_list_free(list);
-       isl_tarjan_graph_free(g);
+       depth = isl_ast_build_get_depth(build);
+       data.executed = executed;
+       data.build = build;
+       if (isl_basic_set_list_foreach_scc(domain_list,
+                                       &domain_follows_at_depth, &depth,
+                                       &add_nodes, &data) < 0)
+               data.list = isl_ast_graft_list_free(data.list);
 
-       return list;
+       return data.list;
 }
 
-struct isl_shared_outer_data {
-       int depth;
-       isl_basic_set **piece;
-};
-
-/* Do elements i and j share any values for the outer dimensions?
+/* Do i and j share any values for the outer dimensions?
  */
-static int shared_outer(int i, int j, void *user)
+static int shared_outer(__isl_keep isl_basic_set *i,
+       __isl_keep isl_basic_set *j, void *user)
 {
-       struct isl_shared_outer_data *data = user;
+       int depth = *(int *) user;
        isl_basic_map *test;
        int empty;
        int l;
 
-       test = isl_basic_map_from_domain_and_range(
-                       isl_basic_set_copy(data->piece[i]),
-                       isl_basic_set_copy(data->piece[j]));
-       for (l = 0; l < data->depth; ++l)
+       test = isl_basic_map_from_domain_and_range(isl_basic_set_copy(i),
+                                                   isl_basic_set_copy(j));
+       for (l = 0; l < depth; ++l)
                test = isl_basic_map_equate(test, isl_dim_in, l,
                                                isl_dim_out, l);
        empty = isl_basic_map_is_empty(test);
@@ -1824,32 +1807,54 @@ static int shared_outer(int i, int j, void *user)
        return empty < 0 ? -1 : !empty;
 }
 
-/* Call generate_sorted_domains on a list containing the elements
- * of "domain_list indexed by the first "n" elements of "pos".
+/* Internal data structure for generate_sorted_domains_wrap.
+ *
+ * "n" is the total number of basic sets
+ * "executed" and "build" are extra arguments to be passed
+ *     to generate_sorted_domains.
+ *
+ * "single" is set to 1 by generate_sorted_domains_wrap if there
+ * is only a single component.
+ * "list" collects the results.
  */
-static __isl_give isl_ast_graft_list *generate_sorted_domains_part(
-       __isl_keep isl_basic_set_list *domain_list, int *pos, int n,
-       __isl_keep isl_union_map *executed,
-       __isl_keep isl_ast_build *build)
-{
-       int i;
-       isl_ctx *ctx;
-       isl_basic_set_list *slice;
+struct isl_ast_generate_parallel_domains_data {
+       int n;
+       isl_union_map *executed;
+       isl_ast_build *build;
+
+       int single;
        isl_ast_graft_list *list;
+};
 
-       ctx = isl_ast_build_get_ctx(build);
-       slice = isl_basic_set_list_alloc(ctx, n);
-       for (i = 0; i < n; ++i) {
-               isl_basic_set *bset;
+/* Call generate_sorted_domains on "scc", fuse the result into a list
+ * with either zero or one graft and collect the these single element
+ * lists into data->list.
+ *
+ * If there is only one component, i.e., if the number of basic sets
+ * in the current component is equal to the total number of basic sets,
+ * then data->single is set to 1 and the result of generate_sorted_domains
+ * is not fused.
+ */
+static int generate_sorted_domains_wrap(__isl_take isl_basic_set_list *scc,
+       void *user)
+{
+       struct isl_ast_generate_parallel_domains_data *data = user;
+       isl_ast_graft_list *list;
 
-               bset = isl_basic_set_copy(domain_list->p[pos[i]]);
-               slice = isl_basic_set_list_add(slice, bset);
-       }
+       list = generate_sorted_domains(scc, data->executed, data->build);
+       data->single = isl_basic_set_list_n_basic_set(scc) == data->n;
+       if (!data->single)
+               list = isl_ast_graft_list_fuse(list, data->build);
+       if (!data->list)
+               data->list = list;
+       else
+               data->list = isl_ast_graft_list_concat(data->list, list);
 
-       list = generate_sorted_domains(slice, executed, build);
-       isl_basic_set_list_free(slice);
+       isl_basic_set_list_free(scc);
+       if (!data->list)
+               return -1;
 
-       return list;
+       return 0;
 }
 
 /* Look for any (weakly connected) components in the "domain_list"
@@ -1860,8 +1865,9 @@ static __isl_give isl_ast_graft_list *generate_sorted_domains_part(
  * Within each of the components, we sort the domains according
  * to the execution order at the current depth.
  *
- * We fuse the result of each call to generate_sorted_domains_part
- * into a list with either zero or one graft and collect these (at most)
+ * If there is more than one component, then generate_sorted_domains_wrap
+ * fuses the result of each call to generate_sorted_domains
+ * into a list with either zero or one graft and collects these (at most)
  * single element lists into a bigger list. This means that the elements of the
  * final list can be freely reordered.  In particular, we sort them
  * according to an arbitrary but fixed ordering to ease merging of
@@ -1871,62 +1877,30 @@ static __isl_give isl_ast_graft_list *generate_parallel_domains(
        __isl_keep isl_basic_set_list *domain_list,
        __isl_keep isl_union_map *executed, __isl_keep isl_ast_build *build)
 {
-       int i, n;
-       isl_ctx *ctx;
-       isl_ast_graft_list *list;
-       struct isl_shared_outer_data data;
-       struct isl_tarjan_graph *g;
+       int depth;
+       struct isl_ast_generate_parallel_domains_data data;
 
        if (!domain_list)
                return NULL;
 
-       n = isl_basic_set_list_n_basic_set(domain_list);
-       if (n <= 1)
+       data.n = isl_basic_set_list_n_basic_set(domain_list);
+       if (data.n <= 1)
                return generate_sorted_domains(domain_list, executed, build);
 
-       ctx = isl_basic_set_list_get_ctx(domain_list);
-
-       data.depth = isl_ast_build_get_depth(build);
-       data.piece = domain_list->p;
-       g = isl_tarjan_graph_init(ctx, n, &shared_outer, &data);
-       if (!g)
-               return NULL;
-
-       i = 0;
-       do {
-               int first;
-               isl_ast_graft_list *list_c;
-
-               if (g->order[i] == -1)
-                       isl_die(ctx, isl_error_internal, "cannot happen",
-                               break);
-               first = i;
-               while (g->order[i] != -1) {
-                       ++i; --n;
-               }
-               if (first == 0 && n == 0) {
-                       isl_tarjan_graph_free(g);
-                       return generate_sorted_domains(domain_list,
-                                                       executed, build);
-               }
-               list_c = generate_sorted_domains_part(domain_list,
-                               g->order + first, i - first, executed, build);
-               list_c = isl_ast_graft_list_fuse(list_c, build);
-               if (first == 0)
-                       list = list_c;
-               else
-                       list = isl_ast_graft_list_concat(list, list_c);
-               ++i;
-       } while (list && n);
-
-       if (n > 0)
-               list = isl_ast_graft_list_free(list);
-
-       list = isl_ast_graft_list_sort_guard(list);
+       depth = isl_ast_build_get_depth(build);
+       data.list = NULL;
+       data.executed = executed;
+       data.build = build;
+       data.single = 0;
+       if (isl_basic_set_list_foreach_scc(domain_list, &shared_outer, &depth,
+                                           &generate_sorted_domains_wrap,
+                                           &data) < 0)
+               data.list = isl_ast_graft_list_free(data.list);
 
-       isl_tarjan_graph_free(g);
+       if (!data.single)
+               data.list = isl_ast_graft_list_sort_guard(data.list);
 
-       return list;
+       return data.list;
 }
 
 /* Internal data for separate_domain.
index 5e861b1..71024ce 100644 (file)
@@ -13,6 +13,7 @@
  */
 
 #include <isl_sort.h>
+#include <isl_tarjan.h>
 
 #define xCAT(A,B) A ## B
 #define CAT(A,B) xCAT(A,B)
@@ -332,6 +333,112 @@ __isl_give LIST(EL) *FN(LIST(EL),sort)(__isl_take LIST(EL) *list,
        return list;
 }
 
+/* Internal data structure for isl_*_list_foreach_scc.
+ *
+ * "list" is the original list.
+ * "follows" is the user provided callback that defines the edges of the graph.
+ */
+S(LIST(EL),foreach_scc_data) {
+       LIST(EL) *list;
+       int (*follows)(__isl_keep EL *a, __isl_keep EL *b, void *user);
+       void *follows_user;
+};
+
+/* Does element i of data->list follow element j?
+ *
+ * Use the user provided callback to find out.
+ */
+static int FN(LIST(EL),follows)(int i, int j, void *user)
+{
+       S(LIST(EL),foreach_scc_data) *data = user;
+
+       return data->follows(data->list->p[i], data->list->p[j],
+                               data->follows_user);
+}
+
+/* Call "fn" on the sublist of "list" that consists of the elements
+ * with indices specified by the "n" elements of "pos".
+ */
+static int FN(LIST(EL),call_on_scc)(__isl_keep LIST(EL) *list, int *pos, int n,
+       int (*fn)(__isl_take LIST(EL) *scc, void *user), void *user)
+{
+       int i;
+       isl_ctx *ctx;
+       LIST(EL) *slice;
+
+       ctx = FN(LIST(EL),get_ctx)(list);
+       slice = FN(LIST(EL),alloc)(ctx, n);
+       for (i = 0; i < n; ++i) {
+               EL *el;
+
+               el = FN(EL,copy)(list->p[pos[i]]);
+               slice = FN(LIST(EL),add)(slice, el);
+       }
+
+       return fn(slice, user);
+}
+
+/* Call "fn" on each of the strongly connected components (SCCs) of
+ * the graph with as vertices the elements of "list" and
+ * a directed edge from node b to node a iff follows(a, b)
+ * returns 1.  follows should return -1 on error.
+ *
+ * If SCC a contains a node i that follows a node j in another SCC b
+ * (i.e., follows(i, j, user) returns 1), then fn will be called on SCC a
+ * after being called on SCC b.
+ *
+ * We simply call isl_tarjan_graph_init, extract the SCCs from the result and
+ * call fn on each of them.
+ */
+int FN(LIST(EL),foreach_scc)(__isl_keep LIST(EL) *list,
+       int (*follows)(__isl_keep EL *a, __isl_keep EL *b, void *user),
+       void *follows_user,
+       int (*fn)(__isl_take LIST(EL) *scc, void *user), void *fn_user)
+{
+       S(LIST(EL),foreach_scc_data) data = { list, follows, follows_user };
+       int i, n;
+       isl_ctx *ctx;
+       struct isl_tarjan_graph *g;
+
+       if (!list)
+               return -1;
+       if (list->n == 0)
+               return 0;
+       if (list->n == 1)
+               return fn(FN(LIST(EL),copy)(list), fn_user);
+
+       ctx = FN(LIST(EL),get_ctx)(list);
+       n = list->n;
+       g = isl_tarjan_graph_init(ctx, n, &FN(LIST(EL),follows), &data);
+       if (!g)
+               return -1;
+
+       i = 0;
+       do {
+               int first;
+
+               if (g->order[i] == -1)
+                       isl_die(ctx, isl_error_internal, "cannot happen",
+                               break);
+               first = i;
+               while (g->order[i] != -1) {
+                       ++i; --n;
+               }
+               if (first == 0 && n == 0) {
+                       isl_tarjan_graph_free(g);
+                       return fn(FN(LIST(EL),copy)(list), fn_user);
+               }
+               if (FN(LIST(EL),call_on_scc)(list, g->order + first, i - first,
+                                           fn, fn_user) < 0)
+                       break;
+               ++i;
+       } while (n);
+
+       isl_tarjan_graph_free(g);
+
+       return n > 0 ? -1 : 0;
+}
+
 __isl_give LIST(EL) *FN(FN(LIST(EL),from),BASE)(__isl_take EL *el)
 {
        isl_ctx *ctx;