From a3f63208c197c11db6296bdf614232390382c41a Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Wed, 17 Jan 2024 17:37:48 +0800 Subject: [PATCH] Fix vardef choosing in schedule/parallelize_as --- src/schedule/parallelize_as.cc | 12 +++++----- test/30.schedule/test_parallelize_as.py | 32 +++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/src/schedule/parallelize_as.cc b/src/schedule/parallelize_as.cc index 7e0f550de..86d9fe27a 100644 --- a/src/schedule/parallelize_as.cc +++ b/src/schedule/parallelize_as.cc @@ -38,7 +38,7 @@ PBMap projectOntoOneOutputDim(const PBMap &map, int dim) { class AddParScopes : public TrackStmt> { typedef TrackStmt> BaseClass; - ID nest_; + ID nest_, defId_; const PBCtx &presburger_; const std::vector &orderedScopes_; const std::unordered_map &scope2Idx2Iter_; @@ -55,18 +55,18 @@ class AddParScopes : public TrackStmt> { std::unordered_map>> threadGuard_; public: - AddParScopes(const ID &nest, const PBCtx &presburger, + AddParScopes(const ID &nest, const ID &defId, const PBCtx &presburger, const std::vector &orderedScopes, const std::unordered_map &scope2Idx2Iter) - : nest_(nest), presburger_(presburger), orderedScopes_(orderedScopes), - scope2Idx2Iter_(scope2Idx2Iter) {} + : nest_(nest), defId_(defId), presburger_(presburger), + orderedScopes_(orderedScopes), scope2Idx2Iter_(scope2Idx2Iter) {} const auto &newScopeIds() const { return newScopeIds_; } const auto &newNestId() const { return newNestId_; } private: template auto visitAcc(const T &op) { - if (inside_) { + if (inside_ && def(op->var_)->id() == defId_) { std::vector thisThreadGuard; thisThreadGuard.reserve(orderedScopes_.size()); for (auto &&[scope, newIterName] : @@ -261,7 +261,7 @@ Stmt parallelizeAs(const Stmt &_ast, const ID &nest, const ID &reference, } } - AddParScopes adder{nest, presburger, orderedScopes, scope2Idx2Iter}; + AddParScopes adder{nest, defId, presburger, orderedScopes, scope2Idx2Iter}; ast = adder(ast); // Shrink original loops in `nest` according to the gaurds with just add. If diff --git a/test/30.schedule/test_parallelize_as.py b/test/30.schedule/test_parallelize_as.py index 9a5c4462a..a01d6f83a 100644 --- a/test/30.schedule/test_parallelize_as.py +++ b/test/30.schedule/test_parallelize_as.py @@ -133,6 +133,38 @@ def test_reference_after_nest(): assert std.match(ast) +def test_choosing_vardef(): + with ft.VarDef([("a", (8,), "int32", "input", "cpu"), + ("c", (8,), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 2) as j: + b[i * 2 + j] = a[i * 2 + j] * 2 + with ft.For("i", 0, 8, label="L2") as i: + c[i] = b[i] + a[0] + ast = ft.pop_ast(verbose=True) + s = ft.Schedule(ast, verbose=1) + s.parallelize("L1", "openmp") + s.parallelize_as("L2", "L1", "Vb") + ast = s.ast() + assert ft.find_stmt(ast, "->L2").property.parallel == "openmp" + + with ft.VarDef([("a", (8,), "int32", "input", "cpu"), + ("c", (8,), "int32", "output", "cpu")]) as (a, c): + ft.MarkLabel("Vb") + with ft.VarDef("b", (8,), "int32", "cache", "cpu") as b: + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 2) as j: + b[i * 2 + j] = a[i * 2 + j] * 2 + with ft.For("i", 0, 4, label="L1") as i: + with ft.For("j", 0, 2) as j: + c[i * 2 + j] = b[i * 2 + j] + a[0] + std = ft.pop_ast() + + assert std.match(ast) + + @pytest.mark.skipif(not ft.with_cuda(), reason="requires CUDA") def test_multiple_levels(): with ft.VarDef([("a", (128, 128), "int32", "input", "cpu"),