1 `/* Implementation of the MATMUL intrinsic
2 Copyright 2002, 2005, 2006, 2007, 2009 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
33 `#if defined (HAVE_'rtype_name`)
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36 passed to us by the front-end, in which case we''`ll call it for large
39 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
40 const int *, const 'rtype_name` *, const 'rtype_name` *,
41 const int *, const 'rtype_name` *, const int *,
42 const 'rtype_name` *, 'rtype_name` *, const int *,
45 /* The order of loops is different in the case of plain matrix
46 multiplication C=MATMUL(A,B), and in the frequent special case where
47 the argument A is the temporary result of a TRANSPOSE intrinsic:
48 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
49 looking at their strides.
51 The equivalent Fortran pseudo-code is:
53 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
54 IF (.NOT.IS_TRANSPOSED(A)) THEN
59 C(I,J) = C(I,J)+A(I,K)*B(K,J)
70 /* If try_blas is set to a nonzero value, then the matmul function will
71 see if there is a way to perform the matrix multiplication by a call
72 to the BLAS gemm function. */
74 extern void matmul_'rtype_code` ('rtype` * const restrict retarray,
75 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
76 int blas_limit, blas_call gemm);
77 export_proto(matmul_'rtype_code`);
80 matmul_'rtype_code` ('rtype` * const restrict retarray,
81 'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
82 int blas_limit, blas_call gemm)
84 const 'rtype_name` * restrict abase;
85 const 'rtype_name` * restrict bbase;
86 'rtype_name` * restrict dest;
88 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
89 index_type x, y, n, count, xcount, ycount;
91 assert (GFC_DESCRIPTOR_RANK (a) == 2
92 || GFC_DESCRIPTOR_RANK (b) == 2);
94 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
96 Either A or B (but not both) can be rank 1:
98 o One-dimensional argument A is implicitly treated as a row matrix
99 dimensioned [1,count], so xcount=1.
101 o One-dimensional argument B is implicitly treated as a column matrix
102 dimensioned [count, 1], so ycount=1.
105 if (retarray->data == NULL)
107 if (GFC_DESCRIPTOR_RANK (a) == 1)
109 GFC_DIMENSION_SET(retarray->dim[0], 0,
110 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
112 else if (GFC_DESCRIPTOR_RANK (b) == 1)
114 GFC_DIMENSION_SET(retarray->dim[0], 0,
115 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
119 GFC_DIMENSION_SET(retarray->dim[0], 0,
120 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
122 GFC_DIMENSION_SET(retarray->dim[1], 0,
123 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
124 GFC_DESCRIPTOR_EXTENT(retarray,0));
128 = internal_malloc_size (sizeof ('rtype_name`) * size0 ((array_t *) retarray));
129 retarray->offset = 0;
131 else if (unlikely (compile_options.bounds_check))
133 index_type ret_extent, arg_extent;
135 if (GFC_DESCRIPTOR_RANK (a) == 1)
137 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
138 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
139 if (arg_extent != ret_extent)
140 runtime_error ("Incorrect extent in return array in"
141 " MATMUL intrinsic: is %ld, should be %ld",
142 (long int) ret_extent, (long int) arg_extent);
144 else if (GFC_DESCRIPTOR_RANK (b) == 1)
146 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
147 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
148 if (arg_extent != ret_extent)
149 runtime_error ("Incorrect extent in return array in"
150 " MATMUL intrinsic: is %ld, should be %ld",
151 (long int) ret_extent, (long int) arg_extent);
155 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
156 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
157 if (arg_extent != ret_extent)
158 runtime_error ("Incorrect extent in return array in"
159 " MATMUL intrinsic for dimension 1:"
160 " is %ld, should be %ld",
161 (long int) ret_extent, (long int) arg_extent);
163 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
164 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
165 if (arg_extent != ret_extent)
166 runtime_error ("Incorrect extent in return array in"
167 " MATMUL intrinsic for dimension 2:"
168 " is %ld, should be %ld",
169 (long int) ret_extent, (long int) arg_extent);
173 sinclude(`matmul_asm_'rtype_code`.m4')dnl
175 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
177 /* One-dimensional result may be addressed in the code below
178 either as a row or a column matrix. We want both cases to
180 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
184 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
185 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
189 if (GFC_DESCRIPTOR_RANK (a) == 1)
191 /* Treat it as a a row matrix A[1,count]. */
192 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
196 count = GFC_DESCRIPTOR_EXTENT(a,0);
200 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
201 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
203 count = GFC_DESCRIPTOR_EXTENT(a,1);
204 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
207 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
209 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
210 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
213 if (GFC_DESCRIPTOR_RANK (b) == 1)
215 /* Treat it as a column matrix B[count,1] */
216 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
218 /* bystride should never be used for 1-dimensional b.
219 in case it is we want it to cause a segfault, rather than
220 an incorrect result. */
221 bystride = 0xDEADBEEF;
226 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
227 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
228 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
233 dest = retarray->data;
236 /* Now that everything is set up, we''`re performing the multiplication
239 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
241 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
242 && (bxstride == 1 || bystride == 1)
243 && (((float) xcount) * ((float) ycount) * ((float) count)
246 const int m = xcount, n = ycount, k = count, ldc = rystride;
247 const 'rtype_name` one = 1, zero = 0;
248 const int lda = (axstride == 1) ? aystride : axstride,
249 ldb = (bxstride == 1) ? bystride : bxstride;
251 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
253 assert (gemm != NULL);
254 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
255 &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
260 if (rxstride == 1 && axstride == 1 && bxstride == 1)
262 const 'rtype_name` * restrict bbase_y;
263 'rtype_name` * restrict dest_y;
264 const 'rtype_name` * restrict abase_n;
265 'rtype_name` bbase_yn;
267 if (rystride == xcount)
268 memset (dest, 0, (sizeof ('rtype_name`) * xcount * ycount));
271 for (y = 0; y < ycount; y++)
272 for (x = 0; x < xcount; x++)
273 dest[x + y*rystride] = ('rtype_name`)0;
276 for (y = 0; y < ycount; y++)
278 bbase_y = bbase + y*bystride;
279 dest_y = dest + y*rystride;
280 for (n = 0; n < count; n++)
282 abase_n = abase + n*aystride;
283 bbase_yn = bbase_y[n];
284 for (x = 0; x < xcount; x++)
286 dest_y[x] += abase_n[x] * bbase_yn;
291 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
293 if (GFC_DESCRIPTOR_RANK (a) != 1)
295 const 'rtype_name` *restrict abase_x;
296 const 'rtype_name` *restrict bbase_y;
297 'rtype_name` *restrict dest_y;
300 for (y = 0; y < ycount; y++)
302 bbase_y = &bbase[y*bystride];
303 dest_y = &dest[y*rystride];
304 for (x = 0; x < xcount; x++)
306 abase_x = &abase[x*axstride];
307 s = ('rtype_name`) 0;
308 for (n = 0; n < count; n++)
309 s += abase_x[n] * bbase_y[n];
316 const 'rtype_name` *restrict bbase_y;
319 for (y = 0; y < ycount; y++)
321 bbase_y = &bbase[y*bystride];
322 s = ('rtype_name`) 0;
323 for (n = 0; n < count; n++)
324 s += abase[n*axstride] * bbase_y[n];
325 dest[y*rystride] = s;
329 else if (axstride < aystride)
331 for (y = 0; y < ycount; y++)
332 for (x = 0; x < xcount; x++)
333 dest[x*rxstride + y*rystride] = ('rtype_name`)0;
335 for (y = 0; y < ycount; y++)
336 for (n = 0; n < count; n++)
337 for (x = 0; x < xcount; x++)
338 /* dest[x,y] += a[x,n] * b[n,y] */
339 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
341 else if (GFC_DESCRIPTOR_RANK (a) == 1)
343 const 'rtype_name` *restrict bbase_y;
346 for (y = 0; y < ycount; y++)
348 bbase_y = &bbase[y*bystride];
349 s = ('rtype_name`) 0;
350 for (n = 0; n < count; n++)
351 s += abase[n*axstride] * bbase_y[n*bxstride];
352 dest[y*rxstride] = s;
357 const 'rtype_name` *restrict abase_x;
358 const 'rtype_name` *restrict bbase_y;
359 'rtype_name` *restrict dest_y;
362 for (y = 0; y < ycount; y++)
364 bbase_y = &bbase[y*bystride];
365 dest_y = &dest[y*rystride];
366 for (x = 0; x < xcount; x++)
368 abase_x = &abase[x*axstride];
369 s = ('rtype_name`) 0;
370 for (n = 0; n < count; n++)
371 s += abase_x[n*aystride] * bbase_y[n*bxstride];
372 dest_y[x*rxstride] = s;