gcc:
[official-gcc.git] / libgfortran / generated / matmulavx128_c8.c
blobaccc69c4d1a20556569fc2bb00ec28f9a96a61e8
1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2018 Free Software Foundation, Inc.
3 Contributed by Thomas Koenig <tkoenig@gcc.gnu.org>.
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
27 #include <string.h>
28 #include <assert.h>
31 /* These are the specific versions of matmul with -mprefer-avx128. */
33 #if defined (HAVE_GFC_COMPLEX_8)
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36 passed to us by the front-end, in which case we call it for large
37 matrices. */
39 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
40 const int *, const GFC_COMPLEX_8 *, const GFC_COMPLEX_8 *,
41 const int *, const GFC_COMPLEX_8 *, const int *,
42 const GFC_COMPLEX_8 *, GFC_COMPLEX_8 *, const int *,
43 int, int);
45 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
46 void
47 matmul_c8_avx128_fma3 (gfc_array_c8 * const restrict retarray,
48 gfc_array_c8 * const restrict a, gfc_array_c8 * const restrict b, int try_blas,
49 int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma")));
50 internal_proto(matmul_c8_avx128_fma3);
51 void
52 matmul_c8_avx128_fma3 (gfc_array_c8 * const restrict retarray,
53 gfc_array_c8 * const restrict a, gfc_array_c8 * const restrict b, int try_blas,
54 int blas_limit, blas_call gemm)
56 const GFC_COMPLEX_8 * restrict abase;
57 const GFC_COMPLEX_8 * restrict bbase;
58 GFC_COMPLEX_8 * restrict dest;
60 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
61 index_type x, y, n, count, xcount, ycount;
63 assert (GFC_DESCRIPTOR_RANK (a) == 2
64 || GFC_DESCRIPTOR_RANK (b) == 2);
66 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
68 Either A or B (but not both) can be rank 1:
70 o One-dimensional argument A is implicitly treated as a row matrix
71 dimensioned [1,count], so xcount=1.
73 o One-dimensional argument B is implicitly treated as a column matrix
74 dimensioned [count, 1], so ycount=1.
77 if (retarray->base_addr == NULL)
79 if (GFC_DESCRIPTOR_RANK (a) == 1)
81 GFC_DIMENSION_SET(retarray->dim[0], 0,
82 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
84 else if (GFC_DESCRIPTOR_RANK (b) == 1)
86 GFC_DIMENSION_SET(retarray->dim[0], 0,
87 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
89 else
91 GFC_DIMENSION_SET(retarray->dim[0], 0,
92 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
94 GFC_DIMENSION_SET(retarray->dim[1], 0,
95 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
96 GFC_DESCRIPTOR_EXTENT(retarray,0));
99 retarray->base_addr
100 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_COMPLEX_8));
101 retarray->offset = 0;
103 else if (unlikely (compile_options.bounds_check))
105 index_type ret_extent, arg_extent;
107 if (GFC_DESCRIPTOR_RANK (a) == 1)
109 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
110 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
111 if (arg_extent != ret_extent)
112 runtime_error ("Incorrect extent in return array in"
113 " MATMUL intrinsic: is %ld, should be %ld",
114 (long int) ret_extent, (long int) arg_extent);
116 else if (GFC_DESCRIPTOR_RANK (b) == 1)
118 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
119 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
120 if (arg_extent != ret_extent)
121 runtime_error ("Incorrect extent in return array in"
122 " MATMUL intrinsic: is %ld, should be %ld",
123 (long int) ret_extent, (long int) arg_extent);
125 else
127 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
128 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
129 if (arg_extent != ret_extent)
130 runtime_error ("Incorrect extent in return array in"
131 " MATMUL intrinsic for dimension 1:"
132 " is %ld, should be %ld",
133 (long int) ret_extent, (long int) arg_extent);
135 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
136 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
137 if (arg_extent != ret_extent)
138 runtime_error ("Incorrect extent in return array in"
139 " MATMUL intrinsic for dimension 2:"
140 " is %ld, should be %ld",
141 (long int) ret_extent, (long int) arg_extent);
146 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
148 /* One-dimensional result may be addressed in the code below
149 either as a row or a column matrix. We want both cases to
150 work. */
151 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
153 else
155 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
156 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
160 if (GFC_DESCRIPTOR_RANK (a) == 1)
162 /* Treat it as a a row matrix A[1,count]. */
163 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
164 aystride = 1;
166 xcount = 1;
167 count = GFC_DESCRIPTOR_EXTENT(a,0);
169 else
171 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
172 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
174 count = GFC_DESCRIPTOR_EXTENT(a,1);
175 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
178 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
180 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
181 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
184 if (GFC_DESCRIPTOR_RANK (b) == 1)
186 /* Treat it as a column matrix B[count,1] */
187 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
189 /* bystride should never be used for 1-dimensional b.
190 The value is only used for calculation of the
191 memory by the buffer. */
192 bystride = 256;
193 ycount = 1;
195 else
197 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
198 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
199 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
202 abase = a->base_addr;
203 bbase = b->base_addr;
204 dest = retarray->base_addr;
206 /* Now that everything is set up, we perform the multiplication
207 itself. */
209 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
210 #define min(a,b) ((a) <= (b) ? (a) : (b))
211 #define max(a,b) ((a) >= (b) ? (a) : (b))
213 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
214 && (bxstride == 1 || bystride == 1)
215 && (((float) xcount) * ((float) ycount) * ((float) count)
216 > POW3(blas_limit)))
218 const int m = xcount, n = ycount, k = count, ldc = rystride;
219 const GFC_COMPLEX_8 one = 1, zero = 0;
220 const int lda = (axstride == 1) ? aystride : axstride,
221 ldb = (bxstride == 1) ? bystride : bxstride;
223 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
225 assert (gemm != NULL);
226 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m,
227 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
228 &ldc, 1, 1);
229 return;
233 if (rxstride == 1 && axstride == 1 && bxstride == 1)
235 /* This block of code implements a tuned matmul, derived from
236 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
238 Bo Kagstrom and Per Ling
239 Department of Computing Science
240 Umea University
241 S-901 87 Umea, Sweden
243 from netlib.org, translated to C, and modified for matmul.m4. */
245 const GFC_COMPLEX_8 *a, *b;
246 GFC_COMPLEX_8 *c;
247 const index_type m = xcount, n = ycount, k = count;
249 /* System generated locals */
250 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
251 i1, i2, i3, i4, i5, i6;
253 /* Local variables */
254 GFC_COMPLEX_8 f11, f12, f21, f22, f31, f32, f41, f42,
255 f13, f14, f23, f24, f33, f34, f43, f44;
256 index_type i, j, l, ii, jj, ll;
257 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
258 GFC_COMPLEX_8 *t1;
260 a = abase;
261 b = bbase;
262 c = retarray->base_addr;
264 /* Parameter adjustments */
265 c_dim1 = rystride;
266 c_offset = 1 + c_dim1;
267 c -= c_offset;
268 a_dim1 = aystride;
269 a_offset = 1 + a_dim1;
270 a -= a_offset;
271 b_dim1 = bystride;
272 b_offset = 1 + b_dim1;
273 b -= b_offset;
275 /* Empty c first. */
276 for (j=1; j<=n; j++)
277 for (i=1; i<=m; i++)
278 c[i + j * c_dim1] = (GFC_COMPLEX_8)0;
280 /* Early exit if possible */
281 if (m == 0 || n == 0 || k == 0)
282 return;
284 /* Adjust size of t1 to what is needed. */
285 index_type t1_dim, a_sz;
286 if (aystride == 1)
287 a_sz = rystride;
288 else
289 a_sz = a_dim1;
291 t1_dim = a_sz * 256 + b_dim1;
292 if (t1_dim > 65536)
293 t1_dim = 65536;
295 t1 = malloc (t1_dim * sizeof(GFC_COMPLEX_8));
297 /* Start turning the crank. */
298 i1 = n;
299 for (jj = 1; jj <= i1; jj += 512)
301 /* Computing MIN */
302 i2 = 512;
303 i3 = n - jj + 1;
304 jsec = min(i2,i3);
305 ujsec = jsec - jsec % 4;
306 i2 = k;
307 for (ll = 1; ll <= i2; ll += 256)
309 /* Computing MIN */
310 i3 = 256;
311 i4 = k - ll + 1;
312 lsec = min(i3,i4);
313 ulsec = lsec - lsec % 2;
315 i3 = m;
316 for (ii = 1; ii <= i3; ii += 256)
318 /* Computing MIN */
319 i4 = 256;
320 i5 = m - ii + 1;
321 isec = min(i4,i5);
322 uisec = isec - isec % 2;
323 i4 = ll + ulsec - 1;
324 for (l = ll; l <= i4; l += 2)
326 i5 = ii + uisec - 1;
327 for (i = ii; i <= i5; i += 2)
329 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
330 a[i + l * a_dim1];
331 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
332 a[i + (l + 1) * a_dim1];
333 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
334 a[i + 1 + l * a_dim1];
335 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
336 a[i + 1 + (l + 1) * a_dim1];
338 if (uisec < isec)
340 t1[l - ll + 1 + (isec << 8) - 257] =
341 a[ii + isec - 1 + l * a_dim1];
342 t1[l - ll + 2 + (isec << 8) - 257] =
343 a[ii + isec - 1 + (l + 1) * a_dim1];
346 if (ulsec < lsec)
348 i4 = ii + isec - 1;
349 for (i = ii; i<= i4; ++i)
351 t1[lsec + ((i - ii + 1) << 8) - 257] =
352 a[i + (ll + lsec - 1) * a_dim1];
356 uisec = isec - isec % 4;
357 i4 = jj + ujsec - 1;
358 for (j = jj; j <= i4; j += 4)
360 i5 = ii + uisec - 1;
361 for (i = ii; i <= i5; i += 4)
363 f11 = c[i + j * c_dim1];
364 f21 = c[i + 1 + j * c_dim1];
365 f12 = c[i + (j + 1) * c_dim1];
366 f22 = c[i + 1 + (j + 1) * c_dim1];
367 f13 = c[i + (j + 2) * c_dim1];
368 f23 = c[i + 1 + (j + 2) * c_dim1];
369 f14 = c[i + (j + 3) * c_dim1];
370 f24 = c[i + 1 + (j + 3) * c_dim1];
371 f31 = c[i + 2 + j * c_dim1];
372 f41 = c[i + 3 + j * c_dim1];
373 f32 = c[i + 2 + (j + 1) * c_dim1];
374 f42 = c[i + 3 + (j + 1) * c_dim1];
375 f33 = c[i + 2 + (j + 2) * c_dim1];
376 f43 = c[i + 3 + (j + 2) * c_dim1];
377 f34 = c[i + 2 + (j + 3) * c_dim1];
378 f44 = c[i + 3 + (j + 3) * c_dim1];
379 i6 = ll + lsec - 1;
380 for (l = ll; l <= i6; ++l)
382 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
383 * b[l + j * b_dim1];
384 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
385 * b[l + j * b_dim1];
386 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
387 * b[l + (j + 1) * b_dim1];
388 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
389 * b[l + (j + 1) * b_dim1];
390 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
391 * b[l + (j + 2) * b_dim1];
392 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
393 * b[l + (j + 2) * b_dim1];
394 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
395 * b[l + (j + 3) * b_dim1];
396 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
397 * b[l + (j + 3) * b_dim1];
398 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
399 * b[l + j * b_dim1];
400 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
401 * b[l + j * b_dim1];
402 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
403 * b[l + (j + 1) * b_dim1];
404 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
405 * b[l + (j + 1) * b_dim1];
406 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
407 * b[l + (j + 2) * b_dim1];
408 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
409 * b[l + (j + 2) * b_dim1];
410 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
411 * b[l + (j + 3) * b_dim1];
412 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
413 * b[l + (j + 3) * b_dim1];
415 c[i + j * c_dim1] = f11;
416 c[i + 1 + j * c_dim1] = f21;
417 c[i + (j + 1) * c_dim1] = f12;
418 c[i + 1 + (j + 1) * c_dim1] = f22;
419 c[i + (j + 2) * c_dim1] = f13;
420 c[i + 1 + (j + 2) * c_dim1] = f23;
421 c[i + (j + 3) * c_dim1] = f14;
422 c[i + 1 + (j + 3) * c_dim1] = f24;
423 c[i + 2 + j * c_dim1] = f31;
424 c[i + 3 + j * c_dim1] = f41;
425 c[i + 2 + (j + 1) * c_dim1] = f32;
426 c[i + 3 + (j + 1) * c_dim1] = f42;
427 c[i + 2 + (j + 2) * c_dim1] = f33;
428 c[i + 3 + (j + 2) * c_dim1] = f43;
429 c[i + 2 + (j + 3) * c_dim1] = f34;
430 c[i + 3 + (j + 3) * c_dim1] = f44;
432 if (uisec < isec)
434 i5 = ii + isec - 1;
435 for (i = ii + uisec; i <= i5; ++i)
437 f11 = c[i + j * c_dim1];
438 f12 = c[i + (j + 1) * c_dim1];
439 f13 = c[i + (j + 2) * c_dim1];
440 f14 = c[i + (j + 3) * c_dim1];
441 i6 = ll + lsec - 1;
442 for (l = ll; l <= i6; ++l)
444 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
445 257] * b[l + j * b_dim1];
446 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
447 257] * b[l + (j + 1) * b_dim1];
448 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
449 257] * b[l + (j + 2) * b_dim1];
450 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
451 257] * b[l + (j + 3) * b_dim1];
453 c[i + j * c_dim1] = f11;
454 c[i + (j + 1) * c_dim1] = f12;
455 c[i + (j + 2) * c_dim1] = f13;
456 c[i + (j + 3) * c_dim1] = f14;
460 if (ujsec < jsec)
462 i4 = jj + jsec - 1;
463 for (j = jj + ujsec; j <= i4; ++j)
465 i5 = ii + uisec - 1;
466 for (i = ii; i <= i5; i += 4)
468 f11 = c[i + j * c_dim1];
469 f21 = c[i + 1 + j * c_dim1];
470 f31 = c[i + 2 + j * c_dim1];
471 f41 = c[i + 3 + j * c_dim1];
472 i6 = ll + lsec - 1;
473 for (l = ll; l <= i6; ++l)
475 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
476 257] * b[l + j * b_dim1];
477 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
478 257] * b[l + j * b_dim1];
479 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
480 257] * b[l + j * b_dim1];
481 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
482 257] * b[l + j * b_dim1];
484 c[i + j * c_dim1] = f11;
485 c[i + 1 + j * c_dim1] = f21;
486 c[i + 2 + j * c_dim1] = f31;
487 c[i + 3 + j * c_dim1] = f41;
489 i5 = ii + isec - 1;
490 for (i = ii + uisec; i <= i5; ++i)
492 f11 = c[i + j * c_dim1];
493 i6 = ll + lsec - 1;
494 for (l = ll; l <= i6; ++l)
496 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
497 257] * b[l + j * b_dim1];
499 c[i + j * c_dim1] = f11;
506 free(t1);
507 return;
509 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
511 if (GFC_DESCRIPTOR_RANK (a) != 1)
513 const GFC_COMPLEX_8 *restrict abase_x;
514 const GFC_COMPLEX_8 *restrict bbase_y;
515 GFC_COMPLEX_8 *restrict dest_y;
516 GFC_COMPLEX_8 s;
518 for (y = 0; y < ycount; y++)
520 bbase_y = &bbase[y*bystride];
521 dest_y = &dest[y*rystride];
522 for (x = 0; x < xcount; x++)
524 abase_x = &abase[x*axstride];
525 s = (GFC_COMPLEX_8) 0;
526 for (n = 0; n < count; n++)
527 s += abase_x[n] * bbase_y[n];
528 dest_y[x] = s;
532 else
534 const GFC_COMPLEX_8 *restrict bbase_y;
535 GFC_COMPLEX_8 s;
537 for (y = 0; y < ycount; y++)
539 bbase_y = &bbase[y*bystride];
540 s = (GFC_COMPLEX_8) 0;
541 for (n = 0; n < count; n++)
542 s += abase[n*axstride] * bbase_y[n];
543 dest[y*rystride] = s;
547 else if (axstride < aystride)
549 for (y = 0; y < ycount; y++)
550 for (x = 0; x < xcount; x++)
551 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_8)0;
553 for (y = 0; y < ycount; y++)
554 for (n = 0; n < count; n++)
555 for (x = 0; x < xcount; x++)
556 /* dest[x,y] += a[x,n] * b[n,y] */
557 dest[x*rxstride + y*rystride] +=
558 abase[x*axstride + n*aystride] *
559 bbase[n*bxstride + y*bystride];
561 else if (GFC_DESCRIPTOR_RANK (a) == 1)
563 const GFC_COMPLEX_8 *restrict bbase_y;
564 GFC_COMPLEX_8 s;
566 for (y = 0; y < ycount; y++)
568 bbase_y = &bbase[y*bystride];
569 s = (GFC_COMPLEX_8) 0;
570 for (n = 0; n < count; n++)
571 s += abase[n*axstride] * bbase_y[n*bxstride];
572 dest[y*rxstride] = s;
575 else
577 const GFC_COMPLEX_8 *restrict abase_x;
578 const GFC_COMPLEX_8 *restrict bbase_y;
579 GFC_COMPLEX_8 *restrict dest_y;
580 GFC_COMPLEX_8 s;
582 for (y = 0; y < ycount; y++)
584 bbase_y = &bbase[y*bystride];
585 dest_y = &dest[y*rystride];
586 for (x = 0; x < xcount; x++)
588 abase_x = &abase[x*axstride];
589 s = (GFC_COMPLEX_8) 0;
590 for (n = 0; n < count; n++)
591 s += abase_x[n*aystride] * bbase_y[n*bxstride];
592 dest_y[x*rxstride] = s;
597 #undef POW3
598 #undef min
599 #undef max
601 #endif
603 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
604 void
605 matmul_c8_avx128_fma4 (gfc_array_c8 * const restrict retarray,
606 gfc_array_c8 * const restrict a, gfc_array_c8 * const restrict b, int try_blas,
607 int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma4")));
608 internal_proto(matmul_c8_avx128_fma4);
609 void
610 matmul_c8_avx128_fma4 (gfc_array_c8 * const restrict retarray,
611 gfc_array_c8 * const restrict a, gfc_array_c8 * const restrict b, int try_blas,
612 int blas_limit, blas_call gemm)
614 const GFC_COMPLEX_8 * restrict abase;
615 const GFC_COMPLEX_8 * restrict bbase;
616 GFC_COMPLEX_8 * restrict dest;
618 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
619 index_type x, y, n, count, xcount, ycount;
621 assert (GFC_DESCRIPTOR_RANK (a) == 2
622 || GFC_DESCRIPTOR_RANK (b) == 2);
624 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
626 Either A or B (but not both) can be rank 1:
628 o One-dimensional argument A is implicitly treated as a row matrix
629 dimensioned [1,count], so xcount=1.
631 o One-dimensional argument B is implicitly treated as a column matrix
632 dimensioned [count, 1], so ycount=1.
635 if (retarray->base_addr == NULL)
637 if (GFC_DESCRIPTOR_RANK (a) == 1)
639 GFC_DIMENSION_SET(retarray->dim[0], 0,
640 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
642 else if (GFC_DESCRIPTOR_RANK (b) == 1)
644 GFC_DIMENSION_SET(retarray->dim[0], 0,
645 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
647 else
649 GFC_DIMENSION_SET(retarray->dim[0], 0,
650 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
652 GFC_DIMENSION_SET(retarray->dim[1], 0,
653 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
654 GFC_DESCRIPTOR_EXTENT(retarray,0));
657 retarray->base_addr
658 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_COMPLEX_8));
659 retarray->offset = 0;
661 else if (unlikely (compile_options.bounds_check))
663 index_type ret_extent, arg_extent;
665 if (GFC_DESCRIPTOR_RANK (a) == 1)
667 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
668 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
669 if (arg_extent != ret_extent)
670 runtime_error ("Incorrect extent in return array in"
671 " MATMUL intrinsic: is %ld, should be %ld",
672 (long int) ret_extent, (long int) arg_extent);
674 else if (GFC_DESCRIPTOR_RANK (b) == 1)
676 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
677 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
678 if (arg_extent != ret_extent)
679 runtime_error ("Incorrect extent in return array in"
680 " MATMUL intrinsic: is %ld, should be %ld",
681 (long int) ret_extent, (long int) arg_extent);
683 else
685 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
686 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
687 if (arg_extent != ret_extent)
688 runtime_error ("Incorrect extent in return array in"
689 " MATMUL intrinsic for dimension 1:"
690 " is %ld, should be %ld",
691 (long int) ret_extent, (long int) arg_extent);
693 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
694 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
695 if (arg_extent != ret_extent)
696 runtime_error ("Incorrect extent in return array in"
697 " MATMUL intrinsic for dimension 2:"
698 " is %ld, should be %ld",
699 (long int) ret_extent, (long int) arg_extent);
704 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
706 /* One-dimensional result may be addressed in the code below
707 either as a row or a column matrix. We want both cases to
708 work. */
709 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
711 else
713 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
714 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
718 if (GFC_DESCRIPTOR_RANK (a) == 1)
720 /* Treat it as a a row matrix A[1,count]. */
721 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
722 aystride = 1;
724 xcount = 1;
725 count = GFC_DESCRIPTOR_EXTENT(a,0);
727 else
729 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
730 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
732 count = GFC_DESCRIPTOR_EXTENT(a,1);
733 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
736 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
738 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
739 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
742 if (GFC_DESCRIPTOR_RANK (b) == 1)
744 /* Treat it as a column matrix B[count,1] */
745 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
747 /* bystride should never be used for 1-dimensional b.
748 The value is only used for calculation of the
749 memory by the buffer. */
750 bystride = 256;
751 ycount = 1;
753 else
755 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
756 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
757 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
760 abase = a->base_addr;
761 bbase = b->base_addr;
762 dest = retarray->base_addr;
764 /* Now that everything is set up, we perform the multiplication
765 itself. */
767 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
768 #define min(a,b) ((a) <= (b) ? (a) : (b))
769 #define max(a,b) ((a) >= (b) ? (a) : (b))
771 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
772 && (bxstride == 1 || bystride == 1)
773 && (((float) xcount) * ((float) ycount) * ((float) count)
774 > POW3(blas_limit)))
776 const int m = xcount, n = ycount, k = count, ldc = rystride;
777 const GFC_COMPLEX_8 one = 1, zero = 0;
778 const int lda = (axstride == 1) ? aystride : axstride,
779 ldb = (bxstride == 1) ? bystride : bxstride;
781 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
783 assert (gemm != NULL);
784 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m,
785 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
786 &ldc, 1, 1);
787 return;
791 if (rxstride == 1 && axstride == 1 && bxstride == 1)
793 /* This block of code implements a tuned matmul, derived from
794 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
796 Bo Kagstrom and Per Ling
797 Department of Computing Science
798 Umea University
799 S-901 87 Umea, Sweden
801 from netlib.org, translated to C, and modified for matmul.m4. */
803 const GFC_COMPLEX_8 *a, *b;
804 GFC_COMPLEX_8 *c;
805 const index_type m = xcount, n = ycount, k = count;
807 /* System generated locals */
808 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
809 i1, i2, i3, i4, i5, i6;
811 /* Local variables */
812 GFC_COMPLEX_8 f11, f12, f21, f22, f31, f32, f41, f42,
813 f13, f14, f23, f24, f33, f34, f43, f44;
814 index_type i, j, l, ii, jj, ll;
815 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
816 GFC_COMPLEX_8 *t1;
818 a = abase;
819 b = bbase;
820 c = retarray->base_addr;
822 /* Parameter adjustments */
823 c_dim1 = rystride;
824 c_offset = 1 + c_dim1;
825 c -= c_offset;
826 a_dim1 = aystride;
827 a_offset = 1 + a_dim1;
828 a -= a_offset;
829 b_dim1 = bystride;
830 b_offset = 1 + b_dim1;
831 b -= b_offset;
833 /* Empty c first. */
834 for (j=1; j<=n; j++)
835 for (i=1; i<=m; i++)
836 c[i + j * c_dim1] = (GFC_COMPLEX_8)0;
838 /* Early exit if possible */
839 if (m == 0 || n == 0 || k == 0)
840 return;
842 /* Adjust size of t1 to what is needed. */
843 index_type t1_dim, a_sz;
844 if (aystride == 1)
845 a_sz = rystride;
846 else
847 a_sz = a_dim1;
849 t1_dim = a_sz * 256 + b_dim1;
850 if (t1_dim > 65536)
851 t1_dim = 65536;
853 t1 = malloc (t1_dim * sizeof(GFC_COMPLEX_8));
855 /* Start turning the crank. */
856 i1 = n;
857 for (jj = 1; jj <= i1; jj += 512)
859 /* Computing MIN */
860 i2 = 512;
861 i3 = n - jj + 1;
862 jsec = min(i2,i3);
863 ujsec = jsec - jsec % 4;
864 i2 = k;
865 for (ll = 1; ll <= i2; ll += 256)
867 /* Computing MIN */
868 i3 = 256;
869 i4 = k - ll + 1;
870 lsec = min(i3,i4);
871 ulsec = lsec - lsec % 2;
873 i3 = m;
874 for (ii = 1; ii <= i3; ii += 256)
876 /* Computing MIN */
877 i4 = 256;
878 i5 = m - ii + 1;
879 isec = min(i4,i5);
880 uisec = isec - isec % 2;
881 i4 = ll + ulsec - 1;
882 for (l = ll; l <= i4; l += 2)
884 i5 = ii + uisec - 1;
885 for (i = ii; i <= i5; i += 2)
887 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
888 a[i + l * a_dim1];
889 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
890 a[i + (l + 1) * a_dim1];
891 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
892 a[i + 1 + l * a_dim1];
893 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
894 a[i + 1 + (l + 1) * a_dim1];
896 if (uisec < isec)
898 t1[l - ll + 1 + (isec << 8) - 257] =
899 a[ii + isec - 1 + l * a_dim1];
900 t1[l - ll + 2 + (isec << 8) - 257] =
901 a[ii + isec - 1 + (l + 1) * a_dim1];
904 if (ulsec < lsec)
906 i4 = ii + isec - 1;
907 for (i = ii; i<= i4; ++i)
909 t1[lsec + ((i - ii + 1) << 8) - 257] =
910 a[i + (ll + lsec - 1) * a_dim1];
914 uisec = isec - isec % 4;
915 i4 = jj + ujsec - 1;
916 for (j = jj; j <= i4; j += 4)
918 i5 = ii + uisec - 1;
919 for (i = ii; i <= i5; i += 4)
921 f11 = c[i + j * c_dim1];
922 f21 = c[i + 1 + j * c_dim1];
923 f12 = c[i + (j + 1) * c_dim1];
924 f22 = c[i + 1 + (j + 1) * c_dim1];
925 f13 = c[i + (j + 2) * c_dim1];
926 f23 = c[i + 1 + (j + 2) * c_dim1];
927 f14 = c[i + (j + 3) * c_dim1];
928 f24 = c[i + 1 + (j + 3) * c_dim1];
929 f31 = c[i + 2 + j * c_dim1];
930 f41 = c[i + 3 + j * c_dim1];
931 f32 = c[i + 2 + (j + 1) * c_dim1];
932 f42 = c[i + 3 + (j + 1) * c_dim1];
933 f33 = c[i + 2 + (j + 2) * c_dim1];
934 f43 = c[i + 3 + (j + 2) * c_dim1];
935 f34 = c[i + 2 + (j + 3) * c_dim1];
936 f44 = c[i + 3 + (j + 3) * c_dim1];
937 i6 = ll + lsec - 1;
938 for (l = ll; l <= i6; ++l)
940 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
941 * b[l + j * b_dim1];
942 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
943 * b[l + j * b_dim1];
944 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
945 * b[l + (j + 1) * b_dim1];
946 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
947 * b[l + (j + 1) * b_dim1];
948 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
949 * b[l + (j + 2) * b_dim1];
950 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
951 * b[l + (j + 2) * b_dim1];
952 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
953 * b[l + (j + 3) * b_dim1];
954 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
955 * b[l + (j + 3) * b_dim1];
956 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
957 * b[l + j * b_dim1];
958 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
959 * b[l + j * b_dim1];
960 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
961 * b[l + (j + 1) * b_dim1];
962 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
963 * b[l + (j + 1) * b_dim1];
964 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
965 * b[l + (j + 2) * b_dim1];
966 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
967 * b[l + (j + 2) * b_dim1];
968 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
969 * b[l + (j + 3) * b_dim1];
970 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
971 * b[l + (j + 3) * b_dim1];
973 c[i + j * c_dim1] = f11;
974 c[i + 1 + j * c_dim1] = f21;
975 c[i + (j + 1) * c_dim1] = f12;
976 c[i + 1 + (j + 1) * c_dim1] = f22;
977 c[i + (j + 2) * c_dim1] = f13;
978 c[i + 1 + (j + 2) * c_dim1] = f23;
979 c[i + (j + 3) * c_dim1] = f14;
980 c[i + 1 + (j + 3) * c_dim1] = f24;
981 c[i + 2 + j * c_dim1] = f31;
982 c[i + 3 + j * c_dim1] = f41;
983 c[i + 2 + (j + 1) * c_dim1] = f32;
984 c[i + 3 + (j + 1) * c_dim1] = f42;
985 c[i + 2 + (j + 2) * c_dim1] = f33;
986 c[i + 3 + (j + 2) * c_dim1] = f43;
987 c[i + 2 + (j + 3) * c_dim1] = f34;
988 c[i + 3 + (j + 3) * c_dim1] = f44;
990 if (uisec < isec)
992 i5 = ii + isec - 1;
993 for (i = ii + uisec; i <= i5; ++i)
995 f11 = c[i + j * c_dim1];
996 f12 = c[i + (j + 1) * c_dim1];
997 f13 = c[i + (j + 2) * c_dim1];
998 f14 = c[i + (j + 3) * c_dim1];
999 i6 = ll + lsec - 1;
1000 for (l = ll; l <= i6; ++l)
1002 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1003 257] * b[l + j * b_dim1];
1004 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1005 257] * b[l + (j + 1) * b_dim1];
1006 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1007 257] * b[l + (j + 2) * b_dim1];
1008 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1009 257] * b[l + (j + 3) * b_dim1];
1011 c[i + j * c_dim1] = f11;
1012 c[i + (j + 1) * c_dim1] = f12;
1013 c[i + (j + 2) * c_dim1] = f13;
1014 c[i + (j + 3) * c_dim1] = f14;
1018 if (ujsec < jsec)
1020 i4 = jj + jsec - 1;
1021 for (j = jj + ujsec; j <= i4; ++j)
1023 i5 = ii + uisec - 1;
1024 for (i = ii; i <= i5; i += 4)
1026 f11 = c[i + j * c_dim1];
1027 f21 = c[i + 1 + j * c_dim1];
1028 f31 = c[i + 2 + j * c_dim1];
1029 f41 = c[i + 3 + j * c_dim1];
1030 i6 = ll + lsec - 1;
1031 for (l = ll; l <= i6; ++l)
1033 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1034 257] * b[l + j * b_dim1];
1035 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
1036 257] * b[l + j * b_dim1];
1037 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
1038 257] * b[l + j * b_dim1];
1039 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
1040 257] * b[l + j * b_dim1];
1042 c[i + j * c_dim1] = f11;
1043 c[i + 1 + j * c_dim1] = f21;
1044 c[i + 2 + j * c_dim1] = f31;
1045 c[i + 3 + j * c_dim1] = f41;
1047 i5 = ii + isec - 1;
1048 for (i = ii + uisec; i <= i5; ++i)
1050 f11 = c[i + j * c_dim1];
1051 i6 = ll + lsec - 1;
1052 for (l = ll; l <= i6; ++l)
1054 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1055 257] * b[l + j * b_dim1];
1057 c[i + j * c_dim1] = f11;
1064 free(t1);
1065 return;
1067 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
1069 if (GFC_DESCRIPTOR_RANK (a) != 1)
1071 const GFC_COMPLEX_8 *restrict abase_x;
1072 const GFC_COMPLEX_8 *restrict bbase_y;
1073 GFC_COMPLEX_8 *restrict dest_y;
1074 GFC_COMPLEX_8 s;
1076 for (y = 0; y < ycount; y++)
1078 bbase_y = &bbase[y*bystride];
1079 dest_y = &dest[y*rystride];
1080 for (x = 0; x < xcount; x++)
1082 abase_x = &abase[x*axstride];
1083 s = (GFC_COMPLEX_8) 0;
1084 for (n = 0; n < count; n++)
1085 s += abase_x[n] * bbase_y[n];
1086 dest_y[x] = s;
1090 else
1092 const GFC_COMPLEX_8 *restrict bbase_y;
1093 GFC_COMPLEX_8 s;
1095 for (y = 0; y < ycount; y++)
1097 bbase_y = &bbase[y*bystride];
1098 s = (GFC_COMPLEX_8) 0;
1099 for (n = 0; n < count; n++)
1100 s += abase[n*axstride] * bbase_y[n];
1101 dest[y*rystride] = s;
1105 else if (axstride < aystride)
1107 for (y = 0; y < ycount; y++)
1108 for (x = 0; x < xcount; x++)
1109 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_8)0;
1111 for (y = 0; y < ycount; y++)
1112 for (n = 0; n < count; n++)
1113 for (x = 0; x < xcount; x++)
1114 /* dest[x,y] += a[x,n] * b[n,y] */
1115 dest[x*rxstride + y*rystride] +=
1116 abase[x*axstride + n*aystride] *
1117 bbase[n*bxstride + y*bystride];
1119 else if (GFC_DESCRIPTOR_RANK (a) == 1)
1121 const GFC_COMPLEX_8 *restrict bbase_y;
1122 GFC_COMPLEX_8 s;
1124 for (y = 0; y < ycount; y++)
1126 bbase_y = &bbase[y*bystride];
1127 s = (GFC_COMPLEX_8) 0;
1128 for (n = 0; n < count; n++)
1129 s += abase[n*axstride] * bbase_y[n*bxstride];
1130 dest[y*rxstride] = s;
1133 else
1135 const GFC_COMPLEX_8 *restrict abase_x;
1136 const GFC_COMPLEX_8 *restrict bbase_y;
1137 GFC_COMPLEX_8 *restrict dest_y;
1138 GFC_COMPLEX_8 s;
1140 for (y = 0; y < ycount; y++)
1142 bbase_y = &bbase[y*bystride];
1143 dest_y = &dest[y*rystride];
1144 for (x = 0; x < xcount; x++)
1146 abase_x = &abase[x*axstride];
1147 s = (GFC_COMPLEX_8) 0;
1148 for (n = 0; n < count; n++)
1149 s += abase_x[n*aystride] * bbase_y[n*bxstride];
1150 dest_y[x*rxstride] = s;
1155 #undef POW3
1156 #undef min
1157 #undef max
1159 #endif
1161 #endif