lower-bitint: Fix lowering of non-_BitInt to _BitInt cast merged with some wider...
[official-gcc.git] / libgfortran / generated / matmulavx128_c16.c
blob200ff4708f85e404309d91cc2d43f89c7b0de044
1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2023 Free Software Foundation, Inc.
3 Contributed by Thomas Koenig <tkoenig@gcc.gnu.org>.
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
27 #include <string.h>
28 #include <assert.h>
31 /* These are the specific versions of matmul with -mprefer-avx128. */
33 #if defined (HAVE_GFC_COMPLEX_16)
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36 passed to us by the front-end, in which case we call it for large
37 matrices. */
39 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
40 const int *, const GFC_COMPLEX_16 *, const GFC_COMPLEX_16 *,
41 const int *, const GFC_COMPLEX_16 *, const int *,
42 const GFC_COMPLEX_16 *, GFC_COMPLEX_16 *, const int *,
43 int, int);
45 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
46 void
47 matmul_c16_avx128_fma3 (gfc_array_c16 * const restrict retarray,
48 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
49 int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma")));
50 internal_proto(matmul_c16_avx128_fma3);
51 void
52 matmul_c16_avx128_fma3 (gfc_array_c16 * const restrict retarray,
53 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
54 int blas_limit, blas_call gemm)
56 const GFC_COMPLEX_16 * restrict abase;
57 const GFC_COMPLEX_16 * restrict bbase;
58 GFC_COMPLEX_16 * restrict dest;
60 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
61 index_type x, y, n, count, xcount, ycount;
63 assert (GFC_DESCRIPTOR_RANK (a) == 2
64 || GFC_DESCRIPTOR_RANK (b) == 2);
66 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
68 Either A or B (but not both) can be rank 1:
70 o One-dimensional argument A is implicitly treated as a row matrix
71 dimensioned [1,count], so xcount=1.
73 o One-dimensional argument B is implicitly treated as a column matrix
74 dimensioned [count, 1], so ycount=1.
77 if (retarray->base_addr == NULL)
79 if (GFC_DESCRIPTOR_RANK (a) == 1)
81 GFC_DIMENSION_SET(retarray->dim[0], 0,
82 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
84 else if (GFC_DESCRIPTOR_RANK (b) == 1)
86 GFC_DIMENSION_SET(retarray->dim[0], 0,
87 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
89 else
91 GFC_DIMENSION_SET(retarray->dim[0], 0,
92 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
94 GFC_DIMENSION_SET(retarray->dim[1], 0,
95 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
96 GFC_DESCRIPTOR_EXTENT(retarray,0));
99 retarray->base_addr
100 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_COMPLEX_16));
101 retarray->offset = 0;
103 else if (unlikely (compile_options.bounds_check))
105 index_type ret_extent, arg_extent;
107 if (GFC_DESCRIPTOR_RANK (a) == 1)
109 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
110 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
111 if (arg_extent != ret_extent)
112 runtime_error ("Array bound mismatch for dimension 1 of "
113 "array (%ld/%ld) ",
114 (long int) ret_extent, (long int) arg_extent);
116 else if (GFC_DESCRIPTOR_RANK (b) == 1)
118 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
119 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
120 if (arg_extent != ret_extent)
121 runtime_error ("Array bound mismatch for dimension 1 of "
122 "array (%ld/%ld) ",
123 (long int) ret_extent, (long int) arg_extent);
125 else
127 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
128 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
129 if (arg_extent != ret_extent)
130 runtime_error ("Array bound mismatch for dimension 1 of "
131 "array (%ld/%ld) ",
132 (long int) ret_extent, (long int) arg_extent);
134 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
135 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
136 if (arg_extent != ret_extent)
137 runtime_error ("Array bound mismatch for dimension 2 of "
138 "array (%ld/%ld) ",
139 (long int) ret_extent, (long int) arg_extent);
144 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
146 /* One-dimensional result may be addressed in the code below
147 either as a row or a column matrix. We want both cases to
148 work. */
149 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
151 else
153 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
154 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
158 if (GFC_DESCRIPTOR_RANK (a) == 1)
160 /* Treat it as a a row matrix A[1,count]. */
161 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
162 aystride = 1;
164 xcount = 1;
165 count = GFC_DESCRIPTOR_EXTENT(a,0);
167 else
169 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
170 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
172 count = GFC_DESCRIPTOR_EXTENT(a,1);
173 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
176 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
178 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
179 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
180 "in dimension 1: is %ld, should be %ld",
181 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
184 if (GFC_DESCRIPTOR_RANK (b) == 1)
186 /* Treat it as a column matrix B[count,1] */
187 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
189 /* bystride should never be used for 1-dimensional b.
190 The value is only used for calculation of the
191 memory by the buffer. */
192 bystride = 256;
193 ycount = 1;
195 else
197 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
198 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
199 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
202 abase = a->base_addr;
203 bbase = b->base_addr;
204 dest = retarray->base_addr;
206 /* Now that everything is set up, we perform the multiplication
207 itself. */
209 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
210 #define min(a,b) ((a) <= (b) ? (a) : (b))
211 #define max(a,b) ((a) >= (b) ? (a) : (b))
213 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
214 && (bxstride == 1 || bystride == 1)
215 && (((float) xcount) * ((float) ycount) * ((float) count)
216 > POW3(blas_limit)))
218 const int m = xcount, n = ycount, k = count, ldc = rystride;
219 const GFC_COMPLEX_16 one = 1, zero = 0;
220 const int lda = (axstride == 1) ? aystride : axstride,
221 ldb = (bxstride == 1) ? bystride : bxstride;
223 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
225 assert (gemm != NULL);
226 const char *transa, *transb;
227 if (try_blas & 2)
228 transa = "C";
229 else
230 transa = axstride == 1 ? "N" : "T";
232 if (try_blas & 4)
233 transb = "C";
234 else
235 transb = bxstride == 1 ? "N" : "T";
237 gemm (transa, transb , &m,
238 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
239 &ldc, 1, 1);
240 return;
244 if (rxstride == 1 && axstride == 1 && bxstride == 1
245 && GFC_DESCRIPTOR_RANK (b) != 1)
247 /* This block of code implements a tuned matmul, derived from
248 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
250 Bo Kagstrom and Per Ling
251 Department of Computing Science
252 Umea University
253 S-901 87 Umea, Sweden
255 from netlib.org, translated to C, and modified for matmul.m4. */
257 const GFC_COMPLEX_16 *a, *b;
258 GFC_COMPLEX_16 *c;
259 const index_type m = xcount, n = ycount, k = count;
261 /* System generated locals */
262 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
263 i1, i2, i3, i4, i5, i6;
265 /* Local variables */
266 GFC_COMPLEX_16 f11, f12, f21, f22, f31, f32, f41, f42,
267 f13, f14, f23, f24, f33, f34, f43, f44;
268 index_type i, j, l, ii, jj, ll;
269 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
270 GFC_COMPLEX_16 *t1;
272 a = abase;
273 b = bbase;
274 c = retarray->base_addr;
276 /* Parameter adjustments */
277 c_dim1 = rystride;
278 c_offset = 1 + c_dim1;
279 c -= c_offset;
280 a_dim1 = aystride;
281 a_offset = 1 + a_dim1;
282 a -= a_offset;
283 b_dim1 = bystride;
284 b_offset = 1 + b_dim1;
285 b -= b_offset;
287 /* Empty c first. */
288 for (j=1; j<=n; j++)
289 for (i=1; i<=m; i++)
290 c[i + j * c_dim1] = (GFC_COMPLEX_16)0;
292 /* Early exit if possible */
293 if (m == 0 || n == 0 || k == 0)
294 return;
296 /* Adjust size of t1 to what is needed. */
297 index_type t1_dim, a_sz;
298 if (aystride == 1)
299 a_sz = rystride;
300 else
301 a_sz = a_dim1;
303 t1_dim = a_sz * 256 + b_dim1;
304 if (t1_dim > 65536)
305 t1_dim = 65536;
307 t1 = malloc (t1_dim * sizeof(GFC_COMPLEX_16));
309 /* Start turning the crank. */
310 i1 = n;
311 for (jj = 1; jj <= i1; jj += 512)
313 /* Computing MIN */
314 i2 = 512;
315 i3 = n - jj + 1;
316 jsec = min(i2,i3);
317 ujsec = jsec - jsec % 4;
318 i2 = k;
319 for (ll = 1; ll <= i2; ll += 256)
321 /* Computing MIN */
322 i3 = 256;
323 i4 = k - ll + 1;
324 lsec = min(i3,i4);
325 ulsec = lsec - lsec % 2;
327 i3 = m;
328 for (ii = 1; ii <= i3; ii += 256)
330 /* Computing MIN */
331 i4 = 256;
332 i5 = m - ii + 1;
333 isec = min(i4,i5);
334 uisec = isec - isec % 2;
335 i4 = ll + ulsec - 1;
336 for (l = ll; l <= i4; l += 2)
338 i5 = ii + uisec - 1;
339 for (i = ii; i <= i5; i += 2)
341 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
342 a[i + l * a_dim1];
343 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
344 a[i + (l + 1) * a_dim1];
345 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
346 a[i + 1 + l * a_dim1];
347 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
348 a[i + 1 + (l + 1) * a_dim1];
350 if (uisec < isec)
352 t1[l - ll + 1 + (isec << 8) - 257] =
353 a[ii + isec - 1 + l * a_dim1];
354 t1[l - ll + 2 + (isec << 8) - 257] =
355 a[ii + isec - 1 + (l + 1) * a_dim1];
358 if (ulsec < lsec)
360 i4 = ii + isec - 1;
361 for (i = ii; i<= i4; ++i)
363 t1[lsec + ((i - ii + 1) << 8) - 257] =
364 a[i + (ll + lsec - 1) * a_dim1];
368 uisec = isec - isec % 4;
369 i4 = jj + ujsec - 1;
370 for (j = jj; j <= i4; j += 4)
372 i5 = ii + uisec - 1;
373 for (i = ii; i <= i5; i += 4)
375 f11 = c[i + j * c_dim1];
376 f21 = c[i + 1 + j * c_dim1];
377 f12 = c[i + (j + 1) * c_dim1];
378 f22 = c[i + 1 + (j + 1) * c_dim1];
379 f13 = c[i + (j + 2) * c_dim1];
380 f23 = c[i + 1 + (j + 2) * c_dim1];
381 f14 = c[i + (j + 3) * c_dim1];
382 f24 = c[i + 1 + (j + 3) * c_dim1];
383 f31 = c[i + 2 + j * c_dim1];
384 f41 = c[i + 3 + j * c_dim1];
385 f32 = c[i + 2 + (j + 1) * c_dim1];
386 f42 = c[i + 3 + (j + 1) * c_dim1];
387 f33 = c[i + 2 + (j + 2) * c_dim1];
388 f43 = c[i + 3 + (j + 2) * c_dim1];
389 f34 = c[i + 2 + (j + 3) * c_dim1];
390 f44 = c[i + 3 + (j + 3) * c_dim1];
391 i6 = ll + lsec - 1;
392 for (l = ll; l <= i6; ++l)
394 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
395 * b[l + j * b_dim1];
396 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
397 * b[l + j * b_dim1];
398 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
399 * b[l + (j + 1) * b_dim1];
400 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
401 * b[l + (j + 1) * b_dim1];
402 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
403 * b[l + (j + 2) * b_dim1];
404 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
405 * b[l + (j + 2) * b_dim1];
406 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
407 * b[l + (j + 3) * b_dim1];
408 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
409 * b[l + (j + 3) * b_dim1];
410 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
411 * b[l + j * b_dim1];
412 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
413 * b[l + j * b_dim1];
414 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
415 * b[l + (j + 1) * b_dim1];
416 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
417 * b[l + (j + 1) * b_dim1];
418 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
419 * b[l + (j + 2) * b_dim1];
420 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
421 * b[l + (j + 2) * b_dim1];
422 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
423 * b[l + (j + 3) * b_dim1];
424 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
425 * b[l + (j + 3) * b_dim1];
427 c[i + j * c_dim1] = f11;
428 c[i + 1 + j * c_dim1] = f21;
429 c[i + (j + 1) * c_dim1] = f12;
430 c[i + 1 + (j + 1) * c_dim1] = f22;
431 c[i + (j + 2) * c_dim1] = f13;
432 c[i + 1 + (j + 2) * c_dim1] = f23;
433 c[i + (j + 3) * c_dim1] = f14;
434 c[i + 1 + (j + 3) * c_dim1] = f24;
435 c[i + 2 + j * c_dim1] = f31;
436 c[i + 3 + j * c_dim1] = f41;
437 c[i + 2 + (j + 1) * c_dim1] = f32;
438 c[i + 3 + (j + 1) * c_dim1] = f42;
439 c[i + 2 + (j + 2) * c_dim1] = f33;
440 c[i + 3 + (j + 2) * c_dim1] = f43;
441 c[i + 2 + (j + 3) * c_dim1] = f34;
442 c[i + 3 + (j + 3) * c_dim1] = f44;
444 if (uisec < isec)
446 i5 = ii + isec - 1;
447 for (i = ii + uisec; i <= i5; ++i)
449 f11 = c[i + j * c_dim1];
450 f12 = c[i + (j + 1) * c_dim1];
451 f13 = c[i + (j + 2) * c_dim1];
452 f14 = c[i + (j + 3) * c_dim1];
453 i6 = ll + lsec - 1;
454 for (l = ll; l <= i6; ++l)
456 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
457 257] * b[l + j * b_dim1];
458 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
459 257] * b[l + (j + 1) * b_dim1];
460 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
461 257] * b[l + (j + 2) * b_dim1];
462 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
463 257] * b[l + (j + 3) * b_dim1];
465 c[i + j * c_dim1] = f11;
466 c[i + (j + 1) * c_dim1] = f12;
467 c[i + (j + 2) * c_dim1] = f13;
468 c[i + (j + 3) * c_dim1] = f14;
472 if (ujsec < jsec)
474 i4 = jj + jsec - 1;
475 for (j = jj + ujsec; j <= i4; ++j)
477 i5 = ii + uisec - 1;
478 for (i = ii; i <= i5; i += 4)
480 f11 = c[i + j * c_dim1];
481 f21 = c[i + 1 + j * c_dim1];
482 f31 = c[i + 2 + j * c_dim1];
483 f41 = c[i + 3 + j * c_dim1];
484 i6 = ll + lsec - 1;
485 for (l = ll; l <= i6; ++l)
487 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
488 257] * b[l + j * b_dim1];
489 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
490 257] * b[l + j * b_dim1];
491 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
492 257] * b[l + j * b_dim1];
493 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
494 257] * b[l + j * b_dim1];
496 c[i + j * c_dim1] = f11;
497 c[i + 1 + j * c_dim1] = f21;
498 c[i + 2 + j * c_dim1] = f31;
499 c[i + 3 + j * c_dim1] = f41;
501 i5 = ii + isec - 1;
502 for (i = ii + uisec; i <= i5; ++i)
504 f11 = c[i + j * c_dim1];
505 i6 = ll + lsec - 1;
506 for (l = ll; l <= i6; ++l)
508 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
509 257] * b[l + j * b_dim1];
511 c[i + j * c_dim1] = f11;
518 free(t1);
519 return;
521 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
523 if (GFC_DESCRIPTOR_RANK (a) != 1)
525 const GFC_COMPLEX_16 *restrict abase_x;
526 const GFC_COMPLEX_16 *restrict bbase_y;
527 GFC_COMPLEX_16 *restrict dest_y;
528 GFC_COMPLEX_16 s;
530 for (y = 0; y < ycount; y++)
532 bbase_y = &bbase[y*bystride];
533 dest_y = &dest[y*rystride];
534 for (x = 0; x < xcount; x++)
536 abase_x = &abase[x*axstride];
537 s = (GFC_COMPLEX_16) 0;
538 for (n = 0; n < count; n++)
539 s += abase_x[n] * bbase_y[n];
540 dest_y[x] = s;
544 else
546 const GFC_COMPLEX_16 *restrict bbase_y;
547 GFC_COMPLEX_16 s;
549 for (y = 0; y < ycount; y++)
551 bbase_y = &bbase[y*bystride];
552 s = (GFC_COMPLEX_16) 0;
553 for (n = 0; n < count; n++)
554 s += abase[n*axstride] * bbase_y[n];
555 dest[y*rystride] = s;
559 else if (GFC_DESCRIPTOR_RANK (a) == 1)
561 const GFC_COMPLEX_16 *restrict bbase_y;
562 GFC_COMPLEX_16 s;
564 for (y = 0; y < ycount; y++)
566 bbase_y = &bbase[y*bystride];
567 s = (GFC_COMPLEX_16) 0;
568 for (n = 0; n < count; n++)
569 s += abase[n*axstride] * bbase_y[n*bxstride];
570 dest[y*rxstride] = s;
573 else if (axstride < aystride)
575 for (y = 0; y < ycount; y++)
576 for (x = 0; x < xcount; x++)
577 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_16)0;
579 for (y = 0; y < ycount; y++)
580 for (n = 0; n < count; n++)
581 for (x = 0; x < xcount; x++)
582 /* dest[x,y] += a[x,n] * b[n,y] */
583 dest[x*rxstride + y*rystride] +=
584 abase[x*axstride + n*aystride] *
585 bbase[n*bxstride + y*bystride];
587 else
589 const GFC_COMPLEX_16 *restrict abase_x;
590 const GFC_COMPLEX_16 *restrict bbase_y;
591 GFC_COMPLEX_16 *restrict dest_y;
592 GFC_COMPLEX_16 s;
594 for (y = 0; y < ycount; y++)
596 bbase_y = &bbase[y*bystride];
597 dest_y = &dest[y*rystride];
598 for (x = 0; x < xcount; x++)
600 abase_x = &abase[x*axstride];
601 s = (GFC_COMPLEX_16) 0;
602 for (n = 0; n < count; n++)
603 s += abase_x[n*aystride] * bbase_y[n*bxstride];
604 dest_y[x*rxstride] = s;
609 #undef POW3
610 #undef min
611 #undef max
613 #endif
615 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
616 void
617 matmul_c16_avx128_fma4 (gfc_array_c16 * const restrict retarray,
618 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
619 int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma4")));
620 internal_proto(matmul_c16_avx128_fma4);
621 void
622 matmul_c16_avx128_fma4 (gfc_array_c16 * const restrict retarray,
623 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
624 int blas_limit, blas_call gemm)
626 const GFC_COMPLEX_16 * restrict abase;
627 const GFC_COMPLEX_16 * restrict bbase;
628 GFC_COMPLEX_16 * restrict dest;
630 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
631 index_type x, y, n, count, xcount, ycount;
633 assert (GFC_DESCRIPTOR_RANK (a) == 2
634 || GFC_DESCRIPTOR_RANK (b) == 2);
636 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
638 Either A or B (but not both) can be rank 1:
640 o One-dimensional argument A is implicitly treated as a row matrix
641 dimensioned [1,count], so xcount=1.
643 o One-dimensional argument B is implicitly treated as a column matrix
644 dimensioned [count, 1], so ycount=1.
647 if (retarray->base_addr == NULL)
649 if (GFC_DESCRIPTOR_RANK (a) == 1)
651 GFC_DIMENSION_SET(retarray->dim[0], 0,
652 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
654 else if (GFC_DESCRIPTOR_RANK (b) == 1)
656 GFC_DIMENSION_SET(retarray->dim[0], 0,
657 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
659 else
661 GFC_DIMENSION_SET(retarray->dim[0], 0,
662 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
664 GFC_DIMENSION_SET(retarray->dim[1], 0,
665 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
666 GFC_DESCRIPTOR_EXTENT(retarray,0));
669 retarray->base_addr
670 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_COMPLEX_16));
671 retarray->offset = 0;
673 else if (unlikely (compile_options.bounds_check))
675 index_type ret_extent, arg_extent;
677 if (GFC_DESCRIPTOR_RANK (a) == 1)
679 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
680 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
681 if (arg_extent != ret_extent)
682 runtime_error ("Array bound mismatch for dimension 1 of "
683 "array (%ld/%ld) ",
684 (long int) ret_extent, (long int) arg_extent);
686 else if (GFC_DESCRIPTOR_RANK (b) == 1)
688 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
689 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
690 if (arg_extent != ret_extent)
691 runtime_error ("Array bound mismatch for dimension 1 of "
692 "array (%ld/%ld) ",
693 (long int) ret_extent, (long int) arg_extent);
695 else
697 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
698 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
699 if (arg_extent != ret_extent)
700 runtime_error ("Array bound mismatch for dimension 1 of "
701 "array (%ld/%ld) ",
702 (long int) ret_extent, (long int) arg_extent);
704 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
705 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
706 if (arg_extent != ret_extent)
707 runtime_error ("Array bound mismatch for dimension 2 of "
708 "array (%ld/%ld) ",
709 (long int) ret_extent, (long int) arg_extent);
714 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
716 /* One-dimensional result may be addressed in the code below
717 either as a row or a column matrix. We want both cases to
718 work. */
719 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
721 else
723 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
724 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
728 if (GFC_DESCRIPTOR_RANK (a) == 1)
730 /* Treat it as a a row matrix A[1,count]. */
731 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
732 aystride = 1;
734 xcount = 1;
735 count = GFC_DESCRIPTOR_EXTENT(a,0);
737 else
739 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
740 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
742 count = GFC_DESCRIPTOR_EXTENT(a,1);
743 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
746 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
748 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
749 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
750 "in dimension 1: is %ld, should be %ld",
751 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
754 if (GFC_DESCRIPTOR_RANK (b) == 1)
756 /* Treat it as a column matrix B[count,1] */
757 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
759 /* bystride should never be used for 1-dimensional b.
760 The value is only used for calculation of the
761 memory by the buffer. */
762 bystride = 256;
763 ycount = 1;
765 else
767 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
768 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
769 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
772 abase = a->base_addr;
773 bbase = b->base_addr;
774 dest = retarray->base_addr;
776 /* Now that everything is set up, we perform the multiplication
777 itself. */
779 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
780 #define min(a,b) ((a) <= (b) ? (a) : (b))
781 #define max(a,b) ((a) >= (b) ? (a) : (b))
783 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
784 && (bxstride == 1 || bystride == 1)
785 && (((float) xcount) * ((float) ycount) * ((float) count)
786 > POW3(blas_limit)))
788 const int m = xcount, n = ycount, k = count, ldc = rystride;
789 const GFC_COMPLEX_16 one = 1, zero = 0;
790 const int lda = (axstride == 1) ? aystride : axstride,
791 ldb = (bxstride == 1) ? bystride : bxstride;
793 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
795 assert (gemm != NULL);
796 const char *transa, *transb;
797 if (try_blas & 2)
798 transa = "C";
799 else
800 transa = axstride == 1 ? "N" : "T";
802 if (try_blas & 4)
803 transb = "C";
804 else
805 transb = bxstride == 1 ? "N" : "T";
807 gemm (transa, transb , &m,
808 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
809 &ldc, 1, 1);
810 return;
814 if (rxstride == 1 && axstride == 1 && bxstride == 1
815 && GFC_DESCRIPTOR_RANK (b) != 1)
817 /* This block of code implements a tuned matmul, derived from
818 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
820 Bo Kagstrom and Per Ling
821 Department of Computing Science
822 Umea University
823 S-901 87 Umea, Sweden
825 from netlib.org, translated to C, and modified for matmul.m4. */
827 const GFC_COMPLEX_16 *a, *b;
828 GFC_COMPLEX_16 *c;
829 const index_type m = xcount, n = ycount, k = count;
831 /* System generated locals */
832 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
833 i1, i2, i3, i4, i5, i6;
835 /* Local variables */
836 GFC_COMPLEX_16 f11, f12, f21, f22, f31, f32, f41, f42,
837 f13, f14, f23, f24, f33, f34, f43, f44;
838 index_type i, j, l, ii, jj, ll;
839 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
840 GFC_COMPLEX_16 *t1;
842 a = abase;
843 b = bbase;
844 c = retarray->base_addr;
846 /* Parameter adjustments */
847 c_dim1 = rystride;
848 c_offset = 1 + c_dim1;
849 c -= c_offset;
850 a_dim1 = aystride;
851 a_offset = 1 + a_dim1;
852 a -= a_offset;
853 b_dim1 = bystride;
854 b_offset = 1 + b_dim1;
855 b -= b_offset;
857 /* Empty c first. */
858 for (j=1; j<=n; j++)
859 for (i=1; i<=m; i++)
860 c[i + j * c_dim1] = (GFC_COMPLEX_16)0;
862 /* Early exit if possible */
863 if (m == 0 || n == 0 || k == 0)
864 return;
866 /* Adjust size of t1 to what is needed. */
867 index_type t1_dim, a_sz;
868 if (aystride == 1)
869 a_sz = rystride;
870 else
871 a_sz = a_dim1;
873 t1_dim = a_sz * 256 + b_dim1;
874 if (t1_dim > 65536)
875 t1_dim = 65536;
877 t1 = malloc (t1_dim * sizeof(GFC_COMPLEX_16));
879 /* Start turning the crank. */
880 i1 = n;
881 for (jj = 1; jj <= i1; jj += 512)
883 /* Computing MIN */
884 i2 = 512;
885 i3 = n - jj + 1;
886 jsec = min(i2,i3);
887 ujsec = jsec - jsec % 4;
888 i2 = k;
889 for (ll = 1; ll <= i2; ll += 256)
891 /* Computing MIN */
892 i3 = 256;
893 i4 = k - ll + 1;
894 lsec = min(i3,i4);
895 ulsec = lsec - lsec % 2;
897 i3 = m;
898 for (ii = 1; ii <= i3; ii += 256)
900 /* Computing MIN */
901 i4 = 256;
902 i5 = m - ii + 1;
903 isec = min(i4,i5);
904 uisec = isec - isec % 2;
905 i4 = ll + ulsec - 1;
906 for (l = ll; l <= i4; l += 2)
908 i5 = ii + uisec - 1;
909 for (i = ii; i <= i5; i += 2)
911 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
912 a[i + l * a_dim1];
913 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
914 a[i + (l + 1) * a_dim1];
915 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
916 a[i + 1 + l * a_dim1];
917 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
918 a[i + 1 + (l + 1) * a_dim1];
920 if (uisec < isec)
922 t1[l - ll + 1 + (isec << 8) - 257] =
923 a[ii + isec - 1 + l * a_dim1];
924 t1[l - ll + 2 + (isec << 8) - 257] =
925 a[ii + isec - 1 + (l + 1) * a_dim1];
928 if (ulsec < lsec)
930 i4 = ii + isec - 1;
931 for (i = ii; i<= i4; ++i)
933 t1[lsec + ((i - ii + 1) << 8) - 257] =
934 a[i + (ll + lsec - 1) * a_dim1];
938 uisec = isec - isec % 4;
939 i4 = jj + ujsec - 1;
940 for (j = jj; j <= i4; j += 4)
942 i5 = ii + uisec - 1;
943 for (i = ii; i <= i5; i += 4)
945 f11 = c[i + j * c_dim1];
946 f21 = c[i + 1 + j * c_dim1];
947 f12 = c[i + (j + 1) * c_dim1];
948 f22 = c[i + 1 + (j + 1) * c_dim1];
949 f13 = c[i + (j + 2) * c_dim1];
950 f23 = c[i + 1 + (j + 2) * c_dim1];
951 f14 = c[i + (j + 3) * c_dim1];
952 f24 = c[i + 1 + (j + 3) * c_dim1];
953 f31 = c[i + 2 + j * c_dim1];
954 f41 = c[i + 3 + j * c_dim1];
955 f32 = c[i + 2 + (j + 1) * c_dim1];
956 f42 = c[i + 3 + (j + 1) * c_dim1];
957 f33 = c[i + 2 + (j + 2) * c_dim1];
958 f43 = c[i + 3 + (j + 2) * c_dim1];
959 f34 = c[i + 2 + (j + 3) * c_dim1];
960 f44 = c[i + 3 + (j + 3) * c_dim1];
961 i6 = ll + lsec - 1;
962 for (l = ll; l <= i6; ++l)
964 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
965 * b[l + j * b_dim1];
966 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
967 * b[l + j * b_dim1];
968 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
969 * b[l + (j + 1) * b_dim1];
970 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
971 * b[l + (j + 1) * b_dim1];
972 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
973 * b[l + (j + 2) * b_dim1];
974 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
975 * b[l + (j + 2) * b_dim1];
976 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
977 * b[l + (j + 3) * b_dim1];
978 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
979 * b[l + (j + 3) * b_dim1];
980 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
981 * b[l + j * b_dim1];
982 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
983 * b[l + j * b_dim1];
984 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
985 * b[l + (j + 1) * b_dim1];
986 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
987 * b[l + (j + 1) * b_dim1];
988 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
989 * b[l + (j + 2) * b_dim1];
990 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
991 * b[l + (j + 2) * b_dim1];
992 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
993 * b[l + (j + 3) * b_dim1];
994 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
995 * b[l + (j + 3) * b_dim1];
997 c[i + j * c_dim1] = f11;
998 c[i + 1 + j * c_dim1] = f21;
999 c[i + (j + 1) * c_dim1] = f12;
1000 c[i + 1 + (j + 1) * c_dim1] = f22;
1001 c[i + (j + 2) * c_dim1] = f13;
1002 c[i + 1 + (j + 2) * c_dim1] = f23;
1003 c[i + (j + 3) * c_dim1] = f14;
1004 c[i + 1 + (j + 3) * c_dim1] = f24;
1005 c[i + 2 + j * c_dim1] = f31;
1006 c[i + 3 + j * c_dim1] = f41;
1007 c[i + 2 + (j + 1) * c_dim1] = f32;
1008 c[i + 3 + (j + 1) * c_dim1] = f42;
1009 c[i + 2 + (j + 2) * c_dim1] = f33;
1010 c[i + 3 + (j + 2) * c_dim1] = f43;
1011 c[i + 2 + (j + 3) * c_dim1] = f34;
1012 c[i + 3 + (j + 3) * c_dim1] = f44;
1014 if (uisec < isec)
1016 i5 = ii + isec - 1;
1017 for (i = ii + uisec; i <= i5; ++i)
1019 f11 = c[i + j * c_dim1];
1020 f12 = c[i + (j + 1) * c_dim1];
1021 f13 = c[i + (j + 2) * c_dim1];
1022 f14 = c[i + (j + 3) * c_dim1];
1023 i6 = ll + lsec - 1;
1024 for (l = ll; l <= i6; ++l)
1026 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1027 257] * b[l + j * b_dim1];
1028 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1029 257] * b[l + (j + 1) * b_dim1];
1030 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1031 257] * b[l + (j + 2) * b_dim1];
1032 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1033 257] * b[l + (j + 3) * b_dim1];
1035 c[i + j * c_dim1] = f11;
1036 c[i + (j + 1) * c_dim1] = f12;
1037 c[i + (j + 2) * c_dim1] = f13;
1038 c[i + (j + 3) * c_dim1] = f14;
1042 if (ujsec < jsec)
1044 i4 = jj + jsec - 1;
1045 for (j = jj + ujsec; j <= i4; ++j)
1047 i5 = ii + uisec - 1;
1048 for (i = ii; i <= i5; i += 4)
1050 f11 = c[i + j * c_dim1];
1051 f21 = c[i + 1 + j * c_dim1];
1052 f31 = c[i + 2 + j * c_dim1];
1053 f41 = c[i + 3 + j * c_dim1];
1054 i6 = ll + lsec - 1;
1055 for (l = ll; l <= i6; ++l)
1057 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1058 257] * b[l + j * b_dim1];
1059 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
1060 257] * b[l + j * b_dim1];
1061 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
1062 257] * b[l + j * b_dim1];
1063 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
1064 257] * b[l + j * b_dim1];
1066 c[i + j * c_dim1] = f11;
1067 c[i + 1 + j * c_dim1] = f21;
1068 c[i + 2 + j * c_dim1] = f31;
1069 c[i + 3 + j * c_dim1] = f41;
1071 i5 = ii + isec - 1;
1072 for (i = ii + uisec; i <= i5; ++i)
1074 f11 = c[i + j * c_dim1];
1075 i6 = ll + lsec - 1;
1076 for (l = ll; l <= i6; ++l)
1078 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1079 257] * b[l + j * b_dim1];
1081 c[i + j * c_dim1] = f11;
1088 free(t1);
1089 return;
1091 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
1093 if (GFC_DESCRIPTOR_RANK (a) != 1)
1095 const GFC_COMPLEX_16 *restrict abase_x;
1096 const GFC_COMPLEX_16 *restrict bbase_y;
1097 GFC_COMPLEX_16 *restrict dest_y;
1098 GFC_COMPLEX_16 s;
1100 for (y = 0; y < ycount; y++)
1102 bbase_y = &bbase[y*bystride];
1103 dest_y = &dest[y*rystride];
1104 for (x = 0; x < xcount; x++)
1106 abase_x = &abase[x*axstride];
1107 s = (GFC_COMPLEX_16) 0;
1108 for (n = 0; n < count; n++)
1109 s += abase_x[n] * bbase_y[n];
1110 dest_y[x] = s;
1114 else
1116 const GFC_COMPLEX_16 *restrict bbase_y;
1117 GFC_COMPLEX_16 s;
1119 for (y = 0; y < ycount; y++)
1121 bbase_y = &bbase[y*bystride];
1122 s = (GFC_COMPLEX_16) 0;
1123 for (n = 0; n < count; n++)
1124 s += abase[n*axstride] * bbase_y[n];
1125 dest[y*rystride] = s;
1129 else if (GFC_DESCRIPTOR_RANK (a) == 1)
1131 const GFC_COMPLEX_16 *restrict bbase_y;
1132 GFC_COMPLEX_16 s;
1134 for (y = 0; y < ycount; y++)
1136 bbase_y = &bbase[y*bystride];
1137 s = (GFC_COMPLEX_16) 0;
1138 for (n = 0; n < count; n++)
1139 s += abase[n*axstride] * bbase_y[n*bxstride];
1140 dest[y*rxstride] = s;
1143 else if (axstride < aystride)
1145 for (y = 0; y < ycount; y++)
1146 for (x = 0; x < xcount; x++)
1147 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_16)0;
1149 for (y = 0; y < ycount; y++)
1150 for (n = 0; n < count; n++)
1151 for (x = 0; x < xcount; x++)
1152 /* dest[x,y] += a[x,n] * b[n,y] */
1153 dest[x*rxstride + y*rystride] +=
1154 abase[x*axstride + n*aystride] *
1155 bbase[n*bxstride + y*bystride];
1157 else
1159 const GFC_COMPLEX_16 *restrict abase_x;
1160 const GFC_COMPLEX_16 *restrict bbase_y;
1161 GFC_COMPLEX_16 *restrict dest_y;
1162 GFC_COMPLEX_16 s;
1164 for (y = 0; y < ycount; y++)
1166 bbase_y = &bbase[y*bystride];
1167 dest_y = &dest[y*rystride];
1168 for (x = 0; x < xcount; x++)
1170 abase_x = &abase[x*axstride];
1171 s = (GFC_COMPLEX_16) 0;
1172 for (n = 0; n < count; n++)
1173 s += abase_x[n*aystride] * bbase_y[n*bxstride];
1174 dest_y[x*rxstride] = s;
1179 #undef POW3
1180 #undef min
1181 #undef max
1183 #endif
1185 #endif