cuda.c: localize_bounds: don't apply destructive operation on "keep" set
[ppcg.git] / clast_printer.c
blob64daf763a6c3dd38b46aa217c2a45fe9f6c9511c
1 /*
2 * Copyright 2010-2011 INRIA Saclay
4 * Use of this software is governed by the GNU LGPLv2.1 license
6 * Written by Sven Verdoolaege, INRIA Saclay - Ile-de-France,
7 * Parc Club Orsay Universite, ZAC des vignes, 4 rue Jacques Monod,
8 * 91893 Orsay, France
9 */
11 #include "clast_printer.h"
13 void print_cloog_macros(FILE *dst)
15 fprintf(dst, "#define floord(n,d) "
16 "(((n)<0) ? -((-(n)+(d)-1)/(d)) : (n)/(d))\n");
17 fprintf(dst, "#define ceild(n,d) "
18 "(((n)<0) ? -((-(n))/(d)) : ((n)+(d)-1)/(d))\n");
19 fprintf(dst, "#define max(x,y) "
20 "((x) > (y) ? (x) : (y))\n");
21 fprintf(dst, "#define min(x,y) "
22 "((x) < (y) ? (x) : (y))\n");
25 static void print_expr(struct clast_expr *e, FILE *dst);
26 static void print_stmt(struct clast_printer_info *info, struct clast_stmt *s);
28 void print_indent(FILE *dst, int indent)
30 fprintf(dst, "%*s", indent, "");
33 static void print_name(struct clast_name *n, FILE *dst)
35 fprintf(dst, "%s", n->name);
38 static void print_term(struct clast_term *t, FILE *dst)
40 if (!t->var) {
41 cloog_int_print(dst, t->val);
42 } else {
43 if (!cloog_int_is_one(t->val)) {
44 cloog_int_print(dst, t->val);
45 fprintf(dst, "*");
47 if (t->var->type == clast_expr_red)
48 fprintf(dst, "(");
49 print_expr(t->var, dst);
50 if (t->var->type == clast_expr_red)
51 fprintf(dst, ")");
55 static void print_bin(struct clast_binary *b, FILE *dst)
57 const char *s1, *s2, *s3;
58 switch (b->type) {
59 case clast_bin_mod:
60 s1 = "(", s2 = ")%", s3 = "";
61 break;
62 case clast_bin_div:
63 s1 = "(", s2 = ")/(", s3 = ")";
64 break;
65 case clast_bin_cdiv:
66 s1 = "ceild(", s2 = ", ", s3 = ")";
67 break;
68 case clast_bin_fdiv:
69 s1 = "floord(", s2 = ", ", s3 = ")";
70 break;
71 default:
72 assert(0);
74 fprintf(dst, "%s", s1);
75 print_expr(b->LHS, dst);
76 fprintf(dst, "%s", s2);
77 cloog_int_print(dst, b->RHS);
78 fprintf(dst, "%s", s3);
81 static void print_red(struct clast_reduction *r, FILE *dst)
83 int i;
84 const char *s1, *s2, *s3;
86 if (r->n == 1) {
87 print_expr(r->elts[0], dst);
88 return;
91 switch (r->type) {
92 case clast_red_sum:
93 s1 = "", s2 = " + ", s3 = "";
94 break;
95 case clast_red_max:
96 s1 = "max(", s2 = ", ", s3 = ")";
97 break;
98 case clast_red_min:
99 s1 = "min(", s2 = ", ", s3 = ")";
100 break;
101 default:
102 assert(0);
105 for (i = 1; i < r->n; ++i)
106 fprintf(dst, "%s", s1);
107 print_expr(r->elts[0], dst);
108 for (i = 1; i < r->n; ++i) {
109 if (r->type == clast_red_sum &&
110 r->elts[i]->type == clast_expr_term &&
111 cloog_int_is_neg(((struct clast_term *) r->elts[i])->val)) {
112 struct clast_term *t = (struct clast_term *) r->elts[i];
113 cloog_int_neg(t->val, t->val);
114 fprintf(dst, " - ");
115 print_expr(r->elts[i], dst);
116 cloog_int_neg(t->val, t->val);
117 } else {
118 fprintf(dst, "%s", s2);
119 print_expr(r->elts[i], dst);
121 fprintf(dst, "%s", s3);
125 static void print_expr(struct clast_expr *e, FILE *dst)
127 switch (e->type) {
128 case clast_expr_name:
129 print_name((struct clast_name*) e, dst);
130 break;
131 case clast_expr_term:
132 print_term((struct clast_term*) e, dst);
133 break;
134 case clast_expr_red:
135 print_red((struct clast_reduction*) e, dst);
136 break;
137 case clast_expr_bin:
138 print_bin((struct clast_binary*) e, dst);
139 break;
140 default:
141 assert(0);
145 static void print_ass(struct clast_assignment *a, FILE *dst, int indent,
146 int first_ass)
148 print_indent(dst, indent);
149 if (first_ass)
150 fprintf(dst, "int ");
151 fprintf(dst, "%s = ", a->LHS);
152 print_expr(a->RHS, dst);
153 fprintf(dst, ";\n");
156 static void print_guard(struct clast_printer_info *info, struct clast_guard *g)
158 int i;
159 int n = g->n;
161 print_indent(info->dst, info->indent);
162 fprintf(info->dst, "if (");
163 for (i = 0; i < n; ++i) {
164 if (i > 0)
165 fprintf(info->dst," && ");
166 if (n > 1)
167 fprintf(info->dst,"(");
168 print_expr(g->eq[i].LHS, info->dst);
169 if (g->eq[i].sign == 0)
170 fprintf(info->dst," == ");
171 else if (g->eq[i].sign > 0)
172 fprintf(info->dst," >= ");
173 else
174 fprintf(info->dst," <= ");
175 print_expr(g->eq[i].RHS, info->dst);
176 if (n > 1)
177 fprintf(info->dst,")");
179 fprintf(info->dst, ") {\n");
180 info->indent += 4;
181 print_stmt(info, g->then);
182 info->indent -= 4;
183 print_indent(info->dst, info->indent);
184 fprintf(info->dst, "}\n");
187 static void print_for(struct clast_printer_info *info, struct clast_for *f)
189 assert(f->LB && f->UB);
190 print_indent(info->dst, info->indent);
191 fprintf(info->dst, "for (int %s = ", f->iterator);
192 print_expr(f->LB, info->dst);
193 fprintf(info->dst, "; %s <= ", f->iterator);
194 print_expr(f->UB, info->dst);
195 fprintf(info->dst, "; %s", f->iterator);
196 if (cloog_int_is_one(f->stride))
197 fprintf(info->dst, "++");
198 else {
199 fprintf(info->dst, " += ");
200 cloog_int_print(info->dst, f->stride);
202 fprintf(info->dst, ") {\n");
203 info->indent += 4;
204 if (info->print_for_head)
205 info->print_for_head(info, f);
206 print_stmt(info, f->body);
207 if (info->print_for_foot)
208 info->print_for_foot(info, f);
209 info->indent -= 4;
210 print_indent(info->dst, info->indent);
211 fprintf(info->dst, "}\n");
214 static void print_user_stmt(struct clast_user_stmt *u, FILE *dst, int indent)
216 struct clast_stmt *t;
218 print_indent(dst, indent);
219 fprintf(dst, "%s", u->statement->name);
220 fprintf(dst, "(");
221 for (t = u->substitutions; t; t = t->next) {
222 assert(CLAST_STMT_IS_A(t, stmt_ass));
223 print_expr(((struct clast_assignment *) t)->RHS, dst);
224 if (t->next)
225 fprintf(dst, ",");
227 fprintf(dst, ");\n");
230 static void print_stmt(struct clast_printer_info *info, struct clast_stmt *s)
232 int first_ass = 1;
234 for ( ; s; s = s->next) {
235 if (CLAST_STMT_IS_A(s, stmt_root))
236 continue;
237 if (CLAST_STMT_IS_A(s, stmt_ass)) {
238 print_ass((struct clast_assignment *) s, info->dst,
239 info->indent, first_ass);
240 first_ass = 0;
241 } else if (CLAST_STMT_IS_A(s, stmt_user)) {
242 struct clast_user_stmt *user_stmt;
243 user_stmt = (struct clast_user_stmt *) s;
245 if (info->print_user_stmt_list) {
246 info->print_user_stmt_list(info, user_stmt);
247 return;
248 } else if (info->print_user_stmt)
249 info->print_user_stmt(info, user_stmt);
250 else
251 print_user_stmt(user_stmt, info->dst,
252 info->indent);
253 } else if (CLAST_STMT_IS_A(s, stmt_for)) {
254 print_for(info, (struct clast_for *) s);
255 } else if (CLAST_STMT_IS_A(s, stmt_guard)) {
256 print_guard(info, (struct clast_guard *) s);
257 } else {
258 assert(0);
263 void print_clast(struct clast_printer_info *info, struct clast_stmt *s)
265 print_stmt(info, s);
268 __isl_give isl_set *extract_host_domain(struct clast_user_stmt *u)
270 return isl_set_from_cloog_domain(cloog_domain_copy(u->domain));
273 /* Extract the set of scattering dimension values for which the given
274 * statement is executed, where the statement may be either a user statement
275 * or a guard containing a sequence of (possibly guarded) user statements.
277 static __isl_give isl_set *extract_nested_host_domain(struct clast_stmt *s)
279 if (CLAST_STMT_IS_A(s, stmt_user)) {
280 struct clast_user_stmt *u = (struct clast_user_stmt *) s;
281 return extract_host_domain(u);
282 } else if (CLAST_STMT_IS_A(s, stmt_guard)) {
283 struct clast_guard *g = (struct clast_guard *) s;
284 return extract_entire_host_domain(g->then);
285 } else
286 assert(0);
289 /* Extract the set of scattering dimension values for which the given
290 * sequence of user statements is executed.
291 * Some of the user statements in the sequence may be guarded
292 * so we return the union of this set over all user statements.
294 __isl_give isl_set *extract_entire_host_domain(struct clast_stmt *s)
296 isl_set *host_domain = NULL;
298 for (; s; s = s->next) {
299 isl_set *set_i;
301 set_i = extract_nested_host_domain(s);
303 if (!host_domain)
304 host_domain = set_i;
305 else
306 host_domain = isl_set_union(host_domain, set_i);
307 assert(host_domain);
310 return isl_set_coalesce(host_domain);