From d2db3350a070557ef616c4dc86607dac79935d7f Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Fri, 19 Jan 2024 11:50:19 +0800 Subject: [PATCH] Remove RelaxMode from analyze/deps Since we have been representing all non-affine expressions as external varaibles, RelaxMode is no longer used. --- include/analyze/deps.h | 30 +++++++++--------------------- include/math/presburger.h | 20 +++++++++++++++++++- src/analyze/deps.cc | 18 +++++++++--------- src/pass/shrink_for.cc | 3 +-- src/schedule/parallelize_as.cc | 4 ++-- 5 files changed, 40 insertions(+), 35 deletions(-) diff --git a/include/analyze/deps.h b/include/analyze/deps.h index a46b3fab4..cd26556b7 100644 --- a/include/analyze/deps.h +++ b/include/analyze/deps.h @@ -101,8 +101,7 @@ class FindAccessPoint : public SymbolTable> { std::vector> conds_; // FIXME: There may be out-dated conditions, and we must check // allReads(cond) against allWrites(body) for each If or For - // nodes. See pass/simplify. If the condition violates, we may - // need to push a null condition according to RelaxMode + // nodes. See pass/simplify. std::vector> reads_, writes_; // For or StmtSeq -> coordinate in iteration space @@ -268,7 +267,6 @@ struct Dependence { typedef SyncFunc FindDepsCallback; -enum class RelaxMode : int { Possible, Necessary }; enum class FindDepsMode : int { Dep, // Dependence may happen between `earlier` and `later` KillEarlier, // At any point in the space of `earlier`, it is dependent by @@ -297,7 +295,6 @@ class AnalyzeDeps { const FindDepsFilter &filter_; const FindDepsMode mode_; - const RelaxMode earlierRelax_, laterRelax_; const DepType depType_; const bool ignoreReductionWAW_; const bool eraseOutsideVarDef_; @@ -321,16 +318,8 @@ class AnalyzeDeps { : scope2coord_(scope2coord), noDepsLists_(noDepsLists), variantExpr_(variantExpr), direction_(direction), found_(found), earlierFilter_(earlierFilter), laterFilter_(laterFilter), - filter_(filter), mode_(mode), - earlierRelax_(mode_ == FindDepsMode::KillLater || - mode_ == FindDepsMode::KillBoth - ? RelaxMode::Necessary - : RelaxMode::Possible), - laterRelax_(mode_ == FindDepsMode::KillEarlier || - mode_ == FindDepsMode::KillBoth - ? RelaxMode::Necessary - : RelaxMode::Possible), - depType_(depType), ignoreReductionWAW_(ignoreReductionWAW), + filter_(filter), mode_(mode), depType_(depType), + ignoreReductionWAW_(ignoreReductionWAW), eraseOutsideVarDef_(eraseOutsideVarDef), noProjectOutPrivateAxis_(noProjectOutPrivateAxis) { readsAsEarlier_ = @@ -364,12 +353,12 @@ class AnalyzeDeps { static std::vector< std::pair> makeAccList(GenPBExpr &genPBExpr, const std::vector &list, - RelaxMode relax, GenPBExpr::VarMap &externals); - static std::string makeCond(GenPBExpr &genPBExpr, RelaxMode relax, + GenPBExpr::VarMap &externals); + static std::string makeCond(GenPBExpr &genPBExpr, GenPBExpr::VarMap &externals, bool eraseOutsideVarDef, const AccessPoint &ap); static PBMap makeAccMapStatic(PBCtx &presburger, const AccessPoint &p, - int iterDim, int accDim, RelaxMode relax, + int iterDim, int accDim, const std::string &extSuffix, GenPBExpr::VarMap &externals, const ASTHashSet &noNeedToBeVars, @@ -377,12 +366,11 @@ class AnalyzeDeps { private: PBMap makeAccMap(PBCtx &presburger, const AccessPoint &p, int iterDim, - int accDim, RelaxMode relax, const std::string &extSuffix, + int accDim, const std::string &extSuffix, GenPBExpr::VarMap &externals, const ASTHashSet &noNeedToBeVars) { - return makeAccMapStatic(presburger, p, iterDim, accDim, relax, - extSuffix, externals, noNeedToBeVars, - eraseOutsideVarDef_); + return makeAccMapStatic(presburger, p, iterDim, accDim, extSuffix, + externals, noNeedToBeVars, eraseOutsideVarDef_); } PBMap makeEqForBothOps(PBCtx &presburger, diff --git a/include/math/presburger.h b/include/math/presburger.h index 4543ca52e..060b70e9a 100644 --- a/include/math/presburger.h +++ b/include/math/presburger.h @@ -912,8 +912,26 @@ template PBSpace spaceMapFromSet(T &&space) { return isl_space_map_from_set(PBRefTake(space)); } +template PBSet wrap(T &&map) { + return isl_map_wrap(PBRefTake(map)); +} + +template PBMap unwrap(T &&set) { + return isl_set_unwrap(PBRefTake(set)); +} + +template PBSet flatten(T &&set) { + return isl_set_flatten(PBRefTake(set)); +} +template PBMap flattenDomain(T &&map) { + return isl_map_flatten_domain(PBRefTake(map)); +} +template PBMap flattenRange(T &&map) { + return isl_map_flatten_range(PBRefTake(map)); +} + template PBSet flattenMapToSet(T &&map) { - return isl_set_flatten(isl_map_wrap(PBRefTake(map))); + return flatten(wrap(std::forward(map))); } template PBPoint sample(T &&set) { diff --git a/src/analyze/deps.cc b/src/analyze/deps.cc index efaa04ea5..ff2423bd4 100644 --- a/src/analyze/deps.cc +++ b/src/analyze/deps.cc @@ -319,7 +319,7 @@ std::string AnalyzeDeps::makeNegIterMap(const std::vector &list, std::vector> AnalyzeDeps::makeAccList(GenPBExpr &genPBExpr, const std::vector &list, - RelaxMode relax, GenPBExpr::VarMap &externals) { + GenPBExpr::VarMap &externals) { std::vector> ret; for (auto &&[l, c] : normalizeConditionalExprList(list)) { std::ostringstream os; @@ -349,7 +349,7 @@ AnalyzeDeps::makeAccList(GenPBExpr &genPBExpr, const std::vector &list, return ret; } -std::string AnalyzeDeps::makeCond(GenPBExpr &genPBExpr, RelaxMode relax, +std::string AnalyzeDeps::makeCond(GenPBExpr &genPBExpr, GenPBExpr::VarMap &externals, bool eraseOutsideVarDef, const AccessPoint &ap) { @@ -431,15 +431,15 @@ std::string AnalyzeDeps::makeCond(GenPBExpr &genPBExpr, RelaxMode relax, } PBMap AnalyzeDeps::makeAccMapStatic(PBCtx &presburger, const AccessPoint &p, - int iterDim, int accDim, RelaxMode relax, + int iterDim, int accDim, const std::string &extSuffix, GenPBExpr::VarMap &externals, const ASTHashSet &noNeedToBeVars, bool eraseOutsideVarDef) { GenPBExpr genPBExpr(extSuffix, noNeedToBeVars); auto iterList = makeIterList(p.iter_, iterDim); - auto condStr = makeCond(genPBExpr, relax, externals, eraseOutsideVarDef, p); - auto accListFactors = makeAccList(genPBExpr, p.access_, relax, externals); + auto condStr = makeCond(genPBExpr, externals, eraseOutsideVarDef, p); + auto accListFactors = makeAccList(genPBExpr, p.access_, externals); std::ostringstream os; for (auto &&[i, factor] : views::enumerate(accListFactors)) { auto &&[accList, accCond] = factor; @@ -1053,7 +1053,7 @@ void AnalyzeDeps::checkDepLatestEarlierImpl( GenPBExpr::VarMap laterExternals; PBMap laterMap = - makeAccMap(presburger, *later, iterDim, accDim, laterRelax_, + makeAccMap(presburger, *later, iterDim, accDim, "later" + std::to_string((uint64_t)later->stmt_->id()), laterExternals, noNeedToBeVars); if (laterMap.empty()) { @@ -1070,7 +1070,7 @@ void AnalyzeDeps::checkDepLatestEarlierImpl( views::zip(views::ints(0, ranges::unreachable), earlierList, earlierMapList, earlierExternalsList)) { earlierMap = makeAccMap( - presburger, *earlier, iterDim, accDim, earlierRelax_, + presburger, *earlier, iterDim, accDim, "earlier" + std::to_string((uint64_t)earlier->stmt_->id()), earlierExternals, noNeedToBeVars); } @@ -1165,7 +1165,7 @@ void AnalyzeDeps::checkDepEarliestLaterImpl( GenPBExpr::VarMap earlierExternals; PBMap earlierMap = - makeAccMap(presburger, *earlier, iterDim, accDim, earlierRelax_, + makeAccMap(presburger, *earlier, iterDim, accDim, "earlier" + std::to_string((uint64_t)earlier->stmt_->id()), earlierExternals, noNeedToBeVars); if (earlierMap.empty()) { @@ -1181,7 +1181,7 @@ void AnalyzeDeps::checkDepEarliestLaterImpl( views::zip(views::ints(0, ranges::unreachable), laterList, laterMapList, laterExternalsList)) { laterMap = - makeAccMap(presburger, *later, iterDim, accDim, laterRelax_, + makeAccMap(presburger, *later, iterDim, accDim, "later" + std::to_string((uint64_t)later->stmt_->id()), laterExternals, noNeedToBeVars); } diff --git a/src/pass/shrink_for.cc b/src/pass/shrink_for.cc index 502a3f9ce..dfc0b5cfd 100644 --- a/src/pass/shrink_for.cc +++ b/src/pass/shrink_for.cc @@ -123,8 +123,7 @@ class CompUniqueBoundsPBWithStride : public CompUniqueBoundsPB { // dimension by local dimensions, instead of representing local // dimensions by the target dimension. The set returned by isl_set_lift // is a wrapped set, so we can simply unwrap it and then reverse it. - set = isl_set_flatten( - isl_map_wrap(isl_map_reverse(isl_set_unwrap(set.move())))); + set = flatten(wrap(reverse(unwrap(std::move(set))))); ASSERT(set.nDims() >= 1); std::vector> ret; diff --git a/src/schedule/parallelize_as.cc b/src/schedule/parallelize_as.cc index 86d9fe27a..6dd570196 100644 --- a/src/schedule/parallelize_as.cc +++ b/src/schedule/parallelize_as.cc @@ -193,8 +193,8 @@ Stmt parallelizeAs(const Stmt &_ast, const ID &nest, const ID &reference, views::concat(finder.reads(), finder.writes())) { GenPBExpr::VarMap externals; auto iter2idx = AnalyzeDeps::makeAccMapStatic( - presburger, *acc, acc->iter_.size(), acc->access_.size(), - RelaxMode::Possible, "", externals, {}, true); + presburger, *acc, acc->iter_.size(), acc->access_.size(), "", + externals, {}, true); if (!externals.empty()) { throw InvalidSchedule( FT_MSG << "Indirect thread mapping in reference loop nest "