Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/roastduck/FreeTensor into…
Browse files Browse the repository at this point in the history
… cutlass
  • Loading branch information
roastduck committed Mar 15, 2024
2 parents 0f4d8fd + 87bfe21 commit 54c0ab3
Show file tree
Hide file tree
Showing 18 changed files with 195 additions and 24 deletions.
34 changes: 28 additions & 6 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ jobs:
source /opt/spack/share/spack/setup-env.sh
spack load [email protected]%[email protected] [email protected]/ehz25ml [email protected]/uopt2y4 [email protected] java@11 [email protected]
source ci-script/prepare-python-environment.sh
# Set OMP_PROC_BIND to make OpenMP happy for 30.schedule/test_auto_fission_fuse.py::test_tune_fission
# Setting OMP_NUM_THREADS=256 seems to work around the conflict of PyTorch
OMP_NUM_THREADS=256 OMP_PROC_BIND=true srun --exclusive -N 1 -p ja --gres=gpu:v100:1 pytest --color=yes test
srun -N 1 -c 64 -p octave --gres=gpu:a100:1 pytest --color=yes -m "not performance_sensitive" test
build-and-test-gcc-minimal-run_in_tree:
runs-on: self-hosted
if: github.event.pull_request.draft == false
Expand All @@ -65,8 +63,7 @@ jobs:
source /opt/spack/share/spack/setup-env.sh
spack load [email protected]%[email protected] java@11 [email protected]
source ci-script/prepare-python-environment.sh
# Set OMP_PROC_BIND to make OpenMP happy for 30.schedule/test_auto_fission_fuse.py::test_tune_fission
OMP_PROC_BIND=true PYTHONPATH=build:python:$PYTHONPATH srun --exclusive -N 1 -p ja pytest --color=yes test
PYTHONPATH=build:python:$PYTHONPATH srun -N 1 -c 64 -p ja pytest --color=yes -m "not performance_sensitive" test
build-and-test-clang-run-in-tree:
runs-on: self-hosted
if: github.event.pull_request.draft == false
Expand All @@ -92,5 +89,30 @@ jobs:
source /opt/spack/share/spack/setup-env.sh
spack load [email protected]%[email protected] java@11 llvm@16%gcc@12
source ci-script/prepare-python-environment.sh
PYTHONPATH=build:python:$PYTHONPATH srun -N 1 -c 64 -p ja pytest --color=yes -m "not performance_sensitive" test
build-and-test-gcc-cuda-mkl-exclusively:
runs-on: self-hosted
if: github.event.pull_request.draft == false
steps:
- uses: roastduck/checkout@main
with:
ssh-key: ${{ secrets.CI }}
submodules: true
fetch-depth: 0
- name: Build ffi module in Release
run: |
git submodule foreach --recursive git clean -ffdx
git submodule foreach --recursive git reset --hard
source /opt/spack/share/spack/setup-env.sh
spack load [email protected]%[email protected] [email protected]/ehz25ml [email protected]/uopt2y4 [email protected] java@11 [email protected]
source ci-script/prepare-python-environment.sh
# -C requires a new enough pip
pip3 install --upgrade pip
pip3 install . -C--local=with-cuda.toml -C--local=ci-script/with-spack-mkl.toml
- name: Run PyTest
run: |
source /opt/spack/share/spack/setup-env.sh
spack load [email protected]%[email protected] [email protected]/ehz25ml [email protected]/uopt2y4 [email protected] java@11 [email protected]
source ci-script/prepare-python-environment.sh
# Set OMP_PROC_BIND to make OpenMP happy for 30.schedule/test_auto_fission_fuse.py::test_tune_fission
OMP_PROC_BIND=true PYTHONPATH=build:python:$PYTHONPATH srun --exclusive -N 1 -p ja pytest --color=yes test
OMP_PROC_BIND=true srun --exclusive=user -N 1 -c 256 -p ja --gres=gpu:v100:1 pytest --color=yes -m "performance_sensitive" test
6 changes: 3 additions & 3 deletions docs/guide/schedules.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ In the example above, we label a loop `Li` and apply schedules on it. It is stra
2. (For debugging only) A numerical ID is also a selector. E.g., `#31`.
3. A node type surrounded in angle brackets (`<>`) is also a selector. E.g., `<For>` matches for-loop statements.
4. A selector can be extended to match a new statement produced by a previous schedule. E.g., `$split.0{Li}` matches the outer loop split from the loop `Li`. This is useful when return values from schedules are hard to track. Please refer the [API document](../../api/#freetensor.core.schedule.Schedule) for detailed grammar.
5. Selectors can be combined to match a statement by nesting order. `A<-B` matches a statement `A` DIRECTLY NESTED IN another statement `B`. `A<<-B` matches a statement DIRECTLY or INDIRECTLY nested in another statement `B`. `A<-(B<-)*C` matches a statement `A` DIRECTLY or INDIRECTLY nested in another statement `C` with intermedaite nesting statements satisfying the condition in `B`. `B->A` matches a statement `B` directly OUT OF another statement `A`. `B->>A` and `C->(B->)*A` are alike. (`A`, `B`, `C` can be nested selectors.) Use `<-|` for the root node, and `->|` for a leaf node.
6. Selectors can be combined to match a statement by DFS order. `A<:B` matches a statement `A` DIRECTLY BEFORE another statement `B`. `A<<:B` matches a statement `A` DIRECTLY or INDIRECTLY before another statement `B`. `B:>A` matches a statment `B` directly AFTER another statement `A`. `B:>>A` matches a statement `B` directly or indirectly after another statement `A`. A statement's descents (other statements nested in) or ancestors (other statements nesting it) are neither considered before or after the statement in the DFS order. "Directly" means there is no other statement between the comparing statements, so a statment can "directly before" both another statement and its parent.
7. Selectors can be combined to match a statement in a function call. `A<~B` matches a statement `A` DIRECTLY called by a call site `B`. `A<<~B` matches a statement DIRECTLY or INDIRECTLY called by a call site `B`. `A<~(B<~)*C` matches a statement `A` DIRECTLY or INDIRECTLY called by a call site `C` with intermediate call sites satisfying the condition in `B`. (`A`, `B`, `C` can be nested selectors.) Use `<~|` for the root function.
5. Selectors can be combined to match a statement by nesting order. `A<-B` matches a statement `A` DIRECTLY NESTED IN another statement `B`. `A<<-B` matches a statement `A` DIRECTLY or INDIRECTLY nested in another statement `B`. `A<-(B<-)*C` matches a statement `A` DIRECTLY or INDIRECTLY nested in another statement `C` with intermedaite nesting statements satisfying the condition in `B`. `B->A` matches a statement `B` directly OUT OF another statement `A`. `B->>A` and `C->(B->)*A` are alike. (`A`, `B`, `C` can be nested selectors.) Use `<-|` for the root node, and `->|` for a leaf node.
6. Selectors can be combined to match a statement by program order. `A<:B` matches a statement `A` DIRECTLY BEFORE another statement `B`. `A<<:B` matches a statement `A` DIRECTLY or INDIRECTLY before another statement `B`. `A<:(B<:)*C` matches a statement `A` DIRECTLY or INDIRECTLY before another statement `C` with intermediate statements satisfying the condition in `B`. `B:>A` matches a statment `B` directly AFTER another statement `A`. `B:>>A` and `C:>(B:>)*A` are alike. A statement's descents (other statements nested in) or ancestors (other statements nesting it) are neither considered before or after the statement in the program order. "Directly" means there is no other statement between the comparing statements, so a statment can "directly before" both another statement and its parent.
7. Selectors can be combined to match a statement in a function call. `A<~B` matches a statement `A` DIRECTLY called by a call site `B`. `A<<~B` matches a statement `A` DIRECTLY or INDIRECTLY called by a call site `B`. `A<~(B<~)*C` matches a statement `A` DIRECTLY or INDIRECTLY called by a call site `C` with intermediate call sites satisfying the condition in `B`. (`A`, `B`, `C` can be nested selectors.) Use `<~|` for the root function.
8. All the arrow-like selectors (`<-`, `<~`, `<:`, etc.) are right-associated. For example, `A<-B<-C` matches `A` nested in `B`, where `B` is nested in `C`.
9. All the arrow-like selectors can be used with the first argument omitted. For example, `<-B` matches ALL statements nested in `B`.
10. Selectors can be combined with logical "and" (`&`), "or" (`|`), "not" (`!`) and parentheses. E.g., `Li|Lj` matches a statement labeled `Li` OR `Lj`. `Li&Lj` matches a statement labeled `Li&Lj`.
Expand Down
6 changes: 6 additions & 0 deletions grammar/selector_parser.g
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,15 @@ selectorImplicitAnd returns[Ref<Selector> s]
| BeforeArrow following=selectorFactor {
$s = Ref<BeforeSelector>::make($following.s);
}
| DirectBeforeArrow LeftParen middle=selectorFactor DirectBeforeArrow RightParen Star following=selectorFactor {
$s = Ref<BeforeSelector>::make($following.s, $middle.s);
}
| DirectAfterArrow leading=selectorFactor {
$s = Ref<DirectAfterSelector>::make($leading.s);
}
| AfterArrow leading=selectorFactor {
$s = Ref<AfterSelector>::make($leading.s);
}
| DirectAfterArrow LeftParen middle=selectorFactor DirectAfterArrow RightParen Star leading=selectorFactor {
$s = Ref<AfterSelector>::make($leading.s, $middle.s);
};
2 changes: 2 additions & 0 deletions include/autograd/propagate_defs_need_grad.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class PropagateRequires : public SymbolTable<Visitor> {

protected:
using BaseClass::visit;
void visitExpr(const Expr &e) override;
void visit(const Load &op) override;
void visit(const Store &op) override;
void visit(const ReduceTo &op) override;
Expand Down Expand Up @@ -71,6 +72,7 @@ class PropagateProvides : public SymbolTable<Visitor> {

protected:
using BaseClass::visit;
void visitExpr(const Expr &e) override;
void visit(const Load &op) override;
void visit(const Store &op) override;
void visit(const ReduceTo &op) override;
Expand Down
12 changes: 8 additions & 4 deletions include/selector.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,15 @@ class DirectBeforeSelector : public Selector {
};

class BeforeSelector : public Selector {
Ref<Selector> following_;
Ref<Selector> following_, middle_;

protected:
bool matchImpl(const Stmt &stmt) override;

public:
BeforeSelector(const Ref<Selector> &following) : following_(following) {}
BeforeSelector(const Ref<Selector> &following,
const Ref<Selector> &middle = nullptr)
: following_(following), middle_(middle) {}
};

class DirectAfterSelector : public Selector {
Expand All @@ -144,13 +146,15 @@ class DirectAfterSelector : public Selector {
};

class AfterSelector : public Selector {
Ref<Selector> leading_;
Ref<Selector> leading_, middle_;

protected:
bool matchImpl(const Stmt &stmt) override;

public:
AfterSelector(const Ref<Selector> &leading) : leading_(leading) {}
AfterSelector(const Ref<Selector> &leading,
const Ref<Selector> &middle = nullptr)
: leading_(leading), middle_(middle) {}
};

class RootNodeSelector : public Selector {
Expand Down
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
markers =
performance_sensitive: These tests should be run on a exlusively dedicated node
10 changes: 10 additions & 0 deletions python/freetensor/core/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
'''

import sys
import builtins
import numpy as np
import inspect
import traceback
Expand Down Expand Up @@ -552,3 +553,12 @@ def foreach(self, name, body: Callable[[Any], None]) -> None:


static_range = range


def reversed(rng):
if isinstance(rng, dynamic_range):
return dynamic_range(
rng.start + rng.step * ((rng.stop - rng.start - 1) // rng.step),
rng.start - rng.step, -rng.step)
else:
return builtins.reversed(rng)
6 changes: 5 additions & 1 deletion python/freetensor/core/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .expr import UndeclaredParam
from .stmt import VarRef
from .func import Func
from . import frontend
from .frontend import lang_overload, staged_callable, LifetimeScope, dynamic_range
from .context import pop_ast_and_user_grads, ctx_stack
from .staging import StagingError, TransformError
Expand All @@ -19,7 +20,10 @@


def _prepare_extra_locals(default_dynamic_range):
extra_locals = {'__ft__': sys.modules['freetensor']}
extra_locals = {
'__ft__': sys.modules['freetensor'],
'reversed': frontend.reversed,
}
if default_dynamic_range:
extra_locals['range'] = dynamic_range
return extra_locals
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ numpy==1.24.3
packaging==23.1
pluggy==1.0.0
py_build_cmake==0.1.8
pybind11-stubgen==0.13.0
Pygments==2.15.1
pymdown-extensions==10.0
pytest==7.3.1
Expand Down
3 changes: 3 additions & 0 deletions src/autograd/derivative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ void Derivative::setPartial(const Expr &expr, const Expr &partial) {
}

void Derivative::visitExpr(const Expr &expr) {
if (!isFloat(expr->dtype())) {
return;
}
if (!rootExpr_.isValid()) {
rootExpr_ = StmtOrExprID{expr, expr->parentStmt()};
setPartial(expr, makeIntConst(1));
Expand Down
17 changes: 14 additions & 3 deletions src/autograd/propagate_defs_need_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@

namespace freetensor {

void PropagateRequires::visitExpr(const Expr &e) {
if (isFloat(e->dtype())) {
BaseClass::visitExpr(e);
}
}

void PropagateRequires::visit(const Load &op) {
if (isFloat(op->dtype()) && curTarget_.isValid() &&
affectedDefs_.count(def(op->var_)->id())) {
if (curTarget_.isValid() && affectedDefs_.count(def(op->var_)->id())) {
affectedDefs_.insert(curTarget_);
// No need to recurse deeper
}
Expand Down Expand Up @@ -71,8 +76,14 @@ std::unordered_set<ID> PropagateRequires::propagateUntilConverge(
return propagator.affectedDefs();
}

void PropagateProvides::visitExpr(const Expr &e) {
if (isFloat(e->dtype())) {
BaseClass::visitExpr(e);
}
}

void PropagateProvides::visit(const Load &op) {
if (isFloat(op->dtype()) && curTarget_.isValid() &&
if (curTarget_.isValid() &&
buffer(op->var_)->atype() == AccessType::Cache) {
affectedDefs_.insert(def(op->var_)->id());
// No need to recurse deeper
Expand Down
2 changes: 1 addition & 1 deletion src/schedule/auto_parallelize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ void Schedule::autoParallelize(const Ref<Target> &target) {
});

// III b. Reduction
if (!needParRed) {
if (localParaAll.size() == localParaNoRed.size() || !needParRed) {
commitTransaction();
} else {
abortTransaction();
Expand Down
6 changes: 6 additions & 0 deletions src/selector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ bool BeforeSelector::matchImpl(const Stmt &stmt) {
following_->match(next)) {
return true;
}
if (middle_.isValid() && !middle_->match(next)) {
return false;
}
}
return false;
}
Expand All @@ -112,6 +115,9 @@ bool AfterSelector::matchImpl(const Stmt &stmt) {
leading_->match(prev)) {
return true;
}
if (middle_.isValid() && !middle_->match(prev)) {
return false;
}
}
return false;
}
Expand Down
42 changes: 42 additions & 0 deletions test/10.analyze/test_find_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,27 @@ def test_select_before():
assert sorted_ids(results) == sorted_ids(results_by_label)


def test_select_before_with_middle():
with ft.VarDef("x", (), "int32", "inout", "cpu") as x:
with ft.If(x[...] < 0):
ft.MarkLabel("S1")
x[...] += 1
with ft.If(x[...] < 0):
ft.MarkLabel("S2")
x[...] += 1
with ft.If(x[...] < 0):
ft.MarkLabel("S3")
x[...] += 1
with ft.If(x[...] < 0):
ft.MarkLabel("S4")
x[...] += 1
ast = ft.pop_ast(verbose=True)

results = ft.find_all_stmt(ast, "<ReduceTo><:(!S2<:)*S4")
results_by_label = ft.find_all_stmt(ast, "S2|S3")
assert sorted_ids(results) == sorted_ids(results_by_label)


def test_select_direct_after():
with ft.VarDef("x", (), "int32", "inout", "cpu") as x:
ft.MarkLabel("I1")
Expand Down Expand Up @@ -239,6 +260,27 @@ def test_select_after():
assert sorted_ids(results) == sorted_ids(results_by_label)


def test_select_after_with_middle():
with ft.VarDef("x", (), "int32", "inout", "cpu") as x:
with ft.If(x[...] < 0):
ft.MarkLabel("S1")
x[...] += 1
with ft.If(x[...] < 0):
ft.MarkLabel("S2")
x[...] += 1
with ft.If(x[...] < 0):
ft.MarkLabel("S3")
x[...] += 1
with ft.If(x[...] < 0):
ft.MarkLabel("S4")
x[...] += 1
ast = ft.pop_ast(verbose=True)

results = ft.find_all_stmt(ast, "<ReduceTo>:>(!S3:>)*S1")
results_by_label = ft.find_all_stmt(ast, "S2|S3")
assert sorted_ids(results) == sorted_ids(results_by_label)


def test_select_before_a_scope():
with ft.VarDef("x", (), "int32", "inout", "cpu") as x:
ft.MarkLabel("I1")
Expand Down
18 changes: 18 additions & 0 deletions test/21.autograd/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,24 @@ def f(a, b):
return ft.libop.matmul(a, b)


def test_no_grad_integer():
# Should not report error on intrinsic because it is integer
with ft.VarDef([("x", (), "float32", "input", "cpu"),
("y", (), "float32", "output", "cpu")]) as (x, y):
y[...] = ft.intrinsic("(%)", ft.cast(x[...], "int32"), ret_type="int32")
ast = ft.pop_ast(verbose=True)
_, ast, _, _, _ = ft.grad_body(ast, ["x"], ["y"],
set(),
reset_provided_grad=False)
print(ast)

with ft.VarDef("d_x", (), "float32", "output", "cpu") as d_x:
d_x[...] = 0
std = ft.pop_ast()

assert std.match(ast)


def test_error_input_not_found():

@ft.transform
Expand Down
3 changes: 3 additions & 0 deletions test/31.auto_schedule/test_auto_fission_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_stmt_in_between_2():
assert logs == ["swap(L2, S1)", "fuse(L1, L2, true)"]


@pytest.mark.performance_sensitive
def test_tune_fuse():
# We may fuse these loops. But fusing them will make it impossible to parallelize.
# After tuning, we will end up in not fusing them
Expand Down Expand Up @@ -133,6 +134,7 @@ def test_tune_fuse():
assert "fuse" not in log


@pytest.mark.performance_sensitive
def test_tune_fission():
# The reverse schedule of `test_tune_fuse`

Expand Down Expand Up @@ -178,6 +180,7 @@ def test_tune_fission():
assert "fission" in ", ".join(logs)


@pytest.mark.performance_sensitive
@pytest.mark.skipif(not ft.with_cuda(), reason="requires CUDA")
def test_tune_with_cond():
# Fuse loops that can parallelize. Don't fuse loops that can't
Expand Down
10 changes: 5 additions & 5 deletions test/31.auto_schedule/test_auto_parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def test_3_levels():

@pytest.mark.skipif(not ft.with_cuda(), reason="requires CUDA")
def test_gpu_basic_static_small():
with ft.VarDef([("x", (10, 10, 2), "int32", "input", "cpu"),
("y", (10, 10, 2), "int32", "output", "cpu")]) as (x, y):
with ft.For("i", 0, 10, label="Li") as i:
with ft.For("j", 0, 10, label="Lj") as j:
with ft.VarDef([("x", (20, 20, 2), "int32", "input", "cpu"),
("y", (20, 20, 2), "int32", "output", "cpu")]) as (x, y):
with ft.For("i", 0, 20, label="Li") as i:
with ft.For("j", 0, 20, label="Lj") as j:
y[i, j, 0] = x[i, j, 0] + 1

device = ft.GPU()
Expand All @@ -68,7 +68,7 @@ def test_gpu_basic_static_small():
logs = list(map(str, s.logs()))
print(logs)
assert fnmatch_list(logs, [
f'split(Lj, -1, {num_sm // 10}, 0)', 'merge(Li, $split.0{Lj})',
f'split(Lj, -1, {num_sm // 20}, 0)', 'merge(Li, $split.0{Lj})',
'parallelize($merge{Li, $split.0{Lj}}, blockIdx.x, *)',
'parallelize($split.1{Lj}, threadIdx.y, *)'
])
Expand Down
Loading

0 comments on commit 54c0ab3

Please sign in to comment.