gpu.c: compute_sched_to_shared: change interface to work on isl_pw_multi_affs
[ppcg.git] / cpu.c
blobf6675953b53ce2dab2430cfb602cbe9828cf62d3
1 /*
2 * Copyright 2012 INRIA Paris-Rocquencourt
4 * Use of this software is governed by the MIT license
6 * Written by Tobias Grosser, INRIA Paris-Rocquencourt,
7 * Domaine de Voluceau, Rocquenqourt, B.P. 105,
8 * 78153 Le Chesnay Cedex France
9 */
11 #include <limits.h>
12 #include <stdio.h>
13 #include <string.h>
15 #include <isl/aff.h>
16 #include <isl/ctx.h>
17 #include <isl/map.h>
18 #include <isl/ast_build.h>
19 #include <pet.h>
21 #include "ppcg.h"
22 #include "ppcg_options.h"
23 #include "cpu.h"
24 #include "print.h"
25 #include "rewrite.h"
27 /* Representation of a statement inside a generated AST.
29 * "stmt" refers to the original statement.
30 * "ref2expr" maps the reference identifier of each access in
31 * the statement to an AST expression that should be printed
32 * at the place of the access.
34 struct ppcg_stmt {
35 struct pet_stmt *stmt;
37 isl_id_to_ast_expr *ref2expr;
40 static void ppcg_stmt_free(void *user)
42 struct ppcg_stmt *stmt = user;
43 int i;
45 if (!stmt)
46 return;
48 isl_id_to_ast_expr_free(stmt->ref2expr);
50 free(stmt);
53 /* Derive the output file name from the input file name.
54 * 'input' is the entire path of the input file. The output
55 * is the file name plus the additional extension.
57 * We will basically replace everything after the last point
58 * with '.ppcg.c'. This means file.c becomes file.ppcg.c
60 static FILE *get_output_file(const char *input, const char *output)
62 char name[PATH_MAX];
63 const char *ext;
64 const char ppcg_marker[] = ".ppcg";
65 int len;
67 len = ppcg_extract_base_name(name, input);
69 strcpy(name + len, ppcg_marker);
70 ext = strrchr(input, '.');
71 strcpy(name + len + sizeof(ppcg_marker) - 1, ext ? ext : ".c");
73 if (!output)
74 output = name;
76 return fopen(output, "w");
79 /* Data used to annotate for nodes in the ast.
81 struct ast_node_userinfo {
82 /* The for node is an openmp parallel for node. */
83 int is_openmp;
86 /* Information used while building the ast.
88 struct ast_build_userinfo {
89 /* The current ppcg scop. */
90 struct ppcg_scop *scop;
92 /* Are we currently in a parallel for loop? */
93 int in_parallel_for;
96 /* Check if the current scheduling dimension is parallel.
98 * We check for parallelism by verifying that the loop does not carry any
99 * dependences.
101 * Parallelism test: if the distance is zero in all outer dimensions, then it
102 * has to be zero in the current dimension as well.
103 * Implementation: first, translate dependences into time space, then force
104 * outer dimensions to be equal. If the distance is zero in the current
105 * dimension, then the loop is parallel.
106 * The distance is zero in the current dimension if it is a subset of a map
107 * with equal values for the current dimension.
109 static int ast_schedule_dim_is_parallel(__isl_keep isl_ast_build *build,
110 struct ppcg_scop *scop)
112 isl_union_map *schedule_node, *schedule, *deps;
113 isl_map *schedule_deps, *test;
114 isl_space *schedule_space;
115 unsigned i, dimension, is_parallel;
117 schedule = isl_ast_build_get_schedule(build);
118 schedule_space = isl_ast_build_get_schedule_space(build);
120 dimension = isl_space_dim(schedule_space, isl_dim_out) - 1;
122 deps = isl_union_map_copy(scop->dep_flow);
123 deps = isl_union_map_union(deps, isl_union_map_copy(scop->dep_false));
124 deps = isl_union_map_apply_range(deps, isl_union_map_copy(schedule));
125 deps = isl_union_map_apply_domain(deps, schedule);
127 if (isl_union_map_is_empty(deps)) {
128 isl_union_map_free(deps);
129 isl_space_free(schedule_space);
130 return 1;
133 schedule_deps = isl_map_from_union_map(deps);
135 for (i = 0; i < dimension; i++)
136 schedule_deps = isl_map_equate(schedule_deps, isl_dim_out, i,
137 isl_dim_in, i);
139 test = isl_map_universe(isl_map_get_space(schedule_deps));
140 test = isl_map_equate(test, isl_dim_out, dimension, isl_dim_in,
141 dimension);
142 is_parallel = isl_map_is_subset(schedule_deps, test);
144 isl_space_free(schedule_space);
145 isl_map_free(test);
146 isl_map_free(schedule_deps);
148 return is_parallel;
151 /* Mark a for node openmp parallel, if it is the outermost parallel for node.
153 static void mark_openmp_parallel(__isl_keep isl_ast_build *build,
154 struct ast_build_userinfo *build_info,
155 struct ast_node_userinfo *node_info)
157 if (build_info->in_parallel_for)
158 return;
160 if (ast_schedule_dim_is_parallel(build, build_info->scop)) {
161 build_info->in_parallel_for = 1;
162 node_info->is_openmp = 1;
166 /* Allocate an ast_node_info structure and initialize it with default values.
168 static struct ast_node_userinfo *allocate_ast_node_userinfo()
170 struct ast_node_userinfo *node_info;
171 node_info = (struct ast_node_userinfo *)
172 malloc(sizeof(struct ast_node_userinfo));
173 node_info->is_openmp = 0;
174 return node_info;
177 /* Free an ast_node_info structure.
179 static void free_ast_node_userinfo(void *ptr)
181 struct ast_node_userinfo *info;
182 info = (struct ast_node_userinfo *) ptr;
183 free(info);
186 /* This method is executed before the construction of a for node. It creates
187 * an isl_id that is used to annotate the subsequently generated ast for nodes.
189 * In this function we also run the following analyses:
191 * - Detection of openmp parallel loops
193 static __isl_give isl_id *ast_build_before_for(
194 __isl_keep isl_ast_build *build, void *user)
196 isl_id *id;
197 struct ast_build_userinfo *build_info;
198 struct ast_node_userinfo *node_info;
200 build_info = (struct ast_build_userinfo *) user;
201 node_info = allocate_ast_node_userinfo();
202 id = isl_id_alloc(isl_ast_build_get_ctx(build), "", node_info);
203 id = isl_id_set_free_user(id, free_ast_node_userinfo);
205 mark_openmp_parallel(build, build_info, node_info);
207 return id;
210 /* This method is executed after the construction of a for node.
212 * It performs the following actions:
214 * - Reset the 'in_parallel_for' flag, as soon as we leave a for node,
215 * that is marked as openmp parallel.
218 static __isl_give isl_ast_node *ast_build_after_for(__isl_take isl_ast_node *node,
219 __isl_keep isl_ast_build *build, void *user) {
220 isl_id *id;
221 struct ast_build_userinfo *build_info;
222 struct ast_node_userinfo *info;
224 id = isl_ast_node_get_annotation(node);
225 info = isl_id_get_user(id);
227 if (info && info->is_openmp) {
228 build_info = (struct ast_build_userinfo *) user;
229 build_info->in_parallel_for = 0;
232 isl_id_free(id);
234 return node;
237 /* Find the element in scop->stmts that has the given "id".
239 static struct pet_stmt *find_stmt(struct ppcg_scop *scop, __isl_keep isl_id *id)
241 int i;
243 for (i = 0; i < scop->n_stmt; ++i) {
244 struct pet_stmt *stmt = scop->stmts[i];
245 isl_id *id_i;
247 id_i = isl_set_get_tuple_id(stmt->domain);
248 isl_id_free(id_i);
250 if (id_i == id)
251 return stmt;
254 isl_die(isl_id_get_ctx(id), isl_error_internal,
255 "statement not found", return NULL);
258 /* Print a user statement in the generated AST.
259 * The ppcg_stmt has been attached to the node in at_each_domain.
261 static __isl_give isl_printer *print_user(__isl_take isl_printer *p,
262 __isl_take isl_ast_print_options *print_options,
263 __isl_keep isl_ast_node *node, void *user)
265 struct ppcg_stmt *stmt;
266 isl_id *id;
268 id = isl_ast_node_get_annotation(node);
269 stmt = isl_id_get_user(id);
270 isl_id_free(id);
272 p = pet_stmt_print_body(stmt->stmt, p, stmt->ref2expr);
274 isl_ast_print_options_free(print_options);
276 return p;
280 /* Print a for loop node as an openmp parallel loop.
282 * To print an openmp parallel loop we print a normal for loop, but add
283 * "#pragma openmp parallel for" in front.
285 * Variables that are declared within the body of this for loop are
286 * automatically openmp 'private'. Iterators declared outside of the
287 * for loop are automatically openmp 'shared'. As ppcg declares all iterators
288 * at the position where they are assigned, there is no need to explicitly mark
289 * variables. Their automatically assigned type is already correct.
291 * This function only generates valid OpenMP code, if the ast was generated
292 * with the 'atomic-bounds' option enabled.
295 static __isl_give isl_printer *print_for_with_openmp(
296 __isl_keep isl_ast_node *node, __isl_take isl_printer *p,
297 __isl_take isl_ast_print_options *print_options)
299 p = isl_printer_start_line(p);
300 p = isl_printer_print_str(p, "#pragma omp parallel for");
301 p = isl_printer_end_line(p);
303 p = isl_ast_node_for_print(node, p, print_options);
305 return p;
308 /* Print a for node.
310 * Depending on how the node is annotated, we either print a normal
311 * for node or an openmp parallel for node.
313 static __isl_give isl_printer *print_for(__isl_take isl_printer *p,
314 __isl_take isl_ast_print_options *print_options,
315 __isl_keep isl_ast_node *node, void *user)
317 struct ppcg_print_info *print_info;
318 isl_id *id;
319 int openmp;
321 openmp = 0;
322 id = isl_ast_node_get_annotation(node);
324 if (id) {
325 struct ast_node_userinfo *info;
327 info = (struct ast_node_userinfo *) isl_id_get_user(id);
328 if (info && info->is_openmp)
329 openmp = 1;
332 if (openmp)
333 p = print_for_with_openmp(node, p, print_options);
334 else
335 p = isl_ast_node_for_print(node, p, print_options);
337 isl_id_free(id);
339 return p;
342 /* Index transformation callback for pet_stmt_build_ast_exprs.
344 * "index" expresses the array indices in terms of statement iterators
345 * "iterator_map" expresses the statement iterators in terms of
346 * AST loop iterators.
348 * The result expresses the array indices in terms of
349 * AST loop iterators.
351 static __isl_give isl_multi_pw_aff *pullback_index(
352 __isl_take isl_multi_pw_aff *index, __isl_keep isl_id *id, void *user)
354 isl_pw_multi_aff *iterator_map = user;
356 iterator_map = isl_pw_multi_aff_copy(iterator_map);
357 return isl_multi_pw_aff_pullback_pw_multi_aff(index, iterator_map);
360 /* Transform the accesses in the statement associated to the domain
361 * called by "node" to refer to the AST loop iterators, construct
362 * corresponding AST expressions using "build",
363 * collect them in a ppcg_stmt and annotate the node with the ppcg_stmt.
365 static __isl_give isl_ast_node *at_each_domain(__isl_take isl_ast_node *node,
366 __isl_keep isl_ast_build *build, void *user)
368 struct ppcg_scop *scop = user;
369 isl_ast_expr *expr, *arg;
370 isl_ctx *ctx;
371 isl_id *id;
372 isl_map *map;
373 isl_pw_multi_aff *iterator_map;
374 struct ppcg_stmt *stmt;
376 ctx = isl_ast_node_get_ctx(node);
377 stmt = isl_calloc_type(ctx, struct ppcg_stmt);
378 if (!stmt)
379 goto error;
381 expr = isl_ast_node_user_get_expr(node);
382 arg = isl_ast_expr_get_op_arg(expr, 0);
383 isl_ast_expr_free(expr);
384 id = isl_ast_expr_get_id(arg);
385 isl_ast_expr_free(arg);
386 stmt->stmt = find_stmt(scop, id);
387 isl_id_free(id);
388 if (!stmt->stmt)
389 goto error;
391 map = isl_map_from_union_map(isl_ast_build_get_schedule(build));
392 map = isl_map_reverse(map);
393 iterator_map = isl_pw_multi_aff_from_map(map);
394 stmt->ref2expr = pet_stmt_build_ast_exprs(stmt->stmt, build,
395 &pullback_index, iterator_map, NULL, NULL);
396 isl_pw_multi_aff_free(iterator_map);
398 id = isl_id_alloc(isl_ast_node_get_ctx(node), NULL, stmt);
399 id = isl_id_set_free_user(id, &ppcg_stmt_free);
400 return isl_ast_node_set_annotation(node, id);
401 error:
402 ppcg_stmt_free(stmt);
403 return isl_ast_node_free(node);
406 /* Code generate the scop 'scop' and print the corresponding C code to 'p'.
408 static __isl_give isl_printer *print_scop(struct ppcg_scop *scop,
409 __isl_take isl_printer *p, struct ppcg_options *options)
411 isl_ctx *ctx = isl_printer_get_ctx(p);
412 isl_set *context;
413 isl_union_set *domain_set;
414 isl_union_map *schedule_map;
415 isl_ast_build *build;
416 isl_ast_print_options *print_options;
417 isl_ast_node *tree;
418 struct ast_build_userinfo build_info;
420 context = isl_set_copy(scop->context);
421 domain_set = isl_union_set_copy(scop->domain);
422 schedule_map = isl_union_map_copy(scop->schedule);
423 schedule_map = isl_union_map_intersect_domain(schedule_map, domain_set);
425 build = isl_ast_build_from_context(context);
426 build = isl_ast_build_set_at_each_domain(build, &at_each_domain, scop);
428 if (options->openmp) {
429 build_info.scop = scop;
430 build_info.in_parallel_for = 0;
432 build = isl_ast_build_set_before_each_for(build,
433 &ast_build_before_for,
434 &build_info);
435 build = isl_ast_build_set_after_each_for(build,
436 &ast_build_after_for,
437 &build_info);
440 tree = isl_ast_build_ast_from_schedule(build, schedule_map);
441 isl_ast_build_free(build);
443 print_options = isl_ast_print_options_alloc(ctx);
444 print_options = isl_ast_print_options_set_print_user(print_options,
445 &print_user, NULL);
447 print_options = isl_ast_print_options_set_print_for(print_options,
448 &print_for, NULL);
450 p = isl_ast_node_print_macros(tree, p);
451 p = isl_ast_node_print(tree, p, print_options);
453 isl_ast_node_free(tree);
455 return p;
458 /* Does "scop" refer to any arrays that are declared, but not
459 * exposed to the code after the scop?
461 static int any_hidden_declarations(struct ppcg_scop *scop)
463 int i;
465 if (!scop)
466 return 0;
468 for (i = 0; i < scop->n_array; ++i)
469 if (scop->arrays[i]->declared && !scop->arrays[i]->exposed)
470 return 1;
472 return 0;
475 /* Generate CPU code for the scop "ps" and print the corresponding C code
476 * to "p", including variable declarations.
478 __isl_give isl_printer *print_cpu(__isl_take isl_printer *p,
479 struct ppcg_scop *ps, struct ppcg_options *options)
481 int hidden;
483 p = isl_printer_start_line(p);
484 p = isl_printer_print_str(p, "/* ppcg generated CPU code */");
485 p = isl_printer_end_line(p);
487 p = isl_printer_start_line(p);
488 p = isl_printer_end_line(p);
490 p = ppcg_print_exposed_declarations(p, ps);
491 hidden = any_hidden_declarations(ps);
492 if (hidden) {
493 p = ppcg_start_block(p);
494 p = ppcg_print_hidden_declarations(p, ps);
496 p = print_scop(ps, p, options);
497 if (hidden)
498 p = ppcg_end_block(p);
500 return p;
503 int generate_cpu(isl_ctx *ctx, struct ppcg_scop *ps,
504 struct ppcg_options *options, const char *input, const char *output)
506 FILE *input_file;
507 FILE *output_file;
508 isl_printer *p;
510 if (!ps)
511 return -1;
513 input_file = fopen(input, "r");
514 output_file = get_output_file(input, output);
516 copy(input_file, output_file, 0, ps->start);
517 p = isl_printer_to_file(ctx, output_file);
518 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
519 p = print_cpu(p, ps, options);
520 isl_printer_free(p);
521 copy(input_file, output_file, ps->end, -1);
523 fclose(output_file);
524 fclose(input_file);
526 return 0;