* es.po: Update.
[official-gcc.git] / libgfortran / generated / matmul_i8.c
blobc4d0327b7aa67594c18078ca3a3678d09b1c0a86
1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2016 Free Software Foundation, Inc.
3 Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
27 #include <stdlib.h>
28 #include <string.h>
29 #include <assert.h>
32 #if defined (HAVE_GFC_INTEGER_8)
34 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
35 passed to us by the front-end, in which case we call it for large
36 matrices. */
38 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
39 const int *, const GFC_INTEGER_8 *, const GFC_INTEGER_8 *,
40 const int *, const GFC_INTEGER_8 *, const int *,
41 const GFC_INTEGER_8 *, GFC_INTEGER_8 *, const int *,
42 int, int);
44 /* The order of loops is different in the case of plain matrix
45 multiplication C=MATMUL(A,B), and in the frequent special case where
46 the argument A is the temporary result of a TRANSPOSE intrinsic:
47 C=MATMUL(TRANSPOSE(A),B). Transposed temporaries are detected by
48 looking at their strides.
50 The equivalent Fortran pseudo-code is:
52 DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
53 IF (.NOT.IS_TRANSPOSED(A)) THEN
54 C = 0
55 DO J=1,N
56 DO K=1,COUNT
57 DO I=1,M
58 C(I,J) = C(I,J)+A(I,K)*B(K,J)
59 ELSE
60 DO J=1,N
61 DO I=1,M
62 S = 0
63 DO K=1,COUNT
64 S = S+A(I,K)*B(K,J)
65 C(I,J) = S
66 ENDIF
69 /* If try_blas is set to a nonzero value, then the matmul function will
70 see if there is a way to perform the matrix multiplication by a call
71 to the BLAS gemm function. */
73 extern void matmul_i8 (gfc_array_i8 * const restrict retarray,
74 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
75 int blas_limit, blas_call gemm);
76 export_proto(matmul_i8);
81 /* Put exhaustive list of possible architectures here here, ORed together. */
83 #if defined(HAVE_AVX) || defined(HAVE_AVX2) || defined(HAVE_AVX512F)
85 #ifdef HAVE_AVX
86 static void
87 matmul_i8_avx (gfc_array_i8 * const restrict retarray,
88 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
89 int blas_limit, blas_call gemm) __attribute__((__target__("avx")));
90 static void
91 matmul_i8_avx (gfc_array_i8 * const restrict retarray,
92 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
93 int blas_limit, blas_call gemm)
95 const GFC_INTEGER_8 * restrict abase;
96 const GFC_INTEGER_8 * restrict bbase;
97 GFC_INTEGER_8 * restrict dest;
99 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
100 index_type x, y, n, count, xcount, ycount;
102 assert (GFC_DESCRIPTOR_RANK (a) == 2
103 || GFC_DESCRIPTOR_RANK (b) == 2);
105 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
107 Either A or B (but not both) can be rank 1:
109 o One-dimensional argument A is implicitly treated as a row matrix
110 dimensioned [1,count], so xcount=1.
112 o One-dimensional argument B is implicitly treated as a column matrix
113 dimensioned [count, 1], so ycount=1.
116 if (retarray->base_addr == NULL)
118 if (GFC_DESCRIPTOR_RANK (a) == 1)
120 GFC_DIMENSION_SET(retarray->dim[0], 0,
121 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
123 else if (GFC_DESCRIPTOR_RANK (b) == 1)
125 GFC_DIMENSION_SET(retarray->dim[0], 0,
126 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
128 else
130 GFC_DIMENSION_SET(retarray->dim[0], 0,
131 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
133 GFC_DIMENSION_SET(retarray->dim[1], 0,
134 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
135 GFC_DESCRIPTOR_EXTENT(retarray,0));
138 retarray->base_addr
139 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_8));
140 retarray->offset = 0;
142 else if (unlikely (compile_options.bounds_check))
144 index_type ret_extent, arg_extent;
146 if (GFC_DESCRIPTOR_RANK (a) == 1)
148 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
149 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
150 if (arg_extent != ret_extent)
151 runtime_error ("Incorrect extent in return array in"
152 " MATMUL intrinsic: is %ld, should be %ld",
153 (long int) ret_extent, (long int) arg_extent);
155 else if (GFC_DESCRIPTOR_RANK (b) == 1)
157 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
158 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
159 if (arg_extent != ret_extent)
160 runtime_error ("Incorrect extent in return array in"
161 " MATMUL intrinsic: is %ld, should be %ld",
162 (long int) ret_extent, (long int) arg_extent);
164 else
166 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
167 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
168 if (arg_extent != ret_extent)
169 runtime_error ("Incorrect extent in return array in"
170 " MATMUL intrinsic for dimension 1:"
171 " is %ld, should be %ld",
172 (long int) ret_extent, (long int) arg_extent);
174 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
175 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
176 if (arg_extent != ret_extent)
177 runtime_error ("Incorrect extent in return array in"
178 " MATMUL intrinsic for dimension 2:"
179 " is %ld, should be %ld",
180 (long int) ret_extent, (long int) arg_extent);
185 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
187 /* One-dimensional result may be addressed in the code below
188 either as a row or a column matrix. We want both cases to
189 work. */
190 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
192 else
194 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
195 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
199 if (GFC_DESCRIPTOR_RANK (a) == 1)
201 /* Treat it as a a row matrix A[1,count]. */
202 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
203 aystride = 1;
205 xcount = 1;
206 count = GFC_DESCRIPTOR_EXTENT(a,0);
208 else
210 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
211 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
213 count = GFC_DESCRIPTOR_EXTENT(a,1);
214 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
217 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
219 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
220 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
223 if (GFC_DESCRIPTOR_RANK (b) == 1)
225 /* Treat it as a column matrix B[count,1] */
226 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
228 /* bystride should never be used for 1-dimensional b.
229 in case it is we want it to cause a segfault, rather than
230 an incorrect result. */
231 bystride = 0xDEADBEEF;
232 ycount = 1;
234 else
236 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
237 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
238 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
241 abase = a->base_addr;
242 bbase = b->base_addr;
243 dest = retarray->base_addr;
245 /* Now that everything is set up, we perform the multiplication
246 itself. */
248 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
249 #define min(a,b) ((a) <= (b) ? (a) : (b))
250 #define max(a,b) ((a) >= (b) ? (a) : (b))
252 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
253 && (bxstride == 1 || bystride == 1)
254 && (((float) xcount) * ((float) ycount) * ((float) count)
255 > POW3(blas_limit)))
257 const int m = xcount, n = ycount, k = count, ldc = rystride;
258 const GFC_INTEGER_8 one = 1, zero = 0;
259 const int lda = (axstride == 1) ? aystride : axstride,
260 ldb = (bxstride == 1) ? bystride : bxstride;
262 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
264 assert (gemm != NULL);
265 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m,
266 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
267 &ldc, 1, 1);
268 return;
272 if (rxstride == 1 && axstride == 1 && bxstride == 1)
274 /* This block of code implements a tuned matmul, derived from
275 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
277 Bo Kagstrom and Per Ling
278 Department of Computing Science
279 Umea University
280 S-901 87 Umea, Sweden
282 from netlib.org, translated to C, and modified for matmul.m4. */
284 const GFC_INTEGER_8 *a, *b;
285 GFC_INTEGER_8 *c;
286 const index_type m = xcount, n = ycount, k = count;
288 /* System generated locals */
289 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
290 i1, i2, i3, i4, i5, i6;
292 /* Local variables */
293 GFC_INTEGER_8 t1[65536], /* was [256][256] */
294 f11, f12, f21, f22, f31, f32, f41, f42,
295 f13, f14, f23, f24, f33, f34, f43, f44;
296 index_type i, j, l, ii, jj, ll;
297 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
299 a = abase;
300 b = bbase;
301 c = retarray->base_addr;
303 /* Parameter adjustments */
304 c_dim1 = rystride;
305 c_offset = 1 + c_dim1;
306 c -= c_offset;
307 a_dim1 = aystride;
308 a_offset = 1 + a_dim1;
309 a -= a_offset;
310 b_dim1 = bystride;
311 b_offset = 1 + b_dim1;
312 b -= b_offset;
314 /* Early exit if possible */
315 if (m == 0 || n == 0 || k == 0)
316 return;
318 /* Empty c first. */
319 for (j=1; j<=n; j++)
320 for (i=1; i<=m; i++)
321 c[i + j * c_dim1] = (GFC_INTEGER_8)0;
323 /* Start turning the crank. */
324 i1 = n;
325 for (jj = 1; jj <= i1; jj += 512)
327 /* Computing MIN */
328 i2 = 512;
329 i3 = n - jj + 1;
330 jsec = min(i2,i3);
331 ujsec = jsec - jsec % 4;
332 i2 = k;
333 for (ll = 1; ll <= i2; ll += 256)
335 /* Computing MIN */
336 i3 = 256;
337 i4 = k - ll + 1;
338 lsec = min(i3,i4);
339 ulsec = lsec - lsec % 2;
341 i3 = m;
342 for (ii = 1; ii <= i3; ii += 256)
344 /* Computing MIN */
345 i4 = 256;
346 i5 = m - ii + 1;
347 isec = min(i4,i5);
348 uisec = isec - isec % 2;
349 i4 = ll + ulsec - 1;
350 for (l = ll; l <= i4; l += 2)
352 i5 = ii + uisec - 1;
353 for (i = ii; i <= i5; i += 2)
355 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
356 a[i + l * a_dim1];
357 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
358 a[i + (l + 1) * a_dim1];
359 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
360 a[i + 1 + l * a_dim1];
361 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
362 a[i + 1 + (l + 1) * a_dim1];
364 if (uisec < isec)
366 t1[l - ll + 1 + (isec << 8) - 257] =
367 a[ii + isec - 1 + l * a_dim1];
368 t1[l - ll + 2 + (isec << 8) - 257] =
369 a[ii + isec - 1 + (l + 1) * a_dim1];
372 if (ulsec < lsec)
374 i4 = ii + isec - 1;
375 for (i = ii; i<= i4; ++i)
377 t1[lsec + ((i - ii + 1) << 8) - 257] =
378 a[i + (ll + lsec - 1) * a_dim1];
382 uisec = isec - isec % 4;
383 i4 = jj + ujsec - 1;
384 for (j = jj; j <= i4; j += 4)
386 i5 = ii + uisec - 1;
387 for (i = ii; i <= i5; i += 4)
389 f11 = c[i + j * c_dim1];
390 f21 = c[i + 1 + j * c_dim1];
391 f12 = c[i + (j + 1) * c_dim1];
392 f22 = c[i + 1 + (j + 1) * c_dim1];
393 f13 = c[i + (j + 2) * c_dim1];
394 f23 = c[i + 1 + (j + 2) * c_dim1];
395 f14 = c[i + (j + 3) * c_dim1];
396 f24 = c[i + 1 + (j + 3) * c_dim1];
397 f31 = c[i + 2 + j * c_dim1];
398 f41 = c[i + 3 + j * c_dim1];
399 f32 = c[i + 2 + (j + 1) * c_dim1];
400 f42 = c[i + 3 + (j + 1) * c_dim1];
401 f33 = c[i + 2 + (j + 2) * c_dim1];
402 f43 = c[i + 3 + (j + 2) * c_dim1];
403 f34 = c[i + 2 + (j + 3) * c_dim1];
404 f44 = c[i + 3 + (j + 3) * c_dim1];
405 i6 = ll + lsec - 1;
406 for (l = ll; l <= i6; ++l)
408 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
409 * b[l + j * b_dim1];
410 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
411 * b[l + j * b_dim1];
412 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
413 * b[l + (j + 1) * b_dim1];
414 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
415 * b[l + (j + 1) * b_dim1];
416 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
417 * b[l + (j + 2) * b_dim1];
418 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
419 * b[l + (j + 2) * b_dim1];
420 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
421 * b[l + (j + 3) * b_dim1];
422 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
423 * b[l + (j + 3) * b_dim1];
424 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
425 * b[l + j * b_dim1];
426 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
427 * b[l + j * b_dim1];
428 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
429 * b[l + (j + 1) * b_dim1];
430 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
431 * b[l + (j + 1) * b_dim1];
432 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
433 * b[l + (j + 2) * b_dim1];
434 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
435 * b[l + (j + 2) * b_dim1];
436 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
437 * b[l + (j + 3) * b_dim1];
438 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
439 * b[l + (j + 3) * b_dim1];
441 c[i + j * c_dim1] = f11;
442 c[i + 1 + j * c_dim1] = f21;
443 c[i + (j + 1) * c_dim1] = f12;
444 c[i + 1 + (j + 1) * c_dim1] = f22;
445 c[i + (j + 2) * c_dim1] = f13;
446 c[i + 1 + (j + 2) * c_dim1] = f23;
447 c[i + (j + 3) * c_dim1] = f14;
448 c[i + 1 + (j + 3) * c_dim1] = f24;
449 c[i + 2 + j * c_dim1] = f31;
450 c[i + 3 + j * c_dim1] = f41;
451 c[i + 2 + (j + 1) * c_dim1] = f32;
452 c[i + 3 + (j + 1) * c_dim1] = f42;
453 c[i + 2 + (j + 2) * c_dim1] = f33;
454 c[i + 3 + (j + 2) * c_dim1] = f43;
455 c[i + 2 + (j + 3) * c_dim1] = f34;
456 c[i + 3 + (j + 3) * c_dim1] = f44;
458 if (uisec < isec)
460 i5 = ii + isec - 1;
461 for (i = ii + uisec; i <= i5; ++i)
463 f11 = c[i + j * c_dim1];
464 f12 = c[i + (j + 1) * c_dim1];
465 f13 = c[i + (j + 2) * c_dim1];
466 f14 = c[i + (j + 3) * c_dim1];
467 i6 = ll + lsec - 1;
468 for (l = ll; l <= i6; ++l)
470 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
471 257] * b[l + j * b_dim1];
472 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
473 257] * b[l + (j + 1) * b_dim1];
474 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
475 257] * b[l + (j + 2) * b_dim1];
476 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
477 257] * b[l + (j + 3) * b_dim1];
479 c[i + j * c_dim1] = f11;
480 c[i + (j + 1) * c_dim1] = f12;
481 c[i + (j + 2) * c_dim1] = f13;
482 c[i + (j + 3) * c_dim1] = f14;
486 if (ujsec < jsec)
488 i4 = jj + jsec - 1;
489 for (j = jj + ujsec; j <= i4; ++j)
491 i5 = ii + uisec - 1;
492 for (i = ii; i <= i5; i += 4)
494 f11 = c[i + j * c_dim1];
495 f21 = c[i + 1 + j * c_dim1];
496 f31 = c[i + 2 + j * c_dim1];
497 f41 = c[i + 3 + j * c_dim1];
498 i6 = ll + lsec - 1;
499 for (l = ll; l <= i6; ++l)
501 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
502 257] * b[l + j * b_dim1];
503 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
504 257] * b[l + j * b_dim1];
505 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
506 257] * b[l + j * b_dim1];
507 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
508 257] * b[l + j * b_dim1];
510 c[i + j * c_dim1] = f11;
511 c[i + 1 + j * c_dim1] = f21;
512 c[i + 2 + j * c_dim1] = f31;
513 c[i + 3 + j * c_dim1] = f41;
515 i5 = ii + isec - 1;
516 for (i = ii + uisec; i <= i5; ++i)
518 f11 = c[i + j * c_dim1];
519 i6 = ll + lsec - 1;
520 for (l = ll; l <= i6; ++l)
522 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
523 257] * b[l + j * b_dim1];
525 c[i + j * c_dim1] = f11;
532 return;
534 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
536 if (GFC_DESCRIPTOR_RANK (a) != 1)
538 const GFC_INTEGER_8 *restrict abase_x;
539 const GFC_INTEGER_8 *restrict bbase_y;
540 GFC_INTEGER_8 *restrict dest_y;
541 GFC_INTEGER_8 s;
543 for (y = 0; y < ycount; y++)
545 bbase_y = &bbase[y*bystride];
546 dest_y = &dest[y*rystride];
547 for (x = 0; x < xcount; x++)
549 abase_x = &abase[x*axstride];
550 s = (GFC_INTEGER_8) 0;
551 for (n = 0; n < count; n++)
552 s += abase_x[n] * bbase_y[n];
553 dest_y[x] = s;
557 else
559 const GFC_INTEGER_8 *restrict bbase_y;
560 GFC_INTEGER_8 s;
562 for (y = 0; y < ycount; y++)
564 bbase_y = &bbase[y*bystride];
565 s = (GFC_INTEGER_8) 0;
566 for (n = 0; n < count; n++)
567 s += abase[n*axstride] * bbase_y[n];
568 dest[y*rystride] = s;
572 else if (axstride < aystride)
574 for (y = 0; y < ycount; y++)
575 for (x = 0; x < xcount; x++)
576 dest[x*rxstride + y*rystride] = (GFC_INTEGER_8)0;
578 for (y = 0; y < ycount; y++)
579 for (n = 0; n < count; n++)
580 for (x = 0; x < xcount; x++)
581 /* dest[x,y] += a[x,n] * b[n,y] */
582 dest[x*rxstride + y*rystride] +=
583 abase[x*axstride + n*aystride] *
584 bbase[n*bxstride + y*bystride];
586 else if (GFC_DESCRIPTOR_RANK (a) == 1)
588 const GFC_INTEGER_8 *restrict bbase_y;
589 GFC_INTEGER_8 s;
591 for (y = 0; y < ycount; y++)
593 bbase_y = &bbase[y*bystride];
594 s = (GFC_INTEGER_8) 0;
595 for (n = 0; n < count; n++)
596 s += abase[n*axstride] * bbase_y[n*bxstride];
597 dest[y*rxstride] = s;
600 else
602 const GFC_INTEGER_8 *restrict abase_x;
603 const GFC_INTEGER_8 *restrict bbase_y;
604 GFC_INTEGER_8 *restrict dest_y;
605 GFC_INTEGER_8 s;
607 for (y = 0; y < ycount; y++)
609 bbase_y = &bbase[y*bystride];
610 dest_y = &dest[y*rystride];
611 for (x = 0; x < xcount; x++)
613 abase_x = &abase[x*axstride];
614 s = (GFC_INTEGER_8) 0;
615 for (n = 0; n < count; n++)
616 s += abase_x[n*aystride] * bbase_y[n*bxstride];
617 dest_y[x*rxstride] = s;
622 #undef POW3
623 #undef min
624 #undef max
626 #endif /* HAVE_AVX */
628 #ifdef HAVE_AVX2
629 static void
630 matmul_i8_avx2 (gfc_array_i8 * const restrict retarray,
631 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
632 int blas_limit, blas_call gemm) __attribute__((__target__("avx2")));
633 static void
634 matmul_i8_avx2 (gfc_array_i8 * const restrict retarray,
635 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
636 int blas_limit, blas_call gemm)
638 const GFC_INTEGER_8 * restrict abase;
639 const GFC_INTEGER_8 * restrict bbase;
640 GFC_INTEGER_8 * restrict dest;
642 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
643 index_type x, y, n, count, xcount, ycount;
645 assert (GFC_DESCRIPTOR_RANK (a) == 2
646 || GFC_DESCRIPTOR_RANK (b) == 2);
648 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
650 Either A or B (but not both) can be rank 1:
652 o One-dimensional argument A is implicitly treated as a row matrix
653 dimensioned [1,count], so xcount=1.
655 o One-dimensional argument B is implicitly treated as a column matrix
656 dimensioned [count, 1], so ycount=1.
659 if (retarray->base_addr == NULL)
661 if (GFC_DESCRIPTOR_RANK (a) == 1)
663 GFC_DIMENSION_SET(retarray->dim[0], 0,
664 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
666 else if (GFC_DESCRIPTOR_RANK (b) == 1)
668 GFC_DIMENSION_SET(retarray->dim[0], 0,
669 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
671 else
673 GFC_DIMENSION_SET(retarray->dim[0], 0,
674 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
676 GFC_DIMENSION_SET(retarray->dim[1], 0,
677 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
678 GFC_DESCRIPTOR_EXTENT(retarray,0));
681 retarray->base_addr
682 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_8));
683 retarray->offset = 0;
685 else if (unlikely (compile_options.bounds_check))
687 index_type ret_extent, arg_extent;
689 if (GFC_DESCRIPTOR_RANK (a) == 1)
691 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
692 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
693 if (arg_extent != ret_extent)
694 runtime_error ("Incorrect extent in return array in"
695 " MATMUL intrinsic: is %ld, should be %ld",
696 (long int) ret_extent, (long int) arg_extent);
698 else if (GFC_DESCRIPTOR_RANK (b) == 1)
700 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
701 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
702 if (arg_extent != ret_extent)
703 runtime_error ("Incorrect extent in return array in"
704 " MATMUL intrinsic: is %ld, should be %ld",
705 (long int) ret_extent, (long int) arg_extent);
707 else
709 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
710 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
711 if (arg_extent != ret_extent)
712 runtime_error ("Incorrect extent in return array in"
713 " MATMUL intrinsic for dimension 1:"
714 " is %ld, should be %ld",
715 (long int) ret_extent, (long int) arg_extent);
717 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
718 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
719 if (arg_extent != ret_extent)
720 runtime_error ("Incorrect extent in return array in"
721 " MATMUL intrinsic for dimension 2:"
722 " is %ld, should be %ld",
723 (long int) ret_extent, (long int) arg_extent);
728 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
730 /* One-dimensional result may be addressed in the code below
731 either as a row or a column matrix. We want both cases to
732 work. */
733 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
735 else
737 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
738 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
742 if (GFC_DESCRIPTOR_RANK (a) == 1)
744 /* Treat it as a a row matrix A[1,count]. */
745 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
746 aystride = 1;
748 xcount = 1;
749 count = GFC_DESCRIPTOR_EXTENT(a,0);
751 else
753 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
754 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
756 count = GFC_DESCRIPTOR_EXTENT(a,1);
757 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
760 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
762 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
763 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
766 if (GFC_DESCRIPTOR_RANK (b) == 1)
768 /* Treat it as a column matrix B[count,1] */
769 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
771 /* bystride should never be used for 1-dimensional b.
772 in case it is we want it to cause a segfault, rather than
773 an incorrect result. */
774 bystride = 0xDEADBEEF;
775 ycount = 1;
777 else
779 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
780 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
781 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
784 abase = a->base_addr;
785 bbase = b->base_addr;
786 dest = retarray->base_addr;
788 /* Now that everything is set up, we perform the multiplication
789 itself. */
791 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
792 #define min(a,b) ((a) <= (b) ? (a) : (b))
793 #define max(a,b) ((a) >= (b) ? (a) : (b))
795 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
796 && (bxstride == 1 || bystride == 1)
797 && (((float) xcount) * ((float) ycount) * ((float) count)
798 > POW3(blas_limit)))
800 const int m = xcount, n = ycount, k = count, ldc = rystride;
801 const GFC_INTEGER_8 one = 1, zero = 0;
802 const int lda = (axstride == 1) ? aystride : axstride,
803 ldb = (bxstride == 1) ? bystride : bxstride;
805 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
807 assert (gemm != NULL);
808 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m,
809 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
810 &ldc, 1, 1);
811 return;
815 if (rxstride == 1 && axstride == 1 && bxstride == 1)
817 /* This block of code implements a tuned matmul, derived from
818 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
820 Bo Kagstrom and Per Ling
821 Department of Computing Science
822 Umea University
823 S-901 87 Umea, Sweden
825 from netlib.org, translated to C, and modified for matmul.m4. */
827 const GFC_INTEGER_8 *a, *b;
828 GFC_INTEGER_8 *c;
829 const index_type m = xcount, n = ycount, k = count;
831 /* System generated locals */
832 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
833 i1, i2, i3, i4, i5, i6;
835 /* Local variables */
836 GFC_INTEGER_8 t1[65536], /* was [256][256] */
837 f11, f12, f21, f22, f31, f32, f41, f42,
838 f13, f14, f23, f24, f33, f34, f43, f44;
839 index_type i, j, l, ii, jj, ll;
840 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
842 a = abase;
843 b = bbase;
844 c = retarray->base_addr;
846 /* Parameter adjustments */
847 c_dim1 = rystride;
848 c_offset = 1 + c_dim1;
849 c -= c_offset;
850 a_dim1 = aystride;
851 a_offset = 1 + a_dim1;
852 a -= a_offset;
853 b_dim1 = bystride;
854 b_offset = 1 + b_dim1;
855 b -= b_offset;
857 /* Early exit if possible */
858 if (m == 0 || n == 0 || k == 0)
859 return;
861 /* Empty c first. */
862 for (j=1; j<=n; j++)
863 for (i=1; i<=m; i++)
864 c[i + j * c_dim1] = (GFC_INTEGER_8)0;
866 /* Start turning the crank. */
867 i1 = n;
868 for (jj = 1; jj <= i1; jj += 512)
870 /* Computing MIN */
871 i2 = 512;
872 i3 = n - jj + 1;
873 jsec = min(i2,i3);
874 ujsec = jsec - jsec % 4;
875 i2 = k;
876 for (ll = 1; ll <= i2; ll += 256)
878 /* Computing MIN */
879 i3 = 256;
880 i4 = k - ll + 1;
881 lsec = min(i3,i4);
882 ulsec = lsec - lsec % 2;
884 i3 = m;
885 for (ii = 1; ii <= i3; ii += 256)
887 /* Computing MIN */
888 i4 = 256;
889 i5 = m - ii + 1;
890 isec = min(i4,i5);
891 uisec = isec - isec % 2;
892 i4 = ll + ulsec - 1;
893 for (l = ll; l <= i4; l += 2)
895 i5 = ii + uisec - 1;
896 for (i = ii; i <= i5; i += 2)
898 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
899 a[i + l * a_dim1];
900 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
901 a[i + (l + 1) * a_dim1];
902 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
903 a[i + 1 + l * a_dim1];
904 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
905 a[i + 1 + (l + 1) * a_dim1];
907 if (uisec < isec)
909 t1[l - ll + 1 + (isec << 8) - 257] =
910 a[ii + isec - 1 + l * a_dim1];
911 t1[l - ll + 2 + (isec << 8) - 257] =
912 a[ii + isec - 1 + (l + 1) * a_dim1];
915 if (ulsec < lsec)
917 i4 = ii + isec - 1;
918 for (i = ii; i<= i4; ++i)
920 t1[lsec + ((i - ii + 1) << 8) - 257] =
921 a[i + (ll + lsec - 1) * a_dim1];
925 uisec = isec - isec % 4;
926 i4 = jj + ujsec - 1;
927 for (j = jj; j <= i4; j += 4)
929 i5 = ii + uisec - 1;
930 for (i = ii; i <= i5; i += 4)
932 f11 = c[i + j * c_dim1];
933 f21 = c[i + 1 + j * c_dim1];
934 f12 = c[i + (j + 1) * c_dim1];
935 f22 = c[i + 1 + (j + 1) * c_dim1];
936 f13 = c[i + (j + 2) * c_dim1];
937 f23 = c[i + 1 + (j + 2) * c_dim1];
938 f14 = c[i + (j + 3) * c_dim1];
939 f24 = c[i + 1 + (j + 3) * c_dim1];
940 f31 = c[i + 2 + j * c_dim1];
941 f41 = c[i + 3 + j * c_dim1];
942 f32 = c[i + 2 + (j + 1) * c_dim1];
943 f42 = c[i + 3 + (j + 1) * c_dim1];
944 f33 = c[i + 2 + (j + 2) * c_dim1];
945 f43 = c[i + 3 + (j + 2) * c_dim1];
946 f34 = c[i + 2 + (j + 3) * c_dim1];
947 f44 = c[i + 3 + (j + 3) * c_dim1];
948 i6 = ll + lsec - 1;
949 for (l = ll; l <= i6; ++l)
951 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
952 * b[l + j * b_dim1];
953 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
954 * b[l + j * b_dim1];
955 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
956 * b[l + (j + 1) * b_dim1];
957 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
958 * b[l + (j + 1) * b_dim1];
959 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
960 * b[l + (j + 2) * b_dim1];
961 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
962 * b[l + (j + 2) * b_dim1];
963 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
964 * b[l + (j + 3) * b_dim1];
965 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
966 * b[l + (j + 3) * b_dim1];
967 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
968 * b[l + j * b_dim1];
969 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
970 * b[l + j * b_dim1];
971 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
972 * b[l + (j + 1) * b_dim1];
973 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
974 * b[l + (j + 1) * b_dim1];
975 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
976 * b[l + (j + 2) * b_dim1];
977 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
978 * b[l + (j + 2) * b_dim1];
979 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
980 * b[l + (j + 3) * b_dim1];
981 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
982 * b[l + (j + 3) * b_dim1];
984 c[i + j * c_dim1] = f11;
985 c[i + 1 + j * c_dim1] = f21;
986 c[i + (j + 1) * c_dim1] = f12;
987 c[i + 1 + (j + 1) * c_dim1] = f22;
988 c[i + (j + 2) * c_dim1] = f13;
989 c[i + 1 + (j + 2) * c_dim1] = f23;
990 c[i + (j + 3) * c_dim1] = f14;
991 c[i + 1 + (j + 3) * c_dim1] = f24;
992 c[i + 2 + j * c_dim1] = f31;
993 c[i + 3 + j * c_dim1] = f41;
994 c[i + 2 + (j + 1) * c_dim1] = f32;
995 c[i + 3 + (j + 1) * c_dim1] = f42;
996 c[i + 2 + (j + 2) * c_dim1] = f33;
997 c[i + 3 + (j + 2) * c_dim1] = f43;
998 c[i + 2 + (j + 3) * c_dim1] = f34;
999 c[i + 3 + (j + 3) * c_dim1] = f44;
1001 if (uisec < isec)
1003 i5 = ii + isec - 1;
1004 for (i = ii + uisec; i <= i5; ++i)
1006 f11 = c[i + j * c_dim1];
1007 f12 = c[i + (j + 1) * c_dim1];
1008 f13 = c[i + (j + 2) * c_dim1];
1009 f14 = c[i + (j + 3) * c_dim1];
1010 i6 = ll + lsec - 1;
1011 for (l = ll; l <= i6; ++l)
1013 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1014 257] * b[l + j * b_dim1];
1015 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1016 257] * b[l + (j + 1) * b_dim1];
1017 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1018 257] * b[l + (j + 2) * b_dim1];
1019 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1020 257] * b[l + (j + 3) * b_dim1];
1022 c[i + j * c_dim1] = f11;
1023 c[i + (j + 1) * c_dim1] = f12;
1024 c[i + (j + 2) * c_dim1] = f13;
1025 c[i + (j + 3) * c_dim1] = f14;
1029 if (ujsec < jsec)
1031 i4 = jj + jsec - 1;
1032 for (j = jj + ujsec; j <= i4; ++j)
1034 i5 = ii + uisec - 1;
1035 for (i = ii; i <= i5; i += 4)
1037 f11 = c[i + j * c_dim1];
1038 f21 = c[i + 1 + j * c_dim1];
1039 f31 = c[i + 2 + j * c_dim1];
1040 f41 = c[i + 3 + j * c_dim1];
1041 i6 = ll + lsec - 1;
1042 for (l = ll; l <= i6; ++l)
1044 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1045 257] * b[l + j * b_dim1];
1046 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
1047 257] * b[l + j * b_dim1];
1048 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
1049 257] * b[l + j * b_dim1];
1050 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
1051 257] * b[l + j * b_dim1];
1053 c[i + j * c_dim1] = f11;
1054 c[i + 1 + j * c_dim1] = f21;
1055 c[i + 2 + j * c_dim1] = f31;
1056 c[i + 3 + j * c_dim1] = f41;
1058 i5 = ii + isec - 1;
1059 for (i = ii + uisec; i <= i5; ++i)
1061 f11 = c[i + j * c_dim1];
1062 i6 = ll + lsec - 1;
1063 for (l = ll; l <= i6; ++l)
1065 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1066 257] * b[l + j * b_dim1];
1068 c[i + j * c_dim1] = f11;
1075 return;
1077 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
1079 if (GFC_DESCRIPTOR_RANK (a) != 1)
1081 const GFC_INTEGER_8 *restrict abase_x;
1082 const GFC_INTEGER_8 *restrict bbase_y;
1083 GFC_INTEGER_8 *restrict dest_y;
1084 GFC_INTEGER_8 s;
1086 for (y = 0; y < ycount; y++)
1088 bbase_y = &bbase[y*bystride];
1089 dest_y = &dest[y*rystride];
1090 for (x = 0; x < xcount; x++)
1092 abase_x = &abase[x*axstride];
1093 s = (GFC_INTEGER_8) 0;
1094 for (n = 0; n < count; n++)
1095 s += abase_x[n] * bbase_y[n];
1096 dest_y[x] = s;
1100 else
1102 const GFC_INTEGER_8 *restrict bbase_y;
1103 GFC_INTEGER_8 s;
1105 for (y = 0; y < ycount; y++)
1107 bbase_y = &bbase[y*bystride];
1108 s = (GFC_INTEGER_8) 0;
1109 for (n = 0; n < count; n++)
1110 s += abase[n*axstride] * bbase_y[n];
1111 dest[y*rystride] = s;
1115 else if (axstride < aystride)
1117 for (y = 0; y < ycount; y++)
1118 for (x = 0; x < xcount; x++)
1119 dest[x*rxstride + y*rystride] = (GFC_INTEGER_8)0;
1121 for (y = 0; y < ycount; y++)
1122 for (n = 0; n < count; n++)
1123 for (x = 0; x < xcount; x++)
1124 /* dest[x,y] += a[x,n] * b[n,y] */
1125 dest[x*rxstride + y*rystride] +=
1126 abase[x*axstride + n*aystride] *
1127 bbase[n*bxstride + y*bystride];
1129 else if (GFC_DESCRIPTOR_RANK (a) == 1)
1131 const GFC_INTEGER_8 *restrict bbase_y;
1132 GFC_INTEGER_8 s;
1134 for (y = 0; y < ycount; y++)
1136 bbase_y = &bbase[y*bystride];
1137 s = (GFC_INTEGER_8) 0;
1138 for (n = 0; n < count; n++)
1139 s += abase[n*axstride] * bbase_y[n*bxstride];
1140 dest[y*rxstride] = s;
1143 else
1145 const GFC_INTEGER_8 *restrict abase_x;
1146 const GFC_INTEGER_8 *restrict bbase_y;
1147 GFC_INTEGER_8 *restrict dest_y;
1148 GFC_INTEGER_8 s;
1150 for (y = 0; y < ycount; y++)
1152 bbase_y = &bbase[y*bystride];
1153 dest_y = &dest[y*rystride];
1154 for (x = 0; x < xcount; x++)
1156 abase_x = &abase[x*axstride];
1157 s = (GFC_INTEGER_8) 0;
1158 for (n = 0; n < count; n++)
1159 s += abase_x[n*aystride] * bbase_y[n*bxstride];
1160 dest_y[x*rxstride] = s;
1165 #undef POW3
1166 #undef min
1167 #undef max
1169 #endif /* HAVE_AVX2 */
1171 #ifdef HAVE_AVX512F
1172 static void
1173 matmul_i8_avx512f (gfc_array_i8 * const restrict retarray,
1174 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
1175 int blas_limit, blas_call gemm) __attribute__((__target__("avx512f")));
1176 static void
1177 matmul_i8_avx512f (gfc_array_i8 * const restrict retarray,
1178 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
1179 int blas_limit, blas_call gemm)
1181 const GFC_INTEGER_8 * restrict abase;
1182 const GFC_INTEGER_8 * restrict bbase;
1183 GFC_INTEGER_8 * restrict dest;
1185 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
1186 index_type x, y, n, count, xcount, ycount;
1188 assert (GFC_DESCRIPTOR_RANK (a) == 2
1189 || GFC_DESCRIPTOR_RANK (b) == 2);
1191 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
1193 Either A or B (but not both) can be rank 1:
1195 o One-dimensional argument A is implicitly treated as a row matrix
1196 dimensioned [1,count], so xcount=1.
1198 o One-dimensional argument B is implicitly treated as a column matrix
1199 dimensioned [count, 1], so ycount=1.
1202 if (retarray->base_addr == NULL)
1204 if (GFC_DESCRIPTOR_RANK (a) == 1)
1206 GFC_DIMENSION_SET(retarray->dim[0], 0,
1207 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
1209 else if (GFC_DESCRIPTOR_RANK (b) == 1)
1211 GFC_DIMENSION_SET(retarray->dim[0], 0,
1212 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
1214 else
1216 GFC_DIMENSION_SET(retarray->dim[0], 0,
1217 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
1219 GFC_DIMENSION_SET(retarray->dim[1], 0,
1220 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
1221 GFC_DESCRIPTOR_EXTENT(retarray,0));
1224 retarray->base_addr
1225 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_8));
1226 retarray->offset = 0;
1228 else if (unlikely (compile_options.bounds_check))
1230 index_type ret_extent, arg_extent;
1232 if (GFC_DESCRIPTOR_RANK (a) == 1)
1234 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
1235 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1236 if (arg_extent != ret_extent)
1237 runtime_error ("Incorrect extent in return array in"
1238 " MATMUL intrinsic: is %ld, should be %ld",
1239 (long int) ret_extent, (long int) arg_extent);
1241 else if (GFC_DESCRIPTOR_RANK (b) == 1)
1243 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
1244 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1245 if (arg_extent != ret_extent)
1246 runtime_error ("Incorrect extent in return array in"
1247 " MATMUL intrinsic: is %ld, should be %ld",
1248 (long int) ret_extent, (long int) arg_extent);
1250 else
1252 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
1253 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1254 if (arg_extent != ret_extent)
1255 runtime_error ("Incorrect extent in return array in"
1256 " MATMUL intrinsic for dimension 1:"
1257 " is %ld, should be %ld",
1258 (long int) ret_extent, (long int) arg_extent);
1260 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
1261 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
1262 if (arg_extent != ret_extent)
1263 runtime_error ("Incorrect extent in return array in"
1264 " MATMUL intrinsic for dimension 2:"
1265 " is %ld, should be %ld",
1266 (long int) ret_extent, (long int) arg_extent);
1271 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
1273 /* One-dimensional result may be addressed in the code below
1274 either as a row or a column matrix. We want both cases to
1275 work. */
1276 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
1278 else
1280 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
1281 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
1285 if (GFC_DESCRIPTOR_RANK (a) == 1)
1287 /* Treat it as a a row matrix A[1,count]. */
1288 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
1289 aystride = 1;
1291 xcount = 1;
1292 count = GFC_DESCRIPTOR_EXTENT(a,0);
1294 else
1296 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
1297 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
1299 count = GFC_DESCRIPTOR_EXTENT(a,1);
1300 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
1303 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
1305 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
1306 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
1309 if (GFC_DESCRIPTOR_RANK (b) == 1)
1311 /* Treat it as a column matrix B[count,1] */
1312 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
1314 /* bystride should never be used for 1-dimensional b.
1315 in case it is we want it to cause a segfault, rather than
1316 an incorrect result. */
1317 bystride = 0xDEADBEEF;
1318 ycount = 1;
1320 else
1322 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
1323 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
1324 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
1327 abase = a->base_addr;
1328 bbase = b->base_addr;
1329 dest = retarray->base_addr;
1331 /* Now that everything is set up, we perform the multiplication
1332 itself. */
1334 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
1335 #define min(a,b) ((a) <= (b) ? (a) : (b))
1336 #define max(a,b) ((a) >= (b) ? (a) : (b))
1338 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
1339 && (bxstride == 1 || bystride == 1)
1340 && (((float) xcount) * ((float) ycount) * ((float) count)
1341 > POW3(blas_limit)))
1343 const int m = xcount, n = ycount, k = count, ldc = rystride;
1344 const GFC_INTEGER_8 one = 1, zero = 0;
1345 const int lda = (axstride == 1) ? aystride : axstride,
1346 ldb = (bxstride == 1) ? bystride : bxstride;
1348 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
1350 assert (gemm != NULL);
1351 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m,
1352 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
1353 &ldc, 1, 1);
1354 return;
1358 if (rxstride == 1 && axstride == 1 && bxstride == 1)
1360 /* This block of code implements a tuned matmul, derived from
1361 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
1363 Bo Kagstrom and Per Ling
1364 Department of Computing Science
1365 Umea University
1366 S-901 87 Umea, Sweden
1368 from netlib.org, translated to C, and modified for matmul.m4. */
1370 const GFC_INTEGER_8 *a, *b;
1371 GFC_INTEGER_8 *c;
1372 const index_type m = xcount, n = ycount, k = count;
1374 /* System generated locals */
1375 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
1376 i1, i2, i3, i4, i5, i6;
1378 /* Local variables */
1379 GFC_INTEGER_8 t1[65536], /* was [256][256] */
1380 f11, f12, f21, f22, f31, f32, f41, f42,
1381 f13, f14, f23, f24, f33, f34, f43, f44;
1382 index_type i, j, l, ii, jj, ll;
1383 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
1385 a = abase;
1386 b = bbase;
1387 c = retarray->base_addr;
1389 /* Parameter adjustments */
1390 c_dim1 = rystride;
1391 c_offset = 1 + c_dim1;
1392 c -= c_offset;
1393 a_dim1 = aystride;
1394 a_offset = 1 + a_dim1;
1395 a -= a_offset;
1396 b_dim1 = bystride;
1397 b_offset = 1 + b_dim1;
1398 b -= b_offset;
1400 /* Early exit if possible */
1401 if (m == 0 || n == 0 || k == 0)
1402 return;
1404 /* Empty c first. */
1405 for (j=1; j<=n; j++)
1406 for (i=1; i<=m; i++)
1407 c[i + j * c_dim1] = (GFC_INTEGER_8)0;
1409 /* Start turning the crank. */
1410 i1 = n;
1411 for (jj = 1; jj <= i1; jj += 512)
1413 /* Computing MIN */
1414 i2 = 512;
1415 i3 = n - jj + 1;
1416 jsec = min(i2,i3);
1417 ujsec = jsec - jsec % 4;
1418 i2 = k;
1419 for (ll = 1; ll <= i2; ll += 256)
1421 /* Computing MIN */
1422 i3 = 256;
1423 i4 = k - ll + 1;
1424 lsec = min(i3,i4);
1425 ulsec = lsec - lsec % 2;
1427 i3 = m;
1428 for (ii = 1; ii <= i3; ii += 256)
1430 /* Computing MIN */
1431 i4 = 256;
1432 i5 = m - ii + 1;
1433 isec = min(i4,i5);
1434 uisec = isec - isec % 2;
1435 i4 = ll + ulsec - 1;
1436 for (l = ll; l <= i4; l += 2)
1438 i5 = ii + uisec - 1;
1439 for (i = ii; i <= i5; i += 2)
1441 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
1442 a[i + l * a_dim1];
1443 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
1444 a[i + (l + 1) * a_dim1];
1445 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
1446 a[i + 1 + l * a_dim1];
1447 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
1448 a[i + 1 + (l + 1) * a_dim1];
1450 if (uisec < isec)
1452 t1[l - ll + 1 + (isec << 8) - 257] =
1453 a[ii + isec - 1 + l * a_dim1];
1454 t1[l - ll + 2 + (isec << 8) - 257] =
1455 a[ii + isec - 1 + (l + 1) * a_dim1];
1458 if (ulsec < lsec)
1460 i4 = ii + isec - 1;
1461 for (i = ii; i<= i4; ++i)
1463 t1[lsec + ((i - ii + 1) << 8) - 257] =
1464 a[i + (ll + lsec - 1) * a_dim1];
1468 uisec = isec - isec % 4;
1469 i4 = jj + ujsec - 1;
1470 for (j = jj; j <= i4; j += 4)
1472 i5 = ii + uisec - 1;
1473 for (i = ii; i <= i5; i += 4)
1475 f11 = c[i + j * c_dim1];
1476 f21 = c[i + 1 + j * c_dim1];
1477 f12 = c[i + (j + 1) * c_dim1];
1478 f22 = c[i + 1 + (j + 1) * c_dim1];
1479 f13 = c[i + (j + 2) * c_dim1];
1480 f23 = c[i + 1 + (j + 2) * c_dim1];
1481 f14 = c[i + (j + 3) * c_dim1];
1482 f24 = c[i + 1 + (j + 3) * c_dim1];
1483 f31 = c[i + 2 + j * c_dim1];
1484 f41 = c[i + 3 + j * c_dim1];
1485 f32 = c[i + 2 + (j + 1) * c_dim1];
1486 f42 = c[i + 3 + (j + 1) * c_dim1];
1487 f33 = c[i + 2 + (j + 2) * c_dim1];
1488 f43 = c[i + 3 + (j + 2) * c_dim1];
1489 f34 = c[i + 2 + (j + 3) * c_dim1];
1490 f44 = c[i + 3 + (j + 3) * c_dim1];
1491 i6 = ll + lsec - 1;
1492 for (l = ll; l <= i6; ++l)
1494 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1495 * b[l + j * b_dim1];
1496 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1497 * b[l + j * b_dim1];
1498 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1499 * b[l + (j + 1) * b_dim1];
1500 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1501 * b[l + (j + 1) * b_dim1];
1502 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1503 * b[l + (j + 2) * b_dim1];
1504 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1505 * b[l + (j + 2) * b_dim1];
1506 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
1507 * b[l + (j + 3) * b_dim1];
1508 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
1509 * b[l + (j + 3) * b_dim1];
1510 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1511 * b[l + j * b_dim1];
1512 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1513 * b[l + j * b_dim1];
1514 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1515 * b[l + (j + 1) * b_dim1];
1516 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1517 * b[l + (j + 1) * b_dim1];
1518 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1519 * b[l + (j + 2) * b_dim1];
1520 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1521 * b[l + (j + 2) * b_dim1];
1522 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
1523 * b[l + (j + 3) * b_dim1];
1524 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
1525 * b[l + (j + 3) * b_dim1];
1527 c[i + j * c_dim1] = f11;
1528 c[i + 1 + j * c_dim1] = f21;
1529 c[i + (j + 1) * c_dim1] = f12;
1530 c[i + 1 + (j + 1) * c_dim1] = f22;
1531 c[i + (j + 2) * c_dim1] = f13;
1532 c[i + 1 + (j + 2) * c_dim1] = f23;
1533 c[i + (j + 3) * c_dim1] = f14;
1534 c[i + 1 + (j + 3) * c_dim1] = f24;
1535 c[i + 2 + j * c_dim1] = f31;
1536 c[i + 3 + j * c_dim1] = f41;
1537 c[i + 2 + (j + 1) * c_dim1] = f32;
1538 c[i + 3 + (j + 1) * c_dim1] = f42;
1539 c[i + 2 + (j + 2) * c_dim1] = f33;
1540 c[i + 3 + (j + 2) * c_dim1] = f43;
1541 c[i + 2 + (j + 3) * c_dim1] = f34;
1542 c[i + 3 + (j + 3) * c_dim1] = f44;
1544 if (uisec < isec)
1546 i5 = ii + isec - 1;
1547 for (i = ii + uisec; i <= i5; ++i)
1549 f11 = c[i + j * c_dim1];
1550 f12 = c[i + (j + 1) * c_dim1];
1551 f13 = c[i + (j + 2) * c_dim1];
1552 f14 = c[i + (j + 3) * c_dim1];
1553 i6 = ll + lsec - 1;
1554 for (l = ll; l <= i6; ++l)
1556 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1557 257] * b[l + j * b_dim1];
1558 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1559 257] * b[l + (j + 1) * b_dim1];
1560 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1561 257] * b[l + (j + 2) * b_dim1];
1562 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1563 257] * b[l + (j + 3) * b_dim1];
1565 c[i + j * c_dim1] = f11;
1566 c[i + (j + 1) * c_dim1] = f12;
1567 c[i + (j + 2) * c_dim1] = f13;
1568 c[i + (j + 3) * c_dim1] = f14;
1572 if (ujsec < jsec)
1574 i4 = jj + jsec - 1;
1575 for (j = jj + ujsec; j <= i4; ++j)
1577 i5 = ii + uisec - 1;
1578 for (i = ii; i <= i5; i += 4)
1580 f11 = c[i + j * c_dim1];
1581 f21 = c[i + 1 + j * c_dim1];
1582 f31 = c[i + 2 + j * c_dim1];
1583 f41 = c[i + 3 + j * c_dim1];
1584 i6 = ll + lsec - 1;
1585 for (l = ll; l <= i6; ++l)
1587 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1588 257] * b[l + j * b_dim1];
1589 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
1590 257] * b[l + j * b_dim1];
1591 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
1592 257] * b[l + j * b_dim1];
1593 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
1594 257] * b[l + j * b_dim1];
1596 c[i + j * c_dim1] = f11;
1597 c[i + 1 + j * c_dim1] = f21;
1598 c[i + 2 + j * c_dim1] = f31;
1599 c[i + 3 + j * c_dim1] = f41;
1601 i5 = ii + isec - 1;
1602 for (i = ii + uisec; i <= i5; ++i)
1604 f11 = c[i + j * c_dim1];
1605 i6 = ll + lsec - 1;
1606 for (l = ll; l <= i6; ++l)
1608 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1609 257] * b[l + j * b_dim1];
1611 c[i + j * c_dim1] = f11;
1618 return;
1620 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
1622 if (GFC_DESCRIPTOR_RANK (a) != 1)
1624 const GFC_INTEGER_8 *restrict abase_x;
1625 const GFC_INTEGER_8 *restrict bbase_y;
1626 GFC_INTEGER_8 *restrict dest_y;
1627 GFC_INTEGER_8 s;
1629 for (y = 0; y < ycount; y++)
1631 bbase_y = &bbase[y*bystride];
1632 dest_y = &dest[y*rystride];
1633 for (x = 0; x < xcount; x++)
1635 abase_x = &abase[x*axstride];
1636 s = (GFC_INTEGER_8) 0;
1637 for (n = 0; n < count; n++)
1638 s += abase_x[n] * bbase_y[n];
1639 dest_y[x] = s;
1643 else
1645 const GFC_INTEGER_8 *restrict bbase_y;
1646 GFC_INTEGER_8 s;
1648 for (y = 0; y < ycount; y++)
1650 bbase_y = &bbase[y*bystride];
1651 s = (GFC_INTEGER_8) 0;
1652 for (n = 0; n < count; n++)
1653 s += abase[n*axstride] * bbase_y[n];
1654 dest[y*rystride] = s;
1658 else if (axstride < aystride)
1660 for (y = 0; y < ycount; y++)
1661 for (x = 0; x < xcount; x++)
1662 dest[x*rxstride + y*rystride] = (GFC_INTEGER_8)0;
1664 for (y = 0; y < ycount; y++)
1665 for (n = 0; n < count; n++)
1666 for (x = 0; x < xcount; x++)
1667 /* dest[x,y] += a[x,n] * b[n,y] */
1668 dest[x*rxstride + y*rystride] +=
1669 abase[x*axstride + n*aystride] *
1670 bbase[n*bxstride + y*bystride];
1672 else if (GFC_DESCRIPTOR_RANK (a) == 1)
1674 const GFC_INTEGER_8 *restrict bbase_y;
1675 GFC_INTEGER_8 s;
1677 for (y = 0; y < ycount; y++)
1679 bbase_y = &bbase[y*bystride];
1680 s = (GFC_INTEGER_8) 0;
1681 for (n = 0; n < count; n++)
1682 s += abase[n*axstride] * bbase_y[n*bxstride];
1683 dest[y*rxstride] = s;
1686 else
1688 const GFC_INTEGER_8 *restrict abase_x;
1689 const GFC_INTEGER_8 *restrict bbase_y;
1690 GFC_INTEGER_8 *restrict dest_y;
1691 GFC_INTEGER_8 s;
1693 for (y = 0; y < ycount; y++)
1695 bbase_y = &bbase[y*bystride];
1696 dest_y = &dest[y*rystride];
1697 for (x = 0; x < xcount; x++)
1699 abase_x = &abase[x*axstride];
1700 s = (GFC_INTEGER_8) 0;
1701 for (n = 0; n < count; n++)
1702 s += abase_x[n*aystride] * bbase_y[n*bxstride];
1703 dest_y[x*rxstride] = s;
1708 #undef POW3
1709 #undef min
1710 #undef max
1712 #endif /* HAVE_AVX512F */
1714 /* Function to fall back to if there is no special processor-specific version. */
1715 static void
1716 matmul_i8_vanilla (gfc_array_i8 * const restrict retarray,
1717 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
1718 int blas_limit, blas_call gemm)
1720 const GFC_INTEGER_8 * restrict abase;
1721 const GFC_INTEGER_8 * restrict bbase;
1722 GFC_INTEGER_8 * restrict dest;
1724 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
1725 index_type x, y, n, count, xcount, ycount;
1727 assert (GFC_DESCRIPTOR_RANK (a) == 2
1728 || GFC_DESCRIPTOR_RANK (b) == 2);
1730 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
1732 Either A or B (but not both) can be rank 1:
1734 o One-dimensional argument A is implicitly treated as a row matrix
1735 dimensioned [1,count], so xcount=1.
1737 o One-dimensional argument B is implicitly treated as a column matrix
1738 dimensioned [count, 1], so ycount=1.
1741 if (retarray->base_addr == NULL)
1743 if (GFC_DESCRIPTOR_RANK (a) == 1)
1745 GFC_DIMENSION_SET(retarray->dim[0], 0,
1746 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
1748 else if (GFC_DESCRIPTOR_RANK (b) == 1)
1750 GFC_DIMENSION_SET(retarray->dim[0], 0,
1751 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
1753 else
1755 GFC_DIMENSION_SET(retarray->dim[0], 0,
1756 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
1758 GFC_DIMENSION_SET(retarray->dim[1], 0,
1759 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
1760 GFC_DESCRIPTOR_EXTENT(retarray,0));
1763 retarray->base_addr
1764 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_8));
1765 retarray->offset = 0;
1767 else if (unlikely (compile_options.bounds_check))
1769 index_type ret_extent, arg_extent;
1771 if (GFC_DESCRIPTOR_RANK (a) == 1)
1773 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
1774 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1775 if (arg_extent != ret_extent)
1776 runtime_error ("Incorrect extent in return array in"
1777 " MATMUL intrinsic: is %ld, should be %ld",
1778 (long int) ret_extent, (long int) arg_extent);
1780 else if (GFC_DESCRIPTOR_RANK (b) == 1)
1782 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
1783 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1784 if (arg_extent != ret_extent)
1785 runtime_error ("Incorrect extent in return array in"
1786 " MATMUL intrinsic: is %ld, should be %ld",
1787 (long int) ret_extent, (long int) arg_extent);
1789 else
1791 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
1792 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
1793 if (arg_extent != ret_extent)
1794 runtime_error ("Incorrect extent in return array in"
1795 " MATMUL intrinsic for dimension 1:"
1796 " is %ld, should be %ld",
1797 (long int) ret_extent, (long int) arg_extent);
1799 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
1800 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
1801 if (arg_extent != ret_extent)
1802 runtime_error ("Incorrect extent in return array in"
1803 " MATMUL intrinsic for dimension 2:"
1804 " is %ld, should be %ld",
1805 (long int) ret_extent, (long int) arg_extent);
1810 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
1812 /* One-dimensional result may be addressed in the code below
1813 either as a row or a column matrix. We want both cases to
1814 work. */
1815 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
1817 else
1819 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
1820 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
1824 if (GFC_DESCRIPTOR_RANK (a) == 1)
1826 /* Treat it as a a row matrix A[1,count]. */
1827 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
1828 aystride = 1;
1830 xcount = 1;
1831 count = GFC_DESCRIPTOR_EXTENT(a,0);
1833 else
1835 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
1836 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
1838 count = GFC_DESCRIPTOR_EXTENT(a,1);
1839 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
1842 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
1844 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
1845 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
1848 if (GFC_DESCRIPTOR_RANK (b) == 1)
1850 /* Treat it as a column matrix B[count,1] */
1851 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
1853 /* bystride should never be used for 1-dimensional b.
1854 in case it is we want it to cause a segfault, rather than
1855 an incorrect result. */
1856 bystride = 0xDEADBEEF;
1857 ycount = 1;
1859 else
1861 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
1862 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
1863 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
1866 abase = a->base_addr;
1867 bbase = b->base_addr;
1868 dest = retarray->base_addr;
1870 /* Now that everything is set up, we perform the multiplication
1871 itself. */
1873 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
1874 #define min(a,b) ((a) <= (b) ? (a) : (b))
1875 #define max(a,b) ((a) >= (b) ? (a) : (b))
1877 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
1878 && (bxstride == 1 || bystride == 1)
1879 && (((float) xcount) * ((float) ycount) * ((float) count)
1880 > POW3(blas_limit)))
1882 const int m = xcount, n = ycount, k = count, ldc = rystride;
1883 const GFC_INTEGER_8 one = 1, zero = 0;
1884 const int lda = (axstride == 1) ? aystride : axstride,
1885 ldb = (bxstride == 1) ? bystride : bxstride;
1887 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
1889 assert (gemm != NULL);
1890 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m,
1891 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
1892 &ldc, 1, 1);
1893 return;
1897 if (rxstride == 1 && axstride == 1 && bxstride == 1)
1899 /* This block of code implements a tuned matmul, derived from
1900 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
1902 Bo Kagstrom and Per Ling
1903 Department of Computing Science
1904 Umea University
1905 S-901 87 Umea, Sweden
1907 from netlib.org, translated to C, and modified for matmul.m4. */
1909 const GFC_INTEGER_8 *a, *b;
1910 GFC_INTEGER_8 *c;
1911 const index_type m = xcount, n = ycount, k = count;
1913 /* System generated locals */
1914 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
1915 i1, i2, i3, i4, i5, i6;
1917 /* Local variables */
1918 GFC_INTEGER_8 t1[65536], /* was [256][256] */
1919 f11, f12, f21, f22, f31, f32, f41, f42,
1920 f13, f14, f23, f24, f33, f34, f43, f44;
1921 index_type i, j, l, ii, jj, ll;
1922 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
1924 a = abase;
1925 b = bbase;
1926 c = retarray->base_addr;
1928 /* Parameter adjustments */
1929 c_dim1 = rystride;
1930 c_offset = 1 + c_dim1;
1931 c -= c_offset;
1932 a_dim1 = aystride;
1933 a_offset = 1 + a_dim1;
1934 a -= a_offset;
1935 b_dim1 = bystride;
1936 b_offset = 1 + b_dim1;
1937 b -= b_offset;
1939 /* Early exit if possible */
1940 if (m == 0 || n == 0 || k == 0)
1941 return;
1943 /* Empty c first. */
1944 for (j=1; j<=n; j++)
1945 for (i=1; i<=m; i++)
1946 c[i + j * c_dim1] = (GFC_INTEGER_8)0;
1948 /* Start turning the crank. */
1949 i1 = n;
1950 for (jj = 1; jj <= i1; jj += 512)
1952 /* Computing MIN */
1953 i2 = 512;
1954 i3 = n - jj + 1;
1955 jsec = min(i2,i3);
1956 ujsec = jsec - jsec % 4;
1957 i2 = k;
1958 for (ll = 1; ll <= i2; ll += 256)
1960 /* Computing MIN */
1961 i3 = 256;
1962 i4 = k - ll + 1;
1963 lsec = min(i3,i4);
1964 ulsec = lsec - lsec % 2;
1966 i3 = m;
1967 for (ii = 1; ii <= i3; ii += 256)
1969 /* Computing MIN */
1970 i4 = 256;
1971 i5 = m - ii + 1;
1972 isec = min(i4,i5);
1973 uisec = isec - isec % 2;
1974 i4 = ll + ulsec - 1;
1975 for (l = ll; l <= i4; l += 2)
1977 i5 = ii + uisec - 1;
1978 for (i = ii; i <= i5; i += 2)
1980 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
1981 a[i + l * a_dim1];
1982 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
1983 a[i + (l + 1) * a_dim1];
1984 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
1985 a[i + 1 + l * a_dim1];
1986 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
1987 a[i + 1 + (l + 1) * a_dim1];
1989 if (uisec < isec)
1991 t1[l - ll + 1 + (isec << 8) - 257] =
1992 a[ii + isec - 1 + l * a_dim1];
1993 t1[l - ll + 2 + (isec << 8) - 257] =
1994 a[ii + isec - 1 + (l + 1) * a_dim1];
1997 if (ulsec < lsec)
1999 i4 = ii + isec - 1;
2000 for (i = ii; i<= i4; ++i)
2002 t1[lsec + ((i - ii + 1) << 8) - 257] =
2003 a[i + (ll + lsec - 1) * a_dim1];
2007 uisec = isec - isec % 4;
2008 i4 = jj + ujsec - 1;
2009 for (j = jj; j <= i4; j += 4)
2011 i5 = ii + uisec - 1;
2012 for (i = ii; i <= i5; i += 4)
2014 f11 = c[i + j * c_dim1];
2015 f21 = c[i + 1 + j * c_dim1];
2016 f12 = c[i + (j + 1) * c_dim1];
2017 f22 = c[i + 1 + (j + 1) * c_dim1];
2018 f13 = c[i + (j + 2) * c_dim1];
2019 f23 = c[i + 1 + (j + 2) * c_dim1];
2020 f14 = c[i + (j + 3) * c_dim1];
2021 f24 = c[i + 1 + (j + 3) * c_dim1];
2022 f31 = c[i + 2 + j * c_dim1];
2023 f41 = c[i + 3 + j * c_dim1];
2024 f32 = c[i + 2 + (j + 1) * c_dim1];
2025 f42 = c[i + 3 + (j + 1) * c_dim1];
2026 f33 = c[i + 2 + (j + 2) * c_dim1];
2027 f43 = c[i + 3 + (j + 2) * c_dim1];
2028 f34 = c[i + 2 + (j + 3) * c_dim1];
2029 f44 = c[i + 3 + (j + 3) * c_dim1];
2030 i6 = ll + lsec - 1;
2031 for (l = ll; l <= i6; ++l)
2033 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2034 * b[l + j * b_dim1];
2035 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2036 * b[l + j * b_dim1];
2037 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2038 * b[l + (j + 1) * b_dim1];
2039 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2040 * b[l + (j + 1) * b_dim1];
2041 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2042 * b[l + (j + 2) * b_dim1];
2043 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2044 * b[l + (j + 2) * b_dim1];
2045 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2046 * b[l + (j + 3) * b_dim1];
2047 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2048 * b[l + (j + 3) * b_dim1];
2049 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2050 * b[l + j * b_dim1];
2051 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2052 * b[l + j * b_dim1];
2053 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2054 * b[l + (j + 1) * b_dim1];
2055 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2056 * b[l + (j + 1) * b_dim1];
2057 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2058 * b[l + (j + 2) * b_dim1];
2059 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2060 * b[l + (j + 2) * b_dim1];
2061 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2062 * b[l + (j + 3) * b_dim1];
2063 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2064 * b[l + (j + 3) * b_dim1];
2066 c[i + j * c_dim1] = f11;
2067 c[i + 1 + j * c_dim1] = f21;
2068 c[i + (j + 1) * c_dim1] = f12;
2069 c[i + 1 + (j + 1) * c_dim1] = f22;
2070 c[i + (j + 2) * c_dim1] = f13;
2071 c[i + 1 + (j + 2) * c_dim1] = f23;
2072 c[i + (j + 3) * c_dim1] = f14;
2073 c[i + 1 + (j + 3) * c_dim1] = f24;
2074 c[i + 2 + j * c_dim1] = f31;
2075 c[i + 3 + j * c_dim1] = f41;
2076 c[i + 2 + (j + 1) * c_dim1] = f32;
2077 c[i + 3 + (j + 1) * c_dim1] = f42;
2078 c[i + 2 + (j + 2) * c_dim1] = f33;
2079 c[i + 3 + (j + 2) * c_dim1] = f43;
2080 c[i + 2 + (j + 3) * c_dim1] = f34;
2081 c[i + 3 + (j + 3) * c_dim1] = f44;
2083 if (uisec < isec)
2085 i5 = ii + isec - 1;
2086 for (i = ii + uisec; i <= i5; ++i)
2088 f11 = c[i + j * c_dim1];
2089 f12 = c[i + (j + 1) * c_dim1];
2090 f13 = c[i + (j + 2) * c_dim1];
2091 f14 = c[i + (j + 3) * c_dim1];
2092 i6 = ll + lsec - 1;
2093 for (l = ll; l <= i6; ++l)
2095 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2096 257] * b[l + j * b_dim1];
2097 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2098 257] * b[l + (j + 1) * b_dim1];
2099 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2100 257] * b[l + (j + 2) * b_dim1];
2101 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2102 257] * b[l + (j + 3) * b_dim1];
2104 c[i + j * c_dim1] = f11;
2105 c[i + (j + 1) * c_dim1] = f12;
2106 c[i + (j + 2) * c_dim1] = f13;
2107 c[i + (j + 3) * c_dim1] = f14;
2111 if (ujsec < jsec)
2113 i4 = jj + jsec - 1;
2114 for (j = jj + ujsec; j <= i4; ++j)
2116 i5 = ii + uisec - 1;
2117 for (i = ii; i <= i5; i += 4)
2119 f11 = c[i + j * c_dim1];
2120 f21 = c[i + 1 + j * c_dim1];
2121 f31 = c[i + 2 + j * c_dim1];
2122 f41 = c[i + 3 + j * c_dim1];
2123 i6 = ll + lsec - 1;
2124 for (l = ll; l <= i6; ++l)
2126 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2127 257] * b[l + j * b_dim1];
2128 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
2129 257] * b[l + j * b_dim1];
2130 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
2131 257] * b[l + j * b_dim1];
2132 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
2133 257] * b[l + j * b_dim1];
2135 c[i + j * c_dim1] = f11;
2136 c[i + 1 + j * c_dim1] = f21;
2137 c[i + 2 + j * c_dim1] = f31;
2138 c[i + 3 + j * c_dim1] = f41;
2140 i5 = ii + isec - 1;
2141 for (i = ii + uisec; i <= i5; ++i)
2143 f11 = c[i + j * c_dim1];
2144 i6 = ll + lsec - 1;
2145 for (l = ll; l <= i6; ++l)
2147 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2148 257] * b[l + j * b_dim1];
2150 c[i + j * c_dim1] = f11;
2157 return;
2159 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
2161 if (GFC_DESCRIPTOR_RANK (a) != 1)
2163 const GFC_INTEGER_8 *restrict abase_x;
2164 const GFC_INTEGER_8 *restrict bbase_y;
2165 GFC_INTEGER_8 *restrict dest_y;
2166 GFC_INTEGER_8 s;
2168 for (y = 0; y < ycount; y++)
2170 bbase_y = &bbase[y*bystride];
2171 dest_y = &dest[y*rystride];
2172 for (x = 0; x < xcount; x++)
2174 abase_x = &abase[x*axstride];
2175 s = (GFC_INTEGER_8) 0;
2176 for (n = 0; n < count; n++)
2177 s += abase_x[n] * bbase_y[n];
2178 dest_y[x] = s;
2182 else
2184 const GFC_INTEGER_8 *restrict bbase_y;
2185 GFC_INTEGER_8 s;
2187 for (y = 0; y < ycount; y++)
2189 bbase_y = &bbase[y*bystride];
2190 s = (GFC_INTEGER_8) 0;
2191 for (n = 0; n < count; n++)
2192 s += abase[n*axstride] * bbase_y[n];
2193 dest[y*rystride] = s;
2197 else if (axstride < aystride)
2199 for (y = 0; y < ycount; y++)
2200 for (x = 0; x < xcount; x++)
2201 dest[x*rxstride + y*rystride] = (GFC_INTEGER_8)0;
2203 for (y = 0; y < ycount; y++)
2204 for (n = 0; n < count; n++)
2205 for (x = 0; x < xcount; x++)
2206 /* dest[x,y] += a[x,n] * b[n,y] */
2207 dest[x*rxstride + y*rystride] +=
2208 abase[x*axstride + n*aystride] *
2209 bbase[n*bxstride + y*bystride];
2211 else if (GFC_DESCRIPTOR_RANK (a) == 1)
2213 const GFC_INTEGER_8 *restrict bbase_y;
2214 GFC_INTEGER_8 s;
2216 for (y = 0; y < ycount; y++)
2218 bbase_y = &bbase[y*bystride];
2219 s = (GFC_INTEGER_8) 0;
2220 for (n = 0; n < count; n++)
2221 s += abase[n*axstride] * bbase_y[n*bxstride];
2222 dest[y*rxstride] = s;
2225 else
2227 const GFC_INTEGER_8 *restrict abase_x;
2228 const GFC_INTEGER_8 *restrict bbase_y;
2229 GFC_INTEGER_8 *restrict dest_y;
2230 GFC_INTEGER_8 s;
2232 for (y = 0; y < ycount; y++)
2234 bbase_y = &bbase[y*bystride];
2235 dest_y = &dest[y*rystride];
2236 for (x = 0; x < xcount; x++)
2238 abase_x = &abase[x*axstride];
2239 s = (GFC_INTEGER_8) 0;
2240 for (n = 0; n < count; n++)
2241 s += abase_x[n*aystride] * bbase_y[n*bxstride];
2242 dest_y[x*rxstride] = s;
2247 #undef POW3
2248 #undef min
2249 #undef max
2252 /* Compiling main function, with selection code for the processor. */
2254 /* Currently, this is i386 only. Adjust for other architectures. */
2256 #include <config/i386/cpuinfo.h>
2257 void matmul_i8 (gfc_array_i8 * const restrict retarray,
2258 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
2259 int blas_limit, blas_call gemm)
2261 static void (*matmul_p) (gfc_array_i8 * const restrict retarray,
2262 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
2263 int blas_limit, blas_call gemm) = NULL;
2265 if (matmul_p == NULL)
2267 matmul_p = matmul_i8_vanilla;
2268 if (__cpu_model.__cpu_vendor == VENDOR_INTEL)
2270 /* Run down the available processors in order of preference. */
2271 #ifdef HAVE_AVX512F
2272 if (__cpu_model.__cpu_features[0] & (1 << FEATURE_AVX512F))
2274 matmul_p = matmul_i8_avx512f;
2275 goto tailcall;
2278 #endif /* HAVE_AVX512F */
2280 #ifdef HAVE_AVX2
2281 if (__cpu_model.__cpu_features[0] & (1 << FEATURE_AVX2))
2283 matmul_p = matmul_i8_avx2;
2284 goto tailcall;
2287 #endif
2289 #ifdef HAVE_AVX
2290 if (__cpu_model.__cpu_features[0] & (1 << FEATURE_AVX))
2292 matmul_p = matmul_i8_avx;
2293 goto tailcall;
2295 #endif /* HAVE_AVX */
2299 tailcall:
2300 (*matmul_p) (retarray, a, b, try_blas, blas_limit, gemm);
2303 #else /* Just the vanilla function. */
2305 void
2306 matmul_i8 (gfc_array_i8 * const restrict retarray,
2307 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
2308 int blas_limit, blas_call gemm)
2310 const GFC_INTEGER_8 * restrict abase;
2311 const GFC_INTEGER_8 * restrict bbase;
2312 GFC_INTEGER_8 * restrict dest;
2314 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
2315 index_type x, y, n, count, xcount, ycount;
2317 assert (GFC_DESCRIPTOR_RANK (a) == 2
2318 || GFC_DESCRIPTOR_RANK (b) == 2);
2320 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
2322 Either A or B (but not both) can be rank 1:
2324 o One-dimensional argument A is implicitly treated as a row matrix
2325 dimensioned [1,count], so xcount=1.
2327 o One-dimensional argument B is implicitly treated as a column matrix
2328 dimensioned [count, 1], so ycount=1.
2331 if (retarray->base_addr == NULL)
2333 if (GFC_DESCRIPTOR_RANK (a) == 1)
2335 GFC_DIMENSION_SET(retarray->dim[0], 0,
2336 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
2338 else if (GFC_DESCRIPTOR_RANK (b) == 1)
2340 GFC_DIMENSION_SET(retarray->dim[0], 0,
2341 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
2343 else
2345 GFC_DIMENSION_SET(retarray->dim[0], 0,
2346 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
2348 GFC_DIMENSION_SET(retarray->dim[1], 0,
2349 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
2350 GFC_DESCRIPTOR_EXTENT(retarray,0));
2353 retarray->base_addr
2354 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_8));
2355 retarray->offset = 0;
2357 else if (unlikely (compile_options.bounds_check))
2359 index_type ret_extent, arg_extent;
2361 if (GFC_DESCRIPTOR_RANK (a) == 1)
2363 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
2364 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
2365 if (arg_extent != ret_extent)
2366 runtime_error ("Incorrect extent in return array in"
2367 " MATMUL intrinsic: is %ld, should be %ld",
2368 (long int) ret_extent, (long int) arg_extent);
2370 else if (GFC_DESCRIPTOR_RANK (b) == 1)
2372 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
2373 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
2374 if (arg_extent != ret_extent)
2375 runtime_error ("Incorrect extent in return array in"
2376 " MATMUL intrinsic: is %ld, should be %ld",
2377 (long int) ret_extent, (long int) arg_extent);
2379 else
2381 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
2382 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
2383 if (arg_extent != ret_extent)
2384 runtime_error ("Incorrect extent in return array in"
2385 " MATMUL intrinsic for dimension 1:"
2386 " is %ld, should be %ld",
2387 (long int) ret_extent, (long int) arg_extent);
2389 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
2390 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
2391 if (arg_extent != ret_extent)
2392 runtime_error ("Incorrect extent in return array in"
2393 " MATMUL intrinsic for dimension 2:"
2394 " is %ld, should be %ld",
2395 (long int) ret_extent, (long int) arg_extent);
2400 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
2402 /* One-dimensional result may be addressed in the code below
2403 either as a row or a column matrix. We want both cases to
2404 work. */
2405 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
2407 else
2409 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
2410 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
2414 if (GFC_DESCRIPTOR_RANK (a) == 1)
2416 /* Treat it as a a row matrix A[1,count]. */
2417 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
2418 aystride = 1;
2420 xcount = 1;
2421 count = GFC_DESCRIPTOR_EXTENT(a,0);
2423 else
2425 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
2426 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
2428 count = GFC_DESCRIPTOR_EXTENT(a,1);
2429 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
2432 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
2434 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
2435 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
2438 if (GFC_DESCRIPTOR_RANK (b) == 1)
2440 /* Treat it as a column matrix B[count,1] */
2441 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
2443 /* bystride should never be used for 1-dimensional b.
2444 in case it is we want it to cause a segfault, rather than
2445 an incorrect result. */
2446 bystride = 0xDEADBEEF;
2447 ycount = 1;
2449 else
2451 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
2452 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
2453 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
2456 abase = a->base_addr;
2457 bbase = b->base_addr;
2458 dest = retarray->base_addr;
2460 /* Now that everything is set up, we perform the multiplication
2461 itself. */
2463 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
2464 #define min(a,b) ((a) <= (b) ? (a) : (b))
2465 #define max(a,b) ((a) >= (b) ? (a) : (b))
2467 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
2468 && (bxstride == 1 || bystride == 1)
2469 && (((float) xcount) * ((float) ycount) * ((float) count)
2470 > POW3(blas_limit)))
2472 const int m = xcount, n = ycount, k = count, ldc = rystride;
2473 const GFC_INTEGER_8 one = 1, zero = 0;
2474 const int lda = (axstride == 1) ? aystride : axstride,
2475 ldb = (bxstride == 1) ? bystride : bxstride;
2477 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
2479 assert (gemm != NULL);
2480 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m,
2481 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
2482 &ldc, 1, 1);
2483 return;
2487 if (rxstride == 1 && axstride == 1 && bxstride == 1)
2489 /* This block of code implements a tuned matmul, derived from
2490 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
2492 Bo Kagstrom and Per Ling
2493 Department of Computing Science
2494 Umea University
2495 S-901 87 Umea, Sweden
2497 from netlib.org, translated to C, and modified for matmul.m4. */
2499 const GFC_INTEGER_8 *a, *b;
2500 GFC_INTEGER_8 *c;
2501 const index_type m = xcount, n = ycount, k = count;
2503 /* System generated locals */
2504 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
2505 i1, i2, i3, i4, i5, i6;
2507 /* Local variables */
2508 GFC_INTEGER_8 t1[65536], /* was [256][256] */
2509 f11, f12, f21, f22, f31, f32, f41, f42,
2510 f13, f14, f23, f24, f33, f34, f43, f44;
2511 index_type i, j, l, ii, jj, ll;
2512 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
2514 a = abase;
2515 b = bbase;
2516 c = retarray->base_addr;
2518 /* Parameter adjustments */
2519 c_dim1 = rystride;
2520 c_offset = 1 + c_dim1;
2521 c -= c_offset;
2522 a_dim1 = aystride;
2523 a_offset = 1 + a_dim1;
2524 a -= a_offset;
2525 b_dim1 = bystride;
2526 b_offset = 1 + b_dim1;
2527 b -= b_offset;
2529 /* Early exit if possible */
2530 if (m == 0 || n == 0 || k == 0)
2531 return;
2533 /* Empty c first. */
2534 for (j=1; j<=n; j++)
2535 for (i=1; i<=m; i++)
2536 c[i + j * c_dim1] = (GFC_INTEGER_8)0;
2538 /* Start turning the crank. */
2539 i1 = n;
2540 for (jj = 1; jj <= i1; jj += 512)
2542 /* Computing MIN */
2543 i2 = 512;
2544 i3 = n - jj + 1;
2545 jsec = min(i2,i3);
2546 ujsec = jsec - jsec % 4;
2547 i2 = k;
2548 for (ll = 1; ll <= i2; ll += 256)
2550 /* Computing MIN */
2551 i3 = 256;
2552 i4 = k - ll + 1;
2553 lsec = min(i3,i4);
2554 ulsec = lsec - lsec % 2;
2556 i3 = m;
2557 for (ii = 1; ii <= i3; ii += 256)
2559 /* Computing MIN */
2560 i4 = 256;
2561 i5 = m - ii + 1;
2562 isec = min(i4,i5);
2563 uisec = isec - isec % 2;
2564 i4 = ll + ulsec - 1;
2565 for (l = ll; l <= i4; l += 2)
2567 i5 = ii + uisec - 1;
2568 for (i = ii; i <= i5; i += 2)
2570 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
2571 a[i + l * a_dim1];
2572 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
2573 a[i + (l + 1) * a_dim1];
2574 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
2575 a[i + 1 + l * a_dim1];
2576 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
2577 a[i + 1 + (l + 1) * a_dim1];
2579 if (uisec < isec)
2581 t1[l - ll + 1 + (isec << 8) - 257] =
2582 a[ii + isec - 1 + l * a_dim1];
2583 t1[l - ll + 2 + (isec << 8) - 257] =
2584 a[ii + isec - 1 + (l + 1) * a_dim1];
2587 if (ulsec < lsec)
2589 i4 = ii + isec - 1;
2590 for (i = ii; i<= i4; ++i)
2592 t1[lsec + ((i - ii + 1) << 8) - 257] =
2593 a[i + (ll + lsec - 1) * a_dim1];
2597 uisec = isec - isec % 4;
2598 i4 = jj + ujsec - 1;
2599 for (j = jj; j <= i4; j += 4)
2601 i5 = ii + uisec - 1;
2602 for (i = ii; i <= i5; i += 4)
2604 f11 = c[i + j * c_dim1];
2605 f21 = c[i + 1 + j * c_dim1];
2606 f12 = c[i + (j + 1) * c_dim1];
2607 f22 = c[i + 1 + (j + 1) * c_dim1];
2608 f13 = c[i + (j + 2) * c_dim1];
2609 f23 = c[i + 1 + (j + 2) * c_dim1];
2610 f14 = c[i + (j + 3) * c_dim1];
2611 f24 = c[i + 1 + (j + 3) * c_dim1];
2612 f31 = c[i + 2 + j * c_dim1];
2613 f41 = c[i + 3 + j * c_dim1];
2614 f32 = c[i + 2 + (j + 1) * c_dim1];
2615 f42 = c[i + 3 + (j + 1) * c_dim1];
2616 f33 = c[i + 2 + (j + 2) * c_dim1];
2617 f43 = c[i + 3 + (j + 2) * c_dim1];
2618 f34 = c[i + 2 + (j + 3) * c_dim1];
2619 f44 = c[i + 3 + (j + 3) * c_dim1];
2620 i6 = ll + lsec - 1;
2621 for (l = ll; l <= i6; ++l)
2623 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2624 * b[l + j * b_dim1];
2625 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2626 * b[l + j * b_dim1];
2627 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2628 * b[l + (j + 1) * b_dim1];
2629 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2630 * b[l + (j + 1) * b_dim1];
2631 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2632 * b[l + (j + 2) * b_dim1];
2633 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2634 * b[l + (j + 2) * b_dim1];
2635 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
2636 * b[l + (j + 3) * b_dim1];
2637 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
2638 * b[l + (j + 3) * b_dim1];
2639 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2640 * b[l + j * b_dim1];
2641 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2642 * b[l + j * b_dim1];
2643 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2644 * b[l + (j + 1) * b_dim1];
2645 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2646 * b[l + (j + 1) * b_dim1];
2647 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2648 * b[l + (j + 2) * b_dim1];
2649 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2650 * b[l + (j + 2) * b_dim1];
2651 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
2652 * b[l + (j + 3) * b_dim1];
2653 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
2654 * b[l + (j + 3) * b_dim1];
2656 c[i + j * c_dim1] = f11;
2657 c[i + 1 + j * c_dim1] = f21;
2658 c[i + (j + 1) * c_dim1] = f12;
2659 c[i + 1 + (j + 1) * c_dim1] = f22;
2660 c[i + (j + 2) * c_dim1] = f13;
2661 c[i + 1 + (j + 2) * c_dim1] = f23;
2662 c[i + (j + 3) * c_dim1] = f14;
2663 c[i + 1 + (j + 3) * c_dim1] = f24;
2664 c[i + 2 + j * c_dim1] = f31;
2665 c[i + 3 + j * c_dim1] = f41;
2666 c[i + 2 + (j + 1) * c_dim1] = f32;
2667 c[i + 3 + (j + 1) * c_dim1] = f42;
2668 c[i + 2 + (j + 2) * c_dim1] = f33;
2669 c[i + 3 + (j + 2) * c_dim1] = f43;
2670 c[i + 2 + (j + 3) * c_dim1] = f34;
2671 c[i + 3 + (j + 3) * c_dim1] = f44;
2673 if (uisec < isec)
2675 i5 = ii + isec - 1;
2676 for (i = ii + uisec; i <= i5; ++i)
2678 f11 = c[i + j * c_dim1];
2679 f12 = c[i + (j + 1) * c_dim1];
2680 f13 = c[i + (j + 2) * c_dim1];
2681 f14 = c[i + (j + 3) * c_dim1];
2682 i6 = ll + lsec - 1;
2683 for (l = ll; l <= i6; ++l)
2685 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2686 257] * b[l + j * b_dim1];
2687 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2688 257] * b[l + (j + 1) * b_dim1];
2689 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2690 257] * b[l + (j + 2) * b_dim1];
2691 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2692 257] * b[l + (j + 3) * b_dim1];
2694 c[i + j * c_dim1] = f11;
2695 c[i + (j + 1) * c_dim1] = f12;
2696 c[i + (j + 2) * c_dim1] = f13;
2697 c[i + (j + 3) * c_dim1] = f14;
2701 if (ujsec < jsec)
2703 i4 = jj + jsec - 1;
2704 for (j = jj + ujsec; j <= i4; ++j)
2706 i5 = ii + uisec - 1;
2707 for (i = ii; i <= i5; i += 4)
2709 f11 = c[i + j * c_dim1];
2710 f21 = c[i + 1 + j * c_dim1];
2711 f31 = c[i + 2 + j * c_dim1];
2712 f41 = c[i + 3 + j * c_dim1];
2713 i6 = ll + lsec - 1;
2714 for (l = ll; l <= i6; ++l)
2716 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2717 257] * b[l + j * b_dim1];
2718 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
2719 257] * b[l + j * b_dim1];
2720 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
2721 257] * b[l + j * b_dim1];
2722 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
2723 257] * b[l + j * b_dim1];
2725 c[i + j * c_dim1] = f11;
2726 c[i + 1 + j * c_dim1] = f21;
2727 c[i + 2 + j * c_dim1] = f31;
2728 c[i + 3 + j * c_dim1] = f41;
2730 i5 = ii + isec - 1;
2731 for (i = ii + uisec; i <= i5; ++i)
2733 f11 = c[i + j * c_dim1];
2734 i6 = ll + lsec - 1;
2735 for (l = ll; l <= i6; ++l)
2737 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
2738 257] * b[l + j * b_dim1];
2740 c[i + j * c_dim1] = f11;
2747 return;
2749 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
2751 if (GFC_DESCRIPTOR_RANK (a) != 1)
2753 const GFC_INTEGER_8 *restrict abase_x;
2754 const GFC_INTEGER_8 *restrict bbase_y;
2755 GFC_INTEGER_8 *restrict dest_y;
2756 GFC_INTEGER_8 s;
2758 for (y = 0; y < ycount; y++)
2760 bbase_y = &bbase[y*bystride];
2761 dest_y = &dest[y*rystride];
2762 for (x = 0; x < xcount; x++)
2764 abase_x = &abase[x*axstride];
2765 s = (GFC_INTEGER_8) 0;
2766 for (n = 0; n < count; n++)
2767 s += abase_x[n] * bbase_y[n];
2768 dest_y[x] = s;
2772 else
2774 const GFC_INTEGER_8 *restrict bbase_y;
2775 GFC_INTEGER_8 s;
2777 for (y = 0; y < ycount; y++)
2779 bbase_y = &bbase[y*bystride];
2780 s = (GFC_INTEGER_8) 0;
2781 for (n = 0; n < count; n++)
2782 s += abase[n*axstride] * bbase_y[n];
2783 dest[y*rystride] = s;
2787 else if (axstride < aystride)
2789 for (y = 0; y < ycount; y++)
2790 for (x = 0; x < xcount; x++)
2791 dest[x*rxstride + y*rystride] = (GFC_INTEGER_8)0;
2793 for (y = 0; y < ycount; y++)
2794 for (n = 0; n < count; n++)
2795 for (x = 0; x < xcount; x++)
2796 /* dest[x,y] += a[x,n] * b[n,y] */
2797 dest[x*rxstride + y*rystride] +=
2798 abase[x*axstride + n*aystride] *
2799 bbase[n*bxstride + y*bystride];
2801 else if (GFC_DESCRIPTOR_RANK (a) == 1)
2803 const GFC_INTEGER_8 *restrict bbase_y;
2804 GFC_INTEGER_8 s;
2806 for (y = 0; y < ycount; y++)
2808 bbase_y = &bbase[y*bystride];
2809 s = (GFC_INTEGER_8) 0;
2810 for (n = 0; n < count; n++)
2811 s += abase[n*axstride] * bbase_y[n*bxstride];
2812 dest[y*rxstride] = s;
2815 else
2817 const GFC_INTEGER_8 *restrict abase_x;
2818 const GFC_INTEGER_8 *restrict bbase_y;
2819 GFC_INTEGER_8 *restrict dest_y;
2820 GFC_INTEGER_8 s;
2822 for (y = 0; y < ycount; y++)
2824 bbase_y = &bbase[y*bystride];
2825 dest_y = &dest[y*rystride];
2826 for (x = 0; x < xcount; x++)
2828 abase_x = &abase[x*axstride];
2829 s = (GFC_INTEGER_8) 0;
2830 for (n = 0; n < count; n++)
2831 s += abase_x[n*aystride] * bbase_y[n*bxstride];
2832 dest_y[x*rxstride] = s;
2837 #undef POW3
2838 #undef min
2839 #undef max
2841 #endif
2842 #endif