gpu.c: set_last_shared: also initialize group->last_shared in absence of tile
[ppcg.git] / cuda.c
blobea035643249b7361783fd1e561b36a5eefb606f3
1 /*
2 * Copyright 2012 Ecole Normale Superieure
4 * Use of this software is governed by the GNU LGPLv2.1 license
6 * Written by Sven Verdoolaege,
7 * Ecole Normale Superieure, 45 rue d’Ulm, 75230 Paris, France
8 */
10 #include <isl/aff.h>
11 #include <isl/ast.h>
13 #include "cuda_common.h"
14 #include "cuda.h"
15 #include "gpu.h"
16 #include "pet_printer.h"
17 #include "print.h"
18 #include "schedule.h"
20 static __isl_give isl_printer *print_cuda_macros(__isl_take isl_printer *p)
22 const char *macros =
23 "#define cudaCheckReturn(ret) assert((ret) == cudaSuccess)\n"
24 "#define cudaCheckKernel()"
25 " assert(cudaGetLastError() == cudaSuccess)\n\n";
26 p = isl_printer_print_str(p, macros);
27 return p;
30 static __isl_give isl_printer *print_array_size(__isl_take isl_printer *prn,
31 struct gpu_array_info *array)
33 int i;
35 for (i = 0; i < array->n_index; ++i) {
36 prn = isl_printer_print_str(prn, "(");
37 prn = isl_printer_print_pw_aff(prn, array->bound[i]);
38 prn = isl_printer_print_str(prn, ") * ");
40 prn = isl_printer_print_str(prn, "sizeof(");
41 prn = isl_printer_print_str(prn, array->type);
42 prn = isl_printer_print_str(prn, ")");
44 return prn;
47 static __isl_give isl_printer *declare_device_arrays(__isl_take isl_printer *p,
48 struct gpu_prog *prog)
50 int i;
52 for (i = 0; i < prog->n_array; ++i) {
53 if (gpu_array_is_read_only_scalar(&prog->array[i]))
54 continue;
55 p = isl_printer_start_line(p);
56 p = isl_printer_print_str(p, prog->array[i].type);
57 p = isl_printer_print_str(p, " *dev_");
58 p = isl_printer_print_str(p, prog->array[i].name);
59 p = isl_printer_print_str(p, ";");
60 p = isl_printer_end_line(p);
62 p = isl_printer_start_line(p);
63 p = isl_printer_end_line(p);
64 return p;
67 static __isl_give isl_printer *allocate_device_arrays(
68 __isl_take isl_printer *p, struct gpu_prog *prog)
70 int i;
72 for (i = 0; i < prog->n_array; ++i) {
73 if (gpu_array_is_read_only_scalar(&prog->array[i]))
74 continue;
75 p = isl_printer_start_line(p);
76 p = isl_printer_print_str(p,
77 "cudaCheckReturn(cudaMalloc((void **) &dev_");
78 p = isl_printer_print_str(p, prog->array[i].name);
79 p = isl_printer_print_str(p, ", ");
80 p = print_array_size(p, &prog->array[i]);
81 p = isl_printer_print_str(p, "));");
82 p = isl_printer_end_line(p);
84 p = isl_printer_start_line(p);
85 p = isl_printer_end_line(p);
86 return p;
89 static __isl_give isl_printer *copy_arrays_to_device(__isl_take isl_printer *p,
90 struct gpu_prog *prog)
92 int i;
94 for (i = 0; i < prog->n_array; ++i) {
95 isl_space *dim;
96 isl_set *read_i;
97 int empty;
99 if (gpu_array_is_read_only_scalar(&prog->array[i]))
100 continue;
102 dim = isl_space_copy(prog->array[i].dim);
103 read_i = isl_union_set_extract_set(prog->copy_in, dim);
104 empty = isl_set_fast_is_empty(read_i);
105 isl_set_free(read_i);
106 if (empty)
107 continue;
109 p = isl_printer_print_str(p, "cudaCheckReturn(cudaMemcpy(dev_");
110 p = isl_printer_print_str(p, prog->array[i].name);
111 p = isl_printer_print_str(p, ", ");
113 if (gpu_array_is_scalar(&prog->array[i]))
114 p = isl_printer_print_str(p, "&");
115 p = isl_printer_print_str(p, prog->array[i].name);
116 p = isl_printer_print_str(p, ", ");
118 p = print_array_size(p, &prog->array[i]);
119 p = isl_printer_print_str(p, ", cudaMemcpyHostToDevice));");
120 p = isl_printer_end_line(p);
122 p = isl_printer_start_line(p);
123 p = isl_printer_end_line(p);
124 return p;
127 static void print_reverse_list(FILE *out, int len, int *list)
129 int i;
131 if (len == 0)
132 return;
134 fprintf(out, "(");
135 for (i = 0; i < len; ++i) {
136 if (i)
137 fprintf(out, ", ");
138 fprintf(out, "%d", list[len - 1 - i]);
140 fprintf(out, ")");
143 /* Print the effective grid size as a list of the sizes in each
144 * dimension, from innermost to outermost.
146 static __isl_give isl_printer *print_grid_size(__isl_take isl_printer *p,
147 struct ppcg_kernel *kernel)
149 int i;
150 int dim;
152 dim = isl_multi_pw_aff_dim(kernel->grid_size, isl_dim_set);
153 if (dim == 0)
154 return p;
156 p = isl_printer_print_str(p, "(");
157 for (i = dim - 1; i >= 0; --i) {
158 isl_pw_aff *bound;
160 bound = isl_multi_pw_aff_get_pw_aff(kernel->grid_size, i);
161 p = isl_printer_print_pw_aff(p, bound);
162 isl_pw_aff_free(bound);
164 if (i > 0)
165 p = isl_printer_print_str(p, ", ");
168 p = isl_printer_print_str(p, ")");
170 return p;
173 /* Print the grid definition.
175 static __isl_give isl_printer *print_grid(__isl_take isl_printer *p,
176 struct ppcg_kernel *kernel)
178 p = isl_printer_start_line(p);
179 p = isl_printer_print_str(p, "dim3 k");
180 p = isl_printer_print_int(p, kernel->id);
181 p = isl_printer_print_str(p, "_dimGrid");
182 p = print_grid_size(p, kernel);
183 p = isl_printer_print_str(p, ";");
184 p = isl_printer_end_line(p);
186 return p;
189 /* Print the arguments to a kernel declaration or call. If "types" is set,
190 * then print a declaration (including the types of the arguments).
192 * The arguments are printed in the following order
193 * - the arrays accessed by the kernel
194 * - the parameters
195 * - the host loop iterators
197 static __isl_give isl_printer *print_kernel_arguments(__isl_take isl_printer *p,
198 struct gpu_prog *prog, struct ppcg_kernel *kernel, int types)
200 int i, n;
201 int first = 1;
202 unsigned nparam;
203 isl_space *space;
204 const char *type;
206 for (i = 0; i < prog->n_array; ++i) {
207 isl_set *arr;
208 int empty;
210 space = isl_space_copy(prog->array[i].dim);
211 arr = isl_union_set_extract_set(kernel->arrays, space);
212 empty = isl_set_fast_is_empty(arr);
213 isl_set_free(arr);
214 if (empty)
215 continue;
217 if (!first)
218 p = isl_printer_print_str(p, ", ");
220 if (types) {
221 p = isl_printer_print_str(p, prog->array[i].type);
222 p = isl_printer_print_str(p, " ");
225 if (gpu_array_is_read_only_scalar(&prog->array[i])) {
226 p = isl_printer_print_str(p, prog->array[i].name);
227 } else {
228 if (types)
229 p = isl_printer_print_str(p, "*");
230 else
231 p = isl_printer_print_str(p, "dev_");
232 p = isl_printer_print_str(p, prog->array[i].name);
235 first = 0;
238 space = isl_union_set_get_space(kernel->arrays);
239 nparam = isl_space_dim(space, isl_dim_param);
240 for (i = 0; i < nparam; ++i) {
241 const char *name;
243 name = isl_space_get_dim_name(space, isl_dim_param, i);
245 if (!first)
246 p = isl_printer_print_str(p, ", ");
247 if (types)
248 p = isl_printer_print_str(p, "int ");
249 p = isl_printer_print_str(p, name);
251 first = 0;
253 isl_space_free(space);
255 n = isl_space_dim(kernel->space, isl_dim_set);
256 type = isl_options_get_ast_iterator_type(prog->ctx);
257 for (i = 0; i < n; ++i) {
258 const char *name;
259 isl_id *id;
261 if (!first)
262 p = isl_printer_print_str(p, ", ");
263 name = isl_space_get_dim_name(kernel->space, isl_dim_set, i);
264 if (types) {
265 p = isl_printer_print_str(p, type);
266 p = isl_printer_print_str(p, " ");
268 p = isl_printer_print_str(p, name);
270 first = 0;
273 return p;
276 /* Print the header of the given kernel.
278 static __isl_give isl_printer *print_kernel_header(__isl_take isl_printer *p,
279 struct gpu_prog *prog, struct ppcg_kernel *kernel)
281 p = isl_printer_start_line(p);
282 p = isl_printer_print_str(p, "__global__ void kernel");
283 p = isl_printer_print_int(p, kernel->id);
284 p = isl_printer_print_str(p, "(");
285 p = print_kernel_arguments(p, prog, kernel, 1);
286 p = isl_printer_print_str(p, ")");
288 return p;
291 /* Print the header of the given kernel to both gen->cuda.kernel_h
292 * and gen->cuda.kernel_c.
294 static void print_kernel_headers(struct gpu_prog *prog,
295 struct ppcg_kernel *kernel, struct cuda_info *cuda)
297 isl_printer *p;
299 p = isl_printer_to_file(prog->ctx, cuda->kernel_h);
300 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
301 p = print_kernel_header(p, prog, kernel);
302 p = isl_printer_print_str(p, ";");
303 p = isl_printer_end_line(p);
304 isl_printer_free(p);
306 p = isl_printer_to_file(prog->ctx, cuda->kernel_c);
307 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
308 p = print_kernel_header(p, prog, kernel);
309 p = isl_printer_end_line(p);
310 isl_printer_free(p);
313 static void print_indent(FILE *dst, int indent)
315 fprintf(dst, "%*s", indent, "");
318 static void print_kernel_iterators(FILE *out, struct ppcg_kernel *kernel)
320 int i;
321 const char *block_dims[] = { "blockIdx.x", "blockIdx.y" };
322 const char *thread_dims[] = { "threadIdx.x", "threadIdx.y",
323 "threadIdx.z" };
325 if (kernel->n_grid > 0) {
326 print_indent(out, 4);
327 fprintf(out, "int ");
328 for (i = 0; i < kernel->n_grid; ++i) {
329 if (i)
330 fprintf(out, ", ");
331 fprintf(out, "b%d = %s",
332 i, block_dims[kernel->n_grid - 1 - i]);
334 fprintf(out, ";\n");
337 if (kernel->n_block > 0) {
338 print_indent(out, 4);
339 fprintf(out, "int ");
340 for (i = 0; i < kernel->n_block; ++i) {
341 if (i)
342 fprintf(out, ", ");
343 fprintf(out, "t%d = %s",
344 i, thread_dims[kernel->n_block - 1 - i]);
346 fprintf(out, ";\n");
350 static void print_kernel_var(FILE *out, struct ppcg_kernel_var *var)
352 int j;
353 isl_int v;
355 print_indent(out, 4);
356 if (var->type == ppcg_access_shared)
357 fprintf(out, "__shared__ ");
358 fprintf(out, "%s %s", var->array->type, var->name);
359 isl_int_init(v);
360 for (j = 0; j < var->array->n_index; ++j) {
361 fprintf(out, "[");
362 isl_vec_get_element(var->size, j, &v);
363 isl_int_print(out, v, 0);
364 fprintf(out, "]");
366 isl_int_clear(v);
367 fprintf(out, ";\n");
370 static void print_kernel_vars(FILE *out, struct ppcg_kernel *kernel)
372 int i;
374 for (i = 0; i < kernel->n_var; ++i)
375 print_kernel_var(out, &kernel->var[i]);
378 /* Print an access to the element in the private/shared memory copy
379 * described by "stmt". The index of the copy is recorded in
380 * stmt->local_index as a "call" to the array.
382 static __isl_give isl_printer *stmt_print_local_index(__isl_take isl_printer *p,
383 struct ppcg_kernel_stmt *stmt)
385 int i;
386 isl_ast_expr *expr;
387 struct gpu_array_info *array = stmt->u.c.array;
389 expr = isl_ast_expr_get_op_arg(stmt->u.c.local_index, 0);
390 p = isl_printer_print_ast_expr(p, expr);
391 isl_ast_expr_free(expr);
393 for (i = 0; i < array->n_index; ++i) {
394 expr = isl_ast_expr_get_op_arg(stmt->u.c.local_index, 1 + i);
396 p = isl_printer_print_str(p, "[");
397 p = isl_printer_print_ast_expr(p, expr);
398 p = isl_printer_print_str(p, "]");
400 isl_ast_expr_free(expr);
403 return p;
406 /* Print an access to the element in the global memory copy
407 * described by "stmt". The index of the copy is recorded in
408 * stmt->index as a "call" to the array.
410 * The copy in global memory has been linearized, so we need to take
411 * the array size into account.
413 static __isl_give isl_printer *stmt_print_global_index(
414 __isl_take isl_printer *p, struct ppcg_kernel_stmt *stmt)
416 int i;
417 struct gpu_array_info *array = stmt->u.c.array;
418 isl_pw_aff_list *bound = stmt->u.c.local_array->bound;
420 if (gpu_array_is_scalar(array)) {
421 if (!array->read_only)
422 p = isl_printer_print_str(p, "*");
423 p = isl_printer_print_str(p, array->name);
424 return p;
427 p = isl_printer_print_str(p, array->name);
428 p = isl_printer_print_str(p, "[");
429 for (i = 0; i + 1 < array->n_index; ++i)
430 p = isl_printer_print_str(p, "(");
431 for (i = 0; i < array->n_index; ++i) {
432 isl_ast_expr *expr;
433 expr = isl_ast_expr_get_op_arg(stmt->u.c.index, 1 + i);
434 if (i) {
435 isl_pw_aff *bound_i;
436 bound_i = isl_pw_aff_list_get_pw_aff(bound, i);
437 p = isl_printer_print_str(p, ") * (");
438 p = isl_printer_print_pw_aff(p, bound_i);
439 p = isl_printer_print_str(p, ") + ");
440 isl_pw_aff_free(bound_i);
442 p = isl_printer_print_ast_expr(p, expr);
443 isl_ast_expr_free(expr);
445 p = isl_printer_print_str(p, "]");
447 return p;
450 /* Print a copy statement.
452 * A read copy statement is printed as
454 * local = global;
456 * while a write copy statement is printed as
458 * global = local;
460 static __isl_give isl_printer *print_copy(__isl_take isl_printer *p,
461 struct ppcg_kernel_stmt *stmt)
463 p = isl_printer_start_line(p);
464 if (stmt->u.c.read) {
465 p = stmt_print_local_index(p, stmt);
466 p = isl_printer_print_str(p, " = ");
467 p = stmt_print_global_index(p, stmt);
468 } else {
469 p = stmt_print_global_index(p, stmt);
470 p = isl_printer_print_str(p, " = ");
471 p = stmt_print_local_index(p, stmt);
473 p = isl_printer_print_str(p, ";");
474 p = isl_printer_end_line(p);
476 return p;
479 /* Print a sync statement.
481 static __isl_give isl_printer *print_sync(__isl_take isl_printer *p,
482 struct ppcg_kernel_stmt *stmt)
484 p = isl_printer_start_line(p);
485 p = isl_printer_print_str(p, "__syncthreads();");
486 p = isl_printer_end_line(p);
488 return p;
491 /* Print an access based on the information in "access".
492 * If this an access to global memory, then the index expression
493 * is linearized.
495 * If access->array is NULL, then we are
496 * accessing an iterator in the original program.
498 static __isl_give isl_printer *print_access(__isl_take isl_printer *p,
499 struct ppcg_kernel_access *access)
501 int i;
502 unsigned n_index;
503 struct gpu_array_info *array;
504 isl_pw_aff_list *bound;
506 array = access->array;
507 bound = array ? access->local_array->bound : NULL;
508 if (!array)
509 p = isl_printer_print_str(p, "(");
510 else {
511 if (access->type == ppcg_access_global &&
512 gpu_array_is_scalar(array) && !array->read_only)
513 p = isl_printer_print_str(p, "*");
514 p = isl_printer_print_str(p, access->local_name);
515 if (gpu_array_is_scalar(array))
516 return p;
517 p = isl_printer_print_str(p, "[");
520 n_index = isl_ast_expr_list_n_ast_expr(access->index);
521 if (access->type == ppcg_access_global)
522 for (i = 0; i + 1 < n_index; ++i)
523 p = isl_printer_print_str(p, "(");
525 for (i = 0; i < n_index; ++i) {
526 isl_ast_expr *index;
528 index = isl_ast_expr_list_get_ast_expr(access->index, i);
529 if (array && i) {
530 if (access->type == ppcg_access_global) {
531 isl_pw_aff *bound_i;
532 bound_i = isl_pw_aff_list_get_pw_aff(bound, i);
533 p = isl_printer_print_str(p, ") * (");
534 p = isl_printer_print_pw_aff(p, bound_i);
535 p = isl_printer_print_str(p, ") + ");
536 isl_pw_aff_free(bound_i);
537 } else
538 p = isl_printer_print_str(p, "][");
540 p = isl_printer_print_ast_expr(p, index);
541 isl_ast_expr_free(index);
543 if (!array)
544 p = isl_printer_print_str(p, ")");
545 else
546 p = isl_printer_print_str(p, "]");
548 return p;
551 struct cuda_access_print_info {
552 int i;
553 struct ppcg_kernel_stmt *stmt;
556 /* To print the cuda accesses we walk the list of cuda accesses simultaneously
557 * with the pet printer. This means that whenever the pet printer prints a
558 * pet access expression we have the corresponding cuda access available and can
559 * print the modified access.
561 static __isl_give isl_printer *print_cuda_access(__isl_take isl_printer *p,
562 struct pet_expr *expr, void *usr)
564 struct cuda_access_print_info *info =
565 (struct cuda_access_print_info *) usr;
567 p = print_access(p, &info->stmt->u.d.access[info->i]);
568 info->i++;
570 return p;
573 static __isl_give isl_printer *print_stmt_body(__isl_take isl_printer *p,
574 struct ppcg_kernel_stmt *stmt)
576 struct cuda_access_print_info info;
578 info.i = 0;
579 info.stmt = stmt;
581 p = isl_printer_start_line(p);
582 p = print_pet_expr(p, stmt->u.d.stmt->body, &print_cuda_access, &info);
583 p = isl_printer_print_str(p, ";");
584 p = isl_printer_end_line(p);
586 return p;
589 /* This function is called for each user statement in the AST,
590 * i.e., for each kernel body statement, copy statement or sync statement.
592 static __isl_give isl_printer *print_kernel_stmt(__isl_take isl_printer *p,
593 __isl_take isl_ast_print_options *print_options,
594 __isl_keep isl_ast_node *node, void *user)
596 isl_id *id;
597 struct ppcg_kernel_stmt *stmt;
599 id = isl_ast_node_get_annotation(node);
600 stmt = isl_id_get_user(id);
601 isl_id_free(id);
603 isl_ast_print_options_free(print_options);
605 switch (stmt->type) {
606 case ppcg_kernel_copy:
607 return print_copy(p, stmt);
608 case ppcg_kernel_sync:
609 return print_sync(p, stmt);
610 case ppcg_kernel_domain:
611 return print_stmt_body(p, stmt);
614 return p;
617 static int print_macro(enum isl_ast_op_type type, void *user)
619 isl_printer **p = user;
621 if (type == isl_ast_op_fdiv_q)
622 return 0;
624 *p = isl_ast_op_type_print_macro(type, *p);
626 return 0;
629 /* Print the required macros for "node", including one for floord.
630 * We always print a macro for floord as it may also appear in the statements.
632 static __isl_give isl_printer *print_macros(
633 __isl_keep isl_ast_node *node, __isl_take isl_printer *p)
635 p = isl_ast_op_type_print_macro(isl_ast_op_fdiv_q, p);
636 if (isl_ast_node_foreach_ast_op_type(node, &print_macro, &p) < 0)
637 return isl_printer_free(p);
638 return p;
641 static void print_kernel(struct gpu_prog *prog, struct ppcg_kernel *kernel,
642 struct cuda_info *cuda)
644 isl_ctx *ctx = isl_ast_node_get_ctx(kernel->tree);
645 isl_ast_print_options *print_options;
646 isl_printer *p;
648 print_kernel_headers(prog, kernel, cuda);
649 fprintf(cuda->kernel_c, "{\n");
650 print_kernel_iterators(cuda->kernel_c, kernel);
651 print_kernel_vars(cuda->kernel_c, kernel);
652 fprintf(cuda->kernel_c, "\n");
654 print_options = isl_ast_print_options_alloc(ctx);
655 print_options = isl_ast_print_options_set_print_user(print_options,
656 &print_kernel_stmt, NULL);
658 p = isl_printer_to_file(ctx, cuda->kernel_c);
659 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
660 p = isl_printer_indent(p, 4);
661 p = print_macros(kernel->tree, p);
662 p = isl_ast_node_print(kernel->tree, p, print_options);
663 isl_printer_free(p);
665 fprintf(cuda->kernel_c, "}\n");
668 struct print_host_user_data {
669 struct cuda_info *cuda;
670 struct gpu_prog *prog;
673 /* Print the user statement of the host code to "p".
675 * In particular, print a block of statements that defines the grid
676 * and the block and then launches the kernel.
678 static __isl_give isl_printer *print_host_user(__isl_take isl_printer *p,
679 __isl_take isl_ast_print_options *print_options,
680 __isl_keep isl_ast_node *node, void *user)
682 isl_id *id;
683 struct ppcg_kernel *kernel;
684 struct print_host_user_data *data;
686 id = isl_ast_node_get_annotation(node);
687 kernel = isl_id_get_user(id);
688 isl_id_free(id);
690 data = (struct print_host_user_data *) user;
692 p = isl_printer_start_line(p);
693 p = isl_printer_print_str(p, "{");
694 p = isl_printer_end_line(p);
695 p = isl_printer_indent(p, 2);
697 p = isl_printer_start_line(p);
698 p = isl_printer_print_str(p, "dim3 k");
699 p = isl_printer_print_int(p, kernel->id);
700 p = isl_printer_print_str(p, "_dimBlock");
701 print_reverse_list(isl_printer_get_file(p),
702 kernel->n_block, kernel->block_dim);
703 p = isl_printer_print_str(p, ";");
704 p = isl_printer_end_line(p);
706 p = print_grid(p, kernel);
708 p = isl_printer_start_line(p);
709 p = isl_printer_print_str(p, "kernel");
710 p = isl_printer_print_int(p, kernel->id);
711 p = isl_printer_print_str(p, " <<<k");
712 p = isl_printer_print_int(p, kernel->id);
713 p = isl_printer_print_str(p, "_dimGrid, k");
714 p = isl_printer_print_int(p, kernel->id);
715 p = isl_printer_print_str(p, "_dimBlock>>> (");
716 p = print_kernel_arguments(p, data->prog, kernel, 0);
717 p = isl_printer_print_str(p, ");");
718 p = isl_printer_end_line(p);
720 p = isl_printer_start_line(p);
721 p = isl_printer_print_str(p, "cudaCheckKernel();");
722 p = isl_printer_end_line(p);
724 p = isl_printer_indent(p, -2);
725 p = isl_printer_start_line(p);
726 p = isl_printer_print_str(p, "}");
727 p = isl_printer_end_line(p);
729 p = isl_printer_start_line(p);
730 p = isl_printer_end_line(p);
732 print_kernel(data->prog, kernel, data->cuda);
734 isl_ast_print_options_free(print_options);
736 return p;
739 static __isl_give isl_printer *print_host_code(__isl_take isl_printer *p,
740 struct gpu_prog *prog, __isl_keep isl_ast_node *tree,
741 struct cuda_info *cuda)
743 isl_ast_print_options *print_options;
744 isl_ctx *ctx = isl_ast_node_get_ctx(tree);
745 struct print_host_user_data data = { cuda, prog };
747 print_options = isl_ast_print_options_alloc(ctx);
748 print_options = isl_ast_print_options_set_print_user(print_options,
749 &print_host_user, &data);
751 p = print_macros(tree, p);
752 p = isl_ast_node_print(tree, p, print_options);
754 return p;
757 /* For each array that is written anywhere in the gpu_prog,
758 * copy the contents back from the GPU to the host.
760 * Arrays that are not visible outside the corresponding scop
761 * do not need to be copied back.
763 static __isl_give isl_printer *copy_arrays_from_device(
764 __isl_take isl_printer *p, struct gpu_prog *prog)
766 int i;
767 isl_union_set *write;
768 write = isl_union_map_range(isl_union_map_copy(prog->write));
770 for (i = 0; i < prog->n_array; ++i) {
771 isl_space *dim;
772 isl_set *write_i;
773 int empty;
775 if (prog->array[i].local)
776 continue;
778 dim = isl_space_copy(prog->array[i].dim);
779 write_i = isl_union_set_extract_set(write, dim);
780 empty = isl_set_fast_is_empty(write_i);
781 isl_set_free(write_i);
782 if (empty)
783 continue;
785 p = isl_printer_print_str(p, "cudaCheckReturn(cudaMemcpy(");
786 if (gpu_array_is_scalar(&prog->array[i]))
787 p = isl_printer_print_str(p, "&");
788 p = isl_printer_print_str(p, prog->array[i].name);
789 p = isl_printer_print_str(p, ", dev_");
790 p = isl_printer_print_str(p, prog->array[i].name);
791 p = isl_printer_print_str(p, ", ");
792 p = print_array_size(p, &prog->array[i]);
793 p = isl_printer_print_str(p, ", cudaMemcpyDeviceToHost));");
794 p = isl_printer_end_line(p);
797 isl_union_set_free(write);
798 p = isl_printer_start_line(p);
799 p = isl_printer_end_line(p);
800 return p;
803 static __isl_give isl_printer *free_device_arrays(__isl_take isl_printer *p,
804 struct gpu_prog *prog)
806 int i;
808 for (i = 0; i < prog->n_array; ++i) {
809 if (gpu_array_is_read_only_scalar(&prog->array[i]))
810 continue;
811 p = isl_printer_print_str(p, "cudaCheckReturn(cudaFree(dev_");
812 p = isl_printer_print_str(p, prog->array[i].name);
813 p = isl_printer_print_str(p, "));");
814 p = isl_printer_end_line(p);
817 return p;
820 int generate_cuda(isl_ctx *ctx, struct ppcg_scop *scop,
821 struct ppcg_options *options, const char *input)
823 struct cuda_info cuda;
824 struct gpu_prog *prog;
825 isl_ast_node *tree;
826 isl_printer *p;
828 if (!scop)
829 return -1;
831 scop->context = add_context_from_str(scop->context, options->ctx);
833 prog = gpu_prog_alloc(ctx, scop);
835 tree = generate_gpu(ctx, prog, options);
837 cuda_open_files(&cuda, input);
839 p = isl_printer_to_file(ctx, cuda.host_c);
840 p = isl_printer_set_output_format(p, ISL_FORMAT_C);
841 p = ppcg_print_exposed_declarations(p, scop);
842 p = ppcg_start_block(p);
844 p = print_cuda_macros(p);
846 p = declare_device_arrays(p, prog);
847 p = allocate_device_arrays(p, prog);
848 p = copy_arrays_to_device(p, prog);
850 p = print_host_code(p, prog, tree, &cuda);
851 isl_ast_node_free(tree);
853 p = copy_arrays_from_device(p, prog);
854 p = free_device_arrays(p, prog);
856 p = ppcg_end_block(p);
857 isl_printer_free(p);
859 cuda_close_files(&cuda);
861 gpu_prog_free(prog);
863 return 0;