Graybat  1.1
Graph Approach for Highly Generic Communication Schemes Based on Adaptive Topologies
Base.hpp
1 #pragma once
2 
3 #include <graybat/communicationPolicy/Traits.hpp>
4 
5 namespace graybat {
6 
7  namespace communicationPolicy {
8 
19  template <typename T_CommunicationPolicy>
20  struct Base {
21 
22  using CommunicationPolicy = T_CommunicationPolicy;
23  using VAddr = typename graybat::communicationPolicy::VAddr<CommunicationPolicy>;
24  using Tag = typename graybat::communicationPolicy::Tag<CommunicationPolicy>;
25  using Context = typename graybat::communicationPolicy::Context<CommunicationPolicy>;
26  using Event = typename graybat::communicationPolicy::Event<CommunicationPolicy>;
27 
28  // TODO
29  // ====
30  //
31  // Is there a way to prevent a lot of functions for
32  // slightly different functionality regarding the
33  // following options:
34  //
35  // * Blocking / Non Blocking
36  // * Var / Non Var
37  // * All / Single Receive
38  //
39 
40  /***********************************************************************
41  * Interface
42  ***********************************************************************/
43 
44  /***********************************************************************/
51  template <typename T_Send>
52  void send(const VAddr destVAddr, const Tag tag, const Context context, const T_Send& sendData) = delete;
53 
54  template <typename T_Send>
55  Event asyncSend(const VAddr destVAddr, const Tag tag, const Context context, const T_Send& sendData) = delete;
56 
57  template <typename T_Recv>
58  void recv(const VAddr srcVAddr, const Tag tag, const Context context, T_Recv& recvData) = delete;
59 
60  template <typename T_Recv>
61  Event recv(const Context context, T_Recv& recvData) = delete;
62 
63  template <typename T_Recv>
64  Event asyncRecv(const VAddr srcVAddr, const Tag tag, const Context context, T_Recv& recvData) = delete;
67  /************************************************************************/
88  template <typename T_Send, typename T_Recv>
89  void gather(const VAddr rootVAddr, const Context context, const T_Send& sendData, T_Recv& recvData);
90 
105  template <typename T_Send, typename T_Recv>
106  void gatherVar(const VAddr rootVAddr, const Context context, const T_Send& sendData, T_Recv& recvData, std::vector<unsigned>& recvCount);
107 
117  template <typename T_Send, typename T_Recv>
118  void allGather(Context context, const T_Send& sendData, T_Recv& recvData);
119 
130  template <typename T_Send, typename T_Recv>
131  void allGatherVar(const Context context, const T_Send& sendData, T_Recv& recvData, std::vector<unsigned>& recvCount);
132 
145  template <typename T_Send, typename T_Recv>
146  void scatter(const VAddr rootVAddr, const Context context, const T_Send& sendData, T_Recv& recvData);
147 
160  template <typename T_Send, typename T_Recv>
161  void allScatter(const Context context, const T_Send& sendData, T_Recv& recvData);
162 
179  template <typename T_Send, typename T_Recv, typename T_Op>
180  void reduce(const VAddr rootVAddr, const Context context, const T_Op op, const T_Send& sendData, T_Recv& recvData);
181 
194  template <typename T_Send, typename T_Recv, typename T_Op>
195  void allReduce(const Context context, T_Op op, const T_Send& sendData, T_Recv& recvData);
196 
208  template <typename T_SendRecv>
209  void broadcast(const VAddr rootVAddr, const Context context, T_SendRecv& data);
210 
216  void synchronize(const Context context);
219  };
220 
221  /***********************************************************************
222  * Implementation
223  ***********************************************************************/
224  template <typename T_CommunicationPolicy>
225  template <typename T_Send, typename T_Recv>
226  void Base<T_CommunicationPolicy>::gather(const VAddr rootVAddr, const Context context, const T_Send& sendData, T_Recv& recvData){
227  using RecvValueType = typename T_Recv::value_type;
228  using CommunicationPolicy = T_CommunicationPolicy;
229  using Event = Base<CommunicationPolicy>::Event;
230 
231  Event e = static_cast<CommunicationPolicy*>(this)->asyncSend(rootVAddr, 0, context, sendData);
232 
233  if(rootVAddr == context.getVAddr()){
234  for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
235  size_t recvOffset = vAddr * sendData.size();
236  std::vector<RecvValueType> tmpData(sendData.size());
237  static_cast<CommunicationPolicy*>(this)->recv(vAddr, 0, context, tmpData);
238  std::copy(tmpData.begin(), tmpData.end(), recvData.begin() + recvOffset);
239 
240  }
241 
242  }
243 
244  }
245 
246  template <typename T_CommunicationPolicy>
247  template <typename T_Send, typename T_Recv>
248  void Base<T_CommunicationPolicy>::gatherVar(const VAddr rootVAddr, const Context context, const T_Send& sendData, T_Recv& recvData, std::vector<unsigned>& recvCount){
249  using RecvValueType = typename T_Recv::value_type;
250  using CommunicationPolicy = T_CommunicationPolicy;
251  using Event = Base<CommunicationPolicy>::Event;
252 
253  std::array<unsigned, 1> nElements{{(unsigned)sendData.size()}};
254  recvCount.resize(context.size());
255  static_cast<CommunicationPolicy*>(this)->allGather(context, nElements, recvCount);
256  recvData.resize(std::accumulate(recvCount.begin(), recvCount.end(), 0U));
257 
258  Event e = static_cast<CommunicationPolicy*>(this)->asyncSend(rootVAddr, 0, context, sendData);
259 
260  if(rootVAddr == context.getVAddr()){
261  size_t recvOffset = 0;
262  for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
263  std::vector<RecvValueType> tmpData(recvCount.at(vAddr));
264  static_cast<CommunicationPolicy*>(this)->recv(vAddr, 0, context, tmpData);
265  std::copy(tmpData.begin(), tmpData.end(), recvData.begin() + recvOffset);
266  recvOffset += recvCount.at(vAddr);
267 
268  }
269 
270  }
271 
272  }
273 
274 
275 
276  template <typename T_CommunicationPolicy>
277  template <typename T_Send, typename T_Recv>
278  void Base<T_CommunicationPolicy>::allGather(Context context, const T_Send& sendData, T_Recv& recvData){
279  using RecvValueType = typename T_Recv::value_type;
280  using CommunicationPolicy = T_CommunicationPolicy;
281  using Event = Base<CommunicationPolicy>::Event;
282 
283  for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
284  Event e = static_cast<CommunicationPolicy*>(this)->asyncSend(vAddr, 0, context, sendData);
285 
286  }
287 
288  for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
289  size_t recvOffset = vAddr * sendData.size();
290  std::vector<RecvValueType> tmpData(sendData.size());
291  static_cast<CommunicationPolicy*>(this)->recv(vAddr, 0, context, tmpData);
292  std::copy(tmpData.begin(), tmpData.end(), recvData.begin() + recvOffset);
293 
294  }
295 
296  }
297 
298  template <typename T_CommunicationPolicy>
299  template <typename T_Send, typename T_Recv>
300  void Base<T_CommunicationPolicy>::allGatherVar(const Context context, const T_Send& sendData, T_Recv& recvData, std::vector<unsigned>& recvCount){
301  using RecvValueType = typename T_Recv::value_type;
302  using CommunicationPolicy = T_CommunicationPolicy;
303  using Event = Base<CommunicationPolicy>::Event;
304 
305  std::array<unsigned, 1> nElements{{(unsigned)sendData.size()}};
306  recvCount.resize(context.size());
307  static_cast<CommunicationPolicy*>(this)->allGather(context, nElements, recvCount);
308  recvData.resize(std::accumulate(recvCount.begin(), recvCount.end(), 0U));
309 
310  for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
311  Event e = static_cast<CommunicationPolicy*>(this)->asyncSend(vAddr, 0, context, sendData);
312  }
313 
314  size_t recvOffset = 0;
315  for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
316  std::vector<RecvValueType> tmpData(recvCount.at(vAddr));
317  static_cast<CommunicationPolicy*>(this)->recv(vAddr, 0, context, tmpData);
318  std::copy(tmpData.begin(), tmpData.end(), recvData.begin() + recvOffset);
319  recvOffset += recvCount.at(vAddr);
320 
321  }
322 
323 
324  }
325 
326  template <typename T_CommunicationPolicy>
327  template <typename T_Send, typename T_Recv>
328  void Base<T_CommunicationPolicy>::scatter(const VAddr rootVAddr, const Context context, const T_Send& sendData, T_Recv& recvData){
329  using SendValueType = typename T_Recv::value_type;
330  using CommunicationPolicy = T_CommunicationPolicy;
331  using Event = Base<CommunicationPolicy>::Event;
332 
333  if(rootVAddr == context.getVAddr()){
334  for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
335  size_t sendOffset = vAddr * recvData.size();
336  std::vector<SendValueType> tmpData(sendData.begin() + sendOffset,
337  sendData.begin() + sendOffset + recvData.size());
338  Event e = static_cast<CommunicationPolicy*>(this)->asyncSend(vAddr, 0, context, tmpData);
339 
340  }
341 
342  }
343 
344  static_cast<CommunicationPolicy*>(this)->recv(rootVAddr, 0, context, recvData);
345 
346  }
347 
348  template <typename T_CommunicationPolicy>
349  template <typename T_Send, typename T_Recv>
350  void Base<T_CommunicationPolicy>::allScatter(const Context context, const T_Send& sendData, T_Recv& recvData){
351  using SendValueType = typename T_Recv::value_type;
352  using CommunicationPolicy = T_CommunicationPolicy;
353  using Event = Base<CommunicationPolicy>::Event;
354 
355  size_t nElementsPerPeer = static_cast<size_t>(recvData.size() / context.size());
356 
357  for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
358  size_t sendOffset = vAddr * nElementsPerPeer;
359  std::vector<SendValueType> tmpData(sendData.begin() + sendOffset,
360  sendData.begin() + sendOffset + nElementsPerPeer);
361  Event e = static_cast<CommunicationPolicy*>(this)->asyncSend(vAddr, 0, context, tmpData);
362 
363  }
364 
365  for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
366  size_t recvOffset = vAddr * nElementsPerPeer;
367  std::vector<SendValueType> tmpData(nElementsPerPeer);
368  static_cast<CommunicationPolicy*>(this)->recv(vAddr, 0, context, tmpData);
369  std::copy(tmpData.begin(), tmpData.end(), recvData.begin() + recvOffset);
370 
371  }
372 
373  }
374 
375 
376  template <typename T_CommunicationPolicy>
377  template <typename T_Send, typename T_Recv, typename T_Op>
378  void Base<T_CommunicationPolicy>::reduce(const VAddr rootVAddr, const Context context, const T_Op op, const T_Send& sendData, T_Recv& recvData){
379  using RecvValueType = typename T_Recv::value_type;
380  using CommunicationPolicy = T_CommunicationPolicy;
381  using Event = Base<CommunicationPolicy>::Event;
382 
383  Event e = static_cast<CommunicationPolicy*>(this)->asyncSend(rootVAddr, 0, context, sendData);
384 
385  if(rootVAddr == context.getVAddr()){
386  static_cast<CommunicationPolicy*>(this)->recv(0, 0, context, recvData);
387 
388  for(VAddr vAddr = 1; vAddr < context.size(); vAddr++){
389  std::vector<RecvValueType> tmpData(recvData.size());
390  static_cast<CommunicationPolicy*>(this)->recv(vAddr, 0, context, tmpData);
391 
392  for(size_t i = 0; i < recvData.size(); ++i){
393  recvData[i] = op(recvData[i], tmpData[i]);
394  }
395 
396  }
397 
398  }
399 
400  }
401 
402 
403  template <typename T_CommunicationPolicy>
404  template <typename T_Send, typename T_Recv, typename T_Op>
405  void Base<T_CommunicationPolicy>::allReduce(const Context context, T_Op op, const T_Send& sendData, T_Recv& recvData){
406  using RecvValueType = typename T_Recv::value_type;
407  using CommunicationPolicy = T_CommunicationPolicy;
408  using Event = Base<CommunicationPolicy>::Event;
409 
410  for(VAddr vAddr = 1; vAddr < context.size(); vAddr++){
411  Event e = static_cast<CommunicationPolicy*>(this)->asyncSend(vAddr, 0, context, sendData);
412  }
413 
414  static_cast<CommunicationPolicy*>(this)->recv(0, 0, context, recvData);
415 
416  for(VAddr vAddr = 1; vAddr < context.size(); vAddr++){
417  std::vector<RecvValueType> tmpData(recvData.size());
418  static_cast<CommunicationPolicy*>(this)->recv(vAddr, 0, context, tmpData);
419 
420  for(size_t i = 0; i < recvData.size(); ++i){
421  recvData[i] = op(recvData[i], tmpData[i]);
422  }
423 
424  }
425 
426  }
427 
428 
429  template <typename T_CommunicationPolicy>
430  template <typename T_SendRecv>
431  void Base<T_CommunicationPolicy>::broadcast(const VAddr rootVAddr, const Context context, T_SendRecv& data){
432  using CommunicationPolicy = T_CommunicationPolicy;
433 
434  if(rootVAddr == context.getVAddr()){
435  for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
436  static_cast<CommunicationPolicy*>(this)->asyncSend(vAddr, 0, context, data);
437 
438  }
439 
440  }
441 
442  static_cast<CommunicationPolicy*>(this)->recv(rootVAddr, 0, context, data);
443 
444  }
445 
446 
447  template <typename T_CommunicationPolicy>
448  void Base<T_CommunicationPolicy>::synchronize(const Context context){
449  std::array<char, 0> null;
450 
451  if(context.getVAddr() == 0){
452  for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
453  static_cast<CommunicationPolicy*>(this)->recv(vAddr, 0, context, null);
454  }
455  for(VAddr vAddr = 0; vAddr < context.size(); vAddr++){
456  static_cast<CommunicationPolicy*>(this)->send(vAddr, 0, context, null);
457  }
458 
459  }
460  else {
461  static_cast<CommunicationPolicy*>(this)->send(0, 0, context, null);
462  static_cast<CommunicationPolicy*>(this)->recv(0, 0, context, null);
463  }
464 
465  }
466 
467  } // namespace communicationPolicy
468 
469 } // namespace graybat
470 
Definition: chain.cpp:31
void allScatter(const Context context, const T_Send &sendData, T_Recv &recvData)
Distributes sendData of all peer in the context to all peers in the context. Every peer will receive ...
Definition: Base.hpp:350
void allGatherVar(const Context context, const T_Send &sendData, T_Recv &recvData, std::vector< unsigned > &recvCount)
Collects sendData from all peers of the context. Size of sendData can vary in size. The data is received by every peer in the context.
Definition: Base.hpp:300
void gather(const VAddr rootVAddr, const Context context, const T_Send &sendData, T_Recv &recvData)
Collects sendData from all peers of the context and transmits it as a list to the peer with rootVAddr...
Definition: Base.hpp:226
void gatherVar(const VAddr rootVAddr, const Context context, const T_Send &sendData, T_Recv &recvData, std::vector< unsigned > &recvCount)
Collects sendData from all members of the context with varying size and transmits it as a list to pee...
Definition: Base.hpp:248
Definition: Base.hpp:20
void synchronize(const Context context)
Synchronizes all peers within context to the same point in the programm execution (barrier)...
Definition: Base.hpp:448
void scatter(const VAddr rootVAddr, const Context context, const T_Send &sendData, T_Recv &recvData)
Distributes sendData from peer rootVAddr to all peers in context. Every peer will receive different d...
Definition: Base.hpp:328
void allGather(Context context, const T_Send &sendData, T_Recv &recvData)
Collects sendData from all members of the context and transmits it as a list to every peer in the con...
Definition: Base.hpp:278
void reduce(const VAddr rootVAddr, const Context context, const T_Op op, const T_Send &sendData, T_Recv &recvData)
Performs a reduction with a binary operator op on all sendData elements from all peers whithin the co...
Definition: Base.hpp:378
void broadcast(const VAddr rootVAddr, const Context context, T_SendRecv &data)
Send sendData from peer rootVAddr to all peers in context. Every peer will receive the same data...
Definition: Base.hpp:431
void allReduce(const Context context, T_Op op, const T_Send &sendData, T_Recv &recvData)
Performs a reduction with a binary operator op on all sendData elements from all peers whithin the co...
Definition: Base.hpp:405
Definition: BiStar.hpp:8