AMPI: add early return to AMPI_Testall when a req is incomplete
[charm.git] / src / libs / ck-libs / ampi / mpich-alltoall.C
blob874541277afbd339f3a0492f88b0679c4099fb90
1 /* 
3 The following code is adapted from alltoall.c in mpich2-1.0.3 
5 Licensing details should be addresssed, since this is copyrighted. 
7 */
9 #include "ampiimpl.h"
10 #include "tcharm.h"
11 #include "ampiEvents.h" /*** for trace generation for projector *****/
12 #include "ampiProjections.h"
15 /* This is the default implementation of alltoall. The algorithm is:
16    
17    Algorithm: MPI_Alltoall
19    We use four algorithms for alltoall. For short messages and
20    (comm_size >= 8), we use the algorithm by Jehoshua Bruck et al,
21    IEEE TPDS, Nov. 1997. It is a store-and-forward algorithm that
22    takes lgp steps. Because of the extra communication, the bandwidth
23    requirement is (n/2).lgp.beta.
25    Cost = lgp.alpha + (n/2).lgp.beta
27    where n is the total amount of data a process needs to send to all
28    other processes.
30    For medium size messages and (short messages for comm_size < 8), we
31    use an algorithm that posts all irecvs and isends and then does a
32    waitall. We scatter the order of sources and destinations among the
33    processes, so that all processes don't try to send/recv to/from the
34    same process at the same time.
36    For long messages and power-of-two number of processes, we use a
37    pairwise exchange algorithm, which takes p-1 steps. We
38    calculate the pairs by using an exclusive-or algorithm:
39            for (i=1; i<comm_size; i++)
40                dest = rank ^ i;
41    This algorithm doesn't work if the number of processes is not a power of
42    two. For a non-power-of-two number of processes, we use an
43    algorithm in which, in step i, each process  receives from (rank-i)
44    and sends to (rank+i). 
46    Cost = (p-1).alpha + n.beta
48    where n is the total amount of data a process needs to send to all
49    other processes.
51    Possible improvements: 
53    End Algorithm: MPI_Alltoall
59 /////////////////////////////////////////////////////////////////////////////////////////////////////
60 //   HELPER FUNCTIONS:
61 /////////////////////////////////////////////////////////////////////////////////////////////////////
63 #ifndef MAX
64 int MAX(int a, int b){
65   if(a>b)
66         return a;
67   else
68         return b;
70 #endif
72 #if 0
73 int MPI_Pack_size(int incount, MPI_Datatype type, MPI_Comm comm, int *size)
75   CkDDT_DataType *ddt = getAmpiInstance(comm)->getDDT()->getType(type);
76   int typesize = ddt->getSize();
77   *size = incount * typesize;
78   return MPI_SUCCESS;
80 #endif
82 // A simplified version of the mpich MPICH_Localcopy function
83 // TODO: This should do a memcpy when data is contiguous (see original)
85 void MPICH_Localcopy(void *sendbuf, int sendcount, MPI_Datatype sendtype,
86                                          void *recvbuf, int recvcount, MPI_Datatype recvtype)
88   int rank;
90   AMPI_Comm_rank (MPI_COMM_WORLD, &rank);
91   getAmpiInstance(MPI_COMM_WORLD)->sendrecv ( sendbuf, sendcount, sendtype,
92                                   rank, MPI_ATA_TAG, 
93                                   recvbuf, recvcount, recvtype,
94                                   rank, MPI_ATA_TAG,
95                                   MPI_COMM_WORLD, MPI_STATUS_IGNORE);
99 inline void MPID_Datatype_get_extent_macro(MPI_Datatype &type, MPI_Aint &extent){
100   CkDDT_DataType *ddt = getAmpiInstance(MPI_COMM_WORLD)->getDDT()->getType(type);
101   extent = ddt->getExtent();
104 inline void MPID_Datatype_get_size_macro(MPI_Datatype &type, int &size){
105   CkDDT_DataType *ddt = getAmpiInstance(MPI_COMM_WORLD)->getDDT()->getType(type);
106   size = ddt->getSize();
110 /////////////////////////////////////////////////////////////////////////////////////////////////////
111 //   LONG MESSAGES
112 /////////////////////////////////////////////////////////////////////////////////////////////////////
115 /* Long message. If comm_size is a power-of-two, do a pairwise
116    exchange using exclusive-or to create pairs. Else send to
117    rank+i, receive from rank-i. */
119 int AMPI_Alltoall_long(
120                                                 void *sendbuf, 
121                                                 int sendcount, 
122                                                 MPI_Datatype sendtype, 
123                                                 void *recvbuf, 
124                                                 int recvcount, 
125                                                 MPI_Datatype recvtype, 
126                                                 MPI_Comm comm )
129   int          comm_size, i, pof2;
130   MPI_Aint     sendtype_extent, recvtype_extent;
132   int src, dst, rank, nbytes;
133   MPI_Status status;
134   int sendtype_size;
136   if (sendcount == 0) return MPI_SUCCESS;
137   
138   MPI_Comm_rank (MPI_COMM_WORLD, &rank);
139   MPI_Comm_size (MPI_COMM_WORLD, &comm_size);
141     
142   /* Get extent of send and recv types */
143   MPID_Datatype_get_extent_macro(recvtype, recvtype_extent);
144   MPID_Datatype_get_extent_macro(sendtype, sendtype_extent);
146   MPID_Datatype_get_size_macro(sendtype, sendtype_size);
147   nbytes = sendtype_size * sendcount;
148   
150   /* Make local copy first */
151   MPICH_Localcopy(((char *)sendbuf + 
152                                    rank*sendcount*sendtype_extent), 
153                                   sendcount, sendtype, 
154                                   ((char *)recvbuf +
155                                    rank*recvcount*recvtype_extent),
156                                   recvcount, recvtype);
157   
159   /* Is comm_size a power-of-two? */
160   i = 1;
161   while (i < comm_size)
162         i *= 2;
163   if (i == comm_size)
164         pof2 = 1;
165   else 
166         pof2 = 0;
168   /* Do the pairwise exchanges */
169   for (i=1; i<comm_size; i++) {
170         if (pof2 == 1) {
171           /* use exclusive-or algorithm */
172           src = dst = rank ^ i;
173         }
174         else {
175           src = (rank - i + comm_size) % comm_size;
176           dst = (rank + i) % comm_size;
177         }
179         getAmpiInstance(comm)->sendrecv(((char *)sendbuf +
180                                                            dst*sendcount*sendtype_extent), 
181                                                           sendcount, sendtype, dst,
182                                                           MPI_ATA_TAG, 
183                                                           ((char *)recvbuf +
184                                                            src*recvcount*recvtype_extent),
185                                                           recvcount, recvtype, src,
186                                                           MPI_ATA_TAG, comm, &status);
187   }
189   return MPI_SUCCESS;
193 /////////////////////////////////////////////////////////////////////////////////////////////////////
194 // SHORT MESSAGES
195 /////////////////////////////////////////////////////////////////////////////////////////////////////
197 #if 0
198 int AMPI_Alltoall_short(
199                                                  void *sendbuf, 
200                                                  int sendcount, 
201                                                  MPI_Datatype sendtype, 
202                                                  void *recvbuf, 
203                                                  int recvcount, 
204                                                  MPI_Datatype recvtype, 
205                                                  MPI_Comm comm )
208   int          comm_size, i, pof2;
209   MPI_Aint     sendtype_extent, recvtype_extent;
211   int mpi_errno=MPI_SUCCESS, src, dst, rank, nbytes;
212   MPI_Status status;
213   void *tmp_buf;
214   int sendtype_size, pack_size, block, position, *displs, count;
216   MPI_Datatype newtype;
217   MPI_Aint recvtype_true_extent, recvbuf_extent, recvtype_true_lb;
220   if (sendcount == 0) return MPI_SUCCESS;
221   
222   MPI_Comm_rank (MPI_COMM_WORLD, &rank);
223   MPI_Comm_size (MPI_COMM_WORLD, &comm_size);
224     
225   /* Get extent of send and recv types */
226   MPID_Datatype_get_extent_macro(recvtype, recvtype_extent);
227   MPID_Datatype_get_extent_macro(sendtype, sendtype_extent);
229   MPID_Datatype_get_size_macro(sendtype, sendtype_size);
230   nbytes = sendtype_size * sendcount;
231     
232   /* use the indexing algorithm by Jehoshua Bruck et al,
233    * IEEE TPDS, Nov. 97 */ 
235   /* allocate temporary buffer */
236   MPI_Pack_size(recvcount*comm_size, recvtype, comm, &pack_size);
237   tmp_buf = malloc(pack_size);
238   CkAssert(tmp_buf);
240   /* Do Phase 1 of the algorithim. Shift the data blocks on process i
241    * upwards by a distance of i blocks. Store the result in recvbuf. */
242   MPICH_Localcopy((char *) sendbuf + rank*sendcount*sendtype_extent, 
243                                   (comm_size - rank)*sendcount, sendtype, recvbuf, 
244                                   (comm_size - rank)*recvcount, recvtype);
245             
246   MPICH_Localcopy(sendbuf, rank*sendcount, sendtype, 
247                                   (char *) recvbuf + (comm_size-rank)*recvcount*recvtype_extent, 
248                                   rank*recvcount, recvtype);
249                                 
250   /* Input data is now stored in recvbuf with datatype recvtype */
252   /* Now do Phase 2, the communication phase. It takes
253          ceiling(lg p) steps. In each step i, each process sends to rank+2^i
254          and receives from rank-2^i, and exchanges all data blocks
255          whose ith bit is 1. */
257   /* allocate displacements array for indexed datatype used in
258          communication */
260   displs = (int*)malloc(comm_size * sizeof(int));
261   CkAssert(displs);
264   pof2 = 1;
265   while (pof2 < comm_size) {
266         dst = (rank + pof2) % comm_size;
267         src = (rank - pof2 + comm_size) % comm_size;
269         /* Exchange all data blocks whose ith bit is 1 */
270         /* Create an indexed datatype for the purpose */
272         count = 0;
273         for (block=1; block<comm_size; block++) {
274           if (block & pof2) {
275                 displs[count] = block * recvcount;
276                 count++;
277           }
278         }
280         mpi_errno = MPI_Type_create_indexed_block(count, recvcount, displs, recvtype, &newtype);
282         if (mpi_errno)
283           return mpi_errno;
285         mpi_errno = MPI_Type_commit(&newtype);
287         if (mpi_errno)
288           return mpi_errno;
289             
290         position = 0;
291         mpi_errno = MPI_Pack(recvbuf, 1, newtype, tmp_buf, pack_size, 
292                                                   &position, comm);
294         getAmpiInstance(comm)->sendrecv(tmp_buf, position, MPI_PACKED, dst,
295                                                           MPI_ATA_TAG, recvbuf, 1, newtype,
296                                                           src, MPI_ATA_TAG, comm,
297                                                           MPI_STATUS_IGNORE);
298             
299         if (mpi_errno)
300           return mpi_errno;
301             
303         mpi_errno = MPI_Type_free(&newtype);
304            
305         if (mpi_errno)
306           return mpi_errno;
308         pof2 *= 2;
309   }
311   free(displs);
312   free(tmp_buf);
314   /* Rotate blocks in recvbuf upwards by (rank + 1) blocks. Need
315    * a temporary buffer of the same size as recvbuf. */
316         
317   /* get true extent of recvtype */
318   mpi_errno = MPI_Type_get_true_extent(recvtype, &recvtype_true_lb,
319                                                                                 &recvtype_true_extent);  
321   if (mpi_errno)
322         return mpi_errno;
324   recvbuf_extent = recvcount * comm_size *
325         (MAX(recvtype_true_extent, recvtype_extent));
326   tmp_buf = malloc(recvbuf_extent);
327   CkAssert(tmp_buf);
329   /* adjust for potential negative lower bound in datatype */
330   tmp_buf = (void *)((char*)tmp_buf - recvtype_true_lb);
332   MPICH_Localcopy((char *) recvbuf + (rank+1)*recvcount*recvtype_extent, 
333                                   (comm_size - rank - 1)*recvcount, recvtype, tmp_buf, 
334                                   (comm_size - rank - 1)*recvcount, recvtype);
335                         
336   MPICH_Localcopy(recvbuf, (rank+1)*recvcount, recvtype, 
337                                   (char *) tmp_buf + (comm_size-rank-1)*recvcount*recvtype_extent, 
338                                   (rank+1)*recvcount, recvtype);
339         
340         
341   /* Blocks are in the reverse order now (comm_size-1 to 0). 
342    * Reorder them to (0 to comm_size-1) and store them in recvbuf. */
344   for (i=0; i<comm_size; i++) 
345         MPICH_Localcopy((char *) tmp_buf + i*recvcount*recvtype_extent,
346                                         recvcount, recvtype, 
347                                         (char *) recvbuf + (comm_size-i-1)*recvcount*recvtype_extent, 
348                                         recvcount, recvtype); 
350   free((char*)tmp_buf + recvtype_true_lb);
353 #endif
355 /////////////////////////////////////////////////////////////////////////////////////////////////////
356 // MEDIUM MESSAGES
357 /////////////////////////////////////////////////////////////////////////////////////////////////////
359 int AMPI_Alltoall_medium(
360                                                   void *sendbuf, 
361                                                   int sendcount, 
362                                                   MPI_Datatype sendtype, 
363                                                   void *recvbuf, 
364                                                   int recvcount, 
365                                                   MPI_Datatype recvtype, 
366                                                   MPI_Comm comm )
369   int          comm_size, i;
370   MPI_Aint     sendtype_extent, recvtype_extent;
372   int mpi_errno=MPI_SUCCESS, dst, rank, nbytes;
373   int sendtype_size;
375   MPI_Request *reqarray;
376   MPI_Status *starray;
378   if (sendcount == 0) return MPI_SUCCESS;
379   
380   MPI_Comm_rank (MPI_COMM_WORLD, &rank);
381   MPI_Comm_size (MPI_COMM_WORLD, &comm_size);
382     
383   /* Get extent of send and recv types */
384   MPID_Datatype_get_extent_macro(recvtype, recvtype_extent);
385   MPID_Datatype_get_extent_macro(sendtype, sendtype_extent);
387   MPID_Datatype_get_size_macro(sendtype, sendtype_size);
388   nbytes = sendtype_size * sendcount;
389     
390   /* Medium-size message. Use isend/irecv with scattered destinations */
392   reqarray = (MPI_Request *) malloc(2*comm_size*sizeof(MPI_Request));
393         
394   if (!reqarray) 
395         return MPI_ERR_OTHER;
396         
397   starray = (MPI_Status *) malloc(2*comm_size*sizeof(MPI_Status));
398   if (!starray) {
399         free(reqarray);
400         return MPI_ERR_OTHER;
401   }
402         
403   /* do the communication -- post all sends and receives: */
404   ampi *ptr = getAmpiInstance(comm);
405   for ( i=0; i<comm_size; i++ ) { 
406         dst = (rank+i) % comm_size;
407     ptr->irecv((char *)recvbuf + dst*recvcount*recvtype_extent, recvcount, recvtype, dst,
408                MPI_ATA_TAG, comm, &reqarray[i]);
409   }
411   for ( i=0; i<comm_size; i++ ) { 
412         dst = (rank+i) % comm_size;
413         /*mpi_errno = AMPI_Isend((char *)sendbuf + dst*sendcount*sendtype_extent,
414             sendcount, sendtype, dst, MPI_ATA_TAG, comm, &reqarray[i+comm_size]);*/
415         ptr->send(MPI_ATA_TAG, getAmpiInstance(comm)->getRank(comm),
416               (char *)sendbuf + dst*sendcount*sendtype_extent,
417                       sendcount, sendtype, dst, comm);
418         reqarray[i+comm_size] = MPI_REQUEST_NULL;
419   }
421   /* ... then wait for *all* of them to finish: */
422   mpi_errno = AMPI_Waitall(2*comm_size,reqarray,starray);
424   /* --BEGIN ERROR HANDLING-- */
425 //   if (mpi_errno == MPI_ERR_IN_STATUS) {
426 //      for (int j=0; j<2*comm_size; j++) {
427 //        if (starray[j] != MPI_SUCCESS) 
428 //              mpi_errno = starray[j];
429 //      }
430 //   }
431   /* --END ERROR HANDLING-- */
433   free(starray);
434   free(reqarray);
435   
436   return mpi_errno;
441 /////////////////////////////////////////////////////////////////////////////////////////////////////
442 // MPICH OLD VERSION -- coming soon, once I figure out how it worked
443 /////////////////////////////////////////////////////////////////////////////////////////////////////