1 `/* Implementation of the MATMUL intrinsic
2 Copyright 2002, 2005 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file. (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 GNU General Public License for more details.
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING. If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA. */
35 #include "libgfortran.h"'
38 `#if defined (HAVE_'rtype_name`)'
40 /* The order of loops is different in the case of plain matrix
41 multiplication C=MATMUL(A,B), and in the frequent special case where
42 the argument A is the temporary result of a TRANSPOSE intrinsic:
43 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
44 looking at their strides.
46 The equivalent Fortran pseudo-code is:
48 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
49 IF (.NOT.IS_TRANSPOSED(A)) THEN
54 C(I,J) = C(I,J)+A(I,K)*B(K,J)
65 extern void matmul_`'rtype_code (rtype * const restrict retarray,
66 rtype * const restrict a, rtype * const restrict b);
67 export_proto(matmul_`'rtype_code);
70 matmul_`'rtype_code (rtype * const restrict retarray,
71 rtype * const restrict a, rtype * const restrict b)
73 const rtype_name * restrict abase;
74 const rtype_name * restrict bbase;
75 rtype_name * restrict dest;
77 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
78 index_type x, y, n, count, xcount, ycount;
80 assert (GFC_DESCRIPTOR_RANK (a) == 2
81 || GFC_DESCRIPTOR_RANK (b) == 2);
83 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
85 Either A or B (but not both) can be rank 1:
87 o One-dimensional argument A is implicitly treated as a row matrix
88 dimensioned [1,count], so xcount=1.
90 o One-dimensional argument B is implicitly treated as a column matrix
91 dimensioned [count, 1], so ycount=1.
94 if (retarray->data == NULL)
96 if (GFC_DESCRIPTOR_RANK (a) == 1)
98 retarray->dim[0].lbound = 0;
99 retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
100 retarray->dim[0].stride = 1;
102 else if (GFC_DESCRIPTOR_RANK (b) == 1)
104 retarray->dim[0].lbound = 0;
105 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
106 retarray->dim[0].stride = 1;
110 retarray->dim[0].lbound = 0;
111 retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
112 retarray->dim[0].stride = 1;
114 retarray->dim[1].lbound = 0;
115 retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
116 retarray->dim[1].stride = retarray->dim[0].ubound+1;
120 = internal_malloc_size (sizeof (rtype_name) * size0 ((array_t *) retarray));
121 retarray->offset = 0;
124 if (retarray->dim[0].stride == 0)
125 retarray->dim[0].stride = 1;
127 /* This prevents constifying the input arguments. */
128 if (a->dim[0].stride == 0)
129 a->dim[0].stride = 1;
130 if (b->dim[0].stride == 0)
131 b->dim[0].stride = 1;
133 sinclude(`matmul_asm_'rtype_code`.m4')dnl
135 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
137 /* One-dimensional result may be addressed in the code below
138 either as a row or a column matrix. We want both cases to
140 rxstride = rystride = retarray->dim[0].stride;
144 rxstride = retarray->dim[0].stride;
145 rystride = retarray->dim[1].stride;
149 if (GFC_DESCRIPTOR_RANK (a) == 1)
151 /* Treat it as a a row matrix A[1,count]. */
152 axstride = a->dim[0].stride;
156 count = a->dim[0].ubound + 1 - a->dim[0].lbound;
160 axstride = a->dim[0].stride;
161 aystride = a->dim[1].stride;
163 count = a->dim[1].ubound + 1 - a->dim[1].lbound;
164 xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
167 assert(count == b->dim[0].ubound + 1 - b->dim[0].lbound);
169 if (GFC_DESCRIPTOR_RANK (b) == 1)
171 /* Treat it as a column matrix B[count,1] */
172 bxstride = b->dim[0].stride;
174 /* bystride should never be used for 1-dimensional b.
175 in case it is we want it to cause a segfault, rather than
176 an incorrect result. */
177 bystride = 0xDEADBEEF;
182 bxstride = b->dim[0].stride;
183 bystride = b->dim[1].stride;
184 ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
189 dest = retarray->data;
191 if (rxstride == 1 && axstride == 1 && bxstride == 1)
193 const rtype_name * restrict bbase_y;
194 rtype_name * restrict dest_y;
195 const rtype_name * restrict abase_n;
198 if (rystride == ycount)
199 memset (dest, 0, (sizeof (rtype_name) * size0((array_t *) retarray)));
202 for (y = 0; y < ycount; y++)
203 for (x = 0; x < xcount; x++)
204 dest[x + y*rystride] = (rtype_name)0;
207 for (y = 0; y < ycount; y++)
209 bbase_y = bbase + y*bystride;
210 dest_y = dest + y*rystride;
211 for (n = 0; n < count; n++)
213 abase_n = abase + n*aystride;
214 bbase_yn = bbase_y[n];
215 for (x = 0; x < xcount; x++)
217 dest_y[x] += abase_n[x] * bbase_yn;
222 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
224 const rtype_name *restrict abase_x;
225 const rtype_name *restrict bbase_y;
226 rtype_name *restrict dest_y;
229 for (y = 0; y < ycount; y++)
231 bbase_y = &bbase[y*bystride];
232 dest_y = &dest[y*rystride];
233 for (x = 0; x < xcount; x++)
235 abase_x = &abase[x*axstride];
237 for (n = 0; n < count; n++)
238 s += abase_x[n] * bbase_y[n];
243 else if (axstride < aystride)
245 for (y = 0; y < ycount; y++)
246 for (x = 0; x < xcount; x++)
247 dest[x*rxstride + y*rystride] = (rtype_name)0;
249 for (y = 0; y < ycount; y++)
250 for (n = 0; n < count; n++)
251 for (x = 0; x < xcount; x++)
252 /* dest[x,y] += a[x,n] * b[n,y] */
253 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
257 const rtype_name *restrict abase_x;
258 const rtype_name *restrict bbase_y;
259 rtype_name *restrict dest_y;
262 for (y = 0; y < ycount; y++)
264 bbase_y = &bbase[y*bystride];
265 dest_y = &dest[y*rystride];
266 for (x = 0; x < xcount; x++)
268 abase_x = &abase[x*axstride];
270 for (n = 0; n < count; n++)
271 s += abase_x[n*aystride] * bbase_y[n*bxstride];
272 dest_y[x*rxstride] = s;