gpu.c: extract false dependence computation
[ppcg.git] / cpu.c
blob7b39e39aa98a9c07811f00316accfcf954c7fc07
1 /*
2 * Copyright 2012 INRIA Paris-Rocquencourt
4 * Use of this software is governed by the GNU LGPLv2.1 license
6 * Written by Tobias Grosser, INRIA Paris-Rocquencourt,
7 * Domaine de Voluceau, Rocquenqourt, B.P. 105,
8 * 78153 Le Chesnay Cedex France
9 */
11 #include <limits.h>
12 #include <stdio.h>
14 #include <isl/aff.h>
15 #include <isl/ctx.h>
16 #include <isl/map.h>
17 #include <isl/ast_build.h>
18 #include <pet.h>
20 #include "ppcg.h"
21 #include "cpu.h"
22 #include "pet_printer.h"
23 #include "print.h"
24 #include "rewrite.h"
26 /* Representation of a statement inside a generated AST.
28 * "stmt" refers to the original statement.
29 * "n_access" is the number of accesses in the statement.
30 * "access" is the list of accesses transformed to refer to the iterators
31 * in the generated AST.
33 struct ppcg_stmt {
34 struct pet_stmt *stmt;
36 int n_access;
37 isl_ast_expr_list **access;
40 static void ppcg_stmt_free(void *user)
42 struct ppcg_stmt *stmt = user;
43 int i;
45 if (!stmt)
46 return;
48 for (i = 0; i < stmt->n_access; ++i)
49 isl_ast_expr_list_free(stmt->access[i]);
51 free(stmt->access);
52 free(stmt);
55 /* Derive the output file name from the input file name.
56 * 'input' is the entire path of the input file. The output
57 * is the file name plus the additional extension.
59 * We will basically replace everything after the last point
60 * with '.ppcg.c'. This means file.c becomes file.ppcg.c
62 static FILE *get_output_file(const char *input, const char *output)
64 char name[PATH_MAX];
65 const char *base;
66 const char *ext;
67 const char ppcg_marker[] = ".ppcg";
68 int len;
70 base = strrchr(input, '/');
71 if (base)
72 base++;
73 else
74 base = input;
75 ext = strrchr(base, '.');
76 len = ext ? ext - base : strlen(base);
78 memcpy(name, base, len);
79 strcpy(name + len, ppcg_marker);
80 strcpy(name + len + sizeof(ppcg_marker) - 1, ext);
82 if (!output)
83 output = name;
85 return fopen(output, "w");
88 /* Print a memory access 'access' to the printer 'p'.
90 * "expr" refers to the original access.
91 * "access" is the list of index expressions transformed to refer
92 * to the iterators of the generated AST.
94 * In case the original access is unnamed (and presumably single-dimensional),
95 * we assume this is not a memory access, but just an expression.
97 static __isl_give isl_printer *print_access(__isl_take isl_printer *p,
98 struct pet_expr *expr, __isl_keep isl_ast_expr_list *access)
100 int i;
101 const char *name;
102 unsigned n_index;
104 n_index = isl_ast_expr_list_n_ast_expr(access);
105 name = isl_map_get_tuple_name(expr->acc.access, isl_dim_out);
107 if (name == NULL) {
108 isl_ast_expr *index;
109 index = isl_ast_expr_list_get_ast_expr(access, 0);
110 p = isl_printer_print_str(p, "(");
111 p = isl_printer_print_ast_expr(p, index);
112 p = isl_printer_print_str(p, ")");
113 isl_ast_expr_free(index);
114 return p;
117 p = isl_printer_print_str(p, name);
119 for (i = 0; i < n_index; ++i) {
120 isl_ast_expr *index;
122 index = isl_ast_expr_list_get_ast_expr(access, i);
124 p = isl_printer_print_str(p, "[");
125 p = isl_printer_print_ast_expr(p, index);
126 p = isl_printer_print_str(p, "]");
127 isl_ast_expr_free(index);
130 return p;
133 /* Find the element in scop->stmts that has the given "id".
135 static struct pet_stmt *find_stmt(struct ppcg_scop *scop, __isl_keep isl_id *id)
137 int i;
139 for (i = 0; i < scop->n_stmt; ++i) {
140 struct pet_stmt *stmt = scop->stmts[i];
141 isl_id *id_i;
143 id_i = isl_set_get_tuple_id(stmt->domain);
144 isl_id_free(id_i);
146 if (id_i == id)
147 return stmt;
150 isl_die(isl_id_get_ctx(id), isl_error_internal,
151 "statement not found", return NULL);
154 /* To print the transformed accesses we walk the list of transformed accesses
155 * simultaneously with the pet printer. This means that whenever
156 * the pet printer prints a pet access expression we have
157 * the corresponding transformed access available for printing.
159 static __isl_give isl_printer *print_access_expr(__isl_take isl_printer *p,
160 struct pet_expr *expr, void *user)
162 isl_ast_expr_list ***access = user;
164 p = print_access(p, expr, **access);
165 (*access)++;
167 return p;
170 /* Print a user statement in the generated AST.
171 * The ppcg_stmt has been attached to the node in at_each_domain.
173 static __isl_give isl_printer *print_user(__isl_take isl_printer *p,
174 __isl_take isl_ast_print_options *print_options,
175 __isl_keep isl_ast_node *node, void *user)
177 struct ppcg_stmt *stmt;
178 isl_ast_expr_list **access;
179 isl_id *id;
181 id = isl_ast_node_get_annotation(node);
182 stmt = isl_id_get_user(id);
183 isl_id_free(id);
185 access = stmt->access;
187 p = isl_printer_start_line(p);
188 p = print_pet_expr(p, stmt->stmt->body, &print_access_expr, &access);
189 p = isl_printer_print_str(p, ";");
190 p = isl_printer_end_line(p);
192 isl_ast_print_options_free(print_options);
194 return p;
197 /* Call "fn" on each access expression in "expr".
199 static int foreach_access_expr(struct pet_expr *expr,
200 int (*fn)(struct pet_expr *expr, void *user), void *user)
202 int i;
204 if (!expr)
205 return -1;
207 if (expr->type == pet_expr_access)
208 return fn(expr, user);
210 for (i = 0; i < expr->n_arg; ++i)
211 if (foreach_access_expr(expr->args[i], fn, user) < 0)
212 return -1;
214 return 0;
217 static int inc_n_access(struct pet_expr *expr, void *user)
219 struct ppcg_stmt *stmt = user;
220 stmt->n_access++;
221 return 0;
224 /* Internal data for add_access.
226 * "stmt" is the statement to which an access needs to be added.
227 * "build" is the current AST build.
228 * "map" maps the AST loop iterators to the iteration domain of the statement.
230 struct ppcg_add_access_data {
231 struct ppcg_stmt *stmt;
232 isl_ast_build *build;
233 isl_map *map;
236 /* Given an access expression, add it to data->stmt after
237 * transforming it to refer to the AST loop iterators.
239 static int add_access(struct pet_expr *expr, void *user)
241 int i, n;
242 isl_ctx *ctx;
243 isl_map *access;
244 isl_pw_multi_aff *pma;
245 struct ppcg_add_access_data *data = user;
246 isl_ast_expr_list *index;
248 ctx = isl_map_get_ctx(expr->acc.access);
249 n = isl_map_dim(expr->acc.access, isl_dim_out);
250 access = isl_map_copy(expr->acc.access);
251 access = isl_map_apply_range(isl_map_copy(data->map), access);
252 pma = isl_pw_multi_aff_from_map(access);
253 pma = isl_pw_multi_aff_coalesce(pma);
255 index = isl_ast_expr_list_alloc(ctx, n);
256 for (i = 0; i < n; ++i) {
257 isl_pw_aff *pa;
258 isl_ast_expr *expr;
260 pa = isl_pw_multi_aff_get_pw_aff(pma, i);
261 expr = isl_ast_build_expr_from_pw_aff(data->build, pa);
262 index = isl_ast_expr_list_add(index, expr);
264 isl_pw_multi_aff_free(pma);
266 data->stmt->access[data->stmt->n_access] = index;
267 data->stmt->n_access++;
268 return 0;
271 /* Transform the accesses in the statement associated to the domain
272 * called by "node" to refer to the AST loop iterators,
273 * collect them in a ppcg_stmt and annotate the node with the ppcg_stmt.
275 static __isl_give isl_ast_node *at_each_domain(__isl_take isl_ast_node *node,
276 __isl_keep isl_ast_build *build, void *user)
278 struct ppcg_scop *scop = user;
279 isl_ast_expr *expr, *arg;
280 isl_ctx *ctx;
281 isl_id *id;
282 isl_map *map;
283 struct ppcg_stmt *stmt;
284 struct ppcg_add_access_data data;
286 ctx = isl_ast_node_get_ctx(node);
287 stmt = isl_calloc_type(ctx, struct ppcg_stmt);
288 if (!stmt)
289 goto error;
291 expr = isl_ast_node_user_get_expr(node);
292 arg = isl_ast_expr_get_op_arg(expr, 0);
293 isl_ast_expr_free(expr);
294 id = isl_ast_expr_get_id(arg);
295 isl_ast_expr_free(arg);
296 stmt->stmt = find_stmt(scop, id);
297 isl_id_free(id);
298 if (!stmt->stmt)
299 goto error;
301 stmt->n_access = 0;
302 if (foreach_access_expr(stmt->stmt->body, &inc_n_access, stmt) < 0)
303 goto error;
305 stmt->access = isl_calloc_array(ctx, isl_ast_expr_list *,
306 stmt->n_access);
307 if (!stmt->access)
308 goto error;
310 map = isl_map_from_union_map(isl_ast_build_get_schedule(build));
311 map = isl_map_reverse(map);
313 stmt->n_access = 0;
314 data.stmt = stmt;
315 data.build = build;
316 data.map = map;
317 if (foreach_access_expr(stmt->stmt->body, &add_access, &data) < 0)
318 node = isl_ast_node_free(node);
320 isl_map_free(map);
322 id = isl_id_alloc(isl_ast_node_get_ctx(node), NULL, stmt);
323 id = isl_id_set_free_user(id, &ppcg_stmt_free);
324 return isl_ast_node_set_annotation(node, id);
325 error:
326 ppcg_stmt_free(stmt);
327 return isl_ast_node_free(node);
330 /* Code generate the scop 'scop' and print the corresponding C code to 'p'.
332 static __isl_give isl_printer *print_scop(isl_ctx *ctx, struct ppcg_scop *scop,
333 __isl_take isl_printer *p)
335 isl_set *context;
336 isl_union_set *domain_set;
337 isl_union_map *schedule_map;
338 isl_ast_build *build;
339 isl_ast_print_options *print_options;
340 isl_ast_node *tree;
342 context = isl_set_copy(scop->context);
343 domain_set = isl_union_set_copy(scop->domain);
344 schedule_map = isl_union_map_copy(scop->schedule);
345 schedule_map = isl_union_map_intersect_domain(schedule_map, domain_set);
347 build = isl_ast_build_from_context(context);
348 build = isl_ast_build_set_at_each_domain(build, &at_each_domain, scop);
349 tree = isl_ast_build_ast_from_schedule(build, schedule_map);
350 isl_ast_build_free(build);
352 print_options = isl_ast_print_options_alloc(ctx);
353 print_options = isl_ast_print_options_set_print_user(print_options,
354 &print_user, NULL);
356 p = isl_ast_node_print_macros(tree, p);
357 p = isl_ast_node_print(tree, p, print_options);
359 isl_ast_node_free(tree);
361 return p;
364 /* Does "scop" refer to any arrays that are declared, but not
365 * exposed to the code after the scop?
367 static int any_hidden_declarations(struct ppcg_scop *scop)
369 int i;
371 if (!scop)
372 return 0;
374 for (i = 0; i < scop->n_array; ++i)
375 if (scop->arrays[i]->declared && !scop->arrays[i]->exposed)
376 return 1;
378 return 0;
381 int generate_cpu(isl_ctx *ctx, struct ppcg_scop *ps,
382 struct ppcg_options *options, const char *input, const char *output)
384 FILE *input_file;
385 FILE *output_file;
386 isl_printer *p;
387 int hidden;
389 if (!ps)
390 return -1;
392 input_file = fopen(input, "r");
393 output_file = get_output_file(input, output);
395 copy_before_scop(input_file, output_file);
396 fprintf(output_file, "/* ppcg generated CPU code */\n\n");
397 p = isl_printer_to_file(ctx, output_file);
398 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
399 p = ppcg_print_exposed_declarations(p, ps);
400 hidden = any_hidden_declarations(ps);
401 if (hidden) {
402 p = ppcg_start_block(p);
403 p = ppcg_print_hidden_declarations(p, ps);
405 p = print_scop(ctx, ps, p);
406 if (hidden)
407 p = ppcg_end_block(p);
408 isl_printer_free(p);
409 copy_after_scop(input_file, output_file);
411 fclose(output_file);
412 fclose(input_file);
414 return 0;