1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2018 Free Software Foundation, Inc.
3 Contributed by Thomas Koenig <tkoenig@gcc.gnu.org>.
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
31 /* These are the specific versions of matmul with -mprefer-avx128. */
33 #if defined (HAVE_GFC_REAL_16)
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36 passed to us by the front-end, in which case we call it for large
39 typedef void (*blas_call
)(const char *, const char *, const int *, const int *,
40 const int *, const GFC_REAL_16
*, const GFC_REAL_16
*,
41 const int *, const GFC_REAL_16
*, const int *,
42 const GFC_REAL_16
*, GFC_REAL_16
*, const int *,
45 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
47 matmul_r16_avx128_fma3 (gfc_array_r16
* const restrict retarray
,
48 gfc_array_r16
* const restrict a
, gfc_array_r16
* const restrict b
, int try_blas
,
49 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma")));
50 internal_proto(matmul_r16_avx128_fma3
);
52 matmul_r16_avx128_fma3 (gfc_array_r16
* const restrict retarray
,
53 gfc_array_r16
* const restrict a
, gfc_array_r16
* const restrict b
, int try_blas
,
54 int blas_limit
, blas_call gemm
)
56 const GFC_REAL_16
* restrict abase
;
57 const GFC_REAL_16
* restrict bbase
;
58 GFC_REAL_16
* restrict dest
;
60 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
61 index_type x
, y
, n
, count
, xcount
, ycount
;
63 assert (GFC_DESCRIPTOR_RANK (a
) == 2
64 || GFC_DESCRIPTOR_RANK (b
) == 2);
66 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
68 Either A or B (but not both) can be rank 1:
70 o One-dimensional argument A is implicitly treated as a row matrix
71 dimensioned [1,count], so xcount=1.
73 o One-dimensional argument B is implicitly treated as a column matrix
74 dimensioned [count, 1], so ycount=1.
77 if (retarray
->base_addr
== NULL
)
79 if (GFC_DESCRIPTOR_RANK (a
) == 1)
81 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
82 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
84 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
86 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
87 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
91 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
92 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
94 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
95 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
96 GFC_DESCRIPTOR_EXTENT(retarray
,0));
100 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_REAL_16
));
101 retarray
->offset
= 0;
103 else if (unlikely (compile_options
.bounds_check
))
105 index_type ret_extent
, arg_extent
;
107 if (GFC_DESCRIPTOR_RANK (a
) == 1)
109 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
110 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
111 if (arg_extent
!= ret_extent
)
112 runtime_error ("Incorrect extent in return array in"
113 " MATMUL intrinsic: is %ld, should be %ld",
114 (long int) ret_extent
, (long int) arg_extent
);
116 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
118 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
119 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
120 if (arg_extent
!= ret_extent
)
121 runtime_error ("Incorrect extent in return array in"
122 " MATMUL intrinsic: is %ld, should be %ld",
123 (long int) ret_extent
, (long int) arg_extent
);
127 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
128 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
129 if (arg_extent
!= ret_extent
)
130 runtime_error ("Incorrect extent in return array in"
131 " MATMUL intrinsic for dimension 1:"
132 " is %ld, should be %ld",
133 (long int) ret_extent
, (long int) arg_extent
);
135 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
136 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
137 if (arg_extent
!= ret_extent
)
138 runtime_error ("Incorrect extent in return array in"
139 " MATMUL intrinsic for dimension 2:"
140 " is %ld, should be %ld",
141 (long int) ret_extent
, (long int) arg_extent
);
146 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
148 /* One-dimensional result may be addressed in the code below
149 either as a row or a column matrix. We want both cases to
151 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
155 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
156 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
160 if (GFC_DESCRIPTOR_RANK (a
) == 1)
162 /* Treat it as a a row matrix A[1,count]. */
163 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
167 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
171 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
172 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
174 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
175 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
178 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
180 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
181 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
184 if (GFC_DESCRIPTOR_RANK (b
) == 1)
186 /* Treat it as a column matrix B[count,1] */
187 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
189 /* bystride should never be used for 1-dimensional b.
190 The value is only used for calculation of the
191 memory by the buffer. */
197 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
198 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
199 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
202 abase
= a
->base_addr
;
203 bbase
= b
->base_addr
;
204 dest
= retarray
->base_addr
;
206 /* Now that everything is set up, we perform the multiplication
209 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
210 #define min(a,b) ((a) <= (b) ? (a) : (b))
211 #define max(a,b) ((a) >= (b) ? (a) : (b))
213 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
214 && (bxstride
== 1 || bystride
== 1)
215 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
218 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
219 const GFC_REAL_16 one
= 1, zero
= 0;
220 const int lda
= (axstride
== 1) ? aystride
: axstride
,
221 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
223 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
225 assert (gemm
!= NULL
);
226 gemm (axstride
== 1 ? "N" : "T", bxstride
== 1 ? "N" : "T", &m
,
227 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
233 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
235 /* This block of code implements a tuned matmul, derived from
236 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
238 Bo Kagstrom and Per Ling
239 Department of Computing Science
241 S-901 87 Umea, Sweden
243 from netlib.org, translated to C, and modified for matmul.m4. */
245 const GFC_REAL_16
*a
, *b
;
247 const index_type m
= xcount
, n
= ycount
, k
= count
;
249 /* System generated locals */
250 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
251 i1
, i2
, i3
, i4
, i5
, i6
;
253 /* Local variables */
254 GFC_REAL_16 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
255 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
256 index_type i
, j
, l
, ii
, jj
, ll
;
257 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
262 c
= retarray
->base_addr
;
264 /* Parameter adjustments */
266 c_offset
= 1 + c_dim1
;
269 a_offset
= 1 + a_dim1
;
272 b_offset
= 1 + b_dim1
;
278 c
[i
+ j
* c_dim1
] = (GFC_REAL_16
)0;
280 /* Early exit if possible */
281 if (m
== 0 || n
== 0 || k
== 0)
284 /* Adjust size of t1 to what is needed. */
285 index_type t1_dim
, a_sz
;
291 t1_dim
= a_sz
* 256 + b_dim1
;
295 t1
= malloc (t1_dim
* sizeof(GFC_REAL_16
));
297 /* Start turning the crank. */
299 for (jj
= 1; jj
<= i1
; jj
+= 512)
305 ujsec
= jsec
- jsec
% 4;
307 for (ll
= 1; ll
<= i2
; ll
+= 256)
313 ulsec
= lsec
- lsec
% 2;
316 for (ii
= 1; ii
<= i3
; ii
+= 256)
322 uisec
= isec
- isec
% 2;
324 for (l
= ll
; l
<= i4
; l
+= 2)
327 for (i
= ii
; i
<= i5
; i
+= 2)
329 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
331 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
332 a
[i
+ (l
+ 1) * a_dim1
];
333 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
334 a
[i
+ 1 + l
* a_dim1
];
335 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
336 a
[i
+ 1 + (l
+ 1) * a_dim1
];
340 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
341 a
[ii
+ isec
- 1 + l
* a_dim1
];
342 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
343 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
349 for (i
= ii
; i
<= i4
; ++i
)
351 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
352 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
356 uisec
= isec
- isec
% 4;
358 for (j
= jj
; j
<= i4
; j
+= 4)
361 for (i
= ii
; i
<= i5
; i
+= 4)
363 f11
= c
[i
+ j
* c_dim1
];
364 f21
= c
[i
+ 1 + j
* c_dim1
];
365 f12
= c
[i
+ (j
+ 1) * c_dim1
];
366 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
367 f13
= c
[i
+ (j
+ 2) * c_dim1
];
368 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
369 f14
= c
[i
+ (j
+ 3) * c_dim1
];
370 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
371 f31
= c
[i
+ 2 + j
* c_dim1
];
372 f41
= c
[i
+ 3 + j
* c_dim1
];
373 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
374 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
375 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
376 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
377 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
378 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
380 for (l
= ll
; l
<= i6
; ++l
)
382 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
384 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
386 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
387 * b
[l
+ (j
+ 1) * b_dim1
];
388 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
389 * b
[l
+ (j
+ 1) * b_dim1
];
390 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
391 * b
[l
+ (j
+ 2) * b_dim1
];
392 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
393 * b
[l
+ (j
+ 2) * b_dim1
];
394 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
395 * b
[l
+ (j
+ 3) * b_dim1
];
396 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
397 * b
[l
+ (j
+ 3) * b_dim1
];
398 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
400 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
402 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
403 * b
[l
+ (j
+ 1) * b_dim1
];
404 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
405 * b
[l
+ (j
+ 1) * b_dim1
];
406 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
407 * b
[l
+ (j
+ 2) * b_dim1
];
408 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
409 * b
[l
+ (j
+ 2) * b_dim1
];
410 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
411 * b
[l
+ (j
+ 3) * b_dim1
];
412 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
413 * b
[l
+ (j
+ 3) * b_dim1
];
415 c
[i
+ j
* c_dim1
] = f11
;
416 c
[i
+ 1 + j
* c_dim1
] = f21
;
417 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
418 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
419 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
420 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
421 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
422 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
423 c
[i
+ 2 + j
* c_dim1
] = f31
;
424 c
[i
+ 3 + j
* c_dim1
] = f41
;
425 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
426 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
427 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
428 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
429 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
430 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
435 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
437 f11
= c
[i
+ j
* c_dim1
];
438 f12
= c
[i
+ (j
+ 1) * c_dim1
];
439 f13
= c
[i
+ (j
+ 2) * c_dim1
];
440 f14
= c
[i
+ (j
+ 3) * c_dim1
];
442 for (l
= ll
; l
<= i6
; ++l
)
444 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
445 257] * b
[l
+ j
* b_dim1
];
446 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
447 257] * b
[l
+ (j
+ 1) * b_dim1
];
448 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
449 257] * b
[l
+ (j
+ 2) * b_dim1
];
450 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
451 257] * b
[l
+ (j
+ 3) * b_dim1
];
453 c
[i
+ j
* c_dim1
] = f11
;
454 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
455 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
456 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
463 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
466 for (i
= ii
; i
<= i5
; i
+= 4)
468 f11
= c
[i
+ j
* c_dim1
];
469 f21
= c
[i
+ 1 + j
* c_dim1
];
470 f31
= c
[i
+ 2 + j
* c_dim1
];
471 f41
= c
[i
+ 3 + j
* c_dim1
];
473 for (l
= ll
; l
<= i6
; ++l
)
475 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
476 257] * b
[l
+ j
* b_dim1
];
477 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
478 257] * b
[l
+ j
* b_dim1
];
479 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
480 257] * b
[l
+ j
* b_dim1
];
481 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
482 257] * b
[l
+ j
* b_dim1
];
484 c
[i
+ j
* c_dim1
] = f11
;
485 c
[i
+ 1 + j
* c_dim1
] = f21
;
486 c
[i
+ 2 + j
* c_dim1
] = f31
;
487 c
[i
+ 3 + j
* c_dim1
] = f41
;
490 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
492 f11
= c
[i
+ j
* c_dim1
];
494 for (l
= ll
; l
<= i6
; ++l
)
496 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
497 257] * b
[l
+ j
* b_dim1
];
499 c
[i
+ j
* c_dim1
] = f11
;
509 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
511 if (GFC_DESCRIPTOR_RANK (a
) != 1)
513 const GFC_REAL_16
*restrict abase_x
;
514 const GFC_REAL_16
*restrict bbase_y
;
515 GFC_REAL_16
*restrict dest_y
;
518 for (y
= 0; y
< ycount
; y
++)
520 bbase_y
= &bbase
[y
*bystride
];
521 dest_y
= &dest
[y
*rystride
];
522 for (x
= 0; x
< xcount
; x
++)
524 abase_x
= &abase
[x
*axstride
];
526 for (n
= 0; n
< count
; n
++)
527 s
+= abase_x
[n
] * bbase_y
[n
];
534 const GFC_REAL_16
*restrict bbase_y
;
537 for (y
= 0; y
< ycount
; y
++)
539 bbase_y
= &bbase
[y
*bystride
];
541 for (n
= 0; n
< count
; n
++)
542 s
+= abase
[n
*axstride
] * bbase_y
[n
];
543 dest
[y
*rystride
] = s
;
547 else if (axstride
< aystride
)
549 for (y
= 0; y
< ycount
; y
++)
550 for (x
= 0; x
< xcount
; x
++)
551 dest
[x
*rxstride
+ y
*rystride
] = (GFC_REAL_16
)0;
553 for (y
= 0; y
< ycount
; y
++)
554 for (n
= 0; n
< count
; n
++)
555 for (x
= 0; x
< xcount
; x
++)
556 /* dest[x,y] += a[x,n] * b[n,y] */
557 dest
[x
*rxstride
+ y
*rystride
] +=
558 abase
[x
*axstride
+ n
*aystride
] *
559 bbase
[n
*bxstride
+ y
*bystride
];
561 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
563 const GFC_REAL_16
*restrict bbase_y
;
566 for (y
= 0; y
< ycount
; y
++)
568 bbase_y
= &bbase
[y
*bystride
];
570 for (n
= 0; n
< count
; n
++)
571 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
572 dest
[y
*rxstride
] = s
;
577 const GFC_REAL_16
*restrict abase_x
;
578 const GFC_REAL_16
*restrict bbase_y
;
579 GFC_REAL_16
*restrict dest_y
;
582 for (y
= 0; y
< ycount
; y
++)
584 bbase_y
= &bbase
[y
*bystride
];
585 dest_y
= &dest
[y
*rystride
];
586 for (x
= 0; x
< xcount
; x
++)
588 abase_x
= &abase
[x
*axstride
];
590 for (n
= 0; n
< count
; n
++)
591 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
592 dest_y
[x
*rxstride
] = s
;
603 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
605 matmul_r16_avx128_fma4 (gfc_array_r16
* const restrict retarray
,
606 gfc_array_r16
* const restrict a
, gfc_array_r16
* const restrict b
, int try_blas
,
607 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma4")));
608 internal_proto(matmul_r16_avx128_fma4
);
610 matmul_r16_avx128_fma4 (gfc_array_r16
* const restrict retarray
,
611 gfc_array_r16
* const restrict a
, gfc_array_r16
* const restrict b
, int try_blas
,
612 int blas_limit
, blas_call gemm
)
614 const GFC_REAL_16
* restrict abase
;
615 const GFC_REAL_16
* restrict bbase
;
616 GFC_REAL_16
* restrict dest
;
618 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
619 index_type x
, y
, n
, count
, xcount
, ycount
;
621 assert (GFC_DESCRIPTOR_RANK (a
) == 2
622 || GFC_DESCRIPTOR_RANK (b
) == 2);
624 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
626 Either A or B (but not both) can be rank 1:
628 o One-dimensional argument A is implicitly treated as a row matrix
629 dimensioned [1,count], so xcount=1.
631 o One-dimensional argument B is implicitly treated as a column matrix
632 dimensioned [count, 1], so ycount=1.
635 if (retarray
->base_addr
== NULL
)
637 if (GFC_DESCRIPTOR_RANK (a
) == 1)
639 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
640 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
642 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
644 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
645 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
649 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
650 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
652 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
653 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
654 GFC_DESCRIPTOR_EXTENT(retarray
,0));
658 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_REAL_16
));
659 retarray
->offset
= 0;
661 else if (unlikely (compile_options
.bounds_check
))
663 index_type ret_extent
, arg_extent
;
665 if (GFC_DESCRIPTOR_RANK (a
) == 1)
667 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
668 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
669 if (arg_extent
!= ret_extent
)
670 runtime_error ("Incorrect extent in return array in"
671 " MATMUL intrinsic: is %ld, should be %ld",
672 (long int) ret_extent
, (long int) arg_extent
);
674 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
676 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
677 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
678 if (arg_extent
!= ret_extent
)
679 runtime_error ("Incorrect extent in return array in"
680 " MATMUL intrinsic: is %ld, should be %ld",
681 (long int) ret_extent
, (long int) arg_extent
);
685 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
686 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
687 if (arg_extent
!= ret_extent
)
688 runtime_error ("Incorrect extent in return array in"
689 " MATMUL intrinsic for dimension 1:"
690 " is %ld, should be %ld",
691 (long int) ret_extent
, (long int) arg_extent
);
693 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
694 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
695 if (arg_extent
!= ret_extent
)
696 runtime_error ("Incorrect extent in return array in"
697 " MATMUL intrinsic for dimension 2:"
698 " is %ld, should be %ld",
699 (long int) ret_extent
, (long int) arg_extent
);
704 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
706 /* One-dimensional result may be addressed in the code below
707 either as a row or a column matrix. We want both cases to
709 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
713 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
714 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
718 if (GFC_DESCRIPTOR_RANK (a
) == 1)
720 /* Treat it as a a row matrix A[1,count]. */
721 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
725 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
729 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
730 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
732 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
733 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
736 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
738 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
739 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
742 if (GFC_DESCRIPTOR_RANK (b
) == 1)
744 /* Treat it as a column matrix B[count,1] */
745 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
747 /* bystride should never be used for 1-dimensional b.
748 The value is only used for calculation of the
749 memory by the buffer. */
755 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
756 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
757 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
760 abase
= a
->base_addr
;
761 bbase
= b
->base_addr
;
762 dest
= retarray
->base_addr
;
764 /* Now that everything is set up, we perform the multiplication
767 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
768 #define min(a,b) ((a) <= (b) ? (a) : (b))
769 #define max(a,b) ((a) >= (b) ? (a) : (b))
771 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
772 && (bxstride
== 1 || bystride
== 1)
773 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
776 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
777 const GFC_REAL_16 one
= 1, zero
= 0;
778 const int lda
= (axstride
== 1) ? aystride
: axstride
,
779 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
781 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
783 assert (gemm
!= NULL
);
784 gemm (axstride
== 1 ? "N" : "T", bxstride
== 1 ? "N" : "T", &m
,
785 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
791 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
793 /* This block of code implements a tuned matmul, derived from
794 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
796 Bo Kagstrom and Per Ling
797 Department of Computing Science
799 S-901 87 Umea, Sweden
801 from netlib.org, translated to C, and modified for matmul.m4. */
803 const GFC_REAL_16
*a
, *b
;
805 const index_type m
= xcount
, n
= ycount
, k
= count
;
807 /* System generated locals */
808 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
809 i1
, i2
, i3
, i4
, i5
, i6
;
811 /* Local variables */
812 GFC_REAL_16 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
813 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
814 index_type i
, j
, l
, ii
, jj
, ll
;
815 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
820 c
= retarray
->base_addr
;
822 /* Parameter adjustments */
824 c_offset
= 1 + c_dim1
;
827 a_offset
= 1 + a_dim1
;
830 b_offset
= 1 + b_dim1
;
836 c
[i
+ j
* c_dim1
] = (GFC_REAL_16
)0;
838 /* Early exit if possible */
839 if (m
== 0 || n
== 0 || k
== 0)
842 /* Adjust size of t1 to what is needed. */
843 index_type t1_dim
, a_sz
;
849 t1_dim
= a_sz
* 256 + b_dim1
;
853 t1
= malloc (t1_dim
* sizeof(GFC_REAL_16
));
855 /* Start turning the crank. */
857 for (jj
= 1; jj
<= i1
; jj
+= 512)
863 ujsec
= jsec
- jsec
% 4;
865 for (ll
= 1; ll
<= i2
; ll
+= 256)
871 ulsec
= lsec
- lsec
% 2;
874 for (ii
= 1; ii
<= i3
; ii
+= 256)
880 uisec
= isec
- isec
% 2;
882 for (l
= ll
; l
<= i4
; l
+= 2)
885 for (i
= ii
; i
<= i5
; i
+= 2)
887 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
889 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
890 a
[i
+ (l
+ 1) * a_dim1
];
891 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
892 a
[i
+ 1 + l
* a_dim1
];
893 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
894 a
[i
+ 1 + (l
+ 1) * a_dim1
];
898 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
899 a
[ii
+ isec
- 1 + l
* a_dim1
];
900 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
901 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
907 for (i
= ii
; i
<= i4
; ++i
)
909 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
910 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
914 uisec
= isec
- isec
% 4;
916 for (j
= jj
; j
<= i4
; j
+= 4)
919 for (i
= ii
; i
<= i5
; i
+= 4)
921 f11
= c
[i
+ j
* c_dim1
];
922 f21
= c
[i
+ 1 + j
* c_dim1
];
923 f12
= c
[i
+ (j
+ 1) * c_dim1
];
924 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
925 f13
= c
[i
+ (j
+ 2) * c_dim1
];
926 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
927 f14
= c
[i
+ (j
+ 3) * c_dim1
];
928 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
929 f31
= c
[i
+ 2 + j
* c_dim1
];
930 f41
= c
[i
+ 3 + j
* c_dim1
];
931 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
932 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
933 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
934 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
935 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
936 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
938 for (l
= ll
; l
<= i6
; ++l
)
940 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
942 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
944 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
945 * b
[l
+ (j
+ 1) * b_dim1
];
946 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
947 * b
[l
+ (j
+ 1) * b_dim1
];
948 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
949 * b
[l
+ (j
+ 2) * b_dim1
];
950 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
951 * b
[l
+ (j
+ 2) * b_dim1
];
952 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
953 * b
[l
+ (j
+ 3) * b_dim1
];
954 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
955 * b
[l
+ (j
+ 3) * b_dim1
];
956 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
958 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
960 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
961 * b
[l
+ (j
+ 1) * b_dim1
];
962 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
963 * b
[l
+ (j
+ 1) * b_dim1
];
964 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
965 * b
[l
+ (j
+ 2) * b_dim1
];
966 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
967 * b
[l
+ (j
+ 2) * b_dim1
];
968 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
969 * b
[l
+ (j
+ 3) * b_dim1
];
970 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
971 * b
[l
+ (j
+ 3) * b_dim1
];
973 c
[i
+ j
* c_dim1
] = f11
;
974 c
[i
+ 1 + j
* c_dim1
] = f21
;
975 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
976 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
977 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
978 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
979 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
980 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
981 c
[i
+ 2 + j
* c_dim1
] = f31
;
982 c
[i
+ 3 + j
* c_dim1
] = f41
;
983 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
984 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
985 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
986 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
987 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
988 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
993 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
995 f11
= c
[i
+ j
* c_dim1
];
996 f12
= c
[i
+ (j
+ 1) * c_dim1
];
997 f13
= c
[i
+ (j
+ 2) * c_dim1
];
998 f14
= c
[i
+ (j
+ 3) * c_dim1
];
1000 for (l
= ll
; l
<= i6
; ++l
)
1002 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1003 257] * b
[l
+ j
* b_dim1
];
1004 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1005 257] * b
[l
+ (j
+ 1) * b_dim1
];
1006 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1007 257] * b
[l
+ (j
+ 2) * b_dim1
];
1008 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1009 257] * b
[l
+ (j
+ 3) * b_dim1
];
1011 c
[i
+ j
* c_dim1
] = f11
;
1012 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1013 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1014 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1021 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
1023 i5
= ii
+ uisec
- 1;
1024 for (i
= ii
; i
<= i5
; i
+= 4)
1026 f11
= c
[i
+ j
* c_dim1
];
1027 f21
= c
[i
+ 1 + j
* c_dim1
];
1028 f31
= c
[i
+ 2 + j
* c_dim1
];
1029 f41
= c
[i
+ 3 + j
* c_dim1
];
1031 for (l
= ll
; l
<= i6
; ++l
)
1033 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1034 257] * b
[l
+ j
* b_dim1
];
1035 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
1036 257] * b
[l
+ j
* b_dim1
];
1037 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
1038 257] * b
[l
+ j
* b_dim1
];
1039 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
1040 257] * b
[l
+ j
* b_dim1
];
1042 c
[i
+ j
* c_dim1
] = f11
;
1043 c
[i
+ 1 + j
* c_dim1
] = f21
;
1044 c
[i
+ 2 + j
* c_dim1
] = f31
;
1045 c
[i
+ 3 + j
* c_dim1
] = f41
;
1048 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1050 f11
= c
[i
+ j
* c_dim1
];
1052 for (l
= ll
; l
<= i6
; ++l
)
1054 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1055 257] * b
[l
+ j
* b_dim1
];
1057 c
[i
+ j
* c_dim1
] = f11
;
1067 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
1069 if (GFC_DESCRIPTOR_RANK (a
) != 1)
1071 const GFC_REAL_16
*restrict abase_x
;
1072 const GFC_REAL_16
*restrict bbase_y
;
1073 GFC_REAL_16
*restrict dest_y
;
1076 for (y
= 0; y
< ycount
; y
++)
1078 bbase_y
= &bbase
[y
*bystride
];
1079 dest_y
= &dest
[y
*rystride
];
1080 for (x
= 0; x
< xcount
; x
++)
1082 abase_x
= &abase
[x
*axstride
];
1083 s
= (GFC_REAL_16
) 0;
1084 for (n
= 0; n
< count
; n
++)
1085 s
+= abase_x
[n
] * bbase_y
[n
];
1092 const GFC_REAL_16
*restrict bbase_y
;
1095 for (y
= 0; y
< ycount
; y
++)
1097 bbase_y
= &bbase
[y
*bystride
];
1098 s
= (GFC_REAL_16
) 0;
1099 for (n
= 0; n
< count
; n
++)
1100 s
+= abase
[n
*axstride
] * bbase_y
[n
];
1101 dest
[y
*rystride
] = s
;
1105 else if (axstride
< aystride
)
1107 for (y
= 0; y
< ycount
; y
++)
1108 for (x
= 0; x
< xcount
; x
++)
1109 dest
[x
*rxstride
+ y
*rystride
] = (GFC_REAL_16
)0;
1111 for (y
= 0; y
< ycount
; y
++)
1112 for (n
= 0; n
< count
; n
++)
1113 for (x
= 0; x
< xcount
; x
++)
1114 /* dest[x,y] += a[x,n] * b[n,y] */
1115 dest
[x
*rxstride
+ y
*rystride
] +=
1116 abase
[x
*axstride
+ n
*aystride
] *
1117 bbase
[n
*bxstride
+ y
*bystride
];
1119 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
1121 const GFC_REAL_16
*restrict bbase_y
;
1124 for (y
= 0; y
< ycount
; y
++)
1126 bbase_y
= &bbase
[y
*bystride
];
1127 s
= (GFC_REAL_16
) 0;
1128 for (n
= 0; n
< count
; n
++)
1129 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
1130 dest
[y
*rxstride
] = s
;
1135 const GFC_REAL_16
*restrict abase_x
;
1136 const GFC_REAL_16
*restrict bbase_y
;
1137 GFC_REAL_16
*restrict dest_y
;
1140 for (y
= 0; y
< ycount
; y
++)
1142 bbase_y
= &bbase
[y
*bystride
];
1143 dest_y
= &dest
[y
*rystride
];
1144 for (x
= 0; x
< xcount
; x
++)
1146 abase_x
= &abase
[x
*axstride
];
1147 s
= (GFC_REAL_16
) 0;
1148 for (n
= 0; n
< count
; n
++)
1149 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
1150 dest_y
[x
*rxstride
] = s
;