Fix hash-table violation in trans-decl.c.
[official-gcc.git] / libgfortran / generated / matmul_i4.c
blob4ef9a0a7c7437a9b8a919be58116f6203535670f
1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2018 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
27 #include <string.h>
28 #include <assert.h>
31 #if defined (HAVE_GFC_INTEGER_4)
33 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
34 passed to us by the front-end, in which case we call it for large
35 matrices. */
37 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
38 const int *, const GFC_INTEGER_4 *, const GFC_INTEGER_4 *,
39 const int *, const GFC_INTEGER_4 *, const int *,
40 const GFC_INTEGER_4 *, GFC_INTEGER_4 *, const int *,
41 int, int);
43 /* The order of loops is different in the case of plain matrix
44 multiplication C=MATMUL(A,B), and in the frequent special case where
45 the argument A is the temporary result of a TRANSPOSE intrinsic:
46 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
47 looking at their strides.
49 The equivalent Fortran pseudo-code is:
51 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
52 IF (.NOT.IS_TRANSPOSED(A)) THEN
53 C = 0
54 DO J=1,N
55 DO K=1,COUNT
56 DO I=1,M
57 C(I,J) = C(I,J)+A(I,K)*B(K,J)
58 ELSE
59 DO J=1,N
60 DO I=1,M
61 S = 0
62 DO K=1,COUNT
63 S = S+A(I,K)*B(K,J)
64 C(I,J) = S
65 ENDIF
68 /* If try_blas is set to a nonzero value, then the matmul function will
69 see if there is a way to perform the matrix multiplication by a call
70 to the BLAS gemm function. */
72 extern void matmul_i4 (gfc_array_i4 * const restrict retarray,
73 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
74 int blas_limit, blas_call gemm);
75 export_proto(matmul_i4);
77 /* Put exhaustive list of possible architectures here here, ORed together. */
79 #if defined(HAVE_AVX) || defined(HAVE_AVX2) || defined(HAVE_AVX512F)
81 #ifdef HAVE_AVX
82 static void
83 matmul_i4_avx (gfc_array_i4 * const restrict retarray,
84 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
85 int blas_limit, blas_call gemm) __attribute__((__target__("avx")));
86 static void
87 matmul_i4_avx (gfc_array_i4 * const restrict retarray,
88 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
89 int blas_limit, blas_call gemm)
91 const GFC_INTEGER_4 * restrict abase;
92 const GFC_INTEGER_4 * restrict bbase;
93 GFC_INTEGER_4 * restrict dest;
95 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
96 index_type x, y, n, count, xcount, ycount;
98 assert (GFC_DESCRIPTOR_RANK (a) == 2
99 || GFC_DESCRIPTOR_RANK (b) == 2);
101 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
103 Either A or B (but not both) can be rank 1:
105 o One-dimensional argument A is implicitly treated as a row matrix
106 dimensioned [1,count], so xcount=1.
108 o One-dimensional argument B is implicitly treated as a column matrix
109 dimensioned [count, 1], so ycount=1.
112 if (retarray->base_addr == NULL)
114 if (GFC_DESCRIPTOR_RANK (a) == 1)
116 GFC_DIMENSION_SET(retarray->dim[0], 0,
117 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
119 else if (GFC_DESCRIPTOR_RANK (b) == 1)
121 GFC_DIMENSION_SET(retarray->dim[0], 0,
122 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
124 else
126 GFC_DIMENSION_SET(retarray->dim[0], 0,
127 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
129 GFC_DIMENSION_SET(retarray->dim[1], 0,
130 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
131 GFC_DESCRIPTOR_EXTENT(retarray,0));
134 retarray->base_addr
135 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_4));
136 retarray->offset = 0;
138 else if (unlikely (compile_options.bounds_check))
140 index_type ret_extent, arg_extent;
142 if (GFC_DESCRIPTOR_RANK (a) == 1)
144 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
145 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
146 if (arg_extent != ret_extent)
147 runtime_error ("Array bound mismatch for dimension 1 of "
148 "array (%ld/%ld) ",
149 (long int) ret_extent, (long int) arg_extent);
151 else if (GFC_DESCRIPTOR_RANK (b) == 1)
153 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
154 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
155 if (arg_extent != ret_extent)
156 runtime_error ("Array bound mismatch for dimension 1 of "
157 "array (%ld/%ld) ",
158 (long int) ret_extent, (long int) arg_extent);
160 else
162 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
163 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
164 if (arg_extent != ret_extent)
165 runtime_error ("Array bound mismatch for dimension 1 of "
166 "array (%ld/%ld) ",
167 (long int) ret_extent, (long int) arg_extent);
169 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
170 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
171 if (arg_extent != ret_extent)
172 runtime_error ("Array bound mismatch for dimension 2 of "
173 "array (%ld/%ld) ",
174 (long int) ret_extent, (long int) arg_extent);
179 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
181 /* One-dimensional result may be addressed in the code below
182 either as a row or a column matrix. We want both cases to
183 work. */
184 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
186 else
188 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
189 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
193 if (GFC_DESCRIPTOR_RANK (a) == 1)
195 /* Treat it as a a row matrix A[1,count]. */
196 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
197 aystride = 1;
199 xcount = 1;
200 count = GFC_DESCRIPTOR_EXTENT(a,0);
202 else
204 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
205 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
207 count = GFC_DESCRIPTOR_EXTENT(a,1);
208 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
211 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
213 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
214 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
215 "in dimension 1: is %ld, should be %ld",
216 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
219 if (GFC_DESCRIPTOR_RANK (b) == 1)
221 /* Treat it as a column matrix B[count,1] */
222 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
224 /* bystride should never be used for 1-dimensional b.
225 The value is only used for calculation of the
226 memory by the buffer. */
227 bystride = 256;
228 ycount = 1;
230 else
232 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
233 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
234 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
237 abase = a->base_addr;
238 bbase = b->base_addr;
239 dest = retarray->base_addr;
241 /* Now that everything is set up, we perform the multiplication
242 itself. */
244 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
245 #define min(a,b) ((a) <= (b) ? (a) : (b))
246 #define max(a,b) ((a) >= (b) ? (a) : (b))
248 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
249 && (bxstride == 1 || bystride == 1)
250 && (((float) xcount) * ((float) ycount) * ((float) count)
251 > POW3(blas_limit)))
253 const int m = xcount, n = ycount, k = count, ldc = rystride;
254 const GFC_INTEGER_4 one = 1, zero = 0;
255 const int lda = (axstride == 1) ? aystride : axstride,
256 ldb = (bxstride == 1) ? bystride : bxstride;
258 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
260 assert (gemm != NULL);
261 const char *transa, *transb;
262 if (try_blas & 2)
263 transa = "C";
264 else
265 transa = axstride == 1 ? "N" : "T";
267 if (try_blas & 4)
268 transb = "C";
269 else
270 transb = bxstride == 1 ? "N" : "T";
272 gemm (transa, transb , &m,
273 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
274 &ldc, 1, 1);
275 return;
279 if (rxstride == 1 && axstride == 1 && bxstride == 1)
281 /* This block of code implements a tuned matmul, derived from
282 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
284 Bo Kagstrom and Per Ling
285 Department of Computing Science
286 Umea University
287 S-901 87 Umea, Sweden
289 from netlib.org, translated to C, and modified for matmul.m4. */
291 const GFC_INTEGER_4 *a, *b;
292 GFC_INTEGER_4 *c;
293 const index_type m = xcount, n = ycount, k = count;
295 /* System generated locals */
296 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
297 i1, i2, i3, i4, i5, i6;
299 /* Local variables */
300 GFC_INTEGER_4 f11, f12, f21, f22, f31, f32, f41, f42,
301 f13, f14, f23, f24, f33, f34, f43, f44;
302 index_type i, j, l, ii, jj, ll;
303 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
304 GFC_INTEGER_4 *t1;
306 a = abase;
307 b = bbase;
308 c = retarray->base_addr;
310 /* Parameter adjustments */
311 c_dim1 = rystride;
312 c_offset = 1 + c_dim1;
313 c -= c_offset;
314 a_dim1 = aystride;
315 a_offset = 1 + a_dim1;
316 a -= a_offset;
317 b_dim1 = bystride;
318 b_offset = 1 + b_dim1;
319 b -= b_offset;
321 /* Empty c first. */
322 for (j=1; j<=n; j++)
323 for (i=1; i<=m; i++)
324 c[i + j * c_dim1] = (GFC_INTEGER_4)0;
326 /* Early exit if possible */
327 if (m == 0 || n == 0 || k == 0)
328 return;
330 /* Adjust size of t1 to what is needed. */
331 index_type t1_dim, a_sz;
332 if (aystride == 1)
333 a_sz = rystride;
334 else
335 a_sz = a_dim1;
337 t1_dim = a_sz * 256 + b_dim1;
338 if (t1_dim > 65536)
339 t1_dim = 65536;
341 t1 = malloc (t1_dim * sizeof(GFC_INTEGER_4));
343 /* Start turning the crank. */
344 i1 = n;
345 for (jj = 1; jj <= i1; jj += 512)
347 /* Computing MIN */
348 i2 = 512;
349 i3 = n - jj + 1;
350 jsec = min(i2,i3);
351 ujsec = jsec - jsec % 4;
352 i2 = k;
353 for (ll = 1; ll <= i2; ll += 256)
355 /* Computing MIN */
356 i3 = 256;
357 i4 = k - ll + 1;
358 lsec = min(i3,i4);
359 ulsec = lsec - lsec % 2;
361 i3 = m;
362 for (ii = 1; ii <= i3; ii += 256)
364 /* Computing MIN */
365 i4 = 256;
366 i5 = m - ii + 1;
367 isec = min(i4,i5);
368 uisec = isec - isec % 2;
369 i4 = ll + ulsec - 1;
370 for (l = ll; l <= i4; l += 2)
372 i5 = ii + uisec - 1;
373 for (i = ii; i <= i5; i += 2)
375 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
376 a[i + l * a_dim1];
377 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
378 a[i + (l + 1) * a_dim1];
379 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
380 a[i + 1 + l * a_dim1];
381 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
382 a[i + 1 + (l + 1) * a_dim1];
384 if (uisec < isec)
386 t1[l - ll + 1 + (isec << 8) - 257] =
387 a[ii + isec - 1 + l * a_dim1];
388 t1[l - ll + 2 + (isec << 8) - 257] =
389 a[ii + isec - 1 + (l + 1) * a_dim1];
392 if (ulsec < lsec)
394 i4 = ii + isec - 1;
395 for (i = ii; i<= i4; ++i)
397 t1[lsec + ((i - ii + 1) << 8) - 257] =
398 a[i + (ll + lsec - 1) * a_dim1];
402 uisec = isec - isec % 4;
403 i4 = jj + ujsec - 1;
404 for (j = jj; j <= i4; j += 4)
406 i5 = ii + uisec - 1;
407 for (i = ii; i <= i5; i += 4)
409 f11 = c[i + j * c_dim1];
410 f21 = c[i + 1 + j * c_dim1];
411 f12 = c[i + (j + 1) * c_dim1];
412 f22 = c[i + 1 + (j + 1) * c_dim1];
413 f13 = c[i + (j + 2) * c_dim1];
414 f23 = c[i + 1 + (j + 2) * c_dim1];
415 f14 = c[i + (j + 3) * c_dim1];
416 f24 = c[i + 1 + (j + 3) * c_dim1];
417 f31 = c[i + 2 + j * c_dim1];
418 f41 = c[i + 3 + j * c_dim1];
419 f32 = c[i + 2 + (j + 1) * c_dim1];
420 f42 = c[i + 3 + (j + 1) * c_dim1];
421 f33 = c[i + 2 + (j + 2) * c_dim1];
422 f43 = c[i + 3 + (j + 2) * c_dim1];
423 f34 = c[i + 2 + (j + 3) * c_dim1];
424 f44 = c[i + 3 + (j + 3) * c_dim1];
425 i6 = ll + lsec - 1;
426 for (l = ll; l <= i6; ++l)
428 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
429 * b[l + j * b_dim1];
430 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
431 * b[l + j * b_dim1];
432 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
433 * b[l + (j + 1) * b_dim1];
434 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
435 * b[l + (j + 1) * b_dim1];
436 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
437 * b[l + (j + 2) * b_dim1];
438 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
439 * b[l + (j + 2) * b_dim1];
440 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
441 * b[l + (j + 3) * b_dim1];
442 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
443 * b[l + (j + 3) * b_dim1];
444 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
445 * b[l + j * b_dim1];
446 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
447 * b[l + j * b_dim1];
448 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
449 * b[l + (j + 1) * b_dim1];
450 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
451 * b[l + (j + 1) * b_dim1];
452 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
453 * b[l + (j + 2) * b_dim1];
454 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
455 * b[l + (j + 2) * b_dim1];
456 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
457 * b[l + (j + 3) * b_dim1];
458 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
459 * b[l + (j + 3) * b_dim1];
461 c[i + j * c_dim1] = f11;
462 c[i + 1 + j * c_dim1] = f21;
463 c[i + (j + 1) * c_dim1] = f12;
464 c[i + 1 + (j + 1) * c_dim1] = f22;
465 c[i + (j + 2) * c_dim1] = f13;
466 c[i + 1 + (j + 2) * c_dim1] = f23;
467 c[i + (j + 3) * c_dim1] = f14;
468 c[i + 1 + (j + 3) * c_dim1] = f24;
469 c[i + 2 + j * c_dim1] = f31;
470 c[i + 3 + j * c_dim1] = f41;
471 c[i + 2 + (j + 1) * c_dim1] = f32;
472 c[i + 3 + (j + 1) * c_dim1] = f42;
473 c[i + 2 + (j + 2) * c_dim1] = f33;
474 c[i + 3 + (j + 2) * c_dim1] = f43;
475 c[i + 2 + (j + 3) * c_dim1] = f34;
476 c[i + 3 + (j + 3) * c_dim1] = f44;
478 if (uisec < isec)
480 i5 = ii + isec - 1;
481 for (i = ii + uisec; i <= i5; ++i)
483 f11 = c[i + j * c_dim1];
484 f12 = c[i + (j + 1) * c_dim1];
485 f13 = c[i + (j + 2) * c_dim1];
486 f14 = c[i + (j + 3) * c_dim1];
487 i6 = ll + lsec - 1;
488 for (l = ll; l <= i6; ++l)
490 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
491 257] * b[l + j * b_dim1];
492 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
493 257] * b[l + (j + 1) * b_dim1];
494 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
495 257] * b[l + (j + 2) * b_dim1];
496 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
497 257] * b[l + (j + 3) * b_dim1];
499 c[i + j * c_dim1] = f11;
500 c[i + (j + 1) * c_dim1] = f12;
501 c[i + (j + 2) * c_dim1] = f13;
502 c[i + (j + 3) * c_dim1] = f14;
506 if (ujsec < jsec)
508 i4 = jj + jsec - 1;
509 for (j = jj + ujsec; j <= i4; ++j)
511 i5 = ii + uisec - 1;
512 for (i = ii; i <= i5; i += 4)
514 f11 = c[i + j * c_dim1];
515 f21 = c[i + 1 + j * c_dim1];
516 f31 = c[i + 2 + j * c_dim1];
517 f41 = c[i + 3 + j * c_dim1];
518 i6 = ll + lsec - 1;
519 for (l = ll; l <= i6; ++l)
521 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
522 257] * b[l + j * b_dim1];
523 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
524 257] * b[l + j * b_dim1];
525 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
526 257] * b[l + j * b_dim1];
527 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
528 257] * b[l + j * b_dim1];
530 c[i + j * c_dim1] = f11;
531 c[i + 1 + j * c_dim1] = f21;
532 c[i + 2 + j * c_dim1] = f31;
533 c[i + 3 + j * c_dim1] = f41;
535 i5 = ii + isec - 1;
536 for (i = ii + uisec; i <= i5; ++i)
538 f11 = c[i + j * c_dim1];
539 i6 = ll + lsec - 1;
540 for (l = ll; l <= i6; ++l)
542 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
543 257] * b[l + j * b_dim1];
545 c[i + j * c_dim1] = f11;
552 free(t1);
553 return;
555 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
557 if (GFC_DESCRIPTOR_RANK (a) != 1)
559 const GFC_INTEGER_4 *restrict abase_x;
560 const GFC_INTEGER_4 *restrict bbase_y;
561 GFC_INTEGER_4 *restrict dest_y;
562 GFC_INTEGER_4 s;
564 for (y = 0; y < ycount; y++)
566 bbase_y = &bbase[y*bystride];
567 dest_y = &dest[y*rystride];
568 for (x = 0; x < xcount; x++)
570 abase_x = &abase[x*axstride];
571 s = (GFC_INTEGER_4) 0;
572 for (n = 0; n < count; n++)
573 s += abase_x[n] * bbase_y[n];
574 dest_y[x] = s;
578 else
580 const GFC_INTEGER_4 *restrict bbase_y;
581 GFC_INTEGER_4 s;
583 for (y = 0; y < ycount; y++)
585 bbase_y = &bbase[y*bystride];
586 s = (GFC_INTEGER_4) 0;
587 for (n = 0; n < count; n++)
588 s += abase[n*axstride] * bbase_y[n];
589 dest[y*rystride] = s;
593 else if (axstride < aystride)
595 for (y = 0; y < ycount; y++)
596 for (x = 0; x < xcount; x++)
597 dest[x*rxstride + y*rystride] = (GFC_INTEGER_4)0;
599 for (y = 0; y < ycount; y++)
600 for (n = 0; n < count; n++)
601 for (x = 0; x < xcount; x++)
602 /* dest[x,y] += a[x,n] * b[n,y] */
603 dest[x*rxstride + y*rystride] +=
604 abase[x*axstride + n*aystride] *
605 bbase[n*bxstride + y*bystride];
607 else if (GFC_DESCRIPTOR_RANK (a) == 1)
609 const GFC_INTEGER_4 *restrict bbase_y;
610 GFC_INTEGER_4 s;
612 for (y = 0; y < ycount; y++)
614 bbase_y = &bbase[y*bystride];
615 s = (GFC_INTEGER_4) 0;
616 for (n = 0; n < count; n++)
617 s += abase[n*axstride] * bbase_y[n*bxstride];
618 dest[y*rxstride] = s;
621 else
623 const GFC_INTEGER_4 *restrict abase_x;
624 const GFC_INTEGER_4 *restrict bbase_y;
625 GFC_INTEGER_4 *restrict dest_y;
626 GFC_INTEGER_4 s;
628 for (y = 0; y < ycount; y++)
630 bbase_y = &bbase[y*bystride];
631 dest_y = &dest[y*rystride];
632 for (x = 0; x < xcount; x++)
634 abase_x = &abase[x*axstride];
635 s = (GFC_INTEGER_4) 0;
636 for (n = 0; n < count; n++)
637 s += abase_x[n*aystride] * bbase_y[n*bxstride];
638 dest_y[x*rxstride] = s;
643 #undef POW3
644 #undef min
645 #undef max
647 #endif /* HAVE_AVX */
649 #ifdef HAVE_AVX2
650 static void
651 matmul_i4_avx2 (gfc_array_i4 * const restrict retarray,
652 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
653 int blas_limit, blas_call gemm) __attribute__((__target__("avx2,fma")));
654 static void
655 matmul_i4_avx2 (gfc_array_i4 * const restrict retarray,
656 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
657 int blas_limit, blas_call gemm)
659 const GFC_INTEGER_4 * restrict abase;
660 const GFC_INTEGER_4 * restrict bbase;
661 GFC_INTEGER_4 * restrict dest;
663 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
664 index_type x, y, n, count, xcount, ycount;
666 assert (GFC_DESCRIPTOR_RANK (a) == 2
667 || GFC_DESCRIPTOR_RANK (b) == 2);
669 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
671 Either A or B (but not both) can be rank 1:
673 o One-dimensional argument A is implicitly treated as a row matrix
674 dimensioned [1,count], so xcount=1.
676 o One-dimensional argument B is implicitly treated as a column matrix
677 dimensioned [count, 1], so ycount=1.
680 if (retarray->base_addr == NULL)
682 if (GFC_DESCRIPTOR_RANK (a) == 1)
684 GFC_DIMENSION_SET(retarray->dim[0], 0,
685 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
687 else if (GFC_DESCRIPTOR_RANK (b) == 1)
689 GFC_DIMENSION_SET(retarray->dim[0], 0,
690 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
692 else
694 GFC_DIMENSION_SET(retarray->dim[0], 0,
695 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
697 GFC_DIMENSION_SET(retarray->dim[1], 0,
698 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
699 GFC_DESCRIPTOR_EXTENT(retarray,0));
702 retarray->base_addr
703 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_4));
704 retarray->offset = 0;
706 else if (unlikely (compile_options.bounds_check))
708 index_type ret_extent, arg_extent;
710 if (GFC_DESCRIPTOR_RANK (a) == 1)
712 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
713 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
714 if (arg_extent != ret_extent)
715 runtime_error ("Array bound mismatch for dimension 1 of "
716 "array (%ld/%ld) ",
717 (long int) ret_extent, (long int) arg_extent);
719 else if (GFC_DESCRIPTOR_RANK (b) == 1)
721 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
722 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
723 if (arg_extent != ret_extent)
724 runtime_error ("Array bound mismatch for dimension 1 of "
725 "array (%ld/%ld) ",
726 (long int) ret_extent, (long int) arg_extent);
728 else
730 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
731 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
732 if (arg_extent != ret_extent)
733 runtime_error ("Array bound mismatch for dimension 1 of "
734 "array (%ld/%ld) ",
735 (long int) ret_extent, (long int) arg_extent);
737 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
738 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
739 if (arg_extent != ret_extent)
740 runtime_error ("Array bound mismatch for dimension 2 of "
741 "array (%ld/%ld) ",
742 (long int) ret_extent, (long int) arg_extent);
747 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
749 /* One-dimensional result may be addressed in the code below
750 either as a row or a column matrix. We want both cases to
751 work. */
752 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
754 else
756 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
757 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
761 if (GFC_DESCRIPTOR_RANK (a) == 1)
763 /* Treat it as a a row matrix A[1,count]. */
764 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
765 aystride = 1;
767 xcount = 1;
768 count = GFC_DESCRIPTOR_EXTENT(a,0);
770 else
772 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
773 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
775 count = GFC_DESCRIPTOR_EXTENT(a,1);
776 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
779 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
781 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
782 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
783 "in dimension 1: is %ld, should be %ld",
784 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
787 if (GFC_DESCRIPTOR_RANK (b) == 1)
789 /* Treat it as a column matrix B[count,1] */
790 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
792 /* bystride should never be used for 1-dimensional b.
793 The value is only used for calculation of the
794 memory by the buffer. */
795 bystride = 256;
796 ycount = 1;
798 else
800 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
801 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
802 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
805 abase = a->base_addr;
806 bbase = b->base_addr;
807 dest = retarray->base_addr;
809 /* Now that everything is set up, we perform the multiplication
810 itself. */
812 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
813 #define min(a,b) ((a) <= (b) ? (a) : (b))
814 #define max(a,b) ((a) >= (b) ? (a) : (b))
816 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
817 && (bxstride == 1 || bystride == 1)
818 && (((float) xcount) * ((float) ycount) * ((float) count)
819 > POW3(blas_limit)))
821 const int m = xcount, n = ycount, k = count, ldc = rystride;
822 const GFC_INTEGER_4 one = 1, zero = 0;
823 const int lda = (axstride == 1) ? aystride : axstride,
824 ldb = (bxstride == 1) ? bystride : bxstride;
826 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
828 assert (gemm != NULL);
829 const char *transa, *transb;
830 if (try_blas & 2)
831 transa = "C";
832 else
833 transa = axstride == 1 ? "N" : "T";
835 if (try_blas & 4)
836 transb = "C";
837 else
838 transb = bxstride == 1 ? "N" : "T";
840 gemm (transa, transb , &m,
841 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
842 &ldc, 1, 1);
843 return;
847 if (rxstride == 1 && axstride == 1 && bxstride == 1)
849 /* This block of code implements a tuned matmul, derived from
850 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
852 Bo Kagstrom and Per Ling
853 Department of Computing Science
854 Umea University
855 S-901 87 Umea, Sweden
857 from netlib.org, translated to C, and modified for matmul.m4. */
859 const GFC_INTEGER_4 *a, *b;
860 GFC_INTEGER_4 *c;
861 const index_type m = xcount, n = ycount, k = count;
863 /* System generated locals */
864 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
865 i1, i2, i3, i4, i5, i6;
867 /* Local variables */
868 GFC_INTEGER_4 f11, f12, f21, f22, f31, f32, f41, f42,
869 f13, f14, f23, f24, f33, f34, f43, f44;
870 index_type i, j, l, ii, jj, ll;
871 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
872 GFC_INTEGER_4 *t1;
874 a = abase;
875 b = bbase;
876 c = retarray->base_addr;
878 /* Parameter adjustments */
879 c_dim1 = rystride;
880 c_offset = 1 + c_dim1;
881 c -= c_offset;
882 a_dim1 = aystride;
883 a_offset = 1 + a_dim1;
884 a -= a_offset;
885 b_dim1 = bystride;
886 b_offset = 1 + b_dim1;
887 b -= b_offset;
889 /* Empty c first. */
890 for (j=1; j<=n; j++)
891 for (i=1; i<=m; i++)
892 c[i + j * c_dim1] = (GFC_INTEGER_4)0;
894 /* Early exit if possible */
895 if (m == 0 || n == 0 || k == 0)
896 return;
898 /* Adjust size of t1 to what is needed. */
899 index_type t1_dim, a_sz;
900 if (aystride == 1)
901 a_sz = rystride;
902 else
903 a_sz = a_dim1;
905 t1_dim = a_sz * 256 + b_dim1;
906 if (t1_dim > 65536)
907 t1_dim = 65536;
909 t1 = malloc (t1_dim * sizeof(GFC_INTEGER_4));
911 /* Start turning the crank. */
912 i1 = n;
913 for (jj = 1; jj <= i1; jj += 512)
915 /* Computing MIN */
916 i2 = 512;
917 i3 = n - jj + 1;
918 jsec = min(i2,i3);
919 ujsec = jsec - jsec % 4;
920 i2 = k;
921 for (ll = 1; ll <= i2; ll += 256)
923 /* Computing MIN */
924 i3 = 256;
925 i4 = k - ll + 1;
926 lsec = min(i3,i4);
927 ulsec = lsec - lsec % 2;
929 i3 = m;
930 for (ii = 1; ii <= i3; ii += 256)
932 /* Computing MIN */
933 i4 = 256;
934 i5 = m - ii + 1;
935 isec = min(i4,i5);
936 uisec = isec - isec % 2;
937 i4 = ll + ulsec - 1;
938 for (l = ll; l <= i4; l += 2)
940 i5 = ii + uisec - 1;
941 for (i = ii; i <= i5; i += 2)
943 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
944 a[i + l * a_dim1];
945 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
946 a[i + (l + 1) * a_dim1];
947 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
948 a[i + 1 + l * a_dim1];
949 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
950 a[i + 1 + (l + 1) * a_dim1];
952 if (uisec < isec)
954 t1[l - ll + 1 + (isec << 8) - 257] =
955 a[ii + isec - 1 + l * a_dim1];
956 t1[l - ll + 2 + (isec << 8) - 257] =
957 a[ii + isec - 1 + (l + 1) * a_dim1];
960 if (ulsec < lsec)
962 i4 = ii + isec - 1;
963 for (i = ii; i<= i4; ++i)
965 t1[lsec + ((i - ii + 1) << 8) - 257] =
966 a[i + (ll + lsec - 1) * a_dim1];
970 uisec = isec - isec % 4;
971 i4 = jj + ujsec - 1;
972 for (j = jj; j <= i4; j += 4)
974 i5 = ii + uisec - 1;
975 for (i = ii; i <= i5; i += 4)
977 f11 = c[i + j * c_dim1];
978 f21 = c[i + 1 + j * c_dim1];
979 f12 = c[i + (j + 1) * c_dim1];
980 f22 = c[i + 1 + (j + 1) * c_dim1];
981 f13 = c[i + (j + 2) * c_dim1];
982 f23 = c[i + 1 + (j + 2) * c_dim1];
983 f14 = c[i + (j + 3) * c_dim1];
984 f24 = c[i + 1 + (j + 3) * c_dim1];
985 f31 = c[i + 2 + j * c_dim1];
986 f41 = c[i + 3 + j * c_dim1];
987 f32 = c[i + 2 + (j + 1) * c_dim1];
988 f42 = c[i + 3 + (j + 1) * c_dim1];
989 f33 = c[i + 2 + (j + 2) * c_dim1];
990 f43 = c[i + 3 + (j + 2) * c_dim1];
991 f34 = c[i + 2 + (j + 3) * c_dim1];
992 f44 = c[i + 3 + (j + 3) * c_dim1];
993 i6 = ll + lsec - 1;
994 for (l = ll; l <= i6; ++l)
996 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
997 * b[l + j * b_dim1];
998 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
999 * b[l + j * b_dim1];
1000 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1001 * b[l + (j + 1) * b_dim1];
1002 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1003 * b[l + (j + 1) * b_dim1];
1004 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1005 * b[l + (j + 2) * b_dim1];
1006 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1007 * b[l + (j + 2) * b_dim1];
1008 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1009 * b[l + (j + 3) * b_dim1];
1010 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1011 * b[l + (j + 3) * b_dim1];
1012 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1013 * b[l + j * b_dim1];
1014 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1015 * b[l + j * b_dim1];
1016 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1017 * b[l + (j + 1) * b_dim1];
1018 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1019 * b[l + (j + 1) * b_dim1];
1020 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1021 * b[l + (j + 2) * b_dim1];
1022 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1023 * b[l + (j + 2) * b_dim1];
1024 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1025 * b[l + (j + 3) * b_dim1];
1026 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1027 * b[l + (j + 3) * b_dim1];
1029 c[i + j * c_dim1] = f11;
1030 c[i + 1 + j * c_dim1] = f21;
1031 c[i + (j + 1) * c_dim1] = f12;
1032 c[i + 1 + (j + 1) * c_dim1] = f22;
1033 c[i + (j + 2) * c_dim1] = f13;
1034 c[i + 1 + (j + 2) * c_dim1] = f23;
1035 c[i + (j + 3) * c_dim1] = f14;
1036 c[i + 1 + (j + 3) * c_dim1] = f24;
1037 c[i + 2 + j * c_dim1] = f31;
1038 c[i + 3 + j * c_dim1] = f41;
1039 c[i + 2 + (j + 1) * c_dim1] = f32;
1040 c[i + 3 + (j + 1) * c_dim1] = f42;
1041 c[i + 2 + (j + 2) * c_dim1] = f33;
1042 c[i + 3 + (j + 2) * c_dim1] = f43;
1043 c[i + 2 + (j + 3) * c_dim1] = f34;
1044 c[i + 3 + (j + 3) * c_dim1] = f44;
1046 if (uisec < isec)
1048 i5 = ii + isec - 1;
1049 for (i = ii + uisec; i <= i5; ++i)
1051 f11 = c[i + j * c_dim1];
1052 f12 = c[i + (j + 1) * c_dim1];
1053 f13 = c[i + (j + 2) * c_dim1];
1054 f14 = c[i + (j + 3) * c_dim1];
1055 i6 = ll + lsec - 1;
1056 for (l = ll; l <= i6; ++l)
1058 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1059 257] * b[l + j * b_dim1];
1060 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1061 257] * b[l + (j + 1) * b_dim1];
1062 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1063 257] * b[l + (j + 2) * b_dim1];
1064 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1065 257] * b[l + (j + 3) * b_dim1];
1067 c[i + j * c_dim1] = f11;
1068 c[i + (j + 1) * c_dim1] = f12;
1069 c[i + (j + 2) * c_dim1] = f13;
1070 c[i + (j + 3) * c_dim1] = f14;
1074 if (ujsec < jsec)
1076 i4 = jj + jsec - 1;
1077 for (j = jj + ujsec; j <= i4; ++j)
1079 i5 = ii + uisec - 1;
1080 for (i = ii; i <= i5; i += 4)
1082 f11 = c[i + j * c_dim1];
1083 f21 = c[i + 1 + j * c_dim1];
1084 f31 = c[i + 2 + j * c_dim1];
1085 f41 = c[i + 3 + j * c_dim1];
1086 i6 = ll + lsec - 1;
1087 for (l = ll; l <= i6; ++l)
1089 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1090 257] * b[l + j * b_dim1];
1091 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
1092 257] * b[l + j * b_dim1];
1093 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
1094 257] * b[l + j * b_dim1];
1095 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
1096 257] * b[l + j * b_dim1];
1098 c[i + j * c_dim1] = f11;
1099 c[i + 1 + j * c_dim1] = f21;
1100 c[i + 2 + j * c_dim1] = f31;
1101 c[i + 3 + j * c_dim1] = f41;
1103 i5 = ii + isec - 1;
1104 for (i = ii + uisec; i <= i5; ++i)
1106 f11 = c[i + j * c_dim1];
1107 i6 = ll + lsec - 1;
1108 for (l = ll; l <= i6; ++l)
1110 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1111 257] * b[l + j * b_dim1];
1113 c[i + j * c_dim1] = f11;
1120 free(t1);
1121 return;
1123 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
1125 if (GFC_DESCRIPTOR_RANK (a) != 1)
1127 const GFC_INTEGER_4 *restrict abase_x;
1128 const GFC_INTEGER_4 *restrict bbase_y;
1129 GFC_INTEGER_4 *restrict dest_y;
1130 GFC_INTEGER_4 s;
1132 for (y = 0; y < ycount; y++)
1134 bbase_y = &bbase[y*bystride];
1135 dest_y = &dest[y*rystride];
1136 for (x = 0; x < xcount; x++)
1138 abase_x = &abase[x*axstride];
1139 s = (GFC_INTEGER_4) 0;
1140 for (n = 0; n < count; n++)
1141 s += abase_x[n] * bbase_y[n];
1142 dest_y[x] = s;
1146 else
1148 const GFC_INTEGER_4 *restrict bbase_y;
1149 GFC_INTEGER_4 s;
1151 for (y = 0; y < ycount; y++)
1153 bbase_y = &bbase[y*bystride];
1154 s = (GFC_INTEGER_4) 0;
1155 for (n = 0; n < count; n++)
1156 s += abase[n*axstride] * bbase_y[n];
1157 dest[y*rystride] = s;
1161 else if (axstride < aystride)
1163 for (y = 0; y < ycount; y++)
1164 for (x = 0; x < xcount; x++)
1165 dest[x*rxstride + y*rystride] = (GFC_INTEGER_4)0;
1167 for (y = 0; y < ycount; y++)
1168 for (n = 0; n < count; n++)
1169 for (x = 0; x < xcount; x++)
1170 /* dest[x,y] += a[x,n] * b[n,y] */
1171 dest[x*rxstride + y*rystride] +=
1172 abase[x*axstride + n*aystride] *
1173 bbase[n*bxstride + y*bystride];
1175 else if (GFC_DESCRIPTOR_RANK (a) == 1)
1177 const GFC_INTEGER_4 *restrict bbase_y;
1178 GFC_INTEGER_4 s;
1180 for (y = 0; y < ycount; y++)
1182 bbase_y = &bbase[y*bystride];
1183 s = (GFC_INTEGER_4) 0;
1184 for (n = 0; n < count; n++)
1185 s += abase[n*axstride] * bbase_y[n*bxstride];
1186 dest[y*rxstride] = s;
1189 else
1191 const GFC_INTEGER_4 *restrict abase_x;
1192 const GFC_INTEGER_4 *restrict bbase_y;
1193 GFC_INTEGER_4 *restrict dest_y;
1194 GFC_INTEGER_4 s;
1196 for (y = 0; y < ycount; y++)
1198 bbase_y = &bbase[y*bystride];
1199 dest_y = &dest[y*rystride];
1200 for (x = 0; x < xcount; x++)
1202 abase_x = &abase[x*axstride];
1203 s = (GFC_INTEGER_4) 0;
1204 for (n = 0; n < count; n++)
1205 s += abase_x[n*aystride] * bbase_y[n*bxstride];
1206 dest_y[x*rxstride] = s;
1211 #undef POW3
1212 #undef min
1213 #undef max
1215 #endif /* HAVE_AVX2 */
1217 #ifdef HAVE_AVX512F
1218 static void
1219 matmul_i4_avx512f (gfc_array_i4 * const restrict retarray,
1220 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
1221 int blas_limit, blas_call gemm) __attribute__((__target__("avx512f")));
1222 static void
1223 matmul_i4_avx512f (gfc_array_i4 * const restrict retarray,
1224 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
1225 int blas_limit, blas_call gemm)
1227 const GFC_INTEGER_4 * restrict abase;
1228 const GFC_INTEGER_4 * restrict bbase;
1229 GFC_INTEGER_4 * restrict dest;
1231 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
1232 index_type x, y, n, count, xcount, ycount;
1234 assert (GFC_DESCRIPTOR_RANK (a) == 2
1235 || GFC_DESCRIPTOR_RANK (b) == 2);
1237 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
1239 Either A or B (but not both) can be rank 1:
1241 o One-dimensional argument A is implicitly treated as a row matrix
1242 dimensioned [1,count], so xcount=1.
1244 o One-dimensional argument B is implicitly treated as a column matrix
1245 dimensioned [count, 1], so ycount=1.
1248 if (retarray->base_addr == NULL)
1250 if (GFC_DESCRIPTOR_RANK (a) == 1)
1252 GFC_DIMENSION_SET(retarray->dim[0], 0,
1253 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
1255 else if (GFC_DESCRIPTOR_RANK (b) == 1)
1257 GFC_DIMENSION_SET(retarray->dim[0], 0,
1258 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
1260 else
1262 GFC_DIMENSION_SET(retarray->dim[0], 0,
1263 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
1265 GFC_DIMENSION_SET(retarray->dim[1], 0,
1266 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
1267 GFC_DESCRIPTOR_EXTENT(retarray,0));
1270 retarray->base_addr
1271 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_4));
1272 retarray->offset = 0;
1274 else if (unlikely (compile_options.bounds_check))
1276 index_type ret_extent, arg_extent;
1278 if (GFC_DESCRIPTOR_RANK (a) == 1)
1280 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
1281 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1282 if (arg_extent != ret_extent)
1283 runtime_error ("Array bound mismatch for dimension 1 of "
1284 "array (%ld/%ld) ",
1285 (long int) ret_extent, (long int) arg_extent);
1287 else if (GFC_DESCRIPTOR_RANK (b) == 1)
1289 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
1290 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1291 if (arg_extent != ret_extent)
1292 runtime_error ("Array bound mismatch for dimension 1 of "
1293 "array (%ld/%ld) ",
1294 (long int) ret_extent, (long int) arg_extent);
1296 else
1298 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
1299 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1300 if (arg_extent != ret_extent)
1301 runtime_error ("Array bound mismatch for dimension 1 of "
1302 "array (%ld/%ld) ",
1303 (long int) ret_extent, (long int) arg_extent);
1305 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
1306 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
1307 if (arg_extent != ret_extent)
1308 runtime_error ("Array bound mismatch for dimension 2 of "
1309 "array (%ld/%ld) ",
1310 (long int) ret_extent, (long int) arg_extent);
1315 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
1317 /* One-dimensional result may be addressed in the code below
1318 either as a row or a column matrix. We want both cases to
1319 work. */
1320 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
1322 else
1324 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
1325 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
1329 if (GFC_DESCRIPTOR_RANK (a) == 1)
1331 /* Treat it as a a row matrix A[1,count]. */
1332 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
1333 aystride = 1;
1335 xcount = 1;
1336 count = GFC_DESCRIPTOR_EXTENT(a,0);
1338 else
1340 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
1341 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
1343 count = GFC_DESCRIPTOR_EXTENT(a,1);
1344 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
1347 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
1349 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
1350 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
1351 "in dimension 1: is %ld, should be %ld",
1352 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
1355 if (GFC_DESCRIPTOR_RANK (b) == 1)
1357 /* Treat it as a column matrix B[count,1] */
1358 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
1360 /* bystride should never be used for 1-dimensional b.
1361 The value is only used for calculation of the
1362 memory by the buffer. */
1363 bystride = 256;
1364 ycount = 1;
1366 else
1368 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
1369 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
1370 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
1373 abase = a->base_addr;
1374 bbase = b->base_addr;
1375 dest = retarray->base_addr;
1377 /* Now that everything is set up, we perform the multiplication
1378 itself. */
1380 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
1381 #define min(a,b) ((a) <= (b) ? (a) : (b))
1382 #define max(a,b) ((a) >= (b) ? (a) : (b))
1384 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
1385 && (bxstride == 1 || bystride == 1)
1386 && (((float) xcount) * ((float) ycount) * ((float) count)
1387 > POW3(blas_limit)))
1389 const int m = xcount, n = ycount, k = count, ldc = rystride;
1390 const GFC_INTEGER_4 one = 1, zero = 0;
1391 const int lda = (axstride == 1) ? aystride : axstride,
1392 ldb = (bxstride == 1) ? bystride : bxstride;
1394 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
1396 assert (gemm != NULL);
1397 const char *transa, *transb;
1398 if (try_blas & 2)
1399 transa = "C";
1400 else
1401 transa = axstride == 1 ? "N" : "T";
1403 if (try_blas & 4)
1404 transb = "C";
1405 else
1406 transb = bxstride == 1 ? "N" : "T";
1408 gemm (transa, transb , &m,
1409 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
1410 &ldc, 1, 1);
1411 return;
1415 if (rxstride == 1 && axstride == 1 && bxstride == 1)
1417 /* This block of code implements a tuned matmul, derived from
1418 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
1420 Bo Kagstrom and Per Ling
1421 Department of Computing Science
1422 Umea University
1423 S-901 87 Umea, Sweden
1425 from netlib.org, translated to C, and modified for matmul.m4. */
1427 const GFC_INTEGER_4 *a, *b;
1428 GFC_INTEGER_4 *c;
1429 const index_type m = xcount, n = ycount, k = count;
1431 /* System generated locals */
1432 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
1433 i1, i2, i3, i4, i5, i6;
1435 /* Local variables */
1436 GFC_INTEGER_4 f11, f12, f21, f22, f31, f32, f41, f42,
1437 f13, f14, f23, f24, f33, f34, f43, f44;
1438 index_type i, j, l, ii, jj, ll;
1439 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
1440 GFC_INTEGER_4 *t1;
1442 a = abase;
1443 b = bbase;
1444 c = retarray->base_addr;
1446 /* Parameter adjustments */
1447 c_dim1 = rystride;
1448 c_offset = 1 + c_dim1;
1449 c -= c_offset;
1450 a_dim1 = aystride;
1451 a_offset = 1 + a_dim1;
1452 a -= a_offset;
1453 b_dim1 = bystride;
1454 b_offset = 1 + b_dim1;
1455 b -= b_offset;
1457 /* Empty c first. */
1458 for (j=1; j<=n; j++)
1459 for (i=1; i<=m; i++)
1460 c[i + j * c_dim1] = (GFC_INTEGER_4)0;
1462 /* Early exit if possible */
1463 if (m == 0 || n == 0 || k == 0)
1464 return;
1466 /* Adjust size of t1 to what is needed. */
1467 index_type t1_dim, a_sz;
1468 if (aystride == 1)
1469 a_sz = rystride;
1470 else
1471 a_sz = a_dim1;
1473 t1_dim = a_sz * 256 + b_dim1;
1474 if (t1_dim > 65536)
1475 t1_dim = 65536;
1477 t1 = malloc (t1_dim * sizeof(GFC_INTEGER_4));
1479 /* Start turning the crank. */
1480 i1 = n;
1481 for (jj = 1; jj <= i1; jj += 512)
1483 /* Computing MIN */
1484 i2 = 512;
1485 i3 = n - jj + 1;
1486 jsec = min(i2,i3);
1487 ujsec = jsec - jsec % 4;
1488 i2 = k;
1489 for (ll = 1; ll <= i2; ll += 256)
1491 /* Computing MIN */
1492 i3 = 256;
1493 i4 = k - ll + 1;
1494 lsec = min(i3,i4);
1495 ulsec = lsec - lsec % 2;
1497 i3 = m;
1498 for (ii = 1; ii <= i3; ii += 256)
1500 /* Computing MIN */
1501 i4 = 256;
1502 i5 = m - ii + 1;
1503 isec = min(i4,i5);
1504 uisec = isec - isec % 2;
1505 i4 = ll + ulsec - 1;
1506 for (l = ll; l <= i4; l += 2)
1508 i5 = ii + uisec - 1;
1509 for (i = ii; i <= i5; i += 2)
1511 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
1512 a[i + l * a_dim1];
1513 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
1514 a[i + (l + 1) * a_dim1];
1515 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
1516 a[i + 1 + l * a_dim1];
1517 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
1518 a[i + 1 + (l + 1) * a_dim1];
1520 if (uisec < isec)
1522 t1[l - ll + 1 + (isec << 8) - 257] =
1523 a[ii + isec - 1 + l * a_dim1];
1524 t1[l - ll + 2 + (isec << 8) - 257] =
1525 a[ii + isec - 1 + (l + 1) * a_dim1];
1528 if (ulsec < lsec)
1530 i4 = ii + isec - 1;
1531 for (i = ii; i<= i4; ++i)
1533 t1[lsec + ((i - ii + 1) << 8) - 257] =
1534 a[i + (ll + lsec - 1) * a_dim1];
1538 uisec = isec - isec % 4;
1539 i4 = jj + ujsec - 1;
1540 for (j = jj; j <= i4; j += 4)
1542 i5 = ii + uisec - 1;
1543 for (i = ii; i <= i5; i += 4)
1545 f11 = c[i + j * c_dim1];
1546 f21 = c[i + 1 + j * c_dim1];
1547 f12 = c[i + (j + 1) * c_dim1];
1548 f22 = c[i + 1 + (j + 1) * c_dim1];
1549 f13 = c[i + (j + 2) * c_dim1];
1550 f23 = c[i + 1 + (j + 2) * c_dim1];
1551 f14 = c[i + (j + 3) * c_dim1];
1552 f24 = c[i + 1 + (j + 3) * c_dim1];
1553 f31 = c[i + 2 + j * c_dim1];
1554 f41 = c[i + 3 + j * c_dim1];
1555 f32 = c[i + 2 + (j + 1) * c_dim1];
1556 f42 = c[i + 3 + (j + 1) * c_dim1];
1557 f33 = c[i + 2 + (j + 2) * c_dim1];
1558 f43 = c[i + 3 + (j + 2) * c_dim1];
1559 f34 = c[i + 2 + (j + 3) * c_dim1];
1560 f44 = c[i + 3 + (j + 3) * c_dim1];
1561 i6 = ll + lsec - 1;
1562 for (l = ll; l <= i6; ++l)
1564 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1565 * b[l + j * b_dim1];
1566 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1567 * b[l + j * b_dim1];
1568 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1569 * b[l + (j + 1) * b_dim1];
1570 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1571 * b[l + (j + 1) * b_dim1];
1572 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1573 * b[l + (j + 2) * b_dim1];
1574 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1575 * b[l + (j + 2) * b_dim1];
1576 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1577 * b[l + (j + 3) * b_dim1];
1578 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1579 * b[l + (j + 3) * b_dim1];
1580 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1581 * b[l + j * b_dim1];
1582 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1583 * b[l + j * b_dim1];
1584 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1585 * b[l + (j + 1) * b_dim1];
1586 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1587 * b[l + (j + 1) * b_dim1];
1588 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1589 * b[l + (j + 2) * b_dim1];
1590 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1591 * b[l + (j + 2) * b_dim1];
1592 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1593 * b[l + (j + 3) * b_dim1];
1594 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1595 * b[l + (j + 3) * b_dim1];
1597 c[i + j * c_dim1] = f11;
1598 c[i + 1 + j * c_dim1] = f21;
1599 c[i + (j + 1) * c_dim1] = f12;
1600 c[i + 1 + (j + 1) * c_dim1] = f22;
1601 c[i + (j + 2) * c_dim1] = f13;
1602 c[i + 1 + (j + 2) * c_dim1] = f23;
1603 c[i + (j + 3) * c_dim1] = f14;
1604 c[i + 1 + (j + 3) * c_dim1] = f24;
1605 c[i + 2 + j * c_dim1] = f31;
1606 c[i + 3 + j * c_dim1] = f41;
1607 c[i + 2 + (j + 1) * c_dim1] = f32;
1608 c[i + 3 + (j + 1) * c_dim1] = f42;
1609 c[i + 2 + (j + 2) * c_dim1] = f33;
1610 c[i + 3 + (j + 2) * c_dim1] = f43;
1611 c[i + 2 + (j + 3) * c_dim1] = f34;
1612 c[i + 3 + (j + 3) * c_dim1] = f44;
1614 if (uisec < isec)
1616 i5 = ii + isec - 1;
1617 for (i = ii + uisec; i <= i5; ++i)
1619 f11 = c[i + j * c_dim1];
1620 f12 = c[i + (j + 1) * c_dim1];
1621 f13 = c[i + (j + 2) * c_dim1];
1622 f14 = c[i + (j + 3) * c_dim1];
1623 i6 = ll + lsec - 1;
1624 for (l = ll; l <= i6; ++l)
1626 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1627 257] * b[l + j * b_dim1];
1628 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1629 257] * b[l + (j + 1) * b_dim1];
1630 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1631 257] * b[l + (j + 2) * b_dim1];
1632 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1633 257] * b[l + (j + 3) * b_dim1];
1635 c[i + j * c_dim1] = f11;
1636 c[i + (j + 1) * c_dim1] = f12;
1637 c[i + (j + 2) * c_dim1] = f13;
1638 c[i + (j + 3) * c_dim1] = f14;
1642 if (ujsec < jsec)
1644 i4 = jj + jsec - 1;
1645 for (j = jj + ujsec; j <= i4; ++j)
1647 i5 = ii + uisec - 1;
1648 for (i = ii; i <= i5; i += 4)
1650 f11 = c[i + j * c_dim1];
1651 f21 = c[i + 1 + j * c_dim1];
1652 f31 = c[i + 2 + j * c_dim1];
1653 f41 = c[i + 3 + j * c_dim1];
1654 i6 = ll + lsec - 1;
1655 for (l = ll; l <= i6; ++l)
1657 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1658 257] * b[l + j * b_dim1];
1659 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
1660 257] * b[l + j * b_dim1];
1661 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
1662 257] * b[l + j * b_dim1];
1663 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
1664 257] * b[l + j * b_dim1];
1666 c[i + j * c_dim1] = f11;
1667 c[i + 1 + j * c_dim1] = f21;
1668 c[i + 2 + j * c_dim1] = f31;
1669 c[i + 3 + j * c_dim1] = f41;
1671 i5 = ii + isec - 1;
1672 for (i = ii + uisec; i <= i5; ++i)
1674 f11 = c[i + j * c_dim1];
1675 i6 = ll + lsec - 1;
1676 for (l = ll; l <= i6; ++l)
1678 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1679 257] * b[l + j * b_dim1];
1681 c[i + j * c_dim1] = f11;
1688 free(t1);
1689 return;
1691 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
1693 if (GFC_DESCRIPTOR_RANK (a) != 1)
1695 const GFC_INTEGER_4 *restrict abase_x;
1696 const GFC_INTEGER_4 *restrict bbase_y;
1697 GFC_INTEGER_4 *restrict dest_y;
1698 GFC_INTEGER_4 s;
1700 for (y = 0; y < ycount; y++)
1702 bbase_y = &bbase[y*bystride];
1703 dest_y = &dest[y*rystride];
1704 for (x = 0; x < xcount; x++)
1706 abase_x = &abase[x*axstride];
1707 s = (GFC_INTEGER_4) 0;
1708 for (n = 0; n < count; n++)
1709 s += abase_x[n] * bbase_y[n];
1710 dest_y[x] = s;
1714 else
1716 const GFC_INTEGER_4 *restrict bbase_y;
1717 GFC_INTEGER_4 s;
1719 for (y = 0; y < ycount; y++)
1721 bbase_y = &bbase[y*bystride];
1722 s = (GFC_INTEGER_4) 0;
1723 for (n = 0; n < count; n++)
1724 s += abase[n*axstride] * bbase_y[n];
1725 dest[y*rystride] = s;
1729 else if (axstride < aystride)
1731 for (y = 0; y < ycount; y++)
1732 for (x = 0; x < xcount; x++)
1733 dest[x*rxstride + y*rystride] = (GFC_INTEGER_4)0;
1735 for (y = 0; y < ycount; y++)
1736 for (n = 0; n < count; n++)
1737 for (x = 0; x < xcount; x++)
1738 /* dest[x,y] += a[x,n] * b[n,y] */
1739 dest[x*rxstride + y*rystride] +=
1740 abase[x*axstride + n*aystride] *
1741 bbase[n*bxstride + y*bystride];
1743 else if (GFC_DESCRIPTOR_RANK (a) == 1)
1745 const GFC_INTEGER_4 *restrict bbase_y;
1746 GFC_INTEGER_4 s;
1748 for (y = 0; y < ycount; y++)
1750 bbase_y = &bbase[y*bystride];
1751 s = (GFC_INTEGER_4) 0;
1752 for (n = 0; n < count; n++)
1753 s += abase[n*axstride] * bbase_y[n*bxstride];
1754 dest[y*rxstride] = s;
1757 else
1759 const GFC_INTEGER_4 *restrict abase_x;
1760 const GFC_INTEGER_4 *restrict bbase_y;
1761 GFC_INTEGER_4 *restrict dest_y;
1762 GFC_INTEGER_4 s;
1764 for (y = 0; y < ycount; y++)
1766 bbase_y = &bbase[y*bystride];
1767 dest_y = &dest[y*rystride];
1768 for (x = 0; x < xcount; x++)
1770 abase_x = &abase[x*axstride];
1771 s = (GFC_INTEGER_4) 0;
1772 for (n = 0; n < count; n++)
1773 s += abase_x[n*aystride] * bbase_y[n*bxstride];
1774 dest_y[x*rxstride] = s;
1779 #undef POW3
1780 #undef min
1781 #undef max
1783 #endif /* HAVE_AVX512F */
1785 /* AMD-specifix funtions with AVX128 and FMA3/FMA4. */
1787 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
1788 void
1789 matmul_i4_avx128_fma3 (gfc_array_i4 * const restrict retarray,
1790 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
1791 int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma")));
1792 internal_proto(matmul_i4_avx128_fma3);
1793 #endif
1795 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
1796 void
1797 matmul_i4_avx128_fma4 (gfc_array_i4 * const restrict retarray,
1798 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
1799 int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma4")));
1800 internal_proto(matmul_i4_avx128_fma4);
1801 #endif
1803 /* Function to fall back to if there is no special processor-specific version. */
1804 static void
1805 matmul_i4_vanilla (gfc_array_i4 * const restrict retarray,
1806 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
1807 int blas_limit, blas_call gemm)
1809 const GFC_INTEGER_4 * restrict abase;
1810 const GFC_INTEGER_4 * restrict bbase;
1811 GFC_INTEGER_4 * restrict dest;
1813 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
1814 index_type x, y, n, count, xcount, ycount;
1816 assert (GFC_DESCRIPTOR_RANK (a) == 2
1817 || GFC_DESCRIPTOR_RANK (b) == 2);
1819 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
1821 Either A or B (but not both) can be rank 1:
1823 o One-dimensional argument A is implicitly treated as a row matrix
1824 dimensioned [1,count], so xcount=1.
1826 o One-dimensional argument B is implicitly treated as a column matrix
1827 dimensioned [count, 1], so ycount=1.
1830 if (retarray->base_addr == NULL)
1832 if (GFC_DESCRIPTOR_RANK (a) == 1)
1834 GFC_DIMENSION_SET(retarray->dim[0], 0,
1835 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
1837 else if (GFC_DESCRIPTOR_RANK (b) == 1)
1839 GFC_DIMENSION_SET(retarray->dim[0], 0,
1840 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
1842 else
1844 GFC_DIMENSION_SET(retarray->dim[0], 0,
1845 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
1847 GFC_DIMENSION_SET(retarray->dim[1], 0,
1848 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
1849 GFC_DESCRIPTOR_EXTENT(retarray,0));
1852 retarray->base_addr
1853 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_4));
1854 retarray->offset = 0;
1856 else if (unlikely (compile_options.bounds_check))
1858 index_type ret_extent, arg_extent;
1860 if (GFC_DESCRIPTOR_RANK (a) == 1)
1862 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
1863 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1864 if (arg_extent != ret_extent)
1865 runtime_error ("Array bound mismatch for dimension 1 of "
1866 "array (%ld/%ld) ",
1867 (long int) ret_extent, (long int) arg_extent);
1869 else if (GFC_DESCRIPTOR_RANK (b) == 1)
1871 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
1872 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1873 if (arg_extent != ret_extent)
1874 runtime_error ("Array bound mismatch for dimension 1 of "
1875 "array (%ld/%ld) ",
1876 (long int) ret_extent, (long int) arg_extent);
1878 else
1880 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
1881 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1882 if (arg_extent != ret_extent)
1883 runtime_error ("Array bound mismatch for dimension 1 of "
1884 "array (%ld/%ld) ",
1885 (long int) ret_extent, (long int) arg_extent);
1887 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
1888 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
1889 if (arg_extent != ret_extent)
1890 runtime_error ("Array bound mismatch for dimension 2 of "
1891 "array (%ld/%ld) ",
1892 (long int) ret_extent, (long int) arg_extent);
1897 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
1899 /* One-dimensional result may be addressed in the code below
1900 either as a row or a column matrix. We want both cases to
1901 work. */
1902 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
1904 else
1906 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
1907 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
1911 if (GFC_DESCRIPTOR_RANK (a) == 1)
1913 /* Treat it as a a row matrix A[1,count]. */
1914 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
1915 aystride = 1;
1917 xcount = 1;
1918 count = GFC_DESCRIPTOR_EXTENT(a,0);
1920 else
1922 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
1923 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
1925 count = GFC_DESCRIPTOR_EXTENT(a,1);
1926 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
1929 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
1931 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
1932 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
1933 "in dimension 1: is %ld, should be %ld",
1934 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
1937 if (GFC_DESCRIPTOR_RANK (b) == 1)
1939 /* Treat it as a column matrix B[count,1] */
1940 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
1942 /* bystride should never be used for 1-dimensional b.
1943 The value is only used for calculation of the
1944 memory by the buffer. */
1945 bystride = 256;
1946 ycount = 1;
1948 else
1950 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
1951 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
1952 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
1955 abase = a->base_addr;
1956 bbase = b->base_addr;
1957 dest = retarray->base_addr;
1959 /* Now that everything is set up, we perform the multiplication
1960 itself. */
1962 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
1963 #define min(a,b) ((a) <= (b) ? (a) : (b))
1964 #define max(a,b) ((a) >= (b) ? (a) : (b))
1966 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
1967 && (bxstride == 1 || bystride == 1)
1968 && (((float) xcount) * ((float) ycount) * ((float) count)
1969 > POW3(blas_limit)))
1971 const int m = xcount, n = ycount, k = count, ldc = rystride;
1972 const GFC_INTEGER_4 one = 1, zero = 0;
1973 const int lda = (axstride == 1) ? aystride : axstride,
1974 ldb = (bxstride == 1) ? bystride : bxstride;
1976 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
1978 assert (gemm != NULL);
1979 const char *transa, *transb;
1980 if (try_blas & 2)
1981 transa = "C";
1982 else
1983 transa = axstride == 1 ? "N" : "T";
1985 if (try_blas & 4)
1986 transb = "C";
1987 else
1988 transb = bxstride == 1 ? "N" : "T";
1990 gemm (transa, transb , &m,
1991 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
1992 &ldc, 1, 1);
1993 return;
1997 if (rxstride == 1 && axstride == 1 && bxstride == 1)
1999 /* This block of code implements a tuned matmul, derived from
2000 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
2002 Bo Kagstrom and Per Ling
2003 Department of Computing Science
2004 Umea University
2005 S-901 87 Umea, Sweden
2007 from netlib.org, translated to C, and modified for matmul.m4. */
2009 const GFC_INTEGER_4 *a, *b;
2010 GFC_INTEGER_4 *c;
2011 const index_type m = xcount, n = ycount, k = count;
2013 /* System generated locals */
2014 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
2015 i1, i2, i3, i4, i5, i6;
2017 /* Local variables */
2018 GFC_INTEGER_4 f11, f12, f21, f22, f31, f32, f41, f42,
2019 f13, f14, f23, f24, f33, f34, f43, f44;
2020 index_type i, j, l, ii, jj, ll;
2021 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
2022 GFC_INTEGER_4 *t1;
2024 a = abase;
2025 b = bbase;
2026 c = retarray->base_addr;
2028 /* Parameter adjustments */
2029 c_dim1 = rystride;
2030 c_offset = 1 + c_dim1;
2031 c -= c_offset;
2032 a_dim1 = aystride;
2033 a_offset = 1 + a_dim1;
2034 a -= a_offset;
2035 b_dim1 = bystride;
2036 b_offset = 1 + b_dim1;
2037 b -= b_offset;
2039 /* Empty c first. */
2040 for (j=1; j<=n; j++)
2041 for (i=1; i<=m; i++)
2042 c[i + j * c_dim1] = (GFC_INTEGER_4)0;
2044 /* Early exit if possible */
2045 if (m == 0 || n == 0 || k == 0)
2046 return;
2048 /* Adjust size of t1 to what is needed. */
2049 index_type t1_dim, a_sz;
2050 if (aystride == 1)
2051 a_sz = rystride;
2052 else
2053 a_sz = a_dim1;
2055 t1_dim = a_sz * 256 + b_dim1;
2056 if (t1_dim > 65536)
2057 t1_dim = 65536;
2059 t1 = malloc (t1_dim * sizeof(GFC_INTEGER_4));
2061 /* Start turning the crank. */
2062 i1 = n;
2063 for (jj = 1; jj <= i1; jj += 512)
2065 /* Computing MIN */
2066 i2 = 512;
2067 i3 = n - jj + 1;
2068 jsec = min(i2,i3);
2069 ujsec = jsec - jsec % 4;
2070 i2 = k;
2071 for (ll = 1; ll <= i2; ll += 256)
2073 /* Computing MIN */
2074 i3 = 256;
2075 i4 = k - ll + 1;
2076 lsec = min(i3,i4);
2077 ulsec = lsec - lsec % 2;
2079 i3 = m;
2080 for (ii = 1; ii <= i3; ii += 256)
2082 /* Computing MIN */
2083 i4 = 256;
2084 i5 = m - ii + 1;
2085 isec = min(i4,i5);
2086 uisec = isec - isec % 2;
2087 i4 = ll + ulsec - 1;
2088 for (l = ll; l <= i4; l += 2)
2090 i5 = ii + uisec - 1;
2091 for (i = ii; i <= i5; i += 2)
2093 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
2094 a[i + l * a_dim1];
2095 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
2096 a[i + (l + 1) * a_dim1];
2097 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
2098 a[i + 1 + l * a_dim1];
2099 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
2100 a[i + 1 + (l + 1) * a_dim1];
2102 if (uisec < isec)
2104 t1[l - ll + 1 + (isec << 8) - 257] =
2105 a[ii + isec - 1 + l * a_dim1];
2106 t1[l - ll + 2 + (isec << 8) - 257] =
2107 a[ii + isec - 1 + (l + 1) * a_dim1];
2110 if (ulsec < lsec)
2112 i4 = ii + isec - 1;
2113 for (i = ii; i<= i4; ++i)
2115 t1[lsec + ((i - ii + 1) << 8) - 257] =
2116 a[i + (ll + lsec - 1) * a_dim1];
2120 uisec = isec - isec % 4;
2121 i4 = jj + ujsec - 1;
2122 for (j = jj; j <= i4; j += 4)
2124 i5 = ii + uisec - 1;
2125 for (i = ii; i <= i5; i += 4)
2127 f11 = c[i + j * c_dim1];
2128 f21 = c[i + 1 + j * c_dim1];
2129 f12 = c[i + (j + 1) * c_dim1];
2130 f22 = c[i + 1 + (j + 1) * c_dim1];
2131 f13 = c[i + (j + 2) * c_dim1];
2132 f23 = c[i + 1 + (j + 2) * c_dim1];
2133 f14 = c[i + (j + 3) * c_dim1];
2134 f24 = c[i + 1 + (j + 3) * c_dim1];
2135 f31 = c[i + 2 + j * c_dim1];
2136 f41 = c[i + 3 + j * c_dim1];
2137 f32 = c[i + 2 + (j + 1) * c_dim1];
2138 f42 = c[i + 3 + (j + 1) * c_dim1];
2139 f33 = c[i + 2 + (j + 2) * c_dim1];
2140 f43 = c[i + 3 + (j + 2) * c_dim1];
2141 f34 = c[i + 2 + (j + 3) * c_dim1];
2142 f44 = c[i + 3 + (j + 3) * c_dim1];
2143 i6 = ll + lsec - 1;
2144 for (l = ll; l <= i6; ++l)
2146 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2147 * b[l + j * b_dim1];
2148 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2149 * b[l + j * b_dim1];
2150 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2151 * b[l + (j + 1) * b_dim1];
2152 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2153 * b[l + (j + 1) * b_dim1];
2154 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2155 * b[l + (j + 2) * b_dim1];
2156 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2157 * b[l + (j + 2) * b_dim1];
2158 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2159 * b[l + (j + 3) * b_dim1];
2160 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2161 * b[l + (j + 3) * b_dim1];
2162 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2163 * b[l + j * b_dim1];
2164 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2165 * b[l + j * b_dim1];
2166 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2167 * b[l + (j + 1) * b_dim1];
2168 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2169 * b[l + (j + 1) * b_dim1];
2170 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2171 * b[l + (j + 2) * b_dim1];
2172 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2173 * b[l + (j + 2) * b_dim1];
2174 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2175 * b[l + (j + 3) * b_dim1];
2176 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2177 * b[l + (j + 3) * b_dim1];
2179 c[i + j * c_dim1] = f11;
2180 c[i + 1 + j * c_dim1] = f21;
2181 c[i + (j + 1) * c_dim1] = f12;
2182 c[i + 1 + (j + 1) * c_dim1] = f22;
2183 c[i + (j + 2) * c_dim1] = f13;
2184 c[i + 1 + (j + 2) * c_dim1] = f23;
2185 c[i + (j + 3) * c_dim1] = f14;
2186 c[i + 1 + (j + 3) * c_dim1] = f24;
2187 c[i + 2 + j * c_dim1] = f31;
2188 c[i + 3 + j * c_dim1] = f41;
2189 c[i + 2 + (j + 1) * c_dim1] = f32;
2190 c[i + 3 + (j + 1) * c_dim1] = f42;
2191 c[i + 2 + (j + 2) * c_dim1] = f33;
2192 c[i + 3 + (j + 2) * c_dim1] = f43;
2193 c[i + 2 + (j + 3) * c_dim1] = f34;
2194 c[i + 3 + (j + 3) * c_dim1] = f44;
2196 if (uisec < isec)
2198 i5 = ii + isec - 1;
2199 for (i = ii + uisec; i <= i5; ++i)
2201 f11 = c[i + j * c_dim1];
2202 f12 = c[i + (j + 1) * c_dim1];
2203 f13 = c[i + (j + 2) * c_dim1];
2204 f14 = c[i + (j + 3) * c_dim1];
2205 i6 = ll + lsec - 1;
2206 for (l = ll; l <= i6; ++l)
2208 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2209 257] * b[l + j * b_dim1];
2210 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2211 257] * b[l + (j + 1) * b_dim1];
2212 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2213 257] * b[l + (j + 2) * b_dim1];
2214 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2215 257] * b[l + (j + 3) * b_dim1];
2217 c[i + j * c_dim1] = f11;
2218 c[i + (j + 1) * c_dim1] = f12;
2219 c[i + (j + 2) * c_dim1] = f13;
2220 c[i + (j + 3) * c_dim1] = f14;
2224 if (ujsec < jsec)
2226 i4 = jj + jsec - 1;
2227 for (j = jj + ujsec; j <= i4; ++j)
2229 i5 = ii + uisec - 1;
2230 for (i = ii; i <= i5; i += 4)
2232 f11 = c[i + j * c_dim1];
2233 f21 = c[i + 1 + j * c_dim1];
2234 f31 = c[i + 2 + j * c_dim1];
2235 f41 = c[i + 3 + j * c_dim1];
2236 i6 = ll + lsec - 1;
2237 for (l = ll; l <= i6; ++l)
2239 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2240 257] * b[l + j * b_dim1];
2241 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
2242 257] * b[l + j * b_dim1];
2243 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
2244 257] * b[l + j * b_dim1];
2245 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
2246 257] * b[l + j * b_dim1];
2248 c[i + j * c_dim1] = f11;
2249 c[i + 1 + j * c_dim1] = f21;
2250 c[i + 2 + j * c_dim1] = f31;
2251 c[i + 3 + j * c_dim1] = f41;
2253 i5 = ii + isec - 1;
2254 for (i = ii + uisec; i <= i5; ++i)
2256 f11 = c[i + j * c_dim1];
2257 i6 = ll + lsec - 1;
2258 for (l = ll; l <= i6; ++l)
2260 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2261 257] * b[l + j * b_dim1];
2263 c[i + j * c_dim1] = f11;
2270 free(t1);
2271 return;
2273 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
2275 if (GFC_DESCRIPTOR_RANK (a) != 1)
2277 const GFC_INTEGER_4 *restrict abase_x;
2278 const GFC_INTEGER_4 *restrict bbase_y;
2279 GFC_INTEGER_4 *restrict dest_y;
2280 GFC_INTEGER_4 s;
2282 for (y = 0; y < ycount; y++)
2284 bbase_y = &bbase[y*bystride];
2285 dest_y = &dest[y*rystride];
2286 for (x = 0; x < xcount; x++)
2288 abase_x = &abase[x*axstride];
2289 s = (GFC_INTEGER_4) 0;
2290 for (n = 0; n < count; n++)
2291 s += abase_x[n] * bbase_y[n];
2292 dest_y[x] = s;
2296 else
2298 const GFC_INTEGER_4 *restrict bbase_y;
2299 GFC_INTEGER_4 s;
2301 for (y = 0; y < ycount; y++)
2303 bbase_y = &bbase[y*bystride];
2304 s = (GFC_INTEGER_4) 0;
2305 for (n = 0; n < count; n++)
2306 s += abase[n*axstride] * bbase_y[n];
2307 dest[y*rystride] = s;
2311 else if (axstride < aystride)
2313 for (y = 0; y < ycount; y++)
2314 for (x = 0; x < xcount; x++)
2315 dest[x*rxstride + y*rystride] = (GFC_INTEGER_4)0;
2317 for (y = 0; y < ycount; y++)
2318 for (n = 0; n < count; n++)
2319 for (x = 0; x < xcount; x++)
2320 /* dest[x,y] += a[x,n] * b[n,y] */
2321 dest[x*rxstride + y*rystride] +=
2322 abase[x*axstride + n*aystride] *
2323 bbase[n*bxstride + y*bystride];
2325 else if (GFC_DESCRIPTOR_RANK (a) == 1)
2327 const GFC_INTEGER_4 *restrict bbase_y;
2328 GFC_INTEGER_4 s;
2330 for (y = 0; y < ycount; y++)
2332 bbase_y = &bbase[y*bystride];
2333 s = (GFC_INTEGER_4) 0;
2334 for (n = 0; n < count; n++)
2335 s += abase[n*axstride] * bbase_y[n*bxstride];
2336 dest[y*rxstride] = s;
2339 else
2341 const GFC_INTEGER_4 *restrict abase_x;
2342 const GFC_INTEGER_4 *restrict bbase_y;
2343 GFC_INTEGER_4 *restrict dest_y;
2344 GFC_INTEGER_4 s;
2346 for (y = 0; y < ycount; y++)
2348 bbase_y = &bbase[y*bystride];
2349 dest_y = &dest[y*rystride];
2350 for (x = 0; x < xcount; x++)
2352 abase_x = &abase[x*axstride];
2353 s = (GFC_INTEGER_4) 0;
2354 for (n = 0; n < count; n++)
2355 s += abase_x[n*aystride] * bbase_y[n*bxstride];
2356 dest_y[x*rxstride] = s;
2361 #undef POW3
2362 #undef min
2363 #undef max
2366 /* Compiling main function, with selection code for the processor. */
2368 /* Currently, this is i386 only. Adjust for other architectures. */
2370 #include <config/i386/cpuinfo.h>
2371 void matmul_i4 (gfc_array_i4 * const restrict retarray,
2372 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
2373 int blas_limit, blas_call gemm)
2375 static void (*matmul_p) (gfc_array_i4 * const restrict retarray,
2376 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
2377 int blas_limit, blas_call gemm);
2379 void (*matmul_fn) (gfc_array_i4 * const restrict retarray,
2380 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
2381 int blas_limit, blas_call gemm);
2383 matmul_fn = __atomic_load_n (&matmul_p, __ATOMIC_RELAXED);
2384 if (matmul_fn == NULL)
2386 matmul_fn = matmul_i4_vanilla;
2387 if (__cpu_model.__cpu_vendor == VENDOR_INTEL)
2389 /* Run down the available processors in order of preference. */
2390 #ifdef HAVE_AVX512F
2391 if (__cpu_model.__cpu_features[0] & (1 << FEATURE_AVX512F))
2393 matmul_fn = matmul_i4_avx512f;
2394 goto store;
2397 #endif /* HAVE_AVX512F */
2399 #ifdef HAVE_AVX2
2400 if ((__cpu_model.__cpu_features[0] & (1 << FEATURE_AVX2))
2401 && (__cpu_model.__cpu_features[0] & (1 << FEATURE_FMA)))
2403 matmul_fn = matmul_i4_avx2;
2404 goto store;
2407 #endif
2409 #ifdef HAVE_AVX
2410 if (__cpu_model.__cpu_features[0] & (1 << FEATURE_AVX))
2412 matmul_fn = matmul_i4_avx;
2413 goto store;
2415 #endif /* HAVE_AVX */
2417 else if (__cpu_model.__cpu_vendor == VENDOR_AMD)
2419 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
2420 if ((__cpu_model.__cpu_features[0] & (1 << FEATURE_AVX))
2421 && (__cpu_model.__cpu_features[0] & (1 << FEATURE_FMA)))
2423 matmul_fn = matmul_i4_avx128_fma3;
2424 goto store;
2426 #endif
2427 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
2428 if ((__cpu_model.__cpu_features[0] & (1 << FEATURE_AVX))
2429 && (__cpu_model.__cpu_features[0] & (1 << FEATURE_FMA4)))
2431 matmul_fn = matmul_i4_avx128_fma4;
2432 goto store;
2434 #endif
2437 store:
2438 __atomic_store_n (&matmul_p, matmul_fn, __ATOMIC_RELAXED);
2441 (*matmul_fn) (retarray, a, b, try_blas, blas_limit, gemm);
2444 #else /* Just the vanilla function. */
2446 void
2447 matmul_i4 (gfc_array_i4 * const restrict retarray,
2448 gfc_array_i4 * const restrict a, gfc_array_i4 * const restrict b, int try_blas,
2449 int blas_limit, blas_call gemm)
2451 const GFC_INTEGER_4 * restrict abase;
2452 const GFC_INTEGER_4 * restrict bbase;
2453 GFC_INTEGER_4 * restrict dest;
2455 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
2456 index_type x, y, n, count, xcount, ycount;
2458 assert (GFC_DESCRIPTOR_RANK (a) == 2
2459 || GFC_DESCRIPTOR_RANK (b) == 2);
2461 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
2463 Either A or B (but not both) can be rank 1:
2465 o One-dimensional argument A is implicitly treated as a row matrix
2466 dimensioned [1,count], so xcount=1.
2468 o One-dimensional argument B is implicitly treated as a column matrix
2469 dimensioned [count, 1], so ycount=1.
2472 if (retarray->base_addr == NULL)
2474 if (GFC_DESCRIPTOR_RANK (a) == 1)
2476 GFC_DIMENSION_SET(retarray->dim[0], 0,
2477 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
2479 else if (GFC_DESCRIPTOR_RANK (b) == 1)
2481 GFC_DIMENSION_SET(retarray->dim[0], 0,
2482 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
2484 else
2486 GFC_DIMENSION_SET(retarray->dim[0], 0,
2487 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
2489 GFC_DIMENSION_SET(retarray->dim[1], 0,
2490 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
2491 GFC_DESCRIPTOR_EXTENT(retarray,0));
2494 retarray->base_addr
2495 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_4));
2496 retarray->offset = 0;
2498 else if (unlikely (compile_options.bounds_check))
2500 index_type ret_extent, arg_extent;
2502 if (GFC_DESCRIPTOR_RANK (a) == 1)
2504 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
2505 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
2506 if (arg_extent != ret_extent)
2507 runtime_error ("Array bound mismatch for dimension 1 of "
2508 "array (%ld/%ld) ",
2509 (long int) ret_extent, (long int) arg_extent);
2511 else if (GFC_DESCRIPTOR_RANK (b) == 1)
2513 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
2514 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
2515 if (arg_extent != ret_extent)
2516 runtime_error ("Array bound mismatch for dimension 1 of "
2517 "array (%ld/%ld) ",
2518 (long int) ret_extent, (long int) arg_extent);
2520 else
2522 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
2523 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
2524 if (arg_extent != ret_extent)
2525 runtime_error ("Array bound mismatch for dimension 1 of "
2526 "array (%ld/%ld) ",
2527 (long int) ret_extent, (long int) arg_extent);
2529 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
2530 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
2531 if (arg_extent != ret_extent)
2532 runtime_error ("Array bound mismatch for dimension 2 of "
2533 "array (%ld/%ld) ",
2534 (long int) ret_extent, (long int) arg_extent);
2539 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
2541 /* One-dimensional result may be addressed in the code below
2542 either as a row or a column matrix. We want both cases to
2543 work. */
2544 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
2546 else
2548 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
2549 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
2553 if (GFC_DESCRIPTOR_RANK (a) == 1)
2555 /* Treat it as a a row matrix A[1,count]. */
2556 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
2557 aystride = 1;
2559 xcount = 1;
2560 count = GFC_DESCRIPTOR_EXTENT(a,0);
2562 else
2564 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
2565 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
2567 count = GFC_DESCRIPTOR_EXTENT(a,1);
2568 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
2571 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
2573 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
2574 runtime_error ("Incorrect extent in argument B in MATMUL intrinsic "
2575 "in dimension 1: is %ld, should be %ld",
2576 (long int) GFC_DESCRIPTOR_EXTENT(b,0), (long int) count);
2579 if (GFC_DESCRIPTOR_RANK (b) == 1)
2581 /* Treat it as a column matrix B[count,1] */
2582 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
2584 /* bystride should never be used for 1-dimensional b.
2585 The value is only used for calculation of the
2586 memory by the buffer. */
2587 bystride = 256;
2588 ycount = 1;
2590 else
2592 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
2593 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
2594 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
2597 abase = a->base_addr;
2598 bbase = b->base_addr;
2599 dest = retarray->base_addr;
2601 /* Now that everything is set up, we perform the multiplication
2602 itself. */
2604 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
2605 #define min(a,b) ((a) <= (b) ? (a) : (b))
2606 #define max(a,b) ((a) >= (b) ? (a) : (b))
2608 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
2609 && (bxstride == 1 || bystride == 1)
2610 && (((float) xcount) * ((float) ycount) * ((float) count)
2611 > POW3(blas_limit)))
2613 const int m = xcount, n = ycount, k = count, ldc = rystride;
2614 const GFC_INTEGER_4 one = 1, zero = 0;
2615 const int lda = (axstride == 1) ? aystride : axstride,
2616 ldb = (bxstride == 1) ? bystride : bxstride;
2618 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
2620 assert (gemm != NULL);
2621 const char *transa, *transb;
2622 if (try_blas & 2)
2623 transa = "C";
2624 else
2625 transa = axstride == 1 ? "N" : "T";
2627 if (try_blas & 4)
2628 transb = "C";
2629 else
2630 transb = bxstride == 1 ? "N" : "T";
2632 gemm (transa, transb , &m,
2633 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
2634 &ldc, 1, 1);
2635 return;
2639 if (rxstride == 1 && axstride == 1 && bxstride == 1)
2641 /* This block of code implements a tuned matmul, derived from
2642 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
2644 Bo Kagstrom and Per Ling
2645 Department of Computing Science
2646 Umea University
2647 S-901 87 Umea, Sweden
2649 from netlib.org, translated to C, and modified for matmul.m4. */
2651 const GFC_INTEGER_4 *a, *b;
2652 GFC_INTEGER_4 *c;
2653 const index_type m = xcount, n = ycount, k = count;
2655 /* System generated locals */
2656 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
2657 i1, i2, i3, i4, i5, i6;
2659 /* Local variables */
2660 GFC_INTEGER_4 f11, f12, f21, f22, f31, f32, f41, f42,
2661 f13, f14, f23, f24, f33, f34, f43, f44;
2662 index_type i, j, l, ii, jj, ll;
2663 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
2664 GFC_INTEGER_4 *t1;
2666 a = abase;
2667 b = bbase;
2668 c = retarray->base_addr;
2670 /* Parameter adjustments */
2671 c_dim1 = rystride;
2672 c_offset = 1 + c_dim1;
2673 c -= c_offset;
2674 a_dim1 = aystride;
2675 a_offset = 1 + a_dim1;
2676 a -= a_offset;
2677 b_dim1 = bystride;
2678 b_offset = 1 + b_dim1;
2679 b -= b_offset;
2681 /* Empty c first. */
2682 for (j=1; j<=n; j++)
2683 for (i=1; i<=m; i++)
2684 c[i + j * c_dim1] = (GFC_INTEGER_4)0;
2686 /* Early exit if possible */
2687 if (m == 0 || n == 0 || k == 0)
2688 return;
2690 /* Adjust size of t1 to what is needed. */
2691 index_type t1_dim, a_sz;
2692 if (aystride == 1)
2693 a_sz = rystride;
2694 else
2695 a_sz = a_dim1;
2697 t1_dim = a_sz * 256 + b_dim1;
2698 if (t1_dim > 65536)
2699 t1_dim = 65536;
2701 t1 = malloc (t1_dim * sizeof(GFC_INTEGER_4));
2703 /* Start turning the crank. */
2704 i1 = n;
2705 for (jj = 1; jj <= i1; jj += 512)
2707 /* Computing MIN */
2708 i2 = 512;
2709 i3 = n - jj + 1;
2710 jsec = min(i2,i3);
2711 ujsec = jsec - jsec % 4;
2712 i2 = k;
2713 for (ll = 1; ll <= i2; ll += 256)
2715 /* Computing MIN */
2716 i3 = 256;
2717 i4 = k - ll + 1;
2718 lsec = min(i3,i4);
2719 ulsec = lsec - lsec % 2;
2721 i3 = m;
2722 for (ii = 1; ii <= i3; ii += 256)
2724 /* Computing MIN */
2725 i4 = 256;
2726 i5 = m - ii + 1;
2727 isec = min(i4,i5);
2728 uisec = isec - isec % 2;
2729 i4 = ll + ulsec - 1;
2730 for (l = ll; l <= i4; l += 2)
2732 i5 = ii + uisec - 1;
2733 for (i = ii; i <= i5; i += 2)
2735 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
2736 a[i + l * a_dim1];
2737 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
2738 a[i + (l + 1) * a_dim1];
2739 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
2740 a[i + 1 + l * a_dim1];
2741 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
2742 a[i + 1 + (l + 1) * a_dim1];
2744 if (uisec < isec)
2746 t1[l - ll + 1 + (isec << 8) - 257] =
2747 a[ii + isec - 1 + l * a_dim1];
2748 t1[l - ll + 2 + (isec << 8) - 257] =
2749 a[ii + isec - 1 + (l + 1) * a_dim1];
2752 if (ulsec < lsec)
2754 i4 = ii + isec - 1;
2755 for (i = ii; i<= i4; ++i)
2757 t1[lsec + ((i - ii + 1) << 8) - 257] =
2758 a[i + (ll + lsec - 1) * a_dim1];
2762 uisec = isec - isec % 4;
2763 i4 = jj + ujsec - 1;
2764 for (j = jj; j <= i4; j += 4)
2766 i5 = ii + uisec - 1;
2767 for (i = ii; i <= i5; i += 4)
2769 f11 = c[i + j * c_dim1];
2770 f21 = c[i + 1 + j * c_dim1];
2771 f12 = c[i + (j + 1) * c_dim1];
2772 f22 = c[i + 1 + (j + 1) * c_dim1];
2773 f13 = c[i + (j + 2) * c_dim1];
2774 f23 = c[i + 1 + (j + 2) * c_dim1];
2775 f14 = c[i + (j + 3) * c_dim1];
2776 f24 = c[i + 1 + (j + 3) * c_dim1];
2777 f31 = c[i + 2 + j * c_dim1];
2778 f41 = c[i + 3 + j * c_dim1];
2779 f32 = c[i + 2 + (j + 1) * c_dim1];
2780 f42 = c[i + 3 + (j + 1) * c_dim1];
2781 f33 = c[i + 2 + (j + 2) * c_dim1];
2782 f43 = c[i + 3 + (j + 2) * c_dim1];
2783 f34 = c[i + 2 + (j + 3) * c_dim1];
2784 f44 = c[i + 3 + (j + 3) * c_dim1];
2785 i6 = ll + lsec - 1;
2786 for (l = ll; l <= i6; ++l)
2788 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2789 * b[l + j * b_dim1];
2790 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2791 * b[l + j * b_dim1];
2792 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2793 * b[l + (j + 1) * b_dim1];
2794 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2795 * b[l + (j + 1) * b_dim1];
2796 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2797 * b[l + (j + 2) * b_dim1];
2798 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2799 * b[l + (j + 2) * b_dim1];
2800 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2801 * b[l + (j + 3) * b_dim1];
2802 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2803 * b[l + (j + 3) * b_dim1];
2804 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2805 * b[l + j * b_dim1];
2806 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2807 * b[l + j * b_dim1];
2808 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2809 * b[l + (j + 1) * b_dim1];
2810 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2811 * b[l + (j + 1) * b_dim1];
2812 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2813 * b[l + (j + 2) * b_dim1];
2814 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2815 * b[l + (j + 2) * b_dim1];
2816 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2817 * b[l + (j + 3) * b_dim1];
2818 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2819 * b[l + (j + 3) * b_dim1];
2821 c[i + j * c_dim1] = f11;
2822 c[i + 1 + j * c_dim1] = f21;
2823 c[i + (j + 1) * c_dim1] = f12;
2824 c[i + 1 + (j + 1) * c_dim1] = f22;
2825 c[i + (j + 2) * c_dim1] = f13;
2826 c[i + 1 + (j + 2) * c_dim1] = f23;
2827 c[i + (j + 3) * c_dim1] = f14;
2828 c[i + 1 + (j + 3) * c_dim1] = f24;
2829 c[i + 2 + j * c_dim1] = f31;
2830 c[i + 3 + j * c_dim1] = f41;
2831 c[i + 2 + (j + 1) * c_dim1] = f32;
2832 c[i + 3 + (j + 1) * c_dim1] = f42;
2833 c[i + 2 + (j + 2) * c_dim1] = f33;
2834 c[i + 3 + (j + 2) * c_dim1] = f43;
2835 c[i + 2 + (j + 3) * c_dim1] = f34;
2836 c[i + 3 + (j + 3) * c_dim1] = f44;
2838 if (uisec < isec)
2840 i5 = ii + isec - 1;
2841 for (i = ii + uisec; i <= i5; ++i)
2843 f11 = c[i + j * c_dim1];
2844 f12 = c[i + (j + 1) * c_dim1];
2845 f13 = c[i + (j + 2) * c_dim1];
2846 f14 = c[i + (j + 3) * c_dim1];
2847 i6 = ll + lsec - 1;
2848 for (l = ll; l <= i6; ++l)
2850 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2851 257] * b[l + j * b_dim1];
2852 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2853 257] * b[l + (j + 1) * b_dim1];
2854 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2855 257] * b[l + (j + 2) * b_dim1];
2856 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2857 257] * b[l + (j + 3) * b_dim1];
2859 c[i + j * c_dim1] = f11;
2860 c[i + (j + 1) * c_dim1] = f12;
2861 c[i + (j + 2) * c_dim1] = f13;
2862 c[i + (j + 3) * c_dim1] = f14;
2866 if (ujsec < jsec)
2868 i4 = jj + jsec - 1;
2869 for (j = jj + ujsec; j <= i4; ++j)
2871 i5 = ii + uisec - 1;
2872 for (i = ii; i <= i5; i += 4)
2874 f11 = c[i + j * c_dim1];
2875 f21 = c[i + 1 + j * c_dim1];
2876 f31 = c[i + 2 + j * c_dim1];
2877 f41 = c[i + 3 + j * c_dim1];
2878 i6 = ll + lsec - 1;
2879 for (l = ll; l <= i6; ++l)
2881 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2882 257] * b[l + j * b_dim1];
2883 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
2884 257] * b[l + j * b_dim1];
2885 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
2886 257] * b[l + j * b_dim1];
2887 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
2888 257] * b[l + j * b_dim1];
2890 c[i + j * c_dim1] = f11;
2891 c[i + 1 + j * c_dim1] = f21;
2892 c[i + 2 + j * c_dim1] = f31;
2893 c[i + 3 + j * c_dim1] = f41;
2895 i5 = ii + isec - 1;
2896 for (i = ii + uisec; i <= i5; ++i)
2898 f11 = c[i + j * c_dim1];
2899 i6 = ll + lsec - 1;
2900 for (l = ll; l <= i6; ++l)
2902 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2903 257] * b[l + j * b_dim1];
2905 c[i + j * c_dim1] = f11;
2912 free(t1);
2913 return;
2915 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
2917 if (GFC_DESCRIPTOR_RANK (a) != 1)
2919 const GFC_INTEGER_4 *restrict abase_x;
2920 const GFC_INTEGER_4 *restrict bbase_y;
2921 GFC_INTEGER_4 *restrict dest_y;
2922 GFC_INTEGER_4 s;
2924 for (y = 0; y < ycount; y++)
2926 bbase_y = &bbase[y*bystride];
2927 dest_y = &dest[y*rystride];
2928 for (x = 0; x < xcount; x++)
2930 abase_x = &abase[x*axstride];
2931 s = (GFC_INTEGER_4) 0;
2932 for (n = 0; n < count; n++)
2933 s += abase_x[n] * bbase_y[n];
2934 dest_y[x] = s;
2938 else
2940 const GFC_INTEGER_4 *restrict bbase_y;
2941 GFC_INTEGER_4 s;
2943 for (y = 0; y < ycount; y++)
2945 bbase_y = &bbase[y*bystride];
2946 s = (GFC_INTEGER_4) 0;
2947 for (n = 0; n < count; n++)
2948 s += abase[n*axstride] * bbase_y[n];
2949 dest[y*rystride] = s;
2953 else if (axstride < aystride)
2955 for (y = 0; y < ycount; y++)
2956 for (x = 0; x < xcount; x++)
2957 dest[x*rxstride + y*rystride] = (GFC_INTEGER_4)0;
2959 for (y = 0; y < ycount; y++)
2960 for (n = 0; n < count; n++)
2961 for (x = 0; x < xcount; x++)
2962 /* dest[x,y] += a[x,n] * b[n,y] */
2963 dest[x*rxstride + y*rystride] +=
2964 abase[x*axstride + n*aystride] *
2965 bbase[n*bxstride + y*bystride];
2967 else if (GFC_DESCRIPTOR_RANK (a) == 1)
2969 const GFC_INTEGER_4 *restrict bbase_y;
2970 GFC_INTEGER_4 s;
2972 for (y = 0; y < ycount; y++)
2974 bbase_y = &bbase[y*bystride];
2975 s = (GFC_INTEGER_4) 0;
2976 for (n = 0; n < count; n++)
2977 s += abase[n*axstride] * bbase_y[n*bxstride];
2978 dest[y*rxstride] = s;
2981 else
2983 const GFC_INTEGER_4 *restrict abase_x;
2984 const GFC_INTEGER_4 *restrict bbase_y;
2985 GFC_INTEGER_4 *restrict dest_y;
2986 GFC_INTEGER_4 s;
2988 for (y = 0; y < ycount; y++)
2990 bbase_y = &bbase[y*bystride];
2991 dest_y = &dest[y*rystride];
2992 for (x = 0; x < xcount; x++)
2994 abase_x = &abase[x*axstride];
2995 s = (GFC_INTEGER_4) 0;
2996 for (n = 0; n < count; n++)
2997 s += abase_x[n*aystride] * bbase_y[n*bxstride];
2998 dest_y[x*rxstride] = s;
3003 #undef POW3
3004 #undef min
3005 #undef max
3007 #endif
3008 #endif