2008-05-30 Vladimir Makarov <vmakarov@redhat.com>
[official-gcc.git] / libgfortran / m4 / matmul.m4
blob181efa3b654b85c3c03b0581520549dd587881d1
1 `/* Implementation of the MATMUL intrinsic
2    Copyright 2002, 2005, 2006, 2007 Free Software Foundation, Inc.
3    Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file.  (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
19 executable.)
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
24 GNU General Public License for more details.
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING.  If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA.  */
31 #include "libgfortran.h"
32 #include <stdlib.h>
33 #include <string.h>
34 #include <assert.h>'
36 include(iparm.m4)dnl
38 `#if defined (HAVE_'rtype_name`)
40 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
41    passed to us by the front-end, in which case we''`ll call it for large
42    matrices.  */
44 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
45                           const int *, const 'rtype_name` *, const 'rtype_name` *,
46                           const int *, const 'rtype_name` *, const int *,
47                           const 'rtype_name` *, 'rtype_name` *, const int *,
48                           int, int);
50 /* The order of loops is different in the case of plain matrix
51    multiplication C=MATMUL(A,B), and in the frequent special case where
52    the argument A is the temporary result of a TRANSPOSE intrinsic:
53    C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
54    looking at their strides.
56    The equivalent Fortran pseudo-code is:
58    DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
59    IF (.NOT.IS_TRANSPOSED(A)) THEN
60      C = 0
61      DO J=1,N
62        DO K=1,COUNT
63          DO I=1,M
64            C(I,J) = C(I,J)+A(I,K)*B(K,J)
65    ELSE
66      DO J=1,N
67        DO I=1,M
68          S = 0
69          DO K=1,COUNT
70            S = S+A(I,K)*B(K,J)
71          C(I,J) = S
72    ENDIF
75 /* If try_blas is set to a nonzero value, then the matmul function will
76    see if there is a way to perform the matrix multiplication by a call
77    to the BLAS gemm function.  */
79 extern void matmul_'rtype_code` ('rtype` * const restrict retarray, 
80         'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
81         int blas_limit, blas_call gemm);
82 export_proto(matmul_'rtype_code`);
84 void
85 matmul_'rtype_code` ('rtype` * const restrict retarray, 
86         'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
87         int blas_limit, blas_call gemm)
89   const 'rtype_name` * restrict abase;
90   const 'rtype_name` * restrict bbase;
91   'rtype_name` * restrict dest;
93   index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
94   index_type x, y, n, count, xcount, ycount;
96   assert (GFC_DESCRIPTOR_RANK (a) == 2
97           || GFC_DESCRIPTOR_RANK (b) == 2);
99 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
101    Either A or B (but not both) can be rank 1:
103    o One-dimensional argument A is implicitly treated as a row matrix
104      dimensioned [1,count], so xcount=1.
106    o One-dimensional argument B is implicitly treated as a column matrix
107      dimensioned [count, 1], so ycount=1.
108   */
110   if (retarray->data == NULL)
111     {
112       if (GFC_DESCRIPTOR_RANK (a) == 1)
113         {
114           retarray->dim[0].lbound = 0;
115           retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
116           retarray->dim[0].stride = 1;
117         }
118       else if (GFC_DESCRIPTOR_RANK (b) == 1)
119         {
120           retarray->dim[0].lbound = 0;
121           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
122           retarray->dim[0].stride = 1;
123         }
124       else
125         {
126           retarray->dim[0].lbound = 0;
127           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
128           retarray->dim[0].stride = 1;
130           retarray->dim[1].lbound = 0;
131           retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
132           retarray->dim[1].stride = retarray->dim[0].ubound+1;
133         }
135       retarray->data
136         = internal_malloc_size (sizeof ('rtype_name`) * size0 ((array_t *) retarray));
137       retarray->offset = 0;
138     }
140 sinclude(`matmul_asm_'rtype_code`.m4')dnl
142   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
143     {
144       /* One-dimensional result may be addressed in the code below
145          either as a row or a column matrix. We want both cases to
146          work. */
147       rxstride = rystride = retarray->dim[0].stride;
148     }
149   else
150     {
151       rxstride = retarray->dim[0].stride;
152       rystride = retarray->dim[1].stride;
153     }
156   if (GFC_DESCRIPTOR_RANK (a) == 1)
157     {
158       /* Treat it as a a row matrix A[1,count]. */
159       axstride = a->dim[0].stride;
160       aystride = 1;
162       xcount = 1;
163       count = a->dim[0].ubound + 1 - a->dim[0].lbound;
164     }
165   else
166     {
167       axstride = a->dim[0].stride;
168       aystride = a->dim[1].stride;
170       count = a->dim[1].ubound + 1 - a->dim[1].lbound;
171       xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
172     }
174   if (count != b->dim[0].ubound + 1 - b->dim[0].lbound)
175     {
176       if (count > 0 || b->dim[0].ubound + 1 - b->dim[0].lbound > 0)
177         runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
178     }
180   if (GFC_DESCRIPTOR_RANK (b) == 1)
181     {
182       /* Treat it as a column matrix B[count,1] */
183       bxstride = b->dim[0].stride;
185       /* bystride should never be used for 1-dimensional b.
186          in case it is we want it to cause a segfault, rather than
187          an incorrect result. */
188       bystride = 0xDEADBEEF;
189       ycount = 1;
190     }
191   else
192     {
193       bxstride = b->dim[0].stride;
194       bystride = b->dim[1].stride;
195       ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
196     }
198   abase = a->data;
199   bbase = b->data;
200   dest = retarray->data;
203   /* Now that everything is set up, we''`re performing the multiplication
204      itself.  */
206 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
208   if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
209       && (bxstride == 1 || bystride == 1)
210       && (((float) xcount) * ((float) ycount) * ((float) count)
211           > POW3(blas_limit)))
212   {
213     const int m = xcount, n = ycount, k = count, ldc = rystride;
214     const 'rtype_name` one = 1, zero = 0;
215     const int lda = (axstride == 1) ? aystride : axstride,
216               ldb = (bxstride == 1) ? bystride : bxstride;
218     if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
219       {
220         assert (gemm != NULL);
221         gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
222               &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
223         return;
224       }
225   }
227   if (rxstride == 1 && axstride == 1 && bxstride == 1)
228     {
229       const 'rtype_name` * restrict bbase_y;
230       'rtype_name` * restrict dest_y;
231       const 'rtype_name` * restrict abase_n;
232       'rtype_name` bbase_yn;
234       if (rystride == xcount)
235         memset (dest, 0, (sizeof ('rtype_name`) * xcount * ycount));
236       else
237         {
238           for (y = 0; y < ycount; y++)
239             for (x = 0; x < xcount; x++)
240               dest[x + y*rystride] = ('rtype_name`)0;
241         }
243       for (y = 0; y < ycount; y++)
244         {
245           bbase_y = bbase + y*bystride;
246           dest_y = dest + y*rystride;
247           for (n = 0; n < count; n++)
248             {
249               abase_n = abase + n*aystride;
250               bbase_yn = bbase_y[n];
251               for (x = 0; x < xcount; x++)
252                 {
253                   dest_y[x] += abase_n[x] * bbase_yn;
254                 }
255             }
256         }
257     }
258   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
259     {
260       if (GFC_DESCRIPTOR_RANK (a) != 1)
261         {
262           const 'rtype_name` *restrict abase_x;
263           const 'rtype_name` *restrict bbase_y;
264           'rtype_name` *restrict dest_y;
265           'rtype_name` s;
267           for (y = 0; y < ycount; y++)
268             {
269               bbase_y = &bbase[y*bystride];
270               dest_y = &dest[y*rystride];
271               for (x = 0; x < xcount; x++)
272                 {
273                   abase_x = &abase[x*axstride];
274                   s = ('rtype_name`) 0;
275                   for (n = 0; n < count; n++)
276                     s += abase_x[n] * bbase_y[n];
277                   dest_y[x] = s;
278                 }
279             }
280         }
281       else
282         {
283           const 'rtype_name` *restrict bbase_y;
284           'rtype_name` s;
286           for (y = 0; y < ycount; y++)
287             {
288               bbase_y = &bbase[y*bystride];
289               s = ('rtype_name`) 0;
290               for (n = 0; n < count; n++)
291                 s += abase[n*axstride] * bbase_y[n];
292               dest[y*rystride] = s;
293             }
294         }
295     }
296   else if (axstride < aystride)
297     {
298       for (y = 0; y < ycount; y++)
299         for (x = 0; x < xcount; x++)
300           dest[x*rxstride + y*rystride] = ('rtype_name`)0;
302       for (y = 0; y < ycount; y++)
303         for (n = 0; n < count; n++)
304           for (x = 0; x < xcount; x++)
305             /* dest[x,y] += a[x,n] * b[n,y] */
306             dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
307     }
308   else if (GFC_DESCRIPTOR_RANK (a) == 1)
309     {
310       const 'rtype_name` *restrict bbase_y;
311       'rtype_name` s;
313       for (y = 0; y < ycount; y++)
314         {
315           bbase_y = &bbase[y*bystride];
316           s = ('rtype_name`) 0;
317           for (n = 0; n < count; n++)
318             s += abase[n*axstride] * bbase_y[n*bxstride];
319           dest[y*rxstride] = s;
320         }
321     }
322   else
323     {
324       const 'rtype_name` *restrict abase_x;
325       const 'rtype_name` *restrict bbase_y;
326       'rtype_name` *restrict dest_y;
327       'rtype_name` s;
329       for (y = 0; y < ycount; y++)
330         {
331           bbase_y = &bbase[y*bystride];
332           dest_y = &dest[y*rystride];
333           for (x = 0; x < xcount; x++)
334             {
335               abase_x = &abase[x*axstride];
336               s = ('rtype_name`) 0;
337               for (n = 0; n < count; n++)
338                 s += abase_x[n*aystride] * bbase_y[n*bxstride];
339               dest_y[x*rxstride] = s;
340             }
341         }
342     }
345 #endif'