diff --git a/ffi/pass.cc b/ffi/pass.cc index 193afbd8e..cd6dc1233 100644 --- a/ffi/pass.cc +++ b/ffi/pass.cc @@ -85,12 +85,15 @@ void init_ffi_pass(py::module_ &m) { "stmt"_a); m.def("shrink_for", - static_cast( - &shrinkFor), - "func"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true); - m.def("shrink_for", - static_cast(&shrinkFor), - "stmt"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true); + static_cast(&shrinkFor), + "func"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true, + "unordered"_a = false); + m.def( + "shrink_for", + static_cast(&shrinkFor), + "stmt"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true, + "unordered"_a = false); m.def("merge_and_hoist_if", static_cast(&mergeAndHoistIf), "func"_a); diff --git a/include/math/parse_pb_expr.h b/include/math/parse_pb_expr.h index 6a88cf8e6..2189b3359 100644 --- a/include/math/parse_pb_expr.h +++ b/include/math/parse_pb_expr.h @@ -29,16 +29,37 @@ typedef std::vector PBFuncAST; */ PBFuncAST parsePBFunc(const std::string &str); -/** - * Parse a PBFunc to be ASTs, but only restricted to one contiguous factor - */ -SimplePBFuncAST parseSimplePBFunc(const std::string &str); - /** * Construct AST from PBSet while preserving min and max with a special hack to * ISL + * + * @{ */ PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set); +PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBMap &map); +/** @} */ + +/** + * Parse a PBFunc to be ASTs, but only restricted to one contiguous factor + * + * @{ + */ +inline SimplePBFuncAST parseSimplePBFunc(const std::string &str) { + auto ret = parsePBFunc(str); + if (ret.size() != 1) { + throw ParserError(str + " is not a simple PBFunc"); + } + return ret.front(); +} +inline SimplePBFuncAST parseSimplePBFuncReconstructMinMax(const PBCtx &ctx, + const auto &f) { + auto ret = parsePBFuncReconstructMinMax(ctx, f); + if (ret.size() != 1) { + throw ParserError(FT_MSG << f << " is not a simple PBFunc"); + } + return ret.front(); +} +/** @} */ } // namespace freetensor diff --git a/include/math/presburger.h b/include/math/presburger.h index 31f41e167..60b6a0115 100644 --- a/include/math/presburger.h +++ b/include/math/presburger.h @@ -605,6 +605,13 @@ template PBMap upperBoundOutputDim(T &&map, unsigned pos, int x) { return isl_map_upper_bound_si(PBRefTake(map), isl_dim_out, pos, x); } +template PBMap newDomainOnlyMap(T &&set) { + return isl_map_from_domain(PBRefTake(set)); +} +template PBMap newRangeOnlyMap(T &&set) { + return isl_map_from_range(PBRefTake(set)); +} + template PBMap moveDimsInputToOutput(T &&map, unsigned first, unsigned n, unsigned target) { diff --git a/include/pass/shrink_for.h b/include/pass/shrink_for.h index 91ff8c487..1872749e2 100644 --- a/include/pass/shrink_for.h +++ b/include/pass/shrink_for.h @@ -27,6 +27,8 @@ class CheckSideEffect : public Visitor { class ShrinkFor : public CompTransientBounds> { typedef CompTransientBounds> BaseClass; + bool unordered_; + ASTHashMap>> newRange_; std::vector iterStack_; std::vector> namesStack_; @@ -36,6 +38,8 @@ class ShrinkFor : public CompTransientBounds> { bool inSubAST_ = false; public: + ShrinkFor(bool unordered = false) : unordered_(unordered) {} + void setSubAST(const Stmt &subAST); protected: @@ -63,13 +67,19 @@ class ShrinkFor : public CompTransientBounds> { * @param doSimplify : If true, run simplify before and after the tranformation. * Transformations are required to ensure the effectiveness of the shrinking. * Please do your own simplification if you want to set it to false. + * @param unordered : If true, shrink the loops aggressively which does NOT + * preserve the iterating order. The caller should guarantee the correctness. + * This is effective when the pattern of redundant iterations is complex. This + * may also split one loop into multiple loops. * * @{ */ -Stmt shrinkFor(const Stmt &op, const ID &subAST = ID(), bool doSimplify = true); +Stmt shrinkFor(const Stmt &op, const ID &subAST = ID(), bool doSimplify = true, + bool unordered = false); inline Stmt shrinkFor(const Stmt &op, const Stmt &subAST, - bool doSimplify = true) { - return shrinkFor(op, subAST.isValid() ? subAST->id() : ID(), doSimplify); + bool doSimplify = true, bool unordered = false) { + return shrinkFor(op, subAST.isValid() ? subAST->id() : ID(), doSimplify, + unordered); } /** @} */ diff --git a/src/math/parse_pb_expr.cc b/src/math/parse_pb_expr.cc index 71114b34d..e06a574a2 100644 --- a/src/math/parse_pb_expr.cc +++ b/src/math/parse_pb_expr.cc @@ -1,10 +1,12 @@ +#include + #include #include #include #include #include -#include +#include #include #include #include @@ -125,14 +127,6 @@ PBFuncAST parsePBFunc(const std::string &str) { } } -SimplePBFuncAST parseSimplePBFunc(const std::string &str) { - auto ret = parsePBFunc(str); - if (ret.size() != 1) { - throw ParserError(str + " is not a simple PBFunc"); - } - return ret.front(); -} - namespace { Expr isl2Expr(__isl_take isl_ast_expr *e) { @@ -254,24 +248,25 @@ Expr isl2Expr(__isl_take isl_ast_expr *e) { return res; } -PBFuncAST isl2Func(__isl_take isl_ast_node *node) { - PBFuncAST ret; +std::vector /* values */, Expr /* cond */>> +isl2Func(__isl_take isl_ast_node *node) { + std::vector, Expr>> ret; try { if (isl_ast_node_get_type(node) == isl_ast_node_if) { auto cond = isl2Expr(isl_ast_node_if_get_cond(node)); - for (auto &&[thenNames, thenFT, thenCond] : + for (auto &&[thenFT, thenCond] : isl2Func(isl_ast_node_if_get_then(node))) { - ret.push_back(SimplePBFuncAST{ - thenNames, thenFT, - thenCond.isValid() ? makeLAnd(cond, thenCond) : cond}); + ret.emplace_back(thenFT, thenCond.isValid() + ? makeLAnd(cond, thenCond) + : cond); } if (isl_ast_node_if_has_else(node)) { - for (auto &&[elseNames, elseFT, elseCond] : + for (auto &&[elseFT, elseCond] : isl2Func(isl_ast_node_if_get_else(node))) { - ret.push_back(SimplePBFuncAST{ - elseNames, elseFT, - elseCond.isValid() ? makeLAnd(makeLNot(cond), elseCond) - : makeLNot(cond)}); + ret.emplace_back(elseFT, + elseCond.isValid() + ? makeLAnd(makeLNot(cond), elseCond) + : makeLNot(cond)); } } @@ -292,14 +287,7 @@ PBFuncAST isl2Func(__isl_take isl_ast_node *node) { }) | ranges::to_vector; - std::unordered_set names; - for (auto &&item : vals) { - for (auto &&name : allNames(item)) { - names.insert(name); - } - } - ret = {SimplePBFuncAST{ranges::to(names), vals, - nullptr}}; + ret = {{vals, nullptr}}; } catch (...) { isl_ast_expr_free(expr); throw; @@ -324,6 +312,13 @@ PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set) { ASSERT(set.isSingleValued()); + std::vector params = + views::ints(0, set.nParamDims()) | + views::transform([&](int i) -> std::string { + return isl_set_get_dim_name(set.get(), isl_dim_param, i); + }) | + ranges::to(); + isl_options_set_ast_build_detect_min_max(ctx.get(), 1); PBFuncAST ret; @@ -333,7 +328,9 @@ PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set) { isl_schedule_from_domain(isl_union_set_from_set(set.copy())); isl_ast_node *ast = isl_ast_build_node_from_schedule(build /* keep */, s /* take */); - ret = isl2Func(ast /* take */); + for (auto &&[vals, cond] : isl2Func(ast /* take */)) { + ret.emplace_back(params, vals, cond); + } } catch (...) { isl_ast_build_free(build); throw; @@ -343,4 +340,36 @@ PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set) { return ret; } +namespace { + +template PBMap moveAllInputDimsToParam(const PBCtx &ctx, T &&map) { + // A name is required for the parameter, so we can't simply use + // isl_map_move_dims. We constuct a map to apply on the set to move the + // dimension. Example map: [i1, i2] -> {[i1, i2] -> []}. The parameters are + // assigned with temporary names. + + int nInDims = map.nInDims(); + std::ostringstream os; + os << "[" + << (views::ints(0, nInDims) | views::transform([](int i) { + return "ft_unnamed_in_dim_" + std::to_string(i); + }) | + join(",")) + << "] -> {[" + << (views::ints(0, nInDims) | views::transform([](int i) { + return "ft_unnamed_in_dim_" + std::to_string(i); + }) | + join(",")) + << "] -> []}"; + PBMap moving(ctx, os.str()); + return applyDomain(std::forward(map), std::move(moving)); +} + +} // Anonymous namespace + +PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBMap &map) { + return parsePBFuncReconstructMinMax( + ctx, range(moveAllInputDimsToParam(ctx, map))); +} + } // namespace freetensor diff --git a/src/pass/shrink_for.cc b/src/pass/shrink_for.cc index 87281b268..c64fa22ff 100644 --- a/src/pass/shrink_for.cc +++ b/src/pass/shrink_for.cc @@ -1,7 +1,15 @@ +#include +#include +#include + +#include #include #include +#include +#include #include #include +#include #include #include #include @@ -11,7 +19,53 @@ namespace freetensor { namespace { +template +PBSet moveDimToNamedParam(const PBCtx &ctx, T &&set, int dim, + const std::string ¶m) { + // A name is required for the parameter, so we can't simply use + // isl_set_move_dims. We constuct a map to apply on the set to move the + // dimension. Example map: [p] -> {[_1, _2, p] -> [_1, _2]} + + int nDims = set.nDims(); + std::ostringstream os; + os << "[" << param << "] -> {[" + << (views::ints(0, nDims) | views::transform([&](int i) { + return i == dim ? param : "_" + std::to_string(i); + }) | + join(",")) + << "] -> [" + << (views::ints(0, nDims) | + views::filter([&](int i) { return i != dim; }) | + views::transform([](int i) { return "_" + std::to_string(i); }) | + join(",")) + << "]}"; + PBMap map(ctx, os.str()); + return apply(std::forward(set), std::move(map)); +} + class CompUniqueBoundsPBWithStride : public CompUniqueBoundsPB { + private: + std::pair + getStride(const Ref &bound, bool requireConst) { + isl_stride_info *info = isl_set_get_stride_info(bound->bound_.get(), 0); + auto stride = PBVal(isl_stride_info_get_stride(info)); + auto offset = PBSingleFunc(isl_stride_info_get_offset(info)); + isl_stride_info_free(info); + ASSERT(stride.denSi() == 1); + auto strideInt = stride.numSi(); + ReplaceIter demangler(*bound->demangleMap_); + auto offsetSimpleFunc = parseSimplePBFunc(toString(offset)); + // offsetSimpleFunc.args_ should be a dummy variable equals to `bound`'s + // value. Leave it. + ASSERT(offsetSimpleFunc.values_.size() == 1); + auto offsetExpr = demangler(offsetSimpleFunc.values_[0]); + if (requireConst && !HashComparator{}(offsetExpr, makeIntConst(0))) { + strideInt = 1; + offsetExpr = makeIntConst(0); + } + return {strideInt, offsetExpr}; + } + public: CompUniqueBoundsPBWithStride(const CompTransientBoundsInterface &transients) : CompUniqueBoundsPB(transients) {} @@ -19,7 +73,8 @@ class CompUniqueBoundsPBWithStride : public CompUniqueBoundsPB { std::tuple unionBoundsAndGetStride( - const std::vector> &bounds) { + const std::vector> &bounds, + bool requireConst) { auto bound = unionBoundsAsBound(bounds); // if no bound presented, return an empty range @@ -28,25 +83,77 @@ class CompUniqueBoundsPBWithStride : public CompUniqueBoundsPB { } // translate the lower and upper bounds back to expression - auto l = bound->lowerExpr(); - auto u = bound->upperExpr(); + auto l = + requireConst ? makeIntConst(bound->lowerInt()) : bound->lowerExpr(); + auto u = + requireConst ? makeIntConst(bound->upperInt()) : bound->upperExpr(); // Addition detction for strides - isl_stride_info *info = isl_set_get_stride_info(bound->bound_.get(), 0); - auto stride = PBVal(isl_stride_info_get_stride(info)); - auto offset = PBSingleFunc(isl_stride_info_get_offset(info)); - isl_stride_info_free(info); - ASSERT(stride.denSi() == 1); - auto strideInt = stride.numSi(); - ReplaceIter demangler(*bound->demangleMap_); - auto offsetSimpleFunc = parseSimplePBFunc(toString(offset)); - // offsetSimpleFunc.args_ should be a dummy variable equals to `bound`'s - // value. Leave it. - ASSERT(offsetSimpleFunc.values_.size() == 1); - auto offsetExpr = demangler(offsetSimpleFunc.values_[0]); + auto [strideInt, offsetExpr] = getStride(bound, requireConst); return {l, u, strideInt, offsetExpr}; } + + std::vector> + unionBoundsAndGetHighOrderStride( + const std::vector> &bounds, + bool requireConst) { + auto bound = unionBoundsAsBound(bounds); + + // if no bound presented, return an empty loop nest + if (!bound.isValid()) { + return {}; + } + + PBSet set = bound->bound_; + + // Reveal local dimensions + set = isl_set_lift(set.move()); + + // Put local dimension at front, so we can represent the target + // dimension by local dimensions, instead of representing local + // dimensions by the target dimension. The set returned by isl_set_lift + // is a wrapped set, so we can simply unwrap it and then reverse it. + set = isl_set_flatten( + isl_map_wrap(isl_map_reverse(isl_set_unwrap(set.move())))); + + ASSERT(set.nDims() >= 1); + std::vector> ret; + ret.reserve(set.nDims()); + auto demangleMap = *bound->demangleMap_; + for (int i = 0;; i++) { + // Project onto the loop we are checking + PBSet thisLoopSet = projectOutDims(set, 1, set.nDims() - 1); + + auto thisLoopBound = Ref::make( + bound->ctx_, + Ref>::make(demangleMap), + thisLoopSet); + auto l = requireConst ? makeIntConst(bound->lowerInt()) + : thisLoopBound->lowerExpr(); + auto u = requireConst ? makeIntConst(bound->upperInt()) + : thisLoopBound->upperExpr(); + auto [strideInt, offsetExpr] = + getStride(thisLoopBound, requireConst); + ret.emplace_back(l, u, strideInt, offsetExpr); + + if (set.nDims() == 1) { + break; + } else { + // As we go from outer loops to inner loops, we will move range + // dimensions to parameter dimensions, so inner loops will be + // represented by outer loops. The parameter name used here is + // temporary, and will be replaced later. + auto paramName = "ft_shrink_for_tmp_" + std::to_string(i); + set = moveDimToNamedParam(*bound->ctx_, std::move(set), 0, + paramName); + demangleMap[paramName] = makeVar(paramName); + } + } + + return ret; + } }; } // Anonymous namespace @@ -144,68 +251,111 @@ Stmt ShrinkFor::visit(const For &_op) { // PBCompBounds requires one instance per Stmt CompUniqueBoundsPBWithStride bound(*this); - auto [lower, upper, stride, offset] = - bound.unionBoundsAndGetStride(newRange_[var]); - - if (op->property_->unroll_) { - // Backends do not support these loops to be of variable lengths - lower = makeIntConst(bound.getIntLower(lower)); - upper = makeIntConst(bound.getIntUpper(upper)); - if (!HashComparator{}(offset, makeIntConst(0))) { - stride = 1; - offset = makeIntConst(0); - } - } - - // Since we can't normalize the loops (see the comment in shrinkFor), we - // have to handle step_ here. - if (op->step_->nodeType() == ASTNodeType::IntConst) { - auto step = op->step_.as()->val_; - ASSERT(stride % step == 0); - if (step > 0) { - if (lower.isValid()) { - if (stride > 1) { - // Find the lowest integer after `lower` that remains - // `offset` modulo `stride`: lowerOnOffset = lower + - // ((offset - lower) % stride + stride) % stride - op->begin_ = makeAdd( - lower, makeMod(makeAdd(makeMod(makeSub(offset, lower), - makeIntConst(stride)), - makeIntConst(stride)), - makeIntConst(stride))); - } else { - op->begin_ = lower; - } - } - if (upper.isValid()) { - op->end_ = makeAdd(upper, makeIntConst(1)); + // Backends do not support these loops to be of variable lengths + bool requireConst = op->property_->unroll_; + + if (unordered_ && op->step_->nodeType() == ASTNodeType::IntConst && + op->property_->parallel_ == serialScope) { + auto info = bound.unionBoundsAndGetHighOrderStride(newRange_[var], + requireConst); + std::unordered_set usedNames = uni(names(), allNames(op)); + std::unordered_map replace; + Stmt ret = op->body_; + for (auto &&[i, item] : views::reverse(views::enumerate(info))) { + auto &&[lower, upper, stride, offset] = item; + + // The last (first before we reverse it) iter is the original iter. + // Keep its name. The others are renamed. + auto thisIterName = op->iter_; + if (i != info.size() - 1) { + thisIterName = getNewName(op->iter_, usedNames); + usedNames.emplace(thisIterName); } - op->step_ = makeIntConst(stride); - op->len_ = makeCeilDiv(makeSub(op->end_, op->begin_), op->step_); - } else if (step < 0) { - if (upper.isValid()) { - if (stride < -1) { - // Find the highest integer before `upper` that remains - // `offset` modulo `stride`: upperOnOffset = upper - - // ((upper - offset) % stride + stride) % stride - op->begin_ = makeSub( - upper, makeMod(makeAdd(makeMod(makeSub(upper, offset), + replace["ft_shrink_for_tmp_" + std::to_string(i)] = + makeVar(thisIterName); + + // Find the lowest integer after `lower` that remains `offset` + // modulo `stride`: lowerOnOffset = lower + ((offset - lower) % + // stride + stride) % stride + auto begin = + makeAdd(lower, makeMod(makeAdd(makeMod(makeSub(offset, lower), makeIntConst(stride)), makeIntConst(stride)), makeIntConst(stride))); - } else { - op->begin_ = upper; + auto end = makeAdd(upper, makeIntConst(1)); + auto step = makeIntConst(stride); + auto len = makeCeilDiv(makeSub(end, begin), step); + + ret = makeFor(thisIterName, std::move(begin), std::move(end), + std::move(step), std::move(len), op->property_, + std::move(ret)); + } + ret = ReplaceIter{replace}(ret); + + // Assign the old ID and metadata to the outer-most new loop + ret->setId(op->id()); + ret->metadata() = op->metadata(); + + return ret; + + } else { + auto [lower, upper, stride, offset] = + bound.unionBoundsAndGetStride(newRange_[var], requireConst); + + // Since we can't normalize the loops (see the comment in shrinkFor), we + // have to handle step_ here. + if (op->step_->nodeType() == ASTNodeType::IntConst) { + auto step = op->step_.as()->val_; + ASSERT(stride % step == 0); + if (step > 0) { + if (lower.isValid()) { + if (stride > 1) { + // Find the lowest integer after `lower` that remains + // `offset` modulo `stride`: lowerOnOffset = lower + + // ((offset - lower) % stride + stride) % stride + op->begin_ = makeAdd( + lower, + makeMod(makeAdd(makeMod(makeSub(offset, lower), + makeIntConst(stride)), + makeIntConst(stride)), + makeIntConst(stride))); + } else { + op->begin_ = lower; + } } + if (upper.isValid()) { + op->end_ = makeAdd(upper, makeIntConst(1)); + } + op->step_ = makeIntConst(stride); + op->len_ = + makeCeilDiv(makeSub(op->end_, op->begin_), op->step_); + } else if (step < 0) { + if (upper.isValid()) { + if (stride < -1) { + // Find the highest integer before `upper` that remains + // `offset` modulo `stride`: upperOnOffset = upper - + // ((upper - offset) % stride + stride) % stride + op->begin_ = makeSub( + upper, + makeMod(makeAdd(makeMod(makeSub(upper, offset), + makeIntConst(stride)), + makeIntConst(stride)), + makeIntConst(stride))); + } else { + op->begin_ = upper; + } + } + if (lower.isValid()) { + op->end_ = makeAdd(lower, makeIntConst(-1)); + } + op->step_ = makeIntConst(-stride); + op->len_ = + makeCeilDiv(makeSub(op->end_, op->begin_), op->step_); } - if (lower.isValid()) { - op->end_ = makeAdd(lower, makeIntConst(-1)); - } - op->step_ = makeIntConst(-stride); - op->len_ = makeCeilDiv(makeSub(op->end_, op->begin_), op->step_); } - } - return op; + return op; + } } void ShrinkFor::setSubAST(const Stmt &subAST) { @@ -214,7 +364,8 @@ void ShrinkFor::setSubAST(const Stmt &subAST) { subASTAncestors_.insert(s); } -Stmt shrinkFor(const Stmt &_op, const ID &_subAST, bool doSimplify) { +Stmt shrinkFor(const Stmt &_op, const ID &_subAST, bool doSimplify, + bool unordered) { auto op = _op; auto subAST = _subAST; @@ -241,13 +392,21 @@ Stmt shrinkFor(const Stmt &_op, const ID &_subAST, bool doSimplify) { subAST = newSubAST.isValid() ? newSubAST->id() : ID(); } - ShrinkFor shrinker; + ShrinkFor shrinker{unordered}; if (subAST.isValid()) shrinker.setSubAST(findStmt(op, subAST)); op = shrinker(op); + // Ranges from lifting are often quite strange. We'd better normalize them + if (unordered) { + op = normalizeLoops(op, [&](const For &loop) { + return subAST.isValid() ? loop->ancestorById(subAST).isValid() + : true; + }); + } + if (doSimplify) // Make new ranges simple + remove redundant branches - op = simplify(z3Simplify(op)); + op = simplify(pbSimplify(z3Simplify(op))); return op; } diff --git a/src/schedule/parallelize_as.cc b/src/schedule/parallelize_as.cc index b1cc15aed..1bd701287 100644 --- a/src/schedule/parallelize_as.cc +++ b/src/schedule/parallelize_as.cc @@ -39,8 +39,9 @@ class AddParScopes : public TrackStmt> { typedef TrackStmt> BaseClass; ID nest_; + const PBCtx &presburger_; const std::vector &orderedScopes_; - const std::unordered_map &scope2Idx2Iter; + const std::unordered_map &scope2Idx2Iter_; ID newNestId_; std::vector newIterNames_; @@ -54,10 +55,11 @@ class AddParScopes : public TrackStmt> { std::unordered_map>> threadGuard_; public: - AddParScopes(const ID &nest, const std::vector &orderedScopes, + AddParScopes(const ID &nest, const PBCtx &presburger, + const std::vector &orderedScopes, const std::unordered_map &scope2Idx2Iter) - : nest_(nest), orderedScopes_(orderedScopes), - scope2Idx2Iter(scope2Idx2Iter) {} + : nest_(nest), presburger_(presburger), orderedScopes_(orderedScopes), + scope2Idx2Iter_(scope2Idx2Iter) {} const auto &newScopeIds() const { return newScopeIds_; } const auto &newNestId() const { return newNestId_; } @@ -69,10 +71,11 @@ class AddParScopes : public TrackStmt> { thisThreadGuard.reserve(orderedScopes_.size()); for (auto &&[scope, newIterName] : views::zip(orderedScopes_, newIterNames_)) { - auto &&idx2iter = scope2Idx2Iter.at(scope->id()); + auto &&idx2iter = coalesce(scope2Idx2Iter_.at(scope->id())); SimplePBFuncAST f; try { - f = parseSimplePBFunc(toString(PBFunc(idx2iter))); + f = parseSimplePBFuncReconstructMinMax(presburger_, + idx2iter); } catch (const ParserError &e) { throw InvalidSchedule( FT_MSG << "Thread mapping is not a simple function: " @@ -258,11 +261,19 @@ Stmt parallelizeAs(const Stmt &_ast, const ID &nest, const ID &reference, } } - AddParScopes adder{nest, orderedScopes, scope2Idx2Iter}; + AddParScopes adder{nest, presburger, orderedScopes, scope2Idx2Iter}; ast = adder(ast); - // Shrink original loops in `nest` according to the gaurds with just add - ast = shrinkFor(ast, adder.newNestId()); + // Shrink original loops in `nest` according to the gaurds with just add. If + // the loop does not carry dependences, we can use a more aggressive + // "unordered" shrinking. + std::vector dirs; + for (auto &&s : findAllStmt(ast, "(<<-" + toString(nest) + ")|" + + toString(nest))) { + dirs.push_back({{s->id(), DepDirection::Normal}}); + } + bool unordered = !FindDeps().direction(dirs).filterSubAST(nest).exists(ast); + ast = shrinkFor(ast, nest, true, unordered); for (auto &&[id, scope] : views::zip(adder.newScopeIds(), orderedScopes)) { ast = parallelize(ast, id, scope->property_->parallel_, true); diff --git a/test/30.schedule/test_parallelize_as.py b/test/30.schedule/test_parallelize_as.py index 3a53c511c..9a5c4462a 100644 --- a/test/30.schedule/test_parallelize_as.py +++ b/test/30.schedule/test_parallelize_as.py @@ -27,8 +27,8 @@ def test_partitioned_by_tile(): with ft.For("j", 0, 2) as j: b[i * 2 + j] = a[i * 2 + j] * 2 with ft.For("i", 0, 4, label="L1") as i: - with ft.For("j", i * 2, i * 2 + 2) as j: - c[j] = b[j] + 1 + with ft.For("j", 0, 2) as j: + c[i * 2 + j] = b[i * 2 + j] + 1 std = ft.pop_ast() assert std.match(ast) @@ -59,8 +59,43 @@ def test_partitioned_by_stride(): with ft.For("j", 0, 2) as j: b[j * 4 + i] = a[j * 4 + i] * 2 with ft.For("i", 0, 4, label="L1") as i: - with ft.For("j", i, i + 5, 4) as j: - c[j] = b[j] + 1 + with ft.For("j", 0, 2) as j: + c[j * 4 + i] = b[j * 4 + i] + 1 + std = ft.pop_ast() + + assert std.match(ast) + + +def test_partitioned_by_high_order_stride(): + with ft.VarDef([("a", (64,), "int32", "input", "cpu"), + ("c", (64,), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (64,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 8) as j: + with ft.For("k", 0, 2) as k: + b[j * 8 + i * 2 + k] = a[j * 8 + i * 2 + k] * 2 + with ft.For("i", 0, 64, label="L2") as i: + c[i] = b[i] + 1 + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast, verbose=1) + s.parallelize("L1", "openmp") + s.parallelize_as("L2", "L1", "Vb") + ast = s.ast() + assert ft.find_stmt(ast, "->L2").property.parallel == "openmp" + + with ft.VarDef([("a", (64,), "int32", "input", "cpu"), + ("c", (64,), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (64,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 8) as j: + with ft.For("k", 0, 2) as k: + b[j * 8 + i * 2 + k] = a[j * 8 + i * 2 + k] * 2 + with ft.For("i", 0, 4, label="L2") as i: + with ft.For("j", 0, 8) as j: + with ft.For("k", 0, 2) as k: + c[j * 8 + i * 2 + k] = b[j * 8 + i * 2 + k] + 1 std = ft.pop_ast() assert std.match(ast) @@ -88,8 +123,8 @@ def test_reference_after_nest(): ft.MarkLabel("Vb") with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b: with ft.For("i", 0, 4, label="L1") as i: - with ft.For("j", i * 2, i * 2 + 2) as j: - b[j] = a[j] + 1 + with ft.For("j", 0, 2) as j: + b[i * 2 + j] = a[i * 2 + j] + 1 with ft.For("i", 0, 4, label="L1") as i: with ft.For("j", 0, 2) as j: c[i * 2 + j] = b[i * 2 + j] * 2 @@ -133,9 +168,10 @@ def test_multiple_levels(): b[i, j] = a[i, j] * 2 with ft.For("i0", 0, 8, label="L1i") as i0: with ft.For("j0", 0, 8, label="L1j") as j0: - with ft.For("i", 16 * i0, 16 * i0 + 16) as i: - with ft.For("j", 16 * j0, 16 * j0 + 16) as j: - c[i, j] = b[i, j] + 1 + with ft.For("i", 0, 16) as i1: + with ft.For("j", 0, 16) as j1: + c[i0 * 16 + i1, + j0 * 16 + j1] = b[i0 * 16 + i1, j0 * 16 + j1] + 1 std = ft.pop_ast() assert std.match(ast)