2008-10-01 Kai Tietz <kai.tietz@onevision.com>
[official-gcc.git] / libgfortran / m4 / matmul.m4
blobd8621fa2b35224d96153f166a6cab5ee923178b9
1 `/* Implementation of the MATMUL intrinsic
2    Copyright 2002, 2005, 2006, 2007 Free Software Foundation, Inc.
3    Contributed by Paul Brook <paul@nowt.org>
5 This file is part of the GNU Fortran 95 runtime library (libgfortran).
7 Libgfortran is free software; you can redistribute it and/or
8 modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either
10 version 2 of the License, or (at your option) any later version.
12 In addition to the permissions in the GNU General Public License, the
13 Free Software Foundation gives you unlimited permission to link the
14 compiled version of this file into combinations with other programs,
15 and to distribute those combinations without any restriction coming
16 from the use of this file.  (The General Public License restrictions
17 do apply in other respects; for example, they cover modification of
18 the file, and distribution when not linked into a combine
19 executable.)
21 Libgfortran is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
24 GNU General Public License for more details.
26 You should have received a copy of the GNU General Public
27 License along with libgfortran; see the file COPYING.  If not,
28 write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
29 Boston, MA 02110-1301, USA.  */
31 #include "libgfortran.h"
32 #include <stdlib.h>
33 #include <string.h>
34 #include <assert.h>'
36 include(iparm.m4)dnl
38 `#if defined (HAVE_'rtype_name`)
40 /* Prototype for the BLAS ?gemm subroutine, a pointer to which can be
41    passed to us by the front-end, in which case we''`ll call it for large
42    matrices.  */
44 typedef void (*blas_call)(const char *, const char *, const int *, const int *,
45                           const int *, const 'rtype_name` *, const 'rtype_name` *,
46                           const int *, const 'rtype_name` *, const int *,
47                           const 'rtype_name` *, 'rtype_name` *, const int *,
48                           int, int);
50 /* The order of loops is different in the case of plain matrix
51    multiplication C=MATMUL(A,B), and in the frequent special case where
52    the argument A is the temporary result of a TRANSPOSE intrinsic:
53    C=MATMUL(TRANSPOSE(A),B).  Transposed temporaries are detected by
54    looking at their strides.
56    The equivalent Fortran pseudo-code is:
58    DIMENSION A(M,COUNT), B(COUNT,N), C(M,N)
59    IF (.NOT.IS_TRANSPOSED(A)) THEN
60      C = 0
61      DO J=1,N
62        DO K=1,COUNT
63          DO I=1,M
64            C(I,J) = C(I,J)+A(I,K)*B(K,J)
65    ELSE
66      DO J=1,N
67        DO I=1,M
68          S = 0
69          DO K=1,COUNT
70            S = S+A(I,K)*B(K,J)
71          C(I,J) = S
72    ENDIF
75 /* If try_blas is set to a nonzero value, then the matmul function will
76    see if there is a way to perform the matrix multiplication by a call
77    to the BLAS gemm function.  */
79 extern void matmul_'rtype_code` ('rtype` * const restrict retarray, 
80         'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
81         int blas_limit, blas_call gemm);
82 export_proto(matmul_'rtype_code`);
84 void
85 matmul_'rtype_code` ('rtype` * const restrict retarray, 
86         'rtype` * const restrict a, 'rtype` * const restrict b, int try_blas,
87         int blas_limit, blas_call gemm)
89   const 'rtype_name` * restrict abase;
90   const 'rtype_name` * restrict bbase;
91   'rtype_name` * restrict dest;
93   index_type rxstride, rystride, axstride, aystride, bxstride, bystride;
94   index_type x, y, n, count, xcount, ycount;
96   assert (GFC_DESCRIPTOR_RANK (a) == 2
97           || GFC_DESCRIPTOR_RANK (b) == 2);
99 /* C[xcount,ycount] = A[xcount, count] * B[count,ycount]
101    Either A or B (but not both) can be rank 1:
103    o One-dimensional argument A is implicitly treated as a row matrix
104      dimensioned [1,count], so xcount=1.
106    o One-dimensional argument B is implicitly treated as a column matrix
107      dimensioned [count, 1], so ycount=1.
108   */
110   if (retarray->data == NULL)
111     {
112       if (GFC_DESCRIPTOR_RANK (a) == 1)
113         {
114           retarray->dim[0].lbound = 0;
115           retarray->dim[0].ubound = b->dim[1].ubound - b->dim[1].lbound;
116           retarray->dim[0].stride = 1;
117         }
118       else if (GFC_DESCRIPTOR_RANK (b) == 1)
119         {
120           retarray->dim[0].lbound = 0;
121           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
122           retarray->dim[0].stride = 1;
123         }
124       else
125         {
126           retarray->dim[0].lbound = 0;
127           retarray->dim[0].ubound = a->dim[0].ubound - a->dim[0].lbound;
128           retarray->dim[0].stride = 1;
130           retarray->dim[1].lbound = 0;
131           retarray->dim[1].ubound = b->dim[1].ubound - b->dim[1].lbound;
132           retarray->dim[1].stride = retarray->dim[0].ubound+1;
133         }
135       retarray->data
136         = internal_malloc_size (sizeof ('rtype_name`) * size0 ((array_t *) retarray));
137       retarray->offset = 0;
138     }
139     else if (unlikely (compile_options.bounds_check))
140       {
141         index_type ret_extent, arg_extent;
143         if (GFC_DESCRIPTOR_RANK (a) == 1)
144           {
145             arg_extent = b->dim[1].ubound + 1 - b->dim[1].lbound;
146             ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
147             if (arg_extent != ret_extent)
148               runtime_error ("Incorrect extent in return array in"
149                              " MATMUL intrinsic: is %ld, should be %ld",
150                              (long int) ret_extent, (long int) arg_extent);
151           }
152         else if (GFC_DESCRIPTOR_RANK (b) == 1)
153           {
154             arg_extent = a->dim[0].ubound + 1 - a->dim[0].lbound;
155             ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
156             if (arg_extent != ret_extent)
157               runtime_error ("Incorrect extent in return array in"
158                              " MATMUL intrinsic: is %ld, should be %ld",
159                              (long int) ret_extent, (long int) arg_extent);         
160           }
161         else
162           {
163             arg_extent = a->dim[0].ubound + 1 - a->dim[0].lbound;
164             ret_extent = retarray->dim[0].ubound + 1 - retarray->dim[0].lbound;
165             if (arg_extent != ret_extent)
166               runtime_error ("Incorrect extent in return array in"
167                              " MATMUL intrinsic for dimension 1:"
168                              " is %ld, should be %ld",
169                              (long int) ret_extent, (long int) arg_extent);
171             arg_extent = b->dim[1].ubound + 1 - b->dim[1].lbound;
172             ret_extent = retarray->dim[1].ubound + 1 - retarray->dim[1].lbound;
173             if (arg_extent != ret_extent)
174               runtime_error ("Incorrect extent in return array in"
175                              " MATMUL intrinsic for dimension 2:"
176                              " is %ld, should be %ld",
177                              (long int) ret_extent, (long int) arg_extent);
178           }
179       }
181 sinclude(`matmul_asm_'rtype_code`.m4')dnl
183   if (GFC_DESCRIPTOR_RANK (retarray) == 1)
184     {
185       /* One-dimensional result may be addressed in the code below
186          either as a row or a column matrix. We want both cases to
187          work. */
188       rxstride = rystride = retarray->dim[0].stride;
189     }
190   else
191     {
192       rxstride = retarray->dim[0].stride;
193       rystride = retarray->dim[1].stride;
194     }
197   if (GFC_DESCRIPTOR_RANK (a) == 1)
198     {
199       /* Treat it as a a row matrix A[1,count]. */
200       axstride = a->dim[0].stride;
201       aystride = 1;
203       xcount = 1;
204       count = a->dim[0].ubound + 1 - a->dim[0].lbound;
205     }
206   else
207     {
208       axstride = a->dim[0].stride;
209       aystride = a->dim[1].stride;
211       count = a->dim[1].ubound + 1 - a->dim[1].lbound;
212       xcount = a->dim[0].ubound + 1 - a->dim[0].lbound;
213     }
215   if (count != b->dim[0].ubound + 1 - b->dim[0].lbound)
216     {
217       if (count > 0 || b->dim[0].ubound + 1 - b->dim[0].lbound > 0)
218         runtime_error ("dimension of array B incorrect in MATMUL intrinsic");
219     }
221   if (GFC_DESCRIPTOR_RANK (b) == 1)
222     {
223       /* Treat it as a column matrix B[count,1] */
224       bxstride = b->dim[0].stride;
226       /* bystride should never be used for 1-dimensional b.
227          in case it is we want it to cause a segfault, rather than
228          an incorrect result. */
229       bystride = 0xDEADBEEF;
230       ycount = 1;
231     }
232   else
233     {
234       bxstride = b->dim[0].stride;
235       bystride = b->dim[1].stride;
236       ycount = b->dim[1].ubound + 1 - b->dim[1].lbound;
237     }
239   abase = a->data;
240   bbase = b->data;
241   dest = retarray->data;
244   /* Now that everything is set up, we''`re performing the multiplication
245      itself.  */
247 #define POW3(x) (((float) (x)) * ((float) (x)) * ((float) (x)))
249   if (try_blas && rxstride == 1 && (axstride == 1 || aystride == 1)
250       && (bxstride == 1 || bystride == 1)
251       && (((float) xcount) * ((float) ycount) * ((float) count)
252           > POW3(blas_limit)))
253   {
254     const int m = xcount, n = ycount, k = count, ldc = rystride;
255     const 'rtype_name` one = 1, zero = 0;
256     const int lda = (axstride == 1) ? aystride : axstride,
257               ldb = (bxstride == 1) ? bystride : bxstride;
259     if (lda > 0 && ldb > 0 && ldc > 0 && m > 1 && n > 1 && k > 1)
260       {
261         assert (gemm != NULL);
262         gemm (axstride == 1 ? "N" : "T", bxstride == 1 ? "N" : "T", &m, &n, &k,
263               &one, abase, &lda, bbase, &ldb, &zero, dest, &ldc, 1, 1);
264         return;
265       }
266   }
268   if (rxstride == 1 && axstride == 1 && bxstride == 1)
269     {
270       const 'rtype_name` * restrict bbase_y;
271       'rtype_name` * restrict dest_y;
272       const 'rtype_name` * restrict abase_n;
273       'rtype_name` bbase_yn;
275       if (rystride == xcount)
276         memset (dest, 0, (sizeof ('rtype_name`) * xcount * ycount));
277       else
278         {
279           for (y = 0; y < ycount; y++)
280             for (x = 0; x < xcount; x++)
281               dest[x + y*rystride] = ('rtype_name`)0;
282         }
284       for (y = 0; y < ycount; y++)
285         {
286           bbase_y = bbase + y*bystride;
287           dest_y = dest + y*rystride;
288           for (n = 0; n < count; n++)
289             {
290               abase_n = abase + n*aystride;
291               bbase_yn = bbase_y[n];
292               for (x = 0; x < xcount; x++)
293                 {
294                   dest_y[x] += abase_n[x] * bbase_yn;
295                 }
296             }
297         }
298     }
299   else if (rxstride == 1 && aystride == 1 && bxstride == 1)
300     {
301       if (GFC_DESCRIPTOR_RANK (a) != 1)
302         {
303           const 'rtype_name` *restrict abase_x;
304           const 'rtype_name` *restrict bbase_y;
305           'rtype_name` *restrict dest_y;
306           'rtype_name` s;
308           for (y = 0; y < ycount; y++)
309             {
310               bbase_y = &bbase[y*bystride];
311               dest_y = &dest[y*rystride];
312               for (x = 0; x < xcount; x++)
313                 {
314                   abase_x = &abase[x*axstride];
315                   s = ('rtype_name`) 0;
316                   for (n = 0; n < count; n++)
317                     s += abase_x[n] * bbase_y[n];
318                   dest_y[x] = s;
319                 }
320             }
321         }
322       else
323         {
324           const 'rtype_name` *restrict bbase_y;
325           'rtype_name` s;
327           for (y = 0; y < ycount; y++)
328             {
329               bbase_y = &bbase[y*bystride];
330               s = ('rtype_name`) 0;
331               for (n = 0; n < count; n++)
332                 s += abase[n*axstride] * bbase_y[n];
333               dest[y*rystride] = s;
334             }
335         }
336     }
337   else if (axstride < aystride)
338     {
339       for (y = 0; y < ycount; y++)
340         for (x = 0; x < xcount; x++)
341           dest[x*rxstride + y*rystride] = ('rtype_name`)0;
343       for (y = 0; y < ycount; y++)
344         for (n = 0; n < count; n++)
345           for (x = 0; x < xcount; x++)
346             /* dest[x,y] += a[x,n] * b[n,y] */
347             dest[x*rxstride + y*rystride] += abase[x*axstride + n*aystride] * bbase[n*bxstride + y*bystride];
348     }
349   else if (GFC_DESCRIPTOR_RANK (a) == 1)
350     {
351       const 'rtype_name` *restrict bbase_y;
352       'rtype_name` s;
354       for (y = 0; y < ycount; y++)
355         {
356           bbase_y = &bbase[y*bystride];
357           s = ('rtype_name`) 0;
358           for (n = 0; n < count; n++)
359             s += abase[n*axstride] * bbase_y[n*bxstride];
360           dest[y*rxstride] = s;
361         }
362     }
363   else
364     {
365       const 'rtype_name` *restrict abase_x;
366       const 'rtype_name` *restrict bbase_y;
367       'rtype_name` *restrict dest_y;
368       'rtype_name` s;
370       for (y = 0; y < ycount; y++)
371         {
372           bbase_y = &bbase[y*bystride];
373           dest_y = &dest[y*rystride];
374           for (x = 0; x < xcount; x++)
375             {
376               abase_x = &abase[x*axstride];
377               s = ('rtype_name`) 0;
378               for (n = 0; n < count; n++)
379                 s += abase_x[n*aystride] * bbase_y[n*bxstride];
380               dest_y[x*rxstride] = s;
381             }
382         }
383     }
386 #endif'