From e78a01891e2adb63cd82f2d91d5fdadd315aff11 Mon Sep 17 00:00:00 2001 From: Sam White Date: Wed, 9 Nov 2016 10:26:58 -0600 Subject: [PATCH] AMPI: drop unnecessary srcIdx field from the message envelope We now use the srcRank to do all message matching. This means that the sequencing queue can be proportional to the size of the comm, rather than always being based on the size of MPI_COMM_WORLD. Change-Id: I2bf444d2643a9c985e78de262fe222cb4f460e3c --- src/libs/ck-libs/ampi/ampi.C | 64 ++++++++++++++++-------------------- src/libs/ck-libs/ampi/ampiOneSided.C | 6 ++-- src/libs/ck-libs/ampi/ampiimpl.h | 33 +++++++++---------- 3 files changed, 47 insertions(+), 56 deletions(-) diff --git a/src/libs/ck-libs/ampi/ampi.C b/src/libs/ck-libs/ampi/ampi.C index 968873ca53..8a5e08058d 100644 --- a/src/libs/ck-libs/ampi/ampi.C +++ b/src/libs/ck-libs/ampi/ampi.C @@ -1704,8 +1704,7 @@ ampi::ampi(CkArrayID parent_,const ampiCommStruct &s):parentProxy(parent_) msgs = AmmNew(); posted_ireqs = AmmNew(); - seqEntries=parent->ckGetArraySize(); - oorder.init (seqEntries); + oorder.init(myComm.getSize()); } ampi::ampi(CkMigrateMessage *msg):CBase_ampi(msg) @@ -2312,8 +2311,8 @@ void ampi::ssend_ack(int sreq_idx){ void ampi::generic(AmpiMsg* msg) { MSG_ORDER_DEBUG( - CkPrintf("AMPI vp %d arrival: tag=%d, src=%d, comm=%d (from %d, seq %d) resumeOnRecv %d\n", - thisIndex, msg->getTag(), msg->getSrcRank(), msg->getComm(this->getComm()), msg->getSrcIdx(), msg->getSeq(), resumeOnRecv); + CkPrintf("AMPI vp %d arrival: tag=%d, src=%d, comm=%d (seq %d) resumeOnRecv %d\n", + thisIndex, msg->getTag(), msg->getSrcRank(), msg->getComm(this->getComm()), msg->getSeq(), resumeOnRecv); ) #if CMK_BIGSIM_CHARM TRACE_BG_ADD_TAG("AMPI_generic"); @@ -2322,15 +2321,16 @@ void ampi::generic(AmpiMsg* msg) int sync = UsrToEnv(msg)->getRef(); int srcIdx; - if (sync) srcIdx = msg->getSrcIdx(); + if (sync) srcIdx = getIndexForRank(msg->getSrcRank()); if(msg->getSeq() != -1) { - int srcIdx=msg->getSrcIdx(); - int n=oorder.put(srcIdx,msg); + // If message was sent over MPI_COMM_SELF, srcRank needs to be this rank in MPI_COMM_WORLD: + int srcRank = (msg->getComm(this->getComm()) == MPI_COMM_SELF) ? this->getRank(MPI_COMM_WORLD) : msg->getSrcRank(); + int n=oorder.put(srcRank,msg); if (n>0) { // This message was in-order inorder(msg); if (n>1) { // It enables other, previously out-of-order messages - while((msg=oorder.getOutOfOrder(srcIdx))!=0) { + while((msg=oorder.getOutOfOrder(srcRank))!=0) { inorder(msg); } } @@ -2355,8 +2355,8 @@ inline static AmpiRequestList *getReqs(void); void ampi::inorder(AmpiMsg* msg) { MSG_ORDER_DEBUG( - CkPrintf("AMPI vp %d inorder: tag=%d, src=%d, comm=%d (from %d, seq %d)\n", - thisIndex, msg->getTag(), msg->getSrcRank(), msg->getComm(this->getComm()), msg->getSrcIdx(), msg->getSeq()); + CkPrintf("AMPI vp %d inorder: tag=%d, src=%d, comm=%d (seq %d)\n", + thisIndex, msg->getTag(), msg->getSrcRank(), msg->getComm(this->getComm()), msg->getSeq()); ) // check posted recvs @@ -2420,7 +2420,7 @@ AmpiMsg *ampi::makeAmpiMsg(int destIdx,int t,int sRank,const void *buf,int count int seq = -1; if (destIdx>=0 && destcomm<=MPI_COMM_WORLD && t<=MPI_ATA_SEQ_TAG) //Not cross-module: set seqno seq = oorder.nextOutgoing(destIdx); - AmpiMsg *msg = new (len, 0) AmpiMsg(seq, t, sIdx, sRank, len, destcomm); + AmpiMsg *msg = new (len, 0) AmpiMsg(seq, t, sRank, len, destcomm); if (sync) UsrToEnv(msg)->setRef(sync); ddt->serialize((char*)buf, msg->getData(), count, 1); return msg; @@ -2455,7 +2455,7 @@ void ampi::send(int t, int sRank, const void* buf, int count, MPI_Datatype type, void ampi::sendraw(int t, int sRank, void* buf, int len, CkArrayID aid, int idx) { - AmpiMsg *msg = new (len, 0) AmpiMsg(-1, t, -1, sRank, len); + AmpiMsg *msg = new (len, 0) AmpiMsg(-1, t, sRank, len); memcpy(msg->getData(), buf, len); CProxy_ampi pa(aid); pa[idx].generic(msg); @@ -2701,40 +2701,34 @@ int ampi::iprobe(int t, int s, MPI_Comm comm, MPI_Status *sts) return 0; } - -const int MPI_BCAST_COMM=MPI_COMM_WORLD+1000; - void ampi::bcast(int root, void* buf, int count, MPI_Datatype type, MPI_Comm destcomm) { - const ampiCommStruct &dest=comm2CommStruct(destcomm); - int rootIdx=dest.getIndexForRank(root); - if(rootIdx==thisIndex) { + if (root==getRank(destcomm)) { #if (defined(_FAULT_MLOG_) || defined(_FAULT_CAUSAL_)) CpvAccess(_currentObj) = this; #endif - thisProxy.generic(makeAmpiMsg(-1,MPI_BCAST_TAG,0, buf,count,type,destcomm)); + thisProxy.generic(makeAmpiMsg(-1, MPI_BCAST_TAG, root, buf, count, type, destcomm)); } - if(-1==recv(MPI_BCAST_TAG,0, buf,count,type, MPI_BCAST_COMM)) CkAbort("AMPI> Error in broadcast"); + + if (-1==recv(MPI_BCAST_TAG, root, buf, count, type, destcomm)) CkAbort("AMPI> Error in broadcast"); } void ampi::ibcast(int root, void* buf, int count, MPI_Datatype type, MPI_Comm destcomm, MPI_Request* request) { - const ampiCommStruct &dest=comm2CommStruct(destcomm); - int rootIdx=dest.getIndexForRank(root); - if(rootIdx==thisIndex){ + if (root==getRank(destcomm)) { #if (defined(_FAULT_MLOG_) || defined(_FAULT_CAUSAL_)) CpvAccess(_currentObj) = this; #endif - thisProxy.generic(makeAmpiMsg(-1, MPI_BCAST_TAG, 0, buf, count, type, destcomm)); + thisProxy.generic(makeAmpiMsg(-1, MPI_BCAST_TAG, root, buf, count, type, destcomm)); } // use an IReq to non-block the caller and get a request ptr - *request = postReq(new IReq(buf, count, type, rootIdx, MPI_BCAST_TAG, MPI_BCAST_COMM)); + *request = postReq(new IReq(buf, count, type, root, MPI_BCAST_TAG, destcomm)); } void ampi::bcastraw(void* buf, int len, CkArrayID aid) { - AmpiMsg *msg = new (len, 0) AmpiMsg(-1, MPI_BCAST_TAG, -1, 0, len); + AmpiMsg *msg = new (len, 0) AmpiMsg(-1, MPI_BCAST_TAG, 0, len); memcpy(msg->getData(), buf, len); CProxy_ampi pa(aid); pa.generic(msg); @@ -2748,7 +2742,7 @@ AmpiMsg* ampi::Alltoall_RemoteIget(MPI_Aint disp, int cnt, MPI_Datatype type, in unit = ddt->getSize(1); int totalsize = unit*cnt; - AmpiMsg *msg = new (totalsize, 0) AmpiMsg(-1, MPI_ATA_TAG, -1, thisIndex,totalsize); + AmpiMsg *msg = new (totalsize, 0) AmpiMsg(-1, MPI_ATA_TAG, thisIndex,totalsize); char* addr = (char*)Alltoallbuff+disp*unit; ddt->serialize(msg->getData(), addr, cnt, (-1)); return msg; @@ -2788,9 +2782,9 @@ int MPI_type_null_delete_fn(MPI_Datatype type, int keyval, void *attr, void *ext return (MPI_SUCCESS); } -void AmpiSeqQ::init(int numP) +void AmpiSeqQ::init(int commSize) { - elements.init(numP); + elements.init(commSize); } AmpiSeqQ::~AmpiSeqQ () { @@ -2801,9 +2795,9 @@ void AmpiSeqQ::pup(PUP::er &p) { p|elements; } -void AmpiSeqQ::putOutOfOrder(int srcIdx, AmpiMsg *msg) +void AmpiSeqQ::putOutOfOrder(int srcRank, AmpiMsg *msg) { - AmpiOtherElement &el=elements[srcIdx]; + AmpiOtherElement &el=elements[srcRank]; #if CMK_ERROR_CHECKING if (msg->getSeq() < el.seqIncoming) CkAbort("AMPI Logic error: received late out-of-order message!\n"); @@ -2812,14 +2806,14 @@ void AmpiSeqQ::putOutOfOrder(int srcIdx, AmpiMsg *msg) el.nOut++; // We have another message in the out-of-order queue } -AmpiMsg *AmpiSeqQ::getOutOfOrder(int srcIdx) +AmpiMsg *AmpiSeqQ::getOutOfOrder(int srcRank) { - AmpiOtherElement &el=elements[srcIdx]; + AmpiOtherElement &el=elements[srcRank]; if (el.nOut==0) return 0; // No more out-of-order left. // Walk through our out-of-order queue, searching for our next message: for (int i=0;igetSrcIdx()==srcIdx && msg->getSeq()==el.seqIncoming) { + if (msg->getSrcRank()==srcRank && msg->getSeq()==el.seqIncoming) { el.seqIncoming++; el.nOut--; // We have one less message out-of-order return msg; @@ -3605,7 +3599,7 @@ int AMPI_Ibcast(void *buf, int count, MPI_Datatype type, int root, ampi* ptr = getAmpiInstance(comm); if(comm==MPI_COMM_SELF){ - *request = ptr->postReq(new IReq(buf, count, type, root, MPI_BCAST_TAG, MPI_BCAST_COMM), + *request = ptr->postReq(new IReq(buf, count, type, root, MPI_BCAST_TAG, comm), AMPI_REQ_COMPLETED); return MPI_SUCCESS; } diff --git a/src/libs/ck-libs/ampi/ampiOneSided.C b/src/libs/ck-libs/ampi/ampiOneSided.C index 94a7abcc8d..2561b11499 100644 --- a/src/libs/ck-libs/ampi/ampiOneSided.C +++ b/src/libs/ck-libs/ampi/ampiOneSided.C @@ -269,7 +269,7 @@ AmpiMsg* ampi::winRemoteGet(int orgcnt, MPI_Datatype orgtype, MPI_Aint targdisp, winobj->get(targaddr, orgcnt, orgunit, targdisp, targcnt, targunit); AMPI_DEBUG(" Rank[%d] get win [%d] \n", thisIndex, *(int*)(targaddr)); - AmpiMsg *msg = new (targtotalsize, 0) AmpiMsg(-1, MPI_RMA_TAG, -1, thisIndex, targtotalsize); + AmpiMsg *msg = new (targtotalsize, 0) AmpiMsg(-1, MPI_RMA_TAG, thisIndex, targtotalsize); tddt->serialize(targaddr, msg->getData(), targcnt, 1); return msg; } @@ -297,7 +297,7 @@ AmpiMsg* ampi::winRemoteIget(MPI_Aint orgdisp, int orgcnt, MPI_Datatype orgtype, AMPI_DEBUG(" Rank[%d] iget win [%d] \n", thisIndex, *(int*)(targaddr)); - AmpiMsg *msg = new (targtotalsize, 0) AmpiMsg(-1, MPI_RMA_TAG, -1, thisIndex, targtotalsize); + AmpiMsg *msg = new (targtotalsize, 0) AmpiMsg(-1, MPI_RMA_TAG, thisIndex, targtotalsize); char* targaddr = (char*)(winobj->baseAddr) + targdisp*targunit; tddt->serialize(targaddr, msg->getData(), targcnt, 1); @@ -412,7 +412,7 @@ AmpiMsg* ampi::winRemoteCompareAndSwap(int size, char* sorgaddr, char* compaddr, CkDDT_DataType *ddt = getDDT()->getType(type); char* targaddr = ((char*)(winobj->baseAddr)) + ddt->getSize(targdisp); - AmpiMsg *msg = new (size, 0) AmpiMsg(-1, MPI_RMA_TAG, -1, thisIndex, size); + AmpiMsg *msg = new (size, 0) AmpiMsg(-1, MPI_RMA_TAG, thisIndex, size); ddt->serialize(targaddr, msg->getData(), 1, 1); if (*targaddr == *compaddr) { diff --git a/src/libs/ck-libs/ampi/ampiimpl.h b/src/libs/ck-libs/ampi/ampiimpl.h index d1658bab48..3db5d466ad 100644 --- a/src/libs/ck-libs/ampi/ampiimpl.h +++ b/src/libs/ck-libs/ampi/ampiimpl.h @@ -979,7 +979,6 @@ class AmpiMsg : public CMessage_AmpiMsg { private: int seq; //Sequence number (for message ordering) int tag; //MPI tag - int srcIdx; //Array index of source int srcRank; //Communicator rank for source int length; //Number of bytes in this message public: @@ -992,10 +991,10 @@ class AmpiMsg : public CMessage_AmpiMsg { public: AmpiMsg(void) { data = NULL; } - AmpiMsg(int _s, int t, int sIdx,int sRank, int l) : - seq(_s), tag(t), srcIdx(sIdx), srcRank(sRank), length(l) {} - AmpiMsg(int _s, int t, int sIdx,int sRank, int l, MPI_Comm comm) : - seq(_s), tag(t), srcIdx(sIdx), srcRank(sRank), length(l) + AmpiMsg(int _s, int t, int sRank, int l) : + seq(_s), tag(t), srcRank(sRank), length(l) {} + AmpiMsg(int _s, int t, int sRank, int l, MPI_Comm comm) : + seq(_s), tag(t), srcRank(sRank), length(l) { //We do not store comm, since it can be gotten from the ampi instance. //The exception is messages for MPI_COMM_SELF: // We make tag negative if the message is for MPI_COMM_SELF, because @@ -1004,7 +1003,6 @@ class AmpiMsg : public CMessage_AmpiMsg { if (comm == MPI_COMM_SELF) tag *= (-1); } inline int getSeq(void) const { return seq; } - inline int getSrcIdx(void) const { return srcIdx; } inline int getSrcRank(void) const { return srcRank; } inline int getLength(void) const { return length; } inline char* getData(void) const { return data; } @@ -1022,17 +1020,16 @@ class AmpiMsg : public CMessage_AmpiMsg { } static AmpiMsg* pup(PUP::er &p, AmpiMsg *m) { - int seq, length, tag, srcIdx, srcRank; + int seq, length, tag, srcRank; if(p.isPacking() || p.isSizing()) { seq = m->seq; tag = m->tag; - srcIdx = m->srcIdx; srcRank = m->srcRank; length = m->length; } - p(seq); p(tag); p(srcIdx); p(srcRank); p(length); + p(seq); p(tag); p(srcRank); p(length); if(p.isUnpacking()) { - m = new (length, 0) AmpiMsg(seq, tag, srcIdx, srcRank, length); + m = new (length, 0) AmpiMsg(seq, tag, srcRank, length); } p(m->data, length); if(p.isDeleting()) { @@ -1072,11 +1069,11 @@ class AmpiSeqQ : private CkNoncopyable { CkMsgQ out; // all out of order messages CkPagedVector elements; // element info - void putOutOfOrder(int srcIdx, AmpiMsg *msg); + void putOutOfOrder(int srcRank, AmpiMsg *msg); public: AmpiSeqQ() {} - void init(int numP); + void init(int commSize); ~AmpiSeqQ (); void pup(PUP::er &p); @@ -1086,25 +1083,25 @@ public: /// If 1, this message can be immediately processed. /// If >1, this message can be immediately processed, /// and you should call "getOutOfOrder" repeatedly. - inline int put(int srcIdx, AmpiMsg *msg) { - AmpiOtherElement &el=elements[srcIdx]; + inline int put(int srcRank, AmpiMsg *msg) { + AmpiOtherElement &el=elements[srcRank]; if (msg->getSeq()==el.seqIncoming) { // In order: el.seqIncoming++; return 1+el.nOut; } else { // Out of order: stash message - putOutOfOrder(srcIdx, msg); + putOutOfOrder(srcRank, msg); return 0; } } /// Get an out-of-order message from the table. /// (in-order messages never go into the table) - AmpiMsg *getOutOfOrder(int p); + AmpiMsg *getOutOfOrder(int srcRank); /// Return the next outgoing sequence number, and increment it. - int nextOutgoing(int p) { - return elements[p].seqOutgoing++; + int nextOutgoing(int srcRank) { + return elements[srcRank].seqOutgoing++; } }; PUPmarshall(AmpiSeqQ) -- 2.11.4.GIT