From 72c8448208fc3e9b303acedc45b047efb9d3ea0b Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Sun, 18 Feb 2024 09:11:05 +0800 Subject: [PATCH 1/4] Support reversed(dynamic_range) (#597) --- python/freetensor/core/frontend.py | 10 ++++++ python/freetensor/core/transform.py | 6 +++- test/50.frontend/test_transformer_basic.py | 38 ++++++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/python/freetensor/core/frontend.py b/python/freetensor/core/frontend.py index ffdfb7ef9..e73ac1280 100644 --- a/python/freetensor/core/frontend.py +++ b/python/freetensor/core/frontend.py @@ -3,6 +3,7 @@ ''' import sys +import builtins import numpy as np import inspect import traceback @@ -552,3 +553,12 @@ def foreach(self, name, body: Callable[[Any], None]) -> None: static_range = range + + +def reversed(rng): + if isinstance(rng, dynamic_range): + return dynamic_range( + rng.start + rng.step * ((rng.stop - rng.start - 1) // rng.step), + rng.start - rng.step, -rng.step) + else: + return builtins.reversed(rng) diff --git a/python/freetensor/core/transform.py b/python/freetensor/core/transform.py index 95b8d4462..c916776a3 100644 --- a/python/freetensor/core/transform.py +++ b/python/freetensor/core/transform.py @@ -10,6 +10,7 @@ from .expr import UndeclaredParam from .stmt import VarRef from .func import Func +from . import frontend from .frontend import lang_overload, staged_callable, LifetimeScope, dynamic_range from .context import pop_ast_and_user_grads, ctx_stack from .staging import StagingError, TransformError @@ -19,7 +20,10 @@ def _prepare_extra_locals(default_dynamic_range): - extra_locals = {'__ft__': sys.modules['freetensor']} + extra_locals = { + '__ft__': sys.modules['freetensor'], + 'reversed': frontend.reversed, + } if default_dynamic_range: extra_locals['range'] = dynamic_range return extra_locals diff --git a/test/50.frontend/test_transformer_basic.py b/test/50.frontend/test_transformer_basic.py index 47ab450f8..ea9c9ef1a 100644 --- a/test/50.frontend/test_transformer_basic.py +++ b/test/50.frontend/test_transformer_basic.py @@ -547,6 +547,44 @@ def test_expected(x: ft.Var[(), 'float32']): assert test.body.match(test_expected.body) +def test_reversed_dynamic_range_1(): + + @ft.transform(verbose=2) + def test(x, y): + x: ft.Var[(16,), "int32", "input"] + y: ft.Var[(16,), "int32", "output"] + for i in reversed(range(0, 16, 2)): + y[i] = x[i] + 1 + + @ft.transform + def test_expected(x, y): + x: ft.Var[(16,), "int32", "input"] + y: ft.Var[(16,), "int32", "output"] + for i in range(14, -2, -2): + y[i] = x[i] + 1 + + assert test.body.match(test_expected.body) + + +def test_reversed_dynamic_range_2(): + + @ft.transform(verbose=2) + def test(x, y): + x: ft.Var[(16,), "int32", "input"] + y: ft.Var[(16,), "int32", "output"] + for i in reversed(range(0, 15, 2)): + y[i] = x[i] + 1 + + @ft.transform + def test_expected(x, y): + x: ft.Var[(16,), "int32", "input"] + y: ft.Var[(16,), "int32", "output"] + for i in range(14, -2, -2): + y[i] = x[i] + 1 + + assert test.body.match(test_expected.body) + + @dataclass class DummyAssigned: attr = None From a3736dac58775aff59f39c8bc2998d5a1143912f Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Mon, 19 Feb 2024 12:54:36 +0800 Subject: [PATCH 2/4] Add `A<:(B<:)*C` and `C:>(B:>)*A` selectors (#598) --- docs/guide/schedules.md | 6 ++--- grammar/selector_parser.g | 6 +++++ include/selector.h | 12 ++++++--- src/selector.cc | 6 +++++ test/10.analyze/test_find_stmt.py | 42 +++++++++++++++++++++++++++++++ 5 files changed, 65 insertions(+), 7 deletions(-) diff --git a/docs/guide/schedules.md b/docs/guide/schedules.md index 4361644c8..d6aa137cc 100644 --- a/docs/guide/schedules.md +++ b/docs/guide/schedules.md @@ -79,9 +79,9 @@ In the example above, we label a loop `Li` and apply schedules on it. It is stra 2. (For debugging only) A numerical ID is also a selector. E.g., `#31`. 3. A node type surrounded in angle brackets (`<>`) is also a selector. E.g., `` matches for-loop statements. 4. A selector can be extended to match a new statement produced by a previous schedule. E.g., `$split.0{Li}` matches the outer loop split from the loop `Li`. This is useful when return values from schedules are hard to track. Please refer the [API document](../../api/#freetensor.core.schedule.Schedule) for detailed grammar. -5. Selectors can be combined to match a statement by nesting order. `A<-B` matches a statement `A` DIRECTLY NESTED IN another statement `B`. `A<<-B` matches a statement DIRECTLY or INDIRECTLY nested in another statement `B`. `A<-(B<-)*C` matches a statement `A` DIRECTLY or INDIRECTLY nested in another statement `C` with intermedaite nesting statements satisfying the condition in `B`. `B->A` matches a statement `B` directly OUT OF another statement `A`. `B->>A` and `C->(B->)*A` are alike. (`A`, `B`, `C` can be nested selectors.) Use `<-|` for the root node, and `->|` for a leaf node. -6. Selectors can be combined to match a statement by DFS order. `A<:B` matches a statement `A` DIRECTLY BEFORE another statement `B`. `A<<:B` matches a statement `A` DIRECTLY or INDIRECTLY before another statement `B`. `B:>A` matches a statment `B` directly AFTER another statement `A`. `B:>>A` matches a statement `B` directly or indirectly after another statement `A`. A statement's descents (other statements nested in) or ancestors (other statements nesting it) are neither considered before or after the statement in the DFS order. "Directly" means there is no other statement between the comparing statements, so a statment can "directly before" both another statement and its parent. -7. Selectors can be combined to match a statement in a function call. `A<~B` matches a statement `A` DIRECTLY called by a call site `B`. `A<<~B` matches a statement DIRECTLY or INDIRECTLY called by a call site `B`. `A<~(B<~)*C` matches a statement `A` DIRECTLY or INDIRECTLY called by a call site `C` with intermediate call sites satisfying the condition in `B`. (`A`, `B`, `C` can be nested selectors.) Use `<~|` for the root function. +5. Selectors can be combined to match a statement by nesting order. `A<-B` matches a statement `A` DIRECTLY NESTED IN another statement `B`. `A<<-B` matches a statement `A` DIRECTLY or INDIRECTLY nested in another statement `B`. `A<-(B<-)*C` matches a statement `A` DIRECTLY or INDIRECTLY nested in another statement `C` with intermedaite nesting statements satisfying the condition in `B`. `B->A` matches a statement `B` directly OUT OF another statement `A`. `B->>A` and `C->(B->)*A` are alike. (`A`, `B`, `C` can be nested selectors.) Use `<-|` for the root node, and `->|` for a leaf node. +6. Selectors can be combined to match a statement by program order. `A<:B` matches a statement `A` DIRECTLY BEFORE another statement `B`. `A<<:B` matches a statement `A` DIRECTLY or INDIRECTLY before another statement `B`. `A<:(B<:)*C` matches a statement `A` DIRECTLY or INDIRECTLY before another statement `C` with intermediate statements satisfying the condition in `B`. `B:>A` matches a statment `B` directly AFTER another statement `A`. `B:>>A` and `C:>(B:>)*A` are alike. A statement's descents (other statements nested in) or ancestors (other statements nesting it) are neither considered before or after the statement in the program order. "Directly" means there is no other statement between the comparing statements, so a statment can "directly before" both another statement and its parent. +7. Selectors can be combined to match a statement in a function call. `A<~B` matches a statement `A` DIRECTLY called by a call site `B`. `A<<~B` matches a statement `A` DIRECTLY or INDIRECTLY called by a call site `B`. `A<~(B<~)*C` matches a statement `A` DIRECTLY or INDIRECTLY called by a call site `C` with intermediate call sites satisfying the condition in `B`. (`A`, `B`, `C` can be nested selectors.) Use `<~|` for the root function. 8. All the arrow-like selectors (`<-`, `<~`, `<:`, etc.) are right-associated. For example, `A<-B<-C` matches `A` nested in `B`, where `B` is nested in `C`. 9. All the arrow-like selectors can be used with the first argument omitted. For example, `<-B` matches ALL statements nested in `B`. 10. Selectors can be combined with logical "and" (`&`), "or" (`|`), "not" (`!`) and parentheses. E.g., `Li|Lj` matches a statement labeled `Li` OR `Lj`. `Li&Lj` matches a statement labeled `Li&Lj`. diff --git a/grammar/selector_parser.g b/grammar/selector_parser.g index b1a05cde5..b71de10a3 100644 --- a/grammar/selector_parser.g +++ b/grammar/selector_parser.g @@ -158,9 +158,15 @@ selectorImplicitAnd returns[Ref s] | BeforeArrow following=selectorFactor { $s = Ref::make($following.s); } + | DirectBeforeArrow LeftParen middle=selectorFactor DirectBeforeArrow RightParen Star following=selectorFactor { + $s = Ref::make($following.s, $middle.s); + } | DirectAfterArrow leading=selectorFactor { $s = Ref::make($leading.s); } | AfterArrow leading=selectorFactor { $s = Ref::make($leading.s); + } + | DirectAfterArrow LeftParen middle=selectorFactor DirectAfterArrow RightParen Star leading=selectorFactor { + $s = Ref::make($leading.s, $middle.s); }; diff --git a/include/selector.h b/include/selector.h index ee3328cbc..4263a9b80 100644 --- a/include/selector.h +++ b/include/selector.h @@ -124,13 +124,15 @@ class DirectBeforeSelector : public Selector { }; class BeforeSelector : public Selector { - Ref following_; + Ref following_, middle_; protected: bool matchImpl(const Stmt &stmt) override; public: - BeforeSelector(const Ref &following) : following_(following) {} + BeforeSelector(const Ref &following, + const Ref &middle = nullptr) + : following_(following), middle_(middle) {} }; class DirectAfterSelector : public Selector { @@ -144,13 +146,15 @@ class DirectAfterSelector : public Selector { }; class AfterSelector : public Selector { - Ref leading_; + Ref leading_, middle_; protected: bool matchImpl(const Stmt &stmt) override; public: - AfterSelector(const Ref &leading) : leading_(leading) {} + AfterSelector(const Ref &leading, + const Ref &middle = nullptr) + : leading_(leading), middle_(middle) {} }; class RootNodeSelector : public Selector { diff --git a/src/selector.cc b/src/selector.cc index 2a286065b..76b7a897e 100644 --- a/src/selector.cc +++ b/src/selector.cc @@ -89,6 +89,9 @@ bool BeforeSelector::matchImpl(const Stmt &stmt) { following_->match(next)) { return true; } + if (middle_.isValid() && !middle_->match(next)) { + return false; + } } return false; } @@ -112,6 +115,9 @@ bool AfterSelector::matchImpl(const Stmt &stmt) { leading_->match(prev)) { return true; } + if (middle_.isValid() && !middle_->match(prev)) { + return false; + } } return false; } diff --git a/test/10.analyze/test_find_stmt.py b/test/10.analyze/test_find_stmt.py index e6b5341ce..704baa9a1 100644 --- a/test/10.analyze/test_find_stmt.py +++ b/test/10.analyze/test_find_stmt.py @@ -193,6 +193,27 @@ def test_select_before(): assert sorted_ids(results) == sorted_ids(results_by_label) +def test_select_before_with_middle(): + with ft.VarDef("x", (), "int32", "inout", "cpu") as x: + with ft.If(x[...] < 0): + ft.MarkLabel("S1") + x[...] += 1 + with ft.If(x[...] < 0): + ft.MarkLabel("S2") + x[...] += 1 + with ft.If(x[...] < 0): + ft.MarkLabel("S3") + x[...] += 1 + with ft.If(x[...] < 0): + ft.MarkLabel("S4") + x[...] += 1 + ast = ft.pop_ast(verbose=True) + + results = ft.find_all_stmt(ast, "<:(!S2<:)*S4") + results_by_label = ft.find_all_stmt(ast, "S2|S3") + assert sorted_ids(results) == sorted_ids(results_by_label) + + def test_select_direct_after(): with ft.VarDef("x", (), "int32", "inout", "cpu") as x: ft.MarkLabel("I1") @@ -239,6 +260,27 @@ def test_select_after(): assert sorted_ids(results) == sorted_ids(results_by_label) +def test_select_after_with_middle(): + with ft.VarDef("x", (), "int32", "inout", "cpu") as x: + with ft.If(x[...] < 0): + ft.MarkLabel("S1") + x[...] += 1 + with ft.If(x[...] < 0): + ft.MarkLabel("S2") + x[...] += 1 + with ft.If(x[...] < 0): + ft.MarkLabel("S3") + x[...] += 1 + with ft.If(x[...] < 0): + ft.MarkLabel("S4") + x[...] += 1 + ast = ft.pop_ast(verbose=True) + + results = ft.find_all_stmt(ast, ":>(!S3:>)*S1") + results_by_label = ft.find_all_stmt(ast, "S2|S3") + assert sorted_ids(results) == sorted_ids(results_by_label) + + def test_select_before_a_scope(): with ft.VarDef("x", (), "int32", "inout", "cpu") as x: ft.MarkLabel("I1") From 2e285e97d0a44da74b775970201fb9931c504fe7 Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Wed, 28 Feb 2024 18:07:21 +0800 Subject: [PATCH 3/4] Exclude integer expressions earlier from autograd (#599) --- include/autograd/propagate_defs_need_grad.h | 2 ++ src/autograd/derivative.cc | 3 +++ src/autograd/propagate_defs_need_grad.cc | 17 ++++++++++++++--- test/21.autograd/test_grad.py | 18 ++++++++++++++++++ 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/include/autograd/propagate_defs_need_grad.h b/include/autograd/propagate_defs_need_grad.h index 4e8164b6c..56a980c0d 100644 --- a/include/autograd/propagate_defs_need_grad.h +++ b/include/autograd/propagate_defs_need_grad.h @@ -37,6 +37,7 @@ class PropagateRequires : public SymbolTable { protected: using BaseClass::visit; + void visitExpr(const Expr &e) override; void visit(const Load &op) override; void visit(const Store &op) override; void visit(const ReduceTo &op) override; @@ -71,6 +72,7 @@ class PropagateProvides : public SymbolTable { protected: using BaseClass::visit; + void visitExpr(const Expr &e) override; void visit(const Load &op) override; void visit(const Store &op) override; void visit(const ReduceTo &op) override; diff --git a/src/autograd/derivative.cc b/src/autograd/derivative.cc index d4220ffba..085aaee33 100644 --- a/src/autograd/derivative.cc +++ b/src/autograd/derivative.cc @@ -116,6 +116,9 @@ void Derivative::setPartial(const Expr &expr, const Expr &partial) { } void Derivative::visitExpr(const Expr &expr) { + if (!isFloat(expr->dtype())) { + return; + } if (!rootExpr_.isValid()) { rootExpr_ = StmtOrExprID{expr, expr->parentStmt()}; setPartial(expr, makeIntConst(1)); diff --git a/src/autograd/propagate_defs_need_grad.cc b/src/autograd/propagate_defs_need_grad.cc index b39052b70..b01775757 100644 --- a/src/autograd/propagate_defs_need_grad.cc +++ b/src/autograd/propagate_defs_need_grad.cc @@ -3,9 +3,14 @@ namespace freetensor { +void PropagateRequires::visitExpr(const Expr &e) { + if (isFloat(e->dtype())) { + BaseClass::visitExpr(e); + } +} + void PropagateRequires::visit(const Load &op) { - if (isFloat(op->dtype()) && curTarget_.isValid() && - affectedDefs_.count(def(op->var_)->id())) { + if (curTarget_.isValid() && affectedDefs_.count(def(op->var_)->id())) { affectedDefs_.insert(curTarget_); // No need to recurse deeper } @@ -71,8 +76,14 @@ std::unordered_set PropagateRequires::propagateUntilConverge( return propagator.affectedDefs(); } +void PropagateProvides::visitExpr(const Expr &e) { + if (isFloat(e->dtype())) { + BaseClass::visitExpr(e); + } +} + void PropagateProvides::visit(const Load &op) { - if (isFloat(op->dtype()) && curTarget_.isValid() && + if (curTarget_.isValid() && buffer(op->var_)->atype() == AccessType::Cache) { affectedDefs_.insert(def(op->var_)->id()); // No need to recurse deeper diff --git a/test/21.autograd/test_grad.py b/test/21.autograd/test_grad.py index f1dc24fa6..c1a7338a1 100644 --- a/test/21.autograd/test_grad.py +++ b/test/21.autograd/test_grad.py @@ -967,6 +967,24 @@ def f(a, b): return ft.libop.matmul(a, b) +def test_no_grad_integer(): + # Should not report error on intrinsic because it is integer + with ft.VarDef([("x", (), "float32", "input", "cpu"), + ("y", (), "float32", "output", "cpu")]) as (x, y): + y[...] = ft.intrinsic("(%)", ft.cast(x[...], "int32"), ret_type="int32") + ast = ft.pop_ast(verbose=True) + _, ast, _, _, _ = ft.grad_body(ast, ["x"], ["y"], + set(), + reset_provided_grad=False) + print(ast) + + with ft.VarDef("d_x", (), "float32", "output", "cpu") as d_x: + d_x[...] = 0 + std = ft.pop_ast() + + assert std.match(ast) + + def test_error_input_not_found(): @ft.transform From 87bfe21616cb822f2d01eea5d6a70b4a3d1c6047 Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Thu, 14 Mar 2024 22:12:16 +0800 Subject: [PATCH 4/4] [CI] Adjust how the tests are run in Slurm (#600) * Adjust how the tests are run in CI * Fix schedule/auto_parallelize and adapt its test to A100 * Remove pybind11-stubgen from CI environment because it's conflict against pyproject unless --no-build-isolation --- .github/workflows/main.yml | 34 +++++++++++++++---- pytest.ini | 3 ++ requirements.txt | 1 - src/schedule/auto_parallelize.cc | 2 +- .../test_auto_fission_fuse.py | 3 ++ .../31.auto_schedule/test_auto_parallelize.py | 10 +++--- 6 files changed, 40 insertions(+), 13 deletions(-) create mode 100644 pytest.ini diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7df9fb269..0d9d51553 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -37,9 +37,7 @@ jobs: source /opt/spack/share/spack/setup-env.sh spack load python~debug@3.9.12%gcc@10.2.1 cuda@11.8.0/ehz25ml cudnn@8.7.0.84-11.8/uopt2y4 intel-mkl@2020.4.304 java@11 gcc@11.3.0 source ci-script/prepare-python-environment.sh - # Set OMP_PROC_BIND to make OpenMP happy for 30.schedule/test_auto_fission_fuse.py::test_tune_fission - # Setting OMP_NUM_THREADS=256 seems to work around the conflict of PyTorch - OMP_NUM_THREADS=256 OMP_PROC_BIND=true srun --exclusive -N 1 -p ja --gres=gpu:v100:1 pytest --color=yes test + srun -N 1 -c 64 -p octave --gres=gpu:a100:1 pytest --color=yes -m "not performance_sensitive" test build-and-test-gcc-minimal-run_in_tree: runs-on: self-hosted if: github.event.pull_request.draft == false @@ -65,8 +63,7 @@ jobs: source /opt/spack/share/spack/setup-env.sh spack load python~debug@3.9.12%gcc@10.2.1 java@11 gcc@12.1.0 source ci-script/prepare-python-environment.sh - # Set OMP_PROC_BIND to make OpenMP happy for 30.schedule/test_auto_fission_fuse.py::test_tune_fission - OMP_PROC_BIND=true PYTHONPATH=build:python:$PYTHONPATH srun --exclusive -N 1 -p ja pytest --color=yes test + PYTHONPATH=build:python:$PYTHONPATH srun -N 1 -c 64 -p ja pytest --color=yes -m "not performance_sensitive" test build-and-test-clang-run-in-tree: runs-on: self-hosted if: github.event.pull_request.draft == false @@ -92,5 +89,30 @@ jobs: source /opt/spack/share/spack/setup-env.sh spack load python~debug@3.9.12%gcc@10.2.1 java@11 llvm@16%gcc@12 source ci-script/prepare-python-environment.sh + PYTHONPATH=build:python:$PYTHONPATH srun -N 1 -c 64 -p ja pytest --color=yes -m "not performance_sensitive" test + build-and-test-gcc-cuda-mkl-exclusively: + runs-on: self-hosted + if: github.event.pull_request.draft == false + steps: + - uses: roastduck/checkout@main + with: + ssh-key: ${{ secrets.CI }} + submodules: true + fetch-depth: 0 + - name: Build ffi module in Release + run: | + git submodule foreach --recursive git clean -ffdx + git submodule foreach --recursive git reset --hard + source /opt/spack/share/spack/setup-env.sh + spack load python~debug@3.9.12%gcc@10.2.1 cuda@11.8.0/ehz25ml cudnn@8.7.0.84-11.8/uopt2y4 intel-mkl@2020.4.304 java@11 gcc@11.3.0 + source ci-script/prepare-python-environment.sh + # -C requires a new enough pip + pip3 install --upgrade pip + pip3 install . -C--local=with-cuda.toml -C--local=ci-script/with-spack-mkl.toml + - name: Run PyTest + run: | + source /opt/spack/share/spack/setup-env.sh + spack load python~debug@3.9.12%gcc@10.2.1 cuda@11.8.0/ehz25ml cudnn@8.7.0.84-11.8/uopt2y4 intel-mkl@2020.4.304 java@11 gcc@11.3.0 + source ci-script/prepare-python-environment.sh # Set OMP_PROC_BIND to make OpenMP happy for 30.schedule/test_auto_fission_fuse.py::test_tune_fission - OMP_PROC_BIND=true PYTHONPATH=build:python:$PYTHONPATH srun --exclusive -N 1 -p ja pytest --color=yes test + OMP_PROC_BIND=true srun --exclusive=user -N 1 -c 256 -p ja --gres=gpu:v100:1 pytest --color=yes -m "performance_sensitive" test diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..bc24e9761 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + performance_sensitive: These tests should be run on a exlusively dedicated node diff --git a/requirements.txt b/requirements.txt index fb4d0b7c5..c079cebcb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,6 @@ numpy==1.24.3 packaging==23.1 pluggy==1.0.0 py_build_cmake==0.1.8 -pybind11-stubgen==0.13.0 Pygments==2.15.1 pymdown-extensions==10.0 pytest==7.3.1 diff --git a/src/schedule/auto_parallelize.cc b/src/schedule/auto_parallelize.cc index 32597601b..b56a41762 100644 --- a/src/schedule/auto_parallelize.cc +++ b/src/schedule/auto_parallelize.cc @@ -570,7 +570,7 @@ void Schedule::autoParallelize(const Ref &target) { }); // III b. Reduction - if (!needParRed) { + if (localParaAll.size() == localParaNoRed.size() || !needParRed) { commitTransaction(); } else { abortTransaction(); diff --git a/test/31.auto_schedule/test_auto_fission_fuse.py b/test/31.auto_schedule/test_auto_fission_fuse.py index 6c074b0c9..270cceb54 100644 --- a/test/31.auto_schedule/test_auto_fission_fuse.py +++ b/test/31.auto_schedule/test_auto_fission_fuse.py @@ -90,6 +90,7 @@ def test_stmt_in_between_2(): assert logs == ["swap(L2, S1)", "fuse(L1, L2, true)"] +@pytest.mark.performance_sensitive def test_tune_fuse(): # We may fuse these loops. But fusing them will make it impossible to parallelize. # After tuning, we will end up in not fusing them @@ -133,6 +134,7 @@ def test_tune_fuse(): assert "fuse" not in log +@pytest.mark.performance_sensitive def test_tune_fission(): # The reverse schedule of `test_tune_fuse` @@ -178,6 +180,7 @@ def test_tune_fission(): assert "fission" in ", ".join(logs) +@pytest.mark.performance_sensitive @pytest.mark.skipif(not ft.with_cuda(), reason="requires CUDA") def test_tune_with_cond(): # Fuse loops that can parallelize. Don't fuse loops that can't diff --git a/test/31.auto_schedule/test_auto_parallelize.py b/test/31.auto_schedule/test_auto_parallelize.py index 9d67221f1..6dd268f60 100644 --- a/test/31.auto_schedule/test_auto_parallelize.py +++ b/test/31.auto_schedule/test_auto_parallelize.py @@ -51,10 +51,10 @@ def test_3_levels(): @pytest.mark.skipif(not ft.with_cuda(), reason="requires CUDA") def test_gpu_basic_static_small(): - with ft.VarDef([("x", (10, 10, 2), "int32", "input", "cpu"), - ("y", (10, 10, 2), "int32", "output", "cpu")]) as (x, y): - with ft.For("i", 0, 10, label="Li") as i: - with ft.For("j", 0, 10, label="Lj") as j: + with ft.VarDef([("x", (20, 20, 2), "int32", "input", "cpu"), + ("y", (20, 20, 2), "int32", "output", "cpu")]) as (x, y): + with ft.For("i", 0, 20, label="Li") as i: + with ft.For("j", 0, 20, label="Lj") as j: y[i, j, 0] = x[i, j, 0] + 1 device = ft.GPU() @@ -68,7 +68,7 @@ def test_gpu_basic_static_small(): logs = list(map(str, s.logs())) print(logs) assert fnmatch_list(logs, [ - f'split(Lj, -1, {num_sm // 10}, 0)', 'merge(Li, $split.0{Lj})', + f'split(Lj, -1, {num_sm // 20}, 0)', 'merge(Li, $split.0{Lj})', 'parallelize($merge{Li, $split.0{Lj}}, blockIdx.x, *)', 'parallelize($split.1{Lj}, threadIdx.y, *)' ])