gpu: add nodes for initializing and clearing the device to the schedule tree
[ppcg.git] / gpu_tree.c
blob2f9c118331ce82822ba53526ec525e92d568fbd4
1 /*
2 * Copyright 2013 Ecole Normale Superieure
4 * Use of this software is governed by the MIT license
6 * Written by Sven Verdoolaege,
7 * Ecole Normale Superieure, 45 rue d'Ulm, 75230 Paris, France
8 */
10 #include <string.h>
12 #include <isl/set.h>
13 #include <isl/union_set.h>
15 #include "gpu_tree.h"
17 /* The functions in this file are used to navigate part of a schedule tree
18 * that is mapped to blocks. Initially, this part consists of a linear
19 * branch segment with a mark node with name "kernel" on the outer end
20 * and a mark node with name "thread" on the inner end.
21 * During the mapping to blocks, branching may be introduced, but only
22 * one of the elements in each sequence contains the "thread" mark.
23 * The filter of this element (and only this filter) contains
24 * domain elements identified by the "core" argument of the functions
25 * that move down this tree.
27 * Synchronization statements have a name that starts with "sync" and
28 * a user pointer pointing to the kernel that contains the synchronization.
29 * The functions inserting or detecting synchronizations take a ppcg_kernel
30 * argument to be able to create or identify such statements.
31 * They may also use two fields in this structure, the "core" field
32 * to move around in the tree and the "n_sync" field to make sure that
33 * each synchronization has a different name (within the kernel).
36 /* Is "node" a mark node with an identifier called "name"?
38 static int is_marked(__isl_keep isl_schedule_node *node, const char *name)
40 isl_id *mark;
41 int has_name;
43 if (!node)
44 return -1;
46 if (isl_schedule_node_get_type(node) != isl_schedule_node_mark)
47 return 0;
49 mark = isl_schedule_node_mark_get_id(node);
50 if (!mark)
51 return -1;
53 has_name = !strcmp(isl_id_get_name(mark), name);
54 isl_id_free(mark);
56 return has_name;
59 /* Is "node" a mark node with an identifier called "kernel"?
61 int gpu_tree_node_is_kernel(__isl_keep isl_schedule_node *node)
63 return is_marked(node, "kernel");
66 /* Is "node" a mark node with an identifier called "thread"?
68 static int node_is_thread(__isl_keep isl_schedule_node *node)
70 return is_marked(node, "thread");
73 /* Assuming "node" is a filter node, does it correspond to the branch
74 * that contains the "thread" mark, i.e., does it contain any elements
75 * in "core"?
77 static int node_is_core(__isl_keep isl_schedule_node *node,
78 __isl_keep isl_union_set *core)
80 int disjoint;
81 isl_union_set *filter;
83 filter = isl_schedule_node_filter_get_filter(node);
84 disjoint = isl_union_set_is_disjoint(filter, core);
85 isl_union_set_free(filter);
86 if (disjoint < 0)
87 return -1;
89 return !disjoint;
92 /* Move to the only child of "node" that has the "thread" mark as descendant,
93 * where the branch containing this mark is identified by the domain elements
94 * in "core".
96 * If "node" is not a sequence, then it only has one child and we move
97 * to that single child.
98 * Otherwise, we check each of the filters in the children, pick
99 * the one that corresponds to "core" and return a pointer to the child
100 * of the filter node.
102 static __isl_give isl_schedule_node *core_child(
103 __isl_take isl_schedule_node *node, __isl_keep isl_union_set *core)
105 int i, n;
107 if (isl_schedule_node_get_type(node) != isl_schedule_node_sequence)
108 return isl_schedule_node_child(node, 0);
110 n = isl_schedule_node_n_children(node);
111 for (i = 0; i < n; ++i) {
112 int is_core;
114 node = isl_schedule_node_child(node, i);
115 is_core = node_is_core(node, core);
117 if (is_core < 0)
118 return isl_schedule_node_free(node);
119 if (is_core)
120 return isl_schedule_node_child(node, 0);
122 node = isl_schedule_node_parent(node);
125 isl_die(isl_schedule_node_get_ctx(node), isl_error_internal,
126 "core child not found", return isl_schedule_node_free(node));
129 /* Move down the branch between "kernel" and "thread" until
130 * the "thread" mark is reached, where the branch containing the "thread"
131 * mark is identified by the domain elements in "core".
133 __isl_give isl_schedule_node *gpu_tree_move_down_to_thread(
134 __isl_take isl_schedule_node *node, __isl_keep isl_union_set *core)
136 int is_thread;
138 while ((is_thread = node_is_thread(node)) == 0)
139 node = core_child(node, core);
140 if (is_thread < 0)
141 node = isl_schedule_node_free(node);
143 return node;
146 /* Move up the tree underneath the "thread" mark until
147 * the "thread" mark is reached.
149 __isl_give isl_schedule_node *gpu_tree_move_up_to_thread(
150 __isl_take isl_schedule_node *node)
152 int is_thread;
154 while ((is_thread = node_is_thread(node)) == 0)
155 node = isl_schedule_node_parent(node);
156 if (is_thread < 0)
157 node = isl_schedule_node_free(node);
159 return node;
162 /* Move up the tree underneath the "kernel" mark until
163 * the "kernel" mark is reached.
165 __isl_give isl_schedule_node *gpu_tree_move_up_to_kernel(
166 __isl_take isl_schedule_node *node)
168 int is_kernel;
170 while ((is_kernel = gpu_tree_node_is_kernel(node)) == 0)
171 node = isl_schedule_node_parent(node);
172 if (is_kernel < 0)
173 node = isl_schedule_node_free(node);
175 return node;
178 /* Move down from the "kernel" mark (or at least a node with schedule
179 * depth smaller than or equal to "depth") to a band node at schedule
180 * depth "depth". The "thread" mark is assumed to have a schedule
181 * depth greater than or equal to "depth". The branch containing the
182 * "thread" mark is identified by the domain elements in "core".
184 * If the desired schedule depth is in the middle of band node,
185 * then the band node is split into two pieces, the second piece
186 * at the desired schedule depth.
188 __isl_give isl_schedule_node *gpu_tree_move_down_to_depth(
189 __isl_take isl_schedule_node *node, int depth,
190 __isl_keep isl_union_set *core)
192 int is_thread;
194 while (node && isl_schedule_node_get_schedule_depth(node) < depth) {
195 if (isl_schedule_node_get_type(node) ==
196 isl_schedule_node_band) {
197 int node_depth, node_dim;
198 node_depth = isl_schedule_node_get_schedule_depth(node);
199 node_dim = isl_schedule_node_band_n_member(node);
200 if (node_depth + node_dim > depth)
201 node = isl_schedule_node_band_split(node,
202 depth - node_depth);
204 node = core_child(node, core);
206 while ((is_thread = node_is_thread(node)) == 0 &&
207 isl_schedule_node_get_type(node) != isl_schedule_node_band)
208 node = core_child(node, core);
209 if (is_thread < 0)
210 node = isl_schedule_node_free(node);
212 return node;
215 /* Create a union set containing a single set with a tuple identifier
216 * called "syncX" and user pointer equal to "kernel".
218 static __isl_give isl_union_set *create_sync_domain(struct ppcg_kernel *kernel)
220 isl_space *space;
221 isl_id *id;
222 char name[40];
224 space = isl_space_set_alloc(kernel->ctx, 0, 0);
225 snprintf(name, sizeof(name), "sync%d", kernel->n_sync++);
226 id = isl_id_alloc(kernel->ctx, name, kernel);
227 space = isl_space_set_tuple_id(space, isl_dim_set, id);
228 return isl_union_set_from_set(isl_set_universe(space));
231 /* Is "id" the identifier of a synchronization statement inside "kernel"?
232 * That is, does its name start with "sync" and does it point to "kernel"?
234 int gpu_tree_id_is_sync(__isl_keep isl_id *id, struct ppcg_kernel *kernel)
236 const char *name;
238 name = isl_id_get_name(id);
239 if (!name)
240 return 0;
241 else if (strncmp(name, "sync", 4))
242 return 0;
243 return isl_id_get_user(id) == kernel;
246 /* Does "domain" consist of a single set with a tuple identifier
247 * corresponding to a synchronization for "kernel"?
249 static int domain_is_sync(__isl_keep isl_union_set *domain,
250 struct ppcg_kernel *kernel)
252 int is_sync;
253 isl_id *id;
254 isl_set *set;
256 if (isl_union_set_n_set(domain) != 1)
257 return 0;
258 set = isl_set_from_union_set(isl_union_set_copy(domain));
259 id = isl_set_get_tuple_id(set);
260 is_sync = gpu_tree_id_is_sync(id, kernel);
261 isl_id_free(id);
262 isl_set_free(set);
264 return is_sync;
267 /* Does "node" point to a filter selecting a synchronization statement
268 * for "kernel"?
270 static int node_is_sync_filter(__isl_keep isl_schedule_node *node,
271 struct ppcg_kernel *kernel)
273 int is_sync;
274 enum isl_schedule_node_type type;
275 isl_union_set *domain;
277 if (!node)
278 return -1;
279 type = isl_schedule_node_get_type(node);
280 if (type != isl_schedule_node_filter)
281 return 0;
282 domain = isl_schedule_node_filter_get_filter(node);
283 is_sync = domain_is_sync(domain, kernel);
284 isl_union_set_free(domain);
286 return is_sync;
289 /* Is "node" part of a sequence with a previous synchronization statement
290 * for "kernel"?
291 * That is, is the parent of "node" a filter such that there is
292 * a previous filter that picks out exactly such a synchronization statement?
294 static int has_preceding_sync(__isl_keep isl_schedule_node *node,
295 struct ppcg_kernel *kernel)
297 int found = 0;
299 node = isl_schedule_node_copy(node);
300 node = isl_schedule_node_parent(node);
301 while (!found && isl_schedule_node_has_previous_sibling(node)) {
302 node = isl_schedule_node_previous_sibling(node);
303 if (!node)
304 break;
305 found = node_is_sync_filter(node, kernel);
307 if (!node)
308 found = -1;
309 isl_schedule_node_free(node);
311 return found;
314 /* Is "node" part of a sequence with a subsequent synchronization statement
315 * for "kernel"?
316 * That is, is the parent of "node" a filter such that there is
317 * a subsequent filter that picks out exactly such a synchronization statement?
319 static int has_following_sync(__isl_keep isl_schedule_node *node,
320 struct ppcg_kernel *kernel)
322 int found = 0;
324 node = isl_schedule_node_copy(node);
325 node = isl_schedule_node_parent(node);
326 while (!found && isl_schedule_node_has_next_sibling(node)) {
327 node = isl_schedule_node_next_sibling(node);
328 if (!node)
329 break;
330 found = node_is_sync_filter(node, kernel);
332 if (!node)
333 found = -1;
334 isl_schedule_node_free(node);
336 return found;
339 /* Does the subtree rooted at "node" (which is a band node) contain
340 * any synchronization statement for "kernel" that precedes
341 * the core computation of "kernel" (identified by the elements
342 * in kernel->core)?
344 static int has_sync_before_core(__isl_keep isl_schedule_node *node,
345 struct ppcg_kernel *kernel)
347 int has_sync = 0;
348 int is_thread;
350 node = isl_schedule_node_copy(node);
351 while ((is_thread = node_is_thread(node)) == 0) {
352 node = core_child(node, kernel->core);
353 has_sync = has_preceding_sync(node, kernel);
354 if (has_sync < 0 || has_sync)
355 break;
357 if (is_thread < 0 || !node)
358 has_sync = -1;
359 isl_schedule_node_free(node);
361 return has_sync;
364 /* Does the subtree rooted at "node" (which is a band node) contain
365 * any synchronization statement for "kernel" that follows
366 * the core computation of "kernel" (identified by the elements
367 * in kernel->core)?
369 static int has_sync_after_core(__isl_keep isl_schedule_node *node,
370 struct ppcg_kernel *kernel)
372 int has_sync = 0;
373 int is_thread;
375 node = isl_schedule_node_copy(node);
376 while ((is_thread = node_is_thread(node)) == 0) {
377 node = core_child(node, kernel->core);
378 has_sync = has_following_sync(node, kernel);
379 if (has_sync < 0 || has_sync)
380 break;
382 if (is_thread < 0 || !node)
383 has_sync = -1;
384 isl_schedule_node_free(node);
386 return has_sync;
389 /* Insert (or extend) an extension on top of "node" that puts
390 * a synchronization node for "kernel" before "node".
391 * Return a pointer to the original node in the updated schedule tree.
393 static __isl_give isl_schedule_node *insert_sync_before(
394 __isl_take isl_schedule_node *node, struct ppcg_kernel *kernel)
396 isl_union_set *domain;
397 isl_schedule_node *graft;
399 if (!node)
400 return NULL;
402 domain = create_sync_domain(kernel);
403 graft = isl_schedule_node_from_domain(domain);
404 node = isl_schedule_node_graft_before(node, graft);
406 return node;
409 /* Insert (or extend) an extension on top of "node" that puts
410 * a synchronization node for "kernel" afater "node".
411 * Return a pointer to the original node in the updated schedule tree.
413 static __isl_give isl_schedule_node *insert_sync_after(
414 __isl_take isl_schedule_node *node, struct ppcg_kernel *kernel)
416 isl_union_set *domain;
417 isl_schedule_node *graft;
419 if (!node)
420 return NULL;
422 domain = create_sync_domain(kernel);
423 graft = isl_schedule_node_from_domain(domain);
424 node = isl_schedule_node_graft_after(node, graft);
426 return node;
429 /* Insert an extension on top of "node" that puts a synchronization node
430 * for "kernel" before "node" unless there already is
431 * such a synchronization node.
433 __isl_give isl_schedule_node *gpu_tree_ensure_preceding_sync(
434 __isl_take isl_schedule_node *node, struct ppcg_kernel *kernel)
436 int has_sync;
438 has_sync = has_preceding_sync(node, kernel);
439 if (has_sync < 0)
440 return isl_schedule_node_free(node);
441 if (has_sync)
442 return node;
443 return insert_sync_before(node, kernel);
446 /* Insert an extension on top of "node" that puts a synchronization node
447 * for "kernel" after "node" unless there already is
448 * such a synchronization node.
450 __isl_give isl_schedule_node *gpu_tree_ensure_following_sync(
451 __isl_take isl_schedule_node *node, struct ppcg_kernel *kernel)
453 int has_sync;
455 has_sync = has_following_sync(node, kernel);
456 if (has_sync < 0)
457 return isl_schedule_node_free(node);
458 if (has_sync)
459 return node;
460 return insert_sync_after(node, kernel);
463 /* Insert an extension on top of "node" that puts a synchronization node
464 * for "kernel" after "node" unless there already is such a sync node or
465 * "node" itself already * contains a synchronization node following
466 * the core computation of "kernel".
468 __isl_give isl_schedule_node *gpu_tree_ensure_sync_after_core(
469 __isl_take isl_schedule_node *node, struct ppcg_kernel *kernel)
471 int has_sync;
473 has_sync = has_sync_after_core(node, kernel);
474 if (has_sync < 0)
475 return isl_schedule_node_free(node);
476 if (has_sync)
477 return node;
478 has_sync = has_following_sync(node, kernel);
479 if (has_sync < 0)
480 return isl_schedule_node_free(node);
481 if (has_sync)
482 return node;
483 return insert_sync_after(node, kernel);
486 /* Move left in the sequence on top of "node" to a synchronization node
487 * for "kernel".
488 * If "node" itself contains a synchronization node preceding
489 * the core computation of "kernel", then return "node" itself.
490 * Otherwise, if "node" does not have a preceding synchronization node,
491 * then create one first.
493 __isl_give isl_schedule_node *gpu_tree_move_left_to_sync(
494 __isl_take isl_schedule_node *node, struct ppcg_kernel *kernel)
496 int has_sync;
497 int is_sync;
499 has_sync = has_sync_before_core(node, kernel);
500 if (has_sync < 0)
501 return isl_schedule_node_free(node);
502 if (has_sync)
503 return node;
504 node = gpu_tree_ensure_preceding_sync(node, kernel);
505 node = isl_schedule_node_parent(node);
506 while ((is_sync = node_is_sync_filter(node, kernel)) == 0)
507 node = isl_schedule_node_previous_sibling(node);
508 if (is_sync < 0)
509 node = isl_schedule_node_free(node);
510 node = isl_schedule_node_child(node, 0);
512 return node;
515 /* Move right in the sequence on top of "node" to a synchronization node
516 * for "kernel".
517 * If "node" itself contains a synchronization node following
518 * the core computation of "kernel", then return "node" itself.
519 * Otherwise, if "node" does not have a following synchronization node,
520 * then create one first.
522 __isl_give isl_schedule_node *gpu_tree_move_right_to_sync(
523 __isl_take isl_schedule_node *node, struct ppcg_kernel *kernel)
525 int has_sync;
526 int is_sync;
528 has_sync = has_sync_after_core(node, kernel);
529 if (has_sync < 0)
530 return isl_schedule_node_free(node);
531 if (has_sync)
532 return node;
533 node = gpu_tree_ensure_following_sync(node, kernel);
534 node = isl_schedule_node_parent(node);
535 while ((is_sync = node_is_sync_filter(node, kernel)) == 0)
536 node = isl_schedule_node_next_sibling(node);
537 if (is_sync < 0)
538 node = isl_schedule_node_free(node);
539 node = isl_schedule_node_child(node, 0);
541 return node;