1#ifndef __IMPL_MPI_OP_HPP__ 
    2#define __IMPL_MPI_OP_HPP__ 
    4#include "common/common.hpp" 
    5#include <mpi_w/message_t.hpp> 
    6#include <mpi_w/mpi_types.hpp> 
    8#include <common/execinfo.hpp> 
   42  template <POD_t DataType>
 
   46                                        size_t tag = 0) noexcept;
 
   60  template <
POD_t DataType>
 
   61  int send(DataType data, 
size_t dest, 
size_t tag = 0);
 
   77  template <
POD_t DataType>
 
   78  int send_v(std::span<const DataType> data,
 
   81             bool send_size = true) noexcept;
 
  100  template <
POD_t DataType>
 
  101  std::optional<DataType>
 
  102  recv(
size_t src, MPI_Status* status = 
nullptr, 
size_t tag = 0) noexcept;
 
  119  template <
POD_t DataType>
 
  122                MPI_Status* status = 
nullptr,
 
  123                size_t tag = 0) noexcept;
 
  139  template <
POD_t DataType>
 
  140  DataType 
try_recv(
size_t src, MPI_Status* status = 
nullptr, 
size_t tag = 0);
 
  158  std::optional<std::vector<T>>
 
  159  recv_v(
size_t source, MPI_Status* status = 
nullptr, 
size_t tag = 0) noexcept;
 
  177  try_recv_v(
size_t src, MPI_Status* status = 
nullptr, 
size_t tag = 0);
 
  211  template <
POD_t DataType> 
int broadcast(DataType& data, 
size_t root) noexcept;
 
  239  template <
POD_t... Args>
 
  262  _gather_unsafe(T* src_data, 
size_t size, 
size_t n_rank, 
size_t root = 0);
 
  268                               size_t root = 0) noexcept;
 
  285  gather(std::span<T> local_data, 
size_t n_rank, 
size_t root = 0);
 
  289                   std::span<const T> local_data,
 
  294  gather(std::span<const T> local_data, 
size_t n_rank, 
size_t root = 0);
 
  308  template <NumberType T> T 
gather_reduce(T data, 
size_t root = 0);
 
  314        &data, &global_sum, 1, 
get_type<T>(), MPI_SUM, MPI_COMM_WORLD);
 
 
  333  gather_v(
const std::vector<T>& local_data, 
size_t n_rank, 
size_t root = 0);
 
  338  template <POD_t DataType>
 
  340  _send_unsafe(DataType* buf, 
size_t buf_size, 
size_t dest, 
size_t tag) 
noexcept 
 
  346  template <POD_t DataType>
 
  347  DataType 
try_recv(
size_t src, MPI_Status* status, 
size_t tag)
 
  350    if (!opt_data.has_value())
 
  357      return opt_data.value();
 
 
  362  std::optional<std::vector<T>>
 
  363  recv_v(
size_t source, MPI_Status* status, 
size_t tag) 
noexcept 
  369    if (!opt_size.has_value())
 
  373    size_t buf_size = opt_size.value();
 
  376    buf.resize(buf_size);
 
  381    int recv_status = MPI_Recv(buf.data(),
 
  382                               static_cast<int>(buf_size),
 
  389    if (recv_status != MPI_SUCCESS)
 
 
  397  template <POD_t DataType>
 
  403    return MPI_Recv(buf.data(),
 
 
  413  std::vector<T> 
try_recv_v(
size_t src, MPI_Status* status, 
size_t tag)
 
  416    if (!opt_data.has_value())
 
  422      return opt_data.value();
 
 
  426  template <POD_t DataType> 
int broadcast(DataType& data, 
size_t root) 
noexcept 
 
  442  int broadcast(std::vector<T>& data, 
size_t root, 
size_t current_rank)
 
  445    size_t data_size = 0;
 
  446    if (current_rank == root)
 
  448      data_size = data.size();
 
  451    if (current_rank != root)
 
  453      data.resize(data_size);
 
  456    return MPI_Bcast(data.data(),
 
  459                     static_cast<int>(root),
 
 
  468      throw std::invalid_argument(
"Data pointer is null");
 
  470    if (_size == 0 || _size > std::numeric_limits<size_t>::max())
 
  472      throw std::invalid_argument(
"Error size");
 
  476    MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
 
  477    if (root >= 
static_cast<size_t>(comm_size))
 
  479      throw std::invalid_argument(
"Root process rank is out of range");
 
  484        data, _size, 
get_type<T>(), 
static_cast<int>(root), MPI_COMM_WORLD);
 
 
  496    int src_size = 
static_cast<int>(size);
 
  497    std::vector<T> total_data(src_size * n_rank);
 
  498    T* dest_data = total_data.data();
 
  511      throw std::runtime_error(
"MPI_Gather failed");
 
 
  521                               size_t root) 
noexcept 
  523    int src_size = 
static_cast<int>(size);
 
  526    return MPI_Gather(src_data,
 
 
  538  gather_span(std::span<T> dest, std::span<const T> local_data, 
size_t root)
 
  540    T* dest_data = dest.data();
 
  543    MPI_Comm_size(MPI_COMM_WORLD, &size);
 
  544    assert(dest.size() == local_data.size() * size);
 
  547        dest_data, 
const_cast<T*
>(local_data.data()), local_data.size(), root);
 
 
  551  std::vector<T> 
gather(std::span<T> local_data, 
size_t n_rank, 
size_t root)
 
  553    return _gather_unsafe(local_data.data(), local_data.size(), n_rank, root);
 
 
  558  gather(std::span<const T> local_data, 
size_t n_rank, 
size_t root)
 
  561        const_cast<T*
>(local_data.data()), local_data.size(), n_rank, root);
 
  582  gather_v(
const std::vector<T>& local_data, 
size_t n_rank, 
size_t root)
 
  584    return _gather_unsafe(local_data.data(), local_data.size(), n_rank, root);
 
 
  587  template <POD_t... Args>
 
  591    for (
int j = 1; j < static_cast<int>(info.
n_rank); ++j)
 
  595        MPI_Send(&sign, 
sizeof(sign), MPI_CHAR, j, 0, MPI_COMM_WORLD);
 
  603            if constexpr (std::is_same_v<std::decay_t<T>, std::span<double>>)
 
  607              MPI_Send(&s, 1, MPI_UNSIGNED_LONG, j, 0, MPI_COMM_WORLD);
 
  616                buf, 
static_cast<int>(s), MPI_DOUBLE, j, 0, MPI_COMM_WORLD);
 
  617          }(std::forward<Args>(args)),
 
 
  622  template <POD_t DataType> 
int send(DataType data, 
size_t dest, 
size_t tag)
 
 
  627  template <POD_t DataType>
 
  628  int send_v(std::span<const DataType> data,
 
  631             bool send_size) 
noexcept 
  633    int send_status = MPI_SUCCESS;
 
  640    if (send_status == MPI_SUCCESS)
 
  642      send_status = 
_send_unsafe(data.data(), data.size(), dest, tag);
 
 
  648  template <POD_t DataType>
 
  649  std::optional<DataType>
 
  650  recv(
size_t src, MPI_Status* status, 
size_t tag) 
noexcept 
  654    int recv_status = MPI_Recv(
 
  655        &buf, 
sizeof(DataType), MPI_BYTE, src, tag, MPI_COMM_WORLD, status);
 
  656    if (recv_status != MPI_SUCCESS)
 
 
Definition message_t.hpp:21
 
Namespace to correclty wrap MPI C API for modern C++.
Definition impl_async.hpp:18
 
int broadcast_span(std::span< T > data, size_t root)
Broadcasts data stored in a span to all processes.
Definition impl_op.hpp:487
 
std::vector< T > try_recv_v(size_t src, MPI_Status *status=nullptr, size_t tag=0)
Attempts to receive a vector of data items from a source.
Definition impl_op.hpp:413
 
int send_v(std::span< const DataType > data, size_t dest, size_t tag=0, bool send_size=true) noexcept
Sends a vector of data to a destination.
Definition impl_op.hpp:628
 
int broadcast(DataType &data, size_t root) noexcept
Broadcasts a single data item to all processes.
Definition impl_op.hpp:426
 
std::vector< T > gather(std::span< T > local_data, size_t n_rank, size_t root=0)
Gathers data from all processes.
Definition impl_op.hpp:551
 
std::vector< T > gather_v(const std::vector< T > &local_data, size_t n_rank, size_t root=0)
Gathers a vector of data from all processes.
Definition impl_op.hpp:582
 
SIGNALS
Definition message_t.hpp:13
 
@ NOP
Definition message_t.hpp:16
 
int recv_span(std::span< DataType > buf, size_t src, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives data into a span buffer.
Definition impl_op.hpp:398
 
T all_reduce(T data)
Definition impl_op.hpp:310
 
int send(DataType data, size_t dest, size_t tag=0)
Sends a single instance of data to a destination.
Definition impl_op.hpp:622
 
T gather_reduce(T data, size_t root=0)
Gathers and reduces data to a single value.
Definition impl_op.hpp:564
 
int _broadcast_unsafe(T *data, size_t _size, size_t root)
Broadcasts raw data to all processes in an unsafe manner.
Definition impl_op.hpp:463
 
int critical_error() noexcept
Definition mpi_wrap.cpp:23
 
void host_dispatch(const ExecInfo &info, SIGNALS sign, Args &&... args)
Dispatches a task to the host with signal and arguments.
Definition impl_op.hpp:588
 
int _gather_unsafe_to_buffer(T *dest, T *src_data, size_t size, size_t root=0) noexcept
Definition impl_op.hpp:518
 
std::optional< DataType > recv(size_t src, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives a single data item from a source.
Definition impl_op.hpp:650
 
constexpr MPI_Datatype get_type() noexcept
Definition mpi_types.hpp:11
 
std::optional< std::vector< T > > recv_v(size_t source, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives a vector of data items from a source.
Definition impl_op.hpp:363
 
DataType try_recv(size_t src, MPI_Status *status=nullptr, size_t tag=0)
Attempts to receive a single data item from a source.
Definition impl_op.hpp:347
 
std::vector< T > _gather_unsafe(T *src_data, size_t size, size_t n_rank, size_t root=0)
Gathers raw data from all processes in an unsafe manner.
Definition impl_op.hpp:494
 
void gather_span(std::span< T > dest, std::span< const T > local_data, size_t root=0)
Definition impl_op.hpp:538
 
static int _send_unsafe(DataType *buf, size_t buf_size, size_t dest, size_t tag=0) noexcept
Sends raw data to a destination in an unsafe manner.
Definition impl_op.hpp:340
 
Definition execinfo.hpp:12
 
uint32_t n_rank
Definition execinfo.hpp:14