AMPI: store predefined ops and types per-process rather than per-rank
[charm.git] / src / libs / ck-libs / ampi / ampiimpl.h
blobf97ecf4923dcea4ed99388684220184120ea9bbf
1 #ifndef _AMPIIMPL_H
2 #define _AMPIIMPL_H
4 #include <string.h> /* for strlen */
5 #include <algorithm>
6 #include <numeric>
7 #include <forward_list>
8 #include <bitset>
10 #include "ampi.h"
11 #include "ddt.h"
12 #include "charm++.h"
14 using std::vector;
16 //Uncomment for debug print statements
17 #define AMPI_DEBUG(...) //CkPrintf(__VA_ARGS__)
20 * All MPI_* routines must be defined using the AMPI_API_IMPL macro.
21 * All calls inside AMPI to MPI_* routines must use MPI_* as the name.
22 * There are two reasons for this:
24 * 1. AMPI supports the PMPI interface only on Linux.
26 * 2. When AMPI is built on top of MPI, we rename the user's MPI_* calls as AMPI_*.
28 #define STRINGIFY(a) #a
30 #if defined(__linux__)
31 #if CMK_CONVERSE_MPI
32 #define AMPI_API_IMPL(ret, name, ...) \
33 _Pragma(STRINGIFY(weak A##name)) \
34 _Pragma(STRINGIFY(weak AP##name = A##name)) \
35 CLINKAGE \
36 ret A##name(__VA_ARGS__)
37 #else
38 #define AMPI_API_IMPL(ret, name, ...) \
39 _Pragma(STRINGIFY(weak name)) \
40 _Pragma(STRINGIFY(weak P##name = name)) \
41 CLINKAGE \
42 ret name(__VA_ARGS__)
43 #endif
44 #else // not Linux (no PMPI support):
45 #if CMK_CONVERSE_MPI
46 #define AMPI_API_IMPL(ret, name, ...) \
47 CLINKAGE \
48 ret A##name(__VA_ARGS__)
49 #else
50 #define AMPI_API_IMPL(ret, name, ...) \
51 CLINKAGE \
52 ret name(__VA_ARGS__)
53 #endif
54 #endif
56 #if AMPIMSGLOG
57 #include "ckliststring.h"
58 static CkListString msgLogRanks;
59 static int msgLogWrite;
60 static int msgLogRead;
61 static char *msgLogFilename;
63 #if CMK_USE_ZLIB && 0
64 #include <zlib.h>
65 namespace PUP{
66 class zdisk : public er {
67 protected:
68 gzFile F;//Disk file to read from/write to
69 zdisk(unsigned int type,gzFile f):er(type),F(f) {}
70 zdisk(const zdisk &p); //You don't want to copy
71 void operator=(const zdisk &p); // You don't want to copy
73 //For seeking (pack/unpack in different orders)
74 virtual void impl_startSeek(seekBlock &s); /*Begin a seeking block*/
75 virtual int impl_tell(seekBlock &s); /*Give the current offset*/
76 virtual void impl_seek(seekBlock &s,int off); /*Seek to the given offset*/
79 //For packing to a disk file
80 class tozDisk : public zdisk {
81 protected:
82 //Generic bottleneck: pack n items of size itemSize from p.
83 virtual void bytes(void *p,int n,size_t itemSize,dataType t);
84 public:
85 //Write data to the given file pointer
86 // (must be opened for binary write)
87 // You must close the file yourself when done.
88 tozDisk(gzFile f):zdisk(IS_PACKING,f) {}
91 //For unpacking from a disk file
92 class fromzDisk : public zdisk {
93 protected:
94 //Generic bottleneck: unpack n items of size itemSize from p.
95 virtual void bytes(void *p,int n,size_t itemSize,dataType t);
96 public:
97 //Write data to the given file pointer
98 // (must be opened for binary read)
99 // You must close the file yourself when done.
100 fromzDisk(gzFile f):zdisk(IS_UNPACKING,f) {}
102 }; // namespace PUP
103 #endif
104 #endif // AMPIMSGLOG
106 /* AMPI sends messages inline to PE-local destination VPs if: BigSim is not being used and
107 * if tracing is not being used (see bug #1640 for more details on the latter). */
108 #ifndef AMPI_LOCAL_IMPL
109 #define AMPI_LOCAL_IMPL ( !CMK_BIGSIM_CHARM && !CMK_TRACE_ENABLED )
110 #endif
112 /* AMPI uses RDMA sends if BigSim is not being used and the underlying comm
113 * layer supports it (except for GNI, which has experimental RDMA support). */
114 #ifndef AMPI_RDMA_IMPL
115 #define AMPI_RDMA_IMPL ( !CMK_BIGSIM_CHARM && CMK_ONESIDED_IMPL && !CMK_CONVERSE_UGNI )
116 #endif
118 /* contiguous messages larger than or equal to this threshold are sent via RDMA */
119 #ifndef AMPI_RDMA_THRESHOLD_DEFAULT
120 #if CMK_USE_IBVERBS || CMK_OFI || CMK_CONVERSE_UGNI
121 #define AMPI_RDMA_THRESHOLD_DEFAULT 65536
122 #else
123 #define AMPI_RDMA_THRESHOLD_DEFAULT 32768
124 #endif
125 #endif
127 /* contiguous messages larger than or equal to this threshold that are being sent
128 * within a process are sent via RDMA. */
129 #ifndef AMPI_SMP_RDMA_THRESHOLD_DEFAULT
130 #define AMPI_SMP_RDMA_THRESHOLD_DEFAULT 16384
131 #endif
133 extern int AMPI_RDMA_THRESHOLD;
134 extern int AMPI_SMP_RDMA_THRESHOLD;
136 #define AMPI_ALLTOALL_THROTTLE 64
137 #define AMPI_ALLTOALL_SHORT_MSG 256
138 #if CMK_BIGSIM_CHARM
139 #define AMPI_ALLTOALL_LONG_MSG 4194304
140 #else
141 #define AMPI_ALLTOALL_LONG_MSG 32768
142 #endif
144 typedef void (*MPI_MigrateFn)(void);
147 * AMPI Message Matching (Amm) Interface:
148 * messages are matched on 2 ints: [tag, src]
150 #define AMM_TAG 0
151 #define AMM_SRC 1
152 #define AMM_NTAGS 2
154 // Number of AmmEntry<T>'s in AmmEntryPool
155 #ifndef AMPI_AMM_POOL_SIZE
156 #define AMPI_AMM_POOL_SIZE 32
157 #endif
159 class AmpiRequestList;
161 typedef void (*AmmPupMessageFn)(PUP::er& p, void **msg);
163 template <class T>
164 class AmmEntry {
165 public:
166 int tags[AMM_NTAGS]; // [tag, src]
167 AmmEntry<T>* next;
168 T msg; // T is either an AmpiRequest* or an AmpiMsg*
169 AmmEntry(T m) noexcept { tags[AMM_TAG] = m->getTag(); tags[AMM_SRC] = m->getSrcRank(); next = NULL; msg = m; }
170 AmmEntry(int tag, int src, T m) noexcept { tags[AMM_TAG] = tag; tags[AMM_SRC] = src; next = NULL; msg = m; }
171 AmmEntry() = default;
172 ~AmmEntry() = default;
175 template <class T>
176 class Amm {
177 public:
178 AmmEntry<T>* first;
179 AmmEntry<T>** lasth;
181 private:
182 int startIdx;
183 std::bitset<AMPI_AMM_POOL_SIZE> validEntries;
184 std::array<AmmEntry<T>, AMPI_AMM_POOL_SIZE> entryPool;
186 public:
187 Amm() noexcept : first(NULL), lasth(&first), startIdx(0) { validEntries.reset(); }
188 ~Amm() = default;
189 inline AmmEntry<T>* newEntry(int tag, int src, T msg) noexcept {
190 if (validEntries.all()) {
191 return new AmmEntry<T>(tag, src, msg);
192 } else {
193 for (int i=startIdx; i<validEntries.size(); i++) {
194 if (!validEntries[i]) {
195 validEntries[i] = 1;
196 AmmEntry<T>* ent = new (&entryPool[i]) AmmEntry<T>(tag, src, msg);
197 startIdx = i+1;
198 return ent;
201 CkAbort("AMPI> failed to find a free entry in pool!");
202 return NULL;
205 inline AmmEntry<T>* newEntry(T msg) noexcept {
206 if (validEntries.all()) {
207 return new AmmEntry<T>(msg);
208 } else {
209 for (int i=startIdx; i<validEntries.size(); i++) {
210 if (!validEntries[i]) {
211 validEntries[i] = 1;
212 AmmEntry<T>* ent = new (&entryPool[i]) AmmEntry<T>(msg);
213 startIdx = i+1;
214 return ent;
217 CkAbort("AMPI> failed to find a free entry in pool!");
218 return NULL;
221 inline void deleteEntry(AmmEntry<T> *ent) noexcept {
222 if (ent >= &entryPool.front() && ent <= &entryPool.back()) {
223 int idx = (int)((intptr_t)ent - (intptr_t)&entryPool.front()) / sizeof(AmmEntry<T>);
224 validEntries[idx] = 0;
225 startIdx = std::min(idx, startIdx);
226 } else {
227 delete ent;
230 void freeAll() noexcept;
231 void flushMsgs() noexcept;
232 inline bool match(const int tags1[AMM_NTAGS], const int tags2[AMM_NTAGS]) const noexcept;
233 inline void put(T msg) noexcept;
234 inline void put(int tag, int src, T msg) noexcept;
235 inline T get(int tag, int src, int* rtags=NULL) noexcept;
236 inline T probe(int tag, int src, int* rtags) noexcept;
237 inline int size() const noexcept;
238 void pup(PUP::er& p, AmmPupMessageFn msgpup) noexcept;
241 PUPfunctionpointer(MPI_User_function*)
244 * OpStruct's are used to lookup an MPI_User_function* and check its commutativity.
245 * They are also used to create AmpiOpHeader's, which are transmitted in reductions
246 * that are user-defined or else lack an equivalent Charm++ reducer type.
248 class OpStruct {
249 public:
250 MPI_User_function* func;
251 bool isCommutative;
252 private:
253 bool isValid;
255 public:
256 OpStruct() = default;
257 OpStruct(MPI_User_function* f) noexcept : func(f), isCommutative(true), isValid(true) {}
258 OpStruct(MPI_User_function* f, bool c) noexcept : func(f), isCommutative(c), isValid(true) {}
259 void init(MPI_User_function* f, bool c) noexcept {
260 func = f;
261 isCommutative = c;
262 isValid = true;
264 bool isFree() const noexcept { return !isValid; }
265 void free() noexcept { isValid = false; }
266 void pup(PUP::er &p) {
267 p|func; p|isCommutative; p|isValid;
271 class AmpiOpHeader {
272 public:
273 MPI_User_function* func;
274 MPI_Datatype dtype;
275 int len;
276 int szdata;
277 AmpiOpHeader(MPI_User_function* f,MPI_Datatype d,int l,int szd) noexcept :
278 func(f),dtype(d),len(l),szdata(szd) { }
281 //------------------- added by YAN for one-sided communication -----------
282 /* the index is unique within a communicator */
283 class WinStruct{
284 public:
285 MPI_Comm comm;
286 int index;
288 private:
289 bool areRecvsPosted;
290 bool inEpoch;
291 vector<int> exposureRankList;
292 vector<int> accessRankList;
293 vector<MPI_Request> requestList;
295 public:
296 WinStruct() noexcept : comm(MPI_COMM_NULL), index(-1), areRecvsPosted(false), inEpoch(false) {
297 exposureRankList.clear(); accessRankList.clear(); requestList.clear();
299 WinStruct(MPI_Comm comm_, int index_) noexcept : comm(comm_), index(index_), areRecvsPosted(false), inEpoch(false) {
300 exposureRankList.clear(); accessRankList.clear(); requestList.clear();
302 void pup(PUP::er &p) noexcept {
303 p|comm; p|index; p|areRecvsPosted; p|inEpoch; p|exposureRankList; p|accessRankList; p|requestList;
305 void clearEpochAccess() noexcept {
306 accessRankList.clear(); inEpoch = false;
308 void clearEpochExposure() noexcept {
309 exposureRankList.clear(); areRecvsPosted = false; requestList.clear(); inEpoch=false;
311 vector<int>& getExposureRankList() noexcept {return exposureRankList;}
312 vector<int>& getAccessRankList() noexcept {return accessRankList;}
313 void setExposureRankList(vector<int> &tmpExposureRankList) noexcept {exposureRankList = tmpExposureRankList;}
314 void setAccessRankList(vector<int> &tmpAccessRankList) noexcept {accessRankList = tmpAccessRankList;}
315 vector<int>& getRequestList() noexcept {return requestList;}
316 bool AreRecvsPosted() const noexcept {return areRecvsPosted;}
317 void setAreRecvsPosted(bool setR) noexcept {areRecvsPosted = setR;}
318 bool isInEpoch() const noexcept {return inEpoch;}
319 void setInEpoch(bool arg) noexcept {inEpoch = arg;}
322 class lockQueueEntry {
323 public:
324 int requestRank;
325 int lock_type;
326 lockQueueEntry (int _requestRank, int _lock_type) noexcept
327 : requestRank(_requestRank), lock_type(_lock_type) {}
328 lockQueueEntry() = default;
331 typedef CkQ<lockQueueEntry *> LockQueue;
333 class ampiParent;
335 class win_obj {
336 public:
337 void *baseAddr;
338 MPI_Aint winSize;
339 int disp_unit;
340 MPI_Comm comm;
342 int owner; // Rank of owner of the lock, -1 if not locked
343 LockQueue lockQueue; // queue of waiting processors for the lock
344 // top of queue is the one holding the lock
345 // queue is empty if lock is not applied
346 std::string winName;
347 bool initflag;
349 vector<int> keyvals; // list of keyval attributes
351 void setName(const char *src) noexcept;
352 void getName(char *src,int *len) noexcept;
354 public:
355 void pup(PUP::er &p) noexcept;
357 win_obj() noexcept;
358 win_obj(const char *name, void *base, MPI_Aint size, int disp_unit, MPI_Comm comm) noexcept;
359 ~win_obj() noexcept;
361 int create(const char *name, void *base, MPI_Aint size, int disp_unit,
362 MPI_Comm comm) noexcept;
363 int free() noexcept;
365 vector<int>& getKeyvals() { return keyvals; }
367 int put(void *orgaddr, int orgcnt, int orgunit,
368 MPI_Aint targdisp, int targcnt, int targunit) noexcept;
370 int get(void *orgaddr, int orgcnt, int orgunit,
371 MPI_Aint targdisp, int targcnt, int targunit) noexcept;
372 int accumulate(void *orgaddr, int count, MPI_Aint targdisp, MPI_Datatype targtype,
373 MPI_Op op, ampiParent* pptr) noexcept;
375 int iget(int orgcnt, MPI_Datatype orgtype,
376 MPI_Aint targdisp, int targcnt, MPI_Datatype targtype) noexcept;
377 int igetWait(MPI_Request *req, MPI_Status *status) noexcept;
378 int igetFree(MPI_Request *req, MPI_Status *status) noexcept;
380 int fence() noexcept;
382 int lock(int requestRank, int lock_type) noexcept;
383 int unlock(int requestRank) noexcept;
385 int wait() noexcept;
386 int post() noexcept;
387 int start() noexcept;
388 int complete() noexcept;
390 void lockTopQueue() noexcept;
391 void enqueue(int requestRank, int lock_type) noexcept;
392 void dequeue() noexcept;
393 bool emptyQueue() noexcept;
395 //-----------------------End of code by YAN ----------------------
397 class KeyvalPair{
398 protected:
399 std::string key;
400 std::string val;
401 public:
402 KeyvalPair() = default;
403 KeyvalPair(const char* k, const char* v) noexcept;
404 ~KeyvalPair() = default;
405 void pup(PUP::er& p) noexcept {
406 p|key;
407 p|val;
409 friend class InfoStruct;
412 class InfoStruct{
413 CkPupPtrVec<KeyvalPair> nodes;
414 bool valid;
415 public:
416 InfoStruct() noexcept : valid(true) { }
417 void setvalid(bool valid_) noexcept { valid = valid_; }
418 bool getvalid() const noexcept { return valid; }
419 int set(const char* k, const char* v) noexcept;
420 int dup(InfoStruct& src) noexcept;
421 int get(const char* k, int vl, char*& v, int *flag) const noexcept;
422 int deletek(const char* k) noexcept;
423 int get_valuelen(const char* k, int* vl, int *flag) const noexcept;
424 int get_nkeys(int *nkeys) const noexcept;
425 int get_nthkey(int n,char* k) const noexcept;
426 void myfree() noexcept;
427 void pup(PUP::er& p) noexcept;
430 class CProxy_ampi;
431 class CProxyElement_ampi;
433 //Virtual class describing a virtual topology: Cart, Graph, DistGraph
434 class ampiTopology {
435 private:
436 vector<int> v; // dummy variable for const& returns from virtual functions
438 public:
439 virtual ~ampiTopology() noexcept {};
440 virtual void pup(PUP::er &p) noexcept =0;
441 virtual int getType() const noexcept =0;
442 virtual void dup(ampiTopology* topo) noexcept =0;
443 virtual const vector<int> &getnbors() const noexcept =0;
444 virtual void setnbors(const vector<int> &nbors_) noexcept =0;
446 virtual const vector<int> &getdims() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return v;}
447 virtual const vector<int> &getperiods() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return v;}
448 virtual int getndims() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return -1;}
449 virtual void setdims(const vector<int> &dims_) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
450 virtual void setperiods(const vector<int> &periods_) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
451 virtual void setndims(int ndims_) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
453 virtual int getnvertices() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return -1;}
454 virtual const vector<int> &getindex() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return v;}
455 virtual const vector<int> &getedges() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return v;}
456 virtual void setnvertices(int nvertices_) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
457 virtual void setindex(const vector<int> &index_) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
458 virtual void setedges(const vector<int> &edges_) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
460 virtual int getInDegree() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return -1;}
461 virtual const vector<int> &getSources() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return v;}
462 virtual const vector<int> &getSourceWeights() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return v;}
463 virtual int getOutDegree() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return -1;}
464 virtual const vector<int> &getDestinations() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return v;}
465 virtual const vector<int> &getDestWeights() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return v;}
466 virtual bool areSourcesWeighted() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return false;}
467 virtual bool areDestsWeighted() const noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class."); return false;}
468 virtual void setAreSourcesWeighted(bool val) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
469 virtual void setAreDestsWeighted(bool val) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
470 virtual void setInDegree(int degree) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
471 virtual void setSources(const vector<int> &sources) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
472 virtual void setSourceWeights(const vector<int> &sourceWeights) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
473 virtual void setOutDegree(int degree) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
474 virtual void setDestinations(const vector<int> &destinations) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
475 virtual void setDestWeights(const vector<int> &destWeights) noexcept {CkAbort("AMPI: instance of invalid Virtual Topology class.");}
478 class ampiCartTopology final : public ampiTopology {
479 private:
480 int ndims;
481 vector<int> dims, periods, nbors;
483 public:
484 ampiCartTopology() noexcept : ndims(-1) {}
486 void pup(PUP::er &p) noexcept {
487 p|ndims;
488 p|dims;
489 p|periods;
490 p|nbors;
493 inline int getType() const noexcept {return MPI_CART;}
494 inline void dup(ampiTopology* topo) noexcept {
495 CkAssert(topo->getType() == MPI_CART);
496 setndims(topo->getndims());
497 setdims(topo->getdims());
498 setperiods(topo->getperiods());
499 setnbors(topo->getnbors());
502 inline const vector<int> &getdims() const noexcept {return dims;}
503 inline const vector<int> &getperiods() const noexcept {return periods;}
504 inline int getndims() const noexcept {return ndims;}
505 inline const vector<int> &getnbors() const noexcept {return nbors;}
507 inline void setdims(const vector<int> &d) noexcept {dims = d; dims.shrink_to_fit();}
508 inline void setperiods(const vector<int> &p) noexcept {periods = p; periods.shrink_to_fit();}
509 inline void setndims(int nd) noexcept {ndims = nd;}
510 inline void setnbors(const vector<int> &n) noexcept {nbors = n; nbors.shrink_to_fit();}
513 class ampiGraphTopology final : public ampiTopology {
514 private:
515 int nvertices;
516 vector<int> index, edges, nbors;
518 public:
519 ampiGraphTopology() noexcept : nvertices(-1) {}
521 void pup(PUP::er &p) noexcept {
522 p|nvertices;
523 p|index;
524 p|edges;
525 p|nbors;
528 inline int getType() const noexcept {return MPI_GRAPH;}
529 inline void dup(ampiTopology* topo) noexcept {
530 CkAssert(topo->getType() == MPI_GRAPH);
531 setnvertices(topo->getnvertices());
532 setindex(topo->getindex());
533 setedges(topo->getedges());
534 setnbors(topo->getnbors());
537 inline int getnvertices() const noexcept {return nvertices;}
538 inline const vector<int> &getindex() const noexcept {return index;}
539 inline const vector<int> &getedges() const noexcept {return edges;}
540 inline const vector<int> &getnbors() const noexcept {return nbors;}
542 inline void setnvertices(int nv) noexcept {nvertices = nv;}
543 inline void setindex(const vector<int> &i) noexcept {index = i; index.shrink_to_fit();}
544 inline void setedges(const vector<int> &e) noexcept {edges = e; edges.shrink_to_fit();}
545 inline void setnbors(const vector<int> &n) noexcept {nbors = n; nbors.shrink_to_fit();}
548 class ampiDistGraphTopology final : public ampiTopology {
549 private:
550 int inDegree, outDegree;
551 bool sourcesWeighted, destsWeighted;
552 vector<int> sources, sourceWeights, destinations, destWeights, nbors;
554 public:
555 ampiDistGraphTopology() noexcept : inDegree(-1), outDegree(-1), sourcesWeighted(false), destsWeighted(false) {}
557 void pup(PUP::er &p) noexcept {
558 p|inDegree;
559 p|outDegree;
560 p|sourcesWeighted;
561 p|destsWeighted;
562 p|sources;
563 p|sourceWeights;
564 p|destinations;
565 p|destWeights;
566 p|nbors;
569 inline int getType() const noexcept {return MPI_DIST_GRAPH;}
570 inline void dup(ampiTopology* topo) noexcept {
571 CkAssert(topo->getType() == MPI_DIST_GRAPH);
572 setAreSourcesWeighted(topo->areSourcesWeighted());
573 setAreDestsWeighted(topo->areDestsWeighted());
574 setInDegree(topo->getInDegree());
575 setSources(topo->getSources());
576 setSourceWeights(topo->getSourceWeights());
577 setOutDegree(topo->getOutDegree());
578 setDestinations(topo->getDestinations());
579 setDestWeights(topo->getDestWeights());
580 setnbors(topo->getnbors());
583 inline int getInDegree() const noexcept {return inDegree;}
584 inline const vector<int> &getSources() const noexcept {return sources;}
585 inline const vector<int> &getSourceWeights() const noexcept {return sourceWeights;}
586 inline int getOutDegree() const noexcept {return outDegree;}
587 inline const vector<int> &getDestinations() const noexcept {return destinations;}
588 inline const vector<int> &getDestWeights() const noexcept {return destWeights;}
589 inline bool areSourcesWeighted() const noexcept {return sourcesWeighted;}
590 inline bool areDestsWeighted() const noexcept {return destsWeighted;}
591 inline const vector<int> &getnbors() const noexcept {return nbors;}
593 inline void setAreSourcesWeighted(bool v) noexcept {sourcesWeighted = v ? 1 : 0;}
594 inline void setAreDestsWeighted(bool v) noexcept {destsWeighted = v ? 1 : 0;}
595 inline void setInDegree(int d) noexcept {inDegree = d;}
596 inline void setSources(const vector<int> &s) noexcept {sources = s; sources.shrink_to_fit();}
597 inline void setSourceWeights(const vector<int> &sw) noexcept {sourceWeights = sw; sourceWeights.shrink_to_fit();}
598 inline void setOutDegree(int d) noexcept {outDegree = d;}
599 inline void setDestinations(const vector<int> &d) noexcept {destinations = d; destinations.shrink_to_fit();}
600 inline void setDestWeights(const vector<int> &dw) noexcept {destWeights = dw; destWeights.shrink_to_fit();}
601 inline void setnbors(const vector<int> &nbors_) noexcept {nbors = nbors_; nbors.shrink_to_fit();}
604 /* KeyValue class for attribute caching */
605 class KeyvalNode {
606 public:
607 void *val;
608 MPI_Copy_function *copy_fn;
609 MPI_Delete_function *delete_fn;
610 void *extra_state;
611 int refCount;
612 bool isValSet;
614 KeyvalNode() : val(NULL), copy_fn(NULL), delete_fn(NULL), extra_state(NULL), refCount(1), isValSet(false) { }
615 KeyvalNode(MPI_Copy_function *cf, MPI_Delete_function *df, void* es) :
616 val(NULL), copy_fn(cf), delete_fn(df), extra_state(es), refCount(1), isValSet(false) { }
617 bool hasVal() const { return isValSet; }
618 void clearVal() { isValSet = false; }
619 void setVal(void *v) { val = v; isValSet = true; }
620 void* getVal() const { return val; }
621 void incRefCount() { refCount++; }
622 int decRefCount() { CkAssert(refCount > 0); refCount--; return refCount; }
623 void pup(PUP::er& p) {
624 p((char *)val, sizeof(void *));
625 p((char *)copy_fn, sizeof(void *));
626 p((char *)delete_fn, sizeof(void *));
627 p((char *)extra_state, sizeof(void *));
628 p|refCount;
629 p|isValSet;
633 enum AmpiCommType : uint8_t {
634 WORLD = 0
635 ,INTRA = 1
636 ,INTER = 2
639 //Describes an AMPI communicator
640 class ampiCommStruct {
641 private:
642 MPI_Comm comm; //Communicator
643 CkArrayID ampiID; //ID of corresponding ampi array
644 int size; //Number of processes in communicator
645 AmpiCommType commType; //COMM_WORLD, intracomm, intercomm?
646 vector<int> indices; //indices[r] gives the array index for rank r
647 vector<int> remoteIndices; // remote group for inter-communicator
649 ampiTopology *ampiTopo; // Virtual topology
650 int topoType; // Type of virtual topology: MPI_CART, MPI_GRAPH, MPI_DIST_GRAPH, or MPI_UNDEFINED
652 // For communicator attributes (MPI_*_get_attr): indexed by keyval
653 vector<int> keyvals;
655 // For communicator names
656 std::string commName;
658 // Lazily fill world communicator indices
659 void makeWorldIndices() const noexcept {
660 vector<int> &ind = const_cast<vector<int> &>(indices);
661 ind.resize(size);
662 std::iota(ind.begin(), ind.end(), 0);
665 public:
666 ampiCommStruct(int ignored=0) noexcept : size(-1), commType(INTRA), ampiTopo(NULL), topoType(MPI_UNDEFINED) {}
667 ampiCommStruct(MPI_Comm comm_,const CkArrayID &id_,int size_) noexcept
668 :comm(comm_), ampiID(id_),size(size_), commType(WORLD), ampiTopo(NULL), topoType(MPI_UNDEFINED) {}
669 ampiCommStruct(MPI_Comm comm_,const CkArrayID &id_,
670 int size_,const vector<int> &indices_) noexcept
671 :comm(comm_), ampiID(id_), size(size_),
672 commType(INTRA), indices(indices_),
673 ampiTopo(NULL), topoType(MPI_UNDEFINED) {}
674 ampiCommStruct(MPI_Comm comm_,const CkArrayID &id_,
675 int size_,const vector<int> &indices_,
676 const vector<int> &remoteIndices_) noexcept
677 :comm(comm_),ampiID(id_),size(size_),commType(INTER),
678 indices(indices_),remoteIndices(remoteIndices_),
679 ampiTopo(NULL), topoType(MPI_UNDEFINED) {}
681 ~ampiCommStruct() noexcept {
682 if (ampiTopo != NULL)
683 delete ampiTopo;
686 // Overloaded copy constructor. Used when creating virtual topologies.
687 ampiCommStruct(const ampiCommStruct &obj, int topoNumber=MPI_UNDEFINED) noexcept {
688 switch (topoNumber) {
689 case MPI_CART:
690 ampiTopo = new ampiCartTopology();
691 break;
692 case MPI_GRAPH:
693 ampiTopo = new ampiGraphTopology();
694 break;
695 case MPI_DIST_GRAPH:
696 ampiTopo = new ampiDistGraphTopology();
697 break;
698 default:
699 ampiTopo = NULL;
700 break;
702 topoType = topoNumber;
703 comm = obj.comm;
704 ampiID = obj.ampiID;
705 size = obj.size;
706 commType = obj.commType;
707 indices = obj.indices;
708 remoteIndices = obj.remoteIndices;
709 keyvals = obj.keyvals;
710 commName = obj.commName;
713 ampiCommStruct &operator=(const ampiCommStruct &obj) noexcept {
714 if (this == &obj) {
715 return *this;
717 switch (obj.topoType) {
718 case MPI_CART:
719 ampiTopo = new ampiCartTopology(*(static_cast<ampiCartTopology*>(obj.ampiTopo)));
720 break;
721 case MPI_GRAPH:
722 ampiTopo = new ampiGraphTopology(*(static_cast<ampiGraphTopology*>(obj.ampiTopo)));
723 break;
724 case MPI_DIST_GRAPH:
725 ampiTopo = new ampiDistGraphTopology(*(static_cast<ampiDistGraphTopology*>(obj.ampiTopo)));
726 break;
727 default:
728 ampiTopo = NULL;
729 break;
731 topoType = obj.topoType;
732 comm = obj.comm;
733 ampiID = obj.ampiID;
734 size = obj.size;
735 commType = obj.commType;
736 indices = obj.indices;
737 remoteIndices = obj.remoteIndices;
738 keyvals = obj.keyvals;
739 commName = obj.commName;
740 return *this;
743 const ampiTopology* getTopologyforNeighbors() const noexcept {
744 return ampiTopo;
747 ampiTopology* getTopology() noexcept {
748 return ampiTopo;
751 inline bool isinter() const noexcept {return commType==INTER;}
752 void setArrayID(const CkArrayID &nID) noexcept {ampiID=nID;}
754 MPI_Comm getComm() const noexcept {return comm;}
755 inline const vector<int> &getIndices() const noexcept {
756 if (commType==WORLD && indices.size()!=size) makeWorldIndices();
757 return indices;
759 const vector<int> &getRemoteIndices() const noexcept {return remoteIndices;}
760 vector<int> &getKeyvals() noexcept {return keyvals;}
762 void setName(const char *src) noexcept {
763 CkDDT_SetName(commName, src);
766 void getName(char *name, int *len) const noexcept {
767 int length = *len = commName.size();
768 memcpy(name, commName.data(), length);
769 name[length] = '\0';
772 //Get the proxy for the entire array
773 CProxy_ampi getProxy() const noexcept;
775 //Get the array index for rank r in this communicator
776 int getIndexForRank(int r) const noexcept {
777 #if CMK_ERROR_CHECKING
778 if (r>=size) CkAbort("AMPI> You passed in an out-of-bounds process rank!");
779 #endif
780 if (commType == WORLD) return r;
781 else return indices[r];
783 int getIndexForRemoteRank(int r) const noexcept {
784 #if CMK_ERROR_CHECKING
785 if (r>=remoteIndices.size()) CkAbort("AMPI> You passed in an out-of-bounds process rank!");
786 #endif
787 if (commType==WORLD) return r;
788 else return remoteIndices[r];
790 //Get the rank for this array index (Warning: linear time)
791 int getRankForIndex(int i) const noexcept {
792 if (commType==WORLD) return i;
793 else {
794 for (int r=0;r<indices.size();r++)
795 if (indices[r]==i) return r;
796 return -1; /*That index isn't in this communicator*/
800 int getSize() const noexcept {return size;}
802 void pup(PUP::er &p) noexcept {
803 p|comm;
804 p|ampiID;
805 p|size;
806 p|commType;
807 p|indices;
808 p|remoteIndices;
809 p|keyvals;
810 p|commName;
811 p|topoType;
812 if (topoType != MPI_UNDEFINED) {
813 if (p.isUnpacking()) {
814 switch (topoType) {
815 case MPI_CART:
816 ampiTopo = new ampiCartTopology();
817 break;
818 case MPI_GRAPH:
819 ampiTopo = new ampiGraphTopology();
820 break;
821 case MPI_DIST_GRAPH:
822 ampiTopo = new ampiDistGraphTopology();
823 break;
824 default:
825 CkAbort("AMPI> Communicator has an invalid topology!");
826 break;
829 ampiTopo->pup(p);
830 } else {
831 ampiTopo = NULL;
833 if (p.isDeleting()) {
834 delete ampiTopo; ampiTopo = NULL;
838 PUPmarshall(ampiCommStruct)
840 class mpi_comm_worlds{
841 ampiCommStruct comms[MPI_MAX_COMM_WORLDS];
842 public:
843 ampiCommStruct &operator[](int i) noexcept {return comms[i];}
844 void pup(PUP::er &p) noexcept {
845 for (int i=0;i<MPI_MAX_COMM_WORLDS;i++)
846 comms[i].pup(p);
850 typedef vector<int> groupStruct;
852 // groupStructure operations
853 inline void outputOp(groupStruct vec) noexcept {
854 if (vec.size() > 50) {
855 CkPrintf("vector too large to output!\n");
856 return;
858 CkPrintf("output vector: size=%d {",vec.size());
859 for (int i=0; i<vec.size(); i++) {
860 CkPrintf(" %d ", vec[i]);
862 CkPrintf("}\n");
865 inline int getPosOp(int idx, groupStruct vec) noexcept {
866 for (int r=0; r<vec.size(); r++) {
867 if (vec[r] == idx) {
868 return r;
871 return MPI_UNDEFINED;
874 inline groupStruct unionOp(groupStruct vec1, groupStruct vec2) noexcept {
875 groupStruct newvec(vec1);
876 for (int i=0; i<vec2.size(); i++) {
877 if (getPosOp(vec2[i], vec1) == MPI_UNDEFINED) {
878 newvec.push_back(vec2[i]);
881 return newvec;
884 inline groupStruct intersectOp(groupStruct vec1, groupStruct vec2) noexcept {
885 groupStruct newvec;
886 for (int i=0; i<vec1.size(); i++) {
887 if (getPosOp(vec1[i], vec2) != MPI_UNDEFINED) {
888 newvec.push_back(vec1[i]);
891 return newvec;
894 inline groupStruct diffOp(groupStruct vec1, groupStruct vec2) noexcept {
895 groupStruct newvec;
896 for (int i=0; i<vec1.size(); i++) {
897 if (getPosOp(vec1[i], vec2) == MPI_UNDEFINED) {
898 newvec.push_back(vec1[i]);
901 return newvec;
904 inline int* translateRanksOp(int n, groupStruct vec1, const int* ranks1, groupStruct vec2, int *ret) noexcept {
905 for (int i=0; i<n; i++) {
906 ret[i] = (ranks1[i] == MPI_PROC_NULL) ? MPI_PROC_NULL : getPosOp(vec1[ranks1[i]], vec2);
908 return ret;
911 inline int compareVecOp(groupStruct vec1, groupStruct vec2) noexcept {
912 int pos, ret = MPI_IDENT;
913 if (vec1.size() != vec2.size()) {
914 return MPI_UNEQUAL;
916 for (int i=0; i<vec1.size(); i++) {
917 pos = getPosOp(vec1[i], vec2);
918 if (pos == MPI_UNDEFINED) {
919 return MPI_UNEQUAL;
921 else if (pos != i) {
922 ret = MPI_SIMILAR;
925 return ret;
928 inline groupStruct inclOp(int n, const int* ranks, groupStruct vec) noexcept {
929 groupStruct retvec(n);
930 for (int i=0; i<n; i++) {
931 retvec[i] = vec[ranks[i]];
933 return retvec;
936 inline groupStruct exclOp(int n, const int* ranks, groupStruct vec) noexcept {
937 groupStruct retvec;
938 bool add = true;
939 for (int j=0; j<vec.size(); j++) {
940 for (int i=0; i<n; i++) {
941 if (j == ranks[i]) {
942 add = false;
943 break;
946 if (add) {
947 retvec.push_back(vec[j]);
949 else {
950 add = true;
953 return retvec;
956 inline groupStruct rangeInclOp(int n, int ranges[][3], groupStruct vec, int *flag) noexcept {
957 groupStruct retvec;
958 int first, last, stride;
959 for (int i=0; i<n; i++) {
960 first = ranges[i][0];
961 last = ranges[i][1];
962 stride = ranges[i][2];
963 if (stride != 0) {
964 for (int j=0; j<=(last-first)/stride; j++) {
965 retvec.push_back(vec[first+stride*j]);
968 else {
969 *flag = MPI_ERR_ARG;
970 return groupStruct();
973 *flag = MPI_SUCCESS;
974 return retvec;
977 inline groupStruct rangeExclOp(int n, int ranges[][3], groupStruct vec, int *flag) noexcept {
978 vector<int> ranks;
979 int first, last, stride;
980 for (int i=0; i<n; i++) {
981 first = ranges[i][0];
982 last = ranges[i][1];
983 stride = ranges[i][2];
984 if (stride != 0) {
985 for (int j=0; j<=(last-first)/stride; j++) {
986 ranks.push_back(first+stride*j);
989 else {
990 *flag = MPI_ERR_ARG;
991 return groupStruct();
994 *flag = MPI_SUCCESS;
995 return exclOp(ranks.size(), &ranks[0], vec);
998 #include "tcharm.h"
999 #include "tcharmc.h"
1001 #include "ampi.decl.h"
1002 #include "charm-api.h"
1003 #include <sys/stat.h> // for mkdir
1005 extern int _mpi_nworlds;
1007 //MPI_ANY_TAG is defined in ampi.h to MPI_TAG_UB_VALUE+1
1008 #define MPI_ATA_SEQ_TAG MPI_TAG_UB_VALUE+2
1009 #define MPI_BCAST_TAG MPI_TAG_UB_VALUE+3
1010 #define MPI_REDN_TAG MPI_TAG_UB_VALUE+4
1011 #define MPI_SCATTER_TAG MPI_TAG_UB_VALUE+5
1012 #define MPI_SCAN_TAG MPI_TAG_UB_VALUE+6
1013 #define MPI_EXSCAN_TAG MPI_TAG_UB_VALUE+7
1014 #define MPI_ATA_TAG MPI_TAG_UB_VALUE+8
1015 #define MPI_NBOR_TAG MPI_TAG_UB_VALUE+9
1016 #define MPI_RMA_TAG MPI_TAG_UB_VALUE+10
1017 #define MPI_EPOCH_START_TAG MPI_TAG_UB_VALUE+11
1018 #define MPI_EPOCH_END_TAG MPI_TAG_UB_VALUE+12
1020 #define AMPI_COLL_SOURCE 0
1021 #define AMPI_COLL_DEST -1
1022 #define AMPI_COLL_COMM MPI_COMM_WORLD
1024 enum AmpiReqType : uint8_t {
1025 AMPI_INVALID_REQ = 0,
1026 AMPI_I_REQ = 1,
1027 AMPI_ATA_REQ = 2,
1028 AMPI_SEND_REQ = 3,
1029 AMPI_SSEND_REQ = 4,
1030 AMPI_REDN_REQ = 5,
1031 AMPI_GATHER_REQ = 6,
1032 AMPI_GATHERV_REQ = 7,
1033 AMPI_G_REQ = 8,
1034 #if CMK_CUDA
1035 AMPI_GPU_REQ = 9
1036 #endif
1039 inline void operator|(PUP::er &p, AmpiReqType &r) {
1040 pup_bytes(&p, (void *)&r, sizeof(AmpiReqType));
1043 enum AmpiReqSts : char {
1044 AMPI_REQ_PENDING = 0,
1045 AMPI_REQ_BLOCKED = 1,
1046 AMPI_REQ_COMPLETED = 2
1049 enum AmpiSendType : bool {
1050 BLOCKING_SEND = false,
1051 I_SEND = true
1054 #define MyAlign8(x) (((x)+7)&(~7))
1057 Represents an MPI request that has been initiated
1058 using Isend, Irecv, Ialltoall, Send_init, etc.
1060 class AmpiRequest {
1061 public:
1062 void *buf = nullptr;
1063 int count = 0;
1064 MPI_Datatype type = MPI_DATATYPE_NULL;
1065 int tag = MPI_ANY_TAG; // the order must match MPI_Status
1066 int src = MPI_ANY_SOURCE;
1067 MPI_Comm comm = MPI_COMM_NULL;
1068 MPI_Request reqIdx = MPI_REQUEST_NULL;
1069 bool complete = false;
1070 bool blocked = false; // this req is currently blocked on
1072 #if CMK_BIGSIM_CHARM
1073 public:
1074 void *event = nullptr; // the event point that corresponds to this message
1075 int eventPe = -1; // the PE that the event is located on
1076 #endif
1078 public:
1079 AmpiRequest() noexcept {}
1080 /// Close this request (used by free and cancel)
1081 virtual ~AmpiRequest() noexcept {}
1083 /// Activate this persistent request.
1084 /// Only meaningful for persistent Ireq, SendReq, and SsendReq requests.
1085 virtual void start(MPI_Request reqIdx) noexcept {}
1087 /// Used by AmmEntry's constructor
1088 virtual int getTag() const noexcept { return tag; }
1089 virtual int getSrcRank() const noexcept { return src; }
1091 /// Return true if this request is finished (progress):
1092 virtual bool test(MPI_Status *sts=MPI_STATUS_IGNORE) noexcept =0;
1094 /// Block until this request is finished,
1095 /// returning a valid MPI error code.
1096 virtual int wait(MPI_Status *sts) noexcept =0;
1098 /// Mark this request for cancellation.
1099 /// Supported only for IReq requests
1100 virtual void cancel() noexcept {}
1102 /// Mark this request persistent.
1103 /// Supported only for IReq, SendReq, and SsendReq requests
1104 virtual void setPersistent(bool p) noexcept {}
1105 virtual bool isPersistent() const noexcept { return false; }
1107 /// Receive an AmpiMsg
1108 virtual void receive(ampi *ptr, AmpiMsg *msg) noexcept =0;
1110 /// Receive a CkReductionMsg
1111 virtual void receive(ampi *ptr, CkReductionMsg *msg) noexcept =0;
1113 /// Receive an Rdma message
1114 virtual void receiveRdma(ampi *ptr, char *sbuf, int slength, int ssendReq,
1115 int srcRank, MPI_Comm scomm) noexcept { }
1117 /// Set the request's index into AmpiRequestList
1118 void setReqIdx(MPI_Request idx) noexcept { reqIdx = idx; }
1119 MPI_Request getReqIdx() const noexcept { return reqIdx; }
1121 /// Free the request's datatype
1122 void free(CkDDT* ddt) noexcept {
1123 if (type != MPI_DATATYPE_NULL) ddt->freeType(type);
1126 /// Set whether the request is currently blocked on
1127 void setBlocked(bool b) noexcept { blocked = b; }
1128 bool isBlocked() const noexcept { return blocked; }
1130 /// Returns the type of request:
1131 /// AMPI_I_REQ, AMPI_ATA_REQ, AMPI_SEND_REQ, AMPI_SSEND_REQ,
1132 /// AMPI_REDN_REQ, AMPI_GATHER_REQ, AMPI_GATHERV_REQ, AMPI_G_REQ
1133 virtual AmpiReqType getType() const noexcept =0;
1135 /// Returns whether this request will need to be matched.
1136 /// It is used to determine whether this request should be inserted into postedReqs.
1137 /// AMPI_SEND_REQ, AMPI_SSEND_REQ, and AMPI_ATA_REQ should not be posted.
1138 virtual bool isUnmatched() const noexcept =0;
1140 /// Returns whether this type is pooled or not:
1141 /// Only AMPI_I_REQ, AMPI_SEND_REQ, and AMPI_SSEND_REQs are pooled.
1142 virtual bool isPooledType() const noexcept { return false; }
1144 /// Return the actual number of bytes that were received.
1145 virtual int getNumReceivedBytes(CkDDT *ddt) const noexcept {
1146 // by default, return number of bytes requested
1147 return count * ddt->getSize(type);
1150 virtual void pup(PUP::er &p) noexcept {
1151 p((char *)&buf, sizeof(void *)); //supposed to work only with Isomalloc
1152 p(count);
1153 p(type);
1154 p(tag);
1155 p(src);
1156 p(comm);
1157 p(reqIdx);
1158 p(complete);
1159 p(blocked);
1160 #if CMK_BIGSIM_CHARM
1161 //needed for bigsim out-of-core emulation
1162 //as the "log" is not moved from memory, this pointer is safe
1163 //to be reused
1164 p((char *)&event, sizeof(void *));
1165 p(eventPe);
1166 #endif
1169 virtual void print() const noexcept =0;
1172 // This is used in the constructors of the AmpiRequest types below,
1173 // assuming arguments: (MPI_Datatype type_, CkDDT* ddt_, AmpiReqSts sts_)
1174 #define AMPI_REQUEST_COMMON_INIT \
1176 complete = (sts_ == AMPI_REQ_COMPLETED); \
1177 blocked = (sts_ == AMPI_REQ_BLOCKED); \
1178 if (type_ != MPI_DATATYPE_NULL) { \
1179 ddt_->getType(type_)->incRefCount(); \
1183 class IReq final : public AmpiRequest {
1184 public:
1185 bool cancelled = false; // track if request is cancelled
1186 bool persistent = false; // Is this a persistent recv request?
1187 int length = 0; // recv'ed length in bytes
1189 IReq(void *buf_, int count_, MPI_Datatype type_, int src_, int tag_,
1190 MPI_Comm comm_, CkDDT *ddt_, AmpiReqSts sts_=AMPI_REQ_PENDING) noexcept
1192 buf = buf_;
1193 count = count_;
1194 type = type_;
1195 src = src_;
1196 tag = tag_;
1197 comm = comm_;
1198 AMPI_REQUEST_COMMON_INIT
1200 IReq() noexcept {}
1201 ~IReq() noexcept {}
1202 bool test(MPI_Status *sts=MPI_STATUS_IGNORE) noexcept override;
1203 int wait(MPI_Status *sts) noexcept override;
1204 void cancel() noexcept override { if (!complete) cancelled = true; }
1205 AmpiReqType getType() const noexcept override { return AMPI_I_REQ; }
1206 bool isUnmatched() const noexcept override { return !complete; }
1207 bool isPooledType() const noexcept override { return true; }
1208 void setPersistent(bool p) noexcept override { persistent = p; }
1209 bool isPersistent() const noexcept override { return persistent; }
1210 void start(MPI_Request reqIdx) noexcept override;
1211 void receive(ampi *ptr, AmpiMsg *msg) noexcept override;
1212 void receive(ampi *ptr, CkReductionMsg *msg) noexcept override {}
1213 void receiveRdma(ampi *ptr, char *sbuf, int slength, int ssendReq, int srcRank, MPI_Comm scomm) noexcept override;
1214 int getNumReceivedBytes(CkDDT *ptr) const noexcept override {
1215 return length;
1217 void pup(PUP::er &p) noexcept override {
1218 AmpiRequest::pup(p);
1219 p|cancelled;
1220 p|persistent;
1221 p|length;
1223 void print() const noexcept override;
1226 class RednReq final : public AmpiRequest {
1227 public:
1228 MPI_Op op = MPI_OP_NULL;
1230 RednReq(void *buf_, int count_, MPI_Datatype type_, MPI_Comm comm_,
1231 MPI_Op op_, CkDDT* ddt_, AmpiReqSts sts_=AMPI_REQ_PENDING) noexcept
1233 buf = buf_;
1234 count = count_;
1235 type = type_;
1236 src = AMPI_COLL_SOURCE;
1237 tag = MPI_REDN_TAG;
1238 comm = comm_;
1239 op = op_;
1240 AMPI_REQUEST_COMMON_INIT
1242 RednReq() noexcept {}
1243 ~RednReq() noexcept {}
1244 bool test(MPI_Status *sts=MPI_STATUS_IGNORE) noexcept override;
1245 int wait(MPI_Status *sts) noexcept override;
1246 void cancel() noexcept override {}
1247 AmpiReqType getType() const noexcept override { return AMPI_REDN_REQ; }
1248 bool isUnmatched() const noexcept override { return !complete; }
1249 void receive(ampi *ptr, AmpiMsg *msg) noexcept override {}
1250 void receive(ampi *ptr, CkReductionMsg *msg) noexcept override;
1251 void pup(PUP::er &p) noexcept override {
1252 AmpiRequest::pup(p);
1253 p|op;
1255 void print() const noexcept override;
1258 class GatherReq final : public AmpiRequest {
1259 public:
1260 GatherReq(void *buf_, int count_, MPI_Datatype type_, MPI_Comm comm_,
1261 CkDDT *ddt_, AmpiReqSts sts_=AMPI_REQ_PENDING) noexcept
1263 buf = buf_;
1264 count = count_;
1265 type = type_;
1266 src = AMPI_COLL_SOURCE;
1267 tag = MPI_REDN_TAG;
1268 comm = comm_;
1269 AMPI_REQUEST_COMMON_INIT
1271 GatherReq() noexcept {}
1272 ~GatherReq() noexcept {}
1273 bool test(MPI_Status *sts=MPI_STATUS_IGNORE) noexcept override;
1274 int wait(MPI_Status *sts) noexcept override;
1275 void cancel() noexcept override {}
1276 AmpiReqType getType() const noexcept override { return AMPI_GATHER_REQ; }
1277 bool isUnmatched() const noexcept override { return !complete; }
1278 void receive(ampi *ptr, AmpiMsg *msg) noexcept override {}
1279 void receive(ampi *ptr, CkReductionMsg *msg) noexcept override;
1280 void pup(PUP::er &p) noexcept override {
1281 AmpiRequest::pup(p);
1283 void print() const noexcept override;
1286 class GathervReq final : public AmpiRequest {
1287 public:
1288 vector<int> recvCounts;
1289 vector<int> displs;
1291 GathervReq(void *buf_, int count_, MPI_Datatype type_, MPI_Comm comm_, const int *rc,
1292 const int *d, CkDDT* ddt_, AmpiReqSts sts_=AMPI_REQ_PENDING) noexcept
1294 buf = buf_;
1295 count = count_;
1296 type = type_;
1297 src = AMPI_COLL_SOURCE;
1298 tag = MPI_REDN_TAG;
1299 comm = comm_;
1300 recvCounts.assign(rc, rc+count);
1301 displs.assign(d, d+count);
1302 AMPI_REQUEST_COMMON_INIT
1304 GathervReq() noexcept {}
1305 ~GathervReq() noexcept {}
1306 bool test(MPI_Status *sts=MPI_STATUS_IGNORE) noexcept override;
1307 int wait(MPI_Status *sts) noexcept override;
1308 AmpiReqType getType() const noexcept override { return AMPI_GATHERV_REQ; }
1309 bool isUnmatched() const noexcept override { return !complete; }
1310 void receive(ampi *ptr, AmpiMsg *msg) noexcept override {}
1311 void receive(ampi *ptr, CkReductionMsg *msg) noexcept override;
1312 void pup(PUP::er &p) noexcept override {
1313 AmpiRequest::pup(p);
1314 p|recvCounts;
1315 p|displs;
1317 void print() const noexcept override;
1320 class SendReq final : public AmpiRequest {
1321 bool persistent = false; // is this a persistent send request?
1323 public:
1324 SendReq(MPI_Datatype type_, MPI_Comm comm_, CkDDT* ddt_, AmpiReqSts sts_=AMPI_REQ_PENDING) noexcept
1326 type = type_;
1327 comm = comm_;
1328 AMPI_REQUEST_COMMON_INIT
1330 SendReq(void* buf_, int count_, MPI_Datatype type_, int dest_, int tag_,
1331 MPI_Comm comm_, CkDDT* ddt_, AmpiReqSts sts_=AMPI_REQ_PENDING) noexcept
1333 buf = buf_;
1334 count = count_;
1335 type = type_;
1336 src = dest_;
1337 tag = tag_;
1338 comm = comm_;
1339 AMPI_REQUEST_COMMON_INIT
1341 SendReq() noexcept {}
1342 ~SendReq() noexcept {}
1343 bool test(MPI_Status *sts=MPI_STATUS_IGNORE) noexcept override;
1344 int wait(MPI_Status *sts) noexcept override;
1345 void setPersistent(bool p) noexcept override { persistent = p; }
1346 bool isPersistent() const noexcept override { return persistent; }
1347 void start(MPI_Request reqIdx) noexcept override;
1348 void receive(ampi *ptr, AmpiMsg *msg) noexcept override {}
1349 void receive(ampi *ptr, CkReductionMsg *msg) noexcept override {}
1350 AmpiReqType getType() const noexcept override { return AMPI_SEND_REQ; }
1351 bool isUnmatched() const noexcept override { return false; }
1352 bool isPooledType() const noexcept override { return true; }
1353 void pup(PUP::er &p) noexcept override {
1354 AmpiRequest::pup(p);
1355 p|persistent;
1357 void print() const noexcept override;
1360 class SsendReq final : public AmpiRequest {
1361 private:
1362 bool persistent = false; // is this a persistent Ssend request?
1364 public:
1365 SsendReq(MPI_Datatype type_, MPI_Comm comm_, CkDDT* ddt_, AmpiReqSts sts_=AMPI_REQ_PENDING) noexcept
1367 type = type_;
1368 comm = comm_;
1369 AMPI_REQUEST_COMMON_INIT
1371 SsendReq(void* buf_, int count_, MPI_Datatype type_, int dest_, int tag_, MPI_Comm comm_,
1372 CkDDT* ddt_, AmpiReqSts sts_=AMPI_REQ_PENDING) noexcept
1374 buf = buf_;
1375 count = count_;
1376 type = type_;
1377 src = dest_;
1378 tag = tag_;
1379 comm = comm_;
1380 AMPI_REQUEST_COMMON_INIT
1382 SsendReq(void* buf_, int count_, MPI_Datatype type_, int dest_, int tag_, MPI_Comm comm_,
1383 int src_, CkDDT* ddt_, AmpiReqSts sts_=AMPI_REQ_PENDING) noexcept
1385 buf = buf_;
1386 count = count_;
1387 type = type_;
1388 src = dest_;
1389 tag = tag_;
1390 comm = comm_;
1391 AMPI_REQUEST_COMMON_INIT
1393 SsendReq() noexcept {}
1394 ~SsendReq() noexcept {}
1395 bool test(MPI_Status *sts=MPI_STATUS_IGNORE) noexcept override;
1396 int wait(MPI_Status *sts) noexcept override;
1397 void setPersistent(bool p) noexcept override { persistent = p; }
1398 bool isPersistent() const noexcept override { return persistent; }
1399 void start(MPI_Request reqIdx) noexcept override;
1400 void receive(ampi *ptr, AmpiMsg *msg) noexcept override {}
1401 void receive(ampi *ptr, CkReductionMsg *msg) noexcept override {}
1402 AmpiReqType getType() const noexcept override { return AMPI_SSEND_REQ; }
1403 bool isUnmatched() const noexcept override { return false; }
1404 bool isPooledType() const noexcept override { return true; }
1405 void pup(PUP::er &p) noexcept override {
1406 AmpiRequest::pup(p);
1407 p|persistent;
1409 void print() const noexcept override;
1412 #if CMK_CUDA
1413 class GPUReq : public AmpiRequest {
1414 public:
1415 GPUReq() noexcept;
1416 ~GPUReq() noexcept {}
1417 bool test(MPI_Status *sts=MPI_STATUS_IGNORE) noexcept override;
1418 int wait(MPI_Status *sts) noexcept override;
1419 void receive(ampi *ptr, AmpiMsg *msg) noexcept override;
1420 void receive(ampi *ptr, CkReductionMsg *msg) noexcept override;
1421 AmpiReqType getType() const noexcept override { return AMPI_GPU_REQ; }
1422 bool isUnmatched() const noexcept override { return false; }
1423 void setComplete() noexcept;
1424 void print() const noexcept override;
1426 #endif
1428 class ATAReq final : public AmpiRequest {
1429 public:
1430 vector<MPI_Request> reqs;
1432 ATAReq(int numReqs_) noexcept : reqs(numReqs_) {}
1433 ATAReq() noexcept {}
1434 ~ATAReq() noexcept {}
1435 bool test(MPI_Status *sts=MPI_STATUS_IGNORE) noexcept override;
1436 int wait(MPI_Status *sts) noexcept override;
1437 void receive(ampi *ptr, AmpiMsg *msg) noexcept override {}
1438 void receive(ampi *ptr, CkReductionMsg *msg) noexcept override {}
1439 int getCount() const noexcept { return reqs.size(); }
1440 AmpiReqType getType() const noexcept override { return AMPI_ATA_REQ; }
1441 bool isUnmatched() const noexcept override { return false; }
1442 void pup(PUP::er &p) noexcept override {
1443 AmpiRequest::pup(p);
1444 p|reqs;
1446 void print() const noexcept override;
1449 class GReq final : public AmpiRequest {
1450 private:
1451 MPI_Grequest_query_function* queryFn;
1452 MPI_Grequest_free_function* freeFn;
1453 MPI_Grequest_cancel_function* cancelFn;
1454 MPIX_Grequest_poll_function* pollFn;
1455 MPIX_Grequest_wait_function* waitFn;
1456 void* extraState;
1458 public:
1459 GReq(MPI_Grequest_query_function* q, MPI_Grequest_free_function* f, MPI_Grequest_cancel_function* c, void* es) noexcept
1460 : queryFn(q), freeFn(f), cancelFn(c), pollFn(nullptr), waitFn(nullptr), extraState(es) {}
1461 GReq(MPI_Grequest_query_function *q, MPI_Grequest_free_function* f, MPI_Grequest_cancel_function* c, MPIX_Grequest_poll_function* p, void* es) noexcept
1462 : queryFn(q), freeFn(f), cancelFn(c), pollFn(p), waitFn(nullptr), extraState(es) {}
1463 GReq(MPI_Grequest_query_function *q, MPI_Grequest_free_function* f, MPI_Grequest_cancel_function* c, MPIX_Grequest_poll_function* p, MPIX_Grequest_wait_function* w, void* es) noexcept
1464 : queryFn(q), freeFn(f), cancelFn(c), pollFn(p), waitFn(w), extraState(es) {}
1465 GReq() noexcept {}
1466 ~GReq() noexcept { (*freeFn)(extraState); }
1467 bool test(MPI_Status *sts=MPI_STATUS_IGNORE) noexcept override;
1468 int wait(MPI_Status *sts) noexcept override;
1469 void receive(ampi *ptr, AmpiMsg *msg) noexcept override {}
1470 void receive(ampi *ptr, CkReductionMsg *msg) noexcept override {}
1471 void cancel() noexcept override { (*cancelFn)(extraState, complete); }
1472 AmpiReqType getType() const noexcept override { return AMPI_G_REQ; }
1473 bool isUnmatched() const noexcept override { return false; }
1474 void pup(PUP::er &p) noexcept override {
1475 AmpiRequest::pup(p);
1476 p((char *)queryFn, sizeof(void *));
1477 p((char *)freeFn, sizeof(void *));
1478 p((char *)cancelFn, sizeof(void *));
1479 p((char *)pollFn, sizeof(void *));
1480 p((char *)waitFn, sizeof(void *));
1481 p((char *)extraState, sizeof(void *));
1483 void print() const noexcept override;
1486 class AmpiRequestPool;
1488 class AmpiRequestList {
1489 private:
1490 vector<AmpiRequest*> reqs; // indexed by MPI_Request
1491 int startIdx; // start next search from this index
1492 AmpiRequestPool* reqPool;
1493 public:
1494 AmpiRequestList() noexcept : startIdx(0) {}
1495 AmpiRequestList(int size, AmpiRequestPool* reqPoolPtr) noexcept
1496 : reqs(size), startIdx(0), reqPool(reqPoolPtr) {}
1497 ~AmpiRequestList() noexcept {}
1499 inline AmpiRequest* operator[](int n) noexcept {
1500 #if CMK_ERROR_CHECKING
1501 return reqs.at(n);
1502 #else
1503 return reqs[n];
1504 #endif
1506 void free(AmpiRequestPool& reqPool, int idx, CkDDT *ddt) noexcept;
1507 void freeNonPersReq(int &idx) noexcept;
1508 inline int insert(AmpiRequest* req) noexcept {
1509 for (int i=startIdx; i<reqs.size(); i++) {
1510 if (reqs[i] == NULL) {
1511 req->setReqIdx(i);
1512 reqs[i] = req;
1513 startIdx = i+1;
1514 return i;
1517 reqs.push_back(req);
1518 int idx = reqs.size()-1;
1519 req->setReqIdx(idx);
1520 startIdx = idx+1;
1521 return idx;
1524 inline void checkRequest(MPI_Request idx) const noexcept {
1525 if (idx != MPI_REQUEST_NULL && (idx < 0 || idx >= reqs.size()))
1526 CkAbort("Invalid MPI_Request\n");
1529 inline void unblockReqs(MPI_Request *requests, int numReqs) noexcept {
1530 for (int i=0; i<numReqs; i++) {
1531 if (requests[i] != MPI_REQUEST_NULL) {
1532 reqs[requests[i]]->setBlocked(false);
1537 void pup(PUP::er &p, AmpiRequestPool* reqPool) noexcept;
1539 void print() const noexcept {
1540 for (int i=0; i<reqs.size(); i++) {
1541 if (reqs[i] == NULL) continue;
1542 CkPrintf("AmpiRequestList Element %d [%p]: \n", i+1, reqs[i]);
1543 reqs[i]->print();
1548 //A simple memory buffer
1549 class memBuf {
1550 CkVec<char> buf;
1551 public:
1552 memBuf() =default;
1553 memBuf(int size) noexcept : buf(size) {}
1554 void setSize(int s) noexcept {buf.resize(s);}
1555 int getSize() const noexcept {return buf.size();}
1556 const void *getData() const noexcept {return (const void *)&buf[0];}
1557 void *getData() noexcept {return (void *)&buf[0];}
1560 template <class T>
1561 inline void pupIntoBuf(memBuf &b,T &t) noexcept {
1562 PUP::sizer ps;ps|t;
1563 b.setSize(ps.size());
1564 PUP::toMem pm(b.getData()); pm|t;
1567 template <class T>
1568 inline void pupFromBuf(const void *data,T &t) noexcept {
1569 PUP::fromMem p(data); p|t;
1572 #define COLL_SEQ_IDX -1
1574 class AmpiMsg final : public CMessage_AmpiMsg {
1575 private:
1576 int ssendReq; //Index to the sender's request
1577 int tag; //MPI tag
1578 int srcRank; //Communicator rank for source
1579 int length; //Number of bytes in this message
1580 MPI_Comm comm; // Communicator
1581 public:
1582 char *data; //Payload
1583 #if CMK_BIGSIM_CHARM
1584 public:
1585 void *event;
1586 int eventPe; // the PE that the event is located
1587 #endif
1589 public:
1590 AmpiMsg() noexcept { data = NULL; }
1591 AmpiMsg(int sreq, int t, int sRank, int l) noexcept :
1592 ssendReq(sreq), tag(t), srcRank(sRank), length(l)
1593 { /* only called from AmpiMsg::pup() since the refnum (seq) will get pup'ed by the runtime */ }
1594 AmpiMsg(CMK_REFNUM_TYPE seq, int sreq, int t, int sRank, int l) noexcept :
1595 ssendReq(sreq), tag(t), srcRank(sRank), length(l)
1596 { CkSetRefNum(this, seq); }
1597 inline void setSsendReq(int s) noexcept { CkAssert(s >= 0); ssendReq = s; }
1598 inline void setSeq(CMK_REFNUM_TYPE s) noexcept { CkAssert(s >= 0); UsrToEnv(this)->setRef(s); }
1599 inline void setSrcRank(int sr) noexcept { srcRank = sr; }
1600 inline void setLength(int l) noexcept { length = l; }
1601 inline void setTag(int t) noexcept { tag = t; }
1602 inline void setComm(MPI_Comm c) noexcept { comm = c; }
1603 inline CMK_REFNUM_TYPE getSeq() const noexcept { return UsrToEnv(this)->getRef(); }
1604 inline int getSsendReq() const noexcept { return ssendReq; }
1605 inline int getSeqIdx() const noexcept {
1606 // seqIdx is srcRank, unless this message was part of a collective
1607 if (tag >= MPI_BCAST_TAG && tag <= MPI_ATA_TAG) {
1608 return COLL_SEQ_IDX;
1610 else {
1611 return srcRank;
1614 inline int getSrcRank() const noexcept { return srcRank; }
1615 inline int getLength() const noexcept { return length; }
1616 inline char* getData() const noexcept { return data; }
1617 inline int getTag() const noexcept { return tag; }
1618 inline MPI_Comm getComm() const noexcept { return comm; }
1619 static AmpiMsg* pup(PUP::er &p, AmpiMsg *m) noexcept
1621 int ssendReq, length, tag, srcRank, comm;
1622 if(p.isPacking() || p.isSizing()) {
1623 ssendReq = m->ssendReq;
1624 tag = m->tag;
1625 srcRank = m->srcRank;
1626 length = m->length;
1627 comm = m->comm;
1629 p(ssendReq); p(tag); p(srcRank); p(length); p(comm);
1630 if(p.isUnpacking()) {
1631 m = new (length, 0) AmpiMsg(ssendReq, tag, srcRank, length, comm);
1633 p(m->data, length);
1634 if(p.isDeleting()) {
1635 delete m;
1636 m = 0;
1638 return m;
1642 #define AMPI_MSG_POOL_SIZE 32 // Max # of AmpiMsg's allowed in the pool
1643 #define AMPI_POOLED_MSG_SIZE 64 // Max # of Bytes in pooled msgs' payload
1645 class AmpiMsgPool {
1646 private:
1647 std::forward_list<AmpiMsg *> msgs; // list of free msgs
1648 int msgLength; // AmpiMsg::length of messages in the pool
1649 int msgUsersize; // usersize of message envelopes in the pool
1650 int maxMsgs; // max # of msgs in the pool
1651 int currMsgs; // current # of msgs in the pool
1653 public:
1654 AmpiMsgPool() noexcept : msgLength(0), msgUsersize(0), maxMsgs(0), currMsgs(0) {}
1655 AmpiMsgPool(int _numMsgs, int _msgLength) noexcept {
1656 msgLength = _msgLength;
1657 maxMsgs = _numMsgs;
1658 if (maxMsgs > 0 && msgLength > 0) {
1659 /* Construct an AmpiMsg to find the usersize (and add it to the pool while it's here).
1660 * The rest of the pool can be filled lazily. */
1661 AmpiMsg* msg = new (msgLength, 0) AmpiMsg(0, 0, 0, 0, 0);
1662 msgs.push_front(msg);
1663 currMsgs = 1;
1664 /* Usersize is the true size of the message envelope, not the length member
1665 * of the AmpiMsg. AmpiMsg::length is used by Ssend msgs to convey the real
1666 * msg payload's length, and is not the length of the Ssend msg itself, so
1667 * it cannot be trusted when returning msgs to the pool. */
1668 msgUsersize = UsrToEnv(msgs.front())->getUsersize();
1670 else {
1671 currMsgs = 0;
1672 msgUsersize = 0;
1675 ~AmpiMsgPool() =default;
1676 inline void clear() noexcept {
1677 while (!msgs.empty()) {
1678 delete msgs.front();
1679 msgs.pop_front();
1681 currMsgs = 0;
1683 inline AmpiMsg* newAmpiMsg(CMK_REFNUM_TYPE seq, int ssendReq, int tag, int srcRank, int len) noexcept {
1684 if (len > msgLength || msgs.empty()) {
1685 return new (len, 0) AmpiMsg(seq, ssendReq, tag, srcRank, len);
1686 } else {
1687 AmpiMsg* msg = msgs.front();
1688 CkAssert(msg != NULL);
1689 msgs.pop_front();
1690 currMsgs--;
1691 msg->setSeq(seq);
1692 msg->setSsendReq(ssendReq);
1693 msg->setTag(tag);
1694 msg->setSrcRank(srcRank);
1695 msg->setLength(len);
1696 return msg;
1699 inline void deleteAmpiMsg(AmpiMsg* msg) noexcept {
1700 if (currMsgs != maxMsgs && UsrToEnv(msg)->getUsersize() == msgUsersize) {
1701 CkAssert(msg != NULL);
1702 msgs.push_front(msg);
1703 currMsgs++;
1704 } else {
1705 delete msg;
1708 void pup(PUP::er& p) {
1709 p|msgLength;
1710 p|msgUsersize;
1711 p|maxMsgs;
1712 // Don't PUP the msgs in the free list or currMsgs, let the pool fill lazily
1716 // Number of requests in the pool
1717 #ifndef AMPI_REQ_POOL_SIZE
1718 #define AMPI_REQ_POOL_SIZE 64
1719 #endif
1721 // Helper macro for pool size and alignment calculations
1722 #define DefinePooledReqX(name, func) \
1723 static const size_t ireq##name = func(IReq); \
1724 static const size_t sreq##name = func(SendReq); \
1725 static const size_t ssreq##name = func(SsendReq); \
1726 static const size_t pooledReq##name = (ireq##name >= sreq##name && ireq##name >= ssreq##name) ? ireq##name : \
1727 (sreq##name >= ireq##name && sreq##name >= ssreq##name) ? sreq##name : \
1728 (ssreq##name);
1730 // This defines 'static const size_t pooledReqSize = ... ;'
1731 DefinePooledReqX(Size, sizeof)
1733 // This defines 'static const size_t pooledReqAlign = ... ;'
1734 DefinePooledReqX(Align, alignof)
1736 // Pool of IReq, SendReq, and SsendReq objects:
1737 // These are different sizes, but we use a single pool for them so
1738 // that iteration over these objects is fast, as in AMPI_Waitall.
1739 // We also try to always allocate new requests from the start to the end
1740 // of the pool, so that forward iteration over requests is fast.
1741 class AmpiRequestPool {
1742 private:
1743 std::bitset<AMPI_REQ_POOL_SIZE> validReqs; // reqs in the pool are either valid (being used by a real req) or invalid
1744 int startIdx; // start next search from this index
1745 alignas(pooledReqAlign) std::array<char, AMPI_REQ_POOL_SIZE*pooledReqSize> reqs; // pool of memory for requests
1747 public:
1748 AmpiRequestPool() noexcept : startIdx(0) {}
1749 ~AmpiRequestPool() =default;
1750 inline IReq* newIReq() noexcept {
1751 if (validReqs.all()) {
1752 return new IReq();
1753 } else {
1754 for (int i=startIdx; i<validReqs.size(); i++) {
1755 if (!validReqs[i]) {
1756 validReqs[i] = 1;
1757 IReq* ireq = new (&reqs[i*pooledReqSize]) IReq();
1758 startIdx = i+1;
1759 return ireq;
1762 CkAbort("AMPI> failed to find a free request in pool!");
1763 return NULL;
1766 inline IReq* newIReq(void* buf, int count, MPI_Datatype type, int src, int tag,
1767 MPI_Comm comm, CkDDT* ddt, AmpiReqSts sts=AMPI_REQ_PENDING) noexcept {
1768 if (validReqs.all()) {
1769 return new IReq(buf, count, type, src, tag, comm, ddt, sts);
1770 } else {
1771 for (int i=startIdx; i<validReqs.size(); i++) {
1772 if (!validReqs[i]) {
1773 validReqs[i] = 1;
1774 IReq* ireq = new (&reqs[i*pooledReqSize]) IReq(buf, count, type, src, tag, comm, ddt, sts);
1775 startIdx = i+1;
1776 return ireq;
1779 CkAbort("AMPI> failed to find a free request in pool!");
1780 return NULL;
1783 inline SendReq* newSendReq() noexcept {
1784 if (validReqs.all()) {
1785 return new SendReq();
1786 } else {
1787 for (int i=startIdx; i<validReqs.size(); i++) {
1788 if (!validReqs[i]) {
1789 validReqs[i] = 1;
1790 SendReq* sreq = new (&reqs[i*pooledReqSize]) SendReq();
1791 startIdx = i+1;
1792 return sreq;
1795 CkAbort("AMPI> failed to find a free request in pool!");
1796 return NULL;
1799 inline SendReq* newSendReq(MPI_Datatype type, MPI_Comm comm, CkDDT* ddt, AmpiReqSts sts=AMPI_REQ_PENDING) noexcept {
1800 if (validReqs.all()) {
1801 return new SendReq(type, comm, ddt, sts);
1802 } else {
1803 for (int i=startIdx; i<validReqs.size(); i++) {
1804 if (!validReqs[i]) {
1805 validReqs[i] = 1;
1806 SendReq* sreq = new (&reqs[i*pooledReqSize]) SendReq(type, comm, ddt, sts);
1807 startIdx = i+1;
1808 return sreq;
1811 CkAbort("AMPI> failed to find a free request in pool!");
1812 return NULL;
1815 inline SendReq* newSendReq(void* buf, int count, MPI_Datatype type, int destRank, int tag,
1816 MPI_Comm comm, CkDDT* ddt, AmpiReqSts sts=AMPI_REQ_PENDING) noexcept {
1817 if (validReqs.all()) {
1818 return new SendReq(buf, count, type, destRank, tag, comm, ddt, sts);
1819 } else {
1820 for (int i=startIdx; i<validReqs.size(); i++) {
1821 if (!validReqs[i]) {
1822 validReqs[i] = 1;
1823 SendReq* sreq = new (&reqs[i*pooledReqSize]) SendReq(buf, count, type, destRank, tag, comm, ddt, sts);
1824 startIdx = i+1;
1825 return sreq;
1828 CkAbort("AMPI> failed to find a free request in pool!");
1829 return NULL;
1832 inline SsendReq* newSsendReq() noexcept {
1833 if (validReqs.all()) {
1834 return new SsendReq();
1835 } else {
1836 for (int i=startIdx; i<validReqs.size(); i++) {
1837 if (!validReqs[i]) {
1838 validReqs[i] = 1;
1839 SsendReq* sreq = new (&reqs[i*pooledReqSize]) SsendReq();
1840 startIdx = i+1;
1841 return sreq;
1844 CkAbort("AMPI> failed to find a free request in pool!");
1845 return NULL;
1848 inline SsendReq* newSsendReq(MPI_Datatype type, MPI_Comm comm, CkDDT* ddt, AmpiReqSts sts=AMPI_REQ_PENDING) noexcept {
1849 if (validReqs.all()) {
1850 return new SsendReq(type, comm, ddt, sts);
1851 } else {
1852 for (int i=startIdx; i<validReqs.size(); i++) {
1853 if (!validReqs[i]) {
1854 validReqs[i] = 1;
1855 SsendReq* sreq = new (&reqs[i*pooledReqSize]) SsendReq(type, comm, ddt, sts);
1856 startIdx = i+1;
1857 return sreq;
1860 CkAbort("AMPI> failed to find a free request in pool!");
1861 return NULL;
1864 inline SsendReq* newSsendReq(void* buf, int count, MPI_Datatype type, int dest, int tag,
1865 MPI_Comm comm, int src, CkDDT* ddt, AmpiReqSts sts=AMPI_REQ_PENDING) noexcept {
1866 if (validReqs.all()) {
1867 return new SsendReq(buf, count, type, dest, tag, comm, src, ddt, sts);
1868 } else {
1869 for (int i=startIdx; i<validReqs.size(); i++) {
1870 if (!validReqs[i]) {
1871 validReqs[i] = 1;
1872 SsendReq* sreq = new (&reqs[i*pooledReqSize]) SsendReq(buf, count, type, dest, tag, comm, src, ddt, sts);
1873 startIdx = i+1;
1874 return sreq;
1877 CkAbort("AMPI> failed to find a free request in pool!");
1878 return NULL;
1881 inline void deleteAmpiRequest(AmpiRequest* req) noexcept {
1882 if (req->isPooledType() &&
1883 ((char*)req >= &reqs.front() && (char*)req <= &reqs.back()))
1885 int idx = (int)((intptr_t)req - (intptr_t)&reqs[0]) / pooledReqSize;
1886 validReqs[idx] = 0;
1887 startIdx = std::min(idx, startIdx);
1888 } else {
1889 delete req;
1892 void pup(PUP::er& p) noexcept {
1893 // Nothing to do here, because AmpiRequestList::pup will be the
1894 // one to actually PUP the AmpiRequest objects to/from the pool
1899 Our local representation of another AMPI
1900 array element. Used to keep track of incoming
1901 and outgoing message sequence numbers, and
1902 the out-of-order message list.
1904 class AmpiOtherElement {
1905 private:
1906 /// Next incoming and outgoing message sequence number
1907 CMK_REFNUM_TYPE seqIncoming, seqOutgoing;
1909 /// Number of messages in out-of-order queue (normally 0)
1910 uint16_t numOutOfOrder;
1912 public:
1913 /// seqIncoming starts from 1, b/c 0 means unsequenced
1914 /// seqOutgoing starts from 0, b/c this will be incremented for the first real seq #
1915 AmpiOtherElement() noexcept : seqIncoming(1), seqOutgoing(0), numOutOfOrder(0) {}
1917 /// Handle wrap around of unsigned type CMK_REFNUM_TYPE
1918 inline void incSeqIncoming() noexcept { seqIncoming++; if (seqIncoming==0) seqIncoming=1; }
1919 inline CMK_REFNUM_TYPE getSeqIncoming() const noexcept { return seqIncoming; }
1921 inline void incSeqOutgoing() noexcept { seqOutgoing++; if (seqOutgoing==0) seqOutgoing=1; }
1922 inline CMK_REFNUM_TYPE getSeqOutgoing() const noexcept { return seqOutgoing; }
1924 inline void incNumOutOfOrder() noexcept { numOutOfOrder++; }
1925 inline void decNumOutOfOrder() noexcept { numOutOfOrder--; }
1926 inline uint16_t getNumOutOfOrder() const noexcept { return numOutOfOrder; }
1928 PUPbytes(AmpiOtherElement)
1930 class AmpiSeqQ : private CkNoncopyable {
1931 CkMsgQ<AmpiMsg> out; // all out of order messages
1932 std::unordered_map<int, AmpiOtherElement> elements; // element info: indexed by seqIdx (comm rank)
1934 public:
1935 AmpiSeqQ() =default;
1936 AmpiSeqQ(int commSize) noexcept {
1937 elements.reserve(std::min(commSize, 64));
1939 ~AmpiSeqQ() =default;
1940 void pup(PUP::er &p) noexcept;
1942 /// Insert this message in the table. Returns the number
1943 /// of messages now available for the element.
1944 /// If 0, the message was out-of-order and is buffered.
1945 /// If 1, this message can be immediately processed.
1946 /// If >1, this message can be immediately processed,
1947 /// and you should call "getOutOfOrder" repeatedly.
1948 inline int put(int seqIdx, AmpiMsg *msg) noexcept {
1949 AmpiOtherElement &el = elements[seqIdx];
1950 if (msg->getSeq() == el.getSeqIncoming()) { // In order:
1951 el.incSeqIncoming();
1952 return 1+el.getNumOutOfOrder();
1954 else { // Out of order: stash message
1955 putOutOfOrder(seqIdx, msg);
1956 return 0;
1960 /// Is this message in order (return >0) or not (return 0)?
1961 /// Same as put() except we don't call putOutOfOrder() here,
1962 /// so the caller should do that separately
1963 inline int isInOrder(int srcRank, CMK_REFNUM_TYPE seq) noexcept {
1964 AmpiOtherElement &el = elements[srcRank];
1965 if (seq == el.getSeqIncoming()) { // In order:
1966 el.incSeqIncoming();
1967 return 1+el.getNumOutOfOrder();
1969 else { // Out of order: caller should stash message
1970 return 0;
1974 /// Get an out-of-order message from the table.
1975 /// (in-order messages never go into the table)
1976 AmpiMsg *getOutOfOrder(int seqIdx) noexcept;
1978 /// Stash an out-of-order message
1979 void putOutOfOrder(int seqIdx, AmpiMsg *msg) noexcept;
1981 /// Increment the outgoing sequence number.
1982 inline void incCollSeqOutgoing() noexcept {
1983 elements[COLL_SEQ_IDX].incSeqOutgoing();
1986 /// Return the next outgoing sequence number, and increment it.
1987 inline CMK_REFNUM_TYPE nextOutgoing(int destRank) noexcept {
1988 AmpiOtherElement &el = elements[destRank];
1989 el.incSeqOutgoing();
1990 return el.getSeqOutgoing();
1993 PUPmarshall(AmpiSeqQ)
1996 inline CProxy_ampi ampiCommStruct::getProxy() const noexcept {return ampiID;}
1997 const ampiCommStruct &universeComm2CommStruct(MPI_Comm universeNo) noexcept;
1999 // Max value of a predefined MPI_Op (values defined in ampi.h)
2000 #define AMPI_MAX_PREDEFINED_OP 13
2003 An ampiParent holds all the communicators and the TCharm thread
2004 for its children, which are bound to it.
2006 class ampiParent final : public CBase_ampiParent {
2007 private:
2008 TCharm *thread;
2009 CProxy_TCharm threads;
2011 public: // Communication state:
2012 int numBlockedReqs; // number of requests currently blocked on
2013 bool resumeOnRecv, resumeOnColl;
2014 AmpiRequestList ampiReqs;
2015 AmpiRequestPool reqPool;
2016 CkDDT myDDT;
2018 private:
2019 MPI_Comm worldNo; //My MPI_COMM_WORLD
2020 ampi *worldPtr; //AMPI element corresponding to MPI_COMM_WORLD
2021 ampiCommStruct worldStruct;
2023 CkPupPtrVec<ampiCommStruct> splitComm; //Communicators from MPI_Comm_split
2024 CkPupPtrVec<ampiCommStruct> groupComm; //Communicators from MPI_Comm_group
2025 CkPupPtrVec<ampiCommStruct> cartComm; //Communicators from MPI_Cart_create
2026 CkPupPtrVec<ampiCommStruct> graphComm; //Communicators from MPI_Graph_create
2027 CkPupPtrVec<ampiCommStruct> distGraphComm; //Communicators from MPI_Dist_graph_create
2028 CkPupPtrVec<ampiCommStruct> interComm; //Communicators from MPI_Intercomm_create
2029 CkPupPtrVec<ampiCommStruct> intraComm; //Communicators from MPI_Intercomm_merge
2031 CkPupPtrVec<groupStruct> groups; // "Wild" groups that don't have a communicator
2032 CkPupPtrVec<WinStruct> winStructList; //List of windows for one-sided communication
2033 CkPupPtrVec<InfoStruct> infos; // list of all MPI_Infos
2034 const std::array<MPI_User_function*, AMPI_MAX_PREDEFINED_OP+1>& predefinedOps; // owned by ampiNodeMgr
2035 vector<OpStruct> userOps; // list of any user-defined MPI_Ops
2036 vector<AmpiMsg *> matchedMsgs; // for use with MPI_Mprobe and MPI_Mrecv
2038 /* MPI_*_get_attr C binding returns a *pointer* to an integer,
2039 * so there needs to be some storage somewhere to point to.
2040 * All builtin keyvals are ints, except for MPI_WIN_BASE, which
2041 * is a pointer, and MPI_WIN_SIZE, which is an MPI_Aint. */
2042 int* kv_builtin_storage;
2043 MPI_Aint* win_size_storage;
2044 void** win_base_storage;
2045 CkPupPtrVec<KeyvalNode> kvlist;
2046 void* bsendBuffer; // NOTE: we don't actually use this for buffering of MPI_Bsend's,
2047 int bsendBufferSize; // we only keep track of it to return it from MPI_Buffer_detach
2049 // Intercommunicator creation:
2050 bool isTmpRProxySet;
2051 CProxy_ampi tmpRProxy;
2053 MPI_MigrateFn userAboutToMigrateFn, userJustMigratedFn;
2055 public:
2056 bool ampiInitCallDone;
2058 private:
2059 bool kv_set_builtin(int keyval, void* attribute_val) noexcept;
2060 bool kv_get_builtin(int keyval) noexcept;
2062 public:
2063 void prepareCtv() noexcept;
2065 MPI_Message putMatchedMsg(AmpiMsg* msg) noexcept {
2066 // Search thru matchedMsgs for any NULL ones first:
2067 for (int i=0; i<matchedMsgs.size(); i++) {
2068 if (matchedMsgs[i] == NULL) {
2069 matchedMsgs[i] = msg;
2070 return i;
2073 // No NULL entries, so create a new one:
2074 matchedMsgs.push_back(msg);
2075 return matchedMsgs.size() - 1;
2077 AmpiMsg* getMatchedMsg(MPI_Message message) noexcept {
2078 if (message == MPI_MESSAGE_NO_PROC || message == MPI_MESSAGE_NULL) {
2079 return NULL;
2081 CkAssert(message >= 0 && message < matchedMsgs.size());
2082 AmpiMsg* msg = matchedMsgs[message];
2083 // Mark this matchedMsg index NULL and free from back of vector:
2084 matchedMsgs[message] = NULL;
2085 while (matchedMsgs.back() == NULL) {
2086 matchedMsgs.pop_back();
2088 return msg;
2091 inline void attachBuffer(void *buffer, int size) noexcept {
2092 bsendBuffer = buffer;
2093 bsendBufferSize = size;
2095 inline void detachBuffer(void *buffer, int *size) noexcept {
2096 *(void **)buffer = bsendBuffer;
2097 *size = bsendBufferSize;
2099 inline bool isSplit(MPI_Comm comm) const noexcept {
2100 return (comm>=MPI_COMM_FIRST_SPLIT && comm<MPI_COMM_FIRST_GROUP);
2102 const ampiCommStruct &getSplit(MPI_Comm comm) const noexcept {
2103 int idx=comm-MPI_COMM_FIRST_SPLIT;
2104 if (idx>=splitComm.size()) CkAbort("Bad split communicator used");
2105 return *splitComm[idx];
2107 void splitChildRegister(const ampiCommStruct &s) noexcept;
2109 inline bool isGroup(MPI_Comm comm) const noexcept {
2110 return (comm>=MPI_COMM_FIRST_GROUP && comm<MPI_COMM_FIRST_CART);
2112 const ampiCommStruct &getGroup(MPI_Comm comm) const noexcept {
2113 int idx=comm-MPI_COMM_FIRST_GROUP;
2114 if (idx>=groupComm.size()) CkAbort("Bad group communicator used");
2115 return *groupComm[idx];
2117 void groupChildRegister(const ampiCommStruct &s) noexcept;
2118 inline bool isInGroups(MPI_Group group) const noexcept {
2119 return (group>=0 && group<groups.size());
2122 void cartChildRegister(const ampiCommStruct &s) noexcept;
2123 void graphChildRegister(const ampiCommStruct &s) noexcept;
2124 void distGraphChildRegister(const ampiCommStruct &s) noexcept;
2125 void interChildRegister(const ampiCommStruct &s) noexcept;
2126 void intraChildRegister(const ampiCommStruct &s) noexcept;
2128 public:
2129 ampiParent(MPI_Comm worldNo_,CProxy_TCharm threads_,int nRanks_) noexcept;
2130 ampiParent(CkMigrateMessage *msg) noexcept;
2131 void ckAboutToMigrate() noexcept;
2132 void ckJustMigrated() noexcept;
2133 void ckJustRestored() noexcept;
2134 void setUserAboutToMigrateFn(MPI_MigrateFn f) noexcept;
2135 void setUserJustMigratedFn(MPI_MigrateFn f) noexcept;
2136 ~ampiParent() noexcept;
2138 //Children call this when they are first created, or just migrated
2139 TCharm *registerAmpi(ampi *ptr,ampiCommStruct s,bool forMigration) noexcept;
2141 // exchange proxy info between two ampi proxies
2142 void ExchangeProxy(CProxy_ampi rproxy) noexcept {
2143 if(!isTmpRProxySet){ tmpRProxy=rproxy; isTmpRProxySet=true; }
2144 else{ tmpRProxy.setRemoteProxy(rproxy); rproxy.setRemoteProxy(tmpRProxy); isTmpRProxySet=false; }
2147 //Grab the next available split/group communicator
2148 MPI_Comm getNextSplit() const noexcept {return MPI_COMM_FIRST_SPLIT+splitComm.size();}
2149 MPI_Comm getNextGroup() const noexcept {return MPI_COMM_FIRST_GROUP+groupComm.size();}
2150 MPI_Comm getNextCart() const noexcept {return MPI_COMM_FIRST_CART+cartComm.size();}
2151 MPI_Comm getNextGraph() const noexcept {return MPI_COMM_FIRST_GRAPH+graphComm.size();}
2152 MPI_Comm getNextDistGraph() const noexcept {return MPI_COMM_FIRST_DIST_GRAPH+distGraphComm.size();}
2153 MPI_Comm getNextInter() const noexcept {return MPI_COMM_FIRST_INTER+interComm.size();}
2154 MPI_Comm getNextIntra() const noexcept {return MPI_COMM_FIRST_INTRA+intraComm.size();}
2156 inline bool isCart(MPI_Comm comm) const noexcept {
2157 return (comm>=MPI_COMM_FIRST_CART && comm<MPI_COMM_FIRST_GRAPH);
2159 ampiCommStruct &getCart(MPI_Comm comm) const noexcept {
2160 int idx=comm-MPI_COMM_FIRST_CART;
2161 if (idx>=cartComm.size()) CkAbort("AMPI> Bad cartesian communicator used!\n");
2162 return *cartComm[idx];
2164 inline bool isGraph(MPI_Comm comm) const noexcept {
2165 return (comm>=MPI_COMM_FIRST_GRAPH && comm<MPI_COMM_FIRST_DIST_GRAPH);
2167 ampiCommStruct &getGraph(MPI_Comm comm) const noexcept {
2168 int idx=comm-MPI_COMM_FIRST_GRAPH;
2169 if (idx>=graphComm.size()) CkAbort("AMPI> Bad graph communicator used!\n");
2170 return *graphComm[idx];
2172 inline bool isDistGraph(MPI_Comm comm) const noexcept {
2173 return (comm >= MPI_COMM_FIRST_DIST_GRAPH && comm < MPI_COMM_FIRST_INTER);
2175 ampiCommStruct &getDistGraph(MPI_Comm comm) const noexcept {
2176 int idx = comm-MPI_COMM_FIRST_DIST_GRAPH;
2177 if (idx>=distGraphComm.size()) CkAbort("Bad distributed graph communicator used");
2178 return *distGraphComm[idx];
2180 inline bool isInter(MPI_Comm comm) const noexcept {
2181 return (comm>=MPI_COMM_FIRST_INTER && comm<MPI_COMM_FIRST_INTRA);
2183 const ampiCommStruct &getInter(MPI_Comm comm) const noexcept {
2184 int idx=comm-MPI_COMM_FIRST_INTER;
2185 if (idx>=interComm.size()) CkAbort("AMPI> Bad inter-communicator used!\n");
2186 return *interComm[idx];
2188 inline bool isIntra(MPI_Comm comm) const noexcept {
2189 return (comm>=MPI_COMM_FIRST_INTRA && comm<MPI_COMM_FIRST_RESVD);
2191 const ampiCommStruct &getIntra(MPI_Comm comm) const noexcept {
2192 int idx=comm-MPI_COMM_FIRST_INTRA;
2193 if (idx>=intraComm.size()) CkAbort("Bad intra-communicator used");
2194 return *intraComm[idx];
2197 void pup(PUP::er &p) noexcept;
2199 void startCheckpoint(const char* dname) noexcept;
2200 void Checkpoint(int len, const char* dname) noexcept;
2201 void ResumeThread() noexcept;
2202 TCharm* getTCharmThread() const noexcept {return thread;}
2203 inline ampiParent* blockOnRecv() noexcept;
2204 inline CkDDT* getDDT() noexcept { return &myDDT; }
2206 #if CMK_LBDB_ON
2207 void setMigratable(bool mig) noexcept {
2208 thread->setMigratable(mig);
2210 #endif
2212 inline const ampiCommStruct &comm2CommStruct(MPI_Comm comm) const noexcept {
2213 if (comm==MPI_COMM_WORLD) return worldStruct;
2214 if (comm==worldNo) return worldStruct;
2215 if (isSplit(comm)) return getSplit(comm);
2216 if (isGroup(comm)) return getGroup(comm);
2217 if (isCart(comm)) return getCart(comm);
2218 if (isGraph(comm)) return getGraph(comm);
2219 if (isDistGraph(comm)) return getDistGraph(comm);
2220 if (isInter(comm)) return getInter(comm);
2221 if (isIntra(comm)) return getIntra(comm);
2222 return universeComm2CommStruct(comm);
2225 inline vector<int>& getKeyvals(MPI_Comm comm) noexcept {
2226 ampiCommStruct &cs = *(ampiCommStruct *)&comm2CommStruct(comm);
2227 return cs.getKeyvals();
2230 inline ampi *comm2ampi(MPI_Comm comm) const noexcept {
2231 if (comm==MPI_COMM_WORLD) return worldPtr;
2232 if (comm==worldNo) return worldPtr;
2233 if (isSplit(comm)) {
2234 const ampiCommStruct &st=getSplit(comm);
2235 return st.getProxy()[thisIndex].ckLocal();
2237 if (isGroup(comm)) {
2238 const ampiCommStruct &st=getGroup(comm);
2239 return st.getProxy()[thisIndex].ckLocal();
2241 if (isCart(comm)) {
2242 const ampiCommStruct &st = getCart(comm);
2243 return st.getProxy()[thisIndex].ckLocal();
2245 if (isGraph(comm)) {
2246 const ampiCommStruct &st = getGraph(comm);
2247 return st.getProxy()[thisIndex].ckLocal();
2249 if (isDistGraph(comm)) {
2250 const ampiCommStruct &st = getDistGraph(comm);
2251 return st.getProxy()[thisIndex].ckLocal();
2253 if (isInter(comm)) {
2254 const ampiCommStruct &st=getInter(comm);
2255 return st.getProxy()[thisIndex].ckLocal();
2257 if (isIntra(comm)) {
2258 const ampiCommStruct &st=getIntra(comm);
2259 return st.getProxy()[thisIndex].ckLocal();
2261 if (comm>MPI_COMM_WORLD) return worldPtr; //Use MPI_WORLD ampi for cross-world messages:
2262 CkAbort("Invalid communicator used!");
2263 return NULL;
2266 inline bool hasComm(const MPI_Group group) const noexcept {
2267 MPI_Comm comm = (MPI_Comm)group;
2268 return ( comm==MPI_COMM_WORLD || comm==worldNo || isSplit(comm) || isGroup(comm) ||
2269 isCart(comm) || isGraph(comm) || isDistGraph(comm) || isIntra(comm) );
2270 //isInter omitted because its comm number != its group number
2272 inline const groupStruct group2vec(MPI_Group group) const noexcept {
2273 if(group == MPI_GROUP_NULL || group == MPI_GROUP_EMPTY)
2274 return groupStruct();
2275 if(hasComm(group))
2276 return comm2CommStruct((MPI_Comm)group).getIndices();
2277 if(isInGroups(group))
2278 return *groups[group];
2279 CkAbort("ampiParent::group2vec: Invalid group id!");
2280 return *groups[0]; //meaningless return
2282 inline MPI_Group saveGroupStruct(groupStruct vec) noexcept {
2283 if (vec.empty()) return MPI_GROUP_EMPTY;
2284 int idx = groups.size();
2285 groups.resize(idx+1);
2286 groups[idx]=new groupStruct(vec);
2287 return (MPI_Group)idx;
2289 inline int getRank(const MPI_Group group) const noexcept {
2290 groupStruct vec = group2vec(group);
2291 return getPosOp(thisIndex,vec);
2293 inline AmpiRequestList &getReqs() noexcept { return ampiReqs; }
2294 inline int getMyPe() const noexcept {
2295 return CkMyPe();
2297 inline bool hasWorld() const noexcept {
2298 return worldPtr!=NULL;
2301 inline void checkComm(MPI_Comm comm) const noexcept {
2302 if ((comm != MPI_COMM_SELF && comm != MPI_COMM_WORLD)
2303 || (isSplit(comm) && comm-MPI_COMM_FIRST_SPLIT >= splitComm.size())
2304 || (isGroup(comm) && comm-MPI_COMM_FIRST_GROUP >= groupComm.size())
2305 || (isCart(comm) && comm-MPI_COMM_FIRST_CART >= cartComm.size())
2306 || (isGraph(comm) && comm-MPI_COMM_FIRST_GRAPH >= graphComm.size())
2307 || (isDistGraph(comm) && comm-MPI_COMM_FIRST_DIST_GRAPH >= distGraphComm.size())
2308 || (isInter(comm) && comm-MPI_COMM_FIRST_INTER >= interComm.size())
2309 || (isIntra(comm) && comm-MPI_COMM_FIRST_INTRA >= intraComm.size()) )
2310 CkAbort("Invalid MPI_Comm\n");
2313 /// if intra-communicator, return comm, otherwise return null group
2314 inline MPI_Group comm2group(const MPI_Comm comm) const noexcept {
2315 if(isInter(comm)) return MPI_GROUP_NULL; // we don't support inter-communicator in such functions
2316 ampiCommStruct s = comm2CommStruct(comm);
2317 if(comm!=MPI_COMM_WORLD && comm!=s.getComm()) CkAbort("Error in ampiParent::comm2group()");
2318 return (MPI_Group)(s.getComm());
2321 inline int getRemoteSize(const MPI_Comm comm) const noexcept {
2322 if(isInter(comm)) return getInter(comm).getRemoteIndices().size();
2323 else return -1;
2325 inline MPI_Group getRemoteGroup(const MPI_Comm comm) noexcept {
2326 if(isInter(comm)) return saveGroupStruct(getInter(comm).getRemoteIndices());
2327 else return MPI_GROUP_NULL;
2330 int createKeyval(MPI_Copy_function *copy_fn, MPI_Delete_function *delete_fn,
2331 int *keyval, void* extra_state) noexcept;
2332 bool getBuiltinKeyval(int keyval, void *attribute_val) noexcept;
2333 int setUserKeyval(MPI_Comm comm, int keyval, void *attribute_val) noexcept;
2334 bool getUserKeyval(MPI_Comm comm, vector<int>& keyvals, int keyval, void *attribute_val, int *flag) noexcept;
2335 int dupUserKeyvals(MPI_Comm old_comm, MPI_Comm new_comm) noexcept;
2336 int freeUserKeyval(int context, vector<int>& keyvals, int *keyval) noexcept;
2337 int freeUserKeyvals(int context, vector<int>& keyvals) noexcept;
2339 int setAttr(MPI_Comm comm, vector<int>& keyvals, int keyval, void *attribute_val) noexcept;
2340 int getAttr(MPI_Comm comm, vector<int>& keyvals, int keyval, void *attribute_val, int *flag) noexcept;
2341 int deleteAttr(MPI_Comm comm, vector<int>& keyvals, int keyval) noexcept;
2343 int addWinStruct(WinStruct *win) noexcept;
2344 WinStruct *getWinStruct(MPI_Win win) const noexcept;
2345 void removeWinStruct(WinStruct *win) noexcept;
2347 int createInfo(MPI_Info *newinfo) noexcept;
2348 int dupInfo(MPI_Info info, MPI_Info *newinfo) noexcept;
2349 int setInfo(MPI_Info info, const char *key, const char *value) noexcept;
2350 int deleteInfo(MPI_Info info, const char *key) noexcept;
2351 int getInfo(MPI_Info info, const char *key, int valuelen, char *value, int *flag) const noexcept;
2352 int getInfoValuelen(MPI_Info info, const char *key, int *valuelen, int *flag) const noexcept;
2353 int getInfoNkeys(MPI_Info info, int *nkeys) const noexcept;
2354 int getInfoNthkey(MPI_Info info, int n, char *key) const noexcept;
2355 int freeInfo(MPI_Info info) noexcept;
2356 void defineInfoEnv(int nRanks_) noexcept;
2357 void defineInfoMigration() noexcept;
2359 // An 'MPI_Op' is an integer that indexes into either:
2360 // A) an array of predefined ops owned by ampiNodeMgr, or
2361 // B) a vector of user-defined ops owned by ampiParent
2362 // The MPI_Op is compared to AMPI_MAX_PREDEFINED_OP to disambiguate.
2363 inline int createOp(MPI_User_function *fn, bool isCommutative) noexcept {
2364 // Search thru non-predefined op's for any invalidated ones:
2365 for (int i=0; i<userOps.size(); i++) {
2366 if (userOps[i].isFree()) {
2367 userOps[i].init(fn, isCommutative);
2368 return AMPI_MAX_PREDEFINED_OP + 1 + i;
2371 // No invalid entries, so create a new one:
2372 userOps.emplace_back(fn, isCommutative);
2373 return AMPI_MAX_PREDEFINED_OP + userOps.size();
2375 inline void freeOp(MPI_Op op) noexcept {
2376 // Don't free predefined op's:
2377 if (!opIsPredefined(op)) {
2378 // Invalidate op, then free all invalid op's from the back of the userOp's vector
2379 int opIdx = op - 1 - AMPI_MAX_PREDEFINED_OP;
2380 CkAssert(opIdx < userOps.size());
2381 userOps[opIdx].free();
2382 while (!userOps.empty() && userOps.back().isFree()) {
2383 userOps.pop_back();
2387 inline bool opIsPredefined(MPI_Op op) const noexcept {
2388 return (op <= AMPI_MAX_PREDEFINED_OP);
2390 inline bool opIsCommutative(MPI_Op op) const noexcept {
2391 if (opIsPredefined(op)) {
2392 return true; // all predefined ops are commutative
2394 else {
2395 int opIdx = op - 1 - AMPI_MAX_PREDEFINED_OP;
2396 CkAssert(opIdx < userOps.size());
2397 return userOps[opIdx].isCommutative;
2400 inline MPI_User_function* op2User_function(MPI_Op op) const noexcept {
2401 if (opIsPredefined(op)) {
2402 return predefinedOps[op];
2404 else {
2405 int opIdx = op - 1 - AMPI_MAX_PREDEFINED_OP;
2406 CkAssert(opIdx < userOps.size());
2407 return userOps[opIdx].func;
2410 inline AmpiOpHeader op2AmpiOpHeader(MPI_Op op, MPI_Datatype type, int count) const noexcept {
2411 if (opIsPredefined(op)) {
2412 int size = myDDT.getType(type)->getSize(count);
2413 return AmpiOpHeader(predefinedOps[op], type, count, size);
2415 else {
2416 int opIdx = op - 1 - AMPI_MAX_PREDEFINED_OP;
2417 CkAssert(opIdx < userOps.size());
2418 int size = myDDT.getType(type)->getSize(count);
2419 return AmpiOpHeader(userOps[opIdx].func, type, count, size);
2422 inline void applyOp(MPI_Datatype datatype, MPI_Op op, int count, const void* invec, void* inoutvec) const noexcept {
2423 // inoutvec[i] = invec[i] op inoutvec[i]
2424 MPI_User_function *func = op2User_function(op);
2425 (func)((void*)invec, inoutvec, &count, &datatype);
2428 void init() noexcept;
2429 void finalize() noexcept;
2430 void block() noexcept;
2431 void yield() noexcept;
2433 #if AMPI_PRINT_MSG_SIZES
2434 // Map of AMPI routine names to message sizes and number of messages:
2435 // ["AMPI_Routine"][ [msg_size][num_msgs] ]
2436 std::unordered_map<std::string, std::map<int, int> > msgSizes;
2437 inline bool isRankRecordingMsgSizes() noexcept;
2438 inline void recordMsgSize(const char* func, int msgSize) noexcept;
2439 void printMsgSizes() noexcept;
2440 #endif
2442 #if AMPIMSGLOG
2443 /* message logging */
2444 int pupBytes;
2445 #if CMK_USE_ZLIB && 0
2446 gzFile fMsgLog;
2447 PUP::tozDisk *toPUPer;
2448 PUP::fromzDisk *fromPUPer;
2449 #else
2450 FILE* fMsgLog;
2451 PUP::toDisk *toPUPer;
2452 PUP::fromDisk *fromPUPer;
2453 #endif
2454 #endif
2457 // Store a generalized request class created by MPIX_Grequest_class_create
2458 class greq_class_desc {
2459 public:
2460 MPI_Grequest_query_function *query_fn;
2461 MPI_Grequest_free_function *free_fn;
2462 MPI_Grequest_cancel_function *cancel_fn;
2463 MPIX_Grequest_poll_function *poll_fn;
2464 MPIX_Grequest_wait_function *wait_fn;
2466 void pup(PUP::er &p) noexcept {
2467 p((char *)query_fn, sizeof(void *));
2468 p((char *)free_fn, sizeof(void *));
2469 p((char *)cancel_fn, sizeof(void *));
2470 p((char *)poll_fn, sizeof(void *));
2471 p((char *)wait_fn, sizeof(void *));
2476 An ampi manages the communication of one thread over
2477 one MPI communicator.
2479 class ampi final : public CBase_ampi {
2480 private:
2481 friend class IReq; // for checking resumeOnRecv
2482 friend class SendReq;
2483 friend class SsendReq;
2484 friend class RednReq;
2485 friend class GatherReq;
2486 friend class GathervReq;
2488 ampiParent *parent;
2489 CProxy_ampiParent parentProxy;
2490 TCharm *thread;
2492 AmpiRequest *blockingReq;
2493 int myRank;
2494 AmpiSeqQ oorder;
2496 public:
2498 * AMPI Message Matching (Amm) queues are indexed by the tag and sender.
2499 * Since ampi objects are per-communicator, there are separate Amm's per communicator.
2501 Amm<AmpiRequest *> postedReqs;
2502 Amm<AmpiMsg *> unexpectedMsgs;
2504 // Store generalized request classes created by MPIX_Grequest_class_create
2505 vector<greq_class_desc> greq_classes;
2507 private:
2508 ampiCommStruct myComm;
2509 groupStruct tmpVec; // stores temp group info
2510 CProxy_ampi remoteProxy; // valid only for intercommunicator
2511 CkPupPtrVec<win_obj> winObjects;
2513 private:
2514 void inorder(AmpiMsg *msg) noexcept;
2515 void inorderRdma(char* buf, int size, CMK_REFNUM_TYPE seq, int tag, int srcRank,
2516 MPI_Comm comm, int ssendReq) noexcept;
2518 void init() noexcept;
2519 void findParent(bool forMigration) noexcept;
2521 public: // entry methods
2522 ampi() noexcept;
2523 ampi(CkArrayID parent_,const ampiCommStruct &s) noexcept;
2524 ampi(CkMigrateMessage *msg) noexcept;
2525 void ckJustMigrated() noexcept;
2526 void ckJustRestored() noexcept;
2527 ~ampi() noexcept;
2529 void pup(PUP::er &p) noexcept;
2531 void allInitDone() noexcept;
2532 void setInitDoneFlag() noexcept;
2534 void unblock() noexcept;
2535 void generic(AmpiMsg *) noexcept;
2536 void genericRdma(char* buf, int size, CMK_REFNUM_TYPE seq, int tag, int srcRank,
2537 MPI_Comm destcomm, int ssendReq) noexcept;
2538 void completedRdmaSend(CkDataMsg *msg) noexcept;
2539 void ssend_ack(int sreq) noexcept;
2540 void barrierResult() noexcept;
2541 void ibarrierResult() noexcept;
2542 void rednResult(CkReductionMsg *msg) noexcept;
2543 void irednResult(CkReductionMsg *msg) noexcept;
2545 void splitPhase1(CkReductionMsg *msg) noexcept;
2546 void splitPhaseInter(CkReductionMsg *msg) noexcept;
2547 void commCreatePhase1(MPI_Comm nextGroupComm) noexcept;
2548 void intercommCreatePhase1(MPI_Comm nextInterComm) noexcept;
2549 void intercommMergePhase1(MPI_Comm nextIntraComm) noexcept;
2551 private: // Used by the above entry methods that create new MPI_Comm objects
2552 CProxy_ampi createNewChildAmpiSync() noexcept;
2553 void insertNewChildAmpiElements(MPI_Comm newComm, CProxy_ampi newAmpi) noexcept;
2555 inline void handleBlockedReq(AmpiRequest* req) noexcept {
2556 if (req->isBlocked() && parent->numBlockedReqs != 0) {
2557 parent->numBlockedReqs--;
2560 inline void resumeThreadIfReady() noexcept {
2561 if (parent->resumeOnRecv && parent->numBlockedReqs == 0) {
2562 thread->resume();
2566 public: // to be used by MPI_* functions
2567 inline const ampiCommStruct &comm2CommStruct(MPI_Comm comm) const noexcept {
2568 return parent->comm2CommStruct(comm);
2571 inline ampi* blockOnRecv() noexcept;
2572 inline ampi* blockOnColl() noexcept;
2573 inline ampi* blockOnRedn(AmpiRequest *req) noexcept;
2574 MPI_Request postReq(AmpiRequest* newreq) noexcept;
2576 inline CMK_REFNUM_TYPE getSeqNo(int destRank, MPI_Comm destcomm, int tag) noexcept;
2577 AmpiMsg *makeBcastMsg(const void *buf,int count,MPI_Datatype type,MPI_Comm destcomm) noexcept;
2578 AmpiMsg *makeAmpiMsg(int destRank,int t,int sRank,const void *buf,int count,
2579 MPI_Datatype type,MPI_Comm destcomm, int ssendReq=0) noexcept;
2581 MPI_Request send(int t, int s, const void* buf, int count, MPI_Datatype type, int rank,
2582 MPI_Comm destcomm, int ssendReq=0, AmpiSendType sendType=BLOCKING_SEND) noexcept;
2583 static void sendraw(int t, int s, void* buf, int len, CkArrayID aid, int idx) noexcept;
2584 inline MPI_Request sendLocalMsg(int t, int sRank, const void* buf, int size, MPI_Datatype type, int destRank,
2585 MPI_Comm destcomm, ampi* destPtr, int ssendReq, AmpiSendType sendType) noexcept;
2586 inline MPI_Request sendRdmaMsg(int t, int sRank, const void* buf, int size, MPI_Datatype type, int destIdx,
2587 int destRank, MPI_Comm destcomm, CProxy_ampi arrProxy, int ssendReq) noexcept;
2588 inline bool destLikelyWithinProcess(CProxy_ampi arrProxy, int destIdx) const noexcept {
2589 CkArray* localBranch = arrProxy.ckLocalBranch();
2590 int destPe = localBranch->lastKnown(CkArrayIndex1D(destIdx));
2591 return (CkNodeOf(destPe) == CkMyNode());
2593 MPI_Request delesend(int t, int s, const void* buf, int count, MPI_Datatype type, int rank,
2594 MPI_Comm destcomm, CProxy_ampi arrproxy, int ssend, AmpiSendType sendType) noexcept;
2595 inline void processAmpiMsg(AmpiMsg *msg, void* buf, MPI_Datatype type, int count) noexcept;
2596 inline void processRdmaMsg(const void *sbuf, int slength, int ssendReq, int srank, void* rbuf,
2597 int rcount, MPI_Datatype rtype, MPI_Comm comm) noexcept;
2598 inline void processRednMsg(CkReductionMsg *msg, void* buf, MPI_Datatype type, int count) noexcept;
2599 inline void processNoncommutativeRednMsg(CkReductionMsg *msg, void* buf, MPI_Datatype type, int count,
2600 MPI_User_function* func) noexcept;
2601 inline void processGatherMsg(CkReductionMsg *msg, void* buf, MPI_Datatype type, int recvCount) noexcept;
2602 inline void processGathervMsg(CkReductionMsg *msg, void* buf, MPI_Datatype type,
2603 int* recvCounts, int* displs) noexcept;
2604 inline AmpiMsg * getMessage(int t, int s, MPI_Comm comm, int *sts) const noexcept;
2605 int recv(int t,int s,void* buf,int count,MPI_Datatype type,MPI_Comm comm,MPI_Status *sts=NULL) noexcept;
2606 void irecv(void *buf, int count, MPI_Datatype type, int src,
2607 int tag, MPI_Comm comm, MPI_Request *request) noexcept;
2608 void mrecv(int tag, int src, void* buf, int count, MPI_Datatype datatype, MPI_Comm comm,
2609 MPI_Status* status, MPI_Message* message) noexcept;
2610 void imrecv(void* buf, int count, MPI_Datatype datatype, int src, int tag, MPI_Comm comm,
2611 MPI_Request* request, MPI_Message* message) noexcept;
2612 void sendrecv(const void *sbuf, int scount, MPI_Datatype stype, int dest, int stag,
2613 void *rbuf, int rcount, MPI_Datatype rtype, int src, int rtag,
2614 MPI_Comm comm, MPI_Status *sts) noexcept;
2615 void sendrecv_replace(void* buf, int count, MPI_Datatype datatype,
2616 int dest, int sendtag, int source, int recvtag,
2617 MPI_Comm comm, MPI_Status *status) noexcept;
2618 void probe(int t,int s,MPI_Comm comm,MPI_Status *sts) noexcept;
2619 void mprobe(int t, int s, MPI_Comm comm, MPI_Status *sts, MPI_Message *message) noexcept;
2620 int iprobe(int t,int s,MPI_Comm comm,MPI_Status *sts) noexcept;
2621 int improbe(int t, int s, MPI_Comm comm, MPI_Status *sts, MPI_Message *message) noexcept;
2622 void barrier() noexcept;
2623 void ibarrier(MPI_Request *request) noexcept;
2624 void bcast(int root, void* buf, int count, MPI_Datatype type, MPI_Comm comm) noexcept;
2625 int intercomm_bcast(int root, void* buf, int count, MPI_Datatype type, MPI_Comm intercomm) noexcept;
2626 void ibcast(int root, void* buf, int count, MPI_Datatype type, MPI_Comm comm, MPI_Request* request) noexcept;
2627 int intercomm_ibcast(int root, void* buf, int count, MPI_Datatype type, MPI_Comm intercomm, MPI_Request *request) noexcept;
2628 static void bcastraw(void* buf, int len, CkArrayID aid) noexcept;
2629 void split(int color,int key,MPI_Comm *dest, int type) noexcept;
2630 void commCreate(const groupStruct vec,MPI_Comm *newcomm) noexcept;
2631 MPI_Comm cartCreate0D() noexcept;
2632 MPI_Comm cartCreate(groupStruct vec, int ndims, const int* dims) noexcept;
2633 void graphCreate(const groupStruct vec, MPI_Comm *newcomm) noexcept;
2634 void distGraphCreate(const groupStruct vec, MPI_Comm *newcomm) noexcept;
2635 void intercommCreate(const groupStruct rvec, int root, MPI_Comm tcomm, MPI_Comm *ncomm) noexcept;
2637 inline bool isInter() const noexcept { return myComm.isinter(); }
2638 void intercommMerge(int first, MPI_Comm *ncomm) noexcept;
2640 inline int getWorldRank() const noexcept {return parent->thisIndex;}
2641 /// Return our rank in this communicator
2642 inline int getRank() const noexcept {return myRank;}
2643 inline int getSize() const noexcept {return myComm.getSize();}
2644 inline MPI_Comm getComm() const noexcept {return myComm.getComm();}
2645 inline void setCommName(const char *name) noexcept {myComm.setName(name);}
2646 inline void getCommName(char *name, int *len) const noexcept {myComm.getName(name,len);}
2647 inline vector<int> getIndices() const noexcept { return myComm.getIndices(); }
2648 inline vector<int> getRemoteIndices() const noexcept { return myComm.getRemoteIndices(); }
2649 inline const CProxy_ampi &getProxy() const noexcept {return thisProxy;}
2650 inline const CProxy_ampi &getRemoteProxy() const noexcept {return remoteProxy;}
2651 inline void setRemoteProxy(CProxy_ampi rproxy) noexcept { remoteProxy = rproxy; thread->resume(); }
2652 inline int getIndexForRank(int r) const noexcept {return myComm.getIndexForRank(r);}
2653 inline int getIndexForRemoteRank(int r) const noexcept {return myComm.getIndexForRemoteRank(r);}
2654 void findNeighbors(MPI_Comm comm, int rank, vector<int>& neighbors) const noexcept;
2655 inline const vector<int>& getNeighbors() const noexcept { return myComm.getTopologyforNeighbors()->getnbors(); }
2656 inline bool opIsCommutative(MPI_Op op) const noexcept { return parent->opIsCommutative(op); }
2657 inline MPI_User_function* op2User_function(MPI_Op op) const noexcept { return parent->op2User_function(op); }
2658 void topoDup(int topoType, int rank, MPI_Comm comm, MPI_Comm *newcomm) noexcept;
2660 inline AmpiRequestList& getReqs() noexcept { return parent->ampiReqs; }
2661 CkDDT *getDDT() noexcept {return &parent->myDDT;}
2662 CthThread getThread() const noexcept { return thread->getThread(); }
2664 public:
2665 MPI_Win createWinInstance(void *base, MPI_Aint size, int disp_unit, MPI_Info info) noexcept;
2666 int deleteWinInstance(MPI_Win win) noexcept;
2667 int winGetGroup(WinStruct *win, MPI_Group *group) const noexcept;
2668 int winPut(const void *orgaddr, int orgcnt, MPI_Datatype orgtype, int rank,
2669 MPI_Aint targdisp, int targcnt, MPI_Datatype targtype, WinStruct *win) noexcept;
2670 int winGet(void *orgaddr, int orgcnt, MPI_Datatype orgtype, int rank,
2671 MPI_Aint targdisp, int targcnt, MPI_Datatype targtype, WinStruct *win) noexcept;
2672 int winIget(MPI_Aint orgdisp, int orgcnt, MPI_Datatype orgtype, int rank,
2673 MPI_Aint targdisp, int targcnt, MPI_Datatype targtype, WinStruct *win,
2674 MPI_Request *req) noexcept;
2675 int winIgetWait(MPI_Request *request, MPI_Status *status) noexcept;
2676 int winIgetFree(MPI_Request *request, MPI_Status *status) noexcept;
2677 void winRemotePut(int orgtotalsize, char* orgaddr, int orgcnt, MPI_Datatype orgtype,
2678 MPI_Aint targdisp, int targcnt, MPI_Datatype targtype, int winIndex) noexcept;
2679 char* winLocalGet(int orgcnt, MPI_Datatype orgtype, MPI_Aint targdisp, int targcnt,
2680 MPI_Datatype targtype, int winIndex) noexcept;
2681 AmpiMsg* winRemoteGet(int orgcnt, MPI_Datatype orgtype, MPI_Aint targdisp,
2682 int targcnt, MPI_Datatype targtype, int winIndex) noexcept;
2683 AmpiMsg* winRemoteIget(MPI_Aint orgdisp, int orgcnt, MPI_Datatype orgtype, MPI_Aint targdisp,
2684 int targcnt, MPI_Datatype targtype, int winIndex) noexcept;
2685 int winLock(int lock_type, int rank, WinStruct *win) noexcept;
2686 int winUnlock(int rank, WinStruct *win) noexcept;
2687 void winRemoteLock(int lock_type, int winIndex, int requestRank) noexcept;
2688 void winRemoteUnlock(int winIndex, int requestRank) noexcept;
2689 int winAccumulate(const void *orgaddr, int orgcnt, MPI_Datatype orgtype, int rank,
2690 MPI_Aint targdisp, int targcnt, MPI_Datatype targtype,
2691 MPI_Op op, WinStruct *win) noexcept;
2692 void winRemoteAccumulate(int orgtotalsize, char* orgaddr, int orgcnt, MPI_Datatype orgtype,
2693 MPI_Aint targdisp, int targcnt, MPI_Datatype targtype,
2694 MPI_Op op, int winIndex) noexcept;
2695 int winGetAccumulate(const void *orgaddr, int orgcnt, MPI_Datatype orgtype, void *resaddr,
2696 int rescnt, MPI_Datatype restype, int rank, MPI_Aint targdisp,
2697 int targcnt, MPI_Datatype targtype, MPI_Op op, WinStruct *win) noexcept;
2698 void winLocalGetAccumulate(int orgtotalsize, char* sorgaddr, int orgcnt, MPI_Datatype orgtype,
2699 MPI_Aint targdisp, int targcnt, MPI_Datatype targtype, MPI_Op op,
2700 char *resaddr, int winIndex) noexcept;
2701 AmpiMsg* winRemoteGetAccumulate(int orgtotalsize, char* sorgaddr, int orgcnt, MPI_Datatype orgtype,
2702 MPI_Aint targdisp, int targcnt, MPI_Datatype targtype, MPI_Op op,
2703 int winIndex) noexcept;
2704 int winCompareAndSwap(const void *orgaddr, const void *compaddr, void *resaddr, MPI_Datatype type,
2705 int rank, MPI_Aint targdisp, WinStruct *win) noexcept;
2706 char* winLocalCompareAndSwap(int size, char* sorgaddr, char* compaddr, MPI_Datatype type,
2707 MPI_Aint targdisp, int winIndex) noexcept;
2708 AmpiMsg* winRemoteCompareAndSwap(int size, char *sorgaddr, char *compaddr, MPI_Datatype type,
2709 MPI_Aint targdisp, int winIndex) noexcept;
2710 void winSetName(WinStruct *win, const char *name) noexcept;
2711 void winGetName(WinStruct *win, char *name, int *length) const noexcept;
2712 win_obj* getWinObjInstance(WinStruct *win) const noexcept;
2713 int getNewSemaId() noexcept;
2715 int intercomm_scatter(int root, const void *sendbuf, int sendcount, MPI_Datatype sendtype,
2716 void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm intercomm) noexcept;
2717 int intercomm_iscatter(int root, const void *sendbuf, int sendcount, MPI_Datatype sendtype,
2718 void *recvbuf, int recvcount, MPI_Datatype recvtype,
2719 MPI_Comm intercomm, MPI_Request *request) noexcept;
2720 int intercomm_scatterv(int root, const void* sendbuf, const int* sendcounts, const int* displs,
2721 MPI_Datatype sendtype, void* recvbuf, int recvcount,
2722 MPI_Datatype recvtype, MPI_Comm intercomm) noexcept;
2723 int intercomm_iscatterv(int root, const void* sendbuf, const int* sendcounts, const int* displs,
2724 MPI_Datatype sendtype, void* recvbuf, int recvcount,
2725 MPI_Datatype recvtype, MPI_Comm intercomm, MPI_Request* request) noexcept;
2728 ampiParent *getAmpiParent() noexcept;
2729 bool isAmpiThread() noexcept;
2730 ampi *getAmpiInstance(MPI_Comm comm) noexcept;
2731 void checkComm(MPI_Comm comm) noexcept;
2732 void checkRequest(MPI_Request req) noexcept;
2733 void handle_MPI_BOTTOM(void* &buf, MPI_Datatype type) noexcept;
2734 void handle_MPI_BOTTOM(void* &buf1, MPI_Datatype type1, void* &buf2, MPI_Datatype type2) noexcept;
2736 #if AMPI_ERROR_CHECKING
2737 int ampiErrhandler(const char* func, int errcode) noexcept;
2738 #else
2739 #define ampiErrhandler(func, errcode) (errcode)
2740 #endif
2743 #if CMK_TRACE_ENABLED
2745 // List of AMPI functions to trace:
2746 static const char *funclist[] = {"AMPI_Abort", "AMPI_Add_error_class", "AMPI_Add_error_code", "AMPI_Add_error_string",
2747 "AMPI_Address", "AMPI_Allgather", "AMPI_Allgatherv", "AMPI_Allreduce", "AMPI_Alltoall",
2748 "AMPI_Alltoallv", "AMPI_Alltoallw", "AMPI_Attr_delete", "AMPI_Attr_get",
2749 "AMPI_Attr_put", "AMPI_Barrier", "AMPI_Bcast", "AMPI_Bsend", "AMPI_Cancel",
2750 "AMPI_Cart_coords", "AMPI_Cart_create", "AMPI_Cart_get", "AMPI_Cart_map",
2751 "AMPI_Cart_rank", "AMPI_Cart_shift", "AMPI_Cart_sub", "AMPI_Cartdim_get",
2752 "AMPI_Comm_call_errhandler", "AMPI_Comm_compare", "AMPI_Comm_create", "AMPI_Comm_create_group",
2753 "AMPI_Comm_create_errhandler", "AMPI_Comm_create_keyval", "AMPI_Comm_delete_attr",
2754 "AMPI_Comm_dup", "AMPI_Comm_dup_with_info", "AMPI_Comm_free",
2755 "AMPI_Comm_free_errhandler", "AMPI_Comm_free_keyval", "AMPI_Comm_get_attr",
2756 "AMPI_Comm_get_errhandler", "AMPI_Comm_get_info", "AMPI_Comm_get_name",
2757 "AMPI_Comm_group", "AMPI_Comm_rank", "AMPI_Comm_remote_group", "AMPI_Comm_remote_size",
2758 "AMPI_Comm_set_attr", "AMPI_Comm_set_errhandler", "AMPI_Comm_set_info", "AMPI_Comm_set_name",
2759 "AMPI_Comm_size", "AMPI_Comm_split", "AMPI_Comm_split_type", "AMPI_Comm_test_inter",
2760 "AMPI_Dims_create", "AMPI_Dist_graph_create", "AMPI_Dist_graph_create_adjacent",
2761 "AMPI_Dist_graph_neighbors", "AMPI_Dist_graph_neighbors_count",
2762 "AMPI_Errhandler_create", "AMPI_Errhandler_free", "AMPI_Errhandler_get",
2763 "AMPI_Errhandler_set", "AMPI_Error_class", "AMPI_Error_string", "AMPI_Exscan", "AMPI_Finalize",
2764 "AMPI_Finalized", "AMPI_Gather", "AMPI_Gatherv", "AMPI_Get_address", "AMPI_Get_count",
2765 "AMPI_Get_elements", "AMPI_Get_library_version", "AMPI_Get_processor_name", "AMPI_Get_version",
2766 "AMPI_Graph_create", "AMPI_Graph_get", "AMPI_Graph_map", "AMPI_Graph_neighbors",
2767 "AMPI_Graph_neighbors_count", "AMPI_Graphdims_get", "AMPI_Group_compare", "AMPI_Group_difference",
2768 "AMPI_Group_excl", "AMPI_Group_free", "AMPI_Group_incl", "AMPI_Group_intersection",
2769 "AMPI_Group_range_excl", "AMPI_Group_range_incl", "AMPI_Group_rank", "AMPI_Group_size",
2770 "AMPI_Group_translate_ranks", "AMPI_Group_union", "AMPI_Iallgather", "AMPI_Iallgatherv",
2771 "AMPI_Iallreduce", "AMPI_Ialltoall", "AMPI_Ialltoallv", "AMPI_Ialltoallw", "AMPI_Ibarrier",
2772 "AMPI_Ibcast", "AMPI_Iexscan", "AMPI_Igather", "AMPI_Igatherv", "AMPI_Ineighbor_allgather",
2773 "AMPI_Ineighbor_allgatherv", "AMPI_Ineighbor_alltoall", "AMPI_Ineighbor_alltoallv",
2774 "AMPI_Ineighbor_alltoallw", "AMPI_Init", "AMPI_Init_thread", "AMPI_Initialized", "AMPI_Intercomm_create",
2775 "AMPI_Intercomm_merge", "AMPI_Iprobe", "AMPI_Irecv", "AMPI_Ireduce", "AMPI_Ireduce_scatter",
2776 "AMPI_Ireduce_scatter_block", "AMPI_Is_thread_main", "AMPI_Iscan", "AMPI_Iscatter", "AMPI_Iscatterv",
2777 "AMPI_Isend", "AMPI_Issend", "AMPI_Keyval_create", "AMPI_Keyval_free", "AMPI_Neighbor_allgather",
2778 "AMPI_Neighbor_allgatherv", "AMPI_Neighbor_alltoall", "AMPI_Neighbor_alltoallv", "AMPI_Neighbor_alltoallw",
2779 "AMPI_Op_commutative", "AMPI_Op_create", "AMPI_Op_free", "AMPI_Pack", "AMPI_Pack_size",
2780 "AMPI_Pcontrol", "AMPI_Probe", "AMPI_Query_thread", "AMPI_Recv", "AMPI_Recv_init", "AMPI_Reduce",
2781 "AMPI_Reduce_local", "AMPI_Reduce_scatter", "AMPI_Reduce_scatter_block", "AMPI_Request_free",
2782 "AMPI_Request_get_status", "AMPI_Rsend", "AMPI_Scan", "AMPI_Scatter", "AMPI_Scatterv", "AMPI_Send",
2783 "AMPI_Send_init", "AMPI_Sendrecv", "AMPI_Sendrecv_replace", "AMPI_Ssend", "AMPI_Ssend_init",
2784 "AMPI_Start", "AMPI_Startall", "AMPI_Status_set_cancelled", "AMPI_Status_set_elements", "AMPI_Test",
2785 "AMPI_Test_cancelled", "AMPI_Testall", "AMPI_Testany", "AMPI_Testsome", "AMPI_Topo_test",
2786 "AMPI_Type_commit", "AMPI_Type_contiguous", "AMPI_Type_create_hindexed",
2787 "AMPI_Type_create_hindexed_block", "AMPI_Type_create_hvector", "AMPI_Type_create_indexed_block",
2788 "AMPI_Type_create_keyval", "AMPI_Type_create_resized", "AMPI_Type_create_struct",
2789 "AMPI_Type_delete_attr", "AMPI_Type_dup", "AMPI_Type_extent", "AMPI_Type_free",
2790 "AMPI_Type_free_keyval", "AMPI_Type_get_attr", "AMPI_Type_get_contents", "AMPI_Type_get_envelope",
2791 "AMPI_Type_get_extent", "AMPI_Type_get_name", "AMPI_Type_get_true_extent", "AMPI_Type_hindexed",
2792 "AMPI_Type_hvector", "AMPI_Type_indexed", "AMPI_Type_lb", "AMPI_Type_set_attr",
2793 "AMPI_Type_set_name", "AMPI_Type_size", "AMPI_Type_struct", "AMPI_Type_ub", "AMPI_Type_vector",
2794 "AMPI_Unpack", "AMPI_Wait", "AMPI_Waitall", "AMPI_Waitany", "AMPI_Waitsome", "AMPI_Wtick", "AMPI_Wtime",
2795 "AMPI_Accumulate", "AMPI_Compare_and_swap", "AMPI_Fetch_and_op", "AMPI_Get", "AMPI_Get_accumulate",
2796 "AMPI_Info_create", "AMPI_Info_delete", "AMPI_Info_dup", "AMPI_Info_free", "AMPI_Info_get",
2797 "AMPI_Info_get_nkeys", "AMPI_Info_get_nthkey", "AMPI_Info_get_valuelen",
2798 "AMPI_Info_set", "AMPI_Put", "AMPI_Raccumulate", "AMPI_Rget", "AMPI_Rget_accumulate",
2799 "AMPI_Rput", "AMPI_Win_complete", "AMPI_Win_create", "AMPI_Win_create_errhandler",
2800 "AMPI_Win_create_keyval", "AMPI_Win_delete_attr", "AMPI_Win_fence", "AMPI_Win_free",
2801 "AMPI_Win_free_keyval", "AMPI_Win_get_attr", "AMPI_Win_get_errhandler",
2802 "AMPI_Win_get_group", "AMPI_Win_get_info", "AMPI_Win_get_name", "AMPI_Win_lock",
2803 "AMPI_Win_post", "AMPI_Win_set_attr", "AMPI_Win_set_errhandler", "AMPI_Win_set_info",
2804 "AMPI_Win_set_name", "AMPI_Win_start", "AMPI_Win_test", "AMPI_Win_unlock",
2805 "AMPI_Win_wait", "AMPI_Exit" /*AMPI extensions:*/, "AMPI_Migrate",
2806 "AMPI_Load_start_measure", "AMPI_Load_stop_measure",
2807 "AMPI_Load_set_value", "AMPI_Migrate_to_pe", "AMPI_Set_migratable",
2808 "AMPI_Register_pup", "AMPI_Get_pup_data", "AMPI_Register_main",
2809 "AMPI_Register_about_to_migrate", "AMPI_Register_just_migrated",
2810 "AMPI_Iget", "AMPI_Iget_wait", "AMPI_Iget_free", "AMPI_Iget_data",
2811 "AMPI_Type_is_contiguous", "AMPI_Yield", "AMPI_Suspend",
2812 "AMPI_Resume", "AMPI_Print", "AMPI_Alltoall_medium",
2813 "AMPI_Alltoall_long", "AMPI_System"};
2815 // not traced: AMPI_Trace_begin, AMPI_Trace_end
2817 #endif // CMK_TRACE_ENABLED
2819 //Use this to mark the start of AMPI interface routines that can only be called on AMPI threads:
2820 #if CMK_ERROR_CHECKING
2821 #define AMPI_API(routineName) \
2822 if (!isAmpiThread()) { CkAbort("AMPI> cannot call MPI routines from non-AMPI threads!"); } \
2823 TCHARM_API_TRACE(routineName, "ampi");
2824 #else
2825 #define AMPI_API(routineName) TCHARM_API_TRACE(routineName, "ampi")
2826 #endif
2828 //Use this for MPI_Init and routines than can be called before AMPI threads have been initialized:
2829 #define AMPI_API_INIT(routineName) TCHARM_API_TRACE(routineName, "ampi")
2831 #endif // _AMPIIMPL_H