2018-04-06 Thomas Koenig <tkoenig@gcc.gnu.org>
[official-gcc.git] / libgfortran / generated / matmulavx128_i8.c
blob731e55d2f630463cc4eceea7f679a2647a7cbf0a
1 /* Implementation of the MATMUL intrinsic
2 Copyright (C) 2002-2018 Free Software Foundation, Inc.
3 Contributed by Thomas Koenig <tkoenig@gcc.gnu.org>.
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24 <http://www.gnu.org/licenses/>. */
26 #include "libgfortran.h"
27 #include <string.h>
28 #include <assert.h>
31 /* These are the specific versions of matmul with -mprefer-avx128. */
33 #if defined (HAVE_GFC_INTEGER_8)
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36 passed to us by the front-end, in which case we call it for large
37 matrices. */
39 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
40 const int *, const GFC_INTEGER_8 *, const GFC_INTEGER_8 *,
41 const int *, const GFC_INTEGER_8 *, const int *,
42 const GFC_INTEGER_8 *, GFC_INTEGER_8 *, const int *,
43 int, int);
45 #if defined(HAVE_AVX) && defined(HAVE_FMA3) && defined(HAVE_AVX128)
46 void
47 matmul_i8_avx128_fma3 (gfc_array_i8 * const restrict retarray,
48 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
49 int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma")));
50 internal_proto(matmul_i8_avx128_fma3);
51 void
52 matmul_i8_avx128_fma3 (gfc_array_i8 * const restrict retarray,
53 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
54 int blas_limit, blas_call gemm)
56 const GFC_INTEGER_8 * restrict abase;
57 const GFC_INTEGER_8 * restrict bbase;
58 GFC_INTEGER_8 * restrict dest;
60 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
61 index_type x, y, n, count, xcount, ycount;
63 assert (GFC_DESCRIPTOR_RANK (a) == 2
64 || GFC_DESCRIPTOR_RANK (b) == 2);
66 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
68 Either A or B (but not both) can be rank 1:
70 o One-dimensional argument A is implicitly treated as a row matrix
71 dimensioned [1,count], so xcount=1.
73 o One-dimensional argument B is implicitly treated as a column matrix
74 dimensioned [count, 1], so ycount=1.
77 if (retarray->base_addr == NULL)
79 if (GFC_DESCRIPTOR_RANK (a) == 1)
81 GFC_DIMENSION_SET(retarray->dim[0], 0,
82 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
84 else if (GFC_DESCRIPTOR_RANK (b) == 1)
86 GFC_DIMENSION_SET(retarray->dim[0], 0,
87 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
89 else
91 GFC_DIMENSION_SET(retarray->dim[0], 0,
92 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
94 GFC_DIMENSION_SET(retarray->dim[1], 0,
95 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
96 GFC_DESCRIPTOR_EXTENT(retarray,0));
99 retarray->base_addr
100 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_8));
101 retarray->offset = 0;
103 else if (unlikely (compile_options.bounds_check))
105 index_type ret_extent, arg_extent;
107 if (GFC_DESCRIPTOR_RANK (a) == 1)
109 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
110 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
111 if (arg_extent != ret_extent)
112 runtime_error ("Incorrect extent in return array in"
113 " MATMUL intrinsic: is %ld, should be %ld",
114 (long int) ret_extent, (long int) arg_extent);
116 else if (GFC_DESCRIPTOR_RANK (b) == 1)
118 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
119 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
120 if (arg_extent != ret_extent)
121 runtime_error ("Incorrect extent in return array in"
122 " MATMUL intrinsic: is %ld, should be %ld",
123 (long int) ret_extent, (long int) arg_extent);
125 else
127 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
128 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
129 if (arg_extent != ret_extent)
130 runtime_error ("Incorrect extent in return array in"
131 " MATMUL intrinsic for dimension 1:"
132 " is %ld, should be %ld",
133 (long int) ret_extent, (long int) arg_extent);
135 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
136 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
137 if (arg_extent != ret_extent)
138 runtime_error ("Incorrect extent in return array in"
139 " MATMUL intrinsic for dimension 2:"
140 " is %ld, should be %ld",
141 (long int) ret_extent, (long int) arg_extent);
146 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
148 /* One-dimensional result may be addressed in the code below
149 either as a row or a column matrix. We want both cases to
150 work. */
151 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
153 else
155 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
156 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
160 if (GFC_DESCRIPTOR_RANK (a) == 1)
162 /* Treat it as a a row matrix A[1,count]. */
163 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
164 aystride = 1;
166 xcount = 1;
167 count = GFC_DESCRIPTOR_EXTENT(a,0);
169 else
171 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
172 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
174 count = GFC_DESCRIPTOR_EXTENT(a,1);
175 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
178 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
180 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
181 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
184 if (GFC_DESCRIPTOR_RANK (b) == 1)
186 /* Treat it as a column matrix B[count,1] */
187 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
189 /* bystride should never be used for 1-dimensional b.
190 The value is only used for calculation of the
191 memory by the buffer. */
192 bystride = 256;
193 ycount = 1;
195 else
197 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
198 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
199 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
202 abase = a->base_addr;
203 bbase = b->base_addr;
204 dest = retarray->base_addr;
206 /* Now that everything is set up, we perform the multiplication
207 itself. */
209 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
210 #define min(a,b) ((a) <= (b) ? (a) : (b))
211 #define max(a,b) ((a) >= (b) ? (a) : (b))
213 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
214 && (bxstride == 1 || bystride == 1)
215 && (((float) xcount) * ((float) ycount) * ((float) count)
216 > POW3(blas_limit)))
218 const int m = xcount, n = ycount, k = count, ldc = rystride;
219 const GFC_INTEGER_8 one = 1, zero = 0;
220 const int lda = (axstride == 1) ? aystride : axstride,
221 ldb = (bxstride == 1) ? bystride : bxstride;
223 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
225 assert (gemm != NULL);
226 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m,
227 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
228 &ldc, 1, 1);
229 return;
233 if (rxstride == 1 && axstride == 1 && bxstride == 1)
235 /* This block of code implements a tuned matmul, derived from
236 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
238 Bo Kagstrom and Per Ling
239 Department of Computing Science
240 Umea University
241 S-901 87 Umea, Sweden
243 from netlib.org, translated to C, and modified for matmul.m4. */
245 const GFC_INTEGER_8 *a, *b;
246 GFC_INTEGER_8 *c;
247 const index_type m = xcount, n = ycount, k = count;
249 /* System generated locals */
250 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
251 i1, i2, i3, i4, i5, i6;
253 /* Local variables */
254 GFC_INTEGER_8 f11, f12, f21, f22, f31, f32, f41, f42,
255 f13, f14, f23, f24, f33, f34, f43, f44;
256 index_type i, j, l, ii, jj, ll;
257 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
258 GFC_INTEGER_8 *t1;
260 a = abase;
261 b = bbase;
262 c = retarray->base_addr;
264 /* Parameter adjustments */
265 c_dim1 = rystride;
266 c_offset = 1 + c_dim1;
267 c -= c_offset;
268 a_dim1 = aystride;
269 a_offset = 1 + a_dim1;
270 a -= a_offset;
271 b_dim1 = bystride;
272 b_offset = 1 + b_dim1;
273 b -= b_offset;
275 /* Empty c first. */
276 for (j=1; j<=n; j++)
277 for (i=1; i<=m; i++)
278 c[i + j * c_dim1] = (GFC_INTEGER_8)0;
280 /* Early exit if possible */
281 if (m == 0 || n == 0 || k == 0)
282 return;
284 /* Adjust size of t1 to what is needed. */
285 index_type t1_dim;
286 t1_dim = (a_dim1 - (ycount > 1)) * 256 + b_dim1;
287 if (t1_dim > 65536)
288 t1_dim = 65536;
290 t1 = malloc (t1_dim * sizeof(GFC_INTEGER_8));
292 /* Start turning the crank. */
293 i1 = n;
294 for (jj = 1; jj <= i1; jj += 512)
296 /* Computing MIN */
297 i2 = 512;
298 i3 = n - jj + 1;
299 jsec = min(i2,i3);
300 ujsec = jsec - jsec % 4;
301 i2 = k;
302 for (ll = 1; ll <= i2; ll += 256)
304 /* Computing MIN */
305 i3 = 256;
306 i4 = k - ll + 1;
307 lsec = min(i3,i4);
308 ulsec = lsec - lsec % 2;
310 i3 = m;
311 for (ii = 1; ii <= i3; ii += 256)
313 /* Computing MIN */
314 i4 = 256;
315 i5 = m - ii + 1;
316 isec = min(i4,i5);
317 uisec = isec - isec % 2;
318 i4 = ll + ulsec - 1;
319 for (l = ll; l <= i4; l += 2)
321 i5 = ii + uisec - 1;
322 for (i = ii; i <= i5; i += 2)
324 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
325 a[i + l * a_dim1];
326 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
327 a[i + (l + 1) * a_dim1];
328 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
329 a[i + 1 + l * a_dim1];
330 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
331 a[i + 1 + (l + 1) * a_dim1];
333 if (uisec < isec)
335 t1[l - ll + 1 + (isec << 8) - 257] =
336 a[ii + isec - 1 + l * a_dim1];
337 t1[l - ll + 2 + (isec << 8) - 257] =
338 a[ii + isec - 1 + (l + 1) * a_dim1];
341 if (ulsec < lsec)
343 i4 = ii + isec - 1;
344 for (i = ii; i<= i4; ++i)
346 t1[lsec + ((i - ii + 1) << 8) - 257] =
347 a[i + (ll + lsec - 1) * a_dim1];
351 uisec = isec - isec % 4;
352 i4 = jj + ujsec - 1;
353 for (j = jj; j <= i4; j += 4)
355 i5 = ii + uisec - 1;
356 for (i = ii; i <= i5; i += 4)
358 f11 = c[i + j * c_dim1];
359 f21 = c[i + 1 + j * c_dim1];
360 f12 = c[i + (j + 1) * c_dim1];
361 f22 = c[i + 1 + (j + 1) * c_dim1];
362 f13 = c[i + (j + 2) * c_dim1];
363 f23 = c[i + 1 + (j + 2) * c_dim1];
364 f14 = c[i + (j + 3) * c_dim1];
365 f24 = c[i + 1 + (j + 3) * c_dim1];
366 f31 = c[i + 2 + j * c_dim1];
367 f41 = c[i + 3 + j * c_dim1];
368 f32 = c[i + 2 + (j + 1) * c_dim1];
369 f42 = c[i + 3 + (j + 1) * c_dim1];
370 f33 = c[i + 2 + (j + 2) * c_dim1];
371 f43 = c[i + 3 + (j + 2) * c_dim1];
372 f34 = c[i + 2 + (j + 3) * c_dim1];
373 f44 = c[i + 3 + (j + 3) * c_dim1];
374 i6 = ll + lsec - 1;
375 for (l = ll; l <= i6; ++l)
377 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
378 * b[l + j * b_dim1];
379 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
380 * b[l + j * b_dim1];
381 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
382 * b[l + (j + 1) * b_dim1];
383 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
384 * b[l + (j + 1) * b_dim1];
385 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
386 * b[l + (j + 2) * b_dim1];
387 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
388 * b[l + (j + 2) * b_dim1];
389 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
390 * b[l + (j + 3) * b_dim1];
391 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
392 * b[l + (j + 3) * b_dim1];
393 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
394 * b[l + j * b_dim1];
395 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
396 * b[l + j * b_dim1];
397 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
398 * b[l + (j + 1) * b_dim1];
399 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
400 * b[l + (j + 1) * b_dim1];
401 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
402 * b[l + (j + 2) * b_dim1];
403 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
404 * b[l + (j + 2) * b_dim1];
405 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
406 * b[l + (j + 3) * b_dim1];
407 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
408 * b[l + (j + 3) * b_dim1];
410 c[i + j * c_dim1] = f11;
411 c[i + 1 + j * c_dim1] = f21;
412 c[i + (j + 1) * c_dim1] = f12;
413 c[i + 1 + (j + 1) * c_dim1] = f22;
414 c[i + (j + 2) * c_dim1] = f13;
415 c[i + 1 + (j + 2) * c_dim1] = f23;
416 c[i + (j + 3) * c_dim1] = f14;
417 c[i + 1 + (j + 3) * c_dim1] = f24;
418 c[i + 2 + j * c_dim1] = f31;
419 c[i + 3 + j * c_dim1] = f41;
420 c[i + 2 + (j + 1) * c_dim1] = f32;
421 c[i + 3 + (j + 1) * c_dim1] = f42;
422 c[i + 2 + (j + 2) * c_dim1] = f33;
423 c[i + 3 + (j + 2) * c_dim1] = f43;
424 c[i + 2 + (j + 3) * c_dim1] = f34;
425 c[i + 3 + (j + 3) * c_dim1] = f44;
427 if (uisec < isec)
429 i5 = ii + isec - 1;
430 for (i = ii + uisec; i <= i5; ++i)
432 f11 = c[i + j * c_dim1];
433 f12 = c[i + (j + 1) * c_dim1];
434 f13 = c[i + (j + 2) * c_dim1];
435 f14 = c[i + (j + 3) * c_dim1];
436 i6 = ll + lsec - 1;
437 for (l = ll; l <= i6; ++l)
439 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
440 257] * b[l + j * b_dim1];
441 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
442 257] * b[l + (j + 1) * b_dim1];
443 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
444 257] * b[l + (j + 2) * b_dim1];
445 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
446 257] * b[l + (j + 3) * b_dim1];
448 c[i + j * c_dim1] = f11;
449 c[i + (j + 1) * c_dim1] = f12;
450 c[i + (j + 2) * c_dim1] = f13;
451 c[i + (j + 3) * c_dim1] = f14;
455 if (ujsec < jsec)
457 i4 = jj + jsec - 1;
458 for (j = jj + ujsec; j <= i4; ++j)
460 i5 = ii + uisec - 1;
461 for (i = ii; i <= i5; i += 4)
463 f11 = c[i + j * c_dim1];
464 f21 = c[i + 1 + j * c_dim1];
465 f31 = c[i + 2 + j * c_dim1];
466 f41 = c[i + 3 + j * c_dim1];
467 i6 = ll + lsec - 1;
468 for (l = ll; l <= i6; ++l)
470 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
471 257] * b[l + j * b_dim1];
472 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
473 257] * b[l + j * b_dim1];
474 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
475 257] * b[l + j * b_dim1];
476 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
477 257] * b[l + j * b_dim1];
479 c[i + j * c_dim1] = f11;
480 c[i + 1 + j * c_dim1] = f21;
481 c[i + 2 + j * c_dim1] = f31;
482 c[i + 3 + j * c_dim1] = f41;
484 i5 = ii + isec - 1;
485 for (i = ii + uisec; i <= i5; ++i)
487 f11 = c[i + j * c_dim1];
488 i6 = ll + lsec - 1;
489 for (l = ll; l <= i6; ++l)
491 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
492 257] * b[l + j * b_dim1];
494 c[i + j * c_dim1] = f11;
501 free(t1);
502 return;
504 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
506 if (GFC_DESCRIPTOR_RANK (a) != 1)
508 const GFC_INTEGER_8 *restrict abase_x;
509 const GFC_INTEGER_8 *restrict bbase_y;
510 GFC_INTEGER_8 *restrict dest_y;
511 GFC_INTEGER_8 s;
513 for (y = 0; y < ycount; y++)
515 bbase_y = &bbase[y*bystride];
516 dest_y = &dest[y*rystride];
517 for (x = 0; x < xcount; x++)
519 abase_x = &abase[x*axstride];
520 s = (GFC_INTEGER_8) 0;
521 for (n = 0; n < count; n++)
522 s += abase_x[n] * bbase_y[n];
523 dest_y[x] = s;
527 else
529 const GFC_INTEGER_8 *restrict bbase_y;
530 GFC_INTEGER_8 s;
532 for (y = 0; y < ycount; y++)
534 bbase_y = &bbase[y*bystride];
535 s = (GFC_INTEGER_8) 0;
536 for (n = 0; n < count; n++)
537 s += abase[n*axstride] * bbase_y[n];
538 dest[y*rystride] = s;
542 else if (axstride < aystride)
544 for (y = 0; y < ycount; y++)
545 for (x = 0; x < xcount; x++)
546 dest[x*rxstride + y*rystride] = (GFC_INTEGER_8)0;
548 for (y = 0; y < ycount; y++)
549 for (n = 0; n < count; n++)
550 for (x = 0; x < xcount; x++)
551 /* dest[x,y] += a[x,n] * b[n,y] */
552 dest[x*rxstride + y*rystride] +=
553 abase[x*axstride + n*aystride] *
554 bbase[n*bxstride + y*bystride];
556 else if (GFC_DESCRIPTOR_RANK (a) == 1)
558 const GFC_INTEGER_8 *restrict bbase_y;
559 GFC_INTEGER_8 s;
561 for (y = 0; y < ycount; y++)
563 bbase_y = &bbase[y*bystride];
564 s = (GFC_INTEGER_8) 0;
565 for (n = 0; n < count; n++)
566 s += abase[n*axstride] * bbase_y[n*bxstride];
567 dest[y*rxstride] = s;
570 else
572 const GFC_INTEGER_8 *restrict abase_x;
573 const GFC_INTEGER_8 *restrict bbase_y;
574 GFC_INTEGER_8 *restrict dest_y;
575 GFC_INTEGER_8 s;
577 for (y = 0; y < ycount; y++)
579 bbase_y = &bbase[y*bystride];
580 dest_y = &dest[y*rystride];
581 for (x = 0; x < xcount; x++)
583 abase_x = &abase[x*axstride];
584 s = (GFC_INTEGER_8) 0;
585 for (n = 0; n < count; n++)
586 s += abase_x[n*aystride] * bbase_y[n*bxstride];
587 dest_y[x*rxstride] = s;
592 #undef POW3
593 #undef min
594 #undef max
596 #endif
598 #if defined(HAVE_AVX) && defined(HAVE_FMA4) && defined(HAVE_AVX128)
599 void
600 matmul_i8_avx128_fma4 (gfc_array_i8 * const restrict retarray,
601 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
602 int blas_limit, blas_call gemm) __attribute__((__target__("avx,fma4")));
603 internal_proto(matmul_i8_avx128_fma4);
604 void
605 matmul_i8_avx128_fma4 (gfc_array_i8 * const restrict retarray,
606 gfc_array_i8 * const restrict a, gfc_array_i8 * const restrict b, int try_blas,
607 int blas_limit, blas_call gemm)
609 const GFC_INTEGER_8 * restrict abase;
610 const GFC_INTEGER_8 * restrict bbase;
611 GFC_INTEGER_8 * restrict dest;
613 index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
614 index_type x, y, n, count, xcount, ycount;
616 assert (GFC_DESCRIPTOR_RANK (a) == 2
617 || GFC_DESCRIPTOR_RANK (b) == 2);
619 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
621 Either A or B (but not both) can be rank 1:
623 o One-dimensional argument A is implicitly treated as a row matrix
624 dimensioned [1,count], so xcount=1.
626 o One-dimensional argument B is implicitly treated as a column matrix
627 dimensioned [count, 1], so ycount=1.
630 if (retarray->base_addr == NULL)
632 if (GFC_DESCRIPTOR_RANK (a) == 1)
634 GFC_DIMENSION_SET(retarray->dim[0], 0,
635 GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
637 else if (GFC_DESCRIPTOR_RANK (b) == 1)
639 GFC_DIMENSION_SET(retarray->dim[0], 0,
640 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
642 else
644 GFC_DIMENSION_SET(retarray->dim[0], 0,
645 GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
647 GFC_DIMENSION_SET(retarray->dim[1], 0,
648 GFC_DESCRIPTOR_EXTENT(b,1) - 1,
649 GFC_DESCRIPTOR_EXTENT(retarray,0));
652 retarray->base_addr
653 = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_INTEGER_8));
654 retarray->offset = 0;
656 else if (unlikely (compile_options.bounds_check))
658 index_type ret_extent, arg_extent;
660 if (GFC_DESCRIPTOR_RANK (a) == 1)
662 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
663 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
664 if (arg_extent != ret_extent)
665 runtime_error ("Incorrect extent in return array in"
666 " MATMUL intrinsic: is %ld, should be %ld",
667 (long int) ret_extent, (long int) arg_extent);
669 else if (GFC_DESCRIPTOR_RANK (b) == 1)
671 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
672 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
673 if (arg_extent != ret_extent)
674 runtime_error ("Incorrect extent in return array in"
675 " MATMUL intrinsic: is %ld, should be %ld",
676 (long int) ret_extent, (long int) arg_extent);
678 else
680 arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
681 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
682 if (arg_extent != ret_extent)
683 runtime_error ("Incorrect extent in return array in"
684 " MATMUL intrinsic for dimension 1:"
685 " is %ld, should be %ld",
686 (long int) ret_extent, (long int) arg_extent);
688 arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
689 ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
690 if (arg_extent != ret_extent)
691 runtime_error ("Incorrect extent in return array in"
692 " MATMUL intrinsic for dimension 2:"
693 " is %ld, should be %ld",
694 (long int) ret_extent, (long int) arg_extent);
699 if (GFC_DESCRIPTOR_RANK (retarray) == 1)
701 /* One-dimensional result may be addressed in the code below
702 either as a row or a column matrix. We want both cases to
703 work. */
704 rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
706 else
708 rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
709 rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
713 if (GFC_DESCRIPTOR_RANK (a) == 1)
715 /* Treat it as a a row matrix A[1,count]. */
716 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
717 aystride = 1;
719 xcount = 1;
720 count = GFC_DESCRIPTOR_EXTENT(a,0);
722 else
724 axstride = GFC_DESCRIPTOR_STRIDE(a,0);
725 aystride = GFC_DESCRIPTOR_STRIDE(a,1);
727 count = GFC_DESCRIPTOR_EXTENT(a,1);
728 xcount = GFC_DESCRIPTOR_EXTENT(a,0);
731 if (count != GFC_DESCRIPTOR_EXTENT(b,0))
733 if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
734 runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
737 if (GFC_DESCRIPTOR_RANK (b) == 1)
739 /* Treat it as a column matrix B[count,1] */
740 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
742 /* bystride should never be used for 1-dimensional b.
743 The value is only used for calculation of the
744 memory by the buffer. */
745 bystride = 256;
746 ycount = 1;
748 else
750 bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
751 bystride = GFC_DESCRIPTOR_STRIDE(b,1);
752 ycount = GFC_DESCRIPTOR_EXTENT(b,1);
755 abase = a->base_addr;
756 bbase = b->base_addr;
757 dest = retarray->base_addr;
759 /* Now that everything is set up, we perform the multiplication
760 itself. */
762 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
763 #define min(a,b) ((a) <= (b) ? (a) : (b))
764 #define max(a,b) ((a) >= (b) ? (a) : (b))
766 if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
767 && (bxstride == 1 || bystride == 1)
768 && (((float) xcount) * ((float) ycount) * ((float) count)
769 > POW3(blas_limit)))
771 const int m = xcount, n = ycount, k = count, ldc = rystride;
772 const GFC_INTEGER_8 one = 1, zero = 0;
773 const int lda = (axstride == 1) ? aystride : axstride,
774 ldb = (bxstride == 1) ? bystride : bxstride;
776 if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
778 assert (gemm != NULL);
779 gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m,
780 &n, &k, &one, abase, &lda, bbase, &ldb, &zero, dest,
781 &ldc, 1, 1);
782 return;
786 if (rxstride == 1 && axstride == 1 && bxstride == 1)
788 /* This block of code implements a tuned matmul, derived from
789 Superscalar GEMM-based level 3 BLAS, Beta version 0.1
791 Bo Kagstrom and Per Ling
792 Department of Computing Science
793 Umea University
794 S-901 87 Umea, Sweden
796 from netlib.org, translated to C, and modified for matmul.m4. */
798 const GFC_INTEGER_8 *a, *b;
799 GFC_INTEGER_8 *c;
800 const index_type m = xcount, n = ycount, k = count;
802 /* System generated locals */
803 index_type a_dim1, a_offset, b_dim1, b_offset, c_dim1, c_offset,
804 i1, i2, i3, i4, i5, i6;
806 /* Local variables */
807 GFC_INTEGER_8 f11, f12, f21, f22, f31, f32, f41, f42,
808 f13, f14, f23, f24, f33, f34, f43, f44;
809 index_type i, j, l, ii, jj, ll;
810 index_type isec, jsec, lsec, uisec, ujsec, ulsec;
811 GFC_INTEGER_8 *t1;
813 a = abase;
814 b = bbase;
815 c = retarray->base_addr;
817 /* Parameter adjustments */
818 c_dim1 = rystride;
819 c_offset = 1 + c_dim1;
820 c -= c_offset;
821 a_dim1 = aystride;
822 a_offset = 1 + a_dim1;
823 a -= a_offset;
824 b_dim1 = bystride;
825 b_offset = 1 + b_dim1;
826 b -= b_offset;
828 /* Empty c first. */
829 for (j=1; j<=n; j++)
830 for (i=1; i<=m; i++)
831 c[i + j * c_dim1] = (GFC_INTEGER_8)0;
833 /* Early exit if possible */
834 if (m == 0 || n == 0 || k == 0)
835 return;
837 /* Adjust size of t1 to what is needed. */
838 index_type t1_dim;
839 t1_dim = (a_dim1 - (ycount > 1)) * 256 + b_dim1;
840 if (t1_dim > 65536)
841 t1_dim = 65536;
843 t1 = malloc (t1_dim * sizeof(GFC_INTEGER_8));
845 /* Start turning the crank. */
846 i1 = n;
847 for (jj = 1; jj <= i1; jj += 512)
849 /* Computing MIN */
850 i2 = 512;
851 i3 = n - jj + 1;
852 jsec = min(i2,i3);
853 ujsec = jsec - jsec % 4;
854 i2 = k;
855 for (ll = 1; ll <= i2; ll += 256)
857 /* Computing MIN */
858 i3 = 256;
859 i4 = k - ll + 1;
860 lsec = min(i3,i4);
861 ulsec = lsec - lsec % 2;
863 i3 = m;
864 for (ii = 1; ii <= i3; ii += 256)
866 /* Computing MIN */
867 i4 = 256;
868 i5 = m - ii + 1;
869 isec = min(i4,i5);
870 uisec = isec - isec % 2;
871 i4 = ll + ulsec - 1;
872 for (l = ll; l <= i4; l += 2)
874 i5 = ii + uisec - 1;
875 for (i = ii; i <= i5; i += 2)
877 t1[l - ll + 1 + ((i - ii + 1) << 8) - 257] =
878 a[i + l * a_dim1];
879 t1[l - ll + 2 + ((i - ii + 1) << 8) - 257] =
880 a[i + (l + 1) * a_dim1];
881 t1[l - ll + 1 + ((i - ii + 2) << 8) - 257] =
882 a[i + 1 + l * a_dim1];
883 t1[l - ll + 2 + ((i - ii + 2) << 8) - 257] =
884 a[i + 1 + (l + 1) * a_dim1];
886 if (uisec < isec)
888 t1[l - ll + 1 + (isec << 8) - 257] =
889 a[ii + isec - 1 + l * a_dim1];
890 t1[l - ll + 2 + (isec << 8) - 257] =
891 a[ii + isec - 1 + (l + 1) * a_dim1];
894 if (ulsec < lsec)
896 i4 = ii + isec - 1;
897 for (i = ii; i<= i4; ++i)
899 t1[lsec + ((i - ii + 1) << 8) - 257] =
900 a[i + (ll + lsec - 1) * a_dim1];
904 uisec = isec - isec % 4;
905 i4 = jj + ujsec - 1;
906 for (j = jj; j <= i4; j += 4)
908 i5 = ii + uisec - 1;
909 for (i = ii; i <= i5; i += 4)
911 f11 = c[i + j * c_dim1];
912 f21 = c[i + 1 + j * c_dim1];
913 f12 = c[i + (j + 1) * c_dim1];
914 f22 = c[i + 1 + (j + 1) * c_dim1];
915 f13 = c[i + (j + 2) * c_dim1];
916 f23 = c[i + 1 + (j + 2) * c_dim1];
917 f14 = c[i + (j + 3) * c_dim1];
918 f24 = c[i + 1 + (j + 3) * c_dim1];
919 f31 = c[i + 2 + j * c_dim1];
920 f41 = c[i + 3 + j * c_dim1];
921 f32 = c[i + 2 + (j + 1) * c_dim1];
922 f42 = c[i + 3 + (j + 1) * c_dim1];
923 f33 = c[i + 2 + (j + 2) * c_dim1];
924 f43 = c[i + 3 + (j + 2) * c_dim1];
925 f34 = c[i + 2 + (j + 3) * c_dim1];
926 f44 = c[i + 3 + (j + 3) * c_dim1];
927 i6 = ll + lsec - 1;
928 for (l = ll; l <= i6; ++l)
930 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
931 * b[l + j * b_dim1];
932 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
933 * b[l + j * b_dim1];
934 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
935 * b[l + (j + 1) * b_dim1];
936 f22 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
937 * b[l + (j + 1) * b_dim1];
938 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
939 * b[l + (j + 2) * b_dim1];
940 f23 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
941 * b[l + (j + 2) * b_dim1];
942 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) - 257]
943 * b[l + (j + 3) * b_dim1];
944 f24 += t1[l - ll + 1 + ((i - ii + 2) << 8) - 257]
945 * b[l + (j + 3) * b_dim1];
946 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
947 * b[l + j * b_dim1];
948 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
949 * b[l + j * b_dim1];
950 f32 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
951 * b[l + (j + 1) * b_dim1];
952 f42 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
953 * b[l + (j + 1) * b_dim1];
954 f33 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
955 * b[l + (j + 2) * b_dim1];
956 f43 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
957 * b[l + (j + 2) * b_dim1];
958 f34 += t1[l - ll + 1 + ((i - ii + 3) << 8) - 257]
959 * b[l + (j + 3) * b_dim1];
960 f44 += t1[l - ll + 1 + ((i - ii + 4) << 8) - 257]
961 * b[l + (j + 3) * b_dim1];
963 c[i + j * c_dim1] = f11;
964 c[i + 1 + j * c_dim1] = f21;
965 c[i + (j + 1) * c_dim1] = f12;
966 c[i + 1 + (j + 1) * c_dim1] = f22;
967 c[i + (j + 2) * c_dim1] = f13;
968 c[i + 1 + (j + 2) * c_dim1] = f23;
969 c[i + (j + 3) * c_dim1] = f14;
970 c[i + 1 + (j + 3) * c_dim1] = f24;
971 c[i + 2 + j * c_dim1] = f31;
972 c[i + 3 + j * c_dim1] = f41;
973 c[i + 2 + (j + 1) * c_dim1] = f32;
974 c[i + 3 + (j + 1) * c_dim1] = f42;
975 c[i + 2 + (j + 2) * c_dim1] = f33;
976 c[i + 3 + (j + 2) * c_dim1] = f43;
977 c[i + 2 + (j + 3) * c_dim1] = f34;
978 c[i + 3 + (j + 3) * c_dim1] = f44;
980 if (uisec < isec)
982 i5 = ii + isec - 1;
983 for (i = ii + uisec; i <= i5; ++i)
985 f11 = c[i + j * c_dim1];
986 f12 = c[i + (j + 1) * c_dim1];
987 f13 = c[i + (j + 2) * c_dim1];
988 f14 = c[i + (j + 3) * c_dim1];
989 i6 = ll + lsec - 1;
990 for (l = ll; l <= i6; ++l)
992 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
993 257] * b[l + j * b_dim1];
994 f12 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
995 257] * b[l + (j + 1) * b_dim1];
996 f13 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
997 257] * b[l + (j + 2) * b_dim1];
998 f14 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
999 257] * b[l + (j + 3) * b_dim1];
1001 c[i + j * c_dim1] = f11;
1002 c[i + (j + 1) * c_dim1] = f12;
1003 c[i + (j + 2) * c_dim1] = f13;
1004 c[i + (j + 3) * c_dim1] = f14;
1008 if (ujsec < jsec)
1010 i4 = jj + jsec - 1;
1011 for (j = jj + ujsec; j <= i4; ++j)
1013 i5 = ii + uisec - 1;
1014 for (i = ii; i <= i5; i += 4)
1016 f11 = c[i + j * c_dim1];
1017 f21 = c[i + 1 + j * c_dim1];
1018 f31 = c[i + 2 + j * c_dim1];
1019 f41 = c[i + 3 + j * c_dim1];
1020 i6 = ll + lsec - 1;
1021 for (l = ll; l <= i6; ++l)
1023 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1024 257] * b[l + j * b_dim1];
1025 f21 += t1[l - ll + 1 + ((i - ii + 2) << 8) -
1026 257] * b[l + j * b_dim1];
1027 f31 += t1[l - ll + 1 + ((i - ii + 3) << 8) -
1028 257] * b[l + j * b_dim1];
1029 f41 += t1[l - ll + 1 + ((i - ii + 4) << 8) -
1030 257] * b[l + j * b_dim1];
1032 c[i + j * c_dim1] = f11;
1033 c[i + 1 + j * c_dim1] = f21;
1034 c[i + 2 + j * c_dim1] = f31;
1035 c[i + 3 + j * c_dim1] = f41;
1037 i5 = ii + isec - 1;
1038 for (i = ii + uisec; i <= i5; ++i)
1040 f11 = c[i + j * c_dim1];
1041 i6 = ll + lsec - 1;
1042 for (l = ll; l <= i6; ++l)
1044 f11 += t1[l - ll + 1 + ((i - ii + 1) << 8) -
1045 257] * b[l + j * b_dim1];
1047 c[i + j * c_dim1] = f11;
1054 free(t1);
1055 return;
1057 else if (rxstride == 1 && aystride == 1 && bxstride == 1)
1059 if (GFC_DESCRIPTOR_RANK (a) != 1)
1061 const GFC_INTEGER_8 *restrict abase_x;
1062 const GFC_INTEGER_8 *restrict bbase_y;
1063 GFC_INTEGER_8 *restrict dest_y;
1064 GFC_INTEGER_8 s;
1066 for (y = 0; y < ycount; y++)
1068 bbase_y = &bbase[y*bystride];
1069 dest_y = &dest[y*rystride];
1070 for (x = 0; x < xcount; x++)
1072 abase_x = &abase[x*axstride];
1073 s = (GFC_INTEGER_8) 0;
1074 for (n = 0; n < count; n++)
1075 s += abase_x[n] * bbase_y[n];
1076 dest_y[x] = s;
1080 else
1082 const GFC_INTEGER_8 *restrict bbase_y;
1083 GFC_INTEGER_8 s;
1085 for (y = 0; y < ycount; y++)
1087 bbase_y = &bbase[y*bystride];
1088 s = (GFC_INTEGER_8) 0;
1089 for (n = 0; n < count; n++)
1090 s += abase[n*axstride] * bbase_y[n];
1091 dest[y*rystride] = s;
1095 else if (axstride < aystride)
1097 for (y = 0; y < ycount; y++)
1098 for (x = 0; x < xcount; x++)
1099 dest[x*rxstride + y*rystride] = (GFC_INTEGER_8)0;
1101 for (y = 0; y < ycount; y++)
1102 for (n = 0; n < count; n++)
1103 for (x = 0; x < xcount; x++)
1104 /* dest[x,y] += a[x,n] * b[n,y] */
1105 dest[x*rxstride + y*rystride] +=
1106 abase[x*axstride + n*aystride] *
1107 bbase[n*bxstride + y*bystride];
1109 else if (GFC_DESCRIPTOR_RANK (a) == 1)
1111 const GFC_INTEGER_8 *restrict bbase_y;
1112 GFC_INTEGER_8 s;
1114 for (y = 0; y < ycount; y++)
1116 bbase_y = &bbase[y*bystride];
1117 s = (GFC_INTEGER_8) 0;
1118 for (n = 0; n < count; n++)
1119 s += abase[n*axstride] * bbase_y[n*bxstride];
1120 dest[y*rxstride] = s;
1123 else
1125 const GFC_INTEGER_8 *restrict abase_x;
1126 const GFC_INTEGER_8 *restrict bbase_y;
1127 GFC_INTEGER_8 *restrict dest_y;
1128 GFC_INTEGER_8 s;
1130 for (y = 0; y < ycount; y++)
1132 bbase_y = &bbase[y*bystride];
1133 dest_y = &dest[y*rystride];
1134 for (x = 0; x < xcount; x++)
1136 abase_x = &abase[x*axstride];
1137 s = (GFC_INTEGER_8) 0;
1138 for (n = 0; n < count; n++)
1139 s += abase_x[n*aystride] * bbase_y[n*bxstride];
1140 dest_y[x*rxstride] = s;
1145 #undef POW3
1146 #undef min
1147 #undef max
1149 #endif
1151 #endif