Skip to content

Commit

Permalink
Recognize IfExpr in CompUniqueBoundsPB (#581)
Browse files Browse the repository at this point in the history
* Recognize IfExpr in CompUniqueBoundsPB

* Remove statement checking from CompUniqueBoundsPB since we have now requried to instantiate one CompUniqueBoundsPB object per statement
  • Loading branch information
roastduck authored Jan 14, 2024
1 parent 240bf3a commit 051ab9e
Show file tree
Hide file tree
Showing 11 changed files with 401 additions and 83 deletions.
11 changes: 0 additions & 11 deletions include/analyze/comp_transient_bounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class CompTransientBoundsInterface {
public:
virtual TransientBound transient(const Expr &op) const = 0;
virtual const std::vector<Expr> &conds() const = 0;
virtual const Stmt &currentStmt() const = 0;
};

/**
Expand Down Expand Up @@ -54,9 +53,6 @@ class CompTransientBounds : public BaseClass,
// Original bounds
std::vector<Expr> conds_;

// Currently visited statement
Stmt currentStmt_;

public:
TransientBound transient(const Expr &op) const override {
if (transients_.count(op)) {
Expand All @@ -67,8 +63,6 @@ class CompTransientBounds : public BaseClass,

const std::vector<Expr> &conds() const override { return conds_; }

const Stmt &currentStmt() const override { return currentStmt_; };

private:
void applyCond(const Expr &_cond,
const std::unordered_set<std::string> &bodyAllWrites) {
Expand Down Expand Up @@ -246,11 +240,6 @@ class CompTransientBounds : public BaseClass,
op->id(), op->debugBlame());
}
}

typename BaseClass::StmtRetType visitStmt(const Stmt &op) override {
currentStmt_ = op;
return BaseClass::visitStmt(op);
}
};

} // namespace freetensor
Expand Down
7 changes: 7 additions & 0 deletions include/analyze/comp_unique_bounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ class CompUniqueBounds {
enum class BoundType { Combination, Presburger };

class Bound {
protected:
static int
countScope(const Expr &op,
const std::unordered_map<std::string, int> &orderedScope);
static int countHeavyOps(const Expr &op);

public:
virtual ~Bound() {}

Expand Down Expand Up @@ -48,6 +54,7 @@ class CompUniqueBounds {
restrictScope(const std::unordered_set<std::string> &scope) const = 0;

virtual Expr simplestExpr(
const Expr &reference,
const std::unordered_map<std::string, int> &orderedScope) const = 0;
};

Expand Down
3 changes: 2 additions & 1 deletion include/analyze/comp_unique_bounds_combination.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class CompUniqueBoundsCombination : public CompUniqueBounds, public Visitor {
Ref<CompUniqueBounds::Bound> restrictScope(
const std::unordered_set<std::string> &scope) const override;

Expr simplestExpr(const std::unordered_map<std::string, int>
Expr simplestExpr(const Expr &reference,
const std::unordered_map<std::string, int>
&orderedScope) const override;
};

Expand Down
8 changes: 3 additions & 5 deletions include/analyze/comp_unique_bounds_pb.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class CompUniqueBoundsPB : public CompUniqueBounds {
Ref<CompUniqueBounds::Bound> restrictScope(
const std::unordered_set<std::string> &scope) const override;

Expr simplestExpr(const std::unordered_map<std::string, int>
Expr simplestExpr(const Expr &reference,
const std::unordered_map<std::string, int>
&orderedScope) const override;
};

Expand All @@ -57,7 +58,6 @@ class CompUniqueBoundsPB : public CompUniqueBounds {
GenPBExpr genPBExpr_;
Ref<PBCtx> ctx_;

Stmt cachedPlace_;
PBSet cachedConds_;
Ref<std::unordered_map<std::string, Expr>> cachedFreeVars_;
std::unordered_map<Expr, Ref<Bound>> cachedValues_;
Expand All @@ -67,9 +67,7 @@ class CompUniqueBoundsPB : public CompUniqueBounds {
unionBoundsAsBound(const std::vector<Ref<CompUniqueBounds::Bound>> &bounds);

public:
CompUniqueBoundsPB(const CompTransientBoundsInterface &transients)
: CompUniqueBounds(transients), transients_(transients),
ctx_(Ref<PBCtx>::make()) {}
CompUniqueBoundsPB(const CompTransientBoundsInterface &transients);

Ref<CompUniqueBounds::Bound> getBound(const Expr &op) override;
bool alwaysLE(const Expr &lhs, const Expr &rhs) override;
Expand Down
27 changes: 27 additions & 0 deletions include/analyze/normalize_conditional_expr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef FREE_TENSOR_NORMALIZE_CONDITIONAL_EXPR_H
#define FREE_TENSOR_NORMALIZE_CONDITIONAL_EXPR_H

#include <vector>

#include <expr.h>

namespace freetensor {

/**
* Break a expression into several conditional parts.
*
* This function is used for analyzing expressions with `IfExpr` inside. The
* result will be several parts with conditions, where each part is no longer
* with `IfExpr`.
*
* @param expr : The expression to be analyzed.
* @return : A vector of pairs, where the first element is the value of the
* expression, and the second element is the condition of the expression. The
* condition may be null, which means the expression is always true.
*/
std::vector<std::pair<Expr /* value */, Expr /* condition, maybe null */>>
normalizeConditionalExpr(const Expr &expr);

} // namespace freetensor

#endif // FREE_TENSOR_NORMALIZE_CONDITIONAL_EXPR_H
42 changes: 42 additions & 0 deletions src/analyze/comp_unique_bounds.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include <analyze/all_uses.h>
#include <analyze/comp_unique_bounds.h>

namespace freetensor {

namespace {

class CountHeavyOps : public Visitor {
int cnt_ = 0;

public:
int cnt() const { return cnt_; }

protected:
void visitExpr(const Expr &op) {
Visitor::visitExpr(op);
if (!op->isConst() && op->nodeType() != ASTNodeType::Add &&
op->nodeType() != ASTNodeType::Sub &&
op->nodeType() != ASTNodeType::Mul) {
cnt_++;
}
}
};

} // Anonymous namespace

int CompUniqueBounds::Bound::countHeavyOps(const Expr &op) {
CountHeavyOps visitor;
visitor(op);
return visitor.cnt();
}

int CompUniqueBounds::Bound::countScope(
const Expr &expr,
const std::unordered_map<std::string, int> &orderedScope) {
int scope = 0;
for (auto &&use : allUses(expr))
scope = std::max(scope, orderedScope.at(use));
return scope;
}

} // namespace freetensor
32 changes: 2 additions & 30 deletions src/analyze/comp_unique_bounds_combination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,6 @@

namespace freetensor {

namespace {

class CountHeavyOps : public Visitor {
int cnt_ = 0;

public:
int cnt() const { return cnt_; }

protected:
void visitExpr(const Expr &op) {
Visitor::visitExpr(op);
if (!op->isConst() && op->nodeType() != ASTNodeType::Add &&
op->nodeType() != ASTNodeType::Sub &&
op->nodeType() != ASTNodeType::Mul) {
cnt_++;
}
}
};

static int countHeavyOps(const Expr &op) {
CountHeavyOps visitor;
visitor(op);
return visitor.cnt();
}

} // namespace

int64_t CompUniqueBoundsCombination::Bound::lowerInt() const {
int64_t ret = LLONG_MIN;
for (auto &&b : lowerBounds_) {
Expand Down Expand Up @@ -96,6 +69,7 @@ Ref<CompUniqueBounds::Bound> CompUniqueBoundsCombination::Bound::restrictScope(
}

Expr CompUniqueBoundsCombination::Bound::simplestExpr(
const Expr &reference,
const std::unordered_map<std::string, int> &orderedScope) const {
Expr best = nullptr;
auto bestScope = -1, bestHeavyOps = -1;
Expand All @@ -115,9 +89,7 @@ Expr CompUniqueBoundsCombination::Bound::simplestExpr(
expr = upper.expr();
}
// firstly choose outermost innermost scope
int scope = 0;
for (auto &&use : allUses(expr))
scope = std::max(scope, orderedScope.at(use));
int scope = countScope(expr, orderedScope);
// secondly choose the one with least heavy operations
auto heavyOps = countHeavyOps(expr);
if (!best.isValid() || scope < bestScope ||
Expand Down
122 changes: 87 additions & 35 deletions src/analyze/comp_unique_bounds_pb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <analyze/all_uses.h>
#include <analyze/comp_unique_bounds_pb.h>
#include <analyze/normalize_conditional_expr.h>
#include <container_utils.h>
#include <expr.h>
#include <math/parse_pb_expr.h>
Expand Down Expand Up @@ -99,6 +100,7 @@ Ref<CompUniqueBounds::Bound> CompUniqueBoundsPB::Bound::restrictScope(
}

Expr CompUniqueBoundsPB::Bound::simplestExpr(
const Expr &reference,
const std::unordered_map<std::string, int> &orderedScope) const {

// first test the original map to be single valued
Expand All @@ -108,67 +110,117 @@ Expr CompUniqueBoundsPB::Bound::simplestExpr(
std::vector<std::pair<std::string, int>> axesScopeLevel;
for (int i = 0; i < bound_.nParamDims(); ++i) {
auto name = bound_.nameParamDim(i);
int scopeLevel = 0;
for (auto &&used : allUses(demangleMap_->at(name)))
scopeLevel = std::max(scopeLevel, orderedScope.at(used));
axesScopeLevel.emplace_back(name, scopeLevel);
axesScopeLevel.emplace_back(
name, countScope(demangleMap_->at(name), orderedScope));
}
// sort to innermost first, we will try remove them one by one
std::sort(axesScopeLevel.begin(), axesScopeLevel.end(),
[](auto &&a, auto &&b) { return a.second > b.second; });

// remove one axis at a time, try until it's not single valued
auto restrictedBound = bound_;
for (auto &&[axis, _] : axesScopeLevel) {
int minScopeLevel = INT_MAX;
for (auto &&[axis, scopeLevel] : axesScopeLevel) {
auto newRestrictedBound =
projectOutParamById(std::move(restrictedBound), axis);
if (!newRestrictedBound.isSingleValued())
break;
restrictedBound = std::move(newRestrictedBound);
minScopeLevel = scopeLevel;
}
return translateBoundFunc(*ctx_, restrictedBound, *demangleMap_);
}

Ref<CompUniqueBounds::Bound> CompUniqueBoundsPB::getBound(const Expr &op) {
if (!isInt(op->dtype()))
auto resultExpr = translateBoundFunc(*ctx_, restrictedBound, *demangleMap_);
if (!resultExpr.isValid()) {
return nullptr;
}
auto isSimplier = minScopeLevel < countScope(reference, orderedScope) ||
countHeavyOps(resultExpr) < countHeavyOps(reference);
return isSimplier ? resultExpr : nullptr;
}

// check if the cache is valid
if (auto place = transients_.currentStmt(); place != cachedPlace_) {
// invalid, refresh it with the new transients condition
cachedPlace_ = place;

// construct full condition
Expr fullCond = makeBoolConst(true);
for (auto &&cond : transients_.conds())
fullCond = makeLAnd(fullCond, cond);

// generate PB condition
auto [str, varMap] = genPBExpr_.gen(fullCond);
cachedConds_ =
PBSet(*ctx_, "[" + (varMap | views::values | join(", ")) +
"] -> { [unique_bounded_var]: " + str + " }");

// initialize known demangle map
cachedFreeVars_ = decltype(cachedFreeVars_)::make();
for (auto &&[expr, pbVar] : varMap) {
ASSERT(!cachedFreeVars_->contains(pbVar));
(*cachedFreeVars_)[pbVar] = expr;
CompUniqueBoundsPB::CompUniqueBoundsPB(
const CompTransientBoundsInterface &transients)
: CompUniqueBounds(transients), transients_(transients),
ctx_(Ref<PBCtx>::make()) {

// construct full condition
Expr fullCond = makeBoolConst(true);
for (auto &&cond : transients_.conds())
fullCond = makeLAnd(fullCond, cond);

// generate PB condition
std::string str;
GenPBExpr::VarMap varMap;
for (auto &&[subExpr, cond] : normalizeConditionalExpr(fullCond)) {
auto [subStr, subVarMap] = genPBExpr_.gen(subExpr);
subStr = "[unique_bounded_var] : " + subStr;
for (auto &&[k, v] : subVarMap) {
if (auto it = varMap.find(k); it != varMap.end()) {
ASSERT(it->second == v);
} else {
varMap[k] = v;
}
}
if (cond.isValid()) {
auto [condStr, condVarMap] = genPBExpr_.gen(cond);
subStr += " and " + condStr;
for (auto &&[k, v] : condVarMap) {
if (auto it = varMap.find(k); it != varMap.end()) {
ASSERT(it->second == v);
} else {
varMap[k] = v;
}
}
}
str += str.empty() ? subStr : "; " + subStr;
}
cachedConds_ = PBSet(*ctx_, "[" + (varMap | views::values | join(", ")) +
"] -> {" + str + "}");

// clear cached query results
cachedValues_.clear();
// initialize known demangle map
cachedFreeVars_ = decltype(cachedFreeVars_)::make();
for (auto &&[expr, pbVar] : varMap) {
ASSERT(!cachedFreeVars_->contains(pbVar));
(*cachedFreeVars_)[pbVar] = expr;
}
}

Ref<CompUniqueBounds::Bound> CompUniqueBoundsPB::getBound(const Expr &op) {
if (!isInt(op->dtype()))
return nullptr;

// find in cached results
if (auto it = cachedValues_.find(op); it != cachedValues_.end())
return it->second;

// not previously queried, construct the bound
auto [str, varMap] = genPBExpr_.gen(op);
std::string str;
GenPBExpr::VarMap varMap;
for (auto &&[subExpr, cond] : normalizeConditionalExpr(op)) {
auto [subStr, subVarMap] = genPBExpr_.gen(subExpr);
subStr = "[" + subStr + "]";
for (auto &&[k, v] : subVarMap) {
if (auto it = varMap.find(k); it != varMap.end()) {
ASSERT(it->second == v);
} else {
varMap[k] = v;
}
}
if (cond.isValid()) {
auto [condStr, condVarMap] = genPBExpr_.gen(cond);
subStr += " : " + condStr;
for (auto &&[k, v] : condVarMap) {
if (auto it = varMap.find(k); it != varMap.end()) {
ASSERT(it->second == v);
} else {
varMap[k] = v;
}
}
}
str += str.empty() ? subStr : "; " + subStr;
}
auto bound =
(intersect(PBSet(*ctx_, "[" + (varMap | views::values | join(", ")) +
"] -> { [" + str + "] }"),
"] -> {" + str + "}"),
cachedConds_));
// update free variables
for (auto &&[expr, pbVar] : varMap) {
Expand Down
Loading

0 comments on commit 051ab9e

Please sign in to comment.