update isl for change in lexicographic optimization
[isa.git] / mem.cc
bloba42b71ec707e89af54695f3e6ff21d20f715905d
1 #include <isl/ctx.h>
2 #include <isl/id.h>
3 #include <isl/space.h>
4 #include <isl/aff.h>
5 #include <isl/map.h>
6 #include <isl/union_map.h>
7 #include <isl/polynomial.h>
8 #include <isl/ast.h>
9 #include <isl/ast_build.h>
10 #include <isl/printer.h>
11 #include <barvinok/isl.h>
12 #include <isa/yaml.h>
13 #include <isa/pdg.h>
14 #include "da.h"
15 #include "mem_options.h"
17 #define ALLOC(type) (type*)malloc(sizeof(type))
19 using pdg::PDG;
20 using namespace da;
22 static void dump_program(PDG *pdg, int *output, struct options *options);
24 int main(int argc, char * argv[])
26 PDG *pdg;
27 isl_ctx *ctx;
28 struct options *options = options_new_with_defaults();
30 argc = options_parse(options, argc, argv, ISL_ARG_ALL);
32 ctx = isl_ctx_alloc_with_options(&options_args, options);
34 pdg = PDG::Load(stdin, ctx);
35 int output[pdg->arrays.size()];
37 for (int i = 0; i < pdg->arrays.size(); ++i) {
38 int n = pdg->dependences.size();
39 find_deps(pdg, pdg->arrays[i], data_reuse);
40 output[i] = n == pdg->dependences.size();
43 dump_program(pdg, output, options);
44 pdg->free();
45 delete pdg;
46 isl_ctx_free(ctx);
48 return 0;
51 static __isl_give isl_map *schedule(__isl_take isl_set *domain,
52 pdg::node *node, pdg::access *ac, int dep, int maxdim, int write)
54 int fulldim = node->prefix.size() + 1;
55 isl_map *schedule;
56 int dim;
58 schedule = isl_map_from_domain(domain);
59 schedule = isl_map_add_dims(schedule, isl_dim_out, 2);
60 schedule = isl_map_fix_si(schedule, isl_dim_out, 0, ac->nr);
61 schedule = isl_map_fix_si(schedule, isl_dim_out, 1, dep);
62 schedule = isl_set_identity(isl_map_wrap(schedule));
63 schedule = isl_map_flatten_range(schedule);
65 for (int i = 0; i < node->prefix.size(); ++i) {
66 if (node->prefix[i] == -1)
67 continue;
68 schedule = isl_map_insert_dims(schedule, isl_dim_out, i, 1);
69 schedule = isl_map_fix_si(schedule, isl_dim_out, i, node->prefix[i]);
72 dim = isl_map_dim(schedule, isl_dim_out);
73 schedule = isl_map_project_out(schedule, isl_dim_out, dim - 1, 1);
74 dim--;
75 schedule = isl_map_add_dims(schedule, isl_dim_out, maxdim - dim);
76 for (int i = dim; i < maxdim; ++i)
77 schedule = isl_map_fix_si(schedule, isl_dim_out, i, 0);
79 if (write)
80 schedule = isl_map_set_tuple_name(schedule, isl_dim_in, "write");
81 else
82 schedule = isl_map_set_tuple_name(schedule, isl_dim_in, "read");
84 return schedule;
87 /* Compute the number of array elements that are read but not
88 * (previously) written. These are assumed to be the input.
90 static __isl_give isl_pw_qpolynomial *compute_input_size(PDG *pdg,
91 struct options *options)
93 unsigned nparam = pdg->params.size();
94 isl_space *space;
95 isl_pw_qpolynomial *inputsize;
97 space = isl_space_set_alloc(pdg->get_isl_ctx(), nparam, 1);
98 isl_dim_set_parameter_names(space, pdg->params);
99 inputsize = isl_pw_qpolynomial_zero(space);
101 if (!options->with_input)
102 return inputsize;
104 for (int i = 0; i < pdg->dependences.size(); ++i) {
105 isl_map *map;
106 isl_pw_qpolynomial *size;
107 pdg::dependence *dep = pdg->dependences[i];
109 if (dep->type != pdg::dependence::uninitialized)
110 continue;
112 map = dep->relation->get_isl_map(pdg->get_isl_ctx());
113 size = isl_set_card(isl_map_range(map));
114 inputsize = isl_pw_qpolynomial_add(inputsize, size);
117 return inputsize;
120 static __isl_give isl_printer *print_user(__isl_take isl_printer *p,
121 __isl_take isl_ast_print_options *print_options,
122 __isl_keep isl_ast_node *node, void *user)
124 isl_id *id;
125 isl_ast_expr *expr, *arg;
126 const char *name;
128 expr = isl_ast_node_user_get_expr(node);
129 arg = isl_ast_expr_get_op_arg(expr, 0);
130 id = isl_ast_expr_get_id(arg);
131 name = isl_id_get_name(id);
132 if (!strcmp(name, "write")) {
133 p = isl_printer_start_line(p);
134 p = isl_printer_print_str(p, "if (++count > max)");
135 p = isl_printer_end_line(p);
136 p = isl_printer_indent(p, 4);
137 p = isl_printer_start_line(p);
138 p = isl_printer_print_str(p, "max = count;");
139 p = isl_printer_end_line(p);
140 p = isl_printer_indent(p, -4);
141 } else {
142 p = isl_printer_start_line(p);
143 p = isl_printer_print_str(p, "--count;");
144 p = isl_printer_end_line(p);
146 isl_id_free(id);
147 isl_ast_expr_free(arg);
148 isl_ast_expr_free(expr);
149 isl_ast_print_options_free(print_options);
150 return p;
153 static void print(__isl_keep isl_ast_node *tree, const char *inputsize)
155 isl_ctx *ctx;
156 FILE *out = stdout;
157 isl_printer *p;
158 isl_ast_print_options *print_options;
160 ctx = isl_ast_node_get_ctx(tree);
162 p = isl_printer_to_file(ctx, out);
163 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
165 p = isl_printer_start_line(p);
166 p = isl_printer_print_str(p, "#include <stdio.h>");
167 p = isl_printer_end_line(p);
169 p = isl_ast_node_print_macros(tree, p);
171 p = isl_printer_start_line(p);
172 p = isl_printer_print_str(p, "int main() {");
173 p = isl_printer_end_line(p);
175 p = isl_printer_indent(p, 4);
177 p = isl_printer_start_line(p);
178 p = isl_printer_print_str(p, "long count = ");
179 p = isl_printer_print_str(p, inputsize);
180 p = isl_printer_print_str(p, ";");
181 p = isl_printer_end_line(p);
183 p = isl_printer_start_line(p);
184 p = isl_printer_print_str(p, "long max = count;");
185 p = isl_printer_end_line(p);
187 print_options = isl_ast_print_options_alloc(ctx);
188 print_options = isl_ast_print_options_set_print_user(print_options,
189 &print_user, NULL);
190 p = isl_ast_node_print(tree, p, print_options);
192 p = isl_printer_start_line(p);
193 p = isl_printer_print_str(p, "printf(\"final: %ld\\n\", count);");
194 p = isl_printer_end_line(p);
196 p = isl_printer_start_line(p);
197 p = isl_printer_print_str(p, "printf(\"max: %ld\\n\", max);");
198 p = isl_printer_end_line(p);
200 p = isl_printer_start_line(p);
201 p = isl_printer_print_str(p, "return 0;");
202 p = isl_printer_end_line(p);
204 p = isl_printer_indent(p, -4);
206 p = isl_printer_start_line(p);
207 p = isl_printer_print_str(p, "}");
208 p = isl_printer_end_line(p);
210 isl_printer_free(p);
213 static void dump_program(PDG *pdg, int *output, struct options *options)
215 int nparam = pdg->params.size();
216 isl_ctx *ctx = pdg->get_isl_ctx();
217 isl_set *context = pdg->get_context_isl_set();
218 isl_map *sched_i;
219 isl_union_map *sched;
220 isl_ast_build *build;
221 isl_ast_node *tree;
222 int maxdim = 0;
223 isl_pw_qpolynomial *inputsize;
224 isl_printer *p;
225 char *s_inputsize;
227 inputsize = compute_input_size(pdg, options);
228 inputsize = isl_pw_qpolynomial_gist(inputsize, isl_set_copy(context));
229 p = isl_printer_to_str(pdg->get_isl_ctx());
230 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
231 p = isl_printer_print_pw_qpolynomial(p, inputsize);
232 s_inputsize = isl_printer_get_str(p);
233 isl_printer_free(p);
234 isl_pw_qpolynomial_free(inputsize);
236 sched = isl_union_map_empty(isl_set_get_space(context));
238 for (int i = 0; i < pdg->nodes.size(); ++i) {
239 pdg::node *node = pdg->nodes[i];
240 if (node->prefix.size() + 1 > maxdim)
241 maxdim = node->prefix.size() + 1;
244 for (int i = 0; i < pdg->nodes.size(); ++i) {
245 pdg::node *node = pdg->nodes[i];
246 pdg::statement *s = pdg->nodes[i]->statement;
248 for (int j = 0; j < pdg->dependences.size(); ++j) {
249 pdg::dependence *dep = pdg->dependences[j];
250 if (dep->to != pdg->nodes[i] && dep->from != pdg->nodes[i])
251 continue;
252 if (!options->with_input &&
253 dep->type == pdg::dependence::uninitialized)
254 continue;
255 isl_map *rel = dep->relation->get_isl_map(pdg->get_isl_ctx());
256 if (dep->to == pdg->nodes[i]) {
257 isl_set *read = isl_map_range(isl_map_copy(rel));
258 sched_i = schedule(read, dep->to, dep->to_access, j,
259 maxdim, 0);
260 sched = isl_union_map_add_map(sched, sched_i);
262 if (dep->from == pdg->nodes[i]) {
263 isl_set *write = isl_map_domain(isl_map_copy(rel));
264 sched_i = schedule(write, dep->from, dep->from_access, j,
265 maxdim, 1);
266 sched = isl_union_map_add_map(sched, sched_i);
268 isl_map_free(rel);
270 for (int j = 0; j < s->accesses.size(); ++j) {
271 pdg::access *access = s->accesses[j];
272 if (access->type == pdg::access::read)
273 continue;
274 int k;
275 for (k = 0; k < pdg->arrays.size(); ++k)
276 if (pdg->arrays[k] == access->array)
277 break;
278 if (!output[k])
279 continue;
280 isl_set *source = node->source->get_isl_set(pdg->get_isl_ctx());
281 sched_i = schedule(source, node, access, -(j + 1), maxdim, 1);
282 sched = isl_union_map_add_map(sched, sched_i);
286 build = isl_ast_build_from_context(context);
287 tree = isl_ast_build_ast_from_schedule(build, sched);
288 isl_ast_build_free(build);
290 print(tree, s_inputsize);
292 isl_ast_node_free(tree);
294 free(s_inputsize);