1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2023 Free Software Foundation, Inc.
3 Contributed by Thomas Koenig <tkoenig@gcc.gnu.org>.
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
31 /* These are the specific versions of matmul with -mprefer-avx128. */
33 #if defined (HAVE_GFC_INTEGER_8)
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36 passed to us by the front-end, in which case we call it for large
39 typedef void (*blas_call
)(const char *, const char *, const int *, const int *,
40 const int *, const GFC_INTEGER_8
*, const GFC_INTEGER_8
*,
41 const int *, const GFC_INTEGER_8
*, const int *,
42 const GFC_INTEGER_8
*, GFC_INTEGER_8
*, const int *,
45 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
47 matmul_i8_avx128_fma3 (gfc_array_i8
* const restrict retarray
,
48 gfc_array_i8
* const restrict a
, gfc_array_i8
* const restrict b
, int try_blas
,
49 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma")));
50 internal_proto(matmul_i8_avx128_fma3
);
52 matmul_i8_avx128_fma3 (gfc_array_i8
* const restrict retarray
,
53 gfc_array_i8
* const restrict a
, gfc_array_i8
* const restrict b
, int try_blas
,
54 int blas_limit
, blas_call gemm
)
56 const GFC_INTEGER_8
* restrict abase
;
57 const GFC_INTEGER_8
* restrict bbase
;
58 GFC_INTEGER_8
* restrict dest
;
60 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
61 index_type x
, y
, n
, count
, xcount
, ycount
;
63 assert (GFC_DESCRIPTOR_RANK (a
) == 2
64 || GFC_DESCRIPTOR_RANK (b
) == 2);
66 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
68 Either A or B (but not both) can be rank 1:
70 o One-dimensional argument A is implicitly treated as a row matrix
71 dimensioned [1,count], so xcount=1.
73 o One-dimensional argument B is implicitly treated as a column matrix
74 dimensioned [count, 1], so ycount=1.
77 if (retarray
->base_addr
== NULL
)
79 if (GFC_DESCRIPTOR_RANK (a
) == 1)
81 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
82 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
84 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
86 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
87 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
91 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
92 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
94 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
95 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
96 GFC_DESCRIPTOR_EXTENT(retarray
,0));
100 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_INTEGER_8
));
101 retarray
->offset
= 0;
103 else if (unlikely (compile_options
.bounds_check
))
105 index_type ret_extent
, arg_extent
;
107 if (GFC_DESCRIPTOR_RANK (a
) == 1)
109 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
110 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
111 if (arg_extent
!= ret_extent
)
112 runtime_error ("Array bound mismatch for dimension 1 of "
114 (long int) ret_extent
, (long int) arg_extent
);
116 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
118 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
119 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
120 if (arg_extent
!= ret_extent
)
121 runtime_error ("Array bound mismatch for dimension 1 of "
123 (long int) ret_extent
, (long int) arg_extent
);
127 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
128 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
129 if (arg_extent
!= ret_extent
)
130 runtime_error ("Array bound mismatch for dimension 1 of "
132 (long int) ret_extent
, (long int) arg_extent
);
134 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
135 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
136 if (arg_extent
!= ret_extent
)
137 runtime_error ("Array bound mismatch for dimension 2 of "
139 (long int) ret_extent
, (long int) arg_extent
);
144 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
146 /* One-dimensional result may be addressed in the code below
147 either as a row or a column matrix. We want both cases to
149 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
153 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
154 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
158 if (GFC_DESCRIPTOR_RANK (a
) == 1)
160 /* Treat it as a a row matrix A[1,count]. */
161 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
165 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
169 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
170 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
172 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
173 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
176 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
178 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
179 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
180 "in dimension 1: is %ld, should be %ld",
181 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
184 if (GFC_DESCRIPTOR_RANK (b
) == 1)
186 /* Treat it as a column matrix B[count,1] */
187 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
189 /* bystride should never be used for 1-dimensional b.
190 The value is only used for calculation of the
191 memory by the buffer. */
197 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
198 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
199 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
202 abase
= a
->base_addr
;
203 bbase
= b
->base_addr
;
204 dest
= retarray
->base_addr
;
206 /* Now that everything is set up, we perform the multiplication
209 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
210 #define min(a,b) ((a) <= (b) ? (a) : (b))
211 #define max(a,b) ((a) >= (b) ? (a) : (b))
213 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
214 && (bxstride
== 1 || bystride
== 1)
215 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
218 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
219 const GFC_INTEGER_8 one
= 1, zero
= 0;
220 const int lda
= (axstride
== 1) ? aystride
: axstride
,
221 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
223 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
225 assert (gemm
!= NULL
);
226 const char *transa
, *transb
;
230 transa
= axstride
== 1 ? "N" : "T";
235 transb
= bxstride
== 1 ? "N" : "T";
237 gemm (transa
, transb
, &m
,
238 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
244 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1
245 && GFC_DESCRIPTOR_RANK (b
) != 1)
247 /* This block of code implements a tuned matmul, derived from
248 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
250 Bo Kagstrom and Per Ling
251 Department of Computing Science
253 S-901 87 Umea, Sweden
255 from netlib.org, translated to C, and modified for matmul.m4. */
257 const GFC_INTEGER_8
*a
, *b
;
259 const index_type m
= xcount
, n
= ycount
, k
= count
;
261 /* System generated locals */
262 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
263 i1
, i2
, i3
, i4
, i5
, i6
;
265 /* Local variables */
266 GFC_INTEGER_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
267 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
268 index_type i
, j
, l
, ii
, jj
, ll
;
269 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
274 c
= retarray
->base_addr
;
276 /* Parameter adjustments */
278 c_offset
= 1 + c_dim1
;
281 a_offset
= 1 + a_dim1
;
284 b_offset
= 1 + b_dim1
;
290 c
[i
+ j
* c_dim1
] = (GFC_INTEGER_8
)0;
292 /* Early exit if possible */
293 if (m
== 0 || n
== 0 || k
== 0)
296 /* Adjust size of t1 to what is needed. */
297 index_type t1_dim
, a_sz
;
303 t1_dim
= a_sz
* 256 + b_dim1
;
307 t1
= malloc (t1_dim
* sizeof(GFC_INTEGER_8
));
309 /* Start turning the crank. */
311 for (jj
= 1; jj
<= i1
; jj
+= 512)
317 ujsec
= jsec
- jsec
% 4;
319 for (ll
= 1; ll
<= i2
; ll
+= 256)
325 ulsec
= lsec
- lsec
% 2;
328 for (ii
= 1; ii
<= i3
; ii
+= 256)
334 uisec
= isec
- isec
% 2;
336 for (l
= ll
; l
<= i4
; l
+= 2)
339 for (i
= ii
; i
<= i5
; i
+= 2)
341 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
343 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
344 a
[i
+ (l
+ 1) * a_dim1
];
345 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
346 a
[i
+ 1 + l
* a_dim1
];
347 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
348 a
[i
+ 1 + (l
+ 1) * a_dim1
];
352 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
353 a
[ii
+ isec
- 1 + l
* a_dim1
];
354 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
355 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
361 for (i
= ii
; i
<= i4
; ++i
)
363 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
364 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
368 uisec
= isec
- isec
% 4;
370 for (j
= jj
; j
<= i4
; j
+= 4)
373 for (i
= ii
; i
<= i5
; i
+= 4)
375 f11
= c
[i
+ j
* c_dim1
];
376 f21
= c
[i
+ 1 + j
* c_dim1
];
377 f12
= c
[i
+ (j
+ 1) * c_dim1
];
378 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
379 f13
= c
[i
+ (j
+ 2) * c_dim1
];
380 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
381 f14
= c
[i
+ (j
+ 3) * c_dim1
];
382 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
383 f31
= c
[i
+ 2 + j
* c_dim1
];
384 f41
= c
[i
+ 3 + j
* c_dim1
];
385 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
386 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
387 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
388 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
389 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
390 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
392 for (l
= ll
; l
<= i6
; ++l
)
394 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
396 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
398 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
399 * b
[l
+ (j
+ 1) * b_dim1
];
400 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
401 * b
[l
+ (j
+ 1) * b_dim1
];
402 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
403 * b
[l
+ (j
+ 2) * b_dim1
];
404 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
405 * b
[l
+ (j
+ 2) * b_dim1
];
406 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
407 * b
[l
+ (j
+ 3) * b_dim1
];
408 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
409 * b
[l
+ (j
+ 3) * b_dim1
];
410 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
412 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
414 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
415 * b
[l
+ (j
+ 1) * b_dim1
];
416 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
417 * b
[l
+ (j
+ 1) * b_dim1
];
418 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
419 * b
[l
+ (j
+ 2) * b_dim1
];
420 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
421 * b
[l
+ (j
+ 2) * b_dim1
];
422 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
423 * b
[l
+ (j
+ 3) * b_dim1
];
424 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
425 * b
[l
+ (j
+ 3) * b_dim1
];
427 c
[i
+ j
* c_dim1
] = f11
;
428 c
[i
+ 1 + j
* c_dim1
] = f21
;
429 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
430 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
431 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
432 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
433 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
434 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
435 c
[i
+ 2 + j
* c_dim1
] = f31
;
436 c
[i
+ 3 + j
* c_dim1
] = f41
;
437 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
438 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
439 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
440 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
441 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
442 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
447 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
449 f11
= c
[i
+ j
* c_dim1
];
450 f12
= c
[i
+ (j
+ 1) * c_dim1
];
451 f13
= c
[i
+ (j
+ 2) * c_dim1
];
452 f14
= c
[i
+ (j
+ 3) * c_dim1
];
454 for (l
= ll
; l
<= i6
; ++l
)
456 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
457 257] * b
[l
+ j
* b_dim1
];
458 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
459 257] * b
[l
+ (j
+ 1) * b_dim1
];
460 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
461 257] * b
[l
+ (j
+ 2) * b_dim1
];
462 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
463 257] * b
[l
+ (j
+ 3) * b_dim1
];
465 c
[i
+ j
* c_dim1
] = f11
;
466 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
467 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
468 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
475 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
478 for (i
= ii
; i
<= i5
; i
+= 4)
480 f11
= c
[i
+ j
* c_dim1
];
481 f21
= c
[i
+ 1 + j
* c_dim1
];
482 f31
= c
[i
+ 2 + j
* c_dim1
];
483 f41
= c
[i
+ 3 + j
* c_dim1
];
485 for (l
= ll
; l
<= i6
; ++l
)
487 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
488 257] * b
[l
+ j
* b_dim1
];
489 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
490 257] * b
[l
+ j
* b_dim1
];
491 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
492 257] * b
[l
+ j
* b_dim1
];
493 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
494 257] * b
[l
+ j
* b_dim1
];
496 c
[i
+ j
* c_dim1
] = f11
;
497 c
[i
+ 1 + j
* c_dim1
] = f21
;
498 c
[i
+ 2 + j
* c_dim1
] = f31
;
499 c
[i
+ 3 + j
* c_dim1
] = f41
;
502 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
504 f11
= c
[i
+ j
* c_dim1
];
506 for (l
= ll
; l
<= i6
; ++l
)
508 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
509 257] * b
[l
+ j
* b_dim1
];
511 c
[i
+ j
* c_dim1
] = f11
;
521 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
523 if (GFC_DESCRIPTOR_RANK (a
) != 1)
525 const GFC_INTEGER_8
*restrict abase_x
;
526 const GFC_INTEGER_8
*restrict bbase_y
;
527 GFC_INTEGER_8
*restrict dest_y
;
530 for (y
= 0; y
< ycount
; y
++)
532 bbase_y
= &bbase
[y
*bystride
];
533 dest_y
= &dest
[y
*rystride
];
534 for (x
= 0; x
< xcount
; x
++)
536 abase_x
= &abase
[x
*axstride
];
537 s
= (GFC_INTEGER_8
) 0;
538 for (n
= 0; n
< count
; n
++)
539 s
+= abase_x
[n
] * bbase_y
[n
];
546 const GFC_INTEGER_8
*restrict bbase_y
;
549 for (y
= 0; y
< ycount
; y
++)
551 bbase_y
= &bbase
[y
*bystride
];
552 s
= (GFC_INTEGER_8
) 0;
553 for (n
= 0; n
< count
; n
++)
554 s
+= abase
[n
*axstride
] * bbase_y
[n
];
555 dest
[y
*rystride
] = s
;
559 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
561 const GFC_INTEGER_8
*restrict bbase_y
;
564 for (y
= 0; y
< ycount
; y
++)
566 bbase_y
= &bbase
[y
*bystride
];
567 s
= (GFC_INTEGER_8
) 0;
568 for (n
= 0; n
< count
; n
++)
569 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
570 dest
[y
*rxstride
] = s
;
573 else if (axstride
< aystride
)
575 for (y
= 0; y
< ycount
; y
++)
576 for (x
= 0; x
< xcount
; x
++)
577 dest
[x
*rxstride
+ y
*rystride
] = (GFC_INTEGER_8
)0;
579 for (y
= 0; y
< ycount
; y
++)
580 for (n
= 0; n
< count
; n
++)
581 for (x
= 0; x
< xcount
; x
++)
582 /* dest[x,y] += a[x,n] * b[n,y] */
583 dest
[x
*rxstride
+ y
*rystride
] +=
584 abase
[x
*axstride
+ n
*aystride
] *
585 bbase
[n
*bxstride
+ y
*bystride
];
589 const GFC_INTEGER_8
*restrict abase_x
;
590 const GFC_INTEGER_8
*restrict bbase_y
;
591 GFC_INTEGER_8
*restrict dest_y
;
594 for (y
= 0; y
< ycount
; y
++)
596 bbase_y
= &bbase
[y
*bystride
];
597 dest_y
= &dest
[y
*rystride
];
598 for (x
= 0; x
< xcount
; x
++)
600 abase_x
= &abase
[x
*axstride
];
601 s
= (GFC_INTEGER_8
) 0;
602 for (n
= 0; n
< count
; n
++)
603 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
604 dest_y
[x
*rxstride
] = s
;
615 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
617 matmul_i8_avx128_fma4 (gfc_array_i8
* const restrict retarray
,
618 gfc_array_i8
* const restrict a
, gfc_array_i8
* const restrict b
, int try_blas
,
619 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma4")));
620 internal_proto(matmul_i8_avx128_fma4
);
622 matmul_i8_avx128_fma4 (gfc_array_i8
* const restrict retarray
,
623 gfc_array_i8
* const restrict a
, gfc_array_i8
* const restrict b
, int try_blas
,
624 int blas_limit
, blas_call gemm
)
626 const GFC_INTEGER_8
* restrict abase
;
627 const GFC_INTEGER_8
* restrict bbase
;
628 GFC_INTEGER_8
* restrict dest
;
630 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
631 index_type x
, y
, n
, count
, xcount
, ycount
;
633 assert (GFC_DESCRIPTOR_RANK (a
) == 2
634 || GFC_DESCRIPTOR_RANK (b
) == 2);
636 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
638 Either A or B (but not both) can be rank 1:
640 o One-dimensional argument A is implicitly treated as a row matrix
641 dimensioned [1,count], so xcount=1.
643 o One-dimensional argument B is implicitly treated as a column matrix
644 dimensioned [count, 1], so ycount=1.
647 if (retarray
->base_addr
== NULL
)
649 if (GFC_DESCRIPTOR_RANK (a
) == 1)
651 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
652 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
654 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
656 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
657 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
661 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
662 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
664 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
665 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
666 GFC_DESCRIPTOR_EXTENT(retarray
,0));
670 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_INTEGER_8
));
671 retarray
->offset
= 0;
673 else if (unlikely (compile_options
.bounds_check
))
675 index_type ret_extent
, arg_extent
;
677 if (GFC_DESCRIPTOR_RANK (a
) == 1)
679 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
680 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
681 if (arg_extent
!= ret_extent
)
682 runtime_error ("Array bound mismatch for dimension 1 of "
684 (long int) ret_extent
, (long int) arg_extent
);
686 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
688 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
689 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
690 if (arg_extent
!= ret_extent
)
691 runtime_error ("Array bound mismatch for dimension 1 of "
693 (long int) ret_extent
, (long int) arg_extent
);
697 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
698 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
699 if (arg_extent
!= ret_extent
)
700 runtime_error ("Array bound mismatch for dimension 1 of "
702 (long int) ret_extent
, (long int) arg_extent
);
704 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
705 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
706 if (arg_extent
!= ret_extent
)
707 runtime_error ("Array bound mismatch for dimension 2 of "
709 (long int) ret_extent
, (long int) arg_extent
);
714 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
716 /* One-dimensional result may be addressed in the code below
717 either as a row or a column matrix. We want both cases to
719 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
723 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
724 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
728 if (GFC_DESCRIPTOR_RANK (a
) == 1)
730 /* Treat it as a a row matrix A[1,count]. */
731 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
735 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
739 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
740 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
742 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
743 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
746 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
748 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
749 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
750 "in dimension 1: is %ld, should be %ld",
751 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
754 if (GFC_DESCRIPTOR_RANK (b
) == 1)
756 /* Treat it as a column matrix B[count,1] */
757 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
759 /* bystride should never be used for 1-dimensional b.
760 The value is only used for calculation of the
761 memory by the buffer. */
767 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
768 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
769 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
772 abase
= a
->base_addr
;
773 bbase
= b
->base_addr
;
774 dest
= retarray
->base_addr
;
776 /* Now that everything is set up, we perform the multiplication
779 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
780 #define min(a,b) ((a) <= (b) ? (a) : (b))
781 #define max(a,b) ((a) >= (b) ? (a) : (b))
783 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
784 && (bxstride
== 1 || bystride
== 1)
785 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
788 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
789 const GFC_INTEGER_8 one
= 1, zero
= 0;
790 const int lda
= (axstride
== 1) ? aystride
: axstride
,
791 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
793 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
795 assert (gemm
!= NULL
);
796 const char *transa
, *transb
;
800 transa
= axstride
== 1 ? "N" : "T";
805 transb
= bxstride
== 1 ? "N" : "T";
807 gemm (transa
, transb
, &m
,
808 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
814 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1
815 && GFC_DESCRIPTOR_RANK (b
) != 1)
817 /* This block of code implements a tuned matmul, derived from
818 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
820 Bo Kagstrom and Per Ling
821 Department of Computing Science
823 S-901 87 Umea, Sweden
825 from netlib.org, translated to C, and modified for matmul.m4. */
827 const GFC_INTEGER_8
*a
, *b
;
829 const index_type m
= xcount
, n
= ycount
, k
= count
;
831 /* System generated locals */
832 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
833 i1
, i2
, i3
, i4
, i5
, i6
;
835 /* Local variables */
836 GFC_INTEGER_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
837 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
838 index_type i
, j
, l
, ii
, jj
, ll
;
839 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
844 c
= retarray
->base_addr
;
846 /* Parameter adjustments */
848 c_offset
= 1 + c_dim1
;
851 a_offset
= 1 + a_dim1
;
854 b_offset
= 1 + b_dim1
;
860 c
[i
+ j
* c_dim1
] = (GFC_INTEGER_8
)0;
862 /* Early exit if possible */
863 if (m
== 0 || n
== 0 || k
== 0)
866 /* Adjust size of t1 to what is needed. */
867 index_type t1_dim
, a_sz
;
873 t1_dim
= a_sz
* 256 + b_dim1
;
877 t1
= malloc (t1_dim
* sizeof(GFC_INTEGER_8
));
879 /* Start turning the crank. */
881 for (jj
= 1; jj
<= i1
; jj
+= 512)
887 ujsec
= jsec
- jsec
% 4;
889 for (ll
= 1; ll
<= i2
; ll
+= 256)
895 ulsec
= lsec
- lsec
% 2;
898 for (ii
= 1; ii
<= i3
; ii
+= 256)
904 uisec
= isec
- isec
% 2;
906 for (l
= ll
; l
<= i4
; l
+= 2)
909 for (i
= ii
; i
<= i5
; i
+= 2)
911 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
913 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
914 a
[i
+ (l
+ 1) * a_dim1
];
915 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
916 a
[i
+ 1 + l
* a_dim1
];
917 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
918 a
[i
+ 1 + (l
+ 1) * a_dim1
];
922 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
923 a
[ii
+ isec
- 1 + l
* a_dim1
];
924 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
925 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
931 for (i
= ii
; i
<= i4
; ++i
)
933 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
934 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
938 uisec
= isec
- isec
% 4;
940 for (j
= jj
; j
<= i4
; j
+= 4)
943 for (i
= ii
; i
<= i5
; i
+= 4)
945 f11
= c
[i
+ j
* c_dim1
];
946 f21
= c
[i
+ 1 + j
* c_dim1
];
947 f12
= c
[i
+ (j
+ 1) * c_dim1
];
948 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
949 f13
= c
[i
+ (j
+ 2) * c_dim1
];
950 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
951 f14
= c
[i
+ (j
+ 3) * c_dim1
];
952 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
953 f31
= c
[i
+ 2 + j
* c_dim1
];
954 f41
= c
[i
+ 3 + j
* c_dim1
];
955 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
956 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
957 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
958 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
959 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
960 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
962 for (l
= ll
; l
<= i6
; ++l
)
964 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
966 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
968 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
969 * b
[l
+ (j
+ 1) * b_dim1
];
970 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
971 * b
[l
+ (j
+ 1) * b_dim1
];
972 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
973 * b
[l
+ (j
+ 2) * b_dim1
];
974 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
975 * b
[l
+ (j
+ 2) * b_dim1
];
976 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
977 * b
[l
+ (j
+ 3) * b_dim1
];
978 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
979 * b
[l
+ (j
+ 3) * b_dim1
];
980 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
982 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
984 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
985 * b
[l
+ (j
+ 1) * b_dim1
];
986 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
987 * b
[l
+ (j
+ 1) * b_dim1
];
988 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
989 * b
[l
+ (j
+ 2) * b_dim1
];
990 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
991 * b
[l
+ (j
+ 2) * b_dim1
];
992 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
993 * b
[l
+ (j
+ 3) * b_dim1
];
994 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
995 * b
[l
+ (j
+ 3) * b_dim1
];
997 c
[i
+ j
* c_dim1
] = f11
;
998 c
[i
+ 1 + j
* c_dim1
] = f21
;
999 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1000 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
1001 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1002 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
1003 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1004 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
1005 c
[i
+ 2 + j
* c_dim1
] = f31
;
1006 c
[i
+ 3 + j
* c_dim1
] = f41
;
1007 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
1008 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
1009 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
1010 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
1011 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
1012 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
1017 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1019 f11
= c
[i
+ j
* c_dim1
];
1020 f12
= c
[i
+ (j
+ 1) * c_dim1
];
1021 f13
= c
[i
+ (j
+ 2) * c_dim1
];
1022 f14
= c
[i
+ (j
+ 3) * c_dim1
];
1024 for (l
= ll
; l
<= i6
; ++l
)
1026 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1027 257] * b
[l
+ j
* b_dim1
];
1028 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1029 257] * b
[l
+ (j
+ 1) * b_dim1
];
1030 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1031 257] * b
[l
+ (j
+ 2) * b_dim1
];
1032 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1033 257] * b
[l
+ (j
+ 3) * b_dim1
];
1035 c
[i
+ j
* c_dim1
] = f11
;
1036 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1037 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1038 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1045 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
1047 i5
= ii
+ uisec
- 1;
1048 for (i
= ii
; i
<= i5
; i
+= 4)
1050 f11
= c
[i
+ j
* c_dim1
];
1051 f21
= c
[i
+ 1 + j
* c_dim1
];
1052 f31
= c
[i
+ 2 + j
* c_dim1
];
1053 f41
= c
[i
+ 3 + j
* c_dim1
];
1055 for (l
= ll
; l
<= i6
; ++l
)
1057 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1058 257] * b
[l
+ j
* b_dim1
];
1059 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
1060 257] * b
[l
+ j
* b_dim1
];
1061 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
1062 257] * b
[l
+ j
* b_dim1
];
1063 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
1064 257] * b
[l
+ j
* b_dim1
];
1066 c
[i
+ j
* c_dim1
] = f11
;
1067 c
[i
+ 1 + j
* c_dim1
] = f21
;
1068 c
[i
+ 2 + j
* c_dim1
] = f31
;
1069 c
[i
+ 3 + j
* c_dim1
] = f41
;
1072 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1074 f11
= c
[i
+ j
* c_dim1
];
1076 for (l
= ll
; l
<= i6
; ++l
)
1078 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1079 257] * b
[l
+ j
* b_dim1
];
1081 c
[i
+ j
* c_dim1
] = f11
;
1091 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
1093 if (GFC_DESCRIPTOR_RANK (a
) != 1)
1095 const GFC_INTEGER_8
*restrict abase_x
;
1096 const GFC_INTEGER_8
*restrict bbase_y
;
1097 GFC_INTEGER_8
*restrict dest_y
;
1100 for (y
= 0; y
< ycount
; y
++)
1102 bbase_y
= &bbase
[y
*bystride
];
1103 dest_y
= &dest
[y
*rystride
];
1104 for (x
= 0; x
< xcount
; x
++)
1106 abase_x
= &abase
[x
*axstride
];
1107 s
= (GFC_INTEGER_8
) 0;
1108 for (n
= 0; n
< count
; n
++)
1109 s
+= abase_x
[n
] * bbase_y
[n
];
1116 const GFC_INTEGER_8
*restrict bbase_y
;
1119 for (y
= 0; y
< ycount
; y
++)
1121 bbase_y
= &bbase
[y
*bystride
];
1122 s
= (GFC_INTEGER_8
) 0;
1123 for (n
= 0; n
< count
; n
++)
1124 s
+= abase
[n
*axstride
] * bbase_y
[n
];
1125 dest
[y
*rystride
] = s
;
1129 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
1131 const GFC_INTEGER_8
*restrict bbase_y
;
1134 for (y
= 0; y
< ycount
; y
++)
1136 bbase_y
= &bbase
[y
*bystride
];
1137 s
= (GFC_INTEGER_8
) 0;
1138 for (n
= 0; n
< count
; n
++)
1139 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
1140 dest
[y
*rxstride
] = s
;
1143 else if (axstride
< aystride
)
1145 for (y
= 0; y
< ycount
; y
++)
1146 for (x
= 0; x
< xcount
; x
++)
1147 dest
[x
*rxstride
+ y
*rystride
] = (GFC_INTEGER_8
)0;
1149 for (y
= 0; y
< ycount
; y
++)
1150 for (n
= 0; n
< count
; n
++)
1151 for (x
= 0; x
< xcount
; x
++)
1152 /* dest[x,y] += a[x,n] * b[n,y] */
1153 dest
[x
*rxstride
+ y
*rystride
] +=
1154 abase
[x
*axstride
+ n
*aystride
] *
1155 bbase
[n
*bxstride
+ y
*bystride
];
1159 const GFC_INTEGER_8
*restrict abase_x
;
1160 const GFC_INTEGER_8
*restrict bbase_y
;
1161 GFC_INTEGER_8
*restrict dest_y
;
1164 for (y
= 0; y
< ycount
; y
++)
1166 bbase_y
= &bbase
[y
*bystride
];
1167 dest_y
= &dest
[y
*rystride
];
1168 for (x
= 0; x
< xcount
; x
++)
1170 abase_x
= &abase
[x
*axstride
];
1171 s
= (GFC_INTEGER_8
) 0;
1172 for (n
= 0; n
< count
; n
++)
1173 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
1174 dest_y
[x
*rxstride
] = s
;