Skip to content

Commit

Permalink
Add new gather function for std::chrono::milliseconds.
Browse files Browse the repository at this point in the history
  • Loading branch information
breyerml committed Dec 12, 2024
1 parent 4d10527 commit 1eaccb8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
9 changes: 9 additions & 0 deletions include/plssvm/mpi/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mpi.h" // MPI_Comm, MPI_COMM_WORLD, MPI_Gather
#endif

#include <chrono> // std::chrono::milliseconds
#include <cstddef> // std::size_t
#include <functional> // std::invoke
#include <string> // std::string
Expand Down Expand Up @@ -126,6 +127,14 @@ class communicator {
*/
[[nodiscard]] std::vector<std::string> gather(const std::string &str) const;

/**
* @brief Gather the `std::chrono::milliseconds` @p duration from each MPI rank on the `communicator::main_rank()`.
* @details If `PLSSVM_HAS_MPI_ENABLED` is undefined, returns the provided @p duration wrapped in a `std::vector`.
* @param[in] duration the duration to gather at the main MPI rank
* @return a `std::vector` containing all gathered durations (`[[nodiscard]]`)
*/
[[nodiscard]] std::vector<std::chrono::milliseconds> gather(const std::chrono::milliseconds &duration) const;

#if defined(PLSSVM_HAS_MPI_ENABLED)
/**
* @brief Add implicit conversion operator back to a native MPI communicator.
Expand Down
25 changes: 22 additions & 3 deletions src/plssvm/mpi/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
#include "mpi.h"
#endif

#include <cstddef> // std::size_t
#include <string> // std::string
#include <vector> // std::vector
#include <algorithm> // std::transform
#include <chrono> // std::chrono::milliseconds
#include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <string> // std::string
#include <vector> // std::vector

namespace plssvm::mpi {

Expand Down Expand Up @@ -95,4 +98,20 @@ std::vector<std::string> communicator::gather(const std::string &str) const {
#endif
}

std::vector<std::chrono::milliseconds> communicator::gather(const std::chrono::milliseconds &duration) const {
#if defined(PLSSVM_HAS_MPI_ENABLED)
// convert the duration to an integer
const std::int64_t intermediate_dur = duration.count();
std::vector<std::int64_t> intermediate_result(this->size());
// gather the integer values from each MPI rank
PLSSVM_MPI_ERROR_CHECK(MPI_Gather(&intermediate_dur, 1, detail::mpi_datatype<std::int64_t>(), intermediate_result.data(), 1, detail::mpi_datatype<std::int64_t>(), communicator::main_rank(), comm_));
// cast integers back to durations
std::vector<std::chrono::milliseconds> result(this->size());
std::transform(intermediate_result.cbegin(), intermediate_result.cend(), result.begin(), [](const std::int64_t dur) { return static_cast<std::chrono::milliseconds>(dur); });
return result;
#else
return { duration };
#endif
}

} // namespace plssvm::mpi

0 comments on commit 1eaccb8

Please sign in to comment.