gpu.c: fix typo in comment
[ppcg.git] / grouping.c
blobc2628d3c8d09def01de61a6a69d3905f1f595707
1 /*
2 * Copyright 2016 Sven Verdoolaege
4 * Use of this software is governed by the MIT license
6 * Written by Sven Verdoolaege.
7 */
9 #include <isl/ctx.h>
10 #include <isl/id.h>
11 #include <isl/val.h>
12 #include <isl/space.h>
13 #include <isl/aff.h>
14 #include <isl/set.h>
15 #include <isl/map.h>
16 #include <isl/union_set.h>
17 #include <isl/union_map.h>
18 #include <isl/schedule.h>
19 #include <isl/schedule_node.h>
21 #include "ppcg.h"
23 /* Internal data structure for use during the detection of statements
24 * that can be grouped.
26 * "sc" contains the original schedule constraints (not a copy).
27 * "dep" contains the intersection of the validity and the proximity
28 * constraints in "sc". It may be NULL if it has not been computed yet.
29 * "group_id" is the identifier for the next group that is extracted.
31 * "domain" is the set of statement instances that belong to any of the groups.
32 * "contraction" maps the elements of "domain" to the corresponding group
33 * instances.
34 * "schedule" schedules the statements in each group relatively to each other.
35 * These last three fields are NULL if no groups have been found so far.
37 struct ppcg_grouping {
38 isl_schedule_constraints *sc;
40 isl_union_map *dep;
41 int group_id;
43 isl_union_set *domain;
44 isl_union_pw_multi_aff *contraction;
45 isl_schedule *schedule;
48 /* Clear all memory allocated by "grouping".
50 static void ppcg_grouping_clear(struct ppcg_grouping *grouping)
52 isl_union_map_free(grouping->dep);
53 isl_union_set_free(grouping->domain);
54 isl_union_pw_multi_aff_free(grouping->contraction);
55 isl_schedule_free(grouping->schedule);
58 /* Compute the intersection of the proximity and validity dependences
59 * in grouping->sc and store the result in grouping->dep, unless
60 * this intersection has been computed before.
62 static isl_stat ppcg_grouping_compute_dep(struct ppcg_grouping *grouping)
64 isl_union_map *validity, *proximity;
66 if (grouping->dep)
67 return isl_stat_ok;
69 validity = isl_schedule_constraints_get_validity(grouping->sc);
70 proximity = isl_schedule_constraints_get_proximity(grouping->sc);
71 grouping->dep = isl_union_map_intersect(validity, proximity);
73 if (!grouping->dep)
74 return isl_stat_error;
76 return isl_stat_ok;
79 /* Information extracted from one or more consecutive leaves
80 * in the input schedule.
82 * "list" contains the sets of statement instances in the leaves,
83 * one element in the list for each original leaf.
84 * "domain" contains the union of the sets in "list".
85 * "prefix" contains the prefix schedule of these elements.
87 struct ppcg_grouping_leaf {
88 isl_union_set *domain;
89 isl_union_set_list *list;
90 isl_multi_union_pw_aff *prefix;
93 /* Free all memory allocated for "leaves".
95 static void ppcg_grouping_leaf_free(int n, struct ppcg_grouping_leaf leaves[n])
97 int i;
99 if (!leaves)
100 return;
102 for (i = 0; i < n; ++i) {
103 isl_union_set_free(leaves[i].domain);
104 isl_union_set_list_free(leaves[i].list);
105 isl_multi_union_pw_aff_free(leaves[i].prefix);
108 free(leaves);
111 /* Short-hand for retrieving the prefix schedule at "node"
112 * in the form of an isl_multi_union_pw_aff.
114 static __isl_give isl_multi_union_pw_aff *get_prefix(
115 __isl_keep isl_schedule_node *node)
117 return isl_schedule_node_get_prefix_schedule_multi_union_pw_aff(node);
120 /* Return an array of "n" elements with information extracted from
121 * the "n" children of "node" starting at "first", all of which
122 * are known to be filtered leaves.
124 struct ppcg_grouping_leaf *extract_leaves(__isl_keep isl_schedule_node *node,
125 int first, int n)
127 int i;
128 isl_ctx *ctx;
129 struct ppcg_grouping_leaf *leaves;
131 if (!node)
132 return NULL;
134 ctx = isl_schedule_node_get_ctx(node);
135 leaves = isl_calloc_array(ctx, struct ppcg_grouping_leaf, n);
136 if (!leaves)
137 return NULL;
139 for (i = 0; i < n; ++i) {
140 isl_schedule_node *child;
141 isl_union_set *domain;
143 child = isl_schedule_node_get_child(node, first + i);
144 child = isl_schedule_node_child(child, 0);
145 domain = isl_schedule_node_get_domain(child);
146 leaves[i].domain = isl_union_set_copy(domain);
147 leaves[i].list = isl_union_set_list_from_union_set(domain);
148 leaves[i].prefix = get_prefix(child);
149 isl_schedule_node_free(child);
152 return leaves;
155 /* Internal data structure used by merge_leaves.
157 * "src" and "dst" point to the two consecutive leaves that are
158 * under investigation for being merged.
159 * "merge" is initially set to 0 and is set to 1 as soon as
160 * it turns out that it is useful to merge the two leaves.
162 struct ppcg_merge_leaves_data {
163 int merge;
164 struct ppcg_grouping_leaf *src;
165 struct ppcg_grouping_leaf *dst;
168 /* Given a relation "map" between instances of two statements A and B,
169 * does it relate every instance of A (according to the domain of "src")
170 * to every instance of B (according to the domain of "dst")?
172 static isl_bool covers_src_and_dst(__isl_keep isl_map *map,
173 struct ppcg_grouping_leaf *src, struct ppcg_grouping_leaf *dst)
175 isl_space *space;
176 isl_set *set1, *set2;
177 isl_bool is_subset;
179 space = isl_space_domain(isl_map_get_space(map));
180 set1 = isl_union_set_extract_set(src->domain, space);
181 set2 = isl_map_domain(isl_map_copy(map));
182 is_subset = isl_set_is_subset(set1, set2);
183 isl_set_free(set1);
184 isl_set_free(set2);
185 if (is_subset < 0 || !is_subset)
186 return is_subset;
188 space = isl_space_range(isl_map_get_space(map));
189 set1 = isl_union_set_extract_set(dst->domain, space);
190 set2 = isl_map_range(isl_map_copy(map));
191 is_subset = isl_set_is_subset(set1, set2);
192 isl_set_free(set1);
193 isl_set_free(set2);
195 return is_subset;
198 /* Given a relation "map" between instances of two statements A and B,
199 * are pairs of related instances executed together in the input schedule?
200 * That is, is each pair of instances assigned the same value
201 * by the corresponding prefix schedules?
203 * In particular, select the subset of "map" that has pairs of elements
204 * with the same value for the prefix schedules and then check
205 * if "map" is still a subset of the result.
207 static isl_bool matches_prefix(__isl_keep isl_map *map,
208 struct ppcg_grouping_leaf *src, struct ppcg_grouping_leaf *dst)
210 isl_union_map *umap, *equal;
211 isl_multi_union_pw_aff *src_prefix, *dst_prefix, *prefix;
212 isl_bool is_subset;
214 src_prefix = isl_multi_union_pw_aff_copy(src->prefix);
215 dst_prefix = isl_multi_union_pw_aff_copy(dst->prefix);
216 prefix = isl_multi_union_pw_aff_union_add(src_prefix, dst_prefix);
218 umap = isl_union_map_from_map(isl_map_copy(map));
219 equal = isl_union_map_copy(umap);
220 equal = isl_union_map_eq_at_multi_union_pw_aff(equal, prefix);
222 is_subset = isl_union_map_is_subset(umap, equal);
224 isl_union_map_free(umap);
225 isl_union_map_free(equal);
227 return is_subset;
230 /* Given a set of validity and proximity schedule constraints "map"
231 * between statements in consecutive leaves in a valid schedule,
232 * should the two leaves be merged into one?
234 * In particular, the two are merged if the constraints form
235 * a bijection between every instance of the first statement and
236 * every instance of the second statement. Moreover, each
237 * pair of such dependent instances needs to be executed consecutively
238 * in the input schedule. That is, they need to be assigned
239 * the same value by their prefix schedules.
241 * What this means is that for each instance of the first statement
242 * there is exactly one instance of the second statement that
243 * is executed immediately after the instance of the first statement and
244 * that, moreover, both depends on this statement instance and
245 * should be brought as close as possible to this statement instance.
246 * In other words, it is both possible to execute the two instances
247 * together (according to the input schedule) and desirable to do so
248 * (according to the validity and proximity schedule constraints).
250 static isl_stat check_merge(__isl_take isl_map *map, void *user)
252 struct ppcg_merge_leaves_data *data = user;
253 isl_bool ok;
255 ok = covers_src_and_dst(map, data->src, data->dst);
256 if (ok >= 0 && ok)
257 ok = isl_map_is_bijective(map);
258 if (ok >= 0 && ok)
259 ok = matches_prefix(map, data->src, data->dst);
261 isl_map_free(map);
263 if (ok < 0)
264 return isl_stat_error;
265 if (!ok)
266 return isl_stat_ok;
268 data->merge = 1;
269 return isl_stat_error;
272 /* Merge the leaves at position "pos" and "pos + 1" in "leaves".
274 static isl_stat merge_pair(int n, struct ppcg_grouping_leaf leaves[n], int pos)
276 int i;
278 leaves[pos].domain = isl_union_set_union(leaves[pos].domain,
279 leaves[pos + 1].domain);
280 leaves[pos].list = isl_union_set_list_concat(leaves[pos].list,
281 leaves[pos + 1].list);
282 leaves[pos].prefix = isl_multi_union_pw_aff_union_add(
283 leaves[pos].prefix, leaves[pos + 1].prefix);
284 for (i = pos + 1; i + 1 < n; ++i)
285 leaves[i] = leaves[i + 1];
286 leaves[n - 1].domain = NULL;
287 leaves[n - 1].list = NULL;
288 leaves[n - 1].prefix = NULL;
290 if (!leaves[pos].domain || !leaves[pos].list || !leaves[pos].prefix)
291 return isl_stat_error;
293 return isl_stat_ok;
296 /* Merge pairs of consecutive leaves in "leaves" taking into account
297 * the intersection of validity and proximity schedule constraints "dep".
299 * If a leaf has been merged with the next leaf, then the combination
300 * is checked again for merging with the next leaf.
301 * That is, if the leaves are A, B and C, then B may not have been
302 * merged with C, but after merging A and B, it could still be useful
303 * to merge the combination AB with C.
305 * Two leaves A and B are merged if there are instances of at least
306 * one pair of statements, one statement in A and one B, such that
307 * the validity and proximity schedule constraints between them
308 * make them suitable for merging according to check_merge.
310 * Return the final number of leaves in the sequence, or -1 on error.
312 static int merge_leaves(int n, struct ppcg_grouping_leaf leaves[n],
313 __isl_keep isl_union_map *dep)
315 int i;
316 struct ppcg_merge_leaves_data data;
318 for (i = n - 1; i >= 0; --i) {
319 isl_union_map *dep_i;
320 isl_stat ok;
322 if (i + 1 >= n)
323 continue;
325 dep_i = isl_union_map_copy(dep);
326 dep_i = isl_union_map_intersect_domain(dep_i,
327 isl_union_set_copy(leaves[i].domain));
328 dep_i = isl_union_map_intersect_range(dep_i,
329 isl_union_set_copy(leaves[i + 1].domain));
330 data.merge = 0;
331 data.src = &leaves[i];
332 data.dst = &leaves[i + 1];
333 ok = isl_union_map_foreach_map(dep_i, &check_merge, &data);
334 isl_union_map_free(dep_i);
335 if (ok < 0 && !data.merge)
336 return -1;
337 if (!data.merge)
338 continue;
339 if (merge_pair(n, leaves, i) < 0)
340 return -1;
341 --n;
342 ++i;
345 return n;
348 /* Construct a schedule with "domain" as domain, that executes
349 * the elements of "list" in order (as a sequence).
351 static __isl_give isl_schedule *schedule_from_domain_and_list(
352 __isl_keep isl_union_set *domain, __isl_keep isl_union_set_list *list)
354 isl_schedule *schedule;
355 isl_schedule_node *node;
357 schedule = isl_schedule_from_domain(isl_union_set_copy(domain));
358 node = isl_schedule_get_root(schedule);
359 isl_schedule_free(schedule);
360 node = isl_schedule_node_child(node, 0);
361 list = isl_union_set_list_copy(list);
362 node = isl_schedule_node_insert_sequence(node, list);
363 schedule = isl_schedule_node_get_schedule(node);
364 isl_schedule_node_free(node);
366 return schedule;
369 /* Construct a unique identifier for a group in "grouping".
371 * The name is of the form G_n, with n the first value starting at
372 * grouping->group_id that does not result in an identifier
373 * that is already in use in the domain of the original schedule
374 * constraints.
376 static isl_id *construct_group_id(struct ppcg_grouping *grouping,
377 __isl_take isl_space *space)
379 isl_ctx *ctx;
380 isl_id *id;
381 isl_bool empty;
382 isl_union_set *domain;
384 if (!space)
385 return NULL;
387 ctx = isl_space_get_ctx(space);
388 domain = isl_schedule_constraints_get_domain(grouping->sc);
390 do {
391 char buffer[20];
392 isl_id *id;
393 isl_set *set;
395 snprintf(buffer, sizeof(buffer), "G_%d", grouping->group_id);
396 grouping->group_id++;
397 id = isl_id_alloc(ctx, buffer, NULL);
398 space = isl_space_set_tuple_id(space, isl_dim_set, id);
399 set = isl_union_set_extract_set(domain, isl_space_copy(space));
400 empty = isl_set_plain_is_empty(set);
401 isl_set_free(set);
402 } while (empty >= 0 && !empty);
404 if (empty < 0)
405 space = isl_space_free(space);
407 id = isl_space_get_tuple_id(space, isl_dim_set);
409 isl_space_free(space);
410 isl_union_set_free(domain);
412 return id;
415 /* Construct a contraction from "prefix" and "domain" for a new group
416 * in "grouping".
418 * The values of the prefix schedule "prefix" are used as instances
419 * of the new group. The identifier of the group is constructed
420 * in such a way that it does not conflict with those of earlier
421 * groups nor with statements in the domain of the original
422 * schedule constraints.
423 * The isl_multi_union_pw_aff "prefix" then simply needs to be
424 * converted to an isl_union_pw_multi_aff. However, this is not
425 * possible if "prefix" is zero-dimensional, so in this case,
426 * a contraction is constructed from "domain" instead.
428 static isl_union_pw_multi_aff *group_contraction_from_prefix_and_domain(
429 struct ppcg_grouping *grouping,
430 __isl_keep isl_multi_union_pw_aff *prefix,
431 __isl_keep isl_union_set *domain)
433 isl_id *id;
434 isl_space *space;
435 int dim;
437 space = isl_multi_union_pw_aff_get_space(prefix);
438 if (!space)
439 return NULL;
440 dim = isl_space_dim(space, isl_dim_set);
441 id = construct_group_id(grouping, space);
442 if (dim == 0) {
443 isl_multi_val *mv;
445 space = isl_multi_union_pw_aff_get_space(prefix);
446 space = isl_space_set_tuple_id(space, isl_dim_set, id);
447 mv = isl_multi_val_zero(space);
448 domain = isl_union_set_copy(domain);
449 return isl_union_pw_multi_aff_multi_val_on_domain(domain, mv);
451 prefix = isl_multi_union_pw_aff_copy(prefix);
452 prefix = isl_multi_union_pw_aff_set_tuple_id(prefix, isl_dim_out, id);
453 return isl_union_pw_multi_aff_from_multi_union_pw_aff(prefix);
456 /* Extend "grouping" with groups corresponding to merged
457 * leaves in the list of potentially merged leaves "leaves".
459 * The "list" field of each element in "leaves" contains a list
460 * of the instances sets of the original leaves that have been
461 * merged into this element. If at least two of the original leaves
462 * have been merged into a given element, then add the corresponding
463 * group to "grouping".
464 * In particular, the domain is extended with the statement instances
465 * of the merged leaves, the contraction is extended with a mapping
466 * of these statement instances to instances of a new group and
467 * the schedule is extended with a schedule that executes
468 * the statement instances according to the order of the leaves
469 * in which they appear.
470 * Since the instances of the groups should already be scheduled apart
471 * in the schedule into which this schedule will be plugged in,
472 * the schedules of the individual groups are combined independently
473 * of each other (as a set).
475 static isl_stat add_groups(struct ppcg_grouping *grouping,
476 int n, struct ppcg_grouping_leaf leaves[n])
478 int i;
480 for (i = 0; i < n; ++i) {
481 int n_leaf;
482 isl_schedule *schedule;
483 isl_union_set *domain;
484 isl_union_pw_multi_aff *upma;
486 n_leaf = isl_union_set_list_n_union_set(leaves[i].list);
487 if (n_leaf < 0)
488 return isl_stat_error;
489 if (n_leaf <= 1)
490 continue;
491 schedule = schedule_from_domain_and_list(leaves[i].domain,
492 leaves[i].list);
493 upma = group_contraction_from_prefix_and_domain(grouping,
494 leaves[i].prefix, leaves[i].domain);
496 domain = isl_union_set_copy(leaves[i].domain);
497 if (grouping->domain) {
498 domain = isl_union_set_union(domain, grouping->domain);
499 upma = isl_union_pw_multi_aff_union_add(upma,
500 grouping->contraction);
501 schedule = isl_schedule_set(schedule,
502 grouping->schedule);
504 grouping->domain = domain;
505 grouping->contraction = upma;
506 grouping->schedule = schedule;
508 if (!grouping->domain || !grouping->contraction ||
509 !grouping->schedule)
510 return isl_stat_error;
513 return isl_stat_ok;
516 /* Look for any pairs of consecutive leaves among the "n" children of "node"
517 * starting at "first" that should be merged together.
518 * Store the results in "grouping".
520 * First make sure the intersection of validity and proximity
521 * schedule constraints is available and extract the required
522 * information from the "n" leaves.
523 * Then try and merge consecutive leaves based on the validity
524 * and proximity constraints.
525 * If any pairs were successfully merged, then add groups
526 * corresponding to the merged leaves to "grouping".
528 static isl_stat group_subsequence(__isl_keep isl_schedule_node *node,
529 int first, int n, struct ppcg_grouping *grouping)
531 int n_merge;
532 struct ppcg_grouping_leaf *leaves;
534 if (ppcg_grouping_compute_dep(grouping) < 0)
535 return isl_stat_error;
537 leaves = extract_leaves(node, first, n);
538 if (!leaves)
539 return isl_stat_error;
541 n_merge = merge_leaves(n, leaves, grouping->dep);
542 if (n_merge >= 0 && n_merge < n &&
543 add_groups(grouping, n_merge, leaves) < 0)
544 return isl_stat_error;
546 ppcg_grouping_leaf_free(n, leaves);
548 return isl_stat_ok;
551 /* If "node" is a sequence, then check if it has any consecutive
552 * leaves that should be merged together and store the results
553 * in "grouping".
555 * In particular, call group_subsequence on each consecutive
556 * sequence of (filtered) leaves among the children of "node".
558 static isl_bool detect_groups(__isl_keep isl_schedule_node *node, void *user)
560 int i, n, first;
561 struct ppcg_grouping *grouping = user;
563 if (isl_schedule_node_get_type(node) != isl_schedule_node_sequence)
564 return isl_bool_true;
566 n = isl_schedule_node_n_children(node);
567 if (n < 0)
568 return isl_bool_error;
570 first = -1;
571 for (i = 0; i < n; ++i) {
572 isl_schedule_node *child;
573 enum isl_schedule_node_type type;
575 child = isl_schedule_node_get_child(node, i);
576 child = isl_schedule_node_child(child, 0);
577 type = isl_schedule_node_get_type(child);
578 isl_schedule_node_free(child);
580 if (first >= 0 && type != isl_schedule_node_leaf) {
581 if (group_subsequence(node, first, i - first,
582 grouping) < 0)
583 return isl_bool_error;
584 first = -1;
586 if (first < 0 && type == isl_schedule_node_leaf)
587 first = i;
589 if (first >= 0) {
590 if (group_subsequence(node, first, n - first, grouping) < 0)
591 return isl_bool_error;
594 return isl_bool_true;
597 /* Complete "grouping" to cover all statement instances in the domain
598 * of grouping->sc.
600 * In particular, grouping->domain is set to the full set of statement
601 * instances; group->contraction is extended with an identity
602 * contraction on the additional instances and group->schedule
603 * is extended with an independent schedule on those additional instances.
604 * In the extension of group->contraction, the additional instances
605 * are split into those belong to different statements and those
606 * that belong to some of the same statements. The first group
607 * is replaced by its universe in order to simplify the contraction extension.
609 static void complete_grouping(struct ppcg_grouping *grouping)
611 isl_union_set *domain, *left, *overlap;
612 isl_union_pw_multi_aff *upma;
613 isl_schedule *schedule;
615 domain = isl_schedule_constraints_get_domain(grouping->sc);
616 left = isl_union_set_subtract(isl_union_set_copy(domain),
617 isl_union_set_copy(grouping->domain));
618 schedule = isl_schedule_from_domain(isl_union_set_copy(left));
619 schedule = isl_schedule_set(schedule, grouping->schedule);
620 grouping->schedule = schedule;
622 overlap = isl_union_set_universe(grouping->domain);
623 grouping->domain = domain;
624 overlap = isl_union_set_intersect(isl_union_set_copy(left), overlap);
625 left = isl_union_set_subtract(left, isl_union_set_copy(overlap));
626 left = isl_union_set_universe(left);
627 left = isl_union_set_union(left, overlap);
628 upma = isl_union_set_identity_union_pw_multi_aff(left);
629 upma = isl_union_pw_multi_aff_union_add(upma, grouping->contraction);
630 grouping->contraction = upma;
633 /* Compute a schedule on the domain of "sc" that respects the schedule
634 * constraints in "sc".
636 * "schedule" is a known correct schedule that is used to combine
637 * groups of statements if options->group_chains is set.
638 * In particular, statements that are executed consecutively in a sequence
639 * in this schedule and where all instances of the second depend on
640 * the instance of the first that is executed in the same iteration
641 * of outer band nodes are grouped together into a single statement.
642 * The schedule constraints are then mapped to these groups of statements
643 * and the resulting schedule is expanded again to refer to the original
644 * statements.
646 __isl_give isl_schedule *ppcg_compute_schedule(
647 __isl_take isl_schedule_constraints *sc,
648 __isl_keep isl_schedule *schedule, struct ppcg_options *options)
650 struct ppcg_grouping grouping = { sc };
651 isl_union_pw_multi_aff *contraction;
652 isl_union_map *umap;
653 isl_schedule *res, *expansion;
655 if (!options->group_chains)
656 return isl_schedule_constraints_compute_schedule(sc);
658 grouping.group_id = 0;
659 if (isl_schedule_foreach_schedule_node_top_down(schedule,
660 &detect_groups, &grouping) < 0)
661 goto error;
662 if (!grouping.contraction) {
663 ppcg_grouping_clear(&grouping);
664 return isl_schedule_constraints_compute_schedule(sc);
666 complete_grouping(&grouping);
667 contraction = isl_union_pw_multi_aff_copy(grouping.contraction);
668 umap = isl_union_map_from_union_pw_multi_aff(contraction);
670 sc = isl_schedule_constraints_apply(sc, umap);
672 res = isl_schedule_constraints_compute_schedule(sc);
674 contraction = isl_union_pw_multi_aff_copy(grouping.contraction);
675 expansion = isl_schedule_copy(grouping.schedule);
676 res = isl_schedule_expand(res, contraction, expansion);
678 ppcg_grouping_clear(&grouping);
679 return res;
680 error:
681 ppcg_grouping_clear(&grouping);
682 isl_schedule_constraints_free(sc);
683 return NULL;