BioCMAMC-ST
impl_op.hpp
1#ifndef __IMPL_MPI_OP_HPP__
2#define __IMPL_MPI_OP_HPP__
3
4#include "common/common.hpp"
5#include <mpi_w/message_t.hpp>
6#include <mpi_w/mpi_types.hpp>
7
8#include <common/execinfo.hpp>
9#include <cstddef>
10#include <limits>
11#include <math.h>
12#include <mpi.h>
13#include <optional>
14#include <span>
15#include <stdexcept>
16#include <vector>
17
21namespace WrapMPI
22{
23
24
25
26 // SENDING
27
43 template <POD_t DataType>
44 [[nodiscard]] static int
45 _send_unsafe(DataType* buf, size_t buf_size, size_t dest, size_t tag = 0) noexcept;
46
58 template <POD_t DataType> int send(DataType data, size_t dest, size_t tag = 0);
59
73 template <POD_t DataType>
74 int send_v(std::span<const DataType> data,
75 size_t dest,
76 size_t tag = 0,
77 bool send_size = true) noexcept;
78
79 // RECEIVE
80
95 template <POD_t DataType>
96 std::optional<DataType> recv(size_t src, MPI_Status* status = nullptr, size_t tag = 0) noexcept;
97
112 template <POD_t DataType>
113 int recv_span(std::span<DataType> buf,
114 size_t src,
115 MPI_Status* status = nullptr,
116 size_t tag = 0) noexcept;
117
132 template <POD_t DataType>
133 DataType try_recv(size_t src, MPI_Status* status = nullptr, size_t tag = 0);
134
149 template <POD_t T>
150 std::optional<std::vector<T>>
151 recv_v(size_t source, MPI_Status* status = nullptr, size_t tag = 0) noexcept;
152
167 template <POD_t T>
168 std::vector<T> try_recv_v(size_t src, MPI_Status* status = nullptr, size_t tag = 0);
169
170 // BROADCASTING
171
186 template <POD_t T> [[nodiscard]] int _broadcast_unsafe(T* data, size_t _size, size_t root);
187
199 template <POD_t DataType> int broadcast(DataType& data, size_t root) noexcept;
200
211 template <POD_t T> int broadcast_span(std::span<T> data, size_t root);
212
224 template <POD_t... Args> void host_dispatch(const ExecInfo& info, SIGNALS sign, Args&&... args);
225
241 template <POD_t T>
242 std::vector<T> _gather_unsafe(T* src_data, size_t size, size_t n_rank, size_t root = 0);
243
244 template <POD_t T>
245 int _gather_unsafe_to_buffer(T* dest, T* src_data, size_t size, size_t root = 0) noexcept;
246
259 template <POD_t T> std::vector<T> gather(std::span<T> local_data, size_t n_rank, size_t root = 0);
260
261 template <POD_t T>
262 void gather_span(std::span<T> dest, std::span<const T> local_data, size_t root = 0);
263
264 template <POD_t T>
265 std::vector<T> gather(std::span<const T> local_data, size_t n_rank, size_t root = 0);
266
279 template <NumberType T> T gather_reduce(T data, size_t root = 0);
280
281 template <NumberType T> T all_reduce(T data)
282 {
283 T global_sum{}; // global sum across all ran
284 MPI_Allreduce(&data, &global_sum, 1, get_type<T>(), MPI_SUM, MPI_COMM_WORLD);
285 return global_sum;
286 }
287
300 template <POD_t T>
301 std::vector<T> gather_v(const std::vector<T>& local_data, size_t n_rank, size_t root = 0);
302
303 //**
304 // IMPL
305 //**
306 template <POD_t DataType>
307 static int _send_unsafe(DataType* buf, size_t buf_size, size_t dest, size_t tag) noexcept
308 {
309 return MPI_Send(buf, buf_size, get_type<DataType>(), dest, tag, MPI_COMM_WORLD);
310 }
311
312 template <POD_t DataType> DataType try_recv(size_t src, MPI_Status* status, size_t tag)
313 {
314 auto opt_data = WrapMPI::recv<DataType>(src, status, tag);
315 if (!opt_data.has_value())
316 {
318 exit(-1); // critical_error should exit before reaching this statement
319 }
320 else
321 {
322 return opt_data.value();
323 }
324 }
325
326 template <POD_t T>
327 std::optional<std::vector<T>> recv_v(size_t source, MPI_Status* status, size_t tag) noexcept
328 {
329 std::vector<T> buf;
330
331 // Receive the size of the vector
332 auto opt_size = recv<size_t>(source, status, tag);
333 if (!opt_size.has_value())
334 {
335 return std::nullopt; // Return early if size reception fails
336 }
337 size_t buf_size = opt_size.value();
338
339 // Resize the buffer
340 buf.resize(buf_size);
341
342 MPI_Datatype datatype = get_type<T>();
343
344 // Receive the vector data
345 int recv_status = MPI_Recv(
346 buf.data(), static_cast<int>(buf_size), datatype, source, tag, MPI_COMM_WORLD, status);
347
348 if (recv_status != MPI_SUCCESS)
349 {
350 return std::nullopt; // Return early if MPI_Recv fails
351 }
352
353 return buf; // Return the received vector
354 }
355
356 template <POD_t DataType>
357 int recv_span(std::span<DataType> buf, size_t src, MPI_Status* status, size_t tag) noexcept
358 {
359 return MPI_Recv(buf.data(), buf.size(), get_type<DataType>(), src, tag, MPI_COMM_WORLD, status);
360 }
361
362 template <POD_t T> std::vector<T> try_recv_v(size_t src, MPI_Status* status, size_t tag)
363 {
364 auto opt_data = WrapMPI::recv_v<T>(src, status, tag);
365 if (!opt_data.has_value())
366 {
368 }
369 else
370 {
371 return opt_data.value();
372 }
373 }
374
375 template <POD_t DataType> int broadcast(DataType& data, size_t root) noexcept
376 {
377 return MPI_Bcast(&data, 1, get_type<DataType>(), static_cast<int>(root), MPI_COMM_WORLD);
378 }
379
380 // template <> int broadcast(size_t &data, size_t root)
381 // {
382 // return MPI_Bcast(&data,
383 // sizeof(size_t),
384 // MPI_UNSIGNED_LONG,
385 // static_cast<int>(root),
386 // MPI_COMM_WORLD);
387 // }
388
389 template <POD_t T> int broadcast(std::vector<T>& data, size_t root, size_t current_rank)
390 {
391
392 size_t data_size = 0;
393 if (current_rank == root)
394 {
395 data_size = data.size();
396 }
397 broadcast(data_size, root);
398 if (current_rank != root)
399 {
400 data.resize(data_size);
401 }
402
403 return MPI_Bcast(data.data(), data_size, get_type<T>(), static_cast<int>(root), MPI_COMM_WORLD);
404 }
405
406 template <POD_t T> int _broadcast_unsafe(T* data, size_t _size, size_t root)
407 {
408
409 if (data == nullptr)
410 {
411 throw std::invalid_argument("Data pointer is null");
412 }
413 if (_size == 0 || _size > std::numeric_limits<size_t>::max())
414 {
415 throw std::invalid_argument("Error size");
416 }
417
418 int comm_size = 0;
419 MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
420 if (root >= static_cast<size_t>(comm_size))
421 {
422 throw std::invalid_argument("Root process rank is out of range");
423 }
424
425 // Broadcast operation
426 return MPI_Bcast(data, _size, get_type<T>(), static_cast<int>(root), MPI_COMM_WORLD);
427 }
428
429 template <POD_t T> int broadcast_span(std::span<T> data, size_t root)
430 {
431 return _broadcast_unsafe(data.data(), data.size(), root);
432 }
433
434 template <POD_t T>
435 std::vector<T> _gather_unsafe(T* src_data, size_t size, size_t n_rank, size_t root)
436 {
437 int src_size = static_cast<int>(size);
438 std::vector<T> total_data(src_size * n_rank);
439 T* dest_data = total_data.data();
440 // auto mpi_type = get_type<T>();
441
442 // int gather_result = MPI_Gather(
443 // src_data, src_size, mpi_type, dest_data, src_size, mpi_type, root, MPI_COMM_WORLD);
444 // if (gather_result != MPI_SUCCESS)
445 // {
446 // throw std::runtime_error("MPI_Gather failed");
447 // }
448 if (_gather_unsafe_to_buffer(dest_data, src_data, size, root) != MPI_SUCCESS)
449 {
450 throw std::runtime_error("MPI_Gather failed");
451 }
452
453 return total_data;
454 }
455
456 template <POD_t T>
457 int _gather_unsafe_to_buffer(T* const dest, T* src_data, size_t size, size_t root) noexcept
458 {
459 int src_size = static_cast<int>(size);
460 auto mpi_type = get_type<T>();
461
462 return MPI_Gather(src_data, src_size, mpi_type, dest, src_size, mpi_type, root, MPI_COMM_WORLD);
463 }
464
465 template <POD_t T> void gather_span(std::span<T> dest, std::span<const T> local_data, size_t root)
466 {
467 T* dest_data = dest.data();
468#ifndef NDEBUG
469 int size{};
470 MPI_Comm_size(MPI_COMM_WORLD, &size);
471 assert(dest.size() == local_data.size() * size);
472#endif
473 _gather_unsafe_to_buffer(dest_data, const_cast<T*>(local_data.data()), local_data.size(), root);
474 }
475
476 template <POD_t T> std::vector<T> gather(std::span<T> local_data, size_t n_rank, size_t root)
477 {
478 return _gather_unsafe(local_data.data(), local_data.size(), n_rank, root);
479 }
480 // FIXME
481 template <POD_t T>
482 std::vector<T> gather(std::span<const T> local_data, size_t n_rank, size_t root)
483 {
484 return _gather_unsafe(const_cast<T*>(local_data.data()), local_data.size(), n_rank, root);
485 }
486
487 template <NumberType T> T gather_reduce(T data, size_t root)
488 {
489
490 T result{};
491
492 MPI_Reduce(&data, &result, 1, WrapMPI::get_type<T>(), MPI_SUM, root, MPI_COMM_WORLD);
493
494 return result;
495 }
496
497 template <POD_t T>
498 std::vector<T> gather_v(const std::vector<T>& local_data, size_t n_rank, size_t root)
499 {
500 return _gather_unsafe(local_data.data(), local_data.size(), n_rank, root);
501 }
502
503 template <POD_t... Args> void host_dispatch(const ExecInfo& info, SIGNALS sign, Args&&... args)
504 {
505
506 for (int j = 1; j < static_cast<int>(info.n_rank); ++j)
507 {
508 if (sign != WrapMPI::SIGNALS::NOP)
509 {
510 MPI_Send(&sign, sizeof(sign), MPI_CHAR, j, 0, MPI_COMM_WORLD);
511 }
512
513 (
514 [&]<POD_t T>(T& arg)
515 {
516 size_t s = 0;
517 void* buf = nullptr;
518 if constexpr (std::is_same_v<std::decay_t<T>, std::span<double>>)
519 {
520 s = arg.size();
521 buf = arg.data();
522 MPI_Send(&s, 1, MPI_UNSIGNED_LONG, j, 0, MPI_COMM_WORLD);
523 }
524 else
525 {
526 s = sizeof(T);
527 buf = &arg;
528 }
529
530 MPI_Send(buf, static_cast<int>(s), MPI_DOUBLE, j, 0, MPI_COMM_WORLD);
531 }(std::forward<Args>(args)),
532 ...);
533 }
534 }
535
536 template <POD_t DataType> int send(DataType data, size_t dest, size_t tag)
537 {
538 return _send_unsafe<DataType>(&data, 1, dest, tag);
539 }
540
541 template <POD_t DataType>
542 int send_v(std::span<const DataType> data, size_t dest, size_t tag, bool send_size) noexcept
543 {
544 int send_status = MPI_SUCCESS;
545
546 if (send_size)
547 {
548 send_status = send<size_t>(data.size(), dest, tag);
549 }
550
551 if (send_status == MPI_SUCCESS)
552 {
553 send_status = _send_unsafe(data.data(), data.size(), dest, tag);
554 }
555
556 return send_status;
557 }
558
559 template <POD_t DataType>
560 std::optional<DataType> recv(size_t src, MPI_Status* status, size_t tag) noexcept
561 {
562 DataType buf;
563
564 int recv_status = MPI_Recv(&buf, sizeof(DataType), MPI_BYTE, src, tag, MPI_COMM_WORLD, status);
565 if (recv_status != MPI_SUCCESS)
566 {
567 return std::nullopt;
568 }
569 return buf;
570 }
571
572} // namespace WrapMPI
573
574#endif //__IMPL_MPI_OP_HPP__
Definition traits.hpp:33
Definition message_t.hpp:21
Namespace to correclty wrap MPI C API for modern C++.
Definition impl_async.hpp:19
int broadcast_span(std::span< T > data, size_t root)
Broadcasts data stored in a span to all processes.
Definition impl_op.hpp:429
std::vector< T > try_recv_v(size_t src, MPI_Status *status=nullptr, size_t tag=0)
Attempts to receive a vector of data items from a source.
Definition impl_op.hpp:362
int send_v(std::span< const DataType > data, size_t dest, size_t tag=0, bool send_size=true) noexcept
Sends a vector of data to a destination.
Definition impl_op.hpp:542
int broadcast(DataType &data, size_t root) noexcept
Broadcasts a single data item to all processes.
Definition impl_op.hpp:375
std::vector< T > gather(std::span< T > local_data, size_t n_rank, size_t root=0)
Gathers data from all processes.
Definition impl_op.hpp:476
std::vector< T > gather_v(const std::vector< T > &local_data, size_t n_rank, size_t root=0)
Gathers a vector of data from all processes.
Definition impl_op.hpp:498
SIGNALS
Definition message_t.hpp:13
int recv_span(std::span< DataType > buf, size_t src, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives data into a span buffer.
Definition impl_op.hpp:357
T all_reduce(T data)
Definition impl_op.hpp:281
int send(DataType data, size_t dest, size_t tag=0)
Sends a single instance of data to a destination.
Definition impl_op.hpp:536
T gather_reduce(T data, size_t root=0)
Gathers and reduces data to a single value.
Definition impl_op.hpp:487
int _broadcast_unsafe(T *data, size_t _size, size_t root)
Broadcasts raw data to all processes in an unsafe manner.
Definition impl_op.hpp:406
int critical_error() noexcept
Definition mpi_wrap.cpp:23
void host_dispatch(const ExecInfo &info, SIGNALS sign, Args &&... args)
Dispatches a task to the host with signal and arguments.
Definition impl_op.hpp:503
int _gather_unsafe_to_buffer(T *dest, T *src_data, size_t size, size_t root=0) noexcept
Definition impl_op.hpp:457
std::optional< DataType > recv(size_t src, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives a single data item from a source.
Definition impl_op.hpp:560
constexpr MPI_Datatype get_type() noexcept
Definition mpi_types.hpp:11
std::optional< std::vector< T > > recv_v(size_t source, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives a vector of data items from a source.
Definition impl_op.hpp:327
DataType try_recv(size_t src, MPI_Status *status=nullptr, size_t tag=0)
Attempts to receive a single data item from a source.
Definition impl_op.hpp:312
std::vector< T > _gather_unsafe(T *src_data, size_t size, size_t n_rank, size_t root=0)
Gathers raw data from all processes in an unsafe manner.
Definition impl_op.hpp:435
void gather_span(std::span< T > dest, std::span< const T > local_data, size_t root=0)
Definition impl_op.hpp:465
static int _send_unsafe(DataType *buf, size_t buf_size, size_t dest, size_t tag=0) noexcept
Sends raw data to a destination in an unsafe manner.
Definition impl_op.hpp:307
Definition execinfo.hpp:12
uint32_t n_rank
Definition execinfo.hpp:14