Skip to content

Commit

Permalink
Remove statement checking from CompUniqueBoundsPB since we have now r…
Browse files Browse the repository at this point in the history
…equried to instantiate one CompUniqueBoundsPB object per statement
  • Loading branch information
roastduck committed Jan 14, 2024
1 parent 1f9e7f8 commit 5bcadd8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 59 deletions.
11 changes: 0 additions & 11 deletions include/analyze/comp_transient_bounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class CompTransientBoundsInterface {
public:
virtual TransientBound transient(const Expr &op) const = 0;
virtual const std::vector<Expr> &conds() const = 0;
virtual const Stmt &currentStmt() const = 0;
};

/**
Expand Down Expand Up @@ -54,9 +53,6 @@ class CompTransientBounds : public BaseClass,
// Original bounds
std::vector<Expr> conds_;

// Currently visited statement
Stmt currentStmt_;

public:
TransientBound transient(const Expr &op) const override {
if (transients_.count(op)) {
Expand All @@ -67,8 +63,6 @@ class CompTransientBounds : public BaseClass,

const std::vector<Expr> &conds() const override { return conds_; }

const Stmt &currentStmt() const override { return currentStmt_; };

private:
void applyCond(const Expr &_cond,
const std::unordered_set<std::string> &bodyAllWrites) {
Expand Down Expand Up @@ -246,11 +240,6 @@ class CompTransientBounds : public BaseClass,
op->id(), op->debugBlame());
}
}

typename BaseClass::StmtRetType visitStmt(const Stmt &op) override {
currentStmt_ = op;
return BaseClass::visitStmt(op);
}
};

} // namespace freetensor
Expand Down
5 changes: 1 addition & 4 deletions include/analyze/comp_unique_bounds_pb.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class CompUniqueBoundsPB : public CompUniqueBounds {
GenPBExpr genPBExpr_;
Ref<PBCtx> ctx_;

Stmt cachedPlace_;
PBSet cachedConds_;
Ref<std::unordered_map<std::string, Expr>> cachedFreeVars_;
std::unordered_map<Expr, Ref<Bound>> cachedValues_;
Expand All @@ -68,9 +67,7 @@ class CompUniqueBoundsPB : public CompUniqueBounds {
unionBoundsAsBound(const std::vector<Ref<CompUniqueBounds::Bound>> &bounds);

public:
CompUniqueBoundsPB(const CompTransientBoundsInterface &transients)
: CompUniqueBounds(transients), transients_(transients),
ctx_(Ref<PBCtx>::make()) {}
CompUniqueBoundsPB(const CompTransientBoundsInterface &transients);

Ref<CompUniqueBounds::Bound> getBound(const Expr &op) override;
bool alwaysLE(const Expr &lhs, const Expr &rhs) override;
Expand Down
84 changes: 40 additions & 44 deletions src/analyze/comp_unique_bounds_pb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,60 +137,56 @@ Expr CompUniqueBoundsPB::Bound::simplestExpr(
return isSimplier ? resultExpr : nullptr;
}

Ref<CompUniqueBounds::Bound> CompUniqueBoundsPB::getBound(const Expr &op) {
if (!isInt(op->dtype()))
return nullptr;
CompUniqueBoundsPB::CompUniqueBoundsPB(
const CompTransientBoundsInterface &transients)
: CompUniqueBounds(transients), transients_(transients),
ctx_(Ref<PBCtx>::make()) {

// construct full condition
Expr fullCond = makeBoolConst(true);
for (auto &&cond : transients_.conds())
fullCond = makeLAnd(fullCond, cond);

// check if the cache is valid
if (auto place = transients_.currentStmt(); place != cachedPlace_) {
// invalid, refresh it with the new transients condition
cachedPlace_ = place;

// construct full condition
Expr fullCond = makeBoolConst(true);
for (auto &&cond : transients_.conds())
fullCond = makeLAnd(fullCond, cond);

// generate PB condition
std::string str;
GenPBExpr::VarMap varMap;
for (auto &&[subExpr, cond] : normalizeConditionalExpr(fullCond)) {
auto [subStr, subVarMap] = genPBExpr_.gen(subExpr);
subStr = "[unique_bounded_var] : " + subStr;
for (auto &&[k, v] : subVarMap) {
// generate PB condition
std::string str;
GenPBExpr::VarMap varMap;
for (auto &&[subExpr, cond] : normalizeConditionalExpr(fullCond)) {
auto [subStr, subVarMap] = genPBExpr_.gen(subExpr);
subStr = "[unique_bounded_var] : " + subStr;
for (auto &&[k, v] : subVarMap) {
if (auto it = varMap.find(k); it != varMap.end()) {
ASSERT(it->second == v);
} else {
varMap[k] = v;
}
}
if (cond.isValid()) {
auto [condStr, condVarMap] = genPBExpr_.gen(cond);
subStr += " and " + condStr;
for (auto &&[k, v] : condVarMap) {
if (auto it = varMap.find(k); it != varMap.end()) {
ASSERT(it->second == v);
} else {
varMap[k] = v;
}
}
if (cond.isValid()) {
auto [condStr, condVarMap] = genPBExpr_.gen(cond);
subStr += " and " + condStr;
for (auto &&[k, v] : condVarMap) {
if (auto it = varMap.find(k); it != varMap.end()) {
ASSERT(it->second == v);
} else {
varMap[k] = v;
}
}
}
str += str.empty() ? subStr : "; " + subStr;
}
cachedConds_ =
PBSet(*ctx_, "[" + (varMap | views::values | join(", ")) +
"] -> {" + str + "}");

// initialize known demangle map
cachedFreeVars_ = decltype(cachedFreeVars_)::make();
for (auto &&[expr, pbVar] : varMap) {
ASSERT(!cachedFreeVars_->contains(pbVar));
(*cachedFreeVars_)[pbVar] = expr;
}
str += str.empty() ? subStr : "; " + subStr;
}
cachedConds_ = PBSet(*ctx_, "[" + (varMap | views::values | join(", ")) +
"] -> {" + str + "}");

// clear cached query results
cachedValues_.clear();
// initialize known demangle map
cachedFreeVars_ = decltype(cachedFreeVars_)::make();
for (auto &&[expr, pbVar] : varMap) {
ASSERT(!cachedFreeVars_->contains(pbVar));
(*cachedFreeVars_)[pbVar] = expr;
}
}

Ref<CompUniqueBounds::Bound> CompUniqueBoundsPB::getBound(const Expr &op) {
if (!isInt(op->dtype()))
return nullptr;

// find in cached results
if (auto it = cachedValues_.find(op); it != cachedValues_.end())
Expand Down

0 comments on commit 5bcadd8

Please sign in to comment.