From 46e3f92fc4c3f8f12d771eb33f866c748722f63b Mon Sep 17 00:00:00 2001 From: Sam White Date: Fri, 3 Feb 2017 14:41:02 -0600 Subject: [PATCH] AMPI: enforce message ordering of broadcasts Change-Id: I3e1aef9a2ca6f59b9f88a44ac500a7957372d61a --- src/libs/ck-libs/ampi/ampi.C | 24 ++++++++++++++++++------ src/libs/ck-libs/ampi/ampiimpl.h | 20 +++++++++++++++++++- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/libs/ck-libs/ampi/ampi.C b/src/libs/ck-libs/ampi/ampi.C index 7152772e35..a2a4b46dc3 100644 --- a/src/libs/ck-libs/ampi/ampi.C +++ b/src/libs/ck-libs/ampi/ampi.C @@ -2796,9 +2796,9 @@ void ampi::sendraw(int t, int sRank, void* buf, int len, CkArrayID aid, int idx) } CMK_REFNUM_TYPE ampi::getSeqNo(int destRank, MPI_Comm destcomm, int tag) { - int seqIdx = destRank; + int seqIdx = (destRank == AMPI_COLL_DEST) ? COLL_SEQ_IDX : destRank; CMK_REFNUM_TYPE seq = 0; - if (destRank>=0 && destcomm<=MPI_COMM_WORLD && tag<=MPI_ATA_SEQ_TAG) { //Not cross-module: set seqno + if (destcomm<=MPI_COMM_WORLD && tag<=MPI_BCAST_TAG) { //Not cross-module: set seqno seq = oorder.nextOutgoing(seqIdx); } return seq; @@ -3194,7 +3194,10 @@ void ampi::bcast(int root, void* buf, int count, MPI_Datatype type, MPI_Comm des #if (defined(_FAULT_MLOG_) || defined(_FAULT_CAUSAL_)) CpvAccess(_currentObj) = this; #endif - thisProxy.generic(makeAmpiMsg(0, MPI_BCAST_TAG, root, buf, count, type, destcomm)); + thisProxy.generic(makeAmpiMsg(AMPI_COLL_DEST, MPI_BCAST_TAG, root, buf, count, type, destcomm)); + } + else { // Non-root ranks need to increment the outgoing sequence number for collectives + oorder.incCollSeqOutgoing(); } if (-1==recv(MPI_BCAST_TAG, root, buf, count, type, destcomm)) CkAbort("AMPI> Error in broadcast"); @@ -3206,7 +3209,10 @@ int ampi::intercomm_bcast(int root, void* buf, int count, MPI_Datatype type, MPI #if (defined(_FAULT_MLOG_) || defined(_FAULT_CAUSAL_)) CpvAccess(_currentObj) = this; #endif - remoteProxy.generic(makeAmpiMsg(0, MPI_BCAST_TAG, getRank(), buf, count, type, intercomm)); + remoteProxy.generic(makeAmpiMsg(AMPI_COLL_DEST, MPI_BCAST_TAG, getRank(), buf, count, type, intercomm)); + } + else { // Non-root ranks need to increment the outgoing sequence number for collectives + oorder.incCollSeqOutgoing(); } if (root!=MPI_PROC_NULL && root!=MPI_ROOT) { @@ -3222,7 +3228,10 @@ void ampi::ibcast(int root, void* buf, int count, MPI_Datatype type, MPI_Comm de #if (defined(_FAULT_MLOG_) || defined(_FAULT_CAUSAL_)) CpvAccess(_currentObj) = this; #endif - thisProxy.generic(makeAmpiMsg(0, MPI_BCAST_TAG, root, buf, count, type, destcomm)); + thisProxy.generic(makeAmpiMsg(AMPI_COLL_DEST, MPI_BCAST_TAG, root, buf, count, type, destcomm)); + } + else { // Non-root ranks need to increment the outgoing sequence number for collectives + oorder.incCollSeqOutgoing(); } // call irecv to post an IReq and check for any pending messages @@ -3235,7 +3244,10 @@ int ampi::intercomm_ibcast(int root, void* buf, int count, MPI_Datatype type, MP #if (defined(_FAULT_MLOG_) || defined(_FAULT_CAUSAL_)) CpvAccess(_currentObj) = this; #endif - remoteProxy.generic(makeAmpiMsg(0, MPI_BCAST_TAG, getRank(), buf, count, type, intercomm)); + remoteProxy.generic(makeAmpiMsg(AMPI_COLL_DEST, MPI_BCAST_TAG, getRank(), buf, count, type, intercomm)); + } + else { // Non-root ranks need to increment the outgoing sequence number for collectives + oorder.incCollSeqOutgoing(); } if (root!=MPI_PROC_NULL && root!=MPI_ROOT) { diff --git a/src/libs/ck-libs/ampi/ampiimpl.h b/src/libs/ck-libs/ampi/ampiimpl.h index d14c1ec9e0..7d829845c8 100644 --- a/src/libs/ck-libs/ampi/ampiimpl.h +++ b/src/libs/ck-libs/ampi/ampiimpl.h @@ -885,6 +885,7 @@ extern int _mpi_nworlds; #define MPI_EPOCH_END_TAG MPI_TAG_UB_VALUE+12 #define AMPI_COLL_SOURCE 0 +#define AMPI_COLL_DEST -1 #define AMPI_COLL_COMM MPI_COMM_WORLD #define MPI_I_REQ 1 @@ -1347,6 +1348,8 @@ inline void pupFromBuf(const void *data,T &t) { PUP::fromMem p(data); p|t; } +#define COLL_SEQ_IDX -1 + class AmpiMsg : public CMessage_AmpiMsg { private: int ssendReq; //Index to the sender's request @@ -1371,7 +1374,15 @@ class AmpiMsg : public CMessage_AmpiMsg { { CkSetRefNum(this, _s); } inline CMK_REFNUM_TYPE getSeq(void) const { return UsrToEnv(this)->getRef(); } inline int getSsendReq(void) const { return ssendReq; } - inline int getSeqIdx(void) const { return srcRank; } + inline int getSeqIdx(void) const { + // seqIdx is srcRank, unless this message was part of a collective + if (tag >= MPI_BCAST_TAG && tag <= MPI_ATA_TAG) { + return COLL_SEQ_IDX; + } + else { + return srcRank; + } + } inline int getSrcRank(void) const { return srcRank; } inline int getLength(void) const { return length; } inline char* getData(void) const { return data; } @@ -1475,6 +1486,13 @@ public: /// Stash an out-of-order message void putOutOfOrder(int seqIdx, AmpiMsg *msg); + /// Increment the outgoing sequence number. + inline void incCollSeqOutgoing(void) { + CMK_REFNUM_TYPE& seqOutgoing = elements[COLL_SEQ_IDX].seqOutgoing; + seqOutgoing++; + if (seqOutgoing == 0) seqOutgoing++; + } + /// Return the next outgoing sequence number, and increment it. /// Handle wrap around of unsigned type CMK_REFNUM_TYPE. inline CMK_REFNUM_TYPE nextOutgoing(int destRank) { -- 2.11.4.GIT