1 ;;; -*- Mode: lisp; outline-regexp: ";;;;;*"; indent-tabs-mode: nil -*-;;;
3 ;;; file: defmatrix-mult.cl
4 ;;; author: cyrus harmon
10 (defun displace-to-1d-array (matrix)
11 (sb-c::%array-data-vector
(clem::matrix-vals matrix
)))
13 (defgeneric mat-mult3
(m n p
))
15 (defmacro def-matrix-mult
(type-1 type-2 accumulator-type
&key suffix
)
16 (let ((element-type-1 (element-type (find-class `,type-1
)))
17 (element-type-2 (element-type (find-class `,type-2
)))
18 (accumulator-element-type (element-type (find-class `,accumulator-type
))))
22 (defmethod ,(ch-util:make-intern
(concatenate 'string
"mat-mult3" suffix
))
23 ((m ,type-1
) (n ,type-2
) (p ,accumulator-type
))
24 (declare (optimize (speed 3) (safety 0)))
25 (let ((a (displace-to-1d-array m
))
26 (b (displace-to-1d-array n
))
27 (c (displace-to-1d-array p
))
28 (atemp (coerce 0 ',accumulator-element-type
)))
29 (declare (type (simple-array ,element-type-1
(*)) a
)
30 (type (simple-array ,element-type-2
(*)) b
)
31 (type (simple-array ,accumulator-element-type
(*)) c
)
32 (type ,accumulator-element-type atemp
))
33 (let ((mr (rows m
)) (mc (cols m
))
34 (nr (rows n
)) (nc (cols n
))
35 (pr (rows p
)) (pc (cols p
)))
36 (declare (type fixnum mr mc nr nc pr pc
))
37 (if (and (= mc nr
) (= mr pr
) (= nc pc
))
38 (do ((k 0 (the fixnum
(1+ k
))))
40 (declare (type fixnum k
))
41 (do* ((i 0 (the fixnum
(1+ i
)))
44 (declare (type fixnum i aind
))
45 (setf atemp
(aref a aind
))
46 (do ((j 0 (the fixnum
(1+ j
)))
47 (bind (* k nc
) (the fixnum
(1+ bind
)))
48 (cind (* i nc
) (the fixnum
(1+ cind
))))
50 (declare (type fixnum j bind cind
))
51 (incf (aref c cind
) (the ,accumulator-element-type
(* atemp
(aref b bind
)))))))
52 (error 'matrix-argument-error
54 "Incompatible matrix dimensions in mat-mult3 (~S x ~S) * (~S x ~S) => (~S x ~S)."
55 :format-arguments
(list mr mc nr nc pr pc
))
60 (defmethod ,(ch-util:make-intern
(concatenate 'string
"mat-mult3" suffix
))
61 ((m ,type-1
) (n ,type-1
) (p ,type-1
))
62 (declare (optimize (speed 3) (safety 0)))
63 (let ((a (clem::matrix-vals m
))
64 (b (clem::matrix-vals n
))
65 (c (clem::matrix-vals p
))
66 (atemp (coerce 0 ',accumulator-element-type
)))
67 (declare (type (simple-array ,element-type-1
*) a
)
68 (type (simple-array ,element-type-2
*) b
)
69 (type (simple-array ,accumulator-element-type
*) c
)
70 (type ,accumulator-element-type atemp
))
71 (let ((mr (rows m
)) (mc (cols m
))
72 (nr (rows n
)) (nc (cols n
))
73 (pr (rows n
)) (pc (cols n
)))
74 (declare (type fixnum mr mc nr nc pr pc
))
75 (if (and (= mc nr
) (= mr pr
) (= nc pc
))
76 (do ((k 0 (the fixnum
(1+ k
))))
78 (declare (type fixnum k
))
79 (do ((i 0 (the fixnum
(1+ i
))))
81 (declare (type fixnum i
))
82 (setf atemp
(aref a i k
))
83 (do ((j 0 (the fixnum
(1+ j
))))
85 (declare (type fixnum j
))
86 (incf (aref c i j
) (the ,accumulator-element-type
(* atemp
(aref b k j
)))))))
87 (error 'matrix-argument-error
89 "Incompatible matrix dimensions in mat-mult3 (~S x ~S) * (~S x ~S) => (~S x ~S)."
90 :format-arguments
(list mr mc nr nc pr pc
)))))
93 (defmethod ,(ch-util:make-intern
(concatenate 'string
"mat-mult" suffix
))
94 ((m ,type-1
) (n ,type-2
))
95 (declare (optimize (speed 3) (safety 0)))
98 (declare (type fixnum mr nc
))
99 (if (= (cols m
) (rows n
))
100 (let ((p (make-instance ',accumulator-type
101 :rows
(the fixnum mr
)
102 :cols
(the fixnum nc
))))
104 (error 'matrix-argument-error
106 "Incompatible matrix dimensions in mat-mult (~S x ~S) * (~S x ~S)."
107 :format-arguments
(list (rows m
) (cols m
) (rows n
) (cols n
)))))))))
110 ;;; need to think about which mat-mult type combinations are needed
111 ;;; here. add more as apporpriate.
112 (macrolet ((frob (type-1 type-2 type-3
&key suffix
)
113 `(def-matrix-mult ,type-1
,type-2
,type-3
:suffix
,suffix
)))
114 (frob double-float-matrix double-float-matrix double-float-matrix
)
115 (frob single-float-matrix single-float-matrix single-float-matrix
)
116 (frob ub8-matrix ub8-matrix ub8-matrix
)
117 (frob ub16-matrix ub16-matrix ub16-matrix
)
118 (frob ub32-matrix ub32-matrix ub32-matrix
)
119 (frob sb8-matrix sb8-matrix sb32-matrix
)
120 (frob sb16-matrix sb16-matrix sb32-matrix
)
121 (frob sb32-matrix sb32-matrix sb32-matrix
)
122 (frob fixnum-matrix fixnum-matrix fixnum-matrix
)
123 (frob bit-matrix bit-matrix bit-matrix
))