missing quote in coerce
[woropt.git] / cuda-fft.lisp
blob8fb383de0b7037d2d65ce7cb9de39cc348c68c4d
1 ;; calling nvidias cufft from sbcl
2 ;; I use debian5 64bit
3 ;; the graphics card is:
4 ;; 02:00.0 VGA compatible controller: nVidia Corporation G94 [GeForce 9600 GT] (rev a1)
5 ;; and the Cuda Toolkit 3.1
6 ;; http://developer.nvidia.com/object/cuda_3_1_downloads.html
7 ;; 2010-08-12 kielhorn.martin@googlemail.com
9 (defpackage :cuda-fft
10 (:use :cl :sb-alien :sb-c-call)
11 (:export #:ft3-csf
12 #:ift3-csf
13 #:ft2-csf
14 #:ift2-csf
15 #:ft3
16 #:ft2
17 #:ift2
18 #:ift3))
19 (in-package :cuda-fft)
21 (declaim (optimize (speed 2) (debug 3) (safety 3)))
23 (load-shared-object "/usr/local/cuda/lib64/libcudart.so")
24 (load-shared-object "/usr/local/cuda/lib64/libcufft.so")
27 (define-alien-type cufft-handle unsigned-int)
28 (define-alien-type cuda-device-ptr unsigned-int)
29 (define-alien-type cuda-error int) ;; enum, 0 is success
30 (define-alien-type cufft-result int) ;; enum, 0 is success
31 (define-alien-type cufft-type int) ;; enum, c2c is #x29
32 (define-alien-type cuda-memcpy-kind
33 (enum nil
34 host->host
35 host->device
36 device->host
37 device->device))
39 (defconstant +host->device+ 1)
40 (defconstant +device->host+ 2)
41 (defconstant +cufft-c2c+ #x29) ;; compex single-float
42 (defconstant +cufft-z2z+ #x69) ;; complex double-float
43 (defconstant +cufft-forward+ -1)
44 (defconstant +cufft-inverse+ 1)
45 (define-alien-type size-t unsigned-long)
46 (define-alien-type cufft-complex single-float) ;; there is no complex
47 ;; support in sb-alien
49 (define-alien-routine ("cudaMalloc" cuda-malloc)
50 cuda-error
51 (device-pointer cuda-device-ptr :out)
52 (size size-t))
54 (define-alien-routine ("cufftPlan2d" cufft-plan-2d)
55 cufft-result
56 (plan cufft-handle :out)
57 (nx int)
58 (ny int)
59 (type cufft-type))
61 (define-alien-routine ("cufftPlan3d" cufft-plan-3d)
62 cufft-result
63 (plan cufft-handle :out)
64 (nx int)
65 (ny int)
66 (nz int)
67 (type cufft-type))
69 (define-alien-routine ("cufftExecC2C" cufft-exec-c2c)
70 cufft-result
71 (plan cufft-handle)
72 (in-data (* cufft-complex))
73 (out-data (* cufft-complex))
74 (direction int))
76 (define-alien-routine ("cufftDestroy" cufft-destroy)
77 cufft-result
78 (plan cufft-handle))
80 (defun cu-plan (x y &optional z)
81 (multiple-value-bind (result plan)
82 (if z
83 (cufft-plan-3d x y z +cufft-c2c+)
84 (cufft-plan-2d x y +cufft-c2c+))
85 (unless (eq 0 result)
86 (error "cu-plan error: ~a" result))
87 plan))
90 (defun cu-malloc-csf (n)
91 (declare (fixnum n))
92 (let ((complex-single-float-size (* 4 2)))
93 (multiple-value-bind (result device-ptr)
94 (cuda-malloc (* complex-single-float-size n))
95 (unless (eq 0 result)
96 (error "cuda-malloc error: ~a" result))
97 device-ptr)))
99 (define-alien-routine ("cudaFree" cuda-free)
100 cuda-error
101 (device-pointer cuda-device-ptr :copy))
103 (define-alien-routine ("cudaMemcpy" cuda-memcpy)
104 cuda-error
105 (dst (* t))
106 (src (* t))
107 (count size-t)
108 (kind cuda-memcpy-kind))
110 ;; same semantics as ft3 wrapper to fftw3, input array isn't modified
111 ;; ft{2,3}-csf
112 (defmacro def-ft?-csf (rank)
113 (let ((ift (intern (format nil "IFT~d-CSF" rank)))
114 (ft (intern (format nil "FT~d-CSF" rank))))
115 `(progn
116 (defun ,ft (in &key (forward t))
117 (declare ((simple-array (complex single-float) ,rank) in)
118 (boolean forward)
119 (values (simple-array (complex single-float) ,rank) &optional))
120 (let ((dims (array-dimensions in)))
121 (let* ((out (make-array dims :element-type '(complex single-float)))
122 (out1 (sb-ext:array-storage-vector out))
123 (in1 (sb-ext:array-storage-vector in))
124 (n (length in1))
125 ;; allocate array on device
126 (device (cu-malloc-csf (length in1)))
127 (complex-single-float-size (* 4 2)))
128 ;; copy data to device
129 (cuda-memcpy (sb-sys:int-sap device)
130 (sb-sys:vector-sap in1)
131 (* n complex-single-float-size)
132 'host->device)
133 ;; plan and execute in-place transform on device
134 (let ((plan ,(ecase rank
135 (3 `(destructuring-bind (z y x)
136 dims
137 (cu-plan x y z)))
138 (2 `(destructuring-bind (y x)
139 dims
140 (cu-plan x y))))))
141 (cufft-exec-c2c plan
142 (sb-sys:int-sap device)
143 (sb-sys:int-sap device)
144 (if forward
145 +cufft-forward+
146 +cufft-inverse+))
147 (cufft-destroy plan))
148 ;; copy result back
149 (cuda-memcpy (sb-sys:vector-sap in1)
150 (sb-sys:int-sap device)
151 (* n complex-single-float-size) 'device->host)
152 ;; deallocate array on device
153 (cuda-free device)
154 ;; normalize if forward
155 (when forward
156 (let* ((1/n (/ 1s0 n)))
157 (dotimes (i n)
158 (setf (aref out1 i) (* 1/n (aref out1 i))))))
159 in)))
160 (defmacro ,ift (in)
161 `(,',ft ,in :forward nil)))))
164 (def-ft?-csf 2)
165 (def-ft?-csf 3)
167 (defmacro def-ft? (rank)
168 (let ((ift (intern (format nil "IFT~d" rank)))
169 (ft (intern (format nil "FT~d" rank)))
170 (ift-c (intern (format nil "IFT~d-CSF" rank)))
171 (ft-c (intern (format nil "FT~d-CSF" rank))))
172 `(progn
173 (defun ,ft (in &key (forward t))
174 (,ft-c in :forward forward))
175 (defun ,ift (in)
176 (,ift-c in)))))
178 (def-ft? 2)
179 (def-ft? 3)
181 #+nil
182 (progn
183 #.(require :vol)
184 (time
185 (let* ((nx 256)
186 (ny nx)
187 (nz ny)
188 (a (vol:convert3-ub8/csf-complex
189 (vol:draw-sphere-ub8 20d0 nz ny nx))))
190 (vol:write-pgm "cufft.pgm"
191 (vol:normalize2-csf/ub8-abs
192 (vol:cross-section-xz-csf (ft3-csf a)))))))