Skip to content

Commit

Permalink
Improve check functions (#125)
Browse files Browse the repository at this point in the history
* Add a trait to extract the base value type from a container

* extend check functions to work on std::array or std::vector

* Avoid throw from a check function

* remove uunnecessary check from are_valid_axes and add overlaped axes tests

* rename: key to value for find function

* fix: remove a variable rank which is no longer used

---------

Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Aug 2, 2024
1 parent 0b1ab3d commit e3454ae
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 36 deletions.
25 changes: 25 additions & 0 deletions common/src/KokkosFFT_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,31 @@ inline constexpr bool are_operatable_views_v =

// Other traits

template <typename ContainerType>
struct base_container_value;

template <template <typename, typename...> class ContainerType,
typename ValueType, typename... Args>
struct base_container_value<ContainerType<ValueType, Args...>> {
using value_type = ValueType;
};

// Specialization for std::array
template <typename ValueType, std::size_t N>
struct base_container_value<std::array<ValueType, N>> {
using value_type = ValueType;
};

// Specialization for Kokkos::Array
template <typename ValueType, std::size_t N>
struct base_container_value<Kokkos::Array<ValueType, N>> {
using value_type = ValueType;
};

/// \brief Helper to extract the base value type from a container
template <typename T>
using base_container_value_type = typename base_container_value<T>::value_type;

/// \brief Helper to define a managable View type from the original view type
template <typename T>
struct managable_view_type {
Expand Down
72 changes: 58 additions & 14 deletions common/src/KokkosFFT_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <algorithm>
#include <numeric>
#include "KokkosFFT_traits.hpp"
#include "KokkosFFT_common_types.hpp"

#if defined(KOKKOS_ENABLE_CXX17)
#include <cstdlib>
Expand Down Expand Up @@ -85,43 +86,86 @@ auto convert_negative_shift(const ViewType& view, int _shift, int _axis) {
return std::tuple<int, int, int>({shift0, shift1, shift2});
}

template <typename T>
bool is_found(std::vector<T>& values, const T& key) {
return std::find(values.begin(), values.end(), key) != values.end();
template <typename ContainerType, typename ValueType>
bool is_found(ContainerType& values, const ValueType& value) {
using value_type = KokkosFFT::Impl::base_container_value_type<ContainerType>;
static_assert(std::is_same_v<value_type, ValueType>,
"Container value type must match ValueType");
return std::find(values.begin(), values.end(), value) != values.end();
}

template <typename T>
bool has_duplicate_values(const std::vector<T>& values) {
std::set<T> set_values(values.begin(), values.end());
template <typename ContainerType>
bool has_duplicate_values(const ContainerType& values) {
using value_type = KokkosFFT::Impl::base_container_value_type<ContainerType>;
std::set<value_type> set_values(values.begin(), values.end());
return set_values.size() < values.size();
}

template <typename IntType, std::enable_if_t<std::is_integral_v<IntType>,
std::nullptr_t> = nullptr>
bool is_out_of_range_value_included(const std::vector<IntType>& values,
IntType max) {
template <
typename ContainerType, typename IntType,
std::enable_if_t<std::is_integral_v<IntType>, std::nullptr_t> = nullptr>
bool is_out_of_range_value_included(const ContainerType& values, IntType max) {
using value_type = KokkosFFT::Impl::base_container_value_type<ContainerType>;
static_assert(std::is_same_v<value_type, IntType>,
"Container value type must match IntType");
bool is_included = false;
for (auto value : values) {
is_included = value >= max;
}
return is_included;
}

template <
typename ViewType, template <typename, std::size_t> class ArrayType,
typename IntType, std::size_t DIM = 1,
std::enable_if_t<std::is_integral_v<IntType>, std::nullptr_t> = nullptr>
bool are_valid_axes(const ViewType& view, const ArrayType<IntType, DIM>& axes) {
static_assert(
DIM >= 1 && DIM <= KokkosFFT::MAX_FFT_DIM,
"are_valid_axes: the Rank of FFT axes must be between 1 and MAX_FFT_DIM");
static_assert(ViewType::rank() >= DIM,
"are_valid_axes: View rank must be larger than or equal to the "
"Rank of FFT axes");

// Convert the input axes to be in the range of [0, rank-1]
// int type is choosen for consistency with the rest of the code
// the axes are defined with int type
std::array<int, DIM> non_negative_axes;

// In case axis is out of range, 'convert_negative_axis' will throw an
// runtime_error and we will return false. Without runtime_error, it is
// ensured that the 'non_negative_axes' are in the range of [0, rank-1]
try {
for (std::size_t i = 0; i < DIM; i++) {
int axis = KokkosFFT::Impl::convert_negative_axis(view, axes[i]);
non_negative_axes[i] = axis;
}
} catch (std::runtime_error& e) {
return false;
}

bool is_valid = !KokkosFFT::Impl::has_duplicate_values(non_negative_axes);
return is_valid;
}

template <std::size_t DIM = 1>
bool is_transpose_needed(std::array<int, DIM> map) {
std::array<int, DIM> contiguous_map;
std::iota(contiguous_map.begin(), contiguous_map.end(), 0);
return map != contiguous_map;
}

template <typename T>
std::size_t get_index(std::vector<T>& values, const T& key) {
auto it = find(values.begin(), values.end(), key);
template <typename ContainerType, typename ValueType>
std::size_t get_index(ContainerType& values, const ValueType& value) {
using value_type = KokkosFFT::Impl::base_container_value_type<ContainerType>;
static_assert(std::is_same_v<value_type, ValueType>,
"Container value type must match ValueType");
auto it = std::find(values.begin(), values.end(), value);
std::size_t index = 0;
if (it != values.end()) {
index = it - values.begin();
} else {
throw std::runtime_error("key is not included in values");
throw std::runtime_error("value is not included in values");
}

return index;
Expand Down
49 changes: 49 additions & 0 deletions common/unit_test/Test_Traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
// All the tests in this file are compile time tests, so we skip all the tests
// by GTEST_SKIP(). gtest is used for type parameterization.

// Int like types
using base_int_types = ::testing::Types<int, std::size_t>;

// Define the types to combine
using base_real_types = std::tuple<float, double, long double>;

Expand Down Expand Up @@ -39,6 +42,19 @@ using paired_view_types =
tuple_to_types_t<cartesian_product_t<base_real_types, base_layout_types,
base_real_types, base_layout_types>>;

template <typename T>
struct ContainerTypes : public ::testing::Test {
static constexpr std::size_t rank = 3;
using value_type = T;
using vector_type = std::vector<T>;
using std_array_type = std::array<T, rank>;
using Kokkos_array_type = Kokkos::Array<T, rank>;

virtual void SetUp() {
GTEST_SKIP() << "Skipping all tests for this fixture";
}
};

template <typename T>
struct RealAndComplexTypes : public ::testing::Test {
using real_type = T;
Expand Down Expand Up @@ -91,12 +107,45 @@ struct PairedViewTypes : public ::testing::Test {
}
};

TYPED_TEST_SUITE(ContainerTypes, base_int_types);
TYPED_TEST_SUITE(RealAndComplexTypes, real_types);
TYPED_TEST_SUITE(RealAndComplexViewTypes, view_types);
TYPED_TEST_SUITE(PairedValueTypes, paired_value_types);
TYPED_TEST_SUITE(PairedLayoutTypes, paired_layout_types);
TYPED_TEST_SUITE(PairedViewTypes, paired_view_types);

// Tests for base value type deduction
template <typename ValueType, typename ContainerType>
void test_get_container_value_type() {
using value_type_ContainerType =
KokkosFFT::Impl::base_container_value_type<ContainerType>;

// base value type of ContainerType is ValueType
static_assert(std::is_same_v<value_type_ContainerType, ValueType>,
"Value type not deduced correctly from ContainerType");
}

TYPED_TEST(ContainerTypes, get_value_type_from_vector) {
using value_type = typename TestFixture::value_type;
using container_type = typename TestFixture::vector_type;

test_get_container_value_type<value_type, container_type>();
}

TYPED_TEST(ContainerTypes, get_value_type_from_std_array) {
using value_type = typename TestFixture::value_type;
using container_type = typename TestFixture::std_array_type;

test_get_container_value_type<value_type, container_type>();
}

TYPED_TEST(ContainerTypes, get_value_type_from_kokkos_array) {
using value_type = typename TestFixture::value_type;
using container_type = typename TestFixture::Kokkos_array_type;

test_get_container_value_type<value_type, container_type>();
}

// Tests for real type deduction
template <typename RealType, typename ComplexType>
void test_get_real_type() {
Expand Down
Loading

0 comments on commit e3454ae

Please sign in to comment.