3 The following code is adapted from alltoall.c in mpich2-1.0.3
5 Licensing details should be addresssed, since this is copyrighted.
11 #include "ampiEvents.h" /*** for trace generation for projector *****/
12 #include "ampiProjections.h"
15 /* This is the default implementation of alltoall. The algorithm is:
17 Algorithm: MPI_Alltoall
19 We use four algorithms for alltoall. For short messages and
20 (comm_size >= 8), we use the algorithm by Jehoshua Bruck et al,
21 IEEE TPDS, Nov. 1997. It is a store-and-forward algorithm that
22 takes lgp steps. Because of the extra communication, the bandwidth
23 requirement is (n/2).lgp.beta.
25 Cost = lgp.alpha + (n/2).lgp.beta
27 where n is the total amount of data a process needs to send to all
30 For medium size messages and (short messages for comm_size < 8), we
31 use an algorithm that posts all irecvs and isends and then does a
32 waitall. We scatter the order of sources and destinations among the
33 processes, so that all processes don't try to send/recv to/from the
34 same process at the same time.
36 For long messages and power-of-two number of processes, we use a
37 pairwise exchange algorithm, which takes p-1 steps. We
38 calculate the pairs by using an exclusive-or algorithm:
39 for (i=1; i<comm_size; i++)
41 This algorithm doesn't work if the number of processes is not a power of
42 two. For a non-power-of-two number of processes, we use an
43 algorithm in which, in step i, each process receives from (rank-i)
44 and sends to (rank+i).
46 Cost = (p-1).alpha + n.beta
48 where n is the total amount of data a process needs to send to all
51 Possible improvements:
53 End Algorithm: MPI_Alltoall
59 /////////////////////////////////////////////////////////////////////////////////////////////////////
61 /////////////////////////////////////////////////////////////////////////////////////////////////////
64 int MAX(int a, int b){
73 int MPI_Pack_size(int incount, MPI_Datatype type, MPI_Comm comm, int *size)
75 CkDDT_DataType *ddt = getAmpiInstance(comm)->getDDT()->getType(type);
76 int typesize = ddt->getSize();
77 *size = incount * typesize;
82 // A simplified version of the mpich MPICH_Localcopy function
83 // TODO: This should do a memcpy when data is contiguous (see original)
85 void MPICH_Localcopy(void *sendbuf, int sendcount, MPI_Datatype sendtype,
86 void *recvbuf, int recvcount, MPI_Datatype recvtype)
90 AMPI_Comm_rank (MPI_COMM_WORLD, &rank);
91 getAmpiInstance(MPI_COMM_WORLD)->sendrecv ( sendbuf, sendcount, sendtype,
93 recvbuf, recvcount, recvtype,
95 MPI_COMM_WORLD, MPI_STATUS_IGNORE);
99 inline void MPID_Datatype_get_extent_macro(MPI_Datatype &type, MPI_Aint &extent){
100 CkDDT_DataType *ddt = getAmpiInstance(MPI_COMM_WORLD)->getDDT()->getType(type);
101 extent = ddt->getExtent();
104 inline void MPID_Datatype_get_size_macro(MPI_Datatype &type, int &size){
105 CkDDT_DataType *ddt = getAmpiInstance(MPI_COMM_WORLD)->getDDT()->getType(type);
106 size = ddt->getSize();
110 /////////////////////////////////////////////////////////////////////////////////////////////////////
112 /////////////////////////////////////////////////////////////////////////////////////////////////////
115 /* Long message. If comm_size is a power-of-two, do a pairwise
116 exchange using exclusive-or to create pairs. Else send to
117 rank+i, receive from rank-i. */
119 int AMPI_Alltoall_long(
122 MPI_Datatype sendtype,
125 MPI_Datatype recvtype,
129 int comm_size, i, pof2;
130 MPI_Aint sendtype_extent, recvtype_extent;
132 int src, dst, rank, nbytes;
136 if (sendcount == 0) return MPI_SUCCESS;
138 MPI_Comm_rank (MPI_COMM_WORLD, &rank);
139 MPI_Comm_size (MPI_COMM_WORLD, &comm_size);
142 /* Get extent of send and recv types */
143 MPID_Datatype_get_extent_macro(recvtype, recvtype_extent);
144 MPID_Datatype_get_extent_macro(sendtype, sendtype_extent);
146 MPID_Datatype_get_size_macro(sendtype, sendtype_size);
147 nbytes = sendtype_size * sendcount;
150 /* Make local copy first */
151 MPICH_Localcopy(((char *)sendbuf +
152 rank*sendcount*sendtype_extent),
155 rank*recvcount*recvtype_extent),
156 recvcount, recvtype);
159 /* Is comm_size a power-of-two? */
161 while (i < comm_size)
168 /* Do the pairwise exchanges */
169 for (i=1; i<comm_size; i++) {
171 /* use exclusive-or algorithm */
172 src = dst = rank ^ i;
175 src = (rank - i + comm_size) % comm_size;
176 dst = (rank + i) % comm_size;
179 getAmpiInstance(comm)->sendrecv(((char *)sendbuf +
180 dst*sendcount*sendtype_extent),
181 sendcount, sendtype, dst,
184 src*recvcount*recvtype_extent),
185 recvcount, recvtype, src,
186 MPI_ATA_TAG, comm, &status);
193 /////////////////////////////////////////////////////////////////////////////////////////////////////
195 /////////////////////////////////////////////////////////////////////////////////////////////////////
198 int AMPI_Alltoall_short(
201 MPI_Datatype sendtype,
204 MPI_Datatype recvtype,
208 int comm_size, i, pof2;
209 MPI_Aint sendtype_extent, recvtype_extent;
211 int mpi_errno=MPI_SUCCESS, src, dst, rank, nbytes;
214 int sendtype_size, pack_size, block, position, *displs, count;
216 MPI_Datatype newtype;
217 MPI_Aint recvtype_true_extent, recvbuf_extent, recvtype_true_lb;
220 if (sendcount == 0) return MPI_SUCCESS;
222 MPI_Comm_rank (MPI_COMM_WORLD, &rank);
223 MPI_Comm_size (MPI_COMM_WORLD, &comm_size);
225 /* Get extent of send and recv types */
226 MPID_Datatype_get_extent_macro(recvtype, recvtype_extent);
227 MPID_Datatype_get_extent_macro(sendtype, sendtype_extent);
229 MPID_Datatype_get_size_macro(sendtype, sendtype_size);
230 nbytes = sendtype_size * sendcount;
232 /* use the indexing algorithm by Jehoshua Bruck et al,
233 * IEEE TPDS, Nov. 97 */
235 /* allocate temporary buffer */
236 MPI_Pack_size(recvcount*comm_size, recvtype, comm, &pack_size);
237 tmp_buf = malloc(pack_size);
240 /* Do Phase 1 of the algorithim. Shift the data blocks on process i
241 * upwards by a distance of i blocks. Store the result in recvbuf. */
242 MPICH_Localcopy((char *) sendbuf + rank*sendcount*sendtype_extent,
243 (comm_size - rank)*sendcount, sendtype, recvbuf,
244 (comm_size - rank)*recvcount, recvtype);
246 MPICH_Localcopy(sendbuf, rank*sendcount, sendtype,
247 (char *) recvbuf + (comm_size-rank)*recvcount*recvtype_extent,
248 rank*recvcount, recvtype);
250 /* Input data is now stored in recvbuf with datatype recvtype */
252 /* Now do Phase 2, the communication phase. It takes
253 ceiling(lg p) steps. In each step i, each process sends to rank+2^i
254 and receives from rank-2^i, and exchanges all data blocks
255 whose ith bit is 1. */
257 /* allocate displacements array for indexed datatype used in
260 displs = (int*)malloc(comm_size * sizeof(int));
265 while (pof2 < comm_size) {
266 dst = (rank + pof2) % comm_size;
267 src = (rank - pof2 + comm_size) % comm_size;
269 /* Exchange all data blocks whose ith bit is 1 */
270 /* Create an indexed datatype for the purpose */
273 for (block=1; block<comm_size; block++) {
275 displs[count] = block * recvcount;
280 mpi_errno = MPI_Type_create_indexed_block(count, recvcount, displs, recvtype, &newtype);
285 mpi_errno = MPI_Type_commit(&newtype);
291 mpi_errno = MPI_Pack(recvbuf, 1, newtype, tmp_buf, pack_size,
294 getAmpiInstance(comm)->sendrecv(tmp_buf, position, MPI_PACKED, dst,
295 MPI_ATA_TAG, recvbuf, 1, newtype,
296 src, MPI_ATA_TAG, comm,
303 mpi_errno = MPI_Type_free(&newtype);
314 /* Rotate blocks in recvbuf upwards by (rank + 1) blocks. Need
315 * a temporary buffer of the same size as recvbuf. */
317 /* get true extent of recvtype */
318 mpi_errno = MPI_Type_get_true_extent(recvtype, &recvtype_true_lb,
319 &recvtype_true_extent);
324 recvbuf_extent = recvcount * comm_size *
325 (MAX(recvtype_true_extent, recvtype_extent));
326 tmp_buf = malloc(recvbuf_extent);
329 /* adjust for potential negative lower bound in datatype */
330 tmp_buf = (void *)((char*)tmp_buf - recvtype_true_lb);
332 MPICH_Localcopy((char *) recvbuf + (rank+1)*recvcount*recvtype_extent,
333 (comm_size - rank - 1)*recvcount, recvtype, tmp_buf,
334 (comm_size - rank - 1)*recvcount, recvtype);
336 MPICH_Localcopy(recvbuf, (rank+1)*recvcount, recvtype,
337 (char *) tmp_buf + (comm_size-rank-1)*recvcount*recvtype_extent,
338 (rank+1)*recvcount, recvtype);
341 /* Blocks are in the reverse order now (comm_size-1 to 0).
342 * Reorder them to (0 to comm_size-1) and store them in recvbuf. */
344 for (i=0; i<comm_size; i++)
345 MPICH_Localcopy((char *) tmp_buf + i*recvcount*recvtype_extent,
347 (char *) recvbuf + (comm_size-i-1)*recvcount*recvtype_extent,
348 recvcount, recvtype);
350 free((char*)tmp_buf + recvtype_true_lb);
355 /////////////////////////////////////////////////////////////////////////////////////////////////////
357 /////////////////////////////////////////////////////////////////////////////////////////////////////
359 int AMPI_Alltoall_medium(
362 MPI_Datatype sendtype,
365 MPI_Datatype recvtype,
370 MPI_Aint sendtype_extent, recvtype_extent;
372 int mpi_errno=MPI_SUCCESS, dst, rank, nbytes;
375 MPI_Request *reqarray;
378 if (sendcount == 0) return MPI_SUCCESS;
380 MPI_Comm_rank (MPI_COMM_WORLD, &rank);
381 MPI_Comm_size (MPI_COMM_WORLD, &comm_size);
383 /* Get extent of send and recv types */
384 MPID_Datatype_get_extent_macro(recvtype, recvtype_extent);
385 MPID_Datatype_get_extent_macro(sendtype, sendtype_extent);
387 MPID_Datatype_get_size_macro(sendtype, sendtype_size);
388 nbytes = sendtype_size * sendcount;
390 /* Medium-size message. Use isend/irecv with scattered destinations */
392 reqarray = (MPI_Request *) malloc(2*comm_size*sizeof(MPI_Request));
395 return MPI_ERR_OTHER;
397 starray = (MPI_Status *) malloc(2*comm_size*sizeof(MPI_Status));
400 return MPI_ERR_OTHER;
403 /* do the communication -- post all sends and receives: */
404 ampi *ptr = getAmpiInstance(comm);
405 for ( i=0; i<comm_size; i++ ) {
406 dst = (rank+i) % comm_size;
407 ptr->irecv((char *)recvbuf + dst*recvcount*recvtype_extent, recvcount, recvtype, dst,
408 MPI_ATA_TAG, comm, &reqarray[i]);
411 for ( i=0; i<comm_size; i++ ) {
412 dst = (rank+i) % comm_size;
413 /*mpi_errno = AMPI_Isend((char *)sendbuf + dst*sendcount*sendtype_extent,
414 sendcount, sendtype, dst, MPI_ATA_TAG, comm, &reqarray[i+comm_size]);*/
415 ptr->send(MPI_ATA_TAG, getAmpiInstance(comm)->getRank(comm),
416 (char *)sendbuf + dst*sendcount*sendtype_extent,
417 sendcount, sendtype, dst, comm);
418 reqarray[i+comm_size] = MPI_REQUEST_NULL;
421 /* ... then wait for *all* of them to finish: */
422 mpi_errno = AMPI_Waitall(2*comm_size,reqarray,starray);
424 /* --BEGIN ERROR HANDLING-- */
425 // if (mpi_errno == MPI_ERR_IN_STATUS) {
426 // for (int j=0; j<2*comm_size; j++) {
427 // if (starray[j] != MPI_SUCCESS)
428 // mpi_errno = starray[j];
431 /* --END ERROR HANDLING-- */
441 /////////////////////////////////////////////////////////////////////////////////////////////////////
442 // MPICH OLD VERSION -- coming soon, once I figure out how it worked
443 /////////////////////////////////////////////////////////////////////////////////////////////////////