3 #include <graybat/communicationPolicy/Traits.hpp>
7 namespace communicationPolicy {
19 template <
typename T_CommunicationPolicy>
22 using CommunicationPolicy = T_CommunicationPolicy;
23 using VAddr =
typename graybat::communicationPolicy::VAddr<CommunicationPolicy>;
24 using Tag =
typename graybat::communicationPolicy::Tag<CommunicationPolicy>;
25 using Context =
typename graybat::communicationPolicy::Context<CommunicationPolicy>;
26 using Event =
typename graybat::communicationPolicy::Event<CommunicationPolicy>;
51 template <
typename T_Send>
52 void send(
const VAddr destVAddr,
const Tag tag,
const Context context,
const T_Send& sendData) =
delete;
54 template <
typename T_Send>
55 Event asyncSend(
const VAddr destVAddr,
const Tag tag,
const Context context,
const T_Send& sendData) =
delete;
57 template <
typename T_Recv>
58 void recv(
const VAddr srcVAddr,
const Tag tag,
const Context context, T_Recv& recvData) =
delete;
60 template <
typename T_Recv>
61 Event recv(
const Context context, T_Recv& recvData) =
delete;
63 template <
typename T_Recv>
64 Event asyncRecv(
const VAddr srcVAddr,
const Tag tag,
const Context context, T_Recv& recvData) =
delete;
88 template <
typename T_Send,
typename T_Recv>
89 void gather(
const VAddr rootVAddr,
const Context context,
const T_Send& sendData, T_Recv& recvData);
105 template <
typename T_Send,
typename T_Recv>
106 void gatherVar(
const VAddr rootVAddr,
const Context context,
const T_Send& sendData, T_Recv& recvData, std::vector<unsigned>& recvCount);
117 template <
typename T_Send,
typename T_Recv>
118 void allGather(Context context,
const T_Send& sendData, T_Recv& recvData);
130 template <
typename T_Send,
typename T_Recv>
131 void allGatherVar(
const Context context,
const T_Send& sendData, T_Recv& recvData, std::vector<unsigned>& recvCount);
145 template <
typename T_Send,
typename T_Recv>
146 void scatter(
const VAddr rootVAddr,
const Context context,
const T_Send& sendData, T_Recv& recvData);
160 template <
typename T_Send,
typename T_Recv>
161 void allScatter(
const Context context,
const T_Send& sendData, T_Recv& recvData);
179 template <
typename T_Send,
typename T_Recv,
typename T_Op>
180 void reduce(
const VAddr rootVAddr,
const Context context,
const T_Op op,
const T_Send& sendData, T_Recv& recvData);
194 template <
typename T_Send,
typename T_Recv,
typename T_Op>
195 void allReduce(
const Context context, T_Op op,
const T_Send& sendData, T_Recv& recvData);
208 template <
typename T_SendRecv>
209 void broadcast(
const VAddr rootVAddr,
const Context context, T_SendRecv& data);
224 template <
typename T_CommunicationPolicy>
225 template <
typename T_Send,
typename T_Recv>
227 using RecvValueType =
typename T_Recv::value_type;
228 using CommunicationPolicy = T_CommunicationPolicy;
229 using Event = Base<CommunicationPolicy>::Event;
231 Event e =
static_cast<CommunicationPolicy*
>(
this)->asyncSend(rootVAddr, 0, context, sendData);
233 if(rootVAddr == context.getVAddr()){
234 for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
235 size_t recvOffset = vAddr * sendData.size();
236 std::vector<RecvValueType> tmpData(sendData.size());
237 static_cast<CommunicationPolicy*
>(
this)->recv(vAddr, 0, context, tmpData);
238 std::copy(tmpData.begin(), tmpData.end(), recvData.begin() + recvOffset);
246 template <
typename T_CommunicationPolicy>
247 template <
typename T_Send,
typename T_Recv>
249 using RecvValueType =
typename T_Recv::value_type;
250 using CommunicationPolicy = T_CommunicationPolicy;
251 using Event = Base<CommunicationPolicy>::Event;
253 std::array<unsigned, 1> nElements{{(unsigned)sendData.size()}};
254 recvCount.resize(context.size());
255 static_cast<CommunicationPolicy*
>(
this)->allGather(context, nElements, recvCount);
256 recvData.resize(std::accumulate(recvCount.begin(), recvCount.end(), 0U));
258 Event e =
static_cast<CommunicationPolicy*
>(
this)->asyncSend(rootVAddr, 0, context, sendData);
260 if(rootVAddr == context.getVAddr()){
261 size_t recvOffset = 0;
262 for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
263 std::vector<RecvValueType> tmpData(recvCount.at(vAddr));
264 static_cast<CommunicationPolicy*
>(
this)->recv(vAddr, 0, context, tmpData);
265 std::copy(tmpData.begin(), tmpData.end(), recvData.begin() + recvOffset);
266 recvOffset += recvCount.at(vAddr);
276 template <
typename T_CommunicationPolicy>
277 template <
typename T_Send,
typename T_Recv>
279 using RecvValueType =
typename T_Recv::value_type;
280 using CommunicationPolicy = T_CommunicationPolicy;
281 using Event = Base<CommunicationPolicy>::Event;
283 for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
284 Event e =
static_cast<CommunicationPolicy*
>(
this)->asyncSend(vAddr, 0, context, sendData);
288 for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
289 size_t recvOffset = vAddr * sendData.size();
290 std::vector<RecvValueType> tmpData(sendData.size());
291 static_cast<CommunicationPolicy*
>(
this)->recv(vAddr, 0, context, tmpData);
292 std::copy(tmpData.begin(), tmpData.end(), recvData.begin() + recvOffset);
298 template <
typename T_CommunicationPolicy>
299 template <
typename T_Send,
typename T_Recv>
301 using RecvValueType =
typename T_Recv::value_type;
302 using CommunicationPolicy = T_CommunicationPolicy;
303 using Event = Base<CommunicationPolicy>::Event;
305 std::array<unsigned, 1> nElements{{(unsigned)sendData.size()}};
306 recvCount.resize(context.size());
307 static_cast<CommunicationPolicy*
>(
this)->allGather(context, nElements, recvCount);
308 recvData.resize(std::accumulate(recvCount.begin(), recvCount.end(), 0U));
310 for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
311 Event e =
static_cast<CommunicationPolicy*
>(
this)->asyncSend(vAddr, 0, context, sendData);
314 size_t recvOffset = 0;
315 for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
316 std::vector<RecvValueType> tmpData(recvCount.at(vAddr));
317 static_cast<CommunicationPolicy*
>(
this)->recv(vAddr, 0, context, tmpData);
318 std::copy(tmpData.begin(), tmpData.end(), recvData.begin() + recvOffset);
319 recvOffset += recvCount.at(vAddr);
326 template <
typename T_CommunicationPolicy>
327 template <
typename T_Send,
typename T_Recv>
329 using SendValueType =
typename T_Recv::value_type;
330 using CommunicationPolicy = T_CommunicationPolicy;
331 using Event = Base<CommunicationPolicy>::Event;
333 if(rootVAddr == context.getVAddr()){
334 for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
335 size_t sendOffset = vAddr * recvData.size();
336 std::vector<SendValueType> tmpData(sendData.begin() + sendOffset,
337 sendData.begin() + sendOffset + recvData.size());
338 Event e =
static_cast<CommunicationPolicy*
>(
this)->asyncSend(vAddr, 0, context, tmpData);
344 static_cast<CommunicationPolicy*
>(
this)->recv(rootVAddr, 0, context, recvData);
348 template <
typename T_CommunicationPolicy>
349 template <
typename T_Send,
typename T_Recv>
351 using SendValueType =
typename T_Recv::value_type;
352 using CommunicationPolicy = T_CommunicationPolicy;
353 using Event = Base<CommunicationPolicy>::Event;
355 size_t nElementsPerPeer =
static_cast<size_t>(recvData.size() / context.size());
357 for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
358 size_t sendOffset = vAddr * nElementsPerPeer;
359 std::vector<SendValueType> tmpData(sendData.begin() + sendOffset,
360 sendData.begin() + sendOffset + nElementsPerPeer);
361 Event e =
static_cast<CommunicationPolicy*
>(
this)->asyncSend(vAddr, 0, context, tmpData);
365 for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
366 size_t recvOffset = vAddr * nElementsPerPeer;
367 std::vector<SendValueType> tmpData(nElementsPerPeer);
368 static_cast<CommunicationPolicy*
>(
this)->recv(vAddr, 0, context, tmpData);
369 std::copy(tmpData.begin(), tmpData.end(), recvData.begin() + recvOffset);
376 template <
typename T_CommunicationPolicy>
377 template <
typename T_Send,
typename T_Recv,
typename T_Op>
379 using RecvValueType =
typename T_Recv::value_type;
380 using CommunicationPolicy = T_CommunicationPolicy;
381 using Event = Base<CommunicationPolicy>::Event;
383 Event e =
static_cast<CommunicationPolicy*
>(
this)->asyncSend(rootVAddr, 0, context, sendData);
385 if(rootVAddr == context.getVAddr()){
386 static_cast<CommunicationPolicy*
>(
this)->recv(0, 0, context, recvData);
388 for(VAddr vAddr = 1; vAddr < context.size(); vAddr++){
389 std::vector<RecvValueType> tmpData(recvData.size());
390 static_cast<CommunicationPolicy*
>(
this)->recv(vAddr, 0, context, tmpData);
392 for(
size_t i = 0; i < recvData.size(); ++i){
393 recvData[i] = op(recvData[i], tmpData[i]);
403 template <
typename T_CommunicationPolicy>
404 template <
typename T_Send,
typename T_Recv,
typename T_Op>
406 using RecvValueType =
typename T_Recv::value_type;
407 using CommunicationPolicy = T_CommunicationPolicy;
408 using Event = Base<CommunicationPolicy>::Event;
410 for(VAddr vAddr = 1; vAddr < context.size(); vAddr++){
411 Event e =
static_cast<CommunicationPolicy*
>(
this)->asyncSend(vAddr, 0, context, sendData);
414 static_cast<CommunicationPolicy*
>(
this)->recv(0, 0, context, recvData);
416 for(VAddr vAddr = 1; vAddr < context.size(); vAddr++){
417 std::vector<RecvValueType> tmpData(recvData.size());
418 static_cast<CommunicationPolicy*
>(
this)->recv(vAddr, 0, context, tmpData);
420 for(
size_t i = 0; i < recvData.size(); ++i){
421 recvData[i] = op(recvData[i], tmpData[i]);
429 template <
typename T_CommunicationPolicy>
430 template <
typename T_SendRecv>
432 using CommunicationPolicy = T_CommunicationPolicy;
434 if(rootVAddr == context.getVAddr()){
435 for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
436 static_cast<CommunicationPolicy*
>(
this)->asyncSend(vAddr, 0, context, data);
442 static_cast<CommunicationPolicy*
>(
this)->recv(rootVAddr, 0, context, data);
447 template <
typename T_CommunicationPolicy>
449 std::array<char, 0> null;
451 if(context.getVAddr() == 0){
452 for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
453 static_cast<CommunicationPolicy*
>(
this)->recv(vAddr, 0, context, null);
455 for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
456 static_cast<CommunicationPolicy*
>(
this)->send(vAddr, 0, context, null);
461 static_cast<CommunicationPolicy*
>(
this)->send(0, 0, context, null);
462 static_cast<CommunicationPolicy*
>(
this)->recv(0, 0, context, null);
void allScatter(const Context context, const T_Send &sendData, T_Recv &recvData)
Distributes sendData of all peer in the context to all peers in the context. Every peer will receive ...
Definition: Base.hpp:350
void allGatherVar(const Context context, const T_Send &sendData, T_Recv &recvData, std::vector< unsigned > &recvCount)
Collects sendData from all peers of the context. Size of sendData can vary in size. The data is received by every peer in the context.
Definition: Base.hpp:300
void gather(const VAddr rootVAddr, const Context context, const T_Send &sendData, T_Recv &recvData)
Collects sendData from all peers of the context and transmits it as a list to the peer with rootVAddr...
Definition: Base.hpp:226
void gatherVar(const VAddr rootVAddr, const Context context, const T_Send &sendData, T_Recv &recvData, std::vector< unsigned > &recvCount)
Collects sendData from all members of the context with varying size and transmits it as a list to pee...
Definition: Base.hpp:248
void synchronize(const Context context)
Synchronizes all peers within context to the same point in the programm execution (barrier)...
Definition: Base.hpp:448
void scatter(const VAddr rootVAddr, const Context context, const T_Send &sendData, T_Recv &recvData)
Distributes sendData from peer rootVAddr to all peers in context. Every peer will receive different d...
Definition: Base.hpp:328
void allGather(Context context, const T_Send &sendData, T_Recv &recvData)
Collects sendData from all members of the context and transmits it as a list to every peer in the con...
Definition: Base.hpp:278
void reduce(const VAddr rootVAddr, const Context context, const T_Op op, const T_Send &sendData, T_Recv &recvData)
Performs a reduction with a binary operator op on all sendData elements from all peers whithin the co...
Definition: Base.hpp:378
void broadcast(const VAddr rootVAddr, const Context context, T_SendRecv &data)
Send sendData from peer rootVAddr to all peers in context. Every peer will receive the same data...
Definition: Base.hpp:431
void allReduce(const Context context, T_Op op, const T_Send &sendData, T_Recv &recvData)
Performs a reduction with a binary operator op on all sendData elements from all peers whithin the co...
Definition: Base.hpp:405