BioCMAMC-ST
impl_op.hpp
1#ifndef __IMPL_MPI_OP_HPP__
2#define __IMPL_MPI_OP_HPP__
3
4#include <common/execinfo.hpp>
5#include <common/traits.hpp>
6#include <cstddef>
7#include <limits>
8#include <math.h>
9#include <mpi.h>
10#include <mpi_w/impl_async.hpp>
11#include <mpi_w/message_t.hpp>
12#include <mpi_w/mpi_types.hpp>
13#include <optional>
14#include <span>
15#include <stdexcept>
16#include <vector>
17
21namespace WrapMPI
22{
23
24 // SENDING
25
26 // /**
27 // * @brief Sends raw data to a destination in an unsafe manner.
28 // *
29 // * This function sends raw data of type `DataType` to the specified
30 // * destination. It assumes the provided buffer and its size are valid
31 // *
32 // * @tparam DataType A type satisfying the `POD` concept.
33 // * @param buf Pointer to the buffer containing data to send.
34 // * @param buf_size The size of the buffer in bytes.
35 // * @param dest The destination identifier for the data.
36 // * @param tag Optional tag to identify the message (default is 0).
37 // * @return An integer indicating success or failure of the operation.
38 // *
39 // * @note Use this function with caution as it performs no validation on the
40 // * input.
41 // */
42 // template <POD_t DataType>
43 // [[nodiscard]] static int _send_unsafe(DataType* buf,
44 // size_t buf_size,
45 // size_t dest,
46 // size_t tag = 0) noexcept;
47
60 template <POD_t DataType>
61 int send(DataType data, size_t dest, size_t tag = 0);
62
78 template <POD_t DataType>
79 int send_v(std::span<const DataType> data,
80 size_t dest,
81 size_t tag = 0,
82 bool send_size = true) noexcept;
83
84 // RECEIVE
85
101 template <POD_t DataType>
102 std::optional<DataType>
103 recv(size_t src, MPI_Status* status = nullptr, size_t tag = 0) noexcept;
104
120 template <POD_t DataType>
121 int recv_span(std::span<DataType> buf,
122 size_t src,
123 MPI_Status* status = nullptr,
124 size_t tag = 0) noexcept;
125
140 template <POD_t DataType>
141 DataType try_recv(size_t src, MPI_Status* status = nullptr, size_t tag = 0);
142
158 template <POD_t T>
159 std::optional<std::vector<T> >
160 recv_v(size_t source, MPI_Status* status = nullptr, size_t tag = 0) noexcept;
161
176 template <POD_t T>
177 std::vector<T>
178 try_recv_v(size_t src, MPI_Status* status = nullptr, size_t tag = 0);
179
180 // BROADCASTING
181
197 template <POD_t T>
198 [[nodiscard]] int _broadcast_unsafe(T* data, size_t _size, size_t root);
199
212 template <POD_t DataType> int broadcast(DataType& data, size_t root) noexcept;
213
226 template <POD_t T> int broadcast_span(std::span<T> data, size_t root);
227
240 template <POD_t... Args>
241 void host_dispatch(const ExecInfo& info, SIGNALS sign, Args&&... args);
242
261 template <POD_t T>
262 std::vector<T>
263 _gather_unsafe(T* src_data, size_t size, size_t n_rank, size_t root = 0);
264
265 template <POD_t T>
266 int _gather_unsafe_to_buffer(T* dest,
267 T* src_data,
268 size_t size,
269 size_t root = 0) noexcept;
270
284 template <POD_t T>
285 std::vector<T>
286 gather(std::span<T> local_data, size_t n_rank, size_t root = 0);
287
288 template <POD_t T>
289 void gather_span(std::span<T> dest,
290 std::span<const T> local_data,
291 size_t root = 0);
292
293 template <POD_t T>
294 std::vector<T>
295 gather(std::span<const T> local_data, size_t n_rank, size_t root = 0);
296
309 template <NumberType T> T gather_reduce(T data, size_t root = 0);
310
311 template <NumberType T>
312 T
314 {
315 T global_sum{}; // global sum across all ran
316 MPI_Allreduce(
317 &data, &global_sum, 1, get_type<T>(), MPI_SUM, MPI_COMM_WORLD);
318 return global_sum;
319 }
320
334 template <POD_t T>
335 std::vector<T>
336 gather_v(const std::vector<T>& local_data, size_t n_rank, size_t root = 0);
337
338 //**
339 // IMPL
340 //**
341
342 template <POD_t DataType>
343 DataType
344 try_recv(size_t src, MPI_Status* status, size_t tag)
345 {
346 auto opt_data = WrapMPI::recv<DataType>(src, status, tag);
347 if (!opt_data.has_value())
348 {
350 exit(-1); // critical_error should exit before reaching this statement
351 }
352 else
353 {
354 return opt_data.value();
355 }
356 }
357
358 template <POD_t T>
359 std::optional<std::vector<T> >
360 recv_v(size_t source, MPI_Status* status, size_t tag) noexcept
361 {
362 std::vector<T> buf;
363
364 // Receive the size of the vector
365 auto opt_size = recv<size_t>(source, status, tag);
366 if (!opt_size.has_value())
367 {
368 return std::nullopt; // Return early if size reception fails
369 }
370 size_t buf_size = opt_size.value();
371
372 // Resize the buffer
373 buf.resize(buf_size);
374
375 MPI_Datatype datatype = get_type<T>();
376
377 // Receive the vector data
378 int recv_status = MPI_Recv(buf.data(),
379 static_cast<int>(buf_size),
380 datatype,
381 source,
382 tag,
383 MPI_COMM_WORLD,
384 status);
385
386 if (recv_status != MPI_SUCCESS)
387 {
388 return std::nullopt; // Return early if MPI_Recv fails
389 }
390
391 return buf; // Return the received vector
392 }
393
394 template <POD_t DataType>
395 int
396 recv_span(std::span<DataType> buf,
397 size_t src,
398 MPI_Status* status,
399 size_t tag) noexcept
400 {
401 return MPI_Recv(buf.data(),
402 buf.size(),
404 src,
405 tag,
406 MPI_COMM_WORLD,
407 status);
408 }
409
410 template <POD_t T>
411 std::vector<T>
412 try_recv_v(size_t src, MPI_Status* status, size_t tag)
413 {
414 auto opt_data = WrapMPI::recv_v<T>(src, status, tag);
415 if (!opt_data.has_value())
416 {
418 }
419 else
420 {
421 return opt_data.value();
422 }
423 }
424
425 template <POD_t DataType>
426 int
427 broadcast(DataType& data, size_t root) noexcept
428 {
429 return MPI_Bcast(
430 &data, 1, get_type<DataType>(), static_cast<int>(root), MPI_COMM_WORLD);
431 }
432
433 // template <> int broadcast(size_t &data, size_t root)
434 // {
435 // return MPI_Bcast(&data,
436 // sizeof(size_t),
437 // MPI_UNSIGNED_LONG,
438 // static_cast<int>(root),
439 // MPI_COMM_WORLD);
440 // }
441
442 template <POD_t T>
443 int
444 broadcast(std::vector<T>& data, size_t root, size_t current_rank)
445 {
446
447 size_t data_size = 0;
448 if (current_rank == root)
449 {
450 data_size = data.size();
451 }
452 broadcast(data_size, root);
453 if (current_rank != root)
454 {
455 data.resize(data_size);
456 }
457
458 return MPI_Bcast(data.data(),
459 data_size,
460 get_type<T>(),
461 static_cast<int>(root),
462 MPI_COMM_WORLD);
463 }
464
465 template <POD_t T>
466 int
467 _broadcast_unsafe(T* data, size_t _size, size_t root)
468 {
469
470 if (data == nullptr)
471 {
472 throw std::invalid_argument("Data pointer is null");
473 }
474 if (_size == 0 || _size > std::numeric_limits<size_t>::max())
475 {
476 throw std::invalid_argument("Error size");
477 }
478
479 int comm_size = 0;
480 MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
481 if (root >= static_cast<size_t>(comm_size))
482 {
483 throw std::invalid_argument("Root process rank is out of range");
484 }
485
486 // Broadcast operation
487 return MPI_Bcast(
488 data, _size, get_type<T>(), static_cast<int>(root), MPI_COMM_WORLD);
489 }
490
491 template <POD_t T>
492 int
493 broadcast_span(std::span<T> data, size_t root)
494 {
495 return _broadcast_unsafe(data.data(), data.size(), root);
496 }
497
498 template <POD_t T>
499 std::vector<T>
500 _gather_unsafe(T* src_data, size_t size, size_t n_rank, size_t root)
501 {
502 int src_size = static_cast<int>(size);
503 std::vector<T> total_data(src_size * n_rank);
504 T* dest_data = total_data.data();
505 // auto mpi_type = get_type<T>();
506
507 // int gather_result = MPI_Gather(
508 // src_data, src_size, mpi_type, dest_data, src_size, mpi_type, root,
509 // MPI_COMM_WORLD);
510 // if (gather_result != MPI_SUCCESS)
511 // {
512 // throw std::runtime_error("MPI_Gather failed");
513 // }
514 if (_gather_unsafe_to_buffer(dest_data, src_data, size, root)
515 != MPI_SUCCESS)
516 {
517 throw std::runtime_error("MPI_Gather failed");
518 }
519
520 return total_data;
521 }
522
523 template <POD_t T>
524 int
526 T* src_data,
527 size_t size,
528 size_t root) noexcept
529 {
530 int src_size = static_cast<int>(size);
531 auto mpi_type = get_type<T>();
532
533 return MPI_Gather(src_data,
534 src_size,
535 mpi_type,
536 dest,
537 src_size,
538 mpi_type,
539 root,
540 MPI_COMM_WORLD);
541 }
542
543 template <POD_t T>
544 void
545 gather_span(std::span<T> dest, std::span<const T> local_data, size_t root)
546 {
547 T* dest_data = dest.data();
548#ifndef NDEBUG
549 int size{};
550 MPI_Comm_size(MPI_COMM_WORLD, &size);
551 assert(dest.size() == local_data.size() * size);
552#endif
554 dest_data, const_cast<T*>(local_data.data()), local_data.size(), root);
555 }
556
557 template <POD_t T>
558 std::vector<T>
559 gather(std::span<T> local_data, size_t n_rank, size_t root)
560 {
561 return _gather_unsafe(local_data.data(), local_data.size(), n_rank, root);
562 }
563 // FIXME
564 template <POD_t T>
565 std::vector<T>
566 gather(std::span<const T> local_data, size_t n_rank, size_t root)
567 {
568 return _gather_unsafe(
569 const_cast<T*>(local_data.data()), local_data.size(), n_rank, root);
570 }
571
572 template <NumberType T>
573 T
574 gather_reduce(T data, size_t root)
575 {
576
577 T result{};
578
579 MPI_Reduce(&data,
580 &result,
581 1,
583 MPI_SUM,
584 root,
585 MPI_COMM_WORLD);
586
587 return result;
588 }
589
590 template <POD_t T>
591 std::vector<T>
592 gather_v(const std::vector<T>& local_data, size_t n_rank, size_t root)
593 {
594 return _gather_unsafe(local_data.data(), local_data.size(), n_rank, root);
595 }
596
597 template <POD_t... Args>
598 void
599 host_dispatch(const ExecInfo& info, SIGNALS sign, Args&&... args)
600 {
601
602 for (int j = 1; j < static_cast<int>(info.n_rank); ++j)
603 {
604 if (sign != WrapMPI::SIGNALS::NOP)
605 {
606 MPI_Send(&sign, sizeof(sign), MPI_CHAR, j, 0, MPI_COMM_WORLD);
607 }
608
609 (
610 [&]<POD_t T>(T& arg)
611 {
612 size_t s = 0;
613 void* buf = nullptr;
614 if constexpr (std::is_same_v<std::decay_t<T>, std::span<double> >)
615 {
616 s = arg.size();
617 buf = arg.data();
618 MPI_Send(&s, 1, MPI_UNSIGNED_LONG, j, 0, MPI_COMM_WORLD);
619 }
620 else
621 {
622 s = sizeof(T);
623 buf = &arg;
624 }
625
626 MPI_Send(
627 buf, static_cast<int>(s), MPI_DOUBLE, j, 0, MPI_COMM_WORLD);
628 }(std::forward<Args>(args)),
629 ...);
630 }
631 }
632
633 template <POD_t DataType>
634 int
635 send(DataType data, size_t dest, size_t tag)
636 {
637 MPI_Request req;
638 auto res = WrapMPI::Async::_send_unsafe<DataType>(req, &data, 1, dest, tag);
640 return res;
641 }
642
643 template <POD_t DataType>
644 int
645 send_v(std::span<const DataType> data,
646 size_t dest,
647 size_t tag,
648 bool send_size) noexcept
649 {
650 int send_status = MPI_SUCCESS;
651
652 if (send_size)
653 {
654 send_status = send<size_t>(data.size(), dest, tag);
655 }
656
657 if (send_status == MPI_SUCCESS)
658 {
659 MPI_Request req{};
660 auto res = WrapMPI::Async::_send_unsafe<DataType>(
661 req, data.data(), data.size(), dest, tag);
663
664 send_status = res;
665 }
666
667 return send_status;
668 }
669
670 template <POD_t DataType>
671 std::optional<DataType>
672 recv(size_t src, MPI_Status* status, size_t tag) noexcept
673 {
674 DataType buf;
675
676 int recv_status = MPI_Recv(
677 &buf, sizeof(DataType), MPI_BYTE, src, tag, MPI_COMM_WORLD, status);
678 if (recv_status != MPI_SUCCESS)
679 {
680 return std::nullopt;
681 }
682 return buf;
683 }
684
685} // namespace WrapMPI
686
687#endif //__IMPL_MPI_OP_HPP__
Definition message_t.hpp:23
MPI_Status wait(MPI_Request &request)
Definition impl_async.hpp:37
Namespace to correclty wrap MPI C API for modern C++.
Definition impl_async.hpp:14
int broadcast_span(std::span< T > data, size_t root)
Broadcasts data stored in a span to all processes.
Definition impl_op.hpp:493
std::vector< T > try_recv_v(size_t src, MPI_Status *status=nullptr, size_t tag=0)
Attempts to receive a vector of data items from a source.
Definition impl_op.hpp:412
int send_v(std::span< const DataType > data, size_t dest, size_t tag=0, bool send_size=true) noexcept
Sends a vector of data to a destination.
Definition impl_op.hpp:645
int broadcast(DataType &data, size_t root) noexcept
Broadcasts a single data item to all processes.
Definition impl_op.hpp:427
std::vector< T > gather(std::span< T > local_data, size_t n_rank, size_t root=0)
Gathers data from all processes.
Definition impl_op.hpp:559
std::vector< T > gather_v(const std::vector< T > &local_data, size_t n_rank, size_t root=0)
Gathers a vector of data from all processes.
Definition impl_op.hpp:592
SIGNALS
Definition message_t.hpp:13
@ NOP
Definition message_t.hpp:17
int recv_span(std::span< DataType > buf, size_t src, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives data into a span buffer.
Definition impl_op.hpp:396
T all_reduce(T data)
Definition impl_op.hpp:313
int send(DataType data, size_t dest, size_t tag=0)
Sends raw data to a destination in an unsafe manner.
Definition impl_op.hpp:635
T gather_reduce(T data, size_t root=0)
Gathers and reduces data to a single value.
Definition impl_op.hpp:574
int _broadcast_unsafe(T *data, size_t _size, size_t root)
Broadcasts raw data to all processes in an unsafe manner.
Definition impl_op.hpp:467
int critical_error() noexcept
Definition mpi_wrap.cpp:26
void host_dispatch(const ExecInfo &info, SIGNALS sign, Args &&... args)
Dispatches a task to the host with signal and arguments.
Definition impl_op.hpp:599
int _gather_unsafe_to_buffer(T *dest, T *src_data, size_t size, size_t root=0) noexcept
Definition impl_op.hpp:525
std::optional< DataType > recv(size_t src, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives a single data item from a source.
Definition impl_op.hpp:672
constexpr MPI_Datatype get_type() noexcept
Definition mpi_types.hpp:13
std::optional< std::vector< T > > recv_v(size_t source, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives a vector of data items from a source.
Definition impl_op.hpp:360
DataType try_recv(size_t src, MPI_Status *status=nullptr, size_t tag=0)
Attempts to receive a single data item from a source.
Definition impl_op.hpp:344
std::vector< T > _gather_unsafe(T *src_data, size_t size, size_t n_rank, size_t root=0)
Gathers raw data from all processes in an unsafe manner.
Definition impl_op.hpp:500
void gather_span(std::span< T > dest, std::span< const T > local_data, size_t root=0)
Definition impl_op.hpp:545
Definition execinfo.hpp:12
uint32_t n_rank
Definition execinfo.hpp:14