Daily bump.
[official-gcc.git] / libgfortran / generated / matmul_c16.c
blobd344a7a51344227a3a7727a0419b0715f6770acb
1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2023 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
27 #include <string.h>
28 #include <assert.h>
31 #if defined (HAVE_GFC_COMPLEX_16)
33 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
34 passed to us by the front-end, in which case we call it for large
35 matrices. */
37 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
38 const int *, const GFC_COMPLEX_16 *, const GFC_COMPLEX_16 *,
39 const int *, const GFC_COMPLEX_16 *, const int *,
40 const GFC_COMPLEX_16 *, GFC_COMPLEX_16 *, const int *,
41 int, int);
43 /* The order of loops is different in the case of plain matrix
44 multiplication C=MATMUL(A,B), and in the frequent special case where
45 the argument A is the temporary result of a TRANSPOSE intrinsic:
46 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
47 looking at their strides.
49 The equivalent Fortran pseudo-code is:
51 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
52 IF (.NOT.IS_TRANSPOSED(A)) THEN
53 C = 0
54 DO J=1,N
55 DO K=1,COUNT
56 DO I=1,M
57 C(I,J) = C(I,J)+A(I,K)*B(K,J)
58 ELSE
59 DO J=1,N
60 DO I=1,M
61 S = 0
62 DO K=1,COUNT
63 S = S+A(I,K)*B(K,J)
64 C(I,J) = S
65 ENDIF
68 /* If try_blas is set to a nonzero value, then the matmul function will
69 see if there is a way to perform the matrix multiplication by a call
70 to the BLAS gemm function. */
72 extern void matmul_c16 (gfc_array_c16 * const restrict retarray,
73 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
74 int blas_limit, blas_call gemm);
75 export_proto(matmul_c16);
77 /* Put exhaustive list of possible architectures here here, ORed together. */
79 #if defined(HAVE_AVX) || defined(HAVE_AVX2) || defined(HAVE_AVX512F)
81 #ifdef HAVE_AVX
82 static void
83 matmul_c16_avx (gfc_array_c16 * const restrict retarray,
84 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
85 int blas_limit, blas_call gemm) __attribute__((__target__("avx")));
86 static void
87 matmul_c16_avx (gfc_array_c16 * const restrict retarray,
88 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
89 int blas_limit, blas_call gemm)
91 const GFC_COMPLEX_16 * restrict abase;
92 const GFC_COMPLEX_16 * restrict bbase;
93 GFC_COMPLEX_16 * restrict dest;
95 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
96 index_type x, y, n, count, xcount, ycount;
98 assert (GFC_DESCRIPTOR_RANK (a) == 2
99 || GFC_DESCRIPTOR_RANK (b) == 2);
101 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
103 Either A or B (but not both) can be rank 1:
105 o One-dimensional argument A is implicitly treated as a row matrix
106 dimensioned [1,count], so xcount=1.
108 o One-dimensional argument B is implicitly treated as a column matrix
109 dimensioned [count, 1], so ycount=1.
112 if (retarray->base_addr == NULL)
114 if (GFC_DESCRIPTOR_RANK (a) == 1)
116 GFC_DIMENSION_SET(retarray->dim[0], 0,
117 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
119 else if (GFC_DESCRIPTOR_RANK (b) == 1)
121 GFC_DIMENSION_SET(retarray->dim[0], 0,
122 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
124 else
126 GFC_DIMENSION_SET(retarray->dim[0], 0,
127 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
129 GFC_DIMENSION_SET(retarray->dim[1], 0,
130 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
131 GFC_DESCRIPTOR_EXTENT(retarray,0));
134 retarray->base_addr
135 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_COMPLEX_16));
136 retarray->offset = 0;
138 else if (unlikely (compile_options.bounds_check))
140 index_type ret_extent, arg_extent;
142 if (GFC_DESCRIPTOR_RANK (a) == 1)
144 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
145 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
146 if (arg_extent != ret_extent)
147 runtime_error ("Array bound mismatch for dimension 1 of "
148 "array (%ld/%ld) ",
149 (long int) ret_extent, (long int) arg_extent);
151 else if (GFC_DESCRIPTOR_RANK (b) == 1)
153 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
154 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
155 if (arg_extent != ret_extent)
156 runtime_error ("Array bound mismatch for dimension 1 of "
157 "array (%ld/%ld) ",
158 (long int) ret_extent, (long int) arg_extent);
160 else
162 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
163 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
164 if (arg_extent != ret_extent)
165 runtime_error ("Array bound mismatch for dimension 1 of "
166 "array (%ld/%ld) ",
167 (long int) ret_extent, (long int) arg_extent);
169 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
170 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
171 if (arg_extent != ret_extent)
172 runtime_error ("Array bound mismatch for dimension 2 of "
173 "array (%ld/%ld) ",
174 (long int) ret_extent, (long int) arg_extent);
179 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
181 /* One-dimensional result may be addressed in the code below
182 either as a row or a column matrix. We want both cases to
183 work. */
184 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
186 else
188 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
189 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
193 if (GFC_DESCRIPTOR_RANK (a) == 1)
195 /* Treat it as a a row matrix A[1,count]. */
196 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
197 aystride = 1;
199 xcount = 1;
200 count = GFC_DESCRIPTOR_EXTENT(a,0);
202 else
204 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
205 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
207 count = GFC_DESCRIPTOR_EXTENT(a,1);
208 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
211 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
213 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
214 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
215 "in dimension 1: is %ld, should be %ld",
216 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
219 if (GFC_DESCRIPTOR_RANK (b) == 1)
221 /* Treat it as a column matrix B[count,1] */
222 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
224 /* bystride should never be used for 1-dimensional b.
225 The value is only used for calculation of the
226 memory by the buffer. */
227 bystride = 256;
228 ycount = 1;
230 else
232 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
233 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
234 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
237 abase = a->base_addr;
238 bbase = b->base_addr;
239 dest = retarray->base_addr;
241 /* Now that everything is set up, we perform the multiplication
242 itself. */
244 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
245 #define min(a,b) ((a) <= (b) ? (a) : (b))
246 #define max(a,b) ((a) >= (b) ? (a) : (b))
248 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
249 && (bxstride == 1 || bystride == 1)
250 && (((float) xcount) * ((float) ycount) * ((float) count)
251 > POW3(blas_limit)))
253 const int m = xcount, n = ycount, k = count, ldc = rystride;
254 const GFC_COMPLEX_16 one = 1, zero = 0;
255 const int lda = (axstride == 1) ? aystride : axstride,
256 ldb = (bxstride == 1) ? bystride : bxstride;
258 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
260 assert (gemm != NULL);
261 const char *transa, *transb;
262 if (try_blas & 2)
263 transa = "C";
264 else
265 transa = axstride == 1 ? "N" : "T";
267 if (try_blas & 4)
268 transb = "C";
269 else
270 transb = bxstride == 1 ? "N" : "T";
272 gemm (transa, transb , &m,
273 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
274 &ldc, 1, 1);
275 return;
279 if (rxstride == 1 && axstride == 1 && bxstride == 1
280 && GFC_DESCRIPTOR_RANK (b) != 1)
282 /* This block of code implements a tuned matmul, derived from
283 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
285 Bo Kagstrom and Per Ling
286 Department of Computing Science
287 Umea University
288 S-901 87 Umea, Sweden
290 from netlib.org, translated to C, and modified for matmul.m4. */
292 const GFC_COMPLEX_16 *a, *b;
293 GFC_COMPLEX_16 *c;
294 const index_type m = xcount, n = ycount, k = count;
296 /* System generated locals */
297 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
298 i1, i2, i3, i4, i5, i6;
300 /* Local variables */
301 GFC_COMPLEX_16 f11, f12, f21, f22, f31, f32, f41, f42,
302 f13, f14, f23, f24, f33, f34, f43, f44;
303 index_type i, j, l, ii, jj, ll;
304 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
305 GFC_COMPLEX_16 *t1;
307 a = abase;
308 b = bbase;
309 c = retarray->base_addr;
311 /* Parameter adjustments */
312 c_dim1 = rystride;
313 c_offset = 1 + c_dim1;
314 c -= c_offset;
315 a_dim1 = aystride;
316 a_offset = 1 + a_dim1;
317 a -= a_offset;
318 b_dim1 = bystride;
319 b_offset = 1 + b_dim1;
320 b -= b_offset;
322 /* Empty c first. */
323 for (j=1; j<=n; j++)
324 for (i=1; i<=m; i++)
325 c[i + j * c_dim1] = (GFC_COMPLEX_16)0;
327 /* Early exit if possible */
328 if (m == 0 || n == 0 || k == 0)
329 return;
331 /* Adjust size of t1 to what is needed. */
332 index_type t1_dim, a_sz;
333 if (aystride == 1)
334 a_sz = rystride;
335 else
336 a_sz = a_dim1;
338 t1_dim = a_sz * 256 + b_dim1;
339 if (t1_dim > 65536)
340 t1_dim = 65536;
342 t1 = malloc (t1_dim * sizeof(GFC_COMPLEX_16));
344 /* Start turning the crank. */
345 i1 = n;
346 for (jj = 1; jj <= i1; jj += 512)
348 /* Computing MIN */
349 i2 = 512;
350 i3 = n - jj + 1;
351 jsec = min(i2,i3);
352 ujsec = jsec - jsec % 4;
353 i2 = k;
354 for (ll = 1; ll <= i2; ll += 256)
356 /* Computing MIN */
357 i3 = 256;
358 i4 = k - ll + 1;
359 lsec = min(i3,i4);
360 ulsec = lsec - lsec % 2;
362 i3 = m;
363 for (ii = 1; ii <= i3; ii += 256)
365 /* Computing MIN */
366 i4 = 256;
367 i5 = m - ii + 1;
368 isec = min(i4,i5);
369 uisec = isec - isec % 2;
370 i4 = ll + ulsec - 1;
371 for (l = ll; l <= i4; l += 2)
373 i5 = ii + uisec - 1;
374 for (i = ii; i <= i5; i += 2)
376 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
377 a[i + l * a_dim1];
378 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
379 a[i + (l + 1) * a_dim1];
380 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
381 a[i + 1 + l * a_dim1];
382 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
383 a[i + 1 + (l + 1) * a_dim1];
385 if (uisec < isec)
387 t1[l - ll + 1 + (isec << 8) - 257] =
388 a[ii + isec - 1 + l * a_dim1];
389 t1[l - ll + 2 + (isec << 8) - 257] =
390 a[ii + isec - 1 + (l + 1) * a_dim1];
393 if (ulsec < lsec)
395 i4 = ii + isec - 1;
396 for (i = ii; i<= i4; ++i)
398 t1[lsec + ((i - ii + 1) << 8) - 257] =
399 a[i + (ll + lsec - 1) * a_dim1];
403 uisec = isec - isec % 4;
404 i4 = jj + ujsec - 1;
405 for (j = jj; j <= i4; j += 4)
407 i5 = ii + uisec - 1;
408 for (i = ii; i <= i5; i += 4)
410 f11 = c[i + j * c_dim1];
411 f21 = c[i + 1 + j * c_dim1];
412 f12 = c[i + (j + 1) * c_dim1];
413 f22 = c[i + 1 + (j + 1) * c_dim1];
414 f13 = c[i + (j + 2) * c_dim1];
415 f23 = c[i + 1 + (j + 2) * c_dim1];
416 f14 = c[i + (j + 3) * c_dim1];
417 f24 = c[i + 1 + (j + 3) * c_dim1];
418 f31 = c[i + 2 + j * c_dim1];
419 f41 = c[i + 3 + j * c_dim1];
420 f32 = c[i + 2 + (j + 1) * c_dim1];
421 f42 = c[i + 3 + (j + 1) * c_dim1];
422 f33 = c[i + 2 + (j + 2) * c_dim1];
423 f43 = c[i + 3 + (j + 2) * c_dim1];
424 f34 = c[i + 2 + (j + 3) * c_dim1];
425 f44 = c[i + 3 + (j + 3) * c_dim1];
426 i6 = ll + lsec - 1;
427 for (l = ll; l <= i6; ++l)
429 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
430 * b[l + j * b_dim1];
431 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
432 * b[l + j * b_dim1];
433 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
434 * b[l + (j + 1) * b_dim1];
435 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
436 * b[l + (j + 1) * b_dim1];
437 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
438 * b[l + (j + 2) * b_dim1];
439 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
440 * b[l + (j + 2) * b_dim1];
441 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
442 * b[l + (j + 3) * b_dim1];
443 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
444 * b[l + (j + 3) * b_dim1];
445 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
446 * b[l + j * b_dim1];
447 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
448 * b[l + j * b_dim1];
449 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
450 * b[l + (j + 1) * b_dim1];
451 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
452 * b[l + (j + 1) * b_dim1];
453 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
454 * b[l + (j + 2) * b_dim1];
455 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
456 * b[l + (j + 2) * b_dim1];
457 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
458 * b[l + (j + 3) * b_dim1];
459 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
460 * b[l + (j + 3) * b_dim1];
462 c[i + j * c_dim1] = f11;
463 c[i + 1 + j * c_dim1] = f21;
464 c[i + (j + 1) * c_dim1] = f12;
465 c[i + 1 + (j + 1) * c_dim1] = f22;
466 c[i + (j + 2) * c_dim1] = f13;
467 c[i + 1 + (j + 2) * c_dim1] = f23;
468 c[i + (j + 3) * c_dim1] = f14;
469 c[i + 1 + (j + 3) * c_dim1] = f24;
470 c[i + 2 + j * c_dim1] = f31;
471 c[i + 3 + j * c_dim1] = f41;
472 c[i + 2 + (j + 1) * c_dim1] = f32;
473 c[i + 3 + (j + 1) * c_dim1] = f42;
474 c[i + 2 + (j + 2) * c_dim1] = f33;
475 c[i + 3 + (j + 2) * c_dim1] = f43;
476 c[i + 2 + (j + 3) * c_dim1] = f34;
477 c[i + 3 + (j + 3) * c_dim1] = f44;
479 if (uisec < isec)
481 i5 = ii + isec - 1;
482 for (i = ii + uisec; i <= i5; ++i)
484 f11 = c[i + j * c_dim1];
485 f12 = c[i + (j + 1) * c_dim1];
486 f13 = c[i + (j + 2) * c_dim1];
487 f14 = c[i + (j + 3) * c_dim1];
488 i6 = ll + lsec - 1;
489 for (l = ll; l <= i6; ++l)
491 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
492 257] * b[l + j * b_dim1];
493 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
494 257] * b[l + (j + 1) * b_dim1];
495 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
496 257] * b[l + (j + 2) * b_dim1];
497 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
498 257] * b[l + (j + 3) * b_dim1];
500 c[i + j * c_dim1] = f11;
501 c[i + (j + 1) * c_dim1] = f12;
502 c[i + (j + 2) * c_dim1] = f13;
503 c[i + (j + 3) * c_dim1] = f14;
507 if (ujsec < jsec)
509 i4 = jj + jsec - 1;
510 for (j = jj + ujsec; j <= i4; ++j)
512 i5 = ii + uisec - 1;
513 for (i = ii; i <= i5; i += 4)
515 f11 = c[i + j * c_dim1];
516 f21 = c[i + 1 + j * c_dim1];
517 f31 = c[i + 2 + j * c_dim1];
518 f41 = c[i + 3 + j * c_dim1];
519 i6 = ll + lsec - 1;
520 for (l = ll; l <= i6; ++l)
522 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
523 257] * b[l + j * b_dim1];
524 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
525 257] * b[l + j * b_dim1];
526 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
527 257] * b[l + j * b_dim1];
528 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
529 257] * b[l + j * b_dim1];
531 c[i + j * c_dim1] = f11;
532 c[i + 1 + j * c_dim1] = f21;
533 c[i + 2 + j * c_dim1] = f31;
534 c[i + 3 + j * c_dim1] = f41;
536 i5 = ii + isec - 1;
537 for (i = ii + uisec; i <= i5; ++i)
539 f11 = c[i + j * c_dim1];
540 i6 = ll + lsec - 1;
541 for (l = ll; l <= i6; ++l)
543 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
544 257] * b[l + j * b_dim1];
546 c[i + j * c_dim1] = f11;
553 free(t1);
554 return;
556 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
558 if (GFC_DESCRIPTOR_RANK (a) != 1)
560 const GFC_COMPLEX_16 *restrict abase_x;
561 const GFC_COMPLEX_16 *restrict bbase_y;
562 GFC_COMPLEX_16 *restrict dest_y;
563 GFC_COMPLEX_16 s;
565 for (y = 0; y < ycount; y++)
567 bbase_y = &bbase[y*bystride];
568 dest_y = &dest[y*rystride];
569 for (x = 0; x < xcount; x++)
571 abase_x = &abase[x*axstride];
572 s = (GFC_COMPLEX_16) 0;
573 for (n = 0; n < count; n++)
574 s += abase_x[n] * bbase_y[n];
575 dest_y[x] = s;
579 else
581 const GFC_COMPLEX_16 *restrict bbase_y;
582 GFC_COMPLEX_16 s;
584 for (y = 0; y < ycount; y++)
586 bbase_y = &bbase[y*bystride];
587 s = (GFC_COMPLEX_16) 0;
588 for (n = 0; n < count; n++)
589 s += abase[n*axstride] * bbase_y[n];
590 dest[y*rystride] = s;
594 else if (GFC_DESCRIPTOR_RANK (a) == 1)
596 const GFC_COMPLEX_16 *restrict bbase_y;
597 GFC_COMPLEX_16 s;
599 for (y = 0; y < ycount; y++)
601 bbase_y = &bbase[y*bystride];
602 s = (GFC_COMPLEX_16) 0;
603 for (n = 0; n < count; n++)
604 s += abase[n*axstride] * bbase_y[n*bxstride];
605 dest[y*rxstride] = s;
608 else if (axstride < aystride)
610 for (y = 0; y < ycount; y++)
611 for (x = 0; x < xcount; x++)
612 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_16)0;
614 for (y = 0; y < ycount; y++)
615 for (n = 0; n < count; n++)
616 for (x = 0; x < xcount; x++)
617 /* dest[x,y] += a[x,n] * b[n,y] */
618 dest[x*rxstride + y*rystride] +=
619 abase[x*axstride + n*aystride] *
620 bbase[n*bxstride + y*bystride];
622 else
624 const GFC_COMPLEX_16 *restrict abase_x;
625 const GFC_COMPLEX_16 *restrict bbase_y;
626 GFC_COMPLEX_16 *restrict dest_y;
627 GFC_COMPLEX_16 s;
629 for (y = 0; y < ycount; y++)
631 bbase_y = &bbase[y*bystride];
632 dest_y = &dest[y*rystride];
633 for (x = 0; x < xcount; x++)
635 abase_x = &abase[x*axstride];
636 s = (GFC_COMPLEX_16) 0;
637 for (n = 0; n < count; n++)
638 s += abase_x[n*aystride] * bbase_y[n*bxstride];
639 dest_y[x*rxstride] = s;
644 #undef POW3
645 #undef min
646 #undef max
648 #endif /* HAVE_AVX */
650 #ifdef HAVE_AVX2
651 static void
652 matmul_c16_avx2 (gfc_array_c16 * const restrict retarray,
653 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
654 int blas_limit, blas_call gemm) __attribute__((__target__("avx2,fma")));
655 static void
656 matmul_c16_avx2 (gfc_array_c16 * const restrict retarray,
657 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
658 int blas_limit, blas_call gemm)
660 const GFC_COMPLEX_16 * restrict abase;
661 const GFC_COMPLEX_16 * restrict bbase;
662 GFC_COMPLEX_16 * restrict dest;
664 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
665 index_type x, y, n, count, xcount, ycount;
667 assert (GFC_DESCRIPTOR_RANK (a) == 2
668 || GFC_DESCRIPTOR_RANK (b) == 2);
670 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
672 Either A or B (but not both) can be rank 1:
674 o One-dimensional argument A is implicitly treated as a row matrix
675 dimensioned [1,count], so xcount=1.
677 o One-dimensional argument B is implicitly treated as a column matrix
678 dimensioned [count, 1], so ycount=1.
681 if (retarray->base_addr == NULL)
683 if (GFC_DESCRIPTOR_RANK (a) == 1)
685 GFC_DIMENSION_SET(retarray->dim[0], 0,
686 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
688 else if (GFC_DESCRIPTOR_RANK (b) == 1)
690 GFC_DIMENSION_SET(retarray->dim[0], 0,
691 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
693 else
695 GFC_DIMENSION_SET(retarray->dim[0], 0,
696 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
698 GFC_DIMENSION_SET(retarray->dim[1], 0,
699 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
700 GFC_DESCRIPTOR_EXTENT(retarray,0));
703 retarray->base_addr
704 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_COMPLEX_16));
705 retarray->offset = 0;
707 else if (unlikely (compile_options.bounds_check))
709 index_type ret_extent, arg_extent;
711 if (GFC_DESCRIPTOR_RANK (a) == 1)
713 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
714 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
715 if (arg_extent != ret_extent)
716 runtime_error ("Array bound mismatch for dimension 1 of "
717 "array (%ld/%ld) ",
718 (long int) ret_extent, (long int) arg_extent);
720 else if (GFC_DESCRIPTOR_RANK (b) == 1)
722 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
723 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
724 if (arg_extent != ret_extent)
725 runtime_error ("Array bound mismatch for dimension 1 of "
726 "array (%ld/%ld) ",
727 (long int) ret_extent, (long int) arg_extent);
729 else
731 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
732 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
733 if (arg_extent != ret_extent)
734 runtime_error ("Array bound mismatch for dimension 1 of "
735 "array (%ld/%ld) ",
736 (long int) ret_extent, (long int) arg_extent);
738 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
739 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
740 if (arg_extent != ret_extent)
741 runtime_error ("Array bound mismatch for dimension 2 of "
742 "array (%ld/%ld) ",
743 (long int) ret_extent, (long int) arg_extent);
748 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
750 /* One-dimensional result may be addressed in the code below
751 either as a row or a column matrix. We want both cases to
752 work. */
753 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
755 else
757 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
758 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
762 if (GFC_DESCRIPTOR_RANK (a) == 1)
764 /* Treat it as a a row matrix A[1,count]. */
765 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
766 aystride = 1;
768 xcount = 1;
769 count = GFC_DESCRIPTOR_EXTENT(a,0);
771 else
773 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
774 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
776 count = GFC_DESCRIPTOR_EXTENT(a,1);
777 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
780 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
782 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
783 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
784 "in dimension 1: is %ld, should be %ld",
785 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
788 if (GFC_DESCRIPTOR_RANK (b) == 1)
790 /* Treat it as a column matrix B[count,1] */
791 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
793 /* bystride should never be used for 1-dimensional b.
794 The value is only used for calculation of the
795 memory by the buffer. */
796 bystride = 256;
797 ycount = 1;
799 else
801 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
802 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
803 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
806 abase = a->base_addr;
807 bbase = b->base_addr;
808 dest = retarray->base_addr;
810 /* Now that everything is set up, we perform the multiplication
811 itself. */
813 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
814 #define min(a,b) ((a) <= (b) ? (a) : (b))
815 #define max(a,b) ((a) >= (b) ? (a) : (b))
817 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
818 && (bxstride == 1 || bystride == 1)
819 && (((float) xcount) * ((float) ycount) * ((float) count)
820 > POW3(blas_limit)))
822 const int m = xcount, n = ycount, k = count, ldc = rystride;
823 const GFC_COMPLEX_16 one = 1, zero = 0;
824 const int lda = (axstride == 1) ? aystride : axstride,
825 ldb = (bxstride == 1) ? bystride : bxstride;
827 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
829 assert (gemm != NULL);
830 const char *transa, *transb;
831 if (try_blas & 2)
832 transa = "C";
833 else
834 transa = axstride == 1 ? "N" : "T";
836 if (try_blas & 4)
837 transb = "C";
838 else
839 transb = bxstride == 1 ? "N" : "T";
841 gemm (transa, transb , &m,
842 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
843 &ldc, 1, 1);
844 return;
848 if (rxstride == 1 && axstride == 1 && bxstride == 1
849 && GFC_DESCRIPTOR_RANK (b) != 1)
851 /* This block of code implements a tuned matmul, derived from
852 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
854 Bo Kagstrom and Per Ling
855 Department of Computing Science
856 Umea University
857 S-901 87 Umea, Sweden
859 from netlib.org, translated to C, and modified for matmul.m4. */
861 const GFC_COMPLEX_16 *a, *b;
862 GFC_COMPLEX_16 *c;
863 const index_type m = xcount, n = ycount, k = count;
865 /* System generated locals */
866 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
867 i1, i2, i3, i4, i5, i6;
869 /* Local variables */
870 GFC_COMPLEX_16 f11, f12, f21, f22, f31, f32, f41, f42,
871 f13, f14, f23, f24, f33, f34, f43, f44;
872 index_type i, j, l, ii, jj, ll;
873 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
874 GFC_COMPLEX_16 *t1;
876 a = abase;
877 b = bbase;
878 c = retarray->base_addr;
880 /* Parameter adjustments */
881 c_dim1 = rystride;
882 c_offset = 1 + c_dim1;
883 c -= c_offset;
884 a_dim1 = aystride;
885 a_offset = 1 + a_dim1;
886 a -= a_offset;
887 b_dim1 = bystride;
888 b_offset = 1 + b_dim1;
889 b -= b_offset;
891 /* Empty c first. */
892 for (j=1; j<=n; j++)
893 for (i=1; i<=m; i++)
894 c[i + j * c_dim1] = (GFC_COMPLEX_16)0;
896 /* Early exit if possible */
897 if (m == 0 || n == 0 || k == 0)
898 return;
900 /* Adjust size of t1 to what is needed. */
901 index_type t1_dim, a_sz;
902 if (aystride == 1)
903 a_sz = rystride;
904 else
905 a_sz = a_dim1;
907 t1_dim = a_sz * 256 + b_dim1;
908 if (t1_dim > 65536)
909 t1_dim = 65536;
911 t1 = malloc (t1_dim * sizeof(GFC_COMPLEX_16));
913 /* Start turning the crank. */
914 i1 = n;
915 for (jj = 1; jj <= i1; jj += 512)
917 /* Computing MIN */
918 i2 = 512;
919 i3 = n - jj + 1;
920 jsec = min(i2,i3);
921 ujsec = jsec - jsec % 4;
922 i2 = k;
923 for (ll = 1; ll <= i2; ll += 256)
925 /* Computing MIN */
926 i3 = 256;
927 i4 = k - ll + 1;
928 lsec = min(i3,i4);
929 ulsec = lsec - lsec % 2;
931 i3 = m;
932 for (ii = 1; ii <= i3; ii += 256)
934 /* Computing MIN */
935 i4 = 256;
936 i5 = m - ii + 1;
937 isec = min(i4,i5);
938 uisec = isec - isec % 2;
939 i4 = ll + ulsec - 1;
940 for (l = ll; l <= i4; l += 2)
942 i5 = ii + uisec - 1;
943 for (i = ii; i <= i5; i += 2)
945 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
946 a[i + l * a_dim1];
947 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
948 a[i + (l + 1) * a_dim1];
949 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
950 a[i + 1 + l * a_dim1];
951 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
952 a[i + 1 + (l + 1) * a_dim1];
954 if (uisec < isec)
956 t1[l - ll + 1 + (isec << 8) - 257] =
957 a[ii + isec - 1 + l * a_dim1];
958 t1[l - ll + 2 + (isec << 8) - 257] =
959 a[ii + isec - 1 + (l + 1) * a_dim1];
962 if (ulsec < lsec)
964 i4 = ii + isec - 1;
965 for (i = ii; i<= i4; ++i)
967 t1[lsec + ((i - ii + 1) << 8) - 257] =
968 a[i + (ll + lsec - 1) * a_dim1];
972 uisec = isec - isec % 4;
973 i4 = jj + ujsec - 1;
974 for (j = jj; j <= i4; j += 4)
976 i5 = ii + uisec - 1;
977 for (i = ii; i <= i5; i += 4)
979 f11 = c[i + j * c_dim1];
980 f21 = c[i + 1 + j * c_dim1];
981 f12 = c[i + (j + 1) * c_dim1];
982 f22 = c[i + 1 + (j + 1) * c_dim1];
983 f13 = c[i + (j + 2) * c_dim1];
984 f23 = c[i + 1 + (j + 2) * c_dim1];
985 f14 = c[i + (j + 3) * c_dim1];
986 f24 = c[i + 1 + (j + 3) * c_dim1];
987 f31 = c[i + 2 + j * c_dim1];
988 f41 = c[i + 3 + j * c_dim1];
989 f32 = c[i + 2 + (j + 1) * c_dim1];
990 f42 = c[i + 3 + (j + 1) * c_dim1];
991 f33 = c[i + 2 + (j + 2) * c_dim1];
992 f43 = c[i + 3 + (j + 2) * c_dim1];
993 f34 = c[i + 2 + (j + 3) * c_dim1];
994 f44 = c[i + 3 + (j + 3) * c_dim1];
995 i6 = ll + lsec - 1;
996 for (l = ll; l <= i6; ++l)
998 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
999 * b[l + j * b_dim1];
1000 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1001 * b[l + j * b_dim1];
1002 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1003 * b[l + (j + 1) * b_dim1];
1004 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1005 * b[l + (j + 1) * b_dim1];
1006 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1007 * b[l + (j + 2) * b_dim1];
1008 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1009 * b[l + (j + 2) * b_dim1];
1010 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1011 * b[l + (j + 3) * b_dim1];
1012 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1013 * b[l + (j + 3) * b_dim1];
1014 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1015 * b[l + j * b_dim1];
1016 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1017 * b[l + j * b_dim1];
1018 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1019 * b[l + (j + 1) * b_dim1];
1020 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1021 * b[l + (j + 1) * b_dim1];
1022 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1023 * b[l + (j + 2) * b_dim1];
1024 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1025 * b[l + (j + 2) * b_dim1];
1026 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1027 * b[l + (j + 3) * b_dim1];
1028 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1029 * b[l + (j + 3) * b_dim1];
1031 c[i + j * c_dim1] = f11;
1032 c[i + 1 + j * c_dim1] = f21;
1033 c[i + (j + 1) * c_dim1] = f12;
1034 c[i + 1 + (j + 1) * c_dim1] = f22;
1035 c[i + (j + 2) * c_dim1] = f13;
1036 c[i + 1 + (j + 2) * c_dim1] = f23;
1037 c[i + (j + 3) * c_dim1] = f14;
1038 c[i + 1 + (j + 3) * c_dim1] = f24;
1039 c[i + 2 + j * c_dim1] = f31;
1040 c[i + 3 + j * c_dim1] = f41;
1041 c[i + 2 + (j + 1) * c_dim1] = f32;
1042 c[i + 3 + (j + 1) * c_dim1] = f42;
1043 c[i + 2 + (j + 2) * c_dim1] = f33;
1044 c[i + 3 + (j + 2) * c_dim1] = f43;
1045 c[i + 2 + (j + 3) * c_dim1] = f34;
1046 c[i + 3 + (j + 3) * c_dim1] = f44;
1048 if (uisec < isec)
1050 i5 = ii + isec - 1;
1051 for (i = ii + uisec; i <= i5; ++i)
1053 f11 = c[i + j * c_dim1];
1054 f12 = c[i + (j + 1) * c_dim1];
1055 f13 = c[i + (j + 2) * c_dim1];
1056 f14 = c[i + (j + 3) * c_dim1];
1057 i6 = ll + lsec - 1;
1058 for (l = ll; l <= i6; ++l)
1060 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1061 257] * b[l + j * b_dim1];
1062 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1063 257] * b[l + (j + 1) * b_dim1];
1064 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1065 257] * b[l + (j + 2) * b_dim1];
1066 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1067 257] * b[l + (j + 3) * b_dim1];
1069 c[i + j * c_dim1] = f11;
1070 c[i + (j + 1) * c_dim1] = f12;
1071 c[i + (j + 2) * c_dim1] = f13;
1072 c[i + (j + 3) * c_dim1] = f14;
1076 if (ujsec < jsec)
1078 i4 = jj + jsec - 1;
1079 for (j = jj + ujsec; j <= i4; ++j)
1081 i5 = ii + uisec - 1;
1082 for (i = ii; i <= i5; i += 4)
1084 f11 = c[i + j * c_dim1];
1085 f21 = c[i + 1 + j * c_dim1];
1086 f31 = c[i + 2 + j * c_dim1];
1087 f41 = c[i + 3 + j * c_dim1];
1088 i6 = ll + lsec - 1;
1089 for (l = ll; l <= i6; ++l)
1091 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1092 257] * b[l + j * b_dim1];
1093 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
1094 257] * b[l + j * b_dim1];
1095 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
1096 257] * b[l + j * b_dim1];
1097 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
1098 257] * b[l + j * b_dim1];
1100 c[i + j * c_dim1] = f11;
1101 c[i + 1 + j * c_dim1] = f21;
1102 c[i + 2 + j * c_dim1] = f31;
1103 c[i + 3 + j * c_dim1] = f41;
1105 i5 = ii + isec - 1;
1106 for (i = ii + uisec; i <= i5; ++i)
1108 f11 = c[i + j * c_dim1];
1109 i6 = ll + lsec - 1;
1110 for (l = ll; l <= i6; ++l)
1112 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1113 257] * b[l + j * b_dim1];
1115 c[i + j * c_dim1] = f11;
1122 free(t1);
1123 return;
1125 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
1127 if (GFC_DESCRIPTOR_RANK (a) != 1)
1129 const GFC_COMPLEX_16 *restrict abase_x;
1130 const GFC_COMPLEX_16 *restrict bbase_y;
1131 GFC_COMPLEX_16 *restrict dest_y;
1132 GFC_COMPLEX_16 s;
1134 for (y = 0; y < ycount; y++)
1136 bbase_y = &bbase[y*bystride];
1137 dest_y = &dest[y*rystride];
1138 for (x = 0; x < xcount; x++)
1140 abase_x = &abase[x*axstride];
1141 s = (GFC_COMPLEX_16) 0;
1142 for (n = 0; n < count; n++)
1143 s += abase_x[n] * bbase_y[n];
1144 dest_y[x] = s;
1148 else
1150 const GFC_COMPLEX_16 *restrict bbase_y;
1151 GFC_COMPLEX_16 s;
1153 for (y = 0; y < ycount; y++)
1155 bbase_y = &bbase[y*bystride];
1156 s = (GFC_COMPLEX_16) 0;
1157 for (n = 0; n < count; n++)
1158 s += abase[n*axstride] * bbase_y[n];
1159 dest[y*rystride] = s;
1163 else if (GFC_DESCRIPTOR_RANK (a) == 1)
1165 const GFC_COMPLEX_16 *restrict bbase_y;
1166 GFC_COMPLEX_16 s;
1168 for (y = 0; y < ycount; y++)
1170 bbase_y = &bbase[y*bystride];
1171 s = (GFC_COMPLEX_16) 0;
1172 for (n = 0; n < count; n++)
1173 s += abase[n*axstride] * bbase_y[n*bxstride];
1174 dest[y*rxstride] = s;
1177 else if (axstride < aystride)
1179 for (y = 0; y < ycount; y++)
1180 for (x = 0; x < xcount; x++)
1181 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_16)0;
1183 for (y = 0; y < ycount; y++)
1184 for (n = 0; n < count; n++)
1185 for (x = 0; x < xcount; x++)
1186 /* dest[x,y] += a[x,n] * b[n,y] */
1187 dest[x*rxstride + y*rystride] +=
1188 abase[x*axstride + n*aystride] *
1189 bbase[n*bxstride + y*bystride];
1191 else
1193 const GFC_COMPLEX_16 *restrict abase_x;
1194 const GFC_COMPLEX_16 *restrict bbase_y;
1195 GFC_COMPLEX_16 *restrict dest_y;
1196 GFC_COMPLEX_16 s;
1198 for (y = 0; y < ycount; y++)
1200 bbase_y = &bbase[y*bystride];
1201 dest_y = &dest[y*rystride];
1202 for (x = 0; x < xcount; x++)
1204 abase_x = &abase[x*axstride];
1205 s = (GFC_COMPLEX_16) 0;
1206 for (n = 0; n < count; n++)
1207 s += abase_x[n*aystride] * bbase_y[n*bxstride];
1208 dest_y[x*rxstride] = s;
1213 #undef POW3
1214 #undef min
1215 #undef max
1217 #endif /* HAVE_AVX2 */
1219 #ifdef HAVE_AVX512F
1220 static void
1221 matmul_c16_avx512f (gfc_array_c16 * const restrict retarray,
1222 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
1223 int blas_limit, blas_call gemm) __attribute__((__target__("avx512f")));
1224 static void
1225 matmul_c16_avx512f (gfc_array_c16 * const restrict retarray,
1226 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
1227 int blas_limit, blas_call gemm)
1229 const GFC_COMPLEX_16 * restrict abase;
1230 const GFC_COMPLEX_16 * restrict bbase;
1231 GFC_COMPLEX_16 * restrict dest;
1233 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
1234 index_type x, y, n, count, xcount, ycount;
1236 assert (GFC_DESCRIPTOR_RANK (a) == 2
1237 || GFC_DESCRIPTOR_RANK (b) == 2);
1239 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
1241 Either A or B (but not both) can be rank 1:
1243 o One-dimensional argument A is implicitly treated as a row matrix
1244 dimensioned [1,count], so xcount=1.
1246 o One-dimensional argument B is implicitly treated as a column matrix
1247 dimensioned [count, 1], so ycount=1.
1250 if (retarray->base_addr == NULL)
1252 if (GFC_DESCRIPTOR_RANK (a) == 1)
1254 GFC_DIMENSION_SET(retarray->dim[0], 0,
1255 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
1257 else if (GFC_DESCRIPTOR_RANK (b) == 1)
1259 GFC_DIMENSION_SET(retarray->dim[0], 0,
1260 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
1262 else
1264 GFC_DIMENSION_SET(retarray->dim[0], 0,
1265 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
1267 GFC_DIMENSION_SET(retarray->dim[1], 0,
1268 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
1269 GFC_DESCRIPTOR_EXTENT(retarray,0));
1272 retarray->base_addr
1273 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_COMPLEX_16));
1274 retarray->offset = 0;
1276 else if (unlikely (compile_options.bounds_check))
1278 index_type ret_extent, arg_extent;
1280 if (GFC_DESCRIPTOR_RANK (a) == 1)
1282 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
1283 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1284 if (arg_extent != ret_extent)
1285 runtime_error ("Array bound mismatch for dimension 1 of "
1286 "array (%ld/%ld) ",
1287 (long int) ret_extent, (long int) arg_extent);
1289 else if (GFC_DESCRIPTOR_RANK (b) == 1)
1291 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
1292 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1293 if (arg_extent != ret_extent)
1294 runtime_error ("Array bound mismatch for dimension 1 of "
1295 "array (%ld/%ld) ",
1296 (long int) ret_extent, (long int) arg_extent);
1298 else
1300 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
1301 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1302 if (arg_extent != ret_extent)
1303 runtime_error ("Array bound mismatch for dimension 1 of "
1304 "array (%ld/%ld) ",
1305 (long int) ret_extent, (long int) arg_extent);
1307 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
1308 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
1309 if (arg_extent != ret_extent)
1310 runtime_error ("Array bound mismatch for dimension 2 of "
1311 "array (%ld/%ld) ",
1312 (long int) ret_extent, (long int) arg_extent);
1317 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
1319 /* One-dimensional result may be addressed in the code below
1320 either as a row or a column matrix. We want both cases to
1321 work. */
1322 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
1324 else
1326 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
1327 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
1331 if (GFC_DESCRIPTOR_RANK (a) == 1)
1333 /* Treat it as a a row matrix A[1,count]. */
1334 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
1335 aystride = 1;
1337 xcount = 1;
1338 count = GFC_DESCRIPTOR_EXTENT(a,0);
1340 else
1342 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
1343 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
1345 count = GFC_DESCRIPTOR_EXTENT(a,1);
1346 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
1349 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
1351 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
1352 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
1353 "in dimension 1: is %ld, should be %ld",
1354 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
1357 if (GFC_DESCRIPTOR_RANK (b) == 1)
1359 /* Treat it as a column matrix B[count,1] */
1360 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
1362 /* bystride should never be used for 1-dimensional b.
1363 The value is only used for calculation of the
1364 memory by the buffer. */
1365 bystride = 256;
1366 ycount = 1;
1368 else
1370 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
1371 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
1372 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
1375 abase = a->base_addr;
1376 bbase = b->base_addr;
1377 dest = retarray->base_addr;
1379 /* Now that everything is set up, we perform the multiplication
1380 itself. */
1382 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
1383 #define min(a,b) ((a) <= (b) ? (a) : (b))
1384 #define max(a,b) ((a) >= (b) ? (a) : (b))
1386 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
1387 && (bxstride == 1 || bystride == 1)
1388 && (((float) xcount) * ((float) ycount) * ((float) count)
1389 > POW3(blas_limit)))
1391 const int m = xcount, n = ycount, k = count, ldc = rystride;
1392 const GFC_COMPLEX_16 one = 1, zero = 0;
1393 const int lda = (axstride == 1) ? aystride : axstride,
1394 ldb = (bxstride == 1) ? bystride : bxstride;
1396 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
1398 assert (gemm != NULL);
1399 const char *transa, *transb;
1400 if (try_blas & 2)
1401 transa = "C";
1402 else
1403 transa = axstride == 1 ? "N" : "T";
1405 if (try_blas & 4)
1406 transb = "C";
1407 else
1408 transb = bxstride == 1 ? "N" : "T";
1410 gemm (transa, transb , &m,
1411 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
1412 &ldc, 1, 1);
1413 return;
1417 if (rxstride == 1 && axstride == 1 && bxstride == 1
1418 && GFC_DESCRIPTOR_RANK (b) != 1)
1420 /* This block of code implements a tuned matmul, derived from
1421 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
1423 Bo Kagstrom and Per Ling
1424 Department of Computing Science
1425 Umea University
1426 S-901 87 Umea, Sweden
1428 from netlib.org, translated to C, and modified for matmul.m4. */
1430 const GFC_COMPLEX_16 *a, *b;
1431 GFC_COMPLEX_16 *c;
1432 const index_type m = xcount, n = ycount, k = count;
1434 /* System generated locals */
1435 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
1436 i1, i2, i3, i4, i5, i6;
1438 /* Local variables */
1439 GFC_COMPLEX_16 f11, f12, f21, f22, f31, f32, f41, f42,
1440 f13, f14, f23, f24, f33, f34, f43, f44;
1441 index_type i, j, l, ii, jj, ll;
1442 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
1443 GFC_COMPLEX_16 *t1;
1445 a = abase;
1446 b = bbase;
1447 c = retarray->base_addr;
1449 /* Parameter adjustments */
1450 c_dim1 = rystride;
1451 c_offset = 1 + c_dim1;
1452 c -= c_offset;
1453 a_dim1 = aystride;
1454 a_offset = 1 + a_dim1;
1455 a -= a_offset;
1456 b_dim1 = bystride;
1457 b_offset = 1 + b_dim1;
1458 b -= b_offset;
1460 /* Empty c first. */
1461 for (j=1; j<=n; j++)
1462 for (i=1; i<=m; i++)
1463 c[i + j * c_dim1] = (GFC_COMPLEX_16)0;
1465 /* Early exit if possible */
1466 if (m == 0 || n == 0 || k == 0)
1467 return;
1469 /* Adjust size of t1 to what is needed. */
1470 index_type t1_dim, a_sz;
1471 if (aystride == 1)
1472 a_sz = rystride;
1473 else
1474 a_sz = a_dim1;
1476 t1_dim = a_sz * 256 + b_dim1;
1477 if (t1_dim > 65536)
1478 t1_dim = 65536;
1480 t1 = malloc (t1_dim * sizeof(GFC_COMPLEX_16));
1482 /* Start turning the crank. */
1483 i1 = n;
1484 for (jj = 1; jj <= i1; jj += 512)
1486 /* Computing MIN */
1487 i2 = 512;
1488 i3 = n - jj + 1;
1489 jsec = min(i2,i3);
1490 ujsec = jsec - jsec % 4;
1491 i2 = k;
1492 for (ll = 1; ll <= i2; ll += 256)
1494 /* Computing MIN */
1495 i3 = 256;
1496 i4 = k - ll + 1;
1497 lsec = min(i3,i4);
1498 ulsec = lsec - lsec % 2;
1500 i3 = m;
1501 for (ii = 1; ii <= i3; ii += 256)
1503 /* Computing MIN */
1504 i4 = 256;
1505 i5 = m - ii + 1;
1506 isec = min(i4,i5);
1507 uisec = isec - isec % 2;
1508 i4 = ll + ulsec - 1;
1509 for (l = ll; l <= i4; l += 2)
1511 i5 = ii + uisec - 1;
1512 for (i = ii; i <= i5; i += 2)
1514 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
1515 a[i + l * a_dim1];
1516 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
1517 a[i + (l + 1) * a_dim1];
1518 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
1519 a[i + 1 + l * a_dim1];
1520 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
1521 a[i + 1 + (l + 1) * a_dim1];
1523 if (uisec < isec)
1525 t1[l - ll + 1 + (isec << 8) - 257] =
1526 a[ii + isec - 1 + l * a_dim1];
1527 t1[l - ll + 2 + (isec << 8) - 257] =
1528 a[ii + isec - 1 + (l + 1) * a_dim1];
1531 if (ulsec < lsec)
1533 i4 = ii + isec - 1;
1534 for (i = ii; i<= i4; ++i)
1536 t1[lsec + ((i - ii + 1) << 8) - 257] =
1537 a[i + (ll + lsec - 1) * a_dim1];
1541 uisec = isec - isec % 4;
1542 i4 = jj + ujsec - 1;
1543 for (j = jj; j <= i4; j += 4)
1545 i5 = ii + uisec - 1;
1546 for (i = ii; i <= i5; i += 4)
1548 f11 = c[i + j * c_dim1];
1549 f21 = c[i + 1 + j * c_dim1];
1550 f12 = c[i + (j + 1) * c_dim1];
1551 f22 = c[i + 1 + (j + 1) * c_dim1];
1552 f13 = c[i + (j + 2) * c_dim1];
1553 f23 = c[i + 1 + (j + 2) * c_dim1];
1554 f14 = c[i + (j + 3) * c_dim1];
1555 f24 = c[i + 1 + (j + 3) * c_dim1];
1556 f31 = c[i + 2 + j * c_dim1];
1557 f41 = c[i + 3 + j * c_dim1];
1558 f32 = c[i + 2 + (j + 1) * c_dim1];
1559 f42 = c[i + 3 + (j + 1) * c_dim1];
1560 f33 = c[i + 2 + (j + 2) * c_dim1];
1561 f43 = c[i + 3 + (j + 2) * c_dim1];
1562 f34 = c[i + 2 + (j + 3) * c_dim1];
1563 f44 = c[i + 3 + (j + 3) * c_dim1];
1564 i6 = ll + lsec - 1;
1565 for (l = ll; l <= i6; ++l)
1567 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1568 * b[l + j * b_dim1];
1569 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1570 * b[l + j * b_dim1];
1571 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1572 * b[l + (j + 1) * b_dim1];
1573 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1574 * b[l + (j + 1) * b_dim1];
1575 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1576 * b[l + (j + 2) * b_dim1];
1577 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1578 * b[l + (j + 2) * b_dim1];
1579 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1580 * b[l + (j + 3) * b_dim1];
1581 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1582 * b[l + (j + 3) * b_dim1];
1583 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1584 * b[l + j * b_dim1];
1585 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1586 * b[l + j * b_dim1];
1587 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1588 * b[l + (j + 1) * b_dim1];
1589 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1590 * b[l + (j + 1) * b_dim1];
1591 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1592 * b[l + (j + 2) * b_dim1];
1593 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1594 * b[l + (j + 2) * b_dim1];
1595 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1596 * b[l + (j + 3) * b_dim1];
1597 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1598 * b[l + (j + 3) * b_dim1];
1600 c[i + j * c_dim1] = f11;
1601 c[i + 1 + j * c_dim1] = f21;
1602 c[i + (j + 1) * c_dim1] = f12;
1603 c[i + 1 + (j + 1) * c_dim1] = f22;
1604 c[i + (j + 2) * c_dim1] = f13;
1605 c[i + 1 + (j + 2) * c_dim1] = f23;
1606 c[i + (j + 3) * c_dim1] = f14;
1607 c[i + 1 + (j + 3) * c_dim1] = f24;
1608 c[i + 2 + j * c_dim1] = f31;
1609 c[i + 3 + j * c_dim1] = f41;
1610 c[i + 2 + (j + 1) * c_dim1] = f32;
1611 c[i + 3 + (j + 1) * c_dim1] = f42;
1612 c[i + 2 + (j + 2) * c_dim1] = f33;
1613 c[i + 3 + (j + 2) * c_dim1] = f43;
1614 c[i + 2 + (j + 3) * c_dim1] = f34;
1615 c[i + 3 + (j + 3) * c_dim1] = f44;
1617 if (uisec < isec)
1619 i5 = ii + isec - 1;
1620 for (i = ii + uisec; i <= i5; ++i)
1622 f11 = c[i + j * c_dim1];
1623 f12 = c[i + (j + 1) * c_dim1];
1624 f13 = c[i + (j + 2) * c_dim1];
1625 f14 = c[i + (j + 3) * c_dim1];
1626 i6 = ll + lsec - 1;
1627 for (l = ll; l <= i6; ++l)
1629 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1630 257] * b[l + j * b_dim1];
1631 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1632 257] * b[l + (j + 1) * b_dim1];
1633 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1634 257] * b[l + (j + 2) * b_dim1];
1635 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1636 257] * b[l + (j + 3) * b_dim1];
1638 c[i + j * c_dim1] = f11;
1639 c[i + (j + 1) * c_dim1] = f12;
1640 c[i + (j + 2) * c_dim1] = f13;
1641 c[i + (j + 3) * c_dim1] = f14;
1645 if (ujsec < jsec)
1647 i4 = jj + jsec - 1;
1648 for (j = jj + ujsec; j <= i4; ++j)
1650 i5 = ii + uisec - 1;
1651 for (i = ii; i <= i5; i += 4)
1653 f11 = c[i + j * c_dim1];
1654 f21 = c[i + 1 + j * c_dim1];
1655 f31 = c[i + 2 + j * c_dim1];
1656 f41 = c[i + 3 + j * c_dim1];
1657 i6 = ll + lsec - 1;
1658 for (l = ll; l <= i6; ++l)
1660 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1661 257] * b[l + j * b_dim1];
1662 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
1663 257] * b[l + j * b_dim1];
1664 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
1665 257] * b[l + j * b_dim1];
1666 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
1667 257] * b[l + j * b_dim1];
1669 c[i + j * c_dim1] = f11;
1670 c[i + 1 + j * c_dim1] = f21;
1671 c[i + 2 + j * c_dim1] = f31;
1672 c[i + 3 + j * c_dim1] = f41;
1674 i5 = ii + isec - 1;
1675 for (i = ii + uisec; i <= i5; ++i)
1677 f11 = c[i + j * c_dim1];
1678 i6 = ll + lsec - 1;
1679 for (l = ll; l <= i6; ++l)
1681 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1682 257] * b[l + j * b_dim1];
1684 c[i + j * c_dim1] = f11;
1691 free(t1);
1692 return;
1694 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
1696 if (GFC_DESCRIPTOR_RANK (a) != 1)
1698 const GFC_COMPLEX_16 *restrict abase_x;
1699 const GFC_COMPLEX_16 *restrict bbase_y;
1700 GFC_COMPLEX_16 *restrict dest_y;
1701 GFC_COMPLEX_16 s;
1703 for (y = 0; y < ycount; y++)
1705 bbase_y = &bbase[y*bystride];
1706 dest_y = &dest[y*rystride];
1707 for (x = 0; x < xcount; x++)
1709 abase_x = &abase[x*axstride];
1710 s = (GFC_COMPLEX_16) 0;
1711 for (n = 0; n < count; n++)
1712 s += abase_x[n] * bbase_y[n];
1713 dest_y[x] = s;
1717 else
1719 const GFC_COMPLEX_16 *restrict bbase_y;
1720 GFC_COMPLEX_16 s;
1722 for (y = 0; y < ycount; y++)
1724 bbase_y = &bbase[y*bystride];
1725 s = (GFC_COMPLEX_16) 0;
1726 for (n = 0; n < count; n++)
1727 s += abase[n*axstride] * bbase_y[n];
1728 dest[y*rystride] = s;
1732 else if (GFC_DESCRIPTOR_RANK (a) == 1)
1734 const GFC_COMPLEX_16 *restrict bbase_y;
1735 GFC_COMPLEX_16 s;
1737 for (y = 0; y < ycount; y++)
1739 bbase_y = &bbase[y*bystride];
1740 s = (GFC_COMPLEX_16) 0;
1741 for (n = 0; n < count; n++)
1742 s += abase[n*axstride] * bbase_y[n*bxstride];
1743 dest[y*rxstride] = s;
1746 else if (axstride < aystride)
1748 for (y = 0; y < ycount; y++)
1749 for (x = 0; x < xcount; x++)
1750 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_16)0;
1752 for (y = 0; y < ycount; y++)
1753 for (n = 0; n < count; n++)
1754 for (x = 0; x < xcount; x++)
1755 /* dest[x,y] += a[x,n] * b[n,y] */
1756 dest[x*rxstride + y*rystride] +=
1757 abase[x*axstride + n*aystride] *
1758 bbase[n*bxstride + y*bystride];
1760 else
1762 const GFC_COMPLEX_16 *restrict abase_x;
1763 const GFC_COMPLEX_16 *restrict bbase_y;
1764 GFC_COMPLEX_16 *restrict dest_y;
1765 GFC_COMPLEX_16 s;
1767 for (y = 0; y < ycount; y++)
1769 bbase_y = &bbase[y*bystride];
1770 dest_y = &dest[y*rystride];
1771 for (x = 0; x < xcount; x++)
1773 abase_x = &abase[x*axstride];
1774 s = (GFC_COMPLEX_16) 0;
1775 for (n = 0; n < count; n++)
1776 s += abase_x[n*aystride] * bbase_y[n*bxstride];
1777 dest_y[x*rxstride] = s;
1782 #undef POW3
1783 #undef min
1784 #undef max
1786 #endif /* HAVE_AVX512F */
1788 /* AMD-specifix funtions with AVX128 and FMA3/FMA4. */
1790 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
1791 void
1792 matmul_c16_avx128_fma3 (gfc_array_c16 * const restrict retarray,
1793 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
1794 int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma")));
1795 internal_proto(matmul_c16_avx128_fma3);
1796 #endif
1798 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
1799 void
1800 matmul_c16_avx128_fma4 (gfc_array_c16 * const restrict retarray,
1801 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
1802 int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma4")));
1803 internal_proto(matmul_c16_avx128_fma4);
1804 #endif
1806 /* Function to fall back to if there is no special processor-specific version. */
1807 static void
1808 matmul_c16_vanilla (gfc_array_c16 * const restrict retarray,
1809 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
1810 int blas_limit, blas_call gemm)
1812 const GFC_COMPLEX_16 * restrict abase;
1813 const GFC_COMPLEX_16 * restrict bbase;
1814 GFC_COMPLEX_16 * restrict dest;
1816 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
1817 index_type x, y, n, count, xcount, ycount;
1819 assert (GFC_DESCRIPTOR_RANK (a) == 2
1820 || GFC_DESCRIPTOR_RANK (b) == 2);
1822 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
1824 Either A or B (but not both) can be rank 1:
1826 o One-dimensional argument A is implicitly treated as a row matrix
1827 dimensioned [1,count], so xcount=1.
1829 o One-dimensional argument B is implicitly treated as a column matrix
1830 dimensioned [count, 1], so ycount=1.
1833 if (retarray->base_addr == NULL)
1835 if (GFC_DESCRIPTOR_RANK (a) == 1)
1837 GFC_DIMENSION_SET(retarray->dim[0], 0,
1838 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
1840 else if (GFC_DESCRIPTOR_RANK (b) == 1)
1842 GFC_DIMENSION_SET(retarray->dim[0], 0,
1843 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
1845 else
1847 GFC_DIMENSION_SET(retarray->dim[0], 0,
1848 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
1850 GFC_DIMENSION_SET(retarray->dim[1], 0,
1851 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
1852 GFC_DESCRIPTOR_EXTENT(retarray,0));
1855 retarray->base_addr
1856 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_COMPLEX_16));
1857 retarray->offset = 0;
1859 else if (unlikely (compile_options.bounds_check))
1861 index_type ret_extent, arg_extent;
1863 if (GFC_DESCRIPTOR_RANK (a) == 1)
1865 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
1866 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1867 if (arg_extent != ret_extent)
1868 runtime_error ("Array bound mismatch for dimension 1 of "
1869 "array (%ld/%ld) ",
1870 (long int) ret_extent, (long int) arg_extent);
1872 else if (GFC_DESCRIPTOR_RANK (b) == 1)
1874 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
1875 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1876 if (arg_extent != ret_extent)
1877 runtime_error ("Array bound mismatch for dimension 1 of "
1878 "array (%ld/%ld) ",
1879 (long int) ret_extent, (long int) arg_extent);
1881 else
1883 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
1884 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1885 if (arg_extent != ret_extent)
1886 runtime_error ("Array bound mismatch for dimension 1 of "
1887 "array (%ld/%ld) ",
1888 (long int) ret_extent, (long int) arg_extent);
1890 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
1891 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
1892 if (arg_extent != ret_extent)
1893 runtime_error ("Array bound mismatch for dimension 2 of "
1894 "array (%ld/%ld) ",
1895 (long int) ret_extent, (long int) arg_extent);
1900 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
1902 /* One-dimensional result may be addressed in the code below
1903 either as a row or a column matrix. We want both cases to
1904 work. */
1905 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
1907 else
1909 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
1910 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
1914 if (GFC_DESCRIPTOR_RANK (a) == 1)
1916 /* Treat it as a a row matrix A[1,count]. */
1917 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
1918 aystride = 1;
1920 xcount = 1;
1921 count = GFC_DESCRIPTOR_EXTENT(a,0);
1923 else
1925 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
1926 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
1928 count = GFC_DESCRIPTOR_EXTENT(a,1);
1929 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
1932 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
1934 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
1935 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
1936 "in dimension 1: is %ld, should be %ld",
1937 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
1940 if (GFC_DESCRIPTOR_RANK (b) == 1)
1942 /* Treat it as a column matrix B[count,1] */
1943 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
1945 /* bystride should never be used for 1-dimensional b.
1946 The value is only used for calculation of the
1947 memory by the buffer. */
1948 bystride = 256;
1949 ycount = 1;
1951 else
1953 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
1954 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
1955 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
1958 abase = a->base_addr;
1959 bbase = b->base_addr;
1960 dest = retarray->base_addr;
1962 /* Now that everything is set up, we perform the multiplication
1963 itself. */
1965 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
1966 #define min(a,b) ((a) <= (b) ? (a) : (b))
1967 #define max(a,b) ((a) >= (b) ? (a) : (b))
1969 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
1970 && (bxstride == 1 || bystride == 1)
1971 && (((float) xcount) * ((float) ycount) * ((float) count)
1972 > POW3(blas_limit)))
1974 const int m = xcount, n = ycount, k = count, ldc = rystride;
1975 const GFC_COMPLEX_16 one = 1, zero = 0;
1976 const int lda = (axstride == 1) ? aystride : axstride,
1977 ldb = (bxstride == 1) ? bystride : bxstride;
1979 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
1981 assert (gemm != NULL);
1982 const char *transa, *transb;
1983 if (try_blas & 2)
1984 transa = "C";
1985 else
1986 transa = axstride == 1 ? "N" : "T";
1988 if (try_blas & 4)
1989 transb = "C";
1990 else
1991 transb = bxstride == 1 ? "N" : "T";
1993 gemm (transa, transb , &m,
1994 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
1995 &ldc, 1, 1);
1996 return;
2000 if (rxstride == 1 && axstride == 1 && bxstride == 1
2001 && GFC_DESCRIPTOR_RANK (b) != 1)
2003 /* This block of code implements a tuned matmul, derived from
2004 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
2006 Bo Kagstrom and Per Ling
2007 Department of Computing Science
2008 Umea University
2009 S-901 87 Umea, Sweden
2011 from netlib.org, translated to C, and modified for matmul.m4. */
2013 const GFC_COMPLEX_16 *a, *b;
2014 GFC_COMPLEX_16 *c;
2015 const index_type m = xcount, n = ycount, k = count;
2017 /* System generated locals */
2018 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
2019 i1, i2, i3, i4, i5, i6;
2021 /* Local variables */
2022 GFC_COMPLEX_16 f11, f12, f21, f22, f31, f32, f41, f42,
2023 f13, f14, f23, f24, f33, f34, f43, f44;
2024 index_type i, j, l, ii, jj, ll;
2025 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
2026 GFC_COMPLEX_16 *t1;
2028 a = abase;
2029 b = bbase;
2030 c = retarray->base_addr;
2032 /* Parameter adjustments */
2033 c_dim1 = rystride;
2034 c_offset = 1 + c_dim1;
2035 c -= c_offset;
2036 a_dim1 = aystride;
2037 a_offset = 1 + a_dim1;
2038 a -= a_offset;
2039 b_dim1 = bystride;
2040 b_offset = 1 + b_dim1;
2041 b -= b_offset;
2043 /* Empty c first. */
2044 for (j=1; j<=n; j++)
2045 for (i=1; i<=m; i++)
2046 c[i + j * c_dim1] = (GFC_COMPLEX_16)0;
2048 /* Early exit if possible */
2049 if (m == 0 || n == 0 || k == 0)
2050 return;
2052 /* Adjust size of t1 to what is needed. */
2053 index_type t1_dim, a_sz;
2054 if (aystride == 1)
2055 a_sz = rystride;
2056 else
2057 a_sz = a_dim1;
2059 t1_dim = a_sz * 256 + b_dim1;
2060 if (t1_dim > 65536)
2061 t1_dim = 65536;
2063 t1 = malloc (t1_dim * sizeof(GFC_COMPLEX_16));
2065 /* Start turning the crank. */
2066 i1 = n;
2067 for (jj = 1; jj <= i1; jj += 512)
2069 /* Computing MIN */
2070 i2 = 512;
2071 i3 = n - jj + 1;
2072 jsec = min(i2,i3);
2073 ujsec = jsec - jsec % 4;
2074 i2 = k;
2075 for (ll = 1; ll <= i2; ll += 256)
2077 /* Computing MIN */
2078 i3 = 256;
2079 i4 = k - ll + 1;
2080 lsec = min(i3,i4);
2081 ulsec = lsec - lsec % 2;
2083 i3 = m;
2084 for (ii = 1; ii <= i3; ii += 256)
2086 /* Computing MIN */
2087 i4 = 256;
2088 i5 = m - ii + 1;
2089 isec = min(i4,i5);
2090 uisec = isec - isec % 2;
2091 i4 = ll + ulsec - 1;
2092 for (l = ll; l <= i4; l += 2)
2094 i5 = ii + uisec - 1;
2095 for (i = ii; i <= i5; i += 2)
2097 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
2098 a[i + l * a_dim1];
2099 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
2100 a[i + (l + 1) * a_dim1];
2101 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
2102 a[i + 1 + l * a_dim1];
2103 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
2104 a[i + 1 + (l + 1) * a_dim1];
2106 if (uisec < isec)
2108 t1[l - ll + 1 + (isec << 8) - 257] =
2109 a[ii + isec - 1 + l * a_dim1];
2110 t1[l - ll + 2 + (isec << 8) - 257] =
2111 a[ii + isec - 1 + (l + 1) * a_dim1];
2114 if (ulsec < lsec)
2116 i4 = ii + isec - 1;
2117 for (i = ii; i<= i4; ++i)
2119 t1[lsec + ((i - ii + 1) << 8) - 257] =
2120 a[i + (ll + lsec - 1) * a_dim1];
2124 uisec = isec - isec % 4;
2125 i4 = jj + ujsec - 1;
2126 for (j = jj; j <= i4; j += 4)
2128 i5 = ii + uisec - 1;
2129 for (i = ii; i <= i5; i += 4)
2131 f11 = c[i + j * c_dim1];
2132 f21 = c[i + 1 + j * c_dim1];
2133 f12 = c[i + (j + 1) * c_dim1];
2134 f22 = c[i + 1 + (j + 1) * c_dim1];
2135 f13 = c[i + (j + 2) * c_dim1];
2136 f23 = c[i + 1 + (j + 2) * c_dim1];
2137 f14 = c[i + (j + 3) * c_dim1];
2138 f24 = c[i + 1 + (j + 3) * c_dim1];
2139 f31 = c[i + 2 + j * c_dim1];
2140 f41 = c[i + 3 + j * c_dim1];
2141 f32 = c[i + 2 + (j + 1) * c_dim1];
2142 f42 = c[i + 3 + (j + 1) * c_dim1];
2143 f33 = c[i + 2 + (j + 2) * c_dim1];
2144 f43 = c[i + 3 + (j + 2) * c_dim1];
2145 f34 = c[i + 2 + (j + 3) * c_dim1];
2146 f44 = c[i + 3 + (j + 3) * c_dim1];
2147 i6 = ll + lsec - 1;
2148 for (l = ll; l <= i6; ++l)
2150 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2151 * b[l + j * b_dim1];
2152 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2153 * b[l + j * b_dim1];
2154 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2155 * b[l + (j + 1) * b_dim1];
2156 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2157 * b[l + (j + 1) * b_dim1];
2158 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2159 * b[l + (j + 2) * b_dim1];
2160 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2161 * b[l + (j + 2) * b_dim1];
2162 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2163 * b[l + (j + 3) * b_dim1];
2164 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2165 * b[l + (j + 3) * b_dim1];
2166 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2167 * b[l + j * b_dim1];
2168 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2169 * b[l + j * b_dim1];
2170 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2171 * b[l + (j + 1) * b_dim1];
2172 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2173 * b[l + (j + 1) * b_dim1];
2174 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2175 * b[l + (j + 2) * b_dim1];
2176 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2177 * b[l + (j + 2) * b_dim1];
2178 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2179 * b[l + (j + 3) * b_dim1];
2180 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2181 * b[l + (j + 3) * b_dim1];
2183 c[i + j * c_dim1] = f11;
2184 c[i + 1 + j * c_dim1] = f21;
2185 c[i + (j + 1) * c_dim1] = f12;
2186 c[i + 1 + (j + 1) * c_dim1] = f22;
2187 c[i + (j + 2) * c_dim1] = f13;
2188 c[i + 1 + (j + 2) * c_dim1] = f23;
2189 c[i + (j + 3) * c_dim1] = f14;
2190 c[i + 1 + (j + 3) * c_dim1] = f24;
2191 c[i + 2 + j * c_dim1] = f31;
2192 c[i + 3 + j * c_dim1] = f41;
2193 c[i + 2 + (j + 1) * c_dim1] = f32;
2194 c[i + 3 + (j + 1) * c_dim1] = f42;
2195 c[i + 2 + (j + 2) * c_dim1] = f33;
2196 c[i + 3 + (j + 2) * c_dim1] = f43;
2197 c[i + 2 + (j + 3) * c_dim1] = f34;
2198 c[i + 3 + (j + 3) * c_dim1] = f44;
2200 if (uisec < isec)
2202 i5 = ii + isec - 1;
2203 for (i = ii + uisec; i <= i5; ++i)
2205 f11 = c[i + j * c_dim1];
2206 f12 = c[i + (j + 1) * c_dim1];
2207 f13 = c[i + (j + 2) * c_dim1];
2208 f14 = c[i + (j + 3) * c_dim1];
2209 i6 = ll + lsec - 1;
2210 for (l = ll; l <= i6; ++l)
2212 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2213 257] * b[l + j * b_dim1];
2214 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2215 257] * b[l + (j + 1) * b_dim1];
2216 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2217 257] * b[l + (j + 2) * b_dim1];
2218 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2219 257] * b[l + (j + 3) * b_dim1];
2221 c[i + j * c_dim1] = f11;
2222 c[i + (j + 1) * c_dim1] = f12;
2223 c[i + (j + 2) * c_dim1] = f13;
2224 c[i + (j + 3) * c_dim1] = f14;
2228 if (ujsec < jsec)
2230 i4 = jj + jsec - 1;
2231 for (j = jj + ujsec; j <= i4; ++j)
2233 i5 = ii + uisec - 1;
2234 for (i = ii; i <= i5; i += 4)
2236 f11 = c[i + j * c_dim1];
2237 f21 = c[i + 1 + j * c_dim1];
2238 f31 = c[i + 2 + j * c_dim1];
2239 f41 = c[i + 3 + j * c_dim1];
2240 i6 = ll + lsec - 1;
2241 for (l = ll; l <= i6; ++l)
2243 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2244 257] * b[l + j * b_dim1];
2245 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
2246 257] * b[l + j * b_dim1];
2247 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
2248 257] * b[l + j * b_dim1];
2249 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
2250 257] * b[l + j * b_dim1];
2252 c[i + j * c_dim1] = f11;
2253 c[i + 1 + j * c_dim1] = f21;
2254 c[i + 2 + j * c_dim1] = f31;
2255 c[i + 3 + j * c_dim1] = f41;
2257 i5 = ii + isec - 1;
2258 for (i = ii + uisec; i <= i5; ++i)
2260 f11 = c[i + j * c_dim1];
2261 i6 = ll + lsec - 1;
2262 for (l = ll; l <= i6; ++l)
2264 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2265 257] * b[l + j * b_dim1];
2267 c[i + j * c_dim1] = f11;
2274 free(t1);
2275 return;
2277 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
2279 if (GFC_DESCRIPTOR_RANK (a) != 1)
2281 const GFC_COMPLEX_16 *restrict abase_x;
2282 const GFC_COMPLEX_16 *restrict bbase_y;
2283 GFC_COMPLEX_16 *restrict dest_y;
2284 GFC_COMPLEX_16 s;
2286 for (y = 0; y < ycount; y++)
2288 bbase_y = &bbase[y*bystride];
2289 dest_y = &dest[y*rystride];
2290 for (x = 0; x < xcount; x++)
2292 abase_x = &abase[x*axstride];
2293 s = (GFC_COMPLEX_16) 0;
2294 for (n = 0; n < count; n++)
2295 s += abase_x[n] * bbase_y[n];
2296 dest_y[x] = s;
2300 else
2302 const GFC_COMPLEX_16 *restrict bbase_y;
2303 GFC_COMPLEX_16 s;
2305 for (y = 0; y < ycount; y++)
2307 bbase_y = &bbase[y*bystride];
2308 s = (GFC_COMPLEX_16) 0;
2309 for (n = 0; n < count; n++)
2310 s += abase[n*axstride] * bbase_y[n];
2311 dest[y*rystride] = s;
2315 else if (GFC_DESCRIPTOR_RANK (a) == 1)
2317 const GFC_COMPLEX_16 *restrict bbase_y;
2318 GFC_COMPLEX_16 s;
2320 for (y = 0; y < ycount; y++)
2322 bbase_y = &bbase[y*bystride];
2323 s = (GFC_COMPLEX_16) 0;
2324 for (n = 0; n < count; n++)
2325 s += abase[n*axstride] * bbase_y[n*bxstride];
2326 dest[y*rxstride] = s;
2329 else if (axstride < aystride)
2331 for (y = 0; y < ycount; y++)
2332 for (x = 0; x < xcount; x++)
2333 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_16)0;
2335 for (y = 0; y < ycount; y++)
2336 for (n = 0; n < count; n++)
2337 for (x = 0; x < xcount; x++)
2338 /* dest[x,y] += a[x,n] * b[n,y] */
2339 dest[x*rxstride + y*rystride] +=
2340 abase[x*axstride + n*aystride] *
2341 bbase[n*bxstride + y*bystride];
2343 else
2345 const GFC_COMPLEX_16 *restrict abase_x;
2346 const GFC_COMPLEX_16 *restrict bbase_y;
2347 GFC_COMPLEX_16 *restrict dest_y;
2348 GFC_COMPLEX_16 s;
2350 for (y = 0; y < ycount; y++)
2352 bbase_y = &bbase[y*bystride];
2353 dest_y = &dest[y*rystride];
2354 for (x = 0; x < xcount; x++)
2356 abase_x = &abase[x*axstride];
2357 s = (GFC_COMPLEX_16) 0;
2358 for (n = 0; n < count; n++)
2359 s += abase_x[n*aystride] * bbase_y[n*bxstride];
2360 dest_y[x*rxstride] = s;
2365 #undef POW3
2366 #undef min
2367 #undef max
2370 /* Compiling main function, with selection code for the processor. */
2372 /* Currently, this is i386 only. Adjust for other architectures. */
2374 void matmul_c16 (gfc_array_c16 * const restrict retarray,
2375 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
2376 int blas_limit, blas_call gemm)
2378 static void (*matmul_p) (gfc_array_c16 * const restrict retarray,
2379 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
2380 int blas_limit, blas_call gemm);
2382 void (*matmul_fn) (gfc_array_c16 * const restrict retarray,
2383 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
2384 int blas_limit, blas_call gemm);
2386 matmul_fn = __atomic_load_n (&matmul_p, __ATOMIC_RELAXED);
2387 if (matmul_fn == NULL)
2389 matmul_fn = matmul_c16_vanilla;
2390 if (__builtin_cpu_is ("intel"))
2392 /* Run down the available processors in order of preference. */
2393 #ifdef HAVE_AVX512F
2394 if (__builtin_cpu_supports ("avx512f"))
2396 matmul_fn = matmul_c16_avx512f;
2397 goto store;
2400 #endif /* HAVE_AVX512F */
2402 #ifdef HAVE_AVX2
2403 if (__builtin_cpu_supports ("avx2")
2404 && __builtin_cpu_supports ("fma"))
2406 matmul_fn = matmul_c16_avx2;
2407 goto store;
2410 #endif
2412 #ifdef HAVE_AVX
2413 if (__builtin_cpu_supports ("avx"))
2415 matmul_fn = matmul_c16_avx;
2416 goto store;
2418 #endif /* HAVE_AVX */
2420 else if (__builtin_cpu_is ("amd"))
2422 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
2423 if (__builtin_cpu_supports ("avx")
2424 && __builtin_cpu_supports ("fma"))
2426 matmul_fn = matmul_c16_avx128_fma3;
2427 goto store;
2429 #endif
2430 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
2431 if (__builtin_cpu_supports ("avx")
2432 && __builtin_cpu_supports ("fma4"))
2434 matmul_fn = matmul_c16_avx128_fma4;
2435 goto store;
2437 #endif
2440 store:
2441 __atomic_store_n (&matmul_p, matmul_fn, __ATOMIC_RELAXED);
2444 (*matmul_fn) (retarray, a, b, try_blas, blas_limit, gemm);
2447 #else /* Just the vanilla function. */
2449 void
2450 matmul_c16 (gfc_array_c16 * const restrict retarray,
2451 gfc_array_c16 * const restrict a, gfc_array_c16 * const restrict b, int try_blas,
2452 int blas_limit, blas_call gemm)
2454 const GFC_COMPLEX_16 * restrict abase;
2455 const GFC_COMPLEX_16 * restrict bbase;
2456 GFC_COMPLEX_16 * restrict dest;
2458 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
2459 index_type x, y, n, count, xcount, ycount;
2461 assert (GFC_DESCRIPTOR_RANK (a) == 2
2462 || GFC_DESCRIPTOR_RANK (b) == 2);
2464 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
2466 Either A or B (but not both) can be rank 1:
2468 o One-dimensional argument A is implicitly treated as a row matrix
2469 dimensioned [1,count], so xcount=1.
2471 o One-dimensional argument B is implicitly treated as a column matrix
2472 dimensioned [count, 1], so ycount=1.
2475 if (retarray->base_addr == NULL)
2477 if (GFC_DESCRIPTOR_RANK (a) == 1)
2479 GFC_DIMENSION_SET(retarray->dim[0], 0,
2480 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
2482 else if (GFC_DESCRIPTOR_RANK (b) == 1)
2484 GFC_DIMENSION_SET(retarray->dim[0], 0,
2485 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
2487 else
2489 GFC_DIMENSION_SET(retarray->dim[0], 0,
2490 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
2492 GFC_DIMENSION_SET(retarray->dim[1], 0,
2493 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
2494 GFC_DESCRIPTOR_EXTENT(retarray,0));
2497 retarray->base_addr
2498 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_COMPLEX_16));
2499 retarray->offset = 0;
2501 else if (unlikely (compile_options.bounds_check))
2503 index_type ret_extent, arg_extent;
2505 if (GFC_DESCRIPTOR_RANK (a) == 1)
2507 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
2508 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
2509 if (arg_extent != ret_extent)
2510 runtime_error ("Array bound mismatch for dimension 1 of "
2511 "array (%ld/%ld) ",
2512 (long int) ret_extent, (long int) arg_extent);
2514 else if (GFC_DESCRIPTOR_RANK (b) == 1)
2516 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
2517 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
2518 if (arg_extent != ret_extent)
2519 runtime_error ("Array bound mismatch for dimension 1 of "
2520 "array (%ld/%ld) ",
2521 (long int) ret_extent, (long int) arg_extent);
2523 else
2525 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
2526 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
2527 if (arg_extent != ret_extent)
2528 runtime_error ("Array bound mismatch for dimension 1 of "
2529 "array (%ld/%ld) ",
2530 (long int) ret_extent, (long int) arg_extent);
2532 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
2533 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
2534 if (arg_extent != ret_extent)
2535 runtime_error ("Array bound mismatch for dimension 2 of "
2536 "array (%ld/%ld) ",
2537 (long int) ret_extent, (long int) arg_extent);
2542 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
2544 /* One-dimensional result may be addressed in the code below
2545 either as a row or a column matrix. We want both cases to
2546 work. */
2547 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
2549 else
2551 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
2552 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
2556 if (GFC_DESCRIPTOR_RANK (a) == 1)
2558 /* Treat it as a a row matrix A[1,count]. */
2559 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
2560 aystride = 1;
2562 xcount = 1;
2563 count = GFC_DESCRIPTOR_EXTENT(a,0);
2565 else
2567 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
2568 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
2570 count = GFC_DESCRIPTOR_EXTENT(a,1);
2571 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
2574 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
2576 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
2577 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
2578 "in dimension 1: is %ld, should be %ld",
2579 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
2582 if (GFC_DESCRIPTOR_RANK (b) == 1)
2584 /* Treat it as a column matrix B[count,1] */
2585 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
2587 /* bystride should never be used for 1-dimensional b.
2588 The value is only used for calculation of the
2589 memory by the buffer. */
2590 bystride = 256;
2591 ycount = 1;
2593 else
2595 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
2596 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
2597 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
2600 abase = a->base_addr;
2601 bbase = b->base_addr;
2602 dest = retarray->base_addr;
2604 /* Now that everything is set up, we perform the multiplication
2605 itself. */
2607 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
2608 #define min(a,b) ((a) <= (b) ? (a) : (b))
2609 #define max(a,b) ((a) >= (b) ? (a) : (b))
2611 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
2612 && (bxstride == 1 || bystride == 1)
2613 && (((float) xcount) * ((float) ycount) * ((float) count)
2614 > POW3(blas_limit)))
2616 const int m = xcount, n = ycount, k = count, ldc = rystride;
2617 const GFC_COMPLEX_16 one = 1, zero = 0;
2618 const int lda = (axstride == 1) ? aystride : axstride,
2619 ldb = (bxstride == 1) ? bystride : bxstride;
2621 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
2623 assert (gemm != NULL);
2624 const char *transa, *transb;
2625 if (try_blas & 2)
2626 transa = "C";
2627 else
2628 transa = axstride == 1 ? "N" : "T";
2630 if (try_blas & 4)
2631 transb = "C";
2632 else
2633 transb = bxstride == 1 ? "N" : "T";
2635 gemm (transa, transb , &m,
2636 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
2637 &ldc, 1, 1);
2638 return;
2642 if (rxstride == 1 && axstride == 1 && bxstride == 1
2643 && GFC_DESCRIPTOR_RANK (b) != 1)
2645 /* This block of code implements a tuned matmul, derived from
2646 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
2648 Bo Kagstrom and Per Ling
2649 Department of Computing Science
2650 Umea University
2651 S-901 87 Umea, Sweden
2653 from netlib.org, translated to C, and modified for matmul.m4. */
2655 const GFC_COMPLEX_16 *a, *b;
2656 GFC_COMPLEX_16 *c;
2657 const index_type m = xcount, n = ycount, k = count;
2659 /* System generated locals */
2660 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
2661 i1, i2, i3, i4, i5, i6;
2663 /* Local variables */
2664 GFC_COMPLEX_16 f11, f12, f21, f22, f31, f32, f41, f42,
2665 f13, f14, f23, f24, f33, f34, f43, f44;
2666 index_type i, j, l, ii, jj, ll;
2667 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
2668 GFC_COMPLEX_16 *t1;
2670 a = abase;
2671 b = bbase;
2672 c = retarray->base_addr;
2674 /* Parameter adjustments */
2675 c_dim1 = rystride;
2676 c_offset = 1 + c_dim1;
2677 c -= c_offset;
2678 a_dim1 = aystride;
2679 a_offset = 1 + a_dim1;
2680 a -= a_offset;
2681 b_dim1 = bystride;
2682 b_offset = 1 + b_dim1;
2683 b -= b_offset;
2685 /* Empty c first. */
2686 for (j=1; j<=n; j++)
2687 for (i=1; i<=m; i++)
2688 c[i + j * c_dim1] = (GFC_COMPLEX_16)0;
2690 /* Early exit if possible */
2691 if (m == 0 || n == 0 || k == 0)
2692 return;
2694 /* Adjust size of t1 to what is needed. */
2695 index_type t1_dim, a_sz;
2696 if (aystride == 1)
2697 a_sz = rystride;
2698 else
2699 a_sz = a_dim1;
2701 t1_dim = a_sz * 256 + b_dim1;
2702 if (t1_dim > 65536)
2703 t1_dim = 65536;
2705 t1 = malloc (t1_dim * sizeof(GFC_COMPLEX_16));
2707 /* Start turning the crank. */
2708 i1 = n;
2709 for (jj = 1; jj <= i1; jj += 512)
2711 /* Computing MIN */
2712 i2 = 512;
2713 i3 = n - jj + 1;
2714 jsec = min(i2,i3);
2715 ujsec = jsec - jsec % 4;
2716 i2 = k;
2717 for (ll = 1; ll <= i2; ll += 256)
2719 /* Computing MIN */
2720 i3 = 256;
2721 i4 = k - ll + 1;
2722 lsec = min(i3,i4);
2723 ulsec = lsec - lsec % 2;
2725 i3 = m;
2726 for (ii = 1; ii <= i3; ii += 256)
2728 /* Computing MIN */
2729 i4 = 256;
2730 i5 = m - ii + 1;
2731 isec = min(i4,i5);
2732 uisec = isec - isec % 2;
2733 i4 = ll + ulsec - 1;
2734 for (l = ll; l <= i4; l += 2)
2736 i5 = ii + uisec - 1;
2737 for (i = ii; i <= i5; i += 2)
2739 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
2740 a[i + l * a_dim1];
2741 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
2742 a[i + (l + 1) * a_dim1];
2743 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
2744 a[i + 1 + l * a_dim1];
2745 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
2746 a[i + 1 + (l + 1) * a_dim1];
2748 if (uisec < isec)
2750 t1[l - ll + 1 + (isec << 8) - 257] =
2751 a[ii + isec - 1 + l * a_dim1];
2752 t1[l - ll + 2 + (isec << 8) - 257] =
2753 a[ii + isec - 1 + (l + 1) * a_dim1];
2756 if (ulsec < lsec)
2758 i4 = ii + isec - 1;
2759 for (i = ii; i<= i4; ++i)
2761 t1[lsec + ((i - ii + 1) << 8) - 257] =
2762 a[i + (ll + lsec - 1) * a_dim1];
2766 uisec = isec - isec % 4;
2767 i4 = jj + ujsec - 1;
2768 for (j = jj; j <= i4; j += 4)
2770 i5 = ii + uisec - 1;
2771 for (i = ii; i <= i5; i += 4)
2773 f11 = c[i + j * c_dim1];
2774 f21 = c[i + 1 + j * c_dim1];
2775 f12 = c[i + (j + 1) * c_dim1];
2776 f22 = c[i + 1 + (j + 1) * c_dim1];
2777 f13 = c[i + (j + 2) * c_dim1];
2778 f23 = c[i + 1 + (j + 2) * c_dim1];
2779 f14 = c[i + (j + 3) * c_dim1];
2780 f24 = c[i + 1 + (j + 3) * c_dim1];
2781 f31 = c[i + 2 + j * c_dim1];
2782 f41 = c[i + 3 + j * c_dim1];
2783 f32 = c[i + 2 + (j + 1) * c_dim1];
2784 f42 = c[i + 3 + (j + 1) * c_dim1];
2785 f33 = c[i + 2 + (j + 2) * c_dim1];
2786 f43 = c[i + 3 + (j + 2) * c_dim1];
2787 f34 = c[i + 2 + (j + 3) * c_dim1];
2788 f44 = c[i + 3 + (j + 3) * c_dim1];
2789 i6 = ll + lsec - 1;
2790 for (l = ll; l <= i6; ++l)
2792 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2793 * b[l + j * b_dim1];
2794 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2795 * b[l + j * b_dim1];
2796 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2797 * b[l + (j + 1) * b_dim1];
2798 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2799 * b[l + (j + 1) * b_dim1];
2800 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2801 * b[l + (j + 2) * b_dim1];
2802 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2803 * b[l + (j + 2) * b_dim1];
2804 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2805 * b[l + (j + 3) * b_dim1];
2806 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2807 * b[l + (j + 3) * b_dim1];
2808 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2809 * b[l + j * b_dim1];
2810 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2811 * b[l + j * b_dim1];
2812 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2813 * b[l + (j + 1) * b_dim1];
2814 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2815 * b[l + (j + 1) * b_dim1];
2816 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2817 * b[l + (j + 2) * b_dim1];
2818 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2819 * b[l + (j + 2) * b_dim1];
2820 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2821 * b[l + (j + 3) * b_dim1];
2822 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2823 * b[l + (j + 3) * b_dim1];
2825 c[i + j * c_dim1] = f11;
2826 c[i + 1 + j * c_dim1] = f21;
2827 c[i + (j + 1) * c_dim1] = f12;
2828 c[i + 1 + (j + 1) * c_dim1] = f22;
2829 c[i + (j + 2) * c_dim1] = f13;
2830 c[i + 1 + (j + 2) * c_dim1] = f23;
2831 c[i + (j + 3) * c_dim1] = f14;
2832 c[i + 1 + (j + 3) * c_dim1] = f24;
2833 c[i + 2 + j * c_dim1] = f31;
2834 c[i + 3 + j * c_dim1] = f41;
2835 c[i + 2 + (j + 1) * c_dim1] = f32;
2836 c[i + 3 + (j + 1) * c_dim1] = f42;
2837 c[i + 2 + (j + 2) * c_dim1] = f33;
2838 c[i + 3 + (j + 2) * c_dim1] = f43;
2839 c[i + 2 + (j + 3) * c_dim1] = f34;
2840 c[i + 3 + (j + 3) * c_dim1] = f44;
2842 if (uisec < isec)
2844 i5 = ii + isec - 1;
2845 for (i = ii + uisec; i <= i5; ++i)
2847 f11 = c[i + j * c_dim1];
2848 f12 = c[i + (j + 1) * c_dim1];
2849 f13 = c[i + (j + 2) * c_dim1];
2850 f14 = c[i + (j + 3) * c_dim1];
2851 i6 = ll + lsec - 1;
2852 for (l = ll; l <= i6; ++l)
2854 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2855 257] * b[l + j * b_dim1];
2856 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2857 257] * b[l + (j + 1) * b_dim1];
2858 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2859 257] * b[l + (j + 2) * b_dim1];
2860 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2861 257] * b[l + (j + 3) * b_dim1];
2863 c[i + j * c_dim1] = f11;
2864 c[i + (j + 1) * c_dim1] = f12;
2865 c[i + (j + 2) * c_dim1] = f13;
2866 c[i + (j + 3) * c_dim1] = f14;
2870 if (ujsec < jsec)
2872 i4 = jj + jsec - 1;
2873 for (j = jj + ujsec; j <= i4; ++j)
2875 i5 = ii + uisec - 1;
2876 for (i = ii; i <= i5; i += 4)
2878 f11 = c[i + j * c_dim1];
2879 f21 = c[i + 1 + j * c_dim1];
2880 f31 = c[i + 2 + j * c_dim1];
2881 f41 = c[i + 3 + j * c_dim1];
2882 i6 = ll + lsec - 1;
2883 for (l = ll; l <= i6; ++l)
2885 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2886 257] * b[l + j * b_dim1];
2887 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
2888 257] * b[l + j * b_dim1];
2889 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
2890 257] * b[l + j * b_dim1];
2891 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
2892 257] * b[l + j * b_dim1];
2894 c[i + j * c_dim1] = f11;
2895 c[i + 1 + j * c_dim1] = f21;
2896 c[i + 2 + j * c_dim1] = f31;
2897 c[i + 3 + j * c_dim1] = f41;
2899 i5 = ii + isec - 1;
2900 for (i = ii + uisec; i <= i5; ++i)
2902 f11 = c[i + j * c_dim1];
2903 i6 = ll + lsec - 1;
2904 for (l = ll; l <= i6; ++l)
2906 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2907 257] * b[l + j * b_dim1];
2909 c[i + j * c_dim1] = f11;
2916 free(t1);
2917 return;
2919 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
2921 if (GFC_DESCRIPTOR_RANK (a) != 1)
2923 const GFC_COMPLEX_16 *restrict abase_x;
2924 const GFC_COMPLEX_16 *restrict bbase_y;
2925 GFC_COMPLEX_16 *restrict dest_y;
2926 GFC_COMPLEX_16 s;
2928 for (y = 0; y < ycount; y++)
2930 bbase_y = &bbase[y*bystride];
2931 dest_y = &dest[y*rystride];
2932 for (x = 0; x < xcount; x++)
2934 abase_x = &abase[x*axstride];
2935 s = (GFC_COMPLEX_16) 0;
2936 for (n = 0; n < count; n++)
2937 s += abase_x[n] * bbase_y[n];
2938 dest_y[x] = s;
2942 else
2944 const GFC_COMPLEX_16 *restrict bbase_y;
2945 GFC_COMPLEX_16 s;
2947 for (y = 0; y < ycount; y++)
2949 bbase_y = &bbase[y*bystride];
2950 s = (GFC_COMPLEX_16) 0;
2951 for (n = 0; n < count; n++)
2952 s += abase[n*axstride] * bbase_y[n];
2953 dest[y*rystride] = s;
2957 else if (GFC_DESCRIPTOR_RANK (a) == 1)
2959 const GFC_COMPLEX_16 *restrict bbase_y;
2960 GFC_COMPLEX_16 s;
2962 for (y = 0; y < ycount; y++)
2964 bbase_y = &bbase[y*bystride];
2965 s = (GFC_COMPLEX_16) 0;
2966 for (n = 0; n < count; n++)
2967 s += abase[n*axstride] * bbase_y[n*bxstride];
2968 dest[y*rxstride] = s;
2971 else if (axstride < aystride)
2973 for (y = 0; y < ycount; y++)
2974 for (x = 0; x < xcount; x++)
2975 dest[x*rxstride + y*rystride] = (GFC_COMPLEX_16)0;
2977 for (y = 0; y < ycount; y++)
2978 for (n = 0; n < count; n++)
2979 for (x = 0; x < xcount; x++)
2980 /* dest[x,y] += a[x,n] * b[n,y] */
2981 dest[x*rxstride + y*rystride] +=
2982 abase[x*axstride + n*aystride] *
2983 bbase[n*bxstride + y*bystride];
2985 else
2987 const GFC_COMPLEX_16 *restrict abase_x;
2988 const GFC_COMPLEX_16 *restrict bbase_y;
2989 GFC_COMPLEX_16 *restrict dest_y;
2990 GFC_COMPLEX_16 s;
2992 for (y = 0; y < ycount; y++)
2994 bbase_y = &bbase[y*bystride];
2995 dest_y = &dest[y*rystride];
2996 for (x = 0; x < xcount; x++)
2998 abase_x = &abase[x*axstride];
2999 s = (GFC_COMPLEX_16) 0;
3000 for (n = 0; n < count; n++)
3001 s += abase_x[n*aystride] * bbase_y[n*bxstride];
3002 dest_y[x*rxstride] = s;
3007 #undef POW3
3008 #undef min
3009 #undef max
3011 #endif
3012 #endif