Rebase.
[official-gcc.git] / libgfortran / generated / matmul_i4.c
blobf62cb56aa1be25810679691c1952d6569eecfb5e
1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2014 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
27 #include <stdlib.h>
28 #include <string.h>
29 #include <assert.h>
32 #if defined (HAVE_GFC_INTEGER_4)
34 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
35 passed to us by the front-end, in which case we'll call it for large
36 matrices. */
38 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
39 const int *, const GFC_INTEGER_4 *, const GFC_INTEGER_4 *,
40 const int *, const GFC_INTEGER_4 *, const int *,
41 const GFC_INTEGER_4 *, GFC_INTEGER_4 *, const int *,
42 int, int);
44 /* The order of loops is different in the case of plain matrix
45 multiplication C=MATMUL(A,B), and in the frequent special case where
46 the argument A is the temporary result of a TRANSPOSE intrinsic:
47 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
48 looking at their strides.
50 The equivalent Fortran pseudo-code is:
52 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
53 IF (.NOT.IS_TRANSPOSED(A)) THEN
54 C = 0
55 DO J=1,N
56 DO K=1,COUNT
57 DO I=1,M
58 C(I,J) = C(I,J)+A(I,K)*B(K,J)
59 ELSE
60 DO J=1,N
61 DO I=1,M
62 S = 0
63 DO K=1,COUNT
64 S = S+A(I,K)*B(K,J)
65 C(I,J) = S
66 ENDIF
69 /* If try_blas is set to a nonzero value, then the matmul function will
70 see if there is a way to perform the matrix multiplication by a call
71 to the BLAS gemm function. */
73 extern void matmul_i4 (gfc_array_i4 * const restrict retarray,
74 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
75 int blas_limit, blas_call gemm);
76 export_proto(matmul_i4);
78 void
79 matmul_i4 (gfc_array_i4 * const restrict retarray,
80 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
81 int blas_limit, blas_call gemm)
83 const GFC_INTEGER_4 * restrict abase;
84 const GFC_INTEGER_4 * restrict bbase;
85 GFC_INTEGER_4 * restrict dest;
87 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
88 index_type x, y, n, count, xcount, ycount;
90 assert (GFC_DESCRIPTOR_RANK (a) == 2
91 || GFC_DESCRIPTOR_RANK (b) == 2);
93 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
95 Either A or B (but not both) can be rank 1:
97 o One-dimensional argument A is implicitly treated as a row matrix
98 dimensioned [1,count], so xcount=1.
100 o One-dimensional argument B is implicitly treated as a column matrix
101 dimensioned [count, 1], so ycount=1.
104 if (retarray->base_addr == NULL)
106 if (GFC_DESCRIPTOR_RANK (a) == 1)
108 GFC_DIMENSION_SET(retarray->dim[0], 0,
109 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
111 else if (GFC_DESCRIPTOR_RANK (b) == 1)
113 GFC_DIMENSION_SET(retarray->dim[0], 0,
114 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
116 else
118 GFC_DIMENSION_SET(retarray->dim[0], 0,
119 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
121 GFC_DIMENSION_SET(retarray->dim[1], 0,
122 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
123 GFC_DESCRIPTOR_EXTENT(retarray,0));
126 retarray->base_addr
127 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_4));
128 retarray->offset = 0;
130 else if (unlikely (compile_options.bounds_check))
132 index_type ret_extent, arg_extent;
134 if (GFC_DESCRIPTOR_RANK (a) == 1)
136 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
137 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
138 if (arg_extent != ret_extent)
139 runtime_error ("Incorrect extent in return array in"
140 " MATMUL intrinsic: is %ld, should be %ld",
141 (long int) ret_extent, (long int) arg_extent);
143 else if (GFC_DESCRIPTOR_RANK (b) == 1)
145 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
146 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
147 if (arg_extent != ret_extent)
148 runtime_error ("Incorrect extent in return array in"
149 " MATMUL intrinsic: is %ld, should be %ld",
150 (long int) ret_extent, (long int) arg_extent);
152 else
154 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
155 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
156 if (arg_extent != ret_extent)
157 runtime_error ("Incorrect extent in return array in"
158 " MATMUL intrinsic for dimension 1:"
159 " is %ld, should be %ld",
160 (long int) ret_extent, (long int) arg_extent);
162 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
163 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
164 if (arg_extent != ret_extent)
165 runtime_error ("Incorrect extent in return array in"
166 " MATMUL intrinsic for dimension 2:"
167 " is %ld, should be %ld",
168 (long int) ret_extent, (long int) arg_extent);
173 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
175 /* One-dimensional result may be addressed in the code below
176 either as a row or a column matrix. We want both cases to
177 work. */
178 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
180 else
182 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
183 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
187 if (GFC_DESCRIPTOR_RANK (a) == 1)
189 /* Treat it as a a row matrix A[1,count]. */
190 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
191 aystride = 1;
193 xcount = 1;
194 count = GFC_DESCRIPTOR_EXTENT(a,0);
196 else
198 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
199 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
201 count = GFC_DESCRIPTOR_EXTENT(a,1);
202 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
205 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
207 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
208 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
211 if (GFC_DESCRIPTOR_RANK (b) == 1)
213 /* Treat it as a column matrix B[count,1] */
214 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
216 /* bystride should never be used for 1-dimensional b.
217 in case it is we want it to cause a segfault, rather than
218 an incorrect result. */
219 bystride = 0xDEADBEEF;
220 ycount = 1;
222 else
224 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
225 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
226 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
229 abase = a->base_addr;
230 bbase = b->base_addr;
231 dest = retarray->base_addr;
234 /* Now that everything is set up, we're performing the multiplication
235 itself. */
237 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
239 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
240 && (bxstride == 1 || bystride == 1)
241 && (((float) xcount) * ((float) ycount) * ((float) count)
242 > POW3(blas_limit)))
244 const int m = xcount, n = ycount, k = count, ldc = rystride;
245 const GFC_INTEGER_4 one = 1, zero = 0;
246 const int lda = (axstride == 1) ? aystride : axstride,
247 ldb = (bxstride == 1) ? bystride : bxstride;
249 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
251 assert (gemm != NULL);
252 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
253 &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
254 return;
258 if (rxstride == 1 && axstride == 1 && bxstride == 1)
260 const GFC_INTEGER_4 * restrict bbase_y;
261 GFC_INTEGER_4 * restrict dest_y;
262 const GFC_INTEGER_4 * restrict abase_n;
263 GFC_INTEGER_4 bbase_yn;
265 if (rystride == xcount)
266 memset (dest, 0, (sizeof (GFC_INTEGER_4) * xcount * ycount));
267 else
269 for (y = 0; y < ycount; y++)
270 for (x = 0; x < xcount; x++)
271 dest[x + y*rystride] = (GFC_INTEGER_4)0;
274 for (y = 0; y < ycount; y++)
276 bbase_y = bbase + y*bystride;
277 dest_y = dest + y*rystride;
278 for (n = 0; n < count; n++)
280 abase_n = abase + n*aystride;
281 bbase_yn = bbase_y[n];
282 for (x = 0; x < xcount; x++)
284 dest_y[x] += abase_n[x] * bbase_yn;
289 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
291 if (GFC_DESCRIPTOR_RANK (a) != 1)
293 const GFC_INTEGER_4 *restrict abase_x;
294 const GFC_INTEGER_4 *restrict bbase_y;
295 GFC_INTEGER_4 *restrict dest_y;
296 GFC_INTEGER_4 s;
298 for (y = 0; y < ycount; y++)
300 bbase_y = &bbase[y*bystride];
301 dest_y = &dest[y*rystride];
302 for (x = 0; x < xcount; x++)
304 abase_x = &abase[x*axstride];
305 s = (GFC_INTEGER_4) 0;
306 for (n = 0; n < count; n++)
307 s += abase_x[n] * bbase_y[n];
308 dest_y[x] = s;
312 else
314 const GFC_INTEGER_4 *restrict bbase_y;
315 GFC_INTEGER_4 s;
317 for (y = 0; y < ycount; y++)
319 bbase_y = &bbase[y*bystride];
320 s = (GFC_INTEGER_4) 0;
321 for (n = 0; n < count; n++)
322 s += abase[n*axstride] * bbase_y[n];
323 dest[y*rystride] = s;
327 else if (axstride < aystride)
329 for (y = 0; y < ycount; y++)
330 for (x = 0; x < xcount; x++)
331 dest[x*rxstride + y*rystride] = (GFC_INTEGER_4)0;
333 for (y = 0; y < ycount; y++)
334 for (n = 0; n < count; n++)
335 for (x = 0; x < xcount; x++)
336 /* dest[x,y] += a[x,n] * b[n,y] */
337 dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
339 else if (GFC_DESCRIPTOR_RANK (a) == 1)
341 const GFC_INTEGER_4 *restrict bbase_y;
342 GFC_INTEGER_4 s;
344 for (y = 0; y < ycount; y++)
346 bbase_y = &bbase[y*bystride];
347 s = (GFC_INTEGER_4) 0;
348 for (n = 0; n < count; n++)
349 s += abase[n*axstride] * bbase_y[n*bxstride];
350 dest[y*rxstride] = s;
353 else
355 const GFC_INTEGER_4 *restrict abase_x;
356 const GFC_INTEGER_4 *restrict bbase_y;
357 GFC_INTEGER_4 *restrict dest_y;
358 GFC_INTEGER_4 s;
360 for (y = 0; y < ycount; y++)
362 bbase_y = &bbase[y*bystride];
363 dest_y = &dest[y*rystride];
364 for (x = 0; x < xcount; x++)
366 abase_x = &abase[x*axstride];
367 s = (GFC_INTEGER_4) 0;
368 for (n = 0; n < count; n++)
369 s += abase_x[n*aystride] * bbase_y[n*bxstride];
370 dest_y[x*rxstride] = s;
376 #endif