1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2023 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
31 #if defined (HAVE_GFC_COMPLEX_8)
33 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
34 passed to us by the front-end, in which case we call it for large
37 typedef void (*blas_call
)(const char *, const char *, const int *, const int *,
38 const int *, const GFC_COMPLEX_8
*, const GFC_COMPLEX_8
*,
39 const int *, const GFC_COMPLEX_8
*, const int *,
40 const GFC_COMPLEX_8
*, GFC_COMPLEX_8
*, const int *,
43 /* The order of loops is different in the case of plain matrix
44 multiplication C=MATMUL(A,B), and in the frequent special case where
45 the argument A is the temporary result of a TRANSPOSE intrinsic:
46 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
47 looking at their strides.
49 The equivalent Fortran pseudo-code is:
51 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
52 IF (.NOT.IS_TRANSPOSED(A)) THEN
57 C(I,J) = C(I,J)+A(I,K)*B(K,J)
68 /* If try_blas is set to a nonzero value, then the matmul function will
69 see if there is a way to perform the matrix multiplication by a call
70 to the BLAS gemm function. */
72 extern void matmul_c8 (gfc_array_c8
* const restrict retarray
,
73 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
74 int blas_limit
, blas_call gemm
);
75 export_proto(matmul_c8
);
77 /* Put exhaustive list of possible architectures here here, ORed together. */
79 #if defined(HAVE_AVX) || defined(HAVE_AVX2) || defined(HAVE_AVX512F)
83 matmul_c8_avx (gfc_array_c8
* const restrict retarray
,
84 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
85 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx")));
87 matmul_c8_avx (gfc_array_c8
* const restrict retarray
,
88 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
89 int blas_limit
, blas_call gemm
)
91 const GFC_COMPLEX_8
* restrict abase
;
92 const GFC_COMPLEX_8
* restrict bbase
;
93 GFC_COMPLEX_8
* restrict dest
;
95 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
96 index_type x
, y
, n
, count
, xcount
, ycount
;
98 assert (GFC_DESCRIPTOR_RANK (a
) == 2
99 || GFC_DESCRIPTOR_RANK (b
) == 2);
101 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
103 Either A or B (but not both) can be rank 1:
105 o One-dimensional argument A is implicitly treated as a row matrix
106 dimensioned [1,count], so xcount=1.
108 o One-dimensional argument B is implicitly treated as a column matrix
109 dimensioned [count, 1], so ycount=1.
112 if (retarray
->base_addr
== NULL
)
114 if (GFC_DESCRIPTOR_RANK (a
) == 1)
116 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
117 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
119 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
121 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
122 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
126 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
127 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
129 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
130 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
131 GFC_DESCRIPTOR_EXTENT(retarray
,0));
135 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_COMPLEX_8
));
136 retarray
->offset
= 0;
138 else if (unlikely (compile_options
.bounds_check
))
140 index_type ret_extent
, arg_extent
;
142 if (GFC_DESCRIPTOR_RANK (a
) == 1)
144 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
145 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
146 if (arg_extent
!= ret_extent
)
147 runtime_error ("Array bound mismatch for dimension 1 of "
149 (long int) ret_extent
, (long int) arg_extent
);
151 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
153 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
154 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
155 if (arg_extent
!= ret_extent
)
156 runtime_error ("Array bound mismatch for dimension 1 of "
158 (long int) ret_extent
, (long int) arg_extent
);
162 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
163 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
164 if (arg_extent
!= ret_extent
)
165 runtime_error ("Array bound mismatch for dimension 1 of "
167 (long int) ret_extent
, (long int) arg_extent
);
169 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
170 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
171 if (arg_extent
!= ret_extent
)
172 runtime_error ("Array bound mismatch for dimension 2 of "
174 (long int) ret_extent
, (long int) arg_extent
);
179 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
181 /* One-dimensional result may be addressed in the code below
182 either as a row or a column matrix. We want both cases to
184 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
188 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
189 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
193 if (GFC_DESCRIPTOR_RANK (a
) == 1)
195 /* Treat it as a a row matrix A[1,count]. */
196 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
200 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
204 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
205 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
207 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
208 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
211 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
213 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
214 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
215 "in dimension 1: is %ld, should be %ld",
216 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
219 if (GFC_DESCRIPTOR_RANK (b
) == 1)
221 /* Treat it as a column matrix B[count,1] */
222 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
224 /* bystride should never be used for 1-dimensional b.
225 The value is only used for calculation of the
226 memory by the buffer. */
232 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
233 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
234 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
237 abase
= a
->base_addr
;
238 bbase
= b
->base_addr
;
239 dest
= retarray
->base_addr
;
241 /* Now that everything is set up, we perform the multiplication
244 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
245 #define min(a,b) ((a) <= (b) ? (a) : (b))
246 #define max(a,b) ((a) >= (b) ? (a) : (b))
248 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
249 && (bxstride
== 1 || bystride
== 1)
250 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
253 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
254 const GFC_COMPLEX_8 one
= 1, zero
= 0;
255 const int lda
= (axstride
== 1) ? aystride
: axstride
,
256 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
258 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
260 assert (gemm
!= NULL
);
261 const char *transa
, *transb
;
265 transa
= axstride
== 1 ? "N" : "T";
270 transb
= bxstride
== 1 ? "N" : "T";
272 gemm (transa
, transb
, &m
,
273 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
279 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1
280 && GFC_DESCRIPTOR_RANK (b
) != 1)
282 /* This block of code implements a tuned matmul, derived from
283 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
285 Bo Kagstrom and Per Ling
286 Department of Computing Science
288 S-901 87 Umea, Sweden
290 from netlib.org, translated to C, and modified for matmul.m4. */
292 const GFC_COMPLEX_8
*a
, *b
;
294 const index_type m
= xcount
, n
= ycount
, k
= count
;
296 /* System generated locals */
297 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
298 i1
, i2
, i3
, i4
, i5
, i6
;
300 /* Local variables */
301 GFC_COMPLEX_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
302 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
303 index_type i
, j
, l
, ii
, jj
, ll
;
304 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
309 c
= retarray
->base_addr
;
311 /* Parameter adjustments */
313 c_offset
= 1 + c_dim1
;
316 a_offset
= 1 + a_dim1
;
319 b_offset
= 1 + b_dim1
;
325 c
[i
+ j
* c_dim1
] = (GFC_COMPLEX_8
)0;
327 /* Early exit if possible */
328 if (m
== 0 || n
== 0 || k
== 0)
331 /* Adjust size of t1 to what is needed. */
332 index_type t1_dim
, a_sz
;
338 t1_dim
= a_sz
* 256 + b_dim1
;
342 t1
= malloc (t1_dim
* sizeof(GFC_COMPLEX_8
));
344 /* Start turning the crank. */
346 for (jj
= 1; jj
<= i1
; jj
+= 512)
352 ujsec
= jsec
- jsec
% 4;
354 for (ll
= 1; ll
<= i2
; ll
+= 256)
360 ulsec
= lsec
- lsec
% 2;
363 for (ii
= 1; ii
<= i3
; ii
+= 256)
369 uisec
= isec
- isec
% 2;
371 for (l
= ll
; l
<= i4
; l
+= 2)
374 for (i
= ii
; i
<= i5
; i
+= 2)
376 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
378 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
379 a
[i
+ (l
+ 1) * a_dim1
];
380 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
381 a
[i
+ 1 + l
* a_dim1
];
382 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
383 a
[i
+ 1 + (l
+ 1) * a_dim1
];
387 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
388 a
[ii
+ isec
- 1 + l
* a_dim1
];
389 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
390 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
396 for (i
= ii
; i
<= i4
; ++i
)
398 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
399 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
403 uisec
= isec
- isec
% 4;
405 for (j
= jj
; j
<= i4
; j
+= 4)
408 for (i
= ii
; i
<= i5
; i
+= 4)
410 f11
= c
[i
+ j
* c_dim1
];
411 f21
= c
[i
+ 1 + j
* c_dim1
];
412 f12
= c
[i
+ (j
+ 1) * c_dim1
];
413 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
414 f13
= c
[i
+ (j
+ 2) * c_dim1
];
415 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
416 f14
= c
[i
+ (j
+ 3) * c_dim1
];
417 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
418 f31
= c
[i
+ 2 + j
* c_dim1
];
419 f41
= c
[i
+ 3 + j
* c_dim1
];
420 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
421 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
422 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
423 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
424 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
425 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
427 for (l
= ll
; l
<= i6
; ++l
)
429 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
431 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
433 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
434 * b
[l
+ (j
+ 1) * b_dim1
];
435 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
436 * b
[l
+ (j
+ 1) * b_dim1
];
437 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
438 * b
[l
+ (j
+ 2) * b_dim1
];
439 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
440 * b
[l
+ (j
+ 2) * b_dim1
];
441 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
442 * b
[l
+ (j
+ 3) * b_dim1
];
443 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
444 * b
[l
+ (j
+ 3) * b_dim1
];
445 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
447 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
449 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
450 * b
[l
+ (j
+ 1) * b_dim1
];
451 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
452 * b
[l
+ (j
+ 1) * b_dim1
];
453 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
454 * b
[l
+ (j
+ 2) * b_dim1
];
455 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
456 * b
[l
+ (j
+ 2) * b_dim1
];
457 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
458 * b
[l
+ (j
+ 3) * b_dim1
];
459 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
460 * b
[l
+ (j
+ 3) * b_dim1
];
462 c
[i
+ j
* c_dim1
] = f11
;
463 c
[i
+ 1 + j
* c_dim1
] = f21
;
464 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
465 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
466 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
467 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
468 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
469 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
470 c
[i
+ 2 + j
* c_dim1
] = f31
;
471 c
[i
+ 3 + j
* c_dim1
] = f41
;
472 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
473 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
474 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
475 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
476 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
477 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
482 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
484 f11
= c
[i
+ j
* c_dim1
];
485 f12
= c
[i
+ (j
+ 1) * c_dim1
];
486 f13
= c
[i
+ (j
+ 2) * c_dim1
];
487 f14
= c
[i
+ (j
+ 3) * c_dim1
];
489 for (l
= ll
; l
<= i6
; ++l
)
491 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
492 257] * b
[l
+ j
* b_dim1
];
493 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
494 257] * b
[l
+ (j
+ 1) * b_dim1
];
495 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
496 257] * b
[l
+ (j
+ 2) * b_dim1
];
497 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
498 257] * b
[l
+ (j
+ 3) * b_dim1
];
500 c
[i
+ j
* c_dim1
] = f11
;
501 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
502 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
503 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
510 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
513 for (i
= ii
; i
<= i5
; i
+= 4)
515 f11
= c
[i
+ j
* c_dim1
];
516 f21
= c
[i
+ 1 + j
* c_dim1
];
517 f31
= c
[i
+ 2 + j
* c_dim1
];
518 f41
= c
[i
+ 3 + j
* c_dim1
];
520 for (l
= ll
; l
<= i6
; ++l
)
522 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
523 257] * b
[l
+ j
* b_dim1
];
524 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
525 257] * b
[l
+ j
* b_dim1
];
526 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
527 257] * b
[l
+ j
* b_dim1
];
528 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
529 257] * b
[l
+ j
* b_dim1
];
531 c
[i
+ j
* c_dim1
] = f11
;
532 c
[i
+ 1 + j
* c_dim1
] = f21
;
533 c
[i
+ 2 + j
* c_dim1
] = f31
;
534 c
[i
+ 3 + j
* c_dim1
] = f41
;
537 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
539 f11
= c
[i
+ j
* c_dim1
];
541 for (l
= ll
; l
<= i6
; ++l
)
543 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
544 257] * b
[l
+ j
* b_dim1
];
546 c
[i
+ j
* c_dim1
] = f11
;
556 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
558 if (GFC_DESCRIPTOR_RANK (a
) != 1)
560 const GFC_COMPLEX_8
*restrict abase_x
;
561 const GFC_COMPLEX_8
*restrict bbase_y
;
562 GFC_COMPLEX_8
*restrict dest_y
;
565 for (y
= 0; y
< ycount
; y
++)
567 bbase_y
= &bbase
[y
*bystride
];
568 dest_y
= &dest
[y
*rystride
];
569 for (x
= 0; x
< xcount
; x
++)
571 abase_x
= &abase
[x
*axstride
];
572 s
= (GFC_COMPLEX_8
) 0;
573 for (n
= 0; n
< count
; n
++)
574 s
+= abase_x
[n
] * bbase_y
[n
];
581 const GFC_COMPLEX_8
*restrict bbase_y
;
584 for (y
= 0; y
< ycount
; y
++)
586 bbase_y
= &bbase
[y
*bystride
];
587 s
= (GFC_COMPLEX_8
) 0;
588 for (n
= 0; n
< count
; n
++)
589 s
+= abase
[n
*axstride
] * bbase_y
[n
];
590 dest
[y
*rystride
] = s
;
594 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
596 const GFC_COMPLEX_8
*restrict bbase_y
;
599 for (y
= 0; y
< ycount
; y
++)
601 bbase_y
= &bbase
[y
*bystride
];
602 s
= (GFC_COMPLEX_8
) 0;
603 for (n
= 0; n
< count
; n
++)
604 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
605 dest
[y
*rxstride
] = s
;
608 else if (axstride
< aystride
)
610 for (y
= 0; y
< ycount
; y
++)
611 for (x
= 0; x
< xcount
; x
++)
612 dest
[x
*rxstride
+ y
*rystride
] = (GFC_COMPLEX_8
)0;
614 for (y
= 0; y
< ycount
; y
++)
615 for (n
= 0; n
< count
; n
++)
616 for (x
= 0; x
< xcount
; x
++)
617 /* dest[x,y] += a[x,n] * b[n,y] */
618 dest
[x
*rxstride
+ y
*rystride
] +=
619 abase
[x
*axstride
+ n
*aystride
] *
620 bbase
[n
*bxstride
+ y
*bystride
];
624 const GFC_COMPLEX_8
*restrict abase_x
;
625 const GFC_COMPLEX_8
*restrict bbase_y
;
626 GFC_COMPLEX_8
*restrict dest_y
;
629 for (y
= 0; y
< ycount
; y
++)
631 bbase_y
= &bbase
[y
*bystride
];
632 dest_y
= &dest
[y
*rystride
];
633 for (x
= 0; x
< xcount
; x
++)
635 abase_x
= &abase
[x
*axstride
];
636 s
= (GFC_COMPLEX_8
) 0;
637 for (n
= 0; n
< count
; n
++)
638 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
639 dest_y
[x
*rxstride
] = s
;
648 #endif /* HAVE_AVX */
652 matmul_c8_avx2 (gfc_array_c8
* const restrict retarray
,
653 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
654 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx2,fma")));
656 matmul_c8_avx2 (gfc_array_c8
* const restrict retarray
,
657 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
658 int blas_limit
, blas_call gemm
)
660 const GFC_COMPLEX_8
* restrict abase
;
661 const GFC_COMPLEX_8
* restrict bbase
;
662 GFC_COMPLEX_8
* restrict dest
;
664 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
665 index_type x
, y
, n
, count
, xcount
, ycount
;
667 assert (GFC_DESCRIPTOR_RANK (a
) == 2
668 || GFC_DESCRIPTOR_RANK (b
) == 2);
670 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
672 Either A or B (but not both) can be rank 1:
674 o One-dimensional argument A is implicitly treated as a row matrix
675 dimensioned [1,count], so xcount=1.
677 o One-dimensional argument B is implicitly treated as a column matrix
678 dimensioned [count, 1], so ycount=1.
681 if (retarray
->base_addr
== NULL
)
683 if (GFC_DESCRIPTOR_RANK (a
) == 1)
685 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
686 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
688 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
690 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
691 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
695 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
696 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
698 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
699 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
700 GFC_DESCRIPTOR_EXTENT(retarray
,0));
704 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_COMPLEX_8
));
705 retarray
->offset
= 0;
707 else if (unlikely (compile_options
.bounds_check
))
709 index_type ret_extent
, arg_extent
;
711 if (GFC_DESCRIPTOR_RANK (a
) == 1)
713 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
714 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
715 if (arg_extent
!= ret_extent
)
716 runtime_error ("Array bound mismatch for dimension 1 of "
718 (long int) ret_extent
, (long int) arg_extent
);
720 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
722 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
723 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
724 if (arg_extent
!= ret_extent
)
725 runtime_error ("Array bound mismatch for dimension 1 of "
727 (long int) ret_extent
, (long int) arg_extent
);
731 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
732 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
733 if (arg_extent
!= ret_extent
)
734 runtime_error ("Array bound mismatch for dimension 1 of "
736 (long int) ret_extent
, (long int) arg_extent
);
738 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
739 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
740 if (arg_extent
!= ret_extent
)
741 runtime_error ("Array bound mismatch for dimension 2 of "
743 (long int) ret_extent
, (long int) arg_extent
);
748 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
750 /* One-dimensional result may be addressed in the code below
751 either as a row or a column matrix. We want both cases to
753 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
757 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
758 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
762 if (GFC_DESCRIPTOR_RANK (a
) == 1)
764 /* Treat it as a a row matrix A[1,count]. */
765 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
769 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
773 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
774 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
776 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
777 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
780 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
782 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
783 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
784 "in dimension 1: is %ld, should be %ld",
785 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
788 if (GFC_DESCRIPTOR_RANK (b
) == 1)
790 /* Treat it as a column matrix B[count,1] */
791 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
793 /* bystride should never be used for 1-dimensional b.
794 The value is only used for calculation of the
795 memory by the buffer. */
801 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
802 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
803 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
806 abase
= a
->base_addr
;
807 bbase
= b
->base_addr
;
808 dest
= retarray
->base_addr
;
810 /* Now that everything is set up, we perform the multiplication
813 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
814 #define min(a,b) ((a) <= (b) ? (a) : (b))
815 #define max(a,b) ((a) >= (b) ? (a) : (b))
817 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
818 && (bxstride
== 1 || bystride
== 1)
819 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
822 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
823 const GFC_COMPLEX_8 one
= 1, zero
= 0;
824 const int lda
= (axstride
== 1) ? aystride
: axstride
,
825 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
827 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
829 assert (gemm
!= NULL
);
830 const char *transa
, *transb
;
834 transa
= axstride
== 1 ? "N" : "T";
839 transb
= bxstride
== 1 ? "N" : "T";
841 gemm (transa
, transb
, &m
,
842 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
848 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1
849 && GFC_DESCRIPTOR_RANK (b
) != 1)
851 /* This block of code implements a tuned matmul, derived from
852 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
854 Bo Kagstrom and Per Ling
855 Department of Computing Science
857 S-901 87 Umea, Sweden
859 from netlib.org, translated to C, and modified for matmul.m4. */
861 const GFC_COMPLEX_8
*a
, *b
;
863 const index_type m
= xcount
, n
= ycount
, k
= count
;
865 /* System generated locals */
866 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
867 i1
, i2
, i3
, i4
, i5
, i6
;
869 /* Local variables */
870 GFC_COMPLEX_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
871 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
872 index_type i
, j
, l
, ii
, jj
, ll
;
873 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
878 c
= retarray
->base_addr
;
880 /* Parameter adjustments */
882 c_offset
= 1 + c_dim1
;
885 a_offset
= 1 + a_dim1
;
888 b_offset
= 1 + b_dim1
;
894 c
[i
+ j
* c_dim1
] = (GFC_COMPLEX_8
)0;
896 /* Early exit if possible */
897 if (m
== 0 || n
== 0 || k
== 0)
900 /* Adjust size of t1 to what is needed. */
901 index_type t1_dim
, a_sz
;
907 t1_dim
= a_sz
* 256 + b_dim1
;
911 t1
= malloc (t1_dim
* sizeof(GFC_COMPLEX_8
));
913 /* Start turning the crank. */
915 for (jj
= 1; jj
<= i1
; jj
+= 512)
921 ujsec
= jsec
- jsec
% 4;
923 for (ll
= 1; ll
<= i2
; ll
+= 256)
929 ulsec
= lsec
- lsec
% 2;
932 for (ii
= 1; ii
<= i3
; ii
+= 256)
938 uisec
= isec
- isec
% 2;
940 for (l
= ll
; l
<= i4
; l
+= 2)
943 for (i
= ii
; i
<= i5
; i
+= 2)
945 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
947 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
948 a
[i
+ (l
+ 1) * a_dim1
];
949 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
950 a
[i
+ 1 + l
* a_dim1
];
951 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
952 a
[i
+ 1 + (l
+ 1) * a_dim1
];
956 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
957 a
[ii
+ isec
- 1 + l
* a_dim1
];
958 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
959 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
965 for (i
= ii
; i
<= i4
; ++i
)
967 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
968 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
972 uisec
= isec
- isec
% 4;
974 for (j
= jj
; j
<= i4
; j
+= 4)
977 for (i
= ii
; i
<= i5
; i
+= 4)
979 f11
= c
[i
+ j
* c_dim1
];
980 f21
= c
[i
+ 1 + j
* c_dim1
];
981 f12
= c
[i
+ (j
+ 1) * c_dim1
];
982 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
983 f13
= c
[i
+ (j
+ 2) * c_dim1
];
984 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
985 f14
= c
[i
+ (j
+ 3) * c_dim1
];
986 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
987 f31
= c
[i
+ 2 + j
* c_dim1
];
988 f41
= c
[i
+ 3 + j
* c_dim1
];
989 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
990 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
991 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
992 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
993 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
994 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
996 for (l
= ll
; l
<= i6
; ++l
)
998 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1000 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1001 * b
[l
+ j
* b_dim1
];
1002 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1003 * b
[l
+ (j
+ 1) * b_dim1
];
1004 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1005 * b
[l
+ (j
+ 1) * b_dim1
];
1006 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1007 * b
[l
+ (j
+ 2) * b_dim1
];
1008 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1009 * b
[l
+ (j
+ 2) * b_dim1
];
1010 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1011 * b
[l
+ (j
+ 3) * b_dim1
];
1012 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1013 * b
[l
+ (j
+ 3) * b_dim1
];
1014 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1015 * b
[l
+ j
* b_dim1
];
1016 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1017 * b
[l
+ j
* b_dim1
];
1018 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1019 * b
[l
+ (j
+ 1) * b_dim1
];
1020 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1021 * b
[l
+ (j
+ 1) * b_dim1
];
1022 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1023 * b
[l
+ (j
+ 2) * b_dim1
];
1024 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1025 * b
[l
+ (j
+ 2) * b_dim1
];
1026 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1027 * b
[l
+ (j
+ 3) * b_dim1
];
1028 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1029 * b
[l
+ (j
+ 3) * b_dim1
];
1031 c
[i
+ j
* c_dim1
] = f11
;
1032 c
[i
+ 1 + j
* c_dim1
] = f21
;
1033 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1034 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
1035 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1036 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
1037 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1038 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
1039 c
[i
+ 2 + j
* c_dim1
] = f31
;
1040 c
[i
+ 3 + j
* c_dim1
] = f41
;
1041 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
1042 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
1043 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
1044 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
1045 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
1046 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
1051 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1053 f11
= c
[i
+ j
* c_dim1
];
1054 f12
= c
[i
+ (j
+ 1) * c_dim1
];
1055 f13
= c
[i
+ (j
+ 2) * c_dim1
];
1056 f14
= c
[i
+ (j
+ 3) * c_dim1
];
1058 for (l
= ll
; l
<= i6
; ++l
)
1060 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1061 257] * b
[l
+ j
* b_dim1
];
1062 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1063 257] * b
[l
+ (j
+ 1) * b_dim1
];
1064 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1065 257] * b
[l
+ (j
+ 2) * b_dim1
];
1066 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1067 257] * b
[l
+ (j
+ 3) * b_dim1
];
1069 c
[i
+ j
* c_dim1
] = f11
;
1070 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1071 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1072 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1079 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
1081 i5
= ii
+ uisec
- 1;
1082 for (i
= ii
; i
<= i5
; i
+= 4)
1084 f11
= c
[i
+ j
* c_dim1
];
1085 f21
= c
[i
+ 1 + j
* c_dim1
];
1086 f31
= c
[i
+ 2 + j
* c_dim1
];
1087 f41
= c
[i
+ 3 + j
* c_dim1
];
1089 for (l
= ll
; l
<= i6
; ++l
)
1091 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1092 257] * b
[l
+ j
* b_dim1
];
1093 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
1094 257] * b
[l
+ j
* b_dim1
];
1095 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
1096 257] * b
[l
+ j
* b_dim1
];
1097 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
1098 257] * b
[l
+ j
* b_dim1
];
1100 c
[i
+ j
* c_dim1
] = f11
;
1101 c
[i
+ 1 + j
* c_dim1
] = f21
;
1102 c
[i
+ 2 + j
* c_dim1
] = f31
;
1103 c
[i
+ 3 + j
* c_dim1
] = f41
;
1106 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1108 f11
= c
[i
+ j
* c_dim1
];
1110 for (l
= ll
; l
<= i6
; ++l
)
1112 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1113 257] * b
[l
+ j
* b_dim1
];
1115 c
[i
+ j
* c_dim1
] = f11
;
1125 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
1127 if (GFC_DESCRIPTOR_RANK (a
) != 1)
1129 const GFC_COMPLEX_8
*restrict abase_x
;
1130 const GFC_COMPLEX_8
*restrict bbase_y
;
1131 GFC_COMPLEX_8
*restrict dest_y
;
1134 for (y
= 0; y
< ycount
; y
++)
1136 bbase_y
= &bbase
[y
*bystride
];
1137 dest_y
= &dest
[y
*rystride
];
1138 for (x
= 0; x
< xcount
; x
++)
1140 abase_x
= &abase
[x
*axstride
];
1141 s
= (GFC_COMPLEX_8
) 0;
1142 for (n
= 0; n
< count
; n
++)
1143 s
+= abase_x
[n
] * bbase_y
[n
];
1150 const GFC_COMPLEX_8
*restrict bbase_y
;
1153 for (y
= 0; y
< ycount
; y
++)
1155 bbase_y
= &bbase
[y
*bystride
];
1156 s
= (GFC_COMPLEX_8
) 0;
1157 for (n
= 0; n
< count
; n
++)
1158 s
+= abase
[n
*axstride
] * bbase_y
[n
];
1159 dest
[y
*rystride
] = s
;
1163 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
1165 const GFC_COMPLEX_8
*restrict bbase_y
;
1168 for (y
= 0; y
< ycount
; y
++)
1170 bbase_y
= &bbase
[y
*bystride
];
1171 s
= (GFC_COMPLEX_8
) 0;
1172 for (n
= 0; n
< count
; n
++)
1173 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
1174 dest
[y
*rxstride
] = s
;
1177 else if (axstride
< aystride
)
1179 for (y
= 0; y
< ycount
; y
++)
1180 for (x
= 0; x
< xcount
; x
++)
1181 dest
[x
*rxstride
+ y
*rystride
] = (GFC_COMPLEX_8
)0;
1183 for (y
= 0; y
< ycount
; y
++)
1184 for (n
= 0; n
< count
; n
++)
1185 for (x
= 0; x
< xcount
; x
++)
1186 /* dest[x,y] += a[x,n] * b[n,y] */
1187 dest
[x
*rxstride
+ y
*rystride
] +=
1188 abase
[x
*axstride
+ n
*aystride
] *
1189 bbase
[n
*bxstride
+ y
*bystride
];
1193 const GFC_COMPLEX_8
*restrict abase_x
;
1194 const GFC_COMPLEX_8
*restrict bbase_y
;
1195 GFC_COMPLEX_8
*restrict dest_y
;
1198 for (y
= 0; y
< ycount
; y
++)
1200 bbase_y
= &bbase
[y
*bystride
];
1201 dest_y
= &dest
[y
*rystride
];
1202 for (x
= 0; x
< xcount
; x
++)
1204 abase_x
= &abase
[x
*axstride
];
1205 s
= (GFC_COMPLEX_8
) 0;
1206 for (n
= 0; n
< count
; n
++)
1207 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
1208 dest_y
[x
*rxstride
] = s
;
1217 #endif /* HAVE_AVX2 */
1221 matmul_c8_avx512f (gfc_array_c8
* const restrict retarray
,
1222 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
1223 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx512f")));
1225 matmul_c8_avx512f (gfc_array_c8
* const restrict retarray
,
1226 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
1227 int blas_limit
, blas_call gemm
)
1229 const GFC_COMPLEX_8
* restrict abase
;
1230 const GFC_COMPLEX_8
* restrict bbase
;
1231 GFC_COMPLEX_8
* restrict dest
;
1233 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
1234 index_type x
, y
, n
, count
, xcount
, ycount
;
1236 assert (GFC_DESCRIPTOR_RANK (a
) == 2
1237 || GFC_DESCRIPTOR_RANK (b
) == 2);
1239 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
1241 Either A or B (but not both) can be rank 1:
1243 o One-dimensional argument A is implicitly treated as a row matrix
1244 dimensioned [1,count], so xcount=1.
1246 o One-dimensional argument B is implicitly treated as a column matrix
1247 dimensioned [count, 1], so ycount=1.
1250 if (retarray
->base_addr
== NULL
)
1252 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1254 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1255 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
1257 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
1259 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1260 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
1264 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1265 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
1267 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
1268 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
1269 GFC_DESCRIPTOR_EXTENT(retarray
,0));
1273 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_COMPLEX_8
));
1274 retarray
->offset
= 0;
1276 else if (unlikely (compile_options
.bounds_check
))
1278 index_type ret_extent
, arg_extent
;
1280 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1282 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
1283 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1284 if (arg_extent
!= ret_extent
)
1285 runtime_error ("Array bound mismatch for dimension 1 of "
1287 (long int) ret_extent
, (long int) arg_extent
);
1289 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
1291 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
1292 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1293 if (arg_extent
!= ret_extent
)
1294 runtime_error ("Array bound mismatch for dimension 1 of "
1296 (long int) ret_extent
, (long int) arg_extent
);
1300 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
1301 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1302 if (arg_extent
!= ret_extent
)
1303 runtime_error ("Array bound mismatch for dimension 1 of "
1305 (long int) ret_extent
, (long int) arg_extent
);
1307 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
1308 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
1309 if (arg_extent
!= ret_extent
)
1310 runtime_error ("Array bound mismatch for dimension 2 of "
1312 (long int) ret_extent
, (long int) arg_extent
);
1317 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
1319 /* One-dimensional result may be addressed in the code below
1320 either as a row or a column matrix. We want both cases to
1322 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
1326 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
1327 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
1331 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1333 /* Treat it as a a row matrix A[1,count]. */
1334 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
1338 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
1342 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
1343 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
1345 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
1346 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
1349 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
1351 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
1352 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
1353 "in dimension 1: is %ld, should be %ld",
1354 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
1357 if (GFC_DESCRIPTOR_RANK (b
) == 1)
1359 /* Treat it as a column matrix B[count,1] */
1360 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
1362 /* bystride should never be used for 1-dimensional b.
1363 The value is only used for calculation of the
1364 memory by the buffer. */
1370 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
1371 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
1372 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
1375 abase
= a
->base_addr
;
1376 bbase
= b
->base_addr
;
1377 dest
= retarray
->base_addr
;
1379 /* Now that everything is set up, we perform the multiplication
1382 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
1383 #define min(a,b) ((a) <= (b) ? (a) : (b))
1384 #define max(a,b) ((a) >= (b) ? (a) : (b))
1386 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
1387 && (bxstride
== 1 || bystride
== 1)
1388 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
1389 > POW3(blas_limit
)))
1391 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
1392 const GFC_COMPLEX_8 one
= 1, zero
= 0;
1393 const int lda
= (axstride
== 1) ? aystride
: axstride
,
1394 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
1396 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
1398 assert (gemm
!= NULL
);
1399 const char *transa
, *transb
;
1403 transa
= axstride
== 1 ? "N" : "T";
1408 transb
= bxstride
== 1 ? "N" : "T";
1410 gemm (transa
, transb
, &m
,
1411 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
1417 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1
1418 && GFC_DESCRIPTOR_RANK (b
) != 1)
1420 /* This block of code implements a tuned matmul, derived from
1421 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
1423 Bo Kagstrom and Per Ling
1424 Department of Computing Science
1426 S-901 87 Umea, Sweden
1428 from netlib.org, translated to C, and modified for matmul.m4. */
1430 const GFC_COMPLEX_8
*a
, *b
;
1432 const index_type m
= xcount
, n
= ycount
, k
= count
;
1434 /* System generated locals */
1435 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
1436 i1
, i2
, i3
, i4
, i5
, i6
;
1438 /* Local variables */
1439 GFC_COMPLEX_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
1440 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
1441 index_type i
, j
, l
, ii
, jj
, ll
;
1442 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
1447 c
= retarray
->base_addr
;
1449 /* Parameter adjustments */
1451 c_offset
= 1 + c_dim1
;
1454 a_offset
= 1 + a_dim1
;
1457 b_offset
= 1 + b_dim1
;
1460 /* Empty c first. */
1461 for (j
=1; j
<=n
; j
++)
1462 for (i
=1; i
<=m
; i
++)
1463 c
[i
+ j
* c_dim1
] = (GFC_COMPLEX_8
)0;
1465 /* Early exit if possible */
1466 if (m
== 0 || n
== 0 || k
== 0)
1469 /* Adjust size of t1 to what is needed. */
1470 index_type t1_dim
, a_sz
;
1476 t1_dim
= a_sz
* 256 + b_dim1
;
1480 t1
= malloc (t1_dim
* sizeof(GFC_COMPLEX_8
));
1482 /* Start turning the crank. */
1484 for (jj
= 1; jj
<= i1
; jj
+= 512)
1490 ujsec
= jsec
- jsec
% 4;
1492 for (ll
= 1; ll
<= i2
; ll
+= 256)
1498 ulsec
= lsec
- lsec
% 2;
1501 for (ii
= 1; ii
<= i3
; ii
+= 256)
1507 uisec
= isec
- isec
% 2;
1508 i4
= ll
+ ulsec
- 1;
1509 for (l
= ll
; l
<= i4
; l
+= 2)
1511 i5
= ii
+ uisec
- 1;
1512 for (i
= ii
; i
<= i5
; i
+= 2)
1514 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
1516 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
1517 a
[i
+ (l
+ 1) * a_dim1
];
1518 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
1519 a
[i
+ 1 + l
* a_dim1
];
1520 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
1521 a
[i
+ 1 + (l
+ 1) * a_dim1
];
1525 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
1526 a
[ii
+ isec
- 1 + l
* a_dim1
];
1527 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
1528 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
1534 for (i
= ii
; i
<= i4
; ++i
)
1536 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
1537 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
1541 uisec
= isec
- isec
% 4;
1542 i4
= jj
+ ujsec
- 1;
1543 for (j
= jj
; j
<= i4
; j
+= 4)
1545 i5
= ii
+ uisec
- 1;
1546 for (i
= ii
; i
<= i5
; i
+= 4)
1548 f11
= c
[i
+ j
* c_dim1
];
1549 f21
= c
[i
+ 1 + j
* c_dim1
];
1550 f12
= c
[i
+ (j
+ 1) * c_dim1
];
1551 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
1552 f13
= c
[i
+ (j
+ 2) * c_dim1
];
1553 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
1554 f14
= c
[i
+ (j
+ 3) * c_dim1
];
1555 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
1556 f31
= c
[i
+ 2 + j
* c_dim1
];
1557 f41
= c
[i
+ 3 + j
* c_dim1
];
1558 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
1559 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
1560 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
1561 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
1562 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
1563 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
1565 for (l
= ll
; l
<= i6
; ++l
)
1567 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1568 * b
[l
+ j
* b_dim1
];
1569 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1570 * b
[l
+ j
* b_dim1
];
1571 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1572 * b
[l
+ (j
+ 1) * b_dim1
];
1573 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1574 * b
[l
+ (j
+ 1) * b_dim1
];
1575 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1576 * b
[l
+ (j
+ 2) * b_dim1
];
1577 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1578 * b
[l
+ (j
+ 2) * b_dim1
];
1579 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1580 * b
[l
+ (j
+ 3) * b_dim1
];
1581 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1582 * b
[l
+ (j
+ 3) * b_dim1
];
1583 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1584 * b
[l
+ j
* b_dim1
];
1585 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1586 * b
[l
+ j
* b_dim1
];
1587 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1588 * b
[l
+ (j
+ 1) * b_dim1
];
1589 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1590 * b
[l
+ (j
+ 1) * b_dim1
];
1591 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1592 * b
[l
+ (j
+ 2) * b_dim1
];
1593 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1594 * b
[l
+ (j
+ 2) * b_dim1
];
1595 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1596 * b
[l
+ (j
+ 3) * b_dim1
];
1597 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1598 * b
[l
+ (j
+ 3) * b_dim1
];
1600 c
[i
+ j
* c_dim1
] = f11
;
1601 c
[i
+ 1 + j
* c_dim1
] = f21
;
1602 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1603 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
1604 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1605 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
1606 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1607 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
1608 c
[i
+ 2 + j
* c_dim1
] = f31
;
1609 c
[i
+ 3 + j
* c_dim1
] = f41
;
1610 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
1611 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
1612 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
1613 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
1614 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
1615 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
1620 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1622 f11
= c
[i
+ j
* c_dim1
];
1623 f12
= c
[i
+ (j
+ 1) * c_dim1
];
1624 f13
= c
[i
+ (j
+ 2) * c_dim1
];
1625 f14
= c
[i
+ (j
+ 3) * c_dim1
];
1627 for (l
= ll
; l
<= i6
; ++l
)
1629 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1630 257] * b
[l
+ j
* b_dim1
];
1631 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1632 257] * b
[l
+ (j
+ 1) * b_dim1
];
1633 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1634 257] * b
[l
+ (j
+ 2) * b_dim1
];
1635 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1636 257] * b
[l
+ (j
+ 3) * b_dim1
];
1638 c
[i
+ j
* c_dim1
] = f11
;
1639 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1640 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1641 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1648 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
1650 i5
= ii
+ uisec
- 1;
1651 for (i
= ii
; i
<= i5
; i
+= 4)
1653 f11
= c
[i
+ j
* c_dim1
];
1654 f21
= c
[i
+ 1 + j
* c_dim1
];
1655 f31
= c
[i
+ 2 + j
* c_dim1
];
1656 f41
= c
[i
+ 3 + j
* c_dim1
];
1658 for (l
= ll
; l
<= i6
; ++l
)
1660 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1661 257] * b
[l
+ j
* b_dim1
];
1662 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
1663 257] * b
[l
+ j
* b_dim1
];
1664 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
1665 257] * b
[l
+ j
* b_dim1
];
1666 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
1667 257] * b
[l
+ j
* b_dim1
];
1669 c
[i
+ j
* c_dim1
] = f11
;
1670 c
[i
+ 1 + j
* c_dim1
] = f21
;
1671 c
[i
+ 2 + j
* c_dim1
] = f31
;
1672 c
[i
+ 3 + j
* c_dim1
] = f41
;
1675 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1677 f11
= c
[i
+ j
* c_dim1
];
1679 for (l
= ll
; l
<= i6
; ++l
)
1681 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1682 257] * b
[l
+ j
* b_dim1
];
1684 c
[i
+ j
* c_dim1
] = f11
;
1694 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
1696 if (GFC_DESCRIPTOR_RANK (a
) != 1)
1698 const GFC_COMPLEX_8
*restrict abase_x
;
1699 const GFC_COMPLEX_8
*restrict bbase_y
;
1700 GFC_COMPLEX_8
*restrict dest_y
;
1703 for (y
= 0; y
< ycount
; y
++)
1705 bbase_y
= &bbase
[y
*bystride
];
1706 dest_y
= &dest
[y
*rystride
];
1707 for (x
= 0; x
< xcount
; x
++)
1709 abase_x
= &abase
[x
*axstride
];
1710 s
= (GFC_COMPLEX_8
) 0;
1711 for (n
= 0; n
< count
; n
++)
1712 s
+= abase_x
[n
] * bbase_y
[n
];
1719 const GFC_COMPLEX_8
*restrict bbase_y
;
1722 for (y
= 0; y
< ycount
; y
++)
1724 bbase_y
= &bbase
[y
*bystride
];
1725 s
= (GFC_COMPLEX_8
) 0;
1726 for (n
= 0; n
< count
; n
++)
1727 s
+= abase
[n
*axstride
] * bbase_y
[n
];
1728 dest
[y
*rystride
] = s
;
1732 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
1734 const GFC_COMPLEX_8
*restrict bbase_y
;
1737 for (y
= 0; y
< ycount
; y
++)
1739 bbase_y
= &bbase
[y
*bystride
];
1740 s
= (GFC_COMPLEX_8
) 0;
1741 for (n
= 0; n
< count
; n
++)
1742 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
1743 dest
[y
*rxstride
] = s
;
1746 else if (axstride
< aystride
)
1748 for (y
= 0; y
< ycount
; y
++)
1749 for (x
= 0; x
< xcount
; x
++)
1750 dest
[x
*rxstride
+ y
*rystride
] = (GFC_COMPLEX_8
)0;
1752 for (y
= 0; y
< ycount
; y
++)
1753 for (n
= 0; n
< count
; n
++)
1754 for (x
= 0; x
< xcount
; x
++)
1755 /* dest[x,y] += a[x,n] * b[n,y] */
1756 dest
[x
*rxstride
+ y
*rystride
] +=
1757 abase
[x
*axstride
+ n
*aystride
] *
1758 bbase
[n
*bxstride
+ y
*bystride
];
1762 const GFC_COMPLEX_8
*restrict abase_x
;
1763 const GFC_COMPLEX_8
*restrict bbase_y
;
1764 GFC_COMPLEX_8
*restrict dest_y
;
1767 for (y
= 0; y
< ycount
; y
++)
1769 bbase_y
= &bbase
[y
*bystride
];
1770 dest_y
= &dest
[y
*rystride
];
1771 for (x
= 0; x
< xcount
; x
++)
1773 abase_x
= &abase
[x
*axstride
];
1774 s
= (GFC_COMPLEX_8
) 0;
1775 for (n
= 0; n
< count
; n
++)
1776 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
1777 dest_y
[x
*rxstride
] = s
;
1786 #endif /* HAVE_AVX512F */
1788 /* AMD-specifix funtions with AVX128 and FMA3/FMA4. */
1790 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
1792 matmul_c8_avx128_fma3 (gfc_array_c8
* const restrict retarray
,
1793 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
1794 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma")));
1795 internal_proto(matmul_c8_avx128_fma3
);
1798 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
1800 matmul_c8_avx128_fma4 (gfc_array_c8
* const restrict retarray
,
1801 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
1802 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma4")));
1803 internal_proto(matmul_c8_avx128_fma4
);
1806 /* Function to fall back to if there is no special processor-specific version. */
1808 matmul_c8_vanilla (gfc_array_c8
* const restrict retarray
,
1809 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
1810 int blas_limit
, blas_call gemm
)
1812 const GFC_COMPLEX_8
* restrict abase
;
1813 const GFC_COMPLEX_8
* restrict bbase
;
1814 GFC_COMPLEX_8
* restrict dest
;
1816 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
1817 index_type x
, y
, n
, count
, xcount
, ycount
;
1819 assert (GFC_DESCRIPTOR_RANK (a
) == 2
1820 || GFC_DESCRIPTOR_RANK (b
) == 2);
1822 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
1824 Either A or B (but not both) can be rank 1:
1826 o One-dimensional argument A is implicitly treated as a row matrix
1827 dimensioned [1,count], so xcount=1.
1829 o One-dimensional argument B is implicitly treated as a column matrix
1830 dimensioned [count, 1], so ycount=1.
1833 if (retarray
->base_addr
== NULL
)
1835 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1837 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1838 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
1840 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
1842 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1843 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
1847 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1848 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
1850 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
1851 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
1852 GFC_DESCRIPTOR_EXTENT(retarray
,0));
1856 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_COMPLEX_8
));
1857 retarray
->offset
= 0;
1859 else if (unlikely (compile_options
.bounds_check
))
1861 index_type ret_extent
, arg_extent
;
1863 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1865 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
1866 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1867 if (arg_extent
!= ret_extent
)
1868 runtime_error ("Array bound mismatch for dimension 1 of "
1870 (long int) ret_extent
, (long int) arg_extent
);
1872 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
1874 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
1875 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1876 if (arg_extent
!= ret_extent
)
1877 runtime_error ("Array bound mismatch for dimension 1 of "
1879 (long int) ret_extent
, (long int) arg_extent
);
1883 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
1884 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1885 if (arg_extent
!= ret_extent
)
1886 runtime_error ("Array bound mismatch for dimension 1 of "
1888 (long int) ret_extent
, (long int) arg_extent
);
1890 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
1891 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
1892 if (arg_extent
!= ret_extent
)
1893 runtime_error ("Array bound mismatch for dimension 2 of "
1895 (long int) ret_extent
, (long int) arg_extent
);
1900 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
1902 /* One-dimensional result may be addressed in the code below
1903 either as a row or a column matrix. We want both cases to
1905 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
1909 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
1910 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
1914 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1916 /* Treat it as a a row matrix A[1,count]. */
1917 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
1921 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
1925 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
1926 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
1928 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
1929 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
1932 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
1934 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
1935 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
1936 "in dimension 1: is %ld, should be %ld",
1937 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
1940 if (GFC_DESCRIPTOR_RANK (b
) == 1)
1942 /* Treat it as a column matrix B[count,1] */
1943 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
1945 /* bystride should never be used for 1-dimensional b.
1946 The value is only used for calculation of the
1947 memory by the buffer. */
1953 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
1954 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
1955 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
1958 abase
= a
->base_addr
;
1959 bbase
= b
->base_addr
;
1960 dest
= retarray
->base_addr
;
1962 /* Now that everything is set up, we perform the multiplication
1965 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
1966 #define min(a,b) ((a) <= (b) ? (a) : (b))
1967 #define max(a,b) ((a) >= (b) ? (a) : (b))
1969 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
1970 && (bxstride
== 1 || bystride
== 1)
1971 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
1972 > POW3(blas_limit
)))
1974 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
1975 const GFC_COMPLEX_8 one
= 1, zero
= 0;
1976 const int lda
= (axstride
== 1) ? aystride
: axstride
,
1977 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
1979 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
1981 assert (gemm
!= NULL
);
1982 const char *transa
, *transb
;
1986 transa
= axstride
== 1 ? "N" : "T";
1991 transb
= bxstride
== 1 ? "N" : "T";
1993 gemm (transa
, transb
, &m
,
1994 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
2000 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1
2001 && GFC_DESCRIPTOR_RANK (b
) != 1)
2003 /* This block of code implements a tuned matmul, derived from
2004 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
2006 Bo Kagstrom and Per Ling
2007 Department of Computing Science
2009 S-901 87 Umea, Sweden
2011 from netlib.org, translated to C, and modified for matmul.m4. */
2013 const GFC_COMPLEX_8
*a
, *b
;
2015 const index_type m
= xcount
, n
= ycount
, k
= count
;
2017 /* System generated locals */
2018 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
2019 i1
, i2
, i3
, i4
, i5
, i6
;
2021 /* Local variables */
2022 GFC_COMPLEX_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
2023 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
2024 index_type i
, j
, l
, ii
, jj
, ll
;
2025 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
2030 c
= retarray
->base_addr
;
2032 /* Parameter adjustments */
2034 c_offset
= 1 + c_dim1
;
2037 a_offset
= 1 + a_dim1
;
2040 b_offset
= 1 + b_dim1
;
2043 /* Empty c first. */
2044 for (j
=1; j
<=n
; j
++)
2045 for (i
=1; i
<=m
; i
++)
2046 c
[i
+ j
* c_dim1
] = (GFC_COMPLEX_8
)0;
2048 /* Early exit if possible */
2049 if (m
== 0 || n
== 0 || k
== 0)
2052 /* Adjust size of t1 to what is needed. */
2053 index_type t1_dim
, a_sz
;
2059 t1_dim
= a_sz
* 256 + b_dim1
;
2063 t1
= malloc (t1_dim
* sizeof(GFC_COMPLEX_8
));
2065 /* Start turning the crank. */
2067 for (jj
= 1; jj
<= i1
; jj
+= 512)
2073 ujsec
= jsec
- jsec
% 4;
2075 for (ll
= 1; ll
<= i2
; ll
+= 256)
2081 ulsec
= lsec
- lsec
% 2;
2084 for (ii
= 1; ii
<= i3
; ii
+= 256)
2090 uisec
= isec
- isec
% 2;
2091 i4
= ll
+ ulsec
- 1;
2092 for (l
= ll
; l
<= i4
; l
+= 2)
2094 i5
= ii
+ uisec
- 1;
2095 for (i
= ii
; i
<= i5
; i
+= 2)
2097 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
2099 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
2100 a
[i
+ (l
+ 1) * a_dim1
];
2101 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
2102 a
[i
+ 1 + l
* a_dim1
];
2103 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
2104 a
[i
+ 1 + (l
+ 1) * a_dim1
];
2108 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
2109 a
[ii
+ isec
- 1 + l
* a_dim1
];
2110 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
2111 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
2117 for (i
= ii
; i
<= i4
; ++i
)
2119 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
2120 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
2124 uisec
= isec
- isec
% 4;
2125 i4
= jj
+ ujsec
- 1;
2126 for (j
= jj
; j
<= i4
; j
+= 4)
2128 i5
= ii
+ uisec
- 1;
2129 for (i
= ii
; i
<= i5
; i
+= 4)
2131 f11
= c
[i
+ j
* c_dim1
];
2132 f21
= c
[i
+ 1 + j
* c_dim1
];
2133 f12
= c
[i
+ (j
+ 1) * c_dim1
];
2134 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
2135 f13
= c
[i
+ (j
+ 2) * c_dim1
];
2136 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
2137 f14
= c
[i
+ (j
+ 3) * c_dim1
];
2138 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
2139 f31
= c
[i
+ 2 + j
* c_dim1
];
2140 f41
= c
[i
+ 3 + j
* c_dim1
];
2141 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
2142 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
2143 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
2144 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
2145 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
2146 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
2148 for (l
= ll
; l
<= i6
; ++l
)
2150 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2151 * b
[l
+ j
* b_dim1
];
2152 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2153 * b
[l
+ j
* b_dim1
];
2154 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2155 * b
[l
+ (j
+ 1) * b_dim1
];
2156 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2157 * b
[l
+ (j
+ 1) * b_dim1
];
2158 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2159 * b
[l
+ (j
+ 2) * b_dim1
];
2160 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2161 * b
[l
+ (j
+ 2) * b_dim1
];
2162 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2163 * b
[l
+ (j
+ 3) * b_dim1
];
2164 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2165 * b
[l
+ (j
+ 3) * b_dim1
];
2166 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2167 * b
[l
+ j
* b_dim1
];
2168 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2169 * b
[l
+ j
* b_dim1
];
2170 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2171 * b
[l
+ (j
+ 1) * b_dim1
];
2172 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2173 * b
[l
+ (j
+ 1) * b_dim1
];
2174 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2175 * b
[l
+ (j
+ 2) * b_dim1
];
2176 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2177 * b
[l
+ (j
+ 2) * b_dim1
];
2178 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2179 * b
[l
+ (j
+ 3) * b_dim1
];
2180 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2181 * b
[l
+ (j
+ 3) * b_dim1
];
2183 c
[i
+ j
* c_dim1
] = f11
;
2184 c
[i
+ 1 + j
* c_dim1
] = f21
;
2185 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
2186 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
2187 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
2188 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
2189 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
2190 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
2191 c
[i
+ 2 + j
* c_dim1
] = f31
;
2192 c
[i
+ 3 + j
* c_dim1
] = f41
;
2193 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
2194 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
2195 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
2196 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
2197 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
2198 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
2203 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
2205 f11
= c
[i
+ j
* c_dim1
];
2206 f12
= c
[i
+ (j
+ 1) * c_dim1
];
2207 f13
= c
[i
+ (j
+ 2) * c_dim1
];
2208 f14
= c
[i
+ (j
+ 3) * c_dim1
];
2210 for (l
= ll
; l
<= i6
; ++l
)
2212 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2213 257] * b
[l
+ j
* b_dim1
];
2214 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2215 257] * b
[l
+ (j
+ 1) * b_dim1
];
2216 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2217 257] * b
[l
+ (j
+ 2) * b_dim1
];
2218 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2219 257] * b
[l
+ (j
+ 3) * b_dim1
];
2221 c
[i
+ j
* c_dim1
] = f11
;
2222 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
2223 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
2224 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
2231 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
2233 i5
= ii
+ uisec
- 1;
2234 for (i
= ii
; i
<= i5
; i
+= 4)
2236 f11
= c
[i
+ j
* c_dim1
];
2237 f21
= c
[i
+ 1 + j
* c_dim1
];
2238 f31
= c
[i
+ 2 + j
* c_dim1
];
2239 f41
= c
[i
+ 3 + j
* c_dim1
];
2241 for (l
= ll
; l
<= i6
; ++l
)
2243 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2244 257] * b
[l
+ j
* b_dim1
];
2245 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
2246 257] * b
[l
+ j
* b_dim1
];
2247 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
2248 257] * b
[l
+ j
* b_dim1
];
2249 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
2250 257] * b
[l
+ j
* b_dim1
];
2252 c
[i
+ j
* c_dim1
] = f11
;
2253 c
[i
+ 1 + j
* c_dim1
] = f21
;
2254 c
[i
+ 2 + j
* c_dim1
] = f31
;
2255 c
[i
+ 3 + j
* c_dim1
] = f41
;
2258 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
2260 f11
= c
[i
+ j
* c_dim1
];
2262 for (l
= ll
; l
<= i6
; ++l
)
2264 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2265 257] * b
[l
+ j
* b_dim1
];
2267 c
[i
+ j
* c_dim1
] = f11
;
2277 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
2279 if (GFC_DESCRIPTOR_RANK (a
) != 1)
2281 const GFC_COMPLEX_8
*restrict abase_x
;
2282 const GFC_COMPLEX_8
*restrict bbase_y
;
2283 GFC_COMPLEX_8
*restrict dest_y
;
2286 for (y
= 0; y
< ycount
; y
++)
2288 bbase_y
= &bbase
[y
*bystride
];
2289 dest_y
= &dest
[y
*rystride
];
2290 for (x
= 0; x
< xcount
; x
++)
2292 abase_x
= &abase
[x
*axstride
];
2293 s
= (GFC_COMPLEX_8
) 0;
2294 for (n
= 0; n
< count
; n
++)
2295 s
+= abase_x
[n
] * bbase_y
[n
];
2302 const GFC_COMPLEX_8
*restrict bbase_y
;
2305 for (y
= 0; y
< ycount
; y
++)
2307 bbase_y
= &bbase
[y
*bystride
];
2308 s
= (GFC_COMPLEX_8
) 0;
2309 for (n
= 0; n
< count
; n
++)
2310 s
+= abase
[n
*axstride
] * bbase_y
[n
];
2311 dest
[y
*rystride
] = s
;
2315 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
2317 const GFC_COMPLEX_8
*restrict bbase_y
;
2320 for (y
= 0; y
< ycount
; y
++)
2322 bbase_y
= &bbase
[y
*bystride
];
2323 s
= (GFC_COMPLEX_8
) 0;
2324 for (n
= 0; n
< count
; n
++)
2325 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
2326 dest
[y
*rxstride
] = s
;
2329 else if (axstride
< aystride
)
2331 for (y
= 0; y
< ycount
; y
++)
2332 for (x
= 0; x
< xcount
; x
++)
2333 dest
[x
*rxstride
+ y
*rystride
] = (GFC_COMPLEX_8
)0;
2335 for (y
= 0; y
< ycount
; y
++)
2336 for (n
= 0; n
< count
; n
++)
2337 for (x
= 0; x
< xcount
; x
++)
2338 /* dest[x,y] += a[x,n] * b[n,y] */
2339 dest
[x
*rxstride
+ y
*rystride
] +=
2340 abase
[x
*axstride
+ n
*aystride
] *
2341 bbase
[n
*bxstride
+ y
*bystride
];
2345 const GFC_COMPLEX_8
*restrict abase_x
;
2346 const GFC_COMPLEX_8
*restrict bbase_y
;
2347 GFC_COMPLEX_8
*restrict dest_y
;
2350 for (y
= 0; y
< ycount
; y
++)
2352 bbase_y
= &bbase
[y
*bystride
];
2353 dest_y
= &dest
[y
*rystride
];
2354 for (x
= 0; x
< xcount
; x
++)
2356 abase_x
= &abase
[x
*axstride
];
2357 s
= (GFC_COMPLEX_8
) 0;
2358 for (n
= 0; n
< count
; n
++)
2359 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
2360 dest_y
[x
*rxstride
] = s
;
2370 /* Compiling main function, with selection code for the processor. */
2372 /* Currently, this is i386 only. Adjust for other architectures. */
2374 void matmul_c8 (gfc_array_c8
* const restrict retarray
,
2375 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
2376 int blas_limit
, blas_call gemm
)
2378 static void (*matmul_p
) (gfc_array_c8
* const restrict retarray
,
2379 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
2380 int blas_limit
, blas_call gemm
);
2382 void (*matmul_fn
) (gfc_array_c8
* const restrict retarray
,
2383 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
2384 int blas_limit
, blas_call gemm
);
2386 matmul_fn
= __atomic_load_n (&matmul_p
, __ATOMIC_RELAXED
);
2387 if (matmul_fn
== NULL
)
2389 matmul_fn
= matmul_c8_vanilla
;
2390 if (__builtin_cpu_is ("intel"))
2392 /* Run down the available processors in order of preference. */
2394 if (__builtin_cpu_supports ("avx512f"))
2396 matmul_fn
= matmul_c8_avx512f
;
2400 #endif /* HAVE_AVX512F */
2403 if (__builtin_cpu_supports ("avx2")
2404 && __builtin_cpu_supports ("fma"))
2406 matmul_fn
= matmul_c8_avx2
;
2413 if (__builtin_cpu_supports ("avx"))
2415 matmul_fn
= matmul_c8_avx
;
2418 #endif /* HAVE_AVX */
2420 else if (__builtin_cpu_is ("amd"))
2422 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
2423 if (__builtin_cpu_supports ("avx")
2424 && __builtin_cpu_supports ("fma"))
2426 matmul_fn
= matmul_c8_avx128_fma3
;
2430 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
2431 if (__builtin_cpu_supports ("avx")
2432 && __builtin_cpu_supports ("fma4"))
2434 matmul_fn
= matmul_c8_avx128_fma4
;
2441 __atomic_store_n (&matmul_p
, matmul_fn
, __ATOMIC_RELAXED
);
2444 (*matmul_fn
) (retarray
, a
, b
, try_blas
, blas_limit
, gemm
);
2447 #else /* Just the vanilla function. */
2450 matmul_c8 (gfc_array_c8
* const restrict retarray
,
2451 gfc_array_c8
* const restrict a
, gfc_array_c8
* const restrict b
, int try_blas
,
2452 int blas_limit
, blas_call gemm
)
2454 const GFC_COMPLEX_8
* restrict abase
;
2455 const GFC_COMPLEX_8
* restrict bbase
;
2456 GFC_COMPLEX_8
* restrict dest
;
2458 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
2459 index_type x
, y
, n
, count
, xcount
, ycount
;
2461 assert (GFC_DESCRIPTOR_RANK (a
) == 2
2462 || GFC_DESCRIPTOR_RANK (b
) == 2);
2464 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
2466 Either A or B (but not both) can be rank 1:
2468 o One-dimensional argument A is implicitly treated as a row matrix
2469 dimensioned [1,count], so xcount=1.
2471 o One-dimensional argument B is implicitly treated as a column matrix
2472 dimensioned [count, 1], so ycount=1.
2475 if (retarray
->base_addr
== NULL
)
2477 if (GFC_DESCRIPTOR_RANK (a
) == 1)
2479 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
2480 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
2482 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
2484 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
2485 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
2489 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
2490 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
2492 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
2493 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
2494 GFC_DESCRIPTOR_EXTENT(retarray
,0));
2498 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_COMPLEX_8
));
2499 retarray
->offset
= 0;
2501 else if (unlikely (compile_options
.bounds_check
))
2503 index_type ret_extent
, arg_extent
;
2505 if (GFC_DESCRIPTOR_RANK (a
) == 1)
2507 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
2508 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
2509 if (arg_extent
!= ret_extent
)
2510 runtime_error ("Array bound mismatch for dimension 1 of "
2512 (long int) ret_extent
, (long int) arg_extent
);
2514 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
2516 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
2517 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
2518 if (arg_extent
!= ret_extent
)
2519 runtime_error ("Array bound mismatch for dimension 1 of "
2521 (long int) ret_extent
, (long int) arg_extent
);
2525 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
2526 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
2527 if (arg_extent
!= ret_extent
)
2528 runtime_error ("Array bound mismatch for dimension 1 of "
2530 (long int) ret_extent
, (long int) arg_extent
);
2532 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
2533 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
2534 if (arg_extent
!= ret_extent
)
2535 runtime_error ("Array bound mismatch for dimension 2 of "
2537 (long int) ret_extent
, (long int) arg_extent
);
2542 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
2544 /* One-dimensional result may be addressed in the code below
2545 either as a row or a column matrix. We want both cases to
2547 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
2551 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
2552 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
2556 if (GFC_DESCRIPTOR_RANK (a
) == 1)
2558 /* Treat it as a a row matrix A[1,count]. */
2559 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
2563 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
2567 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
2568 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
2570 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
2571 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
2574 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
2576 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
2577 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
2578 "in dimension 1: is %ld, should be %ld",
2579 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
2582 if (GFC_DESCRIPTOR_RANK (b
) == 1)
2584 /* Treat it as a column matrix B[count,1] */
2585 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
2587 /* bystride should never be used for 1-dimensional b.
2588 The value is only used for calculation of the
2589 memory by the buffer. */
2595 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
2596 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
2597 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
2600 abase
= a
->base_addr
;
2601 bbase
= b
->base_addr
;
2602 dest
= retarray
->base_addr
;
2604 /* Now that everything is set up, we perform the multiplication
2607 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
2608 #define min(a,b) ((a) <= (b) ? (a) : (b))
2609 #define max(a,b) ((a) >= (b) ? (a) : (b))
2611 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
2612 && (bxstride
== 1 || bystride
== 1)
2613 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
2614 > POW3(blas_limit
)))
2616 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
2617 const GFC_COMPLEX_8 one
= 1, zero
= 0;
2618 const int lda
= (axstride
== 1) ? aystride
: axstride
,
2619 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
2621 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
2623 assert (gemm
!= NULL
);
2624 const char *transa
, *transb
;
2628 transa
= axstride
== 1 ? "N" : "T";
2633 transb
= bxstride
== 1 ? "N" : "T";
2635 gemm (transa
, transb
, &m
,
2636 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
2642 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1
2643 && GFC_DESCRIPTOR_RANK (b
) != 1)
2645 /* This block of code implements a tuned matmul, derived from
2646 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
2648 Bo Kagstrom and Per Ling
2649 Department of Computing Science
2651 S-901 87 Umea, Sweden
2653 from netlib.org, translated to C, and modified for matmul.m4. */
2655 const GFC_COMPLEX_8
*a
, *b
;
2657 const index_type m
= xcount
, n
= ycount
, k
= count
;
2659 /* System generated locals */
2660 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
2661 i1
, i2
, i3
, i4
, i5
, i6
;
2663 /* Local variables */
2664 GFC_COMPLEX_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
2665 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
2666 index_type i
, j
, l
, ii
, jj
, ll
;
2667 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
2672 c
= retarray
->base_addr
;
2674 /* Parameter adjustments */
2676 c_offset
= 1 + c_dim1
;
2679 a_offset
= 1 + a_dim1
;
2682 b_offset
= 1 + b_dim1
;
2685 /* Empty c first. */
2686 for (j
=1; j
<=n
; j
++)
2687 for (i
=1; i
<=m
; i
++)
2688 c
[i
+ j
* c_dim1
] = (GFC_COMPLEX_8
)0;
2690 /* Early exit if possible */
2691 if (m
== 0 || n
== 0 || k
== 0)
2694 /* Adjust size of t1 to what is needed. */
2695 index_type t1_dim
, a_sz
;
2701 t1_dim
= a_sz
* 256 + b_dim1
;
2705 t1
= malloc (t1_dim
* sizeof(GFC_COMPLEX_8
));
2707 /* Start turning the crank. */
2709 for (jj
= 1; jj
<= i1
; jj
+= 512)
2715 ujsec
= jsec
- jsec
% 4;
2717 for (ll
= 1; ll
<= i2
; ll
+= 256)
2723 ulsec
= lsec
- lsec
% 2;
2726 for (ii
= 1; ii
<= i3
; ii
+= 256)
2732 uisec
= isec
- isec
% 2;
2733 i4
= ll
+ ulsec
- 1;
2734 for (l
= ll
; l
<= i4
; l
+= 2)
2736 i5
= ii
+ uisec
- 1;
2737 for (i
= ii
; i
<= i5
; i
+= 2)
2739 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
2741 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
2742 a
[i
+ (l
+ 1) * a_dim1
];
2743 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
2744 a
[i
+ 1 + l
* a_dim1
];
2745 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
2746 a
[i
+ 1 + (l
+ 1) * a_dim1
];
2750 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
2751 a
[ii
+ isec
- 1 + l
* a_dim1
];
2752 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
2753 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
2759 for (i
= ii
; i
<= i4
; ++i
)
2761 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
2762 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
2766 uisec
= isec
- isec
% 4;
2767 i4
= jj
+ ujsec
- 1;
2768 for (j
= jj
; j
<= i4
; j
+= 4)
2770 i5
= ii
+ uisec
- 1;
2771 for (i
= ii
; i
<= i5
; i
+= 4)
2773 f11
= c
[i
+ j
* c_dim1
];
2774 f21
= c
[i
+ 1 + j
* c_dim1
];
2775 f12
= c
[i
+ (j
+ 1) * c_dim1
];
2776 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
2777 f13
= c
[i
+ (j
+ 2) * c_dim1
];
2778 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
2779 f14
= c
[i
+ (j
+ 3) * c_dim1
];
2780 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
2781 f31
= c
[i
+ 2 + j
* c_dim1
];
2782 f41
= c
[i
+ 3 + j
* c_dim1
];
2783 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
2784 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
2785 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
2786 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
2787 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
2788 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
2790 for (l
= ll
; l
<= i6
; ++l
)
2792 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2793 * b
[l
+ j
* b_dim1
];
2794 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2795 * b
[l
+ j
* b_dim1
];
2796 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2797 * b
[l
+ (j
+ 1) * b_dim1
];
2798 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2799 * b
[l
+ (j
+ 1) * b_dim1
];
2800 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2801 * b
[l
+ (j
+ 2) * b_dim1
];
2802 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2803 * b
[l
+ (j
+ 2) * b_dim1
];
2804 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2805 * b
[l
+ (j
+ 3) * b_dim1
];
2806 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2807 * b
[l
+ (j
+ 3) * b_dim1
];
2808 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2809 * b
[l
+ j
* b_dim1
];
2810 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2811 * b
[l
+ j
* b_dim1
];
2812 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2813 * b
[l
+ (j
+ 1) * b_dim1
];
2814 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2815 * b
[l
+ (j
+ 1) * b_dim1
];
2816 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2817 * b
[l
+ (j
+ 2) * b_dim1
];
2818 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2819 * b
[l
+ (j
+ 2) * b_dim1
];
2820 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2821 * b
[l
+ (j
+ 3) * b_dim1
];
2822 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2823 * b
[l
+ (j
+ 3) * b_dim1
];
2825 c
[i
+ j
* c_dim1
] = f11
;
2826 c
[i
+ 1 + j
* c_dim1
] = f21
;
2827 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
2828 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
2829 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
2830 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
2831 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
2832 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
2833 c
[i
+ 2 + j
* c_dim1
] = f31
;
2834 c
[i
+ 3 + j
* c_dim1
] = f41
;
2835 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
2836 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
2837 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
2838 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
2839 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
2840 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
2845 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
2847 f11
= c
[i
+ j
* c_dim1
];
2848 f12
= c
[i
+ (j
+ 1) * c_dim1
];
2849 f13
= c
[i
+ (j
+ 2) * c_dim1
];
2850 f14
= c
[i
+ (j
+ 3) * c_dim1
];
2852 for (l
= ll
; l
<= i6
; ++l
)
2854 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2855 257] * b
[l
+ j
* b_dim1
];
2856 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2857 257] * b
[l
+ (j
+ 1) * b_dim1
];
2858 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2859 257] * b
[l
+ (j
+ 2) * b_dim1
];
2860 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2861 257] * b
[l
+ (j
+ 3) * b_dim1
];
2863 c
[i
+ j
* c_dim1
] = f11
;
2864 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
2865 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
2866 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
2873 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
2875 i5
= ii
+ uisec
- 1;
2876 for (i
= ii
; i
<= i5
; i
+= 4)
2878 f11
= c
[i
+ j
* c_dim1
];
2879 f21
= c
[i
+ 1 + j
* c_dim1
];
2880 f31
= c
[i
+ 2 + j
* c_dim1
];
2881 f41
= c
[i
+ 3 + j
* c_dim1
];
2883 for (l
= ll
; l
<= i6
; ++l
)
2885 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2886 257] * b
[l
+ j
* b_dim1
];
2887 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
2888 257] * b
[l
+ j
* b_dim1
];
2889 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
2890 257] * b
[l
+ j
* b_dim1
];
2891 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
2892 257] * b
[l
+ j
* b_dim1
];
2894 c
[i
+ j
* c_dim1
] = f11
;
2895 c
[i
+ 1 + j
* c_dim1
] = f21
;
2896 c
[i
+ 2 + j
* c_dim1
] = f31
;
2897 c
[i
+ 3 + j
* c_dim1
] = f41
;
2900 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
2902 f11
= c
[i
+ j
* c_dim1
];
2904 for (l
= ll
; l
<= i6
; ++l
)
2906 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2907 257] * b
[l
+ j
* b_dim1
];
2909 c
[i
+ j
* c_dim1
] = f11
;
2919 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
2921 if (GFC_DESCRIPTOR_RANK (a
) != 1)
2923 const GFC_COMPLEX_8
*restrict abase_x
;
2924 const GFC_COMPLEX_8
*restrict bbase_y
;
2925 GFC_COMPLEX_8
*restrict dest_y
;
2928 for (y
= 0; y
< ycount
; y
++)
2930 bbase_y
= &bbase
[y
*bystride
];
2931 dest_y
= &dest
[y
*rystride
];
2932 for (x
= 0; x
< xcount
; x
++)
2934 abase_x
= &abase
[x
*axstride
];
2935 s
= (GFC_COMPLEX_8
) 0;
2936 for (n
= 0; n
< count
; n
++)
2937 s
+= abase_x
[n
] * bbase_y
[n
];
2944 const GFC_COMPLEX_8
*restrict bbase_y
;
2947 for (y
= 0; y
< ycount
; y
++)
2949 bbase_y
= &bbase
[y
*bystride
];
2950 s
= (GFC_COMPLEX_8
) 0;
2951 for (n
= 0; n
< count
; n
++)
2952 s
+= abase
[n
*axstride
] * bbase_y
[n
];
2953 dest
[y
*rystride
] = s
;
2957 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
2959 const GFC_COMPLEX_8
*restrict bbase_y
;
2962 for (y
= 0; y
< ycount
; y
++)
2964 bbase_y
= &bbase
[y
*bystride
];
2965 s
= (GFC_COMPLEX_8
) 0;
2966 for (n
= 0; n
< count
; n
++)
2967 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
2968 dest
[y
*rxstride
] = s
;
2971 else if (axstride
< aystride
)
2973 for (y
= 0; y
< ycount
; y
++)
2974 for (x
= 0; x
< xcount
; x
++)
2975 dest
[x
*rxstride
+ y
*rystride
] = (GFC_COMPLEX_8
)0;
2977 for (y
= 0; y
< ycount
; y
++)
2978 for (n
= 0; n
< count
; n
++)
2979 for (x
= 0; x
< xcount
; x
++)
2980 /* dest[x,y] += a[x,n] * b[n,y] */
2981 dest
[x
*rxstride
+ y
*rystride
] +=
2982 abase
[x
*axstride
+ n
*aystride
] *
2983 bbase
[n
*bxstride
+ y
*bystride
];
2987 const GFC_COMPLEX_8
*restrict abase_x
;
2988 const GFC_COMPLEX_8
*restrict bbase_y
;
2989 GFC_COMPLEX_8
*restrict dest_y
;
2992 for (y
= 0; y
< ycount
; y
++)
2994 bbase_y
= &bbase
[y
*bystride
];
2995 dest_y
= &dest
[y
*rystride
];
2996 for (x
= 0; x
< xcount
; x
++)
2998 abase_x
= &abase
[x
*axstride
];
2999 s
= (GFC_COMPLEX_8
) 0;
3000 for (n
= 0; n
< count
; n
++)
3001 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
3002 dest_y
[x
*rxstride
] = s
;