Initial commit, 3-52-19 alpha
[cls.git] / src / lsp / linalg.lsp
blob9838f4a063cdac9c80f6c0bea3edcdd0ae322eeb
1 ;;;;
2 ;;;; linalg.lsp XLISP-STAT linear algebra functions
3 ;;;; XLISP-STAT 2.1 Copyright (c) 1990-1995, by Luke Tierney
4 ;;;; Additions to Xlisp 2.1, Copyright (c) 1989 by David Michael Betz
5 ;;;; You may give out copies of this software; for conditions see the file
6 ;;;; COPYING included with this distribution.
7 ;;;;
9 (in-package "XLISP")
11 (provide "linalg")
14 ;;;;
15 ;;;; Basic Matrix Operations
16 ;;;;
18 (export '(matmult %*
19 inner-product cross-product outer-product
20 identity-matrix))
22 (defun multiply-matrix-matrix (x y)
23 (multiple-value-bind
24 (fun type atype)
25 (if (any-complex-elements x y)
26 (values #'blas-zgemm 'c-dcomplex '(array c-dcomplex))
27 (values #'blas-dgemm 'c-double '(array c-double)))
28 (let* ((x (coerce x atype))
29 (y (coerce y atype))
30 (m (array-dimension x 0))
31 (k (array-dimension x 1))
32 (ky (array-dimension y 0))
33 (n (array-dimension y 1))
34 (v (make-array (list m n) :element-type type)))
35 (unless (= ky k) (error "dimensions do not match"))
36 (funcall fun "n" "n" n m k 1 y 0 n x 0 k 0 v 0 n)
37 (coerce v '(array t)))))
39 (defun multiply-matrix-vector (x y)
40 (multiple-value-bind
41 (fun type atype)
42 (if (any-complex-elements x y)
43 (values #'blas-zgemv 'c-dcomplex '(array c-dcomplex))
44 (values #'blas-dgemv 'c-double '(array c-double)))
45 (let* ((vtype (if (listp y) 'list '(array t)))
46 (x (coerce x atype))
47 (y (coerce y atype))
48 (m (array-dimension x 0))
49 (k (array-dimension x 1))
50 (ky (length y))
51 (v (make-array m :element-type type)))
52 (unless (= ky k) (error "dimensions do not match"))
53 (funcall fun "t" k m 1 x 0 k y 0 1 0 v 0 1)
54 (coerce v vtype))))
56 (defun multiply-vector-matrix (x y)
57 (multiple-value-bind
58 (fun type atype)
59 (if (any-complex-elements x y)
60 (values #'blas-zgemv 'c-dcomplex '(array c-dcomplex))
61 (values #'blas-dgemv 'c-double '(array c-double)))
62 (let* ((vtype (if (listp x) 'list '(array t)))
63 (x (coerce x atype))
64 (y (coerce y atype))
65 (k (length x))
66 (ky (array-dimension y 0))
67 (n (array-dimension y 1))
68 (v (make-array n :element-type type)))
69 (unless (= ky k) (error "dimensions do not match"))
70 (funcall fun "n" n k 1 y 0 n x 0 1 0 v 0 1)
71 (coerce v vtype))))
73 (defun inner-product (x y &optional (conjugate t))
74 (multiple-value-bind
75 (fun atype)
76 (if (any-complex-elements x y)
77 (values (if conjugate #'blas-zdotc #'blas-zdotu) '(array c-dcomplex))
78 (values #'blas-ddot '(array c-double)))
79 (let* ((x (coerce x atype))
80 (y (coerce y atype))
81 (k (length x))
82 (ky (length y)))
83 (unless (= ky k) (error "dimensions do not match"))
84 (funcall fun k y 0 1 x 0 1))))
86 (defun binary-matmult (x y)
87 (cond
88 ((matrixp x)
89 (cond
90 ((matrixp y) (multiply-matrix-matrix x y))
91 ((sequencep y) (multiply-matrix-vector x y))
92 (t (* x y))))
93 ((sequencep x)
94 (cond
95 ((matrixp y) (multiply-vector-matrix x y))
96 ((sequencep y) (inner-product x y nil))
97 (t (* x y))))
98 (t (* x y))))
100 (defun matmult (x &rest more)
101 (reduce #'binary-matmult more :initial-value x))
103 (setf (symbol-function '%*) #'matmult)
106 (defun cross-product (x &optional (conjugate t))
107 (if (sequencep x)
108 (inner-product x x conjugate)
109 (multiple-value-bind
110 (fun type atype trans)
111 (if (any-complex-elements x)
112 (values #'blas-zgemm
113 'c-dcomplex
114 '(array c-dcomplex)
115 (if conjugate "c" "t"))
116 (values #'blas-dgemm 'c-double '(array c-double) "t"))
117 (let* ((x (coerce x atype))
118 (m (array-dimension x 0))
119 (n (array-dimension x 1))
120 (v (make-array (list n n) :element-type type)))
121 (funcall fun "n" trans n n m 1 x 0 n x 0 n 0 v 0 n)
122 (coerce v '(array t))))))
124 (defun outer-product (x y &optional f)
125 (unless (compound-data-p x) (setf x (vector x)))
126 (unless (compound-data-p y) (setf y (vector y)))
127 (let* ((x (coerce (compound-data-seq x) 'vector))
128 (y (coerce (compound-data-seq y) 'vector))
129 (m (length x))
130 (n (length y))
131 (v (make-array (list m n))))
132 (if f
133 (dotimes (i m)
134 (dotimes (j n)
135 (setf (aref v i j) (funcall f (aref x i) (aref y j)))))
136 (dotimes (i m)
137 (dotimes (j n)
138 (setf (aref v i j) (* (aref x i) (aref y j))))))
141 (defun identity-matrix (n)
142 (diagonal (make-list n :initial-element 1)))
145 ;;;;
146 ;;;; TRANSPOSE
147 ;;;;
149 (export 'transpose)
151 (defun transpose (x)
152 (cond
153 ((matrixp x) (permute-array x '(1 0)))
154 ((consp x) (transpose-list x))
155 (t (error "bad argumant type - ~s" x))))
158 ;;;;
159 ;;;; SWEEP Operator
160 ;;;;
162 (export '(make-sweep-matrix sweep-operator))
164 (defun make-sweep-matrix (x y &optional w)
165 (let* ((n (array-dimension x 0))
166 (p (array-dimension x 1))
167 (x (coerce x '(array c-double)))
168 (y (coerce y '(vector c-double)))
169 (w (if w
170 (coerce w '(vector c-double))
171 (make-array n :element-type 'c-double :initial-element 1.0)))
172 (sm (make-array (list (+ p 2) (+ p 2)) :element-type 'c-double))
173 (xmean (make-array p :element-type 'c-double)))
174 (base-make-sweep-matrix n p x y w sm xmean)
175 (coerce sm '(array t))))
177 (defun sweep-operator (a cols &optional (tol .000001))
178 (let* ((m (array-dimension a 0))
179 (n (array-dimension a 1))
180 (a (make-array (* m n)
181 :element-type 'c-double
182 :initial-contents (compound-data-seq a)))
183 (tols (if (numberp tol) (repeat tol (length cols)) tol))
184 (swept nil))
185 (loop
186 (if (or (null cols) (null tols)) (return))
187 (let ((k (pop cols))
188 (tol (pop tols)))
189 (if (sweep-in-place m n a k tol) (push k swept))))
190 (list (make-array (list m n) :displaced-to (coerce a '(array t))) swept)))
193 ;;;;
194 ;;;; Utilities for LINPACK Interface
195 ;;;;
198 (defun generic-to-linalg (x m n type &optional trans)
199 (if trans
200 (let ((xv (make-array (* m n) :element-type type)))
201 (transpose-into x m n xv)
203 (make-array (* m n)
204 :element-type type
205 :initial-contents (compound-data-seq x))))
207 (defun linalg-to-generic (x dim &optional trans)
208 (let ((val (make-array dim)))
209 (if trans
210 (transpose-into x (second dim) (first dim) val)
211 (replace (compound-data-seq val) x))
212 val))
215 (defun square-matrix-p (x)
216 (and (matrixp x) (= (array-dimension x 0) (array-dimension x 1))))
218 (defmacro check-square-matrix (x)
219 `(unless (square-matrix-p ,x) (error "not a square matrix -- ~s" ,x)))
221 (defmacro check-matrix (x)
222 `(unless (matrixp ,x) (error "not a matrix -- ~s" ,x)))
225 ;;;;
226 ;;;; LU Decomposition, Determinant, and Inverse
227 ;;;;
229 (export '(lu-decomp rcondest determinant inverse lu-solve))
231 (defun lu-decomp (x)
232 (check-square-matrix x)
233 (multiple-value-bind
234 (fun type)
235 (if (any-complex-elements x)
236 (values #'linpack-zgefa 'c-dcomplex)
237 (values #'linpack-dgefa 'c-double))
238 (let* ((n (array-dimension x 0))
239 (xv (generic-to-linalg x n n type t))
240 (ipvt (make-array n :element-type 'c-int))
241 (info (funcall fun xv 0 n n ipvt))
242 (odd nil)
243 (im1 (1- ipvt)))
244 (dotimes (i n) (unless (= i (aref im1 i)) (setf odd (not odd))))
245 (list (linalg-to-generic xv (list n n) t)
247 (if odd -1.0 1.0)
248 (/= info 0.0)))))
250 (defun rcondest (x)
251 (check-square-matrix x)
252 (multiple-value-bind
253 (fun type)
254 (if (any-complex-elements x)
255 (values #'linpack-zgeco 'c-dcomplex)
256 (values #'linpack-dgeco 'c-double))
257 (let* ((n (array-dimension x 0))
258 (xv (generic-to-linalg x n n type t))
259 (ipvt (make-array n :element-type 'c-int))
260 (z (make-array n :element-type type)))
261 (funcall fun xv 0 n n ipvt z))))
263 (defun determinant (x)
264 (check-square-matrix x)
265 (multiple-value-bind
266 (fun1 fun2 type)
267 (if (any-complex-elements x)
268 (values #'linpack-zgefa #'linpack-zgedi 'c-dcomplex)
269 (values #'linpack-dgefa #'linpack-dgedi 'c-double))
270 (let* ((n (array-dimension x 0))
271 (xv (generic-to-linalg x n n type t))
272 (ipvt (make-array n :element-type 'c-int))
273 (det (make-array 2 :element-type type))
274 (work (make-array n :element-type type)))
275 (funcall fun1 xv 0 n n ipvt)
276 (funcall fun2 xv 0 n n ipvt det work 10)
277 (* (aref det 0) (^ 10 (aref det 1))))))
279 (defun inverse (x)
280 (check-square-matrix x)
281 (multiple-value-bind
282 (fun1 fun2 type)
283 (if (any-complex-elements x)
284 (values #'linpack-zgefa #'linpack-zgedi 'c-dcomplex)
285 (values #'linpack-dgefa #'linpack-dgedi 'c-double))
286 (let* ((n (array-dimension x 0))
287 (xv (generic-to-linalg x n n type t))
288 (ipvt (make-array n :element-type 'c-int))
289 (work (make-array n :element-type type)))
290 (funcall fun1 xv 0 n n ipvt)
291 (funcall fun2 xv 0 n n ipvt nil work 1)
292 (linalg-to-generic xv (list n n) t))))
294 (defun lu-solve (lu b)
295 (let ((x (first lu))
296 (i (+ (second lu) 1)))
297 (check-square-matrix x)
298 (multiple-value-bind
299 (fun type)
300 (if (any-complex-elements x b)
301 (values #'linpack-zgesl 'c-dcomplex)
302 (values #'linpack-dgesl 'c-double))
303 (let* ((n (array-dimension x 0))
304 (xv (generic-to-linalg x n n type t))
305 (ipvt (generic-to-linalg i n 1 'c-int))
306 (bv (generic-to-linalg b n 1 type)))
307 (funcall fun xv 0 n n ipvt bv 0)
308 (coerce bv (if (vectorp b) '(vector t) 'list))))))
311 ;;;;
312 ;;;; QR and SV Decompositions
313 ;;;;
315 (export '(qr-decomp sv-decomp))
317 (defun qr-decomp (x &optional pivot)
318 (check-matrix x)
319 (multiple-value-bind
320 (fun type)
321 (if (any-complex-elements x)
322 (values #'linpack-zqrdc 'c-dcomplex)
323 (values #'linpack-dqrdc 'c-double))
324 (let* ((n (array-dimension x 0))
325 (p (array-dimension x 1))
326 (xv (generic-to-linalg x n p type t))
327 (a (make-array p :element-type type))
328 (r (make-array (list p p) :element-type type))
329 (q (make-array (list n p) :element-type type))
330 (j (if pivot (make-array p :element-type 'c-int :initial-element 0)))
331 (w (if pivot (make-array p :element-type type)))
332 (job (if pivot 1 0)))
333 (funcall fun xv 0 n n p a j w job r q)
334 (let ((gq (coerce q '(array t)))
335 (gr (coerce r '(array t)))
336 (gj (if pivot (coerce (1- j) '(vector t)))))
337 (if pivot (list gq gr gj) (list gq gr))))))
339 (defun sv-decomp (x)
340 (check-matrix x)
341 (multiple-value-bind
342 (fun type)
343 (if (any-complex-elements x)
344 (values #'linpack-zsvdc 'c-dcomplex)
345 (values #'linpack-dsvdc 'c-double))
346 (let ((n (array-dimension x 0))
347 (p (array-dimension x 1)))
348 (unless (<= p n) (error "more columns than rows - ~s" x))
349 (let* ((xv (generic-to-linalg x n p type t))
350 (s (make-array p :element-type type))
351 (e (make-array p :element-type type))
352 (u (make-array (* n p) :element-type type))
353 (v (make-array (* p p) :element-type type))
354 (work (make-array n :element-type type))
355 (job 21))
356 (let ((info (funcall fun xv 0 n n p s e u 0 n v 0 p work job)))
357 (list (linalg-to-generic u (list n p) t)
358 (coerce s '(vector t))
359 (linalg-to-generic v (list p p) t)
360 (if info nil t)))))))
363 ;;;;
364 ;;;; Eigenvalues and Eigenvectors
365 ;;;;
367 (export 'eigen)
369 (defun eigen (x)
370 (check-square-matrix x)
371 (cond
372 ((any-complex-elements x)
373 (let* ((n (array-dimension x 0))
374 (xr (generic-to-linalg (realpart x) n n 'c-double t))
375 (xi (generic-to-linalg (imagpart x) n n 'c-double t))
376 (w (make-array n :element-type 'c-double))
377 (zr (make-array (list n n) :element-type 'c-double))
378 (zi (make-array (list n n) :element-type 'c-double))
379 (fv1 (make-array n :element-type 'c-double))
380 (fv2 (make-array n :element-type 'c-double))
381 (fm1 (make-array (* 2 n) :element-type 'c-double))
382 (ierr (eispack-ch n n xr xi w 1 zr zi fv1 fv2 fm1)))
383 (list (nreverse (coerce w '(vector t)))
384 (nreverse (row-list (coerce (complex zr zi) '(array t))))
385 (if ierr (- n ierr) nil))))
387 (let* ((n (array-dimension x 0))
388 (x (generic-to-linalg x n n 'c-double t))
389 (w (make-array n :element-type 'c-double))
390 (z (make-array (list n n) :element-type 'c-double))
391 (fv1 (make-array n :element-type 'c-double))
392 (fv2 (make-array n :element-type 'c-double))
393 (ierr (eispack-rs n n x w 1 z fv1 fv2)))
394 (list (nreverse (coerce w '(vector t)))
395 (nreverse (row-list (coerce z '(array t))))
396 (if ierr (- n ierr) nil))))))