1#ifndef __IMPL_MPI_OP_HPP__
2#define __IMPL_MPI_OP_HPP__
4#include "common/common.hpp"
5#include <mpi_w/message_t.hpp>
6#include <mpi_w/mpi_types.hpp>
8#include <common/execinfo.hpp>
43 template <POD_t DataType>
44 [[nodiscard]]
static int
45 _send_unsafe(DataType* buf,
size_t buf_size,
size_t dest,
size_t tag = 0) noexcept;
58 template <POD_t DataType>
int send(DataType data,
size_t dest,
size_t tag = 0);
73 template <POD_t DataType>
74 int send_v(std::span<const DataType> data,
77 bool send_size = true) noexcept;
95 template <POD_t DataType>
96 std::optional<DataType>
recv(
size_t src, MPI_Status* status =
nullptr,
size_t tag = 0) noexcept;
112 template <POD_t DataType>
115 MPI_Status* status =
nullptr,
116 size_t tag = 0) noexcept;
132 template <POD_t DataType>
133 DataType
try_recv(
size_t src, MPI_Status* status =
nullptr,
size_t tag = 0);
150 std::optional<std::vector<T>>
151 recv_v(
size_t source, MPI_Status* status =
nullptr,
size_t tag = 0) noexcept;
168 std::vector<T>
try_recv_v(
size_t src, MPI_Status* status =
nullptr,
size_t tag = 0);
186 template <POD_t T> [[nodiscard]]
int _broadcast_unsafe(T* data,
size_t _size,
size_t root);
199 template <POD_t DataType>
int broadcast(DataType& data,
size_t root) noexcept;
211 template <POD_t T>
int broadcast_span(std::span<T> data,
size_t root);
242 std::vector<T>
_gather_unsafe(T* src_data,
size_t size,
size_t n_rank,
size_t root = 0);
259 template <POD_t T> std::vector<T>
gather(std::span<T> local_data,
size_t n_rank,
size_t root = 0);
262 void gather_span(std::span<T> dest, std::span<const T> local_data,
size_t root = 0);
265 std::vector<T>
gather(std::span<const T> local_data,
size_t n_rank,
size_t root = 0);
284 MPI_Allreduce(&data, &global_sum, 1,
get_type<T>(), MPI_SUM, MPI_COMM_WORLD);
301 std::vector<T>
gather_v(
const std::vector<T>& local_data,
size_t n_rank,
size_t root = 0);
306 template <POD_t DataType>
307 static int _send_unsafe(DataType* buf,
size_t buf_size,
size_t dest,
size_t tag)
noexcept
312 template <POD_t DataType> DataType
try_recv(
size_t src, MPI_Status* status,
size_t tag)
315 if (!opt_data.has_value())
322 return opt_data.value();
327 std::optional<std::vector<T>>
recv_v(
size_t source, MPI_Status* status,
size_t tag)
noexcept
333 if (!opt_size.has_value())
337 size_t buf_size = opt_size.value();
340 buf.resize(buf_size);
345 int recv_status = MPI_Recv(
346 buf.data(),
static_cast<int>(buf_size), datatype, source, tag, MPI_COMM_WORLD, status);
348 if (recv_status != MPI_SUCCESS)
356 template <POD_t DataType>
357 int recv_span(std::span<DataType> buf,
size_t src, MPI_Status* status,
size_t tag)
noexcept
359 return MPI_Recv(buf.data(), buf.size(),
get_type<DataType>(), src, tag, MPI_COMM_WORLD, status);
362 template <POD_t T> std::vector<T>
try_recv_v(
size_t src, MPI_Status* status,
size_t tag)
365 if (!opt_data.has_value())
371 return opt_data.value();
375 template <POD_t DataType>
int broadcast(DataType& data,
size_t root)
noexcept
377 return MPI_Bcast(&data, 1,
get_type<DataType>(),
static_cast<int>(root), MPI_COMM_WORLD);
389 template <POD_t T>
int broadcast(std::vector<T>& data,
size_t root,
size_t current_rank)
392 size_t data_size = 0;
393 if (current_rank == root)
395 data_size = data.size();
398 if (current_rank != root)
400 data.resize(data_size);
403 return MPI_Bcast(data.data(), data_size,
get_type<T>(),
static_cast<int>(root), MPI_COMM_WORLD);
411 throw std::invalid_argument(
"Data pointer is null");
413 if (_size == 0 || _size > std::numeric_limits<size_t>::max())
415 throw std::invalid_argument(
"Error size");
419 MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
420 if (root >=
static_cast<size_t>(comm_size))
422 throw std::invalid_argument(
"Root process rank is out of range");
426 return MPI_Bcast(data, _size,
get_type<T>(),
static_cast<int>(root), MPI_COMM_WORLD);
435 std::vector<T>
_gather_unsafe(T* src_data,
size_t size,
size_t n_rank,
size_t root)
437 int src_size =
static_cast<int>(size);
438 std::vector<T> total_data(src_size * n_rank);
439 T* dest_data = total_data.data();
450 throw std::runtime_error(
"MPI_Gather failed");
459 int src_size =
static_cast<int>(size);
462 return MPI_Gather(src_data, src_size, mpi_type, dest, src_size, mpi_type, root, MPI_COMM_WORLD);
465 template <POD_t T>
void gather_span(std::span<T> dest, std::span<const T> local_data,
size_t root)
467 T* dest_data = dest.data();
470 MPI_Comm_size(MPI_COMM_WORLD, &size);
471 assert(dest.size() == local_data.size() * size);
476 template <POD_t T> std::vector<T>
gather(std::span<T> local_data,
size_t n_rank,
size_t root)
478 return _gather_unsafe(local_data.data(), local_data.size(), n_rank, root);
482 std::vector<T>
gather(std::span<const T> local_data,
size_t n_rank,
size_t root)
484 return _gather_unsafe(
const_cast<T*
>(local_data.data()), local_data.size(), n_rank, root);
498 std::vector<T>
gather_v(
const std::vector<T>& local_data,
size_t n_rank,
size_t root)
500 return _gather_unsafe(local_data.data(), local_data.size(), n_rank, root);
506 for (
int j = 1; j < static_cast<int>(info.
n_rank); ++j)
510 MPI_Send(&sign,
sizeof(sign), MPI_CHAR, j, 0, MPI_COMM_WORLD);
518 if constexpr (std::is_same_v<std::decay_t<T>, std::span<double>>)
522 MPI_Send(&s, 1, MPI_UNSIGNED_LONG, j, 0, MPI_COMM_WORLD);
530 MPI_Send(buf,
static_cast<int>(s), MPI_DOUBLE, j, 0, MPI_COMM_WORLD);
531 }(std::forward<Args>(args)),
536 template <POD_t DataType>
int send(DataType data,
size_t dest,
size_t tag)
541 template <POD_t DataType>
542 int send_v(std::span<const DataType> data,
size_t dest,
size_t tag,
bool send_size)
noexcept
544 int send_status = MPI_SUCCESS;
551 if (send_status == MPI_SUCCESS)
553 send_status =
_send_unsafe(data.data(), data.size(), dest, tag);
559 template <POD_t DataType>
560 std::optional<DataType>
recv(
size_t src, MPI_Status* status,
size_t tag)
noexcept
564 int recv_status = MPI_Recv(&buf,
sizeof(DataType), MPI_BYTE, src, tag, MPI_COMM_WORLD, status);
565 if (recv_status != MPI_SUCCESS)
Definition message_t.hpp:21
Namespace to correclty wrap MPI C API for modern C++.
Definition impl_async.hpp:19
int broadcast_span(std::span< T > data, size_t root)
Broadcasts data stored in a span to all processes.
Definition impl_op.hpp:429
std::vector< T > try_recv_v(size_t src, MPI_Status *status=nullptr, size_t tag=0)
Attempts to receive a vector of data items from a source.
Definition impl_op.hpp:362
int send_v(std::span< const DataType > data, size_t dest, size_t tag=0, bool send_size=true) noexcept
Sends a vector of data to a destination.
Definition impl_op.hpp:542
int broadcast(DataType &data, size_t root) noexcept
Broadcasts a single data item to all processes.
Definition impl_op.hpp:375
std::vector< T > gather(std::span< T > local_data, size_t n_rank, size_t root=0)
Gathers data from all processes.
Definition impl_op.hpp:476
std::vector< T > gather_v(const std::vector< T > &local_data, size_t n_rank, size_t root=0)
Gathers a vector of data from all processes.
Definition impl_op.hpp:498
SIGNALS
Definition message_t.hpp:13
int recv_span(std::span< DataType > buf, size_t src, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives data into a span buffer.
Definition impl_op.hpp:357
T all_reduce(T data)
Definition impl_op.hpp:281
int send(DataType data, size_t dest, size_t tag=0)
Sends a single instance of data to a destination.
Definition impl_op.hpp:536
T gather_reduce(T data, size_t root=0)
Gathers and reduces data to a single value.
Definition impl_op.hpp:487
int _broadcast_unsafe(T *data, size_t _size, size_t root)
Broadcasts raw data to all processes in an unsafe manner.
Definition impl_op.hpp:406
int critical_error() noexcept
Definition mpi_wrap.cpp:23
void host_dispatch(const ExecInfo &info, SIGNALS sign, Args &&... args)
Dispatches a task to the host with signal and arguments.
Definition impl_op.hpp:503
int _gather_unsafe_to_buffer(T *dest, T *src_data, size_t size, size_t root=0) noexcept
Definition impl_op.hpp:457
std::optional< DataType > recv(size_t src, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives a single data item from a source.
Definition impl_op.hpp:560
constexpr MPI_Datatype get_type() noexcept
Definition mpi_types.hpp:11
std::optional< std::vector< T > > recv_v(size_t source, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives a vector of data items from a source.
Definition impl_op.hpp:327
DataType try_recv(size_t src, MPI_Status *status=nullptr, size_t tag=0)
Attempts to receive a single data item from a source.
Definition impl_op.hpp:312
std::vector< T > _gather_unsafe(T *src_data, size_t size, size_t n_rank, size_t root=0)
Gathers raw data from all processes in an unsafe manner.
Definition impl_op.hpp:435
void gather_span(std::span< T > dest, std::span< const T > local_data, size_t root=0)
Definition impl_op.hpp:465
static int _send_unsafe(DataType *buf, size_t buf_size, size_t dest, size_t tag=0) noexcept
Sends raw data to a destination in an unsafe manner.
Definition impl_op.hpp:307
Definition execinfo.hpp:12
uint32_t n_rank
Definition execinfo.hpp:14