1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2018 Free Software Foundation, Inc.
3 Contributed by Thomas Koenig <tkoenig@gcc.gnu.org>.
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
31 /* These are the specific versions of matmul with -mprefer-avx128. */
33 #if defined (HAVE_GFC_INTEGER_8)
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36 passed to us by the front-end, in which case we call it for large
39 typedef void (*blas_call
)(const char *, const char *, const int *, const int *,
40 const int *, const GFC_INTEGER_8
*, const GFC_INTEGER_8
*,
41 const int *, const GFC_INTEGER_8
*, const int *,
42 const GFC_INTEGER_8
*, GFC_INTEGER_8
*, const int *,
45 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
47 matmul_i8_avx128_fma3 (gfc_array_i8
* const restrict retarray
,
48 gfc_array_i8
* const restrict a
, gfc_array_i8
* const restrict b
, int try_blas
,
49 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma")));
50 internal_proto(matmul_i8_avx128_fma3
);
52 matmul_i8_avx128_fma3 (gfc_array_i8
* const restrict retarray
,
53 gfc_array_i8
* const restrict a
, gfc_array_i8
* const restrict b
, int try_blas
,
54 int blas_limit
, blas_call gemm
)
56 const GFC_INTEGER_8
* restrict abase
;
57 const GFC_INTEGER_8
* restrict bbase
;
58 GFC_INTEGER_8
* restrict dest
;
60 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
61 index_type x
, y
, n
, count
, xcount
, ycount
;
63 assert (GFC_DESCRIPTOR_RANK (a
) == 2
64 || GFC_DESCRIPTOR_RANK (b
) == 2);
66 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
68 Either A or B (but not both) can be rank 1:
70 o One-dimensional argument A is implicitly treated as a row matrix
71 dimensioned [1,count], so xcount=1.
73 o One-dimensional argument B is implicitly treated as a column matrix
74 dimensioned [count, 1], so ycount=1.
77 if (retarray
->base_addr
== NULL
)
79 if (GFC_DESCRIPTOR_RANK (a
) == 1)
81 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
82 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
84 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
86 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
87 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
91 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
92 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
94 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
95 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
96 GFC_DESCRIPTOR_EXTENT(retarray
,0));
100 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_INTEGER_8
));
101 retarray
->offset
= 0;
103 else if (unlikely (compile_options
.bounds_check
))
105 index_type ret_extent
, arg_extent
;
107 if (GFC_DESCRIPTOR_RANK (a
) == 1)
109 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
110 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
111 if (arg_extent
!= ret_extent
)
112 runtime_error ("Incorrect extent in return array in"
113 " MATMUL intrinsic: is %ld, should be %ld",
114 (long int) ret_extent
, (long int) arg_extent
);
116 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
118 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
119 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
120 if (arg_extent
!= ret_extent
)
121 runtime_error ("Incorrect extent in return array in"
122 " MATMUL intrinsic: is %ld, should be %ld",
123 (long int) ret_extent
, (long int) arg_extent
);
127 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
128 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
129 if (arg_extent
!= ret_extent
)
130 runtime_error ("Incorrect extent in return array in"
131 " MATMUL intrinsic for dimension 1:"
132 " is %ld, should be %ld",
133 (long int) ret_extent
, (long int) arg_extent
);
135 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
136 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
137 if (arg_extent
!= ret_extent
)
138 runtime_error ("Incorrect extent in return array in"
139 " MATMUL intrinsic for dimension 2:"
140 " is %ld, should be %ld",
141 (long int) ret_extent
, (long int) arg_extent
);
146 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
148 /* One-dimensional result may be addressed in the code below
149 either as a row or a column matrix. We want both cases to
151 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
155 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
156 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
160 if (GFC_DESCRIPTOR_RANK (a
) == 1)
162 /* Treat it as a a row matrix A[1,count]. */
163 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
167 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
171 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
172 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
174 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
175 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
178 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
180 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
181 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
184 if (GFC_DESCRIPTOR_RANK (b
) == 1)
186 /* Treat it as a column matrix B[count,1] */
187 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
189 /* bystride should never be used for 1-dimensional b.
190 The value is only used for calculation of the
191 memory by the buffer. */
197 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
198 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
199 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
202 abase
= a
->base_addr
;
203 bbase
= b
->base_addr
;
204 dest
= retarray
->base_addr
;
206 /* Now that everything is set up, we perform the multiplication
209 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
210 #define min(a,b) ((a) <= (b) ? (a) : (b))
211 #define max(a,b) ((a) >= (b) ? (a) : (b))
213 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
214 && (bxstride
== 1 || bystride
== 1)
215 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
218 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
219 const GFC_INTEGER_8 one
= 1, zero
= 0;
220 const int lda
= (axstride
== 1) ? aystride
: axstride
,
221 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
223 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
225 assert (gemm
!= NULL
);
226 gemm (axstride
== 1 ? "N" : "T", bxstride
== 1 ? "N" : "T", &m
,
227 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
233 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
235 /* This block of code implements a tuned matmul, derived from
236 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
238 Bo Kagstrom and Per Ling
239 Department of Computing Science
241 S-901 87 Umea, Sweden
243 from netlib.org, translated to C, and modified for matmul.m4. */
245 const GFC_INTEGER_8
*a
, *b
;
247 const index_type m
= xcount
, n
= ycount
, k
= count
;
249 /* System generated locals */
250 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
251 i1
, i2
, i3
, i4
, i5
, i6
;
253 /* Local variables */
254 GFC_INTEGER_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
255 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
256 index_type i
, j
, l
, ii
, jj
, ll
;
257 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
262 c
= retarray
->base_addr
;
264 /* Parameter adjustments */
266 c_offset
= 1 + c_dim1
;
269 a_offset
= 1 + a_dim1
;
272 b_offset
= 1 + b_dim1
;
278 c
[i
+ j
* c_dim1
] = (GFC_INTEGER_8
)0;
280 /* Early exit if possible */
281 if (m
== 0 || n
== 0 || k
== 0)
284 /* Adjust size of t1 to what is needed. */
286 t1_dim
= (a_dim1
- (ycount
> 1)) * 256 + b_dim1
;
290 t1
= malloc (t1_dim
* sizeof(GFC_INTEGER_8
));
292 /* Start turning the crank. */
294 for (jj
= 1; jj
<= i1
; jj
+= 512)
300 ujsec
= jsec
- jsec
% 4;
302 for (ll
= 1; ll
<= i2
; ll
+= 256)
308 ulsec
= lsec
- lsec
% 2;
311 for (ii
= 1; ii
<= i3
; ii
+= 256)
317 uisec
= isec
- isec
% 2;
319 for (l
= ll
; l
<= i4
; l
+= 2)
322 for (i
= ii
; i
<= i5
; i
+= 2)
324 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
326 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
327 a
[i
+ (l
+ 1) * a_dim1
];
328 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
329 a
[i
+ 1 + l
* a_dim1
];
330 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
331 a
[i
+ 1 + (l
+ 1) * a_dim1
];
335 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
336 a
[ii
+ isec
- 1 + l
* a_dim1
];
337 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
338 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
344 for (i
= ii
; i
<= i4
; ++i
)
346 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
347 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
351 uisec
= isec
- isec
% 4;
353 for (j
= jj
; j
<= i4
; j
+= 4)
356 for (i
= ii
; i
<= i5
; i
+= 4)
358 f11
= c
[i
+ j
* c_dim1
];
359 f21
= c
[i
+ 1 + j
* c_dim1
];
360 f12
= c
[i
+ (j
+ 1) * c_dim1
];
361 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
362 f13
= c
[i
+ (j
+ 2) * c_dim1
];
363 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
364 f14
= c
[i
+ (j
+ 3) * c_dim1
];
365 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
366 f31
= c
[i
+ 2 + j
* c_dim1
];
367 f41
= c
[i
+ 3 + j
* c_dim1
];
368 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
369 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
370 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
371 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
372 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
373 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
375 for (l
= ll
; l
<= i6
; ++l
)
377 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
379 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
381 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
382 * b
[l
+ (j
+ 1) * b_dim1
];
383 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
384 * b
[l
+ (j
+ 1) * b_dim1
];
385 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
386 * b
[l
+ (j
+ 2) * b_dim1
];
387 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
388 * b
[l
+ (j
+ 2) * b_dim1
];
389 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
390 * b
[l
+ (j
+ 3) * b_dim1
];
391 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
392 * b
[l
+ (j
+ 3) * b_dim1
];
393 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
395 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
397 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
398 * b
[l
+ (j
+ 1) * b_dim1
];
399 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
400 * b
[l
+ (j
+ 1) * b_dim1
];
401 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
402 * b
[l
+ (j
+ 2) * b_dim1
];
403 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
404 * b
[l
+ (j
+ 2) * b_dim1
];
405 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
406 * b
[l
+ (j
+ 3) * b_dim1
];
407 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
408 * b
[l
+ (j
+ 3) * b_dim1
];
410 c
[i
+ j
* c_dim1
] = f11
;
411 c
[i
+ 1 + j
* c_dim1
] = f21
;
412 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
413 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
414 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
415 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
416 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
417 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
418 c
[i
+ 2 + j
* c_dim1
] = f31
;
419 c
[i
+ 3 + j
* c_dim1
] = f41
;
420 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
421 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
422 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
423 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
424 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
425 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
430 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
432 f11
= c
[i
+ j
* c_dim1
];
433 f12
= c
[i
+ (j
+ 1) * c_dim1
];
434 f13
= c
[i
+ (j
+ 2) * c_dim1
];
435 f14
= c
[i
+ (j
+ 3) * c_dim1
];
437 for (l
= ll
; l
<= i6
; ++l
)
439 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
440 257] * b
[l
+ j
* b_dim1
];
441 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
442 257] * b
[l
+ (j
+ 1) * b_dim1
];
443 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
444 257] * b
[l
+ (j
+ 2) * b_dim1
];
445 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
446 257] * b
[l
+ (j
+ 3) * b_dim1
];
448 c
[i
+ j
* c_dim1
] = f11
;
449 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
450 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
451 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
458 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
461 for (i
= ii
; i
<= i5
; i
+= 4)
463 f11
= c
[i
+ j
* c_dim1
];
464 f21
= c
[i
+ 1 + j
* c_dim1
];
465 f31
= c
[i
+ 2 + j
* c_dim1
];
466 f41
= c
[i
+ 3 + j
* c_dim1
];
468 for (l
= ll
; l
<= i6
; ++l
)
470 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
471 257] * b
[l
+ j
* b_dim1
];
472 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
473 257] * b
[l
+ j
* b_dim1
];
474 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
475 257] * b
[l
+ j
* b_dim1
];
476 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
477 257] * b
[l
+ j
* b_dim1
];
479 c
[i
+ j
* c_dim1
] = f11
;
480 c
[i
+ 1 + j
* c_dim1
] = f21
;
481 c
[i
+ 2 + j
* c_dim1
] = f31
;
482 c
[i
+ 3 + j
* c_dim1
] = f41
;
485 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
487 f11
= c
[i
+ j
* c_dim1
];
489 for (l
= ll
; l
<= i6
; ++l
)
491 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
492 257] * b
[l
+ j
* b_dim1
];
494 c
[i
+ j
* c_dim1
] = f11
;
504 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
506 if (GFC_DESCRIPTOR_RANK (a
) != 1)
508 const GFC_INTEGER_8
*restrict abase_x
;
509 const GFC_INTEGER_8
*restrict bbase_y
;
510 GFC_INTEGER_8
*restrict dest_y
;
513 for (y
= 0; y
< ycount
; y
++)
515 bbase_y
= &bbase
[y
*bystride
];
516 dest_y
= &dest
[y
*rystride
];
517 for (x
= 0; x
< xcount
; x
++)
519 abase_x
= &abase
[x
*axstride
];
520 s
= (GFC_INTEGER_8
) 0;
521 for (n
= 0; n
< count
; n
++)
522 s
+= abase_x
[n
] * bbase_y
[n
];
529 const GFC_INTEGER_8
*restrict bbase_y
;
532 for (y
= 0; y
< ycount
; y
++)
534 bbase_y
= &bbase
[y
*bystride
];
535 s
= (GFC_INTEGER_8
) 0;
536 for (n
= 0; n
< count
; n
++)
537 s
+= abase
[n
*axstride
] * bbase_y
[n
];
538 dest
[y
*rystride
] = s
;
542 else if (axstride
< aystride
)
544 for (y
= 0; y
< ycount
; y
++)
545 for (x
= 0; x
< xcount
; x
++)
546 dest
[x
*rxstride
+ y
*rystride
] = (GFC_INTEGER_8
)0;
548 for (y
= 0; y
< ycount
; y
++)
549 for (n
= 0; n
< count
; n
++)
550 for (x
= 0; x
< xcount
; x
++)
551 /* dest[x,y] += a[x,n] * b[n,y] */
552 dest
[x
*rxstride
+ y
*rystride
] +=
553 abase
[x
*axstride
+ n
*aystride
] *
554 bbase
[n
*bxstride
+ y
*bystride
];
556 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
558 const GFC_INTEGER_8
*restrict bbase_y
;
561 for (y
= 0; y
< ycount
; y
++)
563 bbase_y
= &bbase
[y
*bystride
];
564 s
= (GFC_INTEGER_8
) 0;
565 for (n
= 0; n
< count
; n
++)
566 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
567 dest
[y
*rxstride
] = s
;
572 const GFC_INTEGER_8
*restrict abase_x
;
573 const GFC_INTEGER_8
*restrict bbase_y
;
574 GFC_INTEGER_8
*restrict dest_y
;
577 for (y
= 0; y
< ycount
; y
++)
579 bbase_y
= &bbase
[y
*bystride
];
580 dest_y
= &dest
[y
*rystride
];
581 for (x
= 0; x
< xcount
; x
++)
583 abase_x
= &abase
[x
*axstride
];
584 s
= (GFC_INTEGER_8
) 0;
585 for (n
= 0; n
< count
; n
++)
586 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
587 dest_y
[x
*rxstride
] = s
;
598 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
600 matmul_i8_avx128_fma4 (gfc_array_i8
* const restrict retarray
,
601 gfc_array_i8
* const restrict a
, gfc_array_i8
* const restrict b
, int try_blas
,
602 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma4")));
603 internal_proto(matmul_i8_avx128_fma4
);
605 matmul_i8_avx128_fma4 (gfc_array_i8
* const restrict retarray
,
606 gfc_array_i8
* const restrict a
, gfc_array_i8
* const restrict b
, int try_blas
,
607 int blas_limit
, blas_call gemm
)
609 const GFC_INTEGER_8
* restrict abase
;
610 const GFC_INTEGER_8
* restrict bbase
;
611 GFC_INTEGER_8
* restrict dest
;
613 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
614 index_type x
, y
, n
, count
, xcount
, ycount
;
616 assert (GFC_DESCRIPTOR_RANK (a
) == 2
617 || GFC_DESCRIPTOR_RANK (b
) == 2);
619 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
621 Either A or B (but not both) can be rank 1:
623 o One-dimensional argument A is implicitly treated as a row matrix
624 dimensioned [1,count], so xcount=1.
626 o One-dimensional argument B is implicitly treated as a column matrix
627 dimensioned [count, 1], so ycount=1.
630 if (retarray
->base_addr
== NULL
)
632 if (GFC_DESCRIPTOR_RANK (a
) == 1)
634 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
635 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
637 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
639 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
640 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
644 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
645 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
647 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
648 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
649 GFC_DESCRIPTOR_EXTENT(retarray
,0));
653 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_INTEGER_8
));
654 retarray
->offset
= 0;
656 else if (unlikely (compile_options
.bounds_check
))
658 index_type ret_extent
, arg_extent
;
660 if (GFC_DESCRIPTOR_RANK (a
) == 1)
662 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
663 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
664 if (arg_extent
!= ret_extent
)
665 runtime_error ("Incorrect extent in return array in"
666 " MATMUL intrinsic: is %ld, should be %ld",
667 (long int) ret_extent
, (long int) arg_extent
);
669 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
671 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
672 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
673 if (arg_extent
!= ret_extent
)
674 runtime_error ("Incorrect extent in return array in"
675 " MATMUL intrinsic: is %ld, should be %ld",
676 (long int) ret_extent
, (long int) arg_extent
);
680 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
681 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
682 if (arg_extent
!= ret_extent
)
683 runtime_error ("Incorrect extent in return array in"
684 " MATMUL intrinsic for dimension 1:"
685 " is %ld, should be %ld",
686 (long int) ret_extent
, (long int) arg_extent
);
688 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
689 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
690 if (arg_extent
!= ret_extent
)
691 runtime_error ("Incorrect extent in return array in"
692 " MATMUL intrinsic for dimension 2:"
693 " is %ld, should be %ld",
694 (long int) ret_extent
, (long int) arg_extent
);
699 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
701 /* One-dimensional result may be addressed in the code below
702 either as a row or a column matrix. We want both cases to
704 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
708 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
709 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
713 if (GFC_DESCRIPTOR_RANK (a
) == 1)
715 /* Treat it as a a row matrix A[1,count]. */
716 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
720 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
724 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
725 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
727 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
728 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
731 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
733 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
734 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
737 if (GFC_DESCRIPTOR_RANK (b
) == 1)
739 /* Treat it as a column matrix B[count,1] */
740 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
742 /* bystride should never be used for 1-dimensional b.
743 The value is only used for calculation of the
744 memory by the buffer. */
750 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
751 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
752 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
755 abase
= a
->base_addr
;
756 bbase
= b
->base_addr
;
757 dest
= retarray
->base_addr
;
759 /* Now that everything is set up, we perform the multiplication
762 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
763 #define min(a,b) ((a) <= (b) ? (a) : (b))
764 #define max(a,b) ((a) >= (b) ? (a) : (b))
766 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
767 && (bxstride
== 1 || bystride
== 1)
768 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
771 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
772 const GFC_INTEGER_8 one
= 1, zero
= 0;
773 const int lda
= (axstride
== 1) ? aystride
: axstride
,
774 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
776 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
778 assert (gemm
!= NULL
);
779 gemm (axstride
== 1 ? "N" : "T", bxstride
== 1 ? "N" : "T", &m
,
780 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
786 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
788 /* This block of code implements a tuned matmul, derived from
789 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
791 Bo Kagstrom and Per Ling
792 Department of Computing Science
794 S-901 87 Umea, Sweden
796 from netlib.org, translated to C, and modified for matmul.m4. */
798 const GFC_INTEGER_8
*a
, *b
;
800 const index_type m
= xcount
, n
= ycount
, k
= count
;
802 /* System generated locals */
803 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
804 i1
, i2
, i3
, i4
, i5
, i6
;
806 /* Local variables */
807 GFC_INTEGER_8 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
808 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
809 index_type i
, j
, l
, ii
, jj
, ll
;
810 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
815 c
= retarray
->base_addr
;
817 /* Parameter adjustments */
819 c_offset
= 1 + c_dim1
;
822 a_offset
= 1 + a_dim1
;
825 b_offset
= 1 + b_dim1
;
831 c
[i
+ j
* c_dim1
] = (GFC_INTEGER_8
)0;
833 /* Early exit if possible */
834 if (m
== 0 || n
== 0 || k
== 0)
837 /* Adjust size of t1 to what is needed. */
839 t1_dim
= (a_dim1
- (ycount
> 1)) * 256 + b_dim1
;
843 t1
= malloc (t1_dim
* sizeof(GFC_INTEGER_8
));
845 /* Start turning the crank. */
847 for (jj
= 1; jj
<= i1
; jj
+= 512)
853 ujsec
= jsec
- jsec
% 4;
855 for (ll
= 1; ll
<= i2
; ll
+= 256)
861 ulsec
= lsec
- lsec
% 2;
864 for (ii
= 1; ii
<= i3
; ii
+= 256)
870 uisec
= isec
- isec
% 2;
872 for (l
= ll
; l
<= i4
; l
+= 2)
875 for (i
= ii
; i
<= i5
; i
+= 2)
877 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
879 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
880 a
[i
+ (l
+ 1) * a_dim1
];
881 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
882 a
[i
+ 1 + l
* a_dim1
];
883 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
884 a
[i
+ 1 + (l
+ 1) * a_dim1
];
888 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
889 a
[ii
+ isec
- 1 + l
* a_dim1
];
890 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
891 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
897 for (i
= ii
; i
<= i4
; ++i
)
899 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
900 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
904 uisec
= isec
- isec
% 4;
906 for (j
= jj
; j
<= i4
; j
+= 4)
909 for (i
= ii
; i
<= i5
; i
+= 4)
911 f11
= c
[i
+ j
* c_dim1
];
912 f21
= c
[i
+ 1 + j
* c_dim1
];
913 f12
= c
[i
+ (j
+ 1) * c_dim1
];
914 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
915 f13
= c
[i
+ (j
+ 2) * c_dim1
];
916 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
917 f14
= c
[i
+ (j
+ 3) * c_dim1
];
918 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
919 f31
= c
[i
+ 2 + j
* c_dim1
];
920 f41
= c
[i
+ 3 + j
* c_dim1
];
921 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
922 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
923 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
924 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
925 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
926 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
928 for (l
= ll
; l
<= i6
; ++l
)
930 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
932 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
934 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
935 * b
[l
+ (j
+ 1) * b_dim1
];
936 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
937 * b
[l
+ (j
+ 1) * b_dim1
];
938 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
939 * b
[l
+ (j
+ 2) * b_dim1
];
940 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
941 * b
[l
+ (j
+ 2) * b_dim1
];
942 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
943 * b
[l
+ (j
+ 3) * b_dim1
];
944 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
945 * b
[l
+ (j
+ 3) * b_dim1
];
946 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
948 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
950 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
951 * b
[l
+ (j
+ 1) * b_dim1
];
952 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
953 * b
[l
+ (j
+ 1) * b_dim1
];
954 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
955 * b
[l
+ (j
+ 2) * b_dim1
];
956 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
957 * b
[l
+ (j
+ 2) * b_dim1
];
958 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
959 * b
[l
+ (j
+ 3) * b_dim1
];
960 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
961 * b
[l
+ (j
+ 3) * b_dim1
];
963 c
[i
+ j
* c_dim1
] = f11
;
964 c
[i
+ 1 + j
* c_dim1
] = f21
;
965 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
966 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
967 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
968 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
969 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
970 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
971 c
[i
+ 2 + j
* c_dim1
] = f31
;
972 c
[i
+ 3 + j
* c_dim1
] = f41
;
973 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
974 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
975 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
976 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
977 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
978 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
983 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
985 f11
= c
[i
+ j
* c_dim1
];
986 f12
= c
[i
+ (j
+ 1) * c_dim1
];
987 f13
= c
[i
+ (j
+ 2) * c_dim1
];
988 f14
= c
[i
+ (j
+ 3) * c_dim1
];
990 for (l
= ll
; l
<= i6
; ++l
)
992 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
993 257] * b
[l
+ j
* b_dim1
];
994 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
995 257] * b
[l
+ (j
+ 1) * b_dim1
];
996 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
997 257] * b
[l
+ (j
+ 2) * b_dim1
];
998 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
999 257] * b
[l
+ (j
+ 3) * b_dim1
];
1001 c
[i
+ j
* c_dim1
] = f11
;
1002 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1003 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1004 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1011 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
1013 i5
= ii
+ uisec
- 1;
1014 for (i
= ii
; i
<= i5
; i
+= 4)
1016 f11
= c
[i
+ j
* c_dim1
];
1017 f21
= c
[i
+ 1 + j
* c_dim1
];
1018 f31
= c
[i
+ 2 + j
* c_dim1
];
1019 f41
= c
[i
+ 3 + j
* c_dim1
];
1021 for (l
= ll
; l
<= i6
; ++l
)
1023 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1024 257] * b
[l
+ j
* b_dim1
];
1025 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
1026 257] * b
[l
+ j
* b_dim1
];
1027 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
1028 257] * b
[l
+ j
* b_dim1
];
1029 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
1030 257] * b
[l
+ j
* b_dim1
];
1032 c
[i
+ j
* c_dim1
] = f11
;
1033 c
[i
+ 1 + j
* c_dim1
] = f21
;
1034 c
[i
+ 2 + j
* c_dim1
] = f31
;
1035 c
[i
+ 3 + j
* c_dim1
] = f41
;
1038 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1040 f11
= c
[i
+ j
* c_dim1
];
1042 for (l
= ll
; l
<= i6
; ++l
)
1044 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1045 257] * b
[l
+ j
* b_dim1
];
1047 c
[i
+ j
* c_dim1
] = f11
;
1057 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
1059 if (GFC_DESCRIPTOR_RANK (a
) != 1)
1061 const GFC_INTEGER_8
*restrict abase_x
;
1062 const GFC_INTEGER_8
*restrict bbase_y
;
1063 GFC_INTEGER_8
*restrict dest_y
;
1066 for (y
= 0; y
< ycount
; y
++)
1068 bbase_y
= &bbase
[y
*bystride
];
1069 dest_y
= &dest
[y
*rystride
];
1070 for (x
= 0; x
< xcount
; x
++)
1072 abase_x
= &abase
[x
*axstride
];
1073 s
= (GFC_INTEGER_8
) 0;
1074 for (n
= 0; n
< count
; n
++)
1075 s
+= abase_x
[n
] * bbase_y
[n
];
1082 const GFC_INTEGER_8
*restrict bbase_y
;
1085 for (y
= 0; y
< ycount
; y
++)
1087 bbase_y
= &bbase
[y
*bystride
];
1088 s
= (GFC_INTEGER_8
) 0;
1089 for (n
= 0; n
< count
; n
++)
1090 s
+= abase
[n
*axstride
] * bbase_y
[n
];
1091 dest
[y
*rystride
] = s
;
1095 else if (axstride
< aystride
)
1097 for (y
= 0; y
< ycount
; y
++)
1098 for (x
= 0; x
< xcount
; x
++)
1099 dest
[x
*rxstride
+ y
*rystride
] = (GFC_INTEGER_8
)0;
1101 for (y
= 0; y
< ycount
; y
++)
1102 for (n
= 0; n
< count
; n
++)
1103 for (x
= 0; x
< xcount
; x
++)
1104 /* dest[x,y] += a[x,n] * b[n,y] */
1105 dest
[x
*rxstride
+ y
*rystride
] +=
1106 abase
[x
*axstride
+ n
*aystride
] *
1107 bbase
[n
*bxstride
+ y
*bystride
];
1109 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
1111 const GFC_INTEGER_8
*restrict bbase_y
;
1114 for (y
= 0; y
< ycount
; y
++)
1116 bbase_y
= &bbase
[y
*bystride
];
1117 s
= (GFC_INTEGER_8
) 0;
1118 for (n
= 0; n
< count
; n
++)
1119 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
1120 dest
[y
*rxstride
] = s
;
1125 const GFC_INTEGER_8
*restrict abase_x
;
1126 const GFC_INTEGER_8
*restrict bbase_y
;
1127 GFC_INTEGER_8
*restrict dest_y
;
1130 for (y
= 0; y
< ycount
; y
++)
1132 bbase_y
= &bbase
[y
*bystride
];
1133 dest_y
= &dest
[y
*rystride
];
1134 for (x
= 0; x
< xcount
; x
++)
1136 abase_x
= &abase
[x
*axstride
];
1137 s
= (GFC_INTEGER_8
) 0;
1138 for (n
= 0; n
< count
; n
++)
1139 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
1140 dest_y
[x
*rxstride
] = s
;