Skip to content

Commit

Permalink
Fix vardef choosing in schedule/parallelize_as
Browse files Browse the repository at this point in the history
  • Loading branch information
roastduck committed Jan 17, 2024
1 parent 8222fe3 commit a3f6320
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/schedule/parallelize_as.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ PBMap projectOntoOneOutputDim(const PBMap &map, int dim) {
class AddParScopes : public TrackStmt<SymbolTable<Mutator>> {
typedef TrackStmt<SymbolTable<Mutator>> BaseClass;

ID nest_;
ID nest_, defId_;
const PBCtx &presburger_;
const std::vector<For> &orderedScopes_;
const std::unordered_map<ID, PBMap> &scope2Idx2Iter_;
Expand All @@ -55,18 +55,18 @@ class AddParScopes : public TrackStmt<SymbolTable<Mutator>> {
std::unordered_map<ID, std::vector<std::vector<Expr>>> threadGuard_;

public:
AddParScopes(const ID &nest, const PBCtx &presburger,
AddParScopes(const ID &nest, const ID &defId, const PBCtx &presburger,
const std::vector<For> &orderedScopes,
const std::unordered_map<ID, PBMap> &scope2Idx2Iter)
: nest_(nest), presburger_(presburger), orderedScopes_(orderedScopes),
scope2Idx2Iter_(scope2Idx2Iter) {}
: nest_(nest), defId_(defId), presburger_(presburger),
orderedScopes_(orderedScopes), scope2Idx2Iter_(scope2Idx2Iter) {}

const auto &newScopeIds() const { return newScopeIds_; }
const auto &newNestId() const { return newNestId_; }

private:
template <typename T> auto visitAcc(const T &op) {
if (inside_) {
if (inside_ && def(op->var_)->id() == defId_) {
std::vector<Expr> thisThreadGuard;
thisThreadGuard.reserve(orderedScopes_.size());
for (auto &&[scope, newIterName] :
Expand Down Expand Up @@ -261,7 +261,7 @@ Stmt parallelizeAs(const Stmt &_ast, const ID &nest, const ID &reference,
}
}

AddParScopes adder{nest, presburger, orderedScopes, scope2Idx2Iter};
AddParScopes adder{nest, defId, presburger, orderedScopes, scope2Idx2Iter};
ast = adder(ast);

// Shrink original loops in `nest` according to the gaurds with just add. If
Expand Down
32 changes: 32 additions & 0 deletions test/30.schedule/test_parallelize_as.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,38 @@ def test_reference_after_nest():
assert std.match(ast)


def test_choosing_vardef():
with ft.VarDef([("a", (8,), "int32", "input", "cpu"),
("c", (8,), "int32", "output", "cpu")]) as (a, c):
ft.MarkLabel("Vb")
with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b:
with ft.For("i", 0, 4, label="L1") as i:
with ft.For("j", 0, 2) as j:
b[i * 2 + j] = a[i * 2 + j] * 2
with ft.For("i", 0, 8, label="L2") as i:
c[i] = b[i] + a[0]
ast = ft.pop_ast(verbose=True)
s = ft.Schedule(ast, verbose=1)
s.parallelize("L1", "openmp")
s.parallelize_as("L2", "L1", "Vb")
ast = s.ast()
assert ft.find_stmt(ast, "<For>->L2").property.parallel == "openmp"

with ft.VarDef([("a", (8,), "int32", "input", "cpu"),
("c", (8,), "int32", "output", "cpu")]) as (a, c):
ft.MarkLabel("Vb")
with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b:
with ft.For("i", 0, 4, label="L1") as i:
with ft.For("j", 0, 2) as j:
b[i * 2 + j] = a[i * 2 + j] * 2
with ft.For("i", 0, 4, label="L1") as i:
with ft.For("j", 0, 2) as j:
c[i * 2 + j] = b[i * 2 + j] + a[0]
std = ft.pop_ast()

assert std.match(ast)


@pytest.mark.skipif(not ft.with_cuda(), reason="requires CUDA")
def test_multiple_levels():
with ft.VarDef([("a", (128, 128), "int32", "input", "cpu"),
Expand Down

0 comments on commit a3f6320

Please sign in to comment.