From d97bd4ef7f1e86b707259ca6a792cb300b406228 Mon Sep 17 00:00:00 2001 From: Shizhi Tang Date: Wed, 28 Feb 2024 15:47:29 +0800 Subject: [PATCH] Exclude integer expressions earlier from autograd --- include/autograd/propagate_defs_need_grad.h | 2 ++ src/autograd/derivative.cc | 3 +++ src/autograd/propagate_defs_need_grad.cc | 17 ++++++++++++++--- test/21.autograd/test_grad.py | 18 ++++++++++++++++++ 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/include/autograd/propagate_defs_need_grad.h b/include/autograd/propagate_defs_need_grad.h index 4e8164b6c..56a980c0d 100644 --- a/include/autograd/propagate_defs_need_grad.h +++ b/include/autograd/propagate_defs_need_grad.h @@ -37,6 +37,7 @@ class PropagateRequires : public SymbolTable { protected: using BaseClass::visit; + void visitExpr(const Expr &e) override; void visit(const Load &op) override; void visit(const Store &op) override; void visit(const ReduceTo &op) override; @@ -71,6 +72,7 @@ class PropagateProvides : public SymbolTable { protected: using BaseClass::visit; + void visitExpr(const Expr &e) override; void visit(const Load &op) override; void visit(const Store &op) override; void visit(const ReduceTo &op) override; diff --git a/src/autograd/derivative.cc b/src/autograd/derivative.cc index d4220ffba..085aaee33 100644 --- a/src/autograd/derivative.cc +++ b/src/autograd/derivative.cc @@ -116,6 +116,9 @@ void Derivative::setPartial(const Expr &expr, const Expr &partial) { } void Derivative::visitExpr(const Expr &expr) { + if (!isFloat(expr->dtype())) { + return; + } if (!rootExpr_.isValid()) { rootExpr_ = StmtOrExprID{expr, expr->parentStmt()}; setPartial(expr, makeIntConst(1)); diff --git a/src/autograd/propagate_defs_need_grad.cc b/src/autograd/propagate_defs_need_grad.cc index b39052b70..b01775757 100644 --- a/src/autograd/propagate_defs_need_grad.cc +++ b/src/autograd/propagate_defs_need_grad.cc @@ -3,9 +3,14 @@ namespace freetensor { +void PropagateRequires::visitExpr(const Expr &e) { + if (isFloat(e->dtype())) { + BaseClass::visitExpr(e); + } +} + void PropagateRequires::visit(const Load &op) { - if (isFloat(op->dtype()) && curTarget_.isValid() && - affectedDefs_.count(def(op->var_)->id())) { + if (curTarget_.isValid() && affectedDefs_.count(def(op->var_)->id())) { affectedDefs_.insert(curTarget_); // No need to recurse deeper } @@ -71,8 +76,14 @@ std::unordered_set PropagateRequires::propagateUntilConverge( return propagator.affectedDefs(); } +void PropagateProvides::visitExpr(const Expr &e) { + if (isFloat(e->dtype())) { + BaseClass::visitExpr(e); + } +} + void PropagateProvides::visit(const Load &op) { - if (isFloat(op->dtype()) && curTarget_.isValid() && + if (curTarget_.isValid() && buffer(op->var_)->atype() == AccessType::Cache) { affectedDefs_.insert(def(op->var_)->id()); // No need to recurse deeper diff --git a/test/21.autograd/test_grad.py b/test/21.autograd/test_grad.py index f1dc24fa6..c1a7338a1 100644 --- a/test/21.autograd/test_grad.py +++ b/test/21.autograd/test_grad.py @@ -967,6 +967,24 @@ def f(a, b): return ft.libop.matmul(a, b) +def test_no_grad_integer(): + # Should not report error on intrinsic because it is integer + with ft.VarDef([("x", (), "float32", "input", "cpu"), + ("y", (), "float32", "output", "cpu")]) as (x, y): + y[...] = ft.intrinsic("(%)", ft.cast(x[...], "int32"), ret_type="int32") + ast = ft.pop_ast(verbose=True) + _, ast, _, _, _ = ft.grad_body(ast, ["x"], ["y"], + set(), + reset_provided_grad=False) + print(ast) + + with ft.VarDef("d_x", (), "float32", "output", "cpu") as d_x: + d_x[...] = 0 + std = ft.pop_ast() + + assert std.match(ast) + + def test_error_input_not_found(): @ft.transform