gcc/
[official-gcc.git] / libgfortran / m4 / matmul.m4
blobfae6c38d4765d6f05dfbae426471e6ae274d6ff6
1 `/* Implementation of the MATMUL intrinsic
2    Copyright (C) 2002-2015 Free Software Foundation, Inc.
3    Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 3 of the License, or (at your option) any later version.
12 Libgfortran is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 GNU General Public License for more details.
17 Under Section 7 of GPL version 3, you are granted additional
18 permissions described in the GCC Runtime Library Exception, version
19 3.1, as published by the Free Software Foundation.
21 You should have received a copy of the GNU General Public License and
22 a copy of the GCC Runtime Library Exception along with this program;
23 see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
24 <http://www.gnu.org/licenses/>.  */
26 #include "libgfortran.h"
27 #include <stdlib.h>
28 #include <string.h>
29 #include <assert.h>'
31 include(iparm.m4)dnl
33 `#if defined (HAVE_'rtype_name`)
35 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
36    passed to us by the front-end, in which case we''`ll call it for large
37    matrices.  */
39 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
40                           const int *, const 'rtype_name` *, const 'rtype_name` *,
41                           const int *, const 'rtype_name` *, const int *,
42                           const 'rtype_name` *, 'rtype_name` *, const int *,
43                           int, int);
45 /* The order of loops is different in the case of plain matrix
46    multiplication C=MATMUL(A,B), and in the frequent special case where
47    the argument A is the temporary result of a TRANSPOSE intrinsic:
48    C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
49    looking at their strides.
51    The equivalent Fortran pseudo-code is:
53    DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
54    IF (.NOT.IS_TRANSPOSED(A)) THEN
55      C = 0
56      DO J=1,N
57        DO K=1,COUNT
58          DO I=1,M
59            C(I,J) = C(I,J)+A(I,K)*B(K,J)
60    ELSE
61      DO J=1,N
62        DO I=1,M
63          S = 0
64          DO K=1,COUNT
65            S = S+A(I,K)*B(K,J)
66          C(I,J) = S
67    ENDIF
70 /* If try_blas is set to a nonzero value, then the matmul function will
71    see if there is a way to perform the matrix multiplication by a call
72    to the BLAS gemm function.  */
74 extern void matmul_'rtype_code` ('rtype` * const restrict retarray, 
75         'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
76         int blas_limit, blas_call gemm);
77 export_proto(matmul_'rtype_code`);
79 void
80 matmul_'rtype_code` ('rtype` * const restrict retarray, 
81         'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
82         int blas_limit, blas_call gemm)
84   const 'rtype_name` * restrict abase;
85   const 'rtype_name` * restrict bbase;
86   'rtype_name` * restrict dest;
88   index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
89   index_type x, y, n, count, xcount, ycount;
91   assert (GFC_DESCRIPTOR_RANK (a) == 2
92           || GFC_DESCRIPTOR_RANK (b) == 2);
94 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
96    Either A or B (but not both) can be rank 1:
98    o One-dimensional argument A is implicitly treated as a row matrix
99      dimensioned [1,count], so xcount=1.
101    o One-dimensional argument B is implicitly treated as a column matrix
102      dimensioned [count, 1], so ycount=1.
103   */
105   if (retarray->base_addr == NULL)
106     {
107       if (GFC_DESCRIPTOR_RANK (a) == 1)
108         {
109           GFC_DIMENSION_SET(retarray->dim[0], 0,
110                             GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
111         }
112       else if (GFC_DESCRIPTOR_RANK (b) == 1)
113         {
114           GFC_DIMENSION_SET(retarray->dim[0], 0,
115                             GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
116         }
117       else
118         {
119           GFC_DIMENSION_SET(retarray->dim[0], 0,
120                             GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
122           GFC_DIMENSION_SET(retarray->dim[1], 0,
123                             GFC_DESCRIPTOR_EXTENT(b,1) - 1,
124                             GFC_DESCRIPTOR_EXTENT(retarray,0));
125         }
127       retarray->base_addr
128         = xmallocarray (size0 ((array_t *) retarray), sizeof ('rtype_name`));
129       retarray->offset = 0;
130     }
131     else if (unlikely (compile_options.bounds_check))
132       {
133         index_type ret_extent, arg_extent;
135         if (GFC_DESCRIPTOR_RANK (a) == 1)
136           {
137             arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
138             ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
139             if (arg_extent != ret_extent)
140               runtime_error ("Incorrect extent in return array in"
141                              " MATMUL intrinsic: is %ld, should be %ld",
142                              (long int) ret_extent, (long int) arg_extent);
143           }
144         else if (GFC_DESCRIPTOR_RANK (b) == 1)
145           {
146             arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
147             ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
148             if (arg_extent != ret_extent)
149               runtime_error ("Incorrect extent in return array in"
150                              " MATMUL intrinsic: is %ld, should be %ld",
151                              (long int) ret_extent, (long int) arg_extent);         
152           }
153         else
154           {
155             arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
156             ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
157             if (arg_extent != ret_extent)
158               runtime_error ("Incorrect extent in return array in"
159                              " MATMUL intrinsic for dimension 1:"
160                              " is %ld, should be %ld",
161                              (long int) ret_extent, (long int) arg_extent);
163             arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
164             ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
165             if (arg_extent != ret_extent)
166               runtime_error ("Incorrect extent in return array in"
167                              " MATMUL intrinsic for dimension 2:"
168                              " is %ld, should be %ld",
169                              (long int) ret_extent, (long int) arg_extent);
170           }
171       }
173 sinclude(`matmul_asm_'rtype_code`.m4')dnl
175   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
176     {
177       /* One-dimensional result may be addressed in the code below
178          either as a row or a column matrix. We want both cases to
179          work. */
180       rxstride = rystride = GFC_DESCRIPTOR_STRIDE(retarray,0);
181     }
182   else
183     {
184       rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
185       rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
186     }
189   if (GFC_DESCRIPTOR_RANK (a) == 1)
190     {
191       /* Treat it as a a row matrix A[1,count]. */
192       axstride = GFC_DESCRIPTOR_STRIDE(a,0);
193       aystride = 1;
195       xcount = 1;
196       count = GFC_DESCRIPTOR_EXTENT(a,0);
197     }
198   else
199     {
200       axstride = GFC_DESCRIPTOR_STRIDE(a,0);
201       aystride = GFC_DESCRIPTOR_STRIDE(a,1);
203       count = GFC_DESCRIPTOR_EXTENT(a,1);
204       xcount = GFC_DESCRIPTOR_EXTENT(a,0);
205     }
207   if (count != GFC_DESCRIPTOR_EXTENT(b,0))
208     {
209       if (count > 0 || GFC_DESCRIPTOR_EXTENT(b,0) > 0)
210         runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
211     }
213   if (GFC_DESCRIPTOR_RANK (b) == 1)
214     {
215       /* Treat it as a column matrix B[count,1] */
216       bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
218       /* bystride should never be used for 1-dimensional b.
219          in case it is we want it to cause a segfault, rather than
220          an incorrect result. */
221       bystride = 0xDEADBEEF;
222       ycount = 1;
223     }
224   else
225     {
226       bxstride = GFC_DESCRIPTOR_STRIDE(b,0);
227       bystride = GFC_DESCRIPTOR_STRIDE(b,1);
228       ycount = GFC_DESCRIPTOR_EXTENT(b,1);
229     }
231   abase = a->base_addr;
232   bbase = b->base_addr;
233   dest = retarray->base_addr;
236   /* Now that everything is set up, we''`re performing the multiplication
237      itself.  */
239 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
241   if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
242       && (bxstride == 1 || bystride == 1)
243       && (((float) xcount) * ((float) ycount) * ((float) count)
244           > POW3(blas_limit)))
245   {
246     const int m = xcount, n = ycount, k = count, ldc = rystride;
247     const 'rtype_name` one = 1, zero = 0;
248     const int lda = (axstride == 1) ? aystride : axstride,
249               ldb = (bxstride == 1) ? bystride : bxstride;
251     if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
252       {
253         assert (gemm != NULL);
254         gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
255               &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
256         return;
257       }
258   }
260   if (rxstride == 1 && axstride == 1 && bxstride == 1)
261     {
262       const 'rtype_name` * restrict bbase_y;
263       'rtype_name` * restrict dest_y;
264       const 'rtype_name` * restrict abase_n;
265       'rtype_name` bbase_yn;
267       if (rystride == xcount)
268         memset (dest, 0, (sizeof ('rtype_name`) * xcount * ycount));
269       else
270         {
271           for (y = 0; y < ycount; y++)
272             for (x = 0; x < xcount; x++)
273               dest[x + y*rystride] = ('rtype_name`)0;
274         }
276       for (y = 0; y < ycount; y++)
277         {
278           bbase_y = bbase + y*bystride;
279           dest_y = dest + y*rystride;
280           for (n = 0; n < count; n++)
281             {
282               abase_n = abase + n*aystride;
283               bbase_yn = bbase_y[n];
284               for (x = 0; x < xcount; x++)
285                 {
286                   dest_y[x] += abase_n[x] * bbase_yn;
287                 }
288             }
289         }
290     }
291   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
292     {
293       if (GFC_DESCRIPTOR_RANK (a) != 1)
294         {
295           const 'rtype_name` *restrict abase_x;
296           const 'rtype_name` *restrict bbase_y;
297           'rtype_name` *restrict dest_y;
298           'rtype_name` s;
300           for (y = 0; y < ycount; y++)
301             {
302               bbase_y = &bbase[y*bystride];
303               dest_y = &dest[y*rystride];
304               for (x = 0; x < xcount; x++)
305                 {
306                   abase_x = &abase[x*axstride];
307                   s = ('rtype_name`) 0;
308                   for (n = 0; n < count; n++)
309                     s += abase_x[n] * bbase_y[n];
310                   dest_y[x] = s;
311                 }
312             }
313         }
314       else
315         {
316           const 'rtype_name` *restrict bbase_y;
317           'rtype_name` s;
319           for (y = 0; y < ycount; y++)
320             {
321               bbase_y = &bbase[y*bystride];
322               s = ('rtype_name`) 0;
323               for (n = 0; n < count; n++)
324                 s += abase[n*axstride] * bbase_y[n];
325               dest[y*rystride] = s;
326             }
327         }
328     }
329   else if (axstride < aystride)
330     {
331       for (y = 0; y < ycount; y++)
332         for (x = 0; x < xcount; x++)
333           dest[x*rxstride + y*rystride] = ('rtype_name`)0;
335       for (y = 0; y < ycount; y++)
336         for (n = 0; n < count; n++)
337           for (x = 0; x < xcount; x++)
338             /* dest[x,y] += a[x,n] * b[n,y] */
339             dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
340     }
341   else if (GFC_DESCRIPTOR_RANK (a) == 1)
342     {
343       const 'rtype_name` *restrict bbase_y;
344       'rtype_name` s;
346       for (y = 0; y < ycount; y++)
347         {
348           bbase_y = &bbase[y*bystride];
349           s = ('rtype_name`) 0;
350           for (n = 0; n < count; n++)
351             s += abase[n*axstride] * bbase_y[n*bxstride];
352           dest[y*rxstride] = s;
353         }
354     }
355   else
356     {
357       const 'rtype_name` *restrict abase_x;
358       const 'rtype_name` *restrict bbase_y;
359       'rtype_name` *restrict dest_y;
360       'rtype_name` s;
362       for (y = 0; y < ycount; y++)
363         {
364           bbase_y = &bbase[y*bystride];
365           dest_y = &dest[y*rystride];
366           for (x = 0; x < xcount; x++)
367             {
368               abase_x = &abase[x*axstride];
369               s = ('rtype_name`) 0;
370               for (n = 0; n < count; n++)
371                 s += abase_x[n*aystride] * bbase_y[n*bxstride];
372               dest_y[x*rxstride] = s;
373             }
374         }
375     }
378 #endif'