Skip to content

Commit

Permalink
Detect high-order strides in schedule/parallelize_as
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck committed Jan 15, 2024
1 parent 051ab9e commit 0f19728
Show file tree
Hide file tree
Showing 8 changed files with 409 additions and 133 deletions.
15 changes: 9 additions & 6 deletions ffi/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,15 @@ void init_ffi_pass(py::module_ &m) {
"stmt"_a);

m.def("shrink_for",
static_cast<Func (*)(const Func &, const ID &, const bool &)>(
&shrinkFor),
"func"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true);
m.def("shrink_for",
static_cast<Stmt (*)(const Stmt &, const ID &, bool)>(&shrinkFor),
"stmt"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true);
static_cast<Func (*)(const Func &, const ID &, const bool &,
const bool &)>(&shrinkFor),
"func"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true,
"unordered"_a = false);
m.def(
"shrink_for",
static_cast<Stmt (*)(const Stmt &, const ID &, bool, bool)>(&shrinkFor),
"stmt"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true,
"unordered"_a = false);

m.def("merge_and_hoist_if",
static_cast<Func (*)(const Func &)>(&mergeAndHoistIf), "func"_a);
Expand Down
31 changes: 26 additions & 5 deletions include/math/parse_pb_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,37 @@ typedef std::vector<SimplePBFuncAST> PBFuncAST;
*/
PBFuncAST parsePBFunc(const std::string &str);

/**
* Parse a PBFunc to be ASTs, but only restricted to one contiguous factor
*/
SimplePBFuncAST parseSimplePBFunc(const std::string &str);

/**
* Construct AST from PBSet while preserving min and max with a special hack to
* ISL
*
* @{
*/
PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set);
PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBMap &map);
/** @} */

/**
* Parse a PBFunc to be ASTs, but only restricted to one contiguous factor
*
* @{
*/
inline SimplePBFuncAST parseSimplePBFunc(const std::string &str) {
auto ret = parsePBFunc(str);
if (ret.size() != 1) {
throw ParserError(str + " is not a simple PBFunc");
}
return ret.front();
}
inline SimplePBFuncAST parseSimplePBFuncReconstructMinMax(const PBCtx &ctx,
const auto &f) {
auto ret = parsePBFuncReconstructMinMax(ctx, f);
if (ret.size() != 1) {
throw ParserError(FT_MSG << f << " is not a simple PBFunc");
}
return ret.front();
}
/** @} */

} // namespace freetensor

Expand Down
7 changes: 7 additions & 0 deletions include/math/presburger.h
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,13 @@ template <PBMapRef T> PBMap upperBoundOutputDim(T &&map, unsigned pos, int x) {
return isl_map_upper_bound_si(PBRefTake<T>(map), isl_dim_out, pos, x);
}

template <PBSetRef T> PBMap newDomainOnlyMap(T &&set) {
return isl_map_from_domain(PBRefTake(set));
}
template <PBSetRef T> PBMap newRangeOnlyMap(T &&set) {
return isl_map_from_range(PBRefTake(set));
}

template <PBMapRef T>
PBMap moveDimsInputToOutput(T &&map, unsigned first, unsigned n,
unsigned target) {
Expand Down
16 changes: 13 additions & 3 deletions include/pass/shrink_for.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class CheckSideEffect : public Visitor {
class ShrinkFor : public CompTransientBounds<SymbolTable<Mutator>> {
typedef CompTransientBounds<SymbolTable<Mutator>> BaseClass;

bool unordered_;

ASTHashMap<Var, std::vector<Ref<CompUniqueBounds::Bound>>> newRange_;
std::vector<Var> iterStack_;
std::vector<std::unordered_set<std::string>> namesStack_;
Expand All @@ -36,6 +38,8 @@ class ShrinkFor : public CompTransientBounds<SymbolTable<Mutator>> {
bool inSubAST_ = false;

public:
ShrinkFor(bool unordered = false) : unordered_(unordered) {}

void setSubAST(const Stmt &subAST);

protected:
Expand Down Expand Up @@ -63,13 +67,19 @@ class ShrinkFor : public CompTransientBounds<SymbolTable<Mutator>> {
* @param doSimplify : If true, run simplify before and after the tranformation.
* Transformations are required to ensure the effectiveness of the shrinking.
* Please do your own simplification if you want to set it to false.
* @param unordered : If true, shrink the loops aggressively which does NOT
* preserve the iterating order. The caller should guarantee the correctness.
* This is effective when the pattern of redundant iterations is complex. This
* may also split one loop into multiple loops.
*
* @{
*/
Stmt shrinkFor(const Stmt &op, const ID &subAST = ID(), bool doSimplify = true);
Stmt shrinkFor(const Stmt &op, const ID &subAST = ID(), bool doSimplify = true,
bool unordered = false);
inline Stmt shrinkFor(const Stmt &op, const Stmt &subAST,
bool doSimplify = true) {
return shrinkFor(op, subAST.isValid() ? subAST->id() : ID(), doSimplify);
bool doSimplify = true, bool unordered = false) {
return shrinkFor(op, subAST.isValid() ? subAST->id() : ID(), doSimplify,
unordered);
}
/** @} */

Expand Down
87 changes: 58 additions & 29 deletions src/math/parse_pb_expr.cc
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#include <sstream>

#include <antlr4-runtime.h>
#include <isl/ast.h>
#include <isl/ast_build.h>
#include <isl/ast_type.h>
#include <isl/union_set.h>

#include <analyze/all_uses.h>
#include <container_utils.h>
#include <debug.h>
#include <math/parse_pb_expr.h>
#include <mutator.h>
Expand Down Expand Up @@ -125,14 +127,6 @@ PBFuncAST parsePBFunc(const std::string &str) {
}
}

SimplePBFuncAST parseSimplePBFunc(const std::string &str) {
auto ret = parsePBFunc(str);
if (ret.size() != 1) {
throw ParserError(str + " is not a simple PBFunc");
}
return ret.front();
}

namespace {

Expr isl2Expr(__isl_take isl_ast_expr *e) {
Expand Down Expand Up @@ -254,24 +248,25 @@ Expr isl2Expr(__isl_take isl_ast_expr *e) {
return res;
}

PBFuncAST isl2Func(__isl_take isl_ast_node *node) {
PBFuncAST ret;
std::vector<std::pair<std::vector<Expr> /* values */, Expr /* cond */>>
isl2Func(__isl_take isl_ast_node *node) {
std::vector<std::pair<std::vector<Expr>, Expr>> ret;
try {
if (isl_ast_node_get_type(node) == isl_ast_node_if) {
auto cond = isl2Expr(isl_ast_node_if_get_cond(node));
for (auto &&[thenNames, thenFT, thenCond] :
for (auto &&[thenFT, thenCond] :
isl2Func(isl_ast_node_if_get_then(node))) {
ret.push_back(SimplePBFuncAST{
thenNames, thenFT,
thenCond.isValid() ? makeLAnd(cond, thenCond) : cond});
ret.emplace_back(thenFT, thenCond.isValid()
? makeLAnd(cond, thenCond)
: cond);
}
if (isl_ast_node_if_has_else(node)) {
for (auto &&[elseNames, elseFT, elseCond] :
for (auto &&[elseFT, elseCond] :
isl2Func(isl_ast_node_if_get_else(node))) {
ret.push_back(SimplePBFuncAST{
elseNames, elseFT,
elseCond.isValid() ? makeLAnd(makeLNot(cond), elseCond)
: makeLNot(cond)});
ret.emplace_back(elseFT,
elseCond.isValid()
? makeLAnd(makeLNot(cond), elseCond)
: makeLNot(cond));
}
}

Expand All @@ -292,14 +287,7 @@ PBFuncAST isl2Func(__isl_take isl_ast_node *node) {
}) |
ranges::to_vector;

std::unordered_set<std::string> names;
for (auto &&item : vals) {
for (auto &&name : allNames(item)) {
names.insert(name);
}
}
ret = {SimplePBFuncAST{ranges::to<std::vector>(names), vals,
nullptr}};
ret = {{vals, nullptr}};
} catch (...) {
isl_ast_expr_free(expr);
throw;
Expand All @@ -324,6 +312,13 @@ PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set) {

ASSERT(set.isSingleValued());

std::vector<std::string> params =
views::ints(0, set.nParamDims()) |
views::transform([&](int i) -> std::string {
return isl_set_get_dim_name(set.get(), isl_dim_param, i);
}) |
ranges::to<std::vector>();

isl_options_set_ast_build_detect_min_max(ctx.get(), 1);

PBFuncAST ret;
Expand All @@ -333,7 +328,9 @@ PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set) {
isl_schedule_from_domain(isl_union_set_from_set(set.copy()));
isl_ast_node *ast =
isl_ast_build_node_from_schedule(build /* keep */, s /* take */);
ret = isl2Func(ast /* take */);
for (auto &&[vals, cond] : isl2Func(ast /* take */)) {
ret.emplace_back(params, vals, cond);
}
} catch (...) {
isl_ast_build_free(build);
throw;
Expand All @@ -343,4 +340,36 @@ PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBSet &set) {
return ret;
}

namespace {

template <PBMapRef T> PBMap moveAllInputDimsToParam(const PBCtx &ctx, T &&map) {
// A name is required for the parameter, so we can't simply use
// isl_map_move_dims. We constuct a map to apply on the set to move the
// dimension. Example map: [i1, i2] -> {[i1, i2] -> []}. The parameters are
// assigned with temporary names.

int nInDims = map.nInDims();
std::ostringstream os;
os << "["
<< (views::ints(0, nInDims) | views::transform([](int i) {
return "ft_unnamed_in_dim_" + std::to_string(i);
}) |
join(","))
<< "] -> {["
<< (views::ints(0, nInDims) | views::transform([](int i) {
return "ft_unnamed_in_dim_" + std::to_string(i);
}) |
join(","))
<< "] -> []}";
PBMap moving(ctx, os.str());
return applyDomain(std::forward<T>(map), std::move(moving));
}

} // Anonymous namespace

PBFuncAST parsePBFuncReconstructMinMax(const PBCtx &ctx, const PBMap &map) {
return parsePBFuncReconstructMinMax(
ctx, range(moveAllInputDimsToParam(ctx, map)));
}

} // namespace freetensor
Loading

0 comments on commit 0f19728

Please sign in to comment.