diff --git a/include/plssvm/mpi/communicator.hpp b/include/plssvm/mpi/communicator.hpp index 851516ed3..ba4afbb60 100644 --- a/include/plssvm/mpi/communicator.hpp +++ b/include/plssvm/mpi/communicator.hpp @@ -14,12 +14,18 @@ #define PLSSVM_MPI_COMMUNICATOR_HPP_ #pragma once +#include "plssvm/mpi/detail/mpi_datatype.hpp" // plssvm::mpi::detail::mpi_datatype +#include "plssvm/mpi/detail/utility.hpp" // PLSSVM_MPI_ERROR_CHECK + #if defined(PLSSVM_HAS_MPI_ENABLED) - #include "mpi.h" // MPI_Comm, MPI_COMM_WORLD + #include "mpi.h" // MPI_Comm, MPI_COMM_WORLD, MPI_Gather #endif -#include // std::size_t -#include // std::invoke +#include // std::transform +#include // std::size_t +#include // std::invoke +#include // std::is_enum_v, std::underlying_type_t +#include // std::vector namespace plssvm::mpi { @@ -95,6 +101,24 @@ class communicator { } } + /** + * @brief Gather the @p value from each MPI rank on the `communicator::main_rank()`. + * @details If `PLSSVM_HAS_MPI_ENABLED` is undefined, returns the provided @p value wrapped in a `std::vector`. + * @tparam T the type of the values to gather + * @param value the value to gather at the main MPI rank + * @return a `std::vector` containing all gathered values (`[[nodiscard]]`) + */ + template + [[nodiscard]] std::vector gather(T value) { +#if defined(PLSSVM_HAS_MPI_ENABLED) + std::vector result(this->size()); + PLSSVM_MPI_ERROR_CHECK(MPI_Gather(&value, 1, detail::mpi_datatype(), result.data(), 1, detail::mpi_datatype(), communicator::main_rank(), comm_)); + return result; +#else + return { value }; +#endif + } + #if defined(PLSSVM_HAS_MPI_ENABLED) /** * @brief Add implicit conversion operator back to a native MPI communicator.