2018-10-09 Richard Biener <rguenther@suse.de>
[official-gcc.git] / libgfortran / generated / matmulavx128_c10.c
blob03914715d5c0c1e237496784b542a62d8bdb224a
1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2018 Free Software Foundation, Inc.
3 Contributed by Thomas Koenig <tkoenig@gcc.gnu.org>.
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
27 #include <string.h>
28 #include <assert.h>
31 /* These are the specific versions of matmul with -mprefer-avx128. */
33 #if defined (HAVE_GFC_COMPLEX_10)
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36 passed to us by the front-end, in which case we call it for large
37 matrices. */
39 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
40 const int *, const GFC_COMPLEX_10 *, const GFC_COMPLEX_10 *,
41 const int *, const GFC_COMPLEX_10 *, const int *,
42 const GFC_COMPLEX_10 *, GFC_COMPLEX_10 *, const int *,
43 int, int);
45 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
46 void
47 matmul_c10_avx128_fma3 (gfc_array_c10 * const restrict retarray,
48 gfc_array_c10 * const restrict a, gfc_array_c10 * const restrict b, int try_blas,
49 int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma")));
50 internal_proto(matmul_c10_avx128_fma3);
51 void
52 matmul_c10_avx128_fma3 (gfc_array_c10 * const restrict retarray,
53 gfc_array_c10 * const restrict a, gfc_array_c10 * const restrict b, int try_blas,
54 int blas_limit, blas_call gemm)
56 const GFC_COMPLEX_10 * restrict abase;
57 const GFC_COMPLEX_10 * restrict bbase;
58 GFC_COMPLEX_10 * restrict dest;
60 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
61 index_type x, y, n, count, xcount, ycount;
63 assert (GFC_DESCRIPTOR_RANK (a) == 2
64 || GFC_DESCRIPTOR_RANK (b) == 2);
66 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
68 Either A or B (but not both) can be rank 1:
70 o One-dimensional argument A is implicitly treated as a row matrix
71 dimensioned [1,count], so xcount=1.
73 o One-dimensional argument B is implicitly treated as a column matrix
74 dimensioned [count, 1], so ycount=1.
77 if (retarray->base_addr == NULL)
79 if (GFC_DESCRIPTOR_RANK (a) == 1)
81 GFC_DIMENSION_SET(retarray->dim[0], 0,
82 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
84 else if (GFC_DESCRIPTOR_RANK (b) == 1)
86 GFC_DIMENSION_SET(retarray->dim[0], 0,
87 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
89 else
91 GFC_DIMENSION_SET(retarray->dim[0], 0,
92 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
94 GFC_DIMENSION_SET(retarray->dim[1], 0,
95 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
96 GFC_DESCRIPTOR_EXTENT(retarray,0));
99 retarray->base_addr
100 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_COMPLEX_10));
101 retarray->offset = 0;
103 else if (unlikely (compile_options.bounds_check))
105 index_type ret_extent, arg_extent;
107 if (GFC_DESCRIPTOR_RANK (a) == 1)
109 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
110 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
111 if (arg_extent != ret_extent)
112 runtime_error ("Array bound mismatch for dimension 1 of "
113 "array (%ld/%ld) ",
114 (long int) ret_extent, (long int) arg_extent);
116 else if (GFC_DESCRIPTOR_RANK (b) == 1)
118 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
119 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
120 if (arg_extent != ret_extent)
121 runtime_error ("Array bound mismatch for dimension 1 of "
122 "array (%ld/%ld) ",
123 (long int) ret_extent, (long int) arg_extent);
125 else
127 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
128 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
129 if (arg_extent != ret_extent)
130 runtime_error ("Array bound mismatch for dimension 1 of "
131 "array (%ld/%ld) ",
132 (long int) ret_extent, (long int) arg_extent);
134 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
135 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
136 if (arg_extent != ret_extent)
137 runtime_error ("Array bound mismatch for dimension 2 of "
138 "array (%ld/%ld) ",
139 (long int) ret_extent, (long int) arg_extent);
144 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
146 /* One-dimensional result may be addressed in the code below
147 either as a row or a column matrix. We want both cases to
148 work. */
149 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
151 else
153 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
154 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
158 if (GFC_DESCRIPTOR_RANK (a) == 1)
160 /* Treat it as a a row matrix A[1,count]. */
161 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
162 aystride = 1;
164 xcount = 1;
165 count = GFC_DESCRIPTOR_EXTENT(a,0);
167 else
169 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
170 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
172 count = GFC_DESCRIPTOR_EXTENT(a,1);
173 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
176 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
178 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
179 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
180 "in dimension 1: is %ld, should be %ld",
181 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
184 if (GFC_DESCRIPTOR_RANK (b) == 1)
186 /* Treat it as a column matrix B[count,1] */
187 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
189 /* bystride should never be used for 1-dimensional b.
190 The value is only used for calculation of the
191 memory by the buffer. */
192 bystride = 256;
193 ycount = 1;
195 else
197 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
198 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
199 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
202 abase = a->base_addr;
203 bbase = b->base_addr;
204 dest = retarray->base_addr;
206 /* Now that everything is set up, we perform the multiplication
207 itself. */
209 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
210 #define min(a,b) ((a) <= (b) ? (a) : (b))
211 #define max(a,b) ((a) >= (b) ? (a) : (b))
213 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
214 && (bxstride == 1 || bystride == 1)
215 && (((float) xcount) * ((float) ycount) * ((float) count)
216 > POW3(blas_limit)))
218 const int m = xcount, n = ycount, k = count, ldc = rystride;
219 const GFC_COMPLEX_10 one = 1, zero = 0;
220 const int lda = (axstride == 1) ? aystride : axstride,
221 ldb = (bxstride == 1) ? bystride : bxstride;
223 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
225 assert (gemm != NULL);
226 const char *transa, *transb;
227 if (try_blas & 2)
228 transa = "C";
229 else
230 transa = axstride == 1 ? "N" : "T";
232 if (try_blas & 4)
233 transb = "C";
234 else
235 transb = bxstride == 1 ? "N" : "T";
237 gemm (transa, transb , &m,
238 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
239 &ldc, 1, 1);
240 return;
244 if (rxstride == 1 && axstride == 1 && bxstride == 1)
246 /* This block of code implements a tuned matmul, derived from
247 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
249 Bo Kagstrom and Per Ling
250 Department of Computing Science
251 Umea University
252 S-901 87 Umea, Sweden
254 from netlib.org, translated to C, and modified for matmul.m4. */
256 const GFC_COMPLEX_10 *a, *b;
257 GFC_COMPLEX_10 *c;
258 const index_type m = xcount, n = ycount, k = count;
260 /* System generated locals */
261 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
262 i1, i2, i3, i4, i5, i6;
264 /* Local variables */
265 GFC_COMPLEX_10 f11, f12, f21, f22, f31, f32, f41, f42,
266 f13, f14, f23, f24, f33, f34, f43, f44;
267 index_type i, j, l, ii, jj, ll;
268 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
269 GFC_COMPLEX_10 *t1;
271 a = abase;
272 b = bbase;
273 c = retarray->base_addr;
275 /* Parameter adjustments */
276 c_dim1 = rystride;
277 c_offset = 1 + c_dim1;
278 c -= c_offset;
279 a_dim1 = aystride;
280 a_offset = 1 + a_dim1;
281 a -= a_offset;
282 b_dim1 = bystride;
283 b_offset = 1 + b_dim1;
284 b -= b_offset;
286 /* Empty c first. */
287 for (j=1; j<=n; j++)
288 for (i=1; i<=m; i++)
289 c[i + j * c_dim1] = (GFC_COMPLEX_10)0;
291 /* Early exit if possible */
292 if (m == 0 || n == 0 || k == 0)
293 return;
295 /* Adjust size of t1 to what is needed. */
296 index_type t1_dim, a_sz;
297 if (aystride == 1)
298 a_sz = rystride;
299 else
300 a_sz = a_dim1;
302 t1_dim = a_sz * 256 + b_dim1;
303 if (t1_dim > 65536)
304 t1_dim = 65536;
306 t1 = malloc (t1_dim * sizeof(GFC_COMPLEX_10));
308 /* Start turning the crank. */
309 i1 = n;
310 for (jj = 1; jj <= i1; jj += 512)
312 /* Computing MIN */
313 i2 = 512;
314 i3 = n - jj + 1;
315 jsec = min(i2,i3);
316 ujsec = jsec - jsec % 4;
317 i2 = k;
318 for (ll = 1; ll <= i2; ll += 256)
320 /* Computing MIN */
321 i3 = 256;
322 i4 = k - ll + 1;
323 lsec = min(i3,i4);
324 ulsec = lsec - lsec % 2;
326 i3 = m;
327 for (ii = 1; ii <= i3; ii += 256)
329 /* Computing MIN */
330 i4 = 256;
331 i5 = m - ii + 1;
332 isec = min(i4,i5);
333 uisec = isec - isec % 2;
334 i4 = ll + ulsec - 1;
335 for (l = ll; l <= i4; l += 2)
337 i5 = ii + uisec - 1;
338 for (i = ii; i <= i5; i += 2)
340 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
341 a[i + l * a_dim1];
342 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
343 a[i + (l + 1) * a_dim1];
344 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
345 a[i + 1 + l * a_dim1];
346 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
347 a[i + 1 + (l + 1) * a_dim1];
349 if (uisec < isec)
351 t1[l - ll + 1 + (isec << 8) - 257] =
352 a[ii + isec - 1 + l * a_dim1];
353 t1[l - ll + 2 + (isec << 8) - 257] =
354 a[ii + isec - 1 + (l + 1) * a_dim1];
357 if (ulsec < lsec)
359 i4 = ii + isec - 1;
360 for (i = ii; i<= i4; ++i)
362 t1[lsec + ((i - ii + 1) << 8) - 257] =
363 a[i + (ll + lsec - 1) * a_dim1];
367 uisec = isec - isec % 4;
368 i4 = jj + ujsec - 1;
369 for (j = jj; j <= i4; j += 4)
371 i5 = ii + uisec - 1;
372 for (i = ii; i <= i5; i += 4)
374 f11 = c[i + j * c_dim1];
375 f21 = c[i + 1 + j * c_dim1];
376 f12 = c[i + (j + 1) * c_dim1];
377 f22 = c[i + 1 + (j + 1) * c_dim1];
378 f13 = c[i + (j + 2) * c_dim1];
379 f23 = c[i + 1 + (j + 2) * c_dim1];
380 f14 = c[i + (j + 3) * c_dim1];
381 f24 = c[i + 1 + (j + 3) * c_dim1];
382 f31 = c[i + 2 + j * c_dim1];
383 f41 = c[i + 3 + j * c_dim1];
384 f32 = c[i + 2 + (j + 1) * c_dim1];
385 f42 = c[i + 3 + (j + 1) * c_dim1];
386 f33 = c[i + 2 + (j + 2) * c_dim1];
387 f43 = c[i + 3 + (j + 2) * c_dim1];
388 f34 = c[i + 2 + (j + 3) * c_dim1];
389 f44 = c[i + 3 + (j + 3) * c_dim1];
390 i6 = ll + lsec - 1;
391 for (l = ll; l <= i6; ++l)
393 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
394 * b[l + j * b_dim1];
395 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
396 * b[l + j * b_dim1];
397 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
398 * b[l + (j + 1) * b_dim1];
399 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
400 * b[l + (j + 1) * b_dim1];
401 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
402 * b[l + (j + 2) * b_dim1];
403 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
404 * b[l + (j + 2) * b_dim1];
405 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
406 * b[l + (j + 3) * b_dim1];
407 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
408 * b[l + (j + 3) * b_dim1];
409 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
410 * b[l + j * b_dim1];
411 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
412 * b[l + j * b_dim1];
413 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
414 * b[l + (j + 1) * b_dim1];
415 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
416 * b[l + (j + 1) * b_dim1];
417 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
418 * b[l + (j + 2) * b_dim1];
419 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
420 * b[l + (j + 2) * b_dim1];
421 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
422 * b[l + (j + 3) * b_dim1];
423 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
424 * b[l + (j + 3) * b_dim1];
426 c[i + j * c_dim1] = f11;
427 c[i + 1 + j * c_dim1] = f21;
428 c[i + (j + 1) * c_dim1] = f12;
429 c[i + 1 + (j + 1) * c_dim1] = f22;
430 c[i + (j + 2) * c_dim1] = f13;
431 c[i + 1 + (j + 2) * c_dim1] = f23;
432 c[i + (j + 3) * c_dim1] = f14;
433 c[i + 1 + (j + 3) * c_dim1] = f24;
434 c[i + 2 + j * c_dim1] = f31;
435 c[i + 3 + j * c_dim1] = f41;
436 c[i + 2 + (j + 1) * c_dim1] = f32;
437 c[i + 3 + (j + 1) * c_dim1] = f42;
438 c[i + 2 + (j + 2) * c_dim1] = f33;
439 c[i + 3 + (j + 2) * c_dim1] = f43;
440 c[i + 2 + (j + 3) * c_dim1] = f34;
441 c[i + 3 + (j + 3) * c_dim1] = f44;
443 if (uisec < isec)
445 i5 = ii + isec - 1;
446 for (i = ii + uisec; i <= i5; ++i)
448 f11 = c[i + j * c_dim1];
449 f12 = c[i + (j + 1) * c_dim1];
450 f13 = c[i + (j + 2) * c_dim1];
451 f14 = c[i + (j + 3) * c_dim1];
452 i6 = ll + lsec - 1;
453 for (l = ll; l <= i6; ++l)
455 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
456 257] * b[l + j * b_dim1];
457 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
458 257] * b[l + (j + 1) * b_dim1];
459 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
460 257] * b[l + (j + 2) * b_dim1];
461 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
462 257] * b[l + (j + 3) * b_dim1];
464 c[i + j * c_dim1] = f11;
465 c[i + (j + 1) * c_dim1] = f12;
466 c[i + (j + 2) * c_dim1] = f13;
467 c[i + (j + 3) * c_dim1] = f14;
471 if (ujsec < jsec)
473 i4 = jj + jsec - 1;
474 for (j = jj + ujsec; j <= i4; ++j)
476 i5 = ii + uisec - 1;
477 for (i = ii; i <= i5; i += 4)
479 f11 = c[i + j * c_dim1];
480 f21 = c[i + 1 + j * c_dim1];
481 f31 = c[i + 2 + j * c_dim1];
482 f41 = c[i + 3 + j * c_dim1];
483 i6 = ll + lsec - 1;
484 for (l = ll; l <= i6; ++l)
486 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
487 257] * b[l + j * b_dim1];
488 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
489 257] * b[l + j * b_dim1];
490 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
491 257] * b[l + j * b_dim1];
492 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
493 257] * b[l + j * b_dim1];
495 c[i + j * c_dim1] = f11;
496 c[i + 1 + j * c_dim1] = f21;
497 c[i + 2 + j * c_dim1] = f31;
498 c[i + 3 + j * c_dim1] = f41;
500 i5 = ii + isec - 1;
501 for (i = ii + uisec; i <= i5; ++i)
503 f11 = c[i + j * c_dim1];
504 i6 = ll + lsec - 1;
505 for (l = ll; l <= i6; ++l)
507 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
508 257] * b[l + j * b_dim1];
510 c[i + j * c_dim1] = f11;
517 free(t1);
518 return;
520 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
522 if (GFC_DESCRIPTOR_RANK (a) != 1)
524 const GFC_COMPLEX_10 *restrict abase_x;
525 const GFC_COMPLEX_10 *restrict bbase_y;
526 GFC_COMPLEX_10 *restrict dest_y;
527 GFC_COMPLEX_10 s;
529 for (y = 0; y < ycount; y++)
531 bbase_y = &bbase[y*bystride];
532 dest_y = &dest[y*rystride];
533 for (x = 0; x < xcount; x++)
535 abase_x = &abase[x*axstride];
536 s = (GFC_COMPLEX_10) 0;
537 for (n = 0; n < count; n++)
538 s += abase_x[n] * bbase_y[n];
539 dest_y[x] = s;
543 else
545 const GFC_COMPLEX_10 *restrict bbase_y;
546 GFC_COMPLEX_10 s;
548 for (y = 0; y < ycount; y++)
550 bbase_y = &bbase[y*bystride];
551 s = (GFC_COMPLEX_10) 0;
552 for (n = 0; n < count; n++)
553 s += abase[n*axstride] * bbase_y[n];
554 dest[y*rystride] = s;
558 else if (axstride < aystride)
560 for (y = 0; y < ycount; y++)
561 for (x = 0; x < xcount; x++)
562 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_10)0;
564 for (y = 0; y < ycount; y++)
565 for (n = 0; n < count; n++)
566 for (x = 0; x < xcount; x++)
567 /* dest[x,y] += a[x,n] * b[n,y] */
568 dest[x*rxstride + y*rystride] +=
569 abase[x*axstride + n*aystride] *
570 bbase[n*bxstride + y*bystride];
572 else if (GFC_DESCRIPTOR_RANK (a) == 1)
574 const GFC_COMPLEX_10 *restrict bbase_y;
575 GFC_COMPLEX_10 s;
577 for (y = 0; y < ycount; y++)
579 bbase_y = &bbase[y*bystride];
580 s = (GFC_COMPLEX_10) 0;
581 for (n = 0; n < count; n++)
582 s += abase[n*axstride] * bbase_y[n*bxstride];
583 dest[y*rxstride] = s;
586 else
588 const GFC_COMPLEX_10 *restrict abase_x;
589 const GFC_COMPLEX_10 *restrict bbase_y;
590 GFC_COMPLEX_10 *restrict dest_y;
591 GFC_COMPLEX_10 s;
593 for (y = 0; y < ycount; y++)
595 bbase_y = &bbase[y*bystride];
596 dest_y = &dest[y*rystride];
597 for (x = 0; x < xcount; x++)
599 abase_x = &abase[x*axstride];
600 s = (GFC_COMPLEX_10) 0;
601 for (n = 0; n < count; n++)
602 s += abase_x[n*aystride] * bbase_y[n*bxstride];
603 dest_y[x*rxstride] = s;
608 #undef POW3
609 #undef min
610 #undef max
612 #endif
614 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
615 void
616 matmul_c10_avx128_fma4 (gfc_array_c10 * const restrict retarray,
617 gfc_array_c10 * const restrict a, gfc_array_c10 * const restrict b, int try_blas,
618 int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma4")));
619 internal_proto(matmul_c10_avx128_fma4);
620 void
621 matmul_c10_avx128_fma4 (gfc_array_c10 * const restrict retarray,
622 gfc_array_c10 * const restrict a, gfc_array_c10 * const restrict b, int try_blas,
623 int blas_limit, blas_call gemm)
625 const GFC_COMPLEX_10 * restrict abase;
626 const GFC_COMPLEX_10 * restrict bbase;
627 GFC_COMPLEX_10 * restrict dest;
629 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
630 index_type x, y, n, count, xcount, ycount;
632 assert (GFC_DESCRIPTOR_RANK (a) == 2
633 || GFC_DESCRIPTOR_RANK (b) == 2);
635 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
637 Either A or B (but not both) can be rank 1:
639 o One-dimensional argument A is implicitly treated as a row matrix
640 dimensioned [1,count], so xcount=1.
642 o One-dimensional argument B is implicitly treated as a column matrix
643 dimensioned [count, 1], so ycount=1.
646 if (retarray->base_addr == NULL)
648 if (GFC_DESCRIPTOR_RANK (a) == 1)
650 GFC_DIMENSION_SET(retarray->dim[0], 0,
651 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
653 else if (GFC_DESCRIPTOR_RANK (b) == 1)
655 GFC_DIMENSION_SET(retarray->dim[0], 0,
656 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
658 else
660 GFC_DIMENSION_SET(retarray->dim[0], 0,
661 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
663 GFC_DIMENSION_SET(retarray->dim[1], 0,
664 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
665 GFC_DESCRIPTOR_EXTENT(retarray,0));
668 retarray->base_addr
669 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_COMPLEX_10));
670 retarray->offset = 0;
672 else if (unlikely (compile_options.bounds_check))
674 index_type ret_extent, arg_extent;
676 if (GFC_DESCRIPTOR_RANK (a) == 1)
678 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
679 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
680 if (arg_extent != ret_extent)
681 runtime_error ("Array bound mismatch for dimension 1 of "
682 "array (%ld/%ld) ",
683 (long int) ret_extent, (long int) arg_extent);
685 else if (GFC_DESCRIPTOR_RANK (b) == 1)
687 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
688 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
689 if (arg_extent != ret_extent)
690 runtime_error ("Array bound mismatch for dimension 1 of "
691 "array (%ld/%ld) ",
692 (long int) ret_extent, (long int) arg_extent);
694 else
696 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
697 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
698 if (arg_extent != ret_extent)
699 runtime_error ("Array bound mismatch for dimension 1 of "
700 "array (%ld/%ld) ",
701 (long int) ret_extent, (long int) arg_extent);
703 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
704 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
705 if (arg_extent != ret_extent)
706 runtime_error ("Array bound mismatch for dimension 2 of "
707 "array (%ld/%ld) ",
708 (long int) ret_extent, (long int) arg_extent);
713 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
715 /* One-dimensional result may be addressed in the code below
716 either as a row or a column matrix. We want both cases to
717 work. */
718 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
720 else
722 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
723 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
727 if (GFC_DESCRIPTOR_RANK (a) == 1)
729 /* Treat it as a a row matrix A[1,count]. */
730 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
731 aystride = 1;
733 xcount = 1;
734 count = GFC_DESCRIPTOR_EXTENT(a,0);
736 else
738 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
739 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
741 count = GFC_DESCRIPTOR_EXTENT(a,1);
742 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
745 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
747 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
748 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
749 "in dimension 1: is %ld, should be %ld",
750 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
753 if (GFC_DESCRIPTOR_RANK (b) == 1)
755 /* Treat it as a column matrix B[count,1] */
756 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
758 /* bystride should never be used for 1-dimensional b.
759 The value is only used for calculation of the
760 memory by the buffer. */
761 bystride = 256;
762 ycount = 1;
764 else
766 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
767 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
768 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
771 abase = a->base_addr;
772 bbase = b->base_addr;
773 dest = retarray->base_addr;
775 /* Now that everything is set up, we perform the multiplication
776 itself. */
778 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
779 #define min(a,b) ((a) <= (b) ? (a) : (b))
780 #define max(a,b) ((a) >= (b) ? (a) : (b))
782 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
783 && (bxstride == 1 || bystride == 1)
784 && (((float) xcount) * ((float) ycount) * ((float) count)
785 > POW3(blas_limit)))
787 const int m = xcount, n = ycount, k = count, ldc = rystride;
788 const GFC_COMPLEX_10 one = 1, zero = 0;
789 const int lda = (axstride == 1) ? aystride : axstride,
790 ldb = (bxstride == 1) ? bystride : bxstride;
792 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
794 assert (gemm != NULL);
795 const char *transa, *transb;
796 if (try_blas & 2)
797 transa = "C";
798 else
799 transa = axstride == 1 ? "N" : "T";
801 if (try_blas & 4)
802 transb = "C";
803 else
804 transb = bxstride == 1 ? "N" : "T";
806 gemm (transa, transb , &m,
807 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
808 &ldc, 1, 1);
809 return;
813 if (rxstride == 1 && axstride == 1 && bxstride == 1)
815 /* This block of code implements a tuned matmul, derived from
816 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
818 Bo Kagstrom and Per Ling
819 Department of Computing Science
820 Umea University
821 S-901 87 Umea, Sweden
823 from netlib.org, translated to C, and modified for matmul.m4. */
825 const GFC_COMPLEX_10 *a, *b;
826 GFC_COMPLEX_10 *c;
827 const index_type m = xcount, n = ycount, k = count;
829 /* System generated locals */
830 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
831 i1, i2, i3, i4, i5, i6;
833 /* Local variables */
834 GFC_COMPLEX_10 f11, f12, f21, f22, f31, f32, f41, f42,
835 f13, f14, f23, f24, f33, f34, f43, f44;
836 index_type i, j, l, ii, jj, ll;
837 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
838 GFC_COMPLEX_10 *t1;
840 a = abase;
841 b = bbase;
842 c = retarray->base_addr;
844 /* Parameter adjustments */
845 c_dim1 = rystride;
846 c_offset = 1 + c_dim1;
847 c -= c_offset;
848 a_dim1 = aystride;
849 a_offset = 1 + a_dim1;
850 a -= a_offset;
851 b_dim1 = bystride;
852 b_offset = 1 + b_dim1;
853 b -= b_offset;
855 /* Empty c first. */
856 for (j=1; j<=n; j++)
857 for (i=1; i<=m; i++)
858 c[i + j * c_dim1] = (GFC_COMPLEX_10)0;
860 /* Early exit if possible */
861 if (m == 0 || n == 0 || k == 0)
862 return;
864 /* Adjust size of t1 to what is needed. */
865 index_type t1_dim, a_sz;
866 if (aystride == 1)
867 a_sz = rystride;
868 else
869 a_sz = a_dim1;
871 t1_dim = a_sz * 256 + b_dim1;
872 if (t1_dim > 65536)
873 t1_dim = 65536;
875 t1 = malloc (t1_dim * sizeof(GFC_COMPLEX_10));
877 /* Start turning the crank. */
878 i1 = n;
879 for (jj = 1; jj <= i1; jj += 512)
881 /* Computing MIN */
882 i2 = 512;
883 i3 = n - jj + 1;
884 jsec = min(i2,i3);
885 ujsec = jsec - jsec % 4;
886 i2 = k;
887 for (ll = 1; ll <= i2; ll += 256)
889 /* Computing MIN */
890 i3 = 256;
891 i4 = k - ll + 1;
892 lsec = min(i3,i4);
893 ulsec = lsec - lsec % 2;
895 i3 = m;
896 for (ii = 1; ii <= i3; ii += 256)
898 /* Computing MIN */
899 i4 = 256;
900 i5 = m - ii + 1;
901 isec = min(i4,i5);
902 uisec = isec - isec % 2;
903 i4 = ll + ulsec - 1;
904 for (l = ll; l <= i4; l += 2)
906 i5 = ii + uisec - 1;
907 for (i = ii; i <= i5; i += 2)
909 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
910 a[i + l * a_dim1];
911 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
912 a[i + (l + 1) * a_dim1];
913 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
914 a[i + 1 + l * a_dim1];
915 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
916 a[i + 1 + (l + 1) * a_dim1];
918 if (uisec < isec)
920 t1[l - ll + 1 + (isec << 8) - 257] =
921 a[ii + isec - 1 + l * a_dim1];
922 t1[l - ll + 2 + (isec << 8) - 257] =
923 a[ii + isec - 1 + (l + 1) * a_dim1];
926 if (ulsec < lsec)
928 i4 = ii + isec - 1;
929 for (i = ii; i<= i4; ++i)
931 t1[lsec + ((i - ii + 1) << 8) - 257] =
932 a[i + (ll + lsec - 1) * a_dim1];
936 uisec = isec - isec % 4;
937 i4 = jj + ujsec - 1;
938 for (j = jj; j <= i4; j += 4)
940 i5 = ii + uisec - 1;
941 for (i = ii; i <= i5; i += 4)
943 f11 = c[i + j * c_dim1];
944 f21 = c[i + 1 + j * c_dim1];
945 f12 = c[i + (j + 1) * c_dim1];
946 f22 = c[i + 1 + (j + 1) * c_dim1];
947 f13 = c[i + (j + 2) * c_dim1];
948 f23 = c[i + 1 + (j + 2) * c_dim1];
949 f14 = c[i + (j + 3) * c_dim1];
950 f24 = c[i + 1 + (j + 3) * c_dim1];
951 f31 = c[i + 2 + j * c_dim1];
952 f41 = c[i + 3 + j * c_dim1];
953 f32 = c[i + 2 + (j + 1) * c_dim1];
954 f42 = c[i + 3 + (j + 1) * c_dim1];
955 f33 = c[i + 2 + (j + 2) * c_dim1];
956 f43 = c[i + 3 + (j + 2) * c_dim1];
957 f34 = c[i + 2 + (j + 3) * c_dim1];
958 f44 = c[i + 3 + (j + 3) * c_dim1];
959 i6 = ll + lsec - 1;
960 for (l = ll; l <= i6; ++l)
962 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
963 * b[l + j * b_dim1];
964 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
965 * b[l + j * b_dim1];
966 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
967 * b[l + (j + 1) * b_dim1];
968 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
969 * b[l + (j + 1) * b_dim1];
970 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
971 * b[l + (j + 2) * b_dim1];
972 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
973 * b[l + (j + 2) * b_dim1];
974 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
975 * b[l + (j + 3) * b_dim1];
976 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
977 * b[l + (j + 3) * b_dim1];
978 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
979 * b[l + j * b_dim1];
980 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
981 * b[l + j * b_dim1];
982 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
983 * b[l + (j + 1) * b_dim1];
984 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
985 * b[l + (j + 1) * b_dim1];
986 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
987 * b[l + (j + 2) * b_dim1];
988 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
989 * b[l + (j + 2) * b_dim1];
990 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
991 * b[l + (j + 3) * b_dim1];
992 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
993 * b[l + (j + 3) * b_dim1];
995 c[i + j * c_dim1] = f11;
996 c[i + 1 + j * c_dim1] = f21;
997 c[i + (j + 1) * c_dim1] = f12;
998 c[i + 1 + (j + 1) * c_dim1] = f22;
999 c[i + (j + 2) * c_dim1] = f13;
1000 c[i + 1 + (j + 2) * c_dim1] = f23;
1001 c[i + (j + 3) * c_dim1] = f14;
1002 c[i + 1 + (j + 3) * c_dim1] = f24;
1003 c[i + 2 + j * c_dim1] = f31;
1004 c[i + 3 + j * c_dim1] = f41;
1005 c[i + 2 + (j + 1) * c_dim1] = f32;
1006 c[i + 3 + (j + 1) * c_dim1] = f42;
1007 c[i + 2 + (j + 2) * c_dim1] = f33;
1008 c[i + 3 + (j + 2) * c_dim1] = f43;
1009 c[i + 2 + (j + 3) * c_dim1] = f34;
1010 c[i + 3 + (j + 3) * c_dim1] = f44;
1012 if (uisec < isec)
1014 i5 = ii + isec - 1;
1015 for (i = ii + uisec; i <= i5; ++i)
1017 f11 = c[i + j * c_dim1];
1018 f12 = c[i + (j + 1) * c_dim1];
1019 f13 = c[i + (j + 2) * c_dim1];
1020 f14 = c[i + (j + 3) * c_dim1];
1021 i6 = ll + lsec - 1;
1022 for (l = ll; l <= i6; ++l)
1024 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1025 257] * b[l + j * b_dim1];
1026 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1027 257] * b[l + (j + 1) * b_dim1];
1028 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1029 257] * b[l + (j + 2) * b_dim1];
1030 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1031 257] * b[l + (j + 3) * b_dim1];
1033 c[i + j * c_dim1] = f11;
1034 c[i + (j + 1) * c_dim1] = f12;
1035 c[i + (j + 2) * c_dim1] = f13;
1036 c[i + (j + 3) * c_dim1] = f14;
1040 if (ujsec < jsec)
1042 i4 = jj + jsec - 1;
1043 for (j = jj + ujsec; j <= i4; ++j)
1045 i5 = ii + uisec - 1;
1046 for (i = ii; i <= i5; i += 4)
1048 f11 = c[i + j * c_dim1];
1049 f21 = c[i + 1 + j * c_dim1];
1050 f31 = c[i + 2 + j * c_dim1];
1051 f41 = c[i + 3 + j * c_dim1];
1052 i6 = ll + lsec - 1;
1053 for (l = ll; l <= i6; ++l)
1055 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1056 257] * b[l + j * b_dim1];
1057 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
1058 257] * b[l + j * b_dim1];
1059 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
1060 257] * b[l + j * b_dim1];
1061 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
1062 257] * b[l + j * b_dim1];
1064 c[i + j * c_dim1] = f11;
1065 c[i + 1 + j * c_dim1] = f21;
1066 c[i + 2 + j * c_dim1] = f31;
1067 c[i + 3 + j * c_dim1] = f41;
1069 i5 = ii + isec - 1;
1070 for (i = ii + uisec; i <= i5; ++i)
1072 f11 = c[i + j * c_dim1];
1073 i6 = ll + lsec - 1;
1074 for (l = ll; l <= i6; ++l)
1076 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1077 257] * b[l + j * b_dim1];
1079 c[i + j * c_dim1] = f11;
1086 free(t1);
1087 return;
1089 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
1091 if (GFC_DESCRIPTOR_RANK (a) != 1)
1093 const GFC_COMPLEX_10 *restrict abase_x;
1094 const GFC_COMPLEX_10 *restrict bbase_y;
1095 GFC_COMPLEX_10 *restrict dest_y;
1096 GFC_COMPLEX_10 s;
1098 for (y = 0; y < ycount; y++)
1100 bbase_y = &bbase[y*bystride];
1101 dest_y = &dest[y*rystride];
1102 for (x = 0; x < xcount; x++)
1104 abase_x = &abase[x*axstride];
1105 s = (GFC_COMPLEX_10) 0;
1106 for (n = 0; n < count; n++)
1107 s += abase_x[n] * bbase_y[n];
1108 dest_y[x] = s;
1112 else
1114 const GFC_COMPLEX_10 *restrict bbase_y;
1115 GFC_COMPLEX_10 s;
1117 for (y = 0; y < ycount; y++)
1119 bbase_y = &bbase[y*bystride];
1120 s = (GFC_COMPLEX_10) 0;
1121 for (n = 0; n < count; n++)
1122 s += abase[n*axstride] * bbase_y[n];
1123 dest[y*rystride] = s;
1127 else if (axstride < aystride)
1129 for (y = 0; y < ycount; y++)
1130 for (x = 0; x < xcount; x++)
1131 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_10)0;
1133 for (y = 0; y < ycount; y++)
1134 for (n = 0; n < count; n++)
1135 for (x = 0; x < xcount; x++)
1136 /* dest[x,y] += a[x,n] * b[n,y] */
1137 dest[x*rxstride + y*rystride] +=
1138 abase[x*axstride + n*aystride] *
1139 bbase[n*bxstride + y*bystride];
1141 else if (GFC_DESCRIPTOR_RANK (a) == 1)
1143 const GFC_COMPLEX_10 *restrict bbase_y;
1144 GFC_COMPLEX_10 s;
1146 for (y = 0; y < ycount; y++)
1148 bbase_y = &bbase[y*bystride];
1149 s = (GFC_COMPLEX_10) 0;
1150 for (n = 0; n < count; n++)
1151 s += abase[n*axstride] * bbase_y[n*bxstride];
1152 dest[y*rxstride] = s;
1155 else
1157 const GFC_COMPLEX_10 *restrict abase_x;
1158 const GFC_COMPLEX_10 *restrict bbase_y;
1159 GFC_COMPLEX_10 *restrict dest_y;
1160 GFC_COMPLEX_10 s;
1162 for (y = 0; y < ycount; y++)
1164 bbase_y = &bbase[y*bystride];
1165 dest_y = &dest[y*rystride];
1166 for (x = 0; x < xcount; x++)
1168 abase_x = &abase[x*axstride];
1169 s = (GFC_COMPLEX_10) 0;
1170 for (n = 0; n < count; n++)
1171 s += abase_x[n*aystride] * bbase_y[n*bxstride];
1172 dest_y[x*rxstride] = s;
1177 #undef POW3
1178 #undef min
1179 #undef max
1181 #endif
1183 #endif