diff --git a/ffi/parallel_scope.cc b/ffi/parallel_scope.cc index ce4532150..bf6cb5838 100644 --- a/ffi/parallel_scope.cc +++ b/ffi/parallel_scope.cc @@ -8,24 +8,40 @@ void init_ffi_parallel_scope(py::module_ &m) { .def(py::init<>()) .def("__str__", [](const SerialScope &scope) { return toString(scope); }) - .def("__eq__", [](const SerialScope &lhs, const SerialScope &rhs) { - return lhs == rhs; - }); + .def("__eq__", [](const SerialScope &lhs, + const SerialScope &rhs) { return lhs == rhs; }) + .def("__eq__", + [](const SerialScope &lhs, const std::string &rhs) { + return ParallelScope{lhs} == parseParallelScope(rhs); + }) + .def("__eq__", + [](const SerialScope &lhs, py::object rhs) { return false; }); py::class_(m, "OpenMPScope") .def(py::init<>()) .def("__str__", [](const OpenMPScope &scope) { return toString(scope); }) - .def("__eq__", [](const OpenMPScope &lhs, const OpenMPScope &rhs) { - return lhs == rhs; - }); + .def("__eq__", [](const OpenMPScope &lhs, + const OpenMPScope &rhs) { return lhs == rhs; }) + .def("__eq__", + [](const OpenMPScope &lhs, const std::string &rhs) { + return ParallelScope{lhs} == parseParallelScope(rhs); + }) + .def("__eq__", + [](const OpenMPScope &lhs, py::object rhs) { return false; }); py::class_(m, "CUDAStreamScope") .def(py::init<>()) .def("__str__", [](const CUDAStreamScope &scope) { return toString(scope); }) .def("__eq__", [](const CUDAStreamScope &lhs, - const CUDAStreamScope &rhs) { return lhs == rhs; }); + const CUDAStreamScope &rhs) { return lhs == rhs; }) + .def("__eq__", + [](const CUDAStreamScope &lhs, const std::string &rhs) { + return ParallelScope{lhs} == parseParallelScope(rhs); + }) + .def("__eq__", + [](const CUDAStreamScope &lhs, py::object rhs) { return false; }); py::enum_(m, "CUDAScopeLevel") .value("Block", CUDAScope::Level::Block) @@ -40,9 +56,14 @@ void init_ffi_parallel_scope(py::module_ &m) { return CUDAScope{level, dim}; })) .def("__str__", [](const CUDAScope &scope) { return toString(scope); }) - .def("__eq__", [](const CUDAScope &lhs, const CUDAScope &rhs) { - return lhs == rhs; - }); + .def("__eq__", [](const CUDAScope &lhs, + const CUDAScope &rhs) { return lhs == rhs; }) + .def("__eq__", + [](const CUDAScope &lhs, const std::string &rhs) { + return ParallelScope{lhs} == parseParallelScope(rhs); + }) + .def("__eq__", + [](const CUDAScope &lhs, py::object rhs) { return false; }); // Factory function, used as a class m.def("ParallelScope", &parseParallelScope); diff --git a/ffi/pass.cc b/ffi/pass.cc index 2658732c3..5cff0d44b 100644 --- a/ffi/pass.cc +++ b/ffi/pass.cc @@ -85,12 +85,12 @@ void init_ffi_pass(py::module_ &m) { "stmt"_a); m.def("shrink_for", - static_cast( + static_cast( &shrinkFor), - "func"_a, "sub_ast"_a = nullptr, "do_simplify"_a = true); + "func"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true); m.def("shrink_for", - static_cast(&shrinkFor), - "stmt"_a, "sub_ast"_a = nullptr, "do_simplify"_a = true); + static_cast(&shrinkFor), + "stmt"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true); m.def("merge_and_hoist_if", static_cast(&mergeAndHoistIf), "func"_a); diff --git a/ffi/schedule.cc b/ffi/schedule.cc index 87bf5d684..224a0a3ce 100644 --- a/ffi/schedule.cc +++ b/ffi/schedule.cc @@ -131,6 +131,8 @@ void init_ffi_schedule(py::module_ &m) { .def("inline", &Schedule::inlining, "vardef"_a) .def("parallelize", &Schedule::parallelize, "loop"_a, "parallel"_a, "allow_reduction"_a = true) + .def("parallelize_as", &Schedule::parallelizeAs, "nest"_a, + "reference"_a, "def_id"_a) .def("unroll", &Schedule::unroll, "loop"_a, "immedate"_a = false) .def("vectorize", &Schedule::vectorize, "loop"_a) .def("separate_tail", &Schedule::separateTail, diff --git a/include/analyze/deps.h b/include/analyze/deps.h index 7a4bceda6..9a3fdd490 100644 --- a/include/analyze/deps.h +++ b/include/analyze/deps.h @@ -368,12 +368,22 @@ class AnalyzeDeps { static std::string makeCond(GenPBExpr &genPBExpr, RelaxMode relax, GenPBExpr::VarMap &externals, bool eraseOutsideVarDef, const AccessPoint &ap); + static PBMap makeAccMapStatic(PBCtx &presburger, const AccessPoint &p, + int iterDim, int accDim, RelaxMode relax, + const std::string &extSuffix, + GenPBExpr::VarMap &externals, + const ASTHashSet &noNeedToBeVars, + bool eraseOutsideVarDef); private: PBMap makeAccMap(PBCtx &presburger, const AccessPoint &p, int iterDim, int accDim, RelaxMode relax, const std::string &extSuffix, GenPBExpr::VarMap &externals, - const ASTHashSet &noNeedToBeVars); + const ASTHashSet &noNeedToBeVars) { + return makeAccMapStatic(presburger, p, iterDim, accDim, relax, + extSuffix, externals, noNeedToBeVars, + eraseOutsideVarDef_); + } PBMap makeEqForBothOps(PBCtx &presburger, const std::vector> &coord, diff --git a/include/get_new_name.h b/include/get_new_name.h new file mode 100644 index 000000000..e9d7554ec --- /dev/null +++ b/include/get_new_name.h @@ -0,0 +1,14 @@ +#ifndef FREE_TENSOR_GET_NEW_NAME_H +#define FREE_TENSOR_GET_NEW_NAME_H + +#include +#include + +namespace freetensor { + +std::string getNewName(const std::string &oldName, + const std::unordered_set &used); + +} + +#endif // FREE_TENSOR_GET_NEW_NAME_H diff --git a/include/pass/shrink_for.h b/include/pass/shrink_for.h index ac33a1db1..97ef9c446 100644 --- a/include/pass/shrink_for.h +++ b/include/pass/shrink_for.h @@ -56,9 +56,15 @@ class ShrinkFor : public CompTransientBounds> { /** * Increase the begin and decrease the end index, to remove redundant iterations * from For loops + * + * @{ */ -Stmt shrinkFor(const Stmt &op, const Stmt &subAST = nullptr, - bool doSimplify = true); +Stmt shrinkFor(const Stmt &op, const ID &subAST = ID(), bool doSimplify = true); +inline Stmt shrinkFor(const Stmt &op, const Stmt &subAST, + bool doSimplify = true) { + return shrinkFor(op, subAST.isValid() ? subAST->id() : ID(), doSimplify); +} +/** @} */ DEFINE_PASS_FOR_FUNC(shrinkFor) diff --git a/include/schedule.h b/include/schedule.h index b672bd9a2..c0c421260 100644 --- a/include/schedule.h +++ b/include/schedule.h @@ -645,6 +645,21 @@ class Schedule { void parallelize(const ID &loop, const ParallelScope ¶llel, bool allowReduction = true); + /** + * Parallelize a loop nest according to another loop nest to keep a tensor + * thread-local + * + * @param nest : ID of the loop nest to be parallelized. The ID can be of + * any statement type, and all statements it contains will be parallelized. + * @param reference: ID of the loop nest to be referenced. The ID can be of + * any statement type, and all statements it contains will be referenced. + * @param defId : ID of the VarDef statement of the tensor to be kept + * thread-local. + * @throw InvalidSchedule if any of the ID is not found, or the reference + * loop nest is already thread-non-local. + */ + void parallelizeAs(const ID &nest, const ID &reference, const ID &defId); + /** * Unroll a loop * diff --git a/include/schedule/parallelize_as.h b/include/schedule/parallelize_as.h new file mode 100644 index 000000000..697f49681 --- /dev/null +++ b/include/schedule/parallelize_as.h @@ -0,0 +1,13 @@ +#ifndef FREE_TENSOR_PARALLELIZE_AS_H +#define FREE_TENSOR_PARALLELIZE_AS_H + +#include + +namespace freetensor { + +Stmt parallelizeAs(const Stmt &ast, const ID &nest, const ID &reference, + const ID &defId); + +} // namespace freetensor + +#endif // FREE_TENSOR_PARALLELIZE_AS_H diff --git a/include/schedule/schedule_log.h b/include/schedule/schedule_log.h index 46ee6b573..902e6f2dc 100644 --- a/include/schedule/schedule_log.h +++ b/include/schedule/schedule_log.h @@ -30,6 +30,7 @@ enum class ScheduleType : int { VarReorder, Inline, Parallelize, + ParallelizeAs, Unroll, Vectorize, SeparateTail, @@ -42,14 +43,14 @@ enum class ScheduleType : int { }; constexpr std::array scheduleTypeNames = { - "split", "reorder", "merge", - "fission", "fuse", "swap", - "blend", "cache", "cache_reduction", - "set_mem_type", "var_split", "var_merge", - "var_reorder", "inline", "parallelize", - "unroll", "vectorize", "separate_tail", - "as_matmul", "permute", "pluto_fuse", - "pluto_permute", + "split", "reorder", "merge", + "fission", "fuse", "swap", + "blend", "cache", "cache_reduction", + "set_mem_type", "var_split", "var_merge", + "var_reorder", "inline", "parallelize", + "parallelize_as", "unroll", "vectorize", + "separate_tail", "as_matmul", "permute", + "pluto_fuse", "pluto_permute", }; static_assert(scheduleTypeNames.size() == (size_t)ScheduleType::NumTypes); diff --git a/python/freetensor/core/schedule.py b/python/freetensor/core/schedule.py index 012089c21..f2a8a85d0 100644 --- a/python/freetensor/core/schedule.py +++ b/python/freetensor/core/schedule.py @@ -692,6 +692,31 @@ def parallelize(self, loop, parallel): """ super().parallelize(self._lookup(loop), ParallelScope(parallel)) + def parallelize_as(self, nest, reference, vardef): + ''' + Parallelize a loop nest according to another loop nest to keep a tensor + thread-local + + Parameters + ---------- + nest : str, ID or Stmt + The loop nest to be parallelized. The ID can be of any statement type, + and all statements it contains will be parallelized. + reference: str, ID or Stmt + The loop nest to be referenced. The ID can be of any statement type, + and all statements it contains will be referenced. + vardef : str, ID or Stmt + The VarDef statement of the tensor to be kept thread-local. + + Raises + ------ + InvalidSchedule + if any of the ID is not found, or the reference loop nest is already + thread-non-local. + ''' + super().parallelize_as(self._lookup(nest), self._lookup(reference), + self._lookup(vardef)) + def unroll(self, loop, immediate=False): """ Unroll a loop diff --git a/src/analyze/deps.cc b/src/analyze/deps.cc index 62c298611..6ff12196d 100644 --- a/src/analyze/deps.cc +++ b/src/analyze/deps.cc @@ -417,16 +417,16 @@ std::string AnalyzeDeps::makeCond(GenPBExpr &genPBExpr, RelaxMode relax, return ret; } -PBMap AnalyzeDeps::makeAccMap(PBCtx &presburger, const AccessPoint &p, - int iterDim, int accDim, RelaxMode relax, - const std::string &extSuffix, - GenPBExpr::VarMap &externals, - const ASTHashSet &noNeedToBeVars) { +PBMap AnalyzeDeps::makeAccMapStatic(PBCtx &presburger, const AccessPoint &p, + int iterDim, int accDim, RelaxMode relax, + const std::string &extSuffix, + GenPBExpr::VarMap &externals, + const ASTHashSet &noNeedToBeVars, + bool eraseOutsideVarDef) { GenPBExpr genPBExpr(extSuffix, noNeedToBeVars); auto ret = makeIterList(p.iter_, iterDim) + " -> " + makeAccList(genPBExpr, p.access_, relax, externals); - if (auto str = - makeCond(genPBExpr, relax, externals, eraseOutsideVarDef_, p); + if (auto str = makeCond(genPBExpr, relax, externals, eraseOutsideVarDef, p); !str.empty()) { ret += ": " + str; } diff --git a/src/frontend/inlined_invoke.cc b/src/frontend/inlined_invoke.cc index 4225df9f7..e5947b287 100644 --- a/src/frontend/inlined_invoke.cc +++ b/src/frontend/inlined_invoke.cc @@ -3,20 +3,12 @@ #include #include #include +#include #include #include namespace freetensor { -static std::string getNewName(const std::string &oldName, - const std::unordered_set &used) { - for (int i = 1;; i++) { - if (auto name = oldName + "." + std::to_string(i); !used.count(name)) { - return name; - } - } -} - Stmt StripReturns::visit(const VarDef &op) { if (auto it = std::find_if( returns_.begin(), returns_.end(), diff --git a/src/get_new_name.cc b/src/get_new_name.cc new file mode 100644 index 000000000..7db742595 --- /dev/null +++ b/src/get_new_name.cc @@ -0,0 +1,14 @@ +#include + +namespace freetensor { + +std::string getNewName(const std::string &oldName, + const std::unordered_set &used) { + for (int i = 1;; i++) { + if (auto name = oldName + "." + std::to_string(i); !used.count(name)) { + return name; + } + } +} + +} // namespace freetensor diff --git a/src/pass/gpu/normalize_var_in_kernel.cc b/src/pass/gpu/normalize_var_in_kernel.cc index 4ede14e44..0dffff755 100644 --- a/src/pass/gpu/normalize_var_in_kernel.cc +++ b/src/pass/gpu/normalize_var_in_kernel.cc @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -35,15 +36,6 @@ std::unordered_map countNames(const Stmt &s) { return visitor.nameCnt(); } -std::string getNewName(const std::string &oldName, - const std::unordered_set &used) { - for (int i = 1;; i++) { - if (auto name = oldName + "." + std::to_string(i); !used.count(name)) { - return name; - } - } -} - } // Anonymous namespace Stmt NormalizeVarInKernel::visit(const VarDef &_op) { diff --git a/src/pass/shrink_for.cc b/src/pass/shrink_for.cc index e58594929..5902af9dd 100644 --- a/src/pass/shrink_for.cc +++ b/src/pass/shrink_for.cc @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -213,7 +214,7 @@ void ShrinkFor::setSubAST(const Stmt &subAST) { subASTAncestors_.insert(s); } -Stmt shrinkFor(const Stmt &_op, const Stmt &subAST, bool doSimplify) { +Stmt shrinkFor(const Stmt &_op, const ID &subAST, bool doSimplify) { auto op = _op; // DO NOT CALL normalizeLoops HERE! Since we often use (-INT_MAX, INT_MAX) @@ -225,7 +226,7 @@ Stmt shrinkFor(const Stmt &_op, const Stmt &subAST, bool doSimplify) { ShrinkFor shrinker; if (subAST.isValid()) - shrinker.setSubAST(subAST); + shrinker.setSubAST(findStmt(op, subAST)); op = shrinker(op); if (doSimplify) // Make new ranges simple + remove redundant branches diff --git a/src/schedule/hoist_selected_var.cc b/src/schedule/hoist_selected_var.cc index 1c1d86343..e0746948d 100644 --- a/src/schedule/hoist_selected_var.cc +++ b/src/schedule/hoist_selected_var.cc @@ -1,20 +1,12 @@ #include #include #include +#include #include #include namespace freetensor { -static std::string getNewName(const std::string &oldName, - const std::unordered_set &used) { - for (int i = 1;; i++) { - if (auto name = oldName + "." + std::to_string(i); !used.count(name)) { - return name; - } - } -} - Stmt HoistSelectedVar::visit(const For &op) { if (toHoist_.count(op->body_->id())) { ASSERT(op->body_->nodeType() == ASTNodeType::VarDef); diff --git a/src/schedule/parallelize_as.cc b/src/schedule/parallelize_as.cc new file mode 100644 index 000000000..f0e7a9928 --- /dev/null +++ b/src/schedule/parallelize_as.cc @@ -0,0 +1,287 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace freetensor { + +namespace { + +PBMap projectOntoOneOutputDim(const PBMap &map, int dim) { + auto ret = map; + if (dim < (int)map.nOutDims() - 1) { + ret = projectOutOutputDims(std::move(ret), dim + 1, + map.nOutDims() - dim - 1); + } + if (dim > 0) { + ret = projectOutOutputDims(std::move(ret), 0, dim); + } + ASSERT(ret.nOutDims() == 1); + return ret; +} + +class AddParScopes : public TrackStmt> { + typedef TrackStmt> BaseClass; + + ID nest_; + const std::vector &orderedScopes_; + const std::unordered_map &scope2Idx2Iter; + + ID newNestId_; + std::vector newIterNames_; + std::vector newScopeIds_; + + bool inside_ = false; + + // Itmes of inner vectors are from scopes, which are combined by LAnd. Items + // of the outer vector are from access sites, which are combined by LOr, so + // they can be checked by subsequent `parallelize` schedules. + std::unordered_map>> threadGuard_; + + public: + AddParScopes(const ID &nest, const std::vector &orderedScopes, + const std::unordered_map &scope2Idx2Iter) + : nest_(nest), orderedScopes_(orderedScopes), + scope2Idx2Iter(scope2Idx2Iter) {} + + const auto &newScopeIds() const { return newScopeIds_; } + const auto &newNestId() const { return newNestId_; } + + private: + template auto visitAcc(const T &op) { + if (inside_) { + std::vector thisThreadGuard; + thisThreadGuard.reserve(orderedScopes_.size()); + for (auto &&[scope, newIterName] : + views::zip(orderedScopes_, newIterNames_)) { + auto &&idx2iter = scope2Idx2Iter.at(scope->id()); + SimplePBFuncAST f; + try { + f = parseSimplePBFunc(toString(PBFunc(idx2iter))); + } catch (const ParserError &e) { + throw InvalidSchedule( + std::string( + "Thread mapping is not a simple function: ") + + e.what()); + } + ASSERT(f.args_.size() == op->indices_.size()); + ASSERT(f.values_.size() == 1); + std::unordered_map replace; + replace.reserve(op->indices_.size()); + for (auto &&[arg, idx] : views::zip(f.args_, op->indices_)) { + replace[arg] = idx; + } + thisThreadGuard.emplace_back( + makeEQ(makeVar(newIterName), + ReplaceIter{replace}(f.values_.front()))); + } + for (Stmt s = curStmt(); s.isValid(); s = s->parentStmt()) { + threadGuard_[s->id()].emplace_back(std::move(thisThreadGuard)); + if (s->id() == nest_) { + break; + } + } + } + return BaseClass::visit(op); + } + + Stmt doVisitStmt(const Stmt &s) { + auto ret = BaseClass::visitStmt(s); + if (inside_) { + if (auto it = threadGuard_.find(s->id()); + it != threadGuard_.end()) { + ret = makeIf(makeLOrLAnd(it->second), std::move(ret)); + } + } + return ret; + } + + protected: + using BaseClass::visit; + + Stmt visitStmt(const Stmt &s) override { + if (s->id() == nest_) { + auto usedNames = uni(names(), allNames(s)); + for (auto &&scope : views::reverse(orderedScopes_)) { + auto newIterName = getNewName(scope->iter_, usedNames); + usedNames.emplace(newIterName); + newIterNames_.emplace_back(newIterName); + } + + ASSERT(!inside_); + inside_ = true; + auto ret = doVisitStmt(s); + inside_ = false; + + for (auto &&[scope, newIterName] : + views::reverse(views::zip(orderedScopes_, newIterNames_))) { + // Make `For`s with empty parallel_ property, and subsequent + // `parallelize` schedules will parallelize them and check the + // legality. + ret = makeFor(newIterName, scope->begin_, scope->end_, + scope->step_, scope->len_, + Ref::make(), std::move(ret)); + newScopeIds_.emplace(newScopeIds_.begin(), ret->id()); + } + + newNestId_ = ret->id(); + return ret; + } else { + return doVisitStmt(s); + } + } + + Expr visit(const Load &op) override { return visitAcc(op); } + Stmt visit(const Store &op) override { return visitAcc(op); } + Stmt visit(const ReduceTo &op) override { return visitAcc(op); } +}; + +} // Anonymous namespace + +Stmt parallelizeAs(const Stmt &_ast, const ID &nest, const ID &reference, + const ID &defId) { + Stmt ast = _ast; + + bool referenceIsBeforeNest = + findStmt(ast, reference)->isBefore(findStmt(ast, nest)); + auto isSafeToMove = [&](const Expr &expr) { + // To be conservative, check all the range enclosing reference and nest + if (referenceIsBeforeNest) { + return checkNotModified(ast, expr, CheckNotModifiedSide::Before, + reference, CheckNotModifiedSide::After, + nest); + } else { + return checkNotModified(ast, expr, CheckNotModifiedSide::Before, + nest, CheckNotModifiedSide::After, + reference); + } + }; + auto checkSafeToMoveOrThrow = [&](const Expr &expr) { + if (!isSafeToMove(expr)) { + throw InvalidSchedule( + toString(expr) + + " in a reference nest's loop range is not supported"); + } + }; + + FindAccessPoint finder{ + defId, DEP_ALL, syncFunc([&](const Access &acc) { + return acc.stmt_->ancestorById(reference).isValid(); + })}; + finder.doFind(ast); + + PBCtx presburger; + std::unordered_map scope2Idx2Iter; + for (const Ref &acc : + views::concat(finder.reads(), finder.writes())) { + GenPBExpr::VarMap externals; + auto iter2idx = AnalyzeDeps::makeAccMapStatic( + presburger, *acc, acc->iter_.size(), acc->access_.size(), + RelaxMode::Possible, "", externals, {}, true); + if (!externals.empty()) { + throw InvalidSchedule( + "Indirect thread mapping in reference loop nest " + + toString(reference) + " is not supported"); + } + + std::unordered_map iter2Scope; + for (Stmt s = acc->stmt_; s.isValid(); s = s->parentStmt()) { + if (s->nodeType() == ASTNodeType::For) { + if (auto &&loop = s.as(); + loop->property_->parallel_ != serialScope) { + iter2Scope[loop->iter_] = loop; + } + } + if (s->id() == reference) { + break; + } + } + + for (auto &&[i, iterAxis] : views::enumerate(acc->iter_)) { + if (iterAxis.iter_->nodeType() == ASTNodeType::Var) { + if (auto it = + iter2Scope.find(iterAxis.iter_.as()->name_); + it != iter2Scope.end()) { + auto &&id = it->second->id(); + auto thisIdx2Iter = + projectOntoOneOutputDim(reverse(iter2idx), i); + scope2Idx2Iter[id] = + scope2Idx2Iter.count(id) + ? uni(scope2Idx2Iter[id], thisIdx2Iter) + : thisIdx2Iter; + } + } + } + } + + std::vector orderedScopes; + for (auto &&s : findAllStmt(ast, "(<<-" + toString(reference) + ")|" + + toString(reference))) { + ASSERT(s->nodeType() == ASTNodeType::For); + auto &&loop = s.as(); + checkSafeToMoveOrThrow(loop->begin_); + checkSafeToMoveOrThrow(loop->end_); + checkSafeToMoveOrThrow(loop->step_); + if (std::ranges::find_if(orderedScopes, [&](const For &f) { + return f->property_->parallel_ == loop->property_->parallel_; + }) != orderedScopes.end()) { + throw InvalidSchedule( + "Multiple loops bound to the same parallel scope " + + toString(loop->property_->parallel_) + + " in the reference loop nest " + toString(reference) + + " is not supported yet"); + } + if (auto it = scope2Idx2Iter.find(loop->id()); + it != scope2Idx2Iter.end()) { + if (!it->second.isSingleValued()) { + throw InvalidSchedule("Reference loop nest " + + toString(reference) + + " is not thread-local"); + } + orderedScopes.emplace_back(s.as()); + } + } + + AddParScopes adder{nest, orderedScopes, scope2Idx2Iter}; + ast = adder(ast); + + // Shrink original loops in `nest` according to the gaurds with just add + ast = shrinkFor(ast, adder.newNestId()); + + for (auto &&[id, scope] : views::zip(adder.newScopeIds(), orderedScopes)) { + ast = parallelize(ast, id, scope->property_->parallel_, true); + } + + return ast; +} + +void Schedule::parallelizeAs(const ID &nest, const ID &reference, + const ID &defId) { + beginTransaction(); + auto log = appendLog(MAKE_SCHEDULE_LOG( + ParallelizeAs, freetensor::parallelizeAs, nest, reference, defId)); + try { + applyLog(log); + commitTransaction(); + } catch (const InvalidSchedule &e) { + abortTransaction(); + throw InvalidSchedule(log, ast(), e.what()); + } +} + +} // namespace freetensor diff --git a/test/30.schedule/test_parallelize_as.py b/test/30.schedule/test_parallelize_as.py new file mode 100644 index 000000000..3a53c511c --- /dev/null +++ b/test/30.schedule/test_parallelize_as.py @@ -0,0 +1,192 @@ +import freetensor as ft +import pytest + + +def test_partitioned_by_tile(): + with ft.VarDef([("a", (8,), "int32", "input", "cpu"), + ("c", (8,), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 2) as j: + b[i * 2 + j] = a[i * 2 + j] * 2 + with ft.For("i", 0, 8, label="L2") as i: + c[i] = b[i] + 1 + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast, verbose=1) + s.parallelize("L1", "openmp") + s.parallelize_as("L2", "L1", "Vb") + ast = s.ast() + assert ft.find_stmt(ast, "->L2").property.parallel == "openmp" + + with ft.VarDef([("a", (8,), "int32", "input", "cpu"), + ("c", (8,), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 2) as j: + b[i * 2 + j] = a[i * 2 + j] * 2 + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", i * 2, i * 2 + 2) as j: + c[j] = b[j] + 1 + std = ft.pop_ast() + + assert std.match(ast) + + +def test_partitioned_by_stride(): + with ft.VarDef([("a", (8,), "int32", "input", "cpu"), + ("c", (8,), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 2) as j: + b[j * 4 + i] = a[j * 4 + i] * 2 + with ft.For("i", 0, 8, label="L2") as i: + c[i] = b[i] + 1 + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast, verbose=1) + s.parallelize("L1", "openmp") + s.parallelize_as("L2", "L1", "Vb") + ast = s.ast() + assert ft.find_stmt(ast, "->L2").property.parallel == "openmp" + + with ft.VarDef([("a", (8,), "int32", "input", "cpu"), + ("c", (8,), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 2) as j: + b[j * 4 + i] = a[j * 4 + i] * 2 + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", i, i + 5, 4) as j: + c[j] = b[j] + 1 + std = ft.pop_ast() + + assert std.match(ast) + + +def test_reference_after_nest(): + with ft.VarDef([("a", (8,), "int32", "input", "cpu"), + ("c", (8,), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 8, label="L2") as i: + b[i] = a[i] + 1 + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 2) as j: + c[i * 2 + j] = b[i * 2 + j] * 2 + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast, verbose=1) + s.parallelize("L1", "openmp") + s.parallelize_as("L2", "L1", "Vb") + ast = s.ast() + assert ft.find_stmt(ast, "->L2").property.parallel == "openmp" + + with ft.VarDef([("a", (8,), "int32", "input", "cpu"), + ("c", (8,), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", i * 2, i * 2 + 2) as j: + b[j] = a[j] + 1 + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 2) as j: + c[i * 2 + j] = b[i * 2 + j] * 2 + std = ft.pop_ast() + + assert std.match(ast) + + +@pytest.mark.skipif(not ft.with_cuda(), reason="requires CUDA") +def test_multiple_levels(): + with ft.VarDef([("a", (128, 128), "int32", "input", "cpu"), + ("c", (128, 128), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (128, 128), "int32", "cache", "cpu") as b: + with ft.For("i0", 0, 8, label="L1i") as i0: + with ft.For("j0", 0, 8, label="L1j") as j0: + with ft.For("i", 16 * i0, 16 * i0 + 16) as i: + with ft.For("j", 16 * j0, 16 * j0 + 16) as j: + b[i, j] = a[i, j] * 2 + with ft.For("i", 0, 128, label="L2") as i: + with ft.For("j", 0, 128) as j: + c[i, j] = b[i, j] + 1 + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast, verbose=1) + s.parallelize("L1i", "blockIdx.x") + s.parallelize("L1j", "threadIdx.x") + s.parallelize_as("L2", "L1i", "Vb") + ast = s.ast() + assert ft.find_stmt(ast, "->L2").property.parallel == "threadIdx.x" + assert ft.find_stmt(ast, + "->->L2").property.parallel == "blockIdx.x" + + with ft.VarDef([("a", (128, 128), "int32", "input", "cpu"), + ("c", (128, 128), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (128, 128), "int32", "cache", "cpu") as b: + with ft.For("i0", 0, 8, label="L1i") as i0: + with ft.For("j0", 0, 8, label="L1j") as j0: + with ft.For("i", 16 * i0, 16 * i0 + 16) as i: + with ft.For("j", 16 * j0, 16 * j0 + 16) as j: + b[i, j] = a[i, j] * 2 + with ft.For("i0", 0, 8, label="L1i") as i0: + with ft.For("j0", 0, 8, label="L1j") as j0: + with ft.For("i", 16 * i0, 16 * i0 + 16) as i: + with ft.For("j", 16 * j0, 16 * j0 + 16) as j: + c[i, j] = b[i, j] + 1 + std = ft.pop_ast() + + assert std.match(ast) + + +def test_reject_thread_non_local_reference_shared_by_loop(): + with ft.VarDef([("a", (8,), "int32", "input", "cpu"), + ("c", (2,), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (2,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 2) as j: + b[j] += a[i * 2 + j] * 2 + with ft.For("i", 0, 2, label="L2") as i: + c[i] = b[i] + 1 + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast, verbose=1) + s.parallelize("L1", "openmp") + with pytest.raises(ft.InvalidSchedule): + s.parallelize_as("L2", "L1", "Vb") + + +def test_reject_thread_non_local_reference_shared_by_multiple_access_sites(): + with ft.VarDef([("a", (8,), "int32", "input", "cpu"), + ("c", (8,), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 8, label="L2") as i: + b[i] = a[i] + 1 + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 2) as j: + c[i * 2 + j] = b[i * 2 + j] + b[(i * 2 + j + 1) % 8] + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast, verbose=1) + s.parallelize("L1", "openmp") + with pytest.raises(ft.InvalidSchedule): + s.parallelize_as("L2", "L1", "Vb") + + +def test_reject_thread_non_local_destination_nest(): + with ft.VarDef([("a", (8,), "int32", "input", "cpu"), + ("c", (8,), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 2) as j: + b[i * 2 + j] = a[i * 2 + j] * 2 + with ft.For("i", 0, 8, label="L2") as i: + c[i] = b[i] + b[(i + 1) % 8] + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast, verbose=1) + s.parallelize("L1", "openmp") + with pytest.raises(ft.InvalidSchedule): + s.parallelize_as("L2", "L1", "Vb")