Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove RelaxMode from analyze/deps #591

Merged
merged 1 commit into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 9 additions & 21 deletions include/analyze/deps.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ class FindAccessPoint : public SymbolTable<TrackStmt<Visitor>> {
std::vector<std::pair<Expr, ID>>
conds_; // FIXME: There may be out-dated conditions, and we must check
// allReads(cond) against allWrites(body) for each If or For
// nodes. See pass/simplify. If the condition violates, we may
// need to push a null condition according to RelaxMode
// nodes. See pass/simplify.
std::vector<Ref<AccessPoint>> reads_, writes_;

// For or StmtSeq -> coordinate in iteration space
Expand Down Expand Up @@ -268,7 +267,6 @@ struct Dependence {

typedef SyncFunc<void(const Dependence &)> FindDepsCallback;

enum class RelaxMode : int { Possible, Necessary };
enum class FindDepsMode : int {
Dep, // Dependence may happen between `earlier` and `later`
KillEarlier, // At any point in the space of `earlier`, it is dependent by
Expand Down Expand Up @@ -297,7 +295,6 @@ class AnalyzeDeps {
const FindDepsFilter &filter_;

const FindDepsMode mode_;
const RelaxMode earlierRelax_, laterRelax_;
const DepType depType_;
const bool ignoreReductionWAW_;
const bool eraseOutsideVarDef_;
Expand All @@ -321,16 +318,8 @@ class AnalyzeDeps {
: scope2coord_(scope2coord), noDepsLists_(noDepsLists),
variantExpr_(variantExpr), direction_(direction), found_(found),
earlierFilter_(earlierFilter), laterFilter_(laterFilter),
filter_(filter), mode_(mode),
earlierRelax_(mode_ == FindDepsMode::KillLater ||
mode_ == FindDepsMode::KillBoth
? RelaxMode::Necessary
: RelaxMode::Possible),
laterRelax_(mode_ == FindDepsMode::KillEarlier ||
mode_ == FindDepsMode::KillBoth
? RelaxMode::Necessary
: RelaxMode::Possible),
depType_(depType), ignoreReductionWAW_(ignoreReductionWAW),
filter_(filter), mode_(mode), depType_(depType),
ignoreReductionWAW_(ignoreReductionWAW),
eraseOutsideVarDef_(eraseOutsideVarDef),
noProjectOutPrivateAxis_(noProjectOutPrivateAxis) {
readsAsEarlier_ =
Expand Down Expand Up @@ -364,25 +353,24 @@ class AnalyzeDeps {
static std::vector<
std::pair<std::string /* list */, std::string /* cond */>>
makeAccList(GenPBExpr &genPBExpr, const std::vector<Expr> &list,
RelaxMode relax, GenPBExpr::VarMap &externals);
static std::string makeCond(GenPBExpr &genPBExpr, RelaxMode relax,
GenPBExpr::VarMap &externals);
static std::string makeCond(GenPBExpr &genPBExpr,
GenPBExpr::VarMap &externals,
bool eraseOutsideVarDef, const AccessPoint &ap);
static PBMap makeAccMapStatic(PBCtx &presburger, const AccessPoint &p,
int iterDim, int accDim, RelaxMode relax,
int iterDim, int accDim,
const std::string &extSuffix,
GenPBExpr::VarMap &externals,
const ASTHashSet<Expr> &noNeedToBeVars,
bool eraseOutsideVarDef);

private:
PBMap makeAccMap(PBCtx &presburger, const AccessPoint &p, int iterDim,
int accDim, RelaxMode relax, const std::string &extSuffix,
int accDim, const std::string &extSuffix,
GenPBExpr::VarMap &externals,
const ASTHashSet<Expr> &noNeedToBeVars) {
return makeAccMapStatic(presburger, p, iterDim, accDim, relax,
extSuffix, externals, noNeedToBeVars,
eraseOutsideVarDef_);
return makeAccMapStatic(presburger, p, iterDim, accDim, extSuffix,
externals, noNeedToBeVars, eraseOutsideVarDef_);
}

PBMap makeEqForBothOps(PBCtx &presburger,
Expand Down
20 changes: 19 additions & 1 deletion include/math/presburger.h
Original file line number Diff line number Diff line change
Expand Up @@ -912,8 +912,26 @@ template <PBSpaceRef T> PBSpace spaceMapFromSet(T &&space) {
return isl_space_map_from_set(PBRefTake<T>(space));
}

template <PBMapRef T> PBSet wrap(T &&map) {
return isl_map_wrap(PBRefTake<T>(map));
}

template <PBSetRef T> PBMap unwrap(T &&set) {
return isl_set_unwrap(PBRefTake<T>(set));
}

template <PBSetRef T> PBSet flatten(T &&set) {
return isl_set_flatten(PBRefTake<T>(set));
}
template <PBMapRef T> PBMap flattenDomain(T &&map) {
return isl_map_flatten_domain(PBRefTake<T>(map));
}
template <PBMapRef T> PBMap flattenRange(T &&map) {
return isl_map_flatten_range(PBRefTake<T>(map));
}

template <PBMapRef T> PBSet flattenMapToSet(T &&map) {
return isl_set_flatten(isl_map_wrap(PBRefTake<T>(map)));
return flatten(wrap(std::forward<T>(map)));
}

template <PBSetRef T> PBPoint sample(T &&set) {
Expand Down
18 changes: 9 additions & 9 deletions src/analyze/deps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ std::string AnalyzeDeps::makeNegIterMap(const std::vector<IterAxis> &list,

std::vector<std::pair<std::string /* list */, std::string /* cond */>>
AnalyzeDeps::makeAccList(GenPBExpr &genPBExpr, const std::vector<Expr> &list,
RelaxMode relax, GenPBExpr::VarMap &externals) {
GenPBExpr::VarMap &externals) {
std::vector<std::pair<std::string, std::string>> ret;
for (auto &&[l, c] : normalizeConditionalExprList(list)) {
std::ostringstream os;
Expand Down Expand Up @@ -349,7 +349,7 @@ AnalyzeDeps::makeAccList(GenPBExpr &genPBExpr, const std::vector<Expr> &list,
return ret;
}

std::string AnalyzeDeps::makeCond(GenPBExpr &genPBExpr, RelaxMode relax,
std::string AnalyzeDeps::makeCond(GenPBExpr &genPBExpr,
GenPBExpr::VarMap &externals,
bool eraseOutsideVarDef,
const AccessPoint &ap) {
Expand Down Expand Up @@ -431,15 +431,15 @@ std::string AnalyzeDeps::makeCond(GenPBExpr &genPBExpr, RelaxMode relax,
}

PBMap AnalyzeDeps::makeAccMapStatic(PBCtx &presburger, const AccessPoint &p,
int iterDim, int accDim, RelaxMode relax,
int iterDim, int accDim,
const std::string &extSuffix,
GenPBExpr::VarMap &externals,
const ASTHashSet<Expr> &noNeedToBeVars,
bool eraseOutsideVarDef) {
GenPBExpr genPBExpr(extSuffix, noNeedToBeVars);
auto iterList = makeIterList(p.iter_, iterDim);
auto condStr = makeCond(genPBExpr, relax, externals, eraseOutsideVarDef, p);
auto accListFactors = makeAccList(genPBExpr, p.access_, relax, externals);
auto condStr = makeCond(genPBExpr, externals, eraseOutsideVarDef, p);
auto accListFactors = makeAccList(genPBExpr, p.access_, externals);
std::ostringstream os;
for (auto &&[i, factor] : views::enumerate(accListFactors)) {
auto &&[accList, accCond] = factor;
Expand Down Expand Up @@ -1053,7 +1053,7 @@ void AnalyzeDeps::checkDepLatestEarlierImpl(

GenPBExpr::VarMap laterExternals;
PBMap laterMap =
makeAccMap(presburger, *later, iterDim, accDim, laterRelax_,
makeAccMap(presburger, *later, iterDim, accDim,
"later" + std::to_string((uint64_t)later->stmt_->id()),
laterExternals, noNeedToBeVars);
if (laterMap.empty()) {
Expand All @@ -1070,7 +1070,7 @@ void AnalyzeDeps::checkDepLatestEarlierImpl(
views::zip(views::ints(0, ranges::unreachable), earlierList,
earlierMapList, earlierExternalsList)) {
earlierMap = makeAccMap(
presburger, *earlier, iterDim, accDim, earlierRelax_,
presburger, *earlier, iterDim, accDim,
"earlier" + std::to_string((uint64_t)earlier->stmt_->id()),
earlierExternals, noNeedToBeVars);
}
Expand Down Expand Up @@ -1165,7 +1165,7 @@ void AnalyzeDeps::checkDepEarliestLaterImpl(

GenPBExpr::VarMap earlierExternals;
PBMap earlierMap =
makeAccMap(presburger, *earlier, iterDim, accDim, earlierRelax_,
makeAccMap(presburger, *earlier, iterDim, accDim,
"earlier" + std::to_string((uint64_t)earlier->stmt_->id()),
earlierExternals, noNeedToBeVars);
if (earlierMap.empty()) {
Expand All @@ -1181,7 +1181,7 @@ void AnalyzeDeps::checkDepEarliestLaterImpl(
views::zip(views::ints(0, ranges::unreachable), laterList,
laterMapList, laterExternalsList)) {
laterMap =
makeAccMap(presburger, *later, iterDim, accDim, laterRelax_,
makeAccMap(presburger, *later, iterDim, accDim,
"later" + std::to_string((uint64_t)later->stmt_->id()),
laterExternals, noNeedToBeVars);
}
Expand Down
3 changes: 1 addition & 2 deletions src/pass/shrink_for.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ class CompUniqueBoundsPBWithStride : public CompUniqueBoundsPB {
// dimension by local dimensions, instead of representing local
// dimensions by the target dimension. The set returned by isl_set_lift
// is a wrapped set, so we can simply unwrap it and then reverse it.
set = isl_set_flatten(
isl_map_wrap(isl_map_reverse(isl_set_unwrap(set.move()))));
set = flatten(wrap(reverse(unwrap(std::move(set)))));

ASSERT(set.nDims() >= 1);
std::vector<std::tuple<Expr, Expr, Expr, int64_t, Expr>> ret;
Expand Down
4 changes: 2 additions & 2 deletions src/schedule/parallelize_as.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ Stmt parallelizeAs(const Stmt &_ast, const ID &nest, const ID &reference,
views::concat(finder.reads(), finder.writes())) {
GenPBExpr::VarMap externals;
auto iter2idx = AnalyzeDeps::makeAccMapStatic(
presburger, *acc, acc->iter_.size(), acc->access_.size(),
RelaxMode::Possible, "", externals, {}, true);
presburger, *acc, acc->iter_.size(), acc->access_.size(), "",
externals, {}, true);
if (!externals.empty()) {
throw InvalidSchedule(
FT_MSG << "Indirect thread mapping in reference loop nest "
Expand Down
Loading