Skip to content

Commit

Permalink
Make inter-PBCtx transfer type-safe
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck committed Apr 20, 2024
1 parent 4772cd5 commit d79c212
Show file tree
Hide file tree
Showing 12 changed files with 141 additions and 57 deletions.
12 changes: 8 additions & 4 deletions include/math/parse_pb_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ typedef std::vector<SimplePBFuncAST> PBFuncAST;

/**
* Parse a PBFunc to be ASTs
*
* @{
*/
PBFuncAST parsePBFunc(const std::string &str);
PBFuncAST parsePBFunc(const PBFunc::Serialized &f);
PBFuncAST parsePBFunc(const PBSingleFunc::Serialized &f);
/** @} */

/**
* Construct AST from PBSet while preserving min and max with a special hack to
Expand All @@ -44,10 +48,10 @@ PBFuncAST parsePBFuncReconstructMinMax(const PBMap &map);
*
* @{
*/
inline SimplePBFuncAST parseSimplePBFunc(const std::string &str) {
auto ret = parsePBFunc(str);
inline SimplePBFuncAST parseSimplePBFunc(const auto &f) {
auto ret = parsePBFunc(f);
if (ret.size() != 1) {
throw ParserError(str + " is not a simple PBFunc");
throw ParserError(FT_MSG << f << " is not a simple PBFunc");
}
return ret.front();
}
Expand Down
90 changes: 84 additions & 6 deletions include/math/presburger.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ template <class T> T *MOVE_ISL_PTR(T *&ptr) {
return ret;
}

/**
* Context for presburger operation
*
* - All operands of a presburger operation should be on the same context.
* - Operations in the same context is NOT thread-safe. Explitly transfer to a
* different context if you want to use in multiple threads.
*/
class PBCtx {
isl_ctx *ctx_ = nullptr;

Expand Down Expand Up @@ -113,9 +120,21 @@ class PBMap {
isl_map *copy() const { return COPY_ISL_PTR(map_, map); }
isl_map *move() { return MOVE_ISL_PTR(map_); }

PBMap to(const Ref<PBCtx> &ctx) const {
return {ctx, isl_map_to_str(get())};
}
class Serialized {
std::string data_;

public:
Serialized() {}
Serialized(const std::string &data) : data_(data) {}
PBMap to(const Ref<PBCtx> &ctx) const { return {ctx, data_}; }
bool isValid() const { return !data_.empty(); }
const auto &data() const { return data_; }
friend std::ostream &operator<<(std::ostream &os, const Serialized &s) {
return os << s.data_;
}
};
Serialized toSerialized() const { return {isl_map_to_str(get())}; }
PBMap to(const Ref<PBCtx> &ctx) const { return toSerialized().to(ctx); }

bool empty() const {
DEBUG_PROFILE("empty");
Expand Down Expand Up @@ -249,9 +268,21 @@ class PBSet {
isl_set *copy() const { return COPY_ISL_PTR(set_, set); }
isl_set *move() { return MOVE_ISL_PTR(set_); }

PBSet to(const Ref<PBCtx> &ctx) const {
return {ctx, isl_set_to_str(get())};
}
class Serialized {
std::string data_;

public:
Serialized() {}
Serialized(const std::string &data) : data_(data) {}
PBSet to(const Ref<PBCtx> &ctx) const { return {ctx, data_}; }
bool isValid() const { return !data_.empty(); }
const auto &data() const { return data_; }
friend std::ostream &operator<<(std::ostream &os, const Serialized &s) {
return os << s.data_;
}
};
Serialized toSerialized() const { return {isl_set_to_str(get())}; }
PBSet to(const Ref<PBCtx> &ctx) const { return toSerialized().to(ctx); }

bool empty() const {
DEBUG_PROFILE("empty");
Expand Down Expand Up @@ -351,6 +382,12 @@ class PBSingleFunc {
PBSingleFunc() {}
PBSingleFunc(const Ref<PBCtx> &ctx, isl_pw_aff *func)
: ctx_(ctx), func_(func) {}
PBSingleFunc(const Ref<PBCtx> &ctx, const std::string &str)
: ctx_(ctx), func_(isl_pw_aff_read_from_str(ctx->get(), str.c_str())) {
if (func_ == nullptr) {
ERROR("Unable to construct an PBSingleFunc from " + str);
}
}
explicit PBSingleFunc(const Ref<PBCtx> &ctx, isl_aff *func)
: ctx_(ctx), func_(isl_pw_aff_from_aff(func)) {}

Expand Down Expand Up @@ -391,6 +428,24 @@ class PBSingleFunc {
isl_pw_aff *copy() const { return COPY_ISL_PTR(func_, pw_aff); }
isl_pw_aff *move() { return MOVE_ISL_PTR(func_); }

class Serialized {
std::string data_;

public:
Serialized() {}
Serialized(const std::string &data) : data_(data) {}
PBSingleFunc to(const Ref<PBCtx> &ctx) const { return {ctx, data_}; }
bool isValid() const { return !data_.empty(); }
const auto &data() const { return data_; }
friend std::ostream &operator<<(std::ostream &os, const Serialized &s) {
return os << s.data_;
}
};
Serialized toSerialized() const { return {isl_pw_aff_to_str(get())}; }
PBSingleFunc to(const Ref<PBCtx> &ctx) const {
return toSerialized().to(ctx);
}

isl_size nInDims() const { return isl_pw_aff_dim(get(), isl_dim_in); }

std::vector<std::pair<PBSet, PBSingleFunc>> pieces() const {
Expand Down Expand Up @@ -424,6 +479,13 @@ class PBFunc {
PBFunc() {}
PBFunc(const Ref<PBCtx> &ctx, isl_pw_multi_aff *func)
: ctx_(ctx), func_(func) {}
PBFunc(const Ref<PBCtx> &ctx, const std::string &str)
: ctx_(ctx),
func_(isl_pw_multi_aff_read_from_str(ctx->get(), str.c_str())) {
if (func_ == nullptr) {
ERROR("Unable to construct an PBFunc from " + str);
}
}

PBFunc(const PBSingleFunc &singleFunc)
: ctx_(singleFunc.ctx()),
Expand Down Expand Up @@ -479,6 +541,22 @@ class PBFunc {
isl_pw_multi_aff *copy() const { return COPY_ISL_PTR(func_, pw_multi_aff); }
isl_pw_multi_aff *move() { return MOVE_ISL_PTR(func_); }

class Serialized {
std::string data_;

public:
Serialized() {}
Serialized(const std::string &data) : data_(data) {}
PBFunc to(const Ref<PBCtx> &ctx) const { return {ctx, data_}; }
bool isValid() const { return !data_.empty(); }
const auto &data() const { return data_; }
friend std::ostream &operator<<(std::ostream &os, const Serialized &s) {
return os << s.data_;
}
};
Serialized toSerialized() const { return {isl_pw_multi_aff_to_str(get())}; }
PBFunc to(const Ref<PBCtx> &ctx) const { return toSerialized().to(ctx); }

isl_size nInDims() const { return isl_pw_multi_aff_dim(get(), isl_dim_in); }
isl_size nOutDims() const {
return isl_pw_multi_aff_dim(get(), isl_dim_out);
Expand Down
12 changes: 4 additions & 8 deletions src/analyze/check_not_modified.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,13 @@ bool checkNotModified(const Stmt &op, const Expr &s0Expr, const Expr &s1Expr,
return ret;
};

// write -> serialized PBSet
std::unordered_map<Stmt, std::string> writesWAR;
std::unordered_map<Stmt /* write */, PBSet::Serialized> writesWAR;
std::mutex m;
auto foundWAR = [&](const Dependence &dep) {
// Serialize WAR map because it is from a random PBCtx
auto strWAR =
toString(apply(domain(dep.later2EarlierIter_), dep.laterIter2Idx_));
auto strWAR = apply(domain(dep.later2EarlierIter_), dep.laterIter2Idx_);
// only lock for writing the map
std::lock_guard l(m);
writesWAR[dep.later_.stmt_] = strWAR;
writesWAR[dep.later_.stmt_] = strWAR.toSerialized();
};
FindDeps()
.direction({dir})
Expand All @@ -200,9 +197,8 @@ bool checkNotModified(const Stmt &op, const Expr &s0Expr, const Expr &s1Expr,
.noProjectOutPrivateAxis(true)(tmpOp, unsyncFunc(foundWAR));

auto foundRAW = [&](const Dependence &dep) {
// re-construct WAR map from stored string in current PBCtx
auto w0 =
PBSet(dep.later2EarlierIter_.ctx(), writesWAR[dep.earlier_.stmt_]);
writesWAR[dep.earlier_.stmt_].to(dep.later2EarlierIter_.ctx());
auto w1 = apply(range(dep.later2EarlierIter_), dep.earlierIter2Idx_);
if (!intersect(std::move(w0), std::move(w1)).empty())
throw ModifiedException{};
Expand Down
2 changes: 1 addition & 1 deletion src/autograd/invert_stmts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void genCondExpr(const Ref<PBCtx> &presburger, CondInfo *info) {
PBMap indicator = intersectDomain(
anythingTo1(presburger, info->when_.nDims()), info->when_);
for (auto &&[args, _, factorRange] :
parsePBFunc(toString(PBFunc(indicator)))) {
parsePBFunc(PBFunc(indicator).toSerialized())) {
if (!allReads(factorRange).empty()) {
throw ParserError("External variable in recomputing condition "
"is not yet supported");
Expand Down
13 changes: 10 additions & 3 deletions src/math/parse_pb_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ class RecoverBoolVars : public Mutator {
}
};

} // Anonymous namespace

PBFuncAST parsePBFunc(const std::string &str) {
PBFuncAST parsePBFuncImpl(const std::string &str) {
try {
antlr4::ANTLRInputStream charStream(str);
pb_lexer lexer(&charStream);
Expand All @@ -127,6 +125,15 @@ PBFuncAST parsePBFunc(const std::string &str) {
}
}

} // Anonymous namespace

PBFuncAST parsePBFunc(const PBFunc::Serialized &f) {
return parsePBFuncImpl(f.data());
}
PBFuncAST parsePBFunc(const PBSingleFunc::Serialized &f) {
return parsePBFuncImpl(f.data());
}

namespace {

Expr isl2Expr(__isl_take isl_ast_expr *e) {
Expand Down
16 changes: 8 additions & 8 deletions src/pass/prop_one_time_use.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace {

struct ReplaceInfo {
std::vector<IterAxis> earlierIters_, laterIters_;
std::string funcStr_;
PBFunc::Serialized func_;
};

std::vector<std::pair<AST, std::pair<Stmt, ReplaceInfo>>>
Expand Down Expand Up @@ -69,8 +69,8 @@ Stmt propOneTimeUse(const Stmt &_op, const ID &subAST) {
r2wCandidates;
std::unordered_map<AST, std::vector<Stmt>> r2wMay;
std::unordered_set<Stmt> wCandidates;
std::unordered_map<Stmt,
std::vector<std::pair<AST, std::string /* writeIter */>>>
std::unordered_map<
Stmt, std::vector<std::pair<AST, PBSet::Serialized /* writeIter */>>>
w2rMay;
std::unordered_map<AST, Stmt> stmts;
std::mutex lock;
Expand Down Expand Up @@ -124,7 +124,7 @@ Stmt propOneTimeUse(const Stmt &_op, const ID &subAST) {
// which may propagate
d.earlier().as<StmtNode>(),
ReplaceInfo{d.earlier_.iter_, d.later_.iter_,
toString(*f)});
f->toSerialized()});
wCandidates.emplace(d.earlier().as<StmtNode>());
stmts[d.later()] = d.later_.stmt_;
}
Expand Down Expand Up @@ -153,7 +153,7 @@ Stmt propOneTimeUse(const Stmt &_op, const ID &subAST) {
writeIter.nDims() - d.earlier_.iter_.size());
r2wMay[d.later()].emplace_back(d.earlier().as<StmtNode>());
w2rMay[d.earlier().as<StmtNode>()].emplace_back(
d.later(), toString(writeIter));
d.later(), writeIter.toSerialized());
});

// Filter single-valued and one-time-used
Expand All @@ -175,8 +175,8 @@ Stmt propOneTimeUse(const Stmt &_op, const ID &subAST) {
ASSERT(w2rMay.count(write.first));
auto ctx = Ref<PBCtx>::make();
PBSet writeIterUnion;
for (auto &&[read, writeIterStr] : w2rMay.at(write.first)) {
PBSet writeIter = PBSet(ctx, writeIterStr);
for (auto &&[read, _writeIter] : w2rMay.at(write.first)) {
PBSet writeIter = _writeIter.to(ctx);
if (writeIterUnion.isValid()) {
if (!intersect(writeIterUnion, writeIter).empty()) {
goto failure; // Not one-time-used
Expand Down Expand Up @@ -212,7 +212,7 @@ Stmt propOneTimeUse(const Stmt &_op, const ID &subAST) {
if (!allIters(toProp).empty()) {
try {
auto &&[args, values, cond] =
parseSimplePBFunc(repInfo.funcStr_); // later -> earlier
parseSimplePBFunc(repInfo.func_); // later -> earlier
ASSERT(repInfo.earlierIters_.size() <=
values.size()); // maybe padded
ASSERT(repInfo.laterIters_.size() <= args.size());
Expand Down
19 changes: 9 additions & 10 deletions src/pass/remove_writes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace {

struct ReplaceInfo {
std::vector<IterAxis> earlierIters_, laterIters_;
std::string funcStr_;
PBFunc::Serialized func_;
};

} // Anonymous namespace
Expand Down Expand Up @@ -229,15 +229,15 @@ Stmt removeWrites(const Stmt &_op, const ID &singleDefId) {
auto earlier = d.earlier().as<StmtNode>();
auto later = d.later().as<StmtNode>();
if (!kill.count(earlier)) {
kill[earlier] = domain(d.earlierIter2Idx_).to(presburger);
kill[earlier] = domain(d.earlierIter2Idx_.to(presburger));
}
auto extConstraint = range(d.extConstraint_).to(presburger);
auto extConstraint = range(d.extConstraint_.to(presburger));
std::tie(kill[earlier], extConstraint) =
padToSameDims(std::move(kill[earlier]), std::move(extConstraint));
kill[earlier] =
intersect(std::move(kill[earlier]), std::move(extConstraint));
overwrites.emplace_back(later, earlier,
range(d.later2EarlierIter_).to(presburger),
range(d.later2EarlierIter_.to(presburger)),
ReplaceInfo{});
suspect.insert(d.def());
};
Expand All @@ -255,15 +255,14 @@ Stmt removeWrites(const Stmt &_op, const ID &singleDefId) {
auto earlier = d.earlier().as<StmtNode>();
auto later = d.later().as<StmtNode>();
if (!kill.count(earlier)) {
kill[earlier] = PBSet(
presburger, toString(domain(d.earlierIter2Idx_)));
kill[earlier] =
domain(d.earlierIter2Idx_.to(presburger));
}
overwrites.emplace_back(
later, earlier,
PBSet(presburger,
toString(range(d.later2EarlierIter_))),
range(d.later2EarlierIter_.to(presburger)),
ReplaceInfo{d.earlier_.iter_, d.later_.iter_,
toString(*f)});
f->toSerialized()});
suspect.insert(d.def());
}
}
Expand Down Expand Up @@ -427,7 +426,7 @@ Stmt removeWrites(const Stmt &_op, const ID &singleDefId) {
if (!allIters(expr).empty()) {
try {
auto &&[args, values, cond] =
parseSimplePBFunc(repInfo.funcStr_); // later -> earlier
parseSimplePBFunc(repInfo.func_); // later -> earlier
ASSERT(repInfo.earlierIters_.size() <=
values.size()); // maybe padded
ASSERT(repInfo.laterIters_.size() <= args.size());
Expand Down
2 changes: 1 addition & 1 deletion src/pass/shrink_for.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class CompUniqueBoundsPBWithStride : public CompUniqueBoundsPB {
ASSERT(stride.denSi() == 1);
auto strideInt = stride.numSi();
ReplaceIter demangler(*bound->demangleMap_);
auto offsetSimpleFunc = parseSimplePBFunc(toString(offset));
auto offsetSimpleFunc = parseSimplePBFunc(offset.toSerialized());
// offsetSimpleFunc.args_ should be a dummy variable equals to `bound`'s
// value. Leave it.
ASSERT(offsetSimpleFunc.values_.size() == 1);
Expand Down
6 changes: 3 additions & 3 deletions src/pass/tensor_prop_const.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace {

struct ReplaceInfo {
std::vector<IterAxis> earlierIters_, laterIters_;
std::string funcStr_;
PBFunc::Serialized func_;
};

} // namespace
Expand Down Expand Up @@ -118,7 +118,7 @@ Stmt tensorPropConst(const Stmt &_op, const ID &bothInSubAST,
r2w[d.later()].emplace_back(
d.earlier().as<StmtNode>(),
ReplaceInfo{d.earlier_.iter_, d.later_.iter_,
toString(*f)});
f->toSerialized()});
}
}
}));
Expand Down Expand Up @@ -155,7 +155,7 @@ Stmt tensorPropConst(const Stmt &_op, const ID &bothInSubAST,
if (!allIters(store->expr_).empty()) {
try {
auto &&[args, values, cond] =
parseSimplePBFunc(repInfo.funcStr_); // later -> earlier
parseSimplePBFunc(repInfo.func_); // later -> earlier
ASSERT(repInfo.earlierIters_.size() <=
values.size()); // maybe padded
ASSERT(repInfo.laterIters_.size() <= args.size());
Expand Down
2 changes: 1 addition & 1 deletion src/schedule/inlining.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Stmt inlining(const Stmt &_ast, const ID &def) {
throw ParserError("ISL map is not single-valued");
}
auto &&[args, values, cond] = parseSimplePBFunc(
toString(PBFunc(dep.later2EarlierIter_)));
PBFunc(dep.later2EarlierIter_).toSerialized());
ASSERT(dep.earlier_.iter_.size() <=
values.size()); // maybe padded
ASSERT(dep.later_.iter_.size() <= args.size());
Expand Down
Loading

0 comments on commit d79c212

Please sign in to comment.