Skip to content

Commit

Permalink
Add schedule/parallelize_as
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck committed Jan 10, 2024
1 parent ebefa09 commit b84d74f
Show file tree
Hide file tree
Showing 18 changed files with 638 additions and 61 deletions.
41 changes: 31 additions & 10 deletions ffi/parallel_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,40 @@ void init_ffi_parallel_scope(py::module_ &m) {
.def(py::init<>())
.def("__str__",
[](const SerialScope &scope) { return toString(scope); })
.def("__eq__", [](const SerialScope &lhs, const SerialScope &rhs) {
return lhs == rhs;
});
.def("__eq__", [](const SerialScope &lhs,
const SerialScope &rhs) { return lhs == rhs; })
.def("__eq__",
[](const SerialScope &lhs, const std::string &rhs) {
return ParallelScope{lhs} == parseParallelScope(rhs);
})
.def("__eq__",
[](const SerialScope &lhs, py::object rhs) { return false; });

py::class_<OpenMPScope>(m, "OpenMPScope")
.def(py::init<>())
.def("__str__",
[](const OpenMPScope &scope) { return toString(scope); })
.def("__eq__", [](const OpenMPScope &lhs, const OpenMPScope &rhs) {
return lhs == rhs;
});
.def("__eq__", [](const OpenMPScope &lhs,
const OpenMPScope &rhs) { return lhs == rhs; })
.def("__eq__",
[](const OpenMPScope &lhs, const std::string &rhs) {
return ParallelScope{lhs} == parseParallelScope(rhs);
})
.def("__eq__",
[](const OpenMPScope &lhs, py::object rhs) { return false; });

py::class_<CUDAStreamScope>(m, "CUDAStreamScope")
.def(py::init<>())
.def("__str__",
[](const CUDAStreamScope &scope) { return toString(scope); })
.def("__eq__", [](const CUDAStreamScope &lhs,
const CUDAStreamScope &rhs) { return lhs == rhs; });
const CUDAStreamScope &rhs) { return lhs == rhs; })
.def("__eq__",
[](const CUDAStreamScope &lhs, const std::string &rhs) {
return ParallelScope{lhs} == parseParallelScope(rhs);
})
.def("__eq__",
[](const CUDAStreamScope &lhs, py::object rhs) { return false; });

py::enum_<CUDAScope::Level>(m, "CUDAScopeLevel")
.value("Block", CUDAScope::Level::Block)
Expand All @@ -40,9 +56,14 @@ void init_ffi_parallel_scope(py::module_ &m) {
return CUDAScope{level, dim};
}))
.def("__str__", [](const CUDAScope &scope) { return toString(scope); })
.def("__eq__", [](const CUDAScope &lhs, const CUDAScope &rhs) {
return lhs == rhs;
});
.def("__eq__", [](const CUDAScope &lhs,
const CUDAScope &rhs) { return lhs == rhs; })
.def("__eq__",
[](const CUDAScope &lhs, const std::string &rhs) {
return ParallelScope{lhs} == parseParallelScope(rhs);
})
.def("__eq__",
[](const CUDAScope &lhs, py::object rhs) { return false; });

// Factory function, used as a class
m.def("ParallelScope", &parseParallelScope);
Expand Down
8 changes: 4 additions & 4 deletions ffi/pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ void init_ffi_pass(py::module_ &m) {
"stmt"_a);

m.def("shrink_for",
static_cast<Func (*)(const Func &, const Stmt &, const bool &)>(
static_cast<Func (*)(const Func &, const ID &, const bool &)>(
&shrinkFor),
"func"_a, "sub_ast"_a = nullptr, "do_simplify"_a = true);
"func"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true);
m.def("shrink_for",
static_cast<Stmt (*)(const Stmt &, const Stmt &, bool)>(&shrinkFor),
"stmt"_a, "sub_ast"_a = nullptr, "do_simplify"_a = true);
static_cast<Stmt (*)(const Stmt &, const ID &, bool)>(&shrinkFor),
"stmt"_a, py::arg_v("sub_ast", ID(), "ID()"), "do_simplify"_a = true);

m.def("merge_and_hoist_if",
static_cast<Func (*)(const Func &)>(&mergeAndHoistIf), "func"_a);
Expand Down
2 changes: 2 additions & 0 deletions ffi/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ void init_ffi_schedule(py::module_ &m) {
.def("inline", &Schedule::inlining, "vardef"_a)
.def("parallelize", &Schedule::parallelize, "loop"_a, "parallel"_a,
"allow_reduction"_a = true)
.def("parallelize_as", &Schedule::parallelizeAs, "nest"_a,
"reference"_a, "def_id"_a)
.def("unroll", &Schedule::unroll, "loop"_a, "immedate"_a = false)
.def("vectorize", &Schedule::vectorize, "loop"_a)
.def("separate_tail", &Schedule::separateTail,
Expand Down
12 changes: 11 additions & 1 deletion include/analyze/deps.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,22 @@ class AnalyzeDeps {
static std::string makeCond(GenPBExpr &genPBExpr, RelaxMode relax,
GenPBExpr::VarMap &externals,
bool eraseOutsideVarDef, const AccessPoint &ap);
static PBMap makeAccMapStatic(PBCtx &presburger, const AccessPoint &p,
int iterDim, int accDim, RelaxMode relax,
const std::string &extSuffix,
GenPBExpr::VarMap &externals,
const ASTHashSet<Expr> &noNeedToBeVars,
bool eraseOutsideVarDef);

private:
PBMap makeAccMap(PBCtx &presburger, const AccessPoint &p, int iterDim,
int accDim, RelaxMode relax, const std::string &extSuffix,
GenPBExpr::VarMap &externals,
const ASTHashSet<Expr> &noNeedToBeVars);
const ASTHashSet<Expr> &noNeedToBeVars) {
return makeAccMapStatic(presburger, p, iterDim, accDim, relax,
extSuffix, externals, noNeedToBeVars,
eraseOutsideVarDef_);
}

PBMap makeEqForBothOps(PBCtx &presburger,
const std::vector<std::pair<int, int>> &coord,
Expand Down
14 changes: 14 additions & 0 deletions include/get_new_name.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef FREE_TENSOR_GET_NEW_NAME_H
#define FREE_TENSOR_GET_NEW_NAME_H

#include <string>
#include <unordered_set>

namespace freetensor {

std::string getNewName(const std::string &oldName,
const std::unordered_set<std::string> &used);

}

#endif // FREE_TENSOR_GET_NEW_NAME_H
10 changes: 8 additions & 2 deletions include/pass/shrink_for.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,15 @@ class ShrinkFor : public CompTransientBounds<SymbolTable<Mutator>> {
/**
* Increase the begin and decrease the end index, to remove redundant iterations
* from For loops
*
* @{
*/
Stmt shrinkFor(const Stmt &op, const Stmt &subAST = nullptr,
bool doSimplify = true);
Stmt shrinkFor(const Stmt &op, const ID &subAST = ID(), bool doSimplify = true);
inline Stmt shrinkFor(const Stmt &op, const Stmt &subAST,
bool doSimplify = true) {
return shrinkFor(op, subAST.isValid() ? subAST->id() : ID(), doSimplify);
}
/** @} */

DEFINE_PASS_FOR_FUNC(shrinkFor)

Expand Down
15 changes: 15 additions & 0 deletions include/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,21 @@ class Schedule {
void parallelize(const ID &loop, const ParallelScope &parallel,
bool allowReduction = true);

/**
* Parallelize a loop nest according to another loop nest to keep a tensor
* thread-local
*
* @param nest : ID of the loop nest to be parallelized. The ID can be of
* any statement type, and all statements it contains will be parallelized.
* @param reference: ID of the loop nest to be referenced. The ID can be of
* any statement type, and all statements it contains will be referenced.
* @param defId : ID of the VarDef statement of the tensor to be kept
* thread-local.
* @throw InvalidSchedule if any of the ID is not found, or the reference
* loop nest is already thread-non-local.
*/
void parallelizeAs(const ID &nest, const ID &reference, const ID &defId);

/**
* Unroll a loop
*
Expand Down
13 changes: 13 additions & 0 deletions include/schedule/parallelize_as.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef FREE_TENSOR_PARALLELIZE_AS_H
#define FREE_TENSOR_PARALLELIZE_AS_H

#include <stmt.h>

namespace freetensor {

Stmt parallelizeAs(const Stmt &ast, const ID &nest, const ID &reference,
const ID &defId);

} // namespace freetensor

#endif // FREE_TENSOR_PARALLELIZE_AS_H
17 changes: 9 additions & 8 deletions include/schedule/schedule_log.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ enum class ScheduleType : int {
VarReorder,
Inline,
Parallelize,
ParallelizeAs,
Unroll,
Vectorize,
SeparateTail,
Expand All @@ -42,14 +43,14 @@ enum class ScheduleType : int {
};

constexpr std::array scheduleTypeNames = {
"split", "reorder", "merge",
"fission", "fuse", "swap",
"blend", "cache", "cache_reduction",
"set_mem_type", "var_split", "var_merge",
"var_reorder", "inline", "parallelize",
"unroll", "vectorize", "separate_tail",
"as_matmul", "permute", "pluto_fuse",
"pluto_permute",
"split", "reorder", "merge",
"fission", "fuse", "swap",
"blend", "cache", "cache_reduction",
"set_mem_type", "var_split", "var_merge",
"var_reorder", "inline", "parallelize",
"parallelize_as", "unroll", "vectorize",
"separate_tail", "as_matmul", "permute",
"pluto_fuse", "pluto_permute",
};
static_assert(scheduleTypeNames.size() == (size_t)ScheduleType::NumTypes);

Expand Down
25 changes: 25 additions & 0 deletions python/freetensor/core/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,31 @@ def parallelize(self, loop, parallel):
"""
super().parallelize(self._lookup(loop), ParallelScope(parallel))

def parallelize_as(self, nest, reference, vardef):
'''
Parallelize a loop nest according to another loop nest to keep a tensor
thread-local
Parameters
----------
nest : str, ID or Stmt
The loop nest to be parallelized. The ID can be of any statement type,
and all statements it contains will be parallelized.
reference: str, ID or Stmt
The loop nest to be referenced. The ID can be of any statement type,
and all statements it contains will be referenced.
vardef : str, ID or Stmt
The VarDef statement of the tensor to be kept thread-local.
Raises
------
InvalidSchedule
if any of the ID is not found, or the reference loop nest is already
thread-non-local.
'''
super().parallelize_as(self._lookup(nest), self._lookup(reference),
self._lookup(vardef))

def unroll(self, loop, immediate=False):
"""
Unroll a loop
Expand Down
14 changes: 7 additions & 7 deletions src/analyze/deps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,16 +417,16 @@ std::string AnalyzeDeps::makeCond(GenPBExpr &genPBExpr, RelaxMode relax,
return ret;
}

PBMap AnalyzeDeps::makeAccMap(PBCtx &presburger, const AccessPoint &p,
int iterDim, int accDim, RelaxMode relax,
const std::string &extSuffix,
GenPBExpr::VarMap &externals,
const ASTHashSet<Expr> &noNeedToBeVars) {
PBMap AnalyzeDeps::makeAccMapStatic(PBCtx &presburger, const AccessPoint &p,
int iterDim, int accDim, RelaxMode relax,
const std::string &extSuffix,
GenPBExpr::VarMap &externals,
const ASTHashSet<Expr> &noNeedToBeVars,
bool eraseOutsideVarDef) {
GenPBExpr genPBExpr(extSuffix, noNeedToBeVars);
auto ret = makeIterList(p.iter_, iterDim) + " -> " +
makeAccList(genPBExpr, p.access_, relax, externals);
if (auto str =
makeCond(genPBExpr, relax, externals, eraseOutsideVarDef_, p);
if (auto str = makeCond(genPBExpr, relax, externals, eraseOutsideVarDef, p);
!str.empty()) {
ret += ": " + str;
}
Expand Down
10 changes: 1 addition & 9 deletions src/frontend/inlined_invoke.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,12 @@
#include <analyze/all_uses.h>
#include <container_utils.h>
#include <frontend/inlined_invoke.h>
#include <get_new_name.h>
#include <pass/hoist_return_vars.h>
#include <pass/rename_var.h>

namespace freetensor {

static std::string getNewName(const std::string &oldName,
const std::unordered_set<std::string> &used) {
for (int i = 1;; i++) {
if (auto name = oldName + "." + std::to_string(i); !used.count(name)) {
return name;
}
}
}

Stmt StripReturns::visit(const VarDef &op) {
if (auto it = std::find_if(
returns_.begin(), returns_.end(),
Expand Down
14 changes: 14 additions & 0 deletions src/get_new_name.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include <get_new_name.h>

namespace freetensor {

std::string getNewName(const std::string &oldName,
const std::unordered_set<std::string> &used) {
for (int i = 1;; i++) {
if (auto name = oldName + "." + std::to_string(i); !used.count(name)) {
return name;
}
}
}

} // namespace freetensor
10 changes: 1 addition & 9 deletions src/pass/gpu/normalize_var_in_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <analyze/comp_unique_bounds_combination.h>
#include <container_utils.h>
#include <get_new_name.h>
#include <pass/gpu/normalize_var_in_kernel.h>
#include <pass/rename_var.h>
#include <pass/simplify.h>
Expand Down Expand Up @@ -35,15 +36,6 @@ std::unordered_map<std::string, int> countNames(const Stmt &s) {
return visitor.nameCnt();
}

std::string getNewName(const std::string &oldName,
const std::unordered_set<std::string> &used) {
for (int i = 1;; i++) {
if (auto name = oldName + "." + std::to_string(i); !used.count(name)) {
return name;
}
}
}

} // Anonymous namespace

Stmt NormalizeVarInKernel::visit(const VarDef &_op) {
Expand Down
5 changes: 3 additions & 2 deletions src/pass/shrink_for.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <analyze/comp_unique_bounds_pb.h>
#include <analyze/find_stmt.h>
#include <math/min_max.h>
#include <math/parse_pb_expr.h>
#include <pass/replace_iter.h>
Expand Down Expand Up @@ -213,7 +214,7 @@ void ShrinkFor::setSubAST(const Stmt &subAST) {
subASTAncestors_.insert(s);
}

Stmt shrinkFor(const Stmt &_op, const Stmt &subAST, bool doSimplify) {
Stmt shrinkFor(const Stmt &_op, const ID &subAST, bool doSimplify) {
auto op = _op;

// DO NOT CALL normalizeLoops HERE! Since we often use (-INT_MAX, INT_MAX)
Expand All @@ -225,7 +226,7 @@ Stmt shrinkFor(const Stmt &_op, const Stmt &subAST, bool doSimplify) {

ShrinkFor shrinker;
if (subAST.isValid())
shrinker.setSubAST(subAST);
shrinker.setSubAST(findStmt(op, subAST));
op = shrinker(op);

if (doSimplify) // Make new ranges simple + remove redundant branches
Expand Down
10 changes: 1 addition & 9 deletions src/schedule/hoist_selected_var.cc
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
#include <analyze/all_uses.h>
#include <analyze/find_stmt.h>
#include <container_utils.h>
#include <get_new_name.h>
#include <pass/rename_var.h>
#include <schedule/hoist_selected_var.h>

namespace freetensor {

static std::string getNewName(const std::string &oldName,
const std::unordered_set<std::string> &used) {
for (int i = 1;; i++) {
if (auto name = oldName + "." + std::to_string(i); !used.count(name)) {
return name;
}
}
}

Stmt HoistSelectedVar::visit(const For &op) {
if (toHoist_.count(op->body_->id())) {
ASSERT(op->body_->nodeType() == ASTNodeType::VarDef);
Expand Down
Loading

0 comments on commit b84d74f

Please sign in to comment.