Skip to content

Commit

Permalink
Miscellaneous fixes (#586)
Browse files Browse the repository at this point in the history
* Miscellaneous fixes

- Check data type in pass/simplify to avoid mixing different data types
  in one IfExpr node.
- Fix negative steps when detecting strides in pass/shrink_for
- Reduce redundant analysis in pass/shrink_for

* Minor fix

* Minor fix
  • Loading branch information
roastduck authored Jan 16, 2024
1 parent 7696f46 commit e768fb2
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 35 deletions.
6 changes: 6 additions & 0 deletions include/analyze/comp_unique_bounds_pb.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class CompUniqueBoundsPB : public CompUniqueBounds {
Expr lowerExpr() const override;
Expr upperExpr() const override;

/**
* Fused function returning `lowerExpr()`, `upperExpr()` and
* `upperExpr() - lowerExpr()`, with less redundant computation
*/
std::tuple<Expr, Expr, Expr> lowerUpperDiffExpr() const;

Ref<CompUniqueBounds::Bound> restrictScope(
const std::unordered_set<std::string> &scope) const override;

Expand Down
4 changes: 4 additions & 0 deletions include/math/presburger.h
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,10 @@ template <PBMapRef T> PBMap coalesce(T &&map) {
return isl_map_coalesce(PBRefTake<T>(map));
}

template <PBSetRef T, PBSetRef U> PBSet cartesianProduct(T &&lhs, U &&rhs) {
return isl_set_flat_product(PBRefTake<T>(lhs), PBRefTake<U>(rhs));
}

template <PBSetRef T> PBVal dimMaxVal(T &&set, int pos) {
return isl_set_dim_max_val(PBRefTake<T>(set), pos);
}
Expand Down
13 changes: 13 additions & 0 deletions src/analyze/comp_unique_bounds_pb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ Expr CompUniqueBoundsPB::Bound::upperExpr() const {
? translateBoundFunc(*ctx_, lexmax(bound_), *demangleMap_)
: nullptr;
}
std::tuple<Expr, Expr, Expr>
CompUniqueBoundsPB::Bound::lowerUpperDiffExpr() const {
PBSet l = bound_.hasLowerBound(0) ? lexmin(bound_) : PBSet();
PBSet u = bound_.hasUpperBound(0) ? lexmax(bound_) : PBSet();
PBSet diff =
l.isValid() && u.isValid()
? apply(cartesianProduct(u, l), PBMap(*ctx_, "{[u, l] -> [u - l]}"))
: PBSet();
return {l.isValid() ? translateBoundFunc(*ctx_, l, *demangleMap_) : nullptr,
u.isValid() ? translateBoundFunc(*ctx_, u, *demangleMap_) : nullptr,
diff.isValid() ? translateBoundFunc(*ctx_, diff, *demangleMap_)
: nullptr};
}

Ref<CompUniqueBounds::Bound> CompUniqueBoundsPB::Bound::restrictScope(
const std::unordered_set<std::string> &scope) const {
Expand Down
100 changes: 66 additions & 34 deletions src/pass/shrink_for.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,32 +70,40 @@ class CompUniqueBoundsPBWithStride : public CompUniqueBoundsPB {
CompUniqueBoundsPBWithStride(const CompTransientBoundsInterface &transients)
: CompUniqueBoundsPB(transients) {}

std::tuple<Expr /* lower */, Expr /* upper */, int64_t /* modulo */,
Expr /* offset */>
std::tuple<Expr /* lower */, Expr /* upper */, Expr /* upper - lower */,
int64_t /* modulo */, Expr /* offset */>
unionBoundsAndGetStride(
const std::vector<Ref<CompUniqueBounds::Bound>> &bounds,
bool requireConst) {
auto bound = unionBoundsAsBound(bounds);

// if no bound presented, return an empty range
if (!bound.isValid()) {
return {makeIntConst(0), makeIntConst(-1), 1, makeIntConst(0)};
return {makeIntConst(0), makeIntConst(-1), makeIntConst(0), 1,
makeIntConst(0)};
}

// translate the lower and upper bounds back to expression
auto l =
requireConst ? makeIntConst(bound->lowerInt()) : bound->lowerExpr();
auto u =
requireConst ? makeIntConst(bound->upperInt()) : bound->upperExpr();
Expr l, u, diff;
if (requireConst) {
auto ll = bound->lowerInt();
auto uu = bound->upperInt();
l = makeIntConst(ll);
u = makeIntConst(uu);
diff = makeIntConst(uu - ll);
} else {
std::tie(l, u, diff) = bound->lowerUpperDiffExpr();
}

// Addition detction for strides
auto [strideInt, offsetExpr] = getStride(bound, requireConst);

return {l, u, strideInt, offsetExpr};
return {l, u, diff, strideInt, offsetExpr};
}

std::vector<std::tuple<Expr /* lower */, Expr /* upper */,
int64_t /* modulo */, Expr /* offset */>>
std::vector<
std::tuple<Expr /* lower */, Expr /* upper */, Expr /* upper - lower */,
int64_t /* modulo */, Expr /* offset */>>
unionBoundsAndGetHighOrderStride(
const std::vector<Ref<CompUniqueBounds::Bound>> &bounds,
bool requireConst) {
Expand All @@ -119,24 +127,37 @@ class CompUniqueBoundsPBWithStride : public CompUniqueBoundsPB {
isl_map_wrap(isl_map_reverse(isl_set_unwrap(set.move()))));

ASSERT(set.nDims() >= 1);
std::vector<std::tuple<Expr, Expr, int64_t, Expr>> ret;
std::vector<std::tuple<Expr, Expr, Expr, int64_t, Expr>> ret;
ret.reserve(set.nDims());
auto demangleMap = *bound->demangleMap_;
for (int i = 0;; i++) {
int i = 0;
while (true) {
// Project onto the loop we are checking
PBSet thisLoopSet = projectOutDims(set, 1, set.nDims() - 1);
if (thisLoopSet.isSingleValued() && set.nDims() > 1) {
// This dimension has no contribution. But the last dim in `set`
// must no be skipped, because it is the target loop
set = projectOutDims(std::move(set), 0, 1);
continue;
}

auto thisLoopBound = Ref<CompUniqueBoundsPB::Bound>::make(
bound->ctx_,
Ref<std::unordered_map<std::string, Expr>>::make(demangleMap),
thisLoopSet);
auto l = requireConst ? makeIntConst(bound->lowerInt())
: thisLoopBound->lowerExpr();
auto u = requireConst ? makeIntConst(bound->upperInt())
: thisLoopBound->upperExpr();
Expr l, u, diff;
if (requireConst) {
auto ll = thisLoopBound->lowerInt();
auto uu = thisLoopBound->upperInt();
l = makeIntConst(ll);
u = makeIntConst(uu);
diff = makeIntConst(uu - ll);
} else {
std::tie(l, u, diff) = thisLoopBound->lowerUpperDiffExpr();
}
auto [strideInt, offsetExpr] =
getStride(thisLoopBound, requireConst);
ret.emplace_back(l, u, strideInt, offsetExpr);
ret.emplace_back(l, u, diff, strideInt, offsetExpr);

if (set.nDims() == 1) {
break;
Expand All @@ -145,7 +166,7 @@ class CompUniqueBoundsPBWithStride : public CompUniqueBoundsPB {
// dimensions to parameter dimensions, so inner loops will be
// represented by outer loops. The parameter name used here is
// temporary, and will be replaced later.
auto paramName = "ft_shrink_for_tmp_" + std::to_string(i);
auto paramName = "ft_shrink_for_tmp_" + std::to_string(i++);
set = moveDimToNamedParam(*bound->ctx_, std::move(set), 0,
paramName);
demangleMap[paramName] = makeVar(paramName);
Expand Down Expand Up @@ -262,7 +283,8 @@ Stmt ShrinkFor::visit(const For &_op) {
std::unordered_map<std::string, Expr> replace;
Stmt ret = op->body_;
for (auto &&[i, item] : views::reverse(views::enumerate(info))) {
auto &&[lower, upper, stride, offset] = item;
auto &&[lower, upper, diff, stride, offset] = item;
ASSERT(stride > 0);

// The last (first before we reverse it) iter is the original iter.
// Keep its name. The others are renamed.
Expand All @@ -274,17 +296,22 @@ Stmt ShrinkFor::visit(const For &_op) {
replace["ft_shrink_for_tmp_" + std::to_string(i)] =
makeVar(thisIterName);

// Find the lowest integer after `lower` that remains `offset`
// modulo `stride`: lowerOnOffset = lower + ((offset - lower) %
// stride + stride) % stride
auto begin =
makeAdd(lower, makeMod(makeAdd(makeMod(makeSub(offset, lower),
makeIntConst(stride)),
makeIntConst(stride)),
makeIntConst(stride)));
auto begin = lower;
auto end = makeAdd(upper, makeIntConst(1));
auto len = makeAdd(diff, makeIntConst(1));
if (stride > 1) {
// Find the lowest integer after `lower` that remains `offset`
// modulo `stride`: lowerOnOffset = lower + ((offset - lower) %
// stride + stride) % stride
begin = makeAdd(lower,
makeMod(makeAdd(makeMod(makeSub(offset, lower),
makeIntConst(stride)),
makeIntConst(stride)),
makeIntConst(stride)));
len = makeAdd(makeFloorDiv(diff, makeIntConst(stride)),
makeIntConst(1));
}
auto step = makeIntConst(stride);
auto len = makeCeilDiv(makeSub(end, begin), step);

ret = makeFor(thisIterName, std::move(begin), std::move(end),
std::move(step), std::move(len), op->property_,
Expand All @@ -299,8 +326,9 @@ Stmt ShrinkFor::visit(const For &_op) {
return ret;

} else {
auto [lower, upper, stride, offset] =
auto [lower, upper, diff, stride, offset] =
bound.unionBoundsAndGetStride(newRange_[var], requireConst);
ASSERT(stride > 0);

// Since we can't normalize the loops (see the comment in shrinkFor), we
// have to handle step_ here.
Expand All @@ -319,19 +347,21 @@ Stmt ShrinkFor::visit(const For &_op) {
makeIntConst(stride)),
makeIntConst(stride)),
makeIntConst(stride)));
op->len_ =
makeAdd(makeFloorDiv(diff, makeIntConst(stride)),
makeIntConst(1));
} else {
op->begin_ = lower;
op->len_ = makeAdd(diff, makeIntConst(1));
}
}
if (upper.isValid()) {
op->end_ = makeAdd(upper, makeIntConst(1));
}
op->step_ = makeIntConst(stride);
op->len_ =
makeCeilDiv(makeSub(op->end_, op->begin_), op->step_);
} else if (step < 0) {
if (upper.isValid()) {
if (stride < -1) {
if (stride > 1) {
// Find the highest integer before `upper` that remains
// `offset` modulo `stride`: upperOnOffset = upper -
// ((upper - offset) % stride + stride) % stride
Expand All @@ -341,16 +371,18 @@ Stmt ShrinkFor::visit(const For &_op) {
makeIntConst(stride)),
makeIntConst(stride)),
makeIntConst(stride)));
op->len_ =
makeAdd(makeFloorDiv(diff, makeIntConst(stride)),
makeIntConst(1));
} else {
op->begin_ = upper;
op->len_ = makeAdd(diff, makeIntConst(1));
}
}
if (lower.isValid()) {
op->end_ = makeAdd(lower, makeIntConst(-1));
}
op->step_ = makeIntConst(-stride);
op->len_ =
makeCeilDiv(makeSub(op->end_, op->begin_), op->step_);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/pass/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,8 @@ Expr SimplifyPass::visit(const IfExpr &_op) {
} else if (op->thenCase_->nodeType() == ASTNodeType::Cast) {
auto &&thenCase = op->thenCase_.as<CastNode>();
auto &&elseCase = op->elseCase_.as<CastNode>();
if (thenCase->destType_ == elseCase->destType_) {
if (thenCase->destType_ == elseCase->destType_ &&
thenCase->expr_->dtype() == elseCase->expr_->dtype()) {
return makeCast(
makeIfExpr(op->cond_, thenCase->expr_, elseCase->expr_),
thenCase->destType_);
Expand Down

0 comments on commit e768fb2

Please sign in to comment.