cuda.c: print_kernel_vars: take isl_printer instead of FILE
[ppcg.git] / cuda.c
blob2457d5e97e024d3b6bcd03d10e5598e603c00a05
1 /*
2 * Copyright 2012 Ecole Normale Superieure
4 * Use of this software is governed by the GNU LGPLv2.1 license
6 * Written by Sven Verdoolaege,
7 * Ecole Normale Superieure, 45 rue d’Ulm, 75230 Paris, France
8 */
10 #include <isl/aff.h>
11 #include <isl/ast.h>
13 #include "cuda_common.h"
14 #include "cuda.h"
15 #include "gpu.h"
16 #include "pet_printer.h"
17 #include "print.h"
18 #include "schedule.h"
20 static __isl_give isl_printer *print_cuda_macros(__isl_take isl_printer *p)
22 const char *macros =
23 "#define cudaCheckReturn(ret) \\\n"
24 " do { \\\n"
25 " cudaError_t cudaCheckReturn_e = (ret); \\\n"
26 " if (cudaCheckReturn_e != cudaSuccess) { \\\n"
27 " fprintf(stderr, \"CUDA error: %s\\n\", "
28 "cudaGetErrorString(cudaCheckReturn_e)); \\\n"
29 " fflush(stderr); \\\n"
30 " } \\\n"
31 " assert(cudaCheckReturn_e == cudaSuccess); \\\n"
32 " } while(0)\n"
33 "#define cudaCheckKernel() \\\n"
34 " do { \\\n"
35 " cudaCheckReturn(cudaGetLastError()); \\\n"
36 " } while(0)\n\n";
38 p = isl_printer_print_str(p, macros);
39 return p;
42 static __isl_give isl_printer *print_array_size(__isl_take isl_printer *prn,
43 struct gpu_array_info *array)
45 int i;
47 for (i = 0; i < array->n_index; ++i) {
48 prn = isl_printer_print_str(prn, "(");
49 prn = isl_printer_print_pw_aff(prn, array->bound[i]);
50 prn = isl_printer_print_str(prn, ") * ");
52 prn = isl_printer_print_str(prn, "sizeof(");
53 prn = isl_printer_print_str(prn, array->type);
54 prn = isl_printer_print_str(prn, ")");
56 return prn;
59 static __isl_give isl_printer *declare_device_arrays(__isl_take isl_printer *p,
60 struct gpu_prog *prog)
62 int i;
64 for (i = 0; i < prog->n_array; ++i) {
65 if (gpu_array_is_read_only_scalar(&prog->array[i]))
66 continue;
67 p = isl_printer_start_line(p);
68 p = isl_printer_print_str(p, prog->array[i].type);
69 p = isl_printer_print_str(p, " *dev_");
70 p = isl_printer_print_str(p, prog->array[i].name);
71 p = isl_printer_print_str(p, ";");
72 p = isl_printer_end_line(p);
74 p = isl_printer_start_line(p);
75 p = isl_printer_end_line(p);
76 return p;
79 static __isl_give isl_printer *allocate_device_arrays(
80 __isl_take isl_printer *p, struct gpu_prog *prog)
82 int i;
84 for (i = 0; i < prog->n_array; ++i) {
85 if (gpu_array_is_read_only_scalar(&prog->array[i]))
86 continue;
87 p = isl_printer_start_line(p);
88 p = isl_printer_print_str(p,
89 "cudaCheckReturn(cudaMalloc((void **) &dev_");
90 p = isl_printer_print_str(p, prog->array[i].name);
91 p = isl_printer_print_str(p, ", ");
92 p = print_array_size(p, &prog->array[i]);
93 p = isl_printer_print_str(p, "));");
94 p = isl_printer_end_line(p);
96 p = isl_printer_start_line(p);
97 p = isl_printer_end_line(p);
98 return p;
101 static __isl_give isl_printer *copy_arrays_to_device(__isl_take isl_printer *p,
102 struct gpu_prog *prog)
104 int i;
106 for (i = 0; i < prog->n_array; ++i) {
107 isl_space *dim;
108 isl_set *read_i;
109 int empty;
111 if (gpu_array_is_read_only_scalar(&prog->array[i]))
112 continue;
114 dim = isl_space_copy(prog->array[i].dim);
115 read_i = isl_union_set_extract_set(prog->copy_in, dim);
116 empty = isl_set_fast_is_empty(read_i);
117 isl_set_free(read_i);
118 if (empty)
119 continue;
121 p = isl_printer_print_str(p, "cudaCheckReturn(cudaMemcpy(dev_");
122 p = isl_printer_print_str(p, prog->array[i].name);
123 p = isl_printer_print_str(p, ", ");
125 if (gpu_array_is_scalar(&prog->array[i]))
126 p = isl_printer_print_str(p, "&");
127 p = isl_printer_print_str(p, prog->array[i].name);
128 p = isl_printer_print_str(p, ", ");
130 p = print_array_size(p, &prog->array[i]);
131 p = isl_printer_print_str(p, ", cudaMemcpyHostToDevice));");
132 p = isl_printer_end_line(p);
134 p = isl_printer_start_line(p);
135 p = isl_printer_end_line(p);
136 return p;
139 static void print_reverse_list(FILE *out, int len, int *list)
141 int i;
143 if (len == 0)
144 return;
146 fprintf(out, "(");
147 for (i = 0; i < len; ++i) {
148 if (i)
149 fprintf(out, ", ");
150 fprintf(out, "%d", list[len - 1 - i]);
152 fprintf(out, ")");
155 /* Print the effective grid size as a list of the sizes in each
156 * dimension, from innermost to outermost.
158 static __isl_give isl_printer *print_grid_size(__isl_take isl_printer *p,
159 struct ppcg_kernel *kernel)
161 int i;
162 int dim;
164 dim = isl_multi_pw_aff_dim(kernel->grid_size, isl_dim_set);
165 if (dim == 0)
166 return p;
168 p = isl_printer_print_str(p, "(");
169 for (i = dim - 1; i >= 0; --i) {
170 isl_pw_aff *bound;
172 bound = isl_multi_pw_aff_get_pw_aff(kernel->grid_size, i);
173 p = isl_printer_print_pw_aff(p, bound);
174 isl_pw_aff_free(bound);
176 if (i > 0)
177 p = isl_printer_print_str(p, ", ");
180 p = isl_printer_print_str(p, ")");
182 return p;
185 /* Print the grid definition.
187 static __isl_give isl_printer *print_grid(__isl_take isl_printer *p,
188 struct ppcg_kernel *kernel)
190 p = isl_printer_start_line(p);
191 p = isl_printer_print_str(p, "dim3 k");
192 p = isl_printer_print_int(p, kernel->id);
193 p = isl_printer_print_str(p, "_dimGrid");
194 p = print_grid_size(p, kernel);
195 p = isl_printer_print_str(p, ";");
196 p = isl_printer_end_line(p);
198 return p;
201 /* Print the arguments to a kernel declaration or call. If "types" is set,
202 * then print a declaration (including the types of the arguments).
204 * The arguments are printed in the following order
205 * - the arrays accessed by the kernel
206 * - the parameters
207 * - the host loop iterators
209 static __isl_give isl_printer *print_kernel_arguments(__isl_take isl_printer *p,
210 struct gpu_prog *prog, struct ppcg_kernel *kernel, int types)
212 int i, n;
213 int first = 1;
214 unsigned nparam;
215 isl_space *space;
216 const char *type;
218 for (i = 0; i < prog->n_array; ++i) {
219 isl_set *arr;
220 int empty;
222 space = isl_space_copy(prog->array[i].dim);
223 arr = isl_union_set_extract_set(kernel->arrays, space);
224 empty = isl_set_fast_is_empty(arr);
225 isl_set_free(arr);
226 if (empty)
227 continue;
229 if (!first)
230 p = isl_printer_print_str(p, ", ");
232 if (types) {
233 p = isl_printer_print_str(p, prog->array[i].type);
234 p = isl_printer_print_str(p, " ");
237 if (gpu_array_is_read_only_scalar(&prog->array[i])) {
238 p = isl_printer_print_str(p, prog->array[i].name);
239 } else {
240 if (types)
241 p = isl_printer_print_str(p, "*");
242 else
243 p = isl_printer_print_str(p, "dev_");
244 p = isl_printer_print_str(p, prog->array[i].name);
247 first = 0;
250 space = isl_union_set_get_space(kernel->arrays);
251 nparam = isl_space_dim(space, isl_dim_param);
252 for (i = 0; i < nparam; ++i) {
253 const char *name;
255 name = isl_space_get_dim_name(space, isl_dim_param, i);
257 if (!first)
258 p = isl_printer_print_str(p, ", ");
259 if (types)
260 p = isl_printer_print_str(p, "int ");
261 p = isl_printer_print_str(p, name);
263 first = 0;
265 isl_space_free(space);
267 n = isl_space_dim(kernel->space, isl_dim_set);
268 type = isl_options_get_ast_iterator_type(prog->ctx);
269 for (i = 0; i < n; ++i) {
270 const char *name;
271 isl_id *id;
273 if (!first)
274 p = isl_printer_print_str(p, ", ");
275 name = isl_space_get_dim_name(kernel->space, isl_dim_set, i);
276 if (types) {
277 p = isl_printer_print_str(p, type);
278 p = isl_printer_print_str(p, " ");
280 p = isl_printer_print_str(p, name);
282 first = 0;
285 return p;
288 /* Print the header of the given kernel.
290 static __isl_give isl_printer *print_kernel_header(__isl_take isl_printer *p,
291 struct gpu_prog *prog, struct ppcg_kernel *kernel)
293 p = isl_printer_start_line(p);
294 p = isl_printer_print_str(p, "__global__ void kernel");
295 p = isl_printer_print_int(p, kernel->id);
296 p = isl_printer_print_str(p, "(");
297 p = print_kernel_arguments(p, prog, kernel, 1);
298 p = isl_printer_print_str(p, ")");
300 return p;
303 /* Print the header of the given kernel to both gen->cuda.kernel_h
304 * and gen->cuda.kernel_c.
306 static void print_kernel_headers(struct gpu_prog *prog,
307 struct ppcg_kernel *kernel, struct cuda_info *cuda)
309 isl_printer *p;
311 p = isl_printer_to_file(prog->ctx, cuda->kernel_h);
312 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
313 p = print_kernel_header(p, prog, kernel);
314 p = isl_printer_print_str(p, ";");
315 p = isl_printer_end_line(p);
316 isl_printer_free(p);
318 p = isl_printer_to_file(prog->ctx, cuda->kernel_c);
319 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
320 p = print_kernel_header(p, prog, kernel);
321 p = isl_printer_end_line(p);
322 isl_printer_free(p);
325 static void print_indent(FILE *dst, int indent)
327 fprintf(dst, "%*s", indent, "");
330 static void print_kernel_iterators(FILE *out, struct ppcg_kernel *kernel)
332 int i;
333 isl_ctx *ctx = isl_ast_node_get_ctx(kernel->tree);
334 const char *type;
335 const char *block_dims[] = { "blockIdx.x", "blockIdx.y" };
336 const char *thread_dims[] = { "threadIdx.x", "threadIdx.y",
337 "threadIdx.z" };
339 type = isl_options_get_ast_iterator_type(ctx);
341 if (kernel->n_grid > 0) {
342 print_indent(out, 4);
343 fprintf(out, "%s ", type);
344 for (i = 0; i < kernel->n_grid; ++i) {
345 if (i)
346 fprintf(out, ", ");
347 fprintf(out, "b%d = %s",
348 i, block_dims[kernel->n_grid - 1 - i]);
350 fprintf(out, ";\n");
353 if (kernel->n_block > 0) {
354 print_indent(out, 4);
355 fprintf(out, "%s ", type);
356 for (i = 0; i < kernel->n_block; ++i) {
357 if (i)
358 fprintf(out, ", ");
359 fprintf(out, "t%d = %s",
360 i, thread_dims[kernel->n_block - 1 - i]);
362 fprintf(out, ";\n");
366 static __isl_give isl_printer *print_kernel_var(__isl_take isl_printer *p,
367 struct ppcg_kernel_var *var)
369 int j;
370 isl_int v;
372 p = isl_printer_start_line(p);
373 if (var->type == ppcg_access_shared)
374 p = isl_printer_print_str(p, "__shared__ ");
375 p = isl_printer_print_str(p, var->array->type);
376 p = isl_printer_print_str(p, " ");
377 p = isl_printer_print_str(p, var->name);
378 isl_int_init(v);
379 for (j = 0; j < var->array->n_index; ++j) {
380 p = isl_printer_print_str(p, "[");
381 isl_vec_get_element(var->size, j, &v);
382 p = isl_printer_print_isl_int(p, v);
383 p = isl_printer_print_str(p, "]");
385 isl_int_clear(v);
386 p = isl_printer_print_str(p, ";");
387 p = isl_printer_end_line(p);
389 return p;
392 static __isl_give isl_printer *print_kernel_vars(__isl_take isl_printer *p,
393 struct ppcg_kernel *kernel)
395 int i;
397 for (i = 0; i < kernel->n_var; ++i)
398 p = print_kernel_var(p, &kernel->var[i]);
400 return p;
403 /* Print an access to the element in the private/shared memory copy
404 * described by "stmt". The index of the copy is recorded in
405 * stmt->local_index as a "call" to the array.
407 static __isl_give isl_printer *stmt_print_local_index(__isl_take isl_printer *p,
408 struct ppcg_kernel_stmt *stmt)
410 int i;
411 isl_ast_expr *expr;
412 struct gpu_array_info *array = stmt->u.c.array;
414 expr = isl_ast_expr_get_op_arg(stmt->u.c.local_index, 0);
415 p = isl_printer_print_ast_expr(p, expr);
416 isl_ast_expr_free(expr);
418 for (i = 0; i < array->n_index; ++i) {
419 expr = isl_ast_expr_get_op_arg(stmt->u.c.local_index, 1 + i);
421 p = isl_printer_print_str(p, "[");
422 p = isl_printer_print_ast_expr(p, expr);
423 p = isl_printer_print_str(p, "]");
425 isl_ast_expr_free(expr);
428 return p;
431 /* Print an access to the element in the global memory copy
432 * described by "stmt". The index of the copy is recorded in
433 * stmt->index as a "call" to the array.
435 * The copy in global memory has been linearized, so we need to take
436 * the array size into account.
438 static __isl_give isl_printer *stmt_print_global_index(
439 __isl_take isl_printer *p, struct ppcg_kernel_stmt *stmt)
441 int i;
442 struct gpu_array_info *array = stmt->u.c.array;
443 isl_pw_aff_list *bound = stmt->u.c.local_array->bound;
445 if (gpu_array_is_scalar(array)) {
446 if (!array->read_only)
447 p = isl_printer_print_str(p, "*");
448 p = isl_printer_print_str(p, array->name);
449 return p;
452 p = isl_printer_print_str(p, array->name);
453 p = isl_printer_print_str(p, "[");
454 for (i = 0; i + 1 < array->n_index; ++i)
455 p = isl_printer_print_str(p, "(");
456 for (i = 0; i < array->n_index; ++i) {
457 isl_ast_expr *expr;
458 expr = isl_ast_expr_get_op_arg(stmt->u.c.index, 1 + i);
459 if (i) {
460 isl_pw_aff *bound_i;
461 bound_i = isl_pw_aff_list_get_pw_aff(bound, i);
462 p = isl_printer_print_str(p, ") * (");
463 p = isl_printer_print_pw_aff(p, bound_i);
464 p = isl_printer_print_str(p, ") + (");
465 isl_pw_aff_free(bound_i);
467 p = isl_printer_print_ast_expr(p, expr);
468 if (i)
469 p = isl_printer_print_str(p, ")");
470 isl_ast_expr_free(expr);
472 p = isl_printer_print_str(p, "]");
474 return p;
477 /* Print a copy statement.
479 * A read copy statement is printed as
481 * local = global;
483 * while a write copy statement is printed as
485 * global = local;
487 static __isl_give isl_printer *print_copy(__isl_take isl_printer *p,
488 struct ppcg_kernel_stmt *stmt)
490 p = isl_printer_start_line(p);
491 if (stmt->u.c.read) {
492 p = stmt_print_local_index(p, stmt);
493 p = isl_printer_print_str(p, " = ");
494 p = stmt_print_global_index(p, stmt);
495 } else {
496 p = stmt_print_global_index(p, stmt);
497 p = isl_printer_print_str(p, " = ");
498 p = stmt_print_local_index(p, stmt);
500 p = isl_printer_print_str(p, ";");
501 p = isl_printer_end_line(p);
503 return p;
506 /* Print a sync statement.
508 static __isl_give isl_printer *print_sync(__isl_take isl_printer *p,
509 struct ppcg_kernel_stmt *stmt)
511 p = isl_printer_start_line(p);
512 p = isl_printer_print_str(p, "__syncthreads();");
513 p = isl_printer_end_line(p);
515 return p;
518 /* Print an access based on the information in "access".
519 * If this an access to global memory, then the index expression
520 * is linearized.
522 * If access->array is NULL, then we are
523 * accessing an iterator in the original program.
525 static __isl_give isl_printer *print_access(__isl_take isl_printer *p,
526 struct ppcg_kernel_access *access)
528 int i;
529 unsigned n_index;
530 struct gpu_array_info *array;
531 isl_pw_aff_list *bound;
533 array = access->array;
534 bound = array ? access->local_array->bound : NULL;
535 if (!array)
536 p = isl_printer_print_str(p, "(");
537 else {
538 if (access->type == ppcg_access_global &&
539 gpu_array_is_scalar(array) && !array->read_only)
540 p = isl_printer_print_str(p, "*");
541 p = isl_printer_print_str(p, access->local_name);
542 if (gpu_array_is_scalar(array))
543 return p;
544 p = isl_printer_print_str(p, "[");
547 n_index = isl_ast_expr_list_n_ast_expr(access->index);
548 if (access->type == ppcg_access_global)
549 for (i = 0; i + 1 < n_index; ++i)
550 p = isl_printer_print_str(p, "(");
552 for (i = 0; i < n_index; ++i) {
553 isl_ast_expr *index;
555 index = isl_ast_expr_list_get_ast_expr(access->index, i);
556 if (array && i) {
557 if (access->type == ppcg_access_global) {
558 isl_pw_aff *bound_i;
559 bound_i = isl_pw_aff_list_get_pw_aff(bound, i);
560 p = isl_printer_print_str(p, ") * (");
561 p = isl_printer_print_pw_aff(p, bound_i);
562 p = isl_printer_print_str(p, ") + ");
563 isl_pw_aff_free(bound_i);
564 } else
565 p = isl_printer_print_str(p, "][");
567 p = isl_printer_print_ast_expr(p, index);
568 isl_ast_expr_free(index);
570 if (!array)
571 p = isl_printer_print_str(p, ")");
572 else
573 p = isl_printer_print_str(p, "]");
575 return p;
578 struct cuda_access_print_info {
579 int i;
580 struct ppcg_kernel_stmt *stmt;
583 /* To print the cuda accesses we walk the list of cuda accesses simultaneously
584 * with the pet printer. This means that whenever the pet printer prints a
585 * pet access expression we have the corresponding cuda access available and can
586 * print the modified access.
588 static __isl_give isl_printer *print_cuda_access(__isl_take isl_printer *p,
589 struct pet_expr *expr, void *usr)
591 struct cuda_access_print_info *info =
592 (struct cuda_access_print_info *) usr;
594 p = print_access(p, &info->stmt->u.d.access[info->i]);
595 info->i++;
597 return p;
600 static __isl_give isl_printer *print_stmt_body(__isl_take isl_printer *p,
601 struct ppcg_kernel_stmt *stmt)
603 struct cuda_access_print_info info;
605 info.i = 0;
606 info.stmt = stmt;
608 p = isl_printer_start_line(p);
609 p = print_pet_expr(p, stmt->u.d.stmt->body, &print_cuda_access, &info);
610 p = isl_printer_print_str(p, ";");
611 p = isl_printer_end_line(p);
613 return p;
616 /* This function is called for each user statement in the AST,
617 * i.e., for each kernel body statement, copy statement or sync statement.
619 static __isl_give isl_printer *print_kernel_stmt(__isl_take isl_printer *p,
620 __isl_take isl_ast_print_options *print_options,
621 __isl_keep isl_ast_node *node, void *user)
623 isl_id *id;
624 struct ppcg_kernel_stmt *stmt;
626 id = isl_ast_node_get_annotation(node);
627 stmt = isl_id_get_user(id);
628 isl_id_free(id);
630 isl_ast_print_options_free(print_options);
632 switch (stmt->type) {
633 case ppcg_kernel_copy:
634 return print_copy(p, stmt);
635 case ppcg_kernel_sync:
636 return print_sync(p, stmt);
637 case ppcg_kernel_domain:
638 return print_stmt_body(p, stmt);
641 return p;
644 static int print_macro(enum isl_ast_op_type type, void *user)
646 isl_printer **p = user;
648 if (type == isl_ast_op_fdiv_q)
649 return 0;
651 *p = isl_ast_op_type_print_macro(type, *p);
653 return 0;
656 /* Print the required macros for "node", including one for floord.
657 * We always print a macro for floord as it may also appear in the statements.
659 static __isl_give isl_printer *print_macros(
660 __isl_keep isl_ast_node *node, __isl_take isl_printer *p)
662 p = isl_ast_op_type_print_macro(isl_ast_op_fdiv_q, p);
663 if (isl_ast_node_foreach_ast_op_type(node, &print_macro, &p) < 0)
664 return isl_printer_free(p);
665 return p;
668 static void print_kernel(struct gpu_prog *prog, struct ppcg_kernel *kernel,
669 struct cuda_info *cuda)
671 isl_ctx *ctx = isl_ast_node_get_ctx(kernel->tree);
672 isl_ast_print_options *print_options;
673 isl_printer *p;
675 print_kernel_headers(prog, kernel, cuda);
676 fprintf(cuda->kernel_c, "{\n");
677 print_kernel_iterators(cuda->kernel_c, kernel);
679 p = isl_printer_to_file(ctx, cuda->kernel_c);
680 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
681 p = isl_printer_indent(p, 4);
683 p = print_kernel_vars(p, kernel);
684 p = isl_printer_end_line(p);
685 p = print_macros(kernel->tree, p);
687 print_options = isl_ast_print_options_alloc(ctx);
688 print_options = isl_ast_print_options_set_print_user(print_options,
689 &print_kernel_stmt, NULL);
690 p = isl_ast_node_print(kernel->tree, p, print_options);
691 isl_printer_free(p);
693 fprintf(cuda->kernel_c, "}\n");
696 struct print_host_user_data {
697 struct cuda_info *cuda;
698 struct gpu_prog *prog;
701 /* Print the user statement of the host code to "p".
703 * In particular, print a block of statements that defines the grid
704 * and the block and then launches the kernel.
706 static __isl_give isl_printer *print_host_user(__isl_take isl_printer *p,
707 __isl_take isl_ast_print_options *print_options,
708 __isl_keep isl_ast_node *node, void *user)
710 isl_id *id;
711 struct ppcg_kernel *kernel;
712 struct print_host_user_data *data;
714 id = isl_ast_node_get_annotation(node);
715 kernel = isl_id_get_user(id);
716 isl_id_free(id);
718 data = (struct print_host_user_data *) user;
720 p = isl_printer_start_line(p);
721 p = isl_printer_print_str(p, "{");
722 p = isl_printer_end_line(p);
723 p = isl_printer_indent(p, 2);
725 p = isl_printer_start_line(p);
726 p = isl_printer_print_str(p, "dim3 k");
727 p = isl_printer_print_int(p, kernel->id);
728 p = isl_printer_print_str(p, "_dimBlock");
729 print_reverse_list(isl_printer_get_file(p),
730 kernel->n_block, kernel->block_dim);
731 p = isl_printer_print_str(p, ";");
732 p = isl_printer_end_line(p);
734 p = print_grid(p, kernel);
736 p = isl_printer_start_line(p);
737 p = isl_printer_print_str(p, "kernel");
738 p = isl_printer_print_int(p, kernel->id);
739 p = isl_printer_print_str(p, " <<<k");
740 p = isl_printer_print_int(p, kernel->id);
741 p = isl_printer_print_str(p, "_dimGrid, k");
742 p = isl_printer_print_int(p, kernel->id);
743 p = isl_printer_print_str(p, "_dimBlock>>> (");
744 p = print_kernel_arguments(p, data->prog, kernel, 0);
745 p = isl_printer_print_str(p, ");");
746 p = isl_printer_end_line(p);
748 p = isl_printer_start_line(p);
749 p = isl_printer_print_str(p, "cudaCheckKernel();");
750 p = isl_printer_end_line(p);
752 p = isl_printer_indent(p, -2);
753 p = isl_printer_start_line(p);
754 p = isl_printer_print_str(p, "}");
755 p = isl_printer_end_line(p);
757 p = isl_printer_start_line(p);
758 p = isl_printer_end_line(p);
760 print_kernel(data->prog, kernel, data->cuda);
762 isl_ast_print_options_free(print_options);
764 return p;
767 static __isl_give isl_printer *print_host_code(__isl_take isl_printer *p,
768 struct gpu_prog *prog, __isl_keep isl_ast_node *tree,
769 struct cuda_info *cuda)
771 isl_ast_print_options *print_options;
772 isl_ctx *ctx = isl_ast_node_get_ctx(tree);
773 struct print_host_user_data data = { cuda, prog };
775 print_options = isl_ast_print_options_alloc(ctx);
776 print_options = isl_ast_print_options_set_print_user(print_options,
777 &print_host_user, &data);
779 p = print_macros(tree, p);
780 p = isl_ast_node_print(tree, p, print_options);
782 return p;
785 /* For each array that needs to be copied out (based on prog->copy_out),
786 * copy the contents back from the GPU to the host.
788 * If any element of a given array appears in prog->copy_out, then its
789 * entire extent is in prog->copy_out. The bounds on this extent have
790 * been precomputed in extract_array_info and are used in print_array_size.
792 static __isl_give isl_printer *copy_arrays_from_device(
793 __isl_take isl_printer *p, struct gpu_prog *prog)
795 int i;
796 isl_union_set *copy_out;
797 copy_out = isl_union_set_copy(prog->copy_out);
799 for (i = 0; i < prog->n_array; ++i) {
800 isl_space *dim;
801 isl_set *copy_out_i;
802 int empty;
804 dim = isl_space_copy(prog->array[i].dim);
805 copy_out_i = isl_union_set_extract_set(copy_out, dim);
806 empty = isl_set_fast_is_empty(copy_out_i);
807 isl_set_free(copy_out_i);
808 if (empty)
809 continue;
811 p = isl_printer_print_str(p, "cudaCheckReturn(cudaMemcpy(");
812 if (gpu_array_is_scalar(&prog->array[i]))
813 p = isl_printer_print_str(p, "&");
814 p = isl_printer_print_str(p, prog->array[i].name);
815 p = isl_printer_print_str(p, ", dev_");
816 p = isl_printer_print_str(p, prog->array[i].name);
817 p = isl_printer_print_str(p, ", ");
818 p = print_array_size(p, &prog->array[i]);
819 p = isl_printer_print_str(p, ", cudaMemcpyDeviceToHost));");
820 p = isl_printer_end_line(p);
823 isl_union_set_free(copy_out);
824 p = isl_printer_start_line(p);
825 p = isl_printer_end_line(p);
826 return p;
829 static __isl_give isl_printer *free_device_arrays(__isl_take isl_printer *p,
830 struct gpu_prog *prog)
832 int i;
834 for (i = 0; i < prog->n_array; ++i) {
835 if (gpu_array_is_read_only_scalar(&prog->array[i]))
836 continue;
837 p = isl_printer_print_str(p, "cudaCheckReturn(cudaFree(dev_");
838 p = isl_printer_print_str(p, prog->array[i].name);
839 p = isl_printer_print_str(p, "));");
840 p = isl_printer_end_line(p);
843 return p;
846 int generate_cuda(isl_ctx *ctx, struct ppcg_scop *scop,
847 struct ppcg_options *options, const char *input)
849 struct cuda_info cuda;
850 struct gpu_prog *prog;
851 isl_ast_node *tree;
852 isl_printer *p;
854 if (!scop)
855 return -1;
857 prog = gpu_prog_alloc(ctx, scop);
859 tree = generate_gpu(ctx, prog, options);
861 cuda.start = scop->start;
862 cuda.end = scop->end;
863 cuda_open_files(&cuda, input);
865 p = isl_printer_to_file(ctx, cuda.host_c);
866 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
867 p = ppcg_print_exposed_declarations(p, scop);
868 p = ppcg_start_block(p);
870 p = print_cuda_macros(p);
872 p = declare_device_arrays(p, prog);
873 p = allocate_device_arrays(p, prog);
874 p = copy_arrays_to_device(p, prog);
876 p = print_host_code(p, prog, tree, &cuda);
877 isl_ast_node_free(tree);
879 p = copy_arrays_from_device(p, prog);
880 p = free_device_arrays(p, prog);
882 p = ppcg_end_block(p);
883 isl_printer_free(p);
885 cuda_close_files(&cuda);
887 gpu_prog_free(prog);
889 return 0;