update isl for isl_map_affine_hull optimization
[ppcg.git] / cpu.c
blob6e02d7ae9b92273b47eba7628259d2ab535fd005
1 /*
2 * Copyright 2012 INRIA Paris-Rocquencourt
4 * Use of this software is governed by the GNU LGPLv2.1 license
6 * Written by Tobias Grosser, INRIA Paris-Rocquencourt,
7 * Domaine de Voluceau, Rocquenqourt, B.P. 105,
8 * 78153 Le Chesnay Cedex France
9 */
11 #include <limits.h>
12 #include <stdio.h>
14 #include <isl/aff.h>
15 #include <isl/ctx.h>
16 #include <isl/map.h>
17 #include <isl/ast_build.h>
18 #include <pet.h>
20 #include "ppcg.h"
21 #include "ppcg_options.h"
22 #include "cpu.h"
23 #include "pet_printer.h"
24 #include "print.h"
25 #include "rewrite.h"
27 /* Representation of a statement inside a generated AST.
29 * "stmt" refers to the original statement.
30 * "n_access" is the number of accesses in the statement.
31 * "access" is the list of accesses transformed to refer to the iterators
32 * in the generated AST.
34 struct ppcg_stmt {
35 struct pet_stmt *stmt;
37 int n_access;
38 isl_ast_expr_list **access;
41 static void ppcg_stmt_free(void *user)
43 struct ppcg_stmt *stmt = user;
44 int i;
46 if (!stmt)
47 return;
49 for (i = 0; i < stmt->n_access; ++i)
50 isl_ast_expr_list_free(stmt->access[i]);
52 free(stmt->access);
53 free(stmt);
56 /* Derive the output file name from the input file name.
57 * 'input' is the entire path of the input file. The output
58 * is the file name plus the additional extension.
60 * We will basically replace everything after the last point
61 * with '.ppcg.c'. This means file.c becomes file.ppcg.c
63 static FILE *get_output_file(const char *input, const char *output)
65 char name[PATH_MAX];
66 const char *base;
67 const char *ext;
68 const char ppcg_marker[] = ".ppcg";
69 int len;
71 base = strrchr(input, '/');
72 if (base)
73 base++;
74 else
75 base = input;
76 ext = strrchr(base, '.');
77 len = ext ? ext - base : strlen(base);
79 memcpy(name, base, len);
80 strcpy(name + len, ppcg_marker);
81 strcpy(name + len + sizeof(ppcg_marker) - 1, ext);
83 if (!output)
84 output = name;
86 return fopen(output, "w");
89 /* Data used to annotate for nodes in the ast.
91 struct ast_node_userinfo {
92 /* The for node is an openmp parallel for node. */
93 int is_openmp;
96 /* Information used while building the ast.
98 struct ast_build_userinfo {
99 /* The current ppcg scop. */
100 struct ppcg_scop *scop;
102 /* Are we currently in a parallel for loop? */
103 int in_parallel_for;
106 /* Check if the current scheduling dimension is parallel.
108 * We check for parallelism by verifying that the loop does not carry any
109 * dependences.
111 * Parallelism test: if the distance is zero in all outer dimensions, then it
112 * has to be zero in the current dimension as well.
113 * Implementation: first, translate dependences into time space, then force
114 * outer dimensions to be equal. If the distance is zero in the current
115 * dimension, then the loop is parallel.
116 * The distance is zero in the current dimension if it is a subset of a map
117 * with equal values for the current dimension.
119 static int ast_schedule_dim_is_parallel(__isl_keep isl_ast_build *build,
120 struct ppcg_scop *scop)
122 isl_union_map *schedule_node, *schedule, *deps;
123 isl_map *schedule_deps, *test;
124 isl_space *schedule_space;
125 unsigned i, dimension, is_parallel;
127 schedule = isl_ast_build_get_schedule(build);
128 schedule_space = isl_ast_build_get_schedule_space(build);
130 dimension = isl_space_dim(schedule_space, isl_dim_out) - 1;
132 deps = isl_union_map_copy(scop->dep_flow);
133 deps = isl_union_map_union(deps, isl_union_map_copy(scop->dep_false));
134 deps = isl_union_map_apply_range(deps, isl_union_map_copy(schedule));
135 deps = isl_union_map_apply_domain(deps, schedule);
137 if (isl_union_map_is_empty(deps)) {
138 isl_union_map_free(deps);
139 isl_space_free(schedule_space);
140 return 1;
143 schedule_deps = isl_map_from_union_map(deps);
145 for (i = 0; i < dimension; i++)
146 schedule_deps = isl_map_equate(schedule_deps, isl_dim_out, i,
147 isl_dim_in, i);
149 test = isl_map_universe(isl_map_get_space(schedule_deps));
150 test = isl_map_equate(test, isl_dim_out, dimension, isl_dim_in,
151 dimension);
152 is_parallel = isl_map_is_subset(schedule_deps, test);
154 isl_space_free(schedule_space);
155 isl_map_free(test);
156 isl_map_free(schedule_deps);
158 return is_parallel;
161 /* Mark a for node openmp parallel, if it is the outermost parallel for node.
163 static void mark_openmp_parallel(__isl_keep isl_ast_build *build,
164 struct ast_build_userinfo *build_info,
165 struct ast_node_userinfo *node_info)
167 if (build_info->in_parallel_for)
168 return;
170 if (ast_schedule_dim_is_parallel(build, build_info->scop)) {
171 build_info->in_parallel_for = 1;
172 node_info->is_openmp = 1;
176 /* Allocate an ast_node_info structure and initialize it with default values.
178 static struct ast_node_userinfo *allocate_ast_node_userinfo()
180 struct ast_node_userinfo *node_info;
181 node_info = (struct ast_node_userinfo *)
182 malloc(sizeof(struct ast_node_userinfo));
183 node_info->is_openmp = 0;
184 return node_info;
187 /* Free an ast_node_info structure.
189 static void free_ast_node_userinfo(void *ptr)
191 struct ast_node_userinfo *info;
192 info = (struct ast_node_userinfo *) ptr;
193 free(info);
196 /* This method is executed before the construction of a for node. It creates
197 * an isl_id that is used to annotate the subsequently generated ast for nodes.
199 * In this function we also run the following analyses:
201 * - Detection of openmp parallel loops
203 static __isl_give isl_id *ast_build_before_for(
204 __isl_keep isl_ast_build *build, void *user)
206 isl_id *id;
207 struct ast_build_userinfo *build_info;
208 struct ast_node_userinfo *node_info;
210 build_info = (struct ast_build_userinfo *) user;
211 node_info = allocate_ast_node_userinfo();
212 id = isl_id_alloc(isl_ast_build_get_ctx(build), "", node_info);
213 id = isl_id_set_free_user(id, free_ast_node_userinfo);
215 mark_openmp_parallel(build, build_info, node_info);
217 return id;
220 /* This method is executed after the construction of a for node.
222 * It performs the following actions:
224 * - Reset the 'in_parallel_for' flag, as soon as we leave a for node,
225 * that is marked as openmp parallel.
228 static __isl_give isl_ast_node *ast_build_after_for(__isl_take isl_ast_node *node,
229 __isl_keep isl_ast_build *build, void *user) {
230 isl_id *id;
231 struct ast_build_userinfo *build_info;
232 struct ast_node_userinfo *info;
234 id = isl_ast_node_get_annotation(node);
235 info = isl_id_get_user(id);
237 if (info && info->is_openmp) {
238 build_info = (struct ast_build_userinfo *) user;
239 build_info->in_parallel_for = 0;
242 isl_id_free(id);
244 return node;
247 /* Print a memory access 'access' to the printer 'p'.
249 * "expr" refers to the original access.
250 * "access" is the list of index expressions transformed to refer
251 * to the iterators of the generated AST.
253 * In case the original access is unnamed (and presumably single-dimensional),
254 * we assume this is not a memory access, but just an expression.
256 static __isl_give isl_printer *print_access(__isl_take isl_printer *p,
257 struct pet_expr *expr, __isl_keep isl_ast_expr_list *access)
259 int i;
260 const char *name;
261 unsigned n_index;
263 n_index = isl_ast_expr_list_n_ast_expr(access);
264 name = isl_map_get_tuple_name(expr->acc.access, isl_dim_out);
266 if (name == NULL) {
267 isl_ast_expr *index;
268 index = isl_ast_expr_list_get_ast_expr(access, 0);
269 p = isl_printer_print_str(p, "(");
270 p = isl_printer_print_ast_expr(p, index);
271 p = isl_printer_print_str(p, ")");
272 isl_ast_expr_free(index);
273 return p;
276 p = isl_printer_print_str(p, name);
278 for (i = 0; i < n_index; ++i) {
279 isl_ast_expr *index;
281 index = isl_ast_expr_list_get_ast_expr(access, i);
283 p = isl_printer_print_str(p, "[");
284 p = isl_printer_print_ast_expr(p, index);
285 p = isl_printer_print_str(p, "]");
286 isl_ast_expr_free(index);
289 return p;
292 /* Find the element in scop->stmts that has the given "id".
294 static struct pet_stmt *find_stmt(struct ppcg_scop *scop, __isl_keep isl_id *id)
296 int i;
298 for (i = 0; i < scop->n_stmt; ++i) {
299 struct pet_stmt *stmt = scop->stmts[i];
300 isl_id *id_i;
302 id_i = isl_set_get_tuple_id(stmt->domain);
303 isl_id_free(id_i);
305 if (id_i == id)
306 return stmt;
309 isl_die(isl_id_get_ctx(id), isl_error_internal,
310 "statement not found", return NULL);
313 /* To print the transformed accesses we walk the list of transformed accesses
314 * simultaneously with the pet printer. This means that whenever
315 * the pet printer prints a pet access expression we have
316 * the corresponding transformed access available for printing.
318 static __isl_give isl_printer *print_access_expr(__isl_take isl_printer *p,
319 struct pet_expr *expr, void *user)
321 isl_ast_expr_list ***access = user;
323 p = print_access(p, expr, **access);
324 (*access)++;
326 return p;
329 /* Print a user statement in the generated AST.
330 * The ppcg_stmt has been attached to the node in at_each_domain.
332 static __isl_give isl_printer *print_user(__isl_take isl_printer *p,
333 __isl_take isl_ast_print_options *print_options,
334 __isl_keep isl_ast_node *node, void *user)
336 struct ppcg_stmt *stmt;
337 isl_ast_expr_list **access;
338 isl_id *id;
340 id = isl_ast_node_get_annotation(node);
341 stmt = isl_id_get_user(id);
342 isl_id_free(id);
344 access = stmt->access;
346 p = isl_printer_start_line(p);
347 p = print_pet_expr(p, stmt->stmt->body, &print_access_expr, &access);
348 p = isl_printer_print_str(p, ";");
349 p = isl_printer_end_line(p);
351 isl_ast_print_options_free(print_options);
353 return p;
357 /* Print a for loop node as an openmp parallel loop.
359 * To print an openmp parallel loop we print a normal for loop, but add
360 * "#pragma openmp parallel for" in front.
362 * Variables that are declared within the body of this for loop are
363 * automatically openmp 'private'. Iterators declared outside of the
364 * for loop are automatically openmp 'shared'. As ppcg declares all iterators
365 * at the position where they are assigned, there is no need to explicitly mark
366 * variables. Their automatically assigned type is already correct.
368 * This function only generates valid OpenMP code, if the ast was generated
369 * with the 'atomic-bounds' option enabled.
372 static __isl_give isl_printer *print_for_with_openmp(
373 __isl_keep isl_ast_node *node, __isl_take isl_printer *p,
374 __isl_take isl_ast_print_options *print_options)
376 p = isl_printer_start_line(p);
377 p = isl_printer_print_str(p, "#pragma omp parallel for");
378 p = isl_printer_end_line(p);
380 p = isl_ast_node_for_print(node, p, print_options);
382 return p;
385 /* Print a for node.
387 * Depending on how the node is annotated, we either print a normal
388 * for node or an openmp parallel for node.
390 static __isl_give isl_printer *print_for(__isl_take isl_printer *p,
391 __isl_take isl_ast_print_options *print_options,
392 __isl_keep isl_ast_node *node, void *user)
394 struct ppcg_print_info *print_info;
395 isl_id *id;
396 int openmp;
398 openmp = 0;
399 id = isl_ast_node_get_annotation(node);
401 if (id) {
402 struct ast_node_userinfo *info;
404 info = (struct ast_node_userinfo *) isl_id_get_user(id);
405 if (info && info->is_openmp)
406 openmp = 1;
409 if (openmp)
410 p = print_for_with_openmp(node, p, print_options);
411 else
412 p = isl_ast_node_for_print(node, p, print_options);
414 isl_id_free(id);
416 return p;
419 /* Call "fn" on each access expression in "expr".
421 static int foreach_access_expr(struct pet_expr *expr,
422 int (*fn)(struct pet_expr *expr, void *user), void *user)
424 int i;
426 if (!expr)
427 return -1;
429 if (expr->type == pet_expr_access)
430 return fn(expr, user);
432 for (i = 0; i < expr->n_arg; ++i)
433 if (foreach_access_expr(expr->args[i], fn, user) < 0)
434 return -1;
436 return 0;
439 static int inc_n_access(struct pet_expr *expr, void *user)
441 struct ppcg_stmt *stmt = user;
442 stmt->n_access++;
443 return 0;
446 /* Internal data for add_access.
448 * "stmt" is the statement to which an access needs to be added.
449 * "build" is the current AST build.
450 * "map" maps the AST loop iterators to the iteration domain of the statement.
452 struct ppcg_add_access_data {
453 struct ppcg_stmt *stmt;
454 isl_ast_build *build;
455 isl_map *map;
458 /* Given an access expression, add it to data->stmt after
459 * transforming it to refer to the AST loop iterators.
461 static int add_access(struct pet_expr *expr, void *user)
463 int i, n;
464 isl_ctx *ctx;
465 isl_map *access;
466 isl_pw_multi_aff *pma;
467 struct ppcg_add_access_data *data = user;
468 isl_ast_expr_list *index;
470 ctx = isl_map_get_ctx(expr->acc.access);
471 n = isl_map_dim(expr->acc.access, isl_dim_out);
472 access = isl_map_copy(expr->acc.access);
473 access = isl_map_apply_range(isl_map_copy(data->map), access);
474 pma = isl_pw_multi_aff_from_map(access);
475 pma = isl_pw_multi_aff_coalesce(pma);
477 index = isl_ast_expr_list_alloc(ctx, n);
478 for (i = 0; i < n; ++i) {
479 isl_pw_aff *pa;
480 isl_ast_expr *expr;
482 pa = isl_pw_multi_aff_get_pw_aff(pma, i);
483 expr = isl_ast_build_expr_from_pw_aff(data->build, pa);
484 index = isl_ast_expr_list_add(index, expr);
486 isl_pw_multi_aff_free(pma);
488 data->stmt->access[data->stmt->n_access] = index;
489 data->stmt->n_access++;
490 return 0;
493 /* Transform the accesses in the statement associated to the domain
494 * called by "node" to refer to the AST loop iterators,
495 * collect them in a ppcg_stmt and annotate the node with the ppcg_stmt.
497 static __isl_give isl_ast_node *at_each_domain(__isl_take isl_ast_node *node,
498 __isl_keep isl_ast_build *build, void *user)
500 struct ppcg_scop *scop = user;
501 isl_ast_expr *expr, *arg;
502 isl_ctx *ctx;
503 isl_id *id;
504 isl_map *map;
505 struct ppcg_stmt *stmt;
506 struct ppcg_add_access_data data;
508 ctx = isl_ast_node_get_ctx(node);
509 stmt = isl_calloc_type(ctx, struct ppcg_stmt);
510 if (!stmt)
511 goto error;
513 expr = isl_ast_node_user_get_expr(node);
514 arg = isl_ast_expr_get_op_arg(expr, 0);
515 isl_ast_expr_free(expr);
516 id = isl_ast_expr_get_id(arg);
517 isl_ast_expr_free(arg);
518 stmt->stmt = find_stmt(scop, id);
519 isl_id_free(id);
520 if (!stmt->stmt)
521 goto error;
523 stmt->n_access = 0;
524 if (foreach_access_expr(stmt->stmt->body, &inc_n_access, stmt) < 0)
525 goto error;
527 stmt->access = isl_calloc_array(ctx, isl_ast_expr_list *,
528 stmt->n_access);
529 if (!stmt->access)
530 goto error;
532 map = isl_map_from_union_map(isl_ast_build_get_schedule(build));
533 map = isl_map_reverse(map);
535 stmt->n_access = 0;
536 data.stmt = stmt;
537 data.build = build;
538 data.map = map;
539 if (foreach_access_expr(stmt->stmt->body, &add_access, &data) < 0)
540 node = isl_ast_node_free(node);
542 isl_map_free(map);
544 id = isl_id_alloc(isl_ast_node_get_ctx(node), NULL, stmt);
545 id = isl_id_set_free_user(id, &ppcg_stmt_free);
546 return isl_ast_node_set_annotation(node, id);
547 error:
548 ppcg_stmt_free(stmt);
549 return isl_ast_node_free(node);
552 /* Code generate the scop 'scop' and print the corresponding C code to 'p'.
554 static __isl_give isl_printer *print_scop(isl_ctx *ctx, struct ppcg_scop *scop,
555 __isl_take isl_printer *p, struct ppcg_options *options)
557 isl_set *context;
558 isl_union_set *domain_set;
559 isl_union_map *schedule_map;
560 isl_ast_build *build;
561 isl_ast_print_options *print_options;
562 isl_ast_node *tree;
563 struct ast_build_userinfo build_info;
565 context = isl_set_copy(scop->context);
566 domain_set = isl_union_set_copy(scop->domain);
567 schedule_map = isl_union_map_copy(scop->schedule);
568 schedule_map = isl_union_map_intersect_domain(schedule_map, domain_set);
570 build = isl_ast_build_from_context(context);
571 build = isl_ast_build_set_at_each_domain(build, &at_each_domain, scop);
573 if (options->openmp) {
574 build_info.scop = scop;
575 build_info.in_parallel_for = 0;
577 build = isl_ast_build_set_before_each_for(build,
578 &ast_build_before_for,
579 &build_info);
580 build = isl_ast_build_set_after_each_for(build,
581 &ast_build_after_for,
582 &build_info);
585 tree = isl_ast_build_ast_from_schedule(build, schedule_map);
586 isl_ast_build_free(build);
588 print_options = isl_ast_print_options_alloc(ctx);
589 print_options = isl_ast_print_options_set_print_user(print_options,
590 &print_user, NULL);
592 print_options = isl_ast_print_options_set_print_for(print_options,
593 &print_for, NULL);
595 p = isl_ast_node_print_macros(tree, p);
596 p = isl_ast_node_print(tree, p, print_options);
598 isl_ast_node_free(tree);
600 return p;
603 /* Does "scop" refer to any arrays that are declared, but not
604 * exposed to the code after the scop?
606 static int any_hidden_declarations(struct ppcg_scop *scop)
608 int i;
610 if (!scop)
611 return 0;
613 for (i = 0; i < scop->n_array; ++i)
614 if (scop->arrays[i]->declared && !scop->arrays[i]->exposed)
615 return 1;
617 return 0;
620 int generate_cpu(isl_ctx *ctx, struct ppcg_scop *ps,
621 struct ppcg_options *options, const char *input, const char *output)
623 FILE *input_file;
624 FILE *output_file;
625 isl_printer *p;
626 int hidden;
628 if (!ps)
629 return -1;
631 input_file = fopen(input, "r");
632 output_file = get_output_file(input, output);
634 copy(input_file, output_file, 0, ps->start);
635 fprintf(output_file, "/* ppcg generated CPU code */\n\n");
636 p = isl_printer_to_file(ctx, output_file);
637 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
638 p = ppcg_print_exposed_declarations(p, ps);
639 hidden = any_hidden_declarations(ps);
640 if (hidden) {
641 p = ppcg_start_block(p);
642 p = ppcg_print_hidden_declarations(p, ps);
644 p = print_scop(ctx, ps, p, options);
645 if (hidden)
646 p = ppcg_end_block(p);
647 isl_printer_free(p);
648 copy(input_file, output_file, ps->end, -1);
650 fclose(output_file);
651 fclose(input_file);
653 return 0;