Skip to content

Commit

Permalink
Recognize IfExpr in indices in analyze/deps (#587)
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck authored Jan 17, 2024
1 parent db1bc89 commit 03c26a2
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 25 deletions.
8 changes: 4 additions & 4 deletions include/analyze/deps.h
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,10 @@ class AnalyzeDeps {
static std::string makeIterList(const std::vector<IterAxis> &list, int n);
static std::string makeNegIterMap(const std::vector<IterAxis> &list, int n);
static std::string makeNdList(const std::string &name, int n);
static std::string makeAccList(GenPBExpr &genPBExpr,
const std::vector<Expr> &list,
RelaxMode relax,
GenPBExpr::VarMap &externals);
static std::vector<
std::pair<std::string /* list */, std::string /* cond */>>
makeAccList(GenPBExpr &genPBExpr, const std::vector<Expr> &list,
RelaxMode relax, GenPBExpr::VarMap &externals);
static std::string makeCond(GenPBExpr &genPBExpr, RelaxMode relax,
GenPBExpr::VarMap &externals,
bool eraseOutsideVarDef, const AccessPoint &ap);
Expand Down
16 changes: 16 additions & 0 deletions include/analyze/normalize_conditional_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ namespace freetensor {
std::vector<std::pair<Expr /* value */, Expr /* condition, maybe null */>>
normalizeConditionalExpr(const Expr &expr);

/**
* Break a list of expressions into several conditional parts.
*
* This function is used for analyzing expressions with `IfExpr` inside. The
* result will be several parts with conditions, where each part is no longer
* with `IfExpr`.
*
* @param exprs : The list of expressions to be analyzed.
* @return : A vector of pairs, where the first element is the value of the
* expression list, and the second element is the condition of the expression.
* The condition may be null, which means the expression is always true.
*/
std::vector<
std::pair<std::vector<Expr> /* values */, Expr /* condition, maybe null */>>
normalizeConditionalExprList(const std::vector<Expr> &exprs);

} // namespace freetensor

#endif // FREE_TENSOR_NORMALIZE_CONDITIONAL_EXPR_H
9 changes: 9 additions & 0 deletions include/pass/z3_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ namespace freetensor {
* x - x to x)
* - It can deal with some more complex expressions, such as Mod
* - It may take some more time
*
* Z3Simplify can work on a root-less sub-AST, but in this case, it will not
* benefit from context from the missing ancestor nodes.
*
* Z3Simplify is conflict with SymbolTable. If you want to use these two classes
* together, please use `Z3SimplifyWithSymbolTable`.
*/
class Z3Simplify : public Mutator {
typedef Mutator BaseClass;
Expand Down Expand Up @@ -103,6 +109,9 @@ class Z3Simplify : public Mutator {
Stmt visit(const For &op) override;
};

/**
* Compatible inheritence of both Z3Simplify and SymbolTable
*/
class Z3SimplifyWithSymbolTable : public Z3Simplify,
public SymbolTableInterface {
SymbolTableData symbols_;
Expand Down
62 changes: 42 additions & 20 deletions src/analyze/deps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <analyze/all_uses.h>
#include <analyze/deps.h>
#include <analyze/find_stmt.h>
#include <analyze/normalize_conditional_expr.h>
#include <container_utils.h>
#include <disjoint_set.h>
#include <except.h>
Expand Down Expand Up @@ -316,24 +317,36 @@ std::string AnalyzeDeps::makeNegIterMap(const std::vector<IterAxis> &list,
return "{[" + lhs + "] -> [" + rhs + "]}";
}

std::string AnalyzeDeps::makeAccList(GenPBExpr &genPBExpr,
const std::vector<Expr> &list,
RelaxMode relax,
GenPBExpr::VarMap &externals) {
std::string ret;
for (int i = 0, iEnd = list.size(); i < iEnd; i++) {
auto &&[linstr, vars] = genPBExpr.gen(list[i]);
ret += linstr;
for (auto &&[expr, str] : vars) {
if (expr->nodeType() != ASTNodeType::Var) {
externals[expr] = str;
std::vector<std::pair<std::string /* list */, std::string /* cond */>>
AnalyzeDeps::makeAccList(GenPBExpr &genPBExpr, const std::vector<Expr> &list,
RelaxMode relax, GenPBExpr::VarMap &externals) {
std::vector<std::pair<std::string, std::string>> ret;
for (auto &&[l, c] : normalizeConditionalExprList(list)) {
std::ostringstream os;
os << "[";
for (auto &&[i, item] : views::enumerate(l)) {
auto &&[linstr, vars] = genPBExpr.gen(item);
os << (i > 0 ? ", " : "") << linstr;
for (auto &&[expr, str] : vars) {
if (expr->nodeType() != ASTNodeType::Var) {
externals[expr] = str;
}
}
}
if (i < iEnd - 1) {
ret += ", ";
os << "]";
std::string condStr;
if (c.isValid()) {
GenPBExpr::VarMap condVars;
std::tie(condStr, condVars) = genPBExpr.gen(c);
for (auto &&[expr, str] : condVars) {
if (expr->nodeType() != ASTNodeType::Var) {
externals[expr] = str;
}
}
}
ret.emplace_back(os.str(), condStr);
}
return "[" + ret + "]";
return ret;
}

std::string AnalyzeDeps::makeCond(GenPBExpr &genPBExpr, RelaxMode relax,
Expand Down Expand Up @@ -424,11 +437,20 @@ PBMap AnalyzeDeps::makeAccMapStatic(PBCtx &presburger, const AccessPoint &p,
const ASTHashSet<Expr> &noNeedToBeVars,
bool eraseOutsideVarDef) {
GenPBExpr genPBExpr(extSuffix, noNeedToBeVars);
auto ret = makeIterList(p.iter_, iterDim) + " -> " +
makeAccList(genPBExpr, p.access_, relax, externals);
if (auto str = makeCond(genPBExpr, relax, externals, eraseOutsideVarDef, p);
!str.empty()) {
ret += ": " + str;
auto iterList = makeIterList(p.iter_, iterDim);
auto condStr = makeCond(genPBExpr, relax, externals, eraseOutsideVarDef, p);
auto accListFactors = makeAccList(genPBExpr, p.access_, relax, externals);
std::ostringstream os;
for (auto &&[i, factor] : views::enumerate(accListFactors)) {
auto &&[accList, accCond] = factor;
os << (i > 0 ? "; " : "") << iterList << " -> " << accList;
auto cond = !condStr.empty() && !accCond.empty()
? condStr + " and " + accCond
: !condStr.empty() ? condStr
: accCond;
if (!cond.empty()) {
os << ": " << cond;
}
}
std::string ext;
if (!externals.empty()) {
Expand All @@ -439,7 +461,7 @@ PBMap AnalyzeDeps::makeAccMapStatic(PBCtx &presburger, const AccessPoint &p,
}
ext = "[" + ext + "] -> ";
}
ret = ext + "{" + ret + "}";
auto ret = ext + "{" + os.str() + "}";
auto unordered = PBMap(presburger, ret);
auto negIterMap = PBMap(presburger, makeNegIterMap(p.iter_, iterDim));
auto ordered = applyDomain(std::move(unordered), std::move(negIterMap));
Expand Down
23 changes: 23 additions & 0 deletions src/analyze/normalize_conditional_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,27 @@ normalizeConditionalExpr(const Expr &expr) {
return ret;
}

std::vector<
std::pair<std::vector<Expr> /* values */, Expr /* condition, maybe null */>>
normalizeConditionalExprList(const std::vector<Expr> &exprs) {
std::vector<std::pair<std::vector<Expr>, Expr>> result;
std::vector<Expr> items, conds;
std::function<void()> recurse = [&]() {
if (items.size() == exprs.size()) {
result.emplace_back(items, combineCond(conds));
} else {
for (const auto &kv :
normalizeConditionalExpr(exprs[items.size()])) {
items.emplace_back(kv.first);
conds.emplace_back(kv.second);
recurse();
items.pop_back();
conds.pop_back();
}
}
};
recurse();
return result;
}

} // namespace freetensor
2 changes: 1 addition & 1 deletion src/schedule/parallelize_as.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class AddParScopes : public TrackStmt<SymbolTable<Mutator>> {
Stmt visitStmt(const Stmt &s) override {
if (s->id() == nest_) {
auto usedNames = uni(names(), allNames(s));
for (auto &&scope : views::reverse(orderedScopes_)) {
for (auto &&scope : orderedScopes_) {
auto newIterName = getNewName(scope->iter_, usedNames);
usedNames.emplace(newIterName);
newIterNames_.emplace_back(newIterName);
Expand Down
18 changes: 18 additions & 0 deletions test/30.schedule/test_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,21 @@ def test_no_merge_if_outer_iter_var_is_used_in_inner():
s = ft.Schedule(ast, verbose=2)
with pytest.raises(ft.InvalidSchedule):
s.reorder(["L2", "L1"])


def test_if_expr():
with ft.VarDef("y", (32,), "int32", "output", "cpu") as y:
with ft.For("i", 0, 4, label="L1") as i:
with ft.For("j", 0, 8, label="L2") as j:
y[i * 8 + j + ft.if_then_else(i <= 1, 16, -16)] = i + j
ast = ft.pop_ast(verbose=True)
ast = ft.schedule(ast, lambda s: s.reorder(["L2", "L1"]), verbose=1)
ast = ft.lower(ast, verbose=1)

with ft.VarDef("y", (32,), "int32", "output", "cpu") as y:
with ft.For("j", 0, 8, label="L2") as j:
with ft.For("i", 0, 4, label="L1") as i:
y[i * 8 + j + ft.if_then_else(i <= 1, 16, -16)] = i + j
std = ft.pop_ast()

assert std.match(ast)

0 comments on commit 03c26a2

Please sign in to comment.