Skip to content

Commit

Permalink
Shrink overly relaxed linear indices in pass/shrink_var
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck committed Jan 18, 2024
1 parent 3babd01 commit c182842
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 2 deletions.
10 changes: 8 additions & 2 deletions include/analyze/comp_transient_bounds.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#ifndef FREE_TENSOR_COMP_TRANSIENT_BOUNDS_H
#define FREE_TENSOR_COMP_TRANSIENT_BOUNDS_H

#include <type_traits>
#include <unordered_set>

#include <analyze/all_uses.h>
#include <analyze/analyze_linear.h>
#include <analyze/as_dnf.h>
#include <analyze/symbol_table.h>
#include <container_utils.h>
#include <hash.h>
#include <math/bounds.h>
Expand Down Expand Up @@ -148,9 +150,13 @@ class CompTransientBounds : public BaseClass,
conds_.emplace_back(makeEQ(var, op->begin_));
}
}
this->pushFor(op);
if constexpr (std::is_base_of_v<SymbolTableInterface, BaseClass>) {
this->pushFor(op);
}
MAYBE_VOID(body, (*this)(op->body_));
this->popFor(op);
if constexpr (std::is_base_of_v<SymbolTableInterface, BaseClass>) {
this->popFor(op);
}
conds_.resize(oldCondsSize);
transients_.erase(var);

Expand Down
23 changes: 23 additions & 0 deletions include/pass/shrink_linear_indices.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef FREE_TENSOR_SHRINK_LINEAR_INDICES_H
#define FREE_TENSOR_SHRINK_LINEAR_INDICES_H

#include <stmt.h>

namespace freetensor {

/**
* Mutator for shrinking linear indices in variables
*
* If a variable is consistently accessed with a linear expression, e.g., `a[8i
* + 2j]`, and `2j` as a integer bound no larger than 8, e.g., `0 <= 2j < 4`,
* then we can shrink the expression to be `a[4i + 2j]`.
*
* @{
*/
Stmt shrinkLinearIndices(const Stmt &ast, const ID &vardef);
Stmt shrinkLinearIndices(const Stmt &ast);
/** @} */

} // namespace freetensor

#endif // FREE_TENSOR_SHRINK_LINEAR_INDICES_H
7 changes: 7 additions & 0 deletions include/pass/shrink_var.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@

namespace freetensor {

/**
* Main mutator for shrinking variables
*
* This mutator modifies the shape of each variable to be the upper bound
* expression minus the lower bound expression plus one, with respect to each
* access of the variable.
*/
class ShrinkVar : public Mutator {
// Bound considering the old shape. Used for preventing make the shape even
// larger after shrinking
Expand Down
198 changes: 198 additions & 0 deletions src/pass/shrink_linear_indices.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
#include <algorithm>
#include <unordered_map>

#include <analyze/all_defs.h>
#include <analyze/analyze_linear.h>
#include <analyze/comp_transient_bounds.h>
#include <analyze/comp_unique_bounds_combination.h>
#include <container_utils.h>
#include <math/utils.h>
#include <mutator.h>
#include <pass/shrink_linear_indices.h>
#include <visitor.h>

namespace freetensor {

namespace {

struct IntBound {
int64_t lower_, upper_;
};

class GatherLinearIndices : public CompTransientBounds<Visitor> {
typedef CompTransientBounds<Visitor> BaseClass;

ID vardef_;
std::string var_;

std::vector<std::unordered_map<int64_t /* coeff */, IntBound>> bounds_;

Ref<CompUniqueBounds> unique_;

public:
GatherLinearIndices(const ID &vardef) : vardef_(vardef) {}

const auto &bounds() const { return bounds_; }

private:
template <typename T> void visitAcc(const T &op) {
BaseClass::visit(op);
if (op->var_ == var_) {
ASSERT(bounds_.size() == op->indices_.size());
for (auto &&[idx, bound] : views::zip(op->indices_, bounds_)) {
auto lin = linear(idx);
for (auto &&[_k, a] : lin.coeff_) {
int k = _k;
auto l = unique_->getIntLower(a);
auto u = unique_->getIntUpper(a);
if (k < 0) {
k = -k;
l = -l;
u = -u;
std::swap(l, u);
}
if (!bound.count(k)) {
bound[k] = {l, u};
} else {
bound[k].lower_ = std::min(bound[k].lower_, l);
bound[k].upper_ = std::max(bound[k].upper_, u);
}
}
}
}
}

protected:
using BaseClass::visit;

void visitStmt(const Stmt &s) override {
// CompUniqueBounds requires one instance per Stmt
auto uniqueOfOuterStmt = unique_;
unique_ = Ref<CompUniqueBoundsCombination>::make(*this);
BaseClass::visitStmt(s);
unique_ = uniqueOfOuterStmt;
}

void visit(const VarDef &op) override {
if (op->id() == vardef_) {
var_ = op->name_;
bounds_.resize(op->buffer_->tensor()->shape().size());
BaseClass::visit(op);
var_.clear();
} else {
BaseClass::visit(op);
}
}

void visit(const Load &op) override { visitAcc(op); }
void visit(const Store &op) override { visitAcc(op); }
void visit(const ReduceTo &op) override { visitAcc(op); }
};

class ReplaceLinearIndices : public Mutator {
ID vardef_;
std::string var_;

const std::vector<std::unordered_map<int64_t, int64_t>> &replace_;

public:
ReplaceLinearIndices(
const ID &vardef,
const std::vector<std::unordered_map<int64_t, int64_t>> &replace)
: vardef_(vardef), replace_(replace) {}

private:
template <typename T> auto visitAcc(const T &_op) {
auto __op = Mutator::visit(_op);
ASSERT(__op->nodeType() == _op->nodeType());
auto op = __op.template as<typename T::Object>();
if (op->var_ == var_) {
for (auto &&[idx, rep] : views::zip(op->indices_, replace_)) {
auto lin = linear(idx);
for (auto &[k, a] : lin.coeff_) {
k = rep.at(k);
}
idx = lin2expr(lin);
}
}
return op;
}

protected:
Stmt visit(const VarDef &op) override {
if (op->id() == vardef_) {
var_ = op->name_;
auto ret = Mutator::visit(op);
var_.clear();
return ret;
} else {
return Mutator::visit(op);
}
}

Expr visit(const Load &op) override { return visitAcc(op); }
Stmt visit(const Store &op) override { return visitAcc(op); }
Stmt visit(const ReduceTo &op) override { return visitAcc(op); }
};

} // Anonymous namespace

Stmt shrinkLinearIndices(const Stmt &_ast, const ID &vardef) {
Stmt ast = _ast;

GatherLinearIndices gather{vardef};
gather(ast);
auto &&bounds = gather.bounds();

bool needMutation = false;
std::vector<std::unordered_map<int64_t, int64_t>> replaceCoeff;
for (auto &&_bound : bounds) {
auto bound =
_bound | ranges::to<std::vector<std::pair<int64_t, IntBound>>>();
std::sort(bound.begin(), bound.end(),
[](const auto &lhs, const auto &rhs) {
return lhs.first > rhs.first;
}); // Sort k from high to low
std::vector<int64_t> newCoeff =
bound | views::keys | ranges::to<std::vector>();
for (size_t n = bound.size(), i = n - 1; ~i; i--) {
int g = newCoeff[0];
for (size_t j = 1; j <= i; j++) {
g = gcd(g, newCoeff[j]);
}
int64_t l = LLONG_MAX, u = LLONG_MIN;
if (i + 1 < n) {
for (size_t j = i + 1; j < n; j++) {
l = std::min(l, newCoeff[j] * bound[j].second.lower_);
u = std::max(u, newCoeff[j] * bound[j].second.upper_);
}
} else {
l = u = 0;
}
if (u - l + 1 < g) {
for (size_t j = 0; j <= i; j++) {
newCoeff[j] = newCoeff[j] / g * (u - l + 1);
}
needMutation = true;
}
}
replaceCoeff.emplace_back(views::zip(bound | views::keys, newCoeff) |
ranges::to<std::unordered_map>());
}

if (needMutation) {
ast = ReplaceLinearIndices{vardef, replaceCoeff}(ast);
}

return ast;
}

Stmt shrinkLinearIndices(const Stmt &_ast) {
Stmt ast = _ast;
for (auto &&[varDefId, name] : allDefs(ast, {AccessType::Cache})) {
ast = shrinkLinearIndices(ast, varDefId);
}
return ast;
}

} // namespace freetensor
5 changes: 5 additions & 0 deletions src/pass/shrink_var.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <analyze/find_stmt.h>
#include <container_utils.h>
#include <pass/remove_dead_var.h>
#include <pass/shrink_linear_indices.h>
#include <pass/shrink_var.h>
#include <pass/simplify.h>
#include <pass/z3_simplify.h>
Expand Down Expand Up @@ -101,6 +102,8 @@ Stmt ShrinkVar::visit(const ReduceTo &_op) {
Stmt shrinkVar(const Stmt &_op) {
auto op = removeDeadVar(_op);

op = shrinkLinearIndices(op);

// Algorithm:
// (1) Represent the bounds of each vars with min / max expressions
// (2) Modify var definitions
Expand All @@ -125,6 +128,8 @@ Stmt shrinkVar(const Stmt &_op) {
Stmt shrinkSingleVar(const Stmt &_op, const ID &varDefId) {
auto op = removeDeadVar(_op);

op = shrinkLinearIndices(op, varDefId);

// (1)
std::unordered_map<ID, AccessBound> boundsWithShape, boundsWithoutShape;
boundsWithShape[varDefId] =
Expand Down
31 changes: 31 additions & 0 deletions test/20.pass/test_shrink_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,34 @@ def test_const_in_branch_2():
std = ft.pop_ast()

assert std.match(ast)


def test_over_relaxed_linear():
with ft.VarDef([("x", (12,), "int32", "input", "cpu"),
("y1", (12,), "int32", "output", "cpu"),
("y2", (12,), "int32", "output", "cpu")]) as (x, y1, y2):
with ft.VarDef("b", (1000,), "int32", "cache", "cpu") as b:
with ft.For("i", 0, 3) as i:
with ft.For("j", 0, 4) as j:
b[i * 100 + j * 10] = x[i * 4 + j]
with ft.For("i", 0, 3) as i:
with ft.For("j", 0, 4) as j:
y1[i * 4 + j] = b[i * 100 + j * 10] * i
y2[i * 4 + j] = b[i * 100 + j * 10] + i
ast = ft.pop_ast(verbose=True)
ast = ft.lower(ast, verbose=1)

with ft.VarDef([("x", (12,), "int32", "input", "cpu"),
("y1", (12,), "int32", "output", "cpu"),
("y2", (12,), "int32", "output", "cpu")]) as (x, y1, y2):
with ft.VarDef("b", (12,), "int32", "cache", "cpu") as b:
with ft.For("i", 0, 3) as i:
with ft.For("j", 0, 4) as j:
b[i * 4 + j] = x[i * 4 + j]
with ft.For("i", 0, 3) as i:
with ft.For("j", 0, 4) as j:
y1[i * 4 + j] = b[i * 4 + j] * i
y2[i * 4 + j] = b[i * 4 + j] + i
std = ft.pop_ast()

assert std.match(ast)

0 comments on commit c182842

Please sign in to comment.