From c182842201a0eeb8628215dc903e9d812a143d24 Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Thu, 18 Jan 2024 15:13:56 +0800 Subject: [PATCH] Shrink overly relaxed linear indices in pass/shrink_var --- include/analyze/comp_transient_bounds.h | 10 +- include/pass/shrink_linear_indices.h | 23 +++ include/pass/shrink_var.h | 7 + src/pass/shrink_linear_indices.cc | 198 ++++++++++++++++++++++++ src/pass/shrink_var.cc | 5 + test/20.pass/test_shrink_var.py | 31 ++++ 6 files changed, 272 insertions(+), 2 deletions(-) create mode 100644 include/pass/shrink_linear_indices.h create mode 100644 src/pass/shrink_linear_indices.cc diff --git a/include/analyze/comp_transient_bounds.h b/include/analyze/comp_transient_bounds.h index cf4929e91..7749a18b0 100644 --- a/include/analyze/comp_transient_bounds.h +++ b/include/analyze/comp_transient_bounds.h @@ -1,11 +1,13 @@ #ifndef FREE_TENSOR_COMP_TRANSIENT_BOUNDS_H #define FREE_TENSOR_COMP_TRANSIENT_BOUNDS_H +#include #include #include #include #include +#include #include #include #include @@ -148,9 +150,13 @@ class CompTransientBounds : public BaseClass, conds_.emplace_back(makeEQ(var, op->begin_)); } } - this->pushFor(op); + if constexpr (std::is_base_of_v) { + this->pushFor(op); + } MAYBE_VOID(body, (*this)(op->body_)); - this->popFor(op); + if constexpr (std::is_base_of_v) { + this->popFor(op); + } conds_.resize(oldCondsSize); transients_.erase(var); diff --git a/include/pass/shrink_linear_indices.h b/include/pass/shrink_linear_indices.h new file mode 100644 index 000000000..35790058c --- /dev/null +++ b/include/pass/shrink_linear_indices.h @@ -0,0 +1,23 @@ +#ifndef FREE_TENSOR_SHRINK_LINEAR_INDICES_H +#define FREE_TENSOR_SHRINK_LINEAR_INDICES_H + +#include + +namespace freetensor { + +/** + * Mutator for shrinking linear indices in variables + * + * If a variable is consistently accessed with a linear expression, e.g., `a[8i + * + 2j]`, and `2j` as a integer bound no larger than 8, e.g., `0 <= 2j < 4`, + * then we can shrink the expression to be `a[4i + 2j]`. + * + * @{ + */ +Stmt shrinkLinearIndices(const Stmt &ast, const ID &vardef); +Stmt shrinkLinearIndices(const Stmt &ast); +/** @} */ + +} // namespace freetensor + +#endif // FREE_TENSOR_SHRINK_LINEAR_INDICES_H diff --git a/include/pass/shrink_var.h b/include/pass/shrink_var.h index be86216a7..fa65b5507 100644 --- a/include/pass/shrink_var.h +++ b/include/pass/shrink_var.h @@ -10,6 +10,13 @@ namespace freetensor { +/** + * Main mutator for shrinking variables + * + * This mutator modifies the shape of each variable to be the upper bound + * expression minus the lower bound expression plus one, with respect to each + * access of the variable. + */ class ShrinkVar : public Mutator { // Bound considering the old shape. Used for preventing make the shape even // larger after shrinking diff --git a/src/pass/shrink_linear_indices.cc b/src/pass/shrink_linear_indices.cc new file mode 100644 index 000000000..dfaff92f0 --- /dev/null +++ b/src/pass/shrink_linear_indices.cc @@ -0,0 +1,198 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace freetensor { + +namespace { + +struct IntBound { + int64_t lower_, upper_; +}; + +class GatherLinearIndices : public CompTransientBounds { + typedef CompTransientBounds BaseClass; + + ID vardef_; + std::string var_; + + std::vector> bounds_; + + Ref unique_; + + public: + GatherLinearIndices(const ID &vardef) : vardef_(vardef) {} + + const auto &bounds() const { return bounds_; } + + private: + template void visitAcc(const T &op) { + BaseClass::visit(op); + if (op->var_ == var_) { + ASSERT(bounds_.size() == op->indices_.size()); + for (auto &&[idx, bound] : views::zip(op->indices_, bounds_)) { + auto lin = linear(idx); + for (auto &&[_k, a] : lin.coeff_) { + int k = _k; + auto l = unique_->getIntLower(a); + auto u = unique_->getIntUpper(a); + if (k < 0) { + k = -k; + l = -l; + u = -u; + std::swap(l, u); + } + if (!bound.count(k)) { + bound[k] = {l, u}; + } else { + bound[k].lower_ = std::min(bound[k].lower_, l); + bound[k].upper_ = std::max(bound[k].upper_, u); + } + } + } + } + } + + protected: + using BaseClass::visit; + + void visitStmt(const Stmt &s) override { + // CompUniqueBounds requires one instance per Stmt + auto uniqueOfOuterStmt = unique_; + unique_ = Ref::make(*this); + BaseClass::visitStmt(s); + unique_ = uniqueOfOuterStmt; + } + + void visit(const VarDef &op) override { + if (op->id() == vardef_) { + var_ = op->name_; + bounds_.resize(op->buffer_->tensor()->shape().size()); + BaseClass::visit(op); + var_.clear(); + } else { + BaseClass::visit(op); + } + } + + void visit(const Load &op) override { visitAcc(op); } + void visit(const Store &op) override { visitAcc(op); } + void visit(const ReduceTo &op) override { visitAcc(op); } +}; + +class ReplaceLinearIndices : public Mutator { + ID vardef_; + std::string var_; + + const std::vector> &replace_; + + public: + ReplaceLinearIndices( + const ID &vardef, + const std::vector> &replace) + : vardef_(vardef), replace_(replace) {} + + private: + template auto visitAcc(const T &_op) { + auto __op = Mutator::visit(_op); + ASSERT(__op->nodeType() == _op->nodeType()); + auto op = __op.template as(); + if (op->var_ == var_) { + for (auto &&[idx, rep] : views::zip(op->indices_, replace_)) { + auto lin = linear(idx); + for (auto &[k, a] : lin.coeff_) { + k = rep.at(k); + } + idx = lin2expr(lin); + } + } + return op; + } + + protected: + Stmt visit(const VarDef &op) override { + if (op->id() == vardef_) { + var_ = op->name_; + auto ret = Mutator::visit(op); + var_.clear(); + return ret; + } else { + return Mutator::visit(op); + } + } + + Expr visit(const Load &op) override { return visitAcc(op); } + Stmt visit(const Store &op) override { return visitAcc(op); } + Stmt visit(const ReduceTo &op) override { return visitAcc(op); } +}; + +} // Anonymous namespace + +Stmt shrinkLinearIndices(const Stmt &_ast, const ID &vardef) { + Stmt ast = _ast; + + GatherLinearIndices gather{vardef}; + gather(ast); + auto &&bounds = gather.bounds(); + + bool needMutation = false; + std::vector> replaceCoeff; + for (auto &&_bound : bounds) { + auto bound = + _bound | ranges::to>>(); + std::sort(bound.begin(), bound.end(), + [](const auto &lhs, const auto &rhs) { + return lhs.first > rhs.first; + }); // Sort k from high to low + std::vector newCoeff = + bound | views::keys | ranges::to(); + for (size_t n = bound.size(), i = n - 1; ~i; i--) { + int g = newCoeff[0]; + for (size_t j = 1; j <= i; j++) { + g = gcd(g, newCoeff[j]); + } + int64_t l = LLONG_MAX, u = LLONG_MIN; + if (i + 1 < n) { + for (size_t j = i + 1; j < n; j++) { + l = std::min(l, newCoeff[j] * bound[j].second.lower_); + u = std::max(u, newCoeff[j] * bound[j].second.upper_); + } + } else { + l = u = 0; + } + if (u - l + 1 < g) { + for (size_t j = 0; j <= i; j++) { + newCoeff[j] = newCoeff[j] / g * (u - l + 1); + } + needMutation = true; + } + } + replaceCoeff.emplace_back(views::zip(bound | views::keys, newCoeff) | + ranges::to()); + } + + if (needMutation) { + ast = ReplaceLinearIndices{vardef, replaceCoeff}(ast); + } + + return ast; +} + +Stmt shrinkLinearIndices(const Stmt &_ast) { + Stmt ast = _ast; + for (auto &&[varDefId, name] : allDefs(ast, {AccessType::Cache})) { + ast = shrinkLinearIndices(ast, varDefId); + } + return ast; +} + +} // namespace freetensor diff --git a/src/pass/shrink_var.cc b/src/pass/shrink_var.cc index 7eb649b52..c69ffc738 100644 --- a/src/pass/shrink_var.cc +++ b/src/pass/shrink_var.cc @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -101,6 +102,8 @@ Stmt ShrinkVar::visit(const ReduceTo &_op) { Stmt shrinkVar(const Stmt &_op) { auto op = removeDeadVar(_op); + op = shrinkLinearIndices(op); + // Algorithm: // (1) Represent the bounds of each vars with min / max expressions // (2) Modify var definitions @@ -125,6 +128,8 @@ Stmt shrinkVar(const Stmt &_op) { Stmt shrinkSingleVar(const Stmt &_op, const ID &varDefId) { auto op = removeDeadVar(_op); + op = shrinkLinearIndices(op, varDefId); + // (1) std::unordered_map boundsWithShape, boundsWithoutShape; boundsWithShape[varDefId] = diff --git a/test/20.pass/test_shrink_var.py b/test/20.pass/test_shrink_var.py index 632936b0d..a110f22f3 100644 --- a/test/20.pass/test_shrink_var.py +++ b/test/20.pass/test_shrink_var.py @@ -252,3 +252,34 @@ def test_const_in_branch_2(): std = ft.pop_ast() assert std.match(ast) + + +def test_over_relaxed_linear(): + with ft.VarDef([("x", (12,), "int32", "input", "cpu"), + ("y1", (12,), "int32", "output", "cpu"), + ("y2", (12,), "int32", "output", "cpu")]) as (x, y1, y2): + with ft.VarDef("b", (1000,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 3) as i: + with ft.For("j", 0, 4) as j: + b[i * 100 + j * 10] = x[i * 4 + j] + with ft.For("i", 0, 3) as i: + with ft.For("j", 0, 4) as j: + y1[i * 4 + j] = b[i * 100 + j * 10] * i + y2[i * 4 + j] = b[i * 100 + j * 10] + i + ast = ft.pop_ast(verbose=True) + ast = ft.lower(ast, verbose=1) + + with ft.VarDef([("x", (12,), "int32", "input", "cpu"), + ("y1", (12,), "int32", "output", "cpu"), + ("y2", (12,), "int32", "output", "cpu")]) as (x, y1, y2): + with ft.VarDef("b", (12,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 3) as i: + with ft.For("j", 0, 4) as j: + b[i * 4 + j] = x[i * 4 + j] + with ft.For("i", 0, 3) as i: + with ft.For("j", 0, 4) as j: + y1[i * 4 + j] = b[i * 4 + j] * i + y2[i * 4 + j] = b[i * 4 + j] + i + std = ft.pop_ast() + + assert std.match(ast)