diff --git a/src/pass/shrink_for.cc b/src/pass/shrink_for.cc index 194114268..502a3f9ce 100644 --- a/src/pass/shrink_for.cc +++ b/src/pass/shrink_for.cc @@ -228,7 +228,7 @@ Stmt ShrinkFor::visitStmt(const Stmt &stmt) { default:; } if (checker.hasSideEffect()) { - for (auto &&[_var, _names] : views::zip(iterStack_, namesStack_)) { + for (auto &&[var, _names] : views::zip(iterStack_, namesStack_)) { auto &&names = filterNames(_names); // We need linear programming from PBCompBounds, because the @@ -239,8 +239,6 @@ Stmt ShrinkFor::visitStmt(const Stmt &stmt) { // PBCompBounds requires one instance per Stmt CompUniqueBoundsPBWithStride bound(*this); - // Trigger recomputing in analyze/comp_unique_bounds - auto var = deepCopy(_var).as(); newRange_[var].emplace_back( bound.getBound(var)->restrictScope(names)); } @@ -253,16 +251,17 @@ Stmt ShrinkFor::visit(const For &_op) { auto var = makeVar(_op->iter_).as(); newRange_.erase(var); - iterStack_.emplace_back(var); - namesStack_.emplace_back(names()); - auto __op = BaseClass::visit(_op); - ASSERT(__op->nodeType() == ASTNodeType::For); - auto op = __op.as(); - namesStack_.pop_back(); - iterStack_.pop_back(); - - if ((subAST_.isValid() && !inSubAST_) || !filterLoop(op)) { - return op; + For op; + if ((subAST_.isValid() && !inSubAST_) || !filterLoop(_op)) { + return BaseClass::visit(_op); + } else { + iterStack_.emplace_back(var); + namesStack_.emplace_back(names()); + auto __op = BaseClass::visit(_op); + ASSERT(__op->nodeType() == ASTNodeType::For); + op = __op.as(); + namesStack_.pop_back(); + iterStack_.pop_back(); } if (!newRange_.count(var)) {