cuda.c: stmt_print_global_index: add extra parentheses around index expression
[ppcg.git] / cuda.c
blobb23e8404bca31f0e5112fb8adb4fba814f8cd59f
1 /*
2 * Copyright 2012 Ecole Normale Superieure
4 * Use of this software is governed by the GNU LGPLv2.1 license
6 * Written by Sven Verdoolaege,
7 * Ecole Normale Superieure, 45 rue d’Ulm, 75230 Paris, France
8 */
10 #include <isl/aff.h>
11 #include <isl/ast.h>
13 #include "cuda_common.h"
14 #include "cuda.h"
15 #include "gpu.h"
16 #include "pet_printer.h"
17 #include "print.h"
18 #include "schedule.h"
20 static __isl_give isl_printer *print_cuda_macros(__isl_take isl_printer *p)
22 const char *macros =
23 "#define cudaCheckReturn(ret) \\\n"
24 " do { \\\n"
25 " cudaError_t cudaCheckReturn_e = (ret); \\\n"
26 " if (cudaCheckReturn_e != cudaSuccess) { \\\n"
27 " fprintf(stderr, \"CUDA error: %s\\n\", "
28 "cudaGetErrorString(cudaCheckReturn_e)); \\\n"
29 " fflush(stderr); \\\n"
30 " } \\\n"
31 " assert(cudaCheckReturn_e == cudaSuccess); \\\n"
32 " } while(0)\n"
33 "#define cudaCheckKernel() \\\n"
34 " do { \\\n"
35 " cudaCheckReturn(cudaGetLastError()); \\\n"
36 " } while(0)\n\n";
38 p = isl_printer_print_str(p, macros);
39 return p;
42 static __isl_give isl_printer *print_array_size(__isl_take isl_printer *prn,
43 struct gpu_array_info *array)
45 int i;
47 for (i = 0; i < array->n_index; ++i) {
48 prn = isl_printer_print_str(prn, "(");
49 prn = isl_printer_print_pw_aff(prn, array->bound[i]);
50 prn = isl_printer_print_str(prn, ") * ");
52 prn = isl_printer_print_str(prn, "sizeof(");
53 prn = isl_printer_print_str(prn, array->type);
54 prn = isl_printer_print_str(prn, ")");
56 return prn;
59 static __isl_give isl_printer *declare_device_arrays(__isl_take isl_printer *p,
60 struct gpu_prog *prog)
62 int i;
64 for (i = 0; i < prog->n_array; ++i) {
65 if (gpu_array_is_read_only_scalar(&prog->array[i]))
66 continue;
67 p = isl_printer_start_line(p);
68 p = isl_printer_print_str(p, prog->array[i].type);
69 p = isl_printer_print_str(p, " *dev_");
70 p = isl_printer_print_str(p, prog->array[i].name);
71 p = isl_printer_print_str(p, ";");
72 p = isl_printer_end_line(p);
74 p = isl_printer_start_line(p);
75 p = isl_printer_end_line(p);
76 return p;
79 static __isl_give isl_printer *allocate_device_arrays(
80 __isl_take isl_printer *p, struct gpu_prog *prog)
82 int i;
84 for (i = 0; i < prog->n_array; ++i) {
85 if (gpu_array_is_read_only_scalar(&prog->array[i]))
86 continue;
87 p = isl_printer_start_line(p);
88 p = isl_printer_print_str(p,
89 "cudaCheckReturn(cudaMalloc((void **) &dev_");
90 p = isl_printer_print_str(p, prog->array[i].name);
91 p = isl_printer_print_str(p, ", ");
92 p = print_array_size(p, &prog->array[i]);
93 p = isl_printer_print_str(p, "));");
94 p = isl_printer_end_line(p);
96 p = isl_printer_start_line(p);
97 p = isl_printer_end_line(p);
98 return p;
101 static __isl_give isl_printer *copy_arrays_to_device(__isl_take isl_printer *p,
102 struct gpu_prog *prog)
104 int i;
106 for (i = 0; i < prog->n_array; ++i) {
107 isl_space *dim;
108 isl_set *read_i;
109 int empty;
111 if (gpu_array_is_read_only_scalar(&prog->array[i]))
112 continue;
114 dim = isl_space_copy(prog->array[i].dim);
115 read_i = isl_union_set_extract_set(prog->copy_in, dim);
116 empty = isl_set_fast_is_empty(read_i);
117 isl_set_free(read_i);
118 if (empty)
119 continue;
121 p = isl_printer_print_str(p, "cudaCheckReturn(cudaMemcpy(dev_");
122 p = isl_printer_print_str(p, prog->array[i].name);
123 p = isl_printer_print_str(p, ", ");
125 if (gpu_array_is_scalar(&prog->array[i]))
126 p = isl_printer_print_str(p, "&");
127 p = isl_printer_print_str(p, prog->array[i].name);
128 p = isl_printer_print_str(p, ", ");
130 p = print_array_size(p, &prog->array[i]);
131 p = isl_printer_print_str(p, ", cudaMemcpyHostToDevice));");
132 p = isl_printer_end_line(p);
134 p = isl_printer_start_line(p);
135 p = isl_printer_end_line(p);
136 return p;
139 static void print_reverse_list(FILE *out, int len, int *list)
141 int i;
143 if (len == 0)
144 return;
146 fprintf(out, "(");
147 for (i = 0; i < len; ++i) {
148 if (i)
149 fprintf(out, ", ");
150 fprintf(out, "%d", list[len - 1 - i]);
152 fprintf(out, ")");
155 /* Print the effective grid size as a list of the sizes in each
156 * dimension, from innermost to outermost.
158 static __isl_give isl_printer *print_grid_size(__isl_take isl_printer *p,
159 struct ppcg_kernel *kernel)
161 int i;
162 int dim;
164 dim = isl_multi_pw_aff_dim(kernel->grid_size, isl_dim_set);
165 if (dim == 0)
166 return p;
168 p = isl_printer_print_str(p, "(");
169 for (i = dim - 1; i >= 0; --i) {
170 isl_pw_aff *bound;
172 bound = isl_multi_pw_aff_get_pw_aff(kernel->grid_size, i);
173 p = isl_printer_print_pw_aff(p, bound);
174 isl_pw_aff_free(bound);
176 if (i > 0)
177 p = isl_printer_print_str(p, ", ");
180 p = isl_printer_print_str(p, ")");
182 return p;
185 /* Print the grid definition.
187 static __isl_give isl_printer *print_grid(__isl_take isl_printer *p,
188 struct ppcg_kernel *kernel)
190 p = isl_printer_start_line(p);
191 p = isl_printer_print_str(p, "dim3 k");
192 p = isl_printer_print_int(p, kernel->id);
193 p = isl_printer_print_str(p, "_dimGrid");
194 p = print_grid_size(p, kernel);
195 p = isl_printer_print_str(p, ";");
196 p = isl_printer_end_line(p);
198 return p;
201 /* Print the arguments to a kernel declaration or call. If "types" is set,
202 * then print a declaration (including the types of the arguments).
204 * The arguments are printed in the following order
205 * - the arrays accessed by the kernel
206 * - the parameters
207 * - the host loop iterators
209 static __isl_give isl_printer *print_kernel_arguments(__isl_take isl_printer *p,
210 struct gpu_prog *prog, struct ppcg_kernel *kernel, int types)
212 int i, n;
213 int first = 1;
214 unsigned nparam;
215 isl_space *space;
216 const char *type;
218 for (i = 0; i < prog->n_array; ++i) {
219 isl_set *arr;
220 int empty;
222 space = isl_space_copy(prog->array[i].dim);
223 arr = isl_union_set_extract_set(kernel->arrays, space);
224 empty = isl_set_fast_is_empty(arr);
225 isl_set_free(arr);
226 if (empty)
227 continue;
229 if (!first)
230 p = isl_printer_print_str(p, ", ");
232 if (types) {
233 p = isl_printer_print_str(p, prog->array[i].type);
234 p = isl_printer_print_str(p, " ");
237 if (gpu_array_is_read_only_scalar(&prog->array[i])) {
238 p = isl_printer_print_str(p, prog->array[i].name);
239 } else {
240 if (types)
241 p = isl_printer_print_str(p, "*");
242 else
243 p = isl_printer_print_str(p, "dev_");
244 p = isl_printer_print_str(p, prog->array[i].name);
247 first = 0;
250 space = isl_union_set_get_space(kernel->arrays);
251 nparam = isl_space_dim(space, isl_dim_param);
252 for (i = 0; i < nparam; ++i) {
253 const char *name;
255 name = isl_space_get_dim_name(space, isl_dim_param, i);
257 if (!first)
258 p = isl_printer_print_str(p, ", ");
259 if (types)
260 p = isl_printer_print_str(p, "int ");
261 p = isl_printer_print_str(p, name);
263 first = 0;
265 isl_space_free(space);
267 n = isl_space_dim(kernel->space, isl_dim_set);
268 type = isl_options_get_ast_iterator_type(prog->ctx);
269 for (i = 0; i < n; ++i) {
270 const char *name;
271 isl_id *id;
273 if (!first)
274 p = isl_printer_print_str(p, ", ");
275 name = isl_space_get_dim_name(kernel->space, isl_dim_set, i);
276 if (types) {
277 p = isl_printer_print_str(p, type);
278 p = isl_printer_print_str(p, " ");
280 p = isl_printer_print_str(p, name);
282 first = 0;
285 return p;
288 /* Print the header of the given kernel.
290 static __isl_give isl_printer *print_kernel_header(__isl_take isl_printer *p,
291 struct gpu_prog *prog, struct ppcg_kernel *kernel)
293 p = isl_printer_start_line(p);
294 p = isl_printer_print_str(p, "__global__ void kernel");
295 p = isl_printer_print_int(p, kernel->id);
296 p = isl_printer_print_str(p, "(");
297 p = print_kernel_arguments(p, prog, kernel, 1);
298 p = isl_printer_print_str(p, ")");
300 return p;
303 /* Print the header of the given kernel to both gen->cuda.kernel_h
304 * and gen->cuda.kernel_c.
306 static void print_kernel_headers(struct gpu_prog *prog,
307 struct ppcg_kernel *kernel, struct cuda_info *cuda)
309 isl_printer *p;
311 p = isl_printer_to_file(prog->ctx, cuda->kernel_h);
312 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
313 p = print_kernel_header(p, prog, kernel);
314 p = isl_printer_print_str(p, ";");
315 p = isl_printer_end_line(p);
316 isl_printer_free(p);
318 p = isl_printer_to_file(prog->ctx, cuda->kernel_c);
319 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
320 p = print_kernel_header(p, prog, kernel);
321 p = isl_printer_end_line(p);
322 isl_printer_free(p);
325 static void print_indent(FILE *dst, int indent)
327 fprintf(dst, "%*s", indent, "");
330 static void print_kernel_iterators(FILE *out, struct ppcg_kernel *kernel)
332 int i;
333 isl_ctx *ctx = isl_ast_node_get_ctx(kernel->tree);
334 const char *type;
335 const char *block_dims[] = { "blockIdx.x", "blockIdx.y" };
336 const char *thread_dims[] = { "threadIdx.x", "threadIdx.y",
337 "threadIdx.z" };
339 type = isl_options_get_ast_iterator_type(ctx);
341 if (kernel->n_grid > 0) {
342 print_indent(out, 4);
343 fprintf(out, "%s ", type);
344 for (i = 0; i < kernel->n_grid; ++i) {
345 if (i)
346 fprintf(out, ", ");
347 fprintf(out, "b%d = %s",
348 i, block_dims[kernel->n_grid - 1 - i]);
350 fprintf(out, ";\n");
353 if (kernel->n_block > 0) {
354 print_indent(out, 4);
355 fprintf(out, "%s ", type);
356 for (i = 0; i < kernel->n_block; ++i) {
357 if (i)
358 fprintf(out, ", ");
359 fprintf(out, "t%d = %s",
360 i, thread_dims[kernel->n_block - 1 - i]);
362 fprintf(out, ";\n");
366 static void print_kernel_var(FILE *out, struct ppcg_kernel_var *var)
368 int j;
369 isl_int v;
371 print_indent(out, 4);
372 if (var->type == ppcg_access_shared)
373 fprintf(out, "__shared__ ");
374 fprintf(out, "%s %s", var->array->type, var->name);
375 isl_int_init(v);
376 for (j = 0; j < var->array->n_index; ++j) {
377 fprintf(out, "[");
378 isl_vec_get_element(var->size, j, &v);
379 isl_int_print(out, v, 0);
380 fprintf(out, "]");
382 isl_int_clear(v);
383 fprintf(out, ";\n");
386 static void print_kernel_vars(FILE *out, struct ppcg_kernel *kernel)
388 int i;
390 for (i = 0; i < kernel->n_var; ++i)
391 print_kernel_var(out, &kernel->var[i]);
394 /* Print an access to the element in the private/shared memory copy
395 * described by "stmt". The index of the copy is recorded in
396 * stmt->local_index as a "call" to the array.
398 static __isl_give isl_printer *stmt_print_local_index(__isl_take isl_printer *p,
399 struct ppcg_kernel_stmt *stmt)
401 int i;
402 isl_ast_expr *expr;
403 struct gpu_array_info *array = stmt->u.c.array;
405 expr = isl_ast_expr_get_op_arg(stmt->u.c.local_index, 0);
406 p = isl_printer_print_ast_expr(p, expr);
407 isl_ast_expr_free(expr);
409 for (i = 0; i < array->n_index; ++i) {
410 expr = isl_ast_expr_get_op_arg(stmt->u.c.local_index, 1 + i);
412 p = isl_printer_print_str(p, "[");
413 p = isl_printer_print_ast_expr(p, expr);
414 p = isl_printer_print_str(p, "]");
416 isl_ast_expr_free(expr);
419 return p;
422 /* Print an access to the element in the global memory copy
423 * described by "stmt". The index of the copy is recorded in
424 * stmt->index as a "call" to the array.
426 * The copy in global memory has been linearized, so we need to take
427 * the array size into account.
429 static __isl_give isl_printer *stmt_print_global_index(
430 __isl_take isl_printer *p, struct ppcg_kernel_stmt *stmt)
432 int i;
433 struct gpu_array_info *array = stmt->u.c.array;
434 isl_pw_aff_list *bound = stmt->u.c.local_array->bound;
436 if (gpu_array_is_scalar(array)) {
437 if (!array->read_only)
438 p = isl_printer_print_str(p, "*");
439 p = isl_printer_print_str(p, array->name);
440 return p;
443 p = isl_printer_print_str(p, array->name);
444 p = isl_printer_print_str(p, "[");
445 for (i = 0; i + 1 < array->n_index; ++i)
446 p = isl_printer_print_str(p, "(");
447 for (i = 0; i < array->n_index; ++i) {
448 isl_ast_expr *expr;
449 expr = isl_ast_expr_get_op_arg(stmt->u.c.index, 1 + i);
450 if (i) {
451 isl_pw_aff *bound_i;
452 bound_i = isl_pw_aff_list_get_pw_aff(bound, i);
453 p = isl_printer_print_str(p, ") * (");
454 p = isl_printer_print_pw_aff(p, bound_i);
455 p = isl_printer_print_str(p, ") + (");
456 isl_pw_aff_free(bound_i);
458 p = isl_printer_print_ast_expr(p, expr);
459 if (i)
460 p = isl_printer_print_str(p, ")");
461 isl_ast_expr_free(expr);
463 p = isl_printer_print_str(p, "]");
465 return p;
468 /* Print a copy statement.
470 * A read copy statement is printed as
472 * local = global;
474 * while a write copy statement is printed as
476 * global = local;
478 static __isl_give isl_printer *print_copy(__isl_take isl_printer *p,
479 struct ppcg_kernel_stmt *stmt)
481 p = isl_printer_start_line(p);
482 if (stmt->u.c.read) {
483 p = stmt_print_local_index(p, stmt);
484 p = isl_printer_print_str(p, " = ");
485 p = stmt_print_global_index(p, stmt);
486 } else {
487 p = stmt_print_global_index(p, stmt);
488 p = isl_printer_print_str(p, " = ");
489 p = stmt_print_local_index(p, stmt);
491 p = isl_printer_print_str(p, ";");
492 p = isl_printer_end_line(p);
494 return p;
497 /* Print a sync statement.
499 static __isl_give isl_printer *print_sync(__isl_take isl_printer *p,
500 struct ppcg_kernel_stmt *stmt)
502 p = isl_printer_start_line(p);
503 p = isl_printer_print_str(p, "__syncthreads();");
504 p = isl_printer_end_line(p);
506 return p;
509 /* Print an access based on the information in "access".
510 * If this an access to global memory, then the index expression
511 * is linearized.
513 * If access->array is NULL, then we are
514 * accessing an iterator in the original program.
516 static __isl_give isl_printer *print_access(__isl_take isl_printer *p,
517 struct ppcg_kernel_access *access)
519 int i;
520 unsigned n_index;
521 struct gpu_array_info *array;
522 isl_pw_aff_list *bound;
524 array = access->array;
525 bound = array ? access->local_array->bound : NULL;
526 if (!array)
527 p = isl_printer_print_str(p, "(");
528 else {
529 if (access->type == ppcg_access_global &&
530 gpu_array_is_scalar(array) && !array->read_only)
531 p = isl_printer_print_str(p, "*");
532 p = isl_printer_print_str(p, access->local_name);
533 if (gpu_array_is_scalar(array))
534 return p;
535 p = isl_printer_print_str(p, "[");
538 n_index = isl_ast_expr_list_n_ast_expr(access->index);
539 if (access->type == ppcg_access_global)
540 for (i = 0; i + 1 < n_index; ++i)
541 p = isl_printer_print_str(p, "(");
543 for (i = 0; i < n_index; ++i) {
544 isl_ast_expr *index;
546 index = isl_ast_expr_list_get_ast_expr(access->index, i);
547 if (array && i) {
548 if (access->type == ppcg_access_global) {
549 isl_pw_aff *bound_i;
550 bound_i = isl_pw_aff_list_get_pw_aff(bound, i);
551 p = isl_printer_print_str(p, ") * (");
552 p = isl_printer_print_pw_aff(p, bound_i);
553 p = isl_printer_print_str(p, ") + ");
554 isl_pw_aff_free(bound_i);
555 } else
556 p = isl_printer_print_str(p, "][");
558 p = isl_printer_print_ast_expr(p, index);
559 isl_ast_expr_free(index);
561 if (!array)
562 p = isl_printer_print_str(p, ")");
563 else
564 p = isl_printer_print_str(p, "]");
566 return p;
569 struct cuda_access_print_info {
570 int i;
571 struct ppcg_kernel_stmt *stmt;
574 /* To print the cuda accesses we walk the list of cuda accesses simultaneously
575 * with the pet printer. This means that whenever the pet printer prints a
576 * pet access expression we have the corresponding cuda access available and can
577 * print the modified access.
579 static __isl_give isl_printer *print_cuda_access(__isl_take isl_printer *p,
580 struct pet_expr *expr, void *usr)
582 struct cuda_access_print_info *info =
583 (struct cuda_access_print_info *) usr;
585 p = print_access(p, &info->stmt->u.d.access[info->i]);
586 info->i++;
588 return p;
591 static __isl_give isl_printer *print_stmt_body(__isl_take isl_printer *p,
592 struct ppcg_kernel_stmt *stmt)
594 struct cuda_access_print_info info;
596 info.i = 0;
597 info.stmt = stmt;
599 p = isl_printer_start_line(p);
600 p = print_pet_expr(p, stmt->u.d.stmt->body, &print_cuda_access, &info);
601 p = isl_printer_print_str(p, ";");
602 p = isl_printer_end_line(p);
604 return p;
607 /* This function is called for each user statement in the AST,
608 * i.e., for each kernel body statement, copy statement or sync statement.
610 static __isl_give isl_printer *print_kernel_stmt(__isl_take isl_printer *p,
611 __isl_take isl_ast_print_options *print_options,
612 __isl_keep isl_ast_node *node, void *user)
614 isl_id *id;
615 struct ppcg_kernel_stmt *stmt;
617 id = isl_ast_node_get_annotation(node);
618 stmt = isl_id_get_user(id);
619 isl_id_free(id);
621 isl_ast_print_options_free(print_options);
623 switch (stmt->type) {
624 case ppcg_kernel_copy:
625 return print_copy(p, stmt);
626 case ppcg_kernel_sync:
627 return print_sync(p, stmt);
628 case ppcg_kernel_domain:
629 return print_stmt_body(p, stmt);
632 return p;
635 static int print_macro(enum isl_ast_op_type type, void *user)
637 isl_printer **p = user;
639 if (type == isl_ast_op_fdiv_q)
640 return 0;
642 *p = isl_ast_op_type_print_macro(type, *p);
644 return 0;
647 /* Print the required macros for "node", including one for floord.
648 * We always print a macro for floord as it may also appear in the statements.
650 static __isl_give isl_printer *print_macros(
651 __isl_keep isl_ast_node *node, __isl_take isl_printer *p)
653 p = isl_ast_op_type_print_macro(isl_ast_op_fdiv_q, p);
654 if (isl_ast_node_foreach_ast_op_type(node, &print_macro, &p) < 0)
655 return isl_printer_free(p);
656 return p;
659 static void print_kernel(struct gpu_prog *prog, struct ppcg_kernel *kernel,
660 struct cuda_info *cuda)
662 isl_ctx *ctx = isl_ast_node_get_ctx(kernel->tree);
663 isl_ast_print_options *print_options;
664 isl_printer *p;
666 print_kernel_headers(prog, kernel, cuda);
667 fprintf(cuda->kernel_c, "{\n");
668 print_kernel_iterators(cuda->kernel_c, kernel);
669 print_kernel_vars(cuda->kernel_c, kernel);
670 fprintf(cuda->kernel_c, "\n");
672 print_options = isl_ast_print_options_alloc(ctx);
673 print_options = isl_ast_print_options_set_print_user(print_options,
674 &print_kernel_stmt, NULL);
676 p = isl_printer_to_file(ctx, cuda->kernel_c);
677 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
678 p = isl_printer_indent(p, 4);
679 p = print_macros(kernel->tree, p);
680 p = isl_ast_node_print(kernel->tree, p, print_options);
681 isl_printer_free(p);
683 fprintf(cuda->kernel_c, "}\n");
686 struct print_host_user_data {
687 struct cuda_info *cuda;
688 struct gpu_prog *prog;
691 /* Print the user statement of the host code to "p".
693 * In particular, print a block of statements that defines the grid
694 * and the block and then launches the kernel.
696 static __isl_give isl_printer *print_host_user(__isl_take isl_printer *p,
697 __isl_take isl_ast_print_options *print_options,
698 __isl_keep isl_ast_node *node, void *user)
700 isl_id *id;
701 struct ppcg_kernel *kernel;
702 struct print_host_user_data *data;
704 id = isl_ast_node_get_annotation(node);
705 kernel = isl_id_get_user(id);
706 isl_id_free(id);
708 data = (struct print_host_user_data *) user;
710 p = isl_printer_start_line(p);
711 p = isl_printer_print_str(p, "{");
712 p = isl_printer_end_line(p);
713 p = isl_printer_indent(p, 2);
715 p = isl_printer_start_line(p);
716 p = isl_printer_print_str(p, "dim3 k");
717 p = isl_printer_print_int(p, kernel->id);
718 p = isl_printer_print_str(p, "_dimBlock");
719 print_reverse_list(isl_printer_get_file(p),
720 kernel->n_block, kernel->block_dim);
721 p = isl_printer_print_str(p, ";");
722 p = isl_printer_end_line(p);
724 p = print_grid(p, kernel);
726 p = isl_printer_start_line(p);
727 p = isl_printer_print_str(p, "kernel");
728 p = isl_printer_print_int(p, kernel->id);
729 p = isl_printer_print_str(p, " <<<k");
730 p = isl_printer_print_int(p, kernel->id);
731 p = isl_printer_print_str(p, "_dimGrid, k");
732 p = isl_printer_print_int(p, kernel->id);
733 p = isl_printer_print_str(p, "_dimBlock>>> (");
734 p = print_kernel_arguments(p, data->prog, kernel, 0);
735 p = isl_printer_print_str(p, ");");
736 p = isl_printer_end_line(p);
738 p = isl_printer_start_line(p);
739 p = isl_printer_print_str(p, "cudaCheckKernel();");
740 p = isl_printer_end_line(p);
742 p = isl_printer_indent(p, -2);
743 p = isl_printer_start_line(p);
744 p = isl_printer_print_str(p, "}");
745 p = isl_printer_end_line(p);
747 p = isl_printer_start_line(p);
748 p = isl_printer_end_line(p);
750 print_kernel(data->prog, kernel, data->cuda);
752 isl_ast_print_options_free(print_options);
754 return p;
757 static __isl_give isl_printer *print_host_code(__isl_take isl_printer *p,
758 struct gpu_prog *prog, __isl_keep isl_ast_node *tree,
759 struct cuda_info *cuda)
761 isl_ast_print_options *print_options;
762 isl_ctx *ctx = isl_ast_node_get_ctx(tree);
763 struct print_host_user_data data = { cuda, prog };
765 print_options = isl_ast_print_options_alloc(ctx);
766 print_options = isl_ast_print_options_set_print_user(print_options,
767 &print_host_user, &data);
769 p = print_macros(tree, p);
770 p = isl_ast_node_print(tree, p, print_options);
772 return p;
775 /* For each array that is written anywhere in the gpu_prog,
776 * copy the contents back from the GPU to the host.
778 * Arrays that are not visible outside the corresponding scop
779 * do not need to be copied back.
781 static __isl_give isl_printer *copy_arrays_from_device(
782 __isl_take isl_printer *p, struct gpu_prog *prog)
784 int i;
785 isl_union_set *write;
786 write = isl_union_map_range(isl_union_map_copy(prog->write));
788 for (i = 0; i < prog->n_array; ++i) {
789 isl_space *dim;
790 isl_set *write_i;
791 int empty;
793 if (prog->array[i].local)
794 continue;
796 dim = isl_space_copy(prog->array[i].dim);
797 write_i = isl_union_set_extract_set(write, dim);
798 empty = isl_set_fast_is_empty(write_i);
799 isl_set_free(write_i);
800 if (empty)
801 continue;
803 p = isl_printer_print_str(p, "cudaCheckReturn(cudaMemcpy(");
804 if (gpu_array_is_scalar(&prog->array[i]))
805 p = isl_printer_print_str(p, "&");
806 p = isl_printer_print_str(p, prog->array[i].name);
807 p = isl_printer_print_str(p, ", dev_");
808 p = isl_printer_print_str(p, prog->array[i].name);
809 p = isl_printer_print_str(p, ", ");
810 p = print_array_size(p, &prog->array[i]);
811 p = isl_printer_print_str(p, ", cudaMemcpyDeviceToHost));");
812 p = isl_printer_end_line(p);
815 isl_union_set_free(write);
816 p = isl_printer_start_line(p);
817 p = isl_printer_end_line(p);
818 return p;
821 static __isl_give isl_printer *free_device_arrays(__isl_take isl_printer *p,
822 struct gpu_prog *prog)
824 int i;
826 for (i = 0; i < prog->n_array; ++i) {
827 if (gpu_array_is_read_only_scalar(&prog->array[i]))
828 continue;
829 p = isl_printer_print_str(p, "cudaCheckReturn(cudaFree(dev_");
830 p = isl_printer_print_str(p, prog->array[i].name);
831 p = isl_printer_print_str(p, "));");
832 p = isl_printer_end_line(p);
835 return p;
838 int generate_cuda(isl_ctx *ctx, struct ppcg_scop *scop,
839 struct ppcg_options *options, const char *input)
841 struct cuda_info cuda;
842 struct gpu_prog *prog;
843 isl_ast_node *tree;
844 isl_printer *p;
846 if (!scop)
847 return -1;
849 prog = gpu_prog_alloc(ctx, scop);
851 tree = generate_gpu(ctx, prog, options);
853 cuda.start = scop->start;
854 cuda.end = scop->end;
855 cuda_open_files(&cuda, input);
857 p = isl_printer_to_file(ctx, cuda.host_c);
858 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
859 p = ppcg_print_exposed_declarations(p, scop);
860 p = ppcg_start_block(p);
862 p = print_cuda_macros(p);
864 p = declare_device_arrays(p, prog);
865 p = allocate_device_arrays(p, prog);
866 p = copy_arrays_to_device(p, prog);
868 p = print_host_code(p, prog, tree, &cuda);
869 isl_ast_node_free(tree);
871 p = copy_arrays_from_device(p, prog);
872 p = free_device_arrays(p, prog);
874 p = ppcg_end_block(p);
875 isl_printer_free(p);
877 cuda_close_files(&cuda);
879 gpu_prog_free(prog);
881 return 0;