1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2018 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
31 #if defined (HAVE_GFC_INTEGER_4)
33 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
34 passed to us by the front-end, in which case we call it for large
37 typedef void (*blas_call
)(const char *, const char *, const int *, const int *,
38 const int *, const GFC_INTEGER_4
*, const GFC_INTEGER_4
*,
39 const int *, const GFC_INTEGER_4
*, const int *,
40 const GFC_INTEGER_4
*, GFC_INTEGER_4
*, const int *,
43 /* The order of loops is different in the case of plain matrix
44 multiplication C=MATMUL(A,B), and in the frequent special case where
45 the argument A is the temporary result of a TRANSPOSE intrinsic:
46 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
47 looking at their strides.
49 The equivalent Fortran pseudo-code is:
51 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
52 IF (.NOT.IS_TRANSPOSED(A)) THEN
57 C(I,J) = C(I,J)+A(I,K)*B(K,J)
68 /* If try_blas is set to a nonzero value, then the matmul function will
69 see if there is a way to perform the matrix multiplication by a call
70 to the BLAS gemm function. */
72 extern void matmul_i4 (gfc_array_i4
* const restrict retarray
,
73 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
74 int blas_limit
, blas_call gemm
);
75 export_proto(matmul_i4
);
77 /* Put exhaustive list of possible architectures here here, ORed together. */
79 #if defined(HAVE_AVX) || defined(HAVE_AVX2) || defined(HAVE_AVX512F)
83 matmul_i4_avx (gfc_array_i4
* const restrict retarray
,
84 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
85 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx")));
87 matmul_i4_avx (gfc_array_i4
* const restrict retarray
,
88 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
89 int blas_limit
, blas_call gemm
)
91 const GFC_INTEGER_4
* restrict abase
;
92 const GFC_INTEGER_4
* restrict bbase
;
93 GFC_INTEGER_4
* restrict dest
;
95 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
96 index_type x
, y
, n
, count
, xcount
, ycount
;
98 assert (GFC_DESCRIPTOR_RANK (a
) == 2
99 || GFC_DESCRIPTOR_RANK (b
) == 2);
101 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
103 Either A or B (but not both) can be rank 1:
105 o One-dimensional argument A is implicitly treated as a row matrix
106 dimensioned [1,count], so xcount=1.
108 o One-dimensional argument B is implicitly treated as a column matrix
109 dimensioned [count, 1], so ycount=1.
112 if (retarray
->base_addr
== NULL
)
114 if (GFC_DESCRIPTOR_RANK (a
) == 1)
116 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
117 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
119 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
121 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
122 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
126 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
127 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
129 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
130 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
131 GFC_DESCRIPTOR_EXTENT(retarray
,0));
135 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_INTEGER_4
));
136 retarray
->offset
= 0;
138 else if (unlikely (compile_options
.bounds_check
))
140 index_type ret_extent
, arg_extent
;
142 if (GFC_DESCRIPTOR_RANK (a
) == 1)
144 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
145 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
146 if (arg_extent
!= ret_extent
)
147 runtime_error ("Array bound mismatch for dimension 1 of "
149 (long int) ret_extent
, (long int) arg_extent
);
151 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
153 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
154 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
155 if (arg_extent
!= ret_extent
)
156 runtime_error ("Array bound mismatch for dimension 1 of "
158 (long int) ret_extent
, (long int) arg_extent
);
162 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
163 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
164 if (arg_extent
!= ret_extent
)
165 runtime_error ("Array bound mismatch for dimension 1 of "
167 (long int) ret_extent
, (long int) arg_extent
);
169 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
170 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
171 if (arg_extent
!= ret_extent
)
172 runtime_error ("Array bound mismatch for dimension 2 of "
174 (long int) ret_extent
, (long int) arg_extent
);
179 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
181 /* One-dimensional result may be addressed in the code below
182 either as a row or a column matrix. We want both cases to
184 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
188 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
189 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
193 if (GFC_DESCRIPTOR_RANK (a
) == 1)
195 /* Treat it as a a row matrix A[1,count]. */
196 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
200 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
204 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
205 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
207 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
208 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
211 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
213 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
214 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
215 "in dimension 1: is %ld, should be %ld",
216 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
219 if (GFC_DESCRIPTOR_RANK (b
) == 1)
221 /* Treat it as a column matrix B[count,1] */
222 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
224 /* bystride should never be used for 1-dimensional b.
225 The value is only used for calculation of the
226 memory by the buffer. */
232 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
233 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
234 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
237 abase
= a
->base_addr
;
238 bbase
= b
->base_addr
;
239 dest
= retarray
->base_addr
;
241 /* Now that everything is set up, we perform the multiplication
244 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
245 #define min(a,b) ((a) <= (b) ? (a) : (b))
246 #define max(a,b) ((a) >= (b) ? (a) : (b))
248 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
249 && (bxstride
== 1 || bystride
== 1)
250 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
253 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
254 const GFC_INTEGER_4 one
= 1, zero
= 0;
255 const int lda
= (axstride
== 1) ? aystride
: axstride
,
256 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
258 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
260 assert (gemm
!= NULL
);
261 const char *transa
, *transb
;
265 transa
= axstride
== 1 ? "N" : "T";
270 transb
= bxstride
== 1 ? "N" : "T";
272 gemm (transa
, transb
, &m
,
273 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
279 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
281 /* This block of code implements a tuned matmul, derived from
282 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
284 Bo Kagstrom and Per Ling
285 Department of Computing Science
287 S-901 87 Umea, Sweden
289 from netlib.org, translated to C, and modified for matmul.m4. */
291 const GFC_INTEGER_4
*a
, *b
;
293 const index_type m
= xcount
, n
= ycount
, k
= count
;
295 /* System generated locals */
296 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
297 i1
, i2
, i3
, i4
, i5
, i6
;
299 /* Local variables */
300 GFC_INTEGER_4 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
301 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
302 index_type i
, j
, l
, ii
, jj
, ll
;
303 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
308 c
= retarray
->base_addr
;
310 /* Parameter adjustments */
312 c_offset
= 1 + c_dim1
;
315 a_offset
= 1 + a_dim1
;
318 b_offset
= 1 + b_dim1
;
324 c
[i
+ j
* c_dim1
] = (GFC_INTEGER_4
)0;
326 /* Early exit if possible */
327 if (m
== 0 || n
== 0 || k
== 0)
330 /* Adjust size of t1 to what is needed. */
331 index_type t1_dim
, a_sz
;
337 t1_dim
= a_sz
* 256 + b_dim1
;
341 t1
= malloc (t1_dim
* sizeof(GFC_INTEGER_4
));
343 /* Start turning the crank. */
345 for (jj
= 1; jj
<= i1
; jj
+= 512)
351 ujsec
= jsec
- jsec
% 4;
353 for (ll
= 1; ll
<= i2
; ll
+= 256)
359 ulsec
= lsec
- lsec
% 2;
362 for (ii
= 1; ii
<= i3
; ii
+= 256)
368 uisec
= isec
- isec
% 2;
370 for (l
= ll
; l
<= i4
; l
+= 2)
373 for (i
= ii
; i
<= i5
; i
+= 2)
375 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
377 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
378 a
[i
+ (l
+ 1) * a_dim1
];
379 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
380 a
[i
+ 1 + l
* a_dim1
];
381 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
382 a
[i
+ 1 + (l
+ 1) * a_dim1
];
386 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
387 a
[ii
+ isec
- 1 + l
* a_dim1
];
388 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
389 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
395 for (i
= ii
; i
<= i4
; ++i
)
397 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
398 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
402 uisec
= isec
- isec
% 4;
404 for (j
= jj
; j
<= i4
; j
+= 4)
407 for (i
= ii
; i
<= i5
; i
+= 4)
409 f11
= c
[i
+ j
* c_dim1
];
410 f21
= c
[i
+ 1 + j
* c_dim1
];
411 f12
= c
[i
+ (j
+ 1) * c_dim1
];
412 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
413 f13
= c
[i
+ (j
+ 2) * c_dim1
];
414 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
415 f14
= c
[i
+ (j
+ 3) * c_dim1
];
416 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
417 f31
= c
[i
+ 2 + j
* c_dim1
];
418 f41
= c
[i
+ 3 + j
* c_dim1
];
419 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
420 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
421 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
422 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
423 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
424 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
426 for (l
= ll
; l
<= i6
; ++l
)
428 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
430 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
432 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
433 * b
[l
+ (j
+ 1) * b_dim1
];
434 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
435 * b
[l
+ (j
+ 1) * b_dim1
];
436 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
437 * b
[l
+ (j
+ 2) * b_dim1
];
438 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
439 * b
[l
+ (j
+ 2) * b_dim1
];
440 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
441 * b
[l
+ (j
+ 3) * b_dim1
];
442 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
443 * b
[l
+ (j
+ 3) * b_dim1
];
444 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
446 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
448 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
449 * b
[l
+ (j
+ 1) * b_dim1
];
450 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
451 * b
[l
+ (j
+ 1) * b_dim1
];
452 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
453 * b
[l
+ (j
+ 2) * b_dim1
];
454 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
455 * b
[l
+ (j
+ 2) * b_dim1
];
456 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
457 * b
[l
+ (j
+ 3) * b_dim1
];
458 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
459 * b
[l
+ (j
+ 3) * b_dim1
];
461 c
[i
+ j
* c_dim1
] = f11
;
462 c
[i
+ 1 + j
* c_dim1
] = f21
;
463 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
464 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
465 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
466 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
467 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
468 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
469 c
[i
+ 2 + j
* c_dim1
] = f31
;
470 c
[i
+ 3 + j
* c_dim1
] = f41
;
471 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
472 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
473 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
474 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
475 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
476 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
481 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
483 f11
= c
[i
+ j
* c_dim1
];
484 f12
= c
[i
+ (j
+ 1) * c_dim1
];
485 f13
= c
[i
+ (j
+ 2) * c_dim1
];
486 f14
= c
[i
+ (j
+ 3) * c_dim1
];
488 for (l
= ll
; l
<= i6
; ++l
)
490 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
491 257] * b
[l
+ j
* b_dim1
];
492 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
493 257] * b
[l
+ (j
+ 1) * b_dim1
];
494 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
495 257] * b
[l
+ (j
+ 2) * b_dim1
];
496 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
497 257] * b
[l
+ (j
+ 3) * b_dim1
];
499 c
[i
+ j
* c_dim1
] = f11
;
500 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
501 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
502 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
509 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
512 for (i
= ii
; i
<= i5
; i
+= 4)
514 f11
= c
[i
+ j
* c_dim1
];
515 f21
= c
[i
+ 1 + j
* c_dim1
];
516 f31
= c
[i
+ 2 + j
* c_dim1
];
517 f41
= c
[i
+ 3 + j
* c_dim1
];
519 for (l
= ll
; l
<= i6
; ++l
)
521 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
522 257] * b
[l
+ j
* b_dim1
];
523 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
524 257] * b
[l
+ j
* b_dim1
];
525 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
526 257] * b
[l
+ j
* b_dim1
];
527 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
528 257] * b
[l
+ j
* b_dim1
];
530 c
[i
+ j
* c_dim1
] = f11
;
531 c
[i
+ 1 + j
* c_dim1
] = f21
;
532 c
[i
+ 2 + j
* c_dim1
] = f31
;
533 c
[i
+ 3 + j
* c_dim1
] = f41
;
536 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
538 f11
= c
[i
+ j
* c_dim1
];
540 for (l
= ll
; l
<= i6
; ++l
)
542 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
543 257] * b
[l
+ j
* b_dim1
];
545 c
[i
+ j
* c_dim1
] = f11
;
555 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
557 if (GFC_DESCRIPTOR_RANK (a
) != 1)
559 const GFC_INTEGER_4
*restrict abase_x
;
560 const GFC_INTEGER_4
*restrict bbase_y
;
561 GFC_INTEGER_4
*restrict dest_y
;
564 for (y
= 0; y
< ycount
; y
++)
566 bbase_y
= &bbase
[y
*bystride
];
567 dest_y
= &dest
[y
*rystride
];
568 for (x
= 0; x
< xcount
; x
++)
570 abase_x
= &abase
[x
*axstride
];
571 s
= (GFC_INTEGER_4
) 0;
572 for (n
= 0; n
< count
; n
++)
573 s
+= abase_x
[n
] * bbase_y
[n
];
580 const GFC_INTEGER_4
*restrict bbase_y
;
583 for (y
= 0; y
< ycount
; y
++)
585 bbase_y
= &bbase
[y
*bystride
];
586 s
= (GFC_INTEGER_4
) 0;
587 for (n
= 0; n
< count
; n
++)
588 s
+= abase
[n
*axstride
] * bbase_y
[n
];
589 dest
[y
*rystride
] = s
;
593 else if (axstride
< aystride
)
595 for (y
= 0; y
< ycount
; y
++)
596 for (x
= 0; x
< xcount
; x
++)
597 dest
[x
*rxstride
+ y
*rystride
] = (GFC_INTEGER_4
)0;
599 for (y
= 0; y
< ycount
; y
++)
600 for (n
= 0; n
< count
; n
++)
601 for (x
= 0; x
< xcount
; x
++)
602 /* dest[x,y] += a[x,n] * b[n,y] */
603 dest
[x
*rxstride
+ y
*rystride
] +=
604 abase
[x
*axstride
+ n
*aystride
] *
605 bbase
[n
*bxstride
+ y
*bystride
];
607 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
609 const GFC_INTEGER_4
*restrict bbase_y
;
612 for (y
= 0; y
< ycount
; y
++)
614 bbase_y
= &bbase
[y
*bystride
];
615 s
= (GFC_INTEGER_4
) 0;
616 for (n
= 0; n
< count
; n
++)
617 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
618 dest
[y
*rxstride
] = s
;
623 const GFC_INTEGER_4
*restrict abase_x
;
624 const GFC_INTEGER_4
*restrict bbase_y
;
625 GFC_INTEGER_4
*restrict dest_y
;
628 for (y
= 0; y
< ycount
; y
++)
630 bbase_y
= &bbase
[y
*bystride
];
631 dest_y
= &dest
[y
*rystride
];
632 for (x
= 0; x
< xcount
; x
++)
634 abase_x
= &abase
[x
*axstride
];
635 s
= (GFC_INTEGER_4
) 0;
636 for (n
= 0; n
< count
; n
++)
637 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
638 dest_y
[x
*rxstride
] = s
;
647 #endif /* HAVE_AVX */
651 matmul_i4_avx2 (gfc_array_i4
* const restrict retarray
,
652 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
653 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx2,fma")));
655 matmul_i4_avx2 (gfc_array_i4
* const restrict retarray
,
656 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
657 int blas_limit
, blas_call gemm
)
659 const GFC_INTEGER_4
* restrict abase
;
660 const GFC_INTEGER_4
* restrict bbase
;
661 GFC_INTEGER_4
* restrict dest
;
663 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
664 index_type x
, y
, n
, count
, xcount
, ycount
;
666 assert (GFC_DESCRIPTOR_RANK (a
) == 2
667 || GFC_DESCRIPTOR_RANK (b
) == 2);
669 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
671 Either A or B (but not both) can be rank 1:
673 o One-dimensional argument A is implicitly treated as a row matrix
674 dimensioned [1,count], so xcount=1.
676 o One-dimensional argument B is implicitly treated as a column matrix
677 dimensioned [count, 1], so ycount=1.
680 if (retarray
->base_addr
== NULL
)
682 if (GFC_DESCRIPTOR_RANK (a
) == 1)
684 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
685 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
687 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
689 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
690 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
694 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
695 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
697 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
698 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
699 GFC_DESCRIPTOR_EXTENT(retarray
,0));
703 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_INTEGER_4
));
704 retarray
->offset
= 0;
706 else if (unlikely (compile_options
.bounds_check
))
708 index_type ret_extent
, arg_extent
;
710 if (GFC_DESCRIPTOR_RANK (a
) == 1)
712 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
713 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
714 if (arg_extent
!= ret_extent
)
715 runtime_error ("Array bound mismatch for dimension 1 of "
717 (long int) ret_extent
, (long int) arg_extent
);
719 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
721 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
722 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
723 if (arg_extent
!= ret_extent
)
724 runtime_error ("Array bound mismatch for dimension 1 of "
726 (long int) ret_extent
, (long int) arg_extent
);
730 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
731 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
732 if (arg_extent
!= ret_extent
)
733 runtime_error ("Array bound mismatch for dimension 1 of "
735 (long int) ret_extent
, (long int) arg_extent
);
737 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
738 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
739 if (arg_extent
!= ret_extent
)
740 runtime_error ("Array bound mismatch for dimension 2 of "
742 (long int) ret_extent
, (long int) arg_extent
);
747 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
749 /* One-dimensional result may be addressed in the code below
750 either as a row or a column matrix. We want both cases to
752 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
756 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
757 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
761 if (GFC_DESCRIPTOR_RANK (a
) == 1)
763 /* Treat it as a a row matrix A[1,count]. */
764 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
768 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
772 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
773 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
775 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
776 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
779 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
781 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
782 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
783 "in dimension 1: is %ld, should be %ld",
784 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
787 if (GFC_DESCRIPTOR_RANK (b
) == 1)
789 /* Treat it as a column matrix B[count,1] */
790 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
792 /* bystride should never be used for 1-dimensional b.
793 The value is only used for calculation of the
794 memory by the buffer. */
800 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
801 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
802 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
805 abase
= a
->base_addr
;
806 bbase
= b
->base_addr
;
807 dest
= retarray
->base_addr
;
809 /* Now that everything is set up, we perform the multiplication
812 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
813 #define min(a,b) ((a) <= (b) ? (a) : (b))
814 #define max(a,b) ((a) >= (b) ? (a) : (b))
816 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
817 && (bxstride
== 1 || bystride
== 1)
818 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
821 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
822 const GFC_INTEGER_4 one
= 1, zero
= 0;
823 const int lda
= (axstride
== 1) ? aystride
: axstride
,
824 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
826 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
828 assert (gemm
!= NULL
);
829 const char *transa
, *transb
;
833 transa
= axstride
== 1 ? "N" : "T";
838 transb
= bxstride
== 1 ? "N" : "T";
840 gemm (transa
, transb
, &m
,
841 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
847 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
849 /* This block of code implements a tuned matmul, derived from
850 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
852 Bo Kagstrom and Per Ling
853 Department of Computing Science
855 S-901 87 Umea, Sweden
857 from netlib.org, translated to C, and modified for matmul.m4. */
859 const GFC_INTEGER_4
*a
, *b
;
861 const index_type m
= xcount
, n
= ycount
, k
= count
;
863 /* System generated locals */
864 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
865 i1
, i2
, i3
, i4
, i5
, i6
;
867 /* Local variables */
868 GFC_INTEGER_4 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
869 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
870 index_type i
, j
, l
, ii
, jj
, ll
;
871 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
876 c
= retarray
->base_addr
;
878 /* Parameter adjustments */
880 c_offset
= 1 + c_dim1
;
883 a_offset
= 1 + a_dim1
;
886 b_offset
= 1 + b_dim1
;
892 c
[i
+ j
* c_dim1
] = (GFC_INTEGER_4
)0;
894 /* Early exit if possible */
895 if (m
== 0 || n
== 0 || k
== 0)
898 /* Adjust size of t1 to what is needed. */
899 index_type t1_dim
, a_sz
;
905 t1_dim
= a_sz
* 256 + b_dim1
;
909 t1
= malloc (t1_dim
* sizeof(GFC_INTEGER_4
));
911 /* Start turning the crank. */
913 for (jj
= 1; jj
<= i1
; jj
+= 512)
919 ujsec
= jsec
- jsec
% 4;
921 for (ll
= 1; ll
<= i2
; ll
+= 256)
927 ulsec
= lsec
- lsec
% 2;
930 for (ii
= 1; ii
<= i3
; ii
+= 256)
936 uisec
= isec
- isec
% 2;
938 for (l
= ll
; l
<= i4
; l
+= 2)
941 for (i
= ii
; i
<= i5
; i
+= 2)
943 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
945 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
946 a
[i
+ (l
+ 1) * a_dim1
];
947 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
948 a
[i
+ 1 + l
* a_dim1
];
949 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
950 a
[i
+ 1 + (l
+ 1) * a_dim1
];
954 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
955 a
[ii
+ isec
- 1 + l
* a_dim1
];
956 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
957 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
963 for (i
= ii
; i
<= i4
; ++i
)
965 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
966 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
970 uisec
= isec
- isec
% 4;
972 for (j
= jj
; j
<= i4
; j
+= 4)
975 for (i
= ii
; i
<= i5
; i
+= 4)
977 f11
= c
[i
+ j
* c_dim1
];
978 f21
= c
[i
+ 1 + j
* c_dim1
];
979 f12
= c
[i
+ (j
+ 1) * c_dim1
];
980 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
981 f13
= c
[i
+ (j
+ 2) * c_dim1
];
982 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
983 f14
= c
[i
+ (j
+ 3) * c_dim1
];
984 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
985 f31
= c
[i
+ 2 + j
* c_dim1
];
986 f41
= c
[i
+ 3 + j
* c_dim1
];
987 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
988 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
989 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
990 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
991 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
992 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
994 for (l
= ll
; l
<= i6
; ++l
)
996 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
998 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1000 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1001 * b
[l
+ (j
+ 1) * b_dim1
];
1002 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1003 * b
[l
+ (j
+ 1) * b_dim1
];
1004 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1005 * b
[l
+ (j
+ 2) * b_dim1
];
1006 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1007 * b
[l
+ (j
+ 2) * b_dim1
];
1008 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1009 * b
[l
+ (j
+ 3) * b_dim1
];
1010 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1011 * b
[l
+ (j
+ 3) * b_dim1
];
1012 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1013 * b
[l
+ j
* b_dim1
];
1014 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1015 * b
[l
+ j
* b_dim1
];
1016 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1017 * b
[l
+ (j
+ 1) * b_dim1
];
1018 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1019 * b
[l
+ (j
+ 1) * b_dim1
];
1020 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1021 * b
[l
+ (j
+ 2) * b_dim1
];
1022 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1023 * b
[l
+ (j
+ 2) * b_dim1
];
1024 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1025 * b
[l
+ (j
+ 3) * b_dim1
];
1026 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1027 * b
[l
+ (j
+ 3) * b_dim1
];
1029 c
[i
+ j
* c_dim1
] = f11
;
1030 c
[i
+ 1 + j
* c_dim1
] = f21
;
1031 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1032 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
1033 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1034 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
1035 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1036 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
1037 c
[i
+ 2 + j
* c_dim1
] = f31
;
1038 c
[i
+ 3 + j
* c_dim1
] = f41
;
1039 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
1040 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
1041 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
1042 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
1043 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
1044 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
1049 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1051 f11
= c
[i
+ j
* c_dim1
];
1052 f12
= c
[i
+ (j
+ 1) * c_dim1
];
1053 f13
= c
[i
+ (j
+ 2) * c_dim1
];
1054 f14
= c
[i
+ (j
+ 3) * c_dim1
];
1056 for (l
= ll
; l
<= i6
; ++l
)
1058 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1059 257] * b
[l
+ j
* b_dim1
];
1060 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1061 257] * b
[l
+ (j
+ 1) * b_dim1
];
1062 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1063 257] * b
[l
+ (j
+ 2) * b_dim1
];
1064 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1065 257] * b
[l
+ (j
+ 3) * b_dim1
];
1067 c
[i
+ j
* c_dim1
] = f11
;
1068 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1069 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1070 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1077 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
1079 i5
= ii
+ uisec
- 1;
1080 for (i
= ii
; i
<= i5
; i
+= 4)
1082 f11
= c
[i
+ j
* c_dim1
];
1083 f21
= c
[i
+ 1 + j
* c_dim1
];
1084 f31
= c
[i
+ 2 + j
* c_dim1
];
1085 f41
= c
[i
+ 3 + j
* c_dim1
];
1087 for (l
= ll
; l
<= i6
; ++l
)
1089 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1090 257] * b
[l
+ j
* b_dim1
];
1091 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
1092 257] * b
[l
+ j
* b_dim1
];
1093 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
1094 257] * b
[l
+ j
* b_dim1
];
1095 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
1096 257] * b
[l
+ j
* b_dim1
];
1098 c
[i
+ j
* c_dim1
] = f11
;
1099 c
[i
+ 1 + j
* c_dim1
] = f21
;
1100 c
[i
+ 2 + j
* c_dim1
] = f31
;
1101 c
[i
+ 3 + j
* c_dim1
] = f41
;
1104 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1106 f11
= c
[i
+ j
* c_dim1
];
1108 for (l
= ll
; l
<= i6
; ++l
)
1110 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1111 257] * b
[l
+ j
* b_dim1
];
1113 c
[i
+ j
* c_dim1
] = f11
;
1123 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
1125 if (GFC_DESCRIPTOR_RANK (a
) != 1)
1127 const GFC_INTEGER_4
*restrict abase_x
;
1128 const GFC_INTEGER_4
*restrict bbase_y
;
1129 GFC_INTEGER_4
*restrict dest_y
;
1132 for (y
= 0; y
< ycount
; y
++)
1134 bbase_y
= &bbase
[y
*bystride
];
1135 dest_y
= &dest
[y
*rystride
];
1136 for (x
= 0; x
< xcount
; x
++)
1138 abase_x
= &abase
[x
*axstride
];
1139 s
= (GFC_INTEGER_4
) 0;
1140 for (n
= 0; n
< count
; n
++)
1141 s
+= abase_x
[n
] * bbase_y
[n
];
1148 const GFC_INTEGER_4
*restrict bbase_y
;
1151 for (y
= 0; y
< ycount
; y
++)
1153 bbase_y
= &bbase
[y
*bystride
];
1154 s
= (GFC_INTEGER_4
) 0;
1155 for (n
= 0; n
< count
; n
++)
1156 s
+= abase
[n
*axstride
] * bbase_y
[n
];
1157 dest
[y
*rystride
] = s
;
1161 else if (axstride
< aystride
)
1163 for (y
= 0; y
< ycount
; y
++)
1164 for (x
= 0; x
< xcount
; x
++)
1165 dest
[x
*rxstride
+ y
*rystride
] = (GFC_INTEGER_4
)0;
1167 for (y
= 0; y
< ycount
; y
++)
1168 for (n
= 0; n
< count
; n
++)
1169 for (x
= 0; x
< xcount
; x
++)
1170 /* dest[x,y] += a[x,n] * b[n,y] */
1171 dest
[x
*rxstride
+ y
*rystride
] +=
1172 abase
[x
*axstride
+ n
*aystride
] *
1173 bbase
[n
*bxstride
+ y
*bystride
];
1175 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
1177 const GFC_INTEGER_4
*restrict bbase_y
;
1180 for (y
= 0; y
< ycount
; y
++)
1182 bbase_y
= &bbase
[y
*bystride
];
1183 s
= (GFC_INTEGER_4
) 0;
1184 for (n
= 0; n
< count
; n
++)
1185 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
1186 dest
[y
*rxstride
] = s
;
1191 const GFC_INTEGER_4
*restrict abase_x
;
1192 const GFC_INTEGER_4
*restrict bbase_y
;
1193 GFC_INTEGER_4
*restrict dest_y
;
1196 for (y
= 0; y
< ycount
; y
++)
1198 bbase_y
= &bbase
[y
*bystride
];
1199 dest_y
= &dest
[y
*rystride
];
1200 for (x
= 0; x
< xcount
; x
++)
1202 abase_x
= &abase
[x
*axstride
];
1203 s
= (GFC_INTEGER_4
) 0;
1204 for (n
= 0; n
< count
; n
++)
1205 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
1206 dest_y
[x
*rxstride
] = s
;
1215 #endif /* HAVE_AVX2 */
1219 matmul_i4_avx512f (gfc_array_i4
* const restrict retarray
,
1220 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
1221 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx512f")));
1223 matmul_i4_avx512f (gfc_array_i4
* const restrict retarray
,
1224 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
1225 int blas_limit
, blas_call gemm
)
1227 const GFC_INTEGER_4
* restrict abase
;
1228 const GFC_INTEGER_4
* restrict bbase
;
1229 GFC_INTEGER_4
* restrict dest
;
1231 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
1232 index_type x
, y
, n
, count
, xcount
, ycount
;
1234 assert (GFC_DESCRIPTOR_RANK (a
) == 2
1235 || GFC_DESCRIPTOR_RANK (b
) == 2);
1237 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
1239 Either A or B (but not both) can be rank 1:
1241 o One-dimensional argument A is implicitly treated as a row matrix
1242 dimensioned [1,count], so xcount=1.
1244 o One-dimensional argument B is implicitly treated as a column matrix
1245 dimensioned [count, 1], so ycount=1.
1248 if (retarray
->base_addr
== NULL
)
1250 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1252 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1253 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
1255 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
1257 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1258 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
1262 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1263 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
1265 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
1266 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
1267 GFC_DESCRIPTOR_EXTENT(retarray
,0));
1271 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_INTEGER_4
));
1272 retarray
->offset
= 0;
1274 else if (unlikely (compile_options
.bounds_check
))
1276 index_type ret_extent
, arg_extent
;
1278 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1280 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
1281 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1282 if (arg_extent
!= ret_extent
)
1283 runtime_error ("Array bound mismatch for dimension 1 of "
1285 (long int) ret_extent
, (long int) arg_extent
);
1287 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
1289 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
1290 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1291 if (arg_extent
!= ret_extent
)
1292 runtime_error ("Array bound mismatch for dimension 1 of "
1294 (long int) ret_extent
, (long int) arg_extent
);
1298 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
1299 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1300 if (arg_extent
!= ret_extent
)
1301 runtime_error ("Array bound mismatch for dimension 1 of "
1303 (long int) ret_extent
, (long int) arg_extent
);
1305 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
1306 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
1307 if (arg_extent
!= ret_extent
)
1308 runtime_error ("Array bound mismatch for dimension 2 of "
1310 (long int) ret_extent
, (long int) arg_extent
);
1315 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
1317 /* One-dimensional result may be addressed in the code below
1318 either as a row or a column matrix. We want both cases to
1320 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
1324 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
1325 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
1329 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1331 /* Treat it as a a row matrix A[1,count]. */
1332 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
1336 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
1340 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
1341 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
1343 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
1344 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
1347 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
1349 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
1350 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
1351 "in dimension 1: is %ld, should be %ld",
1352 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
1355 if (GFC_DESCRIPTOR_RANK (b
) == 1)
1357 /* Treat it as a column matrix B[count,1] */
1358 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
1360 /* bystride should never be used for 1-dimensional b.
1361 The value is only used for calculation of the
1362 memory by the buffer. */
1368 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
1369 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
1370 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
1373 abase
= a
->base_addr
;
1374 bbase
= b
->base_addr
;
1375 dest
= retarray
->base_addr
;
1377 /* Now that everything is set up, we perform the multiplication
1380 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
1381 #define min(a,b) ((a) <= (b) ? (a) : (b))
1382 #define max(a,b) ((a) >= (b) ? (a) : (b))
1384 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
1385 && (bxstride
== 1 || bystride
== 1)
1386 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
1387 > POW3(blas_limit
)))
1389 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
1390 const GFC_INTEGER_4 one
= 1, zero
= 0;
1391 const int lda
= (axstride
== 1) ? aystride
: axstride
,
1392 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
1394 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
1396 assert (gemm
!= NULL
);
1397 const char *transa
, *transb
;
1401 transa
= axstride
== 1 ? "N" : "T";
1406 transb
= bxstride
== 1 ? "N" : "T";
1408 gemm (transa
, transb
, &m
,
1409 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
1415 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
1417 /* This block of code implements a tuned matmul, derived from
1418 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
1420 Bo Kagstrom and Per Ling
1421 Department of Computing Science
1423 S-901 87 Umea, Sweden
1425 from netlib.org, translated to C, and modified for matmul.m4. */
1427 const GFC_INTEGER_4
*a
, *b
;
1429 const index_type m
= xcount
, n
= ycount
, k
= count
;
1431 /* System generated locals */
1432 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
1433 i1
, i2
, i3
, i4
, i5
, i6
;
1435 /* Local variables */
1436 GFC_INTEGER_4 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
1437 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
1438 index_type i
, j
, l
, ii
, jj
, ll
;
1439 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
1444 c
= retarray
->base_addr
;
1446 /* Parameter adjustments */
1448 c_offset
= 1 + c_dim1
;
1451 a_offset
= 1 + a_dim1
;
1454 b_offset
= 1 + b_dim1
;
1457 /* Empty c first. */
1458 for (j
=1; j
<=n
; j
++)
1459 for (i
=1; i
<=m
; i
++)
1460 c
[i
+ j
* c_dim1
] = (GFC_INTEGER_4
)0;
1462 /* Early exit if possible */
1463 if (m
== 0 || n
== 0 || k
== 0)
1466 /* Adjust size of t1 to what is needed. */
1467 index_type t1_dim
, a_sz
;
1473 t1_dim
= a_sz
* 256 + b_dim1
;
1477 t1
= malloc (t1_dim
* sizeof(GFC_INTEGER_4
));
1479 /* Start turning the crank. */
1481 for (jj
= 1; jj
<= i1
; jj
+= 512)
1487 ujsec
= jsec
- jsec
% 4;
1489 for (ll
= 1; ll
<= i2
; ll
+= 256)
1495 ulsec
= lsec
- lsec
% 2;
1498 for (ii
= 1; ii
<= i3
; ii
+= 256)
1504 uisec
= isec
- isec
% 2;
1505 i4
= ll
+ ulsec
- 1;
1506 for (l
= ll
; l
<= i4
; l
+= 2)
1508 i5
= ii
+ uisec
- 1;
1509 for (i
= ii
; i
<= i5
; i
+= 2)
1511 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
1513 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
1514 a
[i
+ (l
+ 1) * a_dim1
];
1515 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
1516 a
[i
+ 1 + l
* a_dim1
];
1517 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
1518 a
[i
+ 1 + (l
+ 1) * a_dim1
];
1522 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
1523 a
[ii
+ isec
- 1 + l
* a_dim1
];
1524 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
1525 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
1531 for (i
= ii
; i
<= i4
; ++i
)
1533 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
1534 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
1538 uisec
= isec
- isec
% 4;
1539 i4
= jj
+ ujsec
- 1;
1540 for (j
= jj
; j
<= i4
; j
+= 4)
1542 i5
= ii
+ uisec
- 1;
1543 for (i
= ii
; i
<= i5
; i
+= 4)
1545 f11
= c
[i
+ j
* c_dim1
];
1546 f21
= c
[i
+ 1 + j
* c_dim1
];
1547 f12
= c
[i
+ (j
+ 1) * c_dim1
];
1548 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
1549 f13
= c
[i
+ (j
+ 2) * c_dim1
];
1550 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
1551 f14
= c
[i
+ (j
+ 3) * c_dim1
];
1552 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
1553 f31
= c
[i
+ 2 + j
* c_dim1
];
1554 f41
= c
[i
+ 3 + j
* c_dim1
];
1555 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
1556 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
1557 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
1558 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
1559 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
1560 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
1562 for (l
= ll
; l
<= i6
; ++l
)
1564 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1565 * b
[l
+ j
* b_dim1
];
1566 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1567 * b
[l
+ j
* b_dim1
];
1568 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1569 * b
[l
+ (j
+ 1) * b_dim1
];
1570 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1571 * b
[l
+ (j
+ 1) * b_dim1
];
1572 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1573 * b
[l
+ (j
+ 2) * b_dim1
];
1574 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1575 * b
[l
+ (j
+ 2) * b_dim1
];
1576 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
1577 * b
[l
+ (j
+ 3) * b_dim1
];
1578 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
1579 * b
[l
+ (j
+ 3) * b_dim1
];
1580 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1581 * b
[l
+ j
* b_dim1
];
1582 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1583 * b
[l
+ j
* b_dim1
];
1584 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1585 * b
[l
+ (j
+ 1) * b_dim1
];
1586 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1587 * b
[l
+ (j
+ 1) * b_dim1
];
1588 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1589 * b
[l
+ (j
+ 2) * b_dim1
];
1590 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1591 * b
[l
+ (j
+ 2) * b_dim1
];
1592 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
1593 * b
[l
+ (j
+ 3) * b_dim1
];
1594 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
1595 * b
[l
+ (j
+ 3) * b_dim1
];
1597 c
[i
+ j
* c_dim1
] = f11
;
1598 c
[i
+ 1 + j
* c_dim1
] = f21
;
1599 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1600 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
1601 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1602 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
1603 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1604 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
1605 c
[i
+ 2 + j
* c_dim1
] = f31
;
1606 c
[i
+ 3 + j
* c_dim1
] = f41
;
1607 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
1608 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
1609 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
1610 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
1611 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
1612 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
1617 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1619 f11
= c
[i
+ j
* c_dim1
];
1620 f12
= c
[i
+ (j
+ 1) * c_dim1
];
1621 f13
= c
[i
+ (j
+ 2) * c_dim1
];
1622 f14
= c
[i
+ (j
+ 3) * c_dim1
];
1624 for (l
= ll
; l
<= i6
; ++l
)
1626 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1627 257] * b
[l
+ j
* b_dim1
];
1628 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1629 257] * b
[l
+ (j
+ 1) * b_dim1
];
1630 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1631 257] * b
[l
+ (j
+ 2) * b_dim1
];
1632 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1633 257] * b
[l
+ (j
+ 3) * b_dim1
];
1635 c
[i
+ j
* c_dim1
] = f11
;
1636 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
1637 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
1638 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
1645 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
1647 i5
= ii
+ uisec
- 1;
1648 for (i
= ii
; i
<= i5
; i
+= 4)
1650 f11
= c
[i
+ j
* c_dim1
];
1651 f21
= c
[i
+ 1 + j
* c_dim1
];
1652 f31
= c
[i
+ 2 + j
* c_dim1
];
1653 f41
= c
[i
+ 3 + j
* c_dim1
];
1655 for (l
= ll
; l
<= i6
; ++l
)
1657 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1658 257] * b
[l
+ j
* b_dim1
];
1659 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
1660 257] * b
[l
+ j
* b_dim1
];
1661 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
1662 257] * b
[l
+ j
* b_dim1
];
1663 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
1664 257] * b
[l
+ j
* b_dim1
];
1666 c
[i
+ j
* c_dim1
] = f11
;
1667 c
[i
+ 1 + j
* c_dim1
] = f21
;
1668 c
[i
+ 2 + j
* c_dim1
] = f31
;
1669 c
[i
+ 3 + j
* c_dim1
] = f41
;
1672 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
1674 f11
= c
[i
+ j
* c_dim1
];
1676 for (l
= ll
; l
<= i6
; ++l
)
1678 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
1679 257] * b
[l
+ j
* b_dim1
];
1681 c
[i
+ j
* c_dim1
] = f11
;
1691 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
1693 if (GFC_DESCRIPTOR_RANK (a
) != 1)
1695 const GFC_INTEGER_4
*restrict abase_x
;
1696 const GFC_INTEGER_4
*restrict bbase_y
;
1697 GFC_INTEGER_4
*restrict dest_y
;
1700 for (y
= 0; y
< ycount
; y
++)
1702 bbase_y
= &bbase
[y
*bystride
];
1703 dest_y
= &dest
[y
*rystride
];
1704 for (x
= 0; x
< xcount
; x
++)
1706 abase_x
= &abase
[x
*axstride
];
1707 s
= (GFC_INTEGER_4
) 0;
1708 for (n
= 0; n
< count
; n
++)
1709 s
+= abase_x
[n
] * bbase_y
[n
];
1716 const GFC_INTEGER_4
*restrict bbase_y
;
1719 for (y
= 0; y
< ycount
; y
++)
1721 bbase_y
= &bbase
[y
*bystride
];
1722 s
= (GFC_INTEGER_4
) 0;
1723 for (n
= 0; n
< count
; n
++)
1724 s
+= abase
[n
*axstride
] * bbase_y
[n
];
1725 dest
[y
*rystride
] = s
;
1729 else if (axstride
< aystride
)
1731 for (y
= 0; y
< ycount
; y
++)
1732 for (x
= 0; x
< xcount
; x
++)
1733 dest
[x
*rxstride
+ y
*rystride
] = (GFC_INTEGER_4
)0;
1735 for (y
= 0; y
< ycount
; y
++)
1736 for (n
= 0; n
< count
; n
++)
1737 for (x
= 0; x
< xcount
; x
++)
1738 /* dest[x,y] += a[x,n] * b[n,y] */
1739 dest
[x
*rxstride
+ y
*rystride
] +=
1740 abase
[x
*axstride
+ n
*aystride
] *
1741 bbase
[n
*bxstride
+ y
*bystride
];
1743 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
1745 const GFC_INTEGER_4
*restrict bbase_y
;
1748 for (y
= 0; y
< ycount
; y
++)
1750 bbase_y
= &bbase
[y
*bystride
];
1751 s
= (GFC_INTEGER_4
) 0;
1752 for (n
= 0; n
< count
; n
++)
1753 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
1754 dest
[y
*rxstride
] = s
;
1759 const GFC_INTEGER_4
*restrict abase_x
;
1760 const GFC_INTEGER_4
*restrict bbase_y
;
1761 GFC_INTEGER_4
*restrict dest_y
;
1764 for (y
= 0; y
< ycount
; y
++)
1766 bbase_y
= &bbase
[y
*bystride
];
1767 dest_y
= &dest
[y
*rystride
];
1768 for (x
= 0; x
< xcount
; x
++)
1770 abase_x
= &abase
[x
*axstride
];
1771 s
= (GFC_INTEGER_4
) 0;
1772 for (n
= 0; n
< count
; n
++)
1773 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
1774 dest_y
[x
*rxstride
] = s
;
1783 #endif /* HAVE_AVX512F */
1785 /* AMD-specifix funtions with AVX128 and FMA3/FMA4. */
1787 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
1789 matmul_i4_avx128_fma3 (gfc_array_i4
* const restrict retarray
,
1790 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
1791 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma")));
1792 internal_proto(matmul_i4_avx128_fma3
);
1795 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
1797 matmul_i4_avx128_fma4 (gfc_array_i4
* const restrict retarray
,
1798 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
1799 int blas_limit
, blas_call gemm
) __attribute__((__target__("avx,fma4")));
1800 internal_proto(matmul_i4_avx128_fma4
);
1803 /* Function to fall back to if there is no special processor-specific version. */
1805 matmul_i4_vanilla (gfc_array_i4
* const restrict retarray
,
1806 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
1807 int blas_limit
, blas_call gemm
)
1809 const GFC_INTEGER_4
* restrict abase
;
1810 const GFC_INTEGER_4
* restrict bbase
;
1811 GFC_INTEGER_4
* restrict dest
;
1813 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
1814 index_type x
, y
, n
, count
, xcount
, ycount
;
1816 assert (GFC_DESCRIPTOR_RANK (a
) == 2
1817 || GFC_DESCRIPTOR_RANK (b
) == 2);
1819 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
1821 Either A or B (but not both) can be rank 1:
1823 o One-dimensional argument A is implicitly treated as a row matrix
1824 dimensioned [1,count], so xcount=1.
1826 o One-dimensional argument B is implicitly treated as a column matrix
1827 dimensioned [count, 1], so ycount=1.
1830 if (retarray
->base_addr
== NULL
)
1832 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1834 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1835 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
1837 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
1839 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1840 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
1844 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
1845 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
1847 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
1848 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
1849 GFC_DESCRIPTOR_EXTENT(retarray
,0));
1853 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_INTEGER_4
));
1854 retarray
->offset
= 0;
1856 else if (unlikely (compile_options
.bounds_check
))
1858 index_type ret_extent
, arg_extent
;
1860 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1862 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
1863 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1864 if (arg_extent
!= ret_extent
)
1865 runtime_error ("Array bound mismatch for dimension 1 of "
1867 (long int) ret_extent
, (long int) arg_extent
);
1869 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
1871 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
1872 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1873 if (arg_extent
!= ret_extent
)
1874 runtime_error ("Array bound mismatch for dimension 1 of "
1876 (long int) ret_extent
, (long int) arg_extent
);
1880 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
1881 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
1882 if (arg_extent
!= ret_extent
)
1883 runtime_error ("Array bound mismatch for dimension 1 of "
1885 (long int) ret_extent
, (long int) arg_extent
);
1887 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
1888 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
1889 if (arg_extent
!= ret_extent
)
1890 runtime_error ("Array bound mismatch for dimension 2 of "
1892 (long int) ret_extent
, (long int) arg_extent
);
1897 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
1899 /* One-dimensional result may be addressed in the code below
1900 either as a row or a column matrix. We want both cases to
1902 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
1906 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
1907 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
1911 if (GFC_DESCRIPTOR_RANK (a
) == 1)
1913 /* Treat it as a a row matrix A[1,count]. */
1914 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
1918 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
1922 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
1923 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
1925 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
1926 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
1929 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
1931 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
1932 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
1933 "in dimension 1: is %ld, should be %ld",
1934 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
1937 if (GFC_DESCRIPTOR_RANK (b
) == 1)
1939 /* Treat it as a column matrix B[count,1] */
1940 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
1942 /* bystride should never be used for 1-dimensional b.
1943 The value is only used for calculation of the
1944 memory by the buffer. */
1950 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
1951 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
1952 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
1955 abase
= a
->base_addr
;
1956 bbase
= b
->base_addr
;
1957 dest
= retarray
->base_addr
;
1959 /* Now that everything is set up, we perform the multiplication
1962 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
1963 #define min(a,b) ((a) <= (b) ? (a) : (b))
1964 #define max(a,b) ((a) >= (b) ? (a) : (b))
1966 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
1967 && (bxstride
== 1 || bystride
== 1)
1968 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
1969 > POW3(blas_limit
)))
1971 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
1972 const GFC_INTEGER_4 one
= 1, zero
= 0;
1973 const int lda
= (axstride
== 1) ? aystride
: axstride
,
1974 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
1976 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
1978 assert (gemm
!= NULL
);
1979 const char *transa
, *transb
;
1983 transa
= axstride
== 1 ? "N" : "T";
1988 transb
= bxstride
== 1 ? "N" : "T";
1990 gemm (transa
, transb
, &m
,
1991 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
1997 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
1999 /* This block of code implements a tuned matmul, derived from
2000 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
2002 Bo Kagstrom and Per Ling
2003 Department of Computing Science
2005 S-901 87 Umea, Sweden
2007 from netlib.org, translated to C, and modified for matmul.m4. */
2009 const GFC_INTEGER_4
*a
, *b
;
2011 const index_type m
= xcount
, n
= ycount
, k
= count
;
2013 /* System generated locals */
2014 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
2015 i1
, i2
, i3
, i4
, i5
, i6
;
2017 /* Local variables */
2018 GFC_INTEGER_4 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
2019 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
2020 index_type i
, j
, l
, ii
, jj
, ll
;
2021 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
2026 c
= retarray
->base_addr
;
2028 /* Parameter adjustments */
2030 c_offset
= 1 + c_dim1
;
2033 a_offset
= 1 + a_dim1
;
2036 b_offset
= 1 + b_dim1
;
2039 /* Empty c first. */
2040 for (j
=1; j
<=n
; j
++)
2041 for (i
=1; i
<=m
; i
++)
2042 c
[i
+ j
* c_dim1
] = (GFC_INTEGER_4
)0;
2044 /* Early exit if possible */
2045 if (m
== 0 || n
== 0 || k
== 0)
2048 /* Adjust size of t1 to what is needed. */
2049 index_type t1_dim
, a_sz
;
2055 t1_dim
= a_sz
* 256 + b_dim1
;
2059 t1
= malloc (t1_dim
* sizeof(GFC_INTEGER_4
));
2061 /* Start turning the crank. */
2063 for (jj
= 1; jj
<= i1
; jj
+= 512)
2069 ujsec
= jsec
- jsec
% 4;
2071 for (ll
= 1; ll
<= i2
; ll
+= 256)
2077 ulsec
= lsec
- lsec
% 2;
2080 for (ii
= 1; ii
<= i3
; ii
+= 256)
2086 uisec
= isec
- isec
% 2;
2087 i4
= ll
+ ulsec
- 1;
2088 for (l
= ll
; l
<= i4
; l
+= 2)
2090 i5
= ii
+ uisec
- 1;
2091 for (i
= ii
; i
<= i5
; i
+= 2)
2093 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
2095 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
2096 a
[i
+ (l
+ 1) * a_dim1
];
2097 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
2098 a
[i
+ 1 + l
* a_dim1
];
2099 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
2100 a
[i
+ 1 + (l
+ 1) * a_dim1
];
2104 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
2105 a
[ii
+ isec
- 1 + l
* a_dim1
];
2106 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
2107 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
2113 for (i
= ii
; i
<= i4
; ++i
)
2115 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
2116 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
2120 uisec
= isec
- isec
% 4;
2121 i4
= jj
+ ujsec
- 1;
2122 for (j
= jj
; j
<= i4
; j
+= 4)
2124 i5
= ii
+ uisec
- 1;
2125 for (i
= ii
; i
<= i5
; i
+= 4)
2127 f11
= c
[i
+ j
* c_dim1
];
2128 f21
= c
[i
+ 1 + j
* c_dim1
];
2129 f12
= c
[i
+ (j
+ 1) * c_dim1
];
2130 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
2131 f13
= c
[i
+ (j
+ 2) * c_dim1
];
2132 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
2133 f14
= c
[i
+ (j
+ 3) * c_dim1
];
2134 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
2135 f31
= c
[i
+ 2 + j
* c_dim1
];
2136 f41
= c
[i
+ 3 + j
* c_dim1
];
2137 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
2138 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
2139 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
2140 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
2141 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
2142 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
2144 for (l
= ll
; l
<= i6
; ++l
)
2146 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2147 * b
[l
+ j
* b_dim1
];
2148 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2149 * b
[l
+ j
* b_dim1
];
2150 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2151 * b
[l
+ (j
+ 1) * b_dim1
];
2152 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2153 * b
[l
+ (j
+ 1) * b_dim1
];
2154 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2155 * b
[l
+ (j
+ 2) * b_dim1
];
2156 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2157 * b
[l
+ (j
+ 2) * b_dim1
];
2158 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2159 * b
[l
+ (j
+ 3) * b_dim1
];
2160 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2161 * b
[l
+ (j
+ 3) * b_dim1
];
2162 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2163 * b
[l
+ j
* b_dim1
];
2164 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2165 * b
[l
+ j
* b_dim1
];
2166 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2167 * b
[l
+ (j
+ 1) * b_dim1
];
2168 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2169 * b
[l
+ (j
+ 1) * b_dim1
];
2170 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2171 * b
[l
+ (j
+ 2) * b_dim1
];
2172 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2173 * b
[l
+ (j
+ 2) * b_dim1
];
2174 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2175 * b
[l
+ (j
+ 3) * b_dim1
];
2176 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2177 * b
[l
+ (j
+ 3) * b_dim1
];
2179 c
[i
+ j
* c_dim1
] = f11
;
2180 c
[i
+ 1 + j
* c_dim1
] = f21
;
2181 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
2182 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
2183 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
2184 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
2185 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
2186 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
2187 c
[i
+ 2 + j
* c_dim1
] = f31
;
2188 c
[i
+ 3 + j
* c_dim1
] = f41
;
2189 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
2190 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
2191 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
2192 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
2193 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
2194 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
2199 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
2201 f11
= c
[i
+ j
* c_dim1
];
2202 f12
= c
[i
+ (j
+ 1) * c_dim1
];
2203 f13
= c
[i
+ (j
+ 2) * c_dim1
];
2204 f14
= c
[i
+ (j
+ 3) * c_dim1
];
2206 for (l
= ll
; l
<= i6
; ++l
)
2208 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2209 257] * b
[l
+ j
* b_dim1
];
2210 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2211 257] * b
[l
+ (j
+ 1) * b_dim1
];
2212 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2213 257] * b
[l
+ (j
+ 2) * b_dim1
];
2214 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2215 257] * b
[l
+ (j
+ 3) * b_dim1
];
2217 c
[i
+ j
* c_dim1
] = f11
;
2218 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
2219 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
2220 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
2227 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
2229 i5
= ii
+ uisec
- 1;
2230 for (i
= ii
; i
<= i5
; i
+= 4)
2232 f11
= c
[i
+ j
* c_dim1
];
2233 f21
= c
[i
+ 1 + j
* c_dim1
];
2234 f31
= c
[i
+ 2 + j
* c_dim1
];
2235 f41
= c
[i
+ 3 + j
* c_dim1
];
2237 for (l
= ll
; l
<= i6
; ++l
)
2239 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2240 257] * b
[l
+ j
* b_dim1
];
2241 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
2242 257] * b
[l
+ j
* b_dim1
];
2243 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
2244 257] * b
[l
+ j
* b_dim1
];
2245 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
2246 257] * b
[l
+ j
* b_dim1
];
2248 c
[i
+ j
* c_dim1
] = f11
;
2249 c
[i
+ 1 + j
* c_dim1
] = f21
;
2250 c
[i
+ 2 + j
* c_dim1
] = f31
;
2251 c
[i
+ 3 + j
* c_dim1
] = f41
;
2254 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
2256 f11
= c
[i
+ j
* c_dim1
];
2258 for (l
= ll
; l
<= i6
; ++l
)
2260 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2261 257] * b
[l
+ j
* b_dim1
];
2263 c
[i
+ j
* c_dim1
] = f11
;
2273 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
2275 if (GFC_DESCRIPTOR_RANK (a
) != 1)
2277 const GFC_INTEGER_4
*restrict abase_x
;
2278 const GFC_INTEGER_4
*restrict bbase_y
;
2279 GFC_INTEGER_4
*restrict dest_y
;
2282 for (y
= 0; y
< ycount
; y
++)
2284 bbase_y
= &bbase
[y
*bystride
];
2285 dest_y
= &dest
[y
*rystride
];
2286 for (x
= 0; x
< xcount
; x
++)
2288 abase_x
= &abase
[x
*axstride
];
2289 s
= (GFC_INTEGER_4
) 0;
2290 for (n
= 0; n
< count
; n
++)
2291 s
+= abase_x
[n
] * bbase_y
[n
];
2298 const GFC_INTEGER_4
*restrict bbase_y
;
2301 for (y
= 0; y
< ycount
; y
++)
2303 bbase_y
= &bbase
[y
*bystride
];
2304 s
= (GFC_INTEGER_4
) 0;
2305 for (n
= 0; n
< count
; n
++)
2306 s
+= abase
[n
*axstride
] * bbase_y
[n
];
2307 dest
[y
*rystride
] = s
;
2311 else if (axstride
< aystride
)
2313 for (y
= 0; y
< ycount
; y
++)
2314 for (x
= 0; x
< xcount
; x
++)
2315 dest
[x
*rxstride
+ y
*rystride
] = (GFC_INTEGER_4
)0;
2317 for (y
= 0; y
< ycount
; y
++)
2318 for (n
= 0; n
< count
; n
++)
2319 for (x
= 0; x
< xcount
; x
++)
2320 /* dest[x,y] += a[x,n] * b[n,y] */
2321 dest
[x
*rxstride
+ y
*rystride
] +=
2322 abase
[x
*axstride
+ n
*aystride
] *
2323 bbase
[n
*bxstride
+ y
*bystride
];
2325 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
2327 const GFC_INTEGER_4
*restrict bbase_y
;
2330 for (y
= 0; y
< ycount
; y
++)
2332 bbase_y
= &bbase
[y
*bystride
];
2333 s
= (GFC_INTEGER_4
) 0;
2334 for (n
= 0; n
< count
; n
++)
2335 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
2336 dest
[y
*rxstride
] = s
;
2341 const GFC_INTEGER_4
*restrict abase_x
;
2342 const GFC_INTEGER_4
*restrict bbase_y
;
2343 GFC_INTEGER_4
*restrict dest_y
;
2346 for (y
= 0; y
< ycount
; y
++)
2348 bbase_y
= &bbase
[y
*bystride
];
2349 dest_y
= &dest
[y
*rystride
];
2350 for (x
= 0; x
< xcount
; x
++)
2352 abase_x
= &abase
[x
*axstride
];
2353 s
= (GFC_INTEGER_4
) 0;
2354 for (n
= 0; n
< count
; n
++)
2355 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
2356 dest_y
[x
*rxstride
] = s
;
2366 /* Compiling main function, with selection code for the processor. */
2368 /* Currently, this is i386 only. Adjust for other architectures. */
2370 #include <config/i386/cpuinfo.h>
2371 void matmul_i4 (gfc_array_i4
* const restrict retarray
,
2372 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
2373 int blas_limit
, blas_call gemm
)
2375 static void (*matmul_p
) (gfc_array_i4
* const restrict retarray
,
2376 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
2377 int blas_limit
, blas_call gemm
);
2379 void (*matmul_fn
) (gfc_array_i4
* const restrict retarray
,
2380 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
2381 int blas_limit
, blas_call gemm
);
2383 matmul_fn
= __atomic_load_n (&matmul_p
, __ATOMIC_RELAXED
);
2384 if (matmul_fn
== NULL
)
2386 matmul_fn
= matmul_i4_vanilla
;
2387 if (__cpu_model
.__cpu_vendor
== VENDOR_INTEL
)
2389 /* Run down the available processors in order of preference. */
2391 if (__cpu_model
.__cpu_features
[0] & (1 << FEATURE_AVX512F
))
2393 matmul_fn
= matmul_i4_avx512f
;
2397 #endif /* HAVE_AVX512F */
2400 if ((__cpu_model
.__cpu_features
[0] & (1 << FEATURE_AVX2
))
2401 && (__cpu_model
.__cpu_features
[0] & (1 << FEATURE_FMA
)))
2403 matmul_fn
= matmul_i4_avx2
;
2410 if (__cpu_model
.__cpu_features
[0] & (1 << FEATURE_AVX
))
2412 matmul_fn
= matmul_i4_avx
;
2415 #endif /* HAVE_AVX */
2417 else if (__cpu_model
.__cpu_vendor
== VENDOR_AMD
)
2419 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
2420 if ((__cpu_model
.__cpu_features
[0] & (1 << FEATURE_AVX
))
2421 && (__cpu_model
.__cpu_features
[0] & (1 << FEATURE_FMA
)))
2423 matmul_fn
= matmul_i4_avx128_fma3
;
2427 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
2428 if ((__cpu_model
.__cpu_features
[0] & (1 << FEATURE_AVX
))
2429 && (__cpu_model
.__cpu_features
[0] & (1 << FEATURE_FMA4
)))
2431 matmul_fn
= matmul_i4_avx128_fma4
;
2438 __atomic_store_n (&matmul_p
, matmul_fn
, __ATOMIC_RELAXED
);
2441 (*matmul_fn
) (retarray
, a
, b
, try_blas
, blas_limit
, gemm
);
2444 #else /* Just the vanilla function. */
2447 matmul_i4 (gfc_array_i4
* const restrict retarray
,
2448 gfc_array_i4
* const restrict a
, gfc_array_i4
* const restrict b
, int try_blas
,
2449 int blas_limit
, blas_call gemm
)
2451 const GFC_INTEGER_4
* restrict abase
;
2452 const GFC_INTEGER_4
* restrict bbase
;
2453 GFC_INTEGER_4
* restrict dest
;
2455 index_type rxstride
, rystride
, axstride
, aystride
, bxstride
, bystride
;
2456 index_type x
, y
, n
, count
, xcount
, ycount
;
2458 assert (GFC_DESCRIPTOR_RANK (a
) == 2
2459 || GFC_DESCRIPTOR_RANK (b
) == 2);
2461 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
2463 Either A or B (but not both) can be rank 1:
2465 o One-dimensional argument A is implicitly treated as a row matrix
2466 dimensioned [1,count], so xcount=1.
2468 o One-dimensional argument B is implicitly treated as a column matrix
2469 dimensioned [count, 1], so ycount=1.
2472 if (retarray
->base_addr
== NULL
)
2474 if (GFC_DESCRIPTOR_RANK (a
) == 1)
2476 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
2477 GFC_DESCRIPTOR_EXTENT(b
,1) - 1, 1);
2479 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
2481 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
2482 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
2486 GFC_DIMENSION_SET(retarray
->dim
[0], 0,
2487 GFC_DESCRIPTOR_EXTENT(a
,0) - 1, 1);
2489 GFC_DIMENSION_SET(retarray
->dim
[1], 0,
2490 GFC_DESCRIPTOR_EXTENT(b
,1) - 1,
2491 GFC_DESCRIPTOR_EXTENT(retarray
,0));
2495 = xmallocarray (size0 ((array_t
*) retarray
), sizeof (GFC_INTEGER_4
));
2496 retarray
->offset
= 0;
2498 else if (unlikely (compile_options
.bounds_check
))
2500 index_type ret_extent
, arg_extent
;
2502 if (GFC_DESCRIPTOR_RANK (a
) == 1)
2504 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
2505 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
2506 if (arg_extent
!= ret_extent
)
2507 runtime_error ("Array bound mismatch for dimension 1 of "
2509 (long int) ret_extent
, (long int) arg_extent
);
2511 else if (GFC_DESCRIPTOR_RANK (b
) == 1)
2513 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
2514 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
2515 if (arg_extent
!= ret_extent
)
2516 runtime_error ("Array bound mismatch for dimension 1 of "
2518 (long int) ret_extent
, (long int) arg_extent
);
2522 arg_extent
= GFC_DESCRIPTOR_EXTENT(a
,0);
2523 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,0);
2524 if (arg_extent
!= ret_extent
)
2525 runtime_error ("Array bound mismatch for dimension 1 of "
2527 (long int) ret_extent
, (long int) arg_extent
);
2529 arg_extent
= GFC_DESCRIPTOR_EXTENT(b
,1);
2530 ret_extent
= GFC_DESCRIPTOR_EXTENT(retarray
,1);
2531 if (arg_extent
!= ret_extent
)
2532 runtime_error ("Array bound mismatch for dimension 2 of "
2534 (long int) ret_extent
, (long int) arg_extent
);
2539 if (GFC_DESCRIPTOR_RANK (retarray
) == 1)
2541 /* One-dimensional result may be addressed in the code below
2542 either as a row or a column matrix. We want both cases to
2544 rxstride
= rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
2548 rxstride
= GFC_DESCRIPTOR_STRIDE(retarray
,0);
2549 rystride
= GFC_DESCRIPTOR_STRIDE(retarray
,1);
2553 if (GFC_DESCRIPTOR_RANK (a
) == 1)
2555 /* Treat it as a a row matrix A[1,count]. */
2556 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
2560 count
= GFC_DESCRIPTOR_EXTENT(a
,0);
2564 axstride
= GFC_DESCRIPTOR_STRIDE(a
,0);
2565 aystride
= GFC_DESCRIPTOR_STRIDE(a
,1);
2567 count
= GFC_DESCRIPTOR_EXTENT(a
,1);
2568 xcount
= GFC_DESCRIPTOR_EXTENT(a
,0);
2571 if (count
!= GFC_DESCRIPTOR_EXTENT(b
,0))
2573 if (count
> 0 || GFC_DESCRIPTOR_EXTENT(b
,0) > 0)
2574 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
2575 "in dimension 1: is %ld, should be %ld",
2576 (long int) GFC_DESCRIPTOR_EXTENT(b
,0), (long int) count
);
2579 if (GFC_DESCRIPTOR_RANK (b
) == 1)
2581 /* Treat it as a column matrix B[count,1] */
2582 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
2584 /* bystride should never be used for 1-dimensional b.
2585 The value is only used for calculation of the
2586 memory by the buffer. */
2592 bxstride
= GFC_DESCRIPTOR_STRIDE(b
,0);
2593 bystride
= GFC_DESCRIPTOR_STRIDE(b
,1);
2594 ycount
= GFC_DESCRIPTOR_EXTENT(b
,1);
2597 abase
= a
->base_addr
;
2598 bbase
= b
->base_addr
;
2599 dest
= retarray
->base_addr
;
2601 /* Now that everything is set up, we perform the multiplication
2604 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
2605 #define min(a,b) ((a) <= (b) ? (a) : (b))
2606 #define max(a,b) ((a) >= (b) ? (a) : (b))
2608 if (try_blas
&& rxstride
== 1 && (axstride
== 1 || aystride
== 1)
2609 && (bxstride
== 1 || bystride
== 1)
2610 && (((float) xcount
) * ((float) ycount
) * ((float) count
)
2611 > POW3(blas_limit
)))
2613 const int m
= xcount
, n
= ycount
, k
= count
, ldc
= rystride
;
2614 const GFC_INTEGER_4 one
= 1, zero
= 0;
2615 const int lda
= (axstride
== 1) ? aystride
: axstride
,
2616 ldb
= (bxstride
== 1) ? bystride
: bxstride
;
2618 if (lda
> 0 && ldb
> 0 && ldc
> 0 && m
> 1 && n
> 1 && k
> 1)
2620 assert (gemm
!= NULL
);
2621 const char *transa
, *transb
;
2625 transa
= axstride
== 1 ? "N" : "T";
2630 transb
= bxstride
== 1 ? "N" : "T";
2632 gemm (transa
, transb
, &m
,
2633 &n
, &k
, &one
, abase
, &lda
, bbase
, &ldb
, &zero
, dest
,
2639 if (rxstride
== 1 && axstride
== 1 && bxstride
== 1)
2641 /* This block of code implements a tuned matmul, derived from
2642 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
2644 Bo Kagstrom and Per Ling
2645 Department of Computing Science
2647 S-901 87 Umea, Sweden
2649 from netlib.org, translated to C, and modified for matmul.m4. */
2651 const GFC_INTEGER_4
*a
, *b
;
2653 const index_type m
= xcount
, n
= ycount
, k
= count
;
2655 /* System generated locals */
2656 index_type a_dim1
, a_offset
, b_dim1
, b_offset
, c_dim1
, c_offset
,
2657 i1
, i2
, i3
, i4
, i5
, i6
;
2659 /* Local variables */
2660 GFC_INTEGER_4 f11
, f12
, f21
, f22
, f31
, f32
, f41
, f42
,
2661 f13
, f14
, f23
, f24
, f33
, f34
, f43
, f44
;
2662 index_type i
, j
, l
, ii
, jj
, ll
;
2663 index_type isec
, jsec
, lsec
, uisec
, ujsec
, ulsec
;
2668 c
= retarray
->base_addr
;
2670 /* Parameter adjustments */
2672 c_offset
= 1 + c_dim1
;
2675 a_offset
= 1 + a_dim1
;
2678 b_offset
= 1 + b_dim1
;
2681 /* Empty c first. */
2682 for (j
=1; j
<=n
; j
++)
2683 for (i
=1; i
<=m
; i
++)
2684 c
[i
+ j
* c_dim1
] = (GFC_INTEGER_4
)0;
2686 /* Early exit if possible */
2687 if (m
== 0 || n
== 0 || k
== 0)
2690 /* Adjust size of t1 to what is needed. */
2691 index_type t1_dim
, a_sz
;
2697 t1_dim
= a_sz
* 256 + b_dim1
;
2701 t1
= malloc (t1_dim
* sizeof(GFC_INTEGER_4
));
2703 /* Start turning the crank. */
2705 for (jj
= 1; jj
<= i1
; jj
+= 512)
2711 ujsec
= jsec
- jsec
% 4;
2713 for (ll
= 1; ll
<= i2
; ll
+= 256)
2719 ulsec
= lsec
- lsec
% 2;
2722 for (ii
= 1; ii
<= i3
; ii
+= 256)
2728 uisec
= isec
- isec
% 2;
2729 i4
= ll
+ ulsec
- 1;
2730 for (l
= ll
; l
<= i4
; l
+= 2)
2732 i5
= ii
+ uisec
- 1;
2733 for (i
= ii
; i
<= i5
; i
+= 2)
2735 t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257] =
2737 t1
[l
- ll
+ 2 + ((i
- ii
+ 1) << 8) - 257] =
2738 a
[i
+ (l
+ 1) * a_dim1
];
2739 t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257] =
2740 a
[i
+ 1 + l
* a_dim1
];
2741 t1
[l
- ll
+ 2 + ((i
- ii
+ 2) << 8) - 257] =
2742 a
[i
+ 1 + (l
+ 1) * a_dim1
];
2746 t1
[l
- ll
+ 1 + (isec
<< 8) - 257] =
2747 a
[ii
+ isec
- 1 + l
* a_dim1
];
2748 t1
[l
- ll
+ 2 + (isec
<< 8) - 257] =
2749 a
[ii
+ isec
- 1 + (l
+ 1) * a_dim1
];
2755 for (i
= ii
; i
<= i4
; ++i
)
2757 t1
[lsec
+ ((i
- ii
+ 1) << 8) - 257] =
2758 a
[i
+ (ll
+ lsec
- 1) * a_dim1
];
2762 uisec
= isec
- isec
% 4;
2763 i4
= jj
+ ujsec
- 1;
2764 for (j
= jj
; j
<= i4
; j
+= 4)
2766 i5
= ii
+ uisec
- 1;
2767 for (i
= ii
; i
<= i5
; i
+= 4)
2769 f11
= c
[i
+ j
* c_dim1
];
2770 f21
= c
[i
+ 1 + j
* c_dim1
];
2771 f12
= c
[i
+ (j
+ 1) * c_dim1
];
2772 f22
= c
[i
+ 1 + (j
+ 1) * c_dim1
];
2773 f13
= c
[i
+ (j
+ 2) * c_dim1
];
2774 f23
= c
[i
+ 1 + (j
+ 2) * c_dim1
];
2775 f14
= c
[i
+ (j
+ 3) * c_dim1
];
2776 f24
= c
[i
+ 1 + (j
+ 3) * c_dim1
];
2777 f31
= c
[i
+ 2 + j
* c_dim1
];
2778 f41
= c
[i
+ 3 + j
* c_dim1
];
2779 f32
= c
[i
+ 2 + (j
+ 1) * c_dim1
];
2780 f42
= c
[i
+ 3 + (j
+ 1) * c_dim1
];
2781 f33
= c
[i
+ 2 + (j
+ 2) * c_dim1
];
2782 f43
= c
[i
+ 3 + (j
+ 2) * c_dim1
];
2783 f34
= c
[i
+ 2 + (j
+ 3) * c_dim1
];
2784 f44
= c
[i
+ 3 + (j
+ 3) * c_dim1
];
2786 for (l
= ll
; l
<= i6
; ++l
)
2788 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2789 * b
[l
+ j
* b_dim1
];
2790 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2791 * b
[l
+ j
* b_dim1
];
2792 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2793 * b
[l
+ (j
+ 1) * b_dim1
];
2794 f22
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2795 * b
[l
+ (j
+ 1) * b_dim1
];
2796 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2797 * b
[l
+ (j
+ 2) * b_dim1
];
2798 f23
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2799 * b
[l
+ (j
+ 2) * b_dim1
];
2800 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) - 257]
2801 * b
[l
+ (j
+ 3) * b_dim1
];
2802 f24
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) - 257]
2803 * b
[l
+ (j
+ 3) * b_dim1
];
2804 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2805 * b
[l
+ j
* b_dim1
];
2806 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2807 * b
[l
+ j
* b_dim1
];
2808 f32
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2809 * b
[l
+ (j
+ 1) * b_dim1
];
2810 f42
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2811 * b
[l
+ (j
+ 1) * b_dim1
];
2812 f33
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2813 * b
[l
+ (j
+ 2) * b_dim1
];
2814 f43
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2815 * b
[l
+ (j
+ 2) * b_dim1
];
2816 f34
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) - 257]
2817 * b
[l
+ (j
+ 3) * b_dim1
];
2818 f44
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) - 257]
2819 * b
[l
+ (j
+ 3) * b_dim1
];
2821 c
[i
+ j
* c_dim1
] = f11
;
2822 c
[i
+ 1 + j
* c_dim1
] = f21
;
2823 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
2824 c
[i
+ 1 + (j
+ 1) * c_dim1
] = f22
;
2825 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
2826 c
[i
+ 1 + (j
+ 2) * c_dim1
] = f23
;
2827 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
2828 c
[i
+ 1 + (j
+ 3) * c_dim1
] = f24
;
2829 c
[i
+ 2 + j
* c_dim1
] = f31
;
2830 c
[i
+ 3 + j
* c_dim1
] = f41
;
2831 c
[i
+ 2 + (j
+ 1) * c_dim1
] = f32
;
2832 c
[i
+ 3 + (j
+ 1) * c_dim1
] = f42
;
2833 c
[i
+ 2 + (j
+ 2) * c_dim1
] = f33
;
2834 c
[i
+ 3 + (j
+ 2) * c_dim1
] = f43
;
2835 c
[i
+ 2 + (j
+ 3) * c_dim1
] = f34
;
2836 c
[i
+ 3 + (j
+ 3) * c_dim1
] = f44
;
2841 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
2843 f11
= c
[i
+ j
* c_dim1
];
2844 f12
= c
[i
+ (j
+ 1) * c_dim1
];
2845 f13
= c
[i
+ (j
+ 2) * c_dim1
];
2846 f14
= c
[i
+ (j
+ 3) * c_dim1
];
2848 for (l
= ll
; l
<= i6
; ++l
)
2850 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2851 257] * b
[l
+ j
* b_dim1
];
2852 f12
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2853 257] * b
[l
+ (j
+ 1) * b_dim1
];
2854 f13
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2855 257] * b
[l
+ (j
+ 2) * b_dim1
];
2856 f14
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2857 257] * b
[l
+ (j
+ 3) * b_dim1
];
2859 c
[i
+ j
* c_dim1
] = f11
;
2860 c
[i
+ (j
+ 1) * c_dim1
] = f12
;
2861 c
[i
+ (j
+ 2) * c_dim1
] = f13
;
2862 c
[i
+ (j
+ 3) * c_dim1
] = f14
;
2869 for (j
= jj
+ ujsec
; j
<= i4
; ++j
)
2871 i5
= ii
+ uisec
- 1;
2872 for (i
= ii
; i
<= i5
; i
+= 4)
2874 f11
= c
[i
+ j
* c_dim1
];
2875 f21
= c
[i
+ 1 + j
* c_dim1
];
2876 f31
= c
[i
+ 2 + j
* c_dim1
];
2877 f41
= c
[i
+ 3 + j
* c_dim1
];
2879 for (l
= ll
; l
<= i6
; ++l
)
2881 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2882 257] * b
[l
+ j
* b_dim1
];
2883 f21
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 2) << 8) -
2884 257] * b
[l
+ j
* b_dim1
];
2885 f31
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 3) << 8) -
2886 257] * b
[l
+ j
* b_dim1
];
2887 f41
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 4) << 8) -
2888 257] * b
[l
+ j
* b_dim1
];
2890 c
[i
+ j
* c_dim1
] = f11
;
2891 c
[i
+ 1 + j
* c_dim1
] = f21
;
2892 c
[i
+ 2 + j
* c_dim1
] = f31
;
2893 c
[i
+ 3 + j
* c_dim1
] = f41
;
2896 for (i
= ii
+ uisec
; i
<= i5
; ++i
)
2898 f11
= c
[i
+ j
* c_dim1
];
2900 for (l
= ll
; l
<= i6
; ++l
)
2902 f11
+= t1
[l
- ll
+ 1 + ((i
- ii
+ 1) << 8) -
2903 257] * b
[l
+ j
* b_dim1
];
2905 c
[i
+ j
* c_dim1
] = f11
;
2915 else if (rxstride
== 1 && aystride
== 1 && bxstride
== 1)
2917 if (GFC_DESCRIPTOR_RANK (a
) != 1)
2919 const GFC_INTEGER_4
*restrict abase_x
;
2920 const GFC_INTEGER_4
*restrict bbase_y
;
2921 GFC_INTEGER_4
*restrict dest_y
;
2924 for (y
= 0; y
< ycount
; y
++)
2926 bbase_y
= &bbase
[y
*bystride
];
2927 dest_y
= &dest
[y
*rystride
];
2928 for (x
= 0; x
< xcount
; x
++)
2930 abase_x
= &abase
[x
*axstride
];
2931 s
= (GFC_INTEGER_4
) 0;
2932 for (n
= 0; n
< count
; n
++)
2933 s
+= abase_x
[n
] * bbase_y
[n
];
2940 const GFC_INTEGER_4
*restrict bbase_y
;
2943 for (y
= 0; y
< ycount
; y
++)
2945 bbase_y
= &bbase
[y
*bystride
];
2946 s
= (GFC_INTEGER_4
) 0;
2947 for (n
= 0; n
< count
; n
++)
2948 s
+= abase
[n
*axstride
] * bbase_y
[n
];
2949 dest
[y
*rystride
] = s
;
2953 else if (axstride
< aystride
)
2955 for (y
= 0; y
< ycount
; y
++)
2956 for (x
= 0; x
< xcount
; x
++)
2957 dest
[x
*rxstride
+ y
*rystride
] = (GFC_INTEGER_4
)0;
2959 for (y
= 0; y
< ycount
; y
++)
2960 for (n
= 0; n
< count
; n
++)
2961 for (x
= 0; x
< xcount
; x
++)
2962 /* dest[x,y] += a[x,n] * b[n,y] */
2963 dest
[x
*rxstride
+ y
*rystride
] +=
2964 abase
[x
*axstride
+ n
*aystride
] *
2965 bbase
[n
*bxstride
+ y
*bystride
];
2967 else if (GFC_DESCRIPTOR_RANK (a
) == 1)
2969 const GFC_INTEGER_4
*restrict bbase_y
;
2972 for (y
= 0; y
< ycount
; y
++)
2974 bbase_y
= &bbase
[y
*bystride
];
2975 s
= (GFC_INTEGER_4
) 0;
2976 for (n
= 0; n
< count
; n
++)
2977 s
+= abase
[n
*axstride
] * bbase_y
[n
*bxstride
];
2978 dest
[y
*rxstride
] = s
;
2983 const GFC_INTEGER_4
*restrict abase_x
;
2984 const GFC_INTEGER_4
*restrict bbase_y
;
2985 GFC_INTEGER_4
*restrict dest_y
;
2988 for (y
= 0; y
< ycount
; y
++)
2990 bbase_y
= &bbase
[y
*bystride
];
2991 dest_y
= &dest
[y
*rystride
];
2992 for (x
= 0; x
< xcount
; x
++)
2994 abase_x
= &abase
[x
*axstride
];
2995 s
= (GFC_INTEGER_4
) 0;
2996 for (n
= 0; n
< count
; n
++)
2997 s
+= abase_x
[n
*aystride
] * bbase_y
[n
*bxstride
];
2998 dest_y
[x
*rxstride
] = s
;