Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix vardef choosing in schedule/parallelize_as #589

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/schedule/parallelize_as.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ PBMap projectOntoOneOutputDim(const PBMap &map, int dim) {
class AddParScopes : public TrackStmt<SymbolTable<Mutator>> {
typedef TrackStmt<SymbolTable<Mutator>> BaseClass;

ID nest_;
ID nest_, defId_;
const PBCtx &presburger_;
const std::vector<For> &orderedScopes_;
const std::unordered_map<ID, PBMap> &scope2Idx2Iter_;
Expand All @@ -55,18 +55,18 @@ class AddParScopes : public TrackStmt<SymbolTable<Mutator>> {
std::unordered_map<ID, std::vector<std::vector<Expr>>> threadGuard_;

public:
AddParScopes(const ID &nest, const PBCtx &presburger,
AddParScopes(const ID &nest, const ID &defId, const PBCtx &presburger,
const std::vector<For> &orderedScopes,
const std::unordered_map<ID, PBMap> &scope2Idx2Iter)
: nest_(nest), presburger_(presburger), orderedScopes_(orderedScopes),
scope2Idx2Iter_(scope2Idx2Iter) {}
: nest_(nest), defId_(defId), presburger_(presburger),
orderedScopes_(orderedScopes), scope2Idx2Iter_(scope2Idx2Iter) {}

const auto &newScopeIds() const { return newScopeIds_; }
const auto &newNestId() const { return newNestId_; }

private:
template <typename T> auto visitAcc(const T &op) {
if (inside_) {
if (inside_ && def(op->var_)->id() == defId_) {
std::vector<Expr> thisThreadGuard;
thisThreadGuard.reserve(orderedScopes_.size());
for (auto &&[scope, newIterName] :
Expand Down Expand Up @@ -261,7 +261,7 @@ Stmt parallelizeAs(const Stmt &_ast, const ID &nest, const ID &reference,
}
}

AddParScopes adder{nest, presburger, orderedScopes, scope2Idx2Iter};
AddParScopes adder{nest, defId, presburger, orderedScopes, scope2Idx2Iter};
ast = adder(ast);

// Shrink original loops in `nest` according to the gaurds with just add. If
Expand Down
32 changes: 32 additions & 0 deletions test/30.schedule/test_parallelize_as.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,38 @@ def test_reference_after_nest():
assert std.match(ast)


def test_choosing_vardef():
with ft.VarDef([("a", (8,), "int32", "input", "cpu"),
("c", (8,), "int32", "output", "cpu")]) as (a, c):
ft.MarkLabel("Vb")
with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b:
with ft.For("i", 0, 4, label="L1") as i:
with ft.For("j", 0, 2) as j:
b[i * 2 + j] = a[i * 2 + j] * 2
with ft.For("i", 0, 8, label="L2") as i:
c[i] = b[i] + a[0]
ast = ft.pop_ast(verbose=True)
s = ft.Schedule(ast, verbose=1)
s.parallelize("L1", "openmp")
s.parallelize_as("L2", "L1", "Vb")
ast = s.ast()
assert ft.find_stmt(ast, "<For>->L2").property.parallel == "openmp"

with ft.VarDef([("a", (8,), "int32", "input", "cpu"),
("c", (8,), "int32", "output", "cpu")]) as (a, c):
ft.MarkLabel("Vb")
with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b:
with ft.For("i", 0, 4, label="L1") as i:
with ft.For("j", 0, 2) as j:
b[i * 2 + j] = a[i * 2 + j] * 2
with ft.For("i", 0, 4, label="L1") as i:
with ft.For("j", 0, 2) as j:
c[i * 2 + j] = b[i * 2 + j] + a[0]
std = ft.pop_ast()

assert std.match(ast)


@pytest.mark.skipif(not ft.with_cuda(), reason="requires CUDA")
def test_multiple_levels():
with ft.VarDef([("a", (128, 128), "int32", "input", "cpu"),
Expand Down
Loading