1#ifndef __IMPL_MPI_OP_HPP__
2#define __IMPL_MPI_OP_HPP__
4#include <common/execinfo.hpp>
5#include <common/traits.hpp>
10#include <mpi_w/impl_async.hpp>
11#include <mpi_w/message_t.hpp>
12#include <mpi_w/mpi_types.hpp>
60 template <POD_t DataType>
61 int send(DataType data,
size_t dest,
size_t tag = 0);
78 template <POD_t DataType>
79 int send_v(std::span<const DataType> data,
82 bool send_size =
true) noexcept;
101 template <
POD_t DataType>
102 std::optional<DataType>
103 recv(
size_t src, MPI_Status* status =
nullptr,
size_t tag = 0) noexcept;
120 template <
POD_t DataType>
123 MPI_Status* status =
nullptr,
124 size_t tag = 0) noexcept;
140 template <
POD_t DataType>
141 DataType
try_recv(
size_t src, MPI_Status* status =
nullptr,
size_t tag = 0);
159 std::optional<std::vector<T> >
160 recv_v(
size_t source, MPI_Status* status =
nullptr,
size_t tag = 0) noexcept;
178 try_recv_v(
size_t src, MPI_Status* status =
nullptr,
size_t tag = 0);
212 template <
POD_t DataType>
int broadcast(DataType& data,
size_t root) noexcept;
240 template <
POD_t... Args>
263 _gather_unsafe(T* src_data,
size_t size,
size_t n_rank,
size_t root = 0);
269 size_t root = 0) noexcept;
286 gather(std::span<T> local_data,
size_t n_rank,
size_t root = 0);
290 std::span<const T> local_data,
295 gather(std::span<const T> local_data,
size_t n_rank,
size_t root = 0);
309 template <NumberType T> T
gather_reduce(T data,
size_t root = 0);
311 template <NumberType T>
317 &data, &global_sum, 1,
get_type<T>(), MPI_SUM, MPI_COMM_WORLD);
336 gather_v(
const std::vector<T>& local_data,
size_t n_rank,
size_t root = 0);
342 template <POD_t DataType>
344 try_recv(
size_t src, MPI_Status* status,
size_t tag)
347 if (!opt_data.has_value())
354 return opt_data.value();
359 std::optional<std::vector<T> >
360 recv_v(
size_t source, MPI_Status* status,
size_t tag)
noexcept
366 if (!opt_size.has_value())
370 size_t buf_size = opt_size.value();
373 buf.resize(buf_size);
378 int recv_status = MPI_Recv(buf.data(),
379 static_cast<int>(buf_size),
386 if (recv_status != MPI_SUCCESS)
394 template <POD_t DataType>
401 return MPI_Recv(buf.data(),
415 if (!opt_data.has_value())
421 return opt_data.value();
425 template <POD_t DataType>
444 broadcast(std::vector<T>& data,
size_t root,
size_t current_rank)
447 size_t data_size = 0;
448 if (current_rank == root)
450 data_size = data.size();
453 if (current_rank != root)
455 data.resize(data_size);
458 return MPI_Bcast(data.data(),
461 static_cast<int>(root),
472 throw std::invalid_argument(
"Data pointer is null");
474 if (_size == 0 || _size > std::numeric_limits<size_t>::max())
476 throw std::invalid_argument(
"Error size");
480 MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
481 if (root >=
static_cast<size_t>(comm_size))
483 throw std::invalid_argument(
"Root process rank is out of range");
488 data, _size,
get_type<T>(),
static_cast<int>(root), MPI_COMM_WORLD);
502 int src_size =
static_cast<int>(size);
503 std::vector<T> total_data(src_size * n_rank);
504 T* dest_data = total_data.data();
517 throw std::runtime_error(
"MPI_Gather failed");
528 size_t root)
noexcept
530 int src_size =
static_cast<int>(size);
533 return MPI_Gather(src_data,
545 gather_span(std::span<T> dest, std::span<const T> local_data,
size_t root)
547 T* dest_data = dest.data();
550 MPI_Comm_size(MPI_COMM_WORLD, &size);
551 assert(dest.size() == local_data.size() * size);
554 dest_data,
const_cast<T*
>(local_data.data()), local_data.size(), root);
559 gather(std::span<T> local_data,
size_t n_rank,
size_t root)
561 return _gather_unsafe(local_data.data(), local_data.size(), n_rank, root);
566 gather(std::span<const T> local_data,
size_t n_rank,
size_t root)
569 const_cast<T*
>(local_data.data()), local_data.size(), n_rank, root);
572 template <NumberType T>
592 gather_v(
const std::vector<T>& local_data,
size_t n_rank,
size_t root)
594 return _gather_unsafe(local_data.data(), local_data.size(), n_rank, root);
597 template <POD_t... Args>
602 for (
int j = 1; j < static_cast<int>(info.
n_rank); ++j)
606 MPI_Send(&sign,
sizeof(sign), MPI_CHAR, j, 0, MPI_COMM_WORLD);
614 if constexpr (std::is_same_v<std::decay_t<T>, std::span<double> >)
618 MPI_Send(&s, 1, MPI_UNSIGNED_LONG, j, 0, MPI_COMM_WORLD);
627 buf,
static_cast<int>(s), MPI_DOUBLE, j, 0, MPI_COMM_WORLD);
628 }(std::forward<Args>(args)),
633 template <POD_t DataType>
635 send(DataType data,
size_t dest,
size_t tag)
638 auto res = WrapMPI::Async::_send_unsafe<DataType>(req, &data, 1, dest, tag);
643 template <POD_t DataType>
648 bool send_size)
noexcept
650 int send_status = MPI_SUCCESS;
657 if (send_status == MPI_SUCCESS)
660 auto res = WrapMPI::Async::_send_unsafe<DataType>(
661 req, data.data(), data.size(), dest, tag);
670 template <POD_t DataType>
671 std::optional<DataType>
672 recv(
size_t src, MPI_Status* status,
size_t tag)
noexcept
676 int recv_status = MPI_Recv(
677 &buf,
sizeof(DataType), MPI_BYTE, src, tag, MPI_COMM_WORLD, status);
678 if (recv_status != MPI_SUCCESS)
Definition message_t.hpp:23
MPI_Status wait(MPI_Request &request)
Definition impl_async.hpp:37
Namespace to correclty wrap MPI C API for modern C++.
Definition impl_async.hpp:14
int broadcast_span(std::span< T > data, size_t root)
Broadcasts data stored in a span to all processes.
Definition impl_op.hpp:493
std::vector< T > try_recv_v(size_t src, MPI_Status *status=nullptr, size_t tag=0)
Attempts to receive a vector of data items from a source.
Definition impl_op.hpp:412
int send_v(std::span< const DataType > data, size_t dest, size_t tag=0, bool send_size=true) noexcept
Sends a vector of data to a destination.
Definition impl_op.hpp:645
int broadcast(DataType &data, size_t root) noexcept
Broadcasts a single data item to all processes.
Definition impl_op.hpp:427
std::vector< T > gather(std::span< T > local_data, size_t n_rank, size_t root=0)
Gathers data from all processes.
Definition impl_op.hpp:559
std::vector< T > gather_v(const std::vector< T > &local_data, size_t n_rank, size_t root=0)
Gathers a vector of data from all processes.
Definition impl_op.hpp:592
SIGNALS
Definition message_t.hpp:13
@ NOP
Definition message_t.hpp:17
int recv_span(std::span< DataType > buf, size_t src, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives data into a span buffer.
Definition impl_op.hpp:396
T all_reduce(T data)
Definition impl_op.hpp:313
int send(DataType data, size_t dest, size_t tag=0)
Sends raw data to a destination in an unsafe manner.
Definition impl_op.hpp:635
T gather_reduce(T data, size_t root=0)
Gathers and reduces data to a single value.
Definition impl_op.hpp:574
int _broadcast_unsafe(T *data, size_t _size, size_t root)
Broadcasts raw data to all processes in an unsafe manner.
Definition impl_op.hpp:467
int critical_error() noexcept
Definition mpi_wrap.cpp:26
void host_dispatch(const ExecInfo &info, SIGNALS sign, Args &&... args)
Dispatches a task to the host with signal and arguments.
Definition impl_op.hpp:599
int _gather_unsafe_to_buffer(T *dest, T *src_data, size_t size, size_t root=0) noexcept
Definition impl_op.hpp:525
std::optional< DataType > recv(size_t src, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives a single data item from a source.
Definition impl_op.hpp:672
constexpr MPI_Datatype get_type() noexcept
Definition mpi_types.hpp:13
std::optional< std::vector< T > > recv_v(size_t source, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives a vector of data items from a source.
Definition impl_op.hpp:360
DataType try_recv(size_t src, MPI_Status *status=nullptr, size_t tag=0)
Attempts to receive a single data item from a source.
Definition impl_op.hpp:344
std::vector< T > _gather_unsafe(T *src_data, size_t size, size_t n_rank, size_t root=0)
Gathers raw data from all processes in an unsafe manner.
Definition impl_op.hpp:500
void gather_span(std::span< T > dest, std::span< const T > local_data, size_t root=0)
Definition impl_op.hpp:545
Definition execinfo.hpp:12
uint32_t n_rank
Definition execinfo.hpp:14