update pet for pet_transform_C_source
[ppcg.git] / cpu.c
blobe401e6ae89b32c5907f809da88ab43b4c268262e
1 /*
2 * Copyright 2012 INRIA Paris-Rocquencourt
4 * Use of this software is governed by the MIT license
6 * Written by Tobias Grosser, INRIA Paris-Rocquencourt,
7 * Domaine de Voluceau, Rocquenqourt, B.P. 105,
8 * 78153 Le Chesnay Cedex France
9 */
11 #include <limits.h>
12 #include <stdio.h>
13 #include <string.h>
15 #include <isl/aff.h>
16 #include <isl/ctx.h>
17 #include <isl/map.h>
18 #include <isl/ast_build.h>
19 #include <pet.h>
21 #include "ppcg.h"
22 #include "ppcg_options.h"
23 #include "cpu.h"
24 #include "pet_printer.h"
25 #include "print.h"
26 #include "rewrite.h"
28 /* Representation of a statement inside a generated AST.
30 * "stmt" refers to the original statement.
31 * "n_access" is the number of accesses in the statement.
32 * "access" is the list of accesses transformed to refer to the iterators
33 * in the generated AST.
35 struct ppcg_stmt {
36 struct pet_stmt *stmt;
38 int n_access;
39 isl_ast_expr_list **access;
42 static void ppcg_stmt_free(void *user)
44 struct ppcg_stmt *stmt = user;
45 int i;
47 if (!stmt)
48 return;
50 for (i = 0; i < stmt->n_access; ++i)
51 isl_ast_expr_list_free(stmt->access[i]);
53 free(stmt->access);
54 free(stmt);
57 /* Derive the output file name from the input file name.
58 * 'input' is the entire path of the input file. The output
59 * is the file name plus the additional extension.
61 * We will basically replace everything after the last point
62 * with '.ppcg.c'. This means file.c becomes file.ppcg.c
64 static FILE *get_output_file(const char *input, const char *output)
66 char name[PATH_MAX];
67 const char *ext;
68 const char ppcg_marker[] = ".ppcg";
69 int len;
71 len = ppcg_extract_base_name(name, input);
73 strcpy(name + len, ppcg_marker);
74 ext = strrchr(input, '.');
75 strcpy(name + len + sizeof(ppcg_marker) - 1, ext ? ext : ".c");
77 if (!output)
78 output = name;
80 return fopen(output, "w");
83 /* Data used to annotate for nodes in the ast.
85 struct ast_node_userinfo {
86 /* The for node is an openmp parallel for node. */
87 int is_openmp;
90 /* Information used while building the ast.
92 struct ast_build_userinfo {
93 /* The current ppcg scop. */
94 struct ppcg_scop *scop;
96 /* Are we currently in a parallel for loop? */
97 int in_parallel_for;
100 /* Check if the current scheduling dimension is parallel.
102 * We check for parallelism by verifying that the loop does not carry any
103 * dependences.
105 * Parallelism test: if the distance is zero in all outer dimensions, then it
106 * has to be zero in the current dimension as well.
107 * Implementation: first, translate dependences into time space, then force
108 * outer dimensions to be equal. If the distance is zero in the current
109 * dimension, then the loop is parallel.
110 * The distance is zero in the current dimension if it is a subset of a map
111 * with equal values for the current dimension.
113 static int ast_schedule_dim_is_parallel(__isl_keep isl_ast_build *build,
114 struct ppcg_scop *scop)
116 isl_union_map *schedule_node, *schedule, *deps;
117 isl_map *schedule_deps, *test;
118 isl_space *schedule_space;
119 unsigned i, dimension, is_parallel;
121 schedule = isl_ast_build_get_schedule(build);
122 schedule_space = isl_ast_build_get_schedule_space(build);
124 dimension = isl_space_dim(schedule_space, isl_dim_out) - 1;
126 deps = isl_union_map_copy(scop->dep_flow);
127 deps = isl_union_map_union(deps, isl_union_map_copy(scop->dep_false));
128 deps = isl_union_map_apply_range(deps, isl_union_map_copy(schedule));
129 deps = isl_union_map_apply_domain(deps, schedule);
131 if (isl_union_map_is_empty(deps)) {
132 isl_union_map_free(deps);
133 isl_space_free(schedule_space);
134 return 1;
137 schedule_deps = isl_map_from_union_map(deps);
139 for (i = 0; i < dimension; i++)
140 schedule_deps = isl_map_equate(schedule_deps, isl_dim_out, i,
141 isl_dim_in, i);
143 test = isl_map_universe(isl_map_get_space(schedule_deps));
144 test = isl_map_equate(test, isl_dim_out, dimension, isl_dim_in,
145 dimension);
146 is_parallel = isl_map_is_subset(schedule_deps, test);
148 isl_space_free(schedule_space);
149 isl_map_free(test);
150 isl_map_free(schedule_deps);
152 return is_parallel;
155 /* Mark a for node openmp parallel, if it is the outermost parallel for node.
157 static void mark_openmp_parallel(__isl_keep isl_ast_build *build,
158 struct ast_build_userinfo *build_info,
159 struct ast_node_userinfo *node_info)
161 if (build_info->in_parallel_for)
162 return;
164 if (ast_schedule_dim_is_parallel(build, build_info->scop)) {
165 build_info->in_parallel_for = 1;
166 node_info->is_openmp = 1;
170 /* Allocate an ast_node_info structure and initialize it with default values.
172 static struct ast_node_userinfo *allocate_ast_node_userinfo()
174 struct ast_node_userinfo *node_info;
175 node_info = (struct ast_node_userinfo *)
176 malloc(sizeof(struct ast_node_userinfo));
177 node_info->is_openmp = 0;
178 return node_info;
181 /* Free an ast_node_info structure.
183 static void free_ast_node_userinfo(void *ptr)
185 struct ast_node_userinfo *info;
186 info = (struct ast_node_userinfo *) ptr;
187 free(info);
190 /* This method is executed before the construction of a for node. It creates
191 * an isl_id that is used to annotate the subsequently generated ast for nodes.
193 * In this function we also run the following analyses:
195 * - Detection of openmp parallel loops
197 static __isl_give isl_id *ast_build_before_for(
198 __isl_keep isl_ast_build *build, void *user)
200 isl_id *id;
201 struct ast_build_userinfo *build_info;
202 struct ast_node_userinfo *node_info;
204 build_info = (struct ast_build_userinfo *) user;
205 node_info = allocate_ast_node_userinfo();
206 id = isl_id_alloc(isl_ast_build_get_ctx(build), "", node_info);
207 id = isl_id_set_free_user(id, free_ast_node_userinfo);
209 mark_openmp_parallel(build, build_info, node_info);
211 return id;
214 /* This method is executed after the construction of a for node.
216 * It performs the following actions:
218 * - Reset the 'in_parallel_for' flag, as soon as we leave a for node,
219 * that is marked as openmp parallel.
222 static __isl_give isl_ast_node *ast_build_after_for(__isl_take isl_ast_node *node,
223 __isl_keep isl_ast_build *build, void *user) {
224 isl_id *id;
225 struct ast_build_userinfo *build_info;
226 struct ast_node_userinfo *info;
228 id = isl_ast_node_get_annotation(node);
229 info = isl_id_get_user(id);
231 if (info && info->is_openmp) {
232 build_info = (struct ast_build_userinfo *) user;
233 build_info->in_parallel_for = 0;
236 isl_id_free(id);
238 return node;
241 /* Print a memory access 'access' to the printer 'p'.
243 * "expr" refers to the original access.
244 * "access" is the list of index expressions transformed to refer
245 * to the iterators of the generated AST.
247 * In case the original access is unnamed (and presumably single-dimensional),
248 * we assume this is not a memory access, but just an expression.
250 static __isl_give isl_printer *print_access(__isl_take isl_printer *p,
251 struct pet_expr *expr, __isl_keep isl_ast_expr_list *access)
253 int i;
254 const char *name;
255 unsigned n_index;
257 n_index = isl_ast_expr_list_n_ast_expr(access);
258 name = isl_map_get_tuple_name(expr->acc.access, isl_dim_out);
260 if (name == NULL) {
261 isl_ast_expr *index;
262 index = isl_ast_expr_list_get_ast_expr(access, 0);
263 p = isl_printer_print_str(p, "(");
264 p = isl_printer_print_ast_expr(p, index);
265 p = isl_printer_print_str(p, ")");
266 isl_ast_expr_free(index);
267 return p;
270 p = isl_printer_print_str(p, name);
272 for (i = 0; i < n_index; ++i) {
273 isl_ast_expr *index;
275 index = isl_ast_expr_list_get_ast_expr(access, i);
277 p = isl_printer_print_str(p, "[");
278 p = isl_printer_print_ast_expr(p, index);
279 p = isl_printer_print_str(p, "]");
280 isl_ast_expr_free(index);
283 return p;
286 /* Find the element in scop->stmts that has the given "id".
288 static struct pet_stmt *find_stmt(struct ppcg_scop *scop, __isl_keep isl_id *id)
290 int i;
292 for (i = 0; i < scop->n_stmt; ++i) {
293 struct pet_stmt *stmt = scop->stmts[i];
294 isl_id *id_i;
296 id_i = isl_set_get_tuple_id(stmt->domain);
297 isl_id_free(id_i);
299 if (id_i == id)
300 return stmt;
303 isl_die(isl_id_get_ctx(id), isl_error_internal,
304 "statement not found", return NULL);
307 /* To print the transformed accesses we walk the list of transformed accesses
308 * simultaneously with the pet printer. This means that whenever
309 * the pet printer prints a pet access expression we have
310 * the corresponding transformed access available for printing.
312 static __isl_give isl_printer *print_access_expr(__isl_take isl_printer *p,
313 struct pet_expr *expr, void *user)
315 isl_ast_expr_list ***access = user;
317 p = print_access(p, expr, **access);
318 (*access)++;
320 return p;
323 /* Print a user statement in the generated AST.
324 * The ppcg_stmt has been attached to the node in at_each_domain.
326 static __isl_give isl_printer *print_user(__isl_take isl_printer *p,
327 __isl_take isl_ast_print_options *print_options,
328 __isl_keep isl_ast_node *node, void *user)
330 struct ppcg_stmt *stmt;
331 isl_ast_expr_list **access;
332 isl_id *id;
334 id = isl_ast_node_get_annotation(node);
335 stmt = isl_id_get_user(id);
336 isl_id_free(id);
338 access = stmt->access;
340 p = isl_printer_start_line(p);
341 p = print_pet_expr(p, stmt->stmt->body, &print_access_expr, &access);
342 p = isl_printer_print_str(p, ";");
343 p = isl_printer_end_line(p);
345 isl_ast_print_options_free(print_options);
347 return p;
351 /* Print a for loop node as an openmp parallel loop.
353 * To print an openmp parallel loop we print a normal for loop, but add
354 * "#pragma openmp parallel for" in front.
356 * Variables that are declared within the body of this for loop are
357 * automatically openmp 'private'. Iterators declared outside of the
358 * for loop are automatically openmp 'shared'. As ppcg declares all iterators
359 * at the position where they are assigned, there is no need to explicitly mark
360 * variables. Their automatically assigned type is already correct.
362 * This function only generates valid OpenMP code, if the ast was generated
363 * with the 'atomic-bounds' option enabled.
366 static __isl_give isl_printer *print_for_with_openmp(
367 __isl_keep isl_ast_node *node, __isl_take isl_printer *p,
368 __isl_take isl_ast_print_options *print_options)
370 p = isl_printer_start_line(p);
371 p = isl_printer_print_str(p, "#pragma omp parallel for");
372 p = isl_printer_end_line(p);
374 p = isl_ast_node_for_print(node, p, print_options);
376 return p;
379 /* Print a for node.
381 * Depending on how the node is annotated, we either print a normal
382 * for node or an openmp parallel for node.
384 static __isl_give isl_printer *print_for(__isl_take isl_printer *p,
385 __isl_take isl_ast_print_options *print_options,
386 __isl_keep isl_ast_node *node, void *user)
388 struct ppcg_print_info *print_info;
389 isl_id *id;
390 int openmp;
392 openmp = 0;
393 id = isl_ast_node_get_annotation(node);
395 if (id) {
396 struct ast_node_userinfo *info;
398 info = (struct ast_node_userinfo *) isl_id_get_user(id);
399 if (info && info->is_openmp)
400 openmp = 1;
403 if (openmp)
404 p = print_for_with_openmp(node, p, print_options);
405 else
406 p = isl_ast_node_for_print(node, p, print_options);
408 isl_id_free(id);
410 return p;
413 /* Call "fn" on each access expression in "expr".
415 static int foreach_access_expr(struct pet_expr *expr,
416 int (*fn)(struct pet_expr *expr, void *user), void *user)
418 int i;
420 if (!expr)
421 return -1;
423 if (expr->type == pet_expr_access)
424 return fn(expr, user);
426 for (i = 0; i < expr->n_arg; ++i)
427 if (foreach_access_expr(expr->args[i], fn, user) < 0)
428 return -1;
430 return 0;
433 static int inc_n_access(struct pet_expr *expr, void *user)
435 struct ppcg_stmt *stmt = user;
436 stmt->n_access++;
437 return 0;
440 /* Internal data for add_access.
442 * "stmt" is the statement to which an access needs to be added.
443 * "build" is the current AST build.
444 * "map" maps the AST loop iterators to the iteration domain of the statement.
446 struct ppcg_add_access_data {
447 struct ppcg_stmt *stmt;
448 isl_ast_build *build;
449 isl_map *map;
452 /* Given an access expression, add it to data->stmt after
453 * transforming it to refer to the AST loop iterators.
455 static int add_access(struct pet_expr *expr, void *user)
457 int i, n;
458 isl_ctx *ctx;
459 isl_map *access;
460 isl_pw_multi_aff *pma;
461 struct ppcg_add_access_data *data = user;
462 isl_ast_expr_list *index;
464 ctx = isl_map_get_ctx(expr->acc.access);
465 n = isl_map_dim(expr->acc.access, isl_dim_out);
466 access = isl_map_copy(expr->acc.access);
467 access = isl_map_apply_range(isl_map_copy(data->map), access);
468 pma = isl_pw_multi_aff_from_map(access);
469 pma = isl_pw_multi_aff_coalesce(pma);
471 index = isl_ast_expr_list_alloc(ctx, n);
472 for (i = 0; i < n; ++i) {
473 isl_pw_aff *pa;
474 isl_ast_expr *expr;
476 pa = isl_pw_multi_aff_get_pw_aff(pma, i);
477 expr = isl_ast_build_expr_from_pw_aff(data->build, pa);
478 index = isl_ast_expr_list_add(index, expr);
480 isl_pw_multi_aff_free(pma);
482 data->stmt->access[data->stmt->n_access] = index;
483 data->stmt->n_access++;
484 return 0;
487 /* Transform the accesses in the statement associated to the domain
488 * called by "node" to refer to the AST loop iterators,
489 * collect them in a ppcg_stmt and annotate the node with the ppcg_stmt.
491 static __isl_give isl_ast_node *at_each_domain(__isl_take isl_ast_node *node,
492 __isl_keep isl_ast_build *build, void *user)
494 struct ppcg_scop *scop = user;
495 isl_ast_expr *expr, *arg;
496 isl_ctx *ctx;
497 isl_id *id;
498 isl_map *map;
499 struct ppcg_stmt *stmt;
500 struct ppcg_add_access_data data;
502 ctx = isl_ast_node_get_ctx(node);
503 stmt = isl_calloc_type(ctx, struct ppcg_stmt);
504 if (!stmt)
505 goto error;
507 expr = isl_ast_node_user_get_expr(node);
508 arg = isl_ast_expr_get_op_arg(expr, 0);
509 isl_ast_expr_free(expr);
510 id = isl_ast_expr_get_id(arg);
511 isl_ast_expr_free(arg);
512 stmt->stmt = find_stmt(scop, id);
513 isl_id_free(id);
514 if (!stmt->stmt)
515 goto error;
517 stmt->n_access = 0;
518 if (foreach_access_expr(stmt->stmt->body, &inc_n_access, stmt) < 0)
519 goto error;
521 stmt->access = isl_calloc_array(ctx, isl_ast_expr_list *,
522 stmt->n_access);
523 if (!stmt->access)
524 goto error;
526 map = isl_map_from_union_map(isl_ast_build_get_schedule(build));
527 map = isl_map_reverse(map);
529 stmt->n_access = 0;
530 data.stmt = stmt;
531 data.build = build;
532 data.map = map;
533 if (foreach_access_expr(stmt->stmt->body, &add_access, &data) < 0)
534 node = isl_ast_node_free(node);
536 isl_map_free(map);
538 id = isl_id_alloc(isl_ast_node_get_ctx(node), NULL, stmt);
539 id = isl_id_set_free_user(id, &ppcg_stmt_free);
540 return isl_ast_node_set_annotation(node, id);
541 error:
542 ppcg_stmt_free(stmt);
543 return isl_ast_node_free(node);
546 /* Code generate the scop 'scop' and print the corresponding C code to 'p'.
548 static __isl_give isl_printer *print_scop(struct ppcg_scop *scop,
549 __isl_take isl_printer *p, struct ppcg_options *options)
551 isl_ctx *ctx = isl_printer_get_ctx(p);
552 isl_set *context;
553 isl_union_set *domain_set;
554 isl_union_map *schedule_map;
555 isl_ast_build *build;
556 isl_ast_print_options *print_options;
557 isl_ast_node *tree;
558 struct ast_build_userinfo build_info;
560 context = isl_set_copy(scop->context);
561 domain_set = isl_union_set_copy(scop->domain);
562 schedule_map = isl_union_map_copy(scop->schedule);
563 schedule_map = isl_union_map_intersect_domain(schedule_map, domain_set);
565 build = isl_ast_build_from_context(context);
566 build = isl_ast_build_set_at_each_domain(build, &at_each_domain, scop);
568 if (options->openmp) {
569 build_info.scop = scop;
570 build_info.in_parallel_for = 0;
572 build = isl_ast_build_set_before_each_for(build,
573 &ast_build_before_for,
574 &build_info);
575 build = isl_ast_build_set_after_each_for(build,
576 &ast_build_after_for,
577 &build_info);
580 tree = isl_ast_build_ast_from_schedule(build, schedule_map);
581 isl_ast_build_free(build);
583 print_options = isl_ast_print_options_alloc(ctx);
584 print_options = isl_ast_print_options_set_print_user(print_options,
585 &print_user, NULL);
587 print_options = isl_ast_print_options_set_print_for(print_options,
588 &print_for, NULL);
590 p = isl_ast_node_print_macros(tree, p);
591 p = isl_ast_node_print(tree, p, print_options);
593 isl_ast_node_free(tree);
595 return p;
598 /* Does "scop" refer to any arrays that are declared, but not
599 * exposed to the code after the scop?
601 static int any_hidden_declarations(struct ppcg_scop *scop)
603 int i;
605 if (!scop)
606 return 0;
608 for (i = 0; i < scop->n_array; ++i)
609 if (scop->arrays[i]->declared && !scop->arrays[i]->exposed)
610 return 1;
612 return 0;
615 /* Generate CPU code for the scop "ps" and print the corresponding C code
616 * to "p", including variable declarations.
618 __isl_give isl_printer *print_cpu(__isl_take isl_printer *p,
619 struct ppcg_scop *ps, struct ppcg_options *options)
621 int hidden;
623 p = isl_printer_start_line(p);
624 p = isl_printer_print_str(p, "/* ppcg generated CPU code */");
625 p = isl_printer_end_line(p);
627 p = isl_printer_start_line(p);
628 p = isl_printer_end_line(p);
630 p = ppcg_print_exposed_declarations(p, ps);
631 hidden = any_hidden_declarations(ps);
632 if (hidden) {
633 p = ppcg_start_block(p);
634 p = ppcg_print_hidden_declarations(p, ps);
636 p = print_scop(ps, p, options);
637 if (hidden)
638 p = ppcg_end_block(p);
640 return p;
643 int generate_cpu(isl_ctx *ctx, struct ppcg_scop *ps,
644 struct ppcg_options *options, const char *input, const char *output)
646 FILE *input_file;
647 FILE *output_file;
648 isl_printer *p;
650 if (!ps)
651 return -1;
653 input_file = fopen(input, "r");
654 output_file = get_output_file(input, output);
656 copy(input_file, output_file, 0, ps->start);
657 p = isl_printer_to_file(ctx, output_file);
658 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
659 p = print_cpu(p, ps, options);
660 isl_printer_free(p);
661 copy(input_file, output_file, ps->end, -1);
663 fclose(output_file);
664 fclose(input_file);
666 return 0;