Skip to content

Commit

Permalink
Refactor helper code for enum types
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck committed Jan 19, 2024
1 parent 01989c0 commit baab930
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 32 deletions.
6 changes: 3 additions & 3 deletions include/analyze/all_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <unordered_set>

#include <analyze/find_stmt.h>
#include <container_utils.h>

namespace freetensor {

Expand All @@ -12,9 +13,8 @@ namespace freetensor {
*/
inline std::vector<std::pair<ID, std::string>>
allDefs(const Stmt &op,
const std::unordered_set<AccessType> &atypes = {
AccessType::Input, AccessType::Bypass, AccessType::Output,
AccessType::InOut, AccessType::InputMutable, AccessType::Cache}) {
const std::unordered_set<AccessType> &atypes =
allAccessTypes | ranges::to<std::unordered_set>() {
std::vector<std::pair<ID, std::string>> ret;
for (auto &&node : findAllStmt(op, [&](const Stmt &s) {
return s->nodeType() == ASTNodeType::VarDef &&
Expand Down
12 changes: 5 additions & 7 deletions include/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <ast.h>
#include <buffer.h>
#include <container_utils.h>
#include <except.h>
#include <for_property.h>
#include <reduce_op.h>

Expand Down Expand Up @@ -485,13 +487,9 @@ inline MatMulBackend parseMatMulBackend(const std::string &_str) {
return (MatMulBackend)i;
}
}
std::string msg = "Unrecognized MatMul backend \"" + _str +
"\". Candidates are (case-insensitive): ";
for (auto &&[i, s] : views::enumerate(matMulBackendNames)) {
msg += (i > 0 ? ", " : "");
msg += s;
}
ERROR(msg);
ERROR(FT_MSG << "Unrecognized MatMul backend \"" << _str
<< "\". Candidates are (case-insensitive): "
<< (matMulBackendNames | join(", ")));
}

/**
Expand Down
23 changes: 16 additions & 7 deletions include/type/access_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <string>

#include <container_utils.h>
#include <except.h>
#include <serialize/to_string.h>

namespace freetensor {
Expand Down Expand Up @@ -55,6 +56,18 @@ constexpr std::array accessTypeNames = {
};
static_assert(accessTypeNames.size() == (size_t)AccessType::NumTypes);

namespace detail {

template <typename T, T... i>
constexpr auto createAllAccessTypes(std::integer_sequence<T, i...>) {
return std::array{(AccessType)i...};
}

} // namespace detail

constexpr auto allAccessTypes = detail::createAllAccessTypes(
std::make_index_sequence<(size_t)AccessType::NumTypes>{});

inline std::ostream &operator<<(std::ostream &os, AccessType atype) {
return os << accessTypeNames.at((size_t)atype);
}
Expand All @@ -66,13 +79,9 @@ inline AccessType parseAType(const std::string &_str) {
return (AccessType)i;
}
}
std::string msg = "Unrecognized access type \"" + _str +
"\". Candidates are (case-insensitive): ";
for (auto &&[i, s] : views::enumerate(accessTypeNames)) {
msg += (i > 0 ? ", " : "");
msg += s;
}
ERROR(msg);
ERROR(FT_MSG << "Unrecognized access type \"" << _str
<< "\". Candidates are (case-insensitive): "
<< (accessTypeNames | join(", ")));
}

inline bool isWritable(AccessType atype) {
Expand Down
34 changes: 27 additions & 7 deletions include/type/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ constexpr std::array baseDataTypeNames = {
};
static_assert(baseDataTypeNames.size() == (size_t)BaseDataType::NumTypes);

namespace detail {

template <typename T, T... i>
constexpr auto createAllBaseDataTypes(std::integer_sequence<T, i...>) {
return std::array{(BaseDataType)i...};
}

} // namespace detail

constexpr auto allBaseDataTypes = detail::createAllBaseDataTypes(
std::make_index_sequence<(size_t)BaseDataType::NumTypes>{});

inline std::ostream &operator<<(std::ostream &os, BaseDataType dtype) {
return os << baseDataTypeNames.at((size_t)dtype);
}
Expand All @@ -41,13 +53,9 @@ inline BaseDataType parseBaseDataType(const std::string &_str) {
return (BaseDataType)i;
}
}
std::string msg = "Unrecognized base data type \"" + _str +
"\". Candidates are (case-insensitive): ";
for (auto &&[i, s] : views::enumerate(baseDataTypeNames)) {
msg += (i > 0 ? ", " : "");
msg += s;
}
ERROR(msg);
ERROR(FT_MSG << "Unrecognized base data type \"" << _str
<< "\". Candidates are (case-insensitive): "
<< (baseDataTypeNames | join(", ")));
}

enum class SignDataType : size_t {
Expand All @@ -69,6 +77,18 @@ constexpr std::array signDataTypeNames = {
};
static_assert(signDataTypeNames.size() == (size_t)SignDataType::NumTypes);

namespace detail {

template <typename T, T... i>
constexpr auto createAllSignDataTypes(std::integer_sequence<T, i...>) {
return std::array{(SignDataType)i...};
}

} // namespace detail

constexpr auto allSignDataTypes = detail::createAllSignDataTypes(
std::make_index_sequence<(size_t)SignDataType::NumTypes>{});

inline std::ostream &operator<<(std::ostream &os, SignDataType dtype) {
return os << signDataTypeNames.at((size_t)dtype);
}
Expand Down
23 changes: 16 additions & 7 deletions include/type/mem_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <string>

#include <container_utils.h>
#include <except.h>
#include <serialize/to_string.h>

namespace freetensor {
Expand All @@ -31,6 +32,18 @@ constexpr std::array memTypeNames = {
};
static_assert(memTypeNames.size() == (size_t)MemType::NumTypes);

namespace detail {

template <typename T, T... i>
constexpr auto createAllMemTypes(std::integer_sequence<T, i...>) {
return std::array{(MemType)i...};
}

} // namespace detail

constexpr auto allMemTypes = detail::createAllMemTypes(
std::make_index_sequence<(size_t)MemType::NumTypes>{});

inline std::ostream &operator<<(std::ostream &os, MemType mtype) {
return os << memTypeNames.at((size_t)mtype);
}
Expand All @@ -42,13 +55,9 @@ inline MemType parseMType(const std::string &_str) {
return (MemType)i;
}
}
std::string msg = "Unrecognized memory type \"" + _str +
"\". Candidates are (case-insensitive): ";
for (auto &&[i, s] : views::enumerate(memTypeNames)) {
msg += (i > 0 ? ", " : "");
msg += s;
}
ERROR(msg);
ERROR(FT_MSG << "Unrecognized memory type \"" << _str
<< "\". Candidates are (case-insensitive): "
<< (memTypeNames | join(", ")));
}

} // namespace freetensor
Expand Down
5 changes: 5 additions & 0 deletions src/math/parse_pb_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,11 @@ PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set) {
// implemented by a statement in multiple branches. We can recover Expr from
// the statement and the branches' conditions.

if (set.empty()) {
// It will result in empty block node in isl, which we cannot parse
return {};
}

ASSERT(set.isSingleValued());

std::vector<std::string> params =
Expand Down
2 changes: 1 addition & 1 deletion test/40.codegen/gpu/test_gpu_normalize_buffer_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def test(x, y, z):
t[i, j] = x[b, i, j] * 2
for j in range(10):
t[i, j] += t[i, i]
# The last dimension can be removed although accessed with i
# The last dimension can not be removed although accessed with i
#! label: L1
for i in range(10):
for j in range(10):
Expand Down

0 comments on commit baab930

Please sign in to comment.