From 0448668998c9ab22595956ac634a5d2742f4b44d Mon Sep 17 00:00:00 2001 From: Karthik Senthil Date: Fri, 11 Nov 2016 14:47:06 -0600 Subject: [PATCH] Bug #961: AMPI support for cancelling recv requests Change-Id: I5afcb4a576fa658533234a6760417d7822fa0a54 --- src/libs/ck-libs/ampi/ampi.C | 68 ++++++++++++++++++++++++++++++++-------- src/libs/ck-libs/ampi/ampi.h | 4 +-- src/libs/ck-libs/ampi/ampif.h | 1 + src/libs/ck-libs/ampi/ampiimpl.h | 20 +++++++++++- 4 files changed, 77 insertions(+), 16 deletions(-) diff --git a/src/libs/ck-libs/ampi/ampi.C b/src/libs/ck-libs/ampi/ampi.C index 9ee3bbc255..e238beacbe 100644 --- a/src/libs/ck-libs/ampi/ampi.C +++ b/src/libs/ck-libs/ampi/ampi.C @@ -2582,6 +2582,7 @@ int ampi::recv(int t, int s, void* buf, int count, MPI_Datatype type, MPI_Comm c sts->MPI_TAG = MPI_ANY_TAG; sts->MPI_COMM = comm; sts->MPI_LENGTH = 0; + sts->MPI_CANCEL = 0; return 0; } #if CMK_TRACE_ENABLED && CMK_PROJECTOR @@ -2623,6 +2624,7 @@ int ampi::recv(int t, int s, void* buf, int count, MPI_Datatype type, MPI_Comm c if (sts) { sts->MPI_COMM = msg->getComm(comm); sts->MPI_LENGTH = msg->getLength(); + sts->MPI_CANCEL = 0; } dis->processAmpiMsg(msg, buf, type, count); @@ -2665,6 +2667,7 @@ void ampi::probe(int t, int s, MPI_Comm comm, MPI_Status *sts) if (sts) { sts->MPI_COMM = msg->getComm(comm); sts->MPI_LENGTH = msg->getLength(); + sts->MPI_CANCEL = 0; } #if CMK_BIGSIM_CHARM @@ -2682,6 +2685,7 @@ int ampi::iprobe(int t, int s, MPI_Comm comm, MPI_Status *sts) if (sts) { sts->MPI_COMM = msg->getComm(comm); sts->MPI_LENGTH = msg->getLength(); + sts->MPI_CANCEL = 0; } return 1; } @@ -4312,6 +4316,13 @@ int IReq::wait(MPI_Status *sts){ dis->block(); dis = getAmpiInstance(comm); + if (cancelled) { + sts->MPI_CANCEL = 1; + statusIreq = true; + dis->resumeOnRecv = false; + return 0; + } + #if CMK_BIGSIM_CHARM //Because of the out-of-core emulation, this pointer is changed after in-out //memory operation. So we need to return from this function and do the while loop @@ -4330,6 +4341,7 @@ int IReq::wait(MPI_Status *sts){ sts->MPI_SOURCE = src; sts->MPI_COMM = comm; sts->MPI_LENGTH = length; + sts->MPI_CANCEL = 0; } return 0; @@ -4364,6 +4376,7 @@ int RednReq::wait(MPI_Status *sts){ sts->MPI_TAG = tag; sts->MPI_SOURCE = src; sts->MPI_COMM = comm; + sts->MPI_CANCEL = 0; } return 0; } @@ -4397,6 +4410,7 @@ int GatherReq::wait(MPI_Status *sts){ sts->MPI_TAG = tag; sts->MPI_SOURCE = src; sts->MPI_COMM = comm; + sts->MPI_CANCEL = 0; } return 0; } @@ -4430,6 +4444,7 @@ int GathervReq::wait(MPI_Status *sts){ sts->MPI_TAG = tag; sts->MPI_SOURCE = src; sts->MPI_COMM = comm; + sts->MPI_CANCEL = 0; } return 0; } @@ -4446,6 +4461,7 @@ int SendReq::wait(MPI_Status *sts){ AMPI_DEBUG("SendReq::wait has resumed\n"); if (sts) { sts->MPI_COMM = comm; + sts->MPI_CANCEL = 0; } return 0; } @@ -4458,6 +4474,7 @@ int SsendReq::wait(MPI_Status *sts){ } if (sts) { sts->MPI_COMM = comm; + sts->MPI_CANCEL = 0; } return 0; } @@ -4735,20 +4752,34 @@ bool PersReq::itest(MPI_Status *sts){ } bool IReq::test(MPI_Status *sts){ - if (statusIreq && sts) { - sts->MPI_COMM = comm; - sts->MPI_LENGTH = length; - } - else { - getAmpiInstance(comm)->yield(); + if (sts) { + if (cancelled) { + sts->MPI_CANCEL = 1; + statusIreq = true; + } + else if (statusIreq) { + sts->MPI_COMM = comm; + sts->MPI_LENGTH = length; + sts->MPI_CANCEL = 0; + } + else { + getAmpiInstance(comm)->yield(); + } } return statusIreq; } bool IReq::itest(MPI_Status *sts){ - if (statusIreq && sts) { - sts->MPI_COMM = comm; - sts->MPI_LENGTH = length; + if (sts) { + if (cancelled) { + sts->MPI_CANCEL = 1; + statusIreq = true; + } + else if (statusIreq) { + sts->MPI_COMM = comm; + sts->MPI_LENGTH = length; + sts->MPI_CANCEL = 0; + } } return statusIreq; } @@ -5044,21 +5075,32 @@ int AMPI_Request_free(MPI_Request *request){ CDECL int AMPI_Cancel(MPI_Request *request){ AMPIAPI("AMPI_Cancel"); - return AMPI_Request_free(request); + if(*request == MPI_REQUEST_NULL) return MPI_SUCCESS; + checkRequest(*request); + AmpiRequestList* reqs = getReqs(); + AmpiRequest& req = *(*reqs)[*request]; + if(req.getType() == MPI_I_REQ) { + req.cancel(); + return MPI_SUCCESS; + } + else { + return ampiErrhandler("AMPI_Cancel", MPI_ERR_REQUEST); + } } CDECL int AMPI_Test_cancelled(MPI_Status* status, int* flag) { AMPIAPI("AMPI_Test_cancelled"); - /* FIXME: always returns success */ - *flag = 1; + // NOTE : current implementation requires AMPI_{Wait,Test}{any,some,all} + // to be invoked before AMPI_Test_cancelled + *flag = status->MPI_CANCEL; return MPI_SUCCESS; } CDECL int AMPI_Status_set_cancelled(MPI_Status *status, int flag){ AMPIAPI("AMPI_Status_set_cancelled"); - /* AMPI_Test_cancelled always returns true */ + status->MPI_CANCEL = flag; return MPI_SUCCESS; } diff --git a/src/libs/ck-libs/ampi/ampi.h b/src/libs/ck-libs/ampi/ampi.h index 089750cbde..583d90cba9 100644 --- a/src/libs/ck-libs/ampi/ampi.h +++ b/src/libs/ck-libs/ampi/ampi.h @@ -256,11 +256,11 @@ extern MPI_Comm MPI_COMM_UNIVERSE[MPI_MAX_COMM_WORLDS]; struct AmpiMsg; typedef int MPI_Request; typedef struct { - int MPI_TAG, MPI_SOURCE, MPI_COMM, MPI_LENGTH, MPI_ERROR; /* FIXME: MPI_ERROR is never used */ + int MPI_TAG, MPI_SOURCE, MPI_COMM, MPI_LENGTH, MPI_ERROR, MPI_CANCEL; /* FIXME: MPI_ERROR is never used */ struct AmpiMsg *msg; } MPI_Status; -#define stsempty(sts) (sts).MPI_TAG=(sts).MPI_SOURCE=(sts).MPI_COMM=(sts).MPI_LENGTH=0 +#define stsempty(sts) (sts).MPI_TAG=(sts).MPI_SOURCE=(sts).MPI_COMM=(sts).MPI_LENGTH=(sts).MPI_CANCEL=0 #define MPI_STATUS_IGNORE (MPI_Status *)0 #define MPI_STATUSES_IGNORE (MPI_Status *)0 diff --git a/src/libs/ck-libs/ampi/ampif.h b/src/libs/ck-libs/ampi/ampif.h index f36c365c74..a2ab49e1f4 100644 --- a/src/libs/ck-libs/ampi/ampif.h +++ b/src/libs/ck-libs/ampi/ampif.h @@ -164,6 +164,7 @@ integer, parameter :: MPI_TAG = 1 integer, parameter :: MPI_SOURCE = 2 integer, parameter :: MPI_COMM = 3 + integer, parameter :: MPI_ERROR = 5 integer, dimension(MPI_STATUS_SIZE) :: MPI_STATUS_IGNORE integer, dimension(MPI_STATUS_SIZE) :: MPI_STATUSES_IGNORE diff --git a/src/libs/ck-libs/ampi/ampiimpl.h b/src/libs/ck-libs/ampi/ampiimpl.h index 01a9876913..563a3fb215 100644 --- a/src/libs/ck-libs/ampi/ampiimpl.h +++ b/src/libs/ck-libs/ampi/ampiimpl.h @@ -618,6 +618,10 @@ class AmpiRequest { /// returning a valid MPI error code. virtual int wait(MPI_Status *sts) =0; + /// Mark this request for cancellation. + /// Supported only for IReq requests + virtual void cancel() =0; + /// Receive an AmpiMsg virtual void receive(ampi *ptr, AmpiMsg *msg) = 0; @@ -668,6 +672,7 @@ class PersReq : public AmpiRequest { bool itest(MPI_Status *sts); void complete(MPI_Status *sts); int wait(MPI_Status *sts); + void cancel() {} void receive(ampi *ptr, AmpiMsg *msg) {} void receive(ampi *ptr, CkReductionMsg *msg) {} inline int getType(void) const { return MPI_PERS_REQ; } @@ -681,9 +686,10 @@ class PersReq : public AmpiRequest { class IReq : public AmpiRequest { public: int length; // recv'ed length + bool cancelled; // track if request is cancelled IReq(void *buf_, int count_, MPI_Datatype type_, int src_, int tag_, MPI_Comm comm_){ buf=buf_; count=count_; type=type_; src=src_; tag=tag_; - comm=comm_; isvalid=true; length=0; + comm=comm_; isvalid=true; length=0; cancelled=false; } IReq(){} ~IReq(){} @@ -691,6 +697,11 @@ class IReq : public AmpiRequest { bool itest(MPI_Status *sts); void complete(MPI_Status *sts); int wait(MPI_Status *sts); + void cancel() { + if (!statusIreq) { + cancelled = true; + } + } inline int getType(void) const { return MPI_I_REQ; } void receive(ampi *ptr, AmpiMsg *msg); void receive(ampi *ptr, CkReductionMsg *msg) {} @@ -714,6 +725,7 @@ class RednReq : public AmpiRequest { bool itest(MPI_Status *sts); void complete(MPI_Status *sts); int wait(MPI_Status *sts); + void cancel() {} inline int getType(void) const { return MPI_REDN_REQ; } void receive(ampi *ptr, AmpiMsg *msg) {} void receive(ampi *ptr, CkReductionMsg *msg); @@ -736,6 +748,7 @@ class GatherReq : public AmpiRequest { bool itest(MPI_Status *sts); void complete(MPI_Status *sts); int wait(MPI_Status *sts); + void cancel() {} inline int getType(void) const { return MPI_GATHER_REQ; } void receive(ampi *ptr, AmpiMsg *msg) {} void receive(ampi *ptr, CkReductionMsg *msg); @@ -763,6 +776,7 @@ class GathervReq : public AmpiRequest { bool itest(MPI_Status *sts); void complete(MPI_Status *sts); int wait(MPI_Status *sts); + void cancel() {} inline int getType(void) const { return MPI_GATHERV_REQ; } void receive(ampi *ptr, AmpiMsg *msg) {} void receive(ampi *ptr, CkReductionMsg *msg); @@ -784,6 +798,7 @@ class SendReq : public AmpiRequest { bool itest(MPI_Status *sts); void complete(MPI_Status *sts); int wait(MPI_Status *sts); + void cancel() {} void receive(ampi *ptr, AmpiMsg *msg) {} void receive(ampi *ptr, CkReductionMsg *msg) {} inline int getType(void) const { return MPI_SEND_REQ; } @@ -804,6 +819,7 @@ class SsendReq : public AmpiRequest { bool itest(MPI_Status *sts); void complete(MPI_Status *sts); int wait(MPI_Status *sts); + void cancel() {} void receive(ampi *ptr, AmpiMsg *msg) {} void receive(ampi *ptr, CkReductionMsg *msg) {} inline int getType(void) const { return MPI_SSEND_REQ; } @@ -821,6 +837,7 @@ class GPUReq : public AmpiRequest { bool itest(MPI_Status *sts); void complete(MPI_Status *sts); int wait(MPI_Status *sts); + void cancel() {} void receive(ampi *ptr, AmpiMsg *msg); void receive(ampi *ptr, CkReductionMsg *msg) {} void setComplete(); @@ -844,6 +861,7 @@ class IATAReq : public AmpiRequest { bool itest(MPI_Status *sts); void complete(MPI_Status *sts); int wait(MPI_Status *sts); + void cancel() {} void receive(ampi *ptr, AmpiMsg *msg) {} void receive(ampi *ptr, CkReductionMsg *msg) {} inline int getCount(void) const { return elmcount; } -- 2.11.4.GIT