BioCMAMC-ST
impl_op.hpp
1#ifndef __IMPL_MPI_OP_HPP__
2#define __IMPL_MPI_OP_HPP__
3
4#include "common/common.hpp"
5#include <mpi_w/message_t.hpp>
6#include <mpi_w/mpi_types.hpp>
7
8#include <common/execinfo.hpp>
9#include <cstddef>
10#include <limits>
11#include <math.h>
12#include <mpi.h>
13#include <optional>
14#include <span>
15#include <stdexcept>
16#include <vector>
17
21namespace WrapMPI
22{
23
24 // SENDING
25
42 template <POD_t DataType>
43 [[nodiscard]] static int _send_unsafe(DataType* buf,
44 size_t buf_size,
45 size_t dest,
46 size_t tag = 0) noexcept;
47
60 template <POD_t DataType>
61 int send(DataType data, size_t dest, size_t tag = 0);
62
77 template <POD_t DataType>
78 int send_v(std::span<const DataType> data,
79 size_t dest,
80 size_t tag = 0,
81 bool send_size = true) noexcept;
82
83 // RECEIVE
84
100 template <POD_t DataType>
101 std::optional<DataType>
102 recv(size_t src, MPI_Status* status = nullptr, size_t tag = 0) noexcept;
103
119 template <POD_t DataType>
120 int recv_span(std::span<DataType> buf,
121 size_t src,
122 MPI_Status* status = nullptr,
123 size_t tag = 0) noexcept;
124
139 template <POD_t DataType>
140 DataType try_recv(size_t src, MPI_Status* status = nullptr, size_t tag = 0);
141
157 template <POD_t T>
158 std::optional<std::vector<T>>
159 recv_v(size_t source, MPI_Status* status = nullptr, size_t tag = 0) noexcept;
160
175 template <POD_t T>
176 std::vector<T>
177 try_recv_v(size_t src, MPI_Status* status = nullptr, size_t tag = 0);
178
179 // BROADCASTING
180
196 template <POD_t T>
197 [[nodiscard]] int _broadcast_unsafe(T* data, size_t _size, size_t root);
198
211 template <POD_t DataType> int broadcast(DataType& data, size_t root) noexcept;
212
225 template <POD_t T> int broadcast_span(std::span<T> data, size_t root);
226
239 template <POD_t... Args>
240 void host_dispatch(const ExecInfo& info, SIGNALS sign, Args&&... args);
241
260 template <POD_t T>
261 std::vector<T>
262 _gather_unsafe(T* src_data, size_t size, size_t n_rank, size_t root = 0);
263
264 template <POD_t T>
265 int _gather_unsafe_to_buffer(T* dest,
266 T* src_data,
267 size_t size,
268 size_t root = 0) noexcept;
269
283 template <POD_t T>
284 std::vector<T>
285 gather(std::span<T> local_data, size_t n_rank, size_t root = 0);
286
287 template <POD_t T>
288 void gather_span(std::span<T> dest,
289 std::span<const T> local_data,
290 size_t root = 0);
291
292 template <POD_t T>
293 std::vector<T>
294 gather(std::span<const T> local_data, size_t n_rank, size_t root = 0);
295
308 template <NumberType T> T gather_reduce(T data, size_t root = 0);
309
310 template <NumberType T> T all_reduce(T data)
311 {
312 T global_sum{}; // global sum across all ran
313 MPI_Allreduce(
314 &data, &global_sum, 1, get_type<T>(), MPI_SUM, MPI_COMM_WORLD);
315 return global_sum;
316 }
317
331 template <POD_t T>
332 std::vector<T>
333 gather_v(const std::vector<T>& local_data, size_t n_rank, size_t root = 0);
334
335 //**
336 // IMPL
337 //**
338 template <POD_t DataType>
339 static int
340 _send_unsafe(DataType* buf, size_t buf_size, size_t dest, size_t tag) noexcept
341 {
342 return MPI_Send(
343 buf, buf_size, get_type<DataType>(), dest, tag, MPI_COMM_WORLD);
344 }
345
346 template <POD_t DataType>
347 DataType try_recv(size_t src, MPI_Status* status, size_t tag)
348 {
349 auto opt_data = WrapMPI::recv<DataType>(src, status, tag);
350 if (!opt_data.has_value())
351 {
353 exit(-1); // critical_error should exit before reaching this statement
354 }
355 else
356 {
357 return opt_data.value();
358 }
359 }
360
361 template <POD_t T>
362 std::optional<std::vector<T>>
363 recv_v(size_t source, MPI_Status* status, size_t tag) noexcept
364 {
365 std::vector<T> buf;
366
367 // Receive the size of the vector
368 auto opt_size = recv<size_t>(source, status, tag);
369 if (!opt_size.has_value())
370 {
371 return std::nullopt; // Return early if size reception fails
372 }
373 size_t buf_size = opt_size.value();
374
375 // Resize the buffer
376 buf.resize(buf_size);
377
378 MPI_Datatype datatype = get_type<T>();
379
380 // Receive the vector data
381 int recv_status = MPI_Recv(buf.data(),
382 static_cast<int>(buf_size),
383 datatype,
384 source,
385 tag,
386 MPI_COMM_WORLD,
387 status);
388
389 if (recv_status != MPI_SUCCESS)
390 {
391 return std::nullopt; // Return early if MPI_Recv fails
392 }
393
394 return buf; // Return the received vector
395 }
396
397 template <POD_t DataType>
398 int recv_span(std::span<DataType> buf,
399 size_t src,
400 MPI_Status* status,
401 size_t tag) noexcept
402 {
403 return MPI_Recv(buf.data(),
404 buf.size(),
406 src,
407 tag,
408 MPI_COMM_WORLD,
409 status);
410 }
411
412 template <POD_t T>
413 std::vector<T> try_recv_v(size_t src, MPI_Status* status, size_t tag)
414 {
415 auto opt_data = WrapMPI::recv_v<T>(src, status, tag);
416 if (!opt_data.has_value())
417 {
419 }
420 else
421 {
422 return opt_data.value();
423 }
424 }
425
426 template <POD_t DataType> int broadcast(DataType& data, size_t root) noexcept
427 {
428 return MPI_Bcast(
429 &data, 1, get_type<DataType>(), static_cast<int>(root), MPI_COMM_WORLD);
430 }
431
432 // template <> int broadcast(size_t &data, size_t root)
433 // {
434 // return MPI_Bcast(&data,
435 // sizeof(size_t),
436 // MPI_UNSIGNED_LONG,
437 // static_cast<int>(root),
438 // MPI_COMM_WORLD);
439 // }
440
441 template <POD_t T>
442 int broadcast(std::vector<T>& data, size_t root, size_t current_rank)
443 {
444
445 size_t data_size = 0;
446 if (current_rank == root)
447 {
448 data_size = data.size();
449 }
450 broadcast(data_size, root);
451 if (current_rank != root)
452 {
453 data.resize(data_size);
454 }
455
456 return MPI_Bcast(data.data(),
457 data_size,
458 get_type<T>(),
459 static_cast<int>(root),
460 MPI_COMM_WORLD);
461 }
462
463 template <POD_t T> int _broadcast_unsafe(T* data, size_t _size, size_t root)
464 {
465
466 if (data == nullptr)
467 {
468 throw std::invalid_argument("Data pointer is null");
469 }
470 if (_size == 0 || _size > std::numeric_limits<size_t>::max())
471 {
472 throw std::invalid_argument("Error size");
473 }
474
475 int comm_size = 0;
476 MPI_Comm_size(MPI_COMM_WORLD, &comm_size);
477 if (root >= static_cast<size_t>(comm_size))
478 {
479 throw std::invalid_argument("Root process rank is out of range");
480 }
481
482 // Broadcast operation
483 return MPI_Bcast(
484 data, _size, get_type<T>(), static_cast<int>(root), MPI_COMM_WORLD);
485 }
486
487 template <POD_t T> int broadcast_span(std::span<T> data, size_t root)
488 {
489 return _broadcast_unsafe(data.data(), data.size(), root);
490 }
491
492 template <POD_t T>
493 std::vector<T>
494 _gather_unsafe(T* src_data, size_t size, size_t n_rank, size_t root)
495 {
496 int src_size = static_cast<int>(size);
497 std::vector<T> total_data(src_size * n_rank);
498 T* dest_data = total_data.data();
499 // auto mpi_type = get_type<T>();
500
501 // int gather_result = MPI_Gather(
502 // src_data, src_size, mpi_type, dest_data, src_size, mpi_type, root,
503 // MPI_COMM_WORLD);
504 // if (gather_result != MPI_SUCCESS)
505 // {
506 // throw std::runtime_error("MPI_Gather failed");
507 // }
508 if (_gather_unsafe_to_buffer(dest_data, src_data, size, root) !=
509 MPI_SUCCESS)
510 {
511 throw std::runtime_error("MPI_Gather failed");
512 }
513
514 return total_data;
515 }
516
517 template <POD_t T>
518 int _gather_unsafe_to_buffer(T* const dest,
519 T* src_data,
520 size_t size,
521 size_t root) noexcept
522 {
523 int src_size = static_cast<int>(size);
524 auto mpi_type = get_type<T>();
525
526 return MPI_Gather(src_data,
527 src_size,
528 mpi_type,
529 dest,
530 src_size,
531 mpi_type,
532 root,
533 MPI_COMM_WORLD);
534 }
535
536 template <POD_t T>
537 void
538 gather_span(std::span<T> dest, std::span<const T> local_data, size_t root)
539 {
540 T* dest_data = dest.data();
541#ifndef NDEBUG
542 int size{};
543 MPI_Comm_size(MPI_COMM_WORLD, &size);
544 assert(dest.size() == local_data.size() * size);
545#endif
547 dest_data, const_cast<T*>(local_data.data()), local_data.size(), root);
548 }
549
550 template <POD_t T>
551 std::vector<T> gather(std::span<T> local_data, size_t n_rank, size_t root)
552 {
553 return _gather_unsafe(local_data.data(), local_data.size(), n_rank, root);
554 }
555 // FIXME
556 template <POD_t T>
557 std::vector<T>
558 gather(std::span<const T> local_data, size_t n_rank, size_t root)
559 {
560 return _gather_unsafe(
561 const_cast<T*>(local_data.data()), local_data.size(), n_rank, root);
562 }
563
564 template <NumberType T> T gather_reduce(T data, size_t root)
565 {
566
567 T result{};
568
569 MPI_Reduce(&data,
570 &result,
571 1,
573 MPI_SUM,
574 root,
575 MPI_COMM_WORLD);
576
577 return result;
578 }
579
580 template <POD_t T>
581 std::vector<T>
582 gather_v(const std::vector<T>& local_data, size_t n_rank, size_t root)
583 {
584 return _gather_unsafe(local_data.data(), local_data.size(), n_rank, root);
585 }
586
587 template <POD_t... Args>
588 void host_dispatch(const ExecInfo& info, SIGNALS sign, Args&&... args)
589 {
590
591 for (int j = 1; j < static_cast<int>(info.n_rank); ++j)
592 {
593 if (sign != WrapMPI::SIGNALS::NOP)
594 {
595 MPI_Send(&sign, sizeof(sign), MPI_CHAR, j, 0, MPI_COMM_WORLD);
596 }
597
598 (
599 [&]<POD_t T>(T& arg)
600 {
601 size_t s = 0;
602 void* buf = nullptr;
603 if constexpr (std::is_same_v<std::decay_t<T>, std::span<double>>)
604 {
605 s = arg.size();
606 buf = arg.data();
607 MPI_Send(&s, 1, MPI_UNSIGNED_LONG, j, 0, MPI_COMM_WORLD);
608 }
609 else
610 {
611 s = sizeof(T);
612 buf = &arg;
613 }
614
615 MPI_Send(
616 buf, static_cast<int>(s), MPI_DOUBLE, j, 0, MPI_COMM_WORLD);
617 }(std::forward<Args>(args)),
618 ...);
619 }
620 }
621
622 template <POD_t DataType> int send(DataType data, size_t dest, size_t tag)
623 {
624 return _send_unsafe<DataType>(&data, 1, dest, tag);
625 }
626
627 template <POD_t DataType>
628 int send_v(std::span<const DataType> data,
629 size_t dest,
630 size_t tag,
631 bool send_size) noexcept
632 {
633 int send_status = MPI_SUCCESS;
634
635 if (send_size)
636 {
637 send_status = send<size_t>(data.size(), dest, tag);
638 }
639
640 if (send_status == MPI_SUCCESS)
641 {
642 send_status = _send_unsafe(data.data(), data.size(), dest, tag);
643 }
644
645 return send_status;
646 }
647
648 template <POD_t DataType>
649 std::optional<DataType>
650 recv(size_t src, MPI_Status* status, size_t tag) noexcept
651 {
652 DataType buf;
653
654 int recv_status = MPI_Recv(
655 &buf, sizeof(DataType), MPI_BYTE, src, tag, MPI_COMM_WORLD, status);
656 if (recv_status != MPI_SUCCESS)
657 {
658 return std::nullopt;
659 }
660 return buf;
661 }
662
663} // namespace WrapMPI
664
665#endif //__IMPL_MPI_OP_HPP__
Definition traits.hpp:33
Definition message_t.hpp:21
Namespace to correclty wrap MPI C API for modern C++.
Definition impl_async.hpp:18
int broadcast_span(std::span< T > data, size_t root)
Broadcasts data stored in a span to all processes.
Definition impl_op.hpp:487
std::vector< T > try_recv_v(size_t src, MPI_Status *status=nullptr, size_t tag=0)
Attempts to receive a vector of data items from a source.
Definition impl_op.hpp:413
int send_v(std::span< const DataType > data, size_t dest, size_t tag=0, bool send_size=true) noexcept
Sends a vector of data to a destination.
Definition impl_op.hpp:628
int broadcast(DataType &data, size_t root) noexcept
Broadcasts a single data item to all processes.
Definition impl_op.hpp:426
std::vector< T > gather(std::span< T > local_data, size_t n_rank, size_t root=0)
Gathers data from all processes.
Definition impl_op.hpp:551
std::vector< T > gather_v(const std::vector< T > &local_data, size_t n_rank, size_t root=0)
Gathers a vector of data from all processes.
Definition impl_op.hpp:582
SIGNALS
Definition message_t.hpp:13
@ NOP
Definition message_t.hpp:16
int recv_span(std::span< DataType > buf, size_t src, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives data into a span buffer.
Definition impl_op.hpp:398
T all_reduce(T data)
Definition impl_op.hpp:310
int send(DataType data, size_t dest, size_t tag=0)
Sends a single instance of data to a destination.
Definition impl_op.hpp:622
T gather_reduce(T data, size_t root=0)
Gathers and reduces data to a single value.
Definition impl_op.hpp:564
int _broadcast_unsafe(T *data, size_t _size, size_t root)
Broadcasts raw data to all processes in an unsafe manner.
Definition impl_op.hpp:463
int critical_error() noexcept
Definition mpi_wrap.cpp:23
void host_dispatch(const ExecInfo &info, SIGNALS sign, Args &&... args)
Dispatches a task to the host with signal and arguments.
Definition impl_op.hpp:588
int _gather_unsafe_to_buffer(T *dest, T *src_data, size_t size, size_t root=0) noexcept
Definition impl_op.hpp:518
std::optional< DataType > recv(size_t src, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives a single data item from a source.
Definition impl_op.hpp:650
constexpr MPI_Datatype get_type() noexcept
Definition mpi_types.hpp:11
std::optional< std::vector< T > > recv_v(size_t source, MPI_Status *status=nullptr, size_t tag=0) noexcept
Receives a vector of data items from a source.
Definition impl_op.hpp:363
DataType try_recv(size_t src, MPI_Status *status=nullptr, size_t tag=0)
Attempts to receive a single data item from a source.
Definition impl_op.hpp:347
std::vector< T > _gather_unsafe(T *src_data, size_t size, size_t n_rank, size_t root=0)
Gathers raw data from all processes in an unsafe manner.
Definition impl_op.hpp:494
void gather_span(std::span< T > dest, std::span< const T > local_data, size_t root=0)
Definition impl_op.hpp:538
static int _send_unsafe(DataType *buf, size_t buf_size, size_t dest, size_t tag=0) noexcept
Sends raw data to a destination in an unsafe manner.
Definition impl_op.hpp:340
Definition execinfo.hpp:12
uint32_t n_rank
Definition execinfo.hpp:14