1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2018 Free Software Foundation, Inc.
3 Contributed by Thomas Koenig <tkoenig@gcc.gnu.org>.
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
31 /* These are the specific versions of matmul with -mprefer-avx128. */
33 #if defined (HAVE_GFC_COMPLEX_10)
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36 passed to us by the front-end, in which case we call it for large
39 typedef void (*blas_call
)(const char *, const char *, const int *, const int *,
40 const int *, const GFC_COMPLEX_10
*, const GFC_COMPLEX_10
*,
41 const int *, const GFC_COMPLEX_10
*, const int *,
42 const GFC_COMPLEX_10
*, GFC_COMPLEX_10
*, const int *,
45 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
47 matmul_c10_avx128_fma3 (gfc_array_c10
* const restrict retarray
,
48 gfc_array_c10
* const restrict a
, gfc_array_c10
* const restrict b
, int try_blas
,
49 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma")));
50 internal_proto(matmul_c10_avx128_fma3
);
52 matmul_c10_avx128_fma3 (gfc_array_c10
* const restrict retarray
,
53 gfc_array_c10
* const restrict a
, gfc_array_c10
* const restrict b
, int try_blas
,
54 int blas_limit
, blas_call gemm
)
56 const GFC_COMPLEX_10
* restrict abase
;
57 const GFC_COMPLEX_10
* restrict bbase
;
58 GFC_COMPLEX_10
* restrict dest
;
60 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
61 index_type x
, y
, n
, count
, xcount
, ycount
;
63 assert (GFC_DESCRIPTOR_RANK (a
) == 2
64 || GFC_DESCRIPTOR_RANK (b
) == 2);
66 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
68 Either A or B (but not both) can be rank 1:
70 o One-dimensional argument A is implicitly treated as a row matrix
71 dimensioned [1,count], so xcount=1.
73 o One-dimensional argument B is implicitly treated as a column matrix
74 dimensioned [count, 1], so ycount=1.
77 if (retarray
->base_addr
== NULL
)
79 if (GFC_DESCRIPTOR_RANK (a
) == 1)
81 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
82 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
84 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
86 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
87 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
91 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
92 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
94 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
95 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
96 GFC_DESCRIPTOR_EXTENT(retarray
,0));
100 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_COMPLEX_10
));
101 retarray
->offset
= 0;
103 else if (unlikely (compile_options
.bounds_check
))
105 index_type ret_extent
, arg_extent
;
107 if (GFC_DESCRIPTOR_RANK (a
) == 1)
109 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
110 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
111 if (arg_extent
!= ret_extent
)
112 runtime_error ("Array bound mismatch for dimension 1 of "
114 (long int) ret_extent
, (long int) arg_extent
);
116 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
118 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
119 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
120 if (arg_extent
!= ret_extent
)
121 runtime_error ("Array bound mismatch for dimension 1 of "
123 (long int) ret_extent
, (long int) arg_extent
);
127 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
128 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
129 if (arg_extent
!= ret_extent
)
130 runtime_error ("Array bound mismatch for dimension 1 of "
132 (long int) ret_extent
, (long int) arg_extent
);
134 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
135 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
136 if (arg_extent
!= ret_extent
)
137 runtime_error ("Array bound mismatch for dimension 2 of "
139 (long int) ret_extent
, (long int) arg_extent
);
144 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
146 /* One-dimensional result may be addressed in the code below
147 either as a row or a column matrix. We want both cases to
149 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
153 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
154 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
158 if (GFC_DESCRIPTOR_RANK (a
) == 1)
160 /* Treat it as a a row matrix A[1,count]. */
161 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
165 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
169 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
170 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
172 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
173 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
176 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
178 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
179 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
180 "in dimension 1: is %ld, should be %ld",
181 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
184 if (GFC_DESCRIPTOR_RANK (b
) == 1)
186 /* Treat it as a column matrix B[count,1] */
187 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
189 /* bystride should never be used for 1-dimensional b.
190 The value is only used for calculation of the
191 memory by the buffer. */
197 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
198 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
199 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
202 abase
= a
->base_addr
;
203 bbase
= b
->base_addr
;
204 dest
= retarray
->base_addr
;
206 /* Now that everything is set up, we perform the multiplication
209 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
210 #define min(a,b) ((a) <= (b) ? (a) : (b))
211 #define max(a,b) ((a) >= (b) ? (a) : (b))
213 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
214 && (bxstride
== 1 || bystride
== 1)
215 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
218 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
219 const GFC_COMPLEX_10 one
= 1, zero
= 0;
220 const int lda
= (axstride
== 1) ? aystride
: axstride
,
221 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
223 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
225 assert (gemm
!= NULL
);
226 const char *transa
, *transb
;
230 transa
= axstride
== 1 ? "N" : "T";
235 transb
= bxstride
== 1 ? "N" : "T";
237 gemm (transa
, transb
, &m
,
238 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
244 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
246 /* This block of code implements a tuned matmul, derived from
247 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
249 Bo Kagstrom and Per Ling
250 Department of Computing Science
252 S-901 87 Umea, Sweden
254 from netlib.org, translated to C, and modified for matmul.m4. */
256 const GFC_COMPLEX_10
*a
, *b
;
258 const index_type m
= xcount
, n
= ycount
, k
= count
;
260 /* System generated locals */
261 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
262 i1
, i2
, i3
, i4
, i5
, i6
;
264 /* Local variables */
265 GFC_COMPLEX_10 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
266 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
267 index_type i
, j
, l
, ii
, jj
, ll
;
268 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
273 c
= retarray
->base_addr
;
275 /* Parameter adjustments */
277 c_offset
= 1 + c_dim1
;
280 a_offset
= 1 + a_dim1
;
283 b_offset
= 1 + b_dim1
;
289 c
[i
+ j
* c_dim1
] = (GFC_COMPLEX_10
)0;
291 /* Early exit if possible */
292 if (m
== 0 || n
== 0 || k
== 0)
295 /* Adjust size of t1 to what is needed. */
296 index_type t1_dim
, a_sz
;
302 t1_dim
= a_sz
* 256 + b_dim1
;
306 t1
= malloc (t1_dim
* sizeof(GFC_COMPLEX_10
));
308 /* Start turning the crank. */
310 for (jj
= 1; jj
<= i1
; jj
+= 512)
316 ujsec
= jsec
- jsec
% 4;
318 for (ll
= 1; ll
<= i2
; ll
+= 256)
324 ulsec
= lsec
- lsec
% 2;
327 for (ii
= 1; ii
<= i3
; ii
+= 256)
333 uisec
= isec
- isec
% 2;
335 for (l
= ll
; l
<= i4
; l
+= 2)
338 for (i
= ii
; i
<= i5
; i
+= 2)
340 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
342 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
343 a
[i
+ (l
+ 1) * a_dim1
];
344 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
345 a
[i
+ 1 + l
* a_dim1
];
346 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
347 a
[i
+ 1 + (l
+ 1) * a_dim1
];
351 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
352 a
[ii
+ isec
- 1 + l
* a_dim1
];
353 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
354 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
360 for (i
= ii
; i
<= i4
; ++i
)
362 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
363 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
367 uisec
= isec
- isec
% 4;
369 for (j
= jj
; j
<= i4
; j
+= 4)
372 for (i
= ii
; i
<= i5
; i
+= 4)
374 f11
= c
[i
+ j
* c_dim1
];
375 f21
= c
[i
+ 1 + j
* c_dim1
];
376 f12
= c
[i
+ (j
+ 1) * c_dim1
];
377 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
378 f13
= c
[i
+ (j
+ 2) * c_dim1
];
379 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
380 f14
= c
[i
+ (j
+ 3) * c_dim1
];
381 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
382 f31
= c
[i
+ 2 + j
* c_dim1
];
383 f41
= c
[i
+ 3 + j
* c_dim1
];
384 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
385 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
386 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
387 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
388 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
389 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
391 for (l
= ll
; l
<= i6
; ++l
)
393 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
395 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
397 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
398 * b
[l
+ (j
+ 1) * b_dim1
];
399 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
400 * b
[l
+ (j
+ 1) * b_dim1
];
401 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
402 * b
[l
+ (j
+ 2) * b_dim1
];
403 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
404 * b
[l
+ (j
+ 2) * b_dim1
];
405 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
406 * b
[l
+ (j
+ 3) * b_dim1
];
407 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
408 * b
[l
+ (j
+ 3) * b_dim1
];
409 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
411 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
413 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
414 * b
[l
+ (j
+ 1) * b_dim1
];
415 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
416 * b
[l
+ (j
+ 1) * b_dim1
];
417 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
418 * b
[l
+ (j
+ 2) * b_dim1
];
419 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
420 * b
[l
+ (j
+ 2) * b_dim1
];
421 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
422 * b
[l
+ (j
+ 3) * b_dim1
];
423 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
424 * b
[l
+ (j
+ 3) * b_dim1
];
426 c
[i
+ j
* c_dim1
] = f11
;
427 c
[i
+ 1 + j
* c_dim1
] = f21
;
428 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
429 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
430 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
431 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
432 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
433 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
434 c
[i
+ 2 + j
* c_dim1
] = f31
;
435 c
[i
+ 3 + j
* c_dim1
] = f41
;
436 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
437 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
438 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
439 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
440 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
441 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
446 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
448 f11
= c
[i
+ j
* c_dim1
];
449 f12
= c
[i
+ (j
+ 1) * c_dim1
];
450 f13
= c
[i
+ (j
+ 2) * c_dim1
];
451 f14
= c
[i
+ (j
+ 3) * c_dim1
];
453 for (l
= ll
; l
<= i6
; ++l
)
455 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
456 257] * b
[l
+ j
* b_dim1
];
457 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
458 257] * b
[l
+ (j
+ 1) * b_dim1
];
459 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
460 257] * b
[l
+ (j
+ 2) * b_dim1
];
461 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
462 257] * b
[l
+ (j
+ 3) * b_dim1
];
464 c
[i
+ j
* c_dim1
] = f11
;
465 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
466 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
467 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
474 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
477 for (i
= ii
; i
<= i5
; i
+= 4)
479 f11
= c
[i
+ j
* c_dim1
];
480 f21
= c
[i
+ 1 + j
* c_dim1
];
481 f31
= c
[i
+ 2 + j
* c_dim1
];
482 f41
= c
[i
+ 3 + j
* c_dim1
];
484 for (l
= ll
; l
<= i6
; ++l
)
486 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
487 257] * b
[l
+ j
* b_dim1
];
488 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
489 257] * b
[l
+ j
* b_dim1
];
490 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
491 257] * b
[l
+ j
* b_dim1
];
492 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
493 257] * b
[l
+ j
* b_dim1
];
495 c
[i
+ j
* c_dim1
] = f11
;
496 c
[i
+ 1 + j
* c_dim1
] = f21
;
497 c
[i
+ 2 + j
* c_dim1
] = f31
;
498 c
[i
+ 3 + j
* c_dim1
] = f41
;
501 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
503 f11
= c
[i
+ j
* c_dim1
];
505 for (l
= ll
; l
<= i6
; ++l
)
507 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
508 257] * b
[l
+ j
* b_dim1
];
510 c
[i
+ j
* c_dim1
] = f11
;
520 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
522 if (GFC_DESCRIPTOR_RANK (a
) != 1)
524 const GFC_COMPLEX_10
*restrict abase_x
;
525 const GFC_COMPLEX_10
*restrict bbase_y
;
526 GFC_COMPLEX_10
*restrict dest_y
;
529 for (y
= 0; y
< ycount
; y
++)
531 bbase_y
= &bbase
[y
*bystride
];
532 dest_y
= &dest
[y
*rystride
];
533 for (x
= 0; x
< xcount
; x
++)
535 abase_x
= &abase
[x
*axstride
];
536 s
= (GFC_COMPLEX_10
) 0;
537 for (n
= 0; n
< count
; n
++)
538 s
+= abase_x
[n
] * bbase_y
[n
];
545 const GFC_COMPLEX_10
*restrict bbase_y
;
548 for (y
= 0; y
< ycount
; y
++)
550 bbase_y
= &bbase
[y
*bystride
];
551 s
= (GFC_COMPLEX_10
) 0;
552 for (n
= 0; n
< count
; n
++)
553 s
+= abase
[n
*axstride
] * bbase_y
[n
];
554 dest
[y
*rystride
] = s
;
558 else if (axstride
< aystride
)
560 for (y
= 0; y
< ycount
; y
++)
561 for (x
= 0; x
< xcount
; x
++)
562 dest
[x
*rxstride
+ y
*rystride
] = (GFC_COMPLEX_10
)0;
564 for (y
= 0; y
< ycount
; y
++)
565 for (n
= 0; n
< count
; n
++)
566 for (x
= 0; x
< xcount
; x
++)
567 /* dest[x,y] += a[x,n] * b[n,y] */
568 dest
[x
*rxstride
+ y
*rystride
] +=
569 abase
[x
*axstride
+ n
*aystride
] *
570 bbase
[n
*bxstride
+ y
*bystride
];
572 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
574 const GFC_COMPLEX_10
*restrict bbase_y
;
577 for (y
= 0; y
< ycount
; y
++)
579 bbase_y
= &bbase
[y
*bystride
];
580 s
= (GFC_COMPLEX_10
) 0;
581 for (n
= 0; n
< count
; n
++)
582 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
583 dest
[y
*rxstride
] = s
;
588 const GFC_COMPLEX_10
*restrict abase_x
;
589 const GFC_COMPLEX_10
*restrict bbase_y
;
590 GFC_COMPLEX_10
*restrict dest_y
;
593 for (y
= 0; y
< ycount
; y
++)
595 bbase_y
= &bbase
[y
*bystride
];
596 dest_y
= &dest
[y
*rystride
];
597 for (x
= 0; x
< xcount
; x
++)
599 abase_x
= &abase
[x
*axstride
];
600 s
= (GFC_COMPLEX_10
) 0;
601 for (n
= 0; n
< count
; n
++)
602 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
603 dest_y
[x
*rxstride
] = s
;
614 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
616 matmul_c10_avx128_fma4 (gfc_array_c10
* const restrict retarray
,
617 gfc_array_c10
* const restrict a
, gfc_array_c10
* const restrict b
, int try_blas
,
618 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma4")));
619 internal_proto(matmul_c10_avx128_fma4
);
621 matmul_c10_avx128_fma4 (gfc_array_c10
* const restrict retarray
,
622 gfc_array_c10
* const restrict a
, gfc_array_c10
* const restrict b
, int try_blas
,
623 int blas_limit
, blas_call gemm
)
625 const GFC_COMPLEX_10
* restrict abase
;
626 const GFC_COMPLEX_10
* restrict bbase
;
627 GFC_COMPLEX_10
* restrict dest
;
629 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
630 index_type x
, y
, n
, count
, xcount
, ycount
;
632 assert (GFC_DESCRIPTOR_RANK (a
) == 2
633 || GFC_DESCRIPTOR_RANK (b
) == 2);
635 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
637 Either A or B (but not both) can be rank 1:
639 o One-dimensional argument A is implicitly treated as a row matrix
640 dimensioned [1,count], so xcount=1.
642 o One-dimensional argument B is implicitly treated as a column matrix
643 dimensioned [count, 1], so ycount=1.
646 if (retarray
->base_addr
== NULL
)
648 if (GFC_DESCRIPTOR_RANK (a
) == 1)
650 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
651 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
653 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
655 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
656 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
660 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
661 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
663 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
664 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
665 GFC_DESCRIPTOR_EXTENT(retarray
,0));
669 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_COMPLEX_10
));
670 retarray
->offset
= 0;
672 else if (unlikely (compile_options
.bounds_check
))
674 index_type ret_extent
, arg_extent
;
676 if (GFC_DESCRIPTOR_RANK (a
) == 1)
678 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
679 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
680 if (arg_extent
!= ret_extent
)
681 runtime_error ("Array bound mismatch for dimension 1 of "
683 (long int) ret_extent
, (long int) arg_extent
);
685 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
687 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
688 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
689 if (arg_extent
!= ret_extent
)
690 runtime_error ("Array bound mismatch for dimension 1 of "
692 (long int) ret_extent
, (long int) arg_extent
);
696 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
697 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
698 if (arg_extent
!= ret_extent
)
699 runtime_error ("Array bound mismatch for dimension 1 of "
701 (long int) ret_extent
, (long int) arg_extent
);
703 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
704 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
705 if (arg_extent
!= ret_extent
)
706 runtime_error ("Array bound mismatch for dimension 2 of "
708 (long int) ret_extent
, (long int) arg_extent
);
713 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
715 /* One-dimensional result may be addressed in the code below
716 either as a row or a column matrix. We want both cases to
718 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
722 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
723 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
727 if (GFC_DESCRIPTOR_RANK (a
) == 1)
729 /* Treat it as a a row matrix A[1,count]. */
730 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
734 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
738 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
739 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
741 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
742 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
745 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
747 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
748 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
749 "in dimension 1: is %ld, should be %ld",
750 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
753 if (GFC_DESCRIPTOR_RANK (b
) == 1)
755 /* Treat it as a column matrix B[count,1] */
756 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
758 /* bystride should never be used for 1-dimensional b.
759 The value is only used for calculation of the
760 memory by the buffer. */
766 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
767 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
768 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
771 abase
= a
->base_addr
;
772 bbase
= b
->base_addr
;
773 dest
= retarray
->base_addr
;
775 /* Now that everything is set up, we perform the multiplication
778 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
779 #define min(a,b) ((a) <= (b) ? (a) : (b))
780 #define max(a,b) ((a) >= (b) ? (a) : (b))
782 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
783 && (bxstride
== 1 || bystride
== 1)
784 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
787 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
788 const GFC_COMPLEX_10 one
= 1, zero
= 0;
789 const int lda
= (axstride
== 1) ? aystride
: axstride
,
790 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
792 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
794 assert (gemm
!= NULL
);
795 const char *transa
, *transb
;
799 transa
= axstride
== 1 ? "N" : "T";
804 transb
= bxstride
== 1 ? "N" : "T";
806 gemm (transa
, transb
, &m
,
807 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
813 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
815 /* This block of code implements a tuned matmul, derived from
816 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
818 Bo Kagstrom and Per Ling
819 Department of Computing Science
821 S-901 87 Umea, Sweden
823 from netlib.org, translated to C, and modified for matmul.m4. */
825 const GFC_COMPLEX_10
*a
, *b
;
827 const index_type m
= xcount
, n
= ycount
, k
= count
;
829 /* System generated locals */
830 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
831 i1
, i2
, i3
, i4
, i5
, i6
;
833 /* Local variables */
834 GFC_COMPLEX_10 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
835 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
836 index_type i
, j
, l
, ii
, jj
, ll
;
837 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
842 c
= retarray
->base_addr
;
844 /* Parameter adjustments */
846 c_offset
= 1 + c_dim1
;
849 a_offset
= 1 + a_dim1
;
852 b_offset
= 1 + b_dim1
;
858 c
[i
+ j
* c_dim1
] = (GFC_COMPLEX_10
)0;
860 /* Early exit if possible */
861 if (m
== 0 || n
== 0 || k
== 0)
864 /* Adjust size of t1 to what is needed. */
865 index_type t1_dim
, a_sz
;
871 t1_dim
= a_sz
* 256 + b_dim1
;
875 t1
= malloc (t1_dim
* sizeof(GFC_COMPLEX_10
));
877 /* Start turning the crank. */
879 for (jj
= 1; jj
<= i1
; jj
+= 512)
885 ujsec
= jsec
- jsec
% 4;
887 for (ll
= 1; ll
<= i2
; ll
+= 256)
893 ulsec
= lsec
- lsec
% 2;
896 for (ii
= 1; ii
<= i3
; ii
+= 256)
902 uisec
= isec
- isec
% 2;
904 for (l
= ll
; l
<= i4
; l
+= 2)
907 for (i
= ii
; i
<= i5
; i
+= 2)
909 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
911 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
912 a
[i
+ (l
+ 1) * a_dim1
];
913 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
914 a
[i
+ 1 + l
* a_dim1
];
915 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
916 a
[i
+ 1 + (l
+ 1) * a_dim1
];
920 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
921 a
[ii
+ isec
- 1 + l
* a_dim1
];
922 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
923 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
929 for (i
= ii
; i
<= i4
; ++i
)
931 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
932 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
936 uisec
= isec
- isec
% 4;
938 for (j
= jj
; j
<= i4
; j
+= 4)
941 for (i
= ii
; i
<= i5
; i
+= 4)
943 f11
= c
[i
+ j
* c_dim1
];
944 f21
= c
[i
+ 1 + j
* c_dim1
];
945 f12
= c
[i
+ (j
+ 1) * c_dim1
];
946 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
947 f13
= c
[i
+ (j
+ 2) * c_dim1
];
948 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
949 f14
= c
[i
+ (j
+ 3) * c_dim1
];
950 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
951 f31
= c
[i
+ 2 + j
* c_dim1
];
952 f41
= c
[i
+ 3 + j
* c_dim1
];
953 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
954 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
955 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
956 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
957 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
958 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
960 for (l
= ll
; l
<= i6
; ++l
)
962 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
964 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
966 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
967 * b
[l
+ (j
+ 1) * b_dim1
];
968 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
969 * b
[l
+ (j
+ 1) * b_dim1
];
970 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
971 * b
[l
+ (j
+ 2) * b_dim1
];
972 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
973 * b
[l
+ (j
+ 2) * b_dim1
];
974 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
975 * b
[l
+ (j
+ 3) * b_dim1
];
976 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
977 * b
[l
+ (j
+ 3) * b_dim1
];
978 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
980 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
982 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
983 * b
[l
+ (j
+ 1) * b_dim1
];
984 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
985 * b
[l
+ (j
+ 1) * b_dim1
];
986 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
987 * b
[l
+ (j
+ 2) * b_dim1
];
988 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
989 * b
[l
+ (j
+ 2) * b_dim1
];
990 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
991 * b
[l
+ (j
+ 3) * b_dim1
];
992 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
993 * b
[l
+ (j
+ 3) * b_dim1
];
995 c
[i
+ j
* c_dim1
] = f11
;
996 c
[i
+ 1 + j
* c_dim1
] = f21
;
997 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
998 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
999 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1000 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
1001 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1002 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
1003 c
[i
+ 2 + j
* c_dim1
] = f31
;
1004 c
[i
+ 3 + j
* c_dim1
] = f41
;
1005 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
1006 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
1007 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
1008 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
1009 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
1010 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
1015 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1017 f11
= c
[i
+ j
* c_dim1
];
1018 f12
= c
[i
+ (j
+ 1) * c_dim1
];
1019 f13
= c
[i
+ (j
+ 2) * c_dim1
];
1020 f14
= c
[i
+ (j
+ 3) * c_dim1
];
1022 for (l
= ll
; l
<= i6
; ++l
)
1024 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1025 257] * b
[l
+ j
* b_dim1
];
1026 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1027 257] * b
[l
+ (j
+ 1) * b_dim1
];
1028 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1029 257] * b
[l
+ (j
+ 2) * b_dim1
];
1030 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1031 257] * b
[l
+ (j
+ 3) * b_dim1
];
1033 c
[i
+ j
* c_dim1
] = f11
;
1034 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1035 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1036 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1043 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
1045 i5
= ii
+ uisec
- 1;
1046 for (i
= ii
; i
<= i5
; i
+= 4)
1048 f11
= c
[i
+ j
* c_dim1
];
1049 f21
= c
[i
+ 1 + j
* c_dim1
];
1050 f31
= c
[i
+ 2 + j
* c_dim1
];
1051 f41
= c
[i
+ 3 + j
* c_dim1
];
1053 for (l
= ll
; l
<= i6
; ++l
)
1055 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1056 257] * b
[l
+ j
* b_dim1
];
1057 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
1058 257] * b
[l
+ j
* b_dim1
];
1059 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
1060 257] * b
[l
+ j
* b_dim1
];
1061 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
1062 257] * b
[l
+ j
* b_dim1
];
1064 c
[i
+ j
* c_dim1
] = f11
;
1065 c
[i
+ 1 + j
* c_dim1
] = f21
;
1066 c
[i
+ 2 + j
* c_dim1
] = f31
;
1067 c
[i
+ 3 + j
* c_dim1
] = f41
;
1070 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1072 f11
= c
[i
+ j
* c_dim1
];
1074 for (l
= ll
; l
<= i6
; ++l
)
1076 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1077 257] * b
[l
+ j
* b_dim1
];
1079 c
[i
+ j
* c_dim1
] = f11
;
1089 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
1091 if (GFC_DESCRIPTOR_RANK (a
) != 1)
1093 const GFC_COMPLEX_10
*restrict abase_x
;
1094 const GFC_COMPLEX_10
*restrict bbase_y
;
1095 GFC_COMPLEX_10
*restrict dest_y
;
1098 for (y
= 0; y
< ycount
; y
++)
1100 bbase_y
= &bbase
[y
*bystride
];
1101 dest_y
= &dest
[y
*rystride
];
1102 for (x
= 0; x
< xcount
; x
++)
1104 abase_x
= &abase
[x
*axstride
];
1105 s
= (GFC_COMPLEX_10
) 0;
1106 for (n
= 0; n
< count
; n
++)
1107 s
+= abase_x
[n
] * bbase_y
[n
];
1114 const GFC_COMPLEX_10
*restrict bbase_y
;
1117 for (y
= 0; y
< ycount
; y
++)
1119 bbase_y
= &bbase
[y
*bystride
];
1120 s
= (GFC_COMPLEX_10
) 0;
1121 for (n
= 0; n
< count
; n
++)
1122 s
+= abase
[n
*axstride
] * bbase_y
[n
];
1123 dest
[y
*rystride
] = s
;
1127 else if (axstride
< aystride
)
1129 for (y
= 0; y
< ycount
; y
++)
1130 for (x
= 0; x
< xcount
; x
++)
1131 dest
[x
*rxstride
+ y
*rystride
] = (GFC_COMPLEX_10
)0;
1133 for (y
= 0; y
< ycount
; y
++)
1134 for (n
= 0; n
< count
; n
++)
1135 for (x
= 0; x
< xcount
; x
++)
1136 /* dest[x,y] += a[x,n] * b[n,y] */
1137 dest
[x
*rxstride
+ y
*rystride
] +=
1138 abase
[x
*axstride
+ n
*aystride
] *
1139 bbase
[n
*bxstride
+ y
*bystride
];
1141 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
1143 const GFC_COMPLEX_10
*restrict bbase_y
;
1146 for (y
= 0; y
< ycount
; y
++)
1148 bbase_y
= &bbase
[y
*bystride
];
1149 s
= (GFC_COMPLEX_10
) 0;
1150 for (n
= 0; n
< count
; n
++)
1151 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
1152 dest
[y
*rxstride
] = s
;
1157 const GFC_COMPLEX_10
*restrict abase_x
;
1158 const GFC_COMPLEX_10
*restrict bbase_y
;
1159 GFC_COMPLEX_10
*restrict dest_y
;
1162 for (y
= 0; y
< ycount
; y
++)
1164 bbase_y
= &bbase
[y
*bystride
];
1165 dest_y
= &dest
[y
*rystride
];
1166 for (x
= 0; x
< xcount
; x
++)
1168 abase_x
= &abase
[x
*axstride
];
1169 s
= (GFC_COMPLEX_10
) 0;
1170 for (n
= 0; n
< count
; n
++)
1171 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
1172 dest_y
[x
*rxstride
] = s
;