From baab930119c156cba350e85de8bfbfe2056fc5f1 Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Fri, 19 Jan 2024 17:21:36 +0800 Subject: [PATCH] Refactor helper code for enum types --- include/analyze/all_defs.h | 6 ++-- include/stmt.h | 12 +++---- include/type/access_type.h | 23 +++++++++---- include/type/data_type.h | 34 +++++++++++++++---- include/type/mem_type.h | 23 +++++++++---- src/math/parse_pb_expr.cc | 5 +++ .../gpu/test_gpu_normalize_buffer_shapes.py | 2 +- 7 files changed, 73 insertions(+), 32 deletions(-) diff --git a/include/analyze/all_defs.h b/include/analyze/all_defs.h index 3b5450c0f..22c4cd8d1 100644 --- a/include/analyze/all_defs.h +++ b/include/analyze/all_defs.h @@ -4,6 +4,7 @@ #include #include +#include namespace freetensor { @@ -12,9 +13,8 @@ namespace freetensor { */ inline std::vector> allDefs(const Stmt &op, - const std::unordered_set &atypes = { - AccessType::Input, AccessType::Bypass, AccessType::Output, - AccessType::InOut, AccessType::InputMutable, AccessType::Cache}) { + const std::unordered_set &atypes = + allAccessTypes | ranges::to() { std::vector> ret; for (auto &&node : findAllStmt(op, [&](const Stmt &s) { return s->nodeType() == ASTNodeType::VarDef && diff --git a/include/stmt.h b/include/stmt.h index 55722cda3..f320f8e0b 100644 --- a/include/stmt.h +++ b/include/stmt.h @@ -6,6 +6,8 @@ #include #include +#include +#include #include #include @@ -485,13 +487,9 @@ inline MatMulBackend parseMatMulBackend(const std::string &_str) { return (MatMulBackend)i; } } - std::string msg = "Unrecognized MatMul backend \"" + _str + - "\". Candidates are (case-insensitive): "; - for (auto &&[i, s] : views::enumerate(matMulBackendNames)) { - msg += (i > 0 ? ", " : ""); - msg += s; - } - ERROR(msg); + ERROR(FT_MSG << "Unrecognized MatMul backend \"" << _str + << "\". Candidates are (case-insensitive): " + << (matMulBackendNames | join(", "))); } /** diff --git a/include/type/access_type.h b/include/type/access_type.h index ebb6a9b76..6ed2035d2 100644 --- a/include/type/access_type.h +++ b/include/type/access_type.h @@ -6,6 +6,7 @@ #include #include +#include #include namespace freetensor { @@ -55,6 +56,18 @@ constexpr std::array accessTypeNames = { }; static_assert(accessTypeNames.size() == (size_t)AccessType::NumTypes); +namespace detail { + +template +constexpr auto createAllAccessTypes(std::integer_sequence) { + return std::array{(AccessType)i...}; +} + +} // namespace detail + +constexpr auto allAccessTypes = detail::createAllAccessTypes( + std::make_index_sequence<(size_t)AccessType::NumTypes>{}); + inline std::ostream &operator<<(std::ostream &os, AccessType atype) { return os << accessTypeNames.at((size_t)atype); } @@ -66,13 +79,9 @@ inline AccessType parseAType(const std::string &_str) { return (AccessType)i; } } - std::string msg = "Unrecognized access type \"" + _str + - "\". Candidates are (case-insensitive): "; - for (auto &&[i, s] : views::enumerate(accessTypeNames)) { - msg += (i > 0 ? ", " : ""); - msg += s; - } - ERROR(msg); + ERROR(FT_MSG << "Unrecognized access type \"" << _str + << "\". Candidates are (case-insensitive): " + << (accessTypeNames | join(", "))); } inline bool isWritable(AccessType atype) { diff --git a/include/type/data_type.h b/include/type/data_type.h index fdccb8304..c23a38ef6 100644 --- a/include/type/data_type.h +++ b/include/type/data_type.h @@ -30,6 +30,18 @@ constexpr std::array baseDataTypeNames = { }; static_assert(baseDataTypeNames.size() == (size_t)BaseDataType::NumTypes); +namespace detail { + +template +constexpr auto createAllBaseDataTypes(std::integer_sequence) { + return std::array{(BaseDataType)i...}; +} + +} // namespace detail + +constexpr auto allBaseDataTypes = detail::createAllBaseDataTypes( + std::make_index_sequence<(size_t)BaseDataType::NumTypes>{}); + inline std::ostream &operator<<(std::ostream &os, BaseDataType dtype) { return os << baseDataTypeNames.at((size_t)dtype); } @@ -41,13 +53,9 @@ inline BaseDataType parseBaseDataType(const std::string &_str) { return (BaseDataType)i; } } - std::string msg = "Unrecognized base data type \"" + _str + - "\". Candidates are (case-insensitive): "; - for (auto &&[i, s] : views::enumerate(baseDataTypeNames)) { - msg += (i > 0 ? ", " : ""); - msg += s; - } - ERROR(msg); + ERROR(FT_MSG << "Unrecognized base data type \"" << _str + << "\". Candidates are (case-insensitive): " + << (baseDataTypeNames | join(", "))); } enum class SignDataType : size_t { @@ -69,6 +77,18 @@ constexpr std::array signDataTypeNames = { }; static_assert(signDataTypeNames.size() == (size_t)SignDataType::NumTypes); +namespace detail { + +template +constexpr auto createAllSignDataTypes(std::integer_sequence) { + return std::array{(SignDataType)i...}; +} + +} // namespace detail + +constexpr auto allSignDataTypes = detail::createAllSignDataTypes( + std::make_index_sequence<(size_t)SignDataType::NumTypes>{}); + inline std::ostream &operator<<(std::ostream &os, SignDataType dtype) { return os << signDataTypeNames.at((size_t)dtype); } diff --git a/include/type/mem_type.h b/include/type/mem_type.h index 30e490061..e538661b6 100644 --- a/include/type/mem_type.h +++ b/include/type/mem_type.h @@ -6,6 +6,7 @@ #include #include +#include #include namespace freetensor { @@ -31,6 +32,18 @@ constexpr std::array memTypeNames = { }; static_assert(memTypeNames.size() == (size_t)MemType::NumTypes); +namespace detail { + +template +constexpr auto createAllMemTypes(std::integer_sequence) { + return std::array{(MemType)i...}; +} + +} // namespace detail + +constexpr auto allMemTypes = detail::createAllMemTypes( + std::make_index_sequence<(size_t)MemType::NumTypes>{}); + inline std::ostream &operator<<(std::ostream &os, MemType mtype) { return os << memTypeNames.at((size_t)mtype); } @@ -42,13 +55,9 @@ inline MemType parseMType(const std::string &_str) { return (MemType)i; } } - std::string msg = "Unrecognized memory type \"" + _str + - "\". Candidates are (case-insensitive): "; - for (auto &&[i, s] : views::enumerate(memTypeNames)) { - msg += (i > 0 ? ", " : ""); - msg += s; - } - ERROR(msg); + ERROR(FT_MSG << "Unrecognized memory type \"" << _str + << "\". Candidates are (case-insensitive): " + << (memTypeNames | join(", "))); } } // namespace freetensor diff --git a/src/math/parse_pb_expr.cc b/src/math/parse_pb_expr.cc index e06a574a2..9fba14c91 100644 --- a/src/math/parse_pb_expr.cc +++ b/src/math/parse_pb_expr.cc @@ -310,6 +310,11 @@ PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set) { // implemented by a statement in multiple branches. We can recover Expr from // the statement and the branches' conditions. + if (set.empty()) { + // It will result in empty block node in isl, which we cannot parse + return {}; + } + ASSERT(set.isSingleValued()); std::vector params = diff --git a/test/40.codegen/gpu/test_gpu_normalize_buffer_shapes.py b/test/40.codegen/gpu/test_gpu_normalize_buffer_shapes.py index e08a18020..3f32d1321 100644 --- a/test/40.codegen/gpu/test_gpu_normalize_buffer_shapes.py +++ b/test/40.codegen/gpu/test_gpu_normalize_buffer_shapes.py @@ -429,7 +429,7 @@ def test(x, y, z): t[i, j] = x[b, i, j] * 2 for j in range(10): t[i, j] += t[i, i] - # The last dimension can be removed although accessed with i + # The last dimension can not be removed although accessed with i #! label: L1 for i in range(10): for j in range(10):