Skip to content

Commit

Permalink
Remove redundant analysis in pass/pb_simplify (#584)
Browse files Browse the repository at this point in the history
* Remove redundant analysis in pass/pb_simplify

* Fix broken test
  • Loading branch information
roastduck authored Jan 15, 2024
1 parent 03beb6d commit 7696f46
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 54 deletions.
82 changes: 47 additions & 35 deletions include/pass/simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
#include <func.h>
#include <math/bounds.h>
#include <mutator.h>
#include <pass/annotate_conds.h>
#include <pass/const_fold.h>
#include <pass/flatten_stmt_seq.h>
#include <visitor.h>

namespace freetensor {
Expand All @@ -41,9 +39,27 @@ class FindInnerMostScope : public Visitor {
int findInnerMostScope(const std::unordered_map<std::string, int> &varScope,
const Expr &op);

// NOTE: We use ConstFold because we cannot rely the bound analysis for constant
// propagation. E.g f(x) + 0, where f(x) is a complex expression and it does not
// have a bound. The "+ 0" cannot be removed by bound analysis
/**
* Base class for integer simplify passes
*
* Simplification from 3 mechanisms:
*
* - Bound analysis from a specific subclass of `CompUniqueBounds`. If there is
* only one integer between an expression's lower bound and upper bound, then
* the expression can be replaced by the integer.
* - Constant folding from `ConstFold`.
* - Simplification rules from `SimplifyPass`. This is a complement of the bound
* analysis. E.g. to simplify `x + 0` as `x`, which cannot be simplified by
* bound analysis if `x` has no bound.
*
* @param compUniqueBoundsFactor : A factory function creating a specific
* `CompUniqueBounds` instance for bound analysis.
* @param leafFirstBoundAnalysis : Whether to simplify sub-expressions with
* bound analysis before simplifying their parents. This is useful when the
* simplification of a sub-expression helps analyzing its parent, but will only
* waste time if the analysis is based on a unified representation and does not
* depend on the specific form of the sub-expression.
*/
class SimplifyPass : public CompTransientBounds<SymbolTable<ConstFold>> {
typedef CompTransientBounds<SymbolTable<ConstFold>> BaseClass;

Expand All @@ -54,12 +70,15 @@ class SimplifyPass : public CompTransientBounds<SymbolTable<ConstFold>> {
Ref<CompUniqueBounds> unique_;
std::function<Ref<CompUniqueBounds>(const CompTransientBoundsInterface &)>
compUniqueBoundsFactory_;
bool leafFirstBoundAnalysis_;

public:
SimplifyPass(std::function<
Ref<CompUniqueBounds>(const CompTransientBoundsInterface &)>
compUniqueBoundsFactory)
: compUniqueBoundsFactory_(compUniqueBoundsFactory) {}
SimplifyPass(std::function<Ref<CompUniqueBounds>(
const CompTransientBoundsInterface &)>
compUniqueBoundsFactory,
bool leafFirstBoundAnalysis)
: compUniqueBoundsFactory_(compUniqueBoundsFactory),
leafFirstBoundAnalysis_(leafFirstBoundAnalysis) {}

private:
template <class T> bool equals(const Expr &op, T &&val) const {
Expand Down Expand Up @@ -105,48 +124,41 @@ class SimplifyPass : public CompTransientBounds<SymbolTable<ConstFold>> {
class BuiltinSimplify : public SimplifyPass {
public:
BuiltinSimplify()
: SimplifyPass([](const CompTransientBoundsInterface &tr) {
return Ref<CompUniqueBoundsCombination>::make(tr);
}) {}
: SimplifyPass(
[](const CompTransientBoundsInterface &tr) {
return Ref<CompUniqueBoundsCombination>::make(tr);
},
true) {}
};

class PBSimplify : public SimplifyPass {
public:
PBSimplify()
: SimplifyPass([](const CompTransientBoundsInterface &tr) {
return Ref<CompUniqueBoundsPB>::make(tr);
}) {}
: SimplifyPass(
[](const CompTransientBoundsInterface &tr) {
return Ref<CompUniqueBoundsPB>::make(tr);
},
false) {}
};

/**
* Simplify a program and compute bounds of each expressions
* Simplify integer expressions in a program
*
* `builtinSimplify` and `simplify` uses `CompUniqueBoundsCombination` to
* simplify the program. `pbSimplify` uses `CompUniqueBoundsPB` to simplify the
* program.
*
* This pass can only be applied on a complete program, instead of a single
* expression, because it examines VarDef nodes of each Var
*
* @return : {simplified, lower, upper}
* @return : The simplified AST
*
* @{
*/
template <class Simplifier> Stmt simplifyImpl(const Stmt &_op) {
auto op = _op;

for (int i = 0;; i++) {
auto newOp = annotateConds(op);
newOp = Simplifier()(newOp);
newOp = flattenStmtSeq(newOp);
if (HashComparator()(newOp, op) || i > 100) {
if (i > 100) {
WARNING("pass/simplify iterates over 100 rounds. Maybe there "
"is a bug");
}
return newOp;
}
op = newOp;
}
}

Stmt builtinSimplify(const Stmt &op);
Stmt pbSimplify(const Stmt &op);
Stmt simplify(const Stmt &op);
/** @} */

DEFINE_PASS_FOR_FUNC(builtinSimplify)
DEFINE_PASS_FOR_FUNC(pbSimplify)
Expand Down
8 changes: 6 additions & 2 deletions src/analyze/comp_unique_bounds_pb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,12 @@ Expr CompUniqueBoundsPB::Bound::simplestExpr(

// remove one axis at a time, try until it's not single valued
auto restrictedBound = bound_;
int minScopeLevel = INT_MAX;
int minScopeLevel = INT_MAX,
oldScopeLevel = countScope(reference, orderedScope);
for (auto &&[axis, scopeLevel] : axesScopeLevel) {
if (scopeLevel > oldScopeLevel) {
continue;
}
auto newRestrictedBound =
projectOutParamById(std::move(restrictedBound), axis);
if (!newRestrictedBound.isSingleValued())
Expand All @@ -132,7 +136,7 @@ Expr CompUniqueBoundsPB::Bound::simplestExpr(
if (!resultExpr.isValid()) {
return nullptr;
}
auto isSimplier = minScopeLevel < countScope(reference, orderedScope) ||
auto isSimplier = minScopeLevel < oldScopeLevel ||
countHeavyOps(resultExpr) < countHeavyOps(reference);
return isSimplier ? resultExpr : nullptr;
}
Expand Down
141 changes: 124 additions & 17 deletions src/pass/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <except.h>
#include <math/min_max.h>
#include <math/utils.h>
#include <pass/annotate_conds.h>
#include <pass/flatten_stmt_seq.h>
#include <pass/replace_iter.h>
#include <pass/simplify.h>
Expand Down Expand Up @@ -153,24 +154,35 @@ Stmt SimplifyPass::visitStmt(const Stmt &op) {
}

Expr SimplifyPass::visitExpr(const Expr &_op) {
auto op = BaseClass::visitExpr(_op);

// To avoid divergence
if (!HashComparator()(op, _op)) {
// E.g.
// (1) a[0 - 0] -> a[0]
// (2) (1 + 1) * a[0] -> 2 * a[0 - 0], because of the old bound
return op;
}

if (auto bound = unique_->getBound(op); bound.isValid()) {
Expr best = bound->simplestExpr(op, varScope_);
if (best.isValid() && !HashComparator()(best, op)) {
return best;
if (isInt(_op->dtype())) {
if (leafFirstBoundAnalysis_) {
auto op = BaseClass::visitExpr(_op);
if (!HashComparator()(op, _op)) {
// To avoid divergence
// E.g.
// (1) a[0 - 0] -> a[0]
// (2) (1 + 1) * a[0] -> 2 * a[0 - 0], because of the old bound
return op;
}
if (auto bound = unique_->getBound(op); bound.isValid()) {
Expr best = bound->simplestExpr(op, varScope_);
if (best.isValid() && !HashComparator()(best, op)) {
return best;
}
}
return op;
} else {
if (auto bound = unique_->getBound(_op); bound.isValid()) {
Expr best = bound->simplestExpr(_op, varScope_);
if (best.isValid() && !HashComparator()(best, _op)) {
return best;
}
}
return BaseClass::visitExpr(_op);
}
}

if (_op->dtype() == DataType::Bool) {
} else if (isBool(_op->dtype())) {
auto op = BaseClass::visitExpr(_op);
if (auto p = _op->parentExpr();
!p.isValid() || p->dtype() != DataType::Bool) {
// this is base bool expr
Expand All @@ -179,8 +191,11 @@ Expr SimplifyPass::visitExpr(const Expr &_op) {
op = normalized;
}
}
return op;

} else {
return BaseClass::visitExpr(_op);
}
return op;
}

Expr SimplifyPass::visit(const Add &_op) {
Expand Down Expand Up @@ -711,13 +726,87 @@ Expr SimplifyPass::visit(const IfExpr &_op) {
}
ASSERT(__op->nodeType() == ASTNodeType::IfExpr);
auto op = __op.as<IfExprNode>();

if (op->cond_->nodeType() == ASTNodeType::BoolConst) {
if (op->cond_.as<BoolConstNode>()->val_) {
return op->thenCase_;
} else {
return op->elseCase_;
}
}

if (HashComparator{}(op->thenCase_, op->elseCase_)) {
return op->thenCase_;
}

if (op->thenCase_->nodeType() == op->elseCase_->nodeType()) {
if (op->thenCase_->isUnary()) {
return makeUnary(
op->thenCase_->nodeType(),
makeIfExpr(op->cond_, op->thenCase_.as<UnaryExprNode>()->expr_,
op->elseCase_.as<UnaryExprNode>()->expr_));
} else if (op->thenCase_->isBinary()) {
auto &&thenCase = op->thenCase_.as<BinaryExprNode>();
auto &&elseCase = op->elseCase_.as<BinaryExprNode>();
if (HashComparator{}(thenCase->lhs_, elseCase->lhs_)) {
return makeBinary(
thenCase->nodeType(), thenCase->lhs_,
makeIfExpr(op->cond_, thenCase->rhs_, elseCase->rhs_));
} else if (HashComparator{}(thenCase->rhs_, elseCase->rhs_)) {
return makeBinary(
thenCase->nodeType(),
makeIfExpr(op->cond_, thenCase->lhs_, elseCase->lhs_),
thenCase->rhs_);
} else if (thenCase->isCommutative()) {
if (HashComparator{}(thenCase->lhs_, elseCase->rhs_)) {
return makeBinary(
thenCase->nodeType(), thenCase->lhs_,
makeIfExpr(op->cond_, thenCase->rhs_, elseCase->lhs_));
} else if (HashComparator{}(thenCase->rhs_, elseCase->lhs_)) {
return makeBinary(
thenCase->nodeType(), thenCase->rhs_,
makeIfExpr(op->cond_, thenCase->lhs_, elseCase->rhs_));
}
}
} else if (op->thenCase_->nodeType() == ASTNodeType::Cast) {
auto &&thenCase = op->thenCase_.as<CastNode>();
auto &&elseCase = op->elseCase_.as<CastNode>();
if (thenCase->destType_ == elseCase->destType_) {
return makeCast(
makeIfExpr(op->cond_, thenCase->expr_, elseCase->expr_),
thenCase->destType_);
}
} else if (op->thenCase_->nodeType() == ASTNodeType::Load) {
auto &&thenCase = op->thenCase_.as<LoadNode>();
auto &&elseCase = op->elseCase_.as<LoadNode>();
if (thenCase->var_ == elseCase->var_) {
// Since `var_` is the same, these must be the same
ASSERT(thenCase->indices_.size() == elseCase->indices_.size());
ASSERT(thenCase->loadType_ == elseCase->loadType_);
int diffCnt = 0;
std::vector<Expr> indices;
indices.reserve(thenCase->indices_.size());
for (auto &&[thenItem, elseItem] :
views::zip(thenCase->indices_, elseCase->indices_)) {
if (HashComparator{}(thenItem, elseItem)) {
indices.emplace_back(thenItem);
} else {
diffCnt++;
indices.emplace_back(
makeIfExpr(op->cond_, thenItem, elseItem));
}
}
if (diffCnt <= 1) {
return makeLoad(thenCase->var_, std::move(indices),
thenCase->loadType_);
}
}
}
// TODO: We can also handle `Intrinsic`, but we must properly deal with
// `hasSideEffect_`, and check for data type in case of `... ?
// intrinsic(..., type_a) : intrinsic(..., type_b)`.
}

return op;
}

Expand Down Expand Up @@ -846,6 +935,24 @@ Stmt SimplifyPass::visit(const Assert &_op) {
return op;
}

template <class Simplifier> static Stmt simplifyImpl(const Stmt &_op) {
auto op = _op;

for (int i = 0;; i++) {
auto newOp = annotateConds(op);
newOp = Simplifier()(newOp);
newOp = flattenStmtSeq(newOp);
if (HashComparator()(newOp, op) || i > 100) {
if (i > 100) {
WARNING("pass/simplify iterates over 100 rounds. Maybe there "
"is a bug");
}
return newOp;
}
op = newOp;
}
}

Stmt builtinSimplify(const Stmt &op) {
return flattenStmtSeq(simplifyImpl<BuiltinSimplify>(op));
}
Expand Down

0 comments on commit 7696f46

Please sign in to comment.